Goofed up what file was being edited
This commit is contained in:
parent
23800bf323
commit
ddc8db5aa4
106
main.py
106
main.py
@ -1,98 +1,17 @@
|
|||||||
import os
|
import os
|
||||||
import glob
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.optim as optim
|
|
||||||
import discord
|
import discord
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from models.transformer import TransformerGenerator
|
from ruby_engine import RubyEngine
|
||||||
from utils.tokenizer import HybridTokenizer
|
|
||||||
|
|
||||||
# ──────── Setup ────────
|
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
TOKEN = os.getenv("DISCORD_TOKEN")
|
TOKEN = os.getenv("DISCORD_TOKEN")
|
||||||
if not TOKEN:
|
if not TOKEN:
|
||||||
raise RuntimeError("Missing DISCORD_TOKEN in .env")
|
raise RuntimeError("DISCORD_TOKEN missing in .env")
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
# instantiate your “Ruby” engine
|
||||||
# print(f"[INFO] Using device: {device}")
|
ruby = RubyEngine() # uses GPU if available
|
||||||
|
|
||||||
# ──────── Tokenizer & Vocab ────────
|
|
||||||
|
|
||||||
vocab_file = os.path.join("vocab", "vocab.json")
|
|
||||||
tokenizer = HybridTokenizer(vocab_file)
|
|
||||||
|
|
||||||
# If vocab.json doesn’t exist yet, build it from your books:
|
|
||||||
if not tokenizer.char_to_id:
|
|
||||||
book_paths = glob.glob(os.path.join("data", "books", "*.txt"))
|
|
||||||
texts = []
|
|
||||||
for path in book_paths:
|
|
||||||
with open(path, "r", encoding="utf-8") as f:
|
|
||||||
texts.append(f.read())
|
|
||||||
tokenizer.build_vocab(texts)
|
|
||||||
print(f"[INFO] Built vocab ({len(tokenizer.word_to_id)} words + "
|
|
||||||
f"{len(tokenizer.char_to_id)} chars)")
|
|
||||||
|
|
||||||
# ──────── Model Setup ────────
|
|
||||||
|
|
||||||
vocab_size = len(tokenizer.word_to_id) + len(tokenizer.char_to_id)
|
|
||||||
embed_dim, num_heads, mlp_dim, num_layers = 256, 8, 512, 4
|
|
||||||
max_seq_len = 128
|
|
||||||
|
|
||||||
model = TransformerGenerator(
|
|
||||||
vocab_size, embed_dim, num_heads, mlp_dim, num_layers, max_seq_len
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
ckpt = os.path.join("models", "best_gen.pt")
|
|
||||||
if os.path.isfile(ckpt):
|
|
||||||
state = torch.load(ckpt, map_location=device)
|
|
||||||
model.load_state_dict(state)
|
|
||||||
print("[INFO] Loaded checkpoint models/best_gen.pt")
|
|
||||||
else:
|
|
||||||
print("[INFO] No checkpoint found; starting from random weights")
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
# ──────── Online Trainer ────────
|
|
||||||
|
|
||||||
class OnlineTrainer:
|
|
||||||
"""Fine-tune the generator on each new exchange."""
|
|
||||||
|
|
||||||
def __init__(self, model, lr=1e-5):
|
|
||||||
self.model = model
|
|
||||||
self.optimizer = optim.Adam(model.parameters(), lr=lr)
|
|
||||||
self.criterion = nn.CrossEntropyLoss()
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
def train_example(self, text: str):
|
|
||||||
# simple causal training: predict each next token in `text`
|
|
||||||
token_ids = tokenizer.encode(text)
|
|
||||||
if len(token_ids) < 2:
|
|
||||||
return
|
|
||||||
inp = torch.tensor([token_ids[:-1]], device=self.device)
|
|
||||||
tgt = torch.tensor([token_ids[1:]], device=self.device)
|
|
||||||
|
|
||||||
self.model.train()
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
logits = self.model(inp) # (1, seq_len-1, vocab_size)
|
|
||||||
loss = self.criterion(
|
|
||||||
logits.view(-1, logits.size(-1)),
|
|
||||||
tgt.view(-1)
|
|
||||||
)
|
|
||||||
loss.backward()
|
|
||||||
self.optimizer.step()
|
|
||||||
self.model.eval()
|
|
||||||
|
|
||||||
# persist updated weights
|
|
||||||
os.makedirs("models", exist_ok=True)
|
|
||||||
torch.save(self.model.state_dict(), ckpt)
|
|
||||||
|
|
||||||
trainer = OnlineTrainer(model)
|
|
||||||
|
|
||||||
# ──────── Discord Client ────────
|
|
||||||
|
|
||||||
intents = discord.Intents.default()
|
intents = discord.Intents.default()
|
||||||
intents.message_content = True
|
intents.message_content = True
|
||||||
@ -106,27 +25,16 @@ async def on_ready():
|
|||||||
|
|
||||||
@client.event
|
@client.event
|
||||||
async def on_message(message):
|
async def on_message(message):
|
||||||
# ignore Ruby’s own messages
|
|
||||||
if message.author == client.user:
|
if message.author == client.user:
|
||||||
return
|
return
|
||||||
|
|
||||||
content = message.content.strip()
|
content = message.content.strip()
|
||||||
if not content:
|
if not content:
|
||||||
return
|
return
|
||||||
|
|
||||||
# → Generate Ruby’s reply
|
# generate + train in one call
|
||||||
ids = tokenizer.encode(content)
|
reply = ruby.generate(content)
|
||||||
inp = torch.tensor([ids], dtype=torch.long, device=device)
|
|
||||||
with torch.no_grad():
|
|
||||||
out_ids = model(inp).argmax(-1).squeeze().cpu().tolist()
|
|
||||||
reply = tokenizer.decode(out_ids)
|
|
||||||
|
|
||||||
await message.channel.send(reply)
|
await message.channel.send(reply)
|
||||||
|
ruby.train_on(f"User: {content}\nRuby: {reply}")
|
||||||
|
|
||||||
# → Optionally train on this new example
|
|
||||||
sample = f"User: {content}\nRuby: {reply}"
|
|
||||||
trainer.train_example(sample)
|
|
||||||
|
|
||||||
# ──────── Run ────────
|
|
||||||
|
|
||||||
client.run(TOKEN)
|
client.run(TOKEN)
|
||||||
|
@ -1,40 +1,17 @@
|
|||||||
import os
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
import discord
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
from ruby_heart import RubyHeart
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
TOKEN = os.getenv("DISCORD_TOKEN")
|
|
||||||
if not TOKEN:
|
|
||||||
raise RuntimeError("DISCORD_TOKEN missing in .env")
|
|
||||||
|
|
||||||
# instantiate your “Ruby” engine
|
|
||||||
ruby = RubyHeart() # uses GPU if available
|
|
||||||
|
|
||||||
intents = discord.Intents.default()
|
|
||||||
intents.message_content = True
|
|
||||||
client = discord.Client(intents=intents)
|
|
||||||
|
|
||||||
|
|
||||||
@client.event
|
class Discriminator(nn.Module):
|
||||||
async def on_ready():
|
def __init__(self, vocab_size: int, embed_dim: int):
|
||||||
print(f"Ruby is online as {client.user}")
|
super().__init__()
|
||||||
|
self.embedding = nn.Embedding(vocab_size, embed_dim)
|
||||||
|
self.lstm = nn.LSTM(embed_dim, embed_dim, batch_first=True)
|
||||||
|
self.fc = nn.Linear(embed_dim, 1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
@client.event
|
# x: (batch, seq_len)
|
||||||
async def on_message(message):
|
emb = self.embedding(x)
|
||||||
if message.author == client.user:
|
_, (h_n, _) = self.lstm(emb)
|
||||||
return
|
# h_n[-1]: (batch, embed_dim)
|
||||||
content = message.content.strip()
|
return self.fc(h_n[-1])
|
||||||
if not content:
|
|
||||||
return
|
|
||||||
|
|
||||||
# generate + train in one call
|
|
||||||
reply = ruby.generate(content)
|
|
||||||
await message.channel.send(reply)
|
|
||||||
ruby.train_on(f"User: {content}\nRuby: {reply}")
|
|
||||||
|
|
||||||
|
|
||||||
client.run(TOKEN)
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user