Ruby/main.py
Dani 208b1f190c Revert "Goofed up what file was being edited"
This reverts commit c7a15f63ddc6fd25b2b3bda9f37c05e863ea6285.
2025-05-05 17:40:07 -04:00

133 lines
3.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import glob
import torch
import torch.nn as nn
import torch.optim as optim
import discord
from dotenv import load_dotenv
from models.transformer import TransformerGenerator
from utils.tokenizer import HybridTokenizer
# ──────── Setup ────────
load_dotenv()
TOKEN = os.getenv("DISCORD_TOKEN")
if not TOKEN:
raise RuntimeError("Missing DISCORD_TOKEN in .env")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"[INFO] Using device: {device}")
# ──────── Tokenizer & Vocab ────────
vocab_file = os.path.join("vocab", "vocab.json")
tokenizer = HybridTokenizer(vocab_file)
# If vocab.json doesnt exist yet, build it from your books:
if not tokenizer.char_to_id:
book_paths = glob.glob(os.path.join("data", "books", "*.txt"))
texts = []
for path in book_paths:
with open(path, "r", encoding="utf-8") as f:
texts.append(f.read())
tokenizer.build_vocab(texts)
print(f"[INFO] Built vocab ({len(tokenizer.word_to_id)} words + "
f"{len(tokenizer.char_to_id)} chars)")
# ──────── Model Setup ────────
vocab_size = len(tokenizer.word_to_id) + len(tokenizer.char_to_id)
embed_dim, num_heads, mlp_dim, num_layers = 256, 8, 512, 4
max_seq_len = 128
model = TransformerGenerator(
vocab_size, embed_dim, num_heads, mlp_dim, num_layers, max_seq_len
).to(device)
ckpt = os.path.join("models", "best_gen.pt")
if os.path.isfile(ckpt):
state = torch.load(ckpt, map_location=device)
model.load_state_dict(state)
print("[INFO] Loaded checkpoint models/best_gen.pt")
else:
print("[INFO] No checkpoint found; starting from random weights")
model.eval()
# ──────── Online Trainer ────────
class OnlineTrainer:
"""Fine-tune the generator on each new exchange."""
def __init__(self, model, lr=1e-5):
self.model = model
self.optimizer = optim.Adam(model.parameters(), lr=lr)
self.criterion = nn.CrossEntropyLoss()
self.device = device
def train_example(self, text: str):
# simple causal training: predict each next token in `text`
token_ids = tokenizer.encode(text)
if len(token_ids) < 2:
return
inp = torch.tensor([token_ids[:-1]], device=self.device)
tgt = torch.tensor([token_ids[1:]], device=self.device)
self.model.train()
self.optimizer.zero_grad()
logits = self.model(inp) # (1, seq_len-1, vocab_size)
loss = self.criterion(
logits.view(-1, logits.size(-1)),
tgt.view(-1)
)
loss.backward()
self.optimizer.step()
self.model.eval()
# persist updated weights
os.makedirs("models", exist_ok=True)
torch.save(self.model.state_dict(), ckpt)
trainer = OnlineTrainer(model)
# ──────── Discord Client ────────
intents = discord.Intents.default()
intents.message_content = True
client = discord.Client(intents=intents)
@client.event
async def on_ready():
print(f"Ruby is online as {client.user}")
@client.event
async def on_message(message):
# ignore Rubys own messages
if message.author == client.user:
return
content = message.content.strip()
if not content:
return
# → Generate Rubys reply
ids = tokenizer.encode(content)
inp = torch.tensor([ids], dtype=torch.long, device=device)
with torch.no_grad():
out_ids = model(inp).argmax(-1).squeeze().cpu().tolist()
reply = tokenizer.decode(out_ids)
await message.channel.send(reply)
# → Optionally train on this new example
sample = f"User: {content}\nRuby: {reply}"
trainer.train_example(sample)
# ──────── Run ────────
client.run(TOKEN)