25 lines
862 B
Python
25 lines
862 B
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class MiniGPT(nn.Module):
|
|
def __init__(self, vocab_size, embed_dim=128, n_heads=4, n_layers=2, max_len=128):
|
|
super().__init__()
|
|
self.token_embed = nn.Embedding(vocab_size, embed_dim)
|
|
self.pos_embed = nn.Embedding(max_len, embed_dim)
|
|
self.blocks = nn.ModuleList([
|
|
nn.TransformerEncoderLayer(d_model=embed_dim, nhead=n_heads, batch_first=True)
|
|
for _ in range(n_layers)
|
|
])
|
|
self.ln_f = nn.LayerNorm(embed_dim)
|
|
self.head = nn.Linear(embed_dim, vocab_size)
|
|
|
|
def forward(self, x):
|
|
seq_len = x.size(1)
|
|
pos = torch.arange(0, seq_len, device=x.device).unsqueeze(0)
|
|
x = self.token_embed(x) + self.pos_embed(pos)
|
|
for block in self.blocks:
|
|
x = block(x)
|
|
x = self.ln_f(x)
|
|
return self.head(x)
|