Skip to content

Commit 69b384d

Browse files
author
Stefan
committed
Initial commit
1 parent 6b93e66 commit 69b384d

File tree

5 files changed

+725
-0
lines changed

5 files changed

+725
-0
lines changed

losses.py

Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
2+
from __future__ import print_function, division
3+
4+
import torch
5+
from torch.autograd import Variable
6+
import torch.nn.functional as F
7+
import numpy as np
8+
try:
9+
from itertools import ifilterfalse
10+
except ImportError: # py3k
11+
from itertools import filterfalse as ifilterfalse
12+
13+
def dice_loss(pred, target):
14+
"""This definition generalize to real valued pred and target vector.
15+
This should be differentiable.
16+
pred: tensor with first dimension as batch
17+
target: tensor with first dimension as batch
18+
"""
19+
20+
smooth = 1.
21+
22+
# have to use contiguous since they may from a torch.view op
23+
iflat = pred.contiguous().view(-1)
24+
tflat = target.contiguous().view(-1)
25+
intersection = (iflat * tflat).sum()
26+
27+
A_sum = torch.sum(tflat * iflat)
28+
B_sum = torch.sum(tflat * tflat)
29+
30+
return 1 - ((2. * intersection + smooth) / (A_sum + B_sum + smooth) )
31+
32+
def dice_loss2(input,target):
33+
input = torch.sigmoid(input)
34+
35+
smooth = 1.
36+
37+
iflat = input.view(-1)
38+
tflat = target.view(-1)
39+
intersection = (iflat * tflat).sum()
40+
41+
return 1 - ((2. * intersection + smooth) /(iflat.sum() + tflat.sum() + smooth))
42+
43+
"""
44+
Lovasz-Softmax and Jaccard hinge loss in PyTorch
45+
Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License)
46+
"""
47+
48+
def lovasz_grad(gt_sorted):
49+
"""
50+
Computes gradient of the Lovasz extension w.r.t sorted errors
51+
See Alg. 1 in paper
52+
"""
53+
p = len(gt_sorted)
54+
gts = gt_sorted.sum()
55+
intersection = gts - gt_sorted.float().cumsum(0)
56+
union = gts + (1 - gt_sorted).float().cumsum(0)
57+
jaccard = 1. - intersection / union
58+
if p > 1: # cover 1-pixel case
59+
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
60+
return jaccard
61+
62+
63+
def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True):
64+
"""
65+
IoU for foreground class
66+
binary: 1 foreground, 0 background
67+
"""
68+
if not per_image:
69+
preds, labels = (preds,), (labels,)
70+
ious = []
71+
for pred, label in zip(preds, labels):
72+
intersection = ((label == 1) & (pred == 1)).sum()
73+
union = ((label == 1) | ((pred == 1) & (label != ignore))).sum()
74+
if not union:
75+
iou = EMPTY
76+
else:
77+
iou = float(intersection) / union
78+
ious.append(iou)
79+
iou = mean(ious) # mean accross images if per_image
80+
return 100 * iou
81+
82+
83+
def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False):
84+
"""
85+
Array of IoU for each (non ignored) class
86+
"""
87+
if not per_image:
88+
preds, labels = (preds,), (labels,)
89+
ious = []
90+
for pred, label in zip(preds, labels):
91+
iou = []
92+
for i in range(C):
93+
if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes)
94+
intersection = ((label == i) & (pred == i)).sum()
95+
union = ((label == i) | ((pred == i) & (label != ignore))).sum()
96+
if not union:
97+
iou.append(EMPTY)
98+
else:
99+
iou.append(float(intersection) / union)
100+
ious.append(iou)
101+
ious = map(mean, zip(*ious)) # mean accross images if per_image
102+
return 100 * np.array(ious)
103+
104+
105+
# --------------------------- BINARY LOSSES ---------------------------
106+
107+
108+
def lovasz_hinge(logits, labels, per_image=True, ignore=None):
109+
"""
110+
Binary Lovasz hinge loss
111+
logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
112+
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
113+
per_image: compute the loss per image instead of per batch
114+
ignore: void class id
115+
"""
116+
if per_image:
117+
loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore))
118+
for log, lab in zip(logits, labels))
119+
else:
120+
loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
121+
return loss
122+
123+
124+
def lovasz_hinge_flat(logits, labels):
125+
"""
126+
Binary Lovasz hinge loss
127+
logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
128+
labels: [P] Tensor, binary ground truth labels (0 or 1)
129+
ignore: label to ignore
130+
"""
131+
if len(labels) == 0:
132+
# only void pixels, the gradients should be 0
133+
return logits.sum() * 0.
134+
signs = 2. * labels.float() - 1.
135+
errors = (1. - logits * Variable(signs))
136+
errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
137+
perm = perm.data
138+
gt_sorted = labels[perm]
139+
grad = lovasz_grad(gt_sorted)
140+
loss = torch.dot(F.relu(errors_sorted), Variable(grad))
141+
return loss
142+
143+
144+
def flatten_binary_scores(scores, labels, ignore=None):
145+
"""
146+
Flattens predictions in the batch (binary case)
147+
Remove labels equal to 'ignore'
148+
"""
149+
scores = scores.view(-1)
150+
labels = labels.view(-1)
151+
if ignore is None:
152+
return scores, labels
153+
valid = (labels != ignore)
154+
vscores = scores[valid]
155+
vlabels = labels[valid]
156+
return vscores, vlabels
157+
158+
159+
class StableBCELoss(torch.nn.modules.Module):
160+
def __init__(self):
161+
super(StableBCELoss, self).__init__()
162+
def forward(self, input, target):
163+
neg_abs = - input.abs()
164+
loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
165+
return loss.mean()
166+
167+
168+
def binary_xloss(logits, labels, ignore=None):
169+
"""
170+
Binary Cross entropy loss
171+
logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
172+
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
173+
ignore: void class id
174+
"""
175+
logits, labels = flatten_binary_scores(logits, labels, ignore)
176+
loss = StableBCELoss()(logits, Variable(labels.float()))
177+
return loss
178+
179+
180+
# --------------------------- MULTICLASS LOSSES ---------------------------
181+
182+
183+
def lovasz_softmax(probas, labels, only_present=False, per_image=False, ignore=None):
184+
"""
185+
Multi-class Lovasz-Softmax loss
186+
probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1)
187+
labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
188+
only_present: average only on classes present in ground truth
189+
per_image: compute the loss per image instead of per batch
190+
ignore: void class labels
191+
"""
192+
if per_image:
193+
loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), only_present=only_present)
194+
for prob, lab in zip(probas, labels))
195+
else:
196+
loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), only_present=only_present)
197+
return loss
198+
199+
200+
def lovasz_softmax_flat(probas, labels, only_present=False):
201+
"""
202+
Multi-class Lovasz-Softmax loss
203+
probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
204+
labels: [P] Tensor, ground truth labels (between 0 and C - 1)
205+
only_present: average only on classes present in ground truth
206+
"""
207+
if probas.numel() == 0:
208+
# only void pixels, the gradients should be 0
209+
return probas * 0.
210+
C = probas.size(1)
211+
212+
C = probas.size(1)
213+
losses = []
214+
for c in range(C):
215+
fg = (labels == c).float() # foreground for class c
216+
if only_present and fg.sum() == 0:
217+
continue
218+
errors = (Variable(fg) - probas[:, c]).abs()
219+
errors_sorted, perm = torch.sort(errors, 0, descending=True)
220+
perm = perm.data
221+
fg_sorted = fg[perm]
222+
losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
223+
return mean(losses)
224+
225+
226+
def flatten_probas(probas, labels, ignore=None):
227+
"""
228+
Flattens predictions in the batch
229+
"""
230+
B, C, H, W = probas.size()
231+
probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C
232+
labels = labels.view(-1)
233+
if ignore is None:
234+
return probas, labels
235+
valid = (labels != ignore)
236+
vprobas = probas[valid.nonzero().squeeze()]
237+
vlabels = labels[valid]
238+
return vprobas, vlabels
239+
240+
def xloss(logits, labels, ignore=None):
241+
"""
242+
Cross entropy loss
243+
"""
244+
return F.cross_entropy(logits, Variable(labels), ignore_index=255)
245+
246+
247+
# --------------------------- HELPER FUNCTIONS ---------------------------
248+
def isnan(x):
249+
return x != x
250+
251+
252+
def mean(l, ignore_nan=True, empty=0):
253+
"""
254+
nanmean compatible with generators.
255+
"""
256+
l = iter(l)
257+
if ignore_nan:
258+
l = ifilterfalse(isnan, l)
259+
try:
260+
n = 1
261+
acc = next(l)
262+
except StopIteration:
263+
if empty == 'raise':
264+
raise ValueError('Empty mean')
265+
return empty
266+
for n, v in enumerate(l, 2):
267+
acc += v
268+
if n == 1:
269+
return acc
270+
return acc / n

0 commit comments

Comments
 (0)