From 9ee247994eb75df094b51a50049936db8e832e81 Mon Sep 17 00:00:00 2001 From: TristonianJones Date: Sat, 29 Apr 2023 10:55:18 -0700 Subject: [PATCH 1/2] Additional test for message field accesses --- ext/strings_test.go | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/ext/strings_test.go b/ext/strings_test.go index 0f8fe55a9..cec3ce566 100644 --- a/ext/strings_test.go +++ b/ext/strings_test.go @@ -941,6 +941,19 @@ func TestStringFormat(t *testing.T) { expectedRuntimeCost: 13, expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13}, }, + { + name: "message field support", + format: "message field msg.single_int32: %d, msg.single_double: %.1f", + formatArgs: `msg.single_int32, msg.single_double`, + dynArgs: map[string]any{ + "msg": &proto3pb.TestAllTypes{ + SingleInt32: 2, + SingleDouble: 1.0, + }, + }, + locale: "en_US", + expectedOutput: `message field msg.single_int32: 2, msg.single_double: 1.0`, + }, { name: "unrecognized formatting clause", format: "%a", @@ -1266,7 +1279,7 @@ func TestStringFormat(t *testing.T) { checkCase(out, tt.expectedOutput, err, tt.err, t) if tt.locale == "" { // if the test has no locale specified, then that means it - // should have the same output regardless of lcoale + // should have the same output regardless of locale t.Run("no change on locale", func(t *testing.T) { out, err := runCase(tt.format, tt.formatArgs, "da_DK", tt.dynArgs, tt.skipCompileCheck, tt.expectedRuntimeCost, tt.expectedEstimatedCost, t) checkCase(out, tt.expectedOutput, err, tt.err, t) From 54eedf5f8aea00e6a7f207cda4081e1355fa056b Mon Sep 17 00:00:00 2001 From: TristonianJones Date: Mon, 1 May 2023 10:17:31 -0700 Subject: [PATCH 2/2] Improvements in attribute identification yield more accurate state tracking and correlation to type ids from checked expressions --- ext/strings_test.go | 30 +++++++++++++++++-- interpreter/attributes.go | 19 ++++++++++-- interpreter/attributes_test.go | 53 ++++++++++++++++++++++++++++------ 3 files changed, 88 insertions(+), 14 deletions(-) diff --git a/ext/strings_test.go b/ext/strings_test.go index cec3ce566..e6d5e441a 100644 --- a/ext/strings_test.go +++ b/ext/strings_test.go @@ -23,10 +23,13 @@ import ( "time" "unicode/utf8" + "google.golang.org/protobuf/proto" + "github.com/google/cel-go/cel" "github.com/google/cel-go/checker" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" + proto3pb "github.com/google/cel-go/test/proto3pb" ) @@ -1222,8 +1225,31 @@ func TestStringFormat(t *testing.T) { buildVariables := func(vars map[string]any) []cel.EnvOption { opts := make([]cel.EnvOption, len(vars)) i := 0 - for name := range vars { - opts[i] = cel.Variable(name, cel.DynType) + for name, value := range vars { + t := cel.DynType + switch v := value.(type) { + case proto.Message: + t = cel.ObjectType(string(v.ProtoReflect().Descriptor().FullName())) + case types.Bool: + t = cel.BoolType + case types.Bytes: + t = cel.BytesType + case types.Double: + t = cel.DoubleType + case types.Duration: + t = cel.DurationType + case types.Int: + t = cel.IntType + case types.Null: + t = cel.NullType + case types.String: + t = cel.StringType + case types.Timestamp: + t = cel.TimestampType + case types.Uint: + t = cel.UintType + } + opts[i] = cel.Variable(name, t) i++ } return opts diff --git a/interpreter/attributes.go b/interpreter/attributes.go index 5c8107ab7..4004e32e9 100644 --- a/interpreter/attributes.go +++ b/interpreter/attributes.go @@ -231,7 +231,11 @@ type absoluteAttribute struct { // ID implements the Attribute interface method. func (a *absoluteAttribute) ID() int64 { - return a.id + qual_count := len(a.qualifiers) + if qual_count == 0 { + return a.id + } + return a.qualifiers[qual_count-1].ID() } // IsOptional returns trivially false for an attribute as the attribute represents a fully @@ -315,6 +319,11 @@ type conditionalAttribute struct { // ID is an implementation of the Attribute interface method. func (a *conditionalAttribute) ID() int64 { + // There's a field access after the conditional. + if a.truthy.ID() == a.falsy.ID() { + return a.truthy.ID() + } + // Otherwise return the conditional id as the consistent id being tracked. return a.id } @@ -379,7 +388,7 @@ type maybeAttribute struct { // ID is an implementation of the Attribute interface method. func (a *maybeAttribute) ID() int64 { - return a.id + return a.attrs[0].ID() } // IsOptional returns trivially false for an attribute as the attribute represents a fully @@ -494,7 +503,11 @@ type relativeAttribute struct { // ID is an implementation of the Attribute interface method. func (a *relativeAttribute) ID() int64 { - return a.id + qual_count := len(a.qualifiers) + if qual_count == 0 { + return a.id + } + return a.qualifiers[qual_count-1].ID() } // IsOptional returns trivially false for an attribute as the attribute represents a fully diff --git a/interpreter/attributes_test.go b/interpreter/attributes_test.go index 023aa7841..6d2abeb60 100644 --- a/interpreter/attributes_test.go +++ b/interpreter/attributes_test.go @@ -869,8 +869,8 @@ func TestAttributeStateTracking(t *testing.T) { in: map[string]any{}, out: types.True, state: map[int64]any{ - // overall expression - 1: true, + // [{"field": true}] + 1: []ref.Val{types.DefaultTypeAdapter.NativeToValue(map[ref.Val]ref.Val{types.String("field"): types.True})}, // [{"field": true}][0] 6: map[ref.Val]ref.Val{types.String("field"): types.True}, // [{"field": true}][0].field @@ -893,8 +893,6 @@ func TestAttributeStateTracking(t *testing.T) { }, out: types.True, state: map[int64]any{ - // overall expression - 1: true, // a[1] 2: map[string]bool{"two": true}, // a[1]["two"] @@ -918,8 +916,6 @@ func TestAttributeStateTracking(t *testing.T) { }, out: types.String("dex"), state: map[int64]any{ - // overall expression - 1: "dex", // a[1] 2: map[int64]any{ 1: 0, @@ -948,8 +944,6 @@ func TestAttributeStateTracking(t *testing.T) { }, out: types.String("index"), state: map[int64]any{ - // overall expression - 1: "index", // a[1] 2: map[int64]any{ 1: 0, @@ -969,6 +963,46 @@ func TestAttributeStateTracking(t *testing.T) { 10: int64(0), }, }, + { + expr: `true ? a : b`, + env: []*exprpb.Decl{ + decls.NewVar("a", decls.String), + decls.NewVar("b", decls.String), + }, + in: map[string]any{ + "a": "hello", + "b": "world", + }, + out: types.String("hello"), + state: map[int64]any{ + // 'hello' + 2: types.String("hello"), + }, + }, + { + expr: `(a.size() != 0 ? a : b)[0]`, + env: []*exprpb.Decl{ + decls.NewVar("a", decls.NewListType(decls.String)), + decls.NewVar("b", decls.NewListType(decls.String)), + }, + in: map[string]any{ + "a": []string{"hello", "world"}, + "b": []string{"world", "hello"}, + }, + out: types.String("hello"), + state: map[int64]any{ + // ["hello", "world"] + 1: types.DefaultTypeAdapter.NativeToValue([]string{"hello", "world"}), + // ["hello", "world"].size() // 2 + 2: types.Int(2), + // ["hello", "world"].size() != 0 + 3: types.True, + // constant 0 + 4: types.IntZero, + // 'hello' + 8: types.String("hello"), + }, + }, } for _, test := range tests { tc := test @@ -1014,7 +1048,8 @@ func TestAttributeStateTracking(t *testing.T) { t.Errorf("state not found for %d=%v", id, val) continue } - if !reflect.DeepEqual(stVal.Value(), val) { + wantStVal := types.DefaultTypeAdapter.NativeToValue(val) + if wantStVal.Equal(stVal) != types.True { t.Errorf("got %v, wanted %v for id: %d", stVal.Value(), val, id) } }