# models/gpt.py import torch import torch.nn as nn class CausalSelfAttention(nn.Module): def __init__(self, embed_dim, num_heads, dropout): super().__init__() self.attn = nn.MultiheadAttention( embed_dim, num_heads, dropout=dropout, batch_first=True, bias=True ) def forward(self, x): B, T, _ = x.size() # Create causal mask: (T, T) with float('-inf') for future positions mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1) mask = mask.masked_fill(mask == 1, float('-inf')) # Pass it in as attn_mask return self.attn(x, x, x, attn_mask=mask)[0] class GPTBlock(nn.Module): def __init__(self, embed_dim, num_heads, dropout): super().__init__() self.ln1 = nn.LayerNorm(embed_dim) self.attn = CausalSelfAttention(embed_dim, num_heads, dropout) self.ln2 = nn.LayerNorm(embed_dim) self.ff = nn.Sequential( nn.Linear(embed_dim, 4 * embed_dim), nn.GELU(), nn.Linear(4 * embed_dim, embed_dim), nn.Dropout(dropout) ) def forward(self, x): x = x + self.attn(self.ln1(x)) x = x + self.ff(self.ln2(x)) return x class GPT(nn.Module): def __init__(self, vocab_size, context_size, embed_dim, num_heads, num_layers, dropout=0.1): super().__init__() self.token_emb = nn.Embedding(vocab_size, embed_dim) self.pos_emb = nn.Parameter(torch.zeros(1, context_size, embed_dim)) self.drop = nn.Dropout(dropout) self.blocks = nn.ModuleList([ GPTBlock(embed_dim, num_heads, dropout) for _ in range(num_layers) ]) self.ln_f = nn.LayerNorm(embed_dim) self.head = nn.Linear(embed_dim, vocab_size) def forward(self, x): B, T = x.size() x = self.token_emb(x) + self.pos_emb[:, :T, :] x = self.drop(x) for block in self.blocks: x = block(x) x = self.ln_f(x) return self.head(x)