Compare commits
10 Commits
9db0796905
...
e3e4b7abe6
Author | SHA1 | Date | |
---|---|---|---|
|
e3e4b7abe6 | ||
|
1fe54ed1ff | ||
|
509670c989 | ||
|
47c8cce3dd | ||
|
763514e815 | ||
|
fb8db8a870 | ||
|
75f1116b3b | ||
|
12071fbf61 | ||
|
54c4cf59b0 | ||
|
adca64bfc8 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -158,3 +158,6 @@ cython_debug/
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
/openwebtext
|
||||
/data_extract.py
|
||||
/runs/phoebe_training
|
||||
|
@@ -1,11 +1,19 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v3.4.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: check-yaml
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 24.4.2
|
||||
rev: 22.3.0
|
||||
hooks:
|
||||
- id: black
|
||||
language_version: python3.10.6
|
||||
args: [--line-length=79]
|
||||
|
||||
- repo: https://github.com/pycqa/flake8
|
||||
rev: 7.0.0 # Use the latest revision
|
||||
rev: 4.0.1
|
||||
hooks:
|
||||
- id: flake8
|
||||
args: [--max-line-length=79, --ignore=E203]
|
||||
|
15
.vscode/launch.json
vendored
Normal file
15
.vscode/launch.json
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Phoebe",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "E:\\Development\\AI Development\\Phoebe\\phoebe\\discord_bot.py",
|
||||
"console": "integratedTerminal"
|
||||
}
|
||||
]
|
||||
}
|
@@ -2,4 +2,4 @@
|
||||

|
||||
|
||||
# About Me
|
||||
Hi there! My name is Phoebe! I am a 20 year old college student who is currently working on my degree in Machine Learning. I am a bit of a shy gal, and like to obverse everyone from the distance. My best friend is Daniel (@advtech as he goes by on Discord). I am looking forward to getting to know you!
|
||||
Hi there! My name is Phoebe! I am a 20 year old college student who is currently working on my degree in Machine Learning. I am a bit of a shy gal, and like to obverse everyone from the distance. My best friend is Daniel (@advtech as he goes by on Discord). I am looking forward to getting to know you!
|
||||
|
48
clean_data.py
Normal file
48
clean_data.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import re
|
||||
|
||||
|
||||
def clean_data(data):
|
||||
# Split data into lines and filter out metadata
|
||||
lines = data.splitlines()
|
||||
clean_lines = []
|
||||
|
||||
# Regex patterns to identify metadata
|
||||
metadata_patterns = [
|
||||
r"^[0-9a-f]{8}-[0-9a-f]{32}\.txt", # Pattern to identify metadata
|
||||
# lines with .txt file names
|
||||
r"^[0-9]+$", # Pattern to identify lines with only numbers
|
||||
r"^[0-9]{7,8}.*$", # Pattern to identify lines
|
||||
# starting with 7 or 8 digit numbers
|
||||
r"^[^a-zA-Z]*$", # Pattern to identify lines
|
||||
# without alphabetic characters
|
||||
r"^.*ustar.*$", # Pattern to identify lines containing 'ustar'
|
||||
]
|
||||
|
||||
for line in lines:
|
||||
if any(re.match(pattern, line) for pattern in metadata_patterns):
|
||||
continue
|
||||
clean_lines.append(line)
|
||||
|
||||
return "\n".join(clean_lines)
|
||||
|
||||
|
||||
# Load and clean training data
|
||||
with open("train_split.txt", "r", encoding="utf-8") as f:
|
||||
train_data = f.read()
|
||||
train_data_cleaned = clean_data(train_data)
|
||||
|
||||
# Load and clean validation data
|
||||
with open("eval_split.txt", "r", encoding="utf-8") as f:
|
||||
val_data = f.read()
|
||||
val_data_cleaned = clean_data(val_data)
|
||||
|
||||
# Save cleaned data for inspection (optional)
|
||||
with open("train_split_cleaned.txt", "w", encoding="utf-8") as f:
|
||||
f.write(train_data_cleaned)
|
||||
|
||||
with open("eval_split_cleaned.txt", "w", encoding="utf-8") as f:
|
||||
f.write(val_data_cleaned)
|
||||
|
||||
# Print sample cleaned data
|
||||
print("Sample cleaned training data:", train_data_cleaned[:1000])
|
||||
print("Sample cleaned validation data:", val_data_cleaned[:1000])
|
114
combine_and_clean.py
Normal file
114
combine_and_clean.py
Normal file
@@ -0,0 +1,114 @@
|
||||
# combine_and_clean.py
|
||||
import os
|
||||
import re
|
||||
import random
|
||||
from tqdm import tqdm
|
||||
import concurrent.futures
|
||||
import multiprocessing
|
||||
|
||||
|
||||
def clean_data(data):
|
||||
lines = data.splitlines()
|
||||
clean_lines = []
|
||||
|
||||
metadata_patterns = [
|
||||
r"^[0-9a-f]{8}-[0-9a-f]{32}\.txt",
|
||||
r"^[0-9]+$",
|
||||
r"^[0-9]{7,8}.*$",
|
||||
r"^[^a-zA-Z]*$",
|
||||
r"^.*ustar.*$",
|
||||
]
|
||||
|
||||
for line in lines:
|
||||
if any(re.match(pattern, line) for pattern in metadata_patterns):
|
||||
continue
|
||||
clean_lines.append(line)
|
||||
|
||||
return "\n".join(clean_lines)
|
||||
|
||||
|
||||
def process_file(args):
|
||||
directory, filename, output_file = args
|
||||
file_path = os.path.join(directory, filename)
|
||||
with open(file_path, "rt", encoding="utf-8") as infile:
|
||||
text = infile.read()
|
||||
with open(output_file, "a", encoding="utf-8") as outfile:
|
||||
outfile.write(text)
|
||||
characters = set(text)
|
||||
return characters
|
||||
|
||||
|
||||
def files_in_dir(directory):
|
||||
return [
|
||||
filename
|
||||
for filename in os.listdir(directory)
|
||||
if os.path.isfile(os.path.join(directory, filename))
|
||||
]
|
||||
|
||||
|
||||
def process_files_in_parallel(files, folder_path, output_file):
|
||||
vocab = set()
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=4) as executor:
|
||||
args = [(folder_path, filename, output_file) for filename in files]
|
||||
for characters in tqdm(
|
||||
executor.map(process_file, args), total=len(files)
|
||||
):
|
||||
vocab.update(characters)
|
||||
return vocab
|
||||
|
||||
|
||||
def main():
|
||||
multiprocessing.freeze_support()
|
||||
|
||||
dataset_dirs = ["datasets/openwebtext", "datasets/other_dataset"]
|
||||
output_file_train = "combined_train.txt"
|
||||
output_file_val = "combined_eval.txt"
|
||||
vocab_file = "vocab.txt"
|
||||
|
||||
all_files = []
|
||||
for dir in dataset_dirs:
|
||||
all_files.extend([(dir, filename) for filename in files_in_dir(dir)])
|
||||
|
||||
total_files = len(all_files)
|
||||
split_index = int(total_files * 0.9)
|
||||
files_train = all_files[:split_index]
|
||||
files_val = all_files[split_index:]
|
||||
|
||||
sample_rate = 0.01
|
||||
files_train_sampled = random.sample(
|
||||
files_train, max(1, int(len(files_train) * sample_rate))
|
||||
)
|
||||
files_val_sampled = random.sample(
|
||||
files_val, max(1, int(len(files_val) * sample_rate))
|
||||
)
|
||||
|
||||
open(output_file_train, "w").close()
|
||||
open(output_file_val, "w").close()
|
||||
|
||||
vocab_train = process_files_in_parallel(
|
||||
files_train_sampled, dataset_dirs[0], output_file_train
|
||||
)
|
||||
vocab_val = process_files_in_parallel(
|
||||
files_val_sampled, dataset_dirs[0], output_file_val
|
||||
)
|
||||
|
||||
vocab = vocab_train.union(vocab_val)
|
||||
with open(vocab_file, "w", encoding="utf-8") as vfile:
|
||||
for char in sorted(vocab):
|
||||
vfile.write(char + "\n")
|
||||
|
||||
with open(output_file_train, "r", encoding="utf-8") as f:
|
||||
train_data = f.read()
|
||||
train_data_cleaned = clean_data(train_data)
|
||||
with open("combined_train_cleaned.txt", "w", encoding="utf-8") as f:
|
||||
f.write(train_data_cleaned)
|
||||
|
||||
with open(output_file_val, "r", encoding="utf-8") as f:
|
||||
val_data = f.read()
|
||||
val_data_cleaned = clean_data(val_data)
|
||||
with open("combined_eval_cleaned.txt", "w", encoding="utf-8") as f:
|
||||
f.write(val_data_cleaned)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
BIN
eval_split.txt
Normal file
BIN
eval_split.txt
Normal file
Binary file not shown.
141562
eval_split_cleaned.txt
Normal file
141562
eval_split_cleaned.txt
Normal file
File diff suppressed because one or more lines are too long
64
phoebe/discord_bot.py
Normal file
64
phoebe/discord_bot.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import os
|
||||
import discord
|
||||
import torch
|
||||
from dotenv import load_dotenv
|
||||
from train_gpt_model import process_message
|
||||
from gpt_model import load_model
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
# Get the Discord bot token from environment variables
|
||||
TOKEN = os.getenv("DISCORD_TOKEN")
|
||||
|
||||
# Load the vocabulary
|
||||
with open("vocab.txt", "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
chars = sorted(list(set(text)))
|
||||
|
||||
# Ensure that space and other special characters are included
|
||||
required_chars = " \n\r\t"
|
||||
for char in required_chars:
|
||||
if char not in chars:
|
||||
chars.append(char)
|
||||
|
||||
# Add a special token for unknown characters
|
||||
special_token = "<unk>"
|
||||
if special_token not in chars:
|
||||
chars.append(special_token)
|
||||
|
||||
vocab_size = len(chars)
|
||||
string_to_int = {ch: i for i, ch in enumerate(chars)}
|
||||
int_to_string = {i: ch for i, ch in enumerate(chars)}
|
||||
|
||||
# Initialize and load the model
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = load_model(vocab_size, "phoebe_model.pt").to(device)
|
||||
|
||||
# Initialize Discord client
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
client = discord.Client(intents=intents)
|
||||
|
||||
|
||||
@client.event
|
||||
async def on_ready():
|
||||
print(f"We have logged in as {client.user}")
|
||||
|
||||
|
||||
@client.event
|
||||
async def on_message(message):
|
||||
if message.author == client.user:
|
||||
return
|
||||
|
||||
# Debug: print the message content
|
||||
print(f"Received message: '{message.content}'")
|
||||
|
||||
# Process the message and get a response
|
||||
response = process_message(message.content)
|
||||
|
||||
# Send the response back to the Discord channel
|
||||
await message.channel.send(response)
|
||||
|
||||
|
||||
client.run(TOKEN)
|
148
phoebe/gpt_model.py
Normal file
148
phoebe/gpt_model.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Hyperparameters
|
||||
block_size = 256
|
||||
num_embed = 512 # Increased embedding size
|
||||
num_heads = 8
|
||||
num_layers = 12 # Increased number of layers
|
||||
dropout = 0.3
|
||||
|
||||
|
||||
class Head(nn.Module):
|
||||
def __init__(self, head_size):
|
||||
super().__init__()
|
||||
self.key = nn.Linear(num_embed, head_size)
|
||||
self.query = nn.Linear(num_embed, head_size)
|
||||
self.value = nn.Linear(num_embed, head_size)
|
||||
self.register_buffer(
|
||||
"tril", torch.tril(torch.ones(block_size, block_size))
|
||||
)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
B, T, C = x.shape
|
||||
k = self.key(x)
|
||||
q = self.query(x)
|
||||
wei = q @ k.transpose(-2, -1) * C**-0.5
|
||||
wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
|
||||
wei = F.softmax(wei, dim=-1)
|
||||
wei = self.dropout(wei)
|
||||
v = self.value(x)
|
||||
out = wei @ v
|
||||
return out
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, num_heads, head_size):
|
||||
super().__init__()
|
||||
self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
|
||||
self.proj = nn.Linear(num_embed, num_embed)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
out = torch.cat([h(x) for h in self.heads], dim=-1)
|
||||
out = self.dropout(self.proj(out))
|
||||
return out
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, num_embed):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(num_embed, 4 * num_embed),
|
||||
nn.ReLU(),
|
||||
nn.Linear(4 * num_embed, num_embed),
|
||||
nn.Dropout(dropout),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, num_embed, num_head):
|
||||
super().__init__()
|
||||
head_size = num_embed // num_head
|
||||
self.sa = MultiHeadAttention(num_head, head_size)
|
||||
self.ff = FeedForward(num_embed)
|
||||
self.ln1 = nn.LayerNorm(num_embed)
|
||||
self.ln2 = nn.LayerNorm(num_embed)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.sa(x)
|
||||
x = self.ln1(x + y)
|
||||
y = self.ff(x)
|
||||
x = self.ln2(x + y)
|
||||
return x
|
||||
|
||||
|
||||
class GPT(nn.Module):
|
||||
def __init__(self, vocab_size):
|
||||
super().__init__()
|
||||
self.token_embedding_table = nn.Embedding(vocab_size, num_embed)
|
||||
self.position_embedding_table = nn.Embedding(block_size, num_embed)
|
||||
self.blocks = nn.Sequential(
|
||||
*[Block(num_embed, num_heads) for _ in range(num_layers)]
|
||||
)
|
||||
self.ln = nn.LayerNorm(num_embed)
|
||||
self.lm_head = nn.Linear(num_embed, vocab_size)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
if module.bias is not None:
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
|
||||
def forward(self, idx, targets=None):
|
||||
B, T = idx.shape
|
||||
tok_emb = self.token_embedding_table(idx)
|
||||
pos_emb = self.position_embedding_table(
|
||||
torch.arange(T, device=idx.device)
|
||||
)
|
||||
x = tok_emb + pos_emb
|
||||
x = self.blocks(x)
|
||||
x = self.ln(x)
|
||||
logits = self.lm_head(x)
|
||||
|
||||
if targets is None:
|
||||
loss = None
|
||||
else:
|
||||
B, T, C = logits.shape
|
||||
logits = logits.view(B * T, C)
|
||||
targets = targets.view(B * T)
|
||||
loss = F.cross_entropy(logits, targets)
|
||||
return logits, loss
|
||||
|
||||
def generate(self, idx, max_new_tokens, temperature):
|
||||
for _ in range(max_new_tokens):
|
||||
idx_cond = idx[:, -block_size:]
|
||||
logits, _ = self(idx_cond)
|
||||
logits = logits[:, -1, :] / temperature
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
idx_next = torch.multinomial(probs, num_samples=1)
|
||||
idx = torch.cat((idx, idx_next), dim=1)
|
||||
return idx
|
||||
|
||||
|
||||
def encode(s, string_to_int):
|
||||
return [string_to_int.get(c, string_to_int["<unk>"]) for c in s]
|
||||
|
||||
|
||||
def decode(lst, int_to_string):
|
||||
return "".join([int_to_string[i] for i in lst])
|
||||
|
||||
|
||||
def load_model(vocab_size, model_path=None):
|
||||
model = GPT(vocab_size)
|
||||
if model_path:
|
||||
try:
|
||||
model.load_state_dict(torch.load(model_path))
|
||||
print("Model loaded successfully.")
|
||||
except FileNotFoundError:
|
||||
print("No pre-trained model found. Initialized a new model.")
|
||||
return model
|
229
phoebe/train_gpt_model.py
Normal file
229
phoebe/train_gpt_model.py
Normal file
@@ -0,0 +1,229 @@
|
||||
# flake8: noqa: E203
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from gpt_model import encode, decode, load_model
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Hyperparameters
|
||||
batch_size = 32 # Reduced batch size for gradient accumulation
|
||||
accumulation_steps = 4 # Gradient accumulation steps
|
||||
block_size = 256
|
||||
max_iters = 100000 # Increased iterations
|
||||
learning_rate = 3e-5 # Adjust learning rate
|
||||
eval_iters = 100
|
||||
dropout = 0.4 # Increased dropout to prevent overfitting
|
||||
patience = 20000 # Increased patience for early stopping
|
||||
weight_decay = 0.01 # Add weight decay for regularization
|
||||
|
||||
# Load the vocabulary and encoded data
|
||||
with open("vocab.txt", "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
chars = sorted(list(set(text)))
|
||||
|
||||
required_chars = " \n\r\t"
|
||||
for char in required_chars:
|
||||
if char not in chars:
|
||||
chars.append(char)
|
||||
|
||||
special_token = "<unk>"
|
||||
if special_token not in chars:
|
||||
chars.append(special_token)
|
||||
|
||||
vocab_size = len(chars)
|
||||
string_to_int = {ch: i for i, ch in enumerate(chars)}
|
||||
int_to_string = {i: ch for i, ch in enumerate(chars)}
|
||||
|
||||
|
||||
def clean_text(text):
|
||||
text = re.sub(r"[^a-zA-Z0-9\s.,;!?\'\"]+", "", text)
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
def load_and_clean_data(file_path):
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
cleaned_text = clean_text(text)
|
||||
return cleaned_text
|
||||
|
||||
|
||||
train_data = load_and_clean_data("train_split_cleaned.txt")
|
||||
val_data = load_and_clean_data("eval_split_cleaned.txt")
|
||||
|
||||
train_data = torch.tensor(encode(train_data, string_to_int), dtype=torch.long)
|
||||
val_data = torch.tensor(encode(val_data, string_to_int), dtype=torch.long)
|
||||
|
||||
|
||||
def get_random_chunk(data, chunk_size):
|
||||
start = random.randint(0, len(data) - chunk_size - 1)
|
||||
chunk = data[start : start + chunk_size]
|
||||
return chunk
|
||||
|
||||
|
||||
def get_batch(data, block_size, batch_size):
|
||||
chunk_size = block_size * (batch_size + 1)
|
||||
chunk = get_random_chunk(data, chunk_size)
|
||||
x = chunk[: block_size * batch_size].view(batch_size, block_size)
|
||||
y = chunk[1 : block_size * batch_size + 1].view(batch_size, block_size)
|
||||
x, y = x.to(device), y.to(device)
|
||||
return x, y
|
||||
|
||||
|
||||
def load_or_initialize_model(vocab_size):
|
||||
model = load_model(vocab_size)
|
||||
if os.path.exists("phoebe_model.pt"):
|
||||
model.load_state_dict(torch.load("phoebe_model.pt"))
|
||||
print("Model loaded from phoebe_model.pt")
|
||||
else:
|
||||
print("Initialized a new model")
|
||||
return model
|
||||
|
||||
|
||||
model = load_or_initialize_model(vocab_size).to(device)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def estimate_loss():
|
||||
out = {}
|
||||
model.eval()
|
||||
for split in ["train", "val"]:
|
||||
data = train_data if split == "train" else val_data
|
||||
losses = torch.zeros(eval_iters)
|
||||
for k in range(eval_iters):
|
||||
x, y = get_batch(data, block_size, batch_size)
|
||||
logits, loss = model(x, y)
|
||||
losses[k] = loss.item()
|
||||
out[split] = losses.mean().item()
|
||||
model.train()
|
||||
return out
|
||||
|
||||
|
||||
def train_model():
|
||||
optimizer = optim.AdamW(
|
||||
model.parameters(), lr=learning_rate, weight_decay=weight_decay
|
||||
)
|
||||
steps_per_epoch = len(train_data) // (batch_size * block_size)
|
||||
epochs = max_iters // steps_per_epoch
|
||||
|
||||
scheduler = optim.lr_scheduler.OneCycleLR(
|
||||
optimizer,
|
||||
max_lr=learning_rate * 10,
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
epochs=epochs,
|
||||
)
|
||||
|
||||
writer = SummaryWriter(log_dir="runs/phoebe_training")
|
||||
|
||||
best_val_loss = float("inf")
|
||||
patience_counter = 0
|
||||
|
||||
for iter in range(max_iters):
|
||||
if iter % eval_iters == 0:
|
||||
losses = estimate_loss()
|
||||
print(
|
||||
f"step {iter}: train loss {losses['train']:.3f}, "
|
||||
f"val loss {losses['val']:.3f}"
|
||||
)
|
||||
|
||||
writer.add_scalar("Loss/train", losses["train"], iter)
|
||||
writer.add_scalar("Loss/val", losses["val"], iter)
|
||||
|
||||
if losses["val"] < best_val_loss:
|
||||
best_val_loss = losses["val"]
|
||||
patience_counter = 0
|
||||
torch.save(model.state_dict(), "phoebe_model.pt")
|
||||
print("Model Saved!")
|
||||
else:
|
||||
patience_counter += eval_iters
|
||||
|
||||
if patience_counter >= patience:
|
||||
print("Early stopping triggered.")
|
||||
break
|
||||
|
||||
xb, yb = get_batch(train_data, block_size, batch_size)
|
||||
logits, loss = model(xb, yb)
|
||||
loss = loss / accumulation_steps
|
||||
loss.backward()
|
||||
|
||||
if (iter + 1) % accumulation_steps == 0:
|
||||
optimizer.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
scheduler.step()
|
||||
|
||||
if patience_counter < patience:
|
||||
print("Training completed without early stopping.")
|
||||
print(f"Final loss: {loss.item()}")
|
||||
writer.close()
|
||||
|
||||
|
||||
def check_input_chars(s, string_to_int):
|
||||
unknown_chars = [c for c in s if c not in string_to_int]
|
||||
if unknown_chars:
|
||||
print(f"Unknown characters in input: {unknown_chars}")
|
||||
return unknown_chars
|
||||
|
||||
|
||||
# Maintain conversation history
|
||||
conversation_history = []
|
||||
|
||||
|
||||
def process_message(message):
|
||||
global conversation_history
|
||||
|
||||
print(f"Processing message: '{message}'")
|
||||
if not message.strip():
|
||||
print("Message is empty or invalid.")
|
||||
return "Message is empty or invalid."
|
||||
|
||||
unknown_chars = check_input_chars(message, string_to_int)
|
||||
if unknown_chars:
|
||||
print(f"Message contains unknown characters: {unknown_chars}")
|
||||
return f"Message contains unknown characters: {unknown_chars}"
|
||||
|
||||
# Add the new message to the conversation history
|
||||
conversation_history.append(message)
|
||||
# Limit the conversation history to the last 5 messages to avoid excessive length
|
||||
if len(conversation_history) > 5:
|
||||
conversation_history = conversation_history[-5:]
|
||||
|
||||
# Concatenate the conversation history to form the input prompt
|
||||
context = " ".join(conversation_history)
|
||||
encoded_text = torch.tensor(
|
||||
[encode(context, string_to_int)], dtype=torch.long
|
||||
).to(device)
|
||||
print(f"Encoded text shape: {encoded_text.shape}")
|
||||
|
||||
if encoded_text.size(1) == 0:
|
||||
print("Message could not be processed.")
|
||||
return "Message could not be processed."
|
||||
|
||||
with torch.no_grad():
|
||||
generated_tokens = model.generate(
|
||||
encoded_text, max_new_tokens=100, temperature=1.0
|
||||
)
|
||||
generated_tokens = generated_tokens[0, len(encoded_text[0]) :]
|
||||
|
||||
decoded_response = decode(generated_tokens.tolist(), int_to_string)
|
||||
print(f"Generated response: '{decoded_response}'")
|
||||
|
||||
if decoded_response.startswith(context):
|
||||
decoded_response = decoded_response[len(context) :].strip()
|
||||
|
||||
print(f"Final response: '{decoded_response}'")
|
||||
|
||||
# Add the response to the conversation history
|
||||
conversation_history.append(decoded_response)
|
||||
|
||||
return decoded_response
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_model()
|
BIN
phoebe_model.pt
Normal file
BIN
phoebe_model.pt
Normal file
Binary file not shown.
BIN
train_split.txt
Normal file
BIN
train_split.txt
Normal file
Binary file not shown.
1282730
train_split_cleaned.txt
Normal file
1282730
train_split_cleaned.txt
Normal file
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user