From abbd3b9aaf2569aa303e3f02b3c1b3e3ea32062c Mon Sep 17 00:00:00 2001 From: Hugo Dutka Date: Tue, 8 Jul 2025 16:08:42 +0000 Subject: [PATCH 1/2] remove dbmem --- .claude/docs/DATABASE.md | 39 - .claude/docs/OAUTH2.md | 15 +- .claude/docs/TESTING.md | 35 +- .claude/docs/TROUBLESHOOTING.md | 33 +- .claude/docs/WORKFLOWS.md | 10 +- .github/workflows/ci.yaml | 145 +- .golangci.yaml | 1 - CLAUDE.md | 6 +- Makefile | 3 +- cli/server.go | 64 +- cli/server_test.go | 3 - coderd/database/dbauthz/customroles_test.go | 4 +- coderd/database/dbauthz/dbauthz_test.go | 111 +- coderd/database/dbmem/dbmem.go | 14242 ---------------- coderd/database/dbmem/dbmem_test.go | 209 - coderd/database/dbtestutil/db.go | 85 +- coderd/database/oidcclaims_test.go | 1 - coderd/database/querier_test.go | 1 - coderd/notifications/notifications_test.go | 2 +- coderd/users_test.go | 10 - codersdk/deployment.go | 10 - docs/about/contributing/backend.md | 1 - enterprise/cli/server_test.go | 3 - enterprise/coderd/coderd.go | 8 +- .../coderd/coderdenttest/coderdenttest.go | 3 - scripts/dbgen/main.go | 83 +- site/e2e/playwright.config.ts | 2 +- 27 files changed, 143 insertions(+), 14986 deletions(-) delete mode 100644 coderd/database/dbmem/dbmem.go delete mode 100644 coderd/database/dbmem/dbmem_test.go diff --git a/.claude/docs/DATABASE.md b/.claude/docs/DATABASE.md index f6ba4bd78859b..090054772fc32 100644 --- a/.claude/docs/DATABASE.md +++ b/.claude/docs/DATABASE.md @@ -58,31 +58,6 @@ If adding fields to auditable types: - `ActionSecret`: Field contains sensitive data 3. Run `make gen` to verify no audit errors -## In-Memory Database (dbmem) Updates - -### Critical Requirements - -When adding new fields to database structs: - -- **CRITICAL**: Update `coderd/database/dbmem/dbmem.go` in-memory implementations -- The `Insert*` functions must include ALL new fields, not just basic ones -- Common issue: Tests pass with real database but fail with in-memory database due to missing field mappings -- Always verify in-memory database functions match the real database schema after migrations - -### Example Pattern - -```go -// In dbmem.go - ensure ALL fields are included -code := database.OAuth2ProviderAppCode{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - // ... existing fields ... - ResourceUri: arg.ResourceUri, // New field - CodeChallenge: arg.CodeChallenge, // New field - CodeChallengeMethod: arg.CodeChallengeMethod, // New field -} -``` - ## Database Architecture ### Core Components @@ -116,7 +91,6 @@ roles, err := db.GetAuthorizationUserRoles(dbauthz.AsSystemRestricted(ctx), user 1. **Nullable field errors**: Use `sql.Null*` types consistently 2. **Missing audit entries**: Update `enterprise/audit/table.go` -3. **dbmem inconsistencies**: Ensure in-memory implementations match schema ### Query Issues @@ -139,19 +113,6 @@ func TestDatabaseFunction(t *testing.T) { } ``` -### In-Memory Testing - -```go -func TestInMemoryDatabase(t *testing.T) { - db := dbmem.New() - - // Test with in-memory database - result, err := db.GetSomething(ctx, param) - require.NoError(t, err) - require.Equal(t, expected, result) -} -``` - ## Best Practices ### Schema Design diff --git a/.claude/docs/OAUTH2.md b/.claude/docs/OAUTH2.md index 2c766dd083516..9fb34f093042a 100644 --- a/.claude/docs/OAUTH2.md +++ b/.claude/docs/OAUTH2.md @@ -112,14 +112,13 @@ Always run the full test suite after OAuth2 changes: ## Common OAuth2 Issues 1. **OAuth2 endpoints returning wrong error format** - Ensure OAuth2 endpoints return RFC 6749 compliant errors -2. **OAuth2 tests failing but scripts working** - Check in-memory database implementations in `dbmem.go` -3. **Resource indicator validation failing** - Ensure database stores and retrieves resource parameters correctly -4. **PKCE tests failing** - Verify both authorization code storage and token exchange handle PKCE fields -5. **RFC compliance failures** - Verify against actual RFC specifications, not assumptions -6. **Authorization context errors in public endpoints** - Use `dbauthz.AsSystemRestricted(ctx)` pattern -7. **Default value mismatches** - Ensure database migrations match application code defaults -8. **Bearer token authentication issues** - Check token extraction precedence and format validation -9. **URI validation failures** - Support both standard schemes and custom schemes per protocol requirements +2. **Resource indicator validation failing** - Ensure database stores and retrieves resource parameters correctly +3. **PKCE tests failing** - Verify both authorization code storage and token exchange handle PKCE fields +4. **RFC compliance failures** - Verify against actual RFC specifications, not assumptions +5. **Authorization context errors in public endpoints** - Use `dbauthz.AsSystemRestricted(ctx)` pattern +6. **Default value mismatches** - Ensure database migrations match application code defaults +7. **Bearer token authentication issues** - Check token extraction precedence and format validation +8. **URI validation failures** - Support both standard schemes and custom schemes per protocol requirements ## Authorization Context Patterns diff --git a/.claude/docs/TESTING.md b/.claude/docs/TESTING.md index b8f92f531bb1c..eff655b0acadc 100644 --- a/.claude/docs/TESTING.md +++ b/.claude/docs/TESTING.md @@ -39,31 +39,6 @@ 2. **Verify information disclosure protections** 3. **Test token security and proper invalidation** -## Database Testing - -### In-Memory Database Testing - -When adding new database fields: - -- **CRITICAL**: Update `coderd/database/dbmem/dbmem.go` in-memory implementations -- The `Insert*` functions must include ALL new fields, not just basic ones -- Common issue: Tests pass with real database but fail with in-memory database due to missing field mappings -- Always verify in-memory database functions match the real database schema after migrations - -Example pattern: - -```go -// In dbmem.go - ensure ALL fields are included -code := database.OAuth2ProviderAppCode{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - // ... existing fields ... - ResourceUri: arg.ResourceUri, // New field - CodeChallenge: arg.CodeChallenge, // New field - CodeChallengeMethod: arg.CodeChallengeMethod, // New field -} -``` - ## Test Organization ### Test File Structure @@ -107,15 +82,13 @@ coderd/ ### Database-Related -1. **Tests passing locally but failing in CI** - Check if `dbmem` implementation needs updating -2. **SQL type errors** - Use `sql.Null*` types for nullable fields -3. **Race conditions in tests** - Use unique identifiers instead of hardcoded names +1. **SQL type errors** - Use `sql.Null*` types for nullable fields +2. **Race conditions in tests** - Use unique identifiers instead of hardcoded names ### OAuth2 Testing -1. **OAuth2 tests failing but scripts working** - Check in-memory database implementations in `dbmem.go` -2. **PKCE tests failing** - Verify both authorization code storage and token exchange handle PKCE fields -3. **Resource indicator validation failing** - Ensure database stores and retrieves resource parameters correctly +1. **PKCE tests failing** - Verify both authorization code storage and token exchange handle PKCE fields +2. **Resource indicator validation failing** - Ensure database stores and retrieves resource parameters correctly ### General Issues diff --git a/.claude/docs/TROUBLESHOOTING.md b/.claude/docs/TROUBLESHOOTING.md index 2b4bb3ee064cc..19c05a7a0cd62 100644 --- a/.claude/docs/TROUBLESHOOTING.md +++ b/.claude/docs/TROUBLESHOOTING.md @@ -21,59 +21,50 @@ } ``` -3. **Tests passing locally but failing in CI** - - **Solution**: Check if `dbmem` implementation needs updating - - Update `coderd/database/dbmem/dbmem.go` for Insert/Update methods - - Missing fields in dbmem can cause tests to fail even if main implementation is correct - ### Testing Issues -4. **"package should be X_test"** +3. **"package should be X_test"** - **Solution**: Use `package_test` naming for test files - Example: `identityprovider_test` for black-box testing -5. **Race conditions in tests** +4. **Race conditions in tests** - **Solution**: Use unique identifiers instead of hardcoded names - Example: `fmt.Sprintf("test-client-%s-%d", t.Name(), time.Now().UnixNano())` - Never use hardcoded names in concurrent tests -6. **Missing newlines** +5. **Missing newlines** - **Solution**: Ensure files end with newline character - Most editors can be configured to add this automatically ### OAuth2 Issues -7. **OAuth2 endpoints returning wrong error format** +6. **OAuth2 endpoints returning wrong error format** - **Solution**: Ensure OAuth2 endpoints return RFC 6749 compliant errors - Use standard error codes: `invalid_client`, `invalid_grant`, `invalid_request` - Format: `{"error": "code", "error_description": "details"}` -8. **OAuth2 tests failing but scripts working** - - **Solution**: Check in-memory database implementations in `dbmem.go` - - Ensure all OAuth2 fields are properly copied in Insert/Update methods - -9. **Resource indicator validation failing** +7. **Resource indicator validation failing** - **Solution**: Ensure database stores and retrieves resource parameters correctly - Check both authorization code storage and token exchange handling -10. **PKCE tests failing** +8. **PKCE tests failing** - **Solution**: Verify both authorization code storage and token exchange handle PKCE fields - Check `CodeChallenge` and `CodeChallengeMethod` field handling ### RFC Compliance Issues -11. **RFC compliance failures** +9. **RFC compliance failures** - **Solution**: Verify against actual RFC specifications, not assumptions - Use WebFetch tool to get current RFC content for compliance verification - Read the actual RFC specifications before implementation -12. **Default value mismatches** +10. **Default value mismatches** - **Solution**: Ensure database migrations match application code defaults - Example: RFC 7591 specifies `client_secret_basic` as default, not `client_secret_post` ### Authorization Issues -13. **Authorization context errors in public endpoints** +11. **Authorization context errors in public endpoints** - **Solution**: Use `dbauthz.AsSystemRestricted(ctx)` pattern - Example: @@ -84,17 +75,17 @@ ### Authentication Issues -14. **Bearer token authentication issues** +12. **Bearer token authentication issues** - **Solution**: Check token extraction precedence and format validation - Ensure proper RFC 6750 Bearer Token Support implementation -15. **URI validation failures** +13. **URI validation failures** - **Solution**: Support both standard schemes and custom schemes per protocol requirements - Native OAuth2 apps may use custom schemes ### General Development Issues -16. **Log message formatting errors** +14. **Log message formatting errors** - **Solution**: Use lowercase, descriptive messages without special characters - Follow Go logging conventions diff --git a/.claude/docs/WORKFLOWS.md b/.claude/docs/WORKFLOWS.md index 1bd595c8a4b34..b846110d589d8 100644 --- a/.claude/docs/WORKFLOWS.md +++ b/.claude/docs/WORKFLOWS.md @@ -81,11 +81,6 @@ - Add each new field with appropriate action (ActionTrack, ActionIgnore, ActionSecret) - Run `make gen` to verify no audit errors -6. **In-memory database (dbmem) updates**: - - When adding new fields to database structs, ensure `dbmem` implementation copies all fields - - Check `coderd/database/dbmem/dbmem.go` for Insert/Update methods - - Missing fields in dbmem can cause tests to fail even if main implementation is correct - ### Database Generation Process 1. Modify SQL files in `coderd/database/queries/` @@ -164,9 +159,8 @@ 1. **Development server won't start** - Use `./scripts/develop.sh` instead of manual commands 2. **Database migration errors** - Check migration file format and use helper scripts -3. **Test failures after database changes** - Update `dbmem` implementations -4. **Audit table errors** - Update `enterprise/audit/table.go` with new fields -5. **OAuth2 compliance issues** - Ensure RFC-compliant error responses +3. **Audit table errors** - Update `enterprise/audit/table.go` with new fields +4. **OAuth2 compliance issues** - Ensure RFC-compliant error responses ### Debug Commands diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 8c4e21466c03d..93bd1a9ea65aa 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -311,94 +311,6 @@ jobs: - name: Check for unstaged files run: ./scripts/check_unstaged.sh - test-go: - runs-on: ${{ matrix.os == 'ubuntu-latest' && github.repository_owner == 'coder' && 'depot-ubuntu-22.04-4' || matrix.os == 'macos-latest' && github.repository_owner == 'coder' && 'depot-macos-latest' || matrix.os == 'windows-2022' && github.repository_owner == 'coder' && 'depot-windows-2022-16' || matrix.os }} - needs: changes - if: needs.changes.outputs.go == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main' - timeout-minutes: 20 - strategy: - fail-fast: false - matrix: - os: - - ubuntu-latest - - macos-latest - - windows-2022 - steps: - - name: Harden Runner - # Harden Runner is only supported on Ubuntu runners. - if: runner.os == 'Linux' - uses: step-security/harden-runner@6c439dc8bdf85cadbbce9ed30d1c7b959517bc49 # v2.12.2 - with: - egress-policy: audit - - # Set up RAM disks to speed up the rest of the job. This action is in - # a separate repository to allow its use before actions/checkout. - - name: Setup RAM Disks - if: runner.os == 'Windows' - uses: coder/setup-ramdisk-action@e1100847ab2d7bcd9d14bcda8f2d1b0f07b36f1b - - - name: Checkout - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 1 - - - name: Setup Go Paths - uses: ./.github/actions/setup-go-paths - - - name: Setup Go - uses: ./.github/actions/setup-go - with: - # Runners have Go baked-in and Go will automatically - # download the toolchain configured in go.mod, so we don't - # need to reinstall it. It's faster on Windows runners. - use-preinstalled-go: ${{ runner.os == 'Windows' }} - - - name: Setup Terraform - uses: ./.github/actions/setup-tf - - - name: Download Test Cache - id: download-cache - uses: ./.github/actions/test-cache/download - with: - key-prefix: test-go-${{ runner.os }}-${{ runner.arch }} - - - name: Test with Mock Database - id: test - shell: bash - run: | - # if macOS, install google-chrome for scaletests. As another concern, - # should we really have this kind of external dependency requirement - # on standard CI? - if [ "${{ matrix.os }}" == "macos-latest" ]; then - brew install google-chrome - fi - - # By default Go will use the number of logical CPUs, which - # is a fine default. - PARALLEL_FLAG="" - - # macOS will output "The default interactive shell is now zsh" - # intermittently in CI... - if [ "${{ matrix.os }}" == "macos-latest" ]; then - touch ~/.bash_profile && echo "export BASH_SILENCE_DEPRECATION_WARNING=1" >> ~/.bash_profile - fi - export TS_DEBUG_DISCO=true - gotestsum --junitfile="gotests.xml" --jsonfile="gotests.json" --rerun-fails=2 \ - --packages="./..." -- $PARALLEL_FLAG -short - - - name: Upload Test Cache - uses: ./.github/actions/test-cache/upload - with: - cache-key: ${{ steps.download-cache.outputs.cache-key }} - - - name: Upload test stats to Datadog - timeout-minutes: 1 - continue-on-error: true - uses: ./.github/actions/upload-datadog - if: success() || failure() - with: - api-key: ${{ secrets.DATADOG_API_KEY }} - test-go-pg: # make sure to adjust NUM_PARALLEL_PACKAGES and NUM_PARALLEL_TESTS below # when changing runner sizes @@ -567,7 +479,7 @@ jobs: # We rerun failing tests to counteract flakiness coming from Postgres # choking on macOS and Windows sometimes. - DB=ci gotestsum --rerun-fails=2 --rerun-fails-max-failures=50 \ + gotestsum --rerun-fails=2 --rerun-fails-max-failures=50 \ --format standard-quiet --packages "./..." \ -- -timeout=20m -v -p $NUM_PARALLEL_PACKAGES -parallel=$NUM_PARALLEL_TESTS $TESTCOUNT @@ -655,55 +567,6 @@ jobs: with: api-key: ${{ secrets.DATADOG_API_KEY }} - test-go-race: - runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-16' || 'ubuntu-latest' }} - needs: changes - if: needs.changes.outputs.go == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main' - timeout-minutes: 25 - steps: - - name: Harden Runner - uses: step-security/harden-runner@6c439dc8bdf85cadbbce9ed30d1c7b959517bc49 # v2.12.2 - with: - egress-policy: audit - - - name: Checkout - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 1 - - - name: Setup Go - uses: ./.github/actions/setup-go - - - name: Setup Terraform - uses: ./.github/actions/setup-tf - - - name: Download Test Cache - id: download-cache - uses: ./.github/actions/test-cache/download - with: - key-prefix: test-go-race-${{ runner.os }}-${{ runner.arch }} - - # We run race tests with reduced parallelism because they use more CPU and we were finding - # instances where tests appear to hang for multiple seconds, resulting in flaky tests when - # short timeouts are used. - # c.f. discussion on https://github.com/coder/coder/pull/15106 - - name: Run Tests - run: | - gotestsum --junitfile="gotests.xml" --packages="./..." --rerun-fails=2 --rerun-fails-abort-on-data-race -- -race -parallel 4 -p 4 - - - name: Upload Test Cache - uses: ./.github/actions/test-cache/upload - with: - cache-key: ${{ steps.download-cache.outputs.cache-key }} - - - name: Upload test stats to Datadog - timeout-minutes: 1 - continue-on-error: true - uses: ./.github/actions/upload-datadog - if: always() - with: - api-key: ${{ secrets.DATADOG_API_KEY }} - test-go-race-pg: runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-16' || 'ubuntu-latest' }} needs: changes @@ -741,7 +604,7 @@ jobs: POSTGRES_VERSION: "17" run: | make test-postgres-docker - DB=ci gotestsum --junitfile="gotests.xml" --packages="./..." --rerun-fails=2 --rerun-fails-abort-on-data-race -- -race -parallel 4 -p 4 + gotestsum --junitfile="gotests.xml" --packages="./..." --rerun-fails=2 --rerun-fails-abort-on-data-race -- -race -parallel 4 -p 4 - name: Upload Test Cache uses: ./.github/actions/test-cache/upload @@ -1037,9 +900,7 @@ jobs: - fmt - lint - gen - - test-go - test-go-pg - - test-go-race - test-go-race-pg - test-js - test-e2e @@ -1060,9 +921,7 @@ jobs: echo "- fmt: ${{ needs.fmt.result }}" echo "- lint: ${{ needs.lint.result }}" echo "- gen: ${{ needs.gen.result }}" - echo "- test-go: ${{ needs.test-go.result }}" echo "- test-go-pg: ${{ needs.test-go-pg.result }}" - echo "- test-go-race: ${{ needs.test-go-race.result }}" echo "- test-go-race-pg: ${{ needs.test-go-race-pg.result }}" echo "- test-js: ${{ needs.test-js.result }}" echo "- test-e2e: ${{ needs.test-e2e.result }}" diff --git a/.golangci.yaml b/.golangci.yaml index 2e1e853a0425a..aeebaf47e29a6 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -181,7 +181,6 @@ linters-settings: issues: exclude-dirs: - - coderd/database/dbmem - node_modules - .git diff --git a/CLAUDE.md b/CLAUDE.md index 4df2514a45863..d5335a6d4d0b3 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -44,7 +44,6 @@ 2. Run `make gen` 3. If audit errors: update `enterprise/audit/table.go` 4. Run `make gen` again -5. Update `coderd/database/dbmem/dbmem.go` in-memory implementations ### LSP Navigation (USE FIRST) @@ -116,9 +115,8 @@ app, err := api.Database.GetOAuth2ProviderAppByClientID(ctx, clientID) 1. **Audit table errors** → Update `enterprise/audit/table.go` 2. **OAuth2 errors** → Return RFC-compliant format -3. **dbmem failures** → Update in-memory implementations -4. **Race conditions** → Use unique test identifiers -5. **Missing newlines** → Ensure files end with newline +3. **Race conditions** → Use unique test identifiers +4. **Missing newlines** → Ensure files end with newline --- diff --git a/Makefile b/Makefile index 0ed464ba23a80..ebfb286a482cc 100644 --- a/Makefile +++ b/Makefile @@ -599,7 +599,6 @@ DB_GEN_FILES := \ coderd/database/dump.sql \ coderd/database/querier.go \ coderd/database/unique_constraint.go \ - coderd/database/dbmem/dbmem.go \ coderd/database/dbmetrics/dbmetrics.go \ coderd/database/dbauthz/dbauthz.go \ coderd/database/dbmock/dbmock.go @@ -973,7 +972,7 @@ sqlc-vet: test-postgres-docker test-postgres: test-postgres-docker # The postgres test is prone to failure, so we limit parallelism for # more consistent execution. - $(GIT_FLAGS) DB=ci gotestsum \ + $(GIT_FLAGS) gotestsum \ --junitfile="gotests.xml" \ --jsonfile="gotests.json" \ $(GOTESTSUM_RETRY_FLAGS) \ diff --git a/cli/server.go b/cli/server.go index 5074bffc3a342..602f05d028b66 100644 --- a/cli/server.go +++ b/cli/server.go @@ -77,7 +77,6 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/awsiamrds" "github.com/coder/coder/v2/coderd/database/dbauthz" - "github.com/coder/coder/v2/coderd/database/dbmem" "github.com/coder/coder/v2/coderd/database/dbmetrics" "github.com/coder/coder/v2/coderd/database/dbpurge" "github.com/coder/coder/v2/coderd/database/migrations" @@ -423,7 +422,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. builtinPostgres := false // Only use built-in if PostgreSQL URL isn't specified! - if !vals.InMemoryDatabase && vals.PostgresURL == "" { + if vals.PostgresURL == "" { var closeFunc func() error cliui.Infof(inv.Stdout, "Using built-in PostgreSQL (%s)", config.PostgresPath()) customPostgresCacheDir := "" @@ -726,42 +725,37 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. // nil, that case of the select will just never fire, but it's important not to have a // "bare" read on this channel. var pubsubWatchdogTimeout <-chan struct{} - if vals.InMemoryDatabase { - // This is only used for testing. - options.Database = dbmem.New() - options.Pubsub = pubsub.NewInMemory() - } else { - sqlDB, dbURL, err := getAndMigratePostgresDB(ctx, logger, vals.PostgresURL.String(), codersdk.PostgresAuth(vals.PostgresAuth), sqlDriver) - if err != nil { - return xerrors.Errorf("connect to postgres: %w", err) - } - defer func() { - _ = sqlDB.Close() - }() - if options.DeploymentValues.Prometheus.Enable { - // At this stage we don't think the database name serves much purpose in these metrics. - // It requires parsing the DSN to determine it, which requires pulling in another dependency - // (i.e. https://github.com/jackc/pgx), but it's rather heavy. - // The conn string (https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING) can - // take different forms, which make parsing non-trivial. - options.PrometheusRegistry.MustRegister(collectors.NewDBStatsCollector(sqlDB, "")) - } + sqlDB, dbURL, err := getAndMigratePostgresDB(ctx, logger, vals.PostgresURL.String(), codersdk.PostgresAuth(vals.PostgresAuth), sqlDriver) + if err != nil { + return xerrors.Errorf("connect to postgres: %w", err) + } + defer func() { + _ = sqlDB.Close() + }() - options.Database = database.New(sqlDB) - ps, err := pubsub.New(ctx, logger.Named("pubsub"), sqlDB, dbURL) - if err != nil { - return xerrors.Errorf("create pubsub: %w", err) - } - options.Pubsub = ps - if options.DeploymentValues.Prometheus.Enable { - options.PrometheusRegistry.MustRegister(ps) - } - defer options.Pubsub.Close() - psWatchdog := pubsub.NewWatchdog(ctx, logger.Named("pswatch"), ps) - pubsubWatchdogTimeout = psWatchdog.Timeout() - defer psWatchdog.Close() + if options.DeploymentValues.Prometheus.Enable { + // At this stage we don't think the database name serves much purpose in these metrics. + // It requires parsing the DSN to determine it, which requires pulling in another dependency + // (i.e. https://github.com/jackc/pgx), but it's rather heavy. + // The conn string (https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING) can + // take different forms, which make parsing non-trivial. + options.PrometheusRegistry.MustRegister(collectors.NewDBStatsCollector(sqlDB, "")) + } + + options.Database = database.New(sqlDB) + ps, err := pubsub.New(ctx, logger.Named("pubsub"), sqlDB, dbURL) + if err != nil { + return xerrors.Errorf("create pubsub: %w", err) + } + options.Pubsub = ps + if options.DeploymentValues.Prometheus.Enable { + options.PrometheusRegistry.MustRegister(ps) } + defer options.Pubsub.Close() + psWatchdog := pubsub.NewWatchdog(ctx, logger.Named("pswatch"), ps) + pubsubWatchdogTimeout = psWatchdog.Timeout() + defer psWatchdog.Close() if options.DeploymentValues.Prometheus.Enable && options.DeploymentValues.Prometheus.CollectDBMetrics { options.Database = dbmetrics.NewQueryMetrics(options.Database, options.Logger, options.PrometheusRegistry) diff --git a/cli/server_test.go b/cli/server_test.go index 2d0bbdd24e83b..435ed2879c9a3 100644 --- a/cli/server_test.go +++ b/cli/server_test.go @@ -59,9 +59,6 @@ import ( ) func dbArg(t *testing.T) string { - if !dbtestutil.WillUsePostgres() { - return "--in-memory" - } dbURL, err := dbtestutil.Open(t) require.NoError(t, err) return "--postgres-url=" + dbURL diff --git a/coderd/database/dbauthz/customroles_test.go b/coderd/database/dbauthz/customroles_test.go index 5e19f43ab5376..54541d4670c2c 100644 --- a/coderd/database/dbauthz/customroles_test.go +++ b/coderd/database/dbauthz/customroles_test.go @@ -12,7 +12,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/dbauthz" - "github.com/coder/coder/v2/coderd/database/dbmem" + "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/coderd/rbac/policy" "github.com/coder/coder/v2/codersdk" @@ -202,7 +202,7 @@ func TestInsertCustomRoles(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) rec := &coderdtest.RecordingAuthorizer{ Wrapped: rbac.NewAuthorizer(prometheus.NewRegistry()), } diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index bcbe6bb881dde..84c2067848ec4 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -19,7 +19,6 @@ import ( "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database/db2sdk" - "github.com/coder/coder/v2/coderd/database/dbmem" "github.com/coder/coder/v2/coderd/notifications" "github.com/coder/coder/v2/coderd/rbac/policy" "github.com/coder/coder/v2/codersdk" @@ -3661,148 +3660,119 @@ func (s *MethodTestSuite) TestExtraMethods() { func (s *MethodTestSuite) TestTailnetFunctions() { s.Run("CleanTailnetCoordinators", s.Subtest(func(_ database.Store, check *expects) { check.Args(). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete) })) s.Run("CleanTailnetLostPeers", s.Subtest(func(_ database.Store, check *expects) { check.Args(). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete) })) s.Run("CleanTailnetTunnels", s.Subtest(func(_ database.Store, check *expects) { check.Args(). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete) })) s.Run("DeleteAllTailnetClientSubscriptions", s.Subtest(func(_ database.Store, check *expects) { check.Args(database.DeleteAllTailnetClientSubscriptionsParams{}). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete) })) s.Run("DeleteAllTailnetTunnels", s.Subtest(func(_ database.Store, check *expects) { check.Args(database.DeleteAllTailnetTunnelsParams{}). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete) })) s.Run("DeleteCoordinator", s.Subtest(func(_ database.Store, check *expects) { check.Args(uuid.New()). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete) })) s.Run("DeleteTailnetAgent", s.Subtest(func(_ database.Store, check *expects) { check.Args(database.DeleteTailnetAgentParams{}). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionUpdate).Errors(sql.ErrNoRows). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionUpdate).Errors(sql.ErrNoRows) })) s.Run("DeleteTailnetClient", s.Subtest(func(_ database.Store, check *expects) { check.Args(database.DeleteTailnetClientParams{}). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete).Errors(sql.ErrNoRows). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete).Errors(sql.ErrNoRows) })) s.Run("DeleteTailnetClientSubscription", s.Subtest(func(_ database.Store, check *expects) { check.Args(database.DeleteTailnetClientSubscriptionParams{}). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete) })) s.Run("DeleteTailnetPeer", s.Subtest(func(_ database.Store, check *expects) { check.Args(database.DeleteTailnetPeerParams{}). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete). - ErrorsWithInMemDB(dbmem.ErrUnimplemented). - ErrorsWithPG(sql.ErrNoRows) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete).Errors(sql.ErrNoRows) })) s.Run("DeleteTailnetTunnel", s.Subtest(func(_ database.Store, check *expects) { check.Args(database.DeleteTailnetTunnelParams{}). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete). - ErrorsWithInMemDB(dbmem.ErrUnimplemented). - ErrorsWithPG(sql.ErrNoRows) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete).Errors(sql.ErrNoRows) })) s.Run("GetAllTailnetAgents", s.Subtest(func(_ database.Store, check *expects) { check.Args(). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead) })) s.Run("GetTailnetAgents", s.Subtest(func(_ database.Store, check *expects) { check.Args(uuid.New()). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead) })) s.Run("GetTailnetClientsForAgent", s.Subtest(func(_ database.Store, check *expects) { check.Args(uuid.New()). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead) })) s.Run("GetTailnetPeers", s.Subtest(func(_ database.Store, check *expects) { check.Args(uuid.New()). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead) })) s.Run("GetTailnetTunnelPeerBindings", s.Subtest(func(_ database.Store, check *expects) { check.Args(uuid.New()). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead) })) s.Run("GetTailnetTunnelPeerIDs", s.Subtest(func(_ database.Store, check *expects) { check.Args(uuid.New()). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead) })) s.Run("GetAllTailnetCoordinators", s.Subtest(func(_ database.Store, check *expects) { check.Args(). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead) })) s.Run("GetAllTailnetPeers", s.Subtest(func(_ database.Store, check *expects) { check.Args(). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead) })) s.Run("GetAllTailnetTunnels", s.Subtest(func(_ database.Store, check *expects) { check.Args(). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead) })) s.Run("UpsertTailnetAgent", s.Subtest(func(db database.Store, check *expects) { dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) check.Args(database.UpsertTailnetAgentParams{Node: json.RawMessage("{}")}). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionUpdate). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionUpdate) })) s.Run("UpsertTailnetClient", s.Subtest(func(db database.Store, check *expects) { dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) check.Args(database.UpsertTailnetClientParams{Node: json.RawMessage("{}")}). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionUpdate). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionUpdate) })) s.Run("UpsertTailnetClientSubscription", s.Subtest(func(db database.Store, check *expects) { dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) check.Args(database.UpsertTailnetClientSubscriptionParams{}). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionUpdate). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionUpdate) })) s.Run("UpsertTailnetCoordinator", s.Subtest(func(_ database.Store, check *expects) { check.Args(uuid.New()). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionUpdate). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionUpdate) })) s.Run("UpsertTailnetPeer", s.Subtest(func(db database.Store, check *expects) { dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) check.Args(database.UpsertTailnetPeerParams{ Status: database.TailnetStatusOk, }). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionCreate). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionCreate) })) s.Run("UpsertTailnetTunnel", s.Subtest(func(db database.Store, check *expects) { dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) check.Args(database.UpsertTailnetTunnelParams{}). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionCreate). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionCreate) })) s.Run("UpdateTailnetPeerStatusByCoordinator", s.Subtest(func(db database.Store, check *expects) { dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) check.Args(database.UpdateTailnetPeerStatusByCoordinatorParams{Status: database.TailnetStatusOk}). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionUpdate). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionUpdate) })) } @@ -4787,21 +4757,18 @@ func (s *MethodTestSuite) TestNotifications() { dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) user := dbgen.User(s.T(), db, database.User{}) check.Args(user.ID).Asserts(rbac.ResourceNotificationTemplate, policy.ActionRead). - ErrorsWithPG(sql.ErrNoRows). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + ErrorsWithPG(sql.ErrNoRows) })) s.Run("GetNotificationTemplatesByKind", s.Subtest(func(db database.Store, check *expects) { check.Args(database.NotificationTemplateKindSystem). - Asserts(). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts() // TODO(dannyk): add support for other database.NotificationTemplateKind types once implemented. })) s.Run("UpdateNotificationTemplateMethodByID", s.Subtest(func(db database.Store, check *expects) { check.Args(database.UpdateNotificationTemplateMethodByIDParams{ Method: database.NullNotificationMethod{NotificationMethod: database.NotificationMethodWebhook, Valid: true}, ID: notifications.TemplateWorkspaceDormant, - }).Asserts(rbac.ResourceNotificationTemplate, policy.ActionUpdate). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + }).Asserts(rbac.ResourceNotificationTemplate, policy.ActionUpdate) })) // Notification preferences @@ -5115,8 +5082,7 @@ func (s *MethodTestSuite) TestPrebuilds() { rbac.ResourceWorkspace.WithOwner(user.ID.String()).InOrg(org.ID), policy.ActionCreate, template, policy.ActionRead, template, policy.ActionUse, - ).ErrorsWithInMemDB(dbmem.ErrUnimplemented). - ErrorsWithPG(sql.ErrNoRows) + ).Errors(sql.ErrNoRows) })) s.Run("GetPrebuildMetrics", s.Subtest(func(_ database.Store, check *expects) { check.Args(). @@ -5130,29 +5096,24 @@ func (s *MethodTestSuite) TestPrebuilds() { })) s.Run("CountInProgressPrebuilds", s.Subtest(func(_ database.Store, check *expects) { check.Args(). - Asserts(rbac.ResourceWorkspace.All(), policy.ActionRead). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceWorkspace.All(), policy.ActionRead) })) s.Run("GetPresetsAtFailureLimit", s.Subtest(func(_ database.Store, check *expects) { check.Args(int64(0)). - Asserts(rbac.ResourceTemplate.All(), policy.ActionViewInsights). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTemplate.All(), policy.ActionViewInsights) })) s.Run("GetPresetsBackoff", s.Subtest(func(_ database.Store, check *expects) { check.Args(time.Time{}). - Asserts(rbac.ResourceTemplate.All(), policy.ActionViewInsights). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTemplate.All(), policy.ActionViewInsights) })) s.Run("GetRunningPrebuiltWorkspaces", s.Subtest(func(_ database.Store, check *expects) { check.Args(). - Asserts(rbac.ResourceWorkspace.All(), policy.ActionRead). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceWorkspace.All(), policy.ActionRead) })) s.Run("GetTemplatePresetsWithPrebuilds", s.Subtest(func(db database.Store, check *expects) { user := dbgen.User(s.T(), db, database.User{}) check.Args(uuid.NullUUID{UUID: user.ID, Valid: true}). - Asserts(rbac.ResourceTemplate.All(), policy.ActionRead). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTemplate.All(), policy.ActionRead) })) s.Run("GetPresetByID", s.Subtest(func(db database.Store, check *expects) { org := dbgen.Organization(s.T(), db, database.Organization{}) diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go deleted file mode 100644 index d106e6a5858fb..0000000000000 --- a/coderd/database/dbmem/dbmem.go +++ /dev/null @@ -1,14242 +0,0 @@ -package dbmem - -import ( - "bytes" - "context" - "database/sql" - "encoding/json" - "errors" - "fmt" - "math" - insecurerand "math/rand" //#nosec // this is only used for shuffling an array to pick random jobs to reap - "reflect" - "regexp" - "slices" - "sort" - "strings" - "sync" - "time" - - "github.com/google/uuid" - "github.com/lib/pq" - "golang.org/x/exp/constraints" - "golang.org/x/exp/maps" - "golang.org/x/xerrors" - - "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/dbtime" - "github.com/coder/coder/v2/coderd/notifications/types" - "github.com/coder/coder/v2/coderd/rbac" - "github.com/coder/coder/v2/coderd/rbac/regosql" - "github.com/coder/coder/v2/coderd/util/slice" - "github.com/coder/coder/v2/coderd/workspaceapps/appurl" - "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/provisionersdk" -) - -var validProxyByHostnameRegex = regexp.MustCompile(`^[a-zA-Z0-9._-]+$`) - -// A full mapping of error codes from pq v1.10.9 can be found here: -// https://github.com/lib/pq/blob/2a217b94f5ccd3de31aec4152a541b9ff64bed05/error.go#L75 -var ( - errForeignKeyConstraint = &pq.Error{ - Code: "23503", // "foreign_key_violation" - Message: "update or delete on table violates foreign key constraint", - } - errUniqueConstraint = &pq.Error{ - Code: "23505", // "unique_violation" - Message: "duplicate key value violates unique constraint", - } -) - -// New returns an in-memory fake of the database. -func New() database.Store { - q := &FakeQuerier{ - mutex: &sync.RWMutex{}, - data: &data{ - apiKeys: make([]database.APIKey, 0), - auditLogs: make([]database.AuditLog, 0), - customRoles: make([]database.CustomRole, 0), - dbcryptKeys: make([]database.DBCryptKey, 0), - externalAuthLinks: make([]database.ExternalAuthLink, 0), - files: make([]database.File, 0), - gitSSHKey: make([]database.GitSSHKey, 0), - groups: make([]database.Group, 0), - groupMembers: make([]database.GroupMemberTable, 0), - licenses: make([]database.License, 0), - locks: map[int64]struct{}{}, - notificationMessages: make([]database.NotificationMessage, 0), - notificationPreferences: make([]database.NotificationPreference, 0), - organizationMembers: make([]database.OrganizationMember, 0), - organizations: make([]database.Organization, 0), - inboxNotifications: make([]database.InboxNotification, 0), - parameterSchemas: make([]database.ParameterSchema, 0), - presets: make([]database.TemplateVersionPreset, 0), - presetParameters: make([]database.TemplateVersionPresetParameter, 0), - presetPrebuildSchedules: make([]database.TemplateVersionPresetPrebuildSchedule, 0), - provisionerDaemons: make([]database.ProvisionerDaemon, 0), - provisionerJobs: make([]database.ProvisionerJob, 0), - provisionerJobLogs: make([]database.ProvisionerJobLog, 0), - provisionerKeys: make([]database.ProvisionerKey, 0), - runtimeConfig: map[string]string{}, - telemetryItems: make([]database.TelemetryItem, 0), - templateVersions: make([]database.TemplateVersionTable, 0), - templateVersionTerraformValues: make([]database.TemplateVersionTerraformValue, 0), - templates: make([]database.TemplateTable, 0), - users: make([]database.User, 0), - userConfigs: make([]database.UserConfig, 0), - userStatusChanges: make([]database.UserStatusChange, 0), - workspaceAgents: make([]database.WorkspaceAgent, 0), - workspaceResources: make([]database.WorkspaceResource, 0), - workspaceModules: make([]database.WorkspaceModule, 0), - workspaceResourceMetadata: make([]database.WorkspaceResourceMetadatum, 0), - workspaceAgentStats: make([]database.WorkspaceAgentStat, 0), - workspaceAgentLogs: make([]database.WorkspaceAgentLog, 0), - workspaceBuilds: make([]database.WorkspaceBuild, 0), - workspaceApps: make([]database.WorkspaceApp, 0), - workspaceAppAuditSessions: make([]database.WorkspaceAppAuditSession, 0), - workspaces: make([]database.WorkspaceTable, 0), - workspaceProxies: make([]database.WorkspaceProxy, 0), - }, - } - // Always start with a default org. Matching migration 198. - defaultOrg, err := q.InsertOrganization(context.Background(), database.InsertOrganizationParams{ - ID: uuid.New(), - Name: "coder", - DisplayName: "Coder", - Description: "Builtin default organization.", - Icon: "", - CreatedAt: dbtime.Now(), - UpdatedAt: dbtime.Now(), - }) - if err != nil { - panic(xerrors.Errorf("failed to create default organization: %w", err)) - } - - _, err = q.InsertAllUsersGroup(context.Background(), defaultOrg.ID) - if err != nil { - panic(xerrors.Errorf("failed to create default group: %w", err)) - } - - q.defaultProxyDisplayName = "Default" - q.defaultProxyIconURL = "/emojis/1f3e1.png" - - _, err = q.InsertProvisionerKey(context.Background(), database.InsertProvisionerKeyParams{ - ID: codersdk.ProvisionerKeyUUIDBuiltIn, - OrganizationID: defaultOrg.ID, - CreatedAt: dbtime.Now(), - HashedSecret: []byte{}, - Name: codersdk.ProvisionerKeyNameBuiltIn, - Tags: map[string]string{}, - }) - if err != nil { - panic(xerrors.Errorf("failed to create built-in provisioner key: %w", err)) - } - _, err = q.InsertProvisionerKey(context.Background(), database.InsertProvisionerKeyParams{ - ID: codersdk.ProvisionerKeyUUIDUserAuth, - OrganizationID: defaultOrg.ID, - CreatedAt: dbtime.Now(), - HashedSecret: []byte{}, - Name: codersdk.ProvisionerKeyNameUserAuth, - Tags: map[string]string{}, - }) - if err != nil { - panic(xerrors.Errorf("failed to create user-auth provisioner key: %w", err)) - } - _, err = q.InsertProvisionerKey(context.Background(), database.InsertProvisionerKeyParams{ - ID: codersdk.ProvisionerKeyUUIDPSK, - OrganizationID: defaultOrg.ID, - CreatedAt: dbtime.Now(), - HashedSecret: []byte{}, - Name: codersdk.ProvisionerKeyNamePSK, - Tags: map[string]string{}, - }) - if err != nil { - panic(xerrors.Errorf("failed to create psk provisioner key: %w", err)) - } - - q.mutex.Lock() - // We can't insert this user using the interface, because it's a system user. - q.data.users = append(q.data.users, database.User{ - ID: database.PrebuildsSystemUserID, - Email: "prebuilds@coder.com", - Username: "prebuilds", - CreatedAt: dbtime.Now(), - UpdatedAt: dbtime.Now(), - Status: "active", - LoginType: "none", - HashedPassword: []byte{}, - IsSystem: true, - Deleted: false, - }) - q.mutex.Unlock() - - return q -} - -type rwMutex interface { - Lock() - RLock() - Unlock() - RUnlock() -} - -// inTxMutex is a no op, since inside a transaction we are already locked. -type inTxMutex struct{} - -func (inTxMutex) Lock() {} -func (inTxMutex) RLock() {} -func (inTxMutex) Unlock() {} -func (inTxMutex) RUnlock() {} - -// FakeQuerier replicates database functionality to enable quick testing. It's an exported type so that our test code -// can do type checks. -type FakeQuerier struct { - mutex rwMutex - *data -} - -func (*FakeQuerier) Wrappers() []string { - return []string{} -} - -type fakeTx struct { - *FakeQuerier - locks map[int64]struct{} -} - -type data struct { - // Legacy tables - apiKeys []database.APIKey - organizations []database.Organization - organizationMembers []database.OrganizationMember - users []database.User - userLinks []database.UserLink - - // New tables - auditLogs []database.AuditLog - cryptoKeys []database.CryptoKey - dbcryptKeys []database.DBCryptKey - files []database.File - externalAuthLinks []database.ExternalAuthLink - gitSSHKey []database.GitSSHKey - groupMembers []database.GroupMemberTable - groups []database.Group - licenses []database.License - notificationMessages []database.NotificationMessage - notificationPreferences []database.NotificationPreference - notificationReportGeneratorLogs []database.NotificationReportGeneratorLog - inboxNotifications []database.InboxNotification - oauth2ProviderApps []database.OAuth2ProviderApp - oauth2ProviderAppSecrets []database.OAuth2ProviderAppSecret - oauth2ProviderAppCodes []database.OAuth2ProviderAppCode - oauth2ProviderAppTokens []database.OAuth2ProviderAppToken - parameterSchemas []database.ParameterSchema - provisionerDaemons []database.ProvisionerDaemon - provisionerJobLogs []database.ProvisionerJobLog - provisionerJobs []database.ProvisionerJob - provisionerKeys []database.ProvisionerKey - replicas []database.Replica - templateVersions []database.TemplateVersionTable - templateVersionParameters []database.TemplateVersionParameter - templateVersionTerraformValues []database.TemplateVersionTerraformValue - templateVersionVariables []database.TemplateVersionVariable - templateVersionWorkspaceTags []database.TemplateVersionWorkspaceTag - templates []database.TemplateTable - templateUsageStats []database.TemplateUsageStat - userConfigs []database.UserConfig - webpushSubscriptions []database.WebpushSubscription - workspaceAgents []database.WorkspaceAgent - workspaceAgentMetadata []database.WorkspaceAgentMetadatum - workspaceAgentLogs []database.WorkspaceAgentLog - workspaceAgentLogSources []database.WorkspaceAgentLogSource - workspaceAgentPortShares []database.WorkspaceAgentPortShare - workspaceAgentScriptTimings []database.WorkspaceAgentScriptTiming - workspaceAgentScripts []database.WorkspaceAgentScript - workspaceAgentStats []database.WorkspaceAgentStat - workspaceAgentMemoryResourceMonitors []database.WorkspaceAgentMemoryResourceMonitor - workspaceAgentVolumeResourceMonitors []database.WorkspaceAgentVolumeResourceMonitor - workspaceAgentDevcontainers []database.WorkspaceAgentDevcontainer - workspaceApps []database.WorkspaceApp - workspaceAppStatuses []database.WorkspaceAppStatus - workspaceAppAuditSessions []database.WorkspaceAppAuditSession - workspaceAppStatsLastInsertID int64 - workspaceAppStats []database.WorkspaceAppStat - workspaceBuilds []database.WorkspaceBuild - workspaceBuildParameters []database.WorkspaceBuildParameter - workspaceResourceMetadata []database.WorkspaceResourceMetadatum - workspaceResources []database.WorkspaceResource - workspaceModules []database.WorkspaceModule - workspaces []database.WorkspaceTable - workspaceProxies []database.WorkspaceProxy - customRoles []database.CustomRole - provisionerJobTimings []database.ProvisionerJobTiming - runtimeConfig map[string]string - // Locks is a map of lock names. Any keys within the map are currently - // locked. - locks map[int64]struct{} - deploymentID string - derpMeshKey string - lastUpdateCheck []byte - announcementBanners []byte - healthSettings []byte - notificationsSettings []byte - oauth2GithubDefaultEligible *bool - applicationName string - logoURL string - appSecurityKey string - oauthSigningKey string - coordinatorResumeTokenSigningKey string - lastLicenseID int32 - defaultProxyDisplayName string - defaultProxyIconURL string - webpushVAPIDPublicKey string - webpushVAPIDPrivateKey string - userStatusChanges []database.UserStatusChange - telemetryItems []database.TelemetryItem - presets []database.TemplateVersionPreset - presetParameters []database.TemplateVersionPresetParameter - presetPrebuildSchedules []database.TemplateVersionPresetPrebuildSchedule - prebuildsSettings []byte -} - -func tryPercentileCont(fs []float64, p float64) float64 { - if len(fs) == 0 { - return -1 - } - sort.Float64s(fs) - pos := p * (float64(len(fs)) - 1) / 100 - lower, upper := int(pos), int(math.Ceil(pos)) - if lower == upper { - return fs[lower] - } - return fs[lower] + (fs[upper]-fs[lower])*(pos-float64(lower)) -} - -func tryPercentileDisc(fs []float64, p float64) float64 { - if len(fs) == 0 { - return -1 - } - sort.Float64s(fs) - return fs[max(int(math.Ceil(float64(len(fs))*p/100-1)), 0)] -} - -func validateDatabaseTypeWithValid(v reflect.Value) (handled bool, err error) { - if v.Kind() == reflect.Struct { - return false, nil - } - - if v.CanInterface() { - if !strings.Contains(v.Type().PkgPath(), "coderd/database") { - return true, nil - } - if valid, ok := v.Interface().(interface{ Valid() bool }); ok { - if !valid.Valid() { - return true, xerrors.Errorf("invalid %s: %q", v.Type().Name(), v.Interface()) - } - } - return true, nil - } - return false, nil -} - -// validateDatabaseType uses reflect to check if struct properties are types -// with a Valid() bool function set. If so, call it and return an error -// if false. -// -// Note that we only check immediate values and struct fields. We do not -// recurse into nested structs. -func validateDatabaseType(args interface{}) error { - v := reflect.ValueOf(args) - - // Note: database.Null* types don't have a Valid method, we skip them here - // because their embedded types may have a Valid method and we don't want - // to bother with checking both that the Valid field is true and that the - // type it embeds validates to true. We would need to check: - // - // dbNullEnum.Valid && dbNullEnum.Enum.Valid() - if strings.HasPrefix(v.Type().Name(), "Null") { - return nil - } - - if ok, err := validateDatabaseTypeWithValid(v); ok { - return err - } - switch v.Kind() { - case reflect.Struct: - var errs []string - for i := 0; i < v.NumField(); i++ { - field := v.Field(i) - if ok, err := validateDatabaseTypeWithValid(field); ok && err != nil { - errs = append(errs, fmt.Sprintf("%s.%s: %s", v.Type().Name(), v.Type().Field(i).Name, err.Error())) - } - } - if len(errs) > 0 { - return xerrors.Errorf("invalid database type fields:\n\t%s", strings.Join(errs, "\n\t")) - } - default: - panic(fmt.Sprintf("unhandled type: %s", v.Type().Name())) - } - return nil -} - -func newUniqueConstraintError(uc database.UniqueConstraint) *pq.Error { - newErr := *errUniqueConstraint - newErr.Constraint = string(uc) - - return &newErr -} - -func (*FakeQuerier) Ping(_ context.Context) (time.Duration, error) { - return 0, nil -} - -func (*FakeQuerier) PGLocks(_ context.Context) (database.PGLocks, error) { - return []database.PGLock{}, nil -} - -func (tx *fakeTx) AcquireLock(_ context.Context, id int64) error { - if _, ok := tx.FakeQuerier.locks[id]; ok { - return xerrors.Errorf("cannot acquire lock %d: already held", id) - } - tx.FakeQuerier.locks[id] = struct{}{} - tx.locks[id] = struct{}{} - return nil -} - -func (tx *fakeTx) TryAcquireLock(_ context.Context, id int64) (bool, error) { - if _, ok := tx.FakeQuerier.locks[id]; ok { - return false, nil - } - tx.FakeQuerier.locks[id] = struct{}{} - tx.locks[id] = struct{}{} - return true, nil -} - -func (tx *fakeTx) releaseLocks() { - for id := range tx.locks { - delete(tx.FakeQuerier.locks, id) - } - tx.locks = map[int64]struct{}{} -} - -// InTx doesn't rollback data properly for in-memory yet. -func (q *FakeQuerier) InTx(fn func(database.Store) error, opts *database.TxOptions) error { - q.mutex.Lock() - defer q.mutex.Unlock() - tx := &fakeTx{ - FakeQuerier: &FakeQuerier{mutex: inTxMutex{}, data: q.data}, - locks: map[int64]struct{}{}, - } - defer tx.releaseLocks() - - if opts != nil { - database.IncrementExecutionCount(opts) - } - return fn(tx) -} - -// getUserByIDNoLock is used by other functions in the database fake. -func (q *FakeQuerier) getUserByIDNoLock(id uuid.UUID) (database.User, error) { - for _, user := range q.users { - if user.ID == id { - return user, nil - } - } - return database.User{}, sql.ErrNoRows -} - -func convertUsers(users []database.User, count int64) []database.GetUsersRow { - rows := make([]database.GetUsersRow, len(users)) - for i, u := range users { - rows[i] = database.GetUsersRow{ - ID: u.ID, - Email: u.Email, - Username: u.Username, - Name: u.Name, - HashedPassword: u.HashedPassword, - CreatedAt: u.CreatedAt, - UpdatedAt: u.UpdatedAt, - Status: u.Status, - RBACRoles: u.RBACRoles, - LoginType: u.LoginType, - AvatarURL: u.AvatarURL, - Deleted: u.Deleted, - LastSeenAt: u.LastSeenAt, - Count: count, - IsSystem: u.IsSystem, - } - } - - return rows -} - -// mapAgentStatus determines the agent status based on different timestamps like created_at, last_connected_at, disconnected_at, etc. -// The function must be in sync with: coderd/workspaceagents.go:convertWorkspaceAgent. -func mapAgentStatus(dbAgent database.WorkspaceAgent, agentInactiveDisconnectTimeoutSeconds int64) string { - var status string - connectionTimeout := time.Duration(dbAgent.ConnectionTimeoutSeconds) * time.Second - switch { - case !dbAgent.FirstConnectedAt.Valid: - switch { - case connectionTimeout > 0 && dbtime.Now().Sub(dbAgent.CreatedAt) > connectionTimeout: - // If the agent took too long to connect the first time, - // mark it as timed out. - status = "timeout" - default: - // If the agent never connected, it's waiting for the compute - // to start up. - status = "connecting" - } - case dbAgent.DisconnectedAt.Time.After(dbAgent.LastConnectedAt.Time): - // If we've disconnected after our last connection, we know the - // agent is no longer connected. - status = "disconnected" - case dbtime.Now().Sub(dbAgent.LastConnectedAt.Time) > time.Duration(agentInactiveDisconnectTimeoutSeconds)*time.Second: - // The connection died without updating the last connected. - status = "disconnected" - case dbAgent.LastConnectedAt.Valid: - // The agent should be assumed connected if it's under inactivity timeouts - // and last connected at has been properly set. - status = "connected" - default: - panic("unknown agent status: " + status) - } - return status -} - -func (q *FakeQuerier) convertToWorkspaceRowsNoLock(ctx context.Context, workspaces []database.WorkspaceTable, count int64, withSummary bool) []database.GetWorkspacesRow { //nolint:revive // withSummary flag ensures the extra technical row - rows := make([]database.GetWorkspacesRow, 0, len(workspaces)) - for _, w := range workspaces { - extended := q.extendWorkspace(w) - - wr := database.GetWorkspacesRow{ - ID: w.ID, - CreatedAt: w.CreatedAt, - UpdatedAt: w.UpdatedAt, - OwnerID: w.OwnerID, - OrganizationID: w.OrganizationID, - TemplateID: w.TemplateID, - Deleted: w.Deleted, - Name: w.Name, - AutostartSchedule: w.AutostartSchedule, - Ttl: w.Ttl, - LastUsedAt: w.LastUsedAt, - DormantAt: w.DormantAt, - DeletingAt: w.DeletingAt, - AutomaticUpdates: w.AutomaticUpdates, - Favorite: w.Favorite, - NextStartAt: w.NextStartAt, - - OwnerAvatarUrl: extended.OwnerAvatarUrl, - OwnerUsername: extended.OwnerUsername, - OwnerName: extended.OwnerName, - - OrganizationName: extended.OrganizationName, - OrganizationDisplayName: extended.OrganizationDisplayName, - OrganizationIcon: extended.OrganizationIcon, - OrganizationDescription: extended.OrganizationDescription, - - TemplateName: extended.TemplateName, - TemplateDisplayName: extended.TemplateDisplayName, - TemplateIcon: extended.TemplateIcon, - TemplateDescription: extended.TemplateDescription, - - Count: count, - - // These fields are missing! - // Try to resolve them below - TemplateVersionID: uuid.UUID{}, - TemplateVersionName: sql.NullString{}, - LatestBuildCompletedAt: sql.NullTime{}, - LatestBuildCanceledAt: sql.NullTime{}, - LatestBuildError: sql.NullString{}, - LatestBuildTransition: "", - LatestBuildStatus: "", - } - - if build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, w.ID); err == nil { - for _, tv := range q.templateVersions { - if tv.ID == build.TemplateVersionID { - wr.TemplateVersionID = tv.ID - wr.TemplateVersionName = sql.NullString{ - Valid: true, - String: tv.Name, - } - break - } - } - - if pj, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID); err == nil { - wr.LatestBuildStatus = pj.JobStatus - wr.LatestBuildCanceledAt = pj.CanceledAt - wr.LatestBuildCompletedAt = pj.CompletedAt - wr.LatestBuildError = pj.Error - } - - wr.LatestBuildTransition = build.Transition - } - - rows = append(rows, wr) - } - if withSummary { - rows = append(rows, database.GetWorkspacesRow{ - Name: "**TECHNICAL_ROW**", - Count: count, - }) - } - return rows -} - -func (q *FakeQuerier) getWorkspaceByIDNoLock(_ context.Context, id uuid.UUID) (database.Workspace, error) { - return q.getWorkspaceNoLock(func(w database.WorkspaceTable) bool { - return w.ID == id - }) -} - -func (q *FakeQuerier) getWorkspaceNoLock(find func(w database.WorkspaceTable) bool) (database.Workspace, error) { - w, found := slice.Find(q.workspaces, find) - if found { - return q.extendWorkspace(w), nil - } - return database.Workspace{}, sql.ErrNoRows -} - -func (q *FakeQuerier) extendWorkspace(w database.WorkspaceTable) database.Workspace { - var extended database.Workspace - // This is a cheeky way to copy the fields over without explicitly listing them all. - d, _ := json.Marshal(w) - _ = json.Unmarshal(d, &extended) - - org, _ := slice.Find(q.organizations, func(o database.Organization) bool { - return o.ID == w.OrganizationID - }) - extended.OrganizationName = org.Name - extended.OrganizationDescription = org.Description - extended.OrganizationDisplayName = org.DisplayName - extended.OrganizationIcon = org.Icon - - tpl, _ := slice.Find(q.templates, func(t database.TemplateTable) bool { - return t.ID == w.TemplateID - }) - extended.TemplateName = tpl.Name - extended.TemplateDisplayName = tpl.DisplayName - extended.TemplateDescription = tpl.Description - extended.TemplateIcon = tpl.Icon - - owner, _ := slice.Find(q.users, func(u database.User) bool { - return u.ID == w.OwnerID - }) - extended.OwnerUsername = owner.Username - extended.OwnerName = owner.Name - extended.OwnerAvatarUrl = owner.AvatarURL - - return extended -} - -func (q *FakeQuerier) getWorkspaceByAgentIDNoLock(_ context.Context, agentID uuid.UUID) (database.Workspace, error) { - var agent database.WorkspaceAgent - for _, _agent := range q.workspaceAgents { - if _agent.ID == agentID { - agent = _agent - break - } - } - if agent.ID == uuid.Nil { - return database.Workspace{}, sql.ErrNoRows - } - - var resource database.WorkspaceResource - for _, _resource := range q.workspaceResources { - if _resource.ID == agent.ResourceID { - resource = _resource - break - } - } - if resource.ID == uuid.Nil { - return database.Workspace{}, sql.ErrNoRows - } - - var build database.WorkspaceBuild - for _, _build := range q.workspaceBuilds { - if _build.JobID == resource.JobID { - build = q.workspaceBuildWithUserNoLock(_build) - break - } - } - if build.ID == uuid.Nil { - return database.Workspace{}, sql.ErrNoRows - } - - return q.getWorkspaceNoLock(func(w database.WorkspaceTable) bool { - return w.ID == build.WorkspaceID - }) -} - -func (q *FakeQuerier) getWorkspaceBuildByIDNoLock(_ context.Context, id uuid.UUID) (database.WorkspaceBuild, error) { - for _, build := range q.workspaceBuilds { - if build.ID == id { - return q.workspaceBuildWithUserNoLock(build), nil - } - } - return database.WorkspaceBuild{}, sql.ErrNoRows -} - -func (q *FakeQuerier) getLatestWorkspaceBuildByWorkspaceIDNoLock(_ context.Context, workspaceID uuid.UUID) (database.WorkspaceBuild, error) { - var row database.WorkspaceBuild - var buildNum int32 = -1 - for _, workspaceBuild := range q.workspaceBuilds { - if workspaceBuild.WorkspaceID == workspaceID && workspaceBuild.BuildNumber > buildNum { - row = q.workspaceBuildWithUserNoLock(workspaceBuild) - buildNum = workspaceBuild.BuildNumber - } - } - if buildNum == -1 { - return database.WorkspaceBuild{}, sql.ErrNoRows - } - return row, nil -} - -func (q *FakeQuerier) getTemplateByIDNoLock(_ context.Context, id uuid.UUID) (database.Template, error) { - for _, template := range q.templates { - if template.ID == id { - return q.templateWithNameNoLock(template), nil - } - } - return database.Template{}, sql.ErrNoRows -} - -func (q *FakeQuerier) templatesWithUserNoLock(tpl []database.TemplateTable) []database.Template { - cpy := make([]database.Template, 0, len(tpl)) - for _, t := range tpl { - cpy = append(cpy, q.templateWithNameNoLock(t)) - } - return cpy -} - -func (q *FakeQuerier) templateWithNameNoLock(tpl database.TemplateTable) database.Template { - var user database.User - for _, _user := range q.users { - if _user.ID == tpl.CreatedBy { - user = _user - break - } - } - - var org database.Organization - for _, _org := range q.organizations { - if _org.ID == tpl.OrganizationID { - org = _org - break - } - } - - var withNames database.Template - // This is a cheeky way to copy the fields over without explicitly listing them all. - d, _ := json.Marshal(tpl) - _ = json.Unmarshal(d, &withNames) - withNames.CreatedByUsername = user.Username - withNames.CreatedByAvatarURL = user.AvatarURL - withNames.OrganizationName = org.Name - withNames.OrganizationDisplayName = org.DisplayName - withNames.OrganizationIcon = org.Icon - return withNames -} - -func (q *FakeQuerier) templateVersionWithUserNoLock(tpl database.TemplateVersionTable) database.TemplateVersion { - var user database.User - for _, _user := range q.users { - if _user.ID == tpl.CreatedBy { - user = _user - break - } - } - var withUser database.TemplateVersion - // This is a cheeky way to copy the fields over without explicitly listing them all. - d, _ := json.Marshal(tpl) - _ = json.Unmarshal(d, &withUser) - withUser.CreatedByUsername = user.Username - withUser.CreatedByAvatarURL = user.AvatarURL - return withUser -} - -func (q *FakeQuerier) workspaceBuildWithUserNoLock(tpl database.WorkspaceBuild) database.WorkspaceBuild { - var user database.User - for _, _user := range q.users { - if _user.ID == tpl.InitiatorID { - user = _user - break - } - } - var withUser database.WorkspaceBuild - // This is a cheeky way to copy the fields over without explicitly listing them all. - d, _ := json.Marshal(tpl) - _ = json.Unmarshal(d, &withUser) - withUser.InitiatorByUsername = user.Username - withUser.InitiatorByAvatarUrl = user.AvatarURL - return withUser -} - -func (q *FakeQuerier) getTemplateVersionByIDNoLock(_ context.Context, templateVersionID uuid.UUID) (database.TemplateVersion, error) { - for _, templateVersion := range q.templateVersions { - if templateVersion.ID != templateVersionID { - continue - } - return q.templateVersionWithUserNoLock(templateVersion), nil - } - return database.TemplateVersion{}, sql.ErrNoRows -} - -func (q *FakeQuerier) getWorkspaceAgentByIDNoLock(_ context.Context, id uuid.UUID) (database.WorkspaceAgent, error) { - // The schema sorts this by created at, so we iterate the array backwards. - for i := len(q.workspaceAgents) - 1; i >= 0; i-- { - agent := q.workspaceAgents[i] - if !agent.Deleted && agent.ID == id { - return agent, nil - } - } - return database.WorkspaceAgent{}, sql.ErrNoRows -} - -func (q *FakeQuerier) getWorkspaceAgentsByResourceIDsNoLock(_ context.Context, resourceIDs []uuid.UUID) ([]database.WorkspaceAgent, error) { - workspaceAgents := make([]database.WorkspaceAgent, 0) - for _, agent := range q.workspaceAgents { - if agent.Deleted { - continue - } - for _, resourceID := range resourceIDs { - if agent.ResourceID != resourceID { - continue - } - workspaceAgents = append(workspaceAgents, agent) - } - } - return workspaceAgents, nil -} - -func (q *FakeQuerier) getWorkspaceAppByAgentIDAndSlugNoLock(_ context.Context, arg database.GetWorkspaceAppByAgentIDAndSlugParams) (database.WorkspaceApp, error) { - for _, app := range q.workspaceApps { - if app.AgentID != arg.AgentID { - continue - } - if app.Slug != arg.Slug { - continue - } - return app, nil - } - return database.WorkspaceApp{}, sql.ErrNoRows -} - -func (q *FakeQuerier) getProvisionerJobByIDNoLock(_ context.Context, id uuid.UUID) (database.ProvisionerJob, error) { - for _, provisionerJob := range q.provisionerJobs { - if provisionerJob.ID != id { - continue - } - // clone the Tags before returning, since maps are reference types and - // we don't want the caller to be able to mutate the map we have inside - // dbmem! - provisionerJob.Tags = maps.Clone(provisionerJob.Tags) - return provisionerJob, nil - } - return database.ProvisionerJob{}, sql.ErrNoRows -} - -func (q *FakeQuerier) getWorkspaceResourcesByJobIDNoLock(_ context.Context, jobID uuid.UUID) ([]database.WorkspaceResource, error) { - resources := make([]database.WorkspaceResource, 0) - for _, resource := range q.workspaceResources { - if resource.JobID != jobID { - continue - } - resources = append(resources, resource) - } - return resources, nil -} - -func (q *FakeQuerier) getGroupByIDNoLock(_ context.Context, id uuid.UUID) (database.Group, error) { - for _, group := range q.groups { - if group.ID == id { - return group, nil - } - } - - return database.Group{}, sql.ErrNoRows -} - -// ErrUnimplemented is returned by methods only used by the enterprise/tailnet.pgCoord. This coordinator explicitly -// depends on postgres triggers that announce changes on the pubsub. Implementing support for this in the fake -// database would strongly couple the FakeQuerier to the pubsub, which is undesirable. Furthermore, it makes little -// sense to directly test the pgCoord against anything other than postgres. The FakeQuerier is designed to allow us to -// test the Coderd API, and for that kind of test, the in-memory, AGPL tailnet coordinator is sufficient. Therefore, -// these methods remain unimplemented in the FakeQuerier. -var ErrUnimplemented = xerrors.New("unimplemented") - -func uniqueSortedUUIDs(uuids []uuid.UUID) []uuid.UUID { - set := make(map[uuid.UUID]struct{}) - for _, id := range uuids { - set[id] = struct{}{} - } - unique := make([]uuid.UUID, 0, len(set)) - for id := range set { - unique = append(unique, id) - } - slices.SortFunc(unique, func(a, b uuid.UUID) int { - return slice.Ascending(a.String(), b.String()) - }) - return unique -} - -func (q *FakeQuerier) getOrganizationMemberNoLock(orgID uuid.UUID) []database.OrganizationMember { - var members []database.OrganizationMember - for _, member := range q.organizationMembers { - if member.OrganizationID == orgID { - members = append(members, member) - } - } - - return members -} - -var errUserDeleted = xerrors.New("user deleted") - -// getGroupMemberNoLock fetches a group member by user ID and group ID. -func (q *FakeQuerier) getGroupMemberNoLock(ctx context.Context, userID, groupID uuid.UUID) (database.GroupMember, error) { - groupName := "Everyone" - orgID := groupID - groupIsEveryone := q.isEveryoneGroup(groupID) - if !groupIsEveryone { - group, err := q.getGroupByIDNoLock(ctx, groupID) - if err != nil { - return database.GroupMember{}, err - } - groupName = group.Name - orgID = group.OrganizationID - } - - user, err := q.getUserByIDNoLock(userID) - if err != nil { - return database.GroupMember{}, err - } - if user.Deleted { - return database.GroupMember{}, errUserDeleted - } - - return database.GroupMember{ - UserID: user.ID, - UserEmail: user.Email, - UserUsername: user.Username, - UserHashedPassword: user.HashedPassword, - UserCreatedAt: user.CreatedAt, - UserUpdatedAt: user.UpdatedAt, - UserStatus: user.Status, - UserRbacRoles: user.RBACRoles, - UserLoginType: user.LoginType, - UserAvatarUrl: user.AvatarURL, - UserDeleted: user.Deleted, - UserLastSeenAt: user.LastSeenAt, - UserQuietHoursSchedule: user.QuietHoursSchedule, - UserName: user.Name, - UserGithubComUserID: user.GithubComUserID, - OrganizationID: orgID, - GroupName: groupName, - GroupID: groupID, - }, nil -} - -// getEveryoneGroupMembersNoLock fetches all the users in an organization. -func (q *FakeQuerier) getEveryoneGroupMembersNoLock(ctx context.Context, orgID uuid.UUID) []database.GroupMember { - var ( - everyone []database.GroupMember - orgMembers = q.getOrganizationMemberNoLock(orgID) - ) - for _, member := range orgMembers { - groupMember, err := q.getGroupMemberNoLock(ctx, member.UserID, orgID) - if errors.Is(err, errUserDeleted) { - continue - } - if err != nil { - return nil - } - everyone = append(everyone, groupMember) - } - return everyone -} - -// isEveryoneGroup returns true if the provided ID matches -// an organization ID. -func (q *FakeQuerier) isEveryoneGroup(id uuid.UUID) bool { - for _, org := range q.organizations { - if org.ID == id { - return true - } - } - return false -} - -func (q *FakeQuerier) GetActiveDBCryptKeys(_ context.Context) ([]database.DBCryptKey, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - ks := make([]database.DBCryptKey, 0, len(q.dbcryptKeys)) - for _, k := range q.dbcryptKeys { - if !k.ActiveKeyDigest.Valid { - continue - } - ks = append([]database.DBCryptKey{}, k) - } - return ks, nil -} - -func maxTime(t, u time.Time) time.Time { - if t.After(u) { - return t - } - return u -} - -func minTime(t, u time.Time) time.Time { - if t.Before(u) { - return t - } - return u -} - -func provisionerJobStatus(j database.ProvisionerJob) database.ProvisionerJobStatus { - if isNotNull(j.CompletedAt) { - if j.Error.String != "" { - return database.ProvisionerJobStatusFailed - } - if isNotNull(j.CanceledAt) { - return database.ProvisionerJobStatusCanceled - } - return database.ProvisionerJobStatusSucceeded - } - - if isNotNull(j.CanceledAt) { - return database.ProvisionerJobStatusCanceling - } - if isNull(j.StartedAt) { - return database.ProvisionerJobStatusPending - } - return database.ProvisionerJobStatusRunning -} - -// isNull is only used in dbmem, so reflect is ok. Use this to make the logic -// look more similar to the postgres. -func isNull(v interface{}) bool { - return !isNotNull(v) -} - -func isNotNull(v interface{}) bool { - return reflect.ValueOf(v).FieldByName("Valid").Bool() -} - -// Took the error from the real database. -var deletedUserLinkError = &pq.Error{ - Severity: "ERROR", - // "raise_exception" error - Code: "P0001", - Message: "Cannot create user_link for deleted user", - Where: "PL/pgSQL function insert_user_links_fail_if_user_deleted() line 7 at RAISE", - File: "pl_exec.c", - Line: "3864", - Routine: "exec_stmt_raise", -} - -// m1 and m2 are equal iff |m1| = |m2| ^ m2 ⊆ m1 -func tagsEqual(m1, m2 map[string]string) bool { - return len(m1) == len(m2) && tagsSubset(m1, m2) -} - -// m2 is a subset of m1 if each key in m1 exists in m2 -// with the same value -func tagsSubset(m1, m2 map[string]string) bool { - for k, v1 := range m1 { - if v2, found := m2[k]; !found || v1 != v2 { - return false - } - } - return true -} - -// default tags when no tag is specified for a provisioner or job -var tagsUntagged = provisionersdk.MutateTags(uuid.Nil, nil) - -func least[T constraints.Ordered](a, b T) T { - if a < b { - return a - } - return b -} - -func (q *FakeQuerier) getLatestWorkspaceAppByTemplateIDUserIDSlugNoLock(ctx context.Context, templateID, userID uuid.UUID, slug string) (database.WorkspaceApp, error) { - /* - SELECT - app.display_name, - app.icon, - app.slug - FROM - workspace_apps AS app - JOIN - workspace_agents AS agent - ON - agent.id = app.agent_id - JOIN - workspace_resources AS resource - ON - resource.id = agent.resource_id - JOIN - workspace_builds AS build - ON - build.job_id = resource.job_id - JOIN - workspaces AS workspace - ON - workspace.id = build.workspace_id - WHERE - -- Requires lateral join. - app.slug = app_usage.key - AND workspace.owner_id = tus.user_id - AND workspace.template_id = tus.template_id - ORDER BY - app.created_at DESC - LIMIT 1 - */ - - var workspaces []database.WorkspaceTable - for _, w := range q.workspaces { - if w.TemplateID != templateID || w.OwnerID != userID { - continue - } - workspaces = append(workspaces, w) - } - slices.SortFunc(workspaces, func(a, b database.WorkspaceTable) int { - if a.CreatedAt.Before(b.CreatedAt) { - return 1 - } else if a.CreatedAt.Equal(b.CreatedAt) { - return 0 - } - return -1 - }) - - for _, workspace := range workspaces { - build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID) - if err != nil { - continue - } - - resources, err := q.getWorkspaceResourcesByJobIDNoLock(ctx, build.JobID) - if err != nil { - continue - } - var resourceIDs []uuid.UUID - for _, resource := range resources { - resourceIDs = append(resourceIDs, resource.ID) - } - - agents, err := q.getWorkspaceAgentsByResourceIDsNoLock(ctx, resourceIDs) - if err != nil { - continue - } - - for _, agent := range agents { - app, err := q.getWorkspaceAppByAgentIDAndSlugNoLock(ctx, database.GetWorkspaceAppByAgentIDAndSlugParams{ - AgentID: agent.ID, - Slug: slug, - }) - if err != nil { - continue - } - return app, nil - } - } - - return database.WorkspaceApp{}, sql.ErrNoRows -} - -// getOrganizationByIDNoLock is used by other functions in the database fake. -func (q *FakeQuerier) getOrganizationByIDNoLock(id uuid.UUID) (database.Organization, error) { - for _, organization := range q.organizations { - if organization.ID == id { - return organization, nil - } - } - return database.Organization{}, sql.ErrNoRows -} - -func (q *FakeQuerier) getWorkspaceAgentScriptsByAgentIDsNoLock(ids []uuid.UUID) ([]database.WorkspaceAgentScript, error) { - scripts := make([]database.WorkspaceAgentScript, 0) - for _, script := range q.workspaceAgentScripts { - for _, id := range ids { - if script.WorkspaceAgentID == id { - scripts = append(scripts, script) - break - } - } - } - return scripts, nil -} - -// getOwnerFromTags returns the lowercase owner from tags, matching SQL's COALESCE(tags ->> 'owner', ”) -func getOwnerFromTags(tags map[string]string) string { - if owner, ok := tags["owner"]; ok { - return strings.ToLower(owner) - } - return "" -} - -// provisionerTagsetContains checks if daemonTags contain all key-value pairs from jobTags -func provisionerTagsetContains(daemonTags, jobTags map[string]string) bool { - for jobKey, jobValue := range jobTags { - if daemonValue, exists := daemonTags[jobKey]; !exists || daemonValue != jobValue { - return false - } - } - return true -} - -// GetProvisionerJobsByIDsWithQueuePosition mimics the SQL logic in pure Go -func (q *FakeQuerier) getProvisionerJobsByIDsWithQueuePositionLockedTagBasedQueue(_ context.Context, jobIDs []uuid.UUID) ([]database.GetProvisionerJobsByIDsWithQueuePositionRow, error) { - // Step 1: Filter provisionerJobs based on jobIDs - filteredJobs := make(map[uuid.UUID]database.ProvisionerJob) - for _, job := range q.provisionerJobs { - for _, id := range jobIDs { - if job.ID == id { - filteredJobs[job.ID] = job - } - } - } - - // Step 2: Identify pending jobs - pendingJobs := make(map[uuid.UUID]database.ProvisionerJob) - for _, job := range q.provisionerJobs { - if job.JobStatus == "pending" { - pendingJobs[job.ID] = job - } - } - - // Step 3: Identify pending jobs that have a matching provisioner - matchedJobs := make(map[uuid.UUID]struct{}) - for _, job := range pendingJobs { - for _, daemon := range q.provisionerDaemons { - if provisionerTagsetContains(daemon.Tags, job.Tags) { - matchedJobs[job.ID] = struct{}{} - break - } - } - } - - // Step 4: Rank pending jobs per provisioner - jobRanks := make(map[uuid.UUID][]database.ProvisionerJob) - for _, job := range pendingJobs { - for _, daemon := range q.provisionerDaemons { - if provisionerTagsetContains(daemon.Tags, job.Tags) { - jobRanks[daemon.ID] = append(jobRanks[daemon.ID], job) - } - } - } - - // Sort jobs per provisioner by CreatedAt - for daemonID := range jobRanks { - sort.Slice(jobRanks[daemonID], func(i, j int) bool { - return jobRanks[daemonID][i].CreatedAt.Before(jobRanks[daemonID][j].CreatedAt) - }) - } - - // Step 5: Compute queue position & max queue size across all provisioners - jobQueueStats := make(map[uuid.UUID]database.GetProvisionerJobsByIDsWithQueuePositionRow) - for _, jobs := range jobRanks { - queueSize := int64(len(jobs)) // Queue size per provisioner - for i, job := range jobs { - queuePosition := int64(i + 1) - - // If the job already exists, update only if this queuePosition is better - if existing, exists := jobQueueStats[job.ID]; exists { - jobQueueStats[job.ID] = database.GetProvisionerJobsByIDsWithQueuePositionRow{ - ID: job.ID, - CreatedAt: job.CreatedAt, - ProvisionerJob: job, - QueuePosition: min(existing.QueuePosition, queuePosition), - QueueSize: max(existing.QueueSize, queueSize), // Take the maximum queue size across provisioners - } - } else { - jobQueueStats[job.ID] = database.GetProvisionerJobsByIDsWithQueuePositionRow{ - ID: job.ID, - CreatedAt: job.CreatedAt, - ProvisionerJob: job, - QueuePosition: queuePosition, - QueueSize: queueSize, - } - } - } - } - - // Step 6: Compute the final results with minimal checks - var results []database.GetProvisionerJobsByIDsWithQueuePositionRow - for _, job := range filteredJobs { - // If the job has a computed rank, use it - if rank, found := jobQueueStats[job.ID]; found { - results = append(results, rank) - } else { - // Otherwise, return (0,0) for non-pending jobs and unranked pending jobs - results = append(results, database.GetProvisionerJobsByIDsWithQueuePositionRow{ - ID: job.ID, - CreatedAt: job.CreatedAt, - ProvisionerJob: job, - QueuePosition: 0, - QueueSize: 0, - }) - } - } - - // Step 7: Sort results by CreatedAt - sort.Slice(results, func(i, j int) bool { - return results[i].CreatedAt.Before(results[j].CreatedAt) - }) - - return results, nil -} - -func (q *FakeQuerier) getProvisionerJobsByIDsWithQueuePositionLockedGlobalQueue(_ context.Context, ids []uuid.UUID) ([]database.GetProvisionerJobsByIDsWithQueuePositionRow, error) { - // WITH pending_jobs AS ( - // SELECT - // id, created_at - // FROM - // provisioner_jobs - // WHERE - // started_at IS NULL - // AND - // canceled_at IS NULL - // AND - // completed_at IS NULL - // AND - // error IS NULL - // ), - type pendingJobRow struct { - ID uuid.UUID - CreatedAt time.Time - } - pendingJobs := make([]pendingJobRow, 0) - for _, job := range q.provisionerJobs { - if job.StartedAt.Valid || - job.CanceledAt.Valid || - job.CompletedAt.Valid || - job.Error.Valid { - continue - } - pendingJobs = append(pendingJobs, pendingJobRow{ - ID: job.ID, - CreatedAt: job.CreatedAt, - }) - } - - // queue_position AS ( - // SELECT - // id, - // ROW_NUMBER() OVER (ORDER BY created_at ASC) AS queue_position - // FROM - // pending_jobs - // ), - slices.SortFunc(pendingJobs, func(a, b pendingJobRow) int { - c := a.CreatedAt.Compare(b.CreatedAt) - return c - }) - - queuePosition := make(map[uuid.UUID]int64) - for idx, pj := range pendingJobs { - queuePosition[pj.ID] = int64(idx + 1) - } - - // queue_size AS ( - // SELECT COUNT(*) AS count FROM pending_jobs - // ), - queueSize := len(pendingJobs) - - // SELECT - // sqlc.embed(pj), - // COALESCE(qp.queue_position, 0) AS queue_position, - // COALESCE(qs.count, 0) AS queue_size - // FROM - // provisioner_jobs pj - // LEFT JOIN - // queue_position qp ON pj.id = qp.id - // LEFT JOIN - // queue_size qs ON TRUE - // WHERE - // pj.id IN (...) - jobs := make([]database.GetProvisionerJobsByIDsWithQueuePositionRow, 0) - for _, job := range q.provisionerJobs { - if ids != nil && !slices.Contains(ids, job.ID) { - continue - } - // clone the Tags before appending, since maps are reference types and - // we don't want the caller to be able to mutate the map we have inside - // dbmem! - job.Tags = maps.Clone(job.Tags) - job := database.GetProvisionerJobsByIDsWithQueuePositionRow{ - // sqlc.embed(pj), - ProvisionerJob: job, - // COALESCE(qp.queue_position, 0) AS queue_position, - QueuePosition: queuePosition[job.ID], - // COALESCE(qs.count, 0) AS queue_size - QueueSize: int64(queueSize), - } - jobs = append(jobs, job) - } - - return jobs, nil -} - -// isDeprecated returns true if the template is deprecated. -// A template is considered deprecated when it has a deprecation message. -func isDeprecated(template database.Template) bool { - return template.Deprecated != "" -} - -func (q *FakeQuerier) getWorkspaceBuildParametersNoLock(workspaceBuildID uuid.UUID) ([]database.WorkspaceBuildParameter, error) { - params := make([]database.WorkspaceBuildParameter, 0) - for _, param := range q.workspaceBuildParameters { - if param.WorkspaceBuildID != workspaceBuildID { - continue - } - params = append(params, param) - } - return params, nil -} - -func (*FakeQuerier) AcquireLock(_ context.Context, _ int64) error { - return xerrors.New("AcquireLock must only be called within a transaction") -} - -// AcquireNotificationMessages implements the *basic* business logic, but is *not* exhaustive or meant to be 1:1 with -// the real AcquireNotificationMessages query. -func (q *FakeQuerier) AcquireNotificationMessages(_ context.Context, arg database.AcquireNotificationMessagesParams) ([]database.AcquireNotificationMessagesRow, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - // Shift the first "Count" notifications off the slice (FIFO). - sz := len(q.notificationMessages) - if sz > int(arg.Count) { - sz = int(arg.Count) - } - - list := q.notificationMessages[:sz] - q.notificationMessages = q.notificationMessages[sz:] - - var out []database.AcquireNotificationMessagesRow - for _, nm := range list { - acquirableStatuses := []database.NotificationMessageStatus{database.NotificationMessageStatusPending, database.NotificationMessageStatusTemporaryFailure} - if !slices.Contains(acquirableStatuses, nm.Status) { - continue - } - - // Mimic mutation in database query. - nm.UpdatedAt = sql.NullTime{Time: dbtime.Now(), Valid: true} - nm.Status = database.NotificationMessageStatusLeased - nm.StatusReason = sql.NullString{String: fmt.Sprintf("Enqueued by notifier %d", arg.NotifierID), Valid: true} - nm.LeasedUntil = sql.NullTime{Time: dbtime.Now().Add(time.Second * time.Duration(arg.LeaseSeconds)), Valid: true} - - out = append(out, database.AcquireNotificationMessagesRow{ - ID: nm.ID, - Payload: nm.Payload, - Method: nm.Method, - TitleTemplate: "This is a title with {{.Labels.variable}}", - BodyTemplate: "This is a body with {{.Labels.variable}}", - TemplateID: nm.NotificationTemplateID, - }) - } - - return out, nil -} - -func (q *FakeQuerier) AcquireProvisionerJob(_ context.Context, arg database.AcquireProvisionerJobParams) (database.ProvisionerJob, error) { - if err := validateDatabaseType(arg); err != nil { - return database.ProvisionerJob{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, provisionerJob := range q.provisionerJobs { - if provisionerJob.OrganizationID != arg.OrganizationID { - continue - } - if provisionerJob.StartedAt.Valid { - continue - } - found := false - for _, provisionerType := range arg.Types { - if provisionerJob.Provisioner != provisionerType { - continue - } - found = true - break - } - if !found { - continue - } - tags := map[string]string{} - if arg.ProvisionerTags != nil { - err := json.Unmarshal(arg.ProvisionerTags, &tags) - if err != nil { - return provisionerJob, xerrors.Errorf("unmarshal: %w", err) - } - } - - // Special case for untagged provisioners: only match untagged jobs. - // Ref: coderd/database/queries/provisionerjobs.sql:24-30 - // CASE WHEN nested.tags :: jsonb = '{"scope": "organization", "owner": ""}' :: jsonb - // THEN nested.tags :: jsonb = @tags :: jsonb - if tagsEqual(provisionerJob.Tags, tagsUntagged) && !tagsEqual(provisionerJob.Tags, tags) { - continue - } - // ELSE nested.tags :: jsonb <@ @tags :: jsonb - if !tagsSubset(provisionerJob.Tags, tags) { - continue - } - provisionerJob.StartedAt = arg.StartedAt - provisionerJob.UpdatedAt = arg.StartedAt.Time - provisionerJob.WorkerID = arg.WorkerID - provisionerJob.JobStatus = provisionerJobStatus(provisionerJob) - q.provisionerJobs[index] = provisionerJob - // clone the Tags before returning, since maps are reference types and - // we don't want the caller to be able to mutate the map we have inside - // dbmem! - provisionerJob.Tags = maps.Clone(provisionerJob.Tags) - return provisionerJob, nil - } - return database.ProvisionerJob{}, sql.ErrNoRows -} - -func (q *FakeQuerier) ActivityBumpWorkspace(ctx context.Context, arg database.ActivityBumpWorkspaceParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - workspace, err := q.getWorkspaceByIDNoLock(ctx, arg.WorkspaceID) - if err != nil { - return err - } - latestBuild, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, arg.WorkspaceID) - if err != nil { - return err - } - - now := dbtime.Now() - for i := range q.workspaceBuilds { - if q.workspaceBuilds[i].BuildNumber != latestBuild.BuildNumber { - continue - } - // If the build is not active, do not bump. - if q.workspaceBuilds[i].Transition != database.WorkspaceTransitionStart { - return nil - } - // If the provisioner job is not completed, do not bump. - pj, err := q.getProvisionerJobByIDNoLock(ctx, q.workspaceBuilds[i].JobID) - if err != nil { - return err - } - if !pj.CompletedAt.Valid { - return nil - } - // Do not bump if the deadline is not set. - if q.workspaceBuilds[i].Deadline.IsZero() { - return nil - } - - // Check the template default TTL. - template, err := q.getTemplateByIDNoLock(ctx, workspace.TemplateID) - if err != nil { - return err - } - if template.ActivityBump == 0 { - return nil - } - activityBump := time.Duration(template.ActivityBump) - - var ttlDur time.Duration - if now.Add(activityBump).After(arg.NextAutostart) && arg.NextAutostart.After(now) { - // Extend to TTL (NOT activity bump) - add := arg.NextAutostart.Sub(now) - if workspace.Ttl.Valid && template.AllowUserAutostop { - add += time.Duration(workspace.Ttl.Int64) - } else { - add += time.Duration(template.DefaultTTL) - } - ttlDur = add - } else { - // Otherwise, default to regular activity bump duration. - ttlDur = activityBump - } - - // Only bump if 5% of the deadline has passed. - ttlDur95 := ttlDur - (ttlDur / 20) - minBumpDeadline := q.workspaceBuilds[i].Deadline.Add(-ttlDur95) - if now.Before(minBumpDeadline) { - return nil - } - - // Bump. - newDeadline := now.Add(ttlDur) - // Never decrease deadlines from a bump - newDeadline = maxTime(newDeadline, q.workspaceBuilds[i].Deadline) - q.workspaceBuilds[i].UpdatedAt = now - if !q.workspaceBuilds[i].MaxDeadline.IsZero() { - q.workspaceBuilds[i].Deadline = minTime(newDeadline, q.workspaceBuilds[i].MaxDeadline) - } else { - q.workspaceBuilds[i].Deadline = newDeadline - } - return nil - } - - return sql.ErrNoRows -} - -// nolint:revive // It's not a control flag, it's a filter. -func (q *FakeQuerier) AllUserIDs(_ context.Context, includeSystem bool) ([]uuid.UUID, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - userIDs := make([]uuid.UUID, 0, len(q.users)) - for idx := range q.users { - if !includeSystem && q.users[idx].IsSystem { - continue - } - - userIDs = append(userIDs, q.users[idx].ID) - } - return userIDs, nil -} - -func (q *FakeQuerier) ArchiveUnusedTemplateVersions(_ context.Context, arg database.ArchiveUnusedTemplateVersionsParams) ([]uuid.UUID, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - q.mutex.Lock() - defer q.mutex.Unlock() - type latestBuild struct { - Number int32 - Version uuid.UUID - } - latest := make(map[uuid.UUID]latestBuild) - - for _, b := range q.workspaceBuilds { - v, ok := latest[b.WorkspaceID] - if ok || b.BuildNumber < v.Number { - // Not the latest - continue - } - // Ignore deleted workspaces. - if b.Transition == database.WorkspaceTransitionDelete { - continue - } - latest[b.WorkspaceID] = latestBuild{ - Number: b.BuildNumber, - Version: b.TemplateVersionID, - } - } - - usedVersions := make(map[uuid.UUID]bool) - for _, l := range latest { - usedVersions[l.Version] = true - } - for _, tpl := range q.templates { - usedVersions[tpl.ActiveVersionID] = true - } - - var archived []uuid.UUID - for i, v := range q.templateVersions { - if arg.TemplateVersionID != uuid.Nil { - if v.ID != arg.TemplateVersionID { - continue - } - } - if v.Archived { - continue - } - - if _, ok := usedVersions[v.ID]; !ok { - var job *database.ProvisionerJob - for i, j := range q.provisionerJobs { - if v.JobID == j.ID { - job = &q.provisionerJobs[i] - break - } - } - - if arg.JobStatus.Valid { - if job.JobStatus != arg.JobStatus.ProvisionerJobStatus { - continue - } - } - - if job.JobStatus == database.ProvisionerJobStatusRunning || job.JobStatus == database.ProvisionerJobStatusPending { - continue - } - - v.Archived = true - q.templateVersions[i] = v - archived = append(archived, v.ID) - } - } - - return archived, nil -} - -func (q *FakeQuerier) BatchUpdateWorkspaceLastUsedAt(_ context.Context, arg database.BatchUpdateWorkspaceLastUsedAtParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - // temporary map to avoid O(q.workspaces*arg.workspaceIds) - m := make(map[uuid.UUID]struct{}) - for _, id := range arg.IDs { - m[id] = struct{}{} - } - n := 0 - for i := 0; i < len(q.workspaces); i++ { - if _, found := m[q.workspaces[i].ID]; !found { - continue - } - // WHERE last_used_at < @last_used_at - if !q.workspaces[i].LastUsedAt.Before(arg.LastUsedAt) { - continue - } - q.workspaces[i].LastUsedAt = arg.LastUsedAt - n++ - } - return nil -} - -func (q *FakeQuerier) BatchUpdateWorkspaceNextStartAt(_ context.Context, arg database.BatchUpdateWorkspaceNextStartAtParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, workspace := range q.workspaces { - for j, workspaceID := range arg.IDs { - if workspace.ID != workspaceID { - continue - } - - nextStartAt := arg.NextStartAts[j] - if nextStartAt.IsZero() { - q.workspaces[i].NextStartAt = sql.NullTime{} - } else { - q.workspaces[i].NextStartAt = sql.NullTime{Valid: true, Time: nextStartAt} - } - - break - } - } - - return nil -} - -func (*FakeQuerier) BulkMarkNotificationMessagesFailed(_ context.Context, arg database.BulkMarkNotificationMessagesFailedParams) (int64, error) { - err := validateDatabaseType(arg) - if err != nil { - return 0, err - } - return int64(len(arg.IDs)), nil -} - -func (*FakeQuerier) BulkMarkNotificationMessagesSent(_ context.Context, arg database.BulkMarkNotificationMessagesSentParams) (int64, error) { - err := validateDatabaseType(arg) - if err != nil { - return 0, err - } - return int64(len(arg.IDs)), nil -} - -func (q *FakeQuerier) ClaimPrebuiltWorkspace(ctx context.Context, arg database.ClaimPrebuiltWorkspaceParams) (database.ClaimPrebuiltWorkspaceRow, error) { - return database.ClaimPrebuiltWorkspaceRow{}, ErrUnimplemented -} - -func (*FakeQuerier) CleanTailnetCoordinators(_ context.Context) error { - return ErrUnimplemented -} - -func (*FakeQuerier) CleanTailnetLostPeers(context.Context) error { - return ErrUnimplemented -} - -func (*FakeQuerier) CleanTailnetTunnels(context.Context) error { - return ErrUnimplemented -} - -func (q *FakeQuerier) CountAuditLogs(ctx context.Context, arg database.CountAuditLogsParams) (int64, error) { - return q.CountAuthorizedAuditLogs(ctx, arg, nil) -} - -func (q *FakeQuerier) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) { - return nil, ErrUnimplemented -} - -func (q *FakeQuerier) CountUnreadInboxNotificationsByUserID(_ context.Context, userID uuid.UUID) (int64, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - var count int64 - for _, notification := range q.inboxNotifications { - if notification.UserID != userID { - continue - } - - if notification.ReadAt.Valid { - continue - } - - count++ - } - - return count, nil -} - -func (q *FakeQuerier) CustomRoles(_ context.Context, arg database.CustomRolesParams) ([]database.CustomRole, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - found := make([]database.CustomRole, 0) - for _, role := range q.data.customRoles { - if len(arg.LookupRoles) > 0 { - if !slices.ContainsFunc(arg.LookupRoles, func(pair database.NameOrganizationPair) bool { - if pair.Name != role.Name { - return false - } - - if role.OrganizationID.Valid { - // Expect org match - return role.OrganizationID.UUID == pair.OrganizationID - } - // Expect no org - return pair.OrganizationID == uuid.Nil - }) { - continue - } - } - - if arg.ExcludeOrgRoles && role.OrganizationID.Valid { - continue - } - - if arg.OrganizationID != uuid.Nil && role.OrganizationID.UUID != arg.OrganizationID { - continue - } - - found = append(found, role) - } - - return found, nil -} - -func (q *FakeQuerier) DeleteAPIKeyByID(_ context.Context, id string) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, apiKey := range q.apiKeys { - if apiKey.ID != id { - continue - } - q.apiKeys[index] = q.apiKeys[len(q.apiKeys)-1] - q.apiKeys = q.apiKeys[:len(q.apiKeys)-1] - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) DeleteAPIKeysByUserID(_ context.Context, userID uuid.UUID) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - for i := len(q.apiKeys) - 1; i >= 0; i-- { - if q.apiKeys[i].UserID == userID { - q.apiKeys = append(q.apiKeys[:i], q.apiKeys[i+1:]...) - } - } - - return nil -} - -func (*FakeQuerier) DeleteAllTailnetClientSubscriptions(_ context.Context, arg database.DeleteAllTailnetClientSubscriptionsParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - return ErrUnimplemented -} - -func (*FakeQuerier) DeleteAllTailnetTunnels(_ context.Context, arg database.DeleteAllTailnetTunnelsParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - return ErrUnimplemented -} - -func (q *FakeQuerier) DeleteAllWebpushSubscriptions(_ context.Context) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - q.webpushSubscriptions = make([]database.WebpushSubscription, 0) - return nil -} - -func (q *FakeQuerier) DeleteApplicationConnectAPIKeysByUserID(_ context.Context, userID uuid.UUID) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - for i := len(q.apiKeys) - 1; i >= 0; i-- { - if q.apiKeys[i].UserID == userID && q.apiKeys[i].Scope == database.APIKeyScopeApplicationConnect { - q.apiKeys = append(q.apiKeys[:i], q.apiKeys[i+1:]...) - } - } - - return nil -} - -func (*FakeQuerier) DeleteCoordinator(context.Context, uuid.UUID) error { - return ErrUnimplemented -} - -func (q *FakeQuerier) DeleteCryptoKey(_ context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.CryptoKey{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, key := range q.cryptoKeys { - if key.Feature == arg.Feature && key.Sequence == arg.Sequence { - q.cryptoKeys[i].Secret.String = "" - q.cryptoKeys[i].Secret.Valid = false - q.cryptoKeys[i].SecretKeyID.String = "" - q.cryptoKeys[i].SecretKeyID.Valid = false - return q.cryptoKeys[i], nil - } - } - return database.CryptoKey{}, sql.ErrNoRows -} - -func (q *FakeQuerier) DeleteCustomRole(_ context.Context, arg database.DeleteCustomRoleParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - initial := len(q.data.customRoles) - q.data.customRoles = slices.DeleteFunc(q.data.customRoles, func(role database.CustomRole) bool { - return role.OrganizationID.UUID == arg.OrganizationID.UUID && role.Name == arg.Name - }) - if initial == len(q.data.customRoles) { - return sql.ErrNoRows - } - - // Emulate the trigger 'remove_organization_member_custom_role' - for i, mem := range q.organizationMembers { - if mem.OrganizationID == arg.OrganizationID.UUID { - mem.Roles = slices.DeleteFunc(mem.Roles, func(role string) bool { - return role == arg.Name - }) - q.organizationMembers[i] = mem - } - } - return nil -} - -func (q *FakeQuerier) DeleteExternalAuthLink(_ context.Context, arg database.DeleteExternalAuthLinkParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, key := range q.externalAuthLinks { - if key.UserID != arg.UserID { - continue - } - if key.ProviderID != arg.ProviderID { - continue - } - q.externalAuthLinks[index] = q.externalAuthLinks[len(q.externalAuthLinks)-1] - q.externalAuthLinks = q.externalAuthLinks[:len(q.externalAuthLinks)-1] - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) DeleteGitSSHKey(_ context.Context, userID uuid.UUID) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, key := range q.gitSSHKey { - if key.UserID != userID { - continue - } - q.gitSSHKey[index] = q.gitSSHKey[len(q.gitSSHKey)-1] - q.gitSSHKey = q.gitSSHKey[:len(q.gitSSHKey)-1] - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) DeleteGroupByID(_ context.Context, id uuid.UUID) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, group := range q.groups { - if group.ID == id { - q.groups = append(q.groups[:i], q.groups[i+1:]...) - return nil - } - } - - return sql.ErrNoRows -} - -func (q *FakeQuerier) DeleteGroupMemberFromGroup(_ context.Context, arg database.DeleteGroupMemberFromGroupParams) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, member := range q.groupMembers { - if member.UserID == arg.UserID && member.GroupID == arg.GroupID { - q.groupMembers = append(q.groupMembers[:i], q.groupMembers[i+1:]...) - } - } - return nil -} - -func (q *FakeQuerier) DeleteLicense(_ context.Context, id int32) (int32, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, l := range q.licenses { - if l.ID == id { - q.licenses[index] = q.licenses[len(q.licenses)-1] - q.licenses = q.licenses[:len(q.licenses)-1] - return id, nil - } - } - return 0, sql.ErrNoRows -} - -func (q *FakeQuerier) DeleteOAuth2ProviderAppByClientID(ctx context.Context, id uuid.UUID) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, app := range q.oauth2ProviderApps { - if app.ID == id { - q.oauth2ProviderApps = append(q.oauth2ProviderApps[:i], q.oauth2ProviderApps[i+1:]...) - - // Also delete related secrets and tokens - for j := len(q.oauth2ProviderAppSecrets) - 1; j >= 0; j-- { - if q.oauth2ProviderAppSecrets[j].AppID == id { - q.oauth2ProviderAppSecrets = append(q.oauth2ProviderAppSecrets[:j], q.oauth2ProviderAppSecrets[j+1:]...) - } - } - - // Delete tokens for the app's secrets - for j := len(q.oauth2ProviderAppTokens) - 1; j >= 0; j-- { - token := q.oauth2ProviderAppTokens[j] - for _, secret := range q.oauth2ProviderAppSecrets { - if secret.AppID == id && token.AppSecretID == secret.ID { - q.oauth2ProviderAppTokens = append(q.oauth2ProviderAppTokens[:j], q.oauth2ProviderAppTokens[j+1:]...) - break - } - } - } - - return nil - } - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) DeleteOAuth2ProviderAppByID(_ context.Context, id uuid.UUID) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - index := slices.IndexFunc(q.oauth2ProviderApps, func(app database.OAuth2ProviderApp) bool { - return app.ID == id - }) - - if index < 0 { - return sql.ErrNoRows - } - - q.oauth2ProviderApps[index] = q.oauth2ProviderApps[len(q.oauth2ProviderApps)-1] - q.oauth2ProviderApps = q.oauth2ProviderApps[:len(q.oauth2ProviderApps)-1] - - // Cascade delete secrets associated with the deleted app. - var deletedSecretIDs []uuid.UUID - q.oauth2ProviderAppSecrets = slices.DeleteFunc(q.oauth2ProviderAppSecrets, func(secret database.OAuth2ProviderAppSecret) bool { - matches := secret.AppID == id - if matches { - deletedSecretIDs = append(deletedSecretIDs, secret.ID) - } - return matches - }) - - // Cascade delete tokens through the deleted secrets. - var keyIDsToDelete []string - q.oauth2ProviderAppTokens = slices.DeleteFunc(q.oauth2ProviderAppTokens, func(token database.OAuth2ProviderAppToken) bool { - matches := slice.Contains(deletedSecretIDs, token.AppSecretID) - if matches { - keyIDsToDelete = append(keyIDsToDelete, token.APIKeyID) - } - return matches - }) - - // Cascade delete API keys linked to the deleted tokens. - q.apiKeys = slices.DeleteFunc(q.apiKeys, func(key database.APIKey) bool { - return slices.Contains(keyIDsToDelete, key.ID) - }) - - return nil -} - -func (q *FakeQuerier) DeleteOAuth2ProviderAppCodeByID(_ context.Context, id uuid.UUID) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, code := range q.oauth2ProviderAppCodes { - if code.ID == id { - q.oauth2ProviderAppCodes[index] = q.oauth2ProviderAppCodes[len(q.oauth2ProviderAppCodes)-1] - q.oauth2ProviderAppCodes = q.oauth2ProviderAppCodes[:len(q.oauth2ProviderAppCodes)-1] - return nil - } - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) DeleteOAuth2ProviderAppCodesByAppAndUserID(_ context.Context, arg database.DeleteOAuth2ProviderAppCodesByAppAndUserIDParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, code := range q.oauth2ProviderAppCodes { - if code.AppID == arg.AppID && code.UserID == arg.UserID { - q.oauth2ProviderAppCodes[index] = q.oauth2ProviderAppCodes[len(q.oauth2ProviderAppCodes)-1] - q.oauth2ProviderAppCodes = q.oauth2ProviderAppCodes[:len(q.oauth2ProviderAppCodes)-1] - return nil - } - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) DeleteOAuth2ProviderAppSecretByID(_ context.Context, id uuid.UUID) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - index := slices.IndexFunc(q.oauth2ProviderAppSecrets, func(secret database.OAuth2ProviderAppSecret) bool { - return secret.ID == id - }) - - if index < 0 { - return sql.ErrNoRows - } - - q.oauth2ProviderAppSecrets[index] = q.oauth2ProviderAppSecrets[len(q.oauth2ProviderAppSecrets)-1] - q.oauth2ProviderAppSecrets = q.oauth2ProviderAppSecrets[:len(q.oauth2ProviderAppSecrets)-1] - - // Cascade delete tokens created through the deleted secret. - var keyIDsToDelete []string - q.oauth2ProviderAppTokens = slices.DeleteFunc(q.oauth2ProviderAppTokens, func(token database.OAuth2ProviderAppToken) bool { - matches := token.AppSecretID == id - if matches { - keyIDsToDelete = append(keyIDsToDelete, token.APIKeyID) - } - return matches - }) - - // Cascade delete API keys linked to the deleted tokens. - q.apiKeys = slices.DeleteFunc(q.apiKeys, func(key database.APIKey) bool { - return slices.Contains(keyIDsToDelete, key.ID) - }) - - return nil -} - -func (q *FakeQuerier) DeleteOAuth2ProviderAppTokensByAppAndUserID(_ context.Context, arg database.DeleteOAuth2ProviderAppTokensByAppAndUserIDParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - var keyIDsToDelete []string - q.oauth2ProviderAppTokens = slices.DeleteFunc(q.oauth2ProviderAppTokens, func(token database.OAuth2ProviderAppToken) bool { - // Join secrets and keys to see if the token matches. - secretIdx := slices.IndexFunc(q.oauth2ProviderAppSecrets, func(secret database.OAuth2ProviderAppSecret) bool { - return secret.ID == token.AppSecretID - }) - keyIdx := slices.IndexFunc(q.apiKeys, func(key database.APIKey) bool { - return key.ID == token.APIKeyID - }) - matches := secretIdx != -1 && - q.oauth2ProviderAppSecrets[secretIdx].AppID == arg.AppID && - keyIdx != -1 && q.apiKeys[keyIdx].UserID == arg.UserID - if matches { - keyIDsToDelete = append(keyIDsToDelete, token.APIKeyID) - } - return matches - }) - - // Cascade delete API keys linked to the deleted tokens. - q.apiKeys = slices.DeleteFunc(q.apiKeys, func(key database.APIKey) bool { - return slices.Contains(keyIDsToDelete, key.ID) - }) - - return nil -} - -func (*FakeQuerier) DeleteOldNotificationMessages(_ context.Context) error { - return nil -} - -func (q *FakeQuerier) DeleteOldProvisionerDaemons(_ context.Context) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - now := dbtime.Now() - weekInterval := 7 * 24 * time.Hour - weekAgo := now.Add(-weekInterval) - - var validDaemons []database.ProvisionerDaemon - for _, p := range q.provisionerDaemons { - if (p.CreatedAt.Before(weekAgo) && !p.LastSeenAt.Valid) || (p.LastSeenAt.Valid && p.LastSeenAt.Time.Before(weekAgo)) { - continue - } - validDaemons = append(validDaemons, p) - } - q.provisionerDaemons = validDaemons - return nil -} - -func (q *FakeQuerier) DeleteOldWorkspaceAgentLogs(_ context.Context, threshold time.Time) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - /* - WITH - latest_builds AS ( - SELECT - workspace_id, max(build_number) AS max_build_number - FROM - workspace_builds - GROUP BY - workspace_id - ), - */ - latestBuilds := make(map[uuid.UUID]int32) - for _, wb := range q.workspaceBuilds { - if lastBuildNumber, found := latestBuilds[wb.WorkspaceID]; found && lastBuildNumber > wb.BuildNumber { - continue - } - // not found or newer build number - latestBuilds[wb.WorkspaceID] = wb.BuildNumber - } - - /* - old_agents AS ( - SELECT - wa.id - FROM - workspace_agents AS wa - JOIN - workspace_resources AS wr - ON - wa.resource_id = wr.id - JOIN - workspace_builds AS wb - ON - wb.job_id = wr.job_id - LEFT JOIN - latest_builds - ON - latest_builds.workspace_id = wb.workspace_id - AND - latest_builds.max_build_number = wb.build_number - WHERE - -- Filter out the latest builds for each workspace. - latest_builds.workspace_id IS NULL - AND CASE - -- If the last time the agent connected was before @threshold - WHEN wa.last_connected_at IS NOT NULL THEN - wa.last_connected_at < @threshold :: timestamptz - -- The agent never connected, and was created before @threshold - ELSE wa.created_at < @threshold :: timestamptz - END - ) - */ - oldAgents := make(map[uuid.UUID]struct{}) - for _, wa := range q.workspaceAgents { - for _, wr := range q.workspaceResources { - if wr.ID != wa.ResourceID { - continue - } - for _, wb := range q.workspaceBuilds { - if wb.JobID != wr.JobID { - continue - } - latestBuildNumber, found := latestBuilds[wb.WorkspaceID] - if !found { - panic("workspaceBuilds got modified somehow while q was locked! This is a bug in dbmem!") - } - if latestBuildNumber == wb.BuildNumber { - continue - } - if wa.LastConnectedAt.Valid && wa.LastConnectedAt.Time.Before(threshold) || wa.CreatedAt.Before(threshold) { - oldAgents[wa.ID] = struct{}{} - } - } - } - } - /* - DELETE FROM workspace_agent_logs WHERE agent_id IN (SELECT id FROM old_agents); - */ - var validLogs []database.WorkspaceAgentLog - for _, log := range q.workspaceAgentLogs { - if _, found := oldAgents[log.AgentID]; found { - continue - } - validLogs = append(validLogs, log) - } - q.workspaceAgentLogs = validLogs - return nil -} - -func (q *FakeQuerier) DeleteOldWorkspaceAgentStats(_ context.Context) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - /* - DELETE FROM - workspace_agent_stats - WHERE - created_at < ( - SELECT - COALESCE( - -- When generating initial template usage stats, all the - -- raw agent stats are needed, after that only ~30 mins - -- from last rollup is needed. Deployment stats seem to - -- use between 15 mins and 1 hour of data. We keep a - -- little bit more (1 day) just in case. - MAX(start_time) - '1 days'::interval, - -- Fall back to ~6 months ago if there are no template - -- usage stats so that we don't delete the data before - -- it's rolled up. - NOW() - '180 days'::interval - ) - FROM - template_usage_stats - ) - AND created_at < ( - -- Delete at most in batches of 3 days (with a batch size of 3 days, we - -- can clear out the previous 6 months of data in ~60 iterations) whilst - -- keeping the DB load relatively low. - SELECT - COALESCE(MIN(created_at) + '3 days'::interval, NOW()) - FROM - workspace_agent_stats - ); - */ - - now := dbtime.Now() - var limit time.Time - // MAX - for _, stat := range q.templateUsageStats { - if stat.StartTime.After(limit) { - limit = stat.StartTime.AddDate(0, 0, -1) - } - } - // COALESCE - if limit.IsZero() { - limit = now.AddDate(0, 0, -180) - } - - var validStats []database.WorkspaceAgentStat - var batchLimit time.Time - for _, stat := range q.workspaceAgentStats { - if batchLimit.IsZero() || stat.CreatedAt.Before(batchLimit) { - batchLimit = stat.CreatedAt - } - } - if batchLimit.IsZero() { - batchLimit = time.Now() - } else { - batchLimit = batchLimit.AddDate(0, 0, 3) - } - for _, stat := range q.workspaceAgentStats { - if stat.CreatedAt.Before(limit) && stat.CreatedAt.Before(batchLimit) { - continue - } - validStats = append(validStats, stat) - } - q.workspaceAgentStats = validStats - return nil -} - -func (q *FakeQuerier) DeleteOrganizationMember(ctx context.Context, arg database.DeleteOrganizationMemberParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - deleted := false - q.data.organizationMembers = slices.DeleteFunc(q.data.organizationMembers, func(member database.OrganizationMember) bool { - match := member.OrganizationID == arg.OrganizationID && member.UserID == arg.UserID - deleted = deleted || match - return match - }) - if !deleted { - return sql.ErrNoRows - } - - // Delete group member trigger - q.groupMembers = slices.DeleteFunc(q.groupMembers, func(member database.GroupMemberTable) bool { - if member.UserID != arg.UserID { - return false - } - g, _ := q.getGroupByIDNoLock(ctx, member.GroupID) - return g.OrganizationID == arg.OrganizationID - }) - - return nil -} - -func (q *FakeQuerier) DeleteProvisionerKey(_ context.Context, id uuid.UUID) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, key := range q.provisionerKeys { - if key.ID == id { - q.provisionerKeys = append(q.provisionerKeys[:i], q.provisionerKeys[i+1:]...) - return nil - } - } - - return sql.ErrNoRows -} - -func (q *FakeQuerier) DeleteReplicasUpdatedBefore(_ context.Context, before time.Time) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, replica := range q.replicas { - if replica.UpdatedAt.Before(before) { - q.replicas = append(q.replicas[:i], q.replicas[i+1:]...) - } - } - - return nil -} - -func (q *FakeQuerier) DeleteRuntimeConfig(_ context.Context, key string) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - delete(q.runtimeConfig, key) - return nil -} - -func (*FakeQuerier) DeleteTailnetAgent(context.Context, database.DeleteTailnetAgentParams) (database.DeleteTailnetAgentRow, error) { - return database.DeleteTailnetAgentRow{}, ErrUnimplemented -} - -func (*FakeQuerier) DeleteTailnetClient(context.Context, database.DeleteTailnetClientParams) (database.DeleteTailnetClientRow, error) { - return database.DeleteTailnetClientRow{}, ErrUnimplemented -} - -func (*FakeQuerier) DeleteTailnetClientSubscription(context.Context, database.DeleteTailnetClientSubscriptionParams) error { - return ErrUnimplemented -} - -func (*FakeQuerier) DeleteTailnetPeer(_ context.Context, arg database.DeleteTailnetPeerParams) (database.DeleteTailnetPeerRow, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.DeleteTailnetPeerRow{}, err - } - - return database.DeleteTailnetPeerRow{}, ErrUnimplemented -} - -func (*FakeQuerier) DeleteTailnetTunnel(_ context.Context, arg database.DeleteTailnetTunnelParams) (database.DeleteTailnetTunnelRow, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.DeleteTailnetTunnelRow{}, err - } - - return database.DeleteTailnetTunnelRow{}, ErrUnimplemented -} - -func (q *FakeQuerier) DeleteWebpushSubscriptionByUserIDAndEndpoint(_ context.Context, arg database.DeleteWebpushSubscriptionByUserIDAndEndpointParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, subscription := range q.webpushSubscriptions { - if subscription.UserID == arg.UserID && subscription.Endpoint == arg.Endpoint { - q.webpushSubscriptions[i] = q.webpushSubscriptions[len(q.webpushSubscriptions)-1] - q.webpushSubscriptions = q.webpushSubscriptions[:len(q.webpushSubscriptions)-1] - return nil - } - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) DeleteWebpushSubscriptions(_ context.Context, ids []uuid.UUID) error { - q.mutex.Lock() - defer q.mutex.Unlock() - for i, subscription := range q.webpushSubscriptions { - if slices.Contains(ids, subscription.ID) { - q.webpushSubscriptions[i] = q.webpushSubscriptions[len(q.webpushSubscriptions)-1] - q.webpushSubscriptions = q.webpushSubscriptions[:len(q.webpushSubscriptions)-1] - return nil - } - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) DeleteWorkspaceAgentPortShare(_ context.Context, arg database.DeleteWorkspaceAgentPortShareParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, share := range q.workspaceAgentPortShares { - if share.WorkspaceID == arg.WorkspaceID && share.AgentName == arg.AgentName && share.Port == arg.Port { - q.workspaceAgentPortShares = append(q.workspaceAgentPortShares[:i], q.workspaceAgentPortShares[i+1:]...) - return nil - } - } - - return nil -} - -func (q *FakeQuerier) DeleteWorkspaceAgentPortSharesByTemplate(_ context.Context, templateID uuid.UUID) error { - err := validateDatabaseType(templateID) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, workspace := range q.workspaces { - if workspace.TemplateID != templateID { - continue - } - for i, share := range q.workspaceAgentPortShares { - if share.WorkspaceID != workspace.ID { - continue - } - q.workspaceAgentPortShares = append(q.workspaceAgentPortShares[:i], q.workspaceAgentPortShares[i+1:]...) - } - } - - return nil -} - -func (q *FakeQuerier) DeleteWorkspaceSubAgentByID(_ context.Context, id uuid.UUID) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, agent := range q.workspaceAgents { - if agent.ID == id && agent.ParentID.Valid { - q.workspaceAgents[i].Deleted = true - return nil - } - } - - return nil -} - -func (*FakeQuerier) DisableForeignKeysAndTriggers(_ context.Context) error { - // This is a no-op in the in-memory database. - return nil -} - -func (q *FakeQuerier) EnqueueNotificationMessage(_ context.Context, arg database.EnqueueNotificationMessageParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - var payload types.MessagePayload - err = json.Unmarshal(arg.Payload, &payload) - if err != nil { - return err - } - - nm := database.NotificationMessage{ - ID: arg.ID, - UserID: arg.UserID, - Method: arg.Method, - Payload: arg.Payload, - NotificationTemplateID: arg.NotificationTemplateID, - Targets: arg.Targets, - CreatedBy: arg.CreatedBy, - // Default fields. - CreatedAt: dbtime.Now(), - Status: database.NotificationMessageStatusPending, - } - - q.notificationMessages = append(q.notificationMessages, nm) - - return err -} - -func (q *FakeQuerier) FavoriteWorkspace(_ context.Context, arg uuid.UUID) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i := 0; i < len(q.workspaces); i++ { - if q.workspaces[i].ID != arg { - continue - } - q.workspaces[i].Favorite = true - return nil - } - return nil -} - -func (q *FakeQuerier) FetchMemoryResourceMonitorsByAgentID(_ context.Context, agentID uuid.UUID) (database.WorkspaceAgentMemoryResourceMonitor, error) { - for _, monitor := range q.workspaceAgentMemoryResourceMonitors { - if monitor.AgentID == agentID { - return monitor, nil - } - } - - return database.WorkspaceAgentMemoryResourceMonitor{}, sql.ErrNoRows -} - -func (q *FakeQuerier) FetchMemoryResourceMonitorsUpdatedAfter(_ context.Context, updatedAt time.Time) ([]database.WorkspaceAgentMemoryResourceMonitor, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - monitors := []database.WorkspaceAgentMemoryResourceMonitor{} - for _, monitor := range q.workspaceAgentMemoryResourceMonitors { - if monitor.UpdatedAt.After(updatedAt) { - monitors = append(monitors, monitor) - } - } - return monitors, nil -} - -func (q *FakeQuerier) FetchNewMessageMetadata(_ context.Context, arg database.FetchNewMessageMetadataParams) (database.FetchNewMessageMetadataRow, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.FetchNewMessageMetadataRow{}, err - } - - user, err := q.getUserByIDNoLock(arg.UserID) - if err != nil { - return database.FetchNewMessageMetadataRow{}, xerrors.Errorf("fetch user: %w", err) - } - - // Mimic COALESCE in query - userName := user.Name - if userName == "" { - userName = user.Username - } - - actions, err := json.Marshal([]types.TemplateAction{{URL: "http://xyz.com", Label: "XYZ"}}) - if err != nil { - return database.FetchNewMessageMetadataRow{}, err - } - - return database.FetchNewMessageMetadataRow{ - UserEmail: user.Email, - UserName: userName, - UserUsername: user.Username, - NotificationName: "Some notification", - Actions: actions, - UserID: arg.UserID, - }, nil -} - -func (q *FakeQuerier) FetchVolumesResourceMonitorsByAgentID(_ context.Context, agentID uuid.UUID) ([]database.WorkspaceAgentVolumeResourceMonitor, error) { - monitors := []database.WorkspaceAgentVolumeResourceMonitor{} - - for _, monitor := range q.workspaceAgentVolumeResourceMonitors { - if monitor.AgentID == agentID { - monitors = append(monitors, monitor) - } - } - - return monitors, nil -} - -func (q *FakeQuerier) FetchVolumesResourceMonitorsUpdatedAfter(_ context.Context, updatedAt time.Time) ([]database.WorkspaceAgentVolumeResourceMonitor, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - monitors := []database.WorkspaceAgentVolumeResourceMonitor{} - for _, monitor := range q.workspaceAgentVolumeResourceMonitors { - if monitor.UpdatedAt.After(updatedAt) { - monitors = append(monitors, monitor) - } - } - return monitors, nil -} - -func (q *FakeQuerier) GetAPIKeyByID(_ context.Context, id string) (database.APIKey, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, apiKey := range q.apiKeys { - if apiKey.ID == id { - return apiKey, nil - } - } - return database.APIKey{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetAPIKeyByName(_ context.Context, params database.GetAPIKeyByNameParams) (database.APIKey, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - if params.TokenName == "" { - return database.APIKey{}, sql.ErrNoRows - } - for _, apiKey := range q.apiKeys { - if params.UserID == apiKey.UserID && params.TokenName == apiKey.TokenName { - return apiKey, nil - } - } - return database.APIKey{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetAPIKeysByLoginType(_ context.Context, t database.LoginType) ([]database.APIKey, error) { - if err := validateDatabaseType(t); err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - apiKeys := make([]database.APIKey, 0) - for _, key := range q.apiKeys { - if key.LoginType == t { - apiKeys = append(apiKeys, key) - } - } - return apiKeys, nil -} - -func (q *FakeQuerier) GetAPIKeysByUserID(_ context.Context, params database.GetAPIKeysByUserIDParams) ([]database.APIKey, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - apiKeys := make([]database.APIKey, 0) - for _, key := range q.apiKeys { - if key.UserID == params.UserID && key.LoginType == params.LoginType { - apiKeys = append(apiKeys, key) - } - } - return apiKeys, nil -} - -func (q *FakeQuerier) GetAPIKeysLastUsedAfter(_ context.Context, after time.Time) ([]database.APIKey, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - apiKeys := make([]database.APIKey, 0) - for _, key := range q.apiKeys { - if key.LastUsed.After(after) { - apiKeys = append(apiKeys, key) - } - } - return apiKeys, nil -} - -func (q *FakeQuerier) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - var activeSchedules []database.TemplateVersionPresetPrebuildSchedule - - // Create a map of active template version IDs for quick lookup - activeTemplateVersions := make(map[uuid.UUID]bool) - for _, template := range q.templates { - if !template.Deleted && template.Deprecated == "" { - activeTemplateVersions[template.ActiveVersionID] = true - } - } - - // Create a map of presets for quick lookup - presetMap := make(map[uuid.UUID]database.TemplateVersionPreset) - for _, preset := range q.presets { - presetMap[preset.ID] = preset - } - - // Filter preset prebuild schedules to only include those for active template versions - for _, schedule := range q.presetPrebuildSchedules { - // Look up the preset using the map - preset, exists := presetMap[schedule.PresetID] - if !exists { - continue - } - - // Check if preset's template version is active - if !activeTemplateVersions[preset.TemplateVersionID] { - continue - } - - activeSchedules = append(activeSchedules, schedule) - } - - return activeSchedules, nil -} - -// nolint:revive // It's not a control flag, it's a filter. -func (q *FakeQuerier) GetActiveUserCount(_ context.Context, includeSystem bool) (int64, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - active := int64(0) - for _, u := range q.users { - if !includeSystem && u.IsSystem { - continue - } - - if u.Status == database.UserStatusActive && !u.Deleted { - active++ - } - } - return active, nil -} - -func (q *FakeQuerier) GetActiveWorkspaceBuildsByTemplateID(ctx context.Context, templateID uuid.UUID) ([]database.WorkspaceBuild, error) { - workspaceIDs := func() []uuid.UUID { - q.mutex.RLock() - defer q.mutex.RUnlock() - - ids := []uuid.UUID{} - for _, workspace := range q.workspaces { - if workspace.TemplateID == templateID { - ids = append(ids, workspace.ID) - } - } - return ids - }() - - builds, err := q.GetLatestWorkspaceBuildsByWorkspaceIDs(ctx, workspaceIDs) - if err != nil { - return nil, err - } - - filteredBuilds := []database.WorkspaceBuild{} - for _, build := range builds { - if build.Transition == database.WorkspaceTransitionStart { - filteredBuilds = append(filteredBuilds, build) - } - } - return filteredBuilds, nil -} - -func (*FakeQuerier) GetAllTailnetAgents(_ context.Context) ([]database.TailnetAgent, error) { - return nil, ErrUnimplemented -} - -func (*FakeQuerier) GetAllTailnetCoordinators(context.Context) ([]database.TailnetCoordinator, error) { - return nil, ErrUnimplemented -} - -func (*FakeQuerier) GetAllTailnetPeers(context.Context) ([]database.TailnetPeer, error) { - return nil, ErrUnimplemented -} - -func (*FakeQuerier) GetAllTailnetTunnels(context.Context) ([]database.TailnetTunnel, error) { - return nil, ErrUnimplemented -} - -func (q *FakeQuerier) GetAnnouncementBanners(_ context.Context) (string, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - if q.announcementBanners == nil { - return "", sql.ErrNoRows - } - - return string(q.announcementBanners), nil -} - -func (q *FakeQuerier) GetAppSecurityKey(_ context.Context) (string, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.appSecurityKey, nil -} - -func (q *FakeQuerier) GetApplicationName(_ context.Context) (string, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - if q.applicationName == "" { - return "", sql.ErrNoRows - } - - return q.applicationName, nil -} - -func (q *FakeQuerier) GetAuditLogsOffset(ctx context.Context, arg database.GetAuditLogsOffsetParams) ([]database.GetAuditLogsOffsetRow, error) { - return q.GetAuthorizedAuditLogsOffset(ctx, arg, nil) -} - -func (q *FakeQuerier) GetAuthorizationUserRoles(_ context.Context, userID uuid.UUID) (database.GetAuthorizationUserRolesRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - var user *database.User - roles := make([]string, 0) - for _, u := range q.users { - if u.ID == userID { - roles = append(roles, u.RBACRoles...) - roles = append(roles, "member") - user = &u - break - } - } - - for _, mem := range q.organizationMembers { - if mem.UserID == userID { - for _, orgRole := range mem.Roles { - roles = append(roles, orgRole+":"+mem.OrganizationID.String()) - } - roles = append(roles, "organization-member:"+mem.OrganizationID.String()) - } - } - - var groups []string - for _, member := range q.groupMembers { - if member.UserID == userID { - groups = append(groups, member.GroupID.String()) - } - } - - if user == nil { - return database.GetAuthorizationUserRolesRow{}, sql.ErrNoRows - } - - return database.GetAuthorizationUserRolesRow{ - ID: userID, - Username: user.Username, - Status: user.Status, - Roles: roles, - Groups: groups, - }, nil -} - -func (q *FakeQuerier) GetCoordinatorResumeTokenSigningKey(_ context.Context) (string, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - if q.coordinatorResumeTokenSigningKey == "" { - return "", sql.ErrNoRows - } - return q.coordinatorResumeTokenSigningKey, nil -} - -func (q *FakeQuerier) GetCryptoKeyByFeatureAndSequence(_ context.Context, arg database.GetCryptoKeyByFeatureAndSequenceParams) (database.CryptoKey, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.CryptoKey{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, key := range q.cryptoKeys { - if key.Feature == arg.Feature && key.Sequence == arg.Sequence { - // Keys with NULL secrets are considered deleted. - if key.Secret.Valid { - return key, nil - } - return database.CryptoKey{}, sql.ErrNoRows - } - } - - return database.CryptoKey{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetCryptoKeys(_ context.Context) ([]database.CryptoKey, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - keys := make([]database.CryptoKey, 0) - for _, key := range q.cryptoKeys { - if key.Secret.Valid { - keys = append(keys, key) - } - } - return keys, nil -} - -func (q *FakeQuerier) GetCryptoKeysByFeature(_ context.Context, feature database.CryptoKeyFeature) ([]database.CryptoKey, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - keys := make([]database.CryptoKey, 0) - for _, key := range q.cryptoKeys { - if key.Feature == feature && key.Secret.Valid { - keys = append(keys, key) - } - } - // We want to return the highest sequence number first. - slices.SortFunc(keys, func(i, j database.CryptoKey) int { - return int(j.Sequence - i.Sequence) - }) - return keys, nil -} - -func (q *FakeQuerier) GetDBCryptKeys(_ context.Context) ([]database.DBCryptKey, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - ks := make([]database.DBCryptKey, 0) - ks = append(ks, q.dbcryptKeys...) - return ks, nil -} - -func (q *FakeQuerier) GetDERPMeshKey(_ context.Context) (string, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - if q.derpMeshKey == "" { - return "", sql.ErrNoRows - } - return q.derpMeshKey, nil -} - -func (q *FakeQuerier) GetDefaultOrganization(_ context.Context) (database.Organization, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, org := range q.organizations { - if org.IsDefault { - return org, nil - } - } - return database.Organization{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetDefaultProxyConfig(_ context.Context) (database.GetDefaultProxyConfigRow, error) { - return database.GetDefaultProxyConfigRow{ - DisplayName: q.defaultProxyDisplayName, - IconUrl: q.defaultProxyIconURL, - }, nil -} - -func (q *FakeQuerier) GetDeploymentDAUs(_ context.Context, tzOffset int32) ([]database.GetDeploymentDAUsRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - seens := make(map[time.Time]map[uuid.UUID]struct{}) - - for _, as := range q.workspaceAgentStats { - if as.ConnectionCount == 0 { - continue - } - date := as.CreatedAt.UTC().Add(time.Duration(tzOffset) * -1 * time.Hour).Truncate(time.Hour * 24) - - dateEntry := seens[date] - if dateEntry == nil { - dateEntry = make(map[uuid.UUID]struct{}) - } - dateEntry[as.UserID] = struct{}{} - seens[date] = dateEntry - } - - seenKeys := maps.Keys(seens) - sort.Slice(seenKeys, func(i, j int) bool { - return seenKeys[i].Before(seenKeys[j]) - }) - - var rs []database.GetDeploymentDAUsRow - for _, key := range seenKeys { - ids := seens[key] - for id := range ids { - rs = append(rs, database.GetDeploymentDAUsRow{ - Date: key, - UserID: id, - }) - } - } - - return rs, nil -} - -func (q *FakeQuerier) GetDeploymentID(_ context.Context) (string, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.deploymentID, nil -} - -func (q *FakeQuerier) GetDeploymentWorkspaceAgentStats(_ context.Context, createdAfter time.Time) (database.GetDeploymentWorkspaceAgentStatsRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - agentStatsCreatedAfter := make([]database.WorkspaceAgentStat, 0) - for _, agentStat := range q.workspaceAgentStats { - if agentStat.CreatedAt.After(createdAfter) { - agentStatsCreatedAfter = append(agentStatsCreatedAfter, agentStat) - } - } - - latestAgentStats := map[uuid.UUID]database.WorkspaceAgentStat{} - for _, agentStat := range q.workspaceAgentStats { - if agentStat.CreatedAt.After(createdAfter) { - latestAgentStats[agentStat.AgentID] = agentStat - } - } - - stat := database.GetDeploymentWorkspaceAgentStatsRow{} - for _, agentStat := range latestAgentStats { - stat.SessionCountVSCode += agentStat.SessionCountVSCode - stat.SessionCountJetBrains += agentStat.SessionCountJetBrains - stat.SessionCountReconnectingPTY += agentStat.SessionCountReconnectingPTY - stat.SessionCountSSH += agentStat.SessionCountSSH - } - - latencies := make([]float64, 0) - for _, agentStat := range agentStatsCreatedAfter { - if agentStat.ConnectionMedianLatencyMS <= 0 { - continue - } - stat.WorkspaceRxBytes += agentStat.RxBytes - stat.WorkspaceTxBytes += agentStat.TxBytes - latencies = append(latencies, agentStat.ConnectionMedianLatencyMS) - } - - stat.WorkspaceConnectionLatency50 = tryPercentileCont(latencies, 50) - stat.WorkspaceConnectionLatency95 = tryPercentileCont(latencies, 95) - - return stat, nil -} - -func (q *FakeQuerier) GetDeploymentWorkspaceAgentUsageStats(_ context.Context, createdAt time.Time) (database.GetDeploymentWorkspaceAgentUsageStatsRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - stat := database.GetDeploymentWorkspaceAgentUsageStatsRow{} - sessions := make(map[uuid.UUID]database.WorkspaceAgentStat) - agentStatsCreatedAfter := make([]database.WorkspaceAgentStat, 0) - for _, agentStat := range q.workspaceAgentStats { - // WHERE workspace_agent_stats.created_at > $1 - if agentStat.CreatedAt.After(createdAt) { - agentStatsCreatedAfter = append(agentStatsCreatedAfter, agentStat) - } - // WHERE - // created_at > $1 - // AND created_at < date_trunc('minute', now()) -- Exclude current partial minute - // AND usage = true - if agentStat.Usage && - (agentStat.CreatedAt.After(createdAt) || agentStat.CreatedAt.Equal(createdAt)) && - agentStat.CreatedAt.Before(time.Now().Truncate(time.Minute)) { - val, ok := sessions[agentStat.AgentID] - if !ok { - sessions[agentStat.AgentID] = agentStat - } else if agentStat.CreatedAt.After(val.CreatedAt) { - sessions[agentStat.AgentID] = agentStat - } else if agentStat.CreatedAt.Truncate(time.Minute).Equal(val.CreatedAt.Truncate(time.Minute)) { - val.SessionCountVSCode += agentStat.SessionCountVSCode - val.SessionCountJetBrains += agentStat.SessionCountJetBrains - val.SessionCountReconnectingPTY += agentStat.SessionCountReconnectingPTY - val.SessionCountSSH += agentStat.SessionCountSSH - sessions[agentStat.AgentID] = val - } - } - } - - latencies := make([]float64, 0) - for _, agentStat := range agentStatsCreatedAfter { - if agentStat.ConnectionMedianLatencyMS <= 0 { - continue - } - stat.WorkspaceRxBytes += agentStat.RxBytes - stat.WorkspaceTxBytes += agentStat.TxBytes - latencies = append(latencies, agentStat.ConnectionMedianLatencyMS) - } - stat.WorkspaceConnectionLatency50 = tryPercentileCont(latencies, 50) - stat.WorkspaceConnectionLatency95 = tryPercentileCont(latencies, 95) - - for _, agentStat := range sessions { - stat.SessionCountVSCode += agentStat.SessionCountVSCode - stat.SessionCountJetBrains += agentStat.SessionCountJetBrains - stat.SessionCountReconnectingPTY += agentStat.SessionCountReconnectingPTY - stat.SessionCountSSH += agentStat.SessionCountSSH - } - - return stat, nil -} - -func (q *FakeQuerier) GetDeploymentWorkspaceStats(ctx context.Context) (database.GetDeploymentWorkspaceStatsRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - stat := database.GetDeploymentWorkspaceStatsRow{} - for _, workspace := range q.workspaces { - build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID) - if err != nil { - return stat, err - } - job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) - if err != nil { - return stat, err - } - if !job.StartedAt.Valid { - stat.PendingWorkspaces++ - continue - } - if job.StartedAt.Valid && - !job.CanceledAt.Valid && - time.Since(job.UpdatedAt) <= 30*time.Second && - !job.CompletedAt.Valid { - stat.BuildingWorkspaces++ - continue - } - if job.CompletedAt.Valid && - !job.CanceledAt.Valid && - !job.Error.Valid { - if build.Transition == database.WorkspaceTransitionStart { - stat.RunningWorkspaces++ - } - if build.Transition == database.WorkspaceTransitionStop { - stat.StoppedWorkspaces++ - } - continue - } - if job.CanceledAt.Valid || job.Error.Valid { - stat.FailedWorkspaces++ - continue - } - } - return stat, nil -} - -func (q *FakeQuerier) GetEligibleProvisionerDaemonsByProvisionerJobIDs(_ context.Context, provisionerJobIds []uuid.UUID) ([]database.GetEligibleProvisionerDaemonsByProvisionerJobIDsRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - results := make([]database.GetEligibleProvisionerDaemonsByProvisionerJobIDsRow, 0) - seen := make(map[string]struct{}) // Track unique combinations - - for _, jobID := range provisionerJobIds { - var job database.ProvisionerJob - found := false - for _, j := range q.provisionerJobs { - if j.ID == jobID { - job = j - found = true - break - } - } - if !found { - continue - } - - for _, daemon := range q.provisionerDaemons { - if daemon.OrganizationID != job.OrganizationID { - continue - } - - if !tagsSubset(job.Tags, daemon.Tags) { - continue - } - - provisionerMatches := false - for _, p := range daemon.Provisioners { - if p == job.Provisioner { - provisionerMatches = true - break - } - } - if !provisionerMatches { - continue - } - - key := jobID.String() + "-" + daemon.ID.String() - if _, exists := seen[key]; exists { - continue - } - seen[key] = struct{}{} - - results = append(results, database.GetEligibleProvisionerDaemonsByProvisionerJobIDsRow{ - JobID: jobID, - ProvisionerDaemon: daemon, - }) - } - } - - return results, nil -} - -func (q *FakeQuerier) GetExternalAuthLink(_ context.Context, arg database.GetExternalAuthLinkParams) (database.ExternalAuthLink, error) { - if err := validateDatabaseType(arg); err != nil { - return database.ExternalAuthLink{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - for _, gitAuthLink := range q.externalAuthLinks { - if arg.UserID != gitAuthLink.UserID { - continue - } - if arg.ProviderID != gitAuthLink.ProviderID { - continue - } - return gitAuthLink, nil - } - return database.ExternalAuthLink{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetExternalAuthLinksByUserID(_ context.Context, userID uuid.UUID) ([]database.ExternalAuthLink, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - gals := make([]database.ExternalAuthLink, 0) - for _, gal := range q.externalAuthLinks { - if gal.UserID == userID { - gals = append(gals, gal) - } - } - return gals, nil -} - -func (q *FakeQuerier) GetFailedWorkspaceBuildsByTemplateID(ctx context.Context, arg database.GetFailedWorkspaceBuildsByTemplateIDParams) ([]database.GetFailedWorkspaceBuildsByTemplateIDRow, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - workspaceBuildStats := []database.GetFailedWorkspaceBuildsByTemplateIDRow{} - for _, wb := range q.workspaceBuilds { - job, err := q.getProvisionerJobByIDNoLock(ctx, wb.JobID) - if err != nil { - return nil, xerrors.Errorf("get provisioner job by ID: %w", err) - } - - if job.JobStatus != database.ProvisionerJobStatusFailed { - continue - } - - if !job.CompletedAt.Valid { - continue - } - - if wb.CreatedAt.Before(arg.Since) { - continue - } - - w, err := q.getWorkspaceByIDNoLock(ctx, wb.WorkspaceID) - if err != nil { - return nil, xerrors.Errorf("get workspace by ID: %w", err) - } - - t, err := q.getTemplateByIDNoLock(ctx, w.TemplateID) - if err != nil { - return nil, xerrors.Errorf("get template by ID: %w", err) - } - - if t.ID != arg.TemplateID { - continue - } - - workspaceOwner, err := q.getUserByIDNoLock(w.OwnerID) - if err != nil { - return nil, xerrors.Errorf("get user by ID: %w", err) - } - - templateVersion, err := q.getTemplateVersionByIDNoLock(ctx, wb.TemplateVersionID) - if err != nil { - return nil, xerrors.Errorf("get template version by ID: %w", err) - } - - workspaceBuildStats = append(workspaceBuildStats, database.GetFailedWorkspaceBuildsByTemplateIDRow{ - WorkspaceID: w.ID, - WorkspaceName: w.Name, - WorkspaceOwnerUsername: workspaceOwner.Username, - TemplateVersionName: templateVersion.Name, - WorkspaceBuildNumber: wb.BuildNumber, - }) - } - - sort.Slice(workspaceBuildStats, func(i, j int) bool { - if workspaceBuildStats[i].TemplateVersionName != workspaceBuildStats[j].TemplateVersionName { - return workspaceBuildStats[i].TemplateVersionName < workspaceBuildStats[j].TemplateVersionName - } - return workspaceBuildStats[i].WorkspaceBuildNumber > workspaceBuildStats[j].WorkspaceBuildNumber - }) - return workspaceBuildStats, nil -} - -func (q *FakeQuerier) GetFileByHashAndCreator(_ context.Context, arg database.GetFileByHashAndCreatorParams) (database.File, error) { - if err := validateDatabaseType(arg); err != nil { - return database.File{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, file := range q.files { - if file.Hash == arg.Hash && file.CreatedBy == arg.CreatedBy { - return file, nil - } - } - return database.File{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetFileByID(_ context.Context, id uuid.UUID) (database.File, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, file := range q.files { - if file.ID == id { - return file, nil - } - } - return database.File{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetFileIDByTemplateVersionID(ctx context.Context, templateVersionID uuid.UUID) (uuid.UUID, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, v := range q.templateVersions { - if v.ID == templateVersionID { - jobID := v.JobID - for _, j := range q.provisionerJobs { - if j.ID == jobID { - if j.StorageMethod == database.ProvisionerStorageMethodFile { - return j.FileID, nil - } - // We found the right job id but it wasn't a proper match. - break - } - } - // We found the right template version but it wasn't a proper match. - break - } - } - - return uuid.Nil, sql.ErrNoRows -} - -func (q *FakeQuerier) GetFileTemplates(_ context.Context, id uuid.UUID) ([]database.GetFileTemplatesRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - rows := make([]database.GetFileTemplatesRow, 0) - var file database.File - for _, f := range q.files { - if f.ID == id { - file = f - break - } - } - if file.Hash == "" { - return rows, nil - } - - for _, job := range q.provisionerJobs { - if job.FileID == id { - for _, version := range q.templateVersions { - if version.JobID == job.ID { - for _, template := range q.templates { - if template.ID == version.TemplateID.UUID { - rows = append(rows, database.GetFileTemplatesRow{ - FileID: file.ID, - FileCreatedBy: file.CreatedBy, - TemplateID: template.ID, - TemplateOrganizationID: template.OrganizationID, - TemplateCreatedBy: template.CreatedBy, - UserACL: template.UserACL, - GroupACL: template.GroupACL, - }) - } - } - } - } - } - } - - return rows, nil -} - -func (q *FakeQuerier) GetFilteredInboxNotificationsByUserID(_ context.Context, arg database.GetFilteredInboxNotificationsByUserIDParams) ([]database.InboxNotification, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - notifications := make([]database.InboxNotification, 0) - // TODO : after using go version >= 1.23 , we can change this one to https://pkg.go.dev/slices#Backward - for idx := len(q.inboxNotifications) - 1; idx >= 0; idx-- { - notification := q.inboxNotifications[idx] - - if notification.UserID == arg.UserID { - if !arg.CreatedAtOpt.IsZero() && !notification.CreatedAt.Before(arg.CreatedAtOpt) { - continue - } - - templateFound := false - for _, template := range arg.Templates { - if notification.TemplateID == template { - templateFound = true - } - } - - if len(arg.Templates) > 0 && !templateFound { - continue - } - - targetsFound := true - for _, target := range arg.Targets { - targetFound := false - for _, insertedTarget := range notification.Targets { - if insertedTarget == target { - targetFound = true - break - } - } - - if !targetFound { - targetsFound = false - break - } - } - - if !targetsFound { - continue - } - - if (arg.LimitOpt == 0 && len(notifications) == 25) || - (arg.LimitOpt != 0 && len(notifications) == int(arg.LimitOpt)) { - break - } - - notifications = append(notifications, notification) - } - } - - return notifications, nil -} - -func (q *FakeQuerier) GetGitSSHKey(_ context.Context, userID uuid.UUID) (database.GitSSHKey, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, key := range q.gitSSHKey { - if key.UserID == userID { - return key, nil - } - } - return database.GitSSHKey{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetGroupByID(ctx context.Context, id uuid.UUID) (database.Group, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.getGroupByIDNoLock(ctx, id) -} - -func (q *FakeQuerier) GetGroupByOrgAndName(_ context.Context, arg database.GetGroupByOrgAndNameParams) (database.Group, error) { - if err := validateDatabaseType(arg); err != nil { - return database.Group{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, group := range q.groups { - if group.OrganizationID == arg.OrganizationID && - group.Name == arg.Name { - return group, nil - } - } - - return database.Group{}, sql.ErrNoRows -} - -//nolint:revive // It's not a control flag, its a filter -func (q *FakeQuerier) GetGroupMembers(ctx context.Context, includeSystem bool) ([]database.GroupMember, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - members := make([]database.GroupMemberTable, 0, len(q.groupMembers)) - members = append(members, q.groupMembers...) - for _, org := range q.organizations { - for _, user := range q.users { - if !includeSystem && user.IsSystem { - continue - } - members = append(members, database.GroupMemberTable{ - UserID: user.ID, - GroupID: org.ID, - }) - } - } - - var groupMembers []database.GroupMember - for _, member := range members { - groupMember, err := q.getGroupMemberNoLock(ctx, member.UserID, member.GroupID) - if errors.Is(err, errUserDeleted) { - continue - } - if err != nil { - return nil, err - } - groupMembers = append(groupMembers, groupMember) - } - - return groupMembers, nil -} - -func (q *FakeQuerier) GetGroupMembersByGroupID(ctx context.Context, arg database.GetGroupMembersByGroupIDParams) ([]database.GroupMember, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - if q.isEveryoneGroup(arg.GroupID) { - return q.getEveryoneGroupMembersNoLock(ctx, arg.GroupID), nil - } - - var groupMembers []database.GroupMember - for _, member := range q.groupMembers { - if member.GroupID == arg.GroupID { - groupMember, err := q.getGroupMemberNoLock(ctx, member.UserID, member.GroupID) - if errors.Is(err, errUserDeleted) { - continue - } - if err != nil { - return nil, err - } - groupMembers = append(groupMembers, groupMember) - } - } - - return groupMembers, nil -} - -func (q *FakeQuerier) GetGroupMembersCountByGroupID(ctx context.Context, arg database.GetGroupMembersCountByGroupIDParams) (int64, error) { - users, err := q.GetGroupMembersByGroupID(ctx, database.GetGroupMembersByGroupIDParams(arg)) - if err != nil { - return 0, err - } - return int64(len(users)), nil -} - -func (q *FakeQuerier) GetGroups(_ context.Context, arg database.GetGroupsParams) ([]database.GetGroupsRow, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - userGroupIDs := make(map[uuid.UUID]struct{}) - if arg.HasMemberID != uuid.Nil { - for _, member := range q.groupMembers { - if member.UserID == arg.HasMemberID { - userGroupIDs[member.GroupID] = struct{}{} - } - } - - // Handle the everyone group - for _, orgMember := range q.organizationMembers { - if orgMember.UserID == arg.HasMemberID { - userGroupIDs[orgMember.OrganizationID] = struct{}{} - } - } - } - - orgDetailsCache := make(map[uuid.UUID]struct{ name, displayName string }) - filtered := make([]database.GetGroupsRow, 0) - for _, group := range q.groups { - if len(arg.GroupIds) > 0 { - if !slices.Contains(arg.GroupIds, group.ID) { - continue - } - } - - if arg.OrganizationID != uuid.Nil && group.OrganizationID != arg.OrganizationID { - continue - } - - _, ok := userGroupIDs[group.ID] - if arg.HasMemberID != uuid.Nil && !ok { - continue - } - - if len(arg.GroupNames) > 0 && !slices.Contains(arg.GroupNames, group.Name) { - continue - } - - orgDetails, ok := orgDetailsCache[group.ID] - if !ok { - for _, org := range q.organizations { - if group.OrganizationID == org.ID { - orgDetails = struct{ name, displayName string }{ - name: org.Name, displayName: org.DisplayName, - } - break - } - } - orgDetailsCache[group.ID] = orgDetails - } - - filtered = append(filtered, database.GetGroupsRow{ - Group: group, - OrganizationName: orgDetails.name, - OrganizationDisplayName: orgDetails.displayName, - }) - } - - return filtered, nil -} - -func (q *FakeQuerier) GetHealthSettings(_ context.Context) (string, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - if q.healthSettings == nil { - return "{}", nil - } - - return string(q.healthSettings), nil -} - -func (q *FakeQuerier) GetInboxNotificationByID(_ context.Context, id uuid.UUID) (database.InboxNotification, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, notification := range q.inboxNotifications { - if notification.ID == id { - return notification, nil - } - } - - return database.InboxNotification{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetInboxNotificationsByUserID(_ context.Context, params database.GetInboxNotificationsByUserIDParams) ([]database.InboxNotification, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - notifications := make([]database.InboxNotification, 0) - for _, notification := range q.inboxNotifications { - if notification.UserID == params.UserID { - notifications = append(notifications, notification) - } - } - - return notifications, nil -} - -func (q *FakeQuerier) GetLastUpdateCheck(_ context.Context) (string, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - if q.lastUpdateCheck == nil { - return "", sql.ErrNoRows - } - return string(q.lastUpdateCheck), nil -} - -func (q *FakeQuerier) GetLatestCryptoKeyByFeature(_ context.Context, feature database.CryptoKeyFeature) (database.CryptoKey, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - var latestKey database.CryptoKey - for _, key := range q.cryptoKeys { - if key.Feature == feature && latestKey.Sequence < key.Sequence { - latestKey = key - } - } - if latestKey.StartsAt.IsZero() { - return database.CryptoKey{}, sql.ErrNoRows - } - return latestKey, nil -} - -func (q *FakeQuerier) GetLatestWorkspaceAppStatusesByWorkspaceIDs(_ context.Context, ids []uuid.UUID) ([]database.WorkspaceAppStatus, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - // Map to track latest status per workspace ID - latestByWorkspace := make(map[uuid.UUID]database.WorkspaceAppStatus) - - // Find latest status for each workspace ID - for _, appStatus := range q.workspaceAppStatuses { - if !slices.Contains(ids, appStatus.WorkspaceID) { - continue - } - - current, exists := latestByWorkspace[appStatus.WorkspaceID] - if !exists || appStatus.CreatedAt.After(current.CreatedAt) { - latestByWorkspace[appStatus.WorkspaceID] = appStatus - } - } - - // Convert map to slice - appStatuses := make([]database.WorkspaceAppStatus, 0, len(latestByWorkspace)) - for _, status := range latestByWorkspace { - appStatuses = append(appStatuses, status) - } - - return appStatuses, nil -} - -func (q *FakeQuerier) GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (database.WorkspaceBuild, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspaceID) -} - -func (q *FakeQuerier) GetLatestWorkspaceBuilds(_ context.Context) ([]database.WorkspaceBuild, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - builds := make(map[uuid.UUID]database.WorkspaceBuild) - buildNumbers := make(map[uuid.UUID]int32) - for _, workspaceBuild := range q.workspaceBuilds { - id := workspaceBuild.WorkspaceID - if workspaceBuild.BuildNumber > buildNumbers[id] { - builds[id] = q.workspaceBuildWithUserNoLock(workspaceBuild) - buildNumbers[id] = workspaceBuild.BuildNumber - } - } - var returnBuilds []database.WorkspaceBuild - for i, n := range buildNumbers { - if n > 0 { - b := builds[i] - returnBuilds = append(returnBuilds, b) - } - } - if len(returnBuilds) == 0 { - return nil, sql.ErrNoRows - } - return returnBuilds, nil -} - -func (q *FakeQuerier) GetLatestWorkspaceBuildsByWorkspaceIDs(_ context.Context, ids []uuid.UUID) ([]database.WorkspaceBuild, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - builds := make(map[uuid.UUID]database.WorkspaceBuild) - buildNumbers := make(map[uuid.UUID]int32) - for _, workspaceBuild := range q.workspaceBuilds { - for _, id := range ids { - if id == workspaceBuild.WorkspaceID && workspaceBuild.BuildNumber > buildNumbers[id] { - builds[id] = q.workspaceBuildWithUserNoLock(workspaceBuild) - buildNumbers[id] = workspaceBuild.BuildNumber - } - } - } - var returnBuilds []database.WorkspaceBuild - for i, n := range buildNumbers { - if n > 0 { - b := builds[i] - returnBuilds = append(returnBuilds, b) - } - } - if len(returnBuilds) == 0 { - return nil, sql.ErrNoRows - } - return returnBuilds, nil -} - -func (q *FakeQuerier) GetLicenseByID(_ context.Context, id int32) (database.License, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, license := range q.licenses { - if license.ID == id { - return license, nil - } - } - return database.License{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetLicenses(_ context.Context) ([]database.License, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - results := append([]database.License{}, q.licenses...) - sort.Slice(results, func(i, j int) bool { return results[i].ID < results[j].ID }) - return results, nil -} - -func (q *FakeQuerier) GetLogoURL(_ context.Context) (string, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - if q.logoURL == "" { - return "", sql.ErrNoRows - } - - return q.logoURL, nil -} - -func (q *FakeQuerier) GetNotificationMessagesByStatus(_ context.Context, arg database.GetNotificationMessagesByStatusParams) ([]database.NotificationMessage, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - var out []database.NotificationMessage - for _, m := range q.notificationMessages { - if len(out) > int(arg.Limit) { - return out, nil - } - - if m.Status == arg.Status { - out = append(out, m) - } - } - - return out, nil -} - -func (q *FakeQuerier) GetNotificationReportGeneratorLogByTemplate(_ context.Context, templateID uuid.UUID) (database.NotificationReportGeneratorLog, error) { - err := validateDatabaseType(templateID) - if err != nil { - return database.NotificationReportGeneratorLog{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, record := range q.notificationReportGeneratorLogs { - if record.NotificationTemplateID == templateID { - return record, nil - } - } - return database.NotificationReportGeneratorLog{}, sql.ErrNoRows -} - -func (*FakeQuerier) GetNotificationTemplateByID(_ context.Context, _ uuid.UUID) (database.NotificationTemplate, error) { - // Not implementing this function because it relies on state in the database which is created with migrations. - // We could consider using code-generation to align the database state and dbmem, but it's not worth it right now. - return database.NotificationTemplate{}, ErrUnimplemented -} - -func (*FakeQuerier) GetNotificationTemplatesByKind(_ context.Context, _ database.NotificationTemplateKind) ([]database.NotificationTemplate, error) { - // Not implementing this function because it relies on state in the database which is created with migrations. - // We could consider using code-generation to align the database state and dbmem, but it's not worth it right now. - return nil, ErrUnimplemented -} - -func (q *FakeQuerier) GetNotificationsSettings(_ context.Context) (string, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - if q.notificationsSettings == nil { - return "{}", nil - } - - return string(q.notificationsSettings), nil -} - -func (q *FakeQuerier) GetOAuth2GithubDefaultEligible(_ context.Context) (bool, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - if q.oauth2GithubDefaultEligible == nil { - return false, sql.ErrNoRows - } - return *q.oauth2GithubDefaultEligible, nil -} - -func (q *FakeQuerier) GetOAuth2ProviderAppByClientID(ctx context.Context, id uuid.UUID) (database.OAuth2ProviderApp, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, app := range q.oauth2ProviderApps { - if app.ID == id { - return app, nil - } - } - return database.OAuth2ProviderApp{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetOAuth2ProviderAppByID(_ context.Context, id uuid.UUID) (database.OAuth2ProviderApp, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, app := range q.oauth2ProviderApps { - if app.ID == id { - return app, nil - } - } - return database.OAuth2ProviderApp{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetOAuth2ProviderAppByRegistrationToken(ctx context.Context, registrationAccessToken sql.NullString) (database.OAuth2ProviderApp, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, app := range q.data.oauth2ProviderApps { - if app.RegistrationAccessToken.Valid && registrationAccessToken.Valid && - app.RegistrationAccessToken.String == registrationAccessToken.String { - return app, nil - } - } - return database.OAuth2ProviderApp{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetOAuth2ProviderAppCodeByID(_ context.Context, id uuid.UUID) (database.OAuth2ProviderAppCode, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, code := range q.oauth2ProviderAppCodes { - if code.ID == id { - return code, nil - } - } - return database.OAuth2ProviderAppCode{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetOAuth2ProviderAppCodeByPrefix(_ context.Context, secretPrefix []byte) (database.OAuth2ProviderAppCode, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, code := range q.oauth2ProviderAppCodes { - if bytes.Equal(code.SecretPrefix, secretPrefix) { - return code, nil - } - } - return database.OAuth2ProviderAppCode{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetOAuth2ProviderAppSecretByID(_ context.Context, id uuid.UUID) (database.OAuth2ProviderAppSecret, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, secret := range q.oauth2ProviderAppSecrets { - if secret.ID == id { - return secret, nil - } - } - return database.OAuth2ProviderAppSecret{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetOAuth2ProviderAppSecretByPrefix(_ context.Context, secretPrefix []byte) (database.OAuth2ProviderAppSecret, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, secret := range q.oauth2ProviderAppSecrets { - if bytes.Equal(secret.SecretPrefix, secretPrefix) { - return secret, nil - } - } - return database.OAuth2ProviderAppSecret{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetOAuth2ProviderAppSecretsByAppID(_ context.Context, appID uuid.UUID) ([]database.OAuth2ProviderAppSecret, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, app := range q.oauth2ProviderApps { - if app.ID == appID { - secrets := []database.OAuth2ProviderAppSecret{} - for _, secret := range q.oauth2ProviderAppSecrets { - if secret.AppID == appID { - secrets = append(secrets, secret) - } - } - - slices.SortFunc(secrets, func(a, b database.OAuth2ProviderAppSecret) int { - if a.CreatedAt.Before(b.CreatedAt) { - return -1 - } else if a.CreatedAt.Equal(b.CreatedAt) { - return 0 - } - return 1 - }) - return secrets, nil - } - } - - return []database.OAuth2ProviderAppSecret{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetOAuth2ProviderAppTokenByAPIKeyID(_ context.Context, apiKeyID string) (database.OAuth2ProviderAppToken, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, token := range q.oauth2ProviderAppTokens { - if token.APIKeyID == apiKeyID { - return token, nil - } - } - - return database.OAuth2ProviderAppToken{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetOAuth2ProviderAppTokenByPrefix(_ context.Context, hashPrefix []byte) (database.OAuth2ProviderAppToken, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, token := range q.oauth2ProviderAppTokens { - if bytes.Equal(token.HashPrefix, hashPrefix) { - return token, nil - } - } - return database.OAuth2ProviderAppToken{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetOAuth2ProviderApps(_ context.Context) ([]database.OAuth2ProviderApp, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - slices.SortFunc(q.oauth2ProviderApps, func(a, b database.OAuth2ProviderApp) int { - return slice.Ascending(a.Name, b.Name) - }) - return q.oauth2ProviderApps, nil -} - -func (q *FakeQuerier) GetOAuth2ProviderAppsByUserID(_ context.Context, userID uuid.UUID) ([]database.GetOAuth2ProviderAppsByUserIDRow, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - rows := []database.GetOAuth2ProviderAppsByUserIDRow{} - for _, app := range q.oauth2ProviderApps { - tokens := []database.OAuth2ProviderAppToken{} - for _, secret := range q.oauth2ProviderAppSecrets { - if secret.AppID == app.ID { - for _, token := range q.oauth2ProviderAppTokens { - if token.AppSecretID == secret.ID { - keyIdx := slices.IndexFunc(q.apiKeys, func(key database.APIKey) bool { - return key.ID == token.APIKeyID - }) - if keyIdx != -1 && q.apiKeys[keyIdx].UserID == userID { - tokens = append(tokens, token) - } - } - } - } - } - if len(tokens) > 0 { - rows = append(rows, database.GetOAuth2ProviderAppsByUserIDRow{ - OAuth2ProviderApp: app, - TokenCount: int64(len(tokens)), - }) - } - } - return rows, nil -} - -func (q *FakeQuerier) GetOAuthSigningKey(_ context.Context) (string, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.oauthSigningKey, nil -} - -func (q *FakeQuerier) GetOrganizationByID(_ context.Context, id uuid.UUID) (database.Organization, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.getOrganizationByIDNoLock(id) -} - -func (q *FakeQuerier) GetOrganizationByName(_ context.Context, params database.GetOrganizationByNameParams) (database.Organization, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, organization := range q.organizations { - if organization.Name == params.Name && organization.Deleted == params.Deleted { - return organization, nil - } - } - return database.Organization{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetOrganizationIDsByMemberIDs(_ context.Context, ids []uuid.UUID) ([]database.GetOrganizationIDsByMemberIDsRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - getOrganizationIDsByMemberIDRows := make([]database.GetOrganizationIDsByMemberIDsRow, 0, len(ids)) - for _, userID := range ids { - userOrganizationIDs := make([]uuid.UUID, 0) - for _, membership := range q.organizationMembers { - if membership.UserID == userID { - userOrganizationIDs = append(userOrganizationIDs, membership.OrganizationID) - } - } - getOrganizationIDsByMemberIDRows = append(getOrganizationIDsByMemberIDRows, database.GetOrganizationIDsByMemberIDsRow{ - UserID: userID, - OrganizationIDs: userOrganizationIDs, - }) - } - return getOrganizationIDsByMemberIDRows, nil -} - -func (q *FakeQuerier) GetOrganizationResourceCountByID(_ context.Context, organizationID uuid.UUID) (database.GetOrganizationResourceCountByIDRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - workspacesCount := 0 - for _, workspace := range q.workspaces { - if workspace.OrganizationID == organizationID { - workspacesCount++ - } - } - - groupsCount := 0 - for _, group := range q.groups { - if group.OrganizationID == organizationID { - groupsCount++ - } - } - - templatesCount := 0 - for _, template := range q.templates { - if template.OrganizationID == organizationID { - templatesCount++ - } - } - - organizationMembersCount := 0 - for _, organizationMember := range q.organizationMembers { - if organizationMember.OrganizationID == organizationID { - organizationMembersCount++ - } - } - - provKeyCount := 0 - for _, provKey := range q.provisionerKeys { - if provKey.OrganizationID == organizationID { - provKeyCount++ - } - } - - return database.GetOrganizationResourceCountByIDRow{ - WorkspaceCount: int64(workspacesCount), - GroupCount: int64(groupsCount), - TemplateCount: int64(templatesCount), - MemberCount: int64(organizationMembersCount), - ProvisionerKeyCount: int64(provKeyCount), - }, nil -} - -func (q *FakeQuerier) GetOrganizations(_ context.Context, args database.GetOrganizationsParams) ([]database.Organization, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - tmp := make([]database.Organization, 0) - for _, org := range q.organizations { - if len(args.IDs) > 0 { - if !slices.Contains(args.IDs, org.ID) { - continue - } - } - if args.Name != "" && !strings.EqualFold(org.Name, args.Name) { - continue - } - if args.Deleted != org.Deleted { - continue - } - tmp = append(tmp, org) - } - - return tmp, nil -} - -func (q *FakeQuerier) GetOrganizationsByUserID(_ context.Context, arg database.GetOrganizationsByUserIDParams) ([]database.Organization, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - organizations := make([]database.Organization, 0) - for _, organizationMember := range q.organizationMembers { - if organizationMember.UserID != arg.UserID { - continue - } - for _, organization := range q.organizations { - if organization.ID != organizationMember.OrganizationID { - continue - } - - if arg.Deleted.Valid && organization.Deleted != arg.Deleted.Bool { - continue - } - organizations = append(organizations, organization) - } - } - - return organizations, nil -} - -func (q *FakeQuerier) GetParameterSchemasByJobID(_ context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - parameters := make([]database.ParameterSchema, 0) - for _, parameterSchema := range q.parameterSchemas { - if parameterSchema.JobID != jobID { - continue - } - parameters = append(parameters, parameterSchema) - } - if len(parameters) == 0 { - return nil, sql.ErrNoRows - } - sort.Slice(parameters, func(i, j int) bool { - return parameters[i].Index < parameters[j].Index - }) - return parameters, nil -} - -func (*FakeQuerier) GetPrebuildMetrics(_ context.Context) ([]database.GetPrebuildMetricsRow, error) { - return make([]database.GetPrebuildMetricsRow, 0), nil -} - -func (q *FakeQuerier) GetPrebuildsSettings(_ context.Context) (string, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return string(slices.Clone(q.prebuildsSettings)), nil -} - -func (q *FakeQuerier) GetPresetByID(_ context.Context, presetID uuid.UUID) (database.GetPresetByIDRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - empty := database.GetPresetByIDRow{} - - // Create an index for faster lookup - versionMap := make(map[uuid.UUID]database.TemplateVersionTable) - for _, tv := range q.templateVersions { - versionMap[tv.ID] = tv - } - - for _, preset := range q.presets { - if preset.ID == presetID { - tv, ok := versionMap[preset.TemplateVersionID] - if !ok { - return empty, xerrors.Errorf("template version %v does not exist", preset.TemplateVersionID) - } - return database.GetPresetByIDRow{ - ID: preset.ID, - TemplateVersionID: preset.TemplateVersionID, - Name: preset.Name, - CreatedAt: preset.CreatedAt, - DesiredInstances: preset.DesiredInstances, - InvalidateAfterSecs: preset.InvalidateAfterSecs, - PrebuildStatus: preset.PrebuildStatus, - TemplateID: tv.TemplateID, - OrganizationID: tv.OrganizationID, - }, nil - } - } - - return empty, xerrors.Errorf("preset %v does not exist", presetID) -} - -func (q *FakeQuerier) GetPresetByWorkspaceBuildID(_ context.Context, workspaceBuildID uuid.UUID) (database.TemplateVersionPreset, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, workspaceBuild := range q.workspaceBuilds { - if workspaceBuild.ID != workspaceBuildID { - continue - } - for _, preset := range q.presets { - if preset.TemplateVersionID == workspaceBuild.TemplateVersionID { - return preset, nil - } - } - } - return database.TemplateVersionPreset{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetPresetParametersByPresetID(_ context.Context, presetID uuid.UUID) ([]database.TemplateVersionPresetParameter, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - parameters := make([]database.TemplateVersionPresetParameter, 0) - for _, parameter := range q.presetParameters { - if parameter.TemplateVersionPresetID != presetID { - continue - } - parameters = append(parameters, parameter) - } - - return parameters, nil -} - -func (q *FakeQuerier) GetPresetParametersByTemplateVersionID(_ context.Context, templateVersionID uuid.UUID) ([]database.TemplateVersionPresetParameter, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - presets := make([]database.TemplateVersionPreset, 0) - parameters := make([]database.TemplateVersionPresetParameter, 0) - for _, preset := range q.presets { - if preset.TemplateVersionID != templateVersionID { - continue - } - presets = append(presets, preset) - } - for _, parameter := range q.presetParameters { - for _, preset := range presets { - if parameter.TemplateVersionPresetID != preset.ID { - continue - } - parameters = append(parameters, parameter) - } - } - - return parameters, nil -} - -func (q *FakeQuerier) GetPresetsAtFailureLimit(ctx context.Context, hardLimit int64) ([]database.GetPresetsAtFailureLimitRow, error) { - return nil, ErrUnimplemented -} - -func (*FakeQuerier) GetPresetsBackoff(_ context.Context, _ time.Time) ([]database.GetPresetsBackoffRow, error) { - return nil, ErrUnimplemented -} - -func (q *FakeQuerier) GetPresetsByTemplateVersionID(_ context.Context, templateVersionID uuid.UUID) ([]database.TemplateVersionPreset, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - presets := make([]database.TemplateVersionPreset, 0) - for _, preset := range q.presets { - if preset.TemplateVersionID == templateVersionID { - presets = append(presets, preset) - } - } - return presets, nil -} - -func (q *FakeQuerier) GetPreviousTemplateVersion(_ context.Context, arg database.GetPreviousTemplateVersionParams) (database.TemplateVersion, error) { - if err := validateDatabaseType(arg); err != nil { - return database.TemplateVersion{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - var currentTemplateVersion database.TemplateVersion - for _, templateVersion := range q.templateVersions { - if templateVersion.TemplateID != arg.TemplateID { - continue - } - if templateVersion.Name != arg.Name { - continue - } - if templateVersion.OrganizationID != arg.OrganizationID { - continue - } - currentTemplateVersion = q.templateVersionWithUserNoLock(templateVersion) - break - } - - previousTemplateVersions := make([]database.TemplateVersion, 0) - for _, templateVersion := range q.templateVersions { - if templateVersion.ID == currentTemplateVersion.ID { - continue - } - if templateVersion.OrganizationID != arg.OrganizationID { - continue - } - if templateVersion.TemplateID != currentTemplateVersion.TemplateID { - continue - } - - if templateVersion.CreatedAt.Before(currentTemplateVersion.CreatedAt) { - previousTemplateVersions = append(previousTemplateVersions, q.templateVersionWithUserNoLock(templateVersion)) - } - } - - if len(previousTemplateVersions) == 0 { - return database.TemplateVersion{}, sql.ErrNoRows - } - - sort.Slice(previousTemplateVersions, func(i, j int) bool { - return previousTemplateVersions[i].CreatedAt.After(previousTemplateVersions[j].CreatedAt) - }) - - return previousTemplateVersions[0], nil -} - -func (q *FakeQuerier) GetProvisionerDaemons(_ context.Context) ([]database.ProvisionerDaemon, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - if len(q.provisionerDaemons) == 0 { - // Returning err=nil here for consistency with real querier - return []database.ProvisionerDaemon{}, nil - } - // copy the data so that the caller can't manipulate any data inside dbmem - // after returning - out := make([]database.ProvisionerDaemon, len(q.provisionerDaemons)) - copy(out, q.provisionerDaemons) - for i := range out { - // maps are reference types, so we need to clone them - out[i].Tags = maps.Clone(out[i].Tags) - } - return out, nil -} - -func (q *FakeQuerier) GetProvisionerDaemonsByOrganization(_ context.Context, arg database.GetProvisionerDaemonsByOrganizationParams) ([]database.ProvisionerDaemon, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - daemons := make([]database.ProvisionerDaemon, 0) - for _, daemon := range q.provisionerDaemons { - if daemon.OrganizationID != arg.OrganizationID { - continue - } - // Special case for untagged provisioners: only match untagged jobs. - // Ref: coderd/database/queries/provisionerjobs.sql:24-30 - // CASE WHEN nested.tags :: jsonb = '{"scope": "organization", "owner": ""}' :: jsonb - // THEN nested.tags :: jsonb = @tags :: jsonb - if tagsEqual(arg.WantTags, tagsUntagged) && !tagsEqual(arg.WantTags, daemon.Tags) { - continue - } - // ELSE nested.tags :: jsonb <@ @tags :: jsonb - if !tagsSubset(arg.WantTags, daemon.Tags) { - continue - } - daemon.Tags = maps.Clone(daemon.Tags) - daemons = append(daemons, daemon) - } - - return daemons, nil -} - -func (q *FakeQuerier) GetProvisionerDaemonsWithStatusByOrganization(ctx context.Context, arg database.GetProvisionerDaemonsWithStatusByOrganizationParams) ([]database.GetProvisionerDaemonsWithStatusByOrganizationRow, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - var rows []database.GetProvisionerDaemonsWithStatusByOrganizationRow - for _, daemon := range q.provisionerDaemons { - if daemon.OrganizationID != arg.OrganizationID { - continue - } - if len(arg.IDs) > 0 && !slices.Contains(arg.IDs, daemon.ID) { - continue - } - - if len(arg.Tags) > 0 { - // Special case for untagged provisioners: only match untagged jobs. - // Ref: coderd/database/queries/provisionerjobs.sql:24-30 - // CASE WHEN nested.tags :: jsonb = '{"scope": "organization", "owner": ""}' :: jsonb - // THEN nested.tags :: jsonb = @tags :: jsonb - if tagsEqual(arg.Tags, tagsUntagged) && !tagsEqual(arg.Tags, daemon.Tags) { - continue - } - // ELSE nested.tags :: jsonb <@ @tags :: jsonb - if !tagsSubset(arg.Tags, daemon.Tags) { - continue - } - } - - var status database.ProvisionerDaemonStatus - var currentJob database.ProvisionerJob - if !daemon.LastSeenAt.Valid || daemon.LastSeenAt.Time.Before(time.Now().Add(-time.Duration(arg.StaleIntervalMS)*time.Millisecond)) { - status = database.ProvisionerDaemonStatusOffline - } else { - for _, job := range q.provisionerJobs { - if job.WorkerID.Valid && job.WorkerID.UUID == daemon.ID && !job.CompletedAt.Valid && !job.Error.Valid { - currentJob = job - break - } - } - - if currentJob.ID != uuid.Nil { - status = database.ProvisionerDaemonStatusBusy - } else { - status = database.ProvisionerDaemonStatusIdle - } - } - var currentTemplate database.Template - if currentJob.ID != uuid.Nil { - var input codersdk.ProvisionerJobInput - err := json.Unmarshal(currentJob.Input, &input) - if err != nil { - return nil, err - } - if input.WorkspaceBuildID != nil { - b, err := q.getWorkspaceBuildByIDNoLock(ctx, *input.WorkspaceBuildID) - if err != nil { - return nil, err - } - input.TemplateVersionID = &b.TemplateVersionID - } - if input.TemplateVersionID != nil { - v, err := q.getTemplateVersionByIDNoLock(ctx, *input.TemplateVersionID) - if err != nil { - return nil, err - } - currentTemplate, err = q.getTemplateByIDNoLock(ctx, v.TemplateID.UUID) - if err != nil { - return nil, err - } - } - } - - var previousJob database.ProvisionerJob - for _, job := range q.provisionerJobs { - if !job.WorkerID.Valid || job.WorkerID.UUID != daemon.ID { - continue - } - - if job.StartedAt.Valid || - job.CanceledAt.Valid || - job.CompletedAt.Valid || - job.Error.Valid { - if job.CompletedAt.Time.After(previousJob.CompletedAt.Time) { - previousJob = job - } - } - } - var previousTemplate database.Template - if previousJob.ID != uuid.Nil { - var input codersdk.ProvisionerJobInput - err := json.Unmarshal(previousJob.Input, &input) - if err != nil { - return nil, err - } - if input.WorkspaceBuildID != nil { - b, err := q.getWorkspaceBuildByIDNoLock(ctx, *input.WorkspaceBuildID) - if err != nil { - return nil, err - } - input.TemplateVersionID = &b.TemplateVersionID - } - if input.TemplateVersionID != nil { - v, err := q.getTemplateVersionByIDNoLock(ctx, *input.TemplateVersionID) - if err != nil { - return nil, err - } - previousTemplate, err = q.getTemplateByIDNoLock(ctx, v.TemplateID.UUID) - if err != nil { - return nil, err - } - } - } - - // Get the provisioner key name - var keyName string - for _, key := range q.provisionerKeys { - if key.ID == daemon.KeyID { - keyName = key.Name - break - } - } - - rows = append(rows, database.GetProvisionerDaemonsWithStatusByOrganizationRow{ - ProvisionerDaemon: daemon, - Status: status, - KeyName: keyName, - CurrentJobID: uuid.NullUUID{UUID: currentJob.ID, Valid: currentJob.ID != uuid.Nil}, - CurrentJobStatus: database.NullProvisionerJobStatus{ProvisionerJobStatus: currentJob.JobStatus, Valid: currentJob.ID != uuid.Nil}, - CurrentJobTemplateName: currentTemplate.Name, - CurrentJobTemplateDisplayName: currentTemplate.DisplayName, - CurrentJobTemplateIcon: currentTemplate.Icon, - PreviousJobID: uuid.NullUUID{UUID: previousJob.ID, Valid: previousJob.ID != uuid.Nil}, - PreviousJobStatus: database.NullProvisionerJobStatus{ProvisionerJobStatus: previousJob.JobStatus, Valid: previousJob.ID != uuid.Nil}, - PreviousJobTemplateName: previousTemplate.Name, - PreviousJobTemplateDisplayName: previousTemplate.DisplayName, - PreviousJobTemplateIcon: previousTemplate.Icon, - }) - } - - slices.SortFunc(rows, func(a, b database.GetProvisionerDaemonsWithStatusByOrganizationRow) int { - return b.ProvisionerDaemon.CreatedAt.Compare(a.ProvisionerDaemon.CreatedAt) - }) - - if arg.Limit.Valid && arg.Limit.Int32 > 0 && len(rows) > int(arg.Limit.Int32) { - rows = rows[:arg.Limit.Int32] - } - - return rows, nil -} - -func (q *FakeQuerier) GetProvisionerJobByID(ctx context.Context, id uuid.UUID) (database.ProvisionerJob, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.getProvisionerJobByIDNoLock(ctx, id) -} - -func (q *FakeQuerier) GetProvisionerJobByIDForUpdate(ctx context.Context, id uuid.UUID) (database.ProvisionerJob, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.getProvisionerJobByIDNoLock(ctx, id) -} - -func (q *FakeQuerier) GetProvisionerJobTimingsByJobID(_ context.Context, jobID uuid.UUID) ([]database.ProvisionerJobTiming, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - timings := make([]database.ProvisionerJobTiming, 0) - for _, timing := range q.provisionerJobTimings { - if timing.JobID == jobID { - timings = append(timings, timing) - } - } - if len(timings) == 0 { - return nil, sql.ErrNoRows - } - sort.Slice(timings, func(i, j int) bool { - return timings[i].StartedAt.Before(timings[j].StartedAt) - }) - - return timings, nil -} - -func (q *FakeQuerier) GetProvisionerJobsByIDs(_ context.Context, ids []uuid.UUID) ([]database.ProvisionerJob, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - jobs := make([]database.ProvisionerJob, 0) - for _, job := range q.provisionerJobs { - for _, id := range ids { - if id == job.ID { - // clone the Tags before appending, since maps are reference types and - // we don't want the caller to be able to mutate the map we have inside - // dbmem! - job.Tags = maps.Clone(job.Tags) - jobs = append(jobs, job) - break - } - } - } - if len(jobs) == 0 { - return nil, sql.ErrNoRows - } - - return jobs, nil -} - -func (q *FakeQuerier) GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, arg database.GetProvisionerJobsByIDsWithQueuePositionParams) ([]database.GetProvisionerJobsByIDsWithQueuePositionRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - if arg.IDs == nil { - arg.IDs = []uuid.UUID{} - } - return q.getProvisionerJobsByIDsWithQueuePositionLockedTagBasedQueue(ctx, arg.IDs) -} - -func (q *FakeQuerier) GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisioner(ctx context.Context, arg database.GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisionerParams) ([]database.GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisionerRow, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - /* - -- name: GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisioner :many - WITH pending_jobs AS ( - SELECT - id, created_at - FROM - provisioner_jobs - WHERE - started_at IS NULL - AND - canceled_at IS NULL - AND - completed_at IS NULL - AND - error IS NULL - ), - queue_position AS ( - SELECT - id, - ROW_NUMBER() OVER (ORDER BY created_at ASC) AS queue_position - FROM - pending_jobs - ), - queue_size AS ( - SELECT COUNT(*) AS count FROM pending_jobs - ) - SELECT - sqlc.embed(pj), - COALESCE(qp.queue_position, 0) AS queue_position, - COALESCE(qs.count, 0) AS queue_size, - array_agg(DISTINCT pd.id) FILTER (WHERE pd.id IS NOT NULL)::uuid[] AS available_workers - FROM - provisioner_jobs pj - LEFT JOIN - queue_position qp ON qp.id = pj.id - LEFT JOIN - queue_size qs ON TRUE - LEFT JOIN - provisioner_daemons pd ON ( - -- See AcquireProvisionerJob. - pj.started_at IS NULL - AND pj.organization_id = pd.organization_id - AND pj.provisioner = ANY(pd.provisioners) - AND provisioner_tagset_contains(pd.tags, pj.tags) - ) - WHERE - (sqlc.narg('organization_id')::uuid IS NULL OR pj.organization_id = @organization_id) - AND (COALESCE(array_length(@status::provisioner_job_status[], 1), 1) > 0 OR pj.job_status = ANY(@status::provisioner_job_status[])) - GROUP BY - pj.id, - qp.queue_position, - qs.count - ORDER BY - pj.created_at DESC - LIMIT - sqlc.narg('limit')::int; - */ - rowsWithQueuePosition, err := q.getProvisionerJobsByIDsWithQueuePositionLockedGlobalQueue(ctx, nil) - if err != nil { - return nil, err - } - - var rows []database.GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisionerRow - for _, rowQP := range rowsWithQueuePosition { - job := rowQP.ProvisionerJob - - if job.OrganizationID != arg.OrganizationID { - continue - } - if len(arg.Status) > 0 && !slices.Contains(arg.Status, job.JobStatus) { - continue - } - if len(arg.IDs) > 0 && !slices.Contains(arg.IDs, job.ID) { - continue - } - if len(arg.Tags) > 0 && !tagsSubset(job.Tags, arg.Tags) { - continue - } - - row := database.GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisionerRow{ - ProvisionerJob: rowQP.ProvisionerJob, - QueuePosition: rowQP.QueuePosition, - QueueSize: rowQP.QueueSize, - } - - // Start add metadata. - var input codersdk.ProvisionerJobInput - err := json.Unmarshal([]byte(job.Input), &input) - if err != nil { - return nil, err - } - templateVersionID := input.TemplateVersionID - if input.WorkspaceBuildID != nil { - workspaceBuild, err := q.getWorkspaceBuildByIDNoLock(ctx, *input.WorkspaceBuildID) - if err != nil { - return nil, err - } - workspace, err := q.getWorkspaceByIDNoLock(ctx, workspaceBuild.WorkspaceID) - if err != nil { - return nil, err - } - row.WorkspaceID = uuid.NullUUID{UUID: workspace.ID, Valid: true} - row.WorkspaceName = workspace.Name - if templateVersionID == nil { - templateVersionID = &workspaceBuild.TemplateVersionID - } - } - if templateVersionID != nil { - templateVersion, err := q.getTemplateVersionByIDNoLock(ctx, *templateVersionID) - if err != nil { - return nil, err - } - row.TemplateVersionName = templateVersion.Name - template, err := q.getTemplateByIDNoLock(ctx, templateVersion.TemplateID.UUID) - if err != nil { - return nil, err - } - row.TemplateID = uuid.NullUUID{UUID: template.ID, Valid: true} - row.TemplateName = template.Name - row.TemplateDisplayName = template.DisplayName - } - // End add metadata. - - if row.QueuePosition > 0 { - var availableWorkers []database.ProvisionerDaemon - for _, daemon := range q.provisionerDaemons { - if daemon.OrganizationID == job.OrganizationID && slices.Contains(daemon.Provisioners, job.Provisioner) { - if tagsEqual(job.Tags, tagsUntagged) { - if tagsEqual(job.Tags, daemon.Tags) { - availableWorkers = append(availableWorkers, daemon) - } - } else if tagsSubset(job.Tags, daemon.Tags) { - availableWorkers = append(availableWorkers, daemon) - } - } - } - slices.SortFunc(availableWorkers, func(a, b database.ProvisionerDaemon) int { - return a.CreatedAt.Compare(b.CreatedAt) - }) - for _, worker := range availableWorkers { - row.AvailableWorkers = append(row.AvailableWorkers, worker.ID) - } - } - - // Add daemon name to provisioner job - for _, daemon := range q.provisionerDaemons { - if job.WorkerID.Valid && job.WorkerID.UUID == daemon.ID { - row.WorkerName = daemon.Name - } - } - rows = append(rows, row) - } - - slices.SortFunc(rows, func(a, b database.GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisionerRow) int { - return b.ProvisionerJob.CreatedAt.Compare(a.ProvisionerJob.CreatedAt) - }) - if arg.Limit.Valid && arg.Limit.Int32 > 0 && len(rows) > int(arg.Limit.Int32) { - rows = rows[:arg.Limit.Int32] - } - return rows, nil -} - -func (q *FakeQuerier) GetProvisionerJobsCreatedAfter(_ context.Context, after time.Time) ([]database.ProvisionerJob, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - jobs := make([]database.ProvisionerJob, 0) - for _, job := range q.provisionerJobs { - if job.CreatedAt.After(after) { - // clone the Tags before appending, since maps are reference types and - // we don't want the caller to be able to mutate the map we have inside - // dbmem! - job.Tags = maps.Clone(job.Tags) - jobs = append(jobs, job) - } - } - return jobs, nil -} - -func (q *FakeQuerier) GetProvisionerJobsToBeReaped(_ context.Context, arg database.GetProvisionerJobsToBeReapedParams) ([]database.ProvisionerJob, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - maxJobs := arg.MaxJobs - - hungJobs := []database.ProvisionerJob{} - for _, provisionerJob := range q.provisionerJobs { - if !provisionerJob.CompletedAt.Valid { - if (provisionerJob.StartedAt.Valid && provisionerJob.UpdatedAt.Before(arg.HungSince)) || - (!provisionerJob.StartedAt.Valid && provisionerJob.UpdatedAt.Before(arg.PendingSince)) { - // clone the Tags before appending, since maps are reference types and - // we don't want the caller to be able to mutate the map we have inside - // dbmem! - provisionerJob.Tags = maps.Clone(provisionerJob.Tags) - hungJobs = append(hungJobs, provisionerJob) - if len(hungJobs) >= int(maxJobs) { - break - } - } - } - } - insecurerand.Shuffle(len(hungJobs), func(i, j int) { - hungJobs[i], hungJobs[j] = hungJobs[j], hungJobs[i] - }) - return hungJobs, nil -} - -func (q *FakeQuerier) GetProvisionerKeyByHashedSecret(_ context.Context, hashedSecret []byte) (database.ProvisionerKey, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, key := range q.provisionerKeys { - if bytes.Equal(key.HashedSecret, hashedSecret) { - return key, nil - } - } - - return database.ProvisionerKey{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetProvisionerKeyByID(_ context.Context, id uuid.UUID) (database.ProvisionerKey, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, key := range q.provisionerKeys { - if key.ID == id { - return key, nil - } - } - - return database.ProvisionerKey{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetProvisionerKeyByName(_ context.Context, arg database.GetProvisionerKeyByNameParams) (database.ProvisionerKey, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, key := range q.provisionerKeys { - if strings.EqualFold(key.Name, arg.Name) && key.OrganizationID == arg.OrganizationID { - return key, nil - } - } - - return database.ProvisionerKey{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetProvisionerLogsAfterID(_ context.Context, arg database.GetProvisionerLogsAfterIDParams) ([]database.ProvisionerJobLog, error) { - if err := validateDatabaseType(arg); err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - logs := make([]database.ProvisionerJobLog, 0) - for _, jobLog := range q.provisionerJobLogs { - if jobLog.JobID != arg.JobID { - continue - } - if jobLog.ID <= arg.CreatedAfter { - continue - } - logs = append(logs, jobLog) - } - return logs, nil -} - -func (q *FakeQuerier) GetQuotaAllowanceForUser(_ context.Context, params database.GetQuotaAllowanceForUserParams) (int64, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - var sum int64 - for _, member := range q.groupMembers { - if member.UserID != params.UserID { - continue - } - if _, err := q.getOrganizationByIDNoLock(member.GroupID); err == nil { - // This should never happen, but it has been reported in customer deployments. - // The SQL handles this case, and omits `group_members` rows in the - // Everyone group. It counts these distinctly via `organization_members` table. - continue - } - for _, group := range q.groups { - if group.ID == member.GroupID { - sum += int64(group.QuotaAllowance) - continue - } - } - } - - // Grab the quota for the Everyone group iff the user is a member of - // said organization. - for _, mem := range q.organizationMembers { - if mem.UserID != params.UserID { - continue - } - - group, err := q.getGroupByIDNoLock(context.Background(), mem.OrganizationID) - if err != nil { - return -1, xerrors.Errorf("failed to get everyone group for org %q", mem.OrganizationID.String()) - } - if group.OrganizationID != params.OrganizationID { - continue - } - sum += int64(group.QuotaAllowance) - } - - return sum, nil -} - -func (q *FakeQuerier) GetQuotaConsumedForUser(_ context.Context, params database.GetQuotaConsumedForUserParams) (int64, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - var sum int64 - for _, workspace := range q.workspaces { - if workspace.OwnerID != params.OwnerID { - continue - } - if workspace.OrganizationID != params.OrganizationID { - continue - } - if workspace.Deleted { - continue - } - - var lastBuild database.WorkspaceBuild - for _, build := range q.workspaceBuilds { - if build.WorkspaceID != workspace.ID { - continue - } - if build.CreatedAt.After(lastBuild.CreatedAt) { - lastBuild = build - } - } - sum += int64(lastBuild.DailyCost) - } - return sum, nil -} - -func (q *FakeQuerier) GetReplicaByID(_ context.Context, id uuid.UUID) (database.Replica, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, replica := range q.replicas { - if replica.ID == id { - return replica, nil - } - } - - return database.Replica{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetReplicasUpdatedAfter(_ context.Context, updatedAt time.Time) ([]database.Replica, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - replicas := make([]database.Replica, 0) - for _, replica := range q.replicas { - if replica.UpdatedAt.After(updatedAt) && !replica.StoppedAt.Valid { - replicas = append(replicas, replica) - } - } - return replicas, nil -} - -func (q *FakeQuerier) GetRunningPrebuiltWorkspaces(ctx context.Context) ([]database.GetRunningPrebuiltWorkspacesRow, error) { - return nil, ErrUnimplemented -} - -func (q *FakeQuerier) GetRuntimeConfig(_ context.Context, key string) (string, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - val, ok := q.runtimeConfig[key] - if !ok { - return "", sql.ErrNoRows - } - - return val, nil -} - -func (*FakeQuerier) GetTailnetAgents(context.Context, uuid.UUID) ([]database.TailnetAgent, error) { - return nil, ErrUnimplemented -} - -func (*FakeQuerier) GetTailnetClientsForAgent(context.Context, uuid.UUID) ([]database.TailnetClient, error) { - return nil, ErrUnimplemented -} - -func (*FakeQuerier) GetTailnetPeers(context.Context, uuid.UUID) ([]database.TailnetPeer, error) { - return nil, ErrUnimplemented -} - -func (*FakeQuerier) GetTailnetTunnelPeerBindings(context.Context, uuid.UUID) ([]database.GetTailnetTunnelPeerBindingsRow, error) { - return nil, ErrUnimplemented -} - -func (*FakeQuerier) GetTailnetTunnelPeerIDs(context.Context, uuid.UUID) ([]database.GetTailnetTunnelPeerIDsRow, error) { - return nil, ErrUnimplemented -} - -func (q *FakeQuerier) GetTelemetryItem(_ context.Context, key string) (database.TelemetryItem, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, item := range q.telemetryItems { - if item.Key == key { - return item, nil - } - } - - return database.TelemetryItem{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetTelemetryItems(_ context.Context) ([]database.TelemetryItem, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - return slices.Clone(q.telemetryItems), nil -} - -func (q *FakeQuerier) GetTemplateAppInsights(ctx context.Context, arg database.GetTemplateAppInsightsParams) ([]database.GetTemplateAppInsightsRow, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - /* - WITH - */ - - /* - -- Create a list of all unique apps by template, this is used to - -- filter out irrelevant template usage stats. - apps AS ( - SELECT DISTINCT ON (ws.template_id, app.slug) - ws.template_id, - app.slug, - app.display_name, - app.icon - FROM - workspaces ws - JOIN - workspace_builds AS build - ON - build.workspace_id = ws.id - JOIN - workspace_resources AS resource - ON - resource.job_id = build.job_id - JOIN - workspace_agents AS agent - ON - agent.resource_id = resource.id - JOIN - workspace_apps AS app - ON - app.agent_id = agent.id - WHERE - -- Partial query parameter filter. - CASE WHEN COALESCE(array_length(@template_ids::uuid[], 1), 0) > 0 THEN ws.template_id = ANY(@template_ids::uuid[]) ELSE TRUE END - ORDER BY - ws.template_id, app.slug, app.created_at DESC - ), - -- Join apps and template usage stats to filter out irrelevant rows. - -- Note that this way of joining will eliminate all data-points that - -- aren't for "real" apps. That means ports are ignored (even though - -- they're part of the dataset), as well as are "[terminal]" entries - -- which are alternate datapoints for reconnecting pty usage. - template_usage_stats_with_apps AS ( - SELECT - tus.start_time, - tus.template_id, - tus.user_id, - apps.slug, - apps.display_name, - apps.icon, - tus.app_usage_mins - FROM - apps - JOIN - template_usage_stats AS tus - ON - -- Query parameter filter. - tus.start_time >= @start_time::timestamptz - AND tus.end_time <= @end_time::timestamptz - AND CASE WHEN COALESCE(array_length(@template_ids::uuid[], 1), 0) > 0 THEN tus.template_id = ANY(@template_ids::uuid[]) ELSE TRUE END - -- Primary join condition. - AND tus.template_id = apps.template_id - AND apps.slug IN (SELECT jsonb_object_keys(tus.app_usage_mins)) - ), - -- Group the app insights by interval, user and unique app. This - -- allows us to deduplicate a user using the same app across - -- multiple templates. - app_insights AS ( - SELECT - user_id, - slug, - display_name, - icon, - -- See motivation in GetTemplateInsights for LEAST(SUM(n), 30). - LEAST(SUM(app_usage.value::smallint), 30) AS usage_mins - FROM - template_usage_stats_with_apps, jsonb_each(app_usage_mins) AS app_usage - WHERE - app_usage.key = slug - GROUP BY - start_time, user_id, slug, display_name, icon - ), - -- Analyze the users unique app usage across all templates. Count - -- usage across consecutive intervals as continuous usage. - times_used AS ( - SELECT DISTINCT ON (user_id, slug, display_name, icon, uniq) - slug, - display_name, - icon, - -- Turn start_time into a unique identifier that identifies a users - -- continuous app usage. The value of uniq is otherwise garbage. - -- - -- Since we're aggregating per user app usage across templates, - -- there can be duplicate start_times. To handle this, we use the - -- dense_rank() function, otherwise row_number() would suffice. - start_time - ( - dense_rank() OVER ( - PARTITION BY - user_id, slug, display_name, icon - ORDER BY - start_time - ) * '30 minutes'::interval - ) AS uniq - FROM - template_usage_stats_with_apps - ), - */ - - // Due to query optimizations, this logic is somewhat inverted from - // the above query. - type appInsightsGroupBy struct { - StartTime time.Time - UserID uuid.UUID - Slug string - DisplayName string - Icon string - } - type appTimesUsedGroupBy struct { - UserID uuid.UUID - Slug string - DisplayName string - Icon string - } - type appInsightsRow struct { - appInsightsGroupBy - TemplateIDs []uuid.UUID - AppUsageMins int64 - } - appInsightRows := make(map[appInsightsGroupBy]appInsightsRow) - appTimesUsedRows := make(map[appTimesUsedGroupBy]map[time.Time]struct{}) - // FROM - for _, stat := range q.templateUsageStats { - // WHERE - if stat.StartTime.Before(arg.StartTime) || stat.EndTime.After(arg.EndTime) { - continue - } - if len(arg.TemplateIDs) > 0 && !slices.Contains(arg.TemplateIDs, stat.TemplateID) { - continue - } - - // json_each - for slug, appUsage := range stat.AppUsageMins { - // FROM apps JOIN template_usage_stats - app, _ := q.getLatestWorkspaceAppByTemplateIDUserIDSlugNoLock(ctx, stat.TemplateID, stat.UserID, slug) - if app.Slug == "" { - continue - } - - // SELECT - key := appInsightsGroupBy{ - StartTime: stat.StartTime, - UserID: stat.UserID, - Slug: slug, - DisplayName: app.DisplayName, - Icon: app.Icon, - } - row, ok := appInsightRows[key] - if !ok { - row = appInsightsRow{ - appInsightsGroupBy: key, - } - } - row.TemplateIDs = append(row.TemplateIDs, stat.TemplateID) - row.AppUsageMins = least(row.AppUsageMins+appUsage, 30) - appInsightRows[key] = row - - // Prepare to do times_used calculation, distinct start times. - timesUsedKey := appTimesUsedGroupBy{ - UserID: stat.UserID, - Slug: slug, - DisplayName: app.DisplayName, - Icon: app.Icon, - } - if appTimesUsedRows[timesUsedKey] == nil { - appTimesUsedRows[timesUsedKey] = make(map[time.Time]struct{}) - } - // This assigns a distinct time, so we don't need to - // dense_rank() later on, we can simply do row_number(). - appTimesUsedRows[timesUsedKey][stat.StartTime] = struct{}{} - } - } - - appTimesUsedTempRows := make(map[appTimesUsedGroupBy][]time.Time) - for key, times := range appTimesUsedRows { - for t := range times { - appTimesUsedTempRows[key] = append(appTimesUsedTempRows[key], t) - } - } - for _, times := range appTimesUsedTempRows { - slices.SortFunc(times, func(a, b time.Time) int { - return int(a.Sub(b)) - }) - } - for key, times := range appTimesUsedTempRows { - uniq := make(map[time.Time]struct{}) - for i, t := range times { - uniq[t.Add(-(30 * time.Minute * time.Duration(i)))] = struct{}{} - } - appTimesUsedRows[key] = uniq - } - - /* - -- Even though we allow identical apps to be aggregated across - -- templates, we still want to be able to report which templates - -- the data comes from. - templates AS ( - SELECT - slug, - display_name, - icon, - array_agg(DISTINCT template_id)::uuid[] AS template_ids - FROM - template_usage_stats_with_apps - GROUP BY - slug, display_name, icon - ) - */ - - type appGroupBy struct { - Slug string - DisplayName string - Icon string - } - type templateRow struct { - appGroupBy - TemplateIDs []uuid.UUID - } - - templateRows := make(map[appGroupBy]templateRow) - for _, aiRow := range appInsightRows { - key := appGroupBy{ - Slug: aiRow.Slug, - DisplayName: aiRow.DisplayName, - Icon: aiRow.Icon, - } - row, ok := templateRows[key] - if !ok { - row = templateRow{ - appGroupBy: key, - } - } - row.TemplateIDs = uniqueSortedUUIDs(append(row.TemplateIDs, aiRow.TemplateIDs...)) - templateRows[key] = row - } - - /* - SELECT - t.template_ids, - COUNT(DISTINCT ai.user_id) AS active_users, - ai.slug, - ai.display_name, - ai.icon, - (SUM(ai.usage_mins) * 60)::bigint AS usage_seconds - FROM - app_insights AS ai - JOIN - templates AS t - ON - t.slug = ai.slug - AND t.display_name = ai.display_name - AND t.icon = ai.icon - GROUP BY - t.template_ids, ai.slug, ai.display_name, ai.icon; - */ - - type templateAppInsightsRow struct { - TemplateIDs []uuid.UUID - ActiveUserIDs []uuid.UUID - UsageSeconds int64 - } - groupedRows := make(map[appGroupBy]templateAppInsightsRow) - for _, aiRow := range appInsightRows { - key := appGroupBy{ - Slug: aiRow.Slug, - DisplayName: aiRow.DisplayName, - Icon: aiRow.Icon, - } - row := groupedRows[key] - row.ActiveUserIDs = append(row.ActiveUserIDs, aiRow.UserID) - row.UsageSeconds += aiRow.AppUsageMins * 60 - groupedRows[key] = row - } - - var rows []database.GetTemplateAppInsightsRow - for key, gr := range groupedRows { - row := database.GetTemplateAppInsightsRow{ - TemplateIDs: templateRows[key].TemplateIDs, - ActiveUsers: int64(len(uniqueSortedUUIDs(gr.ActiveUserIDs))), - Slug: key.Slug, - DisplayName: key.DisplayName, - Icon: key.Icon, - UsageSeconds: gr.UsageSeconds, - } - for tuk, uniq := range appTimesUsedRows { - if key.Slug == tuk.Slug && key.DisplayName == tuk.DisplayName && key.Icon == tuk.Icon { - row.TimesUsed += int64(len(uniq)) - } - } - rows = append(rows, row) - } - - // NOTE(mafredri): Add sorting if we decide on how to handle PostgreSQL collations. - // ORDER BY slug_or_port, display_name, icon, is_app - return rows, nil -} - -func (q *FakeQuerier) GetTemplateAppInsightsByTemplate(ctx context.Context, arg database.GetTemplateAppInsightsByTemplateParams) ([]database.GetTemplateAppInsightsByTemplateRow, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - type uniqueKey struct { - TemplateID uuid.UUID - DisplayName string - Slug string - } - - // map (TemplateID + DisplayName + Slug) x time.Time x UserID x - usageByTemplateAppUser := map[uniqueKey]map[time.Time]map[uuid.UUID]int64{} - - // Review agent stats in terms of usage - for _, s := range q.workspaceAppStats { - // (was.session_started_at >= ts.from_ AND was.session_started_at < ts.to_) - // OR (was.session_ended_at > ts.from_ AND was.session_ended_at < ts.to_) - // OR (was.session_started_at < ts.from_ AND was.session_ended_at >= ts.to_) - if !(((s.SessionStartedAt.After(arg.StartTime) || s.SessionStartedAt.Equal(arg.StartTime)) && s.SessionStartedAt.Before(arg.EndTime)) || - (s.SessionEndedAt.After(arg.StartTime) && s.SessionEndedAt.Before(arg.EndTime)) || - (s.SessionStartedAt.Before(arg.StartTime) && (s.SessionEndedAt.After(arg.EndTime) || s.SessionEndedAt.Equal(arg.EndTime)))) { - continue - } - - w, err := q.getWorkspaceByIDNoLock(ctx, s.WorkspaceID) - if err != nil { - return nil, err - } - - app, _ := q.getWorkspaceAppByAgentIDAndSlugNoLock(ctx, database.GetWorkspaceAppByAgentIDAndSlugParams{ - AgentID: s.AgentID, - Slug: s.SlugOrPort, - }) - - key := uniqueKey{ - TemplateID: w.TemplateID, - DisplayName: app.DisplayName, - Slug: app.Slug, - } - - t := s.SessionStartedAt.Truncate(time.Minute) - if t.Before(arg.StartTime) { - t = arg.StartTime - } - for t.Before(s.SessionEndedAt) && t.Before(arg.EndTime) { - if _, ok := usageByTemplateAppUser[key]; !ok { - usageByTemplateAppUser[key] = map[time.Time]map[uuid.UUID]int64{} - } - if _, ok := usageByTemplateAppUser[key][t]; !ok { - usageByTemplateAppUser[key][t] = map[uuid.UUID]int64{} - } - if _, ok := usageByTemplateAppUser[key][t][s.UserID]; !ok { - usageByTemplateAppUser[key][t][s.UserID] = 60 // 1 minute - } - t = t.Add(1 * time.Minute) - } - } - - // Sort usage data - usageKeys := make([]uniqueKey, len(usageByTemplateAppUser)) - var i int - for key := range usageByTemplateAppUser { - usageKeys[i] = key - i++ - } - - slices.SortFunc(usageKeys, func(a, b uniqueKey) int { - if a.TemplateID != b.TemplateID { - return slice.Ascending(a.TemplateID.String(), b.TemplateID.String()) - } - if a.DisplayName != b.DisplayName { - return slice.Ascending(a.DisplayName, b.DisplayName) - } - return slice.Ascending(a.Slug, b.Slug) - }) - - // Build result - var result []database.GetTemplateAppInsightsByTemplateRow - for _, usageKey := range usageKeys { - r := database.GetTemplateAppInsightsByTemplateRow{ - TemplateID: usageKey.TemplateID, - DisplayName: usageKey.DisplayName, - SlugOrPort: usageKey.Slug, - } - for _, mUserUsage := range usageByTemplateAppUser[usageKey] { - r.ActiveUsers += int64(len(mUserUsage)) - for _, usage := range mUserUsage { - r.UsageSeconds += usage - } - } - result = append(result, r) - } - return result, nil -} - -func (q *FakeQuerier) GetTemplateAverageBuildTime(ctx context.Context, arg database.GetTemplateAverageBuildTimeParams) (database.GetTemplateAverageBuildTimeRow, error) { - if err := validateDatabaseType(arg); err != nil { - return database.GetTemplateAverageBuildTimeRow{}, err - } - - var emptyRow database.GetTemplateAverageBuildTimeRow - var ( - startTimes []float64 - stopTimes []float64 - deleteTimes []float64 - ) - q.mutex.RLock() - defer q.mutex.RUnlock() - for _, wb := range q.workspaceBuilds { - version, err := q.getTemplateVersionByIDNoLock(ctx, wb.TemplateVersionID) - if err != nil { - return emptyRow, err - } - if version.TemplateID != arg.TemplateID { - continue - } - - job, err := q.getProvisionerJobByIDNoLock(ctx, wb.JobID) - if err != nil { - return emptyRow, err - } - if job.CompletedAt.Valid { - took := job.CompletedAt.Time.Sub(job.StartedAt.Time).Seconds() - switch wb.Transition { - case database.WorkspaceTransitionStart: - startTimes = append(startTimes, took) - case database.WorkspaceTransitionStop: - stopTimes = append(stopTimes, took) - case database.WorkspaceTransitionDelete: - deleteTimes = append(deleteTimes, took) - } - } - } - - var row database.GetTemplateAverageBuildTimeRow - row.Delete50, row.Delete95 = tryPercentileDisc(deleteTimes, 50), tryPercentileDisc(deleteTimes, 95) - row.Stop50, row.Stop95 = tryPercentileDisc(stopTimes, 50), tryPercentileDisc(stopTimes, 95) - row.Start50, row.Start95 = tryPercentileDisc(startTimes, 50), tryPercentileDisc(startTimes, 95) - return row, nil -} - -func (q *FakeQuerier) GetTemplateByID(ctx context.Context, id uuid.UUID) (database.Template, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.getTemplateByIDNoLock(ctx, id) -} - -func (q *FakeQuerier) GetTemplateByOrganizationAndName(_ context.Context, arg database.GetTemplateByOrganizationAndNameParams) (database.Template, error) { - if err := validateDatabaseType(arg); err != nil { - return database.Template{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, template := range q.templates { - if template.OrganizationID != arg.OrganizationID { - continue - } - if !strings.EqualFold(template.Name, arg.Name) { - continue - } - if template.Deleted != arg.Deleted { - continue - } - return q.templateWithNameNoLock(template), nil - } - return database.Template{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetTemplateDAUs(_ context.Context, arg database.GetTemplateDAUsParams) ([]database.GetTemplateDAUsRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - seens := make(map[time.Time]map[uuid.UUID]struct{}) - - for _, as := range q.workspaceAgentStats { - if as.TemplateID != arg.TemplateID { - continue - } - if as.ConnectionCount == 0 { - continue - } - - date := as.CreatedAt.UTC().Add(time.Duration(arg.TzOffset) * time.Hour * -1).Truncate(time.Hour * 24) - - dateEntry := seens[date] - if dateEntry == nil { - dateEntry = make(map[uuid.UUID]struct{}) - } - dateEntry[as.UserID] = struct{}{} - seens[date] = dateEntry - } - - seenKeys := maps.Keys(seens) - sort.Slice(seenKeys, func(i, j int) bool { - return seenKeys[i].Before(seenKeys[j]) - }) - - var rs []database.GetTemplateDAUsRow - for _, key := range seenKeys { - ids := seens[key] - for id := range ids { - rs = append(rs, database.GetTemplateDAUsRow{ - Date: key, - UserID: id, - }) - } - } - - return rs, nil -} - -func (q *FakeQuerier) GetTemplateInsights(_ context.Context, arg database.GetTemplateInsightsParams) (database.GetTemplateInsightsRow, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.GetTemplateInsightsRow{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - /* - WITH - */ - - /* - insights AS ( - SELECT - user_id, - -- See motivation in GetTemplateInsights for LEAST(SUM(n), 30). - LEAST(SUM(usage_mins), 30) AS usage_mins, - LEAST(SUM(ssh_mins), 30) AS ssh_mins, - LEAST(SUM(sftp_mins), 30) AS sftp_mins, - LEAST(SUM(reconnecting_pty_mins), 30) AS reconnecting_pty_mins, - LEAST(SUM(vscode_mins), 30) AS vscode_mins, - LEAST(SUM(jetbrains_mins), 30) AS jetbrains_mins - FROM - template_usage_stats - WHERE - start_time >= @start_time::timestamptz - AND end_time <= @end_time::timestamptz - AND CASE WHEN COALESCE(array_length(@template_ids::uuid[], 1), 0) > 0 THEN template_id = ANY(@template_ids::uuid[]) ELSE TRUE END - GROUP BY - start_time, user_id - ), - */ - - type insightsGroupBy struct { - StartTime time.Time - UserID uuid.UUID - } - type insightsRow struct { - insightsGroupBy - UsageMins int16 - SSHMins int16 - SFTPMins int16 - ReconnectingPTYMins int16 - VSCodeMins int16 - JetBrainsMins int16 - } - insights := make(map[insightsGroupBy]insightsRow) - for _, stat := range q.templateUsageStats { - if stat.StartTime.Before(arg.StartTime) || stat.EndTime.After(arg.EndTime) { - continue - } - if len(arg.TemplateIDs) > 0 && !slices.Contains(arg.TemplateIDs, stat.TemplateID) { - continue - } - key := insightsGroupBy{ - StartTime: stat.StartTime, - UserID: stat.UserID, - } - row, ok := insights[key] - if !ok { - row = insightsRow{ - insightsGroupBy: key, - } - } - row.UsageMins = least(row.UsageMins+stat.UsageMins, 30) - row.SSHMins = least(row.SSHMins+stat.SshMins, 30) - row.SFTPMins = least(row.SFTPMins+stat.SftpMins, 30) - row.ReconnectingPTYMins = least(row.ReconnectingPTYMins+stat.ReconnectingPtyMins, 30) - row.VSCodeMins = least(row.VSCodeMins+stat.VscodeMins, 30) - row.JetBrainsMins = least(row.JetBrainsMins+stat.JetbrainsMins, 30) - insights[key] = row - } - - /* - templates AS ( - SELECT - array_agg(DISTINCT template_id) AS template_ids, - array_agg(DISTINCT template_id) FILTER (WHERE ssh_mins > 0) AS ssh_template_ids, - array_agg(DISTINCT template_id) FILTER (WHERE sftp_mins > 0) AS sftp_template_ids, - array_agg(DISTINCT template_id) FILTER (WHERE reconnecting_pty_mins > 0) AS reconnecting_pty_template_ids, - array_agg(DISTINCT template_id) FILTER (WHERE vscode_mins > 0) AS vscode_template_ids, - array_agg(DISTINCT template_id) FILTER (WHERE jetbrains_mins > 0) AS jetbrains_template_ids - FROM - template_usage_stats - WHERE - start_time >= @start_time::timestamptz - AND end_time <= @end_time::timestamptz - AND CASE WHEN COALESCE(array_length(@template_ids::uuid[], 1), 0) > 0 THEN template_id = ANY(@template_ids::uuid[]) ELSE TRUE END - ) - */ - - type templateRow struct { - TemplateIDs []uuid.UUID - SSHTemplateIDs []uuid.UUID - SFTPTemplateIDs []uuid.UUID - ReconnectingPTYIDs []uuid.UUID - VSCodeTemplateIDs []uuid.UUID - JetBrainsTemplateIDs []uuid.UUID - } - templates := templateRow{} - for _, stat := range q.templateUsageStats { - if stat.StartTime.Before(arg.StartTime) || stat.EndTime.After(arg.EndTime) { - continue - } - if len(arg.TemplateIDs) > 0 && !slices.Contains(arg.TemplateIDs, stat.TemplateID) { - continue - } - templates.TemplateIDs = append(templates.TemplateIDs, stat.TemplateID) - if stat.SshMins > 0 { - templates.SSHTemplateIDs = append(templates.SSHTemplateIDs, stat.TemplateID) - } - if stat.SftpMins > 0 { - templates.SFTPTemplateIDs = append(templates.SFTPTemplateIDs, stat.TemplateID) - } - if stat.ReconnectingPtyMins > 0 { - templates.ReconnectingPTYIDs = append(templates.ReconnectingPTYIDs, stat.TemplateID) - } - if stat.VscodeMins > 0 { - templates.VSCodeTemplateIDs = append(templates.VSCodeTemplateIDs, stat.TemplateID) - } - if stat.JetbrainsMins > 0 { - templates.JetBrainsTemplateIDs = append(templates.JetBrainsTemplateIDs, stat.TemplateID) - } - } - - /* - SELECT - COALESCE((SELECT template_ids FROM templates), '{}')::uuid[] AS template_ids, -- Includes app usage. - COALESCE((SELECT ssh_template_ids FROM templates), '{}')::uuid[] AS ssh_template_ids, - COALESCE((SELECT sftp_template_ids FROM templates), '{}')::uuid[] AS sftp_template_ids, - COALESCE((SELECT reconnecting_pty_template_ids FROM templates), '{}')::uuid[] AS reconnecting_pty_template_ids, - COALESCE((SELECT vscode_template_ids FROM templates), '{}')::uuid[] AS vscode_template_ids, - COALESCE((SELECT jetbrains_template_ids FROM templates), '{}')::uuid[] AS jetbrains_template_ids, - COALESCE(COUNT(DISTINCT user_id), 0)::bigint AS active_users, -- Includes app usage. - COALESCE(SUM(usage_mins) * 60, 0)::bigint AS usage_total_seconds, -- Includes app usage. - COALESCE(SUM(ssh_mins) * 60, 0)::bigint AS usage_ssh_seconds, - COALESCE(SUM(sftp_mins) * 60, 0)::bigint AS usage_sftp_seconds, - COALESCE(SUM(reconnecting_pty_mins) * 60, 0)::bigint AS usage_reconnecting_pty_seconds, - COALESCE(SUM(vscode_mins) * 60, 0)::bigint AS usage_vscode_seconds, - COALESCE(SUM(jetbrains_mins) * 60, 0)::bigint AS usage_jetbrains_seconds - FROM - insights; - */ - - var row database.GetTemplateInsightsRow - row.TemplateIDs = uniqueSortedUUIDs(templates.TemplateIDs) - row.SshTemplateIds = uniqueSortedUUIDs(templates.SSHTemplateIDs) - row.SftpTemplateIds = uniqueSortedUUIDs(templates.SFTPTemplateIDs) - row.ReconnectingPtyTemplateIds = uniqueSortedUUIDs(templates.ReconnectingPTYIDs) - row.VscodeTemplateIds = uniqueSortedUUIDs(templates.VSCodeTemplateIDs) - row.JetbrainsTemplateIds = uniqueSortedUUIDs(templates.JetBrainsTemplateIDs) - activeUserIDs := make(map[uuid.UUID]struct{}) - for _, insight := range insights { - activeUserIDs[insight.UserID] = struct{}{} - row.UsageTotalSeconds += int64(insight.UsageMins) * 60 - row.UsageSshSeconds += int64(insight.SSHMins) * 60 - row.UsageSftpSeconds += int64(insight.SFTPMins) * 60 - row.UsageReconnectingPtySeconds += int64(insight.ReconnectingPTYMins) * 60 - row.UsageVscodeSeconds += int64(insight.VSCodeMins) * 60 - row.UsageJetbrainsSeconds += int64(insight.JetBrainsMins) * 60 - } - row.ActiveUsers = int64(len(activeUserIDs)) - - return row, nil -} - -func (q *FakeQuerier) GetTemplateInsightsByInterval(_ context.Context, arg database.GetTemplateInsightsByIntervalParams) ([]database.GetTemplateInsightsByIntervalRow, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - /* - WITH - ts AS ( - SELECT - d::timestamptz AS from_, - CASE - WHEN (d::timestamptz + (@interval_days::int || ' day')::interval) <= @end_time::timestamptz - THEN (d::timestamptz + (@interval_days::int || ' day')::interval) - ELSE @end_time::timestamptz - END AS to_ - FROM - -- Subtract 1 microsecond from end_time to avoid including the next interval in the results. - generate_series(@start_time::timestamptz, (@end_time::timestamptz) - '1 microsecond'::interval, (@interval_days::int || ' day')::interval) AS d - ) - - SELECT - ts.from_ AS start_time, - ts.to_ AS end_time, - array_remove(array_agg(DISTINCT tus.template_id), NULL)::uuid[] AS template_ids, - COUNT(DISTINCT tus.user_id) AS active_users - FROM - ts - LEFT JOIN - template_usage_stats AS tus - ON - tus.start_time >= ts.from_ - AND tus.end_time <= ts.to_ - AND CASE WHEN COALESCE(array_length(@template_ids::uuid[], 1), 0) > 0 THEN tus.template_id = ANY(@template_ids::uuid[]) ELSE TRUE END - GROUP BY - ts.from_, ts.to_; - */ - - type interval struct { - From time.Time - To time.Time - } - var ts []interval - for d := arg.StartTime; d.Before(arg.EndTime); d = d.AddDate(0, 0, int(arg.IntervalDays)) { - to := d.AddDate(0, 0, int(arg.IntervalDays)) - if to.After(arg.EndTime) { - to = arg.EndTime - } - ts = append(ts, interval{From: d, To: to}) - } - - type grouped struct { - TemplateIDs map[uuid.UUID]struct{} - UserIDs map[uuid.UUID]struct{} - } - groupedByInterval := make(map[interval]grouped) - for _, tus := range q.templateUsageStats { - for _, t := range ts { - if tus.StartTime.Before(t.From) || tus.EndTime.After(t.To) { - continue - } - if len(arg.TemplateIDs) > 0 && !slices.Contains(arg.TemplateIDs, tus.TemplateID) { - continue - } - g, ok := groupedByInterval[t] - if !ok { - g = grouped{ - TemplateIDs: make(map[uuid.UUID]struct{}), - UserIDs: make(map[uuid.UUID]struct{}), - } - } - g.TemplateIDs[tus.TemplateID] = struct{}{} - g.UserIDs[tus.UserID] = struct{}{} - groupedByInterval[t] = g - } - } - - var rows []database.GetTemplateInsightsByIntervalRow - for _, t := range ts { // Ordered by interval. - row := database.GetTemplateInsightsByIntervalRow{ - StartTime: t.From, - EndTime: t.To, - } - row.TemplateIDs = uniqueSortedUUIDs(maps.Keys(groupedByInterval[t].TemplateIDs)) - row.ActiveUsers = int64(len(groupedByInterval[t].UserIDs)) - rows = append(rows, row) - } - - return rows, nil -} - -func (q *FakeQuerier) GetTemplateInsightsByTemplate(_ context.Context, arg database.GetTemplateInsightsByTemplateParams) ([]database.GetTemplateInsightsByTemplateRow, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - // map time.Time x TemplateID x UserID x - appUsageByTemplateAndUser := map[time.Time]map[uuid.UUID]map[uuid.UUID]database.GetTemplateInsightsByTemplateRow{} - - // Review agent stats in terms of usage - templateIDSet := make(map[uuid.UUID]struct{}) - - for _, s := range q.workspaceAgentStats { - if s.CreatedAt.Before(arg.StartTime) || s.CreatedAt.Equal(arg.EndTime) || s.CreatedAt.After(arg.EndTime) { - continue - } - if s.ConnectionCount == 0 { - continue - } - - t := s.CreatedAt.Truncate(time.Minute) - templateIDSet[s.TemplateID] = struct{}{} - - if _, ok := appUsageByTemplateAndUser[t]; !ok { - appUsageByTemplateAndUser[t] = make(map[uuid.UUID]map[uuid.UUID]database.GetTemplateInsightsByTemplateRow) - } - - if _, ok := appUsageByTemplateAndUser[t][s.TemplateID]; !ok { - appUsageByTemplateAndUser[t][s.TemplateID] = make(map[uuid.UUID]database.GetTemplateInsightsByTemplateRow) - } - - if _, ok := appUsageByTemplateAndUser[t][s.TemplateID][s.UserID]; !ok { - appUsageByTemplateAndUser[t][s.TemplateID][s.UserID] = database.GetTemplateInsightsByTemplateRow{} - } - - u := appUsageByTemplateAndUser[t][s.TemplateID][s.UserID] - if s.SessionCountJetBrains > 0 { - u.UsageJetbrainsSeconds = 60 - } - if s.SessionCountVSCode > 0 { - u.UsageVscodeSeconds = 60 - } - if s.SessionCountReconnectingPTY > 0 { - u.UsageReconnectingPtySeconds = 60 - } - if s.SessionCountSSH > 0 { - u.UsageSshSeconds = 60 - } - appUsageByTemplateAndUser[t][s.TemplateID][s.UserID] = u - } - - // Sort used templates - templateIDs := make([]uuid.UUID, 0, len(templateIDSet)) - for templateID := range templateIDSet { - templateIDs = append(templateIDs, templateID) - } - slices.SortFunc(templateIDs, func(a, b uuid.UUID) int { - return slice.Ascending(a.String(), b.String()) - }) - - // Build result - var result []database.GetTemplateInsightsByTemplateRow - for _, templateID := range templateIDs { - r := database.GetTemplateInsightsByTemplateRow{ - TemplateID: templateID, - } - - uniqueUsers := map[uuid.UUID]struct{}{} - - for _, mTemplateUserUsage := range appUsageByTemplateAndUser { - mUserUsage, ok := mTemplateUserUsage[templateID] - if !ok { - continue // template was not used in this time window - } - - for userID, usage := range mUserUsage { - uniqueUsers[userID] = struct{}{} - - r.UsageJetbrainsSeconds += usage.UsageJetbrainsSeconds - r.UsageVscodeSeconds += usage.UsageVscodeSeconds - r.UsageReconnectingPtySeconds += usage.UsageReconnectingPtySeconds - r.UsageSshSeconds += usage.UsageSshSeconds - } - } - - r.ActiveUsers = int64(len(uniqueUsers)) - - result = append(result, r) - } - return result, nil -} - -func (q *FakeQuerier) GetTemplateParameterInsights(ctx context.Context, arg database.GetTemplateParameterInsightsParams) ([]database.GetTemplateParameterInsightsRow, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - // WITH latest_workspace_builds ... - latestWorkspaceBuilds := make(map[uuid.UUID]database.WorkspaceBuild) - for _, wb := range q.workspaceBuilds { - if wb.CreatedAt.Before(arg.StartTime) || wb.CreatedAt.Equal(arg.EndTime) || wb.CreatedAt.After(arg.EndTime) { - continue - } - if latestWorkspaceBuilds[wb.WorkspaceID].BuildNumber < wb.BuildNumber { - latestWorkspaceBuilds[wb.WorkspaceID] = wb - } - } - if len(arg.TemplateIDs) > 0 { - for wsID := range latestWorkspaceBuilds { - ws, err := q.getWorkspaceByIDNoLock(ctx, wsID) - if err != nil { - return nil, err - } - if slices.Contains(arg.TemplateIDs, ws.TemplateID) { - delete(latestWorkspaceBuilds, wsID) - } - } - } - // WITH unique_template_params ... - num := int64(0) - uniqueTemplateParams := make(map[string]*database.GetTemplateParameterInsightsRow) - uniqueTemplateParamWorkspaceBuildIDs := make(map[string][]uuid.UUID) - for _, wb := range latestWorkspaceBuilds { - tv, err := q.getTemplateVersionByIDNoLock(ctx, wb.TemplateVersionID) - if err != nil { - return nil, err - } - for _, tvp := range q.templateVersionParameters { - if tvp.TemplateVersionID != tv.ID { - continue - } - // GROUP BY tvp.name, tvp.type, tvp.display_name, tvp.description, tvp.options - key := fmt.Sprintf("%s:%s:%s:%s:%s", tvp.Name, tvp.Type, tvp.DisplayName, tvp.Description, tvp.Options) - if _, ok := uniqueTemplateParams[key]; !ok { - num++ - uniqueTemplateParams[key] = &database.GetTemplateParameterInsightsRow{ - Num: num, - Name: tvp.Name, - Type: tvp.Type, - DisplayName: tvp.DisplayName, - Description: tvp.Description, - Options: tvp.Options, - } - } - uniqueTemplateParams[key].TemplateIDs = append(uniqueTemplateParams[key].TemplateIDs, tv.TemplateID.UUID) - uniqueTemplateParamWorkspaceBuildIDs[key] = append(uniqueTemplateParamWorkspaceBuildIDs[key], wb.ID) - } - } - // SELECT ... - counts := make(map[string]map[string]int64) - for key, utp := range uniqueTemplateParams { - for _, wbp := range q.workspaceBuildParameters { - if !slices.Contains(uniqueTemplateParamWorkspaceBuildIDs[key], wbp.WorkspaceBuildID) { - continue - } - if wbp.Name != utp.Name { - continue - } - if counts[key] == nil { - counts[key] = make(map[string]int64) - } - counts[key][wbp.Value]++ - } - } - - var rows []database.GetTemplateParameterInsightsRow - for key, utp := range uniqueTemplateParams { - for value, count := range counts[key] { - rows = append(rows, database.GetTemplateParameterInsightsRow{ - Num: utp.Num, - TemplateIDs: uniqueSortedUUIDs(utp.TemplateIDs), - Name: utp.Name, - DisplayName: utp.DisplayName, - Type: utp.Type, - Description: utp.Description, - Options: utp.Options, - Value: value, - Count: count, - }) - } - } - - // NOTE(mafredri): Add sorting if we decide on how to handle PostgreSQL collations. - // ORDER BY utp.name, utp.type, utp.display_name, utp.description, utp.options, wbp.value - return rows, nil -} - -func (*FakeQuerier) GetTemplatePresetsWithPrebuilds(_ context.Context, _ uuid.NullUUID) ([]database.GetTemplatePresetsWithPrebuildsRow, error) { - return nil, ErrUnimplemented -} - -func (q *FakeQuerier) GetTemplateUsageStats(_ context.Context, arg database.GetTemplateUsageStatsParams) ([]database.TemplateUsageStat, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - var stats []database.TemplateUsageStat - for _, stat := range q.templateUsageStats { - // Exclude all chunks that don't fall exactly within the range. - if stat.StartTime.Before(arg.StartTime) || stat.EndTime.After(arg.EndTime) { - continue - } - if len(arg.TemplateIDs) > 0 && !slices.Contains(arg.TemplateIDs, stat.TemplateID) { - continue - } - stats = append(stats, stat) - } - - if len(stats) == 0 { - return nil, sql.ErrNoRows - } - - return stats, nil -} - -func (q *FakeQuerier) GetTemplateVersionByID(ctx context.Context, templateVersionID uuid.UUID) (database.TemplateVersion, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.getTemplateVersionByIDNoLock(ctx, templateVersionID) -} - -func (q *FakeQuerier) GetTemplateVersionByJobID(_ context.Context, jobID uuid.UUID) (database.TemplateVersion, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, templateVersion := range q.templateVersions { - if templateVersion.JobID != jobID { - continue - } - return q.templateVersionWithUserNoLock(templateVersion), nil - } - return database.TemplateVersion{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetTemplateVersionByTemplateIDAndName(_ context.Context, arg database.GetTemplateVersionByTemplateIDAndNameParams) (database.TemplateVersion, error) { - if err := validateDatabaseType(arg); err != nil { - return database.TemplateVersion{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, templateVersion := range q.templateVersions { - if templateVersion.TemplateID != arg.TemplateID { - continue - } - if !strings.EqualFold(templateVersion.Name, arg.Name) { - continue - } - return q.templateVersionWithUserNoLock(templateVersion), nil - } - return database.TemplateVersion{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetTemplateVersionParameters(_ context.Context, templateVersionID uuid.UUID) ([]database.TemplateVersionParameter, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - parameters := make([]database.TemplateVersionParameter, 0) - for _, param := range q.templateVersionParameters { - if param.TemplateVersionID != templateVersionID { - continue - } - parameters = append(parameters, param) - } - sort.Slice(parameters, func(i, j int) bool { - if parameters[i].DisplayOrder != parameters[j].DisplayOrder { - return parameters[i].DisplayOrder < parameters[j].DisplayOrder - } - return strings.ToLower(parameters[i].Name) < strings.ToLower(parameters[j].Name) - }) - return parameters, nil -} - -func (q *FakeQuerier) GetTemplateVersionTerraformValues(ctx context.Context, templateVersionID uuid.UUID) (database.TemplateVersionTerraformValue, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, tvtv := range q.templateVersionTerraformValues { - if tvtv.TemplateVersionID == templateVersionID { - return tvtv, nil - } - } - - return database.TemplateVersionTerraformValue{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetTemplateVersionVariables(_ context.Context, templateVersionID uuid.UUID) ([]database.TemplateVersionVariable, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - variables := make([]database.TemplateVersionVariable, 0) - for _, variable := range q.templateVersionVariables { - if variable.TemplateVersionID != templateVersionID { - continue - } - variables = append(variables, variable) - } - return variables, nil -} - -func (q *FakeQuerier) GetTemplateVersionWorkspaceTags(_ context.Context, templateVersionID uuid.UUID) ([]database.TemplateVersionWorkspaceTag, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - workspaceTags := make([]database.TemplateVersionWorkspaceTag, 0) - for _, workspaceTag := range q.templateVersionWorkspaceTags { - if workspaceTag.TemplateVersionID != templateVersionID { - continue - } - workspaceTags = append(workspaceTags, workspaceTag) - } - - sort.Slice(workspaceTags, func(i, j int) bool { - return workspaceTags[i].Key < workspaceTags[j].Key - }) - return workspaceTags, nil -} - -func (q *FakeQuerier) GetTemplateVersionsByIDs(_ context.Context, ids []uuid.UUID) ([]database.TemplateVersion, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - versions := make([]database.TemplateVersion, 0) - for _, version := range q.templateVersions { - for _, id := range ids { - if id == version.ID { - versions = append(versions, q.templateVersionWithUserNoLock(version)) - break - } - } - } - if len(versions) == 0 { - return nil, sql.ErrNoRows - } - - return versions, nil -} - -func (q *FakeQuerier) GetTemplateVersionsByTemplateID(_ context.Context, arg database.GetTemplateVersionsByTemplateIDParams) (version []database.TemplateVersion, err error) { - if err := validateDatabaseType(arg); err != nil { - return version, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, templateVersion := range q.templateVersions { - if templateVersion.TemplateID.UUID != arg.TemplateID { - continue - } - if arg.Archived.Valid && arg.Archived.Bool != templateVersion.Archived { - continue - } - version = append(version, q.templateVersionWithUserNoLock(templateVersion)) - } - - // Database orders by created_at - slices.SortFunc(version, func(a, b database.TemplateVersion) int { - if a.CreatedAt.Equal(b.CreatedAt) { - // Technically the postgres database also orders by uuid. So match - // that behavior - return slice.Ascending(a.ID.String(), b.ID.String()) - } - if a.CreatedAt.Before(b.CreatedAt) { - return -1 - } - return 1 - }) - - if arg.AfterID != uuid.Nil { - found := false - for i, v := range version { - if v.ID == arg.AfterID { - // We want to return all users after index i. - version = version[i+1:] - found = true - break - } - } - - // If no users after the time, then we return an empty list. - if !found { - return nil, sql.ErrNoRows - } - } - - if arg.OffsetOpt > 0 { - if int(arg.OffsetOpt) > len(version)-1 { - return nil, sql.ErrNoRows - } - version = version[arg.OffsetOpt:] - } - - if arg.LimitOpt > 0 { - if int(arg.LimitOpt) > len(version) { - // #nosec G115 - Safe conversion as version slice length is expected to be within int32 range - arg.LimitOpt = int32(len(version)) - } - version = version[:arg.LimitOpt] - } - - if len(version) == 0 { - return nil, sql.ErrNoRows - } - - return version, nil -} - -func (q *FakeQuerier) GetTemplateVersionsCreatedAfter(_ context.Context, after time.Time) ([]database.TemplateVersion, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - versions := make([]database.TemplateVersion, 0) - for _, version := range q.templateVersions { - if version.CreatedAt.After(after) { - versions = append(versions, q.templateVersionWithUserNoLock(version)) - } - } - return versions, nil -} - -func (q *FakeQuerier) GetTemplates(_ context.Context) ([]database.Template, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - templates := slices.Clone(q.templates) - slices.SortFunc(templates, func(a, b database.TemplateTable) int { - if a.Name != b.Name { - return slice.Ascending(a.Name, b.Name) - } - return slice.Ascending(a.ID.String(), b.ID.String()) - }) - - return q.templatesWithUserNoLock(templates), nil -} - -func (q *FakeQuerier) GetTemplatesWithFilter(ctx context.Context, arg database.GetTemplatesWithFilterParams) ([]database.Template, error) { - if err := validateDatabaseType(arg); err != nil { - return nil, err - } - - return q.GetAuthorizedTemplates(ctx, arg, nil) -} - -func (q *FakeQuerier) GetUnexpiredLicenses(_ context.Context) ([]database.License, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - now := time.Now() - var results []database.License - for _, l := range q.licenses { - if l.Exp.After(now) { - results = append(results, l) - } - } - sort.Slice(results, func(i, j int) bool { return results[i].ID < results[j].ID }) - return results, nil -} - -func (q *FakeQuerier) GetUserActivityInsights(_ context.Context, arg database.GetUserActivityInsightsParams) ([]database.GetUserActivityInsightsRow, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - /* - WITH - */ - /* - deployment_stats AS ( - SELECT - start_time, - user_id, - array_agg(template_id) AS template_ids, - -- See motivation in GetTemplateInsights for LEAST(SUM(n), 30). - LEAST(SUM(usage_mins), 30) AS usage_mins - FROM - template_usage_stats - WHERE - start_time >= @start_time::timestamptz - AND end_time <= @end_time::timestamptz - AND CASE WHEN COALESCE(array_length(@template_ids::uuid[], 1), 0) > 0 THEN template_id = ANY(@template_ids::uuid[]) ELSE TRUE END - GROUP BY - start_time, user_id - ), - */ - - type deploymentStatsGroupBy struct { - StartTime time.Time - UserID uuid.UUID - } - type deploymentStatsRow struct { - deploymentStatsGroupBy - TemplateIDs []uuid.UUID - UsageMins int16 - } - deploymentStatsRows := make(map[deploymentStatsGroupBy]deploymentStatsRow) - for _, stat := range q.templateUsageStats { - if stat.StartTime.Before(arg.StartTime) || stat.EndTime.After(arg.EndTime) { - continue - } - if len(arg.TemplateIDs) > 0 && !slices.Contains(arg.TemplateIDs, stat.TemplateID) { - continue - } - key := deploymentStatsGroupBy{ - StartTime: stat.StartTime, - UserID: stat.UserID, - } - row, ok := deploymentStatsRows[key] - if !ok { - row = deploymentStatsRow{ - deploymentStatsGroupBy: key, - } - } - row.TemplateIDs = append(row.TemplateIDs, stat.TemplateID) - row.UsageMins = least(row.UsageMins+stat.UsageMins, 30) - deploymentStatsRows[key] = row - } - - /* - template_ids AS ( - SELECT - user_id, - array_agg(DISTINCT template_id) AS ids - FROM - deployment_stats, unnest(template_ids) template_id - GROUP BY - user_id - ) - */ - - type templateIDsRow struct { - UserID uuid.UUID - TemplateIDs []uuid.UUID - } - templateIDs := make(map[uuid.UUID]templateIDsRow) - for _, dsRow := range deploymentStatsRows { - row, ok := templateIDs[dsRow.UserID] - if !ok { - row = templateIDsRow{ - UserID: row.UserID, - } - } - row.TemplateIDs = uniqueSortedUUIDs(append(row.TemplateIDs, dsRow.TemplateIDs...)) - templateIDs[dsRow.UserID] = row - } - - /* - SELECT - ds.user_id, - u.username, - u.avatar_url, - t.ids::uuid[] AS template_ids, - (SUM(ds.usage_mins) * 60)::bigint AS usage_seconds - FROM - deployment_stats ds - JOIN - users u - ON - u.id = ds.user_id - JOIN - template_ids t - ON - ds.user_id = t.user_id - GROUP BY - ds.user_id, u.username, u.avatar_url, t.ids - ORDER BY - ds.user_id ASC; - */ - - var rows []database.GetUserActivityInsightsRow - groupedRows := make(map[uuid.UUID]database.GetUserActivityInsightsRow) - for _, dsRow := range deploymentStatsRows { - row, ok := groupedRows[dsRow.UserID] - if !ok { - user, err := q.getUserByIDNoLock(dsRow.UserID) - if err != nil { - return nil, err - } - row = database.GetUserActivityInsightsRow{ - UserID: user.ID, - Username: user.Username, - AvatarURL: user.AvatarURL, - TemplateIDs: templateIDs[user.ID].TemplateIDs, - } - } - row.UsageSeconds += int64(dsRow.UsageMins) * 60 - groupedRows[dsRow.UserID] = row - } - for _, row := range groupedRows { - rows = append(rows, row) - } - if len(rows) == 0 { - return nil, sql.ErrNoRows - } - slices.SortFunc(rows, func(a, b database.GetUserActivityInsightsRow) int { - return slice.Ascending(a.UserID.String(), b.UserID.String()) - }) - - return rows, nil -} - -func (q *FakeQuerier) GetUserByEmailOrUsername(_ context.Context, arg database.GetUserByEmailOrUsernameParams) (database.User, error) { - if err := validateDatabaseType(arg); err != nil { - return database.User{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, user := range q.users { - if !user.Deleted && (strings.EqualFold(user.Email, arg.Email) || strings.EqualFold(user.Username, arg.Username)) { - return user, nil - } - } - return database.User{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetUserByID(_ context.Context, id uuid.UUID) (database.User, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.getUserByIDNoLock(id) -} - -// nolint:revive // It's not a control flag, it's a filter. -func (q *FakeQuerier) GetUserCount(_ context.Context, includeSystem bool) (int64, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - existing := int64(0) - for _, u := range q.users { - if !includeSystem && u.IsSystem { - continue - } - if !u.Deleted { - existing++ - } - - if !includeSystem && u.IsSystem { - continue - } - } - return existing, nil -} - -func (q *FakeQuerier) GetUserLatencyInsights(_ context.Context, arg database.GetUserLatencyInsightsParams) ([]database.GetUserLatencyInsightsRow, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - /* - SELECT - tus.user_id, - u.username, - u.avatar_url, - array_agg(DISTINCT tus.template_id)::uuid[] AS template_ids, - COALESCE((PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY tus.median_latency_ms)), -1)::float AS workspace_connection_latency_50, - COALESCE((PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY tus.median_latency_ms)), -1)::float AS workspace_connection_latency_95 - FROM - template_usage_stats tus - JOIN - users u - ON - u.id = tus.user_id - WHERE - tus.start_time >= @start_time::timestamptz - AND tus.end_time <= @end_time::timestamptz - AND CASE WHEN COALESCE(array_length(@template_ids::uuid[], 1), 0) > 0 THEN tus.template_id = ANY(@template_ids::uuid[]) ELSE TRUE END - GROUP BY - tus.user_id, u.username, u.avatar_url - ORDER BY - tus.user_id ASC; - */ - - latenciesByUserID := make(map[uuid.UUID][]float64) - seenTemplatesByUserID := make(map[uuid.UUID][]uuid.UUID) - for _, stat := range q.templateUsageStats { - if stat.StartTime.Before(arg.StartTime) || stat.EndTime.After(arg.EndTime) { - continue - } - if len(arg.TemplateIDs) > 0 && !slices.Contains(arg.TemplateIDs, stat.TemplateID) { - continue - } - - if stat.MedianLatencyMs.Valid { - latenciesByUserID[stat.UserID] = append(latenciesByUserID[stat.UserID], stat.MedianLatencyMs.Float64) - } - seenTemplatesByUserID[stat.UserID] = uniqueSortedUUIDs(append(seenTemplatesByUserID[stat.UserID], stat.TemplateID)) - } - - var rows []database.GetUserLatencyInsightsRow - for userID, latencies := range latenciesByUserID { - user, err := q.getUserByIDNoLock(userID) - if err != nil { - return nil, err - } - row := database.GetUserLatencyInsightsRow{ - UserID: userID, - Username: user.Username, - AvatarURL: user.AvatarURL, - TemplateIDs: seenTemplatesByUserID[userID], - WorkspaceConnectionLatency50: tryPercentileCont(latencies, 50), - WorkspaceConnectionLatency95: tryPercentileCont(latencies, 95), - } - rows = append(rows, row) - } - slices.SortFunc(rows, func(a, b database.GetUserLatencyInsightsRow) int { - return slice.Ascending(a.UserID.String(), b.UserID.String()) - }) - - return rows, nil -} - -func (q *FakeQuerier) GetUserLinkByLinkedID(_ context.Context, id string) (database.UserLink, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, link := range q.userLinks { - user, err := q.getUserByIDNoLock(link.UserID) - if err == nil && user.Deleted { - continue - } - if link.LinkedID == id { - return link, nil - } - } - return database.UserLink{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetUserLinkByUserIDLoginType(_ context.Context, params database.GetUserLinkByUserIDLoginTypeParams) (database.UserLink, error) { - if err := validateDatabaseType(params); err != nil { - return database.UserLink{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, link := range q.userLinks { - if link.UserID == params.UserID && link.LoginType == params.LoginType { - return link, nil - } - } - return database.UserLink{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetUserLinksByUserID(_ context.Context, userID uuid.UUID) ([]database.UserLink, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - uls := make([]database.UserLink, 0) - for _, ul := range q.userLinks { - if ul.UserID == userID { - uls = append(uls, ul) - } - } - return uls, nil -} - -func (q *FakeQuerier) GetUserNotificationPreferences(_ context.Context, userID uuid.UUID) ([]database.NotificationPreference, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - out := make([]database.NotificationPreference, 0, len(q.notificationPreferences)) - for _, np := range q.notificationPreferences { - if np.UserID != userID { - continue - } - - out = append(out, np) - } - - return out, nil -} - -func (q *FakeQuerier) GetUserStatusCounts(_ context.Context, arg database.GetUserStatusCountsParams) ([]database.GetUserStatusCountsRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - result := make([]database.GetUserStatusCountsRow, 0) - for _, change := range q.userStatusChanges { - if change.ChangedAt.Before(arg.StartTime) || change.ChangedAt.After(arg.EndTime) { - continue - } - date := time.Date(change.ChangedAt.Year(), change.ChangedAt.Month(), change.ChangedAt.Day(), 0, 0, 0, 0, time.UTC) - if !slices.ContainsFunc(result, func(r database.GetUserStatusCountsRow) bool { - return r.Status == change.NewStatus && r.Date.Equal(date) - }) { - result = append(result, database.GetUserStatusCountsRow{ - Status: change.NewStatus, - Date: date, - Count: 1, - }) - } else { - for i, r := range result { - if r.Status == change.NewStatus && r.Date.Equal(date) { - result[i].Count++ - break - } - } - } - } - - return result, nil -} - -func (q *FakeQuerier) GetUserTerminalFont(ctx context.Context, userID uuid.UUID) (string, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, uc := range q.userConfigs { - if uc.UserID != userID || uc.Key != "terminal_font" { - continue - } - return uc.Value, nil - } - - return "", sql.ErrNoRows -} - -func (q *FakeQuerier) GetUserThemePreference(_ context.Context, userID uuid.UUID) (string, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, uc := range q.userConfigs { - if uc.UserID != userID || uc.Key != "theme_preference" { - continue - } - return uc.Value, nil - } - - return "", sql.ErrNoRows -} - -func (q *FakeQuerier) GetUserWorkspaceBuildParameters(_ context.Context, params database.GetUserWorkspaceBuildParametersParams) ([]database.GetUserWorkspaceBuildParametersRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - userWorkspaceIDs := make(map[uuid.UUID]struct{}) - for _, ws := range q.workspaces { - if ws.OwnerID != params.OwnerID { - continue - } - if ws.TemplateID != params.TemplateID { - continue - } - userWorkspaceIDs[ws.ID] = struct{}{} - } - - userWorkspaceBuilds := make(map[uuid.UUID]struct{}) - for _, wb := range q.workspaceBuilds { - if _, ok := userWorkspaceIDs[wb.WorkspaceID]; !ok { - continue - } - userWorkspaceBuilds[wb.ID] = struct{}{} - } - - templateVersions := make(map[uuid.UUID]struct{}) - for _, tv := range q.templateVersions { - if tv.TemplateID.UUID != params.TemplateID { - continue - } - templateVersions[tv.ID] = struct{}{} - } - - tvps := make(map[string]struct{}) - for _, tvp := range q.templateVersionParameters { - if _, ok := templateVersions[tvp.TemplateVersionID]; !ok { - continue - } - - if _, ok := tvps[tvp.Name]; !ok && !tvp.Ephemeral { - tvps[tvp.Name] = struct{}{} - } - } - - userWorkspaceBuildParameters := make(map[string]database.GetUserWorkspaceBuildParametersRow) - for _, wbp := range q.workspaceBuildParameters { - if _, ok := userWorkspaceBuilds[wbp.WorkspaceBuildID]; !ok { - continue - } - if _, ok := tvps[wbp.Name]; !ok { - continue - } - userWorkspaceBuildParameters[wbp.Name] = database.GetUserWorkspaceBuildParametersRow{ - Name: wbp.Name, - Value: wbp.Value, - } - } - - return maps.Values(userWorkspaceBuildParameters), nil -} - -func (q *FakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams) ([]database.GetUsersRow, error) { - if err := validateDatabaseType(params); err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - // Avoid side-effect of sorting. - users := make([]database.User, len(q.users)) - copy(users, q.users) - - // Database orders by username - slices.SortFunc(users, func(a, b database.User) int { - return slice.Ascending(strings.ToLower(a.Username), strings.ToLower(b.Username)) - }) - - // Filter out deleted since they should never be returned.. - tmp := make([]database.User, 0, len(users)) - for _, user := range users { - if !user.Deleted { - tmp = append(tmp, user) - } - } - users = tmp - - if params.AfterID != uuid.Nil { - found := false - for i, v := range users { - if v.ID == params.AfterID { - // We want to return all users after index i. - users = users[i+1:] - found = true - break - } - } - - // If no users after the time, then we return an empty list. - if !found { - return []database.GetUsersRow{}, nil - } - } - - if params.Search != "" { - tmp := make([]database.User, 0, len(users)) - for i, user := range users { - if strings.Contains(strings.ToLower(user.Email), strings.ToLower(params.Search)) { - tmp = append(tmp, users[i]) - } else if strings.Contains(strings.ToLower(user.Username), strings.ToLower(params.Search)) { - tmp = append(tmp, users[i]) - } - } - users = tmp - } - - if len(params.Status) > 0 { - usersFilteredByStatus := make([]database.User, 0, len(users)) - for i, user := range users { - if slice.ContainsCompare(params.Status, user.Status, func(a, b database.UserStatus) bool { - return strings.EqualFold(string(a), string(b)) - }) { - usersFilteredByStatus = append(usersFilteredByStatus, users[i]) - } - } - users = usersFilteredByStatus - } - - if len(params.RbacRole) > 0 && !slice.Contains(params.RbacRole, rbac.RoleMember().String()) { - usersFilteredByRole := make([]database.User, 0, len(users)) - for i, user := range users { - if slice.OverlapCompare(params.RbacRole, user.RBACRoles, strings.EqualFold) { - usersFilteredByRole = append(usersFilteredByRole, users[i]) - } - } - users = usersFilteredByRole - } - - if len(params.LoginType) > 0 { - usersFilteredByLoginType := make([]database.User, 0, len(users)) - for i, user := range users { - if slice.ContainsCompare(params.LoginType, user.LoginType, func(a, b database.LoginType) bool { - return strings.EqualFold(string(a), string(b)) - }) { - usersFilteredByLoginType = append(usersFilteredByLoginType, users[i]) - } - } - users = usersFilteredByLoginType - } - - if !params.CreatedBefore.IsZero() { - usersFilteredByCreatedAt := make([]database.User, 0, len(users)) - for i, user := range users { - if user.CreatedAt.Before(params.CreatedBefore) { - usersFilteredByCreatedAt = append(usersFilteredByCreatedAt, users[i]) - } - } - users = usersFilteredByCreatedAt - } - - if !params.CreatedAfter.IsZero() { - usersFilteredByCreatedAt := make([]database.User, 0, len(users)) - for i, user := range users { - if user.CreatedAt.After(params.CreatedAfter) { - usersFilteredByCreatedAt = append(usersFilteredByCreatedAt, users[i]) - } - } - users = usersFilteredByCreatedAt - } - - if !params.LastSeenBefore.IsZero() { - usersFilteredByLastSeen := make([]database.User, 0, len(users)) - for i, user := range users { - if user.LastSeenAt.Before(params.LastSeenBefore) { - usersFilteredByLastSeen = append(usersFilteredByLastSeen, users[i]) - } - } - users = usersFilteredByLastSeen - } - - if !params.LastSeenAfter.IsZero() { - usersFilteredByLastSeen := make([]database.User, 0, len(users)) - for i, user := range users { - if user.LastSeenAt.After(params.LastSeenAfter) { - usersFilteredByLastSeen = append(usersFilteredByLastSeen, users[i]) - } - } - users = usersFilteredByLastSeen - } - - if !params.IncludeSystem { - users = slices.DeleteFunc(users, func(u database.User) bool { - return u.IsSystem - }) - } - - if params.GithubComUserID != 0 { - usersFilteredByGithubComUserID := make([]database.User, 0, len(users)) - for i, user := range users { - if user.GithubComUserID.Int64 == params.GithubComUserID { - usersFilteredByGithubComUserID = append(usersFilteredByGithubComUserID, users[i]) - } - } - users = usersFilteredByGithubComUserID - } - - beforePageCount := len(users) - - if params.OffsetOpt > 0 { - if int(params.OffsetOpt) > len(users)-1 { - return []database.GetUsersRow{}, nil - } - users = users[params.OffsetOpt:] - } - - if params.LimitOpt > 0 { - if int(params.LimitOpt) > len(users) { - // #nosec G115 - Safe conversion as users slice length is expected to be within int32 range - params.LimitOpt = int32(len(users)) - } - users = users[:params.LimitOpt] - } - - return convertUsers(users, int64(beforePageCount)), nil -} - -func (q *FakeQuerier) GetUsersByIDs(_ context.Context, ids []uuid.UUID) ([]database.User, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - users := make([]database.User, 0) - for _, user := range q.users { - for _, id := range ids { - if user.ID != id { - continue - } - users = append(users, user) - } - } - return users, nil -} - -func (q *FakeQuerier) GetWebpushSubscriptionsByUserID(_ context.Context, userID uuid.UUID) ([]database.WebpushSubscription, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - out := make([]database.WebpushSubscription, 0) - for _, subscription := range q.webpushSubscriptions { - if subscription.UserID == userID { - out = append(out, subscription) - } - } - - return out, nil -} - -func (q *FakeQuerier) GetWebpushVAPIDKeys(_ context.Context) (database.GetWebpushVAPIDKeysRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - if q.webpushVAPIDPublicKey == "" && q.webpushVAPIDPrivateKey == "" { - return database.GetWebpushVAPIDKeysRow{}, sql.ErrNoRows - } - - return database.GetWebpushVAPIDKeysRow{ - VapidPublicKey: q.webpushVAPIDPublicKey, - VapidPrivateKey: q.webpushVAPIDPrivateKey, - }, nil -} - -func (q *FakeQuerier) GetWorkspaceAgentAndLatestBuildByAuthToken(_ context.Context, authToken uuid.UUID) (database.GetWorkspaceAgentAndLatestBuildByAuthTokenRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - rows := []database.GetWorkspaceAgentAndLatestBuildByAuthTokenRow{} - // We want to return the latest build number for each workspace - latestBuildNumber := make(map[uuid.UUID]int32) - - for _, agt := range q.workspaceAgents { - if agt.Deleted { - continue - } - - // get the related workspace and user - for _, res := range q.workspaceResources { - if agt.ResourceID != res.ID { - continue - } - for _, build := range q.workspaceBuilds { - if build.JobID != res.JobID { - continue - } - for _, ws := range q.workspaces { - if build.WorkspaceID != ws.ID { - continue - } - if ws.Deleted { - continue - } - row := database.GetWorkspaceAgentAndLatestBuildByAuthTokenRow{ - WorkspaceTable: database.WorkspaceTable{ - ID: ws.ID, - TemplateID: ws.TemplateID, - }, - WorkspaceAgent: agt, - WorkspaceBuild: build, - } - usr, err := q.getUserByIDNoLock(ws.OwnerID) - if err != nil { - return database.GetWorkspaceAgentAndLatestBuildByAuthTokenRow{}, sql.ErrNoRows - } - row.WorkspaceTable.OwnerID = usr.ID - - // Keep track of the latest build number - rows = append(rows, row) - if build.BuildNumber > latestBuildNumber[ws.ID] { - latestBuildNumber[ws.ID] = build.BuildNumber - } - } - } - } - } - - for i := range rows { - if rows[i].WorkspaceAgent.AuthToken != authToken { - continue - } - - if rows[i].WorkspaceBuild.BuildNumber != latestBuildNumber[rows[i].WorkspaceTable.ID] { - continue - } - - return rows[i], nil - } - - return database.GetWorkspaceAgentAndLatestBuildByAuthTokenRow{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (database.WorkspaceAgent, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.getWorkspaceAgentByIDNoLock(ctx, id) -} - -func (q *FakeQuerier) GetWorkspaceAgentByInstanceID(_ context.Context, instanceID string) (database.WorkspaceAgent, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - // The schema sorts this by created at, so we iterate the array backwards. - for i := len(q.workspaceAgents) - 1; i >= 0; i-- { - agent := q.workspaceAgents[i] - if !agent.Deleted && agent.AuthInstanceID.Valid && agent.AuthInstanceID.String == instanceID { - return agent, nil - } - } - return database.WorkspaceAgent{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetWorkspaceAgentDevcontainersByAgentID(_ context.Context, workspaceAgentID uuid.UUID) ([]database.WorkspaceAgentDevcontainer, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - devcontainers := make([]database.WorkspaceAgentDevcontainer, 0) - for _, dc := range q.workspaceAgentDevcontainers { - if dc.WorkspaceAgentID == workspaceAgentID { - devcontainers = append(devcontainers, dc) - } - } - if len(devcontainers) == 0 { - return nil, sql.ErrNoRows - } - return devcontainers, nil -} - -func (q *FakeQuerier) GetWorkspaceAgentLifecycleStateByID(ctx context.Context, id uuid.UUID) (database.GetWorkspaceAgentLifecycleStateByIDRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - agent, err := q.getWorkspaceAgentByIDNoLock(ctx, id) - if err != nil { - return database.GetWorkspaceAgentLifecycleStateByIDRow{}, err - } - return database.GetWorkspaceAgentLifecycleStateByIDRow{ - LifecycleState: agent.LifecycleState, - StartedAt: agent.StartedAt, - ReadyAt: agent.ReadyAt, - }, nil -} - -func (q *FakeQuerier) GetWorkspaceAgentLogSourcesByAgentIDs(_ context.Context, ids []uuid.UUID) ([]database.WorkspaceAgentLogSource, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - logSources := make([]database.WorkspaceAgentLogSource, 0) - for _, logSource := range q.workspaceAgentLogSources { - for _, id := range ids { - if logSource.WorkspaceAgentID == id { - logSources = append(logSources, logSource) - break - } - } - } - return logSources, nil -} - -func (q *FakeQuerier) GetWorkspaceAgentLogsAfter(_ context.Context, arg database.GetWorkspaceAgentLogsAfterParams) ([]database.WorkspaceAgentLog, error) { - if err := validateDatabaseType(arg); err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - logs := []database.WorkspaceAgentLog{} - for _, log := range q.workspaceAgentLogs { - if log.AgentID != arg.AgentID { - continue - } - if arg.CreatedAfter != 0 && log.ID <= arg.CreatedAfter { - continue - } - logs = append(logs, log) - } - return logs, nil -} - -func (q *FakeQuerier) GetWorkspaceAgentMetadata(_ context.Context, arg database.GetWorkspaceAgentMetadataParams) ([]database.WorkspaceAgentMetadatum, error) { - if err := validateDatabaseType(arg); err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - metadata := make([]database.WorkspaceAgentMetadatum, 0) - for _, m := range q.workspaceAgentMetadata { - if m.WorkspaceAgentID == arg.WorkspaceAgentID { - if len(arg.Keys) > 0 && !slices.Contains(arg.Keys, m.Key) { - continue - } - metadata = append(metadata, m) - } - } - return metadata, nil -} - -func (q *FakeQuerier) GetWorkspaceAgentPortShare(_ context.Context, arg database.GetWorkspaceAgentPortShareParams) (database.WorkspaceAgentPortShare, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.WorkspaceAgentPortShare{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, share := range q.workspaceAgentPortShares { - if share.WorkspaceID == arg.WorkspaceID && share.AgentName == arg.AgentName && share.Port == arg.Port { - return share, nil - } - } - - return database.WorkspaceAgentPortShare{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetWorkspaceAgentScriptTimingsByBuildID(ctx context.Context, id uuid.UUID) ([]database.GetWorkspaceAgentScriptTimingsByBuildIDRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - build, err := q.getWorkspaceBuildByIDNoLock(ctx, id) - if err != nil { - return nil, xerrors.Errorf("get build: %w", err) - } - - resources, err := q.getWorkspaceResourcesByJobIDNoLock(ctx, build.JobID) - if err != nil { - return nil, xerrors.Errorf("get resources: %w", err) - } - resourceIDs := make([]uuid.UUID, 0, len(resources)) - for _, res := range resources { - resourceIDs = append(resourceIDs, res.ID) - } - - agents, err := q.getWorkspaceAgentsByResourceIDsNoLock(ctx, resourceIDs) - if err != nil { - return nil, xerrors.Errorf("get agents: %w", err) - } - agentIDs := make([]uuid.UUID, 0, len(agents)) - for _, agent := range agents { - agentIDs = append(agentIDs, agent.ID) - } - - scripts, err := q.getWorkspaceAgentScriptsByAgentIDsNoLock(agentIDs) - if err != nil { - return nil, xerrors.Errorf("get scripts: %w", err) - } - scriptIDs := make([]uuid.UUID, 0, len(scripts)) - for _, script := range scripts { - scriptIDs = append(scriptIDs, script.ID) - } - - rows := []database.GetWorkspaceAgentScriptTimingsByBuildIDRow{} - for _, t := range q.workspaceAgentScriptTimings { - if !slice.Contains(scriptIDs, t.ScriptID) { - continue - } - - var script database.WorkspaceAgentScript - for _, s := range scripts { - if s.ID == t.ScriptID { - script = s - break - } - } - if script.ID == uuid.Nil { - return nil, xerrors.Errorf("script with ID %s not found", t.ScriptID) - } - - var agent database.WorkspaceAgent - for _, a := range agents { - if a.ID == script.WorkspaceAgentID { - agent = a - break - } - } - if agent.ID == uuid.Nil { - return nil, xerrors.Errorf("agent with ID %s not found", t.ScriptID) - } - - rows = append(rows, database.GetWorkspaceAgentScriptTimingsByBuildIDRow{ - ScriptID: t.ScriptID, - StartedAt: t.StartedAt, - EndedAt: t.EndedAt, - ExitCode: t.ExitCode, - Stage: t.Stage, - Status: t.Status, - DisplayName: script.DisplayName, - WorkspaceAgentID: agent.ID, - WorkspaceAgentName: agent.Name, - }) - } - - // We want to only return the first script run for each Script ID. - slices.SortFunc(rows, func(a, b database.GetWorkspaceAgentScriptTimingsByBuildIDRow) int { - return a.StartedAt.Compare(b.StartedAt) - }) - rows = slices.CompactFunc(rows, func(e1, e2 database.GetWorkspaceAgentScriptTimingsByBuildIDRow) bool { - return e1.ScriptID == e2.ScriptID - }) - - return rows, nil -} - -func (q *FakeQuerier) GetWorkspaceAgentScriptsByAgentIDs(_ context.Context, ids []uuid.UUID) ([]database.WorkspaceAgentScript, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.getWorkspaceAgentScriptsByAgentIDsNoLock(ids) -} - -func (q *FakeQuerier) GetWorkspaceAgentStats(_ context.Context, createdAfter time.Time) ([]database.GetWorkspaceAgentStatsRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - agentStatsCreatedAfter := make([]database.WorkspaceAgentStat, 0) - for _, agentStat := range q.workspaceAgentStats { - if agentStat.CreatedAt.After(createdAfter) || agentStat.CreatedAt.Equal(createdAfter) { - agentStatsCreatedAfter = append(agentStatsCreatedAfter, agentStat) - } - } - - latestAgentStats := map[uuid.UUID]database.WorkspaceAgentStat{} - for _, agentStat := range q.workspaceAgentStats { - if agentStat.CreatedAt.After(createdAfter) || agentStat.CreatedAt.Equal(createdAfter) { - latestAgentStats[agentStat.AgentID] = agentStat - } - } - - statByAgent := map[uuid.UUID]database.GetWorkspaceAgentStatsRow{} - for agentID, agentStat := range latestAgentStats { - stat := statByAgent[agentID] - stat.AgentID = agentStat.AgentID - stat.TemplateID = agentStat.TemplateID - stat.UserID = agentStat.UserID - stat.WorkspaceID = agentStat.WorkspaceID - stat.SessionCountVSCode += agentStat.SessionCountVSCode - stat.SessionCountJetBrains += agentStat.SessionCountJetBrains - stat.SessionCountReconnectingPTY += agentStat.SessionCountReconnectingPTY - stat.SessionCountSSH += agentStat.SessionCountSSH - statByAgent[stat.AgentID] = stat - } - - latenciesByAgent := map[uuid.UUID][]float64{} - minimumDateByAgent := map[uuid.UUID]time.Time{} - for _, agentStat := range agentStatsCreatedAfter { - if agentStat.ConnectionMedianLatencyMS <= 0 { - continue - } - stat := statByAgent[agentStat.AgentID] - minimumDate := minimumDateByAgent[agentStat.AgentID] - if agentStat.CreatedAt.Before(minimumDate) || minimumDate.IsZero() { - minimumDateByAgent[agentStat.AgentID] = agentStat.CreatedAt - } - stat.WorkspaceRxBytes += agentStat.RxBytes - stat.WorkspaceTxBytes += agentStat.TxBytes - statByAgent[agentStat.AgentID] = stat - latenciesByAgent[agentStat.AgentID] = append(latenciesByAgent[agentStat.AgentID], agentStat.ConnectionMedianLatencyMS) - } - - for _, stat := range statByAgent { - stat.AggregatedFrom = minimumDateByAgent[stat.AgentID] - statByAgent[stat.AgentID] = stat - - latencies, ok := latenciesByAgent[stat.AgentID] - if !ok { - continue - } - stat.WorkspaceConnectionLatency50 = tryPercentileCont(latencies, 50) - stat.WorkspaceConnectionLatency95 = tryPercentileCont(latencies, 95) - statByAgent[stat.AgentID] = stat - } - - stats := make([]database.GetWorkspaceAgentStatsRow, 0, len(statByAgent)) - for _, agent := range statByAgent { - stats = append(stats, agent) - } - return stats, nil -} - -func (q *FakeQuerier) GetWorkspaceAgentStatsAndLabels(ctx context.Context, createdAfter time.Time) ([]database.GetWorkspaceAgentStatsAndLabelsRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - agentStatsCreatedAfter := make([]database.WorkspaceAgentStat, 0) - latestAgentStats := map[uuid.UUID]database.WorkspaceAgentStat{} - - for _, agentStat := range q.workspaceAgentStats { - if agentStat.CreatedAt.After(createdAfter) { - agentStatsCreatedAfter = append(agentStatsCreatedAfter, agentStat) - latestAgentStats[agentStat.AgentID] = agentStat - } - } - - statByAgent := map[uuid.UUID]database.GetWorkspaceAgentStatsAndLabelsRow{} - - // Session and connection metrics - for _, agentStat := range latestAgentStats { - stat := statByAgent[agentStat.AgentID] - stat.SessionCountVSCode += agentStat.SessionCountVSCode - stat.SessionCountJetBrains += agentStat.SessionCountJetBrains - stat.SessionCountReconnectingPTY += agentStat.SessionCountReconnectingPTY - stat.SessionCountSSH += agentStat.SessionCountSSH - stat.ConnectionCount += agentStat.ConnectionCount - if agentStat.ConnectionMedianLatencyMS >= 0 && stat.ConnectionMedianLatencyMS < agentStat.ConnectionMedianLatencyMS { - stat.ConnectionMedianLatencyMS = agentStat.ConnectionMedianLatencyMS - } - statByAgent[agentStat.AgentID] = stat - } - - // Tx, Rx metrics - for _, agentStat := range agentStatsCreatedAfter { - stat := statByAgent[agentStat.AgentID] - stat.RxBytes += agentStat.RxBytes - stat.TxBytes += agentStat.TxBytes - statByAgent[agentStat.AgentID] = stat - } - - // Labels - for _, agentStat := range agentStatsCreatedAfter { - stat := statByAgent[agentStat.AgentID] - - user, err := q.getUserByIDNoLock(agentStat.UserID) - if err != nil { - return nil, err - } - - stat.Username = user.Username - - workspace, err := q.getWorkspaceByIDNoLock(ctx, agentStat.WorkspaceID) - if err != nil { - return nil, err - } - stat.WorkspaceName = workspace.Name - - agent, err := q.getWorkspaceAgentByIDNoLock(ctx, agentStat.AgentID) - if err != nil { - return nil, err - } - stat.AgentName = agent.Name - - statByAgent[agentStat.AgentID] = stat - } - - stats := make([]database.GetWorkspaceAgentStatsAndLabelsRow, 0, len(statByAgent)) - for _, agent := range statByAgent { - stats = append(stats, agent) - } - return stats, nil -} - -func (q *FakeQuerier) GetWorkspaceAgentUsageStats(_ context.Context, createdAt time.Time) ([]database.GetWorkspaceAgentUsageStatsRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - type agentStatsKey struct { - UserID uuid.UUID - AgentID uuid.UUID - WorkspaceID uuid.UUID - TemplateID uuid.UUID - } - - type minuteStatsKey struct { - agentStatsKey - MinuteBucket time.Time - } - - latestAgentStats := map[agentStatsKey]database.GetWorkspaceAgentUsageStatsRow{} - latestAgentLatencies := map[agentStatsKey][]float64{} - for _, agentStat := range q.workspaceAgentStats { - key := agentStatsKey{ - UserID: agentStat.UserID, - AgentID: agentStat.AgentID, - WorkspaceID: agentStat.WorkspaceID, - TemplateID: agentStat.TemplateID, - } - if agentStat.CreatedAt.After(createdAt) { - val, ok := latestAgentStats[key] - if ok { - val.WorkspaceRxBytes += agentStat.RxBytes - val.WorkspaceTxBytes += agentStat.TxBytes - latestAgentStats[key] = val - } else { - latestAgentStats[key] = database.GetWorkspaceAgentUsageStatsRow{ - UserID: agentStat.UserID, - AgentID: agentStat.AgentID, - WorkspaceID: agentStat.WorkspaceID, - TemplateID: agentStat.TemplateID, - AggregatedFrom: createdAt, - WorkspaceRxBytes: agentStat.RxBytes, - WorkspaceTxBytes: agentStat.TxBytes, - } - } - - latencies, ok := latestAgentLatencies[key] - if !ok { - latestAgentLatencies[key] = []float64{agentStat.ConnectionMedianLatencyMS} - } else { - latestAgentLatencies[key] = append(latencies, agentStat.ConnectionMedianLatencyMS) - } - } - } - - for key, latencies := range latestAgentLatencies { - val, ok := latestAgentStats[key] - if ok { - val.WorkspaceConnectionLatency50 = tryPercentileCont(latencies, 50) - val.WorkspaceConnectionLatency95 = tryPercentileCont(latencies, 95) - } - latestAgentStats[key] = val - } - - type bucketRow struct { - database.GetWorkspaceAgentUsageStatsRow - MinuteBucket time.Time - } - - minuteBuckets := make(map[minuteStatsKey]bucketRow) - for _, agentStat := range q.workspaceAgentStats { - if agentStat.Usage && - (agentStat.CreatedAt.After(createdAt) || agentStat.CreatedAt.Equal(createdAt)) && - agentStat.CreatedAt.Before(time.Now().Truncate(time.Minute)) { - key := minuteStatsKey{ - agentStatsKey: agentStatsKey{ - UserID: agentStat.UserID, - AgentID: agentStat.AgentID, - WorkspaceID: agentStat.WorkspaceID, - TemplateID: agentStat.TemplateID, - }, - MinuteBucket: agentStat.CreatedAt.Truncate(time.Minute), - } - val, ok := minuteBuckets[key] - if ok { - val.SessionCountVSCode += agentStat.SessionCountVSCode - val.SessionCountJetBrains += agentStat.SessionCountJetBrains - val.SessionCountReconnectingPTY += agentStat.SessionCountReconnectingPTY - val.SessionCountSSH += agentStat.SessionCountSSH - minuteBuckets[key] = val - } else { - minuteBuckets[key] = bucketRow{ - GetWorkspaceAgentUsageStatsRow: database.GetWorkspaceAgentUsageStatsRow{ - UserID: agentStat.UserID, - AgentID: agentStat.AgentID, - WorkspaceID: agentStat.WorkspaceID, - TemplateID: agentStat.TemplateID, - SessionCountVSCode: agentStat.SessionCountVSCode, - SessionCountSSH: agentStat.SessionCountSSH, - SessionCountJetBrains: agentStat.SessionCountJetBrains, - SessionCountReconnectingPTY: agentStat.SessionCountReconnectingPTY, - }, - MinuteBucket: agentStat.CreatedAt.Truncate(time.Minute), - } - } - } - } - - // Get the latest minute bucket for each agent. - latestBuckets := make(map[uuid.UUID]bucketRow) - for key, bucket := range minuteBuckets { - latest, ok := latestBuckets[key.AgentID] - if !ok || key.MinuteBucket.After(latest.MinuteBucket) { - latestBuckets[key.AgentID] = bucket - } - } - - for key, stat := range latestAgentStats { - bucket, ok := latestBuckets[stat.AgentID] - if ok { - stat.SessionCountVSCode = bucket.SessionCountVSCode - stat.SessionCountJetBrains = bucket.SessionCountJetBrains - stat.SessionCountReconnectingPTY = bucket.SessionCountReconnectingPTY - stat.SessionCountSSH = bucket.SessionCountSSH - } - latestAgentStats[key] = stat - } - return maps.Values(latestAgentStats), nil -} - -func (q *FakeQuerier) GetWorkspaceAgentUsageStatsAndLabels(_ context.Context, createdAt time.Time) ([]database.GetWorkspaceAgentUsageStatsAndLabelsRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - type statsKey struct { - AgentID uuid.UUID - UserID uuid.UUID - WorkspaceID uuid.UUID - } - - latestAgentStats := map[statsKey]database.WorkspaceAgentStat{} - maxConnMedianLatency := 0.0 - for _, agentStat := range q.workspaceAgentStats { - key := statsKey{ - AgentID: agentStat.AgentID, - UserID: agentStat.UserID, - WorkspaceID: agentStat.WorkspaceID, - } - // WHERE workspace_agent_stats.created_at > $1 - // GROUP BY user_id, agent_id, workspace_id - if agentStat.CreatedAt.After(createdAt) { - val, ok := latestAgentStats[key] - if !ok { - val = agentStat - val.SessionCountJetBrains = 0 - val.SessionCountReconnectingPTY = 0 - val.SessionCountSSH = 0 - val.SessionCountVSCode = 0 - } else { - val.RxBytes += agentStat.RxBytes - val.TxBytes += agentStat.TxBytes - } - if agentStat.ConnectionMedianLatencyMS > maxConnMedianLatency { - val.ConnectionMedianLatencyMS = agentStat.ConnectionMedianLatencyMS - } - latestAgentStats[key] = val - } - // WHERE usage = true AND created_at > now() - '1 minute'::interval - // GROUP BY user_id, agent_id, workspace_id - if agentStat.Usage && agentStat.CreatedAt.After(dbtime.Now().Add(-time.Minute)) { - val, ok := latestAgentStats[key] - if !ok { - latestAgentStats[key] = agentStat - } else { - val.SessionCountVSCode += agentStat.SessionCountVSCode - val.SessionCountJetBrains += agentStat.SessionCountJetBrains - val.SessionCountReconnectingPTY += agentStat.SessionCountReconnectingPTY - val.SessionCountSSH += agentStat.SessionCountSSH - val.ConnectionCount += agentStat.ConnectionCount - latestAgentStats[key] = val - } - } - } - - stats := make([]database.GetWorkspaceAgentUsageStatsAndLabelsRow, 0, len(latestAgentStats)) - for key, agentStat := range latestAgentStats { - user, err := q.getUserByIDNoLock(key.UserID) - if err != nil { - return nil, err - } - workspace, err := q.getWorkspaceByIDNoLock(context.Background(), key.WorkspaceID) - if err != nil { - return nil, err - } - agent, err := q.getWorkspaceAgentByIDNoLock(context.Background(), key.AgentID) - if err != nil { - return nil, err - } - stats = append(stats, database.GetWorkspaceAgentUsageStatsAndLabelsRow{ - Username: user.Username, - AgentName: agent.Name, - WorkspaceName: workspace.Name, - RxBytes: agentStat.RxBytes, - TxBytes: agentStat.TxBytes, - SessionCountVSCode: agentStat.SessionCountVSCode, - SessionCountSSH: agentStat.SessionCountSSH, - SessionCountJetBrains: agentStat.SessionCountJetBrains, - SessionCountReconnectingPTY: agentStat.SessionCountReconnectingPTY, - ConnectionCount: agentStat.ConnectionCount, - ConnectionMedianLatencyMS: agentStat.ConnectionMedianLatencyMS, - }) - } - return stats, nil -} - -func (q *FakeQuerier) GetWorkspaceAgentsByParentID(_ context.Context, parentID uuid.UUID) ([]database.WorkspaceAgent, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - workspaceAgents := make([]database.WorkspaceAgent, 0) - for _, agent := range q.workspaceAgents { - if !agent.ParentID.Valid || agent.ParentID.UUID != parentID || agent.Deleted { - continue - } - - workspaceAgents = append(workspaceAgents, agent) - } - - return workspaceAgents, nil -} - -func (q *FakeQuerier) GetWorkspaceAgentsByResourceIDs(ctx context.Context, resourceIDs []uuid.UUID) ([]database.WorkspaceAgent, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.getWorkspaceAgentsByResourceIDsNoLock(ctx, resourceIDs) -} - -func (q *FakeQuerier) GetWorkspaceAgentsByWorkspaceAndBuildNumber(ctx context.Context, arg database.GetWorkspaceAgentsByWorkspaceAndBuildNumberParams) ([]database.WorkspaceAgent, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - build, err := q.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams(arg)) - if err != nil { - return nil, err - } - - resources, err := q.getWorkspaceResourcesByJobIDNoLock(ctx, build.JobID) - if err != nil { - return nil, err - } - - var resourceIDs []uuid.UUID - for _, resource := range resources { - resourceIDs = append(resourceIDs, resource.ID) - } - - return q.GetWorkspaceAgentsByResourceIDs(ctx, resourceIDs) -} - -func (q *FakeQuerier) GetWorkspaceAgentsCreatedAfter(_ context.Context, after time.Time) ([]database.WorkspaceAgent, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - workspaceAgents := make([]database.WorkspaceAgent, 0) - for _, agent := range q.workspaceAgents { - if agent.Deleted { - continue - } - if agent.CreatedAt.After(after) { - workspaceAgents = append(workspaceAgents, agent) - } - } - return workspaceAgents, nil -} - -func (q *FakeQuerier) GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgent, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - // Get latest build for workspace. - workspaceBuild, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspaceID) - if err != nil { - return nil, xerrors.Errorf("get latest workspace build: %w", err) - } - - // Get resources for build. - resources, err := q.getWorkspaceResourcesByJobIDNoLock(ctx, workspaceBuild.JobID) - if err != nil { - return nil, xerrors.Errorf("get workspace resources: %w", err) - } - if len(resources) == 0 { - return []database.WorkspaceAgent{}, nil - } - - resourceIDs := make([]uuid.UUID, len(resources)) - for i, resource := range resources { - resourceIDs[i] = resource.ID - } - - agents, err := q.getWorkspaceAgentsByResourceIDsNoLock(ctx, resourceIDs) - if err != nil { - return nil, xerrors.Errorf("get workspace agents: %w", err) - } - - return agents, nil -} - -func (q *FakeQuerier) GetWorkspaceAppByAgentIDAndSlug(ctx context.Context, arg database.GetWorkspaceAppByAgentIDAndSlugParams) (database.WorkspaceApp, error) { - if err := validateDatabaseType(arg); err != nil { - return database.WorkspaceApp{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.getWorkspaceAppByAgentIDAndSlugNoLock(ctx, arg) -} - -func (q *FakeQuerier) GetWorkspaceAppStatusesByAppIDs(_ context.Context, ids []uuid.UUID) ([]database.WorkspaceAppStatus, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - statuses := make([]database.WorkspaceAppStatus, 0) - for _, status := range q.workspaceAppStatuses { - for _, id := range ids { - if status.AppID == id { - statuses = append(statuses, status) - } - } - } - return statuses, nil -} - -func (q *FakeQuerier) GetWorkspaceAppsByAgentID(_ context.Context, id uuid.UUID) ([]database.WorkspaceApp, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - apps := make([]database.WorkspaceApp, 0) - for _, app := range q.workspaceApps { - if app.AgentID == id { - apps = append(apps, app) - } - } - return apps, nil -} - -func (q *FakeQuerier) GetWorkspaceAppsByAgentIDs(_ context.Context, ids []uuid.UUID) ([]database.WorkspaceApp, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - apps := make([]database.WorkspaceApp, 0) - for _, app := range q.workspaceApps { - for _, id := range ids { - if app.AgentID == id { - apps = append(apps, app) - break - } - } - } - return apps, nil -} - -func (q *FakeQuerier) GetWorkspaceAppsCreatedAfter(_ context.Context, after time.Time) ([]database.WorkspaceApp, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - apps := make([]database.WorkspaceApp, 0) - for _, app := range q.workspaceApps { - if app.CreatedAt.After(after) { - apps = append(apps, app) - } - } - return apps, nil -} - -func (q *FakeQuerier) GetWorkspaceBuildByID(ctx context.Context, id uuid.UUID) (database.WorkspaceBuild, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.getWorkspaceBuildByIDNoLock(ctx, id) -} - -func (q *FakeQuerier) GetWorkspaceBuildByJobID(_ context.Context, jobID uuid.UUID) (database.WorkspaceBuild, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, build := range q.workspaceBuilds { - if build.JobID == jobID { - return q.workspaceBuildWithUserNoLock(build), nil - } - } - return database.WorkspaceBuild{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetWorkspaceBuildByWorkspaceIDAndBuildNumber(_ context.Context, arg database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams) (database.WorkspaceBuild, error) { - if err := validateDatabaseType(arg); err != nil { - return database.WorkspaceBuild{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, workspaceBuild := range q.workspaceBuilds { - if workspaceBuild.WorkspaceID != arg.WorkspaceID { - continue - } - if workspaceBuild.BuildNumber != arg.BuildNumber { - continue - } - return q.workspaceBuildWithUserNoLock(workspaceBuild), nil - } - return database.WorkspaceBuild{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetWorkspaceBuildParameters(_ context.Context, workspaceBuildID uuid.UUID) ([]database.WorkspaceBuildParameter, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.getWorkspaceBuildParametersNoLock(workspaceBuildID) -} - -func (q *FakeQuerier) GetWorkspaceBuildParametersByBuildIDs(ctx context.Context, workspaceBuildIDs []uuid.UUID) ([]database.WorkspaceBuildParameter, error) { - // No auth filter. - return q.GetAuthorizedWorkspaceBuildParametersByBuildIDs(ctx, workspaceBuildIDs, nil) -} - -func (q *FakeQuerier) GetWorkspaceBuildStatsByTemplates(ctx context.Context, since time.Time) ([]database.GetWorkspaceBuildStatsByTemplatesRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - templateStats := map[uuid.UUID]database.GetWorkspaceBuildStatsByTemplatesRow{} - for _, wb := range q.workspaceBuilds { - job, err := q.getProvisionerJobByIDNoLock(ctx, wb.JobID) - if err != nil { - return nil, xerrors.Errorf("get provisioner job by ID: %w", err) - } - - if !job.CompletedAt.Valid { - continue - } - - if wb.CreatedAt.Before(since) { - continue - } - - w, err := q.getWorkspaceByIDNoLock(ctx, wb.WorkspaceID) - if err != nil { - return nil, xerrors.Errorf("get workspace by ID: %w", err) - } - - if _, ok := templateStats[w.TemplateID]; !ok { - t, err := q.getTemplateByIDNoLock(ctx, w.TemplateID) - if err != nil { - return nil, xerrors.Errorf("get template by ID: %w", err) - } - - templateStats[w.TemplateID] = database.GetWorkspaceBuildStatsByTemplatesRow{ - TemplateID: w.TemplateID, - TemplateName: t.Name, - TemplateDisplayName: t.DisplayName, - TemplateOrganizationID: w.OrganizationID, - } - } - - s := templateStats[w.TemplateID] - s.TotalBuilds++ - if job.JobStatus == database.ProvisionerJobStatusFailed { - s.FailedBuilds++ - } - templateStats[w.TemplateID] = s - } - - rows := make([]database.GetWorkspaceBuildStatsByTemplatesRow, 0, len(templateStats)) - for _, ts := range templateStats { - rows = append(rows, ts) - } - - sort.Slice(rows, func(i, j int) bool { - return rows[i].TemplateName < rows[j].TemplateName - }) - return rows, nil -} - -func (q *FakeQuerier) GetWorkspaceBuildsByWorkspaceID(_ context.Context, - params database.GetWorkspaceBuildsByWorkspaceIDParams, -) ([]database.WorkspaceBuild, error) { - if err := validateDatabaseType(params); err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - history := make([]database.WorkspaceBuild, 0) - for _, workspaceBuild := range q.workspaceBuilds { - if workspaceBuild.CreatedAt.Before(params.Since) { - continue - } - if workspaceBuild.WorkspaceID == params.WorkspaceID { - history = append(history, q.workspaceBuildWithUserNoLock(workspaceBuild)) - } - } - - // Order by build_number - slices.SortFunc(history, func(a, b database.WorkspaceBuild) int { - return slice.Descending(a.BuildNumber, b.BuildNumber) - }) - - if params.AfterID != uuid.Nil { - found := false - for i, v := range history { - if v.ID == params.AfterID { - // We want to return all builds after index i. - history = history[i+1:] - found = true - break - } - } - - // If no builds after the time, then we return an empty list. - if !found { - return nil, sql.ErrNoRows - } - } - - if params.OffsetOpt > 0 { - if int(params.OffsetOpt) > len(history)-1 { - return nil, sql.ErrNoRows - } - history = history[params.OffsetOpt:] - } - - if params.LimitOpt > 0 { - if int(params.LimitOpt) > len(history) { - // #nosec G115 - Safe conversion as history slice length is expected to be within int32 range - params.LimitOpt = int32(len(history)) - } - history = history[:params.LimitOpt] - } - - if len(history) == 0 { - return nil, sql.ErrNoRows - } - return history, nil -} - -func (q *FakeQuerier) GetWorkspaceBuildsCreatedAfter(_ context.Context, after time.Time) ([]database.WorkspaceBuild, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - workspaceBuilds := make([]database.WorkspaceBuild, 0) - for _, workspaceBuild := range q.workspaceBuilds { - if workspaceBuild.CreatedAt.After(after) { - workspaceBuilds = append(workspaceBuilds, q.workspaceBuildWithUserNoLock(workspaceBuild)) - } - } - return workspaceBuilds, nil -} - -func (q *FakeQuerier) GetWorkspaceByAgentID(ctx context.Context, agentID uuid.UUID) (database.Workspace, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - w, err := q.getWorkspaceByAgentIDNoLock(ctx, agentID) - if err != nil { - return database.Workspace{}, err - } - - return w, nil -} - -func (q *FakeQuerier) GetWorkspaceByID(ctx context.Context, id uuid.UUID) (database.Workspace, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.getWorkspaceByIDNoLock(ctx, id) -} - -func (q *FakeQuerier) GetWorkspaceByOwnerIDAndName(_ context.Context, arg database.GetWorkspaceByOwnerIDAndNameParams) (database.Workspace, error) { - if err := validateDatabaseType(arg); err != nil { - return database.Workspace{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - var found *database.WorkspaceTable - for _, workspace := range q.workspaces { - if workspace.OwnerID != arg.OwnerID { - continue - } - if !strings.EqualFold(workspace.Name, arg.Name) { - continue - } - if workspace.Deleted != arg.Deleted { - continue - } - - // Return the most recent workspace with the given name - if found == nil || workspace.CreatedAt.After(found.CreatedAt) { - found = &workspace - } - } - if found != nil { - return q.extendWorkspace(*found), nil - } - return database.Workspace{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetWorkspaceByResourceID(ctx context.Context, resourceID uuid.UUID) (database.Workspace, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, resource := range q.workspaceResources { - if resource.ID != resourceID { - continue - } - - for _, build := range q.workspaceBuilds { - if build.JobID != resource.JobID { - continue - } - - for _, workspace := range q.workspaces { - if workspace.ID != build.WorkspaceID { - continue - } - - return q.extendWorkspace(workspace), nil - } - } - } - - return database.Workspace{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetWorkspaceByWorkspaceAppID(_ context.Context, workspaceAppID uuid.UUID) (database.Workspace, error) { - if err := validateDatabaseType(workspaceAppID); err != nil { - return database.Workspace{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, workspaceApp := range q.workspaceApps { - if workspaceApp.ID == workspaceAppID { - return q.getWorkspaceByAgentIDNoLock(context.Background(), workspaceApp.AgentID) - } - } - return database.Workspace{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetWorkspaceModulesByJobID(_ context.Context, jobID uuid.UUID) ([]database.WorkspaceModule, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - modules := make([]database.WorkspaceModule, 0) - for _, module := range q.workspaceModules { - if module.JobID == jobID { - modules = append(modules, module) - } - } - return modules, nil -} - -func (q *FakeQuerier) GetWorkspaceModulesCreatedAfter(_ context.Context, createdAt time.Time) ([]database.WorkspaceModule, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - modules := make([]database.WorkspaceModule, 0) - for _, module := range q.workspaceModules { - if module.CreatedAt.After(createdAt) { - modules = append(modules, module) - } - } - return modules, nil -} - -func (q *FakeQuerier) GetWorkspaceProxies(_ context.Context) ([]database.WorkspaceProxy, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - cpy := make([]database.WorkspaceProxy, 0, len(q.workspaceProxies)) - - for _, p := range q.workspaceProxies { - if !p.Deleted { - cpy = append(cpy, p) - } - } - return cpy, nil -} - -func (q *FakeQuerier) GetWorkspaceProxyByHostname(_ context.Context, params database.GetWorkspaceProxyByHostnameParams) (database.WorkspaceProxy, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - // Return zero rows if this is called with a non-sanitized hostname. The SQL - // version of this query does the same thing. - if !validProxyByHostnameRegex.MatchString(params.Hostname) { - return database.WorkspaceProxy{}, sql.ErrNoRows - } - - // This regex matches the SQL version. - accessURLRegex := regexp.MustCompile(`[^:]*://` + regexp.QuoteMeta(params.Hostname) + `([:/]?.)*`) - - for _, proxy := range q.workspaceProxies { - if proxy.Deleted { - continue - } - if params.AllowAccessUrl && accessURLRegex.MatchString(proxy.Url) { - return proxy, nil - } - - // Compile the app hostname regex. This is slow sadly. - if params.AllowWildcardHostname { - wildcardRegexp, err := appurl.CompileHostnamePattern(proxy.WildcardHostname) - if err != nil { - return database.WorkspaceProxy{}, xerrors.Errorf("compile hostname pattern %q for proxy %q (%s): %w", proxy.WildcardHostname, proxy.Name, proxy.ID.String(), err) - } - if _, ok := appurl.ExecuteHostnamePattern(wildcardRegexp, params.Hostname); ok { - return proxy, nil - } - } - } - - return database.WorkspaceProxy{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetWorkspaceProxyByID(_ context.Context, id uuid.UUID) (database.WorkspaceProxy, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, proxy := range q.workspaceProxies { - if proxy.ID == id { - return proxy, nil - } - } - return database.WorkspaceProxy{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetWorkspaceProxyByName(_ context.Context, name string) (database.WorkspaceProxy, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, proxy := range q.workspaceProxies { - if proxy.Deleted { - continue - } - if proxy.Name == name { - return proxy, nil - } - } - return database.WorkspaceProxy{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetWorkspaceResourceByID(_ context.Context, id uuid.UUID) (database.WorkspaceResource, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, resource := range q.workspaceResources { - if resource.ID == id { - return resource, nil - } - } - return database.WorkspaceResource{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetWorkspaceResourceMetadataByResourceIDs(_ context.Context, ids []uuid.UUID) ([]database.WorkspaceResourceMetadatum, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - metadata := make([]database.WorkspaceResourceMetadatum, 0) - for _, metadatum := range q.workspaceResourceMetadata { - for _, id := range ids { - if metadatum.WorkspaceResourceID == id { - metadata = append(metadata, metadatum) - } - } - } - return metadata, nil -} - -func (q *FakeQuerier) GetWorkspaceResourceMetadataCreatedAfter(ctx context.Context, after time.Time) ([]database.WorkspaceResourceMetadatum, error) { - resources, err := q.GetWorkspaceResourcesCreatedAfter(ctx, after) - if err != nil { - return nil, err - } - resourceIDs := map[uuid.UUID]struct{}{} - for _, resource := range resources { - resourceIDs[resource.ID] = struct{}{} - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - metadata := make([]database.WorkspaceResourceMetadatum, 0) - for _, m := range q.workspaceResourceMetadata { - _, ok := resourceIDs[m.WorkspaceResourceID] - if !ok { - continue - } - metadata = append(metadata, m) - } - return metadata, nil -} - -func (q *FakeQuerier) GetWorkspaceResourcesByJobID(ctx context.Context, jobID uuid.UUID) ([]database.WorkspaceResource, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.getWorkspaceResourcesByJobIDNoLock(ctx, jobID) -} - -func (q *FakeQuerier) GetWorkspaceResourcesByJobIDs(_ context.Context, jobIDs []uuid.UUID) ([]database.WorkspaceResource, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - resources := make([]database.WorkspaceResource, 0) - for _, resource := range q.workspaceResources { - for _, jobID := range jobIDs { - if resource.JobID != jobID { - continue - } - resources = append(resources, resource) - } - } - return resources, nil -} - -func (q *FakeQuerier) GetWorkspaceResourcesCreatedAfter(_ context.Context, after time.Time) ([]database.WorkspaceResource, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - resources := make([]database.WorkspaceResource, 0) - for _, resource := range q.workspaceResources { - if resource.CreatedAt.After(after) { - resources = append(resources, resource) - } - } - return resources, nil -} - -func (q *FakeQuerier) GetWorkspaceUniqueOwnerCountByTemplateIDs(_ context.Context, templateIds []uuid.UUID) ([]database.GetWorkspaceUniqueOwnerCountByTemplateIDsRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - workspaceOwners := make(map[uuid.UUID]map[uuid.UUID]struct{}) - for _, workspace := range q.workspaces { - if workspace.Deleted { - continue - } - if !slices.Contains(templateIds, workspace.TemplateID) { - continue - } - _, ok := workspaceOwners[workspace.TemplateID] - if !ok { - workspaceOwners[workspace.TemplateID] = make(map[uuid.UUID]struct{}) - } - workspaceOwners[workspace.TemplateID][workspace.OwnerID] = struct{}{} - } - resp := make([]database.GetWorkspaceUniqueOwnerCountByTemplateIDsRow, 0) - for _, templateID := range templateIds { - count := len(workspaceOwners[templateID]) - resp = append(resp, database.GetWorkspaceUniqueOwnerCountByTemplateIDsRow{ - TemplateID: templateID, - UniqueOwnersSum: int64(count), - }) - } - - return resp, nil -} - -func (q *FakeQuerier) GetWorkspaces(ctx context.Context, arg database.GetWorkspacesParams) ([]database.GetWorkspacesRow, error) { - if err := validateDatabaseType(arg); err != nil { - return nil, err - } - - // A nil auth filter means no auth filter. - workspaceRows, err := q.GetAuthorizedWorkspaces(ctx, arg, nil) - return workspaceRows, err -} - -func (q *FakeQuerier) GetWorkspacesAndAgentsByOwnerID(ctx context.Context, ownerID uuid.UUID) ([]database.GetWorkspacesAndAgentsByOwnerIDRow, error) { - // No auth filter. - return q.GetAuthorizedWorkspacesAndAgentsByOwnerID(ctx, ownerID, nil) -} - -func (q *FakeQuerier) GetWorkspacesByTemplateID(_ context.Context, templateID uuid.UUID) ([]database.WorkspaceTable, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - workspaces := []database.WorkspaceTable{} - for _, workspace := range q.workspaces { - if workspace.TemplateID == templateID { - workspaces = append(workspaces, workspace) - } - } - - return workspaces, nil -} - -func (q *FakeQuerier) GetWorkspacesEligibleForTransition(ctx context.Context, now time.Time) ([]database.GetWorkspacesEligibleForTransitionRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - workspaces := []database.GetWorkspacesEligibleForTransitionRow{} - for _, workspace := range q.workspaces { - build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID) - if err != nil { - return nil, xerrors.Errorf("get workspace build by ID: %w", err) - } - - user, err := q.getUserByIDNoLock(workspace.OwnerID) - if err != nil { - return nil, xerrors.Errorf("get user by ID: %w", err) - } - - job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) - if err != nil { - return nil, xerrors.Errorf("get provisioner job by ID: %w", err) - } - - template, err := q.getTemplateByIDNoLock(ctx, workspace.TemplateID) - if err != nil { - return nil, xerrors.Errorf("get template by ID: %w", err) - } - - if workspace.Deleted { - continue - } - - if job.JobStatus != database.ProvisionerJobStatusFailed && - !workspace.DormantAt.Valid && - build.Transition == database.WorkspaceTransitionStart && - (user.Status == database.UserStatusSuspended || (!build.Deadline.IsZero() && build.Deadline.Before(now))) { - workspaces = append(workspaces, database.GetWorkspacesEligibleForTransitionRow{ - ID: workspace.ID, - Name: workspace.Name, - }) - continue - } - - if user.Status == database.UserStatusActive && - job.JobStatus != database.ProvisionerJobStatusFailed && - build.Transition == database.WorkspaceTransitionStop && - workspace.AutostartSchedule.Valid && - // We do not know if workspace with a zero next start is eligible - // for autostart, so we accept this false-positive. This can occur - // when a coder version is upgraded and next_start_at has yet to - // be set. - (workspace.NextStartAt.Time.IsZero() || - !now.Before(workspace.NextStartAt.Time)) { - workspaces = append(workspaces, database.GetWorkspacesEligibleForTransitionRow{ - ID: workspace.ID, - Name: workspace.Name, - }) - continue - } - - if !workspace.DormantAt.Valid && - template.TimeTilDormant > 0 && - now.Sub(workspace.LastUsedAt) >= time.Duration(template.TimeTilDormant) { - workspaces = append(workspaces, database.GetWorkspacesEligibleForTransitionRow{ - ID: workspace.ID, - Name: workspace.Name, - }) - continue - } - - if workspace.DormantAt.Valid && - workspace.DeletingAt.Valid && - workspace.DeletingAt.Time.Before(now) && - template.TimeTilDormantAutoDelete > 0 { - if build.Transition == database.WorkspaceTransitionDelete && - job.JobStatus == database.ProvisionerJobStatusFailed { - if job.CanceledAt.Valid && now.Sub(job.CanceledAt.Time) <= 24*time.Hour { - continue - } - - if job.CompletedAt.Valid && now.Sub(job.CompletedAt.Time) <= 24*time.Hour { - continue - } - } - - workspaces = append(workspaces, database.GetWorkspacesEligibleForTransitionRow{ - ID: workspace.ID, - Name: workspace.Name, - }) - continue - } - - if template.FailureTTL > 0 && - build.Transition == database.WorkspaceTransitionStart && - job.JobStatus == database.ProvisionerJobStatusFailed && - job.CompletedAt.Valid && - now.Sub(job.CompletedAt.Time) > time.Duration(template.FailureTTL) { - workspaces = append(workspaces, database.GetWorkspacesEligibleForTransitionRow{ - ID: workspace.ID, - Name: workspace.Name, - }) - continue - } - } - - return workspaces, nil -} - -func (q *FakeQuerier) HasTemplateVersionsWithAITask(_ context.Context) (bool, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, templateVersion := range q.templateVersions { - if templateVersion.HasAITask.Valid && templateVersion.HasAITask.Bool { - return true, nil - } - } - - return false, nil -} - -func (q *FakeQuerier) InsertAPIKey(_ context.Context, arg database.InsertAPIKeyParams) (database.APIKey, error) { - if err := validateDatabaseType(arg); err != nil { - return database.APIKey{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - if arg.LifetimeSeconds == 0 { - arg.LifetimeSeconds = 86400 - } - - for _, u := range q.users { - if u.ID == arg.UserID && u.Deleted { - return database.APIKey{}, xerrors.Errorf("refusing to create APIKey for deleted user") - } - } - - //nolint:gosimple - key := database.APIKey{ - ID: arg.ID, - LifetimeSeconds: arg.LifetimeSeconds, - HashedSecret: arg.HashedSecret, - IPAddress: arg.IPAddress, - UserID: arg.UserID, - ExpiresAt: arg.ExpiresAt, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - LastUsed: arg.LastUsed, - LoginType: arg.LoginType, - Scope: arg.Scope, - TokenName: arg.TokenName, - } - q.apiKeys = append(q.apiKeys, key) - return key, nil -} - -func (q *FakeQuerier) InsertAllUsersGroup(ctx context.Context, orgID uuid.UUID) (database.Group, error) { - return q.InsertGroup(ctx, database.InsertGroupParams{ - ID: orgID, - Name: database.EveryoneGroup, - DisplayName: "", - OrganizationID: orgID, - AvatarURL: "", - QuotaAllowance: 0, - }) -} - -func (q *FakeQuerier) InsertAuditLog(_ context.Context, arg database.InsertAuditLogParams) (database.AuditLog, error) { - if err := validateDatabaseType(arg); err != nil { - return database.AuditLog{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - alog := database.AuditLog(arg) - - q.auditLogs = append(q.auditLogs, alog) - slices.SortFunc(q.auditLogs, func(a, b database.AuditLog) int { - if a.Time.Before(b.Time) { - return -1 - } else if a.Time.Equal(b.Time) { - return 0 - } - return 1 - }) - - return alog, nil -} - -func (q *FakeQuerier) InsertCryptoKey(_ context.Context, arg database.InsertCryptoKeyParams) (database.CryptoKey, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.CryptoKey{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - key := database.CryptoKey{ - Feature: arg.Feature, - Sequence: arg.Sequence, - Secret: arg.Secret, - SecretKeyID: arg.SecretKeyID, - StartsAt: arg.StartsAt, - } - - q.cryptoKeys = append(q.cryptoKeys, key) - - return key, nil -} - -func (q *FakeQuerier) InsertCustomRole(_ context.Context, arg database.InsertCustomRoleParams) (database.CustomRole, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.CustomRole{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - for i := range q.customRoles { - if strings.EqualFold(q.customRoles[i].Name, arg.Name) && - q.customRoles[i].OrganizationID.UUID == arg.OrganizationID.UUID { - return database.CustomRole{}, errUniqueConstraint - } - } - - role := database.CustomRole{ - ID: uuid.New(), - Name: arg.Name, - DisplayName: arg.DisplayName, - OrganizationID: arg.OrganizationID, - SitePermissions: arg.SitePermissions, - OrgPermissions: arg.OrgPermissions, - UserPermissions: arg.UserPermissions, - CreatedAt: dbtime.Now(), - UpdatedAt: dbtime.Now(), - } - q.customRoles = append(q.customRoles, role) - - return role, nil -} - -func (q *FakeQuerier) InsertDBCryptKey(_ context.Context, arg database.InsertDBCryptKeyParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - for _, key := range q.dbcryptKeys { - if key.Number == arg.Number { - return errUniqueConstraint - } - } - - q.dbcryptKeys = append(q.dbcryptKeys, database.DBCryptKey{ - Number: arg.Number, - ActiveKeyDigest: sql.NullString{String: arg.ActiveKeyDigest, Valid: true}, - Test: arg.Test, - }) - return nil -} - -func (q *FakeQuerier) InsertDERPMeshKey(_ context.Context, id string) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - q.derpMeshKey = id - return nil -} - -func (q *FakeQuerier) InsertDeploymentID(_ context.Context, id string) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - q.deploymentID = id - return nil -} - -func (q *FakeQuerier) InsertExternalAuthLink(_ context.Context, arg database.InsertExternalAuthLinkParams) (database.ExternalAuthLink, error) { - if err := validateDatabaseType(arg); err != nil { - return database.ExternalAuthLink{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - // nolint:gosimple - gitAuthLink := database.ExternalAuthLink{ - ProviderID: arg.ProviderID, - UserID: arg.UserID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - OAuthAccessToken: arg.OAuthAccessToken, - OAuthAccessTokenKeyID: arg.OAuthAccessTokenKeyID, - OAuthRefreshToken: arg.OAuthRefreshToken, - OAuthRefreshTokenKeyID: arg.OAuthRefreshTokenKeyID, - OAuthExpiry: arg.OAuthExpiry, - OAuthExtra: arg.OAuthExtra, - } - q.externalAuthLinks = append(q.externalAuthLinks, gitAuthLink) - return gitAuthLink, nil -} - -func (q *FakeQuerier) InsertFile(_ context.Context, arg database.InsertFileParams) (database.File, error) { - if err := validateDatabaseType(arg); err != nil { - return database.File{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - if slices.ContainsFunc(q.files, func(file database.File) bool { - return file.CreatedBy == arg.CreatedBy && file.Hash == arg.Hash - }) { - return database.File{}, newUniqueConstraintError(database.UniqueFilesHashCreatedByKey) - } - - //nolint:gosimple - file := database.File{ - ID: arg.ID, - Hash: arg.Hash, - CreatedAt: arg.CreatedAt, - CreatedBy: arg.CreatedBy, - Mimetype: arg.Mimetype, - Data: arg.Data, - } - q.files = append(q.files, file) - return file, nil -} - -func (q *FakeQuerier) InsertGitSSHKey(_ context.Context, arg database.InsertGitSSHKeyParams) (database.GitSSHKey, error) { - if err := validateDatabaseType(arg); err != nil { - return database.GitSSHKey{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - //nolint:gosimple - gitSSHKey := database.GitSSHKey{ - UserID: arg.UserID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - PrivateKey: arg.PrivateKey, - PublicKey: arg.PublicKey, - } - q.gitSSHKey = append(q.gitSSHKey, gitSSHKey) - return gitSSHKey, nil -} - -func (q *FakeQuerier) InsertGroup(_ context.Context, arg database.InsertGroupParams) (database.Group, error) { - if err := validateDatabaseType(arg); err != nil { - return database.Group{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, group := range q.groups { - if group.OrganizationID == arg.OrganizationID && - group.Name == arg.Name { - return database.Group{}, errUniqueConstraint - } - } - - //nolint:gosimple - group := database.Group{ - ID: arg.ID, - Name: arg.Name, - DisplayName: arg.DisplayName, - OrganizationID: arg.OrganizationID, - AvatarURL: arg.AvatarURL, - QuotaAllowance: arg.QuotaAllowance, - Source: database.GroupSourceUser, - } - - q.groups = append(q.groups, group) - - return group, nil -} - -func (q *FakeQuerier) InsertGroupMember(_ context.Context, arg database.InsertGroupMemberParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, member := range q.groupMembers { - if member.GroupID == arg.GroupID && - member.UserID == arg.UserID { - return errUniqueConstraint - } - } - - //nolint:gosimple - q.groupMembers = append(q.groupMembers, database.GroupMemberTable{ - GroupID: arg.GroupID, - UserID: arg.UserID, - }) - - return nil -} - -func (q *FakeQuerier) InsertInboxNotification(_ context.Context, arg database.InsertInboxNotificationParams) (database.InboxNotification, error) { - if err := validateDatabaseType(arg); err != nil { - return database.InboxNotification{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - notification := database.InboxNotification{ - ID: arg.ID, - UserID: arg.UserID, - TemplateID: arg.TemplateID, - Targets: arg.Targets, - Title: arg.Title, - Content: arg.Content, - Icon: arg.Icon, - Actions: arg.Actions, - CreatedAt: arg.CreatedAt, - } - - q.inboxNotifications = append(q.inboxNotifications, notification) - return notification, nil -} - -func (q *FakeQuerier) InsertLicense( - _ context.Context, arg database.InsertLicenseParams, -) (database.License, error) { - if err := validateDatabaseType(arg); err != nil { - return database.License{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - l := database.License{ - ID: q.lastLicenseID + 1, - UploadedAt: arg.UploadedAt, - JWT: arg.JWT, - Exp: arg.Exp, - } - q.lastLicenseID = l.ID - q.licenses = append(q.licenses, l) - return l, nil -} - -func (q *FakeQuerier) InsertMemoryResourceMonitor(_ context.Context, arg database.InsertMemoryResourceMonitorParams) (database.WorkspaceAgentMemoryResourceMonitor, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.WorkspaceAgentMemoryResourceMonitor{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - //nolint:unconvert // The structs field-order differs so this is needed. - monitor := database.WorkspaceAgentMemoryResourceMonitor(database.WorkspaceAgentMemoryResourceMonitor{ - AgentID: arg.AgentID, - Enabled: arg.Enabled, - State: arg.State, - Threshold: arg.Threshold, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - DebouncedUntil: arg.DebouncedUntil, - }) - - q.workspaceAgentMemoryResourceMonitors = append(q.workspaceAgentMemoryResourceMonitors, monitor) - return monitor, nil -} - -func (q *FakeQuerier) InsertMissingGroups(_ context.Context, arg database.InsertMissingGroupsParams) ([]database.Group, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - groupNameMap := make(map[string]struct{}) - for _, g := range arg.GroupNames { - groupNameMap[g] = struct{}{} - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, g := range q.groups { - if g.OrganizationID != arg.OrganizationID { - continue - } - delete(groupNameMap, g.Name) - } - - newGroups := make([]database.Group, 0, len(groupNameMap)) - for k := range groupNameMap { - g := database.Group{ - ID: uuid.New(), - Name: k, - OrganizationID: arg.OrganizationID, - AvatarURL: "", - QuotaAllowance: 0, - DisplayName: "", - Source: arg.Source, - } - q.groups = append(q.groups, g) - newGroups = append(newGroups, g) - } - - return newGroups, nil -} - -func (q *FakeQuerier) InsertOAuth2ProviderApp(_ context.Context, arg database.InsertOAuth2ProviderAppParams) (database.OAuth2ProviderApp, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.OAuth2ProviderApp{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - //nolint:gosimple // Go wants database.OAuth2ProviderApp(arg), but we cannot be sure the structs will remain identical. - app := database.OAuth2ProviderApp{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - Name: arg.Name, - Icon: arg.Icon, - CallbackURL: arg.CallbackURL, - RedirectUris: arg.RedirectUris, - ClientType: arg.ClientType, - DynamicallyRegistered: arg.DynamicallyRegistered, - ClientIDIssuedAt: arg.ClientIDIssuedAt, - ClientSecretExpiresAt: arg.ClientSecretExpiresAt, - GrantTypes: arg.GrantTypes, - ResponseTypes: arg.ResponseTypes, - TokenEndpointAuthMethod: arg.TokenEndpointAuthMethod, - Scope: arg.Scope, - Contacts: arg.Contacts, - ClientUri: arg.ClientUri, - LogoUri: arg.LogoUri, - TosUri: arg.TosUri, - PolicyUri: arg.PolicyUri, - JwksUri: arg.JwksUri, - Jwks: arg.Jwks, - SoftwareID: arg.SoftwareID, - SoftwareVersion: arg.SoftwareVersion, - RegistrationAccessToken: arg.RegistrationAccessToken, - RegistrationClientUri: arg.RegistrationClientUri, - } - - // Apply RFC-compliant defaults to match database migration defaults - if !app.ClientType.Valid { - app.ClientType = sql.NullString{String: "confidential", Valid: true} - } - if !app.DynamicallyRegistered.Valid { - app.DynamicallyRegistered = sql.NullBool{Bool: false, Valid: true} - } - if len(app.GrantTypes) == 0 { - app.GrantTypes = []string{"authorization_code", "refresh_token"} - } - if len(app.ResponseTypes) == 0 { - app.ResponseTypes = []string{"code"} - } - if !app.TokenEndpointAuthMethod.Valid { - app.TokenEndpointAuthMethod = sql.NullString{String: "client_secret_basic", Valid: true} - } - if !app.Scope.Valid { - app.Scope = sql.NullString{String: "", Valid: true} - } - if app.Contacts == nil { - app.Contacts = []string{} - } - q.oauth2ProviderApps = append(q.oauth2ProviderApps, app) - - return app, nil -} - -func (q *FakeQuerier) InsertOAuth2ProviderAppCode(_ context.Context, arg database.InsertOAuth2ProviderAppCodeParams) (database.OAuth2ProviderAppCode, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.OAuth2ProviderAppCode{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, app := range q.oauth2ProviderApps { - if app.ID == arg.AppID { - code := database.OAuth2ProviderAppCode{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - ExpiresAt: arg.ExpiresAt, - SecretPrefix: arg.SecretPrefix, - HashedSecret: arg.HashedSecret, - UserID: arg.UserID, - AppID: arg.AppID, - ResourceUri: arg.ResourceUri, - CodeChallenge: arg.CodeChallenge, - CodeChallengeMethod: arg.CodeChallengeMethod, - } - q.oauth2ProviderAppCodes = append(q.oauth2ProviderAppCodes, code) - return code, nil - } - } - - return database.OAuth2ProviderAppCode{}, sql.ErrNoRows -} - -func (q *FakeQuerier) InsertOAuth2ProviderAppSecret(_ context.Context, arg database.InsertOAuth2ProviderAppSecretParams) (database.OAuth2ProviderAppSecret, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.OAuth2ProviderAppSecret{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, app := range q.oauth2ProviderApps { - if app.ID == arg.AppID { - secret := database.OAuth2ProviderAppSecret{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - SecretPrefix: arg.SecretPrefix, - HashedSecret: arg.HashedSecret, - DisplaySecret: arg.DisplaySecret, - AppID: arg.AppID, - } - q.oauth2ProviderAppSecrets = append(q.oauth2ProviderAppSecrets, secret) - return secret, nil - } - } - - return database.OAuth2ProviderAppSecret{}, sql.ErrNoRows -} - -func (q *FakeQuerier) InsertOAuth2ProviderAppToken(_ context.Context, arg database.InsertOAuth2ProviderAppTokenParams) (database.OAuth2ProviderAppToken, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.OAuth2ProviderAppToken{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, secret := range q.oauth2ProviderAppSecrets { - if secret.ID == arg.AppSecretID { - //nolint:gosimple // Go wants database.OAuth2ProviderAppToken(arg), but we cannot be sure the structs will remain identical. - token := database.OAuth2ProviderAppToken{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - ExpiresAt: arg.ExpiresAt, - HashPrefix: arg.HashPrefix, - RefreshHash: arg.RefreshHash, - APIKeyID: arg.APIKeyID, - AppSecretID: arg.AppSecretID, - UserID: arg.UserID, - Audience: arg.Audience, - } - q.oauth2ProviderAppTokens = append(q.oauth2ProviderAppTokens, token) - return token, nil - } - } - - return database.OAuth2ProviderAppToken{}, sql.ErrNoRows -} - -func (q *FakeQuerier) InsertOrganization(_ context.Context, arg database.InsertOrganizationParams) (database.Organization, error) { - if err := validateDatabaseType(arg); err != nil { - return database.Organization{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - organization := database.Organization{ - ID: arg.ID, - Name: arg.Name, - DisplayName: arg.DisplayName, - Description: arg.Description, - Icon: arg.Icon, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - IsDefault: len(q.organizations) == 0, - } - q.organizations = append(q.organizations, organization) - return organization, nil -} - -func (q *FakeQuerier) InsertOrganizationMember(_ context.Context, arg database.InsertOrganizationMemberParams) (database.OrganizationMember, error) { - if err := validateDatabaseType(arg); err != nil { - return database.OrganizationMember{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - if slices.IndexFunc(q.data.organizationMembers, func(member database.OrganizationMember) bool { - return member.OrganizationID == arg.OrganizationID && member.UserID == arg.UserID - }) >= 0 { - // Error pulled from a live db error - return database.OrganizationMember{}, &pq.Error{ - Severity: "ERROR", - Code: "23505", - Message: "duplicate key value violates unique constraint \"organization_members_pkey\"", - Detail: "Key (organization_id, user_id)=(f7de1f4e-5833-4410-a28d-0a105f96003f, 36052a80-4a7f-4998-a7ca-44cefa608d3e) already exists.", - Table: "organization_members", - Constraint: "organization_members_pkey", - } - } - - //nolint:gosimple - organizationMember := database.OrganizationMember{ - OrganizationID: arg.OrganizationID, - UserID: arg.UserID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - Roles: arg.Roles, - } - q.organizationMembers = append(q.organizationMembers, organizationMember) - return organizationMember, nil -} - -func (q *FakeQuerier) InsertPreset(_ context.Context, arg database.InsertPresetParams) (database.TemplateVersionPreset, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.TemplateVersionPreset{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - //nolint:gosimple // arg needs to keep its type for interface reasons and that type is not appropriate for preset below. - preset := database.TemplateVersionPreset{ - ID: uuid.New(), - TemplateVersionID: arg.TemplateVersionID, - Name: arg.Name, - CreatedAt: arg.CreatedAt, - DesiredInstances: arg.DesiredInstances, - InvalidateAfterSecs: sql.NullInt32{ - Int32: 0, - Valid: true, - }, - PrebuildStatus: database.PrebuildStatusHealthy, - IsDefault: arg.IsDefault, - } - q.presets = append(q.presets, preset) - return preset, nil -} - -func (q *FakeQuerier) InsertPresetParameters(_ context.Context, arg database.InsertPresetParametersParams) ([]database.TemplateVersionPresetParameter, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - presetParameters := make([]database.TemplateVersionPresetParameter, 0, len(arg.Names)) - for i, v := range arg.Names { - presetParameter := database.TemplateVersionPresetParameter{ - ID: uuid.New(), - TemplateVersionPresetID: arg.TemplateVersionPresetID, - Name: v, - Value: arg.Values[i], - } - presetParameters = append(presetParameters, presetParameter) - q.presetParameters = append(q.presetParameters, presetParameter) - } - - return presetParameters, nil -} - -func (q *FakeQuerier) InsertPresetPrebuildSchedule(ctx context.Context, arg database.InsertPresetPrebuildScheduleParams) (database.TemplateVersionPresetPrebuildSchedule, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.TemplateVersionPresetPrebuildSchedule{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - presetPrebuildSchedule := database.TemplateVersionPresetPrebuildSchedule{ - ID: uuid.New(), - PresetID: arg.PresetID, - CronExpression: arg.CronExpression, - DesiredInstances: arg.DesiredInstances, - } - q.presetPrebuildSchedules = append(q.presetPrebuildSchedules, presetPrebuildSchedule) - return presetPrebuildSchedule, nil -} - -func (q *FakeQuerier) InsertProvisionerJob(_ context.Context, arg database.InsertProvisionerJobParams) (database.ProvisionerJob, error) { - if err := validateDatabaseType(arg); err != nil { - return database.ProvisionerJob{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - job := database.ProvisionerJob{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - OrganizationID: arg.OrganizationID, - InitiatorID: arg.InitiatorID, - Provisioner: arg.Provisioner, - StorageMethod: arg.StorageMethod, - FileID: arg.FileID, - Type: arg.Type, - Input: arg.Input, - Tags: maps.Clone(arg.Tags), - TraceMetadata: arg.TraceMetadata, - } - job.JobStatus = provisionerJobStatus(job) - q.provisionerJobs = append(q.provisionerJobs, job) - return job, nil -} - -func (q *FakeQuerier) InsertProvisionerJobLogs(_ context.Context, arg database.InsertProvisionerJobLogsParams) ([]database.ProvisionerJobLog, error) { - if err := validateDatabaseType(arg); err != nil { - return nil, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - logs := make([]database.ProvisionerJobLog, 0) - id := int64(1) - if len(q.provisionerJobLogs) > 0 { - id = q.provisionerJobLogs[len(q.provisionerJobLogs)-1].ID - } - for index, output := range arg.Output { - id++ - logs = append(logs, database.ProvisionerJobLog{ - ID: id, - JobID: arg.JobID, - CreatedAt: arg.CreatedAt[index], - Source: arg.Source[index], - Level: arg.Level[index], - Stage: arg.Stage[index], - Output: output, - }) - } - q.provisionerJobLogs = append(q.provisionerJobLogs, logs...) - return logs, nil -} - -func (q *FakeQuerier) InsertProvisionerJobTimings(_ context.Context, arg database.InsertProvisionerJobTimingsParams) ([]database.ProvisionerJobTiming, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - insertedTimings := make([]database.ProvisionerJobTiming, 0, len(arg.StartedAt)) - for i := range arg.StartedAt { - timing := database.ProvisionerJobTiming{ - JobID: arg.JobID, - StartedAt: arg.StartedAt[i], - EndedAt: arg.EndedAt[i], - Stage: arg.Stage[i], - Source: arg.Source[i], - Action: arg.Action[i], - Resource: arg.Resource[i], - } - q.provisionerJobTimings = append(q.provisionerJobTimings, timing) - insertedTimings = append(insertedTimings, timing) - } - - return insertedTimings, nil -} - -func (q *FakeQuerier) InsertProvisionerKey(_ context.Context, arg database.InsertProvisionerKeyParams) (database.ProvisionerKey, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.ProvisionerKey{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, key := range q.provisionerKeys { - if key.ID == arg.ID || (key.OrganizationID == arg.OrganizationID && strings.EqualFold(key.Name, arg.Name)) { - return database.ProvisionerKey{}, newUniqueConstraintError(database.UniqueProvisionerKeysOrganizationIDNameIndex) - } - } - - //nolint:gosimple - provisionerKey := database.ProvisionerKey{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - OrganizationID: arg.OrganizationID, - Name: strings.ToLower(arg.Name), - HashedSecret: arg.HashedSecret, - Tags: arg.Tags, - } - q.provisionerKeys = append(q.provisionerKeys, provisionerKey) - - return provisionerKey, nil -} - -func (q *FakeQuerier) InsertReplica(_ context.Context, arg database.InsertReplicaParams) (database.Replica, error) { - if err := validateDatabaseType(arg); err != nil { - return database.Replica{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - replica := database.Replica{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - StartedAt: arg.StartedAt, - UpdatedAt: arg.UpdatedAt, - Hostname: arg.Hostname, - RegionID: arg.RegionID, - RelayAddress: arg.RelayAddress, - Version: arg.Version, - DatabaseLatency: arg.DatabaseLatency, - Primary: arg.Primary, - } - q.replicas = append(q.replicas, replica) - return replica, nil -} - -func (q *FakeQuerier) InsertTelemetryItemIfNotExists(_ context.Context, arg database.InsertTelemetryItemIfNotExistsParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, item := range q.telemetryItems { - if item.Key == arg.Key { - return nil - } - } - - q.telemetryItems = append(q.telemetryItems, database.TelemetryItem{ - Key: arg.Key, - Value: arg.Value, - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - }) - return nil -} - -func (q *FakeQuerier) InsertTemplate(_ context.Context, arg database.InsertTemplateParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - //nolint:gosimple - template := database.TemplateTable{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - OrganizationID: arg.OrganizationID, - Name: arg.Name, - Provisioner: arg.Provisioner, - ActiveVersionID: arg.ActiveVersionID, - Description: arg.Description, - CreatedBy: arg.CreatedBy, - UserACL: arg.UserACL, - GroupACL: arg.GroupACL, - DisplayName: arg.DisplayName, - Icon: arg.Icon, - AllowUserCancelWorkspaceJobs: arg.AllowUserCancelWorkspaceJobs, - AllowUserAutostart: true, - AllowUserAutostop: true, - MaxPortSharingLevel: arg.MaxPortSharingLevel, - UseClassicParameterFlow: arg.UseClassicParameterFlow, - } - q.templates = append(q.templates, template) - return nil -} - -func (q *FakeQuerier) InsertTemplateVersion(_ context.Context, arg database.InsertTemplateVersionParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - if len(arg.Message) > 1048576 { - return xerrors.New("message too long") - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - //nolint:gosimple - version := database.TemplateVersionTable{ - ID: arg.ID, - TemplateID: arg.TemplateID, - OrganizationID: arg.OrganizationID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - Name: arg.Name, - Message: arg.Message, - Readme: arg.Readme, - JobID: arg.JobID, - CreatedBy: arg.CreatedBy, - SourceExampleID: arg.SourceExampleID, - } - q.templateVersions = append(q.templateVersions, version) - return nil -} - -func (q *FakeQuerier) InsertTemplateVersionParameter(_ context.Context, arg database.InsertTemplateVersionParameterParams) (database.TemplateVersionParameter, error) { - if err := validateDatabaseType(arg); err != nil { - return database.TemplateVersionParameter{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - //nolint:gosimple - param := database.TemplateVersionParameter{ - TemplateVersionID: arg.TemplateVersionID, - Name: arg.Name, - DisplayName: arg.DisplayName, - Description: arg.Description, - Type: arg.Type, - FormType: arg.FormType, - Mutable: arg.Mutable, - DefaultValue: arg.DefaultValue, - Icon: arg.Icon, - Options: arg.Options, - ValidationError: arg.ValidationError, - ValidationRegex: arg.ValidationRegex, - ValidationMin: arg.ValidationMin, - ValidationMax: arg.ValidationMax, - ValidationMonotonic: arg.ValidationMonotonic, - Required: arg.Required, - DisplayOrder: arg.DisplayOrder, - Ephemeral: arg.Ephemeral, - } - q.templateVersionParameters = append(q.templateVersionParameters, param) - return param, nil -} - -func (q *FakeQuerier) InsertTemplateVersionTerraformValuesByJobID(_ context.Context, arg database.InsertTemplateVersionTerraformValuesByJobIDParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - // Find the template version by the job_id - templateVersion, ok := slice.Find(q.templateVersions, func(v database.TemplateVersionTable) bool { - return v.JobID == arg.JobID - }) - if !ok { - return sql.ErrNoRows - } - - if !json.Valid(arg.CachedPlan) { - return xerrors.Errorf("cached plan must be valid json, received %q", string(arg.CachedPlan)) - } - - // Insert the new row - row := database.TemplateVersionTerraformValue{ - TemplateVersionID: templateVersion.ID, - UpdatedAt: arg.UpdatedAt, - CachedPlan: arg.CachedPlan, - CachedModuleFiles: arg.CachedModuleFiles, - ProvisionerdVersion: arg.ProvisionerdVersion, - } - q.templateVersionTerraformValues = append(q.templateVersionTerraformValues, row) - return nil -} - -func (q *FakeQuerier) InsertTemplateVersionVariable(_ context.Context, arg database.InsertTemplateVersionVariableParams) (database.TemplateVersionVariable, error) { - if err := validateDatabaseType(arg); err != nil { - return database.TemplateVersionVariable{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - //nolint:gosimple - variable := database.TemplateVersionVariable{ - TemplateVersionID: arg.TemplateVersionID, - Name: arg.Name, - Description: arg.Description, - Type: arg.Type, - Value: arg.Value, - DefaultValue: arg.DefaultValue, - Required: arg.Required, - Sensitive: arg.Sensitive, - } - q.templateVersionVariables = append(q.templateVersionVariables, variable) - return variable, nil -} - -func (q *FakeQuerier) InsertTemplateVersionWorkspaceTag(_ context.Context, arg database.InsertTemplateVersionWorkspaceTagParams) (database.TemplateVersionWorkspaceTag, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.TemplateVersionWorkspaceTag{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - //nolint:gosimple - workspaceTag := database.TemplateVersionWorkspaceTag{ - TemplateVersionID: arg.TemplateVersionID, - Key: arg.Key, - Value: arg.Value, - } - q.templateVersionWorkspaceTags = append(q.templateVersionWorkspaceTags, workspaceTag) - return workspaceTag, nil -} - -func (q *FakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParams) (database.User, error) { - if err := validateDatabaseType(arg); err != nil { - return database.User{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, user := range q.users { - if user.Username == arg.Username && !user.Deleted { - return database.User{}, errUniqueConstraint - } - } - - status := database.UserStatusDormant - if arg.Status != "" { - status = database.UserStatus(arg.Status) - } - - user := database.User{ - ID: arg.ID, - Email: arg.Email, - HashedPassword: arg.HashedPassword, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - Username: arg.Username, - Name: arg.Name, - Status: status, - RBACRoles: arg.RBACRoles, - LoginType: arg.LoginType, - IsSystem: false, - } - q.users = append(q.users, user) - sort.Slice(q.users, func(i, j int) bool { - return q.users[i].CreatedAt.Before(q.users[j].CreatedAt) - }) - - q.userStatusChanges = append(q.userStatusChanges, database.UserStatusChange{ - UserID: user.ID, - NewStatus: user.Status, - ChangedAt: user.UpdatedAt, - }) - return user, nil -} - -func (q *FakeQuerier) InsertUserGroupsByID(_ context.Context, arg database.InsertUserGroupsByIDParams) ([]uuid.UUID, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - var groupIDs []uuid.UUID - for _, group := range q.groups { - for _, groupID := range arg.GroupIds { - if group.ID == groupID { - q.groupMembers = append(q.groupMembers, database.GroupMemberTable{ - UserID: arg.UserID, - GroupID: groupID, - }) - groupIDs = append(groupIDs, group.ID) - } - } - } - - return groupIDs, nil -} - -func (q *FakeQuerier) InsertUserGroupsByName(_ context.Context, arg database.InsertUserGroupsByNameParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - var groupIDs []uuid.UUID - for _, group := range q.groups { - for _, groupName := range arg.GroupNames { - if group.Name == groupName { - groupIDs = append(groupIDs, group.ID) - } - } - } - - for _, groupID := range groupIDs { - q.groupMembers = append(q.groupMembers, database.GroupMemberTable{ - UserID: arg.UserID, - GroupID: groupID, - }) - } - - return nil -} - -func (q *FakeQuerier) InsertUserLink(_ context.Context, args database.InsertUserLinkParams) (database.UserLink, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - if u, err := q.getUserByIDNoLock(args.UserID); err == nil && u.Deleted { - return database.UserLink{}, deletedUserLinkError - } - - //nolint:gosimple - link := database.UserLink{ - UserID: args.UserID, - LoginType: args.LoginType, - LinkedID: args.LinkedID, - OAuthAccessToken: args.OAuthAccessToken, - OAuthAccessTokenKeyID: args.OAuthAccessTokenKeyID, - OAuthRefreshToken: args.OAuthRefreshToken, - OAuthRefreshTokenKeyID: args.OAuthRefreshTokenKeyID, - OAuthExpiry: args.OAuthExpiry, - Claims: args.Claims, - } - - q.userLinks = append(q.userLinks, link) - - return link, nil -} - -func (q *FakeQuerier) InsertVolumeResourceMonitor(_ context.Context, arg database.InsertVolumeResourceMonitorParams) (database.WorkspaceAgentVolumeResourceMonitor, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.WorkspaceAgentVolumeResourceMonitor{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - monitor := database.WorkspaceAgentVolumeResourceMonitor{ - AgentID: arg.AgentID, - Path: arg.Path, - Enabled: arg.Enabled, - State: arg.State, - Threshold: arg.Threshold, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - DebouncedUntil: arg.DebouncedUntil, - } - - q.workspaceAgentVolumeResourceMonitors = append(q.workspaceAgentVolumeResourceMonitors, monitor) - return monitor, nil -} - -func (q *FakeQuerier) InsertWebpushSubscription(_ context.Context, arg database.InsertWebpushSubscriptionParams) (database.WebpushSubscription, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.WebpushSubscription{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - newSub := database.WebpushSubscription{ - ID: uuid.New(), - UserID: arg.UserID, - CreatedAt: arg.CreatedAt, - Endpoint: arg.Endpoint, - EndpointP256dhKey: arg.EndpointP256dhKey, - EndpointAuthKey: arg.EndpointAuthKey, - } - q.webpushSubscriptions = append(q.webpushSubscriptions, newSub) - return newSub, nil -} - -func (q *FakeQuerier) InsertWorkspace(_ context.Context, arg database.InsertWorkspaceParams) (database.WorkspaceTable, error) { - if err := validateDatabaseType(arg); err != nil { - return database.WorkspaceTable{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - //nolint:gosimple - workspace := database.WorkspaceTable{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - OwnerID: arg.OwnerID, - OrganizationID: arg.OrganizationID, - TemplateID: arg.TemplateID, - Name: arg.Name, - AutostartSchedule: arg.AutostartSchedule, - Ttl: arg.Ttl, - LastUsedAt: arg.LastUsedAt, - AutomaticUpdates: arg.AutomaticUpdates, - NextStartAt: arg.NextStartAt, - } - q.workspaces = append(q.workspaces, workspace) - return workspace, nil -} - -func (q *FakeQuerier) InsertWorkspaceAgent(_ context.Context, arg database.InsertWorkspaceAgentParams) (database.WorkspaceAgent, error) { - if err := validateDatabaseType(arg); err != nil { - return database.WorkspaceAgent{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - agent := database.WorkspaceAgent{ - ID: arg.ID, - ParentID: arg.ParentID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - ResourceID: arg.ResourceID, - AuthToken: arg.AuthToken, - AuthInstanceID: arg.AuthInstanceID, - EnvironmentVariables: arg.EnvironmentVariables, - Name: arg.Name, - Architecture: arg.Architecture, - OperatingSystem: arg.OperatingSystem, - Directory: arg.Directory, - InstanceMetadata: arg.InstanceMetadata, - ResourceMetadata: arg.ResourceMetadata, - ConnectionTimeoutSeconds: arg.ConnectionTimeoutSeconds, - TroubleshootingURL: arg.TroubleshootingURL, - MOTDFile: arg.MOTDFile, - LifecycleState: database.WorkspaceAgentLifecycleStateCreated, - DisplayApps: arg.DisplayApps, - DisplayOrder: arg.DisplayOrder, - APIKeyScope: arg.APIKeyScope, - } - - q.workspaceAgents = append(q.workspaceAgents, agent) - return agent, nil -} - -func (q *FakeQuerier) InsertWorkspaceAgentDevcontainers(_ context.Context, arg database.InsertWorkspaceAgentDevcontainersParams) ([]database.WorkspaceAgentDevcontainer, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, agent := range q.workspaceAgents { - if agent.ID == arg.WorkspaceAgentID { - var devcontainers []database.WorkspaceAgentDevcontainer - for i, id := range arg.ID { - devcontainers = append(devcontainers, database.WorkspaceAgentDevcontainer{ - WorkspaceAgentID: arg.WorkspaceAgentID, - CreatedAt: arg.CreatedAt, - ID: id, - Name: arg.Name[i], - WorkspaceFolder: arg.WorkspaceFolder[i], - ConfigPath: arg.ConfigPath[i], - }) - } - q.workspaceAgentDevcontainers = append(q.workspaceAgentDevcontainers, devcontainers...) - return devcontainers, nil - } - } - - return nil, errForeignKeyConstraint -} - -func (q *FakeQuerier) InsertWorkspaceAgentLogSources(_ context.Context, arg database.InsertWorkspaceAgentLogSourcesParams) ([]database.WorkspaceAgentLogSource, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - logSources := make([]database.WorkspaceAgentLogSource, 0) - for index, source := range arg.ID { - logSource := database.WorkspaceAgentLogSource{ - ID: source, - WorkspaceAgentID: arg.WorkspaceAgentID, - CreatedAt: arg.CreatedAt, - DisplayName: arg.DisplayName[index], - Icon: arg.Icon[index], - } - logSources = append(logSources, logSource) - } - q.workspaceAgentLogSources = append(q.workspaceAgentLogSources, logSources...) - return logSources, nil -} - -func (q *FakeQuerier) InsertWorkspaceAgentLogs(_ context.Context, arg database.InsertWorkspaceAgentLogsParams) ([]database.WorkspaceAgentLog, error) { - if err := validateDatabaseType(arg); err != nil { - return nil, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - logs := []database.WorkspaceAgentLog{} - id := int64(0) - if len(q.workspaceAgentLogs) > 0 { - id = q.workspaceAgentLogs[len(q.workspaceAgentLogs)-1].ID - } - outputLength := int32(0) - for index, output := range arg.Output { - id++ - logs = append(logs, database.WorkspaceAgentLog{ - ID: id, - AgentID: arg.AgentID, - CreatedAt: arg.CreatedAt, - Level: arg.Level[index], - LogSourceID: arg.LogSourceID, - Output: output, - }) - // #nosec G115 - Safe conversion as log output length is expected to be within int32 range - outputLength += int32(len(output)) - } - for index, agent := range q.workspaceAgents { - if agent.ID != arg.AgentID { - continue - } - // Greater than 1MB, same as the PostgreSQL constraint! - if agent.LogsLength+outputLength > (1 << 20) { - return nil, &pq.Error{ - Constraint: "max_logs_length", - Table: "workspace_agents", - } - } - agent.LogsLength += outputLength - q.workspaceAgents[index] = agent - break - } - q.workspaceAgentLogs = append(q.workspaceAgentLogs, logs...) - return logs, nil -} - -func (q *FakeQuerier) InsertWorkspaceAgentMetadata(_ context.Context, arg database.InsertWorkspaceAgentMetadataParams) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - //nolint:gosimple - metadatum := database.WorkspaceAgentMetadatum{ - WorkspaceAgentID: arg.WorkspaceAgentID, - Script: arg.Script, - DisplayName: arg.DisplayName, - Key: arg.Key, - Timeout: arg.Timeout, - Interval: arg.Interval, - DisplayOrder: arg.DisplayOrder, - } - - q.workspaceAgentMetadata = append(q.workspaceAgentMetadata, metadatum) - return nil -} - -func (q *FakeQuerier) InsertWorkspaceAgentScriptTimings(_ context.Context, arg database.InsertWorkspaceAgentScriptTimingsParams) (database.WorkspaceAgentScriptTiming, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.WorkspaceAgentScriptTiming{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - timing := database.WorkspaceAgentScriptTiming(arg) - q.workspaceAgentScriptTimings = append(q.workspaceAgentScriptTimings, timing) - - return timing, nil -} - -func (q *FakeQuerier) InsertWorkspaceAgentScripts(_ context.Context, arg database.InsertWorkspaceAgentScriptsParams) ([]database.WorkspaceAgentScript, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - scripts := make([]database.WorkspaceAgentScript, 0) - for index, source := range arg.LogSourceID { - script := database.WorkspaceAgentScript{ - LogSourceID: source, - WorkspaceAgentID: arg.WorkspaceAgentID, - ID: arg.ID[index], - LogPath: arg.LogPath[index], - Script: arg.Script[index], - Cron: arg.Cron[index], - StartBlocksLogin: arg.StartBlocksLogin[index], - RunOnStart: arg.RunOnStart[index], - RunOnStop: arg.RunOnStop[index], - TimeoutSeconds: arg.TimeoutSeconds[index], - CreatedAt: arg.CreatedAt, - } - scripts = append(scripts, script) - } - q.workspaceAgentScripts = append(q.workspaceAgentScripts, scripts...) - return scripts, nil -} - -func (q *FakeQuerier) InsertWorkspaceAgentStats(_ context.Context, arg database.InsertWorkspaceAgentStatsParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - var connectionsByProto []map[string]int64 - if err := json.Unmarshal(arg.ConnectionsByProto, &connectionsByProto); err != nil { - return err - } - for i := 0; i < len(arg.ID); i++ { - cbp, err := json.Marshal(connectionsByProto[i]) - if err != nil { - return xerrors.Errorf("failed to marshal connections_by_proto: %w", err) - } - stat := database.WorkspaceAgentStat{ - ID: arg.ID[i], - CreatedAt: arg.CreatedAt[i], - WorkspaceID: arg.WorkspaceID[i], - AgentID: arg.AgentID[i], - UserID: arg.UserID[i], - ConnectionsByProto: cbp, - ConnectionCount: arg.ConnectionCount[i], - RxPackets: arg.RxPackets[i], - RxBytes: arg.RxBytes[i], - TxPackets: arg.TxPackets[i], - TxBytes: arg.TxBytes[i], - TemplateID: arg.TemplateID[i], - SessionCountVSCode: arg.SessionCountVSCode[i], - SessionCountJetBrains: arg.SessionCountJetBrains[i], - SessionCountReconnectingPTY: arg.SessionCountReconnectingPTY[i], - SessionCountSSH: arg.SessionCountSSH[i], - ConnectionMedianLatencyMS: arg.ConnectionMedianLatencyMS[i], - Usage: arg.Usage[i], - } - q.workspaceAgentStats = append(q.workspaceAgentStats, stat) - } - - return nil -} - -func (q *FakeQuerier) InsertWorkspaceAppStats(_ context.Context, arg database.InsertWorkspaceAppStatsParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - -InsertWorkspaceAppStatsLoop: - for i := 0; i < len(arg.UserID); i++ { - stat := database.WorkspaceAppStat{ - ID: q.workspaceAppStatsLastInsertID + 1, - UserID: arg.UserID[i], - WorkspaceID: arg.WorkspaceID[i], - AgentID: arg.AgentID[i], - AccessMethod: arg.AccessMethod[i], - SlugOrPort: arg.SlugOrPort[i], - SessionID: arg.SessionID[i], - SessionStartedAt: arg.SessionStartedAt[i], - SessionEndedAt: arg.SessionEndedAt[i], - Requests: arg.Requests[i], - } - for j, s := range q.workspaceAppStats { - // Check unique constraint for upsert. - if s.UserID == stat.UserID && s.AgentID == stat.AgentID && s.SessionID == stat.SessionID { - q.workspaceAppStats[j].SessionEndedAt = stat.SessionEndedAt - q.workspaceAppStats[j].Requests = stat.Requests - continue InsertWorkspaceAppStatsLoop - } - } - q.workspaceAppStats = append(q.workspaceAppStats, stat) - q.workspaceAppStatsLastInsertID++ - } - - return nil -} - -func (q *FakeQuerier) InsertWorkspaceAppStatus(_ context.Context, arg database.InsertWorkspaceAppStatusParams) (database.WorkspaceAppStatus, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.WorkspaceAppStatus{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - status := database.WorkspaceAppStatus{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - WorkspaceID: arg.WorkspaceID, - AgentID: arg.AgentID, - AppID: arg.AppID, - State: arg.State, - Message: arg.Message, - Uri: arg.Uri, - } - q.workspaceAppStatuses = append(q.workspaceAppStatuses, status) - return status, nil -} - -func (q *FakeQuerier) InsertWorkspaceBuild(_ context.Context, arg database.InsertWorkspaceBuildParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - workspaceBuild := database.WorkspaceBuild{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - WorkspaceID: arg.WorkspaceID, - TemplateVersionID: arg.TemplateVersionID, - BuildNumber: arg.BuildNumber, - Transition: arg.Transition, - InitiatorID: arg.InitiatorID, - JobID: arg.JobID, - ProvisionerState: arg.ProvisionerState, - Deadline: arg.Deadline, - MaxDeadline: arg.MaxDeadline, - Reason: arg.Reason, - TemplateVersionPresetID: arg.TemplateVersionPresetID, - } - q.workspaceBuilds = append(q.workspaceBuilds, workspaceBuild) - return nil -} - -func (q *FakeQuerier) InsertWorkspaceBuildParameters(_ context.Context, arg database.InsertWorkspaceBuildParametersParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, name := range arg.Name { - q.workspaceBuildParameters = append(q.workspaceBuildParameters, database.WorkspaceBuildParameter{ - WorkspaceBuildID: arg.WorkspaceBuildID, - Name: name, - Value: arg.Value[index], - }) - } - return nil -} - -func (q *FakeQuerier) InsertWorkspaceModule(_ context.Context, arg database.InsertWorkspaceModuleParams) (database.WorkspaceModule, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.WorkspaceModule{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - workspaceModule := database.WorkspaceModule(arg) - q.workspaceModules = append(q.workspaceModules, workspaceModule) - return workspaceModule, nil -} - -func (q *FakeQuerier) InsertWorkspaceProxy(_ context.Context, arg database.InsertWorkspaceProxyParams) (database.WorkspaceProxy, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - lastRegionID := int32(0) - for _, p := range q.workspaceProxies { - if !p.Deleted && p.Name == arg.Name { - return database.WorkspaceProxy{}, errUniqueConstraint - } - if p.RegionID > lastRegionID { - lastRegionID = p.RegionID - } - } - - p := database.WorkspaceProxy{ - ID: arg.ID, - Name: arg.Name, - DisplayName: arg.DisplayName, - Icon: arg.Icon, - DerpEnabled: arg.DerpEnabled, - DerpOnly: arg.DerpOnly, - TokenHashedSecret: arg.TokenHashedSecret, - RegionID: lastRegionID + 1, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - Deleted: false, - } - q.workspaceProxies = append(q.workspaceProxies, p) - return p, nil -} - -func (q *FakeQuerier) InsertWorkspaceResource(_ context.Context, arg database.InsertWorkspaceResourceParams) (database.WorkspaceResource, error) { - if err := validateDatabaseType(arg); err != nil { - return database.WorkspaceResource{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - //nolint:gosimple - resource := database.WorkspaceResource{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - JobID: arg.JobID, - Transition: arg.Transition, - Type: arg.Type, - Name: arg.Name, - Hide: arg.Hide, - Icon: arg.Icon, - DailyCost: arg.DailyCost, - ModulePath: arg.ModulePath, - } - q.workspaceResources = append(q.workspaceResources, resource) - return resource, nil -} - -func (q *FakeQuerier) InsertWorkspaceResourceMetadata(_ context.Context, arg database.InsertWorkspaceResourceMetadataParams) ([]database.WorkspaceResourceMetadatum, error) { - if err := validateDatabaseType(arg); err != nil { - return nil, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - metadata := make([]database.WorkspaceResourceMetadatum, 0) - id := int64(1) - if len(q.workspaceResourceMetadata) > 0 { - id = q.workspaceResourceMetadata[len(q.workspaceResourceMetadata)-1].ID - } - for index, key := range arg.Key { - id++ - value := arg.Value[index] - metadata = append(metadata, database.WorkspaceResourceMetadatum{ - ID: id, - WorkspaceResourceID: arg.WorkspaceResourceID, - Key: key, - Value: sql.NullString{ - String: value, - Valid: value != "", - }, - Sensitive: arg.Sensitive[index], - }) - } - q.workspaceResourceMetadata = append(q.workspaceResourceMetadata, metadata...) - return metadata, nil -} - -func (q *FakeQuerier) ListProvisionerKeysByOrganization(_ context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - keys := make([]database.ProvisionerKey, 0) - for _, key := range q.provisionerKeys { - if key.OrganizationID == organizationID { - keys = append(keys, key) - } - } - - return keys, nil -} - -func (q *FakeQuerier) ListProvisionerKeysByOrganizationExcludeReserved(_ context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - keys := make([]database.ProvisionerKey, 0) - for _, key := range q.provisionerKeys { - if key.ID.String() == codersdk.ProvisionerKeyIDBuiltIn || - key.ID.String() == codersdk.ProvisionerKeyIDUserAuth || - key.ID.String() == codersdk.ProvisionerKeyIDPSK { - continue - } - if key.OrganizationID == organizationID { - keys = append(keys, key) - } - } - - return keys, nil -} - -func (q *FakeQuerier) ListWorkspaceAgentPortShares(_ context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgentPortShare, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - shares := []database.WorkspaceAgentPortShare{} - for _, share := range q.workspaceAgentPortShares { - if share.WorkspaceID == workspaceID { - shares = append(shares, share) - } - } - - return shares, nil -} - -func (q *FakeQuerier) MarkAllInboxNotificationsAsRead(_ context.Context, arg database.MarkAllInboxNotificationsAsReadParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - for idx, notif := range q.inboxNotifications { - if notif.UserID == arg.UserID && !notif.ReadAt.Valid { - q.inboxNotifications[idx].ReadAt = arg.ReadAt - } - } - - return nil -} - -// nolint:forcetypeassert -func (q *FakeQuerier) OIDCClaimFieldValues(_ context.Context, args database.OIDCClaimFieldValuesParams) ([]string, error) { - orgMembers := q.getOrganizationMemberNoLock(args.OrganizationID) - - var values []string - for _, link := range q.userLinks { - if args.OrganizationID != uuid.Nil { - inOrg := slices.ContainsFunc(orgMembers, func(organizationMember database.OrganizationMember) bool { - return organizationMember.UserID == link.UserID - }) - if !inOrg { - continue - } - } - - if link.LoginType != database.LoginTypeOIDC { - continue - } - - if len(link.Claims.MergedClaims) == 0 { - continue - } - - value, ok := link.Claims.MergedClaims[args.ClaimField] - if !ok { - continue - } - switch value := value.(type) { - case string: - values = append(values, value) - case []string: - values = append(values, value...) - case []any: - for _, v := range value { - if sv, ok := v.(string); ok { - values = append(values, sv) - } - } - default: - continue - } - } - - return slice.Unique(values), nil -} - -func (q *FakeQuerier) OIDCClaimFields(_ context.Context, organizationID uuid.UUID) ([]string, error) { - orgMembers := q.getOrganizationMemberNoLock(organizationID) - - var fields []string - for _, link := range q.userLinks { - if organizationID != uuid.Nil { - inOrg := slices.ContainsFunc(orgMembers, func(organizationMember database.OrganizationMember) bool { - return organizationMember.UserID == link.UserID - }) - if !inOrg { - continue - } - } - - if link.LoginType != database.LoginTypeOIDC { - continue - } - - for k := range link.Claims.MergedClaims { - fields = append(fields, k) - } - } - - return slice.Unique(fields), nil -} - -func (q *FakeQuerier) OrganizationMembers(_ context.Context, arg database.OrganizationMembersParams) ([]database.OrganizationMembersRow, error) { - if err := validateDatabaseType(arg); err != nil { - return []database.OrganizationMembersRow{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - tmp := make([]database.OrganizationMembersRow, 0) - for _, organizationMember := range q.organizationMembers { - if arg.OrganizationID != uuid.Nil && organizationMember.OrganizationID != arg.OrganizationID { - continue - } - - if arg.UserID != uuid.Nil && organizationMember.UserID != arg.UserID { - continue - } - - user, _ := q.getUserByIDNoLock(organizationMember.UserID) - tmp = append(tmp, database.OrganizationMembersRow{ - OrganizationMember: organizationMember, - Username: user.Username, - AvatarURL: user.AvatarURL, - Name: user.Name, - Email: user.Email, - GlobalRoles: user.RBACRoles, - }) - } - return tmp, nil -} - -func (q *FakeQuerier) PaginatedOrganizationMembers(_ context.Context, arg database.PaginatedOrganizationMembersParams) ([]database.PaginatedOrganizationMembersRow, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - // All of the members in the organization - orgMembers := make([]database.OrganizationMember, 0) - for _, mem := range q.organizationMembers { - if mem.OrganizationID != arg.OrganizationID { - continue - } - - orgMembers = append(orgMembers, mem) - } - - selectedMembers := make([]database.PaginatedOrganizationMembersRow, 0) - - skippedMembers := 0 - for _, organizationMember := range orgMembers { - if skippedMembers < int(arg.OffsetOpt) { - skippedMembers++ - continue - } - - // if the limit is set to 0 we treat that as returning all of the org members - if int(arg.LimitOpt) != 0 && len(selectedMembers) >= int(arg.LimitOpt) { - break - } - - user, _ := q.getUserByIDNoLock(organizationMember.UserID) - selectedMembers = append(selectedMembers, database.PaginatedOrganizationMembersRow{ - OrganizationMember: organizationMember, - Username: user.Username, - AvatarURL: user.AvatarURL, - Name: user.Name, - Email: user.Email, - GlobalRoles: user.RBACRoles, - Count: int64(len(orgMembers)), - }) - } - return selectedMembers, nil -} - -func (q *FakeQuerier) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(_ context.Context, templateID uuid.UUID) error { - err := validateDatabaseType(templateID) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, workspace := range q.workspaces { - if workspace.TemplateID != templateID { - continue - } - for i, share := range q.workspaceAgentPortShares { - if share.WorkspaceID != workspace.ID { - continue - } - if share.ShareLevel == database.AppSharingLevelPublic { - share.ShareLevel = database.AppSharingLevelAuthenticated - } - q.workspaceAgentPortShares[i] = share - } - } - - return nil -} - -func (q *FakeQuerier) RegisterWorkspaceProxy(_ context.Context, arg database.RegisterWorkspaceProxyParams) (database.WorkspaceProxy, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, p := range q.workspaceProxies { - if p.ID == arg.ID { - p.Url = arg.Url - p.WildcardHostname = arg.WildcardHostname - p.DerpEnabled = arg.DerpEnabled - p.DerpOnly = arg.DerpOnly - p.Version = arg.Version - p.UpdatedAt = dbtime.Now() - q.workspaceProxies[i] = p - return p, nil - } - } - return database.WorkspaceProxy{}, sql.ErrNoRows -} - -func (q *FakeQuerier) RemoveUserFromAllGroups(_ context.Context, userID uuid.UUID) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - newMembers := q.groupMembers[:0] - for _, member := range q.groupMembers { - if member.UserID == userID { - continue - } - newMembers = append(newMembers, member) - } - q.groupMembers = newMembers - - return nil -} - -func (q *FakeQuerier) RemoveUserFromGroups(_ context.Context, arg database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - removed := make([]uuid.UUID, 0) - q.data.groupMembers = slices.DeleteFunc(q.data.groupMembers, func(groupMember database.GroupMemberTable) bool { - // Delete all group members that match the arguments. - if groupMember.UserID != arg.UserID { - // Not the right user, ignore. - return false - } - - if !slices.Contains(arg.GroupIds, groupMember.GroupID) { - return false - } - - removed = append(removed, groupMember.GroupID) - return true - }) - - return removed, nil -} - -func (q *FakeQuerier) RevokeDBCryptKey(_ context.Context, activeKeyDigest string) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - for i := range q.dbcryptKeys { - key := q.dbcryptKeys[i] - - // Is the key already revoked? - if !key.ActiveKeyDigest.Valid { - continue - } - - if key.ActiveKeyDigest.String != activeKeyDigest { - continue - } - - // Check for foreign key constraints. - for _, ul := range q.userLinks { - if (ul.OAuthAccessTokenKeyID.Valid && ul.OAuthAccessTokenKeyID.String == activeKeyDigest) || - (ul.OAuthRefreshTokenKeyID.Valid && ul.OAuthRefreshTokenKeyID.String == activeKeyDigest) { - return errForeignKeyConstraint - } - } - for _, gal := range q.externalAuthLinks { - if (gal.OAuthAccessTokenKeyID.Valid && gal.OAuthAccessTokenKeyID.String == activeKeyDigest) || - (gal.OAuthRefreshTokenKeyID.Valid && gal.OAuthRefreshTokenKeyID.String == activeKeyDigest) { - return errForeignKeyConstraint - } - } - - // Revoke the key. - q.dbcryptKeys[i].RevokedAt = sql.NullTime{Time: dbtime.Now(), Valid: true} - q.dbcryptKeys[i].RevokedKeyDigest = sql.NullString{String: key.ActiveKeyDigest.String, Valid: true} - q.dbcryptKeys[i].ActiveKeyDigest = sql.NullString{} - return nil - } - - return sql.ErrNoRows -} - -func (*FakeQuerier) TryAcquireLock(_ context.Context, _ int64) (bool, error) { - return false, xerrors.New("TryAcquireLock must only be called within a transaction") -} - -func (q *FakeQuerier) UnarchiveTemplateVersion(_ context.Context, arg database.UnarchiveTemplateVersionParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, v := range q.data.templateVersions { - if v.ID == arg.TemplateVersionID { - v.Archived = false - v.UpdatedAt = arg.UpdatedAt - q.data.templateVersions[i] = v - return nil - } - } - - return sql.ErrNoRows -} - -func (q *FakeQuerier) UnfavoriteWorkspace(_ context.Context, arg uuid.UUID) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i := 0; i < len(q.workspaces); i++ { - if q.workspaces[i].ID != arg { - continue - } - q.workspaces[i].Favorite = false - return nil - } - - return nil -} - -func (q *FakeQuerier) UpdateAPIKeyByID(_ context.Context, arg database.UpdateAPIKeyByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, apiKey := range q.apiKeys { - if apiKey.ID != arg.ID { - continue - } - apiKey.LastUsed = arg.LastUsed - apiKey.ExpiresAt = arg.ExpiresAt - apiKey.IPAddress = arg.IPAddress - q.apiKeys[index] = apiKey - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateCryptoKeyDeletesAt(_ context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.CryptoKey{}, err - } - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, key := range q.cryptoKeys { - if key.Feature == arg.Feature && key.Sequence == arg.Sequence { - key.DeletesAt = arg.DeletesAt - q.cryptoKeys[i] = key - return key, nil - } - } - - return database.CryptoKey{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateCustomRole(_ context.Context, arg database.UpdateCustomRoleParams) (database.CustomRole, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.CustomRole{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - for i := range q.customRoles { - if strings.EqualFold(q.customRoles[i].Name, arg.Name) && - q.customRoles[i].OrganizationID.UUID == arg.OrganizationID.UUID { - q.customRoles[i].DisplayName = arg.DisplayName - q.customRoles[i].OrganizationID = arg.OrganizationID - q.customRoles[i].SitePermissions = arg.SitePermissions - q.customRoles[i].OrgPermissions = arg.OrgPermissions - q.customRoles[i].UserPermissions = arg.UserPermissions - q.customRoles[i].UpdatedAt = dbtime.Now() - return q.customRoles[i], nil - } - } - return database.CustomRole{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateExternalAuthLink(_ context.Context, arg database.UpdateExternalAuthLinkParams) (database.ExternalAuthLink, error) { - if err := validateDatabaseType(arg); err != nil { - return database.ExternalAuthLink{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - for index, gitAuthLink := range q.externalAuthLinks { - if gitAuthLink.ProviderID != arg.ProviderID { - continue - } - if gitAuthLink.UserID != arg.UserID { - continue - } - gitAuthLink.UpdatedAt = arg.UpdatedAt - gitAuthLink.OAuthAccessToken = arg.OAuthAccessToken - gitAuthLink.OAuthAccessTokenKeyID = arg.OAuthAccessTokenKeyID - gitAuthLink.OAuthRefreshToken = arg.OAuthRefreshToken - gitAuthLink.OAuthRefreshTokenKeyID = arg.OAuthRefreshTokenKeyID - gitAuthLink.OAuthExpiry = arg.OAuthExpiry - gitAuthLink.OAuthExtra = arg.OAuthExtra - q.externalAuthLinks[index] = gitAuthLink - - return gitAuthLink, nil - } - return database.ExternalAuthLink{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateExternalAuthLinkRefreshToken(_ context.Context, arg database.UpdateExternalAuthLinkRefreshTokenParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - for index, gitAuthLink := range q.externalAuthLinks { - if gitAuthLink.ProviderID != arg.ProviderID { - continue - } - if gitAuthLink.UserID != arg.UserID { - continue - } - gitAuthLink.UpdatedAt = arg.UpdatedAt - gitAuthLink.OAuthRefreshToken = arg.OAuthRefreshToken - q.externalAuthLinks[index] = gitAuthLink - - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateGitSSHKey(_ context.Context, arg database.UpdateGitSSHKeyParams) (database.GitSSHKey, error) { - if err := validateDatabaseType(arg); err != nil { - return database.GitSSHKey{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, key := range q.gitSSHKey { - if key.UserID != arg.UserID { - continue - } - key.UpdatedAt = arg.UpdatedAt - key.PrivateKey = arg.PrivateKey - key.PublicKey = arg.PublicKey - q.gitSSHKey[index] = key - return key, nil - } - return database.GitSSHKey{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateGroupByID(_ context.Context, arg database.UpdateGroupByIDParams) (database.Group, error) { - if err := validateDatabaseType(arg); err != nil { - return database.Group{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, group := range q.groups { - if group.ID == arg.ID { - group.DisplayName = arg.DisplayName - group.Name = arg.Name - group.AvatarURL = arg.AvatarURL - group.QuotaAllowance = arg.QuotaAllowance - q.groups[i] = group - return group, nil - } - } - return database.Group{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateInactiveUsersToDormant(_ context.Context, params database.UpdateInactiveUsersToDormantParams) ([]database.UpdateInactiveUsersToDormantRow, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - var updated []database.UpdateInactiveUsersToDormantRow - for index, user := range q.users { - if user.Status == database.UserStatusActive && user.LastSeenAt.Before(params.LastSeenAfter) && !user.IsSystem { - q.users[index].Status = database.UserStatusDormant - q.users[index].UpdatedAt = params.UpdatedAt - updated = append(updated, database.UpdateInactiveUsersToDormantRow{ - ID: user.ID, - Email: user.Email, - Username: user.Username, - LastSeenAt: user.LastSeenAt, - }) - q.userStatusChanges = append(q.userStatusChanges, database.UserStatusChange{ - UserID: user.ID, - NewStatus: database.UserStatusDormant, - ChangedAt: params.UpdatedAt, - }) - } - } - - if len(updated) == 0 { - return nil, sql.ErrNoRows - } - - return updated, nil -} - -func (q *FakeQuerier) UpdateInboxNotificationReadStatus(_ context.Context, arg database.UpdateInboxNotificationReadStatusParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i := range q.inboxNotifications { - if q.inboxNotifications[i].ID == arg.ID { - q.inboxNotifications[i].ReadAt = arg.ReadAt - } - } - - return nil -} - -func (q *FakeQuerier) UpdateMemberRoles(_ context.Context, arg database.UpdateMemberRolesParams) (database.OrganizationMember, error) { - if err := validateDatabaseType(arg); err != nil { - return database.OrganizationMember{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, mem := range q.organizationMembers { - if mem.UserID == arg.UserID && mem.OrganizationID == arg.OrgID { - uniqueRoles := make([]string, 0, len(arg.GrantedRoles)) - exist := make(map[string]struct{}) - for _, r := range arg.GrantedRoles { - if _, ok := exist[r]; ok { - continue - } - exist[r] = struct{}{} - uniqueRoles = append(uniqueRoles, r) - } - sort.Strings(uniqueRoles) - - mem.Roles = uniqueRoles - q.organizationMembers[i] = mem - return mem, nil - } - } - - return database.OrganizationMember{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateMemoryResourceMonitor(_ context.Context, arg database.UpdateMemoryResourceMonitorParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, monitor := range q.workspaceAgentMemoryResourceMonitors { - if monitor.AgentID != arg.AgentID { - continue - } - - monitor.State = arg.State - monitor.UpdatedAt = arg.UpdatedAt - monitor.DebouncedUntil = arg.DebouncedUntil - q.workspaceAgentMemoryResourceMonitors[i] = monitor - return nil - } - - return nil -} - -func (*FakeQuerier) UpdateNotificationTemplateMethodByID(_ context.Context, _ database.UpdateNotificationTemplateMethodByIDParams) (database.NotificationTemplate, error) { - // Not implementing this function because it relies on state in the database which is created with migrations. - // We could consider using code-generation to align the database state and dbmem, but it's not worth it right now. - return database.NotificationTemplate{}, ErrUnimplemented -} - -func (q *FakeQuerier) UpdateOAuth2ProviderAppByClientID(ctx context.Context, arg database.UpdateOAuth2ProviderAppByClientIDParams) (database.OAuth2ProviderApp, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.OAuth2ProviderApp{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, app := range q.oauth2ProviderApps { - if app.ID == arg.ID { - app.UpdatedAt = arg.UpdatedAt - app.Name = arg.Name - app.Icon = arg.Icon - app.CallbackURL = arg.CallbackURL - app.RedirectUris = arg.RedirectUris - app.GrantTypes = arg.GrantTypes - app.ResponseTypes = arg.ResponseTypes - app.TokenEndpointAuthMethod = arg.TokenEndpointAuthMethod - app.Scope = arg.Scope - app.Contacts = arg.Contacts - app.ClientUri = arg.ClientUri - app.LogoUri = arg.LogoUri - app.TosUri = arg.TosUri - app.PolicyUri = arg.PolicyUri - app.JwksUri = arg.JwksUri - app.Jwks = arg.Jwks - app.SoftwareID = arg.SoftwareID - app.SoftwareVersion = arg.SoftwareVersion - - // Apply RFC-compliant defaults to match database migration defaults - if !app.ClientType.Valid { - app.ClientType = sql.NullString{String: "confidential", Valid: true} - } - if !app.DynamicallyRegistered.Valid { - app.DynamicallyRegistered = sql.NullBool{Bool: false, Valid: true} - } - if len(app.GrantTypes) == 0 { - app.GrantTypes = []string{"authorization_code", "refresh_token"} - } - if len(app.ResponseTypes) == 0 { - app.ResponseTypes = []string{"code"} - } - if !app.TokenEndpointAuthMethod.Valid { - app.TokenEndpointAuthMethod = sql.NullString{String: "client_secret_basic", Valid: true} - } - if !app.Scope.Valid { - app.Scope = sql.NullString{String: "", Valid: true} - } - if app.Contacts == nil { - app.Contacts = []string{} - } - - q.oauth2ProviderApps[i] = app - return app, nil - } - } - return database.OAuth2ProviderApp{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateOAuth2ProviderAppByID(_ context.Context, arg database.UpdateOAuth2ProviderAppByIDParams) (database.OAuth2ProviderApp, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.OAuth2ProviderApp{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, app := range q.oauth2ProviderApps { - if app.Name == arg.Name && app.ID != arg.ID { - return database.OAuth2ProviderApp{}, errUniqueConstraint - } - } - - for index, app := range q.oauth2ProviderApps { - if app.ID == arg.ID { - app.UpdatedAt = arg.UpdatedAt - app.Name = arg.Name - app.Icon = arg.Icon - app.CallbackURL = arg.CallbackURL - app.RedirectUris = arg.RedirectUris - app.ClientType = arg.ClientType - app.DynamicallyRegistered = arg.DynamicallyRegistered - app.ClientSecretExpiresAt = arg.ClientSecretExpiresAt - app.GrantTypes = arg.GrantTypes - app.ResponseTypes = arg.ResponseTypes - app.TokenEndpointAuthMethod = arg.TokenEndpointAuthMethod - app.Scope = arg.Scope - app.Contacts = arg.Contacts - app.ClientUri = arg.ClientUri - app.LogoUri = arg.LogoUri - app.TosUri = arg.TosUri - app.PolicyUri = arg.PolicyUri - app.JwksUri = arg.JwksUri - app.Jwks = arg.Jwks - app.SoftwareID = arg.SoftwareID - app.SoftwareVersion = arg.SoftwareVersion - - // Apply RFC-compliant defaults to match database migration defaults - if !app.ClientType.Valid { - app.ClientType = sql.NullString{String: "confidential", Valid: true} - } - if !app.DynamicallyRegistered.Valid { - app.DynamicallyRegistered = sql.NullBool{Bool: false, Valid: true} - } - if len(app.GrantTypes) == 0 { - app.GrantTypes = []string{"authorization_code", "refresh_token"} - } - if len(app.ResponseTypes) == 0 { - app.ResponseTypes = []string{"code"} - } - if !app.TokenEndpointAuthMethod.Valid { - app.TokenEndpointAuthMethod = sql.NullString{String: "client_secret_basic", Valid: true} - } - if !app.Scope.Valid { - app.Scope = sql.NullString{String: "", Valid: true} - } - if app.Contacts == nil { - app.Contacts = []string{} - } - - q.oauth2ProviderApps[index] = app - return app, nil - } - } - return database.OAuth2ProviderApp{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateOAuth2ProviderAppSecretByID(_ context.Context, arg database.UpdateOAuth2ProviderAppSecretByIDParams) (database.OAuth2ProviderAppSecret, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.OAuth2ProviderAppSecret{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, secret := range q.oauth2ProviderAppSecrets { - if secret.ID == arg.ID { - newSecret := database.OAuth2ProviderAppSecret{ - ID: arg.ID, - CreatedAt: secret.CreatedAt, - SecretPrefix: secret.SecretPrefix, - HashedSecret: secret.HashedSecret, - DisplaySecret: secret.DisplaySecret, - AppID: secret.AppID, - LastUsedAt: arg.LastUsedAt, - } - q.oauth2ProviderAppSecrets[index] = newSecret - return newSecret, nil - } - } - return database.OAuth2ProviderAppSecret{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateOrganization(_ context.Context, arg database.UpdateOrganizationParams) (database.Organization, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.Organization{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - // Enforce the unique constraint, because the API endpoint relies on the database catching - // non-unique names during updates. - for _, org := range q.organizations { - if org.Name == arg.Name && org.ID != arg.ID { - return database.Organization{}, errUniqueConstraint - } - } - - for i, org := range q.organizations { - if org.ID == arg.ID { - org.Name = arg.Name - org.DisplayName = arg.DisplayName - org.Description = arg.Description - org.Icon = arg.Icon - q.organizations[i] = org - return org, nil - } - } - return database.Organization{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateOrganizationDeletedByID(_ context.Context, arg database.UpdateOrganizationDeletedByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, organization := range q.organizations { - if organization.ID != arg.ID || organization.IsDefault { - continue - } - organization.Deleted = true - organization.UpdatedAt = arg.UpdatedAt - q.organizations[index] = organization - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdatePresetPrebuildStatus(ctx context.Context, arg database.UpdatePresetPrebuildStatusParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, preset := range q.presets { - if preset.ID == arg.PresetID { - preset.PrebuildStatus = arg.Status - return nil - } - } - - return xerrors.Errorf("preset %v does not exist", arg.PresetID) -} - -func (q *FakeQuerier) UpdateProvisionerDaemonLastSeenAt(_ context.Context, arg database.UpdateProvisionerDaemonLastSeenAtParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for idx := range q.provisionerDaemons { - if q.provisionerDaemons[idx].ID != arg.ID { - continue - } - if q.provisionerDaemons[idx].LastSeenAt.Time.After(arg.LastSeenAt.Time) { - continue - } - q.provisionerDaemons[idx].LastSeenAt = arg.LastSeenAt - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateProvisionerJobByID(_ context.Context, arg database.UpdateProvisionerJobByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, job := range q.provisionerJobs { - if arg.ID != job.ID { - continue - } - job.UpdatedAt = arg.UpdatedAt - job.JobStatus = provisionerJobStatus(job) - q.provisionerJobs[index] = job - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateProvisionerJobWithCancelByID(_ context.Context, arg database.UpdateProvisionerJobWithCancelByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, job := range q.provisionerJobs { - if arg.ID != job.ID { - continue - } - job.CanceledAt = arg.CanceledAt - job.CompletedAt = arg.CompletedAt - job.JobStatus = provisionerJobStatus(job) - q.provisionerJobs[index] = job - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateProvisionerJobWithCompleteByID(_ context.Context, arg database.UpdateProvisionerJobWithCompleteByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, job := range q.provisionerJobs { - if arg.ID != job.ID { - continue - } - job.UpdatedAt = arg.UpdatedAt - job.CompletedAt = arg.CompletedAt - job.Error = arg.Error - job.ErrorCode = arg.ErrorCode - job.JobStatus = provisionerJobStatus(job) - q.provisionerJobs[index] = job - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateProvisionerJobWithCompleteWithStartedAtByID(_ context.Context, arg database.UpdateProvisionerJobWithCompleteWithStartedAtByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, job := range q.provisionerJobs { - if arg.ID != job.ID { - continue - } - job.UpdatedAt = arg.UpdatedAt - job.CompletedAt = arg.CompletedAt - job.Error = arg.Error - job.ErrorCode = arg.ErrorCode - job.StartedAt = arg.StartedAt - job.JobStatus = provisionerJobStatus(job) - q.provisionerJobs[index] = job - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateReplica(_ context.Context, arg database.UpdateReplicaParams) (database.Replica, error) { - if err := validateDatabaseType(arg); err != nil { - return database.Replica{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, replica := range q.replicas { - if replica.ID != arg.ID { - continue - } - replica.Hostname = arg.Hostname - replica.StartedAt = arg.StartedAt - replica.StoppedAt = arg.StoppedAt - replica.UpdatedAt = arg.UpdatedAt - replica.RelayAddress = arg.RelayAddress - replica.RegionID = arg.RegionID - replica.Version = arg.Version - replica.Error = arg.Error - replica.DatabaseLatency = arg.DatabaseLatency - replica.Primary = arg.Primary - q.replicas[index] = replica - return replica, nil - } - return database.Replica{}, sql.ErrNoRows -} - -func (*FakeQuerier) UpdateTailnetPeerStatusByCoordinator(context.Context, database.UpdateTailnetPeerStatusByCoordinatorParams) error { - return ErrUnimplemented -} - -func (q *FakeQuerier) UpdateTemplateACLByID(_ context.Context, arg database.UpdateTemplateACLByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, template := range q.templates { - if template.ID == arg.ID { - template.GroupACL = arg.GroupACL - template.UserACL = arg.UserACL - - q.templates[i] = template - return nil - } - } - - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateTemplateAccessControlByID(_ context.Context, arg database.UpdateTemplateAccessControlByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for idx, tpl := range q.templates { - if tpl.ID != arg.ID { - continue - } - q.templates[idx].RequireActiveVersion = arg.RequireActiveVersion - q.templates[idx].Deprecated = arg.Deprecated - return nil - } - - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateTemplateActiveVersionByID(_ context.Context, arg database.UpdateTemplateActiveVersionByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, template := range q.templates { - if template.ID != arg.ID { - continue - } - template.ActiveVersionID = arg.ActiveVersionID - template.UpdatedAt = arg.UpdatedAt - q.templates[index] = template - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateTemplateDeletedByID(_ context.Context, arg database.UpdateTemplateDeletedByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, template := range q.templates { - if template.ID != arg.ID { - continue - } - template.Deleted = arg.Deleted - template.UpdatedAt = arg.UpdatedAt - q.templates[index] = template - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateTemplateMetaByID(_ context.Context, arg database.UpdateTemplateMetaByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for idx, tpl := range q.templates { - if tpl.ID != arg.ID { - continue - } - tpl.UpdatedAt = dbtime.Now() - tpl.Name = arg.Name - tpl.DisplayName = arg.DisplayName - tpl.Description = arg.Description - tpl.Icon = arg.Icon - tpl.GroupACL = arg.GroupACL - tpl.AllowUserCancelWorkspaceJobs = arg.AllowUserCancelWorkspaceJobs - tpl.MaxPortSharingLevel = arg.MaxPortSharingLevel - tpl.UseClassicParameterFlow = arg.UseClassicParameterFlow - q.templates[idx] = tpl - return nil - } - - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateTemplateScheduleByID(_ context.Context, arg database.UpdateTemplateScheduleByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for idx, tpl := range q.templates { - if tpl.ID != arg.ID { - continue - } - tpl.AllowUserAutostart = arg.AllowUserAutostart - tpl.AllowUserAutostop = arg.AllowUserAutostop - tpl.UpdatedAt = dbtime.Now() - tpl.DefaultTTL = arg.DefaultTTL - tpl.ActivityBump = arg.ActivityBump - tpl.AutostopRequirementDaysOfWeek = arg.AutostopRequirementDaysOfWeek - tpl.AutostopRequirementWeeks = arg.AutostopRequirementWeeks - tpl.AutostartBlockDaysOfWeek = arg.AutostartBlockDaysOfWeek - tpl.FailureTTL = arg.FailureTTL - tpl.TimeTilDormant = arg.TimeTilDormant - tpl.TimeTilDormantAutoDelete = arg.TimeTilDormantAutoDelete - q.templates[idx] = tpl - return nil - } - - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateTemplateVersionAITaskByJobID(_ context.Context, arg database.UpdateTemplateVersionAITaskByJobIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, templateVersion := range q.templateVersions { - if templateVersion.JobID != arg.JobID { - continue - } - templateVersion.HasAITask = arg.HasAITask - templateVersion.UpdatedAt = arg.UpdatedAt - q.templateVersions[index] = templateVersion - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateTemplateVersionByID(_ context.Context, arg database.UpdateTemplateVersionByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, templateVersion := range q.templateVersions { - if templateVersion.ID != arg.ID { - continue - } - templateVersion.TemplateID = arg.TemplateID - templateVersion.UpdatedAt = arg.UpdatedAt - templateVersion.Name = arg.Name - templateVersion.Message = arg.Message - q.templateVersions[index] = templateVersion - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateTemplateVersionDescriptionByJobID(_ context.Context, arg database.UpdateTemplateVersionDescriptionByJobIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, templateVersion := range q.templateVersions { - if templateVersion.JobID != arg.JobID { - continue - } - templateVersion.Readme = arg.Readme - templateVersion.UpdatedAt = arg.UpdatedAt - q.templateVersions[index] = templateVersion - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateTemplateVersionExternalAuthProvidersByJobID(_ context.Context, arg database.UpdateTemplateVersionExternalAuthProvidersByJobIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, templateVersion := range q.templateVersions { - if templateVersion.JobID != arg.JobID { - continue - } - templateVersion.ExternalAuthProviders = arg.ExternalAuthProviders - templateVersion.UpdatedAt = arg.UpdatedAt - q.templateVersions[index] = templateVersion - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateTemplateWorkspacesLastUsedAt(_ context.Context, arg database.UpdateTemplateWorkspacesLastUsedAtParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, ws := range q.workspaces { - if ws.TemplateID != arg.TemplateID { - continue - } - ws.LastUsedAt = arg.LastUsedAt - q.workspaces[i] = ws - } - - return nil -} - -func (q *FakeQuerier) UpdateUserDeletedByID(_ context.Context, id uuid.UUID) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, u := range q.users { - if u.ID == id { - u.Deleted = true - q.users[i] = u - // NOTE: In the real world, this is done by a trigger. - q.apiKeys = slices.DeleteFunc(q.apiKeys, func(u database.APIKey) bool { - return id == u.UserID - }) - - q.userLinks = slices.DeleteFunc(q.userLinks, func(u database.UserLink) bool { - return id == u.UserID - }) - return nil - } - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateUserGithubComUserID(_ context.Context, arg database.UpdateUserGithubComUserIDParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, user := range q.users { - if user.ID != arg.ID { - continue - } - user.GithubComUserID = arg.GithubComUserID - q.users[i] = user - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateUserHashedOneTimePasscode(_ context.Context, arg database.UpdateUserHashedOneTimePasscodeParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, user := range q.users { - if user.ID != arg.ID { - continue - } - user.HashedOneTimePasscode = arg.HashedOneTimePasscode - user.OneTimePasscodeExpiresAt = arg.OneTimePasscodeExpiresAt - q.users[i] = user - } - return nil -} - -func (q *FakeQuerier) UpdateUserHashedPassword(_ context.Context, arg database.UpdateUserHashedPasswordParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, user := range q.users { - if user.ID != arg.ID { - continue - } - user.HashedPassword = arg.HashedPassword - user.HashedOneTimePasscode = nil - user.OneTimePasscodeExpiresAt = sql.NullTime{} - q.users[i] = user - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateUserLastSeenAt(_ context.Context, arg database.UpdateUserLastSeenAtParams) (database.User, error) { - if err := validateDatabaseType(arg); err != nil { - return database.User{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, user := range q.users { - if user.ID != arg.ID { - continue - } - user.LastSeenAt = arg.LastSeenAt - user.UpdatedAt = arg.UpdatedAt - q.users[index] = user - return user, nil - } - return database.User{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateUserLink(_ context.Context, params database.UpdateUserLinkParams) (database.UserLink, error) { - if err := validateDatabaseType(params); err != nil { - return database.UserLink{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - if u, err := q.getUserByIDNoLock(params.UserID); err == nil && u.Deleted { - return database.UserLink{}, deletedUserLinkError - } - - for i, link := range q.userLinks { - if link.UserID == params.UserID && link.LoginType == params.LoginType { - link.OAuthAccessToken = params.OAuthAccessToken - link.OAuthAccessTokenKeyID = params.OAuthAccessTokenKeyID - link.OAuthRefreshToken = params.OAuthRefreshToken - link.OAuthRefreshTokenKeyID = params.OAuthRefreshTokenKeyID - link.OAuthExpiry = params.OAuthExpiry - link.Claims = params.Claims - - q.userLinks[i] = link - return link, nil - } - } - - return database.UserLink{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateUserLinkedID(_ context.Context, params database.UpdateUserLinkedIDParams) (database.UserLink, error) { - if err := validateDatabaseType(params); err != nil { - return database.UserLink{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, link := range q.userLinks { - if link.UserID == params.UserID && link.LoginType == params.LoginType { - link.LinkedID = params.LinkedID - - q.userLinks[i] = link - return link, nil - } - } - - return database.UserLink{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateUserLoginType(_ context.Context, arg database.UpdateUserLoginTypeParams) (database.User, error) { - if err := validateDatabaseType(arg); err != nil { - return database.User{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, u := range q.users { - if u.ID == arg.UserID { - u.LoginType = arg.NewLoginType - if arg.NewLoginType != database.LoginTypePassword { - u.HashedPassword = []byte{} - } - q.users[i] = u - return u, nil - } - } - return database.User{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateUserNotificationPreferences(_ context.Context, arg database.UpdateUserNotificationPreferencesParams) (int64, error) { - err := validateDatabaseType(arg) - if err != nil { - return 0, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - var upserted int64 - for i := range arg.NotificationTemplateIds { - var ( - found bool - templateID = arg.NotificationTemplateIds[i] - disabled = arg.Disableds[i] - ) - - for j, np := range q.notificationPreferences { - if np.UserID != arg.UserID { - continue - } - - if np.NotificationTemplateID != templateID { - continue - } - - np.Disabled = disabled - np.UpdatedAt = dbtime.Now() - q.notificationPreferences[j] = np - - upserted++ - found = true - break - } - - if !found { - np := database.NotificationPreference{ - Disabled: disabled, - UserID: arg.UserID, - NotificationTemplateID: templateID, - CreatedAt: dbtime.Now(), - UpdatedAt: dbtime.Now(), - } - q.notificationPreferences = append(q.notificationPreferences, np) - upserted++ - } - } - - return upserted, nil -} - -func (q *FakeQuerier) UpdateUserProfile(_ context.Context, arg database.UpdateUserProfileParams) (database.User, error) { - if err := validateDatabaseType(arg); err != nil { - return database.User{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, user := range q.users { - if user.ID != arg.ID { - continue - } - user.Email = arg.Email - user.Username = arg.Username - user.AvatarURL = arg.AvatarURL - user.Name = arg.Name - q.users[index] = user - return user, nil - } - return database.User{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateUserQuietHoursSchedule(_ context.Context, arg database.UpdateUserQuietHoursScheduleParams) (database.User, error) { - if err := validateDatabaseType(arg); err != nil { - return database.User{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, user := range q.users { - if user.ID != arg.ID { - continue - } - user.QuietHoursSchedule = arg.QuietHoursSchedule - q.users[index] = user - return user, nil - } - return database.User{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateUserRoles(_ context.Context, arg database.UpdateUserRolesParams) (database.User, error) { - if err := validateDatabaseType(arg); err != nil { - return database.User{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, user := range q.users { - if user.ID != arg.ID { - continue - } - - // Set new roles - user.RBACRoles = slice.Unique(arg.GrantedRoles) - // Remove duplicates and sort - uniqueRoles := make([]string, 0, len(user.RBACRoles)) - exist := make(map[string]struct{}) - for _, r := range user.RBACRoles { - if _, ok := exist[r]; ok { - continue - } - exist[r] = struct{}{} - uniqueRoles = append(uniqueRoles, r) - } - sort.Strings(uniqueRoles) - user.RBACRoles = uniqueRoles - - q.users[index] = user - return user, nil - } - return database.User{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateUserStatus(_ context.Context, arg database.UpdateUserStatusParams) (database.User, error) { - if err := validateDatabaseType(arg); err != nil { - return database.User{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, user := range q.users { - if user.ID != arg.ID { - continue - } - user.Status = arg.Status - user.UpdatedAt = arg.UpdatedAt - q.users[index] = user - - q.userStatusChanges = append(q.userStatusChanges, database.UserStatusChange{ - UserID: user.ID, - NewStatus: user.Status, - ChangedAt: user.UpdatedAt, - }) - return user, nil - } - return database.User{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateUserTerminalFont(ctx context.Context, arg database.UpdateUserTerminalFontParams) (database.UserConfig, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.UserConfig{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, uc := range q.userConfigs { - if uc.UserID != arg.UserID || uc.Key != "terminal_font" { - continue - } - uc.Value = arg.TerminalFont - q.userConfigs[i] = uc - return uc, nil - } - - uc := database.UserConfig{ - UserID: arg.UserID, - Key: "terminal_font", - Value: arg.TerminalFont, - } - q.userConfigs = append(q.userConfigs, uc) - return uc, nil -} - -func (q *FakeQuerier) UpdateUserThemePreference(_ context.Context, arg database.UpdateUserThemePreferenceParams) (database.UserConfig, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.UserConfig{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, uc := range q.userConfigs { - if uc.UserID != arg.UserID || uc.Key != "theme_preference" { - continue - } - uc.Value = arg.ThemePreference - q.userConfigs[i] = uc - return uc, nil - } - - uc := database.UserConfig{ - UserID: arg.UserID, - Key: "theme_preference", - Value: arg.ThemePreference, - } - q.userConfigs = append(q.userConfigs, uc) - return uc, nil -} - -func (q *FakeQuerier) UpdateVolumeResourceMonitor(_ context.Context, arg database.UpdateVolumeResourceMonitorParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, monitor := range q.workspaceAgentVolumeResourceMonitors { - if monitor.AgentID != arg.AgentID || monitor.Path != arg.Path { - continue - } - - monitor.State = arg.State - monitor.UpdatedAt = arg.UpdatedAt - monitor.DebouncedUntil = arg.DebouncedUntil - q.workspaceAgentVolumeResourceMonitors[i] = monitor - return nil - } - - return nil -} - -func (q *FakeQuerier) UpdateWorkspace(_ context.Context, arg database.UpdateWorkspaceParams) (database.WorkspaceTable, error) { - if err := validateDatabaseType(arg); err != nil { - return database.WorkspaceTable{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, workspace := range q.workspaces { - if workspace.Deleted || workspace.ID != arg.ID { - continue - } - for _, other := range q.workspaces { - if other.Deleted || other.ID == workspace.ID || workspace.OwnerID != other.OwnerID { - continue - } - if other.Name == arg.Name { - return database.WorkspaceTable{}, errUniqueConstraint - } - } - - workspace.Name = arg.Name - q.workspaces[i] = workspace - - return workspace, nil - } - - return database.WorkspaceTable{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceAgentConnectionByID(_ context.Context, arg database.UpdateWorkspaceAgentConnectionByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, agent := range q.workspaceAgents { - if agent.ID != arg.ID { - continue - } - agent.FirstConnectedAt = arg.FirstConnectedAt - agent.LastConnectedAt = arg.LastConnectedAt - agent.DisconnectedAt = arg.DisconnectedAt - agent.UpdatedAt = arg.UpdatedAt - agent.LastConnectedReplicaID = arg.LastConnectedReplicaID - q.workspaceAgents[index] = agent - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceAgentLifecycleStateByID(_ context.Context, arg database.UpdateWorkspaceAgentLifecycleStateByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - for i, agent := range q.workspaceAgents { - if agent.ID == arg.ID { - agent.LifecycleState = arg.LifecycleState - agent.StartedAt = arg.StartedAt - agent.ReadyAt = arg.ReadyAt - q.workspaceAgents[i] = agent - return nil - } - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceAgentLogOverflowByID(_ context.Context, arg database.UpdateWorkspaceAgentLogOverflowByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - for i, agent := range q.workspaceAgents { - if agent.ID == arg.ID { - agent.LogsOverflowed = arg.LogsOverflowed - q.workspaceAgents[i] = agent - return nil - } - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceAgentMetadata(_ context.Context, arg database.UpdateWorkspaceAgentMetadataParams) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, m := range q.workspaceAgentMetadata { - if m.WorkspaceAgentID != arg.WorkspaceAgentID { - continue - } - for j := 0; j < len(arg.Key); j++ { - if m.Key == arg.Key[j] { - q.workspaceAgentMetadata[i].Value = arg.Value[j] - q.workspaceAgentMetadata[i].Error = arg.Error[j] - q.workspaceAgentMetadata[i].CollectedAt = arg.CollectedAt[j] - return nil - } - } - } - - return nil -} - -func (q *FakeQuerier) UpdateWorkspaceAgentStartupByID(_ context.Context, arg database.UpdateWorkspaceAgentStartupByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - if len(arg.Subsystems) > 0 { - seen := map[database.WorkspaceAgentSubsystem]struct{}{ - arg.Subsystems[0]: {}, - } - for i := 1; i < len(arg.Subsystems); i++ { - s := arg.Subsystems[i] - if _, ok := seen[s]; ok { - return xerrors.Errorf("duplicate subsystem %q", s) - } - seen[s] = struct{}{} - - if arg.Subsystems[i-1] > arg.Subsystems[i] { - return xerrors.Errorf("subsystems not sorted: %q > %q", arg.Subsystems[i-1], arg.Subsystems[i]) - } - } - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, agent := range q.workspaceAgents { - if agent.ID != arg.ID { - continue - } - - agent.Version = arg.Version - agent.APIVersion = arg.APIVersion - agent.ExpandedDirectory = arg.ExpandedDirectory - agent.Subsystems = arg.Subsystems - q.workspaceAgents[index] = agent - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceAppHealthByID(_ context.Context, arg database.UpdateWorkspaceAppHealthByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, app := range q.workspaceApps { - if app.ID != arg.ID { - continue - } - app.Health = arg.Health - q.workspaceApps[index] = app - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceAutomaticUpdates(_ context.Context, arg database.UpdateWorkspaceAutomaticUpdatesParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, workspace := range q.workspaces { - if workspace.ID != arg.ID { - continue - } - workspace.AutomaticUpdates = arg.AutomaticUpdates - q.workspaces[index] = workspace - return nil - } - - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceAutostart(_ context.Context, arg database.UpdateWorkspaceAutostartParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, workspace := range q.workspaces { - if workspace.ID != arg.ID { - continue - } - workspace.AutostartSchedule = arg.AutostartSchedule - workspace.NextStartAt = arg.NextStartAt - q.workspaces[index] = workspace - return nil - } - - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceBuildAITaskByID(_ context.Context, arg database.UpdateWorkspaceBuildAITaskByIDParams) error { - if arg.HasAITask.Bool && !arg.SidebarAppID.Valid { - return xerrors.Errorf("ai_task_sidebar_app_id is required when has_ai_task is true") - } - if !arg.HasAITask.Valid && arg.SidebarAppID.Valid { - return xerrors.Errorf("ai_task_sidebar_app_id is can only be set when has_ai_task is true") - } - - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, workspaceBuild := range q.workspaceBuilds { - if workspaceBuild.ID != arg.ID { - continue - } - workspaceBuild.HasAITask = arg.HasAITask - workspaceBuild.AITaskSidebarAppID = arg.SidebarAppID - workspaceBuild.UpdatedAt = dbtime.Now() - q.workspaceBuilds[index] = workspaceBuild - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceBuildCostByID(_ context.Context, arg database.UpdateWorkspaceBuildCostByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, workspaceBuild := range q.workspaceBuilds { - if workspaceBuild.ID != arg.ID { - continue - } - workspaceBuild.DailyCost = arg.DailyCost - q.workspaceBuilds[index] = workspaceBuild - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceBuildDeadlineByID(_ context.Context, arg database.UpdateWorkspaceBuildDeadlineByIDParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for idx, build := range q.workspaceBuilds { - if build.ID != arg.ID { - continue - } - build.Deadline = arg.Deadline - build.MaxDeadline = arg.MaxDeadline - build.UpdatedAt = arg.UpdatedAt - q.workspaceBuilds[idx] = build - return nil - } - - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceBuildProvisionerStateByID(_ context.Context, arg database.UpdateWorkspaceBuildProvisionerStateByIDParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for idx, build := range q.workspaceBuilds { - if build.ID != arg.ID { - continue - } - build.ProvisionerState = arg.ProvisionerState - build.UpdatedAt = arg.UpdatedAt - q.workspaceBuilds[idx] = build - return nil - } - - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceDeletedByID(_ context.Context, arg database.UpdateWorkspaceDeletedByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, workspace := range q.workspaces { - if workspace.ID != arg.ID { - continue - } - workspace.Deleted = arg.Deleted - q.workspaces[index] = workspace - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceDormantDeletingAt(_ context.Context, arg database.UpdateWorkspaceDormantDeletingAtParams) (database.WorkspaceTable, error) { - if err := validateDatabaseType(arg); err != nil { - return database.WorkspaceTable{}, err - } - q.mutex.Lock() - defer q.mutex.Unlock() - for index, workspace := range q.workspaces { - if workspace.ID != arg.ID { - continue - } - workspace.DormantAt = arg.DormantAt - if workspace.DormantAt.Time.IsZero() { - workspace.LastUsedAt = dbtime.Now() - workspace.DeletingAt = sql.NullTime{} - } - if !workspace.DormantAt.Time.IsZero() { - var template database.TemplateTable - for _, t := range q.templates { - if t.ID == workspace.TemplateID { - template = t - break - } - } - if template.ID == uuid.Nil { - return database.WorkspaceTable{}, xerrors.Errorf("unable to find workspace template") - } - if template.TimeTilDormantAutoDelete > 0 { - workspace.DeletingAt = sql.NullTime{ - Valid: true, - Time: workspace.DormantAt.Time.Add(time.Duration(template.TimeTilDormantAutoDelete)), - } - } - } - q.workspaces[index] = workspace - return workspace, nil - } - return database.WorkspaceTable{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceLastUsedAt(_ context.Context, arg database.UpdateWorkspaceLastUsedAtParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, workspace := range q.workspaces { - if workspace.ID != arg.ID { - continue - } - workspace.LastUsedAt = arg.LastUsedAt - q.workspaces[index] = workspace - return nil - } - - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceNextStartAt(_ context.Context, arg database.UpdateWorkspaceNextStartAtParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, workspace := range q.workspaces { - if workspace.ID != arg.ID { - continue - } - - workspace.NextStartAt = arg.NextStartAt - q.workspaces[index] = workspace - - return nil - } - - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceProxy(_ context.Context, arg database.UpdateWorkspaceProxyParams) (database.WorkspaceProxy, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, p := range q.workspaceProxies { - if p.Name == arg.Name && p.ID != arg.ID { - return database.WorkspaceProxy{}, errUniqueConstraint - } - } - - for i, p := range q.workspaceProxies { - if p.ID == arg.ID { - p.Name = arg.Name - p.DisplayName = arg.DisplayName - p.Icon = arg.Icon - if len(p.TokenHashedSecret) > 0 { - p.TokenHashedSecret = arg.TokenHashedSecret - } - q.workspaceProxies[i] = p - return p, nil - } - } - return database.WorkspaceProxy{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceProxyDeleted(_ context.Context, arg database.UpdateWorkspaceProxyDeletedParams) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, p := range q.workspaceProxies { - if p.ID == arg.ID { - p.Deleted = arg.Deleted - p.UpdatedAt = dbtime.Now() - q.workspaceProxies[i] = p - return nil - } - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceTTL(_ context.Context, arg database.UpdateWorkspaceTTLParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, workspace := range q.workspaces { - if workspace.ID != arg.ID { - continue - } - workspace.Ttl = arg.Ttl - q.workspaces[index] = workspace - return nil - } - - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspacesDormantDeletingAtByTemplateID(_ context.Context, arg database.UpdateWorkspacesDormantDeletingAtByTemplateIDParams) ([]database.WorkspaceTable, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - affectedRows := []database.WorkspaceTable{} - for i, ws := range q.workspaces { - if ws.TemplateID != arg.TemplateID { - continue - } - - if ws.DormantAt.Time.IsZero() { - continue - } - - if !arg.DormantAt.IsZero() { - ws.DormantAt = sql.NullTime{ - Valid: true, - Time: arg.DormantAt, - } - } - - deletingAt := sql.NullTime{ - Valid: arg.TimeTilDormantAutodeleteMs > 0, - } - if arg.TimeTilDormantAutodeleteMs > 0 { - deletingAt.Time = ws.DormantAt.Time.Add(time.Duration(arg.TimeTilDormantAutodeleteMs) * time.Millisecond) - } - ws.DeletingAt = deletingAt - q.workspaces[i] = ws - affectedRows = append(affectedRows, ws) - } - - return affectedRows, nil -} - -func (q *FakeQuerier) UpdateWorkspacesTTLByTemplateID(_ context.Context, arg database.UpdateWorkspacesTTLByTemplateIDParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, ws := range q.workspaces { - if ws.TemplateID != arg.TemplateID { - continue - } - - q.workspaces[i].Ttl = arg.Ttl - } - - return nil -} - -func (q *FakeQuerier) UpsertAnnouncementBanners(_ context.Context, data string) error { - q.mutex.RLock() - defer q.mutex.RUnlock() - - q.announcementBanners = []byte(data) - return nil -} - -func (q *FakeQuerier) UpsertAppSecurityKey(_ context.Context, data string) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - q.appSecurityKey = data - return nil -} - -func (q *FakeQuerier) UpsertApplicationName(_ context.Context, data string) error { - q.mutex.RLock() - defer q.mutex.RUnlock() - - q.applicationName = data - return nil -} - -func (q *FakeQuerier) UpsertCoordinatorResumeTokenSigningKey(_ context.Context, value string) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - q.coordinatorResumeTokenSigningKey = value - return nil -} - -func (q *FakeQuerier) UpsertDefaultProxy(_ context.Context, arg database.UpsertDefaultProxyParams) error { - q.defaultProxyDisplayName = arg.DisplayName - q.defaultProxyIconURL = arg.IconUrl - return nil -} - -func (q *FakeQuerier) UpsertHealthSettings(_ context.Context, data string) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - q.healthSettings = []byte(data) - return nil -} - -func (q *FakeQuerier) UpsertLastUpdateCheck(_ context.Context, data string) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - q.lastUpdateCheck = []byte(data) - return nil -} - -func (q *FakeQuerier) UpsertLogoURL(_ context.Context, data string) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - q.logoURL = data - return nil -} - -func (q *FakeQuerier) UpsertNotificationReportGeneratorLog(_ context.Context, arg database.UpsertNotificationReportGeneratorLogParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, record := range q.notificationReportGeneratorLogs { - if arg.NotificationTemplateID == record.NotificationTemplateID { - q.notificationReportGeneratorLogs[i].LastGeneratedAt = arg.LastGeneratedAt - return nil - } - } - - q.notificationReportGeneratorLogs = append(q.notificationReportGeneratorLogs, database.NotificationReportGeneratorLog(arg)) - return nil -} - -func (q *FakeQuerier) UpsertNotificationsSettings(_ context.Context, data string) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - q.notificationsSettings = []byte(data) - return nil -} - -func (q *FakeQuerier) UpsertOAuth2GithubDefaultEligible(_ context.Context, eligible bool) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - q.oauth2GithubDefaultEligible = &eligible - return nil -} - -func (q *FakeQuerier) UpsertOAuthSigningKey(_ context.Context, value string) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - q.oauthSigningKey = value - return nil -} - -func (q *FakeQuerier) UpsertPrebuildsSettings(_ context.Context, value string) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - q.prebuildsSettings = []byte(value) - return nil -} - -func (q *FakeQuerier) UpsertProvisionerDaemon(_ context.Context, arg database.UpsertProvisionerDaemonParams) (database.ProvisionerDaemon, error) { - if err := validateDatabaseType(arg); err != nil { - return database.ProvisionerDaemon{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - // Look for existing daemon using the same composite key as SQL - for i, d := range q.provisionerDaemons { - if d.OrganizationID == arg.OrganizationID && - d.Name == arg.Name && - getOwnerFromTags(d.Tags) == getOwnerFromTags(arg.Tags) { - d.Provisioners = arg.Provisioners - d.Tags = maps.Clone(arg.Tags) - d.LastSeenAt = arg.LastSeenAt - d.Version = arg.Version - d.APIVersion = arg.APIVersion - d.OrganizationID = arg.OrganizationID - d.KeyID = arg.KeyID - q.provisionerDaemons[i] = d - return d, nil - } - } - d := database.ProvisionerDaemon{ - ID: uuid.New(), - CreatedAt: arg.CreatedAt, - Name: arg.Name, - Provisioners: arg.Provisioners, - Tags: maps.Clone(arg.Tags), - LastSeenAt: arg.LastSeenAt, - Version: arg.Version, - APIVersion: arg.APIVersion, - OrganizationID: arg.OrganizationID, - KeyID: arg.KeyID, - } - q.provisionerDaemons = append(q.provisionerDaemons, d) - return d, nil -} - -func (q *FakeQuerier) UpsertRuntimeConfig(_ context.Context, arg database.UpsertRuntimeConfigParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - q.runtimeConfig[arg.Key] = arg.Value - return nil -} - -func (*FakeQuerier) UpsertTailnetAgent(context.Context, database.UpsertTailnetAgentParams) (database.TailnetAgent, error) { - return database.TailnetAgent{}, ErrUnimplemented -} - -func (*FakeQuerier) UpsertTailnetClient(context.Context, database.UpsertTailnetClientParams) (database.TailnetClient, error) { - return database.TailnetClient{}, ErrUnimplemented -} - -func (*FakeQuerier) UpsertTailnetClientSubscription(context.Context, database.UpsertTailnetClientSubscriptionParams) error { - return ErrUnimplemented -} - -func (*FakeQuerier) UpsertTailnetCoordinator(context.Context, uuid.UUID) (database.TailnetCoordinator, error) { - return database.TailnetCoordinator{}, ErrUnimplemented -} - -func (*FakeQuerier) UpsertTailnetPeer(_ context.Context, arg database.UpsertTailnetPeerParams) (database.TailnetPeer, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.TailnetPeer{}, err - } - - return database.TailnetPeer{}, ErrUnimplemented -} - -func (*FakeQuerier) UpsertTailnetTunnel(_ context.Context, arg database.UpsertTailnetTunnelParams) (database.TailnetTunnel, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.TailnetTunnel{}, err - } - - return database.TailnetTunnel{}, ErrUnimplemented -} - -func (q *FakeQuerier) UpsertTelemetryItem(_ context.Context, arg database.UpsertTelemetryItemParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, item := range q.telemetryItems { - if item.Key == arg.Key { - q.telemetryItems[i].Value = arg.Value - q.telemetryItems[i].UpdatedAt = time.Now() - return nil - } - } - - q.telemetryItems = append(q.telemetryItems, database.TelemetryItem{ - Key: arg.Key, - Value: arg.Value, - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - }) - - return nil -} - -func (q *FakeQuerier) UpsertTemplateUsageStats(ctx context.Context) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - /* - WITH - */ - - /* - latest_start AS ( - SELECT - -- Truncate to hour so that we always look at even ranges of data. - date_trunc('hour', COALESCE( - MAX(start_time) - '1 hour'::interval), - -- Fallback when there are no template usage stats yet. - -- App stats can exist before this, but not agent stats, - -- limit the lookback to avoid inconsistency. - (SELECT MIN(created_at) FROM workspace_agent_stats) - )) AS t - FROM - template_usage_stats - ), - */ - - now := time.Now() - latestStart := time.Time{} - for _, stat := range q.templateUsageStats { - if stat.StartTime.After(latestStart) { - latestStart = stat.StartTime.Add(-time.Hour) - } - } - if latestStart.IsZero() { - for _, stat := range q.workspaceAgentStats { - if latestStart.IsZero() || stat.CreatedAt.Before(latestStart) { - latestStart = stat.CreatedAt - } - } - } - if latestStart.IsZero() { - return nil - } - latestStart = latestStart.Truncate(time.Hour) - - /* - workspace_app_stat_buckets AS ( - SELECT - -- Truncate the minute to the nearest half hour, this is the bucket size - -- for the data. - date_trunc('hour', s.minute_bucket) + trunc(date_part('minute', s.minute_bucket) / 30) * 30 * '1 minute'::interval AS time_bucket, - w.template_id, - was.user_id, - -- Both app stats and agent stats track web terminal usage, but - -- by different means. The app stats value should be more - -- accurate so we don't want to discard it just yet. - CASE - WHEN was.access_method = 'terminal' - THEN '[terminal]' -- Unique name, app names can't contain brackets. - ELSE was.slug_or_port - END AS app_name, - COUNT(DISTINCT s.minute_bucket) AS app_minutes, - -- Store each unique minute bucket for later merge between datasets. - array_agg(DISTINCT s.minute_bucket) AS minute_buckets - FROM - workspace_app_stats AS was - JOIN - workspaces AS w - ON - w.id = was.workspace_id - -- Generate a series of minute buckets for each session for computing the - -- mintes/bucket. - CROSS JOIN - generate_series( - date_trunc('minute', was.session_started_at), - -- Subtract 1 microsecond to avoid creating an extra series. - date_trunc('minute', was.session_ended_at - '1 microsecond'::interval), - '1 minute'::interval - ) AS s(minute_bucket) - WHERE - -- s.minute_bucket >= @start_time::timestamptz - -- AND s.minute_bucket < @end_time::timestamptz - s.minute_bucket >= (SELECT t FROM latest_start) - AND s.minute_bucket < NOW() - GROUP BY - time_bucket, w.template_id, was.user_id, was.access_method, was.slug_or_port - ), - */ - - type workspaceAppStatGroupBy struct { - TimeBucket time.Time - TemplateID uuid.UUID - UserID uuid.UUID - AccessMethod string - SlugOrPort string - } - type workspaceAppStatRow struct { - workspaceAppStatGroupBy - AppName string - AppMinutes int - MinuteBuckets map[time.Time]struct{} - } - workspaceAppStatRows := make(map[workspaceAppStatGroupBy]workspaceAppStatRow) - for _, was := range q.workspaceAppStats { - // Preflight: s.minute_bucket >= (SELECT t FROM latest_start) - if was.SessionEndedAt.Before(latestStart) { - continue - } - // JOIN workspaces - w, err := q.getWorkspaceByIDNoLock(ctx, was.WorkspaceID) - if err != nil { - return err - } - // CROSS JOIN generate_series - for t := was.SessionStartedAt.Truncate(time.Minute); t.Before(was.SessionEndedAt); t = t.Add(time.Minute) { - // WHERE - if t.Before(latestStart) || t.After(now) || t.Equal(now) { - continue - } - - bucket := t.Truncate(30 * time.Minute) - // GROUP BY - key := workspaceAppStatGroupBy{ - TimeBucket: bucket, - TemplateID: w.TemplateID, - UserID: was.UserID, - AccessMethod: was.AccessMethod, - SlugOrPort: was.SlugOrPort, - } - // SELECT - row, ok := workspaceAppStatRows[key] - if !ok { - row = workspaceAppStatRow{ - workspaceAppStatGroupBy: key, - AppName: was.SlugOrPort, - AppMinutes: 0, - MinuteBuckets: make(map[time.Time]struct{}), - } - if was.AccessMethod == "terminal" { - row.AppName = "[terminal]" - } - } - row.MinuteBuckets[t] = struct{}{} - row.AppMinutes = len(row.MinuteBuckets) - workspaceAppStatRows[key] = row - } - } - - /* - agent_stats_buckets AS ( - SELECT - -- Truncate the minute to the nearest half hour, this is the bucket size - -- for the data. - date_trunc('hour', created_at) + trunc(date_part('minute', created_at) / 30) * 30 * '1 minute'::interval AS time_bucket, - template_id, - user_id, - -- Store each unique minute bucket for later merge between datasets. - array_agg( - DISTINCT CASE - WHEN - session_count_ssh > 0 - -- TODO(mafredri): Enable when we have the column. - -- OR session_count_sftp > 0 - OR session_count_reconnecting_pty > 0 - OR session_count_vscode > 0 - OR session_count_jetbrains > 0 - THEN - date_trunc('minute', created_at) - ELSE - NULL - END - ) AS minute_buckets, - COUNT(DISTINCT CASE WHEN session_count_ssh > 0 THEN date_trunc('minute', created_at) ELSE NULL END) AS ssh_mins, - -- TODO(mafredri): Enable when we have the column. - -- COUNT(DISTINCT CASE WHEN session_count_sftp > 0 THEN date_trunc('minute', created_at) ELSE NULL END) AS sftp_mins, - COUNT(DISTINCT CASE WHEN session_count_reconnecting_pty > 0 THEN date_trunc('minute', created_at) ELSE NULL END) AS reconnecting_pty_mins, - COUNT(DISTINCT CASE WHEN session_count_vscode > 0 THEN date_trunc('minute', created_at) ELSE NULL END) AS vscode_mins, - COUNT(DISTINCT CASE WHEN session_count_jetbrains > 0 THEN date_trunc('minute', created_at) ELSE NULL END) AS jetbrains_mins, - -- NOTE(mafredri): The agent stats are currently very unreliable, and - -- sometimes the connections are missing, even during active sessions. - -- Since we can't fully rely on this, we check for "any connection - -- during this half-hour". A better solution here would be preferable. - MAX(connection_count) > 0 AS has_connection - FROM - workspace_agent_stats - WHERE - -- created_at >= @start_time::timestamptz - -- AND created_at < @end_time::timestamptz - created_at >= (SELECT t FROM latest_start) - AND created_at < NOW() - -- Inclusion criteria to filter out empty results. - AND ( - session_count_ssh > 0 - -- TODO(mafredri): Enable when we have the column. - -- OR session_count_sftp > 0 - OR session_count_reconnecting_pty > 0 - OR session_count_vscode > 0 - OR session_count_jetbrains > 0 - ) - GROUP BY - time_bucket, template_id, user_id - ), - */ - - type agentStatGroupBy struct { - TimeBucket time.Time - TemplateID uuid.UUID - UserID uuid.UUID - } - type agentStatRow struct { - agentStatGroupBy - MinuteBuckets map[time.Time]struct{} - SSHMinuteBuckets map[time.Time]struct{} - SSHMins int - SFTPMinuteBuckets map[time.Time]struct{} - SFTPMins int - ReconnectingPTYMinuteBuckets map[time.Time]struct{} - ReconnectingPTYMins int - VSCodeMinuteBuckets map[time.Time]struct{} - VSCodeMins int - JetBrainsMinuteBuckets map[time.Time]struct{} - JetBrainsMins int - HasConnection bool - } - agentStatRows := make(map[agentStatGroupBy]agentStatRow) - for _, was := range q.workspaceAgentStats { - // WHERE - if was.CreatedAt.Before(latestStart) || was.CreatedAt.After(now) || was.CreatedAt.Equal(now) { - continue - } - if was.SessionCountSSH == 0 && was.SessionCountReconnectingPTY == 0 && was.SessionCountVSCode == 0 && was.SessionCountJetBrains == 0 { - continue - } - // GROUP BY - key := agentStatGroupBy{ - TimeBucket: was.CreatedAt.Truncate(30 * time.Minute), - TemplateID: was.TemplateID, - UserID: was.UserID, - } - // SELECT - row, ok := agentStatRows[key] - if !ok { - row = agentStatRow{ - agentStatGroupBy: key, - MinuteBuckets: make(map[time.Time]struct{}), - SSHMinuteBuckets: make(map[time.Time]struct{}), - SFTPMinuteBuckets: make(map[time.Time]struct{}), - ReconnectingPTYMinuteBuckets: make(map[time.Time]struct{}), - VSCodeMinuteBuckets: make(map[time.Time]struct{}), - JetBrainsMinuteBuckets: make(map[time.Time]struct{}), - } - } - minute := was.CreatedAt.Truncate(time.Minute) - row.MinuteBuckets[minute] = struct{}{} - if was.SessionCountSSH > 0 { - row.SSHMinuteBuckets[minute] = struct{}{} - row.SSHMins = len(row.SSHMinuteBuckets) - } - // TODO(mafredri): Enable when we have the column. - // if was.SessionCountSFTP > 0 { - // row.SFTPMinuteBuckets[minute] = struct{}{} - // row.SFTPMins = len(row.SFTPMinuteBuckets) - // } - _ = row.SFTPMinuteBuckets - if was.SessionCountReconnectingPTY > 0 { - row.ReconnectingPTYMinuteBuckets[minute] = struct{}{} - row.ReconnectingPTYMins = len(row.ReconnectingPTYMinuteBuckets) - } - if was.SessionCountVSCode > 0 { - row.VSCodeMinuteBuckets[minute] = struct{}{} - row.VSCodeMins = len(row.VSCodeMinuteBuckets) - } - if was.SessionCountJetBrains > 0 { - row.JetBrainsMinuteBuckets[minute] = struct{}{} - row.JetBrainsMins = len(row.JetBrainsMinuteBuckets) - } - if !row.HasConnection { - row.HasConnection = was.ConnectionCount > 0 - } - agentStatRows[key] = row - } - - /* - stats AS ( - SELECT - stats.time_bucket AS start_time, - stats.time_bucket + '30 minutes'::interval AS end_time, - stats.template_id, - stats.user_id, - -- Sum/distinct to handle zero/duplicate values due union and to unnest. - COUNT(DISTINCT minute_bucket) AS usage_mins, - array_agg(DISTINCT minute_bucket) AS minute_buckets, - SUM(DISTINCT stats.ssh_mins) AS ssh_mins, - SUM(DISTINCT stats.sftp_mins) AS sftp_mins, - SUM(DISTINCT stats.reconnecting_pty_mins) AS reconnecting_pty_mins, - SUM(DISTINCT stats.vscode_mins) AS vscode_mins, - SUM(DISTINCT stats.jetbrains_mins) AS jetbrains_mins, - -- This is what we unnested, re-nest as json. - jsonb_object_agg(stats.app_name, stats.app_minutes) FILTER (WHERE stats.app_name IS NOT NULL) AS app_usage_mins - FROM ( - SELECT - time_bucket, - template_id, - user_id, - 0 AS ssh_mins, - 0 AS sftp_mins, - 0 AS reconnecting_pty_mins, - 0 AS vscode_mins, - 0 AS jetbrains_mins, - app_name, - app_minutes, - minute_buckets - FROM - workspace_app_stat_buckets - - UNION ALL - - SELECT - time_bucket, - template_id, - user_id, - ssh_mins, - -- TODO(mafredri): Enable when we have the column. - 0 AS sftp_mins, - reconnecting_pty_mins, - vscode_mins, - jetbrains_mins, - NULL AS app_name, - NULL AS app_minutes, - minute_buckets - FROM - agent_stats_buckets - WHERE - -- See note in the agent_stats_buckets CTE. - has_connection - ) AS stats, unnest(minute_buckets) AS minute_bucket - GROUP BY - stats.time_bucket, stats.template_id, stats.user_id - ), - */ - - type statsGroupBy struct { - TimeBucket time.Time - TemplateID uuid.UUID - UserID uuid.UUID - } - type statsRow struct { - statsGroupBy - UsageMinuteBuckets map[time.Time]struct{} - UsageMins int - SSHMins int - SFTPMins int - ReconnectingPTYMins int - VSCodeMins int - JetBrainsMins int - AppUsageMinutes map[string]int - } - statsRows := make(map[statsGroupBy]statsRow) - for _, was := range workspaceAppStatRows { - // GROUP BY - key := statsGroupBy{ - TimeBucket: was.TimeBucket, - TemplateID: was.TemplateID, - UserID: was.UserID, - } - // SELECT - row, ok := statsRows[key] - if !ok { - row = statsRow{ - statsGroupBy: key, - UsageMinuteBuckets: make(map[time.Time]struct{}), - AppUsageMinutes: make(map[string]int), - } - } - for t := range was.MinuteBuckets { - row.UsageMinuteBuckets[t] = struct{}{} - } - row.UsageMins = len(row.UsageMinuteBuckets) - row.AppUsageMinutes[was.AppName] = was.AppMinutes - statsRows[key] = row - } - for _, was := range agentStatRows { - // GROUP BY - key := statsGroupBy{ - TimeBucket: was.TimeBucket, - TemplateID: was.TemplateID, - UserID: was.UserID, - } - // SELECT - row, ok := statsRows[key] - if !ok { - row = statsRow{ - statsGroupBy: key, - UsageMinuteBuckets: make(map[time.Time]struct{}), - AppUsageMinutes: make(map[string]int), - } - } - for t := range was.MinuteBuckets { - row.UsageMinuteBuckets[t] = struct{}{} - } - row.UsageMins = len(row.UsageMinuteBuckets) - row.SSHMins += was.SSHMins - row.SFTPMins += was.SFTPMins - row.ReconnectingPTYMins += was.ReconnectingPTYMins - row.VSCodeMins += was.VSCodeMins - row.JetBrainsMins += was.JetBrainsMins - statsRows[key] = row - } - - /* - minute_buckets AS ( - -- Create distinct minute buckets for user-activity, so we can filter out - -- irrelevant latencies. - SELECT DISTINCT ON (stats.start_time, stats.template_id, stats.user_id, minute_bucket) - stats.start_time, - stats.template_id, - stats.user_id, - minute_bucket - FROM - stats, unnest(minute_buckets) AS minute_bucket - ), - latencies AS ( - -- Select all non-zero latencies for all the minutes that a user used the - -- workspace in some way. - SELECT - mb.start_time, - mb.template_id, - mb.user_id, - -- TODO(mafredri): We're doing medians on medians here, we may want to - -- improve upon this at some point. - PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY was.connection_median_latency_ms)::real AS median_latency_ms - FROM - minute_buckets AS mb - JOIN - workspace_agent_stats AS was - ON - date_trunc('minute', was.created_at) = mb.minute_bucket - AND was.template_id = mb.template_id - AND was.user_id = mb.user_id - AND was.connection_median_latency_ms >= 0 - GROUP BY - mb.start_time, mb.template_id, mb.user_id - ) - */ - - type latenciesGroupBy struct { - StartTime time.Time - TemplateID uuid.UUID - UserID uuid.UUID - } - type latenciesRow struct { - latenciesGroupBy - Latencies []float64 - MedianLatencyMS float64 - } - latenciesRows := make(map[latenciesGroupBy]latenciesRow) - for _, stat := range statsRows { - for t := range stat.UsageMinuteBuckets { - // GROUP BY - key := latenciesGroupBy{ - StartTime: stat.TimeBucket, - TemplateID: stat.TemplateID, - UserID: stat.UserID, - } - // JOIN - for _, was := range q.workspaceAgentStats { - if !t.Equal(was.CreatedAt.Truncate(time.Minute)) { - continue - } - if was.TemplateID != stat.TemplateID || was.UserID != stat.UserID { - continue - } - if was.ConnectionMedianLatencyMS < 0 { - continue - } - // SELECT - row, ok := latenciesRows[key] - if !ok { - row = latenciesRow{ - latenciesGroupBy: key, - } - } - row.Latencies = append(row.Latencies, was.ConnectionMedianLatencyMS) - sort.Float64s(row.Latencies) - if len(row.Latencies) == 1 { - row.MedianLatencyMS = was.ConnectionMedianLatencyMS - } else if len(row.Latencies)%2 == 0 { - row.MedianLatencyMS = (row.Latencies[len(row.Latencies)/2-1] + row.Latencies[len(row.Latencies)/2]) / 2 - } else { - row.MedianLatencyMS = row.Latencies[len(row.Latencies)/2] - } - latenciesRows[key] = row - } - } - } - - /* - INSERT INTO template_usage_stats AS tus ( - start_time, - end_time, - template_id, - user_id, - usage_mins, - median_latency_ms, - ssh_mins, - sftp_mins, - reconnecting_pty_mins, - vscode_mins, - jetbrains_mins, - app_usage_mins - ) ( - SELECT - stats.start_time, - stats.end_time, - stats.template_id, - stats.user_id, - stats.usage_mins, - latencies.median_latency_ms, - stats.ssh_mins, - stats.sftp_mins, - stats.reconnecting_pty_mins, - stats.vscode_mins, - stats.jetbrains_mins, - stats.app_usage_mins - FROM - stats - LEFT JOIN - latencies - ON - -- The latencies group-by ensures there at most one row. - latencies.start_time = stats.start_time - AND latencies.template_id = stats.template_id - AND latencies.user_id = stats.user_id - ) - ON CONFLICT - (start_time, template_id, user_id) - DO UPDATE - SET - usage_mins = EXCLUDED.usage_mins, - median_latency_ms = EXCLUDED.median_latency_ms, - ssh_mins = EXCLUDED.ssh_mins, - sftp_mins = EXCLUDED.sftp_mins, - reconnecting_pty_mins = EXCLUDED.reconnecting_pty_mins, - vscode_mins = EXCLUDED.vscode_mins, - jetbrains_mins = EXCLUDED.jetbrains_mins, - app_usage_mins = EXCLUDED.app_usage_mins - WHERE - (tus.*) IS DISTINCT FROM (EXCLUDED.*); - */ - -TemplateUsageStatsInsertLoop: - for _, stat := range statsRows { - // LEFT JOIN latencies - latency, latencyOk := latenciesRows[latenciesGroupBy{ - StartTime: stat.TimeBucket, - TemplateID: stat.TemplateID, - UserID: stat.UserID, - }] - - // SELECT - tus := database.TemplateUsageStat{ - StartTime: stat.TimeBucket, - EndTime: stat.TimeBucket.Add(30 * time.Minute), - TemplateID: stat.TemplateID, - UserID: stat.UserID, - // #nosec G115 - Safe conversion for usage minutes which are expected to be within int16 range - UsageMins: int16(stat.UsageMins), - MedianLatencyMs: sql.NullFloat64{Float64: latency.MedianLatencyMS, Valid: latencyOk}, - // #nosec G115 - Safe conversion for SSH minutes which are expected to be within int16 range - SshMins: int16(stat.SSHMins), - // #nosec G115 - Safe conversion for SFTP minutes which are expected to be within int16 range - SftpMins: int16(stat.SFTPMins), - // #nosec G115 - Safe conversion for ReconnectingPTY minutes which are expected to be within int16 range - ReconnectingPtyMins: int16(stat.ReconnectingPTYMins), - // #nosec G115 - Safe conversion for VSCode minutes which are expected to be within int16 range - VscodeMins: int16(stat.VSCodeMins), - // #nosec G115 - Safe conversion for JetBrains minutes which are expected to be within int16 range - JetbrainsMins: int16(stat.JetBrainsMins), - } - if len(stat.AppUsageMinutes) > 0 { - tus.AppUsageMins = make(map[string]int64, len(stat.AppUsageMinutes)) - for k, v := range stat.AppUsageMinutes { - tus.AppUsageMins[k] = int64(v) - } - } - - // ON CONFLICT - for i, existing := range q.templateUsageStats { - if existing.StartTime.Equal(tus.StartTime) && existing.TemplateID == tus.TemplateID && existing.UserID == tus.UserID { - q.templateUsageStats[i] = tus - continue TemplateUsageStatsInsertLoop - } - } - // INSERT INTO - q.templateUsageStats = append(q.templateUsageStats, tus) - } - - return nil -} - -func (q *FakeQuerier) UpsertWebpushVAPIDKeys(_ context.Context, arg database.UpsertWebpushVAPIDKeysParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - q.webpushVAPIDPublicKey = arg.VapidPublicKey - q.webpushVAPIDPrivateKey = arg.VapidPrivateKey - return nil -} - -func (q *FakeQuerier) UpsertWorkspaceAgentPortShare(_ context.Context, arg database.UpsertWorkspaceAgentPortShareParams) (database.WorkspaceAgentPortShare, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.WorkspaceAgentPortShare{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, share := range q.workspaceAgentPortShares { - if share.WorkspaceID == arg.WorkspaceID && share.Port == arg.Port && share.AgentName == arg.AgentName { - share.ShareLevel = arg.ShareLevel - share.Protocol = arg.Protocol - q.workspaceAgentPortShares[i] = share - return share, nil - } - } - - //nolint:gosimple // casts are not a simplification - psl := database.WorkspaceAgentPortShare{ - WorkspaceID: arg.WorkspaceID, - AgentName: arg.AgentName, - Port: arg.Port, - ShareLevel: arg.ShareLevel, - Protocol: arg.Protocol, - } - q.workspaceAgentPortShares = append(q.workspaceAgentPortShares, psl) - - return psl, nil -} - -func (q *FakeQuerier) UpsertWorkspaceApp(ctx context.Context, arg database.UpsertWorkspaceAppParams) (database.WorkspaceApp, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.WorkspaceApp{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - if arg.SharingLevel == "" { - arg.SharingLevel = database.AppSharingLevelOwner - } - if arg.OpenIn == "" { - arg.OpenIn = database.WorkspaceAppOpenInSlimWindow - } - - buildApp := func(id uuid.UUID, createdAt time.Time) database.WorkspaceApp { - return database.WorkspaceApp{ - ID: id, - CreatedAt: createdAt, - AgentID: arg.AgentID, - Slug: arg.Slug, - DisplayName: arg.DisplayName, - Icon: arg.Icon, - Command: arg.Command, - Url: arg.Url, - External: arg.External, - Subdomain: arg.Subdomain, - SharingLevel: arg.SharingLevel, - HealthcheckUrl: arg.HealthcheckUrl, - HealthcheckInterval: arg.HealthcheckInterval, - HealthcheckThreshold: arg.HealthcheckThreshold, - Health: arg.Health, - Hidden: arg.Hidden, - DisplayOrder: arg.DisplayOrder, - OpenIn: arg.OpenIn, - DisplayGroup: arg.DisplayGroup, - } - } - - for i, app := range q.workspaceApps { - if app.ID == arg.ID { - q.workspaceApps[i] = buildApp(app.ID, app.CreatedAt) - return q.workspaceApps[i], nil - } - } - - workspaceApp := buildApp(arg.ID, arg.CreatedAt) - q.workspaceApps = append(q.workspaceApps, workspaceApp) - return workspaceApp, nil -} - -func (q *FakeQuerier) UpsertWorkspaceAppAuditSession(_ context.Context, arg database.UpsertWorkspaceAppAuditSessionParams) (bool, error) { - err := validateDatabaseType(arg) - if err != nil { - return false, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, s := range q.workspaceAppAuditSessions { - if s.AgentID != arg.AgentID { - continue - } - if s.AppID != arg.AppID { - continue - } - if s.UserID != arg.UserID { - continue - } - if s.Ip != arg.Ip { - continue - } - if s.UserAgent != arg.UserAgent { - continue - } - if s.SlugOrPort != arg.SlugOrPort { - continue - } - if s.StatusCode != arg.StatusCode { - continue - } - - staleTime := dbtime.Now().Add(-(time.Duration(arg.StaleIntervalMS) * time.Millisecond)) - fresh := s.UpdatedAt.After(staleTime) - - q.workspaceAppAuditSessions[i].UpdatedAt = arg.UpdatedAt - if !fresh { - q.workspaceAppAuditSessions[i].ID = arg.ID - q.workspaceAppAuditSessions[i].StartedAt = arg.StartedAt - return true, nil - } - return false, nil - } - - q.workspaceAppAuditSessions = append(q.workspaceAppAuditSessions, database.WorkspaceAppAuditSession{ - AgentID: arg.AgentID, - AppID: arg.AppID, - UserID: arg.UserID, - Ip: arg.Ip, - UserAgent: arg.UserAgent, - SlugOrPort: arg.SlugOrPort, - StatusCode: arg.StatusCode, - StartedAt: arg.StartedAt, - UpdatedAt: arg.UpdatedAt, - }) - return true, nil -} - -func (q *FakeQuerier) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]database.Template, error) { - if err := validateDatabaseType(arg); err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - // Call this to match the same function calls as the SQL implementation. - if prepared != nil { - _, err := prepared.CompileToSQL(ctx, rbac.ConfigWithACL()) - if err != nil { - return nil, err - } - } - - var templates []database.Template - for _, templateTable := range q.templates { - template := q.templateWithNameNoLock(templateTable) - if prepared != nil && prepared.Authorize(ctx, template.RBACObject()) != nil { - continue - } - - if template.Deleted != arg.Deleted { - continue - } - if arg.OrganizationID != uuid.Nil && template.OrganizationID != arg.OrganizationID { - continue - } - - if arg.ExactName != "" && !strings.EqualFold(template.Name, arg.ExactName) { - continue - } - // Filters templates based on the search query filter 'Deprecated' status - // Matching SQL logic: - // -- Filter by deprecated - // AND CASE - // WHEN :deprecated IS NOT NULL THEN - // CASE - // WHEN :deprecated THEN deprecated != '' - // ELSE deprecated = '' - // END - // ELSE true - if arg.Deprecated.Valid && arg.Deprecated.Bool != isDeprecated(template) { - continue - } - if arg.FuzzyName != "" { - if !strings.Contains(strings.ToLower(template.Name), strings.ToLower(arg.FuzzyName)) { - continue - } - } - - if len(arg.IDs) > 0 { - match := false - for _, id := range arg.IDs { - if template.ID == id { - match = true - break - } - } - if !match { - continue - } - } - - if arg.HasAITask.Valid { - tv, err := q.getTemplateVersionByIDNoLock(ctx, template.ActiveVersionID) - if err != nil { - return nil, xerrors.Errorf("get template version: %w", err) - } - tvHasAITask := tv.HasAITask.Valid && tv.HasAITask.Bool - if tvHasAITask != arg.HasAITask.Bool { - continue - } - } - - templates = append(templates, template) - } - if len(templates) > 0 { - slices.SortFunc(templates, func(a, b database.Template) int { - if a.Name != b.Name { - return slice.Ascending(a.Name, b.Name) - } - return slice.Ascending(a.ID.String(), b.ID.String()) - }) - return templates, nil - } - - return nil, sql.ErrNoRows -} - -func (q *FakeQuerier) GetTemplateGroupRoles(_ context.Context, id uuid.UUID) ([]database.TemplateGroup, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - var template database.TemplateTable - for _, t := range q.templates { - if t.ID == id { - template = t - break - } - } - - if template.ID == uuid.Nil { - return nil, sql.ErrNoRows - } - - groups := make([]database.TemplateGroup, 0, len(template.GroupACL)) - for k, v := range template.GroupACL { - group, err := q.getGroupByIDNoLock(context.Background(), uuid.MustParse(k)) - if err != nil && !xerrors.Is(err, sql.ErrNoRows) { - return nil, xerrors.Errorf("get group by ID: %w", err) - } - // We don't delete groups from the map if they - // get deleted so just skip. - if xerrors.Is(err, sql.ErrNoRows) { - continue - } - - groups = append(groups, database.TemplateGroup{ - Group: group, - Actions: v, - }) - } - - return groups, nil -} - -func (q *FakeQuerier) GetTemplateUserRoles(_ context.Context, id uuid.UUID) ([]database.TemplateUser, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - var template database.TemplateTable - for _, t := range q.templates { - if t.ID == id { - template = t - break - } - } - - if template.ID == uuid.Nil { - return nil, sql.ErrNoRows - } - - users := make([]database.TemplateUser, 0, len(template.UserACL)) - for k, v := range template.UserACL { - user, err := q.getUserByIDNoLock(uuid.MustParse(k)) - if err != nil && xerrors.Is(err, sql.ErrNoRows) { - return nil, xerrors.Errorf("get user by ID: %w", err) - } - // We don't delete users from the map if they - // get deleted so just skip. - if xerrors.Is(err, sql.ErrNoRows) { - continue - } - - if user.Deleted || user.Status == database.UserStatusSuspended { - continue - } - - users = append(users, database.TemplateUser{ - User: user, - Actions: v, - }) - } - - return users, nil -} - -func (q *FakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) { - if err := validateDatabaseType(arg); err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - if prepared != nil { - // Call this to match the same function calls as the SQL implementation. - _, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL()) - if err != nil { - return nil, err - } - } - - workspaces := make([]database.WorkspaceTable, 0) - for _, workspace := range q.workspaces { - if arg.OwnerID != uuid.Nil && workspace.OwnerID != arg.OwnerID { - continue - } - - if len(arg.HasParam) > 0 || len(arg.ParamNames) > 0 { - build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID) - if err != nil { - return nil, xerrors.Errorf("get latest build: %w", err) - } - - params := make([]database.WorkspaceBuildParameter, 0) - for _, param := range q.workspaceBuildParameters { - if param.WorkspaceBuildID != build.ID { - continue - } - params = append(params, param) - } - - index := slices.IndexFunc(params, func(buildParam database.WorkspaceBuildParameter) bool { - // If hasParam matches, then we are done. This is a good match. - if slices.ContainsFunc(arg.HasParam, func(name string) bool { - return strings.EqualFold(buildParam.Name, name) - }) { - return true - } - - // Check name + value - match := false - for i := range arg.ParamNames { - matchName := arg.ParamNames[i] - if !strings.EqualFold(matchName, buildParam.Name) { - continue - } - - matchValue := arg.ParamValues[i] - if !strings.EqualFold(matchValue, buildParam.Value) { - continue - } - match = true - break - } - - return match - }) - if index < 0 { - continue - } - } - - if arg.OrganizationID != uuid.Nil { - if workspace.OrganizationID != arg.OrganizationID { - continue - } - } - - if arg.OwnerUsername != "" { - owner, err := q.getUserByIDNoLock(workspace.OwnerID) - if err == nil && !strings.EqualFold(arg.OwnerUsername, owner.Username) { - continue - } - } - - if arg.TemplateName != "" { - template, err := q.getTemplateByIDNoLock(ctx, workspace.TemplateID) - if err == nil && !strings.EqualFold(arg.TemplateName, template.Name) { - continue - } - } - - if arg.UsingActive.Valid { - build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID) - if err != nil { - return nil, xerrors.Errorf("get latest build: %w", err) - } - - template, err := q.getTemplateByIDNoLock(ctx, workspace.TemplateID) - if err != nil { - return nil, xerrors.Errorf("get template: %w", err) - } - - updated := build.TemplateVersionID == template.ActiveVersionID - if arg.UsingActive.Bool != updated { - continue - } - } - - if !arg.Deleted && workspace.Deleted { - continue - } - - if arg.Name != "" && !strings.Contains(strings.ToLower(workspace.Name), strings.ToLower(arg.Name)) { - continue - } - - if !arg.LastUsedBefore.IsZero() { - if workspace.LastUsedAt.After(arg.LastUsedBefore) { - continue - } - } - - if !arg.LastUsedAfter.IsZero() { - if workspace.LastUsedAt.Before(arg.LastUsedAfter) { - continue - } - } - - if arg.Status != "" { - build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID) - if err != nil { - return nil, xerrors.Errorf("get latest build: %w", err) - } - - job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) - if err != nil { - return nil, xerrors.Errorf("get provisioner job: %w", err) - } - - // This logic should match the logic in the workspace.sql file. - var statusMatch bool - switch database.WorkspaceStatus(arg.Status) { - case database.WorkspaceStatusStarting: - statusMatch = job.JobStatus == database.ProvisionerJobStatusRunning && - build.Transition == database.WorkspaceTransitionStart - case database.WorkspaceStatusStopping: - statusMatch = job.JobStatus == database.ProvisionerJobStatusRunning && - build.Transition == database.WorkspaceTransitionStop - case database.WorkspaceStatusDeleting: - statusMatch = job.JobStatus == database.ProvisionerJobStatusRunning && - build.Transition == database.WorkspaceTransitionDelete - - case "started": - statusMatch = job.JobStatus == database.ProvisionerJobStatusSucceeded && - build.Transition == database.WorkspaceTransitionStart - case database.WorkspaceStatusDeleted: - statusMatch = job.JobStatus == database.ProvisionerJobStatusSucceeded && - build.Transition == database.WorkspaceTransitionDelete - case database.WorkspaceStatusStopped: - statusMatch = job.JobStatus == database.ProvisionerJobStatusSucceeded && - build.Transition == database.WorkspaceTransitionStop - case database.WorkspaceStatusRunning: - statusMatch = job.JobStatus == database.ProvisionerJobStatusSucceeded && - build.Transition == database.WorkspaceTransitionStart - default: - statusMatch = job.JobStatus == database.ProvisionerJobStatus(arg.Status) - } - if !statusMatch { - continue - } - } - - if arg.HasAgent != "" { - build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID) - if err != nil { - return nil, xerrors.Errorf("get latest build: %w", err) - } - - job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) - if err != nil { - return nil, xerrors.Errorf("get provisioner job: %w", err) - } - - workspaceResources, err := q.getWorkspaceResourcesByJobIDNoLock(ctx, job.ID) - if err != nil { - return nil, xerrors.Errorf("get workspace resources: %w", err) - } - - var workspaceResourceIDs []uuid.UUID - for _, wr := range workspaceResources { - workspaceResourceIDs = append(workspaceResourceIDs, wr.ID) - } - - workspaceAgents, err := q.getWorkspaceAgentsByResourceIDsNoLock(ctx, workspaceResourceIDs) - if err != nil { - return nil, xerrors.Errorf("get workspace agents: %w", err) - } - - var hasAgentMatched bool - for _, wa := range workspaceAgents { - if mapAgentStatus(wa, arg.AgentInactiveDisconnectTimeoutSeconds) == arg.HasAgent { - hasAgentMatched = true - } - } - - if !hasAgentMatched { - continue - } - } - - if arg.Dormant && !workspace.DormantAt.Valid { - continue - } - - if len(arg.TemplateIDs) > 0 { - match := false - for _, id := range arg.TemplateIDs { - if workspace.TemplateID == id { - match = true - break - } - } - if !match { - continue - } - } - - if len(arg.WorkspaceIds) > 0 { - match := false - for _, id := range arg.WorkspaceIds { - if workspace.ID == id { - match = true - break - } - } - if !match { - continue - } - } - - if arg.HasAITask.Valid { - hasAITask, err := func() (bool, error) { - build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID) - if err != nil { - return false, xerrors.Errorf("get latest build: %w", err) - } - if build.HasAITask.Valid { - return build.HasAITask.Bool, nil - } - // If the build has a nil AI task, check if the job is in progress - // and if it has a non-empty AI Prompt parameter - job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) - if err != nil { - return false, xerrors.Errorf("get provisioner job: %w", err) - } - if job.CompletedAt.Valid { - return false, nil - } - parameters, err := q.getWorkspaceBuildParametersNoLock(build.ID) - if err != nil { - return false, xerrors.Errorf("get workspace build parameters: %w", err) - } - for _, param := range parameters { - if param.Name == "AI Prompt" && param.Value != "" { - return true, nil - } - } - return false, nil - }() - if err != nil { - return nil, xerrors.Errorf("get hasAITask: %w", err) - } - if hasAITask != arg.HasAITask.Bool { - continue - } - } - - // If the filter exists, ensure the object is authorized. - if prepared != nil && prepared.Authorize(ctx, workspace.RBACObject()) != nil { - continue - } - workspaces = append(workspaces, workspace) - } - - // Sort workspaces (ORDER BY) - isRunning := func(build database.WorkspaceBuild, job database.ProvisionerJob) bool { - return job.CompletedAt.Valid && !job.CanceledAt.Valid && !job.Error.Valid && build.Transition == database.WorkspaceTransitionStart - } - - preloadedWorkspaceBuilds := map[uuid.UUID]database.WorkspaceBuild{} - preloadedProvisionerJobs := map[uuid.UUID]database.ProvisionerJob{} - preloadedUsers := map[uuid.UUID]database.User{} - - for _, w := range workspaces { - build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, w.ID) - if err == nil { - preloadedWorkspaceBuilds[w.ID] = build - } else if !errors.Is(err, sql.ErrNoRows) { - return nil, xerrors.Errorf("get latest build: %w", err) - } - - job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) - if err == nil { - preloadedProvisionerJobs[w.ID] = job - } else if !errors.Is(err, sql.ErrNoRows) { - return nil, xerrors.Errorf("get provisioner job: %w", err) - } - - user, err := q.getUserByIDNoLock(w.OwnerID) - if err == nil { - preloadedUsers[w.ID] = user - } else if !errors.Is(err, sql.ErrNoRows) { - return nil, xerrors.Errorf("get user: %w", err) - } - } - - sort.Slice(workspaces, func(i, j int) bool { - w1 := workspaces[i] - w2 := workspaces[j] - - // Order by: favorite first - if arg.RequesterID == w1.OwnerID && w1.Favorite { - return true - } - if arg.RequesterID == w2.OwnerID && w2.Favorite { - return false - } - - // Order by: running - w1IsRunning := isRunning(preloadedWorkspaceBuilds[w1.ID], preloadedProvisionerJobs[w1.ID]) - w2IsRunning := isRunning(preloadedWorkspaceBuilds[w2.ID], preloadedProvisionerJobs[w2.ID]) - - if w1IsRunning && !w2IsRunning { - return true - } - - if !w1IsRunning && w2IsRunning { - return false - } - - // Order by: usernames - if strings.Compare(preloadedUsers[w1.ID].Username, preloadedUsers[w2.ID].Username) < 0 { - return true - } - - // Order by: workspace names - return strings.Compare(w1.Name, w2.Name) < 0 - }) - - beforePageCount := len(workspaces) - - if arg.Offset > 0 { - if int(arg.Offset) > len(workspaces) { - return q.convertToWorkspaceRowsNoLock(ctx, []database.WorkspaceTable{}, int64(beforePageCount), arg.WithSummary), nil - } - workspaces = workspaces[arg.Offset:] - } - if arg.Limit > 0 { - if int(arg.Limit) > len(workspaces) { - return q.convertToWorkspaceRowsNoLock(ctx, workspaces, int64(beforePageCount), arg.WithSummary), nil - } - workspaces = workspaces[:arg.Limit] - } - - return q.convertToWorkspaceRowsNoLock(ctx, workspaces, int64(beforePageCount), arg.WithSummary), nil -} - -func (q *FakeQuerier) GetAuthorizedWorkspacesAndAgentsByOwnerID(ctx context.Context, ownerID uuid.UUID, prepared rbac.PreparedAuthorized) ([]database.GetWorkspacesAndAgentsByOwnerIDRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - if prepared != nil { - // Call this to match the same function calls as the SQL implementation. - _, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL()) - if err != nil { - return nil, err - } - } - workspaces := make([]database.WorkspaceTable, 0) - for _, workspace := range q.workspaces { - if workspace.OwnerID == ownerID && !workspace.Deleted { - workspaces = append(workspaces, workspace) - } - } - - out := make([]database.GetWorkspacesAndAgentsByOwnerIDRow, 0, len(workspaces)) - for _, w := range workspaces { - // these always exist - build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, w.ID) - if err != nil { - return nil, xerrors.Errorf("get latest build: %w", err) - } - - job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) - if err != nil { - return nil, xerrors.Errorf("get provisioner job: %w", err) - } - - outAgents := make([]database.AgentIDNamePair, 0) - resources, err := q.getWorkspaceResourcesByJobIDNoLock(ctx, job.ID) - if err != nil { - return nil, xerrors.Errorf("get workspace resources: %w", err) - } - if len(resources) > 0 { - agents, err := q.getWorkspaceAgentsByResourceIDsNoLock(ctx, []uuid.UUID{resources[0].ID}) - if err != nil { - return nil, xerrors.Errorf("get workspace agents: %w", err) - } - for _, a := range agents { - outAgents = append(outAgents, database.AgentIDNamePair{ - ID: a.ID, - Name: a.Name, - }) - } - } - - out = append(out, database.GetWorkspacesAndAgentsByOwnerIDRow{ - ID: w.ID, - Name: w.Name, - JobStatus: job.JobStatus, - Transition: build.Transition, - Agents: outAgents, - }) - } - - return out, nil -} - -func (q *FakeQuerier) GetAuthorizedWorkspaceBuildParametersByBuildIDs(ctx context.Context, workspaceBuildIDs []uuid.UUID, prepared rbac.PreparedAuthorized) ([]database.WorkspaceBuildParameter, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - if prepared != nil { - // Call this to match the same function calls as the SQL implementation. - _, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL()) - if err != nil { - return nil, err - } - } - - filteredParameters := make([]database.WorkspaceBuildParameter, 0) - for _, buildID := range workspaceBuildIDs { - parameters, err := q.GetWorkspaceBuildParameters(ctx, buildID) - if err != nil { - return nil, err - } - filteredParameters = append(filteredParameters, parameters...) - } - - return filteredParameters, nil -} - -func (q *FakeQuerier) GetAuthorizedUsers(ctx context.Context, arg database.GetUsersParams, prepared rbac.PreparedAuthorized) ([]database.GetUsersRow, error) { - if err := validateDatabaseType(arg); err != nil { - return nil, err - } - - // Call this to match the same function calls as the SQL implementation. - if prepared != nil { - _, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{ - VariableConverter: regosql.UserConverter(), - }) - if err != nil { - return nil, err - } - } - - users, err := q.GetUsers(ctx, arg) - if err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - filteredUsers := make([]database.GetUsersRow, 0, len(users)) - for _, user := range users { - // If the filter exists, ensure the object is authorized. - if prepared != nil && prepared.Authorize(ctx, user.RBACObject()) != nil { - continue - } - - filteredUsers = append(filteredUsers, user) - } - return filteredUsers, nil -} - -func (q *FakeQuerier) GetAuthorizedAuditLogsOffset(ctx context.Context, arg database.GetAuditLogsOffsetParams, prepared rbac.PreparedAuthorized) ([]database.GetAuditLogsOffsetRow, error) { - if err := validateDatabaseType(arg); err != nil { - return nil, err - } - - // Call this to match the same function calls as the SQL implementation. - // It functionally does nothing for filtering. - if prepared != nil { - _, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{ - VariableConverter: regosql.AuditLogConverter(), - }) - if err != nil { - return nil, err - } - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - if arg.LimitOpt == 0 { - // Default to 100 is set in the SQL query. - arg.LimitOpt = 100 - } - - logs := make([]database.GetAuditLogsOffsetRow, 0, arg.LimitOpt) - - // q.auditLogs are already sorted by time DESC, so no need to sort after the fact. - for _, alog := range q.auditLogs { - if arg.OffsetOpt > 0 { - arg.OffsetOpt-- - continue - } - if arg.RequestID != uuid.Nil && arg.RequestID != alog.RequestID { - continue - } - if arg.OrganizationID != uuid.Nil && arg.OrganizationID != alog.OrganizationID { - continue - } - if arg.Action != "" && string(alog.Action) != arg.Action { - continue - } - if arg.ResourceType != "" && !strings.Contains(string(alog.ResourceType), arg.ResourceType) { - continue - } - if arg.ResourceID != uuid.Nil && alog.ResourceID != arg.ResourceID { - continue - } - if arg.Username != "" { - user, err := q.getUserByIDNoLock(alog.UserID) - if err == nil && !strings.EqualFold(arg.Username, user.Username) { - continue - } - } - if arg.Email != "" { - user, err := q.getUserByIDNoLock(alog.UserID) - if err == nil && !strings.EqualFold(arg.Email, user.Email) { - continue - } - } - if !arg.DateFrom.IsZero() { - if alog.Time.Before(arg.DateFrom) { - continue - } - } - if !arg.DateTo.IsZero() { - if alog.Time.After(arg.DateTo) { - continue - } - } - if arg.BuildReason != "" { - workspaceBuild, err := q.getWorkspaceBuildByIDNoLock(context.Background(), alog.ResourceID) - if err == nil && !strings.EqualFold(arg.BuildReason, string(workspaceBuild.Reason)) { - continue - } - } - // If the filter exists, ensure the object is authorized. - if prepared != nil && prepared.Authorize(ctx, alog.RBACObject()) != nil { - continue - } - - user, err := q.getUserByIDNoLock(alog.UserID) - userValid := err == nil - - org, _ := q.getOrganizationByIDNoLock(alog.OrganizationID) - - cpy := alog - logs = append(logs, database.GetAuditLogsOffsetRow{ - AuditLog: cpy, - OrganizationName: org.Name, - OrganizationDisplayName: org.DisplayName, - OrganizationIcon: org.Icon, - UserUsername: sql.NullString{String: user.Username, Valid: userValid}, - UserName: sql.NullString{String: user.Name, Valid: userValid}, - UserEmail: sql.NullString{String: user.Email, Valid: userValid}, - UserCreatedAt: sql.NullTime{Time: user.CreatedAt, Valid: userValid}, - UserUpdatedAt: sql.NullTime{Time: user.UpdatedAt, Valid: userValid}, - UserLastSeenAt: sql.NullTime{Time: user.LastSeenAt, Valid: userValid}, - UserLoginType: database.NullLoginType{LoginType: user.LoginType, Valid: userValid}, - UserDeleted: sql.NullBool{Bool: user.Deleted, Valid: userValid}, - UserQuietHoursSchedule: sql.NullString{String: user.QuietHoursSchedule, Valid: userValid}, - UserStatus: database.NullUserStatus{UserStatus: user.Status, Valid: userValid}, - UserRoles: user.RBACRoles, - }) - - if len(logs) >= int(arg.LimitOpt) { - break - } - } - - return logs, nil -} - -func (q *FakeQuerier) CountAuthorizedAuditLogs(ctx context.Context, arg database.CountAuditLogsParams, prepared rbac.PreparedAuthorized) (int64, error) { - if err := validateDatabaseType(arg); err != nil { - return 0, err - } - - // Call this to match the same function calls as the SQL implementation. - // It functionally does nothing for filtering. - if prepared != nil { - _, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{ - VariableConverter: regosql.AuditLogConverter(), - }) - if err != nil { - return 0, err - } - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - var count int64 - - // q.auditLogs are already sorted by time DESC, so no need to sort after the fact. - for _, alog := range q.auditLogs { - if arg.RequestID != uuid.Nil && arg.RequestID != alog.RequestID { - continue - } - if arg.OrganizationID != uuid.Nil && arg.OrganizationID != alog.OrganizationID { - continue - } - if arg.Action != "" && string(alog.Action) != arg.Action { - continue - } - if arg.ResourceType != "" && !strings.Contains(string(alog.ResourceType), arg.ResourceType) { - continue - } - if arg.ResourceID != uuid.Nil && alog.ResourceID != arg.ResourceID { - continue - } - if arg.Username != "" { - user, err := q.getUserByIDNoLock(alog.UserID) - if err == nil && !strings.EqualFold(arg.Username, user.Username) { - continue - } - } - if arg.Email != "" { - user, err := q.getUserByIDNoLock(alog.UserID) - if err == nil && !strings.EqualFold(arg.Email, user.Email) { - continue - } - } - if !arg.DateFrom.IsZero() { - if alog.Time.Before(arg.DateFrom) { - continue - } - } - if !arg.DateTo.IsZero() { - if alog.Time.After(arg.DateTo) { - continue - } - } - if arg.BuildReason != "" { - workspaceBuild, err := q.getWorkspaceBuildByIDNoLock(context.Background(), alog.ResourceID) - if err == nil && !strings.EqualFold(arg.BuildReason, string(workspaceBuild.Reason)) { - continue - } - } - // If the filter exists, ensure the object is authorized. - if prepared != nil && prepared.Authorize(ctx, alog.RBACObject()) != nil { - continue - } - - count++ - } - - return count, nil -} diff --git a/coderd/database/dbmem/dbmem_test.go b/coderd/database/dbmem/dbmem_test.go deleted file mode 100644 index c3df828b95c98..0000000000000 --- a/coderd/database/dbmem/dbmem_test.go +++ /dev/null @@ -1,209 +0,0 @@ -package dbmem_test - -import ( - "context" - "database/sql" - "sort" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/dbgen" - "github.com/coder/coder/v2/coderd/database/dbmem" - "github.com/coder/coder/v2/coderd/database/dbtime" -) - -// test that transactions don't deadlock, and that we don't see intermediate state. -func TestInTx(t *testing.T) { - t.Parallel() - - uut := dbmem.New() - - inTx := make(chan any) - queriesDone := make(chan any) - queriesStarted := make(chan any) - go func() { - err := uut.InTx(func(tx database.Store) error { - close(inTx) - _, err := tx.InsertOrganization(context.Background(), database.InsertOrganizationParams{ - Name: "1", - }) - assert.NoError(t, err) - <-queriesStarted - time.Sleep(5 * time.Millisecond) - _, err = tx.InsertOrganization(context.Background(), database.InsertOrganizationParams{ - Name: "2", - }) - assert.NoError(t, err) - return nil - }, nil) - assert.NoError(t, err) - }() - var nums []int - go func() { - <-inTx - for i := 0; i < 20; i++ { - orgs, err := uut.GetOrganizations(context.Background(), database.GetOrganizationsParams{}) - if err != nil { - assert.ErrorIs(t, err, sql.ErrNoRows) - } - nums = append(nums, len(orgs)) - time.Sleep(time.Millisecond) - } - close(queriesDone) - }() - close(queriesStarted) - <-queriesDone - // ensure we never saw 1 org, only 0 or 2. - for i := 0; i < 20; i++ { - assert.NotEqual(t, 1, nums[i]) - } -} - -// TestUserOrder ensures that the fake database returns users sorted by username. -func TestUserOrder(t *testing.T) { - t.Parallel() - - db := dbmem.New() - now := dbtime.Now() - - usernames := []string{"b-user", "d-user", "a-user", "c-user", "e-user"} - for _, username := range usernames { - dbgen.User(t, db, database.User{Username: username, CreatedAt: now}) - } - - users, err := db.GetUsers(context.Background(), database.GetUsersParams{}) - require.NoError(t, err) - require.Lenf(t, users, len(usernames), "expected %d users", len(usernames)) - - sort.Strings(usernames) - for i, user := range users { - require.Equal(t, usernames[i], user.Username) - } -} - -func TestProxyByHostname(t *testing.T) { - t.Parallel() - - db := dbmem.New() - - // Insert a bunch of different proxies. - proxies := []struct { - name string - accessURL string - wildcardHostname string - }{ - { - name: "one", - accessURL: "https://one.coder.com", - wildcardHostname: "*.wildcard.one.coder.com", - }, - { - name: "two", - accessURL: "https://two.coder.com", - wildcardHostname: "*--suffix.two.coder.com", - }, - } - for _, p := range proxies { - dbgen.WorkspaceProxy(t, db, database.WorkspaceProxy{ - Name: p.name, - Url: p.accessURL, - WildcardHostname: p.wildcardHostname, - }) - } - - cases := []struct { - name string - testHostname string - allowAccessURL bool - allowWildcardHost bool - matchProxyName string - }{ - { - name: "NoMatch", - testHostname: "test.com", - allowAccessURL: true, - allowWildcardHost: true, - matchProxyName: "", - }, - { - name: "MatchAccessURL", - testHostname: "one.coder.com", - allowAccessURL: true, - allowWildcardHost: true, - matchProxyName: "one", - }, - { - name: "MatchWildcard", - testHostname: "something.wildcard.one.coder.com", - allowAccessURL: true, - allowWildcardHost: true, - matchProxyName: "one", - }, - { - name: "MatchSuffix", - testHostname: "something--suffix.two.coder.com", - allowAccessURL: true, - allowWildcardHost: true, - matchProxyName: "two", - }, - { - name: "ValidateHostname/1", - testHostname: ".*ne.coder.com", - allowAccessURL: true, - allowWildcardHost: true, - matchProxyName: "", - }, - { - name: "ValidateHostname/2", - testHostname: "https://one.coder.com", - allowAccessURL: true, - allowWildcardHost: true, - matchProxyName: "", - }, - { - name: "ValidateHostname/3", - testHostname: "one.coder.com:8080/hello", - allowAccessURL: true, - allowWildcardHost: true, - matchProxyName: "", - }, - { - name: "IgnoreAccessURLMatch", - testHostname: "one.coder.com", - allowAccessURL: false, - allowWildcardHost: true, - matchProxyName: "", - }, - { - name: "IgnoreWildcardMatch", - testHostname: "hi.wildcard.one.coder.com", - allowAccessURL: true, - allowWildcardHost: false, - matchProxyName: "", - }, - } - - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - t.Parallel() - - proxy, err := db.GetWorkspaceProxyByHostname(context.Background(), database.GetWorkspaceProxyByHostnameParams{ - Hostname: c.testHostname, - AllowAccessUrl: c.allowAccessURL, - AllowWildcardHostname: c.allowWildcardHost, - }) - if c.matchProxyName == "" { - require.ErrorIs(t, err, sql.ErrNoRows) - require.Empty(t, proxy) - } else { - require.NoError(t, err) - require.NotEmpty(t, proxy) - require.Equal(t, c.matchProxyName, proxy.Name) - } - }) - } -} diff --git a/coderd/database/dbtestutil/db.go b/coderd/database/dbtestutil/db.go index fa3567c490826..f67e3206b09d1 100644 --- a/coderd/database/dbtestutil/db.go +++ b/coderd/database/dbtestutil/db.go @@ -20,14 +20,15 @@ import ( "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/dbmem" "github.com/coder/coder/v2/coderd/database/pubsub" "github.com/coder/coder/v2/testutil" ) // WillUsePostgres returns true if a call to NewDB() will return a real, postgres-backed Store and Pubsub. +// TODO(hugodutka): since we removed the in-memory database, this is always true, +// and we need to remove this function. https://github.com/coder/internal/issues/758 func WillUsePostgres() bool { - return os.Getenv("DB") != "" + return true } type options struct { @@ -109,52 +110,48 @@ func NewDB(t testing.TB, opts ...Option) (database.Store, pubsub.Pubsub) { var db database.Store var ps pubsub.Pubsub - if WillUsePostgres() { - connectionURL := os.Getenv("CODER_PG_CONNECTION_URL") - if connectionURL == "" && o.url != "" { - connectionURL = o.url - } - if connectionURL == "" { - var err error - connectionURL, err = Open(t) - require.NoError(t, err) - } - - if o.fixedTimezone == "" { - // To make sure we find timezone-related issues, we set the timezone - // of the database to a non-UTC one. - // The below was picked due to the following properties: - // - It has a non-UTC offset - // - It has a fractional hour UTC offset - // - It includes a daylight savings time component - o.fixedTimezone = DefaultTimezone - } - dbName := dbNameFromConnectionURL(t, connectionURL) - setDBTimezone(t, connectionURL, dbName, o.fixedTimezone) - sqlDB, err := sql.Open("postgres", connectionURL) + connectionURL := os.Getenv("CODER_PG_CONNECTION_URL") + if connectionURL == "" && o.url != "" { + connectionURL = o.url + } + if connectionURL == "" { + var err error + connectionURL, err = Open(t) require.NoError(t, err) - t.Cleanup(func() { - _ = sqlDB.Close() - }) - if o.returnSQLDB != nil { - o.returnSQLDB(sqlDB) - } - if o.dumpOnFailure { - t.Cleanup(func() { DumpOnFailure(t, connectionURL) }) - } - // Unit tests should not retry serial transaction failures. - db = database.New(sqlDB, database.WithSerialRetryCount(1)) + } - ps, err = pubsub.New(context.Background(), o.logger, sqlDB, connectionURL) - require.NoError(t, err) - t.Cleanup(func() { - _ = ps.Close() - }) - } else { - db = dbmem.New() - ps = pubsub.NewInMemory() + if o.fixedTimezone == "" { + // To make sure we find timezone-related issues, we set the timezone + // of the database to a non-UTC one. + // The below was picked due to the following properties: + // - It has a non-UTC offset + // - It has a fractional hour UTC offset + // - It includes a daylight savings time component + o.fixedTimezone = DefaultTimezone + } + dbName := dbNameFromConnectionURL(t, connectionURL) + setDBTimezone(t, connectionURL, dbName, o.fixedTimezone) + + sqlDB, err := sql.Open("postgres", connectionURL) + require.NoError(t, err) + t.Cleanup(func() { + _ = sqlDB.Close() + }) + if o.returnSQLDB != nil { + o.returnSQLDB(sqlDB) + } + if o.dumpOnFailure { + t.Cleanup(func() { DumpOnFailure(t, connectionURL) }) } + // Unit tests should not retry serial transaction failures. + db = database.New(sqlDB, database.WithSerialRetryCount(1)) + + ps, err = pubsub.New(context.Background(), o.logger, sqlDB, connectionURL) + require.NoError(t, err) + t.Cleanup(func() { + _ = ps.Close() + }) return db, ps } diff --git a/coderd/database/oidcclaims_test.go b/coderd/database/oidcclaims_test.go index f9fe1711b19b8..fe4a10d83495e 100644 --- a/coderd/database/oidcclaims_test.go +++ b/coderd/database/oidcclaims_test.go @@ -222,7 +222,6 @@ func (g userGenerator) withLink(lt database.LoginType, rawJSON json.RawMessage) err := sql.UpdateUserLinkRawJSON(context.Background(), user.ID, rawJSON) require.NoError(t, err) } else { - // no need to test the json key logic in dbmem. Everything is type safe. var claims database.UserLinkClaims err := json.Unmarshal(rawJSON, &claims) require.NoError(t, err) diff --git a/coderd/database/querier_test.go b/coderd/database/querier_test.go index f80f68115ad2c..4dc7bdd66d30a 100644 --- a/coderd/database/querier_test.go +++ b/coderd/database/querier_test.go @@ -1400,7 +1400,6 @@ func TestGetUsers_IncludeSystem(t *testing.T) { // Given: a system user // postgres: introduced by migration coderd/database/migrations/00030*_system_user.up.sql - // dbmem: created in dbmem/dbmem.go db, _ := dbtestutil.NewDB(t) other := dbgen.User(t, db, database.User{}) users, err := db.GetUsers(ctx, database.GetUsersParams{ diff --git a/coderd/notifications/notifications_test.go b/coderd/notifications/notifications_test.go index ec9edee4c8514..e213a62df9996 100644 --- a/coderd/notifications/notifications_test.go +++ b/coderd/notifications/notifications_test.go @@ -769,7 +769,7 @@ func TestNotificationTemplates_Golden(t *testing.T) { hello = "localhost" from = "system@coder.com" - hint = "run \"DB=ci make gen/golden-files\" and commit the changes" + hint = "run \"make gen/golden-files\" and commit the changes" ) tests := []struct { diff --git a/coderd/users_test.go b/coderd/users_test.go index bd0f138b6a339..9d695f37c9906 100644 --- a/coderd/users_test.go +++ b/coderd/users_test.go @@ -32,7 +32,6 @@ import ( "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbfake" "github.com/coder/coder/v2/coderd/database/dbgen" - "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/coderd/util/ptr" @@ -1794,15 +1793,6 @@ func TestUsersFilter(t *testing.T) { } } - // TODO: This can be removed with dbmem - if !dbtestutil.WillUsePostgres() { - for i := range matched.Users { - if len(matched.Users[i].OrganizationIDs) == 0 { - matched.Users[i].OrganizationIDs = nil - } - } - } - require.ElementsMatch(t, exp, matched.Users, "expected users returned") }) } diff --git a/codersdk/deployment.go b/codersdk/deployment.go index b24e321b8e434..266baaee8cd9a 100644 --- a/codersdk/deployment.go +++ b/codersdk/deployment.go @@ -354,7 +354,6 @@ type DeploymentValues struct { ProxyTrustedHeaders serpent.StringArray `json:"proxy_trusted_headers,omitempty" typescript:",notnull"` ProxyTrustedOrigins serpent.StringArray `json:"proxy_trusted_origins,omitempty" typescript:",notnull"` CacheDir serpent.String `json:"cache_directory,omitempty" typescript:",notnull"` - InMemoryDatabase serpent.Bool `json:"in_memory_database,omitempty" typescript:",notnull"` EphemeralDeployment serpent.Bool `json:"ephemeral_deployment,omitempty" typescript:",notnull"` PostgresURL serpent.String `json:"pg_connection_url,omitempty" typescript:",notnull"` PostgresAuth string `json:"pg_auth,omitempty" typescript:",notnull"` @@ -2404,15 +2403,6 @@ func (c *DeploymentValues) Options() serpent.OptionSet { Value: &c.CacheDir, YAML: "cacheDir", }, - { - Name: "In Memory Database", - Description: "Controls whether data will be stored in an in-memory database.", - Flag: "in-memory", - Env: "CODER_IN_MEMORY", - Hidden: true, - Value: &c.InMemoryDatabase, - YAML: "inMemoryDatabase", - }, { Name: "Ephemeral Deployment", Description: "Controls whether Coder data, including built-in Postgres, will be stored in a temporary directory and deleted when the server is stopped.", diff --git a/docs/about/contributing/backend.md b/docs/about/contributing/backend.md index fd1a80dc6b73c..1ffafa62fb324 100644 --- a/docs/about/contributing/backend.md +++ b/docs/about/contributing/backend.md @@ -68,7 +68,6 @@ The Coder backend is organized into multiple packages and directories, each with * [dbauthz](https://github.com/coder/coder/tree/main/coderd/database/dbauthz): AuthZ wrappers for database queries, ideally, every query should verify first if the accessor is eligible to see the query results. * [dbfake](https://github.com/coder/coder/tree/main/coderd/database/dbfake): helper functions to quickly prepare the initial database state for testing purposes (e.g. create N healthy workspaces and templates), operates on higher level than [dbgen](https://github.com/coder/coder/tree/main/coderd/database/dbgen) * [dbgen](https://github.com/coder/coder/tree/main/coderd/database/dbgen): helper functions to insert raw records to the database store, used for testing purposes - * [dbmem](https://github.com/coder/coder/tree/main/coderd/database/dbmem): in-memory implementation of the database store, ideally, every real query should have a complimentary Go implementation * [dbmock](https://github.com/coder/coder/tree/main/coderd/database/dbmock): a store wrapper for database queries, useful to verify if the function has been called, used for testing purposes * [dbpurge](https://github.com/coder/coder/tree/main/coderd/database/dbpurge): simple wrapper for periodic database cleanup operations * [migrations](https://github.com/coder/coder/tree/main/coderd/database/migrations): an ordered list of up/down database migrations, use `./create_migration.sh my_migration_name` to modify the database schema diff --git a/enterprise/cli/server_test.go b/enterprise/cli/server_test.go index 7022f0d33b7a7..7489699a6f3dd 100644 --- a/enterprise/cli/server_test.go +++ b/enterprise/cli/server_test.go @@ -18,9 +18,6 @@ import ( ) func dbArg(t *testing.T) string { - if !dbtestutil.WillUsePostgres() { - return "--in-memory" - } dbURL, err := dbtestutil.Open(t) require.NoError(t, err) return "--postgres-url=" + dbURL diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 5f79608275f96..bd128b25e2ef4 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -780,13 +780,7 @@ func (api *API) updateEntitlements(ctx context.Context) error { if initial, changed, enabled := featureChanged(codersdk.FeatureHighAvailability); shouldUpdate(initial, changed, enabled) { var coordinator agpltailnet.Coordinator - // If HA is enabled, but the database is in-memory, we can't actually - // run HA and the PG coordinator. So throw a log line, and continue to use - // the in memory AGPL coordinator. - if enabled && api.DeploymentValues.InMemoryDatabase.Value() { - api.Logger.Warn(ctx, "high availability is enabled, but cannot be configured due to the database being set to in-memory") - } - if enabled && !api.DeploymentValues.InMemoryDatabase.Value() { + if enabled { haCoordinator, err := tailnet.NewPGCoord(api.ctx, api.Logger, api.Pubsub, api.Database) if err != nil { api.Logger.Error(ctx, "unable to set up high availability coordinator", slog.Error(err)) diff --git a/enterprise/coderd/coderdenttest/coderdenttest.go b/enterprise/coderd/coderdenttest/coderdenttest.go index bd81e5a039599..e4088e83d09f5 100644 --- a/enterprise/coderd/coderdenttest/coderdenttest.go +++ b/enterprise/coderd/coderdenttest/coderdenttest.go @@ -22,7 +22,6 @@ import ( "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/dbmem" "github.com/coder/coder/v2/coderd/database/pubsub" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/drpcsdk" @@ -149,8 +148,6 @@ func NewWithAPI(t *testing.T, options *Options) ( // we check for the in-memory test types so that the real types don't have to exported _, ok := coderAPI.Pubsub.(*pubsub.MemoryPubsub) require.False(t, ok, "FeatureHighAvailability is incompatible with MemoryPubsub") - _, ok = coderAPI.Database.(*dbmem.FakeQuerier) - require.False(t, ok, "FeatureHighAvailability is incompatible with dbmem") } } _ = AddLicense(t, client, lo) diff --git a/scripts/dbgen/main.go b/scripts/dbgen/main.go index 7396a5140d605..561a46199a6ef 100644 --- a/scripts/dbgen/main.go +++ b/scripts/dbgen/main.go @@ -7,7 +7,6 @@ import ( "go/format" "go/token" "os" - "path" "path/filepath" "reflect" "runtime" @@ -52,14 +51,6 @@ func run() error { return err } databasePath := filepath.Join(localPath, "..", "..", "..", "coderd", "database") - - err = orderAndStubDatabaseFunctions(filepath.Join(databasePath, "dbmem", "dbmem.go"), "q", "FakeQuerier", func(_ stubParams) string { - return `panic("not implemented")` - }) - if err != nil { - return xerrors.Errorf("stub dbmem: %w", err) - } - err = orderAndStubDatabaseFunctions(filepath.Join(databasePath, "dbmetrics", "querymetrics.go"), "m", "queryMetricsStore", func(params stubParams) string { return fmt.Sprintf(` start := time.Now() @@ -257,13 +248,13 @@ func orderAndStubDatabaseFunctions(filePath, receiver, structName string, stub f contents, err := os.ReadFile(filePath) if err != nil { - return xerrors.Errorf("read dbmem: %w", err) + return xerrors.Errorf("read file: %w", err) } // Required to preserve imports! f, err := decorator.NewDecoratorWithImports(token.NewFileSet(), packageName, goast.New()).Parse(contents) if err != nil { - return xerrors.Errorf("parse dbmem: %w", err) + return xerrors.Errorf("parse file: %w", err) } pointer := false @@ -298,76 +289,6 @@ func orderAndStubDatabaseFunctions(filePath, receiver, structName string, stub f for _, fn := range funcs { var bodyStmts []dst.Stmt - // Add input validation, only relevant for dbmem. - if strings.Contains(filePath, "dbmem") && len(fn.Func.Params.List) == 2 && fn.Func.Params.List[1].Names[0].Name == "arg" { - /* - err := validateDatabaseType(arg) - if err != nil { - return database.User{}, err - } - */ - bodyStmts = append(bodyStmts, &dst.AssignStmt{ - Lhs: []dst.Expr{dst.NewIdent("err")}, - Tok: token.DEFINE, - Rhs: []dst.Expr{ - &dst.CallExpr{ - Fun: &dst.Ident{ - Name: "validateDatabaseType", - }, - Args: []dst.Expr{dst.NewIdent("arg")}, - }, - }, - }) - returnStmt := &dst.ReturnStmt{ - Results: []dst.Expr{}, // Filled below. - } - bodyStmts = append(bodyStmts, &dst.IfStmt{ - Cond: &dst.BinaryExpr{ - X: dst.NewIdent("err"), - Op: token.NEQ, - Y: dst.NewIdent("nil"), - }, - Body: &dst.BlockStmt{ - List: []dst.Stmt{ - returnStmt, - }, - }, - Decs: dst.IfStmtDecorations{ - NodeDecs: dst.NodeDecs{ - After: dst.EmptyLine, - }, - }, - }) - for _, r := range fn.Func.Results.List { - switch typ := r.Type.(type) { - case *dst.StarExpr, *dst.ArrayType, *dst.SelectorExpr: - returnStmt.Results = append(returnStmt.Results, dst.NewIdent("nil")) - case *dst.Ident: - if typ.Path != "" { - returnStmt.Results = append(returnStmt.Results, dst.NewIdent(fmt.Sprintf("%s.%s{}", path.Base(typ.Path), typ.Name))) - } else { - switch typ.Name { - case "uint8", "uint16", "uint32", "uint64", "uint", "uintptr", - "int8", "int16", "int32", "int64", "int", - "byte", "rune", - "float32", "float64", - "complex64", "complex128": - returnStmt.Results = append(returnStmt.Results, dst.NewIdent("0")) - case "string": - returnStmt.Results = append(returnStmt.Results, dst.NewIdent("\"\"")) - case "bool": - returnStmt.Results = append(returnStmt.Results, dst.NewIdent("false")) - case "error": - returnStmt.Results = append(returnStmt.Results, dst.NewIdent("err")) - default: - panic(fmt.Sprintf("unknown ident: %#v", r.Type)) - } - } - default: - panic(fmt.Sprintf("unknown return type: %T", r.Type)) - } - } - } decl, ok := declByName[fn.Name] if !ok { typeName := structName diff --git a/site/e2e/playwright.config.ts b/site/e2e/playwright.config.ts index 436af99240493..4b3e5c5c86fc6 100644 --- a/site/e2e/playwright.config.ts +++ b/site/e2e/playwright.config.ts @@ -72,7 +72,7 @@ export default defineConfig({ "--global-config $(mktemp -d -t e2e-XXXXXXXXXX)", `--access-url=http://localhost:${coderPort}`, `--http-address=0.0.0.0:${coderPort}`, - "--in-memory", + "--ephemeral", "--telemetry=false", "--dangerous-disable-rate-limits", "--provisioner-daemons 10", From b14e99c6c7dcfe45ffc28dd54f5d7c7dff108b06 Mon Sep 17 00:00:00 2001 From: Hugo Dutka Date: Tue, 8 Jul 2025 16:38:10 +0000 Subject: [PATCH 2/2] make gen --- cli/testdata/server-config.yaml.golden | 3 --- coderd/apidoc/docs.go | 3 --- coderd/apidoc/swagger.json | 3 --- docs/reference/api/general.md | 1 - docs/reference/api/schemas.md | 3 --- site/src/api/typesGenerated.ts | 1 - 6 files changed, 14 deletions(-) diff --git a/cli/testdata/server-config.yaml.golden b/cli/testdata/server-config.yaml.golden index dabeab8ca48bd..e23274e442078 100644 --- a/cli/testdata/server-config.yaml.golden +++ b/cli/testdata/server-config.yaml.golden @@ -462,9 +462,6 @@ enableSwagger: false # to be configured as a shared directory across coderd/provisionerd replicas. # (default: [cache dir], type: string) cacheDir: [cache dir] -# Controls whether data will be stored in an in-memory database. -# (default: , type: bool) -inMemoryDatabase: false # Controls whether Coder data, including built-in Postgres, will be stored in a # temporary directory and deleted when the server is stopped. # (default: , type: bool) diff --git a/coderd/apidoc/docs.go b/coderd/apidoc/docs.go index 8dea791f7b763..79cff80b1fbc5 100644 --- a/coderd/apidoc/docs.go +++ b/coderd/apidoc/docs.go @@ -12305,9 +12305,6 @@ const docTemplate = `{ "http_cookies": { "$ref": "#/definitions/codersdk.HTTPCookieConfig" }, - "in_memory_database": { - "type": "boolean" - }, "job_hang_detector_interval": { "type": "integer" }, diff --git a/coderd/apidoc/swagger.json b/coderd/apidoc/swagger.json index 143342c1c2147..5fa1d98030cb5 100644 --- a/coderd/apidoc/swagger.json +++ b/coderd/apidoc/swagger.json @@ -10989,9 +10989,6 @@ "http_cookies": { "$ref": "#/definitions/codersdk.HTTPCookieConfig" }, - "in_memory_database": { - "type": "boolean" - }, "job_hang_detector_interval": { "type": "integer" }, diff --git a/docs/reference/api/general.md b/docs/reference/api/general.md index 8f440c55b42d6..72543f6774dfd 100644 --- a/docs/reference/api/general.md +++ b/docs/reference/api/general.md @@ -265,7 +265,6 @@ curl -X GET http://coder-server:8080/api/v2/deployment/config \ "same_site": "string", "secure_auth_cookie": true }, - "in_memory_database": true, "job_hang_detector_interval": 0, "logging": { "human": "string", diff --git a/docs/reference/api/schemas.md b/docs/reference/api/schemas.md index 305c35e1bebdd..6ca1cfb9dfe51 100644 --- a/docs/reference/api/schemas.md +++ b/docs/reference/api/schemas.md @@ -1987,7 +1987,6 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o "same_site": "string", "secure_auth_cookie": true }, - "in_memory_database": true, "job_hang_detector_interval": 0, "logging": { "human": "string", @@ -2475,7 +2474,6 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o "same_site": "string", "secure_auth_cookie": true }, - "in_memory_database": true, "job_hang_detector_interval": 0, "logging": { "human": "string", @@ -2772,7 +2770,6 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o | `hide_ai_tasks` | boolean | false | | | | `http_address` | string | false | | Http address is a string because it may be set to zero to disable. | | `http_cookies` | [codersdk.HTTPCookieConfig](#codersdkhttpcookieconfig) | false | | | -| `in_memory_database` | boolean | false | | | | `job_hang_detector_interval` | integer | false | | | | `logging` | [codersdk.LoggingConfig](#codersdkloggingconfig) | false | | | | `metrics_cache_refresh_interval` | integer | false | | | diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 8a8a5583e9069..53dc919df2df3 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -668,7 +668,6 @@ export interface DeploymentValues { readonly proxy_trusted_headers?: string; readonly proxy_trusted_origins?: string; readonly cache_directory?: string; - readonly in_memory_database?: boolean; readonly ephemeral_deployment?: boolean; readonly pg_connection_url?: string; readonly pg_auth?: string;