Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d03f535

Browse files
authored
Fix discover cli command with host (#1437)
1 parent 1be8767 commit d03f535

File tree

4 files changed

+103
-24
lines changed

4 files changed

+103
-24
lines changed

kasa/cli/common.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from contextlib import contextmanager
1111
from functools import singledispatch, update_wrapper, wraps
1212
from gettext import gettext
13-
from typing import TYPE_CHECKING, Any, Final
13+
from typing import TYPE_CHECKING, Any, Final, NoReturn
1414

1515
import asyncclick as click
1616

@@ -57,7 +57,7 @@ def echo(*args, **kwargs) -> None:
5757
_echo(*args, **kwargs)
5858

5959

60-
def error(msg: str) -> None:
60+
def error(msg: str) -> NoReturn:
6161
"""Print an error and exit."""
6262
echo(f"[bold red]{msg}[/bold red]")
6363
sys.exit(1)
@@ -68,6 +68,16 @@ def json_formatter_cb(result: Any, **kwargs) -> None:
6868
if not kwargs.get("json"):
6969
return
7070

71+
# Calling the discover command directly always returns a DeviceDict so if host
72+
# was specified just format the device json
73+
if (
74+
(host := kwargs.get("host"))
75+
and isinstance(result, dict)
76+
and (dev := result.get(host))
77+
and isinstance(dev, Device)
78+
):
79+
result = dev
80+
7181
@singledispatch
7282
def to_serializable(val):
7383
"""Regular obj-to-string for json serialization.
@@ -85,6 +95,25 @@ def _device_to_serializable(val: Device):
8595
print(json_content)
8696

8797

98+
async def invoke_subcommand(
99+
command: click.BaseCommand,
100+
ctx: click.Context,
101+
args: list[str] | None = None,
102+
**extra: Any,
103+
) -> Any:
104+
"""Invoke a click subcommand.
105+
106+
Calling ctx.Invoke() treats the command like a simple callback and doesn't
107+
process any result_callbacks so we use this pattern from the click docs
108+
https://click.palletsprojects.com/en/stable/exceptions/#what-if-i-don-t-want-that.
109+
"""
110+
if args is None:
111+
args = []
112+
sub_ctx = await command.make_context(command.name, args, parent=ctx, **extra)
113+
async with sub_ctx:
114+
return await command.invoke(sub_ctx)
115+
116+
88117
def pass_dev_or_child(wrapped_function: Callable) -> Callable:
89118
"""Pass the device or child to the click command based on the child options."""
90119
child_help = (

kasa/cli/device.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
from pprint import pformat as pf
6+
from typing import TYPE_CHECKING
67

78
import asyncclick as click
89

@@ -82,6 +83,8 @@ async def state(ctx, dev: Device):
8283
echo()
8384
from .discover import _echo_discovery_info
8485

86+
if TYPE_CHECKING:
87+
assert dev._discovery_info
8588
_echo_discovery_info(dev._discovery_info)
8689

8790
return dev.internal_state

kasa/cli/discover.py

Lines changed: 57 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import asyncio
66
from pprint import pformat as pf
7+
from typing import TYPE_CHECKING, cast
78

89
import asyncclick as click
910

@@ -17,8 +18,12 @@
1718
from kasa.discover import (
1819
NEW_DISCOVERY_REDACTORS,
1920
ConnectAttempt,
21+
DeviceDict,
2022
DiscoveredRaw,
2123
DiscoveryResult,
24+
OnDiscoveredCallable,
25+
OnDiscoveredRawCallable,
26+
OnUnsupportedCallable,
2227
)
2328
from kasa.iot.iotdevice import _extract_sys_info
2429
from kasa.protocols.iotprotocol import REDACTORS as IOT_REDACTORS
@@ -30,15 +35,33 @@
3035

3136
@click.group(invoke_without_command=True)
3237
@click.pass_context
33-
async def discover(ctx):
38+
async def discover(ctx: click.Context):
3439
"""Discover devices in the network."""
3540
if ctx.invoked_subcommand is None:
3641
return await ctx.invoke(detail)
3742

3843

44+
@discover.result_callback()
45+
@click.pass_context
46+
async def _close_protocols(ctx: click.Context, discovered: DeviceDict):
47+
"""Close all the device protocols if discover was invoked directly by the user."""
48+
if _discover_is_root_cmd(ctx):
49+
for dev in discovered.values():
50+
await dev.disconnect()
51+
return discovered
52+
53+
54+
def _discover_is_root_cmd(ctx: click.Context) -> bool:
55+
"""Will return true if discover was invoked directly by the user."""
56+
root_ctx = ctx.find_root()
57+
return (
58+
root_ctx.invoked_subcommand is None or root_ctx.invoked_subcommand == "discover"
59+
)
60+
61+
3962
@discover.command()
4063
@click.pass_context
41-
async def detail(ctx):
64+
async def detail(ctx: click.Context) -> DeviceDict:
4265
"""Discover devices in the network using udp broadcasts."""
4366
unsupported = []
4467
auth_failed = []
@@ -59,10 +82,14 @@ async def print_unsupported(unsupported_exception: UnsupportedDeviceError) -> No
5982
from .device import state
6083

6184
async def print_discovered(dev: Device) -> None:
85+
if TYPE_CHECKING:
86+
assert ctx.parent
6287
async with sem:
6388
try:
6489
await dev.update()
6590
except AuthenticationError:
91+
if TYPE_CHECKING:
92+
assert dev._discovery_info
6693
auth_failed.append(dev._discovery_info)
6794
echo("== Authentication failed for device ==")
6895
_echo_discovery_info(dev._discovery_info)
@@ -73,9 +100,11 @@ async def print_discovered(dev: Device) -> None:
73100
echo()
74101

75102
discovered = await _discover(
76-
ctx, print_discovered=print_discovered, print_unsupported=print_unsupported
103+
ctx,
104+
print_discovered=print_discovered if _discover_is_root_cmd(ctx) else None,
105+
print_unsupported=print_unsupported,
77106
)
78-
if ctx.parent.parent.params["host"]:
107+
if ctx.find_root().params["host"]:
79108
return discovered
80109

81110
echo(f"Found {len(discovered)} devices")
@@ -96,7 +125,7 @@ async def print_discovered(dev: Device) -> None:
96125
help="Set flag to redact sensitive data from raw output.",
97126
)
98127
@click.pass_context
99-
async def raw(ctx, redact: bool):
128+
async def raw(ctx: click.Context, redact: bool) -> DeviceDict:
100129
"""Return raw discovery data returned from devices."""
101130

102131
def print_raw(discovered: DiscoveredRaw):
@@ -116,7 +145,7 @@ def print_raw(discovered: DiscoveredRaw):
116145

117146
@discover.command()
118147
@click.pass_context
119-
async def list(ctx):
148+
async def list(ctx: click.Context) -> DeviceDict:
120149
"""List devices in the network in a table using udp broadcasts."""
121150
sem = asyncio.Semaphore()
122151

@@ -147,18 +176,24 @@ async def print_unsupported(unsupported_exception: UnsupportedDeviceError):
147176
f"{'HOST':<15} {'MODEL':<9} {'DEVICE FAMILY':<20} {'ENCRYPT':<7} "
148177
f"{'HTTPS':<5} {'LV':<3} {'ALIAS'}"
149178
)
150-
return await _discover(
179+
discovered = await _discover(
151180
ctx,
152181
print_discovered=print_discovered,
153182
print_unsupported=print_unsupported,
154183
do_echo=False,
155184
)
185+
return discovered
156186

157187

158188
async def _discover(
159-
ctx, *, print_discovered=None, print_unsupported=None, print_raw=None, do_echo=True
160-
):
161-
params = ctx.parent.parent.params
189+
ctx: click.Context,
190+
*,
191+
print_discovered: OnDiscoveredCallable | None = None,
192+
print_unsupported: OnUnsupportedCallable | None = None,
193+
print_raw: OnDiscoveredRawCallable | None = None,
194+
do_echo=True,
195+
) -> DeviceDict:
196+
params = ctx.find_root().params
162197
target = params["target"]
163198
username = params["username"]
164199
password = params["password"]
@@ -170,8 +205,9 @@ async def _discover(
170205
credentials = Credentials(username, password) if username and password else None
171206

172207
if host:
208+
host = cast(str, host)
173209
echo(f"Discovering device {host} for {discovery_timeout} seconds")
174-
return await Discover.discover_single(
210+
dev = await Discover.discover_single(
175211
host,
176212
port=port,
177213
credentials=credentials,
@@ -180,6 +216,12 @@ async def _discover(
180216
on_unsupported=print_unsupported,
181217
on_discovered_raw=print_raw,
182218
)
219+
if dev:
220+
if print_discovered:
221+
await print_discovered(dev)
222+
return {host: dev}
223+
else:
224+
return {}
183225
if do_echo:
184226
echo(f"Discovering devices on {target} for {discovery_timeout} seconds")
185227
discovered_devices = await Discover.discover(
@@ -193,21 +235,18 @@ async def _discover(
193235
on_discovered_raw=print_raw,
194236
)
195237

196-
for device in discovered_devices.values():
197-
await device.protocol.close()
198-
199238
return discovered_devices
200239

201240

202241
@discover.command()
203242
@click.pass_context
204-
async def config(ctx):
243+
async def config(ctx: click.Context) -> DeviceDict:
205244
"""Bypass udp discovery and try to show connection config for a device.
206245
207246
Bypasses udp discovery and shows the parameters required to connect
208247
directly to the device.
209248
"""
210-
params = ctx.parent.parent.params
249+
params = ctx.find_root().params
211250
username = params["username"]
212251
password = params["password"]
213252
timeout = params["timeout"]
@@ -239,6 +278,7 @@ def on_attempt(connect_attempt: ConnectAttempt, success: bool) -> None:
239278
f"--encrypt-type {cparams.encryption_type.value} "
240279
f"{'--https' if cparams.https else '--no-https'}"
241280
)
281+
return {host: dev}
242282
else:
243283
error(f"Unable to connect to {host}")
244284

@@ -251,7 +291,7 @@ def _echo_dictionary(discovery_info: dict) -> None:
251291
echo(f"\t{key_name_and_spaces}{value}")
252292

253293

254-
def _echo_discovery_info(discovery_info) -> None:
294+
def _echo_discovery_info(discovery_info: dict) -> None:
255295
# We don't have discovery info when all connection params are passed manually
256296
if discovery_info is None:
257297
return

kasa/cli/main.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
CatchAllExceptions,
2323
echo,
2424
error,
25+
invoke_subcommand,
2526
json_formatter_cb,
2627
pass_dev_or_child,
2728
)
@@ -295,9 +296,10 @@ async def cli(
295296
echo("No host name given, trying discovery..")
296297
from .discover import discover
297298

298-
return await ctx.invoke(discover)
299+
return await invoke_subcommand(discover, ctx)
299300

300301
device_updated = False
302+
device_discovered = False
301303

302304
if type is not None and type not in {"smart", "camera"}:
303305
from kasa.deviceconfig import DeviceConfig
@@ -351,12 +353,14 @@ async def cli(
351353
return
352354
echo(f"Found hostname by alias: {dev.host}")
353355
device_updated = True
354-
else:
356+
else: # host will be set
355357
from .discover import discover
356358

357-
dev = await ctx.invoke(discover)
358-
if not dev:
359+
discovered = await invoke_subcommand(discover, ctx)
360+
if not discovered:
359361
error(f"Unable to create device for {host}")
362+
dev = discovered[host]
363+
device_discovered = True
360364

361365
# Skip update on specific commands, or if device factory,
362366
# that performs an update was used for the device.
@@ -372,11 +376,14 @@ async def async_wrapped_device(device: Device):
372376

373377
ctx.obj = await ctx.with_async_resource(async_wrapped_device(dev))
374378

375-
if ctx.invoked_subcommand is None:
379+
# discover command has already invoked state
380+
if ctx.invoked_subcommand is None and not device_discovered:
376381
from .device import state
377382

378383
return await ctx.invoke(state)
379384

385+
return dev
386+
380387

381388
@cli.command()
382389
@pass_dev_or_child

0 commit comments

Comments
 (0)