import random import torch import torch.nn.functional as F from model.memory import save_dream from model.brain_state import model, tokenizer, DEVICE from model.journal import record_to_journal from model.trainer import train_on_message from context.context import get_recent_context recent_dreams = [] @torch.inference_mode() def generate_response(max_tokens: int = 50, temperature: float = 1.0): model.eval() input_ids = torch.tensor([tokenizer.token_to_id("")], device=DEVICE).unsqueeze(0) generated = [] forbidden_tokens = { tokenizer.token_to_id(""), tokenizer.token_to_id(""), tokenizer.token_to_id(""), tokenizer.token_to_id(""), tokenizer.token_to_id(""), } for _ in range(max_tokens): output = model(input_ids) if torch.isnan(output).any(): print("[Brain] Detected NaN in output, restarting generation.") return "..." next_token_logits = output[:, -1, :] probs = torch.softmax(next_token_logits / temperature, dim=-1) next_token = torch.multinomial(probs, num_samples=1) # Resample if forbidden token while next_token.item() in forbidden_tokens: next_token = torch.multinomial(probs, num_samples=1) token_id = next_token.item() if tokenizer.reverse_vocab.get(token_id, "") == "": break generated.append(token_id) input_ids = torch.cat([input_ids, next_token], dim=1) return tokenizer.detokenize(generated) def score_sentence(sentence: str) -> float: words = sentence.strip().split() unique = set(words) length = len(words) unique_ratio = len(unique) / (length + 1) if length < 5: return 0.0 if unique_ratio < 0.5: return 0.0 return unique_ratio * min(length / 20.0, 1.0) def daydream(): model.eval() seed = torch.tensor([random.randint(0, tokenizer.next_id - 1)], device=DEVICE).unsqueeze(0) dream = [] for _ in range(12): out = model(seed) logits = out[:, -1, :] probs = F.softmax(logits, dim=-1) token = torch.multinomial(probs, num_samples=1) dream.append(token.item()) seed = torch.cat([seed, token], dim=1) sentence = tokenizer.detokenize(dream) score = score_sentence(sentence) if score > 0.5: save_dream(sentence, score) record_to_journal(sentence) train_on_message(sentence) if len(recent_dreams) > 10: recent_dreams.pop(0)