diff --git a/checker/cost.go b/checker/cost.go index 8ae8d18bf..ef58df766 100644 --- a/checker/cost.go +++ b/checker/cost.go @@ -533,14 +533,34 @@ func (c *coster) functionCost(function, overloadID string, target *AstNode, args if est := c.estimator.EstimateCallCost(function, overloadID, target, args); est != nil { callEst := *est - return CallEstimate{CostEstimate: callEst.Add(argCostSum())} + return CallEstimate{CostEstimate: callEst.Add(argCostSum()), ResultSize: est.ResultSize} } switch overloadID { // O(n) functions - case overloads.StartsWithString, overloads.EndsWithString, overloads.StringToBytes, overloads.BytesToString, overloads.ExtQuoteString, overloads.ExtFormatString: - if overloadID == overloads.ExtFormatString { + case overloads.ExtFormatString: + if target != nil { + // ResultSize not calculated because we can't bound the max size. return CallEstimate{CostEstimate: c.sizeEstimate(*target).MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum())} } + case overloads.StringToBytes: + if len(args) == 1 { + sz := c.sizeEstimate(args[0]) + // ResultSize max is when each char converts to 4 bytes. + return CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum()), ResultSize: &SizeEstimate{Min: sz.Min, Max: sz.Max * 4}} + } + case overloads.BytesToString: + if len(args) == 1 { + sz := c.sizeEstimate(args[0]) + // ResultSize min is when 4 bytes convert to 1 char. + return CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum()), ResultSize: &SizeEstimate{Min: sz.Min / 4, Max: sz.Max}} + } + case overloads.ExtQuoteString: + if len(args) == 1 { + sz := c.sizeEstimate(args[0]) + // ResultSize max is when each char is escaped. 2 quote chars always added. + return CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum()), ResultSize: &SizeEstimate{Min: sz.Min + 2, Max: sz.Max*2 + 2}} + } + case overloads.StartsWithString, overloads.EndsWithString: if len(args) == 1 { return CallEstimate{CostEstimate: c.sizeEstimate(args[0]).MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum())} } diff --git a/checker/cost_test.go b/checker/cost_test.go index 09c457325..9e751f411 100644 --- a/checker/cost_test.go +++ b/checker/cost_test.go @@ -261,6 +261,14 @@ func TestCost(t *testing.T) { expr: `string(input)`, wanted: CostEstimate{Min: 1, Max: 51}, }, + { + name: "bytes to string conversion equality", + decls: []*exprpb.Decl{decls.NewVar("input", decls.Bytes)}, + hints: map[string]int64{"input": 500}, + // equality check ensures that the resultSize calculation is included in cost + expr: `string(input) == string(input)`, + wanted: CostEstimate{Min: 3, Max: 152}, + }, { name: "string to bytes conversion", decls: []*exprpb.Decl{decls.NewVar("input", decls.String)}, @@ -268,6 +276,14 @@ func TestCost(t *testing.T) { expr: `bytes(input)`, wanted: CostEstimate{Min: 1, Max: 51}, }, + { + name: "string to bytes conversion equality", + decls: []*exprpb.Decl{decls.NewVar("input", decls.String)}, + hints: map[string]int64{"input": 500}, + // equality check ensures that the resultSize calculation is included in cost + expr: `bytes(input) == bytes(input)`, + wanted: CostEstimate{Min: 3, Max: 302}, + }, { name: "int to string conversion", expr: `string(1)`,