76 lines
2.3 KiB
Python
76 lines
2.3 KiB
Python
import torch
|
|
import time
|
|
from model.dynamic_expand import expand_model_if_needed, _last_expansion_time
|
|
from model.brain_state import model, tokenizer, DEVICE, loss_fn, optimizer, scheduler
|
|
from model.brainmap import add_to_brainmap
|
|
from model.journal import record_to_journal
|
|
from context.context import add_to_context, get_recent_context
|
|
|
|
LOSS_FILE = "data/logs/loss.log"
|
|
VOCAB_GROWTH_FILE = "data/logs/vocab_growth.log"
|
|
|
|
|
|
def log_vocab_growth():
|
|
with open(VOCAB_GROWTH_FILE, "a", encoding="utf-8") as f:
|
|
f.write(f"{time.time()},{len(tokenizer.vocab)}\n")
|
|
|
|
|
|
def log_loss(value: float):
|
|
with open(LOSS_FILE, "a", encoding="utf-8") as f:
|
|
f.write(f"{time.time()},{round(value, 4)}\n")
|
|
|
|
|
|
async def train_on_message(text: str, source: str = "user"):
|
|
await expand_model_if_needed()
|
|
|
|
now = time.time()
|
|
if now - _last_expansion_time < 5:
|
|
print("[Trainer] Skipping to stabilize after expansion.")
|
|
return
|
|
|
|
model.train()
|
|
context_texts = get_recent_context(10)
|
|
augmented_text = "<start> " + " ".join(context_texts + [text]) + " <end>"
|
|
tokens = tokenizer.tokenize(augmented_text)
|
|
|
|
if len(tokens) < 2:
|
|
print("[Trainer] Message too short after cleaning.")
|
|
return
|
|
|
|
max_token_id = model.head.out_features - 1
|
|
tokens = [max(0, min(t, max_token_id)) for t in tokens][:128]
|
|
|
|
for t in tokens:
|
|
if t > max_token_id or t < 0:
|
|
print(f"[Trainer] Invalid token ID {t} (max={max_token_id})")
|
|
return
|
|
|
|
if len(tokens) < 2:
|
|
print("[Trainer] Message too short after clamping.")
|
|
return
|
|
|
|
input_tensor = torch.tensor(tokens[:-1], dtype=torch.long, device=DEVICE).unsqueeze(0)
|
|
target_tensor = torch.tensor(tokens[1:], dtype=torch.long, device=DEVICE).unsqueeze(0)
|
|
|
|
output = model(input_tensor)
|
|
loss = loss_fn(output.view(-1, output.size(-1)), target_tensor.view(-1))
|
|
if torch.isnan(loss):
|
|
print("[Trainer] Detected NaN loss, skipping update.")
|
|
return
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
scheduler.step()
|
|
|
|
add_to_brainmap(augmented_text.split())
|
|
add_to_context(text, source=source)
|
|
|
|
record_to_journal({
|
|
"timestamp": time.time(),
|
|
"source": source,
|
|
"text": text,
|
|
"loss": round(loss.item(), 4),
|
|
"vocab_size": len(tokenizer.vocab)
|
|
})
|