Fix: Adjusting Phoebe's code to prevent 'parroting'
This commit is contained in:
@ -120,9 +120,7 @@ def train_model():
|
|||||||
epochs=epochs,
|
epochs=epochs,
|
||||||
)
|
)
|
||||||
|
|
||||||
writer = SummaryWriter(
|
writer = SummaryWriter(log_dir="runs/phoebe_training")
|
||||||
log_dir="runs/phoebe_training"
|
|
||||||
) # TensorBoard writer # noqa: E501
|
|
||||||
|
|
||||||
best_val_loss = float("inf")
|
best_val_loss = float("inf")
|
||||||
patience_counter = 0
|
patience_counter = 0
|
||||||
@ -152,7 +150,7 @@ def train_model():
|
|||||||
|
|
||||||
xb, yb = get_batch(train_data, block_size, batch_size)
|
xb, yb = get_batch(train_data, block_size, batch_size)
|
||||||
logits, loss = model(xb, yb)
|
logits, loss = model(xb, yb)
|
||||||
loss = loss / accumulation_steps # Scale loss by accumulation steps
|
loss = loss / accumulation_steps
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
if (iter + 1) % accumulation_steps == 0:
|
if (iter + 1) % accumulation_steps == 0:
|
||||||
@ -192,9 +190,19 @@ def process_message(message):
|
|||||||
print("Message could not be processed.")
|
print("Message could not be processed.")
|
||||||
return "Message could not be processed."
|
return "Message could not be processed."
|
||||||
|
|
||||||
response = model.generate(encoded_text, max_new_tokens=50, temperature=0.7)
|
with torch.no_grad():
|
||||||
decoded_response = decode(response[0].tolist(), int_to_string)
|
generated_tokens = model.generate(
|
||||||
|
encoded_text, max_new_tokens=50, temperature=0.7
|
||||||
|
)
|
||||||
|
generated_tokens = generated_tokens[0, len(encoded_text[0]) :]
|
||||||
|
|
||||||
|
decoded_response = decode(generated_tokens.tolist(), int_to_string)
|
||||||
print(f"Generated response: '{decoded_response}'")
|
print(f"Generated response: '{decoded_response}'")
|
||||||
|
|
||||||
|
if decoded_response.startswith(message):
|
||||||
|
decoded_response = decoded_response[len(message) :].strip()
|
||||||
|
|
||||||
|
print(f"Final response: '{decoded_response}'")
|
||||||
return decoded_response
|
return decoded_response
|
||||||
|
|
||||||
|
|
||||||
|
BIN
phoebe_model.pt
BIN
phoebe_model.pt
Binary file not shown.
Reference in New Issue
Block a user