11#!/usr/bin/env python2.7
22# -*- encoding:utf-8 -*-
33"""
4- 数据集读取脚本
4+ 数据集读取脚本
55Created on 2017-12-2
6- @author: PaddlePaddle CTR Model
7- @copyright: www.baidu.com
6+
87"""
98from utils import logger , TaskMode , load_dnn_input_record , load_lr_input_record
109
1413class Dataset (object ):
1514 def train (self , path ):
1615 '''
17- 载入数据集
16+ 载入数据集
1817 '''
1918 logger .info ("load trainset from %s" % path )
2019 mode = TaskMode .create_train ()
2120 return self ._parse_creator (path , mode )
2221
2322 def test (self , path ):
2423 '''
25- 载入测试集
24+ 载入测试集
2625 '''
2726 logger .info ("load testset from %s" % path )
2827 mode = TaskMode .create_test ()
2928 return self ._parse_creator (path , mode )
3029
3130 def infer (self , path ):
3231 '''
33- 载入预测集
32+ 载入预测集
3433 '''
3534 logger .info ("load inferset from %s" % path )
3635 mode = TaskMode .create_infer ()
3736 return self ._parse_creator (path , mode )
3837
3938 def _parse_creator (self , path , mode ):
4039 '''
41- 稀疏化数据集
40+ 稀疏化数据集
4241 '''
4342
4443 def _parse ():
@@ -58,7 +57,7 @@ def _parse():
5857
5958def load_data_meta (path ):
6059 '''
61- 从指定路径中读取meta数据,返回lr模型维度和dnn模型维度
60+ 从指定路径中读取meta数据,返回lr模型维度和dnn模型维度
6261 '''
6362 with open (path ) as f :
6463 lines = f .read ().split ('\n ' )
@@ -69,4 +68,4 @@ def load_data_meta(path):
6968 res = map (int , [_ .split (':' )[1 ] for _ in lines ])
7069 logger .info ('dnn input dim: %d' % res [0 ])
7170 logger .info ('lr input dim: %d' % res [1 ])
72- return res
71+ return res
0 commit comments