Fixing loss issues,
This commit is contained in:
parent
c9601f132b
commit
19cac46c0c
@ -96,7 +96,7 @@
|
|||||||
<div class="section">
|
<div class="section">
|
||||||
<h2>📉 Recent Loss</h2>
|
<h2>📉 Recent Loss</h2>
|
||||||
<ul>
|
<ul>
|
||||||
{% for loss in loss_data %}
|
{% for loss in loss_data[-10:]|reverse %}
|
||||||
<li>{{ loss }}</li>
|
<li>{{ loss }}</li>
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
</ul>
|
</ul>
|
||||||
|
1
main.py
1
main.py
@ -72,6 +72,7 @@ def start_brain_loops():
|
|||||||
loop.run_forever()
|
loop.run_forever()
|
||||||
|
|
||||||
|
|
||||||
|
threading.Thread(target=run_dashboard, daemon=True).start()
|
||||||
threading.Thread(target=start_brain_loops, daemon=True).start()
|
threading.Thread(target=start_brain_loops, daemon=True).start()
|
||||||
|
|
||||||
# Launch Discord bot (blocking)
|
# Launch Discord bot (blocking)
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
import torch
|
import torch
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
from model.brain_architecture import TinyTransformer
|
from model.brain_architecture import TinyTransformer
|
||||||
from model.brain_state import model, tokenizer, DEVICE
|
from model.brain_state import model, tokenizer, DEVICE
|
||||||
|
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
|
||||||
|
|
||||||
_last_vocab_size = 0
|
_last_vocab_size = 0
|
||||||
_expand_lock = threading.Lock()
|
expand_lock = threading.Lock()
|
||||||
|
_last_expansion_time = 0
|
||||||
|
|
||||||
|
|
||||||
def get_optimizer():
|
def get_optimizer():
|
||||||
@ -14,20 +16,16 @@ def get_optimizer():
|
|||||||
|
|
||||||
|
|
||||||
def expand_model_if_needed():
|
def expand_model_if_needed():
|
||||||
global model, optimizer, _last_vocab_size
|
global model, optimizer, _last_expansion_time
|
||||||
|
|
||||||
with _expand_lock:
|
with expand_lock:
|
||||||
current_vocab_size = len(tokenizer.vocab) + 10
|
current_vocab_size = len(tokenizer.vocab) + 10
|
||||||
|
|
||||||
if current_vocab_size - _last_vocab_size < 10:
|
|
||||||
return # Expand only after 10 new words collected
|
|
||||||
|
|
||||||
old_vocab_size = model.head.out_features
|
old_vocab_size = model.head.out_features
|
||||||
|
|
||||||
if current_vocab_size <= old_vocab_size:
|
if current_vocab_size <= old_vocab_size:
|
||||||
return
|
return
|
||||||
|
|
||||||
# print(f"Expanding model from {old_vocab_size} -> {current_vocab_size}")
|
# print(f"[Expand] Expanding model from {old_vocab_size} -> {current_vocab_size}")
|
||||||
|
|
||||||
old_state = model.state_dict()
|
old_state = model.state_dict()
|
||||||
new_model = TinyTransformer(vocab_size=current_vocab_size).to(DEVICE)
|
new_model = TinyTransformer(vocab_size=current_vocab_size).to(DEVICE)
|
||||||
@ -39,6 +37,6 @@ def expand_model_if_needed():
|
|||||||
|
|
||||||
model = new_model
|
model = new_model
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
||||||
_last_vocab_size = current_vocab_size
|
_last_expansion_time = time.time()
|
||||||
|
|
||||||
# print("Expansion complete.")
|
# print("[Expand] Expansion complete.")
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
import torch
|
import torch
|
||||||
import time
|
import time
|
||||||
from model.dynamic_expand import expand_model_if_needed, get_optimizer
|
from model.dynamic_expand import expand_model_if_needed, _last_expansion_time, get_optimizer, expand_lock
|
||||||
from model.brain_state import model, tokenizer, DEVICE, loss_fn
|
from model.brain_state import model, tokenizer, DEVICE, loss_fn
|
||||||
from context.context import add_to_context, get_recent_context
|
from context.context import add_to_context, get_recent_context
|
||||||
|
|
||||||
LOSS_FILE = "data/logs/loss.log"
|
LOSS_FILE = "data/logs/loss.log"
|
||||||
VOCAB_GROWTH_FILE = "data/logs/vocab_growth.log"
|
VOCAB_GROWTH_FILE = "data/logs/vocab_growth.log"
|
||||||
|
scheduler = torch.optim.lr_scheduler.StepLR(get_optimizer(), step_size=500, gamma=0.95)
|
||||||
|
|
||||||
|
|
||||||
def log_vocab_growth():
|
def log_vocab_growth():
|
||||||
@ -21,35 +22,46 @@ def log_loss(value: float):
|
|||||||
def train_on_message(text: str, source: str = "user"):
|
def train_on_message(text: str, source: str = "user"):
|
||||||
expand_model_if_needed()
|
expand_model_if_needed()
|
||||||
|
|
||||||
model.train()
|
now = time.time()
|
||||||
context_texts = get_recent_context(3)
|
if now - _last_expansion_time < 5: # If expansion happened within the last 5 seconds
|
||||||
augmented_text = " ".join(context_texts + [text])
|
print("[Train] Skipping to stabilize after expansion.")
|
||||||
tokens = tokenizer.tokenize(augmented_text)
|
|
||||||
|
|
||||||
if len(tokens) < 2:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# ✋ Clamp to model's known vocab
|
if not expand_lock.acquire(timeout=0.5):
|
||||||
max_token_id = model.head.out_features - 1
|
print("[Train] Skipped training due to active expansion.")
|
||||||
tokens = [t for t in tokens if t <= max_token_id]
|
return
|
||||||
|
|
||||||
if len(tokens) < 2:
|
try:
|
||||||
return # after filtering, too short to train
|
model.train()
|
||||||
|
context_texts = get_recent_context(10)
|
||||||
|
augmented_text = " ".join(context_texts + [text])
|
||||||
|
tokens = tokenizer.tokenize(augmented_text)
|
||||||
|
|
||||||
tokens = tokens[:128] # safety clamp
|
if len(tokens) < 2:
|
||||||
input_tensor = torch.tensor(tokens[:-1], dtype=torch.long, device=DEVICE).unsqueeze(0)
|
return
|
||||||
target_tensor = torch.tensor(tokens[1:], dtype=torch.long, device=DEVICE).unsqueeze(0)
|
|
||||||
|
|
||||||
opt = get_optimizer()
|
max_token_id = model.head.out_features - 1
|
||||||
|
tokens = [min(t, max_token_id) for t in tokens]
|
||||||
|
|
||||||
output = model(input_tensor)
|
if len(tokens) < 2:
|
||||||
|
return
|
||||||
|
|
||||||
loss = loss_fn(output.view(-1, output.size(-1)), target_tensor.view(-1))
|
tokens = tokens[:128]
|
||||||
|
|
||||||
opt.zero_grad()
|
input_tensor = torch.tensor(tokens[:-1], dtype=torch.long, device=DEVICE).unsqueeze(0)
|
||||||
loss.backward()
|
target_tensor = torch.tensor(tokens[1:], dtype=torch.long, device=DEVICE).unsqueeze(0)
|
||||||
opt.step()
|
|
||||||
|
|
||||||
log_loss(loss.item())
|
opt = get_optimizer()
|
||||||
log_vocab_growth()
|
|
||||||
add_to_context(text, source=source)
|
output = model(input_tensor)
|
||||||
|
loss = loss_fn(output.view(-1, output.size(-1)), target_tensor.view(-1))
|
||||||
|
|
||||||
|
opt.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
opt.step()
|
||||||
|
|
||||||
|
log_loss(loss.item())
|
||||||
|
log_vocab_growth()
|
||||||
|
add_to_context(text, source=source)
|
||||||
|
finally:
|
||||||
|
expand_lock.release()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user