Revert "Goofed up what file was being edited"
This reverts commit c7a15f63ddc6fd25b2b3bda9f37c05e863ea6285.
This commit is contained in:
parent
ddc8db5aa4
commit
208b1f190c
108
main.py
108
main.py
@ -1,17 +1,98 @@
|
|||||||
import os
|
import os
|
||||||
|
import glob
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
import discord
|
import discord
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from ruby_engine import RubyEngine
|
from models.transformer import TransformerGenerator
|
||||||
|
from utils.tokenizer import HybridTokenizer
|
||||||
|
|
||||||
|
# ──────── Setup ────────
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
TOKEN = os.getenv("DISCORD_TOKEN")
|
TOKEN = os.getenv("DISCORD_TOKEN")
|
||||||
if not TOKEN:
|
if not TOKEN:
|
||||||
raise RuntimeError("DISCORD_TOKEN missing in .env")
|
raise RuntimeError("Missing DISCORD_TOKEN in .env")
|
||||||
|
|
||||||
# instantiate your “Ruby” engine
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
ruby = RubyEngine() # uses GPU if available
|
# print(f"[INFO] Using device: {device}")
|
||||||
|
|
||||||
|
# ──────── Tokenizer & Vocab ────────
|
||||||
|
|
||||||
|
vocab_file = os.path.join("vocab", "vocab.json")
|
||||||
|
tokenizer = HybridTokenizer(vocab_file)
|
||||||
|
|
||||||
|
# If vocab.json doesn’t exist yet, build it from your books:
|
||||||
|
if not tokenizer.char_to_id:
|
||||||
|
book_paths = glob.glob(os.path.join("data", "books", "*.txt"))
|
||||||
|
texts = []
|
||||||
|
for path in book_paths:
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
texts.append(f.read())
|
||||||
|
tokenizer.build_vocab(texts)
|
||||||
|
print(f"[INFO] Built vocab ({len(tokenizer.word_to_id)} words + "
|
||||||
|
f"{len(tokenizer.char_to_id)} chars)")
|
||||||
|
|
||||||
|
# ──────── Model Setup ────────
|
||||||
|
|
||||||
|
vocab_size = len(tokenizer.word_to_id) + len(tokenizer.char_to_id)
|
||||||
|
embed_dim, num_heads, mlp_dim, num_layers = 256, 8, 512, 4
|
||||||
|
max_seq_len = 128
|
||||||
|
|
||||||
|
model = TransformerGenerator(
|
||||||
|
vocab_size, embed_dim, num_heads, mlp_dim, num_layers, max_seq_len
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
ckpt = os.path.join("models", "best_gen.pt")
|
||||||
|
if os.path.isfile(ckpt):
|
||||||
|
state = torch.load(ckpt, map_location=device)
|
||||||
|
model.load_state_dict(state)
|
||||||
|
print("[INFO] Loaded checkpoint models/best_gen.pt")
|
||||||
|
else:
|
||||||
|
print("[INFO] No checkpoint found; starting from random weights")
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# ──────── Online Trainer ────────
|
||||||
|
|
||||||
|
class OnlineTrainer:
|
||||||
|
"""Fine-tune the generator on each new exchange."""
|
||||||
|
|
||||||
|
def __init__(self, model, lr=1e-5):
|
||||||
|
self.model = model
|
||||||
|
self.optimizer = optim.Adam(model.parameters(), lr=lr)
|
||||||
|
self.criterion = nn.CrossEntropyLoss()
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
def train_example(self, text: str):
|
||||||
|
# simple causal training: predict each next token in `text`
|
||||||
|
token_ids = tokenizer.encode(text)
|
||||||
|
if len(token_ids) < 2:
|
||||||
|
return
|
||||||
|
inp = torch.tensor([token_ids[:-1]], device=self.device)
|
||||||
|
tgt = torch.tensor([token_ids[1:]], device=self.device)
|
||||||
|
|
||||||
|
self.model.train()
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
logits = self.model(inp) # (1, seq_len-1, vocab_size)
|
||||||
|
loss = self.criterion(
|
||||||
|
logits.view(-1, logits.size(-1)),
|
||||||
|
tgt.view(-1)
|
||||||
|
)
|
||||||
|
loss.backward()
|
||||||
|
self.optimizer.step()
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
# persist updated weights
|
||||||
|
os.makedirs("models", exist_ok=True)
|
||||||
|
torch.save(self.model.state_dict(), ckpt)
|
||||||
|
|
||||||
|
trainer = OnlineTrainer(model)
|
||||||
|
|
||||||
|
# ──────── Discord Client ────────
|
||||||
|
|
||||||
intents = discord.Intents.default()
|
intents = discord.Intents.default()
|
||||||
intents.message_content = True
|
intents.message_content = True
|
||||||
@ -25,16 +106,27 @@ async def on_ready():
|
|||||||
|
|
||||||
@client.event
|
@client.event
|
||||||
async def on_message(message):
|
async def on_message(message):
|
||||||
|
# ignore Ruby’s own messages
|
||||||
if message.author == client.user:
|
if message.author == client.user:
|
||||||
return
|
return
|
||||||
|
|
||||||
content = message.content.strip()
|
content = message.content.strip()
|
||||||
if not content:
|
if not content:
|
||||||
return
|
return
|
||||||
|
|
||||||
# generate + train in one call
|
# → Generate Ruby’s reply
|
||||||
reply = ruby.generate(content)
|
ids = tokenizer.encode(content)
|
||||||
await message.channel.send(reply)
|
inp = torch.tensor([ids], dtype=torch.long, device=device)
|
||||||
ruby.train_on(f"User: {content}\nRuby: {reply}")
|
with torch.no_grad():
|
||||||
|
out_ids = model(inp).argmax(-1).squeeze().cpu().tolist()
|
||||||
|
reply = tokenizer.decode(out_ids)
|
||||||
|
|
||||||
|
await message.channel.send(reply)
|
||||||
|
|
||||||
|
# → Optionally train on this new example
|
||||||
|
sample = f"User: {content}\nRuby: {reply}"
|
||||||
|
trainer.train_example(sample)
|
||||||
|
|
||||||
|
# ──────── Run ────────
|
||||||
|
|
||||||
client.run(TOKEN)
|
client.run(TOKEN)
|
||||||
|
@ -1,17 +1,40 @@
|
|||||||
import torch
|
import os
|
||||||
import torch.nn as nn
|
|
||||||
|
import discord
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from ruby_heart import RubyHeart
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
TOKEN = os.getenv("DISCORD_TOKEN")
|
||||||
|
if not TOKEN:
|
||||||
|
raise RuntimeError("DISCORD_TOKEN missing in .env")
|
||||||
|
|
||||||
|
# instantiate your “Ruby” engine
|
||||||
|
ruby = RubyHeart() # uses GPU if available
|
||||||
|
|
||||||
|
intents = discord.Intents.default()
|
||||||
|
intents.message_content = True
|
||||||
|
client = discord.Client(intents=intents)
|
||||||
|
|
||||||
|
|
||||||
class Discriminator(nn.Module):
|
@client.event
|
||||||
def __init__(self, vocab_size: int, embed_dim: int):
|
async def on_ready():
|
||||||
super().__init__()
|
print(f"Ruby is online as {client.user}")
|
||||||
self.embedding = nn.Embedding(vocab_size, embed_dim)
|
|
||||||
self.lstm = nn.LSTM(embed_dim, embed_dim, batch_first=True)
|
|
||||||
self.fc = nn.Linear(embed_dim, 1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# x: (batch, seq_len)
|
@client.event
|
||||||
emb = self.embedding(x)
|
async def on_message(message):
|
||||||
_, (h_n, _) = self.lstm(emb)
|
if message.author == client.user:
|
||||||
# h_n[-1]: (batch, embed_dim)
|
return
|
||||||
return self.fc(h_n[-1])
|
content = message.content.strip()
|
||||||
|
if not content:
|
||||||
|
return
|
||||||
|
|
||||||
|
# generate + train in one call
|
||||||
|
reply = ruby.generate(content)
|
||||||
|
await message.channel.send(reply)
|
||||||
|
ruby.train_on(f"User: {content}\nRuby: {reply}")
|
||||||
|
|
||||||
|
|
||||||
|
client.run(TOKEN)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user