102 lines
3.2 KiB
Python
102 lines
3.2 KiB
Python
import torch
|
|
from model.brain_state import model, tokenizer, DEVICE
|
|
|
|
|
|
@torch.inference_mode()
|
|
def generate_response(max_tokens: int = 50, temperature: float = 1.0):
|
|
model.eval()
|
|
|
|
input_ids = torch.tensor([tokenizer.token_to_id("<start>")], device=DEVICE).unsqueeze(0)
|
|
generated = []
|
|
forbidden_tokens = {
|
|
tokenizer.token_to_id("<unk>"),
|
|
tokenizer.token_to_id("<start>"),
|
|
tokenizer.token_to_id("<pad>")
|
|
}
|
|
|
|
for step in range(max_tokens):
|
|
output = model(input_ids)
|
|
if torch.isnan(output).any():
|
|
print("[Brain] Detected NaN in output, restarting generation.")
|
|
return "..."
|
|
|
|
next_token_logits = output[:, -1, :]
|
|
|
|
# ✨ Boost <end> token after 10 tokens to encourage stopping
|
|
if step >= 10:
|
|
end_token_id = tokenizer.token_to_id("<end>")
|
|
next_token_logits[:, end_token_id] += 3.0 # Stronger boost
|
|
|
|
probs = torch.softmax(next_token_logits / temperature, dim=-1)
|
|
next_token = torch.multinomial(probs, num_samples=1)
|
|
|
|
# Block forbidden tokens
|
|
while next_token.item() in forbidden_tokens:
|
|
next_token = torch.multinomial(probs, num_samples=1)
|
|
|
|
token_id = next_token.item()
|
|
|
|
if tokenizer.reverse_vocab.get(token_id, "") == "<end>":
|
|
break
|
|
|
|
generated.append(token_id)
|
|
input_ids = torch.cat([input_ids, next_token], dim=1)
|
|
|
|
# ✨ If the last 5 words repeat the same word 3+ times, early stop
|
|
if len(generated) >= 5:
|
|
recent = generated[-5:]
|
|
counts = {}
|
|
for t in recent:
|
|
counts[t] = counts.get(t, 0) + 1
|
|
if any(count >= 3 for count in counts.values()):
|
|
print("[Brain] Detected heavy repetition, stopping early.")
|
|
break
|
|
|
|
text = tokenizer.detokenize(generated)
|
|
|
|
# ✨ Simple final sanity check
|
|
words = text.split()
|
|
if words:
|
|
unique_ratio = len(set(words)) / len(words)
|
|
if unique_ratio < 0.5:
|
|
print("[Brain] Word salad detected, retrying generation...")
|
|
return generate_response(max_tokens)
|
|
|
|
return text
|
|
|
|
|
|
def score_sentence(sentence: str) -> float:
|
|
words = sentence.strip().split()
|
|
unique = set(words)
|
|
length = len(words)
|
|
|
|
# Basic hard filters
|
|
if length < 6:
|
|
return 0.0
|
|
if len(unique) / length < 0.5:
|
|
return 0.0
|
|
|
|
# Check for verb presence
|
|
verbs = {"am", "is", "are", "was", "were", "be", "being", "been",
|
|
"have", "has", "had", "do", "does", "did", "will", "would",
|
|
"shall", "should", "may", "might", "must", "can", "could",
|
|
"say", "said", "go", "went", "gone", "think", "thought",
|
|
"know", "knew", "make", "made", "give", "gave", "take", "took",
|
|
"find", "found", "see", "saw", "come", "came", "run", "ran", "walk", "walked"}
|
|
|
|
has_verb = any(word.lower() in verbs for word in words)
|
|
|
|
if not has_verb:
|
|
return 0.0
|
|
|
|
# Reward longer and more diverse sentences
|
|
diversity = len(unique) / length
|
|
length_bonus = min(length / 20.0, 1.0)
|
|
|
|
score = diversity * length_bonus
|
|
|
|
if has_verb:
|
|
score += 0.1 # Bonus if there's an action!
|
|
|
|
return min(score, 1.0)
|