Skip to content

Commit 1c16986

Browse files
authored
[New Algorithm Support] PifPaf (#322)
* pifpaf model: 1.pifpaf model added 2.pipaf training process adding * pifpaf training: preprocess finished: 1.pif_map generation 2.paf_map generation * pifpaf training: loss calculation added * pifpaf training: visualization processing added * pifpaf training: 1.pifpaf Config module integration added 3.pifpaf keypoint converter added, MPII left(1) 4.pifpaf parallel training added *pifpaf evaluation: 1pifpaf infer processing added * pifpaf debug: 1.train pre-processing debug finished 2.infer processing debug finished * pifpaf evaluating: 1.inference finished! 2.evaluation progressing 3.visualization progressing * pifpaf evaluation: 1.visualization finished! * pifpaf training debug: 1.debug progressing(1) * pifpaf training debug: 1.debug progressing(2) * pifpaf training debug: 1.debug processing(3) * pifpaf training debug: 1.debug progressing(4) * I.pipaf training debug: 1.debug progressing(5) II.code tidy up: 1.pose proposal * I.pifpaf training debug: 1.debug finished! II.docs tidy up 1.readme accuracy table added! * I.docs tidy-up 1.readme modified. * I.compatibility test 1.tensorflow 2.1.2 compatibity tested * I.pifpaf training 1.pifpaf training adjustment * I.pifpaf training 1.pifpaf training debuged II.resnet50 debug: 1.resnet50 backbone debug finished * I.Pifpaf evaluate: 1.pifpaf infer debug finished 2.pifpaf evaluate process debug finished
1 parent ddfc45a commit 1c16986

File tree

26 files changed

+2072
-95
lines changed

26 files changed

+2072
-95
lines changed

README.md

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,11 @@ HyperPose is a library for building human pose estimation systems that can effic
2929

3030
HyperPose has two key features, which are not available in existing libraries:
3131

32-
- **Flexible training platform**: HyperPose provides flexible Python APIs to build many useful pose estimation models (e.g., OpenPose and PoseProposalNetwork). HyperPose users can, for example, customize data augmentation, use parallel GPUs for training, and replace deep neural networks (e.g., changing from ResNet to MobileNet), thus building models specific to their real-world scenarios.
32+
- **Flexible training platform**: HyperPose provides flexible Python APIs to provide a customise pipeline for developing various pose estimation models. HyperPose users can:
33+
* make use of uniform pipelines for train,evaluation,visualization,pre-processing and post-processing across various models (e.g., OpenPose,Pifpaf,PoseProposal Network)
34+
* customise model and dataset for their own use(e.g. user-defined model,user-defined dataset,mitiple dataset combination)
35+
* parallel training using multiple GPUs(using *Kungfu* adaptive distribute training library)
36+
thus building models specific to their real-world scenarios.
3337
- **High-performance pose estimation**: HyperPose achieves real-time pose estimation though a high-performance pose estimation engine. This engine implements numerous system optimizations: pipeline parallelism, model inference with TensorRT, CPU/GPU hybrid scheduling, and many others. This allows HyperPose to **run 4x FASTER than OpenPose and 10x FASTER than TF-Pose**.
3438

3539
## Documentation
@@ -81,6 +85,18 @@ We compare the prediction performance of HyperPose with [OpenPose 1.6](https://g
8185
| OpenPose (MobileNet) | 17.9 MB | 432 x 368 | **84.32 FPS** | 8.5 FPS (TF-Pose) |
8286
| OpenPose (ResNet18) | 45.0 MB | 432 x 368 | **62.52 FPS** | N/A |
8387

88+
## Accuracy
89+
We evaluate accuracy of pose estimation models developed by hyperpose (mainly over Mscoco2017 dataset). the development environment is Ubuntu16.04, with 4 V100-DGXs and 24 Intel Xeon CPU. The training procedure takes 1~2 weeks using 1 V100-DGX for each model. (If you want to train from strach, loading the pretrained backbone weight is recommended.)
90+
91+
| HyperPose Configuration | DNN Size | Input Size | Evaluate Dataset | Accuracy-hyperpose (Iou=0.50:0.95) | Accuracy-original (Iou=0.50:0.95) |
92+
| -------------------- | ---------- | ------------- | ---------------- | --------------------- | ----------------------- |
93+
| Openpose (vgg19) | 199 MB | 432 x 368 | Mscoco2014(random 1160 images) | 57.0 map | 58.4 map |
94+
| LightweightOpenpose (dailated mobilenet) | 17.7 MB | 432 x 368 | Mscoco2017(all 5000 images) | 46.1 map | 42.8 map |
95+
| LightweightOpenpose (mobilenet-thin) | 17.4 MB | 432 x 368 | Mscoco2017(all 5000 images) | 44.2 map | 28.06 map (Mscoco2014) |
96+
| LightweightOpenpose (tinyvgg) | 23.6 MB | 432 x 368 | Mscoco2017(all 5000 images) | 47.3 map | - |
97+
| LightweightOpenpose (resnet50) | 42.7 MB | 432 x 368 | Mscoco2017(all 5000 images) | 48.2 map | - |
98+
| PoseProposal (resnet18) | 45.2 MB | 384 x 384 | MPII(all 2729 images) | 54.9 map (Pckh) | 72.8 map (Pckh)|
99+
84100
</a>
85101
<p align="center">
86102
<img src="./docs/markdown/images/demo-xbd.gif", width="600">

hyperpose/Config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ def get_config():
6565
from .config_mbtopps import model,train,eval,data,log
6666
elif(update_model.model_type==MODEL.PoseProposal):
6767
from .config_ppn import model,train,eval,data,log
68+
elif(update_model.model_type==MODEL.Pifpaf):
69+
from .config_pifpaf import model,train,eval,data,log
6870
#merge settings with basic configurations
6971
model.update(update_model)
7072
train.update(update_train)

hyperpose/Config/config_pifpaf.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import os
2+
from .define import MODEL,DATA,TRAIN,BACKBONE
3+
from easydict import EasyDict as edict
4+
5+
#model configuration
6+
model = edict()
7+
# number of keypoints + 1 for background
8+
model.n_pos = 17
9+
model.num_channels=128
10+
# input size during training , 240
11+
model.hin = 368
12+
model.win = 432
13+
# output size during training (default 46)
14+
model.hout = 46
15+
model.wout = 54
16+
model.model_type = MODEL.Pifpaf
17+
model.model_name = "default_name"
18+
model.model_backbone=BACKBONE.Default
19+
model.data_format = "channels_first"
20+
# save directory
21+
model.model_dir = f"./save_dir/{model.model_name}/model_dir"
22+
23+
#train configuration
24+
train=edict()
25+
train.batch_size = 4
26+
train.save_interval = 2000
27+
# total number of step
28+
train.n_step = 1000000
29+
# initial learning rate
30+
train.lr_init = 1e-4
31+
# evey number of step to decay lr
32+
train.lr_decay_every_step = 136120
33+
# decay lr factor
34+
train.lr_decay_factor = 0.2
35+
train.lr_decay_steps=[420000,630000]
36+
train.weight_decay_factor = 0.0
37+
train.train_type=TRAIN.Single_train
38+
train.vis_dir=f"./save_dir/{model.model_name}/train_vis_dir"
39+
40+
#eval configuration
41+
eval =edict()
42+
eval.batch_size=8
43+
eval.vis_dir= f"./save_dir/{model.model_name}/eval_vis_dir"
44+
45+
#data configuration
46+
data = edict()
47+
data.dataset_type = DATA.MSCOCO # coco, custom, coco_and_custom
48+
data.dataset_version = "2017" # MSCOCO version 2014 or 2017
49+
data.dataset_path = "./data"
50+
data.dataset_filter=None
51+
data.vis_dir=f"./save_dir/data_vis_dir"
52+
53+
#log configuration
54+
log = edict()
55+
log.log_interval = 100
56+
log.log_path= f"./save_dir/{model.model_name}/log.txt"

hyperpose/Config/define.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class MODEL(Enum):
1515
LightweightOpenpose=1
1616
PoseProposal=2
1717
MobilenetThinOpenpose=3
18+
Pifpaf=4
1819

1920
class DATA(Enum):
2021
MSCOCO=0
@@ -33,5 +34,5 @@ class KUNGFU(Enum):
3334

3435
class OPTIM(Enum):
3536
Adam=0
36-
SGD=2
3737
RMSprop=1
38+
SGD=2

hyperpose/Dataset/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ def get_dataset(config):
5656
elif(model_type==MODEL.PoseProposal):
5757
from .mscoco_dataset.define import ppn_input_converter as input_kpt_cvter
5858
from .mscoco_dataset.define import ppn_output_converter as output_kpt_cvter
59+
elif(model_type==MODEL.Pifpaf):
60+
from .mscoco_dataset.define import pifpaf_input_converter as input_kpt_cvter
61+
from .mscoco_dataset.define import pifpaf_output_converter as output_kpt_cvter
5962
dataset=MSCOCO_dataset(config,input_kpt_cvter,output_kpt_cvter)
6063
dataset.prepare_dataset()
6164
elif(dataset_type==DATA.MPII):

hyperpose/Dataset/mscoco_dataset/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def visualize(self,vis_num):
5858
'''
5959

6060
train_dataset=self.get_train_dataset()
61-
visualize(self.vis_dir,vis_num,train_dataset,self.parts,self.colors,dataset_name="mpii")
61+
visualize(self.vis_dir,vis_num,train_dataset,self.parts,self.colors,dataset_name="mscoco")
6262

6363
def get_parts(self):
6464
return self.parts

hyperpose/Dataset/mscoco_dataset/define.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,37 @@ def ppn_output_converter(kpt_list):
8686
kpts+=[0.0,0.0,0.0]
8787
else:
8888
kpts+=[x,y,1.0]
89+
return kpts
90+
91+
#convert kpts from pifpaf to mscoco
92+
from_pifpaf_converter={}
93+
for part_idx in range(0,len(CocoPart)):
94+
from_pifpaf_converter[part_idx]=part_idx
95+
#convert kpts from mscoco to pifpaf
96+
to_pifpaf_converter={}
97+
for part_idx in range(0,len(CocoPart)):
98+
to_pifpaf_converter[part_idx]=part_idx
99+
100+
def pifpaf_input_converter(coco_kpts):
101+
xs=coco_kpts[0::3]
102+
ys=coco_kpts[1::3]
103+
vs=coco_kpts[2::3]
104+
lost_idx=np.where(vs<=0)[0]
105+
xs[lost_idx]=-1000
106+
ys[lost_idx]=-1000
107+
cvt_kpts=np.array([xs,ys]).transpose()
108+
return cvt_kpts
109+
110+
def pifpaf_output_converter(kpt_list):
111+
kpts=[]
112+
for coco_idx in range(0,len(CocoPart)):
113+
flag=False
114+
if(coco_idx in to_pifpaf_converter):
115+
model_idx=to_pifpaf_converter[coco_idx]
116+
x,y=kpt_list[model_idx]
117+
if(x>=0 and y>=0):
118+
kpts==[x,y,1.0]
119+
flag=True
120+
if(not flag):
121+
kpts+=[0.0,0.0,0.0]
89122
return kpts

hyperpose/Model/__init__.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,12 @@ def get_model(config):
7777
from .pose_proposal.utils import get_limbs
7878
model.parts=get_parts(dataset_type)
7979
model.limbs=get_limbs(dataset_type)
80+
elif(model_type == MODEL.Pifpaf):
81+
from .pifpaf.utils import get_parts
82+
from .pifpaf.utils import get_limbs
83+
model.parts=get_parts(dataset_type)
84+
model.limbs=get_limbs(dataset_type)
85+
8086
userdef_parts=config.model.userdef_parts
8187
userdef_limbs=config.model.userdef_limbs
8288
if(userdef_parts!=None):
@@ -93,18 +99,21 @@ def get_model(config):
9399
hin=model.hin,win=model.win,hout=model.hout,wout=model.wout,backbone=backbone,pretraining=pretraining,data_format=model.data_format)
94100
elif model_type == MODEL.LightweightOpenpose:
95101
from .openpose import LightWeightOpenPose as model_arch
96-
ret_model=model_arch(parts=model.parts,n_pos=len(model.parts),limbs=model.limbs,n_limbs=len(model.limbs),num_channels=model.num_channels,hin=model.hin,win=model.win,\
97-
hout=model.hout,wout=model.wout,backbone=backbone,pretraining=pretraining,data_format=model.data_format)
102+
ret_model=model_arch(parts=model.parts,n_pos=len(model.parts),limbs=model.limbs,n_limbs=len(model.limbs),num_channels=model.num_channels,\
103+
hin=model.hin,win=model.win,hout=model.hout,wout=model.wout,backbone=backbone,pretraining=pretraining,data_format=model.data_format)
98104
elif model_type == MODEL.MobilenetThinOpenpose:
99105
from .openpose import MobilenetThinOpenpose as model_arch
100-
ret_model=model_arch(parts=model.parts,n_pos=len(model.parts),limbs=model.limbs,n_limbs=len(model.limbs),num_channels=model.num_channels,hin=model.hin,win=model.win,\
101-
hout=model.hout,wout=model.wout,backbone=backbone,pretraining=pretraining,data_format=model.data_format)
106+
ret_model=model_arch(parts=model.parts,n_pos=len(model.parts),limbs=model.limbs,n_limbs=len(model.limbs),num_channels=model.num_channels,\
107+
hin=model.hin,win=model.win,hout=model.hout,wout=model.wout,backbone=backbone,pretraining=pretraining,data_format=model.data_format)
102108
elif model_type == MODEL.PoseProposal:
103109
from .pose_proposal import PoseProposal as model_arch
104110
ret_model=model_arch(parts=model.parts,K_size=len(model.parts),limbs=model.limbs,L_size=len(model.limbs),hnei=model.hnei,wnei=model.wnei,lmd_rsp=model.lmd_rsp,\
105111
lmd_iou=model.lmd_iou,lmd_coor=model.lmd_coor,lmd_size=model.lmd_size,lmd_limb=model.lmd_limb,backbone=backbone,\
106112
pretraining=pretraining,data_format=model.data_format)
107-
#print(f"\n!!!test in get_model: parts:{model.parts} limbs:{model.limbs}\n\n")
113+
elif model_type == MODEL.Pifpaf:
114+
from .pifpaf import Pifpaf as model_arch
115+
ret_model=model_arch(parts=model.parts,n_pos=len(model.parts),limbs=model.limbs,n_limbs=len(model.limbs),hin=model.hin,win=model.win,\
116+
scale_size=32,pretraining=pretraining,data_format=model.data_format)
108117
else:
109118
raise RuntimeError(f'unknown model type {model_type}')
110119
print(f"using {model_type.name} model arch!")
@@ -143,6 +152,8 @@ def get_train(config):
143152
from .openpose import single_train,parallel_train
144153
elif model_type == MODEL.PoseProposal:
145154
from .pose_proposal import single_train,parallel_train
155+
elif model_type == MODEL.Pifpaf:
156+
from .pifpaf import single_train,parallel_train
146157
else:
147158
raise RuntimeError(f'unknown model type {model_type}')
148159
print(f"training {model_type.name} model...")
@@ -193,6 +204,8 @@ def get_evaluate(config):
193204
from .openpose import evaluate
194205
elif model_type == MODEL.PoseProposal:
195206
from .pose_proposal import evaluate
207+
elif model_type == MODEL.Pifpaf:
208+
from .pifpaf import evaluate
196209
else:
197210
raise RuntimeError(f'unknown model type {model_type}')
198211
evaluate=partial(evaluate,config=config)

0 commit comments

Comments
 (0)