Attempted to fix more word salad
This commit is contained in:
parent
3674425a44
commit
daf381561e
@ -31,31 +31,39 @@ def generate_response(max_tokens: int = 50, temperature: float = 1.0):
|
|||||||
|
|
||||||
next_token_logits = output[:, -1, :]
|
next_token_logits = output[:, -1, :]
|
||||||
|
|
||||||
# 💬 Boost <end> token score after 10+ tokens to encourage ending
|
# ✨ Boost <end> token after 10 tokens to encourage stopping
|
||||||
if step >= 10:
|
if step >= 10:
|
||||||
end_token_id = tokenizer.token_to_id("<end>")
|
end_token_id = tokenizer.token_to_id("<end>")
|
||||||
next_token_logits[:, end_token_id] += 2.0 # Boost end token's chance
|
next_token_logits[:, end_token_id] += 3.0 # Stronger boost
|
||||||
|
|
||||||
probs = torch.softmax(next_token_logits / temperature, dim=-1)
|
probs = torch.softmax(next_token_logits / temperature, dim=-1)
|
||||||
next_token = torch.multinomial(probs, num_samples=1)
|
next_token = torch.multinomial(probs, num_samples=1)
|
||||||
|
|
||||||
# Resample if forbidden token
|
# Block forbidden tokens
|
||||||
while next_token.item() in forbidden_tokens:
|
while next_token.item() in forbidden_tokens:
|
||||||
next_token = torch.multinomial(probs, num_samples=1)
|
next_token = torch.multinomial(probs, num_samples=1)
|
||||||
|
|
||||||
token_id = next_token.item()
|
token_id = next_token.item()
|
||||||
|
|
||||||
# If <end> token, finish
|
|
||||||
if tokenizer.reverse_vocab.get(token_id, "") == "<end>":
|
if tokenizer.reverse_vocab.get(token_id, "") == "<end>":
|
||||||
break
|
break
|
||||||
|
|
||||||
generated.append(token_id)
|
generated.append(token_id)
|
||||||
|
|
||||||
input_ids = torch.cat([input_ids, next_token], dim=1)
|
input_ids = torch.cat([input_ids, next_token], dim=1)
|
||||||
|
|
||||||
|
# ✨ If the last 5 words repeat the same word 3+ times, early stop
|
||||||
|
if len(generated) >= 5:
|
||||||
|
recent = generated[-5:]
|
||||||
|
counts = {}
|
||||||
|
for t in recent:
|
||||||
|
counts[t] = counts.get(t, 0) + 1
|
||||||
|
if any(count >= 3 for count in counts.values()):
|
||||||
|
print("[Brain] Detected heavy repetition, stopping early.")
|
||||||
|
break
|
||||||
|
|
||||||
text = tokenizer.detokenize(generated)
|
text = tokenizer.detokenize(generated)
|
||||||
|
|
||||||
# ✨ Simple sanity check: if text has >50% repeated words, retry once
|
# ✨ Simple final sanity check
|
||||||
words = text.split()
|
words = text.split()
|
||||||
if words:
|
if words:
|
||||||
unique_ratio = len(set(words)) / len(words)
|
unique_ratio = len(set(words)) / len(words)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user