Skip to content

Commit 54632b5

Browse files
authored
Fix param@grad type error for amp in run_program (PaddlePaddle#40938)
1 parent 09e5b00 commit 54632b5

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,9 @@ def _train_amp_program(self):
204204
"""
205205
Lazy initialized property of train_amp_program.
206206
"""
207-
return self._append_backward_desc(self._infer_amp_program)
207+
train_amp_program = self._append_backward_desc(self._infer_amp_program)
208+
self._set_grad_type(self._params, train_amp_program)
209+
return train_amp_program
208210

209211
@LazyInitialized
210212
@switch_to_static_graph
@@ -224,7 +226,10 @@ def _train_pure_fp16_program(self):
224226
"""
225227
Lazy initialized property of _train_pure_fp16_program.
226228
"""
227-
return self._append_backward_desc(self._infer_pure_fp16_program)
229+
train_pure_fp16_program = self._append_backward_desc(
230+
self._infer_pure_fp16_program)
231+
self._set_grad_type(self._params, train_pure_fp16_program)
232+
return train_pure_fp16_program
228233

229234
@LazyInitialized
230235
def _infer_program_id(self):

0 commit comments

Comments
 (0)