diff --git a/bigframes/core/nodes.py b/bigframes/core/nodes.py index 44a8d808ff..82a869dac2 100644 --- a/bigframes/core/nodes.py +++ b/bigframes/core/nodes.py @@ -14,7 +14,7 @@ from __future__ import annotations -from dataclasses import dataclass, field +from dataclasses import dataclass, field, fields import functools import typing from typing import Optional, Tuple @@ -66,6 +66,13 @@ def session(self): return sessions[0] return None + # BigFrameNode trees can be very deep so its important avoid recalculating the hash from scratch + # Each subclass of BigFrameNode should use this property to implement __hash__ + # The default dataclass-generated __hash__ method is not cached + @functools.cached_property + def _node_hash(self): + return hash(tuple(hash(getattr(self, field.name)) for field in fields(self))) + @dataclass(frozen=True) class UnaryNode(BigFrameNode): @@ -95,6 +102,9 @@ class JoinNode(BigFrameNode): def child_nodes(self) -> typing.Sequence[BigFrameNode]: return (self.left_child, self.right_child) + def __hash__(self): + return self._node_hash + @dataclass(frozen=True) class ConcatNode(BigFrameNode): @@ -104,6 +114,9 @@ class ConcatNode(BigFrameNode): def child_nodes(self) -> typing.Sequence[BigFrameNode]: return self.children + def __hash__(self): + return self._node_hash + # Input Nodex @dataclass(frozen=True) @@ -111,6 +124,9 @@ class ReadLocalNode(BigFrameNode): feather_bytes: bytes column_ids: typing.Tuple[str, ...] + def __hash__(self): + return self._node_hash + # TODO: Refactor to take raw gbq object reference @dataclass(frozen=True) @@ -125,38 +141,60 @@ class ReadGbqNode(BigFrameNode): def session(self): return (self.table_session,) + def __hash__(self): + return self._node_hash + # Unary nodes @dataclass(frozen=True) class DropColumnsNode(UnaryNode): columns: Tuple[str, ...] + def __hash__(self): + return self._node_hash + @dataclass(frozen=True) class PromoteOffsetsNode(UnaryNode): col_id: str + def __hash__(self): + return self._node_hash + @dataclass(frozen=True) class FilterNode(UnaryNode): predicate_id: str keep_null: bool = False + def __hash__(self): + return self._node_hash + @dataclass(frozen=True) class OrderByNode(UnaryNode): by: Tuple[OrderingColumnReference, ...] + def __hash__(self): + return self._node_hash + @dataclass(frozen=True) class ReversedNode(UnaryNode): - pass + # useless field to make sure has distinct hash + reversed: bool = True + + def __hash__(self): + return self._node_hash @dataclass(frozen=True) class SelectNode(UnaryNode): column_ids: typing.Tuple[str, ...] + def __hash__(self): + return self._node_hash + @dataclass(frozen=True) class ProjectUnaryOpNode(UnaryNode): @@ -164,6 +202,9 @@ class ProjectUnaryOpNode(UnaryNode): op: ops.UnaryOp output_id: Optional[str] = None + def __hash__(self): + return self._node_hash + @dataclass(frozen=True) class ProjectBinaryOpNode(UnaryNode): @@ -172,6 +213,9 @@ class ProjectBinaryOpNode(UnaryNode): op: ops.BinaryOp output_id: str + def __hash__(self): + return self._node_hash + @dataclass(frozen=True) class ProjectTernaryOpNode(UnaryNode): @@ -181,6 +225,9 @@ class ProjectTernaryOpNode(UnaryNode): op: ops.TernaryOp output_id: str + def __hash__(self): + return self._node_hash + @dataclass(frozen=True) class AggregateNode(UnaryNode): @@ -188,12 +235,18 @@ class AggregateNode(UnaryNode): by_column_ids: typing.Tuple[str, ...] = tuple([]) dropna: bool = True + def __hash__(self): + return self._node_hash + # TODO: Unify into aggregate @dataclass(frozen=True) class CorrNode(UnaryNode): corr_aggregations: typing.Tuple[typing.Tuple[str, str, str], ...] + def __hash__(self): + return self._node_hash + @dataclass(frozen=True) class WindowOpNode(UnaryNode): @@ -204,10 +257,14 @@ class WindowOpNode(UnaryNode): never_skip_nulls: bool = False skip_reproject_unsafe: bool = False + def __hash__(self): + return self._node_hash + @dataclass(frozen=True) class ReprojectOpNode(UnaryNode): - pass + def __hash__(self): + return self._node_hash @dataclass(frozen=True) @@ -223,12 +280,18 @@ class UnpivotNode(UnaryNode): ] = (pandas.Float64Dtype(),) how: typing.Literal["left", "right"] = "left" + def __hash__(self): + return self._node_hash + @dataclass(frozen=True) class AssignNode(UnaryNode): source_id: str destination_id: str + def __hash__(self): + return self._node_hash + @dataclass(frozen=True) class AssignConstantNode(UnaryNode): @@ -236,6 +299,9 @@ class AssignConstantNode(UnaryNode): value: typing.Hashable dtype: typing.Optional[bigframes.dtypes.Dtype] + def __hash__(self): + return self._node_hash + @dataclass(frozen=True) class RandomSampleNode(UnaryNode): @@ -244,3 +310,6 @@ class RandomSampleNode(UnaryNode): @property def deterministic(self) -> bool: return False + + def __hash__(self): + return self._node_hash diff --git a/bigframes/pandas/__init__.py b/bigframes/pandas/__init__.py index c9640abb23..7386c4a2e7 100644 --- a/bigframes/pandas/__init__.py +++ b/bigframes/pandas/__init__.py @@ -18,6 +18,7 @@ from collections import namedtuple import inspect +import sys import typing from typing import ( Any, @@ -657,6 +658,9 @@ def read_gbq_function(function_name: str): close_session = global_session.close_session reset_session = global_session.close_session +# SQL Compilation uses recursive algorithms on deep trees +# 10M tree depth should be sufficient to generate any sql that is under bigquery limit +sys.setrecursionlimit(max(10000000, sys.getrecursionlimit())) # Use __all__ to let type checkers know what is part of the public API. __all___ = [ diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index 57115335dc..663a7ceb49 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -3667,6 +3667,13 @@ def test_df_dot_operator_series( ) +def test_recursion_limit(scalars_df_index): + scalars_df_index = scalars_df_index[["int64_too", "int64_col", "float64_col"]] + for i in range(400): + scalars_df_index = scalars_df_index + 4 + scalars_df_index.to_pandas() + + def test_to_pandas_downsampling_option_override(session): df = session.read_gbq("bigframes-dev.bigframes_tests_sys.batting") download_size = 1