@@ -2,8 +2,140 @@ package ai
2
2
3
3
import (
4
4
"context"
5
+ "fmt"
5
6
7
+ "github.com/anthropics/anthropic-sdk-go"
8
+ anthropicoption "github.com/anthropics/anthropic-sdk-go/option"
9
+ "github.com/coder/coder/v2/codersdk"
6
10
"github.com/kylecarbs/aisdk-go"
11
+ "github.com/openai/openai-go"
12
+ openaioption "github.com/openai/openai-go/option"
13
+ "google.golang.org/genai"
7
14
)
8
15
9
- type Provider func (ctx context.Context , messages []aisdk.Message ) (aisdk.DataStream , error )
16
+ type LanguageModel struct {
17
+ codersdk.LanguageModel
18
+ StreamFunc StreamFunc
19
+ }
20
+
21
+ type StreamOptions struct {
22
+ Model string
23
+ Messages []aisdk.Message
24
+ Thinking bool
25
+ Tools []aisdk.Tool
26
+ }
27
+
28
+ type StreamFunc func (ctx context.Context , options StreamOptions ) (aisdk.DataStream , error )
29
+
30
+ // LanguageModels is a map of language model ID to language model.
31
+ type LanguageModels map [string ]LanguageModel
32
+
33
+ func ModelsFromConfig (ctx context.Context , configs []codersdk.AIProviderConfig ) (LanguageModels , error ) {
34
+ models := make (LanguageModels )
35
+
36
+ for _ , config := range configs {
37
+ var streamFunc StreamFunc
38
+
39
+ switch config .Type {
40
+ case "openai" :
41
+ client := openai .NewClient (openaioption .WithAPIKey (config .APIKey ))
42
+ streamFunc = func (ctx context.Context , options StreamOptions ) (aisdk.DataStream , error ) {
43
+ openaiMessages , err := aisdk .MessagesToOpenAI (options .Messages )
44
+ if err != nil {
45
+ return nil , err
46
+ }
47
+ tools := aisdk .ToolsToOpenAI (options .Tools )
48
+ return aisdk .OpenAIToDataStream (client .Chat .Completions .NewStreaming (ctx , openai.ChatCompletionNewParams {
49
+ Messages : openaiMessages ,
50
+ Model : options .Model ,
51
+ Tools : tools ,
52
+ MaxTokens : openai .Int (8192 ),
53
+ })), nil
54
+ }
55
+ if config .Models == nil {
56
+ models , err := client .Models .List (ctx )
57
+ if err != nil {
58
+ return nil , err
59
+ }
60
+ config .Models = make ([]string , len (models .Data ))
61
+ for i , model := range models .Data {
62
+ config .Models [i ] = model .ID
63
+ }
64
+ }
65
+ break
66
+ case "anthropic" :
67
+ client := anthropic .NewClient (anthropicoption .WithAPIKey (config .APIKey ))
68
+ streamFunc = func (ctx context.Context , options StreamOptions ) (aisdk.DataStream , error ) {
69
+ anthropicMessages , systemMessage , err := aisdk .MessagesToAnthropic (options .Messages )
70
+ if err != nil {
71
+ return nil , err
72
+ }
73
+ return aisdk .AnthropicToDataStream (client .Messages .NewStreaming (ctx , anthropic.MessageNewParams {
74
+ Messages : anthropicMessages ,
75
+ Model : options .Model ,
76
+ System : systemMessage ,
77
+ Tools : aisdk .ToolsToAnthropic (options .Tools ),
78
+ MaxTokens : 8192 ,
79
+ })), nil
80
+ }
81
+ if config .Models == nil {
82
+ models , err := client .Models .List (ctx , anthropic.ModelListParams {})
83
+ if err != nil {
84
+ return nil , err
85
+ }
86
+ config .Models = make ([]string , len (models .Data ))
87
+ for i , model := range models .Data {
88
+ config .Models [i ] = model .ID
89
+ }
90
+ }
91
+ break
92
+ case "google" :
93
+ client , err := genai .NewClient (ctx , & genai.ClientConfig {
94
+ APIKey : config .APIKey ,
95
+ Backend : genai .BackendGeminiAPI ,
96
+ })
97
+ if err != nil {
98
+ return nil , err
99
+ }
100
+ streamFunc = func (ctx context.Context , options StreamOptions ) (aisdk.DataStream , error ) {
101
+ googleMessages , err := aisdk .MessagesToGoogle (options .Messages )
102
+ if err != nil {
103
+ return nil , err
104
+ }
105
+ tools , err := aisdk .ToolsToGoogle (options .Tools )
106
+ if err != nil {
107
+ return nil , err
108
+ }
109
+ return aisdk .GoogleToDataStream (client .Models .GenerateContentStream (ctx , options .Model , googleMessages , & genai.GenerateContentConfig {
110
+ Tools : tools ,
111
+ })), nil
112
+ }
113
+ if config .Models == nil {
114
+ models , err := client .Models .List (ctx , & genai.ListModelsConfig {})
115
+ if err != nil {
116
+ return nil , err
117
+ }
118
+ config .Models = make ([]string , len (models .Items ))
119
+ for i , model := range models .Items {
120
+ config .Models [i ] = model .Name
121
+ }
122
+ }
123
+ break
124
+ default :
125
+ return nil , fmt .Errorf ("unsupported model type: %s" , config .Type )
126
+ }
127
+
128
+ for _ , model := range config .Models {
129
+ models [model ] = LanguageModel {
130
+ LanguageModel : codersdk.LanguageModel {
131
+ ID : model ,
132
+ DisplayName : model ,
133
+ Provider : config .Type ,
134
+ },
135
+ StreamFunc : streamFunc ,
136
+ }
137
+ }
138
+ }
139
+
140
+ return models , nil
141
+ }
0 commit comments