From e8a6a6112b33a84f2feb7c51abf742045e76b4b8 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Mon, 3 Apr 2023 20:47:03 -0500 Subject: [PATCH] chore: Add "required" to allow requring url params Going to be used for workspace apps --- coderd/httpapi/queryparams.go | 37 ++++++++++++++++++++++-- coderd/httpapi/queryparams_test.go | 46 ++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 2 deletions(-) diff --git a/coderd/httpapi/queryparams.go b/coderd/httpapi/queryparams.go index b1a66d74184fa..c2def95774470 100644 --- a/coderd/httpapi/queryparams.go +++ b/coderd/httpapi/queryparams.go @@ -24,12 +24,16 @@ type QueryParamParser struct { // Parsed is a map of all query params that were parsed. This is useful // for checking if extra query params were passed in. Parsed map[string]bool + // RequiredParams is a map of all query params that are required. This is useful + // for forcing a value to be provided. + RequiredParams map[string]bool } func NewQueryParamParser() *QueryParamParser { return &QueryParamParser{ - Errors: []codersdk.ValidationError{}, - Parsed: map[string]bool{}, + Errors: []codersdk.ValidationError{}, + Parsed: map[string]bool{}, + RequiredParams: map[string]bool{}, } } @@ -51,6 +55,20 @@ func (p *QueryParamParser) addParsed(key string) { p.Parsed[key] = true } +func (p *QueryParamParser) UInt(vals url.Values, def uint64, queryParam string) uint64 { + v, err := parseQueryParam(p, vals, func(v string) (uint64, error) { + return strconv.ParseUint(v, 10, 64) + }, def, queryParam) + if err != nil { + p.Errors = append(p.Errors, codersdk.ValidationError{ + Field: queryParam, + Detail: fmt.Sprintf("Query param %q must be a valid positive integer (%s)", queryParam, err.Error()), + }) + return 0 + } + return v +} + func (p *QueryParamParser) Int(vals url.Values, def int, queryParam string) int { v, err := parseQueryParam(p, vals, strconv.Atoi, def, queryParam) if err != nil { @@ -62,6 +80,11 @@ func (p *QueryParamParser) Int(vals url.Values, def int, queryParam string) int return v } +func (p *QueryParamParser) Required(queryParam string) *QueryParamParser { + p.RequiredParams[queryParam] = true + return p +} + func (p *QueryParamParser) UUIDorMe(vals url.Values, def uuid.UUID, me uuid.UUID, queryParam string) uuid.UUID { return ParseCustom(p, vals, def, queryParam, func(v string) (uuid.UUID, error) { if v == "me" { @@ -178,6 +201,16 @@ func ParseCustomList[T any](parser *QueryParamParser, vals url.Values, def []T, func parseQueryParam[T any](parser *QueryParamParser, vals url.Values, parse func(v string) (T, error), def T, queryParam string) (T, error) { parser.addParsed(queryParam) + // If the query param is required and not present, return an error. + if parser.RequiredParams[queryParam] && (!vals.Has(queryParam)) { + parser.Errors = append(parser.Errors, codersdk.ValidationError{ + Field: queryParam, + Detail: fmt.Sprintf("Query param %q is required", queryParam), + }) + return def, nil + } + + // If the query param is not present, return the default value. if !vals.Has(queryParam) || vals.Get(queryParam) == "" { return def, nil } diff --git a/coderd/httpapi/queryparams_test.go b/coderd/httpapi/queryparams_test.go index 6232bef22862d..4a7649e704c0b 100644 --- a/coderd/httpapi/queryparams_test.go +++ b/coderd/httpapi/queryparams_test.go @@ -195,6 +195,43 @@ func TestParseQueryParams(t *testing.T) { testQueryParams(t, expParams, parser, parser.Int) }) + t.Run("UInt", func(t *testing.T) { + t.Parallel() + expParams := []queryParamTestCase[uint64]{ + { + QueryParam: "valid_integer", + Value: "100", + Expected: 100, + }, + { + QueryParam: "empty", + Value: "", + Expected: 0, + }, + { + QueryParam: "no_value", + NoSet: true, + Default: 5, + Expected: 5, + }, + { + QueryParam: "negative", + Value: "-10", + Default: 5, + ExpectedErrorContains: "must be a valid positive integer", + }, + { + QueryParam: "invalid_integer", + Value: "bogus", + Expected: 0, + ExpectedErrorContains: "must be a valid positive integer", + }, + } + + parser := httpapi.NewQueryParamParser() + testQueryParams(t, expParams, parser, parser.UInt) + }) + t.Run("UUIDs", func(t *testing.T) { t.Parallel() expParams := []queryParamTestCase[[]uuid.UUID]{ @@ -237,6 +274,15 @@ func TestParseQueryParams(t *testing.T) { parser := httpapi.NewQueryParamParser() testQueryParams(t, expParams, parser, parser.UUIDs) }) + + t.Run("Required", func(t *testing.T) { + t.Parallel() + + parser := httpapi.NewQueryParamParser() + parser.Required("test_value") + parser.UUID(url.Values{}, uuid.New(), "test_value") + require.Len(t, parser.Errors, 1) + }) } func testQueryParams[T any](t *testing.T, testCases []queryParamTestCase[T], parser *httpapi.QueryParamParser, parse func(vals url.Values, def T, queryParam string) T) {