Skip to content

Conversation

andykaylor
Copy link
Contributor

This adds support for handling exact dynamic casts when optimizations are enabled.

This adds support for handling exact dynamic casts when optimizations are enabled.
@llvmbot llvmbot added clang Clang issues not falling into any other category ClangIR Anything related to the ClangIR project labels Oct 17, 2025
@llvmbot
Copy link
Member

llvmbot commented Oct 17, 2025

@llvm/pr-subscribers-clang

Author: Andy Kaylor (andykaylor)

Changes

This adds support for handling exact dynamic casts when optimizations are enabled.


Full diff: https://github.com/llvm/llvm-project/pull/164007.diff

5 Files Affected:

  • (modified) clang/lib/CIR/CodeGen/CIRGenCall.cpp (+16)
  • (modified) clang/lib/CIR/CodeGen/CIRGenFunction.h (+3)
  • (modified) clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp (+160-2)
  • (modified) clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp (+9)
  • (added) clang/test/CIR/CodeGen/dynamic-cast-exact.cpp (+114)
diff --git a/clang/lib/CIR/CodeGen/CIRGenCall.cpp b/clang/lib/CIR/CodeGen/CIRGenCall.cpp index 61072f0883728..88aef89ddd2b9 100644 --- a/clang/lib/CIR/CodeGen/CIRGenCall.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenCall.cpp @@ -690,6 +690,22 @@ void CallArg::copyInto(CIRGenFunction &cgf, Address addr, isUsed = true; } +mlir::Value CIRGenFunction::emitRuntimeCall(mlir::Location loc, + cir::FuncOp callee, + ArrayRef<mlir::Value> args) { + // TODO(cir): set the calling convention to this runtime call. + assert(!cir::MissingFeatures::opFuncCallingConv()); + + cir::CallOp call = builder.createCallOp(loc, callee, args); + assert(call->getNumResults() <= 1 && + "runtime functions have at most 1 result"); + + if (call->getNumResults() == 0) + return nullptr; + + return call->getResult(0); +} + void CIRGenFunction::emitCallArg(CallArgList &args, const clang::Expr *e, clang::QualType argType) { assert(argType->isReferenceType() == e->isGLValue() && diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.h b/clang/lib/CIR/CodeGen/CIRGenFunction.h index 3c36f5c697118..84b4ba293b3aa 100644 --- a/clang/lib/CIR/CodeGen/CIRGenFunction.h +++ b/clang/lib/CIR/CodeGen/CIRGenFunction.h @@ -1380,6 +1380,9 @@ class CIRGenFunction : public CIRGenTypeCache { void emitReturnOfRValue(mlir::Location loc, RValue rv, QualType ty); + mlir::Value emitRuntimeCall(mlir::Location loc, cir::FuncOp callee, + llvm::ArrayRef<mlir::Value> args = {}); + /// Emit the computation of the specified expression of scalar type. mlir::Value emitScalarExpr(const clang::Expr *e); diff --git a/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp b/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp index d54d2e9cb29e5..ef91288ab6155 100644 --- a/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp @@ -1869,6 +1869,15 @@ static cir::FuncOp getBadCastFn(CIRGenFunction &cgf) { return cgf.cgm.createRuntimeFunction(fnTy, "__cxa_bad_cast"); } +static void emitCallToBadCast(CIRGenFunction &cgf, mlir::Location loc) { + // TODO(cir): set the calling convention to the runtime function. + assert(!cir::MissingFeatures::opFuncCallingConv()); + + cgf.emitRuntimeCall(loc, getBadCastFn(cgf)); + cir::UnreachableOp::create(cgf.getBuilder(), loc); + cgf.getBuilder().clearInsertionPoint(); +} + // TODO(cir): This could be shared with classic codegen. static CharUnits computeOffsetHint(ASTContext &astContext, const CXXRecordDecl *src, @@ -1954,6 +1963,136 @@ static Address emitDynamicCastToVoid(CIRGenFunction &cgf, mlir::Location loc, return Address{ptr, src.getAlignment()}; } +static mlir::Value emitExactDynamicCast(CIRGenItaniumCXXABI &abi, + CIRGenFunction &cgf, mlir::Location loc, + QualType srcRecordTy, + QualType destRecordTy, + cir::PointerType destCIRTy, + bool isRefCast, Address src) { + // Find all the inheritance paths from SrcRecordTy to DestRecordTy. + const CXXRecordDecl *srcDecl = srcRecordTy->getAsCXXRecordDecl(); + const CXXRecordDecl *destDecl = destRecordTy->getAsCXXRecordDecl(); + CXXBasePaths paths(/*FindAmbiguities=*/true, /*RecordPaths=*/true, + /*DetectVirtual=*/false); + (void)destDecl->isDerivedFrom(srcDecl, paths); + + // Find an offset within `destDecl` where a `srcDecl` instance and its vptr + // might appear. + std::optional<CharUnits> offset; + for (const CXXBasePath &path : paths) { + // dynamic_cast only finds public inheritance paths. + if (path.Access != AS_public) + continue; + + CharUnits pathOffset; + for (const CXXBasePathElement &pathElement : path) { + // Find the offset along this inheritance step. + const CXXRecordDecl *base = + pathElement.Base->getType()->getAsCXXRecordDecl(); + if (pathElement.Base->isVirtual()) { + // For a virtual base class, we know that the derived class is exactly + // destDecl, so we can use the vbase offset from its layout. + const ASTRecordLayout &layout = + cgf.getContext().getASTRecordLayout(destDecl); + pathOffset = layout.getVBaseClassOffset(base); + } else { + const ASTRecordLayout &layout = + cgf.getContext().getASTRecordLayout(pathElement.Class); + pathOffset += layout.getBaseClassOffset(base); + } + } + + if (!offset) { + offset = pathOffset; + } else if (offset != pathOffset) { + // base appears in at least two different places. Find the most-derived + // object and see if it's a DestDecl. Note that the most-derived object + // must be at least as aligned as this base class subobject, and must + // have a vptr at offset 0. + src = emitDynamicCastToVoid(cgf, loc, srcRecordTy, src); + srcDecl = destDecl; + offset = CharUnits::Zero(); + break; + } + } + + CIRGenBuilderTy &builder = cgf.getBuilder(); + + if (!offset) { + // If there are no public inheritance paths, the cast always fails. + mlir::Value nullPtrValue = builder.getNullPtr(destCIRTy, loc); + if (isRefCast) { + mlir::Region *currentRegion = builder.getBlock()->getParent(); + emitCallToBadCast(cgf, loc); + + // The call to bad_cast will terminate the block. Create a new block to + // hold any follow up code. + builder.createBlock(currentRegion, currentRegion->end()); + } + + return nullPtrValue; + } + + // Compare the vptr against the expected vptr for the destination type at + // this offset. Note that we do not know what type src points to in the case + // where the derived class multiply inherits from the base class so we can't + // use getVTablePtr, so we load the vptr directly instead. + + mlir::Value expectedVPtr = + abi.getVTableAddressPoint(BaseSubobject(srcDecl, *offset), destDecl); + + // TODO(cir): handle address space here. + assert(!cir::MissingFeatures::addressSpace()); + mlir::Type vptrTy = expectedVPtr.getType(); + mlir::Type vptrPtrTy = builder.getPointerTo(vptrTy); + Address srcVPtrPtr(builder.createBitcast(src.getPointer(), vptrPtrTy), + src.getAlignment()); + mlir::Value srcVPtr = builder.createLoad(loc, srcVPtrPtr); + + // TODO(cir): decorate SrcVPtr with TBAA info. + assert(!cir::MissingFeatures::opTBAA()); + + mlir::Value success = + builder.createCompare(loc, cir::CmpOpKind::eq, srcVPtr, expectedVPtr); + + auto emitCastResult = [&] { + if (offset->isZero()) + return builder.createBitcast(src.getPointer(), destCIRTy); + + // TODO(cir): handle address space here. + assert(!cir::MissingFeatures::addressSpace()); + mlir::Type u8PtrTy = builder.getUInt8PtrTy(); + + mlir::Value strideToApply = + builder.getConstInt(loc, builder.getUInt64Ty(), offset->getQuantity()); + mlir::Value srcU8Ptr = builder.createBitcast(src.getPointer(), u8PtrTy); + mlir::Value resultU8Ptr = cir::PtrStrideOp::create(builder, loc, u8PtrTy, + srcU8Ptr, strideToApply); + return builder.createBitcast(resultU8Ptr, destCIRTy); + }; + + if (isRefCast) { + mlir::Value failed = builder.createNot(success); + cir::IfOp::create(builder, loc, failed, /*withElseRegion=*/false, + [&](mlir::OpBuilder &, mlir::Location) { + emitCallToBadCast(cgf, loc); + }); + return emitCastResult(); + } + + return cir::TernaryOp::create( + builder, loc, success, + [&](mlir::OpBuilder &, mlir::Location) { + auto result = emitCastResult(); + builder.createYield(loc, result); + }, + [&](mlir::OpBuilder &, mlir::Location) { + mlir::Value nullPtrValue = builder.getNullPtr(destCIRTy, loc); + builder.createYield(loc, nullPtrValue); + }) + .getResult(); +} + static cir::DynamicCastInfoAttr emitDynamicCastInfo(CIRGenFunction &cgf, mlir::Location loc, QualType srcRecordTy, @@ -1995,8 +2134,27 @@ mlir::Value CIRGenItaniumCXXABI::emitDynamicCast(CIRGenFunction &cgf, // if the dynamic type of the pointer is exactly the destination type. if (destRecordTy->getAsCXXRecordDecl()->isEffectivelyFinal() && cgf.cgm.getCodeGenOpts().OptimizationLevel > 0) { - cgm.errorNYI(loc, "emitExactDynamicCast"); - return {}; + CIRGenBuilderTy &builder = cgf.getBuilder(); + // If this isn't a reference cast, check the pointer to see if it's null. + if (!isRefCast) { + mlir::Value srcPtrIsNull = builder.createPtrIsNull(src.getPointer()); + return cir::TernaryOp::create( + builder, loc, srcPtrIsNull, + [&](mlir::OpBuilder, mlir::Location) { + builder.createYield( + loc, builder.getNullPtr(destCIRTy, loc).getResult()); + }, + [&](mlir::OpBuilder &, mlir::Location) { + mlir::Value exactCast = emitExactDynamicCast( + *this, cgf, loc, srcRecordTy, destRecordTy, destCIRTy, + isRefCast, src); + builder.createYield(loc, exactCast); + }) + .getResult(); + } + + return emitExactDynamicCast(*this, cgf, loc, srcRecordTy, destRecordTy, + destCIRTy, isRefCast, src); } cir::DynamicCastInfoAttr castInfo = diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 0243bf120f396..51dba33338cd6 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -2223,6 +2223,15 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite( return mlir::success(); } + if (auto vptrTy = mlir::dyn_cast<cir::VPtrType>(type)) { + // !cir.vptr is a special case, but it's just a pointer to LLVM. + auto kind = convertCmpKindToICmpPredicate(cmpOp.getKind(), + /* isSigned=*/false); + rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>( + cmpOp, kind, adaptor.getLhs(), adaptor.getRhs()); + return mlir::success(); + } + if (mlir::isa<cir::FPTypeInterface>(type)) { mlir::LLVM::FCmpPredicate kind = convertCmpKindToFCmpPredicate(cmpOp.getKind()); diff --git a/clang/test/CIR/CodeGen/dynamic-cast-exact.cpp b/clang/test/CIR/CodeGen/dynamic-cast-exact.cpp new file mode 100644 index 0000000000000..41a70ce53db5e --- /dev/null +++ b/clang/test/CIR/CodeGen/dynamic-cast-exact.cpp @@ -0,0 +1,114 @@ +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -O1 -fclangir -clangir-disable-passes -emit-cir -o %t.cir %s +// RUN: FileCheck --input-file=%t.cir %s --check-prefix=CIR +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -O1 -fclangir -emit-llvm -o %t-cir.ll %s +// RUN: FileCheck --input-file=%t-cir.ll --check-prefix=LLVM %s +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -O1 -emit-llvm -o %t.ll %s +// RUN: FileCheck --input-file=%t.ll --check-prefix=OGCG %s + +struct Base1 { + virtual ~Base1(); +}; + +struct Base2 { + virtual ~Base2(); +}; + +struct Derived final : Base1 {}; + +Derived *ptr_cast(Base1 *ptr) { + return dynamic_cast<Derived *>(ptr); +} + +// CIR: cir.func {{.*}} @_Z8ptr_castP5Base1 +// CIR: %[[SRC:.*]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!cir.ptr<!rec_Base1>>, !cir.ptr<!rec_Base1> +// CIR-NEXT: %[[NULL_PTR:.*]] = cir.const #cir.ptr<null> +// CIR-NEXT: %[[SRC_IS_NULL:.*]] = cir.cmp(eq, %[[SRC]], %[[NULL_PTR]]) +// CIR-NEXT: %[[RESULT:.*]] = cir.ternary(%[[SRC_IS_NULL]], true { +// CIR-NEXT: %[[NULL_PTR_DEST:.*]] = cir.const #cir.ptr<null> : !cir.ptr<!rec_Derived> +// CIR-NEXT: cir.yield %[[NULL_PTR_DEST]] : !cir.ptr<!rec_Derived> +// CIR-NEXT: }, false { +// CIR-NEXT: %[[EXPECTED_VPTR:.*]] = cir.vtable.address_point(@_ZTV7Derived, address_point = <index = 0, offset = 2>) : !cir.vptr +// CIR-NEXT: %[[SRC_VPTR_PTR:.*]] = cir.cast bitcast %[[SRC]] : !cir.ptr<!rec_Base1> -> !cir.ptr<!cir.vptr> +// CIR-NEXT: %[[SRC_VPTR:.*]] = cir.load{{.*}} %[[SRC_VPTR_PTR]] : !cir.ptr<!cir.vptr>, !cir.vptr +// CIR-NEXT: %[[SUCCESS:.*]] = cir.cmp(eq, %[[SRC_VPTR]], %[[EXPECTED_VPTR]]) : !cir.vptr, !cir.bool +// CIR-NEXT: %[[EXACT_RESULT:.*]] = cir.ternary(%[[SUCCESS]], true { +// CIR-NEXT: %[[RES:.*]] = cir.cast bitcast %[[SRC]] : !cir.ptr<!rec_Base1> -> !cir.ptr<!rec_Derived> +// CIR-NEXT: cir.yield %[[RES]] : !cir.ptr<!rec_Derived> +// CIR-NEXT: }, false { +// CIR-NEXT: %[[NULL:.*]] = cir.const #cir.ptr<null> : !cir.ptr<!rec_Derived> +// CIR-NEXT: cir.yield %[[NULL]] : !cir.ptr<!rec_Derived> +// CIR-NEXT: }) : (!cir.bool) -> !cir.ptr<!rec_Derived> +// CIR-NEXT: cir.yield %[[EXACT_RESULT]] : !cir.ptr<!rec_Derived> +// CIR-NEXT: }) : (!cir.bool) -> !cir.ptr<!rec_Derived> + +// Note: The LLVM output omits the label for the entry block (which is +// implicitly %1), so we use %{{.*}} to match the implicit label in the +// phi check. + +// LLVM: define dso_local ptr @_Z8ptr_castP5Base1(ptr{{.*}} %[[SRC:.*]]) +// LLVM-NEXT: %[[SRC_IS_NULL:.*]] = icmp eq ptr %0, null +// LLVM-NEXT: br i1 %[[SRC_IS_NULL]], label %[[LABEL_END:.*]], label %[[LABEL_NOTNULL:.*]] +// LLVM: [[LABEL_NOTNULL]]: +// LLVM-NEXT: %[[VPTR:.*]] = load ptr, ptr %[[SRC]], align 8 +// LLVM-NEXT: %[[SUCCESS:.*]] = icmp eq ptr %[[VPTR]], getelementptr inbounds nuw (i8, ptr @_ZTV7Derived, i64 16) +// LLVM-NEXT: %[[EXACT_RESULT:.*]] = select i1 %[[SUCCESS]], ptr %[[SRC]], ptr null +// LLVM-NEXT: br label %[[LABEL_END]] +// LLVM: [[LABEL_END]]: +// LLVM-NEXT: %[[RESULT:.*]] = phi ptr [ %[[EXACT_RESULT]], %[[LABEL_NOTNULL]] ], [ null, %{{.*}} ] +// LLVM-NEXT: ret ptr %[[RESULT]] +// LLVM-NEXT: } + +// OGCG: define{{.*}} ptr @_Z8ptr_castP5Base1(ptr {{.*}} %[[SRC:.*]]) +// OGCG-NEXT: entry: +// OGCG-NEXT: %[[NULL_CHECK:.*]] = icmp eq ptr %[[SRC]], null +// OGCG-NEXT: br i1 %[[NULL_CHECK]], label %[[LABEL_NULL:.*]], label %[[LABEL_NOTNULL:.*]] +// OGCG: [[LABEL_NOTNULL]]: +// OGCG-NEXT: %[[VTABLE:.*]] = load ptr, ptr %[[SRC]], align 8 +// OGCG-NEXT: %[[VTABLE_CHECK:.*]] = icmp eq ptr %[[VTABLE]], getelementptr inbounds {{.*}} (i8, ptr @_ZTV7Derived, i64 16) +// OGCG-NEXT: br i1 %[[VTABLE_CHECK]], label %[[LABEL_END:.*]], label %[[LABEL_NULL]] +// OGCG: [[LABEL_NULL]]: +// OGCG-NEXT: br label %[[LABEL_END]] +// OGCG: [[LABEL_END]]: +// OGCG-NEXT: %[[RESULT:.*]] = phi ptr [ %[[SRC]], %[[LABEL_NOTNULL]] ], [ null, %[[LABEL_NULL]] ] +// OGCG-NEXT: ret ptr %[[RESULT]] +// OGCG-NEXT: } + +Derived &ref_cast(Base1 &ref) { + return dynamic_cast<Derived &>(ref); +} + +// CIR: cir.func {{.*}} @_Z8ref_castR5Base1 +// CIR: %[[SRC:.*]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!cir.ptr<!rec_Base1>>, !cir.ptr<!rec_Base1> +// CIR-NEXT: %[[EXPECTED_VPTR:.*]] = cir.vtable.address_point(@_ZTV7Derived, address_point = <index = 0, offset = 2>) : !cir.vptr +// CIR-NEXT: %[[SRC_VPTR_PTR:.*]] = cir.cast bitcast %[[SRC]] : !cir.ptr<!rec_Base1> -> !cir.ptr<!cir.vptr> +// CIR-NEXT: %[[SRC_VPTR:.*]] = cir.load{{.*}} %[[SRC_VPTR_PTR]] : !cir.ptr<!cir.vptr>, !cir.vptr +// CIR-NEXT: %[[SUCCESS:.*]] = cir.cmp(eq, %[[SRC_VPTR]], %[[EXPECTED_VPTR]]) : !cir.vptr, !cir.bool +// CIR-NEXT: %[[FAILED:.*]] = cir.unary(not, %[[SUCCESS]]) : !cir.bool, !cir.bool +// CIR-NEXT: cir.if %[[FAILED]] { +// CIR-NEXT: cir.call @__cxa_bad_cast() : () -> () +// CIR-NEXT: cir.unreachable +// CIR-NEXT: } +// CIR-NEXT: %{{.+}} = cir.cast bitcast %[[SRC]] : !cir.ptr<!rec_Base1> -> !cir.ptr<!rec_Derived> + +// LLVM: define{{.*}} ptr @_Z8ref_castR5Base1(ptr{{.*}} %[[SRC:.*]]) +// LLVM-NEXT: %[[VPTR:.*]] = load ptr, ptr %[[SRC]], align 8 +// LLVM-NEXT: %[[OK:.*]] = icmp eq ptr %[[VPTR]], getelementptr inbounds nuw (i8, ptr @_ZTV7Derived, i64 16) +// LLVM-NEXT: br i1 %[[OK]], label %[[LABEL_OK:.*]], label %[[LABEL_FAIL:.*]] +// LLVM: [[LABEL_FAIL]]: +// LLVM-NEXT: tail call void @__cxa_bad_cast() +// LLVM-NEXT: unreachable +// LLVM: [[LABEL_OK]]: +// LLVM-NEXT: ret ptr %[[SRC]] +// LLVM-NEXT: } + +// OGCG: define{{.*}} ptr @_Z8ref_castR5Base1(ptr {{.*}} %[[REF:.*]]) +// OGCG-NEXT: entry: +// OGCG-NEXT: %[[VTABLE:.*]] = load ptr, ptr %[[REF]], align 8 +// OGCG-NEXT: %[[VTABLE_CHECK:.*]] = icmp eq ptr %[[VTABLE]], getelementptr inbounds {{.*}} (i8, ptr @_ZTV7Derived, i64 16) +// OGCG-NEXT: br i1 %[[VTABLE_CHECK]], label %[[LABEL_END:.*]], label %[[LABEL_NULL:.*]] +// OGCG: [[LABEL_NULL]]: +// OGCG-NEXT: {{.*}}call void @__cxa_bad_cast() +// OGCG-NEXT: unreachable +// OGCG: [[LABEL_END]]: +// OGCG-NEXT: ret ptr %[[REF]] +// OGCG-NEXT: } 
@llvmbot
Copy link
Member

llvmbot commented Oct 17, 2025

@llvm/pr-subscribers-clangir

Author: Andy Kaylor (andykaylor)

Changes

This adds support for handling exact dynamic casts when optimizations are enabled.


Full diff: https://github.com/llvm/llvm-project/pull/164007.diff

5 Files Affected:

  • (modified) clang/lib/CIR/CodeGen/CIRGenCall.cpp (+16)
  • (modified) clang/lib/CIR/CodeGen/CIRGenFunction.h (+3)
  • (modified) clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp (+160-2)
  • (modified) clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp (+9)
  • (added) clang/test/CIR/CodeGen/dynamic-cast-exact.cpp (+114)
diff --git a/clang/lib/CIR/CodeGen/CIRGenCall.cpp b/clang/lib/CIR/CodeGen/CIRGenCall.cpp index 61072f0883728..88aef89ddd2b9 100644 --- a/clang/lib/CIR/CodeGen/CIRGenCall.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenCall.cpp @@ -690,6 +690,22 @@ void CallArg::copyInto(CIRGenFunction &cgf, Address addr, isUsed = true; } +mlir::Value CIRGenFunction::emitRuntimeCall(mlir::Location loc, + cir::FuncOp callee, + ArrayRef<mlir::Value> args) { + // TODO(cir): set the calling convention to this runtime call. + assert(!cir::MissingFeatures::opFuncCallingConv()); + + cir::CallOp call = builder.createCallOp(loc, callee, args); + assert(call->getNumResults() <= 1 && + "runtime functions have at most 1 result"); + + if (call->getNumResults() == 0) + return nullptr; + + return call->getResult(0); +} + void CIRGenFunction::emitCallArg(CallArgList &args, const clang::Expr *e, clang::QualType argType) { assert(argType->isReferenceType() == e->isGLValue() && diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.h b/clang/lib/CIR/CodeGen/CIRGenFunction.h index 3c36f5c697118..84b4ba293b3aa 100644 --- a/clang/lib/CIR/CodeGen/CIRGenFunction.h +++ b/clang/lib/CIR/CodeGen/CIRGenFunction.h @@ -1380,6 +1380,9 @@ class CIRGenFunction : public CIRGenTypeCache { void emitReturnOfRValue(mlir::Location loc, RValue rv, QualType ty); + mlir::Value emitRuntimeCall(mlir::Location loc, cir::FuncOp callee, + llvm::ArrayRef<mlir::Value> args = {}); + /// Emit the computation of the specified expression of scalar type. mlir::Value emitScalarExpr(const clang::Expr *e); diff --git a/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp b/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp index d54d2e9cb29e5..ef91288ab6155 100644 --- a/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp @@ -1869,6 +1869,15 @@ static cir::FuncOp getBadCastFn(CIRGenFunction &cgf) { return cgf.cgm.createRuntimeFunction(fnTy, "__cxa_bad_cast"); } +static void emitCallToBadCast(CIRGenFunction &cgf, mlir::Location loc) { + // TODO(cir): set the calling convention to the runtime function. + assert(!cir::MissingFeatures::opFuncCallingConv()); + + cgf.emitRuntimeCall(loc, getBadCastFn(cgf)); + cir::UnreachableOp::create(cgf.getBuilder(), loc); + cgf.getBuilder().clearInsertionPoint(); +} + // TODO(cir): This could be shared with classic codegen. static CharUnits computeOffsetHint(ASTContext &astContext, const CXXRecordDecl *src, @@ -1954,6 +1963,136 @@ static Address emitDynamicCastToVoid(CIRGenFunction &cgf, mlir::Location loc, return Address{ptr, src.getAlignment()}; } +static mlir::Value emitExactDynamicCast(CIRGenItaniumCXXABI &abi, + CIRGenFunction &cgf, mlir::Location loc, + QualType srcRecordTy, + QualType destRecordTy, + cir::PointerType destCIRTy, + bool isRefCast, Address src) { + // Find all the inheritance paths from SrcRecordTy to DestRecordTy. + const CXXRecordDecl *srcDecl = srcRecordTy->getAsCXXRecordDecl(); + const CXXRecordDecl *destDecl = destRecordTy->getAsCXXRecordDecl(); + CXXBasePaths paths(/*FindAmbiguities=*/true, /*RecordPaths=*/true, + /*DetectVirtual=*/false); + (void)destDecl->isDerivedFrom(srcDecl, paths); + + // Find an offset within `destDecl` where a `srcDecl` instance and its vptr + // might appear. + std::optional<CharUnits> offset; + for (const CXXBasePath &path : paths) { + // dynamic_cast only finds public inheritance paths. + if (path.Access != AS_public) + continue; + + CharUnits pathOffset; + for (const CXXBasePathElement &pathElement : path) { + // Find the offset along this inheritance step. + const CXXRecordDecl *base = + pathElement.Base->getType()->getAsCXXRecordDecl(); + if (pathElement.Base->isVirtual()) { + // For a virtual base class, we know that the derived class is exactly + // destDecl, so we can use the vbase offset from its layout. + const ASTRecordLayout &layout = + cgf.getContext().getASTRecordLayout(destDecl); + pathOffset = layout.getVBaseClassOffset(base); + } else { + const ASTRecordLayout &layout = + cgf.getContext().getASTRecordLayout(pathElement.Class); + pathOffset += layout.getBaseClassOffset(base); + } + } + + if (!offset) { + offset = pathOffset; + } else if (offset != pathOffset) { + // base appears in at least two different places. Find the most-derived + // object and see if it's a DestDecl. Note that the most-derived object + // must be at least as aligned as this base class subobject, and must + // have a vptr at offset 0. + src = emitDynamicCastToVoid(cgf, loc, srcRecordTy, src); + srcDecl = destDecl; + offset = CharUnits::Zero(); + break; + } + } + + CIRGenBuilderTy &builder = cgf.getBuilder(); + + if (!offset) { + // If there are no public inheritance paths, the cast always fails. + mlir::Value nullPtrValue = builder.getNullPtr(destCIRTy, loc); + if (isRefCast) { + mlir::Region *currentRegion = builder.getBlock()->getParent(); + emitCallToBadCast(cgf, loc); + + // The call to bad_cast will terminate the block. Create a new block to + // hold any follow up code. + builder.createBlock(currentRegion, currentRegion->end()); + } + + return nullPtrValue; + } + + // Compare the vptr against the expected vptr for the destination type at + // this offset. Note that we do not know what type src points to in the case + // where the derived class multiply inherits from the base class so we can't + // use getVTablePtr, so we load the vptr directly instead. + + mlir::Value expectedVPtr = + abi.getVTableAddressPoint(BaseSubobject(srcDecl, *offset), destDecl); + + // TODO(cir): handle address space here. + assert(!cir::MissingFeatures::addressSpace()); + mlir::Type vptrTy = expectedVPtr.getType(); + mlir::Type vptrPtrTy = builder.getPointerTo(vptrTy); + Address srcVPtrPtr(builder.createBitcast(src.getPointer(), vptrPtrTy), + src.getAlignment()); + mlir::Value srcVPtr = builder.createLoad(loc, srcVPtrPtr); + + // TODO(cir): decorate SrcVPtr with TBAA info. + assert(!cir::MissingFeatures::opTBAA()); + + mlir::Value success = + builder.createCompare(loc, cir::CmpOpKind::eq, srcVPtr, expectedVPtr); + + auto emitCastResult = [&] { + if (offset->isZero()) + return builder.createBitcast(src.getPointer(), destCIRTy); + + // TODO(cir): handle address space here. + assert(!cir::MissingFeatures::addressSpace()); + mlir::Type u8PtrTy = builder.getUInt8PtrTy(); + + mlir::Value strideToApply = + builder.getConstInt(loc, builder.getUInt64Ty(), offset->getQuantity()); + mlir::Value srcU8Ptr = builder.createBitcast(src.getPointer(), u8PtrTy); + mlir::Value resultU8Ptr = cir::PtrStrideOp::create(builder, loc, u8PtrTy, + srcU8Ptr, strideToApply); + return builder.createBitcast(resultU8Ptr, destCIRTy); + }; + + if (isRefCast) { + mlir::Value failed = builder.createNot(success); + cir::IfOp::create(builder, loc, failed, /*withElseRegion=*/false, + [&](mlir::OpBuilder &, mlir::Location) { + emitCallToBadCast(cgf, loc); + }); + return emitCastResult(); + } + + return cir::TernaryOp::create( + builder, loc, success, + [&](mlir::OpBuilder &, mlir::Location) { + auto result = emitCastResult(); + builder.createYield(loc, result); + }, + [&](mlir::OpBuilder &, mlir::Location) { + mlir::Value nullPtrValue = builder.getNullPtr(destCIRTy, loc); + builder.createYield(loc, nullPtrValue); + }) + .getResult(); +} + static cir::DynamicCastInfoAttr emitDynamicCastInfo(CIRGenFunction &cgf, mlir::Location loc, QualType srcRecordTy, @@ -1995,8 +2134,27 @@ mlir::Value CIRGenItaniumCXXABI::emitDynamicCast(CIRGenFunction &cgf, // if the dynamic type of the pointer is exactly the destination type. if (destRecordTy->getAsCXXRecordDecl()->isEffectivelyFinal() && cgf.cgm.getCodeGenOpts().OptimizationLevel > 0) { - cgm.errorNYI(loc, "emitExactDynamicCast"); - return {}; + CIRGenBuilderTy &builder = cgf.getBuilder(); + // If this isn't a reference cast, check the pointer to see if it's null. + if (!isRefCast) { + mlir::Value srcPtrIsNull = builder.createPtrIsNull(src.getPointer()); + return cir::TernaryOp::create( + builder, loc, srcPtrIsNull, + [&](mlir::OpBuilder, mlir::Location) { + builder.createYield( + loc, builder.getNullPtr(destCIRTy, loc).getResult()); + }, + [&](mlir::OpBuilder &, mlir::Location) { + mlir::Value exactCast = emitExactDynamicCast( + *this, cgf, loc, srcRecordTy, destRecordTy, destCIRTy, + isRefCast, src); + builder.createYield(loc, exactCast); + }) + .getResult(); + } + + return emitExactDynamicCast(*this, cgf, loc, srcRecordTy, destRecordTy, + destCIRTy, isRefCast, src); } cir::DynamicCastInfoAttr castInfo = diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 0243bf120f396..51dba33338cd6 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -2223,6 +2223,15 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite( return mlir::success(); } + if (auto vptrTy = mlir::dyn_cast<cir::VPtrType>(type)) { + // !cir.vptr is a special case, but it's just a pointer to LLVM. + auto kind = convertCmpKindToICmpPredicate(cmpOp.getKind(), + /* isSigned=*/false); + rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>( + cmpOp, kind, adaptor.getLhs(), adaptor.getRhs()); + return mlir::success(); + } + if (mlir::isa<cir::FPTypeInterface>(type)) { mlir::LLVM::FCmpPredicate kind = convertCmpKindToFCmpPredicate(cmpOp.getKind()); diff --git a/clang/test/CIR/CodeGen/dynamic-cast-exact.cpp b/clang/test/CIR/CodeGen/dynamic-cast-exact.cpp new file mode 100644 index 0000000000000..41a70ce53db5e --- /dev/null +++ b/clang/test/CIR/CodeGen/dynamic-cast-exact.cpp @@ -0,0 +1,114 @@ +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -O1 -fclangir -clangir-disable-passes -emit-cir -o %t.cir %s +// RUN: FileCheck --input-file=%t.cir %s --check-prefix=CIR +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -O1 -fclangir -emit-llvm -o %t-cir.ll %s +// RUN: FileCheck --input-file=%t-cir.ll --check-prefix=LLVM %s +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -O1 -emit-llvm -o %t.ll %s +// RUN: FileCheck --input-file=%t.ll --check-prefix=OGCG %s + +struct Base1 { + virtual ~Base1(); +}; + +struct Base2 { + virtual ~Base2(); +}; + +struct Derived final : Base1 {}; + +Derived *ptr_cast(Base1 *ptr) { + return dynamic_cast<Derived *>(ptr); +} + +// CIR: cir.func {{.*}} @_Z8ptr_castP5Base1 +// CIR: %[[SRC:.*]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!cir.ptr<!rec_Base1>>, !cir.ptr<!rec_Base1> +// CIR-NEXT: %[[NULL_PTR:.*]] = cir.const #cir.ptr<null> +// CIR-NEXT: %[[SRC_IS_NULL:.*]] = cir.cmp(eq, %[[SRC]], %[[NULL_PTR]]) +// CIR-NEXT: %[[RESULT:.*]] = cir.ternary(%[[SRC_IS_NULL]], true { +// CIR-NEXT: %[[NULL_PTR_DEST:.*]] = cir.const #cir.ptr<null> : !cir.ptr<!rec_Derived> +// CIR-NEXT: cir.yield %[[NULL_PTR_DEST]] : !cir.ptr<!rec_Derived> +// CIR-NEXT: }, false { +// CIR-NEXT: %[[EXPECTED_VPTR:.*]] = cir.vtable.address_point(@_ZTV7Derived, address_point = <index = 0, offset = 2>) : !cir.vptr +// CIR-NEXT: %[[SRC_VPTR_PTR:.*]] = cir.cast bitcast %[[SRC]] : !cir.ptr<!rec_Base1> -> !cir.ptr<!cir.vptr> +// CIR-NEXT: %[[SRC_VPTR:.*]] = cir.load{{.*}} %[[SRC_VPTR_PTR]] : !cir.ptr<!cir.vptr>, !cir.vptr +// CIR-NEXT: %[[SUCCESS:.*]] = cir.cmp(eq, %[[SRC_VPTR]], %[[EXPECTED_VPTR]]) : !cir.vptr, !cir.bool +// CIR-NEXT: %[[EXACT_RESULT:.*]] = cir.ternary(%[[SUCCESS]], true { +// CIR-NEXT: %[[RES:.*]] = cir.cast bitcast %[[SRC]] : !cir.ptr<!rec_Base1> -> !cir.ptr<!rec_Derived> +// CIR-NEXT: cir.yield %[[RES]] : !cir.ptr<!rec_Derived> +// CIR-NEXT: }, false { +// CIR-NEXT: %[[NULL:.*]] = cir.const #cir.ptr<null> : !cir.ptr<!rec_Derived> +// CIR-NEXT: cir.yield %[[NULL]] : !cir.ptr<!rec_Derived> +// CIR-NEXT: }) : (!cir.bool) -> !cir.ptr<!rec_Derived> +// CIR-NEXT: cir.yield %[[EXACT_RESULT]] : !cir.ptr<!rec_Derived> +// CIR-NEXT: }) : (!cir.bool) -> !cir.ptr<!rec_Derived> + +// Note: The LLVM output omits the label for the entry block (which is +// implicitly %1), so we use %{{.*}} to match the implicit label in the +// phi check. + +// LLVM: define dso_local ptr @_Z8ptr_castP5Base1(ptr{{.*}} %[[SRC:.*]]) +// LLVM-NEXT: %[[SRC_IS_NULL:.*]] = icmp eq ptr %0, null +// LLVM-NEXT: br i1 %[[SRC_IS_NULL]], label %[[LABEL_END:.*]], label %[[LABEL_NOTNULL:.*]] +// LLVM: [[LABEL_NOTNULL]]: +// LLVM-NEXT: %[[VPTR:.*]] = load ptr, ptr %[[SRC]], align 8 +// LLVM-NEXT: %[[SUCCESS:.*]] = icmp eq ptr %[[VPTR]], getelementptr inbounds nuw (i8, ptr @_ZTV7Derived, i64 16) +// LLVM-NEXT: %[[EXACT_RESULT:.*]] = select i1 %[[SUCCESS]], ptr %[[SRC]], ptr null +// LLVM-NEXT: br label %[[LABEL_END]] +// LLVM: [[LABEL_END]]: +// LLVM-NEXT: %[[RESULT:.*]] = phi ptr [ %[[EXACT_RESULT]], %[[LABEL_NOTNULL]] ], [ null, %{{.*}} ] +// LLVM-NEXT: ret ptr %[[RESULT]] +// LLVM-NEXT: } + +// OGCG: define{{.*}} ptr @_Z8ptr_castP5Base1(ptr {{.*}} %[[SRC:.*]]) +// OGCG-NEXT: entry: +// OGCG-NEXT: %[[NULL_CHECK:.*]] = icmp eq ptr %[[SRC]], null +// OGCG-NEXT: br i1 %[[NULL_CHECK]], label %[[LABEL_NULL:.*]], label %[[LABEL_NOTNULL:.*]] +// OGCG: [[LABEL_NOTNULL]]: +// OGCG-NEXT: %[[VTABLE:.*]] = load ptr, ptr %[[SRC]], align 8 +// OGCG-NEXT: %[[VTABLE_CHECK:.*]] = icmp eq ptr %[[VTABLE]], getelementptr inbounds {{.*}} (i8, ptr @_ZTV7Derived, i64 16) +// OGCG-NEXT: br i1 %[[VTABLE_CHECK]], label %[[LABEL_END:.*]], label %[[LABEL_NULL]] +// OGCG: [[LABEL_NULL]]: +// OGCG-NEXT: br label %[[LABEL_END]] +// OGCG: [[LABEL_END]]: +// OGCG-NEXT: %[[RESULT:.*]] = phi ptr [ %[[SRC]], %[[LABEL_NOTNULL]] ], [ null, %[[LABEL_NULL]] ] +// OGCG-NEXT: ret ptr %[[RESULT]] +// OGCG-NEXT: } + +Derived &ref_cast(Base1 &ref) { + return dynamic_cast<Derived &>(ref); +} + +// CIR: cir.func {{.*}} @_Z8ref_castR5Base1 +// CIR: %[[SRC:.*]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!cir.ptr<!rec_Base1>>, !cir.ptr<!rec_Base1> +// CIR-NEXT: %[[EXPECTED_VPTR:.*]] = cir.vtable.address_point(@_ZTV7Derived, address_point = <index = 0, offset = 2>) : !cir.vptr +// CIR-NEXT: %[[SRC_VPTR_PTR:.*]] = cir.cast bitcast %[[SRC]] : !cir.ptr<!rec_Base1> -> !cir.ptr<!cir.vptr> +// CIR-NEXT: %[[SRC_VPTR:.*]] = cir.load{{.*}} %[[SRC_VPTR_PTR]] : !cir.ptr<!cir.vptr>, !cir.vptr +// CIR-NEXT: %[[SUCCESS:.*]] = cir.cmp(eq, %[[SRC_VPTR]], %[[EXPECTED_VPTR]]) : !cir.vptr, !cir.bool +// CIR-NEXT: %[[FAILED:.*]] = cir.unary(not, %[[SUCCESS]]) : !cir.bool, !cir.bool +// CIR-NEXT: cir.if %[[FAILED]] { +// CIR-NEXT: cir.call @__cxa_bad_cast() : () -> () +// CIR-NEXT: cir.unreachable +// CIR-NEXT: } +// CIR-NEXT: %{{.+}} = cir.cast bitcast %[[SRC]] : !cir.ptr<!rec_Base1> -> !cir.ptr<!rec_Derived> + +// LLVM: define{{.*}} ptr @_Z8ref_castR5Base1(ptr{{.*}} %[[SRC:.*]]) +// LLVM-NEXT: %[[VPTR:.*]] = load ptr, ptr %[[SRC]], align 8 +// LLVM-NEXT: %[[OK:.*]] = icmp eq ptr %[[VPTR]], getelementptr inbounds nuw (i8, ptr @_ZTV7Derived, i64 16) +// LLVM-NEXT: br i1 %[[OK]], label %[[LABEL_OK:.*]], label %[[LABEL_FAIL:.*]] +// LLVM: [[LABEL_FAIL]]: +// LLVM-NEXT: tail call void @__cxa_bad_cast() +// LLVM-NEXT: unreachable +// LLVM: [[LABEL_OK]]: +// LLVM-NEXT: ret ptr %[[SRC]] +// LLVM-NEXT: } + +// OGCG: define{{.*}} ptr @_Z8ref_castR5Base1(ptr {{.*}} %[[REF:.*]]) +// OGCG-NEXT: entry: +// OGCG-NEXT: %[[VTABLE:.*]] = load ptr, ptr %[[REF]], align 8 +// OGCG-NEXT: %[[VTABLE_CHECK:.*]] = icmp eq ptr %[[VTABLE]], getelementptr inbounds {{.*}} (i8, ptr @_ZTV7Derived, i64 16) +// OGCG-NEXT: br i1 %[[VTABLE_CHECK]], label %[[LABEL_END:.*]], label %[[LABEL_NULL:.*]] +// OGCG: [[LABEL_NULL]]: +// OGCG-NEXT: {{.*}}call void @__cxa_bad_cast() +// OGCG-NEXT: unreachable +// OGCG: [[LABEL_END]]: +// OGCG-NEXT: ret ptr %[[REF]] +// OGCG-NEXT: } 
@ojhunt
Copy link
Contributor

ojhunt commented Oct 17, 2025

The class traversal logic looks to be pretty much identical to the non-CIR final class dynamic_cast optimization (CGCXXABI::getExactDynamicCastInfo) is it possible to reuse that logic? maybe extending it if additional information is needed? (also while trying to find that code again I found yet another near copy of the same logic. sigh :D0

@andykaylor
Copy link
Contributor Author

The class traversal logic looks to be pretty much identical to the non-CIR final class dynamic_cast optimization (CGCXXABI::getExactDynamicCastInfo) is it possible to reuse that logic? maybe extending it if additional information is needed? (also while trying to find that code again I found yet another near copy of the same logic. sigh :D0

Unfortunately, this is a recurring problem with CIR codegen. We've tried to keep the basic code structure as close to LLVM IR codegen as is practical. There are, of course, many places where we could be sharing the actual code for things that don't involve the target substrate. We've implemented a few of these, and for others we've just left comments saying that it should be done. We're trying to find the right balance between making fast progress on the upstreaming versus accumulating technical debt and baking in fragility.

Is the other "near copy" you were referring to in computeOffsetHint?

Where would you suggest moving this logic? ASTRecordLayout maybe?

@ojhunt
Copy link
Contributor

ojhunt commented Oct 18, 2025

The class traversal logic looks to be pretty much identical to the non-CIR final class dynamic_cast optimization (CGCXXABI::getExactDynamicCastInfo) is it possible to reuse that logic? maybe extending it if additional information is needed? (also while trying to find that code again I found yet another near copy of the same logic. sigh :D0

Unfortunately, this is a recurring problem with CIR codegen. We've tried to keep the basic code structure as close to LLVM IR codegen as is practical. There are, of course, many places where we could be sharing the actual code for things that don't involve the target substrate. We've implemented a few of these, and for others we've just left comments saying that it should be done. We're trying to find the right balance between making fast progress on the upstreaming versus accumulating technical debt and baking in fragility.

Is the other "near copy" you were referring to in computeOffsetHint?

Possibly - I just noticed it when I searched for base traversal while trying to find the accursed function :D

Where would you suggest moving this logic? ASTRecordLayout maybe?

Ah, do you not have access to the ABI interface in CIR? (This is just an "Oliver doesn't know where the abstraction boundaries are" question)

Anyway, I might have time this weekend to look into moving it into ASTRecordLayout - would you mind looking at the DynamicCastInfo struct and just noting what additional data you need?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

clang Clang issues not falling into any other category ClangIR Anything related to the ClangIR project

3 participants