Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 3efe176

Browse files
committed
Vectorize TensorPrimitives.Pow, Exp2, Exp10, Cbrt
While there are likely more accurate and better performing options that involve more code, for now we can: - Implement Pow for now as Exp(y * Log(x)) - Implement Exp2 for now as Exp(x * Log(2)) - Implement Exp10 for now as Exp(x * Log(10)) - Implement Cbrt for now as Exp(Log(x) / 3) - (The way Exp2M1 and Exp10M1 are already implemented, they implicitly vectorize now as well.)
1 parent 7a60900 commit 3efe176

File tree

1 file changed

+94
-16
lines changed

1 file changed

+94
-16
lines changed

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.netcore.cs

Lines changed: 94 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12263,12 +12263,14 @@ public static Vector512<float> Invoke(Vector512<float> x)
1226312263
internal readonly struct Exp2Operator<T> : IUnaryOperator<T, T>
1226412264
where T : IExponentialFunctions<T>
1226512265
{
12266-
public static bool Vectorizable => false; // TODO: Vectorize
12266+
private const double NaturalLog2 = 0.6931471805599453;
12267+
12268+
public static bool Vectorizable => typeof(T) == typeof(float) || typeof(T) == typeof(double);
1226712269

1226812270
public static T Invoke(T x) => T.Exp2(x);
12269-
public static Vector128<T> Invoke(Vector128<T> x) => throw new NotSupportedException();
12270-
public static Vector256<T> Invoke(Vector256<T> x) => throw new NotSupportedException();
12271-
public static Vector512<T> Invoke(Vector512<T> x) => throw new NotSupportedException();
12271+
public static Vector128<T> Invoke(Vector128<T> x) => ExpOperator<T>.Invoke(x * Vector128.Create(T.CreateTruncating(NaturalLog2)));
12272+
public static Vector256<T> Invoke(Vector256<T> x) => ExpOperator<T>.Invoke(x * Vector256.Create(T.CreateTruncating(NaturalLog2)));
12273+
public static Vector512<T> Invoke(Vector512<T> x) => ExpOperator<T>.Invoke(x * Vector512.Create(T.CreateTruncating(NaturalLog2)));
1227212274
}
1227312275

1227412276
/// <summary>T.Exp2M1(x)</summary>
@@ -12287,12 +12289,14 @@ public static Vector512<float> Invoke(Vector512<float> x)
1228712289
internal readonly struct Exp10Operator<T> : IUnaryOperator<T, T>
1228812290
where T : IExponentialFunctions<T>
1228912291
{
12290-
public static bool Vectorizable => false; // TODO: Vectorize
12292+
private const double NaturalLog10 = 2.302585092994046;
12293+
12294+
public static bool Vectorizable => typeof(T) == typeof(float) || typeof(T) == typeof(double);
1229112295

1229212296
public static T Invoke(T x) => T.Exp10(x);
12293-
public static Vector128<T> Invoke(Vector128<T> x) => throw new NotSupportedException();
12294-
public static Vector256<T> Invoke(Vector256<T> x) => throw new NotSupportedException();
12295-
public static Vector512<T> Invoke(Vector512<T> x) => throw new NotSupportedException();
12297+
public static Vector128<T> Invoke(Vector128<T> x) => ExpOperator<T>.Invoke(x * Vector128.Create(T.CreateTruncating(NaturalLog10)));
12298+
public static Vector256<T> Invoke(Vector256<T> x) => ExpOperator<T>.Invoke(x * Vector256.Create(T.CreateTruncating(NaturalLog10)));
12299+
public static Vector512<T> Invoke(Vector512<T> x) => ExpOperator<T>.Invoke(x * Vector512.Create(T.CreateTruncating(NaturalLog10)));
1229612300
}
1229712301

1229812302
/// <summary>T.Exp10M1(x)</summary>
@@ -12311,11 +12315,48 @@ public static Vector512<float> Invoke(Vector512<float> x)
1231112315
internal readonly struct PowOperator<T> : IBinaryOperator<T>
1231212316
where T : IPowerFunctions<T>
1231312317
{
12314-
public static bool Vectorizable => false; // TODO: Vectorize
12318+
public static bool Vectorizable => typeof(T) == typeof(float) || typeof(T) == typeof(double);
12319+
1231512320
public static T Invoke(T x, T y) => T.Pow(x, y);
12316-
public static Vector128<T> Invoke(Vector128<T> x, Vector128<T> y) => throw new NotSupportedException();
12317-
public static Vector256<T> Invoke(Vector256<T> x, Vector256<T> y) => throw new NotSupportedException();
12318-
public static Vector512<T> Invoke(Vector512<T> x, Vector512<T> y) => throw new NotSupportedException();
12321+
12322+
public static Vector128<T> Invoke(Vector128<T> x, Vector128<T> y)
12323+
{
12324+
if (typeof(T) == typeof(float))
12325+
{
12326+
return ExpOperator<float>.Invoke(y.AsSingle() * LogOperator<float>.Invoke(x.AsSingle())).As<float, T>();
12327+
}
12328+
else
12329+
{
12330+
Debug.Assert(typeof(T) == typeof(double));
12331+
return ExpOperator<double>.Invoke(y.AsDouble() * LogOperator<double>.Invoke(x.AsDouble())).As<double, T>();
12332+
}
12333+
}
12334+
12335+
public static Vector256<T> Invoke(Vector256<T> x, Vector256<T> y)
12336+
{
12337+
if (typeof(T) == typeof(float))
12338+
{
12339+
return ExpOperator<float>.Invoke(y.AsSingle() * LogOperator<float>.Invoke(x.AsSingle())).As<float, T>();
12340+
}
12341+
else
12342+
{
12343+
Debug.Assert(typeof(T) == typeof(double));
12344+
return ExpOperator<double>.Invoke(y.AsDouble() * LogOperator<double>.Invoke(x.AsDouble())).As<double, T>();
12345+
}
12346+
}
12347+
12348+
public static Vector512<T> Invoke(Vector512<T> x, Vector512<T> y)
12349+
{
12350+
if (typeof(T) == typeof(float))
12351+
{
12352+
return ExpOperator<float>.Invoke(y.AsSingle() * LogOperator<float>.Invoke(x.AsSingle())).As<float, T>();
12353+
}
12354+
else
12355+
{
12356+
Debug.Assert(typeof(T) == typeof(double));
12357+
return ExpOperator<double>.Invoke(y.AsDouble() * LogOperator<double>.Invoke(x.AsDouble())).As<double, T>();
12358+
}
12359+
}
1231912360
}
1232012361

1232112362
/// <summary>T.Sqrt(x)</summary>
@@ -12333,11 +12374,48 @@ public static Vector512<float> Invoke(Vector512<float> x)
1233312374
internal readonly struct CbrtOperator<T> : IUnaryOperator<T, T>
1233412375
where T : IRootFunctions<T>
1233512376
{
12336-
public static bool Vectorizable => false; // TODO: Vectorize
12377+
public static bool Vectorizable => typeof(T) == typeof(float) || typeof(T) == typeof(double);
12378+
1233712379
public static T Invoke(T x) => T.Cbrt(x);
12338-
public static Vector128<T> Invoke(Vector128<T> x) => throw new NotSupportedException();
12339-
public static Vector256<T> Invoke(Vector256<T> x) => throw new NotSupportedException();
12340-
public static Vector512<T> Invoke(Vector512<T> x) => throw new NotSupportedException();
12380+
12381+
public static Vector128<T> Invoke(Vector128<T> x)
12382+
{
12383+
if (typeof(T) == typeof(float))
12384+
{
12385+
return ExpOperator<float>.Invoke(LogOperator<float>.Invoke(x.AsSingle()) / Vector128.Create(3f)).As<float, T>();
12386+
}
12387+
else
12388+
{
12389+
Debug.Assert(typeof(T) == typeof(double));
12390+
return ExpOperator<double>.Invoke(LogOperator<double>.Invoke(x.AsDouble()) / Vector128.Create(3d)).As<double, T>();
12391+
}
12392+
}
12393+
12394+
public static Vector256<T> Invoke(Vector256<T> x)
12395+
{
12396+
if (typeof(T) == typeof(float))
12397+
{
12398+
return ExpOperator<float>.Invoke(LogOperator<float>.Invoke(x.AsSingle()) / Vector256.Create(3f)).As<float, T>();
12399+
}
12400+
else
12401+
{
12402+
Debug.Assert(typeof(T) == typeof(double));
12403+
return ExpOperator<double>.Invoke(LogOperator<double>.Invoke(x.AsDouble()) / Vector256.Create(3d)).As<double, T>();
12404+
}
12405+
}
12406+
12407+
public static Vector512<T> Invoke(Vector512<T> x)
12408+
{
12409+
if (typeof(T) == typeof(float))
12410+
{
12411+
return ExpOperator<float>.Invoke(LogOperator<float>.Invoke(x.AsSingle()) / Vector512.Create(3f)).As<float, T>();
12412+
}
12413+
else
12414+
{
12415+
Debug.Assert(typeof(T) == typeof(double));
12416+
return ExpOperator<double>.Invoke(LogOperator<double>.Invoke(x.AsDouble()) / Vector512.Create(3d)).As<double, T>();
12417+
}
12418+
}
1234112419
}
1234212420

1234312421
/// <summary>T.Hypot(x, y)</summary>

0 commit comments

Comments
 (0)