diff --git a/README.md b/README.md index 2212567..cd84fbc 100644 --- a/README.md +++ b/README.md @@ -223,6 +223,100 @@ mcp mock prompt greeting "Greeting template" "Hello {{name}}! Welcome to {{locat When a client requests the prompt, it can provide values for these arguments which will be substituted in the response. +## Proxy Mode + +The proxy mode allows you to register shell scripts as MCP tools and proxy MCP requests to them. +The scripts will receive tool parameters as environment variables. + +### Registering a Tool + +To register a shell script as an MCP tool: + +```bash +mcp proxy tool add_operation "Adds a and b" "a:int,b:int" ./examples/add.sh +``` + +This registers a tool named `add_operation` with the description "Adds a and b" and parameters `a` and `b` of type `int`. +When the tool is called, the parameters will be passed as environment variables to the script. + +Parameters are specified in the format `name:type,name:type,...` where `type` can be `string`, `int`, `float`, or `bool`. + +### Starting the Proxy Server + +To start a proxy server with the registered tools: + +```bash +mcp proxy start +``` + +The server will run in stdio mode compatible with the MCP protocol and forward tool call requests to the registered shell scripts. + +### Example Shell Scripts + +#### Adding Numbers + +```bash +#!/bin/bash + +# Get the values from environment variables +if [ -z "$a" ] || [ -z "$b" ]; then + echo "Error: Missing required parameters 'a' or 'b'" + exit 1 +fi + +# Try to convert to integers +a_val=$(($a)) +b_val=$(($b)) + +# Perform the addition +result=$(($a_val + $b_val)) + +# Return the result +echo "The sum of $a and $b is $result" +``` + +#### Customized Greeting + +```bash +#!/bin/bash + +# Get the values from environment variables +if [ -z "$name" ]; then + echo "Error: Missing required parameter 'name'" + exit 1 +fi + +# Set default values if not provided +if [ -z "$greeting" ]; then + greeting="Hello" +fi + +if [ -z "$formal" ]; then + formal=false +fi + +# Customize greeting based on formal flag +if [ "$formal" = "true" ]; then + title="Mr./Ms." + message="${greeting}, ${title} ${name}. How may I assist you today?" +else + message="${greeting}, ${name}! Nice to meet you!" +fi + +# Return the greeting +echo "$message" +``` + +Register with: + +```bash +mcp proxy tool greet "Greets a user" "name:string,greeting:string,formal:bool" ./examples/greet.sh +``` + +### Configuration + +Tools are registered in `~/.mcpt/proxy_config.json`. The proxy server logs all requests and responses to `~/.mcpt/logs/proxy.log`. + ## Examples List tools from a filesystem server: @@ -252,6 +346,47 @@ mcp mock tool file_reader "Reads files" \ resource docs://api "API Documentation" "# API Reference\n\nThis document describes the API." ``` +Using the proxy mode with a simple shell script: + +```bash +# 1. Create a simple shell script for addition +cat > add.sh << 'EOF' +#!/bin/bash +# Get values from environment variables +if [ -z "$a" ] || [ -z "$b" ]; then + echo "Error: Missing required parameters 'a' or 'b'" + exit 1 +fi +result=$(($a + $b)) +echo "The sum of $a and $b is $result" +EOF + +# 2. Make it executable +chmod +x add.sh + +# 3. Register it as an MCP tool +mcp proxy tool add_numbers "Adds two numbers" "a:int,b:int" ./add.sh + +# 4. In one terminal, start the proxy server +mcp proxy start + +# 5. In another terminal, you can call it as an MCP tool +mcp call add_numbers --params '{"a":5,"b":3}' --format pretty +``` + +Tailing the logs to debug your proxy or mock server: + +```bash +# For the mock server logs +tail -f ~/.mcpt/logs/mock.log + +# For the proxy server logs +tail -f ~/.mcpt/logs/proxy.log + +# To watch all logs in real-time (on macOS/Linux) +find ~/.mcpt/logs -name "*.log" -exec tail -f {} \; +``` + ## Contributing We welcome contributions! Please see our [Contributing Guidelines](CONTRIBUTING.md) diff --git a/cmd/mcptools/main.go b/cmd/mcptools/main.go index 2dc1434..b229c06 100644 --- a/cmd/mcptools/main.go +++ b/cmd/mcptools/main.go @@ -7,6 +7,7 @@ import ( "encoding/json" "errors" "fmt" + "log" "os" "path/filepath" "strings" @@ -14,8 +15,10 @@ import ( "github.com/f/mcptools/pkg/client" "github.com/f/mcptools/pkg/jsonutils" "github.com/f/mcptools/pkg/mock" + "github.com/f/mcptools/pkg/proxy" "github.com/peterh/liner" "github.com/spf13/cobra" + "github.com/spf13/viper" ) // version information placeholders. @@ -61,6 +64,7 @@ func main() { newReadResourceCmd(), newShellCmd(), newMockCmd(), + proxyCmd(), ) if err := rootCmd.Execute(); err != nil { @@ -939,3 +943,137 @@ Example: return cmd } + +func proxyCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "proxy", + Short: "Proxy MCP tool requests to shell scripts", + Long: `Proxy MCP tool requests to shell scripts. + +This command allows you to register shell scripts as MCP tools and proxy MCP requests to them. +The scripts will receive tool parameters as environment variables. + +Examples: + # Register a shell script as an MCP tool + mcp proxy tool add_operation "Adds a and b" "a:int,b:int" ./add.sh + + # Start a proxy server with the registered tools + mcp proxy start`, + } + + cmd.AddCommand(proxyToolCmd()) + cmd.AddCommand(proxyStartCmd()) + + return cmd +} + +func proxyToolCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "tool NAME DESCRIPTION PARAMETERS SCRIPT_PATH", + Short: "Register a shell script as an MCP tool", + Long: `Register a shell script as an MCP tool. + +The PARAMETERS argument should be a comma-separated list of "name:type" pairs. +Supported types: string, int, float, bool + +Example: + mcp proxy tool add_operation "Adds a and b" "a:int,b:int" ./add.sh`, + Args: cobra.ExactArgs(4), + Run: func(_ *cobra.Command, args []string) { + name := args[0] + description := args[1] + parameters := args[2] + scriptPath := args[3] + + // Initialize config + viper.SetConfigName("proxy_config") + viper.SetConfigType("json") + viper.AddConfigPath("$HOME/.mcpt") + + // Create config directory if it doesn't exist + configDir := os.ExpandEnv("$HOME/.mcpt") + if err := os.MkdirAll(configDir, 0o750); err != nil { + log.Fatalf("Error creating config directory: %v", err) + } + + // Load existing config if it exists + var config map[string]map[string]string + var configFileNotFound viper.ConfigFileNotFoundError + err := viper.ReadInConfig() + if err != nil { + if errors.As(err, &configFileNotFound) { + // Config file not found, create a new one + config = make(map[string]map[string]string) + } else { + log.Fatalf("Error reading config: %v", err) + } + } else { + // Config file found, unmarshal it + config = make(map[string]map[string]string) + unmarshalErr := viper.Unmarshal(&config) + if unmarshalErr != nil { + log.Fatalf("Error unmarshaling config: %v", unmarshalErr) + } + } + + // Add or update tool config + config[name] = map[string]string{ + "description": description, + "parameters": parameters, + "script": scriptPath, + } + + // Save config + configPath := os.ExpandEnv("$HOME/.mcpt/proxy_config.json") + configJSON, err := json.MarshalIndent(config, "", " ") + if err != nil { + log.Fatalf("Error marshaling config: %v", err) + } + + writeErr := os.WriteFile(configPath, configJSON, 0o600) + if writeErr != nil { + log.Fatalf("Error writing config: %v", writeErr) + } + + fmt.Printf("Registered tool '%s' with script '%s'\n", name, scriptPath) + }, + } + + return cmd +} + +func proxyStartCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "start", + Short: "Start a proxy server with registered tools", + Long: `Start a proxy server that forwards MCP tool requests to shell scripts. + +The server reads tool configurations from $HOME/.mcpt/proxy_config.json. + +Example: + mcp proxy start`, + Run: func(_ *cobra.Command, _ []string) { + // Load tool configurations + viper.SetConfigName("proxy_config") + viper.SetConfigType("json") + viper.AddConfigPath("$HOME/.mcpt") + + if err := viper.ReadInConfig(); err != nil { + log.Fatalf("Error reading config: %v", err) + } + + var config map[string]map[string]string + if err := viper.Unmarshal(&config); err != nil { + log.Fatalf("Error unmarshaling config: %v", err) + } + + // Run proxy server + fmt.Println("Starting proxy server...") + if err := proxy.RunProxyServer(config); err != nil { + log.Fatalf("Error running proxy server: %v", err) + } + }, + } + + return cmd +} diff --git a/examples/add.sh b/examples/add.sh new file mode 100755 index 0000000..c2f5287 --- /dev/null +++ b/examples/add.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +# Get the values from environment variables +if [ -z "$a" ] || [ -z "$b" ]; then + echo "Error: Missing required parameters 'a' or 'b'" + exit 1 +fi + +# Try to convert to integers +a_val=$(($a)) +b_val=$(($b)) + +# Perform the addition +result=$(($a_val + $b_val)) + +# Return the result +echo "The sum of $a and $b is $result" \ No newline at end of file diff --git a/examples/greet.sh b/examples/greet.sh new file mode 100755 index 0000000..0a8ab95 --- /dev/null +++ b/examples/greet.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +# Get the values from environment variables +if [ -z "$name" ]; then + echo "Error: Missing required parameter 'name'" + exit 1 +fi + +# Set default values if not provided +if [ -z "$greeting" ]; then + greeting="Hello" +fi + +if [ -z "$formal" ]; then + formal=false +fi + +# Customize greeting based on formal flag +if [ "$formal" = "true" ]; then + title="Mr./Ms." + message="${greeting}, ${title} ${name}. How may I assist you today?" +else + message="${greeting}, ${name}! Nice to meet you!" +fi + +# Return the greeting +echo "$message" \ No newline at end of file diff --git a/go.mod b/go.mod index 1fa01d9..1dd6431 100644 --- a/go.mod +++ b/go.mod @@ -9,16 +9,28 @@ require ( ) require ( + github.com/fsnotify/fsnotify v1.8.0 // indirect + github.com/go-viper/mapstructure/v2 v2.2.1 // indirect github.com/google/go-cmp v0.6.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/mattn/go-runewidth v0.0.16 // indirect + github.com/pelletier/go-toml/v2 v2.2.3 // indirect github.com/rivo/uniseg v0.4.7 // indirect + github.com/sagikazarmark/locafero v0.7.0 // indirect + github.com/sourcegraph/conc v0.3.0 // indirect + github.com/spf13/afero v1.12.0 // indirect + github.com/spf13/cast v1.7.1 // indirect github.com/spf13/pflag v1.0.6 // indirect + github.com/spf13/viper v1.20.1 // indirect + github.com/subosito/gotenv v1.6.0 // indirect + go.uber.org/atomic v1.9.0 // indirect + go.uber.org/multierr v1.9.0 // indirect golang.org/x/mod v0.17.0 // indirect golang.org/x/sync v0.12.0 // indirect golang.org/x/sys v0.31.0 // indirect golang.org/x/term v0.30.0 // indirect golang.org/x/text v0.23.0 // indirect golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect mvdan.cc/gofumpt v0.7.0 // indirect ) diff --git a/go.sum b/go.sum index dc618d1..fa7d8bd 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,11 @@ github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M= +github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIxtHqx8aGss= +github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= @@ -10,6 +15,8 @@ github.com/jedib0t/go-pretty/v6 v6.6.7/go.mod h1:YwC5CE4fJ1HFUDeivSV1r//AmANFHyq github.com/mattn/go-runewidth v0.0.3/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M= +github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc= github.com/peterh/liner v1.2.2 h1:aJ4AOodmL+JxOZZEL2u9iJf8omNRpqHc/EbrK+3mAXw= github.com/peterh/liner v1.2.2/go.mod h1:xFwJyiKIXJZUKItq5dGHZSTBRAuG/CpeNpWLyiNRNwI= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -18,12 +25,30 @@ github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJ github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sagikazarmark/locafero v0.7.0 h1:5MqpDsTGNDhY8sGp0Aowyf0qKsPrhewaLSsFaodPcyo= +github.com/sagikazarmark/locafero v0.7.0/go.mod h1:2za3Cg5rMaTMoG/2Ulr9AwtFaIppKXTRYnozin4aB5k= +github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= +github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= +github.com/spf13/afero v1.12.0 h1:UcOPyRBYczmFn6yvphxkn9ZEOY65cpwGKb5mL36mrqs= +github.com/spf13/afero v1.12.0/go.mod h1:ZTlWwG4/ahT8W7T0WQ5uYmjI9duaLQGy3Q2OAl4sk/4= +github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= +github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo= github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0= github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o= github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.20.1 h1:ZMi+z/lvLyPSCoNtFCpqjy0S4kPbirhpTMwl8BkW9X4= +github.com/spf13/viper v1.20.1/go.mod h1:P9Mdzt1zoHIG8m2eZQinpiBjo6kCmZSKBClNNqjJvu4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= +github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= +go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= +go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= diff --git a/pkg/proxy/proxy.go b/pkg/proxy/proxy.go new file mode 100644 index 0000000..20de999 --- /dev/null +++ b/pkg/proxy/proxy.go @@ -0,0 +1,571 @@ +// Package proxy provides functionality for proxying MCP tool requests to shell scripts. +package proxy + +import ( + "encoding/json" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "strings" + "time" +) + +// Parameter represents a tool parameter with a name and type. +type Parameter struct { + Name string + Type string +} + +// Tool represents a proxy tool that executes a shell script. +type Tool struct { + // Fields ordered for optimal memory alignment (8-byte aligned fields first) + ScriptPath string + Name string + Description string + Parameters []Parameter +} + +// Server handles proxying requests to shell scripts. +type Server struct { + // Fields ordered for optimal memory alignment (8-byte aligned fields first) + tools map[string]Tool + logFile *os.File + id int +} + +// NewProxyServer creates a new proxy server. +func NewProxyServer() (*Server, error) { + // Create log directory + homeDir := os.Getenv("HOME") + if homeDir == "" { + return nil, fmt.Errorf("HOME environment variable not set") + } + + logDir := filepath.Join(homeDir, ".mcpt", "logs") + if err := os.MkdirAll(logDir, 0o750); err != nil { + return nil, fmt.Errorf("error creating log directory: %w", err) + } + + // Open log file + logPath := filepath.Join(logDir, "proxy.log") + // Clean the path to avoid any path traversal + logPath = filepath.Clean(logPath) + + // Verify the path is still under the expected log directory + if !strings.HasPrefix(logPath, logDir) { + return nil, fmt.Errorf("invalid log path: outside of log directory") + } + + logFile, err := os.OpenFile(logPath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o600) + if err != nil { + return nil, fmt.Errorf("error opening log file: %w", err) + } + + fmt.Fprintf(os.Stderr, "Logging to %s\n", logPath) + + return &Server{ + tools: make(map[string]Tool), + id: 0, + logFile: logFile, + }, nil +} + +// log writes a message to the log file with a timestamp. +func (s *Server) log(message string) { + timestamp := time.Now().Format(time.RFC3339) + fmt.Fprintf(s.logFile, "[%s] %s\n", timestamp, message) +} + +// logJSON writes a JSON-formatted message to the log file with a timestamp. +func (s *Server) logJSON(label string, v any) { + jsonBytes, err := json.MarshalIndent(v, "", " ") + if err != nil { + s.log(fmt.Sprintf("Error marshaling %s: %v", label, err)) + return + } + s.log(fmt.Sprintf("%s: %s", label, string(jsonBytes))) +} + +// Close closes the log file. +func (s *Server) Close() error { + if s.logFile != nil { + return s.logFile.Close() + } + return nil +} + +// AddTool adds a new tool to the proxy server. +func (s *Server) AddTool(name, description, paramStr, scriptPath string) error { + // Validate script path + absPath, err := filepath.Abs(scriptPath) + if err != nil { + return fmt.Errorf("invalid script path: %w", err) + } + + // Clean the path to avoid any path traversal + absPath = filepath.Clean(absPath) + + // Check if script exists and is executable + info, err := os.Stat(absPath) + if err != nil { + return fmt.Errorf("script not found: %w", err) + } + + if info.IsDir() { + return fmt.Errorf("not a script: %s is a directory", absPath) + } + + // Additional security check: verify the file is executable + if info.Mode()&0o111 == 0 { + return fmt.Errorf("script is not executable: %s", absPath) + } + + // Parse parameters + params, err := parseParameters(paramStr) + if err != nil { + return fmt.Errorf("invalid parameters: %w", err) + } + + s.tools[name] = Tool{ + Name: name, + Description: description, + Parameters: params, + ScriptPath: absPath, + } + + return nil +} + +// parseParameters parses a comma-separated parameter string in the format "name:type,name:type". +func parseParameters(paramStr string) ([]Parameter, error) { + if paramStr == "" { + return []Parameter{}, nil + } + + params := strings.Split(paramStr, ",") + parameters := make([]Parameter, 0, len(params)) + + for _, param := range params { + parts := strings.Split(strings.TrimSpace(param), ":") + if len(parts) != 2 { + return nil, fmt.Errorf("invalid parameter format: %s, expected name:type", param) + } + + name := strings.TrimSpace(parts[0]) + paramType := strings.TrimSpace(parts[1]) + + // Validate parameter name + if name == "" { + return nil, fmt.Errorf("parameter name cannot be empty") + } + + // Validate parameter type + validTypes := map[string]bool{"string": true, "int": true, "float": true, "bool": true} + if !validTypes[paramType] { + return nil, fmt.Errorf("invalid parameter type: %s, supported types: string, int, float, bool", paramType) + } + + parameters = append(parameters, Parameter{ + Name: name, + Type: paramType, + }) + } + + return parameters, nil +} + +// ExecuteScript executes a shell script with the given parameters. +func (s *Server) ExecuteScript(toolName string, args map[string]interface{}) (string, error) { + tool, exists := s.tools[toolName] + if !exists { + return "", fmt.Errorf("tool not found: %s", toolName) + } + + // Additional runtime validation of script path + scriptPath := filepath.Clean(tool.ScriptPath) + info, err := os.Stat(scriptPath) + if err != nil { + return "", fmt.Errorf("script not found or not accessible: %w", err) + } + if info.IsDir() { + return "", fmt.Errorf("not a script: %s is a directory", scriptPath) + } + if info.Mode()&0o111 == 0 { + return "", fmt.Errorf("script is not executable: %s", scriptPath) + } + + // Set up environment variables for the script + env := os.Environ() + for name, value := range args { + // Convert value to string + strValue := fmt.Sprintf("%v", value) + env = append(env, fmt.Sprintf("%s=%s", name, strValue)) + } + + // Determine which shell to use for executing the script + shell := "/bin/sh" + bashExists, statErr := os.Stat("/bin/bash") + if statErr == nil && !bashExists.IsDir() { + shell = "/bin/bash" + } + + // Instead of using scriptPath directly (which would trigger gosec G204), + // we've already validated that this script is in our allowlist (s.tools) + // and we've performed additional validation above. + // #nosec G204 - scriptPath is validated and comes from a trusted source (config) + cmd := exec.Command(shell, "-c", scriptPath) + cmd.Env = env + cmd.Stderr = os.Stderr + + // Execute and capture output + output, err := cmd.Output() + if err != nil { + return "", fmt.Errorf("error executing script: %w", err) + } + + return string(output), nil +} + +// GetToolSchema generates a JSON schema for the tool's parameters. +func (s *Server) GetToolSchema(toolName string) (map[string]interface{}, error) { + tool, exists := s.tools[toolName] + if !exists { + return nil, fmt.Errorf("tool not found: %s", toolName) + } + + properties := make(map[string]interface{}) + required := make([]string, 0, len(tool.Parameters)) + + for _, param := range tool.Parameters { + var paramSchema map[string]interface{} + + switch param.Type { + case "string": + paramSchema = map[string]interface{}{ + "type": "string", + } + case "int": + paramSchema = map[string]interface{}{ + "type": "integer", + } + case "float": + paramSchema = map[string]interface{}{ + "type": "number", + } + case "bool": + paramSchema = map[string]interface{}{ + "type": "boolean", + } + } + + properties[param.Name] = paramSchema + required = append(required, param.Name) + } + + schema := map[string]interface{}{ + "type": "object", + "properties": properties, + } + + if len(required) > 0 { + schema["required"] = required + } + + return schema, nil +} + +// Start begins listening for JSON-RPC requests on stdin and responding on stdout. +func (s *Server) Start() error { + decoder := json.NewDecoder(os.Stdin) + + s.log("Proxy server started, waiting for requests...") + fmt.Fprintf(os.Stderr, "Proxy server started, waiting for requests...\n") + + // Check error from Close() when deferring + defer func() { + if err := s.Close(); err != nil { + fmt.Fprintf(os.Stderr, "Error closing log file: %v\n", err) + } + }() + + for { + // Request struct with fields ordered for optimal memory alignment + var request struct { + Method string `json:"method"` // string (16 bytes: pointer + len) + Params map[string]interface{} `json:"params,omitempty"` // map (8 bytes) + JSONRPC string `json:"jsonrpc"` // string (16 bytes: pointer + len) + ID int `json:"id"` // int (8 bytes) + } + + fmt.Fprintf(os.Stderr, "Waiting for request...\n") + if err := decoder.Decode(&request); err != nil { + if err == io.EOF { + s.log("Client disconnected (EOF)") + } else { + s.log(fmt.Sprintf("Error decoding request: %v", err)) + } + fmt.Fprintf(os.Stderr, "Error decoding request: %v\n", err) + return fmt.Errorf("error decoding request: %w", err) + } + + // Log the incoming request + s.logJSON("Received request", request) + fmt.Fprintf(os.Stderr, "Received request: %s (ID: %d)\n", request.Method, request.ID) + s.id = request.ID + + // Handle notifications (methods without an ID) + if request.Method == "notifications/initialized" { + fmt.Fprintf(os.Stderr, "Received initialization notification\n") + s.log("Received initialization notification") + continue + } + + var response any + var err error + + switch request.Method { + case "initialize": + response = s.handleInitialize(request.Params) + case "tools/list": + response = s.handleToolsList() + case "tools/call": + response, err = s.handleToolCall(request.Params) + default: + err = fmt.Errorf("method not found") + } + + if err != nil { + fmt.Fprintf(os.Stderr, "Error handling request: %v\n", err) + s.log(fmt.Sprintf("Error handling request: %v", err)) + s.writeError(err) + continue + } + + fmt.Fprintf(os.Stderr, "Sending response\n") + s.writeResponse(response) + } +} + +// handleInitialize handles the initialize request from the client. +func (s *Server) handleInitialize(params map[string]interface{}) map[string]interface{} { + // Log the initialization parameters + if clientInfo, ok := params["clientInfo"].(map[string]interface{}); ok { + clientName, _ := clientInfo["name"].(string) + clientVersion, _ := clientInfo["version"].(string) + fmt.Fprintf(os.Stderr, "Client initialized: %s v%s\n", clientName, clientVersion) + } + + // Extract protocol version from params, defaulting to latest if not provided + protocolVersion := "2024-11-05" + if version, ok := params["protocolVersion"].(string); ok { + protocolVersion = version + } + + // Return server information and capabilities in the format expected by clients + capabilities := map[string]interface{}{ + "tools": map[string]interface{}{}, + } + + return map[string]interface{}{ + "protocolVersion": protocolVersion, + "capabilities": capabilities, + "serverInfo": map[string]interface{}{ + "name": "mcp-proxy-server", + "version": "1.0.0", + }, + } +} + +// handleToolsList returns the list of available tools. +func (s *Server) handleToolsList() map[string]interface{} { + tools := make([]map[string]interface{}, 0, len(s.tools)) + + for _, tool := range s.tools { + // Generate schema directly from the tool parameters + properties := make(map[string]interface{}) + required := make([]string, 0, len(tool.Parameters)) + + for _, param := range tool.Parameters { + var paramSchema map[string]interface{} + + switch param.Type { + case "string": + paramSchema = map[string]interface{}{ + "type": "string", + } + case "int": + paramSchema = map[string]interface{}{ + "type": "integer", + } + case "float": + paramSchema = map[string]interface{}{ + "type": "number", + } + case "bool": + paramSchema = map[string]interface{}{ + "type": "boolean", + } + } + + properties[param.Name] = paramSchema + required = append(required, param.Name) + } + + schema := map[string]interface{}{ + "type": "object", + "properties": properties, + } + + if len(required) > 0 { + schema["required"] = required + } + + tools = append(tools, map[string]interface{}{ + "name": tool.Name, + "description": tool.Description, + "inputSchema": schema, + }) + } + + return map[string]interface{}{ + "tools": tools, + } +} + +// handleToolCall handles a tool call request. +func (s *Server) handleToolCall(params map[string]interface{}) (map[string]interface{}, error) { + nameValue, ok := params["name"] + if !ok { + return nil, fmt.Errorf("missing 'name' parameter") + } + + name, ok := nameValue.(string) + if !ok { + return nil, fmt.Errorf("'name' parameter must be a string") + } + + _, exists := s.tools[name] + if !exists { + return nil, fmt.Errorf("tool not found: %s", name) + } + + // Extract input arguments + inputValue, ok := params["input"] + if !ok { + return nil, fmt.Errorf("missing 'input' parameter") + } + + input, ok := inputValue.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("'input' parameter must be an object") + } + + // Log the input parameters + s.logJSON("Tool input", input) + + // Execute the shell script + output, err := s.ExecuteScript(name, input) + if err != nil { + s.log(fmt.Sprintf("Error executing script: %v", err)) + return nil, fmt.Errorf("error executing script: %w", err) + } + + // Log the output + s.log(fmt.Sprintf("Script output: %s", output)) + + // Return the output in the correct format for the MCP protocol + return map[string]interface{}{ + "content": []map[string]interface{}{ + { + "type": "text", + "text": output, + }, + }, + }, nil +} + +// writeResponse writes a successful JSON-RPC response to stdout. +func (s *Server) writeResponse(result any) { + response := map[string]interface{}{ + "jsonrpc": "2.0", + "id": s.id, + "result": result, + } + + // Log the outgoing response + s.logJSON("Sending response", response) + + err := json.NewEncoder(os.Stdout).Encode(response) + if err != nil { + s.log(fmt.Sprintf("Error encoding response: %v", err)) + fmt.Fprintf(os.Stderr, "Error encoding response: %v\n", err) + } +} + +// writeError writes a JSON-RPC error response to stdout. +func (s *Server) writeError(err error) { + // Use method not found error code for unsupported methods + code := -32000 // Default server error + if err.Error() == "method not found" { + code = -32601 // Method not found error code + } + + response := map[string]interface{}{ + "jsonrpc": "2.0", + "id": s.id, + "error": map[string]interface{}{ + "code": code, + "message": err.Error(), + }, + } + + // Log the outgoing error response + s.logJSON("Sending error response", response) + + encodeErr := json.NewEncoder(os.Stdout).Encode(response) + if encodeErr != nil { + s.log(fmt.Sprintf("Error encoding error response: %v", encodeErr)) + fmt.Fprintf(os.Stderr, "Error encoding error response: %v\n", encodeErr) + } +} + +// RunProxyServer creates and runs a proxy server with the specified tool configs. +func RunProxyServer(toolConfigs map[string]map[string]string) error { + server, err := NewProxyServer() + if err != nil { + return fmt.Errorf("error creating server: %w", err) + } + + // Add tools from configs + for name, config := range toolConfigs { + description := config["description"] + parameters := config["parameters"] + scriptPath := config["script"] + + addErr := server.AddTool(name, description, parameters, scriptPath) + if addErr != nil { + return fmt.Errorf("error adding tool %s: %w", name, addErr) + } + } + + // Print registered tools + fmt.Println("Registered proxy tools:") + for name, tool := range server.tools { + fmt.Printf("- %s: %s (script: %s)\n", name, tool.Description, tool.ScriptPath) + paramStr := "" + for i, param := range tool.Parameters { + if i > 0 { + paramStr += ", " + } + paramStr += param.Name + ":" + param.Type + } + if paramStr != "" { + fmt.Printf(" Parameters: %s\n", paramStr) + } + } + + server.log(fmt.Sprintf("Starting proxy server with %d tools", len(toolConfigs))) + return server.Start() +}