22 lines
603 B
Python
22 lines
603 B
Python
import torch
|
|
from model.brain import model, tokenizer, DEVICE
|
|
from model.trainer import train_on_message
|
|
from model.dynamic_expand import expand_model_if_needed
|
|
|
|
|
|
def simulate_conversation():
|
|
expand_model_if_needed()
|
|
|
|
model.eval()
|
|
seed = torch.randint(0, tokenizer.next_id, (1, 5), device=DEVICE)
|
|
seed = seed[:, -128:] # Safety clamp
|
|
output = model(seed)
|
|
|
|
preds = torch.argmax(output, dim=-1).squeeze().tolist()
|
|
if isinstance(preds, int):
|
|
preds = [preds]
|
|
|
|
text = tokenizer.detokenize(preds)
|
|
if text and len(text.split()) >= 3:
|
|
train_on_message(text)
|