42 lines
1.4 KiB
Python
42 lines
1.4 KiB
Python
import torch
|
|
from model.brain_architecture import TinyTransformer
|
|
from model.brain_state import model, tokenizer, DEVICE, optimizer
|
|
import copy
|
|
|
|
|
|
def expand_model_if_needed():
|
|
current_vocab_size = len(tokenizer.vocab) + 10 # Tiny buffer
|
|
old_vocab_size = model.head.out_features
|
|
|
|
if current_vocab_size <= old_vocab_size:
|
|
return # No expansion needed
|
|
|
|
print(f"Expanding model from {old_vocab_size} -> {current_vocab_size}")
|
|
|
|
# Save old model
|
|
old_model = copy.deepcopy(model).to('cpu')
|
|
|
|
# Create new model
|
|
new_model = TinyTransformer(vocab_size=current_vocab_size).to(DEVICE)
|
|
new_optimizer = torch.optim.Adam(new_model.parameters(), lr=1e-4)
|
|
|
|
# Copy parameters
|
|
with torch.no_grad():
|
|
for name, param in old_model.named_parameters():
|
|
if name in dict(new_model.named_parameters()):
|
|
try:
|
|
new_param = dict(new_model.named_parameters())[name]
|
|
if param.shape == new_param.shape:
|
|
new_param.copy_(param)
|
|
else:
|
|
print(f"Skipping mismatched param: {name}")
|
|
except Exception as e:
|
|
print(f"Error copying param: {name} — {e}")
|
|
|
|
# Replace global references
|
|
globals()["model"] = new_model
|
|
globals()["optimizer"] = new_optimizer
|
|
|
|
print("Expansion complete.")
|
|
|