Ruby/train.py

47 lines
1.3 KiB
Python

import torch
from torch.utils.data import DataLoader
from core.dataset import CharDataset
from core.model import GPT, GPTConfig
def train():
# hyperparameters
books_dir = './books'
block_size = 128
batch_size = 32
epochs = 10
lr = 3e-4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# dataset & model
dataset = CharDataset(books_dir, block_size)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
config = GPTConfig(
vocab_size=dataset.vocab_size,
block_size=block_size,
n_layer=6,
n_head=6,
n_embd=384
)
model = GPT(config).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
model.train()
for epoch in range(1, epochs + 1):
total_loss = 0.0
for xb, yb in loader:
xb, yb = xb.to(device), yb.to(device)
optimizer.zero_grad()
_, loss = model(xb, yb)
loss.backward()
optimizer.step()
total_loss += loss.item()
avg = total_loss / len(loader)
print(f'Epoch {epoch}/{epochs} — avg loss: {avg:.4f}')
# save checkpoint each epoch
torch.save(model.state_dict(), 'model.pth')
if __name__ == '__main__':
train()