Skip to content

Commit 0076197

Browse files
authored
Merge pull request PaddlePaddle#2067 from reyoung/feature/expose_inference_in_py_api
Expose Inference in Python V2 API.
2 parents aaef586 + 43493b2 commit 0076197

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

python/paddle/v2/inference.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,22 @@
55
import minibatch
66
from data_feeder import DataFeeder
77

8-
__all__ = ['infer']
8+
__all__ = ['infer', 'Inference']
99

1010

1111
class Inference(object):
1212
"""
1313
Inference combines neural network output and parameters together
1414
to do inference.
15+
16+
.. code-block:: python
17+
18+
inferer = Inference(output_layer=prediction, parameters=parameters)
19+
for data_batch in batches:
20+
print inferer.infer(data_batch)
21+
1522
16-
:param outptut_layer: The neural network that should be inferenced.
23+
:param output_layer: The neural network that should be inferenced.
1724
:type output_layer: paddle.v2.config_base.Layer or the sequence
1825
of paddle.v2.config_base.Layer
1926
:param parameters: The parameters dictionary.
@@ -56,8 +63,14 @@ def iter_infer_field(self, field, **kwargs):
5663
item = [each_result[each_field] for each_field in field]
5764
yield item
5865

59-
def infer(self, field='value', **kwargs):
66+
def infer(self, input, field='value', **kwargs):
67+
"""
68+
Infer a data by model.
69+
:param input: input data batch. Should be python iterable object.
70+
:param field: output field.
71+
"""
6072
retv = None
73+
kwargs['input'] = input
6174
for result in self.iter_infer_field(field=field, **kwargs):
6275
if retv is None:
6376
retv = [[] for i in xrange(len(result))]
@@ -79,7 +92,7 @@ def infer(output_layer, parameters, input, feeding=None, field='value'):
7992
8093
.. code-block:: python
8194
82-
result = paddle.infer(outptut_layer=prediction,
95+
result = paddle.infer(output_layer=prediction,
8396
parameters=parameters,
8497
input=SomeData)
8598
print result

0 commit comments

Comments
 (0)