Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refactor job update for provisionerd
  • Loading branch information
kylecarbs committed Feb 10, 2022
commit e53f0beb26c83638ce269e3662fdc82d8b736ab2
254 changes: 115 additions & 139 deletions coderd/provisionerdaemons.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,6 @@ type workspaceProvisionJob struct {
DryRun bool `json:"dry_run"`
}

// The input for a "project_import" job.
type projectVersionImportJob struct {
OrganizationID string `json:"organization_id"`
ProjectID uuid.UUID `json:"project_id"`
}

// Implementation of the provisioner daemon protobuf server.
type provisionerdServer struct {
ID uuid.UUID
Expand Down Expand Up @@ -242,39 +236,8 @@ func (server *provisionerdServer) AcquireJob(ctx context.Context, _ *proto.Empty
},
}
case database.ProvisionerJobTypeProjectVersionImport:
var input projectVersionImportJob
err = json.Unmarshal(job.Input, &input)
if err != nil {
return nil, failJob(fmt.Sprintf("unmarshal job input %q: %s", job.Input, err))
}

// Compute parameters for the workspace to consume.
parameters, err := parameter.Compute(ctx, server.Database, parameter.ComputeScope{
ProjectImportJobID: job.ID,
OrganizationID: input.OrganizationID,
ProjectID: uuid.NullUUID{
UUID: input.ProjectID,
Valid: input.ProjectID.String() != uuid.Nil.String(),
},
UserID: user.ID,
}, nil)
if err != nil {
return nil, failJob(fmt.Sprintf("compute parameters: %s", err))
}
// Convert parameters to the protobuf type.
protoParameters := make([]*sdkproto.ParameterValue, 0, len(parameters))
for _, parameter := range parameters {
converted, err := convertComputedParameterValue(parameter)
if err != nil {
return nil, failJob(fmt.Sprintf("convert parameter: %s", err))
}
protoParameters = append(protoParameters, converted)
}

protoJob.Type = &proto.AcquiredJob_ProjectImport_{
ProjectImport: &proto.AcquiredJob_ProjectImport{
ParameterValues: protoParameters,
},
ProjectImport: &proto.AcquiredJob_ProjectImport{},
}
}
switch job.StorageMethod {
Expand All @@ -291,119 +254,137 @@ func (server *provisionerdServer) AcquireJob(ctx context.Context, _ *proto.Empty
return protoJob, err
}

func (server *provisionerdServer) UpdateJob(stream proto.DRPCProvisionerDaemon_UpdateJobStream) error {
for {
update, err := stream.Recv()
if err != nil {
return err
func (server *provisionerdServer) UpdateJob(ctx context.Context, request *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
parsedID, err := uuid.Parse(request.JobId)
if err != nil {
return nil, xerrors.Errorf("parse job id: %w", err)
}
job, err := server.Database.GetProvisionerJobByID(ctx, parsedID)
if err != nil {
return nil, xerrors.Errorf("get job: %w", err)
}
if !job.WorkerID.Valid {
return nil, xerrors.New("job isn't running yet")
}
if job.WorkerID.UUID.String() != server.ID.String() {
return nil, xerrors.New("you don't own this job")
}
err = server.Database.UpdateProvisionerJobByID(ctx, database.UpdateProvisionerJobByIDParams{
ID: parsedID,
UpdatedAt: database.Now(),
})
if err != nil {
return nil, xerrors.Errorf("update job: %w", err)
}

if len(request.Logs) > 0 {
insertParams := database.InsertProvisionerJobLogsParams{
JobID: parsedID,
}
parsedID, err := uuid.Parse(update.JobId)
if err != nil {
return xerrors.Errorf("parse job id: %w", err)
for _, log := range request.Logs {
logLevel, err := convertLogLevel(log.Level)
if err != nil {
return nil, xerrors.Errorf("convert log level: %w", err)
}
logSource, err := convertLogSource(log.Source)
if err != nil {
return nil, xerrors.Errorf("convert log source: %w", err)
}
insertParams.ID = append(insertParams.ID, uuid.New())
insertParams.CreatedAt = append(insertParams.CreatedAt, time.UnixMilli(log.CreatedAt))
insertParams.Level = append(insertParams.Level, logLevel)
insertParams.Source = append(insertParams.Source, logSource)
insertParams.Output = append(insertParams.Output, log.Output)
}
job, err := server.Database.GetProvisionerJobByID(stream.Context(), parsedID)
logs, err := server.Database.InsertProvisionerJobLogs(context.Background(), insertParams)
if err != nil {
return xerrors.Errorf("get job: %w", err)
return nil, xerrors.Errorf("insert job logs: %w", err)
}
if !job.WorkerID.Valid {
return xerrors.New("job isn't running yet")
}
if job.WorkerID.UUID.String() != server.ID.String() {
return xerrors.New("you don't own this job")
data, err := json.Marshal(logs)
if err != nil {
return nil, xerrors.Errorf("marshal job log: %w", err)
}

err = server.Database.UpdateProvisionerJobByID(stream.Context(), database.UpdateProvisionerJobByIDParams{
ID: parsedID,
UpdatedAt: database.Now(),
})
err = server.Pubsub.Publish(provisionerJobLogsChannel(parsedID), data)
if err != nil {
return xerrors.Errorf("update job: %w", err)
return nil, xerrors.Errorf("publish job log: %w", err)
}
if len(update.Logs) > 0 {
insertParams := database.InsertProvisionerJobLogsParams{
JobID: parsedID,
}
for _, log := range update.Logs {
logLevel, err := convertLogLevel(log.Level)
if err != nil {
return xerrors.Errorf("convert log level: %w", err)
}
logSource, err := convertLogSource(log.Source)
if err != nil {
return xerrors.Errorf("convert log source: %w", err)
}
insertParams.ID = append(insertParams.ID, uuid.New())
insertParams.CreatedAt = append(insertParams.CreatedAt, time.UnixMilli(log.CreatedAt))
insertParams.Level = append(insertParams.Level, logLevel)
insertParams.Source = append(insertParams.Source, logSource)
insertParams.Output = append(insertParams.Output, log.Output)
}
logs, err := server.Database.InsertProvisionerJobLogs(context.Background(), insertParams)
if err != nil {
return xerrors.Errorf("insert job logs: %w", err)
}
data, err := json.Marshal(logs)
}

if len(request.ParameterSchemas) > 0 {
for _, protoParameter := range request.ParameterSchemas {
validationTypeSystem, err := convertValidationTypeSystem(protoParameter.ValidationTypeSystem)
if err != nil {
return xerrors.Errorf("marshal job log: %w", err)
return nil, xerrors.Errorf("convert validation type system for %q: %w", protoParameter.Name, err)
}
err = server.Pubsub.Publish(provisionerJobLogsChannel(parsedID), data)
if err != nil {
return xerrors.Errorf("publish job log: %w", err)

parameterSchema := database.InsertParameterSchemaParams{
ID: uuid.New(),
CreatedAt: database.Now(),
JobID: job.ID,
Name: protoParameter.Name,
Description: protoParameter.Description,
RedisplayValue: protoParameter.RedisplayValue,
ValidationError: protoParameter.ValidationError,
ValidationCondition: protoParameter.ValidationCondition,
ValidationValueType: protoParameter.ValidationValueType,
ValidationTypeSystem: validationTypeSystem,

DefaultSourceScheme: database.ParameterSourceSchemeNone,
DefaultDestinationScheme: database.ParameterDestinationSchemeNone,

AllowOverrideDestination: protoParameter.AllowOverrideDestination,
AllowOverrideSource: protoParameter.AllowOverrideSource,
}
}

if update.GetProjectImport() != nil {
// Validate that all parameters send from the provisioner daemon
// follow the protocol.
parameterSchemas := make([]database.InsertParameterSchemaParams, 0, len(update.GetProjectImport().ParameterSchemas))
for _, protoParameter := range update.GetProjectImport().ParameterSchemas {
validationTypeSystem, err := convertValidationTypeSystem(protoParameter.ValidationTypeSystem)
// It's possible a parameter doesn't define a default source!
if protoParameter.DefaultSource != nil {
parameterSourceScheme, err := convertParameterSourceScheme(protoParameter.DefaultSource.Scheme)
if err != nil {
return xerrors.Errorf("convert validation type system for %q: %w", protoParameter.Name, err)
}

parameterSchema := database.InsertParameterSchemaParams{
ID: uuid.New(),
CreatedAt: database.Now(),
JobID: job.ID,
Name: protoParameter.Name,
Description: protoParameter.Description,
RedisplayValue: protoParameter.RedisplayValue,
ValidationError: protoParameter.ValidationError,
ValidationCondition: protoParameter.ValidationCondition,
ValidationValueType: protoParameter.ValidationValueType,
ValidationTypeSystem: validationTypeSystem,

DefaultSourceScheme: database.ParameterSourceSchemeNone,
DefaultDestinationScheme: database.ParameterDestinationSchemeNone,

AllowOverrideDestination: protoParameter.AllowOverrideDestination,
AllowOverrideSource: protoParameter.AllowOverrideSource,
return nil, xerrors.Errorf("convert parameter source scheme: %w", err)
}
parameterSchema.DefaultSourceScheme = parameterSourceScheme
parameterSchema.DefaultSourceValue = protoParameter.DefaultSource.Value
}

// It's possible a parameter doesn't define a default source!
if protoParameter.DefaultSource != nil {
parameterSourceScheme, err := convertParameterSourceScheme(protoParameter.DefaultSource.Scheme)
if err != nil {
return xerrors.Errorf("convert parameter source scheme: %w", err)
}
parameterSchema.DefaultSourceScheme = parameterSourceScheme
parameterSchema.DefaultSourceValue = protoParameter.DefaultSource.Value
// It's possible a parameter doesn't define a default destination!
if protoParameter.DefaultDestination != nil {
parameterDestinationScheme, err := convertParameterDestinationScheme(protoParameter.DefaultDestination.Scheme)
if err != nil {
return nil, xerrors.Errorf("convert parameter destination scheme: %w", err)
}
parameterSchema.DefaultDestinationScheme = parameterDestinationScheme
}

// It's possible a parameter doesn't define a default destination!
if protoParameter.DefaultDestination != nil {
parameterDestinationScheme, err := convertParameterDestinationScheme(protoParameter.DefaultDestination.Scheme)
if err != nil {
return xerrors.Errorf("convert parameter destination scheme: %w", err)
}
parameterSchema.DefaultDestinationScheme = parameterDestinationScheme
}
_, err = server.Database.InsertParameterSchema(ctx, parameterSchema)
if err != nil {
return nil, xerrors.Errorf("insert parameter schema: %w", err)
}
}

parameterSchemas = append(parameterSchemas, parameterSchema)
parameters, err := parameter.Compute(ctx, server.Database, parameter.ComputeScope{
ProjectImportJobID: job.ID,
OrganizationID: job.OrganizationID,
UserID: job.InitiatorID,
}, nil)
if err != nil {
return nil, xerrors.Errorf("compute parameters: %w", err)
}
// Convert parameters to the protobuf type.
protoParameters := make([]*sdkproto.ParameterValue, 0, len(parameters))
for _, parameter := range parameters {
converted, err := convertComputedParameterValue(parameter)
if err != nil {
return nil, xerrors.Errorf("convert parameter: %s", err)
}
protoParameters = append(protoParameters, converted)
}

return &proto.UpdateJobResponse{
ParameterValues: protoParameters,
}, nil
}

return &proto.UpdateJobResponse{}, nil
}

func (server *provisionerdServer) CancelJob(ctx context.Context, cancelJob *proto.CancelledJob) (*proto.Empty, error) {
Expand Down Expand Up @@ -450,17 +431,12 @@ func (server *provisionerdServer) CompleteJob(ctx context.Context, completed *pr
if err != nil {
return nil, xerrors.Errorf("get job by id: %w", err)
}
// TODO: Check if the worker ID matches!
// If it doesn't, a provisioner daemon could be impersonating another job!
if job.WorkerID.UUID.String() != server.ID.String() {
return nil, xerrors.Errorf("you don't have permission to update this job")
}

switch jobType := completed.Type.(type) {
case *proto.CompletedJob_ProjectImport_:
var input projectVersionImportJob
err = json.Unmarshal(job.Input, &input)
if err != nil {
return nil, xerrors.Errorf("unmarshal job data: %w", err)
}

err = server.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
ID: jobID,
UpdatedAt: database.Now(),
Expand Down
12 changes: 0 additions & 12 deletions coderd/provisionerjobs.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package coderd

import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"net/http"
Expand Down Expand Up @@ -80,16 +79,6 @@ func (api *api) postProvisionerImportJobByOrganization(rw http.ResponseWriter, r
return
}

input, err := json.Marshal(projectVersionImportJob{
OrganizationID: organization.ID,
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("marshal job: %s", err),
})
return
}

jobID := uuid.New()
for _, parameterValue := range req.ParameterValues {
_, err = api.Database.InsertParameterValue(r.Context(), database.InsertParameterValueParams{
Expand Down Expand Up @@ -121,7 +110,6 @@ func (api *api) postProvisionerImportJobByOrganization(rw http.ResponseWriter, r
StorageMethod: database.ProvisionerStorageMethodFile,
StorageSource: file.Hash,
Type: database.ProvisionerJobTypeProjectVersionImport,
Input: input,
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Expand Down
14 changes: 9 additions & 5 deletions coderd/provisionerjobs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package coderd_test

import (
"context"
"fmt"
"net/http"
"testing"
"time"
Expand Down Expand Up @@ -89,7 +88,6 @@ func TestPostProvisionerImportJobByOrganization(t *testing.T) {
})
require.NoError(t, err)
job = coderdtest.AwaitProvisionerJob(t, client, user.Organization, job.ID)
fmt.Printf("Job %+v\n", job)
values, err := client.ProvisionerJobParameterValues(context.Background(), user.Organization, job.ID)
require.NoError(t, err)
require.Equal(t, "somevalue", values[0].SourceValue)
Expand Down Expand Up @@ -120,16 +118,23 @@ func TestProvisionerJobParametersByID(t *testing.T) {
Complete: &proto.Parse_Complete{
ParameterSchemas: []*proto.ParameterSchema{{
Name: "example",
DefaultSource: &proto.ParameterSource{
Scheme: proto.ParameterSource_DATA,
Value: "hello",
},
DefaultDestination: &proto.ParameterDestination{
Scheme: proto.ParameterDestination_PROVISIONER_VARIABLE,
},
}},
},
},
}},
Provision: echo.ProvisionComplete,
})
coderdtest.AwaitProvisionerJob(t, client, user.Organization, job.ID)
job = coderdtest.AwaitProvisionerJob(t, client, user.Organization, job.ID)
params, err := client.ProvisionerJobParameterValues(context.Background(), user.Organization, job.ID)
require.NoError(t, err)
require.Len(t, params, 0)
require.Len(t, params, 1)
})

t.Run("ListNoRedisplay", func(t *testing.T) {
Expand All @@ -149,7 +154,6 @@ func TestProvisionerJobParametersByID(t *testing.T) {
},
DefaultDestination: &proto.ParameterDestination{
Scheme: proto.ParameterDestination_PROVISIONER_VARIABLE,
Value: "example",
},
RedisplayValue: false,
}},
Expand Down
1 change: 0 additions & 1 deletion provisioner/terraform/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ func convertVariableToParameter(variable *tfconfig.Variable) (*proto.ParameterSc
}
schema.DefaultDestination = &proto.ParameterDestination{
Scheme: proto.ParameterDestination_PROVISIONER_VARIABLE,
Value: variable.Name,
}
}

Expand Down
Loading