Corrected a possible 'posioning' of the model
This commit is contained in:
parent
d456c833bf
commit
97b43f832b
@ -123,7 +123,7 @@ def journal():
|
||||
@app.route("/concepts")
|
||||
def concepts():
|
||||
clusters = cluster_vocab(n_clusters=10)
|
||||
return render_template("concepts.html", clusters=clusters)
|
||||
return render_template("concepts.html", clusters={i: cluster for i, cluster in enumerate(clusters)})
|
||||
|
||||
|
||||
@app.route("/dreams")
|
||||
|
@ -11,34 +11,33 @@ recent_dreams = []
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_response():
|
||||
def generate_response(max_tokens: int = 50):
|
||||
model.eval()
|
||||
input_ids = torch.tensor([tokenizer.token_to_id("<start>")], device=DEVICE).unsqueeze(0)
|
||||
|
||||
# Start from an empty input: purely organic thought
|
||||
input_ids = torch.zeros((1, 1), dtype=torch.long, device=DEVICE)
|
||||
generated = []
|
||||
|
||||
output_tokens = []
|
||||
max_length = 50
|
||||
|
||||
for _ in range(max_length):
|
||||
for _ in range(max_tokens):
|
||||
output = model(input_ids)
|
||||
if torch.isnan(output).any():
|
||||
print("[Brain] Detected NaN in output, restarting generation.")
|
||||
return "..."
|
||||
|
||||
next_token_logits = output[:, -1, :]
|
||||
next_token = torch.argmax(next_token_logits, dim=-1)
|
||||
|
||||
# Get token id values for special tokens
|
||||
pad_token_id = tokenizer.vocab.get("<pad>", None)
|
||||
unk_token_id = tokenizer.vocab.get("<unk>", None)
|
||||
token_id = next_token.item()
|
||||
|
||||
# Stop if the model predicts <pad> or <unk>
|
||||
if pad_token_id is not None and next_token.item() == pad_token_id:
|
||||
break
|
||||
if unk_token_id is not None and next_token.item() == unk_token_id:
|
||||
# If she outputs <end> token, stop generation
|
||||
if tokenizer.reverse_vocab.get(token_id, "") == "<end>":
|
||||
break
|
||||
|
||||
output_tokens.append(next_token.item())
|
||||
generated.append(token_id)
|
||||
|
||||
next_token = next_token.unsqueeze(0)
|
||||
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
|
||||
|
||||
return tokenizer.detokenize(output_tokens)
|
||||
return tokenizer.detokenize(generated)
|
||||
|
||||
|
||||
def score_sentence(sentence: str) -> float:
|
||||
|
@ -19,9 +19,9 @@ def save_vocab(vocab):
|
||||
|
||||
class Tokenizer:
|
||||
def __init__(self):
|
||||
self.vocab = load_vocab()
|
||||
self.reverse_vocab = {v: k for k, v in self.vocab.items()}
|
||||
self.next_id = max(self.vocab.values(), default=0) + 1
|
||||
self.vocab = {"<pad>": 0, "<unk>": 1, "<start>": 2, "<end>": 3}
|
||||
self.reverse_vocab = {0: "<pad>", 1: "<unk>", 2: "<start>", 3: "<end>"}
|
||||
self.next_id = 4
|
||||
|
||||
def tokenize(self, text):
|
||||
words = re.findall(r"\b\w+\b", text.lower())
|
||||
|
@ -2,6 +2,7 @@ import torch
|
||||
import time
|
||||
from model.dynamic_expand import expand_model_if_needed, _last_expansion_time, get_optimizer, expand_lock
|
||||
from model.brain_state import model, tokenizer, DEVICE, loss_fn
|
||||
from model.brainmap import update_brainmap
|
||||
from context.context import add_to_context, get_recent_context
|
||||
|
||||
LOSS_FILE = "data/logs/loss.log"
|
||||
@ -34,38 +35,32 @@ def train_on_message(text: str, source: str = "user"):
|
||||
try:
|
||||
model.train()
|
||||
context_texts = get_recent_context(10)
|
||||
augmented_text = " ".join(context_texts + [text])
|
||||
|
||||
# Here's the important change:
|
||||
augmented_text = "<start> " + " ".join(context_texts + [text]) + " <end>"
|
||||
|
||||
tokens = tokenizer.tokenize(augmented_text)
|
||||
|
||||
if not tokens or len(tokens) < 2:
|
||||
if len(tokens) < 2:
|
||||
return
|
||||
|
||||
max_token_id = model.head.out_features - 1
|
||||
tokens = [t if t <= max_token_id else max_token_id for t in tokens]
|
||||
tokens = tokens[:128]
|
||||
|
||||
# Clamp each token to be inside model's head size
|
||||
clamped_tokens = []
|
||||
for token in tokens:
|
||||
if token > max_token_id:
|
||||
clamped_tokens.append(max_token_id)
|
||||
elif token < 0:
|
||||
clamped_tokens.append(0)
|
||||
else:
|
||||
clamped_tokens.append(token)
|
||||
|
||||
# Clamp sequence length
|
||||
clamped_tokens = clamped_tokens[:128]
|
||||
|
||||
if len(clamped_tokens) < 2:
|
||||
if len(tokens) < 2:
|
||||
return
|
||||
|
||||
input_tensor = torch.tensor(clamped_tokens[:-1], dtype=torch.long, device=DEVICE).unsqueeze(0)
|
||||
target_tensor = torch.tensor(clamped_tokens[1:], dtype=torch.long, device=DEVICE).unsqueeze(0)
|
||||
input_tensor = torch.tensor(tokens[:-1], dtype=torch.long, device=DEVICE).unsqueeze(0)
|
||||
target_tensor = torch.tensor(tokens[1:], dtype=torch.long, device=DEVICE).unsqueeze(0)
|
||||
|
||||
opt = get_optimizer()
|
||||
|
||||
output = model(input_tensor)
|
||||
loss = loss_fn(output.view(-1, output.size(-1)), target_tensor.view(-1))
|
||||
if torch.isnan(loss):
|
||||
print("[Trainer] Detected NaN loss, skipping update.")
|
||||
return
|
||||
|
||||
opt.zero_grad()
|
||||
loss.backward()
|
||||
@ -74,9 +69,6 @@ def train_on_message(text: str, source: str = "user"):
|
||||
log_loss(loss.item())
|
||||
log_vocab_growth()
|
||||
add_to_context(text, source=source)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Train] Exception during training: {e}")
|
||||
|
||||
update_brainmap(augmented_text.split())
|
||||
finally:
|
||||
expand_lock.release()
|
Loading…
x
Reference in New Issue
Block a user