Files
Catlin/models/gpt.py
2025-06-29 12:36:25 -04:00

45 lines
1.4 KiB
Python

# models/gpt.py
import torch
import torch.nn as nn
class GPTBlock(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
self.ln1 = nn.LayerNorm(embed_dim)
self.ff = nn.Sequential(
nn.Linear(embed_dim, 4 * embed_dim),
nn.GELU(),
nn.Linear(4 * embed_dim, embed_dim)
)
self.ln2 = nn.LayerNorm(embed_dim)
def forward(self, x):
attn_out, _ = self.attn(x, x, x, need_weights=False)
x = self.ln1(x + attn_out)
x = self.ln2(x + self.ff(x))
return x
class GPT(nn.Module):
def __init__(self, vocab_size, context_size, embed_dim, num_heads, num_layers, dropout=0.1):
super().__init__()
self.token_emb = nn.Embedding(vocab_size, embed_dim)
self.pos_emb = nn.Parameter(torch.zeros(1, context_size, embed_dim))
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList([GPTBlock(embed_dim, num_heads) for _ in range(num_layers)])
self.ln_f = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, vocab_size)
def forward(self, x):
B, T = x.size()
tok_emb = self.token_emb(x)
x = tok_emb + self.pos_emb[:, :T, :]
x = self.dropout(x)
for block in self.blocks:
x = block(x)
x = self.ln_f(x)
return self.head(x)