From 45318a5cfb48c8f98c5af291b1a1b53db785da2f Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Thu, 15 Feb 2024 12:10:47 +0000 Subject: [PATCH 1/3] Fix for missing get_device_usage --- kasa/exceptions.py | 3 +++ kasa/smart/smartdevice.py | 34 +++++++++++++++++++++------- kasa/smartprotocol.py | 7 ++++-- kasa/tests/test_smartdevice.py | 38 +++++++++++++++++++++++++++++++- kasa/tests/test_smartprotocol.py | 13 +++++------ 5 files changed, 76 insertions(+), 19 deletions(-) diff --git a/kasa/exceptions.py b/kasa/exceptions.py index 75f09169f..af9aaaa59 100644 --- a/kasa/exceptions.py +++ b/kasa/exceptions.py @@ -45,6 +45,9 @@ class ConnectionException(SmartDeviceException): class SmartErrorCode(IntEnum): """Enum for SMART Error Codes.""" + def __str__(self): + return f"{self.name}({self.value})" + SUCCESS = 0 # Transport Errors diff --git a/kasa/smart/smartdevice.py b/kasa/smart/smartdevice.py index 0929c418d..5f66a3da9 100644 --- a/kasa/smart/smartdevice.py +++ b/kasa/smart/smartdevice.py @@ -9,7 +9,7 @@ from ..device_type import DeviceType from ..deviceconfig import DeviceConfig from ..emeterstatus import EmeterStatus -from ..exceptions import AuthenticationException, SmartDeviceException +from ..exceptions import AuthenticationException, SmartDeviceException, SmartErrorCode from ..smartprotocol import SmartProtocol _LOGGER = logging.getLogger(__name__) @@ -60,6 +60,24 @@ def children(self) -> Sequence["SmartDevice"]: """Return list of children.""" return list(self._children.values()) + def _try_get_response(self, responses: dict, request: str, default=None) -> dict: + response = responses.get(request) + if isinstance(response, SmartErrorCode): + _LOGGER.debug( + "Error %s getting request %s for device %s", + response, + request, + self.host, + ) + response = None + if response is not None: + return response + if default is not None: + return default + raise SmartDeviceException( + f"{request} not found in {responses} for device {self.host}" + ) + async def update(self, update_children: bool = True): """Update the device.""" if self.credentials is None and self.credentials_hash is None: @@ -86,7 +104,7 @@ async def update(self, update_children: bool = True): "get_current_power": None, } - if self._components["device"] >= 2: + if self._components.get("device", 0) >= 2: extra_reqs = { **extra_reqs, "get_device_usage": None, @@ -100,13 +118,13 @@ async def update(self, update_children: bool = True): resp = await self.protocol.query(req) - self._info = resp["get_device_info"] - self._time = resp["get_device_time"] + self._info = self._try_get_response(resp, "get_device_info") + self._time = self._try_get_response(resp, "get_device_time", {}) # Device usage is not available on older firmware versions - self._usage = resp.get("get_device_usage", {}) + self._usage = self._try_get_response(resp, "get_device_usage", {}) # Emeter is not always available, but we set them still for now. - self._energy = resp.get("get_energy_usage", {}) - self._emeter = resp.get("get_current_power", {}) + self._energy = self._try_get_response(resp, "get_energy_usage", {}) + self._emeter = self._try_get_response(resp, "get_current_power", {}) self._last_update = { "components": self._components_raw, @@ -115,7 +133,7 @@ async def update(self, update_children: bool = True): "time": self._time, "energy": self._energy, "emeter": self._emeter, - "child_info": resp.get("get_child_device_list", {}), + "child_info": self._try_get_response(resp, "get_child_device_list", {}), } if child_info := self._last_update.get("child_info"): diff --git a/kasa/smartprotocol.py b/kasa/smartprotocol.py index f61bac206..7da11c1b5 100644 --- a/kasa/smartprotocol.py +++ b/kasa/smartprotocol.py @@ -139,7 +139,7 @@ async def _execute_multiple_query(self, request: Dict, retry_count: int) -> Dict self._handle_response_error_code(response_step) responses = response_step["result"]["responses"] for response in responses: - self._handle_response_error_code(response) + self._handle_response_error_code(response, raise_on_error=False) result = response.get("result", None) multi_result[response["method"]] = result return multi_result @@ -179,10 +179,13 @@ async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> D result = response_data.get("result") return {smart_method: result} - def _handle_response_error_code(self, resp_dict: dict): + def _handle_response_error_code(self, resp_dict: dict, raise_on_error=True): error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type] if error_code == SmartErrorCode.SUCCESS: return + if not raise_on_error: + resp_dict["result"] = error_code + return msg = ( f"Error querying device: {self._host}: " + f"{error_code.name}({error_code.value})" diff --git a/kasa/tests/test_smartdevice.py b/kasa/tests/test_smartdevice.py index ba5ebc4fe..aee4b7ebe 100644 --- a/kasa/tests/test_smartdevice.py +++ b/kasa/tests/test_smartdevice.py @@ -1,5 +1,6 @@ import importlib import inspect +import logging import pkgutil import re import sys @@ -21,10 +22,18 @@ import kasa from kasa import Credentials, Device, DeviceConfig, SmartDeviceException +from kasa.exceptions import SmartErrorCode from kasa.iot import IotDevice from kasa.smart import SmartChildDevice, SmartDevice -from .conftest import device_iot, handle_turn_on, has_emeter_iot, no_emeter_iot, turn_on +from .conftest import ( + device_iot, + device_smart, + handle_turn_on, + has_emeter_iot, + no_emeter_iot, + turn_on, +) from .fakeprotocol_iot import FakeIotProtocol @@ -300,6 +309,33 @@ async def test_modules_not_supported(dev: IotDevice): assert module.is_supported is not None +@device_smart +async def test_update_sub_errors(dev: SmartDevice, caplog): + mock_response: dict = { + "get_device_info": {}, + "get_device_usage": SmartErrorCode.PARAMS_ERROR, + "get_device_time": {}, + } + caplog.set_level(logging.DEBUG) + with patch.object(dev.protocol, "query", return_value=mock_response): + await dev.update() + msg = "Error PARAMS_ERROR(-1008) getting request get_device_usage for device 127.0.0.123" + assert msg in caplog.text + + +@device_smart +async def test_update_no_device_info(dev: SmartDevice): + mock_response: dict = { + "get_device_usage": {}, + "get_device_time": {}, + } + msg = f"get_device_info not found in {mock_response} for device 127.0.0.123" + with patch.object(dev.protocol, "query", return_value=mock_response), pytest.raises( + SmartDeviceException, match=msg + ): + await dev.update() + + @pytest.mark.parametrize( "device_class, use_class", kasa.deprecated_smart_devices.items() ) diff --git a/kasa/tests/test_smartprotocol.py b/kasa/tests/test_smartprotocol.py index 86f554b27..fc780be23 100644 --- a/kasa/tests/test_smartprotocol.py +++ b/kasa/tests/test_smartprotocol.py @@ -40,7 +40,7 @@ async def test_smart_device_errors(dummy_protocol, mocker, error_code): @pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name) async def test_smart_device_errors_in_multiple_request( - dummy_protocol, mocker, error_code + dummy_protocol, mocker, error_code, caplog ): mock_response = { "result": { @@ -60,13 +60,10 @@ async def test_smart_device_errors_in_multiple_request( send_mock = mocker.patch.object( dummy_protocol._transport, "send", return_value=mock_response ) - with pytest.raises(SmartDeviceException): - await dummy_protocol.query(DUMMY_MULTIPLE_QUERY, retry_count=2) - if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS): - expected_calls = 3 - else: - expected_calls = 1 - assert send_mock.call_count == expected_calls + + resp_dict = await dummy_protocol.query(DUMMY_MULTIPLE_QUERY, retry_count=2) + assert resp_dict["foobar2"] == error_code + assert send_mock.call_count == 1 @pytest.mark.parametrize("request_size", [1, 3, 5, 10]) From aa5d79b088f33d9e9df53f6cabf9bb09d21dc3e4 Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Thu, 15 Feb 2024 12:35:43 +0000 Subject: [PATCH 2/3] Fix coverage and add methods to exceptions --- kasa/smartprotocol.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/kasa/smartprotocol.py b/kasa/smartprotocol.py index 7da11c1b5..54e2fe1c3 100644 --- a/kasa/smartprotocol.py +++ b/kasa/smartprotocol.py @@ -129,19 +129,21 @@ async def _execute_multiple_query(self, request: Dict, retry_count: int) -> Dict pf(smart_request), ) response_step = await self._transport.send(smart_request) + batch_name = f"multi-request-batch-{i+1}" if debug_enabled: _LOGGER.debug( - "%s multi-request-batch-%s << %s", + "%s %s << %s", self._host, - i + 1, + batch_name, pf(response_step), ) - self._handle_response_error_code(response_step) + self._handle_response_error_code(response_step, batch_name) responses = response_step["result"]["responses"] for response in responses: - self._handle_response_error_code(response, raise_on_error=False) + method = response["method"] + self._handle_response_error_code(response, method, raise_on_error=False) result = response.get("result", None) - multi_result[response["method"]] = result + multi_result[method] = result return multi_result async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict: @@ -173,13 +175,13 @@ async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> D pf(response_data), ) - self._handle_response_error_code(response_data) + self._handle_response_error_code(response_data, smart_method) # Single set_ requests do not return a result result = response_data.get("result") return {smart_method: result} - def _handle_response_error_code(self, resp_dict: dict, raise_on_error=True): + def _handle_response_error_code(self, resp_dict: dict, method, raise_on_error=True): error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type] if error_code == SmartErrorCode.SUCCESS: return @@ -189,9 +191,8 @@ def _handle_response_error_code(self, resp_dict: dict, raise_on_error=True): msg = ( f"Error querying device: {self._host}: " + f"{error_code.name}({error_code.value})" + + f" for method: {method}" ) - if method := resp_dict.get("method"): - msg += f" for method: {method}" if error_code in SMART_TIMEOUT_ERRORS: raise TimeoutException(msg, error_code=error_code) if error_code in SMART_RETRYABLE_ERRORS: @@ -341,7 +342,7 @@ async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: result = response.get("control_child") # Unwrap responseData for control_child if result and (response_data := result.get("responseData")): - self._handle_response_error_code(response_data) + self._handle_response_error_code(response_data, "control_child") result = response_data.get("result") # TODO: handle multipleRequest unwrapping From beef5b5beb432c189505cf55ea42008e0a758abd Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Thu, 15 Feb 2024 12:48:22 +0000 Subject: [PATCH 3/3] Remove unused caplog fixture --- kasa/tests/test_smartprotocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kasa/tests/test_smartprotocol.py b/kasa/tests/test_smartprotocol.py index fc780be23..7d677a831 100644 --- a/kasa/tests/test_smartprotocol.py +++ b/kasa/tests/test_smartprotocol.py @@ -40,7 +40,7 @@ async def test_smart_device_errors(dummy_protocol, mocker, error_code): @pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name) async def test_smart_device_errors_in_multiple_request( - dummy_protocol, mocker, error_code, caplog + dummy_protocol, mocker, error_code ): mock_response = { "result": {