import torch import torch.nn as nn class TinyGPT(nn.Module): def __init__(self, vocab_size, embed_size, num_heads, num_layers): super().__init__() self.embedding = nn.Embedding(vocab_size, embed_size) self.transformer = nn.Transformer( d_model=embed_size, nhead=num_heads, num_encoder_layers=num_layers, num_decoder_layers=num_layers, batch_first=True # Ensures batch is the first dimension ) self.fc = nn.Linear(embed_size, vocab_size) def forward(self, src, tgt): # Embed inputs src_embed = self.embedding(src) # Shape: (batch_size, seq_len, embed_size) tgt_embed = self.embedding(tgt) # Shape: (batch_size, seq_len, embed_size) # Pass through transformer transformer_out = self.transformer(src_embed, tgt_embed) # Linear projection to vocabulary size output = self.fc(transformer_out) return output