Files
Catlin/training/trainer.py

69 lines
2.5 KiB
Python

# training/trainer.py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from training.checkpoint import save_checkpoint, load_checkpoint
class TextDataset(Dataset):
def __init__(self, tokens, context_size):
self.tokens = tokens
self.context_size = context_size
def __len__(self):
return len(self.tokens) - self.context_size
def __getitem__(self, idx):
x = torch.tensor(self.tokens[idx:idx+self.context_size], dtype=torch.long)
y = torch.tensor(self.tokens[idx+1:idx+self.context_size+1], dtype=torch.long)
return x, y
def train(model, dataset, device, lr, batch_size, epochs=10, checkpoint_interval=1000):
model = model.to(device)
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
start_epoch, start_step = load_checkpoint(model, optimizer)
start_epoch = start_epoch or 0
start_step = start_step or 0
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
total_steps = len(loader)
for epoch in range(start_epoch, epochs):
total_loss = 0.0
progress = tqdm(enumerate(loader), total=total_steps, desc=f"Epoch {epoch+1}/{epochs}")
for i, (x, y) in progress:
if epoch == start_epoch and i < start_step:
continue # skip already-trained steps
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
logits = model(x)
loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
loss.backward()
optimizer.step()
total_loss += loss.item()
progress.set_postfix(loss=loss.item())
# Save every N steps
if (i + 1) % checkpoint_interval == 0:
save_checkpoint(model, optimizer, epoch, i + 1, filename="latest.pt")
save_checkpoint(model, optimizer, epoch, i + 1,
filename=f"epoch_{epoch:02d}_step_{i+1:05d}.pt")
print(f"[Checkpoint] Epoch {epoch+1}, Step {i+1}, Loss: {loss.item():.4f}")
# Save at end of each epoch
save_checkpoint(model, optimizer, epoch + 1, 0, filename="latest.pt")
save_checkpoint(model, optimizer, epoch + 1, 0, filename=f"epoch_{epoch+1:02d}_final.pt")
avg_loss = total_loss / len(loader)
print(f"[Epoch {epoch+1}] Avg Loss: {avg_loss:.4f}")