55from tensorlayer .models import Model
66from tensorlayer .layers import BatchNorm2d , Conv2d , DepthwiseConv2d , LayerList , MaxPool2d
77from .define import CocoColor
8+ from .utils import pixel_shuffle ,get_meshgrid
89from ..backbones import Resnet50_backbone
910
1011
@@ -40,30 +41,27 @@ def __init__(self,parts,limbs,colors=CocoColor,n_pos=17,n_limbs=19,hin=368,win=3
4041 self .backbone = backbone (data_format = data_format ,scale_size = self .scale_size )
4142 self .hout = int (hin / self .stride )
4243 self .wout = int (win / self .stride )
43- #generate mesh grid
44- x_range = np .linspace (start = 0 ,stop = self .wout - 1 ,num = self .wout )
45- y_range = np .linspace (start = 0 ,stop = self .hout - 1 ,num = self .hout )
46- mesh_x ,mesh_y = np .meshgrid (x_range ,y_range )
47- self .mesh_grid = np .stack ([mesh_x ,mesh_y ])
4844 #construct head
4945 self .pif_head = self .PifHead (input_features = self .backbone .out_channels ,n_pos = self .n_pos ,n_limbs = self .n_limbs ,\
50- quad_size = self .quad_size ,hout = self .hout ,wout = self .wout ,stride = self .stride ,mesh_grid = self . mesh_grid , data_format = self .data_format )
46+ quad_size = self .quad_size ,hout = self .hout ,wout = self .wout ,stride = self .stride ,data_format = self .data_format )
5147 self .paf_head = self .PafHead (input_features = self .backbone .out_channels ,n_pos = self .n_pos ,n_limbs = self .n_limbs ,\
52- quad_size = self .quad_size ,hout = self .hout ,wout = self .wout ,stride = self .stride ,mesh_grid = self . mesh_grid , data_format = self .data_format )
48+ quad_size = self .quad_size ,hout = self .hout ,wout = self .wout ,stride = self .stride ,data_format = self .data_format )
5349
54- @tf .function (experimental_relax_shapes = True )
55- def forward (self ,x ,is_train = False ):
56- x = self .backbone .forward (x )
57- pif_maps = self .pif_head .forward (x ,is_train = is_train )
58- paf_maps = self .paf_head .forward (x ,is_train = is_train )
50+ # @tf.function(experimental_relax_shapes=True)
51+ def forward (self ,x ,is_train = False ,ret_backbone = False ):
52+ backbone_x = self .backbone .forward (x )
53+ pif_maps = self .pif_head .forward (backbone_x ,is_train = is_train )
54+ paf_maps = self .paf_head .forward (backbone_x ,is_train = is_train )
55+ if (ret_backbone ):
56+ return pif_maps ,paf_maps ,backbone_x
5957 return pif_maps ,paf_maps
6058
61- @tf .function (experimental_relax_shapes = True )
59+ # @tf.function(experimental_relax_shapes=True)
6260 def infer (self ,x ):
63- pif_maps ,paf_maps = self .forward (x ,is_train = False )
61+ pif_maps ,paf_maps , backbone_x = self .forward (x ,is_train = False , ret_backbone = True )
6462 pif_conf ,pif_vec ,_ ,pif_scale = pif_maps
6563 paf_conf ,paf_src_vec ,paf_dst_vec ,_ ,_ ,paf_src_scale ,paf_dst_scale = paf_maps
66- return pif_conf ,pif_vec ,pif_scale ,paf_conf ,paf_src_vec ,paf_dst_vec ,paf_src_scale ,paf_dst_scale
64+ return pif_conf ,pif_vec ,pif_scale ,paf_conf ,paf_src_vec ,paf_dst_vec ,paf_src_scale ,paf_dst_scale , backbone_x
6765
6866 def soft_clamp (self ,x ,max_value = 5.0 ):
6967 above_mask = tf .where (x >= max_value ,1.0 ,0.0 )
@@ -157,7 +155,7 @@ def cal_loss(self,pd_pif_maps,pd_paf_maps,gt_pif_maps,gt_paf_maps):
157155 return loss_pif_maps ,loss_paf_maps ,total_loss
158156
159157 class PifHead (Model ):
160- def __init__ (self ,input_features = 2048 ,n_pos = 19 ,n_limbs = 19 ,quad_size = 2 ,hout = 8 ,wout = 8 ,stride = 8 ,mesh_grid = None , data_format = "channels_first" ):
158+ def __init__ (self ,input_features = 2048 ,n_pos = 19 ,n_limbs = 19 ,quad_size = 2 ,hout = 8 ,wout = 8 ,stride = 8 ,data_format = "channels_first" ):
161159 super ().__init__ ()
162160 self .input_features = input_features
163161 self .n_pos = n_pos
@@ -167,29 +165,32 @@ def __init__(self,input_features=2048,n_pos=19,n_limbs=19,quad_size=2,hout=8,wou
167165 self .stride = stride
168166 self .quad_size = quad_size
169167 self .out_features = self .n_pos * 5 * (self .quad_size ** 2 )
170- self .mesh_grid = mesh_grid
171168 self .data_format = data_format
172169 self .tf_data_format = "NCHW" if self .data_format == "channels_first" else "NHWC"
173170 self .main_block = Conv2d (n_filter = self .out_features ,in_channels = self .input_features ,filter_size = (1 ,1 ),data_format = self .data_format )
174171
175172 def forward (self ,x ,is_train = False ):
176173 x = self .main_block .forward (x )
177- x = tf .nn .depth_to_space (x ,block_size = self .quad_size ,data_format = self .tf_data_format )
178- x = tf .reshape (x ,[- 1 ,self .n_pos ,5 ,self .hout ,self .wout ])
174+ x = pixel_shuffle (x ,scale = 2 )
175+ low_cut = int ((self .quad_size - 1 )// 2 )
176+ high_cut = int (tf .math .ceil ((self .quad_size - 1 )/ 2.0 ))
177+ hout ,wout = x .shape [2 ],x .shape [3 ]
178+ x = tf .reshape (x ,[- 1 ,self .n_pos ,5 ,hout ,wout ])
179179 pif_conf = x [:,:,0 ,:,:]
180180 pif_vec = x [:,:,1 :3 ,:,:]
181181 pif_logb = x [:,:,3 ,:,:]
182- pif_scale = tf . exp ( x [:,:,4 ,:,:])
182+ pif_scale = x [:,:,4 ,:,:]
183183 #restore vec_maps in inference
184184 if (is_train == False ):
185+ mesh_grid = get_meshgrid (mesh_h = hout ,mesh_w = wout )+ np .array ([1.5 ,1.5 ])[:,np .newaxis ,np .newaxis ]
185186 infer_pif_conf = tf .nn .sigmoid (pif_conf )
186- infer_pif_vec = (pif_vec [:,:]+ self . mesh_grid )* self .stride
187+ infer_pif_vec = (pif_vec [:,:]+ mesh_grid )* self .stride
187188 infer_pif_scale = tf .math .softplus (pif_scale )* self .stride
188189 return infer_pif_conf ,infer_pif_vec ,pif_logb ,infer_pif_scale
189190 return pif_conf ,pif_vec ,pif_logb ,pif_scale
190191
191192 class PafHead (Model ):
192- def __init__ (self ,input_features = 2048 ,n_pos = 19 ,n_limbs = 19 ,quad_size = 2 ,hout = 46 ,wout = 46 ,stride = 8 ,mesh_grid = None , data_format = "channels_first" ):
193+ def __init__ (self ,input_features = 2048 ,n_pos = 19 ,n_limbs = 19 ,quad_size = 2 ,hout = 46 ,wout = 46 ,stride = 8 ,data_format = "channels_first" ):
193194 super ().__init__ ()
194195 self .input_features = input_features
195196 self .n_pos = n_pos
@@ -199,27 +200,30 @@ def __init__(self,input_features=2048,n_pos=19,n_limbs=19,quad_size=2,hout=46,wo
199200 self .wout = wout
200201 self .stride = stride
201202 self .out_features = self .n_limbs * 9 * (self .quad_size ** 2 )
202- self .mesh_grid = mesh_grid
203203 self .data_format = data_format
204204 self .tf_data_format = "NCHW" if self .data_format == "channels_first" else "NHWC"
205205 self .main_block = Conv2d (n_filter = self .out_features ,in_channels = self .input_features ,filter_size = (1 ,1 ),data_format = self .data_format )
206206
207207 def forward (self ,x ,is_train = False ):
208208 x = self .main_block .forward (x )
209- x = tf .nn .depth_to_space (x ,block_size = self .quad_size ,data_format = self .tf_data_format )
210- x = tf .reshape (x ,[- 1 ,self .n_limbs ,9 ,self .hout ,self .wout ])
209+ x = pixel_shuffle (x ,scale = 2 )
210+ low_cut = int ((self .quad_size - 1 )// 2 )
211+ high_cut = int (tf .math .ceil ((self .quad_size - 1 )/ 2.0 ))
212+ hout ,wout = x .shape [2 ],x .shape [3 ]
213+ x = tf .reshape (x ,[- 1 ,self .n_limbs ,9 ,hout ,wout ])
211214 paf_conf = x [:,:,0 ,:,:]
212215 paf_src_vec = x [:,:,1 :3 ,:,:]
213216 paf_dst_vec = x [:,:,3 :5 ,:,:]
214217 paf_src_logb = x [:,:,5 ,:,:]
215218 paf_dst_logb = x [:,:,6 ,:,:]
216- paf_src_scale = tf . exp ( x [:,:,7 ,:,:])
217- paf_dst_scale = tf . exp ( x [:,:,8 ,:,:])
219+ paf_src_scale = x [:,:,7 ,:,:]
220+ paf_dst_scale = x [:,:,8 ,:,:]
218221 #restore vec_maps in inference
219222 if (is_train == False ):
223+ mesh_grid = get_meshgrid (mesh_h = hout ,mesh_w = wout )+ np .array ([1.5 ,1.5 ])[:,np .newaxis ,np .newaxis ]
220224 infer_paf_conf = tf .nn .sigmoid (paf_conf )
221- infer_paf_src_vec = (paf_src_vec [:,:]+ self . mesh_grid )* self .stride
222- infer_paf_dst_vec = (paf_dst_vec [:,:]+ self . mesh_grid )* self .stride
225+ infer_paf_src_vec = (paf_src_vec [:,:]+ mesh_grid )* self .stride
226+ infer_paf_dst_vec = (paf_dst_vec [:,:]+ mesh_grid )* self .stride
223227 infer_paf_src_scale = tf .math .softplus (paf_src_scale )* self .stride
224228 infer_paf_dst_scale = tf .math .softplus (paf_dst_scale )* self .stride
225229 return infer_paf_conf ,infer_paf_src_vec ,infer_paf_dst_vec ,paf_src_logb ,paf_dst_logb ,infer_paf_src_scale ,infer_paf_dst_scale
0 commit comments