Skip to content

Commit 3480714

Browse files
authored
【complex op No.50】tan_coo/tan_csr(sparse) (#67885)
* add complex for tan_coo/tan_csr * add two missed changes
1 parent 87fbd08 commit 3480714

File tree

7 files changed

+16
-8
lines changed

7 files changed

+16
-8
lines changed

paddle/phi/kernels/sparse/cpu/unary_grad_kernel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@
6060
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); \
6161
}
6262

63-
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(tan, Tan)
6463
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(asin, Asin)
6564
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(atan, Atan)
6665
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(sinh, Sinh)
@@ -76,6 +75,7 @@ PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(expm1, Expm1)
7675
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(relu6, Relu6)
7776
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(leaky_relu, LeakyRelu)
7877

78+
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL_WITH_COMPLEX(tan, Tan)
7979
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL_WITH_COMPLEX(sin, Sin)
8080
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL_WITH_COMPLEX(abs, Abs)
8181

paddle/phi/kernels/sparse/cpu/unary_kernel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,6 @@ void DivScalarCsrKernel(const Context& dev_ctx,
101101
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); \
102102
}
103103

104-
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(tan, Tan)
105104
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(asin, Asin)
106105
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(atan, Atan)
107106
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(sinh, Sinh)
@@ -118,6 +117,7 @@ PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(expm1, Expm1)
118117
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(relu6, Relu6)
119118
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(leaky_relu, LeakyRelu)
120119

120+
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL_WITH_COMPLEX(tan, Tan)
121121
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL_WITH_COMPLEX(sin, Sin)
122122
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL_WITH_COMPLEX(abs, Abs)
123123

paddle/phi/kernels/sparse/gpu/unary_grad_kernel.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@
6464
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); \
6565
}
6666

67-
PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(tan, Tan)
6867
PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(asin, Asin)
6968
PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(atan, Atan)
7069
PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(sinh, Sinh)
@@ -82,6 +81,7 @@ PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(leaky_relu, LeakyRelu)
8281

8382
PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL_WITH_COMPLEX(sin, Sin)
8483
PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL_WITH_COMPLEX(abs, Abs)
84+
PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL_WITH_COMPLEX(tan, Tan)
8585

8686
PD_REGISTER_KERNEL(cast_coo_grad,
8787
GPU,

paddle/phi/kernels/sparse/gpu/unary_kernel.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ void DivScalarCsrKernel(const Context& dev_ctx,
9494
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); \
9595
}
9696

97-
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(tan, Tan)
9897
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(asin, Asin)
9998
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(atan, Atan)
10099
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(sinh, Sinh)
@@ -111,6 +110,7 @@ PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(expm1, Expm1)
111110
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(relu6, Relu6)
112111
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(leaky_relu, LeakyRelu)
113112

113+
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL_WITH_COMPLEX(tan, Tan)
114114
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL_WITH_COMPLEX(sin, Sin)
115115
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL_WITH_COMPLEX(abs, Abs)
116116

paddle/phi/kernels/sparse/impl/unary_kernel_impl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ namespace sparse {
103103
dev_ctx, x.non_zero_elements(), out->mutable_non_zero_elements()); \
104104
}
105105

106-
DEFINE_SPARSE_UNARY_KERNEL(Tan)
107106
DEFINE_SPARSE_UNARY_KERNEL(Asin)
108107
DEFINE_SPARSE_UNARY_KERNEL(Atan)
109108
DEFINE_SPARSE_UNARY_KERNEL(Sinh)
@@ -120,6 +119,7 @@ DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Pow, factor)
120119
DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(LeakyRelu, alpha)
121120
DEFINE_SPARSE_UNARY_KERNEL_WITH_COMPLEX(Abs)
122121
DEFINE_SPARSE_UNARY_KERNEL_WITH_COMPLEX(Sin)
122+
DEFINE_SPARSE_UNARY_KERNEL_WITH_COMPLEX(Tan)
123123

124124
template <typename T, typename Context>
125125
void ScaleCooKernel(const Context& dev_ctx,

python/paddle/sparse/unary.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def tan(x: Tensor, name: str | None = None) -> Tensor:
9191
out = tan(x)
9292
9393
Parameters:
94-
x (Tensor): The input Sparse Tensor with data type float32, float64.
94+
x (Tensor): The input Sparse Tensor with data type float32, float64, complex64, complex128.
9595
name (str|None, optional): Name for the operation (optional, default is None).
9696
For more information, please refer to :ref:`api_guide_Name`.
9797

test/legacy_test/test_sparse_unary_op.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,11 @@ def test_sparse_sin(self):
166166
self.compare_with_dense(paddle.sin, paddle.sparse.sin, 'complex128')
167167

168168
def test_sparse_tan(self):
169-
self.compare_with_dense(paddle.tan, paddle.sparse.tan)
169+
self.compare_with_dense(paddle.tan, paddle.sparse.tan, 'float16')
170+
self.compare_with_dense(paddle.tan, paddle.sparse.tan, 'float32')
171+
self.compare_with_dense(paddle.tan, paddle.sparse.tan, 'float64')
172+
self.compare_with_dense(paddle.tan, paddle.sparse.tan, 'complex64')
173+
self.compare_with_dense(paddle.tan, paddle.sparse.tan, 'complex128')
170174

171175
def test_sparse_asin(self):
172176
self.compare_with_dense(paddle.asin, paddle.sparse.asin)
@@ -400,7 +404,11 @@ def test_sparse_sin(self):
400404
self.compare_with_dense(paddle.sin, paddle.sparse.sin, 'complex128')
401405

402406
def test_sparse_tan(self):
403-
self.compare_with_dense(paddle.tan, paddle.sparse.tan)
407+
self.compare_with_dense(paddle.tan, paddle.sparse.tan, 'float16')
408+
self.compare_with_dense(paddle.tan, paddle.sparse.tan, 'float32')
409+
self.compare_with_dense(paddle.tan, paddle.sparse.tan, 'float64')
410+
self.compare_with_dense(paddle.tan, paddle.sparse.tan, 'complex64')
411+
self.compare_with_dense(paddle.tan, paddle.sparse.tan, 'complex128')
404412

405413
def test_sparse_asin(self):
406414
self.compare_with_dense(paddle.asin, paddle.sparse.asin)

0 commit comments

Comments
 (0)