RubyOld/dream.py
2025-04-08 19:52:01 -04:00

30 lines
921 B
Python

import torch
import time
from utils import sample_reply, update_model_vocab
from debug import DebugMonitor
debug = DebugMonitor()
def run_dream_loop(model, tokenizer, device, optimizer, train_step, interval=120):
print("Ruby is dreaming...")
while True:
reply, loss = generate_dream(model, tokenizer, device, optimizer, train_step)
print(f"[DREAM] {reply} (loss={loss:.4f})")
time.sleep(interval)
def generate_dream(model, tokenizer, device, optimizer, train_step):
update_model_vocab(model, tokenizer)
prompt = "Ruby: "
input_ids = tokenizer.encode(prompt, return_tensors=True, freeze=True).to(device)
reply = sample_reply(model, tokenizer, input_ids)
training_text = f"User: What do you think?\nRuby: {reply}"
loss = train_step(model, optimizer, tokenizer, training_text, device)
return reply, loss
debug.log_dream(reply)
debug.log_loss(loss)