Ruby/model/brain.py
2025-04-27 14:03:37 -04:00

76 lines
2.1 KiB
Python

import random
import torch
import torch.nn.functional as F
from model.memory import save_dream
from model.brain_state import model, tokenizer, DEVICE
from model.journal import record_to_journal
from model.trainer import train_on_message
from context.context import get_recent_context
recent_dreams = []
@torch.no_grad()
def generate_response():
model.eval()
seed = torch.randint(0, model.head.out_features, (1, 1), device=DEVICE)
input_ids = seed
output_tokens = []
for _ in range(50): # Max 50 tokens (short sentences)
output = model(input_ids)
next_token_logits = output[:, -1, :] / 0.8 # temperature 0.8
# Top-K Sampling
top_k = 40
values, indices = torch.topk(next_token_logits, k=top_k)
probs = F.softmax(values, dim=-1)
sampled_idx = torch.multinomial(probs, num_samples=1)
next_token = indices.gather(-1, sampled_idx)
output_tokens.append(next_token.item())
input_ids = torch.cat([input_ids, next_token.view(1, 1)], dim=1)
# Break if punctuation (end of sentence)
word = tokenizer.detokenize(next_token.item())
if word in [".", "!", "?"]:
break
return tokenizer.detokenize(output_tokens)
def score_sentence(sentence: str) -> float:
words = sentence.strip().split()
length = len(words)
diversity = len(set(words)) / (length + 1)
if length < 4:
return 0.0
return diversity * min(length, 20)
def daydream():
model.eval()
seed = torch.tensor([random.randint(0, tokenizer.next_id - 1)], device=DEVICE).unsqueeze(0)
dream = []
for _ in range(12):
out = model(seed)
logits = out[:, -1, :]
probs = F.softmax(logits, dim=-1)
token = torch.multinomial(probs, num_samples=1)
dream.append(token.item())
seed = torch.cat([seed, token], dim=1)
sentence = tokenizer.detokenize(dream)
score = score_sentence(sentence)
if score > 0.45:
save_dream(sentence, score)
record_to_journal(sentence)
train_on_message(sentence)
if len(recent_dreams) > 10:
recent_dreams.pop(0)