Skip to content

Commit 269dedd

Browse files
fix bug in amp-bf16 (#60268)
1 parent 27a8184 commit 269dedd

File tree

1 file changed

+3
-1
lines changed
  • python/paddle/distributed/auto_parallel/static

1 file changed

+3
-1
lines changed

python/paddle/distributed/auto_parallel/static/engine.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -848,7 +848,9 @@ def _initialize(self, mode, init_parameters=True):
848848
# for amp
849849
if dest_type == core.VarDesc.VarType.BF16:
850850
buffer_tensor.set(
851-
_convert_float_to_bfloat16(buffer.numpy()),
851+
_convert_float_to_bfloat16(
852+
self._place, buffer.numpy()
853+
),
852854
self._place,
853855
)
854856
elif dest_type == core.VarDesc.VarType.FP16:

0 commit comments

Comments
 (0)