# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # pytype: skip-file import copy import itertools import random import threading import unittest from apache_beam.metrics.cells import BoundedTrieData from apache_beam.metrics.cells import CounterCell from apache_beam.metrics.cells import DistributionCell from apache_beam.metrics.cells import DistributionData from apache_beam.metrics.cells import GaugeCell from apache_beam.metrics.cells import GaugeData from apache_beam.metrics.cells import HistogramCell from apache_beam.metrics.cells import HistogramCellFactory from apache_beam.metrics.cells import HistogramData from apache_beam.metrics.cells import StringSetCell from apache_beam.metrics.cells import StringSetData from apache_beam.metrics.cells import _BoundedTrieNode from apache_beam.metrics.metricbase import MetricName from apache_beam.utils.histogram import Histogram from apache_beam.utils.histogram import LinearBucket class TestCounterCell(unittest.TestCase): @classmethod def _modify_counter(cls, d): for i in range(cls.NUM_ITERATIONS): d.inc(i) NUM_THREADS = 5 NUM_ITERATIONS = 100 def test_parallel_access(self): # We create NUM_THREADS threads that concurrently modify the counter. threads = [] c = CounterCell() for _ in range(TestCounterCell.NUM_THREADS): t = threading.Thread( target=TestCounterCell._modify_counter, args=(c, )) threads.append(t) t.start() for t in threads: t.join() total = ( self.NUM_ITERATIONS * (self.NUM_ITERATIONS - 1) // 2 * self.NUM_THREADS) self.assertEqual(c.get_cumulative(), total) def test_basic_operations(self): c = CounterCell() c.inc(2) self.assertEqual(c.get_cumulative(), 2) c.dec(10) self.assertEqual(c.get_cumulative(), -8) c.dec() self.assertEqual(c.get_cumulative(), -9) c.inc() self.assertEqual(c.get_cumulative(), -8) def test_start_time_set(self): c = CounterCell() c.inc(2) name = MetricName('namespace', 'name1') mi = c.to_runner_api_monitoring_info(name, 'transform_id') self.assertGreater(mi.start_time.seconds, 0) class TestDistributionCell(unittest.TestCase): @classmethod def _modify_distribution(cls, d): for i in range(cls.NUM_ITERATIONS): d.update(i) NUM_THREADS = 5 NUM_ITERATIONS = 100 def test_parallel_access(self): # We create NUM_THREADS threads that concurrently modify the distribution. threads = [] d = DistributionCell() for _ in range(TestDistributionCell.NUM_THREADS): t = threading.Thread( target=TestDistributionCell._modify_distribution, args=(d, )) threads.append(t) t.start() for t in threads: t.join() total = ( self.NUM_ITERATIONS * (self.NUM_ITERATIONS - 1) // 2 * self.NUM_THREADS) count = (self.NUM_ITERATIONS * self.NUM_THREADS) self.assertEqual( d.get_cumulative(), DistributionData(total, count, 0, self.NUM_ITERATIONS - 1)) def test_basic_operations(self): d = DistributionCell() d.update(10) self.assertEqual(d.get_cumulative(), DistributionData(10, 1, 10, 10)) d.update(2) self.assertEqual(d.get_cumulative(), DistributionData(12, 2, 2, 10)) d.update(900) self.assertEqual(d.get_cumulative(), DistributionData(912, 3, 2, 900)) def test_integer_only(self): d = DistributionCell() d.update(3.1) d.update(3.2) d.update(3.3) self.assertEqual(d.get_cumulative(), DistributionData(9, 3, 3, 3)) def test_start_time_set(self): d = DistributionCell() d.update(3.1) name = MetricName('namespace', 'name1') mi = d.to_runner_api_monitoring_info(name, 'transform_id') self.assertGreater(mi.start_time.seconds, 0) class TestGaugeCell(unittest.TestCase): def test_basic_operations(self): g = GaugeCell() g.set(10) self.assertEqual(g.get_cumulative().value, GaugeData(10).value) g.set(2) self.assertEqual(g.get_cumulative().value, 2) def test_integer_only(self): g = GaugeCell() g.set(3.3) self.assertEqual(g.get_cumulative().value, 3) def test_combine_appropriately(self): g1 = GaugeCell() g1.set(3) g2 = GaugeCell() g2.set(1) # THe second Gauge, with value 1 was the most recent, so it should be # the final result. result = g2.combine(g1) self.assertEqual(result.data.value, 1) def test_start_time_set(self): g1 = GaugeCell() g1.set(3) name = MetricName('namespace', 'name1') mi = g1.to_runner_api_monitoring_info(name, 'transform_id') self.assertGreater(mi.start_time.seconds, 0) class TestStringSetCell(unittest.TestCase): def test_not_leak_mutable_set(self): c = StringSetCell() c.add('test') c.add('another') s = c.get_cumulative() self.assertEqual(s, StringSetData({'test', 'another'}, 11)) s.add('yet another') self.assertEqual(c.get_cumulative(), StringSetData({'test', 'another'}, 11)) def test_combine_appropriately(self): s1 = StringSetCell() s1.add('1') s1.add('2') s2 = StringSetCell() s2.add('1') s2.add('3') result = s2.combine(s1) self.assertEqual(result.data, StringSetData({'1', '2', '3'})) def test_add_size_tracked_correctly(self): s = StringSetCell() s.add('1') s.add('2') self.assertEqual(s.data.string_size, 2) s.add('2') s.add('3') self.assertEqual(s.data.string_size, 3) class TestBoundedTrieNode(unittest.TestCase): @classmethod def random_segments_fixed_depth(cls, n, depth, overlap, rand): if depth == 0: yield from ((), ) * n else: seen = [] to_string = lambda ix: chr(ord('a') + ix) if ix < 26 else f'z{ix}' for suffix in cls.random_segments_fixed_depth(n, depth - 1, overlap, rand): if not seen or rand.random() > overlap: prefix = to_string(len(seen)) seen.append(prefix) else: prefix = rand.choice(seen) yield (prefix, ) + suffix @classmethod def random_segments(cls, n, min_depth, max_depth, overlap, rand): for depth, segments in zip( itertools.cycle(range(min_depth, max_depth + 1)), cls.random_segments_fixed_depth(n, max_depth, overlap, rand)): yield segments[:depth] def assert_covers(self, node, expected, max_truncated=0): self.assert_covers_flattened(node.flattened(), expected, max_truncated) def assert_covers_flattened(self, flattened, expected, max_truncated=0): expected = set(expected) # Split node into the exact and truncated segments. partitioned = {True: set(), False: set()} for segments in flattened: partitioned[segments[-1]].add(segments[:-1]) exact, truncated = partitioned[False], partitioned[True] # Check we cover both parts. self.assertLessEqual(len(truncated), max_truncated, truncated) self.assertTrue(exact.issubset(expected), exact - expected) seen_truncated = set() for segments in expected - exact: found = 0 for ix in range(len(segments)): if segments[:ix] in truncated: seen_truncated.add(segments[:ix]) found += 1 if found != 1: self.fail( f"Expected exactly one prefix of {segments} " f"to occur in {truncated}, found {found}") self.assertEqual(seen_truncated, truncated, truncated - seen_truncated) def run_covers_test(self, flattened, expected, max_truncated): def parse(s): return tuple(s.strip('*')) + (s.endswith('*'), ) self.assert_covers_flattened([parse(s) for s in flattened], [tuple(s) for s in expected], max_truncated) def test_covers_exact(self): self.run_covers_test(['ab', 'ac', 'cd'], ['ab', 'ac', 'cd'], 0) with self.assertRaises(AssertionError): self.run_covers_test(['ab', 'ac', 'cd'], ['ac', 'cd'], 0) with self.assertRaises(AssertionError): self.run_covers_test(['ab', 'ac'], ['ab', 'ac', 'cd'], 0) with self.assertRaises(AssertionError): self.run_covers_test(['a*', 'cd'], ['ab', 'ac', 'cd'], 0) def test_covers_trunacted(self): self.run_covers_test(['a*', 'cd'], ['ab', 'ac', 'cd'], 1) self.run_covers_test(['a*', 'cd'], ['ab', 'ac', 'abcde', 'cd'], 1) with self.assertRaises(AssertionError): self.run_covers_test(['ab', 'ac', 'cd'], ['ac', 'cd'], 1) with self.assertRaises(AssertionError): self.run_covers_test(['ab', 'ac'], ['ab', 'ac', 'cd'], 1) with self.assertRaises(AssertionError): self.run_covers_test(['a*', 'c*'], ['ab', 'ac', 'cd'], 1) with self.assertRaises(AssertionError): self.run_covers_test(['a*', 'c*'], ['ab', 'ac'], 1) def run_test(self, to_add): everything = list(set(to_add)) all_prefixees = set( segments[:ix] for segments in everything for ix in range(len(segments))) everything_deduped = set(everything) - all_prefixees # Check basic addition. node = _BoundedTrieNode() total_size = node.size() self.assertEqual(total_size, 1) for segments in everything: total_size += node.add(segments) self.assertEqual(node.size(), len(everything_deduped), node) self.assertEqual(node.size(), total_size, node) self.assert_covers(node, everything_deduped) # Check merging node0 = _BoundedTrieNode() node0.add_all(everything[0::2]) node1 = _BoundedTrieNode() node1.add_all(everything[1::2]) pre_merge_size = node0.size() merge_delta = node0.merge(node1) self.assertEqual(node0.size(), pre_merge_size + merge_delta) self.assertEqual(node0, node) # Check trimming. if node.size() > 1: trim_delta = node.trim() self.assertLess(trim_delta, 0, node) self.assertEqual(node.size(), total_size + trim_delta) self.assert_covers(node, everything_deduped, max_truncated=1) if node.size() > 1: trim2_delta = node.trim() self.assertLess(trim2_delta, 0) self.assertEqual(node.size(), total_size + trim_delta + trim2_delta) self.assert_covers(node, everything_deduped, max_truncated=2) # Adding after trimming should be a no-op. node_copy = copy.deepcopy(node) for segments in everything: self.assertEqual(node.add(segments), 0) self.assertEqual(node, node_copy) # Merging after trimming should be a no-op. self.assertEqual(node.merge(node0), 0) self.assertEqual(node.merge(node1), 0) self.assertEqual(node, node_copy) if node._truncated: expected_delta = 0 else: expected_delta = 2 # Adding something new is not. new_values = [('new1', ), ('new2', 'new2.1')] self.assertEqual(node.add_all(new_values), expected_delta) self.assert_covers( node, list(everything_deduped) + new_values, max_truncated=2) # Nor is merging something new. new_values_node = _BoundedTrieNode() new_values_node.add_all(new_values) self.assertEqual(node_copy.merge(new_values_node), expected_delta) self.assert_covers( node_copy, list(everything_deduped) + new_values, max_truncated=2) def run_fuzz(self, iterations=10, **params): for _ in range(iterations): seed = random.getrandbits(64) segments = self.random_segments(**params, rand=random.Random(seed)) try: self.run_test(segments) except: print("SEED", seed) raise def test_trivial(self): self.run_test([('a', 'b'), ('a', 'c')]) def test_flat(self): self.run_test([('a', 'a'), ('b', 'b'), ('c', 'c')]) def test_deep(self): self.run_test([('a', ) * 10, ('b', ) * 12]) def test_small(self): self.run_fuzz(n=5, min_depth=2, max_depth=3, overlap=0.5) def test_medium(self): self.run_fuzz(n=20, min_depth=2, max_depth=4, overlap=0.5) def test_large_sparse(self): self.run_fuzz(n=120, min_depth=2, max_depth=4, overlap=0.2) def test_large_dense(self): self.run_fuzz(n=120, min_depth=2, max_depth=4, overlap=0.8) def test_bounded_trie_data_combine(self): empty = BoundedTrieData() # The merging here isn't complicated we're just ensuring that # BoundedTrieData invokes _BoundedTrieNode correctly. singletonA = BoundedTrieData(singleton=('a', 'a')) singletonB = BoundedTrieData(singleton=('b', 'b')) lots_root = _BoundedTrieNode() lots_root.add_all([('c', 'c'), ('d', 'd')]) lots = BoundedTrieData(root=lots_root) self.assertEqual(empty.get_result(), set()) self.assertEqual( empty.combine(singletonA).get_result(), set([('a', 'a', False)])) self.assertEqual( singletonA.combine(empty).get_result(), set([('a', 'a', False)])) self.assertEqual( singletonA.combine(singletonB).get_result(), set([('a', 'a', False), ('b', 'b', False)])) self.assertEqual( singletonA.combine(lots).get_result(), set([('a', 'a', False), ('c', 'c', False), ('d', 'd', False)])) self.assertEqual( lots.combine(singletonA).get_result(), set([('a', 'a', False), ('c', 'c', False), ('d', 'd', False)])) def test_bounded_trie_data_combine_trim(self): left = _BoundedTrieNode() left.add_all([('a', 'x'), ('b', 'd')]) right = _BoundedTrieNode() right.add_all([('a', 'y'), ('c', 'd')]) self.assertEqual( BoundedTrieData(root=left).combine( BoundedTrieData(root=right, bound=3)).get_result(), set([('a', True), ('b', 'd', False), ('c', 'd', False)])) def test_merge_on_empty_node(self): root1 = _BoundedTrieNode() root2 = _BoundedTrieNode() root2.add_all([["a", "b", "c"], ["a", "b", "d"], ["a", "e"]]) self.assertEqual(2, root1.merge(root2)) self.assertEqual(3, root1.size()) self.assertFalse(root1._truncated) def test_merge_with_empty_node(self): root1 = _BoundedTrieNode() root1.add_all([["a", "b", "c"], ["a", "b", "d"], ["a", "e"]]) root2 = _BoundedTrieNode() self.assertEqual(0, root1.merge(root2)) self.assertEqual(3, root1.size()) self.assertFalse(root1._truncated) class TestHistogramCell(unittest.TestCase): @classmethod def _modify_histogram(cls, d): for i in range(cls.NUM_ITERATIONS): d.update(i) NUM_THREADS = 5 NUM_ITERATIONS = 100 def test_parallel_access(self): # We create NUM_THREADS threads that concurrently modify the distribution. threads = [] bucket_type = LinearBucket(0, 1, 100) d = HistogramCell(bucket_type) for _ in range(TestHistogramCell.NUM_THREADS): t = threading.Thread( target=TestHistogramCell._modify_histogram, args=(d, )) threads.append(t) t.start() for t in threads: t.join() histogram = Histogram(bucket_type) for _ in range(self.NUM_THREADS): for i in range(self.NUM_ITERATIONS): histogram.record(i) self.assertEqual(d.get_cumulative(), HistogramData(histogram)) def test_basic_operations(self): d = HistogramCellFactory(LinearBucket(0, 1, 10))() d.update(10) self.assertEqual( str(d.get_cumulative()), 'HistogramData(Total count: 1, P99: >=10, P90: >=10, P50: >=10)') d.update(0) self.assertEqual( str(d.get_cumulative()), 'HistogramData(Total count: 2, P99: >=10, P90: >=10, P50: 1)') d.update(5) self.assertEqual( str(d.get_cumulative()), 'HistogramData(Total count: 3, P99: >=10, P90: >=10, P50: 6)') if __name__ == '__main__': unittest.main()