Skip to content

Commit 5deff64

Browse files
committed
fix matrix_rank
1 parent 333c8a6 commit 5deff64

File tree

13 files changed

+963
-363
lines changed

13 files changed

+963
-363
lines changed

paddle/fluid/operators/eigh_op.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,11 @@ REGISTER_OPERATOR(eigh, ops::EighOp, ops::EignOpMaker,
145145
ops::EighGradOpMaker<paddle::framework::OpDesc>,
146146
ops::EighGradOpMaker<paddle::imperative::OpBase>);
147147
REGISTER_OPERATOR(eigh_grad, ops::EighGradOp);
148+
149+
REGISTER_OP_CPU_KERNEL(
150+
eigh_grad, ops::EighGradKernel<paddle::platform::CPUDeviceContext, float>,
151+
ops::EighGradKernel<paddle::platform::CPUDeviceContext, double>,
152+
ops::EighGradKernel<paddle::platform::CPUDeviceContext,
153+
paddle::platform::complex<float>>,
154+
ops::EighGradKernel<paddle::platform::CPUDeviceContext,
155+
paddle::platform::complex<double>>);

paddle/fluid/operators/eigh_op.cu

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include "paddle/fluid/operators/eigh_op.h"
13+
14+
namespace ops = paddle::operators;
15+
16+
REGISTER_OP_CUDA_KERNEL(
17+
eigh_grad, ops::EighGradKernel<paddle::platform::CUDADeviceContext, float>,
18+
ops::EighGradKernel<paddle::platform::CUDADeviceContext, double>,
19+
ops::EighGradKernel<paddle::platform::CUDADeviceContext,
20+
paddle::platform::complex<float>>,
21+
ops::EighGradKernel<paddle::platform::CUDADeviceContext,
22+
paddle::platform::complex<double>>);

paddle/phi/kernels/cpu/eigh_grad_kernel.cc

Lines changed: 0 additions & 26 deletions
This file was deleted.

0 commit comments

Comments
 (0)