diff --git a/dashboard/templates/index.html b/dashboard/templates/index.html index 96533ce..ab1bef8 100644 --- a/dashboard/templates/index.html +++ b/dashboard/templates/index.html @@ -96,7 +96,7 @@

📉 Recent Loss

diff --git a/main.py b/main.py index 1553e27..bb18e03 100644 --- a/main.py +++ b/main.py @@ -72,6 +72,7 @@ def start_brain_loops(): loop.run_forever() +threading.Thread(target=run_dashboard, daemon=True).start() threading.Thread(target=start_brain_loops, daemon=True).start() # Launch Discord bot (blocking) diff --git a/model/dynamic_expand.py b/model/dynamic_expand.py index 46a6d1f..ada6331 100644 --- a/model/dynamic_expand.py +++ b/model/dynamic_expand.py @@ -1,12 +1,14 @@ import torch import threading +import time from model.brain_architecture import TinyTransformer from model.brain_state import model, tokenizer, DEVICE -optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) +optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) _last_vocab_size = 0 -_expand_lock = threading.Lock() +expand_lock = threading.Lock() +_last_expansion_time = 0 def get_optimizer(): @@ -14,20 +16,16 @@ def get_optimizer(): def expand_model_if_needed(): - global model, optimizer, _last_vocab_size + global model, optimizer, _last_expansion_time - with _expand_lock: + with expand_lock: current_vocab_size = len(tokenizer.vocab) + 10 - - if current_vocab_size - _last_vocab_size < 10: - return # Expand only after 10 new words collected - old_vocab_size = model.head.out_features if current_vocab_size <= old_vocab_size: return - # print(f"Expanding model from {old_vocab_size} -> {current_vocab_size}") + # print(f"[Expand] Expanding model from {old_vocab_size} -> {current_vocab_size}") old_state = model.state_dict() new_model = TinyTransformer(vocab_size=current_vocab_size).to(DEVICE) @@ -39,6 +37,6 @@ def expand_model_if_needed(): model = new_model optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) - _last_vocab_size = current_vocab_size + _last_expansion_time = time.time() - # print("Expansion complete.") \ No newline at end of file + # print("[Expand] Expansion complete.") diff --git a/model/trainer.py b/model/trainer.py index 1e9cee3..032c484 100644 --- a/model/trainer.py +++ b/model/trainer.py @@ -1,11 +1,12 @@ import torch import time -from model.dynamic_expand import expand_model_if_needed, get_optimizer +from model.dynamic_expand import expand_model_if_needed, _last_expansion_time, get_optimizer, expand_lock from model.brain_state import model, tokenizer, DEVICE, loss_fn from context.context import add_to_context, get_recent_context LOSS_FILE = "data/logs/loss.log" VOCAB_GROWTH_FILE = "data/logs/vocab_growth.log" +scheduler = torch.optim.lr_scheduler.StepLR(get_optimizer(), step_size=500, gamma=0.95) def log_vocab_growth(): @@ -21,35 +22,46 @@ def log_loss(value: float): def train_on_message(text: str, source: str = "user"): expand_model_if_needed() - model.train() - context_texts = get_recent_context(3) - augmented_text = " ".join(context_texts + [text]) - tokens = tokenizer.tokenize(augmented_text) - - if len(tokens) < 2: + now = time.time() + if now - _last_expansion_time < 5: # If expansion happened within the last 5 seconds + print("[Train] Skipping to stabilize after expansion.") return - # ✋ Clamp to model's known vocab - max_token_id = model.head.out_features - 1 - tokens = [t for t in tokens if t <= max_token_id] + if not expand_lock.acquire(timeout=0.5): + print("[Train] Skipped training due to active expansion.") + return - if len(tokens) < 2: - return # after filtering, too short to train + try: + model.train() + context_texts = get_recent_context(10) + augmented_text = " ".join(context_texts + [text]) + tokens = tokenizer.tokenize(augmented_text) - tokens = tokens[:128] # safety clamp - input_tensor = torch.tensor(tokens[:-1], dtype=torch.long, device=DEVICE).unsqueeze(0) - target_tensor = torch.tensor(tokens[1:], dtype=torch.long, device=DEVICE).unsqueeze(0) + if len(tokens) < 2: + return - opt = get_optimizer() + max_token_id = model.head.out_features - 1 + tokens = [min(t, max_token_id) for t in tokens] - output = model(input_tensor) + if len(tokens) < 2: + return - loss = loss_fn(output.view(-1, output.size(-1)), target_tensor.view(-1)) + tokens = tokens[:128] - opt.zero_grad() - loss.backward() - opt.step() + input_tensor = torch.tensor(tokens[:-1], dtype=torch.long, device=DEVICE).unsqueeze(0) + target_tensor = torch.tensor(tokens[1:], dtype=torch.long, device=DEVICE).unsqueeze(0) - log_loss(loss.item()) - log_vocab_growth() - add_to_context(text, source=source) + opt = get_optimizer() + + output = model(input_tensor) + loss = loss_fn(output.view(-1, output.size(-1)), target_tensor.view(-1)) + + opt.zero_grad() + loss.backward() + opt.step() + + log_loss(loss.item()) + log_vocab_growth() + add_to_context(text, source=source) + finally: + expand_lock.release()