Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion paddle/fluid/framework/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -598,8 +598,10 @@ void Executor::EnableONEDNN(const ProgramDesc& program) {
for (size_t bid = 0; bid < program.Size(); ++bid) {
auto* block = const_cast<ProgramDesc&>(program).MutableBlock(bid);
for (auto* op : block->AllOps()) {
if (FoundOneDNNKernel(op) || FoundPhiOneDNNKernel(op))
if (FoundOneDNNKernel(op) || FoundPhiOneDNNKernel(op)) {
op->SetAttr("use_mkldnn", true);
op->SetAttr("use_onednn", true);
}
}
}
#else
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/ir/op_compat_sensible_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ std::unordered_set<std::string> global_extra_attrs = {
"is_test",
"use_mkldnn",
"mkldnn_data_type",
"use_onednn",
"onednn_data_type",
"use_quantizer",
"use_cudnn",
"name",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ std::shared_ptr<OperatorBase> TransferDtype(const std::string& var_name,
attr_map["out_dtype"] = static_cast<int>(out_dtype);
// NOTE(Aurelius84): In which case use_mkldnn = true?
attr_map["use_mkldnn"] = false;
attr_map["use_onednn"] = false;

// 3. Create cast op
std::string op_type("cast");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,10 @@ void CreateAllOps(const framework::BlockDesc& block,
VLOG(4) << "Set use_mkldnn=True for " << op_base->Type();
op_base->SetAttr("use_mkldnn", true);
}
if (op->HasAttr("use_onednn")) {
VLOG(4) << "Set use_onednn=True for " << op_base->Type();
op_base->SetAttr("use_onednn", true);
}
}
#endif

Expand Down
12 changes: 9 additions & 3 deletions paddle/fluid/framework/new_executor/interpreter/static_build.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ bool BlockCanBeStaticBuilt(const framework::BlockDesc& block) {
// in_black_list = (kernelCode >> 5) & 1
// is_operator_base = (kernelCode >> 4) & 1
// is_custom_op = (kernelCode >> 3) & 1
// use_mkldnn = (kernelCode >> 2) & 1
// use_onednn = (kernelCode >> 2) & 1
// sub_block_can_not_static_build = (kernelCode >> 1) & 1
using KernelCode = int8_t;
std::set<std::pair<std::string, KernelCode>> invalid_ops;
Expand All @@ -150,6 +150,12 @@ bool BlockCanBeStaticBuilt(const framework::BlockDesc& block) {
use_mkldnn = attr.index() == 1 ? PADDLE_GET_CONST(int, attr)
: PADDLE_GET_CONST(bool, attr);
}
bool use_onednn = use_mkldnn;
if (!use_mkldnn && op->HasAttr("use_onednn")) {
Attribute attr = op->GetAttr("use_onednn");
use_onednn = attr.index() == 1 ? PADDLE_GET_CONST(int, attr)
: PADDLE_GET_CONST(bool, attr);
}

bool sub_block_can_not_static_build = false;
if (op->HasAttr("sub_block")) {
Expand All @@ -160,9 +166,9 @@ bool BlockCanBeStaticBuilt(const framework::BlockDesc& block) {

KernelCode kernel_code = static_cast<KernelCode>(
(in_black_list << 5) + (is_operator_base << 4) + (is_custom_op << 3) +
(use_mkldnn << 2) + (sub_block_can_not_static_build << 1));
(use_onednn << 2) + (sub_block_can_not_static_build << 1));

if (in_black_list || is_operator_base || is_custom_op || use_mkldnn ||
if (in_black_list || is_operator_base || is_custom_op || use_onednn ||
sub_block_can_not_static_build) {
invalid_ops.insert(std::make_pair(op_type, kernel_code));
}
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1599,7 +1599,8 @@ bool OperatorWithKernel::SupportsKernelType(

bool OperatorWithKernel::CanONEDNNBeUsed(const framework::ExecutionContext& ctx,
phi::DataType data_type) const {
return ctx.HasAttr("use_mkldnn") && ctx.Attr<bool>("use_mkldnn") &&
return ((ctx.HasAttr("use_mkldnn") && ctx.Attr<bool>("use_mkldnn")) ||
(ctx.HasAttr("use_onednn") && ctx.Attr<bool>("use_onednn"))) &&
phi::is_cpu_place(ctx.GetPlace()) && this->SupportsONEDNN(data_type);
}

Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/framework/phi_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,10 @@ KernelArgsNameMakerByOpProto::GetAttrsArgsNames() {
for (int i = 0; i < op_proto_->attrs_size(); ++i) {
auto& attr = op_proto_->attrs()[i];
auto& attr_name = attr.name();
if (attr_name == "use_mkldnn" || attr_name == "use_cudnn" ||
attr_name == "op_role" || attr_name == "op_role_var" ||
attr_name == "op_namescope" || attr_name == "op_callstack" ||
attr_name == "op_device") {
if (attr_name == "use_mkldnn" || attr_name == "use_onednn" ||
attr_name == "use_cudnn" || attr_name == "op_role" ||
attr_name == "op_role_var" || attr_name == "op_namescope" ||
attr_name == "op_callstack" || attr_name == "op_device") {
continue;
}
if ((attr.has_extra() && attr.extra()) ||
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/imperative/tracer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,12 @@ void Tracer::TraceOpImpl(const std::string& type,
if (!FLAGS_tracer_onednn_ops_on.empty()) {
auto is_on = FLAGS_tracer_onednn_ops_on.find(type) != std::string::npos;
attrs["use_mkldnn"] = is_on;
attrs["use_onednn"] = is_on;
} else {
// if ops_on list is empty all ops are enabled except types from off_list
auto is_off = FLAGS_tracer_onednn_ops_off.find(type) != std::string::npos;
attrs["use_mkldnn"] = !is_off;
attrs["use_onednn"] = !is_off;
}
}

Expand Down
7 changes: 5 additions & 2 deletions paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ inline std::string GetPrefix(pir::IrContext* ctx, const OpDesc& op_desc) {
}
#ifdef PADDLE_WITH_DNNL
if (op_desc.GetAttrIfExists<bool>("use_mkldnn") ||
op_desc.GetAttrIfExists<bool>("use_onednn") ||
paddle::dialect::IsOneDNNOnlyOp(op_desc.Type())) {
if (!HasOpInfo(ctx, op_desc, kOneDNNTargetDialectPrefix)) {
VLOG(3) << op_desc.Type()
Expand Down Expand Up @@ -1838,7 +1839,8 @@ struct MulOpTranscriber : public OpTranscriber {
const OpDesc& op_desc,
pir::Block* block) override {
#ifdef PADDLE_WITH_DNNL
if (op_desc.GetAttrIfExists<bool>("use_mkldnn")) {
if (op_desc.GetAttrIfExists<bool>("use_mkldnn") ||
op_desc.GetAttrIfExists<bool>("use_onednn")) {
return static_cast<OpTranscriber>(*this).operator()( // NOLINT
ctx,
param_map,
Expand Down Expand Up @@ -2015,7 +2017,8 @@ struct MulGradOpTranscriber : public OpTranscriber {
const OpDesc& op_desc,
pir::Block* block) override {
#ifdef PADDLE_WITH_DNNL
if (op_desc.GetAttrIfExists<bool>("use_mkldnn")) {
if (op_desc.GetAttrIfExists<bool>("use_mkldnn") ||
op_desc.GetAttrIfExists<bool>("use_onednn")) {
return static_cast<OpTranscriber>(*this).operator()( // NOLINT
ctx,
param_map,
Expand Down