From 423c8c855e36f85e2e570e1248bded7c8d725653 Mon Sep 17 00:00:00 2001 From: Andrew Boyarshin Date: Sun, 18 Apr 2021 22:07:35 +0700 Subject: [PATCH] Ref fix WIP --- .../Marshallers/RefWrapperMarshaller.cs | 104 +++++++++++++++++ SharpGen/Generator/MarshallingRegistry.cs | 5 +- .../ReverseCallablePrologCodeGenerator.cs | 109 +++++++++++++----- 3 files changed, 185 insertions(+), 33 deletions(-) create mode 100644 SharpGen/Generator/Marshallers/RefWrapperMarshaller.cs diff --git a/SharpGen/Generator/Marshallers/RefWrapperMarshaller.cs b/SharpGen/Generator/Marshallers/RefWrapperMarshaller.cs new file mode 100644 index 00000000..16efb853 --- /dev/null +++ b/SharpGen/Generator/Marshallers/RefWrapperMarshaller.cs @@ -0,0 +1,104 @@ +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using SharpGen.Model; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; + +namespace SharpGen.Generator.Marshallers +{ + internal sealed class RefWrapperMarshaller : MarshallerBase, IMarshaller + { + private readonly IMarshaller implementation; + + public RefWrapperMarshaller(GlobalNamespaceProvider globalNamespace, IMarshaller implementation) + : base(globalNamespace) + { + this.implementation = implementation ?? throw new ArgumentNullException(nameof(implementation)); + } + + public static bool IsApplicable(CsMarshalBase csElement) => + csElement is CsMarshalCallableBase {IsLocalByRef: true, IsArray: false}; + + public IEnumerable GenerateManagedToNativeProlog(CsMarshalCallableBase csElement) => + implementation.GenerateManagedToNativeProlog(csElement); + + public IEnumerable GenerateNativeToManagedExtendedProlog(CsMarshalCallableBase csElement) => + implementation.GenerateNativeToManagedExtendedProlog(csElement); + + public StatementSyntax GenerateManagedToNative(CsMarshalBase csElement, bool singleStackFrame) + { + var statement = implementation.GenerateManagedToNative(csElement, singleStackFrame); + + return csElement switch + { + CsParameter {IsOptional: true} parameter => GenerateManagedToNativeForOptional(parameter, statement), + _ => statement + }; + } + + private static StatementSyntax GenerateManagedToNativeForOptional(CsParameter parameter, + StatementSyntax statement) + { + var refIdentifier = IdentifierName(GetRefLocationIdentifier(parameter)); + + StatementSyntaxList statements = new() + { + statement, + IfStatement( + BinaryExpression( + SyntaxKind.NotEqualsExpression, + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + GlobalNamespaceProvider.GetTypeNameSyntax(BuiltinType.Unsafe), + IdentifierName(nameof(Unsafe.AsPointer)) + ), + ArgumentList( + SingletonSeparatedList( + Argument(refIdentifier).WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)) + ) + ) + ), + LiteralExpression(SyntaxKind.DefaultLiteralExpression, Token(SyntaxKind.DefaultKeyword)) + ), + ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + refIdentifier, + IdentifierName(parameter.Name) + ) + ) + ) + }; + + return statements.ToStatement(); + } + + public StatementSyntax GenerateNativeToManaged(CsMarshalBase csElement, bool singleStackFrame) => + implementation.GenerateNativeToManaged(csElement, singleStackFrame); + + public ArgumentSyntax GenerateNativeArgument(CsMarshalCallableBase csElement) => + implementation.GenerateNativeArgument(csElement); + + public ArgumentSyntax GenerateManagedArgument(CsParameter csElement) => + implementation.GenerateManagedArgument(csElement); + + public ParameterSyntax GenerateManagedParameter(CsParameter csElement) => + implementation.GenerateManagedParameter(csElement); + + public StatementSyntax GenerateNativeCleanup(CsMarshalBase csElement, bool singleStackFrame) => + implementation.GenerateNativeCleanup(csElement, singleStackFrame); + + public FixedStatementSyntax GeneratePin(CsParameter csElement) => implementation.GeneratePin(csElement); + + public bool CanMarshal(CsMarshalBase csElement) => implementation.CanMarshal(csElement); + + public bool GeneratesMarshalVariable(CsMarshalCallableBase csElement) => + implementation.GeneratesMarshalVariable(csElement); + + public TypeSyntax GetMarshalTypeSyntax(CsMarshalBase csElement) => + implementation.GetMarshalTypeSyntax(csElement); + } +} \ No newline at end of file diff --git a/SharpGen/Generator/MarshallingRegistry.cs b/SharpGen/Generator/MarshallingRegistry.cs index bc6e2c2a..b5860d07 100644 --- a/SharpGen/Generator/MarshallingRegistry.cs +++ b/SharpGen/Generator/MarshallingRegistry.cs @@ -33,6 +33,7 @@ public MarshallingRegistry(GlobalNamespaceProvider globalNamespace, Logger logge new ValueTypeArrayMarshaller(globalNamespace), new ValueTypeMarshaller(globalNamespace) }; + WrappingMarshallers = Marshallers.Select(x => new RefWrapperMarshaller(globalNamespace, x)).ToArray(); RelationMarshallers = new Dictionary { { typeof(StructSizeRelation), new StructSizeRelationMarshaller(globalNamespace) }, @@ -43,12 +44,14 @@ public MarshallingRegistry(GlobalNamespaceProvider globalNamespace, Logger logge } private IReadOnlyList Marshallers { get; } + private IReadOnlyList WrappingMarshallers { get; } private IReadOnlyDictionary RelationMarshallers { get; } public IMarshaller GetMarshaller(CsMarshalBase csElement) { - var marshaller = Marshallers.FirstOrDefault(m => m.CanMarshal(csElement)); + var list = RefWrapperMarshaller.IsApplicable(csElement) ? WrappingMarshallers : Marshallers; + var marshaller = list.FirstOrDefault(m => m.CanMarshal(csElement)); if (marshaller != null) return marshaller; diff --git a/SharpGen/Generator/ReverseCallablePrologCodeGenerator.cs b/SharpGen/Generator/ReverseCallablePrologCodeGenerator.cs index 7b56b16d..acfeb271 100644 --- a/SharpGen/Generator/ReverseCallablePrologCodeGenerator.cs +++ b/SharpGen/Generator/ReverseCallablePrologCodeGenerator.cs @@ -1,6 +1,8 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Runtime.CompilerServices; +using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using SharpGen.Generator.Marshallers; @@ -101,57 +103,100 @@ private IEnumerable GenerateNativeByRefProlog(CsMarshalCallable MemberAccessExpression( SyntaxKind.SimpleMemberAccessExpression, GlobalNamespaceProvider.GetTypeNameSyntax(BuiltinType.Unsafe), - GenericName(Identifier(nameof(Unsafe.AsRef))) - .WithTypeArgumentList( - TypeArgumentList( - SingletonSeparatedList( - marshaller.GetMarshalTypeSyntax(publicElement)))))) - .WithArgumentList( - ArgumentList( - SingletonSeparatedList( - Argument( - nativeParameter)))); + GenericName( + Identifier(nameof(Unsafe.AsRef)), + TypeArgumentList(SingletonSeparatedList(marshaller.GetMarshalTypeSyntax(publicElement))) + ) + ), + ArgumentList(SingletonSeparatedList(Argument(nativeParameter))) + ); var publicType = ParseTypeName(publicElement.PublicType.QualifiedName); + var generatesMarshalVariable = marshaller.GeneratesMarshalVariable(publicElement); if (publicElement.IsLocalByRef) { - if (!marshaller.GeneratesMarshalVariable(publicElement)) - { - publicType = RefType(publicType); - } + Debug.Assert(marshaller is RefWrapperMarshaller); refToNativeExpression = RefExpression(refToNativeExpression); } + else + { + Debug.Assert(publicElement is CsParameter {IsRefIn: true}); + Debug.Assert(marshaller is not RefWrapperMarshaller); + + if (publicElement is CsParameter {IsOptional: true}) + { + var defaultLiteral = LiteralExpression( + SyntaxKind.DefaultLiteralExpression, + Token(SyntaxKind.DefaultKeyword) + ); + + refToNativeExpression = ConditionalExpression( + BinaryExpression(SyntaxKind.NotEqualsExpression, nativeParameter, defaultLiteral), + refToNativeExpression, + defaultLiteral + ); + } + } + + var refToNativeClause = EqualsValueClause(refToNativeExpression); + EqualsValueClauseSyntax publicParamInitializer = default; - if (marshaller.GeneratesMarshalVariable(publicElement)) + if (publicElement is CsParameter {IsOptional: true, IsLocalByRef: true} parameter) { - yield return LocalDeclarationStatement( + var refVariableDeclaration = LocalDeclarationStatement( VariableDeclaration( - RefType(marshaller.GetMarshalTypeSyntax(publicElement))) - .WithVariables( + RefType(marshaller.GetMarshalTypeSyntax(publicElement)), SingletonSeparatedList( - VariableDeclarator( - MarshallerBase.GetMarshalStorageLocationIdentifier(publicElement)) - .WithInitializer( - EqualsValueClause( - refToNativeExpression))))); + VariableDeclarator(MarshallerBase.GetRefLocationIdentifier(publicElement)) + .WithInitializer(refToNativeClause) + ) + ) + ); + if (generatesMarshalVariable && parameter is {IsRef: true}) + { + refVariableDeclaration = refVariableDeclaration.WithLeadingTrivia( + Comment("Optional ref parameter that requires generating marshal variable is unsupported.") + ); + } + else + { + refToNativeClause = default; + } + + yield return refVariableDeclaration; + } + + if (generatesMarshalVariable) + { yield return LocalDeclarationStatement( - VariableDeclaration(publicType) - .WithVariables( + VariableDeclaration( + RefType(marshaller.GetMarshalTypeSyntax(publicElement)), SingletonSeparatedList( - VariableDeclarator( - Identifier(publicElement.Name))))); + VariableDeclarator(MarshallerBase.GetMarshalStorageLocationIdentifier(publicElement)) + .WithInitializer(refToNativeClause) + ) + ) + ); } else { - yield return LocalDeclarationStatement( - VariableDeclaration(publicType) - .AddVariables( - VariableDeclarator(Identifier(publicElement.Name)) - .WithInitializer(EqualsValueClause(refToNativeExpression)))); + publicParamInitializer = refToNativeClause; + + if (publicElement is CsParameter {IsOptional: false, IsLocalByRef: true} or CsReturnValue) + publicType = RefType(publicType); } + + yield return LocalDeclarationStatement( + VariableDeclaration( + publicType, + SingletonSeparatedList( + VariableDeclarator(Identifier(publicElement.Name), default, publicParamInitializer) + ) + ) + ); } } }