From e3e4b7abe6cedd9462b97063ff5025a257404f76 Mon Sep 17 00:00:00 2001 From: Dan Date: Sat, 8 Jun 2024 09:21:19 -0400 Subject: [PATCH] Fix: Added code to allow for other data sources to be added --- combine_and_clean.py | 114 ++++++++++++++++++++++++++++++++++++++ phoebe/gpt_model.py | 2 +- phoebe/train_gpt_model.py | 27 +++++++-- 3 files changed, 138 insertions(+), 5 deletions(-) create mode 100644 combine_and_clean.py diff --git a/combine_and_clean.py b/combine_and_clean.py new file mode 100644 index 0000000..81f7065 --- /dev/null +++ b/combine_and_clean.py @@ -0,0 +1,114 @@ +# combine_and_clean.py +import os +import re +import random +from tqdm import tqdm +import concurrent.futures +import multiprocessing + + +def clean_data(data): + lines = data.splitlines() + clean_lines = [] + + metadata_patterns = [ + r"^[0-9a-f]{8}-[0-9a-f]{32}\.txt", + r"^[0-9]+$", + r"^[0-9]{7,8}.*$", + r"^[^a-zA-Z]*$", + r"^.*ustar.*$", + ] + + for line in lines: + if any(re.match(pattern, line) for pattern in metadata_patterns): + continue + clean_lines.append(line) + + return "\n".join(clean_lines) + + +def process_file(args): + directory, filename, output_file = args + file_path = os.path.join(directory, filename) + with open(file_path, "rt", encoding="utf-8") as infile: + text = infile.read() + with open(output_file, "a", encoding="utf-8") as outfile: + outfile.write(text) + characters = set(text) + return characters + + +def files_in_dir(directory): + return [ + filename + for filename in os.listdir(directory) + if os.path.isfile(os.path.join(directory, filename)) + ] + + +def process_files_in_parallel(files, folder_path, output_file): + vocab = set() + with concurrent.futures.ProcessPoolExecutor(max_workers=4) as executor: + args = [(folder_path, filename, output_file) for filename in files] + for characters in tqdm( + executor.map(process_file, args), total=len(files) + ): + vocab.update(characters) + return vocab + + +def main(): + multiprocessing.freeze_support() + + dataset_dirs = ["datasets/openwebtext", "datasets/other_dataset"] + output_file_train = "combined_train.txt" + output_file_val = "combined_eval.txt" + vocab_file = "vocab.txt" + + all_files = [] + for dir in dataset_dirs: + all_files.extend([(dir, filename) for filename in files_in_dir(dir)]) + + total_files = len(all_files) + split_index = int(total_files * 0.9) + files_train = all_files[:split_index] + files_val = all_files[split_index:] + + sample_rate = 0.01 + files_train_sampled = random.sample( + files_train, max(1, int(len(files_train) * sample_rate)) + ) + files_val_sampled = random.sample( + files_val, max(1, int(len(files_val) * sample_rate)) + ) + + open(output_file_train, "w").close() + open(output_file_val, "w").close() + + vocab_train = process_files_in_parallel( + files_train_sampled, dataset_dirs[0], output_file_train + ) + vocab_val = process_files_in_parallel( + files_val_sampled, dataset_dirs[0], output_file_val + ) + + vocab = vocab_train.union(vocab_val) + with open(vocab_file, "w", encoding="utf-8") as vfile: + for char in sorted(vocab): + vfile.write(char + "\n") + + with open(output_file_train, "r", encoding="utf-8") as f: + train_data = f.read() + train_data_cleaned = clean_data(train_data) + with open("combined_train_cleaned.txt", "w", encoding="utf-8") as f: + f.write(train_data_cleaned) + + with open(output_file_val, "r", encoding="utf-8") as f: + val_data = f.read() + val_data_cleaned = clean_data(val_data) + with open("combined_eval_cleaned.txt", "w", encoding="utf-8") as f: + f.write(val_data_cleaned) + + +if __name__ == "__main__": + main() diff --git a/phoebe/gpt_model.py b/phoebe/gpt_model.py index 4b3375f..3fea76f 100644 --- a/phoebe/gpt_model.py +++ b/phoebe/gpt_model.py @@ -118,7 +118,7 @@ class GPT(nn.Module): loss = F.cross_entropy(logits, targets) return logits, loss - def generate(self, idx, max_new_tokens, temperature=1.0): + def generate(self, idx, max_new_tokens, temperature): for _ in range(max_new_tokens): idx_cond = idx[:, -block_size:] logits, _ = self(idx_cond) diff --git a/phoebe/train_gpt_model.py b/phoebe/train_gpt_model.py index 1e99d29..aba4a01 100644 --- a/phoebe/train_gpt_model.py +++ b/phoebe/train_gpt_model.py @@ -171,7 +171,13 @@ def check_input_chars(s, string_to_int): return unknown_chars +# Maintain conversation history +conversation_history = [] + + def process_message(message): + global conversation_history + print(f"Processing message: '{message}'") if not message.strip(): print("Message is empty or invalid.") @@ -182,27 +188,40 @@ def process_message(message): print(f"Message contains unknown characters: {unknown_chars}") return f"Message contains unknown characters: {unknown_chars}" + # Add the new message to the conversation history + conversation_history.append(message) + # Limit the conversation history to the last 5 messages to avoid excessive length + if len(conversation_history) > 5: + conversation_history = conversation_history[-5:] + + # Concatenate the conversation history to form the input prompt + context = " ".join(conversation_history) encoded_text = torch.tensor( - [encode(message, string_to_int)], dtype=torch.long + [encode(context, string_to_int)], dtype=torch.long ).to(device) print(f"Encoded text shape: {encoded_text.shape}") + if encoded_text.size(1) == 0: print("Message could not be processed.") return "Message could not be processed." with torch.no_grad(): generated_tokens = model.generate( - encoded_text, max_new_tokens=50, temperature=0.7 + encoded_text, max_new_tokens=100, temperature=1.0 ) generated_tokens = generated_tokens[0, len(encoded_text[0]) :] decoded_response = decode(generated_tokens.tolist(), int_to_string) print(f"Generated response: '{decoded_response}'") - if decoded_response.startswith(message): - decoded_response = decoded_response[len(message) :].strip() + if decoded_response.startswith(context): + decoded_response = decoded_response[len(context) :].strip() print(f"Final response: '{decoded_response}'") + + # Add the response to the conversation history + conversation_history.append(decoded_response) + return decoded_response