diff --git a/cli/vpndaemon_windows.go b/cli/vpndaemon_windows.go index d09733817d787..227bd0fe8e0db 100644 --- a/cli/vpndaemon_windows.go +++ b/cli/vpndaemon_windows.go @@ -41,7 +41,10 @@ func (r *RootCmd) vpnDaemonRun() *serpent.Command { }, Handler: func(inv *serpent.Invocation) error { ctx := inv.Context() - logger := inv.Logger.AppendSinks(sloghuman.Sink(inv.Stderr)).Leveled(slog.LevelDebug) + sinks := []slog.Sink{ + sloghuman.Sink(inv.Stderr), + } + logger := inv.Logger.AppendSinks(sinks...).Leveled(slog.LevelDebug) if rpcReadHandleInt < 0 || rpcWriteHandleInt < 0 { return xerrors.Errorf("rpc-read-handle (%v) and rpc-write-handle (%v) must be positive", rpcReadHandleInt, rpcWriteHandleInt) @@ -60,7 +63,11 @@ func (r *RootCmd) vpnDaemonRun() *serpent.Command { defer pipe.Close() logger.Info(ctx, "starting tunnel") - tunnel, err := vpn.NewTunnel(ctx, logger, pipe, vpn.NewClient()) + tunnel, err := vpn.NewTunnel(ctx, logger, pipe, vpn.NewClient(), + vpn.UseOSNetworkingStack(), + vpn.UseAsLogger(), + vpn.UseCustomLogSinks(sinks...), + ) if err != nil { return xerrors.Errorf("create new tunnel for client: %w", err) } diff --git a/go.mod b/go.mod index 3268e221a9020..8e451210530f3 100644 --- a/go.mod +++ b/go.mod @@ -423,7 +423,7 @@ require ( go.opentelemetry.io/proto/otlp v1.4.0 // indirect go4.org/mem v0.0.0-20220726221520-4f986261bf13 // indirect golang.org/x/time v0.9.0 // indirect - golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect + golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 // indirect golang.zx2c4.com/wireguard/windows v0.5.3 // indirect google.golang.org/appengine v1.6.8 // indirect diff --git a/tailnet/conn.go b/tailnet/conn.go index ff96211702485..6487dff4e8550 100644 --- a/tailnet/conn.go +++ b/tailnet/conn.go @@ -116,6 +116,9 @@ type Options struct { Router router.Router // TUNDev is optional, and is passed to the underlying wireguard engine. TUNDev tun.Device + // WireguardMonitor is optional, and is passed to the underlying wireguard + // engine. + WireguardMonitor *netmon.Monitor } // TelemetrySink allows tailnet.Conn to send network telemetry to the Coder @@ -171,13 +174,15 @@ func NewConn(options *Options) (conn *Conn, err error) { nodeID = tailcfg.NodeID(uid) } - wireguardMonitor, err := netmon.New(Logger(options.Logger.Named("net.wgmonitor"))) - if err != nil { - return nil, xerrors.Errorf("create wireguard link monitor: %w", err) + if options.WireguardMonitor == nil { + options.WireguardMonitor, err = netmon.New(Logger(options.Logger.Named("net.wgmonitor"))) + if err != nil { + return nil, xerrors.Errorf("create wireguard link monitor: %w", err) + } } defer func() { if err != nil { - wireguardMonitor.Close() + options.WireguardMonitor.Close() } }() @@ -186,7 +191,7 @@ func NewConn(options *Options) (conn *Conn, err error) { } sys := new(tsd.System) wireguardEngine, err := wgengine.NewUserspaceEngine(Logger(options.Logger.Named("net.wgengine")), wgengine.Config{ - NetMon: wireguardMonitor, + NetMon: options.WireguardMonitor, Dialer: dialer, ListenPort: options.ListenPort, SetSubsystem: sys.Set, @@ -293,7 +298,7 @@ func NewConn(options *Options) (conn *Conn, err error) { listeners: map[listenKey]*listener{}, tunDevice: sys.Tun.Get(), netStack: netStack, - wireguardMonitor: wireguardMonitor, + wireguardMonitor: options.WireguardMonitor, wireguardRouter: &router.Config{ LocalAddrs: options.Addresses, }, diff --git a/vpn/client.go b/vpn/client.go index 1ee166c704441..e2d846ca2343d 100644 --- a/vpn/client.go +++ b/vpn/client.go @@ -8,6 +8,7 @@ import ( "golang.org/x/xerrors" "tailscale.com/net/dns" + "tailscale.com/net/netmon" "tailscale.com/wgengine/router" "github.com/google/uuid" @@ -57,12 +58,13 @@ func NewClient() Client { } type Options struct { - Headers http.Header - Logger slog.Logger - DNSConfigurator dns.OSConfigurator - Router router.Router - TUNFileDescriptor *int - UpdateHandler tailnet.UpdatesHandler + Headers http.Header + Logger slog.Logger + DNSConfigurator dns.OSConfigurator + Router router.Router + TUNDevice tun.Device + WireguardMonitor *netmon.Monitor + UpdateHandler tailnet.UpdatesHandler } func (*client) NewConn(initCtx context.Context, serverURL *url.URL, token string, options *Options) (vpnC Conn, err error) { @@ -74,15 +76,6 @@ func (*client) NewConn(initCtx context.Context, serverURL *url.URL, token string options.Headers = http.Header{} } - var dev tun.Device - if options.TUNFileDescriptor != nil { - // No-op on non-Darwin platforms. - dev, err = makeTUN(*options.TUNFileDescriptor) - if err != nil { - return nil, xerrors.Errorf("make TUN: %w", err) - } - } - headers := options.Headers sdk := codersdk.New(serverURL) sdk.SetSessionToken(token) @@ -134,7 +127,8 @@ func (*client) NewConn(initCtx context.Context, serverURL *url.URL, token string BlockEndpoints: connInfo.DisableDirectConnections, DNSConfigurator: options.DNSConfigurator, Router: options.Router, - TUNDev: dev, + TUNDev: options.TUNDevice, + WireguardMonitor: options.WireguardMonitor, }) if err != nil { return nil, xerrors.Errorf("create tailnet: %w", err) diff --git a/vpn/dylib/lib.go b/vpn/dylib/lib.go index 465f8afd07190..de6f91042c7ef 100644 --- a/vpn/dylib/lib.go +++ b/vpn/dylib/lib.go @@ -47,8 +47,7 @@ func OpenTunnel(cReadFD, cWriteFD int32) int32 { } _, err = vpn.NewTunnel(ctx, slog.Make(), conn, vpn.NewClient(), - vpn.UseAsDNSConfig(), - vpn.UseAsRouter(), + vpn.UseOSNetworkingStack(), vpn.UseAsLogger(), ) if err != nil { diff --git a/vpn/tun.go b/vpn/tun.go index f8c51bff34390..1c3ac8e014d15 100644 --- a/vpn/tun.go +++ b/vpn/tun.go @@ -1,10 +1,10 @@ -//go:build !darwin +//go:build !darwin && !windows package vpn -import "github.com/tailscale/wireguard-go/tun" +import "cdr.dev/slog" -// This is a no-op on non-Darwin platforms. -func makeTUN(int) (tun.Device, error) { - return nil, nil +// This is a no-op on every platform except Darwin and Windows. +func GetNetworkingStack(_ *Tunnel, _ *StartRequest, _ slog.Logger) (NetworkStack, error) { + return NetworkStack{}, nil } diff --git a/vpn/tun_darwin.go b/vpn/tun_darwin.go index f710c75575009..607be6c1babfc 100644 --- a/vpn/tun_darwin.go +++ b/vpn/tun_darwin.go @@ -5,26 +5,34 @@ package vpn import ( "os" + "cdr.dev/slog" "github.com/tailscale/wireguard-go/tun" "golang.org/x/sys/unix" "golang.org/x/xerrors" ) -func makeTUN(tunFD int) (tun.Device, error) { - dupTunFd, err := unix.Dup(tunFD) +func GetNetworkingStack(t *Tunnel, req *StartRequest, _ slog.Logger) (NetworkStack, error) { + tunFd := int(req.GetTunnelFileDescriptor()) + dupTunFd, err := unix.Dup(tunFd) if err != nil { - return nil, xerrors.Errorf("dup tun fd: %w", err) + return NetworkStack{}, xerrors.Errorf("dup tun fd: %w", err) } err = unix.SetNonblock(dupTunFd, true) if err != nil { unix.Close(dupTunFd) - return nil, xerrors.Errorf("set nonblock: %w", err) + return NetworkStack{}, xerrors.Errorf("set nonblock: %w", err) } fileTun, err := tun.CreateTUNFromFile(os.NewFile(uintptr(dupTunFd), "/dev/tun"), 0) if err != nil { unix.Close(dupTunFd) - return nil, xerrors.Errorf("create TUN from File: %w", err) + return NetworkStack{}, xerrors.Errorf("create TUN from File: %w", err) } - return fileTun, nil + + return NetworkStack{ + WireguardMonitor: nil, // default is fine + TUNDevice: fileTun, + Router: NewRouter(t), + DNSConfigurator: NewDNSConfigurator(t), + }, nil } diff --git a/vpn/tun_windows.go b/vpn/tun_windows.go new file mode 100644 index 0000000000000..45897934ccc8f --- /dev/null +++ b/vpn/tun_windows.go @@ -0,0 +1,115 @@ +//go:build windows + +package vpn + +import ( + "context" + "errors" + "time" + + "github.com/coder/retry" + "github.com/tailscale/wireguard-go/tun" + "golang.org/x/sys/windows" + "golang.org/x/xerrors" + "golang.zx2c4.com/wintun" + "tailscale.com/net/dns" + "tailscale.com/net/netmon" + "tailscale.com/net/tstun" + "tailscale.com/types/logger" + "tailscale.com/util/winutil" + "tailscale.com/wgengine/router" + + "cdr.dev/slog" + "github.com/coder/coder/v2/tailnet" +) + +const tunName = "Coder" + +func GetNetworkingStack(t *Tunnel, _ *StartRequest, logger slog.Logger) (NetworkStack, error) { + tun.WintunTunnelType = tunName + guid, err := windows.GUIDFromString("{0ed1515d-04a4-4c46-abae-11ad07cf0e6d}") + if err != nil { + panic(err) + } + tun.WintunStaticRequestedGUID = &guid + + tunDev, tunName, err := tstunNewWithWindowsRetries(tailnet.Logger(logger.Named("net.tun.device")), tunName) + if err != nil { + return NetworkStack{}, xerrors.Errorf("create tun device: %w", err) + } + logger.Info(context.Background(), "tun created", slog.F("name", tunName)) + + wireguardMonitor, err := netmon.New(tailnet.Logger(logger.Named("net.wgmonitor"))) + + coderRouter, err := router.New(tailnet.Logger(logger.Named("net.router")), tunDev, wireguardMonitor) + if err != nil { + return NetworkStack{}, xerrors.Errorf("create router: %w", err) + } + + dnsConfigurator, err := dns.NewOSConfigurator(tailnet.Logger(logger.Named("net.dns")), tunName) + if err != nil { + return NetworkStack{}, xerrors.Errorf("create dns configurator: %w", err) + } + + return NetworkStack{ + WireguardMonitor: nil, // default is fine + TUNDevice: tunDev, + Router: coderRouter, + DNSConfigurator: dnsConfigurator, + }, nil +} + +// tstunNewOrRetry is a wrapper around tstun.New that retries on Windows for certain +// errors. +// +// This is taken from Tailscale: +// https://github.com/tailscale/tailscale/blob/3abfbf50aebbe3ba57dc749165edb56be6715c0a/cmd/tailscaled/tailscaled_windows.go#L107 +func tstunNewWithWindowsRetries(logf logger.Logf, tunName string) (_ tun.Device, devName string, _ error) { + r := retry.New(250*time.Millisecond, 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + for r.Wait(ctx) { + dev, devName, err := tstun.New(logf, tunName) + if err == nil { + return dev, devName, err + } + if errors.Is(err, windows.ERROR_DEVICE_NOT_AVAILABLE) || windowsUptime() < 10*time.Minute { + // Wintun is not installing correctly. Dump the state of NetSetupSvc + // (which is a user-mode service that must be active for network devices + // to install) and its dependencies to the log. + winutil.LogSvcState(logf, "NetSetupSvc") + } + } + + return nil, "", ctx.Err() +} + +var ( + kernel32 = windows.NewLazySystemDLL("kernel32.dll") + getTickCount64Proc = kernel32.NewProc("GetTickCount64") +) + +func windowsUptime() time.Duration { + r, _, _ := getTickCount64Proc.Call() + return time.Duration(int64(r)) * time.Millisecond +} + +// TODO(@dean): implement a way to install/uninstall the wintun driver, most +// likely as a CLI command +// +// This is taken from Tailscale: +// https://github.com/tailscale/tailscale/blob/3abfbf50aebbe3ba57dc749165edb56be6715c0a/cmd/tailscaled/tailscaled_windows.go#L543 +func uninstallWinTun(logf logger.Logf) { + dll := windows.NewLazyDLL("wintun.dll") + if err := dll.Load(); err != nil { + logf("Cannot load wintun.dll for uninstall: %v", err) + return + } + + logf("Removing wintun driver...") + err := wintun.Uninstall() + logf("Uninstall: %v", err) +} + +// TODO(@dean): remove +var _ = uninstallWinTun diff --git a/vpn/tunnel.go b/vpn/tunnel.go index 6d6983b03946f..d9f2877ebbc9d 100644 --- a/vpn/tunnel.go +++ b/vpn/tunnel.go @@ -15,16 +15,16 @@ import ( "time" "unicode" + "github.com/google/uuid" + "github.com/tailscale/wireguard-go/tun" "golang.org/x/xerrors" "google.golang.org/protobuf/types/known/timestamppb" "tailscale.com/net/dns" + "tailscale.com/net/netmon" "tailscale.com/util/dnsname" "tailscale.com/wgengine/router" - "github.com/google/uuid" - "cdr.dev/slog" - "github.com/coder/coder/v2/coderd/util/ptr" "github.com/coder/coder/v2/tailnet" "github.com/coder/quartz" ) @@ -51,9 +51,8 @@ type Tunnel struct { // option is used, to avoid the tunnel using itself as a sink for it's own // logs, which could lead to deadlocks. clientLogger slog.Logger - // router and dnsConfigurator may be nil - router router.Router - dnsConfigurator dns.OSConfigurator + // the following may be nil + networkingStackFn func(*Tunnel, *StartRequest, slog.Logger) (NetworkStack, error) } type TunnelOption func(t *Tunnel) @@ -169,21 +168,28 @@ func (t *Tunnel) handleRPC(req *request[*TunnelMessage, *ManagerMessage]) { } } -func UseAsRouter() TunnelOption { +type NetworkStack struct { + WireguardMonitor *netmon.Monitor + TUNDevice tun.Device + Router router.Router + DNSConfigurator dns.OSConfigurator +} + +func UseOSNetworkingStack() TunnelOption { return func(t *Tunnel) { - t.router = NewRouter(t) + t.networkingStackFn = GetNetworkingStack } } func UseAsLogger() TunnelOption { return func(t *Tunnel) { - t.clientLogger = slog.Make(t) + t.clientLogger = t.clientLogger.AppendSinks(t) } } -func UseAsDNSConfig() TunnelOption { +func UseCustomLogSinks(sinks ...slog.Sink) TunnelOption { return func(t *Tunnel) { - t.dnsConfigurator = NewDNSConfigurator(t) + t.clientLogger = t.clientLogger.AppendSinks(sinks...) } } @@ -227,18 +233,28 @@ func (t *Tunnel) start(req *StartRequest) error { for _, h := range req.GetHeaders() { header.Add(h.GetName(), h.GetValue()) } + var networkingStack NetworkStack + if t.networkingStackFn != nil { + networkingStack, err = t.networkingStackFn(t, req, t.clientLogger) + if err != nil { + return xerrors.Errorf("failed to create networking stack dependencies: %w", err) + } + } else { + t.logger.Debug(t.ctx, "using default networking stack as no custom stack was provided") + } conn, err := t.client.NewConn( t.ctx, svrURL, apiToken, &Options{ - Headers: header, - Logger: t.clientLogger, - DNSConfigurator: t.dnsConfigurator, - Router: t.router, - TUNFileDescriptor: ptr.Ref(int(req.GetTunnelFileDescriptor())), - UpdateHandler: t, + Headers: header, + Logger: t.clientLogger, + DNSConfigurator: networkingStack.DNSConfigurator, + Router: networkingStack.Router, + TUNDevice: networkingStack.TUNDevice, + WireguardMonitor: networkingStack.WireguardMonitor, + UpdateHandler: t, }, ) if err != nil {