import torch from brain.brain import model, tokenizer, DEVICE from utils.dynamic_expand import expand_model_if_needed from ego.trainer import train_on_message async def simulate_conversation(): await expand_model_if_needed() model.eval() max_token_id = model.head.out_features - 1 if max_token_id < 1: return seed = torch.randint(0, max_token_id + 1, (1, 5), device=DEVICE) seed = seed[:, -128:] output = model(seed) preds = torch.argmax(output, dim=-1).squeeze().tolist() if isinstance(preds, int): preds = [preds] preds = [min(max(p, 0), max_token_id) for p in preds] text = tokenizer.detokenize(preds) if text and len(text.split()) >= 3: await train_on_message(text)