107 lines
3.8 KiB
Python
107 lines
3.8 KiB
Python
import torch
|
|
import torch.optim as optim
|
|
from torch.nn import CrossEntropyLoss
|
|
import torch.nn.functional as F
|
|
|
|
from sensory import Sensory
|
|
from brain import Brain
|
|
|
|
|
|
class NervousSystem:
|
|
"""Wraps the Brain, handles token growth, generation and on-the-fly training.""" # noqa: E501
|
|
|
|
def __init__(self, device: str = "cuda"):
|
|
self.device = torch.device(device if torch.cuda.is_available() else "cpu") # noqa: E501
|
|
self.sensory = Sensory()
|
|
vocab_size = len(self.sensory.stoi)
|
|
self.brain = Brain(vocab_size).to(self.device)
|
|
|
|
self.optimizer = optim.Adam(self.brain.parameters(), lr=1e-4)
|
|
self.criterion = CrossEntropyLoss(ignore_index=0)
|
|
self.meta_steps = 0
|
|
|
|
def _resize_embeddings(self) -> None:
|
|
new_size = len(self.sensory.stoi)
|
|
old_emb = self.brain.token_emb
|
|
|
|
# rebuild token embeddings
|
|
self.brain.token_emb = torch.nn.Embedding(
|
|
new_size, old_emb.embedding_dim
|
|
).to(self.device)
|
|
with torch.no_grad():
|
|
self.brain.token_emb.weight[: old_emb.num_embeddings] = old_emb.weight # noqa: E501
|
|
|
|
# rebuild output head
|
|
old_out = self.brain.fc_out
|
|
self.brain.fc_out = torch.nn.Linear(
|
|
old_emb.embedding_dim, new_size
|
|
).to(self.device)
|
|
with torch.no_grad():
|
|
self.brain.fc_out.weight[: old_out.out_features] = old_out.weight
|
|
self.brain.fc_out.bias[: old_out.out_features] = old_out.bias
|
|
|
|
def generate(self, prompt: str, max_len: int = 50,
|
|
temperature: float = 0.8, top_k: int = 50) -> str:
|
|
self.brain.eval()
|
|
raw_ids = self.sensory.encode(prompt, grow=False)[-self.brain.max_seq_len:] # noqa: E501
|
|
out = torch.tensor(raw_ids, dtype=torch.long, device=self.device).unsqueeze(0) # noqa: E501
|
|
|
|
result = []
|
|
for _ in range(max_len):
|
|
logits = self.brain(out)[:, -1, :]
|
|
# apply temperature
|
|
logits = logits / temperature
|
|
# top-k filtering
|
|
values, indices = torch.topk(logits, top_k)
|
|
probs = F.softmax(values, dim=-1)
|
|
next_tok = indices[0, torch.multinomial(probs, 1)].unsqueeze(0).unsqueeze(0) # noqa: E501
|
|
tok_id = next_tok.item()
|
|
if tok_id == self.sensory.stoi["<eos>"]:
|
|
break
|
|
result.append(tok_id)
|
|
out = torch.cat([out, next_tok], dim=1)
|
|
|
|
return self.sensory.decode(result)
|
|
|
|
def train(self, user_text: str, bot_text: str) -> None:
|
|
# 1) grow vocab on _train_ only
|
|
for txt in (user_text, bot_text):
|
|
_ = self.sensory.encode(txt, grow=True)
|
|
self._resize_embeddings()
|
|
|
|
# ensure <sep>
|
|
if "<sep>" not in self.sensory.stoi:
|
|
idx = len(self.sensory.stoi)
|
|
self.sensory.stoi["<sep>"] = idx
|
|
self.sensory.itos[idx] = "<sep>"
|
|
self._resize_embeddings()
|
|
|
|
combined = f"{user_text} <sep> {bot_text}"
|
|
ids = torch.tensor(
|
|
self.sensory.encode(combined, grow=False), dtype=torch.long, device=self.device # noqa: E501
|
|
).unsqueeze(0)
|
|
|
|
if ids.size(1) < 2:
|
|
return
|
|
|
|
inputs = ids[:, :-1]
|
|
targets = ids[:, 1:]
|
|
|
|
self.brain.train()
|
|
logits = self.brain(inputs)
|
|
loss = self.criterion(
|
|
logits.view(-1, logits.size(-1)), targets.view(-1)
|
|
)
|
|
self.optimizer.zero_grad()
|
|
loss.backward()
|
|
self.optimizer.step()
|
|
|
|
# a tiny meta-learning bump
|
|
self.meta_steps += 1
|
|
if self.meta_steps % 100 == 0:
|
|
for g in self.optimizer.param_groups:
|
|
old_lr = g["lr"]
|
|
g["lr"] = old_lr * 1.1
|
|
torch.cuda.synchronize(self.device)
|
|
g["lr"] = old_lr
|