Fixed vocab issues

This commit is contained in:
Dani 2025-04-25 23:00:04 -04:00
parent b876dede9e
commit e72678b242
5 changed files with 52 additions and 7 deletions

1
.gitignore vendored
View File

@ -173,3 +173,4 @@ cython_debug/
/data/books/wizard_of_oz.txt /data/books/wizard_of_oz.txt
/data/memory/context.json /data/memory/context.json
/data/memory/dreams.json /data/memory/dreams.json
/data/memory/vocab.json

View File

@ -3,13 +3,13 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from model.memory import save_dream from model.memory import save_dream
from model.brain_state import model, tokenizer, DEVICE from model.brain_state import model, tokenizer, DEVICE
from context.context import get_recent_context
recent_dreams = [] recent_dreams = []
def generate_response(): def generate_response():
model.eval() model.eval()
# Pick a real known word to seed from context memory
context_texts = get_recent_context(5) context_texts = get_recent_context(5)
if context_texts: if context_texts:
start = random.choice(context_texts) start = random.choice(context_texts)
@ -22,12 +22,13 @@ def generate_response():
seed = torch.tensor([random.randint(0, tokenizer.next_id - 1)], device=DEVICE).unsqueeze(0) seed = torch.tensor([random.randint(0, tokenizer.next_id - 1)], device=DEVICE).unsqueeze(0)
output = model(seed) output = model(seed)
pred = torch.argmax(output, dim=-1).squeeze().tolist() pred = torch.argmax(output, dim=-1).squeeze().item()
if not isinstance(pred, list): # Clamp prediction into known vocab range
pred = [pred] if pred >= tokenizer.next_id:
pred = random.randint(0, tokenizer.next_id - 1)
return tokenizer.detokenize(pred) return tokenizer.detokenize([pred])
def score_sentence(sentence: str) -> float: def score_sentence(sentence: str) -> float:

View File

@ -6,7 +6,7 @@ from model.tokenizer import Tokenizer
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = Tokenizer() tokenizer = Tokenizer()
VOCAB_SIZE = 10000 # Expandable if needed VOCAB_SIZE = len(tokenizer.vocab) + 10 # Slight buffer
model = TinyTransformer(vocab_size=VOCAB_SIZE).to(DEVICE) model = TinyTransformer(vocab_size=VOCAB_SIZE).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

41
model/dynamic_expand.py Normal file
View File

@ -0,0 +1,41 @@
import torch
from model.brain_architecture import TinyTransformer
from model.brain_state import model, tokenizer, DEVICE, optimizer
import copy
def expand_model_if_needed():
current_vocab_size = len(tokenizer.vocab) + 10 # Tiny buffer
old_vocab_size = model.head.out_features
if current_vocab_size <= old_vocab_size:
return # No expansion needed
print(f"Expanding model from {old_vocab_size} -> {current_vocab_size}")
# Save old model
old_model = copy.deepcopy(model).to('cpu')
# Create new model
new_model = TinyTransformer(vocab_size=current_vocab_size).to(DEVICE)
new_optimizer = torch.optim.Adam(new_model.parameters(), lr=1e-4)
# Copy parameters
with torch.no_grad():
for name, param in old_model.named_parameters():
if name in dict(new_model.named_parameters()):
try:
new_param = dict(new_model.named_parameters())[name]
if param.shape == new_param.shape:
new_param.copy_(param)
else:
print(f"Skipping mismatched param: {name}")
except Exception as e:
print(f"Error copying param: {name}{e}")
# Replace global references
globals()["model"] = new_model
globals()["optimizer"] = new_optimizer
print("Expansion complete.")

View File

@ -2,6 +2,7 @@ import torch
import time import time
from model.brain_state import model, tokenizer, DEVICE, optimizer, loss_fn from model.brain_state import model, tokenizer, DEVICE, optimizer, loss_fn
from context.context import add_to_context, get_recent_context from context.context import add_to_context, get_recent_context
from model.dynamic_expand import expand_model_if_needed
LOSS_FILE = "data/logs/loss.log" LOSS_FILE = "data/logs/loss.log"
@ -12,6 +13,7 @@ def log_loss(value: float):
def train_on_message(text: str): def train_on_message(text: str):
expand_model_if_needed()
model.train() model.train()
context_texts = get_recent_context(3) context_texts = get_recent_context(3)
augmented_text = " ".join(context_texts + [text]) augmented_text = " ".join(context_texts + [text])