@@ -115,6 +115,10 @@ std::string TensorFormatter::Format(const phi::DenseTensor& print_tensor,
115115 FormatData<phi::dtype::float8_e4m3fn>(print_tensor, log_stream);
116116 } else  if  (dtype == phi::DataType::FLOAT8_E5M2) {
117117 FormatData<phi::dtype::float8_e5m2>(print_tensor, log_stream);
118+  } else  if  (dtype == phi::DataType::COMPLEX64) {
119+  FormatData<phi::dtype::complex <float >>(print_tensor, log_stream);
120+  } else  if  (dtype == phi::DataType::COMPLEX128) {
121+  FormatData<phi::dtype::complex <double >>(print_tensor, log_stream);
118122 } else  {
119123 log_stream << "  - data: unprintable type: " 
120124 }
@@ -143,9 +147,20 @@ void TensorFormatter::FormatData(const phi::DenseTensor& print_tensor,
143147
144148 log_stream << "  - data: [" 
145149 if  (print_size > 0 ) {
146-  log_stream << data[0 ];
150+  auto  print_element = [&log_stream](const  auto & elem) {
151+  if  constexpr  (std::is_same_v<T, phi::dtype::complex <float >> ||
152+  std::is_same_v<T, phi::dtype::complex <double >>) {
153+  log_stream << static_cast <float >(elem.real ) << " +" 
154+  << static_cast <float >(elem.imag ) << " j" 
155+  } else  {
156+  log_stream << static_cast <float >(elem);
157+  }
158+  };
159+ 
160+  print_element (data[0 ]);
147161 for  (int64_t  i = 1 ; i < print_size; ++i) {
148-  log_stream << "  " static_cast <float >(data[i]);
162+  log_stream << "  " 
163+  print_element (data[i]);
149164 }
150165 }
151166 log_stream << " ]" 
@@ -165,6 +180,10 @@ template void TensorFormatter::FormatData<phi::dtype::float16>(
165180 const  phi::DenseTensor& print_tensor, std::stringstream& log_stream);
166181template  void  TensorFormatter::FormatData<phi::dtype::bfloat16>(
167182 const  phi::DenseTensor& print_tensor, std::stringstream& log_stream);
183+ template  void  TensorFormatter::FormatData<phi::dtype::complex <float >>(
184+  const  phi::DenseTensor& print_tensor, std::stringstream& log_stream);
185+ template  void  TensorFormatter::FormatData<phi::dtype::complex <double >>(
186+  const  phi::DenseTensor& print_tensor, std::stringstream& log_stream);
168187
169188} //  namespace funcs
170189} //  namespace paddle
0 commit comments