Fixing another string of CUDA errors

This commit is contained in:
Dani 2025-04-27 13:42:29 -04:00
parent 684bf33675
commit 99fddcab4d
2 changed files with 24 additions and 8 deletions

View File

@ -1,4 +1,5 @@
import random import random
import asyncio
from context.context import load_context from context.context import load_context
from model.trainer import train_on_message from model.trainer import train_on_message
from model.dynamic_expand import expand_model_if_needed from model.dynamic_expand import expand_model_if_needed

View File

@ -23,7 +23,7 @@ def train_on_message(text: str, source: str = "user"):
expand_model_if_needed() expand_model_if_needed()
now = time.time() now = time.time()
if now - _last_expansion_time < 5: # If expansion happened within the last 5 seconds if now - _last_expansion_time < 5:
print("[Train] Skipping to stabilize after expansion.") print("[Train] Skipping to stabilize after expansion.")
return return
@ -35,21 +35,32 @@ def train_on_message(text: str, source: str = "user"):
model.train() model.train()
context_texts = get_recent_context(10) context_texts = get_recent_context(10)
augmented_text = " ".join(context_texts + [text]) augmented_text = " ".join(context_texts + [text])
tokens = tokenizer.tokenize(augmented_text) tokens = tokenizer.tokenize(augmented_text)
if len(tokens) < 2: if not tokens or len(tokens) < 2:
return return
max_token_id = model.head.out_features - 1 max_token_id = model.head.out_features - 1
tokens = [min(t, max_token_id) for t in tokens]
if len(tokens) < 2: # Clamp each token to be inside model's head size
clamped_tokens = []
for token in tokens:
if token > max_token_id:
clamped_tokens.append(max_token_id)
elif token < 0:
clamped_tokens.append(0)
else:
clamped_tokens.append(token)
# Clamp sequence length
clamped_tokens = clamped_tokens[:128]
if len(clamped_tokens) < 2:
return return
tokens = tokens[:128] input_tensor = torch.tensor(clamped_tokens[:-1], dtype=torch.long, device=DEVICE).unsqueeze(0)
target_tensor = torch.tensor(clamped_tokens[1:], dtype=torch.long, device=DEVICE).unsqueeze(0)
input_tensor = torch.tensor(tokens[:-1], dtype=torch.long, device=DEVICE).unsqueeze(0)
target_tensor = torch.tensor(tokens[1:], dtype=torch.long, device=DEVICE).unsqueeze(0)
opt = get_optimizer() opt = get_optimizer()
@ -63,5 +74,9 @@ def train_on_message(text: str, source: str = "user"):
log_loss(loss.item()) log_loss(loss.item())
log_vocab_growth() log_vocab_growth()
add_to_context(text, source=source) add_to_context(text, source=source)
except Exception as e:
print(f"[Train] Exception during training: {e}")
finally: finally:
expand_lock.release() expand_lock.release()