1- # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
1+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
22#
33# Licensed under the Apache License, Version 2.0 (the "License");
44# you may not use this file except in compliance with the License.
@@ -68,22 +68,6 @@ def _forward(self):
6868 return head_outs , targets
6969
7070 def get_loss (self ):
71- # batch_gt_class = self.inputs["gt_class"]
72- # batch_gt_box = self.inputs["gt_bbox"]
73- # batch_whwh = self.inputs["img_whwh"]
74- # targets = []
75-
76- # for i in range(len(batch_gt_class)):
77- # boxes = batch_gt_box[i]
78- # labels = batch_gt_class[i].squeeze(-1)
79- # img_whwh = batch_whwh[i]
80- # img_whwh_tgt = img_whwh.unsqueeze(0).tile([int(boxes.shape[0]), 1])
81- # targets.append({
82- # "boxes": boxes,
83- # "labels": labels,
84- # "img_whwh": img_whwh,
85- # "img_whwh_tgt": img_whwh_tgt
86- # })
8771
8872 outputs , targets = self ._forward ()
8973 loss_dict = self .head .get_loss (outputs , targets )
@@ -92,6 +76,7 @@ def get_loss(self):
9276 return loss_dict
9377
9478 def get_pred (self ):
79+
9580 bbox_pred , bbox_num = self ._forward ()
9681 output = {'bbox' : bbox_pred , 'bbox_num' : bbox_num }
9782 return output
0 commit comments