69 lines
2.5 KiB
Python
69 lines
2.5 KiB
Python
# training/trainer.py
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from tqdm import tqdm
|
|
from training.checkpoint import save_checkpoint, load_checkpoint
|
|
|
|
|
|
class TextDataset(Dataset):
|
|
def __init__(self, tokens, context_size):
|
|
self.tokens = tokens
|
|
self.context_size = context_size
|
|
|
|
def __len__(self):
|
|
return len(self.tokens) - self.context_size
|
|
|
|
def __getitem__(self, idx):
|
|
x = torch.tensor(self.tokens[idx:idx+self.context_size], dtype=torch.long)
|
|
y = torch.tensor(self.tokens[idx+1:idx+self.context_size+1], dtype=torch.long)
|
|
return x, y
|
|
|
|
|
|
def train(model, dataset, device, lr, batch_size, epochs=10, checkpoint_interval=1000):
|
|
model = model.to(device)
|
|
model.train()
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
|
|
criterion = nn.CrossEntropyLoss()
|
|
|
|
start_epoch, start_step = load_checkpoint(model, optimizer)
|
|
start_epoch = start_epoch or 0
|
|
start_step = start_step or 0
|
|
|
|
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
|
total_steps = len(loader)
|
|
|
|
for epoch in range(start_epoch, epochs):
|
|
total_loss = 0.0
|
|
|
|
progress = tqdm(enumerate(loader), total=total_steps, desc=f"Epoch {epoch+1}/{epochs}")
|
|
for i, (x, y) in progress:
|
|
if epoch == start_epoch and i < start_step:
|
|
continue # skip already-trained steps
|
|
|
|
x, y = x.to(device), y.to(device)
|
|
optimizer.zero_grad()
|
|
logits = model(x)
|
|
loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
total_loss += loss.item()
|
|
progress.set_postfix(loss=loss.item())
|
|
|
|
# Save every N steps
|
|
if (i + 1) % checkpoint_interval == 0:
|
|
save_checkpoint(model, optimizer, epoch, i + 1, filename="latest.pt")
|
|
save_checkpoint(model, optimizer, epoch, i + 1,
|
|
filename=f"epoch_{epoch:02d}_step_{i+1:05d}.pt")
|
|
print(f"[Checkpoint] Epoch {epoch+1}, Step {i+1}, Loss: {loss.item():.4f}")
|
|
|
|
# Save at end of each epoch
|
|
save_checkpoint(model, optimizer, epoch + 1, 0, filename="latest.pt")
|
|
save_checkpoint(model, optimizer, epoch + 1, 0, filename=f"epoch_{epoch+1:02d}_final.pt")
|
|
|
|
avg_loss = total_loss / len(loader)
|
|
print(f"[Epoch {epoch+1}] Avg Loss: {avg_loss:.4f}") |