import torch from model.brain_state import model, tokenizer, DEVICE @torch.inference_mode() def generate_response(max_tokens: int = 50, temperature: float = 1.0): model.eval() input_ids = torch.tensor([tokenizer.token_to_id("")], device=DEVICE).unsqueeze(0) generated = [] forbidden_tokens = { tokenizer.token_to_id(""), tokenizer.token_to_id(""), tokenizer.token_to_id("") } for step in range(max_tokens): output = model(input_ids) if torch.isnan(output).any(): print("[Brain] Detected NaN in output, restarting generation.") return "..." next_token_logits = output[:, -1, :] # ✨ Boost token after 10 tokens to encourage stopping if step >= 10: end_token_id = tokenizer.token_to_id("") next_token_logits[:, end_token_id] += 3.0 # Stronger boost probs = torch.softmax(next_token_logits / temperature, dim=-1) next_token = torch.multinomial(probs, num_samples=1) # Block forbidden tokens while next_token.item() in forbidden_tokens: next_token = torch.multinomial(probs, num_samples=1) token_id = next_token.item() if tokenizer.reverse_vocab.get(token_id, "") == "": break generated.append(token_id) input_ids = torch.cat([input_ids, next_token], dim=1) # ✨ If the last 5 words repeat the same word 3+ times, early stop if len(generated) >= 5: recent = generated[-5:] counts = {} for t in recent: counts[t] = counts.get(t, 0) + 1 if any(count >= 3 for count in counts.values()): print("[Brain] Detected heavy repetition, stopping early.") break text = tokenizer.detokenize(generated) # ✨ Simple final sanity check words = text.split() if words: unique_ratio = len(set(words)) / len(words) if unique_ratio < 0.5: print("[Brain] Word salad detected, retrying generation...") return generate_response(max_tokens) return text def score_sentence(sentence: str) -> float: words = sentence.strip().split() unique = set(words) length = len(words) # Basic hard filters if length < 6: return 0.0 if len(unique) / length < 0.5: return 0.0 # Check for verb presence verbs = {"am", "is", "are", "was", "were", "be", "being", "been", "have", "has", "had", "do", "does", "did", "will", "would", "shall", "should", "may", "might", "must", "can", "could", "say", "said", "go", "went", "gone", "think", "thought", "know", "knew", "make", "made", "give", "gave", "take", "took", "find", "found", "see", "saw", "come", "came", "run", "ran", "walk", "walked"} has_verb = any(word.lower() in verbs for word in words) if not has_verb: return 0.0 # Reward longer and more diverse sentences diversity = len(unique) / length length_bonus = min(length / 20.0, 1.0) score = diversity * length_bonus if has_verb: score += 0.1 # Bonus if there's an action! return min(score, 1.0)