Thanks to visit codestin.com
Credit goes to github.com

Skip to content

feat: port forwarding dropdown #1824

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 94 additions & 4 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"time"

"github.com/armon/circbuf"
"github.com/cakturk/go-netstat/netstat"
"github.com/gliderlabs/ssh"
"github.com/google/uuid"
"github.com/pkg/sftp"
Expand All @@ -37,13 +38,15 @@ import (
)

const (
ProtocolNetstat = "netstat"
ProtocolReconnectingPTY = "reconnecting-pty"
ProtocolSSH = "ssh"
ProtocolDial = "dial"
)

type Options struct {
ReconnectingPTYTimeout time.Duration
NetstatInterval time.Duration
EnvironmentVariables map[string]string
Logger slog.Logger
}
Expand All @@ -65,10 +68,14 @@ func New(dialer Dialer, options *Options) io.Closer {
if options.ReconnectingPTYTimeout == 0 {
options.ReconnectingPTYTimeout = 5 * time.Minute
}
if options.NetstatInterval == 0 {
options.NetstatInterval = 5 * time.Second
}
ctx, cancelFunc := context.WithCancel(context.Background())
server := &agent{
dialer: dialer,
reconnectingPTYTimeout: options.ReconnectingPTYTimeout,
netstatInterval: options.NetstatInterval,
logger: options.Logger,
closeCancel: cancelFunc,
closed: make(chan struct{}),
Expand All @@ -85,6 +92,8 @@ type agent struct {
reconnectingPTYs sync.Map
reconnectingPTYTimeout time.Duration

netstatInterval time.Duration

connCloseWait sync.WaitGroup
closeCancel context.CancelFunc
closeMutex sync.Mutex
Expand Down Expand Up @@ -225,6 +234,8 @@ func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
go a.handleReconnectingPTY(ctx, channel.Label(), channel.NetConn())
case ProtocolDial:
go a.handleDial(ctx, channel.Label(), channel.NetConn())
case ProtocolNetstat:
go a.handleNetstat(ctx, channel.Label(), channel.NetConn())
default:
a.logger.Warn(ctx, "unhandled protocol from channel",
slog.F("protocol", channel.Protocol()),
Expand Down Expand Up @@ -359,12 +370,10 @@ func (a *agent) createCommand(ctx context.Context, rawCommand string, env []stri
if err != nil {
return nil, xerrors.Errorf("getting os executable: %w", err)
}
cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", username))
cmd.Env = append(cmd.Env, fmt.Sprintf(`PATH=%s%c%s`, os.Getenv("PATH"), filepath.ListSeparator, filepath.Dir(executablePath)))
// Git on Windows resolves with UNIX-style paths.
// If using backslashes, it's unable to find the executable.
unixExecutablePath := strings.ReplaceAll(executablePath, "\\", "/")
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_SSH_COMMAND=%s gitssh --`, unixExecutablePath))
executablePath = strings.ReplaceAll(executablePath, "\\", "/")
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_SSH_COMMAND=%s gitssh --`, executablePath))
// These prevent the user from having to specify _anything_ to successfully commit.
// Both author and committer must be set!
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_AUTHOR_EMAIL=%s`, metadata.OwnerEmail))
Expand Down Expand Up @@ -707,6 +716,87 @@ func (a *agent) handleDial(ctx context.Context, label string, conn net.Conn) {
Bicopy(ctx, conn, nconn)
}

type NetstatPort struct {
Name string `json:"name"`
Port uint16 `json:"port"`
}

type NetstatResponse struct {
Ports []NetstatPort `json:"ports"`
Error string `json:"error,omitempty"`
Took time.Duration `json:"took"`
}

func (a *agent) handleNetstat(ctx context.Context, label string, conn net.Conn) {
write := func(resp NetstatResponse) error {
b, err := json.Marshal(resp)
if err != nil {
a.logger.Warn(ctx, "write netstat response", slog.F("label", label), slog.Error(err))
return xerrors.Errorf("marshal agent netstat response: %w", err)
}
_, err = conn.Write(b)
if err != nil {
a.logger.Warn(ctx, "write netstat response", slog.F("label", label), slog.Error(err))
}
return err
}

scan := func() ([]NetstatPort, error) {
if runtime.GOOS != "linux" && runtime.GOOS != "windows" {
return nil, xerrors.New(fmt.Sprintf("Port scanning is not supported on %s", runtime.GOOS))
}

tabs, err := netstat.TCPSocks(func(s *netstat.SockTabEntry) bool {
return s.State == netstat.Listen
})
if err != nil {
return nil, err
}

ports := []NetstatPort{}
for _, tab := range tabs {
ports = append(ports, NetstatPort{
Name: tab.Process.Name,
Port: tab.LocalAddr.Port,
})
}
return ports, nil
}

scanAndWrite := func() {
start := time.Now()
ports, err := scan()
response := NetstatResponse{
Ports: ports,
Took: time.Since(start),
}
if err != nil {
response.Error = err.Error()
}
_ = write(response)
}

scanAndWrite()

// Using a timer instead of a ticker to ensure delay between calls otherwise
// if nestat took longer than the interval we would constantly run it.
timer := time.NewTimer(a.netstatInterval)
go func() {
defer conn.Close()
defer timer.Stop()

for {
select {
case <-ctx.Done():
return
case <-timer.C:
scanAndWrite()
timer.Reset(a.netstatInterval)
}
}
}()
}

// isClosed returns whether the API is closed or not.
func (a *agent) isClosed() bool {
select {
Expand Down
52 changes: 52 additions & 0 deletions agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,57 @@ func TestAgent(t *testing.T) {
require.ErrorContains(t, err, "no such file")
require.Nil(t, netConn)
})

t.Run("Netstat", func(t *testing.T) {
t.Parallel()

var ports []agent.NetstatPort
listen := func() {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
t.Cleanup(func() {
_ = listener.Close()
})

tcpAddr, valid := listener.Addr().(*net.TCPAddr)
require.True(t, valid)

name, err := os.Executable()
require.NoError(t, err)

ports = append(ports, agent.NetstatPort{
Name: filepath.Base(name),
Port: uint16(tcpAddr.Port),
})
}

conn := setupAgent(t, agent.Metadata{}, 0)
netConn, err := conn.Netstat(context.Background())
require.NoError(t, err)
t.Cleanup(func() {
_ = netConn.Close()
})

decoder := json.NewDecoder(netConn)

expectNetstat := func() {
var res agent.NetstatResponse
err = decoder.Decode(&res)
require.NoError(t, err)

if runtime.GOOS == "linux" || runtime.GOOS == "windows" {
require.Subset(t, res.Ports, ports)
} else {
require.Equal(t, fmt.Sprintf("Port scanning is not supported on %s", runtime.GOOS), res.Error)
}
}

listen()
expectNetstat()

listen()
expectNetstat()
})
}

func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exec.Cmd {
Expand Down Expand Up @@ -420,6 +471,7 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration)
}, &agent.Options{
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
ReconnectingPTYTimeout: ptyTimeout,
NetstatInterval: 100 * time.Millisecond,
})
t.Cleanup(func() {
_ = client.Close()
Expand Down
11 changes: 11 additions & 0 deletions agent/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,17 @@ func (c *Conn) DialContext(ctx context.Context, network string, addr string) (ne
return channel.NetConn(), nil
}

// Netstat returns a connection that serves a list of listening ports.
func (c *Conn) Netstat(ctx context.Context) (net.Conn, error) {
channel, err := c.CreateChannel(ctx, "netstat", &peer.ChannelOptions{
Protocol: ProtocolNetstat,
})
if err != nil {
return nil, xerrors.Errorf("netsat: %w", err)
}
return channel.NetConn(), nil
}

func (c *Conn) Close() error {
_ = c.Negotiator.DRPCConn().Close()
return c.Conn.Close()
Expand Down
1 change: 1 addition & 0 deletions coderd/coderd.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ func New(options *Options) *API {
r.Get("/", api.workspaceAgent)
r.Get("/dial", api.workspaceAgentDial)
r.Get("/turn", api.workspaceAgentTurn)
r.Get("/netstat", api.workspaceAgentNetstat)
r.Get("/pty", api.workspaceAgentPTY)
r.Get("/iceservers", api.workspaceAgentICEServers)
})
Expand Down
1 change: 1 addition & 0 deletions coderd/coderd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ func TestAuthorizeAllEndpoints(t *testing.T) {
"GET:/api/v2/workspaceagents/{workspaceagent}": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/{workspaceagent}/dial": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/{workspaceagent}/iceservers": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/{workspaceagent}/netstat": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/{workspaceagent}/pty": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/{workspaceagent}/turn": {NoAuthorize: true},

Expand Down
52 changes: 52 additions & 0 deletions coderd/workspaceagents.go
Original file line number Diff line number Diff line change
Expand Up @@ -501,3 +501,55 @@ func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, agentUpdateFrequency

return workspaceAgent, nil
}

// workspaceAgentNetstat sends listening ports as `agent.NetstatResponse` on an
// interval.
func (api *API) workspaceAgentNetstat(rw http.ResponseWriter, r *http.Request) {
api.websocketWaitMutex.Lock()
api.websocketWaitGroup.Add(1)
api.websocketWaitMutex.Unlock()
defer api.websocketWaitGroup.Done()

workspaceAgent := httpmw.WorkspaceAgentParam(r)
apiAgent, err := convertWorkspaceAgent(workspaceAgent, api.AgentConnectionUpdateFrequency)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("convert workspace agent: %s", err),
})
return
}
if apiAgent.Status != codersdk.WorkspaceAgentConnected {
httpapi.Write(rw, http.StatusPreconditionRequired, httpapi.Response{
Message: fmt.Sprintf("agent must be in the connected state: %s", apiAgent.Status),
})
return
}

conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
CompressionMode: websocket.CompressionDisabled,
})
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: fmt.Sprintf("accept websocket: %s", err),
})
return
}
defer func() {
_ = conn.Close(websocket.StatusNormalClosure, "ended")
}()
wsNetConn := websocket.NetConn(r.Context(), conn, websocket.MessageBinary)
agentConn, err := api.dialWorkspaceAgent(r, workspaceAgent.ID)
if err != nil {
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial workspace agent: %s", err))
return
}
defer agentConn.Close()
ptNetConn, err := agentConn.Netstat(r.Context())
if err != nil {
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial: %s", err))
return
}
defer ptNetConn.Close()

agent.Bicopy(r.Context(), wsNetConn, ptNetConn)
}
65 changes: 65 additions & 0 deletions coderd/workspaceagents_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bufio"
"context"
"encoding/json"
"fmt"
"runtime"
"strings"
"testing"
Expand Down Expand Up @@ -264,3 +265,67 @@ func TestWorkspaceAgentPTY(t *testing.T) {
expectLine(matchEchoCommand)
expectLine(matchEchoOutput)
}

func TestWorkspaceAgentNetstat(t *testing.T) {
t.Parallel()

client, coderAPI := coderdtest.NewWithAPI(t, nil)
user := coderdtest.CreateFirstUser(t, client)
daemonCloser := coderdtest.NewProvisionerDaemon(t, coderAPI)
authToken := uuid.NewString()
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionDryRun: echo.ProvisionComplete,
Provision: []*proto.Provision_Response{{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
Resources: []*proto.Resource{{
Name: "example",
Type: "aws_instance",
Agents: []*proto.Agent{{
Id: uuid.NewString(),
Auth: &proto.Agent_Token{
Token: authToken,
},
}},
}},
},
},
}},
})
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
daemonCloser.Close()

agentClient := codersdk.New(client.URL)
agentClient.SessionToken = authToken
agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &agent.Options{
Logger: slogtest.Make(t, nil),
})
t.Cleanup(func() {
_ = agentCloser.Close()
})
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)

conn, err := client.WorkspaceAgentNetstat(context.Background(), resources[0].Agents[0].ID)
require.NoError(t, err)
defer conn.Close()

decoder := json.NewDecoder(conn)

expectNetstat := func() {
var res agent.NetstatResponse
err = decoder.Decode(&res)
require.NoError(t, err)

if runtime.GOOS == "linux" || runtime.GOOS == "windows" {
require.NotNil(t, res.Ports)
} else {
require.Equal(t, fmt.Sprintf("Port scanning is not supported on %s", runtime.GOOS), res.Error)
}
}

expectNetstat()
}
Loading