116 lines
4.2 KiB
Python
116 lines
4.2 KiB
Python
import os
|
|
import time
|
|
import asyncio
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
from torch.optim import AdamW
|
|
import discord
|
|
from core.dataset import CharDataset
|
|
from core.model import GPT, GPTConfig
|
|
|
|
|
|
class Brain:
|
|
"""
|
|
Loads model and dataset, serves generate_response() to Discord,
|
|
and runs an async online training loop whenever Ruby is idle.
|
|
"""
|
|
def __init__(
|
|
self,
|
|
books_dir: str = './books',
|
|
model_path: str = './model.pth',
|
|
block_size: int = 128,
|
|
train_batch_size: int = 8,
|
|
idle_threshold: float = 60.0, # seconds of idle before training
|
|
lr: float = 3e-4,
|
|
client: discord.Client = None,
|
|
status_channel_id: int = None
|
|
):
|
|
# device
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
# dataset + loader
|
|
ds = CharDataset(books_dir, block_size)
|
|
self.stoi, self.itos = ds.stoi, ds.itos
|
|
self.block_size = block_size
|
|
self.train_loader = DataLoader(ds, batch_size=train_batch_size, shuffle=True)
|
|
self._train_iter = iter(self.train_loader)
|
|
# model & optimizer
|
|
config = GPTConfig(
|
|
vocab_size=ds.vocab_size,
|
|
block_size=block_size,
|
|
n_layer=6,
|
|
n_head=6,
|
|
n_embd=384,
|
|
)
|
|
self.model = GPT(config).to(self.device)
|
|
if os.path.exists(model_path):
|
|
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
|
|
self.optimizer = AdamW(self.model.parameters(), lr=lr)
|
|
self.model.train()
|
|
# tracking idle time
|
|
self.last_active = time.time()
|
|
self.idle_threshold = idle_threshold
|
|
self.model_path = model_path
|
|
# discord hooks
|
|
self.client = client
|
|
self.status_channel_id = status_channel_id
|
|
|
|
async def generate_response(self, prompt: str, **gen_kwargs) -> str:
|
|
self.last_active = time.time()
|
|
idx = torch.tensor(
|
|
[[self.stoi.get(ch, 0) for ch in prompt[-self.block_size:]]],
|
|
dtype=torch.long,
|
|
device=self.device
|
|
)
|
|
self.model.eval()
|
|
out = self.model.generate(idx, **gen_kwargs)[0]
|
|
self.model.train()
|
|
return ''.join(self.itos[i] for i in out.tolist())
|
|
|
|
async def train_online(self):
|
|
"""
|
|
Background task: whenever idle >= idle_threshold,
|
|
perform one training batch, save checkpoint, then loop.
|
|
"""
|
|
while True:
|
|
if time.time() - self.last_active >= self.idle_threshold:
|
|
# 1) log & presence
|
|
print("⚙️ [Brain] Idle threshold reached—starting training batch.")
|
|
if self.client:
|
|
await self.client.change_presence(
|
|
activity=discord.Activity(
|
|
type=discord.ActivityType.watching,
|
|
name="Training Ruby…"
|
|
)
|
|
)
|
|
|
|
# 2) pull next batch
|
|
try:
|
|
xb, yb = next(self._train_iter)
|
|
except StopIteration:
|
|
self._train_iter = iter(self.train_loader)
|
|
xb, yb = next(self._train_iter)
|
|
xb, yb = xb.to(self.device), yb.to(self.device)
|
|
|
|
# 3) forward/backward
|
|
logits, loss = self.model(xb, yb)
|
|
self.optimizer.zero_grad()
|
|
loss.backward()
|
|
self.optimizer.step()
|
|
|
|
# 4) save & log
|
|
torch.save(self.model.state_dict(), self.model_path)
|
|
print(f"✅ [Brain] Finished batch. Loss: {loss.item():.4f}")
|
|
|
|
# 5) optional Discord ping
|
|
if self.client and self.status_channel_id:
|
|
chan = self.client.get_channel(self.status_channel_id)
|
|
if chan:
|
|
await chan.send(f"🤖 Trained one batch, loss: {loss.item():.4f}")
|
|
|
|
# 6) reset presence & idle timer
|
|
if self.client:
|
|
await self.client.change_presence(activity=None)
|
|
self.last_active = time.time()
|
|
|
|
await asyncio.sleep(1)
|