26 lines
979 B
Python
26 lines
979 B
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
class TinyGPT(nn.Module):
|
|
def __init__(self, vocab_size, embed_size, num_heads, num_layers):
|
|
super().__init__()
|
|
self.embedding = nn.Embedding(vocab_size, embed_size)
|
|
self.transformer = nn.Transformer(
|
|
d_model=embed_size,
|
|
nhead=num_heads,
|
|
num_encoder_layers=num_layers,
|
|
num_decoder_layers=num_layers,
|
|
batch_first=True # Ensures batch is the first dimension
|
|
)
|
|
self.fc = nn.Linear(embed_size, vocab_size)
|
|
|
|
def forward(self, src, tgt):
|
|
# Embed inputs
|
|
src_embed = self.embedding(src) # Shape: (batch_size, seq_len, embed_size)
|
|
tgt_embed = self.embedding(tgt) # Shape: (batch_size, seq_len, embed_size)
|
|
# Pass through transformer
|
|
transformer_out = self.transformer(src_embed, tgt_embed)
|
|
# Linear projection to vocabulary size
|
|
output = self.fc(transformer_out)
|
|
return output
|