This typing library is intended to replace jaxtyping for runtime type checking of torch tensors and numpy arrays.
In particular, we support two functions that beartype/jaxtype do not:
- Support for torch.jit.script/torch.compile/torch.jit.trace
- Pydantic model type annotations for torch tensors.
- Shape and Type Validation: Validate tensor shapes and types at runtime with symbolic dimension support.
- Pydantic Integration: First-class support for tensor validation in Pydantic models.
- Context-Aware Validation: Ensures consistency across multiple tensors in the same context.
- ONNX/torch.compile Compatible: Works seamlessly with model export and compilation workflows.
- Symbolic Dimensions: Support for named dimensions that enforce consistency.
Install dltype through pip
pip3 install dltypeNote
dltype does not depend explicitly on torch or numpy, but you must have at least one of them installed at import time otherwise the import will fail.
Type hints are evaluated in a context in source-code order, so any references to dimension symbols must exist before an expression is evaluated.
DL Type supports four types of dimension specifications:
Single element tensors with no shape
IntTensor[None] # An integer tensor with a single value and no axesSimple integer dimensions with fixed sizes:
FloatTensor["3 5"] # A tensor with shape (3, 5)
FloatTensor["batch channels=3 height width"] # identifiers set to dimensions for documentationMathematical expressions combining literals and symbols.
FloatTensor["batch channels*2"] # If channels=64, shape would be (batch, 128)
FloatTensor["batch-1"] # One less than the batch dimension
FloatTensor["features/2"] # Half the features dimensionNote
Expressions must never have spaces.
+Addition-Subtraction*Multiplication/Integer division^Exponentiation
min(a,b)Minimum of two expressionsmax(a,b)Maximum of two expressions
Warning
While nested function calls like min(max(a,b),c) are supported,
combining function calls with other operators in the same expression
(e.g., min(1,batch)+max(2,channels)) is not supported to simplify parsing.
Symbolic Dimensions Named dimensions that ensure consistency across tensors:
FloatTensor["batch channels"] # A tensor with two dimensionsNamed or anonymous dimension identifiers that may cover zero or more dimensions in the actual tensors. Only one multi-dimension identifier is allowed per type hint.
FloatTensor["... channels h w"] # anonymous dimension will not be matched across tensors
DoubleTensor["batch *channels features"] # named dimension which can be matched across tensorsfrom typing import Annotated
import torch
from dltype import FloatTensor, dltyped
@dltyped()
def add_tensors(
x: Annotated[torch.Tensor, FloatTensor["batch features"]],
y: Annotated[torch.Tensor, FloatTensor["batch features"]]
) -> Annotated[torch.Tensor, FloatTensor["batch features"]]:
return x + yfrom typing import Annotated
from pydantic import BaseModel
import torch
from dltype import FloatTensor, IntTensor
class ImageBatch(BaseModel):
# note the parenthesis instead of brackets for pydantic models
images: Annotated[torch.Tensor, FloatTensor("batch 3 height width")]
labels: Annotated[torch.Tensor, IntTensor("batch")]
# All tensor validations happen automatically
# Shape consistency is enforced across fieldsWe expose @dltyped_namedtuple() for NamedTuples.
NamedTuples are validated upon construction, beware that assignments or manipulations after construction are unchecked.
@dltype.dltyped_namedtuple()
class MyNamedTuple(NamedTuple):
tensor: Annotated[torch.Tensor, dltype.FloatTensor["b c h w"]]
mask: Annotated[torch.Tensor, dltype.IntTensor["b h w"]]
other: intSimilar to NamedTuples and pydantic BaseModels, @dataclasses may be decorated and validated.
The normal caveats apply in that we only validate at construction and not on assignment.
Therefore, we recommend using frozen @dataclasses when possible.
from typing import Annotated
import torch
from dltype import FloatTensor, IntTensor, dltyped_dataclass
# order is important, we raise an error if dltyped_dataclass is applied below dataclass
# this is because the @dataclass decorator applies a bunch of source code modification that we don't want to have to hack around
@dltyped_dataclass()
@dataclass(frozen=True, slots=True)
class MyDataclass:
images: Annotated[torch.Tensor, FloatTensor["batch 3 height width"]]
labels: Annotated[torch.Tensor, IntTensor["batch"]]We have no support for general unions of types to prevent confusing behavior when using runtime shape checking.
DLType only supports optional types (i.e. Type | None).
To annotate a tensor as being optional, see the example below.
@dltype.dltyped()
def optional_tensor_func(tensor: Annotated[torch.Tensor, dltype.FloatTensor["b c h w"]] | None) -> torch.Tensor:
if tensor is None:
return torch.zeros(1, 3, 5, 5)
return tensorfrom typing import Annotated
import torch
import numpy as np
from dltype import FloatTensor, dltyped
@dltyped()
def transform_tensors(
points: Annotated[np.ndarray, FloatTensor["N 3"]]
transform: Annotated[torch.Tensor, FloatTensor["3 3"]]
) -> Annotated[torch.Tensor, FloatTensor["N 3"]]:
return torch.from_numpy(points) @ transformThere are situations that a runtime variable may influence the expected shape of a tensor.
To provide external scope to be used by dltype, you may implement the DLTypeScopeProvider protocol.
There are two flavors of this, one for methods, the other for free functions, both are shown below.
Using external scope providers for free functions is not an encouraged use case as it encourages keeping global state.
Additionally, free functions are generally stateless but this makes the type checking logic stateful and thus
makes the execution of the function impure.
We support this because there are certain scenarios where loading a configuration from a file and providing it as an expected dimension for some typed function may be useful and necessary.
# Using `self` as the DLTypeScopeProvider in an object (this is the primary use case)
class MyModule(nn.Module):
# ... some implementation details
def __init__(self, config: MyConfig) -> None:
self.cfg = config
# the DLTypeScopeProvider protocol requires this function to be specified.
def get_dltype_scope(self) -> dict[str, int]:
"""Return the DLType scope which is simply a dictionary of 'axis-name' -> dimension size."""
return {"in_channel": self.cfg.in_channel}
# "self" is a literal, not a string -- pyright will yell at you if this is wrong.
# The first argument of the decorated function will be checked to obey the protocol before calling `get_dltype_scope`.
@dltyped("self")
def forward(
self,
tensor_1: Annotated[torch.Tensor, FloatTensor["batch num_voxel_features z y x"]],
# NOTE: in_channel comes from the external scope and is used in the expression below to evaluate the 'channels' expected dimension
tensor_2: Annotated[torch.Tensor, FloatTensor["batch channels=in_channel-num_voxel_features z y x"]]
) -> torch.Tensor:
## Using a scope provider for a free function
class MyProvider:
def get_dltype_scope(self) -> dict[str, int]:
# load some_value from a config file in the constructor
# or fetch it from a singleton
return {
"dim1": self.some_value
}
@dltyped(provider=MyProvider())
def free_function(tensor: FloatTensor["batch dim1"]) -> None:
# ... implementation details, dim1 provided by the external scopeFloatTensor: For any precision floating point tensor. Is a superset of the following:Float16Tensor: For any 16 bit floating point type. Is a superset of the following:IEEE754HalfFloatTensor: For 16 bit floating point types that comply with the IEE 754 half-precision specification (notably, does not includebfloat16). For numpy tensorsFloat16Tensoris equal toIEEE754HalfFloatTensor. Use if you need to forbid usage ofbfloat16for some reason. Otherwise prefer theFloat16Tensortype for usage with mixed precision codebases.BFloat16Tensor: For 16 bit floating point tensors following thebfloat16format. Is not IEEE 754 compliant and is not supported by NumPy. Use if you need to write code that isbfloat16specific, otherwise preferFloat16Tensorfor usage with a mixed precision instruction scope (such astorch.amp).
Float32Tensor: For single precision 32 bit floats.Float64Tensor: For double precision 64 bit floats. Aliases toDoubleTensor.- Note that
np.float128andnp.longdoublewill be considered asFloatTensorsBUT do not exist as standalone types to be used bydltypeie. there is noFloat128Tensortype. These types are not supported by torch, and only supported by numpy on certain platforms, thus we only "support" them insofar as they are considered floating point types.
IntTensor: For integer tensors of any precision. Is a superset of the following:Int8TensorInt16TensorInt32TensorInt64Tensor
BoolTensor: For boolean tensorsTensorTypeBase: Base class for any tensor which does not enforce any specific datatype, feel free to add custom validation logic by overriding thecheckmethod.
- In the current implementation, every call will be checked, which may or may not be slow depending on how big the context is (it shouldn't be that slow).
- Pydantic default values are not checked.
- Only symbolic, literal, and expressions are allowed for dimension specifiers, f-string syntax from
jaxtypingis not supported. - Only torch tensors and numpy arrays are supported for now.
- Static checking is not supported, only runtime checks, though some errors will be caught statically by construction.
- We do not support container types (i.e.
list[TensorTypeBase]) and we probably never will because parsing arbitrarily nested containers is very slow to do at runtime. - We do not support union types, but we do support optionals.