diff --git a/cmd/mcptools/commands/shell.go b/cmd/mcptools/commands/shell.go index e26eb1c..501cfbf 100644 --- a/cmd/mcptools/commands/shell.go +++ b/cmd/mcptools/commands/shell.go @@ -44,7 +44,7 @@ func ShellCmd() *cobra.Command { //nolint:gocyclo os.Exit(1) } - mcpClient, clientErr := CreateClientFunc(parsedArgs) + mcpClient, clientErr := CreateClientFunc(parsedArgs, client.CloseTransportAfterExecute(false)) if clientErr != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", clientErr) os.Exit(1) diff --git a/cmd/mcptools/commands/test_helpers.go b/cmd/mcptools/commands/test_helpers.go index 3b27637..6cb19da 100644 --- a/cmd/mcptools/commands/test_helpers.go +++ b/cmd/mcptools/commands/test_helpers.go @@ -30,7 +30,7 @@ func setupMockClient(executeFunc func(method string, _ any) (map[string]any, err mockClient := client.NewWithTransport(mockTransport) // Override the function that creates clients - CreateClientFunc = func(_ []string) (*client.Client, error) { + CreateClientFunc = func(_ []string, _ ...client.Option) (*client.Client, error) { return mockClient, nil } diff --git a/cmd/mcptools/commands/utils.go b/cmd/mcptools/commands/utils.go index 4770757..bbf70a6 100644 --- a/cmd/mcptools/commands/utils.go +++ b/cmd/mcptools/commands/utils.go @@ -17,7 +17,7 @@ var ( // CreateClientFunc is the function used to create MCP clients. // This can be replaced in tests to use a mock transport. -var CreateClientFunc = func(args []string) (*client.Client, error) { +var CreateClientFunc = func(args []string, opts ...client.Option) (*client.Client, error) { if len(args) == 0 { return nil, ErrCommandRequired } @@ -42,7 +42,13 @@ var CreateClientFunc = func(args []string) (*client.Client, error) { return client.NewHTTP(args[0]), nil } - return client.NewStdio(args), nil + c := client.NewStdio(args) + + for _, opt := range opts { + opt(c) + } + + return c, nil } // ProcessFlags processes command line flags, sets the format option, and returns the remaining diff --git a/pkg/client/client.go b/pkg/client/client.go index 41cc64c..e33f923 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -18,6 +18,21 @@ type Client struct { transport transport.Transport } +// Option provides a way for passing options to the Client to change its +// configuration. +type Option func(*Client) + +// CloseTransportAfterExecute allows keeping a transport alive if supported by +// the transport. +func CloseTransportAfterExecute(closeTransport bool) Option { + return func(c *Client) { + t, ok := c.transport.(interface{ SetCloseAfterExecute(bool) }) + if ok { + t.SetCloseAfterExecute(closeTransport) + } + } +} + // NewWithTransport creates a new MCP client using the provided transport. // This allows callers to provide a custom transport implementation. func NewWithTransport(t transport.Transport) *Client { diff --git a/pkg/transport/stdio.go b/pkg/transport/stdio.go index 5e0f82d..3e9553f 100644 --- a/pkg/transport/stdio.go +++ b/pkg/transport/stdio.go @@ -14,11 +14,20 @@ import ( // Stdio implements the Transport interface by executing a command // and communicating with it via stdin/stdout using JSON-RPC. type Stdio struct { + process *stdioProcess command []string nextID int debug bool } +// stdioProcess reflects the state of a running command. +type stdioProcess struct { + stdin io.WriteCloser + stdout io.ReadCloser + cmd *exec.Cmd + stderrBuf *bytes.Buffer +} + // NewStdio creates a new Stdio transport that will execute the given command. // It communicates with the command using JSON-RPC over stdin/stdout. func NewStdio(command []string) *Stdio { @@ -30,23 +39,41 @@ func NewStdio(command []string) *Stdio { } } +// SetCloseAfterExecute toggles whether the underlying process should be closed +// or kept alive after each call to Execute. +func (t *Stdio) SetCloseAfterExecute(v bool) { + if v { + t.process = nil + } else { + t.process = &stdioProcess{} + } +} + // Execute implements the Transport interface by spawning a subprocess // and communicating with it via JSON-RPC over stdin/stdout. func (t *Stdio) Execute(method string, params any) (map[string]any, error) { - stdin, stdout, cmd, stderrBuf, err := t.setupCommand() - if err != nil { - return nil, err + process := t.process + if process == nil { + process = &stdioProcess{} + } + + if process.cmd == nil { + var err error + process.stdin, process.stdout, process.cmd, process.stderrBuf, err = t.setupCommand() + if err != nil { + return nil, err + } } if t.debug { fmt.Fprintf(os.Stderr, "DEBUG: Starting initialization\n") } - if initErr := t.initialize(stdin, stdout); initErr != nil { + if initErr := t.initialize(process.stdin, process.stdout); initErr != nil { if t.debug { fmt.Fprintf(os.Stderr, "DEBUG: Initialization failed: %v\n", initErr) - if stderrBuf.Len() > 0 { - fmt.Fprintf(os.Stderr, "DEBUG: stderr during init: %s\n", stderrBuf.String()) + if process.stderrBuf.Len() > 0 { + fmt.Fprintf(os.Stderr, "DEBUG: stderr during init: %s\n", process.stderrBuf.String()) } } return nil, initErr @@ -64,43 +91,58 @@ func (t *Stdio) Execute(method string, params any) (map[string]any, error) { } t.nextID++ - if sendErr := t.sendRequest(stdin, request); sendErr != nil { + if sendErr := t.sendRequest(process.stdin, request); sendErr != nil { return nil, sendErr } - _ = stdin.Close() - response, err := t.readResponse(stdout) + response, err := t.readResponse(process.stdout) + if err != nil { + return nil, err + } + + err = t.closeProcess(process) if err != nil { return nil, err } + return response.Result, nil +} + +// closeProcess waits for the command to finish, returning any error. +func (t *Stdio) closeProcess(process *stdioProcess) error { + if t.process != nil { + return nil + } + + _ = process.stdin.Close() + // Wait for the command to finish with a timeout to prevent zombie processes done := make(chan error, 1) go func() { - done <- cmd.Wait() + done <- process.cmd.Wait() }() select { case waitErr := <-done: if t.debug { fmt.Fprintf(os.Stderr, "DEBUG: Command completed with err: %v\n", waitErr) - if stderrBuf.Len() > 0 { - fmt.Fprintf(os.Stderr, "DEBUG: stderr output:\n%s\n", stderrBuf.String()) + if process.stderrBuf.Len() > 0 { + fmt.Fprintf(os.Stderr, "DEBUG: stderr output:\n%s\n", process.stderrBuf.String()) } } - if waitErr != nil && stderrBuf.Len() > 0 { - return nil, fmt.Errorf("command error: %w, stderr: %s", waitErr, stderrBuf.String()) + if waitErr != nil && process.stderrBuf.Len() > 0 { + return fmt.Errorf("command error: %w, stderr: %s", waitErr, process.stderrBuf.String()) } case <-time.After(1 * time.Second): if t.debug { fmt.Fprintf(os.Stderr, "DEBUG: Command timed out after 1 seconds\n") } // Kill the process if it times out - _ = cmd.Process.Kill() + _ = process.cmd.Process.Kill() } - return response.Result, nil + return nil } // setupCommand prepares and starts the command, returning the stdin/stdout pipes and any error.