Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions ext/native.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,13 @@ import (
"github.com/google/cel-go/common/types/traits"

exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
structpb "google.golang.org/protobuf/types/known/structpb"
)

var (
nativeObjTraitMask = traits.FieldTesterType | traits.IndexerType
jsonValueType = reflect.TypeOf(&structpb.Value{})
jsonStructType = reflect.TypeOf(&structpb.Struct{})
)

// NativeTypes creates a type provider which uses reflect.Type and reflect.Value instances
Expand Down Expand Up @@ -207,6 +210,9 @@ func (tp *nativeTypeProvider) NativeToValue(val any) ref.Val {
if val == nil {
return types.NullValue
}
if v, ok := val.(ref.Val); ok {
return v
}
rawVal := reflect.ValueOf(val)
refVal := rawVal
if refVal.Kind() == reflect.Ptr {
Expand All @@ -216,7 +222,7 @@ func (tp *nativeTypeProvider) NativeToValue(val any) ref.Val {
// but maybe an acceptable limitation.
switch refVal.Kind() {
case reflect.Array, reflect.Slice:
switch val.(type) {
switch val := val.(type) {
case []byte:
return tp.baseAdapter.NativeToValue(val)
default:
Expand All @@ -226,8 +232,6 @@ func (tp *nativeTypeProvider) NativeToValue(val any) ref.Val {
return types.NewDynamicMap(tp, val)
case reflect.Struct:
switch val := val.(type) {
case ref.Val:
return val
case proto.Message, *pb.Map, protoreflect.List, protoreflect.Message, protoreflect.Value,
time.Time:
return tp.baseAdapter.NativeToValue(val)
Expand Down Expand Up @@ -329,6 +333,32 @@ func (o *nativeObj) ConvertToNative(typeDesc reflect.Type) (any, error) {
ptr.Elem().Set(o.refValue)
return ptr.Interface(), nil
}
switch typeDesc {
case jsonValueType:
jsonStruct, err := o.ConvertToNative(jsonStructType)
if err != nil {
return nil, err
}
return structpb.NewStructValue(jsonStruct.(*structpb.Struct)), nil
case jsonStructType:
refVal := reflect.Indirect(o.refValue)
refType := refVal.Type()
fields := make(map[string]*structpb.Value, refVal.NumField())
for i := 0; i < refVal.NumField(); i++ {
fieldType := refType.Field(i)
fieldValue := refVal.Field(i)
if !fieldValue.IsValid() || fieldValue.IsZero() {
continue
}
fieldCelVal := o.NativeToValue(fieldValue.Interface())
fieldJsonVal, err := fieldCelVal.ConvertToNative(jsonValueType)
if err != nil {
return nil, err
}
fields[fieldType.Name] = fieldJsonVal.(*structpb.Value)
}
return &structpb.Struct{Fields: fields}, nil
}
return nil, fmt.Errorf("type conversion error from '%v' to '%v'", o.Type(), typeDesc)
}

Expand Down
85 changes: 82 additions & 3 deletions ext/native_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ import (
"testing"
"time"

"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/structpb"

"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
Expand Down Expand Up @@ -84,13 +86,13 @@ func TestNativeTypes(t *testing.T) {
},
},
{
expr: `ext.TestAllTypes{
expr: `ext.TestAllTypes{
PbVal: test.TestAllTypes{single_int32: 123}
}.PbVal`,
out: &proto3pb.TestAllTypes{SingleInt32: 123},
},
{
expr: `ext.TestAllTypes{PbVal: test.TestAllTypes{}} ==
expr: `ext.TestAllTypes{PbVal: test.TestAllTypes{}} ==
ext.TestAllTypes{PbVal: test.TestAllTypes{single_bool: false}}`,
},
{expr: `ext.TestNestedType{} == TestNestedType{}`},
Expand Down Expand Up @@ -211,6 +213,83 @@ func TestNativeTypesStaticErrors(t *testing.T) {
}
}

func TestNativeTypesJsonSerialization(t *testing.T) {
tests := []struct {
expr string
out string
}{
{
expr: `[b'string']`,
out: `["c3RyaW5n"]`,
},
{
expr: `TestAllTypes{
BoolVal: true,
DurationVal: duration('5s'),
DoubleVal: 1.5,
FloatVal: 2.0,
Int32Val: 23,
Int64Val: 64,
MapVal: {
'map-key': ext.TestAllTypes{
BoolVal: true
}
},
NestedVal: TestNestedType{
NestedListVal: ["first", "second"],
},
StringVal: "string"
}`,
out: `{
"BoolVal": true,
"DoubleVal": 1.5,
"DurationVal": "5s",
"FloatVal": 2,
"Int32Val": 23,
"Int64Val": 64,
"MapVal": {
"map-key": {
"BoolVal": true
}
},
"NestedVal": {
"NestedListVal": [
"first",
"second"
]
},
"StringVal": "string"
}`,
},
}
env := testNativeEnv(t)
for i, tst := range tests {
tc := tst
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
ast, iss := env.Compile(tc.expr)
if iss.Err() != nil {
t.Fatalf("env.Compile(%v) failed: %v", tc.expr, iss.Err())
}
prg, err := env.Program(ast)
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
out, _, err := prg.Eval(cel.NoVars())
if err != nil {
t.Fatalf("prg.Eval() failed: %v", err)
}
conv, err := out.ConvertToNative(reflect.TypeOf(&structpb.Value{}))
if err != nil {
t.Fatalf("out.ConvertToNative(Value) failed: %v", err)
}
json := protojson.Format(conv.(proto.Message))
if !test.Compare(json, tc.out) {
t.Errorf("expr %v converted to %v, wanted %v", tc.expr, json, tc.out)
}
})
}
}

func TestNativeTypesRuntimeErrors(t *testing.T) {
var nativeTests = []struct {
expr string
Expand Down Expand Up @@ -578,7 +657,7 @@ type TestAllTypes struct {
DoubleVal float64
FloatVal float32
Int32Val int32
Int64Val int32
Int64Val int64
StringVal string
TimestampVal time.Time
Uint32Val uint32
Expand Down