diff --git a/main.py b/main.py index f600f22..077b2cf 100644 --- a/main.py +++ b/main.py @@ -1,17 +1,98 @@ import os +import glob +import torch +import torch.nn as nn +import torch.optim as optim import discord from dotenv import load_dotenv -from ruby_engine import RubyEngine +from models.transformer import TransformerGenerator +from utils.tokenizer import HybridTokenizer + +# ──────── Setup ──────── load_dotenv() TOKEN = os.getenv("DISCORD_TOKEN") if not TOKEN: - raise RuntimeError("DISCORD_TOKEN missing in .env") + raise RuntimeError("Missing DISCORD_TOKEN in .env") -# instantiate your “Ruby” engine -ruby = RubyEngine() # uses GPU if available +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# print(f"[INFO] Using device: {device}") + +# ──────── Tokenizer & Vocab ──────── + +vocab_file = os.path.join("vocab", "vocab.json") +tokenizer = HybridTokenizer(vocab_file) + +# If vocab.json doesn’t exist yet, build it from your books: +if not tokenizer.char_to_id: + book_paths = glob.glob(os.path.join("data", "books", "*.txt")) + texts = [] + for path in book_paths: + with open(path, "r", encoding="utf-8") as f: + texts.append(f.read()) + tokenizer.build_vocab(texts) + print(f"[INFO] Built vocab ({len(tokenizer.word_to_id)} words + " + f"{len(tokenizer.char_to_id)} chars)") + +# ──────── Model Setup ──────── + +vocab_size = len(tokenizer.word_to_id) + len(tokenizer.char_to_id) +embed_dim, num_heads, mlp_dim, num_layers = 256, 8, 512, 4 +max_seq_len = 128 + +model = TransformerGenerator( + vocab_size, embed_dim, num_heads, mlp_dim, num_layers, max_seq_len +).to(device) + +ckpt = os.path.join("models", "best_gen.pt") +if os.path.isfile(ckpt): + state = torch.load(ckpt, map_location=device) + model.load_state_dict(state) + print("[INFO] Loaded checkpoint models/best_gen.pt") +else: + print("[INFO] No checkpoint found; starting from random weights") + +model.eval() + +# ──────── Online Trainer ──────── + +class OnlineTrainer: + """Fine-tune the generator on each new exchange.""" + + def __init__(self, model, lr=1e-5): + self.model = model + self.optimizer = optim.Adam(model.parameters(), lr=lr) + self.criterion = nn.CrossEntropyLoss() + self.device = device + + def train_example(self, text: str): + # simple causal training: predict each next token in `text` + token_ids = tokenizer.encode(text) + if len(token_ids) < 2: + return + inp = torch.tensor([token_ids[:-1]], device=self.device) + tgt = torch.tensor([token_ids[1:]], device=self.device) + + self.model.train() + self.optimizer.zero_grad() + logits = self.model(inp) # (1, seq_len-1, vocab_size) + loss = self.criterion( + logits.view(-1, logits.size(-1)), + tgt.view(-1) + ) + loss.backward() + self.optimizer.step() + self.model.eval() + + # persist updated weights + os.makedirs("models", exist_ok=True) + torch.save(self.model.state_dict(), ckpt) + +trainer = OnlineTrainer(model) + +# ──────── Discord Client ──────── intents = discord.Intents.default() intents.message_content = True @@ -25,16 +106,27 @@ async def on_ready(): @client.event async def on_message(message): + # ignore Ruby’s own messages if message.author == client.user: return + content = message.content.strip() if not content: return - # generate + train in one call - reply = ruby.generate(content) - await message.channel.send(reply) - ruby.train_on(f"User: {content}\nRuby: {reply}") + # → Generate Ruby’s reply + ids = tokenizer.encode(content) + inp = torch.tensor([ids], dtype=torch.long, device=device) + with torch.no_grad(): + out_ids = model(inp).argmax(-1).squeeze().cpu().tolist() + reply = tokenizer.decode(out_ids) + await message.channel.send(reply) + + # → Optionally train on this new example + sample = f"User: {content}\nRuby: {reply}" + trainer.train_example(sample) + +# ──────── Run ──────── client.run(TOKEN) diff --git a/models/discriminator.py b/models/discriminator.py index 1241754..18de477 100644 --- a/models/discriminator.py +++ b/models/discriminator.py @@ -1,17 +1,40 @@ -import torch -import torch.nn as nn +import os + +import discord +from dotenv import load_dotenv + +from ruby_heart import RubyHeart + +load_dotenv() +TOKEN = os.getenv("DISCORD_TOKEN") +if not TOKEN: + raise RuntimeError("DISCORD_TOKEN missing in .env") + +# instantiate your “Ruby” engine +ruby = RubyHeart() # uses GPU if available + +intents = discord.Intents.default() +intents.message_content = True +client = discord.Client(intents=intents) -class Discriminator(nn.Module): - def __init__(self, vocab_size: int, embed_dim: int): - super().__init__() - self.embedding = nn.Embedding(vocab_size, embed_dim) - self.lstm = nn.LSTM(embed_dim, embed_dim, batch_first=True) - self.fc = nn.Linear(embed_dim, 1) +@client.event +async def on_ready(): + print(f"Ruby is online as {client.user}") - def forward(self, x): - # x: (batch, seq_len) - emb = self.embedding(x) - _, (h_n, _) = self.lstm(emb) - # h_n[-1]: (batch, embed_dim) - return self.fc(h_n[-1]) + +@client.event +async def on_message(message): + if message.author == client.user: + return + content = message.content.strip() + if not content: + return + + # generate + train in one call + reply = ruby.generate(content) + await message.channel.send(reply) + ruby.train_on(f"User: {content}\nRuby: {reply}") + + +client.run(TOKEN)