Skip to content

Commit 4dfa923

Browse files
author
Stefan Knegt
authored
Update README.md
1 parent 69b384d commit 4dfa923

File tree

1 file changed

+35
-2
lines changed

1 file changed

+35
-2
lines changed

README.md

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,35 @@
1-
# probabilistic_unet_pytorch
2-
A Probabilistic U-Net for segmentation of ambiguous images implemented in PyTorch
1+
# Probabilistic UNet in PyTorch
2+
A Probabilistic U-Net for segmentation of ambiguous images implemented in PyTorch. This is a pytorch implementation of this paper https://arxiv.org/abs/1806.05034, for which the code can be found here: https://github.com/SimonKohl/probabilistic_unet.
3+
4+
## Adding KL divergence for Independent distribution
5+
In order to implement an Gaussian distribution with an axis aligned covariance matrix in PyTorch, I needed to wrap a Normal distribution in a Independent distribution. Therefore you need the add the following to the PyTorch source code at torch/distributions/kl.py (source: https://github.com/pytorch/pytorch/issues/13545).
6+
7+
```
8+
def _kl_independent_independent(p, q):
9+
if p.reinterpreted_batch_ndims != q.reinterpreted_batch_ndims:
10+
raise NotImplementedError
11+
result = kl_divergence(p.base_dist, q.base_dist)
12+
return _sum_rightmost(result, p.reinterpreted_batch_ndims)
13+
```
14+
15+
## Training
16+
In order to train your own Probabilistic UNet in PyTorch, you should first write your own data loader. Then you can use the following code snippet to train the network
17+
18+
```
19+
train_loader = define this yourself
20+
net = ProbabilisticUnet(no_channels,no_classes,filter_list,latent_dim,no_fcomb_convs,beta)
21+
net.to(device)
22+
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, weight_decay=0)
23+
for epoch in range(epochs):
24+
for step, (patch, mask) in enumerate(train_loader):
25+
patch = patch.to(device)
26+
mask = mask.to(device)
27+
mask = torch.unsqueeze(mask,1)
28+
net.forward(patch, mask, training=True)
29+
elbo = net.elbo(mask)
30+
reg_loss = l2_regularisation(net.posterior) + l2_regularisation(net.prior) + l2_regularisation(net.fcomb)
31+
loss = -elbo + 1e-5 * reg_loss
32+
optimizer.zero_grad()
33+
loss.backward()
34+
optimizer.step()
35+
```

0 commit comments

Comments
 (0)