Updated the model capacity
This commit is contained in:
6
.gitignore
vendored
6
.gitignore
vendored
@ -205,4 +205,8 @@ cython_debug/
|
||||
marimo/_static/
|
||||
marimo/_lsp/
|
||||
__marimo__/
|
||||
/data/books
|
||||
/data/books
|
||||
/data/openwebtext
|
||||
/catlin_tokenizer.pkl
|
||||
/catlin_model.pt
|
||||
/catlin_chatlog.txt
|
@ -8,4 +8,4 @@ NUM_LAYERS = 6
|
||||
BATCH_SIZE = 16
|
||||
LEARNING_RATE = 3e-4
|
||||
DEVICE = "cuda" # fallback handled in trainer
|
||||
MAX_TOKENS = 100_000 # Used to cap input corpus size
|
||||
MAX_TOKENS = 500_000 # Used to cap input corpus size
|
||||
|
File diff suppressed because one or more lines are too long
66
main.py
66
main.py
@ -2,6 +2,7 @@
|
||||
|
||||
import os
|
||||
import torch
|
||||
import json
|
||||
from config import *
|
||||
from tokenizers.word_tokenizer import WordTokenizer
|
||||
from models.gpt import GPT
|
||||
@ -11,45 +12,42 @@ from training.trainer import TextDataset, train
|
||||
def load_texts():
|
||||
text = ""
|
||||
|
||||
# --- Books ---
|
||||
book_dir = os.path.join("data", "books")
|
||||
os.makedirs(book_dir, exist_ok=True)
|
||||
print("[INFO] Scanning books from: data\\books")
|
||||
books_dir = "data/books"
|
||||
if os.path.exists(books_dir):
|
||||
for file in os.listdir(books_dir):
|
||||
if file.endswith(".txt"):
|
||||
print(f" 📚 Loading {file}")
|
||||
with open(os.path.join(books_dir, file), encoding="utf-8") as f:
|
||||
text += f.read()
|
||||
|
||||
print(f"[INFO] Scanning books from: {book_dir}")
|
||||
for file in os.listdir(book_dir):
|
||||
path = os.path.join(book_dir, file)
|
||||
if file.endswith(".txt"):
|
||||
print(f" 📚 Loading {file}")
|
||||
try:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
text += f.read() + "\n"
|
||||
except Exception as e:
|
||||
print(f" ❌ Failed to read {file}: {e}")
|
||||
|
||||
# --- OpenWebText ---
|
||||
owt_path = os.path.join("data/openwebtext", "owt_20000.jsonl")
|
||||
print(f"[INFO] Scanning OpenWebText: {owt_path}")
|
||||
|
||||
if os.path.exists(owt_path):
|
||||
with open(owt_path, encoding="utf-8") as f:
|
||||
for i, line in enumerate(f):
|
||||
if i % 1000 == 0:
|
||||
print(f" ⏳ {i} lines read...")
|
||||
owt_dir = "data/openwebtext"
|
||||
if os.path.exists(owt_dir):
|
||||
print(f"[INFO] Scanning OpenWebText: {owt_dir}")
|
||||
for fname in os.listdir(owt_dir):
|
||||
if fname.endswith(".jsonl"):
|
||||
path = os.path.join(owt_dir, fname)
|
||||
print(f" 📄 Loading {fname}")
|
||||
try:
|
||||
text += line.strip() + "\n"
|
||||
with open(path, encoding="utf-8") as f:
|
||||
for i, line in enumerate(f):
|
||||
try:
|
||||
obj = json.loads(line)
|
||||
text += obj.get("text", "")
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if len(text) >= MAX_TOKENS * 10:
|
||||
break
|
||||
except Exception as e:
|
||||
print(f" ❌ Line {i} decode error: {e}")
|
||||
else:
|
||||
print(f"[WARN] OpenWebText file not found: {owt_path}")
|
||||
print(f" ⚠️ Error reading {fname}: {e}")
|
||||
|
||||
# --- Chat logs ---
|
||||
if os.path.exists("catlin_chatlog.txt"):
|
||||
print(f"[INFO] Appending chat log...")
|
||||
with open("catlin_chatlog.txt", encoding="utf-8") as f:
|
||||
text += "\n" + f.read()
|
||||
print("[INFO] Raw text loaded:", len(text), "characters")
|
||||
|
||||
print(f"[INFO] Raw text loaded: {len(text)} characters")
|
||||
return text[:MAX_TOKENS * 10]
|
||||
# Truncate to MAX_TOKENS * 10 (rough estimate)
|
||||
clipped = text[:MAX_TOKENS * 10]
|
||||
print("[INFO] Loaded text:", len(clipped), "characters")
|
||||
|
||||
return clipped
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -4,22 +4,42 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class GPTBlock(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads):
|
||||
class CausalSelfAttention(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, dropout):
|
||||
super().__init__()
|
||||
self.attn = nn.MultiheadAttention(
|
||||
embed_dim, num_heads,
|
||||
dropout=dropout,
|
||||
batch_first=True,
|
||||
bias=True
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
B, T, _ = x.size()
|
||||
# Create causal mask: (T, T) with float('-inf') for future positions
|
||||
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1)
|
||||
mask = mask.masked_fill(mask == 1, float('-inf'))
|
||||
|
||||
# Pass it in as attn_mask
|
||||
return self.attn(x, x, x, attn_mask=mask)[0]
|
||||
|
||||
|
||||
class GPTBlock(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, dropout):
|
||||
super().__init__()
|
||||
self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
|
||||
self.ln1 = nn.LayerNorm(embed_dim)
|
||||
self.attn = CausalSelfAttention(embed_dim, num_heads, dropout)
|
||||
self.ln2 = nn.LayerNorm(embed_dim)
|
||||
self.ff = nn.Sequential(
|
||||
nn.Linear(embed_dim, 4 * embed_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(4 * embed_dim, embed_dim)
|
||||
nn.Linear(4 * embed_dim, embed_dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
self.ln2 = nn.LayerNorm(embed_dim)
|
||||
|
||||
def forward(self, x):
|
||||
attn_out, _ = self.attn(x, x, x, need_weights=False)
|
||||
x = self.ln1(x + attn_out)
|
||||
x = self.ln2(x + self.ff(x))
|
||||
x = x + self.attn(self.ln1(x))
|
||||
x = x + self.ff(self.ln2(x))
|
||||
return x
|
||||
|
||||
|
||||
@ -28,16 +48,18 @@ class GPT(nn.Module):
|
||||
super().__init__()
|
||||
self.token_emb = nn.Embedding(vocab_size, embed_dim)
|
||||
self.pos_emb = nn.Parameter(torch.zeros(1, context_size, embed_dim))
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.blocks = nn.ModuleList([GPTBlock(embed_dim, num_heads) for _ in range(num_layers)])
|
||||
self.drop = nn.Dropout(dropout)
|
||||
self.blocks = nn.ModuleList([
|
||||
GPTBlock(embed_dim, num_heads, dropout)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
self.ln_f = nn.LayerNorm(embed_dim)
|
||||
self.head = nn.Linear(embed_dim, vocab_size)
|
||||
|
||||
def forward(self, x):
|
||||
B, T = x.size()
|
||||
tok_emb = self.token_emb(x)
|
||||
x = tok_emb + self.pos_emb[:, :T, :]
|
||||
x = self.dropout(x)
|
||||
x = self.token_emb(x) + self.pos_emb[:, :T, :]
|
||||
x = self.drop(x)
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
x = self.ln_f(x)
|
||||
|
@ -1,5 +1,3 @@
|
||||
# tools/openwebtext_fetcher.py
|
||||
|
||||
from datasets import load_dataset
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
@ -8,14 +6,29 @@ TARGET_DIR = "data/openwebtext"
|
||||
os.makedirs(TARGET_DIR, exist_ok=True)
|
||||
|
||||
|
||||
def fetch_subset(n=10000, split="train"):
|
||||
def fetch_subset(total=20000, chunk_size=5000, split="train"):
|
||||
ds = load_dataset("stas/openwebtext-10k", split=split)
|
||||
with open(os.path.join(TARGET_DIR, f"owt_{n}.jsonl"), "w", encoding="utf-8") as f:
|
||||
for i, item in tqdm(enumerate(ds), total=n, desc="Writing JSONL"):
|
||||
f.write(f"{item['text'].replace(chr(10),' ')}\n")
|
||||
if i + 1 >= n:
|
||||
break
|
||||
print(f"[INFO] Total to fetch: {total} | Chunk size: {chunk_size}")
|
||||
|
||||
count = 0
|
||||
file_index = 0
|
||||
f = None
|
||||
|
||||
for item in tqdm(ds, desc="Downloading"):
|
||||
if count % chunk_size == 0:
|
||||
if f: f.close()
|
||||
file_path = os.path.join(TARGET_DIR, f"owt_{file_index:05d}.jsonl")
|
||||
f = open(file_path, "w", encoding="utf-8")
|
||||
print(f"[INFO] ➕ Created {file_path}")
|
||||
file_index += 1
|
||||
f.write(f"{item['text'].replace(chr(10), ' ')}\n")
|
||||
count += 1
|
||||
if count >= total:
|
||||
break
|
||||
|
||||
if f:
|
||||
f.close()
|
||||
print(f"[INFO] ✅ Done. {count} samples across {file_index} files.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
fetch_subset(20000) # fetch 20K examples (~100 MB)
|
||||
fetch_subset(total=100000, chunk_size=5000)
|
||||
|
Reference in New Issue
Block a user