import glob import os import torch import torch.nn as nn import torch.optim as optim from evolution.ga import evolve from models.transformer import TransformerGenerator from models.discriminator import Discriminator from utils.tokenizer import HybridTokenizer def chunked(lst, size): """Yield successive chunks from a list.""" for i in range(0, len(lst), size): yield lst[i:i + size] def train(): vocab_file = os.path.join('vocab', 'vocab.json') tokenizer = HybridTokenizer(vocab_file) book_paths = glob.glob(os.path.join('data', 'books', '*.txt')) texts = [] for path in book_paths: with open(path, 'r', encoding='utf-8') as f: texts.append(f.read()) if not tokenizer.char_to_id: tokenizer.build_vocab(texts) seq_len = 128 sequences = [] for text in texts: token_ids = tokenizer.encode(text) for i in range(0, len(token_ids) - seq_len, seq_len): sequences.append( torch.tensor(token_ids[i:i + seq_len], dtype=torch.long) ) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') pop_size, generations = 10, 50 vocab_size = len(tokenizer.word_to_id) + len(tokenizer.char_to_id) embed_dim, num_heads, mlp_dim, num_layers = 256, 8, 512, 4 population = [ TransformerGenerator( vocab_size, embed_dim, num_heads, mlp_dim, num_layers, seq_len ).to(device) for _ in range(pop_size) ] discriminator = Discriminator(vocab_size, embed_dim).to(device) disc_opt = optim.Adam(discriminator.parameters(), lr=1e-4) bce = nn.BCEWithLogitsLoss() for gen_idx in range(generations): # Evaluate fitness fitnesses = [] for g in population: inp = torch.randint(0, vocab_size, (1, seq_len), device=device) out = g(inp).argmax(-1) score = discriminator(out) fitnesses.append(-bce(score, torch.ones_like(score)).item()) # Train discriminator for batch in chunked(sequences, 8): real = torch.stack(batch).to(device) fake_in = torch.randint(0, vocab_size, real.shape, device=device) fake = population[0](fake_in).argmax(-1).detach() disc_opt.zero_grad() loss_r = bce( discriminator(real), torch.ones(real.size(0), 1, device=device) ) loss_f = bce( discriminator(fake), torch.zeros(fake.size(0), 1, device=device) ) (loss_r + loss_f).div_(2).backward() disc_opt.step() # Evolve population population = evolve(population, fitnesses) print(f'Gen {gen_idx:03d}: best fitness = {max(fitnesses):.4f}') os.makedirs('models', exist_ok=True) best = population[fitnesses.index(max(fitnesses))] torch.save(best.state_dict(), 'models/best_gen.pt') # kick off training immediately (no __main__ guard) train()