diff --git a/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs b/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs index 031ba109ba6b57..de78d145524b0b 100644 --- a/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs +++ b/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs @@ -35,6 +35,9 @@ public static void BitwiseOr(System.ReadOnlySpan x, System.ReadOnlySpan public static void BitwiseOr(System.ReadOnlySpan x, T y, System.Span destination) where T : System.Numerics.IBitwiseOperators { } public static void Cbrt(System.ReadOnlySpan x, System.Span destination) where T : System.Numerics.IRootFunctions { } public static void Ceiling(System.ReadOnlySpan x, System.Span destination) where T : System.Numerics.IFloatingPoint { } + public static void ConvertChecked(System.ReadOnlySpan source, System.Span destination) where TFrom : System.Numerics.INumberBase where TTo : System.Numerics.INumberBase { } + public static void ConvertSaturating(System.ReadOnlySpan source, System.Span destination) where TFrom : System.Numerics.INumberBase where TTo : System.Numerics.INumberBase { } + public static void ConvertTruncating(System.ReadOnlySpan source, System.Span destination) where TFrom : System.Numerics.INumberBase where TTo : System.Numerics.INumberBase { } public static void ConvertToHalf(System.ReadOnlySpan source, System.Span destination) { } public static void ConvertToSingle(System.ReadOnlySpan source, System.Span destination) { } public static void CopySign(System.ReadOnlySpan x, System.ReadOnlySpan sign, System.Span destination) where T : System.Numerics.INumber { } diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.Single.netcore.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.Single.netcore.cs index 1148639d48be02..32723d2b3c3c5e 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.Single.netcore.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.Single.netcore.cs @@ -43,7 +43,7 @@ public static unsafe partial class TensorPrimitives { private static void InvokeSpanIntoSpan( ReadOnlySpan x, Span destination) - where TSingleUnaryOperator : struct, IUnaryOperator => + where TSingleUnaryOperator : struct, IUnaryOperator => InvokeSpanIntoSpan(x, destination); private static void InvokeSpanSpanIntoSpan( @@ -58,7 +58,7 @@ private static void InvokeSpanScalarIntoSpan( private static unsafe void InvokeSpanScalarIntoSpan( ReadOnlySpan x, float y, Span destination) - where TSingleTransformOperator : struct, IUnaryOperator + where TSingleTransformOperator : struct, IUnaryOperator where TSingleBinaryOperator : struct, IBinaryOperator => InvokeSpanScalarIntoSpan(x, y, destination); @@ -79,7 +79,7 @@ private static void InvokeSpanScalarSpanIntoSpan( private static unsafe float Aggregate( ReadOnlySpan x) - where TSingleTransformOperator : struct, IUnaryOperator + where TSingleTransformOperator : struct, IUnaryOperator where TSingleAggregationOperator : struct, IAggregationOperator => Aggregate(x); diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.T.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.T.cs index 4009257b22d36f..0562d74cc9dc49 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.T.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.T.cs @@ -1,6 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Runtime.CompilerServices; + namespace System.Numerics.Tensors { /// Performs primitive tensor operations over spans of memory. @@ -488,6 +490,249 @@ public static void Ceiling(ReadOnlySpan x, Span destination) where T : IFloatingPoint => InvokeSpanIntoSpan>(x, destination); + /// + /// Copies to , converting each + /// value to a value. + /// + /// The source span from which to copy values. + /// The destination span into which the converted values should be written. + /// Destination is too short. + /// + /// + /// This method effectively computes [i] = TTo.CreateChecked([i]). + /// + /// + public static void ConvertChecked(ReadOnlySpan source, Span destination) + where TFrom : INumberBase + where TTo : INumberBase + { + if (!TryConvertUniversal(source, destination)) + { + InvokeSpanIntoSpan>(source, destination); + } + } + + /// + /// Copies to , converting each + /// value to a value. + /// + /// The source span from which to copy values. + /// The destination span into which the converted values should be written. + /// Destination is too short. + /// + /// + /// This method effectively computes [i] = TTo.CreateSaturating([i]). + /// + /// + public static void ConvertSaturating(ReadOnlySpan source, Span destination) + where TFrom : INumberBase + where TTo : INumberBase + { + if (!TryConvertUniversal(source, destination)) + { + InvokeSpanIntoSpan>(source, destination); + } + } + + /// + /// Copies to , converting each + /// value to a value. + /// + /// The source span from which to copy values. + /// The destination span into which the converted values should be written. + /// Destination is too short. + /// + /// + /// This method effectively computes [i] = TTo.CreateTruncating([i]). + /// + /// + public static void ConvertTruncating(ReadOnlySpan source, Span destination) + where TFrom : INumberBase + where TTo : INumberBase + { + if (TryConvertUniversal(source, destination)) + { + return; + } + + if (((typeof(TFrom) == typeof(byte) || typeof(TFrom) == typeof(sbyte)) && (typeof(TTo) == typeof(byte) || typeof(TTo) == typeof(sbyte))) || + ((typeof(TFrom) == typeof(ushort) || typeof(TFrom) == typeof(short)) && (typeof(TTo) == typeof(ushort) || typeof(TTo) == typeof(short))) || + ((IsUInt32Like() || IsInt32Like()) && (IsUInt32Like() || IsInt32Like())) || + ((IsUInt64Like() || IsInt64Like()) && (IsUInt64Like() || IsInt64Like()))) + { + source.CopyTo(Rename(destination)); + return; + } + + if (typeof(TFrom) == typeof(float) && IsUInt32Like()) + { + InvokeSpanIntoSpan(Rename(source), Rename(destination)); + return; + } + + if (typeof(TFrom) == typeof(float) && IsInt32Like()) + { + InvokeSpanIntoSpan(Rename(source), Rename(destination)); + return; + } + + if (typeof(TFrom) == typeof(double) && IsUInt64Like()) + { + InvokeSpanIntoSpan(Rename(source), Rename(destination)); + return; + } + + if (typeof(TFrom) == typeof(double) && IsInt64Like()) + { + InvokeSpanIntoSpan(Rename(source), Rename(destination)); + return; + } + + if (typeof(TFrom) == typeof(ushort) && typeof(TTo) == typeof(byte)) + { + InvokeSpanIntoSpan_2to1(Rename(source), Rename(destination)); + return; + } + + if (typeof(TFrom) == typeof(short) && typeof(TTo) == typeof(sbyte)) + { + InvokeSpanIntoSpan_2to1(Rename(source), Rename(destination)); + return; + } + + if (IsUInt32Like() && typeof(TTo) == typeof(ushort)) + { + InvokeSpanIntoSpan_2to1(Rename(source), Rename(destination)); + return; + } + + if (IsInt32Like() && typeof(TTo) == typeof(short)) + { + InvokeSpanIntoSpan_2to1(Rename(source), Rename(destination)); + return; + } + + if (IsUInt64Like() && IsUInt32Like()) + { + InvokeSpanIntoSpan_2to1(Rename(source), Rename(destination)); + return; + } + + if (IsInt64Like() && IsInt32Like()) + { + InvokeSpanIntoSpan_2to1(Rename(source), Rename(destination)); + return; + } + + InvokeSpanIntoSpan>(source, destination); + } + + /// Performs conversions that are the same regardless of checked, truncating, or saturation. + [MethodImpl(MethodImplOptions.AggressiveInlining)] // at most one of the branches will be kept + private static bool TryConvertUniversal(ReadOnlySpan source, Span destination) + where TFrom : INumberBase + where TTo : INumberBase + { + if (typeof(TFrom) == typeof(TTo)) + { + if (source.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(source, Rename(destination)); + + source.CopyTo(Rename(destination)); + return true; + } + + if (IsInt32Like() && typeof(TTo) == typeof(float)) + { + InvokeSpanIntoSpan(Rename(source), Rename(destination)); + return true; + } + + if (IsUInt32Like() && typeof(TTo) == typeof(float)) + { + InvokeSpanIntoSpan(Rename(source), Rename(destination)); + return true; + } + + if (IsInt64Like() && typeof(TTo) == typeof(double)) + { + InvokeSpanIntoSpan(Rename(source), Rename(destination)); + return true; + } + + if (IsUInt64Like() && typeof(TTo) == typeof(double)) + { + InvokeSpanIntoSpan(Rename(source), Rename(destination)); + return true; + } + + if (typeof(TFrom) == typeof(float) && typeof(TTo) == typeof(Half)) + { + ConvertToHalf(Rename(source), Rename(destination)); + return true; + } + + if (typeof(TFrom) == typeof(Half) && typeof(TTo) == typeof(float)) + { + ConvertToSingle(Rename(source), Rename(destination)); + return true; + } + + if (typeof(TFrom) == typeof(float) && typeof(TTo) == typeof(double)) + { + InvokeSpanIntoSpan_1to2(Rename(source), Rename(destination)); + return true; + } + + if (typeof(TFrom) == typeof(double) && typeof(TTo) == typeof(float)) + { + InvokeSpanIntoSpan_2to1(Rename(source), Rename(destination)); + return true; + } + + if (typeof(TFrom) == typeof(byte) && typeof(TTo) == typeof(ushort)) + { + InvokeSpanIntoSpan_1to2(Rename(source), Rename(destination)); + return true; + } + + if (typeof(TFrom) == typeof(sbyte) && typeof(TTo) == typeof(short)) + { + InvokeSpanIntoSpan_1to2(Rename(source), Rename(destination)); + return true; + } + + if (typeof(TFrom) == typeof(ushort) && IsUInt32Like()) + { + InvokeSpanIntoSpan_1to2(Rename(source), Rename(destination)); + return true; + } + + if (typeof(TFrom) == typeof(short) && IsInt32Like()) + { + InvokeSpanIntoSpan_1to2(Rename(source), Rename(destination)); + return true; + } + + if (IsUInt32Like() && IsUInt64Like()) + { + InvokeSpanIntoSpan_1to2(Rename(source), Rename(destination)); + return true; + } + + if (IsInt32Like() && IsInt64Like()) + { + InvokeSpanIntoSpan_1to2(Rename(source), Rename(destination)); + return true; + } + + return false; + } + /// Computes the element-wise result of copying the sign from one number to another number in the specified tensors. /// The first tensor, represented as a span. /// The second tensor, represented as a span. @@ -963,15 +1208,14 @@ public static void Ieee754Remainder(T x, ReadOnlySpan y, Span destinati public static void ILogB(ReadOnlySpan x, Span destination) where T : IFloatingPointIeee754 { - if (x.Length > destination.Length) + if (typeof(T) == typeof(double)) { - ThrowHelper.ThrowArgument_DestinationTooShort(); + // Special-case double as the only vectorizable floating-point type whose size != sizeof(int). + InvokeSpanIntoSpan_2to1(Rename(x), destination); } - - // TODO: Vectorize - for (int i = 0; i < x.Length; i++) + else { - destination[i] = T.ILogB(x[i]); + InvokeSpanIntoSpan>(x, destination); } } diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.netcore.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.netcore.cs index bb9285b59d1e2e..fec346b381913f 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.netcore.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.netcore.cs @@ -800,7 +800,7 @@ private static T CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan y) w /// private static T Aggregate( ReadOnlySpan x) - where TTransformOperator : struct, IUnaryOperator + where TTransformOperator : struct, IUnaryOperator where TAggregationOperator : struct, IAggregationOperator { // Since every branch has a cost and since that cost is @@ -2242,7 +2242,7 @@ static T Vectorized512(ref T xRef, ref T yRef, nuint remainder) case 3: { Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3)), - Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 3))); + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 3))); vresult = TAggregationOperator.Invoke(vresult, vector); goto case 2; } @@ -3021,32 +3021,49 @@ private static int IndexOfFirstMatch(Vector512 mask) => BitOperations.TrailingZeroCount(mask.ExtractMostSignificantBits()); /// Performs an element-wise operation on and writes the results to . - /// The element type. + /// The element input type. /// Specifies the operation to perform on each element loaded from . private static void InvokeSpanIntoSpan( ReadOnlySpan x, Span destination) - where TUnaryOperator : struct, IUnaryOperator + where TUnaryOperator : struct, IUnaryOperator => + InvokeSpanIntoSpan(x, destination); + + /// Performs an element-wise operation on and writes the results to . + /// The element input type. + /// The element output type. Must be the same size as TInput if TInput and TOutput both support vectorization. + /// Specifies the operation to perform on each element loaded from . + /// + /// This supports vectorizing the operation if and are the same size. + /// Otherwise, it'll fall back to scalar operations. + /// + private static void InvokeSpanIntoSpan( + ReadOnlySpan x, Span destination) + where TUnaryOperator : struct, IUnaryOperator { if (x.Length > destination.Length) { ThrowHelper.ThrowArgument_DestinationTooShort(); } - ValidateInputOutputSpanNonOverlapping(x, destination); + if (typeof(TInput) == typeof(TOutput)) + { + // This ignores the unsafe case where a developer passes in overlapping spans for distinct types. + ValidateInputOutputSpanNonOverlapping(x, Rename(destination)); + } // Since every branch has a cost and since that cost is // essentially lost for larger inputs, we do branches // in a way that allows us to have the minimum possible // for small sizes - ref T xRef = ref MemoryMarshal.GetReference(x); - ref T dRef = ref MemoryMarshal.GetReference(destination); + ref TInput xRef = ref MemoryMarshal.GetReference(x); + ref TOutput dRef = ref MemoryMarshal.GetReference(destination); nuint remainder = (uint)x.Length; - if (Vector512.IsHardwareAccelerated && Vector512.IsSupported && TUnaryOperator.Vectorizable && Unsafe.SizeOf() >= 4) + if (Vector512.IsHardwareAccelerated && Vector512.IsSupported && Vector512.IsSupported && TUnaryOperator.Vectorizable && Unsafe.SizeOf() >= 4 && Unsafe.SizeOf() == Unsafe.SizeOf()) { - if (remainder >= (uint)Vector512.Count) + if (remainder >= (uint)Vector512.Count) { Vectorized512(ref xRef, ref dRef, remainder); } @@ -3062,9 +3079,9 @@ private static void InvokeSpanIntoSpan( return; } - if (Vector256.IsHardwareAccelerated && Vector256.IsSupported && TUnaryOperator.Vectorizable && Unsafe.SizeOf() >= 4) + if (Vector256.IsHardwareAccelerated && Vector256.IsSupported && Vector256.IsSupported && TUnaryOperator.Vectorizable && Unsafe.SizeOf() >= 4 && Unsafe.SizeOf() == Unsafe.SizeOf()) { - if (remainder >= (uint)Vector256.Count) + if (remainder >= (uint)Vector256.Count) { Vectorized256(ref xRef, ref dRef, remainder); } @@ -3080,9 +3097,9 @@ private static void InvokeSpanIntoSpan( return; } - if (Vector128.IsHardwareAccelerated && Vector128.IsSupported && TUnaryOperator.Vectorizable && Unsafe.SizeOf() >= 4) + if (Vector128.IsHardwareAccelerated && Vector128.IsSupported && Vector128.IsSupported && TUnaryOperator.Vectorizable && Unsafe.SizeOf() >= 4 && Unsafe.SizeOf() == Unsafe.SizeOf()) { - if (remainder >= (uint)Vector128.Count) + if (remainder >= (uint)Vector128.Count) { Vectorized128(ref xRef, ref dRef, remainder); } @@ -3104,7 +3121,7 @@ private static void InvokeSpanIntoSpan( SoftwareFallback(ref xRef, ref dRef, remainder); [MethodImpl(MethodImplOptions.AggressiveInlining)] - static void SoftwareFallback(ref T xRef, ref T dRef, nuint length) + static void SoftwareFallback(ref TInput xRef, ref TOutput dRef, nuint length) { for (nuint i = 0; i < length; i++) { @@ -3112,31 +3129,31 @@ static void SoftwareFallback(ref T xRef, ref T dRef, nuint length) } } - static void Vectorized128(ref T xRef, ref T dRef, nuint remainder) + static void Vectorized128(ref TInput xRef, ref TOutput dRef, nuint remainder) { - ref T dRefBeg = ref dRef; + ref TOutput dRefBeg = ref dRef; // Preload the beginning and end so that overlapping accesses don't negatively impact the data - Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); - Vector128 end = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128.Count)); + Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 end = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128.Count)); - if (remainder > (uint)(Vector128.Count * 8)) + if (remainder > (uint)(Vector128.Count * 8)) { // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. - fixed (T* px = &xRef) - fixed (T* pd = &dRef) + fixed (TInput* px = &xRef) + fixed (TOutput* pd = &dRef) { - T* xPtr = px; - T* dPtr = pd; + TInput* xPtr = px; + TOutput* dPtr = pd; // We need to the ensure the underlying data can be aligned and only align // it if it can. It is possible we have an unaligned ref, in which case we // can never achieve the required SIMD alignment. - bool canAlign = ((nuint)dPtr % (nuint)sizeof(T)) == 0; + bool canAlign = ((nuint)dPtr % (nuint)sizeof(TInput)) == 0; if (canAlign) { @@ -3146,96 +3163,96 @@ static void Vectorized128(ref T xRef, ref T dRef, nuint remainder) // are more expensive than unaligned loads and aligning both is significantly more // complex. - nuint misalignment = ((uint)sizeof(Vector128) - ((nuint)dPtr % (uint)sizeof(Vector128))) / (uint)sizeof(T); + nuint misalignment = ((uint)sizeof(Vector128) - ((nuint)dPtr % (uint)sizeof(Vector128))) / (uint)sizeof(TInput); xPtr += misalignment; dPtr += misalignment; - Debug.Assert(((nuint)dPtr % (uint)sizeof(Vector128)) == 0); + Debug.Assert(((nuint)dPtr % (uint)sizeof(Vector128)) == 0); remainder -= misalignment; } - Vector128 vector1; - Vector128 vector2; - Vector128 vector3; - Vector128 vector4; + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; - if ((remainder > (NonTemporalByteThreshold / (nuint)sizeof(T))) && canAlign) + if ((remainder > (NonTemporalByteThreshold / (nuint)sizeof(TInput))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. - while (remainder >= (uint)(Vector128.Count * 8)) + while (remainder >= (uint)(Vector128.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0))); - vector2 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1))); - vector3 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2))); - vector4 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3))); + vector1 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0))); + vector2 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1))); + vector3 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2))); + vector4 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3))); - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); // We load, process, and store the next four vectors - vector1 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4))); - vector2 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5))); - vector3 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6))); - vector4 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7))); + vector1 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4))); + vector2 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5))); + vector3 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6))); + vector4 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7))); - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector128.Count * 8); - dPtr += (uint)(Vector128.Count * 8); + xPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); - remainder -= (uint)(Vector128.Count * 8); + remainder -= (uint)(Vector128.Count * 8); } } else { - while (remainder >= (uint)(Vector128.Count * 8)) + while (remainder >= (uint)(Vector128.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0))); - vector2 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1))); - vector3 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2))); - vector4 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3))); + vector1 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0))); + vector2 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1))); + vector3 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2))); + vector4 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3))); - vector1.Store(dPtr + (uint)(Vector128.Count * 0)); - vector2.Store(dPtr + (uint)(Vector128.Count * 1)); - vector3.Store(dPtr + (uint)(Vector128.Count * 2)); - vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); // We load, process, and store the next four vectors - vector1 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4))); - vector2 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5))); - vector3 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6))); - vector4 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7))); + vector1 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4))); + vector2 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5))); + vector3 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6))); + vector4 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7))); - vector1.Store(dPtr + (uint)(Vector128.Count * 4)); - vector2.Store(dPtr + (uint)(Vector128.Count * 5)); - vector3.Store(dPtr + (uint)(Vector128.Count * 6)); - vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector128.Count * 8); - dPtr += (uint)(Vector128.Count * 8); + xPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); - remainder -= (uint)(Vector128.Count * 8); + remainder -= (uint)(Vector128.Count * 8); } } @@ -3254,63 +3271,63 @@ static void Vectorized128(ref T xRef, ref T dRef, nuint remainder) // data before the first aligned address. nuint endIndex = remainder; - remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); - switch (remainder / (uint)Vector128.Count) + switch (remainder / (uint)Vector128.Count) { case 8: { - Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); goto case 7; } case 7: { - Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); goto case 6; } case 6: { - Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); goto case 5; } case 5: { - Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); goto case 4; } case 4: { - Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); goto case 3; } case 3: { - Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); goto case 2; } case 2: { - Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); goto case 1; } case 1: { // Store the last block, which includes any elements that wouldn't fill a full vector - end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); goto case 0; } @@ -3323,31 +3340,31 @@ static void Vectorized128(ref T xRef, ref T dRef, nuint remainder) } } - static void Vectorized256(ref T xRef, ref T dRef, nuint remainder) + static void Vectorized256(ref TInput xRef, ref TOutput dRef, nuint remainder) { - ref T dRefBeg = ref dRef; + ref TOutput dRefBeg = ref dRef; // Preload the beginning and end so that overlapping accesses don't negatively impact the data - Vector256 beg = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); - Vector256 end = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)Vector256.Count)); + Vector256 beg = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); + Vector256 end = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)Vector256.Count)); - if (remainder > (uint)(Vector256.Count * 8)) + if (remainder > (uint)(Vector256.Count * 8)) { // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. - fixed (T* px = &xRef) - fixed (T* pd = &dRef) + fixed (TInput* px = &xRef) + fixed (TOutput* pd = &dRef) { - T* xPtr = px; - T* dPtr = pd; + TInput* xPtr = px; + TOutput* dPtr = pd; // We need to the ensure the underlying data can be aligned and only align // it if it can. It is possible we have an unaligned ref, in which case we // can never achieve the required SIMD alignment. - bool canAlign = ((nuint)dPtr % (nuint)sizeof(T)) == 0; + bool canAlign = ((nuint)dPtr % (nuint)sizeof(TInput)) == 0; if (canAlign) { @@ -3357,96 +3374,96 @@ static void Vectorized256(ref T xRef, ref T dRef, nuint remainder) // are more expensive than unaligned loads and aligning both is significantly more // complex. - nuint misalignment = ((uint)sizeof(Vector256) - ((nuint)dPtr % (uint)sizeof(Vector256))) / (uint)sizeof(T); + nuint misalignment = ((uint)sizeof(Vector256) - ((nuint)dPtr % (uint)sizeof(Vector256))) / (uint)sizeof(TInput); xPtr += misalignment; dPtr += misalignment; - Debug.Assert(((nuint)dPtr % (uint)sizeof(Vector256)) == 0); + Debug.Assert(((nuint)dPtr % (uint)sizeof(Vector256)) == 0); remainder -= misalignment; } - Vector256 vector1; - Vector256 vector2; - Vector256 vector3; - Vector256 vector4; + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; - if ((remainder > (NonTemporalByteThreshold / (nuint)sizeof(T))) && canAlign) + if ((remainder > (NonTemporalByteThreshold / (nuint)sizeof(TInput))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. - while (remainder >= (uint)(Vector256.Count * 8)) + while (remainder >= (uint)(Vector256.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0))); - vector2 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1))); - vector3 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2))); - vector4 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3))); + vector1 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0))); + vector2 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1))); + vector3 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2))); + vector4 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3))); - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); // We load, process, and store the next four vectors - vector1 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4))); - vector2 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5))); - vector3 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6))); - vector4 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7))); + vector1 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4))); + vector2 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5))); + vector3 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6))); + vector4 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7))); - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector256.Count * 8); - dPtr += (uint)(Vector256.Count * 8); + xPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); - remainder -= (uint)(Vector256.Count * 8); + remainder -= (uint)(Vector256.Count * 8); } } else { - while (remainder >= (uint)(Vector256.Count * 8)) + while (remainder >= (uint)(Vector256.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0))); - vector2 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1))); - vector3 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2))); - vector4 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3))); + vector1 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0))); + vector2 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1))); + vector3 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2))); + vector4 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3))); - vector1.Store(dPtr + (uint)(Vector256.Count * 0)); - vector2.Store(dPtr + (uint)(Vector256.Count * 1)); - vector3.Store(dPtr + (uint)(Vector256.Count * 2)); - vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); // We load, process, and store the next four vectors - vector1 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4))); - vector2 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5))); - vector3 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6))); - vector4 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7))); + vector1 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4))); + vector2 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5))); + vector3 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6))); + vector4 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7))); - vector1.Store(dPtr + (uint)(Vector256.Count * 4)); - vector2.Store(dPtr + (uint)(Vector256.Count * 5)); - vector3.Store(dPtr + (uint)(Vector256.Count * 6)); - vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector256.Count * 8); - dPtr += (uint)(Vector256.Count * 8); + xPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); - remainder -= (uint)(Vector256.Count * 8); + remainder -= (uint)(Vector256.Count * 8); } } @@ -3465,63 +3482,63 @@ static void Vectorized256(ref T xRef, ref T dRef, nuint remainder) // data before the first aligned address. nuint endIndex = remainder; - remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); - switch (remainder / (uint)Vector256.Count) + switch (remainder / (uint)Vector256.Count) { case 8: { - Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); goto case 7; } case 7: { - Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); goto case 6; } case 6: { - Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); goto case 5; } case 5: { - Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); goto case 4; } case 4: { - Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); goto case 3; } case 3: { - Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); goto case 2; } case 2: { - Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); goto case 1; } case 1: { // Store the last block, which includes any elements that wouldn't fill a full vector - end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); goto case 0; } @@ -3534,31 +3551,31 @@ static void Vectorized256(ref T xRef, ref T dRef, nuint remainder) } } - static void Vectorized512(ref T xRef, ref T dRef, nuint remainder) + static void Vectorized512(ref TInput xRef, ref TOutput dRef, nuint remainder) { - ref T dRefBeg = ref dRef; + ref TOutput dRefBeg = ref dRef; // Preload the beginning and end so that overlapping accesses don't negatively impact the data - Vector512 beg = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef)); - Vector512 end = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)Vector512.Count)); + Vector512 beg = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef)); + Vector512 end = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)Vector512.Count)); - if (remainder > (uint)(Vector512.Count * 8)) + if (remainder > (uint)(Vector512.Count * 8)) { // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. - fixed (T* px = &xRef) - fixed (T* pd = &dRef) + fixed (TInput* px = &xRef) + fixed (TOutput* pd = &dRef) { - T* xPtr = px; - T* dPtr = pd; + TInput* xPtr = px; + TOutput* dPtr = pd; // We need to the ensure the underlying data can be aligned and only align // it if it can. It is possible we have an unaligned ref, in which case we // can never achieve the required SIMD alignment. - bool canAlign = ((nuint)dPtr % (nuint)sizeof(T)) == 0; + bool canAlign = ((nuint)dPtr % (nuint)sizeof(TInput)) == 0; if (canAlign) { @@ -3568,96 +3585,96 @@ static void Vectorized512(ref T xRef, ref T dRef, nuint remainder) // are more expensive than unaligned loads and aligning both is significantly more // complex. - nuint misalignment = ((uint)sizeof(Vector512) - ((nuint)dPtr % (uint)sizeof(Vector512))) / (uint)sizeof(T); + nuint misalignment = ((uint)sizeof(Vector512) - ((nuint)dPtr % (uint)sizeof(Vector512))) / (uint)sizeof(TInput); xPtr += misalignment; dPtr += misalignment; - Debug.Assert(((nuint)dPtr % (uint)sizeof(Vector512)) == 0); + Debug.Assert(((nuint)dPtr % (uint)sizeof(Vector512)) == 0); remainder -= misalignment; } - Vector512 vector1; - Vector512 vector2; - Vector512 vector3; - Vector512 vector4; + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; - if ((remainder > (NonTemporalByteThreshold / (nuint)sizeof(T))) && canAlign) + if ((remainder > (NonTemporalByteThreshold / (nuint)sizeof(TInput))) && canAlign) { // This loop stores the data non-temporally, which benefits us when there // is a large amount of data involved as it avoids polluting the cache. - while (remainder >= (uint)(Vector512.Count * 8)) + while (remainder >= (uint)(Vector512.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0))); - vector2 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1))); - vector3 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2))); - vector4 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3))); + vector1 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0))); + vector2 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1))); + vector3 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2))); + vector4 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3))); - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); // We load, process, and store the next four vectors - vector1 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4))); - vector2 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5))); - vector3 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6))); - vector4 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7))); + vector1 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4))); + vector2 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5))); + vector3 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6))); + vector4 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7))); - vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); - vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); - vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); - vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector512.Count * 8); - dPtr += (uint)(Vector512.Count * 8); + xPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); - remainder -= (uint)(Vector512.Count * 8); + remainder -= (uint)(Vector512.Count * 8); } } else { - while (remainder >= (uint)(Vector512.Count * 8)) + while (remainder >= (uint)(Vector512.Count * 8)) { // We load, process, and store the first four vectors - vector1 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0))); - vector2 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1))); - vector3 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2))); - vector4 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3))); + vector1 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0))); + vector2 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1))); + vector3 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2))); + vector4 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3))); - vector1.Store(dPtr + (uint)(Vector512.Count * 0)); - vector2.Store(dPtr + (uint)(Vector512.Count * 1)); - vector3.Store(dPtr + (uint)(Vector512.Count * 2)); - vector4.Store(dPtr + (uint)(Vector512.Count * 3)); + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); // We load, process, and store the next four vectors - vector1 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4))); - vector2 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5))); - vector3 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6))); - vector4 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7))); + vector1 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4))); + vector2 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5))); + vector3 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6))); + vector4 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7))); - vector1.Store(dPtr + (uint)(Vector512.Count * 4)); - vector2.Store(dPtr + (uint)(Vector512.Count * 5)); - vector3.Store(dPtr + (uint)(Vector512.Count * 6)); - vector4.Store(dPtr + (uint)(Vector512.Count * 7)); + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); // We adjust the source and destination references, then update // the count of remaining elements to process. - xPtr += (uint)(Vector512.Count * 8); - dPtr += (uint)(Vector512.Count * 8); + xPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); - remainder -= (uint)(Vector512.Count * 8); + remainder -= (uint)(Vector512.Count * 8); } } @@ -3676,63 +3693,63 @@ static void Vectorized512(ref T xRef, ref T dRef, nuint remainder) // data before the first aligned address. nuint endIndex = remainder; - remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); - switch (remainder / (uint)Vector512.Count) + switch (remainder / (uint)Vector512.Count) { case 8: { - Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); goto case 7; } case 7: { - Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); goto case 6; } case 6: { - Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); goto case 5; } case 5: { - Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); goto case 4; } case 4: { - Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); goto case 3; } case 3: { - Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); goto case 2; } case 2: { - Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2))); - vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); goto case 1; } case 1: { // Store the last block, which includes any elements that wouldn't fill a full vector - end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); goto case 0; } @@ -3746,23 +3763,23 @@ static void Vectorized512(ref T xRef, ref T dRef, nuint remainder) } [MethodImpl(MethodImplOptions.AggressiveInlining)] - static void VectorizedSmall(ref T xRef, ref T dRef, nuint remainder) + static void VectorizedSmall(ref TInput xRef, ref TOutput dRef, nuint remainder) { - if (sizeof(T) == 4) + if (sizeof(TInput) == 4) { VectorizedSmall4(ref xRef, ref dRef, remainder); } else { - Debug.Assert(sizeof(T) == 8); + Debug.Assert(sizeof(TInput) == 8); VectorizedSmall8(ref xRef, ref dRef, remainder); } } [MethodImpl(MethodImplOptions.AggressiveInlining)] - static void VectorizedSmall4(ref T xRef, ref T dRef, nuint remainder) + static void VectorizedSmall4(ref TInput xRef, ref TOutput dRef, nuint remainder) { - Debug.Assert(sizeof(T) == 4); + Debug.Assert(sizeof(TInput) == 4); switch (remainder) { @@ -3776,11 +3793,11 @@ static void VectorizedSmall4(ref T xRef, ref T dRef, nuint remainder) { Debug.Assert(Vector256.IsHardwareAccelerated); - Vector256 beg = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); - Vector256 end = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)Vector256.Count)); + Vector256 beg = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); + Vector256 end = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)Vector256.Count)); beg.StoreUnsafe(ref dRef); - end.StoreUnsafe(ref dRef, remainder - (uint)Vector256.Count); + end.StoreUnsafe(ref dRef, remainder - (uint)Vector256.Count); break; } @@ -3789,7 +3806,7 @@ static void VectorizedSmall4(ref T xRef, ref T dRef, nuint remainder) { Debug.Assert(Vector256.IsHardwareAccelerated); - Vector256 beg = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); + Vector256 beg = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); beg.StoreUnsafe(ref dRef); break; @@ -3801,11 +3818,11 @@ static void VectorizedSmall4(ref T xRef, ref T dRef, nuint remainder) { Debug.Assert(Vector128.IsHardwareAccelerated); - Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); - Vector128 end = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128.Count)); + Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 end = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128.Count)); beg.StoreUnsafe(ref dRef); - end.StoreUnsafe(ref dRef, remainder - (uint)Vector128.Count); + end.StoreUnsafe(ref dRef, remainder - (uint)Vector128.Count); break; } @@ -3814,7 +3831,7 @@ static void VectorizedSmall4(ref T xRef, ref T dRef, nuint remainder) { Debug.Assert(Vector128.IsHardwareAccelerated); - Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); beg.StoreUnsafe(ref dRef); break; @@ -3846,9 +3863,9 @@ static void VectorizedSmall4(ref T xRef, ref T dRef, nuint remainder) } [MethodImpl(MethodImplOptions.AggressiveInlining)] - static void VectorizedSmall8(ref T xRef, ref T dRef, nuint remainder) + static void VectorizedSmall8(ref TInput xRef, ref TOutput dRef, nuint remainder) { - Debug.Assert(sizeof(T) == 8); + Debug.Assert(sizeof(TInput) == 8); switch (remainder) { @@ -3858,11 +3875,11 @@ static void VectorizedSmall8(ref T xRef, ref T dRef, nuint remainder) { Debug.Assert(Vector256.IsHardwareAccelerated); - Vector256 beg = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); - Vector256 end = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)Vector256.Count)); + Vector256 beg = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); + Vector256 end = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)Vector256.Count)); beg.StoreUnsafe(ref dRef); - end.StoreUnsafe(ref dRef, remainder - (uint)Vector256.Count); + end.StoreUnsafe(ref dRef, remainder - (uint)Vector256.Count); break; } @@ -3871,7 +3888,7 @@ static void VectorizedSmall8(ref T xRef, ref T dRef, nuint remainder) { Debug.Assert(Vector256.IsHardwareAccelerated); - Vector256 beg = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); + Vector256 beg = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); beg.StoreUnsafe(ref dRef); break; @@ -3881,11 +3898,11 @@ static void VectorizedSmall8(ref T xRef, ref T dRef, nuint remainder) { Debug.Assert(Vector128.IsHardwareAccelerated); - Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); - Vector128 end = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128.Count)); + Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 end = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128.Count)); beg.StoreUnsafe(ref dRef); - end.StoreUnsafe(ref dRef, remainder - (uint)Vector128.Count); + end.StoreUnsafe(ref dRef, remainder - (uint)Vector128.Count); break; } @@ -3894,7 +3911,7 @@ static void VectorizedSmall8(ref T xRef, ref T dRef, nuint remainder) { Debug.Assert(Vector128.IsHardwareAccelerated); - Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); beg.StoreUnsafe(ref dRef); break; @@ -3902,16 +3919,268 @@ static void VectorizedSmall8(ref T xRef, ref T dRef, nuint remainder) case 1: { - dRef = TUnaryOperator.Invoke(xRef); - goto case 0; + dRef = TUnaryOperator.Invoke(xRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + } + + /// Performs an element-wise operation on and writes the results to . + /// The element input type. + /// The element output type. Must be the same size as TInput if TInput and TOutput both support vectorization. + /// Specifies the operation to perform on each element loaded from . + /// This should only be used when it's known that TInput/TOutput are vectorizable and the size of TInput is twice that of TOutput. + private static void InvokeSpanIntoSpan_2to1( + ReadOnlySpan x, Span destination) + where TUnaryOperator : struct, IUnaryTwoToOneOperator + { + Debug.Assert(sizeof(TInput) == sizeof(TOutput) * 2); + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ref TInput xRef = ref MemoryMarshal.GetReference(x); + ref TOutput destinationRef = ref MemoryMarshal.GetReference(destination); + int i = 0, twoVectorsFromEnd; + + if (Vector512.IsHardwareAccelerated && TUnaryOperator.Vectorizable) + { + Debug.Assert(Vector512.IsSupported); + Debug.Assert(Vector512.IsSupported); + + twoVectorsFromEnd = x.Length - (Vector512.Count * 2); + if (i <= twoVectorsFromEnd) + { + // Loop handling two input vectors / one output vector at a time. + do + { + TUnaryOperator.Invoke( + Vector512.LoadUnsafe(ref xRef, (uint)i), + Vector512.LoadUnsafe(ref xRef, (uint)(i + Vector512.Count))).StoreUnsafe(ref destinationRef, (uint)i); + + i += Vector512.Count * 2; + } + while (i <= twoVectorsFromEnd); + + // Handle any remaining elements with final vectors. + if (i != x.Length) + { + i = x.Length - (Vector512.Count * 2); + + TUnaryOperator.Invoke( + Vector512.LoadUnsafe(ref xRef, (uint)i), + Vector512.LoadUnsafe(ref xRef, (uint)(i + Vector512.Count))).StoreUnsafe(ref destinationRef, (uint)i); + } + + return; + } + } + + if (Vector256.IsHardwareAccelerated && TUnaryOperator.Vectorizable) + { + Debug.Assert(Vector256.IsSupported); + Debug.Assert(Vector256.IsSupported); + + twoVectorsFromEnd = x.Length - (Vector256.Count * 2); + if (i <= twoVectorsFromEnd) + { + // Loop handling two input vectors / one output vector at a time. + do + { + TUnaryOperator.Invoke( + Vector256.LoadUnsafe(ref xRef, (uint)i), + Vector256.LoadUnsafe(ref xRef, (uint)(i + Vector256.Count))).StoreUnsafe(ref destinationRef, (uint)i); + + i += Vector256.Count * 2; + } + while (i <= twoVectorsFromEnd); + + // Handle any remaining elements with final vectors. + if (i != x.Length) + { + i = x.Length - (Vector256.Count * 2); + + TUnaryOperator.Invoke( + Vector256.LoadUnsafe(ref xRef, (uint)i), + Vector256.LoadUnsafe(ref xRef, (uint)(i + Vector256.Count))).StoreUnsafe(ref destinationRef, (uint)i); + } + + return; + } + } + + if (Vector128.IsHardwareAccelerated && TUnaryOperator.Vectorizable) + { + Debug.Assert(Vector128.IsSupported); + Debug.Assert(Vector128.IsSupported); + + twoVectorsFromEnd = x.Length - (Vector128.Count * 2); + if (i <= twoVectorsFromEnd) + { + // Loop handling two input vectors / one output vector at a time. + do + { + TUnaryOperator.Invoke( + Vector128.LoadUnsafe(ref xRef, (uint)i), + Vector128.LoadUnsafe(ref xRef, (uint)(i + Vector128.Count))).StoreUnsafe(ref destinationRef, (uint)i); + + i += Vector128.Count * 2; + } + while (i <= twoVectorsFromEnd); + + // Handle any remaining elements with final vectors. + if (i != x.Length) + { + i = x.Length - (Vector128.Count * 2); + + TUnaryOperator.Invoke( + Vector128.LoadUnsafe(ref xRef, (uint)i), + Vector128.LoadUnsafe(ref xRef, (uint)(i + Vector128.Count))).StoreUnsafe(ref destinationRef, (uint)i); + } + + return; + } + } + + while (i < x.Length) + { + Unsafe.Add(ref destinationRef, i) = TUnaryOperator.Invoke(Unsafe.Add(ref xRef, i)); + i++; + } + } + + /// Performs an element-wise operation on and writes the results to . + /// The element input type. + /// The element output type. Must be the same size as TInput if TInput and TOutput both support vectorization. + /// Specifies the operation to perform on each element loaded from . + /// This should only be used when it's known that TInput/TOutput are vectorizable and the size of TInput is half that of TOutput. + private static void InvokeSpanIntoSpan_1to2( + ReadOnlySpan x, Span destination) + where TUnaryOperator : struct, IUnaryOneToTwoOperator + { + Debug.Assert(sizeof(TInput) * 2 == sizeof(TOutput)); + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ref TInput sourceRef = ref MemoryMarshal.GetReference(x); + ref TOutput destinationRef = ref MemoryMarshal.GetReference(destination); + int i = 0, oneVectorFromEnd; + + if (Vector512.IsHardwareAccelerated && TUnaryOperator.Vectorizable) + { + Debug.Assert(Vector512.IsSupported); + Debug.Assert(Vector512.IsSupported); + + oneVectorFromEnd = x.Length - Vector512.Count; + if (i <= oneVectorFromEnd) + { + // Loop handling one input vector / two output vectors at a time. + do + { + (Vector512 lower, Vector512 upper) = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref sourceRef, (uint)i)); + lower.StoreUnsafe(ref destinationRef, (uint)i); + upper.StoreUnsafe(ref destinationRef, (uint)(i + Vector512.Count)); + + i += Vector512.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final input vector. + if (i != x.Length) + { + i = x.Length - Vector512.Count; + + (Vector512 lower, Vector512 upper) = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref sourceRef, (uint)i)); + lower.StoreUnsafe(ref destinationRef, (uint)i); + upper.StoreUnsafe(ref destinationRef, (uint)(i + Vector512.Count)); + } + + return; + } + } + + if (Vector256.IsHardwareAccelerated && TUnaryOperator.Vectorizable) + { + Debug.Assert(Vector256.IsSupported); + Debug.Assert(Vector256.IsSupported); + + oneVectorFromEnd = x.Length - Vector256.Count; + if (i <= oneVectorFromEnd) + { + // Loop handling one input vector / two output vectors at a time. + do + { + (Vector256 lower, Vector256 upper) = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref sourceRef, (uint)i)); + lower.StoreUnsafe(ref destinationRef, (uint)i); + upper.StoreUnsafe(ref destinationRef, (uint)(i + Vector256.Count)); + + i += Vector256.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final input vector. + if (i != x.Length) + { + i = x.Length - Vector256.Count; + + (Vector256 lower, Vector256 upper) = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref sourceRef, (uint)i)); + lower.StoreUnsafe(ref destinationRef, (uint)i); + upper.StoreUnsafe(ref destinationRef, (uint)(i + Vector256.Count)); + } + + return; + } + } + + if (Vector128.IsHardwareAccelerated && TUnaryOperator.Vectorizable) + { + Debug.Assert(Vector128.IsSupported); + Debug.Assert(Vector128.IsSupported); + + oneVectorFromEnd = x.Length - Vector128.Count; + if (i <= oneVectorFromEnd) + { + // Loop handling one input vector / two output vectors at a time. + do + { + (Vector128 lower, Vector128 upper) = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref sourceRef, (uint)i)); + lower.StoreUnsafe(ref destinationRef, (uint)i); + upper.StoreUnsafe(ref destinationRef, (uint)(i + Vector128.Count)); + + i += Vector128.Count; } + while (i <= oneVectorFromEnd); - case 0: + // Handle any remaining elements with a final input vector. + if (i != x.Length) { - break; + i = x.Length - Vector128.Count; + + (Vector128 lower, Vector128 upper) = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref sourceRef, (uint)i)); + lower.StoreUnsafe(ref destinationRef, (uint)i); + upper.StoreUnsafe(ref destinationRef, (uint)(i + Vector128.Count)); } + + return; } } + + while (i < x.Length) + { + Unsafe.Add(ref destinationRef, i) = TUnaryOperator.Invoke(Unsafe.Add(ref sourceRef, i)); + i++; + } } /// @@ -4968,7 +5237,7 @@ private static void InvokeSpanScalarIntoSpan( /// private static void InvokeSpanScalarIntoSpan( ReadOnlySpan x, T y, Span destination) - where TTransformOperator : struct, IUnaryOperator + where TTransformOperator : struct, IUnaryOperator where TBinaryOperator : struct, IBinaryOperator { if (x.Length > destination.Length) @@ -5866,7 +6135,7 @@ static void VectorizedSmall4(ref T xRef, T y, ref T dRef, nuint remainder) case 2: { Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 1)), - y); + y); goto case 1; } @@ -7037,8 +7306,8 @@ static void VectorizedSmall8(ref T xRef, ref T yRef, ref T zRef, ref T dRef, nui Debug.Assert(Vector256.IsHardwareAccelerated); Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), - Vector256.LoadUnsafe(ref yRef), - Vector256.LoadUnsafe(ref zRef)); + Vector256.LoadUnsafe(ref yRef), + Vector256.LoadUnsafe(ref zRef)); beg.StoreUnsafe(ref dRef); break; @@ -7049,11 +7318,11 @@ static void VectorizedSmall8(ref T xRef, ref T yRef, ref T zRef, ref T dRef, nui Debug.Assert(Vector128.IsHardwareAccelerated); Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), - Vector128.LoadUnsafe(ref yRef), - Vector128.LoadUnsafe(ref zRef)); + Vector128.LoadUnsafe(ref yRef), + Vector128.LoadUnsafe(ref zRef)); Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128.Count), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)Vector128.Count), - Vector128.LoadUnsafe(ref zRef, remainder - (uint)Vector128.Count)); + Vector128.LoadUnsafe(ref yRef, remainder - (uint)Vector128.Count), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)Vector128.Count)); beg.StoreUnsafe(ref dRef); end.StoreUnsafe(ref dRef, remainder - (uint)Vector128.Count); @@ -7066,8 +7335,8 @@ static void VectorizedSmall8(ref T xRef, ref T yRef, ref T zRef, ref T dRef, nui Debug.Assert(Vector128.IsHardwareAccelerated); Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), - Vector128.LoadUnsafe(ref yRef), - Vector128.LoadUnsafe(ref zRef)); + Vector128.LoadUnsafe(ref yRef), + Vector128.LoadUnsafe(ref zRef)); beg.StoreUnsafe(ref dRef); break; @@ -7387,8 +7656,8 @@ static void Vectorized128(ref T xRef, ref T yRef, T z, ref T dRef, nuint remaind case 8: { Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 8)), - zVec); + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 8)), + zVec); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); goto case 7; } @@ -7396,8 +7665,8 @@ static void Vectorized128(ref T xRef, ref T yRef, T z, ref T dRef, nuint remaind case 7: { Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 7)), - zVec); + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 7)), + zVec); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); goto case 6; } @@ -7405,8 +7674,8 @@ static void Vectorized128(ref T xRef, ref T yRef, T z, ref T dRef, nuint remaind case 6: { Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 6)), - zVec); + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 6)), + zVec); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); goto case 5; } @@ -7414,8 +7683,8 @@ static void Vectorized128(ref T xRef, ref T yRef, T z, ref T dRef, nuint remaind case 5: { Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 5)), - zVec); + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 5)), + zVec); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); goto case 4; } @@ -7423,8 +7692,8 @@ static void Vectorized128(ref T xRef, ref T yRef, T z, ref T dRef, nuint remaind case 4: { Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 4)), - zVec); + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 4)), + zVec); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); goto case 3; } @@ -7432,8 +7701,8 @@ static void Vectorized128(ref T xRef, ref T yRef, T z, ref T dRef, nuint remaind case 3: { Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 3)), - zVec); + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 3)), + zVec); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); goto case 2; } @@ -7441,8 +7710,8 @@ static void Vectorized128(ref T xRef, ref T yRef, T z, ref T dRef, nuint remaind case 2: { Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2)), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 2)), - zVec); + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 2)), + zVec); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); goto case 1; } @@ -7656,8 +7925,8 @@ static void Vectorized256(ref T xRef, ref T yRef, T z, ref T dRef, nuint remaind case 8: { Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 8)), - zVec); + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 8)), + zVec); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); goto case 7; } @@ -7665,8 +7934,8 @@ static void Vectorized256(ref T xRef, ref T yRef, T z, ref T dRef, nuint remaind case 7: { Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 7)), - zVec); + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 7)), + zVec); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); goto case 6; } @@ -7674,8 +7943,8 @@ static void Vectorized256(ref T xRef, ref T yRef, T z, ref T dRef, nuint remaind case 6: { Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 6)), - zVec); + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 6)), + zVec); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); goto case 5; } @@ -7683,8 +7952,8 @@ static void Vectorized256(ref T xRef, ref T yRef, T z, ref T dRef, nuint remaind case 5: { Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 5)), - zVec); + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 5)), + zVec); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); goto case 4; } @@ -7692,8 +7961,8 @@ static void Vectorized256(ref T xRef, ref T yRef, T z, ref T dRef, nuint remaind case 4: { Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 4)), - zVec); + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 4)), + zVec); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); goto case 3; } @@ -7701,8 +7970,8 @@ static void Vectorized256(ref T xRef, ref T yRef, T z, ref T dRef, nuint remaind case 3: { Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 3)), - zVec); + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 3)), + zVec); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); goto case 2; } @@ -7710,8 +7979,8 @@ static void Vectorized256(ref T xRef, ref T yRef, T z, ref T dRef, nuint remaind case 2: { Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2)), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 2)), - zVec); + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 2)), + zVec); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); goto case 1; } @@ -8035,11 +8304,11 @@ static void VectorizedSmall4(ref T xRef, ref T yRef, T z, ref T dRef, nuint rema Vector256 zVec = Vector256.Create(z); Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), - Vector256.LoadUnsafe(ref yRef), - zVec); + Vector256.LoadUnsafe(ref yRef), + zVec); Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)Vector256.Count), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)Vector256.Count), - zVec); + Vector256.LoadUnsafe(ref yRef, remainder - (uint)Vector256.Count), + zVec); beg.StoreUnsafe(ref dRef); end.StoreUnsafe(ref dRef, remainder - (uint)Vector256.Count); @@ -8052,8 +8321,8 @@ static void VectorizedSmall4(ref T xRef, ref T yRef, T z, ref T dRef, nuint rema Debug.Assert(Vector256.IsHardwareAccelerated); Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), - Vector256.LoadUnsafe(ref yRef), - Vector256.Create(z)); + Vector256.LoadUnsafe(ref yRef), + Vector256.Create(z)); beg.StoreUnsafe(ref dRef); break; @@ -8068,11 +8337,11 @@ static void VectorizedSmall4(ref T xRef, ref T yRef, T z, ref T dRef, nuint rema Vector128 zVec = Vector128.Create(z); Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), - Vector128.LoadUnsafe(ref yRef), - zVec); + Vector128.LoadUnsafe(ref yRef), + zVec); Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128.Count), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)Vector128.Count), - zVec); + Vector128.LoadUnsafe(ref yRef, remainder - (uint)Vector128.Count), + zVec); beg.StoreUnsafe(ref dRef); end.StoreUnsafe(ref dRef, remainder - (uint)Vector128.Count); @@ -8085,8 +8354,8 @@ static void VectorizedSmall4(ref T xRef, ref T yRef, T z, ref T dRef, nuint rema Debug.Assert(Vector128.IsHardwareAccelerated); Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), - Vector128.LoadUnsafe(ref yRef), - Vector128.Create(z)); + Vector128.LoadUnsafe(ref yRef), + Vector128.Create(z)); beg.StoreUnsafe(ref dRef); break; @@ -8137,11 +8406,11 @@ static void VectorizedSmall8(ref T xRef, ref T yRef, T z, ref T dRef, nuint rema Vector256 zVec = Vector256.Create(z); Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), - Vector256.LoadUnsafe(ref yRef), - zVec); + Vector256.LoadUnsafe(ref yRef), + zVec); Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)Vector256.Count), - Vector256.LoadUnsafe(ref yRef, remainder - (uint)Vector256.Count), - zVec); + Vector256.LoadUnsafe(ref yRef, remainder - (uint)Vector256.Count), + zVec); beg.StoreUnsafe(ref dRef); end.StoreUnsafe(ref dRef, remainder - (uint)Vector256.Count); @@ -8154,8 +8423,8 @@ static void VectorizedSmall8(ref T xRef, ref T yRef, T z, ref T dRef, nuint rema Debug.Assert(Vector256.IsHardwareAccelerated); Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), - Vector256.LoadUnsafe(ref yRef), - Vector256.Create(z)); + Vector256.LoadUnsafe(ref yRef), + Vector256.Create(z)); beg.StoreUnsafe(ref dRef); break; @@ -8168,11 +8437,11 @@ static void VectorizedSmall8(ref T xRef, ref T yRef, T z, ref T dRef, nuint rema Vector128 zVec = Vector128.Create(z); Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), - Vector128.LoadUnsafe(ref yRef), - zVec); + Vector128.LoadUnsafe(ref yRef), + zVec); Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128.Count), - Vector128.LoadUnsafe(ref yRef, remainder - (uint)Vector128.Count), - zVec); + Vector128.LoadUnsafe(ref yRef, remainder - (uint)Vector128.Count), + zVec); beg.StoreUnsafe(ref dRef); end.StoreUnsafe(ref dRef, remainder - (uint)Vector128.Count); @@ -8185,8 +8454,8 @@ static void VectorizedSmall8(ref T xRef, ref T yRef, T z, ref T dRef, nuint rema Debug.Assert(Vector128.IsHardwareAccelerated); Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), - Vector128.LoadUnsafe(ref yRef), - Vector128.Create(z)); + Vector128.LoadUnsafe(ref yRef), + Vector128.Create(z)); beg.StoreUnsafe(ref dRef); break; @@ -8506,8 +8775,8 @@ static void Vectorized128(ref T xRef, T y, ref T zRef, ref T dRef, nuint remaind case 8: { Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8)), - yVec, - Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 8))); + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 8))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); goto case 7; } @@ -8515,8 +8784,8 @@ static void Vectorized128(ref T xRef, T y, ref T zRef, ref T dRef, nuint remaind case 7: { Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7)), - yVec, - Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 7))); + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 7))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); goto case 6; } @@ -8524,8 +8793,8 @@ static void Vectorized128(ref T xRef, T y, ref T zRef, ref T dRef, nuint remaind case 6: { Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6)), - yVec, - Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 6))); + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 6))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); goto case 5; } @@ -8533,8 +8802,8 @@ static void Vectorized128(ref T xRef, T y, ref T zRef, ref T dRef, nuint remaind case 5: { Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5)), - yVec, - Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 5))); + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 5))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); goto case 4; } @@ -8542,8 +8811,8 @@ static void Vectorized128(ref T xRef, T y, ref T zRef, ref T dRef, nuint remaind case 4: { Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4)), - yVec, - Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 4))); + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 4))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); goto case 3; } @@ -8551,8 +8820,8 @@ static void Vectorized128(ref T xRef, T y, ref T zRef, ref T dRef, nuint remaind case 3: { Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3)), - yVec, - Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 3))); + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 3))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); goto case 2; } @@ -8560,8 +8829,8 @@ static void Vectorized128(ref T xRef, T y, ref T zRef, ref T dRef, nuint remaind case 2: { Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2)), - yVec, - Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 2))); + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 2))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); goto case 1; } @@ -8775,8 +9044,8 @@ static void Vectorized256(ref T xRef, T y, ref T zRef, ref T dRef, nuint remaind case 8: { Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8)), - yVec, - Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 8))); + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 8))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); goto case 7; } @@ -8784,8 +9053,8 @@ static void Vectorized256(ref T xRef, T y, ref T zRef, ref T dRef, nuint remaind case 7: { Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7)), - yVec, - Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 7))); + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 7))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); goto case 6; } @@ -8793,8 +9062,8 @@ static void Vectorized256(ref T xRef, T y, ref T zRef, ref T dRef, nuint remaind case 6: { Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6)), - yVec, - Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 6))); + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 6))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); goto case 5; } @@ -8802,8 +9071,8 @@ static void Vectorized256(ref T xRef, T y, ref T zRef, ref T dRef, nuint remaind case 5: { Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5)), - yVec, - Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 5))); + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 5))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); goto case 4; } @@ -8811,8 +9080,8 @@ static void Vectorized256(ref T xRef, T y, ref T zRef, ref T dRef, nuint remaind case 4: { Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4)), - yVec, - Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 4))); + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 4))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); goto case 3; } @@ -8820,8 +9089,8 @@ static void Vectorized256(ref T xRef, T y, ref T zRef, ref T dRef, nuint remaind case 3: { Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3)), - yVec, - Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 3))); + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 3))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); goto case 2; } @@ -8829,8 +9098,8 @@ static void Vectorized256(ref T xRef, T y, ref T zRef, ref T dRef, nuint remaind case 2: { Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2)), - yVec, - Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 2))); + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 2))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); goto case 1; } @@ -9044,8 +9313,8 @@ static void Vectorized512(ref T xRef, T y, ref T zRef, ref T dRef, nuint remaind case 8: { Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8)), - yVec, - Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 8))); + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 8))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); goto case 7; } @@ -9053,8 +9322,8 @@ static void Vectorized512(ref T xRef, T y, ref T zRef, ref T dRef, nuint remaind case 7: { Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7)), - yVec, - Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 7))); + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 7))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); goto case 6; } @@ -9062,8 +9331,8 @@ static void Vectorized512(ref T xRef, T y, ref T zRef, ref T dRef, nuint remaind case 6: { Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6)), - yVec, - Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 6))); + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 6))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); goto case 5; } @@ -9071,8 +9340,8 @@ static void Vectorized512(ref T xRef, T y, ref T zRef, ref T dRef, nuint remaind case 5: { Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5)), - yVec, - Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 5))); + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 5))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); goto case 4; } @@ -9080,8 +9349,8 @@ static void Vectorized512(ref T xRef, T y, ref T zRef, ref T dRef, nuint remaind case 4: { Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4)), - yVec, - Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 4))); + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 4))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); goto case 3; } @@ -9089,8 +9358,8 @@ static void Vectorized512(ref T xRef, T y, ref T zRef, ref T dRef, nuint remaind case 3: { Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3)), - yVec, - Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 3))); + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 3))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); goto case 2; } @@ -9098,8 +9367,8 @@ static void Vectorized512(ref T xRef, T y, ref T zRef, ref T dRef, nuint remaind case 2: { Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2)), - yVec, - Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 2))); + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 2))); vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); goto case 1; } @@ -9154,11 +9423,11 @@ static void VectorizedSmall4(ref T xRef, T y, ref T zRef, ref T dRef, nuint rema Vector256 yVec = Vector256.Create(y); Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), - yVec, - Vector256.LoadUnsafe(ref zRef)); + yVec, + Vector256.LoadUnsafe(ref zRef)); Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)Vector256.Count), - yVec, - Vector256.LoadUnsafe(ref zRef, remainder - (uint)Vector256.Count)); + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)Vector256.Count)); beg.StoreUnsafe(ref dRef); end.StoreUnsafe(ref dRef, remainder - (uint)Vector256.Count); @@ -9171,8 +9440,8 @@ static void VectorizedSmall4(ref T xRef, T y, ref T zRef, ref T dRef, nuint rema Debug.Assert(Vector256.IsHardwareAccelerated); Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), - Vector256.Create(y), - Vector256.LoadUnsafe(ref zRef)); + Vector256.Create(y), + Vector256.LoadUnsafe(ref zRef)); beg.StoreUnsafe(ref dRef); break; @@ -9187,11 +9456,11 @@ static void VectorizedSmall4(ref T xRef, T y, ref T zRef, ref T dRef, nuint rema Vector128 yVec = Vector128.Create(y); Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), - yVec, - Vector128.LoadUnsafe(ref zRef)); + yVec, + Vector128.LoadUnsafe(ref zRef)); Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128.Count), - yVec, - Vector128.LoadUnsafe(ref zRef, remainder - (uint)Vector128.Count)); + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)Vector128.Count)); beg.StoreUnsafe(ref dRef); end.StoreUnsafe(ref dRef, remainder - (uint)Vector128.Count); @@ -9204,8 +9473,8 @@ static void VectorizedSmall4(ref T xRef, T y, ref T zRef, ref T dRef, nuint rema Debug.Assert(Vector128.IsHardwareAccelerated); Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), - Vector128.Create(y), - Vector128.LoadUnsafe(ref zRef)); + Vector128.Create(y), + Vector128.LoadUnsafe(ref zRef)); beg.StoreUnsafe(ref dRef); break; @@ -9256,11 +9525,11 @@ static void VectorizedSmall8(ref T xRef, T y, ref T zRef, ref T dRef, nuint rema Vector256 yVec = Vector256.Create(y); Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), - yVec, - Vector256.LoadUnsafe(ref zRef)); + yVec, + Vector256.LoadUnsafe(ref zRef)); Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)Vector256.Count), - yVec, - Vector256.LoadUnsafe(ref zRef, remainder - (uint)Vector256.Count)); + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)Vector256.Count)); beg.StoreUnsafe(ref dRef); end.StoreUnsafe(ref dRef, remainder - (uint)Vector256.Count); @@ -9273,8 +9542,8 @@ static void VectorizedSmall8(ref T xRef, T y, ref T zRef, ref T dRef, nuint rema Debug.Assert(Vector256.IsHardwareAccelerated); Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), - Vector256.Create(y), - Vector256.LoadUnsafe(ref zRef)); + Vector256.Create(y), + Vector256.LoadUnsafe(ref zRef)); beg.StoreUnsafe(ref dRef); break; @@ -9287,11 +9556,11 @@ static void VectorizedSmall8(ref T xRef, T y, ref T zRef, ref T dRef, nuint rema Vector128 yVec = Vector128.Create(y); Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), - yVec, - Vector128.LoadUnsafe(ref zRef)); + yVec, + Vector128.LoadUnsafe(ref zRef)); Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)Vector128.Count), - yVec, - Vector128.LoadUnsafe(ref zRef, remainder - (uint)Vector128.Count)); + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)Vector128.Count)); beg.StoreUnsafe(ref dRef); end.StoreUnsafe(ref dRef, remainder - (uint)Vector128.Count); @@ -9304,8 +9573,8 @@ static void VectorizedSmall8(ref T xRef, T y, ref T zRef, ref T dRef, nuint rema Debug.Assert(Vector128.IsHardwareAccelerated); Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), - Vector128.Create(y), - Vector128.LoadUnsafe(ref zRef)); + Vector128.Create(y), + Vector128.LoadUnsafe(ref zRef)); beg.StoreUnsafe(ref dRef); break; @@ -9740,6 +10009,32 @@ ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt64Mask_8x9)) throw new NotSupportedException(); } + /// Creates a span of from a when they're the same type. + private static unsafe ReadOnlySpan Rename(ReadOnlySpan span) + { + Debug.Assert(sizeof(TFrom) == sizeof(TTo)); + return MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As(ref MemoryMarshal.GetReference(span)), span.Length); + } + + /// Creates a span of from a when they're the same type. + private static unsafe Span Rename(Span span) + { + Debug.Assert(sizeof(TFrom) == sizeof(TTo)); + return MemoryMarshal.CreateSpan(ref Unsafe.As(ref MemoryMarshal.GetReference(span)), span.Length); + } + + /// Gets whether is or if in a 32-bit process. + private static bool IsUInt32Like() => typeof(T) == typeof(uint) || (IntPtr.Size == 4 && typeof(T) == typeof(nuint)); + + /// Gets whether is or if in a 32-bit process. + private static bool IsInt32Like() => typeof(T) == typeof(int) || (IntPtr.Size == 4 && typeof(T) == typeof(nint)); + + /// Gets whether is or if in a 64-bit process. + private static bool IsUInt64Like() => typeof(T) == typeof(ulong) || (IntPtr.Size == 8 && typeof(T) == typeof(nuint)); + + /// Gets whether is or if in a 64-bit process. + private static bool IsInt64Like() => typeof(T) == typeof(long) || (IntPtr.Size == 8 && typeof(T) == typeof(nint)); + /// x + y internal readonly struct AddOperator : IAggregationOperator where T : IAdditionOperators, IAdditiveIdentity { @@ -9846,7 +10141,7 @@ public static Vector512 Invoke(Vector512 x, Vector512 y) // Ieee754Remainder - internal readonly struct ReciprocalOperator : IUnaryOperator where T : IFloatingPoint + internal readonly struct ReciprocalOperator : IUnaryOperator where T : IFloatingPoint { public static bool Vectorizable => true; public static T Invoke(T x) => T.One / x; @@ -9855,7 +10150,7 @@ public static Vector512 Invoke(Vector512 x, Vector512 y) public static Vector512 Invoke(Vector512 x) => Vector512.One / x; } - private readonly struct ReciprocalSqrtOperator : IUnaryOperator where T : IFloatingPointIeee754 + private readonly struct ReciprocalSqrtOperator : IUnaryOperator where T : IFloatingPointIeee754 { public static bool Vectorizable => true; public static T Invoke(T x) => T.One / T.Sqrt(x); @@ -9864,7 +10159,7 @@ public static Vector512 Invoke(Vector512 x, Vector512 y) public static Vector512 Invoke(Vector512 x) => Vector512.One / Vector512.Sqrt(x); } - private readonly struct ReciprocalEstimateOperator : IUnaryOperator where T : IFloatingPointIeee754 + private readonly struct ReciprocalEstimateOperator : IUnaryOperator where T : IFloatingPointIeee754 { public static bool Vectorizable => true; @@ -9912,7 +10207,7 @@ public static Vector512 Invoke(Vector512 x) } } - private readonly struct ReciprocalSqrtEstimateOperator : IUnaryOperator where T : IFloatingPointIeee754 + private readonly struct ReciprocalSqrtEstimateOperator : IUnaryOperator where T : IFloatingPointIeee754 { public static bool Vectorizable => true; @@ -9991,7 +10286,7 @@ public static Vector512 Invoke(Vector512 x) } /// ~x - internal readonly struct OnesComplementOperator : IUnaryOperator where T : IBitwiseOperators + internal readonly struct OnesComplementOperator : IUnaryOperator where T : IBitwiseOperators { public static bool Vectorizable => true; public static T Invoke(T x) => ~x; @@ -11221,7 +11516,7 @@ public static Vector512 Invoke(Vector512 x, Vector512 y) } /// -x - internal readonly struct NegateOperator : IUnaryOperator where T : IUnaryNegationOperators + internal readonly struct NegateOperator : IUnaryOperator where T : IUnaryNegationOperators { public static bool Vectorizable => true; public static T Invoke(T x) => -x; @@ -11267,7 +11562,7 @@ public static Vector512 Invoke(Vector512 x, Vector512 y) } /// x - internal readonly struct IdentityOperator : IUnaryOperator + internal readonly struct IdentityOperator : IUnaryOperator { public static bool Vectorizable => true; public static T Invoke(T x) => x; @@ -11277,7 +11572,7 @@ public static Vector512 Invoke(Vector512 x, Vector512 y) } /// x * x - internal readonly struct SquaredOperator : IUnaryOperator where T : IMultiplyOperators + internal readonly struct SquaredOperator : IUnaryOperator where T : IMultiplyOperators { public static bool Vectorizable => true; public static T Invoke(T x) => x * x; @@ -11287,7 +11582,7 @@ public static Vector512 Invoke(Vector512 x, Vector512 y) } /// T.Abs(x) - internal readonly struct AbsoluteOperator : IUnaryOperator where T : INumberBase + internal readonly struct AbsoluteOperator : IUnaryOperator where T : INumberBase { public static bool Vectorizable => true; @@ -11367,7 +11662,7 @@ public static Vector512 Invoke(Vector512 x) } /// T.Exp(x) - internal readonly struct ExpOperator : IUnaryOperator + internal readonly struct ExpOperator : IUnaryOperator where T : IExponentialFunctions { public static bool Vectorizable => (typeof(T) == typeof(double)) @@ -11453,7 +11748,7 @@ public static Vector512 Invoke(Vector512 x) #if !NET9_0_OR_GREATER /// double.Exp(x) - internal readonly struct ExpOperatorDouble : IUnaryOperator + internal readonly struct ExpOperatorDouble : IUnaryOperator { // This code is based on `vrd2_exp` from amd/aocl-libm-ose // Copyright (C) 2019-2020 Advanced Micro Devices, Inc. All rights reserved. @@ -11672,7 +11967,7 @@ public static Vector512 Invoke(Vector512 x) } /// float.Exp(x) - internal readonly struct ExpOperatorSingle : IUnaryOperator + internal readonly struct ExpOperatorSingle : IUnaryOperator { // This code is based on `vrs4_expf` from amd/aocl-libm-ose // Copyright (C) 2019-2022 Advanced Micro Devices, Inc. All rights reserved. @@ -11953,7 +12248,7 @@ public static Vector512 Invoke(Vector512 x) #endif /// T.ExpM1(x) - internal readonly struct ExpM1Operator : IUnaryOperator + internal readonly struct ExpM1Operator : IUnaryOperator where T : IExponentialFunctions { public static bool Vectorizable => ExpOperator.Vectorizable; @@ -11965,7 +12260,7 @@ public static Vector512 Invoke(Vector512 x) } /// T.Exp2(x) - internal readonly struct Exp2Operator : IUnaryOperator + internal readonly struct Exp2Operator : IUnaryOperator where T : IExponentialFunctions { public static bool Vectorizable => false; // TODO: Vectorize @@ -11977,7 +12272,7 @@ public static Vector512 Invoke(Vector512 x) } /// T.Exp2M1(x) - internal readonly struct Exp2M1Operator : IUnaryOperator + internal readonly struct Exp2M1Operator : IUnaryOperator where T : IExponentialFunctions { public static bool Vectorizable => Exp2Operator.Vectorizable; @@ -11989,7 +12284,7 @@ public static Vector512 Invoke(Vector512 x) } /// T.Exp10(x) - internal readonly struct Exp10Operator : IUnaryOperator + internal readonly struct Exp10Operator : IUnaryOperator where T : IExponentialFunctions { public static bool Vectorizable => false; // TODO: Vectorize @@ -12001,7 +12296,7 @@ public static Vector512 Invoke(Vector512 x) } /// T.Exp10M1(x) - internal readonly struct Exp10M1Operator : IUnaryOperator + internal readonly struct Exp10M1Operator : IUnaryOperator where T : IExponentialFunctions { public static bool Vectorizable => Exp2Operator.Vectorizable; @@ -12024,7 +12319,7 @@ public static Vector512 Invoke(Vector512 x) } /// T.Sqrt(x) - internal readonly struct SqrtOperator : IUnaryOperator + internal readonly struct SqrtOperator : IUnaryOperator where T : IRootFunctions { public static bool Vectorizable => true; @@ -12035,7 +12330,7 @@ public static Vector512 Invoke(Vector512 x) } /// T.Cbrt(x) - internal readonly struct CbrtOperator : IUnaryOperator + internal readonly struct CbrtOperator : IUnaryOperator where T : IRootFunctions { public static bool Vectorizable => false; // TODO: Vectorize @@ -12057,7 +12352,7 @@ public static Vector512 Invoke(Vector512 x) } /// T.Acos(x) - internal readonly struct AcosOperator : IUnaryOperator + internal readonly struct AcosOperator : IUnaryOperator where T : ITrigonometricFunctions { public static bool Vectorizable => false; // TODO: Vectorize @@ -12068,7 +12363,7 @@ public static Vector512 Invoke(Vector512 x) } /// T.Acosh(x) - internal readonly struct AcoshOperator : IUnaryOperator + internal readonly struct AcoshOperator : IUnaryOperator where T : IHyperbolicFunctions { public static bool Vectorizable => false; // TODO: Vectorize @@ -12079,7 +12374,7 @@ public static Vector512 Invoke(Vector512 x) } /// T.AcosPi(x) - internal readonly struct AcosPiOperator : IUnaryOperator + internal readonly struct AcosPiOperator : IUnaryOperator where T : ITrigonometricFunctions { public static bool Vectorizable => AcosOperator.Vectorizable; @@ -12090,7 +12385,7 @@ public static Vector512 Invoke(Vector512 x) } /// T.Asin(x) - internal readonly struct AsinOperator : IUnaryOperator + internal readonly struct AsinOperator : IUnaryOperator where T : ITrigonometricFunctions { public static bool Vectorizable => false; // TODO: Vectorize @@ -12101,7 +12396,7 @@ public static Vector512 Invoke(Vector512 x) } /// T.Asinh(x) - internal readonly struct AsinhOperator : IUnaryOperator + internal readonly struct AsinhOperator : IUnaryOperator where T : IHyperbolicFunctions { public static bool Vectorizable => false; // TODO: Vectorize @@ -12112,7 +12407,7 @@ public static Vector512 Invoke(Vector512 x) } /// T.AsinPi(x) - internal readonly struct AsinPiOperator : IUnaryOperator + internal readonly struct AsinPiOperator : IUnaryOperator where T : ITrigonometricFunctions { public static bool Vectorizable => AsinOperator.Vectorizable; @@ -12123,7 +12418,7 @@ public static Vector512 Invoke(Vector512 x) } /// T.Atan(x) - internal readonly struct AtanOperator : IUnaryOperator + internal readonly struct AtanOperator : IUnaryOperator where T : ITrigonometricFunctions { public static bool Vectorizable => false; // TODO: Vectorize @@ -12134,7 +12429,7 @@ public static Vector512 Invoke(Vector512 x) } /// T.Atanh(x) - internal readonly struct AtanhOperator : IUnaryOperator + internal readonly struct AtanhOperator : IUnaryOperator where T : IHyperbolicFunctions { public static bool Vectorizable => false; // TODO: Vectorize @@ -12145,7 +12440,7 @@ public static Vector512 Invoke(Vector512 x) } /// T.AtanPi(x) - internal readonly struct AtanPiOperator : IUnaryOperator + internal readonly struct AtanPiOperator : IUnaryOperator where T : ITrigonometricFunctions { public static bool Vectorizable => AtanOperator.Vectorizable; @@ -12178,7 +12473,7 @@ public static Vector512 Invoke(Vector512 x) } /// T.Cos(x) - internal readonly struct CosOperator : IUnaryOperator + internal readonly struct CosOperator : IUnaryOperator where T : ITrigonometricFunctions { public static bool Vectorizable => false; // TODO: Vectorize @@ -12189,7 +12484,7 @@ public static Vector512 Invoke(Vector512 x) } /// T.CosPi(x) - internal readonly struct CosPiOperator : IUnaryOperator + internal readonly struct CosPiOperator : IUnaryOperator where T : ITrigonometricFunctions { public static bool Vectorizable => CosOperator.Vectorizable; @@ -12200,7 +12495,7 @@ public static Vector512 Invoke(Vector512 x) } /// T.Cosh(x) - internal readonly struct CoshOperator : IUnaryOperator + internal readonly struct CoshOperator : IUnaryOperator where T : IHyperbolicFunctions { // This code is based on `vrs4_coshf` from amd/aocl-libm-ose @@ -12264,7 +12559,7 @@ public static Vector512 Invoke(Vector512 t) } /// T.Sin(x) - internal readonly struct SinOperator : IUnaryOperator + internal readonly struct SinOperator : IUnaryOperator where T : ITrigonometricFunctions { public static bool Vectorizable => false; // TODO: Vectorize @@ -12275,7 +12570,7 @@ public static Vector512 Invoke(Vector512 t) } /// T.SinPi(x) - internal readonly struct SinPiOperator : IUnaryOperator + internal readonly struct SinPiOperator : IUnaryOperator where T : ITrigonometricFunctions { public static bool Vectorizable => SinOperator.Vectorizable; @@ -12286,7 +12581,7 @@ public static Vector512 Invoke(Vector512 t) } /// T.Sinh(x) - internal readonly struct SinhOperator : IUnaryOperator + internal readonly struct SinhOperator : IUnaryOperator where T : IHyperbolicFunctions { // Same as cosh, but with `z -` rather than `z +`, and with the sign @@ -12339,7 +12634,7 @@ public static Vector512 Invoke(Vector512 t) } /// T.Tan(x) - internal readonly struct TanOperator : IUnaryOperator + internal readonly struct TanOperator : IUnaryOperator where T : ITrigonometricFunctions { public static bool Vectorizable => false; // TODO: Vectorize @@ -12350,7 +12645,7 @@ public static Vector512 Invoke(Vector512 t) } /// T.TanPi(x) - internal readonly struct TanPiOperator : IUnaryOperator + internal readonly struct TanPiOperator : IUnaryOperator where T : ITrigonometricFunctions { public static bool Vectorizable => TanOperator.Vectorizable; @@ -12361,7 +12656,7 @@ public static Vector512 Invoke(Vector512 t) } /// T.Tanh(x) - internal readonly struct TanhOperator : IUnaryOperator + internal readonly struct TanhOperator : IUnaryOperator where T : IHyperbolicFunctions { // This code is based on `vrs4_tanhf` from amd/aocl-libm-ose @@ -12424,7 +12719,7 @@ public static Vector512 Invoke(Vector512 t) } /// T.Log(x) - internal readonly struct LogOperator : IUnaryOperator + internal readonly struct LogOperator : IUnaryOperator where T : ILogarithmicFunctions { public static bool Vectorizable => (typeof(T) == typeof(double)) @@ -12510,7 +12805,7 @@ public static Vector512 Invoke(Vector512 x) #if !NET9_0_OR_GREATER /// double.Log(x) - internal readonly struct LogOperatorDouble : IUnaryOperator + internal readonly struct LogOperatorDouble : IUnaryOperator { // This code is based on `vrd2_log` from amd/aocl-libm-ose // Copyright (C) 2018-2020 Advanced Micro Devices, Inc. All rights reserved. @@ -12816,7 +13111,7 @@ public static Vector512 Invoke(Vector512 x) } /// float.Log(x) - internal readonly struct LogOperatorSingle : IUnaryOperator + internal readonly struct LogOperatorSingle : IUnaryOperator { // This code is based on `vrs4_logf` from amd/aocl-libm-ose // Copyright (C) 2018-2019 Advanced Micro Devices, Inc. All rights reserved. @@ -13103,7 +13398,7 @@ public static Vector512 Invoke(Vector512 x) #endif /// T.Log2(x) - internal readonly struct Log2Operator : IUnaryOperator + internal readonly struct Log2Operator : IUnaryOperator where T : ILogarithmicFunctions { public static bool Vectorizable => (typeof(T) == typeof(double)) @@ -13189,7 +13484,7 @@ public static Vector512 Invoke(Vector512 x) #if !NET9_0_OR_GREATER /// double.Log2(x) - internal readonly struct Log2OperatorDouble : IUnaryOperator + internal readonly struct Log2OperatorDouble : IUnaryOperator { // This code is based on `vrd2_log2` from amd/aocl-libm-ose // Copyright (C) 2021-2022 Advanced Micro Devices, Inc. All rights reserved. @@ -13493,7 +13788,7 @@ public static Vector512 Invoke(Vector512 x) } /// float.Log2(x) - internal readonly struct Log2OperatorSingle : IUnaryOperator + internal readonly struct Log2OperatorSingle : IUnaryOperator { // This code is based on `vrs4_log2f` from amd/aocl-libm-ose // Copyright (C) 2021-2022 Advanced Micro Devices, Inc. All rights reserved. @@ -13775,7 +14070,7 @@ public static Vector512 Invoke(Vector512 x) #endif /// T.Log10(x) - internal readonly struct Log10Operator : IUnaryOperator + internal readonly struct Log10Operator : IUnaryOperator where T : ILogarithmicFunctions { public static bool Vectorizable => false; // TODO: Vectorize @@ -13786,7 +14081,7 @@ public static Vector512 Invoke(Vector512 x) } /// T.LogP1(x) - internal readonly struct LogP1Operator : IUnaryOperator + internal readonly struct LogP1Operator : IUnaryOperator where T : ILogarithmicFunctions { public static bool Vectorizable => LogOperator.Vectorizable; @@ -13797,7 +14092,7 @@ public static Vector512 Invoke(Vector512 x) } /// T.Log2P1(x) - internal readonly struct Log2P1Operator : IUnaryOperator + internal readonly struct Log2P1Operator : IUnaryOperator where T : ILogarithmicFunctions { public static bool Vectorizable => Log2Operator.Vectorizable; @@ -13808,7 +14103,7 @@ public static Vector512 Invoke(Vector512 x) } /// T.Log10P1(x) - internal readonly struct Log10P1Operator : IUnaryOperator + internal readonly struct Log10P1Operator : IUnaryOperator where T : ILogarithmicFunctions { public static bool Vectorizable => Log10Operator.Vectorizable; @@ -13879,7 +14174,7 @@ private static Vector512 ElementWiseSelect(Vector512 mask, Vector512 } /// 1 / (1 + T.Exp(-x)) - internal readonly struct SigmoidOperator : IUnaryOperator where T : IExponentialFunctions + internal readonly struct SigmoidOperator : IUnaryOperator where T : IExponentialFunctions { public static bool Vectorizable => typeof(T) == typeof(float); public static T Invoke(T x) => T.One / (T.One + T.Exp(-x)); @@ -13888,7 +14183,7 @@ private static Vector512 ElementWiseSelect(Vector512 mask, Vector512 public static Vector512 Invoke(Vector512 x) => Vector512.Create(T.One) / (Vector512.Create(T.One) + ExpOperator.Invoke(-x)); } - internal readonly struct CeilingOperator : IUnaryOperator where T : IFloatingPoint + internal readonly struct CeilingOperator : IUnaryOperator where T : IFloatingPoint { public static bool Vectorizable => typeof(T) == typeof(float) || typeof(T) == typeof(double); @@ -13940,7 +14235,7 @@ public static Vector512 Invoke(Vector512 x) } } - internal readonly struct FloorOperator : IUnaryOperator where T : IFloatingPoint + internal readonly struct FloorOperator : IUnaryOperator where T : IFloatingPoint { public static bool Vectorizable => typeof(T) == typeof(float) || typeof(T) == typeof(double); @@ -13992,7 +14287,7 @@ public static Vector512 Invoke(Vector512 x) } } - private readonly struct TruncateOperator : IUnaryOperator where T : IFloatingPoint + private readonly struct TruncateOperator : IUnaryOperator where T : IFloatingPoint { public static bool Vectorizable => typeof(T) == typeof(float) || typeof(T) == typeof(double); @@ -14071,7 +14366,7 @@ public static Vector512 Invoke(Vector512 x) } /// T.PopCount(x) - internal readonly struct PopCountOperator : IUnaryOperator where T : IBinaryInteger + internal readonly struct PopCountOperator : IUnaryOperator where T : IBinaryInteger { public static bool Vectorizable => false; // TODO: Vectorize public static T Invoke(T x) => T.PopCount(x); @@ -14081,7 +14376,7 @@ public static Vector512 Invoke(Vector512 x) } /// T.LeadingZeroCount(x) - internal readonly struct LeadingZeroCountOperator : IUnaryOperator where T : IBinaryInteger + internal readonly struct LeadingZeroCountOperator : IUnaryOperator where T : IBinaryInteger { public static bool Vectorizable => false; // TODO: Vectorize public static T Invoke(T x) => T.LeadingZeroCount(x); @@ -14091,7 +14386,7 @@ public static Vector512 Invoke(Vector512 x) } /// T.TrailingZeroCount(x) - internal readonly struct TrailingZeroCountOperator : IUnaryOperator where T : IBinaryInteger + internal readonly struct TrailingZeroCountOperator : IUnaryOperator where T : IBinaryInteger { public static bool Vectorizable => false; // TODO: Vectorize public static T Invoke(T x) => T.TrailingZeroCount(x); @@ -14192,7 +14487,7 @@ public static Vector512 Invoke(Vector512 x, Vector512 y) } /// T.DegreesToRadians(x) - internal readonly struct DegreesToRadiansOperator : IUnaryOperator where T : ITrigonometricFunctions + internal readonly struct DegreesToRadiansOperator : IUnaryOperator where T : ITrigonometricFunctions { public static bool Vectorizable => true; public static T Invoke(T x) => T.DegreesToRadians(x); @@ -14202,7 +14497,7 @@ public static Vector512 Invoke(Vector512 x, Vector512 y) } /// T.RadiansToDegrees(x) - internal readonly struct RadiansToDegreesOperator : IUnaryOperator where T : ITrigonometricFunctions + internal readonly struct RadiansToDegreesOperator : IUnaryOperator where T : ITrigonometricFunctions { public static bool Vectorizable => true; public static T Invoke(T x) => T.RadiansToDegrees(x); @@ -14211,14 +14506,334 @@ public static Vector512 Invoke(Vector512 x, Vector512 y) public static Vector512 Invoke(Vector512 x) => (x * T.CreateChecked(180)) / T.Pi; } + /// T.ILogB(x) + internal readonly struct ILogBOperator : IUnaryOperator where T : IFloatingPointIeee754 + { + public static bool Vectorizable => false; // TODO: vectorize for float + + public static int Invoke(T x) => T.ILogB(x); + public static Vector128 Invoke(Vector128 x) => throw new NotImplementedException(); + public static Vector256 Invoke(Vector256 x) => throw new NotImplementedException(); + public static Vector512 Invoke(Vector512 x) => throw new NotImplementedException(); + } + + /// double.ILogB(x) + internal readonly struct ILogBDoubleOperator : IUnaryTwoToOneOperator + { + public static bool Vectorizable => false; // TODO: vectorize + + public static int Invoke(double x) => double.ILogB(x); + public static Vector128 Invoke(Vector128 lower, Vector128 upper) => throw new NotImplementedException(); + public static Vector256 Invoke(Vector256 lower, Vector256 upper) => throw new NotImplementedException(); + public static Vector512 Invoke(Vector512 lower, Vector512 upper) => throw new NotImplementedException(); + } + + /// T.CreateChecked(x) + internal readonly struct ConvertCheckedFallbackOperator : IUnaryOperator where TFrom : INumberBase where TTo : INumberBase + { + public static bool Vectorizable => false; + + public static TTo Invoke(TFrom x) => TTo.CreateChecked(x); + public static Vector128 Invoke(Vector128 x) => throw new NotImplementedException(); + public static Vector256 Invoke(Vector256 x) => throw new NotImplementedException(); + public static Vector512 Invoke(Vector512 x) => throw new NotImplementedException(); + } + + /// T.CreateSaturating(x) + internal readonly struct ConvertSaturatingFallbackOperator : IUnaryOperator where TFrom : INumberBase where TTo : INumberBase + { + public static bool Vectorizable => false; + + public static TTo Invoke(TFrom x) => TTo.CreateSaturating(x); + public static Vector128 Invoke(Vector128 x) => throw new NotImplementedException(); + public static Vector256 Invoke(Vector256 x) => throw new NotImplementedException(); + public static Vector512 Invoke(Vector512 x) => throw new NotImplementedException(); + } + + /// T.CreateTruncating(x) + internal readonly struct ConvertTruncatingFallbackOperator : IUnaryOperator where TFrom : INumberBase where TTo : INumberBase + { + public static bool Vectorizable => false; + + public static TTo Invoke(TFrom x) => TTo.CreateTruncating(x); + public static Vector128 Invoke(Vector128 x) => throw new NotImplementedException(); + public static Vector256 Invoke(Vector256 x) => throw new NotImplementedException(); + public static Vector512 Invoke(Vector512 x) => throw new NotImplementedException(); + } + + /// (uint)float + internal readonly struct ConvertUInt32ToSingle : IUnaryOperator + { + public static bool Vectorizable => true; + + public static float Invoke(uint x) => x; + public static Vector128 Invoke(Vector128 x) => Vector128.ConvertToSingle(x); + public static Vector256 Invoke(Vector256 x) => Vector256.ConvertToSingle(x); + public static Vector512 Invoke(Vector512 x) => Vector512.ConvertToSingle(x); + } + + /// (int)float + internal readonly struct ConvertInt32ToSingle : IUnaryOperator + { + public static bool Vectorizable => true; + + public static float Invoke(int x) => x; + public static Vector128 Invoke(Vector128 x) => Vector128.ConvertToSingle(x); + public static Vector256 Invoke(Vector256 x) => Vector256.ConvertToSingle(x); + public static Vector512 Invoke(Vector512 x) => Vector512.ConvertToSingle(x); + } + + /// (float)uint + internal readonly struct ConvertSingleToUInt32 : IUnaryOperator + { + public static bool Vectorizable => false; // TODO https://github.com/dotnet/runtime/pull/97529: make this true once vectorized behavior matches scalar + + public static uint Invoke(float x) => uint.CreateTruncating(x); + public static Vector128 Invoke(Vector128 x) => Vector128.ConvertToUInt32(x); + public static Vector256 Invoke(Vector256 x) => Vector256.ConvertToUInt32(x); + public static Vector512 Invoke(Vector512 x) => Vector512.ConvertToUInt32(x); + } + + /// (float)int + internal readonly struct ConvertSingleToInt32 : IUnaryOperator + { + public static bool Vectorizable => false; // TODO https://github.com/dotnet/runtime/pull/97529: make this true once vectorized behavior matches scalar + + public static int Invoke(float x) => int.CreateTruncating(x); + public static Vector128 Invoke(Vector128 x) => Vector128.ConvertToInt32(x); + public static Vector256 Invoke(Vector256 x) => Vector256.ConvertToInt32(x); + public static Vector512 Invoke(Vector512 x) => Vector512.ConvertToInt32(x); + } + + /// (double)ulong + internal readonly struct ConvertUInt64ToDouble : IUnaryOperator + { + public static bool Vectorizable => true; + + public static double Invoke(ulong x) => x; + public static Vector128 Invoke(Vector128 x) => Vector128.ConvertToDouble(x); + public static Vector256 Invoke(Vector256 x) => Vector256.ConvertToDouble(x); + public static Vector512 Invoke(Vector512 x) => Vector512.ConvertToDouble(x); + } + + /// (double)long + internal readonly struct ConvertInt64ToDouble : IUnaryOperator + { + public static bool Vectorizable => true; + + public static double Invoke(long x) => x; + public static Vector128 Invoke(Vector128 x) => Vector128.ConvertToDouble(x); + public static Vector256 Invoke(Vector256 x) => Vector256.ConvertToDouble(x); + public static Vector512 Invoke(Vector512 x) => Vector512.ConvertToDouble(x); + } + + /// (ulong)double + internal readonly struct ConvertDoubleToUInt64 : IUnaryOperator + { + public static bool Vectorizable => false; // TODO https://github.com/dotnet/runtime/pull/97529: make this true once vectorized behavior matches scalar + + public static ulong Invoke(double x) => ulong.CreateTruncating(x); + public static Vector128 Invoke(Vector128 x) => Vector128.ConvertToUInt64(x); + public static Vector256 Invoke(Vector256 x) => Vector256.ConvertToUInt64(x); + public static Vector512 Invoke(Vector512 x) => Vector512.ConvertToUInt64(x); + } + + /// (long)double + internal readonly struct ConvertDoubleToInt64 : IUnaryOperator + { + public static bool Vectorizable => false; // TODO https://github.com/dotnet/runtime/pull/97529: make this true once vectorized behavior matches scalar + + public static long Invoke(double x) => long.CreateTruncating(x); + public static Vector128 Invoke(Vector128 x) => Vector128.ConvertToInt64(x); + public static Vector256 Invoke(Vector256 x) => Vector256.ConvertToInt64(x); + public static Vector512 Invoke(Vector512 x) => Vector512.ConvertToInt64(x); + } + + /// (double)float + internal readonly struct WidenSingleToDoubleOperator : IUnaryOneToTwoOperator + { + public static bool Vectorizable => true; + + public static double Invoke(float x) => x; + public static (Vector128 Lower, Vector128 Upper) Invoke(Vector128 x) => Vector128.Widen(x); + public static (Vector256 Lower, Vector256 Upper) Invoke(Vector256 x) => Vector256.Widen(x); + public static (Vector512 Lower, Vector512 Upper) Invoke(Vector512 x) => Vector512.Widen(x); + } + + /// (float)double + internal readonly struct NarrowDoubleToSingleOperator : IUnaryTwoToOneOperator + { + public static bool Vectorizable => true; + + public static float Invoke(double x) => (float)x; + public static Vector128 Invoke(Vector128 lower, Vector128 upper) => Vector128.Narrow(lower, upper); + public static Vector256 Invoke(Vector256 lower, Vector256 upper) => Vector256.Narrow(lower, upper); + public static Vector512 Invoke(Vector512 lower, Vector512 upper) => Vector512.Narrow(lower, upper); + } + + /// (ushort)byte + internal readonly struct WidenByteToUInt16Operator : IUnaryOneToTwoOperator + { + public static bool Vectorizable => true; + + public static ushort Invoke(byte x) => x; + public static (Vector128 Lower, Vector128 Upper) Invoke(Vector128 x) => Vector128.Widen(x); + public static (Vector256 Lower, Vector256 Upper) Invoke(Vector256 x) => Vector256.Widen(x); + public static (Vector512 Lower, Vector512 Upper) Invoke(Vector512 x) => Vector512.Widen(x); + } + + /// (byte)ushort + internal readonly struct NarrowUInt16ToByteOperator : IUnaryTwoToOneOperator + { + public static bool Vectorizable => true; + + public static byte Invoke(ushort x) => (byte)x; + public static Vector128 Invoke(Vector128 lower, Vector128 upper) => Vector128.Narrow(lower, upper); + public static Vector256 Invoke(Vector256 lower, Vector256 upper) => Vector256.Narrow(lower, upper); + public static Vector512 Invoke(Vector512 lower, Vector512 upper) => Vector512.Narrow(lower, upper); + } + + /// (short)sbyte + internal readonly struct WidenSByteToInt16Operator : IUnaryOneToTwoOperator + { + public static bool Vectorizable => true; + + public static short Invoke(sbyte x) => x; + public static (Vector128 Lower, Vector128 Upper) Invoke(Vector128 x) => Vector128.Widen(x); + public static (Vector256 Lower, Vector256 Upper) Invoke(Vector256 x) => Vector256.Widen(x); + public static (Vector512 Lower, Vector512 Upper) Invoke(Vector512 x) => Vector512.Widen(x); + } + + /// (sbyte)short + internal readonly struct NarrowInt16ToSByteOperator : IUnaryTwoToOneOperator + { + public static bool Vectorizable => true; + + public static sbyte Invoke(short x) => (sbyte)x; + public static Vector128 Invoke(Vector128 lower, Vector128 upper) => Vector128.Narrow(lower, upper); + public static Vector256 Invoke(Vector256 lower, Vector256 upper) => Vector256.Narrow(lower, upper); + public static Vector512 Invoke(Vector512 lower, Vector512 upper) => Vector512.Narrow(lower, upper); + } + + /// (uint)ushort + internal readonly struct WidenUInt16ToUInt32Operator : IUnaryOneToTwoOperator + { + public static bool Vectorizable => true; + + public static uint Invoke(ushort x) => x; + public static (Vector128 Lower, Vector128 Upper) Invoke(Vector128 x) => Vector128.Widen(x); + public static (Vector256 Lower, Vector256 Upper) Invoke(Vector256 x) => Vector256.Widen(x); + public static (Vector512 Lower, Vector512 Upper) Invoke(Vector512 x) => Vector512.Widen(x); + } + + /// (ushort)uint + internal readonly struct NarrowUInt32ToUInt16Operator : IUnaryTwoToOneOperator + { + public static bool Vectorizable => true; + + public static ushort Invoke(uint x) => (ushort)x; + public static Vector128 Invoke(Vector128 lower, Vector128 upper) => Vector128.Narrow(lower, upper); + public static Vector256 Invoke(Vector256 lower, Vector256 upper) => Vector256.Narrow(lower, upper); + public static Vector512 Invoke(Vector512 lower, Vector512 upper) => Vector512.Narrow(lower, upper); + } + + /// (int)short + internal readonly struct WidenInt16ToInt32Operator : IUnaryOneToTwoOperator + { + public static bool Vectorizable => true; + + public static int Invoke(short x) => x; + public static (Vector128 Lower, Vector128 Upper) Invoke(Vector128 x) => Vector128.Widen(x); + public static (Vector256 Lower, Vector256 Upper) Invoke(Vector256 x) => Vector256.Widen(x); + public static (Vector512 Lower, Vector512 Upper) Invoke(Vector512 x) => Vector512.Widen(x); + } + + /// (short)int + internal readonly struct NarrowInt32ToInt16Operator : IUnaryTwoToOneOperator + { + public static bool Vectorizable => true; + + public static short Invoke(int x) => (short)x; + public static Vector128 Invoke(Vector128 lower, Vector128 upper) => Vector128.Narrow(lower, upper); + public static Vector256 Invoke(Vector256 lower, Vector256 upper) => Vector256.Narrow(lower, upper); + public static Vector512 Invoke(Vector512 lower, Vector512 upper) => Vector512.Narrow(lower, upper); + } + + /// (ulong)uint + internal readonly struct WidenUInt32ToUInt64Operator : IUnaryOneToTwoOperator + { + public static bool Vectorizable => true; + + public static ulong Invoke(uint x) => x; + public static (Vector128 Lower, Vector128 Upper) Invoke(Vector128 x) => Vector128.Widen(x); + public static (Vector256 Lower, Vector256 Upper) Invoke(Vector256 x) => Vector256.Widen(x); + public static (Vector512 Lower, Vector512 Upper) Invoke(Vector512 x) => Vector512.Widen(x); + } + + /// (uint)ulong + internal readonly struct NarrowUInt64ToUInt32Operator : IUnaryTwoToOneOperator + { + public static bool Vectorizable => true; + + public static uint Invoke(ulong x) => (uint)x; + public static Vector128 Invoke(Vector128 lower, Vector128 upper) => Vector128.Narrow(lower, upper); + public static Vector256 Invoke(Vector256 lower, Vector256 upper) => Vector256.Narrow(lower, upper); + public static Vector512 Invoke(Vector512 lower, Vector512 upper) => Vector512.Narrow(lower, upper); + } + + /// (long)int + internal readonly struct WidenInt32ToInt64Operator : IUnaryOneToTwoOperator + { + public static bool Vectorizable => true; + + public static long Invoke(int x) => x; + public static (Vector128 Lower, Vector128 Upper) Invoke(Vector128 x) => Vector128.Widen(x); + public static (Vector256 Lower, Vector256 Upper) Invoke(Vector256 x) => Vector256.Widen(x); + public static (Vector512 Lower, Vector512 Upper) Invoke(Vector512 x) => Vector512.Widen(x); + } + + /// (int)long + internal readonly struct NarrowInt64ToInt32Operator : IUnaryTwoToOneOperator + { + public static bool Vectorizable => true; + + public static int Invoke(long x) => (int)x; + public static Vector128 Invoke(Vector128 lower, Vector128 upper) => Vector128.Narrow(lower, upper); + public static Vector256 Invoke(Vector256 lower, Vector256 upper) => Vector256.Narrow(lower, upper); + public static Vector512 Invoke(Vector512 lower, Vector512 upper) => Vector512.Narrow(lower, upper); + } + + /// Operator that takes one input value and returns a single value. + /// The input and output type must be of the same size if vectorization is desired. + private interface IUnaryOperator + { + static abstract bool Vectorizable { get; } + static abstract TOutput Invoke(TInput x); + static abstract Vector128 Invoke(Vector128 x); + static abstract Vector256 Invoke(Vector256 x); + static abstract Vector512 Invoke(Vector512 x); + } + + /// Operator that takes one input value and returns a single value. + /// The input type must be half the size of the output type. + private interface IUnaryOneToTwoOperator + { + static abstract bool Vectorizable { get; } + static abstract TOutput Invoke(TInput x); + static abstract (Vector128 Lower, Vector128 Upper) Invoke(Vector128 x); + static abstract (Vector256 Lower, Vector256 Upper) Invoke(Vector256 x); + static abstract (Vector512 Lower, Vector512 Upper) Invoke(Vector512 x); + } + /// Operator that takes one input value and returns a single value. - private interface IUnaryOperator + /// The input type must be twice the size of the output type. + private interface IUnaryTwoToOneOperator { static abstract bool Vectorizable { get; } - static abstract T Invoke(T x); - static abstract Vector128 Invoke(Vector128 x); - static abstract Vector256 Invoke(Vector256 x); - static abstract Vector512 Invoke(Vector512 x); + static abstract TOutput Invoke(TInput x); + static abstract Vector128 Invoke(Vector128 lower, Vector128 upper); + static abstract Vector256 Invoke(Vector256 lower, Vector256 upper); + static abstract Vector512 Invoke(Vector512 lower, Vector512 upper); } /// Operator that takes two input values and returns a single value. diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.Generic.cs b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.Generic.cs index b44036b1a74628..4d1c22a402e54e 100644 --- a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.Generic.cs +++ b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.Generic.cs @@ -4,6 +4,7 @@ using System.Buffers; using System.Collections.Generic; using System.Linq; +using System.Reflection; using System.Runtime.InteropServices; using Xunit; using Xunit.Sdk; @@ -15,6 +16,198 @@ namespace System.Numerics.Tensors.Tests { + public class ConvertTests + { + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsNotBuiltWithAggressiveTrimming))] + public void ConvertTruncatingAndSaturating() + { + MethodInfo convertTruncatingImpl = typeof(ConvertTests).GetMethod(nameof(ConvertTruncatingImpl), BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance); + Assert.NotNull(convertTruncatingImpl); + + MethodInfo convertSaturatingImpl = typeof(ConvertTests).GetMethod(nameof(ConvertSaturatingImpl), BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance); + Assert.NotNull(convertSaturatingImpl); + + Type[] types = + [ + typeof(sbyte), typeof(byte), + typeof(short), typeof(ushort), typeof(char), + typeof(int), typeof(uint), + typeof(long), typeof(ulong), + typeof(nint), typeof(nuint), + typeof(Half), typeof(float), typeof(double), typeof(NFloat), + typeof(Int128), typeof(UInt128), + ]; + + foreach (Type from in types) + { + foreach (Type to in types) + { + convertTruncatingImpl.MakeGenericMethod(from, to).Invoke(null, null); + convertSaturatingImpl.MakeGenericMethod(from, to).Invoke(null, null); + } + } + } + + [Fact] + public void ConvertChecked() + { + // Conversions that never overflow. This isn't an exhaustive list; just a sampling. + ConvertCheckedImpl(); + ConvertCheckedImpl(); + ConvertCheckedImpl(); + ConvertCheckedImpl(); + ConvertCheckedImpl(); + ConvertCheckedImpl(); + ConvertCheckedImpl(); + ConvertCheckedImpl(); + ConvertCheckedImpl(); + ConvertCheckedImpl(); + ConvertCheckedImpl(); + ConvertCheckedImpl(); + ConvertCheckedImpl(); + + // Conversions that may overflow. This isn't an exhaustive list; just a sampling. + ConvertCheckedImpl(42f, float.MaxValue); + ConvertCheckedImpl(42, int.MaxValue + 1L); + } + + private static void ConvertTruncatingImpl() + where TFrom : unmanaged, INumber + where TTo : unmanaged, INumber + { + AssertExtensions.Throws("destination", () => TensorPrimitives.ConvertTruncating(new TFrom[3], new TTo[2])); + + foreach (int tensorLength in Helpers.TensorLengthsIncluding0) + { + using BoundedMemory source = BoundedMemory.Allocate(tensorLength); + using BoundedMemory destination = BoundedMemory.Allocate(tensorLength); + + Random rand = new(42); + Span sourceSpan = source.Span; + for (int i = 0; i < tensorLength; i++) + { + sourceSpan[i] = TFrom.CreateTruncating(new Int128( + (ulong)rand.NextInt64(long.MinValue, long.MaxValue), + (ulong)rand.NextInt64(long.MinValue, long.MaxValue))); + } + + TensorPrimitives.ConvertTruncating(source.Span, destination.Span); + + for (int i = 0; i < tensorLength; i++) + { + if (!IsEqualWithTolerance(TTo.CreateTruncating(source.Span[i]), destination.Span[i])) + { + throw new XunitException($"{typeof(TFrom).Name} => {typeof(TTo).Name}. Input: {source.Span[i]}. Actual: {destination.Span[i]}. Expected: {TTo.CreateTruncating(source.Span[i])}."); + } + } + }; + } + + private static void ConvertSaturatingImpl() + where TFrom : unmanaged, INumber + where TTo : unmanaged, INumber + { + AssertExtensions.Throws("destination", () => TensorPrimitives.ConvertSaturating(new TFrom[3], new TTo[2])); + + foreach (int tensorLength in Helpers.TensorLengthsIncluding0) + { + using BoundedMemory source = BoundedMemory.Allocate(tensorLength); + using BoundedMemory destination = BoundedMemory.Allocate(tensorLength); + + Random rand = new(42); + Span sourceSpan = source.Span; + for (int i = 0; i < tensorLength; i++) + { + sourceSpan[i] = TFrom.CreateTruncating(new Int128( + (ulong)rand.NextInt64(long.MinValue, long.MaxValue), + (ulong)rand.NextInt64(long.MinValue, long.MaxValue))); + } + + TensorPrimitives.ConvertSaturating(source.Span, destination.Span); + + for (int i = 0; i < tensorLength; i++) + { + if (!IsEqualWithTolerance(TTo.CreateSaturating(source.Span[i]), destination.Span[i])) + { + throw new XunitException($"{typeof(TFrom).Name} => {typeof(TTo).Name}. Input: {source.Span[i]}. Actual: {destination.Span[i]}. Expected: {TTo.CreateSaturating(source.Span[i])}."); + } + } + }; + } + + private static void ConvertCheckedImpl() + where TFrom : unmanaged, INumber + where TTo : unmanaged, INumber + { + AssertExtensions.Throws("destination", () => TensorPrimitives.ConvertChecked(new TFrom[3], new TTo[2])); + + foreach (int tensorLength in Helpers.TensorLengthsIncluding0) + { + using BoundedMemory source = BoundedMemory.Allocate(tensorLength); + using BoundedMemory destination = BoundedMemory.Allocate(tensorLength); + + Random rand = new(42); + Span sourceSpan = source.Span; + for (int i = 0; i < tensorLength; i++) + { + sourceSpan[i] = TFrom.CreateTruncating(new Int128( + (ulong)rand.NextInt64(long.MinValue, long.MaxValue), + (ulong)rand.NextInt64(long.MinValue, long.MaxValue))); + } + + TensorPrimitives.ConvertChecked(source.Span, destination.Span); + + for (int i = 0; i < tensorLength; i++) + { + if (!IsEqualWithTolerance(TTo.CreateChecked(source.Span[i]), destination.Span[i])) + { + throw new XunitException($"{typeof(TFrom).Name} => {typeof(TTo).Name}. Input: {source.Span[i]}. Actual: {destination.Span[i]}. Expected: {TTo.CreateChecked(source.Span[i])}."); + } + } + }; + } + + private static void ConvertCheckedImpl(TFrom valid, TFrom invalid) + where TFrom : unmanaged, INumber + where TTo : unmanaged, INumber + { + foreach (int tensorLength in Helpers.TensorLengths) + { + using BoundedMemory source = BoundedMemory.Allocate(tensorLength); + using BoundedMemory destination = BoundedMemory.Allocate(tensorLength); + + // Test with valid + source.Span.Fill(valid); + TensorPrimitives.ConvertChecked(source.Span, destination.Span); + foreach (TTo result in destination.Span) + { + Assert.True(IsEqualWithTolerance(TTo.CreateChecked(valid), result)); + } + + // Test with at least one invalid + foreach (int invalidPosition in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + source.Span.Fill(valid); + source.Span[invalidPosition] = invalid; + Assert.Throws(() => TensorPrimitives.ConvertChecked(source.Span, destination.Span)); + } + }; + } + + private static bool IsEqualWithTolerance(T expected, T actual, T? tolerance = null) where T : unmanaged, INumber + { + tolerance ??= T.CreateTruncating(0.0001); + + T diff = T.Abs(expected - actual); + if (diff > tolerance && diff > T.Max(T.Abs(expected), T.Abs(actual)) * tolerance) + { + return false; + } + + return true; + } + } + public class DoubleGenericTensorPrimitives : GenericFloatingPointNumberTensorPrimitivesTests { } public class SingleGenericTensorPrimitives : GenericFloatingPointNumberTensorPrimitivesTests { } public class HalfGenericTensorPrimitives : GenericFloatingPointNumberTensorPrimitivesTests