import torch import torch.nn as nn class Discriminator(nn.Module): def __init__(self, vocab_size: int, embed_dim: int): super().__init__() self.embedding = nn.Embedding(vocab_size, embed_dim) self.lstm = nn.LSTM(embed_dim, embed_dim, batch_first=True) self.fc = nn.Linear(embed_dim, 1) def forward(self, x): # x: (batch, seq_len) emb = self.embedding(x) _, (h_n, _) = self.lstm(emb) # h_n[-1]: (batch, embed_dim) return self.fc(h_n[-1])