Fix: Added code to allow for other data sources to be added
This commit is contained in:
114
combine_and_clean.py
Normal file
114
combine_and_clean.py
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
# combine_and_clean.py
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import random
|
||||||
|
from tqdm import tqdm
|
||||||
|
import concurrent.futures
|
||||||
|
import multiprocessing
|
||||||
|
|
||||||
|
|
||||||
|
def clean_data(data):
|
||||||
|
lines = data.splitlines()
|
||||||
|
clean_lines = []
|
||||||
|
|
||||||
|
metadata_patterns = [
|
||||||
|
r"^[0-9a-f]{8}-[0-9a-f]{32}\.txt",
|
||||||
|
r"^[0-9]+$",
|
||||||
|
r"^[0-9]{7,8}.*$",
|
||||||
|
r"^[^a-zA-Z]*$",
|
||||||
|
r"^.*ustar.*$",
|
||||||
|
]
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
if any(re.match(pattern, line) for pattern in metadata_patterns):
|
||||||
|
continue
|
||||||
|
clean_lines.append(line)
|
||||||
|
|
||||||
|
return "\n".join(clean_lines)
|
||||||
|
|
||||||
|
|
||||||
|
def process_file(args):
|
||||||
|
directory, filename, output_file = args
|
||||||
|
file_path = os.path.join(directory, filename)
|
||||||
|
with open(file_path, "rt", encoding="utf-8") as infile:
|
||||||
|
text = infile.read()
|
||||||
|
with open(output_file, "a", encoding="utf-8") as outfile:
|
||||||
|
outfile.write(text)
|
||||||
|
characters = set(text)
|
||||||
|
return characters
|
||||||
|
|
||||||
|
|
||||||
|
def files_in_dir(directory):
|
||||||
|
return [
|
||||||
|
filename
|
||||||
|
for filename in os.listdir(directory)
|
||||||
|
if os.path.isfile(os.path.join(directory, filename))
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def process_files_in_parallel(files, folder_path, output_file):
|
||||||
|
vocab = set()
|
||||||
|
with concurrent.futures.ProcessPoolExecutor(max_workers=4) as executor:
|
||||||
|
args = [(folder_path, filename, output_file) for filename in files]
|
||||||
|
for characters in tqdm(
|
||||||
|
executor.map(process_file, args), total=len(files)
|
||||||
|
):
|
||||||
|
vocab.update(characters)
|
||||||
|
return vocab
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
multiprocessing.freeze_support()
|
||||||
|
|
||||||
|
dataset_dirs = ["datasets/openwebtext", "datasets/other_dataset"]
|
||||||
|
output_file_train = "combined_train.txt"
|
||||||
|
output_file_val = "combined_eval.txt"
|
||||||
|
vocab_file = "vocab.txt"
|
||||||
|
|
||||||
|
all_files = []
|
||||||
|
for dir in dataset_dirs:
|
||||||
|
all_files.extend([(dir, filename) for filename in files_in_dir(dir)])
|
||||||
|
|
||||||
|
total_files = len(all_files)
|
||||||
|
split_index = int(total_files * 0.9)
|
||||||
|
files_train = all_files[:split_index]
|
||||||
|
files_val = all_files[split_index:]
|
||||||
|
|
||||||
|
sample_rate = 0.01
|
||||||
|
files_train_sampled = random.sample(
|
||||||
|
files_train, max(1, int(len(files_train) * sample_rate))
|
||||||
|
)
|
||||||
|
files_val_sampled = random.sample(
|
||||||
|
files_val, max(1, int(len(files_val) * sample_rate))
|
||||||
|
)
|
||||||
|
|
||||||
|
open(output_file_train, "w").close()
|
||||||
|
open(output_file_val, "w").close()
|
||||||
|
|
||||||
|
vocab_train = process_files_in_parallel(
|
||||||
|
files_train_sampled, dataset_dirs[0], output_file_train
|
||||||
|
)
|
||||||
|
vocab_val = process_files_in_parallel(
|
||||||
|
files_val_sampled, dataset_dirs[0], output_file_val
|
||||||
|
)
|
||||||
|
|
||||||
|
vocab = vocab_train.union(vocab_val)
|
||||||
|
with open(vocab_file, "w", encoding="utf-8") as vfile:
|
||||||
|
for char in sorted(vocab):
|
||||||
|
vfile.write(char + "\n")
|
||||||
|
|
||||||
|
with open(output_file_train, "r", encoding="utf-8") as f:
|
||||||
|
train_data = f.read()
|
||||||
|
train_data_cleaned = clean_data(train_data)
|
||||||
|
with open("combined_train_cleaned.txt", "w", encoding="utf-8") as f:
|
||||||
|
f.write(train_data_cleaned)
|
||||||
|
|
||||||
|
with open(output_file_val, "r", encoding="utf-8") as f:
|
||||||
|
val_data = f.read()
|
||||||
|
val_data_cleaned = clean_data(val_data)
|
||||||
|
with open("combined_eval_cleaned.txt", "w", encoding="utf-8") as f:
|
||||||
|
f.write(val_data_cleaned)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -118,7 +118,7 @@ class GPT(nn.Module):
|
|||||||
loss = F.cross_entropy(logits, targets)
|
loss = F.cross_entropy(logits, targets)
|
||||||
return logits, loss
|
return logits, loss
|
||||||
|
|
||||||
def generate(self, idx, max_new_tokens, temperature=1.0):
|
def generate(self, idx, max_new_tokens, temperature):
|
||||||
for _ in range(max_new_tokens):
|
for _ in range(max_new_tokens):
|
||||||
idx_cond = idx[:, -block_size:]
|
idx_cond = idx[:, -block_size:]
|
||||||
logits, _ = self(idx_cond)
|
logits, _ = self(idx_cond)
|
||||||
|
@ -171,7 +171,13 @@ def check_input_chars(s, string_to_int):
|
|||||||
return unknown_chars
|
return unknown_chars
|
||||||
|
|
||||||
|
|
||||||
|
# Maintain conversation history
|
||||||
|
conversation_history = []
|
||||||
|
|
||||||
|
|
||||||
def process_message(message):
|
def process_message(message):
|
||||||
|
global conversation_history
|
||||||
|
|
||||||
print(f"Processing message: '{message}'")
|
print(f"Processing message: '{message}'")
|
||||||
if not message.strip():
|
if not message.strip():
|
||||||
print("Message is empty or invalid.")
|
print("Message is empty or invalid.")
|
||||||
@ -182,27 +188,40 @@ def process_message(message):
|
|||||||
print(f"Message contains unknown characters: {unknown_chars}")
|
print(f"Message contains unknown characters: {unknown_chars}")
|
||||||
return f"Message contains unknown characters: {unknown_chars}"
|
return f"Message contains unknown characters: {unknown_chars}"
|
||||||
|
|
||||||
|
# Add the new message to the conversation history
|
||||||
|
conversation_history.append(message)
|
||||||
|
# Limit the conversation history to the last 5 messages to avoid excessive length
|
||||||
|
if len(conversation_history) > 5:
|
||||||
|
conversation_history = conversation_history[-5:]
|
||||||
|
|
||||||
|
# Concatenate the conversation history to form the input prompt
|
||||||
|
context = " ".join(conversation_history)
|
||||||
encoded_text = torch.tensor(
|
encoded_text = torch.tensor(
|
||||||
[encode(message, string_to_int)], dtype=torch.long
|
[encode(context, string_to_int)], dtype=torch.long
|
||||||
).to(device)
|
).to(device)
|
||||||
print(f"Encoded text shape: {encoded_text.shape}")
|
print(f"Encoded text shape: {encoded_text.shape}")
|
||||||
|
|
||||||
if encoded_text.size(1) == 0:
|
if encoded_text.size(1) == 0:
|
||||||
print("Message could not be processed.")
|
print("Message could not be processed.")
|
||||||
return "Message could not be processed."
|
return "Message could not be processed."
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
generated_tokens = model.generate(
|
generated_tokens = model.generate(
|
||||||
encoded_text, max_new_tokens=50, temperature=0.7
|
encoded_text, max_new_tokens=100, temperature=1.0
|
||||||
)
|
)
|
||||||
generated_tokens = generated_tokens[0, len(encoded_text[0]) :]
|
generated_tokens = generated_tokens[0, len(encoded_text[0]) :]
|
||||||
|
|
||||||
decoded_response = decode(generated_tokens.tolist(), int_to_string)
|
decoded_response = decode(generated_tokens.tolist(), int_to_string)
|
||||||
print(f"Generated response: '{decoded_response}'")
|
print(f"Generated response: '{decoded_response}'")
|
||||||
|
|
||||||
if decoded_response.startswith(message):
|
if decoded_response.startswith(context):
|
||||||
decoded_response = decoded_response[len(message) :].strip()
|
decoded_response = decoded_response[len(context) :].strip()
|
||||||
|
|
||||||
print(f"Final response: '{decoded_response}'")
|
print(f"Final response: '{decoded_response}'")
|
||||||
|
|
||||||
|
# Add the response to the conversation history
|
||||||
|
conversation_history.append(decoded_response)
|
||||||
|
|
||||||
return decoded_response
|
return decoded_response
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user