diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatClientExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatClientExtensions.cs
index e485a75d4b1..81fd4f97778 100644
--- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatClientExtensions.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatClientExtensions.cs
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
+using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
@@ -16,6 +17,7 @@ public static class ChatClientExtensions
/// The client.
/// An optional key that can be used to help identify the target service.
/// The found object, otherwise .
+ /// is .
///
/// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the ,
/// including itself or any services it might be wrapping.
@@ -24,7 +26,58 @@ public static class ChatClientExtensions
{
_ = Throw.IfNull(client);
- return (TService?)client.GetService(typeof(TService), serviceKey);
+ return client.GetService(typeof(TService), serviceKey) is TService service ? service : default;
+ }
+
+ ///
+ /// Asks the for an object of the specified type
+ /// and throws an exception if one isn't available.
+ ///
+ /// The client.
+ /// The type of object being requested.
+ /// An optional key that can be used to help identify the target service.
+ /// The found object.
+ /// is .
+ /// is .
+ /// No service of the requested type for the specified key is available.
+ ///
+ /// The purpose of this method is to allow for the retrieval of services that are required to be provided by the ,
+ /// including itself or any services it might be wrapping.
+ ///
+ public static object GetRequiredService(this IChatClient client, Type serviceType, object? serviceKey = null)
+ {
+ _ = Throw.IfNull(client);
+ _ = Throw.IfNull(serviceType);
+
+ return
+ client.GetService(serviceType, serviceKey) ??
+ throw Throw.CreateMissingServiceException(serviceType, serviceKey);
+ }
+
+ ///
+ /// Asks the for an object of type
+ /// and throws an exception if one isn't available.
+ ///
+ /// The type of the object to be retrieved.
+ /// The client.
+ /// An optional key that can be used to help identify the target service.
+ /// The found object.
+ /// is .
+ /// No service of the requested type for the specified key is available.
+ ///
+ /// The purpose of this method is to allow for the retrieval of strongly typed services that are required to be provided by the ,
+ /// including itself or any services it might be wrapping.
+ ///
+ public static TService GetRequiredService(this IChatClient client, object? serviceKey = null)
+ {
+ _ = Throw.IfNull(client);
+
+ if (client.GetService(typeof(TService), serviceKey) is not TService service)
+ {
+ throw Throw.CreateMissingServiceException(typeof(TService), serviceKey);
+ }
+
+ return service;
}
/// Sends a user chat text message and returns the response messages.
@@ -33,6 +86,8 @@ public static class ChatClientExtensions
/// The chat options to configure the request.
/// The to monitor for cancellation requests. The default is .
/// The response messages generated by the client.
+ /// is .
+ /// is .
public static Task GetResponseAsync(
this IChatClient client,
string chatMessage,
@@ -51,6 +106,8 @@ public static Task GetResponseAsync(
/// The chat options to configure the request.
/// The to monitor for cancellation requests. The default is .
/// The response messages generated by the client.
+ /// is .
+ /// is .
public static Task GetResponseAsync(
this IChatClient client,
ChatMessage chatMessage,
@@ -69,6 +126,8 @@ public static Task GetResponseAsync(
/// The chat options to configure the request.
/// The to monitor for cancellation requests. The default is .
/// The response messages generated by the client.
+ /// is .
+ /// is .
public static IAsyncEnumerable GetStreamingResponseAsync(
this IChatClient client,
string chatMessage,
@@ -87,6 +146,8 @@ public static IAsyncEnumerable GetStreamingResponseAsync(
/// The chat options to configure the request.
/// The to monitor for cancellation requests. The default is .
/// The response messages generated by the client.
+ /// is .
+ /// is .
public static IAsyncEnumerable GetStreamingResponseAsync(
this IChatClient client,
ChatMessage chatMessage,
diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs
index 79e2f658cb9..26a39f05105 100644
--- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs
@@ -57,6 +57,7 @@ IAsyncEnumerable GetStreamingResponseAsync(
/// The type of object being requested.
/// An optional key that can be used to help identify the target service.
/// The found object, otherwise .
+ /// is .
///
/// The purpose of this method is to allow for the retrieval of strongly typed services that might be provided by the ,
/// including itself or any services it might be wrapping.
diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs
index 1165d299edf..d69952598dd 100644
--- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs
@@ -9,6 +9,7 @@
using Microsoft.Shared.Diagnostics;
#pragma warning disable S2302 // "nameof" should be used
+#pragma warning disable S4136 // Method overloads should be grouped together
namespace Microsoft.Extensions.AI;
@@ -22,19 +23,80 @@ public static class EmbeddingGeneratorExtensions
/// The generator.
/// An optional key that can be used to help identify the target service.
/// The found object, otherwise .
+ /// is .
///
/// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the
/// , including itself or any services it might be wrapping.
///
- public static TService? GetService(this IEmbeddingGenerator generator, object? serviceKey = null)
+ public static TService? GetService(
+ this IEmbeddingGenerator generator, object? serviceKey = null)
where TEmbedding : Embedding
{
_ = Throw.IfNull(generator);
- return (TService?)generator.GetService(typeof(TService), serviceKey);
+ return generator.GetService(typeof(TService), serviceKey) is TService service ? service : default;
}
- // The following overload exists purely to work around the lack of partial generic type inference.
+ ///
+ /// Asks the for an object of the specified type
+ /// and throws an exception if one isn't available.
+ ///
+ /// The type from which embeddings will be generated.
+ /// The numeric type of the embedding data.
+ /// The generator.
+ /// The type of object being requested.
+ /// An optional key that can be used to help identify the target service.
+ /// The found object.
+ /// is .
+ /// is .
+ /// No service of the requested type for the specified key is available.
+ ///
+ /// The purpose of this method is to allow for the retrieval of services that are required to be provided by the
+ /// , including itself or any services it might be wrapping.
+ ///
+ public static object GetRequiredService(
+ this IEmbeddingGenerator generator, Type serviceType, object? serviceKey = null)
+ where TEmbedding : Embedding
+ {
+ _ = Throw.IfNull(generator);
+ _ = Throw.IfNull(serviceType);
+
+ return
+ generator.GetService(serviceType, serviceKey) ??
+ throw Throw.CreateMissingServiceException(serviceType, serviceKey);
+ }
+
+ ///
+ /// Asks the for an object of type
+ /// and throws an exception if one isn't available.
+ ///
+ /// The type from which embeddings will be generated.
+ /// The numeric type of the embedding data.
+ /// The type of the object to be retrieved.
+ /// The generator.
+ /// An optional key that can be used to help identify the target service.
+ /// The found object.
+ /// is .
+ /// No service of the requested type for the specified key is available.
+ ///
+ /// The purpose of this method is to allow for the retrieval of strongly typed services that are required to be provided by the
+ /// , including itself or any services it might be wrapping.
+ ///
+ public static TService GetRequiredService(
+ this IEmbeddingGenerator generator, object? serviceKey = null)
+ where TEmbedding : Embedding
+ {
+ _ = Throw.IfNull(generator);
+
+ if (generator.GetService(typeof(TService), serviceKey) is not TService service)
+ {
+ throw Throw.CreateMissingServiceException(typeof(TService), serviceKey);
+ }
+
+ return service;
+ }
+
+ // The following overloads exist purely to work around the lack of partial generic type inference.
// Given an IEmbeddingGenerator generator, to call GetService with TService, you still need
// to re-specify both TInput and TEmbedding, e.g. generator.GetService, TService>.
// The case of string/Embedding is by far the most common case today, so this overload exists as an
@@ -45,6 +107,7 @@ public static class EmbeddingGeneratorExtensions
/// The generator.
/// An optional key that can be used to help identify the target service.
/// The found object, otherwise .
+ /// is .
///
/// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the
/// , including itself or any services it might be wrapping.
@@ -52,6 +115,23 @@ public static class EmbeddingGeneratorExtensions
public static TService? GetService(this IEmbeddingGenerator> generator, object? serviceKey = null) =>
GetService, TService>(generator, serviceKey);
+ ///
+ /// Asks the for an object of type
+ /// and throws an exception if one isn't available.
+ ///
+ /// The type of the object to be retrieved.
+ /// The generator.
+ /// An optional key that can be used to help identify the target service.
+ /// The found object.
+ /// is .
+ /// No service of the requested type for the specified key is available.
+ ///
+ /// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the
+ /// , including itself or any services it might be wrapping.
+ ///
+ public static TService GetRequiredService(this IEmbeddingGenerator> generator, object? serviceKey = null) =>
+ GetRequiredService, TService>(generator, serviceKey);
+
/// Generates an embedding vector from the specified .
/// The type from which embeddings will be generated.
/// The numeric type of the embedding data.
@@ -60,6 +140,9 @@ public static class EmbeddingGeneratorExtensions
/// The embedding generation options to configure the request.
/// The to monitor for cancellation requests. The default is .
/// The generated embedding for the specified .
+ /// is .
+ /// is .
+ /// The generator did not produce exactly one embedding.
///
/// This operation is equivalent to using and returning the
/// resulting 's property.
@@ -84,6 +167,9 @@ public static async Task> GenerateEmbeddingVec
///
/// The generated embedding for the specified .
///
+ /// is .
+ /// is .
+ /// The generator did not produce exactly one embedding.
///
/// This operations is equivalent to using with a
/// collection composed of the single and then returning the first embedding element from the
@@ -125,6 +211,9 @@ public static async Task GenerateEmbeddingAsync(
/// The embedding generation options to configure the request.
/// The to monitor for cancellation requests. The default is .
/// An array containing tuples of the input values and the associated generated embeddings.
+ /// is .
+ /// is .
+ /// The generator did not produce one embedding for each input value.
public static async Task<(TInput Value, TEmbedding Embedding)[]> GenerateAndZipAsync(
this IEmbeddingGenerator generator,
IEnumerable values,
diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs
index 02bf1880427..c260708079c 100644
--- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs
@@ -41,6 +41,7 @@ Task> GenerateAsync(
/// The type of object being requested.
/// An optional key that can be used to help identify the target service.
/// The found object, otherwise .
+ /// is .
///
/// The purpose of this method is to allow for the retrieval of strongly typed services that might be provided by the
/// , including itself or any services it might be wrapping.
diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Throw.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Throw.cs
new file mode 100644
index 00000000000..96bbc4d020d
--- /dev/null
+++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Throw.cs
@@ -0,0 +1,15 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+
+namespace Microsoft.Shared.Diagnostics;
+
+internal static partial class Throw
+{
+ /// Throws an exception indicating that a required service is not available.
+ public static InvalidOperationException CreateMissingServiceException(Type serviceType, object? serviceKey) =>
+ new InvalidOperationException(serviceKey is null ?
+ $"No service of type '{serviceType}' is available." :
+ $"No service of type '{serviceType}' for the key '{serviceKey}' is available.");
+}
diff --git a/src/Shared/Throw/Throw.cs b/src/Shared/Throw/Throw.cs
index ee082e2da4a..94c5c8a50db 100644
--- a/src/Shared/Throw/Throw.cs
+++ b/src/Shared/Throw/Throw.cs
@@ -18,13 +18,14 @@ namespace Microsoft.Shared.Diagnostics;
/// messages.
///
[SuppressMessage("Minor Code Smell", "S4136:Method overloads should be grouped together", Justification = "Doesn't work with the region layout")]
+[SuppressMessage("Minor Code Smell", "S2333:Partial is gratuitous in this context", Justification = "Some projects add additional partial parts.")]
[SuppressMessage("Design", "CA1716", Justification = "Not part of an API")]
#if !SHARED_PROJECT
[ExcludeFromCodeCoverage]
#endif
-internal static class Throw
+internal static partial class Throw
{
#region For Object
diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs
index 64ea4406be9..5a95f2b3fd0 100644
--- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs
+++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs
@@ -17,6 +17,58 @@ public void GetService_InvalidArgs_Throws()
Assert.Throws("client", () => ChatClientExtensions.GetService