feat: Managed to achieve a loss of 0.285

This commit is contained in:
Dan
2024-05-23 22:39:46 -04:00
parent 47c8cce3dd
commit 509670c989
4 changed files with 51 additions and 31 deletions

1
.gitignore vendored
View File

@ -160,3 +160,4 @@ cython_debug/
#.idea/ #.idea/
/openwebtext /openwebtext
/data_extract.py /data_extract.py
/runs/phoebe_training

View File

@ -3,12 +3,11 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
# Hyperparameters # Hyperparameters
batch_size = 64
block_size = 256 block_size = 256
num_embed = 384 # Ensure consistency in naming num_embed = 512 # Increased embedding size
num_heads = 8 num_heads = 8
num_layers = 8 num_layers = 12 # Increased number of layers
dropout = 0.2 dropout = 0.3
class Head(nn.Module): class Head(nn.Module):
@ -131,7 +130,6 @@ class GPT(nn.Module):
def encode(s, string_to_int): def encode(s, string_to_int):
# Replace unknown characters with a special token (e.g., "<unk>")
return [string_to_int.get(c, string_to_int["<unk>"]) for c in s] return [string_to_int.get(c, string_to_int["<unk>"]) for c in s]

View File

@ -1,33 +1,37 @@
# flake8: noqa: E203
import os
import random
import re import re
import torch import torch
import torch.optim as optim import torch.optim as optim
import random from torch.utils.tensorboard import SummaryWriter
import os
from gpt_model import encode, decode, load_model from gpt_model import encode, decode, load_model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Hyperparameters # Hyperparameters
batch_size = 64 batch_size = 32 # Reduced batch size for gradient accumulation
accumulation_steps = 4 # Gradient accumulation steps
block_size = 256 block_size = 256
max_iters = 5000 max_iters = 100000 # Increased iterations
learning_rate = 1e-5 # Adjusted learning rate learning_rate = 3e-5 # Adjust learning rate
eval_iters = 100 eval_iters = 100
dropout = 0.2 dropout = 0.4 # Increased dropout to prevent overfitting
patience = 500 # Number of iterations to wait for improvement before stopping patience = 20000 # Increased patience for early stopping
weight_decay = 0.01 # Add weight decay for regularization
# Load the vocabulary and encoded data # Load the vocabulary and encoded data
with open("vocab.txt", "r", encoding="utf-8") as f: with open("vocab.txt", "r", encoding="utf-8") as f:
text = f.read() text = f.read()
chars = sorted(list(set(text))) chars = sorted(list(set(text)))
# Ensure that space and other special characters are included
required_chars = " \n\r\t" required_chars = " \n\r\t"
for char in required_chars: for char in required_chars:
if char not in chars: if char not in chars:
chars.append(char) chars.append(char)
# Add a special token for unknown characters
special_token = "<unk>" special_token = "<unk>"
if special_token not in chars: if special_token not in chars:
chars.append(special_token) chars.append(special_token)
@ -38,14 +42,12 @@ int_to_string = {i: ch for i, ch in enumerate(chars)}
def clean_text(text): def clean_text(text):
"""Remove special characters and unwanted symbols from the text."""
text = re.sub(r"[^a-zA-Z0-9\s.,;!?\'\"]+", "", text) text = re.sub(r"[^a-zA-Z0-9\s.,;!?\'\"]+", "", text)
text = re.sub(r"\s+", " ", text) text = re.sub(r"\s+", " ", text)
text = text.strip() text = text.strip()
return text return text
# Load and preprocess training and validation data from cleaned .txt files
def load_and_clean_data(file_path): def load_and_clean_data(file_path):
with open(file_path, "r", encoding="utf-8") as f: with open(file_path, "r", encoding="utf-8") as f:
text = f.read() text = f.read()
@ -96,8 +98,8 @@ def estimate_loss():
data = train_data if split == "train" else val_data data = train_data if split == "train" else val_data
losses = torch.zeros(eval_iters) losses = torch.zeros(eval_iters)
for k in range(eval_iters): for k in range(eval_iters):
X, Y = get_batch(data, block_size, batch_size) x, y = get_batch(data, block_size, batch_size)
logits, loss = model(X, Y) logits, loss = model(x, y)
losses[k] = loss.item() losses[k] = loss.item()
out[split] = losses.mean().item() out[split] = losses.mean().item()
model.train() model.train()
@ -105,8 +107,23 @@ def estimate_loss():
def train_model(): def train_model():
optimizer = optim.AdamW(model.parameters(), lr=learning_rate) optimizer = optim.AdamW(
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.1) model.parameters(), lr=learning_rate, weight_decay=weight_decay
)
steps_per_epoch = len(train_data) // (batch_size * block_size)
epochs = max_iters // steps_per_epoch
scheduler = optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=learning_rate * 10,
steps_per_epoch=steps_per_epoch,
epochs=epochs,
)
writer = SummaryWriter(
log_dir="runs/phoebe_training"
) # TensorBoard writer # noqa: E501
best_val_loss = float("inf") best_val_loss = float("inf")
patience_counter = 0 patience_counter = 0
@ -118,7 +135,9 @@ def train_model():
f"val loss {losses['val']:.3f}" f"val loss {losses['val']:.3f}"
) )
# Check for improvement in validation loss writer.add_scalar("Loss/train", losses["train"], iter)
writer.add_scalar("Loss/val", losses["val"], iter)
if losses["val"] < best_val_loss: if losses["val"] < best_val_loss:
best_val_loss = losses["val"] best_val_loss = losses["val"]
patience_counter = 0 patience_counter = 0
@ -127,21 +146,24 @@ def train_model():
else: else:
patience_counter += eval_iters patience_counter += eval_iters
# Early stopping
if patience_counter >= patience: if patience_counter >= patience:
print("Early stopping triggered.") print("Early stopping triggered.")
break break
xb, yb = get_batch(train_data, block_size, batch_size) xb, yb = get_batch(train_data, block_size, batch_size)
logits, loss = model(xb, yb) logits, loss = model(xb, yb)
optimizer.zero_grad(set_to_none=True) loss = loss / accumulation_steps # Scale loss by accumulation steps
loss.backward() loss.backward()
optimizer.step()
scheduler.step() if (iter + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad(set_to_none=True)
scheduler.step()
if patience_counter < patience: if patience_counter < patience:
print("Training completed without early stopping.") print("Training completed without early stopping.")
print(f"Final loss: {loss.item()}") print(f"Final loss: {loss.item()}")
writer.close()
def check_input_chars(s, string_to_int): def check_input_chars(s, string_to_int):
@ -152,12 +174,11 @@ def check_input_chars(s, string_to_int):
def process_message(message): def process_message(message):
print(f"Processing message: '{message}'") # Debug print print(f"Processing message: '{message}'")
if not message.strip(): if not message.strip():
print("Message is empty or invalid.") # Debug print print("Message is empty or invalid.")
return "Message is empty or invalid." return "Message is empty or invalid."
# Check for unknown characters
unknown_chars = check_input_chars(message, string_to_int) unknown_chars = check_input_chars(message, string_to_int)
if unknown_chars: if unknown_chars:
print(f"Message contains unknown characters: {unknown_chars}") print(f"Message contains unknown characters: {unknown_chars}")
@ -166,14 +187,14 @@ def process_message(message):
encoded_text = torch.tensor( encoded_text = torch.tensor(
[encode(message, string_to_int)], dtype=torch.long [encode(message, string_to_int)], dtype=torch.long
).to(device) ).to(device)
print(f"Encoded text shape: {encoded_text.shape}") # Debug print print(f"Encoded text shape: {encoded_text.shape}")
if encoded_text.size(1) == 0: if encoded_text.size(1) == 0:
print("Message could not be processed.") # Debug print print("Message could not be processed.")
return "Message could not be processed." return "Message could not be processed."
response = model.generate(encoded_text, max_new_tokens=50, temperature=0.7) response = model.generate(encoded_text, max_new_tokens=50, temperature=0.7)
decoded_response = decode(response[0].tolist(), int_to_string) decoded_response = decode(response[0].tolist(), int_to_string)
print(f"Generated response: '{decoded_response}'") # Debug print print(f"Generated response: '{decoded_response}'")
return decoded_response return decoded_response

Binary file not shown.