Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[HLSL] Update Sema Checking Diagnostics for builtins #138429

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

spall
Copy link
Contributor

@spall spall commented May 4, 2025

Update how Sema Checking is done for HLSL builtins to allow for better error messages, mainly using 'err_builtin_invalid_arg_type'.
Try to follow the formula outlined in issue #134721
Closes #134721

@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" HLSL HLSL Language Support labels May 4, 2025
@llvmbot
Copy link
Member

llvmbot commented May 4, 2025

@llvm/pr-subscribers-hlsl

Author: Sarah Spall (spall)

Changes

Update how Sema Checking is done for HLSL builtins to allow for better error messages, mainly using 'err_builtin_invalid_arg_type'.
Try to follow the formula outlined in issue #134721
Closes #134721


Patch is 61.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/138429.diff

21 Files Affected:

  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+1-1)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+131-185)
  • (modified) clang/test/SemaHLSL/BuiltIns/AddUint64-errors.hlsl (+3-3)
  • (modified) clang/test/SemaHLSL/BuiltIns/asdouble-errors.hlsl (+10)
  • (modified) clang/test/SemaHLSL/BuiltIns/clamp-errors.hlsl (+16-11)
  • (modified) clang/test/SemaHLSL/BuiltIns/cross-errors.hlsl (+9-3)
  • (modified) clang/test/SemaHLSL/BuiltIns/degrees-errors.hlsl (+3-3)
  • (modified) clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl (+17-16)
  • (modified) clang/test/SemaHLSL/BuiltIns/frac-errors.hlsl (+4-4)
  • (modified) clang/test/SemaHLSL/BuiltIns/half-float-only-errors.hlsl (+2-2)
  • (modified) clang/test/SemaHLSL/BuiltIns/half-float-only-errors2.hlsl (+2-2)
  • (modified) clang/test/SemaHLSL/BuiltIns/isinf-errors.hlsl (+6-6)
  • (modified) clang/test/SemaHLSL/BuiltIns/lerp-errors.hlsl (+10-10)
  • (modified) clang/test/SemaHLSL/BuiltIns/logical-operator-errors.hlsl (+6)
  • (modified) clang/test/SemaHLSL/BuiltIns/mad-errors.hlsl (+11-11)
  • (modified) clang/test/SemaHLSL/BuiltIns/normalize-errors.hlsl (+4-4)
  • (modified) clang/test/SemaHLSL/BuiltIns/radians-errors.hlsl (+4-4)
  • (modified) clang/test/SemaHLSL/BuiltIns/rcp-errors.hlsl (+4-5)
  • (modified) clang/test/SemaHLSL/BuiltIns/reversebits-errors.hlsl (+1-1)
  • (modified) clang/test/SemaHLSL/BuiltIns/rsqrt-errors.hlsl (+5-6)
  • (modified) clang/test/SemaHLSL/BuiltIns/step-errors.hlsl (+4-4)
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index ccb14e9927adf..c94d34e0259be 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -12649,7 +12649,7 @@ def err_builtin_invalid_arg_type: Error<
   // An 'or' if non-empty second and third components are combined
   "%plural{0:|:%plural{0:|:or }2}3"
   // Third component: floating-point types
-  "%select{|floating-point}3"
+  "%select{|floating-point|16 or 32 bit floating-point}3"
   // A space after a non-empty third component
   "%plural{0:|: }3"
   "%plural{[0,3]:type|:types}1 (was %4)">;
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 70aacaa2aadbe..6486ef765f32e 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2005,68 +2005,6 @@ void SemaHLSL::diagnoseAvailabilityViolations(TranslationUnitDecl *TU) {
   DiagnoseHLSLAvailability(SemaRef).RunOnTranslationUnit(TU);
 }
 
-// Helper function for CheckHLSLBuiltinFunctionCall
-static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
-  assert(TheCall->getNumArgs() > 1);
-  ExprResult A = TheCall->getArg(0);
-
-  QualType ArgTyA = A.get()->getType();
-
-  auto *VecTyA = ArgTyA->getAs<VectorType>();
-  SourceLocation BuiltinLoc = TheCall->getBeginLoc();
-
-  bool AllBArgAreVectors = true;
-  for (unsigned i = 1; i < TheCall->getNumArgs(); ++i) {
-    ExprResult B = TheCall->getArg(i);
-    QualType ArgTyB = B.get()->getType();
-    auto *VecTyB = ArgTyB->getAs<VectorType>();
-    if (VecTyB == nullptr)
-      AllBArgAreVectors &= false;
-    if (VecTyA && VecTyB == nullptr) {
-      // Note: if we get here 'B' is scalar which
-      // requires a VectorSplat on ArgN
-      S->Diag(BuiltinLoc, diag::err_vec_builtin_non_vector)
-          << TheCall->getDirectCallee() << /*useAllTerminology*/ true
-          << SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc());
-      return true;
-    }
-    if (VecTyA && VecTyB) {
-      bool retValue = false;
-      if (!S->Context.hasSameUnqualifiedType(VecTyA->getElementType(),
-                                             VecTyB->getElementType())) {
-        // Note: type promotion is intended to be handeled via the intrinsics
-        //  and not the builtin itself.
-        S->Diag(TheCall->getBeginLoc(),
-                diag::err_vec_builtin_incompatible_vector)
-            << TheCall->getDirectCallee() << /*useAllTerminology*/ true
-            << SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc());
-        retValue = true;
-      }
-      if (VecTyA->getNumElements() != VecTyB->getNumElements()) {
-        // You should only be hitting this case if you are calling the builtin
-        // directly. HLSL intrinsics should avoid this case via a
-        // HLSLVectorTruncation.
-        S->Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_vector)
-            << TheCall->getDirectCallee() << /*useAllTerminology*/ true
-            << SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc());
-        retValue = true;
-      }
-      if (retValue)
-        return retValue;
-    }
-  }
-
-  if (VecTyA == nullptr && AllBArgAreVectors) {
-    // Note: if we get here 'A' is a scalar which
-    // requires a VectorSplat on Arg0
-    S->Diag(BuiltinLoc, diag::err_vec_builtin_non_vector)
-        << TheCall->getDirectCallee() << /*useAllTerminology*/ true
-        << SourceRange(A.get()->getBeginLoc(), A.get()->getEndLoc());
-    return true;
-  }
-  return false;
-}
-
 static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
   assert(TheCall->getNumArgs() > 1);
   QualType ArgTy0 = TheCall->getArg(0)->getType();
@@ -2094,63 +2032,46 @@ static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType) {
   return false;
 }
 
-static bool CheckArgTypeIsCorrect(
-    Sema *S, Expr *Arg, QualType ExpectedType,
-    llvm::function_ref<bool(clang::QualType PassedType)> Check) {
-  QualType PassedType = Arg->getType();
-  if (Check(PassedType)) {
-    if (auto *VecTyA = PassedType->getAs<VectorType>())
-      ExpectedType = S->Context.getVectorType(
-          ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind());
-    S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible)
-        << PassedType << ExpectedType << 1 << 0 << 0;
-    return true;
-  }
-  return false;
-}
-
 static bool CheckAllArgTypesAreCorrect(
-    Sema *S, CallExpr *TheCall, QualType ExpectedType,
-    llvm::function_ref<bool(clang::QualType PassedType)> Check) {
-  for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) {
-    Expr *Arg = TheCall->getArg(i);
-    if (CheckArgTypeIsCorrect(S, Arg, ExpectedType, Check)) {
+    Sema *S, CallExpr *TheCall,
+    llvm::function_ref<bool(Sema *S, SourceLocation Loc, int ArgOrdinal,
+                            clang::QualType PassedType)>
+        Check) {
+  for (unsigned I = 0; I < TheCall->getNumArgs(); ++I) {
+    Expr *Arg = TheCall->getArg(I);
+    if (Check(S, Arg->getBeginLoc(), I + 1, Arg->getType()))
       return true;
-    }
   }
   return false;
 }
 
-static bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
-  auto checkAllFloatTypes = [](clang::QualType PassedType) -> bool {
-    return !PassedType->hasFloatingRepresentation();
-  };
-  return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
-                                    checkAllFloatTypes);
-}
+static bool CheckFloatOrHalfVecRepresentation(Sema *S, SourceLocation Loc,
+                                              int ArgOrdinal,
+                                              clang::QualType PassedType) {
+  QualType EltTy = PassedType;
+  if (auto *VecTy = EltTy->getAs<VectorType>())
+    EltTy = VecTy->getElementType();
 
-static bool CheckUnsignedIntRepresentations(Sema *S, CallExpr *TheCall) {
-  auto checkUnsignedInteger = [](clang::QualType PassedType) -> bool {
-    clang::QualType BaseType =
-        PassedType->isVectorType()
-            ? PassedType->getAs<clang::VectorType>()->getElementType()
-            : PassedType;
-    return !BaseType->isUnsignedIntegerType();
-  };
-  return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.UnsignedIntTy,
-                                    checkUnsignedInteger);
+  if (!PassedType->getAs<VectorType>() ||
+      !(EltTy->isHalfType() || EltTy->isFloat32Type()))
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* vector of */ 4 << /* no int */ 0
+           << /* half or float */ 2 << PassedType;
+  return false;
 }
 
-static bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) {
-  auto checkFloatorHalf = [](clang::QualType PassedType) -> bool {
-    clang::QualType BaseType =
-        PassedType->isVectorType()
-            ? PassedType->getAs<clang::VectorType>()->getElementType()
-            : PassedType;
-    return !BaseType->isHalfType() && !BaseType->isFloat32Type();
-  };
-  return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
-                                    checkFloatorHalf);
+static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
+                                           int ArgOrdinal,
+                                           clang::QualType PassedType) {
+  clang::QualType BaseType =
+      PassedType->isVectorType()
+          ? PassedType->getAs<clang::VectorType>()->getElementType()
+          : PassedType;
+  if (!BaseType->isHalfType() && !BaseType->isFloat32Type())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
+           << /* half or float */ 2 << PassedType;
+  return false;
 }
 
 static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
@@ -2164,30 +2085,49 @@ static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
   return true;
 }
 
-static bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall) {
-  auto checkDoubleVector = [](clang::QualType PassedType) -> bool {
-    if (const auto *VecTy = PassedType->getAs<VectorType>())
-      return VecTy->getElementType()->isDoubleType();
-    return false;
-  };
-  return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
-                                    checkDoubleVector);
+static bool CheckNoDoubleVectors(Sema *S, SourceLocation Loc, int ArgOrdinal,
+                                 clang::QualType PassedType) {
+  if (const auto *VecTy = PassedType->getAs<VectorType>())
+    if (VecTy->getElementType()->isDoubleType())
+      return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+             << ArgOrdinal << /* scalar */ 1 << /* no int */ 0 << /* fp */ 1
+             << PassedType;
+  return false;
 }
-static bool CheckFloatingOrIntRepresentation(Sema *S, CallExpr *TheCall) {
-  auto checkAllSignedTypes = [](clang::QualType PassedType) -> bool {
-    return !PassedType->hasIntegerRepresentation() &&
-           !PassedType->hasFloatingRepresentation();
-  };
-  return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.IntTy,
-                                    checkAllSignedTypes);
+
+static bool CheckFloatingOrIntRepresentation(Sema *S, SourceLocation Loc,
+                                             int ArgOrdinal,
+                                             clang::QualType PassedType) {
+  if (!PassedType->hasIntegerRepresentation() &&
+      !PassedType->hasFloatingRepresentation())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* scalar or vector of */ 5 << /* integer */ 1
+           << /* fp */ 1 << PassedType;
+  return false;
 }
 
-static bool CheckUnsignedIntRepresentation(Sema *S, CallExpr *TheCall) {
-  auto checkAllUnsignedTypes = [](clang::QualType PassedType) -> bool {
-    return !PassedType->hasUnsignedIntegerRepresentation();
-  };
-  return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.UnsignedIntTy,
-                                    checkAllUnsignedTypes);
+static bool CheckUnsignedIntVecRepresentation(Sema *S, SourceLocation Loc,
+                                              int ArgOrdinal,
+                                              clang::QualType PassedType) {
+  QualType EltTy = PassedType;
+  if (auto *VecTy = EltTy->getAs<VectorType>())
+    EltTy = VecTy->getElementType();
+
+  if (!PassedType->getAs<VectorType>() || !EltTy->isUnsignedIntegerType())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* vector of */ 4 << /* uint */ 3 << /* no fp */ 0
+           << PassedType;
+  return false;
+}
+
+static bool CheckUnsignedIntRepresentation(Sema *S, SourceLocation Loc,
+                                           int ArgOrdinal,
+                                           clang::QualType PassedType) {
+  if (!PassedType->hasUnsignedIntegerRepresentation())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* scalar or vector of */ 5 << /* unsigned int */ 3
+           << /* no fp */ 0 << PassedType;
+  return false;
 }
 
 static void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
@@ -2343,23 +2283,12 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
   case Builtin::BI__builtin_hlsl_adduint64: {
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
-    if (CheckVectorElementCallArgs(&SemaRef, TheCall))
-      return true;
-    if (CheckUnsignedIntRepresentations(&SemaRef, TheCall))
-      return true;
-
-    // CheckVectorElementCallArgs(...) guarantees both args are the same type.
-    assert(TheCall->getArg(0)->getType() == TheCall->getArg(1)->getType() &&
-           "Both args must be of the same type");
 
-    // ensure both args are vectors
-    auto *VTy = TheCall->getArg(0)->getType()->getAs<VectorType>();
-    if (!VTy) {
-      SemaRef.Diag(TheCall->getBeginLoc(), diag::err_vec_builtin_non_vector)
-          << TheCall->getDirectCallee() << /*all*/ 1;
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                   CheckUnsignedIntVecRepresentation))
       return true;
-    }
 
+    auto *VTy = TheCall->getArg(0)->getType()->getAs<VectorType>();
     // ensure arg integers are 32-bits
     uint64_t ElementBitCount = getASTContext()
                                    .getTypeSizeInChars(VTy->getElementType())
@@ -2380,6 +2309,10 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     }
 
+    // ensure first arg and second arg have the same type
+    if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
+      return true;
+
     ExprResult A = TheCall->getArg(0);
     QualType ArgTyA = A.get()->getType();
     // return type is the same as the input type
@@ -2431,10 +2364,10 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
   case Builtin::BI__builtin_hlsl_or: {
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
-    if (CheckVectorElementCallArgs(&SemaRef, TheCall))
-      return true;
     if (CheckScalarOrVector(&SemaRef, TheCall, getASTContext().BoolTy, 0))
       return true;
+    if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
+      return true;
 
     ExprResult A = TheCall->getArg(0);
     QualType ArgTyA = A.get()->getType();
@@ -2446,37 +2379,41 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
   case Builtin::BI__builtin_hlsl_any: {
     if (SemaRef.checkArgCount(TheCall, 1))
       return true;
+    if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
+      return true;
     break;
   }
   case Builtin::BI__builtin_hlsl_asdouble: {
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
-    if (CheckUnsignedIntRepresentation(&SemaRef, TheCall))
+    if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
+                            0))
+      return true;
+    if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
+                            1))
+      return true;
+    if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
       return true;
 
     SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().DoubleTy);
     break;
   }
   case Builtin::BI__builtin_hlsl_elementwise_clamp: {
-    if (SemaRef.checkArgCount(TheCall, 3))
-      return true;
-    if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0) ||
-        CheckAllArgsHaveSameType(&SemaRef, TheCall))
-      return true;
     if (SemaRef.BuiltinElementwiseTernaryMath(
             TheCall, /*ArgTyRestr=*/
-            TheCall->getArg(0)->getType()->hasFloatingRepresentation()
-                ? Sema::EltwiseBuiltinArgTyRestriction::FloatTy
-                : Sema::EltwiseBuiltinArgTyRestriction::None))
+            Sema::EltwiseBuiltinArgTyRestriction::None))
       return true;
     break;
   }
   case Builtin::BI__builtin_hlsl_cross: {
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
-    if (CheckVectorElementCallArgs(&SemaRef, TheCall))
+
+    // ensure args are a half3 or float3
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                   CheckFloatOrHalfVecRepresentation))
       return true;
-    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+    if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
       return true;
     // ensure both args have 3 elements
     int NumElementsArg1 =
@@ -2507,13 +2444,9 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
     break;
   }
   case Builtin::BI__builtin_hlsl_dot: {
-    if (SemaRef.checkArgCount(TheCall, 2))
-      return true;
-    if (CheckVectorElementCallArgs(&SemaRef, TheCall))
-      return true;
     if (SemaRef.BuiltinVectorToScalarMath(TheCall))
       return true;
-    if (CheckNoDoubleVectors(&SemaRef, TheCall))
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall, CheckNoDoubleVectors))
       return true;
     break;
   }
@@ -2560,8 +2493,15 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
   }
   case Builtin::BI__builtin_hlsl_elementwise_saturate:
   case Builtin::BI__builtin_hlsl_elementwise_rcp: {
-    if (CheckAllArgsHaveFloatRepresentation(&SemaRef, TheCall))
+    if (SemaRef.checkArgCount(TheCall, 1))
       return true;
+    if (!TheCall->getArg(0)
+             ->getType()
+             ->hasFloatingRepresentation()) // half or float or double
+      return SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
+                          diag::err_builtin_invalid_arg_type)
+             << /* ordinal */ 1 << /* scalar or vector */ 5 << /* no int */ 0
+             << /* fp */ 1 << TheCall->getArg(0)->getType();
     if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
       return true;
     break;
@@ -2570,14 +2510,20 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
   case Builtin::BI__builtin_hlsl_elementwise_radians:
   case Builtin::BI__builtin_hlsl_elementwise_rsqrt:
   case Builtin::BI__builtin_hlsl_elementwise_frac: {
-    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+    if (SemaRef.checkArgCount(TheCall, 1))
+      return true;
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                   CheckFloatOrHalfRepresentation))
       return true;
     if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
       return true;
     break;
   }
   case Builtin::BI__builtin_hlsl_elementwise_isinf: {
-    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+    if (SemaRef.checkArgCount(TheCall, 1))
+      return true;
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                   CheckFloatOrHalfRepresentation))
       return true;
     if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
       return true;
@@ -2587,34 +2533,28 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
   case Builtin::BI__builtin_hlsl_lerp: {
     if (SemaRef.checkArgCount(TheCall, 3))
       return true;
-    if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0) ||
-        CheckAllArgsHaveSameType(&SemaRef, TheCall))
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                   CheckFloatOrHalfRepresentation))
       return true;
-    if (SemaRef.BuiltinElementwiseTernaryMath(TheCall))
+    if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
       return true;
-    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+    if (SemaRef.BuiltinElementwiseTernaryMath(TheCall))
       return true;
     break;
   }
   case Builtin::BI__builtin_hlsl_mad: {
-    if (SemaRef.checkArgCount(TheCall, 3))
-      return true;
-    if (CheckVectorElementCallArgs(&SemaRef, TheCall))
-      return true;
     if (SemaRef.BuiltinElementwiseTernaryMath(
             TheCall, /*ArgTyRestr=*/
-            TheCall->getArg(0)->getType()->hasFloatingRepresentation()
-                ? Sema::EltwiseBuiltinArgTyRestriction::FloatTy
-                : Sema::EltwiseBuiltinArgTyRestriction::None))
+            Sema::EltwiseBuiltinArgTyRestriction::None))
       return true;
     break;
   }
   case Builtin::BI__builtin_hlsl_normalize: {
-    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
-      return true;
     if (SemaRef.checkArgCount(TheCall, 1))
       return true;
-
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                   CheckFloatOrHalfRepresentation))
+      return true;
     ExprResult A = TheCall->getArg(0);
     QualType ArgTyA = A.get()->getType();
     // return type is the same as the input type
@@ -2622,17 +2562,19 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
     break;
   }
   case Builtin::BI__builtin_hlsl_elementwise_sign: {
-    if (CheckFloatingOrIntRepresentation(&SemaRef, TheCall))
-      return true;
     if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
       return true;
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                   CheckFloatingOrIntRepresentation))
+      return true;
     SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().IntTy);
     break;
   }
   case Builtin::BI__builtin_hlsl_step: {
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
-    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 4, 2025

@llvm/pr-subscribers-clang

Author: Sarah Spall (spall)

Changes

Update how Sema Checking is done for HLSL builtins to allow for better error messages, mainly using 'err_builtin_invalid_arg_type'.
Try to follow the formula outlined in issue #134721
Closes #134721


Patch is 61.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/138429.diff

21 Files Affected:

  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+1-1)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+131-185)
  • (modified) clang/test/SemaHLSL/BuiltIns/AddUint64-errors.hlsl (+3-3)
  • (modified) clang/test/SemaHLSL/BuiltIns/asdouble-errors.hlsl (+10)
  • (modified) clang/test/SemaHLSL/BuiltIns/clamp-errors.hlsl (+16-11)
  • (modified) clang/test/SemaHLSL/BuiltIns/cross-errors.hlsl (+9-3)
  • (modified) clang/test/SemaHLSL/BuiltIns/degrees-errors.hlsl (+3-3)
  • (modified) clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl (+17-16)
  • (modified) clang/test/SemaHLSL/BuiltIns/frac-errors.hlsl (+4-4)
  • (modified) clang/test/SemaHLSL/BuiltIns/half-float-only-errors.hlsl (+2-2)
  • (modified) clang/test/SemaHLSL/BuiltIns/half-float-only-errors2.hlsl (+2-2)
  • (modified) clang/test/SemaHLSL/BuiltIns/isinf-errors.hlsl (+6-6)
  • (modified) clang/test/SemaHLSL/BuiltIns/lerp-errors.hlsl (+10-10)
  • (modified) clang/test/SemaHLSL/BuiltIns/logical-operator-errors.hlsl (+6)
  • (modified) clang/test/SemaHLSL/BuiltIns/mad-errors.hlsl (+11-11)
  • (modified) clang/test/SemaHLSL/BuiltIns/normalize-errors.hlsl (+4-4)
  • (modified) clang/test/SemaHLSL/BuiltIns/radians-errors.hlsl (+4-4)
  • (modified) clang/test/SemaHLSL/BuiltIns/rcp-errors.hlsl (+4-5)
  • (modified) clang/test/SemaHLSL/BuiltIns/reversebits-errors.hlsl (+1-1)
  • (modified) clang/test/SemaHLSL/BuiltIns/rsqrt-errors.hlsl (+5-6)
  • (modified) clang/test/SemaHLSL/BuiltIns/step-errors.hlsl (+4-4)
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index ccb14e9927adf..c94d34e0259be 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -12649,7 +12649,7 @@ def err_builtin_invalid_arg_type: Error<
   // An 'or' if non-empty second and third components are combined
   "%plural{0:|:%plural{0:|:or }2}3"
   // Third component: floating-point types
-  "%select{|floating-point}3"
+  "%select{|floating-point|16 or 32 bit floating-point}3"
   // A space after a non-empty third component
   "%plural{0:|: }3"
   "%plural{[0,3]:type|:types}1 (was %4)">;
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 70aacaa2aadbe..6486ef765f32e 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2005,68 +2005,6 @@ void SemaHLSL::diagnoseAvailabilityViolations(TranslationUnitDecl *TU) {
   DiagnoseHLSLAvailability(SemaRef).RunOnTranslationUnit(TU);
 }
 
-// Helper function for CheckHLSLBuiltinFunctionCall
-static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
-  assert(TheCall->getNumArgs() > 1);
-  ExprResult A = TheCall->getArg(0);
-
-  QualType ArgTyA = A.get()->getType();
-
-  auto *VecTyA = ArgTyA->getAs<VectorType>();
-  SourceLocation BuiltinLoc = TheCall->getBeginLoc();
-
-  bool AllBArgAreVectors = true;
-  for (unsigned i = 1; i < TheCall->getNumArgs(); ++i) {
-    ExprResult B = TheCall->getArg(i);
-    QualType ArgTyB = B.get()->getType();
-    auto *VecTyB = ArgTyB->getAs<VectorType>();
-    if (VecTyB == nullptr)
-      AllBArgAreVectors &= false;
-    if (VecTyA && VecTyB == nullptr) {
-      // Note: if we get here 'B' is scalar which
-      // requires a VectorSplat on ArgN
-      S->Diag(BuiltinLoc, diag::err_vec_builtin_non_vector)
-          << TheCall->getDirectCallee() << /*useAllTerminology*/ true
-          << SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc());
-      return true;
-    }
-    if (VecTyA && VecTyB) {
-      bool retValue = false;
-      if (!S->Context.hasSameUnqualifiedType(VecTyA->getElementType(),
-                                             VecTyB->getElementType())) {
-        // Note: type promotion is intended to be handeled via the intrinsics
-        //  and not the builtin itself.
-        S->Diag(TheCall->getBeginLoc(),
-                diag::err_vec_builtin_incompatible_vector)
-            << TheCall->getDirectCallee() << /*useAllTerminology*/ true
-            << SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc());
-        retValue = true;
-      }
-      if (VecTyA->getNumElements() != VecTyB->getNumElements()) {
-        // You should only be hitting this case if you are calling the builtin
-        // directly. HLSL intrinsics should avoid this case via a
-        // HLSLVectorTruncation.
-        S->Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_vector)
-            << TheCall->getDirectCallee() << /*useAllTerminology*/ true
-            << SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc());
-        retValue = true;
-      }
-      if (retValue)
-        return retValue;
-    }
-  }
-
-  if (VecTyA == nullptr && AllBArgAreVectors) {
-    // Note: if we get here 'A' is a scalar which
-    // requires a VectorSplat on Arg0
-    S->Diag(BuiltinLoc, diag::err_vec_builtin_non_vector)
-        << TheCall->getDirectCallee() << /*useAllTerminology*/ true
-        << SourceRange(A.get()->getBeginLoc(), A.get()->getEndLoc());
-    return true;
-  }
-  return false;
-}
-
 static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
   assert(TheCall->getNumArgs() > 1);
   QualType ArgTy0 = TheCall->getArg(0)->getType();
@@ -2094,63 +2032,46 @@ static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType) {
   return false;
 }
 
-static bool CheckArgTypeIsCorrect(
-    Sema *S, Expr *Arg, QualType ExpectedType,
-    llvm::function_ref<bool(clang::QualType PassedType)> Check) {
-  QualType PassedType = Arg->getType();
-  if (Check(PassedType)) {
-    if (auto *VecTyA = PassedType->getAs<VectorType>())
-      ExpectedType = S->Context.getVectorType(
-          ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind());
-    S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible)
-        << PassedType << ExpectedType << 1 << 0 << 0;
-    return true;
-  }
-  return false;
-}
-
 static bool CheckAllArgTypesAreCorrect(
-    Sema *S, CallExpr *TheCall, QualType ExpectedType,
-    llvm::function_ref<bool(clang::QualType PassedType)> Check) {
-  for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) {
-    Expr *Arg = TheCall->getArg(i);
-    if (CheckArgTypeIsCorrect(S, Arg, ExpectedType, Check)) {
+    Sema *S, CallExpr *TheCall,
+    llvm::function_ref<bool(Sema *S, SourceLocation Loc, int ArgOrdinal,
+                            clang::QualType PassedType)>
+        Check) {
+  for (unsigned I = 0; I < TheCall->getNumArgs(); ++I) {
+    Expr *Arg = TheCall->getArg(I);
+    if (Check(S, Arg->getBeginLoc(), I + 1, Arg->getType()))
       return true;
-    }
   }
   return false;
 }
 
-static bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
-  auto checkAllFloatTypes = [](clang::QualType PassedType) -> bool {
-    return !PassedType->hasFloatingRepresentation();
-  };
-  return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
-                                    checkAllFloatTypes);
-}
+static bool CheckFloatOrHalfVecRepresentation(Sema *S, SourceLocation Loc,
+                                              int ArgOrdinal,
+                                              clang::QualType PassedType) {
+  QualType EltTy = PassedType;
+  if (auto *VecTy = EltTy->getAs<VectorType>())
+    EltTy = VecTy->getElementType();
 
-static bool CheckUnsignedIntRepresentations(Sema *S, CallExpr *TheCall) {
-  auto checkUnsignedInteger = [](clang::QualType PassedType) -> bool {
-    clang::QualType BaseType =
-        PassedType->isVectorType()
-            ? PassedType->getAs<clang::VectorType>()->getElementType()
-            : PassedType;
-    return !BaseType->isUnsignedIntegerType();
-  };
-  return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.UnsignedIntTy,
-                                    checkUnsignedInteger);
+  if (!PassedType->getAs<VectorType>() ||
+      !(EltTy->isHalfType() || EltTy->isFloat32Type()))
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* vector of */ 4 << /* no int */ 0
+           << /* half or float */ 2 << PassedType;
+  return false;
 }
 
-static bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) {
-  auto checkFloatorHalf = [](clang::QualType PassedType) -> bool {
-    clang::QualType BaseType =
-        PassedType->isVectorType()
-            ? PassedType->getAs<clang::VectorType>()->getElementType()
-            : PassedType;
-    return !BaseType->isHalfType() && !BaseType->isFloat32Type();
-  };
-  return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
-                                    checkFloatorHalf);
+static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
+                                           int ArgOrdinal,
+                                           clang::QualType PassedType) {
+  clang::QualType BaseType =
+      PassedType->isVectorType()
+          ? PassedType->getAs<clang::VectorType>()->getElementType()
+          : PassedType;
+  if (!BaseType->isHalfType() && !BaseType->isFloat32Type())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
+           << /* half or float */ 2 << PassedType;
+  return false;
 }
 
 static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
@@ -2164,30 +2085,49 @@ static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
   return true;
 }
 
-static bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall) {
-  auto checkDoubleVector = [](clang::QualType PassedType) -> bool {
-    if (const auto *VecTy = PassedType->getAs<VectorType>())
-      return VecTy->getElementType()->isDoubleType();
-    return false;
-  };
-  return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
-                                    checkDoubleVector);
+static bool CheckNoDoubleVectors(Sema *S, SourceLocation Loc, int ArgOrdinal,
+                                 clang::QualType PassedType) {
+  if (const auto *VecTy = PassedType->getAs<VectorType>())
+    if (VecTy->getElementType()->isDoubleType())
+      return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+             << ArgOrdinal << /* scalar */ 1 << /* no int */ 0 << /* fp */ 1
+             << PassedType;
+  return false;
 }
-static bool CheckFloatingOrIntRepresentation(Sema *S, CallExpr *TheCall) {
-  auto checkAllSignedTypes = [](clang::QualType PassedType) -> bool {
-    return !PassedType->hasIntegerRepresentation() &&
-           !PassedType->hasFloatingRepresentation();
-  };
-  return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.IntTy,
-                                    checkAllSignedTypes);
+
+static bool CheckFloatingOrIntRepresentation(Sema *S, SourceLocation Loc,
+                                             int ArgOrdinal,
+                                             clang::QualType PassedType) {
+  if (!PassedType->hasIntegerRepresentation() &&
+      !PassedType->hasFloatingRepresentation())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* scalar or vector of */ 5 << /* integer */ 1
+           << /* fp */ 1 << PassedType;
+  return false;
 }
 
-static bool CheckUnsignedIntRepresentation(Sema *S, CallExpr *TheCall) {
-  auto checkAllUnsignedTypes = [](clang::QualType PassedType) -> bool {
-    return !PassedType->hasUnsignedIntegerRepresentation();
-  };
-  return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.UnsignedIntTy,
-                                    checkAllUnsignedTypes);
+static bool CheckUnsignedIntVecRepresentation(Sema *S, SourceLocation Loc,
+                                              int ArgOrdinal,
+                                              clang::QualType PassedType) {
+  QualType EltTy = PassedType;
+  if (auto *VecTy = EltTy->getAs<VectorType>())
+    EltTy = VecTy->getElementType();
+
+  if (!PassedType->getAs<VectorType>() || !EltTy->isUnsignedIntegerType())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* vector of */ 4 << /* uint */ 3 << /* no fp */ 0
+           << PassedType;
+  return false;
+}
+
+static bool CheckUnsignedIntRepresentation(Sema *S, SourceLocation Loc,
+                                           int ArgOrdinal,
+                                           clang::QualType PassedType) {
+  if (!PassedType->hasUnsignedIntegerRepresentation())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* scalar or vector of */ 5 << /* unsigned int */ 3
+           << /* no fp */ 0 << PassedType;
+  return false;
 }
 
 static void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
@@ -2343,23 +2283,12 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
   case Builtin::BI__builtin_hlsl_adduint64: {
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
-    if (CheckVectorElementCallArgs(&SemaRef, TheCall))
-      return true;
-    if (CheckUnsignedIntRepresentations(&SemaRef, TheCall))
-      return true;
-
-    // CheckVectorElementCallArgs(...) guarantees both args are the same type.
-    assert(TheCall->getArg(0)->getType() == TheCall->getArg(1)->getType() &&
-           "Both args must be of the same type");
 
-    // ensure both args are vectors
-    auto *VTy = TheCall->getArg(0)->getType()->getAs<VectorType>();
-    if (!VTy) {
-      SemaRef.Diag(TheCall->getBeginLoc(), diag::err_vec_builtin_non_vector)
-          << TheCall->getDirectCallee() << /*all*/ 1;
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                   CheckUnsignedIntVecRepresentation))
       return true;
-    }
 
+    auto *VTy = TheCall->getArg(0)->getType()->getAs<VectorType>();
     // ensure arg integers are 32-bits
     uint64_t ElementBitCount = getASTContext()
                                    .getTypeSizeInChars(VTy->getElementType())
@@ -2380,6 +2309,10 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     }
 
+    // ensure first arg and second arg have the same type
+    if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
+      return true;
+
     ExprResult A = TheCall->getArg(0);
     QualType ArgTyA = A.get()->getType();
     // return type is the same as the input type
@@ -2431,10 +2364,10 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
   case Builtin::BI__builtin_hlsl_or: {
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
-    if (CheckVectorElementCallArgs(&SemaRef, TheCall))
-      return true;
     if (CheckScalarOrVector(&SemaRef, TheCall, getASTContext().BoolTy, 0))
       return true;
+    if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
+      return true;
 
     ExprResult A = TheCall->getArg(0);
     QualType ArgTyA = A.get()->getType();
@@ -2446,37 +2379,41 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
   case Builtin::BI__builtin_hlsl_any: {
     if (SemaRef.checkArgCount(TheCall, 1))
       return true;
+    if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
+      return true;
     break;
   }
   case Builtin::BI__builtin_hlsl_asdouble: {
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
-    if (CheckUnsignedIntRepresentation(&SemaRef, TheCall))
+    if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
+                            0))
+      return true;
+    if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
+                            1))
+      return true;
+    if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
       return true;
 
     SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().DoubleTy);
     break;
   }
   case Builtin::BI__builtin_hlsl_elementwise_clamp: {
-    if (SemaRef.checkArgCount(TheCall, 3))
-      return true;
-    if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0) ||
-        CheckAllArgsHaveSameType(&SemaRef, TheCall))
-      return true;
     if (SemaRef.BuiltinElementwiseTernaryMath(
             TheCall, /*ArgTyRestr=*/
-            TheCall->getArg(0)->getType()->hasFloatingRepresentation()
-                ? Sema::EltwiseBuiltinArgTyRestriction::FloatTy
-                : Sema::EltwiseBuiltinArgTyRestriction::None))
+            Sema::EltwiseBuiltinArgTyRestriction::None))
       return true;
     break;
   }
   case Builtin::BI__builtin_hlsl_cross: {
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
-    if (CheckVectorElementCallArgs(&SemaRef, TheCall))
+
+    // ensure args are a half3 or float3
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                   CheckFloatOrHalfVecRepresentation))
       return true;
-    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+    if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
       return true;
     // ensure both args have 3 elements
     int NumElementsArg1 =
@@ -2507,13 +2444,9 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
     break;
   }
   case Builtin::BI__builtin_hlsl_dot: {
-    if (SemaRef.checkArgCount(TheCall, 2))
-      return true;
-    if (CheckVectorElementCallArgs(&SemaRef, TheCall))
-      return true;
     if (SemaRef.BuiltinVectorToScalarMath(TheCall))
       return true;
-    if (CheckNoDoubleVectors(&SemaRef, TheCall))
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall, CheckNoDoubleVectors))
       return true;
     break;
   }
@@ -2560,8 +2493,15 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
   }
   case Builtin::BI__builtin_hlsl_elementwise_saturate:
   case Builtin::BI__builtin_hlsl_elementwise_rcp: {
-    if (CheckAllArgsHaveFloatRepresentation(&SemaRef, TheCall))
+    if (SemaRef.checkArgCount(TheCall, 1))
       return true;
+    if (!TheCall->getArg(0)
+             ->getType()
+             ->hasFloatingRepresentation()) // half or float or double
+      return SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
+                          diag::err_builtin_invalid_arg_type)
+             << /* ordinal */ 1 << /* scalar or vector */ 5 << /* no int */ 0
+             << /* fp */ 1 << TheCall->getArg(0)->getType();
     if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
       return true;
     break;
@@ -2570,14 +2510,20 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
   case Builtin::BI__builtin_hlsl_elementwise_radians:
   case Builtin::BI__builtin_hlsl_elementwise_rsqrt:
   case Builtin::BI__builtin_hlsl_elementwise_frac: {
-    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+    if (SemaRef.checkArgCount(TheCall, 1))
+      return true;
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                   CheckFloatOrHalfRepresentation))
       return true;
     if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
       return true;
     break;
   }
   case Builtin::BI__builtin_hlsl_elementwise_isinf: {
-    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+    if (SemaRef.checkArgCount(TheCall, 1))
+      return true;
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                   CheckFloatOrHalfRepresentation))
       return true;
     if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
       return true;
@@ -2587,34 +2533,28 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
   case Builtin::BI__builtin_hlsl_lerp: {
     if (SemaRef.checkArgCount(TheCall, 3))
       return true;
-    if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0) ||
-        CheckAllArgsHaveSameType(&SemaRef, TheCall))
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                   CheckFloatOrHalfRepresentation))
       return true;
-    if (SemaRef.BuiltinElementwiseTernaryMath(TheCall))
+    if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
       return true;
-    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+    if (SemaRef.BuiltinElementwiseTernaryMath(TheCall))
       return true;
     break;
   }
   case Builtin::BI__builtin_hlsl_mad: {
-    if (SemaRef.checkArgCount(TheCall, 3))
-      return true;
-    if (CheckVectorElementCallArgs(&SemaRef, TheCall))
-      return true;
     if (SemaRef.BuiltinElementwiseTernaryMath(
             TheCall, /*ArgTyRestr=*/
-            TheCall->getArg(0)->getType()->hasFloatingRepresentation()
-                ? Sema::EltwiseBuiltinArgTyRestriction::FloatTy
-                : Sema::EltwiseBuiltinArgTyRestriction::None))
+            Sema::EltwiseBuiltinArgTyRestriction::None))
       return true;
     break;
   }
   case Builtin::BI__builtin_hlsl_normalize: {
-    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
-      return true;
     if (SemaRef.checkArgCount(TheCall, 1))
       return true;
-
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                   CheckFloatOrHalfRepresentation))
+      return true;
     ExprResult A = TheCall->getArg(0);
     QualType ArgTyA = A.get()->getType();
     // return type is the same as the input type
@@ -2622,17 +2562,19 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
     break;
   }
   case Builtin::BI__builtin_hlsl_elementwise_sign: {
-    if (CheckFloatingOrIntRepresentation(&SemaRef, TheCall))
-      return true;
     if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
       return true;
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                   CheckFloatingOrIntRepresentation))
+      return true;
     SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().IntTy);
     break;
   }
   case Builtin::BI__builtin_hlsl_step: {
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
-    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall...
[truncated]

@@ -1,4 +1,4 @@
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify -verify-ignore-unexpected
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify -verify-ignore-unexpected=note,warning
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the notes / warnings irrelevant or overbearing? Any way to add them in so this option can be dropped?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe for this test there was 1 warning about truncation; I'm not sure about the notes, but usually if there are notes, there is what I'd consider an overbearing amount of them.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category HLSL HLSL Language Support
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[HLSL] Update HLSL's Sema Checking and Diagnostics for builtins
3 participants