diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ce4e2149c1..7a4257936b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,7 +21,11 @@ jobs: strategy: matrix: runs-on: ["ubuntu-latest", "windows-latest"] # can add macos-latest - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] + exclude: + # duckdb (tiled dep) has no Windows wheels for 3.14 yet + - runs-on: "windows-latest" + python-version: "3.14" include: # Include one that runs in the dev environment - runs-on: "ubuntu-latest" diff --git a/docs/api_changes.rst b/docs/api_changes.rst index 2365e06856..238bfa80f9 100644 --- a/docs/api_changes.rst +++ b/docs/api_changes.rst @@ -2,6 +2,48 @@ Release History ================= +v1.15.0 (2026-04-15) +==================== + +Added +----- + +- Added support and testing for Python 3.13 +- ``event_model`` versions to ``RE.md`` +- More flexible addresses for ZMQ callbacks +- Improved numpy sanitization for ``Msg`` tracing + +Fixed +----- + +- Bug where ``SIGINT`` counting had a data race on very rapid presses, causing unreliable pausing behavior +- Data saved from ``read_configuration()`` is now cached per stream, fixing subtle cache invalidation issues +- ``TypeError`` on ``np.round`` in the ``%wa`` Bluesky magic with a multi-axis ``PsuedoPositioner`` coming from Numpy 2.0 change +- Subtle bug in suspenders based on which thread the Ophyd subscription was run in +- Handle empty motors in ``RunStart`` document in ``BestEffortCallback`` + +Changed +------- + +- Dropped support and testing for Python 3.9 (EOL in 2025-10) +- ``SIGINT`` pause/interrupt behavior now requires 100ms between signal arrival to count toward a hard-pause or a ``KeyboardInterrupt`` + +v1.14.6 (2025-10-08) +==================== + +Added +----- + +Fixed +----- + +- Error when using ``bps.wait`` with a timeout that actually triggered + +Changed +------- + +- Remove the ``'streams'`` namespace (container) from the container structure created by ``TiledWriter`` + v1.14.5 (2025-10-03) ==================== diff --git a/docs/hardware.rst b/docs/hardware.rst index 7adaf0178d..73ba309fcb 100644 --- a/docs/hardware.rst +++ b/docs/hardware.rst @@ -167,7 +167,7 @@ While a Datum will be yielded to specify a single frame of data in a Resource: .. autoclass:: bluesky.protocols.Datum :members: -.. seealso:: https://blueskyproject.io/event-model/external.html +.. seealso:: https://blueskyproject.io/event-model/main/explanations/external.html Movable (or "Settable") Device diff --git a/docs/metadata.rst b/docs/metadata.rst index a859a49a9b..33be1a4fbb 100644 --- a/docs/metadata.rst +++ b/docs/metadata.rst @@ -15,7 +15,7 @@ The same exact information can be "data" in one experiment, but "metadata" in a different experiment done on the exact same hardware. The `Document Model -`_ provides a framework +`_ provides a framework for deciding _where_ to record a particular piece of information. There are some things that we know *a priori* before doing an experiment; @@ -23,7 +23,7 @@ where are we? who is the user? what sample are we looking at? what did the user just ask us to do? These are all things that we can, in principle, know independent of the control system. These are the prime candidates for inclusion in the `Start Document -`_. +`_. Downstream DataBroker provides tools to do rich searches on this data. The more information you can include the better. @@ -50,7 +50,7 @@ A third class of information that can be called "metadata" is configuration information of pieces of hardware. These are things like the velocity of a motor or the integration time of a detector. These readings are embedded in the `Descriptor -`_ +`_ and are extracted from the hardware via the `read_configuration `_ method of the hardware. We expect that these values will not change over diff --git a/docs/plans.rst b/docs/plans.rst index 3caecc528d..c2440968c7 100644 --- a/docs/plans.rst +++ b/docs/plans.rst @@ -480,6 +480,7 @@ Plans for interacting with hardware: trigger read rd + locate stage unstage configure diff --git a/docs/tutorial.rst b/docs/tutorial.rst index 064a7ef719..d93ece7cbe 100644 --- a/docs/tutorial.rst +++ b/docs/tutorial.rst @@ -12,14 +12,14 @@ Before You Begin software, and you can skip the rest of this section. Go to `https://try.nsls2.bnl.gov `_. -* You will need Python 3.9 or newer. From a shell ("Terminal" on OSX, +* You will need Python 3.10 or newer. From a shell ("Terminal" on OSX, "Command Prompt" on Windows), check your current Python version. .. code-block:: bash python3 --version - If that version is less than 3.9, you must update it. + If that version is less than 3.10, you must update it. We recommend install bluesky into a "virtual environment" so this installation will not interfere with any existing Python software: @@ -35,7 +35,7 @@ Before You Begin .. code-block:: bash - conda create -n bluesky-tutorial "python>=3.9" + conda create -n bluesky-tutorial "python>=3.10" conda activate bluesky-tutorial * Install the latest versions of bluesky and ophyd. Also install the databroker diff --git a/pyproject.toml b/pyproject.toml index 27d8fadfd9..b4ea7fac5e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,10 +7,10 @@ name = "bluesky" classifiers = [ "Development Status :: 3 - Alpha", "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ] description = "Experiment specification & orchestration." dependencies = [ @@ -28,11 +28,12 @@ dependencies = [ dynamic = ["version"] license.file = "LICENSE" readme = "README.md" -requires-python = ">=3.9" +requires-python = ">=3.10" [project.optional-dependencies] dev = [ "attrs", + "bluesky-tiled-plugins", "cloudpickle", "copier", "coverage", @@ -46,11 +47,13 @@ dev = [ "lmfit", "matplotlib >=3.5.0", "mongoquery", + "mongomock", "multiprocess", "mypy", "myst-parser", "networkx", "numpydoc", + "opentelemetry-sdk", "ophyd", "orjson", "packaging", @@ -59,8 +62,7 @@ dev = [ "pipdeptree", "pre-commit", "pydata-sphinx-theme>=0.12", - "pyepics<=3.5.2;python_version<'3.9'", # Needed to pass CI/CD tests; To be removed once we drop support for py3.8 - "pyepics;python_version>='3.9'", + "pyepics", "pyqt5", "pytest", "pytest-cov", @@ -167,7 +169,10 @@ select = [ "I", # isort - https://docs.astral.sh/ruff/rules/#isort-i "UP", # pyupgrade - https://docs.astral.sh/ruff/rules/#pyupgrade-up ] -ignore = ["UP031"] # Ignore %-format strings until #1848 is done +ignore = [ + "UP031", # Ignore %-format strings until #1848 is done + "B905", # adding strict=False to every zip call is just noise +] [tool.ruff.lint.per-file-ignores] # By default, private member access is allowed in tests @@ -206,4 +211,4 @@ commands = type-checking: mypy src {posargs} tests: pytest --cov=bluesky --cov-report term --cov-report xml:cov.xml {posargs} docs: sphinx-{posargs:build -EW --keep-going} -T docs build/html -""" \ No newline at end of file +""" diff --git a/src/bluesky/__main__.py b/src/bluesky/__main__.py index e149748384..1336f9731d 100644 --- a/src/bluesky/__main__.py +++ b/src/bluesky/__main__.py @@ -2,14 +2,13 @@ from argparse import ArgumentParser from collections.abc import Sequence -from typing import Union from . import __version__ __all__ = ["main"] -def main(args: Union[Sequence[str], None] = None) -> None: +def main(args: Sequence[str] | None = None) -> None: """Argument parser for the CLI.""" parser = ArgumentParser() diff --git a/src/bluesky/bundlers.py b/src/bluesky/bundlers.py index 0f72f6b495..df923c7dd7 100644 --- a/src/bluesky/bundlers.py +++ b/src/bluesky/bundlers.py @@ -2,10 +2,11 @@ import inspect import time as ttime from collections import defaultdict, deque -from collections.abc import Iterable +from collections.abc import Callable, Iterable +from dataclasses import dataclass, field from itertools import combinations from logging import LoggerAdapter -from typing import Any, Callable, Literal, Optional, Union, cast +from typing import Any, Literal, TypeAlias, TypeGuard, cast from event_model import ( ComposeDescriptorBundle, @@ -54,12 +55,12 @@ ) ObjDict = dict[Any, dict[str, T]] -ExternalAssetDoc = Union[Datum, Resource, StreamDatum, StreamResource] +ExternalAssetDoc: TypeAlias = Datum | Resource | StreamDatum | StreamResource def _describe_collect_dict_is_valid( - describe_collect_dict: Union[Any, dict[str, Any]], -) -> bool: # TODO: change to TypeGuard[dict[str, DataKey]] after python 3.9 + describe_collect_dict: Any | dict[str, Any], +) -> TypeGuard[dict[str, DataKey]]: """ Check if the describe_collect dictionary contains valid DataKeys. """ @@ -71,10 +72,68 @@ def _describe_collect_dict_is_valid( ) +@dataclass +class _StreamCache: + """Helper data class that holds all cached data related to devices for a partiuclar stream.""" + + # Readable + # cache of obj.read() in one Event + read_cache: deque[dict[str, Reading]] = field(default_factory=deque) + # cache of all obj.describe() output + describe_cache: ObjDict[DataKey] = field(default_factory=dict) + + # Configurable + # obj.describe_configuration() + config_desc_cache: ObjDict[DataKey] = field(default_factory=dict) + # obj.read_configuration() values + config_values_cache: ObjDict[Any] = field(default_factory=dict) + # obj.read_configuration() timestamps + config_ts_cache: ObjDict[Any] = field(default_factory=dict) + + async def ensure_cached(self, obj, collect=False): + """Cache objects Readable and Configurable methods. Cache Collectable methods if collect is True.""" + coros = [] + if not collect and obj not in self.describe_cache: + coros.append(self._cache_describe(obj)) + + if obj not in self.config_desc_cache: + coros.append(self._cache_describe_config(obj)) + coros.append(self.cache_read_config(obj)) + await asyncio.gather(*coros) + + async def _cache_describe(self, obj): + "Read the object's describe and cache it." + obj = check_supports(obj, Readable) + self.describe_cache[obj] = await maybe_await(obj.describe()) + + async def _cache_describe_config(self, obj): + "Read the object's describe_configuration and cache it." + if isinstance(obj, Configurable): + conf_keys = await maybe_await(obj.describe_configuration()) + else: + conf_keys = {} + + self.config_desc_cache[obj] = conf_keys + + async def cache_read_config(self, obj): + "Read the object's configuration and cache it." + if isinstance(obj, Configurable): + conf = await maybe_await(obj.read_configuration()) + else: + conf = {} + config_values = {} + config_ts = {} + for key, val in conf.items(): + config_values[key] = val["value"] + config_ts[key] = val["timestamp"] + self.config_values_cache[obj] = config_values + self.config_ts_cache[obj] = config_ts + + class RunBundler: def __init__( self, - md: Optional[dict], + md: dict | None, record_interruptions: bool, emit: Callable, emit_sync: Callable, @@ -89,14 +148,15 @@ def __init__( self._bundle_name = None # name given to event descriptor self._run_start_uid = None # The (future) runstart uid self._objs_read: deque[HasName] = deque() # objects read in one Event - self._read_cache: deque[dict[str, Reading]] = deque() # cache of obj.read() in one Event - self._asset_docs_cache: deque[Union[Asset, StreamAsset]] = deque() # cache of obj.collect_asset_docs() - self._describe_cache: ObjDict[DataKey] = dict() # cache of all obj.describe() output # noqa: C408 - self._describe_collect_cache: dict[Any, Union[dict[str, DataKey], dict[str, dict[str, DataKey]]]] = dict() # noqa: C408 # cache of all obj.describe() output - - self._config_desc_cache: ObjDict[DataKey] = dict() # " obj.describe_configuration() # noqa: C408 - self._config_values_cache: ObjDict[Any] = dict() # " obj.read_configuration() values # noqa: C408 - self._config_ts_cache: ObjDict[Any] = dict() # " obj.read_configuration() timestamps # noqa: C408 + + # Collectable + # cache of all obj.describe_collect() output + self._describe_collect_cache: dict[Any, dict[str, DataKey] | dict[str, dict[str, DataKey]]] = dict() # noqa: C408 + self._saved_stream_cache: dict[str, _StreamCache] = dict() # noqa: C408 + self._current_stream_cache: _StreamCache = _StreamCache() + + self._asset_docs_cache: deque[Asset | StreamAsset] = deque() # cache of obj.collect_asset_docs() + # cache of {name: (doc, compose_event, compose_event_page)} self._descriptors: dict[Any, ComposeDescriptorBundle] = dict() # noqa: C408 self._descriptor_objs: dict[str, dict[HasName, dict[str, DataKey]]] = dict() # noqa: C408 @@ -135,6 +195,11 @@ async def open_run(self, msg: Msg): self._compose_stop = run.compose_stop self._compose_stream_resource = run.compose_stream_resource + self._current_stream_cache = _StreamCache() + self._saved_stream_cache.clear() + + self._describe_collect_cache.clear() + await self.emit(DocumentNames.start, doc) doc_logger.debug( "[start] document is emitted (run_uid=%r)", @@ -219,9 +284,9 @@ async def _prepare_stream( dks[key]["object_name"] = obj.name data_keys.update(dks) config[obj.name] = { - "data": self._config_values_cache[obj], - "timestamps": self._config_ts_cache[obj], - "data_keys": self._config_desc_cache[obj], + "data": self._current_stream_cache.config_values_cache[obj], + "timestamps": self._current_stream_cache.config_ts_cache[obj], + "data_keys": self._current_stream_cache.config_desc_cache[obj], } self._descriptors[desc_key] = self._compose_descriptor( @@ -250,17 +315,6 @@ async def _prepare_stream( list(objs_dks), ) - async def _ensure_cached(self, obj, collect=False): - coros = [] - if not collect and obj not in self._describe_cache: - coros.append(self._cache_describe(obj)) - elif collect and obj not in self._describe_collect_cache: - coros.append(self._cache_describe_collect(obj)) - if obj not in self._config_desc_cache: - coros.append(self._cache_describe_config(obj)) - coros.append(self._cache_read_config(obj)) - await asyncio.gather(*coros) - async def declare_stream(self, msg): """Generate and emit an EventDescriptor.""" command, no_obj, objs, kwargs, _ = msg @@ -272,6 +326,8 @@ async def declare_stream(self, msg): objs = frozenset(objs) objs_dks = {} # {collect_object: stream_data_keys} + self._set_current_stream_cache(stream_name) + await asyncio.gather(*[self._ensure_cached(obj, collect=collect) for obj in objs]) for obj in objs: if collect: @@ -286,7 +342,7 @@ async def declare_stream(self, msg): f"not return a single Dict[str, DataKey] for the passed in {stream_name}" ) else: - data_keys = self._describe_cache[obj] + data_keys = self._current_stream_cache.describe_cache[obj] objs_dks[obj] = data_keys @@ -295,6 +351,14 @@ async def declare_stream(self, msg): return await self._prepare_stream(stream_name, objs_dks) + def _set_current_stream_cache(self, stream_name: str): + if stream_name in self._saved_stream_cache: + self._current_stream_cache = self._saved_stream_cache[stream_name] + self._current_stream_cache.read_cache.clear() + else: + self._current_stream_cache = _StreamCache() + self._saved_stream_cache[stream_name] = self._current_stream_cache + async def create(self, msg): """ Start bundling future obj.read() calls for an Event document. @@ -317,7 +381,6 @@ async def create(self, msg): "bundle is closed with a 'save' or " "'drop' message." ) - self._read_cache.clear() self._asset_docs_cache.clear() self._objs_read.clear() self.bundling = True @@ -336,6 +399,8 @@ async def create(self, msg): if self._bundle_name not in self._descriptors: raise IllegalMessageSequence("In strict mode you must pre-declare streams.") + self._set_current_stream_cache(self._bundle_name) + async def read(self, msg, reading): """ Add a reading to the open event bundle. @@ -352,14 +417,14 @@ async def read(self, msg, reading): # on the same device you make obj.describe() calls multiple times. # As this is harmless and not an expected use case, we don't guard # against it. Reading multiple devices concurrently works fine. - await self._ensure_cached(obj) + await self._current_stream_cache.ensure_cached(obj) # check that current read collides with nothing else in # current event - cur_keys = set(self._describe_cache[obj].keys()) + cur_keys = set(self._current_stream_cache.describe_cache[obj].keys()) for read_obj in self._objs_read: # that is, field names - known_keys = self._describe_cache[read_obj].keys() + known_keys = self._current_stream_cache.describe_cache[read_obj].keys() if set(known_keys) & cur_keys: raise ValueError( f"Data keys (field names) from {obj!r} " @@ -372,7 +437,7 @@ async def read(self, msg, reading): # Stash the results, which will be emitted the next time _save is # called --- or never emitted if _drop is called instead. - self._read_cache.append(reading) + self._current_stream_cache.read_cache.append(reading) # Ask the object for any resource or datum documents is has cached # and cache them as well. Likewise, these will be emitted if and # when _save is called. @@ -381,35 +446,6 @@ async def read(self, msg, reading): return reading - async def _cache_describe(self, obj): - "Read the object's describe and cache it." - obj = check_supports(obj, Readable) - self._describe_cache[obj] = await maybe_await(obj.describe()) - - async def _cache_describe_config(self, obj): - "Read the object's describe_configuration and cache it." - - if isinstance(obj, Configurable): - conf_keys = await maybe_await(obj.describe_configuration()) - else: - conf_keys = {} - - self._config_desc_cache[obj] = conf_keys - - async def _cache_read_config(self, obj): - "Read the object's configuration and cache it." - if isinstance(obj, Configurable): - conf = await maybe_await(obj.read_configuration()) - else: - conf = {} - config_values = {} - config_ts = {} - for key, val in conf.items(): - config_values[key] = val["value"] - config_ts[key] = val["timestamp"] - self._config_values_cache[obj] = config_values - self._config_ts_cache[obj] = config_ts - async def monitor(self, msg): """ Monitor a signal. Emit event documents asynchronously. @@ -434,12 +470,12 @@ async def monitor(self, msg): if obj in self._monitor_params: raise IllegalMessageSequence(f"A 'monitor' message was sent for {obj} which is already monitored") - await self._ensure_cached(obj) + await self._current_stream_cache.ensure_cached(obj) - stream_bundle = await self._prepare_stream(name, {obj: self._describe_cache[obj]}) + stream_bundle = await self._prepare_stream(name, {obj: self._current_stream_cache.describe_cache[obj]}) compose_event = stream_bundle[1] - def emit_event(readings: Optional[dict[str, Reading]] = None, *args, **kwargs): + def emit_event(readings: dict[str, Reading] | None = None, *args, **kwargs): if readings is not None: # We were passed something we can use, but check no args or kwargs assert not args and not kwargs, ( @@ -561,8 +597,8 @@ async def save(self, msg): if descriptor_doc is None or d_objs is None: # use the dequeue not the set to preserve order for obj in self._objs_read: - await self._ensure_cached(obj, collect=isinstance(obj, Collectable)) - objs_dks[obj] = self._describe_cache[obj] + await self._current_stream_cache.ensure_cached(obj, collect=isinstance(obj, Collectable)) + objs_dks[obj] = self._current_stream_cache.describe_cache[obj] descriptor_doc, compose_event, d_objs = await self._prepare_stream(desc_key, objs_dks) @@ -580,7 +616,7 @@ async def save(self, msg): ) # Merge list of readings into single dict. - readings = {k: v for d in self._read_cache for k, v in d.items()} + readings = {k: v for d in self._current_stream_cache.read_cache for k, v in d.items()} data, timestamps = _rearrange_into_parallel_dicts(readings) # Mark all externally-stored data as not filled so that consumers # know that the corresponding data are identifiers, not dereferenced @@ -672,8 +708,8 @@ async def kickoff(self, msg): def _format_datakeys_with_stream_name( self, - describe_collect_dict: Union[dict[str, DataKey], dict[str, dict[str, DataKey]]], - message_stream_name: Optional[str] = None, + describe_collect_dict: dict[str, DataKey] | dict[str, dict[str, DataKey]], + message_stream_name: str | None = None, ) -> list[tuple[str, dict[str, DataKey]]]: """ Check if the dictionary returned by describe collect is a dict @@ -684,22 +720,20 @@ def _format_datakeys_with_stream_name( """ def _contains_message_stream_name( - describe_collect_dict: Union[Any, dict[str, Any]], - ) -> bool: # TODO: change to TypeGuard[dict[str, dict[str, DataKey]]] after python 3.9 + describe_collect_dict: Any | dict[str, Any], + ) -> TypeGuard[dict[str, dict[str, DataKey]]]: return isinstance(describe_collect_dict, dict) and all( _describe_collect_dict_is_valid(v) for v in describe_collect_dict.values() ) if describe_collect_dict: if _describe_collect_dict_is_valid(describe_collect_dict): - # TODO: remove cast after python 3.9 is no longer supported - flat_describe_collect_dict = cast(dict[str, DataKey], describe_collect_dict) + flat_describe_collect_dict = describe_collect_dict return [(message_stream_name or "primary", flat_describe_collect_dict)] # Validate that all of the values nested values are DataKeys elif _contains_message_stream_name(describe_collect_dict): # We have Dict[str, Dict[str, DataKey]] so return its items - # TODO: remove cast after python 3.9 is no longer supported - nested_describe_collect_dict = cast(dict[str, dict[str, DataKey]], describe_collect_dict) + nested_describe_collect_dict = describe_collect_dict if message_stream_name and list(nested_describe_collect_dict) != [message_stream_name]: # The collect contained a name and describe_collect returned a Dict[str, Dict[str, DataKey]], # this is only acceptable if the only key in the parent dict is message_stream_name @@ -717,12 +751,6 @@ def _contains_message_stream_name( # Empty dict, could be either but we don't care return [] - async def _cache_describe_collect(self, obj: Collectable): - "Read the object's describe and cache it." - obj = check_supports(obj, Collectable) - c: Union[dict[str, DataKey], dict[str, dict[str, DataKey]]] = await maybe_await(obj.describe_collect()) - self._describe_collect_cache[obj] = c - async def _describe_collect(self, collect_object: Flyable): """Read an object's describe_collect and cache it. @@ -746,7 +774,7 @@ async def _describe_collect(self, collect_object: Flyable): } """ - await self._ensure_cached(collect_object, collect=True) + await self._cache_describe_collect(collect_object) describe_collect = self._describe_collect_cache[collect_object] describe_collect_items = self._format_datakeys_with_stream_name(describe_collect) @@ -779,6 +807,12 @@ def is_data_key(obj: Any) -> bool: ) for stream_name, stream_data_keys in describe_collect_items: + # Will this be okay left at last stream? + self._set_current_stream_cache(stream_name) + # Should we make the describe_cache the describe_collect_cache? + self._current_stream_cache.describe_cache[collect_object] = stream_data_keys + await self._current_stream_cache.ensure_cached(collect_object, collect=True) + if stream_name not in self._descriptor_objs or ( collect_object not in self._descriptor_objs[stream_name] ): @@ -825,9 +859,8 @@ async def _pack_seq_nums_into_stream_datum( return indices_difference - # message strem name here? async def _pack_external_assets( - self, asset_docs: Iterable[Union[Asset, StreamAsset]], message_stream_name: Optional[str] + self, asset_docs: Iterable[Asset | StreamAsset], message_stream_name: str | None ): """Packs some external asset documents with relevant information from the run.""" @@ -920,7 +953,7 @@ async def _collect_events( collect_obj: EventCollectable, local_descriptors, return_payload: bool, - message_stream_name: Optional[str], + message_stream_name: str | None, ): payload = [] pages: dict[frozenset[str], list[Event]] = defaultdict(list) @@ -969,7 +1002,7 @@ async def _collect_event_pages( collect_obj: EventPageCollectable, local_descriptors, return_payload: bool, - message_stream_name: Optional[str], + message_stream_name: str | None, ): payload = [] @@ -1015,7 +1048,7 @@ async def collect(self, msg: Msg): Where there must be at least one collect object. If multiple are used they must obey the WritesStreamAssets protocol. """ - stream_name: Optional[str] = None + stream_name: str | None = None if not self.run_is_open: # sanity check -- 'kickoff' should catch this and make this @@ -1057,7 +1090,7 @@ async def collect(self, msg: Msg): self._uncollected.discard(obj) # Get the provided message stream name for singly nested scans - message_stream_name: Optional[str] = msg.kwargs.get("name", None) + message_stream_name: str | None = msg.kwargs.get("name", None) # Retrive the stream names from pre-declared streams declared_stream_names = self._declared_stream_names.get(frozenset(collect_objects), []) @@ -1173,7 +1206,7 @@ async def configure(self, msg): object.configure(*args, **kwargs) """ obj = msg.obj - await self._cache_read_config(obj) + await self._current_stream_cache.cache_read_config(obj) # Invalidate any event descriptors that include this object. # New event descriptors, with this new configuration, will # be created for any future event documents. @@ -1183,3 +1216,16 @@ async def configure(self, msg): del self._descriptors[name] await self._prepare_stream(name, obj_set) continue + + async def _cache_describe_collect(self, obj): + "Read the object's describe_collect and cache it." + if obj not in self._describe_collect_cache: + obj = check_supports(obj, Collectable) + c: dict[str, DataKey] | dict[str, dict[str, DataKey]] = await maybe_await(obj.describe_collect()) + self._describe_collect_cache[obj] = c + + async def _ensure_cached(self, obj, collect: bool = False): + coros = [self._current_stream_cache.ensure_cached(obj, collect)] + if collect: + coros.append(self._cache_describe_collect(obj)) + await asyncio.gather(*coros) diff --git a/src/bluesky/callbacks/best_effort.py b/src/bluesky/callbacks/best_effort.py index e68b900f42..e10c24f09e 100644 --- a/src/bluesky/callbacks/best_effort.py +++ b/src/bluesky/callbacks/best_effort.py @@ -113,7 +113,7 @@ def start(self, doc): # Prepare a guess about the dimensions (independent variables) in case # we need it. motors = self._start_doc.get("motors") - if motors is not None: + if motors is not None and len(motors) > 0: GUESS = [([motor], "primary") for motor in motors] else: GUESS = [(["time"], "primary")] diff --git a/src/bluesky/callbacks/buffer.py b/src/bluesky/callbacks/buffer.py index 3820aa5c7a..e5bdd0bcd7 100644 --- a/src/bluesky/callbacks/buffer.py +++ b/src/bluesky/callbacks/buffer.py @@ -1,8 +1,8 @@ import atexit import logging import threading +from collections.abc import Callable from queue import Empty, Full, Queue -from typing import Callable logger = logging.getLogger(__name__) diff --git a/src/bluesky/callbacks/json_writer.py b/src/bluesky/callbacks/json_writer.py index 22d24d3640..df7f5a92b4 100644 --- a/src/bluesky/callbacks/json_writer.py +++ b/src/bluesky/callbacks/json_writer.py @@ -1,7 +1,6 @@ import json from datetime import datetime from pathlib import Path -from typing import Optional class JSONWriter: @@ -15,7 +14,7 @@ class JSONWriter: def __init__( self, dirname: str, - filename: Optional[str] = None, + filename: str | None = None, ): self.dirname = Path(dirname) self.filename = filename @@ -45,7 +44,7 @@ class JSONLinesWriter: If the file already exists, new documents will be appended to it. """ - def __init__(self, dirname: str, filename: Optional[str] = None): + def __init__(self, dirname: str, filename: str | None = None): self.dirname = Path(dirname) self.filename = filename diff --git a/src/bluesky/callbacks/tiled_writer.py b/src/bluesky/callbacks/tiled_writer.py index 91a81bd2b8..e5c64f0859 100644 --- a/src/bluesky/callbacks/tiled_writer.py +++ b/src/bluesky/callbacks/tiled_writer.py @@ -2,8 +2,9 @@ import itertools import logging from collections import defaultdict, deque, namedtuple +from collections.abc import Callable from pathlib import Path -from typing import Any, Callable, Optional, Union, cast +from typing import Any, cast from warnings import warn import pyarrow @@ -180,8 +181,8 @@ class RunNormalizer(CallbackBase): def __init__( self, - patches: Optional[dict[str, Callable]] = None, - spec_to_mimetype: Optional[dict[str, str]] = None, + patches: dict[str, Callable] | None = None, + spec_to_mimetype: dict[str, str] | None = None, ): self._token_refs: dict[str, Callable] = {} self.dispatcher = Dispatcher() @@ -199,7 +200,7 @@ def __init__( self._int_keys: set[str] = set() # Names of internal data_keys self._ext_keys: set[str] = set() - def _convert_resource_to_stream_resource(self, doc: Union[Resource, StreamResource]) -> StreamResource: + def _convert_resource_to_stream_resource(self, doc: Resource | StreamResource) -> StreamResource: """Make changes to and return a shallow copy of StreamRsource dictionary adhering to the new structure. Kept for back-compatibility with old StreamResource schema from event_model<1.20.0 @@ -241,7 +242,7 @@ def _convert_resource_to_stream_resource(self, doc: Union[Resource, StreamResour def _convert_datum_to_stream_datum( self, datum_doc: Datum, data_key: str, desc_uid: str, seq_num: int - ) -> tuple[Optional[StreamResource], StreamDatum]: + ) -> tuple[StreamResource | None, StreamDatum]: """Convert the Datum document to the StreamDatum format This conversion requires (and is triggered when) the Event document is received. The function also returns @@ -511,7 +512,7 @@ class _RunWriter(CallbackBase): def __init__(self, client: BaseClient, batch_size: int = BATCH_SIZE): self.client = client - self.root_node: Union[None, Container] = None + self.root_node: None | Container = None self._desc_nodes: dict[str, Container] = {} # references to the descriptor nodes by their uid's and names self._sres_nodes: dict[str, BaseClient] = {} self._internal_tables: dict[str, DataFrameClient] = {} # references to the internal tables by desc_names @@ -657,7 +658,7 @@ def event_page(self, doc: EventPage): def stream_resource(self, doc: StreamResource): self._stream_resource_cache[doc["uid"]] = doc - def get_sres_node(self, sres_uid: str, desc_uid: Optional[str] = None) -> tuple[BaseClient, ConsolidatorBase]: + def get_sres_node(self, sres_uid: str, desc_uid: str | None = None) -> tuple[BaseClient, ConsolidatorBase]: """Get the Tiled node and the associate Consolidator corresponding to the data_key in StreamResource If the node does not exist, register it from a cached StreamResource document. Keep a reference to the @@ -771,10 +772,10 @@ def __init__( self, client: BaseClient, *, - normalizer: Optional[type[CallbackBase]] = RunNormalizer, - patches: Optional[dict[str, Callable]] = None, - spec_to_mimetype: Optional[dict[str, str]] = None, - backup_directory: Optional[str] = None, + normalizer: type[CallbackBase] | None = RunNormalizer, + patches: dict[str, Callable] | None = None, + spec_to_mimetype: dict[str, str] | None = None, + backup_directory: str | None = None, batch_size: int = BATCH_SIZE, ): self.client = client.include_data_sources() @@ -805,10 +806,10 @@ def from_uri( cls, uri, *, - normalizer: Optional[type[CallbackBase]] = RunNormalizer, - patches: Optional[dict[str, Callable]] = None, - spec_to_mimetype: Optional[dict[str, str]] = None, - backup_directory: Optional[str] = None, + normalizer: type[CallbackBase] | None = RunNormalizer, + patches: dict[str, Callable] | None = None, + spec_to_mimetype: dict[str, str] | None = None, + backup_directory: str | None = None, batch_size: int = BATCH_SIZE, **kwargs, ): @@ -827,10 +828,10 @@ def from_profile( cls, profile, *, - normalizer: Optional[type[CallbackBase]] = RunNormalizer, - patches: Optional[dict[str, Callable]] = None, - spec_to_mimetype: Optional[dict[str, str]] = None, - backup_directory: Optional[str] = None, + normalizer: type[CallbackBase] | None = RunNormalizer, + patches: dict[str, Callable] | None = None, + spec_to_mimetype: dict[str, str] | None = None, + backup_directory: str | None = None, batch_size: int = BATCH_SIZE, **kwargs, ): diff --git a/src/bluesky/callbacks/zmq.py b/src/bluesky/callbacks/zmq.py index 50868675cc..0f3749fe0b 100644 --- a/src/bluesky/callbacks/zmq.py +++ b/src/bluesky/callbacks/zmq.py @@ -20,6 +20,36 @@ from ..run_engine import Dispatcher, DocumentNames +def _normalize_address(inp: str | tuple | int): + if isinstance(inp, str): + if "://" in inp: + protocol, _, rest_str = inp.partition("://") + else: + protocol = "tcp" + rest_str = inp + elif isinstance(inp, tuple): + if inp[0] in ["tcp", "ipc"]: + protocol, *rest = inp + else: + protocol = "tcp" + rest = list(inp) + if protocol == "tcp": + if len(rest) == 2: + rest_str = ":".join(str(r) for r in rest) + else: + (rest_str,) = rest + else: + (rest_str,) = rest + elif isinstance(inp, int): + protocol = "tcp" + rest_str = f"0.0.0.0:{inp}" + + else: + raise TypeError(f"Input expected to be int, str, or tuple, not {type(inp)}") + + return f"{protocol}://{rest_str}" + + class Bluesky0MQDecodeError(Exception): """Custom exception class for things that go wrong reading message from wire.""" @@ -73,15 +103,14 @@ def __init__(self, address, *, prefix=b"", RE=None, zmq=None, serializer=pickle. raise ValueError(f"prefix {prefix!r} may not contain b' '") if zmq is None: import zmq - if isinstance(address, str): - address = address.split(":", maxsplit=1) - self.address = (address[0], int(address[1])) + + self.address = _normalize_address(address) self.RE = RE - url = "tcp://%s:%d" % self.address + self._prefix = bytes(prefix) self._context = zmq.Context() self._socket = self._context.socket(zmq.PUB) - self._socket.connect(url) + self._socket.connect(self.address) if RE: self._subscription_token = RE.subscribe(self) self._serializer = serializer @@ -102,23 +131,40 @@ class Proxy: """ Start a 0MQ proxy on the local host. + The addresses can be specified flexibly. It is best to use + a domain_socket (available on unix): + + - ``'icp:///tmp/domain_socket'`` + - ``('ipc', '/tmp/domain_socket')`` + + tcp sockets are also supported: + + - ``'tcp://localhost:6557'`` + - ``6657`` (implicitly binds to ``'tcp://localhost:6557'`` + - ``('tcp', 'localhost', 6657)`` + - ``('localhost', 6657)`` + Parameters ---------- - in_port : int, optional - Port that RunEngines should broadcast to. If None, a random port is - used. - out_port : int, optional - Port that subscribers should subscribe to. If None, a random port is - used. + in_address : str or tuple or int, optional + Address that RunEngines should broadcast to. + + If None, a random tcp port on all interfaces is used. + + out_address : str or tuple or int, optional + Address that subscribers should subscribe to. + + If None, a random tcp port on all interfaces is used. + zmq : object, optional By default, the 'zmq' module is imported and used. Anything else mocking its interface is accepted. Attributes ---------- - in_port : int + in_address: int or str or tuple Port that RunEngines should broadcast to. - out_port : int + out_address : int or str or tuple Port that subscribers should subscribe to. closed : boolean True if the Proxy has already been started and subsequently @@ -146,7 +192,42 @@ class Proxy: >>> proxy.start() # runs until interrupted """ - def __init__(self, in_port=None, out_port=None, *, zmq=None): + def __init__( + self, + in_address=None, + out_address=None, + *, + zmq=None, + in_port=None, + out_port=None, + ): + # Handle backward compatibility for in_port -> in_address + if in_port is not None and in_address is not None: + raise ValueError("Cannot specify both 'in_port' and 'in_address'. Use 'in_address' only.") + if in_port is not None: + warnings.warn( + "The 'in_port' parameter is deprecated and will be removed in a future release. " + "Use 'in_address' instead.", + DeprecationWarning, + stacklevel=2, + ) + in_address = in_port + + # Handle backward compatibility for out_port -> out_address + if out_port is not None and out_address is not None: + raise ValueError("Cannot specify both 'out_port' and 'out_address'. Use 'out_address' only.") + if out_port is not None: + warnings.warn( + "The 'out_port' parameter is deprecated and will be removed in a future release. " + "Use 'out_address' instead.", + DeprecationWarning, + stacklevel=2, + ) + out_address = out_port + + # Delete deprecated parameter names + del in_port, out_port + if zmq is None: import zmq self.zmq = zmq @@ -155,19 +236,22 @@ def __init__(self, in_port=None, out_port=None, *, zmq=None): context = zmq.Context(1) # Socket facing clients frontend = context.socket(zmq.SUB) - if in_port is None: - in_port = frontend.bind_to_random_port("tcp://*") + if in_address is None: + in_bind_result = frontend.bind_to_random_port("tcp://*") else: - frontend.bind("tcp://*:%d" % in_port) + in_address = _normalize_address(in_address) + in_bind_result = frontend.bind(in_address) frontend.setsockopt_string(zmq.SUBSCRIBE, "") # Socket facing services backend = context.socket(zmq.PUB) - if out_port is None: - out_port = backend.bind_to_random_port("tcp://*") + if out_address is None: + out_bind_result = backend.bind_to_random_port("tcp://*") else: - backend.bind("tcp://*:%d" % out_port) + out_address = _normalize_address(out_address) + out_bind_result = backend.bind(out_address) + except BaseException: # Clean up whichever components we have defined so far. try: @@ -181,8 +265,12 @@ def __init__(self, in_port=None, out_port=None, *, zmq=None): context.destroy() raise else: - self.in_port = in_port - self.out_port = out_port + self.in_port = ( + in_bind_result.addr if hasattr(in_bind_result, "addr") else _normalize_address(in_bind_result) + ) + self.out_port = ( + out_bind_result.addr if hasattr(out_bind_result, "addr") else _normalize_address(out_bind_result) + ) self._frontend = frontend self._backend = backend self._context = context @@ -257,10 +345,8 @@ def __init__( import zmq if zmq_asyncio is None: import zmq.asyncio as zmq_asyncio - if isinstance(address, str): - address = address.split(":", maxsplit=1) self._deserializer = deserializer - self.address = (address[0], int(address[1])) + self.address = _normalize_address(address) if loop is None: loop = asyncio.new_event_loop() @@ -274,8 +360,7 @@ def __finish_setup(): self._context = zmq_asyncio.Context() self._socket = self._context.socket(zmq.SUB) - url = "tcp://%s:%d" % self.address - self._socket.connect(url) + self._socket.connect(self.address) self._socket.setsockopt_string(zmq.SUBSCRIBE, "") self.__factory = __finish_setup diff --git a/src/bluesky/commandline/zmq_proxy.py b/src/bluesky/commandline/zmq_proxy.py index 40d66ea284..4425d63e6f 100644 --- a/src/bluesky/commandline/zmq_proxy.py +++ b/src/bluesky/commandline/zmq_proxy.py @@ -1,13 +1,14 @@ import argparse import logging import threading +import warnings from bluesky.callbacks.zmq import Proxy, RemoteDispatcher logger = logging.getLogger("bluesky") -def start_dispatcher(host, port, logfile=None): +def start_dispatcher(out_address, logfile=None): """The dispatcher function Parameters ---------- @@ -15,7 +16,7 @@ def start_dispatcher(host, port, logfile=None): string come from user command. ex --logfile=temp.log logfile will be "temp.log". logfile could be empty. """ - dispatcher = RemoteDispatcher((host, port)) + dispatcher = RemoteDispatcher(out_address) if logfile is not None: raise ValueError( "Parameter 'logfile' is deprecated and will be removed in future releases. " @@ -41,8 +42,15 @@ def log_writer(name, doc): def main(): DESC = "Start a 0MQ proxy for publishing bluesky documents over a network." parser = argparse.ArgumentParser(description=DESC) - parser.add_argument("in_port", type=int, nargs=1, help="port that RunEngines should broadcast to") - parser.add_argument("out_port", type=int, nargs=1, help="port that subscribers should subscribe to") + + # New optional arguments (preferred) + parser.add_argument("--in-address", dest="in_address_opt", help="port that RunEngines should broadcast to") + parser.add_argument("--out-address", dest="out_address_opt", help="port that subscribers should subscribe to") + + # Old positional arguments (deprecated, for backward compatibility) + parser.add_argument("in_port", type=int, nargs="?", help=argparse.SUPPRESS) + parser.add_argument("out_port", type=int, nargs="?", help=argparse.SUPPRESS) + parser.add_argument( "--verbose", "-v", @@ -51,9 +59,52 @@ def main(): ) parser.add_argument("--logfile", type=str, help="Redirect logging output to a file on disk.") args = parser.parse_args() - in_port = args.in_port[0] - out_port = args.out_port[0] + # Handle backward compatibility + in_address = None + out_address = None + + # Check if old positional arguments were used + if args.in_port is not None or args.out_port is not None: + # Validate that both are provided if using positional arguments + if args.in_port is None or args.out_port is None: + raise ValueError( + "Both in_port and out_port positional arguments must be provided together. " + "Consider using the new optional arguments instead: " + "--in-address and --out-address" + ) + + if args.in_address_opt is not None or args.out_address_opt is not None: + raise ValueError( + "Cannot mix positional arguments (in_port, out_port) with optional arguments " + "(--in-address, --out-address)." + ) + + warnings.warn( + "Using positional arguments for in_port and out_port is deprecated. " + "Use --in-address and --out-address instead.", + FutureWarning, + stacklevel=1, + ) + in_address = args.in_port + out_address = args.out_port + else: + # Use new optional arguments + in_address = args.in_address_opt + out_address = args.out_address_opt + + print("Connecting...") + try: + in_address = int(in_address) + except (ValueError, TypeError): + pass + try: + out_address = int(out_address) + except (ValueError, TypeError): + pass + + proxy = Proxy(in_address, out_address) + print("Receiving on address %s; publishing to address %s." % (proxy.in_port, proxy.out_port)) if args.verbose: from bluesky.log import config_bluesky_logging @@ -64,11 +115,8 @@ def main(): else: config_bluesky_logging(level=level) # Set daemon to kill all threads upon IPython exit - threading.Thread(target=start_dispatcher, args=("localhost", out_port), daemon=True).start() + threading.Thread(target=start_dispatcher, args=(proxy.out_port,), daemon=True).start() - print("Connecting...") - proxy = Proxy(in_port, out_port) - print("Receiving on port %d; publishing to port %d." % (in_port, out_port)) print("Use Ctrl+C to exit.") try: proxy.start() diff --git a/src/bluesky/consolidators.py b/src/bluesky/consolidators.py index b2d33dabed..d767ab8ba4 100644 --- a/src/bluesky/consolidators.py +++ b/src/bluesky/consolidators.py @@ -4,7 +4,7 @@ import os import re import warnings -from typing import Any, Literal, Optional, Union, cast +from typing import Any, Literal, cast import numpy as np from event_model.documents import EventDescriptor, StreamDatum, StreamResource @@ -35,17 +35,17 @@ class Management(str, enum.Enum): class Asset: data_uri: str is_directory: bool - parameter: Optional[str] - num: Optional[int] = None - id: Optional[int] = None + parameter: str | None + num: int | None = None + id: int | None = None @dataclasses.dataclass class DataSource: structure_family: StructureFamily structure: Any - id: Optional[int] = None - mimetype: Optional[str] = None + id: int | None = None + mimetype: str | None = None parameters: dict = dataclasses.field(default_factory=dict) assets: list[Asset] = dataclasses.field(default_factory=list) management: Management = Management.writable @@ -119,7 +119,7 @@ def __init__(self, stream_resource: StreamResource, descriptor: EventDescriptor) # TODO: Check consistency with chunk_shape # Determine the machine data type - self.data_type: Union[BuiltinDtype, StructDtype] + self.data_type: BuiltinDtype | StructDtype dtype_numpy = np.dtype(data_desc.get("dtype_numpy")) # Falls back to np.dtype("float64") if not set if dtype_numpy.kind == "V": self.data_type = StructDtype.from_numpy_dtype(dtype_numpy) diff --git a/src/bluesky/log.py b/src/bluesky/log.py index 7fed6f6bc4..d23fb022ed 100644 --- a/src/bluesky/log.py +++ b/src/bluesky/log.py @@ -3,7 +3,6 @@ import logging import sys from types import ModuleType -from typing import Optional try: import colorama @@ -11,7 +10,7 @@ colorama.init() except ImportError: colorama = None -curses: Optional[ModuleType] +curses: ModuleType | None try: import curses except ImportError: diff --git a/src/bluesky/magics.py b/src/bluesky/magics.py index bf14861f9b..fa86d0cecd 100644 --- a/src/bluesky/magics.py +++ b/src/bluesky/magics.py @@ -236,6 +236,19 @@ def is_positioner(dev): return hasattr(dev, "position") +def _round(value, decimals): + """Round value; return str if np.round yields an ndarray. + + PseudoPositioner positions and limits are namedtuples; np.round + converts them to ndarray, which rejects string format specs under + numpy 2.x. + """ + result = np.round(value, decimals=decimals) + if isinstance(result, np.ndarray): + return str(result) + return result + + def _print_positioners(positioners, sort=True, precision=6, prefix=""): """ This will take a list of positioners and try to print them. @@ -272,20 +285,20 @@ def _print_positioners(positioners, sort=True, precision=6, prefix=""): prec = int(p.precision) except Exception: prec = precision - value = np.round(v, decimals=prec) + value = _round(v, decimals=prec) try: low_limit, high_limit = p.limits except Exception as exc: low_limit = high_limit = exc.__class__.__name__ else: - low_limit = np.round(low_limit, decimals=prec) - high_limit = np.round(high_limit, decimals=prec) + low_limit = _round(low_limit, decimals=prec) + high_limit = _round(high_limit, decimals=prec) try: offset = p.user_offset.get() except Exception as exc: offset = exc.__class__.__name__ else: - offset = np.round(offset, decimals=prec) + offset = _round(offset, decimals=prec) else: value = v.__class__.__name__ # e.g. 'DisconnectedError' low_limit = high_limit = offset = "" diff --git a/src/bluesky/plan_stubs.py b/src/bluesky/plan_stubs.py index c779dc3ace..706b7f3f87 100644 --- a/src/bluesky/plan_stubs.py +++ b/src/bluesky/plan_stubs.py @@ -6,7 +6,7 @@ import warnings from collections.abc import Awaitable, Callable, Hashable, Iterable, Mapping, Sequence from functools import reduce -from typing import Any, Literal, Optional, Union +from typing import Any, Literal from cycler import cycler @@ -193,7 +193,7 @@ def locate(*objs, squeeze=True): @plan -def monitor(obj: Readable, *, name: Optional[str] = None, **kwargs) -> MsgGenerator: +def monitor(obj: Readable, *, name: str | None = None, **kwargs) -> MsgGenerator: """ Asynchronously monitor for new values and emit Event documents. @@ -257,7 +257,7 @@ def null() -> MsgGenerator: def abs_set( obj: Movable, *args: Any, - group: Optional[Hashable] = None, + group: Hashable | None = None, wait: bool = False, **kwargs, ) -> MsgGenerator[Status]: @@ -305,7 +305,7 @@ def abs_set( def rel_set( obj: Movable, *args: Any, - group: Optional[Hashable] = None, + group: Hashable | None = None, wait: bool = False, **kwargs, ) -> MsgGenerator[Status]: @@ -349,9 +349,9 @@ def rel_set( # is not currently able to be represented in python's type system @plan def mv( - *args: Union[Movable, Any], - group: Optional[Hashable] = None, - timeout: Optional[float] = None, + *args: Movable | Any, + group: Hashable | None = None, + timeout: float | None = None, **kwargs, ) -> MsgGenerator[tuple[Status, ...]]: """ @@ -401,7 +401,7 @@ def mv( @plan def mvr( - *args: Union[Movable, Any], group: Optional[Hashable] = None, timeout: Optional[float] = None, **kwargs + *args: Movable | Any, group: Hashable | None = None, timeout: float | None = None, **kwargs ) -> MsgGenerator[tuple[Status, ...]]: """ Move one or more devices to a relative setpoint. Wait for all to complete. @@ -572,7 +572,7 @@ def stop(obj: Stoppable) -> MsgGenerator: def trigger( obj: Triggerable, *, - group: Optional[Hashable] = None, + group: Hashable | None = None, wait: bool = False, ) -> MsgGenerator[Status]: """ @@ -627,9 +627,9 @@ def sleep(time: float) -> MsgGenerator: @plan def wait( - group: Optional[Hashable] = None, + group: Hashable | None = None, *, - timeout: Optional[float] = None, + timeout: float | None = None, error_on_timeout: bool = True, watch: Sequence[str] = (), ): @@ -755,7 +755,7 @@ def input_plan(prompt: str = "") -> MsgGenerator[str]: @plan -def prepare(obj: Preparable, *args, group: Optional[Hashable] = None, wait: bool = False, **kwargs): +def prepare(obj: Preparable, *args, group: Hashable | None = None, wait: bool = False, **kwargs): """ Prepare a device ready for trigger or kickoff. @@ -792,7 +792,7 @@ def prepare(obj: Preparable, *args, group: Optional[Hashable] = None, wait: bool def kickoff( obj: Flyable, *, - group: Optional[Hashable] = None, + group: Hashable | None = None, wait: bool = False, **kwargs, ) -> MsgGenerator[Status]: @@ -834,7 +834,7 @@ def kickoff( @plan -def kickoff_all(*args, group: Optional[Hashable] = None, wait: bool = True, **kwargs): +def kickoff_all(*args, group: Hashable | None = None, wait: bool = True, **kwargs): """ Kickoff one or more fly-scanning devices. @@ -878,7 +878,7 @@ def kickoff_all(*args, group: Optional[Hashable] = None, wait: bool = True, **kw def complete( obj: Flyable, *, - group: Optional[Hashable] = None, + group: Hashable | None = None, wait: bool = False, **kwargs, ) -> MsgGenerator[Status]: @@ -927,7 +927,7 @@ def complete( @plan -def complete_all(*args, group: Optional[Hashable] = None, wait: bool = False, **kwargs): +def complete_all(*args, group: Hashable | None = None, wait: bool = False, **kwargs): """ Tell one or more flyable objects, 'stop collecting, whenever you are ready'. @@ -975,7 +975,7 @@ def complete_all(*args, group: Optional[Hashable] = None, wait: bool = False, ** @plan def collect( - obj: Flyable, *args, stream: bool = False, return_payload: bool = True, name: Optional[str] = None + obj: Flyable, *args, stream: bool = False, return_payload: bool = True, name: str | None = None ) -> MsgGenerator[list[PartialEvent]]: """ Collect data cached by one or more fly-scanning devices and emit documents. @@ -1081,9 +1081,9 @@ def configure( def stage( obj: Stageable, *, - group: Optional[Hashable] = None, - wait: Optional[bool] = None, -) -> MsgGenerator[Union[Status, list[Any]]]: + group: Hashable | None = None, + wait: bool | None = None, +) -> MsgGenerator[Status | list[Any]]: """ 'Stage' a device (i.e., prepare it for use, 'arm' it). @@ -1129,7 +1129,7 @@ def stage( @plan def stage_all( *args: Stageable, - group: Optional[Hashable] = None, + group: Hashable | None = None, ) -> MsgGenerator[None]: """ 'Stage' one or more devices (i.e., prepare them for use, 'arm' them). @@ -1166,9 +1166,9 @@ def stage_all( def unstage( obj: Stageable, *, - group: Optional[Hashable] = None, - wait: Optional[bool] = None, -) -> MsgGenerator[Union[Status, list[Any]]]: + group: Hashable | None = None, + wait: bool | None = None, +) -> MsgGenerator[Status | list[Any]]: """ 'Unstage' a device (i.e., put it in standby, 'disarm' it). @@ -1212,7 +1212,7 @@ def unstage( @plan -def unstage_all(*args: Stageable, group: Optional[Hashable] = None) -> MsgGenerator[None]: +def unstage_all(*args: Stageable, group: Hashable | None = None) -> MsgGenerator[None]: """ 'Unstage' one or more devices (i.e., put them in standby, 'disarm' them). @@ -1340,7 +1340,7 @@ def remove_suspender(suspender: SuspenderBase) -> MsgGenerator: @plan -def open_run(md: Optional[CustomPlanMetadata] = None) -> MsgGenerator[str]: +def open_run(md: CustomPlanMetadata | None = None) -> MsgGenerator[str]: """ Mark the beginning of a new 'run'. Emit a RunStart document. @@ -1367,7 +1367,7 @@ def open_run(md: Optional[CustomPlanMetadata] = None) -> MsgGenerator[str]: @plan -def close_run(exit_status: Optional[str] = None, reason: Optional[str] = None) -> MsgGenerator[str]: +def close_run(exit_status: str | None = None, reason: str | None = None) -> MsgGenerator[str]: """ Mark the end of the current 'run'. Emit a RunStop document. @@ -1522,7 +1522,7 @@ def broadcast_msg( @plan def repeater( - n: Optional[int], + n: int | None, gen_func: Callable[..., MsgGenerator], *args, **kwargs, @@ -1560,7 +1560,7 @@ def repeater( @plan -def caching_repeater(n: Optional[int], plan: MsgGenerator) -> MsgGenerator[None]: +def caching_repeater(n: int | None, plan: MsgGenerator) -> MsgGenerator[None]: """ Generate n chained copies of the messages in a plan. @@ -1594,7 +1594,7 @@ def caching_repeater(n: Optional[int], plan: MsgGenerator) -> MsgGenerator[None] @plan -def one_shot(detectors: Sequence[Readable], take_reading: Optional[TakeReading] = None) -> MsgGenerator[None]: +def one_shot(detectors: Sequence[Readable], take_reading: TakeReading | None = None) -> MsgGenerator[None]: """Inner loop of a count. This is the default function for ``per_shot`` in count plans. @@ -1628,7 +1628,7 @@ def one_1d_step( detectors: Sequence[Readable], motor: Movable, step: Any, - take_reading: Optional[TakeReading] = None, + take_reading: TakeReading | None = None, ) -> MsgGenerator[Mapping[str, Reading]]: """ Inner loop of a 1D step scan @@ -1708,7 +1708,7 @@ def one_nd_step( detectors: Sequence[Readable], step: Mapping[Movable, Any], pos_cache: dict[Movable, Any], - take_reading: Optional[TakeReading] = None, + take_reading: TakeReading | None = None, ) -> MsgGenerator[None]: """ Inner loop of an N-dimensional step scan @@ -1746,7 +1746,7 @@ def take_reading(dets, name='primary'): @plan def repeat( plan: Callable[[], MsgGenerator], - num: Optional[int] = 1, + num: int | None = 1, delay: ScalarOrIterableFloat = 0.0, ) -> MsgGenerator[Any]: """ diff --git a/src/bluesky/plans.py b/src/bluesky/plans.py index 9cc197f44b..ad5afafd04 100644 --- a/src/bluesky/plans.py +++ b/src/bluesky/plans.py @@ -4,10 +4,10 @@ import sys import time from collections import defaultdict -from collections.abc import Iterable, Mapping, Sequence +from collections.abc import Callable, Iterable, Mapping, Sequence from functools import partial from itertools import chain, zip_longest -from typing import Any, Callable, Optional, Union +from typing import Any, TypeAlias import numpy as np from cycler import Cycler @@ -31,11 +31,11 @@ ) #: Plan function that can be used for each shot in a detector acquisition involving no actuation -PerShot = Callable[[Sequence[Readable], Optional[bps.TakeReading]], MsgGenerator] +PerShot = Callable[[Sequence[Readable], bps.TakeReading | None], MsgGenerator] #: Plan function that can be used for each step in a scan PerStep1D = Callable[ - [Sequence[Readable], Movable, Any, Optional[bps.TakeReading]], + [Sequence[Readable], Movable, Any, bps.TakeReading | None], MsgGenerator, ] PerStepND = Callable[ @@ -43,11 +43,11 @@ Sequence[Readable], Mapping[Movable, Any], dict[Movable, Any], - Optional[bps.TakeReading], + bps.TakeReading | None, ], MsgGenerator, ] -PerStep = Union[PerStep1D, PerStepND] +PerStep: TypeAlias = PerStep1D | PerStepND def _check_detectors_type_input(detectors): @@ -65,11 +65,11 @@ def derive_default_hints(motors: list[Any]) -> dict[str, Sequence]: def count( detectors: Sequence[Readable], - num: Optional[int] = 1, + num: int | None = 1, delay: ScalarOrIterableFloat = 0.0, *, - per_shot: Optional[PerShot] = None, - md: Optional[CustomPlanMetadata] = None, + per_shot: PerShot | None = None, + md: CustomPlanMetadata | None = None, ) -> MsgGenerator[str]: """ Take one or more readings from detectors. @@ -131,9 +131,9 @@ def inner_count() -> MsgGenerator[str]: def list_scan( detectors: Sequence[Readable], - *args: tuple[Union[Movable, Any], list[Any]], - per_step: Optional[PerStep] = None, - md: Optional[CustomPlanMetadata] = None, + *args: tuple[Movable | Any, list[Any]], + per_step: PerStep | None = None, + md: CustomPlanMetadata | None = None, ) -> MsgGenerator[str]: """ Scan over one or more variables in steps simultaneously (inner product). @@ -224,9 +224,9 @@ def list_scan( def rel_list_scan( detectors: Sequence[Readable], - *args: Union[Movable, Any], - per_step: Optional[PerStep] = None, - md: Optional[CustomPlanMetadata] = None, + *args: Movable | Any, + per_step: PerStep | None = None, + md: CustomPlanMetadata | None = None, ) -> MsgGenerator[str]: """ Scan over one variable in steps relative to current position. @@ -281,10 +281,10 @@ def inner_relative_list_scan(): def list_grid_scan( detectors: Sequence[Readable], - *args: Union[Movable, Any], + *args: Movable | Any, snake_axes: bool = False, - per_step: Optional[PerStep] = None, - md: Optional[CustomPlanMetadata] = None, + per_step: PerStep | None = None, + md: CustomPlanMetadata | None = None, ) -> MsgGenerator[str]: """ Scan over a mesh; each motor is on an independent trajectory. @@ -358,10 +358,10 @@ def list_grid_scan( def rel_list_grid_scan( detectors: Sequence[Readable], - *args: Union[Movable, Any], + *args: Movable | Any, snake_axes: bool = False, - per_step: Optional[PerStep] = None, - md: Optional[CustomPlanMetadata] = None, + per_step: PerStep | None = None, + md: CustomPlanMetadata | None = None, ) -> MsgGenerator[str]: """ Scan over a mesh; each motor is on an independent trajectory. Each point is @@ -425,8 +425,8 @@ def _scan_1d( stop: float, num: int, *, - per_step: Optional[PerStep] = None, - md: Optional[CustomPlanMetadata] = None, + per_step: PerStep | None = None, + md: CustomPlanMetadata | None = None, ) -> MsgGenerator[str]: """ Scan over one variable in equally spaced steps. @@ -502,8 +502,8 @@ def _rel_scan_1d( stop: float, num: int, *, - per_step: Optional[PerStep] = None, - md: Optional[CustomPlanMetadata] = None, + per_step: PerStep | None = None, + md: CustomPlanMetadata | None = None, ) -> MsgGenerator[str]: """ Scan over one variable in equally spaced steps relative to current positon. @@ -630,8 +630,8 @@ def rel_log_scan( stop: float, num: int, *, - per_step: Optional[PerStep] = None, - md: Optional[CustomPlanMetadata] = None, + per_step: PerStep | None = None, + md: CustomPlanMetadata | None = None, ) -> MsgGenerator[str]: """ Scan over one variable in log-spaced steps relative to current position. @@ -680,9 +680,9 @@ def adaptive_scan( max_step: float, target_delta: float, backstep: bool, - threshold: Optional[float] = 0.8, + threshold: float | None = 0.8, *, - md: Optional[CustomPlanMetadata] = None, + md: CustomPlanMetadata | None = None, ) -> MsgGenerator[str]: """ Scan over one variable with adaptively tuned step size. @@ -809,9 +809,9 @@ def rel_adaptive_scan( max_step: float, target_delta: float, backstep: bool, - threshold: Optional[float] = 0.8, + threshold: float | None = 0.8, *, - md: Optional[CustomPlanMetadata] = None, + md: CustomPlanMetadata | None = None, ) -> MsgGenerator[str]: """ Relative scan over one variable with adaptively tuned step size. @@ -881,7 +881,7 @@ def tune_centroid( step_factor: float = 3.0, snake: bool = False, *, - md: Optional[CustomPlanMetadata] = None, + md: CustomPlanMetadata | None = None, ) -> MsgGenerator[str]: r""" plan: tune a motor to the centroid of signal(motor) @@ -1027,8 +1027,8 @@ def scan_nd( detectors: Sequence[Readable], cycler: Cycler, *, - per_step: Optional[PerStep] = None, - md: Optional[CustomPlanMetadata] = None, + per_step: PerStep | None = None, + md: CustomPlanMetadata | None = None, ) -> MsgGenerator[str]: """ Scan over an arbitrary N-dimensional trajectory. @@ -1171,9 +1171,9 @@ def inner_scan_nd(): def inner_product_scan( detectors: Sequence[Readable], num: int, - *args: Union[Movable, Any], - per_step: Optional[PerStep] = None, - md: Optional[CustomPlanMetadata] = None, + *args: Movable | Any, + per_step: PerStep | None = None, + md: CustomPlanMetadata | None = None, ) -> MsgGenerator[None]: # For scan, num is the _last_ positional arg instead of the first one. # Notice the swapped order here. @@ -1184,10 +1184,10 @@ def inner_product_scan( def scan( detectors: Sequence[Readable], - *args: Union[Movable, Any], - num: Optional[int] = None, - per_step: Optional[PerStep] = None, - md: Optional[CustomPlanMetadata] = None, + *args: Movable | Any, + num: int | None = None, + per_step: PerStep | None = None, + md: CustomPlanMetadata | None = None, ) -> MsgGenerator[str]: """ Scan over one multi-motor trajectory. @@ -1294,9 +1294,9 @@ def scan( def grid_scan( detectors: Sequence[Readable], *args, - snake_axes: Optional[Union[Iterable, bool]] = None, - per_step: Optional[PerStep] = None, - md: Optional[CustomPlanMetadata] = None, + snake_axes: Iterable | bool | None = None, + per_step: PerStep | None = None, + md: CustomPlanMetadata | None = None, ) -> MsgGenerator[str]: """ Scan over a mesh; each motor is on an independent trajectory. @@ -1470,10 +1470,10 @@ def _set_snaking(chunk, value): def rel_grid_scan( detectors: Sequence[Readable], - *args: Union[Movable, Any], - snake_axes: Optional[Union[Iterable, bool]] = None, - per_step: Optional[PerStep] = None, - md: Optional[CustomPlanMetadata] = None, + *args: Movable | Any, + snake_axes: Iterable | bool | None = None, + per_step: PerStep | None = None, + md: CustomPlanMetadata | None = None, ) -> MsgGenerator[str]: """ Scan over a mesh relative to current position. @@ -1530,9 +1530,9 @@ def inner_rel_grid_scan(): def relative_inner_product_scan( # type: ignore detectors: Sequence[Readable], num: int, - *args: Union[Movable, Any], - per_step: Optional[PerStep] = None, - md: Optional[CustomPlanMetadata] = None, + *args: Movable | Any, + per_step: PerStep | None = None, + md: CustomPlanMetadata | None = None, ) -> MsgGenerator[str]: # For rel_scan, num is the _last_ positional arg instead of the first one. # Notice the swapped order here. @@ -1543,10 +1543,10 @@ def relative_inner_product_scan( # type: ignore def rel_scan( detectors: Sequence[Readable], - *args: Union[Movable, Any], + *args: Movable | Any, num=None, - per_step: Optional[PerStep] = None, - md: Optional[CustomPlanMetadata] = None, + per_step: PerStep | None = None, + md: CustomPlanMetadata | None = None, ) -> MsgGenerator[str]: """ Scan over one multi-motor trajectory relative to current position. @@ -1602,7 +1602,7 @@ def tweak( motor: NamedMovable, step: float, *, - md: Optional[CustomPlanMetadata] = None, + md: CustomPlanMetadata | None = None, ) -> MsgGenerator[str]: """ Move and motor and read a detector with an interactive prompt. @@ -1693,10 +1693,10 @@ def spiral_fermat( dr: float, factor: float, *, - dr_y: Optional[float] = None, - tilt: Optional[float] = 0.0, - per_step: Optional[PerStep] = None, - md: Optional[CustomPlanMetadata] = None, + dr_y: float | None = None, + tilt: float | None = 0.0, + per_step: PerStep | None = None, + md: CustomPlanMetadata | None = None, ) -> MsgGenerator[str]: """Absolute fermat spiral scan, centered around (x_start, y_start) @@ -1797,10 +1797,10 @@ def rel_spiral_fermat( dr: float, factor: float, *, - dr_y: Optional[float] = None, - tilt: Optional[float] = 0.0, - per_step: Optional[PerStep] = None, - md: Optional[CustomPlanMetadata] = None, + dr_y: float | None = None, + tilt: float | None = 0.0, + per_step: PerStep | None = None, + md: CustomPlanMetadata | None = None, ) -> MsgGenerator[str]: """Relative fermat spiral scan @@ -1876,10 +1876,10 @@ def spiral( dr: float, nth: float, *, - dr_y: Optional[float] = None, - tilt: Optional[float] = 0.0, - per_step: Optional[PerStep] = None, - md: Optional[CustomPlanMetadata] = None, + dr_y: float | None = None, + tilt: float | None = 0.0, + per_step: PerStep | None = None, + md: CustomPlanMetadata | None = None, ) -> MsgGenerator[str]: """Spiral scan, centered around (x_start, y_start) @@ -1978,10 +1978,10 @@ def rel_spiral( dr: float, nth: float, *, - dr_y: Optional[float] = None, + dr_y: float | None = None, tilt: float = 0.0, - per_step: Optional[PerStep] = None, - md: Optional[CustomPlanMetadata] = None, + per_step: PerStep | None = None, + md: CustomPlanMetadata | None = None, ) -> MsgGenerator[str]: """Relative spiral scan @@ -2054,8 +2054,8 @@ def spiral_square( x_num: float, y_num: float, *, - per_step: Optional[PerStep] = None, - md: Optional[CustomPlanMetadata] = None, + per_step: PerStep | None = None, + md: CustomPlanMetadata | None = None, ) -> MsgGenerator[str]: """Absolute square spiral scan, centered around (x_center, y_center) @@ -2150,8 +2150,8 @@ def rel_spiral_square( x_num: float, y_num: float, *, - per_step: Optional[PerStep] = None, - md: Optional[CustomPlanMetadata] = None, + per_step: PerStep | None = None, + md: CustomPlanMetadata | None = None, ) -> MsgGenerator[str]: """Relative square spiral scan, centered around current (x, y) position. @@ -2216,9 +2216,9 @@ def ramp_plan( monitor_sig: Readable, inner_plan_func: Callable[[], MsgGenerator], take_pre_data: bool = True, - timeout: Optional[float] = None, - period: Optional[float] = None, - md: Optional[CustomPlanMetadata] = None, + timeout: float | None = None, + period: float | None = None, + md: CustomPlanMetadata | None = None, ) -> MsgGenerator[str]: """Take data while ramping one or more positioners. @@ -2305,7 +2305,7 @@ def polling_plan(): def fly( flyers: list[Flyable], *, - md: Optional[CustomPlanMetadata] = None, + md: CustomPlanMetadata | None = None, ) -> MsgGenerator[str]: """ Perform a fly scan with one or more 'flyers'. @@ -2346,8 +2346,8 @@ def x2x_scan( stop: float, num: int, *, - per_step: Optional[PerStep] = None, - md: Optional[CustomPlanMetadata] = None, + per_step: PerStep | None = None, + md: CustomPlanMetadata | None = None, ) -> MsgGenerator[str]: """ Relatively scan over two motors in a 2:1 ratio diff --git a/src/bluesky/protocols.py b/src/bluesky/protocols.py index bc207ee9ea..a249d019ae 100644 --- a/src/bluesky/protocols.py +++ b/src/bluesky/protocols.py @@ -1,16 +1,6 @@ from abc import abstractmethod -from collections.abc import AsyncIterator, Awaitable, Iterator -from typing import ( - Any, - Callable, - Generic, - Literal, - Optional, - Protocol, - TypeVar, - Union, - runtime_checkable, -) +from collections.abc import AsyncIterator, Awaitable, Callable, Iterator +from typing import Any, Generic, Literal, Protocol, TypeAlias, TypeVar, runtime_checkable from event_model.documents import Datum, StreamDatum, StreamResource from event_model.documents.event import PartialEvent @@ -52,20 +42,16 @@ class Reading(Generic[T], ReadingOptional): timestamp: float -Asset = Union[ - tuple[Literal["resource"], PartialResource], - tuple[Literal["datum"], Datum], -] +Asset: TypeAlias = tuple[Literal["resource"], PartialResource] | tuple[Literal["datum"], Datum] -StreamAsset = Union[ - tuple[Literal["stream_resource"], StreamResource], - tuple[Literal["stream_datum"], StreamDatum], -] +StreamAsset: TypeAlias = ( + tuple[Literal["stream_resource"], StreamResource] | tuple[Literal["stream_datum"], StreamDatum] +) -SyncOrAsync = Union[T, Awaitable[T]] -SyncOrAsyncIterator = Union[Iterator[T], AsyncIterator[T]] +SyncOrAsync: TypeAlias = T | Awaitable[T] +SyncOrAsyncIterator: TypeAlias = Iterator[T] | AsyncIterator[T] @runtime_checkable @@ -82,7 +68,7 @@ def add_callback(self, callback: Callable[["Status"], None]) -> None: ... @abstractmethod - def exception(self, timeout: Optional[float] = 0.0) -> Optional[BaseException]: ... + def exception(self, timeout: float | None = 0.0) -> BaseException | None: ... @property @abstractmethod @@ -104,7 +90,7 @@ class HasName(Protocol): def name(self) -> str: """Used to populate object_keys in the Event DataKey - https://blueskyproject.io/event-model/event-descriptors.html#object-keys""" + https://blueskyproject.io/event-model/main/explanations/event-descriptors.html#object-keys""" ... @@ -112,7 +98,7 @@ def name(self) -> str: class HasParent(Protocol): @property @abstractmethod - def parent(self) -> Optional[Any]: + def parent(self) -> Any | None: """``None``, or a reference to a parent device. Used by the RE to stop duplicate stages. @@ -152,7 +138,7 @@ def collect_asset_docs(self) -> SyncOrAsyncIterator[Asset]: @runtime_checkable class WritesStreamAssets(Protocol): @abstractmethod - def collect_asset_docs(self, index: Optional[int] = None) -> SyncOrAsyncIterator[StreamAsset]: + def collect_asset_docs(self, index: int | None = None) -> SyncOrAsyncIterator[StreamAsset]: """Create the resource and datum documents describing data in external source up to a given index if provided. @@ -304,7 +290,7 @@ def describe(self) -> SyncOrAsync[dict[str, DataKey]]: @runtime_checkable class Collectable(HasName, Protocol): @abstractmethod - def describe_collect(self) -> SyncOrAsync[Union[dict[str, DataKey], dict[str, dict[str, DataKey]]]]: + def describe_collect(self) -> SyncOrAsync[dict[str, DataKey] | dict[str, dict[str, DataKey]]]: """This is like ``describe()`` on readable devices, but with an extra layer of nesting. Since a flyer can potentially return more than one event stream, this is either @@ -396,7 +382,7 @@ class Stageable(Protocol): # TODO: we were going to extend these to be able to return plans, what # signature should they have? @abstractmethod - def stage(self) -> Union[Status, list[Any]]: + def stage(self) -> Status | list[Any]: """An optional hook for "setting up" the device for acquisition. It should return a ``Status`` that is marked done when the device is @@ -405,7 +391,7 @@ def stage(self) -> Union[Status, list[Any]]: ... @abstractmethod - def unstage(self) -> Union[Status, list[Any]]: + def unstage(self) -> Status | list[Any]: """A hook for "cleaning up" the device after acquisition. It should return a ``Status`` that is marked done when the device is finished diff --git a/src/bluesky/run_engine.py b/src/bluesky/run_engine.py index 8a16844f16..daf9592936 100644 --- a/src/bluesky/run_engine.py +++ b/src/bluesky/run_engine.py @@ -9,6 +9,7 @@ import typing import weakref from collections import ChainMap, defaultdict, deque +from collections.abc import Callable from contextlib import ExitStack from dataclasses import dataclass from datetime import datetime @@ -17,6 +18,7 @@ from itertools import count from warnings import warn +import event_model from event_model import DocumentNames from opentelemetry import trace from opentelemetry.trace import Span @@ -62,13 +64,14 @@ Subscribers, ensure_generator, normalize_subs_input, + sanitize_np, single_gen, warn_if_msg_args_or_kwargs, ) _SPAN_NAME_PREFIX = "Bluesky RunEngine" -current_task: typing.Callable[[typing.Optional[asyncio.AbstractEventLoop]], typing.Optional[asyncio.Task]] +current_task: typing.Callable[[asyncio.AbstractEventLoop | None], asyncio.Task | None] try: from asyncio import current_task except ImportError: @@ -110,7 +113,7 @@ class RunEngineResult: exit_status: str interrupted: bool reason: str - exception: typing.Optional[Exception] + exception: Exception | None class RunEngineStateMachine(StateMachine): @@ -402,15 +405,15 @@ def deferred_pause_requested(self): def __init__( self, - md: typing.Optional[dict] = None, + md: dict | None = None, *, - loop: typing.Optional[asyncio.AbstractEventLoop] = None, - preprocessors: typing.Optional[list] = None, - context_managers: typing.Optional[list] = None, - md_validator: typing.Optional[typing.Callable] = None, - md_normalizer: typing.Optional[typing.Callable] = None, + loop: asyncio.AbstractEventLoop | None = None, + preprocessors: list | None = None, + context_managers: list | None = None, + md_validator: typing.Callable | None = None, + md_normalizer: typing.Callable | None = None, scan_id_source: typing.Callable[[dict], SyncOrAsync[int]] = default_scan_id_source, - during_task: typing.Optional[DuringTask] = None, + during_task: DuringTask | None = None, call_returns_result: bool = False, ): if loop is None: @@ -467,6 +470,7 @@ def setup_run_permit(): from ._version import __version__ self.md["versions"]["bluesky"] = __version__ + self.md["versions"]["event_model"] = event_model.__version__ if preprocessors is None: preprocessors = [] @@ -507,7 +511,9 @@ def setup_run_permit(): self._movable_objs_touched: set[typing.Any] = set() # objects we moved at any point self._run_start_uids: list[typing.Any] = list() # run start uids generated by __call__ # noqa: C408 self._suspenders: set[typing.Any] = set() # set holding suspenders - self._groups: defaultdict[typing.Any, set[typing.Any]] = defaultdict(set) # sets of Events to wait for + self._groups: defaultdict[str, set[Callable[[], asyncio.Future]]] = defaultdict( + set + ) # sets of Events to wait for self._status_objs: defaultdict[typing.Any, set[typing.Any]] = defaultdict( set ) # status objects to wait for @@ -866,10 +872,10 @@ def _create_result(self, plan_return): def __call__( self, plan: typing.Iterable[Msg], - subs: typing.Optional[Subscribers] = None, + subs: Subscribers | None = None, /, **metadata_kw: typing.Any, - ) -> typing.Union[RunEngineResult, tuple[str, ...]]: + ) -> RunEngineResult | tuple[str, ...]: """Execute a plan. Any keyword arguments will be interpreted as metadata and recorded with @@ -1908,8 +1914,8 @@ def _close_run_trace(self, msg: Msg): reason = msg.kwargs.get("reason", self._reason) try: _span: Span = self._run_tracing_spans.pop() - _span.set_attribute("exit_status", exit_status) - _span.set_attribute("reason", reason) + _span.set_attribute("exit_status", exit_status if exit_status is not None else "None") + _span.set_attribute("reason", reason if reason is not None else "None") _span.end() except IndexError: logger.warning("No open traces left to close!") @@ -2379,6 +2385,10 @@ def _status_object_completed(self, ret, fut: asyncio.Future, pardon_failures): except Exception as e: self._exception = e fut.set_exception(e) + # We have set the exception, but we don't mind if + # no-one collects it from the future, so fetch it ourselves to + # squash "Future exception was never retrieved" at teardown. + fut.exception() else: fut.set_result(None) @@ -2804,7 +2814,7 @@ def ignore_exceptions(self, val): def _set_span_msg_attributes(span, msg): span.set_attribute("msg.command", msg.command) - span.set_attribute("msg.args", msg.args) + span.set_attribute("msg.args", sanitize_np(msg.args)) span.set_attribute("msg.kwargs", json.dumps(msg.kwargs, default=repr)) span.set_attribute("msg.obj", repr(msg.obj)) if msg.obj else span.set_attribute("msg.no_obj_given", True) @@ -2866,7 +2876,7 @@ def in_bluesky_event_loop() -> bool: return loop is _bluesky_event_loop -def call_in_bluesky_event_loop(coro: typing.Awaitable[T], timeout: typing.Optional[float] = None) -> T: +def call_in_bluesky_event_loop(coro: typing.Awaitable[T], timeout: float | None = None) -> T: if _bluesky_event_loop is None or not _bluesky_event_loop.is_running(): # Quell "coroutine never awaited" warnings if iscoroutine(coro): diff --git a/src/bluesky/simulators.py b/src/bluesky/simulators.py index d09c14326a..ebbb10f244 100644 --- a/src/bluesky/simulators.py +++ b/src/bluesky/simulators.py @@ -1,12 +1,9 @@ -from collections.abc import Generator, Sequence +from collections.abc import Callable, Generator, Sequence from itertools import dropwhile from time import time from typing import ( Any, - Callable, Literal, - Optional, - Union, cast, ) from warnings import warn @@ -172,10 +169,10 @@ def add_handler_for_callback_subscribes(self): def add_handler( self, - commands: Union[str, Sequence[str]], + commands: str | Sequence[str], handler: Callable[[Msg], object], - msg_filter: Optional[Union[str, Callable[[Msg], bool]]] = None, - index: Union[int, Literal["end"]] = 0, + msg_filter: str | Callable[[Msg], bool] | None = None, + index: int | Literal["end"] = 0, ): """Add the specified handler for a particular message. @@ -203,17 +200,19 @@ def add_handler( self.message_handlers.insert( cast(int, index if index != END else len(self.message_handlers)), _MessageHandler( - lambda msg: msg.command in commands - and ( - msg_filter is None - or (callable(msg_filter) and msg_filter(msg)) - or (msg.obj and msg.obj.name == msg_filter) + lambda msg: ( + msg.command in commands + and ( + msg_filter is None + or (callable(msg_filter) and msg_filter(msg)) + or (msg.obj and msg.obj.name == msg_filter) + ) ), handler, ), ) - def add_read_handler_for(self, obj: Readable, value: Optional[Any]): + def add_read_handler_for(self, obj: Readable, value: Any | None): """ Convenience method to register a handler to return a result from a single-valued 'read' command. @@ -302,7 +301,7 @@ def add_wait_handler(self, handler: Callable[[Msg], None], group: str = GROUP_AN self.add_handler( "wait", handler, - lambda msg: (group == RunEngineSimulator.GROUP_ANY or msg.kwargs["group"] == group), + lambda msg: group == RunEngineSimulator.GROUP_ANY or msg.kwargs["group"] == group, ) def add_callback_handler_for( @@ -310,7 +309,7 @@ def add_callback_handler_for( command: str, document_name: str, document: dict, - msg_filter: Optional[Callable[[Msg], bool]] = None, + msg_filter: Callable[[Msg], bool] | None = None, ): """Add a handler to fire a callback when a matching command is encountered. Equivalent to add_callback_for_multiple(command, [[(document_name, document)]], msg_filter) @@ -331,7 +330,7 @@ def add_callback_handler_for_multiple( self, command: str, docs: Sequence[Sequence[tuple[str, dict]]], - msg_filter: Optional[Callable[[Msg], bool]] = None, + msg_filter: Callable[[Msg], bool] | None = None, ): """Add a handler to fire callbacks in sequence when a matching command is encountered. @@ -409,7 +408,7 @@ def _add_callback(self, msg_args): def assert_message_and_return_remaining( messages: list[Msg], predicate: Callable[[Msg], bool], - group: Optional[str] = None, + group: str | None = None, ): """Find the next message matching the predicate, assert that we found it. diff --git a/src/bluesky/suspenders.py b/src/bluesky/suspenders.py index d7bbe18924..d092169d00 100644 --- a/src/bluesky/suspenders.py +++ b/src/bluesky/suspenders.py @@ -138,7 +138,7 @@ def __call__(self, value, **kwargs): if self._ev is None and self.RE is not None: self.__make_event() if self._ev is None: - raise RuntimeError("Could not create the ") + raise RuntimeError("Could not create the suspender event") cb = partial( self.RE.request_suspend, self._ev.wait, @@ -156,15 +156,19 @@ def __make_event(self): """Make or return the asyncio.Event to use as a bridge.""" assert self._lock.locked() if self._ev is None and self.RE is not None: - th_ev = threading.Event() - - def really_make_the_event(): + if threading.get_ident() == getattr(self.RE._loop, "_thread_id", "unknown"): self._ev = asyncio.Event() - th_ev.set() + return self._ev + else: + th_ev = threading.Event() + + def really_make_the_event(): + self._ev = asyncio.Event() + th_ev.set() - h = self.RE._loop.call_soon_threadsafe(really_make_the_event) - if not th_ev.wait(0.1): - h.cancel() + h = self.RE._loop.call_soon_threadsafe(really_make_the_event) + if not th_ev.wait(0.1): + h.cancel() return self._ev def __set_event(self, loop): diff --git a/src/bluesky/tests/__init__.py b/src/bluesky/tests/__init__.py index 55bae520d3..c461ba2931 100644 --- a/src/bluesky/tests/__init__.py +++ b/src/bluesky/tests/__init__.py @@ -1,11 +1,10 @@ import os from types import ModuleType -from typing import Optional import pytest # some module level globals. -ophyd: Optional[ModuleType] +ophyd: ModuleType | None ophyd = None reason = "" diff --git a/src/bluesky/tests/conftest.py b/src/bluesky/tests/conftest.py index 68356adad8..f6bdc6c4f2 100644 --- a/src/bluesky/tests/conftest.py +++ b/src/bluesky/tests/conftest.py @@ -1,11 +1,15 @@ import asyncio import os +import signal +import threading +from unittest.mock import patch import numpy as np import packaging import pytest from bluesky.run_engine import RunEngine, TransitionError +from bluesky.utils import SigintHandler @pytest.fixture(scope="function", params=[False, True]) @@ -76,3 +80,64 @@ def cleanup_any_figures(request): "Close any matplotlib figures that were opened during a test." plt.close("all") + + +class DeterministicSigint: + """Sends SIGINT signals with a fake monotonic clock so that every signal + deterministically clears the 100ms debounce in SigintHandler. + + The fake clock advances by 0.2s per ``send()`` call, and each call blocks + until the signal handler has finished, so ``_count`` increments reliably + regardless of real wall-clock jitter. + """ + + def __init__(self): + self._fake_time = 0.0 + self._handler_done = threading.Event() + self._pid = os.getpid() + self._orig_enter = SigintHandler.__enter__ + self._patcher = patch.object(SigintHandler, "__enter__", self._patched_enter) + + def _monotonic(self): + return self._fake_time + + def _patched_enter(self, sigint_handler): + with patch("bluesky.utils.time.monotonic", self._monotonic): + result = self._orig_enter(sigint_handler) + installed = signal.getsignal(signal.SIGINT) + + def synced_handler(signum, frame): + try: + with patch("bluesky.utils.time.monotonic", self._monotonic): + installed(signum, frame) + finally: + self._handler_done.set() + + signal.signal(signal.SIGINT, synced_handler) + return result + + def send(self): + """Send one SIGINT and wait for the handler to finish.""" + self._handler_done.clear() + self._fake_time += 0.2 + os.kill(self._pid, signal.SIGINT) + self._handler_done.wait() + + def __enter__(self): + self._patcher.start() + return self + + def __exit__(self, *exc): + self._patcher.stop() + + +@pytest.fixture +def deterministic_sigint(): + """Fixture providing the ``DeterministicSigint`` class. Tests should use + it as a context manager around the code that runs the RE:: + + with deterministic_sigint() as sigint: + ... + sigint.send() + """ + return DeterministicSigint diff --git a/src/bluesky/tests/test_bec.py b/src/bluesky/tests/test_bec.py index af5dcb2c7b..ea1fbf5424 100644 --- a/src/bluesky/tests/test_bec.py +++ b/src/bluesky/tests/test_bec.py @@ -333,3 +333,17 @@ def test_many_motors(RE, hw): assert not bec._live_grids assert not bec._live_scatters assert bec._table is not None + + +def test_empty_motors(RE, hw): + def simple_count( + readables, + md, + ): + yield from bps.open_run(md=md) + yield from bps.trigger_and_read(readables) + + bec = BestEffortCallback() + RE.subscribe(bec) + # should not raise + RE(simple_count([hw.det], md={"motors": []})) diff --git a/src/bluesky/tests/test_event_loop_funcs.py b/src/bluesky/tests/test_event_loop_funcs.py index 0776b73cff..60561486a5 100644 --- a/src/bluesky/tests/test_event_loop_funcs.py +++ b/src/bluesky/tests/test_event_loop_funcs.py @@ -1,5 +1,4 @@ import asyncio -import sys import pytest @@ -24,12 +23,8 @@ async def check(): nonlocal event_loop event_loop = asyncio.get_running_loop() - if sys.version_info >= (3, 10): - # For some reason asyncio.run reuses the RE loop, then closes - # it at the end on python 3.9 and below, which makes test_examples.py - # fail, so skip this bit for that python - asyncio.run(check()) - assert event_loop and event_loop != RE._loop + asyncio.run(check()) + assert event_loop and event_loop != RE._loop call_in_bluesky_event_loop(check()) assert event_loop == RE._loop diff --git a/src/bluesky/tests/test_examples.py b/src/bluesky/tests/test_examples.py index 7b835bfa79..0804459846 100644 --- a/src/bluesky/tests/test_examples.py +++ b/src/bluesky/tests/test_examples.py @@ -385,7 +385,7 @@ def sim_kill(): assert RE.state == "idle" start = ttime.time() threading.Timer(0.1, sim_kill).start() - threading.Timer(0.2, sim_kill).start() + threading.Timer(0.25, sim_kill).start() threading.Timer(1, done).start() with pytest.raises(RunEngineInterrupted): @@ -429,7 +429,7 @@ def sim_kill(): assert RE.state == "idle" start = ttime.time() threading.Timer(0.1, sim_kill).start() - threading.Timer(0.2, sim_kill).start() + threading.Timer(0.25, sim_kill).start() threading.Timer(0.4, done).start() with pytest.raises(RunEngineInterrupted): RE(scan) diff --git a/src/bluesky/tests/test_external_assets_and_paging.py b/src/bluesky/tests/test_external_assets_and_paging.py index eb41afc17d..bab29cad4e 100644 --- a/src/bluesky/tests/test_external_assets_and_paging.py +++ b/src/bluesky/tests/test_external_assets_and_paging.py @@ -1,6 +1,5 @@ import re from collections.abc import Iterator -from typing import Optional import pytest from event_model.documents import Datum @@ -128,7 +127,7 @@ def get_index(self) -> int: return 10 -def collect_asset_docs_stream_datum(self: Named, index: Optional[int] = None) -> Iterator[StreamAsset]: +def collect_asset_docs_stream_datum(self: Named, index: int | None = None) -> Iterator[StreamAsset]: """Produce a StreamResource and StreamDatum for 2 data keys for 0:index""" index = index or 1 for data_key in [f"{self.name}-sd1", f"{self.name}-sd2"]: @@ -166,9 +165,9 @@ def describe_pv(self: Named) -> dict[str, DataKey]: return {f"{self.name}-pv": DataKey(source="pv", dtype="number", shape=[])} -def read_pv(self: Named) -> dict[str, Reading]: +def read_pv(self: Named, value=5.8) -> dict[str, Reading]: """Read a single data_key from a PV""" - return {f"{self.name}-pv": Reading(value=5.8, timestamp=123)} + return {f"{self.name}-pv": Reading(value=value, timestamp=123)} class PvAndDatumReadable(Named, Readable, WritesExternalAssets): diff --git a/src/bluesky/tests/test_flyer.py b/src/bluesky/tests/test_flyer.py index 1a72bfa756..d04f0d0844 100644 --- a/src/bluesky/tests/test_flyer.py +++ b/src/bluesky/tests/test_flyer.py @@ -352,8 +352,8 @@ def test_device_redundent_config_reading(RE): ) assert flyer.call_counts["collect"] == 4 assert flyer.call_counts["describe_collect"] == 1 - assert flyer.call_counts["read_configuration"] == 1 - assert flyer.call_counts["describe_configuration"] == 1 + assert flyer.call_counts["read_configuration"] == 2 + assert flyer.call_counts["describe_configuration"] == 2 assert flyer.call_counts["kickoff"] == 1 assert flyer.call_counts["complete"] == 1 diff --git a/src/bluesky/tests/test_magics.py b/src/bluesky/tests/test_magics.py index aa90d4b692..ffd9f9769b 100644 --- a/src/bluesky/tests/test_magics.py +++ b/src/bluesky/tests/test_magics.py @@ -3,13 +3,30 @@ from types import SimpleNamespace import pytest +from ophyd import Component as Cpt +from ophyd import PseudoPositioner, PseudoSingle, SoftPositioner import bluesky.plan_stubs as bps import bluesky.plans as bp -from bluesky.magics import BlueskyMagics +from bluesky.magics import BlueskyMagics, _print_positioners from bluesky.tests import uses_os_kill_sigint +class TwoAxisPseudo(PseudoPositioner): + """Minimal two-axis PseudoPositioner for testing.""" + + px = Cpt(PseudoSingle) + py = Cpt(PseudoSingle) + rx = Cpt(SoftPositioner, init_pos=0) + ry = Cpt(SoftPositioner, init_pos=0) + + def forward(self, pp): + return self.RealPosition(rx=pp.px, ry=pp.py) + + def inverse(self, rp): + return self.PseudoPosition(px=rp.rx, py=rp.ry) + + class FakeIPython: def __init__(self, user_ns): self.user_ns = user_ns @@ -169,3 +186,19 @@ def sim_kill(n=1): sm.RE.loop.call_later(1, sim_kill, 2) sm.mov("motor 1") assert sm.RE.state == "idle" + + +def test_print_positioners_pseudo_positioner(): + # PseudoPositioner.position is a namedtuple; np.round returns an ndarray + # under numpy 2.x, which raises TypeError with string format specs. + dev = TwoAxisPseudo(name="dev") + _print_positioners([dev]) # must not raise TypeError + + +def test_wa_pseudo_positioner(): + # %wa must not raise TypeError for multi-axis PseudoPositioner objects. + dev = TwoAxisPseudo(name="dev", labels=["motors"]) + ip = FakeIPython({"dev": dev}) + sm = BlueskyMagics(ip) + sm.wa("") # must not raise TypeError + sm.wa("motors") # must not raise TypeError diff --git a/src/bluesky/tests/test_metadata.py b/src/bluesky/tests/test_metadata.py index a578e66a0c..a48c650115 100644 --- a/src/bluesky/tests/test_metadata.py +++ b/src/bluesky/tests/test_metadata.py @@ -1,3 +1,4 @@ +import event_model import ophyd import bluesky @@ -12,6 +13,10 @@ def test_ophydversion(RE): assert RE.md["versions"].get("ophyd") == ophyd.__version__ +def test_eventmodelversion(RE): + assert RE.md["versions"].get("event_model") == event_model.__version__ + + def test_old_md_validator(RE): """ Test an old-style md_validator. diff --git a/src/bluesky/tests/test_new_examples.py b/src/bluesky/tests/test_new_examples.py index d1966571a3..1e5aaa5b6b 100644 --- a/src/bluesky/tests/test_new_examples.py +++ b/src/bluesky/tests/test_new_examples.py @@ -3,8 +3,10 @@ import time as ttime from collections import defaultdict from types import SimpleNamespace +from unittest.mock import Mock import pytest +from event_model.documents.event_descriptor import DataKey import bluesky.plans as bp from bluesky import Msg, RunEngineInterrupted @@ -69,7 +71,8 @@ subs_wrapper, suspend_wrapper, ) -from bluesky.protocols import Descriptor, Locatable, Location, Readable, Reading, Status +from bluesky.protocols import Descriptor, Locatable, Location, Movable, Readable, Reading, Status +from bluesky.tests.test_external_assets_and_paging import DocHolder, Named, describe_pv, read_pv from bluesky.utils import IllegalMessageSequence, all_safe_rewind @@ -971,3 +974,85 @@ def new_per_shot(detectors): with pytest.raises(IllegalMessageSequence): RE(count([hw.det], 3, per_shot=one_shot)) + + +class MultiConfiguredDevice(Named, Readable, Movable[int]): + """Device to test that read_configuration cache is updated for a new stream. This is done by + "configuring" the device and the read_configuration should show a new value.""" + + def __init__(self, motor, name): + self.motor = motor + self.read_value = 10 + super().__init__(name) + + def set(self, config_value: int) -> Status: + return self.motor.set(config_value) + + def read(self) -> dict[str, Reading]: + return read_pv(self, self.read_value) + + def read_configuration(self) -> dict[str, Reading]: + return read_pv(self, self.motor.position) + + def describe(self) -> dict[str, DataKey]: + return describe_pv(self) + + def describe_configuration(self) -> dict[str, DataKey]: + return describe_pv(self) + + +def multi_stream_plan(device: MultiConfiguredDevice, value_configuration: list[int], iterations: int): + """Plan that configures a device by setting a value and then reading the device. This is done X number + of times. Each time, it saves it to a new stream so we get new device configuration each time.""" + yield from open_run() + for v in value_configuration: + yield from abs_set(device, v, wait=True) + for _ in range(iterations): + yield from trigger_and_read([device], name=f"test{v}") + yield from close_run() + + +def test_device_has_new_read_configuration_once_per_stream(RE, hw): + docs = DocHolder() + device = MultiConfiguredDevice(hw.motor, "device") + pv = f"{device.name}-pv" + + config_values = [0, 1, 2, 3] + iterations = 2 + RE(multi_stream_plan(device, config_values, iterations), docs.append) + + docs.assert_emitted(start=1, descriptor=len(config_values), event=len(config_values) * iterations, stop=1) + for v in config_values: + assert docs["descriptor"][v]["name"] == f"test{v}" + assert docs["descriptor"][v]["configuration"][device.name]["data"] == {pv: v} + for i in range(1, iterations + 1): + assert docs["event"][i + v]["data"][pv] == device.read_value + + +def test_cache_used_correct_number_of_times_for_object(RE, hw): + device = MultiConfiguredDevice(hw.motor, "device") + config_values = [0, 1, 2, 3] + iterations = 2 + + expected_config_calls = len(config_values) # New config cache used once per stream. + expected_config_describe_calls = len(config_values) # New config_describe cache used once per stream. + expected_read_calls = expected_config_calls * iterations # New read cache used on each event. + expected_read_describe_calls = len(config_values) # New read_describe used once per stream. + + mock_read_config = Mock(wraps=device.read_configuration) + device.read_configuration = mock_read_config + mock_describe_config = Mock(wraps=device.describe_configuration) + device.describe_configuration = mock_describe_config + + mock_read = Mock(wraps=device.read) + device.read = mock_read + mock_describe = Mock(wraps=device.describe) + device.describe = mock_describe + + RE(multi_stream_plan(device, config_values, iterations)) + + assert mock_read_config.call_count == expected_config_calls + assert mock_describe_config.call_count == expected_config_describe_calls + + assert mock_read.call_count == expected_read_calls + assert mock_describe.call_count == expected_read_describe_calls diff --git a/src/bluesky/tests/test_run_engine.py b/src/bluesky/tests/test_run_engine.py index 2cbb9e7854..3ce44edb36 100644 --- a/src/bluesky/tests/test_run_engine.py +++ b/src/bluesky/tests/test_run_engine.py @@ -1,6 +1,7 @@ import asyncio import os import signal +import sys import threading import time as ttime import types @@ -45,8 +46,10 @@ TransitionError, WaitForTimeoutError, ) +from bluesky.suspenders import SuspendBoolHigh from bluesky.tests import requires_ophyd, uses_os_kill_sigint from bluesky.tests.utils import DocCollector, MsgCollector +from bluesky.utils import SigintHandler from .utils import _careful_event_set, _fabricate_asycio_event @@ -722,95 +725,115 @@ def simple_plan(): @uses_os_kill_sigint -def test_sigint_three_hits(RE, hw): - import time - +def test_sigint_three_hits(RE, hw, deterministic_sigint): motor = hw.motor motor.delay = 0.5 - pid = os.getpid() + event = threading.Event() - def sim_kill(n): - for j in range(n): # noqa: B007 - time.sleep(0.05) - os.kill(pid, signal.SIGINT) + def msg_hook(msg): + if msg.command == "set": + event.set() + + RE.msg_hook = msg_hook lp = RE.loop motor.loop = lp def self_sig_int_plan(): - threading.Timer(0.05, sim_kill, (3,)).start() yield from abs_set(motor, 1, wait=True) - start_time = ttime.time() - with pytest.raises(RunEngineInterrupted): - RE(finalize_wrapper(self_sig_int_plan(), abs_set(motor, 0, wait=True))) - end_time = ttime.time() + with deterministic_sigint() as sigint: + + def sim_kill(): + event.wait(timeout=5) + for _ in range(3): + sigint.send() + + threading.Thread(target=sim_kill, daemon=True).start() + start_time = ttime.time() + with pytest.raises(RunEngineInterrupted): + RE(finalize_wrapper(self_sig_int_plan(), abs_set(motor, 0, wait=True))) + end_time = ttime.time() + # not enough time for motor to cleanup, but long enough to start - assert 0.05 < end_time - start_time < 0.2 + assert end_time - start_time < 0.4 RE.abort() # now cleanup done_cleanup_time = ttime.time() # this should be 0.5 (the motor.delay) above, leave sloppy for CI - assert 0.3 < done_cleanup_time - end_time < 0.6 + assert 0.4 < done_cleanup_time - end_time < 0.6 @uses_os_kill_sigint -def test_sigint_many_hits_pln(RE): - pid = os.getpid() - - def sim_kill(n): - for j in range(n): - print("KILL", j) - ttime.sleep(0.05) - os.kill(pid, signal.SIGINT) +def test_sigint_many_hits_pln(RE, deterministic_sigint): + plan_started = threading.Event() def hanging_plan(): "a plan that blocks the RunEngine's normal Ctrl+C handing with a sleep" + plan_started.set() for j in range(100): # noqa: B007 ttime.sleep(0.1) yield Msg("null") - start_time = ttime.time() - timer = threading.Timer(0.2, sim_kill, (11,)) - timer.start() - with pytest.raises(RunEngineInterrupted): - RE(hanging_plan()) + with deterministic_sigint() as sigint: + + def sim_kill(): + plan_started.wait(timeout=5) + for _ in range(11): + sigint.send() + + threading.Thread(target=sim_kill, daemon=True).start() + start_time = ttime.time() + with pytest.raises(RunEngineInterrupted): + RE(hanging_plan()) + # Check that hammering SIGINT escaped from that 10-second sleep. assert ttime.time() - start_time < 5 # The KeyboardInterrupt will have been converted to a hard pause that # the test plan can not handle so we abort and go to idle. assert RE.state == "idle" - timer.join() +@pytest.mark.skipif( + sys.version_info < (3, 12), + reason=( + "Hangs on Python <3.12 due to a possible CPython bug: after " + "PyThreadState_SetAsyncExc the main thread deadlocks and " + "never reaches the subsequent blocking_event.wait(). " + "Issue only reproduces on CI." + ), +) @uses_os_kill_sigint -def test_sigint_many_hits_panic(RE): - raise pytest.skip("hangs tests on exit") - pid = os.getpid() +def test_sigint_many_hits_panic(RE, deterministic_sigint): + event = threading.Event() + wait_forever_event = threading.Event() - def sim_kill(n): - for j in range(n): - print("KILL", j, ttime.monotonic() - start_time) - ttime.sleep(0.05) - os.kill(pid, signal.SIGINT) + def msg_hook(msg): + if msg.command == "null": + event.set() + + RE.msg_hook = msg_hook def hanging_plan(): - "a plan that blocks the RunEngine's normal Ctrl+C handing with a sleep" + "a plan that blocks the RunEngine's normal Ctrl+C handing with a wait" yield Msg("null") - ttime.sleep(5) + wait_forever_event.wait() yield Msg("null") - start_time = ttime.monotonic() - timer = threading.Timer(0.2, sim_kill, (11,)) - timer.start() - with pytest.raises(RunEngineInterrupted): - RE(hanging_plan()) - # Check that hammering SIGINT escaped from that 5-second sleep. - assert (ttime.monotonic() - start_time) < 2.5 + with deterministic_sigint() as sigint: + + def sim_kill(): + event.wait(timeout=5) + for _ in range(11): + sigint.send() + + threading.Thread(target=sim_kill, daemon=True).start() + with pytest.raises(RunEngineInterrupted): + RE(hanging_plan()) + # The KeyboardInterrupt but because we could not shut down, panic! assert RE.state == "panicked" - timer.join() with pytest.raises(RuntimeError): RE([]) @@ -830,16 +853,12 @@ def hanging_plan(): with pytest.raises(RuntimeError): RE.request_pause() + wait_forever_event.set() -@uses_os_kill_sigint -def test_sigint_many_hits_cb(RE): - pid = os.getpid() - def sim_kill(n): - for j in range(n): # noqa: B007 - print("KILL") - ttime.sleep(0.05) - os.kill(pid, signal.SIGINT) +@uses_os_kill_sigint +def test_sigint_many_hits_cb(RE, deterministic_sigint): + cb_started = threading.Event() @run_decorator() def infinite_plan(): @@ -847,19 +866,26 @@ def infinite_plan(): yield Msg("null") def hanging_callback(name, doc): + cb_started.set() for j in range(100): # noqa: B007 ttime.sleep(0.1) - start_time = ttime.time() - timer = threading.Timer(0.2, sim_kill, (11,)) - timer.start() - with pytest.raises(RunEngineInterrupted): - RE(infinite_plan(), {"start": hanging_callback}) + with deterministic_sigint() as sigint: + + def sim_kill(): + cb_started.wait(timeout=5) + for _ in range(11): + sigint.send() + + threading.Thread(target=sim_kill, daemon=True).start() + start_time = ttime.time() + with pytest.raises(RunEngineInterrupted): + RE(infinite_plan(), {"start": hanging_callback}) + # Check that hammering SIGINT escaped from that 10-second sleep. assert ttime.time() - start_time < 5 # The KeyboardInterrupt will have been converted to a hard pause. assert RE.state == "idle" - timer.join() @uses_os_kill_sigint @@ -901,6 +927,428 @@ def hanging_plan(): timer.join() +@uses_os_kill_sigint +def test_single_sigint_interrupt_no_checkpoint(RE): + """A single SIGINT on a plan without a checkpoint continues running""" + pid = os.getpid() + + event = threading.Event() + + def msg_hook(msg): + if msg.command == "null": + event.set() + + RE.msg_hook = msg_hook + + def send_sigint(): + # Wait for event + event.wait() + os.kill(pid, signal.SIGINT) + + def test_plan(): + for _ in range(15): + yield Msg("null") + + # Single SIGINT defers a pause but plan finishes anyway + sigint_thread = threading.Thread(target=send_sigint, daemon=True) + sigint_thread.start() + RE(test_plan()) + + assert RE.state == "idle" + sigint_thread.join(timeout=0.1) + + +@uses_os_kill_sigint +def test_single_sigint_hits_checkpoint(RE): + """A single SIGINT on a plan with a checkpoint interrupts""" + pid = os.getpid() + + event = threading.Event() + + def msg_hook(msg): + if msg.command == "null": + event.set() + + RE.msg_hook = msg_hook + + def send_sigint(): + event.wait() + os.kill(pid, signal.SIGINT) + + def infinite_plan(): + while True: + yield Msg("null") + yield from checkpoint() + + # Single SIGINT reaches checkpoint + sigint_thread = threading.Thread(target=send_sigint, daemon=True) + sigint_thread.start() + with pytest.raises(RunEngineInterrupted): + RE(infinite_plan()) + + assert RE.state == "paused" + sigint_thread.join(timeout=0.1) + + +@uses_os_kill_sigint +def test_single_sigint_no_carry_over(RE): + """A single SIGINT does not pause at the next plan's checkpoint""" + pid = os.getpid() + + running_event = threading.Event() + deferred_pause_done = threading.Event() + + def msg_hook(msg): + if msg.command == "null": + running_event.set() + + RE.msg_hook = msg_hook + + _orig_request_pause = RE.request_pause + + def _tracked_request_pause(defer=False): + result = _orig_request_pause(defer=defer) + if defer: + deferred_pause_done.set() + return result + + RE.request_pause = _tracked_request_pause + + def send_sigint(): + # Wait for event + running_event.wait() + os.kill(pid, signal.SIGINT) + + def test_plan(): + first = False + for _ in range(5): + yield Msg("null") + if first: + deferred_pause_done.wait() + first = False + + # Single SIGINT defers a pause but plan finishes anyway + sigint_thread = threading.Thread(target=send_sigint, daemon=True) + sigint_thread.start() + RE(test_plan()) + sigint_thread.join(timeout=0.1) + + def checkpoint_plan(): + for _ in range(10): + yield Msg("null") + yield from checkpoint() + + RE(checkpoint_plan()) + assert RE.state == "idle" + + +@uses_os_kill_sigint +def test_double_sigint_interrupts_now(RE): + """Two SIGINTs in succession interrupts immediately""" + pid = os.getpid() + + running_event = threading.Event() + deferred_pause_done = threading.Event() + + def msg_hook(msg): + if msg.command == "null": + running_event.set() + + RE.msg_hook = msg_hook + + _orig_request_pause = RE.request_pause + + def _tracked_request_pause(defer=False): + result = _orig_request_pause(defer=defer) + if defer: + deferred_pause_done.set() + return result + + RE.request_pause = _tracked_request_pause + + def send_sigint(): + running_event.wait(timeout=5) + os.kill(pid, signal.SIGINT) + # Wait at least 100ms to send second SIGINT + ttime.sleep(0.15) + deferred_pause_done.wait(timeout=5) + os.kill(pid, signal.SIGINT) + + def test_plan(): + while True: + yield Msg("null") + + sigint_thread = threading.Thread(target=send_sigint, daemon=True) + sigint_thread.start() + with pytest.raises(RunEngineInterrupted): + RE(test_plan()) + + assert RE.state == "paused" + + sigint_thread.join(timeout=0.1) + + +@uses_os_kill_sigint +def test_sigint_during_suspender_active(RE, hw): + pid = os.getpid() + states = [] + running_event = threading.Event() + wait_for_reached = threading.Event() + deferred_pause_done = threading.Event() + + def state_hook(new_state, old_state): + states.append((old_state, new_state)) + if new_state == "running" and old_state == "idle": + running_event.set() + + RE.state_hook = state_hook + + def msg_hook(msg): + if msg.command == "wait_for": + wait_for_reached.set() + + RE.msg_hook = msg_hook + + _orig_request_pause = RE.request_pause + + def _tracked_request_pause(defer=False): + result = _orig_request_pause(defer=defer) + if defer: + deferred_pause_done.set() + return result + + RE.request_pause = _tracked_request_pause + + bool_signal = hw.bool_sig + suspender = SuspendBoolHigh(bool_signal) + suspender.install(RE) + bool_signal.put(False) + + def send_sigints(): + wait_for_reached.wait(timeout=5) + os.kill(pid, signal.SIGINT) + # Wait at least 100ms to send second SIGINT + ttime.sleep(0.15) + deferred_pause_done.wait(timeout=5) + os.kill(pid, signal.SIGINT) + + def trigger_suspend(): + running_event.wait(timeout=5) + bool_signal.put(True) + + def infinite_plan(): + while True: + yield Msg("null") + + sigint_thread = threading.Thread(target=send_sigints, daemon=True) + suspend_thread = threading.Thread(target=trigger_suspend, daemon=True) + sigint_thread.start() + suspend_thread.start() + + with pytest.raises(RunEngineInterrupted): + RE(infinite_plan()) + + bool_signal.put(False) + sigint_thread.join(timeout=5) + suspend_thread.join(timeout=5) + + assert ("running", "suspending") in states + assert ("running", "pausing") in states + assert RE.state == "paused" + + +@uses_os_kill_sigint +def test_sigint_pause_during_active_suspension_no_devices(RE, hw): + pid = os.getpid() + states = [] + wait_for_reached = threading.Event() + deferred_pause_done = threading.Event() + + def state_hook(new_state, old_state): + states.append((old_state, new_state)) + + RE.state_hook = state_hook + + def msg_hook(msg): + if msg.command == "wait_for": + wait_for_reached.set() + + RE.msg_hook = msg_hook + + _orig_request_pause = RE.request_pause + + def _tracked_request_pause(defer=False): + result = _orig_request_pause(defer=defer) + if defer: + deferred_pause_done.set() + return result + + RE.request_pause = _tracked_request_pause + + bool_signal = hw.bool_sig + suspender = SuspendBoolHigh(bool_signal) + bool_signal.put(True) + RE.install_suspender(suspender) + + def send_sigints(): + wait_for_reached.wait(timeout=5) + os.kill(pid, signal.SIGINT) + # Wait at least 100ms to send second SIGINT + ttime.sleep(0.15) + deferred_pause_done.wait(timeout=5) + os.kill(pid, signal.SIGINT) + + sigint_thread = threading.Thread(target=send_sigints, daemon=True) + sigint_thread.start() + + def plan(): + while True: + yield Msg("null") + + with pytest.raises(RunEngineInterrupted): + RE(plan()) + + bool_signal.put(False) + sigint_thread.join(timeout=5) + + assert ("running", "pausing") in states + assert RE.state == "paused" + + +@uses_os_kill_sigint +def test_sigint_in_not_pausable_state(RE): + """A SIGINT arriving while the RE is not pausable must not crash the + watcher thread. + + Sequence of events: + + 1. ``send_sigints`` thread sends two SIGINTs (soft then hard) to + trigger a hard pause. The RE transitions running -> pausing -> paused. + 2. On the main thread's teardown path, ``GateContext.__exit__`` runs + *before* ``SigintHandler.__exit__`` (LIFO order). It sleeps 0.15 s + to clear the signal handler's 0.1 s debounce window, then sends an + extra SIGINT. Because ``SigintHandler`` is still installed, the + handler routes the signal to the watcher thread. + 3. The watcher thread calls ``request_pause`` while the RE is paused + (not pausable), hitting ``TransitionError``. Our monkey-patched + ``_tracked_request_pause`` sets ``transition_error_hit`` when this + happens. + 4. We assert that the ``TransitionError`` was caught (not raised), and + prove the watcher thread survived by resuming + pausing a second time. + + """ + pid = os.getpid() + + running_event = threading.Event() + deferred_pause_done = threading.Event() + hard_pause_done = threading.Event() + transition_error_hit = threading.Event() + + def msg_hook(msg): + if msg.command == "null": + running_event.set() + + RE.msg_hook = msg_hook + + _orig_request_pause = RE.request_pause + + def _tracked_request_pause(defer=False): + try: + result = _orig_request_pause(defer=defer) + except TransitionError: + transition_error_hit.set() + raise + if defer: + deferred_pause_done.set() + else: + hard_pause_done.set() + return result + + RE.request_pause = _tracked_request_pause + + class GateContext: + """Entered after SigintHandler (exited before it in LIFO order). + + On exit, waits for the hard pause to complete, then sends an + extra SIGINT after 0.15 s (clearing the handler's 0.1 s debounce). + This guarantees the ``SigintHandler`` is still installed when the + extra signal arrives. + """ + + def __init__(self, RE): + pass + + def __enter__(self): + return self + + def __exit__(self, *args): + # Wait for the hard pause to land. + hard_pause_done.wait(timeout=5) + # Sleep past the 0.1 s debounce window so the signal handler + # accepts this SIGINT instead of silently dropping it. + ttime.sleep(0.15) + os.kill(pid, signal.SIGINT) + # Wait for the watcher thread to process the request and + # hit TransitionError. + transition_error_hit.wait(timeout=5) + + RE.context_managers = [SigintHandler, GateContext] + + def infinite_plan(): + while True: + yield Msg("null") + + def send_sigints(): + running_event.wait(timeout=5) + os.kill(pid, signal.SIGINT) + ttime.sleep(0.15) + deferred_pause_done.wait(timeout=5) + os.kill(pid, signal.SIGINT) + + sigint_thread = threading.Thread(target=send_sigints, daemon=True) + sigint_thread.start() + with pytest.raises(RunEngineInterrupted): + RE(infinite_plan()) + sigint_thread.join(timeout=5) + + assert transition_error_hit.is_set() + assert RE.state == "paused" + + # Prove the watcher thread survived: resume the RE, then pause it + # again with SIGINTs. If the watcher had died from an uncaught + # TransitionError, request_pause would never be called and the + # infinite plan would run forever. + running_event.clear() + deferred_pause_done.clear() + + # Replace with a simple tracker and remove the gate for the + # resume cycle. + def _simple_tracked_request_pause(defer=False): + result = _orig_request_pause(defer=defer) + if defer: + deferred_pause_done.set() + return result + + RE.request_pause = _simple_tracked_request_pause + RE.context_managers = [SigintHandler] + + def send_sigints_again(): + running_event.wait(timeout=5) + os.kill(pid, signal.SIGINT) + ttime.sleep(0.15) + deferred_pause_done.wait(timeout=5) + os.kill(pid, signal.SIGINT) + + sigint_thread2 = threading.Thread(target=send_sigints_again, daemon=True) + sigint_thread2.start() + with pytest.raises(RunEngineInterrupted): + RE.resume() + sigint_thread2.join(timeout=5) + + assert RE.state == "paused" + RE.abort() + + def test_many_context_managers(RE): class Manager: enters = 0 @@ -2277,3 +2725,20 @@ def check(key_order, name, doc): assert list(doc["data_keys"]) == key_order RE(scan(dets, i1930.charlie, -1, 1, 2), lambda name, doc, key_order=key_order: check(key_order, name, doc)) + + +@requires_ophyd +@pytest.mark.parametrize("wait", [True, False]) +def test_abs_set_fails(RE, wait): + from ophyd import Device, StatusBase + + class FailingMovable(Device): + def set(self, value) -> Status: + status = StatusBase() + status.set_exception(LookupError("Movement failed")) + return status + + device = FailingMovable(name="failing_device") + + with pytest.raises(FailedStatus): + RE(abs_set(device, 10, wait=wait)) diff --git a/src/bluesky/tests/test_seq_num_indices_in_sync.py b/src/bluesky/tests/test_seq_num_indices_in_sync.py index f7f288bce9..36d3cf9daf 100644 --- a/src/bluesky/tests/test_seq_num_indices_in_sync.py +++ b/src/bluesky/tests/test_seq_num_indices_in_sync.py @@ -1,7 +1,7 @@ import importlib.metadata from collections.abc import Iterable, Iterator from random import randint, sample -from typing import Literal, Optional, Union +from typing import Literal import packaging.version import pytest @@ -29,14 +29,14 @@ class ExternalAssetDevice: - sequence_counter_at_chunks: Optional[Union[range, list[int]]] = None + sequence_counter_at_chunks: range | list[int] | None = None current_chunk: int = 0 def __init__( self, number_of_chunks: int, number_of_frames: int, - detectors: Optional[list[str]] = None, + detectors: list[str] | None = None, stream_datum_contains_one_index: bool = False, ): self.detectors = detectors or ["det1", "det2", "det3"] diff --git a/src/bluesky/tests/test_simulators.py b/src/bluesky/tests/test_simulators.py index 1624d599f3..fa63f539ac 100644 --- a/src/bluesky/tests/test_simulators.py +++ b/src/bluesky/tests/test_simulators.py @@ -250,9 +250,9 @@ def take_three_readings(): msgs = assert_message_and_return_remaining(msgs, lambda msg: msg.command == "stage" and msg.obj.name == "det") msgs = assert_message_and_return_remaining( msgs, - lambda msg: msg.command == "open_run" - and msg.kwargs["plan_name"] == "count" - and msg.kwargs["num_points"] == 3, + lambda msg: ( + msg.command == "open_run" and msg.kwargs["plan_name"] == "count" and msg.kwargs["num_points"] == 3 + ), ) for _ in range(0, 3): msgs = assert_message_and_return_remaining(msgs, lambda msg: msg.command == "checkpoint") diff --git a/src/bluesky/tests/test_suspenders.py b/src/bluesky/tests/test_suspenders.py index b8a805b9da..76e291287b 100644 --- a/src/bluesky/tests/test_suspenders.py +++ b/src/bluesky/tests/test_suspenders.py @@ -68,7 +68,8 @@ def putter(val): # assert we waited at least 2 seconds + the settle time delta = stop - start print(delta) - assert delta > 0.5 + wait_time + 0.2 + # The suspension time is actually 0.5 - 0.1 = 0.4 seconds as timers run in parallel + assert delta > 0.4 + wait_time + 0.2 def test_pretripped(RE, hw): @@ -92,6 +93,32 @@ def accum(msg): assert ["wait_for", "checkpoint"] == [m[0] for m in msg_lst] +def test_suspender_wrapper(RE, hw): + + wait_time = 0.2 + sleep_time = 0.2 + trigger_time = 0.5 + + sig = hw.bool_sig + scan = [Msg("checkpoint"), Msg("sleep", None, sleep_time)] + sig.put(0) + + susp = SuspendBoolHigh(sig, sleep=wait_time) + + RE(suspend_wrapper(scan, susp)) + assert RE.state == "idle" + + sig.put(1) + threading.Timer(trigger_time, sig.put, (0,)).start() + + start = ttime.time() + + RE(suspend_wrapper(scan, susp)) + stop = ttime.time() + delta = stop - start + assert delta > trigger_time + wait_time + sleep_time + + @pytest.mark.parametrize( "pre_plan,post_plan,expected_list", [ diff --git a/src/bluesky/tests/test_tiled_writer.py b/src/bluesky/tests/test_tiled_writer.py index cae4767c09..da34c0ff86 100644 --- a/src/bluesky/tests/test_tiled_writer.py +++ b/src/bluesky/tests/test_tiled_writer.py @@ -3,7 +3,7 @@ import uuid from collections.abc import Iterator from pathlib import Path -from typing import Optional, Union, cast +from typing import cast import h5py import jinja2 @@ -109,7 +109,7 @@ def __init__(self, name: str, root: str) -> None: class StreamDatumReadableCollectable(Named, Readable, Collectable, WritesStreamAssets): """Produces no events, but only StreamResources/StreamDatums and can be read or collected""" - def _get_hdf5_stream(self, data_key: str, index: int) -> tuple[Optional[StreamResource], StreamDatum]: + def _get_hdf5_stream(self, data_key: str, index: int) -> tuple[StreamResource | None, StreamDatum]: file_path = os.path.join(self.root, "dataset.h5") uid = f"{data_key}-uid" data_desc = self.describe()[data_key] # Descriptor dictionary for the current data key @@ -158,7 +158,7 @@ def _get_hdf5_stream(self, data_key: str, index: int) -> tuple[Optional[StreamRe return stream_resource, stream_datum - def _get_tiff_stream(self, data_key: str, index: int) -> tuple[Optional[StreamResource], StreamDatum]: + def _get_tiff_stream(self, data_key: str, index: int) -> tuple[StreamResource | None, StreamDatum]: file_path = self.root for data_key in [f"{self.name}-sd3"]: uid = f"{data_key}-uid" @@ -228,10 +228,10 @@ def describe(self) -> dict[str, DataKey]: ), } - def describe_collect(self) -> Union[dict[str, DataKey], dict[str, dict[str, DataKey]]]: + def describe_collect(self) -> dict[str, DataKey] | dict[str, dict[str, DataKey]]: return self.describe() - def collect_asset_docs(self, index: Optional[int] = None) -> Iterator[StreamAsset]: + def collect_asset_docs(self, index: int | None = None) -> Iterator[StreamAsset]: """Produce a StreamResource and StreamDatum for all data keys for 0:index""" index = index or 1 data_keys_methods = { diff --git a/src/bluesky/tests/test_tracing.py b/src/bluesky/tests/test_tracing.py index 1e7239070f..d191231612 100644 --- a/src/bluesky/tests/test_tracing.py +++ b/src/bluesky/tests/test_tracing.py @@ -1,11 +1,33 @@ from collections.abc import Generator +import pytest from opentelemetry.trace import get_current_span +from bluesky import plans as bp from bluesky.plan_stubs import sleep +from bluesky.tests import requires_ophyd from bluesky.tracing import trace_plan, tracer +@pytest.fixture(scope="session") +def with_otl_instrumentation(): + pytest.importorskip("opentelemetry") + + from opentelemetry import trace + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import ( + ConsoleSpanExporter, + SimpleSpanProcessor, + ) + + provider = TracerProvider() + processor = SimpleSpanProcessor(ConsoleSpanExporter()) + provider.add_span_processor(processor) + + # Sets the global default tracer provider + trace.set_tracer_provider(provider) + + @trace_plan(tracer, "test_plan") def _test_plan(): yield from sleep(0) @@ -22,3 +44,18 @@ def test_plan_2(): assert not get_current_span().is_recording() RE(test_plan_2()) + + +@requires_ophyd +def test_trace_plan_correctly_sends_sanitized_args(RE, caplog, with_otl_instrumentation): + from ophyd.sim import SynAxis + + _det = SynAxis(name="test_det") + + @trace_plan(tracer, "test_plan_np") + def test_plan_np(): + return (yield from bp.scan([_det], _det, -1, 1, num=3)) + + RE(test_plan_np()) + + assert "Invalid type" not in caplog.text diff --git a/src/bluesky/tests/test_zmq.py b/src/bluesky/tests/test_zmq.py index e2fa22452a..d72cd9c4e6 100644 --- a/src/bluesky/tests/test_zmq.py +++ b/src/bluesky/tests/test_zmq.py @@ -11,7 +11,7 @@ from event_model import sanitize_doc from bluesky import Msg -from bluesky.callbacks.zmq import Proxy, Publisher, RemoteDispatcher +from bluesky.callbacks.zmq import Proxy, Publisher, RemoteDispatcher, _normalize_address from bluesky.plans import count from bluesky.tests import uses_os_kill_sigint @@ -126,21 +126,6 @@ def delayed_sigint(delay): gc.collect() -@pytest.mark.parametrize("host", ["localhost:5555", ("localhost", 5555)]) -def test_zmq_RD_ports_spec(host): - # test that two ways of specifying address are equivalent - d = RemoteDispatcher(host) - assert d.address == ("localhost", 5555) - assert d._socket is None - assert d._context is None - assert not d.closed - d.stop() - assert d._socket is None - assert d._context is None - assert d.closed - del d - - def test_zmq_no_RE(RE): # COMPONENT 1 # Run a 0MQ proxy on a separate process. @@ -345,3 +330,46 @@ def local_cb(name, doc): ra = sanitize_doc(remote_accumulator) la = sanitize_doc(local_accumulator) assert ra == la + + +@pytest.mark.parametrize( + "host", + ["localhost:5555", ("localhost", 5555)], +) +def test_zmq_RD_ports_spec(host): + # test that two ways of specifying address are equivalent + d = RemoteDispatcher(host) + assert d.address == "tcp://localhost:5555" + assert d._socket is None + assert d._context is None + assert not d.closed + d.stop() + assert d._socket is None + assert d._context is None + assert d.closed + del d + + +@pytest.mark.parametrize( + "address", + [ + ("localhost", "tcp://localhost"), + ("localhost:9", "tcp://localhost:9"), + ("remote.host", "tcp://remote.host"), + ("remote.host:9", "tcp://remote.host:9"), + ("tcp://remote.host", "tcp://remote.host"), + ("tcp://localhost", "tcp://localhost"), + ("tcp://localhost:9", "tcp://localhost:9"), + ("tcp://remote.host:9", "tcp://remote.host:9"), + ("ipc:///tmp/path", "ipc:///tmp/path"), + (("localhost",), "tcp://localhost"), + (("localhost", 9), "tcp://localhost:9"), + (("ipc", "/tmp/path"), "ipc:///tmp/path"), + (("tcp", "localhost"), "tcp://localhost"), + (("tcp", "localhost", 9), "tcp://localhost:9"), + (("tcp", "localhost", "9"), "tcp://localhost:9"), + ], +) +def test_address_normaliaztion(address): + inp, outp = address + assert _normalize_address(inp) == outp diff --git a/src/bluesky/tracing.py b/src/bluesky/tracing.py index b1ee337986..e86eba43c5 100644 --- a/src/bluesky/tracing.py +++ b/src/bluesky/tracing.py @@ -1,5 +1,6 @@ import functools -from typing import Callable, cast +from collections.abc import Callable +from typing import cast from opentelemetry.trace import Tracer, get_tracer diff --git a/src/bluesky/utils/__init__.py b/src/bluesky/utils/__init__.py index 0a116f3896..16de6a4755 100644 --- a/src/bluesky/utils/__init__.py +++ b/src/bluesky/utils/__init__.py @@ -15,17 +15,16 @@ import uuid import warnings from collections import namedtuple -from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Generator, Iterable, Sequence +from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable, Generator, Iterable, Sequence from collections.abc import Iterable as TypingIterable +from enum import Enum from functools import partial, reduce, wraps from inspect import Parameter, Signature from typing import ( Any, - Callable, - Optional, + TypeAlias, TypedDict, TypeVar, - Union, ) from weakref import WeakKeyDictionary, ref @@ -93,12 +92,12 @@ def __repr__(self): CustomPlanMetadata = dict[str, Any] #: Scalar or iterable of values, one to be applied to each point in a scan -ScalarOrIterableFloat = Union[float, TypingIterable[float]] +ScalarOrIterableFloat: TypeAlias = float | TypingIterable[float] # Single function to be used as an event listener Subscriber = Callable[[str, P], Any] -OneOrMany = Union[P, Sequence[P]] +OneOrMany: TypeAlias = P | Sequence[P] # Mapping from event type to listener or list of listeners @@ -111,7 +110,7 @@ class SubscriberMap(TypedDict, total=False): # Single listener, multiple listeners or mapping of listeners by event type -Subscribers = Union[OneOrMany[Subscriber[Document]], SubscriberMap] +Subscribers: TypeAlias = OneOrMany[Subscriber[Document]] | SubscriberMap class RunEngineControlException(Exception): @@ -226,9 +225,21 @@ class SignalHandler: probably only see one of them when you unblock this signal. https://www.gnu.org/software/libc/manual/html_node/Checking-for-Pending-Signals.html + + .. deprecated: + + This class is deprecated and will be removed in a future version of Bluesky. + See :ref:`SigintHandler` for an example on how to build a custom one. + """ def __init__(self, sig, log=None): + warnings.warn( + f"{SignalHandler.__name__} is deprecated and will be removed in a future version of Bluesky. " + f"See {SigintHandler.__name__} for an example on how to build a custom one.", + DeprecationWarning, + stacklevel=2, + ) self.sig = sig self.interrupted = False self.count = 0 @@ -269,57 +280,116 @@ def release(self): def handle_signals(self): ... -class SigintHandler(SignalHandler): - def __init__(self, RE): - super().__init__(signal.SIGINT, log=RE.log) - self.RE = RE - self.last_sigint_time = None # time most recent SIGINT was processed +class PauseRequest(Enum): + NONE = 0 + SOFT = 1 + HARD = 2 - def __enter__(self): - return super().__enter__() - - def handle_signals(self): - # Check for pause requests from keyboard. - # TODO, there is a possible race condition between the two - # pauses here - if self.RE.state.is_running and (not self.RE._interrupted): - if self.last_sigint_time is None or time.time() - self.last_sigint_time > 10: - # reset the counter to 1 - # It's been 10 seconds since the last SIGINT. Reset. - self.count = 1 - if self.last_sigint_time is not None: - self.log.debug("It has been 10 seconds since the last SIGINT. Resetting SIGINT handler.") - - # weeee push these to threads to not block the main thread - def maybe_defer_pause(): - try: - self.RE.request_pause(True) - except TransitionError: - ... - threading.Thread(target=maybe_defer_pause).start() +class SigintHandler: + """ + Context manager that replaces a `KeyboardInterrupt` with a mechanism to + pause the Bluesky RunEngine. This allows you to press `Ctrl + C` during + a running plan and have the RunEngine pause. + + On the first SIGINT, it will request a 'deferred pause' or 'soft pause'. The RunEngine will + pause at the next checkpoint. + + On each subsequent SIGINT within 10 seconds, it will request a 'hard pause'. The RunEngine + will pause immediately. + + However, if more than 10 SIGINTs are processed within 10 seconds, it will restore + and execute the original signal handler (typically `KeyboardInterrupt`). + + Each SIGINT must be spaced by at least 100ms to count (to represent intentional human input). + + The count will reset after 10 seconds since the last SIGINT processed. + """ + + def __init__(self, RE): + self._RE = RE + self._last_sigint_time = time.monotonic() + self._request = PauseRequest.NONE + self._released = True + self._request_event = threading.Event() + + def _watch_request(self) -> None: + while not self._released: + if self._request == PauseRequest.SOFT: print( "A 'deferred pause' has been requested. The " "RunEngine will pause at the next checkpoint. " "To pause immediately, hit Ctrl+C again in the " "next 10 seconds." ) + try: + self._RE.request_pause(defer=True) + except TransitionError: + ... + elif self._request == PauseRequest.HARD: + print("A 'hard pause' has been requested.") + try: + self._RE.request_pause(defer=False) + except TransitionError: + ... - self.last_sigint_time = time.time() - elif self.count == 2: - print("trying a second time") - # - Ctrl-C twice within 10 seconds -> hard pause - self.log.debug("RunEngine detected two SIGINTs. A hard pause will be requested.") + # Block until next request + self._request_event.wait() + self._request_event.clear() - # weeee push these to threads to not block the main thread - def maybe_prompt_pause(): - try: - self.RE.request_pause(False) - except TransitionError: - ... + def __enter__(self): + # Setup internal state tracking + self._count = 0 + self._last_sigint_time = time.monotonic() + self._released = False + self._request = PauseRequest.NONE + self._original_handler = signal.getsignal(signal.SIGINT) + + # Spawn request thread + self._request_thread = threading.Thread(target=self._watch_request, daemon=True) + self._request_thread.start() + + def handler(signum, frame): + """ + Assumptions: + - `self._last_sigint_time` is initialized on __enter__ with timestamp + - `self._count` is initialized on __enter__ with 0 + This callback must run very fast and avoid heavy work (no threads, prints, + substantial I/O, etc.). Lightweight lock operations like Event.set() are + acceptable. + """ + if self._released: + self._original_handler(signum, frame) + now = time.monotonic() + time_diff = now - self._last_sigint_time + + if time_diff > 10 or self._count == 0: + # First pause request + self._last_sigint_time = now + self._count = 1 + self._request = PauseRequest.SOFT + self._request_event.set() + elif time_diff > 0.1 and self._count > 0: + # Second or more pause requests + self._last_sigint_time = now + self._count += 1 + if self._count < 11: + self._request = PauseRequest.HARD + self._request_event.set() + else: + self._released = True + self._request_event.set() + self._original_handler(signum, frame) + + # Install handler callback + signal.signal(signal.SIGINT, handler) + return self - threading.Thread(target=maybe_prompt_pause).start() - self.last_sigint_time = time.time() + def __exit__(self, type, value, tb) -> None: + signal.signal(signal.SIGINT, self._original_handler) + if not self._released: + self._released = True + self._request_event.set() class CallbackRegistry: @@ -1196,6 +1266,8 @@ def sanitize_np(val): if np.isscalar(val): return val.item() return val.tolist() + if type(val) in (list, tuple): + return type(val)(sanitize_np(v) for v in val) return val @@ -1309,15 +1381,15 @@ def update( # noqa: B027 self, pos: Any, *, - name: Optional[str] = None, + name: str | None = None, current: Any = None, initial: Any = None, target: Any = None, unit: str = "units", precision: Any = None, fraction: Any = None, - time_elapsed: Optional[float] = None, - time_remaining: Optional[float] = None, + time_elapsed: float | None = None, + time_remaining: float | None = None, ): ... def clear(self): ... # noqa: B027 @@ -1464,7 +1536,7 @@ def default_progress_bar(status_objs_or_none) -> ProgressBarBase: class ProgressBarManager: pbar_factory: Callable[[Any], ProgressBarBase] - pbar: Optional[ProgressBarBase] + pbar: ProgressBarBase | None def __init__(self, pbar_factory: Callable[[Any], ProgressBarBase] = default_progress_bar): """ @@ -1933,8 +2005,8 @@ async def iterate_maybe_async(iterator: SyncOrAsyncIterator[T]) -> AsyncIterator async def maybe_collect_asset_docs( - msg, obj, index: Optional[int] = None, *args, **kwargs -) -> AsyncIterable[Union[Asset, StreamAsset]]: + msg, obj, index: int | None = None, *args, **kwargs +) -> AsyncIterable[Asset | StreamAsset]: # The if/elif statement must be done in this order because isinstance for protocol # doesn't check for exclusive signatures, and WritesExternalAssets will also # return true for a WritesStreamAsset as they both contain collect_asset_docs diff --git a/src/bluesky/utils/jupyter.py b/src/bluesky/utils/jupyter.py index 0186adab41..2b9af7cdcb 100644 --- a/src/bluesky/utils/jupyter.py +++ b/src/bluesky/utils/jupyter.py @@ -3,7 +3,7 @@ import time from functools import partial from threading import RLock -from typing import Any, Optional, TextIO +from typing import Any, TextIO import numpy as np from IPython.core.display import HTML @@ -53,15 +53,15 @@ def update( self, pos: Any, *, - name: Optional[str] = None, + name: str | None = None, current: Any = None, initial: Any = None, target: Any = None, unit: str = "units", precision: Any = None, fraction: Any = None, - time_elapsed: Optional[float] = None, - time_remaining: Optional[float] = None, + time_elapsed: float | None = None, + time_remaining: float | None = None, ): if all(x is not None for x in (current, initial, target)): # In this case there is enough information to draw a progress bar with @@ -110,7 +110,7 @@ def draw( meta: str = "", color: str = "#97d4e8", total: float = 1.0, - value: Optional[float] = None, + value: float | None = None, ) -> None: """ Draws the progress bar or a message if there is