From fb3fd2fcf13e52bed4518e041ddcec1983dd165b Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Tue, 18 Feb 2025 23:51:39 -0500 Subject: [PATCH 1/3] Add GetRequiredService extension for IChatClient/EmbeddingGenerator --- .../ChatCompletion/ChatClientExtensions.cs | 28 ++++++++- .../EmbeddingGeneratorExtensions.cs | 47 ++++++++++++++- .../Throw.cs | 15 +++++ src/Shared/Throw/Throw.cs | 3 +- .../ChatClientExtensionsTests.cs | 39 +++++++++++++ .../EmbeddingGeneratorExtensionsTests.cs | 57 +++++++++++++++++++ 6 files changed, 186 insertions(+), 3 deletions(-) create mode 100644 src/Libraries/Microsoft.Extensions.AI.Abstractions/Throw.cs diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatClientExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatClientExtensions.cs index e485a75d4b1..99f0532b49b 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatClientExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatClientExtensions.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; @@ -24,7 +25,32 @@ public static class ChatClientExtensions { _ = Throw.IfNull(client); - return (TService?)client.GetService(typeof(TService), serviceKey); + return client.GetService(typeof(TService), serviceKey) is TService service ? service : default; + } + + /// + /// Asks the for an object of type + /// and throws an exception if one isn't available. + /// + /// The type of the object to be retrieved. + /// The client. + /// An optional key that can be used to help identify the target service. + /// The found object. + /// No service of the requested type for the specified key is available. + /// + /// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the , + /// including itself or any services it might be wrapping. + /// + public static TService GetRequiredService(this IChatClient client, object? serviceKey = null) + { + _ = Throw.IfNull(client); + + if (client.GetService(typeof(TService), serviceKey) is TService service) + { + return service; + } + + throw Throw.CreateMissingServiceException(serviceKey); } /// Sends a user chat text message and returns the response messages. diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs index 1165d299edf..978d7484c5a 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs @@ -9,6 +9,7 @@ using Microsoft.Shared.Diagnostics; #pragma warning disable S2302 // "nameof" should be used +#pragma warning disable S4136 // Method overloads should be grouped together namespace Microsoft.Extensions.AI; @@ -31,7 +32,35 @@ public static class EmbeddingGeneratorExtensions { _ = Throw.IfNull(generator); - return (TService?)generator.GetService(typeof(TService), serviceKey); + return generator.GetService(typeof(TService), serviceKey) is TService service ? service : default; + } + + /// + /// Asks the for an object of type + /// and throws an exception if one isn't available. + /// + /// The type from which embeddings will be generated. + /// The numeric type of the embedding data. + /// The type of the object to be retrieved. + /// The generator. + /// An optional key that can be used to help identify the target service. + /// The found object. + /// No service of the requested type for the specified key is available. + /// + /// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the + /// , including itself or any services it might be wrapping. + /// + public static TService GetRequiredService(this IEmbeddingGenerator generator, object? serviceKey = null) + where TEmbedding : Embedding + { + _ = Throw.IfNull(generator); + + if (generator.GetService(typeof(TService), serviceKey) is TService service) + { + return service; + } + + throw Throw.CreateMissingServiceException(serviceKey); } // The following overload exists purely to work around the lack of partial generic type inference. @@ -52,6 +81,22 @@ public static class EmbeddingGeneratorExtensions public static TService? GetService(this IEmbeddingGenerator> generator, object? serviceKey = null) => GetService, TService>(generator, serviceKey); + /// + /// Asks the for an object of type + /// and throws an exception if one isn't available. + /// + /// The type of the object to be retrieved. + /// The generator. + /// An optional key that can be used to help identify the target service. + /// The found object. + /// No service of the requested type for the specified key is available. + /// + /// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the + /// , including itself or any services it might be wrapping. + /// + public static TService GetRequiredService(this IEmbeddingGenerator> generator, object? serviceKey = null) => + GetRequiredService, TService>(generator, serviceKey); + /// Generates an embedding vector from the specified . /// The type from which embeddings will be generated. /// The numeric type of the embedding data. diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Throw.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Throw.cs new file mode 100644 index 00000000000..5cc2e1a118e --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Throw.cs @@ -0,0 +1,15 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; + +namespace Microsoft.Shared.Diagnostics; + +internal static partial class Throw +{ + /// Throws an exception indicating that a required service is not available. + public static InvalidOperationException CreateMissingServiceException(object? serviceKey) => + new InvalidOperationException(serviceKey is null ? + $"No service of type '{typeof(TService)}' is available." : + $"No service of type '{typeof(TService)}' for the key '{serviceKey}' is available."); +} diff --git a/src/Shared/Throw/Throw.cs b/src/Shared/Throw/Throw.cs index ee082e2da4a..94c5c8a50db 100644 --- a/src/Shared/Throw/Throw.cs +++ b/src/Shared/Throw/Throw.cs @@ -18,13 +18,14 @@ namespace Microsoft.Shared.Diagnostics; /// messages. /// [SuppressMessage("Minor Code Smell", "S4136:Method overloads should be grouped together", Justification = "Doesn't work with the region layout")] +[SuppressMessage("Minor Code Smell", "S2333:Partial is gratuitous in this context", Justification = "Some projects add additional partial parts.")] [SuppressMessage("Design", "CA1716", Justification = "Not part of an API")] #if !SHARED_PROJECT [ExcludeFromCodeCoverage] #endif -internal static class Throw +internal static partial class Throw { #region For Object diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs index 64ea4406be9..16601a9d62c 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs @@ -15,6 +15,45 @@ public class ChatClientExtensionsTests public void GetService_InvalidArgs_Throws() { Assert.Throws("client", () => ChatClientExtensions.GetService(null!)); + Assert.Throws("client", () => ChatClientExtensions.GetRequiredService(null!)); + } + + [Fact] + public void GetService_ValidService_Returned() + { + using var client = new TestChatClient + { + GetServiceCallback = (Type serviceType, object? serviceKey) => + { + if (serviceType == typeof(string)) + { + return serviceKey == null ? "null key" : "non-null key"; + } + + if (serviceType == typeof(IChatClient)) + { + return new object(); + } + + return null; + }, + }; + + Assert.Equal("null key", client.GetService()); + Assert.Equal("null key", client.GetService(null)); + Assert.Equal("non-null key", client.GetService("key")); + + Assert.Null(client.GetService()); + Assert.Null(client.GetService("key")); + Assert.Null(client.GetService()); + + Assert.Equal("null key", client.GetRequiredService()); + Assert.Equal("null key", client.GetRequiredService(null)); + Assert.Equal("non-null key", client.GetRequiredService("key")); + + Assert.Throws(() => client.GetRequiredService()); + Assert.Throws(() => client.GetRequiredService("key")); + Assert.Throws(() => client.GetRequiredService()); } [Fact] diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs index 4466dd85d1e..993938c3557 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs @@ -15,6 +15,63 @@ public void GetService_InvalidArgs_Throws() { Assert.Throws("generator", () => EmbeddingGeneratorExtensions.GetService(null!)); Assert.Throws("generator", () => EmbeddingGeneratorExtensions.GetService, object>(null!)); + + Assert.Throws("generator", () => EmbeddingGeneratorExtensions.GetRequiredService(null!)); + Assert.Throws("generator", () => EmbeddingGeneratorExtensions.GetRequiredService, object>(null!)); + } + + [Fact] + public void GetService_ValidService_Returned() + { + using var generator = new TestEmbeddingGenerator + { + GetServiceCallback = (Type serviceType, object? serviceKey) => + { + if (serviceType == typeof(string)) + { + return serviceKey == null ? "null key" : "non-null key"; + } + + if (serviceType == typeof(IEmbeddingGenerator>)) + { + return new object(); + } + + return null; + }, + }; + + Assert.Equal("null key", generator.GetService()); + Assert.Equal("null key", generator.GetService(null)); + Assert.Equal("non-null key", generator.GetService("key")); + + Assert.Equal("null key", generator.GetService, string>()); + Assert.Equal("null key", generator.GetService, string>(null)); + Assert.Equal("non-null key", generator.GetService, string>("key")); + + Assert.Null(generator.GetService()); + Assert.Null(generator.GetService("key")); + Assert.Null(generator.GetService>>()); + + Assert.Null(generator.GetService, object>()); + Assert.Null(generator.GetService, object>("key")); + Assert.Null(generator.GetService, IEmbeddingGenerator>>()); + + Assert.Equal("null key", generator.GetRequiredService()); + Assert.Equal("null key", generator.GetRequiredService(null)); + Assert.Equal("non-null key", generator.GetRequiredService("key")); + + Assert.Equal("null key", generator.GetRequiredService, string>()); + Assert.Equal("null key", generator.GetRequiredService, string>(null)); + Assert.Equal("non-null key", generator.GetRequiredService, string>("key")); + + Assert.Throws(() => generator.GetRequiredService()); + Assert.Throws(() => generator.GetRequiredService("key")); + Assert.Throws(() => generator.GetRequiredService>>()); + + Assert.Throws(() => generator.GetRequiredService, object>()); + Assert.Throws(() => generator.GetRequiredService, object>("key")); + Assert.Throws(() => generator.GetRequiredService, IEmbeddingGenerator>>()); } [Fact] From ce62cd7fbe7d747d9884ee7e2c64337a8c65831e Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Wed, 19 Feb 2025 09:49:55 -0500 Subject: [PATCH 2/3] Add non-generic GetRequiredService --- .../ChatCompletion/ChatClientExtensions.cs | 41 +++++++++++++- .../ChatCompletion/IChatClient.cs | 1 + .../EmbeddingGeneratorExtensions.cs | 56 +++++++++++++++++-- .../Embeddings/IEmbeddingGenerator.cs | 1 + .../Throw.cs | 6 +- .../ChatClientExtensionsTests.cs | 13 +++++ .../EmbeddingGeneratorExtensionsTests.cs | 48 ++++++++++------ 7 files changed, 137 insertions(+), 29 deletions(-) diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatClientExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatClientExtensions.cs index 99f0532b49b..da943bcb6e2 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatClientExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatClientExtensions.cs @@ -17,6 +17,7 @@ public static class ChatClientExtensions /// The client. /// An optional key that can be used to help identify the target service. /// The found object, otherwise . + /// is . /// /// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the , /// including itself or any services it might be wrapping. @@ -28,6 +29,31 @@ public static class ChatClientExtensions return client.GetService(typeof(TService), serviceKey) is TService service ? service : default; } + /// + /// Asks the for an object of the specified type + /// and throws an exception if one isn't available. + /// + /// The client. + /// The type of object being requested. + /// An optional key that can be used to help identify the target service. + /// The found object. + /// is . + /// is . + /// No service of the requested type for the specified key is available. + /// + /// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the , + /// including itself or any services it might be wrapping. + /// + public static object GetRequiredService(this IChatClient client, Type serviceType, object? serviceKey = null) + { + _ = Throw.IfNull(client); + _ = Throw.IfNull(serviceType); + + return + client.GetService(serviceType, serviceKey) ?? + throw Throw.CreateMissingServiceException(serviceType, serviceKey); + } + /// /// Asks the for an object of type /// and throws an exception if one isn't available. @@ -36,6 +62,7 @@ public static class ChatClientExtensions /// The client. /// An optional key that can be used to help identify the target service. /// The found object. + /// is . /// No service of the requested type for the specified key is available. /// /// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the , @@ -45,12 +72,12 @@ public static TService GetRequiredService(this IChatClient client, obj { _ = Throw.IfNull(client); - if (client.GetService(typeof(TService), serviceKey) is TService service) + if (client.GetService(typeof(TService), serviceKey) is not TService service) { - return service; + throw Throw.CreateMissingServiceException(typeof(TService), serviceKey); } - throw Throw.CreateMissingServiceException(serviceKey); + return service; } /// Sends a user chat text message and returns the response messages. @@ -59,6 +86,8 @@ public static TService GetRequiredService(this IChatClient client, obj /// The chat options to configure the request. /// The to monitor for cancellation requests. The default is . /// The response messages generated by the client. + /// is . + /// is . public static Task GetResponseAsync( this IChatClient client, string chatMessage, @@ -77,6 +106,8 @@ public static Task GetResponseAsync( /// The chat options to configure the request. /// The to monitor for cancellation requests. The default is . /// The response messages generated by the client. + /// is . + /// is . public static Task GetResponseAsync( this IChatClient client, ChatMessage chatMessage, @@ -95,6 +126,8 @@ public static Task GetResponseAsync( /// The chat options to configure the request. /// The to monitor for cancellation requests. The default is . /// The response messages generated by the client. + /// is . + /// is . public static IAsyncEnumerable GetStreamingResponseAsync( this IChatClient client, string chatMessage, @@ -113,6 +146,8 @@ public static IAsyncEnumerable GetStreamingResponseAsync( /// The chat options to configure the request. /// The to monitor for cancellation requests. The default is . /// The response messages generated by the client. + /// is . + /// is . public static IAsyncEnumerable GetStreamingResponseAsync( this IChatClient client, ChatMessage chatMessage, diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs index 79e2f658cb9..26a39f05105 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs @@ -57,6 +57,7 @@ IAsyncEnumerable GetStreamingResponseAsync( /// The type of object being requested. /// An optional key that can be used to help identify the target service. /// The found object, otherwise . + /// is . /// /// The purpose of this method is to allow for the retrieval of strongly typed services that might be provided by the , /// including itself or any services it might be wrapping. diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs index 978d7484c5a..759ad101ac6 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs @@ -23,11 +23,13 @@ public static class EmbeddingGeneratorExtensions /// The generator. /// An optional key that can be used to help identify the target service. /// The found object, otherwise . + /// is . /// /// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the /// , including itself or any services it might be wrapping. /// - public static TService? GetService(this IEmbeddingGenerator generator, object? serviceKey = null) + public static TService? GetService( + this IEmbeddingGenerator generator, object? serviceKey = null) where TEmbedding : Embedding { _ = Throw.IfNull(generator); @@ -35,6 +37,35 @@ public static class EmbeddingGeneratorExtensions return generator.GetService(typeof(TService), serviceKey) is TService service ? service : default; } + /// + /// Asks the for an object of the specified type + /// and throws an exception if one isn't available. + /// + /// The type from which embeddings will be generated. + /// The numeric type of the embedding data. + /// The generator. + /// The type of object being requested. + /// An optional key that can be used to help identify the target service. + /// The found object. + /// is . + /// is . + /// No service of the requested type for the specified key is available. + /// + /// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the + /// , including itself or any services it might be wrapping. + /// + public static object GetRequiredService( + this IEmbeddingGenerator generator, Type serviceType, object? serviceKey = null) + where TEmbedding : Embedding + { + _ = Throw.IfNull(generator); + _ = Throw.IfNull(serviceType); + + return + generator.GetService(serviceType, serviceKey) ?? + throw Throw.CreateMissingServiceException(serviceType, serviceKey); + } + /// /// Asks the for an object of type /// and throws an exception if one isn't available. @@ -45,25 +76,27 @@ public static class EmbeddingGeneratorExtensions /// The generator. /// An optional key that can be used to help identify the target service. /// The found object. + /// is . /// No service of the requested type for the specified key is available. /// /// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the /// , including itself or any services it might be wrapping. /// - public static TService GetRequiredService(this IEmbeddingGenerator generator, object? serviceKey = null) + public static TService GetRequiredService( + this IEmbeddingGenerator generator, object? serviceKey = null) where TEmbedding : Embedding { _ = Throw.IfNull(generator); - if (generator.GetService(typeof(TService), serviceKey) is TService service) + if (generator.GetService(typeof(TService), serviceKey) is not TService service) { - return service; + throw Throw.CreateMissingServiceException(typeof(TService), serviceKey); } - throw Throw.CreateMissingServiceException(serviceKey); + return service; } - // The following overload exists purely to work around the lack of partial generic type inference. + // The following overloads exist purely to work around the lack of partial generic type inference. // Given an IEmbeddingGenerator generator, to call GetService with TService, you still need // to re-specify both TInput and TEmbedding, e.g. generator.GetService, TService>. // The case of string/Embedding is by far the most common case today, so this overload exists as an @@ -74,6 +107,7 @@ public static TService GetRequiredService(this IEm /// The generator. /// An optional key that can be used to help identify the target service. /// The found object, otherwise . + /// is . /// /// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the /// , including itself or any services it might be wrapping. @@ -89,6 +123,7 @@ public static TService GetRequiredService(this IEm /// The generator. /// An optional key that can be used to help identify the target service. /// The found object. + /// is . /// No service of the requested type for the specified key is available. /// /// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the @@ -105,6 +140,9 @@ public static TService GetRequiredService(this IEmbeddingGeneratorThe embedding generation options to configure the request. /// The to monitor for cancellation requests. The default is . /// The generated embedding for the specified . + /// is . + /// is . + /// The generator did not produce exactly one embedding. /// /// This operation is equivalent to using and returning the /// resulting 's property. @@ -129,6 +167,9 @@ public static async Task> GenerateEmbeddingVec /// /// The generated embedding for the specified . /// + /// is . + /// is . + /// The generator did not produce exactly one embedding. /// /// This operations is equivalent to using with a /// collection composed of the single and then returning the first embedding element from the @@ -170,6 +211,9 @@ public static async Task GenerateEmbeddingAsync( /// The embedding generation options to configure the request. /// The to monitor for cancellation requests. The default is . /// An array containing tuples of the input values and the associated generated embeddings. + /// is . + /// is . + /// The generator did not produce one embedding for each input value. public static async Task<(TInput Value, TEmbedding Embedding)[]> GenerateAndZipAsync( this IEmbeddingGenerator generator, IEnumerable values, diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs index 02bf1880427..c260708079c 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs @@ -41,6 +41,7 @@ Task> GenerateAsync( /// The type of object being requested. /// An optional key that can be used to help identify the target service. /// The found object, otherwise . + /// is . /// /// The purpose of this method is to allow for the retrieval of strongly typed services that might be provided by the /// , including itself or any services it might be wrapping. diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Throw.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Throw.cs index 5cc2e1a118e..96bbc4d020d 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Throw.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Throw.cs @@ -8,8 +8,8 @@ namespace Microsoft.Shared.Diagnostics; internal static partial class Throw { /// Throws an exception indicating that a required service is not available. - public static InvalidOperationException CreateMissingServiceException(object? serviceKey) => + public static InvalidOperationException CreateMissingServiceException(Type serviceType, object? serviceKey) => new InvalidOperationException(serviceKey is null ? - $"No service of type '{typeof(TService)}' is available." : - $"No service of type '{typeof(TService)}' for the key '{serviceKey}' is available."); + $"No service of type '{serviceType}' is available." : + $"No service of type '{serviceType}' for the key '{serviceKey}' is available."); } diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs index 16601a9d62c..5a95f2b3fd0 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs @@ -15,7 +15,16 @@ public class ChatClientExtensionsTests public void GetService_InvalidArgs_Throws() { Assert.Throws("client", () => ChatClientExtensions.GetService(null!)); + } + + [Fact] + public void GetRequiredService_InvalidArgs_Throws() + { + Assert.Throws("client", () => ChatClientExtensions.GetRequiredService(null!, typeof(string))); Assert.Throws("client", () => ChatClientExtensions.GetRequiredService(null!)); + + using var client = new TestChatClient(); + Assert.Throws("serviceType", () => client.GetRequiredService(null!)); } [Fact] @@ -47,11 +56,15 @@ public void GetService_ValidService_Returned() Assert.Null(client.GetService("key")); Assert.Null(client.GetService()); + Assert.Equal("null key", client.GetRequiredService(typeof(string))); Assert.Equal("null key", client.GetRequiredService()); Assert.Equal("null key", client.GetRequiredService(null)); + Assert.Equal("non-null key", client.GetRequiredService(typeof(string), "key")); Assert.Equal("non-null key", client.GetRequiredService("key")); + Assert.Throws(() => client.GetRequiredService(typeof(object))); Assert.Throws(() => client.GetRequiredService()); + Assert.Throws(() => client.GetRequiredService(typeof(object), "key")); Assert.Throws(() => client.GetRequiredService("key")); Assert.Throws(() => client.GetRequiredService()); } diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs index 993938c3557..8a61fbb0786 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs @@ -15,15 +15,23 @@ public void GetService_InvalidArgs_Throws() { Assert.Throws("generator", () => EmbeddingGeneratorExtensions.GetService(null!)); Assert.Throws("generator", () => EmbeddingGeneratorExtensions.GetService, object>(null!)); + } + [Fact] + public void GetRequiredService_InvalidArgs_Throws() + { Assert.Throws("generator", () => EmbeddingGeneratorExtensions.GetRequiredService(null!)); + Assert.Throws("generator", () => EmbeddingGeneratorExtensions.GetRequiredService>(null!, typeof(string))); Assert.Throws("generator", () => EmbeddingGeneratorExtensions.GetRequiredService, object>(null!)); + + using var generator = new TestEmbeddingGenerator(); + Assert.Throws("serviceType", () => generator.GetRequiredService(null!)); } [Fact] public void GetService_ValidService_Returned() { - using var generator = new TestEmbeddingGenerator + using IEmbeddingGenerator> generator = new TestEmbeddingGenerator { GetServiceCallback = (Type serviceType, object? serviceKey) => { @@ -41,37 +49,43 @@ public void GetService_ValidService_Returned() }, }; + Assert.Equal("null key", generator.GetService(typeof(string))); Assert.Equal("null key", generator.GetService()); - Assert.Equal("null key", generator.GetService(null)); - Assert.Equal("non-null key", generator.GetService("key")); - Assert.Equal("null key", generator.GetService, string>()); - Assert.Equal("null key", generator.GetService, string>(null)); + + Assert.Equal("non-null key", generator.GetService(typeof(string), "key")); + Assert.Equal("non-null key", generator.GetService("key")); Assert.Equal("non-null key", generator.GetService, string>("key")); + Assert.Null(generator.GetService(typeof(object))); Assert.Null(generator.GetService()); - Assert.Null(generator.GetService("key")); - Assert.Null(generator.GetService>>()); - Assert.Null(generator.GetService, object>()); + + Assert.Null(generator.GetService(typeof(object), "key")); + Assert.Null(generator.GetService("key")); Assert.Null(generator.GetService, object>("key")); - Assert.Null(generator.GetService, IEmbeddingGenerator>>()); - Assert.Equal("null key", generator.GetRequiredService()); - Assert.Equal("null key", generator.GetRequiredService(null)); - Assert.Equal("non-null key", generator.GetRequiredService("key")); + Assert.Null(generator.GetService()); + Assert.Null(generator.GetService, int?>()); + Assert.Equal("null key", generator.GetRequiredService(typeof(string))); + Assert.Equal("null key", generator.GetRequiredService()); Assert.Equal("null key", generator.GetRequiredService, string>()); - Assert.Equal("null key", generator.GetRequiredService, string>(null)); + + Assert.Equal("non-null key", generator.GetRequiredService(typeof(string), "key")); + Assert.Equal("non-null key", generator.GetRequiredService("key")); Assert.Equal("non-null key", generator.GetRequiredService, string>("key")); + Assert.Throws(() => generator.GetRequiredService(typeof(object))); Assert.Throws(() => generator.GetRequiredService()); - Assert.Throws(() => generator.GetRequiredService("key")); - Assert.Throws(() => generator.GetRequiredService>>()); - Assert.Throws(() => generator.GetRequiredService, object>()); + + Assert.Throws(() => generator.GetRequiredService(typeof(object), "key")); + Assert.Throws(() => generator.GetRequiredService("key")); Assert.Throws(() => generator.GetRequiredService, object>("key")); - Assert.Throws(() => generator.GetRequiredService, IEmbeddingGenerator>>()); + + Assert.Throws(() => generator.GetRequiredService()); + Assert.Throws(() => generator.GetRequiredService, int?>()); } [Fact] From 3a57e748b94bbfe6da5186964bd27e76f14ab555 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Wed, 19 Feb 2025 10:46:08 -0500 Subject: [PATCH 3/3] Apply suggestions from code review Co-authored-by: Eric Erhardt --- .../ChatCompletion/ChatClientExtensions.cs | 4 ++-- .../Embeddings/EmbeddingGeneratorExtensions.cs | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatClientExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatClientExtensions.cs index da943bcb6e2..81fd4f97778 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatClientExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatClientExtensions.cs @@ -41,7 +41,7 @@ public static class ChatClientExtensions /// is . /// No service of the requested type for the specified key is available. /// - /// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the , + /// The purpose of this method is to allow for the retrieval of services that are required to be provided by the , /// including itself or any services it might be wrapping. /// public static object GetRequiredService(this IChatClient client, Type serviceType, object? serviceKey = null) @@ -65,7 +65,7 @@ public static object GetRequiredService(this IChatClient client, Type serviceTyp /// is . /// No service of the requested type for the specified key is available. /// - /// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the , + /// The purpose of this method is to allow for the retrieval of strongly typed services that are required to be provided by the , /// including itself or any services it might be wrapping. /// public static TService GetRequiredService(this IChatClient client, object? serviceKey = null) diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs index 759ad101ac6..d69952598dd 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs @@ -51,7 +51,7 @@ public static class EmbeddingGeneratorExtensions /// is . /// No service of the requested type for the specified key is available. /// - /// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the + /// The purpose of this method is to allow for the retrieval of services that are required to be provided by the /// , including itself or any services it might be wrapping. /// public static object GetRequiredService( @@ -79,7 +79,7 @@ public static object GetRequiredService( /// is . /// No service of the requested type for the specified key is available. /// - /// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the + /// The purpose of this method is to allow for the retrieval of strongly typed services that are required to be provided by the /// , including itself or any services it might be wrapping. /// public static TService GetRequiredService(