Ruby/model/brain.py

102 lines
3.2 KiB
Python

import torch
from model.brain_state import model, tokenizer, DEVICE
@torch.inference_mode()
def generate_response(max_tokens: int = 50, temperature: float = 1.0):
model.eval()
input_ids = torch.tensor([tokenizer.token_to_id("<start>")], device=DEVICE).unsqueeze(0)
generated = []
forbidden_tokens = {
tokenizer.token_to_id("<unk>"),
tokenizer.token_to_id("<start>"),
tokenizer.token_to_id("<pad>")
}
for step in range(max_tokens):
output = model(input_ids)
if torch.isnan(output).any():
print("[Brain] Detected NaN in output, restarting generation.")
return "..."
next_token_logits = output[:, -1, :]
# ✨ Boost <end> token after 10 tokens to encourage stopping
if step >= 10:
end_token_id = tokenizer.token_to_id("<end>")
next_token_logits[:, end_token_id] += 3.0 # Stronger boost
probs = torch.softmax(next_token_logits / temperature, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
# Block forbidden tokens
while next_token.item() in forbidden_tokens:
next_token = torch.multinomial(probs, num_samples=1)
token_id = next_token.item()
if tokenizer.reverse_vocab.get(token_id, "") == "<end>":
break
generated.append(token_id)
input_ids = torch.cat([input_ids, next_token], dim=1)
# ✨ If the last 5 words repeat the same word 3+ times, early stop
if len(generated) >= 5:
recent = generated[-5:]
counts = {}
for t in recent:
counts[t] = counts.get(t, 0) + 1
if any(count >= 3 for count in counts.values()):
print("[Brain] Detected heavy repetition, stopping early.")
break
text = tokenizer.detokenize(generated)
# ✨ Simple final sanity check
words = text.split()
if words:
unique_ratio = len(set(words)) / len(words)
if unique_ratio < 0.5:
print("[Brain] Word salad detected, retrying generation...")
return generate_response(max_tokens)
return text
def score_sentence(sentence: str) -> float:
words = sentence.strip().split()
unique = set(words)
length = len(words)
# Basic hard filters
if length < 6:
return 0.0
if len(unique) / length < 0.5:
return 0.0
# Check for verb presence
verbs = {"am", "is", "are", "was", "were", "be", "being", "been",
"have", "has", "had", "do", "does", "did", "will", "would",
"shall", "should", "may", "might", "must", "can", "could",
"say", "said", "go", "went", "gone", "think", "thought",
"know", "knew", "make", "made", "give", "gave", "take", "took",
"find", "found", "see", "saw", "come", "came", "run", "ran", "walk", "walked"}
has_verb = any(word.lower() in verbs for word in words)
if not has_verb:
return 0.0
# Reward longer and more diverse sentences
diversity = len(unique) / length
length_bonus = min(length / 20.0, 1.0)
score = diversity * length_bonus
if has_verb:
score += 0.1 # Bonus if there's an action!
return min(score, 1.0)