Fixing how she replies.
This commit is contained in:
parent
cde0068725
commit
bae6bff5ce
@ -13,40 +13,32 @@ recent_dreams = []
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def generate_response():
|
def generate_response():
|
||||||
model.eval()
|
model.eval()
|
||||||
context_texts = get_recent_context(10)
|
seed = torch.randint(0, model.head.out_features, (1, 1), device=DEVICE)
|
||||||
seed_text = " ".join(context_texts[-1:])
|
input_ids = seed
|
||||||
tokens = tokenizer.tokenize(seed_text)
|
|
||||||
|
|
||||||
input_tensor = torch.tensor(tokens, dtype=torch.long, device=DEVICE).unsqueeze(0)
|
|
||||||
|
|
||||||
output_tokens = []
|
output_tokens = []
|
||||||
max_tokens = 32
|
|
||||||
|
|
||||||
for _ in range(max_tokens):
|
for _ in range(50): # Max 50 tokens (short sentences)
|
||||||
output = model(input_tensor)
|
output = model(input_ids)
|
||||||
logits = output[:, -1, :].squeeze(0)
|
next_token_logits = output[:, -1, :] / 0.8 # temperature 0.8
|
||||||
|
|
||||||
# Apply temperature (soft randomness)
|
# Top-K Sampling
|
||||||
temperature = 0.8
|
top_k = 40
|
||||||
logits = logits / temperature
|
values, indices = torch.topk(next_token_logits, k=top_k)
|
||||||
|
probs = F.softmax(values, dim=-1)
|
||||||
|
sampled_idx = torch.multinomial(probs, num_samples=1)
|
||||||
|
|
||||||
# Top-k sampling
|
next_token = indices.gather(-1, sampled_idx)
|
||||||
k = 10
|
|
||||||
topk_logits, topk_indices = torch.topk(logits, k)
|
|
||||||
probs = torch.nn.functional.softmax(topk_logits, dim=-1)
|
|
||||||
next_token = topk_indices[torch.multinomial(probs, 1)].item()
|
|
||||||
|
|
||||||
output_tokens.append(next_token)
|
output_tokens.append(next_token.item())
|
||||||
|
|
||||||
input_tensor = torch.cat([input_tensor, torch.tensor([[next_token]], device=DEVICE)], dim=1)
|
input_ids = torch.cat([input_ids, next_token.view(1, 1)], dim=1)
|
||||||
|
|
||||||
# Optional: stop if next_token maps to period, question mark, or exclamation
|
# Break if punctuation (end of sentence)
|
||||||
next_char = tokenizer.detokenize([next_token])
|
word = tokenizer.detokenize(next_token.item())
|
||||||
if any(p in next_char for p in [".", "?", "!"]):
|
if word in [".", "!", "?"]:
|
||||||
break
|
break
|
||||||
|
|
||||||
text = tokenizer.detokenize(output_tokens)
|
return tokenizer.detokenize(output_tokens)
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def score_sentence(sentence: str) -> float:
|
def score_sentence(sentence: str) -> float:
|
||||||
|
@ -36,4 +36,6 @@ class Tokenizer:
|
|||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
def detokenize(self, tokens):
|
def detokenize(self, tokens):
|
||||||
|
if isinstance(tokens, int):
|
||||||
|
tokens = [tokens]
|
||||||
return " ".join(self.reverse_vocab.get(t, "<unk>") for t in tokens)
|
return " ".join(self.reverse_vocab.get(t, "<unk>") for t in tokens)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user