diff --git a/phoebe/train_gpt_model.py b/phoebe/train_gpt_model.py index 95291da..1e99d29 100644 --- a/phoebe/train_gpt_model.py +++ b/phoebe/train_gpt_model.py @@ -120,9 +120,7 @@ def train_model(): epochs=epochs, ) - writer = SummaryWriter( - log_dir="runs/phoebe_training" - ) # TensorBoard writer # noqa: E501 + writer = SummaryWriter(log_dir="runs/phoebe_training") best_val_loss = float("inf") patience_counter = 0 @@ -152,7 +150,7 @@ def train_model(): xb, yb = get_batch(train_data, block_size, batch_size) logits, loss = model(xb, yb) - loss = loss / accumulation_steps # Scale loss by accumulation steps + loss = loss / accumulation_steps loss.backward() if (iter + 1) % accumulation_steps == 0: @@ -192,9 +190,19 @@ def process_message(message): print("Message could not be processed.") return "Message could not be processed." - response = model.generate(encoded_text, max_new_tokens=50, temperature=0.7) - decoded_response = decode(response[0].tolist(), int_to_string) + with torch.no_grad(): + generated_tokens = model.generate( + encoded_text, max_new_tokens=50, temperature=0.7 + ) + generated_tokens = generated_tokens[0, len(encoded_text[0]) :] + + decoded_response = decode(generated_tokens.tolist(), int_to_string) print(f"Generated response: '{decoded_response}'") + + if decoded_response.startswith(message): + decoded_response = decoded_response[len(message) :].strip() + + print(f"Final response: '{decoded_response}'") return decoded_response diff --git a/phoebe_model.pt b/phoebe_model.pt index bdccfde..3036d7e 100644 Binary files a/phoebe_model.pt and b/phoebe_model.pt differ