import torch import torch.nn as nn class Cortex(nn.Module): """The ‘brain’: a char‐level Transformer encoder for self-supervised learning.""" def __init__( self, embed_dim: int = 256, num_heads: int = 4, num_layers: int = 4, ff_dim: int = 512, max_seq_len: int = 1024, ) -> None: super().__init__() self.vocab_size = 256 # ASCII self.embed_dim = embed_dim self.token_embedding = nn.Embedding(self.vocab_size, embed_dim) self.position_embedding = nn.Embedding(max_seq_len, embed_dim) encoder_layer = nn.TransformerEncoderLayer( d_model=embed_dim, nhead=num_heads, dim_feedforward=ff_dim, ) self.transformer = nn.TransformerEncoder( encoder_layer, num_layers=num_layers ) self.fc_out = nn.Linear(embed_dim, self.vocab_size) self.max_seq_len = max_seq_len def forward(self, input_ids: torch.Tensor) -> torch.Tensor: # input_ids: (batch, seq_len) batch_size, seq_len = input_ids.size() positions = ( torch.arange(0, seq_len, device=input_ids.device) .unsqueeze(0) .expand(batch_size, -1) ) x = self.token_embedding(input_ids) + self.position_embedding(positions) x = x.permute(1, 0, 2) # (seq_len, batch, embed_dim) x = self.transformer(x) x = x.permute(1, 0, 2) # back to (batch, seq_len, embed_dim) logits = self.fc_out(x) return logits