Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 0c07739

Browse files
committed
And we have chat!
1 parent 9ac4643 commit 0c07739

27 files changed

+2902
-305
lines changed

cli/server.go

+41
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ import (
6161
"github.com/coder/serpent"
6262
"github.com/coder/wgtunnel/tunnelsdk"
6363

64+
"github.com/coder/coder/v2/coderd/ai"
6465
"github.com/coder/coder/v2/coderd/entitlements"
6566
"github.com/coder/coder/v2/coderd/notifications/reports"
6667
"github.com/coder/coder/v2/coderd/runtimeconfig"
@@ -610,6 +611,22 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
610611
)
611612
}
612613

614+
aiProviders, err := ReadAIProvidersFromEnv(os.Environ())
615+
if err != nil {
616+
return xerrors.Errorf("read ai providers from env: %w", err)
617+
}
618+
vals.AI.Value.Providers = append(vals.AI.Value.Providers, aiProviders...)
619+
for _, provider := range aiProviders {
620+
logger.Debug(
621+
ctx, "loaded ai provider",
622+
slog.F("type", provider.Type),
623+
)
624+
}
625+
languageModels, err := ai.ModelsFromConfig(ctx, vals.AI.Value.Providers)
626+
if err != nil {
627+
return xerrors.Errorf("create language models: %w", err)
628+
}
629+
613630
realIPConfig, err := httpmw.ParseRealIPConfig(vals.ProxyTrustedHeaders, vals.ProxyTrustedOrigins)
614631
if err != nil {
615632
return xerrors.Errorf("parse real ip config: %w", err)
@@ -640,6 +657,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
640657
CacheDir: cacheDir,
641658
GoogleTokenValidator: googleTokenValidator,
642659
ExternalAuthConfigs: externalAuthConfigs,
660+
LanguageModels: languageModels,
643661
RealIPConfig: realIPConfig,
644662
SSHKeygenAlgorithm: sshKeygenAlgorithm,
645663
TracerProvider: tracerProvider,
@@ -2655,6 +2673,29 @@ func ReadAIProvidersFromEnv(environ []string) ([]codersdk.AIProviderConfig, erro
26552673
}
26562674
providers[providerNum] = provider
26572675
}
2676+
for _, envVar := range environ {
2677+
tokens := strings.SplitN(envVar, "=", 2)
2678+
if len(tokens) != 2 {
2679+
continue
2680+
}
2681+
switch tokens[0] {
2682+
case "OPENAI_API_KEY":
2683+
providers = append(providers, codersdk.AIProviderConfig{
2684+
Type: "openai",
2685+
APIKey: tokens[1],
2686+
})
2687+
case "ANTHROPIC_API_KEY":
2688+
providers = append(providers, codersdk.AIProviderConfig{
2689+
Type: "anthropic",
2690+
APIKey: tokens[1],
2691+
})
2692+
case "GOOGLE_API_KEY":
2693+
providers = append(providers, codersdk.AIProviderConfig{
2694+
Type: "google",
2695+
APIKey: tokens[1],
2696+
})
2697+
}
2698+
}
26582699
return providers, nil
26592700
}
26602701

coderd/ai/ai.go

+133-1
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,140 @@ package ai
22

33
import (
44
"context"
5+
"fmt"
56

7+
"github.com/anthropics/anthropic-sdk-go"
8+
anthropicoption "github.com/anthropics/anthropic-sdk-go/option"
9+
"github.com/coder/coder/v2/codersdk"
610
"github.com/kylecarbs/aisdk-go"
11+
"github.com/openai/openai-go"
12+
openaioption "github.com/openai/openai-go/option"
13+
"google.golang.org/genai"
714
)
815

9-
type Provider func(ctx context.Context, messages []aisdk.Message) (aisdk.DataStream, error)
16+
type LanguageModel struct {
17+
codersdk.LanguageModel
18+
StreamFunc StreamFunc
19+
}
20+
21+
type StreamOptions struct {
22+
Model string
23+
Messages []aisdk.Message
24+
Thinking bool
25+
Tools []aisdk.Tool
26+
}
27+
28+
type StreamFunc func(ctx context.Context, options StreamOptions) (aisdk.DataStream, error)
29+
30+
// LanguageModels is a map of language model ID to language model.
31+
type LanguageModels map[string]LanguageModel
32+
33+
func ModelsFromConfig(ctx context.Context, configs []codersdk.AIProviderConfig) (LanguageModels, error) {
34+
models := make(LanguageModels)
35+
36+
for _, config := range configs {
37+
var streamFunc StreamFunc
38+
39+
switch config.Type {
40+
case "openai":
41+
client := openai.NewClient(openaioption.WithAPIKey(config.APIKey))
42+
streamFunc = func(ctx context.Context, options StreamOptions) (aisdk.DataStream, error) {
43+
openaiMessages, err := aisdk.MessagesToOpenAI(options.Messages)
44+
if err != nil {
45+
return nil, err
46+
}
47+
tools := aisdk.ToolsToOpenAI(options.Tools)
48+
return aisdk.OpenAIToDataStream(client.Chat.Completions.NewStreaming(ctx, openai.ChatCompletionNewParams{
49+
Messages: openaiMessages,
50+
Model: options.Model,
51+
Tools: tools,
52+
MaxTokens: openai.Int(8192),
53+
})), nil
54+
}
55+
if config.Models == nil {
56+
models, err := client.Models.List(ctx)
57+
if err != nil {
58+
return nil, err
59+
}
60+
config.Models = make([]string, len(models.Data))
61+
for i, model := range models.Data {
62+
config.Models[i] = model.ID
63+
}
64+
}
65+
break
66+
case "anthropic":
67+
client := anthropic.NewClient(anthropicoption.WithAPIKey(config.APIKey))
68+
streamFunc = func(ctx context.Context, options StreamOptions) (aisdk.DataStream, error) {
69+
anthropicMessages, systemMessage, err := aisdk.MessagesToAnthropic(options.Messages)
70+
if err != nil {
71+
return nil, err
72+
}
73+
return aisdk.AnthropicToDataStream(client.Messages.NewStreaming(ctx, anthropic.MessageNewParams{
74+
Messages: anthropicMessages,
75+
Model: options.Model,
76+
System: systemMessage,
77+
Tools: aisdk.ToolsToAnthropic(options.Tools),
78+
MaxTokens: 8192,
79+
})), nil
80+
}
81+
if config.Models == nil {
82+
models, err := client.Models.List(ctx, anthropic.ModelListParams{})
83+
if err != nil {
84+
return nil, err
85+
}
86+
config.Models = make([]string, len(models.Data))
87+
for i, model := range models.Data {
88+
config.Models[i] = model.ID
89+
}
90+
}
91+
break
92+
case "google":
93+
client, err := genai.NewClient(ctx, &genai.ClientConfig{
94+
APIKey: config.APIKey,
95+
Backend: genai.BackendGeminiAPI,
96+
})
97+
if err != nil {
98+
return nil, err
99+
}
100+
streamFunc = func(ctx context.Context, options StreamOptions) (aisdk.DataStream, error) {
101+
googleMessages, err := aisdk.MessagesToGoogle(options.Messages)
102+
if err != nil {
103+
return nil, err
104+
}
105+
tools, err := aisdk.ToolsToGoogle(options.Tools)
106+
if err != nil {
107+
return nil, err
108+
}
109+
return aisdk.GoogleToDataStream(client.Models.GenerateContentStream(ctx, options.Model, googleMessages, &genai.GenerateContentConfig{
110+
Tools: tools,
111+
})), nil
112+
}
113+
if config.Models == nil {
114+
models, err := client.Models.List(ctx, &genai.ListModelsConfig{})
115+
if err != nil {
116+
return nil, err
117+
}
118+
config.Models = make([]string, len(models.Items))
119+
for i, model := range models.Items {
120+
config.Models[i] = model.Name
121+
}
122+
}
123+
break
124+
default:
125+
return nil, fmt.Errorf("unsupported model type: %s", config.Type)
126+
}
127+
128+
for _, model := range config.Models {
129+
models[model] = LanguageModel{
130+
LanguageModel: codersdk.LanguageModel{
131+
ID: model,
132+
DisplayName: model,
133+
Provider: config.Type,
134+
},
135+
StreamFunc: streamFunc,
136+
}
137+
}
138+
}
139+
140+
return models, nil
141+
}

0 commit comments

Comments
 (0)