@@ -2,6 +2,7 @@ package coderd
22
33import (
44 "encoding/json"
5+ "io"
56 "net/http"
67 "time"
78
@@ -12,11 +13,9 @@ import (
1213 "github.com/coder/coder/v2/coderd/httpapi"
1314 "github.com/coder/coder/v2/coderd/httpmw"
1415 "github.com/coder/coder/v2/codersdk"
15- codermcp "github.com/coder/coder/v2/mcp "
16+ "github.com/coder/coder/v2/codersdk/toolsdk "
1617 "github.com/google/uuid"
1718 "github.com/kylecarbs/aisdk-go"
18- "github.com/mark3labs/mcp-go/mcp"
19- "github.com/mark3labs/mcp-go/server"
2019)
2120
2221// postChats creates a new chat.
@@ -157,31 +156,17 @@ func (api *API) postChatMessages(w http.ResponseWriter, r *http.Request) {
157156 }
158157 messages = append (messages , req .Message )
159158
160- toolMap := codermcp .AllTools ()
161- toolsByName := make (map [string ]server.ToolHandlerFunc )
162159 client := codersdk .New (api .AccessURL )
163160 client .SetSessionToken (httpmw .APITokenFromRequest (r ))
164- toolDeps := codermcp.ToolDeps {
165- Client : client ,
166- Logger : & api .Logger ,
167- }
168- for _ , tool := range toolMap {
169- toolsByName [tool .Tool .Name ] = tool .MakeHandler (toolDeps )
170- }
171- convertedTools := make ([]aisdk.Tool , len (toolMap ))
172- for i , tool := range toolMap {
173- schema := aisdk.Schema {
174- Required : tool .Tool .InputSchema .Required ,
175- Properties : tool .Tool .InputSchema .Properties ,
176- }
177- if tool .Tool .InputSchema .Required == nil {
178- schema .Required = []string {}
179- }
180- convertedTools [i ] = aisdk.Tool {
181- Name : tool .Tool .Name ,
182- Description : tool .Tool .Description ,
183- Schema : schema ,
161+
162+ tools := make ([]aisdk.Tool , len (toolsdk .All ))
163+ handlers := map [string ]toolsdk.HandlerFunc [any ]{}
164+ for i , tool := range toolsdk .All {
165+ if tool .Tool .Schema .Required == nil {
166+ tool .Tool .Schema .Required = []string {}
184167 }
168+ tools [i ] = tool .Tool
169+ handlers [tool .Tool .Name ] = tool .Handler
185170 }
186171
187172 provider , ok := api .LanguageModels [req .Model ]
@@ -192,6 +177,43 @@ func (api *API) postChatMessages(w http.ResponseWriter, r *http.Request) {
192177 return
193178 }
194179
180+ // If it's the user's first message, generate a title for the chat.
181+ if len (messages ) == 1 {
182+ var acc aisdk.DataStreamAccumulator
183+ stream , err := provider .StreamFunc (ctx , ai.StreamOptions {
184+ Model : req .Model ,
185+ SystemPrompt : `- You will generate a short title based on the user's message.
186+ - It should be maximum of 40 characters.
187+ - Do not use quotes, colons, special characters, or emojis.` ,
188+ Messages : messages ,
189+ Tools : tools ,
190+ })
191+ if err != nil {
192+ httpapi .Write (ctx , w , http .StatusInternalServerError , codersdk.Response {
193+ Message : "Failed to create stream" ,
194+ Detail : err .Error (),
195+ })
196+ }
197+ stream = stream .WithAccumulator (& acc )
198+ err = stream .Pipe (io .Discard )
199+ if err != nil {
200+ httpapi .Write (ctx , w , http .StatusInternalServerError , codersdk.Response {
201+ Message : "Failed to pipe stream" ,
202+ Detail : err .Error (),
203+ })
204+ }
205+ err = api .Database .UpdateChatByID (ctx , database.UpdateChatByIDParams {
206+ ID : chat .ID ,
207+ Title : acc .Messages ()[0 ].Content ,
208+ })
209+ if err != nil {
210+ httpapi .Write (ctx , w , http .StatusInternalServerError , codersdk.Response {
211+ Message : "Failed to update chat title" ,
212+ Detail : err .Error (),
213+ })
214+ }
215+ }
216+
195217 // Write headers for the data stream!
196218 aisdk .WriteDataStreamHeaders (w )
197219
@@ -224,7 +246,11 @@ func (api *API) postChatMessages(w http.ResponseWriter, r *http.Request) {
224246 stream , err := provider .StreamFunc (ctx , ai.StreamOptions {
225247 Model : req .Model ,
226248 Messages : messages ,
227- Tools : convertedTools ,
249+ Tools : tools ,
250+ SystemPrompt : `You are a chat assistant for Coder. You will attempt to resolve the user's
251+ request to the maximum utilization of your tools.
252+
253+ Try your best to not ask the user for help - solve the task with your tools!` ,
228254 })
229255 if err != nil {
230256 httpapi .Write (ctx , w , http .StatusInternalServerError , codersdk.Response {
@@ -234,28 +260,17 @@ func (api *API) postChatMessages(w http.ResponseWriter, r *http.Request) {
234260 return
235261 }
236262 stream = stream .WithToolCalling (func (toolCall aisdk.ToolCall ) any {
237- tool , ok := toolsByName [toolCall .Name ]
263+ tool , ok := handlers [toolCall .Name ]
238264 if ! ok {
239265 return nil
240266 }
241- result , err := tool (ctx , mcp.CallToolRequest {
242- Params : struct {
243- Name string "json:\" name\" "
244- Arguments map [string ]interface {} "json:\" arguments,omitempty\" "
245- Meta * struct {
246- ProgressToken mcp.ProgressToken "json:\" progressToken,omitempty\" "
247- } "json:\" _meta,omitempty\" "
248- }{
249- Name : toolCall .Name ,
250- Arguments : toolCall .Args ,
251- },
252- })
267+ result , err := tool (toolsdk .WithClient (ctx , client ), toolCall .Args )
253268 if err != nil {
254269 return map [string ]any {
255270 "error" : err .Error (),
256271 }
257272 }
258- return result . Content
273+ return result
259274 }).WithAccumulator (& acc )
260275
261276 err = stream .Pipe (w )
0 commit comments