Skip to content

Commit 1be468b

Browse files
GuoxiaWangqizhaoaoeliuTINA0907
authored
add a model PLSC-ViT (#5697) (#5706)
* add a model and app for VisionTransformer Co-authored-by: liuTINA0907 <65896652+liuTINA0907@users.noreply.github.com> Co-authored-by: qizhaoaoe <2285142981@qq.com> Co-authored-by: liuTINA0907 <65896652+liuTINA0907@users.noreply.github.com>
1 parent 5f14157 commit 1be468b

File tree

13 files changed

+885
-0
lines changed

13 files changed

+885
-0
lines changed

modelcenter/PLSC-ViT/APP/__init__.py

Whitespace-only changes.

modelcenter/PLSC-ViT/APP/app.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import gradio as gr
2+
from predictor import Predictor
3+
4+
model_path = "paddlecv://models/vit/v2.4/imagenet2012-ViT-B_16-224_infer.pdmodel"
5+
params_path = "paddlecv://models/vit/v2.4/imagenet2012-ViT-B_16-224_infer.pdiparams"
6+
label_path = "paddlecv://dataset/imagenet2012_labels.txt"
7+
8+
predictor = None
9+
10+
11+
def model_inference(image):
12+
global predictor
13+
if predictor is None:
14+
predictor = Predictor(
15+
model_path=model_path,
16+
params_path=params_path,
17+
label_path=label_path)
18+
scores, labels = predictor.predict(image)
19+
json_out = {"scores": scores.tolist(), "labels": labels.tolist()}
20+
return image, json_out
21+
22+
23+
def clear_all():
24+
return None, None, None
25+
26+
27+
with gr.Blocks() as demo:
28+
gr.Markdown("Classification based on ViT")
29+
30+
with gr.Column(scale=1, min_width=100):
31+
32+
img_in = gr.Image(
33+
value="https://plsc.bj.bcebos.com/dataset/test_images/cat.jpg",
34+
label="Input")
35+
36+
with gr.Row():
37+
btn1 = gr.Button("Clear")
38+
btn2 = gr.Button("Submit")
39+
40+
img_out = gr.Image(label="Output")
41+
json_out = gr.JSON(label="jsonOutput")
42+
43+
btn2.click(fn=model_inference, inputs=img_in, outputs=[img_out, json_out])
44+
btn1.click(fn=clear_all, inputs=None, outputs=[img_in, img_out, json_out])
45+
gr.Button.style(1)
46+
47+
demo.launch()

modelcenter/PLSC-ViT/APP/app.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
【PLSC-ViT-App-YAML】
2+
3+
APP_Info:
4+
title: PLSC-ViT-App
5+
colorFrom: blue
6+
colorTo: yellow
7+
sdk: gradio
8+
sdk_version: 3.9.1
9+
app_file: app.py
10+
license: apache-2.0
11+
device: cpu
Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import os.path as osp
17+
import sys
18+
import yaml
19+
import time
20+
import shutil
21+
import requests
22+
import tqdm
23+
import hashlib
24+
import base64
25+
import binascii
26+
import tarfile
27+
import zipfile
28+
29+
__all__ = [
30+
'get_model_path',
31+
'get_config_path',
32+
'get_dict_path',
33+
'get_data_path',
34+
]
35+
36+
WEIGHTS_HOME = osp.expanduser("~/.cache/paddlecv/models/plsc")
37+
CONFIGS_HOME = osp.expanduser("~/.cache/paddlecv/configs/plsc")
38+
DICTS_HOME = osp.expanduser("~/.cache/paddlecv/dicts/plsc/")
39+
DATA_HOME = osp.expanduser("~/.cache/paddlecv/dataset/plsc")
40+
# dict of {dataset_name: (download_info, sub_dirs)}
41+
# download info: [(url, md5sum)]
42+
43+
DOWNLOAD_RETRY_LIMIT = 3
44+
45+
PMP_DOWNLOAD_URL_PREFIX = 'https://plsc.bj.bcebos.com/'
46+
47+
48+
def is_url(path):
49+
"""
50+
Whether path is URL.
51+
Args:
52+
path (string): URL string or not.
53+
"""
54+
return path.startswith('http://') \
55+
or path.startswith('https://') \
56+
or path.startswith('paddlecv://')
57+
58+
59+
def parse_url(url):
60+
url = url.replace("paddlecv://", PMP_DOWNLOAD_URL_PREFIX)
61+
return url
62+
63+
64+
def get_model_path(path):
65+
"""Get model path from WEIGHTS_HOME, if not exists,
66+
download it from url.
67+
"""
68+
if not is_url(path):
69+
return path
70+
url = parse_url(path)
71+
path, _ = get_path(url, WEIGHTS_HOME, path_depth=3)
72+
return path
73+
74+
75+
def get_data_path(path):
76+
"""Get model path from DATA_HOME, if not exists,
77+
download it from url.
78+
"""
79+
if not is_url(path):
80+
return path
81+
url = parse_url(path)
82+
path, _ = get_path(url, DATA_HOME, path_depth=1)
83+
return path
84+
85+
86+
def get_config_path(path):
87+
"""Get config path from CONFIGS_HOME, if not exists,
88+
download it from url.
89+
"""
90+
if not is_url(path):
91+
return path
92+
url = parse_url(path)
93+
path, _ = get_path(url, CONFIGS_HOME)
94+
return path
95+
96+
97+
def get_dict_path(path):
98+
"""Get config path from CONFIGS_HOME, if not exists,
99+
download it from url.
100+
"""
101+
if not is_url(path):
102+
return path
103+
url = parse_url(path)
104+
path, _ = get_path(url, DICTS_HOME)
105+
return path
106+
107+
108+
def map_path(url, root_dir, path_depth=1):
109+
# parse path after download to decompress under root_dir
110+
assert path_depth > 0, "path_depth should be a positive integer"
111+
dirname = url
112+
for _ in range(path_depth):
113+
dirname = osp.dirname(dirname)
114+
fpath = osp.relpath(url, dirname)
115+
path = osp.join(root_dir, fpath)
116+
dirname = osp.dirname(path)
117+
return path, dirname
118+
119+
120+
def get_path(url, root_dir, md5sum=None, check_exist=True, path_depth=1):
121+
""" Download from given url to root_dir.
122+
if file or directory specified by url is exists under
123+
root_dir, return the path directly, otherwise download
124+
from url, return the path.
125+
url (str): download url
126+
root_dir (str): root dir for downloading, it should be
127+
WEIGHTS_HOME
128+
md5sum (str): md5 sum of download package
129+
"""
130+
# parse path after download to decompress under root_dir
131+
fullpath, dirname = map_path(url, root_dir, path_depth)
132+
133+
if osp.exists(fullpath) and check_exist:
134+
if not osp.isfile(fullpath) or \
135+
_check_exist_file_md5(fullpath, md5sum, url):
136+
return fullpath, True
137+
else:
138+
os.remove(fullpath)
139+
140+
fullname = _download(url, dirname, md5sum)
141+
return fullpath, False
142+
143+
144+
def _download(url, path, md5sum=None):
145+
"""
146+
Download from url, save to path.
147+
url (str): download url
148+
path (str): download to given path
149+
"""
150+
if not osp.exists(path):
151+
os.makedirs(path)
152+
153+
fname = osp.split(url)[-1]
154+
fullname = osp.join(path, fname)
155+
retry_cnt = 0
156+
157+
while not (osp.exists(fullname) and _check_exist_file_md5(fullname, md5sum,
158+
url)):
159+
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
160+
retry_cnt += 1
161+
else:
162+
raise RuntimeError("Download from {} failed. "
163+
"Retry limit reached".format(url))
164+
165+
# NOTE: windows path join may incur \, which is invalid in url
166+
if sys.platform == "win32":
167+
url = url.replace('\\', '/')
168+
169+
req = requests.get(url, stream=True)
170+
if req.status_code != 200:
171+
raise RuntimeError("Downloading from {} failed with code "
172+
"{}!".format(url, req.status_code))
173+
174+
# For protecting download interupted, download to
175+
# tmp_fullname firstly, move tmp_fullname to fullname
176+
# after download finished
177+
tmp_fullname = fullname + "_tmp"
178+
total_size = req.headers.get('content-length')
179+
with open(tmp_fullname, 'wb') as f:
180+
if total_size:
181+
for chunk in tqdm.tqdm(
182+
req.iter_content(chunk_size=1024),
183+
total=(int(total_size) + 1023) // 1024,
184+
unit='KB'):
185+
f.write(chunk)
186+
else:
187+
for chunk in req.iter_content(chunk_size=1024):
188+
if chunk:
189+
f.write(chunk)
190+
shutil.move(tmp_fullname, fullname)
191+
return fullname
192+
193+
194+
def _check_exist_file_md5(filename, md5sum, url):
195+
# if md5sum is None, and file to check is model file,
196+
# read md5um from url and check, else check md5sum directly
197+
return _md5check_from_url(filename, url) if md5sum is None \
198+
and filename.endswith('pdparams') \
199+
else _md5check(filename, md5sum)
200+
201+
202+
def _md5check_from_url(filename, url):
203+
# For model in bcebos URLs, MD5 value is contained
204+
# in request header as 'content_md5'
205+
req = requests.get(url, stream=True)
206+
content_md5 = req.headers.get('content-md5')
207+
req.close()
208+
if not content_md5 or _md5check(
209+
filename,
210+
binascii.hexlify(base64.b64decode(content_md5.strip('"'))).decode(
211+
)):
212+
return True
213+
else:
214+
return False
215+
216+
217+
def _md5check(fullname, md5sum=None):
218+
if md5sum is None:
219+
return True
220+
221+
md5 = hashlib.md5()
222+
with open(fullname, 'rb') as f:
223+
for chunk in iter(lambda: f.read(4096), b""):
224+
md5.update(chunk)
225+
calc_md5sum = md5.hexdigest()
226+
227+
if calc_md5sum != md5sum:
228+
return False
229+
return True
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import os
2+
import cv2
3+
import numpy as np
4+
import paddle
5+
from download import get_model_path, get_data_path
6+
7+
class Predictor(object):
8+
def __init__(self,
9+
model_type="paddle",
10+
model_path=None,
11+
params_path=None,
12+
label_path=None):
13+
'''
14+
model_path: str, http url
15+
params_path: str, http url, could be downloaded
16+
'''
17+
assert model_type in ["paddle"]
18+
assert model_path is not None and os.path.splitext(model_path)[
19+
1] == '.pdmodel'
20+
assert params_path is not None and os.path.splitext(params_path)[
21+
1] == '.pdiparams'
22+
23+
import paddle.inference as paddle_infer
24+
infer_model = get_model_path(model_path)
25+
infer_params = get_model_path(params_path)
26+
config = paddle_infer.Config(infer_model, infer_params)
27+
self.predictor = paddle_infer.create_predictor(config)
28+
self.input_names = self.predictor.get_input_names()
29+
self.output_names = self.predictor.get_output_names()
30+
self.labels = self.parse_labes(get_data_path(label_path))
31+
self.model_type = model_type
32+
33+
def predict(self, img):
34+
35+
if self.preprocess is not None:
36+
inputs = self.preprocess(img)
37+
else:
38+
inputs = img
39+
for input_name in self.input_names:
40+
input_tensor = self.predictor.get_input_handle(input_name)
41+
input_tensor.copy_from_cpu(inputs[input_name])
42+
self.predictor.run()
43+
outputs = []
44+
for output_idx in range(len(self.output_names)):
45+
output_tensor = self.predictor.get_output_handle(
46+
self.output_names[output_idx])
47+
outputs.append(output_tensor.copy_to_cpu())
48+
if self.postprocess is not None:
49+
output_data = self.postprocess(outputs)
50+
else:
51+
output_data = outputs
52+
53+
return output_data
54+
55+
def preprocess(self, img):
56+
img = cv2.resize(img, (224, 224))
57+
scale = 1.0 / 255.0
58+
mean = 0.5
59+
std = 0.5
60+
img = (img.astype('float32') * scale - mean) / std
61+
img = img[np.newaxis, :, :, :]
62+
img = img.transpose((0, 3, 1, 2))
63+
return {'x': img}
64+
65+
@staticmethod
66+
def parse_labes(label_path):
67+
with open(label_path, 'r') as f:
68+
labels = []
69+
for line in f:
70+
if len(line) < 2:
71+
continue
72+
label = line.strip().split(',')[0].split(' ')[2]
73+
labels.append(label)
74+
return labels
75+
76+
@staticmethod
77+
def softmax(x, epsilon=1e-6):
78+
exp_x = np.exp(x)
79+
sfm = (exp_x + epsilon) / (np.sum(exp_x) + epsilon)
80+
return sfm
81+
82+
def postprocess(self, logits):
83+
pred = np.array(logits).squeeze()
84+
pred = self.softmax(pred)
85+
class_idx = pred.argsort()[::-1]
86+
return pred[class_idx[:5]], np.array(self.labels)[class_idx[:5]]
87+
88+
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
gradio
2+
opencv-python
3+
paddlepaddle
4+
PyYAML
5+
shapely
6+
scipy
7+
Cython
8+
numpy
9+
setuptools
10+
pillow
11+
tqdm

0 commit comments

Comments
 (0)