diff --git a/.github/workflows/coder.yaml b/.github/workflows/coder.yaml index cfc7cab5d92fb..c8e7d5356635c 100644 --- a/.github/workflows/coder.yaml +++ b/.github/workflows/coder.yaml @@ -151,7 +151,6 @@ jobs: - run: go install gotest.tools/gotestsum@latest - uses: hashicorp/setup-terraform@v1 - if: runner.os == 'Linux' with: terraform_version: 1.1.2 terraform_wrapper: false diff --git a/.vscode/settings.json b/.vscode/settings.json index db290aedc5202..3886c15fbfa72 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -23,5 +23,32 @@ } ] }, - "cSpell.words": ["coderd", "coderdtest", "codersdk", "httpmw", "oneof", "stretchr", "xerrors"] + "cSpell.words": [ + "coderd", + "coderdtest", + "codersdk", + "drpc", + "drpcconn", + "drpcmux", + "drpcserver", + "goleak", + "hashicorp", + "httpmw", + "moby", + "nhooyr", + "nolint", + "nosec", + "oneof", + "protobuf", + "provisionerd", + "provisionersdk", + "retrier", + "sdkproto", + "stretchr", + "tfexec", + "tfstate", + "unconvert", + "xerrors", + "yamux" + ] } diff --git a/coderd/cmd/root.go b/coderd/cmd/root.go index 705cb60e511da..e63f4a50a901c 100644 --- a/coderd/cmd/root.go +++ b/coderd/cmd/root.go @@ -11,6 +11,7 @@ import ( "cdr.dev/slog" "cdr.dev/slog/sloggers/sloghuman" "github.com/coder/coder/coderd" + "github.com/coder/coder/database" "github.com/coder/coder/database/databasefake" ) @@ -24,6 +25,7 @@ func Root() *cobra.Command { handler := coderd.New(&coderd.Options{ Logger: slog.Make(sloghuman.Sink(os.Stderr)), Database: databasefake.New(), + Pubsub: database.NewPubsubInMemory(), }) listener, err := net.Listen("tcp", address) diff --git a/coderd/coderd.go b/coderd/coderd.go index 05569b8aeb5e1..b59dbb47b669f 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -64,6 +64,10 @@ func New(options *Options) http.Handler { r.Route("/history", func(r chi.Router) { r.Get("/", api.projectHistoryByOrganization) r.Post("/", api.postProjectHistoryByOrganization) + r.Route("/{projecthistory}", func(r chi.Router) { + r.Use(httpmw.ExtractProjectHistoryParam(api.Database)) + r.Get("/", api.projectHistoryByOrganizationAndName) + }) }) }) }) @@ -84,11 +88,19 @@ func New(options *Options) http.Handler { r.Route("/history", func(r chi.Router) { r.Post("/", api.postWorkspaceHistoryByUser) r.Get("/", api.workspaceHistoryByUser) - r.Get("/latest", api.latestWorkspaceHistoryByUser) + r.Route("/{workspacehistory}", func(r chi.Router) { + r.Use(httpmw.ExtractWorkspaceHistoryParam(options.Database)) + r.Get("/", api.workspaceHistoryByName) + }) }) }) }) }) + + r.Route("/provisioners/daemons", func(r chi.Router) { + r.Get("/", api.provisionerDaemons) + r.Get("/serve", api.provisionerDaemonsServe) + }) }) r.NotFound(site.Handler().ServeHTTP) return r diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 1ecf069bce864..6a6f97f3ef090 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -3,13 +3,16 @@ package coderdtest import ( "context" "database/sql" + "io" "net/http/httptest" "net/url" "os" "testing" + "time" "github.com/stretchr/testify/require" + "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/coderd" "github.com/coder/coder/codersdk" @@ -17,6 +20,10 @@ import ( "github.com/coder/coder/database" "github.com/coder/coder/database/databasefake" "github.com/coder/coder/database/postgres" + "github.com/coder/coder/provisioner/terraform" + "github.com/coder/coder/provisionerd" + "github.com/coder/coder/provisionersdk" + "github.com/coder/coder/provisionersdk/proto" ) // Server represents a test instance of coderd. @@ -57,11 +64,46 @@ func (s *Server) RandomInitialUser(t *testing.T) coderd.CreateInitialUserRequest return req } +// AddProvisionerd launches a new provisionerd instance! +func (s *Server) AddProvisionerd(t *testing.T) io.Closer { + tfClient, tfServer := provisionersdk.TransportPipe() + ctx, cancelFunc := context.WithCancel(context.Background()) + t.Cleanup(func() { + _ = tfClient.Close() + _ = tfServer.Close() + cancelFunc() + }) + go func() { + err := terraform.Serve(ctx, &terraform.ServeOptions{ + ServeOptions: &provisionersdk.ServeOptions{ + Listener: tfServer, + }, + Logger: slogtest.Make(t, nil).Named("terraform-provisioner").Leveled(slog.LevelDebug), + }) + require.NoError(t, err) + }() + + closer := provisionerd.New(s.Client.ProvisionerDaemonClient, &provisionerd.Options{ + Logger: slogtest.Make(t, nil).Named("provisionerd").Leveled(slog.LevelDebug), + PollInterval: 50 * time.Millisecond, + UpdateInterval: 50 * time.Millisecond, + Provisioners: provisionerd.Provisioners{ + string(database.ProvisionerTypeTerraform): proto.NewDRPCProvisionerClient(provisionersdk.Conn(tfClient)), + }, + WorkDirectory: t.TempDir(), + }) + t.Cleanup(func() { + _ = closer.Close() + }) + return closer +} + // New constructs a new coderd test instance. This returned Server // should contain no side-effects. func New(t *testing.T) Server { // This can be hotswapped for a live database instance. db := databasefake.New() + pubsub := database.NewPubsubInMemory() if os.Getenv("DB") != "" { connectionURL, close, err := postgres.Open() require.NoError(t, err) @@ -74,11 +116,18 @@ func New(t *testing.T) Server { err = database.Migrate(sqlDB) require.NoError(t, err) db = database.New(sqlDB) + + pubsub, err = database.NewPubsub(context.Background(), sqlDB, connectionURL) + require.NoError(t, err) + t.Cleanup(func() { + _ = pubsub.Close() + }) } handler := coderd.New(&coderd.Options{ Logger: slogtest.Make(t, nil), Database: db, + Pubsub: pubsub, }) srv := httptest.NewServer(handler) serverURL, err := url.Parse(srv.URL) diff --git a/coderd/coderdtest/coderdtest_test.go b/coderd/coderdtest/coderdtest_test.go index e36d1c1408cd1..b7312f96864fc 100644 --- a/coderd/coderdtest/coderdtest_test.go +++ b/coderd/coderdtest/coderdtest_test.go @@ -16,4 +16,5 @@ func TestNew(t *testing.T) { t.Parallel() server := coderdtest.New(t) _ = server.RandomInitialUser(t) + _ = server.AddProvisionerd(t) } diff --git a/coderd/projecthistory.go b/coderd/projecthistory.go index a5057b8c514f0..55c4cae55ec32 100644 --- a/coderd/projecthistory.go +++ b/coderd/projecthistory.go @@ -4,6 +4,7 @@ import ( "archive/tar" "bytes" "database/sql" + "encoding/json" "errors" "fmt" "net/http" @@ -12,6 +13,7 @@ import ( "github.com/go-chi/render" "github.com/google/uuid" "github.com/moby/moby/pkg/namesgenerator" + "golang.org/x/xerrors" "github.com/coder/coder/database" "github.com/coder/coder/httpapi" @@ -26,6 +28,7 @@ type ProjectHistory struct { UpdatedAt time.Time `json:"updated_at"` Name string `json:"name"` StorageMethod database.ProjectStorageMethod `json:"storage_method"` + Import ProvisionerJob `json:"import"` } // CreateProjectHistoryRequest enables callers to create a new Project Version. @@ -50,12 +53,33 @@ func (api *api) projectHistoryByOrganization(rw http.ResponseWriter, r *http.Req } apiHistory := make([]ProjectHistory, 0) for _, version := range history { - apiHistory = append(apiHistory, convertProjectHistory(version)) + job, err := api.Database.GetProvisionerJobByID(r.Context(), version.ImportJobID) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get provisioner job: %s", err), + }) + return + } + apiHistory = append(apiHistory, convertProjectHistory(version, job)) } render.Status(r, http.StatusOK) render.JSON(rw, r, apiHistory) } +// Return a single project history by organization and name. +func (api *api) projectHistoryByOrganizationAndName(rw http.ResponseWriter, r *http.Request) { + projectHistory := httpmw.ProjectHistoryParam(r) + job, err := api.Database.GetProvisionerJobByID(r.Context(), projectHistory.ImportJobID) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get provisioner job: %s", err), + }) + return + } + render.Status(r, http.StatusOK) + render.JSON(rw, r, convertProjectHistory(projectHistory, job)) +} + // Creates a new version of the project. An import job is queued to parse // the storage method provided. Once completed, the import job will specify // the version as latest. @@ -82,37 +106,71 @@ func (api *api) postProjectHistoryByOrganization(rw http.ResponseWriter, r *http return } + apiKey := httpmw.APIKey(r) project := httpmw.ProjectParam(r) - history, err := api.Database.InsertProjectHistory(r.Context(), database.InsertProjectHistoryParams{ - ID: uuid.New(), - ProjectID: project.ID, - CreatedAt: database.Now(), - UpdatedAt: database.Now(), - Name: namesgenerator.GetRandomName(1), - StorageMethod: createProjectVersion.StorageMethod, - StorageSource: createProjectVersion.StorageSource, - // TODO: Make this do something! - ImportJobID: uuid.New(), + + var provisionerJob database.ProvisionerJob + var projectHistory database.ProjectHistory + err := api.Database.InTx(func(db database.Store) error { + projectHistoryID := uuid.New() + input, err := json.Marshal(projectImportJob{ + ProjectHistoryID: projectHistoryID, + }) + if err != nil { + return xerrors.Errorf("marshal import job: %w", err) + } + + provisionerJob, err = db.InsertProvisionerJob(r.Context(), database.InsertProvisionerJobParams{ + ID: uuid.New(), + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + InitiatorID: apiKey.UserID, + Provisioner: project.Provisioner, + Type: database.ProvisionerJobTypeProjectImport, + ProjectID: project.ID, + Input: input, + }) + if err != nil { + return xerrors.Errorf("insert provisioner job: %w", err) + } + + projectHistory, err = api.Database.InsertProjectHistory(r.Context(), database.InsertProjectHistoryParams{ + ID: projectHistoryID, + ProjectID: project.ID, + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + Name: namesgenerator.GetRandomName(1), + StorageMethod: createProjectVersion.StorageMethod, + StorageSource: createProjectVersion.StorageSource, + ImportJobID: provisionerJob.ID, + }) + if err != nil { + return xerrors.Errorf("insert project history: %s", err) + } + return nil }) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("insert project history: %s", err), + Message: err.Error(), }) return } - // TODO: A job to process the new version should occur here. - render.Status(r, http.StatusCreated) - render.JSON(rw, r, convertProjectHistory(history)) + render.JSON(rw, r, convertProjectHistory(projectHistory, provisionerJob)) } -func convertProjectHistory(history database.ProjectHistory) ProjectHistory { +func convertProjectHistory(history database.ProjectHistory, job database.ProvisionerJob) ProjectHistory { return ProjectHistory{ ID: history.ID, ProjectID: history.ProjectID, CreatedAt: history.CreatedAt, UpdatedAt: history.UpdatedAt, Name: history.Name, + Import: convertProvisionerJob(job), } } + +func projectHistoryLogsChannel(projectHistoryID uuid.UUID) string { + return fmt.Sprintf("project-history-logs:%s", projectHistoryID) +} diff --git a/coderd/projecthistory_test.go b/coderd/projecthistory_test.go index 4c9b727fbe358..f3a1922b0ea4c 100644 --- a/coderd/projecthistory_test.go +++ b/coderd/projecthistory_test.go @@ -25,7 +25,7 @@ func TestProjectHistory(t *testing.T) { Provisioner: database.ProvisionerTypeTerraform, }) require.NoError(t, err) - versions, err := server.Client.ProjectHistory(context.Background(), user.Organization, project.Name) + versions, err := server.Client.ListProjectHistory(context.Background(), user.Organization, project.Name) require.NoError(t, err) require.Len(t, versions, 0) }) @@ -48,14 +48,17 @@ func TestProjectHistory(t *testing.T) { require.NoError(t, err) _, err = writer.Write(make([]byte, 1<<10)) require.NoError(t, err) - _, err = server.Client.CreateProjectHistory(context.Background(), user.Organization, project.Name, coderd.CreateProjectHistoryRequest{ + history, err := server.Client.CreateProjectHistory(context.Background(), user.Organization, project.Name, coderd.CreateProjectHistoryRequest{ StorageMethod: database.ProjectStorageMethodInlineArchive, StorageSource: buffer.Bytes(), }) require.NoError(t, err) - versions, err := server.Client.ProjectHistory(context.Background(), user.Organization, project.Name) + versions, err := server.Client.ListProjectHistory(context.Background(), user.Organization, project.Name) require.NoError(t, err) require.Len(t, versions, 1) + + _, err = server.Client.ProjectHistory(context.Background(), user.Organization, project.Name, history.Name) + require.NoError(t, err) }) t.Run("CreateHistoryArchiveTooBig", func(t *testing.T) { diff --git a/coderd/provisionerdaemons.go b/coderd/provisionerdaemons.go new file mode 100644 index 0000000000000..1a315402f08fc --- /dev/null +++ b/coderd/provisionerdaemons.go @@ -0,0 +1,619 @@ +package coderd + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "net/http" + "reflect" + "time" + + "github.com/go-chi/render" + "github.com/google/uuid" + "github.com/hashicorp/yamux" + "github.com/moby/moby/pkg/namesgenerator" + "golang.org/x/xerrors" + "nhooyr.io/websocket" + "storj.io/drpc/drpcmux" + "storj.io/drpc/drpcserver" + + "github.com/coder/coder/coderd/projectparameter" + "github.com/coder/coder/database" + "github.com/coder/coder/httpapi" + "github.com/coder/coder/provisionerd/proto" + sdkproto "github.com/coder/coder/provisionersdk/proto" +) + +type ProvisionerDaemon database.ProvisionerDaemon + +// Lists all registered provisioner daemons. +func (api *api) provisionerDaemons(rw http.ResponseWriter, r *http.Request) { + daemons, err := api.Database.GetProvisionerDaemons(r.Context()) + if errors.Is(err, sql.ErrNoRows) { + err = nil + daemons = []database.ProvisionerDaemon{} + } + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get provisioner daemons: %s", err), + }) + return + } + + render.Status(r, http.StatusOK) + render.JSON(rw, r, daemons) +} + +// Serves the provisioner daemon protobuf API over a WebSocket. +func (api *api) provisionerDaemonsServe(rw http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{ + // Need to disable compression to avoid a data-race + CompressionMode: websocket.CompressionDisabled, + }) + if err != nil { + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: fmt.Sprintf("accept websocket: %s", err), + }) + return + } + + daemon, err := api.Database.InsertProvisionerDaemon(r.Context(), database.InsertProvisionerDaemonParams{ + ID: uuid.New(), + CreatedAt: database.Now(), + Name: namesgenerator.GetRandomName(1), + Provisioners: []database.ProvisionerType{database.ProvisionerTypeCdrBasic, database.ProvisionerTypeTerraform}, + }) + if err != nil { + _ = conn.Close(websocket.StatusInternalError, fmt.Sprintf("insert provisioner daemon:% s", err)) + return + } + + // Multiplexes the incoming connection using yamux. + // This allows multiple function calls to occur over + // the same connection. + session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), nil) + if err != nil { + _ = conn.Close(websocket.StatusInternalError, fmt.Sprintf("multiplex server: %s", err)) + return + } + mux := drpcmux.New() + err = proto.DRPCRegisterProvisionerDaemon(mux, &provisionerdServer{ + ID: daemon.ID, + Database: api.Database, + Pubsub: api.Pubsub, + Provisioners: daemon.Provisioners, + }) + if err != nil { + _ = conn.Close(websocket.StatusInternalError, fmt.Sprintf("drpc register provisioner daemon: %s", err)) + return + } + server := drpcserver.New(mux) + err = server.Serve(r.Context(), session) + if err != nil { + _ = conn.Close(websocket.StatusInternalError, fmt.Sprintf("serve: %s", err)) + } +} + +// The input for a "workspace_provision" job. +type workspaceProvisionJob struct { + WorkspaceHistoryID uuid.UUID `json:"workspace_history_id"` +} + +// The input for a "project_import" job. +type projectImportJob struct { + ProjectHistoryID uuid.UUID `json:"project_history_id"` +} + +// Implementation of the provisioner daemon protobuf server. +type provisionerdServer struct { + ID uuid.UUID + Provisioners []database.ProvisionerType + Database database.Store + Pubsub database.Pubsub +} + +// AcquireJob queries the database to lock a job. +func (server *provisionerdServer) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { + // This marks the job as locked in the database. + job, err := server.Database.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ + StartedAt: sql.NullTime{ + Time: database.Now(), + Valid: true, + }, + WorkerID: uuid.NullUUID{ + UUID: server.ID, + Valid: true, + }, + Types: server.Provisioners, + }) + if errors.Is(err, sql.ErrNoRows) { + // The provisioner daemon assumes no jobs are available if + // an empty struct is returned. + return &proto.AcquiredJob{}, nil + } + if err != nil { + return nil, xerrors.Errorf("acquire job: %w", err) + } + // Marks the acquired job as failed with the error message provided. + failJob := func(errorMessage string) error { + err = server.Database.UpdateProvisionerJobByID(ctx, database.UpdateProvisionerJobByIDParams{ + ID: job.ID, + CompletedAt: sql.NullTime{ + Time: database.Now(), + Valid: true, + }, + Error: sql.NullString{ + String: errorMessage, + Valid: true, + }, + }) + if err != nil { + return xerrors.Errorf("update provisioner job: %w", err) + } + return xerrors.Errorf("request job was invalidated: %s", errorMessage) + } + + project, err := server.Database.GetProjectByID(ctx, job.ProjectID) + if err != nil { + return nil, failJob(fmt.Sprintf("get project: %s", err)) + } + organization, err := server.Database.GetOrganizationByID(ctx, project.OrganizationID) + if err != nil { + return nil, failJob(fmt.Sprintf("get organization: %s", err)) + } + user, err := server.Database.GetUserByID(ctx, job.InitiatorID) + if err != nil { + return nil, failJob(fmt.Sprintf("get user: %s", err)) + } + + protoJob := &proto.AcquiredJob{ + JobId: job.ID.String(), + CreatedAt: job.CreatedAt.UnixMilli(), + Provisioner: string(job.Provisioner), + OrganizationName: organization.Name, + ProjectName: project.Name, + UserName: user.Username, + } + var projectHistory database.ProjectHistory + switch job.Type { + case database.ProvisionerJobTypeWorkspaceProvision: + var input workspaceProvisionJob + err = json.Unmarshal(job.Input, &input) + if err != nil { + return nil, failJob(fmt.Sprintf("unmarshal job input %q: %s", job.Input, err)) + } + workspaceHistory, err := server.Database.GetWorkspaceHistoryByID(ctx, input.WorkspaceHistoryID) + if err != nil { + return nil, failJob(fmt.Sprintf("get workspace history: %s", err)) + } + workspace, err := server.Database.GetWorkspaceByID(ctx, workspaceHistory.WorkspaceID) + if err != nil { + return nil, failJob(fmt.Sprintf("get workspace: %s", err)) + } + projectHistory, err = server.Database.GetProjectHistoryByID(ctx, workspaceHistory.ProjectHistoryID) + if err != nil { + return nil, failJob(fmt.Sprintf("get project history: %s", err)) + } + + // Compute parameters for the workspace to consume. + parameters, err := projectparameter.Compute(ctx, server.Database, projectparameter.Scope{ + OrganizationID: organization.ID, + ProjectID: project.ID, + ProjectHistoryID: projectHistory.ID, + UserID: user.ID, + WorkspaceID: workspace.ID, + WorkspaceHistoryID: workspaceHistory.ID, + }) + if err != nil { + return nil, failJob(fmt.Sprintf("compute parameters: %s", err)) + } + // Convert parameters to the protobuf type. + protoParameters := make([]*sdkproto.ParameterValue, 0, len(parameters)) + for _, parameter := range parameters { + protoParameters = append(protoParameters, parameter.Proto) + } + + provisionerState := []byte{} + // If workspace history exists before this entry, use that state. + // We can't use the before state everytime, because if a job fails + // for some random reason, the workspace shouldn't be reset. + // + // Maybe we should make state global on a workspace? + if workspaceHistory.BeforeID.Valid { + beforeHistory, err := server.Database.GetWorkspaceHistoryByID(ctx, workspaceHistory.BeforeID.UUID) + if err != nil { + return nil, failJob(fmt.Sprintf("get workspace history: %s", err)) + } + provisionerState = beforeHistory.ProvisionerState + } + + protoJob.Type = &proto.AcquiredJob_WorkspaceProvision_{ + WorkspaceProvision: &proto.AcquiredJob_WorkspaceProvision{ + WorkspaceHistoryId: workspaceHistory.ID.String(), + WorkspaceName: workspace.Name, + State: provisionerState, + ParameterValues: protoParameters, + }, + } + case database.ProvisionerJobTypeProjectImport: + var input projectImportJob + err = json.Unmarshal(job.Input, &input) + if err != nil { + return nil, failJob(fmt.Sprintf("unmarshal job input %q: %s", job.Input, err)) + } + projectHistory, err = server.Database.GetProjectHistoryByID(ctx, input.ProjectHistoryID) + if err != nil { + return nil, failJob(fmt.Sprintf("get project history: %s", err)) + } + + protoJob.Type = &proto.AcquiredJob_ProjectImport_{ + ProjectImport: &proto.AcquiredJob_ProjectImport{ + ProjectHistoryId: projectHistory.ID.String(), + ProjectHistoryName: projectHistory.Name, + }, + } + } + switch projectHistory.StorageMethod { + case database.ProjectStorageMethodInlineArchive: + protoJob.ProjectSourceArchive = projectHistory.StorageSource + default: + return nil, failJob(fmt.Sprintf("unsupported storage source: %q", projectHistory.StorageMethod)) + } + + return protoJob, err +} + +func (server *provisionerdServer) UpdateJob(stream proto.DRPCProvisionerDaemon_UpdateJobStream) error { + for { + update, err := stream.Recv() + if err != nil { + return err + } + parsedID, err := uuid.Parse(update.JobId) + if err != nil { + return xerrors.Errorf("parse job id: %w", err) + } + job, err := server.Database.GetProvisionerJobByID(stream.Context(), parsedID) + if err != nil { + return xerrors.Errorf("get job: %w", err) + } + if !job.WorkerID.Valid { + return errors.New("job isn't running yet") + } + if job.WorkerID.UUID.String() != server.ID.String() { + return errors.New("you don't own this job") + } + + err = server.Database.UpdateProvisionerJobByID(stream.Context(), database.UpdateProvisionerJobByIDParams{ + ID: parsedID, + UpdatedAt: database.Now(), + }) + if err != nil { + return xerrors.Errorf("update job: %w", err) + } + switch job.Type { + case database.ProvisionerJobTypeProjectImport: + if len(update.ProjectImportLogs) == 0 { + continue + } + var input projectImportJob + err = json.Unmarshal(job.Input, &input) + if err != nil { + return xerrors.Errorf("unmarshal job input %q: %s", job.Input, err) + } + insertParams := database.InsertProjectHistoryLogsParams{ + ProjectHistoryID: input.ProjectHistoryID, + } + for _, log := range update.ProjectImportLogs { + logLevel, err := convertLogLevel(log.Level) + if err != nil { + return xerrors.Errorf("convert log level: %w", err) + } + logSource, err := convertLogSource(log.Source) + if err != nil { + return xerrors.Errorf("convert log source: %w", err) + } + insertParams.ID = append(insertParams.ID, uuid.New()) + insertParams.CreatedAt = append(insertParams.CreatedAt, time.UnixMilli(log.CreatedAt)) + insertParams.Level = append(insertParams.Level, logLevel) + insertParams.Source = append(insertParams.Source, logSource) + insertParams.Output = append(insertParams.Output, log.Output) + } + logs, err := server.Database.InsertProjectHistoryLogs(stream.Context(), insertParams) + if err != nil { + return xerrors.Errorf("insert project logs: %w", err) + } + data, err := json.Marshal(logs) + if err != nil { + return xerrors.Errorf("marshal project log: %w", err) + } + err = server.Pubsub.Publish(projectHistoryLogsChannel(input.ProjectHistoryID), data) + if err != nil { + return xerrors.Errorf("publish history log: %w", err) + } + case database.ProvisionerJobTypeWorkspaceProvision: + if len(update.WorkspaceProvisionLogs) == 0 { + continue + } + var input workspaceProvisionJob + err = json.Unmarshal(job.Input, &input) + if err != nil { + return xerrors.Errorf("unmarshal job input %q: %s", job.Input, err) + } + insertParams := database.InsertWorkspaceHistoryLogsParams{ + WorkspaceHistoryID: input.WorkspaceHistoryID, + } + for _, log := range update.WorkspaceProvisionLogs { + logLevel, err := convertLogLevel(log.Level) + if err != nil { + return xerrors.Errorf("convert log level: %w", err) + } + logSource, err := convertLogSource(log.Source) + if err != nil { + return xerrors.Errorf("convert log source: %w", err) + } + insertParams.ID = append(insertParams.ID, uuid.New()) + insertParams.CreatedAt = append(insertParams.CreatedAt, time.UnixMilli(log.CreatedAt)) + insertParams.Level = append(insertParams.Level, logLevel) + insertParams.Source = append(insertParams.Source, logSource) + insertParams.Output = append(insertParams.Output, log.Output) + } + logs, err := server.Database.InsertWorkspaceHistoryLogs(stream.Context(), insertParams) + if err != nil { + return xerrors.Errorf("insert workspace logs: %w", err) + } + data, err := json.Marshal(logs) + if err != nil { + return xerrors.Errorf("marshal project log: %w", err) + } + err = server.Pubsub.Publish(workspaceHistoryLogsChannel(input.WorkspaceHistoryID), data) + if err != nil { + return xerrors.Errorf("publish history log: %w", err) + } + } + } +} + +func (server *provisionerdServer) CancelJob(ctx context.Context, cancelJob *proto.CancelledJob) (*proto.Empty, error) { + jobID, err := uuid.Parse(cancelJob.JobId) + if err != nil { + return nil, xerrors.Errorf("parse job id: %w", err) + } + err = server.Database.UpdateProvisionerJobByID(ctx, database.UpdateProvisionerJobByIDParams{ + ID: jobID, + CancelledAt: sql.NullTime{ + Time: database.Now(), + Valid: true, + }, + UpdatedAt: database.Now(), + Error: sql.NullString{ + String: cancelJob.Error, + Valid: cancelJob.Error != "", + }, + }) + if err != nil { + return nil, xerrors.Errorf("update provisioner job: %w", err) + } + return &proto.Empty{}, nil +} + +// CompleteJob is triggered by a provision daemon to mark a provisioner job as completed. +func (server *provisionerdServer) CompleteJob(ctx context.Context, completed *proto.CompletedJob) (*proto.Empty, error) { + jobID, err := uuid.Parse(completed.JobId) + if err != nil { + return nil, xerrors.Errorf("parse job id: %w", err) + } + job, err := server.Database.GetProvisionerJobByID(ctx, jobID) + if err != nil { + return nil, xerrors.Errorf("get job by id: %w", err) + } + // TODO: Check if the worker ID matches! + // If it doesn't, a provisioner daemon could be impersonating another job! + + switch jobType := completed.Type.(type) { + case *proto.CompletedJob_ProjectImport_: + var input projectImportJob + err = json.Unmarshal(job.Input, &input) + if err != nil { + return nil, xerrors.Errorf("unmarshal job data: %w", err) + } + + // Validate that all parameters send from the provisioner daemon + // follow the protocol. + projectParameters := make([]database.InsertProjectParameterParams, 0, len(jobType.ProjectImport.ParameterSchemas)) + for _, protoParameter := range jobType.ProjectImport.ParameterSchemas { + validationTypeSystem, err := convertValidationTypeSystem(protoParameter.ValidationTypeSystem) + if err != nil { + return nil, xerrors.Errorf("convert validation type system for %q: %w", protoParameter.Name, err) + } + + projectParameter := database.InsertProjectParameterParams{ + ID: uuid.New(), + CreatedAt: database.Now(), + ProjectHistoryID: input.ProjectHistoryID, + Name: protoParameter.Name, + Description: protoParameter.Description, + RedisplayValue: protoParameter.RedisplayValue, + ValidationError: protoParameter.ValidationError, + ValidationCondition: protoParameter.ValidationCondition, + ValidationValueType: protoParameter.ValidationValueType, + ValidationTypeSystem: validationTypeSystem, + + AllowOverrideDestination: protoParameter.AllowOverrideDestination, + AllowOverrideSource: protoParameter.AllowOverrideSource, + } + + // It's possible a parameter doesn't define a default source! + if protoParameter.DefaultSource != nil { + parameterSourceScheme, err := convertParameterSourceScheme(protoParameter.DefaultSource.Scheme) + if err != nil { + return nil, xerrors.Errorf("convert parameter source scheme: %w", err) + } + projectParameter.DefaultSourceScheme = parameterSourceScheme + projectParameter.DefaultSourceValue = sql.NullString{ + String: protoParameter.DefaultSource.Value, + Valid: protoParameter.DefaultSource.Value != "", + } + } + + // It's possible a parameter doesn't define a default destination! + if protoParameter.DefaultDestination != nil { + parameterDestinationScheme, err := convertParameterDestinationScheme(protoParameter.DefaultDestination.Scheme) + if err != nil { + return nil, xerrors.Errorf("convert parameter destination scheme: %w", err) + } + projectParameter.DefaultDestinationScheme = parameterDestinationScheme + projectParameter.DefaultDestinationValue = sql.NullString{ + String: protoParameter.DefaultDestination.Value, + Valid: protoParameter.DefaultDestination.Value != "", + } + } + + projectParameters = append(projectParameters, projectParameter) + } + + // This must occur in a transaction in case of failure. + err = server.Database.InTx(func(db database.Store) error { + err = db.UpdateProvisionerJobByID(ctx, database.UpdateProvisionerJobByIDParams{ + ID: jobID, + UpdatedAt: database.Now(), + CompletedAt: sql.NullTime{ + Time: database.Now(), + Valid: true, + }, + }) + if err != nil { + return xerrors.Errorf("update provisioner job: %w", err) + } + // This could be a bulk-insert operation to improve performance. + // See the "InsertWorkspaceHistoryLogs" query. + for _, projectParameter := range projectParameters { + _, err = db.InsertProjectParameter(ctx, projectParameter) + if err != nil { + return xerrors.Errorf("insert project parameter %q: %w", projectParameter.Name, err) + } + } + return nil + }) + if err != nil { + return nil, xerrors.Errorf("complete job: %w", err) + } + case *proto.CompletedJob_WorkspaceProvision_: + var input workspaceProvisionJob + err = json.Unmarshal(job.Input, &input) + if err != nil { + return nil, xerrors.Errorf("unmarshal job data: %w", err) + } + + workspaceHistory, err := server.Database.GetWorkspaceHistoryByID(ctx, input.WorkspaceHistoryID) + if err != nil { + return nil, xerrors.Errorf("get workspace history: %w", err) + } + + err = server.Database.InTx(func(db database.Store) error { + err = db.UpdateProvisionerJobByID(ctx, database.UpdateProvisionerJobByIDParams{ + ID: jobID, + UpdatedAt: database.Now(), + CompletedAt: sql.NullTime{ + Time: database.Now(), + Valid: true, + }, + }) + if err != nil { + return xerrors.Errorf("update provisioner job: %w", err) + } + err = db.UpdateWorkspaceHistoryByID(ctx, database.UpdateWorkspaceHistoryByIDParams{ + ID: workspaceHistory.ID, + UpdatedAt: database.Now(), + ProvisionerState: jobType.WorkspaceProvision.State, + }) + if err != nil { + return xerrors.Errorf("update workspace history: %w", err) + } + // This could be a bulk insert to improve performance. + for _, protoResource := range jobType.WorkspaceProvision.Resources { + _, err = db.InsertWorkspaceResource(ctx, database.InsertWorkspaceResourceParams{ + ID: uuid.New(), + CreatedAt: database.Now(), + WorkspaceHistoryID: input.WorkspaceHistoryID, + Type: protoResource.Type, + Name: protoResource.Name, + // TODO: Generate this at the variable validation phase. + // Set the value in `default_source`, and disallow overwrite. + WorkspaceAgentToken: uuid.NewString(), + }) + if err != nil { + return xerrors.Errorf("insert workspace resource %q: %w", protoResource.Name, err) + } + } + return nil + }) + if err != nil { + return nil, xerrors.Errorf("complete job: %w", err) + } + default: + return nil, xerrors.Errorf("unknown job type %q; ensure coderd and provisionerd versions match", + reflect.TypeOf(completed.Type).String()) + } + + return &proto.Empty{}, nil +} + +func convertValidationTypeSystem(typeSystem sdkproto.ParameterSchema_TypeSystem) (database.ParameterTypeSystem, error) { + switch typeSystem { + case sdkproto.ParameterSchema_HCL: + return database.ParameterTypeSystemHCL, nil + default: + return database.ParameterTypeSystem(""), xerrors.Errorf("unknown type system: %d", typeSystem) + } +} + +func convertParameterSourceScheme(sourceScheme sdkproto.ParameterSource_Scheme) (database.ParameterSourceScheme, error) { + switch sourceScheme { + case sdkproto.ParameterSource_DATA: + return database.ParameterSourceSchemeData, nil + default: + return database.ParameterSourceScheme(""), xerrors.Errorf("unknown parameter source scheme: %d", sourceScheme) + } +} + +func convertParameterDestinationScheme(destinationScheme sdkproto.ParameterDestination_Scheme) (database.ParameterDestinationScheme, error) { + switch destinationScheme { + case sdkproto.ParameterDestination_ENVIRONMENT_VARIABLE: + return database.ParameterDestinationSchemeEnvironmentVariable, nil + case sdkproto.ParameterDestination_PROVISIONER_VARIABLE: + return database.ParameterDestinationSchemeProvisionerVariable, nil + default: + return database.ParameterDestinationScheme(""), xerrors.Errorf("unknown parameter destination scheme: %d", destinationScheme) + } +} + +func convertLogLevel(logLevel sdkproto.LogLevel) (database.LogLevel, error) { + switch logLevel { + case sdkproto.LogLevel_TRACE: + return database.LogLevelTrace, nil + case sdkproto.LogLevel_DEBUG: + return database.LogLevelDebug, nil + case sdkproto.LogLevel_INFO: + return database.LogLevelInfo, nil + case sdkproto.LogLevel_WARN: + return database.LogLevelWarn, nil + case sdkproto.LogLevel_ERROR: + return database.LogLevelError, nil + default: + return database.LogLevel(""), xerrors.Errorf("unknown log level: %d", logLevel) + } +} + +func convertLogSource(logSource proto.LogSource) (database.LogSource, error) { + switch logSource { + case proto.LogSource_PROVISIONER_DAEMON: + return database.LogSourceProvisionerDaemon, nil + case proto.LogSource_PROVISIONER: + return database.LogSourceProvisioner, nil + default: + return database.LogSource(""), xerrors.Errorf("unknown log source: %d", logSource) + } +} diff --git a/coderd/provisionerdaemons_test.go b/coderd/provisionerdaemons_test.go new file mode 100644 index 0000000000000..5cba701d5a34e --- /dev/null +++ b/coderd/provisionerdaemons_test.go @@ -0,0 +1,26 @@ +package coderd_test + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/coderd/coderdtest" +) + +func TestProvisionerDaemons(t *testing.T) { + t.Parallel() + + t.Run("Register", func(t *testing.T) { + t.Parallel() + server := coderdtest.New(t) + _ = server.AddProvisionerd(t) + require.Eventually(t, func() bool { + daemons, err := server.Client.ProvisionerDaemons(context.Background()) + require.NoError(t, err) + return len(daemons) > 0 + }, time.Second, 10*time.Millisecond) + }) +} diff --git a/coderd/provisioners.go b/coderd/provisioners.go new file mode 100644 index 0000000000000..f2afefa00cbef --- /dev/null +++ b/coderd/provisioners.go @@ -0,0 +1,78 @@ +package coderd + +import ( + "time" + + "github.com/google/uuid" + + "github.com/coder/coder/database" +) + +type ProvisionerJobStatus string + +// Completed returns whether the job is still processing. +func (p ProvisionerJobStatus) Completed() bool { + return p == ProvisionerJobStatusSucceeded || p == ProvisionerJobStatusFailed +} + +const ( + ProvisionerJobStatusPending ProvisionerJobStatus = "pending" + ProvisionerJobStatusRunning ProvisionerJobStatus = "running" + ProvisionerJobStatusSucceeded ProvisionerJobStatus = "succeeded" + ProvisionerJobStatusCancelled ProvisionerJobStatus = "canceled" + ProvisionerJobStatusFailed ProvisionerJobStatus = "failed" +) + +type ProvisionerJob struct { + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + StartedAt *time.Time `json:"started_at,omitempty"` + CancelledAt *time.Time `json:"canceled_at,omitempty"` + CompletedAt *time.Time `json:"completed_at,omitempty"` + Status ProvisionerJobStatus `json:"status"` + Error string `json:"error,omitempty"` + Provisioner database.ProvisionerType `json:"provisioner"` + WorkerID *uuid.UUID `json:"worker_id,omitempty"` +} + +func convertProvisionerJob(provisionerJob database.ProvisionerJob) ProvisionerJob { + job := ProvisionerJob{ + CreatedAt: provisionerJob.CreatedAt, + UpdatedAt: provisionerJob.UpdatedAt, + Error: provisionerJob.Error.String, + Provisioner: provisionerJob.Provisioner, + } + // Applying values optional to the struct. + if provisionerJob.StartedAt.Valid { + job.StartedAt = &provisionerJob.StartedAt.Time + } + if provisionerJob.CancelledAt.Valid { + job.CancelledAt = &provisionerJob.CancelledAt.Time + } + if provisionerJob.CompletedAt.Valid { + job.CompletedAt = &provisionerJob.CompletedAt.Time + } + if provisionerJob.WorkerID.Valid { + job.WorkerID = &provisionerJob.WorkerID.UUID + } + + switch { + case provisionerJob.CancelledAt.Valid: + job.Status = ProvisionerJobStatusCancelled + case !provisionerJob.StartedAt.Valid: + job.Status = ProvisionerJobStatusPending + case provisionerJob.CompletedAt.Valid: + job.Status = ProvisionerJobStatusSucceeded + case database.Now().Sub(provisionerJob.UpdatedAt) > 30*time.Second: + job.Status = ProvisionerJobStatusFailed + job.Error = "Worker failed to update job in time." + default: + job.Status = ProvisionerJobStatusRunning + } + + if job.Error != "" { + job.Status = ProvisionerJobStatusFailed + } + + return job +} diff --git a/coderd/workspacehistory.go b/coderd/workspacehistory.go index 32eba2e98e2da..f9e4c7690b4d3 100644 --- a/coderd/workspacehistory.go +++ b/coderd/workspacehistory.go @@ -2,6 +2,7 @@ package coderd import ( "database/sql" + "encoding/json" "errors" "fmt" "net/http" @@ -22,13 +23,14 @@ type WorkspaceHistory struct { ID uuid.UUID `json:"id"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` - CompletedAt time.Time `json:"completed_at"` WorkspaceID uuid.UUID `json:"workspace_id"` ProjectHistoryID uuid.UUID `json:"project_history_id"` BeforeID uuid.UUID `json:"before_id"` AfterID uuid.UUID `json:"after_id"` + Name string `json:"name"` Transition database.WorkspaceTransition `json:"transition"` Initiator string `json:"initiator"` + Provision ProvisionerJob `json:"provision"` } // CreateWorkspaceHistoryRequest provides options to update the latest workspace history. @@ -37,8 +39,6 @@ type CreateWorkspaceHistoryRequest struct { Transition database.WorkspaceTransition `json:"transition" validate:"oneof=create start stop delete,required"` } -// Begins transitioning a workspace to new state. This queues a provision job to asynchronously -// update the underlying infrastructure. Only one historical transition can occur at a time. func (api *api) postWorkspaceHistoryByUser(rw http.ResponseWriter, r *http.Request) { var createBuild CreateWorkspaceHistoryRequest if !httpapi.Read(rw, r, &createBuild) { @@ -63,12 +63,41 @@ func (api *api) postWorkspaceHistoryByUser(rw http.ResponseWriter, r *http.Reque }) return } + projectHistoryJob, err := api.Database.GetProvisionerJobByID(r.Context(), projectHistory.ImportJobID) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get provisioner job: %s", err), + }) + return + } + projectHistoryJobStatus := convertProvisionerJob(projectHistoryJob).Status + switch projectHistoryJobStatus { + case ProvisionerJobStatusPending, ProvisionerJobStatusRunning: + httpapi.Write(rw, http.StatusPreconditionFailed, httpapi.Response{ + Message: fmt.Sprintf("The provided project history is %s. Wait for it to complete importing!", projectHistoryJobStatus), + }) + return + case ProvisionerJobStatusFailed: + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: fmt.Sprintf("The provided project history %q has failed to import. You cannot create workspaces using it!", projectHistory.Name), + }) + return + } + + project, err := api.Database.GetProjectByID(r.Context(), projectHistory.ProjectID) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get project: %s", err), + }) + return + } // Store prior history ID if it exists to update it after we create new! priorHistoryID := uuid.NullUUID{} priorHistory, err := api.Database.GetWorkspaceHistoryByWorkspaceIDWithoutAfter(r.Context(), workspace.ID) if err == nil { - if !priorHistory.CompletedAt.Valid { + priorJob, err := api.Database.GetProvisionerJobByID(r.Context(), priorHistory.ProvisionJobID) + if err == nil && convertProvisionerJob(priorJob).Status.Completed() { httpapi.Write(rw, http.StatusConflict, httpapi.Response{ Message: "a workspace build is already active", }) @@ -87,12 +116,36 @@ func (api *api) postWorkspaceHistoryByUser(rw http.ResponseWriter, r *http.Reque return } + var provisionerJob database.ProvisionerJob var workspaceHistory database.WorkspaceHistory // This must happen in a transaction to ensure history can be inserted, and // the prior history can update it's "after" column to point at the new. err = api.Database.InTx(func(db database.Store) error { + // Generate the ID before-hand so the provisioner job is aware of it! + workspaceHistoryID := uuid.New() + input, err := json.Marshal(workspaceProvisionJob{ + WorkspaceHistoryID: workspaceHistoryID, + }) + if err != nil { + return xerrors.Errorf("marshal provision job: %w", err) + } + + provisionerJob, err = db.InsertProvisionerJob(r.Context(), database.InsertProvisionerJobParams{ + ID: uuid.New(), + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + InitiatorID: user.ID, + Provisioner: project.Provisioner, + Type: database.ProvisionerJobTypeWorkspaceProvision, + ProjectID: project.ID, + Input: input, + }) + if err != nil { + return xerrors.Errorf("insert provisioner job: %w", err) + } + workspaceHistory, err = db.InsertWorkspaceHistory(r.Context(), database.InsertWorkspaceHistoryParams{ - ID: uuid.New(), + ID: workspaceHistoryID, CreatedAt: database.Now(), UpdatedAt: database.Now(), WorkspaceID: workspace.ID, @@ -100,8 +153,7 @@ func (api *api) postWorkspaceHistoryByUser(rw http.ResponseWriter, r *http.Reque BeforeID: priorHistoryID, Initiator: user.ID, Transition: createBuild.Transition, - // This should create a provision job once that gets implemented! - ProvisionJobID: uuid.New(), + ProvisionJobID: provisionerJob.ID, }) if err != nil { return xerrors.Errorf("insert workspace history: %w", err) @@ -132,7 +184,7 @@ func (api *api) postWorkspaceHistoryByUser(rw http.ResponseWriter, r *http.Reque } render.Status(r, http.StatusCreated) - render.JSON(rw, r, convertWorkspaceHistory(workspaceHistory)) + render.JSON(rw, r, convertWorkspaceHistory(workspaceHistory, provisionerJob)) } // Returns all workspace history. This is not sorted. Use before/after to chronologically sort. @@ -152,31 +204,52 @@ func (api *api) workspaceHistoryByUser(rw http.ResponseWriter, r *http.Request) apiHistory := make([]WorkspaceHistory, 0, len(histories)) for _, history := range histories { - apiHistory = append(apiHistory, convertWorkspaceHistory(history)) + job, err := api.Database.GetProvisionerJobByID(r.Context(), history.ProvisionJobID) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get provisioner job: %s", err), + }) + return + } + apiHistory = append(apiHistory, convertWorkspaceHistory(history, job)) } render.Status(r, http.StatusOK) render.JSON(rw, r, apiHistory) } -// Returns the latest workspace history. This works by querying for history without "after" set. -func (api *api) latestWorkspaceHistoryByUser(rw http.ResponseWriter, r *http.Request) { - workspace := httpmw.WorkspaceParam(r) - - history, err := api.Database.GetWorkspaceHistoryByWorkspaceIDWithoutAfter(r.Context(), workspace.ID) - if errors.Is(err, sql.ErrNoRows) { - httpapi.Write(rw, http.StatusNotFound, httpapi.Response{ - Message: "workspace has no history", - }) - return - } +func (api *api) workspaceHistoryByName(rw http.ResponseWriter, r *http.Request) { + workspaceHistory := httpmw.WorkspaceHistoryParam(r) + job, err := api.Database.GetProvisionerJobByID(r.Context(), workspaceHistory.ProvisionJobID) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("get workspace history: %s", err), + Message: fmt.Sprintf("get provisioner job: %s", err), }) return } render.Status(r, http.StatusOK) - render.JSON(rw, r, convertWorkspaceHistory(history)) + render.JSON(rw, r, convertWorkspaceHistory(workspaceHistory, job)) +} + +// Converts the internal history representation to a public external-facing model. +func convertWorkspaceHistory(workspaceHistory database.WorkspaceHistory, provisionerJob database.ProvisionerJob) WorkspaceHistory { + //nolint:unconvert + return WorkspaceHistory(WorkspaceHistory{ + ID: workspaceHistory.ID, + CreatedAt: workspaceHistory.CreatedAt, + UpdatedAt: workspaceHistory.UpdatedAt, + WorkspaceID: workspaceHistory.WorkspaceID, + ProjectHistoryID: workspaceHistory.ProjectHistoryID, + BeforeID: workspaceHistory.BeforeID.UUID, + AfterID: workspaceHistory.AfterID.UUID, + Name: workspaceHistory.Name, + Transition: workspaceHistory.Transition, + Initiator: workspaceHistory.Initiator, + Provision: convertProvisionerJob(provisionerJob), + }) +} + +func workspaceHistoryLogsChannel(workspaceHistoryID uuid.UUID) string { + return fmt.Sprintf("workspace-history-logs:%s", workspaceHistoryID) } diff --git a/coderd/workspacehistory_test.go b/coderd/workspacehistory_test.go index 773de1a5b5a95..66dc5bd444621 100644 --- a/coderd/workspacehistory_test.go +++ b/coderd/workspacehistory_test.go @@ -5,6 +5,7 @@ import ( "bytes" "context" "testing" + "time" "github.com/google/uuid" "github.com/stretchr/testify/require" @@ -32,21 +33,31 @@ func TestWorkspaceHistory(t *testing.T) { return project, workspace } - setupProjectHistory := func(t *testing.T, client *codersdk.Client, user coderd.CreateInitialUserRequest, project coderd.Project) coderd.ProjectHistory { + setupProjectHistory := func(t *testing.T, client *codersdk.Client, user coderd.CreateInitialUserRequest, project coderd.Project, files map[string]string) coderd.ProjectHistory { var buffer bytes.Buffer writer := tar.NewWriter(&buffer) - err := writer.WriteHeader(&tar.Header{ - Name: "file", - Size: 1 << 10, - }) - require.NoError(t, err) - _, err = writer.Write(make([]byte, 1<<10)) + for path, content := range files { + err := writer.WriteHeader(&tar.Header{ + Name: path, + Size: int64(len(content)), + }) + require.NoError(t, err) + _, err = writer.Write([]byte(content)) + require.NoError(t, err) + } + err := writer.Flush() require.NoError(t, err) + projectHistory, err := client.CreateProjectHistory(context.Background(), user.Organization, project.Name, coderd.CreateProjectHistoryRequest{ StorageMethod: database.ProjectStorageMethodInlineArchive, StorageSource: buffer.Bytes(), }) require.NoError(t, err) + require.Eventually(t, func() bool { + hist, err := client.ProjectHistory(context.Background(), user.Organization, project.Name, projectHistory.Name) + require.NoError(t, err) + return hist.Import.Status.Completed() + }, 15*time.Second, 50*time.Millisecond) return projectHistory } @@ -54,17 +65,20 @@ func TestWorkspaceHistory(t *testing.T) { t.Parallel() server := coderdtest.New(t) user := server.RandomInitialUser(t) + _ = server.AddProvisionerd(t) project, workspace := setupProjectAndWorkspace(t, server.Client, user) - history, err := server.Client.WorkspaceHistory(context.Background(), "", workspace.Name) + history, err := server.Client.ListWorkspaceHistory(context.Background(), "", workspace.Name) require.NoError(t, err) require.Len(t, history, 0) - projectVersion := setupProjectHistory(t, server.Client, user, project) + projectVersion := setupProjectHistory(t, server.Client, user, project, map[string]string{ + "example": "file", + }) _, err = server.Client.CreateWorkspaceHistory(context.Background(), "", workspace.Name, coderd.CreateWorkspaceHistoryRequest{ ProjectHistoryID: projectVersion.ID, Transition: database.WorkspaceTransitionCreate, }) require.NoError(t, err) - history, err = server.Client.WorkspaceHistory(context.Background(), "", workspace.Name) + history, err = server.Client.ListWorkspaceHistory(context.Background(), "", workspace.Name) require.NoError(t, err) require.Len(t, history, 1) }) @@ -73,16 +87,19 @@ func TestWorkspaceHistory(t *testing.T) { t.Parallel() server := coderdtest.New(t) user := server.RandomInitialUser(t) + _ = server.AddProvisionerd(t) project, workspace := setupProjectAndWorkspace(t, server.Client, user) - _, err := server.Client.LatestWorkspaceHistory(context.Background(), "", workspace.Name) + _, err := server.Client.WorkspaceHistory(context.Background(), "", workspace.Name, "") require.Error(t, err) - projectVersion := setupProjectHistory(t, server.Client, user, project) + projectHistory := setupProjectHistory(t, server.Client, user, project, map[string]string{ + "some": "file", + }) _, err = server.Client.CreateWorkspaceHistory(context.Background(), "", workspace.Name, coderd.CreateWorkspaceHistoryRequest{ - ProjectHistoryID: projectVersion.ID, + ProjectHistoryID: projectHistory.ID, Transition: database.WorkspaceTransitionCreate, }) require.NoError(t, err) - _, err = server.Client.LatestWorkspaceHistory(context.Background(), "", workspace.Name) + _, err = server.Client.WorkspaceHistory(context.Background(), "", workspace.Name, "") require.NoError(t, err) }) @@ -90,22 +107,36 @@ func TestWorkspaceHistory(t *testing.T) { t.Parallel() server := coderdtest.New(t) user := server.RandomInitialUser(t) + _ = server.AddProvisionerd(t) project, workspace := setupProjectAndWorkspace(t, server.Client, user) - projectHistory := setupProjectHistory(t, server.Client, user, project) - + projectHistory := setupProjectHistory(t, server.Client, user, project, map[string]string{ + "main.tf": `resource "null_resource" "example" {}`, + }) _, err := server.Client.CreateWorkspaceHistory(context.Background(), "", workspace.Name, coderd.CreateWorkspaceHistoryRequest{ ProjectHistoryID: projectHistory.ID, Transition: database.WorkspaceTransitionCreate, }) require.NoError(t, err) + + var workspaceHistory coderd.WorkspaceHistory + require.Eventually(t, func() bool { + workspaceHistory, err = server.Client.WorkspaceHistory(context.Background(), "", workspace.Name, "") + require.NoError(t, err) + return workspaceHistory.Provision.Status.Completed() + }, 15*time.Second, 50*time.Millisecond) + require.Equal(t, "", workspaceHistory.Provision.Error) + require.Equal(t, coderd.ProvisionerJobStatusSucceeded, workspaceHistory.Provision.Status) }) t.Run("CreateHistoryAlreadyInProgress", func(t *testing.T) { t.Parallel() server := coderdtest.New(t) user := server.RandomInitialUser(t) + _ = server.AddProvisionerd(t) project, workspace := setupProjectAndWorkspace(t, server.Client, user) - projectHistory := setupProjectHistory(t, server.Client, user, project) + projectHistory := setupProjectHistory(t, server.Client, user, project, map[string]string{ + "some": "content", + }) _, err := server.Client.CreateWorkspaceHistory(context.Background(), "", workspace.Name, coderd.CreateWorkspaceHistoryRequest{ ProjectHistoryID: projectHistory.ID, @@ -124,6 +155,7 @@ func TestWorkspaceHistory(t *testing.T) { t.Parallel() server := coderdtest.New(t) user := server.RandomInitialUser(t) + _ = server.AddProvisionerd(t) _, workspace := setupProjectAndWorkspace(t, server.Client, user) _, err := server.Client.CreateWorkspaceHistory(context.Background(), "", workspace.Name, coderd.CreateWorkspaceHistoryRequest{ diff --git a/coderd/workspaces.go b/coderd/workspaces.go index 60504fb2cc184..01ef9870cecd4 100644 --- a/coderd/workspaces.go +++ b/coderd/workspaces.go @@ -149,20 +149,3 @@ func (*api) workspaceByUser(rw http.ResponseWriter, r *http.Request) { func convertWorkspace(workspace database.Workspace) Workspace { return Workspace(workspace) } - -// Converts the internal history representation to a public external-facing model. -func convertWorkspaceHistory(workspaceHistory database.WorkspaceHistory) WorkspaceHistory { - //nolint:unconvert - return WorkspaceHistory(WorkspaceHistory{ - ID: workspaceHistory.ID, - CreatedAt: workspaceHistory.CreatedAt, - UpdatedAt: workspaceHistory.UpdatedAt, - CompletedAt: workspaceHistory.CompletedAt.Time, - WorkspaceID: workspaceHistory.WorkspaceID, - ProjectHistoryID: workspaceHistory.ProjectHistoryID, - BeforeID: workspaceHistory.BeforeID.UUID, - AfterID: workspaceHistory.AfterID.UUID, - Transition: workspaceHistory.Transition, - Initiator: workspaceHistory.Initiator, - }) -} diff --git a/codersdk/projects.go b/codersdk/projects.go index 4b3a4e90e15d6..a4281849c22e9 100644 --- a/codersdk/projects.go +++ b/codersdk/projects.go @@ -57,8 +57,8 @@ func (c *Client) CreateProject(ctx context.Context, organization string, request return project, json.NewDecoder(res.Body).Decode(&project) } -// ProjectHistory lists history for a project. -func (c *Client) ProjectHistory(ctx context.Context, organization, project string) ([]coderd.ProjectHistory, error) { +// ListProjectHistory lists history for a project. +func (c *Client) ListProjectHistory(ctx context.Context, organization, project string) ([]coderd.ProjectHistory, error) { res, err := c.request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/projects/%s/%s/history", organization, project), nil) if err != nil { return nil, err @@ -67,8 +67,22 @@ func (c *Client) ProjectHistory(ctx context.Context, organization, project strin if res.StatusCode != http.StatusOK { return nil, readBodyAsError(res) } - var projectVersions []coderd.ProjectHistory - return projectVersions, json.NewDecoder(res.Body).Decode(&projectVersions) + var projectHistory []coderd.ProjectHistory + return projectHistory, json.NewDecoder(res.Body).Decode(&projectHistory) +} + +// ProjectHistory returns project history by name. +func (c *Client) ProjectHistory(ctx context.Context, organization, project, history string) (coderd.ProjectHistory, error) { + res, err := c.request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/projects/%s/%s/history/%s", organization, project, history), nil) + if err != nil { + return coderd.ProjectHistory{}, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return coderd.ProjectHistory{}, readBodyAsError(res) + } + var projectHistory coderd.ProjectHistory + return projectHistory, json.NewDecoder(res.Body).Decode(&projectHistory) } // CreateProjectHistory inserts a new version for the project. diff --git a/codersdk/projects_test.go b/codersdk/projects_test.go index ad61d79110288..a30146b7b97b7 100644 --- a/codersdk/projects_test.go +++ b/codersdk/projects_test.go @@ -71,14 +71,14 @@ func TestProjects(t *testing.T) { require.NoError(t, err) }) - t.Run("UnauthenticatedVersions", func(t *testing.T) { + t.Run("UnauthenticatedHistory", func(t *testing.T) { t.Parallel() server := coderdtest.New(t) - _, err := server.Client.ProjectHistory(context.Background(), "org", "project") + _, err := server.Client.ListProjectHistory(context.Background(), "org", "project") require.Error(t, err) }) - t.Run("Versions", func(t *testing.T) { + t.Run("History", func(t *testing.T) { t.Parallel() server := coderdtest.New(t) user := server.RandomInitialUser(t) @@ -87,11 +87,11 @@ func TestProjects(t *testing.T) { Provisioner: database.ProvisionerTypeTerraform, }) require.NoError(t, err) - _, err = server.Client.ProjectHistory(context.Background(), user.Organization, project.Name) + _, err = server.Client.ListProjectHistory(context.Background(), user.Organization, project.Name) require.NoError(t, err) }) - t.Run("CreateVersionUnauthenticated", func(t *testing.T) { + t.Run("CreateHistoryUnauthenticated", func(t *testing.T) { t.Parallel() server := coderdtest.New(t) _, err := server.Client.CreateProjectHistory(context.Background(), "org", "project", coderd.CreateProjectHistoryRequest{ @@ -101,7 +101,7 @@ func TestProjects(t *testing.T) { require.Error(t, err) }) - t.Run("CreateVersion", func(t *testing.T) { + t.Run("CreateHistory", func(t *testing.T) { t.Parallel() server := coderdtest.New(t) user := server.RandomInitialUser(t) @@ -119,10 +119,13 @@ func TestProjects(t *testing.T) { require.NoError(t, err) _, err = writer.Write(make([]byte, 1<<10)) require.NoError(t, err) - _, err = server.Client.CreateProjectHistory(context.Background(), user.Organization, project.Name, coderd.CreateProjectHistoryRequest{ + history, err := server.Client.CreateProjectHistory(context.Background(), user.Organization, project.Name, coderd.CreateProjectHistoryRequest{ StorageMethod: database.ProjectStorageMethodInlineArchive, StorageSource: buffer.Bytes(), }) require.NoError(t, err) + + _, err = server.Client.ProjectHistory(context.Background(), user.Organization, project.Name, history.Name) + require.NoError(t, err) }) } diff --git a/codersdk/provisioners.go b/codersdk/provisioners.go new file mode 100644 index 0000000000000..cfc908a7d39b3 --- /dev/null +++ b/codersdk/provisioners.go @@ -0,0 +1,50 @@ +package codersdk + +import ( + "context" + "encoding/json" + "net/http" + + "github.com/hashicorp/yamux" + "golang.org/x/xerrors" + "nhooyr.io/websocket" + + "github.com/coder/coder/coderd" + "github.com/coder/coder/provisionerd/proto" + "github.com/coder/coder/provisionersdk" +) + +func (c *Client) ProvisionerDaemons(ctx context.Context) ([]coderd.ProvisionerDaemon, error) { + res, err := c.request(ctx, http.MethodGet, "/api/v2/provisioners/daemons", nil) + if err != nil { + return nil, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return nil, readBodyAsError(res) + } + var daemons []coderd.ProvisionerDaemon + return daemons, json.NewDecoder(res.Body).Decode(&daemons) +} + +// ProvisionerDaemonClient returns the gRPC service for a provisioner daemon implementation. +func (c *Client) ProvisionerDaemonClient(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { + serverURL, err := c.url.Parse("/api/v2/provisioners/daemons/serve") + if err != nil { + return nil, xerrors.Errorf("parse url: %w", err) + } + conn, res, err := websocket.Dial(ctx, serverURL.String(), &websocket.DialOptions{ + HTTPClient: c.httpClient, + }) + if err != nil { + if res == nil { + return nil, err + } + return nil, readBodyAsError(res) + } + session, err := yamux.Client(websocket.NetConn(context.Background(), conn, websocket.MessageBinary), nil) + if err != nil { + return nil, xerrors.Errorf("multiplex client: %w", err) + } + return proto.NewDRPCProvisionerDaemonClient(provisionersdk.Conn(session)), nil +} diff --git a/codersdk/workspaces.go b/codersdk/workspaces.go index 937f58e861b11..122f66cdea906 100644 --- a/codersdk/workspaces.go +++ b/codersdk/workspaces.go @@ -60,8 +60,8 @@ func (c *Client) Workspace(ctx context.Context, owner, name string) (coderd.Work return workspace, json.NewDecoder(res.Body).Decode(&workspace) } -// WorkspaceHistory returns historical data for workspace builds. -func (c *Client) WorkspaceHistory(ctx context.Context, owner, workspace string) ([]coderd.WorkspaceHistory, error) { +// ListWorkspaceHistory returns historical data for workspace builds. +func (c *Client) ListWorkspaceHistory(ctx context.Context, owner, workspace string) ([]coderd.WorkspaceHistory, error) { if owner == "" { owner = "me" } @@ -77,12 +77,16 @@ func (c *Client) WorkspaceHistory(ctx context.Context, owner, workspace string) return workspaceHistory, json.NewDecoder(res.Body).Decode(&workspaceHistory) } -// LatestWorkspaceHistory returns the newest build for a workspace. -func (c *Client) LatestWorkspaceHistory(ctx context.Context, owner, workspace string) (coderd.WorkspaceHistory, error) { +// WorkspaceHistory returns a single workspace history for a workspace. +// If history is "", the latest version is returned. +func (c *Client) WorkspaceHistory(ctx context.Context, owner, workspace, history string) (coderd.WorkspaceHistory, error) { if owner == "" { owner = "me" } - res, err := c.request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/workspaces/%s/%s/history/latest", owner, workspace), nil) + if history == "" { + history = "latest" + } + res, err := c.request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/workspaces/%s/%s/history/%s", owner, workspace, history), nil) if err != nil { return coderd.WorkspaceHistory{}, err } diff --git a/codersdk/workspaces_test.go b/codersdk/workspaces_test.go index b99f3798e93ee..4b5e64d346d25 100644 --- a/codersdk/workspaces_test.go +++ b/codersdk/workspaces_test.go @@ -117,14 +117,14 @@ func TestWorkspaces(t *testing.T) { ProjectID: project.ID, }) require.NoError(t, err) - _, err = server.Client.WorkspaceHistory(context.Background(), "", workspace.Name) + _, err = server.Client.ListWorkspaceHistory(context.Background(), "", workspace.Name) require.NoError(t, err) }) t.Run("HistoryError", func(t *testing.T) { t.Parallel() server := coderdtest.New(t) - _, err := server.Client.WorkspaceHistory(context.Background(), "", "blob") + _, err := server.Client.ListWorkspaceHistory(context.Background(), "", "blob") require.Error(t, err) }) @@ -142,7 +142,7 @@ func TestWorkspaces(t *testing.T) { ProjectID: project.ID, }) require.NoError(t, err) - _, err = server.Client.LatestWorkspaceHistory(context.Background(), "", workspace.Name) + _, err = server.Client.WorkspaceHistory(context.Background(), "", workspace.Name, "") require.Error(t, err) }) diff --git a/database/databasefake/databasefake.go b/database/databasefake/databasefake.go index 19b5b0a0d060e..7ddb71ba04751 100644 --- a/database/databasefake/databasefake.go +++ b/database/databasefake/databasefake.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "strings" + "sync" "github.com/google/uuid" @@ -35,6 +36,8 @@ func New() database.Store { // fakeQuerier replicates database functionality to enable quick testing. type fakeQuerier struct { + mutex sync.Mutex + // Legacy tables apiKeys []database.APIKey organizations []database.Organization @@ -62,6 +65,9 @@ func (q *fakeQuerier) InTx(fn func(database.Store) error) error { } func (q *fakeQuerier) AcquireProvisionerJob(_ context.Context, arg database.AcquireProvisionerJobParams) (database.ProvisionerJob, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + for index, provisionerJob := range q.provisionerJobs { if provisionerJob.StartedAt.Valid { continue @@ -87,6 +93,9 @@ func (q *fakeQuerier) AcquireProvisionerJob(_ context.Context, arg database.Acqu } func (q *fakeQuerier) GetAPIKeyByID(_ context.Context, id string) (database.APIKey, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + for _, apiKey := range q.apiKeys { if apiKey.ID == id { return apiKey, nil @@ -96,6 +105,9 @@ func (q *fakeQuerier) GetAPIKeyByID(_ context.Context, id string) (database.APIK } func (q *fakeQuerier) GetUserByEmailOrUsername(_ context.Context, arg database.GetUserByEmailOrUsernameParams) (database.User, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + for _, user := range q.users { if user.Email == arg.Email || user.Username == arg.Username { return user, nil @@ -105,6 +117,9 @@ func (q *fakeQuerier) GetUserByEmailOrUsername(_ context.Context, arg database.G } func (q *fakeQuerier) GetUserByID(_ context.Context, id string) (database.User, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + for _, user := range q.users { if user.ID == id { return user, nil @@ -114,10 +129,16 @@ func (q *fakeQuerier) GetUserByID(_ context.Context, id string) (database.User, } func (q *fakeQuerier) GetUserCount(_ context.Context) (int64, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + return int64(len(q.users)), nil } func (q *fakeQuerier) GetWorkspaceAgentsByResourceIDs(_ context.Context, ids []uuid.UUID) ([]database.WorkspaceAgent, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + agents := make([]database.WorkspaceAgent, 0) for _, workspaceAgent := range q.workspaceAgent { for _, id := range ids { @@ -133,6 +154,9 @@ func (q *fakeQuerier) GetWorkspaceAgentsByResourceIDs(_ context.Context, ids []u } func (q *fakeQuerier) GetWorkspaceByID(_ context.Context, id uuid.UUID) (database.Workspace, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + for _, workspace := range q.workspace { if workspace.ID.String() == id.String() { return workspace, nil @@ -142,6 +166,9 @@ func (q *fakeQuerier) GetWorkspaceByID(_ context.Context, id uuid.UUID) (databas } func (q *fakeQuerier) GetWorkspaceByUserIDAndName(_ context.Context, arg database.GetWorkspaceByUserIDAndNameParams) (database.Workspace, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + for _, workspace := range q.workspace { if workspace.OwnerID != arg.OwnerID { continue @@ -155,6 +182,9 @@ func (q *fakeQuerier) GetWorkspaceByUserIDAndName(_ context.Context, arg databas } func (q *fakeQuerier) GetWorkspaceResourcesByHistoryID(_ context.Context, workspaceHistoryID uuid.UUID) ([]database.WorkspaceResource, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + resources := make([]database.WorkspaceResource, 0) for _, workspaceResource := range q.workspaceResource { if workspaceResource.WorkspaceHistoryID.String() == workspaceHistoryID.String() { @@ -168,6 +198,9 @@ func (q *fakeQuerier) GetWorkspaceResourcesByHistoryID(_ context.Context, worksp } func (q *fakeQuerier) GetWorkspaceHistoryByID(_ context.Context, id uuid.UUID) (database.WorkspaceHistory, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + for _, history := range q.workspaceHistory { if history.ID.String() == id.String() { return history, nil @@ -177,6 +210,9 @@ func (q *fakeQuerier) GetWorkspaceHistoryByID(_ context.Context, id uuid.UUID) ( } func (q *fakeQuerier) GetWorkspaceHistoryByWorkspaceIDWithoutAfter(_ context.Context, workspaceID uuid.UUID) (database.WorkspaceHistory, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + for _, workspaceHistory := range q.workspaceHistory { if workspaceHistory.WorkspaceID.String() != workspaceID.String() { continue @@ -189,6 +225,9 @@ func (q *fakeQuerier) GetWorkspaceHistoryByWorkspaceIDWithoutAfter(_ context.Con } func (q *fakeQuerier) GetWorkspaceHistoryLogsByIDBefore(_ context.Context, arg database.GetWorkspaceHistoryLogsByIDBeforeParams) ([]database.WorkspaceHistoryLog, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + logs := make([]database.WorkspaceHistoryLog, 0) for _, workspaceHistoryLog := range q.workspaceHistoryLog { if workspaceHistoryLog.WorkspaceHistoryID.String() != arg.WorkspaceHistoryID.String() { @@ -206,6 +245,9 @@ func (q *fakeQuerier) GetWorkspaceHistoryLogsByIDBefore(_ context.Context, arg d } func (q *fakeQuerier) GetWorkspaceHistoryByWorkspaceID(_ context.Context, workspaceID uuid.UUID) ([]database.WorkspaceHistory, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + history := make([]database.WorkspaceHistory, 0) for _, workspaceHistory := range q.workspaceHistory { if workspaceHistory.WorkspaceID.String() == workspaceID.String() { @@ -219,6 +261,9 @@ func (q *fakeQuerier) GetWorkspaceHistoryByWorkspaceID(_ context.Context, worksp } func (q *fakeQuerier) GetWorkspaceHistoryByWorkspaceIDAndName(_ context.Context, arg database.GetWorkspaceHistoryByWorkspaceIDAndNameParams) (database.WorkspaceHistory, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + for _, workspaceHistory := range q.workspaceHistory { if workspaceHistory.WorkspaceID.String() != arg.WorkspaceID.String() { continue @@ -232,6 +277,9 @@ func (q *fakeQuerier) GetWorkspaceHistoryByWorkspaceIDAndName(_ context.Context, } func (q *fakeQuerier) GetWorkspacesByProjectAndUserID(_ context.Context, arg database.GetWorkspacesByProjectAndUserIDParams) ([]database.Workspace, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + workspaces := make([]database.Workspace, 0) for _, workspace := range q.workspace { if workspace.OwnerID != arg.OwnerID { @@ -249,6 +297,9 @@ func (q *fakeQuerier) GetWorkspacesByProjectAndUserID(_ context.Context, arg dat } func (q *fakeQuerier) GetWorkspacesByUserID(_ context.Context, ownerID string) ([]database.Workspace, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + workspaces := make([]database.Workspace, 0) for _, workspace := range q.workspace { if workspace.OwnerID != ownerID { @@ -263,6 +314,9 @@ func (q *fakeQuerier) GetWorkspacesByUserID(_ context.Context, ownerID string) ( } func (q *fakeQuerier) GetOrganizationByID(_ context.Context, id string) (database.Organization, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + for _, organization := range q.organizations { if organization.ID == id { return organization, nil @@ -272,6 +326,9 @@ func (q *fakeQuerier) GetOrganizationByID(_ context.Context, id string) (databas } func (q *fakeQuerier) GetOrganizationByName(_ context.Context, name string) (database.Organization, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + for _, organization := range q.organizations { if organization.Name == name { return organization, nil @@ -281,6 +338,9 @@ func (q *fakeQuerier) GetOrganizationByName(_ context.Context, name string) (dat } func (q *fakeQuerier) GetOrganizationsByUserID(_ context.Context, userID string) ([]database.Organization, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + organizations := make([]database.Organization, 0) for _, organizationMember := range q.organizationMembers { if organizationMember.UserID != userID { @@ -300,6 +360,9 @@ func (q *fakeQuerier) GetOrganizationsByUserID(_ context.Context, userID string) } func (q *fakeQuerier) GetParameterValuesByScope(_ context.Context, arg database.GetParameterValuesByScopeParams) ([]database.ParameterValue, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + parameterValues := make([]database.ParameterValue, 0) for _, parameterValue := range q.parameterValue { if parameterValue.Scope != arg.Scope { @@ -317,6 +380,9 @@ func (q *fakeQuerier) GetParameterValuesByScope(_ context.Context, arg database. } func (q *fakeQuerier) GetProjectByID(_ context.Context, id uuid.UUID) (database.Project, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + for _, project := range q.project { if project.ID.String() == id.String() { return project, nil @@ -326,6 +392,9 @@ func (q *fakeQuerier) GetProjectByID(_ context.Context, id uuid.UUID) (database. } func (q *fakeQuerier) GetProjectByOrganizationAndName(_ context.Context, arg database.GetProjectByOrganizationAndNameParams) (database.Project, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + for _, project := range q.project { if project.OrganizationID != arg.OrganizationID { continue @@ -339,6 +408,9 @@ func (q *fakeQuerier) GetProjectByOrganizationAndName(_ context.Context, arg dat } func (q *fakeQuerier) GetProjectHistoryByProjectID(_ context.Context, projectID uuid.UUID) ([]database.ProjectHistory, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + history := make([]database.ProjectHistory, 0) for _, projectHistory := range q.projectHistory { if projectHistory.ProjectID.String() != projectID.String() { @@ -353,6 +425,9 @@ func (q *fakeQuerier) GetProjectHistoryByProjectID(_ context.Context, projectID } func (q *fakeQuerier) GetProjectHistoryByProjectIDAndName(_ context.Context, arg database.GetProjectHistoryByProjectIDAndNameParams) (database.ProjectHistory, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + for _, projectHistory := range q.projectHistory { if projectHistory.ProjectID.String() != arg.ProjectID.String() { continue @@ -366,6 +441,9 @@ func (q *fakeQuerier) GetProjectHistoryByProjectIDAndName(_ context.Context, arg } func (q *fakeQuerier) GetProjectHistoryLogsByIDBefore(_ context.Context, arg database.GetProjectHistoryLogsByIDBeforeParams) ([]database.ProjectHistoryLog, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + logs := make([]database.ProjectHistoryLog, 0) for _, projectHistoryLog := range q.projectHistoryLog { if projectHistoryLog.ProjectHistoryID.String() != arg.ProjectHistoryID.String() { @@ -383,6 +461,9 @@ func (q *fakeQuerier) GetProjectHistoryLogsByIDBefore(_ context.Context, arg dat } func (q *fakeQuerier) GetProjectHistoryByID(_ context.Context, projectHistoryID uuid.UUID) (database.ProjectHistory, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + for _, projectHistory := range q.projectHistory { if projectHistory.ID.String() != projectHistoryID.String() { continue @@ -393,6 +474,9 @@ func (q *fakeQuerier) GetProjectHistoryByID(_ context.Context, projectHistoryID } func (q *fakeQuerier) GetProjectParametersByHistoryID(_ context.Context, projectHistoryID uuid.UUID) ([]database.ProjectParameter, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + parameters := make([]database.ProjectParameter, 0) for _, projectParameter := range q.projectParameter { if projectParameter.ProjectHistoryID.String() != projectHistoryID.String() { @@ -407,6 +491,9 @@ func (q *fakeQuerier) GetProjectParametersByHistoryID(_ context.Context, project } func (q *fakeQuerier) GetProjectsByOrganizationIDs(_ context.Context, ids []string) ([]database.Project, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + projects := make([]database.Project, 0) for _, project := range q.project { for _, id := range ids { @@ -423,6 +510,9 @@ func (q *fakeQuerier) GetProjectsByOrganizationIDs(_ context.Context, ids []stri } func (q *fakeQuerier) GetOrganizationMemberByUserID(_ context.Context, arg database.GetOrganizationMemberByUserIDParams) (database.OrganizationMember, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + for _, organizationMember := range q.organizationMembers { if organizationMember.OrganizationID != arg.OrganizationID { continue @@ -436,6 +526,9 @@ func (q *fakeQuerier) GetOrganizationMemberByUserID(_ context.Context, arg datab } func (q *fakeQuerier) GetProvisionerDaemons(_ context.Context) ([]database.ProvisionerDaemon, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + if len(q.provisionerDaemons) == 0 { return nil, sql.ErrNoRows } @@ -443,6 +536,9 @@ func (q *fakeQuerier) GetProvisionerDaemons(_ context.Context) ([]database.Provi } func (q *fakeQuerier) GetProvisionerDaemonByID(_ context.Context, id uuid.UUID) (database.ProvisionerDaemon, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + for _, provisionerDaemon := range q.provisionerDaemons { if provisionerDaemon.ID.String() != id.String() { continue @@ -453,6 +549,9 @@ func (q *fakeQuerier) GetProvisionerDaemonByID(_ context.Context, id uuid.UUID) } func (q *fakeQuerier) GetProvisionerJobByID(_ context.Context, id uuid.UUID) (database.ProvisionerJob, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + for _, provisionerJob := range q.provisionerJobs { if provisionerJob.ID.String() != id.String() { continue @@ -463,6 +562,9 @@ func (q *fakeQuerier) GetProvisionerJobByID(_ context.Context, id uuid.UUID) (da } func (q *fakeQuerier) InsertAPIKey(_ context.Context, arg database.InsertAPIKeyParams) (database.APIKey, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + //nolint:gosimple key := database.APIKey{ ID: arg.ID, @@ -486,6 +588,9 @@ func (q *fakeQuerier) InsertAPIKey(_ context.Context, arg database.InsertAPIKeyP } func (q *fakeQuerier) InsertOrganization(_ context.Context, arg database.InsertOrganizationParams) (database.Organization, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + organization := database.Organization{ ID: arg.ID, Name: arg.Name, @@ -497,6 +602,9 @@ func (q *fakeQuerier) InsertOrganization(_ context.Context, arg database.InsertO } func (q *fakeQuerier) InsertOrganizationMember(_ context.Context, arg database.InsertOrganizationMemberParams) (database.OrganizationMember, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + //nolint:gosimple organizationMember := database.OrganizationMember{ OrganizationID: arg.OrganizationID, @@ -510,6 +618,9 @@ func (q *fakeQuerier) InsertOrganizationMember(_ context.Context, arg database.I } func (q *fakeQuerier) InsertParameterValue(_ context.Context, arg database.InsertParameterValueParams) (database.ParameterValue, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + //nolint:gosimple parameterValue := database.ParameterValue{ ID: arg.ID, @@ -528,6 +639,9 @@ func (q *fakeQuerier) InsertParameterValue(_ context.Context, arg database.Inser } func (q *fakeQuerier) InsertProject(_ context.Context, arg database.InsertProjectParams) (database.Project, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + project := database.Project{ ID: arg.ID, CreatedAt: arg.CreatedAt, @@ -541,6 +655,9 @@ func (q *fakeQuerier) InsertProject(_ context.Context, arg database.InsertProjec } func (q *fakeQuerier) InsertProjectHistory(_ context.Context, arg database.InsertProjectHistoryParams) (database.ProjectHistory, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + //nolint:gosimple history := database.ProjectHistory{ ID: arg.ID, @@ -558,6 +675,9 @@ func (q *fakeQuerier) InsertProjectHistory(_ context.Context, arg database.Inser } func (q *fakeQuerier) InsertProjectHistoryLogs(_ context.Context, arg database.InsertProjectHistoryLogsParams) ([]database.ProjectHistoryLog, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + logs := make([]database.ProjectHistoryLog, 0) for index, output := range arg.Output { logs = append(logs, database.ProjectHistoryLog{ @@ -574,6 +694,9 @@ func (q *fakeQuerier) InsertProjectHistoryLogs(_ context.Context, arg database.I } func (q *fakeQuerier) InsertProjectParameter(_ context.Context, arg database.InsertProjectParameterParams) (database.ProjectParameter, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + //nolint:gosimple param := database.ProjectParameter{ ID: arg.ID, @@ -599,6 +722,9 @@ func (q *fakeQuerier) InsertProjectParameter(_ context.Context, arg database.Ins } func (q *fakeQuerier) InsertProvisionerDaemon(_ context.Context, arg database.InsertProvisionerDaemonParams) (database.ProvisionerDaemon, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + daemon := database.ProvisionerDaemon{ ID: arg.ID, CreatedAt: arg.CreatedAt, @@ -610,6 +736,9 @@ func (q *fakeQuerier) InsertProvisionerDaemon(_ context.Context, arg database.In } func (q *fakeQuerier) InsertProvisionerJob(_ context.Context, arg database.InsertProvisionerJobParams) (database.ProvisionerJob, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + job := database.ProvisionerJob{ ID: arg.ID, CreatedAt: arg.CreatedAt, @@ -625,6 +754,9 @@ func (q *fakeQuerier) InsertProvisionerJob(_ context.Context, arg database.Inser } func (q *fakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParams) (database.User, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + user := database.User{ ID: arg.ID, Email: arg.Email, @@ -640,6 +772,9 @@ func (q *fakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParam } func (q *fakeQuerier) InsertWorkspace(_ context.Context, arg database.InsertWorkspaceParams) (database.Workspace, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + //nolint:gosimple workspace := database.Workspace{ ID: arg.ID, @@ -654,6 +789,9 @@ func (q *fakeQuerier) InsertWorkspace(_ context.Context, arg database.InsertWork } func (q *fakeQuerier) InsertWorkspaceAgent(_ context.Context, arg database.InsertWorkspaceAgentParams) (database.WorkspaceAgent, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + //nolint:gosimple workspaceAgent := database.WorkspaceAgent{ ID: arg.ID, @@ -668,6 +806,9 @@ func (q *fakeQuerier) InsertWorkspaceAgent(_ context.Context, arg database.Inser } func (q *fakeQuerier) InsertWorkspaceHistory(_ context.Context, arg database.InsertWorkspaceHistoryParams) (database.WorkspaceHistory, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + workspaceHistory := database.WorkspaceHistory{ ID: arg.ID, CreatedAt: arg.CreatedAt, @@ -686,6 +827,9 @@ func (q *fakeQuerier) InsertWorkspaceHistory(_ context.Context, arg database.Ins } func (q *fakeQuerier) InsertWorkspaceHistoryLogs(_ context.Context, arg database.InsertWorkspaceHistoryLogsParams) ([]database.WorkspaceHistoryLog, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + logs := make([]database.WorkspaceHistoryLog, 0) for index, output := range arg.Output { logs = append(logs, database.WorkspaceHistoryLog{ @@ -702,6 +846,9 @@ func (q *fakeQuerier) InsertWorkspaceHistoryLogs(_ context.Context, arg database } func (q *fakeQuerier) InsertWorkspaceResource(_ context.Context, arg database.InsertWorkspaceResourceParams) (database.WorkspaceResource, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + workspaceResource := database.WorkspaceResource{ ID: arg.ID, CreatedAt: arg.CreatedAt, @@ -715,6 +862,9 @@ func (q *fakeQuerier) InsertWorkspaceResource(_ context.Context, arg database.In } func (q *fakeQuerier) UpdateAPIKeyByID(_ context.Context, arg database.UpdateAPIKeyByIDParams) error { + q.mutex.Lock() + defer q.mutex.Unlock() + for index, apiKey := range q.apiKeys { if apiKey.ID != arg.ID { continue @@ -731,6 +881,9 @@ func (q *fakeQuerier) UpdateAPIKeyByID(_ context.Context, arg database.UpdateAPI } func (q *fakeQuerier) UpdateProvisionerDaemonByID(_ context.Context, arg database.UpdateProvisionerDaemonByIDParams) error { + q.mutex.Lock() + defer q.mutex.Unlock() + for index, daemon := range q.provisionerDaemons { if arg.ID.String() != daemon.ID.String() { continue @@ -744,6 +897,9 @@ func (q *fakeQuerier) UpdateProvisionerDaemonByID(_ context.Context, arg databas } func (q *fakeQuerier) UpdateProvisionerJobByID(_ context.Context, arg database.UpdateProvisionerJobByIDParams) error { + q.mutex.Lock() + defer q.mutex.Unlock() + for index, job := range q.provisionerJobs { if arg.ID.String() != job.ID.String() { continue @@ -759,12 +915,14 @@ func (q *fakeQuerier) UpdateProvisionerJobByID(_ context.Context, arg database.U } func (q *fakeQuerier) UpdateWorkspaceHistoryByID(_ context.Context, arg database.UpdateWorkspaceHistoryByIDParams) error { + q.mutex.Lock() + defer q.mutex.Unlock() + for index, workspaceHistory := range q.workspaceHistory { if workspaceHistory.ID.String() != arg.ID.String() { continue } workspaceHistory.UpdatedAt = arg.UpdatedAt - workspaceHistory.CompletedAt = arg.CompletedAt workspaceHistory.AfterID = arg.AfterID workspaceHistory.ProvisionerState = arg.ProvisionerState q.workspaceHistory[index] = workspaceHistory diff --git a/database/dump.sql b/database/dump.sql index 0cea42a2355aa..bbbc34658445a 100644 --- a/database/dump.sql +++ b/database/dump.sql @@ -243,7 +243,6 @@ CREATE TABLE workspace_history ( id uuid NOT NULL, created_at timestamp with time zone NOT NULL, updated_at timestamp with time zone NOT NULL, - completed_at timestamp with time zone, workspace_id uuid NOT NULL, project_history_id uuid NOT NULL, name character varying(64) NOT NULL, diff --git a/database/migrations/000003_workspaces.up.sql b/database/migrations/000003_workspaces.up.sql index 60fc1c0d9d8ab..fcc0b8fc3f77b 100644 --- a/database/migrations/000003_workspaces.up.sql +++ b/database/migrations/000003_workspaces.up.sql @@ -20,7 +20,6 @@ CREATE TABLE workspace_history ( id uuid NOT NULL UNIQUE, created_at timestamptz NOT NULL, updated_at timestamptz NOT NULL, - completed_at timestamptz, workspace_id uuid NOT NULL REFERENCES workspace (id) ON DELETE CASCADE, project_history_id uuid NOT NULL REFERENCES project_history (id) ON DELETE CASCADE, name varchar(64) NOT NULL, diff --git a/database/models.go b/database/models.go index 6fd1dad97d4fd..440e11dd71a33 100644 --- a/database/models.go +++ b/database/models.go @@ -422,7 +422,6 @@ type WorkspaceHistory struct { ID uuid.UUID `db:"id" json:"id"` CreatedAt time.Time `db:"created_at" json:"created_at"` UpdatedAt time.Time `db:"updated_at" json:"updated_at"` - CompletedAt sql.NullTime `db:"completed_at" json:"completed_at"` WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"` ProjectHistoryID uuid.UUID `db:"project_history_id" json:"project_history_id"` Name string `db:"name" json:"name"` diff --git a/database/pubsub_test.go b/database/pubsub_test.go index fb21383d7f24c..55f34896184c6 100644 --- a/database/pubsub_test.go +++ b/database/pubsub_test.go @@ -17,6 +17,11 @@ import ( func TestPubsub(t *testing.T) { t.Parallel() + if testing.Short() { + t.Skip() + return + } + t.Run("Postgres", func(t *testing.T) { t.Parallel() ctx, cancelFunc := context.WithCancel(context.Background()) diff --git a/database/query.sql b/database/query.sql index 42f654a4bce9e..f0b09b4081850 100644 --- a/database/query.sql +++ b/database/query.sql @@ -29,7 +29,7 @@ WHERE AND nested.completed_at IS NULL AND nested.provisioner = ANY(@types :: provisioner_type [ ]) ORDER BY - nested.created FOR + nested.created_at FOR UPDATE SKIP LOCKED LIMIT @@ -429,11 +429,11 @@ INSERT INTO project_history_log SELECT @project_history_id :: uuid AS project_history_id, - unnset(@id :: uuid [ ]) AS id, + unnest(@id :: uuid [ ]) AS id, unnest(@created_at :: timestamptz [ ]) AS created_at, - unnset(@source :: log_source [ ]) as source, - unnset(@level :: log_level [ ]) as level, - unnset(@output :: varchar(1024) [ ]) as output RETURNING *; + unnest(@source :: log_source [ ]) as source, + unnest(@level :: log_level [ ]) as level, + unnest(@output :: varchar(1024) [ ]) as output RETURNING *; -- name: InsertProjectParameter :one INSERT INTO @@ -562,12 +562,12 @@ VALUES INSERT INTO workspace_history_log SELECT + unnest(@id :: uuid [ ]) AS id, @workspace_history_id :: uuid AS workspace_history_id, - unnset(@id :: uuid [ ]) AS id, unnest(@created_at :: timestamptz [ ]) AS created_at, - unnset(@source :: log_source [ ]) as source, - unnset(@level :: log_level [ ]) as level, - unnset(@output :: varchar(1024) [ ]) as output RETURNING *; + unnest(@source :: log_source [ ]) as source, + unnest(@level :: log_level [ ]) as level, + unnest(@output :: varchar(1024) [ ]) as output RETURNING *; -- name: InsertWorkspaceResource :one INSERT INTO @@ -619,8 +619,7 @@ UPDATE workspace_history SET updated_at = $2, - completed_at = $3, - after_id = $4, - provisioner_state = $5 + after_id = $3, + provisioner_state = $4 WHERE id = $1; diff --git a/database/query.sql.go b/database/query.sql.go index 0451cc54580c3..ad322bbc9b392 100644 --- a/database/query.sql.go +++ b/database/query.sql.go @@ -32,7 +32,7 @@ WHERE AND nested.completed_at IS NULL AND nested.provisioner = ANY($3 :: provisioner_type [ ]) ORDER BY - nested.created FOR + nested.created_at FOR UPDATE SKIP LOCKED LIMIT @@ -866,7 +866,7 @@ func (q *sqlQuerier) GetWorkspaceByUserIDAndName(ctx context.Context, arg GetWor const getWorkspaceHistoryByID = `-- name: GetWorkspaceHistoryByID :one SELECT - id, created_at, updated_at, completed_at, workspace_id, project_history_id, name, before_id, after_id, transition, initiator, provisioner_state, provision_job_id + id, created_at, updated_at, workspace_id, project_history_id, name, before_id, after_id, transition, initiator, provisioner_state, provision_job_id FROM workspace_history WHERE @@ -882,7 +882,6 @@ func (q *sqlQuerier) GetWorkspaceHistoryByID(ctx context.Context, id uuid.UUID) &i.ID, &i.CreatedAt, &i.UpdatedAt, - &i.CompletedAt, &i.WorkspaceID, &i.ProjectHistoryID, &i.Name, @@ -898,7 +897,7 @@ func (q *sqlQuerier) GetWorkspaceHistoryByID(ctx context.Context, id uuid.UUID) const getWorkspaceHistoryByWorkspaceID = `-- name: GetWorkspaceHistoryByWorkspaceID :many SELECT - id, created_at, updated_at, completed_at, workspace_id, project_history_id, name, before_id, after_id, transition, initiator, provisioner_state, provision_job_id + id, created_at, updated_at, workspace_id, project_history_id, name, before_id, after_id, transition, initiator, provisioner_state, provision_job_id FROM workspace_history WHERE @@ -918,7 +917,6 @@ func (q *sqlQuerier) GetWorkspaceHistoryByWorkspaceID(ctx context.Context, works &i.ID, &i.CreatedAt, &i.UpdatedAt, - &i.CompletedAt, &i.WorkspaceID, &i.ProjectHistoryID, &i.Name, @@ -944,7 +942,7 @@ func (q *sqlQuerier) GetWorkspaceHistoryByWorkspaceID(ctx context.Context, works const getWorkspaceHistoryByWorkspaceIDAndName = `-- name: GetWorkspaceHistoryByWorkspaceIDAndName :one SELECT - id, created_at, updated_at, completed_at, workspace_id, project_history_id, name, before_id, after_id, transition, initiator, provisioner_state, provision_job_id + id, created_at, updated_at, workspace_id, project_history_id, name, before_id, after_id, transition, initiator, provisioner_state, provision_job_id FROM workspace_history WHERE @@ -964,7 +962,6 @@ func (q *sqlQuerier) GetWorkspaceHistoryByWorkspaceIDAndName(ctx context.Context &i.ID, &i.CreatedAt, &i.UpdatedAt, - &i.CompletedAt, &i.WorkspaceID, &i.ProjectHistoryID, &i.Name, @@ -980,7 +977,7 @@ func (q *sqlQuerier) GetWorkspaceHistoryByWorkspaceIDAndName(ctx context.Context const getWorkspaceHistoryByWorkspaceIDWithoutAfter = `-- name: GetWorkspaceHistoryByWorkspaceIDWithoutAfter :one SELECT - id, created_at, updated_at, completed_at, workspace_id, project_history_id, name, before_id, after_id, transition, initiator, provisioner_state, provision_job_id + id, created_at, updated_at, workspace_id, project_history_id, name, before_id, after_id, transition, initiator, provisioner_state, provision_job_id FROM workspace_history WHERE @@ -997,7 +994,6 @@ func (q *sqlQuerier) GetWorkspaceHistoryByWorkspaceIDWithoutAfter(ctx context.Co &i.ID, &i.CreatedAt, &i.UpdatedAt, - &i.CompletedAt, &i.WorkspaceID, &i.ProjectHistoryID, &i.Name, @@ -1523,11 +1519,11 @@ INSERT INTO project_history_log SELECT $1 :: uuid AS project_history_id, - unnset($2 :: uuid [ ]) AS id, + unnest($2 :: uuid [ ]) AS id, unnest($3 :: timestamptz [ ]) AS created_at, - unnset($4 :: log_source [ ]) as source, - unnset($5 :: log_level [ ]) as level, - unnset($6 :: varchar(1024) [ ]) as output RETURNING id, project_history_id, created_at, source, level, output + unnest($4 :: log_source [ ]) as source, + unnest($5 :: log_level [ ]) as level, + unnest($6 :: varchar(1024) [ ]) as output RETURNING id, project_history_id, created_at, source, level, output ` type InsertProjectHistoryLogsParams struct { @@ -1939,7 +1935,7 @@ INSERT INTO provisioner_state ) VALUES - ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) RETURNING id, created_at, updated_at, completed_at, workspace_id, project_history_id, name, before_id, after_id, transition, initiator, provisioner_state, provision_job_id + ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) RETURNING id, created_at, updated_at, workspace_id, project_history_id, name, before_id, after_id, transition, initiator, provisioner_state, provision_job_id ` type InsertWorkspaceHistoryParams struct { @@ -1975,7 +1971,6 @@ func (q *sqlQuerier) InsertWorkspaceHistory(ctx context.Context, arg InsertWorks &i.ID, &i.CreatedAt, &i.UpdatedAt, - &i.CompletedAt, &i.WorkspaceID, &i.ProjectHistoryID, &i.Name, @@ -1993,17 +1988,17 @@ const insertWorkspaceHistoryLogs = `-- name: InsertWorkspaceHistoryLogs :many INSERT INTO workspace_history_log SELECT - $1 :: uuid AS workspace_history_id, - unnset($2 :: uuid [ ]) AS id, + unnest($1 :: uuid [ ]) AS id, + $2 :: uuid AS workspace_history_id, unnest($3 :: timestamptz [ ]) AS created_at, - unnset($4 :: log_source [ ]) as source, - unnset($5 :: log_level [ ]) as level, - unnset($6 :: varchar(1024) [ ]) as output RETURNING id, workspace_history_id, created_at, source, level, output + unnest($4 :: log_source [ ]) as source, + unnest($5 :: log_level [ ]) as level, + unnest($6 :: varchar(1024) [ ]) as output RETURNING id, workspace_history_id, created_at, source, level, output ` type InsertWorkspaceHistoryLogsParams struct { - WorkspaceHistoryID uuid.UUID `db:"workspace_history_id" json:"workspace_history_id"` ID []uuid.UUID `db:"id" json:"id"` + WorkspaceHistoryID uuid.UUID `db:"workspace_history_id" json:"workspace_history_id"` CreatedAt []time.Time `db:"created_at" json:"created_at"` Source []LogSource `db:"source" json:"source"` Level []LogLevel `db:"level" json:"level"` @@ -2012,8 +2007,8 @@ type InsertWorkspaceHistoryLogsParams struct { func (q *sqlQuerier) InsertWorkspaceHistoryLogs(ctx context.Context, arg InsertWorkspaceHistoryLogsParams) ([]WorkspaceHistoryLog, error) { rows, err := q.db.QueryContext(ctx, insertWorkspaceHistoryLogs, - arg.WorkspaceHistoryID, pq.Array(arg.ID), + arg.WorkspaceHistoryID, pq.Array(arg.CreatedAt), pq.Array(arg.Source), pq.Array(arg.Level), @@ -2183,9 +2178,8 @@ UPDATE workspace_history SET updated_at = $2, - completed_at = $3, - after_id = $4, - provisioner_state = $5 + after_id = $3, + provisioner_state = $4 WHERE id = $1 ` @@ -2193,7 +2187,6 @@ WHERE type UpdateWorkspaceHistoryByIDParams struct { ID uuid.UUID `db:"id" json:"id"` UpdatedAt time.Time `db:"updated_at" json:"updated_at"` - CompletedAt sql.NullTime `db:"completed_at" json:"completed_at"` AfterID uuid.NullUUID `db:"after_id" json:"after_id"` ProvisionerState []byte `db:"provisioner_state" json:"provisioner_state"` } @@ -2202,7 +2195,6 @@ func (q *sqlQuerier) UpdateWorkspaceHistoryByID(ctx context.Context, arg UpdateW _, err := q.db.ExecContext(ctx, updateWorkspaceHistoryByID, arg.ID, arg.UpdatedAt, - arg.CompletedAt, arg.AfterID, arg.ProvisionerState, ) diff --git a/go.mod b/go.mod index 31b5af9012165..6eba88227cfa3 100644 --- a/go.mod +++ b/go.mod @@ -2,8 +2,8 @@ module github.com/coder/coder go 1.17 -// Required until https://github.com/hashicorp/terraform-exec/pull/275 is merged. -replace github.com/hashicorp/terraform-exec => github.com/kylecarbs/terraform-exec v0.15.1-0.20220129210610-65894a884c09 +// Required until https://github.com/hashicorp/terraform-exec/pull/275 and https://github.com/hashicorp/terraform-exec/pull/276 are merged. +replace github.com/hashicorp/terraform-exec => github.com/kylecarbs/terraform-exec v0.15.1-0.20220202050609-a1ce7181b180 // Required until https://github.com/hashicorp/terraform-config-inspect/pull/74 is merged. replace github.com/hashicorp/terraform-config-inspect => github.com/kylecarbs/terraform-config-inspect v0.0.0-20211215004401-bbc517866b88 @@ -35,7 +35,6 @@ require ( go.uber.org/goleak v1.1.12 golang.org/x/crypto v0.0.0-20220126234351-aa10faf2a1f8 golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8 - golang.org/x/sync v0.0.0-20210220032951-036812b2e83c golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 google.golang.org/protobuf v1.27.1 nhooyr.io/websocket v1.8.7 @@ -108,7 +107,6 @@ require ( github.com/zeebo/errs v1.2.2 // indirect go.opencensus.io v0.23.0 // indirect golang.org/x/net v0.0.0-20220121210141-e204ce36a2ba // indirect - golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 // indirect golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 // indirect golang.org/x/text v0.3.7 // indirect @@ -117,5 +115,4 @@ require ( google.golang.org/grpc v1.44.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect - nhooyr.io/websocket v1.8.7 // indirect ) diff --git a/go.sum b/go.sum index 096b027cb5a69..5363addd6de24 100644 --- a/go.sum +++ b/go.sum @@ -432,7 +432,9 @@ github.com/gabriel-vasile/mimetype v1.4.0/go.mod h1:fA8fi6KUiG7MgQQ+mEWotXoEOvmx github.com/garyburd/redigo v0.0.0-20150301180006-535138d7bcd7/go.mod h1:NR3MbYisc3/PwhQ00EMzDiPmrwpPxAn5GI05/YaO1SY= github.com/ghodss/yaml v0.0.0-20150909031657-73d445a93680/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= +github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= +github.com/gin-gonic/gin v1.6.3 h1:ahKqKTFpO5KTPHxWZjEdPScmYaGtLo8Y4DMHoEsnp14= github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M= github.com/gliderlabs/ssh v0.2.2/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= github.com/go-chi/chi/v5 v5.0.7 h1:rDTPXLDHGATaeHvVlLcR4Qe0zftYethFucbjVQ1PxU8= @@ -511,8 +513,11 @@ github.com/gobuffalo/packd v0.1.0/go.mod h1:M2Juc+hhDXf/PnmBANFCqx4DM3wRbgDvnVWe github.com/gobuffalo/packr/v2 v2.0.9/go.mod h1:emmyGweYTm6Kdper+iywB6YK5YzuKchGtJQZ0Odn4pQ= github.com/gobuffalo/packr/v2 v2.2.0/go.mod h1:CaAwI0GPIAv+5wKLtv8Afwl+Cm78K/I/VCm/3ptBN+0= github.com/gobuffalo/syncx v0.0.0-20190224160051-33c29581e754/go.mod h1:HhnNqWY95UYwwW3uSASeV7vtgYkT2t16hJgV3AEPUpw= +github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee h1:s+21KNqlpePfkah2I+gwHF8xmJWRjooY+5248k6m4A0= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= +github.com/gobwas/pool v0.2.0 h1:QEmUOlnSjWtnpRGHF3SauEiOsy82Cup83Vf2LcMlnc8= github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= +github.com/gobwas/ws v1.0.2 h1:CoAavW/wd/kulfZmSIBt6p24n4j7tHgNVCjsfHVNUbo= github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= github.com/gocql/gocql v0.0.0-20210515062232-b7ef815b4556/go.mod h1:DL0ekTmBSTdlNF25Orwt/JMzqIq3EJ4MVa/J/uK64OY= github.com/godbus/dbus v0.0.0-20151105175453-c7fdd8b5cd55/go.mod h1:/YcGZj5zSblfDWMMoOzV4fas9FZnQYTkDnsGvmh2Grw= @@ -640,6 +645,7 @@ github.com/gorilla/mux v1.7.4/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB7 github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= @@ -776,6 +782,7 @@ github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/u github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= @@ -824,8 +831,8 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/ktrysmt/go-bitbucket v0.6.4/go.mod h1:9u0v3hsd2rqCHRIpbir1oP7F58uo5dq19sBYvuMoyQ4= github.com/kylecarbs/terraform-config-inspect v0.0.0-20211215004401-bbc517866b88 h1:tvG/qs5c4worwGyGnbbb4i/dYYLjpFwDMqcIT3awAf8= github.com/kylecarbs/terraform-config-inspect v0.0.0-20211215004401-bbc517866b88/go.mod h1:Z0Nnk4+3Cy89smEbrq+sl1bxc9198gIP4I7wcQF6Kqs= -github.com/kylecarbs/terraform-exec v0.15.1-0.20220129210610-65894a884c09 h1:o+8BFGukFfFmGgOJIWEeDXkXRDdFoZ9ndi/GjqnHTGg= -github.com/kylecarbs/terraform-exec v0.15.1-0.20220129210610-65894a884c09/go.mod h1:lRENyXw1BL5V0FCCE8lsW3XoVLRLnxM54jrlYSyXpvM= +github.com/kylecarbs/terraform-exec v0.15.1-0.20220202050609-a1ce7181b180 h1:yafC0pmxjs18fnO5RdKFLSItJIjYwGfSHTfcUvlZb3E= +github.com/kylecarbs/terraform-exec v0.15.1-0.20220202050609-a1ce7181b180/go.mod h1:lRENyXw1BL5V0FCCE8lsW3XoVLRLnxM54jrlYSyXpvM= github.com/kylelemons/godebug v0.0.0-20170820004349-d65d576e9348/go.mod h1:B69LEHPfb2qLo0BaaOLcbitczOKLWTsrBG9LczfCD4k= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= @@ -913,9 +920,11 @@ github.com/moby/term v0.0.0-20201216013528-df9cb8a40635/go.mod h1:FBS0z0QWA44HXy github.com/moby/term v0.0.0-20210619224110-3f7ff695adc6 h1:dcztxKSvZ4Id8iPpHERQBbIJfabdt4wUm5qy3wOL2Zc= github.com/moby/term v0.0.0-20210619224110-3f7ff695adc6/go.mod h1:E2VnQOmVuvZB6UYnnDB0qG5Nq/1tD9acaOpo6xmt0Kw= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= @@ -1172,7 +1181,9 @@ github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1 github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM= github.com/tv42/httpunix v0.0.0-20191220191345-2ba4b9c3382c/go.mod h1:hzIxponao9Kjc7aWznkXaL4U4TWaDSs8zcsY4Ka08nM= github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc= +github.com/ugorji/go v1.1.7 h1:/68gy2h+1mWMrwZFeD1kQialdSzAb432dtpeJ42ovdo= github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= +github.com/ugorji/go/codec v1.1.7 h1:2SvQaVZ1ouYrrKKwoSk2pzd4A9evlKJb9oTL+OaLUSs= github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= github.com/unrolled/secure v1.0.9 h1:BWRuEb1vDrBFFDdbCnKkof3gZ35I/bnHGyt0LB0TNyQ= github.com/unrolled/secure v1.0.9/go.mod h1:fO+mEan+FLB0CdEnHf6Q4ZZVNqG+5fuLFnP8p0BXDPI= @@ -1444,7 +1455,6 @@ golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180224232135-f6cff0780e54/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/httpmw/workspacehistoryparam.go b/httpmw/workspacehistoryparam.go index cd43823b22d9c..ff426faf23c83 100644 --- a/httpmw/workspacehistoryparam.go +++ b/httpmw/workspacehistoryparam.go @@ -36,15 +36,27 @@ func ExtractWorkspaceHistoryParam(db database.Store) func(http.Handler) http.Han }) return } - workspaceHistory, err := db.GetWorkspaceHistoryByWorkspaceIDAndName(r.Context(), database.GetWorkspaceHistoryByWorkspaceIDAndNameParams{ - WorkspaceID: workspace.ID, - Name: workspaceHistoryName, - }) - if errors.Is(err, sql.ErrNoRows) { - httpapi.Write(rw, http.StatusNotFound, httpapi.Response{ - Message: fmt.Sprintf("workspace history %q does not exist", workspaceHistoryName), + var workspaceHistory database.WorkspaceHistory + var err error + if workspaceHistoryName == "latest" { + workspaceHistory, err = db.GetWorkspaceHistoryByWorkspaceIDWithoutAfter(r.Context(), workspace.ID) + if errors.Is(err, sql.ErrNoRows) { + httpapi.Write(rw, http.StatusNotFound, httpapi.Response{ + Message: "there is no workspace history", + }) + return + } + } else { + workspaceHistory, err = db.GetWorkspaceHistoryByWorkspaceIDAndName(r.Context(), database.GetWorkspaceHistoryByWorkspaceIDAndNameParams{ + WorkspaceID: workspace.ID, + Name: workspaceHistoryName, }) - return + if errors.Is(err, sql.ErrNoRows) { + httpapi.Write(rw, http.StatusNotFound, httpapi.Response{ + Message: fmt.Sprintf("workspace history %q does not exist", workspaceHistoryName), + }) + return + } } if err != nil { httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ diff --git a/httpmw/workspacehistoryparam_test.go b/httpmw/workspacehistoryparam_test.go index 6fef05ed91c13..063f2fd7be3ca 100644 --- a/httpmw/workspacehistoryparam_test.go +++ b/httpmw/workspacehistoryparam_test.go @@ -142,4 +142,35 @@ func TestWorkspaceHistoryParam(t *testing.T) { defer res.Body.Close() require.Equal(t, http.StatusOK, res.StatusCode) }) + + t.Run("WorkspaceHistoryLatest", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + rtr := chi.NewRouter() + rtr.Use( + httpmw.ExtractAPIKey(db, nil), + httpmw.ExtractUserParam(db), + httpmw.ExtractWorkspaceParam(db), + httpmw.ExtractWorkspaceHistoryParam(db), + ) + rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) { + _ = httpmw.WorkspaceHistoryParam(r) + rw.WriteHeader(http.StatusOK) + }) + + r, workspace := setupAuthentication(db) + _, err := db.InsertWorkspaceHistory(context.Background(), database.InsertWorkspaceHistoryParams{ + ID: uuid.New(), + WorkspaceID: workspace.ID, + Name: "moo", + }) + require.NoError(t, err) + chi.RouteContext(r.Context()).URLParams.Add("workspacehistory", "latest") + rw := httptest.NewRecorder() + rtr.ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + }) } diff --git a/peer/conn.go b/peer/conn.go index e4722b3f3f320..4333c83688b00 100644 --- a/peer/conn.go +++ b/peer/conn.go @@ -145,6 +145,10 @@ func (c *Conn) init() error { c.rtc.OnNegotiationNeeded(c.negotiate) c.rtc.OnICEConnectionStateChange(func(iceConnectionState webrtc.ICEConnectionState) { + if c.isClosed() { + return + } + c.opts.Logger.Debug(context.Background(), "ice connection state updated", slog.F("state", iceConnectionState)) diff --git a/provisioner/terraform/provision.go b/provisioner/terraform/provision.go index fe0e9bec46425..e528abaaf44ea 100644 --- a/provisioner/terraform/provision.go +++ b/provisioner/terraform/provision.go @@ -38,6 +38,12 @@ func (t *terraform) Provision(request *proto.Provision_Request, stream proto.DRP return xerrors.Errorf("terraform version %q is too old. required >= %q", version.String(), minimumTerraformVersion.String()) } + env := map[string]string{ + // Makes sequential runs significantly faster. + // https://github.com/hashicorp/terraform/blob/d35bc0531255b496beb5d932f185cbcdb2d61a99/internal/command/cliconfig/cliconfig.go#L24 + "TF_PLUGIN_CACHE_DIR": os.ExpandEnv("$HOME/.terraform.d/plugin-cache"), + } + reader, writer := io.Pipe() defer reader.Close() defer writer.Close() @@ -55,12 +61,13 @@ func (t *terraform) Provision(request *proto.Provision_Request, stream proto.DRP } }() terraform.SetStdout(writer) + t.logger.Debug(ctx, "running initialization") err = terraform.Init(ctx) if err != nil { return xerrors.Errorf("initialize terraform: %w", err) } + t.logger.Debug(ctx, "ran initialization") - env := map[string]string{} options := []tfexec.ApplyOption{tfexec.JSON(true)} for _, param := range request.ParameterValues { switch param.DestinationScheme { @@ -124,10 +131,12 @@ func (t *terraform) Provision(request *proto.Provision_Request, stream proto.DRP }() terraform.SetStdout(writer) + t.logger.Debug(ctx, "running apply") err = terraform.Apply(ctx, options...) if err != nil { return xerrors.Errorf("apply terraform: %w", err) } + t.logger.Debug(ctx, "ran apply") statefileContent, err := os.ReadFile(statefilePath) if err != nil { diff --git a/provisioner/terraform/serve.go b/provisioner/terraform/serve.go index 55323f393bf00..20b46bd3d625a 100644 --- a/provisioner/terraform/serve.go +++ b/provisioner/terraform/serve.go @@ -7,6 +7,8 @@ import ( "github.com/hashicorp/go-version" "golang.org/x/xerrors" + "cdr.dev/slog" + "github.com/coder/coder/provisionersdk" ) @@ -29,6 +31,7 @@ type ServeOptions struct { // BinaryPath specifies the "terraform" binary to use. // If omitted, the $PATH will attempt to find it. BinaryPath string + Logger slog.Logger } // Serve starts a dRPC server on the provided transport speaking Terraform provisioner. @@ -43,9 +46,11 @@ func Serve(ctx context.Context, options *ServeOptions) error { return provisionersdk.Serve(ctx, &terraform{ binaryPath: options.BinaryPath, + logger: options.Logger, }, options.ServeOptions) } type terraform struct { binaryPath string + logger slog.Logger } diff --git a/provisionerd/provisionerd.go b/provisionerd/provisionerd.go index 3fe7ce793c65e..94e7e3577800d 100644 --- a/provisionerd/provisionerd.go +++ b/provisionerd/provisionerd.go @@ -16,6 +16,8 @@ import ( "go.uber.org/atomic" + "github.com/hashicorp/yamux" + "cdr.dev/slog" "github.com/coder/coder/provisionerd/proto" sdkproto "github.com/coder/coder/provisionersdk/proto" @@ -32,9 +34,10 @@ type Provisioners map[string]sdkproto.DRPCProvisionerClient type Options struct { Logger slog.Logger - PollInterval time.Duration - Provisioners Provisioners - WorkDirectory string + UpdateInterval time.Duration + PollInterval time.Duration + Provisioners Provisioners + WorkDirectory string } // New creates and starts a provisioner daemon. @@ -42,6 +45,9 @@ func New(clientDialer Dialer, opts *Options) io.Closer { if opts.PollInterval == 0 { opts.PollInterval = 5 * time.Second } + if opts.UpdateInterval == 0 { + opts.UpdateInterval = 5 * time.Second + } ctx, ctxCancel := context.WithCancel(context.Background()) daemon := &provisionerDaemon{ clientDialer: clientDialer, @@ -84,10 +90,10 @@ type provisionerDaemon struct { acquiredJobCancel context.CancelFunc acquiredJobCancelled atomic.Bool acquiredJobRunning atomic.Bool - acquiredJobDone chan struct{} + acquiredJobGroup sync.WaitGroup } -// Connnect establishes a connection to coderd. +// Connect establishes a connection to coderd. func (p *provisionerDaemon) connect(ctx context.Context) { p.connectMutex.Lock() defer p.connectMutex.Unlock() @@ -98,7 +104,9 @@ func (p *provisionerDaemon) connect(ctx context.Context) { for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); { p.client, err = p.clientDialer(ctx) if err != nil { - // Warn + if errors.Is(err, context.Canceled) { + return + } p.opts.Logger.Warn(context.Background(), "failed to dial", slog.Error(err)) continue } @@ -135,7 +143,7 @@ func (p *provisionerDaemon) connect(ctx context.Context) { defer ticker.Stop() for { select { - case <-p.closed: + case <-ctx.Done(): return case <-p.updateStream.Context().Done(): return @@ -160,6 +168,9 @@ func (p *provisionerDaemon) acquireJob(ctx context.Context) { if errors.Is(err, context.Canceled) { return } + if errors.Is(err, yamux.ErrSessionShutdown) { + return + } p.opts.Logger.Warn(context.Background(), "acquire job", slog.Error(err)) return } @@ -173,7 +184,7 @@ func (p *provisionerDaemon) acquireJob(ctx context.Context) { ctx, p.acquiredJobCancel = context.WithCancel(ctx) p.acquiredJobCancelled.Store(false) p.acquiredJobRunning.Store(true) - p.acquiredJobDone = make(chan struct{}) + p.acquiredJobGroup.Add(1) p.opts.Logger.Info(context.Background(), "acquired job", slog.F("organization_name", p.acquiredJob.OrganizationName), @@ -190,12 +201,30 @@ func (p *provisionerDaemon) isRunningJob() bool { } func (p *provisionerDaemon) runJob(ctx context.Context) { + // Prevents p.updateStream from being accessed and + // written to at the same time. + p.connectMutex.Lock() + defer p.connectMutex.Unlock() + go func() { + ticker := time.NewTicker(p.opts.UpdateInterval) + defer ticker.Stop() select { case <-p.closed: + return case <-ctx.Done(): + return + case <-ticker.C: + err := p.updateStream.Send(&proto.JobUpdate{ + JobId: p.acquiredJob.JobId, + }) + if err != nil { + p.cancelActiveJob(fmt.Sprintf("send periodic update: %s", err)) + return + } } - + }() + defer func() { // Cleanup the work directory after execution. err := os.RemoveAll(p.opts.WorkDirectory) if err != nil { @@ -206,7 +235,7 @@ func (p *provisionerDaemon) runJob(ctx context.Context) { p.acquiredJobMutex.Lock() defer p.acquiredJobMutex.Unlock() p.acquiredJobRunning.Store(false) - close(p.acquiredJobDone) + p.acquiredJobGroup.Done() }() // It's safe to cast this ProvisionerType. This data is coming directly from coderd. provisioner, hasProvisioner := p.opts.Provisioners[p.acquiredJob.Provisioner] @@ -215,7 +244,7 @@ func (p *provisionerDaemon) runJob(ctx context.Context) { return } - err := os.MkdirAll(p.opts.WorkDirectory, 0600) + err := os.MkdirAll(p.opts.WorkDirectory, 0700) if err != nil { p.cancelActiveJob(fmt.Sprintf("create work directory %q: %s", p.opts.WorkDirectory, err)) return @@ -253,7 +282,7 @@ func (p *provisionerDaemon) runJob(ctx context.Context) { case tar.TypeReg: file, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, mode) if err != nil { - p.cancelActiveJob(fmt.Sprintf("create file %q: %s", path, err)) + p.cancelActiveJob(fmt.Sprintf("create file %q (mode %s): %s", path, mode, err)) return } // Max file size of 10MB. @@ -433,6 +462,9 @@ func (p *provisionerDaemon) runWorkspaceProvision(ctx context.Context, provision } func (p *provisionerDaemon) cancelActiveJob(errMsg string) { + if p.isClosed() { + return + } if !p.isRunningJob() { p.opts.Logger.Warn(context.Background(), "skipping job cancel; none running", slog.F("error_message", errMsg)) return @@ -488,7 +520,7 @@ func (p *provisionerDaemon) closeWithError(err error) error { if !p.acquiredJobCancelled.Load() { p.cancelActiveJob(errMsg) } - <-p.acquiredJobDone + p.acquiredJobGroup.Wait() } p.opts.Logger.Debug(context.Background(), "closing server with error", slog.Error(err)) @@ -496,6 +528,8 @@ func (p *provisionerDaemon) closeWithError(err error) error { close(p.closed) p.closeCancel() + p.connectMutex.Lock() + defer p.connectMutex.Unlock() if p.updateStream != nil { _ = p.client.DRPCConn().Close() _ = p.updateStream.Close() diff --git a/provisionerd/provisionerd_test.go b/provisionerd/provisionerd_test.go index 8148c5369d938..376bfd1eaadb1 100644 --- a/provisionerd/provisionerd_test.go +++ b/provisionerd/provisionerd_test.go @@ -153,6 +153,48 @@ func TestProvisionerd(t *testing.T) { require.NoError(t, closer.Close()) }) + t.Run("RunningPeriodicUpdate", func(t *testing.T) { + t.Parallel() + completeChan := make(chan struct{}) + closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { + return createProvisionerDaemonClient(t, provisionerDaemonTestServer{ + acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { + return &proto.AcquiredJob{ + JobId: "test", + Provisioner: "someprovisioner", + ProjectSourceArchive: createTar(t, map[string]string{ + "test.txt": "content", + }), + Type: &proto.AcquiredJob_ProjectImport_{ + ProjectImport: &proto.AcquiredJob_ProjectImport{}, + }, + }, nil + }, + updateJob: func(stream proto.DRPCProvisionerDaemon_UpdateJobStream) error { + for { + _, err := stream.Recv() + if err != nil { + return err + } + close(completeChan) + } + }, + cancelJob: func(ctx context.Context, job *proto.CancelledJob) (*proto.Empty, error) { + return &proto.Empty{}, nil + }, + }), nil + }, provisionerd.Provisioners{ + "someprovisioner": createProvisionerClient(t, provisionerTestServer{ + parse: func(request *sdkproto.Parse_Request, stream sdkproto.DRPCProvisioner_ParseStream) error { + <-stream.Context().Done() + return nil + }, + }), + }) + <-completeChan + require.NoError(t, closer.Close()) + }) + t.Run("ProjectImport", func(t *testing.T) { t.Parallel() var ( @@ -331,10 +373,11 @@ func createTar(t *testing.T, files map[string]string) []byte { // Creates a provisionerd implementation with the provided dialer and provisioners. func createProvisionerd(t *testing.T, dialer provisionerd.Dialer, provisioners provisionerd.Provisioners) io.Closer { closer := provisionerd.New(dialer, &provisionerd.Options{ - Logger: slogtest.Make(t, nil).Named("provisionerd").Leveled(slog.LevelDebug), - PollInterval: 50 * time.Millisecond, - Provisioners: provisioners, - WorkDirectory: t.TempDir(), + Logger: slogtest.Make(t, nil).Named("provisionerd").Leveled(slog.LevelDebug), + PollInterval: 50 * time.Millisecond, + UpdateInterval: 50 * time.Millisecond, + Provisioners: provisioners, + WorkDirectory: t.TempDir(), }) t.Cleanup(func() { _ = closer.Close() diff --git a/provisionersdk/transport.go b/provisionersdk/transport.go index 7fd87839d174b..3933aeb5efd7b 100644 --- a/provisionersdk/transport.go +++ b/provisionersdk/transport.go @@ -57,12 +57,16 @@ func (m *multiplexedDRPC) Closed() <-chan struct{} { return m.session.CloseChan() } -func (m *multiplexedDRPC) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, out drpc.Message) error { +func (m *multiplexedDRPC) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, inMessage, outMessage drpc.Message) error { conn, err := m.session.Open() if err != nil { return err } - return drpcconn.New(conn).Invoke(ctx, rpc, enc, in, out) + dConn := drpcconn.New(conn) + defer func() { + _ = dConn.Close() + }() + return dConn.Invoke(ctx, rpc, enc, inMessage, outMessage) } func (m *multiplexedDRPC) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (drpc.Stream, error) { @@ -70,5 +74,13 @@ func (m *multiplexedDRPC) NewStream(ctx context.Context, rpc string, enc drpc.En if err != nil { return nil, err } - return drpcconn.New(conn).NewStream(ctx, rpc, enc) + dConn := drpcconn.New(conn) + stream, err := dConn.NewStream(ctx, rpc, enc) + if err == nil { + go func() { + <-stream.Context().Done() + _ = dConn.Close() + }() + } + return stream, err }