diff --git a/codersdk/workspacesdk/workspacesdk.go b/codersdk/workspacesdk/workspacesdk.go index df851e5ac31e9..25188917dafc9 100644 --- a/codersdk/workspacesdk/workspacesdk.go +++ b/codersdk/workspacesdk/workspacesdk.go @@ -128,12 +128,19 @@ func init() { } } +type resolver interface { + LookupIP(ctx context.Context, network, host string) ([]net.IP, error) +} + type Client struct { client *codersdk.Client + + // overridden in tests + resolver resolver } func New(c *codersdk.Client) *Client { - return &Client{client: c} + return &Client{client: c, resolver: net.DefaultResolver} } // AgentConnectionInfo returns required information for establishing @@ -384,3 +391,46 @@ func (c *Client) AgentReconnectingPTY(ctx context.Context, opts WorkspaceAgentRe } return websocket.NetConn(context.Background(), conn, websocket.MessageBinary), nil } + +type CoderConnectQueryOptions struct { + HostnameSuffix string +} + +// IsCoderConnectRunning checks if Coder Connect (OS level tunnel to workspaces) is running on the system. If you +// already know the hostname suffix your deployment uses, you can pass it in the CoderConnectQueryOptions to avoid an +// API call to AgentConnectionInfoGeneric. +func (c *Client) IsCoderConnectRunning(ctx context.Context, o CoderConnectQueryOptions) (bool, error) { + suffix := o.HostnameSuffix + if suffix == "" { + info, err := c.AgentConnectionInfoGeneric(ctx) + if err != nil { + return false, xerrors.Errorf("get agent connection info: %w", err) + } + suffix = info.HostnameSuffix + } + domainName := fmt.Sprintf(tailnet.IsCoderConnectEnabledFmtString, suffix) + var dnsError *net.DNSError + ips, err := c.resolver.LookupIP(ctx, "ip6", domainName) + if xerrors.As(err, &dnsError) { + if dnsError.IsNotFound { + return false, nil + } + } + if err != nil { + return false, xerrors.Errorf("lookup DNS %s: %w", domainName, err) + } + + // The returned IP addresses are probably from the Coder Connect DNS server, but there are sometimes weird captive + // internet setups where the DNS server is configured to return an address for any IP query. So, to avoid false + // positives, check if we can find an address from our service prefix. + for _, ip := range ips { + addr, ok := netip.AddrFromSlice(ip) + if !ok { + continue + } + if tailnet.CoderServicePrefix.AsNetip().Contains(addr) { + return true, nil + } + } + return false, nil +} diff --git a/codersdk/workspacesdk/workspacesdk_internal_test.go b/codersdk/workspacesdk/workspacesdk_internal_test.go new file mode 100644 index 0000000000000..1b98ebdc2e671 --- /dev/null +++ b/codersdk/workspacesdk/workspacesdk_internal_test.go @@ -0,0 +1,86 @@ +package workspacesdk + +import ( + "context" + "fmt" + "net" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" + + "tailscale.com/net/tsaddr" + + "github.com/coder/coder/v2/tailnet" +) + +func TestClient_IsCoderConnectRunning(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/api/v2/workspaceagents/connection", r.URL.Path) + httpapi.Write(ctx, rw, http.StatusOK, AgentConnectionInfo{ + HostnameSuffix: "test", + }) + })) + defer srv.Close() + + apiURL, err := url.Parse(srv.URL) + require.NoError(t, err) + sdkClient := codersdk.New(apiURL) + client := New(sdkClient) + + // Right name, right IP + expectedName := fmt.Sprintf(tailnet.IsCoderConnectEnabledFmtString, "test") + client.resolver = &fakeResolver{t: t, hostMap: map[string][]net.IP{ + expectedName: {net.ParseIP(tsaddr.CoderServiceIPv6().String())}, + }} + + result, err := client.IsCoderConnectRunning(ctx, CoderConnectQueryOptions{}) + require.NoError(t, err) + require.True(t, result) + + // Wrong name + result, err = client.IsCoderConnectRunning(ctx, CoderConnectQueryOptions{HostnameSuffix: "coder"}) + require.NoError(t, err) + require.False(t, result) + + // Not found + client.resolver = &fakeResolver{t: t, err: &net.DNSError{IsNotFound: true}} + result, err = client.IsCoderConnectRunning(ctx, CoderConnectQueryOptions{}) + require.NoError(t, err) + require.False(t, result) + + // Some other error + client.resolver = &fakeResolver{t: t, err: xerrors.New("a bad thing happened")} + _, err = client.IsCoderConnectRunning(ctx, CoderConnectQueryOptions{}) + require.Error(t, err) + + // Right name, wrong IP + client.resolver = &fakeResolver{t: t, hostMap: map[string][]net.IP{ + expectedName: {net.ParseIP("2001::34")}, + }} + result, err = client.IsCoderConnectRunning(ctx, CoderConnectQueryOptions{}) + require.NoError(t, err) + require.False(t, result) +} + +type fakeResolver struct { + t testing.TB + hostMap map[string][]net.IP + err error +} + +func (f *fakeResolver) LookupIP(_ context.Context, network, host string) ([]net.IP, error) { + assert.Equal(f.t, "ip6", network) + return f.hostMap[host], f.err +}