|
1 |
| -# probabilistic_unet_pytorch |
2 |
| -A Probabilistic U-Net for segmentation of ambiguous images implemented in PyTorch |
| 1 | +# Probabilistic UNet in PyTorch |
| 2 | +A Probabilistic U-Net for segmentation of ambiguous images implemented in PyTorch. This is a pytorch implementation of this paper https://arxiv.org/abs/1806.05034, for which the code can be found here: https://github.com/SimonKohl/probabilistic_unet. |
| 3 | + |
| 4 | +## Adding KL divergence for Independent distribution |
| 5 | +In order to implement an Gaussian distribution with an axis aligned covariance matrix in PyTorch, I needed to wrap a Normal distribution in a Independent distribution. Therefore you need the add the following to the PyTorch source code at torch/distributions/kl.py (source: https://github.com/pytorch/pytorch/issues/13545). |
| 6 | + |
| 7 | +``` |
| 8 | +def _kl_independent_independent(p, q): |
| 9 | + if p.reinterpreted_batch_ndims != q.reinterpreted_batch_ndims: |
| 10 | + raise NotImplementedError |
| 11 | + result = kl_divergence(p.base_dist, q.base_dist) |
| 12 | + return _sum_rightmost(result, p.reinterpreted_batch_ndims) |
| 13 | +``` |
| 14 | + |
| 15 | +## Training |
| 16 | +In order to train your own Probabilistic UNet in PyTorch, you should first write your own data loader. Then you can use the following code snippet to train the network |
| 17 | + |
| 18 | +``` |
| 19 | +train_loader = define this yourself |
| 20 | +net = ProbabilisticUnet(no_channels,no_classes,filter_list,latent_dim,no_fcomb_convs,beta) |
| 21 | +net.to(device) |
| 22 | +optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, weight_decay=0) |
| 23 | +for epoch in range(epochs): |
| 24 | + for step, (patch, mask) in enumerate(train_loader): |
| 25 | + patch = patch.to(device) |
| 26 | + mask = mask.to(device) |
| 27 | + mask = torch.unsqueeze(mask,1) |
| 28 | + net.forward(patch, mask, training=True) |
| 29 | + elbo = net.elbo(mask) |
| 30 | + reg_loss = l2_regularisation(net.posterior) + l2_regularisation(net.prior) + l2_regularisation(net.fcomb) |
| 31 | + loss = -elbo + 1e-5 * reg_loss |
| 32 | + optimizer.zero_grad() |
| 33 | + loss.backward() |
| 34 | + optimizer.step() |
| 35 | +``` |
0 commit comments