Fixed vocab issues
This commit is contained in:
parent
b876dede9e
commit
e72678b242
3
.gitignore
vendored
3
.gitignore
vendored
@ -172,4 +172,5 @@ cython_debug/
|
|||||||
/data/books/alice_in_wonderland.txt
|
/data/books/alice_in_wonderland.txt
|
||||||
/data/books/wizard_of_oz.txt
|
/data/books/wizard_of_oz.txt
|
||||||
/data/memory/context.json
|
/data/memory/context.json
|
||||||
/data/memory/dreams.json
|
/data/memory/dreams.json
|
||||||
|
/data/memory/vocab.json
|
@ -3,13 +3,13 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from model.memory import save_dream
|
from model.memory import save_dream
|
||||||
from model.brain_state import model, tokenizer, DEVICE
|
from model.brain_state import model, tokenizer, DEVICE
|
||||||
|
from context.context import get_recent_context
|
||||||
|
|
||||||
recent_dreams = []
|
recent_dreams = []
|
||||||
|
|
||||||
|
|
||||||
def generate_response():
|
def generate_response():
|
||||||
model.eval()
|
model.eval()
|
||||||
# Pick a real known word to seed from context memory
|
|
||||||
context_texts = get_recent_context(5)
|
context_texts = get_recent_context(5)
|
||||||
if context_texts:
|
if context_texts:
|
||||||
start = random.choice(context_texts)
|
start = random.choice(context_texts)
|
||||||
@ -22,12 +22,13 @@ def generate_response():
|
|||||||
seed = torch.tensor([random.randint(0, tokenizer.next_id - 1)], device=DEVICE).unsqueeze(0)
|
seed = torch.tensor([random.randint(0, tokenizer.next_id - 1)], device=DEVICE).unsqueeze(0)
|
||||||
|
|
||||||
output = model(seed)
|
output = model(seed)
|
||||||
pred = torch.argmax(output, dim=-1).squeeze().tolist()
|
pred = torch.argmax(output, dim=-1).squeeze().item()
|
||||||
|
|
||||||
if not isinstance(pred, list):
|
# Clamp prediction into known vocab range
|
||||||
pred = [pred]
|
if pred >= tokenizer.next_id:
|
||||||
|
pred = random.randint(0, tokenizer.next_id - 1)
|
||||||
|
|
||||||
return tokenizer.detokenize(pred)
|
return tokenizer.detokenize([pred])
|
||||||
|
|
||||||
|
|
||||||
def score_sentence(sentence: str) -> float:
|
def score_sentence(sentence: str) -> float:
|
||||||
|
@ -6,7 +6,7 @@ from model.tokenizer import Tokenizer
|
|||||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
tokenizer = Tokenizer()
|
tokenizer = Tokenizer()
|
||||||
VOCAB_SIZE = 10000 # Expandable if needed
|
VOCAB_SIZE = len(tokenizer.vocab) + 10 # Slight buffer
|
||||||
|
|
||||||
model = TinyTransformer(vocab_size=VOCAB_SIZE).to(DEVICE)
|
model = TinyTransformer(vocab_size=VOCAB_SIZE).to(DEVICE)
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
||||||
|
41
model/dynamic_expand.py
Normal file
41
model/dynamic_expand.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
import torch
|
||||||
|
from model.brain_architecture import TinyTransformer
|
||||||
|
from model.brain_state import model, tokenizer, DEVICE, optimizer
|
||||||
|
import copy
|
||||||
|
|
||||||
|
|
||||||
|
def expand_model_if_needed():
|
||||||
|
current_vocab_size = len(tokenizer.vocab) + 10 # Tiny buffer
|
||||||
|
old_vocab_size = model.head.out_features
|
||||||
|
|
||||||
|
if current_vocab_size <= old_vocab_size:
|
||||||
|
return # No expansion needed
|
||||||
|
|
||||||
|
print(f"Expanding model from {old_vocab_size} -> {current_vocab_size}")
|
||||||
|
|
||||||
|
# Save old model
|
||||||
|
old_model = copy.deepcopy(model).to('cpu')
|
||||||
|
|
||||||
|
# Create new model
|
||||||
|
new_model = TinyTransformer(vocab_size=current_vocab_size).to(DEVICE)
|
||||||
|
new_optimizer = torch.optim.Adam(new_model.parameters(), lr=1e-4)
|
||||||
|
|
||||||
|
# Copy parameters
|
||||||
|
with torch.no_grad():
|
||||||
|
for name, param in old_model.named_parameters():
|
||||||
|
if name in dict(new_model.named_parameters()):
|
||||||
|
try:
|
||||||
|
new_param = dict(new_model.named_parameters())[name]
|
||||||
|
if param.shape == new_param.shape:
|
||||||
|
new_param.copy_(param)
|
||||||
|
else:
|
||||||
|
print(f"Skipping mismatched param: {name}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error copying param: {name} — {e}")
|
||||||
|
|
||||||
|
# Replace global references
|
||||||
|
globals()["model"] = new_model
|
||||||
|
globals()["optimizer"] = new_optimizer
|
||||||
|
|
||||||
|
print("Expansion complete.")
|
||||||
|
|
@ -2,6 +2,7 @@ import torch
|
|||||||
import time
|
import time
|
||||||
from model.brain_state import model, tokenizer, DEVICE, optimizer, loss_fn
|
from model.brain_state import model, tokenizer, DEVICE, optimizer, loss_fn
|
||||||
from context.context import add_to_context, get_recent_context
|
from context.context import add_to_context, get_recent_context
|
||||||
|
from model.dynamic_expand import expand_model_if_needed
|
||||||
|
|
||||||
LOSS_FILE = "data/logs/loss.log"
|
LOSS_FILE = "data/logs/loss.log"
|
||||||
|
|
||||||
@ -12,6 +13,7 @@ def log_loss(value: float):
|
|||||||
|
|
||||||
|
|
||||||
def train_on_message(text: str):
|
def train_on_message(text: str):
|
||||||
|
expand_model_if_needed()
|
||||||
model.train()
|
model.train()
|
||||||
context_texts = get_recent_context(3)
|
context_texts = get_recent_context(3)
|
||||||
augmented_text = " ".join(context_texts + [text])
|
augmented_text = " ".join(context_texts + [text])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user