Goofed up what file was being edited

This commit is contained in:
Dani 2025-05-05 17:40:07 -04:00
parent 23800bf323
commit ddc8db5aa4
2 changed files with 21 additions and 136 deletions

106
main.py
View File

@ -1,98 +1,17 @@
import os import os
import glob
import torch
import torch.nn as nn
import torch.optim as optim
import discord import discord
from dotenv import load_dotenv from dotenv import load_dotenv
from models.transformer import TransformerGenerator from ruby_engine import RubyEngine
from utils.tokenizer import HybridTokenizer
# ──────── Setup ────────
load_dotenv() load_dotenv()
TOKEN = os.getenv("DISCORD_TOKEN") TOKEN = os.getenv("DISCORD_TOKEN")
if not TOKEN: if not TOKEN:
raise RuntimeError("Missing DISCORD_TOKEN in .env") raise RuntimeError("DISCORD_TOKEN missing in .env")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # instantiate your “Ruby” engine
# print(f"[INFO] Using device: {device}") ruby = RubyEngine() # uses GPU if available
# ──────── Tokenizer & Vocab ────────
vocab_file = os.path.join("vocab", "vocab.json")
tokenizer = HybridTokenizer(vocab_file)
# If vocab.json doesnt exist yet, build it from your books:
if not tokenizer.char_to_id:
book_paths = glob.glob(os.path.join("data", "books", "*.txt"))
texts = []
for path in book_paths:
with open(path, "r", encoding="utf-8") as f:
texts.append(f.read())
tokenizer.build_vocab(texts)
print(f"[INFO] Built vocab ({len(tokenizer.word_to_id)} words + "
f"{len(tokenizer.char_to_id)} chars)")
# ──────── Model Setup ────────
vocab_size = len(tokenizer.word_to_id) + len(tokenizer.char_to_id)
embed_dim, num_heads, mlp_dim, num_layers = 256, 8, 512, 4
max_seq_len = 128
model = TransformerGenerator(
vocab_size, embed_dim, num_heads, mlp_dim, num_layers, max_seq_len
).to(device)
ckpt = os.path.join("models", "best_gen.pt")
if os.path.isfile(ckpt):
state = torch.load(ckpt, map_location=device)
model.load_state_dict(state)
print("[INFO] Loaded checkpoint models/best_gen.pt")
else:
print("[INFO] No checkpoint found; starting from random weights")
model.eval()
# ──────── Online Trainer ────────
class OnlineTrainer:
"""Fine-tune the generator on each new exchange."""
def __init__(self, model, lr=1e-5):
self.model = model
self.optimizer = optim.Adam(model.parameters(), lr=lr)
self.criterion = nn.CrossEntropyLoss()
self.device = device
def train_example(self, text: str):
# simple causal training: predict each next token in `text`
token_ids = tokenizer.encode(text)
if len(token_ids) < 2:
return
inp = torch.tensor([token_ids[:-1]], device=self.device)
tgt = torch.tensor([token_ids[1:]], device=self.device)
self.model.train()
self.optimizer.zero_grad()
logits = self.model(inp) # (1, seq_len-1, vocab_size)
loss = self.criterion(
logits.view(-1, logits.size(-1)),
tgt.view(-1)
)
loss.backward()
self.optimizer.step()
self.model.eval()
# persist updated weights
os.makedirs("models", exist_ok=True)
torch.save(self.model.state_dict(), ckpt)
trainer = OnlineTrainer(model)
# ──────── Discord Client ────────
intents = discord.Intents.default() intents = discord.Intents.default()
intents.message_content = True intents.message_content = True
@ -106,27 +25,16 @@ async def on_ready():
@client.event @client.event
async def on_message(message): async def on_message(message):
# ignore Rubys own messages
if message.author == client.user: if message.author == client.user:
return return
content = message.content.strip() content = message.content.strip()
if not content: if not content:
return return
# → Generate Rubys reply # generate + train in one call
ids = tokenizer.encode(content) reply = ruby.generate(content)
inp = torch.tensor([ids], dtype=torch.long, device=device)
with torch.no_grad():
out_ids = model(inp).argmax(-1).squeeze().cpu().tolist()
reply = tokenizer.decode(out_ids)
await message.channel.send(reply) await message.channel.send(reply)
ruby.train_on(f"User: {content}\nRuby: {reply}")
# → Optionally train on this new example
sample = f"User: {content}\nRuby: {reply}"
trainer.train_example(sample)
# ──────── Run ────────
client.run(TOKEN) client.run(TOKEN)

View File

@ -1,40 +1,17 @@
import os import torch
import torch.nn as nn
import discord
from dotenv import load_dotenv
from ruby_heart import RubyHeart
load_dotenv()
TOKEN = os.getenv("DISCORD_TOKEN")
if not TOKEN:
raise RuntimeError("DISCORD_TOKEN missing in .env")
# instantiate your “Ruby” engine
ruby = RubyHeart() # uses GPU if available
intents = discord.Intents.default()
intents.message_content = True
client = discord.Client(intents=intents)
@client.event class Discriminator(nn.Module):
async def on_ready(): def __init__(self, vocab_size: int, embed_dim: int):
print(f"Ruby is online as {client.user}") super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTM(embed_dim, embed_dim, batch_first=True)
self.fc = nn.Linear(embed_dim, 1)
def forward(self, x):
@client.event # x: (batch, seq_len)
async def on_message(message): emb = self.embedding(x)
if message.author == client.user: _, (h_n, _) = self.lstm(emb)
return # h_n[-1]: (batch, embed_dim)
content = message.content.strip() return self.fc(h_n[-1])
if not content:
return
# generate + train in one call
reply = ruby.generate(content)
await message.channel.send(reply)
ruby.train_on(f"User: {content}\nRuby: {reply}")
client.run(TOKEN)