Skip to content

Commit 6085eff

Browse files
committed
add benchmark for maxout/put_along_axis/take_along_axis
1 parent ccf57d5 commit 6085eff

File tree

6 files changed

+195
-0
lines changed

6 files changed

+195
-0
lines changed

api/dynamic_tests_v2/maxout.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from common_import import *
16+
17+
18+
class MaxoutConfig(APIConfig):
19+
def __init__(self):
20+
super(MaxoutConfig, self).__init__('maxout')
21+
self.feed_spec = {"range": [-1, 1]}
22+
self.run_torch = False
23+
24+
25+
class PaddleMaxout(PaddleDynamicAPIBenchmarkBase):
26+
def build_graph(self, config):
27+
x = self.variable(name="x", shape=config.x_shape, dtype=config.x_dtype)
28+
result = paddle.nn.functional.maxout(x=x,
29+
groups=config.groups,
30+
axis=config.axis)
31+
32+
self.feed_list = [x]
33+
self.fetch_list = [result]
34+
if config.backward:
35+
self.append_gradients(result, [x])
36+
37+
38+
if __name__ == '__main__':
39+
test_main(pd_dy_obj=PaddleMaxout(), config=MaxoutConfig())
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from common_import import *
16+
17+
18+
class PutAlongAxisConfig(APIConfig):
19+
def __init__(self):
20+
super(PutAlongAxisConfig, self).__init__("put_along_axis")
21+
self.feed_spec = [{"range": [-1, 1]}, {"range": [0, 100]}]
22+
self.run_torch = False
23+
24+
25+
class PaddlePutAlongAxis(PaddleDynamicAPIBenchmarkBase):
26+
def build_graph(self, config):
27+
arr = self.variable(name="arr",
28+
shape=config.arr_shape,
29+
dtype=config.arr_dtype)
30+
indices = self.variable(name="indices",
31+
shape=config.indices_shape,
32+
dtype=config.indices_dtype)
33+
result = paddle.put_along_axis(arr=arr,
34+
indices=indices,
35+
values=config.values,
36+
axis=config.axis)
37+
38+
self.feed_list = [arr]
39+
self.fetch_list = [result]
40+
if config.backward:
41+
self.append_gradients(result, [arr])
42+
43+
44+
if __name__ == '__main__':
45+
test_main(pd_dy_obj=PaddlePutAlongAxis(), config=PutAlongAxisConfig())
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from common_import import *
16+
17+
18+
class TakeAlongAxisConfig(APIConfig):
19+
def __init__(self):
20+
super(TakeAlongAxisConfig, self).__init__("take_along_axis")
21+
self.feed_spec = [{"range": [-1, 1]}, {"range": [0, 100]}]
22+
self.run_torch = False
23+
24+
25+
class PaddleTakeAlongAxis(PaddleDynamicAPIBenchmarkBase):
26+
def build_graph(self, config):
27+
arr = self.variable(name="arr",
28+
shape=config.arr_shape,
29+
dtype=config.arr_dtype)
30+
indices = self.variable(name="indices",
31+
shape=config.indices_shape,
32+
dtype=config.indices_dtype)
33+
result = paddle.take_along_axis(arr=arr,
34+
indices=indices,
35+
axis=config.axis)
36+
37+
self.feed_list = [arr, indices]
38+
self.fetch_list = [result]
39+
if config.backward:
40+
self.append_gradients(result, [arr])
41+
42+
43+
if __name__ == '__main__':
44+
test_main(pd_dy_obj=PaddleTakeAlongAxis(), config=TakeAlongAxisConfig())

api/tests_v2/configs/maxout.json

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
[
2+
{
3+
"op": "maxout",
4+
"param_info": {
5+
"groups": {
6+
"type": "int",
7+
"value": "2"
8+
},
9+
"axis": {
10+
"type": "int",
11+
"value": "-1"
12+
},
13+
"x": {
14+
"dtype": "float32",
15+
"shape": "[32L, 12L, 128L, 128L]",
16+
"type": "Variable"
17+
}
18+
},
19+
"repeat": 5000
20+
}
21+
]
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
[
2+
{
3+
"op": "put_along_axis",
4+
"param_info": {
5+
"arr": {
6+
"dtype": "float32",
7+
"shape": "[200, 300]",
8+
"type": "Variable"
9+
},
10+
"indices": {
11+
"dtype": "int",
12+
"shape": "[1, 1]",
13+
"type": "Variable"
14+
},
15+
"values": {
16+
"type": "float32",
17+
"value": "99"
18+
},
19+
"axis": {
20+
"type": "int",
21+
"value": "0"
22+
}
23+
}
24+
}
25+
]
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
[
2+
{
3+
"op": "take_along_axis",
4+
"param_info": {
5+
"arr": {
6+
"dtype": "float32",
7+
"shape": "[200, 300]",
8+
"type": "Variable"
9+
},
10+
"indices": {
11+
"dtype": "int",
12+
"shape": "[1, 1]",
13+
"type": "Variable"
14+
},
15+
"axis": {
16+
"type": "int",
17+
"value": "0"
18+
}
19+
}
20+
}
21+
]

0 commit comments

Comments
 (0)