From e9dfdf6cf7edf31653fc7d0d8aa80acea6413824 Mon Sep 17 00:00:00 2001 From: Teemu Rytilahti Date: Fri, 23 Feb 2024 02:32:38 +0100 Subject: [PATCH 1/5] Improve smartdevice update module * Expose current and latest firmware as features * Implement update loop that blocks until the update is complete --- kasa/smart/modules/firmware.py | 47 ++++++++++++++++++++++++++++++---- 1 file changed, 42 insertions(+), 5 deletions(-) diff --git a/kasa/smart/modules/firmware.py b/kasa/smart/modules/firmware.py index 626add0f6..d223b0e70 100644 --- a/kasa/smart/modules/firmware.py +++ b/kasa/smart/modules/firmware.py @@ -1,6 +1,7 @@ """Implementation of firmware module.""" from __future__ import annotations +import asyncio from datetime import date from typing import TYPE_CHECKING, Any, Optional @@ -11,6 +12,11 @@ from ...feature import Feature from ..smartmodule import SmartModule +# When support for cpython older than 3.11 is dropped +# async_timeout can be replaced with asyncio.timeout +from async_timeout import timeout as asyncio_timeout + + if TYPE_CHECKING: from ..smartdevice import SmartDevice @@ -19,7 +25,7 @@ class UpdateInfo(BaseModel): """Update info status object.""" status: int = Field(alias="type") - fw_ver: Optional[str] = None # noqa: UP007 + version: Optional[str] = Field(alias="fw_ver", default=None) # noqa: UP007 release_date: Optional[date] = None # noqa: UP007 release_notes: Optional[str] = Field(alias="release_note", default=None) # noqa: UP007 fw_size: Optional[int] = None # noqa: UP007 @@ -71,6 +77,12 @@ def __init__(self, device: SmartDevice, module: str): category=Feature.Category.Info, ) ) + self._add_feature( + Feature(device, "Current firmware version", container=self, attribute_getter="current_firmware") + ) + self._add_feature( + Feature(device, "Available firmware version", container=self, attribute_getter="latest_firmware") + ) def query(self) -> dict: """Query to execute during the update cycle.""" @@ -80,7 +92,18 @@ def query(self) -> dict: return req @property - def latest_firmware(self): + def current_firmware(self) -> str: + """Return the current firmware version.""" + return self._device.hw_info["sw_ver"] + + + @property + def latest_firmware(self) -> str: + """Return the latest firmware version.""" + return self.firmware_update_info.version + + @property + def firmware_update_info(self): """Return latest firmware information.""" fw = self.data.get("get_latest_fw") or self.data if not self._device.is_cloud_connected or isinstance(fw, SmartErrorCode): @@ -94,7 +117,7 @@ def update_available(self) -> bool | None: """Return True if update is available.""" if not self._device.is_cloud_connected: return None - return self.latest_firmware.update_available + return self.firmware_update_info.update_available async def get_update_state(self): """Return update state.""" @@ -102,7 +125,21 @@ async def get_update_state(self): async def update(self): """Update the device firmware.""" - return await self.call("fw_download") + current_fw = self.current_firmware + _LOGGER.debug("Going to upgrade from %s to %s", current_fw, self.firmware_update_info.version) + resp = await self.call("fw_download") + _LOGGER.debug("Update request response: %s", resp) + # TODO: read timeout from get_auto_update_info or from get_fw_download_state? + async with asyncio_timeout(60*5): + while True: + await asyncio.sleep(0.5) + state = await self.get_update_state() + _LOGGER.debug("Update state: %s" % state) + # TODO: this could await a given callable for progress + + if self.firmware_update_info.version != current_fw: + _LOGGER.info("Updated to %s", self.firmware_update_info.version) + break @property def auto_update_enabled(self): @@ -115,4 +152,4 @@ def auto_update_enabled(self): async def set_auto_update_enabled(self, enabled: bool): """Change autoupdate setting.""" data = {**self.data["get_auto_update_info"], "enable": enabled} - await self.call("set_auto_update_info", data) # {"enable": enabled}) + await self.call("set_auto_update_info", data) From fd8689123f6a28bc5782e8db7f00d263f6b7de4a Mon Sep 17 00:00:00 2001 From: Teemu Rytilahti Date: Fri, 23 Feb 2024 16:12:28 +0100 Subject: [PATCH 2/5] Fix linting --- kasa/smart/modules/firmware.py | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/kasa/smart/modules/firmware.py b/kasa/smart/modules/firmware.py index d223b0e70..bd5e5ef74 100644 --- a/kasa/smart/modules/firmware.py +++ b/kasa/smart/modules/firmware.py @@ -8,19 +8,22 @@ from pydantic.v1 import BaseModel, Field, validator -from ...exceptions import SmartErrorCode from ...feature import Feature -from ..smartmodule import SmartModule - # When support for cpython older than 3.11 is dropped # async_timeout can be replaced with asyncio.timeout from async_timeout import timeout as asyncio_timeout +from ...exceptions import SmartErrorCode +from ...feature import Feature, FeatureType +from ..smartmodule import SmartModule if TYPE_CHECKING: from ..smartdevice import SmartDevice +_LOGGER = logging.getLogger(__name__) + + class UpdateInfo(BaseModel): """Update info status object.""" @@ -78,10 +81,20 @@ def __init__(self, device: SmartDevice, module: str): ) ) self._add_feature( - Feature(device, "Current firmware version", container=self, attribute_getter="current_firmware") + Feature( + device, + "Current firmware version", + container=self, + attribute_getter="current_firmware", + ) ) self._add_feature( - Feature(device, "Available firmware version", container=self, attribute_getter="latest_firmware") + Feature( + device, + "Available firmware version", + container=self, + attribute_getter="latest_firmware", + ) ) def query(self) -> dict: @@ -96,7 +109,6 @@ def current_firmware(self) -> str: """Return the current firmware version.""" return self._device.hw_info["sw_ver"] - @property def latest_firmware(self) -> str: """Return the latest firmware version.""" @@ -126,11 +138,15 @@ async def get_update_state(self): async def update(self): """Update the device firmware.""" current_fw = self.current_firmware - _LOGGER.debug("Going to upgrade from %s to %s", current_fw, self.firmware_update_info.version) + _LOGGER.debug( + "Going to upgrade from %s to %s", + current_fw, + self.firmware_update_info.version, + ) resp = await self.call("fw_download") _LOGGER.debug("Update request response: %s", resp) # TODO: read timeout from get_auto_update_info or from get_fw_download_state? - async with asyncio_timeout(60*5): + async with asyncio_timeout(60 * 5): while True: await asyncio.sleep(0.5) state = await self.get_update_state() From 39e6aac63b85c14c9ae5ac034d444367e9a71546 Mon Sep 17 00:00:00 2001 From: Teemu Rytilahti Date: Tue, 7 May 2024 17:41:38 +0200 Subject: [PATCH 3/5] Fix after rebasing --- kasa/smart/modules/firmware.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/kasa/smart/modules/firmware.py b/kasa/smart/modules/firmware.py index bd5e5ef74..16704ef0b 100644 --- a/kasa/smart/modules/firmware.py +++ b/kasa/smart/modules/firmware.py @@ -1,20 +1,19 @@ """Implementation of firmware module.""" from __future__ import annotations -import asyncio +import asyncio +import logging from datetime import date from typing import TYPE_CHECKING, Any, Optional -from pydantic.v1 import BaseModel, Field, validator - -from ...feature import Feature # When support for cpython older than 3.11 is dropped # async_timeout can be replaced with asyncio.timeout from async_timeout import timeout as asyncio_timeout +from pydantic.v1 import BaseModel, Field, validator from ...exceptions import SmartErrorCode -from ...feature import Feature, FeatureType +from ...feature import Feature from ..smartmodule import SmartModule if TYPE_CHECKING: @@ -83,17 +82,21 @@ def __init__(self, device: SmartDevice, module: str): self._add_feature( Feature( device, - "Current firmware version", + id="current_firmware_version", + name="Current firmware version", container=self, attribute_getter="current_firmware", + category=Feature.Category.Info, ) ) self._add_feature( Feature( device, - "Available firmware version", + id="available_firmware_version", + name="Available firmware version", container=self, attribute_getter="latest_firmware", + category=Feature.Category.Info, ) ) From da83cd5359733ba147b5abc0aae755aa96828f25 Mon Sep 17 00:00:00 2001 From: Teemu Rytilahti Date: Wed, 8 May 2024 18:37:46 +0200 Subject: [PATCH 4/5] Fix update logic & add tests --- kasa/smart/modules/firmware.py | 66 +++++++++++--- kasa/tests/fakeprotocol_smart.py | 2 +- kasa/tests/smart/modules/test_firmware.py | 102 ++++++++++++++++++++++ 3 files changed, 156 insertions(+), 14 deletions(-) create mode 100644 kasa/tests/smart/modules/test_firmware.py diff --git a/kasa/smart/modules/firmware.py b/kasa/smart/modules/firmware.py index 16704ef0b..b1cd0b9dd 100644 --- a/kasa/smart/modules/firmware.py +++ b/kasa/smart/modules/firmware.py @@ -23,6 +23,19 @@ _LOGGER = logging.getLogger(__name__) +class DownloadState(BaseModel): + """Download state.""" + + # Example: + # {'status': 0, 'download_progress': 0, 'reboot_time': 5, + # 'upgrade_time': 5, 'auto_upgrade': False} + status: int + progress: int = Field(alias="download_progress") + reboot_time: int + upgrade_time: int + auto_upgrade: bool + + class UpdateInfo(BaseModel): """Update info status object.""" @@ -86,7 +99,7 @@ def __init__(self, device: SmartDevice, module: str): name="Current firmware version", container=self, attribute_getter="current_firmware", - category=Feature.Category.Info, + category=Feature.Category.Debug, ) ) self._add_feature( @@ -96,7 +109,7 @@ def __init__(self, device: SmartDevice, module: str): name="Available firmware version", container=self, attribute_getter="latest_firmware", - category=Feature.Category.Info, + category=Feature.Category.Debug, ) ) @@ -134,31 +147,58 @@ def update_available(self) -> bool | None: return None return self.firmware_update_info.update_available - async def get_update_state(self): + async def get_update_state(self) -> DownloadState: """Return update state.""" - return await self.call("get_fw_download_state") + resp = await self.call("get_fw_download_state") + state = resp["get_fw_download_state"] + return DownloadState(**state) - async def update(self): + async def update(self, progress_cb=None): """Update the device firmware.""" current_fw = self.current_firmware - _LOGGER.debug( + _LOGGER.info( "Going to upgrade from %s to %s", current_fw, self.firmware_update_info.version, ) - resp = await self.call("fw_download") - _LOGGER.debug("Update request response: %s", resp) + await self.call("fw_download") + # TODO: read timeout from get_auto_update_info or from get_fw_download_state? async with asyncio_timeout(60 * 5): while True: await asyncio.sleep(0.5) - state = await self.get_update_state() - _LOGGER.debug("Update state: %s" % state) - # TODO: this could await a given callable for progress + try: + state = await self.get_update_state() + except Exception as ex: + _LOGGER.warning( + "Got exception, maybe the device is rebooting? %s", ex + ) + continue - if self.firmware_update_info.version != current_fw: - _LOGGER.info("Updated to %s", self.firmware_update_info.version) + _LOGGER.debug("Update state: %s" % state) + if progress_cb is not None: + asyncio.ensure_future(progress_cb(state)) + + if state.status == 0: + _LOGGER.info( + "Update idle, hopefully updated to %s", + self.firmware_update_info.version, + ) + break + elif state.status == 2: + _LOGGER.info("Downloading firmware, progress: %s", state.progress) + elif state.status == 3: + upgrade_sleep = state.upgrade_time + _LOGGER.info( + "Flashing firmware, sleeping for %s before checking status", + upgrade_sleep, + ) + await asyncio.sleep(upgrade_sleep) + elif state.status < 0: + _LOGGER.error("Got error: %s", state.status) break + else: + _LOGGER.warning("Unhandled state code: %s", state) @property def auto_update_enabled(self): diff --git a/kasa/tests/fakeprotocol_smart.py b/kasa/tests/fakeprotocol_smart.py index ae1a7ad66..5ca4a8ae1 100644 --- a/kasa/tests/fakeprotocol_smart.py +++ b/kasa/tests/fakeprotocol_smart.py @@ -234,7 +234,7 @@ def _send_request(self, request_dict: dict): pytest.fixtures_missing_methods[self.fixture_name] = set() pytest.fixtures_missing_methods[self.fixture_name].add(method) return retval - elif method == "set_qs_info": + elif method in ["set_qs_info", "fw_download"]: return {"error_code": 0} elif method == "set_dynamic_light_effect_rule_enable": self._set_light_effect(info, params) diff --git a/kasa/tests/smart/modules/test_firmware.py b/kasa/tests/smart/modules/test_firmware.py new file mode 100644 index 000000000..e413acb6d --- /dev/null +++ b/kasa/tests/smart/modules/test_firmware.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +import logging + +import pytest +from pytest_mock import MockerFixture + +from kasa.smart import SmartDevice +from kasa.smart.modules import Firmware +from kasa.smart.modules.firmware import DownloadState +from kasa.tests.device_fixtures import parametrize + +firmware = parametrize( + "has firmware", component_filter="firmware", protocol_filter={"SMART"} +) + + +@firmware +@pytest.mark.parametrize( + "feature, prop_name, type, required_version", + [ + ("auto_update_enabled", "auto_update_enabled", bool, 2), + ("update_available", "update_available", bool, 1), + ("update_available", "update_available", bool, 1), + ("current_firmware_version", "current_firmware", str, 1), + ("available_firmware_version", "latest_firmware", str, 1), + ], +) +async def test_firmware_features( + dev: SmartDevice, feature, prop_name, type, required_version, mocker: MockerFixture +): + """Test light effect.""" + fw = dev.get_module(Firmware) + assert fw + + if not dev.is_cloud_connected: + pytest.skip("Device is not cloud connected, skipping test") + + if fw.supported_version < required_version: + pytest.skip("Feature %s requires newer version" % feature) + + prop = getattr(fw, prop_name) + assert isinstance(prop, type) + + feat = fw._module_features[feature] + assert feat.value == prop + assert isinstance(feat.value, type) + + +@firmware +async def test_update_available_without_cloud(dev: SmartDevice): + """Test that update_available returns None when disconnected.""" + fw = dev.get_module(Firmware) + assert fw + + if dev.is_cloud_connected: + assert isinstance(fw.update_available, bool) + else: + assert fw.update_available is None + + +@firmware +async def test_update( + dev: SmartDevice, mocker: MockerFixture, caplog: pytest.LogCaptureFixture +): + """Test updating firmware.""" + caplog.set_level(logging.INFO) + + fw = dev.get_module(Firmware) + assert fw + + upgrade_time = 5 + extras = {"reboot_time": 5, "upgrade_time": upgrade_time, "auto_upgrade": False} + update_states = [ + # Unknown 1 + DownloadState(status=1, download_progress=0, **extras), + # Downloading + DownloadState(status=2, download_progress=10, **extras), + DownloadState(status=2, download_progress=100, **extras), + # Flashing + DownloadState(status=3, download_progress=100, **extras), + DownloadState(status=3, download_progress=100, **extras), + # Done + DownloadState(status=0, download_progress=100, **extras), + ] + + sleep = mocker.patch("asyncio.sleep") + mocker.patch.object(fw, "get_update_state", side_effect=update_states) + + cb_mock = mocker.AsyncMock() + + await fw.update(progress_cb=cb_mock) + + assert "Unhandled state code" in caplog.text + assert "Downloading firmware, progress: 10" in caplog.text + assert "Flashing firmware, sleeping" in caplog.text + assert "Update idle" in caplog.text + + cb_mock.assert_called() + + # sleep based on the upgrade_time + sleep.assert_any_call(upgrade_time) From 31a43c27f90aef156e3267f76f637f39f2a9f124 Mon Sep 17 00:00:00 2001 From: Teemu Rytilahti Date: Wed, 8 May 2024 21:05:24 +0200 Subject: [PATCH 5/5] Adjust code based on reviews Also, check that that the callbacks are really awaited --- kasa/smart/modules/firmware.py | 8 +++++--- kasa/tests/smart/modules/test_firmware.py | 10 ++++++++-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/kasa/smart/modules/firmware.py b/kasa/smart/modules/firmware.py index b1cd0b9dd..430515e4b 100644 --- a/kasa/smart/modules/firmware.py +++ b/kasa/smart/modules/firmware.py @@ -5,7 +5,7 @@ import asyncio import logging from datetime import date -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Callable, Coroutine, Optional # When support for cpython older than 3.11 is dropped # async_timeout can be replaced with asyncio.timeout @@ -153,7 +153,9 @@ async def get_update_state(self) -> DownloadState: state = resp["get_fw_download_state"] return DownloadState(**state) - async def update(self, progress_cb=None): + async def update( + self, progress_cb: Callable[[DownloadState], Coroutine] | None = None + ): """Update the device firmware.""" current_fw = self.current_firmware _LOGGER.info( @@ -177,7 +179,7 @@ async def update(self, progress_cb=None): _LOGGER.debug("Update state: %s" % state) if progress_cb is not None: - asyncio.ensure_future(progress_cb(state)) + asyncio.create_task(progress_cb(state)) if state.status == 0: _LOGGER.info( diff --git a/kasa/tests/smart/modules/test_firmware.py b/kasa/tests/smart/modules/test_firmware.py index e413acb6d..d0df87ca5 100644 --- a/kasa/tests/smart/modules/test_firmware.py +++ b/kasa/tests/smart/modules/test_firmware.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import logging import pytest @@ -60,7 +61,7 @@ async def test_update_available_without_cloud(dev: SmartDevice): @firmware -async def test_update( +async def test_firmware_update( dev: SmartDevice, mocker: MockerFixture, caplog: pytest.LogCaptureFixture ): """Test updating firmware.""" @@ -84,6 +85,7 @@ async def test_update( DownloadState(status=0, download_progress=100, **extras), ] + asyncio_sleep = asyncio.sleep sleep = mocker.patch("asyncio.sleep") mocker.patch.object(fw, "get_update_state", side_effect=update_states) @@ -91,12 +93,16 @@ async def test_update( await fw.update(progress_cb=cb_mock) + # This is necessary to allow the eventloop to process the created tasks + await asyncio_sleep(0) + assert "Unhandled state code" in caplog.text assert "Downloading firmware, progress: 10" in caplog.text assert "Flashing firmware, sleeping" in caplog.text assert "Update idle" in caplog.text - cb_mock.assert_called() + for state in update_states: + cb_mock.assert_any_await(state) # sleep based on the upgrade_time sleep.assert_any_call(upgrade_time)