Skip to content

Commit 8fb53b7

Browse files
committed
Merge branch 'develop' of github.com:baidu/Paddle into feature/make_recognize_digits_normal_unittest
2 parents d11e7b4 + 270ecbe commit 8fb53b7

File tree

3 files changed

+147
-88
lines changed

3 files changed

+147
-88
lines changed

python/paddle/v2/fluid/tests/book/CMakeLists.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
22
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
33

4-
list(REMOVE_ITEM TEST_OPS test_image_classification_train)
5-
py_test(test_image_classification_train_resnet SRCS test_image_classification_train.py ARGS resnet)
6-
py_test(test_image_classification_train_vgg SRCS test_image_classification_train.py ARGS vgg)
7-
84
# default test
95
foreach(src ${TEST_OPS})
106
py_test(${src} SRCS ${src}.py)

python/paddle/v2/fluid/tests/book/test_fit_a_line.py

Lines changed: 59 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,44 +12,74 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import numpy as np
1615
import paddle.v2 as paddle
1716
import paddle.v2.fluid as fluid
17+
import contextlib
18+
import unittest
1819

19-
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
2020

21-
y_predict = fluid.layers.fc(input=x, size=1, act=None)
21+
def main(use_cuda):
22+
if use_cuda and not fluid.core.is_compiled_with_cuda():
23+
return
2224

23-
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
25+
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
2426

25-
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
26-
avg_cost = fluid.layers.mean(x=cost)
27+
y_predict = fluid.layers.fc(input=x, size=1, act=None)
2728

28-
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
29-
sgd_optimizer.minimize(avg_cost)
29+
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
3030

31-
BATCH_SIZE = 20
31+
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
32+
avg_cost = fluid.layers.mean(x=cost)
3233

33-
train_reader = paddle.batch(
34-
paddle.reader.shuffle(
35-
paddle.dataset.uci_housing.train(), buf_size=500),
36-
batch_size=BATCH_SIZE)
34+
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
35+
sgd_optimizer.minimize(avg_cost)
3736

38-
place = fluid.CPUPlace()
39-
feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
40-
exe = fluid.Executor(place)
37+
BATCH_SIZE = 20
4138

42-
exe.run(fluid.default_startup_program())
39+
train_reader = paddle.batch(
40+
paddle.reader.shuffle(
41+
paddle.dataset.uci_housing.train(), buf_size=500),
42+
batch_size=BATCH_SIZE)
4343

44-
PASS_NUM = 100
45-
for pass_id in range(PASS_NUM):
46-
fluid.io.save_persistables(exe, "./fit_a_line.model/")
47-
fluid.io.load_persistables(exe, "./fit_a_line.model/")
48-
for data in train_reader():
49-
avg_loss_value, = exe.run(fluid.default_main_program(),
50-
feed=feeder.feed(data),
51-
fetch_list=[avg_cost])
52-
print(avg_loss_value)
53-
if avg_loss_value[0] < 10.0:
54-
exit(0) # if avg cost less than 10.0, we think our code is good.
55-
exit(1)
44+
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
45+
feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
46+
exe = fluid.Executor(place)
47+
48+
exe.run(fluid.default_startup_program())
49+
50+
PASS_NUM = 100
51+
for pass_id in range(PASS_NUM):
52+
fluid.io.save_persistables(exe, "./fit_a_line.model/")
53+
fluid.io.load_persistables(exe, "./fit_a_line.model/")
54+
for data in train_reader():
55+
avg_loss_value, = exe.run(fluid.default_main_program(),
56+
feed=feeder.feed(data),
57+
fetch_list=[avg_cost])
58+
print(avg_loss_value)
59+
if avg_loss_value[0] < 10.0:
60+
return
61+
raise AssertionError("Fit a line cost is too large, {0:2.2}".format(
62+
avg_loss_value[0]))
63+
64+
65+
class TestFitALine(unittest.TestCase):
66+
def test_cpu(self):
67+
with self.program_scope_guard():
68+
main(use_cuda=False)
69+
70+
def test_cuda(self):
71+
with self.program_scope_guard():
72+
main(use_cuda=True)
73+
74+
@contextlib.contextmanager
75+
def program_scope_guard(self):
76+
prog = fluid.Program()
77+
startup_prog = fluid.Program()
78+
scope = fluid.core.Scope()
79+
with fluid.scope_guard(scope):
80+
with fluid.program_guard(prog, startup_prog):
81+
yield
82+
83+
84+
if __name__ == '__main__':
85+
unittest.main()

python/paddle/v2/fluid/tests/book/test_image_classification_train.py

Lines changed: 88 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414

1515
from __future__ import print_function
1616

17-
import sys
18-
1917
import paddle.v2 as paddle
2018
import paddle.v2.fluid as fluid
19+
import unittest
20+
import contextlib
2121

2222

2323
def resnet_cifar10(input, depth=32):
@@ -89,56 +89,89 @@ def conv_block(input, num_filter, groups, dropouts):
8989
return fc2
9090

9191

92-
classdim = 10
93-
data_shape = [3, 32, 32]
94-
95-
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
96-
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
97-
98-
net_type = "vgg"
99-
if len(sys.argv) >= 2:
100-
net_type = sys.argv[1]
101-
102-
if net_type == "vgg":
103-
print("train vgg net")
104-
net = vgg16_bn_drop(images)
105-
elif net_type == "resnet":
106-
print("train resnet")
107-
net = resnet_cifar10(images, 32)
108-
else:
109-
raise ValueError("%s network is not supported" % net_type)
110-
111-
predict = fluid.layers.fc(input=net, size=classdim, act='softmax')
112-
cost = fluid.layers.cross_entropy(input=predict, label=label)
113-
avg_cost = fluid.layers.mean(x=cost)
114-
115-
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
116-
opts = optimizer.minimize(avg_cost)
117-
118-
accuracy = fluid.evaluator.Accuracy(input=predict, label=label)
119-
120-
BATCH_SIZE = 128
121-
PASS_NUM = 1
122-
123-
train_reader = paddle.batch(
124-
paddle.reader.shuffle(
125-
paddle.dataset.cifar.train10(), buf_size=128 * 10),
126-
batch_size=BATCH_SIZE)
127-
128-
place = fluid.CPUPlace()
129-
exe = fluid.Executor(place)
130-
feeder = fluid.DataFeeder(place=place, feed_list=[images, label])
131-
exe.run(fluid.default_startup_program())
132-
133-
for pass_id in range(PASS_NUM):
134-
accuracy.reset(exe)
135-
for data in train_reader():
136-
loss, acc = exe.run(fluid.default_main_program(),
137-
feed=feeder.feed(data),
138-
fetch_list=[avg_cost] + accuracy.metrics)
139-
pass_acc = accuracy.eval(exe)
140-
print("loss:" + str(loss) + " acc:" + str(acc) + " pass_acc:" + str(
141-
pass_acc))
142-
# this model is slow, so if we can train two mini batch, we think it works properly.
143-
exit(0)
144-
exit(1)
92+
def main(net_type, use_cuda):
93+
if use_cuda and not fluid.core.is_compiled_with_cuda():
94+
return
95+
96+
classdim = 10
97+
data_shape = [3, 32, 32]
98+
99+
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
100+
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
101+
102+
if net_type == "vgg":
103+
print("train vgg net")
104+
net = vgg16_bn_drop(images)
105+
elif net_type == "resnet":
106+
print("train resnet")
107+
net = resnet_cifar10(images, 32)
108+
else:
109+
raise ValueError("%s network is not supported" % net_type)
110+
111+
predict = fluid.layers.fc(input=net, size=classdim, act='softmax')
112+
cost = fluid.layers.cross_entropy(input=predict, label=label)
113+
avg_cost = fluid.layers.mean(x=cost)
114+
115+
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
116+
optimizer.minimize(avg_cost)
117+
118+
accuracy = fluid.evaluator.Accuracy(input=predict, label=label)
119+
120+
BATCH_SIZE = 128
121+
PASS_NUM = 1
122+
123+
train_reader = paddle.batch(
124+
paddle.reader.shuffle(
125+
paddle.dataset.cifar.train10(), buf_size=128 * 10),
126+
batch_size=BATCH_SIZE)
127+
128+
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
129+
exe = fluid.Executor(place)
130+
feeder = fluid.DataFeeder(place=place, feed_list=[images, label])
131+
exe.run(fluid.default_startup_program())
132+
133+
loss = 0.0
134+
for pass_id in range(PASS_NUM):
135+
accuracy.reset(exe)
136+
for data in train_reader():
137+
loss, acc = exe.run(fluid.default_main_program(),
138+
feed=feeder.feed(data),
139+
fetch_list=[avg_cost] + accuracy.metrics)
140+
pass_acc = accuracy.eval(exe)
141+
print("loss:" + str(loss) + " acc:" + str(acc) + " pass_acc:" + str(
142+
pass_acc))
143+
return
144+
145+
raise AssertionError(
146+
"Image classification loss is too large, {0:2.2}".format(loss))
147+
148+
149+
class TestImageClassification(unittest.TestCase):
150+
def test_vgg_cuda(self):
151+
with self.scope_prog_guard():
152+
main('vgg', use_cuda=True)
153+
154+
def test_resnet_cuda(self):
155+
with self.scope_prog_guard():
156+
main('resnet', use_cuda=True)
157+
158+
def test_vgg_cpu(self):
159+
with self.scope_prog_guard():
160+
main('vgg', use_cuda=False)
161+
162+
def test_resnet_cpu(self):
163+
with self.scope_prog_guard():
164+
main('resnet', use_cuda=False)
165+
166+
@contextlib.contextmanager
167+
def scope_prog_guard(self):
168+
prog = fluid.Program()
169+
startup_prog = fluid.Program()
170+
scope = fluid.core.Scope()
171+
with fluid.scope_guard(scope):
172+
with fluid.program_guard(prog, startup_prog):
173+
yield
174+
175+
176+
if __name__ == '__main__':
177+
unittest.main()

0 commit comments

Comments
 (0)