diff --git a/model/brain.py b/model/brain.py index 8286e5f..41eeeda 100644 --- a/model/brain.py +++ b/model/brain.py @@ -31,31 +31,39 @@ def generate_response(max_tokens: int = 50, temperature: float = 1.0): next_token_logits = output[:, -1, :] - # 💬 Boost token score after 10+ tokens to encourage ending + # ✨ Boost token after 10 tokens to encourage stopping if step >= 10: end_token_id = tokenizer.token_to_id("") - next_token_logits[:, end_token_id] += 2.0 # Boost end token's chance + next_token_logits[:, end_token_id] += 3.0 # Stronger boost probs = torch.softmax(next_token_logits / temperature, dim=-1) next_token = torch.multinomial(probs, num_samples=1) - # Resample if forbidden token + # Block forbidden tokens while next_token.item() in forbidden_tokens: next_token = torch.multinomial(probs, num_samples=1) token_id = next_token.item() - # If token, finish if tokenizer.reverse_vocab.get(token_id, "") == "": break generated.append(token_id) - input_ids = torch.cat([input_ids, next_token], dim=1) + # ✨ If the last 5 words repeat the same word 3+ times, early stop + if len(generated) >= 5: + recent = generated[-5:] + counts = {} + for t in recent: + counts[t] = counts.get(t, 0) + 1 + if any(count >= 3 for count in counts.values()): + print("[Brain] Detected heavy repetition, stopping early.") + break + text = tokenizer.detokenize(generated) - # ✨ Simple sanity check: if text has >50% repeated words, retry once + # ✨ Simple final sanity check words = text.split() if words: unique_ratio = len(set(words)) / len(words)