Skip to content

Conversation

@fabianmcg
Copy link
Contributor

This patch adds ub as a dependent dialect to memref, and uses ub.poison as the default value in AllocaOp::getDefaultValue for the mem2reg pass.

This aligns the behavior of mem2reg with LLVM, where loading a value before having a value should be poison.

Signed-off-by: Fabian Mora <fabian.mora-cordero@amd.com>
@llvmbot
Copy link
Member

llvmbot commented Nov 14, 2025

@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir

Author: Fabian Mora (fabianmcg)

Changes

This patch adds ub as a dependent dialect to memref, and uses ub.poison as the default value in AllocaOp::getDefaultValue for the mem2reg pass.

This aligns the behavior of mem2reg with LLVM, where loading a value before having a value should be poison.


Full diff: https://github.com/llvm/llvm-project/pull/168066.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td (+1-1)
  • (modified) mlir/lib/Dialect/MemRef/IR/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp (+1)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp (+2-17)
  • (modified) mlir/test/Dialect/MemRef/mem2reg.mlir (+1-1)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td index 3be84ae654f6a..7c088935d2f06 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td @@ -19,7 +19,7 @@ def MemRef_Dialect : Dialect { manipulation ops, which are not strongly associated with any particular other dialect or domain abstraction. }]; - let dependentDialects = ["arith::ArithDialect"]; + let dependentDialects = ["arith::ArithDialect", "ub::UBDialect"]; let hasConstantMaterializer = 1; } diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt index 1382c7aceea79..d358362f1984b 100644 --- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt @@ -25,6 +25,7 @@ add_mlir_dialect_library(MLIRMemRefDialect MLIRMemorySlotInterfaces MLIRShapedOpInterfaces MLIRSideEffectInterfaces + MLIRUBDialect MLIRValueBoundsOpInterface MLIRViewLikeInterface ) diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp index 6ff63df258c79..a1e3f10a871c1 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp @@ -10,6 +10,7 @@ #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h" diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp index dfa2e4e0376ed..540423831937e 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" @@ -61,15 +62,8 @@ static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape, // Interfaces for AllocaOp //===----------------------------------------------------------------------===// -static bool isSupportedElementType(Type type) { - return llvm::isa<MemRefType>(type) || - OpBuilder(type.getContext()).getZeroAttr(type); -} - SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() { MemRefType type = getType(); - if (!isSupportedElementType(type.getElementType())) - return {}; if (!type.hasStaticShape()) return {}; // Make sure the memref contains only a single element. @@ -81,16 +75,7 @@ SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() { Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot, OpBuilder &builder) { - assert(isSupportedElementType(slot.elemType)); - // TODO: support more types. - return TypeSwitch<Type, Value>(slot.elemType) - .Case([&](MemRefType t) { - return memref::AllocaOp::create(builder, getLoc(), t); - }) - .Default([&](Type t) { - return arith::ConstantOp::create(builder, getLoc(), t, - builder.getZeroAttr(t)); - }); + return ub::PoisonOp::create(builder, getLoc(), slot.elemType); } std::optional<PromotableAllocationOpInterface> diff --git a/mlir/test/Dialect/MemRef/mem2reg.mlir b/mlir/test/Dialect/MemRef/mem2reg.mlir index d300699f6f342..dd68675cc4441 100644 --- a/mlir/test/Dialect/MemRef/mem2reg.mlir +++ b/mlir/test/Dialect/MemRef/mem2reg.mlir @@ -18,7 +18,7 @@ func.func @basic() -> i32 { // CHECK-LABEL: func.func @basic_default func.func @basic_default() -> i32 { // CHECK-NOT: = memref.alloca - // CHECK: %[[RES:.*]] = arith.constant 0 : i32 + // CHECK: %[[RES:.*]] = ub.poison : i32 // CHECK-NOT: = memref.alloca %0 = arith.constant 5 : i32 %1 = memref.alloca() : memref<i32> 
other dialect or domain abstraction.
}];
let dependentDialects = ["arith::ArithDialect"];
let dependentDialects = ["arith::ArithDialect", "ub::UBDialect"];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you split the line and add a comment describing why each of these are there?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

3 participants