Thanks to visit codestin.com
Credit goes to github.com

Skip to content

feat: use JWT ticket to avoid DB queries on apps #6148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Mar 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 55 additions & 11 deletions cli/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"crypto/tls"
"crypto/x509"
"database/sql"
"encoding/hex"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -587,19 +588,62 @@ func Server(vip *viper.Viper, newAPI func(context.Context, *coderd.Options) (*co
defer options.Pubsub.Close()
}

deploymentID, err := options.Database.GetDeploymentID(ctx)
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
if err != nil {
return xerrors.Errorf("get deployment id: %w", err)
}
if deploymentID == "" {
deploymentID = uuid.NewString()
err = options.Database.InsertDeploymentID(ctx, deploymentID)
var deploymentID string
err = options.Database.InTx(func(tx database.Store) error {
// This will block until the lock is acquired, and will be
// automatically released when the transaction ends.
err := tx.AcquireLock(ctx, database.LockIDDeploymentSetup)
if err != nil {
return xerrors.Errorf("acquire lock: %w", err)
}

deploymentID, err = tx.GetDeploymentID(ctx)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
return xerrors.Errorf("get deployment id: %w", err)
}
if deploymentID == "" {
deploymentID = uuid.NewString()
err = tx.InsertDeploymentID(ctx, deploymentID)
if err != nil {
return xerrors.Errorf("set deployment id: %w", err)
}
}

// Read the app signing key from the DB. We store it hex
// encoded since the config table uses strings for the value and
// we don't want to deal with automatic encoding issues.
appSigningKeyStr, err := tx.GetAppSigningKey(ctx)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
return xerrors.Errorf("get app signing key: %w", err)
}
if appSigningKeyStr == "" {
// Generate 64 byte secure random string.
b := make([]byte, 64)
_, err := rand.Read(b)
if err != nil {
return xerrors.Errorf("generate fresh app signing key: %w", err)
}

appSigningKeyStr = hex.EncodeToString(b)
err = tx.InsertAppSigningKey(ctx, appSigningKeyStr)
if err != nil {
return xerrors.Errorf("insert freshly generated app signing key to database: %w", err)
}
}

appSigningKey, err := hex.DecodeString(appSigningKeyStr)
if err != nil {
return xerrors.Errorf("set deployment id: %w", err)
return xerrors.Errorf("decode app signing key from database as hex: %w", err)
}
if len(appSigningKey) != 64 {
return xerrors.Errorf("app signing key must be 64 bytes, key in database is %d bytes", len(appSigningKey))
}

options.AppSigningKey = appSigningKey
return nil
}, nil)
if err != nil {
return err
}

// Disable telemetry if the in-memory database is used unless explicitly defined!
Expand Down
78 changes: 32 additions & 46 deletions coderd/coderd.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ import (
"github.com/coder/coder/coderd/tracing"
"github.com/coder/coder/coderd/updatecheck"
"github.com/coder/coder/coderd/util/slice"
"github.com/coder/coder/coderd/workspaceapps"
"github.com/coder/coder/coderd/wsconncache"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/provisionerd/proto"
Expand Down Expand Up @@ -120,6 +121,9 @@ type Options struct {
SwaggerEndpoint bool
SetUserGroups func(ctx context.Context, tx database.Store, userID uuid.UUID, groupNames []string) error
TemplateScheduleStore schedule.TemplateScheduleStore
// AppSigningKey denotes the symmetric key to use for signing app tickets.
// The key must be 64 bytes long.
AppSigningKey []byte

// APIRateLimit is the minutely throughput rate limit per user or ip.
// Setting a rate limit <0 will disable the rate limiter across the entire
Expand Down Expand Up @@ -214,6 +218,9 @@ func New(options *Options) *API {
if options.TemplateScheduleStore == nil {
options.TemplateScheduleStore = schedule.NewAGPLTemplateScheduleStore()
}
if len(options.AppSigningKey) != 64 {
panic("coderd: AppSigningKey must be 64 bytes long")
}

siteCacheDir := options.CacheDir
if siteCacheDir != "" {
Expand All @@ -236,6 +243,11 @@ func New(options *Options) *API {
// static files since it only affects browsers.
staticHandler = httpmw.HSTS(staticHandler, options.StrictTransportSecurityCfg)

oauthConfigs := &httpmw.OAuth2Configs{
Github: options.GithubOAuth2Config,
OIDC: options.OIDCConfig,
}

r := chi.NewRouter()
ctx, cancel := context.WithCancel(context.Background())
api := &API{
Expand All @@ -250,6 +262,15 @@ func New(options *Options) *API {
Authorizer: options.Authorizer,
Logger: options.Logger,
},
WorkspaceAppsProvider: workspaceapps.New(
options.Logger.Named("workspaceapps"),
options.AccessURL,
options.Authorizer,
options.Database,
options.DeploymentConfig,
oauthConfigs,
options.AppSigningKey,
),
metricsCache: metricsCache,
Auditor: atomic.Pointer[audit.Auditor]{},
TemplateScheduleStore: atomic.Pointer[schedule.TemplateScheduleStore]{},
Expand All @@ -266,20 +287,16 @@ func New(options *Options) *API {
api.TemplateScheduleStore.Store(&options.TemplateScheduleStore)
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0)
api.TailnetCoordinator.Store(&options.TailnetCoordinator)
oauthConfigs := &httpmw.OAuth2Configs{
Github: options.GithubOAuth2Config,
OIDC: options.OIDCConfig,
}

apiKeyMiddleware := httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
apiKeyMiddleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
DB: options.Database,
OAuth2Configs: oauthConfigs,
RedirectToLogin: false,
DisableSessionExpiryRefresh: options.DeploymentConfig.DisableSessionExpiryRefresh.Value,
Optional: false,
})
// Same as above but it redirects to the login page.
apiKeyMiddlewareRedirect := httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
apiKeyMiddlewareRedirect := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
DB: options.Database,
OAuth2Configs: oauthConfigs,
RedirectToLogin: true,
Expand All @@ -305,23 +322,9 @@ func New(options *Options) *API {
httpmw.Prometheus(options.PrometheusRegistry),
// handleSubdomainApplications checks if the first subdomain is a valid
// app URL. If it is, it will serve that application.
api.handleSubdomainApplications(
apiRateLimiter,
// Middleware to impose on the served application.
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
DB: options.Database,
OAuth2Configs: oauthConfigs,
// The code handles the the case where the user is not
// authenticated automatically.
RedirectToLogin: false,
DisableSessionExpiryRefresh: options.DeploymentConfig.DisableSessionExpiryRefresh.Value,
Optional: true,
}),
httpmw.AsAuthzSystem(
httpmw.ExtractUserParam(api.Database, false),
httpmw.ExtractWorkspaceAndAgentParam(api.Database),
),
),
//
// Workspace apps do their own auth.
api.handleSubdomainApplications(apiRateLimiter),
// Build-Version is helpful for debugging.
func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -345,26 +348,8 @@ func New(options *Options) *API {
r.Get("/healthz", func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("OK")) })

apps := func(r chi.Router) {
r.Use(
apiRateLimiter,
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
DB: options.Database,
OAuth2Configs: oauthConfigs,
// Optional is true to allow for public apps. If an
// authorization check fails and the user is not authenticated,
// they will be redirected to the login page by the app handler.
RedirectToLogin: false,
DisableSessionExpiryRefresh: options.DeploymentConfig.DisableSessionExpiryRefresh.Value,
Optional: true,
}),
httpmw.AsAuthzSystem(
// Redirect to the login page if the user tries to open an app with
// "me" as the username and they are not logged in.
httpmw.ExtractUserParam(api.Database, true),
// Extracts the <workspace.agent> from the url
httpmw.ExtractWorkspaceAndAgentParam(api.Database),
),
)
// Workspace apps do their own auth.
r.Use(apiRateLimiter)
r.HandleFunc("/*", api.workspaceAppsProxyPath)
}
// %40 is the encoded character of the @ symbol. VS Code Web does
Expand Down Expand Up @@ -742,9 +727,10 @@ type API struct {
WebsocketWaitGroup sync.WaitGroup
derpCloseFunc func()

metricsCache *metricscache.Cache
workspaceAgentCache *wsconncache.Cache
updateChecker *updatecheck.Checker
metricsCache *metricscache.Cache
workspaceAgentCache *wsconncache.Cache
updateChecker *updatecheck.Checker
WorkspaceAppsProvider *workspaceapps.Provider

// Experiments contains the list of experiments currently enabled.
// This is used to gate features that are not yet ready for production.
Expand Down
6 changes: 6 additions & 0 deletions coderd/coderdtest/coderdtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"crypto/x509"
"crypto/x509/pkix"
"encoding/base64"
"encoding/hex"
"encoding/json"
"encoding/pem"
"errors"
Expand Down Expand Up @@ -82,6 +83,10 @@ import (
"github.com/coder/coder/testutil"
)

// AppSigningKey is a 64-byte key used to sign JWTs for workspace app tickets in
// tests.
var AppSigningKey = must(hex.DecodeString("64656164626565666465616462656566646561646265656664656164626565666465616462656566646561646265656664656164626565666465616462656566"))

type Options struct {
// AccessURL denotes a custom access URL. By default we use the httptest
// server's URL. Setting this may result in unexpected behavior (especially
Expand Down Expand Up @@ -330,6 +335,7 @@ func NewOptions(t *testing.T, options *Options) (func(http.Handler), context.Can
DeploymentConfig: options.DeploymentConfig,
UpdateCheckOptions: options.UpdateCheckOptions,
SwaggerEndpoint: options.SwaggerEndpoint,
AppSigningKey: AppSigningKey,
}
}

Expand Down
18 changes: 18 additions & 0 deletions coderd/database/dbauthz/querier.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ func (q *querier) Ping(ctx context.Context) (time.Duration, error) {
return q.db.Ping(ctx)
}

func (q *querier) AcquireLock(ctx context.Context, id int64) error {
return q.db.AcquireLock(ctx, id)
}

func (q *querier) TryAcquireLock(ctx context.Context, id int64) (bool, error) {
return q.db.TryAcquireLock(ctx, id)
}

// InTx runs the given function in a transaction.
func (q *querier) InTx(function func(querier database.Store) error, txOpts *sql.TxOptions) error {
return q.db.InTx(func(tx database.Store) error {
Expand Down Expand Up @@ -317,6 +325,16 @@ func (q *querier) GetLogoURL(ctx context.Context) (string, error) {
return q.db.GetLogoURL(ctx)
}

func (q *querier) GetAppSigningKey(ctx context.Context) (string, error) {
// No authz checks
return q.db.GetAppSigningKey(ctx)
}

func (q *querier) InsertAppSigningKey(ctx context.Context, data string) error {
// No authz checks as this is done during startup
return q.db.InsertAppSigningKey(ctx, data)
}

func (q *querier) GetServiceBanner(ctx context.Context) (string, error) {
// No authz checks
return q.db.GetServiceBanner(ctx)
Expand Down
66 changes: 65 additions & 1 deletion coderd/database/dbfake/databasefake.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ func New() database.Store {
workspaceApps: make([]database.WorkspaceApp, 0),
workspaces: make([]database.Workspace, 0),
licenses: make([]database.License, 0),
locks: map[int64]struct{}{},
},
}
}
Expand All @@ -89,6 +90,11 @@ type fakeQuerier struct {
*data
}

type fakeTx struct {
*fakeQuerier
locks map[int64]struct{}
}

type data struct {
// Legacy tables
apiKeys []database.APIKey
Expand Down Expand Up @@ -124,11 +130,15 @@ type data struct {
workspaceResources []database.WorkspaceResource
workspaces []database.Workspace

// Locks is a map of lock names. Any keys within the map are currently
// locked.
locks map[int64]struct{}
deploymentID string
derpMeshKey string
lastUpdateCheck []byte
serviceBanner []byte
logoURL string
appSigningKey string
lastLicenseID int32
}

Expand Down Expand Up @@ -196,11 +206,50 @@ func (*fakeQuerier) Ping(_ context.Context) (time.Duration, error) {
return 0, nil
}

func (*fakeQuerier) AcquireLock(_ context.Context, _ int64) error {
return xerrors.New("AcquireLock must only be called within a transaction")
}

func (*fakeQuerier) TryAcquireLock(_ context.Context, _ int64) (bool, error) {
return false, xerrors.New("TryAcquireLock must only be called within a transaction")
}

func (tx *fakeTx) AcquireLock(_ context.Context, id int64) error {
if _, ok := tx.fakeQuerier.locks[id]; ok {
return xerrors.Errorf("cannot acquire lock %d: already held", id)
}
tx.fakeQuerier.locks[id] = struct{}{}
tx.locks[id] = struct{}{}
return nil
}

func (tx *fakeTx) TryAcquireLock(_ context.Context, id int64) (bool, error) {
if _, ok := tx.fakeQuerier.locks[id]; ok {
return false, nil
}
tx.fakeQuerier.locks[id] = struct{}{}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be protected or use sync.Map? Otherwise we may have concurrent map read/writes between goroutines.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't expect a transaction to be used concurrently from two places at once (and I don't think that even works anyways) but I can add this

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, the lock on the underlying data struct is replaced with a noop lock in a tx, so concurrent database read/writes may panic

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've actually decided against this for the above reason.

tx.locks[id] = struct{}{}
return true, nil
}

func (tx *fakeTx) releaseLocks() {
for id := range tx.locks {
delete(tx.fakeQuerier.locks, id)
}
tx.locks = map[int64]struct{}{}
}

// InTx doesn't rollback data properly for in-memory yet.
func (q *fakeQuerier) InTx(fn func(database.Store) error, _ *sql.TxOptions) error {
q.mutex.Lock()
defer q.mutex.Unlock()
return fn(&fakeQuerier{mutex: inTxMutex{}, data: q.data})
tx := &fakeTx{
fakeQuerier: &fakeQuerier{mutex: inTxMutex{}, data: q.data},
locks: map[int64]struct{}{},
}
defer tx.releaseLocks()

return fn(tx)
}

func (q *fakeQuerier) AcquireProvisionerJob(_ context.Context, arg database.AcquireProvisionerJobParams) (database.ProvisionerJob, error) {
Expand Down Expand Up @@ -4004,6 +4053,21 @@ func (q *fakeQuerier) GetLogoURL(_ context.Context) (string, error) {
return q.logoURL, nil
}

func (q *fakeQuerier) GetAppSigningKey(_ context.Context) (string, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()

return q.appSigningKey, nil
}

func (q *fakeQuerier) InsertAppSigningKey(_ context.Context, data string) error {
q.mutex.Lock()
defer q.mutex.Unlock()

q.appSigningKey = data
return nil
}

func (q *fakeQuerier) InsertLicense(
_ context.Context, arg database.InsertLicenseParams,
) (database.License, error) {
Expand Down
Loading