@@ -84,15 +84,11 @@ template <typename T, typename Context>
8484void MatrixRankTolKernel (const Context& dev_ctx,
8585 const DenseTensor& x,
8686 const DenseTensor& atol_tensor,
87- bool hermitian,
8887 bool use_default_tol,
88+ bool hermitian,
8989 DenseTensor* out) {
90- // const Tensor* x = context.Input<Tensor>("X");
9190 auto * x_data = x.data <T>();
92- // auto* out = context.Output<Tensor>("Out");
9391 dev_ctx.template Alloc <int64_t >(out);
94- // out->mutable_data<int64_t>(context.GetPlace());
95- // bool hermitian = context.Attr<bool>("hermitian");
9692
9793 auto dim_x = x.dims ();
9894 auto dim_out = out->dims ();
@@ -103,47 +99,43 @@ void MatrixRankTolKernel(const Context& dev_ctx,
10399 int batches = numel / (rows * cols);
104100
105101 T rtol_T = 0 ;
106- // DenseTensor atol_dense_tensor;
107- // DenseTensor temp_tensor;
102+
108103 if (use_default_tol) {
109- // atol_tensor = temp_tensor;
110104 rtol_T = std::numeric_limits<T>::epsilon () * std::max (rows, cols);
111105 }
112106
113107 DenseTensor eigenvalue_tensor;
114- // auto* eigenvalue_data = eigenvalue_tensor.mutable_data<T>(
115- // detail::GetEigenvalueDim(dim_x, k), context.GetPlace());
116108 eigenvalue_tensor.Resize (detail::GetEigenvalueDim (dim_x, k));
117109 auto * eigenvalue_data = dev_ctx.template Alloc <T>(&eigenvalue_tensor);
110+
118111 if (hermitian) {
119112 BatchEigenvalues<T>(x_data, eigenvalue_data, batches, rows, cols, k);
120113 } else {
121114 BatchSVD<T>(x_data, eigenvalue_data, batches, rows, cols, k);
122115 }
123116
124- // auto dito_T = math::DeviceIndependenceTensorOperations<
125- // paddle::platform::CPUDeviceContext,
126- // T>(context);
127- std::vector<int > max_eigenvalue_shape =
128- phi::vectorize<int >(detail::RemoveLastDim (eigenvalue_tensor.dims ()));
129117 DenseTensor max_eigenvalue_tensor;
130- // =
131- // dito_T.ReduceMax(eigenvalue_tensor, max_eigenvalue_shape );
118+ max_eigenvalue_tensor. Resize ( detail::RemoveLastDim (eigenvalue_tensor. dims ()));
119+ dev_ctx. template Alloc <T>(&max_eigenvalue_tensor );
132120
121+ ReduceKernelImpl<Context, T, T, phi::funcs::MaxFunctor>(
122+ dev_ctx,
123+ eigenvalue_tensor,
124+ &max_eigenvalue_tensor,
125+ std::vector<int64_t >{-1 },
126+ false ,
127+ false );
133128 DenseTensor temp_rtol_tensor;
134129 paddle::framework::TensorFromVector<T>(std::vector<T>{rtol_T},
135130 &temp_rtol_tensor);
136-
131+ std::cout << " \n 1111111111111 \n " ;
137132 DenseTensor rtol_tensor =
138133 phi::Multiply<T>(dev_ctx, temp_rtol_tensor, max_eigenvalue_tensor);
139- // DenseTensor rtol_tensor = dito_T.Mul(temp_rtol_tensor,
140- // max_eigenvalue_tensor);
141134
142135 DenseTensor tol_tensor;
143136 tol_tensor.Resize (detail::NewAxisDim (dim_out, k));
144137 dev_ctx.template Alloc <T>(&tol_tensor);
145- // tol_tensor.mutable_data<T>(dim_out, context.GetPlace());
146-
138+ std::cout << " \n 1111111111112\n " ;
147139 funcs::ElementwiseCompute<GreaterElementFunctor<T>, T, T>(
148140 dev_ctx,
149141 atol_tensor,
@@ -156,10 +148,8 @@ void MatrixRankTolKernel(const Context& dev_ctx,
156148
157149 DenseTensor compare_result;
158150 compare_result.Resize (detail::NewAxisDim (dim_out, k));
159- dev_ctx.template Alloc <T>(&compare_result);
160- // compare_result.mutable_data<int64_t>(detail::NewAxisDim(dim_out, k),
161- // context.GetPlace());
162-
151+ dev_ctx.template Alloc <int64_t >(&compare_result);
152+ std::cout << " \n 1111111111113\n " ;
163153 int axis = -1 ;
164154 if (eigenvalue_tensor.dims ().size () >= tol_tensor.dims ().size ()) {
165155 funcs::ElementwiseCompute<funcs::GreaterThanFunctor<T, int64_t >, T, int >(
@@ -177,13 +167,29 @@ void MatrixRankTolKernel(const Context& dev_ctx,
177167 axis,
178168 funcs::LessThanFunctor<T, int64_t >(),
179169 &compare_result);
180-
181- std::vector<int64_t > result_shape = phi::vectorize<int64_t >(dim_out);
182- DenseTensor result;
183- ReduceKernelImpl<Context, T, T, phi::funcs::SumFunctor>(
184- dev_ctx, compare_result, &result, result_shape, true , false );
185- // DenseTensor result = dito_int.ReduceSum(compare_result, result_shape);
170+ std::cout << " \n 1111111111144414\n " ;
171+ // DenseTensor result;
172+ // result.Resize(dim_out);
173+ // dev_ctx.template Alloc<T>(&result);
174+ std::cout << " \n 1111111111144416677\n " ;
175+ std::cout << " compare_result: " << compare_result << " \n " ;
176+ DenseTensor result = phi::Sum<T>(dev_ctx,
177+ compare_result,
178+ std::vector<int64_t >{-1 },
179+ compare_result.dtype (),
180+ false );
181+ // SumKernel<T, Context>(dev_ctx, compare_result, std::vector<int64_t>{-1},
182+ // compare_result.type(), );
183+ // ReduceKernelImpl<Context, T, T, phi::funcs::SumFunctor>(
184+ // dev_ctx,
185+ // compare_result,
186+ // &result,
187+ // std::vector<int64_t>{-1},
188+ // true,
189+ // false);
190+ std::cout << " \n 1111111111116\n " ;
186191 out->ShareDataWith (result);
192+ std::cout << " \n 1111111111115\n " ;
187193 }
188194}
189195} // namespace phi
0 commit comments