Ruby/model/dynamic_expand.py

44 lines
1.3 KiB
Python

import torch
import threading
import time
from model.tokenizer import Tokenizer
from model.brain_state import save_model, DEVICE, model, optimizer
tokenizer = Tokenizer()
expand_lock = threading.Lock()
_last_expansion_time = 0
def expand_model_if_needed():
global _last_expansion_time
with expand_lock:
# Check if expansion is actually needed
needed_vocab_size = tokenizer.next_id
current_vocab_size = model.head.out_features
if needed_vocab_size <= current_vocab_size:
return # ✅ No expansion needed
# print(f"[Expand] Expanding vocabulary: {current_vocab_size} -> {needed_vocab_size}")
# Expand the head layer safely without rebuilding everything
old_head_weight = model.head.weight.data
old_out_features = old_head_weight.size(0)
in_features = model.head.in_features
new_head = torch.nn.Linear(in_features, needed_vocab_size, bias=False)
new_head = new_head.to(DEVICE)
# Copy old weights into the new head
with torch.no_grad():
new_head.weight[:old_out_features] = old_head_weight
model.head = new_head
# Rebuild optimizer and scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.95)
_last_expansion_time = time.time()
save_model()