46 lines
1.5 KiB
Python
46 lines
1.5 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
|
||
|
||
class Cortex(nn.Module):
|
||
"""The ‘brain’: a char‐level Transformer encoder for self-supervised learning."""
|
||
def __init__(
|
||
self,
|
||
embed_dim: int = 256,
|
||
num_heads: int = 4,
|
||
num_layers: int = 4,
|
||
ff_dim: int = 512,
|
||
max_seq_len: int = 1024,
|
||
) -> None:
|
||
super().__init__()
|
||
self.vocab_size = 256 # ASCII
|
||
self.embed_dim = embed_dim
|
||
self.token_embedding = nn.Embedding(self.vocab_size, embed_dim)
|
||
self.position_embedding = nn.Embedding(max_seq_len, embed_dim)
|
||
|
||
encoder_layer = nn.TransformerEncoderLayer(
|
||
d_model=embed_dim,
|
||
nhead=num_heads,
|
||
dim_feedforward=ff_dim,
|
||
)
|
||
self.transformer = nn.TransformerEncoder(
|
||
encoder_layer, num_layers=num_layers
|
||
)
|
||
self.fc_out = nn.Linear(embed_dim, self.vocab_size)
|
||
self.max_seq_len = max_seq_len
|
||
|
||
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||
# input_ids: (batch, seq_len)
|
||
batch_size, seq_len = input_ids.size()
|
||
positions = (
|
||
torch.arange(0, seq_len, device=input_ids.device)
|
||
.unsqueeze(0)
|
||
.expand(batch_size, -1)
|
||
)
|
||
x = self.token_embedding(input_ids) + self.position_embedding(positions)
|
||
x = x.permute(1, 0, 2) # (seq_len, batch, embed_dim)
|
||
x = self.transformer(x)
|
||
x = x.permute(1, 0, 2) # back to (batch, seq_len, embed_dim)
|
||
logits = self.fc_out(x)
|
||
return logits
|