diff --git a/agent/agent.go b/agent/agent.go index 75787b4cfc5e1..56cf1fa00b253 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -21,9 +21,12 @@ import ( "time" "github.com/armon/circbuf" + "github.com/gliderlabs/ssh" "github.com/google/uuid" - + "github.com/pkg/sftp" "go.uber.org/atomic" + gossh "golang.org/x/crypto/ssh" + "golang.org/x/xerrors" "cdr.dev/slog" "github.com/coder/coder/agent/usershell" @@ -31,12 +34,12 @@ import ( "github.com/coder/coder/peerbroker" "github.com/coder/coder/pty" "github.com/coder/retry" +) - "github.com/pkg/sftp" - - "github.com/gliderlabs/ssh" - gossh "golang.org/x/crypto/ssh" - "golang.org/x/xerrors" +const ( + ProtocolReconnectingPTY = "reconnecting-pty" + ProtocolSSH = "ssh" + ProtocolDial = "dial" ) type Options struct { @@ -174,17 +177,25 @@ func (*agent) runStartupScript(ctx context.Context, script string) error { defer func() { _ = writer.Close() }() + caller := "-c" if runtime.GOOS == "windows" { caller = "/c" } + cmd := exec.CommandContext(ctx, shell, caller, script) cmd.Stdout = writer cmd.Stderr = writer err = cmd.Run() if err != nil { + // cmd.Run does not return a context canceled error, it returns "signal: killed". + if ctx.Err() != nil { + return ctx.Err() + } + return xerrors.Errorf("run: %w", err) } + return nil } @@ -208,11 +219,11 @@ func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) { } switch channel.Protocol() { - case "ssh": + case ProtocolSSH: go a.sshServer.HandleConn(channel.NetConn()) - case "reconnecting-pty": + case ProtocolReconnectingPTY: go a.handleReconnectingPTY(ctx, channel.Label(), channel.NetConn()) - case "dial": + case ProtocolDial: go a.handleDial(ctx, channel.Label(), channel.NetConn()) default: a.logger.Warn(ctx, "unhandled protocol from channel", @@ -478,8 +489,8 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn ne a.logger.Warn(ctx, "start reconnecting pty command", slog.F("id", id)) } - // Default to buffer 64KB. - circularBuffer, err := circbuf.NewBuffer(64 * 1024) + // Default to buffer 64KiB. + circularBuffer, err := circbuf.NewBuffer(64 << 10) if err != nil { a.logger.Warn(ctx, "create circular buffer", slog.Error(err)) return diff --git a/agent/agent_test.go b/agent/agent_test.go index c5759f3771d63..cb476116eb435 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -172,8 +172,9 @@ func TestAgent(t *testing.T) { tempPath := filepath.Join(os.TempDir(), "content.txt") content := "somethingnice" setupAgent(t, agent.Metadata{ - StartupScript: "echo " + content + " > " + tempPath, + StartupScript: fmt.Sprintf("echo %s > %s", content, tempPath), }, 0) + var gotContent string require.Eventually(t, func() bool { content, err := os.ReadFile(tempPath) @@ -202,6 +203,7 @@ func TestAgent(t *testing.T) { // it seems like it could be either. t.Skip("ConPTY appears to be inconsistent on Windows.") } + conn := setupAgent(t, agent.Metadata{}, 0) id := uuid.NewString() netConn, err := conn.ReconnectingPTY(id, 100, 100) @@ -228,6 +230,7 @@ func TestAgent(t *testing.T) { } } } + matchEchoCommand := func(line string) bool { return strings.Contains(line, "echo test") } diff --git a/agent/conn.go b/agent/conn.go index 56d3d42ea1784..b63c0d0b0da35 100644 --- a/agent/conn.go +++ b/agent/conn.go @@ -36,7 +36,7 @@ type Conn struct { // be reconnected to via ID. func (c *Conn) ReconnectingPTY(id string, height, width uint16) (net.Conn, error) { channel, err := c.CreateChannel(context.Background(), fmt.Sprintf("%s:%d:%d", id, height, width), &peer.ChannelOptions{ - Protocol: "reconnecting-pty", + Protocol: ProtocolReconnectingPTY, }) if err != nil { return nil, xerrors.Errorf("pty: %w", err) @@ -47,7 +47,7 @@ func (c *Conn) ReconnectingPTY(id string, height, width uint16) (net.Conn, error // SSH dials the built-in SSH server. func (c *Conn) SSH() (net.Conn, error) { channel, err := c.CreateChannel(context.Background(), "ssh", &peer.ChannelOptions{ - Protocol: "ssh", + Protocol: ProtocolSSH, }) if err != nil { return nil, xerrors.Errorf("dial: %w", err) @@ -87,7 +87,7 @@ func (c *Conn) DialContext(ctx context.Context, network string, addr string) (ne } channel, err := c.CreateChannel(ctx, u.String(), &peer.ChannelOptions{ - Protocol: "dial", + Protocol: ProtocolDial, Unordered: strings.HasPrefix(network, "udp"), }) if err != nil {