Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 9 additions & 39 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,36 +26,6 @@ import (
"golang.org/x/xerrors"
)

func DialSSH(conn *peer.Conn) (net.Conn, error) {
channel, err := conn.Dial(context.Background(), "ssh", &peer.ChannelOptions{
Protocol: "ssh",
})
if err != nil {
return nil, err
}
return channel.NetConn(), nil
}

func DialSSHClient(conn *peer.Conn) (*gossh.Client, error) {
netConn, err := DialSSH(conn)
if err != nil {
return nil, err
}
sshConn, channels, requests, err := gossh.NewClientConn(netConn, "localhost:22", &gossh.ClientConfig{
Config: gossh.Config{
Ciphers: []string{"arcfour"},
},
// SSH host validation isn't helpful, because obtaining a peer
// connection already signifies user-intent to dial a workspace.
// #nosec
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
})
if err != nil {
return nil, err
}
return gossh.NewClient(sshConn, channels, requests), nil
}

type Options struct {
Logger slog.Logger
}
Expand All @@ -64,7 +34,7 @@ type Dialer func(ctx context.Context, options *peer.ConnOptions) (*peerbroker.Li

func New(dialer Dialer, options *peer.ConnOptions) io.Closer {
ctx, cancelFunc := context.WithCancel(context.Background())
server := &server{
server := &agent{
clientDialer: dialer,
options: options,
closeCancel: cancelFunc,
Expand All @@ -74,7 +44,7 @@ func New(dialer Dialer, options *peer.ConnOptions) io.Closer {
return server
}

type server struct {
type agent struct {
clientDialer Dialer
options *peer.ConnOptions

Expand All @@ -86,7 +56,7 @@ type server struct {
sshServer *ssh.Server
}

func (s *server) run(ctx context.Context) {
func (s *agent) run(ctx context.Context) {
var peerListener *peerbroker.Listener
var err error
// An exponential back-off occurs when the connection is failing to dial.
Expand All @@ -103,7 +73,7 @@ func (s *server) run(ctx context.Context) {
s.options.Logger.Warn(context.Background(), "failed to dial", slog.Error(err))
continue
}
s.options.Logger.Debug(context.Background(), "connected")
s.options.Logger.Info(context.Background(), "connected")
break
}
select {
Expand All @@ -129,7 +99,7 @@ func (s *server) run(ctx context.Context) {
}
}

func (s *server) handlePeerConn(ctx context.Context, conn *peer.Conn) {
func (s *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
go func() {
<-conn.Closed()
s.connCloseWait.Done()
Expand All @@ -156,7 +126,7 @@ func (s *server) handlePeerConn(ctx context.Context, conn *peer.Conn) {
}
}

func (s *server) init(ctx context.Context) {
func (s *agent) init(ctx context.Context) {
// Clients' should ignore the host key when connecting.
// The agent needs to authenticate with coderd to SSH,
// so SSH authentication doesn't improve security.
Expand Down Expand Up @@ -221,7 +191,7 @@ func (s *server) init(ctx context.Context) {
go s.run(ctx)
}

func (*server) handleSSHSession(session ssh.Session) error {
func (*agent) handleSSHSession(session ssh.Session) error {
var (
command string
args = []string{}
Expand Down Expand Up @@ -316,7 +286,7 @@ func (*server) handleSSHSession(session ssh.Session) error {
}

// isClosed returns whether the API is closed or not.
func (s *server) isClosed() bool {
func (s *agent) isClosed() bool {
select {
case <-s.closed:
return true
Expand All @@ -325,7 +295,7 @@ func (s *server) isClosed() bool {
}
}

func (s *server) Close() error {
func (s *agent) Close() error {
s.closeMutex.Lock()
defer s.closeMutex.Unlock()
if s.isClosed() {
Expand Down
6 changes: 4 additions & 2 deletions agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ func TestAgent(t *testing.T) {
t.Cleanup(func() {
_ = conn.Close()
})
sshClient, err := agent.DialSSHClient(conn)
client := agent.Conn{conn}
sshClient, err := client.SSHClient()
require.NoError(t, err)
session, err := sshClient.NewSession()
require.NoError(t, err)
Expand All @@ -64,7 +65,8 @@ func TestAgent(t *testing.T) {
t.Cleanup(func() {
_ = conn.Close()
})
sshClient, err := agent.DialSSHClient(conn)
client := &agent.Conn{conn}
sshClient, err := client.SSHClient()
require.NoError(t, err)
session, err := sshClient.NewSession()
require.NoError(t, err)
Expand Down
50 changes: 50 additions & 0 deletions agent/conn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package agent

import (
"context"
"net"

"golang.org/x/crypto/ssh"
"golang.org/x/xerrors"

"github.com/coder/coder/peer"
)

// Conn wraps a peer connection with helper functions to
// communicate with the agent.
type Conn struct {
*peer.Conn
}

// SSH dials the built-in SSH server.
func (c *Conn) SSH() (net.Conn, error) {
channel, err := c.Dial(context.Background(), "ssh", &peer.ChannelOptions{
Protocol: "ssh",
})
if err != nil {
return nil, xerrors.Errorf("dial: %w", err)
}
return channel.NetConn(), nil
}

// SSHClient calls SSH to create a client that uses a weak cipher
// for high throughput.
func (c *Conn) SSHClient() (*ssh.Client, error) {
netConn, err := c.SSH()
if err != nil {
return nil, xerrors.Errorf("ssh: %w", err)
}
sshConn, channels, requests, err := ssh.NewClientConn(netConn, "localhost:22", &ssh.ClientConfig{
Config: ssh.Config{
Ciphers: []string{"arcfour"},
},
// SSH host validation isn't helpful, because obtaining a peer
// connection already signifies user-intent to dial a workspace.
// #nosec
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
})
if err != nil {
return nil, xerrors.Errorf("ssh conn: %w", err)
}
return ssh.NewClient(sshConn, channels, requests), nil
}
106 changes: 61 additions & 45 deletions cli/ssh.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
package cli

import (
"fmt"
"os"

"github.com/pion/webrtc/v3"
"github.com/spf13/cobra"
"golang.org/x/crypto/ssh"
"golang.org/x/term"
"golang.org/x/xerrors"

"github.com/coder/coder/agent"
"github.com/coder/coder/peer"
"github.com/coder/coder/peerbroker"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/database"
)

func workspaceSSH() *cobra.Command {
Expand All @@ -26,58 +24,76 @@ func workspaceSSH() *cobra.Command {
if err != nil {
return err
}
if workspace.LatestBuild.Transition == database.WorkspaceTransitionDelete {
return xerrors.New("workspace is deleting...")
}
resources, err := client.WorkspaceResourcesByBuild(cmd.Context(), workspace.LatestBuild.ID)
if err != nil {
return err
}

resourceByAddress := make(map[string]codersdk.WorkspaceResource)
for _, resource := range resources {
_, _ = fmt.Printf("Got resource: %+v\n", resource)
if resource.Agent == nil {
continue
}

dialed, err := client.DialWorkspaceAgent(cmd.Context(), resource.ID)
if err != nil {
return err
}
stream, err := dialed.NegotiateConnection(cmd.Context())
if err != nil {
return err
resourceByAddress[resource.Address] = resource
}
var resourceAddress string
if len(args) >= 2 {
resourceAddress = args[1]
} else {
// No resource name was provided!
if len(resourceByAddress) > 1 {
// List available resources to connect into?
return xerrors.Errorf("multiple agents")
}
conn, err := peerbroker.Dial(stream, []webrtc.ICEServer{{
URLs: []string{"stun:stun.l.google.com:19302"},
}}, &peer.ConnOptions{})
if err != nil {
return err
for _, resource := range resourceByAddress {
resourceAddress = resource.Address
break
}
client, err := agent.DialSSHClient(conn)
if err != nil {
return err
}
resource, exists := resourceByAddress[resourceAddress]
if !exists {
resourceKeys := make([]string, 0)
for resourceKey := range resourceByAddress {
resourceKeys = append(resourceKeys, resourceKey)
}
return xerrors.Errorf("no sshable agent with address %q: %+v", resourceAddress, resourceKeys)
}
if resource.Agent.LastConnectedAt == nil {
return xerrors.Errorf("agent hasn't connected yet")
}

session, err := client.NewSession()
if err != nil {
return err
}
_, _ = term.MakeRaw(int(os.Stdin.Fd()))
err = session.RequestPty("xterm-256color", 128, 128, ssh.TerminalModes{
ssh.OCRNL: 1,
})
if err != nil {
return err
}
session.Stdin = os.Stdin
session.Stdout = os.Stdout
session.Stderr = os.Stderr
err = session.Shell()
if err != nil {
return err
}
err = session.Wait()
if err != nil {
return err
}
conn, err := client.DialWorkspaceAgent(cmd.Context(), resource.ID, nil, nil)
if err != nil {
return err
}
sshClient, err := conn.SSHClient()
if err != nil {
return err
}

sshSession, err := sshClient.NewSession()
if err != nil {
return err
}
_, _ = term.MakeRaw(int(os.Stdin.Fd()))
err = sshSession.RequestPty("xterm-256color", 128, 128, ssh.TerminalModes{
ssh.OCRNL: 1,
})
if err != nil {
return err
}
sshSession.Stdin = os.Stdin
sshSession.Stdout = os.Stdout
sshSession.Stderr = os.Stderr
err = sshSession.Shell()
if err != nil {
return err
}
err = sshSession.Wait()
if err != nil {
return err
}

return nil
Expand Down
Loading