import torch import torch.nn as nn class MiniGPT(nn.Module): def __init__(self, vocab_size, embed_dim=128, n_heads=4, n_layers=2, max_len=128): super().__init__() self.token_embed = nn.Embedding(vocab_size, embed_dim) self.pos_embed = nn.Embedding(max_len, embed_dim) self.blocks = nn.ModuleList([ nn.TransformerEncoderLayer(d_model=embed_dim, nhead=n_heads, batch_first=True) for _ in range(n_layers) ]) self.ln_f = nn.LayerNorm(embed_dim) self.head = nn.Linear(embed_dim, vocab_size) def forward(self, x): seq_len = x.size(1) pos = torch.arange(0, seq_len, device=x.device).unsqueeze(0) x = self.token_embed(x) + self.pos_embed(pos) for block in self.blocks: x = block(x) x = self.ln_f(x) return self.head(x)