Corrected a possible 'posioning' of the model

This commit is contained in:
Dani 2025-04-27 14:36:38 -04:00
parent d456c833bf
commit 97b43f832b
4 changed files with 34 additions and 43 deletions

View File

@ -123,7 +123,7 @@ def journal():
@app.route("/concepts") @app.route("/concepts")
def concepts(): def concepts():
clusters = cluster_vocab(n_clusters=10) clusters = cluster_vocab(n_clusters=10)
return render_template("concepts.html", clusters=clusters) return render_template("concepts.html", clusters={i: cluster for i, cluster in enumerate(clusters)})
@app.route("/dreams") @app.route("/dreams")

View File

@ -11,34 +11,33 @@ recent_dreams = []
@torch.inference_mode() @torch.inference_mode()
def generate_response(): def generate_response(max_tokens: int = 50):
model.eval() model.eval()
input_ids = torch.tensor([tokenizer.token_to_id("<start>")], device=DEVICE).unsqueeze(0)
# Start from an empty input: purely organic thought generated = []
input_ids = torch.zeros((1, 1), dtype=torch.long, device=DEVICE)
output_tokens = [] for _ in range(max_tokens):
max_length = 50
for _ in range(max_length):
output = model(input_ids) output = model(input_ids)
if torch.isnan(output).any():
print("[Brain] Detected NaN in output, restarting generation.")
return "..."
next_token_logits = output[:, -1, :] next_token_logits = output[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1) next_token = torch.argmax(next_token_logits, dim=-1)
# Get token id values for special tokens token_id = next_token.item()
pad_token_id = tokenizer.vocab.get("<pad>", None)
unk_token_id = tokenizer.vocab.get("<unk>", None)
# Stop if the model predicts <pad> or <unk> # If she outputs <end> token, stop generation
if pad_token_id is not None and next_token.item() == pad_token_id: if tokenizer.reverse_vocab.get(token_id, "") == "<end>":
break
if unk_token_id is not None and next_token.item() == unk_token_id:
break break
output_tokens.append(next_token.item()) generated.append(token_id)
next_token = next_token.unsqueeze(0)
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1) input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
return tokenizer.detokenize(output_tokens) return tokenizer.detokenize(generated)
def score_sentence(sentence: str) -> float: def score_sentence(sentence: str) -> float:

View File

@ -19,9 +19,9 @@ def save_vocab(vocab):
class Tokenizer: class Tokenizer:
def __init__(self): def __init__(self):
self.vocab = load_vocab() self.vocab = {"<pad>": 0, "<unk>": 1, "<start>": 2, "<end>": 3}
self.reverse_vocab = {v: k for k, v in self.vocab.items()} self.reverse_vocab = {0: "<pad>", 1: "<unk>", 2: "<start>", 3: "<end>"}
self.next_id = max(self.vocab.values(), default=0) + 1 self.next_id = 4
def tokenize(self, text): def tokenize(self, text):
words = re.findall(r"\b\w+\b", text.lower()) words = re.findall(r"\b\w+\b", text.lower())

View File

@ -2,6 +2,7 @@ import torch
import time import time
from model.dynamic_expand import expand_model_if_needed, _last_expansion_time, get_optimizer, expand_lock from model.dynamic_expand import expand_model_if_needed, _last_expansion_time, get_optimizer, expand_lock
from model.brain_state import model, tokenizer, DEVICE, loss_fn from model.brain_state import model, tokenizer, DEVICE, loss_fn
from model.brainmap import update_brainmap
from context.context import add_to_context, get_recent_context from context.context import add_to_context, get_recent_context
LOSS_FILE = "data/logs/loss.log" LOSS_FILE = "data/logs/loss.log"
@ -34,38 +35,32 @@ def train_on_message(text: str, source: str = "user"):
try: try:
model.train() model.train()
context_texts = get_recent_context(10) context_texts = get_recent_context(10)
augmented_text = " ".join(context_texts + [text])
# Here's the important change:
augmented_text = "<start> " + " ".join(context_texts + [text]) + " <end>"
tokens = tokenizer.tokenize(augmented_text) tokens = tokenizer.tokenize(augmented_text)
if not tokens or len(tokens) < 2: if len(tokens) < 2:
return return
max_token_id = model.head.out_features - 1 max_token_id = model.head.out_features - 1
tokens = [t if t <= max_token_id else max_token_id for t in tokens]
tokens = tokens[:128]
# Clamp each token to be inside model's head size if len(tokens) < 2:
clamped_tokens = []
for token in tokens:
if token > max_token_id:
clamped_tokens.append(max_token_id)
elif token < 0:
clamped_tokens.append(0)
else:
clamped_tokens.append(token)
# Clamp sequence length
clamped_tokens = clamped_tokens[:128]
if len(clamped_tokens) < 2:
return return
input_tensor = torch.tensor(clamped_tokens[:-1], dtype=torch.long, device=DEVICE).unsqueeze(0) input_tensor = torch.tensor(tokens[:-1], dtype=torch.long, device=DEVICE).unsqueeze(0)
target_tensor = torch.tensor(clamped_tokens[1:], dtype=torch.long, device=DEVICE).unsqueeze(0) target_tensor = torch.tensor(tokens[1:], dtype=torch.long, device=DEVICE).unsqueeze(0)
opt = get_optimizer() opt = get_optimizer()
output = model(input_tensor) output = model(input_tensor)
loss = loss_fn(output.view(-1, output.size(-1)), target_tensor.view(-1)) loss = loss_fn(output.view(-1, output.size(-1)), target_tensor.view(-1))
if torch.isnan(loss):
print("[Trainer] Detected NaN loss, skipping update.")
return
opt.zero_grad() opt.zero_grad()
loss.backward() loss.backward()
@ -74,9 +69,6 @@ def train_on_message(text: str, source: str = "user"):
log_loss(loss.item()) log_loss(loss.item())
log_vocab_growth() log_vocab_growth()
add_to_context(text, source=source) add_to_context(text, source=source)
update_brainmap(augmented_text.split())
except Exception as e:
print(f"[Train] Exception during training: {e}")
finally: finally:
expand_lock.release() expand_lock.release()