18 lines
520 B
Python
18 lines
520 B
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class Discriminator(nn.Module):
|
|
def __init__(self, vocab_size: int, embed_dim: int):
|
|
super().__init__()
|
|
self.embedding = nn.Embedding(vocab_size, embed_dim)
|
|
self.lstm = nn.LSTM(embed_dim, embed_dim, batch_first=True)
|
|
self.fc = nn.Linear(embed_dim, 1)
|
|
|
|
def forward(self, x):
|
|
# x: (batch, seq_len)
|
|
emb = self.embedding(x)
|
|
_, (h_n, _) = self.lstm(emb)
|
|
# h_n[-1]: (batch, embed_dim)
|
|
return self.fc(h_n[-1])
|