Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 114 additions & 17 deletions src/amltk/scheduling/plugins/pynisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,17 @@ def callback(exception):
""" # noqa: E501
from __future__ import annotations

import traceback
from collections.abc import Callable
from dataclasses import dataclass
from multiprocessing.context import BaseContext
from typing import TYPE_CHECKING, ClassVar, Literal, TypeAlias, TypeVar
from typing import TYPE_CHECKING, ClassVar, Generic, Literal, TypeAlias, TypeVar
from typing_extensions import ParamSpec, Self, override

import pynisher
import pynisher.exceptions

from amltk.optimization import Trial
from amltk.scheduling.events import Event
from amltk.scheduling.plugins.plugin import Plugin

Expand All @@ -104,6 +107,56 @@ def callback(exception):
R = TypeVar("R")


@dataclass
class _PynisherWrap(Generic[P, R]):
fn: Callable[P, R]
memory_limit: int | tuple[int, str] | None = None
cputime_limit: int | tuple[float, str] | None = None
walltime_limit: int | tuple[float, str] | None = None
terminate_child_processes: bool = True
context: BaseContext | None = None
disable_trial_handling: bool = False

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
if any(
limit is not None
for limit in (self.memory_limit, self.cputime_limit, self.walltime_limit)
):
fn = pynisher.Pynisher(
self.fn,
memory=self.memory_limit,
cpu_time=self.cputime_limit,
wall_time=self.walltime_limit,
terminate_child_processes=True,
context=self.context,
)
else:
fn = self.fn

trial: Trial | None = None
if not self.disable_trial_handling:
if len(args) > 0 and isinstance(args[0], Trial):
trial = args[0]
elif (_trial := kwargs.get("trial")) is not None:
if not isinstance(_trial, Trial):
raise ValueError(
f"Expected 'trial' to be a Trial instance, got {type(trial)}"
f"\n{trial=}",
)
trial = _trial

if trial is not None:
try:
return fn(*args, **kwargs)
except pynisher.PynisherException as e:
tb = traceback.format_exc()
trial.exception = e
trial.traceback = tb
return trial.fail() # type: ignore
else:
return fn(*args, **kwargs)


class PynisherPlugin(Plugin):
"""A plugin that wraps a task in a pynisher to enforce limits on it.

Expand Down Expand Up @@ -259,6 +312,7 @@ def __init__(
cputime_limit: int | tuple[float, str] | None = None,
walltime_limit: int | tuple[float, str] | None = None,
context: BaseContext | None = None,
disable_trial_handling: bool = False,
):
"""Initialize a `PynisherPlugin` instance.

Expand All @@ -274,12 +328,61 @@ def __init__(
`("s", "m", "h")`. Defaults to `None`.
context: The context to use for multiprocessing. Defaults to `None`.
See [`multiprocessing.get_context()`][multiprocessing.get_context]
disable_trial_handling: By default, the `PynisherPlugin` will auto-detect
if the task is one for a `Trial`. If so, it will catch any pynisher
specific exceptions and return a `Trial.Report` with `trial.fail()`,
instead of raising the expcetion. This has the effect that the report
can be caught with `task.on_result` where it can be handled. This will
also prevent the specific events `@pynisher-timeout`,
`@pynisher-memory-limit`, `@pynisher-cputime-limit`
and `@pynisher-walltime-limit` from being emitted.

If this
is `True`, then the pynisher exceptions will be raised as normal and
should be handled with `task.on_exception` where there is no direct
access to the `Trial` submitted.

??? note "Auto-Detection"

This will be triggered if the first positional argument is a
`Trial` or if any of the keyword arguments are `"trial"`.

```python
from amltk.optimization import Trial
from amltk.scheduling import Scheduler

def trial_evaluator_one(trial: Trial, ...) -> int:
...

def trial_evaluator_two(..., trial: Trial) -> int:
...

scheduler = Scheduler.with_processes(1)

task_one = scheduler.task(
trial_evaluator_one,
plugins=PynisherPlugin(memory_limit=(1, "gb")
)
task_two = scheduler.task(
trial_evaluator_two,
plugins=PynisherPlugin(memory_limit=(1, "gb")
)

# Will auto-detect
trial = Trial(...)
task_one.submit(trial, ...)
task_two.submit(..., trial=trial)

# Will not auto-detect
task_one.submit(42, trial)
```
"""
super().__init__()
self.memory_limit = memory_limit
self.cputime_limit = cputime_limit
self.walltime_limit = walltime_limit
self.context = context
self.disable_trial_handling = disable_trial_handling

self.task: Task

Expand All @@ -291,22 +394,16 @@ def pre_submit(
**kwargs: P.kwargs,
) -> tuple[Callable[P, R], tuple, dict]:
"""Wrap a task function in a `Pynisher` instance."""
# If any of our limits is set, we need to wrap it in Pynisher
# to enfore these limits.
if any(
limit is not None
for limit in (self.memory_limit, self.cputime_limit, self.walltime_limit)
):
fn = pynisher.Pynisher(
fn,
memory=self.memory_limit,
cpu_time=self.cputime_limit,
wall_time=self.walltime_limit,
terminate_child_processes=True,
context=self.context,
)

return fn, args, kwargs
_fn = _PynisherWrap(
fn,
disable_trial_handling=self.disable_trial_handling,
memory_limit=self.memory_limit,
cputime_limit=self.cputime_limit,
walltime_limit=self.walltime_limit,
context=self.context,
terminate_child_processes=True,
)
return _fn, args, kwargs

@override
def attach_task(self, task: Task) -> None:
Expand Down
114 changes: 114 additions & 0 deletions tests/scheduling/plugins/test_pynisher_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from distributed.cfexecutor import ClientExecutor
from pytest_cases import case, fixture, parametrize_with_cases

from amltk.optimization import Trial
from amltk.scheduling import ExitState, Scheduler
from amltk.scheduling.plugins.pynisher import PynisherPlugin

Expand Down Expand Up @@ -58,6 +59,25 @@ def big_memory_function(mem_in_bytes: int) -> bytearray:
return z # noqa: RET504


def trial_with_big_memory(trial: Trial, mem_in_bytes: int) -> Trial.Report:
with trial.begin():
pass

# We're particularly interested when the memory error happens during the
# task execution, not during the trial begin period
big_memory_function(mem_in_bytes)

return trial.success()


def trial_with_time_wasting(trial: Trial, duration: int) -> Trial.Report:
with trial.begin():
time_wasting_function(duration)

time_wasting_function(duration)
return trial.success()


def time_wasting_function(duration: int) -> int:
time.sleep(duration)
return duration
Expand Down Expand Up @@ -253,3 +273,97 @@ def start_task() -> None:
)
assert end_status.code == ExitState.Code.EXCEPTION
assert isinstance(end_status.exception, PynisherPlugin.WallTimeoutException)


def test_trial_gets_autodetect_memory(scheduler: Scheduler) -> None:
if not PynisherPlugin.supports("memory"):
pytest.skip("Pynisher does not support memory limits on this system")

one_half_gb = int(1e9 * 1.5)
two_gb = int(1e9) * 2
task = scheduler.task(
trial_with_big_memory,
plugins=PynisherPlugin(
memory_limit=one_half_gb,
disable_trial_handling=False,
),
)
trial = Trial(name="test_trial", config={})

@scheduler.on_start
def start_task() -> None:
task.submit(trial, mem_in_bytes=two_gb)

reports: list[Trial.Report] = []

@task.on_result
def trial_report(_, report: Trial.Report) -> None:
reports.append(report)

status = scheduler.run(on_exception="raise")
assert status.code == ExitState.Code.EXHAUSTED

assert task.event_counts == Counter(
{task.SUBMITTED: 1, task.DONE: 1, task.RESULT: 1},
)

assert scheduler.event_counts == Counter(
{
scheduler.STARTED: 1,
scheduler.FUTURE_RESULT: 1,
scheduler.FINISHING: 1,
scheduler.FINISHED: 1,
scheduler.EMPTY: 1,
scheduler.FUTURE_SUBMITTED: 1,
scheduler.FUTURE_DONE: 1,
},
)
assert len(reports) == 1
assert reports[0].status == Trial.Status.FAIL
assert isinstance(reports[0].exception, PynisherPlugin.MemoryLimitException)


def test_trial_gets_autodetect_time(scheduler: Scheduler) -> None:
if not PynisherPlugin.supports("wall_time"):
pytest.skip("Pynisher does not support wall_time limits on this system")

task = scheduler.task(
trial_with_time_wasting,
plugins=PynisherPlugin(
walltime_limit=1,
disable_trial_handling=False,
),
)
trial = Trial(name="test_trial", config={})

@scheduler.on_start
def start_task() -> None:
task.submit(trial=trial, duration=3)

reports: list[Trial.Report] = []

@task.on_result
def trial_report(_, report: Trial.Report) -> None:
reports.append(report)

status = scheduler.run(on_exception="raise")
assert status.code == ExitState.Code.EXHAUSTED

assert task.event_counts == Counter(
{task.SUBMITTED: 1, task.DONE: 1, task.RESULT: 1},
)

assert scheduler.event_counts == Counter(
{
scheduler.STARTED: 1,
scheduler.FUTURE_RESULT: 1,
scheduler.FINISHING: 1,
scheduler.FINISHED: 1,
scheduler.EMPTY: 1,
scheduler.FUTURE_SUBMITTED: 1,
scheduler.FUTURE_DONE: 1,
},
)
assert len(reports) == 1
assert reports[0].status == Trial.Status.FAIL
assert isinstance(reports[0].exception, PynisherPlugin.WallTimeoutException)