feat: Managed to achieve a loss of 0.285
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@ -160,3 +160,4 @@ cython_debug/
|
||||
#.idea/
|
||||
/openwebtext
|
||||
/data_extract.py
|
||||
/runs/phoebe_training
|
||||
|
@ -3,12 +3,11 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Hyperparameters
|
||||
batch_size = 64
|
||||
block_size = 256
|
||||
num_embed = 384 # Ensure consistency in naming
|
||||
num_embed = 512 # Increased embedding size
|
||||
num_heads = 8
|
||||
num_layers = 8
|
||||
dropout = 0.2
|
||||
num_layers = 12 # Increased number of layers
|
||||
dropout = 0.3
|
||||
|
||||
|
||||
class Head(nn.Module):
|
||||
@ -131,7 +130,6 @@ class GPT(nn.Module):
|
||||
|
||||
|
||||
def encode(s, string_to_int):
|
||||
# Replace unknown characters with a special token (e.g., "<unk>")
|
||||
return [string_to_int.get(c, string_to_int["<unk>"]) for c in s]
|
||||
|
||||
|
||||
|
@ -1,33 +1,37 @@
|
||||
# flake8: noqa: E203
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import random
|
||||
import os
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from gpt_model import encode, decode, load_model
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Hyperparameters
|
||||
batch_size = 64
|
||||
batch_size = 32 # Reduced batch size for gradient accumulation
|
||||
accumulation_steps = 4 # Gradient accumulation steps
|
||||
block_size = 256
|
||||
max_iters = 5000
|
||||
learning_rate = 1e-5 # Adjusted learning rate
|
||||
max_iters = 100000 # Increased iterations
|
||||
learning_rate = 3e-5 # Adjust learning rate
|
||||
eval_iters = 100
|
||||
dropout = 0.2
|
||||
patience = 500 # Number of iterations to wait for improvement before stopping
|
||||
dropout = 0.4 # Increased dropout to prevent overfitting
|
||||
patience = 20000 # Increased patience for early stopping
|
||||
weight_decay = 0.01 # Add weight decay for regularization
|
||||
|
||||
# Load the vocabulary and encoded data
|
||||
with open("vocab.txt", "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
chars = sorted(list(set(text)))
|
||||
|
||||
# Ensure that space and other special characters are included
|
||||
required_chars = " \n\r\t"
|
||||
for char in required_chars:
|
||||
if char not in chars:
|
||||
chars.append(char)
|
||||
|
||||
# Add a special token for unknown characters
|
||||
special_token = "<unk>"
|
||||
if special_token not in chars:
|
||||
chars.append(special_token)
|
||||
@ -38,14 +42,12 @@ int_to_string = {i: ch for i, ch in enumerate(chars)}
|
||||
|
||||
|
||||
def clean_text(text):
|
||||
"""Remove special characters and unwanted symbols from the text."""
|
||||
text = re.sub(r"[^a-zA-Z0-9\s.,;!?\'\"]+", "", text)
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
# Load and preprocess training and validation data from cleaned .txt files
|
||||
def load_and_clean_data(file_path):
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
@ -96,8 +98,8 @@ def estimate_loss():
|
||||
data = train_data if split == "train" else val_data
|
||||
losses = torch.zeros(eval_iters)
|
||||
for k in range(eval_iters):
|
||||
X, Y = get_batch(data, block_size, batch_size)
|
||||
logits, loss = model(X, Y)
|
||||
x, y = get_batch(data, block_size, batch_size)
|
||||
logits, loss = model(x, y)
|
||||
losses[k] = loss.item()
|
||||
out[split] = losses.mean().item()
|
||||
model.train()
|
||||
@ -105,8 +107,23 @@ def estimate_loss():
|
||||
|
||||
|
||||
def train_model():
|
||||
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
|
||||
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.1)
|
||||
optimizer = optim.AdamW(
|
||||
model.parameters(), lr=learning_rate, weight_decay=weight_decay
|
||||
)
|
||||
steps_per_epoch = len(train_data) // (batch_size * block_size)
|
||||
epochs = max_iters // steps_per_epoch
|
||||
|
||||
scheduler = optim.lr_scheduler.OneCycleLR(
|
||||
optimizer,
|
||||
max_lr=learning_rate * 10,
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
epochs=epochs,
|
||||
)
|
||||
|
||||
writer = SummaryWriter(
|
||||
log_dir="runs/phoebe_training"
|
||||
) # TensorBoard writer # noqa: E501
|
||||
|
||||
best_val_loss = float("inf")
|
||||
patience_counter = 0
|
||||
|
||||
@ -118,7 +135,9 @@ def train_model():
|
||||
f"val loss {losses['val']:.3f}"
|
||||
)
|
||||
|
||||
# Check for improvement in validation loss
|
||||
writer.add_scalar("Loss/train", losses["train"], iter)
|
||||
writer.add_scalar("Loss/val", losses["val"], iter)
|
||||
|
||||
if losses["val"] < best_val_loss:
|
||||
best_val_loss = losses["val"]
|
||||
patience_counter = 0
|
||||
@ -127,21 +146,24 @@ def train_model():
|
||||
else:
|
||||
patience_counter += eval_iters
|
||||
|
||||
# Early stopping
|
||||
if patience_counter >= patience:
|
||||
print("Early stopping triggered.")
|
||||
break
|
||||
|
||||
xb, yb = get_batch(train_data, block_size, batch_size)
|
||||
logits, loss = model(xb, yb)
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
loss = loss / accumulation_steps # Scale loss by accumulation steps
|
||||
loss.backward()
|
||||
|
||||
if (iter + 1) % accumulation_steps == 0:
|
||||
optimizer.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
scheduler.step()
|
||||
|
||||
if patience_counter < patience:
|
||||
print("Training completed without early stopping.")
|
||||
print(f"Final loss: {loss.item()}")
|
||||
writer.close()
|
||||
|
||||
|
||||
def check_input_chars(s, string_to_int):
|
||||
@ -152,12 +174,11 @@ def check_input_chars(s, string_to_int):
|
||||
|
||||
|
||||
def process_message(message):
|
||||
print(f"Processing message: '{message}'") # Debug print
|
||||
print(f"Processing message: '{message}'")
|
||||
if not message.strip():
|
||||
print("Message is empty or invalid.") # Debug print
|
||||
print("Message is empty or invalid.")
|
||||
return "Message is empty or invalid."
|
||||
|
||||
# Check for unknown characters
|
||||
unknown_chars = check_input_chars(message, string_to_int)
|
||||
if unknown_chars:
|
||||
print(f"Message contains unknown characters: {unknown_chars}")
|
||||
@ -166,14 +187,14 @@ def process_message(message):
|
||||
encoded_text = torch.tensor(
|
||||
[encode(message, string_to_int)], dtype=torch.long
|
||||
).to(device)
|
||||
print(f"Encoded text shape: {encoded_text.shape}") # Debug print
|
||||
print(f"Encoded text shape: {encoded_text.shape}")
|
||||
if encoded_text.size(1) == 0:
|
||||
print("Message could not be processed.") # Debug print
|
||||
print("Message could not be processed.")
|
||||
return "Message could not be processed."
|
||||
|
||||
response = model.generate(encoded_text, max_new_tokens=50, temperature=0.7)
|
||||
decoded_response = decode(response[0].tolist(), int_to_string)
|
||||
print(f"Generated response: '{decoded_response}'") # Debug print
|
||||
print(f"Generated response: '{decoded_response}'")
|
||||
return decoded_response
|
||||
|
||||
|
||||
|
BIN
phoebe_model.pt
BIN
phoebe_model.pt
Binary file not shown.
Reference in New Issue
Block a user