diff --git a/tests/test_utils.py b/tests/test_utils.py index 8c7f5c318..7617fdd20 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -25,6 +25,11 @@ def test_count(): assert count([True, False, True, True, False]) == 3 assert count([5 > 1, len("abc") == 3, 3+1 == 5]) == 2 +def test_multimap(): + assert multimap([(1, 2),(1, 3),(1, 4),(2, 3),(2, 4),(4, 5)]) == \ + {1: [2, 3, 4], 2: [3, 4], 4: [5]} + assert multimap([("a", 2), ("a", 3), ("a", 4), ("b", 3), ("b", 4), ("c", 5)]) == \ + {'a': [2, 3, 4], 'b': [3, 4], 'c': [5]} def test_product(): assert product([1, 2, 3, 4]) == 24 diff --git a/utils.py b/utils.py index ab6aa1032..9c86b17ca 100644 --- a/utils.py +++ b/utils.py @@ -42,10 +42,10 @@ def count(seq): # TODO: replace with quantify def multimap(items): """Given (key, val) pairs, return {key: [val, ....], ...}.""" - result = defaultdict(list) + result = collections.defaultdict(list) for (key, val) in items: result[key].append(val) - return result + return dict(result) def multimap_items(mmap): """Yield all (key, val) pairs stored in the multimap."""