diff --git a/adafruit_itertools/adafruit_itertools_extras.py b/adafruit_itertools/adafruit_itertools_extras.py index 8a41038..a435bbd 100644 --- a/adafruit_itertools/adafruit_itertools_extras.py +++ b/adafruit_itertools/adafruit_itertools_extras.py @@ -41,26 +41,54 @@ import adafruit_itertools as it +try: + from typing import ( + Any, + Callable, + Iterable, + Iterator, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, + ) + from typing_extensions import TypeAlias + + _T = TypeVar("_T") + _N: TypeAlias = Union[int, float, complex] + _Predicate: TypeAlias = Callable[[_T], bool] +except ImportError: + pass + + __version__ = "0.0.0+auto.0" __repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_Itertools.git" -def all_equal(iterable): +def all_equal(iterable: Iterable[Any]) -> bool: """Returns True if all the elements are equal to each other. :param iterable: source of values """ g = it.groupby(iterable) - next(g) # should succeed, value isn't relevant try: - next(g) # should fail: only 1 group + next(g) # value isn't relevant + except StopIteration: + # Empty iterable, return True to match cpython behavior. + return True + try: + next(g) + # more than one group, so we have different elements. return False except StopIteration: + # Only one group - all elements must be equal. return True -def dotproduct(vec1, vec2): +def dotproduct(vec1: Iterable[_N], vec2: Iterable[_N]) -> _N: """Compute the dot product of two vectors. :param vec1: the first vector @@ -71,7 +99,11 @@ def dotproduct(vec1, vec2): return sum(map(lambda x, y: x * y, vec1, vec2)) -def first_true(iterable, default=False, pred=None): +def first_true( + iterable: Iterable[_T], + default: Union[bool, _T] = False, + pred: Optional[_Predicate[_T]] = None, +) -> Union[bool, _T]: """Returns the first true value in the iterable. If no true value is found, returns *default* @@ -94,7 +126,7 @@ def first_true(iterable, default=False, pred=None): return default -def flatten(iterable_of_iterables): +def flatten(iterable_of_iterables: Iterable[Iterable[_T]]) -> Iterator[_T]: """Flatten one level of nesting. :param iterable_of_iterables: a sequence of iterables to flatten @@ -104,7 +136,9 @@ def flatten(iterable_of_iterables): return it.chain_from_iterable(iterable_of_iterables) -def grouper(iterable, n, fillvalue=None): +def grouper( + iterable: Iterable[_T], n: int, fillvalue: Optional[_T] = None +) -> Iterator[Tuple[_T, ...]]: """Collect data into fixed-length chunks or blocks. :param iterable: source of values @@ -118,7 +152,7 @@ def grouper(iterable, n, fillvalue=None): return it.zip_longest(*args, fillvalue=fillvalue) -def iter_except(func, exception): +def iter_except(func: Callable[[], _T], exception: Type[BaseException]) -> Iterator[_T]: """Call a function repeatedly, yielding the results, until exception is raised. Converts a call-until-exception interface to an iterator interface. @@ -143,7 +177,7 @@ def iter_except(func, exception): pass -def ncycles(iterable, n): +def ncycles(iterable: Iterable[_T], n: int) -> Iterator[_T]: """Returns the sequence elements a number of times. :param iterable: the source of values @@ -153,7 +187,7 @@ def ncycles(iterable, n): return it.chain_from_iterable(it.repeat(tuple(iterable), n)) -def nth(iterable, n, default=None): +def nth(iterable: Iterable[_T], n: int, default: Optional[_T] = None) -> Optional[_T]: """Returns the nth item or a default value. :param iterable: the source of values @@ -166,7 +200,7 @@ def nth(iterable, n, default=None): return default -def padnone(iterable): +def padnone(iterable: Iterable[_T]) -> Iterator[Optional[_T]]: """Returns the sequence elements and then returns None indefinitely. Useful for emulating the behavior of the built-in map() function. @@ -177,13 +211,17 @@ def padnone(iterable): return it.chain(iterable, it.repeat(None)) -def pairwise(iterable): - """Pair up valuesin the iterable. +def pairwise(iterable: Iterable[_T]) -> Iterator[Tuple[_T, _T]]: + """Return successive overlapping pairs from the iterable. + + The number of tuples from the output will be one fewer than the + number of values in the input. It will be empty if the input has + fewer than two values. :param iterable: source of values """ - # pairwise(range(11)) -> (1, 2), (3, 4), (5, 6), (7, 8), (9, 10) + # pairwise(range(5)) -> (0, 1), (1, 2), (2, 3), (3, 4) a, b = it.tee(iterable) try: next(b) @@ -192,7 +230,9 @@ def pairwise(iterable): return zip(a, b) -def partition(pred, iterable): +def partition( + pred: _Predicate[_T], iterable: Iterable[_T] +) -> Tuple[Iterator[_T], Iterator[_T]]: """Use a predicate to partition entries into false entries and true entries. :param pred: the predicate that divides the values @@ -204,7 +244,7 @@ def partition(pred, iterable): return it.filterfalse(pred, t1), filter(pred, t2) -def prepend(value, iterator): +def prepend(value: _T, iterator: Iterable[_T]) -> Iterator[_T]: """Prepend a single value in front of an iterator :param value: the value to prepend @@ -215,7 +255,7 @@ def prepend(value, iterator): return it.chain([value], iterator) -def quantify(iterable, pred=bool): +def quantify(iterable: Iterable[_T], pred: _Predicate[_T] = bool) -> int: """Count how many times the predicate is true. :param iterable: source of values @@ -227,7 +267,9 @@ def quantify(iterable, pred=bool): return sum(map(pred, iterable)) -def repeatfunc(func, times=None, *args): +def repeatfunc( + func: Callable[..., _T], times: Optional[int] = None, *args: Any +) -> Iterator[_T]: """Repeat calls to func with specified arguments. Example: repeatfunc(random.random) @@ -242,7 +284,7 @@ def repeatfunc(func, times=None, *args): return it.starmap(func, it.repeat(args, times)) -def roundrobin(*iterables): +def roundrobin(*iterables: Iterable[_T]) -> Iterator[_T]: """Return an iterable created by repeatedly picking value from each argument in order. @@ -263,18 +305,19 @@ def roundrobin(*iterables): nexts = it.cycle(it.islice(nexts, num_active)) -def tabulate(function, start=0): - """Apply a function to a sequence of consecutive integers. +def tabulate(function: Callable[[int], int], start: int = 0) -> Iterator[int]: + """Apply a function to a sequence of consecutive numbers. - :param function: the function of one integer argument + :param function: the function of one numeric argument. :param start: optional value to start at (default is 0) """ # take(5, tabulate(lambda x: x * x))) -> 0 1 4 9 16 - return map(function, it.count(start)) + counter: Iterator[int] = it.count(start) # type: ignore[assignment] + return map(function, counter) -def tail(n, iterable): +def tail(n: int, iterable: Iterable[_T]) -> Iterator[_T]: """Return an iterator over the last n items :param n: how many values to return @@ -294,7 +337,7 @@ def tail(n, iterable): return iter(buf) -def take(n, iterable): +def take(n: int, iterable: Iterable[_T]) -> List[_T]: """Return first n items of the iterable as a list :param n: how many values to take diff --git a/optional_requirements.txt b/optional_requirements.txt index d4e27c4..1856c06 100644 --- a/optional_requirements.txt +++ b/optional_requirements.txt @@ -1,3 +1,6 @@ # SPDX-FileCopyrightText: 2022 Alec Delaney, for Adafruit Industries # # SPDX-License-Identifier: Unlicense + +# For comparison when running tests +more-itertools diff --git a/tests/test_itertools_extras.py b/tests/test_itertools_extras.py new file mode 100644 index 0000000..fad8786 --- /dev/null +++ b/tests/test_itertools_extras.py @@ -0,0 +1,285 @@ +# SPDX-FileCopyrightText: KB Sriram +# SPDX-License-Identifier: MIT + +from typing import ( + Callable, + Iterator, + Optional, + Sequence, + TypeVar, +) +from typing_extensions import TypeAlias + +import more_itertools as itextras +import pytest +from adafruit_itertools import adafruit_itertools_extras as aextras + +_K = TypeVar("_K") +_T = TypeVar("_T") +_S = TypeVar("_S") +_Predicate: TypeAlias = Callable[[_T], bool] + + +def _take(n: int, iterator: Iterator[_T]) -> Sequence[_T]: + """Extract the first n elements from a long/infinite iterator.""" + return [v for _, v in zip(range(n), iterator)] + + +@pytest.mark.parametrize( + "data", + [ + "aaaa", + "abcd", + "a", + "", + (1, 2), + (3, 3), + ("", False), + (42, True), + ], +) +def test_all_equal(data: Sequence[_T]) -> None: + assert itextras.all_equal(data) == aextras.all_equal(data) + + +@pytest.mark.parametrize( + ("vec1", "vec2"), + [ + ([1, 2], [3, 4]), + ([], []), + ([1], [2, 3]), + ([4, 5], [6]), + ], +) +def test_dotproduct(vec1: Sequence[int], vec2: Sequence[int]) -> None: + assert itextras.dotproduct(vec1, vec2) == aextras.dotproduct(vec1, vec2) + + +@pytest.mark.parametrize( + ("seq", "dflt", "pred"), + [ + ([0, 2], 0, None), + ([], 10, None), + ([False], True, None), + ([1, 2], -1, lambda _: False), + ([0, 1], -1, lambda _: True), + ([], -1, lambda _: True), + ], +) +def test_first_true( + seq: Sequence[_T], dflt: _T, pred: Optional[_Predicate[_T]] +) -> None: + assert itextras.first_true(seq, dflt, pred) == aextras.first_true(seq, dflt, pred) + + +@pytest.mark.parametrize( + ("seq1", "seq2"), + [ + ("abc", "def"), + ("", "def"), + ("abc", ""), + ("", ""), + ], +) +def test_flatten(seq1: str, seq2: str) -> None: + assert list(itextras.flatten(seq1 + seq2)) == list(aextras.flatten(seq1 + seq2)) + for repeat in range(3): + assert list(itextras.flatten([seq1] * repeat)) == list( + aextras.flatten([seq1] * repeat) + ) + assert list(itextras.flatten([seq2] * repeat)) == list( + aextras.flatten([seq2] * repeat) + ) + + +@pytest.mark.parametrize( + ("seq", "count", "fill"), + [ + ("abc", 3, None), + ("abcd", 3, None), + ("abc", 3, "x"), + ("abcd", 3, "x"), + ("abc", 0, None), + ("", 3, "xy"), + ], +) +def test_grouper(seq: Sequence[str], count: int, fill: Optional[str]) -> None: + assert list(itextras.grouper(seq, count, fillvalue=fill)) == list( + aextras.grouper(seq, count, fillvalue=fill) + ) + + +@pytest.mark.parametrize( + ("data"), + [ + (1, 2, 3), + (), + ], +) +def test_iter_except(data: Sequence[int]) -> None: + assert list(itextras.iter_except(list(data).pop, IndexError)) == list( + aextras.iter_except(list(data).pop, IndexError) + ) + + +@pytest.mark.parametrize( + ("seq", "count"), + [ + ("abc", 4), + ("abc", 0), + ("", 4), + ], +) +def test_ncycles(seq: str, count: int) -> None: + assert list(itextras.ncycles(seq, count)) == list(aextras.ncycles(seq, count)) + + +@pytest.mark.parametrize( + ("seq", "n", "dflt"), + [ + ("abc", 1, None), + ("abc", 10, None), + ("abc", 10, "x"), + ("", 0, None), + ], +) +def test_nth(seq: str, n: int, dflt: Optional[str]) -> None: + assert itextras.nth(seq, n, dflt) == aextras.nth(seq, n, dflt) + + +@pytest.mark.parametrize( + ("seq"), + [ + "abc", + "", + ], +) +def test_padnone(seq: str) -> None: + assert _take(10, itextras.padnone(seq)) == _take(10, aextras.padnone(seq)) + + +@pytest.mark.parametrize( + ("seq"), + [ + (), + (1,), + (1, 2), + (1, 2, 3), + (1, 2, 3, 4), + ], +) +def test_pairwise(seq: Sequence[int]) -> None: + assert list(itextras.pairwise(seq)) == list(aextras.pairwise(seq)) + + +@pytest.mark.parametrize( + ("pred", "seq"), + [ + (lambda x: x % 2, (0, 1, 2, 3)), + (lambda x: x % 2, (0, 2)), + (lambda x: x % 2, ()), + ], +) +def test_partition(pred: _Predicate[int], seq: Sequence[int]) -> None: + # assert list(itextras.partition(pred, seq)) == list(aextras.partition(pred, seq)) + true1, false1 = itextras.partition(pred, seq) + true2, false2 = aextras.partition(pred, seq) + assert list(true1) == list(true2) + assert list(false1) == list(false2) + + +@pytest.mark.parametrize( + ("value", "seq"), + [ + (1, (2, 3)), + (1, ()), + ], +) +def test_prepend(value: int, seq: Sequence[int]) -> None: + assert list(itextras.prepend(value, seq)) == list(aextras.prepend(value, seq)) + + +@pytest.mark.parametrize( + ("seq", "pred"), + [ + ((0, 1), lambda x: x % 2 == 0), + ((1, 1), lambda x: x % 2 == 0), + ((), lambda x: x % 2 == 0), + ], +) +def test_quantify(seq: Sequence[int], pred: _Predicate[int]) -> None: + assert itextras.quantify(seq) == aextras.quantify(seq) + assert itextras.quantify(seq, pred) == aextras.quantify(seq, pred) + + +@pytest.mark.parametrize( + ("func", "times", "args"), + [ + (lambda: 1, 5, []), + (lambda: 1, 0, []), + (lambda x: x + 1, 10, [3]), + (lambda x, y: x + y, 10, [3, 4]), + ], +) +def test_repeatfunc(func: Callable, times: int, args: Sequence[int]) -> None: + assert _take(5, itextras.repeatfunc(func, None, *args)) == _take( + 5, aextras.repeatfunc(func, None, *args) + ) + assert list(itextras.repeatfunc(func, times, *args)) == list( + aextras.repeatfunc(func, times, *args) + ) + + +@pytest.mark.parametrize( + ("seq1", "seq2"), + [ + ("abc", "def"), + ("a", "bc"), + ("ab", "c"), + ("", "abc"), + ("", ""), + ], +) +def test_roundrobin(seq1: str, seq2: str) -> None: + assert list(itextras.roundrobin(seq1)) == list(aextras.roundrobin(seq1)) + assert list(itextras.roundrobin(seq1, seq2)) == list(aextras.roundrobin(seq1, seq2)) + + +@pytest.mark.parametrize( + ("func", "start"), + [ + (lambda x: 2 * x, 17), + (lambda x: -x, -3), + ], +) +def test_tabulate(func: Callable[[int], int], start: int) -> None: + assert _take(5, itextras.tabulate(func)) == _take(5, aextras.tabulate(func)) + assert _take(5, itextras.tabulate(func, start)) == _take( + 5, aextras.tabulate(func, start) + ) + + +@pytest.mark.parametrize( + ("n", "seq"), + [ + (3, "abcdefg"), + (0, "abcdefg"), + (10, "abcdefg"), + (5, ""), + ], +) +def test_tail(n: int, seq: str) -> None: + assert list(itextras.tail(n, seq)) == list(aextras.tail(n, seq)) + + +@pytest.mark.parametrize( + ("n", "seq"), + [ + (3, "abcdefg"), + (0, "abcdefg"), + (10, "abcdefg"), + (5, ""), + ], +) +def test_take(n: int, seq: str) -> None: + assert list(itextras.take(n, seq)) == list(aextras.take(n, seq))