44 lines
1.3 KiB
Python
44 lines
1.3 KiB
Python
import torch
|
|
import threading
|
|
import time
|
|
from model.tokenizer import Tokenizer
|
|
from model.brain_state import save_model, DEVICE, model, optimizer
|
|
|
|
tokenizer = Tokenizer()
|
|
expand_lock = threading.Lock()
|
|
_last_expansion_time = 0
|
|
|
|
|
|
def expand_model_if_needed():
|
|
global _last_expansion_time
|
|
|
|
with expand_lock:
|
|
# Check if expansion is actually needed
|
|
needed_vocab_size = tokenizer.next_id
|
|
current_vocab_size = model.head.out_features
|
|
|
|
if needed_vocab_size <= current_vocab_size:
|
|
return # ✅ No expansion needed
|
|
|
|
# print(f"[Expand] Expanding vocabulary: {current_vocab_size} -> {needed_vocab_size}")
|
|
|
|
# Expand the head layer safely without rebuilding everything
|
|
old_head_weight = model.head.weight.data
|
|
old_out_features = old_head_weight.size(0)
|
|
in_features = model.head.in_features
|
|
|
|
new_head = torch.nn.Linear(in_features, needed_vocab_size, bias=False)
|
|
new_head = new_head.to(DEVICE)
|
|
|
|
# Copy old weights into the new head
|
|
with torch.no_grad():
|
|
new_head.weight[:old_out_features] = old_head_weight
|
|
|
|
model.head = new_head
|
|
|
|
# Rebuild optimizer and scheduler
|
|
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.95)
|
|
|
|
_last_expansion_time = time.time()
|
|
save_model()
|