81 lines
2.6 KiB
Python
81 lines
2.6 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class MultiHeadSelfAttention(nn.Module):
|
|
def __init__(self, embed_dim, num_heads):
|
|
super().__init__()
|
|
assert embed_dim % num_heads == 0
|
|
self.embed_dim = embed_dim
|
|
self.num_heads = num_heads
|
|
self.head_dim = embed_dim // num_heads
|
|
self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)
|
|
self.out_proj = nn.Linear(embed_dim, embed_dim)
|
|
|
|
def forward(self, x):
|
|
# x: (batch, seq_len, embed_dim)
|
|
b, t, e = x.size()
|
|
qkv = self.qkv_proj(x) # (b, t, 3*e)
|
|
q, k, v = qkv.chunk(3, dim=-1)
|
|
# reshape for multi-head
|
|
q = q.view(b, t, self.num_heads, self.head_dim).transpose(1, 2)
|
|
k = k.view(b, t, self.num_heads, self.head_dim).transpose(1, 2)
|
|
v = v.view(b, t, self.num_heads, self.head_dim).transpose(1, 2)
|
|
attn = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim**0.5)
|
|
attn = torch.softmax(attn, dim=-1)
|
|
out = torch.matmul(attn, v).transpose(1, 2).contiguous()
|
|
out = out.view(b, t, e)
|
|
return self.out_proj(out)
|
|
|
|
|
|
class TransformerBlock(nn.Module):
|
|
def __init__(self, embed_dim, num_heads, mlp_dim, dropout=0.1):
|
|
super().__init__()
|
|
self.attn = MultiHeadSelfAttention(embed_dim, num_heads)
|
|
self.ln1 = nn.LayerNorm(embed_dim)
|
|
self.ff = nn.Sequential(
|
|
nn.Linear(embed_dim, mlp_dim),
|
|
nn.ReLU(),
|
|
nn.Linear(mlp_dim, embed_dim),
|
|
)
|
|
self.ln2 = nn.LayerNorm(embed_dim)
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
def forward(self, x):
|
|
x = x + self.dropout(self.attn(self.ln1(x)))
|
|
x = x + self.dropout(self.ff(self.ln2(x)))
|
|
return x
|
|
|
|
|
|
class TransformerGenerator(nn.Module):
|
|
def __init__(
|
|
self,
|
|
vocab_size: int,
|
|
embed_dim: int,
|
|
num_heads: int,
|
|
mlp_dim: int,
|
|
num_layers: int,
|
|
max_seq_len: int,
|
|
):
|
|
super().__init__()
|
|
self.token_emb = nn.Embedding(vocab_size, embed_dim)
|
|
self.pos_emb = nn.Embedding(max_seq_len, embed_dim)
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
TransformerBlock(embed_dim, num_heads, mlp_dim)
|
|
for _ in range(num_layers)
|
|
]
|
|
)
|
|
self.ln = nn.LayerNorm(embed_dim)
|
|
self.head = nn.Linear(embed_dim, vocab_size)
|
|
|
|
def forward(self, x):
|
|
# x: (batch, seq_len)
|
|
b, t = x.size()
|
|
positions = torch.arange(t, device=x.device).unsqueeze(0)
|
|
x = self.token_emb(x) + self.pos_emb(positions)
|
|
for layer in self.layers:
|
|
x = layer(x)
|
|
x = self.ln(x)
|
|
return self.head(x)
|