|
| 1 | +package ai |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + |
| 6 | + "github.com/anthropics/anthropic-sdk-go" |
| 7 | + anthropicoption "github.com/anthropics/anthropic-sdk-go/option" |
| 8 | + "github.com/kylecarbs/aisdk-go" |
| 9 | + "github.com/openai/openai-go" |
| 10 | + openaioption "github.com/openai/openai-go/option" |
| 11 | + "golang.org/x/xerrors" |
| 12 | + "google.golang.org/genai" |
| 13 | + |
| 14 | + "github.com/coder/coder/v2/codersdk" |
| 15 | +) |
| 16 | + |
| 17 | +type LanguageModel struct { |
| 18 | + codersdk.LanguageModel |
| 19 | + StreamFunc StreamFunc |
| 20 | +} |
| 21 | + |
| 22 | +type StreamOptions struct { |
| 23 | + SystemPrompt string |
| 24 | + Model string |
| 25 | + Messages []aisdk.Message |
| 26 | + Thinking bool |
| 27 | + Tools []aisdk.Tool |
| 28 | +} |
| 29 | + |
| 30 | +type StreamFunc func(ctx context.Context, options StreamOptions) (aisdk.DataStream, error) |
| 31 | + |
| 32 | +// LanguageModels is a map of language model ID to language model. |
| 33 | +type LanguageModels map[string]LanguageModel |
| 34 | + |
| 35 | +func ModelsFromConfig(ctx context.Context, configs []codersdk.AIProviderConfig) (LanguageModels, error) { |
| 36 | + models := make(LanguageModels) |
| 37 | + |
| 38 | + for _, config := range configs { |
| 39 | + var streamFunc StreamFunc |
| 40 | + |
| 41 | + switch config.Type { |
| 42 | + case "openai": |
| 43 | + opts := []openaioption.RequestOption{ |
| 44 | + openaioption.WithAPIKey(config.APIKey), |
| 45 | + } |
| 46 | + if config.BaseURL != "" { |
| 47 | + opts = append(opts, openaioption.WithBaseURL(config.BaseURL)) |
| 48 | + } |
| 49 | + client := openai.NewClient(opts...) |
| 50 | + streamFunc = func(ctx context.Context, options StreamOptions) (aisdk.DataStream, error) { |
| 51 | + openaiMessages, err := aisdk.MessagesToOpenAI(options.Messages) |
| 52 | + if err != nil { |
| 53 | + return nil, err |
| 54 | + } |
| 55 | + tools := aisdk.ToolsToOpenAI(options.Tools) |
| 56 | + if options.SystemPrompt != "" { |
| 57 | + openaiMessages = append([]openai.ChatCompletionMessageParamUnion{ |
| 58 | + openai.SystemMessage(options.SystemPrompt), |
| 59 | + }, openaiMessages...) |
| 60 | + } |
| 61 | + |
| 62 | + return aisdk.OpenAIToDataStream(client.Chat.Completions.NewStreaming(ctx, openai.ChatCompletionNewParams{ |
| 63 | + Messages: openaiMessages, |
| 64 | + Model: options.Model, |
| 65 | + Tools: tools, |
| 66 | + MaxTokens: openai.Int(8192), |
| 67 | + })), nil |
| 68 | + } |
| 69 | + if config.Models == nil { |
| 70 | + models, err := client.Models.List(ctx) |
| 71 | + if err != nil { |
| 72 | + return nil, err |
| 73 | + } |
| 74 | + config.Models = make([]string, len(models.Data)) |
| 75 | + for i, model := range models.Data { |
| 76 | + config.Models[i] = model.ID |
| 77 | + } |
| 78 | + } |
| 79 | + case "anthropic": |
| 80 | + client := anthropic.NewClient(anthropicoption.WithAPIKey(config.APIKey)) |
| 81 | + streamFunc = func(ctx context.Context, options StreamOptions) (aisdk.DataStream, error) { |
| 82 | + anthropicMessages, systemMessage, err := aisdk.MessagesToAnthropic(options.Messages) |
| 83 | + if err != nil { |
| 84 | + return nil, err |
| 85 | + } |
| 86 | + if options.SystemPrompt != "" { |
| 87 | + systemMessage = []anthropic.TextBlockParam{ |
| 88 | + *anthropic.NewTextBlock(options.SystemPrompt).OfRequestTextBlock, |
| 89 | + } |
| 90 | + } |
| 91 | + return aisdk.AnthropicToDataStream(client.Messages.NewStreaming(ctx, anthropic.MessageNewParams{ |
| 92 | + Messages: anthropicMessages, |
| 93 | + Model: options.Model, |
| 94 | + System: systemMessage, |
| 95 | + Tools: aisdk.ToolsToAnthropic(options.Tools), |
| 96 | + MaxTokens: 8192, |
| 97 | + })), nil |
| 98 | + } |
| 99 | + if config.Models == nil { |
| 100 | + models, err := client.Models.List(ctx, anthropic.ModelListParams{}) |
| 101 | + if err != nil { |
| 102 | + return nil, err |
| 103 | + } |
| 104 | + config.Models = make([]string, len(models.Data)) |
| 105 | + for i, model := range models.Data { |
| 106 | + config.Models[i] = model.ID |
| 107 | + } |
| 108 | + } |
| 109 | + case "google": |
| 110 | + client, err := genai.NewClient(ctx, &genai.ClientConfig{ |
| 111 | + APIKey: config.APIKey, |
| 112 | + Backend: genai.BackendGeminiAPI, |
| 113 | + }) |
| 114 | + if err != nil { |
| 115 | + return nil, err |
| 116 | + } |
| 117 | + streamFunc = func(ctx context.Context, options StreamOptions) (aisdk.DataStream, error) { |
| 118 | + googleMessages, err := aisdk.MessagesToGoogle(options.Messages) |
| 119 | + if err != nil { |
| 120 | + return nil, err |
| 121 | + } |
| 122 | + tools, err := aisdk.ToolsToGoogle(options.Tools) |
| 123 | + if err != nil { |
| 124 | + return nil, err |
| 125 | + } |
| 126 | + var systemInstruction *genai.Content |
| 127 | + if options.SystemPrompt != "" { |
| 128 | + systemInstruction = &genai.Content{ |
| 129 | + Parts: []*genai.Part{ |
| 130 | + genai.NewPartFromText(options.SystemPrompt), |
| 131 | + }, |
| 132 | + Role: "model", |
| 133 | + } |
| 134 | + } |
| 135 | + return aisdk.GoogleToDataStream(client.Models.GenerateContentStream(ctx, options.Model, googleMessages, &genai.GenerateContentConfig{ |
| 136 | + SystemInstruction: systemInstruction, |
| 137 | + Tools: tools, |
| 138 | + })), nil |
| 139 | + } |
| 140 | + if config.Models == nil { |
| 141 | + models, err := client.Models.List(ctx, &genai.ListModelsConfig{}) |
| 142 | + if err != nil { |
| 143 | + return nil, err |
| 144 | + } |
| 145 | + config.Models = make([]string, len(models.Items)) |
| 146 | + for i, model := range models.Items { |
| 147 | + config.Models[i] = model.Name |
| 148 | + } |
| 149 | + } |
| 150 | + default: |
| 151 | + return nil, xerrors.Errorf("unsupported model type: %s", config.Type) |
| 152 | + } |
| 153 | + |
| 154 | + for _, model := range config.Models { |
| 155 | + models[model] = LanguageModel{ |
| 156 | + LanguageModel: codersdk.LanguageModel{ |
| 157 | + ID: model, |
| 158 | + DisplayName: model, |
| 159 | + Provider: config.Type, |
| 160 | + }, |
| 161 | + StreamFunc: streamFunc, |
| 162 | + } |
| 163 | + } |
| 164 | + } |
| 165 | + |
| 166 | + return models, nil |
| 167 | +} |
0 commit comments