Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Fix discovery cli to print devices not printed during discovery timeout #670

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Feb 5, 2024
Merged
7 changes: 5 additions & 2 deletions kasa/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,12 +444,12 @@ async def print_discovered(dev: Device):
_echo_discovery_info(dev._discovery_info)
echo()
else:
discovered[dev.host] = dev.internal_state
ctx.parent.obj = dev
await ctx.parent.invoke(state)
discovered[dev.host] = dev.internal_state
echo()

await Discover.discover(
discovered_devices = await Discover.discover(
target=target,
discovery_timeout=discovery_timeout,
on_discovered=print_discovered,
Expand All @@ -459,6 +459,9 @@ async def print_discovered(dev: Device):
credentials=credentials,
)

for device in discovered_devices.values():
await device.protocol.close()

echo(f"Found {len(discovered)} devices")
if unsupported:
echo(f"Found {len(unsupported)} unsupported devices")
Expand Down
61 changes: 39 additions & 22 deletions kasa/discover.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import ipaddress
import logging
import socket
from typing import Awaitable, Callable, Dict, Optional, Set, Type, cast
from typing import Awaitable, Callable, Dict, List, Optional, Set, Type, cast

# When support for cpython older than 3.11 is dropped
# async_timeout can be replaced with asyncio.timeout
Expand Down Expand Up @@ -46,6 +46,8 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
This is internal class, use :func:`Discover.discover`: instead.
"""

DISCOVERY_START_TIMEOUT = 1

discovered_devices: DeviceDict

def __init__(
Expand All @@ -60,7 +62,6 @@ def __init__(
Callable[[UnsupportedDeviceException], Awaitable[None]]
] = None,
port: Optional[int] = None,
discovered_event: Optional[asyncio.Event] = None,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
) -> None:
Expand All @@ -79,12 +80,32 @@ def __init__(
self.unsupported_device_exceptions: Dict = {}
self.invalid_device_exceptions: Dict = {}
self.on_unsupported = on_unsupported
self.discovered_event = discovered_event
self.credentials = credentials
self.timeout = timeout
self.discovery_timeout = discovery_timeout
self.seen_hosts: Set[str] = set()
self.discover_task: Optional[asyncio.Task] = None
self.callback_tasks: List[asyncio.Task] = []
self.target_discovered: bool = False
self._started_event = asyncio.Event()

def _run_callback_task(self, coro):
task = asyncio.create_task(coro)
self.callback_tasks.append(task)

async def wait_for_discovery_to_complete(self):
"""Wait for the discovery task to complete."""
# Give some time for connection_made event to be received
async with asyncio_timeout(self.DISCOVERY_START_TIMEOUT):
await self._started_event.wait()
try:
await self.discover_task
except asyncio.CancelledError:
# if target_discovered then cancel was called internally
if not self.target_discovered:
raise
# Wait for any pending callbacks to complete
await asyncio.gather(*self.callback_tasks)

def connection_made(self, transport) -> None:
"""Set socket options for broadcasting."""
Expand All @@ -103,20 +124,20 @@ def connection_made(self, transport) -> None:
)

self.discover_task = asyncio.create_task(self.do_discover())
self._started_event.set()

async def do_discover(self) -> None:
"""Send number of discovery datagrams."""
req = json_dumps(Discover.DISCOVERY_QUERY)
_LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY)
encrypted_req = XorEncryption.encrypt(req)
sleep_between_packets = self.discovery_timeout / self.discovery_packets
for i in range(self.discovery_packets):
for _ in range(self.discovery_packets):
if self.target in self.seen_hosts: # Stop sending for discover_single
break
self.transport.sendto(encrypted_req[4:], self.target_1) # type: ignore
self.transport.sendto(Discover.DISCOVERY_QUERY_2, self.target_2) # type: ignore
if i < self.discovery_packets - 1:
await asyncio.sleep(sleep_between_packets)
await asyncio.sleep(sleep_between_packets)

def datagram_received(self, data, addr) -> None:
"""Handle discovery responses."""
Expand Down Expand Up @@ -145,7 +166,7 @@ def datagram_received(self, data, addr) -> None:
_LOGGER.debug("Unsupported device found at %s << %s", ip, udex)
self.unsupported_device_exceptions[ip] = udex
if self.on_unsupported is not None:
asyncio.ensure_future(self.on_unsupported(udex))
self._run_callback_task(self.on_unsupported(udex))
self._handle_discovered_event()
return
except SmartDeviceException as ex:
Expand All @@ -157,16 +178,16 @@ def datagram_received(self, data, addr) -> None:
self.discovered_devices[ip] = device

if self.on_discovered is not None:
asyncio.ensure_future(self.on_discovered(device))
self._run_callback_task(self.on_discovered(device))

self._handle_discovered_event()

def _handle_discovered_event(self):
"""If discovered_event is available set it and cancel discover_task."""
if self.discovered_event is not None:
"""If target is in seen_hosts cancel discover_task."""
if self.target in self.seen_hosts:
self.target_discovered = True
if self.discover_task:
self.discover_task.cancel()
self.discovered_event.set()

def error_received(self, ex):
"""Handle asyncio.Protocol errors."""
Expand Down Expand Up @@ -289,7 +310,11 @@ async def discover(

try:
_LOGGER.debug("Waiting %s seconds for responses...", discovery_timeout)
await asyncio.sleep(discovery_timeout)
await protocol.wait_for_discovery_to_complete()
except SmartDeviceException as ex:
for device in protocol.discovered_devices.values():
await device.protocol.close()
raise ex
finally:
transport.close()

Expand Down Expand Up @@ -322,7 +347,6 @@ async def discover_single(
:return: Object for querying/controlling found device.
"""
loop = asyncio.get_event_loop()
event = asyncio.Event()

try:
ipaddress.ip_address(host)
Expand Down Expand Up @@ -352,7 +376,6 @@ async def discover_single(
lambda: _DiscoverProtocol(
target=ip,
port=port,
discovered_event=event,
credentials=credentials,
timeout=timeout,
discovery_timeout=discovery_timeout,
Expand All @@ -365,13 +388,7 @@ async def discover_single(
_LOGGER.debug(
"Waiting a total of %s seconds for responses...", discovery_timeout
)

async with asyncio_timeout(discovery_timeout):
await event.wait()
except asyncio.TimeoutError as ex:
raise TimeoutException(
f"Timed out getting discovery response for {host}"
) from ex
await protocol.wait_for_discovery_to_complete()
finally:
transport.close()

Expand All @@ -384,7 +401,7 @@ async def discover_single(
elif ip in protocol.invalid_device_exceptions:
raise protocol.invalid_device_exceptions[ip]
else:
raise SmartDeviceException(f"Unable to get discovery response for {host}")
raise TimeoutException(f"Timed out getting discovery response for {host}")

@staticmethod
def _get_device_class(info: dict) -> Type[Device]:
Expand Down
4 changes: 2 additions & 2 deletions kasa/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ class _DiscoveryMock:
login_version,
)

def mock_discover(self):
async def mock_discover(self):
port = (
dm.port_override
if dm.port_override and dm.discovery_port != 20002
Expand Down Expand Up @@ -561,7 +561,7 @@ def unsupported_device_info(request, mocker):
discovery_data = request.param
host = "127.0.0.1"

def mock_discover(self):
async def mock_discover(self):
if discovery_data:
data = (
b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
Expand Down
22 changes: 21 additions & 1 deletion kasa/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ async def test_brightness(dev):
@device_iot
async def test_json_output(dev: Device, mocker):
"""Test that the json output produces correct output."""
mocker.patch("kasa.Discover.discover", return_value=[dev])
mocker.patch("kasa.Discover.discover", return_value={"127.0.0.1": dev})
runner = CliRunner()
res = await runner.invoke(cli, ["--json", "state"], obj=dev)
assert res.exit_code == 0
Expand Down Expand Up @@ -415,6 +415,26 @@ async def test_discover(discovery_mock, mocker):
assert res.exit_code == 0


async def test_discover_host(discovery_mock, mocker):
"""Test discovery output."""
runner = CliRunner()
res = await runner.invoke(
cli,
[
"--discovery-timeout",
0,
"--host",
"127.0.0.123",
"--username",
"foo",
"--password",
"bar",
"--verbose",
],
)
assert res.exit_code == 0


async def test_discover_unsupported(unsupported_device_info):
"""Test discovery output."""
runner = CliRunner()
Expand Down
Loading