128 lines
3.8 KiB
Python
128 lines
3.8 KiB
Python
import random
|
|
import re
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from model.memory import save_dream
|
|
from model.brain_state import model, tokenizer, DEVICE
|
|
from model.journal import record_to_journal
|
|
from model.trainer import train_on_message
|
|
from context.context import get_recent_context
|
|
|
|
recent_dreams = []
|
|
|
|
|
|
@torch.inference_mode()
|
|
def generate_response(max_tokens: int = 50, temperature: float = 1.0):
|
|
model.eval()
|
|
|
|
input_ids = torch.tensor([tokenizer.token_to_id("<start>")], device=DEVICE).unsqueeze(0)
|
|
generated = []
|
|
forbidden_tokens = {
|
|
tokenizer.token_to_id("<unk>"),
|
|
tokenizer.token_to_id("<start>"),
|
|
tokenizer.token_to_id("<pad>")
|
|
}
|
|
|
|
for step in range(max_tokens):
|
|
output = model(input_ids)
|
|
if torch.isnan(output).any():
|
|
print("[Brain] Detected NaN in output, restarting generation.")
|
|
return "..."
|
|
|
|
next_token_logits = output[:, -1, :]
|
|
|
|
# 💬 Boost <end> token score after 10+ tokens to encourage ending
|
|
if step >= 10:
|
|
end_token_id = tokenizer.token_to_id("<end>")
|
|
next_token_logits[:, end_token_id] += 2.0 # Boost end token's chance
|
|
|
|
probs = torch.softmax(next_token_logits / temperature, dim=-1)
|
|
next_token = torch.multinomial(probs, num_samples=1)
|
|
|
|
# Resample if forbidden token
|
|
while next_token.item() in forbidden_tokens:
|
|
next_token = torch.multinomial(probs, num_samples=1)
|
|
|
|
token_id = next_token.item()
|
|
|
|
# If <end> token, finish
|
|
if tokenizer.reverse_vocab.get(token_id, "") == "<end>":
|
|
break
|
|
|
|
generated.append(token_id)
|
|
|
|
input_ids = torch.cat([input_ids, next_token], dim=1)
|
|
|
|
text = tokenizer.detokenize(generated)
|
|
|
|
# ✨ Simple sanity check: if text has >50% repeated words, retry once
|
|
words = text.split()
|
|
if words:
|
|
unique_ratio = len(set(words)) / len(words)
|
|
if unique_ratio < 0.5:
|
|
print("[Brain] Word salad detected, retrying generation...")
|
|
return generate_response(max_tokens)
|
|
|
|
return text
|
|
|
|
|
|
def score_sentence(sentence: str) -> float:
|
|
words = sentence.strip().split()
|
|
unique = set(words)
|
|
length = len(words)
|
|
|
|
# Basic hard filters
|
|
if length < 6:
|
|
return 0.0
|
|
if len(unique) / length < 0.5:
|
|
return 0.0
|
|
|
|
# Check for verb presence
|
|
verbs = {"am", "is", "are", "was", "were", "be", "being", "been",
|
|
"have", "has", "had", "do", "does", "did", "will", "would",
|
|
"shall", "should", "may", "might", "must", "can", "could",
|
|
"say", "said", "go", "went", "gone", "think", "thought",
|
|
"know", "knew", "make", "made", "give", "gave", "take", "took",
|
|
"find", "found", "see", "saw", "come", "came", "run", "ran", "walk", "walked"}
|
|
|
|
has_verb = any(word.lower() in verbs for word in words)
|
|
|
|
if not has_verb:
|
|
return 0.0
|
|
|
|
# Reward longer and more diverse sentences
|
|
diversity = len(unique) / length
|
|
length_bonus = min(length / 20.0, 1.0)
|
|
|
|
score = diversity * length_bonus
|
|
|
|
if has_verb:
|
|
score += 0.1 # Bonus if there's an action!
|
|
|
|
return min(score, 1.0)
|
|
|
|
|
|
def daydream():
|
|
model.eval()
|
|
seed = torch.tensor([random.randint(0, tokenizer.next_id - 1)], device=DEVICE).unsqueeze(0)
|
|
dream = []
|
|
|
|
for _ in range(12):
|
|
out = model(seed)
|
|
logits = out[:, -1, :]
|
|
probs = F.softmax(logits, dim=-1)
|
|
token = torch.multinomial(probs, num_samples=1)
|
|
dream.append(token.item())
|
|
seed = torch.cat([seed, token], dim=1)
|
|
|
|
sentence = tokenizer.detokenize(dream)
|
|
score = score_sentence(sentence)
|
|
|
|
if score > 0.5:
|
|
save_dream(sentence, score)
|
|
record_to_journal(sentence)
|
|
train_on_message(sentence)
|
|
|
|
if len(recent_dreams) > 10:
|
|
recent_dreams.pop(0)
|