Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Worker killed error #78

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/torchrunx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from .launcher import AgentKilledError, Launcher, LaunchResult, launch
from .launcher import Launcher, LaunchResult, launch
from .logging_utils import add_filter_to_handler, file_handler, stream_handler
from .utils import AgentFailedError, WorkerFailedError

__all__ = [
"AgentKilledError",
"AgentFailedError",
"WorkerFailedError",
"Launcher",
"launch",
"LaunchResult",
Expand Down
12 changes: 8 additions & 4 deletions src/torchrunx/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from .utils import (
AgentPayload,
AgentStatus,
ExceptionFromWorker,
LauncherAgentGroup,
WorkerException,
get_open_port,
)

Expand Down Expand Up @@ -52,7 +52,7 @@ def deserialize(self) -> WorkerArgs:
return cloudpickle.loads(self.bytes)


def entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | WorkerException:
def entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | ExceptionFromWorker:
worker_args: WorkerArgs = serialized_worker_args.deserialize()

logger = logging.getLogger()
Expand Down Expand Up @@ -96,7 +96,7 @@ def entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | WorkerExce
return worker_args.function()
except Exception as e:
traceback.print_exc()
return WorkerException(exception=e)
return ExceptionFromWorker(exception=e)
finally:
sys.stdout.flush()
sys.stderr.flush()
Expand Down Expand Up @@ -155,7 +155,9 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_
)
for i in range(num_workers)
},
# environment variables from agent are already automatically copied to workers
envs={i: {} for i in range(num_workers)},
# we handle logging ourselves, so we can discard these
**(
{"logs_specs": dist_mp.DefaultLogsSpecs(log_dir=tempfile.mkdtemp())}
if torch.__version__ >= "2.3"
Expand All @@ -167,8 +169,10 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_
status = None
while True:
if status is None or status.state == "running":
status = AgentStatus.from_result(ctx.wait(5))
# status can contain ExceptionFromWorker or WorkerFailedError
status = AgentStatus.from_result(result=ctx.wait(5))

# can raise AgentFailedError in launcher and all agents
agent_statuses = launcher_agent_group.sync_agent_statuses(status=status)

all_done = all(s.state == "done" for s in agent_statuses)
Expand Down
24 changes: 13 additions & 11 deletions src/torchrunx/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@

from .environment import auto_hosts, auto_workers, slurm_hosts, slurm_workers
from .logging_utils import LogRecordSocketReceiver, default_handlers
from .utils import AgentStatus, LauncherAgentGroup, LauncherPayload, WorkerException, get_open_port


class AgentKilledError(Exception):
pass
from .utils import (
AgentStatus,
ExceptionFromWorker,
LauncherAgentGroup,
LauncherPayload,
WorkerFailedError,
get_open_port,
)


@dataclass
Expand Down Expand Up @@ -141,17 +144,16 @@ def run( # noqa: C901, PLR0912
# loop to monitor agent statuses (until failed or done)

while True:
try:
agent_statuses = launcher_agent_group.sync_agent_statuses(status=None)
except RuntimeError as e:
# occurs if any agent dies and communication times out
raise AgentKilledError from e
# could raise AgentFailedError
agent_statuses = launcher_agent_group.sync_agent_statuses(status=None)

# raises specific exception if any agent fails
for s in agent_statuses:
for value in s.return_values:
if isinstance(value, WorkerException):
if isinstance(value, ExceptionFromWorker):
raise value.exception
if isinstance(value, WorkerFailedError):
raise value

if all(s.state == "done" for s in agent_statuses):
break
Expand Down
32 changes: 22 additions & 10 deletions src/torchrunx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ def get_open_port() -> int:
return s.getsockname()[1]


class AgentFailedError(Exception):
pass


class WorkerFailedError(Exception):
pass


@dataclass
class LauncherAgentGroup:
launcher_hostname: str
Expand Down Expand Up @@ -50,11 +58,15 @@ def _deserialize(self, serialized: bytes) -> Any:

def _all_gather(self, obj: Any) -> list:
"""gather object from every rank to list on every rank"""
object_bytes = self._serialize(obj)
object_list = [b""] * self.world_size
# raises RuntimeError if timeout
dist.all_gather_object(object_list=object_list, obj=object_bytes, group=self.group)
return [self._deserialize(o) for o in object_list]
try:
object_bytes = self._serialize(obj)
object_list = [b""] * self.world_size
# raises RuntimeError if timeout
dist.all_gather_object(object_list=object_list, obj=object_bytes, group=self.group)
return [self._deserialize(o) for o in object_list]
except RuntimeError as e:
# occurs if launcher or any agent dies and communication times out
raise AgentFailedError from e

def sync_payloads(
self,
Expand Down Expand Up @@ -90,25 +102,25 @@ class AgentPayload:


@dataclass
class WorkerException:
class ExceptionFromWorker:
exception: Exception


@dataclass
class AgentStatus:
state: Literal["running", "failed", "done"]
return_values: list[Any | WorkerException] = field(
return_values: list[Any | WorkerFailedError | ExceptionFromWorker] = field(
default_factory=list
) # indexed by local rank

@classmethod
def from_result(cls, result: RunProcsResult | None) -> Self:
if result is None:
return cls(state="running")

for local_rank, failure in result.failures.items():
result.return_values[local_rank] = WorkerFailedError(failure.message)
return_values = list(result.return_values.values())

failed = any(isinstance(v, WorkerException) for v in return_values)
failed = any(isinstance(v, ExceptionFromWorker) for v in return_values)
state = "failed" if failed else "done"

return cls(
Expand Down