Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@ class FuseFCActOneDNNPass : public FusePassBase {

} // namespace ir
} // namespace framework
} // namespace paddlea
} // namespace paddle
54 changes: 33 additions & 21 deletions paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,15 @@ class ConvMKLDNNHandlerT
: dnnl::prop_kind::forward_training;

float sum_scale = 1.0f;
float activation_scale = 1.0f;
std::vector<float> output_shift_scale;
if (platform::is_int8<T>())
std::tie(sum_scale, output_shift_scale) = get_int8_scales(ctx);
std::tie(sum_scale, output_shift_scale, activation_scale) =
get_int8_scales(ctx);

const dnnl::primitive_attr conv_attr = CreatePostOps(
fuse_activation, fuse_alpha, fuse_beta, fuse_residual_conn,
output_shift_scale, sum_scale); // for INT8 only!
output_shift_scale, sum_scale, activation_scale); // for INT8 only!

if (bias) {
auto bias_tz = framework::vectorize(bias->dims());
Expand Down Expand Up @@ -432,7 +434,7 @@ class ConvMKLDNNHandlerT
return bias_scale_tuple;
}

std::tuple<float, std::vector<float>> get_int8_scales(
std::tuple<float, std::vector<float>, float> get_int8_scales(
const framework::ExecutionContext& ctx) const {
const auto* filter = ctx.Input<Tensor>("Filter");
const auto& weights_tz = framework::vectorize(filter->dims());
Expand All @@ -445,8 +447,14 @@ class ConvMKLDNNHandlerT
const auto& scale_in_eltwise_data = ctx.Attr<float>("Scale_in_eltwise");
auto scale_weights_data = ctx.Attr<std::vector<float>>("Scale_weights");
bool is_multi_channel = scale_weights_data.size() > 1;
bool has_activation = !ctx.Attr<std::string>("fuse_activation").empty();
float activation_scale =
force_fp32_output ? 1.0f : has_activation ? ctx.Attr<float>("Scale_out")
: 1.0f;
auto scale_out_data =
force_fp32_output ? 1.0f : ctx.Attr<float>("Scale_out");
force_fp32_output ? 1.0f : has_activation
? 1.0f
: ctx.Attr<float>("Scale_out");
float sum_scale =
fuse_residual_conn ? scale_out_data / scale_in_eltwise_data : 1.0f;
int count =
Expand All @@ -468,13 +476,13 @@ class ConvMKLDNNHandlerT
static_cast<double>(scale_weights_data[i])));
}

return std::make_tuple(sum_scale, output_shift_scale);
return std::make_tuple(sum_scale, output_shift_scale, activation_scale);
}

dnnl::primitive_attr CreatePostOps(
std::string fuse_activation, float fuse_alpha, float fuse_beta,
bool fuse_residual_conn, const std::vector<float> output_shift_scale = {},
float sum_scale = 1.0f) {
float sum_scale = 1.0f, float activation_scale = 1.0f) {
dnnl::primitive_attr conv_attr;
dnnl::post_ops post_operations;
if (output_shift_scale.size() > 0) {
Expand All @@ -492,30 +500,34 @@ class ConvMKLDNNHandlerT
}
// Fusion with ReLU layer is executed through the PostOps feature. Create a
// PostOps object and configure it to execute an eltwise relu operation.
constexpr float scale = 1.0f;
if (fuse_activation == "relu" || fuse_activation == "leaky_relu") {
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_relu,
fuse_alpha, fuse_beta);
post_operations.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_relu, fuse_alpha,
fuse_beta);
} else if (fuse_activation == "relu6") {
post_operations.append_eltwise(
scale, dnnl::algorithm::eltwise_bounded_relu, fuse_alpha, fuse_beta);
} else if (fuse_activation == "swish") {
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_swish,
post_operations.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_bounded_relu,
fuse_alpha, fuse_beta);
} else if (fuse_activation == "swish") {
post_operations.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_swish, fuse_alpha,
fuse_beta);
} else if (fuse_activation == "hard_swish") {
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_hardswish,
post_operations.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_hardswish,
fuse_alpha, fuse_beta);
} else if (fuse_activation == "hard_sigmoid") {
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_linear,
post_operations.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_linear,
fuse_alpha, fuse_beta);
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_clip, 0.0f,
1.0f);
post_operations.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_clip, 0.0f, 1.0f);
} else if (fuse_activation == "gelu_tanh") {
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_gelu_tanh,
0.0f, 0.0f);
post_operations.append_eltwise(
activation_scale, dnnl::algorithm::eltwise_gelu_tanh, 0.0f, 0.0f);
} else if (fuse_activation == "gelu_erf") {
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_gelu_erf,
0.0f, 0.0f);
post_operations.append_eltwise(
activation_scale, dnnl::algorithm::eltwise_gelu_erf, 0.0f, 0.0f);
}
conv_attr.set_post_ops(post_operations);
return conv_attr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ def _optimize_fp32_graph(self, graph):
graph = self._apply_pass(graph, 'conv_elementwise_add_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_relu_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_relu6_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_hard_swish_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'fc_fuse_pass',
['use_gpu', 'use_fc_padding'], [False, False])
graph = self._apply_pass(graph, 'repeated_fc_relu_fuse_pass')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def setUp(self):
self.init_group()
self.init_dilation()
self.init_test_case()
self.init_fuse_relu()
self.init_fuse_activation()
self.init_fuse_residual()
self.init_data_type()

Expand All @@ -54,7 +54,9 @@ def setUp(self):
}
# This implementation of convolution quantization is based on OneDNN documentation
# https://oneapi-src.github.io/oneDNN/dev_guide_int8_computations.html#doxid-dev-guide-int8-computations-1dg-i8-comp-s11
scale_output_shift = (self.scale_out /
inner_scale = 1. if self.fuse_activation != "" else self.scale_out
activation_scale = self.scale_out if self.fuse_activation != "" else 1.
scale_output_shift = (inner_scale /
(self.scale_in * self.scale_weights[0]))
filter = np.random.random(self.filter_size).astype(self.weighttype)

Expand All @@ -78,7 +80,7 @@ def residual_helper(init_low, init_high, output_):
init_low, init_high,
self.input_residual_size).astype(self.srctype)
return (output_ + input_residual_ *
(self.scale_out / self.scale_in_eltwise)), input_residual_
(inner_scale / self.scale_in_eltwise)), input_residual_

if self.srctype == np.int8:
init_low, init_high = (-5, 5)
Expand All @@ -101,12 +103,24 @@ def residual_helper(init_low, init_high, output_):
output, input_residual = residual_helper(init_low, init_high,
output)

output = np.round(output)

if self.fuse_activation == "relu":
output = np.maximum(output, 0)
if self.fuse_activation == "":
pass
elif self.fuse_activation == "relu":
output = activation_scale * np.maximum(output, 0)
elif self.fuse_activation == "hard_swish":
output = activation_scale * output / 6. * np.minimum(
np.maximum(0, output + 3.), 6)
elif self.fuse_activation == "relu6":
output = activation_scale * np.maximum(0, np.minimum(6, output))
elif self.fuse_activation == "swish":
output = activation_scale * output / (1. + np.exp(-1. * output))
elif self.fuse_activation == "leaky_relu":
output = activation_scale * np.maximum(output, 0.02 * output)
else:
raise NotImplementedError("test for " + self.fuse_activation +
" activation not implemented")

output = output.astype(self.dsttype)
output = np.round(output).astype(self.dsttype)

self.inputs = {
'Input':
Expand All @@ -131,6 +145,8 @@ def residual_helper(init_low, init_high, output_):
'Scale_weights': self.scale_weights,
'Scale_in_eltwise': self.scale_in_eltwise,
'fuse_activation': self.fuse_activation,
'fuse_alpha': self.fuse_alpha,
'fuse_beta': self.fuse_beta,
'fuse_residual_connection': self.fuse_residual,
'mkldnn_data_type': self.mkldnn_data_type
}
Expand Down Expand Up @@ -165,8 +181,10 @@ def init_data_type(self):
self.srctype = np.uint8
self.dsttype = np.int8

def init_fuse_relu(self):
def init_fuse_activation(self):
self.fuse_activation = "relu"
self.fuse_alpha = 0
self.fuse_beta = 0

def init_fuse_residual(self):
self.fuse_residual = True
Expand All @@ -190,6 +208,34 @@ def init_test_case(self):
self.scale_in_eltwise = 0.6


class TestWithHardSwish(TestConv2D):
def init_fuse_activation(self):
self.fuse_activation = "hard_swish"
self.fuse_alpha = 0
self.fuse_beta = 0


class TestWithRelu6(TestConv2D):
def init_fuse_activation(self):
self.fuse_activation = "relu6"
self.fuse_alpha = 6
self.fuse_beta = 0


class TestWithSwish(TestConv2D):
def init_fuse_activation(self):
self.fuse_activation = "swish"
self.fuse_alpha = 1
self.fuse_beta = 0


class TestWithLeakyRelu(TestConv2D):
def init_fuse_activation(self):
self.fuse_activation = "leaky_relu"
self.fuse_alpha = 0.02
self.fuse_beta = 0


class TestWithPad(TestConv2D):
def init_test_case(self):
TestConv2D.init_test_case(self)
Expand Down