-
Notifications
You must be signed in to change notification settings - Fork 19
Add coder agent start #311
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,270 @@ | ||
| package cmd | ||
|
|
||
| import ( | ||
| "context" | ||
| "encoding/json" | ||
| "fmt" | ||
| "io" | ||
| "net" | ||
| "net/url" | ||
| "os" | ||
| "strings" | ||
| "time" | ||
|
|
||
| "cdr.dev/slog" | ||
| "cdr.dev/slog/sloggers/sloghuman" | ||
| "github.com/hashicorp/yamux" | ||
| "github.com/pion/webrtc/v3" | ||
| "github.com/spf13/cobra" | ||
| "golang.org/x/xerrors" | ||
| "nhooyr.io/websocket" | ||
|
|
||
| "cdr.dev/coder-cli/internal/x/xcobra" | ||
| "cdr.dev/coder-cli/internal/x/xwebrtc" | ||
| "cdr.dev/coder-cli/pkg/proto" | ||
| ) | ||
|
|
||
| func agentCmd() *cobra.Command { | ||
| cmd := &cobra.Command{ | ||
| Use: "agent", | ||
| Short: "Run the workspace agent", | ||
| Long: "Connect to Coder and start running a p2p agent", | ||
| Hidden: true, | ||
| } | ||
|
|
||
| cmd.AddCommand( | ||
| startCmd(), | ||
| ) | ||
| return cmd | ||
| } | ||
|
|
||
| func startCmd() *cobra.Command { | ||
| var ( | ||
| token string | ||
| ) | ||
| cmd := &cobra.Command{ | ||
| Use: "start [coderURL] --token=[token]", | ||
| Args: xcobra.ExactArgs(1), | ||
| Short: "starts the coder agent", | ||
| Long: "starts the coder agent", | ||
| Example: `# start the agent and connect with a Coder agent token | ||
|
|
||
| coder agent start https://my-coder.com --token xxxx-xxxx | ||
|
|
||
| # start the agent and use CODER_AGENT_TOKEN env var for auth token | ||
|
|
||
| coder agent start https://my-coder.com | ||
| `, | ||
| RunE: func(cmd *cobra.Command, args []string) error { | ||
| ctx := cmd.Context() | ||
| log := slog.Make(sloghuman.Sink(cmd.OutOrStdout())) | ||
|
|
||
| // Pull the URL from the args and do some sanity check. | ||
| rawURL := args[0] | ||
| if rawURL == "" || !strings.HasPrefix(rawURL, "http") { | ||
| return xerrors.Errorf("invalid URL") | ||
| } | ||
| u, err := url.Parse(rawURL) | ||
| if err != nil { | ||
| return xerrors.Errorf("parse url: %w", err) | ||
| } | ||
| // Remove the trailing '/' if any. | ||
| u.Path = "/api/private/envagent/listen" | ||
|
|
||
| if token == "" { | ||
| var ok bool | ||
| token, ok = os.LookupEnv("CODER_AGENT_TOKEN") | ||
| if !ok { | ||
| return xerrors.New("must pass --token or set the CODER_AGENT_TOKEN env variable") | ||
| } | ||
| } | ||
|
|
||
| q := u.Query() | ||
| q.Set("service_token", token) | ||
| u.RawQuery = q.Encode() | ||
|
|
||
| ctx, cancelFunc := context.WithTimeout(ctx, time.Second*15) | ||
| defer cancelFunc() | ||
| log.Info(ctx, "connecting to broker", slog.F("url", u.String())) | ||
| conn, res, err := websocket.Dial(ctx, u.String(), nil) | ||
| if err != nil { | ||
| return fmt.Errorf("dial: %w", err) | ||
| } | ||
| _ = res.Body.Close() | ||
| nc := websocket.NetConn(context.Background(), conn, websocket.MessageBinary) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know this is an MVP but there's a lot going on in this command so all of this should move into it's own package (or potentially even it's own repo if we want to use it within our monorepo as well).
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I pinky promise I'll do that if this ends up being a thing we stick with ;) |
||
| session, err := yamux.Server(nc, nil) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nitpicky since it doesn't really matter since yamux is bidirectional, but this should be a client because it's on the client end of a websocket connection
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @kylecarbs Any objections? (since this part is your code) |
||
| if err != nil { | ||
| return fmt.Errorf("open: %w", err) | ||
| } | ||
| log.Info(ctx, "connected to broker. awaiting connection requests") | ||
| for { | ||
| st, err := session.AcceptStream() | ||
| if err != nil { | ||
| return fmt.Errorf("accept stream: %w", err) | ||
| } | ||
| stream := &stream{ | ||
| logger: log.Named(fmt.Sprintf("stream %d", st.StreamID())), | ||
| stream: st, | ||
| } | ||
| go stream.listen() | ||
| } | ||
| }, | ||
| } | ||
|
|
||
| cmd.Flags().StringVar(&token, "token", "", "coder agent token") | ||
| return cmd | ||
| } | ||
|
|
||
| type stream struct { | ||
| stream *yamux.Stream | ||
| logger slog.Logger | ||
|
|
||
| rtc *webrtc.PeerConnection | ||
| } | ||
|
|
||
| // writes an error and closes. | ||
| func (s *stream) fatal(err error) { | ||
| _ = s.write(proto.Message{ | ||
| Error: err.Error(), | ||
| }) | ||
| s.logger.Error(context.Background(), err.Error(), slog.Error(err)) | ||
| _ = s.stream.Close() | ||
| } | ||
|
|
||
| func (s *stream) listen() { | ||
| decoder := json.NewDecoder(s.stream) | ||
| for { | ||
| var msg proto.Message | ||
| err := decoder.Decode(&msg) | ||
| if err == io.EOF { | ||
| break | ||
| } | ||
| if err != nil { | ||
| s.fatal(err) | ||
| return | ||
| } | ||
| s.processMessage(msg) | ||
| } | ||
| } | ||
|
|
||
| func (s *stream) write(msg proto.Message) error { | ||
| d, err := json.Marshal(&msg) | ||
| if err != nil { | ||
| return err | ||
| } | ||
| _, err = s.stream.Write(d) | ||
| if err != nil { | ||
| return err | ||
| } | ||
| return nil | ||
| } | ||
|
|
||
| func (s *stream) processMessage(msg proto.Message) { | ||
| s.logger.Debug(context.Background(), "processing message", slog.F("msg", msg)) | ||
|
|
||
| if msg.Error != "" { | ||
| s.fatal(xerrors.New(msg.Error)) | ||
| return | ||
| } | ||
|
|
||
| if msg.Candidate != "" { | ||
| if s.rtc == nil { | ||
| s.fatal(xerrors.New("rtc connection must be started before candidates are sent")) | ||
| return | ||
| } | ||
|
|
||
| s.logger.Debug(context.Background(), "accepted ice candidate", slog.F("candidate", msg.Candidate)) | ||
| err := proto.AcceptICECandidate(s.rtc, &msg) | ||
| if err != nil { | ||
| s.fatal(err) | ||
| return | ||
| } | ||
| } | ||
|
|
||
| if msg.Offer != nil { | ||
| rtc, err := xwebrtc.NewPeerConnection() | ||
| if err != nil { | ||
| s.fatal(fmt.Errorf("create connection: %w", err)) | ||
| return | ||
| } | ||
| flushCandidates := proto.ProxyICECandidates(rtc, s.stream) | ||
|
|
||
| err = rtc.SetRemoteDescription(*msg.Offer) | ||
| if err != nil { | ||
| s.fatal(fmt.Errorf("set remote desc: %w", err)) | ||
| return | ||
| } | ||
| answer, err := rtc.CreateAnswer(nil) | ||
| if err != nil { | ||
| s.fatal(fmt.Errorf("create answer: %w", err)) | ||
| return | ||
| } | ||
| err = rtc.SetLocalDescription(answer) | ||
| if err != nil { | ||
| s.fatal(fmt.Errorf("set local desc: %w", err)) | ||
| return | ||
| } | ||
| flushCandidates() | ||
|
|
||
| err = s.write(proto.Message{ | ||
| Answer: rtc.LocalDescription(), | ||
| }) | ||
| if err != nil { | ||
| s.fatal(fmt.Errorf("send local desc: %w", err)) | ||
| return | ||
| } | ||
|
|
||
| rtc.OnConnectionStateChange(func(pcs webrtc.PeerConnectionState) { | ||
| s.logger.Info(context.Background(), "state changed", slog.F("new", pcs)) | ||
| }) | ||
| rtc.OnDataChannel(s.processDataChannel) | ||
| s.rtc = rtc | ||
| } | ||
| } | ||
|
|
||
| func (s *stream) processDataChannel(channel *webrtc.DataChannel) { | ||
| if channel.Protocol() == "ping" { | ||
| channel.OnOpen(func() { | ||
| rw, err := channel.Detach() | ||
| if err != nil { | ||
| return | ||
| } | ||
| d := make([]byte, 64) | ||
| _, _ = rw.Read(d) | ||
| _, _ = rw.Write(d) | ||
| }) | ||
| return | ||
| } | ||
|
|
||
| prto, port, err := xwebrtc.ParseProxyDataChannel(channel) | ||
| if err != nil { | ||
| s.fatal(fmt.Errorf("failed to parse proxy data channel: %w", err)) | ||
| return | ||
| } | ||
| if prto != "tcp" { | ||
| s.fatal(fmt.Errorf("client provided unsupported protocol: %s", prto)) | ||
| return | ||
| } | ||
|
|
||
| conn, err := net.Dial(prto, fmt.Sprintf("localhost:%d", port)) | ||
| if err != nil { | ||
| s.fatal(fmt.Errorf("failed to dial client port: %d", port)) | ||
| return | ||
| } | ||
|
|
||
| channel.OnOpen(func() { | ||
| s.logger.Debug(context.Background(), "proxying data channel to local port", slog.F("port", port)) | ||
| rw, err := channel.Detach() | ||
| if err != nil { | ||
| _ = channel.Close() | ||
| s.logger.Error(context.Background(), "detach client data channel", slog.Error(err)) | ||
| return | ||
| } | ||
| go func() { | ||
| _, _ = io.Copy(rw, conn) | ||
| }() | ||
| go func() { | ||
| _, _ = io.Copy(conn, rw) | ||
| }() | ||
| }) | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| package xwebrtc | ||
|
|
||
| import ( | ||
| "context" | ||
| "errors" | ||
| "fmt" | ||
| "net" | ||
| "strconv" | ||
| "time" | ||
|
|
||
| "github.com/pion/webrtc/v3" | ||
| ) | ||
|
|
||
| // WaitForDataChannelOpen waits for the data channel to have the open state. | ||
| // By default, it waits 15 seconds. | ||
| func WaitForDataChannelOpen(ctx context.Context, channel *webrtc.DataChannel) error { | ||
| if channel.ReadyState() == webrtc.DataChannelStateOpen { | ||
| return nil | ||
| } | ||
| ctx, cancelFunc := context.WithTimeout(ctx, time.Second*15) | ||
| defer cancelFunc() | ||
| channel.OnOpen(func() { | ||
| cancelFunc() | ||
| }) | ||
| <-ctx.Done() | ||
| if ctx.Err() == context.DeadlineExceeded { | ||
| return ctx.Err() | ||
| } | ||
| return nil | ||
| } | ||
|
|
||
| // NewProxyDataChannel creates a new data channel for proxying. | ||
| func NewProxyDataChannel(conn *webrtc.PeerConnection, name, protocol string, port uint16) (*webrtc.DataChannel, error) { | ||
| proto := fmt.Sprintf("%s:%d", protocol, port) | ||
| ordered := true | ||
| return conn.CreateDataChannel(name, &webrtc.DataChannelInit{ | ||
| Protocol: &proto, | ||
| Ordered: &ordered, | ||
| }) | ||
| } | ||
|
|
||
| // ParseProxyDataChannel parses a data channel to get the protocol and port. | ||
| func ParseProxyDataChannel(channel *webrtc.DataChannel) (string, uint16, error) { | ||
| if channel.Protocol() == "" { | ||
| return "", 0, errors.New("data channel is not a proxy") | ||
| } | ||
| host, port, err := net.SplitHostPort(channel.Protocol()) | ||
| if err != nil { | ||
| return "", 0, fmt.Errorf("split protocol: %w", err) | ||
| } | ||
| p, err := strconv.ParseInt(port, 10, 16) | ||
| if err != nil { | ||
| return "", 0, fmt.Errorf("parse port: %w", err) | ||
| } | ||
| return host, uint16(p), nil | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you think
coderURLshould be a--urlflag? I'm imagining a case where this is defaulted tocoder.comin the future for a SaaS offering. However, I know this isn't ideal for now since thecoder agent startcmd would fail without the flags specified, so more so opening this up for discussion. Could potentially defaultcoderURLtocoder.comin the future as well if it's not specified as another option.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this is suuuper mvp just to get a POC working in coder, a real implementation will follow once we get an idea of how scalable and stable this is.