import os import time import asyncio import torch from torch.utils.data import DataLoader from torch.optim import AdamW import discord from core.dataset import CharDataset from core.model import GPT, GPTConfig class Brain: """ Loads model and dataset, serves generate_response() to Discord, and runs an async online training loop whenever Ruby is idle. """ def __init__( self, books_dir: str = './books', model_path: str = './model.pth', block_size: int = 128, train_batch_size: int = 8, idle_threshold: float = 60.0, # seconds of idle before training lr: float = 3e-4, client: discord.Client = None, status_channel_id: int = None ): # device self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # dataset + loader ds = CharDataset(books_dir, block_size) self.stoi, self.itos = ds.stoi, ds.itos self.block_size = block_size self.train_loader = DataLoader(ds, batch_size=train_batch_size, shuffle=True) self._train_iter = iter(self.train_loader) # model & optimizer config = GPTConfig( vocab_size=ds.vocab_size, block_size=block_size, n_layer=6, n_head=6, n_embd=384, ) self.model = GPT(config).to(self.device) if os.path.exists(model_path): self.model.load_state_dict(torch.load(model_path, map_location=self.device)) self.optimizer = AdamW(self.model.parameters(), lr=lr) self.model.train() # tracking idle time self.last_active = time.time() self.idle_threshold = idle_threshold self.model_path = model_path # discord hooks self.client = client self.status_channel_id = status_channel_id async def generate_response(self, prompt: str, **gen_kwargs) -> str: self.last_active = time.time() idx = torch.tensor( [[self.stoi.get(ch, 0) for ch in prompt[-self.block_size:]]], dtype=torch.long, device=self.device ) self.model.eval() out = self.model.generate(idx, **gen_kwargs)[0] self.model.train() return ''.join(self.itos[i] for i in out.tolist()) async def train_online(self): """ Background task: whenever idle >= idle_threshold, perform one training batch, save checkpoint, then loop. """ while True: if time.time() - self.last_active >= self.idle_threshold: # 1) log & presence print("⚙️ [Brain] Idle threshold reached—starting training batch.") if self.client: await self.client.change_presence( activity=discord.Activity( type=discord.ActivityType.watching, name="Training Ruby…" ) ) # 2) pull next batch try: xb, yb = next(self._train_iter) except StopIteration: self._train_iter = iter(self.train_loader) xb, yb = next(self._train_iter) xb, yb = xb.to(self.device), yb.to(self.device) # 3) forward/backward logits, loss = self.model(xb, yb) self.optimizer.zero_grad() loss.backward() self.optimizer.step() # 4) save & log torch.save(self.model.state_dict(), self.model_path) print(f"✅ [Brain] Finished batch. Loss: {loss.item():.4f}") # 5) optional Discord ping if self.client and self.status_channel_id: chan = self.client.get_channel(self.status_channel_id) if chan: await chan.send(f"🤖 Trained one batch, loss: {loss.item():.4f}") # 6) reset presence & idle timer if self.client: await self.client.change_presence(activity=None) self.last_active = time.time() await asyncio.sleep(1)