Revert "Goofed up what file was being edited"
This reverts commit c7a15f63ddc6fd25b2b3bda9f37c05e863ea6285.
This commit is contained in:
parent
c7a15f63dd
commit
151ba084f4
108
main.py
108
main.py
@ -1,17 +1,98 @@
|
||||
import os
|
||||
import glob
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import discord
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from ruby_engine import RubyEngine
|
||||
from models.transformer import TransformerGenerator
|
||||
from utils.tokenizer import HybridTokenizer
|
||||
|
||||
# ──────── Setup ────────
|
||||
|
||||
load_dotenv()
|
||||
TOKEN = os.getenv("DISCORD_TOKEN")
|
||||
if not TOKEN:
|
||||
raise RuntimeError("DISCORD_TOKEN missing in .env")
|
||||
raise RuntimeError("Missing DISCORD_TOKEN in .env")
|
||||
|
||||
# instantiate your “Ruby” engine
|
||||
ruby = RubyEngine() # uses GPU if available
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
# print(f"[INFO] Using device: {device}")
|
||||
|
||||
# ──────── Tokenizer & Vocab ────────
|
||||
|
||||
vocab_file = os.path.join("vocab", "vocab.json")
|
||||
tokenizer = HybridTokenizer(vocab_file)
|
||||
|
||||
# If vocab.json doesn’t exist yet, build it from your books:
|
||||
if not tokenizer.char_to_id:
|
||||
book_paths = glob.glob(os.path.join("data", "books", "*.txt"))
|
||||
texts = []
|
||||
for path in book_paths:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
texts.append(f.read())
|
||||
tokenizer.build_vocab(texts)
|
||||
print(f"[INFO] Built vocab ({len(tokenizer.word_to_id)} words + "
|
||||
f"{len(tokenizer.char_to_id)} chars)")
|
||||
|
||||
# ──────── Model Setup ────────
|
||||
|
||||
vocab_size = len(tokenizer.word_to_id) + len(tokenizer.char_to_id)
|
||||
embed_dim, num_heads, mlp_dim, num_layers = 256, 8, 512, 4
|
||||
max_seq_len = 128
|
||||
|
||||
model = TransformerGenerator(
|
||||
vocab_size, embed_dim, num_heads, mlp_dim, num_layers, max_seq_len
|
||||
).to(device)
|
||||
|
||||
ckpt = os.path.join("models", "best_gen.pt")
|
||||
if os.path.isfile(ckpt):
|
||||
state = torch.load(ckpt, map_location=device)
|
||||
model.load_state_dict(state)
|
||||
print("[INFO] Loaded checkpoint models/best_gen.pt")
|
||||
else:
|
||||
print("[INFO] No checkpoint found; starting from random weights")
|
||||
|
||||
model.eval()
|
||||
|
||||
# ──────── Online Trainer ────────
|
||||
|
||||
class OnlineTrainer:
|
||||
"""Fine-tune the generator on each new exchange."""
|
||||
|
||||
def __init__(self, model, lr=1e-5):
|
||||
self.model = model
|
||||
self.optimizer = optim.Adam(model.parameters(), lr=lr)
|
||||
self.criterion = nn.CrossEntropyLoss()
|
||||
self.device = device
|
||||
|
||||
def train_example(self, text: str):
|
||||
# simple causal training: predict each next token in `text`
|
||||
token_ids = tokenizer.encode(text)
|
||||
if len(token_ids) < 2:
|
||||
return
|
||||
inp = torch.tensor([token_ids[:-1]], device=self.device)
|
||||
tgt = torch.tensor([token_ids[1:]], device=self.device)
|
||||
|
||||
self.model.train()
|
||||
self.optimizer.zero_grad()
|
||||
logits = self.model(inp) # (1, seq_len-1, vocab_size)
|
||||
loss = self.criterion(
|
||||
logits.view(-1, logits.size(-1)),
|
||||
tgt.view(-1)
|
||||
)
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
self.model.eval()
|
||||
|
||||
# persist updated weights
|
||||
os.makedirs("models", exist_ok=True)
|
||||
torch.save(self.model.state_dict(), ckpt)
|
||||
|
||||
trainer = OnlineTrainer(model)
|
||||
|
||||
# ──────── Discord Client ────────
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
@ -25,16 +106,27 @@ async def on_ready():
|
||||
|
||||
@client.event
|
||||
async def on_message(message):
|
||||
# ignore Ruby’s own messages
|
||||
if message.author == client.user:
|
||||
return
|
||||
|
||||
content = message.content.strip()
|
||||
if not content:
|
||||
return
|
||||
|
||||
# generate + train in one call
|
||||
reply = ruby.generate(content)
|
||||
await message.channel.send(reply)
|
||||
ruby.train_on(f"User: {content}\nRuby: {reply}")
|
||||
# → Generate Ruby’s reply
|
||||
ids = tokenizer.encode(content)
|
||||
inp = torch.tensor([ids], dtype=torch.long, device=device)
|
||||
with torch.no_grad():
|
||||
out_ids = model(inp).argmax(-1).squeeze().cpu().tolist()
|
||||
reply = tokenizer.decode(out_ids)
|
||||
|
||||
await message.channel.send(reply)
|
||||
|
||||
# → Optionally train on this new example
|
||||
sample = f"User: {content}\nRuby: {reply}"
|
||||
trainer.train_example(sample)
|
||||
|
||||
# ──────── Run ────────
|
||||
|
||||
client.run(TOKEN)
|
||||
|
@ -1,17 +1,40 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import os
|
||||
|
||||
import discord
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from ruby_heart import RubyHeart
|
||||
|
||||
load_dotenv()
|
||||
TOKEN = os.getenv("DISCORD_TOKEN")
|
||||
if not TOKEN:
|
||||
raise RuntimeError("DISCORD_TOKEN missing in .env")
|
||||
|
||||
# instantiate your “Ruby” engine
|
||||
ruby = RubyHeart() # uses GPU if available
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
client = discord.Client(intents=intents)
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, vocab_size: int, embed_dim: int):
|
||||
super().__init__()
|
||||
self.embedding = nn.Embedding(vocab_size, embed_dim)
|
||||
self.lstm = nn.LSTM(embed_dim, embed_dim, batch_first=True)
|
||||
self.fc = nn.Linear(embed_dim, 1)
|
||||
@client.event
|
||||
async def on_ready():
|
||||
print(f"Ruby is online as {client.user}")
|
||||
|
||||
def forward(self, x):
|
||||
# x: (batch, seq_len)
|
||||
emb = self.embedding(x)
|
||||
_, (h_n, _) = self.lstm(emb)
|
||||
# h_n[-1]: (batch, embed_dim)
|
||||
return self.fc(h_n[-1])
|
||||
|
||||
@client.event
|
||||
async def on_message(message):
|
||||
if message.author == client.user:
|
||||
return
|
||||
content = message.content.strip()
|
||||
if not content:
|
||||
return
|
||||
|
||||
# generate + train in one call
|
||||
reply = ruby.generate(content)
|
||||
await message.channel.send(reply)
|
||||
ruby.train_on(f"User: {content}\nRuby: {reply}")
|
||||
|
||||
|
||||
client.run(TOKEN)
|
||||
|
Loading…
x
Reference in New Issue
Block a user