diff --git a/gcloud/iterator.py b/gcloud/iterator.py index 1126e39eef91..f62d28578ea5 100644 --- a/gcloud/iterator.py +++ b/gcloud/iterator.py @@ -141,3 +141,47 @@ def get_items_from_response(self, response): :returns: Items that the iterator should yield. """ raise NotImplementedError + + +class MethodIterator(object): + """Method-based iterator iterating through Cloud JSON APIs list responses. + + :type method: instance method + :param method: ``list_foo`` method of a domain object, taking as arguments + ``page_token``, ``page_size``, and optional additional + keyword arguments. + + :type page_token: string or ``NoneType`` + :param page_token: Initial page token to pass. if ``None``, fetch the + first page from the ``method`` API call. + + :type page_size: integer or ``NoneType`` + :param page_size: Maximum number of items to return from the ``method`` + API call; if ``None``, uses the default for the API. + + :type max_calls: integer or ``NoneType`` + :param max_calls: Maximum number of times to make the ``method`` + API call; if ``None``, applies no limit. + + :type kw: dict + :param kw: optional keyword argments to be passed to ``method``. + """ + def __init__(self, method, page_token=None, page_size=None, + max_calls=None, **kw): + self._method = method + self._token = page_token + self._page_size = page_size + self._kw = kw + self._max_calls = max_calls + self._page_num = 0 + + def __iter__(self): + while self._max_calls is None or self._page_num < self._max_calls: + items, new_token = self._method( + page_token=self._token, page_size=self._page_size, **self._kw) + for item in items: + yield item + if new_token is None: + return + self._page_num += 1 + self._token = new_token diff --git a/gcloud/test_iterator.py b/gcloud/test_iterator.py index 04ea5908acb1..102da9655d53 100644 --- a/gcloud/test_iterator.py +++ b/gcloud/test_iterator.py @@ -172,6 +172,94 @@ def test_get_items_from_response_raises_NotImplementedError(self): iterator.get_items_from_response, object()) +class TestMethodIterator(unittest2.TestCase): + + def _getTargetClass(self): + from gcloud.iterator import MethodIterator + return MethodIterator + + def _makeOne(self, *args, **kw): + return self._getTargetClass()(*args, **kw) + + def test_ctor_defaults(self): + wlm = _WithListMethod() + iterator = self._makeOne(wlm.list_foo) + self.assertEqual(iterator._method, wlm.list_foo) + self.assertEqual(iterator._token, None) + self.assertEqual(iterator._page_size, None) + self.assertEqual(iterator._kw, {}) + self.assertEqual(iterator._max_calls, None) + self.assertEqual(iterator._page_num, 0) + + def test_ctor_explicit(self): + wlm = _WithListMethod() + TOKEN = wlm._letters + SIZE = 4 + CALLS = 2 + iterator = self._makeOne(wlm.list_foo, TOKEN, SIZE, CALLS, + foo_type='Bar') + self.assertEqual(iterator._method, wlm.list_foo) + self.assertEqual(iterator._token, TOKEN) + self.assertEqual(iterator._page_size, SIZE) + self.assertEqual(iterator._kw, {'foo_type': 'Bar'}) + self.assertEqual(iterator._max_calls, CALLS) + self.assertEqual(iterator._page_num, 0) + + def test___iter___defaults(self): + import string + wlm = _WithListMethod() + iterator = self._makeOne(wlm.list_foo) + found = [] + for char in iterator: + found.append(char) + self.assertEqual(found, list(string.printable)) + self.assertEqual(len(wlm._called_with), len(found) // 10) + for i, (token, size, kw) in enumerate(wlm._called_with): + if i == 0: + self.assertEqual(token, None) + else: + self.assertEqual(token, string.printable[i * 10:]) + self.assertEqual(size, None) + self.assertEqual(kw, {}) + + def test___iter___explicit_size_and_maxcalls_and_kw(self): + import string + wlm = _WithListMethod() + iterator = self._makeOne(wlm.list_foo, page_size=2, max_calls=3, + foo_type='Bar') + found = [] + for char in iterator: + found.append(char) + self.assertEqual(found, list(string.printable[:2 * 3])) + self.assertEqual(len(wlm._called_with), len(found) // 2) + for i, (token, size, kw) in enumerate(wlm._called_with): + if i == 0: + self.assertEqual(token, None) + else: + self.assertEqual(token, string.printable[i * 2:]) + self.assertEqual(size, 2) + self.assertEqual(kw, {'foo_type': 'Bar'}) + + +class _WithListMethod(object): + + def __init__(self): + import string + self._called_with = [] + self._letters = string.printable + + def list_foo(self, page_token, page_size, **kw): + if page_token is not None: + assert page_token == self._letters + self._called_with.append((page_token, page_size, kw)) + if page_size is None: + page_size = 10 + page, self._letters = ( + self._letters[:page_size], self._letters[page_size:]) + token = self._letters or None + return page, token + + class _Connection(object): def __init__(self, *responses):