From 771179a800bcfc9e56e7b38e8140d08885ca952f Mon Sep 17 00:00:00 2001 From: Miha Zupan Date: Sat, 13 Jan 2024 04:50:21 +0100 Subject: [PATCH 1/3] Add Avx512 support to ProbabilisticMap --- .../System.Memory/tests/Span/SearchValues.cs | 2 +- .../System/SearchValues/ProbabilisticMap.cs | 307 +++++++++++++++--- 2 files changed, 270 insertions(+), 39 deletions(-) diff --git a/src/libraries/System.Memory/tests/Span/SearchValues.cs b/src/libraries/System.Memory/tests/Span/SearchValues.cs index 1d58c2cf646c39..9ef91acec80dab 100644 --- a/src/libraries/System.Memory/tests/Span/SearchValues.cs +++ b/src/libraries/System.Memory/tests/Span/SearchValues.cs @@ -403,7 +403,7 @@ static int LastIndexOfAnyExceptReferenceImpl(ReadOnlySpan searchSpace, Rea private static class SearchValuesTestHelper { private const int MaxNeedleLength = 10; - private const int MaxHaystackLength = 100; + private const int MaxHaystackLength = 200; private static readonly char[] s_randomAsciiChars; private static readonly char[] s_randomLatin1Chars; diff --git a/src/libraries/System.Private.CoreLib/src/System/SearchValues/ProbabilisticMap.cs b/src/libraries/System.Private.CoreLib/src/System/SearchValues/ProbabilisticMap.cs index bfad77993dbac4..0a7b7900b35c61 100644 --- a/src/libraries/System.Private.CoreLib/src/System/SearchValues/ProbabilisticMap.cs +++ b/src/libraries/System.Private.CoreLib/src/System/SearchValues/ProbabilisticMap.cs @@ -3,7 +3,6 @@ using System.Diagnostics; using System.Numerics; -using System.Runtime; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Runtime.Intrinsics; @@ -106,6 +105,76 @@ ref Unsafe.As(ref MemoryMarshal.GetReference(values)), (short)ch, values.Length); + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CompExactlyDependsOn(typeof(Avx512Vbmi))] + private static Vector512 ContainsMask64CharsAvx512(Vector512 charMap, ref char searchSpace0, ref char searchSpace1) + { + Vector512 source0 = Vector512.LoadUnsafe(ref searchSpace0); + Vector512 source1 = Vector512.LoadUnsafe(ref searchSpace1); + + Vector512 sourceLower = Avx512BW.PackUnsignedSaturate( + (source0 & Vector512.Create((ushort)255)).AsInt16(), + (source1 & Vector512.Create((ushort)255)).AsInt16()); + + Vector512 sourceUpper = Avx512BW.PackUnsignedSaturate( + (source0 >>> 8).AsInt16(), + (source1 >>> 8).AsInt16()); + + Vector512 resultLower = IsCharBitNotSetAvx512(charMap, sourceLower); + Vector512 resultUpper = IsCharBitNotSetAvx512(charMap, sourceUpper); + + return ~(resultLower | resultUpper); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CompExactlyDependsOn(typeof(Avx512Vbmi))] + private static Vector512 IsCharBitNotSetAvx512(Vector512 charMap, Vector512 values) + { + Vector512 shifted = values >>> VectorizedIndexShift; + + Vector512 bitPositions = Avx512BW.Shuffle(Vector512.Create(0x8040201008040201).AsByte(), shifted); + + Vector512 index = values & Vector512.Create((byte)VectorizedIndexMask); + Vector512 bitMask = Avx512Vbmi.PermuteVar64x8(charMap, index); + + return Vector512.Equals(bitMask & bitPositions, Vector512.Zero); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CompExactlyDependsOn(typeof(Avx512Vbmi.VL))] + private static Vector256 ContainsMask32CharsAvx512(Vector256 charMap, ref char searchSpace0, ref char searchSpace1) + { + Vector256 source0 = Vector256.LoadUnsafe(ref searchSpace0); + Vector256 source1 = Vector256.LoadUnsafe(ref searchSpace1); + + Vector256 sourceLower = Avx2.PackUnsignedSaturate( + (source0 & Vector256.Create((ushort)255)).AsInt16(), + (source1 & Vector256.Create((ushort)255)).AsInt16()); + + Vector256 sourceUpper = Avx2.PackUnsignedSaturate( + (source0 >>> 8).AsInt16(), + (source1 >>> 8).AsInt16()); + + Vector256 resultLower = IsCharBitNotSetAvx512(charMap, sourceLower); + Vector256 resultUpper = IsCharBitNotSetAvx512(charMap, sourceUpper); + + return ~(resultLower | resultUpper); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CompExactlyDependsOn(typeof(Avx512Vbmi.VL))] + private static Vector256 IsCharBitNotSetAvx512(Vector256 charMap, Vector256 values) + { + Vector256 shifted = values >>> VectorizedIndexShift; + + Vector256 bitPositions = Avx2.Shuffle(Vector256.Create(0x8040201008040201).AsByte(), shifted); + + Vector256 index = values & Vector256.Create((byte)VectorizedIndexMask); + Vector256 bitMask = Avx512Vbmi.VL.PermuteVar32x8(charMap, index); + + return Vector256.Equals(bitMask & bitPositions, Vector256.Zero); + } + [MethodImpl(MethodImplOptions.AggressiveInlining)] [CompExactlyDependsOn(typeof(Avx2))] private static Vector256 ContainsMask32CharsAvx2(Vector256 charMapLower, Vector256 charMapUpper, ref char searchSpace) @@ -121,15 +190,15 @@ private static Vector256 ContainsMask32CharsAvx2(Vector256 charMapLo (source0 >>> 8).AsInt16(), (source1 >>> 8).AsInt16()); - Vector256 resultLower = IsCharBitSetAvx2(charMapLower, charMapUpper, sourceLower); - Vector256 resultUpper = IsCharBitSetAvx2(charMapLower, charMapUpper, sourceUpper); + Vector256 resultLower = IsCharBitNotSetAvx2(charMapLower, charMapUpper, sourceLower); + Vector256 resultUpper = IsCharBitNotSetAvx2(charMapLower, charMapUpper, sourceUpper); - return resultLower & resultUpper; + return ~(resultLower | resultUpper); } [MethodImpl(MethodImplOptions.AggressiveInlining)] [CompExactlyDependsOn(typeof(Avx2))] - private static Vector256 IsCharBitSetAvx2(Vector256 charMapLower, Vector256 charMapUpper, Vector256 values) + private static Vector256 IsCharBitNotSetAvx2(Vector256 charMapLower, Vector256 charMapUpper, Vector256 values) { Vector256 shifted = values >>> VectorizedIndexShift; @@ -141,7 +210,7 @@ private static Vector256 IsCharBitSetAvx2(Vector256 charMapLower, Ve Vector256 mask = Vector256.GreaterThan(index, Vector256.Create((byte)15)); Vector256 bitMask = Vector256.ConditionalSelect(mask, bitMaskUpper, bitMaskLower); - return ~Vector256.Equals(bitMask & bitPositions, Vector256.Zero); + return Vector256.Equals(bitMask & bitPositions, Vector256.Zero); } [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -160,10 +229,10 @@ private static Vector128 ContainsMask16Chars(Vector128 charMapLower, ? Sse2.PackUnsignedSaturate((source0 >>> 8).AsInt16(), (source1 >>> 8).AsInt16()) : AdvSimd.Arm64.UnzipOdd(source0.AsByte(), source1.AsByte()); - Vector128 resultLower = IsCharBitSet(charMapLower, charMapUpper, sourceLower); - Vector128 resultUpper = IsCharBitSet(charMapLower, charMapUpper, sourceUpper); + Vector128 resultLower = IsCharBitNotSet(charMapLower, charMapUpper, sourceLower); + Vector128 resultUpper = IsCharBitNotSet(charMapLower, charMapUpper, sourceUpper); - return resultLower & resultUpper; + return ~(resultLower | resultUpper); } [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -172,7 +241,7 @@ private static Vector128 ContainsMask16Chars(Vector128 charMapLower, [CompExactlyDependsOn(typeof(AdvSimd))] [CompExactlyDependsOn(typeof(AdvSimd.Arm64))] [CompExactlyDependsOn(typeof(PackedSimd))] - private static Vector128 IsCharBitSet(Vector128 charMapLower, Vector128 charMapUpper, Vector128 values) + private static Vector128 IsCharBitNotSet(Vector128 charMapLower, Vector128 charMapUpper, Vector128 values) { Vector128 shifted = values >>> VectorizedIndexShift; @@ -193,7 +262,7 @@ private static Vector128 IsCharBitSet(Vector128 charMapLower, Vector bitMask = Vector128.ConditionalSelect(mask, bitMaskUpper, bitMaskLower); } - return ~Vector128.Equals(bitMask & bitPositions, Vector128.Zero); + return Vector128.Equals(bitMask & bitPositions, Vector128.Zero); } [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -302,7 +371,9 @@ internal static int IndexOfAny(ref uint charMap, ref char searchSpace, int searc { if ((Sse41.IsSupported || AdvSimd.Arm64.IsSupported) && searchSpaceLength >= 16) { - return IndexOfAnyVectorized(ref charMap, ref searchSpace, searchSpaceLength, values); + return Avx512Vbmi.VL.IsSupported + ? IndexOfAnyVectorizedAvx512(ref charMap, ref searchSpace, searchSpaceLength, values) + : IndexOfAnyVectorized(ref charMap, ref searchSpace, searchSpaceLength, values); } ref char searchSpaceEnd = ref Unsafe.Add(ref searchSpace, searchSpaceLength); @@ -313,7 +384,7 @@ internal static int IndexOfAny(ref uint charMap, ref char searchSpace, int searc int ch = cur; if (Contains(ref charMap, values, ch)) { - return (int)((nuint)Unsafe.ByteOffset(ref searchSpace, ref cur) / sizeof(char)); + return MatchOffset(ref searchSpace, ref cur); } cur = ref Unsafe.Add(ref cur, 1); @@ -337,6 +408,88 @@ internal static int LastIndexOfAny(ref uint charMap, ref char searchSpace, int s return -1; } + [CompExactlyDependsOn(typeof(Avx512Vbmi.VL))] + private static int IndexOfAnyVectorizedAvx512(ref uint charMap, ref char searchSpace, int searchSpaceLength, ReadOnlySpan values) + { + Debug.Assert(Avx512Vbmi.VL.IsSupported); + Debug.Assert(searchSpaceLength >= 16); + + ref char searchSpaceEnd = ref Unsafe.Add(ref searchSpace, searchSpaceLength); + + Vector256 charMap256 = Vector256.LoadUnsafe(ref Unsafe.As(ref charMap)); + + if (searchSpaceLength > 32) + { + Vector512 charMap512 = Vector512.Create(charMap256, charMap256); + + if (searchSpaceLength > 64) + { + ref char cur = ref searchSpace; + ref char lastStartVector = ref Unsafe.Subtract(ref searchSpaceEnd, 64); + + while (true) + { + Vector512 result = ContainsMask64CharsAvx512(charMap512, ref cur, ref Unsafe.Add(ref cur, Vector512.Count)); + + if (result != Vector512.Zero) + { + if (TryFindMatch(ref cur, PackedSpanHelpers.FixUpPackedVector512Result(result).ExtractMostSignificantBits(), values, out int index)) + { + return MatchOffset(ref searchSpace, ref cur) + index; + } + } + + cur = ref Unsafe.Add(ref cur, 64); + + if (Unsafe.IsAddressGreaterThan(ref cur, ref lastStartVector)) + { + if (Unsafe.AreSame(ref cur, ref searchSpaceEnd)) + { + break; + } + + // Adjust the current vector and do one last iteration. + cur = ref lastStartVector; + } + } + } + else + { + Debug.Assert(searchSpaceLength is > 32 and <= 64); + + // Process the first and last vector in the search space. + // They may overlap, but we'll handle that in the index calculation if we do get a match. + Vector512 result = ContainsMask64CharsAvx512(charMap512, ref searchSpace, ref Unsafe.Subtract(ref searchSpaceEnd, Vector512.Count)); + + if (result != Vector512.Zero) + { + if (TryFindMatchOverlapped(ref searchSpace, searchSpaceLength, PackedSpanHelpers.FixUpPackedVector512Result(result).ExtractMostSignificantBits(), values, out int index)) + { + return index; + } + } + } + } + else + { + Debug.Assert(searchSpaceLength is >= 16 and < 32); + + // Process the first and last vector in the search space. + // They may overlap, but we'll handle that in the index calculation if we do get a match. + Vector256 result = ContainsMask32CharsAvx512(charMap256, ref searchSpace, ref Unsafe.Subtract(ref searchSpaceEnd, Vector256.Count)); + + if (result != Vector256.Zero) + { + if (TryFindMatchOverlapped(ref searchSpace, searchSpaceLength, PackedSpanHelpers.FixUpPackedVector256Result(result).ExtractMostSignificantBits(), values, out int index)) + { + return index; + } + } + } + + return -1; + } + [CompExactlyDependsOn(typeof(AdvSimd.Arm64))] [CompExactlyDependsOn(typeof(Sse41))] private static int IndexOfAnyVectorized(ref uint charMap, ref char searchSpace, int searchSpaceLength, ReadOnlySpan values) @@ -365,21 +518,10 @@ private static int IndexOfAnyVectorized(ref uint charMap, ref char searchSpace, if (result != Vector256.Zero) { - result = PackedSpanHelpers.FixUpPackedVector256Result(result); - - uint mask = result.ExtractMostSignificantBits(); - do + if (TryFindMatch(ref cur, PackedSpanHelpers.FixUpPackedVector256Result(result).ExtractMostSignificantBits(), values, out int index)) { - ref char candidatePos = ref Unsafe.Add(ref cur, BitOperations.TrailingZeroCount(mask)); - - if (Contains(values, candidatePos)) - { - return (int)((nuint)Unsafe.ByteOffset(ref searchSpace, ref candidatePos) / sizeof(char)); - } - - mask = BitOperations.ResetLowestSetBit(mask); + return MatchOffset(ref searchSpace, ref cur) + index; } - while (mask != 0); } cur = ref Unsafe.Add(ref cur, 32); @@ -416,19 +558,10 @@ private static int IndexOfAnyVectorized(ref uint charMap, ref char searchSpace, if (result != Vector128.Zero) { - uint mask = result.ExtractMostSignificantBits(); - do + if (TryFindMatch(ref cur, result.ExtractMostSignificantBits(), values, out int index)) { - ref char candidatePos = ref Unsafe.Add(ref cur, BitOperations.TrailingZeroCount(mask)); - - if (Contains(values, candidatePos)) - { - return (int)((nuint)Unsafe.ByteOffset(ref searchSpace, ref candidatePos) / sizeof(char)); - } - - mask = BitOperations.ResetLowestSetBit(mask); + return MatchOffset(ref searchSpace, ref cur) + index; } - while (mask != 0); } cur = ref Unsafe.Add(ref cur, 16); @@ -448,6 +581,104 @@ private static int IndexOfAnyVectorized(ref uint charMap, ref char searchSpace, return -1; } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static int MatchOffset(ref char searchSpace, ref char cur) => + (int)((nuint)Unsafe.ByteOffset(ref searchSpace, ref cur) / sizeof(char)); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static bool TryFindMatch(ref char cur, uint mask, ReadOnlySpan values, out int index) + { + do + { + index = BitOperations.TrailingZeroCount(mask); + + if (Contains(values, Unsafe.Add(ref cur, index))) + { + return true; + } + + mask = BitOperations.ResetLowestSetBit(mask); + } + while (mask != 0); + + index = 0; + return false; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static bool TryFindMatchOverlapped(ref char cur, int searchSpaceLength, uint mask, ReadOnlySpan values, out int index) + { + do + { + index = BitOperations.TrailingZeroCount(mask); + + if (index >= Vector256.Count) + { + // The potential match is in the second vector. + // Fixup the index to account for how we loaded the second overlapped vector. + index += searchSpaceLength - (2 * Vector256.Count); + } + + if (Contains(values, Unsafe.Add(ref cur, index))) + { + return true; + } + + mask = BitOperations.ResetLowestSetBit(mask); + } + while (mask != 0); + + index = 0; + return false; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static bool TryFindMatch(ref char cur, ulong mask, ReadOnlySpan values, out int index) + { + do + { + index = BitOperations.TrailingZeroCount(mask); + + if (Contains(values, Unsafe.Add(ref cur, index))) + { + return true; + } + + mask = BitOperations.ResetLowestSetBit(mask); + } + while (mask != 0); + + index = 0; + return false; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static bool TryFindMatchOverlapped(ref char cur, int searchSpaceLength, ulong mask, ReadOnlySpan values, out int index) + { + do + { + index = BitOperations.TrailingZeroCount(mask); + + if (index >= Vector512.Count) + { + // The potential match is in the second vector. + // Fixup the index to account for how we loaded the second overlapped vector. + index += searchSpaceLength - (2 * Vector512.Count); + } + + if (Contains(values, Unsafe.Add(ref cur, index))) + { + return true; + } + + mask = BitOperations.ResetLowestSetBit(mask); + } + while (mask != 0); + + index = 0; + return false; + } + [MethodImpl(MethodImplOptions.AggressiveInlining)] internal static int IndexOfAnySimpleLoop(ref char searchSpace, int searchSpaceLength, ReadOnlySpan values) where TNegator : struct, IndexOfAnyAsciiSearcher.INegator @@ -460,7 +691,7 @@ internal static int IndexOfAnySimpleLoop(ref char searchSpace, int sea char c = cur; if (TNegator.NegateIfNeeded(Contains(values, c))) { - return (int)((nuint)Unsafe.ByteOffset(ref searchSpace, ref cur) / sizeof(char)); + return MatchOffset(ref searchSpace, ref cur); } cur = ref Unsafe.Add(ref cur, 1); From d87fc68f97d006a03f175439d19d15c237cc3e48 Mon Sep 17 00:00:00 2001 From: Miha Zupan Date: Sat, 13 Jan 2024 05:35:11 +0100 Subject: [PATCH 2/3] Add Vector512 condition --- .../src/System/SearchValues/ProbabilisticMap.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libraries/System.Private.CoreLib/src/System/SearchValues/ProbabilisticMap.cs b/src/libraries/System.Private.CoreLib/src/System/SearchValues/ProbabilisticMap.cs index 0a7b7900b35c61..7f9d287791606c 100644 --- a/src/libraries/System.Private.CoreLib/src/System/SearchValues/ProbabilisticMap.cs +++ b/src/libraries/System.Private.CoreLib/src/System/SearchValues/ProbabilisticMap.cs @@ -371,7 +371,7 @@ internal static int IndexOfAny(ref uint charMap, ref char searchSpace, int searc { if ((Sse41.IsSupported || AdvSimd.Arm64.IsSupported) && searchSpaceLength >= 16) { - return Avx512Vbmi.VL.IsSupported + return Vector512.IsHardwareAccelerated && Avx512Vbmi.VL.IsSupported ? IndexOfAnyVectorizedAvx512(ref charMap, ref searchSpace, searchSpaceLength, values) : IndexOfAnyVectorized(ref charMap, ref searchSpace, searchSpaceLength, values); } From 68a5b3f0d462e2baef1b770881c0fcd2f5f9fe35 Mon Sep 17 00:00:00 2001 From: Miha Zupan Date: Sat, 13 Jan 2024 06:03:48 +0100 Subject: [PATCH 3/3] Fix assert --- .../src/System/SearchValues/ProbabilisticMap.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libraries/System.Private.CoreLib/src/System/SearchValues/ProbabilisticMap.cs b/src/libraries/System.Private.CoreLib/src/System/SearchValues/ProbabilisticMap.cs index 7f9d287791606c..150372914d8b2e 100644 --- a/src/libraries/System.Private.CoreLib/src/System/SearchValues/ProbabilisticMap.cs +++ b/src/libraries/System.Private.CoreLib/src/System/SearchValues/ProbabilisticMap.cs @@ -472,7 +472,7 @@ private static int IndexOfAnyVectorizedAvx512(ref uint charMap, ref char searchS } else { - Debug.Assert(searchSpaceLength is >= 16 and < 32); + Debug.Assert(searchSpaceLength is >= 16 and <= 32); // Process the first and last vector in the search space. // They may overlap, but we'll handle that in the index calculation if we do get a match.