Fix: Adjusting Phoebe's code to prevent 'parroting'

This commit is contained in:
Dan
2024-05-25 08:30:55 -04:00
parent 509670c989
commit 1fe54ed1ff
2 changed files with 14 additions and 6 deletions

View File

@ -120,9 +120,7 @@ def train_model():
epochs=epochs, epochs=epochs,
) )
writer = SummaryWriter( writer = SummaryWriter(log_dir="runs/phoebe_training")
log_dir="runs/phoebe_training"
) # TensorBoard writer # noqa: E501
best_val_loss = float("inf") best_val_loss = float("inf")
patience_counter = 0 patience_counter = 0
@ -152,7 +150,7 @@ def train_model():
xb, yb = get_batch(train_data, block_size, batch_size) xb, yb = get_batch(train_data, block_size, batch_size)
logits, loss = model(xb, yb) logits, loss = model(xb, yb)
loss = loss / accumulation_steps # Scale loss by accumulation steps loss = loss / accumulation_steps
loss.backward() loss.backward()
if (iter + 1) % accumulation_steps == 0: if (iter + 1) % accumulation_steps == 0:
@ -192,9 +190,19 @@ def process_message(message):
print("Message could not be processed.") print("Message could not be processed.")
return "Message could not be processed." return "Message could not be processed."
response = model.generate(encoded_text, max_new_tokens=50, temperature=0.7) with torch.no_grad():
decoded_response = decode(response[0].tolist(), int_to_string) generated_tokens = model.generate(
encoded_text, max_new_tokens=50, temperature=0.7
)
generated_tokens = generated_tokens[0, len(encoded_text[0]) :]
decoded_response = decode(generated_tokens.tolist(), int_to_string)
print(f"Generated response: '{decoded_response}'") print(f"Generated response: '{decoded_response}'")
if decoded_response.startswith(message):
decoded_response = decoded_response[len(message) :].strip()
print(f"Final response: '{decoded_response}'")
return decoded_response return decoded_response

Binary file not shown.