Skip to content

Commit 99f9224

Browse files
authored
Added stack FP32 FWD oneDNN kernel (#37002)
* added stack oneDNN FP32 op * minor change * CI fix * added skipping for gpus * fix for stack op * CI fix * CI fix * Added comment * CI fix
1 parent 643fd2f commit 99f9224

File tree

5 files changed

+321
-0
lines changed

5 files changed

+321
-0
lines changed
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/utils.h"
16+
#include "paddle/fluid/platform/mkldnn_reuse.h"
17+
namespace paddle {
18+
namespace operators {
19+
20+
using framework::DataLayout;
21+
using framework::Tensor;
22+
using framework::LoDTensor;
23+
using mkldnn::memory;
24+
using mkldnn::primitive;
25+
using mkldnn::concat;
26+
using mkldnn::stream;
27+
using platform::to_void_cast;
28+
29+
template <typename T>
30+
class StackMKLDNNHandler
31+
: public platform::MKLDNNHandlerNoCachingT<T, dnnl::concat> {
32+
public:
33+
StackMKLDNNHandler(const framework::ExecutionContext& ctx,
34+
const mkldnn::engine mkldnn_engine,
35+
const std::vector<const Tensor*>& inputs, Tensor* output)
36+
: platform::MKLDNNHandlerNoCachingT<T, dnnl::concat>(mkldnn_engine,
37+
ctx.GetPlace()) {
38+
int stack_axis = ctx.Attr<int>("axis");
39+
40+
int ndims = inputs[0]->dims().size();
41+
42+
if (stack_axis < 0) {
43+
stack_axis = ndims + 1 + stack_axis; // +1 to match output's ndims
44+
}
45+
46+
// in stack op all inputs must have same dims
47+
auto input_dims = framework::vectorize<int64_t>(inputs[0]->dims());
48+
49+
memory::data_type dt = framework::ToMKLDNNDataType(inputs[0]->type());
50+
std::vector<memory::desc> srcs_md;
51+
memory::desc dst_md;
52+
MKLDNNMemoryFormat dst_fmt;
53+
54+
srcs_md.reserve(inputs.size());
55+
56+
// if stack is not done on last(non existing) axis, then we can optimize
57+
// concat primitive by not adding additional dimension, since it causes
58+
// wrong output format deduction and suboptimal performance as a result
59+
if (stack_axis != ndims) {
60+
for (size_t i = 0; i < inputs.size(); ++i) {
61+
srcs_md.emplace_back(memory::desc(input_dims, dt, inputs[i]->format()));
62+
}
63+
64+
input_dims[stack_axis] *= inputs.size();
65+
dst_md = memory::desc(input_dims, dt, MKLDNNMemoryFormat::any);
66+
} else {
67+
auto extended_input_dims = framework::vectorize<int64_t>(output->dims());
68+
extended_input_dims[stack_axis] = 1;
69+
70+
for (size_t i = 0; i < inputs.size(); ++i) {
71+
srcs_md.emplace_back(memory::desc(input_dims, dt, inputs[i]->format())
72+
.reshape(extended_input_dims));
73+
}
74+
75+
// concat primitive choses suboptimal format tag because it cannot
76+
// distinguish between f.e. abcd and abdc if last dim is equal to 1 so
77+
// enforcing is needed for better performance
78+
dst_fmt = platform::GetPlainMKLDNNFormat(extended_input_dims.size());
79+
dst_md = memory::desc(framework::vectorize(output->dims()), dt, dst_fmt);
80+
}
81+
82+
this->AcquireForwardPrimitiveDescriptor(dst_md, stack_axis, srcs_md);
83+
}
84+
85+
// concat oneDNN prim is not having .desc attribute so we cannot use default
86+
// AcquireForwardPrimitiveDescriptor
87+
void AcquireForwardPrimitiveDescriptor(
88+
const memory::desc& dst_md, const int stack_axis,
89+
const std::vector<memory::desc>& srcs_md) {
90+
this->fwd_pd_.reset(new dnnl::concat::primitive_desc(
91+
dst_md, stack_axis, srcs_md, this->engine_));
92+
}
93+
94+
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(const Tensor& input, int i) {
95+
const T* input_data = input.data<T>();
96+
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src_desc(i),
97+
to_void_cast<T>(input_data));
98+
}
99+
};
100+
101+
template <typename T>
102+
class StackMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
103+
public:
104+
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
105+
auto& dev_ctx =
106+
ctx.template device_context<platform::MKLDNNDeviceContext>();
107+
const auto& mkldnn_engine = dev_ctx.GetEngine();
108+
109+
auto multi_input = ctx.MultiInput<Tensor>("X");
110+
111+
Tensor* output = ctx.Output<Tensor>("Y");
112+
113+
StackMKLDNNHandler<T> handler(ctx, mkldnn_engine, multi_input, output);
114+
115+
std::vector<std::shared_ptr<memory>> srcs;
116+
srcs.reserve(multi_input.size());
117+
118+
auto dst_mem = handler.AcquireDstMemory(output);
119+
auto concat_p = handler.AcquireForwardPrimitive();
120+
121+
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
122+
std::unordered_map<int, memory> args;
123+
for (size_t i = 0; i < multi_input.size(); ++i) {
124+
srcs.push_back(handler.AcquireSrcMemory(*(multi_input[i]), i));
125+
args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, *(srcs.at(i))});
126+
}
127+
args.insert({MKLDNN_ARG_DST, *dst_mem});
128+
129+
concat_p->execute(astream, args);
130+
astream.wait();
131+
132+
output->set_layout(DataLayout::kMKLDNN);
133+
output->set_format(platform::GetMKLDNNFormat(
134+
dst_mem->get_desc().reshape(framework::vectorize(output->dims()))));
135+
}
136+
};
137+
} // namespace operators
138+
} // namespace paddle
139+
140+
namespace ops = paddle::operators;
141+
142+
REGISTER_OP_KERNEL(stack, MKLDNN, ::paddle::platform::CPUPlace,
143+
ops::StackMKLDNNOpKernel<float>);

paddle/fluid/operators/stack_op.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,21 @@ class StackOp : public framework::OperatorWithKernel {
7171
vec.insert(vec.begin() + axis, input_dims.size());
7272
ctx->SetOutputDim("Y", framework::make_ddim(vec));
7373
}
74+
75+
framework::OpKernelType GetExpectedKernelType(
76+
const framework::ExecutionContext &ctx) const override {
77+
auto input_data_type =
78+
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
79+
80+
#ifdef PADDLE_WITH_MKLDNN
81+
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
82+
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
83+
framework::DataLayout::kMKLDNN,
84+
framework::LibraryType::kMKLDNN);
85+
}
86+
#endif
87+
return framework::OpKernelType(input_data_type, ctx.GetPlace());
88+
}
7489
};
7590

7691
class StackOpMaker : public framework::OpProtoAndCheckerMaker {
@@ -81,6 +96,11 @@ class StackOpMaker : public framework::OpProtoAndCheckerMaker {
8196
AddAttr<int>("axis",
8297
"The axis along which all of the Inputs(X) should be stacked.")
8398
.SetDefault(0);
99+
AddAttr<bool>(
100+
"use_mkldnn",
101+
"(bool, default false) Indicates if MKL-DNN kernel will be used")
102+
.SetDefault(false)
103+
.AsExtra();
84104
AddComment(R"DOC(
85105
Stack Operator.
86106
Stack all of the Inputs(X) into one tensor along Attr(axis). The dims of all Inputs(X) must be the same.

paddle/fluid/platform/mkldnn_helper.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,43 @@ inline mkldnn::memory::format_tag GetMKLDNNFormat(const mkldnn::memory memory) {
333333
return GetMKLDNNFormat(mem_desc);
334334
}
335335

336+
inline mkldnn::memory::format_tag GetPlainMKLDNNFormat(int tensor_rank) {
337+
switch (tensor_rank) {
338+
case 1:
339+
return mkldnn::memory::format_tag::a;
340+
break;
341+
case 2:
342+
return mkldnn::memory::format_tag::ab;
343+
break;
344+
case 3:
345+
return mkldnn::memory::format_tag::abc;
346+
break;
347+
case 4:
348+
return mkldnn::memory::format_tag::abcd;
349+
break;
350+
case 5:
351+
return mkldnn::memory::format_tag::abcde;
352+
break;
353+
case 6:
354+
return mkldnn::memory::format_tag::abcdef;
355+
break;
356+
case 7:
357+
return mkldnn::memory::format_tag::abcdefg;
358+
break;
359+
case 8:
360+
return mkldnn::memory::format_tag::abcdefgh;
361+
break;
362+
case 9:
363+
return mkldnn::memory::format_tag::abcdefghi;
364+
break;
365+
default:
366+
PADDLE_THROW(platform::errors::Unimplemented(
367+
"Paddle support tensors with rank in range <1, 9>, but received "
368+
"tensor with rank: %d",
369+
tensor_rank));
370+
}
371+
}
372+
336373
inline MKLDNNMemoryFormat MKLDNNFormatForSize(size_t dims_size,
337374
MKLDNNMemoryFormat data_format) {
338375
if (dims_size == 1) {
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
import numpy as np
17+
from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, skip_check_grad_ci
18+
import paddle
19+
import paddle.fluid as fluid
20+
import paddle.fluid.core as core
21+
22+
23+
@OpTestTool.skip_if_not_cpu()
24+
class TestStack2DOneDNNOp(OpTest):
25+
def initDefaultParameters(self):
26+
self.num_inputs = 4
27+
self.input_dim = (2, 2)
28+
self.axis = 1
29+
self.dtype = np.float32
30+
31+
def initParameters(self):
32+
pass
33+
34+
def getInputNames(self):
35+
input_names = []
36+
for i in range(self.num_inputs):
37+
input_names.append('x{}'.format(i))
38+
return input_names
39+
40+
def setUp(self):
41+
self.initDefaultParameters()
42+
self.initParameters()
43+
self.op_type = 'stack'
44+
self.op_inputs = []
45+
46+
for i in range(self.num_inputs):
47+
self.op_inputs.append(
48+
np.random.random(size=self.input_dim).astype(np.float32))
49+
50+
input_list = []
51+
input_names = self.getInputNames()
52+
for i in range(self.num_inputs):
53+
input_list.append((input_names[i], self.op_inputs[i]))
54+
55+
self.inputs = {'X': input_list}
56+
self.outputs = {'Y': np.stack(self.op_inputs, axis=self.axis)}
57+
self.attrs = {'axis': self.axis, 'use_mkldnn': True}
58+
59+
def test_check_output(self):
60+
self.check_output_with_place(core.CPUPlace())
61+
62+
# JUST FOR CI TO PASS, GRAD IS NOT IMPLEMENTED YET
63+
def test_check_grad(self):
64+
pass
65+
66+
67+
class TestStack1DOneDNNOp(TestStack2DOneDNNOp):
68+
def initParameters(self):
69+
self.input_dim = (100)
70+
self.axis = 0
71+
72+
73+
class TestStack1DAxis1OneDNNOp(TestStack2DOneDNNOp):
74+
def initParameters(self):
75+
self.input_dim = (100)
76+
self.axis = 1
77+
78+
79+
class TestStack2DAxisLastOneDNNOp(TestStack2DOneDNNOp):
80+
def initParameters(self):
81+
self.input_dim = (13, 24)
82+
self.num_inputs = 5
83+
self.axis = -1
84+
85+
86+
class TestStack3DAxisNegativeOneDNNOp(TestStack2DOneDNNOp):
87+
def initParameters(self):
88+
self.input_dim = (10, 128, 128)
89+
self.axis = -2
90+
91+
92+
class TestStack3DOneDNNOp(TestStack2DOneDNNOp):
93+
def initParameters(self):
94+
self.input_dim = (10, 128, 128)
95+
self.num_inputs = 3
96+
self.axis = 1
97+
98+
99+
class TestStack4DOneDNNOp(TestStack2DOneDNNOp):
100+
def initParameters(self):
101+
self.input_dim = (2, 2, 2, 2)
102+
self.num_inputs = 3
103+
self.axis = 4
104+
105+
106+
class TestStack5DOneDNNOp(TestStack2DOneDNNOp):
107+
def initParameters(self):
108+
self.input_dim = (2, 3, 4, 5, 6)
109+
self.num_inputs = 6
110+
self.axis = 0
111+
112+
113+
if __name__ == "__main__":
114+
paddle.enable_static()
115+
unittest.main()

python/paddle/fluid/tests/unittests/op_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1832,3 +1832,9 @@ def skip_if_not_cpu_bf16(cls):
18321832
not (isinstance(_current_expected_place(), core.CPUPlace) and
18331833
core.supports_bfloat16()),
18341834
"Place does not support BF16 evaluation")
1835+
1836+
@classmethod
1837+
def skip_if_not_cpu(cls):
1838+
return OpTestTool.skip_if(
1839+
not isinstance(_current_expected_place(), core.CPUPlace),
1840+
"OneDNN supports only CPU for now")

0 commit comments

Comments
 (0)