Ruby/nervous_system.py
2025-05-04 17:32:25 -04:00

107 lines
3.8 KiB
Python

import torch
import torch.optim as optim
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
from sensory import Sensory
from brain import Brain
class NervousSystem:
"""Wraps the Brain, handles token growth, generation and on-the-fly training.""" # noqa: E501
def __init__(self, device: str = "cuda"):
self.device = torch.device(device if torch.cuda.is_available() else "cpu") # noqa: E501
self.sensory = Sensory()
vocab_size = len(self.sensory.stoi)
self.brain = Brain(vocab_size).to(self.device)
self.optimizer = optim.Adam(self.brain.parameters(), lr=1e-4)
self.criterion = CrossEntropyLoss(ignore_index=0)
self.meta_steps = 0
def _resize_embeddings(self) -> None:
new_size = len(self.sensory.stoi)
old_emb = self.brain.token_emb
# rebuild token embeddings
self.brain.token_emb = torch.nn.Embedding(
new_size, old_emb.embedding_dim
).to(self.device)
with torch.no_grad():
self.brain.token_emb.weight[: old_emb.num_embeddings] = old_emb.weight # noqa: E501
# rebuild output head
old_out = self.brain.fc_out
self.brain.fc_out = torch.nn.Linear(
old_emb.embedding_dim, new_size
).to(self.device)
with torch.no_grad():
self.brain.fc_out.weight[: old_out.out_features] = old_out.weight
self.brain.fc_out.bias[: old_out.out_features] = old_out.bias
def generate(self, prompt: str, max_len: int = 50,
temperature: float = 0.8, top_k: int = 50) -> str:
self.brain.eval()
raw_ids = self.sensory.encode(prompt, grow=False)[-self.brain.max_seq_len:] # noqa: E501
out = torch.tensor(raw_ids, dtype=torch.long, device=self.device).unsqueeze(0) # noqa: E501
result = []
for _ in range(max_len):
logits = self.brain(out)[:, -1, :]
# apply temperature
logits = logits / temperature
# top-k filtering
values, indices = torch.topk(logits, top_k)
probs = F.softmax(values, dim=-1)
next_tok = indices[0, torch.multinomial(probs, 1)].unsqueeze(0).unsqueeze(0) # noqa: E501
tok_id = next_tok.item()
if tok_id == self.sensory.stoi["<eos>"]:
break
result.append(tok_id)
out = torch.cat([out, next_tok], dim=1)
return self.sensory.decode(result)
def train(self, user_text: str, bot_text: str) -> None:
# 1) grow vocab on _train_ only
for txt in (user_text, bot_text):
_ = self.sensory.encode(txt, grow=True)
self._resize_embeddings()
# ensure <sep>
if "<sep>" not in self.sensory.stoi:
idx = len(self.sensory.stoi)
self.sensory.stoi["<sep>"] = idx
self.sensory.itos[idx] = "<sep>"
self._resize_embeddings()
combined = f"{user_text} <sep> {bot_text}"
ids = torch.tensor(
self.sensory.encode(combined, grow=False), dtype=torch.long, device=self.device # noqa: E501
).unsqueeze(0)
if ids.size(1) < 2:
return
inputs = ids[:, :-1]
targets = ids[:, 1:]
self.brain.train()
logits = self.brain(inputs)
loss = self.criterion(
logits.view(-1, logits.size(-1)), targets.view(-1)
)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# a tiny meta-learning bump
self.meta_steps += 1
if self.meta_steps % 100 == 0:
for g in self.optimizer.param_groups:
old_lr = g["lr"]
g["lr"] = old_lr * 1.1
torch.cuda.synchronize(self.device)
g["lr"] = old_lr