1313与train-with-paddle.py不同,这里不需要重新训练模型,只需要加载训练生成的parameters.tar 
1414文件来获取模型参数,对这组参数也就是训练完的模型进行检测。 
15151.载入数据和预处理:load_data() 
16- 2.初始化 
17- 3.配置网络结构 
18- 4.获取训练和测试数据 
19- 5.从parameters.tar文件直接获取模型参数 
20- 6.根据模型参数和测试数据来预测结果 
16+ 2.从parameters.tar文件直接获取模型参数 
17+ 3.初始化 
18+ 4.配置网络结构 
19+ 5.获取测试数据 
20+ 6.根据测试数据获得预测结果 
21+ 7.将预测结果转化为二分类结果 
22+ 8.预测图片是否为猫 
2123""" 
2224
2325import  os 
@@ -129,18 +131,6 @@ def get_data(data_creator):
129131 return  result 
130132
131133
132- # 获取test_data 
133- def  get_test_data ():
134-  """ 
135-  使用test()来获取测试数据 
136- 
137-  Args: 
138-  Return: 
139-  get_data(test()) -- 包含测试数据(image)和标签(label)的python字典 
140-  """ 
141-  return  get_data (test ())
142- 
143- 
144134# 二分类结果 
145135def  get_binary_result (probs ):
146136 """ 
@@ -160,35 +150,54 @@ def get_binary_result(probs):
160150 return  binary_result 
161151
162152
163- def  main ():
153+ # 配置网络结构和设置参数 
154+ def  netconfig ():
164155 """ 
165-  预测结果并检验模型准确率  
156+  配置网络结构和设置参数  
166157 Args: 
167158 Return: 
159+  image -- 输入层,DATADIM维稠密向量 
160+  y_predict -- 输出层,Sigmoid作为激活函数 
168161 """ 
169-  global  PARAMETERS 
170-  paddle .init (use_gpu = False , trainer_count = 1 )
171-  load_data ()
172-  if  not  os .path .exists ('params_pass_1900.tar' ):
173-  print ("Params file doesn't exists." )
174-  return 
175-  with  open ('params_pass_1900.tar' , 'r' ) as  f :
176-  PARAMETERS  =  paddle .parameters .Parameters .from_tar (f )
177- 
178162 # 输入层,paddle.layer.data表示数据层 
179163 # name=’image’:名称为image 
180164 # type=paddle.data_type.dense_vector(DATADIM):数据类型为DATADIM维稠密向量 
181165 image  =  paddle .layer .data (
182166 name = 'image' , type = paddle .data_type .dense_vector (DATADIM ))
183167
184-  # 输入层,paddle.layer.data表示数据层 
185-  # name=’label’:名称为image 
186-  # type=paddle.data_type.dense_vector(DATADIM):数据类型为DATADIM维向量 
168+  # 输出层,paddle.layer.fc表示全连接层,input=image: 该层输入数据为image 
169+  # size=1:神经元个数,act=paddle.activation.Sigmoid():激活函数为Sigmoid() 
187170 y_predict  =  paddle .layer .fc (
188171 input = image , size = 1 , act = paddle .activation .Sigmoid ())
189172
173+  data  =  [image , y_predict ]
174+ 
175+  return  data 
176+ 
177+ 
178+ def  main ():
179+  """ 
180+  预测结果并检验模型准确率 
181+  Args: 
182+  Return: 
183+  """ 
184+  global  PARAMETERS 
185+ 
186+  # 载入数据 
187+  load_data ()
188+ 
189+  # 载入参数 
190+  with  open ('params_pass_1920.tar' , 'r' ) as  f :
191+  PARAMETERS  =  paddle .parameters .Parameters .from_tar (f )
192+ 
193+  # 初始化 
194+  paddle .init (use_gpu = False , trainer_count = 1 )
195+ 
196+  # 配置网络结构 
197+  image , y_predict  =  netconfig ()
198+ 
190199 # 获取测试数据 
191-  test_data  =  get_test_data ( )
200+  test_data  =  get_data ( test () )
192201
193202 # 根据test_data预测结果 
194203 probs  =  paddle .infer (
@@ -198,6 +207,7 @@ def main():
198207 # 将结果转化为二分类结果 
199208 binary_result  =  get_binary_result (probs )
200209
210+  # 预测图片是否为猫 
201211 index  =  12 
202212 print  ("y = "  +  str (binary_result [index ]) +  ", you predicted that it is a \" "  + 
203213 CLASSES [binary_result [index ]].decode ("utf-8" ) +  "\"  picture." )
0 commit comments