From 704ea99c1e8dafeb2b2b48b206e02d25912c336c Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Mon, 29 Apr 2024 08:30:35 +0100 Subject: [PATCH 1/6] Put modules back on children for wall switches --- kasa/device.py | 7 +++- kasa/iot/iotdevice.py | 38 +++++++++++++--------- kasa/iot/iotstrip.py | 2 +- kasa/module.py | 5 ++- kasa/smart/smartdevice.py | 59 ++++++++++++++++++++-------------- kasa/tests/test_smartdevice.py | 4 +-- 6 files changed, 70 insertions(+), 45 deletions(-) diff --git a/kasa/device.py b/kasa/device.py index dda7822f9..8a81030f8 100644 --- a/kasa/device.py +++ b/kasa/device.py @@ -15,6 +15,7 @@ from .exceptions import KasaException from .feature import Feature from .iotprotocol import IotProtocol +from .module import Module from .protocol import BaseProtocol from .xortransport import XorTransport @@ -72,7 +73,6 @@ def __init__( self._last_update: Any = None self._discovery_info: dict[str, Any] | None = None - self.modules: dict[str, Any] = {} self._features: dict[str, Feature] = {} self._parent: Device | None = None self._children: Mapping[str, Device] = {} @@ -111,6 +111,11 @@ async def disconnect(self): """Disconnect and close any underlying connection resources.""" await self.protocol.close() + @property + @abstractmethod + def modules(self) -> Mapping[str, Module]: + """Return the device modules.""" + @property @abstractmethod def is_on(self) -> bool: diff --git a/kasa/iot/iotdevice.py b/kasa/iot/iotdevice.py index d4551d0db..6ade57db9 100755 --- a/kasa/iot/iotdevice.py +++ b/kasa/iot/iotdevice.py @@ -189,12 +189,18 @@ def __init__( self._supported_modules: dict[str, IotModule] | None = None self._legacy_features: set[str] = set() self._children: Mapping[str, IotDevice] = {} + self._modules: dict[str, IotModule] = {} @property def children(self) -> Sequence[IotDevice]: """Return list of children.""" return list(self._children.values()) + @property + def modules(self) -> dict[str, IotModule]: + """Return the device modules.""" + return self._modules + def add_module(self, name: str, module: IotModule): """Register a module.""" if name in self.modules: @@ -420,31 +426,31 @@ async def set_alias(self, alias: str) -> None: """Set the device name (alias).""" return await self._query_helper("system", "set_dev_alias", {"alias": alias}) - @property # type: ignore + @property @requires_update def time(self) -> datetime: """Return current time from the device.""" - return self.modules["time"].time + return self.modules["time"].time # type: ignore[attr-defined] - @property # type: ignore + @property @requires_update def timezone(self) -> dict: """Return the current timezone.""" - return self.modules["time"].timezone + return self.modules["time"].timezone # type: ignore[attr-defined] async def get_time(self) -> datetime | None: """Return current time from the device, if available.""" _LOGGER.warning( "Use `time` property instead, this call will be removed in the future." ) - return await self.modules["time"].get_time() + return await self.modules["time"].get_time() # type: ignore[attr-defined] async def get_timezone(self) -> dict: """Return timezone information.""" _LOGGER.warning( "Use `timezone` property instead, this call will be removed in the future." ) - return await self.modules["time"].get_timezone() + return await self.modules["time"].get_timezone() # type: ignore[attr-defined] @property # type: ignore @requires_update @@ -520,31 +526,31 @@ async def set_mac(self, mac): """ return await self._query_helper("system", "set_mac_addr", {"mac": mac}) - @property # type: ignore + @property @requires_update def emeter_realtime(self) -> EmeterStatus: """Return current energy readings.""" self._verify_emeter() - return EmeterStatus(self.modules["emeter"].realtime) + return EmeterStatus(self.modules["emeter"].realtime) # type: ignore[attr-defined] async def get_emeter_realtime(self) -> EmeterStatus: """Retrieve current energy readings.""" self._verify_emeter() - return EmeterStatus(await self.modules["emeter"].get_realtime()) + return EmeterStatus(await self.modules["emeter"].get_realtime()) # type: ignore[attr-defined] - @property # type: ignore + @property @requires_update def emeter_today(self) -> float | None: """Return today's energy consumption in kWh.""" self._verify_emeter() - return self.modules["emeter"].emeter_today + return self.modules["emeter"].emeter_today # type: ignore[attr-defined] - @property # type: ignore + @property @requires_update def emeter_this_month(self) -> float | None: """Return this month's energy consumption in kWh.""" self._verify_emeter() - return self.modules["emeter"].emeter_this_month + return self.modules["emeter"].emeter_this_month # type: ignore[attr-defined] async def get_emeter_daily( self, year: int | None = None, month: int | None = None, kwh: bool = True @@ -558,7 +564,7 @@ async def get_emeter_daily( :return: mapping of day of month to value """ self._verify_emeter() - return await self.modules["emeter"].get_daystat(year=year, month=month, kwh=kwh) + return await self.modules["emeter"].get_daystat(year=year, month=month, kwh=kwh) # type: ignore[attr-defined] @requires_update async def get_emeter_monthly( @@ -571,13 +577,13 @@ async def get_emeter_monthly( :return: dict: mapping of month to value """ self._verify_emeter() - return await self.modules["emeter"].get_monthstat(year=year, kwh=kwh) + return await self.modules["emeter"].get_monthstat(year=year, kwh=kwh) # type: ignore[attr-defined] @requires_update async def erase_emeter_stats(self) -> dict: """Erase energy meter statistics.""" self._verify_emeter() - return await self.modules["emeter"].erase_stats() + return await self.modules["emeter"].erase_stats() # type: ignore[attr-defined] @requires_update async def current_consumption(self) -> float: diff --git a/kasa/iot/iotstrip.py b/kasa/iot/iotstrip.py index 99f5913d6..9e99a0748 100755 --- a/kasa/iot/iotstrip.py +++ b/kasa/iot/iotstrip.py @@ -253,7 +253,7 @@ def __init__(self, host: str, parent: IotStrip, child_id: str) -> None: self._last_update = parent._last_update self._set_sys_info(parent.sys_info) self._device_type = DeviceType.StripSocket - self.modules = {} + self._modules = {} self.protocol = parent.protocol # Must use the same connection as the parent self.add_module("time", Time(self, "time")) diff --git a/kasa/module.py b/kasa/module.py index ad0b5562a..213a2e0ac 100644 --- a/kasa/module.py +++ b/kasa/module.py @@ -4,11 +4,14 @@ import logging from abc import ABC, abstractmethod +from typing import TYPE_CHECKING -from .device import Device from .exceptions import KasaException from .feature import Feature +if TYPE_CHECKING: + from .device import Device + _LOGGER = logging.getLogger(__name__) diff --git a/kasa/smart/smartdevice.py b/kasa/smart/smartdevice.py index b325614be..69ee44f3f 100644 --- a/kasa/smart/smartdevice.py +++ b/kasa/smart/smartdevice.py @@ -47,7 +47,7 @@ def __init__( self._components_raw: dict[str, Any] | None = None self._components: dict[str, int] = {} self._state_information: dict[str, Any] = {} - self.modules: dict[str, SmartModule] = {} + self._modules: dict[str, SmartModule] = {} self._parent: SmartDevice | None = None self._children: Mapping[str, SmartDevice] = {} self._last_update = {} @@ -84,11 +84,19 @@ async def _initialize_children(self): @property def children(self) -> Sequence[SmartDevice]: """Return list of children.""" - # Wall switches with children report all modules on the parent only - if self.device_type == DeviceType.WallSwitch: - return [] return list(self._children.values()) + @property + def modules(self) -> dict[str, SmartModule]: + """Return the device modules.""" + if self._device_type == DeviceType.WallSwitch and self._children: + modules = {k: v for k, v in self._modules.items()} + for child in self._children.values(): + for modname, mod in child._modules.items(): + modules[modname] = mod + return modules + return self._modules + def _try_get_response(self, responses: dict, request: str, default=None) -> dict: response = responses.get(request) if isinstance(response, SmartErrorCode): @@ -148,7 +156,7 @@ async def update(self, update_children: bool = True): req: dict[str, Any] = {} # TODO: this could be optimized by constructing the query only once - for module in self.modules.values(): + for module in self._modules.values(): req.update(module.query()) self._last_update = resp = await self.protocol.query(req) @@ -174,19 +182,22 @@ async def _initialize_modules(self): # Some wall switches (like ks240) are internally presented as having child # devices which report the child's components on the parent's sysinfo, even # when they need to be accessed through the children. - # The logic below ensures that such devices report all but whitelisted, the - # child modules at the parent level to create an illusion of a single device. + # The logic below ensures that such devices add all but whitelisted, only on + # the child device. + skip_parent_only_modules = False + child_modules_to_skip = set() if self._parent and self._parent.device_type == DeviceType.WallSwitch: - modules = self._parent.modules skip_parent_only_modules = True - else: - modules = self.modules - skip_parent_only_modules = False + elif self._children and self.device_type == DeviceType.WallSwitch: + for child in self._children.values(): + child_modules_to_skip.update(set(child.modules.values())) for mod in SmartModule.REGISTERED_MODULES.values(): _LOGGER.debug("%s requires %s", mod, mod.REQUIRED_COMPONENT) - if skip_parent_only_modules and mod in WALL_SWITCH_PARENT_ONLY_MODULES: + if ( + skip_parent_only_modules and mod in WALL_SWITCH_PARENT_ONLY_MODULES + ) or mod in child_modules_to_skip: continue if mod.REQUIRED_COMPONENT in self._components: _LOGGER.debug( @@ -195,8 +206,8 @@ async def _initialize_modules(self): mod.__name__, ) module = mod(self, mod.REQUIRED_COMPONENT) - if module.name not in modules and await module._check_supported(): - modules[module.name] = module + if module.name not in self._modules and await module._check_supported(): + self._modules[module.name] = module async def _initialize_features(self): """Initialize device features.""" @@ -278,16 +289,16 @@ async def _initialize_features(self): ) ) - for module in self.modules.values(): + for module in self._modules.values(): for feat in module._module_features.values(): self._add_feature(feat) @property def is_cloud_connected(self): """Returns if the device is connected to the cloud.""" - if "CloudModule" not in self.modules: + if "CloudModule" not in self._modules: return False - return self.modules["CloudModule"].is_connected + return self._modules["CloudModule"].is_connected @property def sys_info(self) -> dict[str, Any]: @@ -311,10 +322,10 @@ def alias(self) -> str | None: def time(self) -> datetime: """Return the time.""" # TODO: Default to parent's time module for child devices - if self._parent and "TimeModule" in self.modules: + if self._parent and "TimeModule" in self._modules: _timemod = cast(TimeModule, self._parent.modules["TimeModule"]) # noqa: F405 else: - _timemod = cast(TimeModule, self.modules["TimeModule"]) # noqa: F405 + _timemod = cast(TimeModule, self._modules["TimeModule"]) # noqa: F405 return _timemod.time @@ -391,7 +402,7 @@ def ssid(self) -> str: @property def has_emeter(self) -> bool: """Return if the device has emeter.""" - return "EnergyModule" in self.modules + return "EnergyModule" in self._modules @property def is_on(self) -> bool: @@ -428,19 +439,19 @@ async def get_emeter_realtime(self) -> EmeterStatus: @property def emeter_realtime(self) -> EmeterStatus: """Get the emeter status.""" - energy = cast(EnergyModule, self.modules["EnergyModule"]) # noqa: F405 + energy = cast(EnergyModule, self._modules["EnergyModule"]) # noqa: F405 return energy.emeter_realtime @property def emeter_this_month(self) -> float | None: """Get the emeter value for this month.""" - energy = cast(EnergyModule, self.modules["EnergyModule"]) # noqa: F405 + energy = cast(EnergyModule, self._modules["EnergyModule"]) # noqa: F405 return energy.emeter_this_month @property def emeter_today(self) -> float | None: """Get the emeter value for today.""" - energy = cast(EnergyModule, self.modules["EnergyModule"]) # noqa: F405 + energy = cast(EnergyModule, self._modules["EnergyModule"]) # noqa: F405 return energy.emeter_today @property @@ -452,7 +463,7 @@ def on_since(self) -> datetime | None: ): return None on_time = cast(float, on_time) - if (timemod := self.modules.get("TimeModule")) is not None: + if (timemod := self._modules.get("TimeModule")) is not None: timemod = cast(TimeModule, timemod) # noqa: F405 return timemod.time - timedelta(seconds=on_time) else: # We have no device time, use current local time. diff --git a/kasa/tests/test_smartdevice.py b/kasa/tests/test_smartdevice.py index 037edaf90..899ed6e14 100644 --- a/kasa/tests/test_smartdevice.py +++ b/kasa/tests/test_smartdevice.py @@ -103,10 +103,10 @@ async def test_negotiate(dev: SmartDevice, mocker: MockerFixture): async def test_update_module_queries(dev: SmartDevice, mocker: MockerFixture): """Test that the regular update uses queries from all supported modules.""" # We need to have some modules initialized by now - assert dev.modules + assert dev._modules device_queries: dict[SmartDevice, dict[str, Any]] = {} - for mod in dev.modules.values(): + for mod in dev._modules.values(): device_queries.setdefault(mod._device, {}).update(mod.query()) spies = {} From 8c13c6a2d77682eebba60a4b52c163955e128862 Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Mon, 29 Apr 2024 09:33:28 +0100 Subject: [PATCH 2/6] Fix mypy issues --- kasa/cli.py | 10 +++++----- kasa/tests/smart/modules/test_fan.py | 7 +++++-- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/kasa/cli.py b/kasa/cli.py index 66eb89368..bb8fdd371 100755 --- a/kasa/cli.py +++ b/kasa/cli.py @@ -833,20 +833,20 @@ async def usage(dev: Device, year, month, erase): if erase: echo("Erasing usage statistics..") - return await usage.erase_stats() + return await usage.erase_stats() # type: ignore[attr-defined] if year: echo(f"== For year {year.year} ==") echo("Month, usage (minutes)") - usage_data = await usage.get_monthstat(year=year.year) + usage_data = await usage.get_monthstat(year=year.year) # type: ignore[attr-defined] elif month: echo(f"== For month {month.month} of {month.year} ==") echo("Day, usage (minutes)") - usage_data = await usage.get_daystat(year=month.year, month=month.month) + usage_data = await usage.get_daystat(year=month.year, month=month.month) # type: ignore[attr-defined] else: # Call with no argument outputs summary data and returns - echo("Today: %s minutes" % usage.usage_today) - echo("This month: %s minutes" % usage.usage_this_month) + echo("Today: %s minutes" % usage.usage_today) # type: ignore[attr-defined] + echo("This month: %s minutes" % usage.usage_this_month) # type: ignore[attr-defined] return usage diff --git a/kasa/tests/smart/modules/test_fan.py b/kasa/tests/smart/modules/test_fan.py index 559ffefe0..41d5706cc 100644 --- a/kasa/tests/smart/modules/test_fan.py +++ b/kasa/tests/smart/modules/test_fan.py @@ -1,6 +1,9 @@ +from typing import cast + from pytest_mock import MockerFixture from kasa import SmartDevice +from kasa.smart.modules import FanModule from kasa.tests.device_fixtures import parametrize fan = parametrize("has fan", component_filter="fan_control", protocol_filter={"SMART"}) @@ -9,7 +12,7 @@ @fan async def test_fan_speed(dev: SmartDevice, mocker: MockerFixture): """Test fan speed feature.""" - fan = dev.modules.get("FanModule") + fan = cast(FanModule, dev.modules.get("FanModule")) assert fan level_feature = fan._module_features["fan_speed_level"] @@ -32,7 +35,7 @@ async def test_fan_speed(dev: SmartDevice, mocker: MockerFixture): @fan async def test_sleep_mode(dev: SmartDevice, mocker: MockerFixture): """Test sleep mode feature.""" - fan = dev.modules.get("FanModule") + fan = cast(FanModule, dev.modules.get("FanModule")) assert fan sleep_feature = fan._module_features["fan_sleep_mode"] assert isinstance(sleep_feature.value, bool) From 3be3abd98d294c1791063bee244b923942264b26 Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Mon, 29 Apr 2024 14:59:15 +0100 Subject: [PATCH 3/6] Update post review --- kasa/smart/smartdevice.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/kasa/smart/smartdevice.py b/kasa/smart/smartdevice.py index 69ee44f3f..92edbc18d 100644 --- a/kasa/smart/smartdevice.py +++ b/kasa/smart/smartdevice.py @@ -48,6 +48,7 @@ def __init__( self._components: dict[str, int] = {} self._state_information: dict[str, Any] = {} self._modules: dict[str, SmartModule] = {} + self._combined_modules: dict[str, SmartModule] | None = None self._parent: SmartDevice | None = None self._children: Mapping[str, SmartDevice] = {} self._last_update = {} @@ -90,11 +91,12 @@ def children(self) -> Sequence[SmartDevice]: def modules(self) -> dict[str, SmartModule]: """Return the device modules.""" if self._device_type == DeviceType.WallSwitch and self._children: - modules = {k: v for k, v in self._modules.items()} - for child in self._children.values(): - for modname, mod in child._modules.items(): - modules[modname] = mod - return modules + if self._combined_modules is None: + self._combined_modules = {k: v for k, v in self._modules.items()} + for child in self._children.values(): + for modname, mod in child._modules.items(): + self._combined_modules[modname] = mod + return self._combined_modules return self._modules def _try_get_response(self, responses: dict, request: str, default=None) -> dict: @@ -206,7 +208,7 @@ async def _initialize_modules(self): mod.__name__, ) module = mod(self, mod.REQUIRED_COMPONENT) - if module.name not in self._modules and await module._check_supported(): + if await module._check_supported(): self._modules[module.name] = module async def _initialize_features(self): @@ -296,9 +298,9 @@ async def _initialize_features(self): @property def is_cloud_connected(self): """Returns if the device is connected to the cloud.""" - if "CloudModule" not in self._modules: + if "CloudModule" not in self.modules: return False - return self._modules["CloudModule"].is_connected + return self.modules["CloudModule"].is_connected @property def sys_info(self) -> dict[str, Any]: @@ -322,10 +324,10 @@ def alias(self) -> str | None: def time(self) -> datetime: """Return the time.""" # TODO: Default to parent's time module for child devices - if self._parent and "TimeModule" in self._modules: + if self._parent and "TimeModule" in self.modules: _timemod = cast(TimeModule, self._parent.modules["TimeModule"]) # noqa: F405 else: - _timemod = cast(TimeModule, self._modules["TimeModule"]) # noqa: F405 + _timemod = cast(TimeModule, self.modules["TimeModule"]) # noqa: F405 return _timemod.time @@ -402,7 +404,7 @@ def ssid(self) -> str: @property def has_emeter(self) -> bool: """Return if the device has emeter.""" - return "EnergyModule" in self._modules + return "EnergyModule" in self.modules @property def is_on(self) -> bool: @@ -439,19 +441,19 @@ async def get_emeter_realtime(self) -> EmeterStatus: @property def emeter_realtime(self) -> EmeterStatus: """Get the emeter status.""" - energy = cast(EnergyModule, self._modules["EnergyModule"]) # noqa: F405 + energy = cast(EnergyModule, self.modules["EnergyModule"]) # noqa: F405 return energy.emeter_realtime @property def emeter_this_month(self) -> float | None: """Get the emeter value for this month.""" - energy = cast(EnergyModule, self._modules["EnergyModule"]) # noqa: F405 + energy = cast(EnergyModule, self.modules["EnergyModule"]) # noqa: F405 return energy.emeter_this_month @property def emeter_today(self) -> float | None: """Get the emeter value for today.""" - energy = cast(EnergyModule, self._modules["EnergyModule"]) # noqa: F405 + energy = cast(EnergyModule, self.modules["EnergyModule"]) # noqa: F405 return energy.emeter_today @property @@ -463,7 +465,7 @@ def on_since(self) -> datetime | None: ): return None on_time = cast(float, on_time) - if (timemod := self._modules.get("TimeModule")) is not None: + if (timemod := self.modules.get("TimeModule")) is not None: timemod = cast(TimeModule, timemod) # noqa: F405 return timemod.time - timedelta(seconds=on_time) else: # We have no device time, use current local time. From 94de66858795f6a4b28d43598356b8ca4585a6f5 Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Mon, 29 Apr 2024 16:09:19 +0100 Subject: [PATCH 4/6] Consolidate logic inside _initialize_modules --- kasa/smart/smartdevice.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/kasa/smart/smartdevice.py b/kasa/smart/smartdevice.py index 92edbc18d..e07a37a1f 100644 --- a/kasa/smart/smartdevice.py +++ b/kasa/smart/smartdevice.py @@ -48,7 +48,8 @@ def __init__( self._components: dict[str, int] = {} self._state_information: dict[str, Any] = {} self._modules: dict[str, SmartModule] = {} - self._combined_modules: dict[str, SmartModule] | None = None + self._exposes_child_modules = False + self._combined_modules: dict[str, SmartModule] = {} self._parent: SmartDevice | None = None self._children: Mapping[str, SmartDevice] = {} self._last_update = {} @@ -90,12 +91,7 @@ def children(self) -> Sequence[SmartDevice]: @property def modules(self) -> dict[str, SmartModule]: """Return the device modules.""" - if self._device_type == DeviceType.WallSwitch and self._children: - if self._combined_modules is None: - self._combined_modules = {k: v for k, v in self._modules.items()} - for child in self._children.values(): - for modname, mod in child._modules.items(): - self._combined_modules[modname] = mod + if self._exposes_child_modules: return self._combined_modules return self._modules @@ -187,19 +183,21 @@ async def _initialize_modules(self): # The logic below ensures that such devices add all but whitelisted, only on # the child device. skip_parent_only_modules = False - child_modules_to_skip = set() + child_modules_to_skip = {} if self._parent and self._parent.device_type == DeviceType.WallSwitch: skip_parent_only_modules = True elif self._children and self.device_type == DeviceType.WallSwitch: + # _initialize_modules is called on the parent after the children + self._exposes_child_modules = True for child in self._children.values(): - child_modules_to_skip.update(set(child.modules.values())) + child_modules_to_skip.update(**child.modules) for mod in SmartModule.REGISTERED_MODULES.values(): _LOGGER.debug("%s requires %s", mod, mod.REQUIRED_COMPONENT) if ( skip_parent_only_modules and mod in WALL_SWITCH_PARENT_ONLY_MODULES - ) or mod in child_modules_to_skip: + ) or mod.__name__ in child_modules_to_skip: continue if mod.REQUIRED_COMPONENT in self._components: _LOGGER.debug( @@ -211,6 +209,10 @@ async def _initialize_modules(self): if await module._check_supported(): self._modules[module.name] = module + if self._exposes_child_modules: + self._combined_modules = {k: v for k, v in self._modules.items()} + self._combined_modules.update(**child_modules_to_skip) + async def _initialize_features(self): """Initialize device features.""" self._add_feature( From df4a3109603042d873bf1e625e1ddb29acfe8761 Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Mon, 29 Apr 2024 16:42:24 +0100 Subject: [PATCH 5/6] Drop combined_modules and add child modules directly to parent modules --- kasa/smart/smartdevice.py | 6 +----- kasa/tests/test_smartdevice.py | 12 ++++++------ 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/kasa/smart/smartdevice.py b/kasa/smart/smartdevice.py index e07a37a1f..80528fe44 100644 --- a/kasa/smart/smartdevice.py +++ b/kasa/smart/smartdevice.py @@ -49,7 +49,6 @@ def __init__( self._state_information: dict[str, Any] = {} self._modules: dict[str, SmartModule] = {} self._exposes_child_modules = False - self._combined_modules: dict[str, SmartModule] = {} self._parent: SmartDevice | None = None self._children: Mapping[str, SmartDevice] = {} self._last_update = {} @@ -91,8 +90,6 @@ def children(self) -> Sequence[SmartDevice]: @property def modules(self) -> dict[str, SmartModule]: """Return the device modules.""" - if self._exposes_child_modules: - return self._combined_modules return self._modules def _try_get_response(self, responses: dict, request: str, default=None) -> dict: @@ -210,8 +207,7 @@ async def _initialize_modules(self): self._modules[module.name] = module if self._exposes_child_modules: - self._combined_modules = {k: v for k, v in self._modules.items()} - self._combined_modules.update(**child_modules_to_skip) + self._modules.update(**child_modules_to_skip) async def _initialize_features(self): """Initialize device features.""" diff --git a/kasa/tests/test_smartdevice.py b/kasa/tests/test_smartdevice.py index 899ed6e14..2b39e105a 100644 --- a/kasa/tests/test_smartdevice.py +++ b/kasa/tests/test_smartdevice.py @@ -110,15 +110,15 @@ async def test_update_module_queries(dev: SmartDevice, mocker: MockerFixture): device_queries.setdefault(mod._device, {}).update(mod.query()) spies = {} - for dev in device_queries: - spies[dev] = mocker.spy(dev.protocol, "query") + for device in device_queries: + spies[device] = mocker.spy(device.protocol, "query") await dev.update() - for dev in device_queries: - if device_queries[dev]: - spies[dev].assert_called_with(device_queries[dev]) + for device in device_queries: + if device_queries[device]: + spies[device].assert_called_with(device_queries[device]) else: - spies[dev].assert_not_called() + spies[device].assert_not_called() @bulb_smart From 1cdd3e7d94bf580f2902e9631214e7eef6272d53 Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Mon, 29 Apr 2024 17:26:29 +0100 Subject: [PATCH 6/6] Replace mypy ignores with cast --- kasa/cli.py | 13 +++++++------ kasa/iot/iotdevice.py | 30 +++++++++++++++++------------- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/kasa/cli.py b/kasa/cli.py index bb8fdd371..317bf0383 100755 --- a/kasa/cli.py +++ b/kasa/cli.py @@ -39,6 +39,7 @@ IotStrip, IotWallSwitch, ) +from kasa.iot.modules import Usage from kasa.smart import SmartBulb, SmartDevice try: @@ -829,24 +830,24 @@ async def usage(dev: Device, year, month, erase): Daily and monthly data provided in CSV format. """ echo("[bold]== Usage ==[/bold]") - usage = dev.modules["usage"] + usage = cast(Usage, dev.modules["usage"]) if erase: echo("Erasing usage statistics..") - return await usage.erase_stats() # type: ignore[attr-defined] + return await usage.erase_stats() if year: echo(f"== For year {year.year} ==") echo("Month, usage (minutes)") - usage_data = await usage.get_monthstat(year=year.year) # type: ignore[attr-defined] + usage_data = await usage.get_monthstat(year=year.year) elif month: echo(f"== For month {month.month} of {month.year} ==") echo("Day, usage (minutes)") - usage_data = await usage.get_daystat(year=month.year, month=month.month) # type: ignore[attr-defined] + usage_data = await usage.get_daystat(year=month.year, month=month.month) else: # Call with no argument outputs summary data and returns - echo("Today: %s minutes" % usage.usage_today) # type: ignore[attr-defined] - echo("This month: %s minutes" % usage.usage_this_month) # type: ignore[attr-defined] + echo("Today: %s minutes" % usage.usage_today) + echo("This month: %s minutes" % usage.usage_this_month) return usage diff --git a/kasa/iot/iotdevice.py b/kasa/iot/iotdevice.py index 6ade57db9..81b5eddac 100755 --- a/kasa/iot/iotdevice.py +++ b/kasa/iot/iotdevice.py @@ -19,7 +19,7 @@ import inspect import logging from datetime import datetime, timedelta -from typing import Any, Mapping, Sequence +from typing import Any, Mapping, Sequence, cast from ..device import Device, WifiNetwork from ..deviceconfig import DeviceConfig @@ -28,7 +28,7 @@ from ..feature import Feature from ..protocol import BaseProtocol from .iotmodule import IotModule -from .modules import Emeter +from .modules import Emeter, Time _LOGGER = logging.getLogger(__name__) @@ -430,27 +430,27 @@ async def set_alias(self, alias: str) -> None: @requires_update def time(self) -> datetime: """Return current time from the device.""" - return self.modules["time"].time # type: ignore[attr-defined] + return cast(Time, self.modules["time"]).time @property @requires_update def timezone(self) -> dict: """Return the current timezone.""" - return self.modules["time"].timezone # type: ignore[attr-defined] + return cast(Time, self.modules["time"]).timezone async def get_time(self) -> datetime | None: """Return current time from the device, if available.""" _LOGGER.warning( "Use `time` property instead, this call will be removed in the future." ) - return await self.modules["time"].get_time() # type: ignore[attr-defined] + return await cast(Time, self.modules["time"]).get_time() async def get_timezone(self) -> dict: """Return timezone information.""" _LOGGER.warning( "Use `timezone` property instead, this call will be removed in the future." ) - return await self.modules["time"].get_timezone() # type: ignore[attr-defined] + return await cast(Time, self.modules["time"]).get_timezone() @property # type: ignore @requires_update @@ -531,26 +531,26 @@ async def set_mac(self, mac): def emeter_realtime(self) -> EmeterStatus: """Return current energy readings.""" self._verify_emeter() - return EmeterStatus(self.modules["emeter"].realtime) # type: ignore[attr-defined] + return EmeterStatus(cast(Emeter, self.modules["emeter"]).realtime) async def get_emeter_realtime(self) -> EmeterStatus: """Retrieve current energy readings.""" self._verify_emeter() - return EmeterStatus(await self.modules["emeter"].get_realtime()) # type: ignore[attr-defined] + return EmeterStatus(await cast(Emeter, self.modules["emeter"]).get_realtime()) @property @requires_update def emeter_today(self) -> float | None: """Return today's energy consumption in kWh.""" self._verify_emeter() - return self.modules["emeter"].emeter_today # type: ignore[attr-defined] + return cast(Emeter, self.modules["emeter"]).emeter_today @property @requires_update def emeter_this_month(self) -> float | None: """Return this month's energy consumption in kWh.""" self._verify_emeter() - return self.modules["emeter"].emeter_this_month # type: ignore[attr-defined] + return cast(Emeter, self.modules["emeter"]).emeter_this_month async def get_emeter_daily( self, year: int | None = None, month: int | None = None, kwh: bool = True @@ -564,7 +564,9 @@ async def get_emeter_daily( :return: mapping of day of month to value """ self._verify_emeter() - return await self.modules["emeter"].get_daystat(year=year, month=month, kwh=kwh) # type: ignore[attr-defined] + return await cast(Emeter, self.modules["emeter"]).get_daystat( + year=year, month=month, kwh=kwh + ) @requires_update async def get_emeter_monthly( @@ -577,13 +579,15 @@ async def get_emeter_monthly( :return: dict: mapping of month to value """ self._verify_emeter() - return await self.modules["emeter"].get_monthstat(year=year, kwh=kwh) # type: ignore[attr-defined] + return await cast(Emeter, self.modules["emeter"]).get_monthstat( + year=year, kwh=kwh + ) @requires_update async def erase_emeter_stats(self) -> dict: """Erase energy meter statistics.""" self._verify_emeter() - return await self.modules["emeter"].erase_stats() # type: ignore[attr-defined] + return await cast(Emeter, self.modules["emeter"]).erase_stats() @requires_update async def current_consumption(self) -> float: