Files
Catlin/models/gpt.py
2025-06-30 18:08:11 -04:00

67 lines
2.0 KiB
Python

# models/gpt.py
import torch
import torch.nn as nn
class CausalSelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads, dropout):
super().__init__()
self.attn = nn.MultiheadAttention(
embed_dim, num_heads,
dropout=dropout,
batch_first=True,
bias=True
)
def forward(self, x):
B, T, _ = x.size()
# Create causal mask: (T, T) with float('-inf') for future positions
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1)
mask = mask.masked_fill(mask == 1, float('-inf'))
# Pass it in as attn_mask
return self.attn(x, x, x, attn_mask=mask)[0]
class GPTBlock(nn.Module):
def __init__(self, embed_dim, num_heads, dropout):
super().__init__()
self.ln1 = nn.LayerNorm(embed_dim)
self.attn = CausalSelfAttention(embed_dim, num_heads, dropout)
self.ln2 = nn.LayerNorm(embed_dim)
self.ff = nn.Sequential(
nn.Linear(embed_dim, 4 * embed_dim),
nn.GELU(),
nn.Linear(4 * embed_dim, embed_dim),
nn.Dropout(dropout)
)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.ff(self.ln2(x))
return x
class GPT(nn.Module):
def __init__(self, vocab_size, context_size, embed_dim, num_heads, num_layers, dropout=0.1):
super().__init__()
self.token_emb = nn.Embedding(vocab_size, embed_dim)
self.pos_emb = nn.Parameter(torch.zeros(1, context_size, embed_dim))
self.drop = nn.Dropout(dropout)
self.blocks = nn.ModuleList([
GPTBlock(embed_dim, num_heads, dropout)
for _ in range(num_layers)
])
self.ln_f = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, vocab_size)
def forward(self, x):
B, T = x.size()
x = self.token_emb(x) + self.pos_emb[:, :T, :]
x = self.drop(x)
for block in self.blocks:
x = block(x)
x = self.ln_f(x)
return self.head(x)