84 lines
2.4 KiB
Python
84 lines
2.4 KiB
Python
import random
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from model.memory import save_dream
|
|
from model.brain_state import model, tokenizer, DEVICE
|
|
from model.journal import record_to_journal
|
|
from model.trainer import train_on_message
|
|
from context.context import get_recent_context
|
|
|
|
recent_dreams = []
|
|
|
|
|
|
@torch.no_grad()
|
|
def generate_response():
|
|
model.eval()
|
|
context_texts = get_recent_context(10)
|
|
seed_text = " ".join(context_texts[-1:])
|
|
tokens = tokenizer.tokenize(seed_text)
|
|
|
|
input_tensor = torch.tensor(tokens, dtype=torch.long, device=DEVICE).unsqueeze(0)
|
|
|
|
output_tokens = []
|
|
max_tokens = 32
|
|
|
|
for _ in range(max_tokens):
|
|
output = model(input_tensor)
|
|
logits = output[:, -1, :].squeeze(0)
|
|
|
|
# Apply temperature (soft randomness)
|
|
temperature = 0.8
|
|
logits = logits / temperature
|
|
|
|
# Top-k sampling
|
|
k = 10
|
|
topk_logits, topk_indices = torch.topk(logits, k)
|
|
probs = torch.nn.functional.softmax(topk_logits, dim=-1)
|
|
next_token = topk_indices[torch.multinomial(probs, 1)].item()
|
|
|
|
output_tokens.append(next_token)
|
|
|
|
input_tensor = torch.cat([input_tensor, torch.tensor([[next_token]], device=DEVICE)], dim=1)
|
|
|
|
# Optional: stop if next_token maps to period, question mark, or exclamation
|
|
next_char = tokenizer.detokenize([next_token])
|
|
if any(p in next_char for p in [".", "?", "!"]):
|
|
break
|
|
|
|
text = tokenizer.detokenize(output_tokens)
|
|
return text
|
|
|
|
|
|
def score_sentence(sentence: str) -> float:
|
|
words = sentence.strip().split()
|
|
length = len(words)
|
|
diversity = len(set(words)) / (length + 1)
|
|
if length < 4:
|
|
return 0.0
|
|
return diversity * min(length, 20)
|
|
|
|
|
|
def daydream():
|
|
model.eval()
|
|
seed = torch.tensor([random.randint(0, tokenizer.next_id - 1)], device=DEVICE).unsqueeze(0)
|
|
dream = []
|
|
|
|
for _ in range(12):
|
|
out = model(seed)
|
|
logits = out[:, -1, :]
|
|
probs = F.softmax(logits, dim=-1)
|
|
token = torch.multinomial(probs, num_samples=1)
|
|
dream.append(token.item())
|
|
seed = torch.cat([seed, token], dim=1)
|
|
|
|
sentence = tokenizer.detokenize(dream)
|
|
score = score_sentence(sentence)
|
|
|
|
if score > 0.45:
|
|
save_dream(sentence, score)
|
|
record_to_journal(sentence)
|
|
train_on_message(sentence)
|
|
|
|
if len(recent_dreams) > 10:
|
|
recent_dreams.pop(0)
|