import glob import os import torch import torch.nn as nn import torch.optim as optim from models.transformer import TransformerGenerator from models.discriminator import Discriminator from utils.tokenizer import HybridTokenizer import torch.nn.functional as F class RubyHeart: def __init__( self, books_dir="data/books", vocab_file="vocab/vocab.json", model_file="models/best_gen.pt", device=None, ): self.device = device or torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) # tokenizer & vocab self.tokenizer = HybridTokenizer(vocab_file) if not self.tokenizer.char_to_id: self._build_vocab(books_dir) # model init vs = ( len(self.tokenizer.word_to_id) + len(self.tokenizer.char_to_id) ) self.model = TransformerGenerator( vocab_size=vs, embed_dim=256, num_heads=8, mlp_dim=512, num_layers=4, max_seq_len=128, ).to(self.device) self.model_file = model_file self._load_checkpoint(model_file) # online trainer self.trainer = self._make_trainer() def _build_vocab(self, books_dir): paths = glob.glob(os.path.join(books_dir, "*.txt")) texts = [open(p, encoding="utf-8").read() for p in paths] self.tokenizer.build_vocab(texts) def _load_checkpoint(self, path): if os.path.isfile(path): state = torch.load(path, map_location=self.device, weights_only=True) self.model.load_state_dict(state) # else: start from scratch def _make_trainer(self, lr=1e-5): opt = optim.Adam(self.model.parameters(), lr=lr) loss_fn = nn.CrossEntropyLoss() return {"opt": opt, "loss": loss_fn} @staticmethod def _top_k_top_p(logits, top_k=50, top_p=0.9): # (same filtering code as before) if top_k > 0: kth = torch.topk(logits, top_k)[0][..., -1, None] logits = torch.where( logits < kth, float("-inf"), logits ) if top_p > 0.0: sorted_logits, indices = torch.sort( logits, descending=True ) cum_probs = F.softmax(sorted_logits, dim=-1).cumsum(dim=-1) mask = cum_probs > top_p mask[..., 1:] = mask[..., :-1].clone() mask[..., 0] = False remove = indices[mask] logits[remove] = float("-inf") return logits def generate(self, prompt, max_len=64, temp=1.0, top_k=50, top_p=0.9): self.model.eval() ids = self.tokenizer.encode(prompt) input_ids = torch.tensor([ids], device=self.device) with torch.no_grad(): for _ in range(max_len): logits = self.model(input_ids)[:, -1, :] / temp filt = self._top_k_top_p(logits, top_k, top_p) probs = F.softmax(filt, dim=-1) nxt = torch.multinomial(probs, 1) input_ids = torch.cat([input_ids, nxt], dim=-1) return self.tokenizer.decode(input_ids[0].cpu().tolist()) def train_on(self, text): ids = self.tokenizer.encode(text) if len(ids) < 2: return inp = torch.tensor([ids[:-1]], device=self.device) tgt = torch.tensor([ids[1:]], device=self.device) self.model.train() self.trainer["opt"].zero_grad() logits = self.model(inp) loss = self.trainer["loss"]( logits.view(-1, logits.size(-1)), tgt.view(-1), ) loss.backward() self.trainer["opt"].step() torch.save(self.model.state_dict(), self.model_file) self.model.eval()