# training/trainer.py import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from tqdm import tqdm class TextDataset(Dataset): def __init__(self, tokens, context_size): self.tokens = tokens self.context_size = context_size def __len__(self): return len(self.tokens) - self.context_size def __getitem__(self, idx): x = torch.tensor(self.tokens[idx:idx+self.context_size], dtype=torch.long) y = torch.tensor(self.tokens[idx+1:idx+self.context_size+1], dtype=torch.long) return x, y def train(model, dataset, device, lr, batch_size, epochs=1): loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) optimizer = optim.Adam(model.parameters(), lr=lr) loss_fn = nn.CrossEntropyLoss() model.to(device) model.train() torch.autograd.set_detect_anomaly(True) for epoch in range(epochs): total_loss = 0.0 progress = tqdm(loader, desc=f"Epoch {epoch+1}/{epochs}", unit="batch", dynamic_ncols=True) if len(loader) == 0: print("❌ No data to train on. Check your token count or dataset.") return for x, y in progress: x, y = x.to(device), y.to(device) logits = model(x) loss = loss_fn(logits.view(-1, logits.size(-1)), y.view(-1)) if torch.isnan(loss): print("❌ Loss is NaN! Aborting training.") exit(1) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() total_loss += loss.item() progress.set_postfix(loss=loss.item()) print(f"[Epoch {epoch+1}] Avg Loss: {total_loss / len(loader):.4f}")