45 lines
1.4 KiB
Python
45 lines
1.4 KiB
Python
# models/gpt.py
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class GPTBlock(nn.Module):
|
|
def __init__(self, embed_dim, num_heads):
|
|
super().__init__()
|
|
self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
|
|
self.ln1 = nn.LayerNorm(embed_dim)
|
|
self.ff = nn.Sequential(
|
|
nn.Linear(embed_dim, 4 * embed_dim),
|
|
nn.GELU(),
|
|
nn.Linear(4 * embed_dim, embed_dim)
|
|
)
|
|
self.ln2 = nn.LayerNorm(embed_dim)
|
|
|
|
def forward(self, x):
|
|
attn_out, _ = self.attn(x, x, x, need_weights=False)
|
|
x = self.ln1(x + attn_out)
|
|
x = self.ln2(x + self.ff(x))
|
|
return x
|
|
|
|
|
|
class GPT(nn.Module):
|
|
def __init__(self, vocab_size, context_size, embed_dim, num_heads, num_layers, dropout=0.1):
|
|
super().__init__()
|
|
self.token_emb = nn.Embedding(vocab_size, embed_dim)
|
|
self.pos_emb = nn.Parameter(torch.zeros(1, context_size, embed_dim))
|
|
self.dropout = nn.Dropout(dropout)
|
|
self.blocks = nn.ModuleList([GPTBlock(embed_dim, num_heads) for _ in range(num_layers)])
|
|
self.ln_f = nn.LayerNorm(embed_dim)
|
|
self.head = nn.Linear(embed_dim, vocab_size)
|
|
|
|
def forward(self, x):
|
|
B, T = x.size()
|
|
tok_emb = self.token_emb(x)
|
|
x = tok_emb + self.pos_emb[:, :T, :]
|
|
x = self.dropout(x)
|
|
for block in self.blocks:
|
|
x = block(x)
|
|
x = self.ln_f(x)
|
|
return self.head(x)
|