@@ -46,7 +46,7 @@ def __init__(self,
4646 use_dark = True ):
4747 """
4848 HRNet network, see https://arxiv.org/abs/1902.09212
49-
49+
5050 Args:
5151 backbone (nn.Layer): backbone instance
5252 post_process (object): `HRNetPostProcess` instance
@@ -132,10 +132,10 @@ def __init__(self, use_dark=True):
132132
133133 def get_max_preds (self , heatmaps ):
134134 '''get predictions from score maps
135-
135+
136136 Args:
137137 heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
138-
138+
139139 Returns:
140140 preds: numpy.ndarray([batch_size, num_joints, 2]), keypoints coords
141141 maxvals: numpy.ndarray([batch_size, num_joints, 2]), the maximum confidence of the keypoints
@@ -220,12 +220,12 @@ def dark_postprocess(self, hm, coords, kernelsize):
220220 def get_final_preds (self , heatmaps , center , scale , kernelsize = 3 ):
221221 """the highest heatvalue location with a quarter offset in the
222222 direction from the highest response to the second highest response.
223-
223+
224224 Args:
225225 heatmaps (numpy.ndarray): The predicted heatmaps
226226 center (numpy.ndarray): The boxes center
227227 scale (numpy.ndarray): The scale factor
228-
228+
229229 Returns:
230230 preds: numpy.ndarray([batch_size, num_joints, 2]), keypoints coords
231231 maxvals: numpy.ndarray([batch_size, num_joints, 1]), the maximum confidence of the keypoints
@@ -341,10 +341,7 @@ def __init__(
341341 self .deploy = False
342342 self .num_joints = num_joints
343343
344- self .final_conv = L .Conv2d (width , num_joints , 1 , 1 , 0 , bias = True )
345- # for heatmap output
346- self .final_conv_new = L .Conv2d (
347- width , num_joints * 32 , 1 , 1 , 0 , bias = True )
344+ self .final_conv = L .Conv2d (width , num_joints * 32 , 1 , 1 , 0 , bias = True )
348345
349346 @classmethod
350347 def from_config (cls , cfg , * args , ** kwargs ):
@@ -356,20 +353,19 @@ def from_config(cls, cfg, *args, **kwargs):
356353 def _forward (self ):
357354 feats = self .backbone (self .inputs ) # feats:[[batch_size, 40, 32, 24]]
358355
359- hrnet_outputs = self .final_conv_new (feats [0 ])
356+ hrnet_outputs = self .final_conv (feats [0 ])
360357 res = soft_argmax (hrnet_outputs , self .num_joints )
361-
362- if self .training :
363- return self .loss (res , self .inputs )
364- else : # export model need
365- return res
358+ return res
366359
367360 def get_loss (self ):
368- return self ._forward ()
361+ pose3d = self ._forward ()
362+ loss = self .loss (pose3d , None , self .inputs )
363+ outputs = {'loss' : loss }
364+ return outputs
369365
370366 def get_pred (self ):
371367 res_lst = self ._forward ()
372- outputs = {'keypoint ' : res_lst }
368+ outputs = {'pose3d ' : res_lst }
373369 return outputs
374370
375371 def flip_back (self , output_flipped , matched_parts ):
@@ -427,16 +423,23 @@ def from_config(cls, cfg, *args, **kwargs):
427423 return {'backbone' : backbone , }
428424
429425 def _forward (self ):
430- feats = self .backbone (self .inputs ) # feats:[[batch_size, 40, 32, 24]]
426+ '''
427+ self.inputs is a dict
428+ '''
429+ feats = self .backbone (
430+ self .inputs ) # feats:[[batch_size, 40, width/4, height/4]]
431+
432+ hrnet_outputs = self .final_conv (
433+ feats [0 ]) # hrnet_outputs: [batch_size, num_joints*32,32,32]
431434
432- hrnet_outputs = self .final_conv (feats [0 ])
433435 flatten_res = self .flatten (
434- hrnet_outputs ) # [batch_size, 24, (height/4)*(width/4)]
436+ hrnet_outputs ) # [batch_size,num_joints*32,32*32]
437+
435438 res = self .fc1 (flatten_res )
436439 res = self .act1 (res )
437440 res = self .fc2 (res )
438441 res = self .act2 (res )
439- res = self .fc3 (res ) # [batch_size, 24, 3]
442+ res = self .fc3 (res )
440443
441444 if self .training :
442445 return self .loss (res , self .inputs )
@@ -448,7 +451,7 @@ def get_loss(self):
448451
449452 def get_pred (self ):
450453 res_lst = self ._forward ()
451- outputs = {'keypoint ' : res_lst }
454+ outputs = {'pose3d ' : res_lst }
452455 return outputs
453456
454457 def flip_back (self , output_flipped , matched_parts ):
0 commit comments