diff --git a/clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def b/clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def index b506f18eb7050..d9edeff7ac567 100644 --- a/clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def +++ b/clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def @@ -75,6 +75,10 @@ PUNCTUATOR(minus, '-') // RootElement Keywords: KEYWORD(RootSignature) // used only for diagnostic messaging KEYWORD(DescriptorTable) +KEYWORD(RootConstants) + +// RootConstants Keywords: +KEYWORD(num32BitConstants) // DescriptorTable Keywords: KEYWORD(CBV) diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h index 91640e8bf0354..efa735ea03d94 100644 --- a/clang/include/clang/Parse/ParseHLSLRootSignature.h +++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h @@ -71,6 +71,7 @@ class RootSignatureParser { // expected, or, there is a lexing error /// Root Element parse methods: + std::optional parseRootConstants(); std::optional parseDescriptorTable(); std::optional parseDescriptorTableClause(); diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp index 042aedbf1af52..48d3e38b0519d 100644 --- a/clang/lib/Parse/ParseHLSLRootSignature.cpp +++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp @@ -27,6 +27,13 @@ RootSignatureParser::RootSignatureParser(SmallVector &Elements, bool RootSignatureParser::parse() { // Iterate as many RootElements as possible do { + if (tryConsumeExpectedToken(TokenKind::kw_RootConstants)) { + auto Constants = parseRootConstants(); + if (!Constants.has_value()) + return true; + Elements.push_back(*Constants); + } + if (tryConsumeExpectedToken(TokenKind::kw_DescriptorTable)) { auto Table = parseDescriptorTable(); if (!Table.has_value()) @@ -35,12 +42,27 @@ bool RootSignatureParser::parse() { } } while (tryConsumeExpectedToken(TokenKind::pu_comma)); - if (consumeExpectedToken(TokenKind::end_of_stream, + return consumeExpectedToken(TokenKind::end_of_stream, + diag::err_hlsl_unexpected_end_of_params, + /*param of=*/TokenKind::kw_RootSignature); +} + +std::optional RootSignatureParser::parseRootConstants() { + assert(CurToken.TokKind == TokenKind::kw_RootConstants && + "Expects to only be invoked starting at given keyword"); + + if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after, + CurToken.TokKind)) + return std::nullopt; + + RootConstants Constants; + + if (consumeExpectedToken(TokenKind::pu_r_paren, diag::err_hlsl_unexpected_end_of_params, - /*param of=*/TokenKind::kw_RootSignature)) - return true; + /*param of=*/TokenKind::kw_RootConstants)) + return std::nullopt; - return false; + return Constants; } std::optional RootSignatureParser::parseDescriptorTable() { diff --git a/clang/unittests/Lex/LexHLSLRootSignatureTest.cpp b/clang/unittests/Lex/LexHLSLRootSignatureTest.cpp index ca609b0b2e8b8..a761257149c11 100644 --- a/clang/unittests/Lex/LexHLSLRootSignatureTest.cpp +++ b/clang/unittests/Lex/LexHLSLRootSignatureTest.cpp @@ -128,7 +128,9 @@ TEST_F(LexHLSLRootSignatureTest, ValidLexAllTokensTest) { RootSignature - DescriptorTable + DescriptorTable RootConstants + + num32BitConstants CBV SRV UAV Sampler space visibility flags diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp index 585ac051d66a2..0a7d8ac86cc5f 100644 --- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp +++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp @@ -252,6 +252,32 @@ TEST_F(ParseHLSLRootSignatureTest, ValidSamplerFlagsTest) { ASSERT_TRUE(Consumer->isSatisfied()); } +TEST_F(ParseHLSLRootSignatureTest, ValidParseRootConsantsTest) { + const llvm::StringLiteral Source = R"cc( + RootConstants() + )cc"; + + TrivialModuleLoader ModLoader; + auto PP = createPP(Source, ModLoader); + auto TokLoc = SourceLocation(); + + hlsl::RootSignatureLexer Lexer(Source, TokLoc); + SmallVector Elements; + hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); + + // Test no diagnostics produced + Consumer->setNoDiag(); + + ASSERT_FALSE(Parser.parse()); + + ASSERT_EQ(Elements.size(), 1u); + + RootElement Elem = Elements[0]; + ASSERT_TRUE(std::holds_alternative(Elem)); + + ASSERT_TRUE(Consumer->isSatisfied()); +} + TEST_F(ParseHLSLRootSignatureTest, ValidTrailingCommaTest) { // This test will checks we can handling trailing commas ',' const llvm::StringLiteral Source = R"cc( diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h index 818caccfe1998..05735fa75b318 100644 --- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h +++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h @@ -54,6 +54,9 @@ struct Register { uint32_t Number; }; +// Models the parameter values of root constants +struct RootConstants {}; + // Models the end of a descriptor table and stores its visibility struct DescriptorTable { ShaderVisibility Visibility = ShaderVisibility::All; @@ -88,8 +91,9 @@ struct DescriptorTableClause { } }; -// Models RootElement : DescriptorTable | DescriptorTableClause -using RootElement = std::variant; +// Models RootElement : RootConstants | DescriptorTable | DescriptorTableClause +using RootElement = + std::variant; } // namespace rootsig } // namespace hlsl