Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
29 changes: 25 additions & 4 deletions ignite/engine/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,13 @@ def __call__(
sum(
(
event_filter is not None,
every is not None,
once is not None,
(before is not None or after is not None),
(every is not None or before is not None or after is not None),
)
)
!= 1
):
raise ValueError("Only one of the input arguments should be specified except before and after")
raise ValueError("Only one of the input arguments should be specified, except before, after and every")

if (event_filter is not None) and not callable(event_filter):
raise TypeError("Argument event_filter should be a callable")
Expand Down Expand Up @@ -116,7 +115,10 @@ def __call__(
event_filter = self.once_event_filter([once] if isinstance(once, int) else once)

if before is not None or after is not None:
event_filter = self.before_and_after_event_filter(before, after)
if every is not None:
event_filter = self.every_before_and_after_event_filter(every, before, after)
else:
event_filter = self.before_and_after_event_filter(before, after)

# check signature:
if event_filter is not None:
Expand Down Expand Up @@ -159,6 +161,21 @@ def wrapper(engine: "Engine", event: int) -> bool:

return wrapper

@staticmethod
def every_before_and_after_event_filter(
every: int, before: Optional[int] = None, after: Optional[int] = None
) -> Callable:
"""A wrapper which triggers for every `every` iterations after `after` and before `before`."""
before_: Union[int, float] = float("inf") if before is None else before
after_: int = 0 if after is None else after

def wrapper(engine: "Engine", event: int) -> bool:
if after_ < event < before_ and (event - after_ - 1) % every == 0:
return True
return False

return wrapper

@staticmethod
def default_event_filter(engine: "Engine", event: int) -> bool:
"""Default event filter. This method is is deprecated and will be removed. Please, use None instead"""
Expand Down Expand Up @@ -301,6 +318,10 @@ def call_once(engine):
def call_before(engine):
# do something in 11 to 29 epoch

# e) Mixing "every" and "before" / "after" event filters
@engine.on(Events.EPOCH_STARTED(every=5, before=25, after=8))
def call_every_itr_before_after(engine):
# do something on 9, 14, 19, 24 epochs

Event filter function `event_filter` accepts as input `engine` and `event` and should return True/False.
Argument `event` is the value of iteration or epoch, depending on which type of Events the function is passed.
Expand Down
96 changes: 79 additions & 17 deletions tests/ignite/engine/test_custom_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,23 +129,62 @@ def process_func(engine, batch):


def test_callable_events_with_wrong_inputs():

with pytest.raises(
ValueError, match=r"Only one of the input arguments should be specified except before and after"
):
Events.ITERATION_STARTED()

with pytest.raises(
ValueError, match=r"Only one of the input arguments should be specified except before and after"
):
Events.ITERATION_STARTED(event_filter="123", every=12)

with pytest.raises(
ValueError, match=r"Only one of the input arguments should be specified except before and after"
):
Events.ITERATION_STARTED(after=10, before=30, once=1)

assert Events.ITERATION_STARTED(after=10, before=30)
def ef(e, i):
return 1

expected_raise = {
# event_filter, every, once, before, after
(None, None, None, None, None): True, # raises ValueError
(ef, None, None, None, None): False,
(None, 2, None, None, None): False,
(ef, 2, None, None, None): True,
(None, None, 2, None, None): False,
(ef, None, 2, None, None): True,
(None, 2, 2, None, None): True,
(ef, 2, 2, None, None): True,
(None, None, None, 30, None): False,
(ef, None, None, 30, None): True,
(None, 2, None, 30, None): False,
(ef, 2, None, 30, None): True,
(None, None, 2, 30, None): True,
(ef, None, 2, 30, None): True,
(None, 2, 2, 30, None): True,
(ef, 2, 2, 30, None): True,
# event_filter, every, once, before, after
(None, None, None, None, 10): False,
(ef, None, None, None, 10): True,
(None, 2, None, None, 10): False,
(ef, 2, None, None, 10): True,
(None, None, 2, None, 10): True,
(ef, None, 2, None, 10): True,
(None, 2, 2, None, 10): True,
(ef, 2, 2, None, 10): True,
(None, None, None, 25, 8): False,
(ef, None, None, 25, 8): True,
(None, 2, None, 25, 8): False,
(ef, 2, None, 25, 8): True,
(None, None, 2, 25, 8): True,
(ef, None, 2, 25, 8): True,
(None, 2, 2, 25, 8): True,
(ef, 2, 2, 25, 8): True,
}
for event_filter in [None, ef]:
for every in [None, 2]:
for once in [None, 2]:
for before, after in [(None, None), (None, 10), (30, None), (25, 8)]:
if expected_raise[(event_filter, every, once, before, after)]:
with pytest.raises(
ValueError,
match=r"Only one of the input arguments should be specified, "
"except before, after and every",
):
Events.ITERATION_STARTED(
event_filter=event_filter, once=once, every=every, before=before, after=after
)
else:
Events.ITERATION_STARTED(
event_filter=event_filter, once=once, every=every, before=before, after=after
)

with pytest.raises(TypeError, match=r"Argument event_filter should be a callable"):
Events.ITERATION_STARTED(event_filter="123")
Expand Down Expand Up @@ -408,6 +447,29 @@ def _before_and_after_event():
assert num_calls == expect_calls


@pytest.mark.parametrize(
"event_name, event_attr, every, before, after, expect_calls",
[(Events.ITERATION_STARTED, "iteration", 5, 25, 8, 4), (Events.EPOCH_COMPLETED, "epoch", 2, 5, 1, 2)],
)
def test_every_before_and_after_event_filter_with_engine(event_name, event_attr, every, before, after, expect_calls):

data = range(100)

engine = Engine(lambda e, b: 1)
num_calls = 0

@engine.on(event_name(every=every, before=before, after=after))
def _every_before_and_after_event():
assert getattr(engine.state, event_attr) > after
assert getattr(engine.state, event_attr) < before
assert ((getattr(engine.state, event_attr) - after - 1) % every) == 0
nonlocal num_calls
num_calls += 1

engine.run(data, max_epochs=5)
assert num_calls == expect_calls


@pytest.mark.parametrize(
"event_name, event_attr, once, expect_calls",
[
Expand Down