1313# limitations under the License.
1414
1515from py_paddle import DataProviderConverter
16-
16+ import collections
1717import paddle .trainer .PyDataProvider2 as pydp2
1818
1919__all__ = ['DataFeeder' ]
@@ -35,15 +35,30 @@ class DataFeeder(DataProviderConverter):
3535 DataFeeder converts this mini-batch data entries into Arguments in order
3636 to feed it to C++ interface.
3737
38- The example usage:
38+ The simple usage shows below
39+
40+ .. code-block:: python
41+
42+ feeding = ['image', 'label']
43+ data_types = enumerate_data_types_of_data_layers(topology)
44+ feeder = DataFeeder(data_types=data_types, feeding=feeding)
45+
46+ minibatch_data = [([1.0, 2.0, 3.0, ...], 5)]
47+
48+ arg = feeder(minibatch_data)
49+
50+
51+ If mini-batch data and data layers are not one to one mapping, we
52+ could pass a dictionary to feeding parameter to represent the mapping
53+ relationship.
3954
4055
4156 .. code-block:: python
4257
4358 data_types = [('image', paddle.data_type.dense_vector(784)),
4459 ('label', paddle.data_type.integer_value(10))]
45- reader_dict = {'image':0, 'label':1}
46- feeder = DataFeeder(data_types=data_types, reader_dict=reader_dict )
60+ feeding = {'image':0, 'label':1}
61+ feeder = DataFeeder(data_types=data_types, feeding=feeding )
4762 minibatch_data = [
4863 ( [1.0,2.0,3.0,4.0], 5, [6,7,8] ), # first sample
4964 ( [1.0,2.0,3.0,4.0], 5, [6,7,8] ) # second sample
@@ -65,16 +80,23 @@ class DataFeeder(DataProviderConverter):
6580 a tuple of (data_name, data_type).
6681
6782 :type data_types: list
68- :param reader_dict : A dictionary to specify the position of each data
69- in the input data.
70- :type feeding: dict
83+ :param feeding : A dictionary or a sequence to specify the position of each
84+ data in the input data.
85+ :type feeding: dict|collections.Sequence|None
7186 """
7287
7388 def __init__ (self , data_types , feeding = None ):
7489 self .input_names = []
7590 input_types = []
7691 if feeding is None :
7792 feeding = default_feeding_map (data_types )
93+ elif isinstance (feeding , collections .Sequence ):
94+ feed_list = feeding
95+ feeding = dict ()
96+ for i , name in enumerate (feed_list ):
97+ feeding [name ] = i
98+ elif not isinstance (feeding , dict ):
99+ raise TypeError ("Feeding should be dict or sequence or None." )
78100
79101 self .feeding = feeding
80102 for each in data_types :
0 commit comments