Skip to content

Commit 1db188f

Browse files
authored
[IPU] update ipu unittests p0 (#39707)
* update ipu UTs part0 * rename UT * sync api changes * update uts for new api * use_ipumodel() as classmethod
1 parent 0c3f7fb commit 1db188f

13 files changed

+950
-1233
lines changed

python/paddle/fluid/tests/unittests/ipu/ernie_training.py

Lines changed: 0 additions & 934 deletions
This file was deleted.

python/paddle/fluid/tests/unittests/ipu/op_test_ipu.py

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
1516
import random
1617
import unittest
17-
1818
import numpy as np
19-
from paddle.fluid.tests.unittests.op_test import _set_use_system_allocator
20-
from typing import Optional
21-
import paddle.fluid.compiler as compiler
22-
23-
SEED = 2021
19+
from enum import Enum
2420

25-
ipu_compiler_ref: Optional[compiler.IPUCompiledProgram] = None
21+
import paddle
22+
import paddle.static
2623

2724
map_np_dtype_to_fluid_dtype = {
2825
'bool': "bool",
@@ -36,36 +33,84 @@
3633
}
3734

3835

36+
class ExecutionMode(Enum):
37+
CPU_FP32 = 1
38+
IPU_FP32 = 2
39+
# enable_fp16 through ipu_strategy.enable_fp16
40+
IPU_POPART_FP16 = 3
41+
42+
def __lt__(self, other):
43+
return self.value < other.value
44+
45+
def __gt__(self, other):
46+
return self.value > other.value
47+
48+
3949
def np_dtype_to_fluid_str(dtype: np.dtype) -> str:
4050
return map_np_dtype_to_fluid_dtype[dtype.name]
4151

4252

4353
class IPUOpTest(unittest.TestCase):
4454
@classmethod
4555
def setUpClass(cls):
56+
# Get random seeds
4657
cls._np_rand_state = np.random.get_state()
4758
cls._py_rand_state = random.getstate()
4859

49-
cls.SEED = SEED
60+
cls.SEED = 2021
5061
np.random.seed(cls.SEED)
5162
random.seed(cls.SEED)
5263

53-
cls._use_system_allocator = _set_use_system_allocator(True)
64+
# Enable paddle static graph mode
65+
paddle.enable_static()
5466

5567
@classmethod
5668
def tearDownClass(cls):
5769
"""Restore random seeds"""
5870
np.random.set_state(cls._np_rand_state)
5971
random.setstate(cls._py_rand_state)
6072

61-
_set_use_system_allocator(cls._use_system_allocator)
62-
# unittest will to trigger IPUCompiledProgram.__del__ automatically
63-
global ipu_compiler_ref
64-
ipu_compiler_ref is not None and ipu_compiler_ref.clean()
73+
@classmethod
74+
def use_ipumodel(cls):
75+
if 'POPLAR_IPUMODEL' not in os.environ:
76+
return False
77+
else:
78+
flag = os.environ['POPLAR_IPUMODEL']
79+
if flag.upper() in ['1', "TRUE"]:
80+
return True
6581

6682
def set_atol(self):
67-
self.atol = 1e-5
83+
self.atol = 1e-10
84+
self.rtol = 1e-6
85+
self.atol_fp16 = 1e-3
86+
self.rtol_fp16 = 1e-3
6887

6988
def set_training(self):
7089
self.is_training = False
7190
self.epoch = 1
91+
92+
def check(self, outputs, check_shape=False):
93+
cpu_fp32 = outputs[ExecutionMode.CPU_FP32]
94+
ipu_fp32 = outputs[ExecutionMode.IPU_FP32]
95+
max_diff = np.abs(cpu_fp32 - ipu_fp32).max()
96+
fp32_flag = np.allclose(
97+
cpu_fp32, ipu_fp32, rtol=self.rtol, atol=self.atol)
98+
self.assertTrue(fp32_flag, "max diff is %f" % (max_diff))
99+
100+
if check_shape:
101+
self.assertTrue(cpu_fp32.shape == ipu_fp32.shape)
102+
103+
ipu_popart_fp16 = None
104+
if ExecutionMode.IPU_POPART_FP16 in outputs.keys():
105+
ipu_popart_fp16 = outputs[ExecutionMode.IPU_POPART_FP16]
106+
max_diff = np.abs(ipu_popart_fp16.astype(np.float32) -
107+
cpu_fp32).max()
108+
fp16_flag = np.allclose(
109+
ipu_popart_fp16.astype(np.float32),
110+
cpu_fp32,
111+
rtol=self.rtol_fp16,
112+
atol=self.atol_fp16)
113+
self.assertTrue(fp16_flag, "max diff is %f" % (max_diff))
114+
115+
if check_shape:
116+
self.assertTrue(ipu_popart_fp16.shape == cpu_fp32.shape)
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
import paddle
19+
import paddle.nn.functional as F
20+
import paddle.static
21+
from paddle.fluid.tests.unittests.ipu.op_test_ipu import (ExecutionMode,
22+
IPUOpTest)
23+
24+
25+
@unittest.skipIf(not paddle.is_compiled_with_ipu(),
26+
"core is not compiled with IPU")
27+
class TestRelu(IPUOpTest):
28+
def setUp(self):
29+
self.set_atol()
30+
self.set_test_op()
31+
self.set_training()
32+
self.set_data_feed()
33+
self.set_feed_attr()
34+
35+
@property
36+
def fp16_enabled(self):
37+
return True
38+
39+
def set_test_op(self):
40+
self.op = paddle.fluid.layers.relu
41+
self.op_attrs = {}
42+
43+
def set_data_feed(self):
44+
data = np.random.uniform(size=[1, 3, 10, 10])
45+
self.feed_fp32 = {'in_0': data.astype(np.float32)}
46+
self.feed_fp16 = {'in_0': data.astype(np.float16)}
47+
48+
def set_feed_attr(self):
49+
self.feed_shape = [x.shape for x in self.feed_fp32.values()]
50+
self.feed_list = list(self.feed_fp32.keys())
51+
52+
def _test_base(self, exec_mode):
53+
scope = paddle.static.Scope()
54+
main_prog = paddle.static.Program()
55+
startup_prog = paddle.static.Program()
56+
main_prog.random_seed = self.SEED
57+
startup_prog.random_seed = self.SEED
58+
59+
with paddle.static.scope_guard(scope):
60+
with paddle.static.program_guard(main_prog, startup_prog):
61+
x = paddle.static.data(
62+
name=self.feed_list[0],
63+
shape=self.feed_shape[0],
64+
dtype='float32')
65+
66+
out = self.op(x, **self.op_attrs)
67+
68+
fetch_list = [out.name]
69+
70+
if exec_mode == ExecutionMode.CPU_FP32:
71+
place = paddle.CPUPlace()
72+
else:
73+
place = paddle.IPUPlace()
74+
75+
exe = paddle.static.Executor(place)
76+
exe.run(startup_prog)
77+
78+
if exec_mode != ExecutionMode.CPU_FP32:
79+
feed_list = self.feed_list
80+
ipu_strategy = paddle.static.IpuStrategy()
81+
82+
ipu_strategy.set_graph_config(is_training=self.is_training)
83+
if exec_mode == ExecutionMode.IPU_POPART_FP16:
84+
ipu_strategy.set_precision_config(enable_fp16=True)
85+
program = paddle.static.IpuCompiledProgram(
86+
main_prog,
87+
ipu_strategy=ipu_strategy).compile(feed_list, fetch_list)
88+
else:
89+
program = main_prog
90+
91+
feed = self.feed_fp32
92+
if exec_mode > ExecutionMode.IPU_FP32:
93+
feed = self.feed_fp16
94+
95+
result = exe.run(program, feed=feed, fetch_list=fetch_list)
96+
return result[0]
97+
98+
def test(self):
99+
output_dict = {}
100+
for mode in ExecutionMode:
101+
if mode > ExecutionMode.IPU_FP32 and not self.fp16_enabled:
102+
break
103+
output_dict[mode] = self._test_base(mode).flatten()
104+
105+
self.check(output_dict)
106+
107+
108+
class TestTanh(TestRelu):
109+
def set_test_op(self):
110+
self.op = F.tanh
111+
self.op_attrs = {}
112+
113+
114+
class TestLog(TestRelu):
115+
def set_test_op(self):
116+
self.op = paddle.fluid.layers.log
117+
self.op_attrs = {}
118+
119+
120+
class TestSigmoid(TestRelu):
121+
def set_test_op(self):
122+
self.op = F.sigmoid
123+
self.op_attrs = {}
124+
125+
126+
class TestSqrt(TestRelu):
127+
def set_test_op(self):
128+
self.op = paddle.fluid.layers.sqrt
129+
self.op_attrs = {}
130+
131+
132+
if __name__ == "__main__":
133+
unittest.main()
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
import paddle
19+
import paddle.static
20+
from paddle.fluid.tests.unittests.ipu.op_test_ipu import (ExecutionMode,
21+
IPUOpTest)
22+
23+
24+
@unittest.skipIf(not paddle.is_compiled_with_ipu(),
25+
"core is not compiled with IPU")
26+
class TestBase(IPUOpTest):
27+
def setUp(self):
28+
self.set_atol()
29+
self.set_training()
30+
self.set_data_feed()
31+
self.set_feed_attr()
32+
self.set_op_attrs()
33+
34+
@property
35+
def fp16_enabled(self):
36+
return True
37+
38+
def set_data_feed(self):
39+
data = np.random.uniform(size=[10, 1000])
40+
self.feed_fp32 = {"in_0": data.astype(np.float32)}
41+
self.feed_fp16 = {"in_0": data.astype(np.float16)}
42+
43+
def set_feed_attr(self):
44+
self.feed_shape = [x.shape for x in self.feed_fp32.values()]
45+
self.feed_list = list(self.feed_fp32.keys())
46+
self.feed_dtype = [x.dtype for x in self.feed_fp32.values()]
47+
48+
def set_op_attrs(self):
49+
self.attrs = {"axis": -1}
50+
51+
def _test_base(self, exec_mode):
52+
scope = paddle.static.Scope()
53+
main_prog = paddle.static.Program()
54+
startup_prog = paddle.static.Program()
55+
main_prog.random_seed = self.SEED
56+
startup_prog.random_seed = self.SEED
57+
58+
with paddle.static.scope_guard(scope):
59+
with paddle.static.program_guard(main_prog, startup_prog):
60+
x = paddle.static.data(
61+
name=self.feed_list[0],
62+
shape=self.feed_shape[0],
63+
dtype='float32')
64+
65+
out = paddle.fluid.layers.argmax(x, **self.attrs)
66+
67+
fetch_list = [out.name]
68+
69+
if exec_mode == ExecutionMode.CPU_FP32:
70+
place = paddle.CPUPlace()
71+
else:
72+
place = paddle.IPUPlace()
73+
74+
exe = paddle.static.Executor(place)
75+
exe.run(startup_prog)
76+
77+
if exec_mode != ExecutionMode.CPU_FP32:
78+
feed_list = self.feed_list
79+
ipu_strategy = paddle.static.IpuStrategy()
80+
ipu_strategy.set_graph_config(is_training=self.is_training)
81+
if exec_mode == ExecutionMode.IPU_POPART_FP16:
82+
ipu_strategy.set_precision_config(enable_fp16=True)
83+
program = paddle.static.IpuCompiledProgram(
84+
main_prog,
85+
ipu_strategy=ipu_strategy).compile(feed_list, fetch_list)
86+
else:
87+
program = main_prog
88+
89+
feed = self.feed_fp32
90+
if exec_mode > ExecutionMode.IPU_FP32:
91+
feed = self.feed_fp16
92+
93+
result = exe.run(program, feed=feed, fetch_list=fetch_list)
94+
return result[0].astype(np.int32)
95+
96+
def test_base(self):
97+
output_dict_fp32 = {}
98+
output_dict_fp16 = {}
99+
for mode in ExecutionMode:
100+
if mode > ExecutionMode.IPU_FP32 and not self.fp16_enabled:
101+
break
102+
103+
if mode > ExecutionMode.IPU_FP32:
104+
output_dict_fp16[mode] = self._test_base(mode).flatten()
105+
else:
106+
output_dict_fp32[mode] = self._test_base(mode).flatten()
107+
108+
self.check(output_dict_fp32)
109+
110+
111+
class TestCase1(TestBase):
112+
def set_op_attrs(self):
113+
self.attrs = {"axis": 0}
114+
115+
116+
if __name__ == "__main__":
117+
unittest.main()

0 commit comments

Comments
 (0)