Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bigframes/core/compile/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def compile_op(self, op: ops.ScalarOp, *args: pl.Expr) -> pl.Expr:

@compile_op.register(gen_ops.InvertOp)
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
return ~input
return input.not_()

@compile_op.register(num_ops.AbsOp)
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
Expand Down
31 changes: 30 additions & 1 deletion bigframes/core/compile/polars/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,18 @@

import dataclasses

import numpy as np

from bigframes import dtypes
from bigframes.core import bigframe_node, expression
from bigframes.core.rewrite import op_lowering
from bigframes.operations import comparison_ops, datetime_ops, json_ops, numeric_ops
from bigframes.operations import (
comparison_ops,
datetime_ops,
generic_ops,
json_ops,
numeric_ops,
)
import bigframes.operations as ops

# TODO: Would be more precise to actually have separate op set for polars ops (where they diverge from the original ops)
Expand Down Expand Up @@ -288,6 +296,26 @@ def lower(self, expr: expression.OpExpression) -> expression.Expression:
return _lower_cast(expr.op, expr.inputs[0])


def invert_bytes(byte_string):
inverted_bytes = ~np.frombuffer(byte_string, dtype=np.uint8)
return inverted_bytes.tobytes()


class LowerInvertOp(op_lowering.OpLoweringRule):
@property
def op(self) -> type[ops.ScalarOp]:
return generic_ops.InvertOp

def lower(self, expr: expression.OpExpression) -> expression.Expression:
assert isinstance(expr.op, generic_ops.InvertOp)
arg = expr.children[0]
if arg.output_type == dtypes.BYTES_DTYPE:
return generic_ops.PyUdfOp(invert_bytes, dtypes.BYTES_DTYPE).as_expr(
expr.inputs[0]
)
return expr


def _coerce_comparables(
expr1: expression.Expression,
expr2: expression.Expression,
Expand Down Expand Up @@ -385,6 +413,7 @@ def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression):
LowerFloorDivRule(),
LowerModRule(),
LowerAsTypeRule(),
LowerInvertOp(),
)


Expand Down
11 changes: 11 additions & 0 deletions bigframes/core/compile/polars/operations/generic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,14 @@ def isnull_op_impl(
input: pl.Expr,
) -> pl.Expr:
return input.is_null()


@polars_compiler.register_op(generic_ops.PyUdfOp)
def py_udf_op_impl(
compiler: polars_compiler.PolarsExpressionCompiler,
op: generic_ops.PyUdfOp, # type: ignore
input: pl.Expr,
) -> pl.Expr:
return input.map_elements(
op.fn, return_dtype=polars_compiler._DTYPE_MAPPING[op._output_type]
)
12 changes: 12 additions & 0 deletions bigframes/operations/generic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,3 +446,15 @@ class SqlScalarOp(base_ops.NaryOp):

def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
return self._output_type


@dataclasses.dataclass(frozen=True)
class PyUdfOp(base_ops.NaryOp):
"""Represents a local UDF."""

name: typing.ClassVar[str] = "py_udf"
fn: typing.Callable
_output_type: dtypes.ExpressionType

def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
return self._output_type
5 changes: 5 additions & 0 deletions bigframes/session/polars_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@
numeric_ops.FloorDivOp,
numeric_ops.ModOp,
generic_ops.AsTypeOp,
generic_ops.WhereOp,
generic_ops.CoalesceOp,
generic_ops.FillNaOp,
generic_ops.CaseWhenOp,
generic_ops.InvertOp,
)
_COMPATIBLE_AGG_OPS = (
agg_ops.SizeOp,
Expand Down
124 changes: 124 additions & 0 deletions tests/system/small/engines/test_generic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def test_engines_astype_int(scalars_array_value: array_value.ArrayValue, engine)
ops.AsTypeOp(to_type=bigframes.dtypes.INT_DTYPE),
excluded_cols=["string_col"],
)

assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


Expand All @@ -73,6 +74,7 @@ def test_engines_astype_string_int(scalars_array_value: array_value.ArrayValue,
for val in vals
]
)

assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


Expand All @@ -83,6 +85,7 @@ def test_engines_astype_float(scalars_array_value: array_value.ArrayValue, engin
ops.AsTypeOp(to_type=bigframes.dtypes.FLOAT_DTYPE),
excluded_cols=["string_col"],
)

assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


Expand All @@ -99,6 +102,7 @@ def test_engines_astype_string_float(
for val in vals
]
)

assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


Expand All @@ -107,6 +111,7 @@ def test_engines_astype_bool(scalars_array_value: array_value.ArrayValue, engine
arr = apply_op(
scalars_array_value, ops.AsTypeOp(to_type=bigframes.dtypes.BOOL_DTYPE)
)

assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


Expand All @@ -118,6 +123,7 @@ def test_engines_astype_string(scalars_array_value: array_value.ArrayValue, engi
ops.AsTypeOp(to_type=bigframes.dtypes.STRING_DTYPE),
excluded_cols=["float64_col"],
)

assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


Expand All @@ -128,6 +134,7 @@ def test_engines_astype_numeric(scalars_array_value: array_value.ArrayValue, eng
ops.AsTypeOp(to_type=bigframes.dtypes.NUMERIC_DTYPE),
excluded_cols=["string_col"],
)

assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


Expand All @@ -144,6 +151,7 @@ def test_engines_astype_string_numeric(
for val in vals
]
)

assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


Expand All @@ -154,6 +162,7 @@ def test_engines_astype_date(scalars_array_value: array_value.ArrayValue, engine
ops.AsTypeOp(to_type=bigframes.dtypes.DATE_DTYPE),
excluded_cols=["string_col"],
)

assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


Expand All @@ -170,6 +179,7 @@ def test_engines_astype_string_date(
for val in vals
]
)

assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


Expand All @@ -180,6 +190,7 @@ def test_engines_astype_datetime(scalars_array_value: array_value.ArrayValue, en
ops.AsTypeOp(to_type=bigframes.dtypes.DATETIME_DTYPE),
excluded_cols=["string_col"],
)

assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


Expand All @@ -196,6 +207,7 @@ def test_engines_astype_string_datetime(
for val in vals
]
)

assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


Expand All @@ -206,6 +218,7 @@ def test_engines_astype_timestamp(scalars_array_value: array_value.ArrayValue, e
ops.AsTypeOp(to_type=bigframes.dtypes.TIMESTAMP_DTYPE),
excluded_cols=["string_col"],
)

assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


Expand All @@ -226,6 +239,7 @@ def test_engines_astype_string_timestamp(
for val in vals
]
)

assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


Expand All @@ -236,6 +250,7 @@ def test_engines_astype_time(scalars_array_value: array_value.ArrayValue, engine
ops.AsTypeOp(to_type=bigframes.dtypes.TIME_DTYPE),
excluded_cols=["string_col", "int64_col", "int64_too"],
)

assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


Expand All @@ -256,6 +271,7 @@ def test_engines_astype_from_json(scalars_array_value: array_value.ArrayValue, e
),
]
arr, _ = scalars_array_value.compute_values(exprs)

assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


Expand All @@ -265,4 +281,112 @@ def test_engines_astype_timedelta(scalars_array_value: array_value.ArrayValue, e
scalars_array_value,
ops.AsTypeOp(to_type=bigframes.dtypes.TIMEDELTA_DTYPE),
)

assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
def test_engines_where_op(scalars_array_value: array_value.ArrayValue, engine):
arr, _ = scalars_array_value.compute_values(
[
ops.where_op.as_expr(
expression.deref("int64_col"),
expression.deref("bool_col"),
expression.deref("float64_col"),
)
]
)

assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style nit: could you add an empty line before each "assert" statement? Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed for whole file



@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
def test_engines_coalesce_op(scalars_array_value: array_value.ArrayValue, engine):
arr, _ = scalars_array_value.compute_values(
[
ops.coalesce_op.as_expr(
expression.deref("int64_col"),
expression.deref("float64_col"),
)
]
)

assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
def test_engines_fillna_op(scalars_array_value: array_value.ArrayValue, engine):
arr, _ = scalars_array_value.compute_values(
[
ops.fillna_op.as_expr(
expression.deref("int64_col"),
expression.deref("float64_col"),
)
]
)

assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
def test_engines_casewhen_op_single_case(
scalars_array_value: array_value.ArrayValue, engine
):
arr, _ = scalars_array_value.compute_values(
[
ops.case_when_op.as_expr(
expression.deref("bool_col"),
expression.deref("int64_col"),
)
]
)

assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
def test_engines_casewhen_op_double_case(
scalars_array_value: array_value.ArrayValue, engine
):
arr, _ = scalars_array_value.compute_values(
[
ops.case_when_op.as_expr(
ops.gt_op.as_expr(expression.deref("int64_col"), expression.const(3)),
expression.deref("int64_col"),
ops.lt_op.as_expr(expression.deref("int64_col"), expression.const(-3)),
expression.deref("int64_too"),
)
]
)

assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
def test_engines_isnull_op(scalars_array_value: array_value.ArrayValue, engine):
arr, _ = scalars_array_value.compute_values(
[ops.isnull_op.as_expr(expression.deref("string_col"))]
)

assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
def test_engines_notnull_op(scalars_array_value: array_value.ArrayValue, engine):
arr, _ = scalars_array_value.compute_values(
[ops.notnull_op.as_expr(expression.deref("string_col"))]
)

assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
def test_engines_invert_op(scalars_array_value: array_value.ArrayValue, engine):
arr, _ = scalars_array_value.compute_values(
[
ops.invert_op.as_expr(expression.deref("bytes_col")),
ops.invert_op.as_expr(expression.deref("bool_col")),
]
)

assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ class Tan(TrigonometricUnary):
class BitwiseNot(Unary):
"""Bitwise NOT operation."""

arg: Integer
arg: Value[dt.Integer | dt.Binary]

dtype = rlz.numeric_like("args", operator.invert)

Expand Down
3 changes: 3 additions & 0 deletions third_party/bigframes_vendored/ibis/expr/types/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def hashbytes(
"""
return ops.HashBytes(self, how).to_expr()

def __invert__(self) -> BinaryValue:
return ops.BitwiseNot(self).to_expr()


@public
class BinaryScalar(Scalar, BinaryValue):
Expand Down