59 lines
1.8 KiB
Python
59 lines
1.8 KiB
Python
# training/trainer.py
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from tqdm import tqdm
|
|
|
|
|
|
class TextDataset(Dataset):
|
|
def __init__(self, tokens, context_size):
|
|
self.tokens = tokens
|
|
self.context_size = context_size
|
|
|
|
def __len__(self):
|
|
return len(self.tokens) - self.context_size
|
|
|
|
def __getitem__(self, idx):
|
|
x = torch.tensor(self.tokens[idx:idx+self.context_size], dtype=torch.long)
|
|
y = torch.tensor(self.tokens[idx+1:idx+self.context_size+1], dtype=torch.long)
|
|
return x, y
|
|
|
|
|
|
def train(model, dataset, device, lr, batch_size, epochs=1):
|
|
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
|
optimizer = optim.Adam(model.parameters(), lr=lr)
|
|
loss_fn = nn.CrossEntropyLoss()
|
|
|
|
model.to(device)
|
|
model.train()
|
|
torch.autograd.set_detect_anomaly(True)
|
|
|
|
for epoch in range(epochs):
|
|
total_loss = 0.0
|
|
progress = tqdm(loader, desc=f"Epoch {epoch+1}/{epochs}", unit="batch", dynamic_ncols=True)
|
|
|
|
if len(loader) == 0:
|
|
print("❌ No data to train on. Check your token count or dataset.")
|
|
return
|
|
|
|
for x, y in progress:
|
|
x, y = x.to(device), y.to(device)
|
|
logits = model(x)
|
|
loss = loss_fn(logits.view(-1, logits.size(-1)), y.view(-1))
|
|
|
|
if torch.isnan(loss):
|
|
print("❌ Loss is NaN! Aborting training.")
|
|
exit(1)
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
|
optimizer.step()
|
|
|
|
total_loss += loss.item()
|
|
progress.set_postfix(loss=loss.item())
|
|
|
|
print(f"[Epoch {epoch+1}] Avg Loss: {total_loss / len(loader):.4f}")
|