import torch import time from model.dynamic_expand import expand_model_if_needed, _last_expansion_time, expand_lock from model.brain_state import model, tokenizer, DEVICE, loss_fn, optimizer, scheduler from model.brainmap import add_to_brainmap from model.journal import record_to_journal from context.context import add_to_context, get_recent_context LOSS_FILE = "data/logs/loss.log" VOCAB_GROWTH_FILE = "data/logs/vocab_growth.log" def log_vocab_growth(): with open(VOCAB_GROWTH_FILE, "a", encoding="utf-8") as f: f.write(f"{time.time()},{len(tokenizer.vocab)}\n") def log_loss(value: float): with open(LOSS_FILE, "a", encoding="utf-8") as f: f.write(f"{time.time()},{round(value, 4)}\n") def train_on_message(text: str, source: str = "user"): expand_model_if_needed() now = time.time() if now - _last_expansion_time < 5: print("[Trainer] Skipping to stabilize after expansion.") return if not expand_lock.acquire(timeout=0.5): print("[Trainer] Skipped training due to active expansion.") return try: model.train() context_texts = get_recent_context(10) # Augment the input with recent context augmented_text = " " + " ".join(context_texts + [text]) + " " tokens = tokenizer.tokenize(augmented_text) if len(tokens) < 2: print("[Trainer] Message too short after cleaning.") return # Clamp any token IDs beyond the model's output size max_token_id = model.head.out_features - 1 tokens = [min(t, max_token_id) for t in tokens] tokens = tokens[:128] # Hard clamp input length if len(tokens) < 2: print("[Trainer] Message too short after clamping.") return input_tensor = torch.tensor(tokens[:-1], dtype=torch.long, device=DEVICE).unsqueeze(0) target_tensor = torch.tensor(tokens[1:], dtype=torch.long, device=DEVICE).unsqueeze(0) output = model(input_tensor) loss = loss_fn(output.view(-1, output.size(-1)), target_tensor.view(-1)) if torch.isnan(loss): print("[Trainer] Detected NaN loss, skipping update.") return optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() # Update brainmap and context add_to_brainmap(augmented_text.split()) add_to_context(text, source=source) # Log training success to journal record_to_journal({ "timestamp": time.time(), "source": source, "text": text, "loss": round(loss.item(), 4), "vocab_size": len(tokenizer.vocab) }) finally: expand_lock.release()