import torch import threading import time from model.brain_architecture import TinyTransformer from model.brain_state import model, tokenizer, DEVICE optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) _last_vocab_size = 0 expand_lock = threading.Lock() _last_expansion_time = 0 def get_optimizer(): return optimizer def expand_model_if_needed(): global model, optimizer, _last_expansion_time with expand_lock: current_vocab_size = len(tokenizer.vocab) + 10 old_vocab_size = model.head.out_features if current_vocab_size <= old_vocab_size: return # print(f"[Expand] Expanding model from {old_vocab_size} -> {current_vocab_size}") old_state = model.state_dict() new_model = TinyTransformer(vocab_size=current_vocab_size).to(DEVICE) with torch.no_grad(): for name, param in new_model.named_parameters(): if name in old_state and old_state[name].shape == param.shape: param.copy_(old_state[name]) model = new_model optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) _last_expansion_time = time.time() # print("[Expand] Expansion complete.")