@@ -837,6 +837,14 @@ def dispatch(cls):
837837 dispatch_cache [cls ] = impl
838838 return impl
839839
840+ def _is_union_type (cls ):
841+ from typing import get_origin , Union
842+ return get_origin (cls ) in {Union , types .UnionType }
843+
844+ def _is_valid_union_type (cls ):
845+ from typing import get_args
846+ return _is_union_type (cls ) and all (isinstance (arg , type ) for arg in get_args (cls ))
847+
840848 def register (cls , func = None ):
841849 """generic_func.register(cls, func) -> func
842850
@@ -845,7 +853,7 @@ def register(cls, func=None):
845853 """
846854 nonlocal cache_token
847855 if func is None :
848- if isinstance (cls , type ):
856+ if isinstance (cls , type ) or _is_valid_union_type ( cls ) :
849857 return lambda f : register (cls , f )
850858 ann = getattr (cls , '__annotations__' , {})
851859 if not ann :
@@ -859,12 +867,25 @@ def register(cls, func=None):
859867 # only import typing if annotation parsing is necessary
860868 from typing import get_type_hints
861869 argname , cls = next (iter (get_type_hints (func ).items ()))
862- if not isinstance (cls , type ):
863- raise TypeError (
864- f"Invalid annotation for { argname !r} . "
865- f"{ cls !r} is not a class."
866- )
867- registry [cls ] = func
870+ if not isinstance (cls , type ) and not _is_valid_union_type (cls ):
871+ if _is_union_type (cls ):
872+ raise TypeError (
873+ f"Invalid annotation for { argname !r} . "
874+ f"{ cls !r} not all arguments are classes."
875+ )
876+ else :
877+ raise TypeError (
878+ f"Invalid annotation for { argname !r} . "
879+ f"{ cls !r} is not a class."
880+ )
881+
882+ if _is_union_type (cls ):
883+ from typing import get_args
884+
885+ for arg in get_args (cls ):
886+ registry [arg ] = func
887+ else :
888+ registry [cls ] = func
868889 if cache_token is None and hasattr (cls , '__abstractmethods__' ):
869890 cache_token = get_cache_token ()
870891 dispatch_cache .clear ()
0 commit comments