@@ -19,7 +19,8 @@ def __init__(self, model_name="google/bert_uncased_L-2_H-128_A-2", lr=3e-5):
1919 model_name , num_labels = 2
2020 )
2121 self .num_classes = 2
22- self .accuracy_metric = torchmetrics .Accuracy ()
22+ self .train_accuracy_metric = torchmetrics .Accuracy ()
23+ self .val_accuracy_metric = torchmetrics .Accuracy ()
2324 self .f1_metric = torchmetrics .F1 (num_classes = self .num_classes )
2425 self .precision_macro_metric = torchmetrics .Precision (
2526 average = "macro" , num_classes = self .num_classes
@@ -42,7 +43,7 @@ def training_step(self, batch, batch_idx):
4243 )
4344 # loss = F.cross_entropy(logits, batch["label"])
4445 preds = torch .argmax (outputs .logits , 1 )
45- train_acc = self .accuracy_metric (preds , batch ["label" ])
46+ train_acc = self .train_accuracy_metric (preds , batch ["label" ])
4647 self .log ("train/loss" , outputs .loss , prog_bar = True , on_epoch = True )
4748 self .log ("train/acc" , train_acc , prog_bar = True , on_epoch = True )
4849 return outputs .loss
@@ -55,7 +56,7 @@ def validation_step(self, batch, batch_idx):
5556 preds = torch .argmax (outputs .logits , 1 )
5657
5758 # Metrics
58- valid_acc = self .accuracy_metric (preds , labels )
59+ valid_acc = self .val_accuracy_metric (preds , labels )
5960 precision_macro = self .precision_macro_metric (preds , labels )
6061 recall_macro = self .recall_macro_metric (preds , labels )
6162 precision_micro = self .precision_micro_metric (preds , labels )
0 commit comments