Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ repos:
additional_dependencies: [ "flake8-eradicate==0.4.0" ]

- repo: https://github.com/pycqa/isort
rev: 5.6.4
rev: 5.9.3
hooks:
- id: isort
args: ["--profile", "black"]
Expand Down
32 changes: 32 additions & 0 deletions src/histolab/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import functools
import warnings
from functools import singledispatch, update_wrapper
from typing import Any, Callable, List, Tuple

import numpy as np
Expand Down Expand Up @@ -313,3 +314,34 @@ def fget(self):
"""
# pylint: disable=unused-variable
return property(functools.lru_cache(maxsize=100)(f))


def method_dispatch(func: Callable[..., Any]) -> Callable[..., Any]:
"""Decorator like @singledispatch to dispatch on the second argument of a method.

It relies on @singledispatch to return a wrapper function that selects which
registered function to call based on the type of the second argument.

This is implementation is required in order to be compatible with Python versions
older than 3.8. In the future we could use ``functools.singledispatchmethod``.

Source: https://stackoverflow.com/a/24602374/7162549

Parameters
----------
func : Callable[..., Any]
Method to dispatch

Returns
-------
Callable[..., Any]
Selected method
"""
dispatcher = singledispatch(func)

def wrapper(*args, **kw):
return dispatcher.dispatch(args[1].__class__)(*args, **kw)

wrapper.register = dispatcher.register
update_wrapper(wrapper, func)
return wrapper
50 changes: 35 additions & 15 deletions tests/unit/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,12 @@

import numpy as np
import pytest
from tests.base import (
COMPLEX_MASK,
IMAGE1_GRAY,
IMAGE1_RGB,
IMAGE1_RGBA,
IMAGE2_GRAY,
IMAGE2_RGB,
IMAGE2_RGBA,
IMAGE3_GRAY_BLACK,
IMAGE3_RGB_BLACK,
IMAGE3_RGBA_BLACK,
IMAGE4_GRAY_WHITE,
IMAGE4_RGB_WHITE,
IMAGE4_RGBA_WHITE,
)

from histolab.types import CP, Region
from histolab.util import (
apply_mask_image,
lazyproperty,
method_dispatch,
np_to_pil,
random_choice_true_mask2d,
rectangle_to_mask,
Expand All @@ -35,6 +21,21 @@
threshold_to_mask,
)

from ..base import (
COMPLEX_MASK,
IMAGE1_GRAY,
IMAGE1_RGB,
IMAGE1_RGBA,
IMAGE2_GRAY,
IMAGE2_RGB,
IMAGE2_RGBA,
IMAGE3_GRAY_BLACK,
IMAGE3_RGB_BLACK,
IMAGE3_RGBA_BLACK,
IMAGE4_GRAY_WHITE,
IMAGE4_RGB_WHITE,
IMAGE4_RGBA_WHITE,
)
from ..fixtures import MASKNPY, NPY
from ..util import load_expectation, load_python_expression

Expand Down Expand Up @@ -239,3 +240,22 @@ def fget(self):
@pytest.fixture
def obj(self, Obj):
return Obj()


def test_method_dispatch():
class Obj(object):
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)

@method_dispatch
def get(self, arg):
return getattr(self, arg, None)

@get.register(list)
def _(self, arg):
return [self.get(x) for x in arg]

obj = Obj(a=1, b=2, c=3)
assert obj.get("b") == 2
assert obj.get(["a", "c"]) == [1, 3]