Ruby/models/transformer.py

81 lines
2.6 KiB
Python

import torch
import torch.nn as nn
class MultiHeadSelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
assert embed_dim % num_heads == 0
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)
self.out_proj = nn.Linear(embed_dim, embed_dim)
def forward(self, x):
# x: (batch, seq_len, embed_dim)
b, t, e = x.size()
qkv = self.qkv_proj(x) # (b, t, 3*e)
q, k, v = qkv.chunk(3, dim=-1)
# reshape for multi-head
q = q.view(b, t, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(b, t, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(b, t, self.num_heads, self.head_dim).transpose(1, 2)
attn = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim**0.5)
attn = torch.softmax(attn, dim=-1)
out = torch.matmul(attn, v).transpose(1, 2).contiguous()
out = out.view(b, t, e)
return self.out_proj(out)
class TransformerBlock(nn.Module):
def __init__(self, embed_dim, num_heads, mlp_dim, dropout=0.1):
super().__init__()
self.attn = MultiHeadSelfAttention(embed_dim, num_heads)
self.ln1 = nn.LayerNorm(embed_dim)
self.ff = nn.Sequential(
nn.Linear(embed_dim, mlp_dim),
nn.ReLU(),
nn.Linear(mlp_dim, embed_dim),
)
self.ln2 = nn.LayerNorm(embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = x + self.dropout(self.attn(self.ln1(x)))
x = x + self.dropout(self.ff(self.ln2(x)))
return x
class TransformerGenerator(nn.Module):
def __init__(
self,
vocab_size: int,
embed_dim: int,
num_heads: int,
mlp_dim: int,
num_layers: int,
max_seq_len: int,
):
super().__init__()
self.token_emb = nn.Embedding(vocab_size, embed_dim)
self.pos_emb = nn.Embedding(max_seq_len, embed_dim)
self.layers = nn.ModuleList(
[
TransformerBlock(embed_dim, num_heads, mlp_dim)
for _ in range(num_layers)
]
)
self.ln = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, vocab_size)
def forward(self, x):
# x: (batch, seq_len)
b, t = x.size()
positions = torch.arange(t, device=x.device).unsqueeze(0)
x = self.token_emb(x) + self.pos_emb(positions)
for layer in self.layers:
x = layer(x)
x = self.ln(x)
return self.head(x)