Ruby/model/dynamic_expand.py
2025-04-27 11:11:39 -04:00

44 lines
1.3 KiB
Python

import torch
import threading
from model.brain_architecture import TinyTransformer
from model.brain_state import model, tokenizer, DEVICE
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
_last_vocab_size = 0
_expand_lock = threading.Lock()
def get_optimizer():
return optimizer
def expand_model_if_needed():
global model, optimizer, _last_vocab_size
with _expand_lock:
current_vocab_size = len(tokenizer.vocab) + 10
if current_vocab_size - _last_vocab_size < 10:
return # Expand only after 10 new words collected
old_vocab_size = model.head.out_features
if current_vocab_size <= old_vocab_size:
return
# print(f"Expanding model from {old_vocab_size} -> {current_vocab_size}")
old_state = model.state_dict()
new_model = TinyTransformer(vocab_size=current_vocab_size).to(DEVICE)
with torch.no_grad():
for name, param in new_model.named_parameters():
if name in old_state and old_state[name].shape == param.shape:
param.copy_(old_state[name])
model = new_model
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
_last_vocab_size = current_vocab_size
# print("Expansion complete.")