From 6ab7b7586a229073770a298a2a414f65d5558ed3 Mon Sep 17 00:00:00 2001 From: Dani Date: Fri, 25 Apr 2025 23:16:18 -0400 Subject: [PATCH] Trying to fix assert errors --- model/brain_state.py | 3 +-- model/dynamic_expand.py | 45 ++++++++++++++++++----------------------- model/trainer.py | 9 +++++---- 3 files changed, 26 insertions(+), 31 deletions(-) diff --git a/model/brain_state.py b/model/brain_state.py index 7e3b94b..b86a798 100644 --- a/model/brain_state.py +++ b/model/brain_state.py @@ -6,8 +6,7 @@ from model.tokenizer import Tokenizer DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = Tokenizer() -VOCAB_SIZE = len(tokenizer.vocab) + 10 # Slight buffer +VOCAB_SIZE = len(tokenizer.vocab) + 10 # with a small buffer model = TinyTransformer(vocab_size=VOCAB_SIZE).to(DEVICE) -optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) loss_fn = nn.CrossEntropyLoss() diff --git a/model/dynamic_expand.py b/model/dynamic_expand.py index 7094927..f6f8e92 100644 --- a/model/dynamic_expand.py +++ b/model/dynamic_expand.py @@ -1,41 +1,36 @@ import torch from model.brain_architecture import TinyTransformer -from model.brain_state import model, tokenizer, DEVICE, optimizer -import copy +from model.brain_state import model, tokenizer, DEVICE + +optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) + + +def get_optimizer(): + global optimizer + return optimizer def expand_model_if_needed(): - current_vocab_size = len(tokenizer.vocab) + 10 # Tiny buffer + global model, optimizer + + current_vocab_size = len(tokenizer.vocab) + 10 # Buffer old_vocab_size = model.head.out_features if current_vocab_size <= old_vocab_size: - return # No expansion needed + return print(f"Expanding model from {old_vocab_size} -> {current_vocab_size}") - # Save old model - old_model = copy.deepcopy(model).to('cpu') - - # Create new model + old_state = model.state_dict() new_model = TinyTransformer(vocab_size=current_vocab_size).to(DEVICE) - new_optimizer = torch.optim.Adam(new_model.parameters(), lr=1e-4) - # Copy parameters + # Transfer matching parameters with torch.no_grad(): - for name, param in old_model.named_parameters(): - if name in dict(new_model.named_parameters()): - try: - new_param = dict(new_model.named_parameters())[name] - if param.shape == new_param.shape: - new_param.copy_(param) - else: - print(f"Skipping mismatched param: {name}") - except Exception as e: - print(f"Error copying param: {name} — {e}") + for name, param in new_model.named_parameters(): + if name in old_state and old_state[name].shape == param.shape: + param.copy_(old_state[name]) - # Replace global references - globals()["model"] = new_model - globals()["optimizer"] = new_optimizer - - print("Expansion complete.") + model = new_model + opt = get_optimizer() + print("Model expanded and optimizer rebuilt.") diff --git a/model/trainer.py b/model/trainer.py index 79c67f9..91f51b6 100644 --- a/model/trainer.py +++ b/model/trainer.py @@ -1,8 +1,8 @@ import torch import time -from model.brain_state import model, tokenizer, DEVICE, optimizer, loss_fn +from model.brain_state import model, tokenizer, DEVICE,loss_fn from context.context import add_to_context, get_recent_context -from model.dynamic_expand import expand_model_if_needed +from model.dynamic_expand import expand_model_if_needed, get_optimizer from model.brainmap import update_brainmap LOSS_FILE = "data/logs/loss.log" @@ -31,9 +31,10 @@ def train_on_message(text: str): output = model(input_tensor) loss = loss_fn(output.view(-1, output.size(-1)), target_tensor.view(-1)) - optimizer.zero_grad() + opt = get_optimizer() + opt.zero_grad() loss.backward() - optimizer.step() + opt.step() log_loss(loss.item()) add_to_context(text)