1010
1111class Pifpaf (Model ):
1212 def __init__ (self ,parts ,limbs ,colors = CocoColor ,n_pos = 17 ,n_limbs = 19 ,hin = 368 ,win = 368 ,scale_size = 32 ,backbone = None ,pretraining = False ,quad_size = 2 ,quad_num = 1 ,
13- lambda_pif_conf = 30 .0 ,lambda_pif_vec = 2 .0 ,lambda_pif_scale = 2 .0 ,lambda_paf_conf = 50 .0 ,lambda_paf_src_vec = 3 .0 ,lambda_paf_dst_vec = 3 .0 ,
14- lambda_paf_src_scale = 2 .0 ,lambda_paf_dst_scale = 2 .0 ,data_format = "channels_first" ):
13+ lambda_pif_conf = 1 .0 ,lambda_pif_vec = 1 .0 ,lambda_pif_scale = 1 .0 ,lambda_paf_conf = 1 .0 ,lambda_paf_src_vec = 1 .0 ,lambda_paf_dst_vec = 1 .0 ,
14+ lambda_paf_src_scale = 1 .0 ,lambda_paf_dst_scale = 1 .0 ,data_format = "channels_first" ):
1515 super ().__init__ ()
1616 self .parts = parts
1717 self .limbs = limbs
@@ -65,6 +65,13 @@ def infer(self,x):
6565 paf_conf ,paf_src_vec ,paf_dst_vec ,_ ,_ ,paf_src_scale ,paf_dst_scale = paf_maps
6666 return pif_conf ,pif_vec ,pif_scale ,paf_conf ,paf_src_vec ,paf_dst_vec ,paf_src_scale ,paf_dst_scale
6767
68+ def soft_clamp (self ,x ,max_value = 5.0 ):
69+ above_mask = tf .where (x >= max_value ,1.0 ,0.0 )
70+ x_below = x * (1 - above_mask )
71+ x_soft_above = tf .where (x >= max_value ,x ,max_value )
72+ x_above = (max_value + tf .math .log (1 + x_soft_above - max_value ))* above_mask
73+ return x_below + x_above
74+
6875 def Bce_loss (self ,pd_conf ,gt_conf ,focal_gamma = 1.0 ):
6976 #shape conf:[batch,field,h,w]
7077 batch_size = pd_conf .shape [0 ]
@@ -75,15 +82,18 @@ def Bce_loss(self,pd_conf,gt_conf,focal_gamma=1.0):
7582 gt_conf = gt_conf [valid_mask ]
7683 #calculate loss
7784 bce_loss = tf .nn .sigmoid_cross_entropy_with_logits (logits = pd_conf ,labels = gt_conf )
78- bce_loss = tf . clip_by_value (bce_loss , 0.02 , 5.0 )
85+ bce_loss = self . soft_clamp (bce_loss )
7986 if (focal_gamma != 0.0 ):
80- focal = (1 - tf .exp (- bce_loss ))** focal_gamma
81- focal = tf .stop_gradient (focal )
82- bce_loss = focal * bce_loss
87+ p = tf .nn .sigmoid (pd_conf )
88+ pt = p * gt_conf + (1 - p )* (1 - gt_conf )
89+ focal = 1.0 - pt
90+ if (focal_gamma != 1.0 ):
91+ focal = (focal + 1e-4 )** focal_gamma
92+ bce_loss = focal * bce_loss * 0.5
8393 bce_loss = tf .reduce_sum (bce_loss )/ batch_size
8494 return bce_loss
8595
86- def Laplace_loss (self ,pd_vec ,pd_logb ,gt_vec ):
96+ def Laplace_loss (self ,pd_vec ,pd_logb ,gt_vec , gt_bmin ):
8797 #shape vec: [batch,field,2,h,w]
8898 #shape logb: [batch,field,h,w]
8999 batch_size = pd_vec .shape [0 ]
@@ -98,40 +108,46 @@ def Laplace_loss(self,pd_vec,pd_logb,gt_vec):
98108 gt_vec_x = gt_vec [:,:,0 :1 ,:,:][valid_mask ]
99109 gt_vec_y = gt_vec [:,:,1 :2 ,:,:][valid_mask ]
100110 gt_vec = tf .stack ([gt_vec_x ,gt_vec_y ])
111+ #select gt_bmin
112+ gt_bmin = gt_bmin [:,:,np .newaxis ,:,:][valid_mask ]
101113 #calculate loss
102- norm = tf .norm (pd_vec - gt_vec ,axis = 0 )
103- norm = tf .clip_by_value (norm ,0.0 ,5.0 )
104- pd_logb = tf .clip_by_value (pd_logb ,- 3.0 ,np .inf )
105- laplace_loss = pd_logb + (norm + 0.1 )* tf .exp (- pd_logb )
114+ norm = tf .norm (tf .stack ([pd_vec_x - gt_vec_x ,pd_vec_y - gt_vec_y ,gt_bmin ]),axis = 0 )
115+ pd_logb = 3.0 * tf .tanh (pd_logb / 3.0 )
116+ scaled_norm = norm * tf .exp (- pd_logb )
117+ scaled_norm = self .soft_clamp (scaled_norm )
118+ laplace_loss = pd_logb + scaled_norm
106119 laplace_loss = tf .reduce_sum (laplace_loss )/ batch_size
107120 return laplace_loss
108121
109122 def Scale_loss (self ,pd_scale ,gt_scale ,b = 1.0 ):
110123 batch_size = pd_scale .shape [0 ]
111124 valid_mask = tf .logical_not (tf .math .is_nan (gt_scale ))
112125 pd_scale = pd_scale [valid_mask ]
126+ pd_scale = tf .nn .softplus (pd_scale )
113127 gt_scale = gt_scale [valid_mask ]
114128 scale_loss = tf .abs (pd_scale - gt_scale )
115- scale_loss = tf .clip_by_value (scale_loss ,0.0 ,5.0 )/ b
129+ denominator = 10.0 * (0.1 + gt_scale )
130+ scale_loss = scale_loss / denominator
131+ scale_loss = self .soft_clamp (scale_loss )
116132 scale_loss = tf .reduce_sum (scale_loss )/ batch_size
117133 return scale_loss
118134
119135 def cal_loss (self ,pd_pif_maps ,pd_paf_maps ,gt_pif_maps ,gt_paf_maps ):
120136 #calculate pif losses
121137 pd_pif_conf ,pd_pif_vec ,pd_pif_logb ,pd_pif_scale = pd_pif_maps
122- gt_pif_conf ,gt_pif_vec ,gt_pif_scale = gt_pif_maps
138+ gt_pif_conf ,gt_pif_vec ,gt_pif_bmin , gt_pif_scale = gt_pif_maps
123139 loss_pif_conf = self .Bce_loss (pd_pif_conf ,gt_pif_conf )
124- loss_pif_vec = self .Laplace_loss (pd_pif_vec ,pd_pif_logb ,gt_pif_vec )
140+ loss_pif_vec = self .Laplace_loss (pd_pif_vec ,pd_pif_logb ,gt_pif_vec , gt_pif_bmin )
125141 loss_pif_scale = self .Scale_loss (pd_pif_scale ,gt_pif_scale )
126142 loss_pif_maps = [loss_pif_conf ,loss_pif_vec ,loss_pif_scale ]
127143 #calculate paf losses
128144 pd_paf_conf ,pd_paf_src_vec ,pd_paf_dst_vec ,pd_paf_src_logb ,pd_paf_dst_logb ,pd_paf_src_scale ,pd_paf_dst_scale = pd_paf_maps
129- gt_paf_conf ,gt_paf_src_vec ,gt_paf_dst_vec ,gt_paf_src_scale ,gt_paf_dst_scale = gt_paf_maps
145+ gt_paf_conf ,gt_paf_src_vec ,gt_paf_dst_vec ,gt_paf_src_bmin , gt_paf_dst_bmin , gt_paf_src_scale ,gt_paf_dst_scale = gt_paf_maps
130146 loss_paf_conf = self .Bce_loss (pd_paf_conf ,gt_paf_conf )
131147 loss_paf_src_scale = self .Scale_loss (pd_paf_src_scale ,gt_paf_src_scale )
132148 loss_paf_dst_scale = self .Scale_loss (pd_paf_dst_scale ,gt_paf_dst_scale )
133- loss_paf_src_vec = self .Laplace_loss (pd_paf_src_vec ,pd_paf_src_logb ,gt_paf_src_vec )
134- loss_paf_dst_vec = self .Laplace_loss (pd_paf_dst_vec ,pd_paf_dst_logb ,gt_paf_dst_vec )
149+ loss_paf_src_vec = self .Laplace_loss (pd_paf_src_vec ,pd_paf_src_logb ,gt_paf_src_vec , gt_paf_src_bmin )
150+ loss_paf_dst_vec = self .Laplace_loss (pd_paf_dst_vec ,pd_paf_dst_logb ,gt_paf_dst_vec , gt_paf_dst_bmin )
135151 loss_paf_maps = [loss_paf_conf ,loss_paf_src_vec ,loss_paf_dst_vec ,loss_paf_src_scale ,loss_paf_dst_scale ]
136152 #calculate total loss
137153 total_loss = (loss_pif_conf * self .lambda_pif_conf + loss_pif_vec * self .lambda_pif_vec + loss_pif_scale * self .lambda_pif_scale +
@@ -168,7 +184,7 @@ def forward(self,x,is_train=False):
168184 if (is_train == False ):
169185 infer_pif_conf = tf .nn .sigmoid (pif_conf )
170186 infer_pif_vec = (pif_vec [:,:]+ self .mesh_grid )* self .stride
171- infer_pif_scale = pif_scale * self .stride
187+ infer_pif_scale = tf . math . softplus ( pif_scale ) * self .stride
172188 return infer_pif_conf ,infer_pif_vec ,pif_logb ,infer_pif_scale
173189 return pif_conf ,pif_vec ,pif_logb ,pif_scale
174190
@@ -204,7 +220,7 @@ def forward(self,x,is_train=False):
204220 infer_paf_conf = tf .nn .sigmoid (paf_conf )
205221 infer_paf_src_vec = (paf_src_vec [:,:]+ self .mesh_grid )* self .stride
206222 infer_paf_dst_vec = (paf_dst_vec [:,:]+ self .mesh_grid )* self .stride
207- infer_paf_src_scale = paf_src_scale * self .stride
208- infer_paf_dst_scale = paf_dst_scale * self .stride
223+ infer_paf_src_scale = tf . math . softplus ( paf_src_scale ) * self .stride
224+ infer_paf_dst_scale = tf . math . softplus ( paf_dst_scale ) * self .stride
209225 return infer_paf_conf ,infer_paf_src_vec ,infer_paf_dst_vec ,paf_src_logb ,paf_dst_logb ,infer_paf_src_scale ,infer_paf_dst_scale
210226 return paf_conf ,paf_src_vec ,paf_dst_vec ,paf_src_logb ,paf_dst_logb ,paf_src_scale ,paf_dst_scale
0 commit comments