Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Let caller handle SMART errors on multi-requests #754

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions kasa/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ class ConnectionException(SmartDeviceException):
class SmartErrorCode(IntEnum):
"""Enum for SMART Error Codes."""

def __str__(self):
return f"{self.name}({self.value})"

SUCCESS = 0

# Transport Errors
Expand Down
34 changes: 26 additions & 8 deletions kasa/smart/smartdevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ..device_type import DeviceType
from ..deviceconfig import DeviceConfig
from ..emeterstatus import EmeterStatus
from ..exceptions import AuthenticationException, SmartDeviceException
from ..exceptions import AuthenticationException, SmartDeviceException, SmartErrorCode
from ..feature import Feature, FeatureType
from ..smartprotocol import SmartProtocol

Expand Down Expand Up @@ -61,6 +61,24 @@ def children(self) -> Sequence["SmartDevice"]:
"""Return list of children."""
return list(self._children.values())

def _try_get_response(self, responses: dict, request: str, default=None) -> dict:
response = responses.get(request)
if isinstance(response, SmartErrorCode):
_LOGGER.debug(
"Error %s getting request %s for device %s",
response,
request,
self.host,
)
response = None
if response is not None:
return response
if default is not None:
return default
raise SmartDeviceException(
f"{request} not found in {responses} for device {self.host}"
)

async def update(self, update_children: bool = True):
"""Update the device."""
if self.credentials is None and self.credentials_hash is None:
Expand All @@ -87,7 +105,7 @@ async def update(self, update_children: bool = True):
"get_current_power": None,
}

if self._components["device"] >= 2:
if self._components.get("device", 0) >= 2:
extra_reqs = {
**extra_reqs,
"get_device_usage": None,
Expand All @@ -101,13 +119,13 @@ async def update(self, update_children: bool = True):

resp = await self.protocol.query(req)

self._info = resp["get_device_info"]
self._time = resp["get_device_time"]
self._info = self._try_get_response(resp, "get_device_info")
self._time = self._try_get_response(resp, "get_device_time", {})
# Device usage is not available on older firmware versions
self._usage = resp.get("get_device_usage", {})
self._usage = self._try_get_response(resp, "get_device_usage", {})
# Emeter is not always available, but we set them still for now.
self._energy = resp.get("get_energy_usage", {})
self._emeter = resp.get("get_current_power", {})
self._energy = self._try_get_response(resp, "get_energy_usage", {})
self._emeter = self._try_get_response(resp, "get_current_power", {})

self._last_update = {
"components": self._components_raw,
Expand All @@ -116,7 +134,7 @@ async def update(self, update_children: bool = True):
"time": self._time,
"energy": self._energy,
"emeter": self._emeter,
"child_info": resp.get("get_child_device_list", {}),
"child_info": self._try_get_response(resp, "get_child_device_list", {}),
}

if child_info := self._last_update.get("child_info"):
Expand Down
24 changes: 14 additions & 10 deletions kasa/smartprotocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,19 +129,21 @@ async def _execute_multiple_query(self, request: Dict, retry_count: int) -> Dict
pf(smart_request),
)
response_step = await self._transport.send(smart_request)
batch_name = f"multi-request-batch-{i+1}"
if debug_enabled:
_LOGGER.debug(
"%s multi-request-batch-%s << %s",
"%s %s << %s",
self._host,
i + 1,
batch_name,
pf(response_step),
)
self._handle_response_error_code(response_step)
self._handle_response_error_code(response_step, batch_name)
responses = response_step["result"]["responses"]
for response in responses:
self._handle_response_error_code(response)
method = response["method"]
self._handle_response_error_code(response, method, raise_on_error=False)
result = response.get("result", None)
multi_result[response["method"]] = result
multi_result[method] = result
return multi_result

async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict:
Expand Down Expand Up @@ -173,22 +175,24 @@ async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> D
pf(response_data),
)

self._handle_response_error_code(response_data)
self._handle_response_error_code(response_data, smart_method)

# Single set_ requests do not return a result
result = response_data.get("result")
return {smart_method: result}

def _handle_response_error_code(self, resp_dict: dict):
def _handle_response_error_code(self, resp_dict: dict, method, raise_on_error=True):
error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
if error_code == SmartErrorCode.SUCCESS:
return
if not raise_on_error:
resp_dict["result"] = error_code
return
msg = (
f"Error querying device: {self._host}: "
+ f"{error_code.name}({error_code.value})"
+ f" for method: {method}"
)
if method := resp_dict.get("method"):
msg += f" for method: {method}"
if error_code in SMART_TIMEOUT_ERRORS:
raise TimeoutException(msg, error_code=error_code)
if error_code in SMART_RETRYABLE_ERRORS:
Expand Down Expand Up @@ -338,7 +342,7 @@ async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
result = response.get("control_child")
# Unwrap responseData for control_child
if result and (response_data := result.get("responseData")):
self._handle_response_error_code(response_data)
self._handle_response_error_code(response_data, "control_child")
result = response_data.get("result")

# TODO: handle multipleRequest unwrapping
Expand Down
38 changes: 37 additions & 1 deletion kasa/tests/test_smartdevice.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import importlib
import inspect
import logging
import pkgutil
import re
import sys
Expand All @@ -21,10 +22,18 @@

import kasa
from kasa import Credentials, Device, DeviceConfig, SmartDeviceException
from kasa.exceptions import SmartErrorCode
from kasa.iot import IotDevice
from kasa.smart import SmartChildDevice, SmartDevice

from .conftest import device_iot, handle_turn_on, has_emeter_iot, no_emeter_iot, turn_on
from .conftest import (
device_iot,
device_smart,
handle_turn_on,
has_emeter_iot,
no_emeter_iot,
turn_on,
)
from .fakeprotocol_iot import FakeIotProtocol


Expand Down Expand Up @@ -300,6 +309,33 @@ async def test_modules_not_supported(dev: IotDevice):
assert module.is_supported is not None


@device_smart
async def test_update_sub_errors(dev: SmartDevice, caplog):
mock_response: dict = {
"get_device_info": {},
"get_device_usage": SmartErrorCode.PARAMS_ERROR,
"get_device_time": {},
}
caplog.set_level(logging.DEBUG)
with patch.object(dev.protocol, "query", return_value=mock_response):
await dev.update()
msg = "Error PARAMS_ERROR(-1008) getting request get_device_usage for device 127.0.0.123"
assert msg in caplog.text


@device_smart
async def test_update_no_device_info(dev: SmartDevice):
mock_response: dict = {
"get_device_usage": {},
"get_device_time": {},
}
msg = f"get_device_info not found in {mock_response} for device 127.0.0.123"
with patch.object(dev.protocol, "query", return_value=mock_response), pytest.raises(
SmartDeviceException, match=msg
):
await dev.update()


@pytest.mark.parametrize(
"device_class, use_class", kasa.deprecated_smart_devices.items()
)
Expand Down
11 changes: 4 additions & 7 deletions kasa/tests/test_smartprotocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,10 @@ async def test_smart_device_errors_in_multiple_request(
send_mock = mocker.patch.object(
dummy_protocol._transport, "send", return_value=mock_response
)
with pytest.raises(SmartDeviceException):
await dummy_protocol.query(DUMMY_MULTIPLE_QUERY, retry_count=2)
if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS):
expected_calls = 3
else:
expected_calls = 1
assert send_mock.call_count == expected_calls

resp_dict = await dummy_protocol.query(DUMMY_MULTIPLE_QUERY, retry_count=2)
assert resp_dict["foobar2"] == error_code
assert send_mock.call_count == 1


@pytest.mark.parametrize("request_size", [1, 3, 5, 10])
Expand Down