There was an error while loading. Please reload this page.
1 parent 4204b97 commit 781484eCopy full SHA for 781484e
python/paddle/distributed/fleet/utils/recompute.py
@@ -182,9 +182,10 @@ def backward(ctx, *args):
182
"none of output has requires_grad=True, this recompute() is not necessary"
183
)
184
185
- # actually backward
186
- paddle.autograd.backward(forward_outputs_with_grad,
187
- backward_inputs_with_grad)
+ # actually backward
+ with paddle.amp.auto_cast(enable=False):
+ paddle.autograd.backward(forward_outputs_with_grad,
188
+ backward_inputs_with_grad)
189
190
grads = list(inp._grad_ivar() for inp in detached_inputs
191
if isinstance(inp, core.VarBase))
0 commit comments