Skip to content

Commit b0f56db

Browse files
support complex (#73886)
1 parent b3bdf7e commit b0f56db

File tree

1 file changed

+21
-2
lines changed

1 file changed

+21
-2
lines changed

paddle/phi/kernels/funcs/tensor_formatter.cc

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,10 @@ std::string TensorFormatter::Format(const phi::DenseTensor& print_tensor,
115115
FormatData<phi::dtype::float8_e4m3fn>(print_tensor, log_stream);
116116
} else if (dtype == phi::DataType::FLOAT8_E5M2) {
117117
FormatData<phi::dtype::float8_e5m2>(print_tensor, log_stream);
118+
} else if (dtype == phi::DataType::COMPLEX64) {
119+
FormatData<phi::dtype::complex<float>>(print_tensor, log_stream);
120+
} else if (dtype == phi::DataType::COMPLEX128) {
121+
FormatData<phi::dtype::complex<double>>(print_tensor, log_stream);
118122
} else {
119123
log_stream << " - data: unprintable type: " << dtype << std::endl;
120124
}
@@ -143,9 +147,20 @@ void TensorFormatter::FormatData(const phi::DenseTensor& print_tensor,
143147

144148
log_stream << " - data: [";
145149
if (print_size > 0) {
146-
log_stream << data[0];
150+
auto print_element = [&log_stream](const auto& elem) {
151+
if constexpr (std::is_same_v<T, phi::dtype::complex<float>> ||
152+
std::is_same_v<T, phi::dtype::complex<double>>) {
153+
log_stream << static_cast<float>(elem.real) << "+"
154+
<< static_cast<float>(elem.imag) << "j";
155+
} else {
156+
log_stream << static_cast<float>(elem);
157+
}
158+
};
159+
160+
print_element(data[0]);
147161
for (int64_t i = 1; i < print_size; ++i) {
148-
log_stream << " " << static_cast<float>(data[i]);
162+
log_stream << " ";
163+
print_element(data[i]);
149164
}
150165
}
151166
log_stream << "]" << std::endl;
@@ -165,6 +180,10 @@ template void TensorFormatter::FormatData<phi::dtype::float16>(
165180
const phi::DenseTensor& print_tensor, std::stringstream& log_stream);
166181
template void TensorFormatter::FormatData<phi::dtype::bfloat16>(
167182
const phi::DenseTensor& print_tensor, std::stringstream& log_stream);
183+
template void TensorFormatter::FormatData<phi::dtype::complex<float>>(
184+
const phi::DenseTensor& print_tensor, std::stringstream& log_stream);
185+
template void TensorFormatter::FormatData<phi::dtype::complex<double>>(
186+
const phi::DenseTensor& print_tensor, std::stringstream& log_stream);
168187

169188
} // namespace funcs
170189
} // namespace paddle

0 commit comments

Comments
 (0)