import os import glob import torch import torch.nn as nn import torch.optim as optim import discord from dotenv import load_dotenv from models.transformer import TransformerGenerator from utils.tokenizer import HybridTokenizer # ──────── Setup ──────── load_dotenv() TOKEN = os.getenv("DISCORD_TOKEN") if not TOKEN: raise RuntimeError("Missing DISCORD_TOKEN in .env") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # print(f"[INFO] Using device: {device}") # ──────── Tokenizer & Vocab ──────── vocab_file = os.path.join("vocab", "vocab.json") tokenizer = HybridTokenizer(vocab_file) # If vocab.json doesn’t exist yet, build it from your books: if not tokenizer.char_to_id: book_paths = glob.glob(os.path.join("data", "books", "*.txt")) texts = [] for path in book_paths: with open(path, "r", encoding="utf-8") as f: texts.append(f.read()) tokenizer.build_vocab(texts) print(f"[INFO] Built vocab ({len(tokenizer.word_to_id)} words + " f"{len(tokenizer.char_to_id)} chars)") # ──────── Model Setup ──────── vocab_size = len(tokenizer.word_to_id) + len(tokenizer.char_to_id) embed_dim, num_heads, mlp_dim, num_layers = 256, 8, 512, 4 max_seq_len = 128 model = TransformerGenerator( vocab_size, embed_dim, num_heads, mlp_dim, num_layers, max_seq_len ).to(device) ckpt = os.path.join("models", "best_gen.pt") if os.path.isfile(ckpt): state = torch.load(ckpt, map_location=device) model.load_state_dict(state) print("[INFO] Loaded checkpoint models/best_gen.pt") else: print("[INFO] No checkpoint found; starting from random weights") model.eval() # ──────── Online Trainer ──────── class OnlineTrainer: """Fine-tune the generator on each new exchange.""" def __init__(self, model, lr=1e-5): self.model = model self.optimizer = optim.Adam(model.parameters(), lr=lr) self.criterion = nn.CrossEntropyLoss() self.device = device def train_example(self, text: str): # simple causal training: predict each next token in `text` token_ids = tokenizer.encode(text) if len(token_ids) < 2: return inp = torch.tensor([token_ids[:-1]], device=self.device) tgt = torch.tensor([token_ids[1:]], device=self.device) self.model.train() self.optimizer.zero_grad() logits = self.model(inp) # (1, seq_len-1, vocab_size) loss = self.criterion( logits.view(-1, logits.size(-1)), tgt.view(-1) ) loss.backward() self.optimizer.step() self.model.eval() # persist updated weights os.makedirs("models", exist_ok=True) torch.save(self.model.state_dict(), ckpt) trainer = OnlineTrainer(model) # ──────── Discord Client ──────── intents = discord.Intents.default() intents.message_content = True client = discord.Client(intents=intents) @client.event async def on_ready(): print(f"Ruby is online as {client.user}") @client.event async def on_message(message): # ignore Ruby’s own messages if message.author == client.user: return content = message.content.strip() if not content: return # → Generate Ruby’s reply ids = tokenizer.encode(content) inp = torch.tensor([ids], dtype=torch.long, device=device) with torch.no_grad(): out_ids = model(inp).argmax(-1).squeeze().cpu().tolist() reply = tokenizer.decode(out_ids) await message.channel.send(reply) # → Optionally train on this new example sample = f"User: {content}\nRuby: {reply}" trainer.train_example(sample) # ──────── Run ──────── client.run(TOKEN)