Ruby/models/discriminator.py

18 lines
520 B
Python

import torch
import torch.nn as nn
class Discriminator(nn.Module):
def __init__(self, vocab_size: int, embed_dim: int):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTM(embed_dim, embed_dim, batch_first=True)
self.fc = nn.Linear(embed_dim, 1)
def forward(self, x):
# x: (batch, seq_len)
emb = self.embedding(x)
_, (h_n, _) = self.lstm(emb)
# h_n[-1]: (batch, embed_dim)
return self.fc(h_n[-1])