From ec1d8276204b8cacda97cf5269bc1583cabda101 Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Wed, 3 Jul 2024 06:12:29 +0100 Subject: [PATCH 1/9] Add rule params to get_preset_rules --- kasa/smart/modules/lightpreset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kasa/smart/modules/lightpreset.py b/kasa/smart/modules/lightpreset.py index 8e5cae209..7635a5f86 100644 --- a/kasa/smart/modules/lightpreset.py +++ b/kasa/smart/modules/lightpreset.py @@ -140,7 +140,7 @@ def query(self) -> dict: """Query to execute during the update cycle.""" if self._state_in_sysinfo: # Child lights can have states in the child info return {} - return {self.QUERY_GETTER_NAME: None} + return {self.QUERY_GETTER_NAME: {"start_index": 0}} async def _check_supported(self): """Additional check to see if the module is supported by the device. From 3cfb6893c76b43b2d96f876db2d8f209994c9f07 Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Wed, 3 Jul 2024 07:28:32 +0100 Subject: [PATCH 2/9] Handle module errors more robustly --- kasa/smart/modules/cloud.py | 7 +++++++ kasa/smart/smartdevice.py | 30 ++++++++++++++++++++++++++---- kasa/smart/smartmodule.py | 15 ++++++++++++++- 3 files changed, 47 insertions(+), 5 deletions(-) diff --git a/kasa/smart/modules/cloud.py b/kasa/smart/modules/cloud.py index 1b64f090a..25238124a 100644 --- a/kasa/smart/modules/cloud.py +++ b/kasa/smart/modules/cloud.py @@ -18,6 +18,13 @@ class Cloud(SmartModule): QUERY_GETTER_NAME = "get_connect_cloud_state" REQUIRED_COMPONENT = "cloud_connect" + def _post_update_hook(self): + """Perform actions after a device update. + + Overrides the default behaviour to disable a module if the query returns + an error because the logic here is to treat that as not connected. + """ + def __init__(self, device: SmartDevice, module: str): super().__init__(device, module) diff --git a/kasa/smart/smartdevice.py b/kasa/smart/smartdevice.py index 408ba0278..676b7c13f 100644 --- a/kasa/smart/smartdevice.py +++ b/kasa/smart/smartdevice.py @@ -177,11 +177,20 @@ async def update(self, update_children: bool = False): self._children[info["device_id"]]._update_internal_state(info) # Call handle update for modules that want to update internal data - for module in self._modules.values(): - module._post_update_hook() + errors = [] + for module_name, module in self._modules.items(): + if not self._handle_module_post_update_hook(module): + errors.append(module_name) + for error in errors: + self._modules.pop(error) + for child in self._children.values(): - for child_module in child._modules.values(): - child_module._post_update_hook() + errors = [] + for child_module_name, child_module in child._modules.items(): + if not self._handle_module_post_update_hook(child_module): + errors.append(child_module_name) + for error in errors: + child._modules.pop(error) # We can first initialize the features after the first update. # We make here an assumption that every device has at least a single feature. @@ -190,6 +199,19 @@ async def update(self, update_children: bool = False): _LOGGER.debug("Got an update: %s", self._last_update) + def _handle_module_post_update_hook(self, module: SmartModule) -> bool: + try: + module._post_update_hook() + return True + except DeviceError as ex: + _LOGGER.error( + "Error processing %s for device %s, module will be unavailable: %s", + module.name, + self.host, + ex, + ) + return False + async def _initialize_modules(self): """Initialize modules based on component negotiation response.""" from .smartmodule import SmartModule diff --git a/kasa/smart/smartmodule.py b/kasa/smart/smartmodule.py index e78f43933..e5b05b989 100644 --- a/kasa/smart/smartmodule.py +++ b/kasa/smart/smartmodule.py @@ -5,7 +5,7 @@ import logging from typing import TYPE_CHECKING -from ..exceptions import KasaException +from ..exceptions import DeviceError, KasaException, SmartErrorCode from ..module import Module if TYPE_CHECKING: @@ -41,6 +41,14 @@ def name(self) -> str: """Name of the module.""" return getattr(self, "NAME", self.__class__.__name__) + def _post_update_hook(self): # noqa: B027 + """Perform actions after a device update. + + Any modules overriding this should ensure that self.data is + accessed + """ + assert self.data # noqa: S101 + def query(self) -> dict: """Query to execute during the update cycle. @@ -87,6 +95,11 @@ def data(self): filtered_data = {k: v for k, v in dev._last_update.items() if k in q_keys} + for data_item in filtered_data: + if isinstance(filtered_data[data_item], SmartErrorCode): + raise DeviceError( + f"{data_item} for {self.name}", error_code=filtered_data[data_item] + ) if len(filtered_data) == 1: return next(iter(filtered_data.values())) From 291e102b0e81ae3e4ed1b38cec70f3cfd88a79e5 Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Wed, 3 Jul 2024 08:30:11 +0100 Subject: [PATCH 3/9] Add check for data errors and ensure sysinfo only modules return empty queries --- kasa/smart/modules/batterysensor.py | 4 ++++ kasa/smart/modules/cloud.py | 3 +-- kasa/smart/modules/firmware.py | 5 ++--- kasa/smart/modules/frostprotection.py | 4 ++++ kasa/smart/modules/humiditysensor.py | 4 ++++ kasa/smart/modules/reportmode.py | 4 ++++ kasa/smart/modules/temperaturesensor.py | 4 ++++ kasa/smart/smartmodule.py | 9 ++++++++- 8 files changed, 31 insertions(+), 6 deletions(-) diff --git a/kasa/smart/modules/batterysensor.py b/kasa/smart/modules/batterysensor.py index 415e47d1e..7ff7df2d8 100644 --- a/kasa/smart/modules/batterysensor.py +++ b/kasa/smart/modules/batterysensor.py @@ -43,6 +43,10 @@ def _initialize_features(self): ) ) + def query(self) -> dict: + """Query to execute during the update cycle.""" + return {} + @property def battery(self): """Return battery level.""" diff --git a/kasa/smart/modules/cloud.py b/kasa/smart/modules/cloud.py index 25238124a..8346af57a 100644 --- a/kasa/smart/modules/cloud.py +++ b/kasa/smart/modules/cloud.py @@ -4,7 +4,6 @@ from typing import TYPE_CHECKING -from ...exceptions import SmartErrorCode from ...feature import Feature from ..smartmodule import SmartModule @@ -44,6 +43,6 @@ def __init__(self, device: SmartDevice, module: str): @property def is_connected(self): """Return True if device is connected to the cloud.""" - if isinstance(self.data, SmartErrorCode): + if self._has_data_error(): return False return self.data["status"] == 0 diff --git a/kasa/smart/modules/firmware.py b/kasa/smart/modules/firmware.py index 3dcaddd66..74052268f 100644 --- a/kasa/smart/modules/firmware.py +++ b/kasa/smart/modules/firmware.py @@ -13,7 +13,6 @@ from async_timeout import timeout as asyncio_timeout from pydantic.v1 import BaseModel, Field, validator -from ...exceptions import SmartErrorCode from ...feature import Feature from ..smartmodule import SmartModule @@ -136,11 +135,11 @@ def latest_firmware(self) -> str: @property def firmware_update_info(self): """Return latest firmware information.""" - fw = self.data.get("get_latest_fw") or self.data - if not self._device.is_cloud_connected or isinstance(fw, SmartErrorCode): + if not self._device.is_cloud_connected or self._has_data_error(): # Error in response, probably disconnected from the cloud. return UpdateInfo(type=0, need_to_upgrade=False) + fw = self.data.get("get_latest_fw") or self.data return UpdateInfo.parse_obj(fw) @property diff --git a/kasa/smart/modules/frostprotection.py b/kasa/smart/modules/frostprotection.py index f1811012f..440e1ed1b 100644 --- a/kasa/smart/modules/frostprotection.py +++ b/kasa/smart/modules/frostprotection.py @@ -14,6 +14,10 @@ class FrostProtection(SmartModule): REQUIRED_COMPONENT = "frost_protection" QUERY_GETTER_NAME = "get_frost_protection" + def query(self) -> dict: + """Query to execute during the update cycle.""" + return {} + @property def enabled(self) -> bool: """Return True if frost protection is on.""" diff --git a/kasa/smart/modules/humiditysensor.py b/kasa/smart/modules/humiditysensor.py index f0dcc18a4..b137736ff 100644 --- a/kasa/smart/modules/humiditysensor.py +++ b/kasa/smart/modules/humiditysensor.py @@ -45,6 +45,10 @@ def __init__(self, device: SmartDevice, module: str): ) ) + def query(self) -> dict: + """Query to execute during the update cycle.""" + return {} + @property def humidity(self): """Return current humidity in percentage.""" diff --git a/kasa/smart/modules/reportmode.py b/kasa/smart/modules/reportmode.py index 79c8ae621..8d210a5b3 100644 --- a/kasa/smart/modules/reportmode.py +++ b/kasa/smart/modules/reportmode.py @@ -32,6 +32,10 @@ def __init__(self, device: SmartDevice, module: str): ) ) + def query(self) -> dict: + """Query to execute during the update cycle.""" + return {} + @property def report_interval(self): """Reporting interval of a sensor device.""" diff --git a/kasa/smart/modules/temperaturesensor.py b/kasa/smart/modules/temperaturesensor.py index d98501508..a61859cdc 100644 --- a/kasa/smart/modules/temperaturesensor.py +++ b/kasa/smart/modules/temperaturesensor.py @@ -58,6 +58,10 @@ def __init__(self, device: SmartDevice, module: str): ) ) + def query(self) -> dict: + """Query to execute during the update cycle.""" + return {} + @property def temperature(self): """Return current humidity in percentage.""" diff --git a/kasa/smart/smartmodule.py b/kasa/smart/smartmodule.py index e5b05b989..fb946a8b3 100644 --- a/kasa/smart/smartmodule.py +++ b/kasa/smart/smartmodule.py @@ -45,7 +45,7 @@ def _post_update_hook(self): # noqa: B027 """Perform actions after a device update. Any modules overriding this should ensure that self.data is - accessed + accessed unless the module should remain active despite errors. """ assert self.data # noqa: S101 @@ -123,3 +123,10 @@ async def _check_supported(self) -> bool: color_temp_range but only supports one value. """ return True + + def _has_data_error(self) -> bool: + try: + assert self.data # noqa: S101 + return False + except DeviceError: + return True From ce4aaf933e8f04e6b08f2e4f38abcb8b3fe85c35 Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Wed, 3 Jul 2024 09:21:40 +0100 Subject: [PATCH 4/9] Fix test with child device query --- kasa/tests/test_smartdevice.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/kasa/tests/test_smartdevice.py b/kasa/tests/test_smartdevice.py index 48475a900..d025e0a7b 100644 --- a/kasa/tests/test_smartdevice.py +++ b/kasa/tests/test_smartdevice.py @@ -181,6 +181,9 @@ async def test_smartdevice_cloud_connection(dev: SmartDevice, mocker: MockerFixt assert dev.is_cloud_connected == is_connected last_update = dev._last_update + for child in dev.children: + mocker.patch.object(child.protocol, "query", return_value=child._last_update) + last_update["get_connect_cloud_state"] = {"status": 0} with patch.object(dev.protocol, "query", return_value=last_update): await dev.update() @@ -207,21 +210,18 @@ async def test_smartdevice_cloud_connection(dev: SmartDevice, mocker: MockerFixt "get_connect_cloud_state": last_update["get_connect_cloud_state"], "get_device_info": last_update["get_device_info"], } - # Child component list is not stored on the device - if "get_child_device_list" in last_update: - child_component_list = await dev.protocol.query( - "get_child_device_component_list" - ) - last_update["get_child_device_component_list"] = child_component_list[ - "get_child_device_component_list" - ] + new_dev = SmartDevice("127.0.0.1", protocol=dev.protocol) first_call = True - def side_effect_func(*_, **__): + async def side_effect_func(*args, **kwargs): nonlocal first_call - resp = initial_response if first_call else last_update + resp = ( + initial_response + if first_call + else await new_dev.protocol._query(*args, **kwargs) + ) first_call = False return resp From 39e3d05718e721f43bc6238f7fff28cb4352908e Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Wed, 3 Jul 2024 11:28:30 +0100 Subject: [PATCH 5/9] Add tests for disabling modules --- kasa/smart/modules/devicemodule.py | 7 +++ kasa/smart/modules/firmware.py | 7 +++ kasa/smartprotocol.py | 6 ++- kasa/tests/test_smartdevice.py | 74 +++++++++++++++++++++++++++++- 4 files changed, 92 insertions(+), 2 deletions(-) diff --git a/kasa/smart/modules/devicemodule.py b/kasa/smart/modules/devicemodule.py index 6a846d542..3203e82fa 100644 --- a/kasa/smart/modules/devicemodule.py +++ b/kasa/smart/modules/devicemodule.py @@ -10,6 +10,13 @@ class DeviceModule(SmartModule): REQUIRED_COMPONENT = "device" + def _post_update_hook(self): + """Perform actions after a device update. + + Overrides the default behaviour to disable a module if the query returns + an error because this module is critical. + """ + def query(self) -> dict: """Query to execute during the update cycle.""" query = { diff --git a/kasa/smart/modules/firmware.py b/kasa/smart/modules/firmware.py index 74052268f..10a6b8245 100644 --- a/kasa/smart/modules/firmware.py +++ b/kasa/smart/modules/firmware.py @@ -122,6 +122,13 @@ def query(self) -> dict: req["get_auto_update_info"] = None return req + def _post_update_hook(self): + """Perform actions after a device update. + + Overrides the default behaviour to disable a module if the query returns + an error because some of the module still functions. + """ + @property def current_firmware(self) -> str: """Return the current firmware version.""" diff --git a/kasa/smartprotocol.py b/kasa/smartprotocol.py index e6741bc47..5356e9188 100644 --- a/kasa/smartprotocol.py +++ b/kasa/smartprotocol.py @@ -416,6 +416,10 @@ def _get_method_and_params_for_request(self, request): return smart_method, smart_params async def query(self, request: str | dict, retry_count: int = 3) -> dict: + """Wrap request inside control_child envelope.""" + return await self._query(request, retry_count) + + async def _query(self, request: str | dict, retry_count: int = 3) -> dict: """Wrap request inside control_child envelope.""" method, params = self._get_method_and_params_for_request(request) request_data = { @@ -429,7 +433,7 @@ async def query(self, request: str | dict, retry_count: int = 3) -> dict: } } - response = await self._protocol.query(wrapped_payload, retry_count) + response = await self._protocol._query(wrapped_payload, retry_count) result = response.get("control_child") # Unwrap responseData for control_child if result and (response_data := result.get("responseData")): diff --git a/kasa/tests/test_smartdevice.py b/kasa/tests/test_smartdevice.py index d025e0a7b..44fabc715 100644 --- a/kasa/tests/test_smartdevice.py +++ b/kasa/tests/test_smartdevice.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging -from typing import Any +from typing import Any, cast from unittest.mock import patch import pytest @@ -132,6 +132,78 @@ async def test_update_module_queries(dev: SmartDevice, mocker: MockerFixture): spies[device].assert_not_called() +@device_smart +async def test_update_module_errors(dev: SmartDevice, mocker: MockerFixture): + """Test that modules that error are disabled / removed.""" + # We need to have some modules initialized by now + assert dev._modules + + critical_modules = {Module.DeviceModule, Module.ChildDevice} + not_disabling_modules = {Module.Firmware, Module.Cloud} + + new_dev = SmartDevice("127.0.0.1", protocol=dev.protocol) + + module_queries = { + modname: q + for modname, module in dev._modules.items() + if (q := module.query()) and modname not in critical_modules + } + child_module_queries = { + modname: q + for child in dev.children + for modname, module in child._modules.items() + if (q := module.query()) and modname not in critical_modules + } + all_queries_names = { + key for mod_query in module_queries.values() for key in mod_query + } + all_child_queries_names = { + key for mod_query in child_module_queries.values() for key in mod_query + } + + async def _query(request, *args, **kwargs): + responses = await dev.protocol._query(request, *args, **kwargs) + for k in responses: + if k in all_queries_names: + responses[k] = SmartErrorCode.PARAMS_ERROR + return responses + + async def _child_query(self, request, *args, **kwargs): + responses = await child_protocols[self._device_id]._query( + request, *args, **kwargs + ) + for k in responses: + if k in all_child_queries_names: + responses[k] = SmartErrorCode.PARAMS_ERROR + return responses + + mocker.patch.object(new_dev.protocol, "query", side_effect=_query) + + from kasa.smartprotocol import _ChildProtocolWrapper + + child_protocols = { + cast(_ChildProtocolWrapper, child.protocol)._device_id: child.protocol + for child in dev.children + } + # children not created yet so cannot patch.object + mocker.patch("kasa.smartprotocol._ChildProtocolWrapper.query", new=_child_query) + + await new_dev.update() + for modname in module_queries: + no_disable = modname in not_disabling_modules + mod_present = modname in new_dev._modules + assert ( + mod_present is no_disable + ), f"{modname} present {mod_present} when no_disable {no_disable}" + + for modname in child_module_queries: + no_disable = modname in not_disabling_modules + mod_present = any(modname in child._modules for child in new_dev.children) + assert ( + mod_present is no_disable + ), f"{modname} present {mod_present} when no_disable {no_disable}" + + async def test_get_modules(): """Test getting modules for child and parent modules.""" dummy_device = await get_device_for_fixture_protocol( From 92d8031e589b34ad7845cd7babc09b147282ae04 Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Wed, 3 Jul 2024 11:51:40 +0100 Subject: [PATCH 6/9] Fix test --- kasa/smartprotocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kasa/smartprotocol.py b/kasa/smartprotocol.py index 5356e9188..3085714c4 100644 --- a/kasa/smartprotocol.py +++ b/kasa/smartprotocol.py @@ -433,7 +433,7 @@ async def _query(self, request: str | dict, retry_count: int = 3) -> dict: } } - response = await self._protocol._query(wrapped_payload, retry_count) + response = await self._protocol.query(wrapped_payload, retry_count) result = response.get("control_child") # Unwrap responseData for control_child if result and (response_data := result.get("responseData")): From 4ede2481719101a1a6182af1cc57e7211078730c Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Wed, 3 Jul 2024 12:21:07 +0100 Subject: [PATCH 7/9] Fix coverage --- kasa/smart/modules/autooff.py | 6 ------ kasa/tests/test_smartprotocol.py | 16 ++++++++++++++++ 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/kasa/smart/modules/autooff.py b/kasa/smart/modules/autooff.py index 0004aec43..5e4b100f8 100644 --- a/kasa/smart/modules/autooff.py +++ b/kasa/smart/modules/autooff.py @@ -19,12 +19,6 @@ class AutoOff(SmartModule): def _initialize_features(self): """Initialize features after the initial update.""" - if not isinstance(self.data, dict): - _LOGGER.warning( - "No data available for module, skipping %s: %s", self, self.data - ) - return - self._add_feature( Feature( self._device, diff --git a/kasa/tests/test_smartprotocol.py b/kasa/tests/test_smartprotocol.py index d362fd00a..71125ca83 100644 --- a/kasa/tests/test_smartprotocol.py +++ b/kasa/tests/test_smartprotocol.py @@ -1,6 +1,7 @@ import logging import pytest +import pytest_mock from ..exceptions import ( SMART_RETRYABLE_ERRORS, @@ -19,6 +20,21 @@ ERRORS = [e for e in SmartErrorCode if e != 0] +async def test_smart_queries(dummy_protocol, mocker: pytest_mock.MockerFixture): + mock_response = {"result": {"great": "success"}, "error_code": 0} + + mocker.patch.object(dummy_protocol._transport, "send", return_value=mock_response) + # test sending a method name as a string + resp = await dummy_protocol.query("foobar") + assert "foobar" in resp + assert resp["foobar"] == mock_response["result"] + + # test sending a method name as a dict + resp = await dummy_protocol.query(DUMMY_QUERY) + assert "foobar" in resp + assert resp["foobar"] == mock_response["result"] + + @pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name) async def test_smart_device_errors(dummy_protocol, mocker, error_code): mock_response = {"result": {"great": "success"}, "error_code": error_code.value} From 79e1f66a92e17dc238b4c7406e90c0888e00b60c Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Wed, 3 Jul 2024 12:56:16 +0100 Subject: [PATCH 8/9] Try fix on off gradually --- devtools/helpers/smartrequests.py | 11 ++++++++++- kasa/smart/modules/lighttransition.py | 2 +- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/devtools/helpers/smartrequests.py b/devtools/helpers/smartrequests.py index 881488b5e..4db1f7a1c 100644 --- a/devtools/helpers/smartrequests.py +++ b/devtools/helpers/smartrequests.py @@ -284,6 +284,15 @@ def get_preset_rules(params: GetRulesParams | None = None) -> SmartRequest: """Get preset rules.""" return SmartRequest("get_preset_rules", params or SmartRequest.GetRulesParams()) + @staticmethod + def get_on_off_gradually_info( + params: SmartRequestParams | None = None, + ) -> SmartRequest: + """Get preset rules.""" + return SmartRequest( + "get_on_off_gradually_info", params or SmartRequest.SmartRequestParams() + ) + @staticmethod def get_auto_light_info() -> SmartRequest: """Get auto light info.""" @@ -382,7 +391,7 @@ def get_component_requests(component_id, ver_code): "auto_light": [SmartRequest.get_auto_light_info()], "light_effect": [SmartRequest.get_dynamic_light_effect_rules()], "bulb_quick_control": [], - "on_off_gradually": [SmartRequest.get_raw_request("get_on_off_gradually_info")], + "on_off_gradually": [SmartRequest.get_on_off_gradually_info()], "light_strip": [], "light_strip_lighting_effect": [ SmartRequest.get_raw_request("get_lighting_effect") diff --git a/kasa/smart/modules/lighttransition.py b/kasa/smart/modules/lighttransition.py index 29a4bb055..ca0eca867 100644 --- a/kasa/smart/modules/lighttransition.py +++ b/kasa/smart/modules/lighttransition.py @@ -230,7 +230,7 @@ def query(self) -> dict: if self._state_in_sysinfo: return {} else: - return {self.QUERY_GETTER_NAME: None} + return {self.QUERY_GETTER_NAME: {}} async def _check_supported(self): """Additional check to see if the module is supported by the device.""" From ba951567cd3c9430a1edbb6dae24dee6d6e5ff2f Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Thu, 4 Jul 2024 07:54:29 +0100 Subject: [PATCH 9/9] Catch broad exceptions when trying post update hook --- kasa/smart/smartdevice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kasa/smart/smartdevice.py b/kasa/smart/smartdevice.py index 676b7c13f..5bf2400b7 100644 --- a/kasa/smart/smartdevice.py +++ b/kasa/smart/smartdevice.py @@ -203,7 +203,7 @@ def _handle_module_post_update_hook(self, module: SmartModule) -> bool: try: module._post_update_hook() return True - except DeviceError as ex: + except Exception as ex: _LOGGER.error( "Error processing %s for device %s, module will be unavailable: %s", module.name,