diff --git a/src/coreclr/System.Private.CoreLib/src/System/String.CoreCLR.cs b/src/coreclr/System.Private.CoreLib/src/System/String.CoreCLR.cs index d19cb01034a74e..2f148da67aa575 100644 --- a/src/coreclr/System.Private.CoreLib/src/System/String.CoreCLR.cs +++ b/src/coreclr/System.Private.CoreLib/src/System/String.CoreCLR.cs @@ -38,7 +38,7 @@ internal static unsafe void InternalCopy(string src, IntPtr dest, int len) { if (len != 0) { - SpanHelpers.Memmove(ref *(byte*)dest, ref Unsafe.As(ref src.GetRawStringData()), (nuint)len); + SpanHelpers.Memmove(ref *(byte*)dest, ref src.GetRawStringDataAsUInt8(), (nuint)len); } } diff --git a/src/libraries/Fuzzing/DotnetFuzzing/Fuzzers/SearchValuesStringFuzzer.cs b/src/libraries/Fuzzing/DotnetFuzzing/Fuzzers/SearchValuesStringFuzzer.cs index a99c6c3a4a4740..e122b4d337952d 100644 --- a/src/libraries/Fuzzing/DotnetFuzzing/Fuzzers/SearchValuesStringFuzzer.cs +++ b/src/libraries/Fuzzing/DotnetFuzzing/Fuzzers/SearchValuesStringFuzzer.cs @@ -36,8 +36,8 @@ private static void Test(ReadOnlySpan haystack, ReadOnlySpan haystac SearchValues searchValues = SearchValues.Create(needles, comparisonType); int index = haystack.IndexOfAny(searchValues); - Assert.Equal(index, haystackCopy.IndexOfAny(searchValues)); - Assert.Equal(index, IndexOfAnyReferenceImpl(haystack, needles, comparisonType)); + AssertEqual(index, haystackCopy.IndexOfAny(searchValues), searchValues); + AssertEqual(index, IndexOfAnyReferenceImpl(haystack, needles, comparisonType), searchValues); } private static int IndexOfAnyReferenceImpl(ReadOnlySpan haystack, string[] needles, StringComparison comparisonType) @@ -55,4 +55,15 @@ private static int IndexOfAnyReferenceImpl(ReadOnlySpan haystack, string[] return minIndex == int.MaxValue ? -1 : minIndex; } + + private static void AssertEqual(int expected, int actual, SearchValues searchValues) + { + if (expected != actual) + { + Type implType = searchValues.GetType(); + string impl = $"{implType.Name} [{string.Join(", ", implType.GenericTypeArguments.Select(t => t.Name))}]"; + + throw new Exception($"Expected {expected}, got {actual} for impl='{impl}'"); + } + } } diff --git a/src/libraries/System.Memory/tests/Span/StringSearchValues.cs b/src/libraries/System.Memory/tests/Span/StringSearchValues.cs index f702b0bdcad99d..5a3bc8d9a87f06 100644 --- a/src/libraries/System.Memory/tests/Span/StringSearchValues.cs +++ b/src/libraries/System.Memory/tests/Span/StringSearchValues.cs @@ -7,6 +7,7 @@ using System.Linq; using System.Runtime.ExceptionServices; using System.Runtime.InteropServices; +using System.Runtime.Intrinsics.X86; using System.Threading; using System.Threading.Tasks; using Microsoft.DotNet.RemoteExecutor; @@ -313,6 +314,7 @@ public static void IndexOfAny_InvalidUtf16() IndexOfAny(StringComparison.OrdinalIgnoreCase, -1, " foO\uD801bar", "oo\uD800baR, bar\uD800foo"); // Low surrogate without the high surrogate. + IndexOfAny(StringComparison.OrdinalIgnoreCase, 1, "\uD801\uDCD8\uD8FB\uDCD8", "\uDCD8"); IndexOfAny(StringComparison.OrdinalIgnoreCase, 1, "\uD801\uDCD8\uD8FB\uDCD8", "foo, \uDCD8"); } @@ -337,6 +339,15 @@ public static void IndexOfAny_InvalidUtf16() [InlineData("abcd!")] [InlineData("abcdefgh")] [InlineData("abcdefghi")] + [InlineData("123456789")] + [InlineData("123456789a")] + [InlineData("123456789ab")] + [InlineData("123456789abc")] + [InlineData("123456789abcd")] + [InlineData("123456789abcde")] + [InlineData("123456789abcdef")] + [InlineData("123456789abcdefg")] + [InlineData("123456789abcdefgh")] // Multiple values, but they all share the same prefix [InlineData("abc", "ab", "abcd")] // These should hit the Aho-Corasick implementation @@ -406,9 +417,25 @@ static void TestCore(string[] valuesArray) Values_ImplementsSearchValuesBase(StringComparison.OrdinalIgnoreCase, valuesArray); string values = string.Join(", ", valuesArray); + string text = valuesArray[0]; - IndexOfAny(StringComparison.Ordinal, 0, valuesArray[0], values); - IndexOfAny(StringComparison.OrdinalIgnoreCase, 0, valuesArray[0], values); + IndexOfAny(StringComparison.Ordinal, 0, text, values); + IndexOfAny(StringComparison.OrdinalIgnoreCase, 0, text, values); + + // Replace every position in the text with a different character. + foreach (StringComparison comparisonType in new[] { StringComparison.Ordinal, StringComparison.OrdinalIgnoreCase }) + { + SearchValues stringValues = SearchValues.Create(valuesArray, comparisonType); + + for (int i = 0; i < text.Length - 1; i++) + { + foreach (char replacement in "AaBb _!\u00F6") + { + string newText = $"{text.AsSpan(0, i)}{replacement}{text.AsSpan(i + 1)}"; + Assert.Equal(IndexOfAnyReferenceImpl(newText, valuesArray, comparisonType), newText.IndexOfAny(stringValues)); + } + } + } } } @@ -499,6 +526,20 @@ public static void TestIndexOfAny_RandomInputs_Stress() { RunStress(); + if (RemoteExecutor.IsSupported && Avx512F.IsSupported) + { + var psi = new ProcessStartInfo(); + psi.Environment.Add("DOTNET_EnableAVX512F", "0"); + RemoteExecutor.Invoke(RunStress, new RemoteInvokeOptions { StartInfo = psi, TimeOut = 10 * 60 * 1000 }).Dispose(); + } + + if (RemoteExecutor.IsSupported && Avx2.IsSupported) + { + var psi = new ProcessStartInfo(); + psi.Environment.Add("DOTNET_EnableAVX2", "0"); + RemoteExecutor.Invoke(RunStress, new RemoteInvokeOptions { StartInfo = psi, TimeOut = 10 * 60 * 1000 }).Dispose(); + } + if (CanTestInvariantCulture) { RunUsingInvariantCulture(static () => RunStress()); diff --git a/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/RandomizedStringEqualityComparer.cs b/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/RandomizedStringEqualityComparer.cs index 659ac09c7259ef..d6d79bb96b9e7b 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/RandomizedStringEqualityComparer.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/RandomizedStringEqualityComparer.cs @@ -80,7 +80,7 @@ public override int GetHashCode(string? obj) // The Ordinal version of Marvin32 operates over bytes. // The multiplication from # chars -> # bytes will never integer overflow. return Marvin.ComputeHash32( - ref Unsafe.As(ref obj.GetRawStringData()), + ref obj.GetRawStringDataAsUInt8(), (uint)obj.Length * 2, _seed.p0, _seed.p1); } diff --git a/src/libraries/System.Private.CoreLib/src/System/SearchValues/Strings/Helpers/StringSearchValuesHelper.cs b/src/libraries/System.Private.CoreLib/src/System/SearchValues/Strings/Helpers/StringSearchValuesHelper.cs index 9f1ba0ebf1771f..15a608cce70db6 100644 --- a/src/libraries/System.Private.CoreLib/src/System/SearchValues/Strings/Helpers/StringSearchValuesHelper.cs +++ b/src/libraries/System.Private.CoreLib/src/System/SearchValues/Strings/Helpers/StringSearchValuesHelper.cs @@ -62,47 +62,166 @@ public static bool StartsWith(ref char matchStart, int lengthR return false; } - return TCaseSensitivity.Equals(ref matchStart, candidate); + return UnknownLengthEquals(ref matchStart, candidate); } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static bool ScalarEquals(ref char matchStart, string candidate) + private static bool UnknownLengthEquals(ref char matchStart, string candidate) where TCaseSensitivity : struct, ICaseSensitivity { - for (int i = 0; i < candidate.Length; i++) + if (typeof(TCaseSensitivity) == typeof(CaseSensitive)) { - if (TCaseSensitivity.TransformInput(Unsafe.Add(ref matchStart, i)) != candidate[i]) - { - return false; - } + return SpanHelpers.SequenceEqual( + ref Unsafe.As(ref matchStart), + ref candidate.GetRawStringDataAsUInt8(), + (uint)candidate.Length * sizeof(char)); } - return true; - } + if (typeof(TCaseSensitivity) == typeof(CaseInsensitiveAscii) || + typeof(TCaseSensitivity) == typeof(CaseInsensitiveAsciiLetters)) + { + return Ascii.EqualsIgnoreCase(ref matchStart, ref candidate.GetRawStringData(), (uint)candidate.Length); + } - public interface IValueLength - { - static abstract bool AtLeast4Chars { get; } - static abstract bool AtLeast8CharsOrUnknown { get; } + Debug.Assert(typeof(TCaseSensitivity) == typeof(CaseInsensitiveUnicode)); + return Ordinal.EqualsIgnoreCase(ref matchStart, ref candidate.GetRawStringData(), candidate.Length); } - public readonly struct ValueLengthLessThan4 : IValueLength - { - public static bool AtLeast4Chars => false; - public static bool AtLeast8CharsOrUnknown => false; - } + public interface IValueLength { } - public readonly struct ValueLength4To7 : IValueLength - { - public static bool AtLeast4Chars => true; - public static bool AtLeast8CharsOrUnknown => false; - } + public readonly struct ValueLengthLessThan4 : IValueLength { } + + public readonly struct ValueLength4To8 : IValueLength { } + + public readonly struct ValueLength9To16 : IValueLength { } // "Unknown" is currently only used by Teddy when confirming matches. - public readonly struct ValueLength8OrLongerOrUnknown : IValueLength + public readonly struct ValueLengthLongOrUnknown : IValueLength { } + + public readonly struct SingleValueState { - public static bool AtLeast4Chars => true; - public static bool AtLeast8CharsOrUnknown => true; + public readonly string Value; + public readonly nint SecondReadByteOffset; + public readonly Vector256 Value256; + public readonly Vector256 ToUpperMask256; + + public readonly ulong Value64_0 => Value256.AsUInt64()[0]; + public readonly ulong Value64_1 => Value256.AsUInt64()[1]; + public readonly uint Value32_0 => Value256.AsUInt32()[0]; + public readonly uint Value32_1 => Value256.AsUInt32()[1]; + + public readonly ulong ToUpperMask64_0 => ToUpperMask256.AsUInt64()[0]; + public readonly ulong ToUpperMask64_1 => ToUpperMask256.AsUInt64()[1]; + public readonly uint ToUpperMask32_0 => ToUpperMask256.AsUInt32()[0]; + public readonly uint ToUpperMask32_1 => ToUpperMask256.AsUInt32()[1]; + + public SingleValueState(string value, bool ignoreCase) + { + Debug.Assert(value.Length >= 2); + + Value = value; + + // We precompute vectors specific to this value to speed up later comparisons. + // We group values depending on their length (2-3, 4-8, 9-16). + // For any of those lengths, we can load the whole value with two overlapped reads (e.g. 2x 8 characters for lengths 9-16). + // For a string "Hello World", we would load + // [Hello Wo] + // [lo World] + // SecondReadByteOffset: 6 bytes (3 characters) + // We then precompute a mask that converts any potential input to the uppercase variant, specific to this value. + // We must ensure that the ASCII letter mask only applies to the letters, not the space character. + // Value256: [HELLO WOLO WORLD] (note that the value is already converted to uppercase if we're ignoring casing) + // ToUpperMask256: [xxxxx xxxx xxxxx] (x = ~0x20 for ASCII letters, 0xFFFF otherwise) + // + // Given a potential match, we can now confirm whether we found a match by loading the candidate in the same way and applying this mask: + // Vector256 input = [Vector128.Load(candidate), Vector128.Load(candidate + 6 bytes)]; + // bool matches = (input & ToUpperMask256) == Value256; + + // The two vectors may overlap completely for Length == 2 or Length == 4, and that's fine. + // The second comparison during validation is redundant in such cases, but the alternative is to introduce more IValueLength specializations. + + if (value.Length <= 16) + { + if (value.Length > 8) + { + SecondReadByteOffset = (value.Length - 8) * sizeof(char); + Value256 = Vector256.Create( + Vector128.LoadUnsafe(ref value.GetRawStringDataAsUInt16()), + Vector128.LoadUnsafe(ref Unsafe.AddByteOffset(ref value.GetRawStringDataAsUInt16(), SecondReadByteOffset))); + } + else if (value.Length >= 4) + { + SecondReadByteOffset = (value.Length - 4) * sizeof(char); + Value256 = Vector256.Create(Vector128.Create( + Unsafe.ReadUnaligned(ref value.GetRawStringDataAsUInt8()), + Unsafe.ReadUnaligned(ref Unsafe.Add(ref value.GetRawStringDataAsUInt8(), SecondReadByteOffset)) + )).AsUInt16(); + } + else + { + Debug.Assert(value.Length is 2 or 3); + + SecondReadByteOffset = (value.Length - 2) * sizeof(char); + Value256 = Vector256.Create(Vector128.Create(Vector64.Create( + Unsafe.ReadUnaligned(ref value.GetRawStringDataAsUInt8()), + Unsafe.ReadUnaligned(ref Unsafe.Add(ref value.GetRawStringDataAsUInt8(), SecondReadByteOffset)) + ))).AsUInt16(); + } + + if (ignoreCase) + { + Vector256 isAsciiLetter = + Vector256.GreaterThanOrEqual(Value256, Vector256.Create((ushort)'A')) & + Vector256.LessThanOrEqual(Value256, Vector256.Create((ushort)'Z')); + + ToUpperMask256 = Vector256.ConditionalSelect(isAsciiLetter, Vector256.Create(unchecked((ushort)~0x20)), Vector256.Create(ushort.MaxValue)); + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public bool MatchesLength9To16_CaseSensitive(ref char matchStart) + { + Debug.Assert(Value.Length is >= 9 and <= 16); + Debug.Assert(ToUpperMask256 == default); + + if (Vector256.IsHardwareAccelerated) + { + Vector256 input = Vector256.Create( + Vector128.LoadUnsafe(ref matchStart), + Vector128.LoadUnsafe(ref Unsafe.AddByteOffset(ref matchStart, SecondReadByteOffset))); + + return input == Value256; + } + else + { + Vector128 different = Vector128.LoadUnsafe(ref matchStart) ^ Value256.GetLower(); + different |= Vector128.LoadUnsafe(ref Unsafe.AddByteOffset(ref matchStart, SecondReadByteOffset)) ^ Value256.GetUpper(); + return different == Vector128.Zero; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public bool MatchesLength9To16_CaseInsensitiveAscii(ref char matchStart) + { + Debug.Assert(Value.Length is >= 9 and <= 16); + Debug.Assert(ToUpperMask256 != default); + + if (Vector256.IsHardwareAccelerated) + { + Vector256 input = Vector256.Create( + Vector128.LoadUnsafe(ref matchStart), + Vector128.LoadUnsafe(ref Unsafe.AddByteOffset(ref matchStart, SecondReadByteOffset))); + + return (input & ToUpperMask256) == Value256; + } + else + { + Vector128 different = (Vector128.LoadUnsafe(ref matchStart) & ToUpperMask256.GetLower()) ^ Value256.GetLower(); + different |= (Vector128.LoadUnsafe(ref Unsafe.AddByteOffset(ref matchStart, SecondReadByteOffset)) & ToUpperMask256.GetUpper()) ^ Value256.GetUpper(); + return different == Vector128.Zero; + } + } } public interface ICaseSensitivity @@ -111,7 +230,7 @@ public interface ICaseSensitivity static abstract Vector128 TransformInput(Vector128 input); static abstract Vector256 TransformInput(Vector256 input); static abstract Vector512 TransformInput(Vector512 input); - static abstract bool Equals(ref char matchStart, string candidate) where TValueLength : struct, IValueLength; + static abstract bool Equals(ref char matchStart, ref readonly SingleValueState state) where TValueLength : struct, IValueLength; } // Performs no case transformations. @@ -130,39 +249,33 @@ public interface ICaseSensitivity public static Vector512 TransformInput(Vector512 input) => input; [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static bool Equals(ref char matchStart, string candidate) + public static bool Equals(ref char matchStart, ref readonly SingleValueState state) where TValueLength : struct, IValueLength { - Debug.Assert(candidate.Length > 1); - - ref byte first = ref Unsafe.As(ref matchStart); - ref byte second = ref Unsafe.As(ref candidate.GetRawStringData()); - nuint byteLength = (nuint)(uint)candidate.Length * 2; - - if (TValueLength.AtLeast8CharsOrUnknown) + if (typeof(TValueLength) == typeof(ValueLengthLongOrUnknown)) { - return SpanHelpers.SequenceEqual(ref first, ref second, byteLength); + return UnknownLengthEquals(ref matchStart, state.Value); } - - Debug.Assert(matchStart == candidate[0], "This should only be called after the first character has been checked"); - - if (TValueLength.AtLeast4Chars) + else if (typeof(TValueLength) == typeof(ValueLength9To16)) { - nuint offset = byteLength - sizeof(ulong); - ulong differentBits = Unsafe.ReadUnaligned(ref first) - Unsafe.ReadUnaligned(ref second); - differentBits |= Unsafe.ReadUnaligned(ref Unsafe.Add(ref first, offset)) - Unsafe.ReadUnaligned(ref Unsafe.Add(ref second, offset)); + return state.MatchesLength9To16_CaseSensitive(ref matchStart); + } + else if (typeof(TValueLength) == typeof(ValueLength4To8)) + { + ref byte matchByteStart = ref Unsafe.As(ref matchStart); + ulong differentBits = Unsafe.ReadUnaligned(ref matchByteStart) - state.Value64_0; + differentBits |= Unsafe.ReadUnaligned(ref Unsafe.Add(ref matchByteStart, state.SecondReadByteOffset)) - state.Value64_1; return differentBits == 0; } else { - Debug.Assert(candidate.Length is 2 or 3); + Debug.Assert(state.Value.Length is 2 or 3); + Debug.Assert(matchStart == state.Value[0], "This should only be called after the first character has been checked"); // We know that the candidate is 2 or 3 characters long, and that the first character has already been checked. - // We only have to to check the last 2 characters also match. - nuint offset = byteLength - sizeof(uint); - - return Unsafe.ReadUnaligned(ref Unsafe.Add(ref first, offset)) - == Unsafe.ReadUnaligned(ref Unsafe.Add(ref second, offset)); + // We only have to to check whether the last 2 characters also match. + ref byte matchByteStart = ref Unsafe.As(ref matchStart); + return Unsafe.ReadUnaligned(ref Unsafe.Add(ref matchByteStart, state.SecondReadByteOffset)) == state.Value32_1; } } } @@ -184,36 +297,35 @@ public static bool Equals(ref char matchStart, string candidate) public static Vector512 TransformInput(Vector512 input) => input & Vector512.Create(unchecked((byte)~0x20)); [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static bool Equals(ref char matchStart, string candidate) + public static bool Equals(ref char matchStart, ref readonly SingleValueState state) where TValueLength : struct, IValueLength { - Debug.Assert(candidate.Length > 1); - Debug.Assert(candidate.ToUpperInvariant() == candidate); - - if (TValueLength.AtLeast8CharsOrUnknown) + if (typeof(TValueLength) == typeof(ValueLengthLongOrUnknown)) { - return Ascii.EqualsIgnoreCase(ref matchStart, ref candidate.GetRawStringData(), (uint)candidate.Length); + return UnknownLengthEquals(ref matchStart, state.Value); } - - ref byte first = ref Unsafe.As(ref matchStart); - ref byte second = ref Unsafe.As(ref candidate.GetRawStringData()); - nuint byteLength = (nuint)(uint)candidate.Length * 2; - - if (TValueLength.AtLeast4Chars) + else if (typeof(TValueLength) == typeof(ValueLength9To16)) + { + return state.MatchesLength9To16_CaseInsensitiveAscii(ref matchStart); + } + else if (typeof(TValueLength) == typeof(ValueLength4To8)) { const ulong CaseMask = ~0x20002000200020u; - nuint offset = byteLength - sizeof(ulong); - ulong differentBits = (Unsafe.ReadUnaligned(ref first) & CaseMask) - Unsafe.ReadUnaligned(ref second); - differentBits |= (Unsafe.ReadUnaligned(ref Unsafe.Add(ref first, offset)) & CaseMask) - Unsafe.ReadUnaligned(ref Unsafe.Add(ref second, offset)); + ref byte matchByteStart = ref Unsafe.As(ref matchStart); + ulong differentBits = (Unsafe.ReadUnaligned(ref matchByteStart) & CaseMask) - state.Value64_0; + differentBits |= (Unsafe.ReadUnaligned(ref Unsafe.Add(ref matchByteStart, state.SecondReadByteOffset)) & CaseMask) - state.Value64_1; return differentBits == 0; } else { + Debug.Assert(state.Value.Length is 2 or 3); + Debug.Assert(TransformInput(matchStart) == state.Value[0], "This should only be called after the first character has been checked"); + + // We know that the candidate is 2 or 3 characters long, and that the first character has already been checked. + // We only have to to check whether the last 2 characters also match. const uint CaseMask = ~0x200020u; - nuint offset = byteLength - sizeof(uint); - uint differentBits = (Unsafe.ReadUnaligned(ref first) & CaseMask) - Unsafe.ReadUnaligned(ref second); - differentBits |= (Unsafe.ReadUnaligned(ref Unsafe.Add(ref first, offset)) & CaseMask) - Unsafe.ReadUnaligned(ref Unsafe.Add(ref second, offset)); - return differentBits == 0; + ref byte matchByteStart = ref Unsafe.As(ref matchStart); + return (Unsafe.ReadUnaligned(ref Unsafe.Add(ref matchByteStart, state.SecondReadByteOffset)) & CaseMask) == state.Value32_1; } } } @@ -259,15 +371,34 @@ public static Vector512 TransformInput(Vector512 input) } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static bool Equals(ref char matchStart, string candidate) + public static bool Equals(ref char matchStart, ref readonly SingleValueState state) where TValueLength : struct, IValueLength { - if (TValueLength.AtLeast8CharsOrUnknown) + if (typeof(TValueLength) == typeof(ValueLengthLongOrUnknown)) + { + return UnknownLengthEquals(ref matchStart, state.Value); + } + else if (typeof(TValueLength) == typeof(ValueLength9To16)) + { + return state.MatchesLength9To16_CaseInsensitiveAscii(ref matchStart); + } + else if (typeof(TValueLength) == typeof(ValueLength4To8)) { - return Ascii.EqualsIgnoreCase(ref matchStart, ref candidate.GetRawStringData(), (uint)candidate.Length); + ref byte matchByteStart = ref Unsafe.As(ref matchStart); + ulong differentBits = (Unsafe.ReadUnaligned(ref matchByteStart) & state.ToUpperMask64_0) - state.Value64_0; + differentBits |= (Unsafe.ReadUnaligned(ref Unsafe.Add(ref matchByteStart, state.SecondReadByteOffset)) & state.ToUpperMask64_1) - state.Value64_1; + return differentBits == 0; } + else + { + Debug.Assert(state.Value.Length is 2 or 3); + Debug.Assert((matchStart & ~0x20) == (state.Value[0] & ~0x20)); - return ScalarEquals(ref matchStart, candidate); + ref byte matchByteStart = ref Unsafe.As(ref matchStart); + uint differentBits = (Unsafe.ReadUnaligned(ref matchByteStart) & state.ToUpperMask32_0) - state.Value32_0; + differentBits |= (Unsafe.ReadUnaligned(ref Unsafe.Add(ref matchByteStart, state.SecondReadByteOffset)) & state.ToUpperMask32_1) - state.Value32_1; + return differentBits == 0; + } } } @@ -281,15 +412,17 @@ public static bool Equals(ref char matchStart, string candidate) public static Vector512 TransformInput(Vector512 input) => throw new UnreachableException(); [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static bool Equals(ref char matchStart, string candidate) + public static bool Equals(ref char matchStart, ref readonly SingleValueState state) where TValueLength : struct, IValueLength { - if (TValueLength.AtLeast8CharsOrUnknown) + if (typeof(TValueLength) == typeof(ValueLengthLongOrUnknown)) { - return Ordinal.EqualsIgnoreCase(ref matchStart, ref candidate.GetRawStringData(), candidate.Length); + return UnknownLengthEquals(ref matchStart, state.Value); + } + else + { + return Ordinal.EqualsIgnoreCase_Scalar(ref matchStart, ref state.Value.GetRawStringData(), state.Value.Length); } - - return Ordinal.EqualsIgnoreCase_Scalar(ref matchStart, ref candidate.GetRawStringData(), candidate.Length); } } } diff --git a/src/libraries/System.Private.CoreLib/src/System/SearchValues/Strings/SingleStringSearchValuesThreeChars.cs b/src/libraries/System.Private.CoreLib/src/System/SearchValues/Strings/SingleStringSearchValuesThreeChars.cs index 429e690f49a915..c005173b67e143 100644 --- a/src/libraries/System.Private.CoreLib/src/System/SearchValues/Strings/SingleStringSearchValuesThreeChars.cs +++ b/src/libraries/System.Private.CoreLib/src/System/SearchValues/Strings/SingleStringSearchValuesThreeChars.cs @@ -14,14 +14,14 @@ namespace System.Buffers // Based on SpanHelpers.IndexOf(ref char, int, ref char, int) // This implementation uses 3 precomputed anchor points when searching. // This implementation may also be used for length=2 values, in which case two anchors point at the same position. - // Has an O(i * m) worst-case, with the expected time closer to O(n) for most inputs. + // Has an O(i * m) worst-case, with the expected time closer to O(i) for most inputs. internal sealed class SingleStringSearchValuesThreeChars : StringSearchValuesBase where TValueLength : struct, IValueLength where TCaseSensitivity : struct, ICaseSensitivity { private const ushort CaseConversionMask = unchecked((ushort)~0x20); - private readonly string _value; + private readonly SingleValueState _valueState; private readonly nint _minusValueTailLength; private readonly nuint _ch2ByteOffset; private readonly nuint _ch3ByteOffset; @@ -31,7 +31,7 @@ internal sealed class SingleStringSearchValuesThreeChars typeof(TCaseSensitivity) != typeof(CaseSensitive); - // If the value is short (!TValueLength.AtLeast4Chars => 2 or 3 characters), the anchors already represent the whole value. + // If the value is short (ValueLengthLessThan4 => 2 or 3 characters), the anchors already represent the whole value. // With case-sensitive comparisons, we've therefore already confirmed the match. // With case-insensitive comparisons, we've applied the CaseConversionMask to the input, so while the anchors likely matched, we can't be sure. // An exception to that is if we know the value is composed of only ASCII letters, in which case masking the input can't produce false positives. @@ -39,7 +39,7 @@ private static bool CanSkipAnchorMatchVerification { [MethodImpl(MethodImplOptions.AggressiveInlining)] get => - !TValueLength.AtLeast4Chars && + typeof(TValueLength) == typeof(ValueLengthLessThan4) && (typeof(TCaseSensitivity) == typeof(CaseSensitive) || typeof(TCaseSensitivity) == typeof(CaseInsensitiveAsciiLetters)); } @@ -47,13 +47,12 @@ public SingleStringSearchValuesThreeChars(HashSet? uniqueValues, string { // We could have more than one entry in 'uniqueValues' if this value is an exact prefix of all the others. Debug.Assert(value.Length > 1); - Debug.Assert((value.Length >= 8) == TValueLength.AtLeast8CharsOrUnknown); CharacterFrequencyHelper.GetSingleStringMultiCharacterOffsets(value, IgnoreCase, out int ch2Offset, out int ch3Offset); Debug.Assert(ch3Offset == 0 || ch3Offset > ch2Offset); - _value = value; + _valueState = new SingleValueState(value, IgnoreCase); _minusValueTailLength = -(value.Length - 1); _ch1 = value[0]; @@ -233,8 +232,7 @@ private int IndexOf(ref char searchSpace, int searchSpaceLength) } ShortInput: - string value = _value; - char valueHead = value.GetRawStringData(); + char valueHead = _valueState.Value.GetRawStringData(); for (nint i = 0; i < searchSpaceMinusValueTailLength; i++) { @@ -242,7 +240,7 @@ private int IndexOf(ref char searchSpace, int searchSpaceLength) // CaseInsensitiveUnicode doesn't support single-character transformations, so we skip checking the first character first. if ((typeof(TCaseSensitivity) == typeof(CaseInsensitiveUnicode) || TCaseSensitivity.TransformInput(cur) == valueHead) && - TCaseSensitivity.Equals(ref cur, value)) + TCaseSensitivity.Equals(ref cur, in _valueState)) { return (int)i; } @@ -337,9 +335,9 @@ private bool TryMatch(ref char searchSpaceStart, int searchSpaceLength, ref char ref char matchRef = ref Unsafe.AddByteOffset(ref searchSpace, bitPos); - ValidateReadPosition(ref searchSpaceStart, searchSpaceLength, ref matchRef, _value.Length); + ValidateReadPosition(ref searchSpaceStart, searchSpaceLength, ref matchRef, _valueState.Value.Length); - if (CanSkipAnchorMatchVerification || TCaseSensitivity.Equals(ref matchRef, _value)) + if (CanSkipAnchorMatchVerification || TCaseSensitivity.Equals(ref matchRef, in _valueState)) { offsetFromStart = (int)((nuint)Unsafe.ByteOffset(ref searchSpaceStart, ref matchRef) / 2); return true; @@ -365,9 +363,9 @@ private bool TryMatch(ref char searchSpaceStart, int searchSpaceLength, ref char ref char matchRef = ref Unsafe.AddByteOffset(ref searchSpace, bitPos); - ValidateReadPosition(ref searchSpaceStart, searchSpaceLength, ref matchRef, _value.Length); + ValidateReadPosition(ref searchSpaceStart, searchSpaceLength, ref matchRef, _valueState.Value.Length); - if (CanSkipAnchorMatchVerification || TCaseSensitivity.Equals(ref matchRef, _value)) + if (CanSkipAnchorMatchVerification || TCaseSensitivity.Equals(ref matchRef, in _valueState)) { offsetFromStart = (int)((nuint)Unsafe.ByteOffset(ref searchSpaceStart, ref matchRef) / 2); return true; @@ -384,10 +382,10 @@ private bool TryMatch(ref char searchSpaceStart, int searchSpaceLength, ref char internal override bool ContainsCore(string value) => HasUniqueValues ? base.ContainsCore(value) - : _value.Equals(value, IgnoreCase ? StringComparison.OrdinalIgnoreCase : StringComparison.Ordinal); + : _valueState.Value.Equals(value, IgnoreCase ? StringComparison.OrdinalIgnoreCase : StringComparison.Ordinal); internal override string[] GetValues() => HasUniqueValues ? base.GetValues() - : [_value]; + : [_valueState.Value]; } } diff --git a/src/libraries/System.Private.CoreLib/src/System/SearchValues/Strings/StringSearchValues.cs b/src/libraries/System.Private.CoreLib/src/System/SearchValues/Strings/StringSearchValues.cs index e2ae3c61b04455..42ae98a2b440ce 100644 --- a/src/libraries/System.Private.CoreLib/src/System/SearchValues/Strings/StringSearchValues.cs +++ b/src/libraries/System.Private.CoreLib/src/System/SearchValues/Strings/StringSearchValues.cs @@ -389,8 +389,9 @@ private static SearchValues CreateForSingleValue( SearchValues? searchValues = value.Length switch { < 4 => TryCreateSingleValuesThreeChars(value, uniqueValues, ignoreCase, allAscii, asciiLettersOnly), - < 8 => TryCreateSingleValuesThreeChars(value, uniqueValues, ignoreCase, allAscii, asciiLettersOnly), - _ => TryCreateSingleValuesThreeChars(value, uniqueValues, ignoreCase, allAscii, asciiLettersOnly), + <= 8 => TryCreateSingleValuesThreeChars(value, uniqueValues, ignoreCase, allAscii, asciiLettersOnly), + <= 16 => TryCreateSingleValuesThreeChars(value, uniqueValues, ignoreCase, allAscii, asciiLettersOnly), + _ => TryCreateSingleValuesThreeChars(value, uniqueValues, ignoreCase, allAscii, asciiLettersOnly), }; if (searchValues is not null) diff --git a/src/libraries/System.Private.CoreLib/src/System/String.Comparison.cs b/src/libraries/System.Private.CoreLib/src/System/String.Comparison.cs index 9eacfefab5d1eb..b2a8a6c4c65ea5 100644 --- a/src/libraries/System.Private.CoreLib/src/System/String.Comparison.cs +++ b/src/libraries/System.Private.CoreLib/src/System/String.Comparison.cs @@ -27,9 +27,9 @@ private static bool EqualsHelper(string strA, string strB) Debug.Assert(strA.Length == strB.Length); return SpanHelpers.SequenceEqual( - ref Unsafe.As(ref strA.GetRawStringData()), - ref Unsafe.As(ref strB.GetRawStringData()), - ((uint)strA.Length) * sizeof(char)); + ref strA.GetRawStringDataAsUInt8(), + ref strB.GetRawStringDataAsUInt8(), + ((uint)strA.Length) * sizeof(char)); } [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -1138,8 +1138,8 @@ public bool StartsWith(string value, StringComparison comparisonType) return (value.Length == 1) ? true : // First char is the same and thats all there is to compare SpanHelpers.SequenceEqual( - ref Unsafe.As(ref this.GetRawStringData()), - ref Unsafe.As(ref value.GetRawStringData()), + ref this.GetRawStringDataAsUInt8(), + ref value.GetRawStringDataAsUInt8(), ((nuint)value.Length) * 2); case StringComparison.OrdinalIgnoreCase: diff --git a/src/libraries/System.Private.CoreLib/src/System/String.cs b/src/libraries/System.Private.CoreLib/src/System/String.cs index 01564cb1cb7aa6..e94d8778f1cab4 100644 --- a/src/libraries/System.Private.CoreLib/src/System/String.cs +++ b/src/libraries/System.Private.CoreLib/src/System/String.cs @@ -525,6 +525,7 @@ public static bool IsNullOrWhiteSpace([NotNullWhen(false)] string? value) public ref readonly char GetPinnableReference() => ref _firstChar; internal ref char GetRawStringData() => ref _firstChar; + internal ref byte GetRawStringDataAsUInt8() => ref Unsafe.As(ref _firstChar); internal ref ushort GetRawStringDataAsUInt16() => ref Unsafe.As(ref _firstChar); // Helper for encodings so they can talk to our buffer directly