@@ -133,7 +133,7 @@ class PlacementPassTest {
133133
134134 auto  pass = PassRegistry::Instance ().Get (" onednn_placement_pass" 
135135
136-  pass->Set (" mkldnn_enabled_op_types " 
136+  pass->Set (" onednn_enabled_op_types " 
137137 new  std::unordered_set<std::string>(onednn_enabled_op_types));
138138
139139 graph.reset (pass->Apply (graph.release ()));
@@ -143,8 +143,10 @@ class PlacementPassTest {
143143 for  (auto * node : graph->Nodes ()) {
144144 if  (node->IsOp ()) {
145145 auto * op = node->Op ();
146-  if  (op->HasAttr (" use_mkldnn" 
147-  PADDLE_GET_CONST (bool , op->GetAttr (" use_mkldnn" 
146+  if  ((op->HasAttr (" use_mkldnn" 
147+  PADDLE_GET_CONST (bool , op->GetAttr (" use_mkldnn" 
148+  (op->HasAttr (" use_onednn" 
149+  PADDLE_GET_CONST (bool , op->GetAttr (" use_onednn" 
148150 ++use_onednn_true_count;
149151 }
150152 }
@@ -156,27 +158,27 @@ class PlacementPassTest {
156158 void  PlacementNameTest () {
157159 auto  pass = PassRegistry::Instance ().Get (" onednn_placement_pass" 
158160 EXPECT_EQ (static_cast <PlacementPassBase*>(pass.get ())->GetPlacementName (),
159-  " MKLDNN " 
161+  " ONEDNN " 
160162 }
161163};
162164
163- TEST (MKLDNNPlacementPass , enable_conv_relu) {
165+ TEST (ONEDNNPlacementPass , enable_conv_relu) {
164166 //  2 conv (1 conv is always true) + 2 relu (1 relu is always true) + 0 pool
165167 PlacementPassTest ().MainTest ({" conv2d" " relu" 4 );
166168}
167169
168- TEST (MKLDNNPlacementPass , enable_relu_pool) {
170+ TEST (ONEDNNPlacementPass , enable_relu_pool) {
169171 //  1 conv (1 conv is always true) + 2 relu (1 relu is always true) + 1 pool
170172 PlacementPassTest ().MainTest ({" relu" " pool2d" 4 );
171173}
172174
173- TEST (MKLDNNPlacementPass , enable_all) {
175+ TEST (ONEDNNPlacementPass , enable_all) {
174176 //  2 conv (1 conv is always true) + 2 relu (1 relu is always true) + 1 pool +
175177 //  1 concat
176178 PlacementPassTest ().MainTest ({}, 6 );
177179}
178180
179- TEST (MKLDNNPlacementPass , placement_name) {
181+ TEST (ONEDNNPlacementPass , placement_name) {
180182 PlacementPassTest ().PlacementNameTest ();
181183}
182184
0 commit comments