@@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212See the License for the specific language governing permissions and
1313limitations under the License. */
1414
15- #include " paddle/fluid/operators/controlflow/bitwise_op.h"
1615#include < algorithm>
1716#include < string>
1817#include < vector>
1918#include " paddle/fluid/framework/op_registry.h"
19+ #include " paddle/fluid/operators/elementwise/elementwise_op_function.h"
2020
2121namespace paddle {
2222namespace operators {
@@ -75,11 +75,19 @@ It operates ``%s`` on Tensor ``X`` .
7575 }
7676};
7777
78- class BitwiseOp : public framework ::OperatorWithKernel {
78+ template <typename OpComment>
79+ class UnaryBitwiseOp : public framework ::OperatorWithKernel {
7980 public:
8081 using framework::OperatorWithKernel::OperatorWithKernel;
8182
8283 protected:
84+ void InferShape (framework::InferShapeContext *context) const override {
85+ OpComment comment;
86+ OP_INOUT_CHECK (context->HasInput (" X" ), " Input" , " X" , comment.type );
87+ context->SetOutputDim (" Out" , context->GetInputDim (" X" ));
88+ context->ShareLoD (" X" , " Out" );
89+ }
90+
8391 framework::OpKernelType GetExpectedKernelType (
8492 const framework::ExecutionContext &ctx) const override {
8593 framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType (ctx);
@@ -90,23 +98,9 @@ class BitwiseOp : public framework::OperatorWithKernel {
9098};
9199
92100template <typename OpComment>
93- class UnaryBitwiseOp : public BitwiseOp {
94- public:
95- using BitwiseOp::BitwiseOp;
96-
97- protected:
98- void InferShape (framework::InferShapeContext *context) const override {
99- OpComment comment;
100- OP_INOUT_CHECK (context->HasInput (" X" ), " Input" , " X" , comment.type );
101- context->SetOutputDim (" Out" , context->GetInputDim (" X" ));
102- context->ShareLoD (" X" , " Out" );
103- }
104- };
105-
106- template <typename OpComment>
107- class BinaryBitwiseOp : public BitwiseOp {
101+ class BinaryBitwiseOp : public framework ::OperatorWithKernel {
108102 public:
109- using BitwiseOp::BitwiseOp ;
103+ using framework::OperatorWithKernel::OperatorWithKernel ;
110104
111105 protected:
112106 void InferShape (framework::InferShapeContext *context) const override {
@@ -130,6 +124,14 @@ class BinaryBitwiseOp : public BitwiseOp {
130124 }
131125 context->ShareLoD (" X" , " Out" );
132126 }
127+
128+ framework::OpKernelType GetExpectedKernelType (
129+ const framework::ExecutionContext &ctx) const override {
130+ framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType (ctx);
131+ // BitwiseOp kernel's device type is decided by input tensor place
132+ kt.place_ = ctx.Input <framework::LoDTensor>(" X" )->place ();
133+ return kt;
134+ }
133135};
134136
135137} // namespace operators
@@ -167,8 +169,3 @@ REGISTER_BINARY_BITWISE_OP(bitwise_and, "Out = X \\& Y");
167169REGISTER_BINARY_BITWISE_OP (bitwise_or, " Out = X | Y" );
168170REGISTER_BINARY_BITWISE_OP (bitwise_xor, " Out = X ^\\ wedge Y" );
169171REGISTER_UNARY_BITWISE_OP (bitwise_not, " Out = \\ sim X" );
170-
171- REGISTER_BINARY_BITWISE_KERNEL (bitwise_and, CPU, ops::BitwiseAndFunctor);
172- REGISTER_BINARY_BITWISE_KERNEL (bitwise_or, CPU, ops::BitwiseOrFunctor);
173- REGISTER_BINARY_BITWISE_KERNEL (bitwise_xor, CPU, ops::BitwiseXorFunctor);
174- REGISTER_UNARY_BITWISE_KERNEL (bitwise_not, CPU, ops::BitwiseNotFunctor);
0 commit comments