diff --git a/.gitignore b/.gitignore index dc0daa9..94251e3 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ vendor bin .vscode sshcode +sshcode.exe diff --git a/README.md b/README.md index a5db262..3882f94 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,11 @@ # sshcode +**This project has been deprecated in favour of the [code-server install script](https://github.com/cdr/code-server#quick-install)** + +**See the discussion in [#185](https://github.com/cdr/sshcode/issues/185)** + +--- + [!["Open Issues"](https://img.shields.io/github/issues-raw/cdr/sshcode.svg)](https://github.com/cdr/sshcode/issues) [!["Latest Release"](https://img.shields.io/github/release/cdr/sshcode.svg)](https://github.com/cdr/sshcode/releases/latest) [![MIT license](https://img.shields.io/badge/license-MIT-green.svg)](https://github.com/cdr/sshcode/blob/master/LICENSE) @@ -36,6 +42,11 @@ We currently support: - MacOS - WSL +For the remote server, we currently only support Linux `x86_64` (64-bit) +servers with `glibc`. `musl` libc (which is most notably used by Alpine Linux) +is currently not supported on the remote server: +[#122](https://github.com/cdr/sshcode/issues/122). + ## Usage ```bash diff --git a/ci/build.sh b/ci/build.sh index 1f14096..9e30b09 100755 --- a/ci/build.sh +++ b/ci/build.sh @@ -10,8 +10,8 @@ build(){ go build -ldflags "-X main.version=${tag}" -o $tmpdir/sshcode pushd $tmpdir - tarname=sshcode-$GOOS-$GOARCH.tar - tar -cf $tarname sshcode + tarname=sshcode-$GOOS-$GOARCH.tar.gz + tar -czf $tarname sshcode popd cp $tmpdir/$tarname bin rm -rf $tmpdir diff --git a/go.mod b/go.mod index 26daea9..e378ffe 100644 --- a/go.mod +++ b/go.mod @@ -5,11 +5,12 @@ go 1.12 require ( github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4 github.com/pkg/errors v0.8.1 // indirect + github.com/spf13/pflag v1.0.3 github.com/stretchr/testify v1.3.0 - go.coder.com/cli v0.1.0 + go.coder.com/cli v0.4.0 go.coder.com/flog v0.0.0-20190129195112-eaed154a0db8 go.coder.com/retry v0.0.0-20180926062817-cf12c95974ac golang.org/x/crypto v0.0.0-20190422183909-d864b10871cd golang.org/x/sys v0.0.0-20190418153312-f0ce4c0180be // indirect - golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373 + golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 ) diff --git a/go.sum b/go.sum index c24ed7b..19ceba8 100644 --- a/go.sum +++ b/go.sum @@ -13,12 +13,14 @@ github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/spf13/pflag v1.0.3 h1:zPAT6CGy6wXeQ7NtTnaTerfKOsV6V6F8agHXFiazDkg= +github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.1.4/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -go.coder.com/cli v0.1.0 h1:ZAjpjXJxMnwj1TqXUi7nnXXuxiPRfwfoC2kViN93oMM= -go.coder.com/cli v0.1.0/go.mod h1:pbVagI9YH/HHMManxPFML4P527GDREwsb+yciZ7mtB8= +go.coder.com/cli v0.4.0 h1:PruDGwm/CPFndyK/eMowZG3vzg5CgohRWeXWCTr3zi8= +go.coder.com/cli v0.4.0/go.mod h1:hRTOURCR3LJF1FRW9arecgrzX+AHG7mfYMwThPIgq+w= go.coder.com/flog v0.0.0-20190129195112-eaed154a0db8 h1:PtQ3moPi4EAz3cyQhkUs1IGIXa2QgJpP60yMjOdu0kk= go.coder.com/flog v0.0.0-20190129195112-eaed154a0db8/go.mod h1:83JsYgXYv0EOaXjIMnaZ1Fl6ddNB3fJnDZ/8845mUJ8= go.coder.com/retry v0.0.0-20180926062817-cf12c95974ac h1:ekdpsuykRy/E+SDq5BquFomNhRCk8OOyhtnACW9Bi50= @@ -32,5 +34,5 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190418153312-f0ce4c0180be h1:mI+jhqkn68ybP0ORJqunXn+fq+Eeb4hHKqLQcFICjAc= golang.org/x/sys v0.0.0-20190418153312-f0ce4c0180be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373 h1:PPwnA7z1Pjf7XYaBP9GL1VAMZmcIWyFz7QCMSIIa3Bg= -golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/main.go b/main.go index bc0de4a..674f983 100644 --- a/main.go +++ b/main.go @@ -1,13 +1,15 @@ package main import ( - "flag" "fmt" "math/rand" "os" + "runtime" "strings" "time" + "github.com/spf13/pflag" + "go.coder.com/cli" "go.coder.com/flog" ) @@ -34,11 +36,13 @@ var _ interface { } = new(rootCmd) type rootCmd struct { - skipSync bool - syncBack bool - printVersion bool - bindAddr string - sshFlags string + skipSync bool + syncBack bool + printVersion bool + noReuseConnection bool + bindAddr string + sshFlags string + uploadCodeServer string } func (c *rootCmd) Spec() cli.CommandSpec { @@ -49,15 +53,17 @@ func (c *rootCmd) Spec() cli.CommandSpec { } } -func (c *rootCmd) RegisterFlags(fl *flag.FlagSet) { +func (c *rootCmd) RegisterFlags(fl *pflag.FlagSet) { fl.BoolVar(&c.skipSync, "skipsync", false, "skip syncing local settings and extensions to remote host") fl.BoolVar(&c.syncBack, "b", false, "sync extensions back on termination") fl.BoolVar(&c.printVersion, "version", false, "print version information and exit") - fl.StringVar(&c.bindAddr, "bind", "", "local bind address for ssh tunnel") + fl.BoolVar(&c.noReuseConnection, "no-reuse-connection", false, "do not reuse SSH connection via control socket") + fl.StringVar(&c.bindAddr, "bind", "", "local bind address for SSH tunnel, in [HOST][:PORT] syntax (default: 127.0.0.1)") fl.StringVar(&c.sshFlags, "ssh-flags", "", "custom SSH flags") + fl.StringVar(&c.uploadCodeServer, "upload-code-server", "", "custom code-server binary to upload to the remote host") } -func (c *rootCmd) Run(fl *flag.FlagSet) { +func (c *rootCmd) Run(fl *pflag.FlagSet) { if c.printVersion { fmt.Printf("%v\n", version) os.Exit(0) @@ -75,11 +81,18 @@ func (c *rootCmd) Run(fl *flag.FlagSet) { dir = "~" } + // Get linux relative path if on windows. + if runtime.GOOS == "windows" { + dir = gitbashWindowsDir(dir) + } + err := sshCode(host, dir, options{ - skipSync: c.skipSync, - sshFlags: c.sshFlags, - bindAddr: c.bindAddr, - syncBack: c.syncBack, + skipSync: c.skipSync, + sshFlags: c.sshFlags, + bindAddr: c.bindAddr, + syncBack: c.syncBack, + reuseConnection: !c.noReuseConnection, + uploadCodeServer: c.uploadCodeServer, }) if err != nil { @@ -101,7 +114,7 @@ Environment variables: More info: https://github.com/cdr/sshcode Arguments: -%vHOST is passed into the ssh command. Valid formats are '' or 'gcp:'. +%vHOST is passed into the ssh command. Valid formats are '' or 'gcp:'. %vDIR is optional.`, helpTab, vsCodeConfigDirEnv, helpTab, vsCodeExtensionsDirEnv, diff --git a/settings.go b/settings.go index ad962a3..e88c260 100644 --- a/settings.go +++ b/settings.go @@ -24,6 +24,8 @@ func configDir() (string, error) { path = os.ExpandEnv("$HOME/.config/Code/User/") case "darwin": path = os.ExpandEnv("$HOME/Library/Application Support/Code/User/") + case "windows": + return os.ExpandEnv("/c/Users/$USERNAME/AppData/Roaming/Code/User"), nil default: return "", xerrors.Errorf("unsupported platform: %s", runtime.GOOS) } @@ -39,6 +41,8 @@ func extensionsDir() (string, error) { switch runtime.GOOS { case "linux", "darwin": path = os.ExpandEnv("$HOME/.vscode/extensions/") + case "windows": + return os.ExpandEnv("/c/Users/$USERNAME/.vscode/extensions"), nil default: return "", xerrors.Errorf("unsupported platform: %s", runtime.GOOS) } diff --git a/sshcode.go b/sshcode.go index a31e3d8..5021c09 100644 --- a/sshcode.go +++ b/sshcode.go @@ -10,8 +10,10 @@ import ( "os/exec" "os/signal" "path/filepath" + "runtime" "strconv" "strings" + "syscall" "time" "github.com/pkg/browser" @@ -21,18 +23,24 @@ import ( const codeServerPath = "~/.cache/sshcode/sshcode-server" +const ( + sshDirectory = "~/.ssh" + sshDirectoryUnsafeModeMask = 0022 + sshControlPath = sshDirectory + "/control-%h-%p-%r" +) + type options struct { - skipSync bool - syncBack bool - noOpen bool - bindAddr string - remotePort string - sshFlags string + skipSync bool + syncBack bool + noOpen bool + reuseConnection bool + bindAddr string + remotePort string + sshFlags string + uploadCodeServer string } func sshCode(host, dir string, o options) error { - flog.Info("ensuring code-server is updated...") - host, extraSSHFlags, err := parseHost(host) if err != nil { return xerrors.Errorf("failed to parse host IP: %w", err) @@ -53,22 +61,65 @@ func sshCode(host, dir string, o options) error { return xerrors.Errorf("failed to find available remote port: %w", err) } - dlScript := downloadScript(codeServerPath) + // Check the SSH directory's permissions and warn the user if it is not safe. + o.reuseConnection = checkSSHDirectory(sshDirectory, o.reuseConnection) - // Downloads the latest code-server and allows it to be executed. - sshCmdStr := fmt.Sprintf("ssh %v %v /bin/bash", o.sshFlags, host) + // Start SSH master connection socket. This prevents multiple password prompts from appearing as authentication + // only happens on the initial connection. + if o.reuseConnection { + flog.Info("starting SSH master connection...") + newSSHFlags, cancel, err := startSSHMaster(o.sshFlags, sshControlPath, host) + defer cancel() + if err != nil { + flog.Error("failed to start SSH master connection: %v", err) + o.reuseConnection = false + } else { + o.sshFlags = newSSHFlags + } + } - sshCmd := exec.Command("sh", "-c", sshCmdStr) - sshCmd.Stdout = os.Stdout - sshCmd.Stderr = os.Stderr - sshCmd.Stdin = strings.NewReader(dlScript) - err = sshCmd.Run() - if err != nil { - return xerrors.Errorf("failed to update code-server: \n---ssh cmd---\n%s\n---download script---\n%s: %w", - sshCmdStr, - dlScript, - err, - ) + // Upload local code-server or download code-server from CI server. + if o.uploadCodeServer != "" { + flog.Info("uploading local code-server binary...") + err = copyCodeServerBinary(o.sshFlags, host, o.uploadCodeServer, codeServerPath) + if err != nil { + return xerrors.Errorf("failed to upload local code-server binary to remote server: %w", err) + } + + sshCmdStr := + fmt.Sprintf("ssh %v %v 'chmod +x %v'", + o.sshFlags, host, codeServerPath, + ) + + sshCmd := exec.Command("sh", "-l", "-c", sshCmdStr) + sshCmd.Stdout = os.Stdout + sshCmd.Stderr = os.Stderr + err = sshCmd.Run() + if err != nil { + return xerrors.Errorf("failed to make code-server binary executable:\n---ssh cmd---\n%s: %w", + sshCmdStr, + err, + ) + } + } else { + flog.Info("ensuring code-server is updated...") + dlScript := downloadScript(codeServerPath) + + // Downloads the latest code-server and allows it to be executed. + sshCmdStr := fmt.Sprintf("ssh %v %v '/usr/bin/env bash -l'", o.sshFlags, host) + sshCmd := exec.Command("sh", "-l", "-c", sshCmdStr) + sshCmd.Stdout = os.Stdout + sshCmd.Stderr = os.Stderr + sshCmd.Stdin = strings.NewReader(dlScript) + err = sshCmd.Run() + if err != nil { + return xerrors.Errorf("failed to update code-server:\n---ssh cmd---\n%s"+ + "\n---download script---\n%s: %w", + sshCmdStr, + dlScript, + err, + ) + } } if !o.skipSync { @@ -93,13 +144,12 @@ func sshCode(host, dir string, o options) error { flog.Info("Tunneling remote port %v to %v", o.remotePort, o.bindAddr) - sshCmdStr = - fmt.Sprintf("ssh -tt -q -L %v:localhost:%v %v %v 'cd %v; %v --host 127.0.0.1 --allow-http --no-auth --port=%v'", - o.bindAddr, o.remotePort, o.sshFlags, host, dir, codeServerPath, o.remotePort, + sshCmdStr := + fmt.Sprintf("ssh -tt -q -L %v:localhost:%v %v %v '%v %v --host 127.0.0.1 --auth none --port=%v'", + o.bindAddr, o.remotePort, o.sshFlags, host, codeServerPath, dir, o.remotePort, ) - // Starts code-server and forwards the remote port. - sshCmd = exec.Command("sh", "-c", sshCmdStr) + sshCmd := exec.Command("sh", "-l", "-c", sshCmdStr) sshCmd.Stdin = os.Stdin sshCmd.Stdout = os.Stdout sshCmd.Stderr = os.Stderr @@ -147,8 +197,8 @@ func sshCode(host, dir string, o options) error { case <-c: } + flog.Info("shutting down") if !o.syncBack || o.skipSync { - flog.Info("shutting down") return nil } @@ -161,15 +211,33 @@ func sshCode(host, dir string, o options) error { err = syncUserSettings(o.sshFlags, host, true) if err != nil { - return xerrors.Errorf("failed to sync user settings settings back: %w", err) + return xerrors.Errorf("failed to sync user settings back: %w", err) } return nil } +// expandPath returns an expanded version of path. +func expandPath(path string) string { + path = filepath.Clean(os.ExpandEnv(path)) + + // Replace tilde notation in path with the home directory. You can't replace the first instance of `~` in the + // string with the homedir as having a tilde in the middle of a filename is valid. + homedir := os.Getenv("HOME") + if homedir != "" { + if path == "~" { + path = homedir + } else if strings.HasPrefix(path, "~/") { + path = filepath.Join(homedir, path[2:]) + } + } + + return filepath.Clean(path) +} + func parseBindAddr(bindAddr string) (string, error) { - if bindAddr == "" { - bindAddr = ":" + if !strings.Contains(bindAddr, ":") { + bindAddr += ":" } host, port, err := net.SplitHostPort(bindAddr) @@ -197,9 +265,12 @@ func openBrowser(url string) { const ( macPath = "/Applications/Google Chrome.app/Contents/MacOS/Google Chrome" wslPath = "/mnt/c/Program Files (x86)/Google/Chrome/Application/chrome.exe" + winPath = "C:/Program Files (x86)/Google/Chrome/Application/chrome.exe" ) switch { + case commandExists("chrome"): + openCmd = exec.Command("chrome", chromeOptions(url)...) case commandExists("google-chrome"): openCmd = exec.Command("google-chrome", chromeOptions(url)...) case commandExists("google-chrome-stable"): @@ -212,6 +283,8 @@ func openBrowser(url string) { openCmd = exec.Command(macPath, chromeOptions(url)...) case pathExists(wslPath): openCmd = exec.Command(wslPath, chromeOptions(url)...) + case pathExists(winPath): + openCmd = exec.Command(winPath, chromeOptions(url)...) default: err := browser.OpenURL(url) if err != nil { @@ -263,6 +336,119 @@ func randomPort() (string, error) { return "", xerrors.Errorf("max number of tries exceeded: %d", maxTries) } +// checkSSHDirectory performs sanity and safety checks on sshDirectory, and +// returns a new value for o.reuseConnection depending on the checks. +func checkSSHDirectory(sshDirectory string, reuseConnection bool) bool { + if runtime.GOOS == "windows" { + flog.Info("OS is windows, disabling connection reuse feature") + return false + } + + sshDirectoryMode, err := os.Lstat(expandPath(sshDirectory)) + if err != nil { + if reuseConnection { + flog.Info("failed to stat %v directory, disabling connection reuse feature: %v", sshDirectory, err) + } + reuseConnection = false + } else { + if !sshDirectoryMode.IsDir() { + if reuseConnection { + flog.Info("%v is not a directory, disabling connection reuse feature", sshDirectory) + } else { + flog.Info("warning: %v is not a directory", sshDirectory) + } + reuseConnection = false + } + if sshDirectoryMode.Mode().Perm()&sshDirectoryUnsafeModeMask != 0 { + flog.Info("warning: the %v directory has unsafe permissions, they should only be writable by "+ + "the owner (and files inside should be set to 0600)", sshDirectory) + } + } + return reuseConnection +} + +// startSSHMaster starts an SSH master connection and waits for it to be ready. +// It returns a new set of SSH flags for child SSH processes to use. +func startSSHMaster(sshFlags string, sshControlPath string, host string) (string, func(), error) { + ctx, cancel := context.WithCancel(context.Background()) + + newSSHFlags := fmt.Sprintf(`%v -o "ControlPath=%v"`, sshFlags, sshControlPath) + + // -MN means "start a master socket and don't open a session, just connect". + sshCmdStr := fmt.Sprintf(`exec ssh %v -MNq %v`, newSSHFlags, host) + sshMasterCmd := exec.CommandContext(ctx, "sh", "-c", sshCmdStr) + sshMasterCmd.Stdin = os.Stdin + sshMasterCmd.Stderr = os.Stderr + + // Gracefully stop the SSH master. + stopSSHMaster := func() { + if sshMasterCmd.Process != nil { + if sshMasterCmd.ProcessState != nil && sshMasterCmd.ProcessState.Exited() { + return + } + err := sshMasterCmd.Process.Signal(syscall.SIGTERM) + if err != nil { + flog.Error("failed to send SIGTERM to SSH master process: %v", err) + } + } + cancel() + } + + // Start ssh master and wait. Waiting prevents the process from becoming a zombie process if it dies before + // sshcode does, and allows sshMasterCmd.ProcessState to be populated. + err := sshMasterCmd.Start() + go sshMasterCmd.Wait() + if err != nil { + return "", stopSSHMaster, err + } + err = checkSSHMaster(sshMasterCmd, newSSHFlags, host) + if err != nil { + stopSSHMaster() + return "", stopSSHMaster, xerrors.Errorf("SSH master wasn't ready on time: %w", err) + } + return newSSHFlags, stopSSHMaster, nil +} + +// checkSSHMaster polls every second for 30 seconds to check if the SSH master +// is ready. +func checkSSHMaster(sshMasterCmd *exec.Cmd, sshFlags string, host string) error { + var ( + maxTries = 30 + sleepDur = time.Second + err error + ) + for i := 0; i < maxTries; i++ { + // Check if the master is running. + if sshMasterCmd.Process == nil || (sshMasterCmd.ProcessState != nil && sshMasterCmd.ProcessState.Exited()) { + return xerrors.Errorf("SSH master process is not running") + } + + // Check if it's ready. + sshCmdStr := fmt.Sprintf(`ssh %v -O check %v`, sshFlags, host) + sshCmd := exec.Command("sh", "-c", sshCmdStr) + err = sshCmd.Run() + if err == nil { + return nil + } + time.Sleep(sleepDur) + } + return xerrors.Errorf("max number of tries exceeded: %d", maxTries) +} + +// copyCodeServerBinary copies a code-server binary from local to remote. +func copyCodeServerBinary(sshFlags string, host string, localPath string, remotePath string) error { + if err := validateIsFile(localPath); err != nil { + return err + } + + var ( + src = localPath + dest = host + ":" + remotePath + ) + + return rsync(src, dest, sshFlags) +} + func syncUserSettings(sshFlags string, host string, back bool) error { localConfDir, err := configDir() if err != nil { @@ -274,8 +460,10 @@ func syncUserSettings(sshFlags string, host string, back bool) error { return err } - const remoteSettingsDir = "~/.local/share/code-server/User/" - + var remoteSettingsDir = "~/.local/share/code-server/User/" + if runtime.GOOS == "windows" { + remoteSettingsDir = ".local/share/code-server/User/" + } var ( src = localConfDir + "/" dest = host + ":" + remoteSettingsDir @@ -300,7 +488,10 @@ func syncExtensions(sshFlags string, host string, back bool) error { return err } - const remoteExtensionsDir = "~/.local/share/code-server/extensions/" + var remoteExtensionsDir = "~/.local/share/code-server/extensions/" + if runtime.GOOS == "windows" { + remoteExtensionsDir = ".local/share/code-server/extensions/" + } var ( src = localExtensionsDir + "/" @@ -328,6 +519,7 @@ func rsync(src string, dest string, sshFlags string, excludePaths ...string) err // locally in order to properly delete an extension. "--delete", "--copy-unsafe-links", + "-zz", src, dest, )..., ) @@ -345,16 +537,21 @@ func downloadScript(codeServerPath string) string { return fmt.Sprintf( `set -euxo pipefail || exit 1 +[ "$(uname -m)" != "x86_64" ] && echo "Unsupported server architecture $(uname -m). code-server only has releases for x86_64 systems." && exit 1 pkill -f %v || true -mkdir -p ~/.local/share/code-server %v +mkdir -p $HOME/.local/share/code-server %v cd %v -wget -N https://codesrv-ci.cdr.sh/latest-linux +curlflags="-o latest-linux" +if [ -f latest-linux ]; then + curlflags="$curlflags -z latest-linux" +fi +curl $curlflags https://codesrv-ci.cdr.sh/latest-linux [ -f %v ] && rm %v ln latest-linux %v chmod +x %v`, codeServerPath, - filepath.Dir(codeServerPath), - filepath.Dir(codeServerPath), + filepath.ToSlash(filepath.Dir(codeServerPath)), + filepath.ToSlash(filepath.Dir(codeServerPath)), codeServerPath, codeServerPath, codeServerPath, @@ -366,6 +563,11 @@ chmod +x %v`, func ensureDir(path string) error { _, err := os.Stat(path) if os.IsNotExist(err) { + // This fixes a issue where Go reads `/c/` as `C:\c\` and creates + // empty directories on the client that don't need to exist. + if runtime.GOOS == "windows" && strings.HasPrefix(path, "/c/") { + path = "C:" + path[2:] + } err = os.MkdirAll(path, 0750) } @@ -376,6 +578,18 @@ func ensureDir(path string) error { return nil } +// validateIsFile tries to stat the specified path and ensure it's a file. +func validateIsFile(path string) error { + info, err := os.Stat(path) + if err != nil { + return err + } + if info.IsDir() { + return xerrors.New("path is a directory") + } + return nil +} + // parseHost parses the host argument. If 'gcp:' is prefixed to the // host then a lookup is done using gcloud to determine the external IP and any // additional SSH arguments that should be used for ssh commands. Otherwise, host @@ -396,7 +610,7 @@ func parseHost(host string) (parsedHost string, additionalFlags string, err erro func parseGCPSSHCmd(instance string) (ip, sshFlags string, err error) { dryRunCmd := fmt.Sprintf("gcloud compute ssh --dry-run %v", instance) - out, err := exec.Command("sh", "-c", dryRunCmd).CombinedOutput() + out, err := exec.Command("sh", "-l", "-c", dryRunCmd).CombinedOutput() if err != nil { return "", "", xerrors.Errorf("%s: %w", out, err) } @@ -411,17 +625,29 @@ func parseGCPSSHCmd(instance string) (ip, sshFlags string, err error) { // E.g. foo@1.2.3.4. userIP := toks[len(toks)-1] - toks = strings.Split(userIP, "@") - // Assume the '@' is missing. - if len(toks) < 2 { - ip = strings.TrimSpace(toks[0]) - } else { - ip = strings.TrimSpace(toks[1]) + + return strings.TrimSpace(userIP), sshFlags, nil +} + +// gitbashWindowsDir strips a the msys2 install directory from the beginning of +// the path. On msys2, if a user provides `/workspace` sshcode will receive +// `C:/msys64/workspace` which won't work on the remote host. +func gitbashWindowsDir(dir string) string { + + // Don't bother figuring out path if it's relative to home dir. + if strings.HasPrefix(dir, "~/") { + if dir == "~" { + return "~/" + } + return dir } - if net.ParseIP(ip) == nil { - return "", "", xerrors.Errorf("parsed invalid ip address %v", ip) + mingwPrefix, err := exec.Command("sh", "-c", "{ cd / && pwd -W; }").Output() + if err != nil { + // Default to a sane location. + mingwPrefix = []byte("C:/mingw64") } - return ip, sshFlags, nil + prefix := strings.TrimSuffix(string(mingwPrefix), "/\n") + return strings.TrimPrefix(dir, prefix) } diff --git a/sshcode_test.go b/sshcode_test.go index fc6eb7d..096bff6 100644 --- a/sshcode_test.go +++ b/sshcode_test.go @@ -48,7 +48,7 @@ func TestSSHCode(t *testing.T) { waitForSSHCode(t, remotePort, time.Second*30) // Typically we'd do an os.Stat call here but the os package doesn't expand '~' - out, err := exec.Command("sh", "-c", "stat "+codeServerPath).CombinedOutput() + out, err := exec.Command("sh", "-l", "-c", "stat "+codeServerPath).CombinedOutput() require.NoError(t, err, "%s", out) out, err = exec.Command("pkill", filepath.Base(codeServerPath)).CombinedOutput() @@ -200,7 +200,7 @@ func handleSession(ch ssh.Channel, in <-chan *ssh.Request, t *testing.T) { return } - cmd := exec.Command("sh", "-c", exReq.Command) + cmd := exec.Command("sh", "-l", "-c", exReq.Command) stdin, err := cmd.StdinPipe() require.NoError(t, err)