Ruby/train.py

47 lines
1.4 KiB
Python
Raw Normal View History

import torch
from torch import nn
from torch.optim import Adam
from model import TinyGPT
from dataset import create_dataset
from tokenizer import load_vocab
def pad_sequence(seq, max_len):
"""Pads a sequence to the given maximum length."""
return torch.cat([seq, torch.zeros(max_len - len(seq), dtype=torch.long)], dim=0)
def train_model():
vocab = load_vocab()
inputs, targets = create_dataset()
# Determine the maximum sequence length for padding
max_len = max(len(seq) for seq in inputs + targets)
# Pad inputs and targets
inputs = [pad_sequence(seq, max_len) for seq in inputs]
targets = [pad_sequence(seq, max_len) for seq in targets]
# Convert to batch tensors
inputs = torch.stack(inputs).cuda()
targets = torch.stack(targets).cuda()
# Model setup
model = TinyGPT(vocab_size=len(vocab), embed_size=32, num_heads=2, num_layers=2).cuda()
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.001)
# Training loop
for epoch in range(100):
optimizer.zero_grad()
output = model(inputs, targets)
loss = criterion(output.view(-1, len(vocab)), targets.view(-1))
loss.backward()
optimizer.step()
print(f"Epoch {epoch + 1}, Loss: {loss.item()}")
# Save the model
torch.save(model.state_dict(), "ruby_model.pth")
print("Model saved as ruby_model.pth")
if __name__ == "__main__":
train_model()