- Notifications
You must be signed in to change notification settings - Fork 14.9k
[CIR] Add support for exact dynamic casts #164007
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This adds support for handling exact dynamic casts when optimizations are enabled.
@llvm/pr-subscribers-clang Author: Andy Kaylor (andykaylor) ChangesThis adds support for handling exact dynamic casts when optimizations are enabled. Full diff: https://github.com/llvm/llvm-project/pull/164007.diff 5 Files Affected:
diff --git a/clang/lib/CIR/CodeGen/CIRGenCall.cpp b/clang/lib/CIR/CodeGen/CIRGenCall.cpp index 61072f0883728..88aef89ddd2b9 100644 --- a/clang/lib/CIR/CodeGen/CIRGenCall.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenCall.cpp @@ -690,6 +690,22 @@ void CallArg::copyInto(CIRGenFunction &cgf, Address addr, isUsed = true; } +mlir::Value CIRGenFunction::emitRuntimeCall(mlir::Location loc, + cir::FuncOp callee, + ArrayRef<mlir::Value> args) { + // TODO(cir): set the calling convention to this runtime call. + assert(!cir::MissingFeatures::opFuncCallingConv()); + + cir::CallOp call = builder.createCallOp(loc, callee, args); + assert(call->getNumResults() <= 1 && + "runtime functions have at most 1 result"); + + if (call->getNumResults() == 0) + return nullptr; + + return call->getResult(0); +} + void CIRGenFunction::emitCallArg(CallArgList &args, const clang::Expr *e, clang::QualType argType) { assert(argType->isReferenceType() == e->isGLValue() && diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.h b/clang/lib/CIR/CodeGen/CIRGenFunction.h index 3c36f5c697118..84b4ba293b3aa 100644 --- a/clang/lib/CIR/CodeGen/CIRGenFunction.h +++ b/clang/lib/CIR/CodeGen/CIRGenFunction.h @@ -1380,6 +1380,9 @@ class CIRGenFunction : public CIRGenTypeCache { void emitReturnOfRValue(mlir::Location loc, RValue rv, QualType ty); + mlir::Value emitRuntimeCall(mlir::Location loc, cir::FuncOp callee, + llvm::ArrayRef<mlir::Value> args = {}); + /// Emit the computation of the specified expression of scalar type. mlir::Value emitScalarExpr(const clang::Expr *e); diff --git a/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp b/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp index d54d2e9cb29e5..ef91288ab6155 100644 --- a/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp @@ -1869,6 +1869,15 @@ static cir::FuncOp getBadCastFn(CIRGenFunction &cgf) { return cgf.cgm.createRuntimeFunction(fnTy, "__cxa_bad_cast"); } +static void emitCallToBadCast(CIRGenFunction &cgf, mlir::Location loc) { + // TODO(cir): set the calling convention to the runtime function. + assert(!cir::MissingFeatures::opFuncCallingConv()); + + cgf.emitRuntimeCall(loc, getBadCastFn(cgf)); + cir::UnreachableOp::create(cgf.getBuilder(), loc); + cgf.getBuilder().clearInsertionPoint(); +} + // TODO(cir): This could be shared with classic codegen. static CharUnits computeOffsetHint(ASTContext &astContext, const CXXRecordDecl *src, @@ -1954,6 +1963,136 @@ static Address emitDynamicCastToVoid(CIRGenFunction &cgf, mlir::Location loc, return Address{ptr, src.getAlignment()}; } +static mlir::Value emitExactDynamicCast(CIRGenItaniumCXXABI &abi, + CIRGenFunction &cgf, mlir::Location loc, + QualType srcRecordTy, + QualType destRecordTy, + cir::PointerType destCIRTy, + bool isRefCast, Address src) { + // Find all the inheritance paths from SrcRecordTy to DestRecordTy. + const CXXRecordDecl *srcDecl = srcRecordTy->getAsCXXRecordDecl(); + const CXXRecordDecl *destDecl = destRecordTy->getAsCXXRecordDecl(); + CXXBasePaths paths(/*FindAmbiguities=*/true, /*RecordPaths=*/true, + /*DetectVirtual=*/false); + (void)destDecl->isDerivedFrom(srcDecl, paths); + + // Find an offset within `destDecl` where a `srcDecl` instance and its vptr + // might appear. + std::optional<CharUnits> offset; + for (const CXXBasePath &path : paths) { + // dynamic_cast only finds public inheritance paths. + if (path.Access != AS_public) + continue; + + CharUnits pathOffset; + for (const CXXBasePathElement &pathElement : path) { + // Find the offset along this inheritance step. + const CXXRecordDecl *base = + pathElement.Base->getType()->getAsCXXRecordDecl(); + if (pathElement.Base->isVirtual()) { + // For a virtual base class, we know that the derived class is exactly + // destDecl, so we can use the vbase offset from its layout. + const ASTRecordLayout &layout = + cgf.getContext().getASTRecordLayout(destDecl); + pathOffset = layout.getVBaseClassOffset(base); + } else { + const ASTRecordLayout &layout = + cgf.getContext().getASTRecordLayout(pathElement.Class); + pathOffset += layout.getBaseClassOffset(base); + } + } + + if (!offset) { + offset = pathOffset; + } else if (offset != pathOffset) { + // base appears in at least two different places. Find the most-derived + // object and see if it's a DestDecl. Note that the most-derived object + // must be at least as aligned as this base class subobject, and must + // have a vptr at offset 0. + src = emitDynamicCastToVoid(cgf, loc, srcRecordTy, src); + srcDecl = destDecl; + offset = CharUnits::Zero(); + break; + } + } + + CIRGenBuilderTy &builder = cgf.getBuilder(); + + if (!offset) { + // If there are no public inheritance paths, the cast always fails. + mlir::Value nullPtrValue = builder.getNullPtr(destCIRTy, loc); + if (isRefCast) { + mlir::Region *currentRegion = builder.getBlock()->getParent(); + emitCallToBadCast(cgf, loc); + + // The call to bad_cast will terminate the block. Create a new block to + // hold any follow up code. + builder.createBlock(currentRegion, currentRegion->end()); + } + + return nullPtrValue; + } + + // Compare the vptr against the expected vptr for the destination type at + // this offset. Note that we do not know what type src points to in the case + // where the derived class multiply inherits from the base class so we can't + // use getVTablePtr, so we load the vptr directly instead. + + mlir::Value expectedVPtr = + abi.getVTableAddressPoint(BaseSubobject(srcDecl, *offset), destDecl); + + // TODO(cir): handle address space here. + assert(!cir::MissingFeatures::addressSpace()); + mlir::Type vptrTy = expectedVPtr.getType(); + mlir::Type vptrPtrTy = builder.getPointerTo(vptrTy); + Address srcVPtrPtr(builder.createBitcast(src.getPointer(), vptrPtrTy), + src.getAlignment()); + mlir::Value srcVPtr = builder.createLoad(loc, srcVPtrPtr); + + // TODO(cir): decorate SrcVPtr with TBAA info. + assert(!cir::MissingFeatures::opTBAA()); + + mlir::Value success = + builder.createCompare(loc, cir::CmpOpKind::eq, srcVPtr, expectedVPtr); + + auto emitCastResult = [&] { + if (offset->isZero()) + return builder.createBitcast(src.getPointer(), destCIRTy); + + // TODO(cir): handle address space here. + assert(!cir::MissingFeatures::addressSpace()); + mlir::Type u8PtrTy = builder.getUInt8PtrTy(); + + mlir::Value strideToApply = + builder.getConstInt(loc, builder.getUInt64Ty(), offset->getQuantity()); + mlir::Value srcU8Ptr = builder.createBitcast(src.getPointer(), u8PtrTy); + mlir::Value resultU8Ptr = cir::PtrStrideOp::create(builder, loc, u8PtrTy, + srcU8Ptr, strideToApply); + return builder.createBitcast(resultU8Ptr, destCIRTy); + }; + + if (isRefCast) { + mlir::Value failed = builder.createNot(success); + cir::IfOp::create(builder, loc, failed, /*withElseRegion=*/false, + [&](mlir::OpBuilder &, mlir::Location) { + emitCallToBadCast(cgf, loc); + }); + return emitCastResult(); + } + + return cir::TernaryOp::create( + builder, loc, success, + [&](mlir::OpBuilder &, mlir::Location) { + auto result = emitCastResult(); + builder.createYield(loc, result); + }, + [&](mlir::OpBuilder &, mlir::Location) { + mlir::Value nullPtrValue = builder.getNullPtr(destCIRTy, loc); + builder.createYield(loc, nullPtrValue); + }) + .getResult(); +} + static cir::DynamicCastInfoAttr emitDynamicCastInfo(CIRGenFunction &cgf, mlir::Location loc, QualType srcRecordTy, @@ -1995,8 +2134,27 @@ mlir::Value CIRGenItaniumCXXABI::emitDynamicCast(CIRGenFunction &cgf, // if the dynamic type of the pointer is exactly the destination type. if (destRecordTy->getAsCXXRecordDecl()->isEffectivelyFinal() && cgf.cgm.getCodeGenOpts().OptimizationLevel > 0) { - cgm.errorNYI(loc, "emitExactDynamicCast"); - return {}; + CIRGenBuilderTy &builder = cgf.getBuilder(); + // If this isn't a reference cast, check the pointer to see if it's null. + if (!isRefCast) { + mlir::Value srcPtrIsNull = builder.createPtrIsNull(src.getPointer()); + return cir::TernaryOp::create( + builder, loc, srcPtrIsNull, + [&](mlir::OpBuilder, mlir::Location) { + builder.createYield( + loc, builder.getNullPtr(destCIRTy, loc).getResult()); + }, + [&](mlir::OpBuilder &, mlir::Location) { + mlir::Value exactCast = emitExactDynamicCast( + *this, cgf, loc, srcRecordTy, destRecordTy, destCIRTy, + isRefCast, src); + builder.createYield(loc, exactCast); + }) + .getResult(); + } + + return emitExactDynamicCast(*this, cgf, loc, srcRecordTy, destRecordTy, + destCIRTy, isRefCast, src); } cir::DynamicCastInfoAttr castInfo = diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 0243bf120f396..51dba33338cd6 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -2223,6 +2223,15 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite( return mlir::success(); } + if (auto vptrTy = mlir::dyn_cast<cir::VPtrType>(type)) { + // !cir.vptr is a special case, but it's just a pointer to LLVM. + auto kind = convertCmpKindToICmpPredicate(cmpOp.getKind(), + /* isSigned=*/false); + rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>( + cmpOp, kind, adaptor.getLhs(), adaptor.getRhs()); + return mlir::success(); + } + if (mlir::isa<cir::FPTypeInterface>(type)) { mlir::LLVM::FCmpPredicate kind = convertCmpKindToFCmpPredicate(cmpOp.getKind()); diff --git a/clang/test/CIR/CodeGen/dynamic-cast-exact.cpp b/clang/test/CIR/CodeGen/dynamic-cast-exact.cpp new file mode 100644 index 0000000000000..41a70ce53db5e --- /dev/null +++ b/clang/test/CIR/CodeGen/dynamic-cast-exact.cpp @@ -0,0 +1,114 @@ +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -O1 -fclangir -clangir-disable-passes -emit-cir -o %t.cir %s +// RUN: FileCheck --input-file=%t.cir %s --check-prefix=CIR +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -O1 -fclangir -emit-llvm -o %t-cir.ll %s +// RUN: FileCheck --input-file=%t-cir.ll --check-prefix=LLVM %s +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -O1 -emit-llvm -o %t.ll %s +// RUN: FileCheck --input-file=%t.ll --check-prefix=OGCG %s + +struct Base1 { + virtual ~Base1(); +}; + +struct Base2 { + virtual ~Base2(); +}; + +struct Derived final : Base1 {}; + +Derived *ptr_cast(Base1 *ptr) { + return dynamic_cast<Derived *>(ptr); +} + +// CIR: cir.func {{.*}} @_Z8ptr_castP5Base1 +// CIR: %[[SRC:.*]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!cir.ptr<!rec_Base1>>, !cir.ptr<!rec_Base1> +// CIR-NEXT: %[[NULL_PTR:.*]] = cir.const #cir.ptr<null> +// CIR-NEXT: %[[SRC_IS_NULL:.*]] = cir.cmp(eq, %[[SRC]], %[[NULL_PTR]]) +// CIR-NEXT: %[[RESULT:.*]] = cir.ternary(%[[SRC_IS_NULL]], true { +// CIR-NEXT: %[[NULL_PTR_DEST:.*]] = cir.const #cir.ptr<null> : !cir.ptr<!rec_Derived> +// CIR-NEXT: cir.yield %[[NULL_PTR_DEST]] : !cir.ptr<!rec_Derived> +// CIR-NEXT: }, false { +// CIR-NEXT: %[[EXPECTED_VPTR:.*]] = cir.vtable.address_point(@_ZTV7Derived, address_point = <index = 0, offset = 2>) : !cir.vptr +// CIR-NEXT: %[[SRC_VPTR_PTR:.*]] = cir.cast bitcast %[[SRC]] : !cir.ptr<!rec_Base1> -> !cir.ptr<!cir.vptr> +// CIR-NEXT: %[[SRC_VPTR:.*]] = cir.load{{.*}} %[[SRC_VPTR_PTR]] : !cir.ptr<!cir.vptr>, !cir.vptr +// CIR-NEXT: %[[SUCCESS:.*]] = cir.cmp(eq, %[[SRC_VPTR]], %[[EXPECTED_VPTR]]) : !cir.vptr, !cir.bool +// CIR-NEXT: %[[EXACT_RESULT:.*]] = cir.ternary(%[[SUCCESS]], true { +// CIR-NEXT: %[[RES:.*]] = cir.cast bitcast %[[SRC]] : !cir.ptr<!rec_Base1> -> !cir.ptr<!rec_Derived> +// CIR-NEXT: cir.yield %[[RES]] : !cir.ptr<!rec_Derived> +// CIR-NEXT: }, false { +// CIR-NEXT: %[[NULL:.*]] = cir.const #cir.ptr<null> : !cir.ptr<!rec_Derived> +// CIR-NEXT: cir.yield %[[NULL]] : !cir.ptr<!rec_Derived> +// CIR-NEXT: }) : (!cir.bool) -> !cir.ptr<!rec_Derived> +// CIR-NEXT: cir.yield %[[EXACT_RESULT]] : !cir.ptr<!rec_Derived> +// CIR-NEXT: }) : (!cir.bool) -> !cir.ptr<!rec_Derived> + +// Note: The LLVM output omits the label for the entry block (which is +// implicitly %1), so we use %{{.*}} to match the implicit label in the +// phi check. + +// LLVM: define dso_local ptr @_Z8ptr_castP5Base1(ptr{{.*}} %[[SRC:.*]]) +// LLVM-NEXT: %[[SRC_IS_NULL:.*]] = icmp eq ptr %0, null +// LLVM-NEXT: br i1 %[[SRC_IS_NULL]], label %[[LABEL_END:.*]], label %[[LABEL_NOTNULL:.*]] +// LLVM: [[LABEL_NOTNULL]]: +// LLVM-NEXT: %[[VPTR:.*]] = load ptr, ptr %[[SRC]], align 8 +// LLVM-NEXT: %[[SUCCESS:.*]] = icmp eq ptr %[[VPTR]], getelementptr inbounds nuw (i8, ptr @_ZTV7Derived, i64 16) +// LLVM-NEXT: %[[EXACT_RESULT:.*]] = select i1 %[[SUCCESS]], ptr %[[SRC]], ptr null +// LLVM-NEXT: br label %[[LABEL_END]] +// LLVM: [[LABEL_END]]: +// LLVM-NEXT: %[[RESULT:.*]] = phi ptr [ %[[EXACT_RESULT]], %[[LABEL_NOTNULL]] ], [ null, %{{.*}} ] +// LLVM-NEXT: ret ptr %[[RESULT]] +// LLVM-NEXT: } + +// OGCG: define{{.*}} ptr @_Z8ptr_castP5Base1(ptr {{.*}} %[[SRC:.*]]) +// OGCG-NEXT: entry: +// OGCG-NEXT: %[[NULL_CHECK:.*]] = icmp eq ptr %[[SRC]], null +// OGCG-NEXT: br i1 %[[NULL_CHECK]], label %[[LABEL_NULL:.*]], label %[[LABEL_NOTNULL:.*]] +// OGCG: [[LABEL_NOTNULL]]: +// OGCG-NEXT: %[[VTABLE:.*]] = load ptr, ptr %[[SRC]], align 8 +// OGCG-NEXT: %[[VTABLE_CHECK:.*]] = icmp eq ptr %[[VTABLE]], getelementptr inbounds {{.*}} (i8, ptr @_ZTV7Derived, i64 16) +// OGCG-NEXT: br i1 %[[VTABLE_CHECK]], label %[[LABEL_END:.*]], label %[[LABEL_NULL]] +// OGCG: [[LABEL_NULL]]: +// OGCG-NEXT: br label %[[LABEL_END]] +// OGCG: [[LABEL_END]]: +// OGCG-NEXT: %[[RESULT:.*]] = phi ptr [ %[[SRC]], %[[LABEL_NOTNULL]] ], [ null, %[[LABEL_NULL]] ] +// OGCG-NEXT: ret ptr %[[RESULT]] +// OGCG-NEXT: } + +Derived &ref_cast(Base1 &ref) { + return dynamic_cast<Derived &>(ref); +} + +// CIR: cir.func {{.*}} @_Z8ref_castR5Base1 +// CIR: %[[SRC:.*]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!cir.ptr<!rec_Base1>>, !cir.ptr<!rec_Base1> +// CIR-NEXT: %[[EXPECTED_VPTR:.*]] = cir.vtable.address_point(@_ZTV7Derived, address_point = <index = 0, offset = 2>) : !cir.vptr +// CIR-NEXT: %[[SRC_VPTR_PTR:.*]] = cir.cast bitcast %[[SRC]] : !cir.ptr<!rec_Base1> -> !cir.ptr<!cir.vptr> +// CIR-NEXT: %[[SRC_VPTR:.*]] = cir.load{{.*}} %[[SRC_VPTR_PTR]] : !cir.ptr<!cir.vptr>, !cir.vptr +// CIR-NEXT: %[[SUCCESS:.*]] = cir.cmp(eq, %[[SRC_VPTR]], %[[EXPECTED_VPTR]]) : !cir.vptr, !cir.bool +// CIR-NEXT: %[[FAILED:.*]] = cir.unary(not, %[[SUCCESS]]) : !cir.bool, !cir.bool +// CIR-NEXT: cir.if %[[FAILED]] { +// CIR-NEXT: cir.call @__cxa_bad_cast() : () -> () +// CIR-NEXT: cir.unreachable +// CIR-NEXT: } +// CIR-NEXT: %{{.+}} = cir.cast bitcast %[[SRC]] : !cir.ptr<!rec_Base1> -> !cir.ptr<!rec_Derived> + +// LLVM: define{{.*}} ptr @_Z8ref_castR5Base1(ptr{{.*}} %[[SRC:.*]]) +// LLVM-NEXT: %[[VPTR:.*]] = load ptr, ptr %[[SRC]], align 8 +// LLVM-NEXT: %[[OK:.*]] = icmp eq ptr %[[VPTR]], getelementptr inbounds nuw (i8, ptr @_ZTV7Derived, i64 16) +// LLVM-NEXT: br i1 %[[OK]], label %[[LABEL_OK:.*]], label %[[LABEL_FAIL:.*]] +// LLVM: [[LABEL_FAIL]]: +// LLVM-NEXT: tail call void @__cxa_bad_cast() +// LLVM-NEXT: unreachable +// LLVM: [[LABEL_OK]]: +// LLVM-NEXT: ret ptr %[[SRC]] +// LLVM-NEXT: } + +// OGCG: define{{.*}} ptr @_Z8ref_castR5Base1(ptr {{.*}} %[[REF:.*]]) +// OGCG-NEXT: entry: +// OGCG-NEXT: %[[VTABLE:.*]] = load ptr, ptr %[[REF]], align 8 +// OGCG-NEXT: %[[VTABLE_CHECK:.*]] = icmp eq ptr %[[VTABLE]], getelementptr inbounds {{.*}} (i8, ptr @_ZTV7Derived, i64 16) +// OGCG-NEXT: br i1 %[[VTABLE_CHECK]], label %[[LABEL_END:.*]], label %[[LABEL_NULL:.*]] +// OGCG: [[LABEL_NULL]]: +// OGCG-NEXT: {{.*}}call void @__cxa_bad_cast() +// OGCG-NEXT: unreachable +// OGCG: [[LABEL_END]]: +// OGCG-NEXT: ret ptr %[[REF]] +// OGCG-NEXT: } |
@llvm/pr-subscribers-clangir Author: Andy Kaylor (andykaylor) ChangesThis adds support for handling exact dynamic casts when optimizations are enabled. Full diff: https://github.com/llvm/llvm-project/pull/164007.diff 5 Files Affected:
diff --git a/clang/lib/CIR/CodeGen/CIRGenCall.cpp b/clang/lib/CIR/CodeGen/CIRGenCall.cpp index 61072f0883728..88aef89ddd2b9 100644 --- a/clang/lib/CIR/CodeGen/CIRGenCall.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenCall.cpp @@ -690,6 +690,22 @@ void CallArg::copyInto(CIRGenFunction &cgf, Address addr, isUsed = true; } +mlir::Value CIRGenFunction::emitRuntimeCall(mlir::Location loc, + cir::FuncOp callee, + ArrayRef<mlir::Value> args) { + // TODO(cir): set the calling convention to this runtime call. + assert(!cir::MissingFeatures::opFuncCallingConv()); + + cir::CallOp call = builder.createCallOp(loc, callee, args); + assert(call->getNumResults() <= 1 && + "runtime functions have at most 1 result"); + + if (call->getNumResults() == 0) + return nullptr; + + return call->getResult(0); +} + void CIRGenFunction::emitCallArg(CallArgList &args, const clang::Expr *e, clang::QualType argType) { assert(argType->isReferenceType() == e->isGLValue() && diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.h b/clang/lib/CIR/CodeGen/CIRGenFunction.h index 3c36f5c697118..84b4ba293b3aa 100644 --- a/clang/lib/CIR/CodeGen/CIRGenFunction.h +++ b/clang/lib/CIR/CodeGen/CIRGenFunction.h @@ -1380,6 +1380,9 @@ class CIRGenFunction : public CIRGenTypeCache { void emitReturnOfRValue(mlir::Location loc, RValue rv, QualType ty); + mlir::Value emitRuntimeCall(mlir::Location loc, cir::FuncOp callee, + llvm::ArrayRef<mlir::Value> args = {}); + /// Emit the computation of the specified expression of scalar type. mlir::Value emitScalarExpr(const clang::Expr *e); diff --git a/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp b/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp index d54d2e9cb29e5..ef91288ab6155 100644 --- a/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp @@ -1869,6 +1869,15 @@ static cir::FuncOp getBadCastFn(CIRGenFunction &cgf) { return cgf.cgm.createRuntimeFunction(fnTy, "__cxa_bad_cast"); } +static void emitCallToBadCast(CIRGenFunction &cgf, mlir::Location loc) { + // TODO(cir): set the calling convention to the runtime function. + assert(!cir::MissingFeatures::opFuncCallingConv()); + + cgf.emitRuntimeCall(loc, getBadCastFn(cgf)); + cir::UnreachableOp::create(cgf.getBuilder(), loc); + cgf.getBuilder().clearInsertionPoint(); +} + // TODO(cir): This could be shared with classic codegen. static CharUnits computeOffsetHint(ASTContext &astContext, const CXXRecordDecl *src, @@ -1954,6 +1963,136 @@ static Address emitDynamicCastToVoid(CIRGenFunction &cgf, mlir::Location loc, return Address{ptr, src.getAlignment()}; } +static mlir::Value emitExactDynamicCast(CIRGenItaniumCXXABI &abi, + CIRGenFunction &cgf, mlir::Location loc, + QualType srcRecordTy, + QualType destRecordTy, + cir::PointerType destCIRTy, + bool isRefCast, Address src) { + // Find all the inheritance paths from SrcRecordTy to DestRecordTy. + const CXXRecordDecl *srcDecl = srcRecordTy->getAsCXXRecordDecl(); + const CXXRecordDecl *destDecl = destRecordTy->getAsCXXRecordDecl(); + CXXBasePaths paths(/*FindAmbiguities=*/true, /*RecordPaths=*/true, + /*DetectVirtual=*/false); + (void)destDecl->isDerivedFrom(srcDecl, paths); + + // Find an offset within `destDecl` where a `srcDecl` instance and its vptr + // might appear. + std::optional<CharUnits> offset; + for (const CXXBasePath &path : paths) { + // dynamic_cast only finds public inheritance paths. + if (path.Access != AS_public) + continue; + + CharUnits pathOffset; + for (const CXXBasePathElement &pathElement : path) { + // Find the offset along this inheritance step. + const CXXRecordDecl *base = + pathElement.Base->getType()->getAsCXXRecordDecl(); + if (pathElement.Base->isVirtual()) { + // For a virtual base class, we know that the derived class is exactly + // destDecl, so we can use the vbase offset from its layout. + const ASTRecordLayout &layout = + cgf.getContext().getASTRecordLayout(destDecl); + pathOffset = layout.getVBaseClassOffset(base); + } else { + const ASTRecordLayout &layout = + cgf.getContext().getASTRecordLayout(pathElement.Class); + pathOffset += layout.getBaseClassOffset(base); + } + } + + if (!offset) { + offset = pathOffset; + } else if (offset != pathOffset) { + // base appears in at least two different places. Find the most-derived + // object and see if it's a DestDecl. Note that the most-derived object + // must be at least as aligned as this base class subobject, and must + // have a vptr at offset 0. + src = emitDynamicCastToVoid(cgf, loc, srcRecordTy, src); + srcDecl = destDecl; + offset = CharUnits::Zero(); + break; + } + } + + CIRGenBuilderTy &builder = cgf.getBuilder(); + + if (!offset) { + // If there are no public inheritance paths, the cast always fails. + mlir::Value nullPtrValue = builder.getNullPtr(destCIRTy, loc); + if (isRefCast) { + mlir::Region *currentRegion = builder.getBlock()->getParent(); + emitCallToBadCast(cgf, loc); + + // The call to bad_cast will terminate the block. Create a new block to + // hold any follow up code. + builder.createBlock(currentRegion, currentRegion->end()); + } + + return nullPtrValue; + } + + // Compare the vptr against the expected vptr for the destination type at + // this offset. Note that we do not know what type src points to in the case + // where the derived class multiply inherits from the base class so we can't + // use getVTablePtr, so we load the vptr directly instead. + + mlir::Value expectedVPtr = + abi.getVTableAddressPoint(BaseSubobject(srcDecl, *offset), destDecl); + + // TODO(cir): handle address space here. + assert(!cir::MissingFeatures::addressSpace()); + mlir::Type vptrTy = expectedVPtr.getType(); + mlir::Type vptrPtrTy = builder.getPointerTo(vptrTy); + Address srcVPtrPtr(builder.createBitcast(src.getPointer(), vptrPtrTy), + src.getAlignment()); + mlir::Value srcVPtr = builder.createLoad(loc, srcVPtrPtr); + + // TODO(cir): decorate SrcVPtr with TBAA info. + assert(!cir::MissingFeatures::opTBAA()); + + mlir::Value success = + builder.createCompare(loc, cir::CmpOpKind::eq, srcVPtr, expectedVPtr); + + auto emitCastResult = [&] { + if (offset->isZero()) + return builder.createBitcast(src.getPointer(), destCIRTy); + + // TODO(cir): handle address space here. + assert(!cir::MissingFeatures::addressSpace()); + mlir::Type u8PtrTy = builder.getUInt8PtrTy(); + + mlir::Value strideToApply = + builder.getConstInt(loc, builder.getUInt64Ty(), offset->getQuantity()); + mlir::Value srcU8Ptr = builder.createBitcast(src.getPointer(), u8PtrTy); + mlir::Value resultU8Ptr = cir::PtrStrideOp::create(builder, loc, u8PtrTy, + srcU8Ptr, strideToApply); + return builder.createBitcast(resultU8Ptr, destCIRTy); + }; + + if (isRefCast) { + mlir::Value failed = builder.createNot(success); + cir::IfOp::create(builder, loc, failed, /*withElseRegion=*/false, + [&](mlir::OpBuilder &, mlir::Location) { + emitCallToBadCast(cgf, loc); + }); + return emitCastResult(); + } + + return cir::TernaryOp::create( + builder, loc, success, + [&](mlir::OpBuilder &, mlir::Location) { + auto result = emitCastResult(); + builder.createYield(loc, result); + }, + [&](mlir::OpBuilder &, mlir::Location) { + mlir::Value nullPtrValue = builder.getNullPtr(destCIRTy, loc); + builder.createYield(loc, nullPtrValue); + }) + .getResult(); +} + static cir::DynamicCastInfoAttr emitDynamicCastInfo(CIRGenFunction &cgf, mlir::Location loc, QualType srcRecordTy, @@ -1995,8 +2134,27 @@ mlir::Value CIRGenItaniumCXXABI::emitDynamicCast(CIRGenFunction &cgf, // if the dynamic type of the pointer is exactly the destination type. if (destRecordTy->getAsCXXRecordDecl()->isEffectivelyFinal() && cgf.cgm.getCodeGenOpts().OptimizationLevel > 0) { - cgm.errorNYI(loc, "emitExactDynamicCast"); - return {}; + CIRGenBuilderTy &builder = cgf.getBuilder(); + // If this isn't a reference cast, check the pointer to see if it's null. + if (!isRefCast) { + mlir::Value srcPtrIsNull = builder.createPtrIsNull(src.getPointer()); + return cir::TernaryOp::create( + builder, loc, srcPtrIsNull, + [&](mlir::OpBuilder, mlir::Location) { + builder.createYield( + loc, builder.getNullPtr(destCIRTy, loc).getResult()); + }, + [&](mlir::OpBuilder &, mlir::Location) { + mlir::Value exactCast = emitExactDynamicCast( + *this, cgf, loc, srcRecordTy, destRecordTy, destCIRTy, + isRefCast, src); + builder.createYield(loc, exactCast); + }) + .getResult(); + } + + return emitExactDynamicCast(*this, cgf, loc, srcRecordTy, destRecordTy, + destCIRTy, isRefCast, src); } cir::DynamicCastInfoAttr castInfo = diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 0243bf120f396..51dba33338cd6 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -2223,6 +2223,15 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite( return mlir::success(); } + if (auto vptrTy = mlir::dyn_cast<cir::VPtrType>(type)) { + // !cir.vptr is a special case, but it's just a pointer to LLVM. + auto kind = convertCmpKindToICmpPredicate(cmpOp.getKind(), + /* isSigned=*/false); + rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>( + cmpOp, kind, adaptor.getLhs(), adaptor.getRhs()); + return mlir::success(); + } + if (mlir::isa<cir::FPTypeInterface>(type)) { mlir::LLVM::FCmpPredicate kind = convertCmpKindToFCmpPredicate(cmpOp.getKind()); diff --git a/clang/test/CIR/CodeGen/dynamic-cast-exact.cpp b/clang/test/CIR/CodeGen/dynamic-cast-exact.cpp new file mode 100644 index 0000000000000..41a70ce53db5e --- /dev/null +++ b/clang/test/CIR/CodeGen/dynamic-cast-exact.cpp @@ -0,0 +1,114 @@ +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -O1 -fclangir -clangir-disable-passes -emit-cir -o %t.cir %s +// RUN: FileCheck --input-file=%t.cir %s --check-prefix=CIR +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -O1 -fclangir -emit-llvm -o %t-cir.ll %s +// RUN: FileCheck --input-file=%t-cir.ll --check-prefix=LLVM %s +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -O1 -emit-llvm -o %t.ll %s +// RUN: FileCheck --input-file=%t.ll --check-prefix=OGCG %s + +struct Base1 { + virtual ~Base1(); +}; + +struct Base2 { + virtual ~Base2(); +}; + +struct Derived final : Base1 {}; + +Derived *ptr_cast(Base1 *ptr) { + return dynamic_cast<Derived *>(ptr); +} + +// CIR: cir.func {{.*}} @_Z8ptr_castP5Base1 +// CIR: %[[SRC:.*]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!cir.ptr<!rec_Base1>>, !cir.ptr<!rec_Base1> +// CIR-NEXT: %[[NULL_PTR:.*]] = cir.const #cir.ptr<null> +// CIR-NEXT: %[[SRC_IS_NULL:.*]] = cir.cmp(eq, %[[SRC]], %[[NULL_PTR]]) +// CIR-NEXT: %[[RESULT:.*]] = cir.ternary(%[[SRC_IS_NULL]], true { +// CIR-NEXT: %[[NULL_PTR_DEST:.*]] = cir.const #cir.ptr<null> : !cir.ptr<!rec_Derived> +// CIR-NEXT: cir.yield %[[NULL_PTR_DEST]] : !cir.ptr<!rec_Derived> +// CIR-NEXT: }, false { +// CIR-NEXT: %[[EXPECTED_VPTR:.*]] = cir.vtable.address_point(@_ZTV7Derived, address_point = <index = 0, offset = 2>) : !cir.vptr +// CIR-NEXT: %[[SRC_VPTR_PTR:.*]] = cir.cast bitcast %[[SRC]] : !cir.ptr<!rec_Base1> -> !cir.ptr<!cir.vptr> +// CIR-NEXT: %[[SRC_VPTR:.*]] = cir.load{{.*}} %[[SRC_VPTR_PTR]] : !cir.ptr<!cir.vptr>, !cir.vptr +// CIR-NEXT: %[[SUCCESS:.*]] = cir.cmp(eq, %[[SRC_VPTR]], %[[EXPECTED_VPTR]]) : !cir.vptr, !cir.bool +// CIR-NEXT: %[[EXACT_RESULT:.*]] = cir.ternary(%[[SUCCESS]], true { +// CIR-NEXT: %[[RES:.*]] = cir.cast bitcast %[[SRC]] : !cir.ptr<!rec_Base1> -> !cir.ptr<!rec_Derived> +// CIR-NEXT: cir.yield %[[RES]] : !cir.ptr<!rec_Derived> +// CIR-NEXT: }, false { +// CIR-NEXT: %[[NULL:.*]] = cir.const #cir.ptr<null> : !cir.ptr<!rec_Derived> +// CIR-NEXT: cir.yield %[[NULL]] : !cir.ptr<!rec_Derived> +// CIR-NEXT: }) : (!cir.bool) -> !cir.ptr<!rec_Derived> +// CIR-NEXT: cir.yield %[[EXACT_RESULT]] : !cir.ptr<!rec_Derived> +// CIR-NEXT: }) : (!cir.bool) -> !cir.ptr<!rec_Derived> + +// Note: The LLVM output omits the label for the entry block (which is +// implicitly %1), so we use %{{.*}} to match the implicit label in the +// phi check. + +// LLVM: define dso_local ptr @_Z8ptr_castP5Base1(ptr{{.*}} %[[SRC:.*]]) +// LLVM-NEXT: %[[SRC_IS_NULL:.*]] = icmp eq ptr %0, null +// LLVM-NEXT: br i1 %[[SRC_IS_NULL]], label %[[LABEL_END:.*]], label %[[LABEL_NOTNULL:.*]] +// LLVM: [[LABEL_NOTNULL]]: +// LLVM-NEXT: %[[VPTR:.*]] = load ptr, ptr %[[SRC]], align 8 +// LLVM-NEXT: %[[SUCCESS:.*]] = icmp eq ptr %[[VPTR]], getelementptr inbounds nuw (i8, ptr @_ZTV7Derived, i64 16) +// LLVM-NEXT: %[[EXACT_RESULT:.*]] = select i1 %[[SUCCESS]], ptr %[[SRC]], ptr null +// LLVM-NEXT: br label %[[LABEL_END]] +// LLVM: [[LABEL_END]]: +// LLVM-NEXT: %[[RESULT:.*]] = phi ptr [ %[[EXACT_RESULT]], %[[LABEL_NOTNULL]] ], [ null, %{{.*}} ] +// LLVM-NEXT: ret ptr %[[RESULT]] +// LLVM-NEXT: } + +// OGCG: define{{.*}} ptr @_Z8ptr_castP5Base1(ptr {{.*}} %[[SRC:.*]]) +// OGCG-NEXT: entry: +// OGCG-NEXT: %[[NULL_CHECK:.*]] = icmp eq ptr %[[SRC]], null +// OGCG-NEXT: br i1 %[[NULL_CHECK]], label %[[LABEL_NULL:.*]], label %[[LABEL_NOTNULL:.*]] +// OGCG: [[LABEL_NOTNULL]]: +// OGCG-NEXT: %[[VTABLE:.*]] = load ptr, ptr %[[SRC]], align 8 +// OGCG-NEXT: %[[VTABLE_CHECK:.*]] = icmp eq ptr %[[VTABLE]], getelementptr inbounds {{.*}} (i8, ptr @_ZTV7Derived, i64 16) +// OGCG-NEXT: br i1 %[[VTABLE_CHECK]], label %[[LABEL_END:.*]], label %[[LABEL_NULL]] +// OGCG: [[LABEL_NULL]]: +// OGCG-NEXT: br label %[[LABEL_END]] +// OGCG: [[LABEL_END]]: +// OGCG-NEXT: %[[RESULT:.*]] = phi ptr [ %[[SRC]], %[[LABEL_NOTNULL]] ], [ null, %[[LABEL_NULL]] ] +// OGCG-NEXT: ret ptr %[[RESULT]] +// OGCG-NEXT: } + +Derived &ref_cast(Base1 &ref) { + return dynamic_cast<Derived &>(ref); +} + +// CIR: cir.func {{.*}} @_Z8ref_castR5Base1 +// CIR: %[[SRC:.*]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!cir.ptr<!rec_Base1>>, !cir.ptr<!rec_Base1> +// CIR-NEXT: %[[EXPECTED_VPTR:.*]] = cir.vtable.address_point(@_ZTV7Derived, address_point = <index = 0, offset = 2>) : !cir.vptr +// CIR-NEXT: %[[SRC_VPTR_PTR:.*]] = cir.cast bitcast %[[SRC]] : !cir.ptr<!rec_Base1> -> !cir.ptr<!cir.vptr> +// CIR-NEXT: %[[SRC_VPTR:.*]] = cir.load{{.*}} %[[SRC_VPTR_PTR]] : !cir.ptr<!cir.vptr>, !cir.vptr +// CIR-NEXT: %[[SUCCESS:.*]] = cir.cmp(eq, %[[SRC_VPTR]], %[[EXPECTED_VPTR]]) : !cir.vptr, !cir.bool +// CIR-NEXT: %[[FAILED:.*]] = cir.unary(not, %[[SUCCESS]]) : !cir.bool, !cir.bool +// CIR-NEXT: cir.if %[[FAILED]] { +// CIR-NEXT: cir.call @__cxa_bad_cast() : () -> () +// CIR-NEXT: cir.unreachable +// CIR-NEXT: } +// CIR-NEXT: %{{.+}} = cir.cast bitcast %[[SRC]] : !cir.ptr<!rec_Base1> -> !cir.ptr<!rec_Derived> + +// LLVM: define{{.*}} ptr @_Z8ref_castR5Base1(ptr{{.*}} %[[SRC:.*]]) +// LLVM-NEXT: %[[VPTR:.*]] = load ptr, ptr %[[SRC]], align 8 +// LLVM-NEXT: %[[OK:.*]] = icmp eq ptr %[[VPTR]], getelementptr inbounds nuw (i8, ptr @_ZTV7Derived, i64 16) +// LLVM-NEXT: br i1 %[[OK]], label %[[LABEL_OK:.*]], label %[[LABEL_FAIL:.*]] +// LLVM: [[LABEL_FAIL]]: +// LLVM-NEXT: tail call void @__cxa_bad_cast() +// LLVM-NEXT: unreachable +// LLVM: [[LABEL_OK]]: +// LLVM-NEXT: ret ptr %[[SRC]] +// LLVM-NEXT: } + +// OGCG: define{{.*}} ptr @_Z8ref_castR5Base1(ptr {{.*}} %[[REF:.*]]) +// OGCG-NEXT: entry: +// OGCG-NEXT: %[[VTABLE:.*]] = load ptr, ptr %[[REF]], align 8 +// OGCG-NEXT: %[[VTABLE_CHECK:.*]] = icmp eq ptr %[[VTABLE]], getelementptr inbounds {{.*}} (i8, ptr @_ZTV7Derived, i64 16) +// OGCG-NEXT: br i1 %[[VTABLE_CHECK]], label %[[LABEL_END:.*]], label %[[LABEL_NULL:.*]] +// OGCG: [[LABEL_NULL]]: +// OGCG-NEXT: {{.*}}call void @__cxa_bad_cast() +// OGCG-NEXT: unreachable +// OGCG: [[LABEL_END]]: +// OGCG-NEXT: ret ptr %[[REF]] +// OGCG-NEXT: } |
The class traversal logic looks to be pretty much identical to the non-CIR final class dynamic_cast optimization ( |
Unfortunately, this is a recurring problem with CIR codegen. We've tried to keep the basic code structure as close to LLVM IR codegen as is practical. There are, of course, many places where we could be sharing the actual code for things that don't involve the target substrate. We've implemented a few of these, and for others we've just left comments saying that it should be done. We're trying to find the right balance between making fast progress on the upstreaming versus accumulating technical debt and baking in fragility. Is the other "near copy" you were referring to in Where would you suggest moving this logic? |
Possibly - I just noticed it when I searched for base traversal while trying to find the accursed function :D
Ah, do you not have access to the ABI interface in CIR? (This is just an "Oliver doesn't know where the abstraction boundaries are" question) Anyway, I might have time this weekend to look into moving it into ASTRecordLayout - would you mind looking at the DynamicCastInfo struct and just noting what additional data you need? |
This adds support for handling exact dynamic casts when optimizations are enabled.