|
2 | 2 | import sys |
3 | 3 | from string import Template, ascii_lowercase |
4 | 4 | from ..cwrap import cwrap |
5 | | -from ..cwrap.plugins import StandaloneExtension, GenericNN, NullableArguments, AutoGPU |
| 5 | +from ..cwrap.plugins import StandaloneExtension, NullableArguments, AutoGPU |
6 | 6 | from ..shared import import_module |
7 | 7 |
|
8 | 8 | BASE_PATH = os.path.realpath(os.path.join(__file__, '..', '..', '..')) |
@@ -98,7 +98,6 @@ def wrap_function(name, type, arguments): |
98 | 98 | def generate_wrappers(): |
99 | 99 | wrap_nn() |
100 | 100 | wrap_cunn() |
101 | | - wrap_generic() |
102 | 101 |
|
103 | 102 |
|
104 | 103 | def wrap_nn(): |
@@ -129,67 +128,3 @@ def wrap_cunn(): |
129 | 128 | NullableArguments(), |
130 | 129 | AutoGPU(has_self=False), |
131 | 130 | ]) |
132 | | - |
133 | | - |
134 | | -GENERIC_FUNCTION_TEMPLATE = Template("""\ |
135 | | -[[ |
136 | | - name: $name |
137 | | - return: void |
138 | | - options: |
139 | | -""") |
140 | | - |
141 | | - |
142 | | -def wrap_generic_function(name, backends): |
143 | | - declaration = '' |
144 | | - declaration += GENERIC_FUNCTION_TEMPLATE.substitute(name=name) |
145 | | - for backend in backends: |
146 | | - declaration += ' - cname: ' + name + '\n' |
147 | | - declaration += ' backend: ' + backend['name'] + '\n' |
148 | | - declaration += ' arguments:\n' |
149 | | - for arg in backend['arguments']: |
150 | | - declaration += ' - arg: ' + arg.type + ' ' + arg.name + '\n' |
151 | | - if arg.is_optional: |
152 | | - declaration += ' optional: True\n' |
153 | | - declaration += ']]\n\n\n' |
154 | | - return declaration |
155 | | - |
156 | | - |
157 | | -def wrap_generic(): |
158 | | - from collections import OrderedDict |
159 | | - defs = OrderedDict() |
160 | | - |
161 | | - def should_wrap_function(name): |
162 | | - if name.startswith('LookupTable_'): |
163 | | - return False |
164 | | - return (name.endswith('updateOutput') or |
165 | | - name.endswith('updateGradInput') or |
166 | | - name.endswith('accGradParameters') or |
167 | | - name.endswith('backward')) |
168 | | - |
169 | | - def add_functions(name, functions): |
170 | | - for fn in functions: |
171 | | - if not should_wrap_function(fn.name): |
172 | | - continue |
173 | | - if fn.name not in defs: |
174 | | - defs[fn.name] = [] |
175 | | - defs[fn.name] += [{ |
176 | | - 'name': name, |
177 | | - 'arguments': fn.arguments[1:], |
178 | | - }] |
179 | | - |
180 | | - add_functions('nn', thnn_utils.parse_header(thnn_utils.THNN_H_PATH)) |
181 | | - add_functions('cunn', thnn_utils.parse_header(thnn_utils.THCUNN_H_PATH)) |
182 | | - |
183 | | - wrapper = '' |
184 | | - for name, backends in defs.items(): |
185 | | - wrapper += wrap_generic_function(name, backends) |
186 | | - with open('torch/csrc/nn/THNN_generic.cwrap', 'w') as f: |
187 | | - f.write(wrapper) |
188 | | - |
189 | | - cwrap('torch/csrc/nn/THNN_generic.cwrap', plugins=[ |
190 | | - GenericNN(header=True), |
191 | | - ], default_plugins=False, destination='torch/csrc/nn/THNN_generic.h') |
192 | | - |
193 | | - cwrap('torch/csrc/nn/THNN_generic.cwrap', plugins=[ |
194 | | - GenericNN(), |
195 | | - ], default_plugins=False) |
0 commit comments