Skip to content

Commit f57b21e

Browse files
authored
[bf16] support printing bf16 tensor (#39375)
1 parent eacfc1e commit f57b21e

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

python/paddle/fluid/tests/unittests/test_var_base.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,6 +1094,20 @@ def test_tensor_str_linewidth2(self):
10941094
self.assertEqual(a_str, expected)
10951095
paddle.enable_static()
10961096

1097+
def test_tensor_str_bf16(self):
1098+
paddle.disable_static(paddle.CPUPlace())
1099+
a = paddle.to_tensor([[1.5, 1.0], [0, 0]])
1100+
a = paddle.cast(a, dtype=core.VarDesc.VarType.BF16)
1101+
paddle.set_printoptions(precision=4)
1102+
a_str = str(a)
1103+
1104+
expected = '''Tensor(shape=[2, 2], dtype=bfloat16, place=Place(cpu), stop_gradient=True,
1105+
[[1.5000, 1. ],
1106+
[0. , 0. ]])'''
1107+
1108+
self.assertEqual(a_str, expected)
1109+
paddle.enable_static()
1110+
10971111
def test_print_tensor_dtype(self):
10981112
paddle.disable_static(paddle.CPUPlace())
10991113
a = paddle.rand([1])

python/paddle/tensor/to_string.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,12 +223,18 @@ def _format_tensor(var, summary, indent=0, max_width=0, signed=False):
223223
def to_string(var, prefix='Tensor'):
224224
indent = len(prefix) + 1
225225

226+
dtype = convert_dtype(var.dtype)
227+
if var.dtype == core.VarDesc.VarType.BF16:
228+
dtype = 'bfloat16'
229+
226230
_template = "{prefix}(shape={shape}, dtype={dtype}, place={place}, stop_gradient={stop_gradient},\n{indent}{data})"
227231

228232
tensor = var.value().get_tensor()
229233
if not tensor._is_initialized():
230234
return "Tensor(Not initialized)"
231235

236+
if var.dtype == core.VarDesc.VarType.BF16:
237+
var = var.astype('float32')
232238
np_var = var.numpy()
233239

234240
if len(var.shape) == 0:
@@ -250,7 +256,7 @@ def to_string(var, prefix='Tensor'):
250256
return _template.format(
251257
prefix=prefix,
252258
shape=var.shape,
253-
dtype=convert_dtype(var.dtype),
259+
dtype=dtype,
254260
place=var._place_str,
255261
stop_gradient=var.stop_gradient,
256262
indent=' ' * indent,

0 commit comments

Comments
 (0)