import torch import torch.optim as optim from torch.nn import CrossEntropyLoss import torch.nn.functional as F from sensory import Sensory from brain import Brain class NervousSystem: """Wraps the Brain, handles token growth, generation and on-the-fly training.""" # noqa: E501 def __init__(self, device: str = "cuda"): self.device = torch.device(device if torch.cuda.is_available() else "cpu") # noqa: E501 self.sensory = Sensory() vocab_size = len(self.sensory.stoi) self.brain = Brain(vocab_size).to(self.device) self.optimizer = optim.Adam(self.brain.parameters(), lr=1e-4) self.criterion = CrossEntropyLoss(ignore_index=0) self.meta_steps = 0 def _resize_embeddings(self) -> None: new_size = len(self.sensory.stoi) old_emb = self.brain.token_emb # rebuild token embeddings self.brain.token_emb = torch.nn.Embedding( new_size, old_emb.embedding_dim ).to(self.device) with torch.no_grad(): self.brain.token_emb.weight[: old_emb.num_embeddings] = old_emb.weight # noqa: E501 # rebuild output head old_out = self.brain.fc_out self.brain.fc_out = torch.nn.Linear( old_emb.embedding_dim, new_size ).to(self.device) with torch.no_grad(): self.brain.fc_out.weight[: old_out.out_features] = old_out.weight self.brain.fc_out.bias[: old_out.out_features] = old_out.bias def generate(self, prompt: str, max_len: int = 50, temperature: float = 0.8, top_k: int = 50) -> str: self.brain.eval() raw_ids = self.sensory.encode(prompt, grow=False)[-self.brain.max_seq_len:] # noqa: E501 out = torch.tensor(raw_ids, dtype=torch.long, device=self.device).unsqueeze(0) # noqa: E501 result = [] for _ in range(max_len): logits = self.brain(out)[:, -1, :] # apply temperature logits = logits / temperature # top-k filtering values, indices = torch.topk(logits, top_k) probs = F.softmax(values, dim=-1) next_tok = indices[0, torch.multinomial(probs, 1)].unsqueeze(0).unsqueeze(0) # noqa: E501 tok_id = next_tok.item() if tok_id == self.sensory.stoi[""]: break result.append(tok_id) out = torch.cat([out, next_tok], dim=1) return self.sensory.decode(result) def train(self, user_text: str, bot_text: str) -> None: # 1) grow vocab on _train_ only for txt in (user_text, bot_text): _ = self.sensory.encode(txt, grow=True) self._resize_embeddings() # ensure if "" not in self.sensory.stoi: idx = len(self.sensory.stoi) self.sensory.stoi[""] = idx self.sensory.itos[idx] = "" self._resize_embeddings() combined = f"{user_text} {bot_text}" ids = torch.tensor( self.sensory.encode(combined, grow=False), dtype=torch.long, device=self.device # noqa: E501 ).unsqueeze(0) if ids.size(1) < 2: return inputs = ids[:, :-1] targets = ids[:, 1:] self.brain.train() logits = self.brain(inputs) loss = self.criterion( logits.view(-1, logits.size(-1)), targets.view(-1) ) self.optimizer.zero_grad() loss.backward() self.optimizer.step() # a tiny meta-learning bump self.meta_steps += 1 if self.meta_steps % 100 == 0: for g in self.optimizer.param_groups: old_lr = g["lr"] g["lr"] = old_lr * 1.1 torch.cuda.synchronize(self.device) g["lr"] = old_lr