90 lines
2.9 KiB
Python
90 lines
2.9 KiB
Python
import random
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from brain.brain import model, tokenizer, DEVICE, score_sentence
|
|
from brain.brainmap import add_to_brainmap
|
|
from ego.trainer import train_on_message
|
|
from ego.context import load_context, get_recent_context
|
|
from ego.dreams import save_dream, load_dreams
|
|
from ego.journal import record_to_journal
|
|
from utils.dynamic_expand import expand_model_if_needed
|
|
|
|
|
|
recent_dreams = []
|
|
|
|
|
|
async def daydream():
|
|
model.eval()
|
|
max_token_id = model.head.out_features - 1
|
|
|
|
# 🧠 Seed from recent context
|
|
context = get_recent_context(5)
|
|
if context:
|
|
tokens = tokenizer.tokenize(" ".join(context))[:16]
|
|
if not tokens:
|
|
seed = torch.randint(0, max_token_id + 1, (1, 1), device=DEVICE)
|
|
else:
|
|
tokens = [max(0, min(t, max_token_id)) for t in tokens]
|
|
seed = torch.tensor([tokens], dtype=torch.long, device=DEVICE)
|
|
else:
|
|
seed = torch.randint(0, max_token_id + 1, (1, 1), device=DEVICE)
|
|
|
|
dream = []
|
|
|
|
for _ in range(6): # shorter for early-stage models
|
|
# Build logits
|
|
out = model(seed)
|
|
logits = out[:, -1, :]
|
|
|
|
# Repetition penalty
|
|
recent_tokens = seed[0, -4:].tolist()
|
|
for t in recent_tokens:
|
|
logits[0, t] *= 0.75 # discourage repeat tokens
|
|
|
|
# Sample with top-k filtering
|
|
probs = F.softmax(logits, dim=-1)
|
|
k = min(20, probs.size(-1))
|
|
top_probs, top_indices = torch.topk(probs, k=k, dim=-1)
|
|
top_probs = top_probs / top_probs.sum(dim=-1, keepdim=True)
|
|
token = top_indices.gather(1, torch.multinomial(top_probs, num_samples=1))
|
|
|
|
token = torch.clamp(token, max=max_token_id)
|
|
dream.append(token.item())
|
|
seed = torch.cat([seed, token], dim=1)
|
|
|
|
sentence = tokenizer.detokenize(dream)
|
|
score = score_sentence(sentence)
|
|
|
|
unique_ratio = len(set(sentence.split())) / len(sentence.split())
|
|
if unique_ratio < 0.5:
|
|
print(f"[Dreamer] Skipped low-variance dream: '{sentence}'")
|
|
return
|
|
|
|
print(f"[Dreamer] Dream: '{sentence}' | Score: {round(score, 3)}")
|
|
|
|
if score >= 0.3:
|
|
save_dream(sentence, score)
|
|
record_to_journal(sentence)
|
|
add_to_brainmap(sentence.split())
|
|
await train_on_message(sentence)
|
|
|
|
|
|
async def replay_dreams():
|
|
await expand_model_if_needed()
|
|
dreams = load_dreams()
|
|
context = load_context()
|
|
|
|
if not dreams or not context:
|
|
return
|
|
|
|
selected_dreams = random.sample(dreams, min(len(dreams), 5))
|
|
selected_contexts = random.sample(context, min(len(context), 5))
|
|
|
|
all_sources = [d["sentence"] for d in selected_dreams] + [c["text"] for c in selected_contexts]
|
|
random.shuffle(all_sources)
|
|
|
|
mixed_sentence = " ".join(random.sample(all_sources, min(len(all_sources), 3)))
|
|
|
|
if mixed_sentence:
|
|
await train_on_message(mixed_sentence, source="dream")
|