99from torchmetrics .utilities .data import select_topk
1010
1111
12+ class _PrecisonRecallWrapperMetric (Metric ):
13+ """Encapsulate common functions of RPrecision, PrecisionAtK, and RecallAtK.
14+
15+ Args:
16+ top_k (int): the top k relevant labels to evaluate.
17+ """
18+
19+ # If the metric state of one batch is independent of the state of other batches,
20+ # full_state_update can be set to False,
21+ # which leads to more efficient computation with calling update() only once.
22+ # Please find the detailed explanation here:
23+ # https://torchmetrics.readthedocs.io/en/stable/pages/implement.html
24+ full_state_update = False
25+
26+ def __init__ (self , top_k ):
27+ super ().__init__ ()
28+ self .top_k = top_k
29+ self .add_state ("score" , default = torch .tensor (0.0 , dtype = torch .double ), dist_reduce_fx = "sum" )
30+ self .add_state ("num_sample" , default = torch .tensor (0 ), dist_reduce_fx = "sum" )
31+
32+ def compute (self ):
33+ return self .score / self .num_sample
34+
35+ def _get_num_relevant (self , preds , target ):
36+ assert preds .shape == target .shape
37+ binary_topk_preds = select_topk (preds , self .top_k )
38+ target = target .to (dtype = torch .int )
39+ num_relevant = torch .sum (binary_topk_preds & target , dim = - 1 )
40+ return num_relevant
41+
42+
1243class Loss (Metric ):
1344 """Loss records the batch-wise losses
1445 and then obtains a mean loss from the recorded losses.
@@ -35,7 +66,57 @@ def compute(self):
3566 return self .loss / self .num_sample
3667
3768
38- class NDCG (Metric ):
69+ class MacroF1 (Metric ):
70+ """The macro-f1 score computes the average f1 scores of all labels in the dataset.
71+
72+ Args:
73+ num_classes (int): Total number of classes.
74+ metric_threshold (float): The decision value threshold over which a label is predicted as positive.
75+ another_macro_f1 (bool, optional): Whether to compute the 'Another-Macro-F1' score.
76+ The 'Another-Macro-F1' is the f1 value of macro-precision and macro-recall.
77+ This variant of macro-f1 is less preferred but is used in some works.
78+ Please refer to Opitz et al. 2019 [https://arxiv.org/pdf/1911.03347.pdf].
79+ Defaults to False.
80+ """
81+
82+ # If the metric state of one batch is independent of the state of other batches,
83+ # full_state_update can be set to False,
84+ # which leads to more efficient computation with calling update() only once.
85+ # Please find the detailed explanation here:
86+ # https://torchmetrics.readthedocs.io/en/stable/pages/implement.html
87+ full_state_update = False
88+
89+ def __init__ (self , num_classes , metric_threshold , another_macro_f1 = False , top_k = None ):
90+ super ().__init__ ()
91+ self .metric_threshold = metric_threshold
92+ self .another_macro_f1 = another_macro_f1
93+ self .top_k = top_k
94+ self .add_state ("preds_sum" , default = torch .zeros (num_classes , dtype = torch .double ))
95+ self .add_state ("target_sum" , default = torch .zeros (num_classes , dtype = torch .double ))
96+ self .add_state ("tp_sum" , default = torch .zeros (num_classes , dtype = torch .double ))
97+
98+ def update (self , preds , target ):
99+ assert preds .shape == target .shape
100+ if self .top_k :
101+ preds = select_topk (preds , self .top_k )
102+ else :
103+ preds = torch .where (preds > self .metric_threshold , 1 , 0 )
104+
105+ self .preds_sum = torch .add (self .preds_sum , preds .sum (dim = 0 ))
106+ self .target_sum = torch .add (self .target_sum , target .sum (dim = 0 ))
107+ self .tp_sum = torch .add (self .tp_sum , (preds & target ).sum (dim = 0 ))
108+
109+ def compute (self ):
110+ if self .another_macro_f1 :
111+ macro_prec = torch .mean (torch .nan_to_num (self .tp_sum / self .preds_sum , posinf = 0.0 ))
112+ macro_recall = torch .mean (torch .nan_to_num (self .tp_sum / self .target_sum , posinf = 0.0 ))
113+ return 2 * (macro_prec * macro_recall ) / (macro_prec + macro_recall + 1e-10 )
114+ else :
115+ label_f1 = 2 * self .tp_sum / (self .preds_sum + self .target_sum + 1e-10 )
116+ return torch .mean (label_f1 )
117+
118+
119+ class NDCGAtK (Metric ):
39120 """NDCG (Normalized Discounted Cumulative Gain) sums the true scores
40121 ranked in the order induced by the predicted scores after applying a logarithmic discount,
41122 and then divides by the best possible score (Ideal DCG, obtained for a perfect ranking)
@@ -98,50 +179,6 @@ def _idcg(self, target, discount):
98179 return cum_discount [idx ]
99180
100181
101- class _PrecisonRecallWrapperMetric (Metric ):
102- """Encapsulate common functions of RPrecision, PrecisionAtK, and RecallAtK.
103-
104- Args:
105- top_k (int): the top k relevant labels to evaluate.
106- """
107-
108- # If the metric state of one batch is independent of the state of other batches,
109- # full_state_update can be set to False,
110- # which leads to more efficient computation with calling update() only once.
111- # Please find the detailed explanation here:
112- # https://torchmetrics.readthedocs.io/en/stable/pages/implement.html
113- full_state_update = False
114-
115- def __init__ (self , top_k ):
116- super ().__init__ ()
117- self .top_k = top_k
118- self .add_state ("score" , default = torch .tensor (0.0 , dtype = torch .double ), dist_reduce_fx = "sum" )
119- self .add_state ("num_sample" , default = torch .tensor (0 ), dist_reduce_fx = "sum" )
120-
121- def compute (self ):
122- return self .score / self .num_sample
123-
124- def _get_num_relevant (self , preds , target ):
125- assert preds .shape == target .shape
126- binary_topk_preds = select_topk (preds , self .top_k )
127- target = target .to (dtype = torch .int )
128- num_relevant = torch .sum (binary_topk_preds & target , dim = - 1 )
129- return num_relevant
130-
131-
132- class RPrecision (_PrecisonRecallWrapperMetric ):
133- """R-precision calculates precision at k by adjusting k to the minimum value of the number of
134- relevant labels and k. The definition is given at Appendix C equation (3) of
135- https://aclanthology.org/P19-1636.pdf
136- """
137-
138- def update (self , preds , target ):
139- num_relevant = super ()._get_num_relevant (preds , target )
140- top_ks = torch .tensor ([self .top_k ] * preds .shape [0 ]).to (preds .device )
141- self .score += torch .nan_to_num (num_relevant / torch .min (top_ks , target .sum (dim = - 1 )), posinf = 0.0 ).sum ()
142- self .num_sample += len (preds )
143-
144-
145182class PrecisionAtK (_PrecisonRecallWrapperMetric ):
146183 """Precision at k. Please refer to the `implementation document`
147184 (https://www.csie.ntu.edu.tw/~cjlin/papers/libmultilabel/libmultilabel_implementation.pdf) for details.
@@ -164,54 +201,17 @@ def update(self, preds, target):
164201 self .num_sample += len (preds )
165202
166203
167- class MacroF1 (Metric ):
168- """The macro-f1 score computes the average f1 scores of all labels in the dataset.
169-
170- Args:
171- num_classes (int): Total number of classes.
172- metric_threshold (float): The decision value threshold over which a label is predicted as positive.
173- another_macro_f1 (bool, optional): Whether to compute the 'Another-Macro-F1' score.
174- The 'Another-Macro-F1' is the f1 value of macro-precision and macro-recall.
175- This variant of macro-f1 is less preferred but is used in some works.
176- Please refer to Opitz et al. 2019 [https://arxiv.org/pdf/1911.03347.pdf].
177- Defaults to False.
204+ class RPrecisionAtK (_PrecisonRecallWrapperMetric ):
205+ """R-precision calculates precision at k by adjusting k to the minimum value of the number of
206+ relevant labels and k. The definition is given at Appendix C equation (3) of
207+ https://aclanthology.org/P19-1636.pdf
178208 """
179209
180- # If the metric state of one batch is independent of the state of other batches,
181- # full_state_update can be set to False,
182- # which leads to more efficient computation with calling update() only once.
183- # Please find the detailed explanation here:
184- # https://torchmetrics.readthedocs.io/en/stable/pages/implement.html
185- full_state_update = False
186-
187- def __init__ (self , num_classes , metric_threshold , another_macro_f1 = False , top_k = None ):
188- super ().__init__ ()
189- self .metric_threshold = metric_threshold
190- self .another_macro_f1 = another_macro_f1
191- self .top_k = top_k
192- self .add_state ("preds_sum" , default = torch .zeros (num_classes , dtype = torch .double ))
193- self .add_state ("target_sum" , default = torch .zeros (num_classes , dtype = torch .double ))
194- self .add_state ("tp_sum" , default = torch .zeros (num_classes , dtype = torch .double ))
195-
196210 def update (self , preds , target ):
197- assert preds .shape == target .shape
198- if self .top_k :
199- preds = select_topk (preds , self .top_k )
200- else :
201- preds = torch .where (preds > self .metric_threshold , 1 , 0 )
202-
203- self .preds_sum = torch .add (self .preds_sum , preds .sum (dim = 0 ))
204- self .target_sum = torch .add (self .target_sum , target .sum (dim = 0 ))
205- self .tp_sum = torch .add (self .tp_sum , (preds & target ).sum (dim = 0 ))
206-
207- def compute (self ):
208- if self .another_macro_f1 :
209- macro_prec = torch .mean (torch .nan_to_num (self .tp_sum / self .preds_sum , posinf = 0.0 ))
210- macro_recall = torch .mean (torch .nan_to_num (self .tp_sum / self .target_sum , posinf = 0.0 ))
211- return 2 * (macro_prec * macro_recall ) / (macro_prec + macro_recall + 1e-10 )
212- else :
213- label_f1 = 2 * self .tp_sum / (self .preds_sum + self .target_sum + 1e-10 )
214- return torch .mean (label_f1 )
211+ num_relevant = super ()._get_num_relevant (preds , target )
212+ top_ks = torch .tensor ([self .top_k ] * preds .shape [0 ]).to (preds .device )
213+ self .score += torch .nan_to_num (num_relevant / torch .min (top_ks , target .sum (dim = - 1 )), posinf = 0.0 ).sum ()
214+ self .num_sample += len (preds )
215215
216216
217217def get_metrics (metric_threshold , monitor_metrics , num_classes , top_k = None ):
@@ -257,9 +257,9 @@ def get_metrics(metric_threshold, monitor_metrics, num_classes, top_k=None):
257257 elif metric_abbr == "R" :
258258 metrics [metric ] = RecallAtK (top_k = top_k )
259259 elif metric_abbr == "RP" :
260- metrics [metric ] = RPrecision (top_k = top_k )
260+ metrics [metric ] = RPrecisionAtK (top_k = top_k )
261261 elif metric_abbr == "nDCG" :
262- metrics [metric ] = NDCG (top_k = top_k )
262+ metrics [metric ] = NDCGAtK (top_k = top_k )
263263 # The implementation in torchmetrics stores the prediction/target of all batches,
264264 # which can lead to CUDA out of memory.
265265 # metrics[metric] = RetrievalNormalizedDCG(k=top_k)
@@ -278,7 +278,6 @@ def get_metrics(metric_threshold, monitor_metrics, num_classes, top_k=None):
278278 threshold = metric_threshold ,
279279 num_labels = num_classes ,
280280 average = average_type ,
281- top_k = top_k ,
282281 )
283282 else :
284283 raise ValueError (
0 commit comments