Skip to content

Commit 751bcd2

Browse files
authored
fix mnist example, test=develop (#119)
1 parent 64fea5a commit 751bcd2

File tree

3 files changed

+19
-17
lines changed

3 files changed

+19
-17
lines changed

example/distill/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ train_reader = dr.set_sample_list_generator(train_reader)
8282
```
8383
The run student code.
8484
``` python
85-
python train_student.py
85+
python train_with_fleet.py --use_distill_service True
8686
```
8787

8888
## On Kubernetes

example/distill/mnist_distill/README_CN.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,12 @@ python train_with_fleet.py --use_distill_service True --distill_teachers 127.0.0
4343
teacher服务使用paddle_serving部署,需保存成serving模型。可以有两种方式获取(详见[如何保存Serving模型](https://github.com/PaddlePaddle/Serving/blob/develop/doc/SAVE.md))。
4444
1. 直接在训练中保存serving模型。
4545
``` bash
46-
python train_with_fleet.py --save_serving_model
46+
python train_with_fleet.py --save_serving_model True
4747
```
48-
保存的代码见[train_with_fleet.py](train_with_fleet.py)。模型输入为img,模型输出为prediction,mnist_model为serving模型的目录
49-
serving_conf为保存的client配置文件。
48+
保存的代码见[train_with_fleet.py](train_with_fleet.py)。模型输入为img,模型输出为prediction。
49+
模型保存到output目录,mnist_model为保存的serving模型,serving_conf为保存的client配置文件。
5050
``` bash
51-
serving_io.save_model("mnist_cnn_model", "serving_conf",
51+
serving_io.save_model("output/mnist_cnn_model", "output/serving_conf",
5252
{img.name: img}, {prediction.name: prediction},
5353
test_program)
5454
```
@@ -103,7 +103,7 @@ dr.set_dynamic_teacher(discovery_servers, teacher_service_name)
103103
``` python
104104
python -m paddle_edl.distill.redis.balance_server \
105105
--server 127.0.0.1:7001 \
106-
--db_endpoints 127.0.0.1:6379```
106+
--db_endpoints 127.0.0.1:6379
107107
```
108108
#### 4.2 服务注册
109109
在已启动好teacher后,需要往redis数据库注册teacher服务。

example/distill/mnist_distill/train_with_fleet.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -154,11 +154,11 @@ def train(nn_type,
154154
train_nranks = fleet.worker_num()
155155

156156
optimizer = fluid.optimizer.Momentum(learning_rate=0.001, momentum=0.9)
157-
if train_nranks != 1:
157+
if use_cuda:
158158
optimizer = fleet.distributed_optimizer(optimizer)
159159
optimizer.minimize(loss)
160160

161-
main_program = fleet.main_program if train_nranks != 1 else main_program
161+
main_program = fleet.main_program if use_cuda else main_program
162162
gpu_id = int(os.getenv("FLAGS_selected_gpus", "0"))
163163

164164
def train_test(train_test_program, train_test_reader):
@@ -183,20 +183,20 @@ def train_test(train_test_program, train_test_reader):
183183
py_train_reader.set_sample_list_generator(train_reader, reader_places)
184184
py_test_reader = fluid.io.DataLoader.from_generator(
185185
feed_list=test_inputs, capacity=64)
186-
py_test_reader.set_sample_list_generator(test_reader, reader_places)
186+
py_test_reader.set_sample_list_generator(test_reader, place)
187187

188188
exe = fluid.Executor(place)
189189
exe.run(startup_program)
190-
epochs = [epoch_id for epoch_id in range(PASS_NUM)]
190+
epochs = [epoch_id for epoch_id in range(NUM_EPOCHS)]
191191

192192
lists = []
193193
step = 0
194194
for epoch_id in epochs:
195195
for step_id, data in enumerate(py_train_reader()):
196196
metrics = exe.run(main_program, feed=data, fetch_list=[loss, acc])
197197
if step % 100 == 0:
198-
print("Pass {}, Epoch {}, Cost {}".format(step, epoch_id,
199-
metrics[0]))
198+
print("Pass {}, Step {}, Cost {}".format(epoch_id, step,
199+
metrics[0].mean()))
200200
step += 1
201201

202202
if train_rank == 0:
@@ -205,7 +205,7 @@ def train_test(train_test_program, train_test_reader):
205205
train_test_program=test_program,
206206
train_test_reader=py_test_reader)
207207

208-
print("Test with Epoch %d, avg_cost: %s, acc: %s" %
208+
print("Test with Pass %d, avg_cost: %s, acc: %s" %
209209
(epoch_id, avg_loss_val, acc_val))
210210
lists.append((epoch_id, avg_loss_val, acc_val))
211211
if save_dirname is not None:
@@ -218,15 +218,17 @@ def train_test(train_test_program, train_test_reader):
218218
if train_rank == 0:
219219
if args.save_serving_model:
220220
import paddle_serving_client.io as serving_io
221-
serving_io.save_model("mnist_cnn_model", "serving_conf",
222-
{img.name: img},
221+
if not os.path.isdir('output'):
222+
os.mkdir('output')
223+
serving_io.save_model("output/mnist_cnn_model",
224+
"output/serving_conf", {img.name: img},
223225
{prediction.name: prediction}, test_program)
224226
print('save serving model, feed_names={}, fetch_names={}'.format(
225227
[img.name], [prediction.name]))
226228

227229
# find the best pass
228230
best = sorted(lists, key=lambda list: float(list[1]))[0]
229-
print('Best pass is %s, testing Avgcost is %s' % (best[0], best[1]))
231+
print('Best pass is %s, testing Avg cost is %s' % (best[0], best[1]))
230232
print('The classification accuracy is %.2f%%' % (float(best[2]) * 100))
231233

232234

@@ -292,7 +294,7 @@ def main(use_cuda, nn_type):
292294
if __name__ == '__main__':
293295
args = parse_args()
294296
BATCH_SIZE = 64
295-
PASS_NUM = args.num_epochs
297+
NUM_EPOCHS = args.num_epochs
296298
use_cuda = args.use_gpu
297299
# predict = 'softmax_regression' # uncomment for Softmax
298300
#predict = 'multilayer_perceptron' # uncomment for MLP

0 commit comments

Comments
 (0)