feat: Managed to achieve a loss of 0.285
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@ -160,3 +160,4 @@ cython_debug/
|
|||||||
#.idea/
|
#.idea/
|
||||||
/openwebtext
|
/openwebtext
|
||||||
/data_extract.py
|
/data_extract.py
|
||||||
|
/runs/phoebe_training
|
||||||
|
@ -3,12 +3,11 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
batch_size = 64
|
|
||||||
block_size = 256
|
block_size = 256
|
||||||
num_embed = 384 # Ensure consistency in naming
|
num_embed = 512 # Increased embedding size
|
||||||
num_heads = 8
|
num_heads = 8
|
||||||
num_layers = 8
|
num_layers = 12 # Increased number of layers
|
||||||
dropout = 0.2
|
dropout = 0.3
|
||||||
|
|
||||||
|
|
||||||
class Head(nn.Module):
|
class Head(nn.Module):
|
||||||
@ -131,7 +130,6 @@ class GPT(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def encode(s, string_to_int):
|
def encode(s, string_to_int):
|
||||||
# Replace unknown characters with a special token (e.g., "<unk>")
|
|
||||||
return [string_to_int.get(c, string_to_int["<unk>"]) for c in s]
|
return [string_to_int.get(c, string_to_int["<unk>"]) for c in s]
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,33 +1,37 @@
|
|||||||
|
# flake8: noqa: E203
|
||||||
|
import os
|
||||||
|
import random
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
import random
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
import os
|
|
||||||
from gpt_model import encode, decode, load_model
|
from gpt_model import encode, decode, load_model
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
batch_size = 64
|
batch_size = 32 # Reduced batch size for gradient accumulation
|
||||||
|
accumulation_steps = 4 # Gradient accumulation steps
|
||||||
block_size = 256
|
block_size = 256
|
||||||
max_iters = 5000
|
max_iters = 100000 # Increased iterations
|
||||||
learning_rate = 1e-5 # Adjusted learning rate
|
learning_rate = 3e-5 # Adjust learning rate
|
||||||
eval_iters = 100
|
eval_iters = 100
|
||||||
dropout = 0.2
|
dropout = 0.4 # Increased dropout to prevent overfitting
|
||||||
patience = 500 # Number of iterations to wait for improvement before stopping
|
patience = 20000 # Increased patience for early stopping
|
||||||
|
weight_decay = 0.01 # Add weight decay for regularization
|
||||||
|
|
||||||
# Load the vocabulary and encoded data
|
# Load the vocabulary and encoded data
|
||||||
with open("vocab.txt", "r", encoding="utf-8") as f:
|
with open("vocab.txt", "r", encoding="utf-8") as f:
|
||||||
text = f.read()
|
text = f.read()
|
||||||
chars = sorted(list(set(text)))
|
chars = sorted(list(set(text)))
|
||||||
|
|
||||||
# Ensure that space and other special characters are included
|
|
||||||
required_chars = " \n\r\t"
|
required_chars = " \n\r\t"
|
||||||
for char in required_chars:
|
for char in required_chars:
|
||||||
if char not in chars:
|
if char not in chars:
|
||||||
chars.append(char)
|
chars.append(char)
|
||||||
|
|
||||||
# Add a special token for unknown characters
|
|
||||||
special_token = "<unk>"
|
special_token = "<unk>"
|
||||||
if special_token not in chars:
|
if special_token not in chars:
|
||||||
chars.append(special_token)
|
chars.append(special_token)
|
||||||
@ -38,14 +42,12 @@ int_to_string = {i: ch for i, ch in enumerate(chars)}
|
|||||||
|
|
||||||
|
|
||||||
def clean_text(text):
|
def clean_text(text):
|
||||||
"""Remove special characters and unwanted symbols from the text."""
|
|
||||||
text = re.sub(r"[^a-zA-Z0-9\s.,;!?\'\"]+", "", text)
|
text = re.sub(r"[^a-zA-Z0-9\s.,;!?\'\"]+", "", text)
|
||||||
text = re.sub(r"\s+", " ", text)
|
text = re.sub(r"\s+", " ", text)
|
||||||
text = text.strip()
|
text = text.strip()
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
# Load and preprocess training and validation data from cleaned .txt files
|
|
||||||
def load_and_clean_data(file_path):
|
def load_and_clean_data(file_path):
|
||||||
with open(file_path, "r", encoding="utf-8") as f:
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
text = f.read()
|
text = f.read()
|
||||||
@ -96,8 +98,8 @@ def estimate_loss():
|
|||||||
data = train_data if split == "train" else val_data
|
data = train_data if split == "train" else val_data
|
||||||
losses = torch.zeros(eval_iters)
|
losses = torch.zeros(eval_iters)
|
||||||
for k in range(eval_iters):
|
for k in range(eval_iters):
|
||||||
X, Y = get_batch(data, block_size, batch_size)
|
x, y = get_batch(data, block_size, batch_size)
|
||||||
logits, loss = model(X, Y)
|
logits, loss = model(x, y)
|
||||||
losses[k] = loss.item()
|
losses[k] = loss.item()
|
||||||
out[split] = losses.mean().item()
|
out[split] = losses.mean().item()
|
||||||
model.train()
|
model.train()
|
||||||
@ -105,8 +107,23 @@ def estimate_loss():
|
|||||||
|
|
||||||
|
|
||||||
def train_model():
|
def train_model():
|
||||||
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
|
optimizer = optim.AdamW(
|
||||||
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.1)
|
model.parameters(), lr=learning_rate, weight_decay=weight_decay
|
||||||
|
)
|
||||||
|
steps_per_epoch = len(train_data) // (batch_size * block_size)
|
||||||
|
epochs = max_iters // steps_per_epoch
|
||||||
|
|
||||||
|
scheduler = optim.lr_scheduler.OneCycleLR(
|
||||||
|
optimizer,
|
||||||
|
max_lr=learning_rate * 10,
|
||||||
|
steps_per_epoch=steps_per_epoch,
|
||||||
|
epochs=epochs,
|
||||||
|
)
|
||||||
|
|
||||||
|
writer = SummaryWriter(
|
||||||
|
log_dir="runs/phoebe_training"
|
||||||
|
) # TensorBoard writer # noqa: E501
|
||||||
|
|
||||||
best_val_loss = float("inf")
|
best_val_loss = float("inf")
|
||||||
patience_counter = 0
|
patience_counter = 0
|
||||||
|
|
||||||
@ -118,7 +135,9 @@ def train_model():
|
|||||||
f"val loss {losses['val']:.3f}"
|
f"val loss {losses['val']:.3f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check for improvement in validation loss
|
writer.add_scalar("Loss/train", losses["train"], iter)
|
||||||
|
writer.add_scalar("Loss/val", losses["val"], iter)
|
||||||
|
|
||||||
if losses["val"] < best_val_loss:
|
if losses["val"] < best_val_loss:
|
||||||
best_val_loss = losses["val"]
|
best_val_loss = losses["val"]
|
||||||
patience_counter = 0
|
patience_counter = 0
|
||||||
@ -127,21 +146,24 @@ def train_model():
|
|||||||
else:
|
else:
|
||||||
patience_counter += eval_iters
|
patience_counter += eval_iters
|
||||||
|
|
||||||
# Early stopping
|
|
||||||
if patience_counter >= patience:
|
if patience_counter >= patience:
|
||||||
print("Early stopping triggered.")
|
print("Early stopping triggered.")
|
||||||
break
|
break
|
||||||
|
|
||||||
xb, yb = get_batch(train_data, block_size, batch_size)
|
xb, yb = get_batch(train_data, block_size, batch_size)
|
||||||
logits, loss = model(xb, yb)
|
logits, loss = model(xb, yb)
|
||||||
optimizer.zero_grad(set_to_none=True)
|
loss = loss / accumulation_steps # Scale loss by accumulation steps
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
|
||||||
scheduler.step()
|
if (iter + 1) % accumulation_steps == 0:
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
if patience_counter < patience:
|
if patience_counter < patience:
|
||||||
print("Training completed without early stopping.")
|
print("Training completed without early stopping.")
|
||||||
print(f"Final loss: {loss.item()}")
|
print(f"Final loss: {loss.item()}")
|
||||||
|
writer.close()
|
||||||
|
|
||||||
|
|
||||||
def check_input_chars(s, string_to_int):
|
def check_input_chars(s, string_to_int):
|
||||||
@ -152,12 +174,11 @@ def check_input_chars(s, string_to_int):
|
|||||||
|
|
||||||
|
|
||||||
def process_message(message):
|
def process_message(message):
|
||||||
print(f"Processing message: '{message}'") # Debug print
|
print(f"Processing message: '{message}'")
|
||||||
if not message.strip():
|
if not message.strip():
|
||||||
print("Message is empty or invalid.") # Debug print
|
print("Message is empty or invalid.")
|
||||||
return "Message is empty or invalid."
|
return "Message is empty or invalid."
|
||||||
|
|
||||||
# Check for unknown characters
|
|
||||||
unknown_chars = check_input_chars(message, string_to_int)
|
unknown_chars = check_input_chars(message, string_to_int)
|
||||||
if unknown_chars:
|
if unknown_chars:
|
||||||
print(f"Message contains unknown characters: {unknown_chars}")
|
print(f"Message contains unknown characters: {unknown_chars}")
|
||||||
@ -166,14 +187,14 @@ def process_message(message):
|
|||||||
encoded_text = torch.tensor(
|
encoded_text = torch.tensor(
|
||||||
[encode(message, string_to_int)], dtype=torch.long
|
[encode(message, string_to_int)], dtype=torch.long
|
||||||
).to(device)
|
).to(device)
|
||||||
print(f"Encoded text shape: {encoded_text.shape}") # Debug print
|
print(f"Encoded text shape: {encoded_text.shape}")
|
||||||
if encoded_text.size(1) == 0:
|
if encoded_text.size(1) == 0:
|
||||||
print("Message could not be processed.") # Debug print
|
print("Message could not be processed.")
|
||||||
return "Message could not be processed."
|
return "Message could not be processed."
|
||||||
|
|
||||||
response = model.generate(encoded_text, max_new_tokens=50, temperature=0.7)
|
response = model.generate(encoded_text, max_new_tokens=50, temperature=0.7)
|
||||||
decoded_response = decode(response[0].tolist(), int_to_string)
|
decoded_response = decode(response[0].tolist(), int_to_string)
|
||||||
print(f"Generated response: '{decoded_response}'") # Debug print
|
print(f"Generated response: '{decoded_response}'")
|
||||||
return decoded_response
|
return decoded_response
|
||||||
|
|
||||||
|
|
||||||
|
BIN
phoebe_model.pt
BIN
phoebe_model.pt
Binary file not shown.
Reference in New Issue
Block a user