Trying to fix assert errors
This commit is contained in:
parent
e506032364
commit
6ab7b7586a
@ -6,8 +6,7 @@ from model.tokenizer import Tokenizer
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
tokenizer = Tokenizer()
|
||||
VOCAB_SIZE = len(tokenizer.vocab) + 10 # Slight buffer
|
||||
VOCAB_SIZE = len(tokenizer.vocab) + 10 # with a small buffer
|
||||
|
||||
model = TinyTransformer(vocab_size=VOCAB_SIZE).to(DEVICE)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
|
@ -1,41 +1,36 @@
|
||||
import torch
|
||||
from model.brain_architecture import TinyTransformer
|
||||
from model.brain_state import model, tokenizer, DEVICE, optimizer
|
||||
import copy
|
||||
from model.brain_state import model, tokenizer, DEVICE
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
||||
|
||||
|
||||
def get_optimizer():
|
||||
global optimizer
|
||||
return optimizer
|
||||
|
||||
|
||||
def expand_model_if_needed():
|
||||
current_vocab_size = len(tokenizer.vocab) + 10 # Tiny buffer
|
||||
global model, optimizer
|
||||
|
||||
current_vocab_size = len(tokenizer.vocab) + 10 # Buffer
|
||||
old_vocab_size = model.head.out_features
|
||||
|
||||
if current_vocab_size <= old_vocab_size:
|
||||
return # No expansion needed
|
||||
return
|
||||
|
||||
print(f"Expanding model from {old_vocab_size} -> {current_vocab_size}")
|
||||
|
||||
# Save old model
|
||||
old_model = copy.deepcopy(model).to('cpu')
|
||||
|
||||
# Create new model
|
||||
old_state = model.state_dict()
|
||||
new_model = TinyTransformer(vocab_size=current_vocab_size).to(DEVICE)
|
||||
new_optimizer = torch.optim.Adam(new_model.parameters(), lr=1e-4)
|
||||
|
||||
# Copy parameters
|
||||
# Transfer matching parameters
|
||||
with torch.no_grad():
|
||||
for name, param in old_model.named_parameters():
|
||||
if name in dict(new_model.named_parameters()):
|
||||
try:
|
||||
new_param = dict(new_model.named_parameters())[name]
|
||||
if param.shape == new_param.shape:
|
||||
new_param.copy_(param)
|
||||
else:
|
||||
print(f"Skipping mismatched param: {name}")
|
||||
except Exception as e:
|
||||
print(f"Error copying param: {name} — {e}")
|
||||
for name, param in new_model.named_parameters():
|
||||
if name in old_state and old_state[name].shape == param.shape:
|
||||
param.copy_(old_state[name])
|
||||
|
||||
# Replace global references
|
||||
globals()["model"] = new_model
|
||||
globals()["optimizer"] = new_optimizer
|
||||
|
||||
print("Expansion complete.")
|
||||
model = new_model
|
||||
opt = get_optimizer()
|
||||
|
||||
print("Model expanded and optimizer rebuilt.")
|
||||
|
@ -1,8 +1,8 @@
|
||||
import torch
|
||||
import time
|
||||
from model.brain_state import model, tokenizer, DEVICE, optimizer, loss_fn
|
||||
from model.brain_state import model, tokenizer, DEVICE,loss_fn
|
||||
from context.context import add_to_context, get_recent_context
|
||||
from model.dynamic_expand import expand_model_if_needed
|
||||
from model.dynamic_expand import expand_model_if_needed, get_optimizer
|
||||
from model.brainmap import update_brainmap
|
||||
|
||||
LOSS_FILE = "data/logs/loss.log"
|
||||
@ -31,9 +31,10 @@ def train_on_message(text: str):
|
||||
output = model(input_tensor)
|
||||
loss = loss_fn(output.view(-1, output.size(-1)), target_tensor.view(-1))
|
||||
|
||||
optimizer.zero_grad()
|
||||
opt = get_optimizer()
|
||||
opt.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
opt.step()
|
||||
|
||||
log_loss(loss.item())
|
||||
add_to_context(text)
|
||||
|
Loading…
x
Reference in New Issue
Block a user