import torch import torch.nn as nn class MiniTransformer(nn.Module): def __init__(self, vocab_size, d_model=256, n_heads=4, n_layers=4, max_seq_len=512): super().__init__() self.token_emb = nn.Embedding(vocab_size, d_model) self.pos_emb = nn.Parameter(torch.zeros(1, max_seq_len, d_model)) self.layers = nn.ModuleList([ nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads, batch_first=True) for _ in range(n_layers) ]) self.ln = nn.LayerNorm(d_model) self.head = nn.Linear(d_model, vocab_size) def forward(self, x): B, T = x.size() x = self.token_emb(x) + self.pos_emb[:, :T] for layer in self.layers: x = layer(x) x = self.ln(x) return self.head(x)