Skip to content

Commit f876320

Browse files
authored
support code auto-gene for sparse backward api (#40196)
1 parent d4b007a commit f876320

File tree

6 files changed

+235
-9
lines changed

6 files changed

+235
-9
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@ paddle/fluid/eager/api/generated/*
66
paddle/fluid/op_use_default_grad_maker_DEV.spec
77
paddle/fluid/op_use_default_grad_maker_PR.spec
88
paddle/phi/api/backward/backward_api.h
9+
paddle/phi/api/backward/sparse_bw_api.h
910
paddle/phi/api/include/api.h
1011
paddle/phi/api/include/sparse_api.h
1112
paddle/phi/api/lib/api.cc
1213
paddle/phi/api/lib/dygraph_api.*
1314
paddle/phi/api/lib/backward_api.cc
1415
paddle/phi/api/lib/sparse_api.cc
16+
paddle/phi/api/lib/sparse_bw_api.cc
1517
paddle/phi/extension.h
1618
paddle/phi/include/*
1719
paddle/phi/infermeta/generated.*

paddle/phi/api/lib/CMakeLists.txt

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,14 @@ set(sparse_api_source_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/lib/sparse_api.cc)
4040
set(sparse_api_header_file_tmp ${api_header_file}.tmp)
4141
set(sparse_api_source_file_tmp ${api_source_file}.tmp)
4242

43+
# sparse bw api file
44+
set(sparse_bw_api_gen_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/sparse_bw_api_gen.py)
45+
set(sparse_bw_api_yaml_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/sparse_bw_api.yaml)
46+
set(sparse_bw_api_header_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/backward/sparse_bw_api.h)
47+
set(sparse_bw_api_source_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/lib/sparse_bw_api.cc)
48+
set(sparse_bw_api_header_file_tmp ${sparse_bw_api_header_file}.tmp)
49+
set(sparse_bw_api_source_file_tmp ${sparse_bw_api_source_file}.tmp)
50+
4351
# wrapped infermeta file
4452
set(wrapped_infermeta_gen_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/wrapped_infermeta_gen.py)
4553
set(api_yaml_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/api.yaml)
@@ -91,7 +99,20 @@ add_custom_command(
9199
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${sparse_api_header_file_tmp} ${sparse_api_header_file}
92100
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${sparse_api_source_file_tmp} ${sparse_api_source_file}
93101
COMMENT "copy_if_different ${sparse_api_header_file} ${sparse_sparse_api_source_file}"
94-
DEPENDS ${sparse_api_yaml_file} ${sparse_api_gen_file} ${api_gen_base}
102+
DEPENDS ${sparse_api_yaml_file} ${sparse_api_gen_file} ${api_gen_base} ${api_gen_file}
103+
VERBATIM)
104+
105+
# generate backward sparse api
106+
add_custom_command(
107+
OUTPUT ${sparse_bw_api_header_file} ${sparse_bw_api_source_file}
108+
COMMAND ${PYTHON_EXECUTABLE} ${sparse_bw_api_gen_file}
109+
--api_yaml_path ${sparse_bw_api_yaml_file}
110+
--api_header_path ${sparse_bw_api_header_file_tmp}
111+
--api_source_path ${sparse_bw_api_source_file_tmp}
112+
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${sparse_bw_api_header_file_tmp} ${sparse_bw_api_header_file}
113+
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${sparse_bw_api_source_file_tmp} ${sparse_bw_api_source_file}
114+
COMMENT "copy_if_different ${sparse_bw_api_header_file} ${sparse_bw_sparse_api_source_file}"
115+
DEPENDS ${sparse_bw_api_yaml_file} ${sparse_bw_api_gen_file} ${api_gen_base} ${api_gen_file} ${sparse_api_gen_file} ${bw_api_gen_file}
95116
VERBATIM)
96117

97118
# generate wrapped infermeta
@@ -113,9 +134,10 @@ cc_library(phi_data_transform SRCS data_transform.cc DEPS phi_tensor_raw transfe
113134
cc_library(api_custom_impl SRCS api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform)
114135
cc_library(sparse_api_custom_impl SRCS sparse_api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform)
115136

116-
cc_library(sparse_api SRCS sparse_api.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils sparse_api_custom_impl)
117137
cc_library(phi_function_api SRCS ${api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform api_custom_impl)
118138
cc_library(phi_dygraph_api SRCS ${dygraph_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform)
119139
cc_library(phi_bw_function_api SRCS ${bw_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils backward_infermeta phi_data_transform phi_function_api api_custom_impl)
140+
cc_library(sparse_api SRCS ${sparse_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils sparse_api_custom_impl)
141+
cc_library(sparse_bw_api SRCS ${sparse_bw_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils sparse_api sparse_api_custom_impl)
120142

121143
cc_library(phi_tensor SRCS tensor_method.cc DEPS phi_tensor_raw phi_function_api)

python/paddle/utils/code_gen/backward_api_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def parse_forward_config(self, forward_config):
3535
forward_config)
3636
api = result.group('api')
3737
_, outputs, _ = self.parse_output(self.api, result.group('outputs'))
38+
outputs = [item.split('@')[0] for item in outputs]
3839
fw_inputs, fw_attrs, _, = self.parse_input_and_attr(
3940
api, result.group('args'))
4041

python/paddle/utils/code_gen/sparse_api_gen.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
import argparse
1818
import re
1919

20-
from api_base import BaseAPI
20+
from api_gen import ForwardAPI
2121

2222

23-
class SparseAPI(BaseAPI):
23+
class SparseAPI(ForwardAPI):
2424
def __init__(self, api_item_yaml):
2525
super(SparseAPI, self).__init__(api_item_yaml)
2626

@@ -30,11 +30,6 @@ def get_api_name(self, api_item_yaml):
3030
def get_api_func_name(self):
3131
return self.api
3232

33-
def get_return_type(self, out_type_list):
34-
return out_type_list[0] if len(
35-
out_type_list) == 1 else "std::tuple<" + ",".join(
36-
out_type_list) + ">"
37-
3833
def gene_api_declaration(self):
3934
return f"""
4035
// {", ".join(self.outputs['names'])}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
- sparse_bw_api : conv3d_grad
2+
forward : conv3d (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups) -> Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor)
3+
args : (Tensor x, Tensor kernel, Tensor rulebook, Tensor out_grad, int[] paddings, int[] dilations, int[] strides, int groups)
4+
output : Tensor(x_grad@DenseTensor), Tensor(kernel_grad@DenseTensor)
5+
kernel :
6+
func : sparse_conv_grad
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import yaml
17+
import argparse
18+
import re
19+
20+
from sparse_api_gen import SparseAPI
21+
from backward_api_gen import BackwardAPI
22+
23+
24+
class SparseBackwardAPI(SparseAPI, BackwardAPI):
25+
def __init__(self, bw_api_item_yaml):
26+
BackwardAPI.__init__(self, bw_api_item_yaml)
27+
28+
def get_api_name(self, api_item_yaml):
29+
return api_item_yaml['sparse_bw_api']
30+
31+
def get_api_func_name(self):
32+
return self.api
33+
34+
def get_return_type(self, out_type_list):
35+
return BackwardAPI.get_return_type(self, out_type_list)
36+
37+
def gene_api_declaration(self):
38+
return SparseAPI.gene_api_declaration(self)
39+
40+
def gene_output(self,
41+
output_type_list,
42+
set_out_func,
43+
code_indent,
44+
inplace_flag=False):
45+
kernel_output = ""
46+
output_names = []
47+
output_create = ""
48+
49+
if len(output_type_list) == 1:
50+
kernel_output = 'kernel_out'
51+
output_names.append('kernel_out')
52+
inplace_assign = " = " + self.inplace_map[self.outputs['names'][
53+
0]] if inplace_flag and self.inplace_map is not None and self.outputs[
54+
'names'][0] in self.inplace_map else ""
55+
output_create = f"""
56+
{self.outputs['return_type']} out{inplace_assign};
57+
auto kernel_out = {set_out_func}(&out, {self.get_kernel_tensor_out_type(self.outputs['names'][0])});"""
58+
59+
elif len(output_type_list) > 1:
60+
output_create = f"""
61+
{self.outputs['return_type']} out({len(output_type_list)});"""
62+
63+
for i, out_type_item in enumerate(output_type_list):
64+
kernel_output = kernel_output + f'kernel_out_{i}, '
65+
output_names.append(f'kernel_out_{i}')
66+
if out_type_item == 'Tensor':
67+
get_out_code = f'&out[{i}][0]'
68+
if inplace_flag and self.inplace_map is not None and self.outputs[
69+
'names'][i] in self.inplace_map:
70+
output_create = output_create + f"""
71+
out[{i}].emplace_back({self.inplace_map[self.outputs['names'][i]]});"""
72+
73+
else:
74+
output_create = output_create + f"""
75+
out[{i}].emplace_back();"""
76+
77+
else:
78+
get_out_code = f'&out[{i}]'
79+
if inplace_flag and self.inplace_map is not None and self.outputs[
80+
'names'][i] in self.inplace_map:
81+
output_create = output_create + f"""
82+
out[{i}] = {self.inplace_map[self.outputs['names'][i]]};"""
83+
84+
output_create = output_create + f"""
85+
auto kernel_out_{i} = {set_out_func}({get_out_code}, {self.get_kernel_tensor_out_type(self.outputs['names'][i])});"""
86+
87+
kernel_output = kernel_output[:-2]
88+
else:
89+
raise ValueError(
90+
"{} : Output error: the output should not be empty.".format(
91+
self.api))
92+
93+
return kernel_output, output_names, output_create
94+
95+
96+
def header_include():
97+
return """
98+
#include "paddle/phi/api/include/tensor.h"
99+
#include "paddle/phi/common/scalar.h"
100+
#include "paddle/phi/common/scalar_array.h"
101+
#include "paddle/utils/optional.h"
102+
"""
103+
104+
105+
def source_include(header_file_path):
106+
return f"""
107+
#include "{header_file_path}"
108+
#include <memory>
109+
110+
#include "glog/logging.h"
111+
112+
#include "paddle/phi/api/lib/api_registry.h"
113+
#include "paddle/phi/api/lib/api_gen_utils.h"
114+
#include "paddle/phi/api/lib/kernel_dispatch.h"
115+
#include "paddle/phi/api/lib/sparse_api_custom_impl.h"
116+
#include "paddle/phi/core/kernel_registry.h"
117+
#include "paddle/phi/kernels/declarations.h"
118+
"""
119+
120+
121+
def api_register():
122+
return """
123+
PD_REGISTER_API(Test);
124+
"""
125+
126+
127+
def api_namespace():
128+
return ("""
129+
namespace paddle {
130+
namespace experimental {
131+
namespace sparse {
132+
133+
""", """
134+
135+
} // namespace sparse
136+
} // namespace experimental
137+
} // namespace paddle
138+
""")
139+
140+
141+
def generate_api(api_yaml_path, header_file_path, source_file_path):
142+
143+
with open(api_yaml_path, 'r') as f:
144+
apis = yaml.load(f, Loader=yaml.FullLoader)
145+
header_file = open(header_file_path, 'w')
146+
source_file = open(source_file_path, 'w')
147+
148+
namespace = api_namespace()
149+
150+
header_file.write("#pragma once\n")
151+
header_file.write(header_include())
152+
header_file.write(namespace[0])
153+
154+
include_header_file = "paddle/phi/api/backward/sparse_bw_api.h"
155+
source_file.write(source_include(include_header_file))
156+
source_file.write(namespace[0])
157+
158+
for api in apis:
159+
sparse_bw_api = SparseBackwardAPI(api)
160+
header_file.write(sparse_bw_api.gene_api_declaration())
161+
source_file.write(sparse_bw_api.gene_api_code())
162+
163+
header_file.write(namespace[1])
164+
source_file.write(namespace[1])
165+
166+
source_file.write(api_register())
167+
168+
header_file.close()
169+
source_file.close()
170+
171+
172+
def main():
173+
parser = argparse.ArgumentParser(
174+
description='Generate PaddlePaddle C++ Sparse API files')
175+
parser.add_argument(
176+
'--api_yaml_path',
177+
help='path to sparse api yaml file',
178+
default='python/paddle/utils/code_gen/sparse_bw_api.yaml')
179+
180+
parser.add_argument(
181+
'--api_header_path',
182+
help='output of generated api header code file',
183+
default='paddle/phi/api/backward/sparse_bw_api.h')
184+
185+
parser.add_argument(
186+
'--api_source_path',
187+
help='output of generated api source code file',
188+
default='paddle/phi/api/lib/sparse_bw_api.cc')
189+
190+
options = parser.parse_args()
191+
192+
api_yaml_path = options.api_yaml_path
193+
header_file_path = options.api_header_path
194+
source_file_path = options.api_source_path
195+
196+
generate_api(api_yaml_path, header_file_path, source_file_path)
197+
198+
199+
if __name__ == '__main__':
200+
main()

0 commit comments

Comments
 (0)