From 44b1c324d54920056559d0659a62d6421f69935f Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Thu, 21 Nov 2024 19:38:08 +0900 Subject: [PATCH 1/8] feat: implement Speaker and Serdes - Adds protobuf code generation - Adds utilities and wrapping code around RPC message types - Adds RPC protocol code required for sending and receiving messages - Adds a end-to-end test with messages and replies between two Speakers --- .gitignore | 400 ++++++++++++++++++ .idea/.idea.Coder.Desktop/.idea/.name | 1 + .../.idea.Coder.Desktop/.idea/indexLayout.xml | 8 + .../.idea/projectSettingsUpdater.xml | 7 + .idea/.idea.Coder.Desktop/.idea/vcs.xml | 6 + .idea/.idea.Coder.Desktop/.idea/workspace.xml | 96 +++++ Coder.Desktop.sln | 29 ++ Coder.Desktop.sln.DotSettings | 2 + Rpc.Proto/Rpc.Proto.csproj | 22 + Rpc.Proto/RpcMessage.cs | 71 ++++ Rpc.Proto/vpn.proto | 199 +++++++++ Rpc/Rpc.csproj | 14 + Rpc/Serdes.cs | 93 ++++ Rpc/Speaker.cs | 369 ++++++++++++++++ Rpc/Version.cs | 71 ++++ Tests/Rpc/SpeakerTest.cs | 134 ++++++ Tests/Tests.csproj | 30 ++ 17 files changed, 1552 insertions(+) create mode 100644 .gitignore create mode 100644 .idea/.idea.Coder.Desktop/.idea/.name create mode 100644 .idea/.idea.Coder.Desktop/.idea/indexLayout.xml create mode 100644 .idea/.idea.Coder.Desktop/.idea/projectSettingsUpdater.xml create mode 100644 .idea/.idea.Coder.Desktop/.idea/vcs.xml create mode 100644 .idea/.idea.Coder.Desktop/.idea/workspace.xml create mode 100644 Coder.Desktop.sln create mode 100644 Coder.Desktop.sln.DotSettings create mode 100644 Rpc.Proto/Rpc.Proto.csproj create mode 100644 Rpc.Proto/RpcMessage.cs create mode 100644 Rpc.Proto/vpn.proto create mode 100644 Rpc/Rpc.csproj create mode 100644 Rpc/Serdes.cs create mode 100644 Rpc/Speaker.cs create mode 100644 Rpc/Version.cs create mode 100644 Tests/Rpc/SpeakerTest.cs create mode 100644 Tests/Tests.csproj diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..30e6bff --- /dev/null +++ b/.gitignore @@ -0,0 +1,400 @@ +## Ignore Visual Studio temporary files, build results, and +## files generated by popular Visual Studio add-ons. +## +## Get latest from https://github.com/github/gitignore/blob/main/VisualStudio.gitignore + +# User-specific files +*.rsuser +*.suo +*.user +*.userosscache +*.sln.docstates + +# User-specific files (MonoDevelop/Xamarin Studio) +*.userprefs + +# Mono auto generated files +mono_crash.* + +# Build results +[Dd]ebug/ +[Dd]ebugPublic/ +[Rr]elease/ +[Rr]eleases/ +x64/ +x86/ +[Ww][Ii][Nn]32/ +[Aa][Rr][Mm]/ +[Aa][Rr][Mm]64/ +bld/ +[Bb]in/ +[Oo]bj/ +[Ll]og/ +[Ll]ogs/ + +# Visual Studio 2015/2017 cache/options directory +.vs/ +# Uncomment if you have tasks that create the project's static files in wwwroot +#wwwroot/ + +# Visual Studio 2017 auto generated files +Generated\ Files/ + +# MSTest test Results +[Tt]est[Rr]esult*/ +[Bb]uild[Ll]og.* + +# NUnit +*.VisualState.xml +TestResult.xml +nunit-*.xml + +# Build Results of an ATL Project +[Dd]ebugPS/ +[Rr]eleasePS/ +dlldata.c + +# Benchmark Results +BenchmarkDotNet.Artifacts/ + +# .NET Core +project.lock.json +project.fragment.lock.json +artifacts/ + +# ASP.NET Scaffolding +ScaffoldingReadMe.txt + +# StyleCop +StyleCopReport.xml + +# Files built by Visual Studio +*_i.c +*_p.c +*_h.h +*.ilk +*.meta +*.obj +*.iobj +*.pch +*.pdb +*.ipdb +*.pgc +*.pgd +*.rsp +# but not Directory.Build.rsp, as it configures directory-level build defaults +!Directory.Build.rsp +*.sbr +*.tlb +*.tli +*.tlh +*.tmp +*.tmp_proj +*_wpftmp.csproj +*.log +*.tlog +*.vspscc +*.vssscc +.builds +*.pidb +*.svclog +*.scc + +# Chutzpah Test files +_Chutzpah* + +# Visual C++ cache files +ipch/ +*.aps +*.ncb +*.opendb +*.opensdf +*.sdf +*.cachefile +*.VC.db +*.VC.VC.opendb + +# Visual Studio profiler +*.psess +*.vsp +*.vspx +*.sap + +# Visual Studio Trace Files +*.e2e + +# TFS 2012 Local Workspace +$tf/ + +# Guidance Automation Toolkit +*.gpState + +# ReSharper is a .NET coding add-in +_ReSharper*/ +*.[Rr]e[Ss]harper +*.DotSettings.user + +# TeamCity is a build add-in +_TeamCity* + +# DotCover is a Code Coverage Tool +*.dotCover + +# AxoCover is a Code Coverage Tool +.axoCover/* +!.axoCover/settings.json + +# Coverlet is a free, cross platform Code Coverage Tool +coverage*.json +coverage*.xml +coverage*.info + +# Visual Studio code coverage results +*.coverage +*.coveragexml + +# NCrunch +_NCrunch_* +.*crunch*.local.xml +nCrunchTemp_* + +# MightyMoose +*.mm.* +AutoTest.Net/ + +# Web workbench (sass) +.sass-cache/ + +# Installshield output folder +[Ee]xpress/ + +# DocProject is a documentation generator add-in +DocProject/buildhelp/ +DocProject/Help/*.HxT +DocProject/Help/*.HxC +DocProject/Help/*.hhc +DocProject/Help/*.hhk +DocProject/Help/*.hhp +DocProject/Help/Html2 +DocProject/Help/html + +# Click-Once directory +publish/ + +# Publish Web Output +*.[Pp]ublish.xml +*.azurePubxml +# Note: Comment the next line if you want to checkin your web deploy settings, +# but database connection strings (with potential passwords) will be unencrypted +*.pubxml +*.publishproj + +# Microsoft Azure Web App publish settings. Comment the next line if you want to +# checkin your Azure Web App publish settings, but sensitive information contained +# in these scripts will be unencrypted +PublishScripts/ + +# NuGet Packages +*.nupkg +# NuGet Symbol Packages +*.snupkg +# The packages folder can be ignored because of Package Restore +**/[Pp]ackages/* +# except build/, which is used as an MSBuild target. +!**/[Pp]ackages/build/ +# Uncomment if necessary however generally it will be regenerated when needed +#!**/[Pp]ackages/repositories.config +# NuGet v3's project.json files produces more ignorable files +*.nuget.props +*.nuget.targets + +# Microsoft Azure Build Output +csx/ +*.build.csdef + +# Microsoft Azure Emulator +ecf/ +rcf/ + +# Windows Store app package directories and files +AppPackages/ +BundleArtifacts/ +Package.StoreAssociation.xml +_pkginfo.txt +*.appx +*.appxbundle +*.appxupload + +# Visual Studio cache files +# files ending in .cache can be ignored +*.[Cc]ache +# but keep track of directories ending in .cache +!?*.[Cc]ache/ + +# Others +ClientBin/ +~$* +*~ +*.dbmdl +*.dbproj.schemaview +*.jfm +*.pfx +*.publishsettings +orleans.codegen.cs + +# Including strong name files can present a security risk +# (https://github.com/github/gitignore/pull/2483#issue-259490424) +#*.snk + +# Since there are multiple workflows, uncomment next line to ignore bower_components +# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) +#bower_components/ + +# RIA/Silverlight projects +Generated_Code/ + +# Backup & report files from converting an old project file +# to a newer Visual Studio version. Backup files are not needed, +# because we have git ;-) +_UpgradeReport_Files/ +Backup*/ +UpgradeLog*.XML +UpgradeLog*.htm +ServiceFabricBackup/ +*.rptproj.bak + +# SQL Server files +*.mdf +*.ldf +*.ndf + +# Business Intelligence projects +*.rdl.data +*.bim.layout +*.bim_*.settings +*.rptproj.rsuser +*- [Bb]ackup.rdl +*- [Bb]ackup ([0-9]).rdl +*- [Bb]ackup ([0-9][0-9]).rdl + +# Microsoft Fakes +FakesAssemblies/ + +# GhostDoc plugin setting file +*.GhostDoc.xml + +# Node.js Tools for Visual Studio +.ntvs_analysis.dat +node_modules/ + +# Visual Studio 6 build log +*.plg + +# Visual Studio 6 workspace options file +*.opt + +# Visual Studio 6 auto-generated workspace file (contains which files were open etc.) +*.vbw + +# Visual Studio 6 auto-generated project file (contains which files were open etc.) +*.vbp + +# Visual Studio 6 workspace and project file (working project files containing files to include in project) +*.dsw +*.dsp + +# Visual Studio 6 technical files +*.ncb +*.aps + +# Visual Studio LightSwitch build output +**/*.HTMLClient/GeneratedArtifacts +**/*.DesktopClient/GeneratedArtifacts +**/*.DesktopClient/ModelManifest.xml +**/*.Server/GeneratedArtifacts +**/*.Server/ModelManifest.xml +_Pvt_Extensions + +# Paket dependency manager +.paket/paket.exe +paket-files/ + +# FAKE - F# Make +.fake/ + +# CodeRush personal settings +.cr/personal + +# Python Tools for Visual Studio (PTVS) +__pycache__/ +*.pyc + +# Cake - Uncomment if you are using it +# tools/** +# !tools/packages.config + +# Tabs Studio +*.tss + +# Telerik's JustMock configuration file +*.jmconfig + +# BizTalk build output +*.btp.cs +*.btm.cs +*.odx.cs +*.xsd.cs + +# OpenCover UI analysis results +OpenCover/ + +# Azure Stream Analytics local run output +ASALocalRun/ + +# MSBuild Binary and Structured Log +*.binlog + +# NVidia Nsight GPU debugger configuration file +*.nvuser + +# MFractors (Xamarin productivity tool) working folder +.mfractor/ + +# Local History for Visual Studio +.localhistory/ + +# Visual Studio History (VSHistory) files +.vshistory/ + +# BeatPulse healthcheck temp database +healthchecksdb + +# Backup folder for Package Reference Convert tool in Visual Studio 2017 +MigrationBackup/ + +# Ionide (cross platform F# VS Code tools) working folder +.ionide/ + +# Fody - auto-generated XML schema +FodyWeavers.xsd + +# VS Code files for those working on multiple tools +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +*.code-workspace + +# Local History for Visual Studio Code +.history/ + +# Windows Installer files from build outputs +*.cab +*.msi +*.msix +*.msm +*.msp + +# JetBrains Rider +*.sln.iml diff --git a/.idea/.idea.Coder.Desktop/.idea/.name b/.idea/.idea.Coder.Desktop/.idea/.name new file mode 100644 index 0000000..6e47b44 --- /dev/null +++ b/.idea/.idea.Coder.Desktop/.idea/.name @@ -0,0 +1 @@ +Coder.Desktop \ No newline at end of file diff --git a/.idea/.idea.Coder.Desktop/.idea/indexLayout.xml b/.idea/.idea.Coder.Desktop/.idea/indexLayout.xml new file mode 100644 index 0000000..7b08163 --- /dev/null +++ b/.idea/.idea.Coder.Desktop/.idea/indexLayout.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/.idea.Coder.Desktop/.idea/projectSettingsUpdater.xml b/.idea/.idea.Coder.Desktop/.idea/projectSettingsUpdater.xml new file mode 100644 index 0000000..64af657 --- /dev/null +++ b/.idea/.idea.Coder.Desktop/.idea/projectSettingsUpdater.xml @@ -0,0 +1,7 @@ + + + + + \ No newline at end of file diff --git a/.idea/.idea.Coder.Desktop/.idea/vcs.xml b/.idea/.idea.Coder.Desktop/.idea/vcs.xml new file mode 100644 index 0000000..35eb1dd --- /dev/null +++ b/.idea/.idea.Coder.Desktop/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/.idea/.idea.Coder.Desktop/.idea/workspace.xml b/.idea/.idea.Coder.Desktop/.idea/workspace.xml new file mode 100644 index 0000000..08bb464 --- /dev/null +++ b/.idea/.idea.Coder.Desktop/.idea/workspace.xml @@ -0,0 +1,96 @@ + + + + Rpc/Proto/Proto/Proto.csproj + + + + + + + + + + + + + + + + + + + + + + + + + + + + + { + "associatedIndex": 5 +} + + + + + + + { + "keyToString": { + "3ac978bb-8bbe-45e0-8b3e-92a2319baad6.executor": "Debug", + "RunOnceActivity.ShowReadmeOnStart": "true", + "RunOnceActivity.git.unshallow": "true", + "f59fc48a-1c71-44e4-ba56-e50869357557.executor": "Debug", + "git-widget-placeholder": "main", + "ignore.virus.scanning.warn.message": "true", + "node.js.detected.package.eslint": "true", + "node.js.detected.package.tslint": "true", + "node.js.selected.package.eslint": "(autodetect)", + "node.js.selected.package.tslint": "(autodetect)", + "nodejs_package_manager_path": "npm", + "settings.editor.selected.configurable": "preferences.sourceCode.C#", + "vue.rearranger.settings.migration": "true" + } +} + + + C:\Users\dean\AppData\Roaming\Subversion + + + + + 1732080436101 + + + + + + + + + + + - - \ No newline at end of file diff --git a/Rpc/Version.cs b/Rpc.Proto/ApiVersion.cs similarity index 93% rename from Rpc/Version.cs rename to Rpc.Proto/ApiVersion.cs index ee43cf6..16f19e4 100644 --- a/Rpc/Version.cs +++ b/Rpc.Proto/ApiVersion.cs @@ -1,4 +1,4 @@ -namespace Coder.Desktop.Rpc; +namespace Coder.Desktop.Rpc.Proto; /// /// Thrown when the two peers are incompatible with each other. @@ -16,9 +16,9 @@ public class ApiVersion(int major, int minor, params int[] additionalMajors) { public static readonly ApiVersion Current = new(1, 0); - public int Major { get; } = major; - public int Minor { get; } = minor; - public int[] AdditionalMajors { get; } = additionalMajors; + private int Major { get; } = major; + private int Minor { get; } = minor; + private int[] AdditionalMajors { get; } = additionalMajors; /// /// Parse a string in the format "major.minor" into an ApiVersion. diff --git a/Rpc.Proto/RpcHeader.cs b/Rpc.Proto/RpcHeader.cs new file mode 100644 index 0000000..48b4c29 --- /dev/null +++ b/Rpc.Proto/RpcHeader.cs @@ -0,0 +1,46 @@ +using System.Text; + +namespace Coder.Desktop.Rpc.Proto; + +/// +/// A header to write or read from a stream to identify the speaker's role and version. +/// +/// Role of the speaker +/// Version of the speaker +public class RpcHeader(RpcRole role, ApiVersion version) +{ + private const string Preamble = "codervpn"; + + public RpcRole Role { get; } = role; + public ApiVersion Version { get; } = version; + + /// + /// Parse a header string into a SpeakerHeader. + /// + /// Raw header string without trailing newline + /// Parsed header + /// Invalid header string + public static RpcHeader Parse(string header) + { + var parts = header.Split(' '); + if (parts.Length != 3) throw new ArgumentException($"Wrong number of parts in header string '{header}'"); + if (parts[0] != Preamble) throw new ArgumentException($"Invalid preamble in header string '{header}'"); + + var version = ApiVersion.ParseString(parts[1]); + var role = new RpcRole(parts[2]); + return new RpcHeader(role, version); + } + + /// + /// Construct a header string from the role and version with a trailing newline. + /// + public override string ToString() + { + return $"{Preamble} {Version} {Role}\n"; + } + + public ReadOnlyMemory ToBytes() + { + return Encoding.UTF8.GetBytes(ToString()); + } +} diff --git a/Rpc.Proto/RpcMessage.cs b/Rpc.Proto/RpcMessage.cs index 7502ebe..f80fff7 100644 --- a/Rpc.Proto/RpcMessage.cs +++ b/Rpc.Proto/RpcMessage.cs @@ -1,7 +1,14 @@ -using Google.Protobuf; +using System.Reflection; +using Google.Protobuf; namespace Coder.Desktop.Rpc.Proto; +[AttributeUsage(AttributeTargets.Class, Inherited = false)] +public class RpcRoleAttribute(string role) : Attribute +{ + public RpcRole Role { get; } = new(role); +} + /// /// Represents an actual over-the-wire message type. /// @@ -19,8 +26,22 @@ public abstract class RpcMessage where T : IMessage /// contents. /// public abstract T Message { get; } + + /// + /// Gets the RpcRole of the message type from it's RpcRole attribute. + /// + /// + /// The message type does not have an RpcRoleAttribute + public static RpcRole GetRole() + { + var type = typeof(T); + var attr = type.GetCustomAttribute(); + if (attr is null) throw new ArgumentException($"Message type {type} does not have a RpcRoleAttribute"); + return attr.Role; + } } +[RpcRole(RpcRole.Manager)] public partial class ManagerMessage : RpcMessage { public override RPC RpcField @@ -32,6 +53,7 @@ public override RPC RpcField public override ManagerMessage Message => this; } +[RpcRole(RpcRole.Tunnel)] public partial class TunnelMessage : RpcMessage { public override RPC RpcField diff --git a/Rpc.Proto/RpcRole.cs b/Rpc.Proto/RpcRole.cs new file mode 100644 index 0000000..063294a --- /dev/null +++ b/Rpc.Proto/RpcRole.cs @@ -0,0 +1,56 @@ +namespace Coder.Desktop.Rpc.Proto; + +/// +/// Represents a role that either side of the connection can fulfil. +/// +public sealed class RpcRole +{ + internal const string Manager = "manager"; + internal const string Tunnel = "tunnel"; + + public RpcRole(string role) + { + if (role != Manager && role != Tunnel) throw new ArgumentException($"Unknown role '{role}'"); + + Role = role; + } + + private string Role { get; } + + public override string ToString() + { + return Role; + } + + #region SpeakerRole equality + + public static bool operator ==(RpcRole a, RpcRole b) + { + return a.Equals(b); + } + + public static bool operator !=(RpcRole a, RpcRole b) + { + return !a.Equals(b); + } + + private bool Equals(RpcRole other) + { + return Role == other.Role; + } + + public override bool Equals(object? obj) + { + if (obj is null) return false; + if (ReferenceEquals(this, obj)) return true; + if (obj.GetType() != GetType()) return false; + return Equals((RpcRole)obj); + } + + public override int GetHashCode() + { + return Role.GetHashCode(); + } + + #endregion +} diff --git a/Rpc/Speaker.cs b/Rpc/Speaker.cs index fb8a535..d8b69d6 100644 --- a/Rpc/Speaker.cs +++ b/Rpc/Speaker.cs @@ -5,108 +5,6 @@ namespace Coder.Desktop.Rpc; -/// -/// Represents a role that either side of the connection can fulfil. -/// -public class SpeakerRole -{ - private const string ManagerString = "manager"; - private const string TunnelString = "tunnel"; - - public static readonly SpeakerRole Manager = new(ManagerString); - public static readonly SpeakerRole Tunnel = new(TunnelString); - - // TODO: it would be nice if we could expose this on the RpcMessage types instead - public SpeakerRole(string role) - { - if (role != ManagerString && role != TunnelString) throw new ArgumentException($"Unknown role '{role}'"); - - Role = role; - } - - public string Role { get; } - - public override string ToString() - { - return Role; - } - - #region SpeakerRole equality - - public static bool operator ==(SpeakerRole a, SpeakerRole b) - { - return a.Equals(b); - } - - public static bool operator !=(SpeakerRole a, SpeakerRole b) - { - return !a.Equals(b); - } - - private bool Equals(SpeakerRole other) - { - return Role == other.Role; - } - - public override bool Equals(object? obj) - { - if (obj is null) return false; - if (ReferenceEquals(this, obj)) return true; - if (obj.GetType() != GetType()) return false; - return Equals((SpeakerRole)obj); - } - - public override int GetHashCode() - { - return Role.GetHashCode(); - } - - #endregion -} - -/// -/// A header to write or read from a stream to identify the speaker's role and version. -/// -/// Role of the speaker -/// Version of the speaker -public class SpeakerHeader(SpeakerRole role, ApiVersion version) -{ - public const string Preamble = "codervpn"; - - public SpeakerRole Role { get; } = role; - public ApiVersion Version { get; } = version; - - /// - /// Parse a header string into a SpeakerHeader. - /// - /// Raw header string without trailing newline - /// Parsed header - /// Invalid header string - public static SpeakerHeader Parse(string header) - { - var parts = header.Split(' '); - if (parts.Length != 3) throw new ArgumentException($"Wrong number of parts in header string '{header}'"); - if (parts[0] != Preamble) throw new ArgumentException($"Invalid preamble in header string '{header}'"); - - var version = ApiVersion.ParseString(parts[1]); - var role = new SpeakerRole(parts[2]); - return new SpeakerHeader(role, version); - } - - /// - /// Construct a header string from the role and version with a trailing newline. - /// - public override string ToString() - { - return $"{Preamble} {Version} {Role}\n"; - } - - public ReadOnlyMemory ToBytes() - { - return Encoding.UTF8.GetBytes(ToString()); - } -} - /// /// Wraps a RpcMessage to allow easily sending a reply via the Speaker. /// @@ -138,9 +36,9 @@ public async Task SendReply(TS reply, CancellationToken ct = default) /// /// Manages an RPC connection between two peers, allowing messages to be sent and received. /// -/// The wrapped message type for sent messages -/// The wrapped message type for received messages -public class Speaker : IDisposable, IAsyncDisposable +/// The message type for sent messages +/// The message type for received messages +public class Speaker : IAsyncDisposable where TS : RpcMessage, IMessage where TR : RpcMessage, IMessage, new() { @@ -153,78 +51,53 @@ public class Speaker : IDisposable, IAsyncDisposable // _cts is cancelled when Dispose is called and will cause all ongoing I/O // operations to be cancelled. private readonly CancellationTokenSource _cts = new(); - private readonly SpeakerRole _me; - private readonly ConcurrentDictionary> _pendingReplies = new(); - private readonly Task _receiveTask; private readonly Serdes _serdes = new(); - private readonly SpeakerRole _them; // _lastMessageId is incremented using an atomic operation, and as such the // first message ID will actually be 1. private ulong _lastMessageId; + private Task? _receiveTask; /// - /// Instantiates a speaker, performs a handshake with the peer, and starts receiving messages. + /// Instantiates a speaker. The speaker will not perform any I/O until StartAsync is called. /// - /// Stream to use for I/O - will be automatically closed on ctor failure or Dispose - /// The local role - /// The remote role - /// Callback to fire on received messages (except replies) - /// Callback to fire on fatal receive errors - /// Could not complete handshake within 5s - /// Handshake failed - public Speaker(Stream conn, SpeakerRole me, SpeakerRole them, OnReceiveDelegate onReceive, OnErrorDelegate onError) + /// Stream to use for I/O + public Speaker(Stream conn) { _conn = conn; - _me = me; - _them = them; - Receive += onReceive; - Error += onError; - - // Handshake with a hard timeout of 5s. - var handshakeTask = PerformHandshake(); - handshakeTask.Wait(TimeSpan.FromSeconds(5)); - if (!handshakeTask.IsCompleted) - { - _conn.Dispose(); - throw new TimeoutException("RPC handshake timed out"); - } - - if (handshakeTask.IsFaulted) - { - _conn.Dispose(); - throw handshakeTask.Exception!; - } - - _receiveTask = ReceiveLoop(_cts.Token); - _receiveTask.ContinueWith(t => - { - if (t.IsFaulted) Error.Invoke(t.Exception!); - }); } public async ValueTask DisposeAsync() { await _cts.CancelAsync(); - await _receiveTask.WaitAsync(TimeSpan.FromSeconds(5)); + if (_receiveTask is not null) await _receiveTask.WaitAsync(TimeSpan.FromSeconds(5)); await _conn.DisposeAsync(); GC.SuppressFinalize(this); } - public void Dispose() + // TODO: do we want to do events API or channels API? + public event OnReceiveDelegate? Receive; + public event OnErrorDelegate? Error; + + /// + /// Performs a handshake with the peer and starts the async receive loop. The caller should attach it's Receive and + /// Error event handlers before calling this method. + /// + public async Task StartAsync(CancellationToken ct = default) { - _cts.Cancel(); - // Wait up to 5s for _receiveTask to finish, we don't really care about - // the result. - _receiveTask.Wait(TimeSpan.FromSeconds(5)); - _conn.Dispose(); - GC.SuppressFinalize(this); - } + // Handshakes should always finish quickly, so enforce a 5s timeout. + using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct); + cts.CancelAfter(TimeSpan.FromSeconds(5)); + await PerformHandshake(ct); - // TODO: do we want to do events API or channels API? - private event OnReceiveDelegate Receive; - private event OnErrorDelegate Error; + // Start ReceiveLoop in the background. + _receiveTask = ReceiveLoop(_cts.Token); + _ = _receiveTask.ContinueWith(t => + { + if (t.IsFaulted) Error?.Invoke(t.Exception!); + }, CancellationToken.None); + } private async Task PerformHandshake(CancellationToken ct = default) { @@ -234,15 +107,17 @@ private async Task PerformHandshake(CancellationToken ct = default) var readTask = ReadHeader(ct); await Task.WhenAll(writeTask, readTask); - var header = SpeakerHeader.Parse(await readTask); - if (header.Role != _them) throw new ArgumentException($"Expected peer role '{_them}' but got '{header.Role}'"); + var header = RpcHeader.Parse(await readTask); + var expectedRole = RpcMessage.GetRole(); + if (header.Role != expectedRole) + throw new ArgumentException($"Expected peer role '{expectedRole}' but got '{header.Role}'"); header.Version.Validate(ApiVersion.Current); } private async Task WriteHeader(CancellationToken ct = default) { - var header = new SpeakerHeader(_me, ApiVersion.Current); + var header = new RpcHeader(RpcMessage.GetRole(), ApiVersion.Current); await _conn.WriteAsync(header.ToBytes(), ct); } @@ -279,7 +154,7 @@ private async Task ReceiveLoop(CancellationToken ct = default) // TODO: we should log unknown replies // Start a new task in the background to handle the message. - _ = Task.Run(() => Receive.Invoke(new ReplyableRpcMessage(this, message)), ct); + _ = Task.Run(() => Receive?.Invoke(new ReplyableRpcMessage(this, message)), ct); } } catch (OperationCanceledException) @@ -288,7 +163,7 @@ private async Task ReceiveLoop(CancellationToken ct = default) } catch (Exception e) { - if (!ct.IsCancellationRequested) Error.Invoke(e); + if (!ct.IsCancellationRequested) Error?.Invoke(e); } } diff --git a/Tests/Rpc/SpeakerTest.cs b/Tests/Rpc/SpeakerTest.cs index 973e789..15ae53c 100644 --- a/Tests/Rpc/SpeakerTest.cs +++ b/Tests/Rpc/SpeakerTest.cs @@ -80,32 +80,28 @@ public async Task Ok() { var (stream1, stream2) = BidirectionalPipe.New(); + var speaker1 = new Speaker(stream1); var speaker1Ch = Channel .CreateUnbounded>(); + speaker1.Receive += msg => + { + Console.WriteLine($"speaker1 received message: {msg.RpcField.MsgId}"); + Assert.That(speaker1Ch.Writer.TryWrite(msg), Is.True); + }; + speaker1.Error += ex => { Assert.Fail($"speaker1 error: {ex}"); }; + + var speaker2 = new Speaker(stream2); var speaker2Ch = Channel .CreateUnbounded>(); + speaker2.Receive += msg => + { + Console.WriteLine($"speaker2 received message: {msg.RpcField.MsgId}"); + Assert.That(speaker2Ch.Writer.TryWrite(msg), Is.True); + }; + speaker2.Error += ex => { Assert.Fail($"speaker2 error: {ex}"); }; - // Start two speakers asynchronously as startup is blocking. - var speaker1Task = Task.Run(() => - new Speaker(stream1, - SpeakerRole.Manager, SpeakerRole.Tunnel, - tri => - { - Console.WriteLine($"speaker1 received message: {tri.RpcField.MsgId}"); - Assert.That(speaker1Ch.Writer.TryWrite(tri), Is.True); - }, ex => { Assert.Fail($"speaker1 error: {ex}"); })); - var speaker2Task = Task.Run(() => - new Speaker(stream2, - SpeakerRole.Tunnel, SpeakerRole.Manager, - tri => - { - Console.WriteLine($"speaker2 received message: {tri.RpcField.MsgId}"); - Assert.That(speaker2Ch.Writer.TryWrite(tri), Is.True); - }, ex => { Assert.Fail($"speaker2 error: {ex}"); })); - - Task.WaitAll(speaker1Task, speaker2Task); - await using var speaker1 = await speaker1Task; - await using var speaker2 = await speaker2Task; + // Start both speakers simultaneously + Task.WaitAll(speaker1.StartAsync(), speaker2.StartAsync()); var sendTask = speaker1.SendMessageAwaitReply(new ManagerMessage { From 634764c5859121d3a6bf497fdbd171b1f83b145a Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Tue, 26 Nov 2024 03:12:09 +0900 Subject: [PATCH 4/8] Most tests --- Rpc.Proto/ApiVersion.cs | 34 ++++++++++- Rpc.Proto/RpcHeader.cs | 2 +- Rpc.Proto/RpcMessage.cs | 2 +- Rpc.Proto/RpcRole.cs | 4 +- Rpc.Proto/vpn.proto | 1 - Rpc/Serdes.cs | 15 +++-- Rpc/Speaker.cs | 37 +++++++----- Tests/Rpc.Proto/ApiVersionTest.cs | 36 ++++++++++++ Tests/Rpc.Proto/RpcHeaderTest.cs | 41 +++++++++++++ Tests/Rpc.Proto/RpcMessageTest.cs | 39 +++++++++++++ Tests/Rpc.Proto/RpcRoleTest.cs | 22 +++++++ Tests/Rpc/SerdesTest.cs | 96 +++++++++++++++++++++++++++++++ Tests/Rpc/SpeakerTest.cs | 19 ++++-- 13 files changed, 321 insertions(+), 27 deletions(-) create mode 100644 Tests/Rpc.Proto/ApiVersionTest.cs create mode 100644 Tests/Rpc.Proto/RpcHeaderTest.cs create mode 100644 Tests/Rpc.Proto/RpcMessageTest.cs create mode 100644 Tests/Rpc.Proto/RpcRoleTest.cs create mode 100644 Tests/Rpc/SerdesTest.cs diff --git a/Rpc.Proto/ApiVersion.cs b/Rpc.Proto/ApiVersion.cs index 16f19e4..fcadd64 100644 --- a/Rpc.Proto/ApiVersion.cs +++ b/Rpc.Proto/ApiVersion.cs @@ -26,7 +26,7 @@ public class ApiVersion(int major, int minor, params int[] additionalMajors) /// Version string to parse /// Parsed ApiVersion /// The version string is invalid - public static ApiVersion ParseString(string versionString) + public static ApiVersion Parse(string versionString) { var parts = versionString.Split('.'); if (parts.Length != 2) throw new ArgumentException($"Invalid version string '{versionString}'"); @@ -68,4 +68,36 @@ public void Validate(ApiVersion other) if (AdditionalMajors.Any(major => other.Major == major)) return; throw new ApiCompatibilityException(this, other, "Version is no longer supported"); } + + #region ApiVersion Equality + + public static bool operator ==(ApiVersion a, ApiVersion b) + { + return a.Equals(b); + } + + public static bool operator !=(ApiVersion a, ApiVersion b) + { + return !a.Equals(b); + } + + private bool Equals(ApiVersion other) + { + return Major == other.Major && Minor == other.Minor && AdditionalMajors.SequenceEqual(other.AdditionalMajors); + } + + public override bool Equals(object? obj) + { + if (obj is null) return false; + if (ReferenceEquals(this, obj)) return true; + if (obj.GetType() != GetType()) return false; + return Equals((ApiVersion)obj); + } + + public override int GetHashCode() + { + return HashCode.Combine(Major, Minor, AdditionalMajors); + } + + #endregion } diff --git a/Rpc.Proto/RpcHeader.cs b/Rpc.Proto/RpcHeader.cs index 48b4c29..9e3bce5 100644 --- a/Rpc.Proto/RpcHeader.cs +++ b/Rpc.Proto/RpcHeader.cs @@ -26,7 +26,7 @@ public static RpcHeader Parse(string header) if (parts.Length != 3) throw new ArgumentException($"Wrong number of parts in header string '{header}'"); if (parts[0] != Preamble) throw new ArgumentException($"Invalid preamble in header string '{header}'"); - var version = ApiVersion.ParseString(parts[1]); + var version = ApiVersion.Parse(parts[1]); var role = new RpcRole(parts[2]); return new RpcHeader(role, version); } diff --git a/Rpc.Proto/RpcMessage.cs b/Rpc.Proto/RpcMessage.cs index f80fff7..6662a25 100644 --- a/Rpc.Proto/RpcMessage.cs +++ b/Rpc.Proto/RpcMessage.cs @@ -36,7 +36,7 @@ public static RpcRole GetRole() { var type = typeof(T); var attr = type.GetCustomAttribute(); - if (attr is null) throw new ArgumentException($"Message type {type} does not have a RpcRoleAttribute"); + if (attr is null) throw new ArgumentException($"Message type '{type}' does not have a RpcRoleAttribute"); return attr.Role; } } diff --git a/Rpc.Proto/RpcRole.cs b/Rpc.Proto/RpcRole.cs index 063294a..275da24 100644 --- a/Rpc.Proto/RpcRole.cs +++ b/Rpc.Proto/RpcRole.cs @@ -5,8 +5,8 @@ namespace Coder.Desktop.Rpc.Proto; /// public sealed class RpcRole { - internal const string Manager = "manager"; - internal const string Tunnel = "tunnel"; + public const string Manager = "manager"; + public const string Tunnel = "tunnel"; public RpcRole(string role) { diff --git a/Rpc.Proto/vpn.proto b/Rpc.Proto/vpn.proto index e56661a..dda973d 100644 --- a/Rpc.Proto/vpn.proto +++ b/Rpc.Proto/vpn.proto @@ -1,6 +1,5 @@ syntax = "proto3"; option go_package = "github.com/coder/coder/v2/vpn"; -// TODO: add this upstream option csharp_namespace = "Coder.Desktop.Rpc.Proto"; import "google/protobuf/timestamp.proto"; diff --git a/Rpc/Serdes.cs b/Rpc/Serdes.cs index f45f341..c7f5632 100644 --- a/Rpc/Serdes.cs +++ b/Rpc/Serdes.cs @@ -69,7 +69,6 @@ public async Task WriteMessage(Stream conn, TS message, CancellationToken ct = d /// Optional cancellation token /// Decoded message /// Could not decode the message - /// Could not cast the received message to the expected type public async Task ReadMessage(Stream conn, CancellationToken ct = default) { using var _ = await _readLock.LockAsync(ct); @@ -83,8 +82,16 @@ public async Task ReadMessage(Stream conn, CancellationToken ct = default) var msgBytes = new byte[len]; await conn.ReadExactlyAsync(msgBytes, ct); - var msg = _parser.ParseFrom(msgBytes); - if (msg == null) throw new IOException("Failed to parse message"); - return msg; + try + { + var msg = _parser.ParseFrom(msgBytes); + if (msg?.RpcField is null) + throw new IOException("Parsed message is empty or invalid"); + return msg; + } + catch (Exception e) + { + throw new IOException("Failed to parse message", e); + } } } diff --git a/Rpc/Speaker.cs b/Rpc/Speaker.cs index d8b69d6..d5e8b44 100644 --- a/Rpc/Speaker.cs +++ b/Rpc/Speaker.cs @@ -46,6 +46,17 @@ public class Speaker : IAsyncDisposable public delegate void OnReceiveDelegate(ReplyableRpcMessage message); + /// + /// Event that is triggered when a message is received. + /// + public event OnReceiveDelegate? Receive; + + /// + /// Event that is triggered when an error occurs. The handling code should dispose the Speaker after this event is + /// triggered. + /// + public event OnErrorDelegate? Error; + private readonly Stream _conn; // _cts is cancelled when Dispose is called and will cause all ongoing I/O @@ -70,16 +81,13 @@ public Speaker(Stream conn) public async ValueTask DisposeAsync() { + Error = null; await _cts.CancelAsync(); if (_receiveTask is not null) await _receiveTask.WaitAsync(TimeSpan.FromSeconds(5)); await _conn.DisposeAsync(); GC.SuppressFinalize(this); } - // TODO: do we want to do events API or channels API? - public event OnReceiveDelegate? Receive; - public event OnErrorDelegate? Error; - /// /// Performs a handshake with the peer and starts the async receive loop. The caller should attach it's Receive and /// Error event handlers before calling this method. @@ -87,7 +95,7 @@ public async ValueTask DisposeAsync() public async Task StartAsync(CancellationToken ct = default) { // Handshakes should always finish quickly, so enforce a 5s timeout. - using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct); + using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct, _cts.Token); cts.CancelAfter(TimeSpan.FromSeconds(5)); await PerformHandshake(ct); @@ -174,23 +182,25 @@ private async Task ReceiveLoop(CancellationToken ct = default) /// Optional cancellation token public async Task SendMessage(TS message, CancellationToken ct = default) { + using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct, _cts.Token); message.RpcField = new RPC { MsgId = Interlocked.Add(ref _lastMessageId, 1), ResponseTo = 0, }; - await _serdes.WriteMessage(_conn, message, ct); + await _serdes.WriteMessage(_conn, message, cts.Token); } /// /// Send a message and wait for a reply. The reply will be returned and the callback will not be invoked as long as the /// reply is received before cancellation. /// - /// Message to send + /// Message to send - the Rpc field will be overwritten /// Optional cancellation token /// Received reply public async ValueTask SendMessageAwaitReply(TS message, CancellationToken ct = default) { + using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct, _cts.Token); message.RpcField = new RPC { MsgId = Interlocked.Add(ref _lastMessageId, 1), @@ -203,31 +213,32 @@ public async ValueTask SendMessageAwaitReply(TS message, CancellationToken c _pendingReplies[message.RpcField.MsgId] = tcs; try { - await _serdes.WriteMessage(_conn, message, ct); + await _serdes.WriteMessage(_conn, message, cts.Token); // Wait for the reply to be received. - return await tcs.Task.WaitAsync(ct); + return await tcs.Task.WaitAsync(cts.Token); } finally { // Clean up the pending reply if it was not received before - // cancellation. + // cancellation or another exception occurred. _pendingReplies.TryRemove(message.RpcField.MsgId, out _); } } /// - /// Sends a reply to a received request. + /// Sends a reply to a received message. /// - /// Message to reply to + /// Message to reply to - the Rpc field will be overwritten /// Reply message /// Optional cancellation token public async Task SendReply(TR originalMessage, TS reply, CancellationToken ct = default) { + using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct, _cts.Token); reply.RpcField = new RPC { MsgId = Interlocked.Add(ref _lastMessageId, 1), ResponseTo = originalMessage.RpcField.MsgId, }; - await _serdes.WriteMessage(_conn, reply, ct); + await _serdes.WriteMessage(_conn, reply, cts.Token); } } diff --git a/Tests/Rpc.Proto/ApiVersionTest.cs b/Tests/Rpc.Proto/ApiVersionTest.cs new file mode 100644 index 0000000..bf33f27 --- /dev/null +++ b/Tests/Rpc.Proto/ApiVersionTest.cs @@ -0,0 +1,36 @@ +using Coder.Desktop.Rpc.Proto; + +namespace Coder.Desktop.Tests.Rpc.Proto; + +[TestFixture] +public class ApiVersionTest +{ + [Test(Description = "Parse a variety of version strings")] + public void Parse() + { + Assert.That(ApiVersion.Parse("2.1"), Is.EqualTo(new ApiVersion(2, 1))); + Assert.That(ApiVersion.Parse("1.0"), Is.EqualTo(new ApiVersion(1, 0))); + + Assert.Throws(() => ApiVersion.Parse("cats")); + Assert.Throws(() => ApiVersion.Parse("cats.dogs")); + Assert.Throws(() => ApiVersion.Parse("1.dogs")); + Assert.Throws(() => ApiVersion.Parse("1.0.1")); + Assert.Throws(() => ApiVersion.Parse("11")); + } + + [Test(Description = "Test that versions are compatible")] + public void Validate() + { + var twoOne = new ApiVersion(2, 1, 1); + Assert.DoesNotThrow(() => twoOne.Validate(twoOne)); + Assert.DoesNotThrow(() => twoOne.Validate(new ApiVersion(2, 0))); + Assert.DoesNotThrow(() => twoOne.Validate(new ApiVersion(1, 0))); + + var ex = Assert.Throws(() => twoOne.Validate(new ApiVersion(2, 2))); + Assert.That(ex.Message, Does.Contain("Peer supports newer minor version")); + ex = Assert.Throws(() => twoOne.Validate(new ApiVersion(3, 1))); + Assert.That(ex.Message, Does.Contain("Peer supports newer major version")); + ex = Assert.Throws(() => twoOne.Validate(new ApiVersion(0, 8))); + Assert.That(ex.Message, Does.Contain("Version is no longer supported")); + } +} diff --git a/Tests/Rpc.Proto/RpcHeaderTest.cs b/Tests/Rpc.Proto/RpcHeaderTest.cs new file mode 100644 index 0000000..69dd627 --- /dev/null +++ b/Tests/Rpc.Proto/RpcHeaderTest.cs @@ -0,0 +1,41 @@ +using System.Text; +using Coder.Desktop.Rpc.Proto; + +namespace Coder.Desktop.Tests.Rpc.Proto; + +[TestFixture] +public class RpcHeaderTest +{ + [Test(Description = "Parse and use some valid header strings")] + public void Valid() + { + var headerStr = "codervpn 2.1 manager"; + var header = RpcHeader.Parse(headerStr); + Assert.That(header.Role.ToString(), Is.EqualTo(RpcRole.Manager)); + Assert.That(header.Version, Is.EqualTo(new ApiVersion(2, 1))); + Assert.That(header.ToString(), Is.EqualTo(headerStr + "\n")); + Assert.That(header.ToBytes().ToArray(), Is.EqualTo(Encoding.UTF8.GetBytes(headerStr + "\n"))); + + headerStr = "codervpn 1.0 tunnel"; + header = RpcHeader.Parse(headerStr); + Assert.That(header.Role.ToString(), Is.EqualTo(RpcRole.Tunnel)); + Assert.That(header.Version, Is.EqualTo(new ApiVersion(1, 0))); + Assert.That(header.ToString(), Is.EqualTo(headerStr + "\n")); + Assert.That(header.ToBytes().ToArray(), Is.EqualTo(Encoding.UTF8.GetBytes(headerStr + "\n"))); + } + + [Test(Description = "Try to parse some invalid header strings")] + public void ParseInvalid() + { + var ex = Assert.Throws(() => RpcHeader.Parse("codervpn")); + Assert.That(ex.Message, Does.Contain("Wrong number of parts")); + ex = Assert.Throws(() => RpcHeader.Parse("codervpn 1.0 manager cats")); + Assert.That(ex.Message, Does.Contain("Wrong number of parts")); + ex = Assert.Throws(() => RpcHeader.Parse("codervpn 1.0")); + Assert.That(ex.Message, Does.Contain("Wrong number of parts")); + ex = Assert.Throws(() => RpcHeader.Parse("cats 1.0 manager")); + Assert.That(ex.Message, Does.Contain("Invalid preamble")); + ex = Assert.Throws(() => RpcHeader.Parse("codervpn 1.0 cats")); + Assert.That(ex.Message, Does.Contain("Unknown role 'cats'")); + } +} diff --git a/Tests/Rpc.Proto/RpcMessageTest.cs b/Tests/Rpc.Proto/RpcMessageTest.cs new file mode 100644 index 0000000..9f9c73f --- /dev/null +++ b/Tests/Rpc.Proto/RpcMessageTest.cs @@ -0,0 +1,39 @@ +using Coder.Desktop.Rpc.Proto; + +namespace Coder.Desktop.Tests.Rpc.Proto; + +[TestFixture] +public class RpcRoleAttributeTest +{ + [Test] + public void Valid() + { + var role = new RpcRoleAttribute(RpcRole.Manager); + Assert.That(role.Role.ToString(), Is.EqualTo(RpcRole.Manager)); + role = new RpcRoleAttribute(RpcRole.Tunnel); + Assert.That(role.Role.ToString(), Is.EqualTo(RpcRole.Tunnel)); + } + + [Test] + public void Invalid() + { + Assert.Throws(() => _ = new RpcRoleAttribute("cats")); + } +} + +[TestFixture] +public class RpcMessageTest +{ + [Test] + public void GetRole() + { + // RpcMessage is not a supported message type and doesn't have an + // RpcRoleAttribute + var ex = Assert.Throws(() => _ = RpcMessage.GetRole()); + Assert.That(ex.Message, + Does.Contain("Message type 'Coder.Desktop.Rpc.Proto.RPC' does not have a RpcRoleAttribute")); + + Assert.That(ManagerMessage.GetRole().ToString(), Is.EqualTo(RpcRole.Manager)); + Assert.That(TunnelMessage.GetRole().ToString(), Is.EqualTo(RpcRole.Tunnel)); + } +} diff --git a/Tests/Rpc.Proto/RpcRoleTest.cs b/Tests/Rpc.Proto/RpcRoleTest.cs new file mode 100644 index 0000000..59ad489 --- /dev/null +++ b/Tests/Rpc.Proto/RpcRoleTest.cs @@ -0,0 +1,22 @@ +using Coder.Desktop.Rpc.Proto; + +namespace Coder.Desktop.Tests.Rpc.Proto; + +[TestFixture] +public class RpcRoleTest +{ + [Test(Description = "Instantiate a RpcRole with a valid name")] + public void ValidRole() + { + var role = new RpcRole(RpcRole.Manager); + Assert.That(role.ToString(), Is.EqualTo(RpcRole.Manager)); + role = new RpcRole(RpcRole.Tunnel); + Assert.That(role.ToString(), Is.EqualTo(RpcRole.Tunnel)); + } + + [Test(Description = "Try to instantiate a RpcRole with an invalid name")] + public void InvalidRole() + { + Assert.Throws(() => _ = new RpcRole("cats")); + } +} diff --git a/Tests/Rpc/SerdesTest.cs b/Tests/Rpc/SerdesTest.cs new file mode 100644 index 0000000..ce7cfce --- /dev/null +++ b/Tests/Rpc/SerdesTest.cs @@ -0,0 +1,96 @@ +using System.Buffers.Binary; +using Coder.Desktop.Rpc; +using Coder.Desktop.Rpc.Proto; + +namespace Coder.Desktop.Tests.Rpc; + +[TestFixture] +public class SerdesTest +{ + [Test(Description = "Tests that writing and reading a message works")] + [Timeout(5_000)] + public async Task WriteReadMessage() + { + var (stream1, stream2) = BidirectionalPipe.New(); + var serdes = new Serdes(); + + var msg = new ManagerMessage + { + Rpc = new RPC + { + MsgId = 1, + }, + }; + await serdes.WriteMessage(stream1, msg); + var got = await serdes.ReadMessage(stream2); + Assert.That(msg, Is.EqualTo(got)); + } + + [Test(Description = "Tests that writing a message larger than 16 MiB throws an exception")] + [Timeout(5_000)] + public void WriteMessageTooLarge() + { + var (stream1, _) = BidirectionalPipe.New(); + var serdes = new Serdes(); + + var msg = new ManagerMessage + { + Rpc = new RPC + { + MsgId = 1, + }, + Start = new StartRequest + { + ApiToken = new string('a', 0x1000001), + CoderUrl = "test", + }, + }; + Assert.ThrowsAsync(() => serdes.WriteMessage(stream1, msg)); + } + + [Test(Description = "Tests that attempting to read a message larger than 16 MiB throws an exception")] + [Timeout(5_000)] + public async Task ReadMessageTooLarge() + { + var (stream1, stream2) = BidirectionalPipe.New(); + var serdes = new Serdes(); + + // In this test we don't actually write a message as the parser should + // bail out immediately after reading the message length + var lenBytes = new byte[4]; + BinaryPrimitives.WriteUInt32BigEndian(lenBytes, 0x1000001); + await stream1.WriteAsync(lenBytes); + Assert.ThrowsAsync(() => serdes.ReadMessage(stream2)); + } + + [Test(Description = "Read an empty (size 0) message from the stream")] + [Timeout(5_000)] + public async Task ReadEmptyMessage() + { + var (stream1, stream2) = BidirectionalPipe.New(); + var serdes = new Serdes(); + + // Write an empty message. + var lenBytes = new byte[4]; + BinaryPrimitives.WriteUInt32BigEndian(lenBytes, 0); + await stream1.WriteAsync(lenBytes); + var ex = Assert.ThrowsAsync(() => serdes.ReadMessage(stream2)); + Assert.That(ex.InnerException, Is.Not.Null); + Assert.That(ex.InnerException?.Message, Does.Contain("Parsed message is empty or invalid")); + } + + [Test(Description = "Read an invalid/corrupt message from the stream")] + [Timeout(5_000)] + public async Task ReadInvalidMessage() + { + var (stream1, stream2) = BidirectionalPipe.New(); + var serdes = new Serdes(); + + var lenBytes = new byte[4]; + BinaryPrimitives.WriteUInt32BigEndian(lenBytes, 1); + await stream1.WriteAsync(lenBytes); + await stream1.WriteAsync(new byte[1]); + var ex = Assert.ThrowsAsync(() => serdes.ReadMessage(stream2)); + Assert.That(ex.Message, Does.Not.Contain("Parsed message is empty or invalid")); + } +} diff --git a/Tests/Rpc/SpeakerTest.cs b/Tests/Rpc/SpeakerTest.cs index 15ae53c..c8ce60e 100644 --- a/Tests/Rpc/SpeakerTest.cs +++ b/Tests/Rpc/SpeakerTest.cs @@ -6,6 +6,8 @@ namespace Coder.Desktop.Tests.Rpc; +#region BidrectionalPipe + internal class BidirectionalPipe(PipeReader reader, PipeWriter writer) : Stream { public override bool CanRead => true; @@ -73,14 +75,18 @@ protected override void Dispose(bool disposing) } } +#endregion + +[TestFixture] public class SpeakerTest { - [Test] - public async Task Ok() + [Test(Description = "Send a message from speaker1 to speaker2, receive it, and send a reply back")] + [Timeout(30_000)] + public async Task SendReceiveReplyReceive() { var (stream1, stream2) = BidirectionalPipe.New(); - var speaker1 = new Speaker(stream1); + await using var speaker1 = new Speaker(stream1); var speaker1Ch = Channel .CreateUnbounded>(); speaker1.Receive += msg => @@ -90,7 +96,7 @@ public async Task Ok() }; speaker1.Error += ex => { Assert.Fail($"speaker1 error: {ex}"); }; - var speaker2 = new Speaker(stream2); + await using var speaker2 = new Speaker(stream2); var speaker2Ch = Channel .CreateUnbounded>(); speaker2.Receive += msg => @@ -103,6 +109,7 @@ public async Task Ok() // Start both speakers simultaneously Task.WaitAll(speaker1.StartAsync(), speaker2.StartAsync()); + // Send a message from speaker1 to speaker2 in the background var sendTask = speaker1.SendMessageAwaitReply(new ManagerMessage { Start = new StartRequest @@ -112,8 +119,11 @@ public async Task Ok() }, }); + // Receive the message in speaker2 var message = await speaker2Ch.Reader.ReadAsync(); Assert.That(message.Message.Start.ApiToken, Is.EqualTo("test")); + + // Send a reply back to speaker1 await message.SendReply(new TunnelMessage { Start = new StartResponse @@ -122,6 +132,7 @@ await message.SendReply(new TunnelMessage }, }); + // Receive the reply in speaker1 by awaiting sendTask var reply = await sendTask; Assert.That(reply.Message.Start.Success, Is.True); } From d095896b2b3e86be2beb84b55b4f132ce06857de Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Tue, 26 Nov 2024 16:10:42 +0900 Subject: [PATCH 5/8] More speaker tests --- Coder.Desktop.sln.DotSettings | 253 +++++++++++++++++++++++++++++++++ Rpc/Speaker.cs | 30 ++-- Rpc/Utilities/TaskUtilities.cs | 59 ++++++++ Tests/Rpc/SpeakerTest.cs | 239 +++++++++++++++++++++++++++++++ 4 files changed, 567 insertions(+), 14 deletions(-) create mode 100644 Rpc/Utilities/TaskUtilities.cs diff --git a/Coder.Desktop.sln.DotSettings b/Coder.Desktop.sln.DotSettings index 324ee6d..3859343 100644 --- a/Coder.Desktop.sln.DotSettings +++ b/Coder.Desktop.sln.DotSettings @@ -1,2 +1,255 @@  + <Patterns xmlns="urn:schemas-jetbrains-com:member-reordering-patterns"> + <TypePattern DisplayName="Non-reorderable types" Priority="99999999"> + <TypePattern.Match> + <Or> + <And> + <Kind Is="Interface" /> + <Or> + <HasAttribute Name="System.Runtime.InteropServices.InterfaceTypeAttribute" /> + <HasAttribute Name="System.Runtime.InteropServices.ComImport" /> + </Or> + </And> + <Kind Is="Struct" /> + <HasAttribute Name="System.Runtime.InteropServices.StructLayoutAttribute" /> + <HasAttribute Name="JetBrains.Annotations.NoReorderAttribute" /> + </Or> + </TypePattern.Match> + </TypePattern> + + <TypePattern DisplayName="xUnit.net Test Classes" RemoveRegions="All"> + <TypePattern.Match> + <And> + <Kind Is="Class" /> + <HasMember> + <And> + <Kind Is="Method" /> + <HasAttribute Name="Xunit.FactAttribute" Inherited="True" /> + <HasAttribute Name="Xunit.TheoryAttribute" Inherited="True" /> + </And> + </HasMember> + </And> + </TypePattern.Match> + + <Entry DisplayName="Fields"> + <Entry.Match> + <And> + <Kind Is="Field" /> + <Not> + <Static /> + </Not> + </And> + </Entry.Match> + + <Entry.SortBy> + <Readonly /> + <Name /> + </Entry.SortBy> + </Entry> + + <Entry DisplayName="Constructors"> + <Entry.Match> + <Kind Is="Constructor" /> + </Entry.Match> + + <Entry.SortBy> + <Static/> + </Entry.SortBy> + </Entry> + + <Entry DisplayName="Teardown Methods"> + <Entry.Match> + <And> + <Kind Is="Method" /> + <ImplementsInterface Name="System.IDisposable" /> + </And> + </Entry.Match> + </Entry> + + <Entry DisplayName="All other members" /> + + <Entry DisplayName="Test Methods" Priority="100"> + <Entry.Match> + <And> + <Kind Is="Method" /> + <HasAttribute Name="Xunit.FactAttribute" Inherited="false" /> + <HasAttribute Name="Xunit.TheoryAttribute" Inherited="false" /> + </And> + </Entry.Match> + + <Entry.SortBy> + <Name /> + </Entry.SortBy> + </Entry> + </TypePattern> + + <TypePattern DisplayName="NUnit Test Fixtures" RemoveRegions="All"> + <TypePattern.Match> + <And> + <Kind Is="Class" /> + <Or> + <HasAttribute Name="NUnit.Framework.TestFixtureAttribute" Inherited="true" /> + <HasAttribute Name="NUnit.Framework.TestFixtureSourceAttribute" Inherited="true" /> + <HasMember> + <And> + <Kind Is="Method" /> + <HasAttribute Name="NUnit.Framework.TestAttribute" Inherited="false" /> + <HasAttribute Name="NUnit.Framework.TestCaseAttribute" Inherited="false" /> + <HasAttribute Name="NUnit.Framework.TestCaseSourceAttribute" Inherited="false" /> + </And> + </HasMember> + </Or> + </And> + </TypePattern.Match> + + <Entry DisplayName="Setup/Teardown Methods"> + <Entry.Match> + <And> + <Kind Is="Method" /> + <Or> + <HasAttribute Name="NUnit.Framework.SetUpAttribute" Inherited="true" /> + <HasAttribute Name="NUnit.Framework.TearDownAttribute" Inherited="true" /> + <HasAttribute Name="NUnit.Framework.TestFixtureSetUpAttribute" Inherited="true" /> + <HasAttribute Name="NUnit.Framework.TestFixtureTearDownAttribute" Inherited="true" /> + <HasAttribute Name="NUnit.Framework.OneTimeSetUpAttribute" Inherited="true" /> + <HasAttribute Name="NUnit.Framework.OneTimeTearDownAttribute" Inherited="true" /> + </Or> + </And> + </Entry.Match> + </Entry> + + <Entry DisplayName="All other members" /> + + <Entry DisplayName="Test Methods" Priority="100"> + <Entry.Match> + <And> + <Kind Is="Method" /> + <HasAttribute Name="NUnit.Framework.TestAttribute" Inherited="false" /> + <HasAttribute Name="NUnit.Framework.TestCaseAttribute" Inherited="false" /> + <HasAttribute Name="NUnit.Framework.TestCaseSourceAttribute" Inherited="false" /> + </And> + </Entry.Match> + + <Entry.SortBy> + <Name /> + </Entry.SortBy> + </Entry> + </TypePattern> + + <TypePattern DisplayName="Default Pattern"> + <Entry DisplayName="Public Delegates" Priority="100"> + <Entry.Match> + <And> + <Access Is="Public" /> + <Kind Is="Delegate" /> + </And> + </Entry.Match> + + <Entry.SortBy> + <Name /> + </Entry.SortBy> + </Entry> + + <Entry DisplayName="Public Enums" Priority="100"> + <Entry.Match> + <And> + <Access Is="Public" /> + <Kind Is="Enum" /> + </And> + </Entry.Match> + + <Entry.SortBy> + <Name /> + </Entry.SortBy> + </Entry> + + <Entry DisplayName="Static Fields and Constants"> + <Entry.Match> + <Or> + <Kind Is="Constant" /> + <And> + <Kind Is="Field" /> + <Static /> + </And> + </Or> + </Entry.Match> + + <Entry.SortBy> + <Kind> + <Kind.Order> + <DeclarationKind>Constant</DeclarationKind> + <DeclarationKind>Field</DeclarationKind> + </Kind.Order> + </Kind> + </Entry.SortBy> + </Entry> + + <Entry DisplayName="Fields"> + <Entry.Match> + <And> + <Kind Is="Field" /> + <Not> + <Static /> + </Not> + </And> + </Entry.Match> + + <Entry.SortBy> + <Readonly /> + <Name /> + </Entry.SortBy> + </Entry> + + <Entry DisplayName="Events"> + <Entry.Match> + <Kind Is="Event" /> + </Entry.Match> + + <Entry.SortBy> + <Name /> + </Entry.SortBy> + </Entry> + + <Entry DisplayName="Constructors"> + <Entry.Match> + <Kind Is="Constructor" /> + </Entry.Match> + + <Entry.SortBy> + <Static/> + </Entry.SortBy> + </Entry> + + <Entry DisplayName="Properties, Indexers"> + <Entry.Match> + <Or> + <Kind Is="Property" /> + <Kind Is="Indexer" /> + </Or> + </Entry.Match> + </Entry> + + <Entry DisplayName="Interface Implementations" Priority="100"> + <Entry.Match> + <And> + <Kind Is="Member" /> + <ImplementsInterface /> + </And> + </Entry.Match> + + <Entry.SortBy> + <ImplementsInterface Immediate="true" /> + </Entry.SortBy> + </Entry> + + <Entry DisplayName="All other members" /> + + <Entry DisplayName="Nested Types"> + <Entry.Match> + <Kind Is="Type" /> + </Entry.Match> + </Entry> + </TypePattern> +</Patterns> + True \ No newline at end of file diff --git a/Rpc/Speaker.cs b/Rpc/Speaker.cs index d5e8b44..a6cc3bd 100644 --- a/Rpc/Speaker.cs +++ b/Rpc/Speaker.cs @@ -1,6 +1,7 @@ using System.Collections.Concurrent; using System.Text; using Coder.Desktop.Rpc.Proto; +using Coder.Desktop.Rpc.Utilities; using Google.Protobuf; namespace Coder.Desktop.Rpc; @@ -46,17 +47,6 @@ public class Speaker : IAsyncDisposable public delegate void OnReceiveDelegate(ReplyableRpcMessage message); - /// - /// Event that is triggered when a message is received. - /// - public event OnReceiveDelegate? Receive; - - /// - /// Event that is triggered when an error occurs. The handling code should dispose the Speaker after this event is - /// triggered. - /// - public event OnErrorDelegate? Error; - private readonly Stream _conn; // _cts is cancelled when Dispose is called and will cause all ongoing I/O @@ -70,6 +60,17 @@ public class Speaker : IAsyncDisposable private ulong _lastMessageId; private Task? _receiveTask; + /// + /// Event that is triggered when an error occurs. The handling code should dispose the Speaker after this event is + /// triggered. + /// + public event OnErrorDelegate? Error; + + /// + /// Event that is triggered when a message is received. + /// + public event OnReceiveDelegate? Receive; + /// /// Instantiates a speaker. The speaker will not perform any I/O until StartAsync is called. /// @@ -111,9 +112,10 @@ private async Task PerformHandshake(CancellationToken ct = default) { // Simultaneously write the header string and read the header string in // case the conn is not buffered. - var writeTask = WriteHeader(ct); - var readTask = ReadHeader(ct); - await Task.WhenAll(writeTask, readTask); + var headerCts = CancellationTokenSource.CreateLinkedTokenSource(ct, _cts.Token); + var writeTask = WriteHeader(headerCts.Token); + var readTask = ReadHeader(headerCts.Token); + await TaskUtilities.CancellableWhenAll(headerCts, writeTask, readTask); var header = RpcHeader.Parse(await readTask); var expectedRole = RpcMessage.GetRole(); diff --git a/Rpc/Utilities/TaskUtilities.cs b/Rpc/Utilities/TaskUtilities.cs new file mode 100644 index 0000000..65f382d --- /dev/null +++ b/Rpc/Utilities/TaskUtilities.cs @@ -0,0 +1,59 @@ +namespace Coder.Desktop.Rpc.Utilities; + +internal static class TaskUtilities +{ + /// + /// Waits for all tasks to complete, but cancels the provided CancellationTokenSource if any task is canceled or + /// faulted. The first cancel or fault will be propagated to the returned Task. All passed in tasks must be using the + /// same CancellationTokenSource. + /// The returned task will wait for all tasks to be completed. + /// + /// + /// + /// var cts = new CancellationTokenSource(); + /// var task1 = Task.Delay(1000, cts.Token); + /// var task2 = Task.Delay(2000, cts.Token); + /// await TaskUtilities.CancellableWhenAll(cts, task1, task2); + /// + /// + /// Tasks to wait on + /// The cancellation token source that was provided to each task + /// + /// A task that completes when all tasks are completed, with the cancellation or exception state of the first + /// non-successful task + /// + public static async Task CancellableWhenAll(CancellationTokenSource cts, params Task[] tasks) + { + var taskList = tasks.ToList(); + if (taskList.Count == 0) return; + var tcs = new TaskCompletionSource(); + + var tasksWithCancellation = taskList.Select(task => + task.ContinueWith(t => + { + if (t.IsFaulted) + { + cts.Cancel(); + tcs.TrySetException(t.Exception.InnerExceptions.First()); + } + else if (t.IsCanceled) + { + cts.Cancel(); + tcs.TrySetCanceled(); + } + })); + + // Wait for all the task continuations to complete. + try + { + await Task.WhenAll(tasksWithCancellation); + tcs.TrySetResult(); + } + catch + { + // Exception was already propagated. + } + + await tcs.Task; + } +} diff --git a/Tests/Rpc/SpeakerTest.cs b/Tests/Rpc/SpeakerTest.cs index c8ce60e..19f58f2 100644 --- a/Tests/Rpc/SpeakerTest.cs +++ b/Tests/Rpc/SpeakerTest.cs @@ -1,5 +1,6 @@ using System.Buffers; using System.IO.Pipelines; +using System.Reflection; using System.Threading.Channels; using Coder.Desktop.Rpc; using Coder.Desktop.Rpc.Proto; @@ -77,6 +78,95 @@ protected override void Dispose(bool disposing) #endregion +#region FailableStream + +internal class FailableStream : Stream +{ + private readonly Stream _inner; + private readonly TaskCompletionSource _readTcs = new(); + + private readonly TaskCompletionSource _writeTcs = new(); + + public FailableStream(Stream inner, Exception? writeException, Exception? readException) + { + _inner = inner; + if (writeException != null) _writeTcs.SetException(writeException); + if (readException != null) _readTcs.SetException(readException); + } + + public override bool CanRead => _inner.CanRead; + public override bool CanSeek => _inner.CanSeek; + public override bool CanWrite => _inner.CanWrite; + public override long Length => _inner.Length; + + public override long Position + { + get => _inner.Position; + set => _inner.Position = value; + } + + public void SetWriteException(Exception ex) + { + _writeTcs.SetException(ex); + } + + public void SetReadException(Exception ex) + { + _readTcs.SetException(ex); + } + + public override void Flush() + { + _inner.Flush(); + } + + public override long Seek(long offset, SeekOrigin origin) + { + return _inner.Seek(offset, origin); + } + + public override void SetLength(long value) + { + _inner.SetLength(value); + } + + public override int Read(byte[] buffer, int offset, int count) + { + return _inner.ReadAsync(buffer, offset, count).GetAwaiter().GetResult(); + } + + private void CheckException(TaskCompletionSource tcs) + { + if (tcs.Task.IsFaulted) throw tcs.Task.Exception.InnerException!; + } + + public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + CheckException(_readTcs); + var readTask = _inner.ReadAsync(buffer, cancellationToken); + await Task.WhenAny(readTask.AsTask(), _readTcs.Task); + CheckException(_readTcs); + return await readTask; + } + + public override void Write(byte[] buffer, int offset, int count) + { + _inner.WriteAsync(buffer, offset, count).Wait(); + } + + public override async ValueTask WriteAsync(ReadOnlyMemory buffer, + CancellationToken cancellationToken = default) + { + CheckException(_writeTcs); + var writeTask = _inner.WriteAsync(buffer, cancellationToken); + await Task.WhenAny(writeTask.AsTask(), _writeTcs.Task); + CheckException(_writeTcs); + await writeTask; + } +} + +#endregion + [TestFixture] public class SpeakerTest { @@ -136,4 +226,153 @@ await message.SendReply(new TunnelMessage var reply = await sendTask; Assert.That(reply.Message.Start.Success, Is.True); } + + [Test(Description = "Encounter a write error during handshake")] + [Timeout(30_000)] + public async Task WriteError() + { + var (stream1, _) = BidirectionalPipe.New(); + var writeEx = new IOException("Test write error"); + var failStream = new FailableStream(stream1, writeEx, null); + + await using var speaker = new Speaker(failStream); + + var gotEx = Assert.ThrowsAsync(() => speaker.StartAsync()); + Assert.That(gotEx, Is.EqualTo(writeEx)); + } + + [Test(Description = "Encounter a read error during handshake")] + [Timeout(30_000)] + public async Task ReadError() + { + var (stream1, _) = BidirectionalPipe.New(); + var readEx = new IOException("Test read error"); + var failStream = new FailableStream(stream1, null, readEx); + + await using var speaker = new Speaker(failStream); + + var gotEx = Assert.ThrowsAsync(() => speaker.StartAsync()); + Assert.That(gotEx, Is.EqualTo(readEx)); + } + + [Test(Description = "Receive a header that exceeds 256 bytes")] + [Timeout(30_000)] + public async Task ReadLargeHeader() + { + var (stream1, stream2) = BidirectionalPipe.New(); + await using var speaker1 = new Speaker(stream1); + + var header = new byte[257]; + for (var i = 0; i < header.Length; i++) header[i] = (byte)'a'; + await stream2.WriteAsync(header); + + var gotEx = Assert.ThrowsAsync(() => speaker1.StartAsync()); + Assert.That(gotEx.Message, Does.Contain("Header malformed or too large")); + } + + [Test(Description = "Encounter a write error during message send")] + [Timeout(30_000)] + public async Task SendMessageWriteError() + { + var (stream1, stream2) = BidirectionalPipe.New(); + var failStream = new FailableStream(stream1, null, null); + + await using var speaker1 = new Speaker(failStream); + speaker1.Receive += msg => Assert.Fail($"speaker1 received message: {msg}"); + speaker1.Error += ex => Assert.Fail($"speaker1 error: {ex}"); + await using var speaker2 = new Speaker(stream2); + speaker2.Receive += msg => Assert.Fail($"speaker2 received message: {msg}"); + speaker2.Error += ex => Assert.Fail($"speaker2 error: {ex}"); + await Task.WhenAll(speaker1.StartAsync(), speaker2.StartAsync()); + + var writeEx = new IOException("Test write error"); + failStream.SetWriteException(writeEx); + + var gotEx = Assert.ThrowsAsync(() => speaker1.SendMessage(new ManagerMessage())); + Assert.That(gotEx, Is.EqualTo(writeEx)); + } + + [Test(Description = "Encounter a read error during message receive")] + [Timeout(30_000)] + public async Task ReceiveMessageReadError() + { + var (stream1, stream2) = BidirectionalPipe.New(); + var failStream = new FailableStream(stream1, null, null); + + // Speaker1 is bound to failStream and will write an error to errorCh + var errorCh = Channel.CreateUnbounded(); + await using var speaker1 = new Speaker(failStream); + speaker1.Receive += msg => Assert.Fail($"speaker1 received message: {msg}"); + speaker1.Error += ex => errorCh.Writer.TryWrite(ex); + + // Speaker2 is normal and is only used to perform a handshake + await using var speaker2 = new Speaker(stream2); + speaker2.Receive += msg => Assert.Fail($"speaker2 received message: {msg}"); + speaker2.Error += ex => Assert.Fail($"speaker2 error: {ex}"); + await Task.WhenAll(speaker1.StartAsync(), speaker2.StartAsync()); + + // Now the handshake is complete, cause all reads to fail + var readEx = new IOException("Test write error"); + failStream.SetReadException(readEx); + + var gotEx = await errorCh.Reader.ReadAsync(); + Assert.That(gotEx, Is.EqualTo(readEx)); + + // The receive loop should be stopped within a timely fashion. + var receiveLoopTask = (Task?)speaker1.GetType() + .GetField("_receiveTask", BindingFlags.NonPublic | BindingFlags.Instance)!.GetValue(speaker1); + if (receiveLoopTask is null) + { + Assert.Fail("Receive loop task not found"); + } + else + { + var delayTask = Task.Delay(TimeSpan.FromSeconds(5)); + await Task.WhenAny(receiveLoopTask, delayTask); + Assert.That(receiveLoopTask.IsCompleted, Is.True); + } + } + + [Test(Description = "Handle dispose while receive loop is running")] + [Timeout(30_000)] + public async Task DisposeWhileReceiveLoopRunning() + { + var (stream1, stream2) = BidirectionalPipe.New(); + var speaker1 = new Speaker(stream1); + await using var speaker2 = new Speaker(stream2); + await Task.WhenAll(speaker1.StartAsync(), speaker2.StartAsync()); + + // Dispose should happen in a timely fashion + var disposeTask = speaker1.DisposeAsync(); + var delayTask = Task.Delay(TimeSpan.FromSeconds(5)); + await Task.WhenAny(disposeTask.AsTask(), delayTask); + Assert.That(disposeTask.IsCompleted, Is.True); + + // Receive loop should be stopped + var receiveLoopTask = (Task?)speaker1.GetType() + .GetField("_receiveTask", BindingFlags.NonPublic | BindingFlags.Instance)!.GetValue(speaker1); + if (receiveLoopTask is null) + Assert.Fail("Receive loop task not found"); + else + Assert.That(receiveLoopTask.IsCompleted, Is.True); + } + + [Test(Description = "Handle dispose while a message is awaiting a reply")] + [Timeout(30_000)] + public async Task DisposeWhileAwaitingReply() + { + var (stream1, stream2) = BidirectionalPipe.New(); + var speaker1 = new Speaker(stream1); + await using var speaker2 = new Speaker(stream2); + await Task.WhenAll(speaker1.StartAsync(), speaker2.StartAsync()); + + // Send a message from speaker1 to speaker2 + var sendTask = speaker1.SendMessageAwaitReply(new ManagerMessage()); + + // Dispose speaker1 + await speaker1.DisposeAsync(); + + // The send task should complete with an exception + Assert.ThrowsAsync(() => sendTask.AsTask()); + } } From 6d4ba04f3404aae093f2a92f7445d73870f74569 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Wed, 27 Nov 2024 16:47:33 +0900 Subject: [PATCH 6/8] Fix Rpc field management to match Go --- Coder.Desktop.sln.DotSettings | 1 + Rpc.Proto/RpcMessage.cs | 22 +++++++++++++++++--- Rpc/Serdes.cs | 13 +++++++++--- Rpc/Speaker.cs | 39 +++++++++++++++++++++-------------- Tests/Rpc/SerdesTest.cs | 15 ++++---------- Tests/Rpc/SpeakerTest.cs | 39 +++++++++++++++++++++++------------ 6 files changed, 84 insertions(+), 45 deletions(-) diff --git a/Coder.Desktop.sln.DotSettings b/Coder.Desktop.sln.DotSettings index 3859343..636b95d 100644 --- a/Coder.Desktop.sln.DotSettings +++ b/Coder.Desktop.sln.DotSettings @@ -252,4 +252,5 @@ </TypePattern> </Patterns> + True True \ No newline at end of file diff --git a/Rpc.Proto/RpcMessage.cs b/Rpc.Proto/RpcMessage.cs index 6662a25..035e87c 100644 --- a/Rpc.Proto/RpcMessage.cs +++ b/Rpc.Proto/RpcMessage.cs @@ -19,7 +19,7 @@ public abstract class RpcMessage where T : IMessage /// The inner RPC component of the message. This is a separate field as the C# compiler does not allow the existing Rpc /// field to be overridden or implement this abstract property. /// - public abstract RPC RpcField { get; set; } + public abstract RPC? RpcField { get; set; } /// /// The inner message component of the message. This exists so values of type RpcMessage can easily get message @@ -27,6 +27,12 @@ public abstract class RpcMessage where T : IMessage /// public abstract T Message { get; } + /// + /// Check if the message is valid. Checks for empty oneof of fields. + /// + /// Invalid message + public abstract void Validate(); + /// /// Gets the RpcRole of the message type from it's RpcRole attribute. /// @@ -44,23 +50,33 @@ public static RpcRole GetRole() [RpcRole(RpcRole.Manager)] public partial class ManagerMessage : RpcMessage { - public override RPC RpcField + public override RPC? RpcField { get => Rpc; set => Rpc = value; } public override ManagerMessage Message => this; + + public override void Validate() + { + if (MsgCase == MsgOneofCase.None) throw new ArgumentException("Message does not contain inner message type"); + } } [RpcRole(RpcRole.Tunnel)] public partial class TunnelMessage : RpcMessage { - public override RPC RpcField + public override RPC? RpcField { get => Rpc; set => Rpc = value; } public override TunnelMessage Message => this; + + public override void Validate() + { + if (MsgCase == MsgOneofCase.None) throw new ArgumentException("Message does not contain inner message type"); + } } diff --git a/Rpc/Serdes.cs b/Rpc/Serdes.cs index c7f5632..cf19655 100644 --- a/Rpc/Serdes.cs +++ b/Rpc/Serdes.cs @@ -47,12 +47,15 @@ public class Serdes /// Stream to write the encoded message to /// Message to encode and write /// Optional cancellation token - /// If the message exceeds the maximum message size of 16 MiB + /// If the message is invalid public async Task WriteMessage(Stream conn, TS message, CancellationToken ct = default) { + message.Validate(); // throws ArgumentException if invalid using var _ = await _writeLock.LockAsync(ct); var mb = message.ToByteArray(); + if (mb.Length == 0) + throw new ArgumentException("Marshalled message is empty"); if (mb.Length > MaxMessageSize) throw new ArgumentException($"Marshalled message size {mb.Length} exceeds maximum {MaxMessageSize}"); @@ -69,6 +72,7 @@ public async Task WriteMessage(Stream conn, TS message, CancellationToken ct = d /// Optional cancellation token /// Decoded message /// Could not decode the message + /// The message is invalid public async Task ReadMessage(Stream conn, CancellationToken ct = default) { using var _ = await _readLock.LockAsync(ct); @@ -76,6 +80,8 @@ public async Task ReadMessage(Stream conn, CancellationToken ct = default) var lenBytes = new byte[sizeof(uint)]; await conn.ReadExactlyAsync(lenBytes, ct); var len = BinaryPrimitives.ReadUInt32BigEndian(lenBytes); + if (len == 0) + throw new IOException("Received message size 0"); if (len > MaxMessageSize) throw new IOException($"Received message size {len} exceeds maximum {MaxMessageSize}"); @@ -85,8 +91,9 @@ public async Task ReadMessage(Stream conn, CancellationToken ct = default) try { var msg = _parser.ParseFrom(msgBytes); - if (msg?.RpcField is null) - throw new IOException("Parsed message is empty or invalid"); + if (msg is null) + throw new IOException("Parsed message is null"); + msg.Validate(); // throws ArgumentException if invalid return msg; } catch (Exception e) diff --git a/Rpc/Speaker.cs b/Rpc/Speaker.cs index a6cc3bd..73297d7 100644 --- a/Rpc/Speaker.cs +++ b/Rpc/Speaker.cs @@ -15,7 +15,7 @@ public class ReplyableRpcMessage(Speaker speaker, TR message) : where TS : RpcMessage, IMessage where TR : RpcMessage, IMessage, new() { - public override RPC RpcField + public override RPC? RpcField { get => message.RpcField; set => message.RpcField = value; @@ -23,6 +23,11 @@ public override RPC RpcField public override TR Message => message; + public override void Validate() + { + message.Validate(); + } + /// /// Sends a reply to the original message. /// @@ -55,9 +60,9 @@ public class Speaker : IAsyncDisposable private readonly ConcurrentDictionary> _pendingReplies = new(); private readonly Serdes _serdes = new(); - // _lastMessageId is incremented using an atomic operation, and as such the - // first message ID will actually be 1. - private ulong _lastMessageId; + // _lastRequestId is incremented using an atomic operation, and as such the + // first request ID will actually be 1. + private ulong _lastRequestId; private Task? _receiveTask; /// @@ -156,13 +161,17 @@ private async Task ReceiveLoop(CancellationToken ct = default) while (!ct.IsCancellationRequested) { var message = await _serdes.ReadMessage(_conn, ct); - if (message.RpcField.ResponseTo != 0) + if (message is { RpcField.ResponseTo : not 0 }) + { // Look up the TaskCompletionSource for the message ID and // complete it with the message. if (_pendingReplies.TryRemove(message.RpcField.ResponseTo, out var tcs)) tcs.SetResult(message); + else + // TODO: we should log unknown replies + continue; + } - // TODO: we should log unknown replies // Start a new task in the background to handle the message. _ = Task.Run(() => Receive?.Invoke(new ReplyableRpcMessage(this, message)), ct); } @@ -178,18 +187,14 @@ private async Task ReceiveLoop(CancellationToken ct = default) } /// - /// Send a message without waiting for a reply. If a reply is received it will be handled by the callback. + /// Send a message that does not expect a reply. /// /// Message to send /// Optional cancellation token public async Task SendMessage(TS message, CancellationToken ct = default) { using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct, _cts.Token); - message.RpcField = new RPC - { - MsgId = Interlocked.Add(ref _lastMessageId, 1), - ResponseTo = 0, - }; + message.RpcField = null; await _serdes.WriteMessage(_conn, message, cts.Token); } @@ -200,12 +205,12 @@ public async Task SendMessage(TS message, CancellationToken ct = default) /// Message to send - the Rpc field will be overwritten /// Optional cancellation token /// Received reply - public async ValueTask SendMessageAwaitReply(TS message, CancellationToken ct = default) + public async ValueTask SendRequestAwaitReply(TS message, CancellationToken ct = default) { using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct, _cts.Token); message.RpcField = new RPC { - MsgId = Interlocked.Add(ref _lastMessageId, 1), + MsgId = Interlocked.Add(ref _lastRequestId, 1), ResponseTo = 0, }; @@ -233,12 +238,16 @@ public async ValueTask SendMessageAwaitReply(TS message, CancellationToken c /// Message to reply to - the Rpc field will be overwritten /// Reply message /// Optional cancellation token + /// The original message is not a request and cannot be replied to public async Task SendReply(TR originalMessage, TS reply, CancellationToken ct = default) { + if (originalMessage.RpcField == null || originalMessage.RpcField.MsgId == 0) + throw new ArgumentException("Original message is not a request"); + using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct, _cts.Token); reply.RpcField = new RPC { - MsgId = Interlocked.Add(ref _lastMessageId, 1), + MsgId = 0, ResponseTo = originalMessage.RpcField.MsgId, }; await _serdes.WriteMessage(_conn, reply, cts.Token); diff --git a/Tests/Rpc/SerdesTest.cs b/Tests/Rpc/SerdesTest.cs index ce7cfce..df34cfa 100644 --- a/Tests/Rpc/SerdesTest.cs +++ b/Tests/Rpc/SerdesTest.cs @@ -1,6 +1,7 @@ using System.Buffers.Binary; using Coder.Desktop.Rpc; using Coder.Desktop.Rpc.Proto; +using Google.Protobuf; namespace Coder.Desktop.Tests.Rpc; @@ -16,10 +17,7 @@ public async Task WriteReadMessage() var msg = new ManagerMessage { - Rpc = new RPC - { - MsgId = 1, - }, + Start = new StartRequest(), }; await serdes.WriteMessage(stream1, msg); var got = await serdes.ReadMessage(stream2); @@ -35,10 +33,6 @@ public void WriteMessageTooLarge() var msg = new ManagerMessage { - Rpc = new RPC - { - MsgId = 1, - }, Start = new StartRequest { ApiToken = new string('a', 0x1000001), @@ -75,8 +69,7 @@ public async Task ReadEmptyMessage() BinaryPrimitives.WriteUInt32BigEndian(lenBytes, 0); await stream1.WriteAsync(lenBytes); var ex = Assert.ThrowsAsync(() => serdes.ReadMessage(stream2)); - Assert.That(ex.InnerException, Is.Not.Null); - Assert.That(ex.InnerException?.Message, Does.Contain("Parsed message is empty or invalid")); + Assert.That(ex.Message, Does.Contain("Received message size 0")); } [Test(Description = "Read an invalid/corrupt message from the stream")] @@ -91,6 +84,6 @@ public async Task ReadInvalidMessage() await stream1.WriteAsync(lenBytes); await stream1.WriteAsync(new byte[1]); var ex = Assert.ThrowsAsync(() => serdes.ReadMessage(stream2)); - Assert.That(ex.Message, Does.Not.Contain("Parsed message is empty or invalid")); + Assert.That(ex.InnerException, Is.TypeOf(typeof(InvalidProtocolBufferException))); } } diff --git a/Tests/Rpc/SpeakerTest.cs b/Tests/Rpc/SpeakerTest.cs index 19f58f2..7b26546 100644 --- a/Tests/Rpc/SpeakerTest.cs +++ b/Tests/Rpc/SpeakerTest.cs @@ -179,28 +179,29 @@ public async Task SendReceiveReplyReceive() await using var speaker1 = new Speaker(stream1); var speaker1Ch = Channel .CreateUnbounded>(); - speaker1.Receive += msg => - { - Console.WriteLine($"speaker1 received message: {msg.RpcField.MsgId}"); - Assert.That(speaker1Ch.Writer.TryWrite(msg), Is.True); - }; + speaker1.Receive += msg => { Assert.That(speaker1Ch.Writer.TryWrite(msg), Is.True); }; speaker1.Error += ex => { Assert.Fail($"speaker1 error: {ex}"); }; await using var speaker2 = new Speaker(stream2); var speaker2Ch = Channel .CreateUnbounded>(); - speaker2.Receive += msg => - { - Console.WriteLine($"speaker2 received message: {msg.RpcField.MsgId}"); - Assert.That(speaker2Ch.Writer.TryWrite(msg), Is.True); - }; + speaker2.Receive += msg => { Assert.That(speaker2Ch.Writer.TryWrite(msg), Is.True); }; speaker2.Error += ex => { Assert.Fail($"speaker2 error: {ex}"); }; // Start both speakers simultaneously Task.WaitAll(speaker1.StartAsync(), speaker2.StartAsync()); + // Send a normal message from speaker2 to speaker1 + await speaker2.SendMessage(new TunnelMessage + { + PeerUpdate = new PeerUpdate(), + }); + var receivedMessage = await speaker1Ch.Reader.ReadAsync(); + Assert.That(receivedMessage.RpcField, Is.Null); // not a request + Assert.That(receivedMessage.Message.PeerUpdate, Is.Not.Null); + // Send a message from speaker1 to speaker2 in the background - var sendTask = speaker1.SendMessageAwaitReply(new ManagerMessage + var sendTask = speaker1.SendRequestAwaitReply(new ManagerMessage { Start = new StartRequest { @@ -211,6 +212,9 @@ public async Task SendReceiveReplyReceive() // Receive the message in speaker2 var message = await speaker2Ch.Reader.ReadAsync(); + Assert.That(message.RpcField, Is.Not.Null); + Assert.That(message.RpcField!.MsgId, Is.Not.EqualTo(0)); + Assert.That(message.RpcField!.ResponseTo, Is.EqualTo(0)); Assert.That(message.Message.Start.ApiToken, Is.EqualTo("test")); // Send a reply back to speaker1 @@ -224,6 +228,9 @@ await message.SendReply(new TunnelMessage // Receive the reply in speaker1 by awaiting sendTask var reply = await sendTask; + Assert.That(message.RpcField, Is.Not.Null); + Assert.That(reply.RpcField!.MsgId, Is.EqualTo(0)); + Assert.That(reply.RpcField!.ResponseTo, Is.EqualTo(message.RpcField!.MsgId)); Assert.That(reply.Message.Start.Success, Is.True); } @@ -288,7 +295,10 @@ public async Task SendMessageWriteError() var writeEx = new IOException("Test write error"); failStream.SetWriteException(writeEx); - var gotEx = Assert.ThrowsAsync(() => speaker1.SendMessage(new ManagerMessage())); + var gotEx = Assert.ThrowsAsync(() => speaker1.SendMessage(new ManagerMessage + { + Start = new StartRequest(), + })); Assert.That(gotEx, Is.EqualTo(writeEx)); } @@ -367,7 +377,10 @@ public async Task DisposeWhileAwaitingReply() await Task.WhenAll(speaker1.StartAsync(), speaker2.StartAsync()); // Send a message from speaker1 to speaker2 - var sendTask = speaker1.SendMessageAwaitReply(new ManagerMessage()); + var sendTask = speaker1.SendRequestAwaitReply(new ManagerMessage + { + Start = new StartRequest(), + }); // Dispose speaker1 await speaker1.DisposeAsync(); From 7a00c76f61236c9340089c0136b4bd95d1146320 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Wed, 27 Nov 2024 17:11:13 +0900 Subject: [PATCH 7/8] Rename Coder.Desktop.Rpc to Coder.Desktop.Vpn --- Coder.Desktop.sln | 4 ++-- Tests/Tests.csproj | 2 +- Tests/{Rpc.Proto => Vpn.Proto}/ApiVersionTest.cs | 4 ++-- Tests/{Rpc.Proto => Vpn.Proto}/RpcHeaderTest.cs | 4 ++-- Tests/{Rpc.Proto => Vpn.Proto}/RpcMessageTest.cs | 6 +++--- Tests/{Rpc.Proto => Vpn.Proto}/RpcRoleTest.cs | 4 ++-- Tests/{Rpc => Vpn}/SerdesTest.cs | 6 +++--- Tests/{Rpc => Vpn}/SpeakerTest.cs | 6 +++--- {Rpc.Proto => Vpn.Proto}/ApiVersion.cs | 2 +- {Rpc.Proto => Vpn.Proto}/RpcHeader.cs | 2 +- {Rpc.Proto => Vpn.Proto}/RpcMessage.cs | 2 +- {Rpc.Proto => Vpn.Proto}/RpcRole.cs | 2 +- Rpc.Proto/Rpc.Proto.csproj => Vpn.Proto/Vpn.Proto.csproj | 2 +- {Rpc.Proto => Vpn.Proto}/vpn.proto | 2 +- {Rpc => Vpn}/Serdes.cs | 8 ++++---- {Rpc => Vpn}/Speaker.cs | 6 +++--- {Rpc => Vpn}/Utilities/TaskUtilities.cs | 2 +- Rpc/Rpc.csproj => Vpn/Vpn.csproj | 4 ++-- 18 files changed, 34 insertions(+), 34 deletions(-) rename Tests/{Rpc.Proto => Vpn.Proto}/ApiVersionTest.cs (95%) rename Tests/{Rpc.Proto => Vpn.Proto}/RpcHeaderTest.cs (96%) rename Tests/{Rpc.Proto => Vpn.Proto}/RpcMessageTest.cs (88%) rename Tests/{Rpc.Proto => Vpn.Proto}/RpcRoleTest.cs (88%) rename Tests/{Rpc => Vpn}/SerdesTest.cs (97%) rename Tests/{Rpc => Vpn}/SpeakerTest.cs (99%) rename {Rpc.Proto => Vpn.Proto}/ApiVersion.cs (98%) rename {Rpc.Proto => Vpn.Proto}/RpcHeader.cs (97%) rename {Rpc.Proto => Vpn.Proto}/RpcMessage.cs (98%) rename {Rpc.Proto => Vpn.Proto}/RpcRole.cs (97%) rename Rpc.Proto/Rpc.Proto.csproj => Vpn.Proto/Vpn.Proto.csproj (91%) rename {Rpc.Proto => Vpn.Proto}/vpn.proto (99%) rename {Rpc => Vpn}/Serdes.cs (96%) rename {Rpc => Vpn}/Speaker.cs (99%) rename {Rpc => Vpn}/Utilities/TaskUtilities.cs (98%) rename Rpc/Rpc.csproj => Vpn/Vpn.csproj (68%) diff --git a/Coder.Desktop.sln b/Coder.Desktop.sln index 9f43629..342963b 100644 --- a/Coder.Desktop.sln +++ b/Coder.Desktop.sln @@ -1,9 +1,9 @@  Microsoft Visual Studio Solution File, Format Version 12.00 # -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Coder.Desktop.Rpc", "Rpc\Rpc.csproj", "{B342F896-C721-4AA5-A0F6-0BFA8EBAFACB}" +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Coder.Desktop.Vpn", "Vpn\Vpn.csproj", "{B342F896-C721-4AA5-A0F6-0BFA8EBAFACB}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Coder.Desktop.Rpc.Proto", "Rpc.Proto\Rpc.Proto.csproj", "{318E78BB-E6AD-410F-8F3F-B680F6880293}" +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Coder.Desktop.Vpn.Proto", "Vpn.Proto\Vpn.Proto.csproj", "{318E78BB-E6AD-410F-8F3F-B680F6880293}" EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Coder.Desktop.Tests", "Tests\Tests.csproj", "{D247B2E7-38A0-4A69-A710-7E8FAA7B807E}" EndProject diff --git a/Tests/Tests.csproj b/Tests/Tests.csproj index 1008942..cccd5dc 100644 --- a/Tests/Tests.csproj +++ b/Tests/Tests.csproj @@ -24,7 +24,7 @@ - + diff --git a/Tests/Rpc.Proto/ApiVersionTest.cs b/Tests/Vpn.Proto/ApiVersionTest.cs similarity index 95% rename from Tests/Rpc.Proto/ApiVersionTest.cs rename to Tests/Vpn.Proto/ApiVersionTest.cs index bf33f27..3536bd2 100644 --- a/Tests/Rpc.Proto/ApiVersionTest.cs +++ b/Tests/Vpn.Proto/ApiVersionTest.cs @@ -1,6 +1,6 @@ -using Coder.Desktop.Rpc.Proto; +using Coder.Desktop.Vpn.Proto; -namespace Coder.Desktop.Tests.Rpc.Proto; +namespace Coder.Desktop.Tests.Vpn.Proto; [TestFixture] public class ApiVersionTest diff --git a/Tests/Rpc.Proto/RpcHeaderTest.cs b/Tests/Vpn.Proto/RpcHeaderTest.cs similarity index 96% rename from Tests/Rpc.Proto/RpcHeaderTest.cs rename to Tests/Vpn.Proto/RpcHeaderTest.cs index 69dd627..17c8636 100644 --- a/Tests/Rpc.Proto/RpcHeaderTest.cs +++ b/Tests/Vpn.Proto/RpcHeaderTest.cs @@ -1,7 +1,7 @@ using System.Text; -using Coder.Desktop.Rpc.Proto; +using Coder.Desktop.Vpn.Proto; -namespace Coder.Desktop.Tests.Rpc.Proto; +namespace Coder.Desktop.Tests.Vpn.Proto; [TestFixture] public class RpcHeaderTest diff --git a/Tests/Rpc.Proto/RpcMessageTest.cs b/Tests/Vpn.Proto/RpcMessageTest.cs similarity index 88% rename from Tests/Rpc.Proto/RpcMessageTest.cs rename to Tests/Vpn.Proto/RpcMessageTest.cs index 9f9c73f..36de12d 100644 --- a/Tests/Rpc.Proto/RpcMessageTest.cs +++ b/Tests/Vpn.Proto/RpcMessageTest.cs @@ -1,6 +1,6 @@ -using Coder.Desktop.Rpc.Proto; +using Coder.Desktop.Vpn.Proto; -namespace Coder.Desktop.Tests.Rpc.Proto; +namespace Coder.Desktop.Tests.Vpn.Proto; [TestFixture] public class RpcRoleAttributeTest @@ -31,7 +31,7 @@ public void GetRole() // RpcRoleAttribute var ex = Assert.Throws(() => _ = RpcMessage.GetRole()); Assert.That(ex.Message, - Does.Contain("Message type 'Coder.Desktop.Rpc.Proto.RPC' does not have a RpcRoleAttribute")); + Does.Contain("Message type 'Coder.Desktop.Vpn.Proto.RPC' does not have a RpcRoleAttribute")); Assert.That(ManagerMessage.GetRole().ToString(), Is.EqualTo(RpcRole.Manager)); Assert.That(TunnelMessage.GetRole().ToString(), Is.EqualTo(RpcRole.Tunnel)); diff --git a/Tests/Rpc.Proto/RpcRoleTest.cs b/Tests/Vpn.Proto/RpcRoleTest.cs similarity index 88% rename from Tests/Rpc.Proto/RpcRoleTest.cs rename to Tests/Vpn.Proto/RpcRoleTest.cs index 59ad489..f39d5cb 100644 --- a/Tests/Rpc.Proto/RpcRoleTest.cs +++ b/Tests/Vpn.Proto/RpcRoleTest.cs @@ -1,6 +1,6 @@ -using Coder.Desktop.Rpc.Proto; +using Coder.Desktop.Vpn.Proto; -namespace Coder.Desktop.Tests.Rpc.Proto; +namespace Coder.Desktop.Tests.Vpn.Proto; [TestFixture] public class RpcRoleTest diff --git a/Tests/Rpc/SerdesTest.cs b/Tests/Vpn/SerdesTest.cs similarity index 97% rename from Tests/Rpc/SerdesTest.cs rename to Tests/Vpn/SerdesTest.cs index df34cfa..7673d6a 100644 --- a/Tests/Rpc/SerdesTest.cs +++ b/Tests/Vpn/SerdesTest.cs @@ -1,9 +1,9 @@ using System.Buffers.Binary; -using Coder.Desktop.Rpc; -using Coder.Desktop.Rpc.Proto; +using Coder.Desktop.Vpn; +using Coder.Desktop.Vpn.Proto; using Google.Protobuf; -namespace Coder.Desktop.Tests.Rpc; +namespace Coder.Desktop.Tests.Vpn; [TestFixture] public class SerdesTest diff --git a/Tests/Rpc/SpeakerTest.cs b/Tests/Vpn/SpeakerTest.cs similarity index 99% rename from Tests/Rpc/SpeakerTest.cs rename to Tests/Vpn/SpeakerTest.cs index 7b26546..3eeebb3 100644 --- a/Tests/Rpc/SpeakerTest.cs +++ b/Tests/Vpn/SpeakerTest.cs @@ -2,10 +2,10 @@ using System.IO.Pipelines; using System.Reflection; using System.Threading.Channels; -using Coder.Desktop.Rpc; -using Coder.Desktop.Rpc.Proto; +using Coder.Desktop.Vpn; +using Coder.Desktop.Vpn.Proto; -namespace Coder.Desktop.Tests.Rpc; +namespace Coder.Desktop.Tests.Vpn; #region BidrectionalPipe diff --git a/Rpc.Proto/ApiVersion.cs b/Vpn.Proto/ApiVersion.cs similarity index 98% rename from Rpc.Proto/ApiVersion.cs rename to Vpn.Proto/ApiVersion.cs index fcadd64..25d96f9 100644 --- a/Rpc.Proto/ApiVersion.cs +++ b/Vpn.Proto/ApiVersion.cs @@ -1,4 +1,4 @@ -namespace Coder.Desktop.Rpc.Proto; +namespace Coder.Desktop.Vpn.Proto; /// /// Thrown when the two peers are incompatible with each other. diff --git a/Rpc.Proto/RpcHeader.cs b/Vpn.Proto/RpcHeader.cs similarity index 97% rename from Rpc.Proto/RpcHeader.cs rename to Vpn.Proto/RpcHeader.cs index 9e3bce5..0aa63ae 100644 --- a/Rpc.Proto/RpcHeader.cs +++ b/Vpn.Proto/RpcHeader.cs @@ -1,6 +1,6 @@ using System.Text; -namespace Coder.Desktop.Rpc.Proto; +namespace Coder.Desktop.Vpn.Proto; /// /// A header to write or read from a stream to identify the speaker's role and version. diff --git a/Rpc.Proto/RpcMessage.cs b/Vpn.Proto/RpcMessage.cs similarity index 98% rename from Rpc.Proto/RpcMessage.cs rename to Vpn.Proto/RpcMessage.cs index 035e87c..c44168c 100644 --- a/Rpc.Proto/RpcMessage.cs +++ b/Vpn.Proto/RpcMessage.cs @@ -1,7 +1,7 @@ using System.Reflection; using Google.Protobuf; -namespace Coder.Desktop.Rpc.Proto; +namespace Coder.Desktop.Vpn.Proto; [AttributeUsage(AttributeTargets.Class, Inherited = false)] public class RpcRoleAttribute(string role) : Attribute diff --git a/Rpc.Proto/RpcRole.cs b/Vpn.Proto/RpcRole.cs similarity index 97% rename from Rpc.Proto/RpcRole.cs rename to Vpn.Proto/RpcRole.cs index 275da24..9190281 100644 --- a/Rpc.Proto/RpcRole.cs +++ b/Vpn.Proto/RpcRole.cs @@ -1,4 +1,4 @@ -namespace Coder.Desktop.Rpc.Proto; +namespace Coder.Desktop.Vpn.Proto; /// /// Represents a role that either side of the connection can fulfil. diff --git a/Rpc.Proto/Rpc.Proto.csproj b/Vpn.Proto/Vpn.Proto.csproj similarity index 91% rename from Rpc.Proto/Rpc.Proto.csproj rename to Vpn.Proto/Vpn.Proto.csproj index 2d9d4c6..5380bd4 100644 --- a/Rpc.Proto/Rpc.Proto.csproj +++ b/Vpn.Proto/Vpn.Proto.csproj @@ -1,7 +1,7 @@  - Coder.Desktop.Rpc.Proto + Coder.Desktop.Vpn.Proto net8.0 enable enable diff --git a/Rpc.Proto/vpn.proto b/Vpn.Proto/vpn.proto similarity index 99% rename from Rpc.Proto/vpn.proto rename to Vpn.Proto/vpn.proto index dda973d..33a3ff4 100644 --- a/Rpc.Proto/vpn.proto +++ b/Vpn.Proto/vpn.proto @@ -1,6 +1,6 @@ syntax = "proto3"; option go_package = "github.com/coder/coder/v2/vpn"; -option csharp_namespace = "Coder.Desktop.Rpc.Proto"; +option csharp_namespace = "Coder.Desktop.Vpn.Proto"; import "google/protobuf/timestamp.proto"; diff --git a/Rpc/Serdes.cs b/Vpn/Serdes.cs similarity index 96% rename from Rpc/Serdes.cs rename to Vpn/Serdes.cs index cf19655..317417b 100644 --- a/Rpc/Serdes.cs +++ b/Vpn/Serdes.cs @@ -1,8 +1,8 @@ using System.Buffers.Binary; -using Coder.Desktop.Rpc.Proto; +using Coder.Desktop.Vpn.Proto; using Google.Protobuf; -namespace Coder.Desktop.Rpc; +namespace Coder.Desktop.Vpn; /// /// RaiiSemaphoreSlim is a wrapper around SemaphoreSlim that provides RAII-style locking. @@ -54,12 +54,12 @@ public async Task WriteMessage(Stream conn, TS message, CancellationToken ct = d using var _ = await _writeLock.LockAsync(ct); var mb = message.ToByteArray(); - if (mb.Length == 0) + if (mb == null || mb.Length == 0) throw new ArgumentException("Marshalled message is empty"); if (mb.Length > MaxMessageSize) throw new ArgumentException($"Marshalled message size {mb.Length} exceeds maximum {MaxMessageSize}"); - var lenBytes = new byte[4]; + var lenBytes = new byte[sizeof(uint)]; BinaryPrimitives.WriteUInt32BigEndian(lenBytes, (uint)mb.Length); await conn.WriteAsync(lenBytes, ct); await conn.WriteAsync(mb, ct); diff --git a/Rpc/Speaker.cs b/Vpn/Speaker.cs similarity index 99% rename from Rpc/Speaker.cs rename to Vpn/Speaker.cs index 73297d7..0bf57eb 100644 --- a/Rpc/Speaker.cs +++ b/Vpn/Speaker.cs @@ -1,10 +1,10 @@ using System.Collections.Concurrent; using System.Text; -using Coder.Desktop.Rpc.Proto; -using Coder.Desktop.Rpc.Utilities; +using Coder.Desktop.Vpn.Proto; +using Coder.Desktop.Vpn.Utilities; using Google.Protobuf; -namespace Coder.Desktop.Rpc; +namespace Coder.Desktop.Vpn; /// /// Wraps a RpcMessage to allow easily sending a reply via the Speaker. diff --git a/Rpc/Utilities/TaskUtilities.cs b/Vpn/Utilities/TaskUtilities.cs similarity index 98% rename from Rpc/Utilities/TaskUtilities.cs rename to Vpn/Utilities/TaskUtilities.cs index 65f382d..8a2bfdb 100644 --- a/Rpc/Utilities/TaskUtilities.cs +++ b/Vpn/Utilities/TaskUtilities.cs @@ -1,4 +1,4 @@ -namespace Coder.Desktop.Rpc.Utilities; +namespace Coder.Desktop.Vpn.Utilities; internal static class TaskUtilities { diff --git a/Rpc/Rpc.csproj b/Vpn/Vpn.csproj similarity index 68% rename from Rpc/Rpc.csproj rename to Vpn/Vpn.csproj index 135f605..bcef1b5 100644 --- a/Rpc/Rpc.csproj +++ b/Vpn/Vpn.csproj @@ -1,14 +1,14 @@  - Coder.Desktop.Rpc + Coder.Desktop.Vpn net8.0 enable enable - + From 0e65c36dce07bffe68478f20b89eb0daee756f71 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Wed, 27 Nov 2024 17:14:32 +0900 Subject: [PATCH 8/8] fixup! Rename Coder.Desktop.Rpc to Coder.Desktop.Vpn --- Vpn/Speaker.cs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/Vpn/Speaker.cs b/Vpn/Speaker.cs index 0bf57eb..030f908 100644 --- a/Vpn/Speaker.cs +++ b/Vpn/Speaker.cs @@ -167,9 +167,8 @@ private async Task ReceiveLoop(CancellationToken ct = default) // complete it with the message. if (_pendingReplies.TryRemove(message.RpcField.ResponseTo, out var tcs)) tcs.SetResult(message); - else - // TODO: we should log unknown replies - continue; + // TODO: we should log unknown replies + continue; } // Start a new task in the background to handle the message.