diff --git a/.github/trigger_files/beam_PostCommit_Python_ValidatesRunner_Dataflow.json b/.github/trigger_files/beam_PostCommit_Python_ValidatesRunner_Dataflow.json index 19ebbfb9ad92..e3d6056a5de9 100644 --- a/.github/trigger_files/beam_PostCommit_Python_ValidatesRunner_Dataflow.json +++ b/.github/trigger_files/beam_PostCommit_Python_ValidatesRunner_Dataflow.json @@ -1,3 +1,4 @@ { - "https://github.com/apache/beam/pull/35951": "triggering sideinput test" + "comment": "Modify this file in a trivial way to cause this test suite to run", + "modification": 1 } diff --git a/sdks/python/apache_beam/coders/coder_impl.pxd b/sdks/python/apache_beam/coders/coder_impl.pxd index 27cffe7b62df..2db7b963d151 100644 --- a/sdks/python/apache_beam/coders/coder_impl.pxd +++ b/sdks/python/apache_beam/coders/coder_impl.pxd @@ -81,6 +81,7 @@ cdef class FastPrimitivesCoderImpl(StreamCoderImpl): cdef CoderImpl iterable_coder_impl cdef object requires_deterministic_step_label cdef bint warn_deterministic_fallback + cdef bint force_use_dill @cython.locals(dict_value=dict, int_value=libc.stdint.int64_t, unicode_value=unicode) @@ -88,6 +89,7 @@ cdef class FastPrimitivesCoderImpl(StreamCoderImpl): @cython.locals(t=int) cpdef decode_from_stream(self, InputStream stream, bint nested) cdef encode_special_deterministic(self, value, OutputStream stream) + cdef encode_type_2_67_0(self, t, OutputStream stream) cdef encode_type(self, t, OutputStream stream) cdef decode_type(self, InputStream stream) diff --git a/sdks/python/apache_beam/coders/coder_impl.py b/sdks/python/apache_beam/coders/coder_impl.py index 807d083d8a38..4f28fb3c916b 100644 --- a/sdks/python/apache_beam/coders/coder_impl.py +++ b/sdks/python/apache_beam/coders/coder_impl.py @@ -50,7 +50,6 @@ from typing import Tuple from typing import Type -import dill import numpy as np from fastavro import parse_schema from fastavro import schemaless_reader @@ -58,6 +57,7 @@ from apache_beam.coders import observable from apache_beam.coders.avro_record import AvroRecord +from apache_beam.internal import cloudpickle_pickler from apache_beam.typehints.schemas import named_tuple_from_schema from apache_beam.utils import proto_utils from apache_beam.utils import windowed_value @@ -71,6 +71,11 @@ except ImportError: dataclasses = None # type: ignore +try: + import dill +except ImportError: + dill = None + if TYPE_CHECKING: import proto from apache_beam.transforms import userstate @@ -354,14 +359,30 @@ def decode(self, value): _ITERABLE_LIKE_TYPES = set() # type: Set[Type] +def _verify_dill_compat(): + base_error = ( + "This pipeline runs with the pipeline option " + "--update_compatibility_version=2.67.0 or earlier. " + "When running with this option on SDKs 2.68.0 or " + "later, you must ensure dill==0.3.1.1 is installed.") + if not dill: + raise RuntimeError(base_error + ". Dill is not installed.") + if dill.__version__ != "0.3.1.1": + raise RuntimeError(base_error + f". Found dill version '{dill.__version__}") + + class FastPrimitivesCoderImpl(StreamCoderImpl): """For internal use only; no backwards-compatibility guarantees.""" def __init__( - self, fallback_coder_impl, requires_deterministic_step_label=None): + self, + fallback_coder_impl, + requires_deterministic_step_label=None, + force_use_dill=False): self.fallback_coder_impl = fallback_coder_impl self.iterable_coder_impl = IterableCoderImpl(self) self.requires_deterministic_step_label = requires_deterministic_step_label self.warn_deterministic_fallback = True + self.force_use_dill = force_use_dill @staticmethod def register_iterable_like_type(t): @@ -525,10 +546,23 @@ def _deterministic_encoding_error_msg(self, value): "please provide a type hint for the input of '%s'" % (value, type(value), self.requires_deterministic_step_label)) - def encode_type(self, t, stream): + def encode_type_2_67_0(self, t, stream): + """ + Encode special type with <=2.67.0 compatibility. + """ + _verify_dill_compat() stream.write(dill.dumps(t), True) + def encode_type(self, t, stream): + if self.force_use_dill: + return self.encode_type_2_67_0(t, stream) + bs = cloudpickle_pickler.dumps( + t, config=cloudpickle_pickler.NO_DYNAMIC_CLASS_TRACKING_CONFIG) + stream.write(bs, True) + def decode_type(self, stream): + if self.force_use_dill: + return _unpickle_type_2_67_0(stream.read_all(True)) return _unpickle_type(stream.read_all(True)) def decode_from_stream(self, stream, nested): @@ -589,19 +623,35 @@ def decode_from_stream(self, stream, nested): _unpickled_types = {} # type: Dict[bytes, type] -def _unpickle_type(bs): +def _unpickle_type_2_67_0(bs): + """ + Decode special type with <=2.67.0 compatibility. + """ t = _unpickled_types.get(bs, None) if t is None: + _verify_dill_compat() t = _unpickled_types[bs] = dill.loads(bs) # Fix unpicklable anonymous named tuples for Python 3.6. if t.__base__ is tuple and hasattr(t, '_fields'): try: pickle.loads(pickle.dumps(t)) except pickle.PicklingError: - t.__reduce__ = lambda self: (_unpickle_named_tuple, (bs, tuple(self))) + t.__reduce__ = lambda self: ( + _unpickle_named_tuple_2_67_0, (bs, tuple(self))) return t +def _unpickle_named_tuple_2_67_0(bs, items): + return _unpickle_type_2_67_0(bs)(*items) + + +def _unpickle_type(bs): + if not _unpickled_types.get(bs, None): + _unpickled_types[bs] = cloudpickle_pickler.loads(bs) + + return _unpickled_types[bs] + + def _unpickle_named_tuple(bs, items): return _unpickle_type(bs)(*items) @@ -837,6 +887,7 @@ def decode_from_stream(self, in_, nested): if IntervalWindow is None: from apache_beam.transforms.window import IntervalWindow # instantiating with None is not part of the public interface + # pylint: disable=too-many-function-args typed_value = IntervalWindow(None, None) # type: ignore[arg-type] typed_value._end_micros = ( 1000 * self._to_normal_time(in_.read_bigendian_uint64())) diff --git a/sdks/python/apache_beam/coders/coders.py b/sdks/python/apache_beam/coders/coders.py index 2691857bf0a6..e527185bd571 100644 --- a/sdks/python/apache_beam/coders/coders.py +++ b/sdks/python/apache_beam/coders/coders.py @@ -911,6 +911,44 @@ def _create_impl(self): cloudpickle_pickler.dumps, cloudpickle_pickler.loads) +class DeterministicFastPrimitivesCoderV2(FastCoder): + """Throws runtime errors when encoding non-deterministic values.""" + def __init__(self, coder, step_label): + self._underlying_coder = coder + self._step_label = step_label + + def _create_impl(self): + + return coder_impl.FastPrimitivesCoderImpl( + self._underlying_coder.get_impl(), + requires_deterministic_step_label=self._step_label, + force_use_dill=False) + + def is_deterministic(self): + # type: () -> bool + return True + + def is_kv_coder(self): + # type: () -> bool + return True + + def key_coder(self): + return self + + def value_coder(self): + return self + + def to_type_hint(self): + return Any + + def to_runner_api_parameter(self, context): + # type: (Optional[PipelineContext]) -> Tuple[str, Any, Sequence[Coder]] + return ( + python_urns.PICKLED_CODER, + google.protobuf.wrappers_pb2.BytesValue(value=serialize_coder(self)), + ()) + + class DeterministicFastPrimitivesCoder(FastCoder): """Throws runtime errors when encoding non-deterministic values.""" def __init__(self, coder, step_label): @@ -920,7 +958,8 @@ def __init__(self, coder, step_label): def _create_impl(self): return coder_impl.FastPrimitivesCoderImpl( self._underlying_coder.get_impl(), - requires_deterministic_step_label=self._step_label) + requires_deterministic_step_label=self._step_label, + force_use_dill=True) def is_deterministic(self): # type: () -> bool @@ -940,6 +979,34 @@ def to_type_hint(self): return Any +def _should_force_use_dill(): + from apache_beam.coders import typecoders + from apache_beam.transforms.util import is_v1_prior_to_v2 + update_compat_version = typecoders.registry.update_compatibility_version + + if not update_compat_version: + return False + + if not is_v1_prior_to_v2(v1=update_compat_version, v2="2.68.0"): + return False + + try: + import dill + assert dill.__version__ == "0.3.1.1" + except Exception as e: + raise RuntimeError("This pipeline runs with the pipeline option " \ + "--update_compatibility_version=2.67.0 or earlier. When running with " \ + "this option on SDKs 2.68.0 or later, you must ensure dill==0.3.1.1 " \ + f"is installed. Error {e}") + return True + + +def _update_compatible_deterministic_fast_primitives_coder(coder, step_label): + if _should_force_use_dill(): + return DeterministicFastPrimitivesCoder(coder, step_label) + return DeterministicFastPrimitivesCoderV2(coder, step_label) + + class FastPrimitivesCoder(FastCoder): """Encodes simple primitives (e.g. str, int) efficiently. @@ -960,7 +1027,8 @@ def as_deterministic_coder(self, step_label, error_message=None): if self.is_deterministic(): return self else: - return DeterministicFastPrimitivesCoder(self, step_label) + return _update_compatible_deterministic_fast_primitives_coder( + self, step_label) def to_type_hint(self): return Any diff --git a/sdks/python/apache_beam/coders/coders_test_common.py b/sdks/python/apache_beam/coders/coders_test_common.py index dbd0a301bb0d..587e5d87522e 100644 --- a/sdks/python/apache_beam/coders/coders_test_common.py +++ b/sdks/python/apache_beam/coders/coders_test_common.py @@ -34,6 +34,8 @@ from typing import NamedTuple import pytest +from parameterized import param +from parameterized import parameterized from apache_beam.coders import proto2_coder_test_messages_pb2 as test_message from apache_beam.coders import coders @@ -58,6 +60,7 @@ dataclasses = None # type: ignore MyNamedTuple = collections.namedtuple('A', ['x', 'y']) # type: ignore[name-match] +AnotherNamedTuple = collections.namedtuple('AnotherNamedTuple', ['x', 'y']) MyTypedNamedTuple = NamedTuple('MyTypedNamedTuple', [('f1', int), ('f2', str)]) @@ -175,6 +178,9 @@ def tearDownClass(cls): assert not standard - cls.seen, str(standard - cls.seen) assert not cls.seen_nested - standard, str(cls.seen_nested - standard) + def tearDown(self): + typecoders.registry.update_compatibility_version = None + @classmethod def _observe(cls, coder): cls.seen.add(type(coder)) @@ -230,9 +236,15 @@ def test_memoizing_pickle_coder(self): coder = coders._MemoizingPickleCoder() self.check_coder(coder, *self.test_values) - def test_deterministic_coder(self): + @parameterized.expand([ + param(compat_version=None), + param(compat_version="2.67.0"), + ]) + def test_deterministic_coder(self, compat_version): + typecoders.registry.update_compatibility_version = compat_version coder = coders.FastPrimitivesCoder() - deterministic_coder = coders.DeterministicFastPrimitivesCoder(coder, 'step') + deterministic_coder = coder.as_deterministic_coder(step_label="step") + self.check_coder(deterministic_coder, *self.test_values_deterministic) for v in self.test_values_deterministic: self.check_coder(coders.TupleCoder((deterministic_coder, )), (v, )) @@ -254,8 +266,16 @@ def test_deterministic_coder(self): self.check_coder(deterministic_coder, test_message.MessageA(field1='value')) + # Skip this test during cloudpickle. Dill monkey patches the __reduce__ + # method for anonymous named tuples (MyNamedTuple) which is not pickleable. + # Since the test is parameterized the type gets colbbered. + if compat_version: + self.check_coder( + deterministic_coder, [MyNamedTuple(1, 2), MyTypedNamedTuple(1, 'a')]) + self.check_coder( - deterministic_coder, [MyNamedTuple(1, 2), MyTypedNamedTuple(1, 'a')]) + deterministic_coder, + [AnotherNamedTuple(1, 2), MyTypedNamedTuple(1, 'a')]) if dataclasses is not None: self.check_coder(deterministic_coder, FrozenDataClass(1, 2)) @@ -265,9 +285,10 @@ def test_deterministic_coder(self): with self.assertRaises(TypeError): self.check_coder( deterministic_coder, FrozenDataClass(UnFrozenDataClass(1, 2), 3)) - with self.assertRaises(TypeError): - self.check_coder( - deterministic_coder, MyNamedTuple(UnFrozenDataClass(1, 2), 3)) + with self.assertRaises(TypeError): + self.check_coder( + deterministic_coder, + AnotherNamedTuple(UnFrozenDataClass(1, 2), 3)) self.check_coder(deterministic_coder, list(MyEnum)) self.check_coder(deterministic_coder, list(MyIntEnum)) @@ -286,6 +307,29 @@ def test_deterministic_coder(self): 1: 'x', 'y': 2 })) + @parameterized.expand([ + param(compat_version=None), + param(compat_version="2.67.0"), + ]) + def test_deterministic_map_coder_is_update_compatible(self, compat_version): + typecoders.registry.update_compatibility_version = compat_version + values = [{ + MyTypedNamedTuple(i, 'a'): MyTypedNamedTuple('a', i) + for i in range(10) + }] + + coder = coders.MapCoder( + coders.FastPrimitivesCoder(), coders.FastPrimitivesCoder()) + + deterministic_coder = coder.as_deterministic_coder(step_label="step") + + assert isinstance( + deterministic_coder._key_coder, + coders.DeterministicFastPrimitivesCoderV2 + if not compat_version else coders.DeterministicFastPrimitivesCoder) + + self.check_coder(deterministic_coder, *values) + def test_dill_coder(self): cell_value = (lambda x: lambda: x)(0).__closure__[0] self.check_coder(coders.DillCoder(), 'a', 1, cell_value) @@ -610,15 +654,21 @@ def test_param_windowed_value_coder(self): 1, (window.IntervalWindow(11, 21), ), PaneInfo(True, False, 1, 2, 3)))) - def test_cross_process_encoding_of_special_types_is_deterministic(self): + @parameterized.expand([ + param(compat_version=None), + param(compat_version="2.67.0"), + ]) + def test_cross_process_encoding_of_special_types_is_deterministic( + self, compat_version): """Test cross-process determinism for all special deterministic types""" if sys.executable is None: self.skipTest('No Python interpreter found') + typecoders.registry.update_compatibility_version = compat_version # pylint: disable=line-too-long script = textwrap.dedent( - '''\ + f'''\ import pickle import sys import collections @@ -626,13 +676,19 @@ def test_cross_process_encoding_of_special_types_is_deterministic(self): import logging from apache_beam.coders import coders - from apache_beam.coders import proto2_coder_test_messages_pb2 as test_message - from typing import NamedTuple + from apache_beam.coders import typecoders + from apache_beam.coders.coders_test_common import MyNamedTuple + from apache_beam.coders.coders_test_common import MyTypedNamedTuple + from apache_beam.coders.coders_test_common import MyEnum + from apache_beam.coders.coders_test_common import MyIntEnum + from apache_beam.coders.coders_test_common import MyIntFlag + from apache_beam.coders.coders_test_common import MyFlag + from apache_beam.coders.coders_test_common import DefinesGetState + from apache_beam.coders.coders_test_common import DefinesGetAndSetState + from apache_beam.coders.coders_test_common import FrozenDataClass - try: - import dataclasses - except ImportError: - dataclasses = None + + from apache_beam.coders import proto2_coder_test_messages_pb2 as test_message logging.basicConfig( level=logging.INFO, @@ -640,38 +696,6 @@ def test_cross_process_encoding_of_special_types_is_deterministic(self): stream=sys.stderr, force=True ) - - # Define all the special types that encode_special_deterministic handles - MyNamedTuple = collections.namedtuple('A', ['x', 'y']) - MyTypedNamedTuple = NamedTuple('MyTypedNamedTuple', [('f1', int), ('f2', str)]) - - class MyEnum(enum.Enum): - E1 = 5 - E2 = enum.auto() - E3 = 'abc' - - MyIntEnum = enum.IntEnum('MyIntEnum', 'I1 I2 I3') - MyIntFlag = enum.IntFlag('MyIntFlag', 'F1 F2 F3') - MyFlag = enum.Flag('MyFlag', 'F1 F2 F3') - - if dataclasses is not None: - @dataclasses.dataclass(frozen=True) - class FrozenDataClass: - a: int - b: int - - class DefinesGetAndSetState: - def __init__(self, value): - self.value = value - - def __getstate__(self): - return self.value - - def __setstate__(self, value): - self.value = value - - def __eq__(self, other): - return type(other) is type(self) and other.value == self.value # Test cases for all special deterministic types # NOTE: When this script run in a subprocess the module is considered @@ -683,26 +707,28 @@ def __eq__(self, other): ("named_tuple_simple", MyNamedTuple(1, 2)), ("typed_named_tuple", MyTypedNamedTuple(1, 'a')), ("named_tuple_list", [MyNamedTuple(1, 2), MyTypedNamedTuple(1, 'a')]), - # ("enum_single", MyEnum.E1), - # ("enum_list", list(MyEnum)), - # ("int_enum_list", list(MyIntEnum)), - # ("int_flag_list", list(MyIntFlag)), - # ("flag_list", list(MyFlag)), + ("enum_single", MyEnum.E1), + ("enum_list", list(MyEnum)), + ("int_enum_list", list(MyIntEnum)), + ("int_flag_list", list(MyIntFlag)), + ("flag_list", list(MyFlag)), ("getstate_setstate_simple", DefinesGetAndSetState(1)), ("getstate_setstate_complex", DefinesGetAndSetState((1, 2, 3))), ("getstate_setstate_list", [DefinesGetAndSetState(1), DefinesGetAndSetState((1, 2, 3))]), ] - if dataclasses is not None: - test_cases.extend([ - ("frozen_dataclass", FrozenDataClass(1, 2)), - ("frozen_dataclass_list", [FrozenDataClass(1, 2), FrozenDataClass(3, 4)]), - ]) + + test_cases.extend([ + ("frozen_dataclass", FrozenDataClass(1, 2)), + ("frozen_dataclass_list", [FrozenDataClass(1, 2), FrozenDataClass(3, 4)]), + ]) + compat_version = {'"'+ compat_version +'"' if compat_version else None} + typecoders.registry.update_compatibility_version = compat_version coder = coders.FastPrimitivesCoder() - deterministic_coder = coders.DeterministicFastPrimitivesCoder(coder, 'step') + deterministic_coder = coder.as_deterministic_coder("step") - results = {} + results = dict() for test_name, value in test_cases: try: encoded = deterministic_coder.encode(value) @@ -730,7 +756,7 @@ def run_subprocess(): results2 = run_subprocess() coder = coders.FastPrimitivesCoder() - deterministic_coder = coders.DeterministicFastPrimitivesCoder(coder, 'step') + deterministic_coder = coder.as_deterministic_coder("step") for test_name in results1: data1 = results1[test_name] @@ -861,7 +887,7 @@ def test_map_coder(self): { i: str(i) for i in range(5000) - } + }, ] map_coder = coders.MapCoder(coders.VarIntCoder(), coders.StrUtf8Coder()) self.check_coder(map_coder, *values) diff --git a/sdks/python/apache_beam/coders/typecoders.py b/sdks/python/apache_beam/coders/typecoders.py index 19300c675596..779c65dc772c 100644 --- a/sdks/python/apache_beam/coders/typecoders.py +++ b/sdks/python/apache_beam/coders/typecoders.py @@ -84,6 +84,7 @@ def __init__(self, fallback_coder=None): self._coders: Dict[Any, Type[coders.Coder]] = {} self.custom_types: List[Any] = [] self.register_standard_coders(fallback_coder) + self.update_compatibility_version = None def register_standard_coders(self, fallback_coder): """Register coders for all basic and composite types.""" diff --git a/sdks/python/apache_beam/internal/cloudpickle_pickler.py b/sdks/python/apache_beam/internal/cloudpickle_pickler.py index 63038e770f27..e55818bfb226 100644 --- a/sdks/python/apache_beam/internal/cloudpickle_pickler.py +++ b/sdks/python/apache_beam/internal/cloudpickle_pickler.py @@ -39,6 +39,8 @@ DEFAULT_CONFIG = cloudpickle.CloudPickleConfig( skip_reset_dynamic_type_state=True) +NO_DYNAMIC_CLASS_TRACKING_CONFIG = cloudpickle.CloudPickleConfig( + id_generator=None, skip_reset_dynamic_type_state=True) try: from absl import flags diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py index 83a0bee81456..0ed5a435e788 100644 --- a/sdks/python/apache_beam/pipeline.py +++ b/sdks/python/apache_beam/pipeline.py @@ -76,6 +76,7 @@ from google.protobuf import message from apache_beam import pvalue +from apache_beam.coders import typecoders from apache_beam.internal import pickler from apache_beam.io.filesystems import FileSystems from apache_beam.options.pipeline_options import CrossLanguageOptions @@ -83,6 +84,7 @@ from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.options.pipeline_options import SetupOptions from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.options.pipeline_options import StreamingOptions from apache_beam.options.pipeline_options import TypeOptions from apache_beam.options.pipeline_options_validator import PipelineOptionsValidator from apache_beam.portability import common_urns @@ -229,6 +231,9 @@ def __init__( raise ValueError( 'Pipeline has validations errors: \n' + '\n'.join(errors)) + typecoders.registry.update_compatibility_version = self._options.view_as( + StreamingOptions).update_compatibility_version + # set default experiments for portable runners # (needs to occur prior to pipeline construction) if runner.is_fnapi_compatible(): diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py index e1c84c7dc9ae..39d216c4b3b4 100644 --- a/sdks/python/apache_beam/transforms/ptransform_test.py +++ b/sdks/python/apache_beam/transforms/ptransform_test.py @@ -34,16 +34,20 @@ import hamcrest as hc import numpy as np import pytest +from parameterized import param +from parameterized import parameterized from parameterized import parameterized_class import apache_beam as beam import apache_beam.transforms.combiners as combine from apache_beam import pvalue from apache_beam import typehints +from apache_beam.coders import coders_test_common from apache_beam.io.iobase import Read from apache_beam.metrics import Metrics from apache_beam.metrics.metric import MetricsFilter from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import StreamingOptions from apache_beam.options.pipeline_options import TypeOptions from apache_beam.portability import common_urns from apache_beam.testing.test_pipeline import TestPipeline @@ -572,7 +576,7 @@ def encode(self, o): def decode(self, encoded): return MyObject(pickle.loads(encoded)[0]) - def as_deterministic_coder(self, *args): + def as_deterministic_coder(self, *args, **kwargs): return MydeterministicObjectCoder() def to_type_hint(self): @@ -719,6 +723,67 @@ def test_flatten_one_single_pcollection(self): result = (pcoll, ) | 'Single Flatten' >> beam.Flatten() assert_that(result, equal_to(input)) + @parameterized.expand([ + param(compat_version=None), + param(compat_version="2.66.0"), + ]) + @pytest.mark.it_validatesrunner + def test_group_by_key_importable_special_types(self, compat_version): + def generate(_): + for _ in range(100): + yield (coders_test_common.MyTypedNamedTuple(1, 'a'), 1) + + pipeline = TestPipeline(is_integration_test=True) + if compat_version: + pipeline.get_pipeline_options().view_as( + StreamingOptions).update_compatibility_version = compat_version + with pipeline as p: + result = ( + p + | 'Create' >> beam.Create([i for i in range(100)]) + | 'Generate' >> beam.ParDo(generate) + | 'Reshuffle' >> beam.Reshuffle() + | 'GBK' >> beam.GroupByKey()) + assert_that( + result, + equal_to([( + coders_test_common.MyTypedNamedTuple(1, 'a'), + [1 for i in range(10000)])])) + + @pytest.mark.it_validatesrunner + def test_group_by_key_dynamic_special_types(self): + def create_dynamic_named_tuple(): + return collections.namedtuple('DynamicNamedTuple', ['x', 'y']) + + dynamic_named_tuple = create_dynamic_named_tuple() + + # Standard FastPrimitivesCoder falls back to python PickleCoder which + # cannot serialize dynamic types or types defined in __main__. Use + # CloudPickleCoder as fallback coder for non-deterministic steps. + class FastPrimitivesCoderV2(beam.coders.FastPrimitivesCoder): + def __init__(self): + super().__init__(fallback_coder=beam.coders.CloudpickleCoder()) + + beam.coders.typecoders.registry.register_coder( + dynamic_named_tuple, FastPrimitivesCoderV2) + + def generate(_): + for _ in range(100): + yield (dynamic_named_tuple(1, 'a'), 1) + + pipeline = TestPipeline(is_integration_test=True) + + with pipeline as p: + result = ( + p + | 'Create' >> beam.Create([i for i in range(100)]) + | 'Reshuffle' >> beam.Reshuffle() + | 'Generate' >> beam.ParDo(generate).with_output_types( + tuple[dynamic_named_tuple, int]) + | 'GBK' >> beam.GroupByKey() + | 'Count Elements' >> beam.Map(lambda x: len(x[1]))) + assert_that(result, equal_to([10000])) + # TODO(https://github.com/apache/beam/issues/20067): Does not work in # streaming mode on Dataflow. @pytest.mark.no_sickbay_streaming diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index c60ded52df26..2df66aadcc64 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -932,6 +932,15 @@ def get_window_coder(self): return self._window_coder +def is_v1_prior_to_v2(*, v1, v2): + if v1 is None: + return False + + v1_parts = (v1.split('.') + ['0', '0', '0'])[:3] + v2_parts = (v2.split('.') + ['0', '0', '0'])[:3] + return tuple(map(int, v1_parts)) < tuple(map(int, v2_parts)) + + def is_compat_version_prior_to(options, breaking_change_version): # This function is used in a branch statement to determine whether we should # keep the old behavior prior to a breaking change or use the new behavior. @@ -940,15 +949,8 @@ def is_compat_version_prior_to(options, breaking_change_version): update_compatibility_version = options.view_as( pipeline_options.StreamingOptions).update_compatibility_version - if update_compatibility_version is None: - return False - - compat_version = tuple(map(int, update_compatibility_version.split('.')[0:3])) - change_version = tuple(map(int, breaking_change_version.split('.')[0:3])) - for i in range(min(len(compat_version), len(change_version))): - if compat_version[i] < change_version[i]: - return True - return False + return is_v1_prior_to_v2( + v1=update_compatibility_version, v2=breaking_change_version) def reify_metadata_default_window( diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py index ad185ac6a6d1..b365d9b22090 100644 --- a/sdks/python/apache_beam/transforms/util_test.py +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -2193,6 +2193,68 @@ def record(tag): label='result') +class CompatCheckTest(unittest.TestCase): + def test_is_v1_prior_to_v2(self): + test_cases = [ + # Basic comparison cases + ("1.0.0", "2.0.0", True), # v1 < v2 in major + ("2.0.0", "1.0.0", False), # v1 > v2 in major + ("1.1.0", "1.2.0", True), # v1 < v2 in minor + ("1.2.0", "1.1.0", False), # v1 > v2 in minor + ("1.0.1", "1.0.2", True), # v1 < v2 in patch + ("1.0.2", "1.0.1", False), # v1 > v2 in patch + + # Equal versions + ("1.0.0", "1.0.0", False), # Identical + ("0.0.0", "0.0.0", False), # Both zero + + # Different lengths - shorter vs longer + ("1.0", "1.0.0", False), # Should be equal (1.0 = 1.0.0) + ("1.0", "1.0.1", True), # 1.0.0 < 1.0.1 + ("1.2", "1.2.0", False), # Should be equal (1.2 = 1.2.0) + ("1.2", "1.2.3", True), # 1.2.0 < 1.2.3 + ("2", "2.0.0", False), # Should be equal (2 = 2.0.0) + ("2", "2.0.1", True), # 2.0.0 < 2.0.1 + ("1", "2.0", True), # 1.0.0 < 2.0.0 + + # Different lengths - longer vs shorter + ("1.0.0", "1.0", False), # Should be equal + ("1.0.1", "1.0", False), # 1.0.1 > 1.0.0 + ("1.2.0", "1.2", False), # Should be equal + ("1.2.3", "1.2", False), # 1.2.3 > 1.2.0 + ("2.0.0", "2", False), # Should be equal + ("2.0.1", "2", False), # 2.0.1 > 2.0.0 + ("2.0", "1", False), # 2.0.0 > 1.0.0 + + # Mixed length comparisons + ("1.0", "2.0.0", True), # 1.0.0 < 2.0.0 + ("2.0", "1.0.0", False), # 2.0.0 > 1.0.0 + ("1", "1.0.1", True), # 1.0.0 < 1.0.1 + ("1.1", "1.0.9", False), # 1.1.0 > 1.0.9 + + # Large numbers + ("1.9.9", "2.0.0", True), # 1.9.9 < 2.0.0 + ("10.0.0", "9.9.9", False), # 10.0.0 > 9.9.9 + ("1.10.0", "1.9.0", False), # 1.10.0 > 1.9.0 + ("1.2.10", "1.2.9", False), # 1.2.10 > 1.2.9 + + # Sequential versions + ("1.0.0", "1.0.1", True), + ("1.0.1", "1.0.2", True), + ("1.0.9", "1.1.0", True), + ("1.9.9", "2.0.0", True), + + # Null/None cases + (None, "1.0.0", False), # v1 is None + ] + + for v1, v2, expected in test_cases: + self.assertEqual( + util.is_v1_prior_to_v2(v1=v1, v2=v2), + expected, + msg=f"Failed {v1} < {v2} == {expected}") + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main()