Ruby/training/train.py

94 lines
2.9 KiB
Python

import glob
import os
import torch
import torch.nn as nn
import torch.optim as optim
from evolution.ga import evolve
from models.transformer import TransformerGenerator
from models.discriminator import Discriminator
from utils.tokenizer import HybridTokenizer
def chunked(lst, size):
"""Yield successive chunks from a list."""
for i in range(0, len(lst), size):
yield lst[i:i + size]
def train():
vocab_file = os.path.join('vocab', 'vocab.json')
tokenizer = HybridTokenizer(vocab_file)
book_paths = glob.glob(os.path.join('data', 'books', '*.txt'))
texts = []
for path in book_paths:
with open(path, 'r', encoding='utf-8') as f:
texts.append(f.read())
if not tokenizer.char_to_id:
tokenizer.build_vocab(texts)
seq_len = 128
sequences = []
for text in texts:
token_ids = tokenizer.encode(text)
for i in range(0, len(token_ids) - seq_len, seq_len):
sequences.append(
torch.tensor(token_ids[i:i + seq_len], dtype=torch.long)
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pop_size, generations = 10, 50
vocab_size = len(tokenizer.word_to_id) + len(tokenizer.char_to_id)
embed_dim, num_heads, mlp_dim, num_layers = 256, 8, 512, 4
population = [
TransformerGenerator(
vocab_size, embed_dim, num_heads, mlp_dim, num_layers, seq_len
).to(device)
for _ in range(pop_size)
]
discriminator = Discriminator(vocab_size, embed_dim).to(device)
disc_opt = optim.Adam(discriminator.parameters(), lr=1e-4)
bce = nn.BCEWithLogitsLoss()
for gen_idx in range(generations):
# Evaluate fitness
fitnesses = []
for g in population:
inp = torch.randint(0, vocab_size, (1, seq_len), device=device)
out = g(inp).argmax(-1)
score = discriminator(out)
fitnesses.append(-bce(score, torch.ones_like(score)).item())
# Train discriminator
for batch in chunked(sequences, 8):
real = torch.stack(batch).to(device)
fake_in = torch.randint(0, vocab_size, real.shape, device=device)
fake = population[0](fake_in).argmax(-1).detach()
disc_opt.zero_grad()
loss_r = bce(
discriminator(real),
torch.ones(real.size(0), 1, device=device)
)
loss_f = bce(
discriminator(fake),
torch.zeros(fake.size(0), 1, device=device)
)
(loss_r + loss_f).div_(2).backward()
disc_opt.step()
# Evolve population
population = evolve(population, fitnesses)
print(f'Gen {gen_idx:03d}: best fitness = {max(fitnesses):.4f}')
os.makedirs('models', exist_ok=True)
best = population[fitnesses.index(max(fitnesses))]
torch.save(best.state_dict(), 'models/best_gen.pt')
# kick off training immediately (no __main__ guard)
train()