Updated the model capacity
This commit is contained in:
@ -4,22 +4,42 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class GPTBlock(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads):
|
||||
class CausalSelfAttention(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, dropout):
|
||||
super().__init__()
|
||||
self.attn = nn.MultiheadAttention(
|
||||
embed_dim, num_heads,
|
||||
dropout=dropout,
|
||||
batch_first=True,
|
||||
bias=True
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
B, T, _ = x.size()
|
||||
# Create causal mask: (T, T) with float('-inf') for future positions
|
||||
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1)
|
||||
mask = mask.masked_fill(mask == 1, float('-inf'))
|
||||
|
||||
# Pass it in as attn_mask
|
||||
return self.attn(x, x, x, attn_mask=mask)[0]
|
||||
|
||||
|
||||
class GPTBlock(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, dropout):
|
||||
super().__init__()
|
||||
self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
|
||||
self.ln1 = nn.LayerNorm(embed_dim)
|
||||
self.attn = CausalSelfAttention(embed_dim, num_heads, dropout)
|
||||
self.ln2 = nn.LayerNorm(embed_dim)
|
||||
self.ff = nn.Sequential(
|
||||
nn.Linear(embed_dim, 4 * embed_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(4 * embed_dim, embed_dim)
|
||||
nn.Linear(4 * embed_dim, embed_dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
self.ln2 = nn.LayerNorm(embed_dim)
|
||||
|
||||
def forward(self, x):
|
||||
attn_out, _ = self.attn(x, x, x, need_weights=False)
|
||||
x = self.ln1(x + attn_out)
|
||||
x = self.ln2(x + self.ff(x))
|
||||
x = x + self.attn(self.ln1(x))
|
||||
x = x + self.ff(self.ln2(x))
|
||||
return x
|
||||
|
||||
|
||||
@ -28,16 +48,18 @@ class GPT(nn.Module):
|
||||
super().__init__()
|
||||
self.token_emb = nn.Embedding(vocab_size, embed_dim)
|
||||
self.pos_emb = nn.Parameter(torch.zeros(1, context_size, embed_dim))
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.blocks = nn.ModuleList([GPTBlock(embed_dim, num_heads) for _ in range(num_layers)])
|
||||
self.drop = nn.Dropout(dropout)
|
||||
self.blocks = nn.ModuleList([
|
||||
GPTBlock(embed_dim, num_heads, dropout)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
self.ln_f = nn.LayerNorm(embed_dim)
|
||||
self.head = nn.Linear(embed_dim, vocab_size)
|
||||
|
||||
def forward(self, x):
|
||||
B, T = x.size()
|
||||
tok_emb = self.token_emb(x)
|
||||
x = tok_emb + self.pos_emb[:, :T, :]
|
||||
x = self.dropout(x)
|
||||
x = self.token_emb(x) + self.pos_emb[:, :T, :]
|
||||
x = self.drop(x)
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
x = self.ln_f(x)
|
||||
|
Reference in New Issue
Block a user