Skip to content

Commit 03eb792

Browse files
authored
【Phi】Migrate bitwise_and/bitwise_or/bitwise_xor/bitwise_not op into phi (#40031)
* Migrate bitwise_and/or/xor/not op into phi * fix CI
1 parent d9dd840 commit 03eb792

File tree

8 files changed

+313
-210
lines changed

8 files changed

+313
-210
lines changed

paddle/fluid/operators/controlflow/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,4 @@ endif()
2121

2222
file(APPEND ${pybind_file} "USE_OP_ITSELF(less_than);\nUSE_OP_ITSELF(equal_all);\nUSE_NO_KERNEL_OP(read_from_array);\n")
2323
file(APPEND ${pybind_file} "USE_OP_ITSELF(logical_and);\nUSE_OP_ITSELF(logical_or);\nUSE_OP_ITSELF(logical_xor);\nUSE_OP_ITSELF(logical_not);\n")
24-
file(APPEND ${pybind_file} "USE_OP(bitwise_and);\nUSE_OP(bitwise_or);\nUSE_OP(bitwise_xor);\nUSE_OP(bitwise_not);\n")
24+
file(APPEND ${pybind_file} "USE_OP_ITSELF(bitwise_and);\nUSE_OP_ITSELF(bitwise_or);\nUSE_OP_ITSELF(bitwise_xor);\nUSE_OP_ITSELF(bitwise_not);\n")

paddle/fluid/operators/controlflow/bitwise_op.cc

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/fluid/operators/controlflow/bitwise_op.h"
1615
#include <algorithm>
1716
#include <string>
1817
#include <vector>
1918
#include "paddle/fluid/framework/op_registry.h"
19+
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
2020

2121
namespace paddle {
2222
namespace operators {
@@ -75,11 +75,19 @@ It operates ``%s`` on Tensor ``X`` .
7575
}
7676
};
7777

78-
class BitwiseOp : public framework::OperatorWithKernel {
78+
template <typename OpComment>
79+
class UnaryBitwiseOp : public framework::OperatorWithKernel {
7980
public:
8081
using framework::OperatorWithKernel::OperatorWithKernel;
8182

8283
protected:
84+
void InferShape(framework::InferShapeContext *context) const override {
85+
OpComment comment;
86+
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", comment.type);
87+
context->SetOutputDim("Out", context->GetInputDim("X"));
88+
context->ShareLoD("X", "Out");
89+
}
90+
8391
framework::OpKernelType GetExpectedKernelType(
8492
const framework::ExecutionContext &ctx) const override {
8593
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
@@ -90,23 +98,9 @@ class BitwiseOp : public framework::OperatorWithKernel {
9098
};
9199

92100
template <typename OpComment>
93-
class UnaryBitwiseOp : public BitwiseOp {
94-
public:
95-
using BitwiseOp::BitwiseOp;
96-
97-
protected:
98-
void InferShape(framework::InferShapeContext *context) const override {
99-
OpComment comment;
100-
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", comment.type);
101-
context->SetOutputDim("Out", context->GetInputDim("X"));
102-
context->ShareLoD("X", "Out");
103-
}
104-
};
105-
106-
template <typename OpComment>
107-
class BinaryBitwiseOp : public BitwiseOp {
101+
class BinaryBitwiseOp : public framework::OperatorWithKernel {
108102
public:
109-
using BitwiseOp::BitwiseOp;
103+
using framework::OperatorWithKernel::OperatorWithKernel;
110104

111105
protected:
112106
void InferShape(framework::InferShapeContext *context) const override {
@@ -130,6 +124,14 @@ class BinaryBitwiseOp : public BitwiseOp {
130124
}
131125
context->ShareLoD("X", "Out");
132126
}
127+
128+
framework::OpKernelType GetExpectedKernelType(
129+
const framework::ExecutionContext &ctx) const override {
130+
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
131+
// BitwiseOp kernel's device type is decided by input tensor place
132+
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
133+
return kt;
134+
}
133135
};
134136

135137
} // namespace operators
@@ -167,8 +169,3 @@ REGISTER_BINARY_BITWISE_OP(bitwise_and, "Out = X \\& Y");
167169
REGISTER_BINARY_BITWISE_OP(bitwise_or, "Out = X | Y");
168170
REGISTER_BINARY_BITWISE_OP(bitwise_xor, "Out = X ^\\wedge Y");
169171
REGISTER_UNARY_BITWISE_OP(bitwise_not, "Out = \\sim X");
170-
171-
REGISTER_BINARY_BITWISE_KERNEL(bitwise_and, CPU, ops::BitwiseAndFunctor);
172-
REGISTER_BINARY_BITWISE_KERNEL(bitwise_or, CPU, ops::BitwiseOrFunctor);
173-
REGISTER_BINARY_BITWISE_KERNEL(bitwise_xor, CPU, ops::BitwiseXorFunctor);
174-
REGISTER_UNARY_BITWISE_KERNEL(bitwise_not, CPU, ops::BitwiseNotFunctor);

paddle/fluid/operators/controlflow/bitwise_op.cu

Lines changed: 0 additions & 74 deletions
This file was deleted.

paddle/fluid/operators/controlflow/bitwise_op.h

Lines changed: 0 additions & 112 deletions
This file was deleted.
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "paddle/phi/core/dense_tensor.h"
18+
19+
namespace phi {
20+
21+
template <typename T, typename Context>
22+
void BitwiseAndKernel(const Context& dev_ctx,
23+
const DenseTensor& x,
24+
const DenseTensor& y,
25+
DenseTensor* out);
26+
27+
template <typename T, typename Context>
28+
void BitwiseOrKernel(const Context& dev_ctx,
29+
const DenseTensor& x,
30+
const DenseTensor& y,
31+
DenseTensor* out);
32+
33+
template <typename T, typename Context>
34+
void BitwiseXorKernel(const Context& dev_ctx,
35+
const DenseTensor& x,
36+
const DenseTensor& y,
37+
DenseTensor* out);
38+
39+
template <typename T, typename Context>
40+
void BitwiseNotKernel(const Context& dev_ctx,
41+
const DenseTensor& x,
42+
DenseTensor* out);
43+
44+
} // namespace phi

0 commit comments

Comments
 (0)