Complete transformer LLM built from scratch with: Core Features: - Full transformer architecture (RoPE, RMSNorm, SwiGLU, KV-cache) - SentencePiece tokenizer (BPE/Unigram) - Training pipeline (AMP, gradient checkpointing, DDP) - Persona system with personality matrix (NO AI disclosure by default) - Genetic evolution (NOVA-EVO) for hyperparameter optimization - Legal-only data pipeline with license tracking - Chat interface (CLI + REST API) - Conversation memory (SQLite) Model Sizes: - 125M, 350M, 1.3B, 3B parameters - Local-first, runs on CPU or GPU - Python 3.10.6+, PyTorch 2.0+ Personas: - girlfriend_gentle (high warmth, high empathy) - girlfriend_playful (high humor, high playfulness) - girlfriend_supportive (balanced, default) Documentation: - Complete README with quickstart - Model card with ethical considerations - Privacy documentation (local-first, zero telemetry) - Data licenses and attribution - Contributing guide Infrastructure: - GitHub Actions CI/CD - Comprehensive test suite - Quickstart script - CLI tool License: Apache 2.0 🤖 Generated with Claude Code https://claude.com/claude-code Co-Authored-By: Claude <noreply@anthropic.com>
331 lines
12 KiB
Python
331 lines
12 KiB
Python
"""
|
|
NOVA Trainer - Training loop with AMP, gradient checkpointing, DDP
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.cuda.amp import autocast, GradScaler
|
|
from torch.utils.data import DataLoader, DistributedSampler
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
import torch.distributed as dist
|
|
from pathlib import Path
|
|
from tqdm import tqdm
|
|
from typing import Optional, Dict, Any
|
|
import os
|
|
import json
|
|
import time
|
|
import math
|
|
|
|
from .config import TrainingConfig
|
|
from nova_core import NovaTransformer, ModelConfig
|
|
|
|
|
|
class NovaTrainer:
|
|
"""
|
|
Trainer for NOVA models with support for:
|
|
- Automatic Mixed Precision (AMP)
|
|
- Gradient checkpointing
|
|
- Distributed Data Parallel (DDP)
|
|
- Resume from checkpoint
|
|
- Early stopping
|
|
- Cosine learning rate schedule with warmup
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: NovaTransformer,
|
|
train_config: TrainingConfig,
|
|
train_dataloader: DataLoader,
|
|
val_dataloader: Optional[DataLoader] = None,
|
|
):
|
|
"""
|
|
Args:
|
|
model: NOVA transformer model
|
|
train_config: Training configuration
|
|
train_dataloader: Training data loader
|
|
val_dataloader: Optional validation data loader
|
|
"""
|
|
self.config = train_config
|
|
self.model = model
|
|
self.train_dataloader = train_dataloader
|
|
self.val_dataloader = val_dataloader
|
|
|
|
# Setup device
|
|
self.device = self._setup_device()
|
|
self.model.to(self.device)
|
|
|
|
# Setup distributed training if needed
|
|
self.is_ddp = train_config.use_ddp and torch.cuda.device_count() > 1
|
|
if self.is_ddp:
|
|
self.model = DDP(self.model)
|
|
|
|
# Setup optimizer
|
|
self.optimizer = self._create_optimizer()
|
|
|
|
# Setup learning rate scheduler
|
|
total_steps = len(train_dataloader) * train_config.num_epochs // train_config.gradient_accumulation_steps
|
|
self.scheduler = self._create_scheduler(total_steps)
|
|
|
|
# Setup AMP
|
|
self.use_amp = train_config.use_amp and self.device.type == 'cuda'
|
|
self.scaler = GradScaler() if self.use_amp else None
|
|
|
|
# Tracking
|
|
self.global_step = 0
|
|
self.current_epoch = 0
|
|
self.best_val_loss = float('inf')
|
|
self.patience_counter = 0
|
|
|
|
# Create save directory
|
|
Path(train_config.save_dir).mkdir(parents=True, exist_ok=True)
|
|
|
|
def _setup_device(self) -> torch.device:
|
|
"""Setup training device"""
|
|
if self.config.device == "auto":
|
|
if torch.cuda.is_available():
|
|
return torch.device("cuda")
|
|
else:
|
|
return torch.device("cpu")
|
|
else:
|
|
return torch.device(self.config.device)
|
|
|
|
def _create_optimizer(self) -> optim.Optimizer:
|
|
"""Create optimizer"""
|
|
# Separate parameters with and without weight decay
|
|
decay_params = []
|
|
no_decay_params = []
|
|
|
|
for name, param in self.model.named_parameters():
|
|
if param.requires_grad:
|
|
# Don't apply weight decay to biases and layer norms
|
|
if 'bias' in name or 'norm' in name:
|
|
no_decay_params.append(param)
|
|
else:
|
|
decay_params.append(param)
|
|
|
|
param_groups = [
|
|
{'params': decay_params, 'weight_decay': self.config.weight_decay},
|
|
{'params': no_decay_params, 'weight_decay': 0.0}
|
|
]
|
|
|
|
if self.config.optimizer.lower() == "adamw":
|
|
return optim.AdamW(
|
|
param_groups,
|
|
lr=self.config.learning_rate,
|
|
betas=(self.config.adam_beta1, self.config.adam_beta2),
|
|
eps=self.config.adam_epsilon
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown optimizer: {self.config.optimizer}")
|
|
|
|
def _create_scheduler(self, total_steps: int):
|
|
"""Create learning rate scheduler with warmup"""
|
|
if self.config.lr_scheduler == "cosine":
|
|
def lr_lambda(current_step: int):
|
|
# Warmup
|
|
if current_step < self.config.warmup_steps:
|
|
return float(current_step) / float(max(1, self.config.warmup_steps))
|
|
# Cosine decay
|
|
progress = float(current_step - self.config.warmup_steps) / float(max(1, total_steps - self.config.warmup_steps))
|
|
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
|
|
|
|
return optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda)
|
|
|
|
elif self.config.lr_scheduler == "linear":
|
|
def lr_lambda(current_step: int):
|
|
if current_step < self.config.warmup_steps:
|
|
return float(current_step) / float(max(1, self.config.warmup_steps))
|
|
return max(0.0, float(total_steps - current_step) / float(max(1, total_steps - self.config.warmup_steps)))
|
|
|
|
return optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda)
|
|
|
|
else: # constant
|
|
return optim.lr_scheduler.LambdaLR(self.optimizer, lambda _: 1.0)
|
|
|
|
def train(self):
|
|
"""Main training loop"""
|
|
print(f"Starting training on {self.device}")
|
|
print(f" Num epochs: {self.config.num_epochs}")
|
|
print(f" Batch size: {self.config.batch_size}")
|
|
print(f" Gradient accumulation steps: {self.config.gradient_accumulation_steps}")
|
|
print(f" Learning rate: {self.config.learning_rate}")
|
|
print(f" Mixed precision: {self.use_amp}")
|
|
|
|
for epoch in range(self.current_epoch, self.config.num_epochs):
|
|
self.current_epoch = epoch
|
|
print(f"\nEpoch {epoch + 1}/{self.config.num_epochs}")
|
|
|
|
# Training
|
|
train_loss = self.train_epoch()
|
|
print(f" Train loss: {train_loss:.4f}")
|
|
|
|
# Validation
|
|
if self.val_dataloader is not None:
|
|
val_loss = self.evaluate()
|
|
print(f" Val loss: {val_loss:.4f}")
|
|
|
|
# Early stopping check
|
|
if self.config.early_stopping:
|
|
if val_loss < self.best_val_loss - self.config.early_stopping_threshold:
|
|
self.best_val_loss = val_loss
|
|
self.patience_counter = 0
|
|
self.save_checkpoint(is_best=True)
|
|
else:
|
|
self.patience_counter += 1
|
|
if self.patience_counter >= self.config.early_stopping_patience:
|
|
print(f"Early stopping triggered after {epoch + 1} epochs")
|
|
break
|
|
|
|
print("\nTraining complete!")
|
|
|
|
def train_epoch(self) -> float:
|
|
"""Train for one epoch"""
|
|
self.model.train()
|
|
total_loss = 0.0
|
|
num_batches = 0
|
|
|
|
progress_bar = tqdm(self.train_dataloader, desc="Training")
|
|
|
|
for batch_idx, batch in enumerate(progress_bar):
|
|
loss = self.train_step(batch)
|
|
total_loss += loss
|
|
num_batches += 1
|
|
|
|
progress_bar.set_postfix({"loss": f"{loss:.4f}", "lr": f"{self.scheduler.get_last_lr()[0]:.2e}"})
|
|
|
|
return total_loss / num_batches
|
|
|
|
def train_step(self, batch: Dict[str, torch.Tensor]) -> float:
|
|
"""Single training step"""
|
|
input_ids = batch['input_ids'].to(self.device)
|
|
labels = batch.get('labels', input_ids).to(self.device)
|
|
|
|
# Forward pass with AMP
|
|
with autocast(enabled=self.use_amp):
|
|
outputs = self.model(input_ids=input_ids)
|
|
logits = outputs['logits']
|
|
|
|
# Calculate loss (next token prediction)
|
|
shift_logits = logits[..., :-1, :].contiguous()
|
|
shift_labels = labels[..., 1:].contiguous()
|
|
|
|
loss = nn.functional.cross_entropy(
|
|
shift_logits.view(-1, shift_logits.size(-1)),
|
|
shift_labels.view(-1),
|
|
ignore_index=-100
|
|
)
|
|
|
|
# Scale loss for gradient accumulation
|
|
loss = loss / self.config.gradient_accumulation_steps
|
|
|
|
# Backward pass with gradient scaling
|
|
if self.use_amp:
|
|
self.scaler.scale(loss).backward()
|
|
else:
|
|
loss.backward()
|
|
|
|
# Update weights every N accumulation steps
|
|
if (self.global_step + 1) % self.config.gradient_accumulation_steps == 0:
|
|
# Gradient clipping
|
|
if self.use_amp:
|
|
self.scaler.unscale_(self.optimizer)
|
|
|
|
torch.nn.utils.clip_grad_norm_(
|
|
self.model.parameters(),
|
|
self.config.max_grad_norm
|
|
)
|
|
|
|
# Optimizer step
|
|
if self.use_amp:
|
|
self.scaler.step(self.optimizer)
|
|
self.scaler.update()
|
|
else:
|
|
self.optimizer.step()
|
|
|
|
self.scheduler.step()
|
|
self.optimizer.zero_grad()
|
|
|
|
self.global_step += 1
|
|
|
|
# Checkpointing
|
|
if self.global_step % self.config.save_steps == 0:
|
|
self.save_checkpoint()
|
|
|
|
return loss.item() * self.config.gradient_accumulation_steps
|
|
|
|
@torch.no_grad()
|
|
def evaluate(self) -> float:
|
|
"""Evaluate on validation set"""
|
|
self.model.eval()
|
|
total_loss = 0.0
|
|
num_batches = 0
|
|
|
|
for batch in tqdm(self.val_dataloader, desc="Evaluating"):
|
|
input_ids = batch['input_ids'].to(self.device)
|
|
labels = batch.get('labels', input_ids).to(self.device)
|
|
|
|
with autocast(enabled=self.use_amp):
|
|
outputs = self.model(input_ids=input_ids)
|
|
logits = outputs['logits']
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous()
|
|
shift_labels = labels[..., 1:].contiguous()
|
|
|
|
loss = nn.functional.cross_entropy(
|
|
shift_logits.view(-1, shift_logits.size(-1)),
|
|
shift_labels.view(-1),
|
|
ignore_index=-100
|
|
)
|
|
|
|
total_loss += loss.item()
|
|
num_batches += 1
|
|
|
|
return total_loss / num_batches
|
|
|
|
def save_checkpoint(self, is_best: bool = False):
|
|
"""Save model checkpoint"""
|
|
model_to_save = self.model.module if self.is_ddp else self.model
|
|
|
|
checkpoint = {
|
|
'model_state_dict': model_to_save.state_dict(),
|
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
'scheduler_state_dict': self.scheduler.state_dict(),
|
|
'global_step': self.global_step,
|
|
'epoch': self.current_epoch,
|
|
'config': self.config.__dict__,
|
|
}
|
|
|
|
if self.use_amp:
|
|
checkpoint['scaler_state_dict'] = self.scaler.state_dict()
|
|
|
|
# Save regular checkpoint
|
|
checkpoint_path = Path(self.config.save_dir) / f"checkpoint-{self.global_step}.pt"
|
|
torch.save(checkpoint, checkpoint_path)
|
|
print(f" Checkpoint saved: {checkpoint_path}")
|
|
|
|
# Save best model
|
|
if is_best:
|
|
best_path = Path(self.config.save_dir) / "best_model.pt"
|
|
torch.save(checkpoint, best_path)
|
|
print(f" Best model saved: {best_path}")
|
|
|
|
def load_checkpoint(self, checkpoint_path: str):
|
|
"""Load from checkpoint"""
|
|
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
|
|
|
model_to_load = self.model.module if self.is_ddp else self.model
|
|
model_to_load.load_state_dict(checkpoint['model_state_dict'])
|
|
|
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
|
self.global_step = checkpoint['global_step']
|
|
self.current_epoch = checkpoint['epoch']
|
|
|
|
if self.use_amp and 'scaler_state_dict' in checkpoint:
|
|
self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
|
|
|
print(f"Resumed from checkpoint: {checkpoint_path}")
|
|
print(f" Global step: {self.global_step}")
|
|
print(f" Epoch: {self.current_epoch}")
|