Fixing the CUDA errors
Fixing replies.
This commit is contained in:
parent
99fddcab4d
commit
cde0068725
@ -10,28 +10,43 @@ from context.context import get_recent_context
|
||||
recent_dreams = []
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_response():
|
||||
model.eval()
|
||||
context_texts = get_recent_context(5)
|
||||
if context_texts:
|
||||
start = random.choice(context_texts)
|
||||
seed_tokens = tokenizer.tokenize(start)
|
||||
if seed_tokens:
|
||||
seed = torch.tensor([seed_tokens[-1]], device=DEVICE).unsqueeze(0)
|
||||
seed = seed[:, -128:]
|
||||
else:
|
||||
seed = torch.tensor([random.randint(0, tokenizer.next_id - 1)], device=DEVICE).unsqueeze(0)
|
||||
else:
|
||||
seed = torch.tensor([random.randint(0, tokenizer.next_id - 1)], device=DEVICE).unsqueeze(0)
|
||||
context_texts = get_recent_context(10)
|
||||
seed_text = " ".join(context_texts[-1:])
|
||||
tokens = tokenizer.tokenize(seed_text)
|
||||
|
||||
output = model(seed)
|
||||
pred = torch.argmax(output, dim=-1).squeeze().item()
|
||||
input_tensor = torch.tensor(tokens, dtype=torch.long, device=DEVICE).unsqueeze(0)
|
||||
|
||||
# Clamp prediction into known vocab range
|
||||
if pred >= tokenizer.next_id:
|
||||
pred = random.randint(0, tokenizer.next_id - 1)
|
||||
output_tokens = []
|
||||
max_tokens = 32
|
||||
|
||||
return tokenizer.detokenize([pred])
|
||||
for _ in range(max_tokens):
|
||||
output = model(input_tensor)
|
||||
logits = output[:, -1, :].squeeze(0)
|
||||
|
||||
# Apply temperature (soft randomness)
|
||||
temperature = 0.8
|
||||
logits = logits / temperature
|
||||
|
||||
# Top-k sampling
|
||||
k = 10
|
||||
topk_logits, topk_indices = torch.topk(logits, k)
|
||||
probs = torch.nn.functional.softmax(topk_logits, dim=-1)
|
||||
next_token = topk_indices[torch.multinomial(probs, 1)].item()
|
||||
|
||||
output_tokens.append(next_token)
|
||||
|
||||
input_tensor = torch.cat([input_tensor, torch.tensor([[next_token]], device=DEVICE)], dim=1)
|
||||
|
||||
# Optional: stop if next_token maps to period, question mark, or exclamation
|
||||
next_char = tokenizer.detokenize([next_token])
|
||||
if any(p in next_char for p in [".", "?", "!"]):
|
||||
break
|
||||
|
||||
text = tokenizer.detokenize(output_tokens)
|
||||
return text
|
||||
|
||||
|
||||
def score_sentence(sentence: str) -> float:
|
||||
|
@ -28,7 +28,12 @@ def expand_model_if_needed():
|
||||
# print(f"[Expand] Expanding model from {old_vocab_size} -> {current_vocab_size}")
|
||||
|
||||
old_state = model.state_dict()
|
||||
new_model = TinyTransformer(vocab_size=current_vocab_size).to(DEVICE)
|
||||
new_model = TinyTransformer(
|
||||
vocab_size=current_vocab_size,
|
||||
embed_dim=model.token_embed.embedding_dim,
|
||||
depth=len(model.blocks),
|
||||
heads=model.blocks[0].attn.heads
|
||||
).to(DEVICE)
|
||||
|
||||
with torch.no_grad():
|
||||
for name, param in new_model.named_parameters():
|
||||
|
Loading…
x
Reference in New Issue
Block a user