Skip to content

Commit 2f9f6bd

Browse files
committed
(1) Add AtK like linear . (2) Update order
1 parent 86be2c9 commit 2f9f6bd

File tree

1 file changed

+92
-93
lines changed

1 file changed

+92
-93
lines changed

libmultilabel/nn/metrics.py

Lines changed: 92 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,37 @@
99
from torchmetrics.utilities.data import select_topk
1010

1111

12+
class _PrecisonRecallWrapperMetric(Metric):
13+
"""Encapsulate common functions of RPrecision, PrecisionAtK, and RecallAtK.
14+
15+
Args:
16+
top_k (int): the top k relevant labels to evaluate.
17+
"""
18+
19+
# If the metric state of one batch is independent of the state of other batches,
20+
# full_state_update can be set to False,
21+
# which leads to more efficient computation with calling update() only once.
22+
# Please find the detailed explanation here:
23+
# https://torchmetrics.readthedocs.io/en/stable/pages/implement.html
24+
full_state_update = False
25+
26+
def __init__(self, top_k):
27+
super().__init__()
28+
self.top_k = top_k
29+
self.add_state("score", default=torch.tensor(0.0, dtype=torch.double), dist_reduce_fx="sum")
30+
self.add_state("num_sample", default=torch.tensor(0), dist_reduce_fx="sum")
31+
32+
def compute(self):
33+
return self.score / self.num_sample
34+
35+
def _get_num_relevant(self, preds, target):
36+
assert preds.shape == target.shape
37+
binary_topk_preds = select_topk(preds, self.top_k)
38+
target = target.to(dtype=torch.int)
39+
num_relevant = torch.sum(binary_topk_preds & target, dim=-1)
40+
return num_relevant
41+
42+
1243
class Loss(Metric):
1344
"""Loss records the batch-wise losses
1445
and then obtains a mean loss from the recorded losses.
@@ -35,7 +66,57 @@ def compute(self):
3566
return self.loss / self.num_sample
3667

3768

38-
class NDCG(Metric):
69+
class MacroF1(Metric):
70+
"""The macro-f1 score computes the average f1 scores of all labels in the dataset.
71+
72+
Args:
73+
num_classes (int): Total number of classes.
74+
metric_threshold (float): The decision value threshold over which a label is predicted as positive.
75+
another_macro_f1 (bool, optional): Whether to compute the 'Another-Macro-F1' score.
76+
The 'Another-Macro-F1' is the f1 value of macro-precision and macro-recall.
77+
This variant of macro-f1 is less preferred but is used in some works.
78+
Please refer to Opitz et al. 2019 [https://arxiv.org/pdf/1911.03347.pdf].
79+
Defaults to False.
80+
"""
81+
82+
# If the metric state of one batch is independent of the state of other batches,
83+
# full_state_update can be set to False,
84+
# which leads to more efficient computation with calling update() only once.
85+
# Please find the detailed explanation here:
86+
# https://torchmetrics.readthedocs.io/en/stable/pages/implement.html
87+
full_state_update = False
88+
89+
def __init__(self, num_classes, metric_threshold, another_macro_f1=False, top_k=None):
90+
super().__init__()
91+
self.metric_threshold = metric_threshold
92+
self.another_macro_f1 = another_macro_f1
93+
self.top_k = top_k
94+
self.add_state("preds_sum", default=torch.zeros(num_classes, dtype=torch.double))
95+
self.add_state("target_sum", default=torch.zeros(num_classes, dtype=torch.double))
96+
self.add_state("tp_sum", default=torch.zeros(num_classes, dtype=torch.double))
97+
98+
def update(self, preds, target):
99+
assert preds.shape == target.shape
100+
if self.top_k:
101+
preds = select_topk(preds, self.top_k)
102+
else:
103+
preds = torch.where(preds > self.metric_threshold, 1, 0)
104+
105+
self.preds_sum = torch.add(self.preds_sum, preds.sum(dim=0))
106+
self.target_sum = torch.add(self.target_sum, target.sum(dim=0))
107+
self.tp_sum = torch.add(self.tp_sum, (preds & target).sum(dim=0))
108+
109+
def compute(self):
110+
if self.another_macro_f1:
111+
macro_prec = torch.mean(torch.nan_to_num(self.tp_sum / self.preds_sum, posinf=0.0))
112+
macro_recall = torch.mean(torch.nan_to_num(self.tp_sum / self.target_sum, posinf=0.0))
113+
return 2 * (macro_prec * macro_recall) / (macro_prec + macro_recall + 1e-10)
114+
else:
115+
label_f1 = 2 * self.tp_sum / (self.preds_sum + self.target_sum + 1e-10)
116+
return torch.mean(label_f1)
117+
118+
119+
class NDCGAtK(Metric):
39120
"""NDCG (Normalized Discounted Cumulative Gain) sums the true scores
40121
ranked in the order induced by the predicted scores after applying a logarithmic discount,
41122
and then divides by the best possible score (Ideal DCG, obtained for a perfect ranking)
@@ -98,50 +179,6 @@ def _idcg(self, target, discount):
98179
return cum_discount[idx]
99180

100181

101-
class _PrecisonRecallWrapperMetric(Metric):
102-
"""Encapsulate common functions of RPrecision, PrecisionAtK, and RecallAtK.
103-
104-
Args:
105-
top_k (int): the top k relevant labels to evaluate.
106-
"""
107-
108-
# If the metric state of one batch is independent of the state of other batches,
109-
# full_state_update can be set to False,
110-
# which leads to more efficient computation with calling update() only once.
111-
# Please find the detailed explanation here:
112-
# https://torchmetrics.readthedocs.io/en/stable/pages/implement.html
113-
full_state_update = False
114-
115-
def __init__(self, top_k):
116-
super().__init__()
117-
self.top_k = top_k
118-
self.add_state("score", default=torch.tensor(0.0, dtype=torch.double), dist_reduce_fx="sum")
119-
self.add_state("num_sample", default=torch.tensor(0), dist_reduce_fx="sum")
120-
121-
def compute(self):
122-
return self.score / self.num_sample
123-
124-
def _get_num_relevant(self, preds, target):
125-
assert preds.shape == target.shape
126-
binary_topk_preds = select_topk(preds, self.top_k)
127-
target = target.to(dtype=torch.int)
128-
num_relevant = torch.sum(binary_topk_preds & target, dim=-1)
129-
return num_relevant
130-
131-
132-
class RPrecision(_PrecisonRecallWrapperMetric):
133-
"""R-precision calculates precision at k by adjusting k to the minimum value of the number of
134-
relevant labels and k. The definition is given at Appendix C equation (3) of
135-
https://aclanthology.org/P19-1636.pdf
136-
"""
137-
138-
def update(self, preds, target):
139-
num_relevant = super()._get_num_relevant(preds, target)
140-
top_ks = torch.tensor([self.top_k] * preds.shape[0]).to(preds.device)
141-
self.score += torch.nan_to_num(num_relevant / torch.min(top_ks, target.sum(dim=-1)), posinf=0.0).sum()
142-
self.num_sample += len(preds)
143-
144-
145182
class PrecisionAtK(_PrecisonRecallWrapperMetric):
146183
"""Precision at k. Please refer to the `implementation document`
147184
(https://www.csie.ntu.edu.tw/~cjlin/papers/libmultilabel/libmultilabel_implementation.pdf) for details.
@@ -164,54 +201,17 @@ def update(self, preds, target):
164201
self.num_sample += len(preds)
165202

166203

167-
class MacroF1(Metric):
168-
"""The macro-f1 score computes the average f1 scores of all labels in the dataset.
169-
170-
Args:
171-
num_classes (int): Total number of classes.
172-
metric_threshold (float): The decision value threshold over which a label is predicted as positive.
173-
another_macro_f1 (bool, optional): Whether to compute the 'Another-Macro-F1' score.
174-
The 'Another-Macro-F1' is the f1 value of macro-precision and macro-recall.
175-
This variant of macro-f1 is less preferred but is used in some works.
176-
Please refer to Opitz et al. 2019 [https://arxiv.org/pdf/1911.03347.pdf].
177-
Defaults to False.
204+
class RPrecisionAtK(_PrecisonRecallWrapperMetric):
205+
"""R-precision calculates precision at k by adjusting k to the minimum value of the number of
206+
relevant labels and k. The definition is given at Appendix C equation (3) of
207+
https://aclanthology.org/P19-1636.pdf
178208
"""
179209

180-
# If the metric state of one batch is independent of the state of other batches,
181-
# full_state_update can be set to False,
182-
# which leads to more efficient computation with calling update() only once.
183-
# Please find the detailed explanation here:
184-
# https://torchmetrics.readthedocs.io/en/stable/pages/implement.html
185-
full_state_update = False
186-
187-
def __init__(self, num_classes, metric_threshold, another_macro_f1=False, top_k=None):
188-
super().__init__()
189-
self.metric_threshold = metric_threshold
190-
self.another_macro_f1 = another_macro_f1
191-
self.top_k = top_k
192-
self.add_state("preds_sum", default=torch.zeros(num_classes, dtype=torch.double))
193-
self.add_state("target_sum", default=torch.zeros(num_classes, dtype=torch.double))
194-
self.add_state("tp_sum", default=torch.zeros(num_classes, dtype=torch.double))
195-
196210
def update(self, preds, target):
197-
assert preds.shape == target.shape
198-
if self.top_k:
199-
preds = select_topk(preds, self.top_k)
200-
else:
201-
preds = torch.where(preds > self.metric_threshold, 1, 0)
202-
203-
self.preds_sum = torch.add(self.preds_sum, preds.sum(dim=0))
204-
self.target_sum = torch.add(self.target_sum, target.sum(dim=0))
205-
self.tp_sum = torch.add(self.tp_sum, (preds & target).sum(dim=0))
206-
207-
def compute(self):
208-
if self.another_macro_f1:
209-
macro_prec = torch.mean(torch.nan_to_num(self.tp_sum / self.preds_sum, posinf=0.0))
210-
macro_recall = torch.mean(torch.nan_to_num(self.tp_sum / self.target_sum, posinf=0.0))
211-
return 2 * (macro_prec * macro_recall) / (macro_prec + macro_recall + 1e-10)
212-
else:
213-
label_f1 = 2 * self.tp_sum / (self.preds_sum + self.target_sum + 1e-10)
214-
return torch.mean(label_f1)
211+
num_relevant = super()._get_num_relevant(preds, target)
212+
top_ks = torch.tensor([self.top_k] * preds.shape[0]).to(preds.device)
213+
self.score += torch.nan_to_num(num_relevant / torch.min(top_ks, target.sum(dim=-1)), posinf=0.0).sum()
214+
self.num_sample += len(preds)
215215

216216

217217
def get_metrics(metric_threshold, monitor_metrics, num_classes, top_k=None):
@@ -257,9 +257,9 @@ def get_metrics(metric_threshold, monitor_metrics, num_classes, top_k=None):
257257
elif metric_abbr == "R":
258258
metrics[metric] = RecallAtK(top_k=top_k)
259259
elif metric_abbr == "RP":
260-
metrics[metric] = RPrecision(top_k=top_k)
260+
metrics[metric] = RPrecisionAtK(top_k=top_k)
261261
elif metric_abbr == "nDCG":
262-
metrics[metric] = NDCG(top_k=top_k)
262+
metrics[metric] = NDCGAtK(top_k=top_k)
263263
# The implementation in torchmetrics stores the prediction/target of all batches,
264264
# which can lead to CUDA out of memory.
265265
# metrics[metric] = RetrievalNormalizedDCG(k=top_k)
@@ -278,7 +278,6 @@ def get_metrics(metric_threshold, monitor_metrics, num_classes, top_k=None):
278278
threshold=metric_threshold,
279279
num_labels=num_classes,
280280
average=average_type,
281-
top_k=top_k,
282281
)
283282
else:
284283
raise ValueError(

0 commit comments

Comments
 (0)