Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/mcptools/commands/shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func ShellCmd() *cobra.Command { //nolint:gocyclo
os.Exit(1)
}

mcpClient, clientErr := CreateClientFunc(parsedArgs)
mcpClient, clientErr := CreateClientFunc(parsedArgs, client.CloseTransportAfterExecute(false))
if clientErr != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", clientErr)
os.Exit(1)
Expand Down
2 changes: 1 addition & 1 deletion cmd/mcptools/commands/test_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func setupMockClient(executeFunc func(method string, _ any) (map[string]any, err
mockClient := client.NewWithTransport(mockTransport)

// Override the function that creates clients
CreateClientFunc = func(_ []string) (*client.Client, error) {
CreateClientFunc = func(_ []string, _ ...client.Option) (*client.Client, error) {
return mockClient, nil
}

Expand Down
10 changes: 8 additions & 2 deletions cmd/mcptools/commands/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ var (

// CreateClientFunc is the function used to create MCP clients.
// This can be replaced in tests to use a mock transport.
var CreateClientFunc = func(args []string) (*client.Client, error) {
var CreateClientFunc = func(args []string, opts ...client.Option) (*client.Client, error) {
if len(args) == 0 {
return nil, ErrCommandRequired
}
Expand All @@ -42,7 +42,13 @@ var CreateClientFunc = func(args []string) (*client.Client, error) {
return client.NewHTTP(args[0]), nil
}

return client.NewStdio(args), nil
c := client.NewStdio(args)

for _, opt := range opts {
opt(c)
}

return c, nil
}

// ProcessFlags processes command line flags, sets the format option, and returns the remaining
Expand Down
15 changes: 15 additions & 0 deletions pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,21 @@ type Client struct {
transport transport.Transport
}

// Option provides a way for passing options to the Client to change its
// configuration.
type Option func(*Client)

// CloseTransportAfterExecute allows keeping a transport alive if supported by
// the transport.
func CloseTransportAfterExecute(closeTransport bool) Option {
return func(c *Client) {
t, ok := c.transport.(interface{ SetCloseAfterExecute(bool) })
if ok {
t.SetCloseAfterExecute(closeTransport)
}
}
}

// NewWithTransport creates a new MCP client using the provided transport.
// This allows callers to provide a custom transport implementation.
func NewWithTransport(t transport.Transport) *Client {
Expand Down
74 changes: 58 additions & 16 deletions pkg/transport/stdio.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,20 @@ import (
// Stdio implements the Transport interface by executing a command
// and communicating with it via stdin/stdout using JSON-RPC.
type Stdio struct {
process *stdioProcess
command []string
nextID int
debug bool
}

// stdioProcess reflects the state of a running command.
type stdioProcess struct {
stdin io.WriteCloser
stdout io.ReadCloser
cmd *exec.Cmd
stderrBuf *bytes.Buffer
}

// NewStdio creates a new Stdio transport that will execute the given command.
// It communicates with the command using JSON-RPC over stdin/stdout.
func NewStdio(command []string) *Stdio {
Expand All @@ -30,23 +39,41 @@ func NewStdio(command []string) *Stdio {
}
}

// SetCloseAfterExecute toggles whether the underlying process should be closed
// or kept alive after each call to Execute.
func (t *Stdio) SetCloseAfterExecute(v bool) {
if v {
t.process = nil
} else {
t.process = &stdioProcess{}
}
}

// Execute implements the Transport interface by spawning a subprocess
// and communicating with it via JSON-RPC over stdin/stdout.
func (t *Stdio) Execute(method string, params any) (map[string]any, error) {
stdin, stdout, cmd, stderrBuf, err := t.setupCommand()
if err != nil {
return nil, err
process := t.process
if process == nil {
process = &stdioProcess{}
}

if process.cmd == nil {
var err error
process.stdin, process.stdout, process.cmd, process.stderrBuf, err = t.setupCommand()
if err != nil {
return nil, err
}
}

if t.debug {
fmt.Fprintf(os.Stderr, "DEBUG: Starting initialization\n")
}

if initErr := t.initialize(stdin, stdout); initErr != nil {
if initErr := t.initialize(process.stdin, process.stdout); initErr != nil {
if t.debug {
fmt.Fprintf(os.Stderr, "DEBUG: Initialization failed: %v\n", initErr)
if stderrBuf.Len() > 0 {
fmt.Fprintf(os.Stderr, "DEBUG: stderr during init: %s\n", stderrBuf.String())
if process.stderrBuf.Len() > 0 {
fmt.Fprintf(os.Stderr, "DEBUG: stderr during init: %s\n", process.stderrBuf.String())
}
}
return nil, initErr
Expand All @@ -64,43 +91,58 @@ func (t *Stdio) Execute(method string, params any) (map[string]any, error) {
}
t.nextID++

if sendErr := t.sendRequest(stdin, request); sendErr != nil {
if sendErr := t.sendRequest(process.stdin, request); sendErr != nil {
return nil, sendErr
}
_ = stdin.Close()

response, err := t.readResponse(stdout)
response, err := t.readResponse(process.stdout)
if err != nil {
return nil, err
}

err = t.closeProcess(process)
if err != nil {
return nil, err
}

return response.Result, nil
}

// closeProcess waits for the command to finish, returning any error.
func (t *Stdio) closeProcess(process *stdioProcess) error {
if t.process != nil {
return nil
}

_ = process.stdin.Close()

// Wait for the command to finish with a timeout to prevent zombie processes
done := make(chan error, 1)
go func() {
done <- cmd.Wait()
done <- process.cmd.Wait()
}()

select {
case waitErr := <-done:
if t.debug {
fmt.Fprintf(os.Stderr, "DEBUG: Command completed with err: %v\n", waitErr)
if stderrBuf.Len() > 0 {
fmt.Fprintf(os.Stderr, "DEBUG: stderr output:\n%s\n", stderrBuf.String())
if process.stderrBuf.Len() > 0 {
fmt.Fprintf(os.Stderr, "DEBUG: stderr output:\n%s\n", process.stderrBuf.String())
}
}

if waitErr != nil && stderrBuf.Len() > 0 {
return nil, fmt.Errorf("command error: %w, stderr: %s", waitErr, stderrBuf.String())
if waitErr != nil && process.stderrBuf.Len() > 0 {
return fmt.Errorf("command error: %w, stderr: %s", waitErr, process.stderrBuf.String())
}
case <-time.After(1 * time.Second):
if t.debug {
fmt.Fprintf(os.Stderr, "DEBUG: Command timed out after 1 seconds\n")
}
// Kill the process if it times out
_ = cmd.Process.Kill()
_ = process.cmd.Process.Kill()
}

return response.Result, nil
return nil
}

// setupCommand prepares and starts the command, returning the stdin/stdout pipes and any error.
Expand Down