Ruby/model/trainer.py
Dani f41d14075e clenaed up gitignore.
Hopefully fixed up another set of cuda errors
2025-04-28 23:07:18 -04:00

88 lines
2.8 KiB
Python

import torch
import time
from model.dynamic_expand import expand_model_if_needed, _last_expansion_time, expand_lock
from model.brain_state import model, tokenizer, DEVICE, loss_fn, optimizer, scheduler
from model.brainmap import add_to_brainmap
from model.journal import record_to_journal
from context.context import add_to_context, get_recent_context
LOSS_FILE = "data/logs/loss.log"
VOCAB_GROWTH_FILE = "data/logs/vocab_growth.log"
def log_vocab_growth():
with open(VOCAB_GROWTH_FILE, "a", encoding="utf-8") as f:
f.write(f"{time.time()},{len(tokenizer.vocab)}\n")
def log_loss(value: float):
with open(LOSS_FILE, "a", encoding="utf-8") as f:
f.write(f"{time.time()},{round(value, 4)}\n")
def train_on_message(text: str, source: str = "user"):
expand_model_if_needed()
now = time.time()
if now - _last_expansion_time < 5:
print("[Trainer] Skipping to stabilize after expansion.")
return
if not expand_lock.acquire(timeout=0.5):
print("[Trainer] Skipped training due to active expansion.")
return
try:
model.train()
context_texts = get_recent_context(10)
# Augment the input with recent context
augmented_text = "<start> " + " ".join(context_texts + [text]) + " <end>"
tokens = tokenizer.tokenize(augmented_text)
if len(tokens) < 2:
print("[Trainer] Message too short after cleaning.")
return
# Clamp any token IDs beyond the model's output size
max_token_id = model.head.out_features - 1
if tokenizer.next_id > model.head.out_features:
expand_model_if_needed()
tokens = [t if t <= max_token_id else max_token_id for t in tokens]
tokens = tokens[:128] # Hard clamp input length
if len(tokens) < 2:
print("[Trainer] Message too short after clamping.")
return
input_tensor = torch.tensor(tokens[:-1], dtype=torch.long, device=DEVICE).unsqueeze(0)
target_tensor = torch.tensor(tokens[1:], dtype=torch.long, device=DEVICE).unsqueeze(0)
output = model(input_tensor)
loss = loss_fn(output.view(-1, output.size(-1)), target_tensor.view(-1))
if torch.isnan(loss):
print("[Trainer] Detected NaN loss, skipping update.")
return
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
# Update brainmap and context
add_to_brainmap(augmented_text.split())
add_to_context(text, source=source)
# Log training success to journal
record_to_journal({
"timestamp": time.time(),
"source": source,
"text": text,
"loss": round(loss.item(), 4),
"vocab_size": len(tokenizer.vocab)
})
finally:
expand_lock.release()