133 lines
3.8 KiB
Python
133 lines
3.8 KiB
Python
import os
|
||
import glob
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.optim as optim
|
||
import discord
|
||
from dotenv import load_dotenv
|
||
|
||
from models.transformer import TransformerGenerator
|
||
from utils.tokenizer import HybridTokenizer
|
||
|
||
# ──────── Setup ────────
|
||
|
||
load_dotenv()
|
||
TOKEN = os.getenv("DISCORD_TOKEN")
|
||
if not TOKEN:
|
||
raise RuntimeError("Missing DISCORD_TOKEN in .env")
|
||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
# print(f"[INFO] Using device: {device}")
|
||
|
||
# ──────── Tokenizer & Vocab ────────
|
||
|
||
vocab_file = os.path.join("vocab", "vocab.json")
|
||
tokenizer = HybridTokenizer(vocab_file)
|
||
|
||
# If vocab.json doesn’t exist yet, build it from your books:
|
||
if not tokenizer.char_to_id:
|
||
book_paths = glob.glob(os.path.join("data", "books", "*.txt"))
|
||
texts = []
|
||
for path in book_paths:
|
||
with open(path, "r", encoding="utf-8") as f:
|
||
texts.append(f.read())
|
||
tokenizer.build_vocab(texts)
|
||
print(f"[INFO] Built vocab ({len(tokenizer.word_to_id)} words + "
|
||
f"{len(tokenizer.char_to_id)} chars)")
|
||
|
||
# ──────── Model Setup ────────
|
||
|
||
vocab_size = len(tokenizer.word_to_id) + len(tokenizer.char_to_id)
|
||
embed_dim, num_heads, mlp_dim, num_layers = 256, 8, 512, 4
|
||
max_seq_len = 128
|
||
|
||
model = TransformerGenerator(
|
||
vocab_size, embed_dim, num_heads, mlp_dim, num_layers, max_seq_len
|
||
).to(device)
|
||
|
||
ckpt = os.path.join("models", "best_gen.pt")
|
||
if os.path.isfile(ckpt):
|
||
state = torch.load(ckpt, map_location=device)
|
||
model.load_state_dict(state)
|
||
print("[INFO] Loaded checkpoint models/best_gen.pt")
|
||
else:
|
||
print("[INFO] No checkpoint found; starting from random weights")
|
||
|
||
model.eval()
|
||
|
||
# ──────── Online Trainer ────────
|
||
|
||
class OnlineTrainer:
|
||
"""Fine-tune the generator on each new exchange."""
|
||
|
||
def __init__(self, model, lr=1e-5):
|
||
self.model = model
|
||
self.optimizer = optim.Adam(model.parameters(), lr=lr)
|
||
self.criterion = nn.CrossEntropyLoss()
|
||
self.device = device
|
||
|
||
def train_example(self, text: str):
|
||
# simple causal training: predict each next token in `text`
|
||
token_ids = tokenizer.encode(text)
|
||
if len(token_ids) < 2:
|
||
return
|
||
inp = torch.tensor([token_ids[:-1]], device=self.device)
|
||
tgt = torch.tensor([token_ids[1:]], device=self.device)
|
||
|
||
self.model.train()
|
||
self.optimizer.zero_grad()
|
||
logits = self.model(inp) # (1, seq_len-1, vocab_size)
|
||
loss = self.criterion(
|
||
logits.view(-1, logits.size(-1)),
|
||
tgt.view(-1)
|
||
)
|
||
loss.backward()
|
||
self.optimizer.step()
|
||
self.model.eval()
|
||
|
||
# persist updated weights
|
||
os.makedirs("models", exist_ok=True)
|
||
torch.save(self.model.state_dict(), ckpt)
|
||
|
||
trainer = OnlineTrainer(model)
|
||
|
||
# ──────── Discord Client ────────
|
||
|
||
intents = discord.Intents.default()
|
||
intents.message_content = True
|
||
client = discord.Client(intents=intents)
|
||
|
||
|
||
@client.event
|
||
async def on_ready():
|
||
print(f"Ruby is online as {client.user}")
|
||
|
||
|
||
@client.event
|
||
async def on_message(message):
|
||
# ignore Ruby’s own messages
|
||
if message.author == client.user:
|
||
return
|
||
|
||
content = message.content.strip()
|
||
if not content:
|
||
return
|
||
|
||
# → Generate Ruby’s reply
|
||
ids = tokenizer.encode(content)
|
||
inp = torch.tensor([ids], dtype=torch.long, device=device)
|
||
with torch.no_grad():
|
||
out_ids = model(inp).argmax(-1).squeeze().cpu().tolist()
|
||
reply = tokenizer.decode(out_ids)
|
||
|
||
await message.channel.send(reply)
|
||
|
||
# → Optionally train on this new example
|
||
sample = f"User: {content}\nRuby: {reply}"
|
||
trainer.train_example(sample)
|
||
|
||
# ──────── Run ────────
|
||
|
||
client.run(TOKEN)
|