diff --git a/conftest.py b/conftest.py index b1c3bdacd..721f99a4d 100644 --- a/conftest.py +++ b/conftest.py @@ -64,6 +64,7 @@ RunSettings, SrunSettings, ) +from collections.abc import Callable, Collection logger = get_logger(__name__) @@ -79,7 +80,7 @@ test_alloc_specs_path = os.getenv("SMARTSIM_TEST_ALLOC_SPEC_SHEET_PATH", None) test_ports = CONFIG.test_ports test_account = CONFIG.test_account or "" -test_batch_resources: t.Dict[t.Any, t.Any] = CONFIG.test_batch_resources +test_batch_resources: dict[t.Any, t.Any] = CONFIG.test_batch_resources test_output_dirs = 0 mpi_app_exe = None built_mpi_app = False @@ -169,7 +170,7 @@ def pytest_sessionfinish( kill_all_test_spawned_processes() -def build_mpi_app() -> t.Optional[pathlib.Path]: +def build_mpi_app() -> pathlib.Path | None: global built_mpi_app built_mpi_app = True cc = shutil.which("cc") @@ -190,7 +191,7 @@ def build_mpi_app() -> t.Optional[pathlib.Path]: return None @pytest.fixture(scope="session") -def mpi_app_path() -> t.Optional[pathlib.Path]: +def mpi_app_path() -> pathlib.Path | None: """Return path to MPI app if it was built return None if it could not or will not be built @@ -223,7 +224,7 @@ def kill_all_test_spawned_processes() -> None: -def get_hostlist() -> t.Optional[t.List[str]]: +def get_hostlist() -> list[str] | None: global test_hostlist if not test_hostlist: if "PBS_NODEFILE" in os.environ and test_launcher == "pals": @@ -251,14 +252,14 @@ def get_hostlist() -> t.Optional[t.List[str]]: return test_hostlist -def _parse_hostlist_file(path: str) -> t.List[str]: +def _parse_hostlist_file(path: str) -> list[str]: with open(path, "r", encoding="utf-8") as nodefile: return list({line.strip() for line in nodefile.readlines()}) @pytest.fixture(scope="session") -def alloc_specs() -> t.Dict[str, t.Any]: - specs: t.Dict[str, t.Any] = {} +def alloc_specs() -> dict[str, t.Any]: + specs: dict[str, t.Any] = {} if test_alloc_specs_path: try: with open(test_alloc_specs_path, encoding="utf-8") as spec_file: @@ -293,7 +294,7 @@ def _reset(): ) -def _find_free_port(ports: t.Collection[int]) -> int: +def _find_free_port(ports: Collection[int]) -> int: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: for port in ports: try: @@ -310,7 +311,7 @@ def _find_free_port(ports: t.Collection[int]) -> int: @pytest.fixture(scope="session") -def wlmutils() -> t.Type[WLMUtils]: +def wlmutils() -> type[WLMUtils]: return WLMUtils @@ -335,22 +336,22 @@ def get_test_account() -> str: return get_account() @staticmethod - def get_test_interface() -> t.List[str]: + def get_test_interface() -> list[str]: return test_nic @staticmethod - def get_test_hostlist() -> t.Optional[t.List[str]]: + def get_test_hostlist() -> list[str] | None: return get_hostlist() @staticmethod - def get_batch_resources() -> t.Dict: + def get_batch_resources() -> dict: return test_batch_resources @staticmethod def get_base_run_settings( - exe: str, args: t.List[str], nodes: int = 1, ntasks: int = 1, **kwargs: t.Any + exe: str, args: list[str], nodes: int = 1, ntasks: int = 1, **kwargs: t.Any ) -> RunSettings: - run_args: t.Dict[str, t.Union[int, str, float, None]] = {} + run_args: dict[str, int, str | float | None] = {} if test_launcher == "slurm": run_args = {"--nodes": nodes, "--ntasks": ntasks, "--time": "00:10:00"} @@ -391,9 +392,9 @@ def get_base_run_settings( @staticmethod def get_run_settings( - exe: str, args: t.List[str], nodes: int = 1, ntasks: int = 1, **kwargs: t.Any + exe: str, args: list[str], nodes: int = 1, ntasks: int = 1, **kwargs: t.Any ) -> RunSettings: - run_args: t.Dict[str, t.Union[int, str, float, None]] = {} + run_args: dict[str, int, str | float | None] = {} if test_launcher == "slurm": run_args = {"nodes": nodes, "ntasks": ntasks, "time": "00:10:00"} @@ -423,7 +424,7 @@ def get_run_settings( return RunSettings(exe, args) @staticmethod - def choose_host(rs: RunSettings) -> t.Optional[str]: + def choose_host(rs: RunSettings) -> str | None: if isinstance(rs, (MpirunSettings, MpiexecSettings)): hl = get_hostlist() if hl is not None: @@ -450,13 +451,13 @@ def check_output_dir() -> None: @pytest.fixture -def dbutils() -> t.Type[DBUtils]: +def dbutils() -> type[DBUtils]: return DBUtils class DBUtils: @staticmethod - def get_db_configs() -> t.Dict[str, t.Any]: + def get_db_configs() -> dict[str, t.Any]: config_settings = { "enable_checkpoints": 1, "set_max_memory": "3gb", @@ -470,7 +471,7 @@ def get_db_configs() -> t.Dict[str, t.Any]: return config_settings @staticmethod - def get_smartsim_error_db_configs() -> t.Dict[str, t.Any]: + def get_smartsim_error_db_configs() -> dict[str, t.Any]: bad_configs = { "save": [ "-1", # frequency must be positive @@ -497,8 +498,8 @@ def get_smartsim_error_db_configs() -> t.Dict[str, t.Any]: return bad_configs @staticmethod - def get_type_error_db_configs() -> t.Dict[t.Union[int, str], t.Any]: - bad_configs: t.Dict[t.Union[int, str], t.Any] = { + def get_type_error_db_configs() -> dict[int | str, t.Any]: + bad_configs: dict[int | str, t.Any] = { "save": [2, True, ["2"]], # frequency must be specified as a string "maxmemory": [99, True, ["99"]], # memory form must be a string "maxclients": [3, True, ["3"]], # number of clients must be a string @@ -519,9 +520,9 @@ def get_type_error_db_configs() -> t.Dict[t.Union[int, str], t.Any]: @staticmethod def get_config_edit_method( db: Orchestrator, config_setting: str - ) -> t.Optional[t.Callable[..., None]]: + ) -> Callable[..., None] | None: """Get a db configuration file edit method from a str""" - config_edit_methods: t.Dict[str, t.Callable[..., None]] = { + config_edit_methods: dict[str, Callable[..., None]] = { "enable_checkpoints": db.enable_checkpoints, "set_max_memory": db.set_max_memory, "set_eviction_strategy": db.set_eviction_strategy, @@ -564,7 +565,7 @@ def test_dir(request: pytest.FixtureRequest) -> str: @pytest.fixture -def fileutils() -> t.Type[FileUtils]: +def fileutils() -> type[FileUtils]: return FileUtils @@ -589,7 +590,7 @@ def get_test_dir_path(dirname: str) -> str: @staticmethod def make_test_file( - file_name: str, file_dir: str, file_content: t.Optional[str] = None + file_name: str, file_dir: str, file_content: str | None = None ) -> str: """Create a dummy file in the test output directory. @@ -609,7 +610,7 @@ def make_test_file( @pytest.fixture -def mlutils() -> t.Type[MLUtils]: +def mlutils() -> type[MLUtils]: return MLUtils @@ -624,21 +625,21 @@ def get_test_num_gpus() -> int: @pytest.fixture -def coloutils() -> t.Type[ColoUtils]: +def coloutils() -> type[ColoUtils]: return ColoUtils class ColoUtils: @staticmethod def setup_test_colo( - fileutils: t.Type[FileUtils], + fileutils: type[FileUtils], db_type: str, exp: Experiment, application_file: str, - db_args: t.Dict[str, t.Any], - colo_settings: t.Optional[RunSettings] = None, + db_args: dict[str, t.Any], + colo_settings: RunSettings | None = None, colo_model_name: str = "colocated_model", - port: t.Optional[int] = None, + port: int | None = None, on_wlm: bool = False, ) -> Model: """Setup database needed for the colo pinning tests""" @@ -666,7 +667,7 @@ def setup_test_colo( socket_name = f"{colo_model_name}_{socket_suffix}.socket" db_args["unix_socket"] = os.path.join(tmp_dir, socket_name) - colocate_fun: t.Dict[str, t.Callable[..., None]] = { + colocate_fun: dict[str, Callable[..., None]] = { "tcp": colo_model.colocate_db_tcp, "deprecated": colo_model.colocate_db, "uds": colo_model.colocate_db_uds, @@ -708,7 +709,7 @@ def config() -> Config: class CountingCallable: def __init__(self) -> None: self._num: int = 0 - self._details: t.List[t.Tuple[t.Tuple[t.Any, ...], t.Dict[str, t.Any]]] = [] + self._details: list[tuple[tuple[t.Any, ...], dict[str, t.Any]]] = [] def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any: self._num += 1 @@ -719,12 +720,12 @@ def num_calls(self) -> int: return self._num @property - def details(self) -> t.List[t.Tuple[t.Tuple[t.Any, ...], t.Dict[str, t.Any]]]: + def details(self) -> list[tuple[tuple[t.Any, ...], dict[str, t.Any]]]: return self._details ## Reuse database across tests -database_registry: t.DefaultDict[str, t.Optional[Orchestrator]] = defaultdict(lambda: None) +database_registry: defaultdict[str, Orchestrator | None] = defaultdict(lambda: None) @pytest.fixture(scope="function") def local_experiment(test_dir: str) -> smartsim.Experiment: @@ -758,13 +759,13 @@ class DBConfiguration: name: str launcher: str num_nodes: int - interface: t.Union[str,t.List[str]] - hostlist: t.Optional[t.List[str]] + interface: str | list[str] + hostlist: list[str] | None port: int @dataclass class PrepareDatabaseOutput: - orchestrator: t.Optional[Orchestrator] # The actual orchestrator object + orchestrator: Orchestrator | None # The actual orchestrator object new_db: bool # True if a new database was created when calling prepare_db # Reuse databases @@ -817,7 +818,7 @@ def clustered_db(wlmutils: WLMUtils) -> t.Generator[DBConfiguration, None, None] @pytest.fixture -def register_new_db() -> t.Callable[[DBConfiguration], Orchestrator]: +def register_new_db() -> Callable[[DBConfiguration], Orchestrator]: def _register_new_db( config: DBConfiguration ) -> Orchestrator: @@ -845,11 +846,11 @@ def _register_new_db( @pytest.fixture(scope="function") def prepare_db( - register_new_db: t.Callable[ + register_new_db: Callable[ [DBConfiguration], Orchestrator ] -) -> t.Callable[ +) -> Callable[ [DBConfiguration], PrepareDatabaseOutput ]: diff --git a/doc/changelog.md b/doc/changelog.md index 7c7ec4c78..2500a317a 100644 --- a/doc/changelog.md +++ b/doc/changelog.md @@ -11,6 +11,9 @@ To be released at some point in the future Description +- Modernize typing syntax to Python 3.10+ standards +- **BREAKING CHANGE**: Removed telemetry functionality, LaunchedManifest tracking + classes, and SmartDashboard integration - Removed telemetry functionality, LaunchedManifest tracking classes, and SmartDashboard integration - Update copyright headers from 2021-2024 to 2021-2025 across the entire codebase @@ -24,7 +27,11 @@ Description Detailed Notes -- Removed telemetry functionality, LaunchedManifest tracking +- Modernized typing syntax to use Python 3.10+ standards, replacing + `Union[X, Y]` with `X | Y`, `Optional[X]` with `X | None`, and generic + collections (`List[X]` → `list[X]`, `Dict[X, Y]` → `dict[X, Y]`, etc.). + ([SmartSim-PR791](https://github.com/CrayLabs/SmartSim/pull/791)) +- **BREAKING CHANGE**: Removed telemetry functionality, LaunchedManifest tracking system, and SmartDashboard integration. This includes complete removal of the telemetry monitor and collection system, telemetry configuration classes (`TelemetryConfiguration`, @@ -45,10 +52,6 @@ Detailed Notes `.smartsim/metadata/` structure for job output files with entity-specific subdirectories (`ensemble/{name}`, `model/{name}`, `database/{name}`) and proper symlink management. - Added new `CONFIG` path helpers (`smartsim_base_dir`, `metadata_subdir`, - `dragon_default_subdir`, `dragon_logs_subdir`) that now return - `pathlib.Path` instances to provide a single source of truth for SmartSim's - hidden workspace directories and Dragon launcher log locations. ([SmartSim-PR789](https://github.com/CrayLabs/SmartSim/pull/789)) - Copyright headers have been updated from "2021-2024" to "2021-2025" across 271 files including Python source files, configuration files, documentation, diff --git a/smartsim/_core/_cli/build.py b/smartsim/_core/_cli/build.py index 18863e7d1..e3ce64f23 100644 --- a/smartsim/_core/_cli/build.py +++ b/smartsim/_core/_cli/build.py @@ -31,7 +31,7 @@ import re import shutil import textwrap -import typing as t +from collections.abc import Callable, Collection from pathlib import Path from tabulate import tabulate @@ -139,7 +139,7 @@ def build_redis_ai( def parse_requirement( requirement: str, -) -> t.Tuple[str, t.Optional[str], t.Callable[[Version_], bool]]: +) -> tuple[str, str | None, Callable[[Version_], bool]]: operators = { "==": operator.eq, "<=": operator.le, @@ -199,10 +199,10 @@ def check_ml_python_packages(packages: MLPackageCollection) -> None: def _format_incompatible_python_env_message( - missing: t.Collection[str], conflicting: t.Collection[str] + missing: Collection[str], conflicting: Collection[str] ) -> str: indent = "\n\t" - fmt_list: t.Callable[[str, t.Collection[str]], str] = lambda n, l: ( + fmt_list: Callable[[str, Collection[str]], str] = lambda n, l: ( f"{n}:{indent}{indent.join(l)}" if l else "" ) missing_str = fmt_list("Missing", missing) @@ -237,7 +237,7 @@ def _configure_keydb_build(versions: Versioner) -> None: # pylint: disable-next=too-many-statements def execute( - args: argparse.Namespace, _unparsed_args: t.Optional[t.List[str]] = None, / + args: argparse.Namespace, _unparsed_args: list[str] | None = None, / ) -> int: # Unpack various arguments diff --git a/smartsim/_core/_cli/clean.py b/smartsim/_core/_cli/clean.py index 2a60e7b36..eec3549e2 100644 --- a/smartsim/_core/_cli/clean.py +++ b/smartsim/_core/_cli/clean.py @@ -25,7 +25,6 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import argparse -import typing as t from smartsim._core._cli.utils import clean, get_install_path @@ -41,13 +40,13 @@ def configure_parser(parser: argparse.ArgumentParser) -> None: def execute( - args: argparse.Namespace, _unparsed_args: t.Optional[t.List[str]] = None, / + args: argparse.Namespace, _unparsed_args: list[str] | None = None, / ) -> int: return clean(get_install_path() / "_core", _all=args.clobber) def execute_all( - args: argparse.Namespace, _unparsed_args: t.Optional[t.List[str]] = None, / + args: argparse.Namespace, _unparsed_args: list[str] | None = None, / ) -> int: args.clobber = True return execute(args) diff --git a/smartsim/_core/_cli/cli.py b/smartsim/_core/_cli/cli.py index a19037158..ce7a49011 100644 --- a/smartsim/_core/_cli/cli.py +++ b/smartsim/_core/_cli/cli.py @@ -28,7 +28,6 @@ import argparse import os -import typing as t from smartsim._core._cli.build import configure_parser as build_parser from smartsim._core._cli.build import execute as build_execute @@ -47,8 +46,8 @@ class SmartCli: - def __init__(self, menu: t.List[MenuItemConfig]) -> None: - self.menu: t.Dict[str, MenuItemConfig] = {} + def __init__(self, menu: list[MenuItemConfig]) -> None: + self.menu: dict[str, MenuItemConfig] = {} self.parser = argparse.ArgumentParser( prog="smart", description="SmartSim command line interface", @@ -66,7 +65,7 @@ def __init__(self, menu: t.List[MenuItemConfig]) -> None: plugin_items = [plugin() for plugin in plugins] self.register_menu_items(plugin_items) - def execute(self, cli_args: t.List[str]) -> int: + def execute(self, cli_args: list[str]) -> int: if len(cli_args) < 2: self.parser.print_help() return os.EX_USAGE @@ -101,7 +100,7 @@ def _register_menu_item(self, item: MenuItemConfig) -> None: self.menu[item.command] = item - def register_menu_items(self, menu_items: t.List[MenuItemConfig]) -> None: + def register_menu_items(self, menu_items: list[MenuItemConfig]) -> None: for item in menu_items: self._register_menu_item(item) diff --git a/smartsim/_core/_cli/dbcli.py b/smartsim/_core/_cli/dbcli.py index cbf7f59b0..53f980301 100644 --- a/smartsim/_core/_cli/dbcli.py +++ b/smartsim/_core/_cli/dbcli.py @@ -26,13 +26,12 @@ import argparse import os -import typing as t from smartsim._core._cli.utils import get_db_path def execute( - _args: argparse.Namespace, _unparsed_args: t.Optional[t.List[str]] = None, / + _args: argparse.Namespace, _unparsed_args: list[str] | None = None, / ) -> int: if db_path := get_db_path(): print(db_path) diff --git a/smartsim/_core/_cli/info.py b/smartsim/_core/_cli/info.py index c08fcb1a3..a72c73f64 100644 --- a/smartsim/_core/_cli/info.py +++ b/smartsim/_core/_cli/info.py @@ -2,7 +2,6 @@ import importlib.metadata import os import pathlib -import typing as t from tabulate import tabulate @@ -14,7 +13,7 @@ def execute( - _args: argparse.Namespace, _unparsed_args: t.Optional[t.List[str]] = None, / + _args: argparse.Namespace, _unparsed_args: list[str] | None = None, / ) -> int: print("\nSmart Python Packages:") print( @@ -72,7 +71,7 @@ def execute( return os.EX_OK -def _fmt_installed_db(db_path: t.Optional[pathlib.Path]) -> str: +def _fmt_installed_db(db_path: pathlib.Path | None) -> str: if db_path is None: return _MISSING_DEP db_name, _ = db_path.name.split("-", 1) diff --git a/smartsim/_core/_cli/plugin.py b/smartsim/_core/_cli/plugin.py index 9540aa2e0..f59db0201 100644 --- a/smartsim/_core/_cli/plugin.py +++ b/smartsim/_core/_cli/plugin.py @@ -3,7 +3,7 @@ import os import subprocess as sp import sys -import typing as t +from collections.abc import Callable import smartsim.log from smartsim._core._cli.utils import SMART_LOGGER_FORMAT, MenuItemConfig @@ -14,10 +14,8 @@ def dynamic_execute( cmd: str, plugin_name: str -) -> t.Callable[[argparse.Namespace, t.List[str]], int]: - def process_execute( - _args: argparse.Namespace, unparsed_args: t.List[str], / - ) -> int: +) -> Callable[[argparse.Namespace, list[str]], int]: + def process_execute(_args: argparse.Namespace, unparsed_args: list[str], /) -> int: try: spec = importlib.util.find_spec(cmd) if spec is None: @@ -39,4 +37,4 @@ def process_execute( # No plugins currently available -plugins: t.Tuple[t.Callable[[], MenuItemConfig], ...] = () +plugins: tuple[Callable[[], MenuItemConfig], ...] = () diff --git a/smartsim/_core/_cli/scripts/dragon_install.py b/smartsim/_core/_cli/scripts/dragon_install.py index cfdc51a9b..45a06f6e5 100644 --- a/smartsim/_core/_cli/scripts/dragon_install.py +++ b/smartsim/_core/_cli/scripts/dragon_install.py @@ -2,6 +2,7 @@ import pathlib import sys import typing as t +from collections.abc import Collection from github import Github from github.GitReleaseAsset import GitReleaseAsset @@ -83,7 +84,7 @@ def _pin_filter(asset_name: str) -> bool: return f"dragon-{dragon_pin()}" in asset_name -def _get_release_assets() -> t.Collection[GitReleaseAsset]: +def _get_release_assets() -> Collection[GitReleaseAsset]: """Retrieve a collection of available assets for all releases that satisfy the dragon version pin @@ -107,7 +108,7 @@ def _get_release_assets() -> t.Collection[GitReleaseAsset]: return assets -def filter_assets(assets: t.Collection[GitReleaseAsset]) -> t.Optional[GitReleaseAsset]: +def filter_assets(assets: Collection[GitReleaseAsset]) -> GitReleaseAsset | None: """Filter the available release assets so that HSTA agents are used when run on a Cray EX platform @@ -191,7 +192,7 @@ def install_package(asset_dir: pathlib.Path) -> int: def cleanup( - archive_path: t.Optional[pathlib.Path] = None, + archive_path: pathlib.Path | None = None, ) -> None: """Delete the downloaded asset and any files extracted during installation @@ -201,7 +202,7 @@ def cleanup( logger.debug(f"Deleted archive: {archive_path}") -def install_dragon(extraction_dir: t.Union[str, os.PathLike[str]]) -> int: +def install_dragon(extraction_dir: str | os.PathLike[str]) -> int: """Retrieve a dragon runtime appropriate for the current platform and install to the current python environment :param extraction_dir: path for download and extraction of assets @@ -211,8 +212,8 @@ def install_dragon(extraction_dir: t.Union[str, os.PathLike[str]]) -> int: return 1 extraction_dir = pathlib.Path(extraction_dir) - filename: t.Optional[pathlib.Path] = None - asset_dir: t.Optional[pathlib.Path] = None + filename: pathlib.Path | None = None + asset_dir: pathlib.Path | None = None try: asset_info = retrieve_asset_info() diff --git a/smartsim/_core/_cli/site.py b/smartsim/_core/_cli/site.py index 076fc0de7..e2c8e2813 100644 --- a/smartsim/_core/_cli/site.py +++ b/smartsim/_core/_cli/site.py @@ -26,11 +26,10 @@ import argparse import os -import typing as t from smartsim._core._cli.utils import get_install_path -def execute(_args: argparse.Namespace, _unparsed_args: t.List[str], /) -> int: +def execute(_args: argparse.Namespace, _unparsed_args: list[str], /) -> int: print(get_install_path()) return os.EX_OK diff --git a/smartsim/_core/_cli/teardown.py b/smartsim/_core/_cli/teardown.py index 8e900b0e6..9d4d32572 100644 --- a/smartsim/_core/_cli/teardown.py +++ b/smartsim/_core/_cli/teardown.py @@ -27,7 +27,6 @@ import argparse import os import subprocess -import typing as t from smartsim._core.config import CONFIG @@ -66,7 +65,7 @@ def _do_dragon_teardown() -> int: def execute( - args: argparse.Namespace, _unparsed_args: t.Optional[t.List[str]] = None, / + args: argparse.Namespace, _unparsed_args: list[str] | None = None, / ) -> int: if args.dragon: return _do_dragon_teardown() diff --git a/smartsim/_core/_cli/utils.py b/smartsim/_core/_cli/utils.py index 1e55c9017..44a668b6e 100644 --- a/smartsim/_core/_cli/utils.py +++ b/smartsim/_core/_cli/utils.py @@ -29,8 +29,8 @@ import shutil import subprocess as sp import sys -import typing as t from argparse import ArgumentParser, Namespace +from collections.abc import Callable from pathlib import Path from smartsim._core._install.buildenv import SetupError @@ -118,7 +118,7 @@ def clean(core_path: Path, _all: bool = False) -> int: return os.EX_OK -def get_db_path() -> t.Optional[Path]: +def get_db_path() -> Path | None: bin_path = get_install_path() / "_core" / "bin" for option in bin_path.iterdir(): if option.name in ("redis-cli", "keydb-cli"): @@ -126,8 +126,8 @@ def get_db_path() -> t.Optional[Path]: return None -_CliHandler = t.Callable[[Namespace, t.List[str]], int] -_CliParseConfigurator = t.Callable[[ArgumentParser], None] +_CliHandler = Callable[[Namespace, list[str]], int] +_CliParseConfigurator = Callable[[ArgumentParser], None] class MenuItemConfig: @@ -136,7 +136,7 @@ def __init__( cmd: str, description: str, handler: _CliHandler, - configurator: t.Optional[_CliParseConfigurator] = None, + configurator: _CliParseConfigurator | None = None, is_plugin: bool = False, ): self.command = cmd diff --git a/smartsim/_core/_cli/validate.py b/smartsim/_core/_cli/validate.py index da382f93f..bf1c48eed 100644 --- a/smartsim/_core/_cli/validate.py +++ b/smartsim/_core/_cli/validate.py @@ -31,6 +31,7 @@ import os.path import tempfile import typing as t +from collections.abc import Callable, Mapping from types import TracebackType import numpy as np @@ -68,9 +69,9 @@ class _VerificationTempDir(_TemporaryDirectory): def __exit__( self, - exc: t.Optional[t.Type[BaseException]], - value: t.Optional[BaseException], - tb: t.Optional[TracebackType], + exc: type[BaseException] | None, + value: BaseException | None, + tb: TracebackType | None, ) -> None: if not value: # Yay, no error! Clean up as normal super().__exit__(exc, value, tb) @@ -79,7 +80,7 @@ def __exit__( def execute( - args: argparse.Namespace, _unparsed_args: t.Optional[t.List[str]] = None, / + args: argparse.Namespace, _unparsed_args: list[str] | None = None, / ) -> int: """Validate the SmartSim installation works as expected given a simple experiment @@ -143,7 +144,7 @@ def configure_parser(parser: argparse.ArgumentParser) -> None: def test_install( location: str, - port: t.Optional[int], + port: int | None, device: Device, with_tf: bool, with_pt: bool, @@ -169,9 +170,7 @@ def test_install( @contextlib.contextmanager -def _env_vars_set_to( - evars: t.Mapping[str, t.Optional[str]] -) -> t.Generator[None, None, None]: +def _env_vars_set_to(evars: Mapping[str, str | None]) -> t.Generator[None, None, None]: envvars = tuple((var, os.environ.pop(var, None), val) for var, val in evars.items()) for var, _, tmpval in envvars: _set_or_del_env_var(var, tmpval) @@ -182,7 +181,7 @@ def _env_vars_set_to( _set_or_del_env_var(var, origval) -def _set_or_del_env_var(var: str, val: t.Optional[str]) -> None: +def _set_or_del_env_var(var: str, val: str | None) -> None: if val is not None: os.environ[var] = val else: @@ -221,7 +220,7 @@ def _test_tf_install(client: Client, tmp_dir: str, device: Device) -> None: client.get_tensor("keras-output") -def _build_tf_frozen_model(tmp_dir: str) -> t.Tuple[str, t.List[str], t.List[str]]: +def _build_tf_frozen_model(tmp_dir: str) -> tuple[str, list[str], list[str]]: from tensorflow import keras # pylint: disable=no-name-in-module @@ -250,7 +249,7 @@ def _test_torch_install(client: Client, device: Device) -> None: class Net(nn.Module): def __init__(self) -> None: super().__init__() - self.conv: t.Callable[..., torch.Tensor] = nn.Conv2d(1, 1, 3) + self.conv: Callable[..., torch.Tensor] = nn.Conv2d(1, 1, 3) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.conv(x) diff --git a/smartsim/_core/_install/buildenv.py b/smartsim/_core/_install/buildenv.py index 463b9c413..f453187e7 100644 --- a/smartsim/_core/_install/buildenv.py +++ b/smartsim/_core/_install/buildenv.py @@ -64,7 +64,7 @@ class Version_(str): @staticmethod def _convert_to_version( - vers: t.Union[str, Iterable[Version], Version], + vers: str | Iterable[Version] | Version, ) -> t.Any: if isinstance(vers, Version): return vers @@ -172,7 +172,7 @@ class Versioner: ) REDISAI_BRANCH = get_env("SMARTSIM_REDISAI_BRANCH", f"v{REDISAI}") - def as_dict(self, db_name: DbEngine = "REDIS") -> t.Dict[str, t.Tuple[str, ...]]: + def as_dict(self, db_name: DbEngine = "REDIS") -> dict[str, tuple[str, ...]]: pkg_map = { "SMARTSIM": self.SMARTSIM, db_name: self.REDIS, @@ -259,7 +259,7 @@ def check_dependencies(self) -> None: for dep in deps: self.check_build_dependency(dep) - def __call__(self) -> t.Dict[str, str]: + def __call__(self) -> dict[str, str]: # return the build env for the build process env = os.environ.copy() env.update( @@ -272,8 +272,8 @@ def __call__(self) -> t.Dict[str, str]: ) return env - def as_dict(self) -> t.Dict[str, t.List[str]]: - variables: t.List[str] = [ + def as_dict(self) -> dict[str, list[str]]: + variables: list[str] = [ "CC", "CXX", "CFLAGS", @@ -283,7 +283,7 @@ def as_dict(self) -> t.Dict[str, t.List[str]]: "PYTHON_VERSION", "PLATFORM", ] - values: t.List[str] = [ + values: list[str] = [ self.CC, self.CXX, self.CFLAGS, @@ -316,7 +316,7 @@ def is_macos(cls) -> bool: return cls.PLATFORM == "darwin" @staticmethod - def get_cudnn_env() -> t.Optional[t.Dict[str, str]]: + def get_cudnn_env() -> dict[str, str] | None: """Collect the environment variables needed for Caffe (Pytorch) and throw an error if they are not found diff --git a/smartsim/_core/_install/builder.py b/smartsim/_core/_install/builder.py index 2bb5a9902..bae2db896 100644 --- a/smartsim/_core/_install/builder.py +++ b/smartsim/_core/_install/builder.py @@ -38,12 +38,9 @@ from smartsim._core._install.utils import retrieve from smartsim._core.utils import expand_exe_path -if t.TYPE_CHECKING: - from typing_extensions import Never - # TODO: check cmake version and use system if possible to avoid conflicts -_PathLike = t.Union[str, "os.PathLike[str]"] +_PathLike = str | os.PathLike[str] _T = t.TypeVar("_T") _U = t.TypeVar("_U") @@ -67,7 +64,7 @@ class Builder: def __init__( self, - env: t.Dict[str, str], + env: dict[str, str], jobs: int = 1, verbose: bool = False, ) -> None: @@ -99,7 +96,7 @@ def __init__( self.jobs = jobs @property - def out(self) -> t.Optional[int]: + def out(self) -> int | None: return None if self.verbose else subprocess.DEVNULL # implemented in base classes @@ -115,16 +112,12 @@ def binary_path(binary: str) -> str: raise BuildError(f"{binary} not found in PATH") @staticmethod - def copy_file( - src: t.Union[str, Path], dst: t.Union[str, Path], set_exe: bool = False - ) -> None: + def copy_file(src: str | Path, dst: str | Path, set_exe: bool = False) -> None: shutil.copyfile(src, dst) if set_exe: Path(dst).chmod(stat.S_IXUSR | stat.S_IWUSR | stat.S_IRUSR) - def copy_dir( - self, src: t.Union[str, Path], dst: t.Union[str, Path], set_exe: bool = False - ) -> None: + def copy_dir(self, src: str | Path, dst: str | Path, set_exe: bool = False) -> None: src = Path(src) dst = Path(dst) dst.mkdir(exist_ok=True) @@ -144,10 +137,10 @@ def cleanup(self) -> None: def run_command( self, - cmd: t.List[str], + cmd: list[str], shell: bool = False, - out: t.Optional[int] = None, - cwd: t.Union[str, Path, None] = None, + out: int | None = None, + cwd: str | Path | None = None, ) -> None: # option to manually disable output if necessary if not out: @@ -179,7 +172,7 @@ class DatabaseBuilder(Builder): def __init__( self, - build_env: t.Optional[t.Dict[str, str]] = None, + build_env: dict[str, str] | None = None, malloc: str = "libc", jobs: int = 1, verbose: bool = False, diff --git a/smartsim/_core/_install/mlpackages.py b/smartsim/_core/_install/mlpackages.py index b5bae5845..baf978d36 100644 --- a/smartsim/_core/_install/mlpackages.py +++ b/smartsim/_core/_install/mlpackages.py @@ -31,7 +31,7 @@ import subprocess import sys import typing as t -from collections.abc import MutableMapping +from collections.abc import MutableMapping, Sequence from dataclasses import dataclass from tabulate import tabulate @@ -73,9 +73,9 @@ class MLPackage: name: str version: str pip_index: str - python_packages: t.List[str] + python_packages: list[str] lib_source: PathLike - rai_patches: t.Tuple[RAIPatch, ...] = () + rai_patches: tuple[RAIPatch, ...] = () def retrieve(self, destination: PathLike) -> None: """Retrieve an archive and/or repository for the package @@ -105,7 +105,7 @@ class MLPackageCollection(MutableMapping[str, MLPackage]): Define a collection of MLPackages available for a specific platform """ - def __init__(self, platform: Platform, ml_packages: t.Sequence[MLPackage]): + def __init__(self, platform: Platform, ml_packages: Sequence[MLPackage]): self.platform = platform self._ml_packages = {pkg.name: pkg for pkg in ml_packages} @@ -173,7 +173,7 @@ def __str__(self, tablefmt: str = "github") -> str: def load_platform_configs( config_file_path: pathlib.Path, -) -> t.Dict[Platform, MLPackageCollection]: +) -> dict[Platform, MLPackageCollection]: """Create MLPackageCollections from JSON files in directory :param config_file_path: Directory with JSON files describing the diff --git a/smartsim/_core/_install/platform.py b/smartsim/_core/_install/platform.py index 60d704101..0b5fe6142 100644 --- a/smartsim/_core/_install/platform.py +++ b/smartsim/_core/_install/platform.py @@ -29,7 +29,6 @@ import os import pathlib import platform -import typing as t from dataclasses import dataclass from typing_extensions import Self @@ -98,7 +97,7 @@ def from_str(cls, str_: str) -> "Device": return cls(str_) @classmethod - def detect_cuda_version(cls) -> t.Optional["Device"]: + def detect_cuda_version(cls) -> "Device | None": """Find the enum based on environment CUDA :return: Enum for the version of CUDA currently available @@ -112,7 +111,7 @@ def detect_cuda_version(cls) -> t.Optional["Device"]: return None @classmethod - def detect_rocm_version(cls) -> t.Optional["Device"]: + def detect_rocm_version(cls) -> "Device | None": """Find the enum based on environment ROCm :return: Enum for the version of ROCm currently available @@ -149,7 +148,7 @@ def is_rocm(self) -> bool: return self in cls.rocm_enums() @classmethod - def cuda_enums(cls) -> t.Tuple["Device", ...]: + def cuda_enums(cls) -> tuple["Device", ...]: """Detect all CUDA devices supported by SmartSim :return: all enums associated with CUDA @@ -157,7 +156,7 @@ def cuda_enums(cls) -> t.Tuple["Device", ...]: return tuple(device for device in cls if "cuda" in device.value) @classmethod - def rocm_enums(cls) -> t.Tuple["Device", ...]: + def rocm_enums(cls) -> tuple["Device", ...]: """Detect all ROCm devices supported by SmartSim :return: all enums associated with ROCm diff --git a/smartsim/_core/_install/redisaiBuilder.py b/smartsim/_core/_install/redisaiBuilder.py index dc8872e03..253d00eeb 100644 --- a/smartsim/_core/_install/redisaiBuilder.py +++ b/smartsim/_core/_install/redisaiBuilder.py @@ -59,9 +59,9 @@ def __init__( build_env: BuildEnv, main_build_path: pathlib.Path, verbose: bool = False, - source: t.Union[ - str, pathlib.Path - ] = "https://github.com/RedisAI/redis-inference-optimization.git", + source: ( + str | pathlib.Path + ) = "https://github.com/RedisAI/redis-inference-optimization.git", version: str = "v1.2.7", ) -> None: @@ -196,7 +196,7 @@ def _set_execute(target: pathlib.Path) -> None: @staticmethod def _find_closest_object( start_path: pathlib.Path, target_obj: str - ) -> t.Optional[pathlib.Path]: + ) -> pathlib.Path | None: queue = deque([start_path]) while queue: current_dir = queue.popleft() @@ -234,7 +234,7 @@ def _prepare_packages(self) -> None: for file in actual_root.iterdir(): file.rename(target_dir / file.name) - def run_command(self, cmd: t.Union[str, t.List[str]], cwd: pathlib.Path) -> None: + def run_command(self, cmd: str | list[str], cwd: pathlib.Path) -> None: """Executor of commands usedi in the build :param cmd: The actual command to execute @@ -252,7 +252,7 @@ def run_command(self, cmd: t.Union[str, t.List[str]], cwd: pathlib.Path) -> None f"RedisAI build failed during command: {' '.join(cmd)}" ) - def _rai_cmake_cmd(self) -> t.List[str]: + def _rai_cmake_cmd(self) -> list[str]: """Build the CMake configuration command :return: CMake command with correct options @@ -281,7 +281,7 @@ def on_off(expression: bool) -> t.Literal["ON", "OFF"]: return cmd @property - def _rai_build_cmd(self) -> t.List[str]: + def _rai_build_cmd(self) -> list[str]: """Shell command to build RedisAI and modules With the CMake based install, very little needs to be done here. @@ -293,7 +293,7 @@ def _rai_build_cmd(self) -> t.List[str]: """ return "make install -j VERBOSE=1".split(" ") - def _patch_source_files(self, patches: t.Tuple[RAIPatch, ...]) -> None: + def _patch_source_files(self, patches: tuple[RAIPatch, ...]) -> None: """Apply specified RedisAI patches""" for patch in patches: with fileinput.input( diff --git a/smartsim/_core/_install/types.py b/smartsim/_core/_install/types.py index 9f57b928b..c3b2e6c83 100644 --- a/smartsim/_core/_install/types.py +++ b/smartsim/_core/_install/types.py @@ -25,6 +25,5 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import pathlib -import typing as t -PathLike = t.Union[str, pathlib.Path] +PathLike = str | pathlib.Path diff --git a/smartsim/_core/_install/utils/retrieve.py b/smartsim/_core/_install/utils/retrieve.py index bc1da7d3e..b5f019576 100644 --- a/smartsim/_core/_install/utils/retrieve.py +++ b/smartsim/_core/_install/utils/retrieve.py @@ -51,8 +51,8 @@ class _TqdmUpTo(tqdm): # type: ignore[type-arg] """ def update_to( - self, num_blocks: int = 1, bsize: int = 1, tsize: t.Optional[int] = None - ) -> t.Optional[bool]: + self, num_blocks: int = 1, bsize: int = 1, tsize: int | None = None + ) -> bool | None: """Update progress in tqdm-like way :param b: number of blocks transferred so far, defaults to 1 diff --git a/smartsim/_core/config/config.py b/smartsim/_core/config/config.py index ab063eea6..ee416f7de 100644 --- a/smartsim/_core/config/config.py +++ b/smartsim/_core/config/config.py @@ -27,6 +27,7 @@ import json import os import typing as t +from collections.abc import Sequence from functools import lru_cache from pathlib import Path @@ -175,7 +176,7 @@ def dragon_dotenv(self) -> Path: return Path(self.conf_dir / "dragon" / ".env") @property - def dragon_server_path(self) -> t.Optional[str]: + def dragon_server_path(self) -> str | None: return os.getenv( "SMARTSIM_DRAGON_SERVER_PATH", os.getenv("SMARTSIM_DRAGON_SERVER_PATH_EXP", None), @@ -218,7 +219,7 @@ def test_num_gpus(self) -> int: # pragma: no cover return int(os.environ.get("SMARTSIM_TEST_NUM_GPUS") or 1) @property - def test_ports(self) -> t.Sequence[int]: # pragma: no cover + def test_ports(self) -> Sequence[int]: # pragma: no cover min_required_ports = 25 first_port = int(os.environ.get("SMARTSIM_TEST_PORT", 6780)) num_ports = max( @@ -228,7 +229,7 @@ def test_ports(self) -> t.Sequence[int]: # pragma: no cover return range(first_port, first_port + num_ports) @property - def test_batch_resources(self) -> t.Dict[t.Any, t.Any]: # pragma: no cover + def test_batch_resources(self) -> dict[t.Any, t.Any]: # pragma: no cover resource_str = os.environ.get("SMARTSIM_TEST_BATCH_RESOURCES", "{}") resources = json.loads(resource_str) if not isinstance(resources, dict): @@ -242,7 +243,7 @@ def test_batch_resources(self) -> t.Dict[t.Any, t.Any]: # pragma: no cover return resources @property - def test_interface(self) -> t.List[str]: # pragma: no cover + def test_interface(self) -> list[str]: # pragma: no cover if interfaces_cfg := os.environ.get("SMARTSIM_TEST_INTERFACE", None): return interfaces_cfg.split(",") @@ -262,7 +263,7 @@ def test_interface(self) -> t.List[str]: # pragma: no cover return ["lo"] @property - def test_account(self) -> t.Optional[str]: # pragma: no cover + def test_account(self) -> str | None: # pragma: no cover # no account by default return os.environ.get("SMARTSIM_TEST_ACCOUNT", None) diff --git a/smartsim/_core/control/controller.py b/smartsim/_core/control/controller.py index 8e0ef7c38..88631cf08 100644 --- a/smartsim/_core/control/controller.py +++ b/smartsim/_core/control/controller.py @@ -34,7 +34,6 @@ import signal import threading import time -import typing as t from smartredis import Client, ConfigOptions @@ -135,7 +134,7 @@ def start( self.poll(5, True, kill_on_interrupt=kill_on_interrupt) @property - def active_orchestrator_jobs(self) -> t.Dict[str, Job]: + def active_orchestrator_jobs(self) -> dict[str, Job]: """Return active orchestrator jobs.""" return {**self._jobs.db_jobs} @@ -167,9 +166,7 @@ def poll( for job in to_monitor.values(): logger.info(job) - def finished( - self, entity: t.Union[SmartSimEntity, EntitySequence[SmartSimEntity]] - ) -> bool: + def finished(self, entity: SmartSimEntity | EntitySequence[SmartSimEntity]) -> bool: """Return a boolean indicating wether a job has finished or not :param entity: object launched by SmartSim. @@ -194,7 +191,7 @@ def finished( ) from None def stop_entity( - self, entity: t.Union[SmartSimEntity, EntitySequence[SmartSimEntity]] + self, entity: SmartSimEntity | EntitySequence[SmartSimEntity] ) -> None: """Stop an instance of an entity @@ -265,7 +262,7 @@ def stop_entity_list(self, entity_list: EntitySequence[SmartSimEntity]) -> None: for entity in entity_list.entities: self.stop_entity(entity) - def get_jobs(self) -> t.Dict[str, Job]: + def get_jobs(self) -> dict[str, Job]: """Return a dictionary of completed job data :returns: dict[str, Job] @@ -274,7 +271,7 @@ def get_jobs(self) -> t.Dict[str, Job]: return self._jobs.completed def get_entity_status( - self, entity: t.Union[SmartSimEntity, EntitySequence[SmartSimEntity]] + self, entity: SmartSimEntity | EntitySequence[SmartSimEntity] ) -> SmartSimStatus: """Get the status of an entity @@ -291,7 +288,7 @@ def get_entity_status( def get_entity_list_status( self, entity_list: EntitySequence[SmartSimEntity] - ) -> t.List[SmartSimStatus]: + ) -> list[SmartSimStatus]: """Get the statuses of an entity list :param entity_list: entity list containing entities to @@ -320,7 +317,7 @@ def init_launcher(self, launcher: str) -> None: a supported launcher :raises TypeError: if no launcher argument is provided. """ - launcher_map: t.Dict[str, t.Type[Launcher]] = { + launcher_map: dict[str, type[Launcher]] = { "slurm": SlurmLauncher, "pbs": PBSLauncher, "pals": PBSLauncher, @@ -342,7 +339,7 @@ def init_launcher(self, launcher: str) -> None: @staticmethod def symlink_output_files( - job_step: Step, entity: t.Union[SmartSimEntity, EntitySequence[SmartSimEntity]] + job_step: Step, entity: SmartSimEntity | EntitySequence[SmartSimEntity] ) -> None: """Create symlinks for entity output files that point to the output files under the .smartsim directory @@ -414,12 +411,10 @@ def _launch(self, _exp_name: str, exp_path: str, manifest: Manifest) -> None: self._set_dbobjects(manifest) # create all steps prior to launch - steps: t.List[ - t.Tuple[Step, t.Union[SmartSimEntity, EntitySequence[SmartSimEntity]]] - ] = [] + steps: list[tuple[Step, SmartSimEntity | EntitySequence[SmartSimEntity]]] = [] - symlink_substeps: t.List[ - t.Tuple[Step, t.Union[SmartSimEntity, EntitySequence[SmartSimEntity]]] + symlink_substeps: list[ + tuple[Step, SmartSimEntity | EntitySequence[SmartSimEntity]] ] = [] for elist in manifest.ensembles: @@ -483,6 +478,12 @@ def _launch_orchestrator( # if the orchestrator was launched as a batch workload if orchestrator.batch: + metadata_dir = ( + pathlib.Path(orchestrator.path) + / CONFIG.metadata_subdir + / "database" + / orchestrator.name + ) orc_batch_step, substeps = self._create_batch_job_step( orchestrator, metadata_dir ) @@ -495,6 +496,12 @@ def _launch_orchestrator( # if orchestrator was run on existing allocation, locally, or in allocation else: + metadata_dir = ( + pathlib.Path(orchestrator.path) + / CONFIG.metadata_subdir + / "database" + / orchestrator.name + ) db_steps = [ (self._create_job_step(db, metadata_dir), db) for db in orchestrator.entities @@ -537,7 +544,7 @@ def _launch_orchestrator( def _launch_step( self, job_step: Step, - entity: t.Union[SmartSimEntity, EntitySequence[SmartSimEntity]], + entity: SmartSimEntity | EntitySequence[SmartSimEntity], ) -> None: """Use the launcher to launch a job step @@ -594,9 +601,9 @@ def _launch_step( def _create_batch_job_step( self, - entity_list: t.Union[Orchestrator, Ensemble, _AnonymousBatchJob], + entity_list: Orchestrator | Ensemble | _AnonymousBatchJob, metadata_dir: pathlib.Path, - ) -> t.Tuple[Step, t.List[Step]]: + ) -> tuple[Step, list[Step]]: """Use launcher to create batch job step :param entity_list: EntityList to launch as batch @@ -655,7 +662,7 @@ def _prep_entity_client_env(self, entity: Model) -> None: :param entity: The entity to retrieve connections from """ - client_env: t.Dict[str, t.Union[str, int, float, bool]] = {} + client_env: dict[str, str | int | float | bool] = {} address_dict = self._jobs.get_db_host_addresses() for db_id, addresses in address_dict.items(): @@ -787,9 +794,7 @@ def _orchestrator_launch_wait(self, orchestrator: Orchestrator) -> None: # launch explicitly raise - def reload_saved_db( - self, checkpoint_file: t.Union[str, os.PathLike[str]] - ) -> Orchestrator: + def reload_saved_db(self, checkpoint_file: str | os.PathLike[str]) -> Orchestrator: with JM_LOCK: if not osp.exists(checkpoint_file): diff --git a/smartsim/_core/control/job.py b/smartsim/_core/control/job.py index f095b61ec..c96960cfc 100644 --- a/smartsim/_core/control/job.py +++ b/smartsim/_core/control/job.py @@ -25,7 +25,6 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import time -import typing as t from ...entity import EntitySequence, SmartSimEntity from ...status import SmartSimStatus @@ -41,8 +40,8 @@ class Job: def __init__( self, job_name: str, - job_id: t.Optional[str], - entity: t.Union[SmartSimEntity, EntitySequence[SmartSimEntity]], + job_id: str | None, + entity: SmartSimEntity | EntitySequence[SmartSimEntity], launcher: str, is_task: bool, ) -> None: @@ -59,12 +58,12 @@ def __init__( self.entity = entity self.status = SmartSimStatus.STATUS_NEW # status before smartsim status mapping is applied - self.raw_status: t.Optional[str] = None - self.returncode: t.Optional[int] = None + self.raw_status: str | None = None + self.returncode: int | None = None # output is only populated if it's system related (e.g. cmd failed immediately) - self.output: t.Optional[str] = None - self.error: t.Optional[str] = None # same as output - self.hosts: t.List[str] = [] # currently only used for DB jobs + self.output: str | None = None + self.error: str | None = None # same as output + self.hosts: list[str] = [] # currently only used for DB jobs self.launched_with = launcher self.is_task = is_task self.start_time = time.time() @@ -79,9 +78,9 @@ def set_status( self, new_status: SmartSimStatus, raw_status: str, - returncode: t.Optional[int], - error: t.Optional[str] = None, - output: t.Optional[str] = None, + returncode: int | None, + error: str | None = None, + output: str | None = None, ) -> None: """Set the status of a job. @@ -105,9 +104,7 @@ def record_history(self) -> None: """Record the launching history of a job.""" self.history.record(self.jid, self.status, self.returncode, self.elapsed) - def reset( - self, new_job_name: str, new_job_id: t.Optional[str], is_task: bool - ) -> None: + def reset(self, new_job_name: str, new_job_id: str | None, is_task: bool) -> None: """Reset the job in order to be able to restart it. :param new_job_name: name of the new job step @@ -168,16 +165,16 @@ def __init__(self, runs: int = 0) -> None: :param runs: number of runs so far """ self.runs = runs - self.jids: t.Dict[int, t.Optional[str]] = {} - self.statuses: t.Dict[int, SmartSimStatus] = {} - self.returns: t.Dict[int, t.Optional[int]] = {} - self.job_times: t.Dict[int, float] = {} + self.jids: dict[int, str | None] = {} + self.statuses: dict[int, SmartSimStatus] = {} + self.returns: dict[int, int | None] = {} + self.job_times: dict[int, float] = {} def record( self, - job_id: t.Optional[str], + job_id: str | None, status: SmartSimStatus, - returncode: t.Optional[int], + returncode: int | None, job_time: float, ) -> None: """record the history of a job""" diff --git a/smartsim/_core/control/jobmanager.py b/smartsim/_core/control/jobmanager.py index 8bf0804c3..d253c02c8 100644 --- a/smartsim/_core/control/jobmanager.py +++ b/smartsim/_core/control/jobmanager.py @@ -27,7 +27,6 @@ import itertools import time -import typing as t from collections import ChainMap from threading import RLock, Thread from types import FrameType @@ -57,19 +56,19 @@ class JobManager: wlm to query information about jobs that the user requests. """ - def __init__(self, lock: RLock, launcher: t.Optional[Launcher] = None) -> None: + def __init__(self, lock: RLock, launcher: Launcher | None = None) -> None: """Initialize a Jobmanager :param launcher: a Launcher object to manage jobs """ - self.monitor: t.Optional[Thread] = None + self.monitor: Thread | None = None # active jobs - self.jobs: t.Dict[str, Job] = {} - self.db_jobs: t.Dict[str, Job] = {} + self.jobs: dict[str, Job] = {} + self.db_jobs: dict[str, Job] = {} # completed jobs - self.completed: t.Dict[str, Job] = {} + self.completed: dict[str, Job] = {} self.actively_monitoring = False # on/off flag self._launcher = launcher # reference to launcher @@ -145,7 +144,7 @@ def __getitem__(self, entity_name: str) -> Job: entities = ChainMap(self.db_jobs, self.jobs, self.completed) return entities[entity_name] - def __call__(self) -> t.Dict[str, Job]: + def __call__(self) -> dict[str, Job]: """Returns dictionary all jobs for () operator :returns: Dictionary of all jobs @@ -163,8 +162,8 @@ def __contains__(self, key: str) -> bool: def add_job( self, job_name: str, - job_id: t.Optional[str], - entity: t.Union[SmartSimEntity, EntitySequence[SmartSimEntity]], + job_id: str | None, + entity: SmartSimEntity | EntitySequence[SmartSimEntity], is_task: bool = True, ) -> None: """Add a job to the job manager which holds specific jobs by type. @@ -225,7 +224,7 @@ def check_jobs(self) -> None: def get_status( self, - entity: t.Union[SmartSimEntity, EntitySequence[SmartSimEntity]], + entity: SmartSimEntity | EntitySequence[SmartSimEntity], ) -> SmartSimStatus: """Return the status of a job. @@ -262,7 +261,7 @@ def query_restart(self, entity_name: str) -> bool: def restart_job( self, job_name: str, - job_id: t.Optional[str], + job_id: str | None, entity_name: str, is_task: bool = True, ) -> None: @@ -285,14 +284,14 @@ def restart_job( else: self.jobs[entity_name] = job - def get_db_host_addresses(self) -> t.Dict[str, t.List[str]]: + def get_db_host_addresses(self) -> dict[str, list[str]]: """Retrieve the list of hosts for the database for corresponding database identifiers :return: dictionary of host ip addresses """ - address_dict: t.Dict[str, t.List[str]] = {} + address_dict: dict[str, list[str]] = {} for db_job in self.db_jobs.values(): addresses = [] if isinstance(db_job.entity, (DBNode, Orchestrator)): @@ -301,7 +300,7 @@ def get_db_host_addresses(self) -> t.Dict[str, t.List[str]]: ip_addr = get_ip_from_host(combine[0]) addresses.append(":".join((ip_addr, str(combine[1])))) - dict_entry: t.List[str] = address_dict.get(db_entity.db_identifier, []) + dict_entry: list[str] = address_dict.get(db_entity.db_identifier, []) dict_entry.extend(addresses) address_dict[db_entity.db_identifier] = dict_entry @@ -325,7 +324,7 @@ def set_db_hosts(self, orchestrator: Orchestrator) -> None: else: self.db_jobs[dbnode.name].hosts = dbnode.hosts - def signal_interrupt(self, signo: int, _frame: t.Optional[FrameType]) -> None: + def signal_interrupt(self, signo: int, _frame: FrameType | None) -> None: """Custom handler for whenever SIGINT is received""" if not signo: logger.warning("Received SIGINT with no signal number") diff --git a/smartsim/_core/control/manifest.py b/smartsim/_core/control/manifest.py index 0ba0e6f79..5154f7620 100644 --- a/smartsim/_core/control/manifest.py +++ b/smartsim/_core/control/manifest.py @@ -26,6 +26,7 @@ import itertools import typing as t +from collections.abc import Iterable from ...database import Orchestrator from ...entity import Ensemble, EntitySequence, Model, SmartSimEntity @@ -43,16 +44,14 @@ class Manifest: can all be passed as arguments """ - def __init__( - self, *args: t.Union[SmartSimEntity, EntitySequence[SmartSimEntity]] - ) -> None: + def __init__(self, *args: SmartSimEntity | EntitySequence[SmartSimEntity]) -> None: self._deployables = list(args) self._check_types(self._deployables) self._check_names(self._deployables) self._check_entity_lists_nonempty() @property - def dbs(self) -> t.List[Orchestrator]: + def dbs(self) -> list[Orchestrator]: """Return a list of Orchestrator instances in Manifest :raises SmartSimError: if user added to databases to manifest @@ -62,18 +61,18 @@ def dbs(self) -> t.List[Orchestrator]: return dbs @property - def models(self) -> t.List[Model]: + def models(self) -> list[Model]: """Return Model instances in Manifest :return: model instances """ - _models: t.List[Model] = [ + _models: list[Model] = [ item for item in self._deployables if isinstance(item, Model) ] return _models @property - def ensembles(self) -> t.List[Ensemble]: + def ensembles(self) -> list[Ensemble]: """Return Ensemble instances in Manifest :return: list of ensembles @@ -81,13 +80,13 @@ def ensembles(self) -> t.List[Ensemble]: return [e for e in self._deployables if isinstance(e, Ensemble)] @property - def all_entity_lists(self) -> t.List[EntitySequence[SmartSimEntity]]: + def all_entity_lists(self) -> list[EntitySequence[SmartSimEntity]]: """All entity lists, including ensembles and exceptional ones like Orchestrator :return: list of entity lists """ - _all_entity_lists: t.List[EntitySequence[SmartSimEntity]] = list(self.ensembles) + _all_entity_lists: list[EntitySequence[SmartSimEntity]] = list(self.ensembles) for db in self.dbs: _all_entity_lists.append(db) @@ -103,7 +102,7 @@ def has_deployable(self) -> bool: return bool(self._deployables) @staticmethod - def _check_names(deployables: t.List[t.Any]) -> None: + def _check_names(deployables: list[t.Any]) -> None: used = [] for deployable in deployables: name = getattr(deployable, "name", None) @@ -114,7 +113,7 @@ def _check_names(deployables: t.List[t.Any]) -> None: used.append(name) @staticmethod - def _check_types(deployables: t.List[t.Any]) -> None: + def _check_types(deployables: list[t.Any]) -> None: for deployable in deployables: if not isinstance(deployable, (SmartSimEntity, EntitySequence)): raise TypeError( @@ -172,7 +171,7 @@ def __str__(self) -> str: @property def has_db_objects(self) -> bool: """Check if any entity has DBObjects to set""" - ents: t.Iterable[t.Union[Model, Ensemble]] = itertools.chain( + ents: Iterable[Model | Ensemble] = itertools.chain( self.models, self.ensembles, (member for ens in self.ensembles for member in ens.entities), diff --git a/smartsim/_core/control/previewrenderer.py b/smartsim/_core/control/previewrenderer.py index dfda4285a..d871a3aeb 100644 --- a/smartsim/_core/control/previewrenderer.py +++ b/smartsim/_core/control/previewrenderer.py @@ -64,7 +64,7 @@ def as_toggle(_eval_ctx: u.F, value: bool) -> str: @pass_eval_context -def get_ifname(_eval_ctx: u.F, value: t.List[str]) -> str: +def get_ifname(_eval_ctx: u.F, value: list[str]) -> str: """Extract Network Interface from orchestrator run settings.""" if value: for val in value: @@ -108,11 +108,11 @@ def render_to_file(content: str, filename: str) -> None: def render( exp: "Experiment", - manifest: t.Optional[Manifest] = None, + manifest: Manifest | None = None, verbosity_level: Verbosity = Verbosity.INFO, output_format: Format = Format.PLAINTEXT, - output_filename: t.Optional[str] = None, - active_dbjobs: t.Optional[t.Dict[str, Job]] = None, + output_filename: str | None = None, + active_dbjobs: dict[str, Job] | None = None, ) -> str: """ Render the template from the supplied entities. diff --git a/smartsim/_core/entrypoints/colocated.py b/smartsim/_core/entrypoints/colocated.py index 6615c9c76..539bc298e 100644 --- a/smartsim/_core/entrypoints/colocated.py +++ b/smartsim/_core/entrypoints/colocated.py @@ -30,7 +30,6 @@ import socket import sys import tempfile -import typing as t from pathlib import Path from subprocess import STDOUT from types import FrameType @@ -52,13 +51,13 @@ SIGNALS = [signal.SIGINT, signal.SIGTERM, signal.SIGQUIT, signal.SIGABRT] -def handle_signal(signo: int, _frame: t.Optional[FrameType]) -> None: +def handle_signal(signo: int, _frame: FrameType | None) -> None: if not signo: logger.warning("Received signal with no signo") cleanup() -def launch_db_model(client: Client, db_model: t.List[str]) -> str: +def launch_db_model(client: Client, db_model: list[str]) -> str: """Parse options to launch model on local cluster :param client: SmartRedis client connected to local DB @@ -122,7 +121,7 @@ def launch_db_model(client: Client, db_model: t.List[str]) -> str: return name -def launch_db_script(client: Client, db_script: t.List[str]) -> str: +def launch_db_script(client: Client, db_script: list[str]) -> str: """Parse options to launch script on local cluster :param client: SmartRedis client connected to local DB @@ -166,9 +165,9 @@ def launch_db_script(client: Client, db_script: t.List[str]) -> str: def main( network_interface: str, db_cpus: int, - command: t.List[str], - db_models: t.List[t.List[str]], - db_scripts: t.List[t.List[str]], + command: list[str], + db_models: list[list[str]], + db_scripts: list[list[str]], db_identifier: str, ) -> None: # pylint: disable=too-many-statements @@ -226,13 +225,13 @@ def main( logger.error(f"Failed to start database process: {str(e)}") raise SSInternalError("Colocated process failed to start") from e - def launch_models(client: Client, db_models: t.List[t.List[str]]) -> None: + def launch_models(client: Client, db_models: list[list[str]]) -> None: for i, db_model in enumerate(db_models): logger.debug("Uploading model") model_name = launch_db_model(client, db_model) logger.debug(f"Added model {model_name} ({i+1}/{len(db_models)})") - def launch_db_scripts(client: Client, db_scripts: t.List[t.List[str]]) -> None: + def launch_db_scripts(client: Client, db_scripts: list[list[str]]) -> None: for i, db_script in enumerate(db_scripts): logger.debug("Uploading script") script_name = launch_db_script(client, db_script) diff --git a/smartsim/_core/entrypoints/dragon.py b/smartsim/_core/entrypoints/dragon.py index 4bc4c0e3b..3ae1aca9f 100644 --- a/smartsim/_core/entrypoints/dragon.py +++ b/smartsim/_core/entrypoints/dragon.py @@ -68,7 +68,7 @@ class DragonEntrypointArgs: interface: str -def handle_signal(signo: int, _frame: t.Optional[FrameType] = None) -> None: +def handle_signal(signo: int, _frame: FrameType | None = None) -> None: if not signo: logger.info("Received signal with no signo") else: @@ -99,7 +99,7 @@ def print_summary(network_interface: str, ip_address: str) -> None: def start_updater( - backend: DragonBackend, updater: t.Optional[ContextThread] + backend: DragonBackend, updater: ContextThread | None ) -> ContextThread: """Start the ``DragonBackend`` updater thread. @@ -302,7 +302,7 @@ def register_signal_handlers() -> None: signal.signal(sig, handle_signal) -def parse_arguments(args: t.List[str]) -> DragonEntrypointArgs: +def parse_arguments(args: list[str]) -> DragonEntrypointArgs: parser = argparse.ArgumentParser( prefix_chars="+", description="SmartSim Dragon Head Process" ) @@ -326,7 +326,7 @@ def parse_arguments(args: t.List[str]) -> DragonEntrypointArgs: return DragonEntrypointArgs(args_.launching_address, args_.interface) -def main(args_: t.List[str]) -> int: +def main(args_: list[str]) -> int: """Execute the dragon entrypoint as a module""" os.environ["PYTHONUNBUFFERED"] = "1" logger.info("Dragon server started") diff --git a/smartsim/_core/entrypoints/dragon_client.py b/smartsim/_core/entrypoints/dragon_client.py index c4b77b90f..eb12f9aee 100644 --- a/smartsim/_core/entrypoints/dragon_client.py +++ b/smartsim/_core/entrypoints/dragon_client.py @@ -31,7 +31,6 @@ import signal import sys import time -import typing as t from pathlib import Path from types import FrameType @@ -66,13 +65,13 @@ def cleanup() -> None: logger.debug("Cleaning up") -def parse_requests(request_filepath: Path) -> t.List[DragonRequest]: +def parse_requests(request_filepath: Path) -> list[DragonRequest]: """Parse serialized requests from file :param request_filepath: Path to file with serialized requests :return: Deserialized requests """ - requests: t.List[DragonRequest] = [] + requests: list[DragonRequest] = [] try: with open(request_filepath, "r", encoding="utf-8") as request_file: req_strings = json.load(fp=request_file) @@ -91,7 +90,7 @@ def parse_requests(request_filepath: Path) -> t.List[DragonRequest]: return requests -def parse_arguments(args: t.List[str]) -> DragonClientEntrypointArgs: +def parse_arguments(args: list[str]) -> DragonClientEntrypointArgs: """Parse arguments used to run entrypoint script :param args: Arguments without name of executable @@ -111,7 +110,7 @@ def parse_arguments(args: t.List[str]) -> DragonClientEntrypointArgs: return DragonClientEntrypointArgs(submit=Path(args_.submit)) -def handle_signal(signo: int, _frame: t.Optional[FrameType] = None) -> None: +def handle_signal(signo: int, _frame: FrameType | None = None) -> None: """Handle signals sent to this process :param signo: Signal number @@ -176,7 +175,7 @@ def execute_entrypoint(args: DragonClientEntrypointArgs) -> int: return os.EX_OK -def main(args_: t.List[str]) -> int: +def main(args_: list[str]) -> int: """Execute the dragon client entrypoint as a module""" os.environ["PYTHONUNBUFFERED"] = "1" diff --git a/smartsim/_core/entrypoints/redis.py b/smartsim/_core/entrypoints/redis.py index 130b3ce91..88e45da0c 100644 --- a/smartsim/_core/entrypoints/redis.py +++ b/smartsim/_core/entrypoints/redis.py @@ -29,7 +29,6 @@ import os import signal import textwrap -import typing as t from subprocess import PIPE, STDOUT from types import FrameType @@ -45,19 +44,19 @@ Redis/KeyDB entrypoint script """ -DBPID: t.Optional[int] = None +DBPID: int | None = None # kill is not catchable SIGNALS = [signal.SIGINT, signal.SIGQUIT, signal.SIGTERM, signal.SIGABRT] -def handle_signal(signo: int, _frame: t.Optional[FrameType]) -> None: +def handle_signal(signo: int, _frame: FrameType | None) -> None: if not signo: logger.warning("Received signal with no signo") cleanup() -def build_bind_args(source_addr: str, *addrs: str) -> t.Tuple[str, ...]: +def build_bind_args(source_addr: str, *addrs: str) -> tuple[str, ...]: return ( "--bind", source_addr, @@ -68,14 +67,14 @@ def build_bind_args(source_addr: str, *addrs: str) -> t.Tuple[str, ...]: ) -def build_cluster_args(shard_data: LaunchedShardData) -> t.Tuple[str, ...]: +def build_cluster_args(shard_data: LaunchedShardData) -> tuple[str, ...]: if cluster_conf_file := shard_data.cluster_conf_file: return ("--cluster-enabled", "yes", "--cluster-config-file", cluster_conf_file) return () def print_summary( - cmd: t.List[str], network_interface: str, shard_data: LaunchedShardData + cmd: list[str], network_interface: str, shard_data: LaunchedShardData ) -> None: print( textwrap.dedent(f"""\ diff --git a/smartsim/_core/generation/generator.py b/smartsim/_core/generation/generator.py index 5e937a69b..95b85f9b4 100644 --- a/smartsim/_core/generation/generator.py +++ b/smartsim/_core/generation/generator.py @@ -108,7 +108,7 @@ def generate_experiment(self, *args: t.Any) -> None: self._gen_entity_list_dir(generator_manifest.ensembles) self._gen_entity_dirs(generator_manifest.models) - def set_tag(self, tag: str, regex: t.Optional[str] = None) -> None: + def set_tag(self, tag: str, regex: str | None = None) -> None: """Set the tag used for tagging input files Set a tag or a regular expression for the @@ -153,7 +153,7 @@ def _gen_exp_dir(self) -> None: dt_string = datetime.now().strftime("%d/%m/%Y %H:%M:%S") log_file.write(f"Generation start date and time: {dt_string}\n") - def _gen_orc_dir(self, orchestrator_list: t.List[Orchestrator]) -> None: + def _gen_orc_dir(self, orchestrator_list: list[Orchestrator]) -> None: """Create the directory that will hold the error, output and configuration files for the orchestrator. @@ -169,7 +169,7 @@ def _gen_orc_dir(self, orchestrator_list: t.List[Orchestrator]) -> None: shutil.rmtree(orc_path, ignore_errors=True) pathlib.Path(orc_path).mkdir(exist_ok=self.overwrite, parents=True) - def _gen_entity_list_dir(self, entity_lists: t.List[Ensemble]) -> None: + def _gen_entity_list_dir(self, entity_lists: list[Ensemble]) -> None: """Generate directories for Ensemble instances :param entity_lists: list of Ensemble instances @@ -192,8 +192,8 @@ def _gen_entity_list_dir(self, entity_lists: t.List[Ensemble]) -> None: def _gen_entity_dirs( self, - entities: t.List[Model], - entity_list: t.Optional[Ensemble] = None, + entities: list[Model], + entity_list: Ensemble | None = None, ) -> None: """Generate directories for Entity instances @@ -269,7 +269,7 @@ def _build_tagged_files(tagged: TaggedFilesHierarchy) -> None: self._log_params(entity, files_to_params) def _log_params( - self, entity: Model, files_to_params: t.Dict[str, t.Dict[str, str]] + self, entity: Model, files_to_params: dict[str, dict[str, str]] ) -> None: """Log which files were modified during generation @@ -278,8 +278,8 @@ def _log_params( :param entity: the model being generated :param files_to_params: a dict connecting each file to its parameter settings """ - used_params: t.Dict[str, str] = {} - file_to_tables: t.Dict[str, str] = {} + used_params: dict[str, str] = {} + file_to_tables: dict[str, str] = {} for file, params in files_to_params.items(): used_params.update(params) table = tabulate(params.items(), headers=["Name", "Value"]) diff --git a/smartsim/_core/generation/modelwriter.py b/smartsim/_core/generation/modelwriter.py index 7502a1622..b7bee66e7 100644 --- a/smartsim/_core/generation/modelwriter.py +++ b/smartsim/_core/generation/modelwriter.py @@ -26,7 +26,7 @@ import collections import re -import typing as t +from collections import defaultdict from smartsim.error.errors import SmartSimError @@ -40,9 +40,9 @@ class ModelWriter: def __init__(self) -> None: self.tag = ";" self.regex = "(;[^;]+;)" - self.lines: t.List[str] = [] + self.lines: list[str] = [] - def set_tag(self, tag: str, regex: t.Optional[str] = None) -> None: + def set_tag(self, tag: str, regex: str | None = None) -> None: """Set the tag for the modelwriter to search for within tagged files attached to an entity. @@ -59,10 +59,10 @@ def set_tag(self, tag: str, regex: t.Optional[str] = None) -> None: def configure_tagged_model_files( self, - tagged_files: t.List[str], - params: t.Dict[str, str], + tagged_files: list[str], + params: dict[str, str], make_missing_tags_fatal: bool = False, - ) -> t.Dict[str, t.Dict[str, str]]: + ) -> dict[str, dict[str, str]]: """Read, write and configure tagged files attached to a Model instance. @@ -71,7 +71,7 @@ def configure_tagged_model_files( :param make_missing_tags_fatal: raise an error if a tag is missing :returns: A dict connecting each file to its parameter settings """ - files_to_tags: t.Dict[str, t.Dict[str, str]] = {} + files_to_tags: dict[str, dict[str, str]] = {} for tagged_file in tagged_files: self._set_lines(tagged_file) used_tags = self._replace_tags(params, make_missing_tags_fatal) @@ -105,8 +105,8 @@ def _write_changes(self, file_path: str) -> None: raise ParameterWriterError(file_path, read=False) from e def _replace_tags( - self, params: t.Dict[str, str], make_fatal: bool = False - ) -> t.Dict[str, str]: + self, params: dict[str, str], make_fatal: bool = False + ) -> dict[str, str]: """Replace the tagged parameters within the file attached to this model. The tag defaults to ";" @@ -116,8 +116,8 @@ def _replace_tags( :returns: A dict of parameter names and values set for the file """ edited = [] - unused_tags: t.DefaultDict[str, t.List[int]] = collections.defaultdict(list) - used_params: t.Dict[str, str] = {} + unused_tags: defaultdict[str, list[int]] = collections.defaultdict(list) + used_params: dict[str, str] = {} for i, line in enumerate(self.lines, 1): while search := re.search(self.regex, line): tagged_line = search.group(0) @@ -144,9 +144,7 @@ def _replace_tags( self.lines = edited return used_params - def _is_ensemble_spec( - self, tagged_line: str, model_params: t.Dict[str, str] - ) -> bool: + def _is_ensemble_spec(self, tagged_line: str, model_params: dict[str, str]) -> bool: split_tag = tagged_line.split(self.tag) prev_val = split_tag[1] if prev_val in model_params.keys(): diff --git a/smartsim/_core/launcher/colocated.py b/smartsim/_core/launcher/colocated.py index 4de156b65..3f7e7cfd2 100644 --- a/smartsim/_core/launcher/colocated.py +++ b/smartsim/_core/launcher/colocated.py @@ -34,7 +34,7 @@ def write_colocated_launch_script( - file_name: str, db_log: str, colocated_settings: t.Dict[str, t.Any] + file_name: str, db_log: str, colocated_settings: dict[str, t.Any] ) -> None: """Write the colocated launch script @@ -80,11 +80,11 @@ def write_colocated_launch_script( def _build_colocated_wrapper_cmd( db_log: str, cpus: int = 1, - rai_args: t.Optional[t.Dict[str, str]] = None, - extra_db_args: t.Optional[t.Dict[str, str]] = None, + rai_args: dict[str, str] | None = None, + extra_db_args: dict[str, str] | None = None, port: int = 6780, - ifname: t.Optional[t.Union[str, t.List[str]]] = None, - custom_pinning: t.Optional[str] = None, + ifname: str | list[str] | None = None, + custom_pinning: str | None = None, **kwargs: t.Any, ) -> str: """Build the command use to run a colocated DB application @@ -189,7 +189,7 @@ def _build_colocated_wrapper_cmd( return " ".join(cmd) -def _build_db_model_cmd(db_models: t.List[DBModel]) -> t.List[str]: +def _build_db_model_cmd(db_models: list[DBModel]) -> list[str]: cmd = [] for db_model in db_models: cmd.append("+db_model") @@ -219,7 +219,7 @@ def _build_db_model_cmd(db_models: t.List[DBModel]) -> t.List[str]: return cmd -def _build_db_script_cmd(db_scripts: t.List[DBScript]) -> t.List[str]: +def _build_db_script_cmd(db_scripts: list[DBScript]) -> list[str]: cmd = [] for db_script in db_scripts: cmd.append("+db_script") diff --git a/smartsim/_core/launcher/dragon/dragonBackend.py b/smartsim/_core/launcher/dragon/dragonBackend.py index 2f8704be2..18364676e 100644 --- a/smartsim/_core/launcher/dragon/dragonBackend.py +++ b/smartsim/_core/launcher/dragon/dragonBackend.py @@ -78,19 +78,19 @@ def __str__(self) -> str: class ProcessGroupInfo: status: SmartSimStatus """Status of step""" - process_group: t.Optional[dragon_process_group.ProcessGroup] = None + process_group: dragon_process_group.ProcessGroup | None = None """Internal Process Group object, None for finished or not started steps""" - puids: t.Optional[t.List[t.Optional[int]]] = None # puids can be None + puids: list[int | None] | None = None # puids can be None """List of Process UIDS belonging to the ProcessGroup""" - return_codes: t.Optional[t.List[int]] = None + return_codes: list[int] | None = None """List of return codes of completed processes""" - hosts: t.List[str] = field(default_factory=list) + hosts: list[str] = field(default_factory=list) """List of hosts on which the Process Group """ - redir_workers: t.Optional[dragon_process_group.ProcessGroup] = None + redir_workers: dragon_process_group.ProcessGroup | None = None """Workers used to redirect stdout and stderr to file""" @property - def smartsim_info(self) -> t.Tuple[SmartSimStatus, t.Optional[t.List[int]]]: + def smartsim_info(self) -> tuple[SmartSimStatus, list[int] | None]: """Information needed by SmartSim Launcher and Job Manager""" return (self.status, self.return_codes) @@ -145,7 +145,7 @@ class DragonBackend: def __init__(self, pid: int) -> None: self._pid = pid """PID of dragon executable which launched this server""" - self._group_infos: t.Dict[str, ProcessGroupInfo] = {} + self._group_infos: dict[str, ProcessGroupInfo] = {} """ProcessGroup execution state information""" self._queue_lock = RLock() """Lock that needs to be acquired to access internal queues""" @@ -159,9 +159,9 @@ def __init__(self, pid: int) -> None: """Steps waiting for execution""" self._stop_requests: t.Deque[DragonStopRequest] = collections.deque() """Stop requests which have not been processed yet""" - self._running_steps: t.List[str] = [] + self._running_steps: list[str] = [] """List of currently running steps""" - self._completed_steps: t.List[str] = [] + self._completed_steps: list[str] = [] """List of completed steps""" self._last_beat: float = 0.0 """Time at which the last heartbeat was set""" @@ -174,7 +174,7 @@ def __init__(self, pid: int) -> None: """Whether the server can shut down""" self._frontend_shutdown: bool = False """Whether the server frontend should shut down when the backend does""" - self._shutdown_initiation_time: t.Optional[float] = None + self._shutdown_initiation_time: float | None = None """The time at which the server initiated shutdown""" self._cooldown_period = 5 """Time in seconds needed to server to complete shutdown""" @@ -207,14 +207,14 @@ def _initialize_hosts(self) -> None: self._nodes = [ dragon_machine.Node(node) for node in dragon_machine.System().nodes ] - self._hosts: t.List[str] = sorted(node.hostname for node in self._nodes) + self._hosts: list[str] = sorted(node.hostname for node in self._nodes) self._cpus = [node.num_cpus for node in self._nodes] self._gpus = [node.num_gpus for node in self._nodes] """List of hosts available in allocation""" self._free_hosts: t.Deque[str] = collections.deque(self._hosts) """List of hosts on which steps can be launched""" - self._allocated_hosts: t.Dict[str, str] = {} + self._allocated_hosts: dict[str, str] = {} """Mapping of hosts on which a step is already running to step ID""" def __str__(self) -> str: @@ -282,9 +282,7 @@ def current_time(self) -> float: """Current time for DragonBackend object, in seconds since the Epoch""" return time.time() - def _can_honor_policy( - self, request: DragonRunRequest - ) -> t.Tuple[bool, t.Optional[str]]: + def _can_honor_policy(self, request: DragonRunRequest) -> tuple[bool, str | None]: """Check if the policy can be honored with resources available in the allocation. :param request: DragonRunRequest containing policy information @@ -310,7 +308,7 @@ def _can_honor_policy( return True, None - def _can_honor(self, request: DragonRunRequest) -> t.Tuple[bool, t.Optional[str]]: + def _can_honor(self, request: DragonRunRequest) -> tuple[bool, str | None]: """Check if request can be honored with resources available in the allocation. Currently only checks for total number of nodes, @@ -333,7 +331,7 @@ def _can_honor(self, request: DragonRunRequest) -> t.Tuple[bool, t.Optional[str] def _allocate_step( self, step_id: str, request: DragonRunRequest - ) -> t.Optional[t.List[str]]: + ) -> list[str] | None: num_hosts: int = request.nodes with self._queue_lock: @@ -349,10 +347,10 @@ def _allocate_step( @staticmethod def _create_redirect_workers( global_policy: dragon_policy.Policy, - policies: t.List[dragon_policy.Policy], - puids: t.List[int], - out_file: t.Optional[str], - err_file: t.Optional[str], + policies: list[dragon_policy.Policy], + puids: list[int], + out_file: str | None, + err_file: str | None, ) -> dragon_process_group.ProcessGroup: grp_redir = dragon_process_group.ProcessGroup( restart=False, policy=global_policy, pmi_enabled=False @@ -433,8 +431,8 @@ def create_run_policy( run_request: DragonRunRequest = request affinity = dragon_policy.Policy.Affinity.DEFAULT - cpu_affinity: t.List[int] = [] - gpu_affinity: t.List[int] = [] + cpu_affinity: list[int] = [] + gpu_affinity: list[int] = [] # Customize policy only if the client requested it, otherwise use default if run_request.policy is not None: @@ -737,7 +735,7 @@ def host_desc(self) -> str: @staticmethod def _proc_group_info_table_line( step_id: str, proc_group_info: ProcessGroupInfo - ) -> t.List[str]: + ) -> list[str]: table_line = [step_id, f"{proc_group_info.status.value}"] if proc_group_info.hosts is not None: diff --git a/smartsim/_core/launcher/dragon/dragonConnector.py b/smartsim/_core/launcher/dragon/dragonConnector.py index e43865b28..72a2512f7 100644 --- a/smartsim/_core/launcher/dragon/dragonConnector.py +++ b/smartsim/_core/launcher/dragon/dragonConnector.py @@ -35,6 +35,7 @@ import sys import typing as t from collections import defaultdict +from collections.abc import Iterable from pathlib import Path from threading import RLock @@ -59,7 +60,7 @@ logger = get_logger(__name__) -_SchemaT = t.TypeVar("_SchemaT", bound=t.Union[DragonRequest, DragonResponse]) +_SchemaT = t.TypeVar("_SchemaT", bound=DragonRequest | DragonResponse) DRG_LOCK = RLock() @@ -73,17 +74,17 @@ def __init__(self) -> None: self._context: zmq.Context[t.Any] = zmq.Context.instance() self._context.setsockopt(zmq.REQ_CORRELATE, 1) self._context.setsockopt(zmq.REQ_RELAXED, 1) - self._authenticator: t.Optional[zmq.auth.thread.ThreadAuthenticator] = None + self._authenticator: zmq.auth.thread.ThreadAuthenticator | None = None config = get_config() self._reset_timeout(config.dragon_server_timeout) - self._dragon_head_socket: t.Optional[zmq.Socket[t.Any]] = None - self._dragon_head_process: t.Optional[subprocess.Popen[bytes]] = None + self._dragon_head_socket: zmq.Socket[t.Any] | None = None + self._dragon_head_process: subprocess.Popen[bytes] | None = None # Returned by dragon head, useful if shutdown is to be requested # but process was started by another connector - self._dragon_head_pid: t.Optional[int] = None + self._dragon_head_pid: int | None = None self._dragon_server_path = config.dragon_server_path logger.debug(f"Dragon Server path was set to {self._dragon_server_path}") - self._env_vars: t.Dict[str, str] = {} + self._env_vars: dict[str, str] = {} if self._dragon_server_path is None: raise SmartSimError( "DragonConnector could not find the dragon server path. " @@ -218,7 +219,7 @@ def _connect_to_existing_server(self, path: Path) -> None: def _start_connector_socket(self, socket_addr: str) -> zmq.Socket[t.Any]: config = get_config() - connector_socket: t.Optional[zmq.Socket[t.Any]] = None + connector_socket: zmq.Socket[t.Any] | None = None self._reset_timeout(config.dragon_server_startup_timeout) self._get_new_authenticator(-1) connector_socket = dragonSockets.get_secure_socket(self._context, zmq.REP, True) @@ -229,7 +230,7 @@ def _start_connector_socket(self, socket_addr: str) -> zmq.Socket[t.Any]: return connector_socket - def load_persisted_env(self) -> t.Dict[str, str]: + def load_persisted_env(self) -> dict[str, str]: """Load key-value pairs from a .env file created during dragon installation :return: Key-value pairs stored in .env file""" @@ -251,7 +252,7 @@ def load_persisted_env(self) -> t.Dict[str, str]: return self._env_vars - def merge_persisted_env(self, current_env: t.Dict[str, str]) -> t.Dict[str, str]: + def merge_persisted_env(self, current_env: dict[str, str]) -> dict[str, str]: """Combine the current environment variable set with the dragon .env by adding Dragon-specific values and prepending any new values to existing keys @@ -259,7 +260,7 @@ def merge_persisted_env(self, current_env: t.Dict[str, str]) -> t.Dict[str, str] :return: Merged environment """ # ensure we start w/a complete env from current env state - merged_env: t.Dict[str, str] = {**current_env} + merged_env: dict[str, str] = {**current_env} # copy all the values for dragon straight into merged_env merged_env.update( @@ -416,8 +417,8 @@ def send_request(self, request: DragonRequest, flags: int = 0) -> DragonResponse @staticmethod def _parse_launched_dragon_server_info_from_iterable( - stream: t.Iterable[str], num_dragon_envs: t.Optional[int] = None - ) -> t.List[t.Dict[str, str]]: + stream: Iterable[str], num_dragon_envs: int | None = None + ) -> list[dict[str, str]]: lines = (line.strip() for line in stream) lines = (line for line in lines if line) tokenized = (line.split(maxsplit=1) for line in lines) @@ -441,9 +442,9 @@ def _parse_launched_dragon_server_info_from_iterable( @classmethod def _parse_launched_dragon_server_info_from_files( cls, - file_paths: t.List[t.Union[str, "os.PathLike[str]"]], - num_dragon_envs: t.Optional[int] = None, - ) -> t.List[t.Dict[str, str]]: + file_paths: list[str | os.PathLike[str]], + num_dragon_envs: int | None = None, + ) -> list[dict[str, str]]: with fileinput.FileInput(file_paths) as ifstream: dragon_envs = cls._parse_launched_dragon_server_info_from_iterable( ifstream, num_dragon_envs @@ -468,16 +469,16 @@ def _send_req_with_socket( return response -def _assert_schema_type(obj: object, typ: t.Type[_SchemaT], /) -> _SchemaT: +def _assert_schema_type(obj: object, typ: type[_SchemaT], /) -> _SchemaT: if not isinstance(obj, typ): raise TypeError(f"Expected schema of type `{typ}`, but got {type(obj)}") return obj def _dragon_cleanup( - server_socket: t.Optional[zmq.Socket[t.Any]] = None, - server_process_pid: t.Optional[int] = 0, - server_authenticator: t.Optional[zmq.auth.thread.ThreadAuthenticator] = None, + server_socket: zmq.Socket[t.Any] | None = None, + server_process_pid: int | None = 0, + server_authenticator: zmq.auth.thread.ThreadAuthenticator | None = None, ) -> None: """Clean up resources used by the launcher. :param server_socket: (optional) Socket used to connect to dragon environment @@ -519,7 +520,7 @@ def _dragon_cleanup( print("Authenticator shutdown is complete") -def _resolve_dragon_path(fallback: t.Union[str, "os.PathLike[str]"]) -> Path: +def _resolve_dragon_path(fallback: str | os.PathLike[str]) -> Path: dragon_server_path = get_config().dragon_server_path or os.path.join( fallback, ".smartsim", "dragon" ) diff --git a/smartsim/_core/launcher/dragon/dragonLauncher.py b/smartsim/_core/launcher/dragon/dragonLauncher.py index 911625800..666f09104 100644 --- a/smartsim/_core/launcher/dragon/dragonLauncher.py +++ b/smartsim/_core/launcher/dragon/dragonLauncher.py @@ -27,7 +27,6 @@ from __future__ import annotations import os -import typing as t from smartsim._core.schemas.dragonRequests import DragonRunPolicy @@ -92,7 +91,7 @@ def cleanup(self) -> None: # RunSettings types supported by this launcher @property - def supported_rs(self) -> t.Dict[t.Type[SettingsBase], t.Type[Step]]: + def supported_rs(self) -> dict[type[SettingsBase], type[Step]]: # RunSettings types supported by this launcher return { DragonRunSettings: DragonStep, @@ -106,7 +105,7 @@ def add_step_to_mapping_table(self, name: str, step_map: StepMap) -> None: if step_map.step_id is None: return - sublauncher: t.Optional[t.Union[SlurmLauncher, PBSLauncher]] = None + sublauncher: SlurmLauncher | PBSLauncher | None = None if step_map.step_id.startswith("SLURM-"): sublauncher = self._slurm_launcher elif step_map.step_id.startswith("PBS-"): @@ -121,7 +120,7 @@ def add_step_to_mapping_table(self, name: str, step_map: StepMap) -> None: ) sublauncher.add_step_to_mapping_table(name, sublauncher_step_map) - def run(self, step: Step) -> t.Optional[str]: + def run(self, step: Step) -> str | None: """Run a job step through Slurm :param step: a job step instance @@ -140,7 +139,7 @@ def run(self, step: Step) -> t.Optional[str]: if isinstance(step, DragonBatchStep): # wait for batch step to submit successfully - sublauncher_step_id: t.Optional[str] = None + sublauncher_step_id: str | None = None return_code, out, err = self.task_manager.start_and_wait(cmd, step.cwd) if return_code != 0: raise LauncherError(f"Sbatch submission failed\n {out}\n {err}") @@ -241,7 +240,7 @@ def stop(self, step_name: str) -> StepInfo: def _unprefix_step_id(step_id: str) -> str: return step_id.split("-", maxsplit=1)[1] - def _get_managed_step_update(self, step_ids: t.List[str]) -> t.List[StepInfo]: + def _get_managed_step_update(self, step_ids: list[str]) -> list[StepInfo]: """Get step updates for Dragon-managed jobs :param step_ids: list of job step ids @@ -250,9 +249,9 @@ def _get_managed_step_update(self, step_ids: t.List[str]) -> t.List[StepInfo]: step_id_updates: dict[str, StepInfo] = {} - dragon_step_ids: t.List[str] = [] - slurm_step_ids: t.List[str] = [] - pbs_step_ids: t.List[str] = [] + dragon_step_ids: list[str] = [] + slurm_step_ids: list[str] = [] + pbs_step_ids: list[str] = [] for step_id in step_ids: if step_id.startswith("SLURM-"): slurm_step_ids.append(step_id) @@ -321,7 +320,7 @@ def __str__(self) -> str: return "Dragon" -def _assert_schema_type(obj: object, typ: t.Type[_SchemaT], /) -> _SchemaT: +def _assert_schema_type(obj: object, typ: type[_SchemaT], /) -> _SchemaT: if not isinstance(obj, typ): raise TypeError(f"Expected schema of type `{typ}`, but got {type(obj)}") return obj diff --git a/smartsim/_core/launcher/dragon/dragonSockets.py b/smartsim/_core/launcher/dragon/dragonSockets.py index ae669acdd..6b2dcb96a 100644 --- a/smartsim/_core/launcher/dragon/dragonSockets.py +++ b/smartsim/_core/launcher/dragon/dragonSockets.py @@ -42,7 +42,7 @@ logger = get_logger(__name__) -AUTHENTICATOR: t.Optional["zmq.auth.thread.ThreadAuthenticator"] = None +AUTHENTICATOR: "zmq.auth.thread.ThreadAuthenticator | None" = None def as_server( diff --git a/smartsim/_core/launcher/launcher.py b/smartsim/_core/launcher/launcher.py index 87ab468cd..70e7900d5 100644 --- a/smartsim/_core/launcher/launcher.py +++ b/smartsim/_core/launcher/launcher.py @@ -25,7 +25,6 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import abc -import typing as t from ..._core.launcher.stepMapping import StepMap from ...error import AllocationError, LauncherError, SSUnsupportedError @@ -54,16 +53,16 @@ def create_step(self, name: str, cwd: str, step_settings: SettingsBase) -> Step: @abc.abstractmethod def get_step_update( - self, step_names: t.List[str] - ) -> t.List[t.Tuple[str, t.Union[StepInfo, None]]]: + self, step_names: list[str] + ) -> list[tuple[str, StepInfo | None]]: raise NotImplementedError @abc.abstractmethod - def get_step_nodes(self, step_names: t.List[str]) -> t.List[t.List[str]]: + def get_step_nodes(self, step_names: list[str]) -> list[list[str]]: raise NotImplementedError @abc.abstractmethod - def run(self, step: Step) -> t.Optional[str]: + def run(self, step: Step) -> str | None: raise NotImplementedError @abc.abstractmethod @@ -93,7 +92,7 @@ def __init__(self) -> None: @property @abc.abstractmethod - def supported_rs(self) -> t.Dict[t.Type[SettingsBase], t.Type[Step]]: + def supported_rs(self) -> dict[type[SettingsBase], type[Step]]: raise NotImplementedError # every launcher utilizing this interface must have a map @@ -125,19 +124,19 @@ def create_step( # don't need to be covered here. def get_step_nodes( - self, step_names: t.List[str] - ) -> t.List[t.List[str]]: # pragma: no cover + self, step_names: list[str] + ) -> list[list[str]]: # pragma: no cover raise SSUnsupportedError("Node acquisition not supported for this launcher") def get_step_update( - self, step_names: t.List[str] - ) -> t.List[t.Tuple[str, t.Union[StepInfo, None]]]: # cov-wlm + self, step_names: list[str] + ) -> list[tuple[str, StepInfo | None]]: # cov-wlm """Get update for a list of job steps :param step_names: list of job steps to get updates for :return: list of name, job update tuples """ - updates: t.List[t.Tuple[str, t.Union[StepInfo, None]]] = [] + updates: list[tuple[str, StepInfo | None]] = [] # get updates of jobs managed by workload manager (PBS, Slurm, etc) # this is primarily batch jobs. @@ -161,8 +160,8 @@ def get_step_update( return updates def _get_unmanaged_step_update( - self, task_ids: t.List[str] - ) -> t.List[UnmanagedStepInfo]: # cov-wlm + self, task_ids: list[str] + ) -> list[UnmanagedStepInfo]: # cov-wlm """Get step updates for Popen managed jobs :param task_ids: task id to check @@ -178,6 +177,6 @@ def _get_unmanaged_step_update( # pylint: disable-next=no-self-use def _get_managed_step_update( self, - step_ids: t.List[str], # pylint: disable=unused-argument - ) -> t.List[StepInfo]: # pragma: no cover + step_ids: list[str], # pylint: disable=unused-argument + ) -> list[StepInfo]: # pragma: no cover return [] diff --git a/smartsim/_core/launcher/local/local.py b/smartsim/_core/launcher/local/local.py index 2fc470021..6cff067ce 100644 --- a/smartsim/_core/launcher/local/local.py +++ b/smartsim/_core/launcher/local/local.py @@ -24,7 +24,6 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import typing as t from ....settings import RunSettings, SettingsBase from ..launcher import Launcher @@ -54,8 +53,8 @@ def create_step(self, name: str, cwd: str, step_settings: SettingsBase) -> Step: return LocalStep(name, cwd, step_settings) def get_step_update( - self, step_names: t.List[str] - ) -> t.List[t.Tuple[str, t.Optional[StepInfo]]]: + self, step_names: list[str] + ) -> list[tuple[str, StepInfo | None]]: """Get status updates of each job step name provided :param step_names: list of step_names @@ -63,7 +62,7 @@ def get_step_update( """ # step ids are process ids of the tasks # as there is no WLM intermediary - updates: t.List[t.Tuple[str, t.Optional[StepInfo]]] = [] + updates: list[tuple[str, StepInfo | None]] = [] s_names, s_ids = self.step_mapping.get_ids(step_names, managed=False) for step_name, step_id in zip(s_names, s_ids): status, ret_code, out, err = self.task_manager.get_task_update(str(step_id)) @@ -72,7 +71,7 @@ def get_step_update( updates.append(update) return updates - def get_step_nodes(self, step_names: t.List[str]) -> t.List[t.List[str]]: + def get_step_nodes(self, step_names: list[str]) -> list[list[str]]: """Return the address of nodes assigned to the step :param step_names: list of step_names diff --git a/smartsim/_core/launcher/pbs/pbsCommands.py b/smartsim/_core/launcher/pbs/pbsCommands.py index a0eb8a988..de3f402f5 100644 --- a/smartsim/_core/launcher/pbs/pbsCommands.py +++ b/smartsim/_core/launcher/pbs/pbsCommands.py @@ -24,12 +24,11 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import typing as t from ...utils.shell import execute_cmd -def qstat(args: t.List[str]) -> t.Tuple[str, str]: +def qstat(args: list[str]) -> tuple[str, str]: """Calls PBS qstat with args :param args: List of command arguments @@ -40,7 +39,7 @@ def qstat(args: t.List[str]) -> t.Tuple[str, str]: return out, error -def qsub(args: t.List[str]) -> t.Tuple[str, str]: +def qsub(args: list[str]) -> tuple[str, str]: """Calls PBS qsub with args :param args: List of command arguments @@ -51,7 +50,7 @@ def qsub(args: t.List[str]) -> t.Tuple[str, str]: return out, error -def qdel(args: t.List[str]) -> t.Tuple[int, str, str]: +def qdel(args: list[str]) -> tuple[int, str, str]: """Calls PBS qdel with args. returncode is also supplied in this function. diff --git a/smartsim/_core/launcher/pbs/pbsLauncher.py b/smartsim/_core/launcher/pbs/pbsLauncher.py index 6907c13de..f3d312fbe 100644 --- a/smartsim/_core/launcher/pbs/pbsLauncher.py +++ b/smartsim/_core/launcher/pbs/pbsLauncher.py @@ -25,7 +25,6 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import time -import typing as t from ....error import LauncherError from ....log import get_logger @@ -76,7 +75,7 @@ class PBSLauncher(WLMLauncher): # init in WLMLauncher, launcher.py @property - def supported_rs(self) -> t.Dict[t.Type[SettingsBase], t.Type[Step]]: + def supported_rs(self) -> dict[type[SettingsBase], type[Step]]: # RunSettings types supported by this launcher return { AprunSettings: AprunStep, @@ -88,7 +87,7 @@ def supported_rs(self) -> t.Dict[t.Type[SettingsBase], t.Type[Step]]: PalsMpiexecSettings: MpiexecStep, } - def run(self, step: Step) -> t.Optional[str]: + def run(self, step: Step) -> str | None: """Run a job step through PBSPro :param step: a job step instance @@ -99,8 +98,8 @@ def run(self, step: Step) -> t.Optional[str]: self.task_manager.start() cmd_list = step.get_launch_cmd() - step_id: t.Optional[str] = None - task_id: t.Optional[str] = None + step_id: str | None = None + task_id: str | None = None if isinstance(step, QsubBatchStep): # wait for batch step to submit successfully return_code, out, err = self.task_manager.start_and_wait(cmd_list, step.cwd) @@ -162,7 +161,7 @@ def _get_pbs_step_id(step: Step, interval: int = 2) -> str: TODO: change this to use ``qstat -a -u user`` """ time.sleep(interval) - step_id: t.Optional[str] = None + step_id: str | None = None trials = CONFIG.wlm_trials while trials > 0: output, _ = qstat(["-f", "-F", "json"]) @@ -176,13 +175,13 @@ def _get_pbs_step_id(step: Step, interval: int = 2) -> str: raise LauncherError("Could not find id of launched job step") return step_id - def _get_managed_step_update(self, step_ids: t.List[str]) -> t.List[StepInfo]: + def _get_managed_step_update(self, step_ids: list[str]) -> list[StepInfo]: """Get step updates for WLM managed jobs :param step_ids: list of job step ids :return: list of updates for managed jobs """ - updates: t.List[StepInfo] = [] + updates: list[StepInfo] = [] qstat_out, _ = qstat(step_ids) stats = [parse_qstat_jobid(qstat_out, str(step_id)) for step_id in step_ids] diff --git a/smartsim/_core/launcher/pbs/pbsParser.py b/smartsim/_core/launcher/pbs/pbsParser.py index 8ded7c380..4439c52fa 100644 --- a/smartsim/_core/launcher/pbs/pbsParser.py +++ b/smartsim/_core/launcher/pbs/pbsParser.py @@ -57,7 +57,7 @@ def parse_qsub_error(output: str) -> str: return base_err -def parse_qstat_jobid(output: str, job_id: str) -> t.Optional[str]: +def parse_qstat_jobid(output: str, job_id: str) -> str | None: """Parse and return output of the qstat command run with options to obtain job status. @@ -76,7 +76,7 @@ def parse_qstat_jobid(output: str, job_id: str) -> t.Optional[str]: return result -def parse_qstat_jobid_json(output: str, job_id: str) -> t.Optional[str]: +def parse_qstat_jobid_json(output: str, job_id: str) -> str | None: """Parse and return output of the qstat command run with JSON options to obtain job status. @@ -89,13 +89,13 @@ def parse_qstat_jobid_json(output: str, job_id: str) -> t.Optional[str]: if "Jobs" not in out_json: return None jobs: dict[str, t.Any] = out_json["Jobs"] - job: t.Optional[dict[str, t.Any]] = jobs.get(job_id, None) + job: dict[str, t.Any] | None = jobs.get(job_id, None) if job is None: return None return str(job.get("job_state", None)) -def parse_qstat_nodes(output: str) -> t.List[str]: +def parse_qstat_nodes(output: str) -> list[str]: """Parse and return the qstat command run with options to obtain node list. @@ -107,7 +107,7 @@ def parse_qstat_nodes(output: str) -> t.List[str]: :param output: output of the qstat command in JSON format :return: compute nodes of the allocation or job """ - nodes: t.List[str] = [] + nodes: list[str] = [] out_json = load_and_clean_json(output) if "Jobs" not in out_json: return nodes @@ -122,14 +122,14 @@ def parse_qstat_nodes(output: str) -> t.List[str]: return list(sorted(set(nodes))) -def parse_step_id_from_qstat(output: str, step_name: str) -> t.Optional[str]: +def parse_step_id_from_qstat(output: str, step_name: str) -> str | None: """Parse and return the step id from a qstat command :param output: output qstat :param step_name: the name of the step to query :return: the step_id """ - step_id: t.Optional[str] = None + step_id: str | None = None out_json = load_and_clean_json(output) if "Jobs" not in out_json: diff --git a/smartsim/_core/launcher/sge/sgeCommands.py b/smartsim/_core/launcher/sge/sgeCommands.py index c9160b6ac..710b4ec7c 100644 --- a/smartsim/_core/launcher/sge/sgeCommands.py +++ b/smartsim/_core/launcher/sge/sgeCommands.py @@ -24,12 +24,11 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import typing as t from ...utils.shell import execute_cmd -def qstat(args: t.List[str]) -> t.Tuple[str, str]: +def qstat(args: list[str]) -> tuple[str, str]: """Calls SGE qstat with args :param args: List of command arguments @@ -40,7 +39,7 @@ def qstat(args: t.List[str]) -> t.Tuple[str, str]: return out, error -def qsub(args: t.List[str]) -> t.Tuple[str, str]: +def qsub(args: list[str]) -> tuple[str, str]: """Calls SGE qsub with args :param args: List of command arguments @@ -51,7 +50,7 @@ def qsub(args: t.List[str]) -> t.Tuple[str, str]: return out, error -def qdel(args: t.List[str]) -> t.Tuple[int, str, str]: +def qdel(args: list[str]) -> tuple[int, str, str]: """Calls SGE qdel with args. returncode is also supplied in this function. @@ -64,7 +63,7 @@ def qdel(args: t.List[str]) -> t.Tuple[int, str, str]: return returncode, out, error -def qacct(args: t.List[str]) -> t.Tuple[int, str, str]: +def qacct(args: list[str]) -> tuple[int, str, str]: """Calls SGE qacct with args. returncode is also supplied in this function. diff --git a/smartsim/_core/launcher/sge/sgeLauncher.py b/smartsim/_core/launcher/sge/sgeLauncher.py index 920fab4d7..f6b4558ce 100644 --- a/smartsim/_core/launcher/sge/sgeLauncher.py +++ b/smartsim/_core/launcher/sge/sgeLauncher.py @@ -25,7 +25,6 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import time -import typing as t from ....error import LauncherError from ....log import get_logger @@ -69,7 +68,7 @@ class SGELauncher(WLMLauncher): # init in WLMLauncher, launcher.py @property - def supported_rs(self) -> t.Dict[t.Type[SettingsBase], t.Type[Step]]: + def supported_rs(self) -> dict[type[SettingsBase], type[Step]]: # RunSettings types supported by this launcher return { SgeQsubBatchSettings: SgeQsubBatchStep, @@ -79,7 +78,7 @@ def supported_rs(self) -> t.Dict[t.Type[SettingsBase], t.Type[Step]]: RunSettings: LocalStep, } - def run(self, step: Step) -> t.Optional[str]: + def run(self, step: Step) -> str | None: """Run a job step through SGE :param step: a job step instance @@ -90,8 +89,8 @@ def run(self, step: Step) -> t.Optional[str]: self.task_manager.start() cmd_list = step.get_launch_cmd() - step_id: t.Optional[str] = None - task_id: t.Optional[str] = None + step_id: str | None = None + task_id: str | None = None if isinstance(step, SgeQsubBatchStep): # wait for batch step to submit successfully return_code, out, err = self.task_manager.start_and_wait(cmd_list, step.cwd) @@ -141,13 +140,13 @@ def stop(self, step_name: str) -> StepInfo: ) # set status to cancelled instead of failed return step_info - def _get_managed_step_update(self, step_ids: t.List[str]) -> t.List[StepInfo]: + def _get_managed_step_update(self, step_ids: list[str]) -> list[StepInfo]: """Get step updates for WLM managed jobs :param step_ids: list of job step ids :return: list of updates for managed jobs """ - updates: t.List[StepInfo] = [] + updates: list[StepInfo] = [] qstat_out, _ = qstat(["-xml"]) stats = [parse_qstat_jobid_xml(qstat_out, str(step_id)) for step_id in step_ids] diff --git a/smartsim/_core/launcher/sge/sgeParser.py b/smartsim/_core/launcher/sge/sgeParser.py index ec811d53b..de03c5416 100644 --- a/smartsim/_core/launcher/sge/sgeParser.py +++ b/smartsim/_core/launcher/sge/sgeParser.py @@ -24,7 +24,6 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import typing as t import xml.etree.ElementTree as ET @@ -57,7 +56,7 @@ def parse_qsub_error(output: str) -> str: return base_err -def parse_qstat_jobid_xml(output: str, job_id: str) -> t.Optional[str]: +def parse_qstat_jobid_xml(output: str, job_id: str) -> str | None: """Parse and return output of the qstat command run with XML options to obtain job status. @@ -78,7 +77,7 @@ def parse_qstat_jobid_xml(output: str, job_id: str) -> t.Optional[str]: return None -def parse_qacct_job_output(output: str, field_name: str) -> t.Union[str, int]: +def parse_qacct_job_output(output: str, field_name: str) -> str | int: """Parse the output from qacct for a single job :param output: The raw text output from qacct diff --git a/smartsim/_core/launcher/slurm/slurmCommands.py b/smartsim/_core/launcher/slurm/slurmCommands.py index ee043c759..08da33fc1 100644 --- a/smartsim/_core/launcher/slurm/slurmCommands.py +++ b/smartsim/_core/launcher/slurm/slurmCommands.py @@ -24,7 +24,6 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import typing as t from ....error import LauncherError from ....log import get_logger @@ -34,7 +33,7 @@ logger = get_logger(__name__) -def sstat(args: t.List[str], *, raise_on_err: bool = False) -> t.Tuple[str, str]: +def sstat(args: list[str], *, raise_on_err: bool = False) -> tuple[str, str]: """Calls sstat with args :param args: List of command arguments @@ -44,7 +43,7 @@ def sstat(args: t.List[str], *, raise_on_err: bool = False) -> t.Tuple[str, str] return out, err -def sacct(args: t.List[str], *, raise_on_err: bool = False) -> t.Tuple[str, str]: +def sacct(args: list[str], *, raise_on_err: bool = False) -> tuple[str, str]: """Calls sacct with args :param args: List of command arguments @@ -54,7 +53,7 @@ def sacct(args: t.List[str], *, raise_on_err: bool = False) -> t.Tuple[str, str] return out, err -def salloc(args: t.List[str], *, raise_on_err: bool = False) -> t.Tuple[str, str]: +def salloc(args: list[str], *, raise_on_err: bool = False) -> tuple[str, str]: """Calls slurm salloc with args :param args: List of command arguments @@ -64,7 +63,7 @@ def salloc(args: t.List[str], *, raise_on_err: bool = False) -> t.Tuple[str, str return out, err -def sinfo(args: t.List[str], *, raise_on_err: bool = False) -> t.Tuple[str, str]: +def sinfo(args: list[str], *, raise_on_err: bool = False) -> tuple[str, str]: """Calls slurm sinfo with args :param args: List of command arguments @@ -74,7 +73,7 @@ def sinfo(args: t.List[str], *, raise_on_err: bool = False) -> t.Tuple[str, str] return out, err -def scontrol(args: t.List[str], *, raise_on_err: bool = False) -> t.Tuple[str, str]: +def scontrol(args: list[str], *, raise_on_err: bool = False) -> tuple[str, str]: """Calls slurm scontrol with args :param args: List of command arguments @@ -84,7 +83,7 @@ def scontrol(args: t.List[str], *, raise_on_err: bool = False) -> t.Tuple[str, s return out, err -def scancel(args: t.List[str], *, raise_on_err: bool = False) -> t.Tuple[int, str, str]: +def scancel(args: list[str], *, raise_on_err: bool = False) -> tuple[int, str, str]: """Calls slurm scancel with args. returncode is also supplied in this function. @@ -106,8 +105,8 @@ def _find_slurm_command(cmd: str) -> str: def _execute_slurm_cmd( - command: str, args: t.List[str], raise_on_err: bool = False -) -> t.Tuple[int, str, str]: + command: str, args: list[str], raise_on_err: bool = False +) -> tuple[int, str, str]: cmd_exe = _find_slurm_command(command) cmd = [cmd_exe] + args returncode, out, error = execute_cmd(cmd) diff --git a/smartsim/_core/launcher/slurm/slurmLauncher.py b/smartsim/_core/launcher/slurm/slurmLauncher.py index dba0cd5ed..5b8bda6f5 100644 --- a/smartsim/_core/launcher/slurm/slurmLauncher.py +++ b/smartsim/_core/launcher/slurm/slurmLauncher.py @@ -26,7 +26,6 @@ import os import time -import typing as t from shutil import which from ....error import LauncherError @@ -74,7 +73,7 @@ class SlurmLauncher(WLMLauncher): # RunSettings types supported by this launcher @property - def supported_rs(self) -> t.Dict[t.Type[SettingsBase], t.Type[Step]]: + def supported_rs(self) -> dict[type[SettingsBase], type[Step]]: # RunSettings types supported by this launcher return { SrunSettings: SrunStep, @@ -85,7 +84,7 @@ def supported_rs(self) -> t.Dict[t.Type[SettingsBase], t.Type[Step]]: RunSettings: LocalStep, } - def get_step_nodes(self, step_names: t.List[str]) -> t.List[t.List[str]]: + def get_step_nodes(self, step_names: list[str]) -> list[list[str]]: """Return the compute nodes of a specific job or allocation This function returns the compute nodes of a specific job or allocation @@ -116,7 +115,7 @@ def get_step_nodes(self, step_names: t.List[str]) -> t.List[t.List[str]]: raise LauncherError("Failed to retrieve nodelist from stat") return node_lists - def run(self, step: Step) -> t.Optional[str]: + def run(self, step: Step) -> str | None: """Run a job step through Slurm :param step: a job step instance @@ -230,7 +229,7 @@ def _get_slurm_step_id(step: Step, interval: int = 2) -> str: m2-119225.1|119225.1| """ time.sleep(interval) - step_id: t.Optional[str] = None + step_id: str | None = None trials = CONFIG.wlm_trials while trials > 0: output, _ = sacct( @@ -247,7 +246,7 @@ def _get_slurm_step_id(step: Step, interval: int = 2) -> str: raise LauncherError("Could not find id of launched job step") return step_id - def _get_managed_step_update(self, step_ids: t.List[str]) -> t.List[StepInfo]: + def _get_managed_step_update(self, step_ids: list[str]) -> list[StepInfo]: """Get step updates for WLM managed jobs :param step_ids: list of job step ids @@ -262,7 +261,7 @@ def _get_managed_step_update(self, step_ids: t.List[str]) -> t.List[StepInfo]: stat_tuples = [parse_sacct(sacct_out, step_id) for step_id in step_ids] # create SlurmStepInfo objects to return - updates: t.List[StepInfo] = [] + updates: list[StepInfo] = [] for stat_tuple, step_id in zip(stat_tuples, step_ids): _rc = int(stat_tuple[1]) if stat_tuple[1] else None info = SlurmStepInfo(stat_tuple[0], _rc) @@ -301,5 +300,5 @@ def __str__(self) -> str: return "Slurm" -def _create_step_id_str(step_ids: t.List[str]) -> str: +def _create_step_id_str(step_ids: list[str]) -> str: return ",".join(step_ids) diff --git a/smartsim/_core/launcher/slurm/slurmParser.py b/smartsim/_core/launcher/slurm/slurmParser.py index 29ce00317..ee1732b36 100644 --- a/smartsim/_core/launcher/slurm/slurmParser.py +++ b/smartsim/_core/launcher/slurm/slurmParser.py @@ -24,7 +24,6 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import typing as t from shutil import which """ @@ -32,14 +31,14 @@ """ -def parse_salloc(output: str) -> t.Optional[str]: +def parse_salloc(output: str) -> str | None: for line in output.split("\n"): if line.startswith("salloc: Granted job allocation"): return line.split()[-1] return None -def parse_salloc_error(output: str) -> t.Optional[str]: +def parse_salloc_error(output: str) -> str | None: """Parse and return error output of a failed salloc command :param output: stderr output of salloc command @@ -81,14 +80,14 @@ def jobid_exact_match(parsed_id: str, job_id: str) -> bool: return parsed_id.split(".")[0] == job_id -def parse_sacct(output: str, job_id: str) -> t.Tuple[str, t.Optional[str]]: +def parse_sacct(output: str, job_id: str) -> tuple[str, str | None]: """Parse and return output of the sacct command :param output: output of the sacct command :param job_id: allocation id or job step id :return: status and returncode """ - result: t.Tuple[str, t.Optional[str]] = ("PENDING", None) + result: tuple[str, str | None] = ("PENDING", None) for line in output.split("\n"): parts = line.split("|") if len(parts) >= 3: @@ -100,7 +99,7 @@ def parse_sacct(output: str, job_id: str) -> t.Tuple[str, t.Optional[str]]: return result -def parse_sstat_nodes(output: str, job_id: str) -> t.List[str]: +def parse_sstat_nodes(output: str, job_id: str) -> list[str]: """Parse and return the sstat command This function parses and returns the nodes of @@ -121,7 +120,7 @@ def parse_sstat_nodes(output: str, job_id: str) -> t.List[str]: return list(set(nodes)) -def parse_step_id_from_sacct(output: str, step_name: str) -> t.Optional[str]: +def parse_step_id_from_sacct(output: str, step_name: str) -> str | None: """Parse and return the step id from a sacct command :param output: output of sacct --noheader -p diff --git a/smartsim/_core/launcher/step/alpsStep.py b/smartsim/_core/launcher/step/alpsStep.py index ff0ef69b6..d102f5333 100644 --- a/smartsim/_core/launcher/step/alpsStep.py +++ b/smartsim/_core/launcher/step/alpsStep.py @@ -26,7 +26,6 @@ import os import shutil -import typing as t from shlex import split as sh_split from ....error import AllocationError @@ -46,18 +45,18 @@ def __init__(self, name: str, cwd: str, run_settings: AprunSettings) -> None: :param run_settings: run settings for entity """ super().__init__(name, cwd, run_settings) - self.alloc: t.Optional[str] = None + self.alloc: str | None = None if not run_settings.in_batch: self._set_alloc() self.run_settings = run_settings - def _get_mpmd(self) -> t.List[RunSettings]: + def _get_mpmd(self) -> list[RunSettings]: """Temporary convenience function to return a typed list of attached RunSettings """ return self.run_settings.mpmd - def get_launch_cmd(self) -> t.List[str]: + def get_launch_cmd(self) -> list[str]: """Get the command to launch this step :return: launch command @@ -113,7 +112,7 @@ def _set_alloc(self) -> None: "No allocation specified or found and not running in batch" ) - def _build_exe(self) -> t.List[str]: + def _build_exe(self) -> list[str]: """Build the executable for this step :return: executable list @@ -125,7 +124,7 @@ def _build_exe(self) -> t.List[str]: args = self.run_settings._exe_args # pylint: disable=protected-access return exe + args - def _make_mpmd(self) -> t.List[str]: + def _make_mpmd(self) -> list[str]: """Build Aprun (MPMD) executable""" exe = self.run_settings.exe diff --git a/smartsim/_core/launcher/step/dragonStep.py b/smartsim/_core/launcher/step/dragonStep.py index a5c851c4e..60d9eefa5 100644 --- a/smartsim/_core/launcher/step/dragonStep.py +++ b/smartsim/_core/launcher/step/dragonStep.py @@ -63,7 +63,7 @@ def __init__(self, name: str, cwd: str, run_settings: DragonRunSettings) -> None def run_settings(self) -> DragonRunSettings: return t.cast(DragonRunSettings, self.step_settings) - def get_launch_cmd(self) -> t.List[str]: + def get_launch_cmd(self) -> list[str]: """Get stringified version of request needed to launch this step @@ -93,12 +93,12 @@ def get_launch_cmd(self) -> t.List[str]: return exe_cmd_and_args @staticmethod - def _get_exe_args_list(run_setting: DragonRunSettings) -> t.List[str]: + def _get_exe_args_list(run_setting: DragonRunSettings) -> list[str]: """Convenience function to encapsulate checking the runsettings.exe_args type to always return a list """ exe_args = run_setting.exe_args - args: t.List[str] = exe_args if isinstance(exe_args, list) else [exe_args] + args: list[str] = exe_args if isinstance(exe_args, list) else [exe_args] return args @@ -107,7 +107,7 @@ def __init__( self, name: str, cwd: str, - batch_settings: t.Union[SbatchSettings, QsubBatchSettings], + batch_settings: SbatchSettings | QsubBatchSettings, ) -> None: """Initialize a Slurm Sbatch step @@ -116,12 +116,12 @@ def __init__( :param batch_settings: batch settings for entity """ super().__init__(name, cwd, batch_settings) - self.steps: t.List[Step] = [] + self.steps: list[Step] = [] self.managed = True self.batch_settings = batch_settings self._request_file_name = "requests.json" - def get_launch_cmd(self) -> t.List[str]: + def get_launch_cmd(self) -> list[str]: """Get the launch command for the batch :return: launch command for the batch diff --git a/smartsim/_core/launcher/step/localStep.py b/smartsim/_core/launcher/step/localStep.py index cd527f1dd..9ad104473 100644 --- a/smartsim/_core/launcher/step/localStep.py +++ b/smartsim/_core/launcher/step/localStep.py @@ -26,7 +26,6 @@ import os import shutil -import typing as t from ....settings import Singularity from ....settings.base import RunSettings @@ -40,10 +39,10 @@ def __init__(self, name: str, cwd: str, run_settings: RunSettings): self._env = self._set_env() @property - def env(self) -> t.Dict[str, str]: + def env(self) -> dict[str, str]: return self._env - def get_launch_cmd(self) -> t.List[str]: + def get_launch_cmd(self) -> list[str]: cmd = [] # Add run command and args if user specified @@ -72,7 +71,7 @@ def get_launch_cmd(self) -> t.List[str]: cmd.extend(self.run_settings.exe_args) return cmd - def _set_env(self) -> t.Dict[str, str]: + def _set_env(self) -> dict[str, str]: env = os.environ.copy() if self.run_settings.env_vars: for k, v in self.run_settings.env_vars.items(): diff --git a/smartsim/_core/launcher/step/mpiStep.py b/smartsim/_core/launcher/step/mpiStep.py index 8972c9b5e..c272f59f4 100644 --- a/smartsim/_core/launcher/step/mpiStep.py +++ b/smartsim/_core/launcher/step/mpiStep.py @@ -26,7 +26,6 @@ import os import shutil -import typing as t from shlex import split as sh_split from ....error import AllocationError, SmartSimError @@ -49,14 +48,14 @@ def __init__(self, name: str, cwd: str, run_settings: RunSettings) -> None: super().__init__(name, cwd, run_settings) - self.alloc: t.Optional[str] = None + self.alloc: str | None = None if not run_settings.in_batch: self._set_alloc() self.run_settings = run_settings _supported_launchers = ["PBS", "SLURM", "LSB", "SGE"] - def get_launch_cmd(self) -> t.List[str]: + def get_launch_cmd(self) -> list[str]: """Get the command to launch this step :return: launch command @@ -115,16 +114,16 @@ def _set_alloc(self) -> None: "No allocation specified or found and not running in batch" ) - def _get_mpmd(self) -> t.List[RunSettings]: + def _get_mpmd(self) -> list[RunSettings]: """Temporary convenience function to return a typed list of attached RunSettings """ if hasattr(self.run_settings, "mpmd") and self.run_settings.mpmd: - rs_mpmd: t.List[RunSettings] = self.run_settings.mpmd + rs_mpmd: list[RunSettings] = self.run_settings.mpmd return rs_mpmd return [] - def _build_exe(self) -> t.List[str]: + def _build_exe(self) -> list[str]: """Build the executable for this step :return: executable list @@ -136,7 +135,7 @@ def _build_exe(self) -> t.List[str]: args = self.run_settings._exe_args # pylint: disable=protected-access return exe + args - def _make_mpmd(self) -> t.List[str]: + def _make_mpmd(self) -> list[str]: """Build mpiexec (MPMD) executable""" exe = self.run_settings.exe args = self.run_settings._exe_args # pylint: disable=protected-access diff --git a/smartsim/_core/launcher/step/pbsStep.py b/smartsim/_core/launcher/step/pbsStep.py index bc96659b4..124fb2660 100644 --- a/smartsim/_core/launcher/step/pbsStep.py +++ b/smartsim/_core/launcher/step/pbsStep.py @@ -24,7 +24,6 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import typing as t from ....log import get_logger from ....settings import QsubBatchSettings @@ -42,11 +41,11 @@ def __init__(self, name: str, cwd: str, batch_settings: QsubBatchSettings) -> No :param batch_settings: batch settings for entity """ super().__init__(name, cwd, batch_settings) - self.step_cmds: t.List[t.List[str]] = [] + self.step_cmds: list[list[str]] = [] self.managed = True self.batch_settings = batch_settings - def get_launch_cmd(self) -> t.List[str]: + def get_launch_cmd(self) -> list[str]: """Get the launch command for the batch :return: launch command for the batch diff --git a/smartsim/_core/launcher/step/sgeStep.py b/smartsim/_core/launcher/step/sgeStep.py index 14225e07c..1dc889be9 100644 --- a/smartsim/_core/launcher/step/sgeStep.py +++ b/smartsim/_core/launcher/step/sgeStep.py @@ -24,7 +24,6 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import typing as t from ....log import get_logger from ....settings import SgeQsubBatchSettings @@ -44,11 +43,11 @@ def __init__( :param batch_settings: batch settings for entity """ super().__init__(name, cwd, batch_settings) - self.step_cmds: t.List[t.List[str]] = [] + self.step_cmds: list[list[str]] = [] self.managed = True self.batch_settings = batch_settings - def get_launch_cmd(self) -> t.List[str]: + def get_launch_cmd(self) -> list[str]: """Get the launch command for the batch :return: launch command for the batch diff --git a/smartsim/_core/launcher/step/slurmStep.py b/smartsim/_core/launcher/step/slurmStep.py index 5b5db499e..a14e9b110 100644 --- a/smartsim/_core/launcher/step/slurmStep.py +++ b/smartsim/_core/launcher/step/slurmStep.py @@ -26,7 +26,6 @@ import os import shutil -import typing as t from shlex import split as sh_split from ....error import AllocationError @@ -46,11 +45,11 @@ def __init__(self, name: str, cwd: str, batch_settings: SbatchSettings) -> None: :param batch_settings: batch settings for entity """ super().__init__(name, cwd, batch_settings) - self.step_cmds: t.List[t.List[str]] = [] + self.step_cmds: list[list[str]] = [] self.managed = True self.batch_settings = batch_settings - def get_launch_cmd(self) -> t.List[str]: + def get_launch_cmd(self) -> list[str]: """Get the launch command for the batch :return: launch command for the batch @@ -106,13 +105,13 @@ def __init__(self, name: str, cwd: str, run_settings: SrunSettings) -> None: :param run_settings: run settings for entity """ super().__init__(name, cwd, run_settings) - self.alloc: t.Optional[str] = None + self.alloc: str | None = None self.managed = True self.run_settings = run_settings if not self.run_settings.in_batch: self._set_alloc() - def get_launch_cmd(self) -> t.List[str]: + def get_launch_cmd(self) -> list[str]: """Get the command to launch this step :return: launch command @@ -124,7 +123,7 @@ def get_launch_cmd(self) -> t.List[str]: output, error = self.get_output_files() srun_cmd = [srun, "--output", output, "--error", error, "--job-name", self.name] - compound_env: t.Set[str] = set() + compound_env: set[str] = set() if self.alloc: srun_cmd += ["--jobid", str(self.alloc)] @@ -177,22 +176,22 @@ def _set_alloc(self) -> None: "No allocation specified or found and not running in batch" ) - def _get_mpmd(self) -> t.List[RunSettings]: + def _get_mpmd(self) -> list[RunSettings]: """Temporary convenience function to return a typed list of attached RunSettings """ return self.run_settings.mpmd @staticmethod - def _get_exe_args_list(run_setting: RunSettings) -> t.List[str]: + def _get_exe_args_list(run_setting: RunSettings) -> list[str]: """Convenience function to encapsulate checking the runsettings.exe_args type to always return a list """ exe_args = run_setting.exe_args - args: t.List[str] = exe_args if isinstance(exe_args, list) else [exe_args] + args: list[str] = exe_args if isinstance(exe_args, list) else [exe_args] return args - def _build_exe(self) -> t.List[str]: + def _build_exe(self) -> list[str]: """Build the executable for this step :return: executable list @@ -204,7 +203,7 @@ def _build_exe(self) -> t.List[str]: args = self._get_exe_args_list(self.run_settings) return exe + args - def _make_mpmd(self) -> t.List[str]: + def _make_mpmd(self) -> list[str]: """Build Slurm multi-prog (MPMD) executable""" exe = self.run_settings.exe args = self._get_exe_args_list(self.run_settings) diff --git a/smartsim/_core/launcher/step/step.py b/smartsim/_core/launcher/step/step.py index 4af8054ce..b7bb43e7d 100644 --- a/smartsim/_core/launcher/step/step.py +++ b/smartsim/_core/launcher/step/step.py @@ -30,7 +30,6 @@ import os.path as osp import pathlib import time -import typing as t from os import makedirs from smartsim.error.errors import SmartSimError @@ -50,14 +49,14 @@ def __init__(self, name: str, cwd: str, step_settings: SettingsBase) -> None: self.cwd = cwd self.managed = False self.step_settings = copy.deepcopy(step_settings) - self.meta: t.Dict[str, str] = {} + self.meta: dict[str, str] = {} @property - def env(self) -> t.Optional[t.Dict[str, str]]: + def env(self) -> dict[str, str] | None: """Overridable, read only property for step to specify its environment""" return None - def get_launch_cmd(self) -> t.List[str]: + def get_launch_cmd(self) -> list[str]: raise NotImplementedError @staticmethod @@ -71,7 +70,7 @@ def _ensure_output_directory_exists(output_dir: str) -> None: if not osp.exists(output_dir): pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True) - def get_output_files(self) -> t.Tuple[str, str]: + def get_output_files(self) -> tuple[str, str]: """Return two paths to error and output files based on metadata directory""" try: output_dir = self.meta["metadata_dir"] @@ -82,9 +81,7 @@ def get_output_files(self) -> t.Tuple[str, str]: error = osp.join(output_dir, f"{self.entity_name}.err") return output, error - def get_step_file( - self, ending: str = ".sh", script_name: t.Optional[str] = None - ) -> str: + def get_step_file(self, ending: str = ".sh", script_name: str | None = None) -> str: """Get the name for a file/script created by the step class Used for Batch scripts, mpmd scripts, etc. diff --git a/smartsim/_core/launcher/stepInfo.py b/smartsim/_core/launcher/stepInfo.py index ad72f7131..79ba9e56c 100644 --- a/smartsim/_core/launcher/stepInfo.py +++ b/smartsim/_core/launcher/stepInfo.py @@ -24,7 +24,6 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import typing as t import psutil @@ -36,9 +35,9 @@ def __init__( self, status: SmartSimStatus, launcher_status: str = "", - returncode: t.Optional[int] = None, - output: t.Optional[str] = None, - error: t.Optional[str] = None, + returncode: int | None = None, + output: str | None = None, + error: str | None = None, ) -> None: self.status = status self.launcher_status = launcher_status @@ -53,11 +52,11 @@ def __str__(self) -> str: return info_str @property - def mapping(self) -> t.Dict[str, SmartSimStatus]: + def mapping(self) -> dict[str, SmartSimStatus]: raise NotImplementedError def _get_smartsim_status( - self, status: str, returncode: t.Optional[int] = None + self, status: str, returncode: int | None = None ) -> SmartSimStatus: """ Map the status of the WLM step to a smartsim-specific status @@ -73,7 +72,7 @@ def _get_smartsim_status( class UnmanagedStepInfo(StepInfo): @property - def mapping(self) -> t.Dict[str, SmartSimStatus]: + def mapping(self) -> dict[str, SmartSimStatus]: # see https://github.com/giampaolo/psutil/blob/master/psutil/_pslinux.py # see https://github.com/giampaolo/psutil/blob/master/psutil/_common.py return { @@ -96,9 +95,9 @@ def mapping(self) -> t.Dict[str, SmartSimStatus]: def __init__( self, status: str = "", - returncode: t.Optional[int] = None, - output: t.Optional[str] = None, - error: t.Optional[str] = None, + returncode: int | None = None, + output: str | None = None, + error: str | None = None, ) -> None: smartsim_status = self._get_smartsim_status(status) super().__init__( @@ -138,9 +137,9 @@ class SlurmStepInfo(StepInfo): # cov-slurm def __init__( self, status: str = "", - returncode: t.Optional[int] = None, - output: t.Optional[str] = None, - error: t.Optional[str] = None, + returncode: int | None = None, + output: str | None = None, + error: str | None = None, ) -> None: smartsim_status = self._get_smartsim_status(status) super().__init__( @@ -150,7 +149,7 @@ def __init__( class PBSStepInfo(StepInfo): # cov-pbs @property - def mapping(self) -> t.Dict[str, SmartSimStatus]: + def mapping(self) -> dict[str, SmartSimStatus]: # pylint: disable-next=line-too-long # see http://nusc.nsu.ru/wiki/lib/exe/fetch.php/doc/pbs/PBSReferenceGuide19.2.1.pdf#M11.9.90788.PBSHeading1.81.Job.States return { @@ -176,9 +175,9 @@ def mapping(self) -> t.Dict[str, SmartSimStatus]: def __init__( self, status: str = "", - returncode: t.Optional[int] = None, - output: t.Optional[str] = None, - error: t.Optional[str] = None, + returncode: int | None = None, + output: str | None = None, + error: str | None = None, ) -> None: if status == "NOTFOUND": if returncode is not None: @@ -200,7 +199,7 @@ def __init__( class SGEStepInfo(StepInfo): # cov-pbs @property - def mapping(self) -> t.Dict[str, SmartSimStatus]: + def mapping(self) -> dict[str, SmartSimStatus]: # pylint: disable-next=line-too-long # see https://manpages.ubuntu.com/manpages/jammy/man5/sge_status.5.html return { @@ -250,9 +249,9 @@ def mapping(self) -> t.Dict[str, SmartSimStatus]: def __init__( self, status: str = "", - returncode: t.Optional[int] = None, - output: t.Optional[str] = None, - error: t.Optional[str] = None, + returncode: int | None = None, + output: str | None = None, + error: str | None = None, ) -> None: if status == "NOTFOUND": if returncode is not None: diff --git a/smartsim/_core/launcher/stepMapping.py b/smartsim/_core/launcher/stepMapping.py index 50c12f8bd..b52af18a7 100644 --- a/smartsim/_core/launcher/stepMapping.py +++ b/smartsim/_core/launcher/stepMapping.py @@ -24,7 +24,6 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import typing as t from ...log import get_logger @@ -34,9 +33,9 @@ class StepMap: def __init__( self, - step_id: t.Optional[str] = None, - task_id: t.Optional[str] = None, - managed: t.Optional[bool] = None, + step_id: str | None = None, + task_id: str | None = None, + managed: bool | None = None, ) -> None: self.step_id = step_id self.task_id = task_id @@ -46,7 +45,7 @@ def __init__( class StepMapping: def __init__(self) -> None: # step_name : wlm_id, pid, wlm_managed? - self.mapping: t.Dict[str, StepMap] = {} + self.mapping: dict[str, StepMap] = {} def __getitem__(self, step_name: str) -> StepMap: return self.mapping[step_name] @@ -57,8 +56,8 @@ def __setitem__(self, step_name: str, step_map: StepMap) -> None: def add( self, step_name: str, - step_id: t.Optional[str] = None, - task_id: t.Optional[str] = None, + step_id: str | None = None, + task_id: str | None = None, managed: bool = True, ) -> None: try: @@ -68,7 +67,7 @@ def add( msg = f"Could not add step {step_name} to mapping: {e}" logger.exception(msg) - def get_task_id(self, step_id: str) -> t.Optional[str]: + def get_task_id(self, step_id: str) -> str | None: """Get the task id from the step id""" task_id = None for stepmap in self.mapping.values(): @@ -78,9 +77,9 @@ def get_task_id(self, step_id: str) -> t.Optional[str]: return task_id def get_ids( - self, step_names: t.List[str], managed: bool = True - ) -> t.Tuple[t.List[str], t.List[t.Union[str, None]]]: - ids: t.List[t.Union[str, None]] = [] + self, step_names: list[str], managed: bool = True + ) -> tuple[list[str], list[str | None]]: + ids: list[str | None] = [] names = [] for name in step_names: if name in self.mapping: diff --git a/smartsim/_core/launcher/taskManager.py b/smartsim/_core/launcher/taskManager.py index a2e9393ab..59093166c 100644 --- a/smartsim/_core/launcher/taskManager.py +++ b/smartsim/_core/launcher/taskManager.py @@ -27,7 +27,6 @@ from __future__ import annotations import time -import typing as t from subprocess import PIPE from threading import RLock @@ -62,10 +61,8 @@ class TaskManager: def __init__(self) -> None: """Initialize a task manager thread.""" self.actively_monitoring = False - self.task_history: t.Dict[ - str, t.Tuple[t.Optional[int], t.Optional[str], t.Optional[str]] - ] = {} - self.tasks: t.List[Task] = [] + self.task_history: dict[str, tuple[int | None, str | None, str | None]] = {} + self.tasks: list[Task] = [] self._lock = RLock() def start(self) -> None: @@ -102,9 +99,9 @@ def run(self) -> None: def start_task( self, - cmd_list: t.List[str], + cmd_list: list[str], cwd: str, - env: t.Optional[t.Dict[str, str]] = None, + env: dict[str, str] | None = None, out: int = PIPE, err: int = PIPE, ) -> str: @@ -131,11 +128,11 @@ def start_task( @staticmethod def start_and_wait( - cmd_list: t.List[str], + cmd_list: list[str], cwd: str, - env: t.Optional[t.Dict[str, str]] = None, - timeout: t.Optional[int] = None, - ) -> t.Tuple[int, str, str]: + env: dict[str, str] | None = None, + timeout: int | None = None, + ) -> tuple[int, str, str]: """Start a task not managed by the TaskManager This method is used by launchers to launch managed tasks @@ -193,7 +190,7 @@ def remove_task(self, task_id: str) -> None: def get_task_update( self, task_id: str - ) -> t.Tuple[str, t.Optional[int], t.Optional[str], t.Optional[str]]: + ) -> tuple[str, int | None, str | None, str | None]: """Get the update of a task :param task_id: task id @@ -227,9 +224,9 @@ def get_task_update( def add_task_history( self, task_id: str, - returncode: t.Optional[int] = None, - out: t.Optional[str] = None, - err: t.Optional[str] = None, + returncode: int | None = None, + out: str | None = None, + err: str | None = None, ) -> None: """Add a task to the task history @@ -263,7 +260,7 @@ def __init__(self, process: psutil.Process) -> None: self.process = process self.pid = str(self.process.pid) - def check_status(self) -> t.Optional[int]: + def check_status(self) -> int | None: """Ping the job and return the returncode if finished :return: returncode if finished otherwise None @@ -277,7 +274,7 @@ def check_status(self) -> t.Optional[int]: # have to rely on .kill() to stop. return self.returncode - def get_io(self) -> t.Tuple[t.Optional[str], t.Optional[str]]: + def get_io(self) -> tuple[str | None, str | None]: """Get the IO from the subprocess :return: output and error from the Popen @@ -341,7 +338,7 @@ def wait(self) -> None: self.process.wait() @property - def returncode(self) -> t.Optional[int]: + def returncode(self) -> int | None: if self.owned and isinstance(self.process, psutil.Popen): if self.process.returncode is not None: return int(self.process.returncode) diff --git a/smartsim/_core/launcher/util/launcherUtil.py b/smartsim/_core/launcher/util/launcherUtil.py index 0307bc51b..a58eaf2e4 100644 --- a/smartsim/_core/launcher/util/launcherUtil.py +++ b/smartsim/_core/launcher/util/launcherUtil.py @@ -24,8 +24,6 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import typing as t - class ComputeNode: # cov-slurm """The ComputeNode class holds resource information @@ -33,15 +31,15 @@ class ComputeNode: # cov-slurm """ def __init__( - self, node_name: t.Optional[str] = None, node_ppn: t.Optional[int] = None + self, node_name: str | None = None, node_ppn: int | None = None ) -> None: """Initialize a ComputeNode :param node_name: the name of the node :param node_ppn: the number of ppn """ - self.name: t.Optional[str] = node_name - self.ppn: t.Optional[int] = node_ppn + self.name: str | None = node_name + self.ppn: int | None = node_ppn def _is_valid_node(self) -> bool: """Check if the node is complete @@ -66,9 +64,9 @@ class Partition: # cov-slurm def __init__(self) -> None: """Initialize a system partition""" - self.name: t.Optional[str] = None - self.min_ppn: t.Optional[int] = None - self.nodes: t.Set[ComputeNode] = set() + self.name: str | None = None + self.min_ppn: int | None = None + self.nodes: set[ComputeNode] = set() def _is_valid_partition(self) -> bool: """Check if the partition is valid diff --git a/smartsim/_core/schemas/dragonRequests.py b/smartsim/_core/schemas/dragonRequests.py index 28ff30b55..f3990f4c0 100644 --- a/smartsim/_core/schemas/dragonRequests.py +++ b/smartsim/_core/schemas/dragonRequests.py @@ -43,14 +43,14 @@ class DragonRequest(BaseModel): ... class DragonRunPolicy(BaseModel): """Policy specifying hardware constraints when running a Dragon job""" - cpu_affinity: t.List[NonNegativeInt] = Field(default_factory=list) + cpu_affinity: list[NonNegativeInt] = Field(default_factory=list) """List of CPU indices to which the job should be pinned""" - gpu_affinity: t.List[NonNegativeInt] = Field(default_factory=list) + gpu_affinity: list[NonNegativeInt] = Field(default_factory=list) """List of GPU indices to which the job should be pinned""" @staticmethod def from_run_args( - run_args: t.Dict[str, t.Union[int, str, float, None]] + run_args: dict[str, int | str | float | None] ) -> "DragonRunPolicy": """Create a DragonRunPolicy with hardware constraints passed from a dictionary of run arguments @@ -79,23 +79,23 @@ def from_run_args( class DragonRunRequestView(DragonRequest): exe: t.Annotated[str, Field(min_length=1)] - exe_args: t.List[t.Annotated[str, Field(min_length=1)]] = [] + exe_args: list[t.Annotated[str, Field(min_length=1)]] = [] path: t.Annotated[str, Field(min_length=1)] nodes: PositiveInt = 1 tasks: PositiveInt = 1 tasks_per_node: PositiveInt = 1 - hostlist: t.Optional[t.Annotated[str, Field(min_length=1)]] = None - output_file: t.Optional[t.Annotated[str, Field(min_length=1)]] = None - error_file: t.Optional[t.Annotated[str, Field(min_length=1)]] = None - env: t.Dict[str, t.Optional[str]] = {} - name: t.Optional[t.Annotated[str, Field(min_length=1)]] = None + hostlist: t.Annotated[str, Field(min_length=1)] | None = None + output_file: t.Annotated[str, Field(min_length=1)] | None = None + error_file: t.Annotated[str, Field(min_length=1)] | None = None + env: dict[str, str | None] = {} + name: t.Annotated[str, Field(min_length=1)] | None = None pmi_enabled: bool = True @request_registry.register("run") class DragonRunRequest(DragonRunRequestView): - current_env: t.Dict[str, t.Optional[str]] = {} - policy: t.Optional[DragonRunPolicy] = None + current_env: dict[str, str | None] = {} + policy: DragonRunPolicy | None = None def __str__(self) -> str: return str(DragonRunRequestView.parse_obj(self.dict(exclude={"current_env"}))) @@ -103,7 +103,7 @@ def __str__(self) -> str: @request_registry.register("update_status") class DragonUpdateStatusRequest(DragonRequest): - step_ids: t.List[t.Annotated[str, Field(min_length=1)]] + step_ids: list[t.Annotated[str, Field(min_length=1)]] @request_registry.register("stop") diff --git a/smartsim/_core/schemas/dragonResponses.py b/smartsim/_core/schemas/dragonResponses.py index 318a4eabf..14ffd797c 100644 --- a/smartsim/_core/schemas/dragonResponses.py +++ b/smartsim/_core/schemas/dragonResponses.py @@ -25,6 +25,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import typing as t +from collections.abc import Mapping from pydantic import BaseModel, Field @@ -38,7 +39,7 @@ class DragonResponse(BaseModel): - error_message: t.Optional[str] = None + error_message: str | None = None @response_registry.register("run") @@ -49,9 +50,9 @@ class DragonRunResponse(DragonResponse): @response_registry.register("status_update") class DragonUpdateStatusResponse(DragonResponse): # status is a dict: {step_id: (is_alive, returncode)} - statuses: t.Mapping[ + statuses: Mapping[ t.Annotated[str, Field(min_length=1)], - t.Tuple[SmartSimStatus, t.Optional[t.List[int]]], + tuple[SmartSimStatus, list[int] | None], ] = {} diff --git a/smartsim/_core/schemas/utils.py b/smartsim/_core/schemas/utils.py index 508ef34ed..47daf1e05 100644 --- a/smartsim/_core/schemas/utils.py +++ b/smartsim/_core/schemas/utils.py @@ -26,6 +26,7 @@ import dataclasses import typing as t +from collections.abc import Callable, Mapping import pydantic import pydantic.dataclasses @@ -54,7 +55,7 @@ def __str__(self) -> str: def from_str( cls, str_: str, - payload_type: t.Type[_SchemaT], + payload_type: type[_SchemaT], delimiter: str = _DEFAULT_MSG_DELIM, ) -> "_Message[_SchemaT]": header, payload = str_.split(delimiter, 1) @@ -63,11 +64,11 @@ def from_str( class SchemaRegistry(t.Generic[_SchemaT]): def __init__( - self, init_map: t.Optional[t.Mapping[str, t.Type[_SchemaT]]] = None + self, init_map: t.Optional[Mapping[str, type[_SchemaT]]] = None ) -> None: self._map = dict(init_map) if init_map else {} - def register(self, key: str) -> t.Callable[[t.Type[_SchemaT]], t.Type[_SchemaT]]: + def register(self, key: str) -> Callable[[type[_SchemaT]], type[_SchemaT]]: if _DEFAULT_MSG_DELIM in key: _msg = f"Registry key cannot contain delimiter `{_DEFAULT_MSG_DELIM}`" raise ValueError(_msg) @@ -76,7 +77,7 @@ def register(self, key: str) -> t.Callable[[t.Type[_SchemaT]], t.Type[_SchemaT]] if key in self._map: raise KeyError(f"Key `{key}` has already been registered for this parser") - def _register(cls: t.Type[_SchemaT]) -> t.Type[_SchemaT]: + def _register(cls: type[_SchemaT]) -> type[_SchemaT]: self._map[key] = cls return cls diff --git a/smartsim/_core/utils/helpers.py b/smartsim/_core/utils/helpers.py index b4caf6d71..9e0be29b7 100644 --- a/smartsim/_core/utils/helpers.py +++ b/smartsim/_core/utils/helpers.py @@ -33,6 +33,7 @@ import subprocess import typing as t import uuid +from collections.abc import Callable, Iterable, Sequence from datetime import datetime from functools import lru_cache from pathlib import Path @@ -43,10 +44,10 @@ _TRedisAIBackendStr = t.Literal["tensorflow", "torch", "onnxruntime"] -_TSignalHandlerFn = t.Callable[[int, t.Optional["FrameType"]], object] +_TSignalHandlerFn = Callable[[int, "FrameType | None"], object] -def unpack_db_identifier(db_id: str, token: str) -> t.Tuple[str, str]: +def unpack_db_identifier(db_id: str, token: str) -> tuple[str, str]: """Unpack the unformatted database identifier and format for env variable suffix using the token :param db_id: the unformatted database identifier eg. identifier_1 @@ -85,7 +86,7 @@ def check_dev_log_level() -> bool: return lvl == "developer" -def fmt_dict(value: t.Dict[str, t.Any]) -> str: +def fmt_dict(value: dict[str, t.Any]) -> str: fmt_str = "" for k, v in value.items(): fmt_str += "\t" + str(k) + " = " + str(v) @@ -129,7 +130,7 @@ def expand_exe_path(exe: str) -> str: return os.path.abspath(in_path) -def is_valid_cmd(command: t.Union[str, None]) -> bool: +def is_valid_cmd(command: str | None) -> bool: try: if command: expand_exe_path(command) @@ -172,7 +173,7 @@ def colorize( return f"\x1b[{';'.join(attr)}m{string}\x1b[0m" -def delete_elements(dictionary: t.Dict[str, t.Any], key_list: t.List[str]) -> None: +def delete_elements(dictionary: dict[str, t.Any], key_list: list[str]) -> None: """Delete elements from a dictionary. :param dictionary: the dictionary from which the elements must be deleted. :param key_list: the list of keys to delete from the dictionary. @@ -224,7 +225,7 @@ def _installed(base_path: Path, backend: str) -> bool: return backend_so.is_file() -def redis_install_base(backends_path: t.Optional[str] = None) -> Path: +def redis_install_base(backends_path: str | None = None) -> Path: # pylint: disable-next=import-outside-toplevel,cyclic-import from ..._core.config import CONFIG @@ -235,8 +236,8 @@ def redis_install_base(backends_path: t.Optional[str] = None) -> Path: def installed_redisai_backends( - backends_path: t.Optional[str] = None, -) -> t.Set[_TRedisAIBackendStr]: + backends_path: str | None = None, +) -> set[_TRedisAIBackendStr]: """Check which ML backends are available for the RedisAI module. The optional argument ``backends_path`` is needed if the backends @@ -251,7 +252,7 @@ def installed_redisai_backends( """ # import here to avoid circular import base_path = redis_install_base(backends_path) - backends: t.Set[_TRedisAIBackendStr] = { + backends: set[_TRedisAIBackendStr] = { "tensorflow", "torch", "onnxruntime", @@ -281,7 +282,7 @@ def check_for_utility(util_name: str) -> str: return utility -def execute_platform_cmd(cmd: str) -> t.Tuple[str, int]: +def execute_platform_cmd(cmd: str) -> tuple[str, int]: """Execute the platform check command as a subprocess :param cmd: the command to execute @@ -297,9 +298,9 @@ def execute_platform_cmd(cmd: str) -> t.Tuple[str, int]: class CrayExPlatformResult: locate_msg = "Unable to locate `{0}`." - def __init__(self, ldconfig: t.Optional[str], fi_info: t.Optional[str]) -> None: - self.ldconfig: t.Optional[str] = ldconfig - self.fi_info: t.Optional[str] = fi_info + def __init__(self, ldconfig: str | None, fi_info: str | None) -> None: + self.ldconfig: str | None = ldconfig + self.fi_info: str | None = fi_info self.has_pmi: bool = False self.has_pmi2: bool = False self.has_cxi: bool = False @@ -325,7 +326,7 @@ def is_cray(self) -> bool: ) @property - def failures(self) -> t.List[str]: + def failures(self) -> list[str]: """Return a list of messages describing all failed validations""" failure_messages = [] @@ -397,7 +398,7 @@ class SignalInterceptionStack(collections.abc.Collection[_TSignalHandlerFn]): def __init__( self, signalnum: int, - callbacks: t.Optional[t.Iterable[_TSignalHandlerFn]] = None, + callbacks: Iterable[_TSignalHandlerFn] | None = None, ) -> None: """Set up a ``SignalInterceptionStack`` for particular signal number. @@ -414,7 +415,7 @@ def __init__( self._callbacks = list(callbacks) if callbacks else [] self._original = signal.signal(signalnum, self) - def __call__(self, signalnum: int, frame: t.Optional["FrameType"]) -> None: + def __call__(self, signalnum: int, frame: "FrameType | None") -> None: """Handle the signal on which the interception stack was registered. End by calling the originally registered signal hander (if present). diff --git a/smartsim/_core/utils/network.py b/smartsim/_core/utils/network.py index 7c2b6f5e1..1c08c0e00 100644 --- a/smartsim/_core/utils/network.py +++ b/smartsim/_core/utils/network.py @@ -35,8 +35,8 @@ class IFConfig(t.NamedTuple): - interface: t.Optional[str] - address: t.Optional[str] + interface: str | None + address: str | None def get_ip_from_host(host: str) -> str: diff --git a/smartsim/_core/utils/redis.py b/smartsim/_core/utils/redis.py index ab7ecdea0..9b290eac2 100644 --- a/smartsim/_core/utils/redis.py +++ b/smartsim/_core/utils/redis.py @@ -46,7 +46,7 @@ logger = get_logger(__name__) -def create_cluster(hosts: t.List[str], ports: t.List[int]) -> None: # cov-wlm +def create_cluster(hosts: list[str], ports: list[int]) -> None: # cov-wlm """Connect launched cluster instances. Should only be used in the case where cluster initialization @@ -78,7 +78,7 @@ def create_cluster(hosts: t.List[str], ports: t.List[int]) -> None: # cov-wlm def check_cluster_status( - hosts: t.List[str], ports: t.List[int], trials: int = 10 + hosts: list[str], ports: list[int], trials: int = 10 ) -> None: # cov-wlm """Check that a Redis/KeyDB cluster is up and running @@ -117,7 +117,7 @@ def check_cluster_status( raise SSInternalError("Cluster setup could not be verified") -def db_is_active(hosts: t.List[str], ports: t.List[int], num_shards: int) -> bool: +def db_is_active(hosts: list[str], ports: list[int], num_shards: int) -> bool: """Check if a DB is running if the DB is clustered, check cluster status, otherwise @@ -212,7 +212,7 @@ def set_script(db_script: DBScript, client: Client) -> None: raise error -def shutdown_db_node(host_ip: str, port: int) -> t.Tuple[int, str, str]: # cov-wlm +def shutdown_db_node(host_ip: str, port: int) -> tuple[int, str, str]: # cov-wlm """Send shutdown signal to DB node. Should only be used in the case where cluster deallocation diff --git a/smartsim/_core/utils/security.py b/smartsim/_core/utils/security.py index c3f460074..a65466dea 100644 --- a/smartsim/_core/utils/security.py +++ b/smartsim/_core/utils/security.py @@ -28,7 +28,6 @@ import dataclasses import pathlib import stat -import typing as t from enum import IntEnum import zmq @@ -216,7 +215,7 @@ def _load_keypair(cls, locator: _KeyLocator, in_context: bool) -> KeyPair: key_path = locator.private if in_context else locator.public pub_key: bytes = b"" - priv_key: t.Optional[bytes] = b"" + priv_key: bytes | None = b"" if key_path.exists(): logger.debug(f"Existing key files located at {key_path}") @@ -227,7 +226,7 @@ def _load_keypair(cls, locator: _KeyLocator, in_context: bool) -> KeyPair: # avoid a `None` value in the private key when it isn't loaded return KeyPair(pub_key, priv_key or b"") - def _load_keys(self) -> t.Tuple[KeyPair, KeyPair]: + def _load_keys(self) -> tuple[KeyPair, KeyPair]: """Use ZMQ auth to load public/private key pairs for the server and client components from the standard key paths for the associated experiment @@ -270,7 +269,7 @@ def _create_keys(self) -> None: locator.private.chmod(_KeyPermissions.PRIVATE_KEY) locator.public.chmod(_KeyPermissions.PUBLIC_KEY) - def get_keys(self, create: bool = True) -> t.Tuple[KeyPair, KeyPair]: + def get_keys(self, create: bool = True) -> tuple[KeyPair, KeyPair]: """Use ZMQ auth to generate a public/private key pair for the server and client components. diff --git a/smartsim/_core/utils/shell.py b/smartsim/_core/utils/shell.py index 32ff0b86f..b1b3f3572 100644 --- a/smartsim/_core/utils/shell.py +++ b/smartsim/_core/utils/shell.py @@ -25,7 +25,6 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import time -import typing as t from subprocess import PIPE, TimeoutExpired import psutil @@ -39,13 +38,13 @@ def execute_cmd( - cmd_list: t.List[str], + cmd_list: list[str], shell: bool = False, - cwd: t.Optional[str] = None, - env: t.Optional[t.Dict[str, str]] = None, + cwd: str | None = None, + env: dict[str, str] | None = None, proc_input: str = "", - timeout: t.Optional[int] = None, -) -> t.Tuple[int, str, str]: + timeout: int | None = None, +) -> tuple[int, str, str]: """Execute a command locally :param cmd_list: list of command with arguments @@ -86,9 +85,9 @@ def execute_cmd( def execute_async_cmd( - cmd_list: t.List[str], + cmd_list: list[str], cwd: str, - env: t.Optional[t.Dict[str, str]] = None, + env: dict[str, str] | None = None, out: int = PIPE, err: int = PIPE, ) -> psutil.Popen: diff --git a/smartsim/database/orchestrator.py b/smartsim/database/orchestrator.py index 728d12d04..25ec48f4e 100644 --- a/smartsim/database/orchestrator.py +++ b/smartsim/database/orchestrator.py @@ -68,7 +68,7 @@ logger = get_logger(__name__) -by_launcher: t.Dict[str, t.List[str]] = { +by_launcher: dict[str, list[str]] = { "dragon": [""], "slurm": ["srun", "mpirun", "mpiexec"], "pbs": ["aprun", "mpirun", "mpiexec"], @@ -93,7 +93,7 @@ def _detect_command(launcher: str) -> str: raise SmartSimError(msg) -def _autodetect(launcher: str, run_command: str) -> t.Tuple[str, str]: +def _autodetect(launcher: str, run_command: str) -> tuple[str, str]: """Automatically detect the launcher and run command to use""" if launcher == "auto": launcher = detect_launcher() @@ -163,22 +163,22 @@ class Orchestrator(EntityList[DBNode]): def __init__( self, - path: t.Optional[str] = getcwd(), + path: str | None = getcwd(), port: int = 6379, - interface: t.Union[str, t.List[str]] = "lo", + interface: str | list[str] = "lo", launcher: str = "local", run_command: str = "auto", db_nodes: int = 1, batch: bool = False, - hosts: t.Optional[t.Union[t.List[str], str]] = None, - account: t.Optional[str] = None, - time: t.Optional[str] = None, - alloc: t.Optional[str] = None, + hosts: list[str] | str | None = None, + account: str | None = None, + time: str | None = None, + alloc: str | None = None, single_cmd: bool = False, *, - threads_per_queue: t.Optional[int] = None, - inter_op_threads: t.Optional[int] = None, - intra_op_threads: t.Optional[int] = None, + threads_per_queue: int | None = None, + inter_op_threads: int | None = None, + intra_op_threads: int | None = None, db_identifier: str = "orchestrator", **kwargs: t.Any, ) -> None: @@ -213,9 +213,9 @@ def __init__( single_cmd = _get_single_command( self.run_command, self.launcher, batch, single_cmd ) - self.ports: t.List[int] = [] - self._hosts: t.List[str] = [] - self._user_hostlist: t.List[str] = [] + self.ports: list[int] = [] + self._hosts: list[str] = [] + self._user_hostlist: list[str] = [] if isinstance(interface, str): interface = [interface] self._interfaces = interface @@ -224,8 +224,8 @@ def __init__( self.inter_threads = inter_op_threads self.intra_threads = intra_op_threads - gpus_per_shard: t.Optional[int] = None - cpus_per_shard: t.Optional[int] = None + gpus_per_shard: int | None = None + cpus_per_shard: int | None = None super().__init__( name=db_identifier, @@ -284,8 +284,8 @@ def __init__( "Orchestrator with mpirun", ) ) - self._reserved_run_args: t.Dict[t.Type[RunSettings], t.List[str]] = {} - self._reserved_batch_args: t.Dict[t.Type[BatchSettings], t.List[str]] = {} + self._reserved_run_args: dict[type[RunSettings], list[str]] = {} + self._reserved_batch_args: dict[type[BatchSettings], list[str]] = {} self._fill_reserved() def _mpi_has_sge_support(self) -> bool: @@ -334,7 +334,7 @@ def db_nodes(self) -> int: return self.num_shards @property - def hosts(self) -> t.List[str]: + def hosts(self) -> list[str]: """Return the hostnames of Orchestrator instance hosts Note that this will only be populated after the orchestrator @@ -360,7 +360,7 @@ def remove_stale_files(self) -> None: for db in self.entities: db.remove_stale_dbnode_files() - def get_address(self) -> t.List[str]: + def get_address(self) -> list[str]: """Return database addresses :return: addresses @@ -373,7 +373,7 @@ def get_address(self) -> t.List[str]: raise SmartSimError("Database is not active") return self._get_address() - def _get_address(self) -> t.List[str]: + def _get_address(self) -> list[str]: return [ f"{host}:{port}" for host, port in itertools.product(self._hosts, self.ports) @@ -391,7 +391,7 @@ def is_active(self) -> bool: return db_is_active(hosts, self.ports, self.num_shards) @property - def _rai_module(self) -> t.Tuple[str, ...]: + def _rai_module(self) -> tuple[str, ...]: """Get the RedisAI module from third-party installations :return: Tuple of args to pass to the orchestrator exe @@ -460,7 +460,7 @@ def set_walltime(self, walltime: str) -> None: if hasattr(self, "batch_settings") and self.batch_settings: self.batch_settings.set_walltime(walltime) - def set_hosts(self, host_list: t.Union[t.List[str], str]) -> None: + def set_hosts(self, host_list: list[str] | str) -> None: """Specify the hosts for the ``Orchestrator`` to launch on :param host_list: list of host (compute node names) @@ -496,7 +496,7 @@ def set_hosts(self, host_list: t.Union[t.List[str], str]) -> None: for i, mpmd_runsettings in enumerate(db.run_settings.mpmd, 1): mpmd_runsettings.set_hostlist(host_list[i]) - def set_batch_arg(self, arg: str, value: t.Optional[str] = None) -> None: + def set_batch_arg(self, arg: str, value: str | None = None) -> None: """Set a batch argument the orchestrator should launch with Some commonly used arguments such as --job-name are used @@ -517,7 +517,7 @@ def set_batch_arg(self, arg: str, value: t.Optional[str] = None) -> None: else: self.batch_settings.batch_args[arg] = value - def set_run_arg(self, arg: str, value: t.Optional[str] = None) -> None: + def set_run_arg(self, arg: str, value: str | None = None) -> None: """Set a run argument the orchestrator should launch each node with (it will be passed to `jrun`) @@ -654,9 +654,9 @@ def _build_batch_settings( account: str, time: str, *, - launcher: t.Optional[str] = None, + launcher: str | None = None, **kwargs: t.Any, - ) -> t.Optional[BatchSettings]: + ) -> BatchSettings | None: batch_settings = None if launcher is None: @@ -674,9 +674,9 @@ def _build_batch_settings( def _build_run_settings( self, exe: str, - exe_args: t.List[t.List[str]], + exe_args: list[list[str]], *, - run_args: t.Optional[t.Dict[str, t.Any]] = None, + run_args: dict[str, t.Any] | None = None, db_nodes: int = 1, single_cmd: bool = True, **kwargs: t.Any, @@ -769,7 +769,7 @@ def _initialize_entities_mpmd( ) -> None: cluster = db_nodes >= 3 mpmd_node_name = self.name + "_0" - exe_args_mpmd: t.List[t.List[str]] = [] + exe_args_mpmd: list[list[str]] = [] for db_id in range(db_nodes): db_shard_name = "_".join((self.name, str(db_id))) @@ -780,7 +780,7 @@ def _initialize_entities_mpmd( ) exe_args = " ".join(start_script_args) exe_args_mpmd.append(sh_split(exe_args)) - run_settings: t.Optional[RunSettings] = None + run_settings: RunSettings | None = None run_settings = self._build_run_settings( sys.executable, exe_args_mpmd, db_nodes=db_nodes, port=port, **kwargs @@ -799,9 +799,7 @@ def _initialize_entities_mpmd( self.entities.append(node) self.ports = [port] - def _get_start_script_args( - self, name: str, port: int, cluster: bool - ) -> t.List[str]: + def _get_start_script_args(self, name: str, port: int, cluster: bool) -> list[str]: cmd = [ "-m", "smartsim._core.entrypoints.redis", # entrypoint @@ -818,7 +816,7 @@ def _get_start_script_args( return cmd - def _get_db_hosts(self) -> t.List[str]: + def _get_db_hosts(self) -> list[str]: hosts = [] for db in self.entities: if not db.is_mpmd: diff --git a/smartsim/entity/dbnode.py b/smartsim/entity/dbnode.py index 98f7baed6..9dd32d764 100644 --- a/smartsim/entity/dbnode.py +++ b/smartsim/entity/dbnode.py @@ -31,6 +31,7 @@ import os.path as osp import time import typing as t +from collections.abc import Iterable from dataclasses import dataclass from .._core.config import CONFIG @@ -56,14 +57,14 @@ def __init__( name: str, path: str, run_settings: RunSettings, - ports: t.List[int], - output_files: t.List[str], + ports: list[int], + output_files: list[str], db_identifier: str = "", ) -> None: """Initialize a database node within an orchestrator.""" super().__init__(name, path, run_settings) self.ports = ports - self._hosts: t.Optional[t.List[str]] = None + self._hosts: list[str] | None = None if not output_files: raise ValueError("output_files cannot be empty") @@ -93,7 +94,7 @@ def host(self) -> str: return host @property - def hosts(self) -> t.List[str]: + def hosts(self) -> list[str]: if not self._hosts: self._hosts = self._parse_db_hosts() return self._hosts @@ -109,7 +110,7 @@ def is_mpmd(self) -> bool: return bool(self.run_settings.mpmd) - def set_hosts(self, hosts: t.List[str]) -> None: + def set_hosts(self, hosts: list[str]) -> None: self._hosts = [str(host) for host in hosts] def remove_stale_dbnode_files(self) -> None: @@ -140,7 +141,7 @@ def remove_stale_dbnode_files(self) -> None: if osp.exists(file_name): os.remove(file_name) - def _get_cluster_conf_filenames(self, port: int) -> t.List[str]: + def _get_cluster_conf_filenames(self, port: int) -> list[str]: """Returns the .conf file name for the given port number This function should bu used if and only if ``_mpmd==True`` @@ -157,8 +158,8 @@ def _get_cluster_conf_filenames(self, port: int) -> t.List[str]: @staticmethod def _parse_launched_shard_info_from_iterable( - stream: t.Iterable[str], num_shards: t.Optional[int] = None - ) -> "t.List[LaunchedShardData]": + stream: Iterable[str], num_shards: int | None = None + ) -> "list[LaunchedShardData]": lines = (line.strip() for line in stream) lines = (line for line in lines if line) tokenized = (line.split(maxsplit=1) for line in lines) @@ -167,7 +168,7 @@ def _parse_launched_shard_info_from_iterable( kwjson for first, kwjson in tokenized if "SMARTSIM_ORC_SHARD_INFO" in first ) shard_data_kwargs = (json.loads(kwjson) for kwjson in shard_data_jsons) - shard_data: "t.Iterable[LaunchedShardData]" = ( + shard_data: "Iterable[LaunchedShardData]" = ( LaunchedShardData(**kwargs) for kwargs in shard_data_kwargs ) if num_shards: @@ -176,18 +177,18 @@ def _parse_launched_shard_info_from_iterable( @classmethod def _parse_launched_shard_info_from_files( - cls, file_paths: t.List[str], num_shards: t.Optional[int] = None - ) -> "t.List[LaunchedShardData]": + cls, file_paths: list[str], num_shards: int | None = None + ) -> "list[LaunchedShardData]": with fileinput.FileInput(file_paths) as ifstream: return cls._parse_launched_shard_info_from_iterable(ifstream, num_shards) - def get_launched_shard_info(self) -> "t.List[LaunchedShardData]": + def get_launched_shard_info(self) -> "list[LaunchedShardData]": """Parse the launched database shard info from the output files :raises SSDBFilesNotParseable: if all shard info could not be found :return: The found launched shard info """ - ips: "t.List[LaunchedShardData]" = [] + ips: "list[LaunchedShardData]" = [] trials = CONFIG.database_file_parse_trials interval = CONFIG.database_file_parse_interval output_files = [osp.join(self.path, file) for file in self._output_files] @@ -214,7 +215,7 @@ def get_launched_shard_info(self) -> "t.List[LaunchedShardData]": raise SSDBFilesNotParseable(msg) return ips - def _parse_db_hosts(self) -> t.List[str]: + def _parse_db_hosts(self) -> list[str]: """Parse the database hosts/IPs from the output files The IP address is preferred, but if hostname is only present @@ -236,8 +237,8 @@ class LaunchedShardData: cluster: bool @property - def cluster_conf_file(self) -> t.Optional[str]: + def cluster_conf_file(self) -> str | None: return f"nodes-{self.name}-{self.port}.conf" if self.cluster else None - def to_dict(self) -> t.Dict[str, t.Any]: + def to_dict(self) -> dict[str, t.Any]: return dict(self.__dict__) diff --git a/smartsim/entity/dbobject.py b/smartsim/entity/dbobject.py index 3c0e216b4..e0239c7df 100644 --- a/smartsim/entity/dbobject.py +++ b/smartsim/entity/dbobject.py @@ -45,17 +45,15 @@ class DBObject(t.Generic[_DBObjectFuncT]): def __init__( self, name: str, - func: t.Optional[_DBObjectFuncT], - file_path: t.Optional[str], + func: _DBObjectFuncT | None, + file_path: str | None, device: str, devices_per_node: int, first_device: int, ) -> None: self.name = name - self.func: t.Optional[_DBObjectFuncT] = func - self.file: t.Optional[Path] = ( - None # Need to have this explicitly to check on it - ) + self.func: _DBObjectFuncT | None = func + self.file: Path | None = None # Need to have this explicitly to check on it if file_path: self.file = self._check_filepath(file_path) self.device = self._check_device(device) @@ -64,7 +62,7 @@ def __init__( self._check_devices(device, devices_per_node, first_device) @property - def devices(self) -> t.List[str]: + def devices(self) -> list[str]: return self._enumerate_devices() @property @@ -73,9 +71,9 @@ def is_file(self) -> bool: @staticmethod def _check_tensor_args( - inputs: t.Union[str, t.Optional[t.List[str]]], - outputs: t.Union[str, t.Optional[t.List[str]]], - ) -> t.Tuple[t.List[str], t.List[str]]: + inputs: str | list[str] | None, + outputs: str | list[str] | None, + ) -> tuple[list[str], list[str]]: if isinstance(inputs, str): inputs = [inputs] if isinstance(outputs, str): @@ -107,7 +105,7 @@ def _check_device(device: str) -> str: raise ValueError("Device argument must start with either CPU or GPU") return device - def _enumerate_devices(self) -> t.List[str]: + def _enumerate_devices(self) -> list[str]: """Enumerate devices for a DBObject :param dbobject: DBObject to enumerate @@ -154,8 +152,8 @@ class DBScript(DBObject[str]): def __init__( self, name: str, - script: t.Optional[str] = None, - script_path: t.Optional[str] = None, + script: str | None = None, + script_path: str | None = None, device: str = Device.CPU.value.upper(), devices_per_node: int = 1, first_device: int = 0, @@ -187,7 +185,7 @@ def __init__( raise ValueError("Either script or script_path must be provided") @property - def script(self) -> t.Optional[t.Union[bytes, str]]: + def script(self) -> bytes | str | None: return self.func def __str__(self) -> str: @@ -210,8 +208,8 @@ def __init__( self, name: str, backend: str, - model: t.Optional[bytes] = None, - model_file: t.Optional[str] = None, + model: bytes | None = None, + model_file: str | None = None, device: str = Device.CPU.value.upper(), devices_per_node: int = 1, first_device: int = 0, @@ -219,8 +217,8 @@ def __init__( min_batch_size: int = 0, min_batch_timeout: int = 0, tag: str = "", - inputs: t.Optional[t.List[str]] = None, - outputs: t.Optional[t.List[str]] = None, + inputs: list[str] | None = None, + outputs: list[str] | None = None, ) -> None: """A TF, TF-lite, PT, or ONNX model to load into the DB at runtime @@ -254,7 +252,7 @@ def __init__( self.inputs, self.outputs = self._check_tensor_args(inputs, outputs) @property - def model(self) -> t.Optional[bytes]: + def model(self) -> bytes | None: return self.func def __str__(self) -> str: diff --git a/smartsim/entity/ensemble.py b/smartsim/entity/ensemble.py index cbf36c431..8ec9a0c0a 100644 --- a/smartsim/entity/ensemble.py +++ b/smartsim/entity/ensemble.py @@ -26,6 +26,7 @@ import os.path as osp import typing as t +from collections.abc import Callable, Collection from copy import deepcopy from os import getcwd @@ -49,9 +50,7 @@ logger = get_logger(__name__) -StrategyFunction = t.Callable[ - [t.List[str], t.List[t.List[str]], int], t.List[t.Dict[str, str]] -] +StrategyFunction = Callable[[list[str], list[list[str]], int], list[dict[str, str]]] class Ensemble(EntityList[Model]): @@ -62,11 +61,11 @@ class Ensemble(EntityList[Model]): def __init__( self, name: str, - params: t.Dict[str, t.Any], - path: t.Optional[str] = getcwd(), - params_as_args: t.Optional[t.List[str]] = None, - batch_settings: t.Optional[BatchSettings] = None, - run_settings: t.Optional[RunSettings] = None, + params: dict[str, t.Any], + path: str | None = getcwd(), + params_as_args: list[str] | None = None, + batch_settings: BatchSettings | None = None, + run_settings: RunSettings | None = None, perm_strat: str = "all_perm", **kwargs: t.Any, ) -> None: @@ -100,7 +99,7 @@ def __init__( super().__init__(name, str(path), perm_strat=perm_strat, **kwargs) @property - def models(self) -> t.Collection[Model]: + def models(self) -> Collection[Model]: """An alias for a shallow copy of the ``entities`` attribute""" return list(self.entities) @@ -235,9 +234,9 @@ def query_key_prefixing(self) -> bool: def attach_generator_files( self, - to_copy: t.Optional[t.List[str]] = None, - to_symlink: t.Optional[t.List[str]] = None, - to_configure: t.Optional[t.List[str]] = None, + to_copy: list[str] | None = None, + to_symlink: list[str] | None = None, + to_configure: list[str] | None = None, ) -> None: """Attach files to each model within the ensemble for generation @@ -307,7 +306,7 @@ def _set_strategy(strategy: str) -> StrategyFunction: f"Permutation strategy given is not supported: {strategy}" ) - def _read_model_parameters(self) -> t.Tuple[t.List[str], t.List[t.List[str]]]: + def _read_model_parameters(self) -> tuple[list[str], list[list[str]]]: """Take in the parameters given to the ensemble and prepare to create models for the ensemble @@ -320,8 +319,8 @@ def _read_model_parameters(self) -> t.Tuple[t.List[str], t.List[t.List[str]]]: "Ensemble initialization argument 'params' must be of type dict" ) - param_names: t.List[str] = [] - parameters: t.List[t.List[str]] = [] + param_names: list[str] = [] + parameters: list[list[str]] = [] for name, val in self.params.items(): param_names.append(name) @@ -341,8 +340,8 @@ def add_ml_model( self, name: str, backend: str, - model: t.Optional[bytes] = None, - model_path: t.Optional[str] = None, + model: bytes | None = None, + model_path: str | None = None, device: str = Device.CPU.value.upper(), devices_per_node: int = 1, first_device: int = 0, @@ -350,8 +349,8 @@ def add_ml_model( min_batch_size: int = 0, min_batch_timeout: int = 0, tag: str = "", - inputs: t.Optional[t.List[str]] = None, - outputs: t.Optional[t.List[str]] = None, + inputs: list[str] | None = None, + outputs: list[str] | None = None, ) -> None: """A TF, TF-lite, PT, or ONNX model to load into the DB at runtime @@ -411,8 +410,8 @@ def add_ml_model( def add_script( self, name: str, - script: t.Optional[str] = None, - script_path: t.Optional[str] = None, + script: str | None = None, + script_path: str | None = None, device: str = Device.CPU.value.upper(), devices_per_node: int = 1, first_device: int = 0, @@ -466,7 +465,7 @@ def add_script( def add_function( self, name: str, - function: t.Optional[str] = None, + function: str | None = None, device: str = Device.CPU.value.upper(), devices_per_node: int = 1, first_device: int = 0, @@ -517,7 +516,7 @@ def add_function( self._extend_entity_db_scripts(entity, [db_script]) @staticmethod - def _extend_entity_db_models(model: Model, db_models: t.List[DBModel]) -> None: + def _extend_entity_db_models(model: Model, db_models: list[DBModel]) -> None: """ Ensures that the Machine Learning model names being added to the Ensemble are unique. @@ -545,7 +544,7 @@ def _extend_entity_db_models(model: Model, db_models: t.List[DBModel]) -> None: model.add_ml_model_object(add_ml_model) @staticmethod - def _extend_entity_db_scripts(model: Model, db_scripts: t.List[DBScript]) -> None: + def _extend_entity_db_scripts(model: Model, db_scripts: list[DBScript]) -> None: """ Ensures that the script/function names being added to the Ensemble are unique. diff --git a/smartsim/entity/entityList.py b/smartsim/entity/entityList.py index c5eb7571c..1eccc470c 100644 --- a/smartsim/entity/entityList.py +++ b/smartsim/entity/entityList.py @@ -25,6 +25,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import typing as t +from collections.abc import Iterable, Sequence from .entity import SmartSimEntity @@ -67,9 +68,9 @@ def __init__(self, name: str, path: str, **kwargs: t.Any) -> None: # object construction into the class' constructor. # --------------------------------------------------------------------- # - self.entities: t.Sequence[_T_co] = [] - self._db_models: t.Sequence["smartsim.entity.DBModel"] = [] - self._db_scripts: t.Sequence["smartsim.entity.DBScript"] = [] + self.entities: Sequence[_T_co] = [] + self._db_models: Sequence["smartsim.entity.DBModel"] = [] + self._db_scripts: Sequence["smartsim.entity.DBScript"] = [] # # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< @@ -80,12 +81,12 @@ def _initialize_entities(self, **kwargs: t.Any) -> None: raise NotImplementedError @property - def db_models(self) -> t.Iterable["smartsim.entity.DBModel"]: + def db_models(self) -> Iterable["smartsim.entity.DBModel"]: """Return an immutable collection of attached models""" return (model for model in self._db_models) @property - def db_scripts(self) -> t.Iterable["smartsim.entity.DBScript"]: + def db_scripts(self) -> Iterable["smartsim.entity.DBScript"]: """Return an immutable collection of attached scripts""" return (script for script in self._db_scripts) @@ -110,7 +111,7 @@ def set_path(self, new_path: str) -> None: for entity in self.entities: entity.path = new_path - def __getitem__(self, name: str) -> t.Optional[_T_co]: + def __getitem__(self, name: str) -> _T_co | None: for entity in self.entities: if entity.name == name: return entity @@ -129,9 +130,9 @@ class EntityList(EntitySequence[_T]): def __init__(self, name: str, path: str, **kwargs: t.Any) -> None: super().__init__(name, path, **kwargs) # Change container types to be invariant ``list``s - self.entities: t.List[_T] = list(self.entities) - self._db_models: t.List["smartsim.entity.DBModel"] = list(self._db_models) - self._db_scripts: t.List["smartsim.entity.DBScript"] = list(self._db_scripts) + self.entities: list[_T] = list(self.entities) + self._db_models: list["smartsim.entity.DBModel"] = list(self._db_models) + self._db_scripts: list["smartsim.entity.DBScript"] = list(self._db_scripts) def _initialize_entities(self, **kwargs: t.Any) -> None: """Initialize the SmartSimEntity objects in the container""" diff --git a/smartsim/entity/files.py b/smartsim/entity/files.py index 5eaca8c65..35868098f 100644 --- a/smartsim/entity/files.py +++ b/smartsim/entity/files.py @@ -51,9 +51,9 @@ class EntityFiles: def __init__( self, - tagged: t.Optional[t.List[str]] = None, - copy: t.Optional[t.List[str]] = None, - symlink: t.Optional[t.List[str]] = None, + tagged: list[str] | None = None, + copy: list[str] | None = None, + symlink: list[str] | None = None, ) -> None: """Initialize an EntityFiles instance @@ -93,9 +93,7 @@ def _check_files(self) -> None: self.link[i] = self._check_path(value) @staticmethod - def _type_check_files( - file_list: t.Union[t.List[str], None], file_type: str - ) -> t.List[str]: + def _type_check_files(file_list: list[str] | None, file_type: str) -> list[str]: """Check the type of the files provided by the user. :param file_list: either tagged, copy, or symlink files @@ -169,7 +167,7 @@ class TaggedFilesHierarchy: tagged file directory structure can be replicated """ - def __init__(self, parent: t.Optional[t.Any] = None, subdir_name: str = "") -> None: + def __init__(self, parent: t.Any | None = None, subdir_name: str = "") -> None: """Initialize a TaggedFilesHierarchy :param parent: The parent hierarchy of the new hierarchy, @@ -203,8 +201,8 @@ def __init__(self, parent: t.Optional[t.Any] = None, subdir_name: str = "") -> N self._base: str = path.join(parent.base, subdir_name) if parent else "" self.parent: t.Any = parent - self.files: t.Set[str] = set() - self.dirs: t.Set[TaggedFilesHierarchy] = set() + self.files: set[str] = set() + self.dirs: set[TaggedFilesHierarchy] = set() @property def base(self) -> str: @@ -213,7 +211,7 @@ def base(self) -> str: @classmethod def from_list_paths( - cls, path_list: t.List[str], dir_contents_to_base: bool = False + cls, path_list: list[str], dir_contents_to_base: bool = False ) -> t.Any: """Given a list of absolute paths to files and dirs, create and return a TaggedFilesHierarchy instance representing the file hierarchy of @@ -264,7 +262,7 @@ def _add_dir(self, dir_path: str) -> None: [path.join(dir_path, file) for file in os.listdir(dir_path)] ) - def _add_paths(self, paths: t.List[str]) -> None: + def _add_paths(self, paths: list[str]) -> None: """Takes a list of paths and iterates over it, determining if each path is to a file or a dir and then appropriatly adding it to the TaggedFilesHierarchy. diff --git a/smartsim/entity/model.py b/smartsim/entity/model.py index 70bc6c34c..76c60ad1d 100644 --- a/smartsim/entity/model.py +++ b/smartsim/entity/model.py @@ -32,6 +32,7 @@ import sys import typing as t import warnings +from collections.abc import Iterable, Mapping from os import getcwd from os import path as osp @@ -48,13 +49,13 @@ logger = get_logger(__name__) -def _parse_model_parameters(params_dict: t.Dict[str, t.Any]) -> t.Dict[str, str]: +def _parse_model_parameters(params_dict: dict[str, t.Any]) -> dict[str, str]: """Convert the values in a params dict to strings :raises TypeError: if params are of the wrong type :return: param dictionary with values and keys cast as strings """ - param_names: t.List[str] = [] - parameters: t.List[str] = [] + param_names: list[str] = [] + parameters: list[str] = [] for name, val in params_dict.items(): param_names.append(name) if isinstance(val, (str, numbers.Number)): @@ -71,11 +72,11 @@ class Model(SmartSimEntity): def __init__( self, name: str, - params: t.Dict[str, str], + params: dict[str, str], run_settings: RunSettings, - path: t.Optional[str] = getcwd(), - params_as_args: t.Optional[t.List[str]] = None, - batch_settings: t.Optional[BatchSettings] = None, + path: str | None = getcwd(), + params_as_args: list[str] | None = None, + batch_settings: BatchSettings | None = None, ): """Initialize a ``Model`` @@ -93,15 +94,15 @@ def __init__( super().__init__(name, str(path), run_settings) self.params = _parse_model_parameters(params) self.params_as_args = params_as_args - self.incoming_entities: t.List[SmartSimEntity] = [] + self.incoming_entities: list[SmartSimEntity] = [] self._key_prefixing_enabled = False self.batch_settings = batch_settings - self._db_models: t.List[DBModel] = [] - self._db_scripts: t.List[DBScript] = [] - self.files: t.Optional[EntityFiles] = None + self._db_models: list[DBModel] = [] + self._db_scripts: list[DBScript] = [] + self.files: EntityFiles | None = None @property - def db_models(self) -> t.Iterable[DBModel]: + def db_models(self) -> Iterable[DBModel]: """Retrieve an immutable collection of attached models :return: Return an immutable collection of attached models @@ -109,7 +110,7 @@ def db_models(self) -> t.Iterable[DBModel]: return (model for model in self._db_models) @property - def db_scripts(self) -> t.Iterable[DBScript]: + def db_scripts(self) -> Iterable[DBScript]: """Retrieve an immutable collection attached of scripts :return: Return an immutable collection of attached scripts @@ -161,9 +162,9 @@ def query_key_prefixing(self) -> bool: def attach_generator_files( self, - to_copy: t.Optional[t.List[str]] = None, - to_symlink: t.Optional[t.List[str]] = None, - to_configure: t.Optional[t.List[str]] = None, + to_copy: list[str] | None = None, + to_symlink: list[str] | None = None, + to_configure: list[str] | None = None, ) -> None: """Attach files to an entity for generation @@ -235,7 +236,7 @@ def colocate_db_uds( unix_socket: str = "/tmp/redis.socket", socket_permissions: int = 755, db_cpus: int = 1, - custom_pinning: t.Optional[t.Iterable[t.Union[int, t.Iterable[int]]]] = None, + custom_pinning: Iterable[int | Iterable[int]] | None = None, debug: bool = False, db_identifier: str = "", **kwargs: t.Any, @@ -276,7 +277,7 @@ def colocate_db_uds( f"Invalid name for unix socket: {unix_socket}. Must only " "contain alphanumeric characters or . : _ - /" ) - uds_options: t.Dict[str, t.Union[int, str]] = { + uds_options: dict[str, int | str] = { "unix_socket": unix_socket, "socket_permissions": socket_permissions, # This is hardcoded to 0 as recommended by redis for UDS @@ -294,9 +295,9 @@ def colocate_db_uds( def colocate_db_tcp( self, port: int = 6379, - ifname: t.Union[str, list[str]] = "lo", + ifname: str | list[str] = "lo", db_cpus: int = 1, - custom_pinning: t.Optional[t.Iterable[t.Union[int, t.Iterable[int]]]] = None, + custom_pinning: Iterable[int | Iterable[int]] | None = None, debug: bool = False, db_identifier: str = "", **kwargs: t.Any, @@ -343,18 +344,12 @@ def colocate_db_tcp( def _set_colocated_db_settings( self, - connection_options: t.Mapping[str, t.Union[int, t.List[str], str]], - common_options: t.Dict[ + connection_options: Mapping[str, int | list[str] | str], + common_options: dict[ str, - t.Union[ - t.Union[t.Iterable[t.Union[int, t.Iterable[int]]], None], - bool, - int, - str, - None, - ], + Iterable[int | Iterable[int]] | None | bool | int | str | None, ], - **kwargs: t.Union[int, None], + **kwargs: int | None, ) -> None: """ Ingest the connection-specific options (UDS/TCP) and set the final settings @@ -378,7 +373,7 @@ def _set_colocated_db_settings( # TODO list which db settings can be extras custom_pinning_ = t.cast( - t.Optional[t.Iterable[t.Union[int, t.Iterable[int]]]], + Iterable[int | Iterable[int]] | None, common_options.get("custom_pinning"), ) cpus_ = t.cast(int, common_options.get("cpus")) @@ -386,20 +381,20 @@ def _set_colocated_db_settings( custom_pinning_, cpus_ ) - colo_db_config: t.Dict[ + colo_db_config: dict[ str, - t.Union[ - bool, - int, - str, - None, - t.List[str], - t.Iterable[t.Union[int, t.Iterable[int]]], - t.List[DBModel], - t.List[DBScript], - t.Dict[str, t.Union[int, None]], - t.Dict[str, str], - ], + ( + bool + | int + | str + | None + | list[str] + | Iterable[int | Iterable[int]] + | list[DBModel] + | list[DBScript] + | dict[str, int | None] + | dict[str, str] + ), ] = {} colo_db_config.update(connection_options) colo_db_config.update(common_options) @@ -423,8 +418,8 @@ def _set_colocated_db_settings( @staticmethod def _create_pinning_string( - pin_ids: t.Optional[t.Iterable[t.Union[int, t.Iterable[int]]]], cpus: int - ) -> t.Optional[str]: + pin_ids: Iterable[int | Iterable[int]] | None, cpus: int + ) -> str | None: """Create a comma-separated string of CPU ids. By default, ``None`` returns 0,1,...,cpus-1; an empty iterable will disable pinning altogether, and an iterable constructs a comma separated string of @@ -432,7 +427,7 @@ def _create_pinning_string( """ def _stringify_id(_id: int) -> str: - """Return the cPU id as a string if an int, otherwise raise a ValueError""" + """Return the CPU id as a string if an int, otherwise raise a ValueError""" if isinstance(_id, int): if _id < 0: raise ValueError("CPU id must be a nonnegative number") @@ -491,8 +486,8 @@ def add_ml_model( self, name: str, backend: str, - model: t.Optional[bytes] = None, - model_path: t.Optional[str] = None, + model: bytes | None = None, + model_path: str | None = None, device: str = Device.CPU.value.upper(), devices_per_node: int = 1, first_device: int = 0, @@ -500,8 +495,8 @@ def add_ml_model( min_batch_size: int = 0, min_batch_timeout: int = 0, tag: str = "", - inputs: t.Optional[t.List[str]] = None, - outputs: t.Optional[t.List[str]] = None, + inputs: list[str] | None = None, + outputs: list[str] | None = None, ) -> None: """A TF, TF-lite, PT, or ONNX model to load into the DB at runtime @@ -550,8 +545,8 @@ def add_ml_model( def add_script( self, name: str, - script: t.Optional[str] = None, - script_path: t.Optional[str] = None, + script: str | None = None, + script_path: str | None = None, device: str = Device.CPU.value.upper(), devices_per_node: int = 1, first_device: int = 0, @@ -597,7 +592,7 @@ def add_script( def add_function( self, name: str, - function: t.Optional[str] = None, + function: str | None = None, device: str = Device.CPU.value.upper(), devices_per_node: int = 1, first_device: int = 0, diff --git a/smartsim/entity/strategies.py b/smartsim/entity/strategies.py index 5d0c48a46..923db4113 100644 --- a/smartsim/entity/strategies.py +++ b/smartsim/entity/strategies.py @@ -26,15 +26,14 @@ # Generation Strategies import random -import typing as t from itertools import product # create permutations of all parameters # single model if parameters only have one value def create_all_permutations( - param_names: t.List[str], param_values: t.List[t.List[str]], _n_models: int = 0 -) -> t.List[t.Dict[str, str]]: + param_names: list[str], param_values: list[list[str]], _n_models: int = 0 +) -> list[dict[str, str]]: perms = list(product(*param_values)) all_permutations = [] for permutation in perms: @@ -44,8 +43,8 @@ def create_all_permutations( def step_values( - param_names: t.List[str], param_values: t.List[t.List[str]], _n_models: int = 0 -) -> t.List[t.Dict[str, str]]: + param_names: list[str], param_values: list[list[str]], _n_models: int = 0 +) -> list[dict[str, str]]: permutations = [] for param_value in zip(*param_values): permutations.append(dict(zip(param_names, param_value))) @@ -53,8 +52,8 @@ def step_values( def random_permutations( - param_names: t.List[str], param_values: t.List[t.List[str]], n_models: int = 0 -) -> t.List[t.Dict[str, str]]: + param_names: list[str], param_values: list[list[str]], n_models: int = 0 +) -> list[dict[str, str]]: permutations = create_all_permutations(param_names, param_values) # sample from available permutations if n_models is specified diff --git a/smartsim/error/errors.py b/smartsim/error/errors.py index e62ec4cf0..dd0519dec 100644 --- a/smartsim/error/errors.py +++ b/smartsim/error/errors.py @@ -24,7 +24,6 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import typing as t # Exceptions @@ -124,8 +123,8 @@ class ShellError(LauncherError): def __init__( self, message: str, - command_list: t.Union[str, t.List[str]], - details: t.Optional[t.Union[Exception, str]] = None, + command_list: str | list[str], + details: Exception | str | None = None, ) -> None: msg = self.create_message(message, command_list, details=details) super().__init__(msg) @@ -133,8 +132,8 @@ def __init__( @staticmethod def create_message( message: str, - command_list: t.Union[str, t.List[str]], - details: t.Optional[t.Union[Exception, str]], + command_list: str | list[str], + details: Exception | str | None, ) -> str: if isinstance(command_list, list): command_list = " ".join(command_list) diff --git a/smartsim/experiment.py b/smartsim/experiment.py index 2674682bd..e04ff5fe7 100644 --- a/smartsim/experiment.py +++ b/smartsim/experiment.py @@ -78,7 +78,7 @@ class Experiment: def __init__( self, name: str, - exp_path: t.Optional[str] = None, + exp_path: str | None = None, launcher: str = "local", ): """Initialize an Experiment instance. @@ -149,7 +149,7 @@ def __init__( self._control = Controller(launcher=self._launcher) - self.db_identifiers: t.Set[str] = set() + self.db_identifiers: set[str] = set() def _set_dragon_server_path(self) -> None: """Set path for dragon server through environment varialbes""" @@ -161,7 +161,7 @@ def _set_dragon_server_path(self) -> None: @_contextualize def start( self, - *args: t.Union[SmartSimEntity, EntitySequence[SmartSimEntity]], + *args: SmartSimEntity | EntitySequence[SmartSimEntity], block: bool = True, summary: bool = False, kill_on_interrupt: bool = True, @@ -228,9 +228,7 @@ def start( raise @_contextualize - def stop( - self, *args: t.Union[SmartSimEntity, EntitySequence[SmartSimEntity]] - ) -> None: + def stop(self, *args: SmartSimEntity | EntitySequence[SmartSimEntity]) -> None: """Stop specific instances launched by this ``Experiment`` Instances of ``Model``, ``Ensemble`` and ``Orchestrator`` @@ -270,8 +268,8 @@ def stop( @_contextualize def generate( self, - *args: t.Union[SmartSimEntity, EntitySequence[SmartSimEntity]], - tag: t.Optional[str] = None, + *args: SmartSimEntity | EntitySequence[SmartSimEntity], + tag: str | None = None, overwrite: bool = False, verbose: bool = False, ) -> None: @@ -365,8 +363,8 @@ def finished(self, entity: SmartSimEntity) -> bool: @_contextualize def get_status( - self, *args: t.Union[SmartSimEntity, EntitySequence[SmartSimEntity]] - ) -> t.List[SmartSimStatus]: + self, *args: SmartSimEntity | EntitySequence[SmartSimEntity] + ) -> list[SmartSimStatus]: """Query the status of launched entity instances Return a smartsim.status string representing @@ -393,7 +391,7 @@ def get_status( """ try: manifest = Manifest(*args) - statuses: t.List[SmartSimStatus] = [] + statuses: list[SmartSimStatus] = [] for entity in manifest.models: statuses.append(self._control.get_entity_status(entity)) for entity_list in manifest.all_entity_lists: @@ -407,12 +405,12 @@ def get_status( def create_ensemble( self, name: str, - params: t.Optional[t.Dict[str, t.Any]] = None, - batch_settings: t.Optional[base.BatchSettings] = None, - run_settings: t.Optional[base.RunSettings] = None, - replicas: t.Optional[int] = None, + params: dict[str, t.Any] | None = None, + batch_settings: base.BatchSettings | None = None, + run_settings: base.RunSettings | None = None, + replicas: int | None = None, perm_strategy: str = "all_perm", - path: t.Optional[str] = None, + path: str | None = None, **kwargs: t.Any, ) -> Ensemble: """Create an ``Ensemble`` of ``Model`` instances @@ -483,10 +481,10 @@ def create_model( self, name: str, run_settings: base.RunSettings, - params: t.Optional[t.Dict[str, t.Any]] = None, - path: t.Optional[str] = None, + params: dict[str, t.Any] | None = None, + path: str | None = None, enable_key_prefixing: bool = False, - batch_settings: t.Optional[base.BatchSettings] = None, + batch_settings: base.BatchSettings | None = None, ) -> Model: """Create a general purpose ``Model`` @@ -591,11 +589,11 @@ def create_model( def create_run_settings( self, exe: str, - exe_args: t.Optional[t.List[str]] = None, + exe_args: list[str] | None = None, run_command: str = "auto", - run_args: t.Optional[t.Dict[str, t.Union[int, str, float, None]]] = None, - env_vars: t.Optional[t.Dict[str, t.Optional[str]]] = None, - container: t.Optional[Container] = None, + run_args: dict[str, int | str | float | None] | None = None, + env_vars: dict[str, str | None] | None = None, + container: Container | None = None, **kwargs: t.Any, ) -> settings.RunSettings: """Create a ``RunSettings`` instance. @@ -651,7 +649,7 @@ def create_batch_settings( time: str = "", queue: str = "", account: str = "", - batch_args: t.Optional[t.Dict[str, str]] = None, + batch_args: dict[str, str] | None = None, **kwargs: t.Any, ) -> base.BatchSettings: """Create a ``BatchSettings`` instance @@ -703,15 +701,15 @@ def create_batch_settings( def create_database( self, port: int = 6379, - path: t.Optional[str] = None, + path: str | None = None, db_nodes: int = 1, batch: bool = False, - hosts: t.Optional[t.Union[t.List[str], str]] = None, + hosts: list[str] | str | None = None, run_command: str = "auto", - interface: t.Union[str, t.List[str]] = "ipogif0", - account: t.Optional[str] = None, - time: t.Optional[str] = None, - queue: t.Optional[str] = None, + interface: str | list[str] = "ipogif0", + account: str | None = None, + time: str | None = None, + queue: str | None = None, single_cmd: bool = True, db_identifier: str = "orchestrator", **kwargs: t.Any, @@ -798,7 +796,7 @@ def preview( *args: t.Any, verbosity_level: previewrenderer.Verbosity = previewrenderer.Verbosity.INFO, output_format: previewrenderer.Format = previewrenderer.Format.PLAINTEXT, - output_filename: t.Optional[str] = None, + output_filename: str | None = None, ) -> None: """Preview entity information prior to launch. This method aggregates multiple pieces of information to give users insight @@ -909,7 +907,7 @@ def _launch_summary(self, manifest: Manifest) -> None: logger.info(summary) def _create_entity_dir(self, start_manifest: Manifest) -> None: - def create_entity_dir(entity: t.Union[Orchestrator, Model, Ensemble]) -> None: + def create_entity_dir(entity: Orchestrator | Model | Ensemble) -> None: if not os.path.isdir(entity.path): os.makedirs(entity.path) diff --git a/smartsim/log.py b/smartsim/log.py index 50a126bad..9437adb2d 100644 --- a/smartsim/log.py +++ b/smartsim/log.py @@ -31,6 +31,7 @@ import sys import threading import typing as t +from collections.abc import Callable from contextvars import ContextVar, copy_context import coloredlogs @@ -89,7 +90,7 @@ def _translate_log_level(user_log_level: str = "info") -> str: return "info" -def get_exp_log_paths() -> t.Tuple[t.Optional[pathlib.Path], t.Optional[pathlib.Path]]: +def get_exp_log_paths() -> tuple[pathlib.Path | None, pathlib.Path | None]: """Returns the output and error file paths to experiment logs. Returns None for both paths if experiment context is unavailable. @@ -154,7 +155,7 @@ class ContextAwareLogger(logging.Logger): """A logger customized to automatically write experiment logs to a dynamic target directory by inspecting the value of a context var""" - def __init__(self, name: str, level: t.Union[int, str] = 0) -> None: + def __init__(self, name: str, level: int | str = 0) -> None: super().__init__(name, level) self.addFilter(ContextInjectingLogFilter(name="exp-ctx-log-filter")) @@ -163,8 +164,8 @@ def _log( level: int, msg: object, args: t.Any, - exc_info: t.Optional[t.Any] = None, - extra: t.Optional[t.Any] = None, + exc_info: t.Any | None = None, + extra: t.Any | None = None, stack_info: bool = False, stacklevel: int = 1, ) -> None: @@ -189,7 +190,7 @@ def _log( def get_logger( - name: str, log_level: t.Optional[str] = None, fmt: t.Optional[str] = None + name: str, log_level: str | None = None, fmt: str | None = None ) -> logging.Logger: """Return a logger instance @@ -272,8 +273,8 @@ def log_to_exp_file( filename: str, logger: logging.Logger, log_level: str = "warn", - fmt: t.Optional[str] = EXPERIMENT_LOG_FORMAT, - log_filter: t.Optional[logging.Filter] = None, + fmt: str | None = EXPERIMENT_LOG_FORMAT, + log_filter: logging.Filter | None = None, ) -> logging.Handler: """Installs a second filestream handler to the root logger, allowing subsequent logging calls to be sent to filename. @@ -308,10 +309,10 @@ def log_to_exp_file( def method_contextualizer( ctx_var: ContextVar[_ContextT], - ctx_map: t.Callable[[_T], _ContextT], -) -> """t.Callable[ - [t.Callable[Concatenate[_T, _PR], _RT]], - t.Callable[Concatenate[_T, _PR], _RT], + ctx_map: Callable[[_T], _ContextT], +) -> """Callable[ + [Callable[Concatenate[_T, _PR], _RT]], + Callable[Concatenate[_T, _PR], _RT], ]""": """Parameterized-decorator factory that enables a target value to be placed into global context prior to execution of the @@ -325,8 +326,8 @@ def method_contextualizer( """ def _contextualize( - fn: "t.Callable[Concatenate[_T, _PR], _RT]", / - ) -> "t.Callable[Concatenate[_T, _PR], _RT]": + fn: "Callable[Concatenate[_T, _PR], _RT]", / + ) -> "Callable[Concatenate[_T, _PR], _RT]": """Executes the decorated method in a cloned context and ensures `ctx_var` is updated to the value returned by `ctx_map` prior to calling the decorated method""" diff --git a/smartsim/ml/data.py b/smartsim/ml/data.py index 332966bbe..bd49024ff 100644 --- a/smartsim/ml/data.py +++ b/smartsim/ml/data.py @@ -69,7 +69,7 @@ def __init__( list_name: str, sample_name: str = "samples", target_name: str = "targets", - num_classes: t.Optional[int] = None, + num_classes: int | None = None, ) -> None: self.list_name = list_name self.sample_name = sample_name @@ -160,10 +160,10 @@ def __init__( list_name: str = "training_data", sample_name: str = "samples", target_name: str = "targets", - num_classes: t.Optional[int] = None, + num_classes: int | None = None, cluster: bool = True, - address: t.Optional[str] = None, - rank: t.Optional[int] = None, + address: str | None = None, + rank: int | None = None, verbose: bool = False, ) -> None: if not list_name: @@ -190,7 +190,7 @@ def target_name(self) -> str: return self._info.target_name @property - def num_classes(self) -> t.Optional[int]: + def num_classes(self) -> int | None: return self._info.num_classes def publish_info(self) -> None: @@ -199,7 +199,7 @@ def publish_info(self) -> None: def put_batch( self, samples: np.ndarray, # type: ignore[type-arg] - targets: t.Optional[np.ndarray] = None, # type: ignore[type-arg] + targets: np.ndarray | None = None, # type: ignore[type-arg] ) -> None: batch_ds_name = form_name("training_samples", self.rank, self.batch_idx) batch_ds = Dataset(batch_ds_name) @@ -276,12 +276,12 @@ class DataDownloader: def __init__( self, - data_info_or_list_name: t.Union[str, DataInfo], + data_info_or_list_name: str | DataInfo, batch_size: int = 32, dynamic: bool = True, shuffle: bool = True, cluster: bool = True, - address: t.Optional[str] = None, + address: str | None = None, replica_rank: int = 0, num_replicas: int = 1, verbose: bool = False, @@ -292,8 +292,8 @@ def __init__( self.address = address self.cluster = cluster self.verbose = verbose - self.samples: t.Optional["npt.NDArray[t.Any]"] = None - self.targets: t.Optional["npt.NDArray[t.Any]"] = None + self.samples: "npt.NDArray[t.Any] | None" = None + self.targets: "npt.NDArray[t.Any] | None" = None self.num_samples = 0 self.indices = np.arange(0) self.shuffle = shuffle @@ -307,7 +307,7 @@ def __init__( self._info.download(client) else: raise TypeError("data_info_or_list_name must be either DataInfo or str") - self._client: t.Optional[Client] = None + self._client: Client | None = None sskeyin = environ.get("SSKEYIN", "") self.uploader_keys = sskeyin.split(",") @@ -348,7 +348,7 @@ def target_name(self) -> str: return self._info.target_name @property - def num_classes(self) -> t.Optional[int]: + def num_classes(self) -> int | None: return self._info.num_classes @property @@ -368,7 +368,7 @@ def _calc_indices(self, index: int) -> np.ndarray: # type: ignore[type-arg] def __iter__( self, - ) -> t.Iterator[t.Tuple[np.ndarray, np.ndarray]]: # type: ignore[type-arg] + ) -> t.Iterator[tuple[np.ndarray, np.ndarray]]: # type: ignore[type-arg] self.update_data() # Generate data if len(self) < 1: @@ -416,8 +416,8 @@ def _data_exists(self, batch_name: str, target_name: str) -> bool: return bool(self.client.tensor_exists(batch_name)) - def _add_samples(self, indices: t.List[int]) -> None: - datasets: t.List[Dataset] = [] + def _add_samples(self, indices: list[int]) -> None: + datasets: list[Dataset] = [] if self.num_replicas == 1: datasets = self.client.get_dataset_list_range( @@ -483,7 +483,7 @@ def update_data(self) -> None: def _data_generation( self, indices: "npt.NDArray[t.Any]" - ) -> t.Tuple["npt.NDArray[t.Any]", "npt.NDArray[t.Any]"]: + ) -> tuple["npt.NDArray[t.Any]", "npt.NDArray[t.Any]"]: # Initialization if self.samples is None: raise ValueError("Samples have not been initialized") diff --git a/smartsim/ml/tf/data.py b/smartsim/ml/tf/data.py index 23885d505..d58283345 100644 --- a/smartsim/ml/tf/data.py +++ b/smartsim/ml/tf/data.py @@ -38,7 +38,7 @@ class _TFDataGenerationCommon(DataDownloader, keras.utils.Sequence): def __getitem__( self, index: int - ) -> t.Tuple[np.ndarray, np.ndarray]: # type: ignore[type-arg] + ) -> tuple[np.ndarray, np.ndarray]: # type: ignore[type-arg] if len(self) < 1: raise ValueError( "Not enough samples in generator for one batch. Please " @@ -65,7 +65,7 @@ def on_epoch_end(self) -> None: def _data_generation( self, indices: "npt.NDArray[t.Any]" - ) -> t.Tuple["npt.NDArray[t.Any]", "npt.NDArray[t.Any]"]: + ) -> tuple["npt.NDArray[t.Any]", "npt.NDArray[t.Any]"]: # Initialization if self.samples is None: raise ValueError("No samples loaded for data generation") diff --git a/smartsim/ml/tf/utils.py b/smartsim/ml/tf/utils.py index 2de6a0bcf..f334784bc 100644 --- a/smartsim/ml/tf/utils.py +++ b/smartsim/ml/tf/utils.py @@ -36,7 +36,7 @@ def freeze_model( model: keras.Model, output_dir: str, file_name: str -) -> t.Tuple[str, t.List[str], t.List[str]]: +) -> tuple[str, list[str], list[str]]: """Freeze a Keras or TensorFlow Graph to use a Keras or TensorFlow model in SmartSim, the model @@ -78,7 +78,7 @@ def freeze_model( return model_file_path, input_names, output_names -def serialize_model(model: keras.Model) -> t.Tuple[str, t.List[str], t.List[str]]: +def serialize_model(model: keras.Model) -> tuple[str, list[str], list[str]]: """Serialize a Keras or TensorFlow Graph to use a Keras or TensorFlow model in SmartSim, the model diff --git a/smartsim/ml/torch/data.py b/smartsim/ml/torch/data.py index 04e508d34..bd8582bbd 100644 --- a/smartsim/ml/torch/data.py +++ b/smartsim/ml/torch/data.py @@ -44,13 +44,13 @@ def __init__(self, **kwargs: t.Any) -> None: "init_samples=False. Setting it to False automatically." ) - def _add_samples(self, indices: t.List[int]) -> None: + def _add_samples(self, indices: list[int]) -> None: if self.client is None: client = Client(self.cluster, self.address) else: client = self.client - datasets: t.List[Dataset] = [] + datasets: list[Dataset] = [] if self.num_replicas == 1: datasets = client.get_dataset_list_range( self.list_name, start_index=indices[0], end_index=indices[-1] diff --git a/smartsim/settings/alpsSettings.py b/smartsim/settings/alpsSettings.py index 51d99f02a..6059cc193 100644 --- a/smartsim/settings/alpsSettings.py +++ b/smartsim/settings/alpsSettings.py @@ -36,9 +36,9 @@ class AprunSettings(RunSettings): def __init__( self, exe: str, - exe_args: t.Optional[t.Union[str, t.List[str]]] = None, - run_args: t.Optional[t.Dict[str, t.Union[int, str, float, None]]] = None, - env_vars: t.Optional[t.Dict[str, t.Optional[str]]] = None, + exe_args: t.Optional[str | list[str]] = None, + run_args: dict[str, int | str | float | None] | None = None, + env_vars: dict[str, str | None] | None = None, **kwargs: t.Any, ): """Settings to run job with ``aprun`` command @@ -58,7 +58,7 @@ def __init__( env_vars=env_vars, **kwargs, ) - self.mpmd: t.List[RunSettings] = [] + self.mpmd: list[RunSettings] = [] def make_mpmd(self, settings: RunSettings) -> None: """Make job an MPMD job @@ -105,7 +105,7 @@ def set_tasks_per_node(self, tasks_per_node: int) -> None: """ self.run_args["pes-per-node"] = int(tasks_per_node) - def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None: + def set_hostlist(self, host_list: str | list[str]) -> None: """Specify the hostlist for this job :param host_list: hosts to launch on @@ -128,7 +128,7 @@ def set_hostlist_from_file(self, file_path: str) -> None: """ self.run_args["node-list-file"] = file_path - def set_excluded_hosts(self, host_list: t.Union[str, t.List[str]]) -> None: + def set_excluded_hosts(self, host_list: str | list[str]) -> None: """Specify a list of hosts to exclude for launching this job :param host_list: hosts to exclude @@ -142,7 +142,7 @@ def set_excluded_hosts(self, host_list: t.Union[str, t.List[str]]) -> None: raise TypeError("host_list argument must be list of strings") self.run_args["exclude-node-list"] = ",".join(host_list) - def set_cpu_bindings(self, bindings: t.Union[int, t.List[int]]) -> None: + def set_cpu_bindings(self, bindings: int | list[int]) -> None: """Specifies the cores to which MPI processes are bound This sets ``--cpu-binding`` @@ -186,7 +186,7 @@ def set_quiet_launch(self, quiet: bool) -> None: else: self.run_args.pop("quiet", None) - def format_run_args(self) -> t.List[str]: + def format_run_args(self) -> list[str]: """Return a list of ALPS formatted run arguments :return: list of ALPS arguments for these settings @@ -208,7 +208,7 @@ def format_run_args(self) -> t.List[str]: args += ["=".join((prefix + opt, str(value)))] return args - def format_env_vars(self) -> t.List[str]: + def format_env_vars(self) -> list[str]: """Format the environment variables for aprun :return: list of env vars diff --git a/smartsim/settings/base.py b/smartsim/settings/base.py index 03ea0cadf..039d5844e 100644 --- a/smartsim/settings/base.py +++ b/smartsim/settings/base.py @@ -26,6 +26,7 @@ import copy import typing as t +from collections.abc import Iterable from smartsim.settings.containers import Container @@ -48,11 +49,11 @@ class RunSettings(SettingsBase): def __init__( self, exe: str, - exe_args: t.Optional[t.Union[str, t.List[str]]] = None, + exe_args: str | list[str] | None = None, run_command: str = "", - run_args: t.Optional[t.Dict[str, t.Union[int, str, float, None]]] = None, - env_vars: t.Optional[t.Dict[str, t.Optional[str]]] = None, - container: t.Optional[Container] = None, + run_args: dict[str, int | str | float | None] | None = None, + env_vars: dict[str, str | None] | None = None, + container: Container | None = None, **_kwargs: t.Any, ) -> None: """Run parameters for a ``Model`` @@ -89,26 +90,27 @@ def __init__( self.container = container self._run_command = run_command self.in_batch = False - self.colocated_db_settings: t.Optional[ - t.Dict[ + self.colocated_db_settings: ( + dict[ str, - t.Union[ - bool, - int, - str, - None, - t.List[str], - t.Iterable[t.Union[int, t.Iterable[int]]], - t.List[DBModel], - t.List[DBScript], - t.Dict[str, t.Union[int, None]], - t.Dict[str, str], - ], + ( + bool + | int + | str + | None + | list[str] + | Iterable[int | Iterable[int]] + | list[DBModel] + | list[DBScript] + | dict[str, int | None] + | dict[str, str] + ), ] - ] = None + | None + ) = None @property - def exe_args(self) -> t.Union[str, t.List[str]]: + def exe_args(self) -> str | list[str]: """Return an immutable list of attached executable arguments. :returns: attached executable arguments @@ -116,7 +118,7 @@ def exe_args(self) -> t.Union[str, t.List[str]]: return self._exe_args @exe_args.setter - def exe_args(self, value: t.Union[str, t.List[str], None]) -> None: + def exe_args(self, value: str | list[str] | None) -> None: """Set the executable arguments. :param value: executable arguments @@ -124,7 +126,7 @@ def exe_args(self, value: t.Union[str, t.List[str], None]) -> None: self._exe_args = self._build_exe_args(value) @property - def run_args(self) -> t.Dict[str, t.Union[int, str, float, None]]: + def run_args(self) -> dict[str, int | str | float | None]: """Return an immutable list of attached run arguments. :returns: attached run arguments @@ -132,7 +134,7 @@ def run_args(self) -> t.Dict[str, t.Union[int, str, float, None]]: return self._run_args @run_args.setter - def run_args(self, value: t.Dict[str, t.Union[int, str, float, None]]) -> None: + def run_args(self, value: dict[str, int | str | float | None]) -> None: """Set the run arguments. :param value: run arguments @@ -140,7 +142,7 @@ def run_args(self, value: t.Dict[str, t.Union[int, str, float, None]]) -> None: self._run_args = copy.deepcopy(value) @property - def env_vars(self) -> t.Dict[str, t.Optional[str]]: + def env_vars(self) -> dict[str, str | None]: """Return an immutable list of attached environment variables. :returns: attached environment variables @@ -148,7 +150,7 @@ def env_vars(self) -> t.Dict[str, t.Optional[str]]: return self._env_vars @env_vars.setter - def env_vars(self, value: t.Dict[str, t.Optional[str]]) -> None: + def env_vars(self, value: dict[str, str | None]) -> None: """Set the environment variables. :param value: environment variables @@ -218,7 +220,7 @@ def set_cpus_per_task(self, cpus_per_task: int) -> None: ) ) - def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None: + def set_hostlist(self, host_list: str | list[str]) -> None: """Specify the hostlist for this job :param host_list: hosts to launch on @@ -242,7 +244,7 @@ def set_hostlist_from_file(self, file_path: str) -> None: ) ) - def set_excluded_hosts(self, host_list: t.Union[str, t.List[str]]) -> None: + def set_excluded_hosts(self, host_list: str | list[str]) -> None: """Specify a list of hosts to exclude for launching this job :param host_list: hosts to exclude @@ -254,7 +256,7 @@ def set_excluded_hosts(self, host_list: t.Union[str, t.List[str]]) -> None: ) ) - def set_cpu_bindings(self, bindings: t.Union[int, t.List[int]]) -> None: + def set_cpu_bindings(self, bindings: int | list[int]) -> None: """Set the cores to which MPI processes are bound :param bindings: List specifing the cores to which MPI processes are bound @@ -302,7 +304,7 @@ def set_quiet_launch(self, quiet: bool) -> None: ) ) - def set_broadcast(self, dest_path: t.Optional[str] = None) -> None: + def set_broadcast(self, dest_path: str | None = None) -> None: """Copy executable file to allocated compute nodes :param dest_path: Path to copy an executable file @@ -325,7 +327,7 @@ def set_time(self, hours: int = 0, minutes: int = 0, seconds: int = 0) -> None: self._fmt_walltime(int(hours), int(minutes), int(seconds)) ) - def set_node_feature(self, feature_list: t.Union[str, t.List[str]]) -> None: + def set_node_feature(self, feature_list: str | list[str]) -> None: """Specify the node feature for this job :param feature_list: node feature to launch on @@ -377,7 +379,7 @@ def set_binding(self, binding: str) -> None: ) ) - def set_mpmd_preamble(self, preamble_lines: t.List[str]) -> None: + def set_mpmd_preamble(self, preamble_lines: list[str]) -> None: """Set preamble to a file to make a job MPMD :param preamble_lines: lines to put at the beginning of a file. @@ -402,7 +404,7 @@ def make_mpmd(self, settings: RunSettings) -> None: ) @property - def run_command(self) -> t.Optional[str]: + def run_command(self) -> str | None: """Return the launch binary used to launch the executable Attempt to expand the path to the executable if possible @@ -421,7 +423,7 @@ def run_command(self) -> t.Optional[str]: # run without run command return None - def update_env(self, env_vars: t.Dict[str, t.Union[str, int, float, bool]]) -> None: + def update_env(self, env_vars: dict[str, str | int | float | bool]) -> None: """Update the job environment variables To fully inherit the current user environment, add the @@ -443,7 +445,7 @@ def update_env(self, env_vars: t.Dict[str, t.Union[str, int, float, bool]]) -> N self.env_vars[env] = str(val) - def add_exe_args(self, args: t.Union[str, t.List[str]]) -> None: + def add_exe_args(self, args: str | list[str]) -> None: """Add executable arguments to executable :param args: executable arguments @@ -451,9 +453,7 @@ def add_exe_args(self, args: t.Union[str, t.List[str]]) -> None: args = self._build_exe_args(args) self._exe_args.extend(args) - def set( - self, arg: str, value: t.Optional[str] = None, condition: bool = True - ) -> None: + def set(self, arg: str, value: str | None = None, condition: bool = True) -> None: """Allows users to set individual run arguments. A method that allows users to set run arguments after object @@ -523,7 +523,7 @@ def set( self.run_args[arg] = value @staticmethod - def _build_exe_args(exe_args: t.Optional[t.Union[str, t.List[str]]]) -> t.List[str]: + def _build_exe_args(exe_args: str | list[str] | None) -> list[str]: """Check and convert exe_args input to a desired collection format""" if not exe_args: return [] @@ -545,7 +545,7 @@ def _build_exe_args(exe_args: t.Optional[t.Union[str, t.List[str]]]) -> t.List[s return exe_args - def format_run_args(self) -> t.List[str]: + def format_run_args(self) -> list[str]: """Return formatted run arguments For ``RunSettings``, the run arguments are passed @@ -559,7 +559,7 @@ def format_run_args(self) -> t.List[str]: formatted.append(str(value)) return formatted - def format_env_vars(self) -> t.List[str]: + def format_env_vars(self) -> list[str]: """Build environment variable string :returns: formatted list of strings to export variables @@ -588,12 +588,12 @@ class BatchSettings(SettingsBase): def __init__( self, batch_cmd: str, - batch_args: t.Optional[t.Dict[str, t.Optional[str]]] = None, + batch_args: dict[str, str | None] | None = None, **kwargs: t.Any, ) -> None: self._batch_cmd = batch_cmd self.batch_args = batch_args or {} - self._preamble: t.List[str] = [] + self._preamble: list[str] = [] nodes = kwargs.get("nodes", None) if nodes: self.set_nodes(nodes) @@ -623,7 +623,7 @@ def batch_cmd(self) -> str: return self._batch_cmd @property - def batch_args(self) -> t.Dict[str, t.Optional[str]]: + def batch_args(self) -> dict[str, str | None]: """Retrieve attached batch arguments :returns: attached batch arguments @@ -631,7 +631,7 @@ def batch_args(self) -> t.Dict[str, t.Optional[str]]: return self._batch_args @batch_args.setter - def batch_args(self, value: t.Dict[str, t.Optional[str]]) -> None: + def batch_args(self, value: dict[str, str | None]) -> None: """Attach batch arguments :param value: dictionary of batch arguments @@ -641,7 +641,7 @@ def batch_args(self, value: t.Dict[str, t.Optional[str]]) -> None: def set_nodes(self, num_nodes: int) -> None: raise NotImplementedError - def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None: + def set_hostlist(self, host_list: str | list[str]) -> None: raise NotImplementedError def set_queue(self, queue: str) -> None: @@ -653,7 +653,7 @@ def set_walltime(self, walltime: str) -> None: def set_account(self, account: str) -> None: raise NotImplementedError - def format_batch_args(self) -> t.List[str]: + def format_batch_args(self) -> list[str]: raise NotImplementedError def set_batch_command(self, command: str) -> None: @@ -663,7 +663,7 @@ def set_batch_command(self, command: str) -> None: """ self._batch_cmd = command - def add_preamble(self, lines: t.List[str]) -> None: + def add_preamble(self, lines: list[str]) -> None: """Add lines to the batch file preamble. The lines are just written (unmodified) at the beginning of the batch file (after the WLM directives) and can be used to e.g. @@ -679,7 +679,7 @@ def add_preamble(self, lines: t.List[str]) -> None: raise TypeError("Expected str or List[str] for lines argument") @property - def preamble(self) -> t.Iterable[str]: + def preamble(self) -> Iterable[str]: """Return an iterable of preamble clauses to be prepended to the batch file :return: attached preamble clauses diff --git a/smartsim/settings/containers.py b/smartsim/settings/containers.py index f187bbb48..05f7f6ac8 100644 --- a/smartsim/settings/containers.py +++ b/smartsim/settings/containers.py @@ -101,7 +101,7 @@ class Singularity(Container): def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: super().__init__(*args, **kwargs) - def _container_cmds(self, default_working_directory: str = "") -> t.List[str]: + def _container_cmds(self, default_working_directory: str = "") -> list[str]: """Return list of container commands to be inserted before exe. Container members are validated during this call. diff --git a/smartsim/settings/dragonRunSettings.py b/smartsim/settings/dragonRunSettings.py index 666f490a0..76939e708 100644 --- a/smartsim/settings/dragonRunSettings.py +++ b/smartsim/settings/dragonRunSettings.py @@ -40,8 +40,8 @@ class DragonRunSettings(RunSettings): def __init__( self, exe: str, - exe_args: t.Optional[t.Union[str, t.List[str]]] = None, - env_vars: t.Optional[t.Dict[str, t.Optional[str]]] = None, + exe_args: t.Optional[str | list[str]] = None, + env_vars: dict[str, str | None] | None = None, **kwargs: t.Any, ) -> None: """Initialize run parameters for a Dragon process @@ -82,7 +82,7 @@ def set_tasks_per_node(self, tasks_per_node: int) -> None: self.run_args["tasks-per-node"] = tasks_per_node @override - def set_node_feature(self, feature_list: t.Union[str, t.List[str]]) -> None: + def set_node_feature(self, feature_list: str | list[str]) -> None: """Specify the node feature for this job :param feature_list: a collection of strings representing the required @@ -95,14 +95,14 @@ def set_node_feature(self, feature_list: t.Union[str, t.List[str]]) -> None: self.run_args["node-feature"] = ",".join(feature_list) - def set_cpu_affinity(self, devices: t.List[int]) -> None: + def set_cpu_affinity(self, devices: list[int]) -> None: """Set the CPU affinity for this job :param devices: list of CPU indices to execute on """ self.run_args["cpu-affinity"] = ",".join(str(device) for device in devices) - def set_gpu_affinity(self, devices: t.List[int]) -> None: + def set_gpu_affinity(self, devices: list[int]) -> None: """Set the GPU affinity for this job :param devices: list of GPU indices to execute on. diff --git a/smartsim/settings/mpiSettings.py b/smartsim/settings/mpiSettings.py index ff698a9fb..d356c8879 100644 --- a/smartsim/settings/mpiSettings.py +++ b/smartsim/settings/mpiSettings.py @@ -43,10 +43,10 @@ class _BaseMPISettings(RunSettings): def __init__( self, exe: str, - exe_args: t.Optional[t.Union[str, t.List[str]]] = None, + exe_args: t.Optional[str | list[str]] = None, run_command: str = "mpiexec", - run_args: t.Optional[t.Dict[str, t.Union[int, str, float, None]]] = None, - env_vars: t.Optional[t.Dict[str, t.Optional[str]]] = None, + run_args: dict[str, int | str | float | None] | None = None, + env_vars: dict[str, str | None] | None = None, fail_if_missing_exec: bool = True, **kwargs: t.Any, ) -> None: @@ -75,8 +75,8 @@ def __init__( env_vars=env_vars, **kwargs, ) - self.mpmd: t.List[RunSettings] = [] - self.affinity_script: t.List[str] = [] + self.mpmd: list[RunSettings] = [] + self.affinity_script: list[str] = [] if not shutil.which(self._run_command): msg = ( @@ -151,7 +151,7 @@ def set_tasks(self, tasks: int) -> None: """ self.run_args["n"] = int(tasks) - def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None: + def set_hostlist(self, host_list: str | list[str]) -> None: """Set the hostlist for the ``mpirun`` command This sets ``--host`` @@ -200,7 +200,7 @@ def set_quiet_launch(self, quiet: bool) -> None: else: self.run_args.pop("quiet", None) - def set_broadcast(self, dest_path: t.Optional[str] = None) -> None: + def set_broadcast(self, dest_path: str | None = None) -> None: """Copy the specified executable(s) to remote machines This sets ``--preload-binary`` @@ -225,7 +225,7 @@ def set_walltime(self, walltime: str) -> None: """ self.run_args["timeout"] = walltime - def format_run_args(self) -> t.List[str]: + def format_run_args(self) -> list[str]: """Return a list of MPI-standard formatted run arguments :return: list of MPI-standard arguments for these settings @@ -243,7 +243,7 @@ def format_run_args(self) -> t.List[str]: args += [prefix + opt, str(value)] return args - def format_env_vars(self) -> t.List[str]: + def format_env_vars(self) -> list[str]: """Format the environment variables for mpirun :return: list of env vars @@ -264,9 +264,9 @@ class MpirunSettings(_BaseMPISettings): def __init__( self, exe: str, - exe_args: t.Optional[t.Union[str, t.List[str]]] = None, - run_args: t.Optional[t.Dict[str, t.Union[int, str, float, None]]] = None, - env_vars: t.Optional[t.Dict[str, t.Optional[str]]] = None, + exe_args: t.Optional[str | list[str]] = None, + run_args: dict[str, int | str | float | None] | None = None, + env_vars: dict[str, str | None] | None = None, **kwargs: t.Any, ) -> None: """Settings to run job with ``mpirun`` command (MPI-standard) @@ -291,9 +291,9 @@ class MpiexecSettings(_BaseMPISettings): def __init__( self, exe: str, - exe_args: t.Optional[t.Union[str, t.List[str]]] = None, - run_args: t.Optional[t.Dict[str, t.Union[int, str, float, None]]] = None, - env_vars: t.Optional[t.Dict[str, t.Optional[str]]] = None, + exe_args: t.Optional[str | list[str]] = None, + run_args: dict[str, int | str | float | None] | None = None, + env_vars: dict[str, str | None] | None = None, **kwargs: t.Any, ) -> None: """Settings to run job with ``mpiexec`` command (MPI-standard) @@ -327,9 +327,9 @@ class OrterunSettings(_BaseMPISettings): def __init__( self, exe: str, - exe_args: t.Optional[t.Union[str, t.List[str]]] = None, - run_args: t.Optional[t.Dict[str, t.Union[int, str, float, None]]] = None, - env_vars: t.Optional[t.Dict[str, t.Optional[str]]] = None, + exe_args: t.Optional[str | list[str]] = None, + run_args: dict[str, int | str | float | None] | None = None, + env_vars: dict[str, str | None] | None = None, **kwargs: t.Any, ) -> None: """Settings to run job with ``orterun`` command (MPI-standard) diff --git a/smartsim/settings/palsSettings.py b/smartsim/settings/palsSettings.py index 1d6e9bedf..e619bc991 100644 --- a/smartsim/settings/palsSettings.py +++ b/smartsim/settings/palsSettings.py @@ -53,9 +53,9 @@ class PalsMpiexecSettings(_BaseMPISettings): def __init__( self, exe: str, - exe_args: t.Optional[t.Union[str, t.List[str]]] = None, - run_args: t.Optional[t.Dict[str, t.Union[int, str, float, None]]] = None, - env_vars: t.Optional[t.Dict[str, t.Optional[str]]] = None, + exe_args: t.Optional[str | list[str]] = None, + run_args: dict[str, int | str | float | None] | None = None, + env_vars: dict[str, str | None] | None = None, fail_if_missing_exec: bool = True, **kwargs: t.Any, ) -> None: @@ -142,7 +142,7 @@ def set_quiet_launch(self, quiet: bool) -> None: logger.warning("set_quiet_launch not supported under PALS") - def set_broadcast(self, dest_path: t.Optional[str] = None) -> None: + def set_broadcast(self, dest_path: str | None = None) -> None: """Copy the specified executable(s) to remote machines This sets ``--preload-binary`` @@ -174,7 +174,7 @@ def set_gpu_affinity_script(self, affinity: str, *args: t.Any) -> None: for arg in args: self.affinity_script.append(str(arg)) - def format_run_args(self) -> t.List[str]: + def format_run_args(self) -> list[str]: """Return a list of MPI-standard formatted run arguments :return: list of MPI-standard arguments for these settings @@ -196,7 +196,7 @@ def format_run_args(self) -> t.List[str]: return args - def format_env_vars(self) -> t.List[str]: + def format_env_vars(self) -> list[str]: """Format the environment variables for mpirun :return: list of env vars @@ -216,7 +216,7 @@ def format_env_vars(self) -> t.List[str]: return formatted - def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None: + def set_hostlist(self, host_list: str | list[str]) -> None: """Set the hostlist for the PALS ``mpiexec`` command This sets ``--hosts`` diff --git a/smartsim/settings/pbsSettings.py b/smartsim/settings/pbsSettings.py index 8869c2529..2ec952f62 100644 --- a/smartsim/settings/pbsSettings.py +++ b/smartsim/settings/pbsSettings.py @@ -36,13 +36,13 @@ class QsubBatchSettings(BatchSettings): def __init__( self, - nodes: t.Optional[int] = None, - ncpus: t.Optional[int] = None, - time: t.Optional[str] = None, - queue: t.Optional[str] = None, - account: t.Optional[str] = None, - resources: t.Optional[t.Dict[str, t.Union[str, int]]] = None, - batch_args: t.Optional[t.Dict[str, t.Optional[str]]] = None, + nodes: int | None = None, + ncpus: int | None = None, + time: str | None = None, + queue: str | None = None, + account: str | None = None, + resources: dict[str, str | int] | None = None, + batch_args: dict[str, str | None] | None = None, **kwargs: t.Any, ): """Specify ``qsub`` batch parameters for a job @@ -84,14 +84,14 @@ def __init__( **kwargs, ) - self._hosts: t.List[str] = [] + self._hosts: list[str] = [] @property - def resources(self) -> t.Dict[str, t.Union[str, int]]: + def resources(self) -> dict[str, str | int]: return self._resources.copy() @resources.setter - def resources(self, resources: t.Dict[str, t.Union[str, int]]) -> None: + def resources(self, resources: dict[str, str | int]) -> None: self._sanity_check_resources(resources) self._resources = resources.copy() @@ -110,7 +110,7 @@ def set_nodes(self, num_nodes: int) -> None: if num_nodes: self.set_resource("nodes", num_nodes) - def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None: + def set_hostlist(self, host_list: str | list[str]) -> None: """Specify the hostlist for this job :param host_list: hosts to launch on @@ -146,7 +146,7 @@ def set_queue(self, queue: str) -> None: if queue: self.batch_args["q"] = str(queue) - def set_ncpus(self, num_cpus: t.Union[int, str]) -> None: + def set_ncpus(self, num_cpus: int | str) -> None: """Set the number of cpus obtained in each node. If a select argument is provided in @@ -165,7 +165,7 @@ def set_account(self, account: str) -> None: if account: self.batch_args["A"] = str(account) - def set_resource(self, resource_name: str, value: t.Union[str, int]) -> None: + def set_resource(self, resource_name: str, value: str | int) -> None: """Set a resource value for the Qsub batch If a select statement is provided, the nodes and ncpus @@ -181,7 +181,7 @@ def set_resource(self, resource_name: str, value: t.Union[str, int]) -> None: self._sanity_check_resources(updated_dict) self.resources = updated_dict - def format_batch_args(self) -> t.List[str]: + def format_batch_args(self) -> list[str]: """Get the formatted batch arguments for a preview :return: batch arguments for Qsub @@ -196,7 +196,7 @@ def format_batch_args(self) -> t.List[str]: return opts def _sanity_check_resources( - self, resources: t.Optional[t.Dict[str, t.Union[str, int]]] = None + self, resources: dict[str, str | int] | None = None ) -> None: """Check that only select or nodes was specified in resources @@ -233,7 +233,7 @@ def _sanity_check_resources( "and str are allowed." ) - def _create_resource_list(self) -> t.List[str]: + def _create_resource_list(self) -> list[str]: self._sanity_check_resources() res = [] diff --git a/smartsim/settings/settings.py b/smartsim/settings/settings.py index 03c37a685..ecd32f3db 100644 --- a/smartsim/settings/settings.py +++ b/smartsim/settings/settings.py @@ -25,6 +25,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import typing as t +from collections.abc import Callable from .._core.utils.helpers import is_valid_cmd from ..error import SmartSimError @@ -45,16 +46,16 @@ ) from ..wlm import detect_launcher -_TRunSettingsSelector = t.Callable[[str], t.Callable[..., RunSettings]] +_TRunSettingsSelector = Callable[[str], Callable[..., RunSettings]] def create_batch_settings( launcher: str, - nodes: t.Optional[int] = None, + nodes: int | None = None, time: str = "", - queue: t.Optional[str] = None, - account: t.Optional[str] = None, - batch_args: t.Optional[t.Dict[str, str]] = None, + queue: str | None = None, + account: str | None = None, + batch_args: dict[str, str] | None = None, **kwargs: t.Any, ) -> base.BatchSettings: """Create a ``BatchSettings`` instance @@ -72,7 +73,7 @@ def create_batch_settings( :raises SmartSimError: if batch creation fails """ # all supported batch class implementations - by_launcher: t.Dict[str, t.Callable[..., base.BatchSettings]] = { + by_launcher: dict[str, Callable[..., base.BatchSettings]] = { "pbs": QsubBatchSettings, "slurm": SbatchSettings, "pals": QsubBatchSettings, @@ -110,11 +111,11 @@ def create_batch_settings( def create_run_settings( launcher: str, exe: str, - exe_args: t.Optional[t.List[str]] = None, + exe_args: list[str] | None = None, run_command: str = "auto", - run_args: t.Optional[t.Dict[str, t.Union[int, str, float, None]]] = None, - env_vars: t.Optional[t.Dict[str, t.Optional[str]]] = None, - container: t.Optional[Container] = None, + run_args: dict[str, int | str | float | None] | None = None, + env_vars: dict[str, str | None] | None = None, + container: Container | None = None, **kwargs: t.Any, ) -> RunSettings: """Create a ``RunSettings`` instance. @@ -133,7 +134,7 @@ def create_run_settings( :raises SmartSimError: if run_command=="auto" and detection fails """ # all supported RunSettings child classes - supported: t.Dict[str, _TRunSettingsSelector] = { + supported: dict[str, _TRunSettingsSelector] = { "aprun": lambda launcher: AprunSettings, "srun": lambda launcher: SrunSettings, "mpirun": lambda launcher: MpirunSettings, diff --git a/smartsim/settings/sgeSettings.py b/smartsim/settings/sgeSettings.py index 5a46c9f1b..0bbae9218 100644 --- a/smartsim/settings/sgeSettings.py +++ b/smartsim/settings/sgeSettings.py @@ -36,13 +36,13 @@ class SgeQsubBatchSettings(BatchSettings): def __init__( self, - time: t.Optional[str] = None, - ncpus: t.Optional[int] = None, - pe_type: t.Optional[str] = None, - account: t.Optional[str] = None, + time: str | None = None, + ncpus: int | None = None, + pe_type: str | None = None, + account: str | None = None, shebang: str = "#!/bin/bash -l", - resources: t.Optional[t.Dict[str, t.Union[str, int]]] = None, - batch_args: t.Optional[t.Dict[str, t.Optional[str]]] = None, + resources: dict[str, str | int] | None = None, + batch_args: dict[str, str | None] | None = None, **kwargs: t.Any, ): """Specify SGE batch parameters for a job @@ -75,19 +75,19 @@ def __init__( **kwargs, ) - self._context_variables: t.List[str] = [] - self._env_vars: t.Dict[str, str] = {} + self._context_variables: list[str] = [] + self._env_vars: dict[str, str] = {} @property - def resources(self) -> t.Dict[str, t.Union[str, int]]: + def resources(self) -> dict[str, str | int]: return self._resources.copy() @resources.setter - def resources(self, resources: t.Dict[str, t.Union[str, int]]) -> None: + def resources(self, resources: dict[str, str | int]) -> None: self._sanity_check_resources(resources) self._resources = resources.copy() - def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None: + def set_hostlist(self, host_list: str | list[str]) -> None: raise LauncherUnsupportedFeature( "SGE does not support requesting specific hosts in batch jobs" ) @@ -117,7 +117,7 @@ def set_walltime(self, walltime: str) -> None: if walltime: self.set_resource("h_rt", walltime) - def set_nodes(self, num_nodes: t.Optional[int]) -> None: + def set_nodes(self, num_nodes: int | None) -> None: """Set the number of nodes, invalid for SGE :param nodes: Number of nodes, any integer other than 0 is invalid @@ -127,14 +127,14 @@ def set_nodes(self, num_nodes: t.Optional[int]) -> None: "SGE does not support setting the number of nodes" ) - def set_ncpus(self, num_cpus: t.Union[int, str]) -> None: + def set_ncpus(self, num_cpus: int | str) -> None: """Set the number of cpus obtained in each node. :param num_cpus: number of cpus per node in select """ self.set_resource("ncpus", int(num_cpus)) - def set_ngpus(self, num_gpus: t.Union[int, str]) -> None: + def set_ngpus(self, num_gpus: int | str) -> None: """Set the number of GPUs obtained in each node. :param num_gpus: number of GPUs per node in select @@ -161,7 +161,7 @@ def update_context_variables( self, action: t.Literal["ac", "sc", "dc"], var_name: str, - value: t.Optional[t.Union[int, str]] = None, + value: int | str | None = None, ) -> None: """ Add, set, or delete context variables @@ -214,7 +214,7 @@ def set_threads_per_pe(self, threads_per_core: int) -> None: self._env_vars["OMP_NUM_THREADS"] = str(threads_per_core) - def set_resource(self, resource_name: str, value: t.Union[str, int]) -> None: + def set_resource(self, resource_name: str, value: str | int) -> None: """Set a resource value for the SGE batch If a select statement is provided, the nodes and ncpus @@ -228,7 +228,7 @@ def set_resource(self, resource_name: str, value: t.Union[str, int]) -> None: self._sanity_check_resources(updated_dict) self.resources = updated_dict - def format_batch_args(self) -> t.List[str]: + def format_batch_args(self) -> list[str]: """Get the formatted batch arguments for a preview :return: batch arguments for SGE @@ -243,7 +243,7 @@ def format_batch_args(self) -> t.List[str]: return opts def _sanity_check_resources( - self, resources: t.Optional[t.Dict[str, t.Union[str, int]]] = None + self, resources: dict[str, str | int] | None = None ) -> None: """Check that resources are correctly formatted""" # Note: isinstance check here to avoid collision with default @@ -261,7 +261,7 @@ def _sanity_check_resources( "and str are allowed." ) - def _create_resource_list(self) -> t.List[str]: + def _create_resource_list(self) -> list[str]: self._sanity_check_resources() res = [] diff --git a/smartsim/settings/slurmSettings.py b/smartsim/settings/slurmSettings.py index faffc7837..af30ec8a4 100644 --- a/smartsim/settings/slurmSettings.py +++ b/smartsim/settings/slurmSettings.py @@ -29,6 +29,7 @@ import datetime import os import typing as t +from collections.abc import Iterable from ..error import SSUnsupportedError from ..log import get_logger @@ -41,10 +42,10 @@ class SrunSettings(RunSettings): def __init__( self, exe: str, - exe_args: t.Optional[t.Union[str, t.List[str]]] = None, - run_args: t.Optional[t.Dict[str, t.Union[int, str, float, None]]] = None, - env_vars: t.Optional[t.Dict[str, t.Optional[str]]] = None, - alloc: t.Optional[str] = None, + exe_args: t.Optional[str | list[str]] = None, + run_args: dict[str, int | str | float | None] | None = None, + env_vars: dict[str, str | None] | None = None, + alloc: str | None = None, **kwargs: t.Any, ) -> None: """Initialize run parameters for a slurm job with ``srun`` @@ -69,7 +70,7 @@ def __init__( **kwargs, ) self.alloc = alloc - self.mpmd: t.List[RunSettings] = [] + self.mpmd: list[RunSettings] = [] reserved_run_args = frozenset({"chdir", "D"}) @@ -104,7 +105,7 @@ def make_mpmd(self, settings: RunSettings) -> None: ) self.mpmd.append(settings) - def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None: + def set_hostlist(self, host_list: str | list[str]) -> None: """Specify the hostlist for this job This sets ``--nodelist`` @@ -129,7 +130,7 @@ def set_hostlist_from_file(self, file_path: str) -> None: """ self.run_args["nodefile"] = file_path - def set_excluded_hosts(self, host_list: t.Union[str, t.List[str]]) -> None: + def set_excluded_hosts(self, host_list: str | list[str]) -> None: """Specify a list of hosts to exclude for launching this job :param host_list: hosts to exclude @@ -170,7 +171,7 @@ def set_tasks_per_node(self, tasks_per_node: int) -> None: """ self.run_args["ntasks-per-node"] = int(tasks_per_node) - def set_cpu_bindings(self, bindings: t.Union[int, t.List[int]]) -> None: + def set_cpu_bindings(self, bindings: int | list[int]) -> None: """Bind by setting CPU masks on tasks This sets ``--cpu-bind`` using the ``map_cpu:`` option @@ -216,7 +217,7 @@ def set_quiet_launch(self, quiet: bool) -> None: else: self.run_args.pop("quiet", None) - def set_broadcast(self, dest_path: t.Optional[str] = None) -> None: + def set_broadcast(self, dest_path: str | None = None) -> None: """Copy executable file to allocated compute nodes This sets ``--bcast`` @@ -225,7 +226,7 @@ def set_broadcast(self, dest_path: t.Optional[str] = None) -> None: """ self.run_args["bcast"] = dest_path - def set_node_feature(self, feature_list: t.Union[str, t.List[str]]) -> None: + def set_node_feature(self, feature_list: str | list[str]) -> None: """Specify the node feature for this job This sets ``-C`` @@ -261,7 +262,7 @@ def set_walltime(self, walltime: str) -> None: """ self.run_args["time"] = str(walltime) - def set_het_group(self, het_group: t.Iterable[int]) -> None: + def set_het_group(self, het_group: Iterable[int]) -> None: """Set the heterogeneous group for this job this sets `--het-group` @@ -291,7 +292,7 @@ def set_het_group(self, het_group: t.Iterable[int]) -> None: logger.warning(msg) self.run_args["het-group"] = ",".join(str(group) for group in het_group) - def format_run_args(self) -> t.List[str]: + def format_run_args(self) -> list[str]: """Return a list of slurm formatted run arguments :return: list of slurm arguments for these settings @@ -331,7 +332,7 @@ def check_env_vars(self) -> None: ) logger.warning(msg) - def format_env_vars(self) -> t.List[str]: + def format_env_vars(self) -> list[str]: """Build bash compatible environment variable string for Slurm :returns: the formatted string of environment variables @@ -339,7 +340,7 @@ def format_env_vars(self) -> t.List[str]: self.check_env_vars() return [f"{k}={v}" for k, v in self.env_vars.items() if "," not in str(v)] - def format_comma_sep_env_vars(self) -> t.Tuple[str, t.List[str]]: + def format_comma_sep_env_vars(self) -> tuple[str, list[str]]: """Build environment variable string for Slurm Slurm takes exports in comma separated lists @@ -393,10 +394,10 @@ def fmt_walltime(hours: int, minutes: int, seconds: int) -> str: class SbatchSettings(BatchSettings): def __init__( self, - nodes: t.Optional[int] = None, + nodes: int | None = None, time: str = "", - account: t.Optional[str] = None, - batch_args: t.Optional[t.Dict[str, t.Optional[str]]] = None, + account: str | None = None, + batch_args: dict[str, str | None] | None = None, **kwargs: t.Any, ) -> None: """Specify run parameters for a Slurm batch job @@ -477,7 +478,7 @@ def set_cpus_per_task(self, cpus_per_task: int) -> None: """ self.batch_args["cpus-per-task"] = str(int(cpus_per_task)) - def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None: + def set_hostlist(self, host_list: str | list[str]) -> None: """Specify the hostlist for this job :param host_list: hosts to launch on @@ -491,7 +492,7 @@ def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None: raise TypeError("host_list argument must be list of strings") self.batch_args["nodelist"] = ",".join(host_list) - def format_batch_args(self) -> t.List[str]: + def format_batch_args(self) -> list[str]: """Get the formatted batch arguments for a preview :return: batch arguments for Sbatch diff --git a/smartsim/wlm/__init__.py b/smartsim/wlm/__init__.py index 1f70dcf3f..b870de74a 100644 --- a/smartsim/wlm/__init__.py +++ b/smartsim/wlm/__init__.py @@ -66,7 +66,7 @@ def detect_launcher() -> str: return "local" -def get_hosts(launcher: t.Optional[str] = None) -> t.List[str]: +def get_hosts(launcher: str | None = None) -> list[str]: """Get the name of the hosts used in an allocation. :param launcher: Name of the WLM to use to collect allocation info. If no launcher @@ -83,7 +83,7 @@ def get_hosts(launcher: t.Optional[str] = None) -> t.List[str]: raise SSUnsupportedError(f"SmartSim cannot get hosts for launcher `{launcher}`") -def get_queue(launcher: t.Optional[str] = None) -> str: +def get_queue(launcher: str | None = None) -> str: """Get the name of the queue used in an allocation. :param launcher: Name of the WLM to use to collect allocation info. If no launcher @@ -100,7 +100,7 @@ def get_queue(launcher: t.Optional[str] = None) -> str: raise SSUnsupportedError(f"SmartSim cannot get queue for launcher `{launcher}`") -def get_tasks(launcher: t.Optional[str] = None) -> int: +def get_tasks(launcher: str | None = None) -> int: """Get the number of tasks in an allocation. :param launcher: Name of the WLM to use to collect allocation info. If no launcher @@ -117,7 +117,7 @@ def get_tasks(launcher: t.Optional[str] = None) -> int: raise SSUnsupportedError(f"SmartSim cannot get tasks for launcher `{launcher}`") -def get_tasks_per_node(launcher: t.Optional[str] = None) -> t.Dict[str, int]: +def get_tasks_per_node(launcher: str | None = None) -> dict[str, int]: """Get a map of nodes in an allocation to the number of tasks on each node. :param launcher: Name of the WLM to use to collect allocation info. If no launcher diff --git a/smartsim/wlm/pbs.py b/smartsim/wlm/pbs.py index a7e1dae87..0f7133072 100644 --- a/smartsim/wlm/pbs.py +++ b/smartsim/wlm/pbs.py @@ -26,7 +26,6 @@ import json import os -import typing as t from shutil import which from smartsim.error.errors import LauncherError, SmartSimError @@ -34,7 +33,7 @@ from .._core.launcher.pbs.pbsCommands import qstat -def get_hosts() -> t.List[str]: +def get_hosts() -> list[str]: """Get the name of the hosts used in a PBS allocation. :returns: Names of the host nodes @@ -92,7 +91,7 @@ def get_tasks() -> int: ) -def get_tasks_per_node() -> t.Dict[str, int]: +def get_tasks_per_node() -> dict[str, int]: """Get the number of processes on each chunk in a PBS allocation. .. note:: diff --git a/smartsim/wlm/slurm.py b/smartsim/wlm/slurm.py index 490e46b21..f4fd57973 100644 --- a/smartsim/wlm/slurm.py +++ b/smartsim/wlm/slurm.py @@ -25,7 +25,6 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import os -import typing as t from shutil import which from .._core.launcher.slurm.slurmCommands import salloc, scancel, scontrol, sinfo @@ -45,9 +44,9 @@ def get_allocation( nodes: int = 1, - time: t.Optional[str] = None, - account: t.Optional[str] = None, - options: t.Optional[t.Dict[str, str]] = None, + time: str | None = None, + account: str | None = None, + options: dict[str, str] | None = None, ) -> str: """Request an allocation @@ -125,7 +124,7 @@ def release_allocation(alloc_id: str) -> None: logger.info(f"Successfully freed allocation {alloc_id}") -def validate(nodes: int = 1, ppn: int = 1, partition: t.Optional[str] = None) -> bool: +def validate(nodes: int = 1, ppn: int = 1, partition: str | None = None) -> bool: """Check that there are sufficient resources in the provided Slurm partitions. if no partition is provided, the default partition is found and used. @@ -191,14 +190,14 @@ def get_default_partition() -> str: return default -def _get_system_partition_info() -> t.Dict[str, Partition]: +def _get_system_partition_info() -> dict[str, Partition]: """Build a dictionary of slurm partitions :returns: dict of Partition objects """ sinfo_output, _ = sinfo(["--noheader", "--format", "%R %n %c"]) - partitions: t.Dict[str, Partition] = {} + partitions: dict[str, Partition] = {} for line in sinfo_output.split("\n"): line = line.strip() if line == "": @@ -220,10 +219,10 @@ def _get_system_partition_info() -> t.Dict[str, Partition]: def _get_alloc_cmd( nodes: int, - time: t.Optional[str] = None, - account: t.Optional[str] = None, - options: t.Optional[t.Dict[str, str]] = None, -) -> t.List[str]: + time: str | None = None, + account: str | None = None, + options: dict[str, str] | None = None, +) -> list[str]: """Return the command to request an allocation from Slurm with the class variables as the slurm options. """ @@ -278,7 +277,7 @@ def _validate_time_format(time: str) -> str: return fmt_walltime(hours, minutes, seconds) -def get_hosts() -> t.List[str]: +def get_hosts() -> list[str]: """Get the name of the nodes used in a slurm allocation. .. note:: @@ -327,7 +326,7 @@ def get_tasks() -> int: raise SmartSimError("Could not parse number of requested tasks from SLURM_NTASKS") -def get_tasks_per_node() -> t.Dict[str, int]: +def get_tasks_per_node() -> dict[str, int]: """Get the number of tasks per each node in a slurm allocation. .. note:: diff --git a/tests/on_wlm/test_dragon_entrypoint.py b/tests/on_wlm/test_dragon_entrypoint.py index 287088a7f..c0ae04d1f 100644 --- a/tests/on_wlm/test_dragon_entrypoint.py +++ b/tests/on_wlm/test_dragon_entrypoint.py @@ -40,7 +40,7 @@ @pytest.fixture -def mock_argv() -> t.List[str]: +def mock_argv() -> list[str]: """Fixture for returning valid arguments to the entrypoint""" return ["+launching_address", "mock-addr", "+interface", "mock-interface"] @@ -83,7 +83,7 @@ def test_file_removal_on_bad_path(test_dir: str, monkeypatch: pytest.MonkeyPatch def test_dragon_failure( - mock_argv: t.List[str], test_dir: str, monkeypatch: pytest.MonkeyPatch + mock_argv: list[str], test_dir: str, monkeypatch: pytest.MonkeyPatch ): """Verify that the expected cleanup actions are taken when the dragon entrypoint exits""" @@ -110,7 +110,7 @@ def raiser(args_) -> int: def test_dragon_main( - mock_argv: t.List[str], test_dir: str, monkeypatch: pytest.MonkeyPatch + mock_argv: list[str], test_dir: str, monkeypatch: pytest.MonkeyPatch ): """Verify that the expected startup & cleanup actions are taken when the dragon entrypoint exits""" @@ -228,7 +228,7 @@ def increment_counter(*args, **kwargs): def test_signal_handler_registration(test_dir: str, monkeypatch: pytest.MonkeyPatch): """Verify that signal handlers are registered for all expected signals""" - sig_nums: t.List[int] = [] + sig_nums: list[int] = [] def track_args(*args, **kwargs): nonlocal sig_nums diff --git a/tests/test_cli.py b/tests/test_cli.py index 6a4d161cb..a6db1169d 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -59,20 +59,20 @@ def mock_execute_custom(msg: str = None, good: bool = True) -> int: def mock_execute_good( - _ns: argparse.Namespace, _unparsed: t.Optional[t.List[str]] = None + _ns: argparse.Namespace, _unparsed: list[str] | None = None ) -> int: return mock_execute_custom("GOOD THINGS", good=True) def mock_execute_fail( - _ns: argparse.Namespace, _unparsed: t.Optional[t.List[str]] = None + _ns: argparse.Namespace, _unparsed: list[str] | None = None ) -> int: return mock_execute_custom("BAD THINGS", good=False) def test_cli_default_args_parsing(capsys): """Test default parser behaviors with no subparsers""" - menu: t.List[cli.MenuItemConfig] = [] + menu: list[cli.MenuItemConfig] = [] smart_cli = cli.SmartCli(menu) captured = capsys.readouterr() # throw away existing output @@ -111,7 +111,7 @@ def test_cli_invalid_command(capsys): def test_cli_bad_default_args_parsing_bad_help(capsys): """Test passing an argument name that is incorrect""" - menu: t.List[cli.MenuItemConfig] = [] + menu: list[cli.MenuItemConfig] = [] smart_cli = cli.SmartCli(menu) captured = capsys.readouterr() # throw away existing output @@ -127,7 +127,7 @@ def test_cli_bad_default_args_parsing_bad_help(capsys): def test_cli_bad_default_args_parsing_good_help(capsys): """Test passing an argument name that is correct""" - menu: t.List[cli.MenuItemConfig] = [] + menu: list[cli.MenuItemConfig] = [] smart_cli = cli.SmartCli(menu) captured = capsys.readouterr() # throw away existing output @@ -388,7 +388,7 @@ def test_cli_plugin_invalid( def test_cli_action(capsys, monkeypatch, command, mock_location, exp_output): """Ensure the default CLI executes the build action""" - def mock_execute(ns: argparse.Namespace, _unparsed: t.Optional[t.List[str]] = None): + def mock_execute(ns: argparse.Namespace, _unparsed: list[str] | None = None): print(exp_output) return 0 @@ -444,7 +444,7 @@ def test_cli_optional_args( ): """Ensure the parser for a command handles expected optional arguments""" - def mock_execute(ns: argparse.Namespace, _unparsed: t.Optional[t.List[str]] = None): + def mock_execute(ns: argparse.Namespace, _unparsed: list[str] | None = None): print(exp_output) return 0 @@ -495,7 +495,7 @@ def test_cli_help_support( ): """Ensure the parser supports help optional for commands as expected""" - def mock_execute(ns: argparse.Namespace, unparsed: t.Optional[t.List[str]] = None): + def mock_execute(ns: argparse.Namespace, unparsed: list[str] | None = None): print(mock_output) return 0 @@ -534,7 +534,7 @@ def test_cli_invalid_optional_args( ): """Ensure the parser throws expected error for an invalid argument""" - def mock_execute(ns: argparse.Namespace, unparsed: t.Optional[t.List[str]] = None): + def mock_execute(ns: argparse.Namespace, unparsed: list[str] | None = None): print(exp_output) return 0 diff --git a/tests/test_config.py b/tests/test_config.py index 55f26df30..16277e834 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -55,9 +55,7 @@ def test_all_config_defaults(): config.test_device -def get_redisai_env( - rai_path: t.Optional[str], lib_path: t.Optional[str] -) -> t.Dict[str, str]: +def get_redisai_env(rai_path: str | None, lib_path: str | None) -> dict[str, str]: """Convenience method to create a set of environment variables that include RedisAI-specific variables :param rai_path: The path to the RedisAI library @@ -149,7 +147,7 @@ def test_redisai_valid_lib_path(test_dir, monkeypatch): def test_redisai_valid_lib_path_null_rai(test_dir, monkeypatch): """Missing RAI_PATH and valid SMARTSIM_DEP_INSTALL_PATH should succeed""" - rai_file_path: t.Optional[str] = None + rai_file_path: str | None = None lib_file_path = os.path.join(test_dir, "lib", "redisai.so") make_file(lib_file_path) env = get_redisai_env(rai_file_path, test_dir) diff --git a/tests/test_dragon_client.py b/tests/test_dragon_client.py index cab35c673..ba2a15ec2 100644 --- a/tests/test_dragon_client.py +++ b/tests/test_dragon_client.py @@ -92,7 +92,7 @@ def dragon_batch_step(test_dir: str) -> "DragonBatchStep": return batch_step -def get_request_path_from_batch_script(launch_cmd: t.List[str]) -> pathlib.Path: +def get_request_path_from_batch_script(launch_cmd: list[str]) -> pathlib.Path: """Helper method for finding the path to a request file from the launch command""" script_path = pathlib.Path(launch_cmd[-1]) batch_script = script_path.read_text(encoding="utf-8") diff --git a/tests/test_dragon_installer.py b/tests/test_dragon_installer.py index 7e233000f..7445d5ff2 100644 --- a/tests/test_dragon_installer.py +++ b/tests/test_dragon_installer.py @@ -29,6 +29,7 @@ import tarfile import typing as t from collections import namedtuple +from collections.abc import Collection import pytest from github.GitReleaseAsset import GitReleaseAsset @@ -84,7 +85,7 @@ def extraction_dir(test_dir: str) -> pathlib.Path: @pytest.fixture -def test_assets(monkeypatch: pytest.MonkeyPatch) -> t.Dict[str, GitReleaseAsset]: +def test_assets(monkeypatch: pytest.MonkeyPatch) -> dict[str, GitReleaseAsset]: requester = Requester( auth=None, base_url="https://github.com", @@ -99,7 +100,7 @@ def test_assets(monkeypatch: pytest.MonkeyPatch) -> t.Dict[str, GitReleaseAsset] attributes = {"mock-attr": "mock-attr-value"} completed = True - assets: t.List[GitReleaseAsset] = [] + assets: list[GitReleaseAsset] = [] mock_archive_name_tpl = "{}-{}.4.1-{}ac132fe95.tar.gz" for python_version in ["py3.10", "py3.11"]: @@ -205,7 +206,7 @@ def test_retrieve_cached( ], ) def test_retrieve_asset_info( - test_assets: t.Collection[GitReleaseAsset], + test_assets: Collection[GitReleaseAsset], monkeypatch: pytest.MonkeyPatch, dragon_pin: str, pyv: str, diff --git a/tests/test_dragon_launcher.py b/tests/test_dragon_launcher.py index 4b59db935..9147296d1 100644 --- a/tests/test_dragon_launcher.py +++ b/tests/test_dragon_launcher.py @@ -701,7 +701,7 @@ def test_run_step_success(test_dir: str) -> None: send_invocation = mock_connector.send_request send_invocation.assert_called_once() - args = send_invocation.call_args[0] # call_args == t.Tuple[args, kwargs] + args = send_invocation.call_args[0] # call_args == tuple[args, kwargs] dragon_run_request = args[0] req_name = dragon_run_request.name # name sent to dragon env diff --git a/tests/test_dragon_run_request.py b/tests/test_dragon_run_request.py index a74ca0e79..c664f66de 100644 --- a/tests/test_dragon_run_request.py +++ b/tests/test_dragon_run_request.py @@ -58,7 +58,7 @@ class NodeMock(MagicMock): def __init__( - self, name: t.Optional[str] = None, num_gpus: int = 2, num_cpus: int = 8 + self, name: str | None = None, num_gpus: int = 2, num_cpus: int = 8 ) -> None: super().__init__() self._mock_id = name @@ -82,7 +82,7 @@ def num_gpus(self) -> str: def _set_id(self, value: str) -> None: self._mock_id = value - def gpus(self, parent: t.Any = None) -> t.List[str]: + def gpus(self, parent: t.Any = None) -> list[str]: if self._num_gpus: return [f"{self.hostname}-gpu{i}" for i in range(NodeMock._num_gpus)] return [] @@ -161,7 +161,7 @@ def get_mock_backend( def set_mock_group_infos( monkeypatch: pytest.MonkeyPatch, dragon_backend: "DragonBackend" -) -> t.Dict[str, "ProcessGroupInfo"]: +) -> dict[str, "ProcessGroupInfo"]: dragon_mock = MagicMock() process_mock = MagicMock() process_mock.configure_mock(**{"returncode": 0}) @@ -518,7 +518,7 @@ def test_can_honor(monkeypatch: pytest.MonkeyPatch, num_nodes: int) -> None: @pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") @pytest.mark.parametrize("affinity", [[0], [0, 1], list(range(8))]) def test_can_honor_cpu_affinity( - monkeypatch: pytest.MonkeyPatch, affinity: t.List[int] + monkeypatch: pytest.MonkeyPatch, affinity: list[int] ) -> None: """Verify that valid CPU affinities are accepted""" dragon_backend = get_mock_backend(monkeypatch) @@ -562,7 +562,7 @@ def test_can_honor_cpu_affinity_out_of_range(monkeypatch: pytest.MonkeyPatch) -> @pytest.mark.skipif(not dragon_loaded, reason="Test is only for Dragon WLM systems") @pytest.mark.parametrize("affinity", [[0], [0, 1]]) def test_can_honor_gpu_affinity( - monkeypatch: pytest.MonkeyPatch, affinity: t.List[int] + monkeypatch: pytest.MonkeyPatch, affinity: list[int] ) -> None: """Verify that valid GPU affinities are accepted""" dragon_backend = get_mock_backend(monkeypatch) diff --git a/tests/test_dragon_run_request_nowlm.py b/tests/test_dragon_run_request_nowlm.py index 7a1cd90a2..167489233 100644 --- a/tests/test_dragon_run_request_nowlm.py +++ b/tests/test_dragon_run_request_nowlm.py @@ -81,8 +81,8 @@ def test_run_request_with_empty_policy(monkeypatch: pytest.MonkeyPatch) -> None: ) def test_run_request_with_negative_affinity( device: str, - cpu_affinity: t.List[int], - gpu_affinity: t.List[int], + cpu_affinity: list[int], + gpu_affinity: list[int], ) -> None: """Verify that invalid affinity values fail validation""" with pytest.raises(ValidationError) as ex: diff --git a/tests/test_dragon_step.py b/tests/test_dragon_step.py index 9053e6129..10c4e0598 100644 --- a/tests/test_dragon_step.py +++ b/tests/test_dragon_step.py @@ -94,7 +94,7 @@ def dragon_batch_step(test_dir: str) -> DragonBatchStep: return batch_step -def get_request_path_from_batch_script(launch_cmd: t.List[str]) -> pathlib.Path: +def get_request_path_from_batch_script(launch_cmd: list[str]) -> pathlib.Path: """Helper method for finding the path to a request file from the launch command""" script_path = pathlib.Path(launch_cmd[-1]) batch_script = script_path.read_text(encoding="utf-8") @@ -298,7 +298,7 @@ def test_dragon_batch_step_get_launch_command_meta_fail(test_dir: str) -> None: ) def test_dragon_batch_step_get_launch_command( test_dir: str, - batch_settings_class: t.Type, + batch_settings_class: type, batch_exe: str, batch_header: str, node_spec_tpl: str, @@ -379,7 +379,7 @@ def test_dragon_batch_step_write_request_file( requests_file = get_request_path_from_batch_script(launch_cmd) requests_text = requests_file.read_text(encoding="utf-8") - requests_json: t.List[str] = json.loads(requests_text) + requests_json: list[str] = json.loads(requests_text) # verify that there is an item in file for each step added to the batch assert len(requests_json) == len(dragon_batch_step.steps) diff --git a/tests/test_manifest.py b/tests/test_manifest.py index 998969062..f75f752e8 100644 --- a/tests/test_manifest.py +++ b/tests/test_manifest.py @@ -47,8 +47,8 @@ # ---- create entities for testing -------- -_EntityResult = t.Tuple[ - Experiment, t.Tuple[Model, Model], Ensemble, Orchestrator, DBModel, DBScript +_EntityResult = tuple[ + Experiment, tuple[Model, Model], Ensemble, Orchestrator, DBModel, DBScript ] diff --git a/tests/test_manifest_metadata_directories.py b/tests/test_manifest_metadata_directories.py new file mode 100644 index 000000000..95cc3d201 --- /dev/null +++ b/tests/test_manifest_metadata_directories.py @@ -0,0 +1,205 @@ +"""Test the metadata directory functionality added to LaunchedManifestBuilder""" + +# NOTE: This entire test file has been commented out because it tests +# LaunchedManifestBuilder functionality which has been removed. +# All LaunchedManifest-related classes have been deleted from the codebase. +# +# # import pathlib +# # import tempfile +# # import time +# # from unittest.mock import patch +# # +# # import pytest +# # +# # from smartsim._core.config import CONFIG +# # from smartsim._core.control.manifest import LaunchedManifestBuilder +# +# +# class TestLaunchedManifestBuilderMetadataDirectories: +# """Test metadata directory properties and methods of LaunchedManifestBuilder""" +# +# def test_exp_metadata_subdirectory_property(self): +# """Test that exp_metadata_subdirectory returns correct path""" +# with tempfile.TemporaryDirectory() as temp_dir: +# lmb = LaunchedManifestBuilder( +# exp_name="test_exp", +# exp_path=temp_dir, +# launcher_name="local", +# ) +# +# expected_path = pathlib.Path(temp_dir) / CONFIG.metadata_subdir +# assert lmb.exp_metadata_subdirectory == expected_path +# +# def test_run_metadata_subdirectory_property(self): +# """Test that run_metadata_subdirectory returns correct timestamped path""" +# with tempfile.TemporaryDirectory() as temp_dir: +# # Mock the timestamp to make it predictable +# mock_timestamp = "1234567890123" +# with patch.object(time, "time", return_value=1234567890.123): +# lmb = LaunchedManifestBuilder( +# exp_name="test_exp", +# exp_path=temp_dir, +# launcher_name="local", +# ) +# +# expected_path = ( +# pathlib.Path(temp_dir) +# / CONFIG.metadata_subdir +# / f"run_{mock_timestamp}" +# ) +# assert lmb.run_metadata_subdirectory == expected_path +# +# def test_run_metadata_subdirectory_uses_actual_timestamp(self): +# """Test that run_metadata_subdirectory uses actual timestamp from launch""" +# with tempfile.TemporaryDirectory() as temp_dir: +# lmb = LaunchedManifestBuilder( +# exp_name="test_exp", +# exp_path=temp_dir, +# launcher_name="local", +# ) +# +# # Check that the timestamp is reasonable (within last few seconds) +# run_dir_name = lmb.run_metadata_subdirectory.name +# assert run_dir_name.startswith("run_") +# +# # Extract timestamp and verify it's recent +# timestamp_str = run_dir_name[4:] # Remove "run_" prefix +# timestamp_ms = int(timestamp_str) +# current_time_ms = int(time.time() * 1000) +# +# # Should be within 5 seconds of current time +# assert abs(current_time_ms - timestamp_ms) < 5000 +# +# def test_get_entity_metadata_subdirectory_method(self): +# """Test that get_entity_metadata_subdirectory returns correct entity-specific paths""" +# with tempfile.TemporaryDirectory() as temp_dir: +# mock_timestamp = "1234567890123" +# with patch.object(time, "time", return_value=1234567890.123): +# lmb = LaunchedManifestBuilder( +# exp_name="test_exp", +# exp_path=temp_dir, +# launcher_name="local", +# ) +# +# # Test different entity types +# model_dir = lmb.get_entity_metadata_subdirectory("model") +# ensemble_dir = lmb.get_entity_metadata_subdirectory("ensemble") +# database_dir = lmb.get_entity_metadata_subdirectory("database") +# +# base_path = ( +# pathlib.Path(temp_dir) +# / CONFIG.metadata_subdir +# / f"run_{mock_timestamp}" +# ) +# +# assert model_dir == base_path / "model" +# assert ensemble_dir == base_path / "ensemble" +# assert database_dir == base_path / "database" +# +# def test_get_entity_metadata_subdirectory_custom_entity_type(self): +# """Test that get_entity_metadata_subdirectory works with custom entity types""" +# with tempfile.TemporaryDirectory() as temp_dir: +# lmb = LaunchedManifestBuilder( +# exp_name="test_exp", +# exp_path=temp_dir, +# launcher_name="local", +# ) +# +# # Test with custom entity type +# custom_dir = lmb.get_entity_metadata_subdirectory("custom_entity_type") +# +# expected_path = lmb.run_metadata_subdirectory / "custom_entity_type" +# assert custom_dir == expected_path +# +# def test_metadata_directory_hierarchy(self): +# """Test that the metadata directory hierarchy is correct""" +# with tempfile.TemporaryDirectory() as temp_dir: +# lmb = LaunchedManifestBuilder( +# exp_name="test_exp", +# exp_path=temp_dir, +# launcher_name="local", +# ) +# +# # Test that the hierarchy is: exp_path/.smartsim/metadata/run_/entity_type +# model_dir = lmb.get_entity_metadata_subdirectory("model") +# +# # Check path components +# path_parts = model_dir.parts +# # Extract the metadata subdir parts for comparison +# metadata_parts = pathlib.Path(CONFIG.metadata_subdir).parts +# if len(metadata_parts) == 2: # e.g., ".smartsim/metadata" +# assert path_parts[-4] == metadata_parts[0] # ".smartsim" +# assert path_parts[-3] == metadata_parts[1] # "metadata" +# else: # single part, e.g., "metadata" +# assert path_parts[-3] == metadata_parts[0] +# assert path_parts[-2].startswith("run_") +# assert path_parts[-1] == "model" +# +# def test_multiple_instances_have_different_timestamps(self): +# """Test that multiple LaunchedManifestBuilder instances have different timestamps""" +# with tempfile.TemporaryDirectory() as temp_dir: +# lmb1 = LaunchedManifestBuilder( +# exp_name="test_exp1", +# exp_path=temp_dir, +# launcher_name="local", +# ) +# +# # Small delay to ensure different timestamps +# time.sleep(0.001) +# +# lmb2 = LaunchedManifestBuilder( +# exp_name="test_exp2", +# exp_path=temp_dir, +# launcher_name="local", +# ) +# +# # Timestamps should be different +# assert lmb1._launch_timestamp != lmb2._launch_timestamp +# assert lmb1.run_metadata_subdirectory != lmb2.run_metadata_subdirectory +# +# def test_same_instance_consistent_timestamps(self): +# """Test that the same instance always returns consistent timestamps""" +# with tempfile.TemporaryDirectory() as temp_dir: +# lmb = LaunchedManifestBuilder( +# exp_name="test_exp", +# exp_path=temp_dir, +# launcher_name="local", +# ) +# +# # Multiple calls should return the same timestamp +# timestamp1 = lmb._launch_timestamp +# timestamp2 = lmb._launch_timestamp +# assert timestamp1 == timestamp2 +# +# # Multiple calls to run_metadata_subdirectory should be consistent +# run_dir1 = lmb.run_metadata_subdirectory +# run_dir2 = lmb.run_metadata_subdirectory +# assert run_dir1 == run_dir2 +# +# def test_exp_path_with_pathlib(self): +# """Test that metadata directories work correctly when exp_path is a pathlib.Path""" +# with tempfile.TemporaryDirectory() as temp_dir: +# exp_path = pathlib.Path(temp_dir) +# lmb = LaunchedManifestBuilder( +# exp_name="test_exp", +# exp_path=str(exp_path), # LaunchedManifestBuilder expects string +# launcher_name="local", +# ) +# +# expected_exp_metadata = exp_path / CONFIG.metadata_subdir +# assert lmb.exp_metadata_subdirectory == expected_exp_metadata +# +# def test_metadata_paths_are_pathlib_paths(self): +# """Test that all metadata directory methods return pathlib.Path objects""" +# with tempfile.TemporaryDirectory() as temp_dir: +# lmb = LaunchedManifestBuilder( +# exp_name="test_exp", +# exp_path=temp_dir, +# launcher_name="local", +# ) +# +# assert isinstance(lmb.exp_metadata_subdirectory, pathlib.Path) +# assert isinstance(lmb.run_metadata_subdirectory, pathlib.Path) +# assert isinstance( +# lmb.get_entity_metadata_subdirectory("model"), pathlib.Path +# ) diff --git a/tests/test_orchestrator.py b/tests/test_orchestrator.py index 0770ab17e..7e992f3ad 100644 --- a/tests/test_orchestrator.py +++ b/tests/test_orchestrator.py @@ -88,7 +88,7 @@ def test_orc_is_active_functions( def test_multiple_interfaces( - test_dir: str, wlmutils: t.Type["conftest.WLMUtils"] + test_dir: str, wlmutils: type["conftest.WLMUtils"] ) -> None: exp_name = "test_multiple_interfaces" exp = Experiment(exp_name, launcher="local", exp_path=test_dir) @@ -136,7 +136,7 @@ def test_catch_local_db_errors() -> None: ##### PBS ###### -def test_pbs_set_run_arg(wlmutils: t.Type["conftest.WLMUtils"]) -> None: +def test_pbs_set_run_arg(wlmutils: type["conftest.WLMUtils"]) -> None: orc = Orchestrator( wlmutils.get_test_port(), db_nodes=3, @@ -155,7 +155,7 @@ def test_pbs_set_run_arg(wlmutils: t.Type["conftest.WLMUtils"]) -> None: ) -def test_pbs_set_batch_arg(wlmutils: t.Type["conftest.WLMUtils"]) -> None: +def test_pbs_set_batch_arg(wlmutils: type["conftest.WLMUtils"]) -> None: orc = Orchestrator( wlmutils.get_test_port(), db_nodes=3, @@ -184,7 +184,7 @@ def test_pbs_set_batch_arg(wlmutils: t.Type["conftest.WLMUtils"]) -> None: ##### Slurm ###### -def test_slurm_set_run_arg(wlmutils: t.Type["conftest.WLMUtils"]) -> None: +def test_slurm_set_run_arg(wlmutils: type["conftest.WLMUtils"]) -> None: orc = Orchestrator( wlmutils.get_test_port(), db_nodes=3, @@ -199,7 +199,7 @@ def test_slurm_set_run_arg(wlmutils: t.Type["conftest.WLMUtils"]) -> None: ) -def test_slurm_set_batch_arg(wlmutils: t.Type["conftest.WLMUtils"]) -> None: +def test_slurm_set_batch_arg(wlmutils: type["conftest.WLMUtils"]) -> None: orc = Orchestrator( wlmutils.get_test_port(), db_nodes=3, diff --git a/tests/test_preview.py b/tests/test_preview.py index 4dbe4d8b4..91b26cf7a 100644 --- a/tests/test_preview.py +++ b/tests/test_preview.py @@ -60,7 +60,7 @@ def _choose_host(wlmutils, index: int = 0): @pytest.fixture -def preview_object(test_dir) -> t.Dict[str, Job]: +def preview_object(test_dir) -> dict[str, Job]: """ Bare bones orch """ @@ -72,12 +72,12 @@ def preview_object(test_dir) -> t.Dict[str, Job]: s.ports = [1235] s.num_shards = 1 job = Job("faux-name", "faux-step-id", s, "slurm", True) - active_dbjobs: t.Dict[str, Job] = {"mock_job": job} + active_dbjobs: dict[str, Job] = {"mock_job": job} return active_dbjobs @pytest.fixture -def preview_object_multidb(test_dir) -> t.Dict[str, Job]: +def preview_object_multidb(test_dir) -> dict[str, Job]: """ Bare bones orch """ @@ -99,7 +99,7 @@ def preview_object_multidb(test_dir) -> t.Dict[str, Job]: s2.num_shards = 1 job2 = Job("faux-name_2", "faux-step-id_2", s2, "slurm", True) - active_dbjobs: t.Dict[str, Job] = {"mock_job": job, "mock_job2": job2} + active_dbjobs: dict[str, Job] = {"mock_job": job, "mock_job2": job2} return active_dbjobs