diff --git a/bigframes/core/compile/polars/compiler.py b/bigframes/core/compile/polars/compiler.py index dd30aec16a..a3c5d5a80e 100644 --- a/bigframes/core/compile/polars/compiler.py +++ b/bigframes/core/compile/polars/compiler.py @@ -168,7 +168,7 @@ def compile_op(self, op: ops.ScalarOp, *args: pl.Expr) -> pl.Expr: @compile_op.register(gen_ops.InvertOp) def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: - return ~input + return input.not_() @compile_op.register(num_ops.AbsOp) def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: diff --git a/bigframes/core/compile/polars/lowering.py b/bigframes/core/compile/polars/lowering.py index 013651ff17..f6ed6c676c 100644 --- a/bigframes/core/compile/polars/lowering.py +++ b/bigframes/core/compile/polars/lowering.py @@ -14,10 +14,18 @@ import dataclasses +import numpy as np + from bigframes import dtypes from bigframes.core import bigframe_node, expression from bigframes.core.rewrite import op_lowering -from bigframes.operations import comparison_ops, datetime_ops, json_ops, numeric_ops +from bigframes.operations import ( + comparison_ops, + datetime_ops, + generic_ops, + json_ops, + numeric_ops, +) import bigframes.operations as ops # TODO: Would be more precise to actually have separate op set for polars ops (where they diverge from the original ops) @@ -288,6 +296,26 @@ def lower(self, expr: expression.OpExpression) -> expression.Expression: return _lower_cast(expr.op, expr.inputs[0]) +def invert_bytes(byte_string): + inverted_bytes = ~np.frombuffer(byte_string, dtype=np.uint8) + return inverted_bytes.tobytes() + + +class LowerInvertOp(op_lowering.OpLoweringRule): + @property + def op(self) -> type[ops.ScalarOp]: + return generic_ops.InvertOp + + def lower(self, expr: expression.OpExpression) -> expression.Expression: + assert isinstance(expr.op, generic_ops.InvertOp) + arg = expr.children[0] + if arg.output_type == dtypes.BYTES_DTYPE: + return generic_ops.PyUdfOp(invert_bytes, dtypes.BYTES_DTYPE).as_expr( + expr.inputs[0] + ) + return expr + + def _coerce_comparables( expr1: expression.Expression, expr2: expression.Expression, @@ -385,6 +413,7 @@ def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression): LowerFloorDivRule(), LowerModRule(), LowerAsTypeRule(), + LowerInvertOp(), ) diff --git a/bigframes/core/compile/polars/operations/generic_ops.py b/bigframes/core/compile/polars/operations/generic_ops.py index de0e987aa2..4051fa4995 100644 --- a/bigframes/core/compile/polars/operations/generic_ops.py +++ b/bigframes/core/compile/polars/operations/generic_ops.py @@ -45,3 +45,14 @@ def isnull_op_impl( input: pl.Expr, ) -> pl.Expr: return input.is_null() + + +@polars_compiler.register_op(generic_ops.PyUdfOp) +def py_udf_op_impl( + compiler: polars_compiler.PolarsExpressionCompiler, + op: generic_ops.PyUdfOp, # type: ignore + input: pl.Expr, +) -> pl.Expr: + return input.map_elements( + op.fn, return_dtype=polars_compiler._DTYPE_MAPPING[op._output_type] + ) diff --git a/bigframes/operations/generic_ops.py b/bigframes/operations/generic_ops.py index 152de543db..d6155a770c 100644 --- a/bigframes/operations/generic_ops.py +++ b/bigframes/operations/generic_ops.py @@ -446,3 +446,15 @@ class SqlScalarOp(base_ops.NaryOp): def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: return self._output_type + + +@dataclasses.dataclass(frozen=True) +class PyUdfOp(base_ops.NaryOp): + """Represents a local UDF.""" + + name: typing.ClassVar[str] = "py_udf" + fn: typing.Callable + _output_type: dtypes.ExpressionType + + def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: + return self._output_type diff --git a/bigframes/session/polars_executor.py b/bigframes/session/polars_executor.py index ccc577deae..8aa7fd9002 100644 --- a/bigframes/session/polars_executor.py +++ b/bigframes/session/polars_executor.py @@ -58,6 +58,11 @@ numeric_ops.FloorDivOp, numeric_ops.ModOp, generic_ops.AsTypeOp, + generic_ops.WhereOp, + generic_ops.CoalesceOp, + generic_ops.FillNaOp, + generic_ops.CaseWhenOp, + generic_ops.InvertOp, ) _COMPATIBLE_AGG_OPS = ( agg_ops.SizeOp, diff --git a/tests/system/small/engines/test_generic_ops.py b/tests/system/small/engines/test_generic_ops.py index af114991eb..9fdb6bca78 100644 --- a/tests/system/small/engines/test_generic_ops.py +++ b/tests/system/small/engines/test_generic_ops.py @@ -59,6 +59,7 @@ def test_engines_astype_int(scalars_array_value: array_value.ArrayValue, engine) ops.AsTypeOp(to_type=bigframes.dtypes.INT_DTYPE), excluded_cols=["string_col"], ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) @@ -73,6 +74,7 @@ def test_engines_astype_string_int(scalars_array_value: array_value.ArrayValue, for val in vals ] ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) @@ -83,6 +85,7 @@ def test_engines_astype_float(scalars_array_value: array_value.ArrayValue, engin ops.AsTypeOp(to_type=bigframes.dtypes.FLOAT_DTYPE), excluded_cols=["string_col"], ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) @@ -99,6 +102,7 @@ def test_engines_astype_string_float( for val in vals ] ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) @@ -107,6 +111,7 @@ def test_engines_astype_bool(scalars_array_value: array_value.ArrayValue, engine arr = apply_op( scalars_array_value, ops.AsTypeOp(to_type=bigframes.dtypes.BOOL_DTYPE) ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) @@ -118,6 +123,7 @@ def test_engines_astype_string(scalars_array_value: array_value.ArrayValue, engi ops.AsTypeOp(to_type=bigframes.dtypes.STRING_DTYPE), excluded_cols=["float64_col"], ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) @@ -128,6 +134,7 @@ def test_engines_astype_numeric(scalars_array_value: array_value.ArrayValue, eng ops.AsTypeOp(to_type=bigframes.dtypes.NUMERIC_DTYPE), excluded_cols=["string_col"], ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) @@ -144,6 +151,7 @@ def test_engines_astype_string_numeric( for val in vals ] ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) @@ -154,6 +162,7 @@ def test_engines_astype_date(scalars_array_value: array_value.ArrayValue, engine ops.AsTypeOp(to_type=bigframes.dtypes.DATE_DTYPE), excluded_cols=["string_col"], ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) @@ -170,6 +179,7 @@ def test_engines_astype_string_date( for val in vals ] ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) @@ -180,6 +190,7 @@ def test_engines_astype_datetime(scalars_array_value: array_value.ArrayValue, en ops.AsTypeOp(to_type=bigframes.dtypes.DATETIME_DTYPE), excluded_cols=["string_col"], ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) @@ -196,6 +207,7 @@ def test_engines_astype_string_datetime( for val in vals ] ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) @@ -206,6 +218,7 @@ def test_engines_astype_timestamp(scalars_array_value: array_value.ArrayValue, e ops.AsTypeOp(to_type=bigframes.dtypes.TIMESTAMP_DTYPE), excluded_cols=["string_col"], ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) @@ -226,6 +239,7 @@ def test_engines_astype_string_timestamp( for val in vals ] ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) @@ -236,6 +250,7 @@ def test_engines_astype_time(scalars_array_value: array_value.ArrayValue, engine ops.AsTypeOp(to_type=bigframes.dtypes.TIME_DTYPE), excluded_cols=["string_col", "int64_col", "int64_too"], ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) @@ -256,6 +271,7 @@ def test_engines_astype_from_json(scalars_array_value: array_value.ArrayValue, e ), ] arr, _ = scalars_array_value.compute_values(exprs) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) @@ -265,4 +281,112 @@ def test_engines_astype_timedelta(scalars_array_value: array_value.ArrayValue, e scalars_array_value, ops.AsTypeOp(to_type=bigframes.dtypes.TIMEDELTA_DTYPE), ) + + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_where_op(scalars_array_value: array_value.ArrayValue, engine): + arr, _ = scalars_array_value.compute_values( + [ + ops.where_op.as_expr( + expression.deref("int64_col"), + expression.deref("bool_col"), + expression.deref("float64_col"), + ) + ] + ) + + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_coalesce_op(scalars_array_value: array_value.ArrayValue, engine): + arr, _ = scalars_array_value.compute_values( + [ + ops.coalesce_op.as_expr( + expression.deref("int64_col"), + expression.deref("float64_col"), + ) + ] + ) + + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_fillna_op(scalars_array_value: array_value.ArrayValue, engine): + arr, _ = scalars_array_value.compute_values( + [ + ops.fillna_op.as_expr( + expression.deref("int64_col"), + expression.deref("float64_col"), + ) + ] + ) + + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_casewhen_op_single_case( + scalars_array_value: array_value.ArrayValue, engine +): + arr, _ = scalars_array_value.compute_values( + [ + ops.case_when_op.as_expr( + expression.deref("bool_col"), + expression.deref("int64_col"), + ) + ] + ) + + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_casewhen_op_double_case( + scalars_array_value: array_value.ArrayValue, engine +): + arr, _ = scalars_array_value.compute_values( + [ + ops.case_when_op.as_expr( + ops.gt_op.as_expr(expression.deref("int64_col"), expression.const(3)), + expression.deref("int64_col"), + ops.lt_op.as_expr(expression.deref("int64_col"), expression.const(-3)), + expression.deref("int64_too"), + ) + ] + ) + + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_isnull_op(scalars_array_value: array_value.ArrayValue, engine): + arr, _ = scalars_array_value.compute_values( + [ops.isnull_op.as_expr(expression.deref("string_col"))] + ) + + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_notnull_op(scalars_array_value: array_value.ArrayValue, engine): + arr, _ = scalars_array_value.compute_values( + [ops.notnull_op.as_expr(expression.deref("string_col"))] + ) + + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_invert_op(scalars_array_value: array_value.ArrayValue, engine): + arr, _ = scalars_array_value.compute_values( + [ + ops.invert_op.as_expr(expression.deref("bytes_col")), + ops.invert_op.as_expr(expression.deref("bool_col")), + ] + ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) diff --git a/third_party/bigframes_vendored/ibis/expr/operations/numeric.py b/third_party/bigframes_vendored/ibis/expr/operations/numeric.py index 174de5ab7f..384323c596 100644 --- a/third_party/bigframes_vendored/ibis/expr/operations/numeric.py +++ b/third_party/bigframes_vendored/ibis/expr/operations/numeric.py @@ -326,7 +326,7 @@ class Tan(TrigonometricUnary): class BitwiseNot(Unary): """Bitwise NOT operation.""" - arg: Integer + arg: Value[dt.Integer | dt.Binary] dtype = rlz.numeric_like("args", operator.invert) diff --git a/third_party/bigframes_vendored/ibis/expr/types/binary.py b/third_party/bigframes_vendored/ibis/expr/types/binary.py index ba6140a49f..08fea31a1c 100644 --- a/third_party/bigframes_vendored/ibis/expr/types/binary.py +++ b/third_party/bigframes_vendored/ibis/expr/types/binary.py @@ -32,6 +32,9 @@ def hashbytes( """ return ops.HashBytes(self, how).to_expr() + def __invert__(self) -> BinaryValue: + return ops.BitwiseNot(self).to_expr() + @public class BinaryScalar(Scalar, BinaryValue):