From e68723cd72aa472ff91d5f65957c5efa6ac50602 Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Mon, 22 Jan 2024 18:57:18 +0000 Subject: [PATCH 1/8] Fix discovery cli to print devices not printed during discovery --- kasa/cli.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/kasa/cli.py b/kasa/cli.py index d1cb72765..205c73cdd 100755 --- a/kasa/cli.py +++ b/kasa/cli.py @@ -399,7 +399,7 @@ async def discover(ctx): sem = asyncio.Semaphore() discovered = dict() unsupported = [] - auth_failed = [] + auth_failed = {} async def print_unsupported(unsupported_exception: UnsupportedDeviceException): unsupported.append(unsupported_exception) @@ -420,17 +420,17 @@ async def print_discovered(dev: SmartDevice): try: await dev.update() except AuthenticationException: - auth_failed.append(dev._discovery_info) + auth_failed[dev.host] = dev._discovery_info echo("== Authentication failed for device ==") _echo_discovery_info(dev._discovery_info) echo() else: - discovered[dev.host] = dev.internal_state ctx.parent.obj = dev await ctx.parent.invoke(state) + discovered[dev.host] = dev.internal_state echo() - await Discover.discover( + discovered_devices = await Discover.discover( target=target, discovery_timeout=discovery_timeout, on_discovered=print_discovered, @@ -440,6 +440,23 @@ async def print_discovered(dev: SmartDevice): credentials=credentials, ) + # TimeoutError could have cancelled discover() before all + # discovered devices can print + not_printed = [ + device + for ip, device in discovered_devices.items() + if ip not in discovered and ip not in auth_failed + ] + for device in not_printed: + await print_discovered(device) + + for device in discovered_devices.values(): + await device.protocol.close() + # TODO transports need to permit close to avoid + # ERROR:asyncio:Unclosed client session + if hasattr(device.protocol._transport, "_http_client"): + await device.protocol._transport._http_client.close() + echo(f"Found {len(discovered)} devices") if unsupported: echo(f"Found {len(unsupported)} unsupported devices") From 914b9e74054d461e01adc133145a6be20e75653d Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Tue, 23 Jan 2024 18:53:05 +0000 Subject: [PATCH 2/8] Fix tests --- kasa/device_factory.py | 2 +- kasa/tests/test_cli.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/kasa/device_factory.py b/kasa/device_factory.py index 83db093f4..175552709 100755 --- a/kasa/device_factory.py +++ b/kasa/device_factory.py @@ -132,7 +132,7 @@ def get_device_class_from_family(device_type: str) -> Optional[Type[SmartDevice] "SMART.TAPOPLUG": TapoPlug, "SMART.TAPOBULB": TapoBulb, "SMART.KASAPLUG": TapoPlug, - "SMART.KASASWITCH": TapoBulb, + "SMART.KASASWITCH": TapoPlug, "IOT.SMARTPLUGSWITCH": SmartPlug, "IOT.SMARTBULB": SmartBulb, } diff --git a/kasa/tests/test_cli.py b/kasa/tests/test_cli.py index 3aad37dda..dd99a0e98 100644 --- a/kasa/tests/test_cli.py +++ b/kasa/tests/test_cli.py @@ -253,7 +253,7 @@ async def test_brightness(dev): @device_iot async def test_json_output(dev: SmartDevice, mocker): """Test that the json output produces correct output.""" - mocker.patch("kasa.Discover.discover", return_value=[dev]) + mocker.patch("kasa.Discover.discover", return_value={"127.0.0.1": dev}) runner = CliRunner() res = await runner.invoke(cli, ["--json", "state"], obj=dev) assert res.exit_code == 0 From b1d551cafa638c7a44192491806761e972e31d13 Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Mon, 29 Jan 2024 15:08:30 +0000 Subject: [PATCH 3/8] Fix print exceptions not being propagated --- kasa/cli.py | 16 +++++++--------- kasa/device_factory.py | 2 +- kasa/discover.py | 19 ++++++++++++++++--- kasa/tapo/tapobulb.py | 3 ++- kasa/tests/conftest.py | 4 ++-- kasa/tests/test_cli.py | 22 +++++++++++++++++++++- kasa/tests/test_discovery.py | 13 +++++++++++++ 7 files changed, 62 insertions(+), 17 deletions(-) diff --git a/kasa/cli.py b/kasa/cli.py index 6797f61b7..d6e906fc1 100755 --- a/kasa/cli.py +++ b/kasa/cli.py @@ -459,15 +459,12 @@ async def print_discovered(dev: SmartDevice): for ip, device in discovered_devices.items() if ip not in discovered and ip not in auth_failed ] - for device in not_printed: - await print_discovered(device) - - for device in discovered_devices.values(): - await device.protocol.close() - # TODO transports need to permit close to avoid - # ERROR:asyncio:Unclosed client session - if hasattr(device.protocol._transport, "_http_client"): - await device.protocol._transport._http_client.close() + try: + for device in not_printed: + await print_discovered(device) + finally: + for device in discovered_devices.values(): + await device.protocol.close() echo(f"Found {len(discovered)} devices") if unsupported: @@ -561,6 +558,7 @@ async def state(ctx, dev: SmartDevice): echo("\n\t[bold]== Device specific information ==[/bold]") for info_name, info_data in dev.state_information.items(): + # raise SmartDeviceException("Test exception msg") if isinstance(info_data, list): echo(f"\t{info_name}:") for item in info_data: diff --git a/kasa/device_factory.py b/kasa/device_factory.py index afe77c201..fdb5b1b49 100755 --- a/kasa/device_factory.py +++ b/kasa/device_factory.py @@ -134,7 +134,7 @@ def get_device_class_from_family(device_type: str) -> Optional[Type[SmartDevice] "SMART.TAPOBULB": TapoBulb, "SMART.TAPOSWITCH": TapoBulb, "SMART.KASAPLUG": TapoPlug, - "SMART.KASASWITCH": TapoPlug, + "SMART.KASASWITCH": TapoBulb, "IOT.SMARTPLUGSWITCH": SmartPlug, "IOT.SMARTBULB": SmartBulb, } diff --git a/kasa/discover.py b/kasa/discover.py index 8286387ae..e1860e4b3 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -4,7 +4,7 @@ import ipaddress import logging import socket -from typing import Awaitable, Callable, Dict, Optional, Set, Type, cast +from typing import Awaitable, Callable, Dict, List, Optional, Set, Type, cast # When support for cpython older than 3.11 is dropped # async_timeout can be replaced with asyncio.timeout @@ -80,6 +80,11 @@ def __init__( self.discovery_timeout = discovery_timeout self.seen_hosts: Set[str] = set() self.discover_task: Optional[asyncio.Task] = None + self.callback_tasks: List[asyncio.Task] = [] + + def _run_callback_task(self, coro): + task = asyncio.create_task(coro) + self.callback_tasks.append(task) def connection_made(self, transport) -> None: """Set socket options for broadcasting.""" @@ -152,7 +157,7 @@ def datagram_received(self, data, addr) -> None: self.discovered_devices[ip] = device if self.on_discovered is not None: - asyncio.ensure_future(self.on_discovered(device)) + self._run_callback_task(self.on_discovered(device)) self._handle_discovered_event() @@ -284,7 +289,15 @@ async def discover( try: _LOGGER.debug("Waiting %s seconds for responses...", discovery_timeout) - await asyncio.sleep(discovery_timeout) + async with asyncio_timeout(discovery_timeout): + while not protocol.discover_task: + await asyncio.sleep(0) + await protocol.discover_task + await asyncio.gather(*protocol.callback_tasks) + except SmartDeviceException as ex: + for device in protocol.discovered_devices.values(): + await device.protocol.close() + raise ex finally: transport.close() diff --git a/kasa/tapo/tapobulb.py b/kasa/tapo/tapobulb.py index bbaf093d6..cfd5768f0 100644 --- a/kasa/tapo/tapobulb.py +++ b/kasa/tapo/tapobulb.py @@ -243,9 +243,10 @@ def state_information(self) -> Dict[str, Any]: info: Dict[str, Any] = { # TODO: re-enable after we don't inherit from smartbulb # **super().state_information - "Brightness": self.brightness, "Is dimmable": self.is_dimmable, } + if self.is_dimmable: + info["Brightness"] = self.brightness if self.is_variable_color_temp: info["Color temperature"] = self.color_temp info["Valid temperature range"] = self.valid_temperature_range diff --git a/kasa/tests/conftest.py b/kasa/tests/conftest.py index 24bc3372b..d8d1c2046 100644 --- a/kasa/tests/conftest.py +++ b/kasa/tests/conftest.py @@ -490,7 +490,7 @@ class _DiscoveryMock: login_version, ) - def mock_discover(self): + async def mock_discover(self): port = ( dm.port_override if dm.port_override and dm.discovery_port != 20002 @@ -543,7 +543,7 @@ def unsupported_device_info(request, mocker): discovery_data = request.param host = "127.0.0.1" - def mock_discover(self): + async def mock_discover(self): if discovery_data: data = ( b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8" diff --git a/kasa/tests/test_cli.py b/kasa/tests/test_cli.py index 703ad533e..7a8681ee6 100644 --- a/kasa/tests/test_cli.py +++ b/kasa/tests/test_cli.py @@ -396,7 +396,7 @@ async def test_discover(discovery_mock, mocker): cli, [ "--discovery-timeout", - 0, + 1, "--username", "foo", "--password", @@ -408,6 +408,26 @@ async def test_discover(discovery_mock, mocker): assert res.exit_code == 0 +async def test_discover_host(discovery_mock, mocker): + """Test discovery output.""" + runner = CliRunner() + res = await runner.invoke( + cli, + [ + "--discovery-timeout", + 1, + "--host", + "127.0.0.123", + "--username", + "foo", + "--password", + "bar", + "--verbose", + ], + ) + assert res.exit_code == 0 + + async def test_discover_unsupported(unsupported_device_info): """Test discovery output.""" runner = CliRunner() diff --git a/kasa/tests/test_discovery.py b/kasa/tests/test_discovery.py index db4d8fc1c..97cff9595 100644 --- a/kasa/tests/test_discovery.py +++ b/kasa/tests/test_discovery.py @@ -456,3 +456,16 @@ async def test_do_discover_invalid(mocker, port, will_timeout): await asyncio.sleep(0) assert dp.discover_task.done() assert timed_out is will_timeout + + +async def test_discover_propogates_task_exceptions(discovery_mock): + """Make sure that discover propogates callback exceptions.""" + discovery_timeout = 0 + + async def on_discovered(dev): + raise SmartDeviceException("Dummy exception") + + with pytest.raises(SmartDeviceException): + await Discover.discover( + discovery_timeout=discovery_timeout, on_discovered=on_discovered + ) From 4737522d9407a62465ae4ede767020f1d4c5f581 Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Mon, 29 Jan 2024 18:31:59 +0000 Subject: [PATCH 4/8] Fix tests --- kasa/cli.py | 15 ++------------- kasa/discover.py | 5 ++++- kasa/tests/test_cli.py | 4 ++-- kasa/tests/test_discovery.py | 10 +++++----- 4 files changed, 13 insertions(+), 21 deletions(-) diff --git a/kasa/cli.py b/kasa/cli.py index d6e906fc1..aa136c871 100755 --- a/kasa/cli.py +++ b/kasa/cli.py @@ -452,19 +452,8 @@ async def print_discovered(dev: SmartDevice): credentials=credentials, ) - # TimeoutError could have cancelled discover() before all - # discovered devices can print - not_printed = [ - device - for ip, device in discovered_devices.items() - if ip not in discovered and ip not in auth_failed - ] - try: - for device in not_printed: - await print_discovered(device) - finally: - for device in discovered_devices.values(): - await device.protocol.close() + for device in discovered_devices.values(): + await device.protocol.close() echo(f"Found {len(discovered)} devices") if unsupported: diff --git a/kasa/discover.py b/kasa/discover.py index e1860e4b3..c16da87fd 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -145,7 +145,7 @@ def datagram_received(self, data, addr) -> None: _LOGGER.debug("Unsupported device found at %s << %s", ip, udex) self.unsupported_device_exceptions[ip] = udex if self.on_unsupported is not None: - asyncio.ensure_future(self.on_unsupported(udex)) + self._run_callback_task(self.on_unsupported(udex)) self._handle_discovered_event() return except SmartDeviceException as ex: @@ -293,6 +293,9 @@ async def discover( while not protocol.discover_task: await asyncio.sleep(0) await protocol.discover_task + # Give the last sent packet time to respond + await asyncio.sleep(discovery_timeout / discovery_packets) + # Wait for any pending callbacks to complete await asyncio.gather(*protocol.callback_tasks) except SmartDeviceException as ex: for device in protocol.discovered_devices.values(): diff --git a/kasa/tests/test_cli.py b/kasa/tests/test_cli.py index 7a8681ee6..e6f1c6515 100644 --- a/kasa/tests/test_cli.py +++ b/kasa/tests/test_cli.py @@ -396,7 +396,7 @@ async def test_discover(discovery_mock, mocker): cli, [ "--discovery-timeout", - 1, + 0, "--username", "foo", "--password", @@ -415,7 +415,7 @@ async def test_discover_host(discovery_mock, mocker): cli, [ "--discovery-timeout", - 1, + 0, "--host", "127.0.0.123", "--username", diff --git a/kasa/tests/test_discovery.py b/kasa/tests/test_discovery.py index 97cff9595..6894eda0e 100644 --- a/kasa/tests/test_discovery.py +++ b/kasa/tests/test_discovery.py @@ -401,7 +401,7 @@ def sendto(self, data, addr=None): async def test_do_discover_drop_packets(mocker, port, do_not_reply_count): """Make sure that discover_single handles authenticating devices correctly.""" host = "127.0.0.1" - discovery_timeout = 1 + discovery_timeout = 0.1 event = asyncio.Event() dp = _DiscoverProtocol( @@ -432,9 +432,9 @@ async def test_do_discover_drop_packets(mocker, port, do_not_reply_count): ids=["unknownport", "unsupporteddevice"], ) async def test_do_discover_invalid(mocker, port, will_timeout): - """Make sure that discover_single handles authenticating devices correctly.""" + """Make sure that discover_single handles invalid devices correctly.""" host = "127.0.0.1" - discovery_timeout = 1 + discovery_timeout = 0.1 event = asyncio.Event() dp = _DiscoverProtocol( @@ -448,7 +448,7 @@ async def test_do_discover_invalid(mocker, port, will_timeout): timed_out = False try: - async with asyncio_timeout(15): + async with asyncio_timeout(discovery_timeout): await event.wait() except asyncio.TimeoutError: timed_out = True @@ -460,7 +460,7 @@ async def test_do_discover_invalid(mocker, port, will_timeout): async def test_discover_propogates_task_exceptions(discovery_mock): """Make sure that discover propogates callback exceptions.""" - discovery_timeout = 0 + discovery_timeout = 0.1 async def on_discovered(dev): raise SmartDeviceException("Dummy exception") From 0844034e76e9ba3344c5be8c8c1392504f51bdce Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Mon, 29 Jan 2024 18:45:35 +0000 Subject: [PATCH 5/8] Reduce test discover_send time --- kasa/tests/test_discovery.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/kasa/tests/test_discovery.py b/kasa/tests/test_discovery.py index 6894eda0e..8a3fa2ffc 100644 --- a/kasa/tests/test_discovery.py +++ b/kasa/tests/test_discovery.py @@ -200,7 +200,8 @@ def mock_discover(self): async def test_discover_send(mocker): """Test discovery parameters.""" - proto = _DiscoverProtocol() + discovery_timeout = 0.1 + proto = _DiscoverProtocol(discovery_timeout=discovery_timeout) assert proto.discovery_packets == 3 assert proto.target_1 == ("255.255.255.255", 9999) transport = mocker.patch.object(proto, "transport") From f8fd4999bbdbf1a6d80c3739ece853d9f984b832 Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Tue, 30 Jan 2024 10:07:01 +0000 Subject: [PATCH 6/8] Simplify wait logic --- kasa/discover.py | 52 ++++++++++++++++++------------------ kasa/tests/test_discovery.py | 34 +++++++---------------- 2 files changed, 35 insertions(+), 51 deletions(-) diff --git a/kasa/discover.py b/kasa/discover.py index c16da87fd..b31bd6160 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -41,6 +41,8 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): This is internal class, use :func:`Discover.discover`: instead. """ + DISCOVERY_START_TIMEOUT = 1 + discovered_devices: DeviceDict def __init__( @@ -55,7 +57,6 @@ def __init__( Callable[[UnsupportedDeviceException], Awaitable[None]] ] = None, port: Optional[int] = None, - discovered_event: Optional[asyncio.Event] = None, credentials: Optional[Credentials] = None, timeout: Optional[int] = None, ) -> None: @@ -74,18 +75,33 @@ def __init__( self.unsupported_device_exceptions: Dict = {} self.invalid_device_exceptions: Dict = {} self.on_unsupported = on_unsupported - self.discovered_event = discovered_event self.credentials = credentials self.timeout = timeout self.discovery_timeout = discovery_timeout self.seen_hosts: Set[str] = set() self.discover_task: Optional[asyncio.Task] = None self.callback_tasks: List[asyncio.Task] = [] + self.target_discovered: bool = False def _run_callback_task(self, coro): task = asyncio.create_task(coro) self.callback_tasks.append(task) + async def wait_for_discovery_to_complete(self): + """Wait for the discovery task to complete.""" + # Give some time for connection_made event to be received + async with asyncio_timeout(self.DISCOVERY_START_TIMEOUT): + while not self.discover_task: + await asyncio.sleep(0) + try: + await self.discover_task + except asyncio.CancelledError: + # if target_discovered then cancel was called internally + if not self.target_discovered: + raise + # Wait for any pending callbacks to complete + await asyncio.gather(*self.callback_tasks) + def connection_made(self, transport) -> None: """Set socket options for broadcasting.""" self.transport = transport @@ -110,13 +126,12 @@ async def do_discover(self) -> None: _LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY) encrypted_req = XorEncryption.encrypt(req) sleep_between_packets = self.discovery_timeout / self.discovery_packets - for i in range(self.discovery_packets): + for _ in range(self.discovery_packets): if self.target in self.seen_hosts: # Stop sending for discover_single break self.transport.sendto(encrypted_req[4:], self.target_1) # type: ignore self.transport.sendto(Discover.DISCOVERY_QUERY_2, self.target_2) # type: ignore - if i < self.discovery_packets - 1: - await asyncio.sleep(sleep_between_packets) + await asyncio.sleep(sleep_between_packets) def datagram_received(self, data, addr) -> None: """Handle discovery responses.""" @@ -162,11 +177,11 @@ def datagram_received(self, data, addr) -> None: self._handle_discovered_event() def _handle_discovered_event(self): - """If discovered_event is available set it and cancel discover_task.""" - if self.discovered_event is not None: + """If target is in seen_hosts cancel discover_task.""" + if self.target in self.seen_hosts: + self.target_discovered = True if self.discover_task: self.discover_task.cancel() - self.discovered_event.set() def error_received(self, ex): """Handle asyncio.Protocol errors.""" @@ -289,14 +304,7 @@ async def discover( try: _LOGGER.debug("Waiting %s seconds for responses...", discovery_timeout) - async with asyncio_timeout(discovery_timeout): - while not protocol.discover_task: - await asyncio.sleep(0) - await protocol.discover_task - # Give the last sent packet time to respond - await asyncio.sleep(discovery_timeout / discovery_packets) - # Wait for any pending callbacks to complete - await asyncio.gather(*protocol.callback_tasks) + await protocol.wait_for_discovery_to_complete() except SmartDeviceException as ex: for device in protocol.discovered_devices.values(): await device.protocol.close() @@ -333,7 +341,6 @@ async def discover_single( :return: Object for querying/controlling found device. """ loop = asyncio.get_event_loop() - event = asyncio.Event() try: ipaddress.ip_address(host) @@ -363,7 +370,6 @@ async def discover_single( lambda: _DiscoverProtocol( target=ip, port=port, - discovered_event=event, credentials=credentials, timeout=timeout, discovery_timeout=discovery_timeout, @@ -376,13 +382,7 @@ async def discover_single( _LOGGER.debug( "Waiting a total of %s seconds for responses...", discovery_timeout ) - - async with asyncio_timeout(discovery_timeout): - await event.wait() - except asyncio.TimeoutError as ex: - raise TimeoutException( - f"Timed out getting discovery response for {host}" - ) from ex + await protocol.wait_for_discovery_to_complete() finally: transport.close() @@ -395,7 +395,7 @@ async def discover_single( elif ip in protocol.invalid_device_exceptions: raise protocol.invalid_device_exceptions[ip] else: - raise SmartDeviceException(f"Unable to get discovery response for {host}") + raise TimeoutException(f"Timed out getting discovery response for {host}") @staticmethod def _get_device_class(info: dict) -> Type[SmartDevice]: diff --git a/kasa/tests/test_discovery.py b/kasa/tests/test_discovery.py index c85012d92..f91801761 100644 --- a/kasa/tests/test_discovery.py +++ b/kasa/tests/test_discovery.py @@ -6,7 +6,6 @@ import aiohttp import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 -from async_timeout import timeout as asyncio_timeout from kasa import ( Credentials, @@ -190,7 +189,7 @@ async def test_discover_invalid_info(msg, data, mocker): """Make sure that invalid discovery information raises an exception.""" host = "127.0.0.1" - def mock_discover(self): + async def mock_discover(self): self.datagram_received( XorEncryption.encrypt(json_dumps(data))[4:], (host, 9999) ) @@ -203,7 +202,7 @@ def mock_discover(self): async def test_discover_send(mocker): """Test discovery parameters.""" - discovery_timeout = 0.1 + discovery_timeout = 0 proto = _DiscoverProtocol(discovery_timeout=discovery_timeout) assert proto.discovery_packets == 3 assert proto.target_1 == ("255.255.255.255", 9999) @@ -405,29 +404,22 @@ def sendto(self, data, addr=None): async def test_do_discover_drop_packets(mocker, port, do_not_reply_count): """Make sure that discover_single handles authenticating devices correctly.""" host = "127.0.0.1" - discovery_timeout = 0.1 + discovery_timeout = 0 - event = asyncio.Event() dp = _DiscoverProtocol( target=host, discovery_timeout=discovery_timeout, discovery_packets=5, - discovered_event=event, ) ft = FakeDatagramTransport(dp, port, do_not_reply_count) dp.connection_made(ft) - timed_out = False - try: - async with asyncio_timeout(discovery_timeout): - await event.wait() - except asyncio.TimeoutError: - timed_out = True + await dp.wait_for_discovery_to_complete() await asyncio.sleep(0) assert ft.send_count == do_not_reply_count + 1 assert dp.discover_task.done() - assert timed_out is False + assert dp.discover_task.cancelled() @pytest.mark.parametrize( @@ -438,33 +430,25 @@ async def test_do_discover_drop_packets(mocker, port, do_not_reply_count): async def test_do_discover_invalid(mocker, port, will_timeout): """Make sure that discover_single handles invalid devices correctly.""" host = "127.0.0.1" - discovery_timeout = 0.1 + discovery_timeout = 0 - event = asyncio.Event() dp = _DiscoverProtocol( target=host, discovery_timeout=discovery_timeout, discovery_packets=5, - discovered_event=event, ) ft = FakeDatagramTransport(dp, port, 0, unsupported=True) dp.connection_made(ft) - timed_out = False - try: - async with asyncio_timeout(discovery_timeout): - await event.wait() - except asyncio.TimeoutError: - timed_out = True - + await dp.wait_for_discovery_to_complete() await asyncio.sleep(0) assert dp.discover_task.done() - assert timed_out is will_timeout + assert dp.discover_task.cancelled() != will_timeout async def test_discover_propogates_task_exceptions(discovery_mock): """Make sure that discover propogates callback exceptions.""" - discovery_timeout = 0.1 + discovery_timeout = 0 async def on_discovered(dev): raise SmartDeviceException("Dummy exception") From cc7b0bbb94af9f1b8e1fc967140be387f1f6d9bd Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Tue, 30 Jan 2024 10:30:02 +0000 Subject: [PATCH 7/8] Add tests --- kasa/cli.py | 1 - kasa/tests/test_discovery.py | 45 +++++++++++++++++++++++++++++++++--- 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/kasa/cli.py b/kasa/cli.py index aa136c871..29ab85c4c 100755 --- a/kasa/cli.py +++ b/kasa/cli.py @@ -547,7 +547,6 @@ async def state(ctx, dev: SmartDevice): echo("\n\t[bold]== Device specific information ==[/bold]") for info_name, info_data in dev.state_information.items(): - # raise SmartDeviceException("Test exception msg") if isinstance(info_data, list): echo(f"\t{info_name}:") for item in info_data: diff --git a/kasa/tests/test_discovery.py b/kasa/tests/test_discovery.py index f91801761..831540933 100644 --- a/kasa/tests/test_discovery.py +++ b/kasa/tests/test_discovery.py @@ -6,6 +6,7 @@ import aiohttp import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 +from async_timeout import timeout as asyncio_timeout from kasa import ( Credentials, @@ -298,6 +299,7 @@ async def test_discover_single_authentication(discovery_mock, mocker): @new_discovery async def test_device_update_from_new_discovery_info(discovery_data): + """Make sure that new discovery devices update from discovery info correctly.""" device = SmartDevice("127.0.0.7") discover_info = DiscoveryResult(**discovery_data["result"]) discover_dump = discover_info.get_dict() @@ -334,7 +336,7 @@ async def test_discover_single_http_client(discovery_mock, mocker): async def test_discover_http_client(discovery_mock, mocker): - """Make sure that discover_single returns an initialized SmartDevice instance.""" + """Make sure that discover returns an initialized SmartDevice instance.""" host = "127.0.0.1" discovery_mock.ip = host @@ -402,7 +404,7 @@ def sendto(self, data, addr=None): @pytest.mark.parametrize("port", [9999, 20002]) @pytest.mark.parametrize("do_not_reply_count", [0, 1, 2, 3, 4]) async def test_do_discover_drop_packets(mocker, port, do_not_reply_count): - """Make sure that discover_single handles authenticating devices correctly.""" + """Make sure that _DiscoverProtocol handles authenticating devices correctly.""" host = "127.0.0.1" discovery_timeout = 0 @@ -428,7 +430,7 @@ async def test_do_discover_drop_packets(mocker, port, do_not_reply_count): ids=["unknownport", "unsupporteddevice"], ) async def test_do_discover_invalid(mocker, port, will_timeout): - """Make sure that discover_single handles invalid devices correctly.""" + """Make sure that _DiscoverProtocol handles invalid devices correctly.""" host = "127.0.0.1" discovery_timeout = 0 @@ -457,3 +459,40 @@ async def on_discovered(dev): await Discover.discover( discovery_timeout=discovery_timeout, on_discovered=on_discovered ) + + +async def test_do_discover_no_connection(mocker): + """Make sure that if the datagram connection doesnt start a TimeoutError is raised.""" + host = "127.0.0.1" + discovery_timeout = 0 + mocker.patch.object(_DiscoverProtocol, "DISCOVERY_START_TIMEOUT", 0) + dp = _DiscoverProtocol( + target=host, + discovery_timeout=discovery_timeout, + discovery_packets=5, + ) + # Normally tests would simulate connection as per below + # ft = FakeDatagramTransport(dp, port, 0, unsupported=True) + # dp.connection_made(ft) + + with pytest.raises(asyncio.TimeoutError): + await dp.wait_for_discovery_to_complete() + + +async def test_do_discover_external_cancel(mocker): + """Make sure that a cancel other than when target is discovered propogates.""" + host = "127.0.0.1" + discovery_timeout = 1 + + dp = _DiscoverProtocol( + target=host, + discovery_timeout=discovery_timeout, + discovery_packets=1, + ) + # Normally tests would simulate connection as per below + ft = FakeDatagramTransport(dp, 9999, 1, unsupported=True) + dp.connection_made(ft) + + with pytest.raises(asyncio.TimeoutError): + async with asyncio_timeout(0): + await dp.wait_for_discovery_to_complete() From 1e3a2ce2ae7e9ec3941e1813eff136c7da065f2a Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Mon, 5 Feb 2024 16:40:58 +0000 Subject: [PATCH 8/8] Remove sleep loop and make auth failed a list --- kasa/cli.py | 4 ++-- kasa/discover.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/kasa/cli.py b/kasa/cli.py index de9781748..53c68adb4 100755 --- a/kasa/cli.py +++ b/kasa/cli.py @@ -418,7 +418,7 @@ async def discover(ctx): sem = asyncio.Semaphore() discovered = dict() unsupported = [] - auth_failed = {} + auth_failed = [] async def print_unsupported(unsupported_exception: UnsupportedDeviceException): unsupported.append(unsupported_exception) @@ -439,7 +439,7 @@ async def print_discovered(dev: Device): try: await dev.update() except AuthenticationException: - auth_failed[dev.host] = dev._discovery_info + auth_failed.append(dev._discovery_info) echo("== Authentication failed for device ==") _echo_discovery_info(dev._discovery_info) echo() diff --git a/kasa/discover.py b/kasa/discover.py index c20769798..f9ce6e0a5 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -87,6 +87,7 @@ def __init__( self.discover_task: Optional[asyncio.Task] = None self.callback_tasks: List[asyncio.Task] = [] self.target_discovered: bool = False + self._started_event = asyncio.Event() def _run_callback_task(self, coro): task = asyncio.create_task(coro) @@ -96,8 +97,7 @@ async def wait_for_discovery_to_complete(self): """Wait for the discovery task to complete.""" # Give some time for connection_made event to be received async with asyncio_timeout(self.DISCOVERY_START_TIMEOUT): - while not self.discover_task: - await asyncio.sleep(0) + await self._started_event.wait() try: await self.discover_task except asyncio.CancelledError: @@ -124,6 +124,7 @@ def connection_made(self, transport) -> None: ) self.discover_task = asyncio.create_task(self.do_discover()) + self._started_event.set() async def do_discover(self) -> None: """Send number of discovery datagrams."""