@@ -66,7 +66,7 @@ module attributes {transform.with_named_sequence} {
6666// ----- 
6767
6868#map  = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 , d2 )>
69- func.func  @vectorize_nd_tensor_extract_constant_idx  (%arg0:  tensor <3 x3 xf32 >, %arg2:  tensor <1 x1 x3 xf32 >) -> tensor <1 x1 x3 xf32 > {
69+ func.func  @vectorize_nd_tensor_extract_scalar_broadcast  (%arg0:  tensor <3 x3 xf32 >, %arg2:  tensor <1 x1 x3 xf32 >) -> tensor <1 x1 x3 xf32 > {
7070 %c0  = arith.constant  1  : index 
7171 %c1  = arith.constant  2  : index 
7272 %2  = linalg.generic  {
@@ -80,17 +80,17 @@ func.func @vectorize_nd_tensor_extract_constant_idx(%arg0: tensor<3x3xf32>, %arg
8080 return  %2  : tensor <1 x1 x3 xf32 >
8181}
8282
83- // CHECK: #[[$MAP:.* ]] = affine_map<(d0, d1) -> (0, 0, 0)> 
84- // CHECK-LABEL: func.func @vectorize_nd_tensor_extract_constant_idx ( 
83+ // CHECK: #[[$MAP:.+ ]] = affine_map<(d0, d1) -> (0, 0, 0)> 
84+ // CHECK-LABEL: func.func @vectorize_nd_tensor_extract_scalar_broadcast ( 
8585// CHECK-SAME: %[[ARG_0:.*]]: tensor<3x3xf32>, 
8686// CHECK-SAME: %[[ARG_1:.*]]: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> { 
8787// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 
8888// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index 
89- // CHECK-DAG: %[[C0_f32_2 :.*]] = arith.constant 0.000000e+00  : f32  
90- // CHECK-DAG : %[[C0_f32 :.*]] = arith.constant 0.000000e+00  : f32  
91- // CHECK: %[[READ:.*]] = vector.transfer_read  %[[ARG_0]][%[[C1]], %[[C2]]], %[[C0_f32]]  {in_bounds = [true, true, true], permutation_map = #[[$MAP]]} : tensor<3x3xf32>, vector<1x1x3xf32> 
92- // CHECK: %[[C0_4 :.*]] = arith.constant 0 : index 
93- // CHECK: vector.transfer_write %[[READ]], %[[ARG_1]][ %[[C0_4 ]], %[[C0_4 ]], %[[C0_4 ]]]   : vector<1x1x3xf32>, tensor<1x1x3xf32> 
89+ // CHECK-DAG: %[[C0 :.*]] = arith.constant 0 : index  
90+ // CHECK:   %[[MASK :.*]] = vector.constant_mask [1]  : vector<1xi1>  
91+ // CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read  %[[ARG_0]][%[[C1]], %[[C2]]], {{.*}}  {in_bounds = [true, true, true], permutation_map = #[[$MAP]]} : tensor<3x3xf32>, vector<1x1x3xf32> } : vector<1xi1> ->  vector<1x1x3xf32> 
92+ // CHECK: %[[C0_2 :.*]] = arith.constant 0 : index 
93+ // CHECK: vector.transfer_write %[[READ]], %[[ARG_1]]{{\[}} %[[C0_2 ]], %[[C0_2 ]], %[[C0_2 ]]] : vector<1x1x3xf32>, tensor<1x1x3xf32> 
9494
9595module  attributes  {transform.with_named_sequence } {
9696 transform.named_sequence  @__transform_main (%arg1:  !transform.any_op  {transform.readonly }) {
@@ -823,7 +823,7 @@ func.func @vectorize_scalar_broadcast_column_tensor(%in: tensor<1x1x4xi32>) -> t
823823 return  %out:tensor <1 x1 x4 xi32 >
824824}
825825
826- // CHECK: #[[$ATTR_1 :.+]] = affine_map<(d0, d1) -> (0, 0, 0)> 
826+ // CHECK: #[[$MAP :.+]] = affine_map<(d0, d1) -> (0, 0, 0)> 
827827// CHECK-LABEL: func.func @vectorize_scalar_broadcast_column_tensor( 
828828// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> { 
829829// CHECK: %[[VAL_1:.*]] = arith.constant 4 : index 
@@ -844,12 +844,14 @@ func.func @vectorize_scalar_broadcast_column_tensor(%in: tensor<1x1x4xi32>) -> t
844844// CHECK: %[[VAL_16:.*]] = arith.constant dense<true> : vector<1x1x4xi1> 
845845// CHECK: %[[VAL_17:.*]] = arith.constant dense<0> : vector<1x1x4xi32> 
846846// CHECK: %[[VAL_18:.*]] = arith.constant 0 : index 
847- // CHECK: %[[VAL_20:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex> 
848- // CHECK: %[[VAL_21:.*]] = vector.extract %[[VAL_20]][0] : index from vector<4xindex> 
849- // CHECK: %[[VAL_22:.*]] = arith.constant 0 : i32 
850- // CHECK: %[[VAL_23:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_21]], %[[VAL_2]]], %[[VAL_22]] {in_bounds = [true, true, true], permutation_map = #[[$ATTR_1]]} : tensor<15x1xi32>, vector<1x1x4xi32> 
847+ // CHECK: %[[VAL_19:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex> 
848+ // CHECK: %[[VAL_20:.*]] = vector.extract %[[VAL_19]][0] : index from vector<4xindex> 
849+ // CHECK: %[[VAL_21:.*]] = arith.constant 0 : i32 
850+ // CHECK: %[[VAL_22:.*]] = vector.constant_mask [1] : vector<1xi1> 
851+ // CHECK: %[[VAL_23:.*]] = vector.mask %[[VAL_22]] { vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_20]], %[[VAL_2]]], %[[VAL_21]] {in_bounds = [true, true, true], permutation_map = #[[$MAP]]} : tensor<15x1xi32>, vector<1x1x4xi32> } : vector<1xi1> -> vector<1x1x4xi32> 
851852// CHECK: %[[VAL_24:.*]] = arith.constant 0 : index 
852853// CHECK: %[[VAL_25:.*]] = vector.transfer_write %[[VAL_23]], %[[VAL_0]]{{\[}}%[[VAL_24]], %[[VAL_24]], %[[VAL_24]]] : vector<1x1x4xi32>, tensor<1x1x4xi32> 
854+ // CHECK: return %[[VAL_25]] : tensor<1x1x4xi32> 
853855
854856module  attributes  {transform.with_named_sequence } {
855857 transform.named_sequence  @__transform_main (%arg1:  !transform.any_op  {transform.readonly }) {
0 commit comments