""" Training configuration """ from dataclasses import dataclass from typing import Optional @dataclass class TrainingConfig: """Configuration for training NOVA models""" # Model model_name: str = "nova-125m" model_config_path: Optional[str] = None # Data train_data_path: str = "data/train" val_data_path: str = "data/val" max_seq_length: int = 2048 # Training hyperparameters num_epochs: int = 10 batch_size: int = 8 gradient_accumulation_steps: int = 4 learning_rate: float = 3e-4 weight_decay: float = 0.1 max_grad_norm: float = 1.0 warmup_steps: int = 1000 lr_scheduler: str = "cosine" # or "linear", "constant" # Optimization optimizer: str = "adamw" # or "lion", "adafactor" adam_beta1: float = 0.9 adam_beta2: float = 0.95 adam_epsilon: float = 1e-8 # Mixed precision and efficiency use_amp: bool = True # Automatic Mixed Precision gradient_checkpointing: bool = False use_ddp: bool = False # Distributed Data Parallel # Checkpointing save_dir: str = "checkpoints" save_steps: int = 1000 save_total_limit: int = 5 resume_from_checkpoint: Optional[str] = None # Evaluation eval_steps: int = 500 eval_strategy: str = "steps" # or "epoch" logging_steps: int = 100 # Early stopping early_stopping: bool = False early_stopping_patience: int = 3 early_stopping_threshold: float = 0.001 # Reproducibility seed: int = 42 # Device device: str = "auto" # "auto", "cpu", "cuda", "cuda:0", etc. # Logging log_to_wandb: bool = False wandb_project: Optional[str] = None wandb_run_name: Optional[str] = None def __post_init__(self): """Validate configuration""" assert self.batch_size > 0, "batch_size must be positive" assert self.learning_rate > 0, "learning_rate must be positive" assert self.gradient_accumulation_steps > 0, "gradient_accumulation_steps must be positive"