From 2493d27d528cdea777dfe72e344b18e77d35303b Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Tue, 1 Oct 2024 11:40:42 -0500 Subject: [PATCH] feat: add terminal sharing --- cmd/wasm/main.go | 101 +++++++++ cmd/wush/main.go | 1 + cmd/wush/share.go | 273 ++++++++++++++++++++++++ go.mod | 8 +- site/app/components/Terminal.client.tsx | 17 +- site/cors-config.json | 17 +- site/types/wush_js.d.ts | 28 +-- 7 files changed, 423 insertions(+), 22 deletions(-) create mode 100644 cmd/wush/share.go diff --git a/cmd/wasm/main.go b/cmd/wasm/main.go index b0560e6..33dd92d 100644 --- a/cmd/wasm/main.go +++ b/cmd/wasm/main.go @@ -6,6 +6,7 @@ import ( "bytes" "context" "fmt" + "io" "log" "log/slog" "net" @@ -116,6 +117,30 @@ func newWush(jsConfig js.Value) map[string]any { go sess.Run() + return map[string]any{ + "close": js.FuncOf(func(this js.Value, args []js.Value) any { + return sess.Close() != nil + }), + "resize": js.FuncOf(func(this js.Value, args []js.Value) any { + rows := args[0].Int() + cols := args[1].Int() + return sess.Resize(rows, cols) != nil + }), + } + }), + "share": js.FuncOf(func(this js.Value, args []js.Value) any { + if len(args) != 1 { + log.Printf("Usage: ssh({})") + return nil + } + + sess := &shareSession{ + ts: ts, + cfg: args[0], + } + + go sess.Run() + return map[string]any{ "close": js.FuncOf(func(this js.Value, args []js.Value) any { return sess.Close() != nil @@ -263,6 +288,82 @@ func (s *sshSession) Run() { } } +type shareSession struct { + ts *tsnet.Server + cfg js.Value + + conn net.Conn + pendingResizeRows int + pendingResizeCols int +} + +func (s *shareSession) Close() error { + if s.conn == nil { + // We never had a chance to open the session, ignore the close request. + return nil + } + return s.conn.Close() +} + +func (s *shareSession) Resize(rows, cols int) error { + if s.conn == nil { + s.pendingResizeRows = rows + s.pendingResizeCols = cols + return nil + } + + return nil + // return s.session.WindowChange(rows, cols) +} + +func (s *shareSession) Run() { + writeFn := s.cfg.Get("writeFn") + writeErrorFn := s.cfg.Get("writeErrorFn") + setReadFn := s.cfg.Get("setReadFn") + // rows := s.cfg.Get("rows").Int() + // cols := s.cfg.Get("cols").Int() + timeoutSeconds := 5.0 + if jsTimeoutSeconds := s.cfg.Get("timeoutSeconds"); jsTimeoutSeconds.Type() == js.TypeNumber { + timeoutSeconds = jsTimeoutSeconds.Float() + } + onConnectionProgress := s.cfg.Get("onConnectionProgress") + onConnected := s.cfg.Get("onConnected") + onDone := s.cfg.Get("onDone") + defer onDone.Invoke() + + writeError := func(label string, err error) { + writeErrorFn.Invoke(fmt.Sprintf("%s Error: %v\r\n", label, err)) + } + reportProgress := func(message string) { + onConnectionProgress.Invoke(message) + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSeconds*float64(time.Second))) + defer cancel() + reportProgress(fmt.Sprintf("Connecting...")) + c, err := s.ts.Dial(ctx, "tcp", net.JoinHostPort("100.64.0.0", "33")) + if err != nil { + writeError("Dial", err) + return + } + defer c.Close() + s.conn = c + reportProgress(fmt.Sprintf("Connected")) + + setReadFn.Invoke(js.FuncOf(func(this js.Value, args []js.Value) any { + input := args[0].String() + _, err := c.Write([]byte(input)) + if err != nil { + writeError("Write Input", err) + } + return nil + })) + + onConnected.Invoke() + tw := termWriter{writeFn} + _, _ = io.Copy(tw, c) +} + type termWriter struct { f js.Value } diff --git a/cmd/wush/main.go b/cmd/wush/main.go index cf935a6..7eb4543 100644 --- a/cmd/wush/main.go +++ b/cmd/wush/main.go @@ -52,6 +52,7 @@ func main() { rsyncCmd(), cpCmd(), portForwardCmd(), + shareCmd(), }, Options: []serpent.Option{ { diff --git a/cmd/wush/share.go b/cmd/wush/share.go new file mode 100644 index 0000000..1d572bf --- /dev/null +++ b/cmd/wush/share.go @@ -0,0 +1,273 @@ +package main + +import ( + "bytes" + "fmt" + "io" + "io/fs" + "log" + "log/slog" + "os" + "os/exec" + "os/signal" + "sync" + "syscall" + "time" + + "github.com/coder/serpent" + "github.com/coder/wush/cliui" + "github.com/coder/wush/overlay" + "github.com/coder/wush/tsserver" + "github.com/creack/pty" + "github.com/mattn/go-isatty" + "golang.org/x/crypto/ssh/terminal" + "tailscale.com/net/netns" +) + +func shareCmd() *serpent.Command { + var ( + overlayType string + verbose bool + enabled = []string{} + disabled = []string{} + ) + return &serpent.Command{ + Use: "share", + Aliases: []string{}, + Short: "Share a terminal.", + Long: "Share a terminal.", + Handler: func(inv *serpent.Invocation) error { + ctx := inv.Context() + defer fmt.Println("[exited]") + + // Switch to the alternate screen buffer + fmt.Print("\033[?1049h") + // Reset cursor to the top-left corner + fmt.Print("\033[H") + // Switch back to the main screen buffer on exit + defer fmt.Print("\033[?1049l") + + var logSink io.Writer = io.Discard + if verbose { + logSink = inv.Stderr + } + logger := slog.New(slog.NewTextHandler(logSink, nil)) + hlog := func(format string, args ...any) { + fmt.Fprintf(inv.Stderr, format+"\n", args...) + } + dm, err := tsserver.DERPMapTailscale(ctx) + if err != nil { + return err + } + // r := overlay.NewReceiveOverlay(logger, hlog, dm) + r := overlay.NewReceiveOverlay(logger, func(format string, args ...any) {}, dm) + + switch overlayType { + case "derp": + err = r.PickDERPHome(ctx) + if err != nil { + return err + } + go r.ListenOverlayDERP(ctx) + + case "stun": + waitStun, err := r.ListenOverlaySTUN(ctx) + if err != nil { + return fmt.Errorf("get stun addr: %w", err) + } + <-waitStun + + default: + return fmt.Errorf("unknown overlay type: %s", overlayType) + } + + // Ensure we always print the auth key on stdout + if isatty.IsTerminal(os.Stdout.Fd()) { + hlog("Your auth key is:") + fmt.Println(" |", cliui.Code(r.ClientAuth().AuthKey())) + fmt.Println(" |", cliui.Code("http://localhost:5173/connect#"+r.ClientAuth().AuthKey())) + hlog("Use this key to authenticate other " + cliui.Code("wush") + " commands to this instance.") + } else { + fmt.Println(cliui.Code(r.ClientAuth().AuthKey())) + hlog("The auth key has been printed to stdout") + } + + s, err := tsserver.NewServer(ctx, logger, r) + if err != nil { + return err + } + + go s.ListenAndServe(ctx) + netns.SetDialerOverride(s.Dialer()) + ts, err := newTSNet("receive") + if err != nil { + return err + } + + ts.Up(ctx) + + hlog("WireGuard is ready") + + ll, err := ts.Listen("tcp", ":33") + if err != nil { + return err + } + + shell := os.Getenv("SHELL") + if shell == "" { + shell = "/bin/sh" + } + + // Save the current state of the terminal + oldState, err := terminal.MakeRaw(int(os.Stdin.Fd())) + if err != nil { + panic(err) + } + defer func() { + _ = terminal.Restore(int(os.Stdin.Fd()), oldState) + }() + + cmd := exec.Command(shell) + ptmx, err := pty.Start(cmd) + if err != nil { + log.Fatal(err) + } + defer func() { _ = ptmx.Close() }() + + // Handle pty size. + ch := make(chan os.Signal, 1) + signal.Notify(ch, syscall.SIGWINCH) + go func() { + for range ch { + if err := pty.InheritSize(os.Stdin, ptmx); err != nil { + log.Printf("error resizing pty: %s", err) + } + } + }() + ch <- syscall.SIGWINCH // Initial resize. + defer func() { signal.Stop(ch); close(ch) }() // Cleanup signals when done. + + // Copy stdin to the pty and the pty to stdout. + // NOTE: The goroutine will keep reading until the next keystroke before returning. + go func() { _, _ = io.Copy(ptmx, os.Stdin) }() + mw := &multiWriter{wrs: map[int64]io.Writer{}} + buf := bytes.NewBuffer(nil) + mw.AddWriter(buf) + mw.AddWriter(unclose{os.Stdout}) + go func() { _, _ = io.Copy(mw, ptmx) }() + + go func() { + cmd.Wait() + mw.Close() + ll.Close() + }() + + for { + conn, err := ll.Accept() + if err != nil { + return nil + } + + mw.lock() + _, _ = io.Copy(conn, buf) + mw.unlock() + + close := mw.AddWriter(conn) + go func() { defer close(); _, _ = io.Copy(ptmx, conn) }() + } + + }, + Options: []serpent.Option{ + { + Flag: "overlay-type", + Default: "derp", + Value: serpent.EnumOf(&overlayType, "derp", "stun"), + }, + { + Flag: "verbose", + FlagShorthand: "v", + Description: "Enable verbose logging.", + Default: "false", + Value: serpent.BoolOf(&verbose), + }, + { + Flag: "enable", + Description: "Server options to enable.", + Default: "ssh,cp,port-forward", + Value: serpent.EnumArrayOf(&enabled, "ssh", "cp", "port-forward"), + }, + { + Flag: "disable", + Description: "Server options to disable.", + Default: "", + Value: serpent.EnumArrayOf(&disabled, "ssh", "cp", "port-forward"), + }, + }, + } +} + +type multiWriter struct { + mu sync.Mutex + wrs map[int64]io.Writer + closed bool +} + +func (mw *multiWriter) lock() { + mw.mu.Lock() +} + +func (mw *multiWriter) unlock() { + mw.mu.Unlock() +} + +func (mw *multiWriter) Close() error { + mw.mu.Lock() + defer mw.mu.Unlock() + if mw.closed { + return nil + } + mw.closed = true + for _, w := range mw.wrs { + if closer, ok := w.(io.Closer); ok { + _ = closer.Close() + } + } + return nil +} + +func (mw *multiWriter) Write(p []byte) (int, error) { + mw.mu.Lock() + defer mw.mu.Unlock() + if mw.closed { + return 0, fs.ErrClosed + } + // var total int + for _, w := range mw.wrs { + n, err := w.Write(p) + if err != nil { + continue + } + _ = n + } + return len(p), nil +} + +func (mw *multiWriter) AddWriter(w io.Writer) func() { + mw.mu.Lock() + defer mw.mu.Unlock() + id := time.Now().UnixNano() + mw.wrs[id] = w + return func() { + mw.mu.Lock() + defer mw.mu.Unlock() + delete(mw.wrs, id) + } +} + +type unclose struct { + w io.Writer +} + +func (u unclose) Write(p []byte) (int, error) { + return u.w.Write(p) +} diff --git a/go.mod b/go.mod index 0bafc78..ba8971d 100644 --- a/go.mod +++ b/go.mod @@ -9,10 +9,14 @@ replace github.com/gliderlabs/ssh => github.com/coder/ssh v0.0.0-20231128192721- require ( cdr.dev/slog v1.6.2-0.20240126064726-20367d4aede6 github.com/btcsuite/btcd/btcutil v1.1.6 + github.com/charmbracelet/bubbles v0.20.0 + github.com/charmbracelet/bubbletea v1.1.0 github.com/charmbracelet/huh v0.6.0 + github.com/charmbracelet/lipgloss v0.13.0 github.com/coder/coder/v2 v2.14.2 github.com/coder/pretty v0.0.0-20230908205945-e89ba86370e0 github.com/coder/serpent v0.8.0 + github.com/creack/pty v1.1.21 github.com/go-chi/chi/v5 v5.1.0 github.com/klauspost/compress v1.17.10 github.com/mattn/go-isatty v0.0.20 @@ -72,16 +76,12 @@ require ( github.com/catppuccin/go v0.2.0 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect - github.com/charmbracelet/bubbles v0.20.0 // indirect - github.com/charmbracelet/bubbletea v1.1.0 // indirect - github.com/charmbracelet/lipgloss v0.13.0 // indirect github.com/charmbracelet/x/ansi v0.2.3 // indirect github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0 // indirect github.com/charmbracelet/x/term v0.2.0 // indirect github.com/coder/terraform-provider-coder v0.23.0 // indirect github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6 // indirect github.com/coreos/go-oidc/v3 v3.11.0 // indirect - github.com/creack/pty v1.1.21 // indirect github.com/dblohm7/wingoes v0.0.0-20240119213807-a09d6be7affa // indirect github.com/digitalocean/go-smbios v0.0.0-20180907143718-390a4f403a8e // indirect github.com/dustin/go-humanize v1.0.1 // indirect diff --git a/site/app/components/Terminal.client.tsx b/site/app/components/Terminal.client.tsx index b8a889c..6a1b7b6 100644 --- a/site/app/components/Terminal.client.tsx +++ b/site/app/components/Terminal.client.tsx @@ -17,6 +17,7 @@ const WushTerminal: React.FC = ({ authKey }) => { const fitAddonRef = useRef(); const wushInitialized = useContext(WushContext); const sshSessionRef = useRef(); + const wushRef = useRef(); useEffect(() => { if (!wushInitialized) { @@ -65,7 +66,8 @@ const WushTerminal: React.FC = ({ authKey }) => { resizeObserver.observe(terminalRef.current); newWush({ authKey: authKey }).then((wush) => { - const sshSession = wush.ssh({ + wushRef.current = wush; + const sshSession = wush.share({ writeFn(input) { term.write(input); }, @@ -77,8 +79,13 @@ const WushTerminal: React.FC = ({ authKey }) => { }, rows: term.rows, cols: term.cols, - onConnectionProgress: (msg) => {}, - onConnected: () => {}, + onConnectionProgress: (msg) => { + term.writeln(msg); + }, + onConnected: () => { + term.writeln(""); + term.clear(); + }, onDone() { resizeObserver?.disconnect(); term.dispose(); @@ -107,6 +114,10 @@ const WushTerminal: React.FC = ({ authKey }) => { sshSessionRef.current.close(); sshSessionRef.current = null; } + if (wushRef.current) { + wushRef.current.stop(); + wushRef.current = null; + } }; }, [authKey, wushInitialized]); diff --git a/site/cors-config.json b/site/cors-config.json index 5768cb4..2d75eda 100644 --- a/site/cors-config.json +++ b/site/cors-config.json @@ -1,8 +1,19 @@ [ { - "origin": ["https://*.wush-1n6.pages.dev", "https://dev.wush.dev", "https://wush.dev"], - "method": ["GET", "HEAD", "OPTIONS"], - "responseHeader": ["Content-Type", "Content-Encoding"], + "origin": [ + "https://*.wush-1n6.pages.dev", + "https://stg.wush.dev", + "https://wush.dev" + ], + "method": [ + "GET", + "HEAD", + "OPTIONS" + ], + "responseHeader": [ + "Content-Type", + "Content-Encoding" + ], "maxAgeSeconds": 3600 } ] diff --git a/site/types/wush_js.d.ts b/site/types/wush_js.d.ts index 77cd46b..4cb1a1e 100644 --- a/site/types/wush_js.d.ts +++ b/site/types/wush_js.d.ts @@ -4,20 +4,24 @@ declare global { interface Wush { run(callbacks: WushCallbacks): void; - ssh(termConfig: { - writeFn: (data: string) => void; - writeErrorFn: (err: string) => void; - setReadFn: (readFn: (data: string) => void) => void; - rows: number; - cols: number; - /** Defaults to 5 seconds */ - timeoutSeconds?: number; - onConnectionProgress: (message: string) => void; - onConnected: () => void; - onDone: () => void; - }): WushSSHSession; + stop(): void; + ssh(termConfig: TerminalConfig): WushSSHSession; + share(termConfig: TerminalConfig): WushSSHSession; } + type TerminalConfig = { + writeFn: (data: string) => void; + writeErrorFn: (err: string) => void; + setReadFn: (readFn: (data: string) => void) => void; + rows: number; + cols: number; + /** Defaults to 5 seconds */ + timeoutSeconds?: number; + onConnectionProgress: (message: string) => void; + onConnected: () => void; + onDone: () => void; + }; + interface WushSSHSession { resize(rows: number, cols: number): boolean; close(): boolean;