diff --git a/cmd/mcptools/commands/root.go b/cmd/mcptools/commands/root.go index fd5f2f9..d67dfc1 100644 --- a/cmd/mcptools/commands/root.go +++ b/cmd/mcptools/commands/root.go @@ -27,7 +27,8 @@ const ( var ( // FormatOption is the format option for the command, valid values are "table", "json", and // "pretty". - FormatOption string + // Default is "table". + FormatOption = "table" // ParamsString is the params for the command. ParamsString string ) diff --git a/cmd/mcptools/commands/shell.go b/cmd/mcptools/commands/shell.go index 4822638..e26eb1c 100644 --- a/cmd/mcptools/commands/shell.go +++ b/cmd/mcptools/commands/shell.go @@ -8,7 +8,7 @@ import ( "path/filepath" "strings" - "github.com/f/mcptools/pkg/jsonutils" + "github.com/f/mcptools/pkg/client" "github.com/peterh/liner" "github.com/spf13/cobra" ) @@ -29,15 +29,12 @@ func ShellCmd() *cobra.Command { //nolint:gocyclo cmdArgs := args parsedArgs := []string{} - i := 0 - for i < len(cmdArgs) { - switch { - case (cmdArgs[i] == FlagFormat || cmdArgs[i] == FlagFormatShort) && i+1 < len(cmdArgs): + for i := 0; i < len(cmdArgs); i++ { + if (cmdArgs[i] == FlagFormat || cmdArgs[i] == FlagFormatShort) && i+1 < len(cmdArgs) { FormatOption = cmdArgs[i+1] - i += 2 - default: - parsedArgs = append(parsedArgs, cmdArgs[i]) i++ + } else { + parsedArgs = append(parsedArgs, cmdArgs[i]) } } @@ -59,53 +56,22 @@ func ShellCmd() *cobra.Command { //nolint:gocyclo os.Exit(1) } - fmt.Printf("mcp > MCP Tools Shell (%s)\n", Version) - fmt.Println("mcp > Connected to Server:", strings.Join(parsedArgs, " ")) - fmt.Println("\nmcp > Type '/h' for help or '/q' to quit") + fmt.Fprintf(thisCmd.OutOrStdout(), "mcp > MCP Tools Shell (%s)\n", Version) + fmt.Fprintf(thisCmd.OutOrStdout(), "mcp > Connected to Server: %s\n", strings.Join(parsedArgs, " ")) + fmt.Fprintf(thisCmd.OutOrStdout(), "\nmcp > Type '/h' for help or '/q' to quit\n") line := liner.NewLiner() + line.SetCtrlCAborts(true) defer func() { _ = line.Close() }() - historyFile := filepath.Join(os.Getenv("HOME"), ".mcp_history") - if f, err := os.Open(filepath.Clean(historyFile)); err == nil { - _, _ = line.ReadHistory(f) - _ = f.Close() - } - - defer func() { - if f, err := os.Create(historyFile); err == nil { - _, _ = line.WriteHistory(f) - _ = f.Close() - } - }() - - line.SetCompleter(func(line string) (c []string) { - commands := []string{ - "tools", - "resources", - "prompts", - "call", - "format", - "help", - "exit", - "/h", - "/q", - "/help", - "/quit", - } - for _, cmd := range commands { - if strings.HasPrefix(cmd, line) { - c = append(c, cmd) - } - } - return - }) + defer setUpHistory(line)() + setUpCompleter(line) for { input, err := line.Prompt("mcp > ") if err != nil { if errors.Is(err, liner.ErrPromptAborted) { - fmt.Println("Exiting MCP shell") + fmt.Fprintln(thisCmd.OutOrStdout(), "Exiting MCP shell") break } fmt.Fprintf(os.Stderr, "Error reading input: %v\n", err) @@ -119,12 +85,12 @@ func ShellCmd() *cobra.Command { //nolint:gocyclo line.AppendHistory(input) if input == "/q" || input == "/quit" || input == "exit" { - fmt.Println("Exiting MCP shell") + fmt.Fprintln(thisCmd.OutOrStdout(), "Exiting MCP shell") break } if input == "/h" || input == "/help" || input == "help" { - printShellHelp() + printShellHelp(thisCmd) continue } @@ -137,211 +103,190 @@ func ShellCmd() *cobra.Command { //nolint:gocyclo commandArgs := parts[1:] var resp map[string]any - var respErr error + var listErr error switch command { case "tools": - resp, respErr = mcpClient.ListTools() - if respErr != nil { - fmt.Fprintf(os.Stderr, "Error: %v\n", respErr) - + resp, listErr = mcpClient.ListTools() + if formatErr := FormatAndPrintResponse(thisCmd, resp, listErr); formatErr != nil { + fmt.Fprintf(os.Stderr, "%v\n", formatErr) continue } - - output, formatErr := jsonutils.Format(resp, FormatOption) - if formatErr != nil { - fmt.Fprintf(os.Stderr, "Error formatting output: %v\n", formatErr) - - continue - } - - fmt.Println(output) - case "resources": - resp, respErr = mcpClient.ListResources() - if respErr != nil { - fmt.Fprintf(os.Stderr, "Error: %v\n", respErr) - - continue - } - - output, formatErr := jsonutils.Format(resp, FormatOption) - if formatErr != nil { - fmt.Fprintf(os.Stderr, "Error formatting output: %v\n", formatErr) - + resp, listErr = mcpClient.ListResources() + if formatErr := FormatAndPrintResponse(thisCmd, resp, listErr); formatErr != nil { + fmt.Fprintf(os.Stderr, "%v\n", formatErr) continue } - - fmt.Println(output) - case "prompts": - resp, respErr = mcpClient.ListPrompts() - if respErr != nil { - fmt.Fprintf(os.Stderr, "Error: %v\n", respErr) - + resp, listErr = mcpClient.ListPrompts() + if formatErr := FormatAndPrintResponse(thisCmd, resp, listErr); formatErr != nil { + fmt.Fprintf(os.Stderr, "%v\n", formatErr) continue } - - output, formatErr := jsonutils.Format(resp, FormatOption) - if formatErr != nil { - fmt.Fprintf(os.Stderr, "Error formatting output: %v\n", formatErr) - + case "format": + if len(commandArgs) < 1 { + fmt.Fprintf(thisCmd.OutOrStdout(), "Current format: %s\n", FormatOption) continue } - fmt.Println(output) - + oldFormat := FormatOption + defer func() { FormatOption = oldFormat }() + newFormat := commandArgs[0] + if IsValidFormat(newFormat) { + FormatOption = newFormat + fmt.Fprintf(thisCmd.OutOrStdout(), "Format set to: %s\n", FormatOption) + } else { + fmt.Fprintln(thisCmd.OutOrStdout(), "Invalid format. Use: table, json, or pretty") + } case "call": if len(commandArgs) < 1 { - fmt.Println("Usage: call [--params '{...}']") - + fmt.Fprintln(thisCmd.OutOrStdout(), "Usage: call [--params '{...}']") continue } - - entityName := commandArgs[0] - entityType := EntityTypeTool - - parts = strings.SplitN(entityName, ":", 2) - if len(parts) == 2 { - entityType = parts[0] - entityName = parts[1] - } - - params := map[string]any{} - for ii := 1; ii < len(commandArgs); ii++ { - if commandArgs[ii] == FlagParams || commandArgs[ii] == FlagParamsShort { - if ii+1 < len(commandArgs) { - if jsonErr := json.Unmarshal([]byte(commandArgs[ii+1]), ¶ms); jsonErr != nil { - fmt.Fprintf(os.Stderr, "Error: invalid JSON for params: %v\n", jsonErr) - - continue - } - break - } - } - } - - var execErr error - - switch entityType { - case EntityTypeTool: - resp, execErr = mcpClient.CallTool(entityName, params) - case EntityTypeRes: - resp, execErr = mcpClient.ReadResource(entityName) - case EntityTypePrompt: - resp, execErr = mcpClient.GetPrompt(entityName) - default: - fmt.Fprintf(os.Stderr, "Error: unsupported entity type: %s\n", entityType) + err := callCommand(thisCmd, mcpClient, commandArgs) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) continue } - - if execErr != nil { - fmt.Fprintf(os.Stderr, "Error: %v\n", execErr) + default: + if err := callCommand(thisCmd, mcpClient, append([]string{command}, commandArgs...)); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) continue } + } + } + }, + } +} - output, formatErr := jsonutils.Format(resp, FormatOption) - if formatErr != nil { - fmt.Fprintf(os.Stderr, "Error formatting output: %v\n", formatErr) - continue - } +func callCommand(thisCmd *cobra.Command, mcpClient *client.Client, commandArgs []string) error { + entityName := commandArgs[0] + entityType := EntityTypeTool + parts := strings.SplitN(entityName, ":", 2) + if len(parts) == 2 { + entityType = parts[0] + entityName = parts[1] + } - fmt.Println(output) + params := map[string]any{} + remainingArgs := []string{} + for i := 1; i < len(commandArgs); i++ { + switch commandArgs[i] { + case FlagParams, FlagParamsShort: + continue + case FlagFormat, FlagFormatShort: + if i+1 >= len(commandArgs) { + return fmt.Errorf("no format provided after %s", commandArgs[i]) + } + oldFormat := FormatOption + defer func() { FormatOption = oldFormat }() + newFormat := commandArgs[i+1] + if IsValidFormat(newFormat) { + FormatOption = newFormat + } else { + fmt.Fprintln(thisCmd.OutOrStdout(), "Invalid format. Use: table, json, or pretty") + } + i++ + default: + remainingArgs = append(remainingArgs, commandArgs[i]) + } + } - case "format": - if len(commandArgs) < 1 { - fmt.Printf("Current format: %s\n", FormatOption) - continue - } + if len(remainingArgs) > 0 { + if err := parseJSONBestEffort(strings.Join(remainingArgs, " "), ¶ms); err != nil { + return fmt.Errorf("invalid JSON for params: %w", err) + } + } - newFormat := commandArgs[0] - if newFormat == "json" || newFormat == "j" || - newFormat == "pretty" || newFormat == "p" || - newFormat == "table" || newFormat == "t" { - FormatOption = newFormat - fmt.Printf("Format set to: %s\n", FormatOption) - } else { - fmt.Println("Invalid format. Use: table, json, or pretty") - } + var resp map[string]any + var execErr error + + switch entityType { + case EntityTypeTool: + resp, execErr = mcpClient.CallTool(entityName, params) + case EntityTypeRes: + resp, execErr = mcpClient.ReadResource(entityName) + case EntityTypePrompt: + resp, execErr = mcpClient.GetPrompt(entityName) + default: + fmt.Fprintf(os.Stderr, "Error: unsupported entity type: %s\n", entityType) + } - default: - entityName := command - entityType := EntityTypeTool + if execErr != nil { + return execErr + } - parts = strings.SplitN(entityName, ":", 2) - if len(parts) == 2 { - entityType = parts[0] - entityName = parts[1] - } + formatErr := FormatAndPrintResponse(thisCmd, resp, nil) + if formatErr != nil { + return fmt.Errorf("error formatting output: %w", formatErr) + } - params := map[string]any{} - - if len(commandArgs) > 0 { - firstArg := commandArgs[0] - if strings.HasPrefix(firstArg, "{") && strings.HasSuffix(firstArg, "}") { - if jsonErr := json.Unmarshal([]byte(firstArg), ¶ms); jsonErr != nil { - fmt.Fprintf(os.Stderr, "Error: invalid JSON for params: %v\n", jsonErr) - continue - } - } else { - for iii := 0; iii < len(commandArgs); iii++ { - if commandArgs[iii] == FlagParams || commandArgs[iii] == FlagParamsShort { - if iii+1 < len(commandArgs) { - if jsonErr := json.Unmarshal([]byte(commandArgs[iii+1]), ¶ms); jsonErr != nil { - fmt.Fprintf(os.Stderr, "Error: invalid JSON for params: %v\n", jsonErr) - continue - } - break - } - } - } - } - } + return nil +} - var execErr error - - switch entityType { - case EntityTypeTool: - resp, execErr = mcpClient.CallTool(entityName, params) - case EntityTypeRes: - resp, execErr = mcpClient.ReadResource(entityName) - case EntityTypePrompt: - resp, execErr = mcpClient.GetPrompt(entityName) - default: - fmt.Printf("Unknown command: %s\nType '/h' for help\n", command) - continue - } +func parseJSONBestEffort(jsonString string, params *map[string]any) error { + jsonString = strings.Trim(jsonString, "'\"") + if jsonString == "" { + return nil + } + if err := json.Unmarshal([]byte(jsonString), ¶ms); err != nil { + return err + } + return nil +} - if execErr != nil { - fmt.Fprintf(os.Stderr, "Error: %v\n", execErr) - continue - } +func setUpHistory(line *liner.State) func() { + historyFile := filepath.Join(os.Getenv("HOME"), ".mcp_history") + if f, err := os.Open(filepath.Clean(historyFile)); err == nil { + _, _ = line.ReadHistory(f) + _ = f.Close() + } - output, formatErr := jsonutils.Format(resp, FormatOption) - if formatErr != nil { - fmt.Fprintf(os.Stderr, "Error formatting output: %v\n", formatErr) - continue - } + return func() { + if f, err := os.Create(historyFile); err == nil { + _, _ = line.WriteHistory(f) + _ = f.Close() + } + } +} - fmt.Println(output) - } +func setUpCompleter(line *liner.State) { + line.SetCompleter(func(line string) (c []string) { + commands := []string{ + "tools", + "resources", + "prompts", + "call", + "format", + "help", + "exit", + "/h", + "/q", + "/help", + "/quit", + } + for _, cmd := range commands { + if strings.HasPrefix(cmd, line) { + c = append(c, cmd) } - }, - } + } + return + }) } -func printShellHelp() { - fmt.Println("MCP Shell Commands:") - fmt.Println(" tools List available tools") - fmt.Println(" resources List available resources") - fmt.Println(" prompts List available prompts") - fmt.Println(" call [--params '{...}'] Call a tool, resource, or prompt") - fmt.Println(" format [json|pretty|table] Get or set output format") - fmt.Println("Direct Tool Calling:") - fmt.Println(" {\"param\": \"value\"} Call a tool directly with JSON parameters") - fmt.Println(" resource: Read a resource directly") - fmt.Println(" prompt: Get a prompt directly") - fmt.Println("Special Commands:") - fmt.Println(" /h, /help Show this help") - fmt.Println(" /q, /quit, exit Exit the shell") +func printShellHelp(thisCmd *cobra.Command) { + fmt.Fprintln(thisCmd.OutOrStdout(), "MCP Shell Commands:") + fmt.Fprintln(thisCmd.OutOrStdout(), " tools List available tools") + fmt.Fprintln(thisCmd.OutOrStdout(), " resources List available resources") + fmt.Fprintln(thisCmd.OutOrStdout(), " prompts List available prompts") + fmt.Fprintln(thisCmd.OutOrStdout(), " call [--params '{...}'] Call a tool, resource, or prompt") + fmt.Fprintln(thisCmd.OutOrStdout(), " format [json|pretty|table] Get or set output format") + fmt.Fprintln(thisCmd.OutOrStdout(), "Direct Tool Calling:") + fmt.Fprintln(thisCmd.OutOrStdout(), " {\"param\": \"value\"} Call a tool directly with JSON parameters") + fmt.Fprintln(thisCmd.OutOrStdout(), " resource: Read a resource directly") + fmt.Fprintln(thisCmd.OutOrStdout(), " prompt: Get a prompt directly") + fmt.Fprintln(thisCmd.OutOrStdout(), "Special Commands:") + fmt.Fprintln(thisCmd.OutOrStdout(), " /h, /help Show this help") + fmt.Fprintln(thisCmd.OutOrStdout(), " /q, /quit, exit Exit the shell") } diff --git a/cmd/mcptools/commands/shell_test.go b/cmd/mcptools/commands/shell_test.go new file mode 100644 index 0000000..9431dcc --- /dev/null +++ b/cmd/mcptools/commands/shell_test.go @@ -0,0 +1,366 @@ +package commands + +import ( + "bytes" + "os" + "reflect" + "strings" + "testing" + + "github.com/spf13/cobra" +) + +const toolsListMethod = "tools/list" + +// setupTestCommand creates a command with a buffer for output and simulated stdin. +// It returns the command, output buffer, and a cleanup function. +func setupTestCommand(t *testing.T, input string) (*cobra.Command, *bytes.Buffer, func()) { + cmd := ShellCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + + // Create a pipe to simulate stdin + r, w, err := os.Pipe() + if err != nil { + t.Fatalf("Failed to create pipe: %v", err) + } + + // Save original stdin and restore it after test + oldStdin := os.Stdin + os.Stdin = r + + // This is just to prevent line.Prompt to print to stdout + oldStdout := os.Stdout + os.Stdout = nil + + // Write input to the pipe + go func() { + defer func() { + if err := w.Close(); err != nil { + t.Errorf("Failed to close pipe writer: %v", err) + } + }() + _, err := w.Write([]byte(input)) + if err != nil { + t.Errorf("Failed to write to pipe: %v", err) + } + }() + + cleanup := func() { + os.Stdin = oldStdin + os.Stdout = oldStdout + } + + return cmd, buf, cleanup +} + +func TestShellBasicCommands(t *testing.T) { + tests := []struct { //nolint:govet + mockResponses map[string]map[string]interface{} + name string + expectedOutputs []string + input string + }{ + { + name: "tools command", + input: "tools\n/q\n", + expectedOutputs: []string{"test-tool", "A test tool"}, + mockResponses: map[string]map[string]any{ + "tools/list": { + "tools": []any{ + map[string]any{ + "name": "test-tool", + "description": "A test tool", + }, + }, + }, + }, + }, + { + name: "prompts command", + input: "prompts\n/q\n", + expectedOutputs: []string{"test-prompt", "Test prompt description"}, + mockResponses: map[string]map[string]any{ + "prompts/list": { + "prompts": []any{ + map[string]any{ + "name": "test-prompt", + "description": "Test prompt description", + }, + }, + }, + }, + }, + { + name: "resources command", + input: "resources\n/q\n", + expectedOutputs: []string{"test_resource", "A test resource"}, + mockResponses: map[string]map[string]any{ + "resources/list": { + "resources": []any{ + map[string]any{ + "uri": "test_resource", + "description": "A test resource", + }, + }, + }, + }, + }, + { + name: "help command", + input: "/h\n/q\n", + expectedOutputs: []string{"MCP Shell Commands:"}, + mockResponses: map[string]map[string]any{}, + }, + { + name: "quit command with /q", + input: "/q\n", + expectedOutputs: []string{"Exiting MCP shell"}, + mockResponses: map[string]map[string]any{}, + }, + { + name: "quit command with exit", + input: "exit\n", + expectedOutputs: []string{"Exiting MCP shell"}, + mockResponses: map[string]map[string]any{}, + }, + { + name: "quit command with /quit", + input: "/quit\n", + expectedOutputs: []string{"Exiting MCP shell"}, + mockResponses: map[string]map[string]any{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd, buf, cleanupSetup := setupTestCommand(t, tt.input) + defer cleanupSetup() + + cleanupClient := setupMockClient(func(method string, _ any) (map[string]any, error) { + mockResponse, ok := tt.mockResponses[method] + if !ok { + // Tools list is always called to make sure the server is reachable. + if method == toolsListMethod { + return map[string]any{}, nil + } + t.Errorf("expected method %q, got %q", method, mockResponse) + } + return mockResponse, nil + }) + defer cleanupClient() + + err := cmd.Execute() + if err != nil { + t.Errorf("cmd.Execute() error = %v", err) + } + + output := buf.String() + for _, expectedOutput := range tt.expectedOutputs { + if !strings.Contains(output, expectedOutput) { + t.Errorf("Expected output to contain %q, got: \n%s", expectedOutput, output) + } + } + }) + } +} + +func TestShellCallCommand(t *testing.T) { + // 1. call a tool with `` + // 2. call a tool with `call ` + // 3. call a tool with ` ` + // 4. call a tool with `call ` + // 5. call a tool with ` ''` + // 6. call a tool with `call ''` + // 7. call a tool with ` --params ` + // 8. call a tool with `call --params ` + // 9. call a tool with ` --params ''` + // 10. call a tool with `call --params ''` + tests := []struct { + name string + mockResponses map[string]map[string]any + expectedParams map[string]any + input string + expectedOutputs []string + }{ + { + name: "tool_name without params", + input: "test-tool\n/q\n", + expectedOutputs: []string{"Tool executed successfully"}, + expectedParams: map[string]any{"name": "test-tool", "arguments": map[string]any{}}, + mockResponses: map[string]map[string]any{ + "tools/call": { + "content": []any{map[string]any{"type": "text", "text": "Tool executed successfully"}}, + }, + }, + }, + { + name: "call tool without params", + input: "call test-tool\n/q\n", + expectedOutputs: []string{"Tool executed successfully"}, + expectedParams: map[string]any{"name": "test-tool", "arguments": map[string]any{}}, + mockResponses: map[string]map[string]any{ + "tools/call": { + "content": []any{map[string]any{"type": "text", "text": "Tool executed successfully"}}, + }, + }, + }, + { + name: "tool_name with direct params", + input: "test-tool {\"foo\": \"bar\"}\n/q\n", + expectedOutputs: []string{"Tool executed successfully"}, + expectedParams: map[string]any{"name": "test-tool", "arguments": map[string]any{"foo": "bar"}}, + mockResponses: map[string]map[string]any{ + "tools/call": { + "content": []any{map[string]any{"type": "text", "text": "Tool executed successfully"}}, + }, + }, + }, + { + name: "call tool with direct params", + input: "call test-tool {\"foo\": \"bar\"}\n/q\n", + expectedOutputs: []string{"Tool executed successfully"}, + expectedParams: map[string]any{"name": "test-tool", "arguments": map[string]any{"foo": "bar"}}, + mockResponses: map[string]map[string]any{ + "tools/call": { + "content": []any{map[string]any{"type": "text", "text": "Tool executed successfully"}}, + }, + }, + }, + { + name: "tool_name with quoted direct params", + input: "test-tool '{\"foo\": \"bar\"}'\n/q\n", + expectedOutputs: []string{"Tool executed successfully"}, + expectedParams: map[string]any{"name": "test-tool", "arguments": map[string]any{"foo": "bar"}}, + mockResponses: map[string]map[string]any{ + "tools/call": { + "content": []any{map[string]any{"type": "text", "text": "Tool executed successfully"}}, + }, + }, + }, + { + name: "call tool with quoted direct params", + input: "call test-tool '{\"foo\": \"bar\"}'\n/q\n", + expectedOutputs: []string{"Tool executed successfully"}, + expectedParams: map[string]any{"name": "test-tool", "arguments": map[string]any{"foo": "bar"}}, + mockResponses: map[string]map[string]any{ + "tools/call": { + "content": []any{map[string]any{"type": "text", "text": "Tool executed successfully"}}, + }, + }, + }, + { + name: "tool_name with params flag", + input: "test-tool --params {\"foo\": \"bar\"}\n/q\n", + expectedOutputs: []string{"Tool executed successfully"}, + expectedParams: map[string]any{"name": "test-tool", "arguments": map[string]any{"foo": "bar"}}, + mockResponses: map[string]map[string]any{ + "tools/call": { + "content": []any{map[string]any{"type": "text", "text": "Tool executed successfully"}}, + }, + }, + }, + { + name: "call tool with params flag", + input: "call test-tool --params {\"foo\": \"bar\"}\n/q\n", + expectedOutputs: []string{"Tool executed successfully"}, + expectedParams: map[string]any{"name": "test-tool", "arguments": map[string]any{"foo": "bar"}}, + mockResponses: map[string]map[string]any{ + "tools/call": { + "content": []any{map[string]any{"type": "text", "text": "Tool executed successfully"}}, + }, + }, + }, + { + name: "tool_name with quoted params flag", + input: "test-tool --params '{\"foo\": \"bar\"}'\n/q\n", + expectedOutputs: []string{"Tool executed successfully"}, + expectedParams: map[string]any{"name": "test-tool", "arguments": map[string]any{"foo": "bar"}}, + mockResponses: map[string]map[string]any{ + "tools/call": { + "content": []any{map[string]any{"type": "text", "text": "Tool executed successfully"}}, + }, + }, + }, + { + name: "call tool with quoted params flag", + input: "call test-tool --params '{\"foo\": \"bar\"}'\n/q\n", + expectedOutputs: []string{"Tool executed successfully"}, + expectedParams: map[string]any{"name": "test-tool", "arguments": map[string]any{"foo": "bar"}}, + mockResponses: map[string]map[string]any{ + "tools/call": { + "content": []any{map[string]any{"type": "text", "text": "Tool executed successfully"}}, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd, buf, cleanupSetup := setupTestCommand(t, tt.input) + defer cleanupSetup() + + cleanupClient := setupMockClient(func(method string, params any) (map[string]any, error) { + mockResponse, ok := tt.mockResponses[method] + if !ok { + // Tools list is always called to make sure the server is reachable. + if method == toolsListMethod { + return map[string]any{}, nil + } + t.Errorf("expected method %q, got %q", method, mockResponse) + } + if !reflect.DeepEqual(params, tt.expectedParams) { + t.Errorf("expected params %v, got %v", tt.expectedParams, params) + } + return mockResponse, nil + }) + defer cleanupClient() + + err := cmd.Execute() + if err != nil { + t.Errorf("cmd.Execute() error = %v", err) + } + + output := buf.String() + for _, expectedOutput := range tt.expectedOutputs { + if !strings.Contains(output, expectedOutput) { + t.Errorf("Expected output to contain %q, got: %s", expectedOutput, output) + } + } + }) + } +} + +func TestShellExit(t *testing.T) { + tests := []struct { + name string + input string + }{ + {"quit with /q", "/q\n"}, + {"quit with /quit", "/quit\n"}, + {"quit with exit", "exit\n"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd, buf, cleanupSetup := setupTestCommand(t, tt.input) + defer cleanupSetup() + + cleanupClient := setupMockClient(func(method string, _ any) (map[string]any, error) { + if method != toolsListMethod { + t.Errorf("Expected method 'tools/list', got %q", method) + } + return map[string]any{}, nil + }) + defer cleanupClient() + + err := cmd.Execute() + if err != nil { + t.Errorf("cmd.Execute() error = %v", err) + } + + output := buf.String() + if !strings.Contains(output, "Exiting MCP shell") { + t.Errorf("Expected output to contain 'Exiting MCP shell', got: %s", output) + } + }) + } +} diff --git a/cmd/mcptools/commands/test_helpers.go b/cmd/mcptools/commands/test_helpers.go index 51e436f..3b27637 100644 --- a/cmd/mcptools/commands/test_helpers.go +++ b/cmd/mcptools/commands/test_helpers.go @@ -19,7 +19,7 @@ func (m *MockTransport) Execute(method string, params any) (map[string]any, erro } // setupMockClient creates a mock client with the given execute function and returns cleanup function. -func setupMockClient(executeFunc func(method string, params any) (map[string]any, error)) func() { +func setupMockClient(executeFunc func(method string, _ any) (map[string]any, error)) func() { // Save original function and restore later originalFunc := CreateClientFunc diff --git a/cmd/mcptools/commands/utils.go b/cmd/mcptools/commands/utils.go index 31bc537..4770757 100644 --- a/cmd/mcptools/commands/utils.go +++ b/cmd/mcptools/commands/utils.go @@ -84,3 +84,10 @@ func FormatAndPrintResponse(cmd *cobra.Command, resp map[string]any, err error) fmt.Fprintln(cmd.OutOrStdout(), output) return nil } + +// IsValidFormat returns true if the format is valid. +func IsValidFormat(format string) bool { + return format == "json" || format == "j" || + format == "pretty" || format == "p" || + format == "table" || format == "t" +}