|
16 | 16 | import contextlib |
17 | 17 | from framework import Program, default_main_program, Variable |
18 | 18 | from . import core |
| 19 | +import sys |
19 | 20 |
|
20 | 21 | __all__ = [ |
21 | 22 | 'Executor', 'global_scope', 'scope_guard', 'switch_scope', 'fetch_var' |
@@ -207,7 +208,7 @@ def _add_program_cache(self, program_cache_key, program): |
207 | 208 | def _add_feed_fetch_ops(self, program, feed, fetch_list, feed_var_name, |
208 | 209 | fetch_var_name): |
209 | 210 | tmp_program = program.clone() |
210 | | - |
| 211 | + """ |
211 | 212 | global_block = tmp_program.global_block() |
212 | 213 |
|
213 | 214 | if feed_var_name in global_block.vars: |
@@ -246,7 +247,7 @@ def _add_feed_fetch_ops(self, program, feed, fetch_list, feed_var_name, |
246 | 247 | inputs={'X': [var]}, |
247 | 248 | outputs={'Out': [fetch_var]}, |
248 | 249 | attrs={'col': i}) |
249 | | - |
| 250 | + """ |
250 | 251 | return tmp_program |
251 | 252 |
|
252 | 253 | def _feed_data(self, program, feed, feed_var_name, scope): |
@@ -277,7 +278,8 @@ def run(self, |
277 | 278 | fetch_var_name='fetch', |
278 | 279 | scope=None, |
279 | 280 | return_numpy=True, |
280 | | - use_program_cache=False): |
| 281 | + use_program_cache=False, |
| 282 | + keep_create=False): |
281 | 283 | """ Run program by this Executor. Feed data by feed map, fetch result by fetch_list. |
282 | 284 |
|
283 | 285 | Python executor takes a program, add feed operators and fetch operators to this program according |
@@ -329,12 +331,14 @@ def run(self, |
329 | 331 | program = cached_program |
330 | 332 | else: |
331 | 333 | self.program_caches.pop(cache_key, None) |
332 | | - program = self._add_feed_fetch_ops( |
333 | | - program=program, |
334 | | - feed=feed, |
335 | | - fetch_list=fetch_list, |
336 | | - feed_var_name=feed_var_name, |
337 | | - fetch_var_name=fetch_var_name) |
| 334 | + while keep_create: |
| 335 | + program = self._add_feed_fetch_ops( |
| 336 | + program=program, |
| 337 | + feed=feed, |
| 338 | + fetch_list=fetch_list, |
| 339 | + feed_var_name=feed_var_name, |
| 340 | + fetch_var_name=fetch_var_name) |
| 341 | + sys.stderr.write('created a program\n') |
338 | 342 |
|
339 | 343 | self._feed_data(program, feed, feed_var_name, scope) |
340 | 344 | self.executor.run(program.desc, scope, 0, True, True) |
|
0 commit comments