Fixing loss issues,
This commit is contained in:
parent
c9601f132b
commit
19cac46c0c
@ -96,7 +96,7 @@
|
||||
<div class="section">
|
||||
<h2>📉 Recent Loss</h2>
|
||||
<ul>
|
||||
{% for loss in loss_data %}
|
||||
{% for loss in loss_data[-10:]|reverse %}
|
||||
<li>{{ loss }}</li>
|
||||
{% endfor %}
|
||||
</ul>
|
||||
|
1
main.py
1
main.py
@ -72,6 +72,7 @@ def start_brain_loops():
|
||||
loop.run_forever()
|
||||
|
||||
|
||||
threading.Thread(target=run_dashboard, daemon=True).start()
|
||||
threading.Thread(target=start_brain_loops, daemon=True).start()
|
||||
|
||||
# Launch Discord bot (blocking)
|
||||
|
@ -1,12 +1,14 @@
|
||||
import torch
|
||||
import threading
|
||||
import time
|
||||
from model.brain_architecture import TinyTransformer
|
||||
from model.brain_state import model, tokenizer, DEVICE
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
|
||||
|
||||
_last_vocab_size = 0
|
||||
_expand_lock = threading.Lock()
|
||||
expand_lock = threading.Lock()
|
||||
_last_expansion_time = 0
|
||||
|
||||
|
||||
def get_optimizer():
|
||||
@ -14,20 +16,16 @@ def get_optimizer():
|
||||
|
||||
|
||||
def expand_model_if_needed():
|
||||
global model, optimizer, _last_vocab_size
|
||||
global model, optimizer, _last_expansion_time
|
||||
|
||||
with _expand_lock:
|
||||
with expand_lock:
|
||||
current_vocab_size = len(tokenizer.vocab) + 10
|
||||
|
||||
if current_vocab_size - _last_vocab_size < 10:
|
||||
return # Expand only after 10 new words collected
|
||||
|
||||
old_vocab_size = model.head.out_features
|
||||
|
||||
if current_vocab_size <= old_vocab_size:
|
||||
return
|
||||
|
||||
# print(f"Expanding model from {old_vocab_size} -> {current_vocab_size}")
|
||||
# print(f"[Expand] Expanding model from {old_vocab_size} -> {current_vocab_size}")
|
||||
|
||||
old_state = model.state_dict()
|
||||
new_model = TinyTransformer(vocab_size=current_vocab_size).to(DEVICE)
|
||||
@ -39,6 +37,6 @@ def expand_model_if_needed():
|
||||
|
||||
model = new_model
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
||||
_last_vocab_size = current_vocab_size
|
||||
_last_expansion_time = time.time()
|
||||
|
||||
# print("Expansion complete.")
|
||||
# print("[Expand] Expansion complete.")
|
||||
|
@ -1,11 +1,12 @@
|
||||
import torch
|
||||
import time
|
||||
from model.dynamic_expand import expand_model_if_needed, get_optimizer
|
||||
from model.dynamic_expand import expand_model_if_needed, _last_expansion_time, get_optimizer, expand_lock
|
||||
from model.brain_state import model, tokenizer, DEVICE, loss_fn
|
||||
from context.context import add_to_context, get_recent_context
|
||||
|
||||
LOSS_FILE = "data/logs/loss.log"
|
||||
VOCAB_GROWTH_FILE = "data/logs/vocab_growth.log"
|
||||
scheduler = torch.optim.lr_scheduler.StepLR(get_optimizer(), step_size=500, gamma=0.95)
|
||||
|
||||
|
||||
def log_vocab_growth():
|
||||
@ -21,29 +22,38 @@ def log_loss(value: float):
|
||||
def train_on_message(text: str, source: str = "user"):
|
||||
expand_model_if_needed()
|
||||
|
||||
now = time.time()
|
||||
if now - _last_expansion_time < 5: # If expansion happened within the last 5 seconds
|
||||
print("[Train] Skipping to stabilize after expansion.")
|
||||
return
|
||||
|
||||
if not expand_lock.acquire(timeout=0.5):
|
||||
print("[Train] Skipped training due to active expansion.")
|
||||
return
|
||||
|
||||
try:
|
||||
model.train()
|
||||
context_texts = get_recent_context(3)
|
||||
context_texts = get_recent_context(10)
|
||||
augmented_text = " ".join(context_texts + [text])
|
||||
tokens = tokenizer.tokenize(augmented_text)
|
||||
|
||||
if len(tokens) < 2:
|
||||
return
|
||||
|
||||
# ✋ Clamp to model's known vocab
|
||||
max_token_id = model.head.out_features - 1
|
||||
tokens = [t for t in tokens if t <= max_token_id]
|
||||
tokens = [min(t, max_token_id) for t in tokens]
|
||||
|
||||
if len(tokens) < 2:
|
||||
return # after filtering, too short to train
|
||||
return
|
||||
|
||||
tokens = tokens[:128]
|
||||
|
||||
tokens = tokens[:128] # safety clamp
|
||||
input_tensor = torch.tensor(tokens[:-1], dtype=torch.long, device=DEVICE).unsqueeze(0)
|
||||
target_tensor = torch.tensor(tokens[1:], dtype=torch.long, device=DEVICE).unsqueeze(0)
|
||||
|
||||
opt = get_optimizer()
|
||||
|
||||
output = model(input_tensor)
|
||||
|
||||
loss = loss_fn(output.view(-1, output.size(-1)), target_tensor.view(-1))
|
||||
|
||||
opt.zero_grad()
|
||||
@ -53,3 +63,5 @@ def train_on_message(text: str, source: str = "user"):
|
||||
log_loss(loss.item())
|
||||
log_vocab_growth()
|
||||
add_to_context(text, source=source)
|
||||
finally:
|
||||
expand_lock.release()
|
||||
|
Loading…
x
Reference in New Issue
Block a user