Fixed dynamic expand

This commit is contained in:
Dani 2025-04-26 22:42:49 -04:00
parent 5f74b2c64c
commit a9b4871420
4 changed files with 12 additions and 8 deletions

View File

@ -3,6 +3,8 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from model.memory import save_dream from model.memory import save_dream
from model.brain_state import model, tokenizer, DEVICE from model.brain_state import model, tokenizer, DEVICE
from model.journal import record_to_journal
from model.trainer import train_on_message
from context.context import get_recent_context from context.context import get_recent_context
recent_dreams = [] recent_dreams = []
@ -59,9 +61,7 @@ def daydream():
if score > 0.45: if score > 0.45:
save_dream(sentence, score) save_dream(sentence, score)
from model.journal import record_to_journal
record_to_journal(sentence) record_to_journal(sentence)
from model.trainer import train_on_message
train_on_message(sentence) train_on_message(sentence)
if len(recent_dreams) > 10: if len(recent_dreams) > 10:

View File

@ -4,20 +4,24 @@ from model.brain_state import model, tokenizer, DEVICE
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
_last_expansion_vocab_size = 0
def get_optimizer(): def get_optimizer():
return optimizer return optimizer
def expand_model_if_needed(): def expand_model_if_needed():
global model, optimizer global model, optimizer, _last_expansion_vocab_size
current_vocab_size = len(tokenizer.vocab) + 10 current_vocab_size = len(tokenizer.vocab) + 10
old_vocab_size = model.head.out_features
if current_vocab_size - _last_expansion_vocab_size < 5:
return # Only expand every 5 words
old_vocab_size = model.head.out_features
if current_vocab_size <= old_vocab_size: if current_vocab_size <= old_vocab_size:
return # No expansion needed return # No expansion needed
print(f"Expanding model from {old_vocab_size} -> {current_vocab_size}") print(f"Expanding model from {old_vocab_size} -> {current_vocab_size}")
old_state = model.state_dict() old_state = model.state_dict()

View File

@ -12,7 +12,7 @@ def log_loss(value: float):
f.write(f"{time.time()},{round(value, 4)}\n") f.write(f"{time.time()},{round(value, 4)}\n")
def train_on_message(text: str): def train_on_message(text: str, source: str = "user"):
expand_model_if_needed() expand_model_if_needed()
model.train() model.train()
@ -45,4 +45,4 @@ def train_on_message(text: str):
opt.step() opt.step()
log_loss(loss.item()) log_loss(loss.item())
add_to_context(text) add_to_context(text, source=source)

View File

@ -46,6 +46,6 @@ async def read_books_forever():
save_progress(progress) save_progress(progress)
if is_valid_line(line): if is_valid_line(line):
train_on_message(line) train_on_message(line, source="book")
set_next_action(READ_DELAY, "Reading") set_next_action(READ_DELAY, "Reading")
await asyncio.sleep(READ_DELAY) await asyncio.sleep(READ_DELAY)