diff --git a/src/libraries/System.Formats.Nrbf/ref/System.Formats.Nrbf.cs b/src/libraries/System.Formats.Nrbf/ref/System.Formats.Nrbf.cs index f82bbb96732c95..292a5eb1038d53 100644 --- a/src/libraries/System.Formats.Nrbf/ref/System.Formats.Nrbf.cs +++ b/src/libraries/System.Formats.Nrbf/ref/System.Formats.Nrbf.cs @@ -9,7 +9,6 @@ namespace System.Formats.Nrbf public abstract partial class ArrayRecord : System.Formats.Nrbf.SerializationRecord { internal ArrayRecord() { } - public virtual long FlattenedLength { get { throw null; } } public override System.Formats.Nrbf.SerializationRecordId Id { get { throw null; } } public abstract System.ReadOnlySpan Lengths { get; } public int Rank { get { throw null; } } diff --git a/src/libraries/System.Formats.Nrbf/src/PACKAGE.md b/src/libraries/System.Formats.Nrbf/src/PACKAGE.md index c301459358838b..23e5ac389d3d14 100644 --- a/src/libraries/System.Formats.Nrbf/src/PACKAGE.md +++ b/src/libraries/System.Formats.Nrbf/src/PACKAGE.md @@ -54,7 +54,7 @@ There are more than a dozen different serialization [record types](https://learn - `PrimitiveTypeRecord` derives from the non-generic [PrimitiveTypeRecord](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.primitivetyperecord), which also exposes a [Value](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.primitivetyperecord.value) property. But on the base class, the value is returned as `object` (which introduces boxing for value types). - [ClassRecord](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.classrecord): describes all `class` and `struct` besides the aforementioned primitive types. - [ArrayRecord](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.arrayrecord): describes all array records, including jagged and multi-dimensional arrays. -- [`SZArrayRecord`](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.szarrayrecord-1): describes single-dimensional, zero-indexed array records, where `T` can be either a primitive type or a `ClassRecord`. +- [`SZArrayRecord`](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.szarrayrecord-1): describes single-dimensional, zero-indexed array records, where `T` can be either a primitive type or a `SerializationRecord`. ```csharp SerializationRecord rootObject = NrbfDecoder.Decode(payload); // payload is a Stream diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/AllowedRecordType.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/AllowedRecordType.cs index 063a2430782064..60623ac0dbde3a 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/AllowedRecordType.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/AllowedRecordType.cs @@ -28,12 +28,13 @@ internal enum AllowedRecordTypes : uint ArraySingleString = 1 << SerializationRecordType.ArraySingleString, Nulls = ObjectNull | ObjectNullMultiple256 | ObjectNullMultiple, + Arrays = ArraySingleObject | ArraySinglePrimitive | ArraySingleString | BinaryArray, /// /// Any .NET object (a primitive, a reference type, a reference or single null). /// AnyObject = MemberPrimitiveTyped - | ArraySingleObject | ArraySinglePrimitive | ArraySingleString | BinaryArray + | Arrays | ClassWithId | ClassWithMembersAndTypes | SystemClassWithMembersAndTypes | BinaryObjectString | MemberReference diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayRecord.cs index 237b7b72a27198..c18208668225f8 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayRecord.cs @@ -4,6 +4,9 @@ using System.Diagnostics.CodeAnalysis; using System.Reflection.Metadata; using System.Formats.Nrbf.Utils; +using System.Collections.Generic; +using System.Diagnostics; +using System.Runtime.Serialization; namespace System.Formats.Nrbf; @@ -27,12 +30,6 @@ private protected ArrayRecord(ArrayInfo arrayInfo) /// A buffer of integers that represent the number of elements in every dimension. public abstract ReadOnlySpan Lengths { get; } - /// - /// When overridden in a derived class, gets the total number of all elements in every dimension. - /// - /// A number that represent the total number of all elements in every dimension. - public virtual long FlattenedLength => ArrayInfo.FlattenedLength; - /// /// Gets the rank of the array. /// @@ -118,4 +115,86 @@ private void HandleNext(object value, NextInfo info, int size) } internal abstract (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType(); + + internal static void Populate(List source, Array destination, int[] lengths, AllowedRecordTypes allowedRecordTypes, bool allowNulls) + { + int[] indices = new int[lengths.Length]; + nuint numElementsWritten = 0; // only for debugging; not used in release builds + + foreach (SerializationRecord record in source) + { + object? value = GetActualValue(record, allowedRecordTypes, out int incrementCount); + if (value is not null) + { + // null is a default element for all array of reference types, so we don't call SetValue for nulls. + destination.SetValue(value, indices); + Debug.Assert(incrementCount == 1, "IncrementCount other than 1 is allowed only for null records."); + } + else if (!allowNulls) + { + ThrowHelper.ThrowArrayContainedNulls(); + } + + while (incrementCount > 0) + { + incrementCount--; + numElementsWritten++; + int dimension = indices.Length - 1; + while (dimension >= 0) + { + indices[dimension]++; + if (indices[dimension] < lengths[dimension]) + { + break; + } + indices[dimension] = 0; + dimension--; + } + + if (dimension < 0) + { + break; + } + } + } + + Debug.Assert(numElementsWritten == (uint)source.Count, "We should have traversed the entirety of the source records collection."); + Debug.Assert(numElementsWritten == (ulong)destination.LongLength, "We should have traversed the entirety of the destination array."); + } + + private static object? GetActualValue(SerializationRecord record, AllowedRecordTypes allowedRecordTypes, out int repeatCount) + { + repeatCount = 1; + + if (record is NullsRecord nullsRecord) + { + repeatCount = nullsRecord.NullCount; + return null; + } + else if (record.RecordType == SerializationRecordType.MemberReference) + { + record = ((MemberReferenceRecord)record).GetReferencedRecord(); + } + + if (allowedRecordTypes == AllowedRecordTypes.BinaryObjectString) + { + if (record is not BinaryObjectStringRecord stringRecord) + { + throw new SerializationException(SR.Serialization_InvalidReference); + } + + return stringRecord.Value; + } + else if (allowedRecordTypes == AllowedRecordTypes.Arrays) + { + if (record is not ArrayRecord arrayRecord) + { + throw new SerializationException(SR.Serialization_InvalidReference); + } + + return arrayRecord; + } + + return record; + } } diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayRectangularPrimitiveRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayRectangularPrimitiveRecord.cs new file mode 100644 index 00000000000000..39c66c5f2af0d9 --- /dev/null +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayRectangularPrimitiveRecord.cs @@ -0,0 +1,83 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Formats.Nrbf.Utils; +using System.Linq; +using System.Reflection.Metadata; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading.Tasks; + +namespace System.Formats.Nrbf +{ + internal sealed class ArrayRectangularPrimitiveRecord : ArrayRecord where T : unmanaged + { + private readonly int[] _lengths; + private readonly IReadOnlyList _values; + private TypeName? _typeName; + + internal ArrayRectangularPrimitiveRecord(ArrayInfo arrayInfo, int[] lengths, IReadOnlyList values) : base(arrayInfo) + { + _lengths = lengths; + _values = values; + ValuesToRead = 0; // there is nothing to read anymore + } + + public override ReadOnlySpan Lengths => _lengths; + + public override SerializationRecordType RecordType => SerializationRecordType.BinaryArray; + + public override TypeName TypeName + => _typeName ??= TypeNameHelpers.GetPrimitiveTypeName(TypeNameHelpers.GetPrimitiveType()).MakeArrayTypeName(Rank); + + internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType() => throw new InvalidOperationException(); + + private protected override void AddValue(object value) => throw new InvalidOperationException(); + + [RequiresDynamicCode("May call Array.CreateInstance().")] + private protected override Array Deserialize(Type arrayType, bool allowNulls) + { + Array result = +#if NET9_0_OR_GREATER + Array.CreateInstanceFromArrayType(arrayType, _lengths); +#else + Array.CreateInstance(typeof(T), _lengths); +#endif + int[] indices = new int[_lengths.Length]; + nuint numElementsWritten = 0; // only for debugging; not used in release builds + + for (int i = 0; i < _values.Count; i++) + { + result.SetValue(_values[i], indices); + numElementsWritten++; + + int dimension = indices.Length - 1; + while (dimension >= 0) + { + indices[dimension]++; + if (indices[dimension] < Lengths[dimension]) + { + break; + } + indices[dimension] = 0; + dimension--; + } + + if (dimension < 0) + { + break; + } + } + + Debug.Assert(numElementsWritten == (uint)_values.Count, "We should have traversed the entirety of the source values collection."); + Debug.Assert(numElementsWritten == (ulong)result.LongLength, "We should have traversed the entirety of the destination array."); + + return result; + } + } +} diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleObjectRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleObjectRecord.cs index d0276ff3782e3a..2c402af7c35ab0 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleObjectRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleObjectRecord.cs @@ -15,9 +15,9 @@ namespace System.Formats.Nrbf; /// /// ArraySingleObject records are described in [MS-NRBF] 2.4.3.2. /// -internal sealed class ArraySingleObjectRecord : SZArrayRecord +internal sealed class ArraySingleObjectRecord : SZArrayRecord { - private ArraySingleObjectRecord(ArrayInfo arrayInfo) : base(arrayInfo) => Records = []; + internal ArraySingleObjectRecord(ArrayInfo arrayInfo) : base(arrayInfo) => Records = []; public override SerializationRecordType RecordType => SerializationRecordType.ArraySingleObject; @@ -27,25 +27,26 @@ public override TypeName TypeName private List Records { get; } /// - public override object?[] GetArray(bool allowNulls = true) - => (object?[])(allowNulls ? _arrayNullsAllowed ??= ToArray(true) : _arrayNullsNotAllowed ??= ToArray(false)); + public override SerializationRecord?[] GetArray(bool allowNulls = true) + => (SerializationRecord?[])(allowNulls ? _arrayNullsAllowed ??= ToArray(true) : _arrayNullsNotAllowed ??= ToArray(false)); - private object?[] ToArray(bool allowNulls) + private SerializationRecord?[] ToArray(bool allowNulls) { - object?[] values = new object?[Length]; + SerializationRecord?[] values = new SerializationRecord?[Length]; int valueIndex = 0; for (int recordIndex = 0; recordIndex < Records.Count; recordIndex++) { SerializationRecord record = Records[recordIndex]; - int nullCount = record is NullsRecord nullsRecord ? nullsRecord.NullCount : 0; - if (nullCount == 0) + if (record is MemberReferenceRecord referenceRecord) { - // "new object[] { }" is special cased because it allows for storing reference to itself. - values[valueIndex++] = record is MemberReferenceRecord referenceRecord && referenceRecord.Reference.Equals(Id) - ? values // a reference to self, and a way to get StackOverflow exception ;) - : record.GetValue(); + record = referenceRecord.GetReferencedRecord(); + } + + if (record is not NullsRecord nullsRecord) + { + values[valueIndex++] = record; continue; } @@ -54,6 +55,7 @@ public override TypeName TypeName ThrowHelper.ThrowArrayContainedNulls(); } + int nullCount = nullsRecord.NullCount; do { values[valueIndex++] = null; diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySinglePrimitiveRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySinglePrimitiveRecord.cs index a13507b97015a0..a28359d9bb13dc 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySinglePrimitiveRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySinglePrimitiveRecord.cs @@ -47,6 +47,11 @@ public override T[] GetArray(bool allowNulls = true) internal static IReadOnlyList DecodePrimitiveTypes(BinaryReader reader, int count) { + if (count == 0) + { + return Array.Empty(); // Empty arrays are allowed. + } + // For decimals, the input is provided as strings, so we can't compute the required size up-front. if (typeof(T) == typeof(decimal)) { @@ -71,18 +76,15 @@ internal static IReadOnlyList DecodePrimitiveTypes(BinaryReader reader, int c // allocations to be proportional to the amount of data present in the input stream, // which is a sufficient defense against DoS. - long requiredBytes = count; - if (typeof(T) == typeof(DateTime) || typeof(T) == typeof(TimeSpan)) - { - // We can't assume DateTime as represented by the runtime is 8 bytes. - // The only assumption we can make is that it's 8 bytes on the wire. - requiredBytes *= 8; - } - else if (typeof(T) != typeof(char)) - { - requiredBytes *= Unsafe.SizeOf(); - } + // We can't assume DateTime as represented by the runtime is 8 bytes. + // The only assumption we can make is that it's 8 bytes on the wire. + int sizeOfT = typeof(T) == typeof(DateTime) || typeof(T) == typeof(TimeSpan) + ? 8 + : typeof(T) != typeof(char) + ? Unsafe.SizeOf() + : 1; + long requiredBytes = (long)count * sizeOfT; bool? isDataAvailable = reader.IsDataAvailable(requiredBytes); if (!isDataAvailable.HasValue) { @@ -110,26 +112,49 @@ internal static IReadOnlyList DecodePrimitiveTypes(BinaryReader reader, int c // It's safe to pre-allocate, as we have ensured there is enough bytes in the stream. T[] result = new T[count]; - Span resultAsBytes = MemoryMarshal.AsBytes(result); -#if NET - reader.BaseStream.ReadExactly(resultAsBytes); + + // MemoryMarshal.AsBytes can fail for inputs that need more than int.MaxValue bytes. + // To avoid OverflowException, we read the data in chunks. + int MaxChunkLength = +#if !DEBUG + int.MaxValue / sizeOfT; #else - byte[] bytes = ArrayPool.Shared.Rent((int)Math.Min(requiredBytes, 256_000)); + // Let's use a different value for non-release builds to ensure this code path + // is covered with tests without the need of decoding enormous payloads. + 8_000; +#endif - while (!resultAsBytes.IsEmpty) +#if !NET + byte[] rented = ArrayPool.Shared.Rent((int)Math.Min(requiredBytes, 256_000)); +#endif + + Span valuesToRead = result.AsSpan(); + while (!valuesToRead.IsEmpty) { - int bytesRead = reader.Read(bytes, 0, Math.Min(resultAsBytes.Length, bytes.Length)); - if (bytesRead <= 0) + int sliceSize = Math.Min(valuesToRead.Length, MaxChunkLength); + + Span resultAsBytes = MemoryMarshal.AsBytes(valuesToRead.Slice(0, sliceSize)); +#if NET + reader.BaseStream.ReadExactly(resultAsBytes); +#else + while (!resultAsBytes.IsEmpty) { - ArrayPool.Shared.Return(bytes); - ThrowHelper.ThrowEndOfStreamException(); - } + int bytesRead = reader.Read(rented, 0, Math.Min(resultAsBytes.Length, rented.Length)); + if (bytesRead <= 0) + { + ArrayPool.Shared.Return(rented); + ThrowHelper.ThrowEndOfStreamException(); + } - bytes.AsSpan(0, bytesRead).CopyTo(resultAsBytes); - resultAsBytes = resultAsBytes.Slice(bytesRead); + rented.AsSpan(0, bytesRead).CopyTo(resultAsBytes); + resultAsBytes = resultAsBytes.Slice(bytesRead); + } +#endif + valuesToRead = valuesToRead.Slice(sliceSize); } - ArrayPool.Shared.Return(bytes); +#if !NET + ArrayPool.Shared.Return(rented); #endif if (!BitConverter.IsLittleEndian) @@ -176,7 +201,7 @@ internal static IReadOnlyList DecodePrimitiveTypes(BinaryReader reader, int c { // See DontCastBytesToBooleans test to see what could go wrong. bool[] booleans = (bool[])(object)result; - resultAsBytes = MemoryMarshal.AsBytes(result); + Span resultAsBytes = MemoryMarshal.AsBytes(result); for (int i = 0; i < booleans.Length; i++) { // We don't use the bool array to get the value, as an optimizing compiler or JIT could elide this. diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleStringRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleStringRecord.cs index 42b9eadd97bd55..38884aadc54693 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleStringRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleStringRecord.cs @@ -17,7 +17,7 @@ namespace System.Formats.Nrbf; /// internal sealed class ArraySingleStringRecord : SZArrayRecord { - private ArraySingleStringRecord(ArrayInfo arrayInfo) : base(arrayInfo) => Records = []; + internal ArraySingleStringRecord(ArrayInfo arrayInfo) : base(arrayInfo) => Records = []; public override SerializationRecordType RecordType => SerializationRecordType.ArraySingleString; diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/BinaryArrayRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/BinaryArrayRecord.cs deleted file mode 100644 index 41b1f73f03550e..00000000000000 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/BinaryArrayRecord.cs +++ /dev/null @@ -1,309 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; -using System.IO; -using System.Reflection.Metadata; -using System.Formats.Nrbf.Utils; -using System.Diagnostics; - -namespace System.Formats.Nrbf; - -/// -/// Represents an array other than single dimensional array of primitive types or . -/// -/// -/// BinaryArray records are described in [MS-NRBF] 2.4.3.1. -/// -internal sealed class BinaryArrayRecord : ArrayRecord -{ - private static HashSet PrimitiveTypes { get; } = - [ - typeof(bool), typeof(char), typeof(byte), typeof(sbyte), - typeof(short), typeof(ushort), typeof(int), typeof(uint), - typeof(long), typeof(ulong), typeof(IntPtr), typeof(UIntPtr), - typeof(float), typeof(double), typeof(decimal), typeof(DateTime), - typeof(TimeSpan), typeof(string), typeof(object) - ]; - - private TypeName? _typeName; - private long _totalElementsCount; - - private BinaryArrayRecord(ArrayInfo arrayInfo, MemberTypeInfo memberTypeInfo) - : base(arrayInfo) - { - MemberTypeInfo = memberTypeInfo; - Values = []; - // We need to parse all elements of the jagged array to obtain total elements count. - _totalElementsCount = -1; - } - - public override SerializationRecordType RecordType => SerializationRecordType.BinaryArray; - - /// - public override ReadOnlySpan Lengths => new int[1] { Length }; - - /// - public override long FlattenedLength - { - get - { - if (_totalElementsCount < 0) - { - _totalElementsCount = IsJagged - ? GetJaggedArrayFlattenedLength(this) - : ArrayInfo.FlattenedLength; - } - - return _totalElementsCount; - } - } - - public override TypeName TypeName - => _typeName ??= MemberTypeInfo.GetArrayTypeName(ArrayInfo); - - private int Length => ArrayInfo.GetSZArrayLength(); - - private MemberTypeInfo MemberTypeInfo { get; } - - private List Values { get; } - - [RequiresDynamicCode("May call Array.CreateInstance() and Type.MakeArrayType().")] - private protected override Array Deserialize(Type arrayType, bool allowNulls) - { - // We can not deserialize non-primitive types. - // This method returns arrays of ClassRecord for arrays of complex types. - Type elementType = MapElementType(arrayType, out bool isClassRecord); - Type actualElementType = arrayType.GetElementType()!; - Array array = -#if NET9_0_OR_GREATER - isClassRecord - ? Array.CreateInstance(elementType, Length) - : Array.CreateInstanceFromArrayType(arrayType, Length); -#else - Array.CreateInstance(elementType, Length); -#endif - - int resultIndex = 0; - foreach (object value in Values) - { - object item = value is MemberReferenceRecord referenceRecord - ? referenceRecord.GetReferencedRecord() - : value; - - if (item is not SerializationRecord record) - { - array.SetValue(item, resultIndex++); - continue; - } - - switch (record.RecordType) - { - case SerializationRecordType.BinaryArray: - case SerializationRecordType.ArraySinglePrimitive: - case SerializationRecordType.ArraySingleObject: - case SerializationRecordType.ArraySingleString: - - // Recursion depth is bounded by the depth of arrayType, which is - // a trustworthy Type instance. Don't need to worry about stack overflow. - - ArrayRecord nestedArrayRecord = (ArrayRecord)record; - Array nestedArray = nestedArrayRecord.GetArray(actualElementType, allowNulls); - array.SetValue(nestedArray, resultIndex++); - break; - case SerializationRecordType.ObjectNull: - case SerializationRecordType.ObjectNullMultiple256: - case SerializationRecordType.ObjectNullMultiple: - if (!allowNulls) - { - ThrowHelper.ThrowArrayContainedNulls(); - } - - int nullCount = ((NullsRecord)item).NullCount; - Debug.Assert(nullCount > 0, "All implementations of NullsRecord are expected to return a positive value for NullCount."); - do - { - array.SetValue(null, resultIndex++); - nullCount--; - } - while (nullCount > 0); - break; - default: - array.SetValue(record.GetValue(), resultIndex++); - break; - } - } - - Debug.Assert(resultIndex == array.Length, "We should have traversed the entirety of the newly created array."); - - return array; - } - - internal static ArrayRecord Decode(BinaryReader reader, RecordMap recordMap, PayloadOptions options) - { - SerializationRecordId objectId = SerializationRecordId.Decode(reader); - BinaryArrayType arrayType = reader.ReadArrayType(); - int rank = reader.ReadInt32(); - - bool isRectangular = arrayType is BinaryArrayType.Rectangular; - - // It is an arbitrary limit in the current CoreCLR type loader. - // Don't change this value without reviewing the loop a few lines below. - const int MaxSupportedArrayRank = 32; - - if (rank < 1 || rank > MaxSupportedArrayRank - || (rank != 1 && !isRectangular) - || (rank == 1 && isRectangular)) - { - ThrowHelper.ThrowInvalidValue(rank); - } - - int[] lengths = new int[rank]; // adversary-controlled, but acceptable since upper limit of 32 - long totalElementCount = 1; // to avoid integer overflow during the multiplication below - for (int i = 0; i < lengths.Length; i++) - { - lengths[i] = ArrayInfo.ParseValidArrayLength(reader); - totalElementCount *= lengths[i]; - - // n.b. This forbids "new T[Array.MaxLength, Array.MaxLength, Array.MaxLength, ..., 0]" - // but allows "new T[0, Array.MaxLength, Array.MaxLength, Array.MaxLength, ...]". But - // that's the same behavior that newarr and Array.CreateInstance exhibit, so at least - // we're consistent. - - if (totalElementCount > ArrayInfo.MaxArrayLength) - { - ThrowHelper.ThrowInvalidValue(lengths[i]); // max array size exceeded - } - } - - // Per BinaryReaderExtensions.ReadArrayType, we do not support nonzero offsets, so - // we don't need to read the NRBF stream 'LowerBounds' field here. - - MemberTypeInfo memberTypeInfo = MemberTypeInfo.Decode(reader, 1, options, recordMap); - ArrayInfo arrayInfo = new(objectId, totalElementCount, arrayType, rank); - - if (isRectangular) - { - return RectangularArrayRecord.Create(reader, arrayInfo, memberTypeInfo, lengths); - } - - return memberTypeInfo.ShouldBeRepresentedAsArrayOfClassRecords() - ? new ArrayOfClassesRecord(arrayInfo, memberTypeInfo) - : new BinaryArrayRecord(arrayInfo, memberTypeInfo); - } - - private static long GetJaggedArrayFlattenedLength(BinaryArrayRecord jaggedArrayRecord) - { - long result = 0; - Queue? jaggedArrayRecords = null; - - do - { - if (jaggedArrayRecords is not null) - { - jaggedArrayRecord = jaggedArrayRecords.Dequeue(); - } - - Debug.Assert(jaggedArrayRecord.IsJagged); - - // In theory somebody could create a payload that would represent - // a very nested array with total elements count > long.MaxValue. - // That is why this method is using checked arithmetic. - result = checked(result + jaggedArrayRecord.Length); // count the arrays themselves - - foreach (object value in jaggedArrayRecord.Values) - { - if (value is not SerializationRecord record) - { - continue; - } - - if (record.RecordType == SerializationRecordType.MemberReference) - { - record = ((MemberReferenceRecord)record).GetReferencedRecord(); - } - - switch (record.RecordType) - { - case SerializationRecordType.ArraySinglePrimitive: - case SerializationRecordType.ArraySingleObject: - case SerializationRecordType.ArraySingleString: - case SerializationRecordType.BinaryArray: - ArrayRecord nestedArrayRecord = (ArrayRecord)record; - if (nestedArrayRecord.IsJagged) - { - (jaggedArrayRecords ??= new()).Enqueue((BinaryArrayRecord)nestedArrayRecord); - } - else - { - // Don't call nestedArrayRecord.FlattenedLength to avoid any potential recursion, - // just call nestedArrayRecord.ArrayInfo.FlattenedLength that returns pre-computed value. - result = checked(result + nestedArrayRecord.ArrayInfo.FlattenedLength); - } - break; - default: - break; - } - } - } - while (jaggedArrayRecords is not null && jaggedArrayRecords.Count > 0); - - return result; - } - - private protected override void AddValue(object value) => Values.Add(value); - - internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType() - { - (AllowedRecordTypes allowed, PrimitiveType primitiveType) = MemberTypeInfo.GetNextAllowedRecordType(0); - - if (allowed != AllowedRecordTypes.None) - { - // It's an array, it can also contain multiple nulls - return (allowed | AllowedRecordTypes.Nulls, primitiveType); - } - - return (allowed, primitiveType); - } - - /// - /// Complex types must not be instantiated, but represented as ClassRecord. - /// For arrays of primitive types like int, string and object this method returns the element type. - /// For array of complex types, it returns ClassRecord. - /// It takes arrays of arrays into account: - /// - int[][] => int[] - /// - MyClass[][][] => ClassRecord[][] - /// - [RequiresDynamicCode("May call Type.MakeArrayType().")] - private static Type MapElementType(Type arrayType, out bool isClassRecord) - { - Type elementType = arrayType; - int arrayNestingDepth = 0; - - // Loop iteration counts are bound by the nesting depth of arrayType, - // which is a trustworthy input. No DoS concerns. - - while (elementType.IsArray) - { - elementType = elementType.GetElementType()!; - arrayNestingDepth++; - } - - if (PrimitiveTypes.Contains(elementType) || (Nullable.GetUnderlyingType(elementType) is Type nullable && PrimitiveTypes.Contains(nullable))) - { - isClassRecord = false; - return arrayNestingDepth == 1 ? elementType : arrayType.GetElementType()!; - } - - // Complex types are never instantiated, but represented as ClassRecord - isClassRecord = true; - Type complexType = typeof(ClassRecord); - for (int i = 1; i < arrayNestingDepth; i++) - { - complexType = complexType.MakeArrayType(); - } - - return complexType; - } -} diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ClassWithIdRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ClassWithIdRecord.cs index c643d3ce8c8465..2762be167b1112 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ClassWithIdRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ClassWithIdRecord.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Formats.Nrbf.Utils; using System.IO; using System.Runtime.Serialization; @@ -27,16 +28,57 @@ private ClassWithIdRecord(SerializationRecordId id, ClassRecord metadataClass) : internal ClassRecord MetadataClass { get; } - internal static ClassWithIdRecord Decode( + internal static SerializationRecord Decode( BinaryReader reader, RecordMap recordMap) { SerializationRecordId id = SerializationRecordId.Decode(reader); SerializationRecordId metadataId = SerializationRecordId.Decode(reader); - ClassRecord referencedRecord = recordMap.GetRecord(metadataId); + SerializationRecord metadataRecord = recordMap.GetRecord(metadataId); + if (metadataRecord is ClassRecord referencedClassRecord) + { + return new ClassWithIdRecord(id, referencedClassRecord); + } + else if (metadataRecord is PrimitiveTypeRecord primitiveTypeRecord + && !primitiveTypeRecord.Id.Equals(default) // such records always have Id provided + && metadataRecord is not BinaryObjectStringRecord) // it does not apply to BinaryObjectStringRecord + { + // BinaryFormatter represents primitive types as MemberPrimitiveTypedRecord + // only for arrays of objects. For other arrays, like arrays of some abstraction + // (example: new IComparable[] { int.MaxValue }), it uses SystemClassWithMembersAndTypes. + // SystemClassWithMembersAndTypes.Decode handles that by returning MemberPrimitiveTypedRecord. + // But arrays of such types typically have only one SystemClassWithMembersAndTypes record with + // all the member information and multiple ClassWithIdRecord records that just reuse that information. + return primitiveTypeRecord switch + { + MemberPrimitiveTypedRecord => Create(reader.ReadBoolean()), + MemberPrimitiveTypedRecord => Create(reader.ReadByte()), + MemberPrimitiveTypedRecord => Create(reader.ReadSByte()), + MemberPrimitiveTypedRecord => Create(reader.ParseChar()), + MemberPrimitiveTypedRecord => Create(reader.ReadInt16()), + MemberPrimitiveTypedRecord => Create(reader.ReadUInt16()), + MemberPrimitiveTypedRecord => Create(reader.ReadInt32()), + MemberPrimitiveTypedRecord => Create(reader.ReadUInt32()), + MemberPrimitiveTypedRecord => Create(reader.ReadInt64()), + MemberPrimitiveTypedRecord => Create(reader.ReadUInt64()), + MemberPrimitiveTypedRecord => Create(reader.ReadSingle()), + MemberPrimitiveTypedRecord => Create(reader.ReadDouble()), + MemberPrimitiveTypedRecord => Create(new IntPtr(reader.ReadInt64())), + MemberPrimitiveTypedRecord => Create(new UIntPtr(reader.ReadUInt64())), + MemberPrimitiveTypedRecord => Create(new TimeSpan(reader.ReadInt64())), + MemberPrimitiveTypedRecord => SystemClassWithMembersAndTypesRecord.DecodeDateTime(reader, id), + MemberPrimitiveTypedRecord => SystemClassWithMembersAndTypesRecord.DecodeDecimal(reader, id), + _ => throw new InvalidOperationException() + }; + } + else + { + throw new SerializationException(SR.Serialization_InvalidReference); + } - return new ClassWithIdRecord(id, referencedRecord); + SerializationRecord Create(T value) where T : unmanaged + => new MemberPrimitiveTypedRecord(value, id); } internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetNextAllowedRecordType() diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/JaggedArrayRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/JaggedArrayRecord.cs new file mode 100644 index 00000000000000..6ac97ef40675d6 --- /dev/null +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/JaggedArrayRecord.cs @@ -0,0 +1,65 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Reflection.Metadata; +using System.Formats.Nrbf.Utils; +using System.Diagnostics; +using System.Runtime.Serialization; + +namespace System.Formats.Nrbf; + +/// +/// Represents an array of arrays. +/// +/// +/// BinaryArray records are described in [MS-NRBF] 2.4.3.1. +/// +internal sealed class JaggedArrayRecord : ArrayRecord +{ + private readonly MemberTypeInfo _memberTypeInfo; + private readonly int[] _lengths; + private readonly List _records; + private readonly AllowedRecordTypes _allowedRecordTypes; + private TypeName? _typeName; + + internal JaggedArrayRecord(ArrayInfo arrayInfo, MemberTypeInfo memberTypeInfo, int[] lengths) + : base(arrayInfo) + { + _memberTypeInfo = memberTypeInfo; + _lengths = lengths; + _records = []; + _allowedRecordTypes = memberTypeInfo.GetNextAllowedRecordType(0).allowed; + + Debug.Assert(TypeName.GetElementType().IsArray, "Jagged arrays are required."); + } + + public override SerializationRecordType RecordType => SerializationRecordType.BinaryArray; + + public override ReadOnlySpan Lengths => _lengths; + + public override TypeName TypeName => _typeName ??= _memberTypeInfo.GetArrayTypeName(ArrayInfo); + + [RequiresDynamicCode("May call Array.CreateInstance().")] + private protected override Array Deserialize(Type arrayType, bool allowNulls) + { + // This method returns arrays of ArrayRecords. + Array array = _lengths.Length switch + { + 1 => new ArrayRecord[_lengths[0]], + 2 => new ArrayRecord[_lengths[0], _lengths[1]], + _ => Array.CreateInstance(typeof(ArrayRecord), _lengths) + }; + + Populate(_records, array, _lengths, AllowedRecordTypes.Arrays, allowNulls); + + return array; + } + + private protected override void AddValue(object value) => _records.Add((SerializationRecord)value); + + internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType() + => (_allowedRecordTypes, default); +} diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/MemberTypeInfo.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/MemberTypeInfo.cs index 57e47a02eec688..84c1073b0ef67a 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/MemberTypeInfo.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/MemberTypeInfo.cs @@ -86,10 +86,12 @@ internal static MemberTypeInfo Decode(BinaryReader reader, int count, PayloadOpt // Every class can be a null or a reference and a ClassWithId const AllowedRecordTypes Classes = AllowedRecordTypes.ClassWithId | AllowedRecordTypes.ObjectNull | AllowedRecordTypes.MemberReference - | AllowedRecordTypes.MemberPrimitiveTyped | AllowedRecordTypes.BinaryLibrary; // Classes may be preceded with a library record (System too!) // but System Classes can be expressed only by System records - const AllowedRecordTypes SystemClass = Classes | AllowedRecordTypes.SystemClassWithMembersAndTypes; + const AllowedRecordTypes SystemClass = Classes | AllowedRecordTypes.SystemClassWithMembersAndTypes + // All primitive types can be stored by using one of the interfaces they implement. + // Example: `new IEnumerable[1] { "hello" }` or `new IComparable[1] { int.MaxValue }`. + | AllowedRecordTypes.BinaryObjectString | AllowedRecordTypes.MemberPrimitiveTyped; const AllowedRecordTypes NonSystemClass = Classes | AllowedRecordTypes.ClassWithMembersAndTypes; return binaryType switch @@ -106,43 +108,6 @@ internal static MemberTypeInfo Decode(BinaryReader reader, int count, PayloadOpt }; } - internal bool ShouldBeRepresentedAsArrayOfClassRecords() - { - // This library tries to minimize the number of concepts the users need to learn to use it. - // Since SZArrays are most common, it provides an SZArrayRecord abstraction. - // Every other array (jagged, multi-dimensional etc) is represented using ArrayRecord. - // The goal of this method is to determine whether given array can be represented as SZArrayRecord. - - (BinaryType binaryType, object? additionalInfo) = Infos[0]; - - if (binaryType == BinaryType.Class) - { - // An array of arrays can not be represented as SZArrayRecord. - return !((ClassTypeInfo)additionalInfo!).TypeName.IsArray; - } - else if (binaryType == BinaryType.SystemClass) - { - TypeName typeName = (TypeName)additionalInfo!; - - // An array of arrays can not be represented as SZArrayRecord. - if (typeName.IsArray) - { - return false; - } - - if (!typeName.IsConstructedGenericType) - { - return true; - } - - // Can't use SZArrayRecord for Nullable[] - // as it consists of MemberPrimitiveTypedRecord and NullsRecord - return typeName.GetGenericTypeDefinition().FullName != typeof(Nullable<>).FullName; - } - - return false; - } - internal TypeName GetArrayTypeName(ArrayInfo arrayInfo) { (BinaryType binaryType, object? additionalInfo) = Infos[0]; diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/NrbfDecoder.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/NrbfDecoder.cs index a315b37cff0234..65bb6e8beb9c5c 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/NrbfDecoder.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/NrbfDecoder.cs @@ -9,6 +9,7 @@ using System.Text; using System.Runtime.Serialization; using System.Runtime.InteropServices; +using System.Reflection.Metadata; namespace System.Formats.Nrbf; @@ -223,7 +224,7 @@ private static SerializationRecord DecodeNext(BinaryReader reader, RecordMap rec SerializationRecordType.ArraySingleObject => ArraySingleObjectRecord.Decode(reader), SerializationRecordType.ArraySinglePrimitive => DecodeArraySinglePrimitiveRecord(reader), SerializationRecordType.ArraySingleString => ArraySingleStringRecord.Decode(reader), - SerializationRecordType.BinaryArray => BinaryArrayRecord.Decode(reader, recordMap, options), + SerializationRecordType.BinaryArray => DecodeBinaryArrayRecord(reader, recordMap, options), SerializationRecordType.BinaryLibrary => BinaryLibraryRecord.Decode(reader, options), SerializationRecordType.BinaryObjectString => BinaryObjectStringRecord.Decode(reader), SerializationRecordType.ClassWithId => ClassWithIdRecord.Decode(reader, recordMap), @@ -269,11 +270,16 @@ private static SerializationRecord DecodeMemberPrimitiveTypedRecord(BinaryReader }; } - private static SerializationRecord DecodeArraySinglePrimitiveRecord(BinaryReader reader) + private static ArrayRecord DecodeArraySinglePrimitiveRecord(BinaryReader reader) { ArrayInfo info = ArrayInfo.Decode(reader); PrimitiveType primitiveType = reader.ReadPrimitiveType(); + return DecodeArraySinglePrimitiveRecord(reader, info, primitiveType); + } + + private static ArrayRecord DecodeArraySinglePrimitiveRecord(BinaryReader reader, ArrayInfo info, PrimitiveType primitiveType) + { return primitiveType switch { PrimitiveType.Boolean => Decode(info, reader), @@ -294,10 +300,171 @@ private static SerializationRecord DecodeArraySinglePrimitiveRecord(BinaryReader _ => throw new InvalidOperationException() }; - static SerializationRecord Decode(ArrayInfo info, BinaryReader reader) where T : unmanaged + static ArrayRecord Decode(ArrayInfo info, BinaryReader reader) where T : unmanaged => new ArraySinglePrimitiveRecord(info, ArraySinglePrimitiveRecord.DecodePrimitiveTypes(reader, info.GetSZArrayLength())); } + private static ArrayRecord DecodeArrayRectangularPrimitiveRecord(PrimitiveType primitiveType, ArrayInfo info, int[] lengths, BinaryReader reader) + { + return primitiveType switch + { + PrimitiveType.Boolean => Decode(info, lengths, reader), + PrimitiveType.Byte => Decode(info, lengths, reader), + PrimitiveType.SByte => Decode(info, lengths, reader), + PrimitiveType.Char => Decode(info, lengths, reader), + PrimitiveType.Int16 => Decode(info, lengths, reader), + PrimitiveType.UInt16 => Decode(info, lengths, reader), + PrimitiveType.Int32 => Decode(info, lengths, reader), + PrimitiveType.UInt32 => Decode(info, lengths, reader), + PrimitiveType.Int64 => Decode(info, lengths, reader), + PrimitiveType.UInt64 => Decode(info, lengths, reader), + PrimitiveType.Single => Decode(info, lengths, reader), + PrimitiveType.Double => Decode(info, lengths, reader), + PrimitiveType.Decimal => Decode(info, lengths, reader), + PrimitiveType.DateTime => Decode(info, lengths, reader), + PrimitiveType.TimeSpan => Decode(info, lengths, reader), + _ => throw new InvalidOperationException() + }; + + static ArrayRecord Decode(ArrayInfo info, int[] lengths, BinaryReader reader) where T : unmanaged + { + // We limit the length of multi-dimensional array to max length of SZArray. + // Because of that, it's possible to re-use the same decoding logic for both MD and SZ arrays. + IReadOnlyList values = ArraySinglePrimitiveRecord.DecodePrimitiveTypes(reader, info.GetSZArrayLength()); + return new ArrayRectangularPrimitiveRecord(info, lengths, values); + } + } + + private static ArrayRecord DecodeBinaryArrayRecord(BinaryReader reader, RecordMap recordMap, PayloadOptions options) + { + SerializationRecordId objectId = SerializationRecordId.Decode(reader); + BinaryArrayType arrayType = reader.ReadArrayType(); + int rank = reader.ReadInt32(); + + bool isRectangular = arrayType is BinaryArrayType.Rectangular; + + // It is an arbitrary limit in the current CoreCLR type loader. + // Don't change this value without reviewing the loop a few lines below. + const int MaxSupportedArrayRank = 32; + + if (rank < 1 || rank > MaxSupportedArrayRank + || (rank != 1 && !isRectangular) + || (rank == 1 && isRectangular)) + { + ThrowHelper.ThrowInvalidValue(rank); + } + + int[] lengths = new int[rank]; // adversary-controlled, but acceptable since upper limit of 32 + long totalElementCount = 1; // to avoid integer overflow during the multiplication below + for (int i = 0; i < lengths.Length; i++) + { + lengths[i] = ArrayInfo.ParseValidArrayLength(reader); + totalElementCount *= lengths[i]; + + // n.b. This forbids "new T[Array.MaxLength, Array.MaxLength, Array.MaxLength, ..., 0]" + // but allows "new T[0, Array.MaxLength, Array.MaxLength, Array.MaxLength, ...]". But + // that's the same behavior that newarr and Array.CreateInstance exhibit, so at least + // we're consistent. + + if (totalElementCount > ArrayInfo.MaxArrayLength) + { + ThrowHelper.ThrowInvalidValue(lengths[i]); // max array size exceeded + } + } + + // Per BinaryReaderExtensions.ReadArrayType, we do not support nonzero offsets, so + // we don't need to read the NRBF stream 'LowerBounds' field here. + + MemberTypeInfo memberTypeInfo = MemberTypeInfo.Decode(reader, 1, options, recordMap); + ArrayInfo arrayInfo = new(objectId, totalElementCount, arrayType, rank); + + (BinaryType binaryType, object? additionalInfo) = memberTypeInfo.Infos[0]; + if (arrayType == BinaryArrayType.Rectangular) + { + if (binaryType == BinaryType.Primitive) + { + return DecodeArrayRectangularPrimitiveRecord((PrimitiveType)additionalInfo!, arrayInfo, lengths, reader); + } + else if (binaryType == BinaryType.String) + { + return new RectangularArrayRecord(typeof(string), arrayInfo, memberTypeInfo, lengths); + } + else if (binaryType == BinaryType.Object) + { + return new RectangularArrayRecord(typeof(SerializationRecord), arrayInfo, memberTypeInfo, lengths); + } + else if (binaryType is BinaryType.SystemClass or BinaryType.Class) + { + TypeName typeName = binaryType == BinaryType.SystemClass ? (TypeName)additionalInfo! : ((ClassTypeInfo)additionalInfo!).TypeName; + // BinaryArrayType.Rectangular can be also a jagged array. + return typeName.IsArray + ? new JaggedArrayRecord(arrayInfo, memberTypeInfo, lengths) + : new RectangularArrayRecord(typeof(SerializationRecord), arrayInfo, memberTypeInfo, lengths); + } + else if (binaryType is BinaryType.PrimitiveArray or BinaryType.StringArray or BinaryType.ObjectArray) + { + // A multi-dimensional array of single dimensional arrays. Example: int[][,] + return new JaggedArrayRecord(arrayInfo, memberTypeInfo, lengths); + } + } + else if (arrayType == BinaryArrayType.Single) + { + if (binaryType is BinaryType.SystemClass or BinaryType.Class) + { + TypeName typeName = binaryType == BinaryType.SystemClass ? (TypeName)additionalInfo! : ((ClassTypeInfo)additionalInfo!).TypeName; + // BinaryArrayType.Single that describes an array is just a jagged array. + return typeName.IsArray + ? new JaggedArrayRecord(arrayInfo, memberTypeInfo, lengths) + : new SZArrayOfRecords(arrayInfo, memberTypeInfo); + } + else if (binaryType == BinaryType.String) + { + // BinaryArrayRecord can represent string[] (but BF always uses ArraySingleStringRecord for that). + return new ArraySingleStringRecord(arrayInfo); + } + else if (binaryType == BinaryType.Primitive) + { + // BinaryArrayRecord can represent Primitive[] (but BF always uses ArraySinglePrimitiveRecord for that). + return DecodeArraySinglePrimitiveRecord(reader, arrayInfo, (PrimitiveType)additionalInfo!); + } + else if (binaryType == BinaryType.Object) + { + // BinaryArrayRecord can represent object[] (but BF always uses ArraySingleObjectRecord for that). + return new ArraySingleObjectRecord(arrayInfo); + } + else if (binaryType is BinaryType.ObjectArray or BinaryType.StringArray or BinaryType.PrimitiveArray) + { + // It's a Jagged array that does not use BinaryArrayType.Jagged. + return new JaggedArrayRecord(arrayInfo, memberTypeInfo, lengths); + } + } + else if (arrayType == BinaryArrayType.Jagged) + { + if (binaryType is BinaryType.ObjectArray or BinaryType.StringArray or BinaryType.PrimitiveArray) + { + // It's a Jagged array that does not use BinaryArrayType.Jagged. + return new JaggedArrayRecord(arrayInfo, memberTypeInfo, lengths); + } + else if (binaryType == BinaryType.SystemClass && ((TypeName)additionalInfo!).IsArray) + { + // BinaryType.SystemClass can be used to describe arrays of system class records. + // Example: new Exception[] { new Exception("test") }; + return new JaggedArrayRecord(arrayInfo, memberTypeInfo, lengths); + } + else if (binaryType == BinaryType.Class && ((ClassTypeInfo)additionalInfo!).TypeName.IsArray) + { + // BinaryType.Class can be used to describe arrays of class records. + // Example: new MyCustomType[] { new MyCustomType(0) }; + return new JaggedArrayRecord(arrayInfo, memberTypeInfo, lengths); + } + + // It's invalid, the element type must be an array. + throw new SerializationException(SR.Format(SR.Serialization_InvalidValue, binaryType)); + } + + throw new InvalidOperationException(); + } + /// /// This method is responsible for pushing only the FIRST read info /// of the NESTED record into the . diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RecordMap.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RecordMap.cs index eafcbf93249c57..dd5862c7b2b862 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RecordMap.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RecordMap.cs @@ -61,18 +61,7 @@ internal void Add(SerializationRecord record) } } - internal SerializationRecord GetRootRecord(SerializedStreamHeaderRecord header) - { - SerializationRecord rootRecord = GetRecord(header.RootId); - - if (rootRecord is SystemClassWithMembersAndTypesRecord systemClass) - { - // update the record map, so it's visible also to those who access it via Id - _map[header.RootId] = rootRecord = systemClass.TryToMapToUserFriendly(); - } - - return rootRecord; - } + internal SerializationRecord GetRootRecord(SerializedStreamHeaderRecord header) => GetRecord(header.RootId); internal SerializationRecord GetRecord(SerializationRecordId recordId) => _map.TryGetValue(recordId, out SerializationRecord? record) diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RectangularArrayRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RectangularArrayRecord.cs index f64dde36163d69..f10bc3f51efdae 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RectangularArrayRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RectangularArrayRecord.cs @@ -9,24 +9,29 @@ using System.Runtime.InteropServices; using System.Formats.Nrbf.Utils; using System.Diagnostics; +using System.Runtime.Serialization; namespace System.Formats.Nrbf; internal sealed class RectangularArrayRecord : ArrayRecord { + private readonly Type _elementType; private readonly int[] _lengths; - private readonly List _values; + private readonly List _records; + private readonly AllowedRecordTypes _allowedRecordTypes; + private readonly MemberTypeInfo _memberTypeInfo; private TypeName? _typeName; - private RectangularArrayRecord(Type elementType, ArrayInfo arrayInfo, - MemberTypeInfo memberTypeInfo, int[] lengths, bool canPreAllocate) : base(arrayInfo) + internal RectangularArrayRecord(Type elementType, ArrayInfo arrayInfo, MemberTypeInfo memberTypeInfo, int[] lengths) : base(arrayInfo) { - ElementType = elementType; - MemberTypeInfo = memberTypeInfo; + _elementType = elementType; _lengths = lengths; + _memberTypeInfo = memberTypeInfo; + _records = new List(Math.Min(4, arrayInfo.GetSZArrayLength())); + _allowedRecordTypes = memberTypeInfo.GetNextAllowedRecordType(0).allowed; - // ArrayInfo.GetSZArrayLength ensures to return a value <= Array.MaxLength - _values = new List(canPreAllocate ? arrayInfo.GetSZArrayLength() : Math.Min(4, arrayInfo.GetSZArrayLength())); + Debug.Assert(elementType == typeof(string) || elementType == typeof(SerializationRecord)); + Debug.Assert(!TypeName.GetElementType().IsArray, "Use JaggedArrayRecord instead."); } public override SerializationRecordType RecordType => SerializationRecordType.BinaryArray; @@ -34,230 +39,32 @@ private RectangularArrayRecord(Type elementType, ArrayInfo arrayInfo, public override ReadOnlySpan Lengths => _lengths.AsSpan(); public override TypeName TypeName - => _typeName ??= MemberTypeInfo.GetArrayTypeName(ArrayInfo); - - private Type ElementType { get; } - - private MemberTypeInfo MemberTypeInfo { get; } + => _typeName ??= _memberTypeInfo.GetArrayTypeName(ArrayInfo); [RequiresDynamicCode("May call Array.CreateInstance() and Type.MakeArrayType().")] private protected override Array Deserialize(Type arrayType, bool allowNulls) { - // We can not deserialize non-primitive types. - // This method returns arrays of ClassRecord for arrays of complex types. + bool storeStrings = _elementType == typeof(string); + + // We can not deserialize non-string types. + // This method returns arrays of SerializationRecord for arrays of complex types. Array result = #if NET9_0_OR_GREATER - ElementType == typeof(ClassRecord) - ? Array.CreateInstance(ElementType, _lengths) - : Array.CreateInstanceFromArrayType(arrayType, _lengths); + storeStrings + ? Array.CreateInstanceFromArrayType(arrayType, _lengths) + : Array.CreateInstance(_elementType, _lengths); #else - Array.CreateInstance(ElementType, _lengths); + Array.CreateInstance(_elementType, _lengths); #endif -#if !NET8_0_OR_GREATER - int[] indices = new int[_lengths.Length]; - nuint numElementsWritten = 0; // only for debugging; not used in release builds - - foreach (object value in _values) - { - result.SetValue(GetActualValue(value), indices); - numElementsWritten++; - - int dimension = indices.Length - 1; - while (dimension >= 0) - { - indices[dimension]++; - if (indices[dimension] < Lengths[dimension]) - { - break; - } - indices[dimension] = 0; - dimension--; - } - - if (dimension < 0) - { - break; - } - } - - Debug.Assert(numElementsWritten == (uint)_values.Count, "We should have traversed the entirety of the source values collection."); - Debug.Assert(numElementsWritten == (ulong)result.LongLength, "We should have traversed the entirety of the destination array."); + AllowedRecordTypes allowedRecordTypes = storeStrings ? AllowedRecordTypes.BinaryObjectString : AllowedRecordTypes.AnyObject; + Populate(_records, result, _lengths, allowedRecordTypes, allowNulls); return result; -#else - // Idea from Array.CoreCLR that maps an array of int indices into - // an internal flat index. - if (ElementType.IsValueType) - { - if (ElementType == typeof(bool)) CopyTo(_values, result); - else if (ElementType == typeof(byte)) CopyTo(_values, result); - else if (ElementType == typeof(sbyte)) CopyTo(_values, result); - else if (ElementType == typeof(short)) CopyTo(_values, result); - else if (ElementType == typeof(ushort)) CopyTo(_values, result); - else if (ElementType == typeof(char)) CopyTo(_values, result); - else if (ElementType == typeof(int)) CopyTo(_values, result); - else if (ElementType == typeof(float)) CopyTo(_values, result); - else if (ElementType == typeof(long)) CopyTo(_values, result); - else if (ElementType == typeof(ulong)) CopyTo(_values, result); - else if (ElementType == typeof(double)) CopyTo(_values, result); - else if (ElementType == typeof(TimeSpan)) CopyTo(_values, result); - else if (ElementType == typeof(DateTime)) CopyTo(_values, result); - else if (ElementType == typeof(decimal)) CopyTo(_values, result); - else throw new InvalidOperationException(); - } - else - { - CopyTo(_values, result); - } - - return result; - - static void CopyTo(List list, Array array) - { - ref byte arrayDataRef = ref MemoryMarshal.GetArrayDataReference(array); - ref T firstElementRef = ref Unsafe.As(ref arrayDataRef); - nuint flattenedIndex = 0; - foreach (object value in list) - { - ref T targetElement = ref Unsafe.Add(ref firstElementRef, flattenedIndex); - targetElement = (T)GetActualValue(value)!; - flattenedIndex++; - } - - Debug.Assert(flattenedIndex == (ulong)array.LongLength, "We should have traversed the entirety of the array."); - } -#endif } - private protected override void AddValue(object value) => _values.Add(value); + private protected override void AddValue(object value) => _records.Add((SerializationRecord)value); internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType() - { - (AllowedRecordTypes allowed, PrimitiveType primitiveType) = MemberTypeInfo.GetNextAllowedRecordType(0); - - if (allowed != AllowedRecordTypes.None) - { - // It's an array, it can also contain multiple nulls - return (allowed | AllowedRecordTypes.Nulls, primitiveType); - } - - return (allowed, primitiveType); - } - - internal static RectangularArrayRecord Create(BinaryReader reader, ArrayInfo arrayInfo, - MemberTypeInfo memberTypeInfo, int[] lengths) - { - BinaryType binaryType = memberTypeInfo.Infos[0].BinaryType; - Type elementType = binaryType switch - { - BinaryType.Primitive => MapPrimitive((PrimitiveType)memberTypeInfo.Infos[0].AdditionalInfo!), - BinaryType.PrimitiveArray => MapPrimitiveArray((PrimitiveType)memberTypeInfo.Infos[0].AdditionalInfo!), - BinaryType.String => typeof(string), - BinaryType.Object => typeof(object), - _ => typeof(ClassRecord) - }; - - bool canPreAllocate = false; - if (binaryType == BinaryType.Primitive) - { - int sizeOfSingleValue = (PrimitiveType)memberTypeInfo.Infos[0].AdditionalInfo! switch - { - PrimitiveType.Boolean => sizeof(bool), - PrimitiveType.Byte => sizeof(byte), - PrimitiveType.SByte => sizeof(sbyte), - PrimitiveType.Char => sizeof(byte), // it's UTF8 (see comment below) - PrimitiveType.Int16 => sizeof(short), - PrimitiveType.UInt16 => sizeof(ushort), - PrimitiveType.Int32 => sizeof(int), - PrimitiveType.UInt32 => sizeof(uint), - PrimitiveType.Single => sizeof(float), - PrimitiveType.Int64 => sizeof(long), - PrimitiveType.UInt64 => sizeof(ulong), - PrimitiveType.Double => sizeof(double), - PrimitiveType.TimeSpan => sizeof(ulong), - PrimitiveType.DateTime => sizeof(ulong), - PrimitiveType.Decimal => -1, // represented as variable-length string - _ => throw new InvalidOperationException() - }; - - if (sizeOfSingleValue > 0) - { - // NRBF encodes rectangular char[,,,...] by converting each standalone UTF-16 code point into - // its UTF-8 encoding. This means that surrogate code points (including adjacent surrogate - // pairs) occurring within a char[,,,...] cannot be encoded by NRBF. BinaryReader will detect - // that they're ill-formed and reject them on read. - // - // Per the comment in ArraySinglePrimitiveRecord.DecodePrimitiveTypes, we'll assume best-case - // encoding where 1 UTF-16 char encodes as a single UTF-8 byte, even though this might lead - // to encountering an EOF if we realize later that we actually need to read more bytes in - // order to fully populate the char[,,,...] array. Any such allocation is still linearly - // proportional to the length of the incoming payload, so it's not a DoS vector. - // The multiplication below is guaranteed not to overflow because FlattenedLength is bounded - // to <= Array.MaxLength (see BinaryArrayRecord.Decode) and sizeOfSingleValue is at most 8. - Debug.Assert(arrayInfo.FlattenedLength >= 0 && arrayInfo.FlattenedLength <= long.MaxValue / sizeOfSingleValue); - - long size = arrayInfo.FlattenedLength * sizeOfSingleValue; - bool? isDataAvailable = reader.IsDataAvailable(size); - if (isDataAvailable.HasValue) - { - if (!isDataAvailable.Value) - { - ThrowHelper.ThrowEndOfStreamException(); - } - - canPreAllocate = true; - } - } - } - - return new RectangularArrayRecord(elementType, arrayInfo, memberTypeInfo, lengths, canPreAllocate); - } - - private static Type MapPrimitive(PrimitiveType primitiveType) - => primitiveType switch - { - PrimitiveType.Boolean => typeof(bool), - PrimitiveType.Byte => typeof(byte), - PrimitiveType.Char => typeof(char), - PrimitiveType.Decimal => typeof(decimal), - PrimitiveType.Double => typeof(double), - PrimitiveType.Int16 => typeof(short), - PrimitiveType.Int32 => typeof(int), - PrimitiveType.Int64 => typeof(long), - PrimitiveType.SByte => typeof(sbyte), - PrimitiveType.Single => typeof(float), - PrimitiveType.TimeSpan => typeof(TimeSpan), - PrimitiveType.DateTime => typeof(DateTime), - PrimitiveType.UInt16 => typeof(ushort), - PrimitiveType.UInt32 => typeof(uint), - PrimitiveType.UInt64 => typeof(ulong), - _ => throw new InvalidOperationException() - }; - - private static Type MapPrimitiveArray(PrimitiveType primitiveType) - => primitiveType switch - { - PrimitiveType.Boolean => typeof(bool[]), - PrimitiveType.Byte => typeof(byte[]), - PrimitiveType.Char => typeof(char[]), - PrimitiveType.Decimal => typeof(decimal[]), - PrimitiveType.Double => typeof(double[]), - PrimitiveType.Int16 => typeof(short[]), - PrimitiveType.Int32 => typeof(int[]), - PrimitiveType.Int64 => typeof(long[]), - PrimitiveType.SByte => typeof(sbyte[]), - PrimitiveType.Single => typeof(float[]), - PrimitiveType.TimeSpan => typeof(TimeSpan[]), - PrimitiveType.DateTime => typeof(DateTime[]), - PrimitiveType.UInt16 => typeof(ushort[]), - PrimitiveType.UInt32 => typeof(uint[]), - PrimitiveType.UInt64 => typeof(ulong[]), - _ => throw new InvalidOperationException() - }; - - private static object? GetActualValue(object value) - => value is SerializationRecord serializationRecord - ? serializationRecord.GetValue() - : value; // it must be a primitive type + => (_allowedRecordTypes, default); } diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayOfClassesRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SZArrayOfRecords.cs similarity index 69% rename from src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayOfClassesRecord.cs rename to src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SZArrayOfRecords.cs index f345292c693a61..b77a4a57a2a348 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayOfClassesRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SZArrayOfRecords.cs @@ -8,11 +8,15 @@ namespace System.Formats.Nrbf; -internal sealed class ArrayOfClassesRecord : SZArrayRecord +// This library tries to minimize the number of concepts the users need to learn to use it. +// Since SZArrays are most common, it provides an SZArrayRecord abstraction. +// Every other array (jagged, multi-dimensional etc) is represented using ArrayRecord. +// The goal of this class is to let the users use SZArrayRecord abstraction. +internal sealed class SZArrayOfRecords : SZArrayRecord { private TypeName? _typeName; - internal ArrayOfClassesRecord(ArrayInfo arrayInfo, MemberTypeInfo memberTypeInfo) + internal SZArrayOfRecords(ArrayInfo arrayInfo, MemberTypeInfo memberTypeInfo) : base(arrayInfo) { MemberTypeInfo = memberTypeInfo; @@ -29,12 +33,12 @@ public override TypeName TypeName => _typeName ??= MemberTypeInfo.GetArrayTypeName(ArrayInfo); /// - public override ClassRecord?[] GetArray(bool allowNulls = true) - => (ClassRecord?[])(allowNulls ? _arrayNullsAllowed ??= ToArray(true) : _arrayNullsNotAllowed ??= ToArray(false)); + public override SerializationRecord?[] GetArray(bool allowNulls = true) + => (SerializationRecord?[])(allowNulls ? _arrayNullsAllowed ??= ToArray(true) : _arrayNullsNotAllowed ??= ToArray(false)); - private ClassRecord?[] ToArray(bool allowNulls) + private SerializationRecord?[] ToArray(bool allowNulls) { - ClassRecord?[] result = new ClassRecord?[Length]; + SerializationRecord?[] result = new SerializationRecord?[Length]; int resultIndex = 0; foreach (SerializationRecord record in Records) @@ -43,9 +47,9 @@ public override TypeName TypeName ? referenceRecord.GetReferencedRecord() : record; - if (actual is ClassRecord classRecord) + if (actual is not NullsRecord nullsRecord) { - result[resultIndex++] = classRecord; + result[resultIndex++] = actual; } else { @@ -54,7 +58,7 @@ public override TypeName TypeName ThrowHelper.ThrowArrayContainedNulls(); } - int nullCount = ((NullsRecord)actual).NullCount; + int nullCount = nullsRecord.NullCount; Debug.Assert(nullCount > 0, "All implementations of NullsRecord are expected to return a positive value for NullCount."); do { diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SystemClassWithMembersAndTypesRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SystemClassWithMembersAndTypesRecord.cs index ccecc2246e8c22..0c5193cd92272a 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SystemClassWithMembersAndTypesRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SystemClassWithMembersAndTypesRecord.cs @@ -3,6 +3,7 @@ using System.IO; using System.Formats.Nrbf.Utils; +using System.Reflection.Metadata; namespace System.Formats.Nrbf; @@ -21,92 +22,100 @@ private SystemClassWithMembersAndTypesRecord(ClassInfo classInfo, MemberTypeInfo public override SerializationRecordType RecordType => SerializationRecordType.SystemClassWithMembersAndTypes; - internal static SystemClassWithMembersAndTypesRecord Decode(BinaryReader reader, RecordMap recordMap, PayloadOptions options) + internal static SerializationRecord Decode(BinaryReader reader, RecordMap recordMap, PayloadOptions options) { ClassInfo classInfo = ClassInfo.Decode(reader); MemberTypeInfo memberTypeInfo = MemberTypeInfo.Decode(reader, classInfo.MemberNames.Count, options, recordMap); // the only difference with ClassWithMembersAndTypesRecord is that we don't read library id here classInfo.LoadTypeName(options); - return new SystemClassWithMembersAndTypesRecord(classInfo, memberTypeInfo); - } + TypeName typeName = classInfo.TypeName; - internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetNextAllowedRecordType() - => MemberTypeInfo.GetNextAllowedRecordType(MemberValues.Count); + // BinaryFormatter represents primitive types as MemberPrimitiveTypedRecord + // only for arrays of objects. For other arrays, like arrays of some abstraction + // (example: new IComparable[] { int.MaxValue }), it uses SystemClassWithMembersAndTypes. + // The same goes for root records that turn out to be primitive types. + // We want to have the behavior unified, so we map such records to + // PrimitiveTypeRecord so the users don't need to learn the BF internals + // to get a single primitive value. + // We need to be as strict as possible, as we don't want to map anything else by accident. + // That is why the code below is VERY defensive. - // For the root records that turn out to be primitive types, we map them to - // PrimitiveTypeRecord so the users don't need to learn the BF internals - // to get a single primitive value! - internal SerializationRecord TryToMapToUserFriendly() - { - if (!TypeName.IsSimple) + if (!classInfo.TypeName.IsSimple || classInfo.MemberNames.Count == 0 || memberTypeInfo.Infos[0].BinaryType != BinaryType.Primitive) { - return this; + return new SystemClassWithMembersAndTypesRecord(classInfo, memberTypeInfo); } - - if (MemberValues.Count == 1) + else if (classInfo.MemberNames.Count == 1) { - if (HasMember("m_value")) + PrimitiveType primitiveType = (PrimitiveType)memberTypeInfo.Infos[0].AdditionalInfo!; + // Get the member name without allocating on the heap. + Collections.Generic.Dictionary.Enumerator structEnumerator = classInfo.MemberNames.GetEnumerator(); + _ = structEnumerator.MoveNext(); + string memberName = structEnumerator.Current.Key; + // Everything needs to match: primitive type, type name name and member name. + return (primitiveType, typeName.FullName, memberName) switch { - return MemberValues[0] switch - { - // there can be a value match, but no TypeName match - bool value when TypeNameMatches(typeof(bool)) => Create(value), - byte value when TypeNameMatches(typeof(byte)) => Create(value), - sbyte value when TypeNameMatches(typeof(sbyte)) => Create(value), - char value when TypeNameMatches(typeof(char)) => Create(value), - short value when TypeNameMatches(typeof(short)) => Create(value), - ushort value when TypeNameMatches(typeof(ushort)) => Create(value), - int value when TypeNameMatches(typeof(int)) => Create(value), - uint value when TypeNameMatches(typeof(uint)) => Create(value), - long value when TypeNameMatches(typeof(long)) => Create(value), - ulong value when TypeNameMatches(typeof(ulong)) => Create(value), - float value when TypeNameMatches(typeof(float)) => Create(value), - double value when TypeNameMatches(typeof(double)) => Create(value), - _ => this - }; - } - else if (HasMember("value")) - { - return MemberValues[0] switch - { - // there can be a value match, but no TypeName match - long value when TypeNameMatches(typeof(IntPtr)) => Create(new IntPtr(value)), - ulong value when TypeNameMatches(typeof(UIntPtr)) => Create(new UIntPtr(value)), - _ => this - }; - } - else if (HasMember("_ticks") && GetRawValue("_ticks") is long ticks && TypeNameMatches(typeof(TimeSpan))) - { - return Create(new TimeSpan(ticks)); - } + (PrimitiveType.Boolean, "System.Boolean", "m_value") => Create(reader.ReadBoolean()), + (PrimitiveType.Byte, "System.Byte", "m_value") => Create(reader.ReadByte()), + (PrimitiveType.SByte, "System.SByte", "m_value") => Create(reader.ReadSByte()), + (PrimitiveType.Char, "System.Char", "m_value") => Create(reader.ParseChar()), + (PrimitiveType.Int16, "System.Int16", "m_value") => Create(reader.ReadInt16()), + (PrimitiveType.UInt16, "System.UInt16", "m_value") => Create(reader.ReadUInt16()), + (PrimitiveType.Int32, "System.Int32", "m_value") => Create(reader.ReadInt32()), + (PrimitiveType.UInt32, "System.UInt32", "m_value") => Create(reader.ReadUInt32()), + (PrimitiveType.Int64, "System.Int64", "m_value") => Create(reader.ReadInt64()), + (PrimitiveType.Int64, "System.IntPtr", "value") => Create(new IntPtr(reader.ReadInt64())), + (PrimitiveType.Int64, "System.TimeSpan", "_ticks") => Create(new TimeSpan(reader.ReadInt64())), + (PrimitiveType.UInt64, "System.UInt64", "m_value") => Create(reader.ReadUInt64()), + (PrimitiveType.UInt64, "System.UIntPtr", "value") => Create(new UIntPtr(reader.ReadUInt64())), + (PrimitiveType.Single, "System.Single", "m_value") => Create(reader.ReadSingle()), + (PrimitiveType.Double, "System.Double", "m_value") => Create(reader.ReadDouble()), + _ => new SystemClassWithMembersAndTypesRecord(classInfo, memberTypeInfo) + }; } - else if (MemberValues.Count == 2 - && HasMember("ticks") && HasMember("dateData") - && GetRawValue("ticks") is long && GetRawValue("dateData") is ulong dateData - && TypeNameMatches(typeof(DateTime))) + else if (classInfo.MemberNames.Count == 2 && typeName.FullName == "System.DateTime" + && HasMember("ticks", 0, PrimitiveType.Int64) + && HasMember("dateData", 1, PrimitiveType.UInt64)) { - return Create(Utils.BinaryReaderExtensions.CreateDateTimeFromData(dateData)); + return DecodeDateTime(reader, classInfo.Id); } - else if (MemberValues.Count == 4 - && HasMember("lo") && HasMember("mid") && HasMember("hi") && HasMember("flags") - && GetRawValue("lo") is int lo && GetRawValue("mid") is int mid - && GetRawValue("hi") is int hi && GetRawValue("flags") is int flags - && TypeNameMatches(typeof(decimal))) + else if (classInfo.MemberNames.Count == 4 && typeName.FullName == "System.Decimal" + && HasMember("flags", 0, PrimitiveType.Int32) + && HasMember("hi", 1, PrimitiveType.Int32) + && HasMember("lo", 2, PrimitiveType.Int32) + && HasMember("mid", 3, PrimitiveType.Int32)) { - int[] bits = - [ - lo, - mid, - hi, - flags - ]; - - return Create(new decimal(bits)); + return DecodeDecimal(reader, classInfo.Id); } - return this; + return new SystemClassWithMembersAndTypesRecord(classInfo, memberTypeInfo); SerializationRecord Create(T value) where T : unmanaged - => new MemberPrimitiveTypedRecord(value, Id); + => new MemberPrimitiveTypedRecord(value, classInfo.Id); + + bool HasMember(string name, int order, PrimitiveType primitiveType) + => classInfo.MemberNames.TryGetValue(name, out int memberOrder) + && memberOrder == order + && ((PrimitiveType)memberTypeInfo.Infos[order].AdditionalInfo!) == primitiveType; + } + + internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetNextAllowedRecordType() + => MemberTypeInfo.GetNextAllowedRecordType(MemberValues.Count); + + internal static MemberPrimitiveTypedRecord DecodeDateTime(BinaryReader reader, SerializationRecordId id) + { + _ = reader.ReadInt64(); // ticks are not used, but they need to be read as they go first in the payload + ulong dateData = reader.ReadUInt64(); + + return new MemberPrimitiveTypedRecord(BinaryReaderExtensions.CreateDateTimeFromData(dateData), id); + } + + internal static MemberPrimitiveTypedRecord DecodeDecimal(BinaryReader reader, SerializationRecordId id) + { + int flags = reader.ReadInt32(); + int hi = reader.ReadInt32(); + int lo = reader.ReadInt32(); + int mid = reader.ReadInt32(); + + return new MemberPrimitiveTypedRecord(new decimal([lo, mid, hi, flags]), id); } } diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/Utils/BinaryReaderExtensions.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/Utils/BinaryReaderExtensions.cs index d5baa09dbd8fc4..8bb3ac3a1107bd 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/Utils/BinaryReaderExtensions.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/Utils/BinaryReaderExtensions.cs @@ -33,7 +33,7 @@ internal static BinaryArrayType ReadArrayType(this BinaryReader reader) { // To simplify the behavior and security review of the BinaryArrayRecord type, we // do not support reading non-zero-offset arrays. If this should change in the - // future, the BinaryArrayRecord.Decode method and supporting infrastructure + // future, the NrbfDecoder.DecodeBinaryArrayRecord method and supporting infrastructure // will need re-review. byte arrayType = reader.ReadByte(); diff --git a/src/libraries/System.Formats.Nrbf/tests/ArrayOfSerializationRecordsTests.cs b/src/libraries/System.Formats.Nrbf/tests/ArrayOfSerializationRecordsTests.cs new file mode 100644 index 00000000000000..18e39a5fd68e1f --- /dev/null +++ b/src/libraries/System.Formats.Nrbf/tests/ArrayOfSerializationRecordsTests.cs @@ -0,0 +1,516 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Runtime.Serialization; +using Microsoft.DotNet.XUnitExtensions; +using Xunit; + +namespace System.Formats.Nrbf.Tests +{ + public class ArrayOfSerializationRecordsTests : ReadTests + { + public enum ElementType + { + Object, + NonGeneric, + Generic + } + + [Serializable] + public class CustomClassThatImplementsIEnumerable : IEnumerable + { + public int Field; + + public IEnumerator GetEnumerator() => Array.Empty().GetEnumerator(); + } + + [Theory] + [InlineData(ElementType.Object)] + [InlineData(ElementType.NonGeneric)] + [InlineData(ElementType.Generic)] + public void CanReadArrayThatContainsStringRecord_SZ(ElementType elementType) + { + const string Text = "hello"; + Array input = elementType switch + { + ElementType.Object => new object[] { Text }, + ElementType.NonGeneric => new IEnumerable[] { Text }, + ElementType.Generic => new IEnumerable[] { Text }, + _ => throw new InvalidOperationException() + }; + + SZArrayRecord arrayRecord = (SZArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap); + SerializationRecord[] output = arrayRecord.GetArray(); + + Verify(input, arrayRecord, output, recordMap); + PrimitiveTypeRecord stringRecord = (PrimitiveTypeRecord)output.Single(); + Assert.Equal(Text, stringRecord.Value); + } + + [Theory] + [InlineData(ElementType.Object)] + [InlineData(ElementType.NonGeneric)] + [InlineData(ElementType.Generic)] + public void CanReadArrayThatContainsStringRecord_MD(ElementType elementType) + { + const string Text = "hello"; + Array input = elementType switch + { + ElementType.Object => new object[1, 1], + ElementType.NonGeneric => new IEnumerable[1, 1], + ElementType.Generic => new IEnumerable[1, 1], + _ => throw new InvalidOperationException() + }; + input.SetValue(Text, 0, 0); + + ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap); + SerializationRecord[,] output = (SerializationRecord[,])arrayRecord.GetArray(input.GetType()); + + Verify(input, arrayRecord, output, recordMap); + PrimitiveTypeRecord stringRecord = (PrimitiveTypeRecord)output[0, 0]; + Assert.Equal(Text, stringRecord.Value); + } + + [Theory] + [InlineData(ElementType.Object)] + [InlineData(ElementType.NonGeneric)] + [InlineData(ElementType.Generic)] + public void CanReadArrayThatContainsStringRecord_Jagged(ElementType elementType) + { + const string Text = "hello"; + Array input = elementType switch + { + ElementType.Object => new object[1][] { [Text] }, + ElementType.NonGeneric => new IEnumerable[1][] { [Text] }, + ElementType.Generic => new IEnumerable[1][] { [Text] }, + _ => throw new InvalidOperationException() + }; + + ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap); + ArrayRecord[] output = (ArrayRecord[])arrayRecord.GetArray(input.GetType()); + + Verify(input, arrayRecord, output, recordMap); + + SZArrayRecord contained = (SZArrayRecord)output.Single(); + PrimitiveTypeRecord stringRecord = (PrimitiveTypeRecord)contained.GetArray().Single(); + Assert.Equal(Text, stringRecord.Value); + } + + [ConditionalTheory] + [InlineData(ElementType.Object)] + [InlineData(ElementType.NonGeneric)] + [InlineData(ElementType.Generic)] + public void CanReadArrayThatContainsMemberPrimitiveTypedRecord_SZ(ElementType elementType) + { + if (elementType != ElementType.Object && !IsPatched) + { + throw new SkipTestException("Current machine has not been patched with the most recent BinaryFormatter fix."); + } + + const int Integer = 123; + Array input = elementType switch + { + ElementType.Object => new object[] { Integer }, + ElementType.NonGeneric => new IComparable[] { Integer }, + ElementType.Generic => new IComparable[] { Integer }, + _ => throw new InvalidOperationException() + }; + + SZArrayRecord arrayRecord = (SZArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap); + SerializationRecord[] output = arrayRecord.GetArray(); + + Verify(input, arrayRecord, output, recordMap); + PrimitiveTypeRecord intRecord = (PrimitiveTypeRecord)output.Single(); + Assert.Equal(Integer, intRecord.Value); + } + + [ConditionalTheory] + [InlineData(ElementType.Object)] + [InlineData(ElementType.NonGeneric)] + [InlineData(ElementType.Generic)] + public void CanReadArrayThatContainsMemberPrimitiveTypedRecord_MD(ElementType elementType) + { + if (elementType != ElementType.Object && !IsPatched) + { + throw new SkipTestException("Current machine has not been patched with the most recent BinaryFormatter fix."); + } + + const int Integer = 123; + Array input = elementType switch + { + ElementType.Object => new object[1, 1], + ElementType.NonGeneric => new IComparable[1, 1], + ElementType.Generic => new IComparable[1, 1], + _ => throw new InvalidOperationException() + }; + input.SetValue(Integer, 0, 0); + + ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap); + SerializationRecord[,] output = (SerializationRecord[,])arrayRecord.GetArray(input.GetType()); + + Verify(input, arrayRecord, output, recordMap); + PrimitiveTypeRecord intRecord = (PrimitiveTypeRecord)output[0, 0]; + Assert.Equal(Integer, intRecord.Value); + } + + [ConditionalTheory] + [InlineData(ElementType.Object)] + [InlineData(ElementType.NonGeneric)] + [InlineData(ElementType.Generic)] + public void CanReadArrayThatContainsMemberPrimitiveTypedRecord_Jagged(ElementType elementType) + { + if (elementType != ElementType.Object && !IsPatched) + { + throw new SkipTestException("Current machine has not been patched with the most recent BinaryFormatter fix."); + } + + const int Integer = 123; + Array input = elementType switch + { + ElementType.Object => new object[1][] { [Integer] }, + ElementType.NonGeneric => new IComparable[1][] { [Integer] }, + ElementType.Generic => new IComparable[1][] { [Integer] }, + _ => throw new InvalidOperationException() + }; + + ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap); + ArrayRecord[] output = (ArrayRecord[])arrayRecord.GetArray(input.GetType()); + + Verify(input, arrayRecord, output, recordMap); + SZArrayRecord contained = (SZArrayRecord)output.Single(); + PrimitiveTypeRecord intRecord = (PrimitiveTypeRecord)contained.GetArray().Single(); + Assert.Equal(Integer, intRecord.Value); + } + + public static IEnumerable NullAndArrayPermutations() + { + foreach (ElementType elementType in Enum.GetValues(typeof(ElementType))) + { + yield return new object[] { elementType, 1 }; // ObjectNullRecord + yield return new object[] { elementType, 200 }; // ObjectNullMultiple256Record + yield return new object[] { elementType, 1_000 }; // ObjectNullMultipleRecord + } + } + + [Theory] + [MemberData(nameof(NullAndArrayPermutations))] + public void CanReadArrayThatContainsNullRecords_SZ(ElementType elementType, int nullCount) + { + const string Text = "notNull"; + Array input = elementType switch + { + ElementType.Object => new object[nullCount + 1], + ElementType.NonGeneric => new IEnumerable[nullCount + 1], + ElementType.Generic => new IEnumerable[nullCount + 1], + _ => throw new InvalidOperationException() + }; + input.SetValue(Text, nullCount); + + SZArrayRecord arrayRecord = (SZArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap); + SerializationRecord?[] output = arrayRecord.GetArray(); + + Verify(input, arrayRecord, output, recordMap); + Assert.All(output.Take(nullCount), Assert.Null); + PrimitiveTypeRecord stringRecord = (PrimitiveTypeRecord)output[nullCount]; + Assert.Equal(Text, stringRecord.Value); + } + + [Theory] + [MemberData(nameof(NullAndArrayPermutations))] + public void CanReadArrayThatContainsNullRecords_MD(ElementType elementType, int nullCount) + { + const string Text = "notNull"; + Array input = elementType switch + { + ElementType.Object => new object[1, nullCount + 1], + ElementType.NonGeneric => new IEnumerable[1, nullCount + 1], + ElementType.Generic => new IEnumerable[1, nullCount + 1], + _ => throw new InvalidOperationException() + }; + input.SetValue(Text, 0, nullCount); + + ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap); + SerializationRecord[,] output = (SerializationRecord[,])arrayRecord.GetArray(input.GetType()); + + Verify(input, arrayRecord, output, recordMap); + for (int i = 0; i < nullCount; i++) + { + Assert.Null(output[0, i]); + } + PrimitiveTypeRecord stringRecord = (PrimitiveTypeRecord)output[0, nullCount]; + Assert.Equal(Text, stringRecord.Value); + } + + [Theory] + [MemberData(nameof(NullAndArrayPermutations))] + public void CanReadArrayThatContainsNullRecords_Jagged(ElementType elementType, int nullCount) + { + const string Text = "notNull"; + Array input = elementType switch + { + ElementType.Object => new object[1][] { new object[nullCount + 1] }, + ElementType.NonGeneric => new IEnumerable[1][] { new IEnumerable[nullCount + 1] }, + ElementType.Generic => new IEnumerable[1][] { new IEnumerable[nullCount + 1] }, + _ => throw new InvalidOperationException() + }; + ((Array)input.GetValue(0)).SetValue(Text, nullCount); + + ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap); + ArrayRecord[] output = (ArrayRecord[])arrayRecord.GetArray(input.GetType()); + + Verify(input, arrayRecord, output, recordMap); + SZArrayRecord contained = (SZArrayRecord)output.Single(); + Assert.All(contained.GetArray().Take(nullCount), Assert.Null); + PrimitiveTypeRecord stringRecord = (PrimitiveTypeRecord)contained.GetArray()[nullCount]; + Assert.Equal(Text, stringRecord.Value); + } + + [Theory] + [InlineData(ElementType.Object)] + [InlineData(ElementType.NonGeneric)] + [InlineData(ElementType.Generic)] + public void CanReadArrayThatContainsArrayRecord_SZ(ElementType elementType) + { + int[] intArray = [1, 2, 3]; + Array input = elementType switch + { + ElementType.Object => new object[] { intArray }, + ElementType.NonGeneric => new IEnumerable[] { intArray }, + ElementType.Generic => new IEnumerable[] { intArray }, + _ => throw new InvalidOperationException() + }; + + SZArrayRecord arrayRecord = (SZArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap); + SerializationRecord[] output = arrayRecord.GetArray(); + + Verify(input, arrayRecord, output, recordMap); + SZArrayRecord intArrayRecord = (SZArrayRecord)output.Single(); + Assert.Equal(intArray, intArrayRecord.GetArray()); + } + + [Theory] + [InlineData(ElementType.Object)] + [InlineData(ElementType.NonGeneric)] + [InlineData(ElementType.Generic)] + public void CanReadArrayThatContainsArrayRecord_MD(ElementType elementType) + { + int[] intArray = [1, 2, 3]; + Array input = elementType switch + { + ElementType.Object => new object[1, 1], + ElementType.NonGeneric => new IEnumerable[1, 1], + ElementType.Generic => new IEnumerable[1, 1], + _ => throw new InvalidOperationException() + }; + input.SetValue(intArray, 0, 0); + + ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap); + SerializationRecord[,] output = (SerializationRecord[,])arrayRecord.GetArray(input.GetType()); + + Verify(input, arrayRecord, output, recordMap); + SZArrayRecord intArrayRecord = (SZArrayRecord)output[0, 0]; + Assert.Equal(intArray, intArrayRecord.GetArray()); + } + + [Theory] + [InlineData(ElementType.Object)] + [InlineData(ElementType.NonGeneric)] + [InlineData(ElementType.Generic)] + public void CanReadArrayThatContainsArrayRecord_Jagged(ElementType elementType) + { + int[] intArray = [1, 2, 3]; + Array input = elementType switch + { + ElementType.Object => new object[1][] { [intArray] }, + ElementType.NonGeneric => new IEnumerable[1][] { [intArray] }, + ElementType.Generic => new IEnumerable[1][] { [intArray] }, + _ => throw new InvalidOperationException() + }; + + ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap); + ArrayRecord[] output = (ArrayRecord[])arrayRecord.GetArray(input.GetType()); + + Verify(input, arrayRecord, output, recordMap); + SZArrayRecord contained = (SZArrayRecord)output.Single(); + SZArrayRecord intArrayRecord = (SZArrayRecord)contained.GetArray().Single(); + Assert.Equal(intArray, intArrayRecord.GetArray()); + } + + [Theory] + [InlineData(ElementType.Object)] + [InlineData(ElementType.NonGeneric)] + public void CanReadArrayThatContainsAllRecordTypes_SZ(ElementType elementType) + { + const string Text = "hello"; + int[] intArray = [1, 2, 3]; + CustomClassThatImplementsIEnumerable classThatImplementsIEnumerable = new() { Field = 456 }; + Array input = elementType switch + { + ElementType.Object => new object[] + { + Text, // BinaryObjectStringRecord + intArray, // ArraySinglePrimitiveRecord + classThatImplementsIEnumerable, // ClassWithMembersAndTypesRecord, + null // ObjectNullRecord + }, + ElementType.NonGeneric => new IEnumerable[] { Text, intArray, classThatImplementsIEnumerable, null }, + _ => throw new InvalidOperationException() + }; + + SZArrayRecord arrayRecord = (SZArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap); + SerializationRecord[] output = arrayRecord.GetArray(); + + Verify(input, arrayRecord, output, recordMap); + PrimitiveTypeRecord stringRecord = (PrimitiveTypeRecord)output[0]; + Assert.Equal(Text, stringRecord.Value); + SZArrayRecord intArrayRecord = (SZArrayRecord)output[1]; + Assert.Equal(intArray, intArrayRecord.GetArray()); + ClassRecord classRecord = (ClassRecord)output[2]; + Assert.Equal(classThatImplementsIEnumerable.Field, classRecord.GetInt32(nameof(CustomClassThatImplementsIEnumerable.Field))); + Assert.Null(output[3]); + } + + [Theory] + [InlineData(ElementType.Object)] + [InlineData(ElementType.NonGeneric)] + public void CanReadArrayThatContainsAllRecordTypes_MD(ElementType elementType) + { + const string Text = "hello"; + int[] intArray = [1, 2, 3]; + CustomClassThatImplementsIEnumerable classThatImplementsIEnumerable = new() { Field = 456 }; + + Array input = elementType switch + { + ElementType.Object => new object[1, 4], + ElementType.NonGeneric => new IEnumerable[1, 4], + _ => throw new InvalidOperationException() + }; + input.SetValue(Text, 0, 0); + input.SetValue(intArray, 0, 1); + input.SetValue(classThatImplementsIEnumerable, 0, 2); + input.SetValue(null, 0, 3); + + ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap); + SerializationRecord[,] output = (SerializationRecord[,])arrayRecord.GetArray(input.GetType()); + + Verify(input, arrayRecord, output, recordMap); + PrimitiveTypeRecord stringRecord = (PrimitiveTypeRecord)output[0, 0]; + Assert.Equal(Text, stringRecord.Value); + SZArrayRecord intArrayRecord = (SZArrayRecord)output[0, 1]; + Assert.Equal(intArray, intArrayRecord.GetArray()); + ClassRecord classRecord = (ClassRecord)output[0, 2]; + Assert.Equal(classThatImplementsIEnumerable.Field, classRecord.GetInt32(nameof(CustomClassThatImplementsIEnumerable.Field))); + Assert.Null(output[0, 3]); + } + + [Theory] + [InlineData(ElementType.Object)] + [InlineData(ElementType.NonGeneric)] + public void CanReadArrayThatContainsAllRecordTypes_Jagged(ElementType elementType) + { + const string Text = "hello"; + int[] intArray = [1, 2, 3]; + CustomClassThatImplementsIEnumerable classThatImplementsIEnumerable = new() { Field = 456 }; + + Array input = elementType switch + { + ElementType.Object => new object[1][] { [Text, intArray, classThatImplementsIEnumerable, null] }, + ElementType.NonGeneric => new IEnumerable[1][] { [Text, intArray, classThatImplementsIEnumerable, null] }, + _ => throw new InvalidOperationException() + }; + + ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap); + ArrayRecord[] output = (ArrayRecord[])arrayRecord.GetArray(input.GetType()); + + Verify(input, arrayRecord, output, recordMap); + SZArrayRecord contained = (SZArrayRecord)output.Single(); + SerializationRecord[] records = contained.GetArray(); + PrimitiveTypeRecord stringRecord = (PrimitiveTypeRecord)records[0]; + Assert.Equal(Text, stringRecord.Value); + SZArrayRecord intArrayRecord = (SZArrayRecord)records[1]; + Assert.Equal(intArray, intArrayRecord.GetArray()); + ClassRecord classRecord = (ClassRecord)records[2]; + Assert.Equal(classThatImplementsIEnumerable.Field, classRecord.GetInt32(nameof(CustomClassThatImplementsIEnumerable.Field))); + Assert.Null(records[3]); + } + + [Theory] + [InlineData(ElementType.Object)] + [InlineData(ElementType.NonGeneric)] + public void CanReadArrayThatContainsAllRecordTypes_Jagged_MD(ElementType elementType) + { + const string Text = "hello"; + int[] intArray = [1, 2, 3]; + CustomClassThatImplementsIEnumerable classThatImplementsIEnumerable = new() { Field = 456 }; + + Array input = elementType switch + { + ElementType.Object => new object[1, 1][,], + ElementType.NonGeneric => new IEnumerable[1, 1][,], + _ => throw new InvalidOperationException() + }; + Array contained = elementType switch + { + ElementType.Object => new object[2, 2], + ElementType.NonGeneric => new IEnumerable[2, 2], + _ => throw new InvalidOperationException() + }; + contained.SetValue(Text, 0, 0); + contained.SetValue(intArray, 0, 1); + contained.SetValue(classThatImplementsIEnumerable, 1, 0); + input.SetValue(contained, 0, 0); + + ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap); + ArrayRecord[,] output = (ArrayRecord[,])arrayRecord.GetArray(input.GetType()); + + Verify(input, arrayRecord, output, recordMap); + SerializationRecord[,] records = (SerializationRecord[,])output[0, 0].GetArray(contained.GetType()); + PrimitiveTypeRecord stringRecord = (PrimitiveTypeRecord)records[0, 0]; + Assert.Equal(Text, stringRecord.Value); + SZArrayRecord intArrayRecord = (SZArrayRecord)records[0, 1]; + Assert.Equal(intArray, intArrayRecord.GetArray()); + ClassRecord classRecord = (ClassRecord)records[1, 0]; + Assert.Equal(classThatImplementsIEnumerable.Field, classRecord.GetInt32(nameof(CustomClassThatImplementsIEnumerable.Field))); + Assert.Null(records[1, 1]); + } + + [Fact] + public void TypeMismatch() + { + // An array of strings that contains non-string. + byte[] bytes = Convert.FromBase64String("AAEAAAD/////AQAAAAAAAAAHAQAAAAICAAAAAQAAAAEAAAABCQEAAAAL"); + + ArrayRecord arrRecord = (ArrayRecord)NrbfDecoder.Decode(new MemoryStream(bytes)); + + Assert.Throws(() => arrRecord.GetArray(typeof(string[,]))); + } + + private static void Verify(Array input, ArrayRecord arrayRecord, Array output, + IReadOnlyDictionary recordMap) + { + Assert.Equal(input.Rank, arrayRecord.Rank); + Assert.Equal(input.Rank, output.Rank); + + for (int i = 0; i < input.Rank; i++) + { + Assert.Equal(input.GetLength(i), arrayRecord.Lengths[i]); + Assert.Equal(input.GetLength(i), output.GetLength(i)); + } + + foreach (object? recordOrNull in output) + { + if (recordOrNull is SerializationRecord record && !record.Id.Equals(default)) + { + // An array of abstractions always uses SystemClassWithMembersAndTypesRecord to represent primitive values. + // This requires some non-trivial mapping and we need to ensure that it's reflected not only in what + // has been stored in the array, but also in the record map. + Assert.Same(record, recordMap[record.Id]); + } + } + } + } +} diff --git a/src/libraries/System.Formats.Nrbf/tests/ArraySinglePrimitiveRecordTests.cs b/src/libraries/System.Formats.Nrbf/tests/ArraySinglePrimitiveRecordTests.cs index 49d523088a89fe..7ef801808e4e95 100644 --- a/src/libraries/System.Formats.Nrbf/tests/ArraySinglePrimitiveRecordTests.cs +++ b/src/libraries/System.Formats.Nrbf/tests/ArraySinglePrimitiveRecordTests.cs @@ -5,6 +5,7 @@ using System.IO; using System.Runtime.Serialization; using System.Text; +using Microsoft.DotNet.XUnitExtensions; using Xunit; namespace System.Formats.Nrbf.Tests; @@ -71,63 +72,63 @@ public void DontCastBytesToDateTimes() Assert.Throws(() => NrbfDecoder.Decode(stream)); } - [Theory] + [ConditionalTheory] [MemberData(nameof(GetCanReadArrayOfAnySizeArgs))] public void CanReadArrayOfAnySize_Bool(int size, bool canSeek) => Test(size, canSeek); - [Theory] + [ConditionalTheory] [MemberData(nameof(GetCanReadArrayOfAnySizeArgs))] public void CanReadArrayOfAnySize_Byte(int size, bool canSeek) => Test(size, canSeek); - [Theory] + [ConditionalTheory] [MemberData(nameof(GetCanReadArrayOfAnySizeArgs))] public void CanReadArrayOfAnySize_SByte(int size, bool canSeek) => Test(size, canSeek); - [Theory] + [ConditionalTheory] [MemberData(nameof(GetCanReadArrayOfAnySizeArgs))] public void CanReadArrayOfAnySize_Char(int size, bool canSeek) => Test(size, canSeek); - [Theory] + [ConditionalTheory] [MemberData(nameof(GetCanReadArrayOfAnySizeArgs))] public void CanReadArrayOfAnySize_Int16(int size, bool canSeek) => Test(size, canSeek); - [Theory] + [ConditionalTheory] [MemberData(nameof(GetCanReadArrayOfAnySizeArgs))] public void CanReadArrayOfAnySize_UInt16(int size, bool canSeek) => Test(size, canSeek); - [Theory] + [ConditionalTheory] [MemberData(nameof(GetCanReadArrayOfAnySizeArgs))] public void CanReadArrayOfAnySize_Int32(int size, bool canSeek) => Test(size, canSeek); - [Theory] + [ConditionalTheory] [MemberData(nameof(GetCanReadArrayOfAnySizeArgs))] public void CanReadArrayOfAnySize_UInt32(int size, bool canSeek) => Test(size, canSeek); - [Theory] + [ConditionalTheory] [MemberData(nameof(GetCanReadArrayOfAnySizeArgs))] public void CanReadArrayOfAnySize_Int64(int size, bool canSeek) => Test(size, canSeek); - [Theory] + [ConditionalTheory] [MemberData(nameof(GetCanReadArrayOfAnySizeArgs))] public void CanReadArrayOfAnySize_UInt64(int size, bool canSeek) => Test(size, canSeek); - [Theory] + [ConditionalTheory] [MemberData(nameof(GetCanReadArrayOfAnySizeArgs))] public void CanReadArrayOfAnySize_Single(int size, bool canSeek) => Test(size, canSeek); - [Theory] + [ConditionalTheory] [MemberData(nameof(GetCanReadArrayOfAnySizeArgs))] public void CanReadArrayOfAnySize_Double(int size, bool canSeek) => Test(size, canSeek); - [Theory] + [ConditionalTheory] [MemberData(nameof(GetCanReadArrayOfAnySizeArgs))] public void CanReadArrayOfAnySize_TimeSpan(int size, bool canSeek) => Test(size, canSeek); - [Theory] + [ConditionalTheory] [MemberData(nameof(GetCanReadArrayOfAnySizeArgs))] public void CanReadArrayOfAnySize_DateTime(int size, bool canSeek) => Test(size, canSeek); - private void Test(int size, bool canSeek) + private void Test(int size, bool canSeek) where T : IComparable { Random constSeed = new Random(27644437); T[] input = new T[size]; @@ -136,17 +137,69 @@ private void Test(int size, bool canSeek) input[i] = GenerateValue(constSeed); } + TestSZArrayOfT(input, size, canSeek); + TestSZArrayOfIComparable(input, size, canSeek); + } + + private void TestSZArrayOfT(T[] input, int size, bool canSeek) + { MemoryStream stream = Serialize(input); stream = canSeek ? stream : new NonSeekableStream(stream.ToArray()); SZArrayRecord arrayRecord = (SZArrayRecord)NrbfDecoder.Decode(stream); Assert.Equal(size, arrayRecord.Length); - Assert.Equal(size, arrayRecord.FlattenedLength); T?[] output = arrayRecord.GetArray(); Assert.Equal(input, output); Assert.Same(output, arrayRecord.GetArray()); } + private void TestSZArrayOfIComparable(T[] input, int size, bool canSeek) where T : IComparable + { + if (!IsPatched) + { + throw new SkipTestException("Current machine has not been patched with the most recent BinaryFormatter fix."); + } + + // Arrays of abstractions that store primitive values (example: new IComparable[1] { int.MaxValue }) + // are represented by BinaryFormatter with a single SystemClassWithMembersAndTypesRecord + // and multiple ClassWithIdRecord that re-use the information from the system record. + // This requires some non-trivial mapping and this test is very important as it covers that code path. + IComparable[] comparables = new IComparable[size]; + for (int i = 0; i < input.Length; i++) + { + comparables[i] = input[i]; + } + + TestArrayOfSerializationRecords(input, comparables, canSeek); + } + + private void TestSZArrayOfObjects(T[] input, int size, bool canSeek) + { + // Arrays of objects that store primitive values (example: new object[1] { int.MaxValue }) + // are represented by BinaryFormatter with MemberPrimitiveTypedRecord instances. + object[] objects = new object[size]; + for (int i = 0; i < input.Length; i++) + { + objects[i] = input[i]; + } + + TestArrayOfSerializationRecords(input, objects, canSeek); + } + + private void TestArrayOfSerializationRecords(T[] values, object input, bool canSeek) + { + MemoryStream stream = Serialize(input); + + stream = canSeek ? stream : new NonSeekableStream(stream.ToArray()); + SZArrayRecord arrayRecordOfPrimitiveRecords = (SZArrayRecord)NrbfDecoder.Decode(stream); + SerializationRecord[] arrayOfPrimitiveRecords = arrayRecordOfPrimitiveRecords.GetArray(); + for (int i = 0; i < values.Length; i++) + { + Assert.Equal(values[i], ((PrimitiveTypeRecord)arrayOfPrimitiveRecords[i]).Value); + Assert.Equal(values[i], ((PrimitiveTypeRecord)arrayOfPrimitiveRecords[i]).Value); + } + } + private static T GenerateValue(Random random) { if (typeof(T) == typeof(byte)) diff --git a/src/libraries/System.Formats.Nrbf/tests/AttackTests.cs b/src/libraries/System.Formats.Nrbf/tests/AttackTests.cs index fe780d94698df0..3a81e3f131c823 100644 --- a/src/libraries/System.Formats.Nrbf/tests/AttackTests.cs +++ b/src/libraries/System.Formats.Nrbf/tests/AttackTests.cs @@ -50,18 +50,51 @@ public void CyclicReferencesInSystemClassesDoNotCauseStackOverflow() } [Fact] - public void CyclicReferencesInArraysOfObjectsDoNotCauseStackOverflow() + public void CyclicReferencesInSZArraysOfObjectsDoNotCauseStackOverflow() { object[] input = new object[2]; input[0] = "not an array"; input[1] = input; ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input)); - object?[] output = ((SZArrayRecord)arrayRecord).GetArray(); + SerializationRecord?[] output = ((SZArrayRecord)arrayRecord).GetArray(); - Assert.Equal(input[0], output[0]); + Assert.Equal(input[0], ((PrimitiveTypeRecord)output[0]).Value); Assert.Same(input, input[1]); - Assert.Same(output, output[1]); + Assert.Same(arrayRecord, output[1]); + } + + [Fact] + public void CyclicReferencesInMDArraysOfObjectsDoNotCauseStackOverflow() + { + object[,] input = new object[2, 2]; + input[0, 0] = "not an array"; + input[1, 1] = input; + + ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input)); + SerializationRecord?[,] output = (SerializationRecord?[,])arrayRecord.GetArray(typeof(object[,])); + + Assert.Equal(input[0, 0], ((PrimitiveTypeRecord)output[0, 0]).Value); + Assert.Same(input, input[1, 1]); + Assert.Same(arrayRecord, output[1, 1]); + } + + [Fact] + public void CyclicReferencesInJaggedArraysOfObjectsDoNotCauseStackOverflow() + { + object[][] input = new object[1][]; + input[0] = new object[2]; + input[0][0] = "not an array"; + input[0][1] = input; + + ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input)); + ArrayRecord[] output = (ArrayRecord[])arrayRecord.GetArray(typeof(object[][])); + SZArrayRecord row = (SZArrayRecord)output.Single(); + SerializationRecord[] contained = row.GetArray(); + + Assert.Equal(input[0][0], ((PrimitiveTypeRecord)contained[0]).Value); + Assert.Same(input, input[0][1]); + Assert.Same(arrayRecord, contained[1]); } [Serializable] @@ -81,8 +114,8 @@ public void CyclicClassReferencesInArraysOfObjectsDoNotCauseStackOverflow() ClassRecord classRecord = NrbfDecoder.DecodeClassRecord(Serialize(input)); Assert.Equal(input.Name, classRecord.GetString(nameof(WithCyclicReferenceInArrayOfObjects.Name))); - SZArrayRecord arrayRecord = (SZArrayRecord)classRecord.GetSerializationRecord(nameof(WithCyclicReferenceInArrayOfObjects.ArrayWithReferenceToSelf))!; - object?[] array = arrayRecord.GetArray(); + SZArrayRecord arrayRecord = (SZArrayRecord)classRecord.GetSerializationRecord(nameof(WithCyclicReferenceInArrayOfObjects.ArrayWithReferenceToSelf))!; + SerializationRecord?[] array = arrayRecord.GetArray(); Assert.Same(classRecord, array.Single()); } @@ -103,7 +136,7 @@ public void CyclicClassReferencesInArraysOfTDoNotCauseStackOverflow() ClassRecord classRecord = NrbfDecoder.DecodeClassRecord(Serialize(input)); Assert.Equal(input.Name, classRecord.GetString(nameof(WithCyclicReferenceInArrayOfT.Name))); - SZArrayRecord classRecords = (SZArrayRecord)classRecord.GetSerializationRecord(nameof(WithCyclicReferenceInArrayOfT.ArrayWithReferenceToSelf))!; + SZArrayRecord classRecords = (SZArrayRecord)classRecord.GetSerializationRecord(nameof(WithCyclicReferenceInArrayOfT.ArrayWithReferenceToSelf))!; Assert.Same(classRecord, classRecords.GetArray().Single()); } diff --git a/src/libraries/System.Formats.Nrbf/tests/InvalidInputTests.cs b/src/libraries/System.Formats.Nrbf/tests/InvalidInputTests.cs index 6acb44d03697d2..2d78954d649094 100644 --- a/src/libraries/System.Formats.Nrbf/tests/InvalidInputTests.cs +++ b/src/libraries/System.Formats.Nrbf/tests/InvalidInputTests.cs @@ -355,6 +355,36 @@ public void ThrowsForInvalidPositiveArrayRank(int rank, byte arrayType) Assert.Throws(() => NrbfDecoder.Decode(stream)); } + public static IEnumerable AllPrimitiveTypes() + { + foreach (PrimitiveType primitiveType in Enum.GetValues(typeof(PrimitiveType))) + { + yield return new object[] { (byte)primitiveType }; + } + } + + [Theory] + [MemberData(nameof(AllPrimitiveTypes))] + public void ThrowsForInvalidPrimitiveTypeForBinaryArrayRecords(byte primitiveType) + { + using MemoryStream stream = new(); + BinaryWriter writer = new(stream, Encoding.UTF8); + + WriteSerializedStreamHeader(writer); + + writer.Write((byte)SerializationRecordType.BinaryArray); + writer.Write(1); // object Id + writer.Write((byte)BinaryArrayType.Jagged); + writer.Write(1); // rank! + writer.Write(1); // length + writer.Write((byte)BinaryType.Primitive); // A jagged array must consist of other arrays, not primitive values + writer.Write(primitiveType); + writer.Write((byte)SerializationRecordType.MessageEnd); + + stream.Position = 0; + Assert.Throws(() => NrbfDecoder.Decode(stream)); + } + [Theory] [InlineData(SerializationRecordType.ClassWithMembersAndTypes)] [InlineData(SerializationRecordType.SystemClassWithMembersAndTypes)] diff --git a/src/libraries/System.Formats.Nrbf/tests/JaggedArraysTests.cs b/src/libraries/System.Formats.Nrbf/tests/JaggedArraysTests.cs index 8bb844ff76a586..f02128ab08f99d 100644 --- a/src/libraries/System.Formats.Nrbf/tests/JaggedArraysTests.cs +++ b/src/libraries/System.Formats.Nrbf/tests/JaggedArraysTests.cs @@ -24,40 +24,43 @@ public void CanReadJaggedArraysOfPrimitiveTypes_2D(bool useReferences) var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input)); Verify(input, arrayRecord); - Assert.Equal(input, arrayRecord.GetArray(input.GetType())); - Assert.Equal(input.Length + input.Length * 3, arrayRecord.FlattenedLength); + ArrayRecord?[] output = (ArrayRecord?[])arrayRecord.GetArray(input.GetType()); + for (int i = 0; i < input.Length; i++) + { + Assert.Equal(input[i], ((SZArrayRecord)output[i]).GetArray()); + if (useReferences) + { + Assert.Same(((SZArrayRecord)output[0]).GetArray(), ((SZArrayRecord)output[i]).GetArray()); + } + } } [Theory] [InlineData(1)] // SerializationRecordType.ObjectNull [InlineData(200)] // SerializationRecordType.ObjectNullMultiple256 [InlineData(10_000)] // SerializationRecordType.ObjectNullMultiple - public void FlattenedLengthIncludesNullArrays(int nullCount) + public void NullRecordsOfAllKindsAreHandledProperly(int nullCount) { int[][] input = new int[nullCount][]; var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input)); Verify(input, arrayRecord); - Assert.Equal(input, arrayRecord.GetArray(input.GetType())); - Assert.Equal(nullCount, arrayRecord.FlattenedLength); + ArrayRecord?[] output = (ArrayRecord?[])arrayRecord.GetArray(input.GetType()); + Assert.All(output, Assert.Null); } [Fact] public void ItIsPossibleToHaveBinaryArrayRecordsHaveAnElementTypeOfArrayWithoutBeingMarkedAsJagged() { int[][][] input = new int[3][][]; - long totalElementsCount = 0; for (int i = 0; i < input.Length; i++) { input[i] = new int[4][]; - totalElementsCount++; // count the arrays themselves for (int j = 0; j < input[i].Length; j++) { input[i][j] = [i, j, 0, 1, 2]; - totalElementsCount += input[i][j].Length; - totalElementsCount++; // count the arrays themselves } } @@ -75,57 +78,105 @@ public void ItIsPossibleToHaveBinaryArrayRecordsHaveAnElementTypeOfArrayWithoutB var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(new MemoryStream(serialized)); Verify(input, arrayRecord); - Assert.Equal(input, arrayRecord.GetArray(input.GetType())); - Assert.Equal(3 + 3 * 4 + 3 * 4 * 5, totalElementsCount); - Assert.Equal(totalElementsCount, arrayRecord.FlattenedLength); + ArrayRecord?[] output = (ArrayRecord?[])arrayRecord.GetArray(input.GetType()); + for (int i = 0; i < input.Length; i++) + { + ArrayRecord[] firstLevel = (ArrayRecord[])output[i].GetArray(typeof(int[][])); + + for (int j = 0; j < input[i].Length; j++) + { + Assert.Equal(input[i][j], (int[])firstLevel[j].GetArray(typeof(int[]))); + } + } } [Fact] - public void CanReadJaggedArraysOfPrimitiveTypes_3D() + public void CanReadSZJaggedArrayOfMDArrays() { - int[][][] input = new int[7][][]; - long totalElementsCount = 0; + int[][,] input = new int[7][,]; for (int i = 0; i < input.Length; i++) { - totalElementsCount++; // count the arrays themselves - input[i] = new int[1][]; - totalElementsCount++; // count the arrays themselves - input[i][0] = [i, i, i]; - totalElementsCount += input[i][0].Length; + input[i] = new int[3, 3]; + + for (int j = 0; j < input[i].GetLength(0); j++) + { + for (int k = 0; k < input[i].GetLength(1); k++) + { + input[i][j, k] = i * j * k; + } + } } var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input)); Verify(input, arrayRecord); - Assert.Equal(input, arrayRecord.GetArray(input.GetType())); - Assert.Equal(1, arrayRecord.Rank); - Assert.Equal(7 + 7 * 1 + 7 * 1 * 3, totalElementsCount); - Assert.Equal(totalElementsCount, arrayRecord.FlattenedLength); + ArrayRecord[] output = (ArrayRecord[])arrayRecord.GetArray(input.GetType()); + for (int i = 0; i < input.Length; i++) + { + Assert.Equal(input[i], output[i].GetArray(typeof(int[,]))); + } } [Fact] - public void CanReadJaggedArrayOfRectangularArrays() + public void CanReadMDJaggedArrayOfSZArrays() { - int[][,] input = new int[7][,]; - for (int i = 0; i < input.Length; i++) - { - input[i] = new int[3,3]; + int[,][] input = new int[2,2][]; + input[0, 0] = [1, 2, 3]; - for (int j = 0; j < input[i].GetLength(0); j++) + var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input)); + + Verify(input, arrayRecord); + ArrayRecord[,] output = (ArrayRecord[,])arrayRecord.GetArray(input.GetType()); + Assert.Equal(input[0, 0], output[0, 0].GetArray(typeof(int[]))); + Assert.Null(output[0, 1]); + Assert.Null(output[1, 0]); + Assert.Null(output[1, 1]); + } + + [Fact] + public void MultiDimensionalArrayOfMultiDimensionalArrays_Integers() + => MultiDimensionalArrayOfMultiDimensionalArrays(static (x, y) => x * y); + + [Fact] + public void MultiDimensionalArrayOfMultiDimensionalArrays_Doubles() + => MultiDimensionalArrayOfMultiDimensionalArrays(static (x, y) => x * y / 10); + + [Fact] + public void MultiDimensionalArrayOfMultiDimensionalArrays_Strings() + => MultiDimensionalArrayOfMultiDimensionalArrays(static (x, y) => $"{x},{y}"); + + static void MultiDimensionalArrayOfMultiDimensionalArrays(Func valueFactory) + { + T[,][,] input = new T[2, 2][,]; + for (int i = 0; i < input.GetLength(0); i++) + { + for (int j = 0; j < input.GetLength(1); j++) { - for (int k = 0; k < input[i].GetLength(1); k++) + T[,] contained = new T[i + 1, j + 1]; + for (int k = 0; k < contained.GetLength(0); k++) { - input[i][j, k] = i * j * k; + for (int l = 0; l < contained.GetLength(1); l++) + { + contained[k, l] = valueFactory(k, l); + } } + + input[i, j] = contained; } } var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input)); Verify(input, arrayRecord); - Assert.Equal(input, arrayRecord.GetArray(input.GetType())); - Assert.Equal(1, arrayRecord.Rank); - Assert.Equal(input.Length + input.Length * 3 * 3, arrayRecord.FlattenedLength); + + ArrayRecord[,] output = (ArrayRecord[,])arrayRecord.GetArray(input.GetType()); + for (int i = 0; i < input.GetLength(0); i++) + { + for (int j = 0; j < input.GetLength(1); j++) + { + Assert.Equal(input[i, j], output[i, j].GetArray(typeof(T[,]))); + } + } } [Fact] @@ -140,8 +191,11 @@ public void CanReadJaggedArraysOfStrings() var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input)); Verify(input, arrayRecord); - Assert.Equal(input, arrayRecord.GetArray(input.GetType())); - Assert.Equal(input.Length + input.Length * 3, arrayRecord.FlattenedLength); + ArrayRecord[] output = (ArrayRecord[])arrayRecord.GetArray(input.GetType()); + for (int i = 0; i < input.Length; i++) + { + Assert.Equal(input[i], ((SZArrayRecord)output[i]).GetArray()); + } } [Fact] @@ -156,8 +210,16 @@ public void CanReadJaggedArraysOfObjects() var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input)); Verify(input, arrayRecord); - Assert.Equal(input, arrayRecord.GetArray(input.GetType())); - Assert.Equal(input.Length + input.Length * 3, arrayRecord.FlattenedLength); + ArrayRecord[] output = (ArrayRecord[])arrayRecord.GetArray(input.GetType()); + + for (int i = 0; i < input.Length; i++) + { + SerializationRecord[] row = (SerializationRecord[])output[i].GetArray(typeof(object[])); + for (int j = 0; j < input[i].Length; j++) + { + Assert.Equal(input[i][j], ((PrimitiveTypeRecord)row[j]).Value); + } + } } [Serializable] @@ -170,32 +232,28 @@ public class ComplexType public void CanReadJaggedArraysOfComplexTypes() { ComplexType[][] input = new ComplexType[3][]; - long totalElementsCount = 0; for (int i = 0; i < input.Length; i++) { input[i] = Enumerable.Range(0, i + 1).Select(j => new ComplexType { SomeField = j }).ToArray(); - totalElementsCount += input[i].Length; - totalElementsCount++; // count the arrays themselves } var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input)); Verify(input, arrayRecord); - Assert.Equal(totalElementsCount, arrayRecord.FlattenedLength); - var output = (ClassRecord?[][])arrayRecord.GetArray(input.GetType()); + ArrayRecord[] output = (ArrayRecord[])arrayRecord.GetArray(input.GetType()); for (int i = 0; i < input.Length; i++) { + SerializationRecord[] row = ((SZArrayRecord)output[i]).GetArray(); for (int j = 0; j < input[i].Length; j++) { - Assert.Equal(input[i][j].SomeField, output[i][j]!.GetInt32(nameof(ComplexType.SomeField))); + Assert.Equal(input[i][j].SomeField, ((ClassRecord)row[j]!).GetInt32(nameof(ComplexType.SomeField))); } } } private static void Verify(Array input, ArrayRecord arrayRecord) { - Assert.Equal(1, arrayRecord.Lengths.Length); - Assert.Equal(input.Length, arrayRecord.Lengths[0]); + Assert.Equal(input.Rank, arrayRecord.Rank); Assert.True(arrayRecord.TypeName.GetElementType().IsArray); // true only for Jagged arrays Assert.Equal(input.GetType().FullName, arrayRecord.TypeName.FullName); Assert.Equal(input.GetType().GetAssemblyNameIncludingTypeForwards(), arrayRecord.TypeName.AssemblyName!.FullName); diff --git a/src/libraries/System.Formats.Nrbf/tests/ReadAnythingTests.cs b/src/libraries/System.Formats.Nrbf/tests/ReadAnythingTests.cs index 66cbab818dbba7..3dd4d212ca2805 100644 --- a/src/libraries/System.Formats.Nrbf/tests/ReadAnythingTests.cs +++ b/src/libraries/System.Formats.Nrbf/tests/ReadAnythingTests.cs @@ -32,8 +32,8 @@ public void UserCanReadAnyValidInputAndCheckTypesUsingStronglyTypedTypeInstances ClassRecord comparerRecord = dictionaryRecord.GetClassRecord(nameof(input.Comparer))!; Assert.True(comparerRecord.TypeNameMatches(input.Comparer.GetType())); - SZArrayRecord arrayRecord = (SZArrayRecord)dictionaryRecord.GetSerializationRecord("KeyValuePairs")!; - ClassRecord[] keyValuePairs = arrayRecord.GetArray()!; + SZArrayRecord arrayRecord = (SZArrayRecord)dictionaryRecord.GetSerializationRecord("KeyValuePairs")!; + ClassRecord[] keyValuePairs = arrayRecord.GetArray().OfType().ToArray(); Assert.True(keyValuePairs[0].TypeNameMatches(typeof(KeyValuePair))); ClassRecord exceptionPair = Find(keyValuePairs, "exception"); @@ -225,8 +225,8 @@ public void UserCanReadEveryPossibleSerializationRecord(object input) case ClassRecord record when record.TypeNameMatches(typeof(Exception)): Assert.Equal(((Exception)input).Message, record.GetString("Message")); break; - case SZArrayRecord record when record.TypeNameMatches(typeof(Exception[])): - Assert.Equal(((Exception[])input)[0].Message, record.GetArray()[0]!.GetString("Message")); + case SZArrayRecord record when record.TypeNameMatches(typeof(Exception[])): + Assert.Equal(((Exception[])input)[0].Message, ((ClassRecord)record.GetArray()[0]!).GetString("Message")); break; case ClassRecord record when record.TypeNameMatches(typeof(JsonException)): Assert.Equal(((JsonException)input).Message, record.GetString("Message")); @@ -241,7 +241,11 @@ public void UserCanReadEveryPossibleSerializationRecord(object input) Assert.Empty(record.MemberNames); break; case ArrayRecord arrayRecord when arrayRecord.TypeNameMatches(typeof(int?[])): - Assert.Equal(input, arrayRecord.GetArray(typeof(int?[]))); + SerializationRecord?[] nullableArray = (SerializationRecord?[])arrayRecord.GetArray(typeof(int?[])); + Assert.Equal(((int?[])input)[0], ((PrimitiveTypeRecord)nullableArray[0]).Value); + Assert.Equal(((int?[])input)[1], ((PrimitiveTypeRecord)nullableArray[1]).Value); + Assert.Equal(((int?[])input)[2], ((PrimitiveTypeRecord)nullableArray[2]).Value); + Assert.Null(nullableArray[3]); break; case ArrayRecord arrayRecord when arrayRecord.TypeNameMatches(typeof(EmptyClass[])): Assert.Equal(0, arrayRecord.Lengths.ToArray().Single()); @@ -262,11 +266,58 @@ public void UserCanReadEveryPossibleSerializationRecord(object input) static void VerifyDictionary(ClassRecord record) { - SZArrayRecord arrayRecord = (SZArrayRecord)record.GetSerializationRecord("KeyValuePairs")!; - ClassRecord[] keyValuePairs = arrayRecord.GetArray()!; + SZArrayRecord arrayRecord = (SZArrayRecord)record.GetSerializationRecord("KeyValuePairs")!; + ClassRecord[] keyValuePairs = arrayRecord.GetArray().OfType().ToArray(); Assert.True(keyValuePairs[0].TypeNameMatches(typeof(KeyValuePair))); } } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void UserCanReadArrayOfBaseType(bool mixed) + { + Mammal[] input = mixed + ? [new Dog { Name = "Buddy" }, new Cat { Name = "Luna" }] + : [new Dog { Name = "Buddy" }, new Dog { Name = "Rocky" }]; + + SZArrayRecord root = (SZArrayRecord)NrbfDecoder.Decode(Serialize(input)); + + SerializationRecord[] output = (SerializationRecord[])root.GetArray(typeof(Mammal[])); + Assert.True(output[0].TypeNameMatches(typeof(Dog))); + Assert.True(output[1].TypeNameMatches(mixed ? typeof(Cat) : typeof(Dog))); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void UserCanReadArrayOfDerivedTypes(bool dogs) + { + Array input = dogs + ? new Dog[] { new Dog { Name = "Buddy" }, new Dog { Name = "Rocky" } } + : new Cat[] { new Cat { Name = "Luna" }, new Cat { Name = "Tiger" } }; + + SZArrayRecord root = (SZArrayRecord)NrbfDecoder.Decode(Serialize(input)); + + Type expected = root.TypeName.GetElementType().FullName == typeof(Dog).FullName + ? typeof(Dog[]) + : typeof(Cat[]); + + SerializationRecord[] output = (SerializationRecord[])root.GetArray(expected); + Assert.All(output, record => record.TypeNameMatches(expected.GetElementType())); + } + + [Serializable] + public class Mammal + { + public string Name; + } + + [Serializable] + public class Cat : Mammal { } + + [Serializable] + public class Dog : Mammal { } } [Serializable] diff --git a/src/libraries/System.Formats.Nrbf/tests/ReadExactTypesTests.cs b/src/libraries/System.Formats.Nrbf/tests/ReadExactTypesTests.cs index ccf1dd402fc7b6..027293bce05a68 100644 --- a/src/libraries/System.Formats.Nrbf/tests/ReadExactTypesTests.cs +++ b/src/libraries/System.Formats.Nrbf/tests/ReadExactTypesTests.cs @@ -263,11 +263,11 @@ public void CanRead_ArraysOfComplexTypes() new () { Long = 5 }, ]; - SZArrayRecord arrayRecord = ((SZArrayRecord)NrbfDecoder.Decode(Serialize(input))); + SZArrayRecord arrayRecord = (SZArrayRecord)NrbfDecoder.Decode(Serialize(input)); Assert.Equal(typeof(CustomTypeWithPrimitiveFields[]).FullName, arrayRecord.TypeName.FullName); Assert.Equal(typeof(CustomTypeWithPrimitiveFields).Assembly.FullName, arrayRecord.TypeName.GetElementType().AssemblyName!.FullName); - ClassRecord?[] classRecords = arrayRecord.GetArray(); + ClassRecord?[] classRecords = arrayRecord.GetArray().OfType().ToArray(); for (int i = 0; i < input.Length; i++) { Verify(input[i], classRecords[i]!); @@ -298,8 +298,8 @@ public void CanRead_TypesWithArraysOfComplexTypes() ClassRecord classRecord = NrbfDecoder.DecodeClassRecord(Serialize(input)); - SZArrayRecord classRecords = (SZArrayRecord)classRecord.GetSerializationRecord(nameof(CustomTypeWithArrayOfComplexTypes.Array))!; - ClassRecord?[] array = classRecords.GetArray(); + SZArrayRecord classRecords = (SZArrayRecord)classRecord.GetSerializationRecord(nameof(CustomTypeWithArrayOfComplexTypes.Array))!; + SerializationRecord?[] array = classRecords.GetArray(); } [Theory] @@ -316,8 +316,8 @@ public void CanRead_TypesWithArraysOfComplexTypes_MultipleNulls(int nullCount) ClassRecord classRecord = NrbfDecoder.DecodeClassRecord(stream); - SZArrayRecord classRecords = (SZArrayRecord)classRecord.GetSerializationRecord(nameof(CustomTypeWithArrayOfComplexTypes.Array))!; - ClassRecord?[] array = classRecords.GetArray(); + SZArrayRecord classRecords = (SZArrayRecord)classRecord.GetSerializationRecord(nameof(CustomTypeWithArrayOfComplexTypes.Array))!; + SerializationRecord?[] array = classRecords.GetArray(); Assert.Equal(nullCount, array.Length); Assert.All(array, Assert.Null); @@ -337,7 +337,10 @@ public void CanRead_ArraysOfObjects() Assert.Equal(typeof(object[]).FullName, arrayRecord.TypeName.FullName); Assert.Equal("mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089", arrayRecord.TypeName.GetElementType().AssemblyName!.FullName); - Assert.Equal(input, ((SZArrayRecord)arrayRecord).GetArray()); + SerializationRecord?[] output = ((SZArrayRecord)arrayRecord).GetArray(); + Assert.Equal(input[0], ((PrimitiveTypeRecord)output[0]).Value); + Assert.Equal(input[1], ((PrimitiveTypeRecord)output[1]).Value); + Assert.Null(output[2]); } [Theory] @@ -348,7 +351,7 @@ public void CanRead_ArraysOfObjects_MultipleNulls(int nullCount) object?[] input = Enumerable.Repeat(null!, nullCount).ToArray(); ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input)); - object?[] output = ((SZArrayRecord)arrayRecord).GetArray(); + SerializationRecord?[] output = ((SZArrayRecord)arrayRecord).GetArray(); Assert.Equal(nullCount, output.Length); Assert.All(output, Assert.Null); @@ -374,9 +377,13 @@ public void CanRead_CustomTypeWithArrayOfObjects() }; ClassRecord classRecord = NrbfDecoder.DecodeClassRecord(Serialize(input)); - SZArrayRecord arrayRecord = (SZArrayRecord)classRecord.GetSerializationRecord(nameof(CustomTypeWithArrayOfObjects.Array))!; + SZArrayRecord arrayRecord = (SZArrayRecord)classRecord.GetSerializationRecord(nameof(CustomTypeWithArrayOfObjects.Array))!; + SerializationRecord?[] output = arrayRecord.GetArray(); - Assert.Equal(input.Array, arrayRecord.GetArray()); + Assert.Equal(input.Array[0], ((PrimitiveTypeRecord)output[0]).Value); + Assert.Equal(input.Array[1], ((PrimitiveTypeRecord)output[1]).Value); + Assert.Equal(input.Array[2], ((PrimitiveTypeRecord)output[2]).Value); + Assert.Null(output[3]); } [Theory] diff --git a/src/libraries/System.Formats.Nrbf/tests/ReadTests.cs b/src/libraries/System.Formats.Nrbf/tests/ReadTests.cs index 0c7bd2045fa1f8..8e5ce021db1751 100644 --- a/src/libraries/System.Formats.Nrbf/tests/ReadTests.cs +++ b/src/libraries/System.Formats.Nrbf/tests/ReadTests.cs @@ -8,6 +8,40 @@ namespace System.Formats.Nrbf.Tests; [ConditionalClass(typeof(PlatformDetection), nameof(PlatformDetection.IsBinaryFormatterSupported))] public abstract class ReadTests { + public static bool IsPatched +#if NET + => true; +#else + => s_isPatched.Value; + + private static readonly Lazy s_isPatched = new(GetIsPatched); + + private static bool GetIsPatched() + { + Tuple tuple = new Tuple(42, new byte[] { 1, 2, 3, 4 }); +#pragma warning disable SYSLIB0011 // Type or member is obsolete + BinaryFormatter formatter = new(); +#pragma warning restore SYSLIB0011 // Type or member is obsolete + using MemoryStream stream = new(); + + // This particular scenario is going to throw on Full Framework + // if given machine has not installed the July 2024 cumulative update preview: + // https://learn.microsoft.com/dotnet/framework/release-notes/2024/07-25-july-preview-cumulative-update + + try + { + formatter.Serialize(stream, tuple); + stream.Position = 0; + Tuple deserialized = (Tuple)formatter.Deserialize(stream); + return tuple.Item1.Equals(deserialized.Item1); + } + catch (Exception) + { + return false; + } + } +#endif + protected static MemoryStream Serialize(T instance) where T : notnull { MemoryStream ms = new(); diff --git a/src/libraries/System.Formats.Nrbf/tests/RectangularArraysTests.cs b/src/libraries/System.Formats.Nrbf/tests/RectangularArraysTests.cs index 3191d57ba807cd..b746faccbdb53f 100644 --- a/src/libraries/System.Formats.Nrbf/tests/RectangularArraysTests.cs +++ b/src/libraries/System.Formats.Nrbf/tests/RectangularArraysTests.cs @@ -64,11 +64,18 @@ public void CanReadRectangularArraysOfObjects_2D() using MemoryStream stream = Serialize(array); ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(stream); + SerializationRecord[,] output = (SerializationRecord[,])arrayRecord.GetArray(array.GetType()); Verify(array, arrayRecord); Assert.True(arrayRecord.TypeNameMatches(typeof(object[,]))); Assert.False(arrayRecord.TypeNameMatches(typeof(int[,]))); - Assert.Equal(array, arrayRecord.GetArray(typeof(object[,]))); + + for (int i = 0; i < array.GetLength(0); i++) + { + Assert.Equal(array[i, 0], ((PrimitiveTypeRecord)output[i, 0]).Value); + Assert.Equal(array[i, 1], ((PrimitiveTypeRecord)output[i, 1]).Value); + Assert.Null(output[i, 2]); + } } [Serializable] @@ -176,7 +183,14 @@ public void CanReadRectangularArraysOfObjects_3D() Assert.True(arrayRecord.TypeNameMatches(typeof(object[,,]))); Assert.False(arrayRecord.TypeNameMatches(typeof(object[,]))); Assert.False(arrayRecord.TypeNameMatches(typeof(int[,,]))); - Assert.Equal(array, arrayRecord.GetArray(typeof(object[,,]))); + SerializationRecord[,,] output = (SerializationRecord[,,])arrayRecord.GetArray(typeof(object[,,])); + + for (int i = 0; i < array.GetLength(0); i++) + { + Assert.Equal(array[i, 0, 0], ((PrimitiveTypeRecord)output[i, 0, 0]).Value); + Assert.Equal(array[i, 1, 0], ((PrimitiveTypeRecord)output[i, 1, 0]).Value); + Assert.Null(output[i, 2, 0]); + } } [Serializable] @@ -223,13 +237,10 @@ public void CanReadRectangularArraysOfComplexTypes_3D() internal static void Verify(Array input, ArrayRecord arrayRecord) { Assert.Equal(input.Rank, arrayRecord.Lengths.Length); - long totalElementsCount = 1; for (int i = 0; i < input.Rank; i++) { Assert.Equal(input.GetLength(i), arrayRecord.Lengths[i]); - totalElementsCount *= input.GetLength(i); } - Assert.Equal(totalElementsCount, arrayRecord.FlattenedLength); Assert.Equal(input.GetType().FullName, arrayRecord.TypeName.FullName); Assert.Equal(input.GetType().GetAssemblyNameIncludingTypeForwards(), arrayRecord.TypeName.AssemblyName!.FullName); } diff --git a/src/libraries/System.Formats.Nrbf/tests/TypeMatchTests.cs b/src/libraries/System.Formats.Nrbf/tests/TypeMatchTests.cs index e0827c1225b421..c4c8018cdc4643 100644 --- a/src/libraries/System.Formats.Nrbf/tests/TypeMatchTests.cs +++ b/src/libraries/System.Formats.Nrbf/tests/TypeMatchTests.cs @@ -362,7 +362,7 @@ private static void VerifySZArray(T input) where T : notnull } else { - Assert.True(arrayRecord is SZArrayRecord, userMessage: typeof(T).Name); + Assert.True(arrayRecord is SZArrayRecord, userMessage: typeof(T).Name); Assert.True(arrayRecord.TypeNameMatches(typeof(T[]))); Assert.Equal(arrayRecord.TypeName.GetElementType().AssemblyName.FullName, typeof(T).GetAssemblyNameIncludingTypeForwards()); } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicApi.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicApi.cs index 253f617892f8fb..a02723101d683d 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicApi.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicApi.cs @@ -70,6 +70,8 @@ private MsQuicApi(QUIC_API_TABLE* apiTable) internal static bool Tls13ClientMayBeDisabled { get; } #pragma warning disable CA1810 // Initialize all static fields in 'MsQuicApi' when those fields are declared and remove the explicit static constructor + [UnconditionalSuppressMessage("SingleFile", "IL3000: Avoid accessing Assembly file path when publishing as a single file", + Justification = "The code handles the Assembly.Location being null/empty by falling back to AppContext.BaseDirectory")] static MsQuicApi() { bool loaded = false; @@ -89,8 +91,23 @@ static MsQuicApi() if (OperatingSystem.IsWindows()) { - // Windows ships msquic in the assembly directory. - loaded = NativeLibrary.TryLoad(Interop.Libraries.MsQuic, typeof(MsQuicApi).Assembly, DllImportSearchPath.AssemblyDirectory, out msQuicHandle); +#pragma warning disable IL3000 // Avoid accessing Assembly file path when publishing as a single file + // Windows ships msquic in the assembly directory next to System.Net.Quic, so load that. + // For single-file deployments, the assembly location is an empty string so we fall back + // to AppContext.BaseDirectory which is the directory containing the single-file executable. + string path = typeof(MsQuicApi).Assembly.Location is string assemblyLocation && !string.IsNullOrEmpty(assemblyLocation) + ? System.IO.Path.GetDirectoryName(assemblyLocation)! + : AppContext.BaseDirectory; +#pragma warning restore IL3000 + + path = System.IO.Path.Combine(path, Interop.Libraries.MsQuic); + + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Info(null, $"Attempting to load MsQuic from {path}"); + } + + loaded = NativeLibrary.TryLoad(path, typeof(MsQuicApi).Assembly, DllImportSearchPath.LegacyBehavior, out msQuicHandle); } else { diff --git a/src/libraries/System.Resources.Extensions/src/System/Resources/Extensions/BinaryFormat/Deserializer/ArrayRecordDeserializer.cs b/src/libraries/System.Resources.Extensions/src/System/Resources/Extensions/BinaryFormat/Deserializer/ArrayRecordDeserializer.cs index 0c8e5d00fe454b..18981a295abdce 100644 --- a/src/libraries/System.Resources.Extensions/src/System/Resources/Extensions/BinaryFormat/Deserializer/ArrayRecordDeserializer.cs +++ b/src/libraries/System.Resources.Extensions/src/System/Resources/Extensions/BinaryFormat/Deserializer/ArrayRecordDeserializer.cs @@ -121,8 +121,14 @@ internal override SerializationRecordId Continue() }; [RequiresUnreferencedCode("Calls System.Windows.Forms.BinaryFormat.BinaryFormattedObject.TypeResolver.GetType(TypeName)")] - internal static Array? GetSimpleBinaryArray(ArrayRecord arrayRecord, BinaryFormattedObject.ITypeResolver typeResolver) + internal static Array? GetRectangularArrayOfPrimitives(ArrayRecord arrayRecord, BinaryFormattedObject.ITypeResolver typeResolver) { + // Only rectangular, non-jagged BinaryArrayRecord can hit the lucky path below. + if (arrayRecord.Rank <= 1 || arrayRecord.TypeName.GetElementType().IsArray) + { + return null; + } + Type arrayRecordElementType = typeResolver.GetType(arrayRecord.TypeName.GetElementType()); Type elementType = arrayRecordElementType; while (elementType.IsArray) @@ -130,17 +136,12 @@ internal override SerializationRecordId Continue() elementType = elementType.GetElementType()!; } - if (!(HasBuiltInSupport(elementType) - || (Nullable.GetUnderlyingType(elementType) is Type nullable && HasBuiltInSupport(nullable)))) + if (!HasBuiltInSupport(elementType)) { return null; } - Type expectedArrayType = arrayRecord.Rank switch - { - 1 => arrayRecordElementType.MakeArrayType(), - _ => arrayRecordElementType.MakeArrayType(arrayRecord.Rank) - }; + Type expectedArrayType = arrayRecordElementType.MakeArrayType(arrayRecord.Rank); return arrayRecord.GetArray(expectedArrayType); diff --git a/src/libraries/System.Resources.Extensions/src/System/Resources/Extensions/BinaryFormat/Deserializer/Deserializer.cs b/src/libraries/System.Resources.Extensions/src/System/Resources/Extensions/BinaryFormat/Deserializer/Deserializer.cs index 8e763fe850d863..eee808bfc36061 100644 --- a/src/libraries/System.Resources.Extensions/src/System/Resources/Extensions/BinaryFormat/Deserializer/Deserializer.cs +++ b/src/libraries/System.Resources.Extensions/src/System/Resources/Extensions/BinaryFormat/Deserializer/Deserializer.cs @@ -229,7 +229,7 @@ object DeserializeNew(SerializationRecordId id) SerializationRecordType.MemberPrimitiveTyped => ((PrimitiveTypeRecord)record).Value, SerializationRecordType.ArraySingleString => ((SZArrayRecord)record).GetArray(), SerializationRecordType.ArraySinglePrimitive => ArrayRecordDeserializer.GetArraySinglePrimitive(record), - SerializationRecordType.BinaryArray => ArrayRecordDeserializer.GetSimpleBinaryArray((ArrayRecord)record, _typeResolver), + SerializationRecordType.BinaryArray => ArrayRecordDeserializer.GetRectangularArrayOfPrimitives((ArrayRecord)record, _typeResolver), _ => null }; diff --git a/src/libraries/System.Resources.Extensions/tests/BinaryFormatTests/Common/MultidimensionalArrayTests.cs b/src/libraries/System.Resources.Extensions/tests/BinaryFormatTests/Common/MultidimensionalArrayTests.cs index 8814a3184ef9e5..efac5078557647 100644 --- a/src/libraries/System.Resources.Extensions/tests/BinaryFormatTests/Common/MultidimensionalArrayTests.cs +++ b/src/libraries/System.Resources.Extensions/tests/BinaryFormatTests/Common/MultidimensionalArrayTests.cs @@ -52,6 +52,84 @@ public void IntegerArrays_Basic() Assert.Equal(threeDimensions, deserialized); } + [Serializable] + public class CustomComparable : IComparable, IEquatable + { + public int Integer; + + public int CompareTo(object? obj) + { + CustomComparable other = (CustomComparable)obj; + + return other.Integer.CompareTo(other.Integer); + } + + public bool Equals(CustomComparable? other) => Integer == other.Integer; + + public override int GetHashCode() => Integer; + + public override bool Equals(object? obj) => obj is CustomComparable other && Equals(other); + } + + [Fact] + public void MultiDimensionalArrayOfMultiDimensionalArrays_Integers() + => MultiDimensionalArrayOfMultiDimensionalArrays(static (x, y) => x * y); + + [Fact] + public void MultiDimensionalArrayOfMultiDimensionalArrays_Doubles() + => MultiDimensionalArrayOfMultiDimensionalArrays(static (x, y) => x * y / 10); + + [Fact] + public void MultiDimensionalArrayOfMultiDimensionalArrays_Strings() + => MultiDimensionalArrayOfMultiDimensionalArrays(static (x, y) => $"{x},{y}"); + + [Fact] + public void MultiDimensionalArrayOfMultiDimensionalArrays_Abstraction() + => MultiDimensionalArrayOfMultiDimensionalArrays(static (x, y) => x switch + { + 0 => x * y, // int + 1 => x + (double)y / 10, // double + 2 => $"{x},{y}", // string + _ => new CustomComparable() { Integer = x * y } + }); + + [Fact] + public void MultiDimensionalArrayOfMultiDimensionalArrays_Objects() + => MultiDimensionalArrayOfMultiDimensionalArrays(static (x, y) => x switch + { + 0 => x * y, // int + 1 => x + (double)y / 10, // double + 2 => $"{x},{y}", // string + _ => new CustomComparable() { Integer = x * y } + }); + + private static void MultiDimensionalArrayOfMultiDimensionalArrays(Func valueFactory) + { + TValue[,][,] input = new TValue[3, 3][,]; + for (int i = 0; i < input.GetLength(0); i++) + { + for (int j = 0; j < input.GetLength(1); j++) + { + TValue[,] contained = new TValue[i + 1, j + 1]; + for (int k = 0; k < contained.GetLength(0); k++) + { + for (int l = 0; l < contained.GetLength(1); l++) + { + contained[k, l] = valueFactory(k, l); + } + } + + input[i, j] = contained; + + object deserializedMd = Deserialize(Serialize(contained)); + Assert.Equal(contained, deserializedMd); + } + } + + object deserializedJagged = Deserialize(Serialize(input)); + Assert.Equal(input, deserializedJagged); + } + [Fact] public void EmptyDimensions() { diff --git a/src/libraries/System.Resources.Extensions/tests/BinaryFormatTests/FormattedObject/BinaryFormattedObjectTests.cs b/src/libraries/System.Resources.Extensions/tests/BinaryFormatTests/FormattedObject/BinaryFormattedObjectTests.cs index f59d2bd47c9ea3..b08874b1336203 100644 --- a/src/libraries/System.Resources.Extensions/tests/BinaryFormatTests/FormattedObject/BinaryFormattedObjectTests.cs +++ b/src/libraries/System.Resources.Extensions/tests/BinaryFormatTests/FormattedObject/BinaryFormattedObjectTests.cs @@ -38,9 +38,9 @@ public void ReadEmptyHashTable() ClassRecord systemClass = (ClassRecord)format[format.RootRecord.Id]; VerifyHashTable(systemClass, expectedVersion: 0, expectedHashSize: 3); - SZArrayRecord keys = (SZArrayRecord)systemClass.GetSerializationRecord("Keys")!; + SZArrayRecord keys = (SZArrayRecord)systemClass.GetSerializationRecord("Keys")!; Assert.Equal(0, keys.Length); - SZArrayRecord values = (SZArrayRecord)systemClass.GetSerializationRecord("Values")!; + SZArrayRecord values = (SZArrayRecord)systemClass.GetSerializationRecord("Values")!; Assert.Equal(0, values.Length); } @@ -77,12 +77,12 @@ public void ReadHashTableWithStringPair() ClassRecord systemClass = (ClassRecord)format[format.RootRecord.Id]; VerifyHashTable(systemClass, expectedVersion: 1, expectedHashSize: 3); - SZArrayRecord keys = (SZArrayRecord)format[systemClass.GetArrayRecord("Keys").Id]; + SZArrayRecord keys = (SZArrayRecord)format[systemClass.GetArrayRecord("Keys").Id]; Assert.Equal(1, keys.Length); - Assert.Equal("This", keys.GetArray().Single()); - SZArrayRecord values = (SZArrayRecord)format[systemClass.GetArrayRecord("Values").Id]; + Assert.Equal("This", ((PrimitiveTypeRecord)keys.GetArray().Single()).Value); + SZArrayRecord values = (SZArrayRecord)format[systemClass.GetArrayRecord("Values").Id]; Assert.Equal(1, values.Length); - Assert.Equal("That", values.GetArray().Single()); + Assert.Equal("That", ((PrimitiveTypeRecord)values.GetArray().Single()).Value); } [Fact] @@ -100,8 +100,9 @@ public void ReadHashTableWithRepeatedStrings() // The collections themselves get ids first before the strings do. // Everything in the second keys is a string reference. - SZArrayRecord array = (SZArrayRecord)systemClass.GetSerializationRecord("Keys")!; - Assert.Equivalent(new object[] { "TheOther", "That", "This" }, array.GetArray()); + SZArrayRecord arrayRecord = (SZArrayRecord)systemClass.GetSerializationRecord("Keys")!; + SerializationRecord[] array = arrayRecord.GetArray(); + Assert.Equivalent(new string[] { "TheOther", "That", "This" }, array.OfType>().Select(sr => sr.Value).ToArray()); } [Fact] @@ -119,11 +120,14 @@ public void ReadHashTableWithNullValues() // The collections themselves get ids first before the strings do. // Everything in the second keys is a string reference. - SZArrayRecord keys = (SZArrayRecord)systemClass.GetSerializationRecord("Keys")!; - Assert.Equivalent(new object[] { "Yowza", "Youza", "Meeza" }, keys.GetArray()); - - SZArrayRecord values = (SZArrayRecord)systemClass.GetSerializationRecord("Values")!; - Assert.Equal(new object?[] { null, null, null }, values.GetArray()); + SZArrayRecord keysRecord = (SZArrayRecord)systemClass.GetSerializationRecord("Keys")!; + SerializationRecord[] keysRecords = keysRecord.GetArray(); + Assert.Equivalent(new string[] { "Yowza", "Youza", "Meeza" }, keysRecords.OfType>().Select(sr => sr.Value).ToArray()); + + SZArrayRecord valuesRecord = (SZArrayRecord)systemClass.GetSerializationRecord("Values")!; + SerializationRecord[] valuesRecords = valuesRecord.GetArray(); + Assert.Equal(3, valuesRecords.Length); + Assert.All(valuesRecords, Assert.Null); } [Fact] diff --git a/src/libraries/System.Resources.Extensions/tests/BinaryFormatTests/FormattedObject/HashTableTests.cs b/src/libraries/System.Resources.Extensions/tests/BinaryFormatTests/FormattedObject/HashTableTests.cs index dad76abff8e916..f0b3b5adf83b8a 100644 --- a/src/libraries/System.Resources.Extensions/tests/BinaryFormatTests/FormattedObject/HashTableTests.cs +++ b/src/libraries/System.Resources.Extensions/tests/BinaryFormatTests/FormattedObject/HashTableTests.cs @@ -72,8 +72,10 @@ public void HashTable_CustomComparer() Assert.Equal("System.Collections.Hashtable", systemClass.TypeName.FullName); Assert.Equal("System.OrdinalComparer", systemClass.GetClassRecord("Comparer")!.TypeName.FullName); Assert.Equal("System.Resources.Extensions.Tests.FormattedObject.HashtableTests+CustomHashCodeProvider", systemClass.GetClassRecord("HashCodeProvider")!.TypeName.FullName); - Assert.True(systemClass.GetSerializationRecord("Keys") is SZArrayRecord); - Assert.True(systemClass.GetSerializationRecord("Values") is SZArrayRecord); + Assert.True(systemClass.GetSerializationRecord("Keys") is SZArrayRecord); + Assert.Equal(SerializationRecordType.ArraySingleObject, systemClass.GetSerializationRecord("Keys").RecordType); + Assert.True(systemClass.GetSerializationRecord("Values") is SZArrayRecord); + Assert.Equal(SerializationRecordType.ArraySingleObject, systemClass.GetSerializationRecord("Values").RecordType); } [Serializable] diff --git a/src/libraries/System.Resources.Extensions/tests/BinaryFormatTests/FormattedObject/ListTests.cs b/src/libraries/System.Resources.Extensions/tests/BinaryFormatTests/FormattedObject/ListTests.cs index 6d43e04498dd31..cb3fdc2927641a 100644 --- a/src/libraries/System.Resources.Extensions/tests/BinaryFormatTests/FormattedObject/ListTests.cs +++ b/src/libraries/System.Resources.Extensions/tests/BinaryFormatTests/FormattedObject/ListTests.cs @@ -19,7 +19,8 @@ public void BinaryFormattedObject_ParseEmptyArrayList() VerifyArrayList((ClassRecord)format[format.RootRecord.Id]); - Assert.True(format[((ClassRecord)format.RootRecord).GetArrayRecord("_items").Id] is SZArrayRecord); + Assert.True(format[((ClassRecord)format.RootRecord).GetArrayRecord("_items").Id] is SZArrayRecord); + Assert.Equal(SerializationRecordType.ArraySingleObject, format[((ClassRecord)format.RootRecord).GetArrayRecord("_items").Id].RecordType); } private static void VerifyArrayList(ClassRecord systemClass) @@ -28,7 +29,8 @@ private static void VerifyArrayList(ClassRecord systemClass) Assert.Equal(typeof(ArrayList).FullName, systemClass.TypeName.FullName); Assert.Equal(["_items", "_size", "_version"], systemClass.MemberNames); - Assert.True(systemClass.GetSerializationRecord("_items") is SZArrayRecord); + Assert.True(systemClass.GetSerializationRecord("_items") is SZArrayRecord); + Assert.Equal(SerializationRecordType.ArraySingleObject, systemClass.GetSerializationRecord("_items").RecordType); } [Theory] @@ -43,9 +45,9 @@ public void BinaryFormattedObject_ParsePrimitivesArrayList(object value) ClassRecord listRecord = (ClassRecord)format[format.RootRecord.Id]; VerifyArrayList(listRecord); - SZArrayRecord array = (SZArrayRecord)format[listRecord.GetArrayRecord("_items").Id]; + SZArrayRecord array = (SZArrayRecord)format[listRecord.GetArrayRecord("_items").Id]; - Assert.Equal(new[] { value }, array.GetArray().Take(listRecord.GetInt32("_size"))); + Assert.Equal(value, ((PrimitiveTypeRecord)array.GetArray().Take(listRecord.GetInt32("_size")).Single()).Value); } [Fact] @@ -59,8 +61,8 @@ public void BinaryFormattedObject_ParseStringArrayList() ClassRecord listRecord = (ClassRecord)format[format.RootRecord.Id]; VerifyArrayList(listRecord); - SZArrayRecord array = (SZArrayRecord)format[listRecord.GetArrayRecord("_items").Id]; - Assert.Equal(new object[] { "JarJar" }, array.GetArray().Take(listRecord.GetInt32("_size"))); + SZArrayRecord array = (SZArrayRecord)format[listRecord.GetArrayRecord("_items").Id]; + Assert.Equal("JarJar", ((PrimitiveTypeRecord)array.GetArray().Take(listRecord.GetInt32("_size")).Single()).Value); } public static TheoryData ArrayList_Primitive_Data => new()