Skip to content

Commit b76308e

Browse files
committed
add max functor
1 parent 48d2f24 commit b76308e

File tree

12 files changed

+91
-183
lines changed

12 files changed

+91
-183
lines changed

paddle/phi/infermeta/binary.cc

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -688,53 +688,4 @@ void SigmoidCrossEntropyWithLogitsInferMeta(const MetaTensor& x,
688688
out->share_lod(x);
689689
}
690690

691-
namespace detail {
692-
static DDim CheckAndGetOutputDim(const DDim& dim_x) {
693-
auto x_vec = phi::vectorize(dim_x);
694-
if (x_vec.size() == 2) {
695-
return phi::make_ddim({1});
696-
}
697-
x_vec.erase(x_vec.end() - 2, x_vec.end());
698-
return phi::make_ddim(x_vec);
699-
}
700-
} // namespace detail
701-
702-
// void MatrixRankTolMeta(const MetaTensor& x,
703-
// const MetaTensor& tol_tensor,
704-
// bool use_default_tol,
705-
// bool hermitian,
706-
// MetaTensor* out){
707-
// auto dim_x = x.dims();
708-
// PADDLE_ENFORCE_GE(dim_x.size(), 2,
709-
// phi::errors::InvalidArgument(
710-
// "The dims of input must be greater than 2"));
711-
712-
// if (hermitian) {
713-
// int rows = dim_x[dim_x.size() - 2];
714-
// int cols = dim_x[dim_x.size() - 1];
715-
// PADDLE_ENFORCE_EQ(rows, cols,
716-
// phi::errors::InvalidArgument(
717-
// "if hermitian == true, matrix should be n*n"));
718-
// }
719-
720-
// DDim dim_x_batch = detail::CheckAndGetOutputDim(dim_x);
721-
// auto dim_tol = tol_tensor.dims();
722-
// if (dim_x_batch == dim_tol) {
723-
// out->set_dims(dim_x_batch);
724-
// } else {
725-
// int max_dim = std::max(dim_x_batch.size(), dim_tol.size());
726-
// int axis = std::abs(dim_x_batch.size() - dim_tol.size());
727-
// std::vector<int> x_batch_dims_array(max_dim);
728-
// std::vector<int> tol_dims_array(max_dim);
729-
// std::vector<int> out_dims_array(max_dim);
730-
// funcs::GetBroadcastDimsArrays(dim_x_batch, dim_tol,
731-
// x_batch_dims_array.data(),
732-
// tol_dims_array.data(), out_dims_array.data(),
733-
// max_dim, axis);
734-
// out->set_dims(phi::make_ddim(out_dims_array));
735-
// }
736-
// out->set_dtype(x.dtype());
737-
// out->share_lod(x);
738-
// }
739-
740691
} // namespace phi

paddle/phi/infermeta/binary.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,4 @@ void SigmoidCrossEntropyWithLogitsInferMeta(const MetaTensor& x,
118118
MetaTensor* out,
119119
MetaConfig config = MetaConfig());
120120

121-
void MatrixRankTolMeta(const MetaTensor& x,
122-
const MetaTensor& tol_tensor,
123-
bool use_default_tol,
124-
bool hermitian,
125-
MetaTensor* out);
126-
127121
} // namespace phi

paddle/phi/infermeta/unary.cc

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,43 +1106,6 @@ void TransposeInferMeta(const MetaTensor& x,
11061106
out->set_dtype(x.dtype());
11071107
}
11081108

1109-
namespace detail {
1110-
static DDim CheckAndGetOutputDim(const DDim& dim_x) {
1111-
auto x_vec = phi::vectorize(dim_x);
1112-
if (x_vec.size() == 2) {
1113-
return phi::make_ddim({1});
1114-
}
1115-
x_vec.erase(x_vec.end() - 2, x_vec.end());
1116-
return phi::make_ddim(x_vec);
1117-
}
1118-
} // namespace detail
1119-
1120-
void MatrixRankMeta(const MetaTensor& x,
1121-
bool use_default_tol,
1122-
bool hermitian,
1123-
float tol,
1124-
MetaTensor* out) {
1125-
auto dim_x = x.dims();
1126-
PADDLE_ENFORCE_GE(
1127-
dim_x.size(),
1128-
2,
1129-
phi::errors::InvalidArgument("The dims of input must be greater than 2"));
1130-
1131-
if (hermitian) {
1132-
int rows = dim_x[dim_x.size() - 2];
1133-
int cols = dim_x[dim_x.size() - 1];
1134-
PADDLE_ENFORCE_EQ(rows,
1135-
cols,
1136-
phi::errors::InvalidArgument(
1137-
"if hermitian == true, matrix should be n*n"));
1138-
}
1139-
1140-
DDim dim_x_batch = detail::CheckAndGetOutputDim(dim_x);
1141-
out->set_dims(dim_x_batch);
1142-
out->set_dtype(x.dtype());
1143-
out->share_lod(x);
1144-
}
1145-
11461109
} // namespace phi
11471110

11481111
PD_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta);

paddle/phi/infermeta/unary.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,4 @@ void TransposeInferMeta(const MetaTensor& x,
161161
const std::vector<int>& axis,
162162
MetaTensor* out);
163163

164-
void MatrixRankMeta(const MetaTensor& x,
165-
bool use_default_tol,
166-
bool hermitian,
167-
float tol,
168-
MetaTensor* out);
169-
170164
} // namespace phi

paddle/phi/kernels/cpu/matrix_rank_kernel.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ namespace phi {
2121
template <typename T, typename Context>
2222
void MatrixRankKernel(const Context& dev_ctx,
2323
const DenseTensor& x,
24-
bool hermitian,
25-
bool use_default_tol,
2624
float tol,
25+
bool use_default_tol,
26+
bool hermitian,
2727
DenseTensor* out) {
2828
DenseTensor atol_tensor;
2929
if (use_default_tol) {
@@ -34,7 +34,7 @@ void MatrixRankKernel(const Context& dev_ctx,
3434
std::vector<T>{tol}, dev_ctx, &atol_tensor);
3535
}
3636
MatrixRankTolKernel<T, Context>(
37-
dev_ctx, x, atol_tensor, hermitian, use_default_tol, out);
37+
dev_ctx, x, atol_tensor, use_default_tol, hermitian, out);
3838
}
3939

4040
} // namespace phi

paddle/phi/kernels/cpu/matrix_rank_tol_kernel.cc

Lines changed: 38 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,11 @@ template <typename T, typename Context>
8484
void MatrixRankTolKernel(const Context& dev_ctx,
8585
const DenseTensor& x,
8686
const DenseTensor& atol_tensor,
87-
bool hermitian,
8887
bool use_default_tol,
88+
bool hermitian,
8989
DenseTensor* out) {
90-
// const Tensor* x = context.Input<Tensor>("X");
9190
auto* x_data = x.data<T>();
92-
// auto* out = context.Output<Tensor>("Out");
9391
dev_ctx.template Alloc<int64_t>(out);
94-
// out->mutable_data<int64_t>(context.GetPlace());
95-
// bool hermitian = context.Attr<bool>("hermitian");
9692

9793
auto dim_x = x.dims();
9894
auto dim_out = out->dims();
@@ -103,47 +99,43 @@ void MatrixRankTolKernel(const Context& dev_ctx,
10399
int batches = numel / (rows * cols);
104100

105101
T rtol_T = 0;
106-
// DenseTensor atol_dense_tensor;
107-
// DenseTensor temp_tensor;
102+
108103
if (use_default_tol) {
109-
// atol_tensor = temp_tensor;
110104
rtol_T = std::numeric_limits<T>::epsilon() * std::max(rows, cols);
111105
}
112106

113107
DenseTensor eigenvalue_tensor;
114-
// auto* eigenvalue_data = eigenvalue_tensor.mutable_data<T>(
115-
// detail::GetEigenvalueDim(dim_x, k), context.GetPlace());
116108
eigenvalue_tensor.Resize(detail::GetEigenvalueDim(dim_x, k));
117109
auto* eigenvalue_data = dev_ctx.template Alloc<T>(&eigenvalue_tensor);
110+
118111
if (hermitian) {
119112
BatchEigenvalues<T>(x_data, eigenvalue_data, batches, rows, cols, k);
120113
} else {
121114
BatchSVD<T>(x_data, eigenvalue_data, batches, rows, cols, k);
122115
}
123116

124-
// auto dito_T = math::DeviceIndependenceTensorOperations<
125-
// paddle::platform::CPUDeviceContext,
126-
// T>(context);
127-
std::vector<int> max_eigenvalue_shape =
128-
phi::vectorize<int>(detail::RemoveLastDim(eigenvalue_tensor.dims()));
129117
DenseTensor max_eigenvalue_tensor;
130-
// =
131-
// dito_T.ReduceMax(eigenvalue_tensor, max_eigenvalue_shape);
118+
max_eigenvalue_tensor.Resize(detail::RemoveLastDim(eigenvalue_tensor.dims()));
119+
dev_ctx.template Alloc<T>(&max_eigenvalue_tensor);
132120

121+
ReduceKernelImpl<Context, T, T, phi::funcs::MaxFunctor>(
122+
dev_ctx,
123+
eigenvalue_tensor,
124+
&max_eigenvalue_tensor,
125+
std::vector<int64_t>{-1},
126+
false,
127+
false);
133128
DenseTensor temp_rtol_tensor;
134129
paddle::framework::TensorFromVector<T>(std::vector<T>{rtol_T},
135130
&temp_rtol_tensor);
136-
131+
std::cout << "\n1111111111111\n";
137132
DenseTensor rtol_tensor =
138133
phi::Multiply<T>(dev_ctx, temp_rtol_tensor, max_eigenvalue_tensor);
139-
// DenseTensor rtol_tensor = dito_T.Mul(temp_rtol_tensor,
140-
// max_eigenvalue_tensor);
141134

142135
DenseTensor tol_tensor;
143136
tol_tensor.Resize(detail::NewAxisDim(dim_out, k));
144137
dev_ctx.template Alloc<T>(&tol_tensor);
145-
// tol_tensor.mutable_data<T>(dim_out, context.GetPlace());
146-
138+
std::cout << "\n1111111111112\n";
147139
funcs::ElementwiseCompute<GreaterElementFunctor<T>, T, T>(
148140
dev_ctx,
149141
atol_tensor,
@@ -156,10 +148,8 @@ void MatrixRankTolKernel(const Context& dev_ctx,
156148

157149
DenseTensor compare_result;
158150
compare_result.Resize(detail::NewAxisDim(dim_out, k));
159-
dev_ctx.template Alloc<T>(&compare_result);
160-
// compare_result.mutable_data<int64_t>(detail::NewAxisDim(dim_out, k),
161-
// context.GetPlace());
162-
151+
dev_ctx.template Alloc<int64_t>(&compare_result);
152+
std::cout << "\n1111111111113\n";
163153
int axis = -1;
164154
if (eigenvalue_tensor.dims().size() >= tol_tensor.dims().size()) {
165155
funcs::ElementwiseCompute<funcs::GreaterThanFunctor<T, int64_t>, T, int>(
@@ -177,13 +167,29 @@ void MatrixRankTolKernel(const Context& dev_ctx,
177167
axis,
178168
funcs::LessThanFunctor<T, int64_t>(),
179169
&compare_result);
180-
181-
std::vector<int64_t> result_shape = phi::vectorize<int64_t>(dim_out);
182-
DenseTensor result;
183-
ReduceKernelImpl<Context, T, T, phi::funcs::SumFunctor>(
184-
dev_ctx, compare_result, &result, result_shape, true, false);
185-
// DenseTensor result = dito_int.ReduceSum(compare_result, result_shape);
170+
std::cout << "\n1111111111144414\n";
171+
// DenseTensor result;
172+
// result.Resize(dim_out);
173+
// dev_ctx.template Alloc<T>(&result);
174+
std::cout << "\n1111111111144416677\n";
175+
std::cout << "compare_result: " << compare_result << "\n";
176+
DenseTensor result = phi::Sum<T>(dev_ctx,
177+
compare_result,
178+
std::vector<int64_t>{-1},
179+
compare_result.dtype(),
180+
false);
181+
// SumKernel<T, Context>(dev_ctx, compare_result, std::vector<int64_t>{-1},
182+
// compare_result.type(), );
183+
// ReduceKernelImpl<Context, T, T, phi::funcs::SumFunctor>(
184+
// dev_ctx,
185+
// compare_result,
186+
// &result,
187+
// std::vector<int64_t>{-1},
188+
// true,
189+
// false);
190+
std::cout << "\n1111111111116\n";
186191
out->ShareDataWith(result);
192+
std::cout << "\n1111111111115\n";
187193
}
188194
}
189195
} // namespace phi

paddle/phi/kernels/funcs/reduce_functor.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,13 @@ struct ProdFunctor {
4141
}
4242
};
4343

44+
//////// Max Functor ///////
45+
struct MaxFunctor {
46+
template <typename DeviceContext, typename X, typename Y, typename Dim>
47+
void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
48+
y->device(place) = x->maximum(dim);
49+
}
50+
};
51+
4452
} // namespace funcs
4553
} // namespace phi

paddle/phi/kernels/gpu/matrix_rank_kernel.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ namespace phi {
2424
template <typename T, typename Context>
2525
void MatrixRankKernel(const Context& dev_ctx,
2626
const DenseTensor& x,
27-
bool hermitian,
28-
bool use_default_tol,
2927
float tol,
28+
bool use_default_tol,
29+
bool hermitian,
3030
DenseTensor* out) {
3131
DenseTensor atol_tensor;
3232
if (use_default_tol) {
@@ -37,7 +37,7 @@ void MatrixRankKernel(const Context& dev_ctx,
3737
std::vector<T>{tol}, dev_ctx, &atol_tensor);
3838
}
3939
MatrixRankTolKernel<T, Context>(
40-
dev_ctx, x, atol_tensor, hermitian, use_default_tol, out);
40+
dev_ctx, x, atol_tensor, use_default_tol, hermitian, out);
4141
}
4242

4343
} // namespace phi

0 commit comments

Comments
 (0)