From bae6bff5cec69460680673a206b35d6f9dfa763a Mon Sep 17 00:00:00 2001 From: Dani Date: Sun, 27 Apr 2025 14:03:37 -0400 Subject: [PATCH] Fixing how she replies. --- model/brain.py | 42 +++++++++++++++++------------------------- model/tokenizer.py | 2 ++ 2 files changed, 19 insertions(+), 25 deletions(-) diff --git a/model/brain.py b/model/brain.py index 2f265ec..a6c3b48 100644 --- a/model/brain.py +++ b/model/brain.py @@ -13,40 +13,32 @@ recent_dreams = [] @torch.no_grad() def generate_response(): model.eval() - context_texts = get_recent_context(10) - seed_text = " ".join(context_texts[-1:]) - tokens = tokenizer.tokenize(seed_text) - - input_tensor = torch.tensor(tokens, dtype=torch.long, device=DEVICE).unsqueeze(0) - + seed = torch.randint(0, model.head.out_features, (1, 1), device=DEVICE) + input_ids = seed output_tokens = [] - max_tokens = 32 - for _ in range(max_tokens): - output = model(input_tensor) - logits = output[:, -1, :].squeeze(0) + for _ in range(50): # Max 50 tokens (short sentences) + output = model(input_ids) + next_token_logits = output[:, -1, :] / 0.8 # temperature 0.8 - # Apply temperature (soft randomness) - temperature = 0.8 - logits = logits / temperature + # Top-K Sampling + top_k = 40 + values, indices = torch.topk(next_token_logits, k=top_k) + probs = F.softmax(values, dim=-1) + sampled_idx = torch.multinomial(probs, num_samples=1) - # Top-k sampling - k = 10 - topk_logits, topk_indices = torch.topk(logits, k) - probs = torch.nn.functional.softmax(topk_logits, dim=-1) - next_token = topk_indices[torch.multinomial(probs, 1)].item() + next_token = indices.gather(-1, sampled_idx) - output_tokens.append(next_token) + output_tokens.append(next_token.item()) - input_tensor = torch.cat([input_tensor, torch.tensor([[next_token]], device=DEVICE)], dim=1) + input_ids = torch.cat([input_ids, next_token.view(1, 1)], dim=1) - # Optional: stop if next_token maps to period, question mark, or exclamation - next_char = tokenizer.detokenize([next_token]) - if any(p in next_char for p in [".", "?", "!"]): + # Break if punctuation (end of sentence) + word = tokenizer.detokenize(next_token.item()) + if word in [".", "!", "?"]: break - text = tokenizer.detokenize(output_tokens) - return text + return tokenizer.detokenize(output_tokens) def score_sentence(sentence: str) -> float: diff --git a/model/tokenizer.py b/model/tokenizer.py index 5983b3f..663bbb2 100644 --- a/model/tokenizer.py +++ b/model/tokenizer.py @@ -36,4 +36,6 @@ class Tokenizer: return tokens def detokenize(self, tokens): + if isinstance(tokens, int): + tokens = [tokens] return " ".join(self.reverse_vocab.get(t, "") for t in tokens)