Skip to content

Commit 781484e

Browse files
committed
run recompute's real backward with amp disabled
1 parent 4204b97 commit 781484e

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

python/paddle/distributed/fleet/utils/recompute.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,10 @@ def backward(ctx, *args):
182182
"none of output has requires_grad=True, this recompute() is not necessary"
183183
)
184184

185-
# actually backward
186-
paddle.autograd.backward(forward_outputs_with_grad,
187-
backward_inputs_with_grad)
185+
# actually backward
186+
with paddle.amp.auto_cast(enable=False):
187+
paddle.autograd.backward(forward_outputs_with_grad,
188+
backward_inputs_with_grad)
188189

189190
grads = list(inp._grad_ivar() for inp in detached_inputs
190191
if isinstance(inp, core.VarBase))

0 commit comments

Comments
 (0)