31 lines
962 B
Python
31 lines
962 B
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class MiniTransformer(nn.Module):
|
|
def __init__(self,
|
|
vocab_size,
|
|
d_model=256,
|
|
n_heads=4,
|
|
n_layers=4,
|
|
max_seq_len=512):
|
|
super().__init__()
|
|
self.token_emb = nn.Embedding(vocab_size, d_model)
|
|
self.pos_emb = nn.Parameter(torch.zeros(1, max_seq_len, d_model))
|
|
self.layers = nn.ModuleList([
|
|
nn.TransformerEncoderLayer(d_model=d_model,
|
|
nhead=n_heads,
|
|
batch_first=True)
|
|
for _ in range(n_layers)
|
|
])
|
|
self.ln = nn.LayerNorm(d_model)
|
|
self.head = nn.Linear(d_model, vocab_size)
|
|
|
|
def forward(self, x):
|
|
B, T = x.size()
|
|
x = self.token_emb(x) + self.pos_emb[:, :T]
|
|
for layer in self.layers:
|
|
x = layer(x)
|
|
x = self.ln(x)
|
|
return self.head(x)
|