Attempted to fix more word salad
This commit is contained in:
parent
3674425a44
commit
daf381561e
@ -31,31 +31,39 @@ def generate_response(max_tokens: int = 50, temperature: float = 1.0):
|
||||
|
||||
next_token_logits = output[:, -1, :]
|
||||
|
||||
# 💬 Boost <end> token score after 10+ tokens to encourage ending
|
||||
# ✨ Boost <end> token after 10 tokens to encourage stopping
|
||||
if step >= 10:
|
||||
end_token_id = tokenizer.token_to_id("<end>")
|
||||
next_token_logits[:, end_token_id] += 2.0 # Boost end token's chance
|
||||
next_token_logits[:, end_token_id] += 3.0 # Stronger boost
|
||||
|
||||
probs = torch.softmax(next_token_logits / temperature, dim=-1)
|
||||
next_token = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
# Resample if forbidden token
|
||||
# Block forbidden tokens
|
||||
while next_token.item() in forbidden_tokens:
|
||||
next_token = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
token_id = next_token.item()
|
||||
|
||||
# If <end> token, finish
|
||||
if tokenizer.reverse_vocab.get(token_id, "") == "<end>":
|
||||
break
|
||||
|
||||
generated.append(token_id)
|
||||
|
||||
input_ids = torch.cat([input_ids, next_token], dim=1)
|
||||
|
||||
# ✨ If the last 5 words repeat the same word 3+ times, early stop
|
||||
if len(generated) >= 5:
|
||||
recent = generated[-5:]
|
||||
counts = {}
|
||||
for t in recent:
|
||||
counts[t] = counts.get(t, 0) + 1
|
||||
if any(count >= 3 for count in counts.values()):
|
||||
print("[Brain] Detected heavy repetition, stopping early.")
|
||||
break
|
||||
|
||||
text = tokenizer.detokenize(generated)
|
||||
|
||||
# ✨ Simple sanity check: if text has >50% repeated words, retry once
|
||||
# ✨ Simple final sanity check
|
||||
words = text.split()
|
||||
if words:
|
||||
unique_ratio = len(set(words)) / len(words)
|
||||
|
Loading…
x
Reference in New Issue
Block a user