|
| 1 | +import itertools |
1 | 2 |
|
2 | 3 |
|
3 | 4 | def check_in_list(_values, *, _print_supported_values=True, **kwargs):
|
@@ -31,3 +32,64 @@ def check_in_list(_values, *, _print_supported_values=True, **kwargs):
|
31 | 32 | f"supported values are {', '.join(map(repr, values))}")
|
32 | 33 | else:
|
33 | 34 | raise ValueError(f"{val!r} is not a valid value for {key}")
|
| 35 | + |
| 36 | + |
| 37 | +def check_shape(_shape, **kwargs): |
| 38 | + """ |
| 39 | + For each *key, value* pair in *kwargs*, check that *value* has the shape |
| 40 | + *_shape*, if not, raise an appropriate ValueError. |
| 41 | +
|
| 42 | + *None* in the shape is treated as a "free" size that can have any length. |
| 43 | + e.g. (None, 2) -> (N, 2) |
| 44 | +
|
| 45 | + The values checked must be numpy arrays. |
| 46 | +
|
| 47 | + Examples |
| 48 | + -------- |
| 49 | + To check for (N, 2) shaped arrays |
| 50 | +
|
| 51 | + >>> _api.check_shape((None, 2), arg=arg, other_arg=other_arg) |
| 52 | + """ |
| 53 | + target_shape = _shape |
| 54 | + for k, v in kwargs.items(): |
| 55 | + data_shape = v.shape |
| 56 | + |
| 57 | + if len(target_shape) != len(data_shape) or any( |
| 58 | + t not in [s, None] |
| 59 | + for t, s in zip(target_shape, data_shape) |
| 60 | + ): |
| 61 | + dim_labels = iter(itertools.chain( |
| 62 | + 'MNLIJKLH', |
| 63 | + (f"D{i}" for i in itertools.count()))) |
| 64 | + text_shape = ", ".join((str(n) |
| 65 | + if n is not None |
| 66 | + else next(dim_labels) |
| 67 | + for n in target_shape)) |
| 68 | + |
| 69 | + raise ValueError( |
| 70 | + f"{k!r} must be {len(target_shape)}D " |
| 71 | + f"with shape ({text_shape}). " |
| 72 | + f"Your input has shape {v.shape}." |
| 73 | + ) |
| 74 | + |
| 75 | + |
| 76 | +def check_getitem(_mapping, **kwargs): |
| 77 | + """ |
| 78 | + *kwargs* must consist of a single *key, value* pair. If *key* is in |
| 79 | + *_mapping*, return ``_mapping[value]``; else, raise an appropriate |
| 80 | + ValueError. |
| 81 | +
|
| 82 | + Examples |
| 83 | + -------- |
| 84 | + >>> _api.check_getitem({"foo": "bar"}, arg=arg) |
| 85 | + """ |
| 86 | + mapping = _mapping |
| 87 | + if len(kwargs) != 1: |
| 88 | + raise ValueError("check_getitem takes a single keyword argument") |
| 89 | + (k, v), = kwargs.items() |
| 90 | + try: |
| 91 | + return mapping[v] |
| 92 | + except KeyError: |
| 93 | + raise ValueError( |
| 94 | + "{!r} is not a valid value for {}; supported values are {}" |
| 95 | + .format(v, k, ', '.join(map(repr, mapping)))) from None |
0 commit comments