Attempted to fix more word salad

This commit is contained in:
Dani 2025-04-27 18:15:33 -04:00
parent 3674425a44
commit daf381561e

View File

@ -31,31 +31,39 @@ def generate_response(max_tokens: int = 50, temperature: float = 1.0):
next_token_logits = output[:, -1, :] next_token_logits = output[:, -1, :]
# 💬 Boost <end> token score after 10+ tokens to encourage ending # ✨ Boost <end> token after 10 tokens to encourage stopping
if step >= 10: if step >= 10:
end_token_id = tokenizer.token_to_id("<end>") end_token_id = tokenizer.token_to_id("<end>")
next_token_logits[:, end_token_id] += 2.0 # Boost end token's chance next_token_logits[:, end_token_id] += 3.0 # Stronger boost
probs = torch.softmax(next_token_logits / temperature, dim=-1) probs = torch.softmax(next_token_logits / temperature, dim=-1)
next_token = torch.multinomial(probs, num_samples=1) next_token = torch.multinomial(probs, num_samples=1)
# Resample if forbidden token # Block forbidden tokens
while next_token.item() in forbidden_tokens: while next_token.item() in forbidden_tokens:
next_token = torch.multinomial(probs, num_samples=1) next_token = torch.multinomial(probs, num_samples=1)
token_id = next_token.item() token_id = next_token.item()
# If <end> token, finish
if tokenizer.reverse_vocab.get(token_id, "") == "<end>": if tokenizer.reverse_vocab.get(token_id, "") == "<end>":
break break
generated.append(token_id) generated.append(token_id)
input_ids = torch.cat([input_ids, next_token], dim=1) input_ids = torch.cat([input_ids, next_token], dim=1)
# ✨ If the last 5 words repeat the same word 3+ times, early stop
if len(generated) >= 5:
recent = generated[-5:]
counts = {}
for t in recent:
counts[t] = counts.get(t, 0) + 1
if any(count >= 3 for count in counts.values()):
print("[Brain] Detected heavy repetition, stopping early.")
break
text = tokenizer.detokenize(generated) text = tokenizer.detokenize(generated)
# ✨ Simple sanity check: if text has >50% repeated words, retry once # ✨ Simple final sanity check
words = text.split() words = text.split()
if words: if words:
unique_ratio = len(set(words)) / len(words) unique_ratio = len(set(words)) / len(words)