Dimensionality annotations for tensor parameters and return values.
- Clone the repository
- Run
python setup.py install
This module uses the python type annotations to provide run-time argument size checking for PyTorch tensors, allowing for writing
- Self-documenting code (in a way which doesn't silently become outdated)
- Fail-fast code, where the error points to the first location where a contract was violated
the only user-facing part of torch_dimcheck is the dimchecked function decorator:
import torch
from torch_dimcheck import dimchecked
@dimchecked
def matmul(a: ['X', 'Y'], b: ['Y', 'Z']) -> ['X', 'Z']:
return torch.matmul(a, b)
a = torch.randn(3, 4)
b = torch.randn(4, 2)
c = matmul(a, b) # works
c = matmul(b, a) # throws at function call levelEach function parameter and output value can be annotated with a list where each element is either str, int or .... We refer to the elements of the list as labels and say that
- The tensor will be required to have as many dimensions as the list has labels.
intlabels require the tensor dimension to have size equal to that value (i.e.f(a: [1, 4])will accept only tensors of shape[1, 4])strlabels create a unique dynamic label, which can have any size but must be consistent across the whole signature. This means that inadd(a: ['A'], b: ['A'])the tensors must be 1-dimensional and of equal size- Ellipsis
...is a special value which can stand for any amount of dimensions, thus being a way to violate rule 1. There can be at most one...per tensor annotation (otherwise the notation would be ambiguous). For example,g(a: ['A', ..., 'B'], b: ['A', ..., 'B'])means thataandbcan have an arbitrary amount of dimensions as long as the first and last ones agree in size. - Argument annotations other than
lists are ignored, which means that one can still use regular type hints alongside@dimchecked.
Additionally, function outputs are annotated as a tuple of lists, with each list referring to one function output.
@dimchecked
def matmul_two_ways(a: ['X', 'Y'], b: ['Y', 'Z']) -> (['X', 'Z'], ['Z', 'X']):
ab = torch.matmul(a, b)
ba = torch.matmul(b, a)
return ab, baIn this context ... has a special meaning and can replace a list, meaning that this output will not be checked: this is useful if only part of the function outputs are tensors.
@dimchecked
def load_ith_image(i) -> (['H', 'W', 3], ...):
path = find_ith_path(i)
return load_image(path), pathFinally, if there is only a single tensor as an output, the outer tuple can be skipped:
@dimchecked
def f() -> ['X', 'Y']:
pass
# is equivalent to
@dimchecked
def f() -> (['X', 'Y'], ):
pass