47 lines
1.4 KiB
Python
47 lines
1.4 KiB
Python
import torch
|
|
from torch import nn
|
|
from torch.optim import Adam
|
|
from model import TinyGPT
|
|
from dataset import create_dataset
|
|
from tokenizer import load_vocab
|
|
|
|
def pad_sequence(seq, max_len):
|
|
"""Pads a sequence to the given maximum length."""
|
|
return torch.cat([seq, torch.zeros(max_len - len(seq), dtype=torch.long)], dim=0)
|
|
|
|
def train_model():
|
|
vocab = load_vocab()
|
|
inputs, targets = create_dataset()
|
|
|
|
# Determine the maximum sequence length for padding
|
|
max_len = max(len(seq) for seq in inputs + targets)
|
|
|
|
# Pad inputs and targets
|
|
inputs = [pad_sequence(seq, max_len) for seq in inputs]
|
|
targets = [pad_sequence(seq, max_len) for seq in targets]
|
|
|
|
# Convert to batch tensors
|
|
inputs = torch.stack(inputs).cuda()
|
|
targets = torch.stack(targets).cuda()
|
|
|
|
# Model setup
|
|
model = TinyGPT(vocab_size=len(vocab), embed_size=32, num_heads=2, num_layers=2).cuda()
|
|
criterion = nn.CrossEntropyLoss()
|
|
optimizer = Adam(model.parameters(), lr=0.001)
|
|
|
|
# Training loop
|
|
for epoch in range(100):
|
|
optimizer.zero_grad()
|
|
output = model(inputs, targets)
|
|
loss = criterion(output.view(-1, len(vocab)), targets.view(-1))
|
|
loss.backward()
|
|
optimizer.step()
|
|
print(f"Epoch {epoch + 1}, Loss: {loss.item()}")
|
|
|
|
# Save the model
|
|
torch.save(model.state_dict(), "ruby_model.pth")
|
|
print("Model saved as ruby_model.pth")
|
|
|
|
if __name__ == "__main__":
|
|
train_model()
|