Skip to content

Commit 4fc943c

Browse files
authored
Add TensorPrimitives.ConvertTruncating/Saturating/Checked (#97572)
* Add TensorPrimitives.ConvertTruncating/Saturating/Checked * Fix auto-indentation * Add comment * Fix failures
1 parent 9bffb0f commit 4fc943c

File tree

5 files changed

+1516
-461
lines changed

5 files changed

+1516
-461
lines changed

src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ public static void BitwiseOr<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T>
3535
public static void BitwiseOr<T>(System.ReadOnlySpan<T> x, T y, System.Span<T> destination) where T : System.Numerics.IBitwiseOperators<T, T, T> { }
3636
public static void Cbrt<T>(System.ReadOnlySpan<T> x, System.Span<T> destination) where T : System.Numerics.IRootFunctions<T> { }
3737
public static void Ceiling<T>(System.ReadOnlySpan<T> x, System.Span<T> destination) where T : System.Numerics.IFloatingPoint<T> { }
38+
public static void ConvertChecked<TFrom, TTo>(System.ReadOnlySpan<TFrom> source, System.Span<TTo> destination) where TFrom : System.Numerics.INumberBase<TFrom> where TTo : System.Numerics.INumberBase<TTo> { }
39+
public static void ConvertSaturating<TFrom, TTo>(System.ReadOnlySpan<TFrom> source, System.Span<TTo> destination) where TFrom : System.Numerics.INumberBase<TFrom> where TTo : System.Numerics.INumberBase<TTo> { }
40+
public static void ConvertTruncating<TFrom, TTo>(System.ReadOnlySpan<TFrom> source, System.Span<TTo> destination) where TFrom : System.Numerics.INumberBase<TFrom> where TTo : System.Numerics.INumberBase<TTo> { }
3841
public static void ConvertToHalf(System.ReadOnlySpan<float> source, System.Span<System.Half> destination) { }
3942
public static void ConvertToSingle(System.ReadOnlySpan<System.Half> source, System.Span<float> destination) { }
4043
public static void CopySign<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> sign, System.Span<T> destination) where T : System.Numerics.INumber<T> { }

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.Single.netcore.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public static unsafe partial class TensorPrimitives
4343
{
4444
private static void InvokeSpanIntoSpan<TSingleUnaryOperator>(
4545
ReadOnlySpan<float> x, Span<float> destination)
46-
where TSingleUnaryOperator : struct, IUnaryOperator<float> =>
46+
where TSingleUnaryOperator : struct, IUnaryOperator<float, float> =>
4747
InvokeSpanIntoSpan<float, TSingleUnaryOperator>(x, destination);
4848

4949
private static void InvokeSpanSpanIntoSpan<TSingleBinaryOperator>(
@@ -58,7 +58,7 @@ private static void InvokeSpanScalarIntoSpan<TSingleBinaryOperator>(
5858

5959
private static unsafe void InvokeSpanScalarIntoSpan<TSingleTransformOperator, TSingleBinaryOperator>(
6060
ReadOnlySpan<float> x, float y, Span<float> destination)
61-
where TSingleTransformOperator : struct, IUnaryOperator<float>
61+
where TSingleTransformOperator : struct, IUnaryOperator<float, float>
6262
where TSingleBinaryOperator : struct, IBinaryOperator<float> =>
6363
InvokeSpanScalarIntoSpan<float, TSingleTransformOperator, TSingleBinaryOperator>(x, y, destination);
6464

@@ -79,7 +79,7 @@ private static void InvokeSpanScalarSpanIntoSpan<TSingleTernaryOperator>(
7979

8080
private static unsafe float Aggregate<TSingleTransformOperator, TSingleAggregationOperator>(
8181
ReadOnlySpan<float> x)
82-
where TSingleTransformOperator : struct, IUnaryOperator<float>
82+
where TSingleTransformOperator : struct, IUnaryOperator<float, float>
8383
where TSingleAggregationOperator : struct, IAggregationOperator<float> =>
8484
Aggregate<float, TSingleTransformOperator, TSingleAggregationOperator>(x);
8585

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.T.cs

Lines changed: 250 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

4+
using System.Runtime.CompilerServices;
5+
46
namespace System.Numerics.Tensors
57
{
68
/// <summary>Performs primitive tensor operations over spans of memory.</summary>
@@ -488,6 +490,249 @@ public static void Ceiling<T>(ReadOnlySpan<T> x, Span<T> destination)
488490
where T : IFloatingPoint<T> =>
489491
InvokeSpanIntoSpan<T, CeilingOperator<T>>(x, destination);
490492

493+
/// <summary>
494+
/// Copies <paramref name="source"/> to <paramref name="destination"/>, converting each <typeparamref name="TFrom"/>
495+
/// value to a <typeparamref name="TTo"/> value.
496+
/// </summary>
497+
/// <param name="source">The source span from which to copy values.</param>
498+
/// <param name="destination">The destination span into which the converted values should be written.</param>
499+
/// <exception cref="ArgumentException">Destination is too short.</exception>
500+
/// <remarks>
501+
/// <para>
502+
/// This method effectively computes <c><paramref name="destination" />[i] = TTo.CreateChecked(<paramref name="source"/>[i])</c>.
503+
/// </para>
504+
/// </remarks>
505+
public static void ConvertChecked<TFrom, TTo>(ReadOnlySpan<TFrom> source, Span<TTo> destination)
506+
where TFrom : INumberBase<TFrom>
507+
where TTo : INumberBase<TTo>
508+
{
509+
if (!TryConvertUniversal(source, destination))
510+
{
511+
InvokeSpanIntoSpan<TFrom, TTo, ConvertCheckedFallbackOperator<TFrom, TTo>>(source, destination);
512+
}
513+
}
514+
515+
/// <summary>
516+
/// Copies <paramref name="source"/> to <paramref name="destination"/>, converting each <typeparamref name="TFrom"/>
517+
/// value to a <typeparamref name="TTo"/> value.
518+
/// </summary>
519+
/// <param name="source">The source span from which to copy values.</param>
520+
/// <param name="destination">The destination span into which the converted values should be written.</param>
521+
/// <exception cref="ArgumentException">Destination is too short.</exception>
522+
/// <remarks>
523+
/// <para>
524+
/// This method effectively computes <c><paramref name="destination" />[i] = TTo.CreateSaturating(<paramref name="source"/>[i])</c>.
525+
/// </para>
526+
/// </remarks>
527+
public static void ConvertSaturating<TFrom, TTo>(ReadOnlySpan<TFrom> source, Span<TTo> destination)
528+
where TFrom : INumberBase<TFrom>
529+
where TTo : INumberBase<TTo>
530+
{
531+
if (!TryConvertUniversal(source, destination))
532+
{
533+
InvokeSpanIntoSpan<TFrom, TTo, ConvertSaturatingFallbackOperator<TFrom, TTo>>(source, destination);
534+
}
535+
}
536+
537+
/// <summary>
538+
/// Copies <paramref name="source"/> to <paramref name="destination"/>, converting each <typeparamref name="TFrom"/>
539+
/// value to a <typeparamref name="TTo"/> value.
540+
/// </summary>
541+
/// <param name="source">The source span from which to copy values.</param>
542+
/// <param name="destination">The destination span into which the converted values should be written.</param>
543+
/// <exception cref="ArgumentException">Destination is too short.</exception>
544+
/// <remarks>
545+
/// <para>
546+
/// This method effectively computes <c><paramref name="destination" />[i] = TTo.CreateTruncating(<paramref name="source"/>[i])</c>.
547+
/// </para>
548+
/// </remarks>
549+
public static void ConvertTruncating<TFrom, TTo>(ReadOnlySpan<TFrom> source, Span<TTo> destination)
550+
where TFrom : INumberBase<TFrom>
551+
where TTo : INumberBase<TTo>
552+
{
553+
if (TryConvertUniversal(source, destination))
554+
{
555+
return;
556+
}
557+
558+
if (((typeof(TFrom) == typeof(byte) || typeof(TFrom) == typeof(sbyte)) && (typeof(TTo) == typeof(byte) || typeof(TTo) == typeof(sbyte))) ||
559+
((typeof(TFrom) == typeof(ushort) || typeof(TFrom) == typeof(short)) && (typeof(TTo) == typeof(ushort) || typeof(TTo) == typeof(short))) ||
560+
((IsUInt32Like<TFrom>() || IsInt32Like<TFrom>()) && (IsUInt32Like<TTo>() || IsInt32Like<TTo>())) ||
561+
((IsUInt64Like<TFrom>() || IsInt64Like<TFrom>()) && (IsUInt64Like<TTo>() || IsInt64Like<TTo>())))
562+
{
563+
source.CopyTo(Rename<TTo, TFrom>(destination));
564+
return;
565+
}
566+
567+
if (typeof(TFrom) == typeof(float) && IsUInt32Like<TTo>())
568+
{
569+
InvokeSpanIntoSpan<float, uint, ConvertSingleToUInt32>(Rename<TFrom, float>(source), Rename<TTo, uint>(destination));
570+
return;
571+
}
572+
573+
if (typeof(TFrom) == typeof(float) && IsInt32Like<TTo>())
574+
{
575+
InvokeSpanIntoSpan<float, int, ConvertSingleToInt32>(Rename<TFrom, float>(source), Rename<TTo, int>(destination));
576+
return;
577+
}
578+
579+
if (typeof(TFrom) == typeof(double) && IsUInt64Like<TTo>())
580+
{
581+
InvokeSpanIntoSpan<double, ulong, ConvertDoubleToUInt64>(Rename<TFrom, double>(source), Rename<TTo, ulong>(destination));
582+
return;
583+
}
584+
585+
if (typeof(TFrom) == typeof(double) && IsInt64Like<TTo>())
586+
{
587+
InvokeSpanIntoSpan<double, long, ConvertDoubleToInt64>(Rename<TFrom, double>(source), Rename<TTo, long>(destination));
588+
return;
589+
}
590+
591+
if (typeof(TFrom) == typeof(ushort) && typeof(TTo) == typeof(byte))
592+
{
593+
InvokeSpanIntoSpan_2to1<ushort, byte, NarrowUInt16ToByteOperator>(Rename<TFrom, ushort>(source), Rename<TTo, byte>(destination));
594+
return;
595+
}
596+
597+
if (typeof(TFrom) == typeof(short) && typeof(TTo) == typeof(sbyte))
598+
{
599+
InvokeSpanIntoSpan_2to1<short, sbyte, NarrowInt16ToSByteOperator>(Rename<TFrom, short>(source), Rename<TTo, sbyte>(destination));
600+
return;
601+
}
602+
603+
if (IsUInt32Like<TFrom>() && typeof(TTo) == typeof(ushort))
604+
{
605+
InvokeSpanIntoSpan_2to1<uint, ushort, NarrowUInt32ToUInt16Operator>(Rename<TFrom, uint>(source), Rename<TTo, ushort>(destination));
606+
return;
607+
}
608+
609+
if (IsInt32Like<TFrom>() && typeof(TTo) == typeof(short))
610+
{
611+
InvokeSpanIntoSpan_2to1<int, short, NarrowInt32ToInt16Operator>(Rename<TFrom, int>(source), Rename<TTo, short>(destination));
612+
return;
613+
}
614+
615+
if (IsUInt64Like<TFrom>() && IsUInt32Like<TTo>())
616+
{
617+
InvokeSpanIntoSpan_2to1<ulong, uint, NarrowUInt64ToUInt32Operator>(Rename<TFrom, ulong>(source), Rename<TTo, uint>(destination));
618+
return;
619+
}
620+
621+
if (IsInt64Like<TFrom>() && IsInt32Like<TTo>())
622+
{
623+
InvokeSpanIntoSpan_2to1<long, int, NarrowInt64ToInt32Operator>(Rename<TFrom, long>(source), Rename<TTo, int>(destination));
624+
return;
625+
}
626+
627+
InvokeSpanIntoSpan<TFrom, TTo, ConvertTruncatingFallbackOperator<TFrom, TTo>>(source, destination);
628+
}
629+
630+
/// <summary>Performs conversions that are the same regardless of checked, truncating, or saturation.</summary>
631+
[MethodImpl(MethodImplOptions.AggressiveInlining)] // at most one of the branches will be kept
632+
private static bool TryConvertUniversal<TFrom, TTo>(ReadOnlySpan<TFrom> source, Span<TTo> destination)
633+
where TFrom : INumberBase<TFrom>
634+
where TTo : INumberBase<TTo>
635+
{
636+
if (typeof(TFrom) == typeof(TTo))
637+
{
638+
if (source.Length > destination.Length)
639+
{
640+
ThrowHelper.ThrowArgument_DestinationTooShort();
641+
}
642+
643+
ValidateInputOutputSpanNonOverlapping(source, Rename<TTo, TFrom>(destination));
644+
645+
source.CopyTo(Rename<TTo, TFrom>(destination));
646+
return true;
647+
}
648+
649+
if (IsInt32Like<TFrom>() && typeof(TTo) == typeof(float))
650+
{
651+
InvokeSpanIntoSpan<int, float, ConvertInt32ToSingle>(Rename<TFrom, int>(source), Rename<TTo, float>(destination));
652+
return true;
653+
}
654+
655+
if (IsUInt32Like<TFrom>() && typeof(TTo) == typeof(float))
656+
{
657+
InvokeSpanIntoSpan<uint, float, ConvertUInt32ToSingle>(Rename<TFrom, uint>(source), Rename<TTo, float>(destination));
658+
return true;
659+
}
660+
661+
if (IsInt64Like<TFrom>() && typeof(TTo) == typeof(double))
662+
{
663+
InvokeSpanIntoSpan<long, double, ConvertInt64ToDouble>(Rename<TFrom, long>(source), Rename<TTo, double>(destination));
664+
return true;
665+
}
666+
667+
if (IsUInt64Like<TFrom>() && typeof(TTo) == typeof(double))
668+
{
669+
InvokeSpanIntoSpan<ulong, double, ConvertUInt64ToDouble>(Rename<TFrom, ulong>(source), Rename<TTo, double>(destination));
670+
return true;
671+
}
672+
673+
if (typeof(TFrom) == typeof(float) && typeof(TTo) == typeof(Half))
674+
{
675+
ConvertToHalf(Rename<TFrom, float>(source), Rename<TTo, Half>(destination));
676+
return true;
677+
}
678+
679+
if (typeof(TFrom) == typeof(Half) && typeof(TTo) == typeof(float))
680+
{
681+
ConvertToSingle(Rename<TFrom, Half>(source), Rename<TTo, float>(destination));
682+
return true;
683+
}
684+
685+
if (typeof(TFrom) == typeof(float) && typeof(TTo) == typeof(double))
686+
{
687+
InvokeSpanIntoSpan_1to2<float, double, WidenSingleToDoubleOperator>(Rename<TFrom, float>(source), Rename<TTo, double>(destination));
688+
return true;
689+
}
690+
691+
if (typeof(TFrom) == typeof(double) && typeof(TTo) == typeof(float))
692+
{
693+
InvokeSpanIntoSpan_2to1<double, float, NarrowDoubleToSingleOperator>(Rename<TFrom, double>(source), Rename<TTo, float>(destination));
694+
return true;
695+
}
696+
697+
if (typeof(TFrom) == typeof(byte) && typeof(TTo) == typeof(ushort))
698+
{
699+
InvokeSpanIntoSpan_1to2<byte, ushort, WidenByteToUInt16Operator>(Rename<TFrom, byte>(source), Rename<TTo, ushort>(destination));
700+
return true;
701+
}
702+
703+
if (typeof(TFrom) == typeof(sbyte) && typeof(TTo) == typeof(short))
704+
{
705+
InvokeSpanIntoSpan_1to2<sbyte, short, WidenSByteToInt16Operator>(Rename<TFrom, sbyte>(source), Rename<TTo, short>(destination));
706+
return true;
707+
}
708+
709+
if (typeof(TFrom) == typeof(ushort) && IsUInt32Like<TTo>())
710+
{
711+
InvokeSpanIntoSpan_1to2<ushort, uint, WidenUInt16ToUInt32Operator>(Rename<TFrom, ushort>(source), Rename<TTo, uint>(destination));
712+
return true;
713+
}
714+
715+
if (typeof(TFrom) == typeof(short) && IsInt32Like<TTo>())
716+
{
717+
InvokeSpanIntoSpan_1to2<short, int, WidenInt16ToInt32Operator>(Rename<TFrom, short>(source), Rename<TTo, int>(destination));
718+
return true;
719+
}
720+
721+
if (IsUInt32Like<TTo>() && IsUInt64Like<TTo>())
722+
{
723+
InvokeSpanIntoSpan_1to2<uint, ulong, WidenUInt32ToUInt64Operator>(Rename<TFrom, uint>(source), Rename<TTo, ulong>(destination));
724+
return true;
725+
}
726+
727+
if (IsInt32Like<TFrom>() && IsInt64Like<TTo>())
728+
{
729+
InvokeSpanIntoSpan_1to2<int, long, WidenInt32ToInt64Operator>(Rename<TFrom, int>(source), Rename<TTo, long>(destination));
730+
return true;
731+
}
732+
733+
return false;
734+
}
735+
491736
/// <summary>Computes the element-wise result of copying the sign from one number to another number in the specified tensors.</summary>
492737
/// <param name="x">The first tensor, represented as a span.</param>
493738
/// <param name="sign">The second tensor, represented as a span.</param>
@@ -963,15 +1208,14 @@ public static void Ieee754Remainder<T>(T x, ReadOnlySpan<T> y, Span<T> destinati
9631208
public static void ILogB<T>(ReadOnlySpan<T> x, Span<int> destination)
9641209
where T : IFloatingPointIeee754<T>
9651210
{
966-
if (x.Length > destination.Length)
1211+
if (typeof(T) == typeof(double))
9671212
{
968-
ThrowHelper.ThrowArgument_DestinationTooShort();
1213+
// Special-case double as the only vectorizable floating-point type whose size != sizeof(int).
1214+
InvokeSpanIntoSpan_2to1<double, int, ILogBDoubleOperator>(Rename<T, double>(x), destination);
9691215
}
970-
971-
// TODO: Vectorize
972-
for (int i = 0; i < x.Length; i++)
1216+
else
9731217
{
974-
destination[i] = T.ILogB(x[i]);
1218+
InvokeSpanIntoSpan<T, int, ILogBOperator<T>>(x, destination);
9751219
}
9761220
}
9771221

0 commit comments

Comments
 (0)