Skip to content

Incorrect normalization in VAE example loss function #290

@richardzhu

Description

@richardzhu

In the loss_function part of the VAE example, I noticed that

KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) # Normalise by same number of elements as in reconstruction KLD /= args.batch_size * 784

But the dimensionality of the latent variables (logvar, mu) is 20, not 784 -- hence it should either be
torch.sum and normalize by args.batch_size * 20 or just straight-up torch.mean, otherwise the BCE and KLD losses are not properly scaled against each other. Changing the normalization from 784 to 20 increases the test error at the end of training, but this is due to a lower normalization increasing the scale of the KLD.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions