Updated the model capacity
This commit is contained in:
6
.gitignore
vendored
6
.gitignore
vendored
@ -205,4 +205,8 @@ cython_debug/
|
|||||||
marimo/_static/
|
marimo/_static/
|
||||||
marimo/_lsp/
|
marimo/_lsp/
|
||||||
__marimo__/
|
__marimo__/
|
||||||
/data/books
|
/data/books
|
||||||
|
/data/openwebtext
|
||||||
|
/catlin_tokenizer.pkl
|
||||||
|
/catlin_model.pt
|
||||||
|
/catlin_chatlog.txt
|
@ -8,4 +8,4 @@ NUM_LAYERS = 6
|
|||||||
BATCH_SIZE = 16
|
BATCH_SIZE = 16
|
||||||
LEARNING_RATE = 3e-4
|
LEARNING_RATE = 3e-4
|
||||||
DEVICE = "cuda" # fallback handled in trainer
|
DEVICE = "cuda" # fallback handled in trainer
|
||||||
MAX_TOKENS = 100_000 # Used to cap input corpus size
|
MAX_TOKENS = 500_000 # Used to cap input corpus size
|
||||||
|
File diff suppressed because one or more lines are too long
66
main.py
66
main.py
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
|
import json
|
||||||
from config import *
|
from config import *
|
||||||
from tokenizers.word_tokenizer import WordTokenizer
|
from tokenizers.word_tokenizer import WordTokenizer
|
||||||
from models.gpt import GPT
|
from models.gpt import GPT
|
||||||
@ -11,45 +12,42 @@ from training.trainer import TextDataset, train
|
|||||||
def load_texts():
|
def load_texts():
|
||||||
text = ""
|
text = ""
|
||||||
|
|
||||||
# --- Books ---
|
print("[INFO] Scanning books from: data\\books")
|
||||||
book_dir = os.path.join("data", "books")
|
books_dir = "data/books"
|
||||||
os.makedirs(book_dir, exist_ok=True)
|
if os.path.exists(books_dir):
|
||||||
|
for file in os.listdir(books_dir):
|
||||||
|
if file.endswith(".txt"):
|
||||||
|
print(f" 📚 Loading {file}")
|
||||||
|
with open(os.path.join(books_dir, file), encoding="utf-8") as f:
|
||||||
|
text += f.read()
|
||||||
|
|
||||||
print(f"[INFO] Scanning books from: {book_dir}")
|
owt_dir = "data/openwebtext"
|
||||||
for file in os.listdir(book_dir):
|
if os.path.exists(owt_dir):
|
||||||
path = os.path.join(book_dir, file)
|
print(f"[INFO] Scanning OpenWebText: {owt_dir}")
|
||||||
if file.endswith(".txt"):
|
for fname in os.listdir(owt_dir):
|
||||||
print(f" 📚 Loading {file}")
|
if fname.endswith(".jsonl"):
|
||||||
try:
|
path = os.path.join(owt_dir, fname)
|
||||||
with open(path, encoding="utf-8") as f:
|
print(f" 📄 Loading {fname}")
|
||||||
text += f.read() + "\n"
|
|
||||||
except Exception as e:
|
|
||||||
print(f" ❌ Failed to read {file}: {e}")
|
|
||||||
|
|
||||||
# --- OpenWebText ---
|
|
||||||
owt_path = os.path.join("data/openwebtext", "owt_20000.jsonl")
|
|
||||||
print(f"[INFO] Scanning OpenWebText: {owt_path}")
|
|
||||||
|
|
||||||
if os.path.exists(owt_path):
|
|
||||||
with open(owt_path, encoding="utf-8") as f:
|
|
||||||
for i, line in enumerate(f):
|
|
||||||
if i % 1000 == 0:
|
|
||||||
print(f" ⏳ {i} lines read...")
|
|
||||||
try:
|
try:
|
||||||
text += line.strip() + "\n"
|
with open(path, encoding="utf-8") as f:
|
||||||
|
for i, line in enumerate(f):
|
||||||
|
try:
|
||||||
|
obj = json.loads(line)
|
||||||
|
text += obj.get("text", "")
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
if len(text) >= MAX_TOKENS * 10:
|
||||||
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" ❌ Line {i} decode error: {e}")
|
print(f" ⚠️ Error reading {fname}: {e}")
|
||||||
else:
|
|
||||||
print(f"[WARN] OpenWebText file not found: {owt_path}")
|
|
||||||
|
|
||||||
# --- Chat logs ---
|
print("[INFO] Raw text loaded:", len(text), "characters")
|
||||||
if os.path.exists("catlin_chatlog.txt"):
|
|
||||||
print(f"[INFO] Appending chat log...")
|
|
||||||
with open("catlin_chatlog.txt", encoding="utf-8") as f:
|
|
||||||
text += "\n" + f.read()
|
|
||||||
|
|
||||||
print(f"[INFO] Raw text loaded: {len(text)} characters")
|
# Truncate to MAX_TOKENS * 10 (rough estimate)
|
||||||
return text[:MAX_TOKENS * 10]
|
clipped = text[:MAX_TOKENS * 10]
|
||||||
|
print("[INFO] Loaded text:", len(clipped), "characters")
|
||||||
|
|
||||||
|
return clipped
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -4,22 +4,42 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
class GPTBlock(nn.Module):
|
class CausalSelfAttention(nn.Module):
|
||||||
def __init__(self, embed_dim, num_heads):
|
def __init__(self, embed_dim, num_heads, dropout):
|
||||||
|
super().__init__()
|
||||||
|
self.attn = nn.MultiheadAttention(
|
||||||
|
embed_dim, num_heads,
|
||||||
|
dropout=dropout,
|
||||||
|
batch_first=True,
|
||||||
|
bias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, T, _ = x.size()
|
||||||
|
# Create causal mask: (T, T) with float('-inf') for future positions
|
||||||
|
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1)
|
||||||
|
mask = mask.masked_fill(mask == 1, float('-inf'))
|
||||||
|
|
||||||
|
# Pass it in as attn_mask
|
||||||
|
return self.attn(x, x, x, attn_mask=mask)[0]
|
||||||
|
|
||||||
|
|
||||||
|
class GPTBlock(nn.Module):
|
||||||
|
def __init__(self, embed_dim, num_heads, dropout):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
|
|
||||||
self.ln1 = nn.LayerNorm(embed_dim)
|
self.ln1 = nn.LayerNorm(embed_dim)
|
||||||
|
self.attn = CausalSelfAttention(embed_dim, num_heads, dropout)
|
||||||
|
self.ln2 = nn.LayerNorm(embed_dim)
|
||||||
self.ff = nn.Sequential(
|
self.ff = nn.Sequential(
|
||||||
nn.Linear(embed_dim, 4 * embed_dim),
|
nn.Linear(embed_dim, 4 * embed_dim),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.Linear(4 * embed_dim, embed_dim)
|
nn.Linear(4 * embed_dim, embed_dim),
|
||||||
|
nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
self.ln2 = nn.LayerNorm(embed_dim)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
attn_out, _ = self.attn(x, x, x, need_weights=False)
|
x = x + self.attn(self.ln1(x))
|
||||||
x = self.ln1(x + attn_out)
|
x = x + self.ff(self.ln2(x))
|
||||||
x = self.ln2(x + self.ff(x))
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -28,16 +48,18 @@ class GPT(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.token_emb = nn.Embedding(vocab_size, embed_dim)
|
self.token_emb = nn.Embedding(vocab_size, embed_dim)
|
||||||
self.pos_emb = nn.Parameter(torch.zeros(1, context_size, embed_dim))
|
self.pos_emb = nn.Parameter(torch.zeros(1, context_size, embed_dim))
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.drop = nn.Dropout(dropout)
|
||||||
self.blocks = nn.ModuleList([GPTBlock(embed_dim, num_heads) for _ in range(num_layers)])
|
self.blocks = nn.ModuleList([
|
||||||
|
GPTBlock(embed_dim, num_heads, dropout)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
])
|
||||||
self.ln_f = nn.LayerNorm(embed_dim)
|
self.ln_f = nn.LayerNorm(embed_dim)
|
||||||
self.head = nn.Linear(embed_dim, vocab_size)
|
self.head = nn.Linear(embed_dim, vocab_size)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
B, T = x.size()
|
B, T = x.size()
|
||||||
tok_emb = self.token_emb(x)
|
x = self.token_emb(x) + self.pos_emb[:, :T, :]
|
||||||
x = tok_emb + self.pos_emb[:, :T, :]
|
x = self.drop(x)
|
||||||
x = self.dropout(x)
|
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x)
|
x = block(x)
|
||||||
x = self.ln_f(x)
|
x = self.ln_f(x)
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
# tools/openwebtext_fetcher.py
|
|
||||||
|
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
import os
|
import os
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@ -8,14 +6,29 @@ TARGET_DIR = "data/openwebtext"
|
|||||||
os.makedirs(TARGET_DIR, exist_ok=True)
|
os.makedirs(TARGET_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
def fetch_subset(n=10000, split="train"):
|
def fetch_subset(total=20000, chunk_size=5000, split="train"):
|
||||||
ds = load_dataset("stas/openwebtext-10k", split=split)
|
ds = load_dataset("stas/openwebtext-10k", split=split)
|
||||||
with open(os.path.join(TARGET_DIR, f"owt_{n}.jsonl"), "w", encoding="utf-8") as f:
|
print(f"[INFO] Total to fetch: {total} | Chunk size: {chunk_size}")
|
||||||
for i, item in tqdm(enumerate(ds), total=n, desc="Writing JSONL"):
|
|
||||||
f.write(f"{item['text'].replace(chr(10),' ')}\n")
|
|
||||||
if i + 1 >= n:
|
|
||||||
break
|
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
file_index = 0
|
||||||
|
f = None
|
||||||
|
|
||||||
|
for item in tqdm(ds, desc="Downloading"):
|
||||||
|
if count % chunk_size == 0:
|
||||||
|
if f: f.close()
|
||||||
|
file_path = os.path.join(TARGET_DIR, f"owt_{file_index:05d}.jsonl")
|
||||||
|
f = open(file_path, "w", encoding="utf-8")
|
||||||
|
print(f"[INFO] ➕ Created {file_path}")
|
||||||
|
file_index += 1
|
||||||
|
f.write(f"{item['text'].replace(chr(10), ' ')}\n")
|
||||||
|
count += 1
|
||||||
|
if count >= total:
|
||||||
|
break
|
||||||
|
|
||||||
|
if f:
|
||||||
|
f.close()
|
||||||
|
print(f"[INFO] ✅ Done. {count} samples across {file_index} files.")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
fetch_subset(20000) # fetch 20K examples (~100 MB)
|
fetch_subset(total=100000, chunk_size=5000)
|
||||||
|
Reference in New Issue
Block a user