diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index cf1107c45a..e7b00da71c 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -1,7 +1,7 @@ import doctest import unittest from test import support -from test.support import threading_helper +from test.support import threading_helper, script_helper from itertools import * import weakref from decimal import Decimal @@ -15,6 +15,26 @@ import struct import threading import gc +import warnings + +def pickle_deprecated(testfunc): + """ Run the test three times. + First, verify that a Deprecation Warning is raised. + Second, run normally but with DeprecationWarnings temporarily disabled. + Third, run with warnings promoted to errors. + """ + def inner(self): + with self.assertWarns(DeprecationWarning): + testfunc(self) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=DeprecationWarning) + testfunc(self) + with warnings.catch_warnings(): + warnings.simplefilter("error", category=DeprecationWarning) + with self.assertRaises((DeprecationWarning, AssertionError, SystemError)): + testfunc(self) + + return inner maxsize = support.MAX_Py_ssize_t minsize = -maxsize-1 @@ -126,6 +146,7 @@ def expand(it, i=0): # TODO: RUSTPYTHON @unittest.expectedFailure + @pickle_deprecated def test_accumulate(self): self.assertEqual(list(accumulate(range(10))), # one positional arg [0, 1, 3, 6, 10, 15, 21, 28, 36, 45]) @@ -161,6 +182,44 @@ def test_accumulate(self): with self.assertRaises(TypeError): list(accumulate([10, 20], 100)) + def test_batched(self): + self.assertEqual(list(batched('ABCDEFG', 3)), + [('A', 'B', 'C'), ('D', 'E', 'F'), ('G',)]) + self.assertEqual(list(batched('ABCDEFG', 2)), + [('A', 'B'), ('C', 'D'), ('E', 'F'), ('G',)]) + self.assertEqual(list(batched('ABCDEFG', 1)), + [('A',), ('B',), ('C',), ('D',), ('E',), ('F',), ('G',)]) + + with self.assertRaises(TypeError): # Too few arguments + list(batched('ABCDEFG')) + with self.assertRaises(TypeError): + list(batched('ABCDEFG', 3, None)) # Too many arguments + with self.assertRaises(TypeError): + list(batched(None, 3)) # Non-iterable input + with self.assertRaises(TypeError): + list(batched('ABCDEFG', 'hello')) # n is a string + with self.assertRaises(ValueError): + list(batched('ABCDEFG', 0)) # n is zero + with self.assertRaises(ValueError): + list(batched('ABCDEFG', -1)) # n is negative + + data = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' + for n in range(1, 6): + for i in range(len(data)): + s = data[:i] + batches = list(batched(s, n)) + with self.subTest(s=s, n=n, batches=batches): + # Order is preserved and no data is lost + self.assertEqual(''.join(chain(*batches)), s) + # Each batch is an exact tuple + self.assertTrue(all(type(batch) is tuple for batch in batches)) + # All but the last batch is of size n + if batches: + last_batch = batches.pop() + self.assertTrue(all(len(batch) == n for batch in batches)) + self.assertTrue(len(last_batch) <= n) + batches.append(last_batch) + def test_chain(self): def chain2(*iterables): @@ -182,7 +241,11 @@ def test_chain_from_iterable(self): self.assertEqual(list(chain.from_iterable([''])), []) self.assertEqual(take(4, chain.from_iterable(['abc', 'def'])), list('abcd')) self.assertRaises(TypeError, list, chain.from_iterable([2, 3])) + self.assertEqual(list(islice(chain.from_iterable(repeat(range(5))), 2)), [0, 1]) + # TODO: RUSTPYTHON + @unittest.expectedFailure + @pickle_deprecated def test_chain_reducible(self): for oper in [copy.deepcopy] + picklecopiers: it = chain('abc', 'def') @@ -196,6 +259,9 @@ def test_chain_reducible(self): for proto in range(pickle.HIGHEST_PROTOCOL + 1): self.pickletest(proto, chain('abc', 'def'), compare=list('abcdef')) + # TODO: RUSTPYTHON + @unittest.expectedFailure + @pickle_deprecated def test_chain_setstate(self): self.assertRaises(TypeError, chain().__setstate__, ()) self.assertRaises(TypeError, chain().__setstate__, []) @@ -211,6 +277,7 @@ def test_chain_setstate(self): # TODO: RUSTPYTHON @unittest.expectedFailure + @pickle_deprecated def test_combinations(self): self.assertRaises(TypeError, combinations, 'abc') # missing r argument self.assertRaises(TypeError, combinations, 'abc', 2, 1) # too many arguments @@ -234,7 +301,6 @@ def test_combinations(self): self.assertEqual(list(op(testIntermediate)), [(0,1,3), (0,2,3), (1,2,3)]) - def combinations1(iterable, r): 'Pure python version shown in the docs' pool = tuple(iterable) @@ -304,6 +370,7 @@ def test_combinations_tuple_reuse(self): # TODO: RUSTPYTHON @unittest.expectedFailure + @pickle_deprecated def test_combinations_with_replacement(self): cwr = combinations_with_replacement self.assertRaises(TypeError, cwr, 'abc') # missing r argument @@ -394,6 +461,7 @@ def test_combinations_with_replacement_tuple_reuse(self): # TODO: RUSTPYTHON @unittest.expectedFailure + @pickle_deprecated def test_permutations(self): self.assertRaises(TypeError, permutations) # too few arguments self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments @@ -500,6 +568,9 @@ def test_combinatorics(self): self.assertEqual(comb, list(filter(set(perm).__contains__, cwr))) # comb: cwr that is a perm self.assertEqual(comb, sorted(set(cwr) & set(perm))) # comb: both a cwr and a perm + # TODO: RUSTPYTHON + @unittest.expectedFailure + @pickle_deprecated def test_compress(self): self.assertEqual(list(compress(data='ABCDEF', selectors=[1,0,1,0,1,1])), list('ACEF')) self.assertEqual(list(compress('ABCDEF', [1,0,1,0,1,1])), list('ACEF')) @@ -533,7 +604,9 @@ def test_compress(self): next(testIntermediate) self.assertEqual(list(op(testIntermediate)), list(result2)) - + # TODO: RUSTPYTHON + @unittest.expectedFailure + @pickle_deprecated def test_count(self): self.assertEqual(lzip('abc',count()), [('a', 0), ('b', 1), ('c', 2)]) self.assertEqual(lzip('abc',count(3)), [('a', 3), ('b', 4), ('c', 5)]) @@ -584,6 +657,7 @@ def test_count(self): # TODO: RUSTPYTHON @unittest.expectedFailure + @pickle_deprecated def test_count_with_stride(self): self.assertEqual(lzip('abc',count(2,3)), [('a', 2), ('b', 5), ('c', 8)]) self.assertEqual(lzip('abc',count(start=2,step=3)), @@ -648,6 +722,7 @@ def test_cycle(self): # TODO: RUSTPYTHON @unittest.expectedFailure + @pickle_deprecated def test_cycle_copy_pickle(self): # check copy, deepcopy, pickle c = cycle('abc') @@ -686,6 +761,7 @@ def test_cycle_copy_pickle(self): # TODO: RUSTPYTHON @unittest.expectedFailure + @pickle_deprecated def test_cycle_unpickle_compat(self): testcases = [ b'citertools\ncycle\n(c__builtin__\niter\n((lI1\naI2\naI3\natRI1\nbtR((lI1\naI0\ntb.', @@ -719,6 +795,7 @@ def test_cycle_unpickle_compat(self): # TODO: RUSTPYTHON @unittest.expectedFailure + @pickle_deprecated def test_cycle_setstate(self): # Verify both modes for restoring state @@ -757,6 +834,7 @@ def test_cycle_setstate(self): # TODO: RUSTPYTHON @unittest.expectedFailure + @pickle_deprecated def test_groupby(self): # Check whether it accepts arguments correctly self.assertEqual([], list(groupby([]))) @@ -914,6 +992,9 @@ def test_filter(self): c = filter(isEven, range(6)) self.pickletest(proto, c) + # TODO: RUSTPYTHON + @unittest.expectedFailure + @pickle_deprecated def test_filterfalse(self): self.assertEqual(list(filterfalse(isEven, range(6))), [1,3,5]) self.assertEqual(list(filterfalse(None, [0,1,0,2,0])), [0,0,0]) @@ -944,6 +1025,7 @@ def test_zip(self): lzip('abc', 'def')) @support.impl_detail("tuple reuse is specific to CPython") + @pickle_deprecated def test_zip_tuple_reuse(self): ids = list(map(id, zip('abc', 'def'))) self.assertEqual(min(ids), max(ids)) @@ -1019,6 +1101,9 @@ def test_zip_longest_tuple_reuse(self): ids = list(map(id, list(zip_longest('abc', 'def')))) self.assertEqual(len(dict.fromkeys(ids)), len(ids)) + # TODO: RUSTPYTHON + @unittest.expectedFailure + @pickle_deprecated def test_zip_longest_pickling(self): for proto in range(pickle.HIGHEST_PROTOCOL + 1): self.pickletest(proto, zip_longest("abc", "def")) @@ -1097,6 +1182,82 @@ def test_pairwise(self): with self.assertRaises(TypeError): pairwise(None) # non-iterable argument + # TODO: RUSTPYTHON + @unittest.skip("TODO: RUSTPYTHON, hangs") + def test_pairwise_reenter(self): + def check(reenter_at, expected): + class I: + count = 0 + def __iter__(self): + return self + def __next__(self): + self.count +=1 + if self.count in reenter_at: + return next(it) + return [self.count] # new object + + it = pairwise(I()) + for item in expected: + self.assertEqual(next(it), item) + + check({1}, [ + (([2], [3]), [4]), + ([4], [5]), + ]) + check({2}, [ + ([1], ([1], [3])), + (([1], [3]), [4]), + ([4], [5]), + ]) + check({3}, [ + ([1], [2]), + ([2], ([2], [4])), + (([2], [4]), [5]), + ([5], [6]), + ]) + check({1, 2}, [ + ((([3], [4]), [5]), [6]), + ([6], [7]), + ]) + check({1, 3}, [ + (([2], ([2], [4])), [5]), + ([5], [6]), + ]) + check({1, 4}, [ + (([2], [3]), (([2], [3]), [5])), + ((([2], [3]), [5]), [6]), + ([6], [7]), + ]) + check({2, 3}, [ + ([1], ([1], ([1], [4]))), + (([1], ([1], [4])), [5]), + ([5], [6]), + ]) + + # TODO: RUSTPYTHON + @unittest.skip("TODO: RUSTPYTHON, hangs") + def test_pairwise_reenter2(self): + def check(maxcount, expected): + class I: + count = 0 + def __iter__(self): + return self + def __next__(self): + if self.count >= maxcount: + raise StopIteration + self.count +=1 + if self.count == 1: + return next(it, None) + return [self.count] # new object + + it = pairwise(I()) + self.assertEqual(list(it), expected) + + check(1, []) + check(2, []) + check(3, []) + check(4, [(([2], [3]), [4])]) + def test_product(self): for args, result in [ ([], [()]), # zero iterables @@ -1165,7 +1326,9 @@ def test_product_tuple_reuse(self): self.assertEqual(len(set(map(id, product('abc', 'def')))), 1) self.assertNotEqual(len(set(map(id, list(product('abc', 'def'))))), 1) - + # TODO: RUSTPYTHON + @unittest.expectedFailure + @pickle_deprecated def test_product_pickling(self): # check copy, deepcopy, pickle for args, result in [ @@ -1183,6 +1346,7 @@ def test_product_pickling(self): # TODO: RUSTPYTHON @unittest.expectedFailure + @pickle_deprecated def test_product_issue_25021(self): # test that indices are properly clamped to the length of the tuples p = product((1, 2),(3,)) @@ -1193,6 +1357,9 @@ def test_product_issue_25021(self): p.__setstate__((0, 0, 0x1000)) # will access tuple element 1 if not clamped self.assertRaises(StopIteration, next, p) + # TODO: RUSTPYTHON + @unittest.expectedFailure + @pickle_deprecated def test_repeat(self): self.assertEqual(list(repeat(object='a', times=3)), ['a', 'a', 'a']) self.assertEqual(lzip(range(3),repeat('a')), @@ -1227,6 +1394,7 @@ def test_repeat_with_negative_times(self): # TODO: RUSTPYTHON @unittest.expectedFailure + @pickle_deprecated def test_map(self): self.assertEqual(list(map(operator.pow, range(3), range(1,7))), [0**1, 1**2, 2**3]) @@ -1257,6 +1425,9 @@ def test_map(self): c = map(tupleize, 'abc', count()) self.pickletest(proto, c) + # TODO: RUSTPYTHON + @unittest.expectedFailure + @pickle_deprecated def test_starmap(self): self.assertEqual(list(starmap(operator.pow, zip(range(3), range(1,7)))), [0**1, 1**2, 2**3]) @@ -1286,6 +1457,7 @@ def test_starmap(self): # TODO: RUSTPYTHON @unittest.expectedFailure + @pickle_deprecated def test_islice(self): for args in [ # islice(args) should agree with range(args) (10, 20, 3), @@ -1380,6 +1552,9 @@ def __index__(self): self.assertEqual(list(islice(range(100), IntLike(10), IntLike(50), IntLike(5))), list(range(10,50,5))) + # TODO: RUSTPYTHON + @unittest.expectedFailure + @pickle_deprecated def test_takewhile(self): data = [1, 3, 5, 20, 2, 4, 6, 8] self.assertEqual(list(takewhile(underten, data)), [1, 3, 5]) @@ -1402,6 +1577,7 @@ def test_takewhile(self): # TODO: RUSTPYTHON @unittest.expectedFailure + @pickle_deprecated def test_dropwhile(self): data = [1, 3, 5, 20, 2, 4, 6, 8] self.assertEqual(list(dropwhile(underten, data)), [20, 2, 4, 6, 8]) @@ -1421,6 +1597,7 @@ def test_dropwhile(self): # TODO: RUSTPYTHON @unittest.expectedFailure + @pickle_deprecated def test_tee(self): n = 200 @@ -1570,6 +1747,14 @@ def test_tee(self): self.pickletest(proto, a, compare=ans) self.pickletest(proto, b, compare=ans) + def test_tee_dealloc_segfault(self): + # gh-115874: segfaults when accessing module state in tp_dealloc. + script = ( + "import typing, copyreg, itertools; " + "copyreg.buggy_tee = itertools.tee(())" + ) + script_helper.assert_python_ok("-c", script) + # Issue 13454: Crash when deleting backward iterator from tee() def test_tee_del_backward(self): forward, backward = tee(repeat(None, 20000000)) @@ -1687,12 +1872,47 @@ def test_zip_longest_result_gc(self): gc.collect() self.assertTrue(gc.is_tracked(next(it))) + @support.cpython_only + def test_immutable_types(self): + from itertools import _grouper, _tee, _tee_dataobject + dataset = ( + accumulate, + batched, + chain, + combinations, + combinations_with_replacement, + compress, + count, + cycle, + dropwhile, + filterfalse, + groupby, + _grouper, + islice, + pairwise, + permutations, + product, + repeat, + starmap, + takewhile, + _tee, + _tee_dataobject, + zip_longest, + ) + for tp in dataset: + with self.subTest(tp=tp): + with self.assertRaisesRegex(TypeError, "immutable"): + tp.foobar = 1 + class TestExamples(unittest.TestCase): def test_accumulate(self): self.assertEqual(list(accumulate([1,2,3,4,5])), [1, 3, 6, 10, 15]) + # TODO: RUSTPYTHON + @unittest.expectedFailure + @pickle_deprecated def test_accumulate_reducible(self): # check copy, deepcopy, pickle data = [1, 2, 3, 4, 5] @@ -1710,6 +1930,7 @@ def test_accumulate_reducible(self): # TODO: RUSTPYTHON @unittest.expectedFailure + @pickle_deprecated def test_accumulate_reducible_none(self): # Issue #25718: total is None it = accumulate([None, None, None], operator.is_) @@ -1802,6 +2023,31 @@ def test_takewhile(self): class TestPurePythonRoughEquivalents(unittest.TestCase): + def test_batched_recipe(self): + def batched_recipe(iterable, n): + "Batch data into tuples of length n. The last batch may be shorter." + # batched('ABCDEFG', 3) --> ABC DEF G + if n < 1: + raise ValueError('n must be at least one') + it = iter(iterable) + while batch := tuple(islice(it, n)): + yield batch + + for iterable, n in product( + ['', 'a', 'ab', 'abc', 'abcd', 'abcde', 'abcdef', 'abcdefg', None], + [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, None]): + with self.subTest(iterable=iterable, n=n): + try: + e1, r1 = None, list(batched(iterable, n)) + except Exception as e: + e1, r1 = type(e), None + try: + e2, r2 = None, list(batched_recipe(iterable, n)) + except Exception as e: + e2, r2 = type(e), None + self.assertEqual(r1, r2) + self.assertEqual(e1, e2) + @staticmethod def islice(iterable, *args): s = slice(*args) @@ -1853,6 +2099,10 @@ def test_accumulate(self): a = [] self.makecycle(accumulate([1,2,a,3]), a) + def test_batched(self): + a = [] + self.makecycle(batched([1,2,a,3], 2), a) + def test_chain(self): a = [] self.makecycle(chain(a), a) @@ -2010,6 +2260,20 @@ def __iter__(self): def __next__(self): 3 // 0 +class E2: + 'Test propagation of exceptions after two iterations' + def __init__(self, seqn): + self.seqn = seqn + self.i = 0 + def __iter__(self): + return self + def __next__(self): + if self.i == 2: + raise ZeroDivisionError + v = self.seqn[self.i] + self.i += 1 + return v + class S: 'Test immediate stop' def __init__(self, seqn): @@ -2037,6 +2301,19 @@ def test_accumulate(self): self.assertRaises(TypeError, accumulate, N(s)) self.assertRaises(ZeroDivisionError, list, accumulate(E(s))) + def test_batched(self): + s = 'abcde' + r = [('a', 'b'), ('c', 'd'), ('e',)] + n = 2 + for g in (G, I, Ig, L, R): + with self.subTest(g=g): + self.assertEqual(list(batched(g(s), n)), r) + self.assertEqual(list(batched(S(s), 2)), []) + self.assertRaises(TypeError, batched, X(s), 2) + self.assertRaises(TypeError, batched, N(s), 2) + self.assertRaises(ZeroDivisionError, list, batched(E(s), 2)) + self.assertRaises(ZeroDivisionError, list, batched(E2(s), 4)) + def test_chain(self): for s in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5)): for g in (G, I, Ig, S, L, R): @@ -2263,6 +2540,7 @@ def gen2(x): self.assertEqual(hist, [0,1]) @support.skip_if_pgo_task + @support.requires_resource('cpu') def test_long_chain_of_empty_iterables(self): # Make sure itertools.chain doesn't run into recursion limits when # dealing with long chains of empty iterables. Even with a high @@ -2296,7 +2574,6 @@ def __eq__(self, other): class SubclassWithKwargsTest(unittest.TestCase): - # TODO: RUSTPYTHON @unittest.expectedFailure def test_keywords_in_subclass(self): diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index 3bc26c0a8f..08a3469191 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -22,6 +22,9 @@ mod decl { VirtualMachine, }; use crossbeam_utils::atomic::AtomicCell; + use malachite_bigint::BigInt; + use num_traits::One; + use num_traits::{Signed, ToPrimitive}; use std::fmt; @@ -1952,4 +1955,78 @@ mod decl { Ok(PyIterReturn::Return(vm.new_tuple((old, new)).into())) } } + + #[pyattr] + #[pyclass(name = "batched")] + #[derive(Debug, PyPayload)] + struct PyItertoolsBatched { + exhausted: AtomicCell, + iterable: PyIter, + n: AtomicCell, + } + + #[derive(FromArgs)] + struct BatchedNewArgs { + #[pyarg(positional)] + iterable_ref: PyObjectRef, + #[pyarg(positional)] + n: PyIntRef, + } + + impl Constructor for PyItertoolsBatched { + type Args = BatchedNewArgs; + + fn py_new( + cls: PyTypeRef, + Self::Args { iterable_ref, n }: Self::Args, + vm: &VirtualMachine, + ) -> PyResult { + let n = n.as_bigint(); + if n.lt(&BigInt::one()) { + return Err(vm.new_value_error("n must be at least one".to_owned())); + } + let n = n.to_usize().ok_or( + vm.new_overflow_error("Python int too large to convert to usize".to_owned()), + )?; + let iterable = iterable_ref.get_iter(vm)?; + + Self { + iterable, + n: AtomicCell::new(n), + exhausted: AtomicCell::new(false), + } + .into_ref_with_type(vm, cls) + .map(Into::into) + } + } + + #[pyclass(with(IterNext, Iterable, Constructor), flags(BASETYPE, HAS_DICT))] + impl PyItertoolsBatched {} + + impl SelfIter for PyItertoolsBatched {} + + impl IterNext for PyItertoolsBatched { + fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { + if zelf.exhausted.load() { + return Ok(PyIterReturn::StopIteration(None)); + } + let mut result: Vec = Vec::new(); + let n = zelf.n.load(); + for _ in 0..n { + match zelf.iterable.next(vm)? { + PyIterReturn::Return(obj) => { + result.push(obj); + } + PyIterReturn::StopIteration(_) => { + zelf.exhausted.store(true); + break; + } + } + } + match result.len() { + 0 => Ok(PyIterReturn::StopIteration(None)), + _ => Ok(PyIterReturn::Return(vm.ctx.new_tuple(result).into())), + } + } + } } diff --git a/vm/src/stdlib/sre.rs b/vm/src/stdlib/sre.rs index 17a84ac150..93ea46ea87 100644 --- a/vm/src/stdlib/sre.rs +++ b/vm/src/stdlib/sre.rs @@ -471,7 +471,6 @@ mod _sre { } }; - last_pos = iter.state.cursor.position; n += 1; }