import torch import torch.nn as nn class MultiHeadSelfAttention(nn.Module): def __init__(self, embed_dim, num_heads): super().__init__() assert embed_dim % num_heads == 0 self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3) self.out_proj = nn.Linear(embed_dim, embed_dim) def forward(self, x): # x: (batch, seq_len, embed_dim) b, t, e = x.size() qkv = self.qkv_proj(x) # (b, t, 3*e) q, k, v = qkv.chunk(3, dim=-1) # reshape for multi-head q = q.view(b, t, self.num_heads, self.head_dim).transpose(1, 2) k = k.view(b, t, self.num_heads, self.head_dim).transpose(1, 2) v = v.view(b, t, self.num_heads, self.head_dim).transpose(1, 2) attn = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim**0.5) attn = torch.softmax(attn, dim=-1) out = torch.matmul(attn, v).transpose(1, 2).contiguous() out = out.view(b, t, e) return self.out_proj(out) class TransformerBlock(nn.Module): def __init__(self, embed_dim, num_heads, mlp_dim, dropout=0.1): super().__init__() self.attn = MultiHeadSelfAttention(embed_dim, num_heads) self.ln1 = nn.LayerNorm(embed_dim) self.ff = nn.Sequential( nn.Linear(embed_dim, mlp_dim), nn.ReLU(), nn.Linear(mlp_dim, embed_dim), ) self.ln2 = nn.LayerNorm(embed_dim) self.dropout = nn.Dropout(dropout) def forward(self, x): x = x + self.dropout(self.attn(self.ln1(x))) x = x + self.dropout(self.ff(self.ln2(x))) return x class TransformerGenerator(nn.Module): def __init__( self, vocab_size: int, embed_dim: int, num_heads: int, mlp_dim: int, num_layers: int, max_seq_len: int, ): super().__init__() self.token_emb = nn.Embedding(vocab_size, embed_dim) self.pos_emb = nn.Embedding(max_seq_len, embed_dim) self.layers = nn.ModuleList( [ TransformerBlock(embed_dim, num_heads, mlp_dim) for _ in range(num_layers) ] ) self.ln = nn.LayerNorm(embed_dim) self.head = nn.Linear(embed_dim, vocab_size) def forward(self, x): # x: (batch, seq_len) b, t = x.size() positions = torch.arange(t, device=x.device).unsqueeze(0) x = self.token_emb(x) + self.pos_emb(positions) for layer in self.layers: x = layer(x) x = self.ln(x) return self.head(x)