class AE(nn.Module): def __init__(self): super(AE, self).__init__() self.encoder = nn.Sequential( nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 36), nn.ReLU(), nn.Linear(36, 18), nn.ReLU(), nn.Linear(18, 9) ) self.decoder = nn.Sequential( nn.Linear(9, 18), nn.ReLU(), nn.Linear(18, 36), nn.ReLU(), nn.Linear(36, 64), nn.ReLU(), nn.Linear(64, 128), nn.ReLU(), nn.Linear(128, 28 * 28), nn.Sigmoid() ) def forward(self, x): encoded = self.encoder(x) decoded = self.decoder(encoded) return decoded