import torch import asyncio import time from model.tokenizer import Tokenizer from model.brain_state import save_model, DEVICE, model, optimizer tokenizer = Tokenizer() expand_lock = asyncio.Lock() _last_expansion_time = 0 async def expand_model_if_needed(): global _last_expansion_time async with expand_lock: needed_vocab_size = tokenizer.next_id current_vocab_size = model.head.out_features if needed_vocab_size <= current_vocab_size: return print(f"[Expand] Expanding vocabulary: {current_vocab_size} -> {needed_vocab_size}") old_head_weight = model.head.weight.data old_out_features = old_head_weight.size(0) in_features = model.head.in_features new_head = torch.nn.Linear(in_features, needed_vocab_size, bias=False).to(DEVICE) with torch.no_grad(): new_head.weight[:old_out_features] = old_head_weight model.head = new_head torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.95) _last_expansion_time = time.time() save_model()