30 lines
921 B
Python
30 lines
921 B
Python
import torch
|
|
import time
|
|
from utils import sample_reply, update_model_vocab
|
|
from debug import DebugMonitor
|
|
|
|
debug = DebugMonitor()
|
|
|
|
|
|
def run_dream_loop(model, tokenizer, device, optimizer, train_step, interval=120):
|
|
print("Ruby is dreaming...")
|
|
while True:
|
|
reply, loss = generate_dream(model, tokenizer, device, optimizer, train_step)
|
|
print(f"[DREAM] {reply} (loss={loss:.4f})")
|
|
time.sleep(interval)
|
|
|
|
|
|
def generate_dream(model, tokenizer, device, optimizer, train_step):
|
|
update_model_vocab(model, tokenizer)
|
|
|
|
prompt = "Ruby: "
|
|
input_ids = tokenizer.encode(prompt, return_tensors=True, freeze=True).to(device)
|
|
|
|
reply = sample_reply(model, tokenizer, input_ids)
|
|
training_text = f"User: What do you think?\nRuby: {reply}"
|
|
|
|
loss = train_step(model, optimizer, tokenizer, training_text, device)
|
|
return reply, loss
|
|
debug.log_dream(reply)
|
|
debug.log_loss(loss)
|