Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/DXIL.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3027,7 +3027,7 @@ DECL.USEDEXTERNALFUNCTION External function must be used
DECL.USEDINTERNAL Internal declaration must be used
FLOW.DEADLOOP Loop must have break.
FLOW.FUNCTIONCALL Function with parameter is not permitted
FLOW.NORECUSION Recursion is not permitted.
FLOW.NORECURSION Recursion is not permitted.
FLOW.REDUCIBLE Execution flow must be reducible.
INSTR.ALLOWED Instructions must be of an allowed type.
INSTR.ATOMICCONST Constant destination to atomic.
Expand Down
4 changes: 2 additions & 2 deletions lib/HLSL/DxilValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5595,7 +5595,7 @@ static void ValidateCallGraph(ValidationContext &ValCtx) {
depthMap[entryNode] = 0;
if (CallGraphNode *N = CalculateCallDepth(entryNode, depthMap, callStack,
ValCtx.entryFuncCallSet))
ValCtx.EmitFnError(N->getFunction(), ValidationRule::FlowNoRecusion);
ValCtx.EmitFnError(N->getFunction(), ValidationRule::FlowNoRecursion);
if (ValCtx.DxilMod.GetShaderModel()->IsHS()) {
CallGraphNode *patchConstantNode =
CG[ValCtx.DxilMod.GetPatchConstantFunction()];
Expand All @@ -5604,7 +5604,7 @@ static void ValidateCallGraph(ValidationContext &ValCtx) {
if (CallGraphNode *N =
CalculateCallDepth(patchConstantNode, depthMap, callStack,
ValCtx.patchConstFuncCallSet))
ValCtx.EmitFnError(N->getFunction(), ValidationRule::FlowNoRecusion);
ValCtx.EmitFnError(N->getFunction(), ValidationRule::FlowNoRecursion);
}
}

Expand Down
6 changes: 6 additions & 0 deletions tools/clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -7802,6 +7802,12 @@ def err_hlsl_logical_binop_scalar : Error<
"operands for short-circuiting logical binary operator must be scalar, for non-scalar types use '%select{and|or}0'">;
def err_hlsl_ternary_scalar : Error<
"condition for short-circuiting ternary operator must be scalar, for non-scalar types use 'select'">;
def err_hlsl_no_recursion : Error<
"recursive functions are not allowed: %select{entry|export|patch constant}0 function calls recursive function '%1'">;
def err_hlsl_missing_patch_constant_function : Error<
"patch constant function '%0' must be defined">;
def err_hlsl_patch_reachability_not_allowed : Error<
"%select{patch constant|entry}0 function '%1' should not be reachable from %select{patch constant|entry}2 function '%3'">;
def warn_hlsl_structurize_exits_lifetime_markers_conflict : Warning <
"structurize-returns skipped function '%0' due to incompatibility with lifetime markers. Use -disable-lifetime-markers to enable structurize-exits on this function.">,
InGroup< HLSLStructurizeExitsLifetimeMarkersConflict >;
Expand Down
5 changes: 5 additions & 0 deletions tools/clang/include/clang/Sema/SemaHLSL.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ void DiagnosePackingOffset(clang::Sema *self, clang::SourceLocation loc,
void DiagnoseRegisterType(clang::Sema *self, clang::SourceLocation loc,
clang::QualType type, char registerType);

clang::FunctionDecl *ValidateNoRecursion(clang::Sema *self,
clang::FunctionDecl *FD);

void ValidateNoRecursionInTranslationUnit(clang::Sema *self);

void DiagnoseTranslationUnit(clang::Sema *self);

void DiagnoseUnusualAnnotationsForHLSL(
Expand Down
178 changes: 139 additions & 39 deletions tools/clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3024,6 +3024,20 @@ class CallGraphWithRecurseGuard {
}
}

// return true if FD2 is reachable from FD1
bool CheckReachability(FunctionDecl *FD1, FunctionDecl *FD2) {
if (FD1 == FD2)
return true;
auto node = m_callNodes.find(FD1);
if (node != m_callNodes.end()) {
for (FunctionDecl *Callee : node->second.CalleeFns) {
if (CheckReachability(Callee, FD2))
return true;
}
}
return false;
}

FunctionDecl *CheckRecursion(FunctionDecl *EntryFnDecl) const {
FnCallStack CallStack;
EntryFnDecl = getFunctionWithBody(EntryFnDecl);
Expand Down Expand Up @@ -11372,6 +11386,31 @@ bool hlsl::DiagnoseNodeStructArgument(Sema *self, TemplateArgumentLoc ArgLoc,
}
}

// validates that for every function in the translation unit, if it
// references a patch constant function, there exists a function declaration
// that could serve as a candidate to that patch constant function.
void ValidatePatchConstantFunctionsExist(clang::Sema *self) {
for (auto decl : self->getASTContext().getTranslationUnitDecl()->decls()) {
// TODO: improve condition so that only exported functions are checked,
// instead of all functions. Issue: #5857
if (FunctionDecl *FD = dyn_cast<FunctionDecl>(decl)) {
// If there is no patch constant function, then we don't need to validate
// anything.
if (const HLSLPatchConstantFuncAttr *Attr =
FD->getAttr<HLSLPatchConstantFuncAttr>()) {
NameLookup NL =
GetSingleFunctionDeclByName(self, Attr->getFunctionName(),
/*checkPatch*/ true);
if (!NL.Found || !NL.Found->hasBody()) {
self->Diag(Attr->getLocation(),
diag::err_hlsl_missing_patch_constant_function)
<< Attr->getFunctionName();
}
}
}
}
}

// This function diagnoses whether or not all entry-point attributes
// should exist on this shader stage
void DiagnoseEntryAttrAllowedOnStage(clang::Sema *self,
Expand Down Expand Up @@ -11421,9 +11460,9 @@ void hlsl::DiagnoseTranslationUnit(clang::Sema *self) {
}
}

// Don't check entry function for library.
if (self->getLangOpts().IsHLSLLibrary) {
// TODO: validate no recursion start from every function.
ValidatePatchConstantFunctionsExist(self);
ValidateNoRecursionInTranslationUnit(self);
return;
}

Expand Down Expand Up @@ -11458,47 +11497,56 @@ void hlsl::DiagnoseTranslationUnit(clang::Sema *self) {
}
}

// Validate that there is no recursion; start with the entry function.
// NOTE: the information gathered here could be used to bypass code generation
// on functions that are unreachable (as an early form of dead code
// elimination).
if (pEntryPointDecl) {
const auto *shaderModel =
hlsl::ShaderModel::GetByName(self->getLangOpts().HLSLProfile.c_str());
FunctionDecl *result = ValidateNoRecursion(self, pEntryPointDecl);

if (shaderModel->IsHS()) {
if (const HLSLPatchConstantFuncAttr *Attr =
pEntryPointDecl->getAttr<HLSLPatchConstantFuncAttr>()) {
NameLookup NL = GetSingleFunctionDeclByName(
self, Attr->getFunctionName(), /*checkPatch*/ true);
if (!NL.Found || !NL.Found->hasBody()) {
unsigned id =
Diags.getCustomDiagID(clang::DiagnosticsEngine::Level::Error,
"missing patch function definition");
Diags.Report(id);
return;
}
pPatchFnDecl = NL.Found;
if (result) {
self->Diag(result->getSourceRange().getBegin(), diag::err_hlsl_no_recursion)
<< 0 << result->getName();
}

const auto *shaderModel =
hlsl::ShaderModel::GetByName(self->getLangOpts().HLSLProfile.c_str());

if (shaderModel->IsHS()) {
if (const HLSLPatchConstantFuncAttr *attr =
pEntryPointDecl->getAttr<HLSLPatchConstantFuncAttr>()) {
NameLookup NL = GetSingleFunctionDeclByName(self, attr->getFunctionName(),
/*checkPatch*/ true);
if (!NL.Found || !NL.Found->hasBody()) {
self->Diag(attr->getLocation(),
diag::err_hlsl_missing_patch_constant_function)
<< attr->getFunctionName();
}
pPatchFnDecl = NL.Found;
}
}

hlsl::CallGraphWithRecurseGuard CG;
CG.BuildForEntry(pEntryPointDecl);
Decl *pResult = CG.CheckRecursion(pEntryPointDecl);
if (pResult) {
unsigned id =
Diags.getCustomDiagID(clang::DiagnosticsEngine::Level::Error,
"recursive functions not allowed");
Diags.Report(pResult->getSourceRange().getBegin(), id);
if (pPatchFnDecl) {
FunctionDecl *patchResult = ValidateNoRecursion(self, pPatchFnDecl);

// In this case, recursion was detected in the patch-constant function
if (patchResult) {
self->Diag(patchResult->getSourceRange().getBegin(),
diag::err_hlsl_no_recursion)
<< 2 << patchResult->getName();
}
if (pPatchFnDecl) {

// The patch function decl and the entry function decl should be
// disconnected with respect to the call graph.
// Only check this if neither function decl is recursive
if (!result && !patchResult) {
hlsl::CallGraphWithRecurseGuard CG;
CG.BuildForEntry(pPatchFnDecl);
Decl *pPatchFnDecl = CG.CheckRecursion(pEntryPointDecl);
if (pPatchFnDecl) {
unsigned id = Diags.getCustomDiagID(
clang::DiagnosticsEngine::Level::Error,
"recursive functions not allowed (via patch function)");
Diags.Report(pPatchFnDecl->getSourceRange().getBegin(), id);
if (CG.CheckReachability(pPatchFnDecl, pEntryPointDecl)) {
self->Diag(pEntryPointDecl->getSourceRange().getBegin(),
diag::err_hlsl_patch_reachability_not_allowed)
<< 1 << pEntryPointDecl->getName() << 0 << pPatchFnDecl->getName();
}
CG.BuildForEntry(pEntryPointDecl);
if (CG.CheckReachability(pEntryPointDecl, pPatchFnDecl)) {
self->Diag(pEntryPointDecl->getSourceRange().getBegin(),
diag::err_hlsl_patch_reachability_not_allowed)
<< 0 << pPatchFnDecl->getName() << 1 << pEntryPointDecl->getName();
}
}
}
Expand Down Expand Up @@ -15313,8 +15361,8 @@ void DiagnoseMeshEntry(Sema &S, FunctionDecl *FD, llvm::StringRef StageName) {
}

void DiagnoseHullEntry(Sema &S, FunctionDecl *FD, llvm::StringRef StageName) {

if (!(FD->getAttr<HLSLPatchConstantFuncAttr>()))
HLSLPatchConstantFuncAttr *Attr = FD->getAttr<HLSLPatchConstantFuncAttr>();
if (!Attr)
S.Diags.Report(FD->getLocation(), diag::err_hlsl_missing_attr)
<< StageName << "patchconstantfunc";

Expand Down Expand Up @@ -15690,6 +15738,58 @@ void TryAddShaderAttrFromTargetProfile(Sema &S, FunctionDecl *FD,
return;
}

// in the non-library case, this function will be run only once,
// but in the library case, this function will be run for each
// viable top-level function declaration by
// ValidateNoRecursionInTranslationUnit.
// (viable as in, is exported)
clang::FunctionDecl *ValidateNoRecursion(clang::Sema *self,
clang::FunctionDecl *FD) {
// Validate that there is no recursion reachable by this function declaration
// NOTE: the information gathered here could be used to bypass code generation
// on functions that are unreachable (as an early form of dead code
// elimination).
if (FD) {
hlsl::CallGraphWithRecurseGuard CG;
CG.BuildForEntry(FD);
return CG.CheckRecursion(FD);
}
return nullptr;
}

void ValidateNoRecursionInTranslationUnit(clang::Sema *self) {
std::set<FunctionDecl *> FDecls;
std::vector<FunctionDecl *> FDeclsVec;
for (auto decl : self->getASTContext().getTranslationUnitDecl()->decls()) {
// TODO: improve condition so that only exported functions are checked,
// instead of all functions. Issue: #5857
if (FunctionDecl *FD = dyn_cast<FunctionDecl>(decl)) {
// returns the first recursive function declaration detected
// from this function declaration FD, and determines whether
// the recursion was detected in the patch-constant function
FunctionDecl *pResult = ValidateNoRecursion(self, FD);
// if there is a function that was detected to be recursive,
// then make sure it is saved for later to emit a diagnostic
if (pResult) {
FDecls.insert(pResult);
FDeclsVec.push_back(pResult);
}
}
}

// iterate through FDeclsVec to maintain AST order, and delete
// from set FDecls as we go.
for (FunctionDecl *fdecl : FDeclsVec) {
if (FDecls.find(fdecl) == FDecls.end()) {
continue;
}
self->Diag(fdecl->getSourceRange().getBegin(), diag::err_hlsl_no_recursion)
<< 0 << fdecl->getName();

FDecls.erase(fdecl);
}
}

// The DiagnoseEntry function does 2 things:
// 1. Determine whether this function is the current entry point for a
// non-library compilation, add an implicit shader attribute if so.
Expand Down
2 changes: 1 addition & 1 deletion tools/clang/test/CodeGenHLSL/recursive2.hlsl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s

// CHECK: error: recursive functions not allowed
// CHECK: error: recursive functions are not allowed: entry function calls recursive function 'test_inout'

struct M {
float m;
Expand Down
2 changes: 1 addition & 1 deletion tools/clang/test/CodeGenHLSL/recursive3.hlsl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s

// CHECK: error: recursive functions not allowed
// CHECK: error: recursive functions are not allowed: entry function calls recursive function 'test_ret'

float test_ret()
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// RUN: %dxc -T lib_6_5 %s | FileCheck %s


// actual selected HSPerPatchFunc1 for HSMain1 and HSMain3
float4 fooey()
{
float4 e;
float4 d;
d.x = 4;

return e;
}

[shader("hull")]
// CHECK: error: patch constant function 'NotFooey' must be defined
[patchconstantfunc("NotFooey")]
float4 main(float a : A, float b:B) : SV_TARGET
{
float4 f = b;
return f;
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// RUN: %dxc -E HSmain -T hs_6_0 %s | FileCheck %s

struct HSPerPatchData
{
float edges[3] : SV_TessFactor;
float inside : SV_InsideTessFactor;
};

[shader("hull")]
[domain("tri")]
[partitioning("fractional_odd")]
[outputtopology("triangle_cw")]
[outputcontrolpoints(3)]
[patchconstantfunc("patchfn")]
void HSmain(uint ix : SV_OutputControlPointID);


// actual selected HSPerPatchFunc1 for HSMain1 and HSMain3
HSPerPatchData patchfn()
{
HSPerPatchData d;

d.edges[0] = -5;
d.edges[1] = -6;
d.edges[2] = -7;
d.inside = -8;
// CHECK: error: entry function 'HSmain' should not be reachable from patch constant function 'patchfn'
HSmain(3);
return d;
}

[shader("hull")]
[domain("tri")]
[partitioning("fractional_odd")]
[outputtopology("triangle_cw")]
[outputcontrolpoints(3)]
[patchconstantfunc("patchfn")]
void HSmain(uint ix : SV_OutputControlPointID)
{
return;
}

Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// Test for recursion detection. Note that this cannot be a syntax test
// because we detect from the entry point and syntax tests have none.

// CHECK: error: recursive functions not allowed
// CHECK: error: recursive functions are not allowed: entry function calls recursive function 'A'

void B();
void A() { B(); }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// Test for recursion detection. Note that this cannot be a syntax test
// because we detect from the entry point and syntax tests have none.

// CHECK: error: recursive functions not allowed
// CHECK: error: recursive functions are not allowed: entry function calls recursive function 'A'

struct MyClass
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
// The SCCP pass replaces the recursive call with an undef value,
// which is why validation fails with a non-obvious error.

// CHECK: validation errors
// CHECK: Instructions should not read uninitialized value
// CHECK: error: recursive functions are not allowed: entry function calls recursive function 'func'

struct S
{
Expand Down
Loading