import torch from torch.utils.data import DataLoader from core.dataset import CharDataset from core.model import GPT, GPTConfig def train(): # hyperparameters books_dir = './books' block_size = 128 batch_size = 32 epochs = 10 lr = 3e-4 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # dataset & model dataset = CharDataset(books_dir, block_size) loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) config = GPTConfig( vocab_size=dataset.vocab_size, block_size=block_size, n_layer=6, n_head=6, n_embd=384 ) model = GPT(config).to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=lr) model.train() for epoch in range(1, epochs + 1): total_loss = 0.0 for xb, yb in loader: xb, yb = xb.to(device), yb.to(device) optimizer.zero_grad() _, loss = model(xb, yb) loss.backward() optimizer.step() total_loss += loss.item() avg = total_loss / len(loader) print(f'Epoch {epoch}/{epochs} — avg loss: {avg:.4f}') # save checkpoint each epoch torch.save(model.state_dict(), 'model.pth') if __name__ == '__main__': train()