RubyOld/model.py
2025-04-08 19:52:01 -04:00

31 lines
962 B
Python

import torch
import torch.nn as nn
class MiniTransformer(nn.Module):
def __init__(self,
vocab_size,
d_model=256,
n_heads=4,
n_layers=4,
max_seq_len=512):
super().__init__()
self.token_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Parameter(torch.zeros(1, max_seq_len, d_model))
self.layers = nn.ModuleList([
nn.TransformerEncoderLayer(d_model=d_model,
nhead=n_heads,
batch_first=True)
for _ in range(n_layers)
])
self.ln = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, vocab_size)
def forward(self, x):
B, T = x.size()
x = self.token_emb(x) + self.pos_emb[:, :T]
for layer in self.layers:
x = layer(x)
x = self.ln(x)
return self.head(x)