Ruby/core/brain.py

116 lines
4.2 KiB
Python

import os
import time
import asyncio
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
import discord
from core.dataset import CharDataset
from core.model import GPT, GPTConfig
class Brain:
"""
Loads model and dataset, serves generate_response() to Discord,
and runs an async online training loop whenever Ruby is idle.
"""
def __init__(
self,
books_dir: str = './books',
model_path: str = './model.pth',
block_size: int = 128,
train_batch_size: int = 8,
idle_threshold: float = 60.0, # seconds of idle before training
lr: float = 3e-4,
client: discord.Client = None,
status_channel_id: int = None
):
# device
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# dataset + loader
ds = CharDataset(books_dir, block_size)
self.stoi, self.itos = ds.stoi, ds.itos
self.block_size = block_size
self.train_loader = DataLoader(ds, batch_size=train_batch_size, shuffle=True)
self._train_iter = iter(self.train_loader)
# model & optimizer
config = GPTConfig(
vocab_size=ds.vocab_size,
block_size=block_size,
n_layer=6,
n_head=6,
n_embd=384,
)
self.model = GPT(config).to(self.device)
if os.path.exists(model_path):
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
self.optimizer = AdamW(self.model.parameters(), lr=lr)
self.model.train()
# tracking idle time
self.last_active = time.time()
self.idle_threshold = idle_threshold
self.model_path = model_path
# discord hooks
self.client = client
self.status_channel_id = status_channel_id
async def generate_response(self, prompt: str, **gen_kwargs) -> str:
self.last_active = time.time()
idx = torch.tensor(
[[self.stoi.get(ch, 0) for ch in prompt[-self.block_size:]]],
dtype=torch.long,
device=self.device
)
self.model.eval()
out = self.model.generate(idx, **gen_kwargs)[0]
self.model.train()
return ''.join(self.itos[i] for i in out.tolist())
async def train_online(self):
"""
Background task: whenever idle >= idle_threshold,
perform one training batch, save checkpoint, then loop.
"""
while True:
if time.time() - self.last_active >= self.idle_threshold:
# 1) log & presence
print("⚙️ [Brain] Idle threshold reached—starting training batch.")
if self.client:
await self.client.change_presence(
activity=discord.Activity(
type=discord.ActivityType.watching,
name="Training Ruby…"
)
)
# 2) pull next batch
try:
xb, yb = next(self._train_iter)
except StopIteration:
self._train_iter = iter(self.train_loader)
xb, yb = next(self._train_iter)
xb, yb = xb.to(self.device), yb.to(self.device)
# 3) forward/backward
logits, loss = self.model(xb, yb)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# 4) save & log
torch.save(self.model.state_dict(), self.model_path)
print(f"✅ [Brain] Finished batch. Loss: {loss.item():.4f}")
# 5) optional Discord ping
if self.client and self.status_channel_id:
chan = self.client.get_channel(self.status_channel_id)
if chan:
await chan.send(f"🤖 Trained one batch, loss: {loss.item():.4f}")
# 6) reset presence & idle timer
if self.client:
await self.client.change_presence(activity=None)
self.last_active = time.time()
await asyncio.sleep(1)