|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -from typing import Any, FrozenSet, List, Tuple, Union, cast |
4 | | -from typing_extensions import Final |
| 3 | +from typing import FrozenSet, List, Tuple, Union |
| 4 | +from typing_extensions import Final, TypeGuard |
5 | 5 |
|
6 | 6 | # Supported Python literal types. All tuple / frozenset items must have supported |
7 | 7 | # literal types as well, but we can't represent the type precisely. |
8 | 8 | LiteralValue = Union[ |
9 | 9 | str, bytes, int, bool, float, complex, Tuple[object, ...], FrozenSet[object], None |
10 | 10 | ] |
11 | 11 |
|
| 12 | + |
| 13 | +def _is_literal_value(obj: object) -> TypeGuard[LiteralValue]: |
| 14 | + return isinstance(obj, (str, bytes, int, float, complex, tuple, frozenset, type(None))) |
| 15 | + |
| 16 | + |
12 | 17 | # Some literals are singletons and handled specially (None, False and True) |
13 | 18 | NUM_SINGLETONS: Final = 3 |
14 | 19 |
|
@@ -55,13 +60,15 @@ def record_literal(self, value: LiteralValue) -> None: |
55 | 60 | tuple_literals = self.tuple_literals |
56 | 61 | if value not in tuple_literals: |
57 | 62 | for item in value: |
58 | | - self.record_literal(cast(Any, item)) |
| 63 | + assert _is_literal_value(item) |
| 64 | + self.record_literal(item) |
59 | 65 | tuple_literals[value] = len(tuple_literals) |
60 | 66 | elif isinstance(value, frozenset): |
61 | 67 | frozenset_literals = self.frozenset_literals |
62 | 68 | if value not in frozenset_literals: |
63 | 69 | for item in value: |
64 | | - self.record_literal(cast(Any, item)) |
| 70 | + assert _is_literal_value(item) |
| 71 | + self.record_literal(item) |
65 | 72 | frozenset_literals[value] = len(frozenset_literals) |
66 | 73 | else: |
67 | 74 | assert False, "invalid literal: %r" % value |
@@ -159,7 +166,8 @@ def _encode_collection_values( |
159 | 166 | value = value_by_index[i] |
160 | 167 | result.append(str(len(value))) |
161 | 168 | for item in value: |
162 | | - index = self.literal_index(cast(Any, item)) |
| 169 | + assert _is_literal_value(item) |
| 170 | + index = self.literal_index(item) |
163 | 171 | result.append(str(index)) |
164 | 172 | return result |
165 | 173 |
|
|
0 commit comments