94 lines
2.9 KiB
Python
94 lines
2.9 KiB
Python
import glob
|
|
import os
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
|
|
from evolution.ga import evolve
|
|
from models.transformer import TransformerGenerator
|
|
from models.discriminator import Discriminator
|
|
from utils.tokenizer import HybridTokenizer
|
|
|
|
|
|
def chunked(lst, size):
|
|
"""Yield successive chunks from a list."""
|
|
for i in range(0, len(lst), size):
|
|
yield lst[i:i + size]
|
|
|
|
|
|
def train():
|
|
vocab_file = os.path.join('vocab', 'vocab.json')
|
|
tokenizer = HybridTokenizer(vocab_file)
|
|
book_paths = glob.glob(os.path.join('data', 'books', '*.txt'))
|
|
texts = []
|
|
for path in book_paths:
|
|
with open(path, 'r', encoding='utf-8') as f:
|
|
texts.append(f.read())
|
|
|
|
if not tokenizer.char_to_id:
|
|
tokenizer.build_vocab(texts)
|
|
|
|
seq_len = 128
|
|
sequences = []
|
|
for text in texts:
|
|
token_ids = tokenizer.encode(text)
|
|
for i in range(0, len(token_ids) - seq_len, seq_len):
|
|
sequences.append(
|
|
torch.tensor(token_ids[i:i + seq_len], dtype=torch.long)
|
|
)
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
pop_size, generations = 10, 50
|
|
vocab_size = len(tokenizer.word_to_id) + len(tokenizer.char_to_id)
|
|
embed_dim, num_heads, mlp_dim, num_layers = 256, 8, 512, 4
|
|
|
|
population = [
|
|
TransformerGenerator(
|
|
vocab_size, embed_dim, num_heads, mlp_dim, num_layers, seq_len
|
|
).to(device)
|
|
for _ in range(pop_size)
|
|
]
|
|
discriminator = Discriminator(vocab_size, embed_dim).to(device)
|
|
disc_opt = optim.Adam(discriminator.parameters(), lr=1e-4)
|
|
bce = nn.BCEWithLogitsLoss()
|
|
|
|
for gen_idx in range(generations):
|
|
# Evaluate fitness
|
|
fitnesses = []
|
|
for g in population:
|
|
inp = torch.randint(0, vocab_size, (1, seq_len), device=device)
|
|
out = g(inp).argmax(-1)
|
|
score = discriminator(out)
|
|
fitnesses.append(-bce(score, torch.ones_like(score)).item())
|
|
|
|
# Train discriminator
|
|
for batch in chunked(sequences, 8):
|
|
real = torch.stack(batch).to(device)
|
|
fake_in = torch.randint(0, vocab_size, real.shape, device=device)
|
|
fake = population[0](fake_in).argmax(-1).detach()
|
|
|
|
disc_opt.zero_grad()
|
|
loss_r = bce(
|
|
discriminator(real),
|
|
torch.ones(real.size(0), 1, device=device)
|
|
)
|
|
loss_f = bce(
|
|
discriminator(fake),
|
|
torch.zeros(fake.size(0), 1, device=device)
|
|
)
|
|
(loss_r + loss_f).div_(2).backward()
|
|
disc_opt.step()
|
|
|
|
# Evolve population
|
|
population = evolve(population, fitnesses)
|
|
print(f'Gen {gen_idx:03d}: best fitness = {max(fitnesses):.4f}')
|
|
|
|
os.makedirs('models', exist_ok=True)
|
|
best = population[fitnesses.index(max(fitnesses))]
|
|
torch.save(best.state_dict(), 'models/best_gen.pt')
|
|
|
|
|
|
# kick off training immediately (no __main__ guard)
|
|
train()
|