Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 567d395

Browse files
committed
Add more MCP stuff
1 parent 0c07739 commit 567d395

File tree

13 files changed

+1369
-48
lines changed

13 files changed

+1369
-48
lines changed

coderd/ai/ai.go

+27-5
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@ type LanguageModel struct {
1919
}
2020

2121
type StreamOptions struct {
22-
Model string
23-
Messages []aisdk.Message
24-
Thinking bool
25-
Tools []aisdk.Tool
22+
SystemPrompt string
23+
Model string
24+
Messages []aisdk.Message
25+
Thinking bool
26+
Tools []aisdk.Tool
2627
}
2728

2829
type StreamFunc func(ctx context.Context, options StreamOptions) (aisdk.DataStream, error)
@@ -45,6 +46,12 @@ func ModelsFromConfig(ctx context.Context, configs []codersdk.AIProviderConfig)
4546
return nil, err
4647
}
4748
tools := aisdk.ToolsToOpenAI(options.Tools)
49+
if options.SystemPrompt != "" {
50+
openaiMessages = append([]openai.ChatCompletionMessageParamUnion{
51+
openai.SystemMessage(options.SystemPrompt),
52+
}, openaiMessages...)
53+
}
54+
4855
return aisdk.OpenAIToDataStream(client.Chat.Completions.NewStreaming(ctx, openai.ChatCompletionNewParams{
4956
Messages: openaiMessages,
5057
Model: options.Model,
@@ -70,6 +77,11 @@ func ModelsFromConfig(ctx context.Context, configs []codersdk.AIProviderConfig)
7077
if err != nil {
7178
return nil, err
7279
}
80+
if options.SystemPrompt != "" {
81+
systemMessage = []anthropic.TextBlockParam{
82+
*anthropic.NewTextBlock(options.SystemPrompt).OfRequestTextBlock,
83+
}
84+
}
7385
return aisdk.AnthropicToDataStream(client.Messages.NewStreaming(ctx, anthropic.MessageNewParams{
7486
Messages: anthropicMessages,
7587
Model: options.Model,
@@ -106,8 +118,18 @@ func ModelsFromConfig(ctx context.Context, configs []codersdk.AIProviderConfig)
106118
if err != nil {
107119
return nil, err
108120
}
121+
var systemInstruction *genai.Content
122+
if options.SystemPrompt != "" {
123+
systemInstruction = &genai.Content{
124+
Parts: []*genai.Part{
125+
genai.NewPartFromText(options.SystemPrompt),
126+
},
127+
Role: "model",
128+
}
129+
}
109130
return aisdk.GoogleToDataStream(client.Models.GenerateContentStream(ctx, options.Model, googleMessages, &genai.GenerateContentConfig{
110-
Tools: tools,
131+
SystemInstruction: systemInstruction,
132+
Tools: tools,
111133
})), nil
112134
}
113135
if config.Models == nil {

coderd/chat.go

+55-40
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package coderd
22

33
import (
44
"encoding/json"
5+
"io"
56
"net/http"
67
"time"
78

@@ -12,11 +13,9 @@ import (
1213
"github.com/coder/coder/v2/coderd/httpapi"
1314
"github.com/coder/coder/v2/coderd/httpmw"
1415
"github.com/coder/coder/v2/codersdk"
15-
codermcp "github.com/coder/coder/v2/mcp"
16+
"github.com/coder/coder/v2/codersdk/toolsdk"
1617
"github.com/google/uuid"
1718
"github.com/kylecarbs/aisdk-go"
18-
"github.com/mark3labs/mcp-go/mcp"
19-
"github.com/mark3labs/mcp-go/server"
2019
)
2120

2221
// postChats creates a new chat.
@@ -157,31 +156,17 @@ func (api *API) postChatMessages(w http.ResponseWriter, r *http.Request) {
157156
}
158157
messages = append(messages, req.Message)
159158

160-
toolMap := codermcp.AllTools()
161-
toolsByName := make(map[string]server.ToolHandlerFunc)
162159
client := codersdk.New(api.AccessURL)
163160
client.SetSessionToken(httpmw.APITokenFromRequest(r))
164-
toolDeps := codermcp.ToolDeps{
165-
Client: client,
166-
Logger: &api.Logger,
167-
}
168-
for _, tool := range toolMap {
169-
toolsByName[tool.Tool.Name] = tool.MakeHandler(toolDeps)
170-
}
171-
convertedTools := make([]aisdk.Tool, len(toolMap))
172-
for i, tool := range toolMap {
173-
schema := aisdk.Schema{
174-
Required: tool.Tool.InputSchema.Required,
175-
Properties: tool.Tool.InputSchema.Properties,
176-
}
177-
if tool.Tool.InputSchema.Required == nil {
178-
schema.Required = []string{}
179-
}
180-
convertedTools[i] = aisdk.Tool{
181-
Name: tool.Tool.Name,
182-
Description: tool.Tool.Description,
183-
Schema: schema,
161+
162+
tools := make([]aisdk.Tool, len(toolsdk.All))
163+
handlers := map[string]toolsdk.HandlerFunc[any]{}
164+
for i, tool := range toolsdk.All {
165+
if tool.Tool.Schema.Required == nil {
166+
tool.Tool.Schema.Required = []string{}
184167
}
168+
tools[i] = tool.Tool
169+
handlers[tool.Tool.Name] = tool.Handler
185170
}
186171

187172
provider, ok := api.LanguageModels[req.Model]
@@ -192,6 +177,43 @@ func (api *API) postChatMessages(w http.ResponseWriter, r *http.Request) {
192177
return
193178
}
194179

180+
// If it's the user's first message, generate a title for the chat.
181+
if len(messages) == 1 {
182+
var acc aisdk.DataStreamAccumulator
183+
stream, err := provider.StreamFunc(ctx, ai.StreamOptions{
184+
Model: req.Model,
185+
SystemPrompt: `- You will generate a short title based on the user's message.
186+
- It should be maximum of 40 characters.
187+
- Do not use quotes, colons, special characters, or emojis.`,
188+
Messages: messages,
189+
Tools: tools,
190+
})
191+
if err != nil {
192+
httpapi.Write(ctx, w, http.StatusInternalServerError, codersdk.Response{
193+
Message: "Failed to create stream",
194+
Detail: err.Error(),
195+
})
196+
}
197+
stream = stream.WithAccumulator(&acc)
198+
err = stream.Pipe(io.Discard)
199+
if err != nil {
200+
httpapi.Write(ctx, w, http.StatusInternalServerError, codersdk.Response{
201+
Message: "Failed to pipe stream",
202+
Detail: err.Error(),
203+
})
204+
}
205+
err = api.Database.UpdateChatByID(ctx, database.UpdateChatByIDParams{
206+
ID: chat.ID,
207+
Title: acc.Messages()[0].Content,
208+
})
209+
if err != nil {
210+
httpapi.Write(ctx, w, http.StatusInternalServerError, codersdk.Response{
211+
Message: "Failed to update chat title",
212+
Detail: err.Error(),
213+
})
214+
}
215+
}
216+
195217
// Write headers for the data stream!
196218
aisdk.WriteDataStreamHeaders(w)
197219

@@ -224,7 +246,11 @@ func (api *API) postChatMessages(w http.ResponseWriter, r *http.Request) {
224246
stream, err := provider.StreamFunc(ctx, ai.StreamOptions{
225247
Model: req.Model,
226248
Messages: messages,
227-
Tools: convertedTools,
249+
Tools: tools,
250+
SystemPrompt: `You are a chat assistant for Coder. You will attempt to resolve the user's
251+
request to the maximum utilization of your tools.
252+
253+
Try your best to not ask the user for help - solve the task with your tools!`,
228254
})
229255
if err != nil {
230256
httpapi.Write(ctx, w, http.StatusInternalServerError, codersdk.Response{
@@ -234,28 +260,17 @@ func (api *API) postChatMessages(w http.ResponseWriter, r *http.Request) {
234260
return
235261
}
236262
stream = stream.WithToolCalling(func(toolCall aisdk.ToolCall) any {
237-
tool, ok := toolsByName[toolCall.Name]
263+
tool, ok := handlers[toolCall.Name]
238264
if !ok {
239265
return nil
240266
}
241-
result, err := tool(ctx, mcp.CallToolRequest{
242-
Params: struct {
243-
Name string "json:\"name\""
244-
Arguments map[string]interface{} "json:\"arguments,omitempty\""
245-
Meta *struct {
246-
ProgressToken mcp.ProgressToken "json:\"progressToken,omitempty\""
247-
} "json:\"_meta,omitempty\""
248-
}{
249-
Name: toolCall.Name,
250-
Arguments: toolCall.Args,
251-
},
252-
})
267+
result, err := tool(toolsdk.WithClient(ctx, client), toolCall.Args)
253268
if err != nil {
254269
return map[string]any{
255270
"error": err.Error(),
256271
}
257272
}
258-
return result.Content
273+
return result
259274
}).WithAccumulator(&acc)
260275

261276
err = stream.Pipe(w)

coderd/database/dbauthz/dbauthz.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -3993,7 +3993,10 @@ func (q *querier) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKe
39933993
}
39943994

39953995
func (q *querier) UpdateChatByID(ctx context.Context, arg database.UpdateChatByIDParams) error {
3996-
panic("not implemented")
3996+
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat.WithID(arg.ID)); err != nil {
3997+
return err
3998+
}
3999+
return q.db.UpdateChatByID(ctx, arg)
39974000
}
39984001

39994002
func (q *querier) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) {

0 commit comments

Comments
 (0)