Skip to content

Commit 2005b98

Browse files
authored
fix recompute no grad warning (#38293)
1 parent 06cf314 commit 2005b98

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

python/paddle/distributed/fleet/utils/recompute.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ def swith_rng_state(rng_state):
6363
class RecomputeFunction(PyLayer):
6464
@staticmethod
6565
def forward(ctx, run_function, preserve_rng_state, *args):
66-
check_recompute_necessary(args)
66+
if framework._dygraph_tracer()._has_grad:
67+
check_recompute_necessary(args)
6768

6869
# store for recomputing
6970
ctx.run_function = run_function

0 commit comments

Comments
 (0)