Fixing another string of CUDA errors
This commit is contained in:
parent
684bf33675
commit
99fddcab4d
@ -1,4 +1,5 @@
|
||||
import random
|
||||
import asyncio
|
||||
from context.context import load_context
|
||||
from model.trainer import train_on_message
|
||||
from model.dynamic_expand import expand_model_if_needed
|
||||
|
@ -23,7 +23,7 @@ def train_on_message(text: str, source: str = "user"):
|
||||
expand_model_if_needed()
|
||||
|
||||
now = time.time()
|
||||
if now - _last_expansion_time < 5: # If expansion happened within the last 5 seconds
|
||||
if now - _last_expansion_time < 5:
|
||||
print("[Train] Skipping to stabilize after expansion.")
|
||||
return
|
||||
|
||||
@ -35,21 +35,32 @@ def train_on_message(text: str, source: str = "user"):
|
||||
model.train()
|
||||
context_texts = get_recent_context(10)
|
||||
augmented_text = " ".join(context_texts + [text])
|
||||
|
||||
tokens = tokenizer.tokenize(augmented_text)
|
||||
|
||||
if len(tokens) < 2:
|
||||
if not tokens or len(tokens) < 2:
|
||||
return
|
||||
|
||||
max_token_id = model.head.out_features - 1
|
||||
tokens = [min(t, max_token_id) for t in tokens]
|
||||
|
||||
if len(tokens) < 2:
|
||||
# Clamp each token to be inside model's head size
|
||||
clamped_tokens = []
|
||||
for token in tokens:
|
||||
if token > max_token_id:
|
||||
clamped_tokens.append(max_token_id)
|
||||
elif token < 0:
|
||||
clamped_tokens.append(0)
|
||||
else:
|
||||
clamped_tokens.append(token)
|
||||
|
||||
# Clamp sequence length
|
||||
clamped_tokens = clamped_tokens[:128]
|
||||
|
||||
if len(clamped_tokens) < 2:
|
||||
return
|
||||
|
||||
tokens = tokens[:128]
|
||||
|
||||
input_tensor = torch.tensor(tokens[:-1], dtype=torch.long, device=DEVICE).unsqueeze(0)
|
||||
target_tensor = torch.tensor(tokens[1:], dtype=torch.long, device=DEVICE).unsqueeze(0)
|
||||
input_tensor = torch.tensor(clamped_tokens[:-1], dtype=torch.long, device=DEVICE).unsqueeze(0)
|
||||
target_tensor = torch.tensor(clamped_tokens[1:], dtype=torch.long, device=DEVICE).unsqueeze(0)
|
||||
|
||||
opt = get_optimizer()
|
||||
|
||||
@ -63,5 +74,9 @@ def train_on_message(text: str, source: str = "user"):
|
||||
log_loss(loss.item())
|
||||
log_vocab_growth()
|
||||
add_to_context(text, source=source)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Train] Exception during training: {e}")
|
||||
|
||||
finally:
|
||||
expand_lock.release()
|
||||
|
Loading…
x
Reference in New Issue
Block a user