@@ -9,23 +9,35 @@ from utils import PathRecover, log
99import persistence as pst
1010import os
1111import repo
12+ import argparse
1213
1314$ceroot = config .workspace
1415os .environ ['ceroot' ] = config .workspace
1516mode = os .environ .get ('mode' , 'evaluation' )
1617
18+ def parse_args ():
19+ parser = argparse .ArgumentParser ("Tool for running CE models" )
20+ parser .add_argument (
21+ '--modified' ,
22+ action = 'store_true' ,
23+ help = 'if set, we will just run modified models.' )
24+ args = parser .parse_args ()
25+ return args
1726
1827def main ():
1928 #try_start_mongod()
20- refresh_baseline_workspace ()
21- suc = evaluate_tasks ()
29+ args = parse_args ()
30+ if not args .modified :
31+ refresh_baseline_workspace ()
32+ suc = evaluate_tasks (args )
2233 if suc :
2334 display_success_info ()
2435 if mode != "baseline_test" :
2536 update_baseline ()
2637 exit 0
2738 else :
28- display_fail_info ()
39+ if not args .modified :
40+ display_fail_info ()
2941 sys .exit (- 1 )
3042 exit - 1
3143
@@ -72,7 +84,7 @@ def refresh_baseline_workspace():
7284 git clone @(config .baseline_repo_url ) @(config .baseline_path )
7385
7486
75- def evaluate_tasks ():
87+ def evaluate_tasks (args ):
7688 '''
7789 Evaluate all the tasks. It will continue to run all the tasks even
7890 if any task is failed to get a summary.
@@ -82,14 +94,18 @@ def evaluate_tasks():
8294 commit_time = repo .get_commit_date (config .paddle_path )
8395 log .warn ('commit' , paddle_commit )
8496 all_passed = True
85- tasks = [v for v in get_tasks ()]
86- for task in get_tasks ():
97+ if args .modified :
98+ tasks = [v for v in get_changed_tasks ()]
99+ else :
100+ tasks = [v for v in get_tasks ()]
101+ for task in tasks :
87102 passed , eval_infos , kpis , kpi_types = evaluate (task )
88103
89104 if mode != "baseline_test" :
90105 log .warn ('add evaluation %s result to mongodb' % task )
91106 kpi_objs = get_kpi_tasks (task )
92- pst .add_evaluation_record (commitid = paddle_commit ,
107+ if not args .modified :
108+ pst .add_evaluation_record (commitid = paddle_commit ,
93109 date = commit_time ,
94110 task = task ,
95111 passed = passed ,
@@ -186,4 +202,17 @@ def get_kpi_tasks(task_name):
186202 tracking_kpis = env ['tracking_kpis' ]
187203 return tracking_kpis
188204
205+
206+ def get_changed_tasks ():
207+ tasks = []
208+ cd @(config .baseline_path )
209+ out = $(git diff master | grep "diff --git" )
210+ out = out .strip ()
211+ for item in out .split ('\n ' ):
212+ task = item .split ()[3 ].split ('/' )[1 ]
213+ if task not in tasks :
214+ tasks .append (task )
215+ log .warn ("changed tasks: %s" % tasks )
216+ return ['resnet50' ]
217+
189218main ()
0 commit comments