Ruby/model.py

26 lines
979 B
Python
Raw Normal View History

import torch
import torch.nn as nn
class TinyGPT(nn.Module):
def __init__(self, vocab_size, embed_size, num_heads, num_layers):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_size)
self.transformer = nn.Transformer(
d_model=embed_size,
nhead=num_heads,
num_encoder_layers=num_layers,
num_decoder_layers=num_layers,
batch_first=True # Ensures batch is the first dimension
)
self.fc = nn.Linear(embed_size, vocab_size)
def forward(self, src, tgt):
# Embed inputs
src_embed = self.embedding(src) # Shape: (batch_size, seq_len, embed_size)
tgt_embed = self.embedding(tgt) # Shape: (batch_size, seq_len, embed_size)
# Pass through transformer
transformer_out = self.transformer(src_embed, tgt_embed)
# Linear projection to vocabulary size
output = self.fc(transformer_out)
return output