55import minibatch
66from data_feeder import DataFeeder
77
8- __all__ = ['infer' ]
8+ __all__ = ['infer' , 'Inference' ]
99
1010
1111class Inference (object ):
1212 """
1313 Inference combines neural network output and parameters together
1414 to do inference.
15+
16+ .. code-block:: python
17+
18+ inferer = Inference(output_layer=prediction, parameters=parameters)
19+ for data_batch in batches:
20+ print inferer.infer(data_batch)
21+
1522
16- :param outptut_layer : The neural network that should be inferenced.
23+ :param output_layer : The neural network that should be inferenced.
1724 :type output_layer: paddle.v2.config_base.Layer or the sequence
1825 of paddle.v2.config_base.Layer
1926 :param parameters: The parameters dictionary.
@@ -56,8 +63,14 @@ def iter_infer_field(self, field, **kwargs):
5663 item = [each_result [each_field ] for each_field in field ]
5764 yield item
5865
59- def infer (self , field = 'value' , ** kwargs ):
66+ def infer (self , input , field = 'value' , ** kwargs ):
67+ """
68+ Infer a data by model.
69+ :param input: input data batch. Should be python iterable object.
70+ :param field: output field.
71+ """
6072 retv = None
73+ kwargs ['input' ] = input
6174 for result in self .iter_infer_field (field = field , ** kwargs ):
6275 if retv is None :
6376 retv = [[] for i in xrange (len (result ))]
@@ -79,7 +92,7 @@ def infer(output_layer, parameters, input, feeding=None, field='value'):
7992
8093 .. code-block:: python
8194
82- result = paddle.infer(outptut_layer =prediction,
95+ result = paddle.infer(output_layer =prediction,
8396 parameters=parameters,
8497 input=SomeData)
8598 print result
0 commit comments