Updated the model capacity

This commit is contained in:
2025-06-30 18:08:11 -04:00
parent 159be1eb82
commit 6366f72716
6 changed files with 95 additions and 10058 deletions

View File

@ -4,22 +4,42 @@ import torch
import torch.nn as nn
class GPTBlock(nn.Module):
def __init__(self, embed_dim, num_heads):
class CausalSelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads, dropout):
super().__init__()
self.attn = nn.MultiheadAttention(
embed_dim, num_heads,
dropout=dropout,
batch_first=True,
bias=True
)
def forward(self, x):
B, T, _ = x.size()
# Create causal mask: (T, T) with float('-inf') for future positions
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1)
mask = mask.masked_fill(mask == 1, float('-inf'))
# Pass it in as attn_mask
return self.attn(x, x, x, attn_mask=mask)[0]
class GPTBlock(nn.Module):
def __init__(self, embed_dim, num_heads, dropout):
super().__init__()
self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
self.ln1 = nn.LayerNorm(embed_dim)
self.attn = CausalSelfAttention(embed_dim, num_heads, dropout)
self.ln2 = nn.LayerNorm(embed_dim)
self.ff = nn.Sequential(
nn.Linear(embed_dim, 4 * embed_dim),
nn.GELU(),
nn.Linear(4 * embed_dim, embed_dim)
nn.Linear(4 * embed_dim, embed_dim),
nn.Dropout(dropout)
)
self.ln2 = nn.LayerNorm(embed_dim)
def forward(self, x):
attn_out, _ = self.attn(x, x, x, need_weights=False)
x = self.ln1(x + attn_out)
x = self.ln2(x + self.ff(x))
x = x + self.attn(self.ln1(x))
x = x + self.ff(self.ln2(x))
return x
@ -28,16 +48,18 @@ class GPT(nn.Module):
super().__init__()
self.token_emb = nn.Embedding(vocab_size, embed_dim)
self.pos_emb = nn.Parameter(torch.zeros(1, context_size, embed_dim))
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList([GPTBlock(embed_dim, num_heads) for _ in range(num_layers)])
self.drop = nn.Dropout(dropout)
self.blocks = nn.ModuleList([
GPTBlock(embed_dim, num_heads, dropout)
for _ in range(num_layers)
])
self.ln_f = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, vocab_size)
def forward(self, x):
B, T = x.size()
tok_emb = self.token_emb(x)
x = tok_emb + self.pos_emb[:, :T, :]
x = self.dropout(x)
x = self.token_emb(x) + self.pos_emb[:, :T, :]
x = self.drop(x)
for block in self.blocks:
x = block(x)
x = self.ln_f(x)