diff --git a/kasa/exceptions.py b/kasa/exceptions.py index 75f09169f..af9aaaa59 100644 --- a/kasa/exceptions.py +++ b/kasa/exceptions.py @@ -45,6 +45,9 @@ class ConnectionException(SmartDeviceException): class SmartErrorCode(IntEnum): """Enum for SMART Error Codes.""" + def __str__(self): + return f"{self.name}({self.value})" + SUCCESS = 0 # Transport Errors diff --git a/kasa/smart/smartdevice.py b/kasa/smart/smartdevice.py index dde8634f3..d22594347 100644 --- a/kasa/smart/smartdevice.py +++ b/kasa/smart/smartdevice.py @@ -9,7 +9,7 @@ from ..device_type import DeviceType from ..deviceconfig import DeviceConfig from ..emeterstatus import EmeterStatus -from ..exceptions import AuthenticationException, SmartDeviceException +from ..exceptions import AuthenticationException, SmartDeviceException, SmartErrorCode from ..feature import Feature, FeatureType from ..smartprotocol import SmartProtocol @@ -61,6 +61,24 @@ def children(self) -> Sequence["SmartDevice"]: """Return list of children.""" return list(self._children.values()) + def _try_get_response(self, responses: dict, request: str, default=None) -> dict: + response = responses.get(request) + if isinstance(response, SmartErrorCode): + _LOGGER.debug( + "Error %s getting request %s for device %s", + response, + request, + self.host, + ) + response = None + if response is not None: + return response + if default is not None: + return default + raise SmartDeviceException( + f"{request} not found in {responses} for device {self.host}" + ) + async def update(self, update_children: bool = True): """Update the device.""" if self.credentials is None and self.credentials_hash is None: @@ -87,7 +105,7 @@ async def update(self, update_children: bool = True): "get_current_power": None, } - if self._components["device"] >= 2: + if self._components.get("device", 0) >= 2: extra_reqs = { **extra_reqs, "get_device_usage": None, @@ -101,13 +119,13 @@ async def update(self, update_children: bool = True): resp = await self.protocol.query(req) - self._info = resp["get_device_info"] - self._time = resp["get_device_time"] + self._info = self._try_get_response(resp, "get_device_info") + self._time = self._try_get_response(resp, "get_device_time", {}) # Device usage is not available on older firmware versions - self._usage = resp.get("get_device_usage", {}) + self._usage = self._try_get_response(resp, "get_device_usage", {}) # Emeter is not always available, but we set them still for now. - self._energy = resp.get("get_energy_usage", {}) - self._emeter = resp.get("get_current_power", {}) + self._energy = self._try_get_response(resp, "get_energy_usage", {}) + self._emeter = self._try_get_response(resp, "get_current_power", {}) self._last_update = { "components": self._components_raw, @@ -116,7 +134,7 @@ async def update(self, update_children: bool = True): "time": self._time, "energy": self._energy, "emeter": self._emeter, - "child_info": resp.get("get_child_device_list", {}), + "child_info": self._try_get_response(resp, "get_child_device_list", {}), } if child_info := self._last_update.get("child_info"): diff --git a/kasa/smartprotocol.py b/kasa/smartprotocol.py index f61bac206..54e2fe1c3 100644 --- a/kasa/smartprotocol.py +++ b/kasa/smartprotocol.py @@ -129,19 +129,21 @@ async def _execute_multiple_query(self, request: Dict, retry_count: int) -> Dict pf(smart_request), ) response_step = await self._transport.send(smart_request) + batch_name = f"multi-request-batch-{i+1}" if debug_enabled: _LOGGER.debug( - "%s multi-request-batch-%s << %s", + "%s %s << %s", self._host, - i + 1, + batch_name, pf(response_step), ) - self._handle_response_error_code(response_step) + self._handle_response_error_code(response_step, batch_name) responses = response_step["result"]["responses"] for response in responses: - self._handle_response_error_code(response) + method = response["method"] + self._handle_response_error_code(response, method, raise_on_error=False) result = response.get("result", None) - multi_result[response["method"]] = result + multi_result[method] = result return multi_result async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict: @@ -173,22 +175,24 @@ async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> D pf(response_data), ) - self._handle_response_error_code(response_data) + self._handle_response_error_code(response_data, smart_method) # Single set_ requests do not return a result result = response_data.get("result") return {smart_method: result} - def _handle_response_error_code(self, resp_dict: dict): + def _handle_response_error_code(self, resp_dict: dict, method, raise_on_error=True): error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type] if error_code == SmartErrorCode.SUCCESS: return + if not raise_on_error: + resp_dict["result"] = error_code + return msg = ( f"Error querying device: {self._host}: " + f"{error_code.name}({error_code.value})" + + f" for method: {method}" ) - if method := resp_dict.get("method"): - msg += f" for method: {method}" if error_code in SMART_TIMEOUT_ERRORS: raise TimeoutException(msg, error_code=error_code) if error_code in SMART_RETRYABLE_ERRORS: @@ -338,7 +342,7 @@ async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: result = response.get("control_child") # Unwrap responseData for control_child if result and (response_data := result.get("responseData")): - self._handle_response_error_code(response_data) + self._handle_response_error_code(response_data, "control_child") result = response_data.get("result") # TODO: handle multipleRequest unwrapping diff --git a/kasa/tests/test_smartdevice.py b/kasa/tests/test_smartdevice.py index efe6995ba..67f8fa84f 100644 --- a/kasa/tests/test_smartdevice.py +++ b/kasa/tests/test_smartdevice.py @@ -1,5 +1,6 @@ import importlib import inspect +import logging import pkgutil import re import sys @@ -21,10 +22,18 @@ import kasa from kasa import Credentials, Device, DeviceConfig, SmartDeviceException +from kasa.exceptions import SmartErrorCode from kasa.iot import IotDevice from kasa.smart import SmartChildDevice, SmartDevice -from .conftest import device_iot, handle_turn_on, has_emeter_iot, no_emeter_iot, turn_on +from .conftest import ( + device_iot, + device_smart, + handle_turn_on, + has_emeter_iot, + no_emeter_iot, + turn_on, +) from .fakeprotocol_iot import FakeIotProtocol @@ -300,6 +309,33 @@ async def test_modules_not_supported(dev: IotDevice): assert module.is_supported is not None +@device_smart +async def test_update_sub_errors(dev: SmartDevice, caplog): + mock_response: dict = { + "get_device_info": {}, + "get_device_usage": SmartErrorCode.PARAMS_ERROR, + "get_device_time": {}, + } + caplog.set_level(logging.DEBUG) + with patch.object(dev.protocol, "query", return_value=mock_response): + await dev.update() + msg = "Error PARAMS_ERROR(-1008) getting request get_device_usage for device 127.0.0.123" + assert msg in caplog.text + + +@device_smart +async def test_update_no_device_info(dev: SmartDevice): + mock_response: dict = { + "get_device_usage": {}, + "get_device_time": {}, + } + msg = f"get_device_info not found in {mock_response} for device 127.0.0.123" + with patch.object(dev.protocol, "query", return_value=mock_response), pytest.raises( + SmartDeviceException, match=msg + ): + await dev.update() + + @pytest.mark.parametrize( "device_class, use_class", kasa.deprecated_smart_devices.items() ) diff --git a/kasa/tests/test_smartprotocol.py b/kasa/tests/test_smartprotocol.py index 86f554b27..7d677a831 100644 --- a/kasa/tests/test_smartprotocol.py +++ b/kasa/tests/test_smartprotocol.py @@ -60,13 +60,10 @@ async def test_smart_device_errors_in_multiple_request( send_mock = mocker.patch.object( dummy_protocol._transport, "send", return_value=mock_response ) - with pytest.raises(SmartDeviceException): - await dummy_protocol.query(DUMMY_MULTIPLE_QUERY, retry_count=2) - if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS): - expected_calls = 3 - else: - expected_calls = 1 - assert send_mock.call_count == expected_calls + + resp_dict = await dummy_protocol.query(DUMMY_MULTIPLE_QUERY, retry_count=2) + assert resp_dict["foobar2"] == error_code + assert send_mock.call_count == 1 @pytest.mark.parametrize("request_size", [1, 3, 5, 10])