Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 5bad41e

Browse files
committed
Merge in r68394 fixing itertools.permutations() and combinations().
1 parent 5e4e427 commit 5bad41e

3 files changed

Lines changed: 47 additions & 22 deletions

File tree

Doc/library/itertools.rst

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ loops that truncate the stream.
104104
# combinations(range(4), 3) --> 012 013 023 123
105105
pool = tuple(iterable)
106106
n = len(pool)
107-
indices = range(r)
107+
if r > n:
108+
return
109+
indices = list(range(r))
108110
yield tuple(pool[i] for i in indices)
109111
while 1:
110112
for i in reversed(range(r)):
@@ -128,6 +130,8 @@ loops that truncate the stream.
128130
if sorted(indices) == list(indices):
129131
yield tuple(pool[i] for i in indices)
130132

133+
The number of items returned is ``n! / r! / (n-r)!`` when ``0 <= r <= n``
134+
or zero when ``r > n``.
131135

132136
.. function:: count([n])
133137

@@ -325,7 +329,9 @@ loops that truncate the stream.
325329
pool = tuple(iterable)
326330
n = len(pool)
327331
r = n if r is None else r
328-
indices = range(n)
332+
if r > n:
333+
return
334+
indices = list(range(n))
329335
cycles = range(n, n-r, -1)
330336
yield tuple(pool[i] for i in indices[:r])
331337
while n:
@@ -354,6 +360,8 @@ loops that truncate the stream.
354360
if len(set(indices)) == r:
355361
yield tuple(pool[i] for i in indices)
356362

363+
The number of items returned is ``n! / (n-r)!`` when ``0 <= r <= n``
364+
or zero when ``r > n``.
357365

358366
.. function:: product(*iterables[, repeat])
359367

@@ -593,7 +601,8 @@ which incur interpreter overhead.
593601
return (d for d, s in zip(data, selectors) if s)
594602

595603
def combinations_with_replacement(iterable, r):
596-
"combinations_with_replacement('ABC', 3) --> AA AB AC BB BC CC"
604+
"combinations_with_replacement('ABC', 2) --> AA AB AC BB BC CC"
605+
# number items returned: (n+r-1)! / r! / (n-1)!
597606
pool = tuple(iterable)
598607
n = len(pool)
599608
indices = [0] * r

Lib/test/test_itertools.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,18 +75,20 @@ def test_chain_from_iterable(self):
7575
self.assertRaises(TypeError, list, chain.from_iterable([2, 3]))
7676

7777
def test_combinations(self):
78-
self.assertRaises(TypeError, combinations, 'abc') # missing r argument
78+
self.assertRaises(TypeError, combinations, 'abc') # missing r argument
7979
self.assertRaises(TypeError, combinations, 'abc', 2, 1) # too many arguments
8080
self.assertRaises(TypeError, combinations, None) # pool is not iterable
8181
self.assertRaises(ValueError, combinations, 'abc', -2) # r is negative
82-
self.assertRaises(ValueError, combinations, 'abc', 32) # r is too big
82+
self.assertEqual(list(combinations('abc', 32)), []) # r > n
8383
self.assertEqual(list(combinations(range(4), 3)),
8484
[(0,1,2), (0,1,3), (0,2,3), (1,2,3)])
8585

8686
def combinations1(iterable, r):
8787
'Pure python version shown in the docs'
8888
pool = tuple(iterable)
8989
n = len(pool)
90+
if r > n:
91+
return
9092
indices = list(range(r))
9193
yield tuple(pool[i] for i in indices)
9294
while 1:
@@ -110,9 +112,9 @@ def combinations2(iterable, r):
110112

111113
for n in range(7):
112114
values = [5*x-12 for x in range(n)]
113-
for r in range(n+1):
115+
for r in range(n+2):
114116
result = list(combinations(values, r))
115-
self.assertEqual(len(result), fact(n) / fact(r) / fact(n-r)) # right number of combs
117+
self.assertEqual(len(result), 0 if r>n else fact(n) / fact(r) / fact(n-r)) # right number of combs
116118
self.assertEqual(len(result), len(set(result))) # no repeats
117119
self.assertEqual(result, sorted(result)) # lexicographic order
118120
for c in result:
@@ -123,7 +125,7 @@ def combinations2(iterable, r):
123125
self.assertEqual(list(c),
124126
[e for e in values if e in c]) # comb is a subsequence of the input iterable
125127
self.assertEqual(result, list(combinations1(values, r))) # matches first pure python version
126-
self.assertEqual(result, list(combinations2(values, r))) # matches first pure python version
128+
self.assertEqual(result, list(combinations2(values, r))) # matches second pure python version
127129

128130
# Test implementation detail: tuple re-use
129131
self.assertEqual(len(set(map(id, combinations('abcde', 3)))), 1)
@@ -134,7 +136,7 @@ def test_permutations(self):
134136
self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments
135137
self.assertRaises(TypeError, permutations, None) # pool is not iterable
136138
self.assertRaises(ValueError, permutations, 'abc', -2) # r is negative
137-
self.assertRaises(ValueError, permutations, 'abc', 32) # r is too big
139+
self.assertEqual(list(permutations('abc', 32)), []) # r > n
138140
self.assertRaises(TypeError, permutations, 'abc', 's') # r is not an int or None
139141
self.assertEqual(list(permutations(range(3), 2)),
140142
[(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)])
@@ -144,6 +146,8 @@ def permutations1(iterable, r=None):
144146
pool = tuple(iterable)
145147
n = len(pool)
146148
r = n if r is None else r
149+
if r > n:
150+
return
147151
indices = list(range(n))
148152
cycles = list(range(n-r+1, n+1))[::-1]
149153
yield tuple(pool[i] for i in indices[:r])
@@ -172,17 +176,17 @@ def permutations2(iterable, r=None):
172176

173177
for n in range(7):
174178
values = [5*x-12 for x in range(n)]
175-
for r in range(n+1):
179+
for r in range(n+2):
176180
result = list(permutations(values, r))
177-
self.assertEqual(len(result), fact(n) / fact(n-r)) # right number of perms
181+
self.assertEqual(len(result), 0 if r>n else fact(n) / fact(n-r)) # right number of perms
178182
self.assertEqual(len(result), len(set(result))) # no repeats
179183
self.assertEqual(result, sorted(result)) # lexicographic order
180184
for p in result:
181185
self.assertEqual(len(p), r) # r-length permutations
182186
self.assertEqual(len(set(p)), r) # no duplicate elements
183187
self.assert_(all(e in values for e in p)) # elements taken from input iterable
184188
self.assertEqual(result, list(permutations1(values, r))) # matches first pure python version
185-
self.assertEqual(result, list(permutations2(values, r))) # matches first pure python version
189+
self.assertEqual(result, list(permutations2(values, r))) # matches second pure python version
186190
if r == n:
187191
self.assertEqual(result, list(permutations(values, None))) # test r as None
188192
self.assertEqual(result, list(permutations(values))) # test default r
@@ -1384,6 +1388,26 @@ def __init__(self, newarg=None, *args):
13841388
>>> list(combinations_with_replacement('abc', 2))
13851389
[('a', 'a'), ('a', 'b'), ('a', 'c'), ('b', 'b'), ('b', 'c'), ('c', 'c')]
13861390
1391+
>>> list(combinations_with_replacement('01', 3))
1392+
[('0', '0', '0'), ('0', '0', '1'), ('0', '1', '1'), ('1', '1', '1')]
1393+
1394+
>>> def combinations_with_replacement2(iterable, r):
1395+
... 'Alternate version that filters from product()'
1396+
... pool = tuple(iterable)
1397+
... n = len(pool)
1398+
... for indices in product(range(n), repeat=r):
1399+
... if sorted(indices) == list(indices):
1400+
... yield tuple(pool[i] for i in indices)
1401+
1402+
>>> list(combinations_with_replacement('abc', 2)) == list(combinations_with_replacement2('abc', 2))
1403+
True
1404+
1405+
>>> list(combinations_with_replacement('01', 3)) == list(combinations_with_replacement2('01', 3))
1406+
True
1407+
1408+
>>> list(combinations_with_replacement('2310', 6)) == list(combinations_with_replacement2('2310', 6))
1409+
True
1410+
13871411
>>> list(unique_everseen('AAAABBBCCDAABBB'))
13881412
['A', 'B', 'C', 'D']
13891413

Modules/itertoolsmodule.c

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1880,10 +1880,6 @@ combinations_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
18801880
PyErr_SetString(PyExc_ValueError, "r must be non-negative");
18811881
goto error;
18821882
}
1883-
if (r > n) {
1884-
PyErr_SetString(PyExc_ValueError, "r cannot be bigger than the iterable");
1885-
goto error;
1886-
}
18871883

18881884
indices = PyMem_Malloc(r * sizeof(Py_ssize_t));
18891885
if (indices == NULL) {
@@ -1903,7 +1899,7 @@ combinations_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
19031899
co->indices = indices;
19041900
co->result = NULL;
19051901
co->r = r;
1906-
co->stopped = 0;
1902+
co->stopped = r > n ? 1 : 0;
19071903

19081904
return (PyObject *)co;
19091905

@@ -2143,10 +2139,6 @@ permutations_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
21432139
PyErr_SetString(PyExc_ValueError, "r must be non-negative");
21442140
goto error;
21452141
}
2146-
if (r > n) {
2147-
PyErr_SetString(PyExc_ValueError, "r cannot be bigger than the iterable");
2148-
goto error;
2149-
}
21502142

21512143
indices = PyMem_Malloc(n * sizeof(Py_ssize_t));
21522144
cycles = PyMem_Malloc(r * sizeof(Py_ssize_t));
@@ -2170,7 +2162,7 @@ permutations_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
21702162
po->cycles = cycles;
21712163
po->result = NULL;
21722164
po->r = r;
2173-
po->stopped = 0;
2165+
po->stopped = r > n ? 1 : 0;
21742166

21752167
return (PyObject *)po;
21762168

0 commit comments

Comments
 (0)