# models/gpt.py import torch import torch.nn as nn class GPTBlock(nn.Module): def __init__(self, embed_dim, num_heads): super().__init__() self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) self.ln1 = nn.LayerNorm(embed_dim) self.ff = nn.Sequential( nn.Linear(embed_dim, 4 * embed_dim), nn.GELU(), nn.Linear(4 * embed_dim, embed_dim) ) self.ln2 = nn.LayerNorm(embed_dim) def forward(self, x): attn_out, _ = self.attn(x, x, x, need_weights=False) x = self.ln1(x + attn_out) x = self.ln2(x + self.ff(x)) return x class GPT(nn.Module): def __init__(self, vocab_size, context_size, embed_dim, num_heads, num_layers, dropout=0.1): super().__init__() self.token_emb = nn.Embedding(vocab_size, embed_dim) self.pos_emb = nn.Parameter(torch.zeros(1, context_size, embed_dim)) self.dropout = nn.Dropout(dropout) self.blocks = nn.ModuleList([GPTBlock(embed_dim, num_heads) for _ in range(num_layers)]) self.ln_f = nn.LayerNorm(embed_dim) self.head = nn.Linear(embed_dim, vocab_size) def forward(self, x): B, T = x.size() tok_emb = self.token_emb(x) x = tok_emb + self.pos_emb[:, :T, :] x = self.dropout(x) for block in self.blocks: x = block(x) x = self.ln_f(x) return self.head(x)