From 54133fc08e0a6d3079bc2a168b33e768372e3a87 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Tue, 15 Feb 2022 20:35:12 +0000 Subject: [PATCH 01/18] Initial agent --- .vscode/settings.json | 2 + agent/agent.go | 140 ++++++++++++++++++++++++++++++++++++++++++ agent/agent_test.go | 61 ++++++++++++++++++ go.mod | 7 ++- go.sum | 6 +- peer/channel.go | 10 +++ peer/conn_test.go | 17 +++++ 7 files changed, 238 insertions(+), 5 deletions(-) create mode 100644 agent/agent.go create mode 100644 agent/agent_test.go diff --git a/.vscode/settings.json b/.vscode/settings.json index 34ed9fbae2c42..8b869a71b16ba 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -33,6 +33,7 @@ "drpcserver", "fatih", "goleak", + "gossh", "hashicorp", "httpmw", "isatty", @@ -54,6 +55,7 @@ "retrier", "sdkproto", "stretchr", + "tcpip", "tfexec", "tfstate", "unconvert", diff --git a/agent/agent.go b/agent/agent.go new file mode 100644 index 0000000000000..42dfea17b6104 --- /dev/null +++ b/agent/agent.go @@ -0,0 +1,140 @@ +package agent + +import ( + "context" + "errors" + "io" + "sync" + "time" + + "cdr.dev/slog" + "github.com/coder/coder/peer" + "github.com/coder/coder/peerbroker" + "github.com/coder/retry" + + "github.com/gliderlabs/ssh" + gossh "golang.org/x/crypto/ssh" +) + +type Options struct { + Logger slog.Logger +} + +type Dialer func(ctx context.Context) (*peerbroker.Listener, error) + +func Server(dialer Dialer, options *Options) io.Closer { + ctx, cancelFunc := context.WithCancel(context.Background()) + s := &server{ + clientDialer: dialer, + options: options, + closeCancel: cancelFunc, + } + s.init(ctx) + return s +} + +type server struct { + clientDialer Dialer + options *Options + + closeCancel context.CancelFunc + closeMutex sync.Mutex + closed chan struct{} + closeError error + + sshServer *ssh.Server +} + +func (s *server) init(ctx context.Context) { + forwardHandler := &ssh.ForwardedTCPHandler{} + s.sshServer = &ssh.Server{ + LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool { + return false + }, + ReversePortForwardingCallback: func(ctx ssh.Context, bindHost string, bindPort uint32) bool { + return false + }, + PtyCallback: func(ctx ssh.Context, pty ssh.Pty) bool { + return false + }, + ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig { + return &gossh.ServerConfig{ + Config: gossh.Config{ + // "arcfour" is the fastest SSH cipher. We prioritize throughput + // over encryption here, because the WebRTC connection is already + // encrypted. If possible, we'd disable encryption entirely here. + Ciphers: []string{"arcfour"}, + }, + NoClientAuth: true, + } + }, + RequestHandlers: map[string]ssh.RequestHandler{ + "tcpip-forward": forwardHandler.HandleSSHRequest, + "cancel-tcpip-forward": forwardHandler.HandleSSHRequest, + }, + } + + go s.run(ctx) +} + +func (s *server) run(ctx context.Context) { + var peerListener *peerbroker.Listener + var err error + // An exponential back-off occurs when the connection is failing to dial. + // This is to prevent server spam in case of a coderd outage. + for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); { + peerListener, err = s.clientDialer(ctx) + if err != nil { + if errors.Is(err, context.Canceled) { + return + } + if s.isClosed() { + return + } + s.options.Logger.Warn(context.Background(), "failed to dial", slog.Error(err)) + continue + } + s.options.Logger.Debug(context.Background(), "connected") + break + } + + for { + conn, err := peerListener.Accept() + if err != nil { + // This is closed! + return + } + go s.handle(ctx, conn) + } +} + +func (s *server) handle(ctx context.Context, conn *peer.Conn) { + for { + channel, err := conn.Accept(ctx) + if err != nil { + // TODO: Log here! + return + } + + switch channel.Protocol() { + case "ssh": + s.sshServer.HandleConn(channel.NetConn()) + case "proxy": + // Proxy the port provided. + } + } +} + +// isClosed returns whether the API is closed or not. +func (s *server) isClosed() bool { + select { + case <-s.closed: + return true + default: + return false + } +} + +func (s *server) Close() error { + return nil +} diff --git a/agent/agent_test.go b/agent/agent_test.go new file mode 100644 index 0000000000000..a027f7d902b8d --- /dev/null +++ b/agent/agent_test.go @@ -0,0 +1,61 @@ +package agent_test + +import ( + "context" + "net" + "os" + "testing" + + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/agent" + "github.com/coder/coder/peer" + "github.com/coder/coder/peerbroker" + "github.com/coder/coder/peerbroker/proto" + "github.com/coder/coder/provisionersdk" + "github.com/pion/webrtc/v3" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" +) + +func TestAgent(t *testing.T) { + t.Run("asd", func(t *testing.T) { + ctx := context.Background() + client, server := provisionersdk.TransportPipe() + defer client.Close() + defer server.Close() + closer := agent.Server(func(ctx context.Context) (*peerbroker.Listener, error) { + return peerbroker.Listen(server, &peer.ConnOptions{ + Logger: slogtest.Make(t, nil), + }) + }, &agent.Options{ + Logger: slogtest.Make(t, nil), + }) + defer closer.Close() + api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client)) + stream, err := api.NegotiateConnection(ctx) + require.NoError(t, err) + conn, err := peerbroker.Dial(stream, []webrtc.ICEServer{}, &peer.ConnOptions{ + Logger: slogtest.Make(t, nil), + }) + require.NoError(t, err) + defer conn.Close() + channel, err := conn.Dial(ctx, "example", &peer.ChannelOptions{ + Protocol: "ssh", + }) + require.NoError(t, err) + sshConn, channels, requests, err := ssh.NewClientConn(channel.NetConn(), "localhost:22", &ssh.ClientConfig{ + User: "kyle", + HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, + }) + require.NoError(t, err) + sshClient := ssh.NewClient(sshConn, channels, requests) + session, err := sshClient.NewSession() + require.NoError(t, err) + session.Stdout = os.Stdout + session.Stderr = os.Stderr + err = session.Run("echo test") + require.NoError(t, err) + }) +} diff --git a/go.mod b/go.mod index 290c7d3758b85..7c86df1308668 100644 --- a/go.mod +++ b/go.mod @@ -16,10 +16,11 @@ replace github.com/chzyer/readline => github.com/kylecarbs/readline v0.0.0-20220 require ( cdr.dev/slog v1.4.1 - github.com/Netflix/go-expect v0.0.0-20220104043353-73e0943537d2 github.com/briandowns/spinner v1.18.1 github.com/coder/retry v1.3.0 + github.com/creack/pty v1.1.17 github.com/fatih/color v1.13.0 + github.com/gliderlabs/ssh v0.3.3 github.com/go-chi/chi/v5 v5.0.7 github.com/go-chi/render v1.0.1 github.com/go-playground/validator/v10 v10.10.0 @@ -50,6 +51,7 @@ require ( go.uber.org/goleak v1.1.12 golang.org/x/crypto v0.0.0-20220131195533-30dcbda58838 golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8 + golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 google.golang.org/protobuf v1.27.1 nhooyr.io/websocket v1.8.7 @@ -63,11 +65,11 @@ require ( github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect github.com/agext/levenshtein v1.2.3 // indirect github.com/alecthomas/chroma v0.10.0 // indirect + github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect github.com/apparentlymart/go-textseg/v13 v13.0.0 // indirect github.com/cenkalti/backoff/v4 v4.1.2 // indirect github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e // indirect github.com/containerd/continuity v0.2.2 // indirect - github.com/creack/pty v1.1.17 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dhui/dktest v0.3.9 // indirect github.com/dlclark/regexp2 v1.4.0 // indirect @@ -124,7 +126,6 @@ require ( github.com/zeebo/errs v1.2.2 // indirect go.opencensus.io v0.23.0 // indirect golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd // indirect - golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 // indirect golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 // indirect golang.org/x/text v0.3.7 // indirect google.golang.org/appengine v1.6.7 // indirect diff --git a/go.sum b/go.sum index 8e4c8c18401ae..24b37bc319d34 100644 --- a/go.sum +++ b/go.sum @@ -103,8 +103,6 @@ github.com/Microsoft/hcsshim v0.8.23/go.mod h1:4zegtUJth7lAvFyc6cH2gGQ5B3OFQim01 github.com/Microsoft/hcsshim/test v0.0.0-20201218223536-d3e5debf77da/go.mod h1:5hlzMzRKMLyo42nCZ9oml8AdTlq/0cvIaBv6tK1RehU= github.com/Microsoft/hcsshim/test v0.0.0-20210227013316-43a75bb4edd3/go.mod h1:mw7qgWloBUl75W/gVH3cQszUg1+gUITj7D6NY7ywVnY= github.com/NYTimes/gziphandler v0.0.0-20170623195520-56545f4a5d46/go.mod h1:3wb06e3pkSAbeQ52E9H9iFoQsEEwGN64994WTCIhntQ= -github.com/Netflix/go-expect v0.0.0-20220104043353-73e0943537d2 h1:+vx7roKuyA63nhn5WAunQHLTznkw5W8b1Xc0dNjp83s= -github.com/Netflix/go-expect v0.0.0-20220104043353-73e0943537d2/go.mod h1:HBCaDeC1lPdgDeDbhX8XFpy1jqjK0IBG8W5K+xYqA0w= github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 h1:TngWCqHvy9oXAN6lEVMRuU21PR1EtLVZJmdB18Gu3Rw= github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5/go.mod h1:lmUJ/7eu/Q8D7ML55dXQrVaamCz2vxCfdQBasLZfHKk= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= @@ -134,6 +132,8 @@ github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRF github.com/alexflint/go-filemutex v0.0.0-20171022225611-72bdc8eae2ae/go.mod h1:CgnQgUtFrFz9mxFNtED3jI5tLDjKlOM+oUF/sTk6ps0= github.com/andybalholm/crlf v0.0.0-20171020200849-670099aa064f/go.mod h1:k8feO4+kXDxro6ErPXBRTJ/ro2mf0SsFG8s7doP9kJE= github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c= +github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= +github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/apache/arrow/go/arrow v0.0.0-20210818145353-234c94e4ce64/go.mod h1:2qMFB56yOP3KzkB3PbYZ4AlUFg3a88F67TIx5lB/WwY= github.com/apache/arrow/go/arrow v0.0.0-20211013220434-5962184e7a30/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs= @@ -443,6 +443,8 @@ github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm github.com/gin-gonic/gin v1.6.3 h1:ahKqKTFpO5KTPHxWZjEdPScmYaGtLo8Y4DMHoEsnp14= github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M= github.com/gliderlabs/ssh v0.2.2/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= +github.com/gliderlabs/ssh v0.3.3 h1:mBQ8NiOgDkINJrZtoizkC3nDNYgSaWtxyem6S2XHBtA= +github.com/gliderlabs/ssh v0.3.3/go.mod h1:ZSS+CUoKHDrqVakTfTWUlKSr9MtMFkC4UvtQKD7O914= github.com/go-chi/chi/v5 v5.0.7 h1:rDTPXLDHGATaeHvVlLcR4Qe0zftYethFucbjVQ1PxU8= github.com/go-chi/chi/v5 v5.0.7/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= github.com/go-chi/render v1.0.1 h1:4/5tis2cKaNdnv9zFLfXzcquC9HbeZgCnxGnKrltBS8= diff --git a/peer/channel.go b/peer/channel.go index b01154bcfaa25..18b6f0b18c57d 100644 --- a/peer/channel.go +++ b/peer/channel.go @@ -2,8 +2,10 @@ package peer import ( "context" + "fmt" "io" "net" + "runtime/debug" "sync" "time" @@ -186,6 +188,7 @@ func (c *Channel) Read(bytes []byte) (int, error) { if c.isClosed() { return 0, c.closeError } + debug.PrintStack() // An EOF always occurs when the connection is closed. // Alternative close errors will occur first if an unexpected // close has occurred. @@ -233,6 +236,8 @@ func (c *Channel) Write(bytes []byte) (n int, err error) { // See: https://github.com/pion/sctp/issues/181 time.Sleep(time.Microsecond) + fmt.Printf("Writing %d\n", len(bytes)) + return c.rwc.Write(bytes) } @@ -246,6 +251,11 @@ func (c *Channel) Label() string { return c.dc.Label() } +// Protocol returns the protocol of the underlying DataChannel. +func (c *Channel) Protocol() string { + return c.dc.Protocol() +} + // NetConn wraps the DataChannel in a struct fulfilling net.Conn. // Read, Write, and Close operations can still be used on the *Channel struct. func (c *Channel) NetConn() net.Conn { diff --git a/peer/conn_test.go b/peer/conn_test.go index 519e5f3b743db..c5584ed0b16b7 100644 --- a/peer/conn_test.go +++ b/peer/conn_test.go @@ -267,6 +267,23 @@ func TestConn(t *testing.T) { _, err := client.Ping() require.NoError(t, err) }) + + t.Run("ShortBuffer", func(t *testing.T) { + t.Parallel() + client, server, _ := createPair(t) + exchange(client, server) + go func() { + channel, err := client.Dial(context.Background(), "test", nil) + require.NoError(t, err) + _, err = channel.Write([]byte{'1', '2'}) + require.NoError(t, err) + }() + + channel, err := server.Accept(context.Background()) + require.NoError(t, err) + _, err = channel.Read(make([]byte, 1)) + require.NoError(t, err) + }) } func createPair(t *testing.T) (client *peer.Conn, server *peer.Conn, wan *vnet.Router) { From 4da9c893085cdd7c8713094cd4372ca1894ac1bc Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Wed, 16 Feb 2022 04:27:23 +0000 Subject: [PATCH 02/18] fix: Use buffered reader in peer to fix ShortBuffer This prevents a io.ErrShortBuffer from occurring when the byte slice being read is smaller than the chunks sent from the opposite pipe. This makes sense for unordered connections, where transmission is not guarunteed, but does not make sense for TCP-like connections. We use a bufio.Reader when ordered to ensure data isn't lost. --- .vscode/settings.json | 1 + go.mod | 5 ++--- go.sum | 2 -- peer/channel.go | 21 +++++++++++++++++++-- peer/conn_test.go | 21 +++++++++++++++++++++ 5 files changed, 43 insertions(+), 7 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 34ed9fbae2c42..d9b2b88f1798c 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -57,6 +57,7 @@ "tfexec", "tfstate", "unconvert", + "webrtc", "xerrors", "yamux" ] diff --git a/go.mod b/go.mod index 290c7d3758b85..d082567bfa1f8 100644 --- a/go.mod +++ b/go.mod @@ -16,9 +16,9 @@ replace github.com/chzyer/readline => github.com/kylecarbs/readline v0.0.0-20220 require ( cdr.dev/slog v1.4.1 - github.com/Netflix/go-expect v0.0.0-20220104043353-73e0943537d2 github.com/briandowns/spinner v1.18.1 github.com/coder/retry v1.3.0 + github.com/creack/pty v1.1.17 github.com/fatih/color v1.13.0 github.com/go-chi/chi/v5 v5.0.7 github.com/go-chi/render v1.0.1 @@ -50,6 +50,7 @@ require ( go.uber.org/goleak v1.1.12 golang.org/x/crypto v0.0.0-20220131195533-30dcbda58838 golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8 + golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 google.golang.org/protobuf v1.27.1 nhooyr.io/websocket v1.8.7 @@ -67,7 +68,6 @@ require ( github.com/cenkalti/backoff/v4 v4.1.2 // indirect github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e // indirect github.com/containerd/continuity v0.2.2 // indirect - github.com/creack/pty v1.1.17 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dhui/dktest v0.3.9 // indirect github.com/dlclark/regexp2 v1.4.0 // indirect @@ -124,7 +124,6 @@ require ( github.com/zeebo/errs v1.2.2 // indirect go.opencensus.io v0.23.0 // indirect golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd // indirect - golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 // indirect golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 // indirect golang.org/x/text v0.3.7 // indirect google.golang.org/appengine v1.6.7 // indirect diff --git a/go.sum b/go.sum index 8e4c8c18401ae..02d93b18ba563 100644 --- a/go.sum +++ b/go.sum @@ -103,8 +103,6 @@ github.com/Microsoft/hcsshim v0.8.23/go.mod h1:4zegtUJth7lAvFyc6cH2gGQ5B3OFQim01 github.com/Microsoft/hcsshim/test v0.0.0-20201218223536-d3e5debf77da/go.mod h1:5hlzMzRKMLyo42nCZ9oml8AdTlq/0cvIaBv6tK1RehU= github.com/Microsoft/hcsshim/test v0.0.0-20210227013316-43a75bb4edd3/go.mod h1:mw7qgWloBUl75W/gVH3cQszUg1+gUITj7D6NY7ywVnY= github.com/NYTimes/gziphandler v0.0.0-20170623195520-56545f4a5d46/go.mod h1:3wb06e3pkSAbeQ52E9H9iFoQsEEwGN64994WTCIhntQ= -github.com/Netflix/go-expect v0.0.0-20220104043353-73e0943537d2 h1:+vx7roKuyA63nhn5WAunQHLTznkw5W8b1Xc0dNjp83s= -github.com/Netflix/go-expect v0.0.0-20220104043353-73e0943537d2/go.mod h1:HBCaDeC1lPdgDeDbhX8XFpy1jqjK0IBG8W5K+xYqA0w= github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 h1:TngWCqHvy9oXAN6lEVMRuU21PR1EtLVZJmdB18Gu3Rw= github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5/go.mod h1:lmUJ/7eu/Q8D7ML55dXQrVaamCz2vxCfdQBasLZfHKk= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= diff --git a/peer/channel.go b/peer/channel.go index b01154bcfaa25..d1f4930fe31f7 100644 --- a/peer/channel.go +++ b/peer/channel.go @@ -1,6 +1,7 @@ package peer import ( + "bufio" "context" "io" "net" @@ -78,7 +79,8 @@ type Channel struct { dc *webrtc.DataChannel // This field can be nil. It becomes set after the DataChannel // has been opened and is detached. - rwc datachannel.ReadWriteCloser + rwc datachannel.ReadWriteCloser + reader io.Reader closed chan struct{} closeMutex sync.Mutex @@ -130,6 +132,21 @@ func (c *Channel) init() { _ = c.closeWithError(xerrors.Errorf("detach: %w", err)) return } + // pion/webrtc will return an io.ErrShortBuffer when a read + // is triggerred with a buffer size less than the chunks written. + // + // This makes sense when considering UDP connections, because + // bufferring of data that has no transmit guarantees is likely + // to cause unexpected behavior. + // + // When ordered, this adds a bufio.Reader. This ensures additional + // data on TCP-like connections can be read in parts, while still + // being bufferred. + if c.opts.Unordered { + c.reader = c.rwc + } else { + c.reader = bufio.NewReader(c.rwc) + } close(c.opened) }) @@ -181,7 +198,7 @@ func (c *Channel) Read(bytes []byte) (int, error) { } } - bytesRead, err := c.rwc.Read(bytes) + bytesRead, err := c.reader.Read(bytes) if err != nil { if c.isClosed() { return 0, c.closeError diff --git a/peer/conn_test.go b/peer/conn_test.go index 519e5f3b743db..644390ba2ea68 100644 --- a/peer/conn_test.go +++ b/peer/conn_test.go @@ -267,6 +267,27 @@ func TestConn(t *testing.T) { _, err := client.Ping() require.NoError(t, err) }) + + t.Run("ShortBuffer", func(t *testing.T) { + t.Parallel() + client, server, _ := createPair(t) + exchange(client, server) + go func() { + channel, err := client.Dial(context.Background(), "test", nil) + require.NoError(t, err) + _, err = channel.Write([]byte{1, 2}) + require.NoError(t, err) + }() + channel, err := server.Accept(context.Background()) + require.NoError(t, err) + data := make([]byte, 1) + _, err = channel.Read(data) + require.NoError(t, err) + require.Equal(t, uint8(0x1), data[0]) + _, err = channel.Read(data) + require.NoError(t, err) + require.Equal(t, uint8(0x2), data[0]) + }) } func createPair(t *testing.T) (client *peer.Conn, server *peer.Conn, wan *vnet.Router) { From af5e3c2bcd56bf91d2bc5469655ac72bd30c6ad3 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Wed, 16 Feb 2022 17:34:16 +0000 Subject: [PATCH 03/18] SSH server works! --- agent/{agent.go => server.go} | 39 ++++++++++++++++++++----- agent/{agent_test.go => server_test.go} | 11 +++++-- peer/channel.go | 5 ---- 3 files changed, 40 insertions(+), 15 deletions(-) rename agent/{agent.go => server.go} (80%) rename agent/{agent_test.go => server_test.go} (89%) diff --git a/agent/agent.go b/agent/server.go similarity index 80% rename from agent/agent.go rename to agent/server.go index 42dfea17b6104..f6d580dfce004 100644 --- a/agent/agent.go +++ b/agent/server.go @@ -2,8 +2,12 @@ package agent import ( "context" + "crypto/rand" + "crypto/rsa" "errors" + "fmt" "io" + "net" "sync" "time" @@ -47,15 +51,37 @@ type server struct { func (s *server) init(ctx context.Context) { forwardHandler := &ssh.ForwardedTCPHandler{} + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + panic(err) + } + signer, err := gossh.NewSignerFromKey(key) + if err != nil { + panic(err) + } s.sshServer = &ssh.Server{ + ChannelHandlers: ssh.DefaultChannelHandlers, + ConnectionFailedCallback: func(conn net.Conn, err error) { + fmt.Printf("Conn failed: %s\n", err) + }, + Handler: func(s ssh.Session) { + fmt.Printf("WE GOT %q %q\n", s.User(), s.RawCommand()) + }, + HostSigners: []ssh.Signer{signer}, LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool { + // Allow local port forwarding all! + return true + }, + PtyCallback: func(ctx ssh.Context, pty ssh.Pty) bool { return false }, ReversePortForwardingCallback: func(ctx ssh.Context, bindHost string, bindPort uint32) bool { - return false + // Allow revere port forwarding all! + return true }, - PtyCallback: func(ctx ssh.Context, pty ssh.Pty) bool { - return false + RequestHandlers: map[string]ssh.RequestHandler{ + "tcpip-forward": forwardHandler.HandleSSHRequest, + "cancel-tcpip-forward": forwardHandler.HandleSSHRequest, }, ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig { return &gossh.ServerConfig{ @@ -65,13 +91,12 @@ func (s *server) init(ctx context.Context) { // encrypted. If possible, we'd disable encryption entirely here. Ciphers: []string{"arcfour"}, }, + PublicKeyCallback: func(conn gossh.ConnMetadata, key gossh.PublicKey) (*gossh.Permissions, error) { + return &gossh.Permissions{}, nil + }, NoClientAuth: true, } }, - RequestHandlers: map[string]ssh.RequestHandler{ - "tcpip-forward": forwardHandler.HandleSSHRequest, - "cancel-tcpip-forward": forwardHandler.HandleSSHRequest, - }, } go s.run(ctx) diff --git a/agent/agent_test.go b/agent/server_test.go similarity index 89% rename from agent/agent_test.go rename to agent/server_test.go index a027f7d902b8d..ca3c3605a6d8f 100644 --- a/agent/agent_test.go +++ b/agent/server_test.go @@ -2,7 +2,6 @@ package agent_test import ( "context" - "net" "os" "testing" @@ -14,9 +13,14 @@ import ( "github.com/coder/coder/provisionersdk" "github.com/pion/webrtc/v3" "github.com/stretchr/testify/require" + "go.uber.org/goleak" "golang.org/x/crypto/ssh" ) +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + func TestAgent(t *testing.T) { t.Run("asd", func(t *testing.T) { ctx := context.Background() @@ -45,9 +49,10 @@ func TestAgent(t *testing.T) { require.NoError(t, err) sshConn, channels, requests, err := ssh.NewClientConn(channel.NetConn(), "localhost:22", &ssh.ClientConfig{ User: "kyle", - HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { - return nil + Config: ssh.Config{ + Ciphers: []string{"arcfour"}, }, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), }) require.NoError(t, err) sshClient := ssh.NewClient(sshConn, channels, requests) diff --git a/peer/channel.go b/peer/channel.go index 7bd93a1c5330b..732a6a1c1de2d 100644 --- a/peer/channel.go +++ b/peer/channel.go @@ -3,10 +3,8 @@ package peer import ( "bufio" "context" - "fmt" "io" "net" - "runtime/debug" "sync" "time" @@ -205,7 +203,6 @@ func (c *Channel) Read(bytes []byte) (int, error) { if c.isClosed() { return 0, c.closeError } - debug.PrintStack() // An EOF always occurs when the connection is closed. // Alternative close errors will occur first if an unexpected // close has occurred. @@ -253,8 +250,6 @@ func (c *Channel) Write(bytes []byte) (n int, err error) { // See: https://github.com/pion/sctp/issues/181 time.Sleep(time.Microsecond) - fmt.Printf("Writing %d\n", len(bytes)) - return c.rwc.Write(bytes) } From f8a733c7e2ef5a5c7bf4c07059e8fa4460ea9117 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Wed, 16 Feb 2022 18:14:49 +0000 Subject: [PATCH 04/18] Start Windows support --- agent/server.go | 74 ++++++++++++++++++++---- agent/server_test.go | 2 + console/conpty/conpty.go | 4 ++ console/pty/pty.go | 1 + console/pty/pty_other.go | 4 ++ console/pty/pty_windows.go | 4 ++ wintest/main.go | 113 +++++++++++++++++++++++++++++++++++++ 7 files changed, 190 insertions(+), 12 deletions(-) create mode 100644 wintest/main.go diff --git a/agent/server.go b/agent/server.go index f6d580dfce004..eb087313bb5f5 100644 --- a/agent/server.go +++ b/agent/server.go @@ -8,10 +8,13 @@ import ( "fmt" "io" "net" + "os/exec" "sync" + "syscall" "time" "cdr.dev/slog" + "github.com/coder/coder/console/pty" "github.com/coder/coder/peer" "github.com/coder/coder/peerbroker" "github.com/coder/retry" @@ -50,33 +53,83 @@ type server struct { } func (s *server) init(ctx context.Context) { - forwardHandler := &ssh.ForwardedTCPHandler{} - key, err := rsa.GenerateKey(rand.Reader, 2048) + // Clients' should ignore the host key when connecting. + // The agent needs to authenticate with coderd to SSH, + // so SSH authentication doesn't improve security. + randomHostKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { panic(err) } - signer, err := gossh.NewSignerFromKey(key) + randomSigner, err := gossh.NewSignerFromKey(randomHostKey) if err != nil { panic(err) } + sshLogger := s.options.Logger.Named("ssh-server") + forwardHandler := &ssh.ForwardedTCPHandler{} s.sshServer = &ssh.Server{ ChannelHandlers: ssh.DefaultChannelHandlers, ConnectionFailedCallback: func(conn net.Conn, err error) { - fmt.Printf("Conn failed: %s\n", err) + sshLogger.Info(ctx, "ssh connection ended", slog.Error(err)) }, - Handler: func(s ssh.Session) { - fmt.Printf("WE GOT %q %q\n", s.User(), s.RawCommand()) + Handler: func(session ssh.Session) { + fmt.Printf("WE GOT %q %q\n", session.User(), session.RawCommand()) + + sshPty, windowSize, isPty := session.Pty() + if isPty { + cmd := exec.CommandContext(ctx, session.Command()[0], session.Command()[1:]...) + cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term)) + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setsid: true, + Setctty: true, + } + pty, err := pty.New() + if err != nil { + panic(err) + } + err = pty.Resize(uint16(sshPty.Window.Width), uint16(sshPty.Window.Height)) + if err != nil { + panic(err) + } + cmd.Stdout = pty.OutPipe() + cmd.Stderr = pty.OutPipe() + cmd.Stdin = pty.InPipe() + err = cmd.Start() + if err != nil { + panic(err) + } + go func() { + for win := range windowSize { + err := pty.Resize(uint16(win.Width), uint16(win.Height)) + if err != nil { + panic(err) + } + } + }() + go func() { + io.Copy(pty.Writer(), session) + }() + fmt.Printf("Got here!\n") + io.Copy(session, pty.Reader()) + fmt.Printf("Done!\n") + cmd.Wait() + } }, - HostSigners: []ssh.Signer{signer}, + HostSigners: []ssh.Signer{randomSigner}, LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool { // Allow local port forwarding all! + sshLogger.Debug(ctx, "local port forward", + slog.F("destination-host", destinationHost), + slog.F("destination-port", destinationPort)) return true }, PtyCallback: func(ctx ssh.Context, pty ssh.Pty) bool { - return false + return true }, ReversePortForwardingCallback: func(ctx ssh.Context, bindHost string, bindPort uint32) bool { - // Allow revere port forwarding all! + // Allow reverse port forwarding all! + sshLogger.Debug(ctx, "local port forward", + slog.F("bind-host", bindHost), + slog.F("bind-port", bindPort)) return true }, RequestHandlers: map[string]ssh.RequestHandler{ @@ -91,9 +144,6 @@ func (s *server) init(ctx context.Context) { // encrypted. If possible, we'd disable encryption entirely here. Ciphers: []string{"arcfour"}, }, - PublicKeyCallback: func(conn gossh.ConnMetadata, key gossh.PublicKey) (*gossh.Permissions, error) { - return &gossh.Permissions{}, nil - }, NoClientAuth: true, } }, diff --git a/agent/server_test.go b/agent/server_test.go index ca3c3605a6d8f..cc6aaef7d6522 100644 --- a/agent/server_test.go +++ b/agent/server_test.go @@ -58,6 +58,8 @@ func TestAgent(t *testing.T) { sshClient := ssh.NewClient(sshConn, channels, requests) session, err := sshClient.NewSession() require.NoError(t, err) + err = session.RequestPty("xterm-256color", 128, 128, ssh.TerminalModes{}) + require.NoError(t, err) session.Stdout = os.Stdout session.Stderr = os.Stderr err = session.Run("echo test") diff --git a/console/conpty/conpty.go b/console/conpty/conpty.go index a57264b8ff195..3b00f31a31765 100644 --- a/console/conpty/conpty.go +++ b/console/conpty/conpty.go @@ -65,6 +65,10 @@ func (c *ConPty) Reader() io.Reader { return c.outFileOurSide } +func (c *ConPty) Writer() io.Writer { + return c.inFileOurSide +} + // InPipe returns input pipe of the pseudo terminal // Note: It is safer to use the Write method to prevent partially-written VT sequences // from corrupting the terminal diff --git a/console/pty/pty.go b/console/pty/pty.go index 86b56e68f922e..3e3384faec6c3 100644 --- a/console/pty/pty.go +++ b/console/pty/pty.go @@ -12,6 +12,7 @@ type Pty interface { Resize(cols uint16, rows uint16) error WriteString(str string) (int, error) Reader() io.Reader + Writer() io.Writer Close() error } diff --git a/console/pty/pty_other.go b/console/pty/pty_other.go index 723a6dbfd748a..f1a21a941cf13 100644 --- a/console/pty/pty_other.go +++ b/console/pty/pty_other.go @@ -38,6 +38,10 @@ func (p *unixPty) Reader() io.Reader { return p.pty } +func (p *unixPty) Writer() io.Writer { + return p.pty +} + func (p *unixPty) WriteString(str string) (int, error) { return p.pty.WriteString(str) } diff --git a/console/pty/pty_windows.go b/console/pty/pty_windows.go index 01fbe39169f04..6a990f6068c4f 100644 --- a/console/pty/pty_windows.go +++ b/console/pty/pty_windows.go @@ -61,6 +61,10 @@ func (p *pipePtyVal) Reader() io.Reader { return p.outFileOurSide } +func (p *pipePtyVal) Writer() io.Writer { + return p.inFileOurSide +} + func (p *pipePtyVal) WriteString(str string) (int, error) { return p.inFileOurSide.WriteString(str) } diff --git a/wintest/main.go b/wintest/main.go new file mode 100644 index 0000000000000..cc0f7c6fd2ea4 --- /dev/null +++ b/wintest/main.go @@ -0,0 +1,113 @@ +package main + +import ( + "context" + "os" + "testing" + + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/agent" + "github.com/coder/coder/peer" + "github.com/coder/coder/peerbroker" + "github.com/coder/coder/peerbroker/proto" + "github.com/coder/coder/provisionersdk" + "github.com/pion/webrtc/v3" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" + "golang.org/x/sys/windows" +) + +func main() { + state, err := MakeOutputRaw(os.Stdout.Fd()) + if err != nil { + panic(err) + } + defer Restore(os.Stdout.Fd(), state) + + t := &testing.T{} + ctx := context.Background() + client, server := provisionersdk.TransportPipe() + defer client.Close() + defer server.Close() + closer := agent.Server(func(ctx context.Context) (*peerbroker.Listener, error) { + return peerbroker.Listen(server, &peer.ConnOptions{ + Logger: slogtest.Make(t, nil), + }) + }, &agent.Options{ + Logger: slogtest.Make(t, nil), + }) + defer closer.Close() + api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client)) + stream, err := api.NegotiateConnection(ctx) + require.NoError(t, err) + conn, err := peerbroker.Dial(stream, []webrtc.ICEServer{}, &peer.ConnOptions{ + Logger: slogtest.Make(t, nil), + }) + require.NoError(t, err) + defer conn.Close() + channel, err := conn.Dial(ctx, "example", &peer.ChannelOptions{ + Protocol: "ssh", + }) + require.NoError(t, err) + sshConn, channels, requests, err := ssh.NewClientConn(channel.NetConn(), "localhost:22", &ssh.ClientConfig{ + User: "kyle", + Config: ssh.Config{ + Ciphers: []string{"arcfour"}, + }, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + }) + require.NoError(t, err) + sshClient := ssh.NewClient(sshConn, channels, requests) + session, err := sshClient.NewSession() + require.NoError(t, err) + err = session.RequestPty("xterm-256color", 128, 128, ssh.TerminalModes{ + ssh.ECHO: 1, + }) + require.NoError(t, err) + session.Stdin = os.Stdin + session.Stdout = os.Stdout + session.Stderr = os.Stderr + err = session.Run("bash") + require.NoError(t, err) +} + +// State differs per-platform. +type State struct { + mode uint32 +} + +// makeRaw sets the terminal in raw mode and returns the previous state so it can be restored. +func makeRaw(handle windows.Handle, input bool) (uint32, error) { + var prevState uint32 + if err := windows.GetConsoleMode(handle, &prevState); err != nil { + return 0, err + } + + var raw uint32 + if input { + raw = prevState &^ (windows.ENABLE_ECHO_INPUT | windows.ENABLE_PROCESSED_INPUT | windows.ENABLE_LINE_INPUT | windows.ENABLE_PROCESSED_OUTPUT) + raw |= windows.ENABLE_VIRTUAL_TERMINAL_INPUT + } else { + raw = prevState | windows.ENABLE_VIRTUAL_TERMINAL_PROCESSING + } + + if err := windows.SetConsoleMode(handle, raw); err != nil { + return 0, err + } + return prevState, nil +} + +// MakeOutputRaw sets an output terminal to raw and enables VT100 processing. +func MakeOutputRaw(handle uintptr) (*State, error) { + prevState, err := makeRaw(windows.Handle(handle), false) + if err != nil { + return nil, err + } + + return &State{mode: prevState}, nil +} + +// Restore terminal back to original state. +func Restore(handle uintptr, state *State) error { + return windows.SetConsoleMode(windows.Handle(handle), state.mode) +} From 72b6b095e53552bff996a372caa9494fe435e8ee Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Wed, 16 Feb 2022 18:58:46 -0600 Subject: [PATCH 05/18] Something works --- agent/server.go | 36 ++--- agent/server_test.go | 5 +- console/conpty/spawn.go | 260 +++++++++++++++++++++++++++++++++++ console/conpty/syscall.go | 67 +++++++++ console/console.go | 17 ++- console/expect_test.go | 8 +- console/pty/pty.go | 66 ++++++++- console/pty/pty_other.go | 18 +-- console/pty/pty_windows.go | 114 ++++++++++----- console/pty/run.go | 7 + console/pty/run_other.go | 18 +++ console/pty/start_test.go | 28 ++++ console/pty/start_windows.go | 77 +++++++++++ console/test_console.go | 6 +- go.mod | 1 + go.sum | 2 + wintest/main.go | 56 ++------ 17 files changed, 645 insertions(+), 141 deletions(-) create mode 100644 console/conpty/spawn.go create mode 100644 console/pty/run.go create mode 100644 console/pty/run_other.go create mode 100644 console/pty/start_test.go create mode 100644 console/pty/start_windows.go diff --git a/agent/server.go b/agent/server.go index eb087313bb5f5..632b191c975ee 100644 --- a/agent/server.go +++ b/agent/server.go @@ -5,16 +5,15 @@ import ( "crypto/rand" "crypto/rsa" "errors" - "fmt" "io" "net" - "os/exec" + "os" "sync" "syscall" "time" "cdr.dev/slog" - "github.com/coder/coder/console/pty" + "github.com/ActiveState/termtest/conpty" "github.com/coder/coder/peer" "github.com/coder/coder/peerbroker" "github.com/coder/retry" @@ -72,46 +71,31 @@ func (s *server) init(ctx context.Context) { sshLogger.Info(ctx, "ssh connection ended", slog.Error(err)) }, Handler: func(session ssh.Session) { - fmt.Printf("WE GOT %q %q\n", session.User(), session.RawCommand()) - sshPty, windowSize, isPty := session.Pty() if isPty { - cmd := exec.CommandContext(ctx, session.Command()[0], session.Command()[1:]...) - cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term)) - cmd.SysProcAttr = &syscall.SysProcAttr{ - Setsid: true, - Setctty: true, - } - pty, err := pty.New() - if err != nil { - panic(err) - } - err = pty.Resize(uint16(sshPty.Window.Width), uint16(sshPty.Window.Height)) + cpty, err := conpty.New(int16(sshPty.Window.Width), int16(sshPty.Window.Height)) if err != nil { panic(err) } - cmd.Stdout = pty.OutPipe() - cmd.Stderr = pty.OutPipe() - cmd.Stdin = pty.InPipe() - err = cmd.Start() + _, _, err = cpty.Spawn("C:\\WINDOWS\\System32\\WindowsPowerShell\\v1.0\\powershell.exe", []string{}, &syscall.ProcAttr{ + Env: os.Environ(), + }) if err != nil { panic(err) } go func() { for win := range windowSize { - err := pty.Resize(uint16(win.Width), uint16(win.Height)) + err := cpty.Resize(uint16(win.Width), uint16(win.Height)) if err != nil { panic(err) } } }() + go func() { - io.Copy(pty.Writer(), session) + io.Copy(session, cpty) }() - fmt.Printf("Got here!\n") - io.Copy(session, pty.Reader()) - fmt.Printf("Done!\n") - cmd.Wait() + io.Copy(cpty, session) } }, HostSigners: []ssh.Signer{randomSigner}, diff --git a/agent/server_test.go b/agent/server_test.go index cc6aaef7d6522..8351a25c53257 100644 --- a/agent/server_test.go +++ b/agent/server_test.go @@ -62,7 +62,10 @@ func TestAgent(t *testing.T) { require.NoError(t, err) session.Stdout = os.Stdout session.Stderr = os.Stderr - err = session.Run("echo test") + err = session.Run("cmd.exe /k echo test") require.NoError(t, err) }) } + +// Read + write for input +// Read + write for output diff --git a/console/conpty/spawn.go b/console/conpty/spawn.go new file mode 100644 index 0000000000000..3831e64d07a80 --- /dev/null +++ b/console/conpty/spawn.go @@ -0,0 +1,260 @@ +package conpty + +import ( + "fmt" + "os" + "strings" + "syscall" + "unicode/utf16" + "unsafe" + + "golang.org/x/sys/windows" +) + +// Spawn spawns a new process attached to the pseudo terminal +func Spawn(conpty *ConPty, argv0 string, argv []string, attr *syscall.ProcAttr) (pid int, handle uintptr, err error) { + startupInfo := &startupInfoEx{} + var attrListSize uint64 + startupInfo.startupInfo.Cb = uint32(unsafe.Sizeof(startupInfo)) + + err = initializeProcThreadAttributeList(0, 1, &attrListSize) + if err != nil { + return 0, 0, fmt.Errorf("could not retrieve list size: %v", err) + } + + attributeListBuffer := make([]byte, attrListSize) + startupInfo.lpAttributeList = windows.Handle(unsafe.Pointer(&attributeListBuffer[0])) + + err = initializeProcThreadAttributeList(uintptr(startupInfo.lpAttributeList), 1, &attrListSize) + if err != nil { + return 0, 0, fmt.Errorf("failed to initialize proc thread attributes for conpty: %v", err) + } + + err = updateProcThreadAttributeList( + startupInfo.lpAttributeList, + procThreadAttributePseudoconsole, + conpty.hpCon, + unsafe.Sizeof(conpty.hpCon)) + if err != nil { + return 0, 0, fmt.Errorf("failed to update proc thread attributes attributes for conpty usage: %v", err) + } + + if attr == nil { + attr = &syscall.ProcAttr{} + } + + if len(attr.Dir) != 0 { + // StartProcess assumes that argv0 is relative to attr.Dir, + // because it implies Chdir(attr.Dir) before executing argv0. + // Windows CreateProcess assumes the opposite: it looks for + // argv0 relative to the current directory, and, only once the new + // process is started, it does Chdir(attr.Dir). We are adjusting + // for that difference here by making argv0 absolute. + var err error + argv0, err = joinExeDirAndFName(attr.Dir, argv0) + if err != nil { + return 0, 0, err + } + } + argv0p, err := windows.UTF16PtrFromString(argv0) + if err != nil { + return 0, 0, err + } + + // Windows CreateProcess takes the command line as a single string: + // use attr.CmdLine if set, else build the command line by escaping + // and joining each argument with spaces + cmdline := makeCmdLine(argv) + + var argvp *uint16 + if len(cmdline) != 0 { + argvp, err = windows.UTF16PtrFromString(cmdline) + if err != nil { + return 0, 0, fmt.Errorf("utf ptr from string: %w", err) + } + } + + var dirp *uint16 + if len(attr.Dir) != 0 { + dirp, err = windows.UTF16PtrFromString(attr.Dir) + if err != nil { + return 0, 0, fmt.Errorf("utf ptr from string: %w", err) + } + } + + startupInfo.startupInfo.Flags = windows.STARTF_USESTDHANDLES + + pi := new(windows.ProcessInformation) + + flags := uint32(windows.CREATE_UNICODE_ENVIRONMENT) | extendedStartupinfoPresent + + var zeroSec windows.SecurityAttributes + pSec := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(zeroSec)), InheritHandle: 1} + tSec := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(zeroSec)), InheritHandle: 1} + + // c.startupInfo.startupInfo.Cb = uint32(unsafe.Sizeof(c.startupInfo)) + err = windows.CreateProcess( + argv0p, + argvp, + pSec, // process handle not inheritable + tSec, // thread handles not inheritable, + false, + flags, + createEnvBlock(addCriticalEnv(dedupEnvCase(true, attr.Env))), + dirp, // use current directory later: dirp, + &startupInfo.startupInfo, + pi) + + if err != nil { + return 0, 0, fmt.Errorf("create process: %w", err) + } + defer windows.CloseHandle(windows.Handle(pi.Thread)) + + return int(pi.ProcessId), uintptr(pi.Process), nil +} + +// makeCmdLine builds a command line out of args by escaping "special" +// characters and joining the arguments with spaces. +func makeCmdLine(args []string) string { + var s string + for _, v := range args { + if s != "" { + s += " " + } + s += windows.EscapeArg(v) + } + return s +} + +func isSlash(c uint8) bool { + return c == '\\' || c == '/' +} + +func normalizeDir(dir string) (name string, err error) { + ndir, err := syscall.FullPath(dir) + if err != nil { + return "", err + } + if len(ndir) > 2 && isSlash(ndir[0]) && isSlash(ndir[1]) { + // dir cannot have \\server\share\path form + return "", syscall.EINVAL + } + return ndir, nil +} + +func volToUpper(ch int) int { + if 'a' <= ch && ch <= 'z' { + ch += 'A' - 'a' + } + return ch +} + +func joinExeDirAndFName(dir, p string) (name string, err error) { + if len(p) == 0 { + return "", syscall.EINVAL + } + if len(p) > 2 && isSlash(p[0]) && isSlash(p[1]) { + // \\server\share\path form + return p, nil + } + if len(p) > 1 && p[1] == ':' { + // has drive letter + if len(p) == 2 { + return "", syscall.EINVAL + } + if isSlash(p[2]) { + return p, nil + } else { + d, err := normalizeDir(dir) + if err != nil { + return "", err + } + if volToUpper(int(p[0])) == volToUpper(int(d[0])) { + return syscall.FullPath(d + "\\" + p[2:]) + } else { + return syscall.FullPath(p) + } + } + } else { + // no drive letter + d, err := normalizeDir(dir) + if err != nil { + return "", err + } + if isSlash(p[0]) { + return windows.FullPath(d[:2] + p) + } else { + return windows.FullPath(d + "\\" + p) + } + } +} + +// createEnvBlock converts an array of environment strings into +// the representation required by CreateProcess: a sequence of NUL +// terminated strings followed by a nil. +// Last bytes are two UCS-2 NULs, or four NUL bytes. +func createEnvBlock(envv []string) *uint16 { + if len(envv) == 0 { + return &utf16.Encode([]rune("\x00\x00"))[0] + } + length := 0 + for _, s := range envv { + length += len(s) + 1 + } + length += 1 + + b := make([]byte, length) + i := 0 + for _, s := range envv { + l := len(s) + copy(b[i:i+l], []byte(s)) + copy(b[i+l:i+l+1], []byte{0}) + i = i + l + 1 + } + copy(b[i:i+1], []byte{0}) + + return &utf16.Encode([]rune(string(b)))[0] +} + +// dedupEnvCase is dedupEnv with a case option for testing. +// If caseInsensitive is true, the case of keys is ignored. +func dedupEnvCase(caseInsensitive bool, env []string) []string { + out := make([]string, 0, len(env)) + saw := make(map[string]int, len(env)) // key => index into out + for _, kv := range env { + eq := strings.Index(kv, "=") + if eq < 0 { + out = append(out, kv) + continue + } + k := kv[:eq] + if caseInsensitive { + k = strings.ToLower(k) + } + if dupIdx, isDup := saw[k]; isDup { + out[dupIdx] = kv + continue + } + saw[k] = len(out) + out = append(out, kv) + } + return out +} + +// addCriticalEnv adds any critical environment variables that are required +// (or at least almost always required) on the operating system. +// Currently this is only used for Windows. +func addCriticalEnv(env []string) []string { + for _, kv := range env { + eq := strings.Index(kv, "=") + if eq < 0 { + continue + } + k := kv[:eq] + if strings.EqualFold(k, "SYSTEMROOT") { + // We already have it. + return env + } + } + return append(env, "SYSTEMROOT="+os.Getenv("SYSTEMROOT")) +} diff --git a/console/conpty/syscall.go b/console/conpty/syscall.go index 284603aa8fdc7..39a8904426648 100644 --- a/console/conpty/syscall.go +++ b/console/conpty/syscall.go @@ -18,8 +18,75 @@ var ( procResizePseudoConsole = kernel32.NewProc("ResizePseudoConsole") procCreatePseudoConsole = kernel32.NewProc("CreatePseudoConsole") procClosePseudoConsole = kernel32.NewProc("ClosePseudoConsole") + + // Required for executing processes! + procInitializeProcThreadAttributeList = kernel32.NewProc("InitializeProcThreadAttributeList") + procUpdateProcThreadAttribute = kernel32.NewProc("UpdateProcThreadAttribute") + procLocalAlloc = kernel32.NewProc("LocalAlloc") + procDeleteProcThreadAttributeList = kernel32.NewProc("DeleteProcThreadAttributeList") + procCreateProcessW = kernel32.NewProc("CreateProcessW") ) +// An extended version of process startup information that points +// to a pseudo terminal object. +type startupInfoEx struct { + startupInfo windows.StartupInfo + lpAttributeList windows.Handle +} + +// Constant in CreateProcessW indicating that extended startup information is present. +const extendedStartupinfoPresent uint32 = 0x00080000 + +type procThreadAttribute uintptr + +// windows constant needed during initialization of extended startupinfo +const procThreadAttributePseudoconsole procThreadAttribute = 22 | 0x00020000 + +func initializeProcThreadAttributeList(attributeList uintptr, attributeCount uint32, listSize *uint64) (err error) { + if attributeList == 0 { + procInitializeProcThreadAttributeList.Call(0, uintptr(attributeCount), 0, uintptr(unsafe.Pointer(listSize))) + return + } + r1, _, e1 := procInitializeProcThreadAttributeList.Call(attributeList, uintptr(attributeCount), 0, uintptr(unsafe.Pointer(listSize))) + + if r1 == 0 { // boolean FALSE + err = e1 + } + + return +} + +func updateProcThreadAttributeList(attributeList windows.Handle, attribute procThreadAttribute, lpValue windows.Handle, lpSize uintptr) (err error) { + + r1, _, e1 := procUpdateProcThreadAttribute.Call(uintptr(attributeList), 0, uintptr(attribute), uintptr(lpValue), lpSize, 0, 0) + + if r1 == 0 { // boolean FALSE + err = e1 + } + + return +} +func deleteProcThreadAttributeList(handle windows.Handle) (err error) { + r1, _, e1 := procDeleteProcThreadAttributeList.Call(uintptr(handle)) + + if r1 == 0 { // boolean FALSE + err = e1 + } + + return +} + +func localAlloc(size uint64) (ptr windows.Handle, err error) { + r1, _, e1 := procLocalAlloc.Call(uintptr(0x0040), uintptr(size)) + if r1 == 0 { + err = e1 + ptr = windows.InvalidHandle + return + } + ptr = windows.Handle(r1) + return +} + func createPseudoConsole(consoleSize uintptr, ptyIn windows.Handle, ptyOut windows.Handle, hpCon *windows.Handle) (err error) { r1, _, e1 := procCreatePseudoConsole.Call( consoleSize, diff --git a/console/console.go b/console/console.go index e5af7fa20977b..4b7daf52d6613 100644 --- a/console/console.go +++ b/console/console.go @@ -20,7 +20,6 @@ import ( "io" "io/ioutil" "log" - "os" "unicode/utf8" "github.com/coder/coder/console/pty" @@ -86,8 +85,8 @@ func WithExpectObserver(observers ...Observer) Opt { } } -// NewConsole returns a new Console with the given options. -func NewConsole(opts ...Opt) (*Console, error) { +// NewWithOptions returns a new Console with the given options. +func NewWithOptions(opts ...Opt) (*Console, error) { options := Opts{ Logger: log.New(ioutil.Discard, "", 0), } @@ -103,7 +102,7 @@ func NewConsole(opts ...Opt) (*Console, error) { return nil, err } closers := []io.Closer{consolePty} - reader := consolePty.Reader() + reader := consolePty.Output() cons := &Console{ opts: options, @@ -116,13 +115,13 @@ func NewConsole(opts ...Opt) (*Console, error) { } // Tty returns an input Tty for accepting input -func (c *Console) InTty() *os.File { - return c.pty.InPipe() +func (c *Console) Input() io.ReadWriter { + return c.pty.Input() } // OutTty returns an output tty for writing -func (c *Console) OutTty() *os.File { - return c.pty.OutPipe() +func (c *Console) Output() io.ReadWriter { + return c.pty.Output() } // Close closes Console's tty. Calling Close will unblock Expect and ExpectEOF. @@ -139,7 +138,7 @@ func (c *Console) Close() error { // Send writes string s to Console's tty. func (c *Console) Send(s string) (int, error) { c.Logf("console send: %q", s) - n, err := c.pty.WriteString(s) + n, err := c.pty.Input().Write([]byte(s)) return n, err } diff --git a/console/expect_test.go b/console/expect_test.go index c80f981717d44..8fcefbf043234 100644 --- a/console/expect_test.go +++ b/console/expect_test.go @@ -75,7 +75,7 @@ func newTestConsole(t *testing.T, opts ...Opt) (*Console, error) { opts = append([]Opt{ expectNoError(t), }, opts...) - return NewConsole(opts...) + return NewWithOptions(opts...) } func expectNoError(t *testing.T) Opt { @@ -123,7 +123,7 @@ func TestExpectf(t *testing.T) { console.SendLine("xilfteN") }() - err = Prompt(console.InTty(), console.OutTty()) + err = Prompt(console.Input(), console.Output()) if err != nil { t.Errorf("Expected no error but got '%s'", err) } @@ -149,7 +149,7 @@ func TestExpect(t *testing.T) { console.SendLine("xilfteN") }() - err = Prompt(console.InTty(), console.OutTty()) + err = Prompt(console.Input(), console.Output()) if err != nil { t.Errorf("Expected no error but got '%s'", err) } @@ -173,7 +173,7 @@ func TestExpectOutput(t *testing.T) { console.SendLine("3") }() - err = Prompt(console.InTty(), console.OutTty()) + err = Prompt(console.Input(), console.Output()) if err == nil || !errors.Is(err, ErrWrongAnswer) { t.Errorf("Expected error '%s' but got '%s' instead", ErrWrongAnswer, err) } diff --git a/console/pty/pty.go b/console/pty/pty.go index 3e3384faec6c3..cf4b4849454ec 100644 --- a/console/pty/pty.go +++ b/console/pty/pty.go @@ -7,12 +7,9 @@ import ( // Pty is the minimal pseudo-tty interface we require. type Pty interface { - InPipe() *os.File - OutPipe() *os.File + Input() io.ReadWriter + Output() io.ReadWriter Resize(cols uint16, rows uint16) error - WriteString(str string) (int, error) - Reader() io.Reader - Writer() io.Writer Close() error } @@ -20,3 +17,62 @@ type Pty interface { func New() (Pty, error) { return newPty() } + +func pipePty() (Pty, error) { + inFilePipeSide, inFileOurSide, err := os.Pipe() + if err != nil { + return nil, err + } + + outFileOurSide, outFilePipeSide, err := os.Pipe() + if err != nil { + return nil, err + } + + return &pipePtyVal{ + inFilePipeSide, + inFileOurSide, + outFileOurSide, + outFilePipeSide, + }, nil +} + +type pipePtyVal struct { + inFilePipeSide, inFileOurSide *os.File + outFileOurSide, outFilePipeSide *os.File +} + +func (p *pipePtyVal) Output() io.ReadWriter { + return readWriter{ + Reader: p.outFilePipeSide, + Writer: p.outFileOurSide, + } +} + +func (p *pipePtyVal) Input() io.ReadWriter { + return readWriter{ + Reader: p.inFilePipeSide, + Writer: p.inFileOurSide, + } +} + +func (p *pipePtyVal) WriteString(str string) (int, error) { + return p.inFileOurSide.WriteString(str) +} + +func (p *pipePtyVal) Resize(uint16, uint16) error { + return nil +} + +func (p *pipePtyVal) Close() error { + p.inFileOurSide.Close() + p.inFilePipeSide.Close() + p.outFilePipeSide.Close() + p.outFileOurSide.Close() + return nil +} + +type readWriter struct { + io.Reader + io.Writer +} diff --git a/console/pty/pty_other.go b/console/pty/pty_other.go index f1a21a941cf13..7149b344c6fcc 100644 --- a/console/pty/pty_other.go +++ b/console/pty/pty_other.go @@ -16,44 +16,44 @@ func newPty() (Pty, error) { return nil, err } - return &unixPty{ + return &otherPty{ pty: ptyFile, tty: ttyFile, }, nil } -type unixPty struct { +type otherPty struct { pty, tty *os.File } -func (p *unixPty) InPipe() *os.File { +func (p *otherPty) InPipe() *os.File { return p.tty } -func (p *unixPty) OutPipe() *os.File { +func (p *otherPty) OutPipe() *os.File { return p.tty } -func (p *unixPty) Reader() io.Reader { +func (p *otherPty) Reader() io.Reader { return p.pty } -func (p *unixPty) Writer() io.Writer { +func (p *otherPty) Writer() io.Writer { return p.pty } -func (p *unixPty) WriteString(str string) (int, error) { +func (p *otherPty) WriteString(str string) (int, error) { return p.pty.WriteString(str) } -func (p *unixPty) Resize(cols uint16, rows uint16) error { +func (p *otherPty) Resize(cols uint16, rows uint16) error { return pty.Setsize(p.tty, &pty.Winsize{ Rows: rows, Cols: cols, }) } -func (p *unixPty) Close() error { +func (p *otherPty) Close() error { err := p.pty.Close() if err != nil { return err diff --git a/console/pty/pty_windows.go b/console/pty/pty_windows.go index 6a990f6068c4f..4e2430772e6b3 100644 --- a/console/pty/pty_windows.go +++ b/console/pty/pty_windows.go @@ -6,12 +6,22 @@ package pty import ( "io" "os" + "sync" + "unsafe" "golang.org/x/sys/windows" - "github.com/coder/coder/console/conpty" + "golang.org/x/xerrors" ) +var ( + kernel32 = windows.NewLazySystemDLL("kernel32.dll") + procResizePseudoConsole = kernel32.NewProc("ResizePseudoConsole") + procCreatePseudoConsole = kernel32.NewProc("CreatePseudoConsole") + procClosePseudoConsole = kernel32.NewProc("ClosePseudoConsole") +) + +// See: https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session func newPty() (Pty, error) { // We use the CreatePseudoConsole API which was introduced in build 17763 vsn := windows.RtlGetVersion() @@ -22,61 +32,93 @@ func newPty() (Pty, error) { return pipePty() } - return conpty.New(80, 80) -} + ptyWindows := &ptyWindows{} -func pipePty() (Pty, error) { - inFilePipeSide, inFileOurSide, err := os.Pipe() - if err != nil { + // Create the stdin pipe + if err := windows.CreatePipe(&ptyWindows.inputReadSide, &ptyWindows.inputWriteSide, nil, 0); err != nil { return nil, err } - outFileOurSide, outFilePipeSide, err := os.Pipe() - if err != nil { + // Create the stdout pipe + if err := windows.CreatePipe(&ptyWindows.outputReadSide, &ptyWindows.outputWriteSide, nil, 0); err != nil { return nil, err } - return &pipePtyVal{ - inFilePipeSide, - inFileOurSide, - outFileOurSide, - outFilePipeSide, - }, nil -} + consoleSize := uintptr((int32(80) << 16) | int32(80)) + ret, _, err := procCreatePseudoConsole.Call( + consoleSize, + uintptr(ptyWindows.inputReadSide), + uintptr(ptyWindows.outputWriteSide), + 0, + uintptr(unsafe.Pointer(&ptyWindows.console)), + ) + if ret != 0 { + return nil, xerrors.Errorf("create pseudo console (%d): %w", ret, err) + } -type pipePtyVal struct { - inFilePipeSide, inFileOurSide *os.File - outFileOurSide, outFilePipeSide *os.File -} + ptyWindows.outputWriteSideFile = os.NewFile(uintptr(ptyWindows.outputWriteSide), "|0") + ptyWindows.outputReadSideFile = os.NewFile(uintptr(ptyWindows.outputReadSide), "|1") + ptyWindows.inputReadSideFile = os.NewFile(uintptr(ptyWindows.inputReadSide), "|2") + ptyWindows.inputWriteSideFile = os.NewFile(uintptr(ptyWindows.inputWriteSide), "|3") + ptyWindows.closed = false -func (p *pipePtyVal) InPipe() *os.File { - return p.inFilePipeSide + return ptyWindows, nil } -func (p *pipePtyVal) OutPipe() *os.File { - return p.outFilePipeSide -} +type ptyWindows struct { + console windows.Handle + + outputWriteSide windows.Handle + outputReadSide windows.Handle + inputReadSide windows.Handle + inputWriteSide windows.Handle -func (p *pipePtyVal) Reader() io.Reader { - return p.outFileOurSide + outputWriteSideFile *os.File + outputReadSideFile *os.File + inputReadSideFile *os.File + inputWriteSideFile *os.File + + closeMutex sync.Mutex + closed bool } -func (p *pipePtyVal) Writer() io.Writer { - return p.inFileOurSide +func (p *ptyWindows) Input() io.ReadWriter { + return readWriter{ + Writer: p.inputWriteSideFile, + Reader: p.inputReadSideFile, + } } -func (p *pipePtyVal) WriteString(str string) (int, error) { - return p.inFileOurSide.WriteString(str) +func (p *ptyWindows) Output() io.ReadWriter { + return readWriter{ + Writer: p.outputWriteSideFile, + Reader: p.outputReadSideFile, + } } -func (p *pipePtyVal) Resize(uint16, uint16) error { +func (p *ptyWindows) Resize(cols uint16, rows uint16) error { + ret, _, err := procResizePseudoConsole.Call(uintptr(p.console), uintptr(cols)+(uintptr(rows)<<16)) + if ret != 0 { + return err + } return nil } -func (p *pipePtyVal) Close() error { - p.inFileOurSide.Close() - p.inFilePipeSide.Close() - p.outFilePipeSide.Close() - p.outFileOurSide.Close() +func (p *ptyWindows) Close() error { + p.closeMutex.Lock() + defer p.closeMutex.Unlock() + if p.closed { + return nil + } + p.closed = true + + ret, _, err := procClosePseudoConsole.Call(uintptr(p.console)) + if ret != 0 { + return xerrors.Errorf("close pseudo console: %w", err) + } + _ = p.outputWriteSideFile.Close() + _ = p.outputReadSideFile.Close() + _ = p.inputReadSideFile.Close() + _ = p.inputWriteSideFile.Close() return nil } diff --git a/console/pty/run.go b/console/pty/run.go new file mode 100644 index 0000000000000..e6c8e234178be --- /dev/null +++ b/console/pty/run.go @@ -0,0 +1,7 @@ +package pty + +import "os/exec" + +func Run(cmd *exec.Cmd) (Pty, error) { + return runPty(cmd) +} diff --git a/console/pty/run_other.go b/console/pty/run_other.go new file mode 100644 index 0000000000000..79f40d061a788 --- /dev/null +++ b/console/pty/run_other.go @@ -0,0 +1,18 @@ +//go:build !windows +// +build !windows + +package pty + +import ( + "os/exec" + + "github.com/creack/pty" +) + +func runPty(cmd *exec.Cmd) (Pty, error) { + pty, err := pty.Start(cmd) + if err != nil { + return nil, err + } + return &otherPty{pty, pty}, nil +} diff --git a/console/pty/start_test.go b/console/pty/start_test.go new file mode 100644 index 0000000000000..2ad525eefb1c4 --- /dev/null +++ b/console/pty/start_test.go @@ -0,0 +1,28 @@ +package pty_test + +import ( + "fmt" + "os/exec" + "regexp" + "testing" + + "github.com/coder/coder/console/pty" + "github.com/stretchr/testify/require" +) + +var ( + // Used to ensure terminal output doesn't have anything crazy! + // See: https://stackoverflow.com/a/29497680 + stripAnsi = regexp.MustCompile("[\u001B\u009B][[\\]()#;?]*(?:(?:(?:[a-zA-Z\\d]*(?:;[a-zA-Z\\d]*)*)?\u0007)|(?:(?:\\d{1,4}(?:;\\d{0,4})*)?[\\dA-PRZcf-ntqry=><~]))") +) + +func TestStart(t *testing.T) { + t.Run("Do", func(t *testing.T) { + pty, err := pty.Run(exec.Command("powershell.exe", "echo", "test")) + require.NoError(t, err) + data := make([]byte, 128) + _, err = pty.Output().Read(data) + require.NoError(t, err) + t.Log(fmt.Sprintf("%q", stripAnsi.ReplaceAllString(string(data), ""))) + }) +} diff --git a/console/pty/start_windows.go b/console/pty/start_windows.go new file mode 100644 index 0000000000000..9d163420e772b --- /dev/null +++ b/console/pty/start_windows.go @@ -0,0 +1,77 @@ +//go:build windows +// +build windows + +package pty + +import ( + "os" + "os/exec" + "unsafe" + + "golang.org/x/sys/windows" +) + +func runPty(cmd *exec.Cmd) (Pty, error) { + fullPath, err := exec.LookPath(cmd.Path) + if err != nil { + return nil, err + } + pathPtr, err := windows.UTF16PtrFromString(fullPath) + if err != nil { + return nil, err + } + argsPtr, err := windows.UTF16PtrFromString(windows.ComposeCommandLine(cmd.Args)) + if err != nil { + return nil, err + } + if cmd.Dir == "" { + cmd.Dir, err = os.Getwd() + if err != nil { + return nil, err + } + } + dirPtr, err := windows.UTF16PtrFromString(cmd.Dir) + if err != nil { + return nil, err + } + pty, err := newPty() + if err != nil { + return nil, err + } + winPty := pty.(*ptyWindows) + + attrs, err := windows.NewProcThreadAttributeList(1) + if err != nil { + return nil, err + } + err = attrs.Update(22|0x00020000, unsafe.Pointer(winPty.console), unsafe.Sizeof(winPty.console)) + if err != nil { + return nil, err + } + + startupInfo := &windows.StartupInfoEx{} + startupInfo.StartupInfo.Cb = uint32(unsafe.Sizeof(*startupInfo)) + startupInfo.StartupInfo.Flags = windows.STARTF_USESTDHANDLES + startupInfo.ProcThreadAttributeList = attrs.List() + var processInfo windows.ProcessInformation + err = windows.CreateProcess( + pathPtr, + argsPtr, + nil, + nil, + false, + // https://docs.microsoft.com/en-us/windows/win32/procthread/process-creation-flags#create_unicode_environment + windows.CREATE_UNICODE_ENVIRONMENT|windows.EXTENDED_STARTUPINFO_PRESENT, + // Environment variables can come later! + nil, + dirPtr, + &startupInfo.StartupInfo, + &processInfo, + ) + if err != nil { + return nil, err + } + defer windows.CloseHandle(windows.Handle(processInfo.Thread)) + + return pty, nil +} diff --git a/console/test_console.go b/console/test_console.go index d1d845d6cb4db..961ee63589353 100644 --- a/console/test_console.go +++ b/console/test_console.go @@ -34,12 +34,12 @@ func New(t *testing.T, cmd *cobra.Command) *Console { } }() - console, err := NewConsole(WithStdout(writer)) + console, err := NewWithOptions(WithStdout(writer)) require.NoError(t, err) t.Cleanup(func() { console.Close() }) - cmd.SetIn(console.InTty()) - cmd.SetOut(console.OutTty()) + cmd.SetIn(console.Input()) + cmd.SetOut(console.Output()) return console } diff --git a/go.mod b/go.mod index 7c86df1308668..0b8bcd86e2f8d 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ replace github.com/chzyer/readline => github.com/kylecarbs/readline v0.0.0-20220 require ( cdr.dev/slog v1.4.1 + github.com/ActiveState/termtest/conpty v0.5.0 github.com/briandowns/spinner v1.18.1 github.com/coder/retry v1.3.0 github.com/creack/pty v1.1.17 diff --git a/go.sum b/go.sum index 24b37bc319d34..863c5672d1a90 100644 --- a/go.sum +++ b/go.sum @@ -57,6 +57,8 @@ cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RX cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= gioui.org v0.0.0-20210308172011-57750fc8a0a6/go.mod h1:RSH6KIUZ0p2xy5zHDxgAM4zumjgTw83q2ge/PI+yyw8= +github.com/ActiveState/termtest/conpty v0.5.0 h1:JLUe6YDs4Jw4xNPCU+8VwTpniYOGeKzQg4SM2YHQNA8= +github.com/ActiveState/termtest/conpty v0.5.0/go.mod h1:LO4208FLsxw6DcNZ1UtuGUMW+ga9PFtX4ntv8Ymg9og= github.com/Azure/azure-pipeline-go v0.2.3/go.mod h1:x841ezTBIMG6O3lAcl8ATHnsOPVl2bqk7S3ta6S6u4k= github.com/Azure/azure-sdk-for-go v16.2.1+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc= github.com/Azure/azure-storage-blob-go v0.14.0/go.mod h1:SMqIBi+SuiQH32bvyjngEewEeXoPfKMgWlBDaYf6fck= diff --git a/wintest/main.go b/wintest/main.go index cc0f7c6fd2ea4..12831d46c2da1 100644 --- a/wintest/main.go +++ b/wintest/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "log" "os" "testing" @@ -14,15 +15,15 @@ import ( "github.com/pion/webrtc/v3" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" - "golang.org/x/sys/windows" + "golang.org/x/crypto/ssh/terminal" ) func main() { - state, err := MakeOutputRaw(os.Stdout.Fd()) + oldState, err := terminal.MakeRaw(int(os.Stdin.Fd())) if err != nil { - panic(err) + log.Fatalf("Could not put terminal in raw mode: %v\n", err) } - defer Restore(os.Stdout.Fd(), state) + defer terminal.Restore(0, oldState) t := &testing.T{} ctx := context.Background() @@ -60,54 +61,13 @@ func main() { sshClient := ssh.NewClient(sshConn, channels, requests) session, err := sshClient.NewSession() require.NoError(t, err) - err = session.RequestPty("xterm-256color", 128, 128, ssh.TerminalModes{ - ssh.ECHO: 1, + err = session.RequestPty("", 1024, 1024, ssh.TerminalModes{ + ssh.IGNCR: 1, }) require.NoError(t, err) session.Stdin = os.Stdin session.Stdout = os.Stdout session.Stderr = os.Stderr - err = session.Run("bash") + err = session.Run("C:\\WINDOWS\\System32\\WindowsPowerShell\\v1.0\\powershell.exe") require.NoError(t, err) } - -// State differs per-platform. -type State struct { - mode uint32 -} - -// makeRaw sets the terminal in raw mode and returns the previous state so it can be restored. -func makeRaw(handle windows.Handle, input bool) (uint32, error) { - var prevState uint32 - if err := windows.GetConsoleMode(handle, &prevState); err != nil { - return 0, err - } - - var raw uint32 - if input { - raw = prevState &^ (windows.ENABLE_ECHO_INPUT | windows.ENABLE_PROCESSED_INPUT | windows.ENABLE_LINE_INPUT | windows.ENABLE_PROCESSED_OUTPUT) - raw |= windows.ENABLE_VIRTUAL_TERMINAL_INPUT - } else { - raw = prevState | windows.ENABLE_VIRTUAL_TERMINAL_PROCESSING - } - - if err := windows.SetConsoleMode(handle, raw); err != nil { - return 0, err - } - return prevState, nil -} - -// MakeOutputRaw sets an output terminal to raw and enables VT100 processing. -func MakeOutputRaw(handle uintptr) (*State, error) { - prevState, err := makeRaw(windows.Handle(handle), false) - if err != nil { - return nil, err - } - - return &State{mode: prevState}, nil -} - -// Restore terminal back to original state. -func Restore(handle uintptr, state *State) error { - return windows.SetConsoleMode(windows.Handle(handle), state.mode) -} From 722be6c98288db4102e7f617c020c13f2edb0ce8 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 17 Feb 2022 00:11:35 -0600 Subject: [PATCH 06/18] Refactor pty package to support Windows spawn --- agent/server.go | 22 +- cli/login_test.go | 15 +- cli/projectcreate_test.go | 22 +- cli/root.go | 18 +- cli/workspacecreate_test.go | 15 +- console/conpty/conpty.go | 111 -------- console/conpty/spawn.go | 260 ------------------ console/conpty/syscall.go | 120 -------- console/console.go | 162 ----------- console/doc.go | 19 -- console/expect.go | 109 -------- console/expect_opt.go | 139 ---------- console/expect_opt_test.go | 163 ----------- console/expect_test.go | 181 ------------ console/pty/pty.go | 78 ------ console/pty/run.go | 7 - console/pty/start_test.go | 28 -- console/test_console.go | 45 --- pty/pty.go | 31 +++ {console/pty => pty}/pty_other.go | 12 +- {console/pty => pty}/pty_windows.go | 68 ++--- pty/ptytest/ptytest.go | 93 +++++++ pty/ptytest/ptytest_test.go | 14 + pty/start.go | 7 + .../pty/run_other.go => pty/start_other.go | 0 {console/pty => pty}/start_windows.go | 40 ++- pty/start_windows_test.go | 18 ++ wintest/main.go | 17 +- 28 files changed, 289 insertions(+), 1525 deletions(-) delete mode 100644 console/conpty/conpty.go delete mode 100644 console/conpty/spawn.go delete mode 100644 console/conpty/syscall.go delete mode 100644 console/console.go delete mode 100644 console/doc.go delete mode 100644 console/expect.go delete mode 100644 console/expect_opt.go delete mode 100644 console/expect_opt_test.go delete mode 100644 console/expect_test.go delete mode 100644 console/pty/pty.go delete mode 100644 console/pty/run.go delete mode 100644 console/pty/start_test.go delete mode 100644 console/test_console.go create mode 100644 pty/pty.go rename {console/pty => pty}/pty_other.go (79%) rename {console/pty => pty}/pty_windows.go (55%) create mode 100644 pty/ptytest/ptytest.go create mode 100644 pty/ptytest/ptytest_test.go create mode 100644 pty/start.go rename console/pty/run_other.go => pty/start_other.go (100%) rename {console/pty => pty}/start_windows.go (58%) create mode 100644 pty/start_windows_test.go diff --git a/agent/server.go b/agent/server.go index 632b191c975ee..746a3d6a63d57 100644 --- a/agent/server.go +++ b/agent/server.go @@ -7,13 +7,12 @@ import ( "errors" "io" "net" - "os" + "os/exec" "sync" - "syscall" "time" "cdr.dev/slog" - "github.com/ActiveState/termtest/conpty" + "github.com/coder/coder/console/pty" "github.com/coder/coder/peer" "github.com/coder/coder/peerbroker" "github.com/coder/retry" @@ -71,31 +70,24 @@ func (s *server) init(ctx context.Context) { sshLogger.Info(ctx, "ssh connection ended", slog.Error(err)) }, Handler: func(session ssh.Session) { - sshPty, windowSize, isPty := session.Pty() + _, windowSize, isPty := session.Pty() if isPty { - cpty, err := conpty.New(int16(sshPty.Window.Width), int16(sshPty.Window.Height)) - if err != nil { - panic(err) - } - _, _, err = cpty.Spawn("C:\\WINDOWS\\System32\\WindowsPowerShell\\v1.0\\powershell.exe", []string{}, &syscall.ProcAttr{ - Env: os.Environ(), - }) + pty, err := pty.Start(exec.Command("powershell.exe")) if err != nil { panic(err) } go func() { for win := range windowSize { - err := cpty.Resize(uint16(win.Width), uint16(win.Height)) + err := pty.Resize(uint16(win.Width), uint16(win.Height)) if err != nil { panic(err) } } }() - go func() { - io.Copy(session, cpty) + io.Copy(session, pty.Output()) }() - io.Copy(cpty, session) + io.Copy(pty.Input(), session) } }, HostSigners: []ssh.Signer{randomSigner}, diff --git a/cli/login_test.go b/cli/login_test.go index b6c581cc41f12..02af769e6c49c 100644 --- a/cli/login_test.go +++ b/cli/login_test.go @@ -5,7 +5,7 @@ import ( "github.com/coder/coder/cli/clitest" "github.com/coder/coder/coderd/coderdtest" - "github.com/coder/coder/console" + "github.com/coder/coder/pty/ptytest" "github.com/stretchr/testify/require" ) @@ -26,7 +26,9 @@ func TestLogin(t *testing.T) { // accurately detect Windows ptys when they are not attached to a process: // https://github.com/mattn/go-isatty/issues/59 root, _ := clitest.New(t, "login", client.URL.String(), "--force-tty") - cons := console.New(t, root) + pty := ptytest.New(t) + root.SetIn(pty.Input()) + root.SetOut(pty.Output()) go func() { err := root.Execute() require.NoError(t, err) @@ -42,12 +44,9 @@ func TestLogin(t *testing.T) { for i := 0; i < len(matches); i += 2 { match := matches[i] value := matches[i+1] - _, err := cons.ExpectString(match) - require.NoError(t, err) - _, err = cons.SendLine(value) - require.NoError(t, err) + pty.ExpectMatch(match) + pty.WriteLine(value) } - _, err := cons.ExpectString("Welcome to Coder") - require.NoError(t, err) + pty.ExpectMatch("Welcome to Coder") }) } diff --git a/cli/projectcreate_test.go b/cli/projectcreate_test.go index 6311aaf141f30..873a276263e5a 100644 --- a/cli/projectcreate_test.go +++ b/cli/projectcreate_test.go @@ -7,10 +7,10 @@ import ( "github.com/coder/coder/cli/clitest" "github.com/coder/coder/coderd/coderdtest" - "github.com/coder/coder/console" "github.com/coder/coder/database" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionersdk/proto" + "github.com/coder/coder/pty/ptytest" ) func TestProjectCreate(t *testing.T) { @@ -26,7 +26,9 @@ func TestProjectCreate(t *testing.T) { cmd, root := clitest.New(t, "projects", "create", "--directory", source, "--provisioner", string(database.ProvisionerTypeEcho)) clitest.SetupConfig(t, client, root) _ = coderdtest.NewProvisionerDaemon(t, client) - console := console.New(t, cmd) + pty := ptytest.New(t) + cmd.SetIn(pty.Input()) + cmd.SetOut(pty.Output()) closeChan := make(chan struct{}) go func() { err := cmd.Execute() @@ -43,10 +45,8 @@ func TestProjectCreate(t *testing.T) { for i := 0; i < len(matches); i += 2 { match := matches[i] value := matches[i+1] - _, err := console.ExpectString(match) - require.NoError(t, err) - _, err = console.SendLine(value) - require.NoError(t, err) + pty.ExpectMatch(match) + pty.WriteLine(value) } <-closeChan }) @@ -73,7 +73,9 @@ func TestProjectCreate(t *testing.T) { cmd, root := clitest.New(t, "projects", "create", "--directory", source, "--provisioner", string(database.ProvisionerTypeEcho)) clitest.SetupConfig(t, client, root) coderdtest.NewProvisionerDaemon(t, client) - cons := console.New(t, cmd) + pty := ptytest.New(t) + cmd.SetIn(pty.Input()) + cmd.SetOut(pty.Output()) closeChan := make(chan struct{}) go func() { err := cmd.Execute() @@ -91,10 +93,8 @@ func TestProjectCreate(t *testing.T) { for i := 0; i < len(matches); i += 2 { match := matches[i] value := matches[i+1] - _, err := cons.ExpectString(match) - require.NoError(t, err) - _, err = cons.SendLine(value) - require.NoError(t, err) + pty.ExpectMatch(match) + pty.WriteLine(value) } <-closeChan }) diff --git a/cli/root.go b/cli/root.go index f4e27a49d9e67..e5747f62db961 100644 --- a/cli/root.go +++ b/cli/root.go @@ -139,14 +139,21 @@ func isTTY(cmd *cobra.Command) bool { func prompt(cmd *cobra.Command, prompt *promptui.Prompt) (string, error) { var ok bool - prompt.Stdin, ok = cmd.InOrStdin().(io.ReadCloser) + reader, ok := cmd.InOrStdin().(io.Reader) if !ok { return "", xerrors.New("stdin must be a readcloser") } - prompt.Stdout, ok = cmd.OutOrStdout().(io.WriteCloser) + prompt.Stdin = readWriteCloser{ + Reader: reader, + } + + writer, ok := cmd.OutOrStdout().(io.Writer) if !ok { return "", xerrors.New("stdout must be a readcloser") } + prompt.Stdout = readWriteCloser{ + Writer: writer, + } // The prompt library displays defaults in a jarring way for the user // by attempting to autocomplete it. This sets no default enabling us @@ -199,3 +206,10 @@ func prompt(cmd *cobra.Command, prompt *promptui.Prompt) (string, error) { return value, err } + +// readWriteCloser fakes reads, writes, and closing! +type readWriteCloser struct { + io.Reader + io.Writer + io.Closer +} diff --git a/cli/workspacecreate_test.go b/cli/workspacecreate_test.go index 306caa65c4b0c..8bf683f8f3439 100644 --- a/cli/workspacecreate_test.go +++ b/cli/workspacecreate_test.go @@ -5,9 +5,9 @@ import ( "github.com/coder/coder/cli/clitest" "github.com/coder/coder/coderd/coderdtest" - "github.com/coder/coder/console" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionersdk/proto" + "github.com/coder/coder/pty/ptytest" "github.com/stretchr/testify/require" ) @@ -36,7 +36,9 @@ func TestWorkspaceCreate(t *testing.T) { cmd, root := clitest.New(t, "workspaces", "create", project.Name) clitest.SetupConfig(t, client, root) - cons := console.New(t, cmd) + pty := ptytest.New(t) + cmd.SetIn(pty.Input()) + cmd.SetOut(pty.Output()) closeChan := make(chan struct{}) go func() { err := cmd.Execute() @@ -51,13 +53,10 @@ func TestWorkspaceCreate(t *testing.T) { for i := 0; i < len(matches); i += 2 { match := matches[i] value := matches[i+1] - _, err := cons.ExpectString(match) - require.NoError(t, err) - _, err = cons.SendLine(value) - require.NoError(t, err) + pty.ExpectMatch(match) + pty.WriteLine(value) } - _, err := cons.ExpectString("Create") - require.NoError(t, err) + pty.ExpectMatch("Create") <-closeChan }) } diff --git a/console/conpty/conpty.go b/console/conpty/conpty.go deleted file mode 100644 index 3b00f31a31765..0000000000000 --- a/console/conpty/conpty.go +++ /dev/null @@ -1,111 +0,0 @@ -//go:build windows -// +build windows - -// Original copyright 2020 ActiveState Software. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file - -package conpty - -import ( - "fmt" - "io" - "os" - - "golang.org/x/sys/windows" -) - -// ConPty represents a windows pseudo console. -type ConPty struct { - hpCon windows.Handle - outPipePseudoConsoleSide windows.Handle - outPipeOurSide windows.Handle - inPipeOurSide windows.Handle - inPipePseudoConsoleSide windows.Handle - consoleSize uintptr - outFilePseudoConsoleSide *os.File - outFileOurSide *os.File - inFilePseudoConsoleSide *os.File - inFileOurSide *os.File - closed bool -} - -// New returns a new ConPty pseudo terminal device -func New(columns int16, rows int16) (*ConPty, error) { - c := &ConPty{ - consoleSize: uintptr(columns) + (uintptr(rows) << 16), - } - - return c, c.createPseudoConsoleAndPipes() -} - -// Close closes the pseudo-terminal and cleans up all attached resources -func (c *ConPty) Close() error { - // Trying to close these pipes multiple times will result in an - // access violation - if c.closed { - return nil - } - - err := closePseudoConsole(c.hpCon) - c.outFilePseudoConsoleSide.Close() - c.outFileOurSide.Close() - c.inFilePseudoConsoleSide.Close() - c.inFileOurSide.Close() - c.closed = true - return err -} - -// OutPipe returns the output pipe of the pseudo terminal -func (c *ConPty) OutPipe() *os.File { - return c.outFilePseudoConsoleSide -} - -func (c *ConPty) Reader() io.Reader { - return c.outFileOurSide -} - -func (c *ConPty) Writer() io.Writer { - return c.inFileOurSide -} - -// InPipe returns input pipe of the pseudo terminal -// Note: It is safer to use the Write method to prevent partially-written VT sequences -// from corrupting the terminal -func (c *ConPty) InPipe() *os.File { - return c.inFilePseudoConsoleSide -} - -func (c *ConPty) WriteString(str string) (int, error) { - return c.inFileOurSide.WriteString(str) -} - -func (c *ConPty) createPseudoConsoleAndPipes() error { - // Create the stdin pipe - if err := windows.CreatePipe(&c.inPipePseudoConsoleSide, &c.inPipeOurSide, nil, 0); err != nil { - return err - } - - // Create the stdout pipe - if err := windows.CreatePipe(&c.outPipeOurSide, &c.outPipePseudoConsoleSide, nil, 0); err != nil { - return err - } - - // Create the pty with our stdin/stdout - if err := createPseudoConsole(c.consoleSize, c.inPipePseudoConsoleSide, c.outPipePseudoConsoleSide, &c.hpCon); err != nil { - return fmt.Errorf("failed to create pseudo console: %d, %v", uintptr(c.hpCon), err) - } - - c.outFilePseudoConsoleSide = os.NewFile(uintptr(c.outPipePseudoConsoleSide), "|0") - c.outFileOurSide = os.NewFile(uintptr(c.outPipeOurSide), "|1") - - c.inFilePseudoConsoleSide = os.NewFile(uintptr(c.inPipePseudoConsoleSide), "|2") - c.inFileOurSide = os.NewFile(uintptr(c.inPipeOurSide), "|3") - c.closed = false - - return nil -} - -func (c *ConPty) Resize(cols uint16, rows uint16) error { - return resizePseudoConsole(c.hpCon, uintptr(cols)+(uintptr(rows)<<16)) -} diff --git a/console/conpty/spawn.go b/console/conpty/spawn.go deleted file mode 100644 index 3831e64d07a80..0000000000000 --- a/console/conpty/spawn.go +++ /dev/null @@ -1,260 +0,0 @@ -package conpty - -import ( - "fmt" - "os" - "strings" - "syscall" - "unicode/utf16" - "unsafe" - - "golang.org/x/sys/windows" -) - -// Spawn spawns a new process attached to the pseudo terminal -func Spawn(conpty *ConPty, argv0 string, argv []string, attr *syscall.ProcAttr) (pid int, handle uintptr, err error) { - startupInfo := &startupInfoEx{} - var attrListSize uint64 - startupInfo.startupInfo.Cb = uint32(unsafe.Sizeof(startupInfo)) - - err = initializeProcThreadAttributeList(0, 1, &attrListSize) - if err != nil { - return 0, 0, fmt.Errorf("could not retrieve list size: %v", err) - } - - attributeListBuffer := make([]byte, attrListSize) - startupInfo.lpAttributeList = windows.Handle(unsafe.Pointer(&attributeListBuffer[0])) - - err = initializeProcThreadAttributeList(uintptr(startupInfo.lpAttributeList), 1, &attrListSize) - if err != nil { - return 0, 0, fmt.Errorf("failed to initialize proc thread attributes for conpty: %v", err) - } - - err = updateProcThreadAttributeList( - startupInfo.lpAttributeList, - procThreadAttributePseudoconsole, - conpty.hpCon, - unsafe.Sizeof(conpty.hpCon)) - if err != nil { - return 0, 0, fmt.Errorf("failed to update proc thread attributes attributes for conpty usage: %v", err) - } - - if attr == nil { - attr = &syscall.ProcAttr{} - } - - if len(attr.Dir) != 0 { - // StartProcess assumes that argv0 is relative to attr.Dir, - // because it implies Chdir(attr.Dir) before executing argv0. - // Windows CreateProcess assumes the opposite: it looks for - // argv0 relative to the current directory, and, only once the new - // process is started, it does Chdir(attr.Dir). We are adjusting - // for that difference here by making argv0 absolute. - var err error - argv0, err = joinExeDirAndFName(attr.Dir, argv0) - if err != nil { - return 0, 0, err - } - } - argv0p, err := windows.UTF16PtrFromString(argv0) - if err != nil { - return 0, 0, err - } - - // Windows CreateProcess takes the command line as a single string: - // use attr.CmdLine if set, else build the command line by escaping - // and joining each argument with spaces - cmdline := makeCmdLine(argv) - - var argvp *uint16 - if len(cmdline) != 0 { - argvp, err = windows.UTF16PtrFromString(cmdline) - if err != nil { - return 0, 0, fmt.Errorf("utf ptr from string: %w", err) - } - } - - var dirp *uint16 - if len(attr.Dir) != 0 { - dirp, err = windows.UTF16PtrFromString(attr.Dir) - if err != nil { - return 0, 0, fmt.Errorf("utf ptr from string: %w", err) - } - } - - startupInfo.startupInfo.Flags = windows.STARTF_USESTDHANDLES - - pi := new(windows.ProcessInformation) - - flags := uint32(windows.CREATE_UNICODE_ENVIRONMENT) | extendedStartupinfoPresent - - var zeroSec windows.SecurityAttributes - pSec := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(zeroSec)), InheritHandle: 1} - tSec := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(zeroSec)), InheritHandle: 1} - - // c.startupInfo.startupInfo.Cb = uint32(unsafe.Sizeof(c.startupInfo)) - err = windows.CreateProcess( - argv0p, - argvp, - pSec, // process handle not inheritable - tSec, // thread handles not inheritable, - false, - flags, - createEnvBlock(addCriticalEnv(dedupEnvCase(true, attr.Env))), - dirp, // use current directory later: dirp, - &startupInfo.startupInfo, - pi) - - if err != nil { - return 0, 0, fmt.Errorf("create process: %w", err) - } - defer windows.CloseHandle(windows.Handle(pi.Thread)) - - return int(pi.ProcessId), uintptr(pi.Process), nil -} - -// makeCmdLine builds a command line out of args by escaping "special" -// characters and joining the arguments with spaces. -func makeCmdLine(args []string) string { - var s string - for _, v := range args { - if s != "" { - s += " " - } - s += windows.EscapeArg(v) - } - return s -} - -func isSlash(c uint8) bool { - return c == '\\' || c == '/' -} - -func normalizeDir(dir string) (name string, err error) { - ndir, err := syscall.FullPath(dir) - if err != nil { - return "", err - } - if len(ndir) > 2 && isSlash(ndir[0]) && isSlash(ndir[1]) { - // dir cannot have \\server\share\path form - return "", syscall.EINVAL - } - return ndir, nil -} - -func volToUpper(ch int) int { - if 'a' <= ch && ch <= 'z' { - ch += 'A' - 'a' - } - return ch -} - -func joinExeDirAndFName(dir, p string) (name string, err error) { - if len(p) == 0 { - return "", syscall.EINVAL - } - if len(p) > 2 && isSlash(p[0]) && isSlash(p[1]) { - // \\server\share\path form - return p, nil - } - if len(p) > 1 && p[1] == ':' { - // has drive letter - if len(p) == 2 { - return "", syscall.EINVAL - } - if isSlash(p[2]) { - return p, nil - } else { - d, err := normalizeDir(dir) - if err != nil { - return "", err - } - if volToUpper(int(p[0])) == volToUpper(int(d[0])) { - return syscall.FullPath(d + "\\" + p[2:]) - } else { - return syscall.FullPath(p) - } - } - } else { - // no drive letter - d, err := normalizeDir(dir) - if err != nil { - return "", err - } - if isSlash(p[0]) { - return windows.FullPath(d[:2] + p) - } else { - return windows.FullPath(d + "\\" + p) - } - } -} - -// createEnvBlock converts an array of environment strings into -// the representation required by CreateProcess: a sequence of NUL -// terminated strings followed by a nil. -// Last bytes are two UCS-2 NULs, or four NUL bytes. -func createEnvBlock(envv []string) *uint16 { - if len(envv) == 0 { - return &utf16.Encode([]rune("\x00\x00"))[0] - } - length := 0 - for _, s := range envv { - length += len(s) + 1 - } - length += 1 - - b := make([]byte, length) - i := 0 - for _, s := range envv { - l := len(s) - copy(b[i:i+l], []byte(s)) - copy(b[i+l:i+l+1], []byte{0}) - i = i + l + 1 - } - copy(b[i:i+1], []byte{0}) - - return &utf16.Encode([]rune(string(b)))[0] -} - -// dedupEnvCase is dedupEnv with a case option for testing. -// If caseInsensitive is true, the case of keys is ignored. -func dedupEnvCase(caseInsensitive bool, env []string) []string { - out := make([]string, 0, len(env)) - saw := make(map[string]int, len(env)) // key => index into out - for _, kv := range env { - eq := strings.Index(kv, "=") - if eq < 0 { - out = append(out, kv) - continue - } - k := kv[:eq] - if caseInsensitive { - k = strings.ToLower(k) - } - if dupIdx, isDup := saw[k]; isDup { - out[dupIdx] = kv - continue - } - saw[k] = len(out) - out = append(out, kv) - } - return out -} - -// addCriticalEnv adds any critical environment variables that are required -// (or at least almost always required) on the operating system. -// Currently this is only used for Windows. -func addCriticalEnv(env []string) []string { - for _, kv := range env { - eq := strings.Index(kv, "=") - if eq < 0 { - continue - } - k := kv[:eq] - if strings.EqualFold(k, "SYSTEMROOT") { - // We already have it. - return env - } - } - return append(env, "SYSTEMROOT="+os.Getenv("SYSTEMROOT")) -} diff --git a/console/conpty/syscall.go b/console/conpty/syscall.go deleted file mode 100644 index 39a8904426648..0000000000000 --- a/console/conpty/syscall.go +++ /dev/null @@ -1,120 +0,0 @@ -//go:build windows -// +build windows - -// Copyright 2020 ActiveState Software. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file - -package conpty - -import ( - "unsafe" - - "golang.org/x/sys/windows" -) - -var ( - kernel32 = windows.NewLazySystemDLL("kernel32.dll") - procResizePseudoConsole = kernel32.NewProc("ResizePseudoConsole") - procCreatePseudoConsole = kernel32.NewProc("CreatePseudoConsole") - procClosePseudoConsole = kernel32.NewProc("ClosePseudoConsole") - - // Required for executing processes! - procInitializeProcThreadAttributeList = kernel32.NewProc("InitializeProcThreadAttributeList") - procUpdateProcThreadAttribute = kernel32.NewProc("UpdateProcThreadAttribute") - procLocalAlloc = kernel32.NewProc("LocalAlloc") - procDeleteProcThreadAttributeList = kernel32.NewProc("DeleteProcThreadAttributeList") - procCreateProcessW = kernel32.NewProc("CreateProcessW") -) - -// An extended version of process startup information that points -// to a pseudo terminal object. -type startupInfoEx struct { - startupInfo windows.StartupInfo - lpAttributeList windows.Handle -} - -// Constant in CreateProcessW indicating that extended startup information is present. -const extendedStartupinfoPresent uint32 = 0x00080000 - -type procThreadAttribute uintptr - -// windows constant needed during initialization of extended startupinfo -const procThreadAttributePseudoconsole procThreadAttribute = 22 | 0x00020000 - -func initializeProcThreadAttributeList(attributeList uintptr, attributeCount uint32, listSize *uint64) (err error) { - if attributeList == 0 { - procInitializeProcThreadAttributeList.Call(0, uintptr(attributeCount), 0, uintptr(unsafe.Pointer(listSize))) - return - } - r1, _, e1 := procInitializeProcThreadAttributeList.Call(attributeList, uintptr(attributeCount), 0, uintptr(unsafe.Pointer(listSize))) - - if r1 == 0 { // boolean FALSE - err = e1 - } - - return -} - -func updateProcThreadAttributeList(attributeList windows.Handle, attribute procThreadAttribute, lpValue windows.Handle, lpSize uintptr) (err error) { - - r1, _, e1 := procUpdateProcThreadAttribute.Call(uintptr(attributeList), 0, uintptr(attribute), uintptr(lpValue), lpSize, 0, 0) - - if r1 == 0 { // boolean FALSE - err = e1 - } - - return -} -func deleteProcThreadAttributeList(handle windows.Handle) (err error) { - r1, _, e1 := procDeleteProcThreadAttributeList.Call(uintptr(handle)) - - if r1 == 0 { // boolean FALSE - err = e1 - } - - return -} - -func localAlloc(size uint64) (ptr windows.Handle, err error) { - r1, _, e1 := procLocalAlloc.Call(uintptr(0x0040), uintptr(size)) - if r1 == 0 { - err = e1 - ptr = windows.InvalidHandle - return - } - ptr = windows.Handle(r1) - return -} - -func createPseudoConsole(consoleSize uintptr, ptyIn windows.Handle, ptyOut windows.Handle, hpCon *windows.Handle) (err error) { - r1, _, e1 := procCreatePseudoConsole.Call( - consoleSize, - uintptr(ptyIn), - uintptr(ptyOut), - 0, - uintptr(unsafe.Pointer(hpCon)), - ) - - if r1 != 0 { // !S_OK - err = e1 - } - return -} - -func resizePseudoConsole(handle windows.Handle, consoleSize uintptr) (err error) { - r1, _, e1 := procResizePseudoConsole.Call(uintptr(handle), consoleSize) - if r1 != 0 { // !S_OK - err = e1 - } - return -} - -func closePseudoConsole(handle windows.Handle) (err error) { - r1, _, e1 := procClosePseudoConsole.Call(uintptr(handle)) - if r1 == 0 { - err = e1 - } - - return -} diff --git a/console/console.go b/console/console.go deleted file mode 100644 index 4b7daf52d6613..0000000000000 --- a/console/console.go +++ /dev/null @@ -1,162 +0,0 @@ -// Copyright 2018 Netflix, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package console - -import ( - "bufio" - "fmt" - "io" - "io/ioutil" - "log" - "unicode/utf8" - - "github.com/coder/coder/console/pty" -) - -// Console is an interface to automate input and output for interactive -// applications. Console can block until a specified output is received and send -// input back on it's tty. Console can also multiplex other sources of input -// and multiplex its output to other writers. -type Console struct { - opts Opts - pty pty.Pty - runeReader *bufio.Reader - closers []io.Closer -} - -// Opt allows setting Console options. -type Opt func(*Opts) error - -// Opts provides additional options on creating a Console. -type Opts struct { - Logger *log.Logger - Stdouts []io.Writer - ExpectObservers []Observer -} - -// Observer provides an interface for a function callback that will -// be called after each Expect operation. -// matchers will be the list of active matchers when an error occurred, -// or a list of matchers that matched `buf` when err is nil. -// buf is the captured output that was matched against. -// err is error that might have occurred. May be nil. -type Observer func(matchers []Matcher, buf string, err error) - -// WithStdout adds writers that Console duplicates writes to, similar to the -// Unix tee(1) command. -// -// Each write is written to each listed writer, one at a time. Console is the -// last writer, writing to it's internal buffer for matching expects. -// If a listed writer returns an error, that overall write operation stops and -// returns the error; it does not continue down the list. -func WithStdout(writers ...io.Writer) Opt { - return func(opts *Opts) error { - opts.Stdouts = append(opts.Stdouts, writers...) - return nil - } -} - -// WithLogger adds a logger for Console to log debugging information to. By -// default Console will discard logs. -func WithLogger(logger *log.Logger) Opt { - return func(opts *Opts) error { - opts.Logger = logger - return nil - } -} - -// WithExpectObserver adds an ExpectObserver to allow monitoring Expect operations. -func WithExpectObserver(observers ...Observer) Opt { - return func(opts *Opts) error { - opts.ExpectObservers = append(opts.ExpectObservers, observers...) - return nil - } -} - -// NewWithOptions returns a new Console with the given options. -func NewWithOptions(opts ...Opt) (*Console, error) { - options := Opts{ - Logger: log.New(ioutil.Discard, "", 0), - } - - for _, opt := range opts { - if err := opt(&options); err != nil { - return nil, err - } - } - - consolePty, err := pty.New() - if err != nil { - return nil, err - } - closers := []io.Closer{consolePty} - reader := consolePty.Output() - - cons := &Console{ - opts: options, - pty: consolePty, - runeReader: bufio.NewReaderSize(reader, utf8.UTFMax), - closers: closers, - } - - return cons, nil -} - -// Tty returns an input Tty for accepting input -func (c *Console) Input() io.ReadWriter { - return c.pty.Input() -} - -// OutTty returns an output tty for writing -func (c *Console) Output() io.ReadWriter { - return c.pty.Output() -} - -// Close closes Console's tty. Calling Close will unblock Expect and ExpectEOF. -func (c *Console) Close() error { - for _, fd := range c.closers { - err := fd.Close() - if err != nil { - c.Logf("failed to close: %s", err) - } - } - return nil -} - -// Send writes string s to Console's tty. -func (c *Console) Send(s string) (int, error) { - c.Logf("console send: %q", s) - n, err := c.pty.Input().Write([]byte(s)) - return n, err -} - -// SendLine writes string s to Console's tty with a trailing newline. -func (c *Console) SendLine(s string) (int, error) { - bytes, err := c.Send(fmt.Sprintf("%s\n", s)) - - return bytes, err -} - -// Log prints to Console's logger. -// Arguments are handled in the manner of fmt.Print. -func (c *Console) Log(v ...interface{}) { - c.opts.Logger.Print(v...) -} - -// Logf prints to Console's logger. -// Arguments are handled in the manner of fmt.Printf. -func (c *Console) Logf(format string, v ...interface{}) { - c.opts.Logger.Printf(format, v...) -} diff --git a/console/doc.go b/console/doc.go deleted file mode 100644 index 7a5fc545cd982..0000000000000 --- a/console/doc.go +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright 2018 Netflix, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package expect provides an expect-like interface to automate control of -// applications. It is unlike expect in that it does not spawn or manage -// process lifecycle. This package only focuses on expecting output and sending -// input through it's psuedoterminal. -package console diff --git a/console/expect.go b/console/expect.go deleted file mode 100644 index c2e3f583b0a06..0000000000000 --- a/console/expect.go +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright 2018 Netflix, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package console - -import ( - "bufio" - "bytes" - "fmt" - "io" - "unicode/utf8" -) - -// Expectf reads from the Console's tty until the provided formatted string -// is read or an error occurs, and returns the buffer read by Console. -func (c *Console) Expectf(format string, args ...interface{}) (string, error) { - return c.Expect(String(fmt.Sprintf(format, args...))) -} - -// ExpectString reads from Console's tty until the provided string is read or -// an error occurs, and returns the buffer read by Console. -func (c *Console) ExpectString(s string) (string, error) { - return c.Expect(String(s)) -} - -// Expect reads from Console's tty until a condition specified from opts is -// encountered or an error occurs, and returns the buffer read by console. -// No extra bytes are read once a condition is met, so if a program isn't -// expecting input yet, it will be blocked. Sends are queued up in tty's -// internal buffer so that the next Expect will read the remaining bytes (i.e. -// rest of prompt) as well as its conditions. -func (c *Console) Expect(opts ...ExpectOpt) (string, error) { - var options ExpectOpts - for _, opt := range opts { - if err := opt(&options); err != nil { - return "", err - } - } - - buf := new(bytes.Buffer) - writer := io.MultiWriter(append(c.opts.Stdouts, buf)...) - runeWriter := bufio.NewWriterSize(writer, utf8.UTFMax) - - var matcher Matcher - var err error - - defer func() { - for _, observer := range c.opts.ExpectObservers { - if matcher != nil { - observer([]Matcher{matcher}, buf.String(), err) - return - } - observer(options.Matchers, buf.String(), err) - } - }() - - for { - var r rune - r, _, err = c.runeReader.ReadRune() - if err != nil { - matcher = options.Match(err) - if matcher != nil { - err = nil - break - } - return buf.String(), err - } - - c.Logf("expect read: %q", string(r)) - _, err = runeWriter.WriteRune(r) - if err != nil { - return buf.String(), err - } - - // Immediately flush rune to the underlying writers. - err = runeWriter.Flush() - if err != nil { - return buf.String(), err - } - - matcher = options.Match(buf) - if matcher != nil { - break - } - } - - if matcher != nil { - cb, ok := matcher.(CallbackMatcher) - if ok { - err = cb.Callback(buf) - if err != nil { - return buf.String(), err - } - } - } - - return buf.String(), err -} diff --git a/console/expect_opt.go b/console/expect_opt.go deleted file mode 100644 index fec0d9b8f3e0b..0000000000000 --- a/console/expect_opt.go +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright 2018 Netflix, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package console - -import ( - "bytes" - "strings" - "time" -) - -// ExpectOpt allows settings Expect options. -type ExpectOpt func(*ExpectOpts) error - -// Callback is a callback function to execute if a match is found for -// the chained matcher. -type Callback func(buf *bytes.Buffer) error - -// ExpectOpts provides additional options on Expect. -type ExpectOpts struct { - Matchers []Matcher - ReadTimeout *time.Duration -} - -// Match sequentially calls Match on all matchers in ExpectOpts and returns the -// first matcher if a match exists, otherwise nil. -func (eo ExpectOpts) Match(v interface{}) Matcher { - for _, matcher := range eo.Matchers { - if matcher.Match(v) { - return matcher - } - } - return nil -} - -// CallbackMatcher is a matcher that provides a Callback function. -type CallbackMatcher interface { - // Callback executes the matcher's callback with the content buffer at the - // time of match. - Callback(buf *bytes.Buffer) error -} - -// Matcher provides an interface for finding a match in content read from -// Console's tty. -type Matcher interface { - // Match returns true iff a match is found. - Match(v interface{}) bool - Criteria() interface{} -} - -// stringMatcher fulfills the Matcher interface to match strings against a given -// bytes.Buffer. -type stringMatcher struct { - str string -} - -func (sm *stringMatcher) Match(v interface{}) bool { - buf, ok := v.(*bytes.Buffer) - if !ok { - return false - } - if strings.Contains(buf.String(), sm.str) { - return true - } - return false -} - -func (sm *stringMatcher) Criteria() interface{} { - return sm.str -} - -// allMatcher fulfills the Matcher interface to match a group of ExpectOpt -// against any value. -type allMatcher struct { - options ExpectOpts -} - -func (am *allMatcher) Match(v interface{}) bool { - var matchers []Matcher - for _, matcher := range am.options.Matchers { - if matcher.Match(v) { - continue - } - matchers = append(matchers, matcher) - } - - am.options.Matchers = matchers - return len(matchers) == 0 -} - -func (am *allMatcher) Criteria() interface{} { - var criteria []interface{} - for _, matcher := range am.options.Matchers { - criteria = append(criteria, matcher.Criteria()) - } - return criteria -} - -// All adds an Expect condition to exit if the content read from Console's tty -// matches all of the provided ExpectOpt, in any order. -func All(expectOpts ...ExpectOpt) ExpectOpt { - return func(opts *ExpectOpts) error { - var options ExpectOpts - for _, opt := range expectOpts { - if err := opt(&options); err != nil { - return err - } - } - - opts.Matchers = append(opts.Matchers, &allMatcher{ - options: options, - }) - return nil - } -} - -// String adds an Expect condition to exit if the content read from Console's -// tty contains any of the given strings. -func String(strs ...string) ExpectOpt { - return func(opts *ExpectOpts) error { - for _, str := range strs { - opts.Matchers = append(opts.Matchers, &stringMatcher{ - str: str, - }) - } - return nil - } -} diff --git a/console/expect_opt_test.go b/console/expect_opt_test.go deleted file mode 100644 index 91efc935fca4e..0000000000000 --- a/console/expect_opt_test.go +++ /dev/null @@ -1,163 +0,0 @@ -// Copyright 2018 Netflix, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package console_test - -import ( - "bytes" - "testing" - - "github.com/stretchr/testify/require" - - . "github.com/coder/coder/console" -) - -func TestExpectOptString(t *testing.T) { - t.Parallel() - - tests := []struct { - title string - opt ExpectOpt - data string - expected bool - }{ - { - "No args", - String(), - "Hello world", - false, - }, - { - "Single arg", - String("Hello"), - "Hello world", - true, - }, - { - "Multiple arg", - String("other", "world"), - "Hello world", - true, - }, - { - "No matches", - String("hello"), - "Hello world", - false, - }, - } - - for _, test := range tests { - test := test - t.Run(test.title, func(t *testing.T) { - t.Parallel() - - var options ExpectOpts - err := test.opt(&options) - require.Nil(t, err) - - buf := new(bytes.Buffer) - _, err = buf.WriteString(test.data) - require.Nil(t, err) - - matcher := options.Match(buf) - if test.expected { - require.NotNil(t, matcher) - } else { - require.Nil(t, matcher) - } - }) - } -} - -func TestExpectOptAll(t *testing.T) { - t.Parallel() - - tests := []struct { - title string - opt ExpectOpt - data string - expected bool - }{ - { - "No opts", - All(), - "Hello world", - true, - }, - { - "Single string match", - All(String("Hello")), - "Hello world", - true, - }, - { - "Single string no match", - All(String("Hello")), - "No match", - false, - }, - { - "Ordered strings match", - All(String("Hello"), String("world")), - "Hello world", - true, - }, - { - "Ordered strings not all match", - All(String("Hello"), String("world")), - "Hello", - false, - }, - { - "Unordered strings", - All(String("world"), String("Hello")), - "Hello world", - true, - }, - { - "Unordered strings not all match", - All(String("world"), String("Hello")), - "Hello", - false, - }, - { - "Repeated strings match", - All(String("Hello"), String("Hello")), - "Hello world", - true, - }, - } - - for _, test := range tests { - test := test - t.Run(test.title, func(t *testing.T) { - t.Parallel() - var options ExpectOpts - err := test.opt(&options) - require.Nil(t, err) - - buf := new(bytes.Buffer) - _, err = buf.WriteString(test.data) - require.Nil(t, err) - - matcher := options.Match(buf) - if test.expected { - require.NotNil(t, matcher) - } else { - require.Nil(t, matcher) - } - }) - } -} diff --git a/console/expect_test.go b/console/expect_test.go deleted file mode 100644 index 8fcefbf043234..0000000000000 --- a/console/expect_test.go +++ /dev/null @@ -1,181 +0,0 @@ -// Copyright 2018 Netflix, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package console_test - -import ( - "bufio" - "errors" - "fmt" - "io" - "runtime/debug" - "strings" - "sync" - "testing" - - "golang.org/x/xerrors" - - . "github.com/coder/coder/console" -) - -var ( - ErrWrongAnswer = xerrors.New("wrong answer") -) - -type Survey struct { - Prompt string - Answer string -} - -func Prompt(in io.Reader, out io.Writer) error { - reader := bufio.NewReader(in) - - for _, survey := range []Survey{ - { - "What is 1+1?", "2", - }, - { - "What is Netflix backwards?", "xilfteN", - }, - } { - _, err := fmt.Fprintf(out, "%s: ", survey.Prompt) - if err != nil { - return err - } - text, err := reader.ReadString('\n') - if err != nil { - return err - } - - _, err = fmt.Fprint(out, text) - if err != nil { - return err - } - text = strings.TrimSpace(text) - if text != survey.Answer { - return ErrWrongAnswer - } - } - - return nil -} - -func newTestConsole(t *testing.T, opts ...Opt) (*Console, error) { - opts = append([]Opt{ - expectNoError(t), - }, opts...) - return NewWithOptions(opts...) -} - -func expectNoError(t *testing.T) Opt { - return WithExpectObserver( - func(matchers []Matcher, buf string, err error) { - if err == nil { - return - } - if len(matchers) == 0 { - t.Fatalf("Error occurred while matching %q: %s\n%s", buf, err, string(debug.Stack())) - } else { - var criteria []string - for _, matcher := range matchers { - criteria = append(criteria, fmt.Sprintf("%q", matcher.Criteria())) - } - t.Fatalf("Failed to find [%s] in %q: %s\n%s", strings.Join(criteria, ", "), buf, err, string(debug.Stack())) - } - }, - ) -} - -func testCloser(t *testing.T, closer io.Closer) { - if err := closer.Close(); err != nil { - t.Errorf("Close failed: %s", err) - debug.PrintStack() - } -} - -func TestExpectf(t *testing.T) { - t.Parallel() - - console, err := newTestConsole(t) - if err != nil { - t.Errorf("Expected no error but got'%s'", err) - } - defer testCloser(t, console) - - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - console.Expectf("What is 1+%d?", 1) - console.SendLine("2") - console.Expectf("What is %s backwards?", "Netflix") - console.SendLine("xilfteN") - }() - - err = Prompt(console.Input(), console.Output()) - if err != nil { - t.Errorf("Expected no error but got '%s'", err) - } - wg.Wait() -} - -func TestExpect(t *testing.T) { - t.Parallel() - - console, err := newTestConsole(t) - if err != nil { - t.Errorf("Expected no error but got'%s'", err) - } - defer testCloser(t, console) - - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - console.ExpectString("What is 1+1?") - console.SendLine("2") - console.ExpectString("What is Netflix backwards?") - console.SendLine("xilfteN") - }() - - err = Prompt(console.Input(), console.Output()) - if err != nil { - t.Errorf("Expected no error but got '%s'", err) - } - wg.Wait() -} - -func TestExpectOutput(t *testing.T) { - t.Parallel() - - console, err := newTestConsole(t) - if err != nil { - t.Errorf("Expected no error but got'%s'", err) - } - defer testCloser(t, console) - - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - console.ExpectString("What is 1+1?") - console.SendLine("3") - }() - - err = Prompt(console.Input(), console.Output()) - if err == nil || !errors.Is(err, ErrWrongAnswer) { - t.Errorf("Expected error '%s' but got '%s' instead", ErrWrongAnswer, err) - } - wg.Wait() -} diff --git a/console/pty/pty.go b/console/pty/pty.go deleted file mode 100644 index cf4b4849454ec..0000000000000 --- a/console/pty/pty.go +++ /dev/null @@ -1,78 +0,0 @@ -package pty - -import ( - "io" - "os" -) - -// Pty is the minimal pseudo-tty interface we require. -type Pty interface { - Input() io.ReadWriter - Output() io.ReadWriter - Resize(cols uint16, rows uint16) error - Close() error -} - -// New creates a new Pty. -func New() (Pty, error) { - return newPty() -} - -func pipePty() (Pty, error) { - inFilePipeSide, inFileOurSide, err := os.Pipe() - if err != nil { - return nil, err - } - - outFileOurSide, outFilePipeSide, err := os.Pipe() - if err != nil { - return nil, err - } - - return &pipePtyVal{ - inFilePipeSide, - inFileOurSide, - outFileOurSide, - outFilePipeSide, - }, nil -} - -type pipePtyVal struct { - inFilePipeSide, inFileOurSide *os.File - outFileOurSide, outFilePipeSide *os.File -} - -func (p *pipePtyVal) Output() io.ReadWriter { - return readWriter{ - Reader: p.outFilePipeSide, - Writer: p.outFileOurSide, - } -} - -func (p *pipePtyVal) Input() io.ReadWriter { - return readWriter{ - Reader: p.inFilePipeSide, - Writer: p.inFileOurSide, - } -} - -func (p *pipePtyVal) WriteString(str string) (int, error) { - return p.inFileOurSide.WriteString(str) -} - -func (p *pipePtyVal) Resize(uint16, uint16) error { - return nil -} - -func (p *pipePtyVal) Close() error { - p.inFileOurSide.Close() - p.inFilePipeSide.Close() - p.outFilePipeSide.Close() - p.outFileOurSide.Close() - return nil -} - -type readWriter struct { - io.Reader - io.Writer -} diff --git a/console/pty/run.go b/console/pty/run.go deleted file mode 100644 index e6c8e234178be..0000000000000 --- a/console/pty/run.go +++ /dev/null @@ -1,7 +0,0 @@ -package pty - -import "os/exec" - -func Run(cmd *exec.Cmd) (Pty, error) { - return runPty(cmd) -} diff --git a/console/pty/start_test.go b/console/pty/start_test.go deleted file mode 100644 index 2ad525eefb1c4..0000000000000 --- a/console/pty/start_test.go +++ /dev/null @@ -1,28 +0,0 @@ -package pty_test - -import ( - "fmt" - "os/exec" - "regexp" - "testing" - - "github.com/coder/coder/console/pty" - "github.com/stretchr/testify/require" -) - -var ( - // Used to ensure terminal output doesn't have anything crazy! - // See: https://stackoverflow.com/a/29497680 - stripAnsi = regexp.MustCompile("[\u001B\u009B][[\\]()#;?]*(?:(?:(?:[a-zA-Z\\d]*(?:;[a-zA-Z\\d]*)*)?\u0007)|(?:(?:\\d{1,4}(?:;\\d{0,4})*)?[\\dA-PRZcf-ntqry=><~]))") -) - -func TestStart(t *testing.T) { - t.Run("Do", func(t *testing.T) { - pty, err := pty.Run(exec.Command("powershell.exe", "echo", "test")) - require.NoError(t, err) - data := make([]byte, 128) - _, err = pty.Output().Read(data) - require.NoError(t, err) - t.Log(fmt.Sprintf("%q", stripAnsi.ReplaceAllString(string(data), ""))) - }) -} diff --git a/console/test_console.go b/console/test_console.go deleted file mode 100644 index 961ee63589353..0000000000000 --- a/console/test_console.go +++ /dev/null @@ -1,45 +0,0 @@ -package console - -import ( - "bufio" - "io" - "regexp" - "testing" - - "github.com/spf13/cobra" - "github.com/stretchr/testify/require" -) - -var ( - // Used to ensure terminal output doesn't have anything crazy! - // See: https://stackoverflow.com/a/29497680 - stripAnsi = regexp.MustCompile("[\u001B\u009B][[\\]()#;?]*(?:(?:(?:[a-zA-Z\\d]*(?:;[a-zA-Z\\d]*)*)?\u0007)|(?:(?:\\d{1,4}(?:;\\d{0,4})*)?[\\dA-PRZcf-ntqry=><~]))") -) - -// New creates a new TTY bound to the command provided. -// All ANSI escape codes are stripped to provide clean output. -func New(t *testing.T, cmd *cobra.Command) *Console { - reader, writer := io.Pipe() - scanner := bufio.NewScanner(reader) - t.Cleanup(func() { - _ = reader.Close() - _ = writer.Close() - }) - go func() { - for scanner.Scan() { - if scanner.Err() != nil { - return - } - t.Log(stripAnsi.ReplaceAllString(scanner.Text(), "")) - } - }() - - console, err := NewWithOptions(WithStdout(writer)) - require.NoError(t, err) - t.Cleanup(func() { - console.Close() - }) - cmd.SetIn(console.Input()) - cmd.SetOut(console.Output()) - return console -} diff --git a/pty/pty.go b/pty/pty.go new file mode 100644 index 0000000000000..c1e4a092bdb79 --- /dev/null +++ b/pty/pty.go @@ -0,0 +1,31 @@ +package pty + +import ( + "io" +) + +// PTY is a minimal interface for interacting with a TTY. +type PTY interface { + io.Closer + // Output handles PTY output. + // + // cmd.SetOutput(pty.Output()) would be used to specify a command + // uses the output stream for writing. + // + // The same stream could be read to validate output. + Output() io.ReadWriter + // Input handles PTY input. + // + // cmd.SetInput(pty.Input()) would be used to specify a command + // uses the PTY input for reading. + // + // The same stream would be used to provide user input: pty.Input().Write(...) + Input() io.ReadWriter + // Resize sets the size of the PTY. + Resize(cols uint16, rows uint16) error +} + +// New constructs a new Pty. +func New() (PTY, error) { + return newPty() +} diff --git a/console/pty/pty_other.go b/pty/pty_other.go similarity index 79% rename from console/pty/pty_other.go rename to pty/pty_other.go index 7149b344c6fcc..bb71286bead63 100644 --- a/console/pty/pty_other.go +++ b/pty/pty_other.go @@ -26,19 +26,11 @@ type otherPty struct { pty, tty *os.File } -func (p *otherPty) InPipe() *os.File { - return p.tty -} - -func (p *otherPty) OutPipe() *os.File { - return p.tty -} - -func (p *otherPty) Reader() io.Reader { +func (p *otherPty) Input() io.ReadWriter { return p.pty } -func (p *otherPty) Writer() io.Writer { +func (p *otherPty) Output() io.ReadWriter { return p.pty } diff --git a/console/pty/pty_windows.go b/pty/pty_windows.go similarity index 55% rename from console/pty/pty_windows.go rename to pty/pty_windows.go index 4e2430772e6b3..e01117e98bb9c 100644 --- a/console/pty/pty_windows.go +++ b/pty/pty_windows.go @@ -22,77 +22,67 @@ var ( ) // See: https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session -func newPty() (Pty, error) { +func newPty() (PTY, error) { // We use the CreatePseudoConsole API which was introduced in build 17763 vsn := windows.RtlGetVersion() if vsn.MajorVersion < 10 || vsn.BuildNumber < 17763 { // If the CreatePseudoConsole API is not available, we fall back to a simpler // implementation that doesn't create an actual PTY - just uses os.Pipe - return pipePty() + return nil, xerrors.Errorf("pty not supported") } ptyWindows := &ptyWindows{} - // Create the stdin pipe - if err := windows.CreatePipe(&ptyWindows.inputReadSide, &ptyWindows.inputWriteSide, nil, 0); err != nil { + var err error + ptyWindows.inputRead, ptyWindows.inputWrite, err = os.Pipe() + if err != nil { return nil, err } + ptyWindows.outputRead, ptyWindows.outputWrite, err = os.Pipe() - // Create the stdout pipe - if err := windows.CreatePipe(&ptyWindows.outputReadSide, &ptyWindows.outputWriteSide, nil, 0); err != nil { - return nil, err - } - - consoleSize := uintptr((int32(80) << 16) | int32(80)) + consoleSize := uintptr((int32(20) << 16) | int32(20)) ret, _, err := procCreatePseudoConsole.Call( consoleSize, - uintptr(ptyWindows.inputReadSide), - uintptr(ptyWindows.outputWriteSide), + uintptr(ptyWindows.inputRead.Fd()), + uintptr(ptyWindows.outputWrite.Fd()), 0, uintptr(unsafe.Pointer(&ptyWindows.console)), ) - if ret != 0 { - return nil, xerrors.Errorf("create pseudo console (%d): %w", ret, err) + if int32(ret) < 0 { + return nil, xerrors.Errorf("create pseudo console (%d): %w", int32(ret), err) } - - ptyWindows.outputWriteSideFile = os.NewFile(uintptr(ptyWindows.outputWriteSide), "|0") - ptyWindows.outputReadSideFile = os.NewFile(uintptr(ptyWindows.outputReadSide), "|1") - ptyWindows.inputReadSideFile = os.NewFile(uintptr(ptyWindows.inputReadSide), "|2") - ptyWindows.inputWriteSideFile = os.NewFile(uintptr(ptyWindows.inputWriteSide), "|3") - ptyWindows.closed = false - return ptyWindows, nil } type ptyWindows struct { console windows.Handle - outputWriteSide windows.Handle - outputReadSide windows.Handle - inputReadSide windows.Handle - inputWriteSide windows.Handle - - outputWriteSideFile *os.File - outputReadSideFile *os.File - inputReadSideFile *os.File - inputWriteSideFile *os.File + outputWrite *os.File + outputRead *os.File + inputWrite *os.File + inputRead *os.File closeMutex sync.Mutex closed bool } -func (p *ptyWindows) Input() io.ReadWriter { +type readWriter struct { + io.Reader + io.Writer +} + +func (p *ptyWindows) Output() io.ReadWriter { return readWriter{ - Writer: p.inputWriteSideFile, - Reader: p.inputReadSideFile, + Reader: p.outputRead, + Writer: p.outputWrite, } } -func (p *ptyWindows) Output() io.ReadWriter { +func (p *ptyWindows) Input() io.ReadWriter { return readWriter{ - Writer: p.outputWriteSideFile, - Reader: p.outputReadSideFile, + Reader: p.inputRead, + Writer: p.inputWrite, } } @@ -116,9 +106,7 @@ func (p *ptyWindows) Close() error { if ret != 0 { return xerrors.Errorf("close pseudo console: %w", err) } - _ = p.outputWriteSideFile.Close() - _ = p.outputReadSideFile.Close() - _ = p.inputReadSideFile.Close() - _ = p.inputWriteSideFile.Close() + _ = p.outputRead.Close() + _ = p.inputWrite.Close() return nil } diff --git a/pty/ptytest/ptytest.go b/pty/ptytest/ptytest.go new file mode 100644 index 0000000000000..555e3f6d723be --- /dev/null +++ b/pty/ptytest/ptytest.go @@ -0,0 +1,93 @@ +package ptytest + +import ( + "bufio" + "bytes" + "io" + "os/exec" + "regexp" + "strings" + "testing" + "unicode/utf8" + + "github.com/coder/coder/pty" + "github.com/stretchr/testify/require" +) + +var ( + // Used to ensure terminal output doesn't have anything crazy! + // See: https://stackoverflow.com/a/29497680 + stripAnsi = regexp.MustCompile("[\u001B\u009B][[\\]()#;?]*(?:(?:(?:[a-zA-Z\\d]*(?:;[a-zA-Z\\d]*)*)?\u0007)|(?:(?:\\d{1,4}(?:;\\d{0,4})*)?[\\dA-PRZcf-ntqry=><~]))") +) + +func New(t *testing.T) *PTY { + pty, err := pty.New() + require.NoError(t, err) + return create(t, pty) +} + +func Start(t *testing.T, cmd *exec.Cmd) *PTY { + pty, err := pty.Start(cmd) + require.NoError(t, err) + return create(t, pty) +} + +func create(t *testing.T, pty pty.PTY) *PTY { + reader, writer := io.Pipe() + scanner := bufio.NewScanner(reader) + t.Cleanup(func() { + _ = reader.Close() + _ = writer.Close() + }) + go func() { + for scanner.Scan() { + if scanner.Err() != nil { + return + } + t.Log(stripAnsi.ReplaceAllString(scanner.Text(), "")) + } + }() + + t.Cleanup(func() { + _ = pty.Close() + }) + return &PTY{ + t: t, + PTY: pty, + + outputWriter: writer, + runeReader: bufio.NewReaderSize(pty.Output(), utf8.UTFMax), + } +} + +type PTY struct { + t *testing.T + pty.PTY + + outputWriter io.Writer + runeReader *bufio.Reader +} + +func (p *PTY) ExpectMatch(str string) string { + var buffer bytes.Buffer + multiWriter := io.MultiWriter(&buffer, p.outputWriter) + runeWriter := bufio.NewWriterSize(multiWriter, utf8.UTFMax) + for { + var r rune + r, _, err := p.runeReader.ReadRune() + require.NoError(p.t, err) + _, err = runeWriter.WriteRune(r) + require.NoError(p.t, err) + err = runeWriter.Flush() + require.NoError(p.t, err) + if strings.Contains(buffer.String(), str) { + break + } + } + return buffer.String() +} + +func (p *PTY) WriteLine(str string) { + _, err := p.PTY.Input().Write([]byte(str + "\n")) + require.NoError(p.t, err) +} diff --git a/pty/ptytest/ptytest_test.go b/pty/ptytest/ptytest_test.go new file mode 100644 index 0000000000000..992077cdc2200 --- /dev/null +++ b/pty/ptytest/ptytest_test.go @@ -0,0 +1,14 @@ +package ptytest_test + +import ( + "testing" + + "github.com/coder/coder/pty/ptytest" +) + +func TestPtytest(t *testing.T) { + pty := ptytest.New(t) + pty.Output().Write([]byte("write")) + pty.ExpectMatch("write") + pty.WriteLine("read") +} diff --git a/pty/start.go b/pty/start.go new file mode 100644 index 0000000000000..2b75843ee16c2 --- /dev/null +++ b/pty/start.go @@ -0,0 +1,7 @@ +package pty + +import "os/exec" + +func Start(cmd *exec.Cmd) (PTY, error) { + return startPty(cmd) +} diff --git a/console/pty/run_other.go b/pty/start_other.go similarity index 100% rename from console/pty/run_other.go rename to pty/start_other.go diff --git a/console/pty/start_windows.go b/pty/start_windows.go similarity index 58% rename from console/pty/start_windows.go rename to pty/start_windows.go index 9d163420e772b..4c9d601261039 100644 --- a/console/pty/start_windows.go +++ b/pty/start_windows.go @@ -6,12 +6,15 @@ package pty import ( "os" "os/exec" + "unicode/utf16" "unsafe" "golang.org/x/sys/windows" ) -func runPty(cmd *exec.Cmd) (Pty, error) { +// Allocates a PTY and starts the specified command attached to it. +// See: https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session#creating-the-hosted-process +func startPty(cmd *exec.Cmd) (PTY, error) { fullPath, err := exec.LookPath(cmd.Path) if err != nil { return nil, err @@ -44,15 +47,15 @@ func runPty(cmd *exec.Cmd) (Pty, error) { if err != nil { return nil, err } - err = attrs.Update(22|0x00020000, unsafe.Pointer(winPty.console), unsafe.Sizeof(winPty.console)) + err = attrs.Update(0x20016, unsafe.Pointer(winPty.console), unsafe.Sizeof(winPty.console)) if err != nil { return nil, err } startupInfo := &windows.StartupInfoEx{} - startupInfo.StartupInfo.Cb = uint32(unsafe.Sizeof(*startupInfo)) - startupInfo.StartupInfo.Flags = windows.STARTF_USESTDHANDLES startupInfo.ProcThreadAttributeList = attrs.List() + startupInfo.StartupInfo.Flags = windows.STARTF_USESTDHANDLES + startupInfo.StartupInfo.Cb = uint32(unsafe.Sizeof(*startupInfo)) var processInfo windows.ProcessInformation err = windows.CreateProcess( pathPtr, @@ -63,7 +66,7 @@ func runPty(cmd *exec.Cmd) (Pty, error) { // https://docs.microsoft.com/en-us/windows/win32/procthread/process-creation-flags#create_unicode_environment windows.CREATE_UNICODE_ENVIRONMENT|windows.EXTENDED_STARTUPINFO_PRESENT, // Environment variables can come later! - nil, + createEnvBlock([]string{"SYSTEMROOT=" + os.Getenv("SYSTEMROOT")}), dirPtr, &startupInfo.StartupInfo, &processInfo, @@ -71,7 +74,32 @@ func runPty(cmd *exec.Cmd) (Pty, error) { if err != nil { return nil, err } - defer windows.CloseHandle(windows.Handle(processInfo.Thread)) + defer windows.CloseHandle(processInfo.Thread) + defer windows.CloseHandle(processInfo.Process) return pty, nil } + +// Taken from: https://github.com/microsoft/hcsshim/blob/7fbdca16f91de8792371ba22b7305bf4ca84170a/internal/exec/exec.go#L476 +func createEnvBlock(envv []string) *uint16 { + if len(envv) == 0 { + return &utf16.Encode([]rune("\x00\x00"))[0] + } + length := 0 + for _, s := range envv { + length += len(s) + 1 + } + length += 1 + + b := make([]byte, length) + i := 0 + for _, s := range envv { + l := len(s) + copy(b[i:i+l], []byte(s)) + copy(b[i+l:i+l+1], []byte{0}) + i = i + l + 1 + } + copy(b[i:i+1], []byte{0}) + + return &utf16.Encode([]rune(string(b)))[0] +} diff --git a/pty/start_windows_test.go b/pty/start_windows_test.go new file mode 100644 index 0000000000000..4fa10c47f5dfc --- /dev/null +++ b/pty/start_windows_test.go @@ -0,0 +1,18 @@ +//go:build windows +// +build windows + +package pty_test + +import ( + "os/exec" + "testing" + + "github.com/coder/coder/pty/ptytest" +) + +func TestStart(t *testing.T) { + t.Run("Echo", func(t *testing.T) { + pty := ptytest.Start(t, exec.Command("cmd.exe", "/c", "echo", "test")) + pty.ExpectMatch("test") + }) +} diff --git a/wintest/main.go b/wintest/main.go index 12831d46c2da1..6d504b9e67f03 100644 --- a/wintest/main.go +++ b/wintest/main.go @@ -25,6 +25,19 @@ func main() { } defer terminal.Restore(0, oldState) + // if true { + // pty, err := pty.Run(exec.Command("powershell.exe")) + // if err != nil { + // panic(err) + // } + // go func() { + // _, _ = io.Copy(pty.Input(), os.Stdin) + + // }() + // _, _ = io.Copy(os.Stdout, pty.Output()) + // return + // } + t := &testing.T{} ctx := context.Background() client, server := provisionersdk.TransportPipe() @@ -61,9 +74,7 @@ func main() { sshClient := ssh.NewClient(sshConn, channels, requests) session, err := sshClient.NewSession() require.NoError(t, err) - err = session.RequestPty("", 1024, 1024, ssh.TerminalModes{ - ssh.IGNCR: 1, - }) + err = session.RequestPty("", 128, 128, ssh.TerminalModes{}) require.NoError(t, err) session.Stdin = os.Stdin session.Stdout = os.Stdout From 12ffaff17ce186b4db9400e52d57aaae27866e33 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 17 Feb 2022 00:12:51 -0600 Subject: [PATCH 07/18] SSH server now works on Windows --- agent/server.go | 2 +- agent/server_test.go | 3 --- cli/clitest/clitest_test.go | 9 +++++---- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/agent/server.go b/agent/server.go index 746a3d6a63d57..a80a6e5879cb8 100644 --- a/agent/server.go +++ b/agent/server.go @@ -12,9 +12,9 @@ import ( "time" "cdr.dev/slog" - "github.com/coder/coder/console/pty" "github.com/coder/coder/peer" "github.com/coder/coder/peerbroker" + "github.com/coder/coder/pty" "github.com/coder/retry" "github.com/gliderlabs/ssh" diff --git a/agent/server_test.go b/agent/server_test.go index 8351a25c53257..76c33e0546d06 100644 --- a/agent/server_test.go +++ b/agent/server_test.go @@ -66,6 +66,3 @@ func TestAgent(t *testing.T) { require.NoError(t, err) }) } - -// Read + write for input -// Read + write for output diff --git a/cli/clitest/clitest_test.go b/cli/clitest/clitest_test.go index fa11db7c04e9b..b1bd908bf6128 100644 --- a/cli/clitest/clitest_test.go +++ b/cli/clitest/clitest_test.go @@ -8,7 +8,7 @@ import ( "github.com/coder/coder/cli/clitest" "github.com/coder/coder/coderd/coderdtest" - "github.com/coder/coder/console" + "github.com/coder/coder/pty/ptytest" ) func TestMain(m *testing.M) { @@ -21,11 +21,12 @@ func TestCli(t *testing.T) { client := coderdtest.New(t) cmd, config := clitest.New(t) clitest.SetupConfig(t, client, config) - cons := console.New(t, cmd) + pty := ptytest.New(t) + cmd.SetIn(pty.Input()) + cmd.SetOut(pty.Output()) go func() { err := cmd.Execute() require.NoError(t, err) }() - _, err := cons.ExpectString("coder") - require.NoError(t, err) + pty.ExpectMatch("coder") } From e9381fa4f89e6634582ab61e4b8376a0c4659211 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 17 Feb 2022 14:48:49 +0000 Subject: [PATCH 08/18] Fix non-Windows --- .vscode/settings.json | 1 + agent/server_test.go | 3 +++ go.mod | 1 - go.sum | 2 -- pty/pty_other.go | 2 +- pty/start_other.go | 2 +- pty/start_other_test.go | 15 +++++++++++++++ 7 files changed, 21 insertions(+), 5 deletions(-) create mode 100644 pty/start_other_test.go diff --git a/.vscode/settings.json b/.vscode/settings.json index 14599af9e98b6..1e3c14fae03d5 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -52,6 +52,7 @@ "protobuf", "provisionerd", "provisionersdk", + "ptytest", "retrier", "sdkproto", "stretchr", diff --git a/agent/server_test.go b/agent/server_test.go index 76c33e0546d06..4471c4ae47a54 100644 --- a/agent/server_test.go +++ b/agent/server_test.go @@ -22,6 +22,9 @@ func TestMain(m *testing.M) { } func TestAgent(t *testing.T) { + t.Skip() + return + t.Run("asd", func(t *testing.T) { ctx := context.Background() client, server := provisionersdk.TransportPipe() diff --git a/go.mod b/go.mod index 0b8bcd86e2f8d..7c86df1308668 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,6 @@ replace github.com/chzyer/readline => github.com/kylecarbs/readline v0.0.0-20220 require ( cdr.dev/slog v1.4.1 - github.com/ActiveState/termtest/conpty v0.5.0 github.com/briandowns/spinner v1.18.1 github.com/coder/retry v1.3.0 github.com/creack/pty v1.1.17 diff --git a/go.sum b/go.sum index 863c5672d1a90..24b37bc319d34 100644 --- a/go.sum +++ b/go.sum @@ -57,8 +57,6 @@ cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RX cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= gioui.org v0.0.0-20210308172011-57750fc8a0a6/go.mod h1:RSH6KIUZ0p2xy5zHDxgAM4zumjgTw83q2ge/PI+yyw8= -github.com/ActiveState/termtest/conpty v0.5.0 h1:JLUe6YDs4Jw4xNPCU+8VwTpniYOGeKzQg4SM2YHQNA8= -github.com/ActiveState/termtest/conpty v0.5.0/go.mod h1:LO4208FLsxw6DcNZ1UtuGUMW+ga9PFtX4ntv8Ymg9og= github.com/Azure/azure-pipeline-go v0.2.3/go.mod h1:x841ezTBIMG6O3lAcl8ATHnsOPVl2bqk7S3ta6S6u4k= github.com/Azure/azure-sdk-for-go v16.2.1+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc= github.com/Azure/azure-storage-blob-go v0.14.0/go.mod h1:SMqIBi+SuiQH32bvyjngEewEeXoPfKMgWlBDaYf6fck= diff --git a/pty/pty_other.go b/pty/pty_other.go index bb71286bead63..74c6217a1e1ca 100644 --- a/pty/pty_other.go +++ b/pty/pty_other.go @@ -10,7 +10,7 @@ import ( "github.com/creack/pty" ) -func newPty() (Pty, error) { +func newPty() (PTY, error) { ptyFile, ttyFile, err := pty.Open() if err != nil { return nil, err diff --git a/pty/start_other.go b/pty/start_other.go index 79f40d061a788..c93b850e9fe60 100644 --- a/pty/start_other.go +++ b/pty/start_other.go @@ -9,7 +9,7 @@ import ( "github.com/creack/pty" ) -func runPty(cmd *exec.Cmd) (Pty, error) { +func startPty(cmd *exec.Cmd) (PTY, error) { pty, err := pty.Start(cmd) if err != nil { return nil, err diff --git a/pty/start_other_test.go b/pty/start_other_test.go new file mode 100644 index 0000000000000..d0796032a9614 --- /dev/null +++ b/pty/start_other_test.go @@ -0,0 +1,15 @@ +package pty_test + +import ( + "os/exec" + "testing" + + "github.com/coder/coder/pty/ptytest" +) + +func TestStart(t *testing.T) { + t.Run("Echo", func(t *testing.T) { + pty := ptytest.Start(t, exec.Command("echo", "test")) + pty.ExpectMatch("test") + }) +} From db73933d6e7e5b9b12a9922fdd6c6771524787c1 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 17 Feb 2022 15:26:26 +0000 Subject: [PATCH 09/18] Fix Linux PTY render --- cli/root.go | 17 ++--------------- pty/pty.go | 12 ++++++++++-- pty/pty_other.go | 14 ++++++++------ pty/pty_windows.go | 5 ----- pty/ptytest/ptytest.go | 3 ++- pty/start_other.go | 20 ++++++++++++++++++-- 6 files changed, 40 insertions(+), 31 deletions(-) diff --git a/cli/root.go b/cli/root.go index e5747f62db961..55e2b4c1d65ef 100644 --- a/cli/root.go +++ b/cli/root.go @@ -12,7 +12,6 @@ import ( "github.com/manifoldco/promptui" "github.com/mattn/go-isatty" "github.com/spf13/cobra" - "golang.org/x/xerrors" "github.com/coder/coder/cli/config" "github.com/coder/coder/coderd" @@ -138,21 +137,9 @@ func isTTY(cmd *cobra.Command) bool { } func prompt(cmd *cobra.Command, prompt *promptui.Prompt) (string, error) { - var ok bool - reader, ok := cmd.InOrStdin().(io.Reader) - if !ok { - return "", xerrors.New("stdin must be a readcloser") - } - prompt.Stdin = readWriteCloser{ - Reader: reader, - } - - writer, ok := cmd.OutOrStdout().(io.Writer) - if !ok { - return "", xerrors.New("stdout must be a readcloser") - } + prompt.Stdin = io.NopCloser(cmd.InOrStdin()) prompt.Stdout = readWriteCloser{ - Writer: writer, + Writer: cmd.OutOrStdout(), } // The prompt library displays defaults in a jarring way for the user diff --git a/pty/pty.go b/pty/pty.go index c1e4a092bdb79..0086bfba56c15 100644 --- a/pty/pty.go +++ b/pty/pty.go @@ -7,20 +7,23 @@ import ( // PTY is a minimal interface for interacting with a TTY. type PTY interface { io.Closer - // Output handles PTY output. + + // Output handles TTY output. // // cmd.SetOutput(pty.Output()) would be used to specify a command // uses the output stream for writing. // // The same stream could be read to validate output. Output() io.ReadWriter - // Input handles PTY input. + + // Input handles TTY input. // // cmd.SetInput(pty.Input()) would be used to specify a command // uses the PTY input for reading. // // The same stream would be used to provide user input: pty.Input().Write(...) Input() io.ReadWriter + // Resize sets the size of the PTY. Resize(cols uint16, rows uint16) error } @@ -29,3 +32,8 @@ type PTY interface { func New() (PTY, error) { return newPty() } + +type readWriter struct { + io.Reader + io.Writer +} diff --git a/pty/pty_other.go b/pty/pty_other.go index 74c6217a1e1ca..dbdda408b1365 100644 --- a/pty/pty_other.go +++ b/pty/pty_other.go @@ -27,15 +27,17 @@ type otherPty struct { } func (p *otherPty) Input() io.ReadWriter { - return p.pty + return readWriter{ + Reader: p.tty, + Writer: p.pty, + } } func (p *otherPty) Output() io.ReadWriter { - return p.pty -} - -func (p *otherPty) WriteString(str string) (int, error) { - return p.pty.WriteString(str) + return readWriter{ + Reader: p.pty, + Writer: p.tty, + } } func (p *otherPty) Resize(cols uint16, rows uint16) error { diff --git a/pty/pty_windows.go b/pty/pty_windows.go index e01117e98bb9c..5cf1dec4ee06f 100644 --- a/pty/pty_windows.go +++ b/pty/pty_windows.go @@ -67,11 +67,6 @@ type ptyWindows struct { closed bool } -type readWriter struct { - io.Reader - io.Writer -} - func (p *ptyWindows) Output() io.ReadWriter { return readWriter{ Reader: p.outputRead, diff --git a/pty/ptytest/ptytest.go b/pty/ptytest/ptytest.go index 555e3f6d723be..b2e0477756125 100644 --- a/pty/ptytest/ptytest.go +++ b/pty/ptytest/ptytest.go @@ -3,6 +3,7 @@ package ptytest import ( "bufio" "bytes" + "fmt" "io" "os/exec" "regexp" @@ -88,6 +89,6 @@ func (p *PTY) ExpectMatch(str string) string { } func (p *PTY) WriteLine(str string) { - _, err := p.PTY.Input().Write([]byte(str + "\n")) + _, err := fmt.Fprintf(p.PTY.Input(), "%s\n", str) require.NoError(p.t, err) } diff --git a/pty/start_other.go b/pty/start_other.go index c93b850e9fe60..4365fa2e70cb7 100644 --- a/pty/start_other.go +++ b/pty/start_other.go @@ -5,14 +5,30 @@ package pty import ( "os/exec" + "syscall" "github.com/creack/pty" ) func startPty(cmd *exec.Cmd) (PTY, error) { - pty, err := pty.Start(cmd) + pty, tty, err := pty.Open() if err != nil { return nil, err } - return &otherPty{pty, pty}, nil + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setsid: true, + Setctty: true, + } + cmd.Stdout = tty + cmd.Stderr = tty + cmd.Stdin = tty + err = cmd.Start() + if err != nil { + _ = pty.Close() + return nil, err + } + return &otherPty{ + pty: pty, + tty: tty, + }, nil } From d00bc29e12879750f6e12084c638cf22785e5623 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 17 Feb 2022 15:28:36 +0000 Subject: [PATCH 10/18] FIx linux build tests --- pty/start_other_test.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pty/start_other_test.go b/pty/start_other_test.go index d0796032a9614..4c16a8b11a52a 100644 --- a/pty/start_other_test.go +++ b/pty/start_other_test.go @@ -1,3 +1,6 @@ +//go:build !windows +// +build !windows + package pty_test import ( From 3121e107a7169b9efcb3ed24a8386a19e7aa5bdf Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 17 Feb 2022 15:31:00 +0000 Subject: [PATCH 11/18] Remove agent and wintest --- agent/server.go | 191 ------------------------------------------- agent/server_test.go | 71 ---------------- wintest/main.go | 84 ------------------- 3 files changed, 346 deletions(-) delete mode 100644 agent/server.go delete mode 100644 agent/server_test.go delete mode 100644 wintest/main.go diff --git a/agent/server.go b/agent/server.go deleted file mode 100644 index a80a6e5879cb8..0000000000000 --- a/agent/server.go +++ /dev/null @@ -1,191 +0,0 @@ -package agent - -import ( - "context" - "crypto/rand" - "crypto/rsa" - "errors" - "io" - "net" - "os/exec" - "sync" - "time" - - "cdr.dev/slog" - "github.com/coder/coder/peer" - "github.com/coder/coder/peerbroker" - "github.com/coder/coder/pty" - "github.com/coder/retry" - - "github.com/gliderlabs/ssh" - gossh "golang.org/x/crypto/ssh" -) - -type Options struct { - Logger slog.Logger -} - -type Dialer func(ctx context.Context) (*peerbroker.Listener, error) - -func Server(dialer Dialer, options *Options) io.Closer { - ctx, cancelFunc := context.WithCancel(context.Background()) - s := &server{ - clientDialer: dialer, - options: options, - closeCancel: cancelFunc, - } - s.init(ctx) - return s -} - -type server struct { - clientDialer Dialer - options *Options - - closeCancel context.CancelFunc - closeMutex sync.Mutex - closed chan struct{} - closeError error - - sshServer *ssh.Server -} - -func (s *server) init(ctx context.Context) { - // Clients' should ignore the host key when connecting. - // The agent needs to authenticate with coderd to SSH, - // so SSH authentication doesn't improve security. - randomHostKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - panic(err) - } - randomSigner, err := gossh.NewSignerFromKey(randomHostKey) - if err != nil { - panic(err) - } - sshLogger := s.options.Logger.Named("ssh-server") - forwardHandler := &ssh.ForwardedTCPHandler{} - s.sshServer = &ssh.Server{ - ChannelHandlers: ssh.DefaultChannelHandlers, - ConnectionFailedCallback: func(conn net.Conn, err error) { - sshLogger.Info(ctx, "ssh connection ended", slog.Error(err)) - }, - Handler: func(session ssh.Session) { - _, windowSize, isPty := session.Pty() - if isPty { - pty, err := pty.Start(exec.Command("powershell.exe")) - if err != nil { - panic(err) - } - go func() { - for win := range windowSize { - err := pty.Resize(uint16(win.Width), uint16(win.Height)) - if err != nil { - panic(err) - } - } - }() - go func() { - io.Copy(session, pty.Output()) - }() - io.Copy(pty.Input(), session) - } - }, - HostSigners: []ssh.Signer{randomSigner}, - LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool { - // Allow local port forwarding all! - sshLogger.Debug(ctx, "local port forward", - slog.F("destination-host", destinationHost), - slog.F("destination-port", destinationPort)) - return true - }, - PtyCallback: func(ctx ssh.Context, pty ssh.Pty) bool { - return true - }, - ReversePortForwardingCallback: func(ctx ssh.Context, bindHost string, bindPort uint32) bool { - // Allow reverse port forwarding all! - sshLogger.Debug(ctx, "local port forward", - slog.F("bind-host", bindHost), - slog.F("bind-port", bindPort)) - return true - }, - RequestHandlers: map[string]ssh.RequestHandler{ - "tcpip-forward": forwardHandler.HandleSSHRequest, - "cancel-tcpip-forward": forwardHandler.HandleSSHRequest, - }, - ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig { - return &gossh.ServerConfig{ - Config: gossh.Config{ - // "arcfour" is the fastest SSH cipher. We prioritize throughput - // over encryption here, because the WebRTC connection is already - // encrypted. If possible, we'd disable encryption entirely here. - Ciphers: []string{"arcfour"}, - }, - NoClientAuth: true, - } - }, - } - - go s.run(ctx) -} - -func (s *server) run(ctx context.Context) { - var peerListener *peerbroker.Listener - var err error - // An exponential back-off occurs when the connection is failing to dial. - // This is to prevent server spam in case of a coderd outage. - for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); { - peerListener, err = s.clientDialer(ctx) - if err != nil { - if errors.Is(err, context.Canceled) { - return - } - if s.isClosed() { - return - } - s.options.Logger.Warn(context.Background(), "failed to dial", slog.Error(err)) - continue - } - s.options.Logger.Debug(context.Background(), "connected") - break - } - - for { - conn, err := peerListener.Accept() - if err != nil { - // This is closed! - return - } - go s.handle(ctx, conn) - } -} - -func (s *server) handle(ctx context.Context, conn *peer.Conn) { - for { - channel, err := conn.Accept(ctx) - if err != nil { - // TODO: Log here! - return - } - - switch channel.Protocol() { - case "ssh": - s.sshServer.HandleConn(channel.NetConn()) - case "proxy": - // Proxy the port provided. - } - } -} - -// isClosed returns whether the API is closed or not. -func (s *server) isClosed() bool { - select { - case <-s.closed: - return true - default: - return false - } -} - -func (s *server) Close() error { - return nil -} diff --git a/agent/server_test.go b/agent/server_test.go deleted file mode 100644 index 4471c4ae47a54..0000000000000 --- a/agent/server_test.go +++ /dev/null @@ -1,71 +0,0 @@ -package agent_test - -import ( - "context" - "os" - "testing" - - "cdr.dev/slog/sloggers/slogtest" - "github.com/coder/coder/agent" - "github.com/coder/coder/peer" - "github.com/coder/coder/peerbroker" - "github.com/coder/coder/peerbroker/proto" - "github.com/coder/coder/provisionersdk" - "github.com/pion/webrtc/v3" - "github.com/stretchr/testify/require" - "go.uber.org/goleak" - "golang.org/x/crypto/ssh" -) - -func TestMain(m *testing.M) { - goleak.VerifyTestMain(m) -} - -func TestAgent(t *testing.T) { - t.Skip() - return - - t.Run("asd", func(t *testing.T) { - ctx := context.Background() - client, server := provisionersdk.TransportPipe() - defer client.Close() - defer server.Close() - closer := agent.Server(func(ctx context.Context) (*peerbroker.Listener, error) { - return peerbroker.Listen(server, &peer.ConnOptions{ - Logger: slogtest.Make(t, nil), - }) - }, &agent.Options{ - Logger: slogtest.Make(t, nil), - }) - defer closer.Close() - api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client)) - stream, err := api.NegotiateConnection(ctx) - require.NoError(t, err) - conn, err := peerbroker.Dial(stream, []webrtc.ICEServer{}, &peer.ConnOptions{ - Logger: slogtest.Make(t, nil), - }) - require.NoError(t, err) - defer conn.Close() - channel, err := conn.Dial(ctx, "example", &peer.ChannelOptions{ - Protocol: "ssh", - }) - require.NoError(t, err) - sshConn, channels, requests, err := ssh.NewClientConn(channel.NetConn(), "localhost:22", &ssh.ClientConfig{ - User: "kyle", - Config: ssh.Config{ - Ciphers: []string{"arcfour"}, - }, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), - }) - require.NoError(t, err) - sshClient := ssh.NewClient(sshConn, channels, requests) - session, err := sshClient.NewSession() - require.NoError(t, err) - err = session.RequestPty("xterm-256color", 128, 128, ssh.TerminalModes{}) - require.NoError(t, err) - session.Stdout = os.Stdout - session.Stderr = os.Stderr - err = session.Run("cmd.exe /k echo test") - require.NoError(t, err) - }) -} diff --git a/wintest/main.go b/wintest/main.go deleted file mode 100644 index 6d504b9e67f03..0000000000000 --- a/wintest/main.go +++ /dev/null @@ -1,84 +0,0 @@ -package main - -import ( - "context" - "log" - "os" - "testing" - - "cdr.dev/slog/sloggers/slogtest" - "github.com/coder/coder/agent" - "github.com/coder/coder/peer" - "github.com/coder/coder/peerbroker" - "github.com/coder/coder/peerbroker/proto" - "github.com/coder/coder/provisionersdk" - "github.com/pion/webrtc/v3" - "github.com/stretchr/testify/require" - "golang.org/x/crypto/ssh" - "golang.org/x/crypto/ssh/terminal" -) - -func main() { - oldState, err := terminal.MakeRaw(int(os.Stdin.Fd())) - if err != nil { - log.Fatalf("Could not put terminal in raw mode: %v\n", err) - } - defer terminal.Restore(0, oldState) - - // if true { - // pty, err := pty.Run(exec.Command("powershell.exe")) - // if err != nil { - // panic(err) - // } - // go func() { - // _, _ = io.Copy(pty.Input(), os.Stdin) - - // }() - // _, _ = io.Copy(os.Stdout, pty.Output()) - // return - // } - - t := &testing.T{} - ctx := context.Background() - client, server := provisionersdk.TransportPipe() - defer client.Close() - defer server.Close() - closer := agent.Server(func(ctx context.Context) (*peerbroker.Listener, error) { - return peerbroker.Listen(server, &peer.ConnOptions{ - Logger: slogtest.Make(t, nil), - }) - }, &agent.Options{ - Logger: slogtest.Make(t, nil), - }) - defer closer.Close() - api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client)) - stream, err := api.NegotiateConnection(ctx) - require.NoError(t, err) - conn, err := peerbroker.Dial(stream, []webrtc.ICEServer{}, &peer.ConnOptions{ - Logger: slogtest.Make(t, nil), - }) - require.NoError(t, err) - defer conn.Close() - channel, err := conn.Dial(ctx, "example", &peer.ChannelOptions{ - Protocol: "ssh", - }) - require.NoError(t, err) - sshConn, channels, requests, err := ssh.NewClientConn(channel.NetConn(), "localhost:22", &ssh.ClientConfig{ - User: "kyle", - Config: ssh.Config{ - Ciphers: []string{"arcfour"}, - }, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), - }) - require.NoError(t, err) - sshClient := ssh.NewClient(sshConn, channels, requests) - session, err := sshClient.NewSession() - require.NoError(t, err) - err = session.RequestPty("", 128, 128, ssh.TerminalModes{}) - require.NoError(t, err) - session.Stdin = os.Stdin - session.Stdout = os.Stdout - session.Stderr = os.Stderr - err = session.Run("C:\\WINDOWS\\System32\\WindowsPowerShell\\v1.0\\powershell.exe") - require.NoError(t, err) -} From cbbae07c2c027f69e066b53bf27fd32c17ed74df Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 17 Feb 2022 15:37:57 +0000 Subject: [PATCH 12/18] Add test for Windows resize --- pty/pty_windows.go | 2 +- pty/start_windows_test.go | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/pty/pty_windows.go b/pty/pty_windows.go index 5cf1dec4ee06f..b6a9f8ae2e5dd 100644 --- a/pty/pty_windows.go +++ b/pty/pty_windows.go @@ -41,7 +41,7 @@ func newPty() (PTY, error) { } ptyWindows.outputRead, ptyWindows.outputWrite, err = os.Pipe() - consoleSize := uintptr((int32(20) << 16) | int32(20)) + consoleSize := uintptr(80) + (uintptr(80) << 16) ret, _, err := procCreatePseudoConsole.Call( consoleSize, uintptr(ptyWindows.inputRead.Fd()), diff --git a/pty/start_windows_test.go b/pty/start_windows_test.go index 4fa10c47f5dfc..0ff8eb6b529a5 100644 --- a/pty/start_windows_test.go +++ b/pty/start_windows_test.go @@ -10,9 +10,19 @@ import ( "github.com/coder/coder/pty/ptytest" ) +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + func TestStart(t *testing.T) { t.Run("Echo", func(t *testing.T) { pty := ptytest.Start(t, exec.Command("cmd.exe", "/c", "echo", "test")) pty.ExpectMatch("test") }) + + t.Run("Resize", func(t *testing.T) { + pty := ptytest.Start(t, exec.Command("cmd.exe")) + err := pty.Resize(100, 50) + require.NoError(t, err) + }) } From 91ea40aa332e39158359e478f730ada1f2f99e4a Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 17 Feb 2022 15:42:05 +0000 Subject: [PATCH 13/18] Fix linting errors --- .vscode/settings.json | 1 + cli/login_test.go | 3 ++- cli/workspacecreate_test.go | 3 ++- coderd/projectimport_test.go | 3 ++- codersdk/projectimport_test.go | 5 +++-- pty/ptytest/ptytest.go | 19 ++++++++++--------- pty/ptytest/ptytest_test.go | 1 + pty/start_other.go | 6 +++--- pty/start_other_test.go | 2 ++ pty/start_windows_test.go | 4 +++- 10 files changed, 29 insertions(+), 18 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 1e3c14fae03d5..02c3b05cc42c5 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -52,6 +52,7 @@ "protobuf", "provisionerd", "provisionersdk", + "ptty", "ptytest", "retrier", "sdkproto", diff --git a/cli/login_test.go b/cli/login_test.go index 02af769e6c49c..24caf18e1aa3f 100644 --- a/cli/login_test.go +++ b/cli/login_test.go @@ -3,10 +3,11 @@ package cli_test import ( "testing" + "github.com/stretchr/testify/require" + "github.com/coder/coder/cli/clitest" "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/pty/ptytest" - "github.com/stretchr/testify/require" ) func TestLogin(t *testing.T) { diff --git a/cli/workspacecreate_test.go b/cli/workspacecreate_test.go index 8bf683f8f3439..b3b1ca26915f7 100644 --- a/cli/workspacecreate_test.go +++ b/cli/workspacecreate_test.go @@ -3,12 +3,13 @@ package cli_test import ( "testing" + "github.com/stretchr/testify/require" + "github.com/coder/coder/cli/clitest" "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionersdk/proto" "github.com/coder/coder/pty/ptytest" - "github.com/stretchr/testify/require" ) func TestWorkspaceCreate(t *testing.T) { diff --git a/coderd/projectimport_test.go b/coderd/projectimport_test.go index 06140190f51d5..b9df691233576 100644 --- a/coderd/projectimport_test.go +++ b/coderd/projectimport_test.go @@ -5,13 +5,14 @@ import ( "net/http" "testing" + "github.com/stretchr/testify/require" + "github.com/coder/coder/coderd" "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/codersdk" "github.com/coder/coder/database" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionersdk/proto" - "github.com/stretchr/testify/require" ) func TestPostProjectImportByOrganization(t *testing.T) { diff --git a/codersdk/projectimport_test.go b/codersdk/projectimport_test.go index 8cc6b28a23f6c..ccbe01345845a 100644 --- a/codersdk/projectimport_test.go +++ b/codersdk/projectimport_test.go @@ -5,12 +5,13 @@ import ( "testing" "time" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "github.com/coder/coder/coderd" "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionersdk/proto" - "github.com/google/uuid" - "github.com/stretchr/testify/require" ) func TestCreateProjectImportJob(t *testing.T) { diff --git a/pty/ptytest/ptytest.go b/pty/ptytest/ptytest.go index b2e0477756125..7ea5b7a119f0d 100644 --- a/pty/ptytest/ptytest.go +++ b/pty/ptytest/ptytest.go @@ -11,8 +11,9 @@ import ( "testing" "unicode/utf8" - "github.com/coder/coder/pty" "github.com/stretchr/testify/require" + + "github.com/coder/coder/pty" ) var ( @@ -22,18 +23,18 @@ var ( ) func New(t *testing.T) *PTY { - pty, err := pty.New() + ptty, err := pty.New() require.NoError(t, err) - return create(t, pty) + return create(t, ptty) } func Start(t *testing.T, cmd *exec.Cmd) *PTY { - pty, err := pty.Start(cmd) + ptty, err := pty.Start(cmd) require.NoError(t, err) - return create(t, pty) + return create(t, ptty) } -func create(t *testing.T, pty pty.PTY) *PTY { +func create(t *testing.T, ptty pty.PTY) *PTY { reader, writer := io.Pipe() scanner := bufio.NewScanner(reader) t.Cleanup(func() { @@ -50,14 +51,14 @@ func create(t *testing.T, pty pty.PTY) *PTY { }() t.Cleanup(func() { - _ = pty.Close() + _ = ptty.Close() }) return &PTY{ t: t, - PTY: pty, + PTY: ptty, outputWriter: writer, - runeReader: bufio.NewReaderSize(pty.Output(), utf8.UTFMax), + runeReader: bufio.NewReaderSize(ptty.Output(), utf8.UTFMax), } } diff --git a/pty/ptytest/ptytest_test.go b/pty/ptytest/ptytest_test.go index 992077cdc2200..6603b35ad59db 100644 --- a/pty/ptytest/ptytest_test.go +++ b/pty/ptytest/ptytest_test.go @@ -7,6 +7,7 @@ import ( ) func TestPtytest(t *testing.T) { + t.Parallel() pty := ptytest.New(t) pty.Output().Write([]byte("write")) pty.ExpectMatch("write") diff --git a/pty/start_other.go b/pty/start_other.go index 4365fa2e70cb7..103f55202efe3 100644 --- a/pty/start_other.go +++ b/pty/start_other.go @@ -11,7 +11,7 @@ import ( ) func startPty(cmd *exec.Cmd) (PTY, error) { - pty, tty, err := pty.Open() + ptty, tty, err := pty.Open() if err != nil { return nil, err } @@ -24,11 +24,11 @@ func startPty(cmd *exec.Cmd) (PTY, error) { cmd.Stdin = tty err = cmd.Start() if err != nil { - _ = pty.Close() + _ = ptty.Close() return nil, err } return &otherPty{ - pty: pty, + pty: ptty, tty: tty, }, nil } diff --git a/pty/start_other_test.go b/pty/start_other_test.go index 4c16a8b11a52a..16fc9c6789932 100644 --- a/pty/start_other_test.go +++ b/pty/start_other_test.go @@ -11,7 +11,9 @@ import ( ) func TestStart(t *testing.T) { + t.Parallel() t.Run("Echo", func(t *testing.T) { + t.Parallel() pty := ptytest.Start(t, exec.Command("echo", "test")) pty.ExpectMatch("test") }) diff --git a/pty/start_windows_test.go b/pty/start_windows_test.go index 0ff8eb6b529a5..311d65d2c3e5b 100644 --- a/pty/start_windows_test.go +++ b/pty/start_windows_test.go @@ -15,12 +15,14 @@ func TestMain(m *testing.M) { } func TestStart(t *testing.T) { + t.Parallel() t.Run("Echo", func(t *testing.T) { + t.Parallel() pty := ptytest.Start(t, exec.Command("cmd.exe", "/c", "echo", "test")) pty.ExpectMatch("test") }) - t.Run("Resize", func(t *testing.T) { + t.Parallel() pty := ptytest.Start(t, exec.Command("cmd.exe")) err := pty.Resize(100, 50) require.NoError(t, err) From 3187b70e0c13fec9d3e21bfb1599ae74b4ba3a89 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 17 Feb 2022 15:55:08 +0000 Subject: [PATCH 14/18] Add Windows environment variables --- pty/start_windows.go | 46 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/pty/start_windows.go b/pty/start_windows.go index 4c9d601261039..e851af3893b80 100644 --- a/pty/start_windows.go +++ b/pty/start_windows.go @@ -65,8 +65,7 @@ func startPty(cmd *exec.Cmd) (PTY, error) { false, // https://docs.microsoft.com/en-us/windows/win32/procthread/process-creation-flags#create_unicode_environment windows.CREATE_UNICODE_ENVIRONMENT|windows.EXTENDED_STARTUPINFO_PRESENT, - // Environment variables can come later! - createEnvBlock([]string{"SYSTEMROOT=" + os.Getenv("SYSTEMROOT")}), + createEnvBlock(addCriticalEnv(dedupEnvCase(true, cmd.Env))), dirPtr, &startupInfo.StartupInfo, &processInfo, @@ -103,3 +102,46 @@ func createEnvBlock(envv []string) *uint16 { return &utf16.Encode([]rune(string(b)))[0] } + +// dedupEnvCase is dedupEnv with a case option for testing. +// If caseInsensitive is true, the case of keys is ignored. +func dedupEnvCase(caseInsensitive bool, env []string) []string { + out := make([]string, 0, len(env)) + saw := make(map[string]int, len(env)) // key => index into out + for _, kv := range env { + eq := strings.Index(kv, "=") + if eq < 0 { + out = append(out, kv) + continue + } + k := kv[:eq] + if caseInsensitive { + k = strings.ToLower(k) + } + if dupIdx, isDup := saw[k]; isDup { + out[dupIdx] = kv + continue + } + saw[k] = len(out) + out = append(out, kv) + } + return out +} + +// addCriticalEnv adds any critical environment variables that are required +// (or at least almost always required) on the operating system. +// Currently this is only used for Windows. +func addCriticalEnv(env []string) []string { + for _, kv := range env { + eq := strings.Index(kv, "=") + if eq < 0 { + continue + } + k := kv[:eq] + if strings.EqualFold(k, "SYSTEMROOT") { + // We already have it. + return env + } + } + return append(env, "SYSTEMROOT="+os.Getenv("SYSTEMROOT")) +} From 60bd0a490f0f6a98a7b894c22dc83199f25a9237 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 17 Feb 2022 10:06:23 -0600 Subject: [PATCH 15/18] Add strings import --- pty/start_windows.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pty/start_windows.go b/pty/start_windows.go index e851af3893b80..4ffe28f892e20 100644 --- a/pty/start_windows.go +++ b/pty/start_windows.go @@ -6,6 +6,7 @@ package pty import ( "os" "os/exec" + "strings" "unicode/utf16" "unsafe" From 80b443bb32240c5a1bdd45ced699cf642cf81c01 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 17 Feb 2022 16:19:23 +0000 Subject: [PATCH 16/18] Add comment for attrs --- pty/start_windows.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pty/start_windows.go b/pty/start_windows.go index e851af3893b80..a42d916c4ba9a 100644 --- a/pty/start_windows.go +++ b/pty/start_windows.go @@ -47,6 +47,7 @@ func startPty(cmd *exec.Cmd) (PTY, error) { if err != nil { return nil, err } + // Taken from: https://github.com/microsoft/hcsshim/blob/2314362e977aa03b3ed245a4beb12d00422af0e2/internal/winapi/process.go#L6 err = attrs.Update(0x20016, unsafe.Pointer(winPty.console), unsafe.Sizeof(winPty.console)) if err != nil { return nil, err From 21a8efe6f29107fcffbbc21fdaefde3da937174a Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 17 Feb 2022 16:30:36 +0000 Subject: [PATCH 17/18] Add goleak --- pty/start_other_test.go | 5 +++++ pty/start_windows_test.go | 1 + 2 files changed, 6 insertions(+) diff --git a/pty/start_other_test.go b/pty/start_other_test.go index 16fc9c6789932..a5e7d94b36af1 100644 --- a/pty/start_other_test.go +++ b/pty/start_other_test.go @@ -8,8 +8,13 @@ import ( "testing" "github.com/coder/coder/pty/ptytest" + "go.uber.org/goleak" ) +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + func TestStart(t *testing.T) { t.Parallel() t.Run("Echo", func(t *testing.T) { diff --git a/pty/start_windows_test.go b/pty/start_windows_test.go index 311d65d2c3e5b..35ff47c6c1425 100644 --- a/pty/start_windows_test.go +++ b/pty/start_windows_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/coder/coder/pty/ptytest" + "go.uber.org/goleak" ) func TestMain(m *testing.M) { From 9edc0140ca5839db4d767a5559a8ea99a8dd46aa Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 17 Feb 2022 16:36:30 +0000 Subject: [PATCH 18/18] Add require import --- pty/start_windows_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pty/start_windows_test.go b/pty/start_windows_test.go index 35ff47c6c1425..faee269776830 100644 --- a/pty/start_windows_test.go +++ b/pty/start_windows_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/coder/coder/pty/ptytest" + "github.com/stretchr/testify/require" "go.uber.org/goleak" )