diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatClientExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatClientExtensions.cs index e485a75d4b1..81fd4f97778 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatClientExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatClientExtensions.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; @@ -16,6 +17,7 @@ public static class ChatClientExtensions /// The client. /// An optional key that can be used to help identify the target service. /// The found object, otherwise . + /// is . /// /// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the , /// including itself or any services it might be wrapping. @@ -24,7 +26,58 @@ public static class ChatClientExtensions { _ = Throw.IfNull(client); - return (TService?)client.GetService(typeof(TService), serviceKey); + return client.GetService(typeof(TService), serviceKey) is TService service ? service : default; + } + + /// + /// Asks the for an object of the specified type + /// and throws an exception if one isn't available. + /// + /// The client. + /// The type of object being requested. + /// An optional key that can be used to help identify the target service. + /// The found object. + /// is . + /// is . + /// No service of the requested type for the specified key is available. + /// + /// The purpose of this method is to allow for the retrieval of services that are required to be provided by the , + /// including itself or any services it might be wrapping. + /// + public static object GetRequiredService(this IChatClient client, Type serviceType, object? serviceKey = null) + { + _ = Throw.IfNull(client); + _ = Throw.IfNull(serviceType); + + return + client.GetService(serviceType, serviceKey) ?? + throw Throw.CreateMissingServiceException(serviceType, serviceKey); + } + + /// + /// Asks the for an object of type + /// and throws an exception if one isn't available. + /// + /// The type of the object to be retrieved. + /// The client. + /// An optional key that can be used to help identify the target service. + /// The found object. + /// is . + /// No service of the requested type for the specified key is available. + /// + /// The purpose of this method is to allow for the retrieval of strongly typed services that are required to be provided by the , + /// including itself or any services it might be wrapping. + /// + public static TService GetRequiredService(this IChatClient client, object? serviceKey = null) + { + _ = Throw.IfNull(client); + + if (client.GetService(typeof(TService), serviceKey) is not TService service) + { + throw Throw.CreateMissingServiceException(typeof(TService), serviceKey); + } + + return service; } /// Sends a user chat text message and returns the response messages. @@ -33,6 +86,8 @@ public static class ChatClientExtensions /// The chat options to configure the request. /// The to monitor for cancellation requests. The default is . /// The response messages generated by the client. + /// is . + /// is . public static Task GetResponseAsync( this IChatClient client, string chatMessage, @@ -51,6 +106,8 @@ public static Task GetResponseAsync( /// The chat options to configure the request. /// The to monitor for cancellation requests. The default is . /// The response messages generated by the client. + /// is . + /// is . public static Task GetResponseAsync( this IChatClient client, ChatMessage chatMessage, @@ -69,6 +126,8 @@ public static Task GetResponseAsync( /// The chat options to configure the request. /// The to monitor for cancellation requests. The default is . /// The response messages generated by the client. + /// is . + /// is . public static IAsyncEnumerable GetStreamingResponseAsync( this IChatClient client, string chatMessage, @@ -87,6 +146,8 @@ public static IAsyncEnumerable GetStreamingResponseAsync( /// The chat options to configure the request. /// The to monitor for cancellation requests. The default is . /// The response messages generated by the client. + /// is . + /// is . public static IAsyncEnumerable GetStreamingResponseAsync( this IChatClient client, ChatMessage chatMessage, diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs index 79e2f658cb9..26a39f05105 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs @@ -57,6 +57,7 @@ IAsyncEnumerable GetStreamingResponseAsync( /// The type of object being requested. /// An optional key that can be used to help identify the target service. /// The found object, otherwise . + /// is . /// /// The purpose of this method is to allow for the retrieval of strongly typed services that might be provided by the , /// including itself or any services it might be wrapping. diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs index 1165d299edf..d69952598dd 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs @@ -9,6 +9,7 @@ using Microsoft.Shared.Diagnostics; #pragma warning disable S2302 // "nameof" should be used +#pragma warning disable S4136 // Method overloads should be grouped together namespace Microsoft.Extensions.AI; @@ -22,19 +23,80 @@ public static class EmbeddingGeneratorExtensions /// The generator. /// An optional key that can be used to help identify the target service. /// The found object, otherwise . + /// is . /// /// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the /// , including itself or any services it might be wrapping. /// - public static TService? GetService(this IEmbeddingGenerator generator, object? serviceKey = null) + public static TService? GetService( + this IEmbeddingGenerator generator, object? serviceKey = null) where TEmbedding : Embedding { _ = Throw.IfNull(generator); - return (TService?)generator.GetService(typeof(TService), serviceKey); + return generator.GetService(typeof(TService), serviceKey) is TService service ? service : default; } - // The following overload exists purely to work around the lack of partial generic type inference. + /// + /// Asks the for an object of the specified type + /// and throws an exception if one isn't available. + /// + /// The type from which embeddings will be generated. + /// The numeric type of the embedding data. + /// The generator. + /// The type of object being requested. + /// An optional key that can be used to help identify the target service. + /// The found object. + /// is . + /// is . + /// No service of the requested type for the specified key is available. + /// + /// The purpose of this method is to allow for the retrieval of services that are required to be provided by the + /// , including itself or any services it might be wrapping. + /// + public static object GetRequiredService( + this IEmbeddingGenerator generator, Type serviceType, object? serviceKey = null) + where TEmbedding : Embedding + { + _ = Throw.IfNull(generator); + _ = Throw.IfNull(serviceType); + + return + generator.GetService(serviceType, serviceKey) ?? + throw Throw.CreateMissingServiceException(serviceType, serviceKey); + } + + /// + /// Asks the for an object of type + /// and throws an exception if one isn't available. + /// + /// The type from which embeddings will be generated. + /// The numeric type of the embedding data. + /// The type of the object to be retrieved. + /// The generator. + /// An optional key that can be used to help identify the target service. + /// The found object. + /// is . + /// No service of the requested type for the specified key is available. + /// + /// The purpose of this method is to allow for the retrieval of strongly typed services that are required to be provided by the + /// , including itself or any services it might be wrapping. + /// + public static TService GetRequiredService( + this IEmbeddingGenerator generator, object? serviceKey = null) + where TEmbedding : Embedding + { + _ = Throw.IfNull(generator); + + if (generator.GetService(typeof(TService), serviceKey) is not TService service) + { + throw Throw.CreateMissingServiceException(typeof(TService), serviceKey); + } + + return service; + } + + // The following overloads exist purely to work around the lack of partial generic type inference. // Given an IEmbeddingGenerator generator, to call GetService with TService, you still need // to re-specify both TInput and TEmbedding, e.g. generator.GetService, TService>. // The case of string/Embedding is by far the most common case today, so this overload exists as an @@ -45,6 +107,7 @@ public static class EmbeddingGeneratorExtensions /// The generator. /// An optional key that can be used to help identify the target service. /// The found object, otherwise . + /// is . /// /// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the /// , including itself or any services it might be wrapping. @@ -52,6 +115,23 @@ public static class EmbeddingGeneratorExtensions public static TService? GetService(this IEmbeddingGenerator> generator, object? serviceKey = null) => GetService, TService>(generator, serviceKey); + /// + /// Asks the for an object of type + /// and throws an exception if one isn't available. + /// + /// The type of the object to be retrieved. + /// The generator. + /// An optional key that can be used to help identify the target service. + /// The found object. + /// is . + /// No service of the requested type for the specified key is available. + /// + /// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the + /// , including itself or any services it might be wrapping. + /// + public static TService GetRequiredService(this IEmbeddingGenerator> generator, object? serviceKey = null) => + GetRequiredService, TService>(generator, serviceKey); + /// Generates an embedding vector from the specified . /// The type from which embeddings will be generated. /// The numeric type of the embedding data. @@ -60,6 +140,9 @@ public static class EmbeddingGeneratorExtensions /// The embedding generation options to configure the request. /// The to monitor for cancellation requests. The default is . /// The generated embedding for the specified . + /// is . + /// is . + /// The generator did not produce exactly one embedding. /// /// This operation is equivalent to using and returning the /// resulting 's property. @@ -84,6 +167,9 @@ public static async Task> GenerateEmbeddingVec /// /// The generated embedding for the specified . /// + /// is . + /// is . + /// The generator did not produce exactly one embedding. /// /// This operations is equivalent to using with a /// collection composed of the single and then returning the first embedding element from the @@ -125,6 +211,9 @@ public static async Task GenerateEmbeddingAsync( /// The embedding generation options to configure the request. /// The to monitor for cancellation requests. The default is . /// An array containing tuples of the input values and the associated generated embeddings. + /// is . + /// is . + /// The generator did not produce one embedding for each input value. public static async Task<(TInput Value, TEmbedding Embedding)[]> GenerateAndZipAsync( this IEmbeddingGenerator generator, IEnumerable values, diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs index 02bf1880427..c260708079c 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs @@ -41,6 +41,7 @@ Task> GenerateAsync( /// The type of object being requested. /// An optional key that can be used to help identify the target service. /// The found object, otherwise . + /// is . /// /// The purpose of this method is to allow for the retrieval of strongly typed services that might be provided by the /// , including itself or any services it might be wrapping. diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Throw.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Throw.cs new file mode 100644 index 00000000000..96bbc4d020d --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Throw.cs @@ -0,0 +1,15 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; + +namespace Microsoft.Shared.Diagnostics; + +internal static partial class Throw +{ + /// Throws an exception indicating that a required service is not available. + public static InvalidOperationException CreateMissingServiceException(Type serviceType, object? serviceKey) => + new InvalidOperationException(serviceKey is null ? + $"No service of type '{serviceType}' is available." : + $"No service of type '{serviceType}' for the key '{serviceKey}' is available."); +} diff --git a/src/Shared/Throw/Throw.cs b/src/Shared/Throw/Throw.cs index ee082e2da4a..94c5c8a50db 100644 --- a/src/Shared/Throw/Throw.cs +++ b/src/Shared/Throw/Throw.cs @@ -18,13 +18,14 @@ namespace Microsoft.Shared.Diagnostics; /// messages. /// [SuppressMessage("Minor Code Smell", "S4136:Method overloads should be grouped together", Justification = "Doesn't work with the region layout")] +[SuppressMessage("Minor Code Smell", "S2333:Partial is gratuitous in this context", Justification = "Some projects add additional partial parts.")] [SuppressMessage("Design", "CA1716", Justification = "Not part of an API")] #if !SHARED_PROJECT [ExcludeFromCodeCoverage] #endif -internal static class Throw +internal static partial class Throw { #region For Object diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs index 64ea4406be9..5a95f2b3fd0 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs @@ -17,6 +17,58 @@ public void GetService_InvalidArgs_Throws() Assert.Throws("client", () => ChatClientExtensions.GetService(null!)); } + [Fact] + public void GetRequiredService_InvalidArgs_Throws() + { + Assert.Throws("client", () => ChatClientExtensions.GetRequiredService(null!, typeof(string))); + Assert.Throws("client", () => ChatClientExtensions.GetRequiredService(null!)); + + using var client = new TestChatClient(); + Assert.Throws("serviceType", () => client.GetRequiredService(null!)); + } + + [Fact] + public void GetService_ValidService_Returned() + { + using var client = new TestChatClient + { + GetServiceCallback = (Type serviceType, object? serviceKey) => + { + if (serviceType == typeof(string)) + { + return serviceKey == null ? "null key" : "non-null key"; + } + + if (serviceType == typeof(IChatClient)) + { + return new object(); + } + + return null; + }, + }; + + Assert.Equal("null key", client.GetService()); + Assert.Equal("null key", client.GetService(null)); + Assert.Equal("non-null key", client.GetService("key")); + + Assert.Null(client.GetService()); + Assert.Null(client.GetService("key")); + Assert.Null(client.GetService()); + + Assert.Equal("null key", client.GetRequiredService(typeof(string))); + Assert.Equal("null key", client.GetRequiredService()); + Assert.Equal("null key", client.GetRequiredService(null)); + Assert.Equal("non-null key", client.GetRequiredService(typeof(string), "key")); + Assert.Equal("non-null key", client.GetRequiredService("key")); + + Assert.Throws(() => client.GetRequiredService(typeof(object))); + Assert.Throws(() => client.GetRequiredService()); + Assert.Throws(() => client.GetRequiredService(typeof(object), "key")); + Assert.Throws(() => client.GetRequiredService("key")); + Assert.Throws(() => client.GetRequiredService()); + } + [Fact] public void GetResponseAsync_InvalidArgs_Throws() { diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs index 4466dd85d1e..8a61fbb0786 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs @@ -17,6 +17,77 @@ public void GetService_InvalidArgs_Throws() Assert.Throws("generator", () => EmbeddingGeneratorExtensions.GetService, object>(null!)); } + [Fact] + public void GetRequiredService_InvalidArgs_Throws() + { + Assert.Throws("generator", () => EmbeddingGeneratorExtensions.GetRequiredService(null!)); + Assert.Throws("generator", () => EmbeddingGeneratorExtensions.GetRequiredService>(null!, typeof(string))); + Assert.Throws("generator", () => EmbeddingGeneratorExtensions.GetRequiredService, object>(null!)); + + using var generator = new TestEmbeddingGenerator(); + Assert.Throws("serviceType", () => generator.GetRequiredService(null!)); + } + + [Fact] + public void GetService_ValidService_Returned() + { + using IEmbeddingGenerator> generator = new TestEmbeddingGenerator + { + GetServiceCallback = (Type serviceType, object? serviceKey) => + { + if (serviceType == typeof(string)) + { + return serviceKey == null ? "null key" : "non-null key"; + } + + if (serviceType == typeof(IEmbeddingGenerator>)) + { + return new object(); + } + + return null; + }, + }; + + Assert.Equal("null key", generator.GetService(typeof(string))); + Assert.Equal("null key", generator.GetService()); + Assert.Equal("null key", generator.GetService, string>()); + + Assert.Equal("non-null key", generator.GetService(typeof(string), "key")); + Assert.Equal("non-null key", generator.GetService("key")); + Assert.Equal("non-null key", generator.GetService, string>("key")); + + Assert.Null(generator.GetService(typeof(object))); + Assert.Null(generator.GetService()); + Assert.Null(generator.GetService, object>()); + + Assert.Null(generator.GetService(typeof(object), "key")); + Assert.Null(generator.GetService("key")); + Assert.Null(generator.GetService, object>("key")); + + Assert.Null(generator.GetService()); + Assert.Null(generator.GetService, int?>()); + + Assert.Equal("null key", generator.GetRequiredService(typeof(string))); + Assert.Equal("null key", generator.GetRequiredService()); + Assert.Equal("null key", generator.GetRequiredService, string>()); + + Assert.Equal("non-null key", generator.GetRequiredService(typeof(string), "key")); + Assert.Equal("non-null key", generator.GetRequiredService("key")); + Assert.Equal("non-null key", generator.GetRequiredService, string>("key")); + + Assert.Throws(() => generator.GetRequiredService(typeof(object))); + Assert.Throws(() => generator.GetRequiredService()); + Assert.Throws(() => generator.GetRequiredService, object>()); + + Assert.Throws(() => generator.GetRequiredService(typeof(object), "key")); + Assert.Throws(() => generator.GetRequiredService("key")); + Assert.Throws(() => generator.GetRequiredService, object>("key")); + + Assert.Throws(() => generator.GetRequiredService()); + Assert.Throws(() => generator.GetRequiredService, int?>()); + } + [Fact] public async Task GenerateAsync_InvalidArgs_ThrowsAsync() {