Ruby/model/brain.py
Dani 0674d51471 fixed an error in the tokenizer.
Updated brain.py to be a tad more aggressive, and added a cleanup function for the brainmap to cleanup.py
2025-04-27 17:03:26 -04:00

92 lines
2.5 KiB
Python

import random
import torch
import torch.nn.functional as F
from model.memory import save_dream
from model.brain_state import model, tokenizer, DEVICE
from model.journal import record_to_journal
from model.trainer import train_on_message
from context.context import get_recent_context
recent_dreams = []
@torch.inference_mode()
def generate_response(max_tokens: int = 50, temperature: float = 1.0):
model.eval()
input_ids = torch.tensor([tokenizer.token_to_id("<start>")], device=DEVICE).unsqueeze(0)
generated = []
forbidden_tokens = {
tokenizer.token_to_id("<unk>"),
tokenizer.token_to_id("<start>"),
tokenizer.token_to_id("<pad>"),
tokenizer.token_to_id("<end>"),
tokenizer.token_to_id("<sep>"),
}
for _ in range(max_tokens):
output = model(input_ids)
if torch.isnan(output).any():
print("[Brain] Detected NaN in output, restarting generation.")
return "..."
next_token_logits = output[:, -1, :]
probs = torch.softmax(next_token_logits / temperature, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
# Resample if forbidden token
while next_token.item() in forbidden_tokens:
next_token = torch.multinomial(probs, num_samples=1)
token_id = next_token.item()
if tokenizer.reverse_vocab.get(token_id, "") == "<end>":
break
generated.append(token_id)
input_ids = torch.cat([input_ids, next_token], dim=1)
return tokenizer.detokenize(generated)
def score_sentence(sentence: str) -> float:
words = sentence.strip().split()
unique = set(words)
length = len(words)
unique_ratio = len(unique) / (length + 1)
if length < 5:
return 0.0
if unique_ratio < 0.5:
return 0.0
return unique_ratio * min(length / 20.0, 1.0)
def daydream():
model.eval()
seed = torch.tensor([random.randint(0, tokenizer.next_id - 1)], device=DEVICE).unsqueeze(0)
dream = []
for _ in range(12):
out = model(seed)
logits = out[:, -1, :]
probs = F.softmax(logits, dim=-1)
token = torch.multinomial(probs, num_samples=1)
dream.append(token.item())
seed = torch.cat([seed, token], dim=1)
sentence = tokenizer.detokenize(dream)
score = score_sentence(sentence)
if score > 0.5:
save_dream(sentence, score)
record_to_journal(sentence)
train_on_message(sentence)
if len(recent_dreams) > 10:
recent_dreams.pop(0)