import re import torch import torch.optim as optim import random import os from gpt_model import encode, decode, load_model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Hyperparameters batch_size = 64 block_size = 256 max_iters = 5000 learning_rate = 1e-5 # Adjusted learning rate eval_iters = 100 dropout = 0.2 patience = 500 # Number of iterations to wait for improvement before stopping # Load the vocabulary and encoded data with open("vocab.txt", "r", encoding="utf-8") as f: text = f.read() chars = sorted(list(set(text))) # Ensure that space and other special characters are included required_chars = " \n\r\t" for char in required_chars: if char not in chars: chars.append(char) # Add a special token for unknown characters special_token = "" if special_token not in chars: chars.append(special_token) vocab_size = len(chars) string_to_int = {ch: i for i, ch in enumerate(chars)} int_to_string = {i: ch for i, ch in enumerate(chars)} def clean_text(text): """Remove special characters and unwanted symbols from the text.""" text = re.sub(r"[^a-zA-Z0-9\s.,;!?\'\"]+", "", text) text = re.sub(r"\s+", " ", text) text = text.strip() return text # Load and preprocess training and validation data from cleaned .txt files def load_and_clean_data(file_path): with open(file_path, "r", encoding="utf-8") as f: text = f.read() cleaned_text = clean_text(text) return cleaned_text train_data = load_and_clean_data("train_split_cleaned.txt") val_data = load_and_clean_data("eval_split_cleaned.txt") train_data = torch.tensor(encode(train_data, string_to_int), dtype=torch.long) val_data = torch.tensor(encode(val_data, string_to_int), dtype=torch.long) def get_random_chunk(data, chunk_size): start = random.randint(0, len(data) - chunk_size - 1) chunk = data[start : start + chunk_size] return chunk def get_batch(data, block_size, batch_size): chunk_size = block_size * (batch_size + 1) chunk = get_random_chunk(data, chunk_size) x = chunk[: block_size * batch_size].view(batch_size, block_size) y = chunk[1 : block_size * batch_size + 1].view(batch_size, block_size) x, y = x.to(device), y.to(device) return x, y def load_or_initialize_model(vocab_size): model = load_model(vocab_size) if os.path.exists("phoebe_model.pt"): model.load_state_dict(torch.load("phoebe_model.pt")) print("Model loaded from phoebe_model.pt") else: print("Initialized a new model") return model model = load_or_initialize_model(vocab_size).to(device) @torch.no_grad() def estimate_loss(): out = {} model.eval() for split in ["train", "val"]: data = train_data if split == "train" else val_data losses = torch.zeros(eval_iters) for k in range(eval_iters): X, Y = get_batch(data, block_size, batch_size) logits, loss = model(X, Y) losses[k] = loss.item() out[split] = losses.mean().item() model.train() return out def train_model(): optimizer = optim.AdamW(model.parameters(), lr=learning_rate) scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.1) best_val_loss = float("inf") patience_counter = 0 for iter in range(max_iters): if iter % eval_iters == 0: losses = estimate_loss() print( f"step {iter}: train loss {losses['train']:.3f}, " f"val loss {losses['val']:.3f}" ) # Check for improvement in validation loss if losses["val"] < best_val_loss: best_val_loss = losses["val"] patience_counter = 0 torch.save(model.state_dict(), "phoebe_model.pt") print("Model Saved!") else: patience_counter += eval_iters # Early stopping if patience_counter >= patience: print("Early stopping triggered.") break xb, yb = get_batch(train_data, block_size, batch_size) logits, loss = model(xb, yb) optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step() scheduler.step() if patience_counter < patience: print("Training completed without early stopping.") print(f"Final loss: {loss.item()}") def check_input_chars(s, string_to_int): unknown_chars = [c for c in s if c not in string_to_int] if unknown_chars: print(f"Unknown characters in input: {unknown_chars}") return unknown_chars def process_message(message): print(f"Processing message: '{message}'") # Debug print if not message.strip(): print("Message is empty or invalid.") # Debug print return "Message is empty or invalid." # Check for unknown characters unknown_chars = check_input_chars(message, string_to_int) if unknown_chars: print(f"Message contains unknown characters: {unknown_chars}") return f"Message contains unknown characters: {unknown_chars}" encoded_text = torch.tensor( [encode(message, string_to_int)], dtype=torch.long ).to(device) print(f"Encoded text shape: {encoded_text.shape}") # Debug print if encoded_text.size(1) == 0: print("Message could not be processed.") # Debug print return "Message could not be processed." response = model.generate(encoded_text, max_new_tokens=50, temperature=0.7) decoded_response = decode(response[0].tolist(), int_to_string) print(f"Generated response: '{decoded_response}'") # Debug print return decoded_response if __name__ == "__main__": train_model()