""" NOVA Trainer - Training loop with AMP, gradient checkpointing, DDP """ import torch import torch.nn as nn import torch.optim as optim from torch.cuda.amp import autocast, GradScaler from torch.utils.data import DataLoader, DistributedSampler from torch.nn.parallel import DistributedDataParallel as DDP import torch.distributed as dist from pathlib import Path from tqdm import tqdm from typing import Optional, Dict, Any import os import json import time import math from .config import TrainingConfig from nova_core import NovaTransformer, ModelConfig class NovaTrainer: """ Trainer for NOVA models with support for: - Automatic Mixed Precision (AMP) - Gradient checkpointing - Distributed Data Parallel (DDP) - Resume from checkpoint - Early stopping - Cosine learning rate schedule with warmup """ def __init__( self, model: NovaTransformer, train_config: TrainingConfig, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader] = None, ): """ Args: model: NOVA transformer model train_config: Training configuration train_dataloader: Training data loader val_dataloader: Optional validation data loader """ self.config = train_config self.model = model self.train_dataloader = train_dataloader self.val_dataloader = val_dataloader # Setup device self.device = self._setup_device() self.model.to(self.device) # Setup distributed training if needed self.is_ddp = train_config.use_ddp and torch.cuda.device_count() > 1 if self.is_ddp: self.model = DDP(self.model) # Setup optimizer self.optimizer = self._create_optimizer() # Setup learning rate scheduler total_steps = len(train_dataloader) * train_config.num_epochs // train_config.gradient_accumulation_steps self.scheduler = self._create_scheduler(total_steps) # Setup AMP self.use_amp = train_config.use_amp and self.device.type == 'cuda' self.scaler = GradScaler() if self.use_amp else None # Tracking self.global_step = 0 self.current_epoch = 0 self.best_val_loss = float('inf') self.patience_counter = 0 # Create save directory Path(train_config.save_dir).mkdir(parents=True, exist_ok=True) def _setup_device(self) -> torch.device: """Setup training device""" if self.config.device == "auto": if torch.cuda.is_available(): return torch.device("cuda") else: return torch.device("cpu") else: return torch.device(self.config.device) def _create_optimizer(self) -> optim.Optimizer: """Create optimizer""" # Separate parameters with and without weight decay decay_params = [] no_decay_params = [] for name, param in self.model.named_parameters(): if param.requires_grad: # Don't apply weight decay to biases and layer norms if 'bias' in name or 'norm' in name: no_decay_params.append(param) else: decay_params.append(param) param_groups = [ {'params': decay_params, 'weight_decay': self.config.weight_decay}, {'params': no_decay_params, 'weight_decay': 0.0} ] if self.config.optimizer.lower() == "adamw": return optim.AdamW( param_groups, lr=self.config.learning_rate, betas=(self.config.adam_beta1, self.config.adam_beta2), eps=self.config.adam_epsilon ) else: raise ValueError(f"Unknown optimizer: {self.config.optimizer}") def _create_scheduler(self, total_steps: int): """Create learning rate scheduler with warmup""" if self.config.lr_scheduler == "cosine": def lr_lambda(current_step: int): # Warmup if current_step < self.config.warmup_steps: return float(current_step) / float(max(1, self.config.warmup_steps)) # Cosine decay progress = float(current_step - self.config.warmup_steps) / float(max(1, total_steps - self.config.warmup_steps)) return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) return optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda) elif self.config.lr_scheduler == "linear": def lr_lambda(current_step: int): if current_step < self.config.warmup_steps: return float(current_step) / float(max(1, self.config.warmup_steps)) return max(0.0, float(total_steps - current_step) / float(max(1, total_steps - self.config.warmup_steps))) return optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda) else: # constant return optim.lr_scheduler.LambdaLR(self.optimizer, lambda _: 1.0) def train(self): """Main training loop""" print(f"Starting training on {self.device}") print(f" Num epochs: {self.config.num_epochs}") print(f" Batch size: {self.config.batch_size}") print(f" Gradient accumulation steps: {self.config.gradient_accumulation_steps}") print(f" Learning rate: {self.config.learning_rate}") print(f" Mixed precision: {self.use_amp}") for epoch in range(self.current_epoch, self.config.num_epochs): self.current_epoch = epoch print(f"\nEpoch {epoch + 1}/{self.config.num_epochs}") # Training train_loss = self.train_epoch() print(f" Train loss: {train_loss:.4f}") # Validation if self.val_dataloader is not None: val_loss = self.evaluate() print(f" Val loss: {val_loss:.4f}") # Early stopping check if self.config.early_stopping: if val_loss < self.best_val_loss - self.config.early_stopping_threshold: self.best_val_loss = val_loss self.patience_counter = 0 self.save_checkpoint(is_best=True) else: self.patience_counter += 1 if self.patience_counter >= self.config.early_stopping_patience: print(f"Early stopping triggered after {epoch + 1} epochs") break print("\nTraining complete!") def train_epoch(self) -> float: """Train for one epoch""" self.model.train() total_loss = 0.0 num_batches = 0 progress_bar = tqdm(self.train_dataloader, desc="Training") for batch_idx, batch in enumerate(progress_bar): loss = self.train_step(batch) total_loss += loss num_batches += 1 progress_bar.set_postfix({"loss": f"{loss:.4f}", "lr": f"{self.scheduler.get_last_lr()[0]:.2e}"}) return total_loss / num_batches def train_step(self, batch: Dict[str, torch.Tensor]) -> float: """Single training step""" input_ids = batch['input_ids'].to(self.device) labels = batch.get('labels', input_ids).to(self.device) # Forward pass with AMP with autocast(enabled=self.use_amp): outputs = self.model(input_ids=input_ids) logits = outputs['logits'] # Calculate loss (next token prediction) shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = nn.functional.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100 ) # Scale loss for gradient accumulation loss = loss / self.config.gradient_accumulation_steps # Backward pass with gradient scaling if self.use_amp: self.scaler.scale(loss).backward() else: loss.backward() # Update weights every N accumulation steps if (self.global_step + 1) % self.config.gradient_accumulation_steps == 0: # Gradient clipping if self.use_amp: self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.config.max_grad_norm ) # Optimizer step if self.use_amp: self.scaler.step(self.optimizer) self.scaler.update() else: self.optimizer.step() self.scheduler.step() self.optimizer.zero_grad() self.global_step += 1 # Checkpointing if self.global_step % self.config.save_steps == 0: self.save_checkpoint() return loss.item() * self.config.gradient_accumulation_steps @torch.no_grad() def evaluate(self) -> float: """Evaluate on validation set""" self.model.eval() total_loss = 0.0 num_batches = 0 for batch in tqdm(self.val_dataloader, desc="Evaluating"): input_ids = batch['input_ids'].to(self.device) labels = batch.get('labels', input_ids).to(self.device) with autocast(enabled=self.use_amp): outputs = self.model(input_ids=input_ids) logits = outputs['logits'] shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = nn.functional.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100 ) total_loss += loss.item() num_batches += 1 return total_loss / num_batches def save_checkpoint(self, is_best: bool = False): """Save model checkpoint""" model_to_save = self.model.module if self.is_ddp else self.model checkpoint = { 'model_state_dict': model_to_save.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), 'global_step': self.global_step, 'epoch': self.current_epoch, 'config': self.config.__dict__, } if self.use_amp: checkpoint['scaler_state_dict'] = self.scaler.state_dict() # Save regular checkpoint checkpoint_path = Path(self.config.save_dir) / f"checkpoint-{self.global_step}.pt" torch.save(checkpoint, checkpoint_path) print(f" Checkpoint saved: {checkpoint_path}") # Save best model if is_best: best_path = Path(self.config.save_dir) / "best_model.pt" torch.save(checkpoint, best_path) print(f" Best model saved: {best_path}") def load_checkpoint(self, checkpoint_path: str): """Load from checkpoint""" checkpoint = torch.load(checkpoint_path, map_location=self.device) model_to_load = self.model.module if self.is_ddp else self.model model_to_load.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) self.global_step = checkpoint['global_step'] self.current_epoch = checkpoint['epoch'] if self.use_amp and 'scaler_state_dict' in checkpoint: self.scaler.load_state_dict(checkpoint['scaler_state_dict']) print(f"Resumed from checkpoint: {checkpoint_path}") print(f" Global step: {self.global_step}") print(f" Epoch: {self.current_epoch}")