import torch
import torch.nn as nn

class TinyGPT(nn.Module):
    def __init__(self, vocab_size, embed_size, num_heads, num_layers):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.transformer = nn.Transformer(
            d_model=embed_size,
            nhead=num_heads,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            batch_first=True  # Ensures batch is the first dimension
        )
        self.fc = nn.Linear(embed_size, vocab_size)

    def forward(self, src, tgt):
        # Embed inputs
        src_embed = self.embedding(src)  # Shape: (batch_size, seq_len, embed_size)
        tgt_embed = self.embedding(tgt)  # Shape: (batch_size, seq_len, embed_size)
        # Pass through transformer
        transformer_out = self.transformer(src_embed, tgt_embed)
        # Linear projection to vocabulary size
        output = self.fc(transformer_out)
        return output