Files
NOVA/nova_train/trainer.py
Dani a7f091aa45 Initial commit: NOVA - Neuro-Optimizing Versatile Agent
Complete transformer LLM built from scratch with:

Core Features:
- Full transformer architecture (RoPE, RMSNorm, SwiGLU, KV-cache)
- SentencePiece tokenizer (BPE/Unigram)
- Training pipeline (AMP, gradient checkpointing, DDP)
- Persona system with personality matrix (NO AI disclosure by default)
- Genetic evolution (NOVA-EVO) for hyperparameter optimization
- Legal-only data pipeline with license tracking
- Chat interface (CLI + REST API)
- Conversation memory (SQLite)

Model Sizes:
- 125M, 350M, 1.3B, 3B parameters
- Local-first, runs on CPU or GPU
- Python 3.10.6+, PyTorch 2.0+

Personas:
- girlfriend_gentle (high warmth, high empathy)
- girlfriend_playful (high humor, high playfulness)
- girlfriend_supportive (balanced, default)

Documentation:
- Complete README with quickstart
- Model card with ethical considerations
- Privacy documentation (local-first, zero telemetry)
- Data licenses and attribution
- Contributing guide

Infrastructure:
- GitHub Actions CI/CD
- Comprehensive test suite
- Quickstart script
- CLI tool

License: Apache 2.0

🤖 Generated with Claude Code
https://claude.com/claude-code

Co-Authored-By: Claude <noreply@anthropic.com>
2025-10-12 20:56:37 -04:00

331 lines
12 KiB
Python

"""
NOVA Trainer - Training loop with AMP, gradient checkpointing, DDP
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader, DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from pathlib import Path
from tqdm import tqdm
from typing import Optional, Dict, Any
import os
import json
import time
import math
from .config import TrainingConfig
from nova_core import NovaTransformer, ModelConfig
class NovaTrainer:
"""
Trainer for NOVA models with support for:
- Automatic Mixed Precision (AMP)
- Gradient checkpointing
- Distributed Data Parallel (DDP)
- Resume from checkpoint
- Early stopping
- Cosine learning rate schedule with warmup
"""
def __init__(
self,
model: NovaTransformer,
train_config: TrainingConfig,
train_dataloader: DataLoader,
val_dataloader: Optional[DataLoader] = None,
):
"""
Args:
model: NOVA transformer model
train_config: Training configuration
train_dataloader: Training data loader
val_dataloader: Optional validation data loader
"""
self.config = train_config
self.model = model
self.train_dataloader = train_dataloader
self.val_dataloader = val_dataloader
# Setup device
self.device = self._setup_device()
self.model.to(self.device)
# Setup distributed training if needed
self.is_ddp = train_config.use_ddp and torch.cuda.device_count() > 1
if self.is_ddp:
self.model = DDP(self.model)
# Setup optimizer
self.optimizer = self._create_optimizer()
# Setup learning rate scheduler
total_steps = len(train_dataloader) * train_config.num_epochs // train_config.gradient_accumulation_steps
self.scheduler = self._create_scheduler(total_steps)
# Setup AMP
self.use_amp = train_config.use_amp and self.device.type == 'cuda'
self.scaler = GradScaler() if self.use_amp else None
# Tracking
self.global_step = 0
self.current_epoch = 0
self.best_val_loss = float('inf')
self.patience_counter = 0
# Create save directory
Path(train_config.save_dir).mkdir(parents=True, exist_ok=True)
def _setup_device(self) -> torch.device:
"""Setup training device"""
if self.config.device == "auto":
if torch.cuda.is_available():
return torch.device("cuda")
else:
return torch.device("cpu")
else:
return torch.device(self.config.device)
def _create_optimizer(self) -> optim.Optimizer:
"""Create optimizer"""
# Separate parameters with and without weight decay
decay_params = []
no_decay_params = []
for name, param in self.model.named_parameters():
if param.requires_grad:
# Don't apply weight decay to biases and layer norms
if 'bias' in name or 'norm' in name:
no_decay_params.append(param)
else:
decay_params.append(param)
param_groups = [
{'params': decay_params, 'weight_decay': self.config.weight_decay},
{'params': no_decay_params, 'weight_decay': 0.0}
]
if self.config.optimizer.lower() == "adamw":
return optim.AdamW(
param_groups,
lr=self.config.learning_rate,
betas=(self.config.adam_beta1, self.config.adam_beta2),
eps=self.config.adam_epsilon
)
else:
raise ValueError(f"Unknown optimizer: {self.config.optimizer}")
def _create_scheduler(self, total_steps: int):
"""Create learning rate scheduler with warmup"""
if self.config.lr_scheduler == "cosine":
def lr_lambda(current_step: int):
# Warmup
if current_step < self.config.warmup_steps:
return float(current_step) / float(max(1, self.config.warmup_steps))
# Cosine decay
progress = float(current_step - self.config.warmup_steps) / float(max(1, total_steps - self.config.warmup_steps))
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
return optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda)
elif self.config.lr_scheduler == "linear":
def lr_lambda(current_step: int):
if current_step < self.config.warmup_steps:
return float(current_step) / float(max(1, self.config.warmup_steps))
return max(0.0, float(total_steps - current_step) / float(max(1, total_steps - self.config.warmup_steps)))
return optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda)
else: # constant
return optim.lr_scheduler.LambdaLR(self.optimizer, lambda _: 1.0)
def train(self):
"""Main training loop"""
print(f"Starting training on {self.device}")
print(f" Num epochs: {self.config.num_epochs}")
print(f" Batch size: {self.config.batch_size}")
print(f" Gradient accumulation steps: {self.config.gradient_accumulation_steps}")
print(f" Learning rate: {self.config.learning_rate}")
print(f" Mixed precision: {self.use_amp}")
for epoch in range(self.current_epoch, self.config.num_epochs):
self.current_epoch = epoch
print(f"\nEpoch {epoch + 1}/{self.config.num_epochs}")
# Training
train_loss = self.train_epoch()
print(f" Train loss: {train_loss:.4f}")
# Validation
if self.val_dataloader is not None:
val_loss = self.evaluate()
print(f" Val loss: {val_loss:.4f}")
# Early stopping check
if self.config.early_stopping:
if val_loss < self.best_val_loss - self.config.early_stopping_threshold:
self.best_val_loss = val_loss
self.patience_counter = 0
self.save_checkpoint(is_best=True)
else:
self.patience_counter += 1
if self.patience_counter >= self.config.early_stopping_patience:
print(f"Early stopping triggered after {epoch + 1} epochs")
break
print("\nTraining complete!")
def train_epoch(self) -> float:
"""Train for one epoch"""
self.model.train()
total_loss = 0.0
num_batches = 0
progress_bar = tqdm(self.train_dataloader, desc="Training")
for batch_idx, batch in enumerate(progress_bar):
loss = self.train_step(batch)
total_loss += loss
num_batches += 1
progress_bar.set_postfix({"loss": f"{loss:.4f}", "lr": f"{self.scheduler.get_last_lr()[0]:.2e}"})
return total_loss / num_batches
def train_step(self, batch: Dict[str, torch.Tensor]) -> float:
"""Single training step"""
input_ids = batch['input_ids'].to(self.device)
labels = batch.get('labels', input_ids).to(self.device)
# Forward pass with AMP
with autocast(enabled=self.use_amp):
outputs = self.model(input_ids=input_ids)
logits = outputs['logits']
# Calculate loss (next token prediction)
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100
)
# Scale loss for gradient accumulation
loss = loss / self.config.gradient_accumulation_steps
# Backward pass with gradient scaling
if self.use_amp:
self.scaler.scale(loss).backward()
else:
loss.backward()
# Update weights every N accumulation steps
if (self.global_step + 1) % self.config.gradient_accumulation_steps == 0:
# Gradient clipping
if self.use_amp:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.config.max_grad_norm
)
# Optimizer step
if self.use_amp:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()
self.global_step += 1
# Checkpointing
if self.global_step % self.config.save_steps == 0:
self.save_checkpoint()
return loss.item() * self.config.gradient_accumulation_steps
@torch.no_grad()
def evaluate(self) -> float:
"""Evaluate on validation set"""
self.model.eval()
total_loss = 0.0
num_batches = 0
for batch in tqdm(self.val_dataloader, desc="Evaluating"):
input_ids = batch['input_ids'].to(self.device)
labels = batch.get('labels', input_ids).to(self.device)
with autocast(enabled=self.use_amp):
outputs = self.model(input_ids=input_ids)
logits = outputs['logits']
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100
)
total_loss += loss.item()
num_batches += 1
return total_loss / num_batches
def save_checkpoint(self, is_best: bool = False):
"""Save model checkpoint"""
model_to_save = self.model.module if self.is_ddp else self.model
checkpoint = {
'model_state_dict': model_to_save.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'global_step': self.global_step,
'epoch': self.current_epoch,
'config': self.config.__dict__,
}
if self.use_amp:
checkpoint['scaler_state_dict'] = self.scaler.state_dict()
# Save regular checkpoint
checkpoint_path = Path(self.config.save_dir) / f"checkpoint-{self.global_step}.pt"
torch.save(checkpoint, checkpoint_path)
print(f" Checkpoint saved: {checkpoint_path}")
# Save best model
if is_best:
best_path = Path(self.config.save_dir) / "best_model.pt"
torch.save(checkpoint, best_path)
print(f" Best model saved: {best_path}")
def load_checkpoint(self, checkpoint_path: str):
"""Load from checkpoint"""
checkpoint = torch.load(checkpoint_path, map_location=self.device)
model_to_load = self.model.module if self.is_ddp else self.model
model_to_load.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
self.global_step = checkpoint['global_step']
self.current_epoch = checkpoint['epoch']
if self.use_amp and 'scaler_state_dict' in checkpoint:
self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
print(f"Resumed from checkpoint: {checkpoint_path}")
print(f" Global step: {self.global_step}")
print(f" Epoch: {self.current_epoch}")