diff --git a/src/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.BIO.cs b/src/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.BIO.cs index ffa82ef001f8..194cecbbb0c7 100644 --- a/src/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.BIO.cs +++ b/src/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.BIO.cs @@ -35,5 +35,27 @@ internal static partial class Crypto [DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_BioCtrlPending")] internal static extern int BioCtrlPending(SafeBioHandle bio); + + [DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_BioSetAppData")] + internal static extern void BioSetAppData(SafeBioHandle bio, IntPtr data); + + [DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_BioSetWriteFlag")] + internal static extern void BioSetWriteFlag(SafeBioHandle bio); + + [DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_BioSetShoudRetryReadFlag")] + internal static extern void BioSetShoudRetryReadFlag(SafeBioHandle bio); + + //These need to be here and private to ensure the static constructor is run to init the bio on the class + [DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_CreateManagedSslBio")] + private static extern SafeBioHandle CreateManagedSslBio(); + + [DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_InitManagedSslBioMethod")] + private static extern void InitManagedSslBioMethod(WriteDelegate bwrite, ReadDelegate bread); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + private unsafe delegate int ReadDelegate(IntPtr bio, void* buf, int size, IntPtr data); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + private unsafe delegate int WriteDelegate(IntPtr bio, void* buf, int num, IntPtr data); } } diff --git a/src/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.ManagedSslBio.cs b/src/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.ManagedSslBio.cs new file mode 100644 index 000000000000..011f402b6c2d --- /dev/null +++ b/src/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.ManagedSslBio.cs @@ -0,0 +1,57 @@ +using System; +using System.Runtime.InteropServices; +using Microsoft.Win32.SafeHandles; +using System.Diagnostics; + +internal static partial class Interop +{ + internal static partial class Crypto + { + internal static class ManagedSslBio + { + private unsafe readonly static ReadDelegate s_readDelegate; + private unsafe readonly static WriteDelegate s_writeDelegate; + + internal static SafeBioHandle CreateManagedSslBio() => Crypto.CreateManagedSslBio(); + + unsafe static ManagedSslBio() + { + s_writeDelegate = Write; + s_readDelegate = Read; + Crypto.InitManagedSslBioMethod(s_writeDelegate, s_readDelegate); + } + + internal static void BioSetGCHandle(SafeBioHandle bio, GCHandle handle) + { + IntPtr pointer = handle.IsAllocated ? GCHandle.ToIntPtr(handle) : IntPtr.Zero; + Crypto.BioSetAppData(bio, pointer); + } + + private static unsafe int Write(IntPtr bio, void* input, int size, IntPtr data) + { + GCHandle handle = GCHandle.FromIntPtr(data); + Debug.Assert(handle.IsAllocated); + + if (handle.Target is SafeSslHandle.WriteBioBuffer buffer) + { + return buffer.Write(new Span(input, size)); + } + + return -1; + } + + private static unsafe int Read(IntPtr bio, void* output, int size, IntPtr data) + { + GCHandle handle = GCHandle.FromIntPtr(data); + Debug.Assert(handle.IsAllocated); + + if (handle.Target is SafeSslHandle.ReadBioBuffer buffer) + { + return buffer.Read(new Span(output, size)); + } + + return -1; + } + } + } +} diff --git a/src/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.OpenSsl.cs b/src/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.OpenSsl.cs index 5ac8a42b5977..8916ed8596bd 100644 --- a/src/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.OpenSsl.cs +++ b/src/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.OpenSsl.cs @@ -73,6 +73,10 @@ internal static SafeSslHandle AllocateSslContext(SslProtocols protocols, SafeX50 // https://www.openssl.org/docs/manmaster/ssl/SSL_shutdown.html Ssl.SslCtxSetQuietShutdown(innerContext); + // This allows the write buffer to move during a multi call write, this stops us having to pin it + // across multiple calls where there is an async output to the innerstream inbetween + Ssl.SslCtxSetAcceptMovingWriteBuffer(innerContext); + if (!Ssl.SetEncryptionPolicy(innerContext, policy)) { throw new PlatformNotSupportedException(SR.Format(SR.net_ssl_encryptionpolicy_notsupported, policy)); @@ -93,7 +97,7 @@ internal static SafeSslHandle AllocateSslContext(SslProtocols protocols, SafeX50 Ssl.SslCtxSetVerify(innerContext, s_verifyClientCertificate); - //update the client CA list + // update the client CA list UpdateCAListFromRootStore(innerContext); } @@ -135,17 +139,21 @@ internal static bool DoSslHandshake(SafeSslHandle context, byte[] recvBuf, int r { sendBuf = null; sendCount = 0; - if ((recvBuf != null) && (recvCount > 0)) + + sendCount = context.OutputBio.TakeBytes(out sendBuf); + if (recvBuf == null && sendCount > 0) { - BioWrite(context.InputBio, recvBuf, recvOffset, recvCount); + return false; } + context.InputBio.SetData(recvBuf, recvOffset, recvCount); + context.OutputBio.SetData(buffer: null, isHandshake: true); + int retVal = Ssl.SslDoHandshake(context); if (retVal != 1) { - Exception innerError; - Ssl.SslErrorCode error = GetSslError(context, retVal, out innerError); + Ssl.SslErrorCode error = GetSslError(context, retVal, out Exception innerError); if ((retVal != -1) || (error != Ssl.SslErrorCode.SSL_ERROR_WANT_READ)) { @@ -153,31 +161,14 @@ internal static bool DoSslHandshake(SafeSslHandle context, byte[] recvBuf, int r } } - sendCount = Crypto.BioCtrlPending(context.OutputBio); - - if (sendCount > 0) - { - sendBuf = new byte[sendCount]; - - try - { - sendCount = BioRead(context.OutputBio, sendBuf, sendCount); - } - finally - { - if (sendCount <= 0) - { - sendBuf = null; - sendCount = 0; - } - } - } + sendCount = context.OutputBio.TakeBytes(out sendBuf); bool stateOk = Ssl.IsSslStateOK(context); if (stateOk) { context.MarkHandshakeCompleted(); } + return stateOk; } @@ -190,6 +181,7 @@ internal static int Encrypt(SafeSslHandle context, byte[] input, int offset, int Debug.Assert(input.Length - offset >= count); errorCode = Ssl.SslErrorCode.SSL_ERROR_NONE; + context.OutputBio.SetData(output, isHandshake: false); int retVal; unsafe @@ -202,8 +194,7 @@ internal static int Encrypt(SafeSslHandle context, byte[] input, int offset, int if (retVal != count) { - Exception innerError; - errorCode = GetSslError(context, retVal, out innerError); + errorCode = GetSslError(context, retVal, out Exception innerError); retVal = 0; switch (errorCode) @@ -212,52 +203,46 @@ internal static int Encrypt(SafeSslHandle context, byte[] input, int offset, int case Ssl.SslErrorCode.SSL_ERROR_ZERO_RETURN: case Ssl.SslErrorCode.SSL_ERROR_WANT_READ: break; - + // indicates we need to write the out buffer and write again + case Ssl.SslErrorCode.SSL_ERROR_WANT_WRITE: + break; default: throw new SslException(SR.Format(SR.net_ssl_encrypt_failed, errorCode), innerError); } } - else - { - int capacityNeeded = Crypto.BioCtrlPending(context.OutputBio); - - if (output == null || output.Length < capacityNeeded) - { - output = new byte[capacityNeeded]; - } - - retVal = BioRead(context.OutputBio, output, capacityNeeded); - } - return retVal; + int bytesWritten = context.OutputBio.BytesWritten; + context.OutputBio.Reset(); + return bytesWritten; } internal static int Decrypt(SafeSslHandle context, byte[] outBuffer, int offset, int count, out Ssl.SslErrorCode errorCode) { + Debug.Assert(offset >= 0); + Debug.Assert(offset <= outBuffer.Length); + errorCode = Ssl.SslErrorCode.SSL_ERROR_NONE; - int retVal = BioWrite(context.InputBio, outBuffer, offset, count); + context.InputBio.SetData(outBuffer, offset, count); - if (retVal == count) + int retVal; + unsafe { - unsafe + fixed (byte* fixedBuffer = outBuffer) { - fixed (byte* fixedBuffer = outBuffer) - { - retVal = Ssl.SslRead(context, fixedBuffer + offset, outBuffer.Length); - } + retVal = Ssl.SslRead(context, fixedBuffer + offset, outBuffer.Length - offset); } + } - if (retVal > 0) - { - count = retVal; - } + if (retVal > 0) + { + count = retVal; } + if (retVal != count) { - Exception innerError; - errorCode = GetSslError(context, retVal, out innerError); + errorCode = GetSslError(context, retVal, out Exception innerError); retVal = 0; switch (errorCode) @@ -345,10 +330,10 @@ private static void AddX509Names(SafeX509NameStackHandle nameStack, StoreLocatio { store.Open(OpenFlags.ReadOnly); - foreach (var certificate in store.Certificates) + foreach (X509Certificate2 certificate in store.Certificates) { - //Check if issuer name is already present - //Avoiding duplicate names + // Check if issuer name is already present + // Avoiding duplicate names if (!issuerNameHashSet.Add(certificate.Issuer)) { continue; @@ -458,7 +443,7 @@ private static void SetSslCertificate(SafeSslContextHandle contextPtr, SafeX509H throw CreateSslException(SR.net_ssl_use_private_key_failed); } - //check private key + // check private key retVal = Ssl.SslCtxCheckPrivateKey(contextPtr); if (1 != retVal) diff --git a/src/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.Ssl.cs b/src/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.Ssl.cs old mode 100644 new mode 100755 index c468e1a34ad0..b611d9a319fb --- a/src/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.Ssl.cs +++ b/src/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.Ssl.cs @@ -4,6 +4,7 @@ using System; using System.Diagnostics; +using System.Runtime; using System.Runtime.InteropServices; using System.Security.Cryptography.X509Certificates; using Microsoft.Win32.SafeHandles; @@ -156,12 +157,12 @@ internal enum SslErrorCode SSL_ERROR_WANT_WRITE = 3, SSL_ERROR_SYSCALL = 5, SSL_ERROR_ZERO_RETURN = 6, - + // NOTE: this SslErrorCode value doesn't exist in OpenSSL, but // we use it to distinguish when a renegotiation is pending. // Choosing an arbitrarily large value that shouldn't conflict // with any actual OpenSSL error codes - SSL_ERROR_RENEGOTIATE = 29304 + SSL_ERROR_RENEGOTIATE = 29304 } } } @@ -170,8 +171,8 @@ namespace Microsoft.Win32.SafeHandles { internal sealed class SafeSslHandle : SafeHandle { - private SafeBioHandle _readBio; - private SafeBioHandle _writeBio; + private ReadBioBuffer _readBio; + private WriteBioBuffer _writeBio; private bool _isServer; private bool _handshakeCompleted = false; @@ -180,21 +181,8 @@ public bool IsServer get { return _isServer; } } - public SafeBioHandle InputBio - { - get - { - return _readBio; - } - } - - public SafeBioHandle OutputBio - { - get - { - return _writeBio; - } - } + public ReadBioBuffer InputBio => _readBio; + public WriteBioBuffer OutputBio => _writeBio; internal void MarkHandshakeCompleted() { @@ -203,8 +191,8 @@ internal void MarkHandshakeCompleted() public static SafeSslHandle Create(SafeSslContextHandle context, bool isServer) { - SafeBioHandle readBio = Interop.Crypto.CreateMemoryBio(); - SafeBioHandle writeBio = Interop.Crypto.CreateMemoryBio(); + SafeBioHandle readBio = Interop.Crypto.ManagedSslBio.CreateManagedSslBio(); + SafeBioHandle writeBio = Interop.Crypto.ManagedSslBio.CreateManagedSslBio(); SafeSslHandle handle = Interop.Ssl.SslCreate(context); if (readBio.IsInvalid || writeBio.IsInvalid || handle.IsInvalid) { @@ -220,8 +208,8 @@ public static SafeSslHandle Create(SafeSslContextHandle context, bool isServer) { readBio.TransferOwnershipToParent(handle); writeBio.TransferOwnershipToParent(handle); - handle._readBio = readBio; - handle._writeBio = writeBio; + handle._readBio = new ReadBioBuffer(readBio); + handle._writeBio = new WriteBioBuffer(writeBio); Interop.Ssl.SslSetBio(handle, readBio, writeBio); } catch (Exception exc) @@ -296,5 +284,148 @@ internal SafeSslHandle(IntPtr validSslPointer, bool ownsHandle) : base(IntPtr.Ze { handle = validSslPointer; } + + internal class ReadBioBuffer : IDisposable + { + private readonly SafeBioHandle _bioHandle; + private GCHandle _handle; + private int _bytesAvailable; + private byte[] _byteArray; + private int _offset; + + internal ReadBioBuffer(SafeBioHandle bioHandle) + { + _bioHandle = bioHandle; + _handle = GCHandle.Alloc(this, GCHandleType.Normal); + Interop.Crypto.ManagedSslBio.BioSetGCHandle(_bioHandle, _handle); + Interop.Crypto.BioSetShoudRetryReadFlag(bioHandle); + } + + public void SetData(byte[] buffer, int offset, int length) + { + Debug.Assert(_bytesAvailable == 0); + + _byteArray = buffer; + _offset = offset; + _bytesAvailable = length; + } + + public int Read(Span output) + { + int bytesToCopy = Math.Min(output.Length, _bytesAvailable); + if (bytesToCopy == 0) + { + return -1; + } + + var span = new Span(_byteArray, _offset, bytesToCopy); + span.CopyTo(output); + _offset += bytesToCopy; + _bytesAvailable -= bytesToCopy; + return bytesToCopy; + } + + // Bio is already released by the ssl object + public void Dispose() + { + if (_handle.IsAllocated) + { + _handle.Free(); + } + } + } + + internal class WriteBioBuffer : IDisposable + { + private readonly SafeBioHandle _bioHandle; + private GCHandle _handle; + private byte[] _byteArray; + private int _bytesWritten; + private bool _isHandshake; + + internal WriteBioBuffer(SafeBioHandle bioHandle) + { + _bioHandle = bioHandle; + _handle = GCHandle.Alloc(this, GCHandleType.Normal); + Interop.Crypto.ManagedSslBio.BioSetGCHandle(_bioHandle, _handle); + } + + public int BytesWritten => _bytesWritten; + + public void SetData(byte[] buffer, bool isHandshake) + { + Debug.Assert(_byteArray == null); + + _byteArray = buffer; + _bytesWritten = 0; + _isHandshake = isHandshake; + } + + public int TakeBytes(out byte[] output) + { + output = _byteArray; + int bytes = _bytesWritten; + Reset(); + return bytes; + } + + public void Reset() + { + _bytesWritten = 0; + _byteArray = null; + } + + public int Write(Span input) + { + // Only for the handshake do we dynamically allocate + // buffers. For normal encrypt operations we use a fixed + // size buffer handed to us and loop to do all the needed + // writes. This should be changed for the handshake as well + // but will require more securechannel/sslstatus changes + if (_isHandshake) + { + if (_byteArray == null) + { + _byteArray = new byte[input.Length]; + _bytesWritten = 0; + } + else if (_byteArray.Length - _bytesWritten < input.Length) + { + byte[] oldArray = _byteArray; + _byteArray = new byte[input.Length + _bytesWritten]; + Buffer.BlockCopy(oldArray, 0, _byteArray, 0, _bytesWritten); + } + } + int bytesToWrite; + if (_byteArray == null) + { + bytesToWrite = -1; + } + else + { + bytesToWrite = Math.Min(input.Length, _byteArray.Length - _bytesWritten); + } + if (bytesToWrite < 1) + { + // We need to return -1 to indicate that it is an async method and + // and the write should retry later rather and a zero indicating EOF + Interop.Crypto.BioSetWriteFlag(_bioHandle); + return -1; + } + + input.Slice(0, bytesToWrite).CopyTo(new Span(_byteArray, _bytesWritten)); + _bytesWritten += bytesToWrite; + return bytesToWrite; + } + + // Bio is already released by the ssl object + public void Dispose() + { + if (_handle.IsAllocated) + { + _handle.Free(); + } + } + } } } diff --git a/src/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.SslCtxOptions.cs b/src/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.SslCtxOptions.cs index 3ab95d491ed1..29951bc69d33 100644 --- a/src/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.SslCtxOptions.cs +++ b/src/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.SslCtxOptions.cs @@ -35,5 +35,8 @@ internal static partial class Ssl [DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslCtxSetClientCAList")] internal static extern void SslCtxSetClientCAList(SafeSslContextHandle ctx, SafeX509NameStackHandle x509NameStackPtr); + + [DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslCtxSetAcceptMovingWriteBuffer")] + internal static extern void SslCtxSetAcceptMovingWriteBuffer(SafeSslContextHandle ctx); } } diff --git a/src/Native/Unix/System.Security.Cryptography.Native/opensslshim.h b/src/Native/Unix/System.Security.Cryptography.Native/opensslshim.h index 3e405b09da87..1258d00b2860 100644 --- a/src/Native/Unix/System.Security.Cryptography.Native/opensslshim.h +++ b/src/Native/Unix/System.Security.Cryptography.Native/opensslshim.h @@ -68,11 +68,14 @@ int EC_POINT_set_affine_coordinates_GF2m(const EC_GROUP *group, EC_POINT *p, PER_FUNCTION_BLOCK(BIO_ctrl, true) \ PER_FUNCTION_BLOCK(BIO_ctrl_pending, true) \ PER_FUNCTION_BLOCK(BIO_free, true) \ + PER_FUNCTION_BLOCK(BIO_get_ex_data, true) \ PER_FUNCTION_BLOCK(BIO_gets, true) \ PER_FUNCTION_BLOCK(BIO_new, true) \ PER_FUNCTION_BLOCK(BIO_new_file, true) \ PER_FUNCTION_BLOCK(BIO_read, true) \ PER_FUNCTION_BLOCK(BIO_s_mem, true) \ + PER_FUNCTION_BLOCK(BIO_set_ex_data, true) \ + PER_FUNCTION_BLOCK(BIO_set_flags, true) \ PER_FUNCTION_BLOCK(BIO_write, true) \ PER_FUNCTION_BLOCK(BN_bin2bn, true) \ PER_FUNCTION_BLOCK(BN_bn2bin, true) \ @@ -359,11 +362,14 @@ FOR_ALL_OPENSSL_FUNCTIONS #define BIO_ctrl BIO_ctrl_ptr #define BIO_ctrl_pending BIO_ctrl_pending_ptr #define BIO_free BIO_free_ptr +#define BIO_get_ex_data BIO_get_ex_data_ptr #define BIO_gets BIO_gets_ptr #define BIO_new BIO_new_ptr #define BIO_new_file BIO_new_file_ptr #define BIO_read BIO_read_ptr #define BIO_s_mem BIO_s_mem_ptr +#define BIO_set_ex_data BIO_set_ex_data_ptr +#define BIO_set_flags BIO_set_flags_ptr #define BIO_write BIO_write_ptr #define BN_bin2bn BN_bin2bn_ptr #define BN_bn2bin BN_bn2bin_ptr diff --git a/src/Native/Unix/System.Security.Cryptography.Native/pal_bio.cpp b/src/Native/Unix/System.Security.Cryptography.Native/pal_bio.cpp index b4009d5c402e..4311694f3061 100644 --- a/src/Native/Unix/System.Security.Cryptography.Native/pal_bio.cpp +++ b/src/Native/Unix/System.Security.Cryptography.Native/pal_bio.cpp @@ -6,6 +6,72 @@ #include +typedef struct ReadWriteMethodStruct +{ + BioWriteCallback write; + BioReadCallback read; +} ReadWriteMethodStruct; + +static ReadWriteMethodStruct managedMethods = {nullptr, nullptr}; + +static long ControlCallback(BIO* bio, int cmd, long param, void* ptr) +{ + (void)bio, (void)param, (void)ptr; // deliberately unused parameters + switch (cmd) + { + case BIO_CTRL_FLUSH: + case BIO_CTRL_POP: + case BIO_CTRL_PUSH: + return 1; + } + return 0; +} + +static int DestroyCallback(BIO* bio) +{ + (void)bio; // deliberately unused parameter + return -1; +} + +static int CreateCallback(BIO* bio) +{ + bio->init = 1; + return 1; +} + +static int WriteCallback(BIO* b, const char* buf, int32_t len) +{ + void* ptr = BIO_get_ex_data(b, 0); + if (ptr == nullptr) + { + return -1; + } + return managedMethods.write(b, buf, len, ptr); +} + +static int ReadCallback(BIO* b, char* buf, int32_t len) +{ + void* ptr = BIO_get_ex_data(b, 0); + if (ptr == nullptr) + { + return -1; + } + return managedMethods.read(b, buf, len, ptr); +} + +static BIO_METHOD managedSslBio = { + BIO_TYPE_SOURCE_SINK, + "Managed Ssl Bio", + WriteCallback, + ReadCallback, + nullptr, + nullptr, + ControlCallback, + CreateCallback, + DestroyCallback, + nullptr, +}; + extern "C" BIO* CryptoNative_CreateMemoryBio() { return BIO_new(BIO_s_mem()); @@ -52,3 +118,29 @@ extern "C" int32_t CryptoNative_BioCtrlPending(BIO* bio) assert(result <= INT32_MAX); return static_cast(result); } + +extern "C" void CryptoNative_BioSetAppData(BIO* bio, void* data) +{ + BIO_set_ex_data(bio, 0, data); +} + +extern "C" void CryptoNative_BioSetWriteFlag(BIO* bio) +{ + BIO_set_flags(bio, BIO_FLAGS_WRITE); +} + +extern "C" void CryptoNative_BioSetShoudRetryReadFlag(BIO* bio) +{ + BIO_set_flags(bio, BIO_FLAGS_SHOULD_RETRY | BIO_FLAGS_READ); +} + +extern "C" void CryptoNative_InitManagedSslBioMethod(BioWriteCallback bwrite, BioReadCallback bread) +{ + managedMethods.write = bwrite; + managedMethods.read = bread; +} + +extern "C" BIO* CryptoNative_CreateManagedSslBio() +{ + return BIO_new(&managedSslBio); +} diff --git a/src/Native/Unix/System.Security.Cryptography.Native/pal_bio.h b/src/Native/Unix/System.Security.Cryptography.Native/pal_bio.h index 6897dc14030b..ce77cd11edba 100644 --- a/src/Native/Unix/System.Security.Cryptography.Native/pal_bio.h +++ b/src/Native/Unix/System.Security.Cryptography.Native/pal_bio.h @@ -5,6 +5,9 @@ #include "pal_types.h" #include "opensslshim.h" +typedef int32_t (*BioWriteCallback)(BIO* b, const char* buf, int32_t len, void* appData); +typedef int32_t (*BioReadCallback)(BIO* b, char* buf, int32_t len, void* appData); + /* Creates a new memory-backed BIO instance. */ @@ -54,3 +57,30 @@ Shims the BIO_ctrl_pending method. Returns the number of pending characters in the BIOs read and write buffers. */ extern "C" int32_t CryptoNative_BioCtrlPending(BIO* bio); + +/* +Adds app data to the extension slot of the bio +*/ +extern "C" void CryptoNative_BioSetAppData(BIO* bio, void* data); + +/* +Set write flag for the custom bio +*/ +extern "C" void CryptoNative_BioSetWriteFlag(BIO* bio); + +/* +Set the read and should retry flag for the custom bio +*/ +extern "C" void CryptoNative_BioSetShoudRetryReadFlag(BIO* bio); + +/* +Creates a ManagedSslBio instance. +*/ +extern "C" BIO* CryptoNative_CreateManagedSslBio(); + +/* +Sets the managed callbacks that are used by the ManagedSslBio +for reads and writes +*/ +extern "C" void CryptoNative_InitManagedSslBioMethod(BioWriteCallback bwrite, BioReadCallback bread); + diff --git a/src/Native/Unix/System.Security.Cryptography.Native/pal_ssl.cpp b/src/Native/Unix/System.Security.Cryptography.Native/pal_ssl.cpp index fda1f4da9537..dc65f0a86c0f 100644 --- a/src/Native/Unix/System.Security.Cryptography.Native/pal_ssl.cpp +++ b/src/Native/Unix/System.Security.Cryptography.Native/pal_ssl.cpp @@ -41,6 +41,11 @@ extern "C" SSL_CTX* CryptoNative_SslCtxCreate(SSL_METHOD* method) return ctx; } +extern "C" void CryptoNative_SslCtxSetAcceptMovingWriteBuffer(SSL_CTX* ctx) +{ + SSL_CTX_ctrl(ctx, SSL_CTRL_MODE, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER, nullptr); +} + extern "C" void CryptoNative_SetProtocolOptions(SSL_CTX* ctx, SslProtocols protocols) { // protocols may be 0, meaning system default, in which case let OpenSSL do what OpenSSL wants. diff --git a/src/Native/Unix/System.Security.Cryptography.Native/pal_ssl.h b/src/Native/Unix/System.Security.Cryptography.Native/pal_ssl.h index 4fd7fb59c051..c69adde07d44 100644 --- a/src/Native/Unix/System.Security.Cryptography.Native/pal_ssl.h +++ b/src/Native/Unix/System.Security.Cryptography.Native/pal_ssl.h @@ -177,6 +177,11 @@ Always succeeds. */ extern "C" void CryptoNative_SslCtxDestroy(SSL_CTX* ctx); +/* +Shims the SSL_ctx_change_mode with the flag SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER +*/ +extern "C" void CryptoNative_SslCtxSetAcceptMovingWriteBuffer(SSL_CTX* ctx); + /* Shims the SSL_set_connect_state method. */ diff --git a/src/System.Net.Http/src/System.Net.Http.csproj b/src/System.Net.Http/src/System.Net.Http.csproj index 252c82030211..23fe7bbf349e 100644 --- a/src/System.Net.Http/src/System.Net.Http.csproj +++ b/src/System.Net.Http/src/System.Net.Http.csproj @@ -314,6 +314,9 @@ Common\Interop\Unix\System.Security.Cryptography.Native\Interop.BIO.cs + + Common\Interop\Unix\System.Security.Cryptography.Native\Interop.ManagedSslBio.cs + Common\Interop\Unix\System.Security.Cryptography.Native\Interop.ERR.cs diff --git a/src/System.Net.Security/System.Net.Security.sln b/src/System.Net.Security/System.Net.Security.sln index 7661a942233e..234b3b226483 100644 --- a/src/System.Net.Security/System.Net.Security.sln +++ b/src/System.Net.Security/System.Net.Security.sln @@ -1,6 +1,6 @@ Microsoft Visual Studio Solution File, Format Version 12.00 -# Visual Studio 14 -VisualStudioVersion = 14.0.25420.1 +# Visual Studio 15 +VisualStudioVersion = 15.0.26730.3 MinimumVisualStudioVersion = 10.0.40219.1 Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "System.Net.Security.Tests", "tests\FunctionalTests\System.Net.Security.Tests.csproj", "{A55A2B9A-830F-4330-A0E7-02A9FB30ABD2}" ProjectSection(ProjectDependencies) = postProject @@ -57,4 +57,7 @@ Global {89F37791-6254-4D60-AB96-ACD3CCA0E771} = {E107E9C1-E893-4E87-987E-04EF0DCEAEFD} {A7488FC0-9A8F-4EF9-BC3E-C5EBA47E13F8} = {2E666815-2EDB-464B-9DF6-380BF4789AD4} EndGlobalSection + GlobalSection(ExtensibilityGlobals) = postSolution + SolutionGuid = {21E9B22F-61C2-4AAF-A44C-CF37265A7AC5} + EndGlobalSection EndGlobal diff --git a/src/System.Net.Security/src/System.Net.Security.csproj b/src/System.Net.Security/src/System.Net.Security.csproj index b901f525744b..40e180d13c4d 100644 --- a/src/System.Net.Security/src/System.Net.Security.csproj +++ b/src/System.Net.Security/src/System.Net.Security.csproj @@ -18,7 +18,6 @@ - @@ -273,6 +272,9 @@ Common\Interop\Unix\System.Security.Cryptography.Native\Interop.BIO.cs + + Common\Interop\Unix\System.Security.Cryptography.Native\Interop.ManagedSslBio.cs + Common\Interop\Unix\System.Security.Cryptography.Native\Interop.ERR.cs @@ -386,6 +388,7 @@ + @@ -411,4 +414,4 @@ - \ No newline at end of file + diff --git a/src/System.Net.Security/src/System/Net/Security/SecureChannel.cs b/src/System.Net.Security/src/System/Net/Security/SecureChannel.cs index f623b2fdbf55..a72829e6b46e 100644 --- a/src/System.Net.Security/src/System/Net/Security/SecureChannel.cs +++ b/src/System.Net.Security/src/System/Net/Security/SecureChannel.cs @@ -83,7 +83,7 @@ internal SecureChannel(string hostname, bool serverMode, SslProtocols sslProtoco _certSelectionDelegate = certSelectionDelegate; _refreshCredentialNeeded = true; _encryptionPolicy = encryptionPolicy; - + if (NetEventSource.IsEnabled) NetEventSource.Exit(this); } @@ -845,7 +845,7 @@ private SecurityStatusPal GenerateToken(byte[] input, int offset, int count, ref } output = outgoingSecurity.token; - + return status; } @@ -905,8 +905,6 @@ internal SecurityStatusPal Encrypt(byte[] buffer, int offset, int size, ref byte NetEventSource.DumpBuffer(this, buffer, 0, Math.Min(buffer.Length, 128)); } - byte[] writeBuffer = output; - try { if (offset < 0 || offset > (buffer == null ? 0 : buffer.Length)) @@ -934,17 +932,22 @@ internal SecurityStatusPal Encrypt(byte[] buffer, int offset, int size, ref byte size, _headerSize, _trailerSize, - ref writeBuffer, + ref output, out resultSize); - - if (secStatus.ErrorCode != SecurityStatusPalErrorCode.OK) - { - if (NetEventSource.IsEnabled) NetEventSource.Exit(this, $"ERROR {secStatus}"); - } - else + if (NetEventSource.IsEnabled) { - output = writeBuffer; - if (NetEventSource.IsEnabled) NetEventSource.Exit(this, $"OK data size:{resultSize}"); + switch (secStatus.ErrorCode) + { + case SecurityStatusPalErrorCode.OK: + NetEventSource.Exit(this, $"OK data size:{resultSize}"); + break; + case SecurityStatusPalErrorCode.ContinueNeeded: + NetEventSource.Exit(this, $"OK but more writes needed data size:{resultSize}"); + break; + default: + NetEventSource.Exit(this, $"ERROR {secStatus}"); + break; + } } return secStatus; @@ -1151,7 +1154,7 @@ private ProtocolToken GenerateAlertToken() return token; } - + private static TlsAlertMessage GetAlertMessageFromChain(X509Chain chain) { foreach (X509ChainStatus chainStatus in chain.ChainStatus) @@ -1169,7 +1172,7 @@ private static TlsAlertMessage GetAlertMessageFromChain(X509Chain chain) } if ((chainStatus.Status & - (X509ChainStatusFlags.Revoked | X509ChainStatusFlags.OfflineRevocation )) != 0) + (X509ChainStatusFlags.Revoked | X509ChainStatusFlags.OfflineRevocation)) != 0) { return TlsAlertMessage.CertificateRevoked; } @@ -1183,7 +1186,7 @@ private static TlsAlertMessage GetAlertMessageFromChain(X509Chain chain) if ((chainStatus.Status & X509ChainStatusFlags.CtlNotValidForUsage) != 0) { - return TlsAlertMessage.UnsupportedCert; + return TlsAlertMessage.UnsupportedCert; } if ((chainStatus.Status & diff --git a/src/System.Net.Security/src/System/Net/Security/SslStreamInternal.cs b/src/System.Net.Security/src/System/Net/Security/SslStreamInternal.cs old mode 100644 new mode 100755 index 7b11fb88f6ce..cf71e4cf5f47 --- a/src/System.Net.Security/src/System/Net/Security/SslStreamInternal.cs +++ b/src/System.Net.Security/src/System/Net/Security/SslStreamInternal.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Buffers; using System.Diagnostics; using System.IO; using System.Runtime.ExceptionServices; @@ -21,24 +22,26 @@ internal class SslStreamInternal private static readonly AsyncProtocolCallback s_readHeaderCallback = new AsyncProtocolCallback(ReadHeaderCallback); private static readonly AsyncProtocolCallback s_readFrameCallback = new AsyncProtocolCallback(ReadFrameCallback); - private const int PinnableReadBufferSize = 4096 * 4 + 32; // We read in 16K chunks + headers. - private static PinnableBufferCache s_PinnableReadBufferCache = new PinnableBufferCache("System.Net.SslStream", PinnableReadBufferSize); - private const int PinnableWriteBufferSize = 4096 + 1024; // We write in 4K chunks + encryption overhead. - private static PinnableBufferCache s_PinnableWriteBufferCache = new PinnableBufferCache("System.Net.SslStream", PinnableWriteBufferSize); + private const int ReadBufferSize = 4096 * 4 + 32; // We read in 16K chunks + headers. + private const int WriteBufferSize = 4096 + 1024; // We write in 4K chunks + encryption overhead. + + private static readonly ArrayPool BufferCache = ArrayPool.Shared; private SslState _sslState; private int _nestedWrite; private int _nestedRead; private AsyncProtocolRequest _readProtocolRequest; // cached, reusable AsyncProtocolRequest used for read operations private AsyncProtocolRequest _writeProtocolRequest; // cached, reusable AsyncProtocolRequest used for write operations + private static Action s_freeBufferAction = (task, buffer) => + { + FreeBuffer((byte[])buffer); + task.GetAwaiter().GetResult(); // propagate any exception + }; // Never updated directly, special properties are used. This is the read buffer. private byte[] _internalBuffer; private bool _internalBufferFromPinnableCache; - private byte[] _pinnableOutputBuffer; // Used for writes when we can do it. - private byte[] _pinnableOutputBufferInUse; // Remembers what UNENCRYPTED buffer is using _PinnableOutputBuffer. - private int _internalOffset; private int _internalBufferCount; @@ -47,11 +50,6 @@ internal class SslStreamInternal internal SslStreamInternal(SslState sslState) { - if (PinnableBufferCacheEventSource.Log.IsEnabled()) - { - PinnableBufferCacheEventSource.Log.DebugMessage1("CTOR: In System.Net._SslStream.SslStream", this.GetHashCode()); - } - _sslState = sslState; _decryptedBytesOffset = 0; @@ -63,7 +61,7 @@ private void FreeReadBuffer() { if (_internalBufferFromPinnableCache) { - s_PinnableReadBufferCache.FreeBuffer(_internalBuffer); + BufferCache.Return(_internalBuffer); _internalBufferFromPinnableCache = false; } @@ -74,22 +72,8 @@ private void FreeReadBuffer() { if (_internalBufferFromPinnableCache) { - if (PinnableBufferCacheEventSource.Log.IsEnabled()) - { - PinnableBufferCacheEventSource.Log.DebugMessage2("DTOR: In System.Net._SslStream.~SslStream Freeing Read Buffer", this.GetHashCode(), PinnableBufferCacheEventSource.AddressOfByteArray(_internalBuffer)); - } - FreeReadBuffer(); } - if (_pinnableOutputBuffer != null) - { - if (PinnableBufferCacheEventSource.Log.IsEnabled()) - { - PinnableBufferCacheEventSource.Log.DebugMessage2("DTOR: In System.Net._SslStream.~SslStream Freeing Write Buffer", this.GetHashCode(), PinnableBufferCacheEventSource.AddressOfByteArray(_pinnableOutputBuffer)); - } - - s_PinnableWriteBufferCache.FreeBuffer(_pinnableOutputBuffer); - } } internal int ReadByte() @@ -231,23 +215,13 @@ private void EnsureInternalBufferSize(int newSize) bool wasPinnable = _internalBufferFromPinnableCache; byte[] saved = _internalBuffer; - if (newSize <= PinnableReadBufferSize) + if (newSize <= ReadBufferSize) { - if (PinnableBufferCacheEventSource.Log.IsEnabled()) - { - PinnableBufferCacheEventSource.Log.DebugMessage2("In System.Net._SslStream.EnsureInternalBufferSize IS pinnable", this.GetHashCode(), newSize); - } - _internalBufferFromPinnableCache = true; - _internalBuffer = s_PinnableReadBufferCache.AllocateBuffer(); + _internalBuffer = BufferCache.Rent(ReadBufferSize); } else { - if (PinnableBufferCacheEventSource.Log.IsEnabled()) - { - PinnableBufferCacheEventSource.Log.DebugMessage2("In System.Net._SslStream.EnsureInternalBufferSize NOT pinnable", this.GetHashCode(), newSize); - } - _internalBufferFromPinnableCache = false; _internalBuffer = new byte[newSize]; } @@ -259,7 +233,7 @@ private void EnsureInternalBufferSize(int newSize) if (wasPinnable) { - s_PinnableReadBufferCache.FreeBuffer(saved); + BufferCache.Return(saved); } } else if (_internalOffset > 0 && _internalBufferCount > 0) @@ -325,7 +299,7 @@ private AsyncProtocolRequest GetOrCreateProtocolRequest(ref AsyncProtocolRequest // private void ProcessWrite(byte[] buffer, int offset, int count, LazyAsyncResult asyncResult) { - _sslState.CheckThrow(authSuccessCheck:true, shutdownCheck:true); + _sslState.CheckThrow(authSuccessCheck: true, shutdownCheck: true); ValidateParameters(buffer, offset, count); if (Interlocked.Exchange(ref _nestedWrite, 1) == 1) @@ -373,31 +347,8 @@ private void StartWriting(byte[] buffer, int offset, int count, AsyncProtocolReq // We loop to this method from the callback. // If the last chunk was just completed from async callback (count < 0), we complete user request. - if (count >= 0 ) + if (count >= 0) { - byte[] outBuffer = null; - if (_pinnableOutputBufferInUse == null) - { - if (_pinnableOutputBuffer == null) - { - _pinnableOutputBuffer = s_PinnableWriteBufferCache.AllocateBuffer(); - } - - _pinnableOutputBufferInUse = buffer; - outBuffer = _pinnableOutputBuffer; - if (PinnableBufferCacheEventSource.Log.IsEnabled()) - { - PinnableBufferCacheEventSource.Log.DebugMessage3("In System.Net._SslStream.StartWriting Trying Pinnable", this.GetHashCode(), count, PinnableBufferCacheEventSource.AddressOfByteArray(outBuffer)); - } - } - else - { - if (PinnableBufferCacheEventSource.Log.IsEnabled()) - { - PinnableBufferCacheEventSource.Log.DebugMessage2("In System.Net._SslStream.StartWriting BufferInUse", this.GetHashCode(), count); - } - } - do { if (count == 0 && !SslStreamPal.CanEncryptEmptyMessage) @@ -416,20 +367,31 @@ private void StartWriting(byte[] buffer, int offset, int count, AsyncProtocolReq int chunkBytes = Math.Min(count, _sslState.MaxDataSize); int encryptedBytes; - SecurityStatusPal status = _sslState.EncryptData(buffer, offset, chunkBytes, ref outBuffer, out encryptedBytes); - if (status.ErrorCode != SecurityStatusPalErrorCode.OK) + byte[] outBuffer = BufferCache.Rent(WriteBufferSize); + + try { - // Re-handshake status is not supported. - ProtocolToken message = new ProtocolToken(null, status); - throw new IOException(SR.net_io_encrypt, message.GetException()); - } + SecurityStatusPal status = _sslState.EncryptData(buffer, offset, chunkBytes, ref outBuffer, out encryptedBytes); - if (PinnableBufferCacheEventSource.Log.IsEnabled()) + if (status.ErrorCode == SecurityStatusPalErrorCode.ContinueNeeded) + { + chunkBytes = 0; + } + else if (status.ErrorCode != SecurityStatusPalErrorCode.OK) + { + // Re-handshake status is not supported. + ProtocolToken message = new ProtocolToken(null, status); + throw new IOException(SR.net_io_encrypt, message.GetException()); + } + + } + catch { - PinnableBufferCacheEventSource.Log.DebugMessage3("In System.Net._SslStream.StartWriting Got Encrypted Buffer", - this.GetHashCode(), encryptedBytes, PinnableBufferCacheEventSource.AddressOfByteArray(outBuffer)); + FreeBuffer(outBuffer); + throw; } + if (asyncRequest != null) { // Prepare for the next request. @@ -437,10 +399,12 @@ private void StartWriting(byte[] buffer, int offset, int count, AsyncProtocolReq Task t = _sslState.InnerStream.WriteAsync(outBuffer, 0, encryptedBytes); if (t.IsCompleted) { + FreeBuffer(outBuffer); t.GetAwaiter().GetResult(); } else { + t = t.ContinueWith(s_freeBufferAction, outBuffer, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default); IAsyncResult ar = TaskToApm.Begin(t, s_writeCallback, asyncRequest); if (!ar.CompletedSynchronously) { @@ -448,10 +412,18 @@ private void StartWriting(byte[] buffer, int offset, int count, AsyncProtocolReq } TaskToApm.End(ar); } + } else { - _sslState.InnerStream.Write(outBuffer, 0, encryptedBytes); + try + { + _sslState.InnerStream.Write(outBuffer, 0, encryptedBytes); + } + finally + { + FreeBuffer(outBuffer); + } } offset += chunkBytes; @@ -461,21 +433,14 @@ private void StartWriting(byte[] buffer, int offset, int count, AsyncProtocolReq _sslState.FinishWrite(); } while (count != 0); - } - if (asyncRequest != null) - { - asyncRequest.CompleteUser(); } + asyncRequest?.CompleteUser(); + } - if (buffer == _pinnableOutputBufferInUse) - { - _pinnableOutputBufferInUse = null; - if (PinnableBufferCacheEventSource.Log.IsEnabled()) - { - PinnableBufferCacheEventSource.Log.DebugMessage1("In System.Net._SslStream.StartWriting Freeing buffer.", this.GetHashCode()); - } - } + private static void FreeBuffer(byte[] buffer) + { + BufferCache.Return(buffer); } // Fill the buffer up to the minimum specified size (or more, if possible). @@ -652,7 +617,7 @@ private int ProcessRead(byte[] buffer, int offset, int count, BufferAsyncResult if (Interlocked.Exchange(ref _nestedRead, 1) == 1) { - throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, (asyncResult!=null? "BeginRead":"Read"), "read")); + throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, (asyncResult != null ? "BeginRead" : "Read"), "read")); } // If this is an async operation, get the AsyncProtocolRequest to use. @@ -666,9 +631,9 @@ private int ProcessRead(byte[] buffer, int offset, int count, BufferAsyncResult if (_decryptedBytesCount != 0) { int copyBytes = CopyDecryptedData(buffer, offset, count); - + asyncRequest?.CompleteUser(copyBytes); - + return copyBytes; } @@ -743,7 +708,7 @@ private int StartFrameHeader(byte[] buffer, int offset, int count, AsyncProtocol Debug.Assert(asyncRequest != null); return 0; } - + return StartFrameBody(readBytes, buffer, offset, count, asyncRequest); } @@ -772,7 +737,7 @@ private int StartFrameBody(int readBytes, byte[] buffer, int offset, int count, } Debug.Assert(readBytes == 0 || readBytes == SecureChannel.ReadHeaderSize + payloadBytes); - + return ProcessFrameBody(readBytes, buffer, offset, count, asyncRequest); } @@ -829,7 +794,8 @@ private int ProcessFrameBody(int readBytes, byte[] buffer, int offset, int count private int ProcessReadErrorCode(SecurityStatusPal status, AsyncProtocolRequest asyncRequest, byte[] extraBuffer) { ProtocolToken message = new ProtocolToken(null, status); - if (NetEventSource.IsEnabled) NetEventSource.Info(null, $"***Processing an error Status = {message.Status}"); + if (NetEventSource.IsEnabled) + NetEventSource.Info(null, $"***Processing an error Status = {message.Status}"); if (message.Renegotiate) { diff --git a/src/System.Net.Security/src/System/Net/Security/SslStreamPal.Unix.cs b/src/System.Net.Security/src/System/Net/Security/SslStreamPal.Unix.cs index a0a4d65781a1..ad53e0ad87ad 100644 --- a/src/System.Net.Security/src/System/Net/Security/SslStreamPal.Unix.cs +++ b/src/System.Net.Security/src/System/Net/Security/SslStreamPal.Unix.cs @@ -67,6 +67,7 @@ public static SecurityStatusPal DecryptMessage(SafeDeleteContext securityContext { count = resultSize; } + return retVal; } @@ -167,6 +168,8 @@ private static SecurityStatusPal EncryptDecryptHelper(SafeDeleteContext security case Interop.Ssl.SslErrorCode.SSL_ERROR_NONE: case Interop.Ssl.SslErrorCode.SSL_ERROR_WANT_READ: return new SecurityStatusPal(SecurityStatusPalErrorCode.OK); + case Interop.Ssl.SslErrorCode.SSL_ERROR_WANT_WRITE: + return new SecurityStatusPal(SecurityStatusPalErrorCode.ContinueNeeded); default: return new SecurityStatusPal(SecurityStatusPalErrorCode.InternalError, new Interop.OpenSsl.SslException((int)errorCode)); } diff --git a/src/System.Net.Security/src/System/PinnableBufferCache.cs b/src/System.Net.Security/src/System/PinnableBufferCache.cs deleted file mode 100644 index 7e5889335045..000000000000 --- a/src/System.Net.Security/src/System/PinnableBufferCache.cs +++ /dev/null @@ -1,589 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System.Collections.Concurrent; -using System.Collections.Generic; -using System.Diagnostics.Tracing; -using System.Runtime.InteropServices; -using System.Threading; - -namespace System -{ - internal sealed class PinnableBufferCache - { - /// - /// Create a new cache for pinned byte[] buffers. - /// - /// A name used in diagnostic messages - /// The size of byte[] buffers in the cache (they are all the same size) - public PinnableBufferCache(string cacheName, int numberOfElements) : this(cacheName, () => new byte[numberOfElements]) { } - - /// - /// Get a buffer from the buffer manager. If no buffers exist, allocate a new one. - /// - public byte[] AllocateBuffer() { return (byte[])Allocate(); } - - /// - /// Return a buffer back to the buffer manager. - /// - public void FreeBuffer(byte[] buffer) { Free(buffer); } - - /// - /// Create a PinnableBufferCache that works on any object (it is intended for OverlappedData) - /// - internal PinnableBufferCache(string cacheName, Func factory) - { - _notGen2 = new List(DefaultNumberOfBuffers); - _factory = factory; - - PinnableBufferCacheEventSource.Log.Create(cacheName); - _cacheName = cacheName; - } - - /// - /// Get a object from the buffer manager. If no buffers exist, allocate a new one. - /// - [System.Security.SecuritySafeCritical] - internal object Allocate() - { - // Fast path, get it from our Gen2 aged _freeList. - object returnBuffer; - if (!_freeList.TryPop(out returnBuffer)) - { - Restock(out returnBuffer); - } - - // Computing free count is expensive enough that we don't want to compute it unless logging is on. - if (PinnableBufferCacheEventSource.Log.IsEnabled()) - { - int numAllocCalls = Interlocked.Increment(ref _numAllocCalls); - if (numAllocCalls >= 1024) - { - lock (this) - { - int previousNumAllocCalls = Interlocked.Exchange(ref _numAllocCalls, 0); - if (previousNumAllocCalls >= 1024) - { - int nonGen2Count = 0; - foreach (object o in _freeList) - { - if (GC.GetGeneration(o) < GC.MaxGeneration) - { - nonGen2Count++; - } - } - - PinnableBufferCacheEventSource.Log.WalkFreeListResult(_cacheName, _freeList.Count, nonGen2Count); - } - } - } - - PinnableBufferCacheEventSource.Log.AllocateBuffer(_cacheName, PinnableBufferCacheEventSource.AddressOf(returnBuffer), returnBuffer.GetHashCode(), GC.GetGeneration(returnBuffer), _freeList.Count); - } - return returnBuffer; - } - - /// - /// Return a buffer back to the buffer manager. - /// - [System.Security.SecuritySafeCritical] - internal void Free(object buffer) - { - if (PinnableBufferCacheEventSource.Log.IsEnabled()) - { - PinnableBufferCacheEventSource.Log.FreeBuffer(_cacheName, PinnableBufferCacheEventSource.AddressOf(buffer), buffer.GetHashCode(), _freeList.Count); - } - - - // After we've done 3 gen1 GCs, assume that all buffers have aged into gen2 on the free path. - if ((_gen1CountAtLastRestock + 3) > GC.CollectionCount(GC.MaxGeneration - 1)) - { - lock (this) - { - if (GC.GetGeneration(buffer) < GC.MaxGeneration) - { - // The buffer is not aged, so put it in the non-aged free list. - _moreThanFreeListNeeded = true; - PinnableBufferCacheEventSource.Log.FreeBufferStillTooYoung(_cacheName, _notGen2.Count); - _notGen2.Add(buffer); - _gen1CountAtLastRestock = GC.CollectionCount(GC.MaxGeneration - 1); - return; - } - } - } - - // If we discovered that it is indeed Gen2, great, put it in the Gen2 list. - _freeList.Push(buffer); - } - - #region Private - - /// - /// Called when we don't have any buffers in our free list to give out. - /// - /// - [System.Security.SecuritySafeCritical] - private void Restock(out object returnBuffer) - { - lock (this) - { - // Try again after getting the lock as another thread could have just filled the free list. If we don't check - // then we unnecessarily grab a new set of buffers because we think we are out. - if (_freeList.TryPop(out returnBuffer)) - { - return; - } - - // Lazy init, Ask that TrimFreeListIfNeeded be called on every Gen 2 GC. - if (_restockSize == 0) - { - Gen2GcCallback.Register(Gen2GcCallbackFunc, this); - } - - // Indicate to the trimming policy that the free list is insufficient. - _moreThanFreeListNeeded = true; - PinnableBufferCacheEventSource.Log.AllocateBufferFreeListEmpty(_cacheName, _notGen2.Count); - - // Get more buffers if needed. - if (_notGen2.Count == 0) - { - CreateNewBuffers(); - } - - // We have no buffers in the aged freelist, so get one from the newer list. Try to pick the best one. - int idx = _notGen2.Count - 1; - if (GC.GetGeneration(_notGen2[idx]) < GC.MaxGeneration && GC.GetGeneration(_notGen2[0]) == GC.MaxGeneration) - { - idx = 0; - } - - returnBuffer = _notGen2[idx]; - _notGen2.RemoveAt(idx); - - // Remember any sub-optimal buffer so we don't put it on the free list when it gets freed. - if (PinnableBufferCacheEventSource.Log.IsEnabled() && GC.GetGeneration(returnBuffer) < GC.MaxGeneration) - { - PinnableBufferCacheEventSource.Log.AllocateBufferFromNotGen2(_cacheName, _notGen2.Count); - } - - // If we have a Gen1 collection, then everything on _notGen2 should have aged. Move them to the _freeList. - if (!AgePendingBuffers()) - { - // Before we could age at set of buffers, we have handed out half of them. - // This implies we should be proactive about allocating more (since we will trim them if we over-allocate). - if (_notGen2.Count == _restockSize / 2) - { - PinnableBufferCacheEventSource.Log.DebugMessage("Proactively adding more buffers to aging pool"); - CreateNewBuffers(); - } - } - } - } - - /// - /// See if we can promote the buffers to the free list. Returns true if successful. - /// - [System.Security.SecuritySafeCritical] - private bool AgePendingBuffers() - { - if (_gen1CountAtLastRestock < GC.CollectionCount(GC.MaxGeneration - 1)) - { - // Allocate a temp list of buffers that are not actually in gen2, and swap it in once - // we're done scanning all buffers. - int promotedCount = 0; - List notInGen2 = new List(); - PinnableBufferCacheEventSource.Log.AllocateBufferAged(_cacheName, _notGen2.Count); - for (int i = 0; i < _notGen2.Count; i++) - { - // We actually check every object to ensure that we aren't putting non-aged buffers into the free list. - object currentBuffer = _notGen2[i]; - if (GC.GetGeneration(currentBuffer) >= GC.MaxGeneration) - { - _freeList.Push(currentBuffer); - promotedCount++; - } - else - { - notInGen2.Add(currentBuffer); - } - } - PinnableBufferCacheEventSource.Log.AgePendingBuffersResults(_cacheName, promotedCount, notInGen2.Count); - _notGen2 = notInGen2; - - return true; - } - return false; - } - - /// - /// Generates some buffers to age into Gen2. - /// - private void CreateNewBuffers() - { - // We choose a very modest number of buffers initially because for the client case. This is often enough. - if (_restockSize == 0) - { - _restockSize = 4; - } - else if (_restockSize < DefaultNumberOfBuffers) - { - _restockSize = DefaultNumberOfBuffers; - } - else if (_restockSize < 256) - { - _restockSize = _restockSize * 2; // Grow quickly at small sizes - } - else if (_restockSize < 4096) - { - _restockSize = _restockSize * 3 / 2; // Less aggressively at large ones - } - else - { - _restockSize = 4096; // Cap how aggressive we are - } - - // Ensure we hit our minimums - if (_minBufferCount > _buffersUnderManagement) - _restockSize = Math.Max(_restockSize, _minBufferCount - _buffersUnderManagement); - - PinnableBufferCacheEventSource.Log.AllocateBufferCreatingNewBuffers(_cacheName, _buffersUnderManagement, _restockSize); - for (int i = 0; i < _restockSize; i++) - { - // Make a new buffer. - object newBuffer = _factory(); - - // Create space between the objects. We do this because otherwise it forms a single plug (group of objects) - // and the GC pins the entire plug making them NOT move to Gen1 and Gen2. by putting space between them - // we ensure that object get a chance to move independently (even if some are pinned). - var dummyObject = new object(); - _notGen2.Add(newBuffer); - } - _buffersUnderManagement += _restockSize; - _gen1CountAtLastRestock = GC.CollectionCount(GC.MaxGeneration - 1); - } - - /// - /// This is the static function that is called from the gen2 GC callback. - /// The input object is the cache itself. - /// NOTE: The reason that we make this function static and take the cache as a parameter is that - /// otherwise, we root the cache to the Gen2GcCallback object, and leak the cache even when - /// the application no longer needs it. - /// - [System.Security.SecuritySafeCritical] - private static bool Gen2GcCallbackFunc(object targetObj) - { - return ((PinnableBufferCache)(targetObj)).TrimFreeListIfNeeded(); - } - - /// - /// This is called on every gen2 GC to see if we need to trim the free list. - /// NOTE: DO NOT CALL THIS DIRECTLY FROM THE GEN2GCCALLBACK. INSTEAD CALL IT VIA A STATIC FUNCTION (SEE ABOVE). - /// If you register a non-static function as a callback, then this object will be leaked. - /// - [System.Security.SecuritySafeCritical] - private bool TrimFreeListIfNeeded() - { - int curMSec = Environment.TickCount; - int deltaMSec = curMSec - _msecNoUseBeyondFreeListSinceThisTime; - PinnableBufferCacheEventSource.Log.TrimCheck(_cacheName, _buffersUnderManagement, _moreThanFreeListNeeded, deltaMSec); - - // If we needed more than just the set of aged buffers since the last time we were called, - // we obviously should not be trimming any memory, so do nothing except reset the flag - if (_moreThanFreeListNeeded) - { - _moreThanFreeListNeeded = false; - _trimmingExperimentInProgress = false; - _msecNoUseBeyondFreeListSinceThisTime = curMSec; - return true; - } - - // We require a minimum amount of clock time to pass (10 seconds) before we trim. Ideally this time - // is larger than the typical buffer hold time. - if (0 <= deltaMSec && deltaMSec < 10000) - { - return true; - } - - // If we got here we have spend the last few second without needing to lengthen the free list. Thus - // we have 'enough' buffers, but maybe we have too many. - // See if we can trim - lock (this) - { - // Hit a race, try again later. - if (_moreThanFreeListNeeded) - { - _moreThanFreeListNeeded = false; - _trimmingExperimentInProgress = false; - _msecNoUseBeyondFreeListSinceThisTime = curMSec; - return true; - } - - var freeCount = _freeList.Count; // This is expensive to fetch, do it once. - - // If there is something in _notGen2 it was not used for the last few seconds, it is trim-able. - if (_notGen2.Count > 0) - { - // If we are not performing an experiment and we have stuff that is waiting to go into the - // free list but has not made it there, it could be because the 'slow path' of restocking - // has not happened, so force this (which should flush the list) and start over. - if (!_trimmingExperimentInProgress) - { - PinnableBufferCacheEventSource.Log.TrimFlush(_cacheName, _buffersUnderManagement, freeCount, _notGen2.Count); - AgePendingBuffers(); - _trimmingExperimentInProgress = true; - return true; - } - - PinnableBufferCacheEventSource.Log.TrimFree(_cacheName, _buffersUnderManagement, freeCount, _notGen2.Count); - _buffersUnderManagement -= _notGen2.Count; - - // Possibly revise the restocking down. We don't want to grow aggressively if we are trimming. - var newRestockSize = _buffersUnderManagement / 4; - if (newRestockSize < _restockSize) - { - _restockSize = Math.Max(newRestockSize, DefaultNumberOfBuffers); - } - - _notGen2.Clear(); - _trimmingExperimentInProgress = false; - return true; - } - - // Set up an experiment where we use 25% less buffers in our free list. We put them in - // _notGen2, and if they are needed they will be put back in the free list again. - var trimSize = freeCount / 4 + 1; - - // We are OK with a 15% overhead, do nothing in that case. - if (freeCount * 15 <= _buffersUnderManagement || _buffersUnderManagement - trimSize <= _minBufferCount) - { - PinnableBufferCacheEventSource.Log.TrimFreeSizeOK(_cacheName, _buffersUnderManagement, freeCount); - return true; - } - - // Move buffers from the free list back to the non-aged list. If we don't use them by next time, then we'll consider trimming them. - PinnableBufferCacheEventSource.Log.TrimExperiment(_cacheName, _buffersUnderManagement, freeCount, trimSize); - object buffer; - for (int i = 0; i < trimSize; i++) - { - if (_freeList.TryPop(out buffer)) - { - _notGen2.Add(buffer); - } - } - _msecNoUseBeyondFreeListSinceThisTime = curMSec; - _trimmingExperimentInProgress = true; - } - - // Indicate that we want to be called back on the next Gen 2 GC. - return true; - } - - private const int DefaultNumberOfBuffers = 16; - private string _cacheName; - private Func _factory; - - /// - /// Contains 'good' buffers to reuse. They are guaranteed to be Gen 2 ENFORCED! - /// - private ConcurrentStack _freeList = new ConcurrentStack(); - /// - /// Contains buffers that are not gen 2 and thus we do not wish to give out unless we have to. - /// To implement trimming we sometimes put aged buffers in here as a place to 'park' them - /// before true deletion. - /// - private List _notGen2; - /// - /// What was the gen 1 count the last time re restocked? If it is now greater, then - /// we know that all objects are in Gen 2 so we don't have to check. Should be updated - /// every time something gets added to the _notGen2 list. - /// - private int _gen1CountAtLastRestock; - - /// - /// Used to ensure we have a minimum time between trimmings. - /// - private int _msecNoUseBeyondFreeListSinceThisTime; - /// - /// To trim, we remove things from the free list (which is Gen 2) and see if we 'hit bottom' - /// This flag indicates that we hit bottom (we really needed a bigger free list). - /// - private bool _moreThanFreeListNeeded; - /// - /// The total number of buffers that this cache has ever allocated. - /// Used in trimming heuristics. - /// - private int _buffersUnderManagement; - /// - /// The number of buffers we added the last time we restocked. - /// - private int _restockSize; - /// - /// Did we put some buffers into _notGen2 to see if we can trim? - /// - private bool _trimmingExperimentInProgress; - /// - /// A forced minimum number of buffers. - /// - private int _minBufferCount = 0; - /// - /// The number of calls to Allocate. - /// - private int _numAllocCalls; - #endregion - } - - /// - /// Schedules a callback roughly every gen 2 GC (you may see a Gen 0 an Gen 1 but only once) - /// (We can fix this by capturing the Gen 2 count at startup and testing, but I mostly don't care) - /// - internal sealed class Gen2GcCallback //: CriticalFinalizerObject - { - [System.Security.SecuritySafeCritical] - public Gen2GcCallback() - : base() - { - } - - /// - /// Schedule 'callback' to be called in the next GC. If the callback returns true it is - /// rescheduled for the next Gen 2 GC. Otherwise the callbacks stop. - /// - /// NOTE: This callback will be kept alive until either the callback function returns false, - /// or the target object dies. - /// - public static void Register(Func callback, object targetObj) - { - // Create a unreachable object that remembers the callback function and target object. - Gen2GcCallback gcCallback = new Gen2GcCallback(); - gcCallback.Setup(callback, targetObj); - } - - #region Private - - private Func _callback; - private GCHandle _weakTargetObj; - - [System.Security.SecuritySafeCritical] - private void Setup(Func callback, object targetObj) - { - _callback = callback; - _weakTargetObj = GCHandle.Alloc(targetObj, GCHandleType.Weak); - } - - [System.Security.SecuritySafeCritical] - ~Gen2GcCallback() - { - // Check to see if the target object is still alive. - object targetObj = _weakTargetObj.Target; - if (targetObj == null) - { - // The target object is dead, so this callback object is no longer needed. - _weakTargetObj.Free(); - return; - } - - // Execute the callback method. - try - { - if (!_callback(targetObj)) - { - // If the callback returns false, this callback object is no longer needed. - return; - } - } - catch - { - // Ensure that we still get a chance to resurrect this object, even if the callback throws an exception. - } - - // Resurrect ourselves by re-registering for finalization. - if (!Environment.HasShutdownStarted) - { - GC.ReRegisterForFinalize(this); - } - } - #endregion - } - - /// - /// PinnableBufferCacheEventSource is a private eventSource that we are using to - /// debug and monitor the effectiveness of PinnableBufferCache - /// - - // The following EventSource Name must be unique per DLL: - [EventSource(Name = "Microsoft-DotNETRuntime-PinnableBufferCache-Networking")] - internal sealed class PinnableBufferCacheEventSource : EventSource - { - public static readonly PinnableBufferCacheEventSource Log = new PinnableBufferCacheEventSource(); - - [Event(1, Level = EventLevel.Verbose)] - public void DebugMessage(string message) { if (IsEnabled()) WriteEvent(1, message); } - [Event(2, Level = EventLevel.Verbose)] - public void DebugMessage1(string message, long value) { if (IsEnabled()) WriteEvent(2, message, value); } - [Event(3, Level = EventLevel.Verbose)] - public void DebugMessage2(string message, long value1, long value2) { if (IsEnabled()) WriteEvent(3, message, value1, value2); } - [Event(18, Level = EventLevel.Verbose)] - public void DebugMessage3(string message, long value1, long value2, long value3) { if (IsEnabled()) WriteEvent(18, message, value1, value2, value3); } - - [Event(4)] - public void Create(string cacheName) { if (IsEnabled()) WriteEvent(4, cacheName); } - - [Event(5, Level = EventLevel.Verbose)] - public void AllocateBuffer(string cacheName, ulong objectId, int objectHash, int objectGen, int freeCountAfter) { if (IsEnabled()) WriteEvent(5, cacheName, objectId, objectHash, objectGen, freeCountAfter); } - [Event(6)] - public void AllocateBufferFromNotGen2(string cacheName, int notGen2CountAfter) { if (IsEnabled()) WriteEvent(6, cacheName, notGen2CountAfter); } - [Event(7)] - public void AllocateBufferCreatingNewBuffers(string cacheName, int totalBuffsBefore, int objectCount) { if (IsEnabled()) WriteEvent(7, cacheName, totalBuffsBefore, objectCount); } - [Event(8)] - public void AllocateBufferAged(string cacheName, int agedCount) { if (IsEnabled()) WriteEvent(8, cacheName, agedCount); } - [Event(9)] - public void AllocateBufferFreeListEmpty(string cacheName, int notGen2CountBefore) { if (IsEnabled()) WriteEvent(9, cacheName, notGen2CountBefore); } - - [Event(10, Level = EventLevel.Verbose)] - public void FreeBuffer(string cacheName, ulong objectId, int objectHash, int freeCountBefore) { if (IsEnabled()) WriteEvent(10, cacheName, objectId, objectHash, freeCountBefore); } - [Event(11)] - public void FreeBufferStillTooYoung(string cacheName, int notGen2CountBefore) { if (IsEnabled()) WriteEvent(11, cacheName, notGen2CountBefore); } - - [Event(13)] - public void TrimCheck(string cacheName, int totalBuffs, bool neededMoreThanFreeList, int deltaMSec) { if (IsEnabled()) WriteEvent(13, cacheName, totalBuffs, neededMoreThanFreeList, deltaMSec); } - [Event(14)] - public void TrimFree(string cacheName, int totalBuffs, int freeListCount, int toBeFreed) { if (IsEnabled()) WriteEvent(14, cacheName, totalBuffs, freeListCount, toBeFreed); } - [Event(15)] - public void TrimExperiment(string cacheName, int totalBuffs, int freeListCount, int numTrimTrial) { if (IsEnabled()) WriteEvent(15, cacheName, totalBuffs, freeListCount, numTrimTrial); } - [Event(16)] - public void TrimFreeSizeOK(string cacheName, int totalBuffs, int freeListCount) { if (IsEnabled()) WriteEvent(16, cacheName, totalBuffs, freeListCount); } - [Event(17)] - public void TrimFlush(string cacheName, int totalBuffs, int freeListCount, int notGen2CountBefore) { if (IsEnabled()) WriteEvent(17, cacheName, totalBuffs, freeListCount, notGen2CountBefore); } - [Event(20)] - public void AgePendingBuffersResults(string cacheName, int promotedToFreeListCount, int heldBackCount) { if (IsEnabled()) WriteEvent(20, cacheName, promotedToFreeListCount, heldBackCount); } - [Event(21)] - public void WalkFreeListResult(string cacheName, int freeListCount, int gen0BuffersInFreeList) { if (IsEnabled()) WriteEvent(21, cacheName, freeListCount, gen0BuffersInFreeList); } - - - internal static ulong AddressOf(object obj) - { - var asByteArray = obj as byte[]; - if (asByteArray != null) - { - return (ulong)AddressOfByteArray(asByteArray); - } - - return 0; - } - - [System.Security.SecuritySafeCritical] - internal static unsafe long AddressOfByteArray(byte[] array) - { - if (array == null) - { - return 0; - } - - fixed (byte* ptr = array) - { - return (long)(ptr - 2 * sizeof(void*)); - } - } - } -} diff --git a/src/System.Net.Security/tests/UnitTests/System.Net.Security.Unit.Tests.csproj b/src/System.Net.Security/tests/UnitTests/System.Net.Security.Unit.Tests.csproj old mode 100644 new mode 100755 index fd8c1b14f280..87d1947635b0 --- a/src/System.Net.Security/tests/UnitTests/System.Net.Security.Unit.Tests.csproj +++ b/src/System.Net.Security/tests/UnitTests/System.Net.Security.Unit.Tests.csproj @@ -1,4 +1,4 @@ - + @@ -16,11 +16,9 @@ - - @@ -33,9 +31,6 @@ - - ProductionCode\System\PinnableBufferCache.cs - ProductionCode\System\Net\Security\SslStream.cs @@ -63,4 +58,4 @@ - + \ No newline at end of file diff --git a/src/System.Net.Security/tests/UnitTests/System/PinnableBufferCacheTest.cs b/src/System.Net.Security/tests/UnitTests/System/PinnableBufferCacheTest.cs deleted file mode 100644 index 2d6bc2306211..000000000000 --- a/src/System.Net.Security/tests/UnitTests/System/PinnableBufferCacheTest.cs +++ /dev/null @@ -1,27 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using Xunit; - -namespace System.Net.Security.Tests -{ - public class PinnableBufferCacheTest - { - [Fact] - public void PinnableBufferCache_AllocateBuffer_Ok() - { - string cacheName = "Test"; - int numberOfElements = 5; - PinnableBufferCache p = new PinnableBufferCache(cacheName, numberOfElements); - - byte[] a = p.AllocateBuffer(); - - Assert.Equal(numberOfElements, a.Length); - foreach (byte t in a) - { - Assert.Equal(0, t); - } - } - } -}