diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml new file mode 100644 index 0000000..946fd2f --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -0,0 +1,81 @@ +--- +name: "🐛 Bug Report" +description: Report a bug +title: "(short issue description)" +labels: [bug, needs-triage] +assignees: [] +body: + - type: textarea + id: description + attributes: + label: Describe the bug + description: What is the problem? A clear and concise description of the bug. + validations: + required: true + - type: checkboxes + id: regression + attributes: + label: Regression Issue + description: What is a regression? If it worked in a previous version but doesn't in the latest version, it's considered a regression. In this case, please provide specific version number in the report. + options: + - label: Select this option if this issue appears to be a regression. + required: false + - type: textarea + id: expected + attributes: + label: Expected Behavior + description: | + What did you expect to happen? + validations: + required: true + - type: textarea + id: current + attributes: + label: Current Behavior + description: | + What actually happened? + + Please include full errors, uncaught exceptions, stack traces, and relevant logs. + If service responses are relevant, please include wire logs. + validations: + required: true + - type: textarea + id: reproduction + attributes: + label: Reproduction Steps + description: | + Provide a self-contained, concise snippet of code that can be used to reproduce the issue. + For more complex issues provide a repo with the smallest sample that reproduces the bug. + + Avoid including business logic or unrelated code, it makes diagnosis more difficult. + The code sample should be an SSCCE. See http://sscce.org/ for details. In short, please provide a code sample that we can copy/paste, run and reproduce. + validations: + required: true + - type: textarea + id: solution + attributes: + label: Possible Solution + description: | + Suggest a fix/reason for the bug + validations: + required: false + - type: textarea + id: context + attributes: + label: Additional Information/Context + description: | + Anything else that might be relevant for troubleshooting this bug. Providing context helps us come up with a solution that is most useful in the real world. + validations: + required: false + - type: input + id: sdk-version + attributes: + label: SDK version used + validations: + required: true + - type: input + id: environment + attributes: + label: Environment details (OS name and version, etc.) + validations: + required: true diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000..fe0acce --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,6 @@ +--- +blank_issues_enabled: false +contact_links: + - name: 💬 General Question + url: https://github.com/aws/aws-iot-device-sdk-python/discussions/categories/q-a + about: Please ask and answer questions as a discussion thread diff --git a/.github/ISSUE_TEMPLATE/documentation.yml b/.github/ISSUE_TEMPLATE/documentation.yml new file mode 100644 index 0000000..7d73869 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/documentation.yml @@ -0,0 +1,23 @@ +--- +name: "📕 Documentation Issue" +description: Report an issue in the API Reference documentation or Developer Guide +title: "(short issue description)" +labels: [documentation, needs-triage] +assignees: [] +body: + - type: textarea + id: description + attributes: + label: Describe the issue + description: A clear and concise description of the issue. + validations: + required: true + + - type: textarea + id: links + attributes: + label: Links + description: | + Include links to affected documentation page(s). + validations: + required: true diff --git a/.github/ISSUE_TEMPLATE/feature-request.yml b/.github/ISSUE_TEMPLATE/feature-request.yml new file mode 100644 index 0000000..60d2431 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature-request.yml @@ -0,0 +1,59 @@ +--- +name: 🚀 Feature Request +description: Suggest an idea for this project +title: "(short issue description)" +labels: [feature-request, needs-triage] +assignees: [] +body: + - type: textarea + id: description + attributes: + label: Describe the feature + description: A clear and concise description of the feature you are proposing. + validations: + required: true + - type: textarea + id: use-case + attributes: + label: Use Case + description: | + Why do you need this feature? For example: "I'm always frustrated when..." + validations: + required: true + - type: textarea + id: solution + attributes: + label: Proposed Solution + description: | + Suggest how to implement the addition or change. Please include prototype/workaround/sketch/reference implementation. + validations: + required: false + - type: textarea + id: other + attributes: + label: Other Information + description: | + Any alternative solutions or features you considered, a more detailed explanation, stack traces, related issues, links for context, etc. + validations: + required: false + - type: checkboxes + id: ack + attributes: + label: Acknowledgements + options: + - label: I may be able to implement this feature request + required: false + - label: This feature might incur a breaking change + required: false + - type: input + id: sdk-version + attributes: + label: SDK version used + validations: + required: true + - type: input + id: environment + attributes: + label: Environment details (OS name and version, etc.) + validations: + required: true diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..195bf2d --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,59 @@ +name: CI + +on: + push: + branches: + - '*' + - '!main' + +env: + RUN: ${{ github.run_id }}-${{ github.run_number }} + AWS_DEFAULT_REGION: us-east-1 + CI_SDK_V1_ROLE: arn:aws:iam::180635532705:role/CI_SDK_V1_ROLE + PACKAGE_NAME: aws-iot-device-sdk-python + AWS_EC2_METADATA_DISABLED: true + +jobs: + unit-tests: + runs-on: ubuntu-20.04 + strategy: + fail-fast: false + + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: '3.8' + - name: Unit tests + run: | + python3 setup.py install + pip install pytest + pip install mock + python3 -m pytest test + + integration-tests: + runs-on: ubuntu-latest + permissions: + id-token: write # This is required for requesting the JWT + contents: read # This is required for actions/checkout + strategy: + fail-fast: false + matrix: + test-type: [ MutualAuth, Websocket, ALPN ] + python-version: [ '3.8', '3.13' ] + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - uses: aws-actions/configure-aws-credentials@v2 + with: + role-to-assume: ${{ env.CI_SDK_V1_ROLE }} + aws-region: ${{ env.AWS_DEFAULT_REGION }} + - name: Integration tests + run: | + pip install pytest + pip install mock + pip install boto3 + python --version + ./test-integration/run/run.sh ${{ matrix.test-type }} 1000 100 7 diff --git a/.github/workflows/closed-issue-message.yml b/.github/workflows/closed-issue-message.yml new file mode 100644 index 0000000..22bf2a7 --- /dev/null +++ b/.github/workflows/closed-issue-message.yml @@ -0,0 +1,19 @@ +name: Closed Issue Message +on: + issues: + types: [closed] +jobs: + auto_comment: + runs-on: ubuntu-latest + permissions: + issues: write + steps: + - uses: aws-actions/closed-issue-message@v1 + with: + # These inputs are both required + repo-token: "${{ secrets.GITHUB_TOKEN }}" + message: | + ### ⚠️COMMENT VISIBILITY WARNING⚠️ + Comments on closed issues are hard for our team to see. + If you need more assistance, please either tag a team member or open a new issue that references this one. + If you wish to keep having a conversation with other community members under this issue feel free to do so. diff --git a/.github/workflows/handle-stale-discussions.yml b/.github/workflows/handle-stale-discussions.yml new file mode 100644 index 0000000..4fbcd70 --- /dev/null +++ b/.github/workflows/handle-stale-discussions.yml @@ -0,0 +1,19 @@ +name: HandleStaleDiscussions +on: + schedule: + - cron: '0 */4 * * *' + discussion_comment: + types: [created] + +jobs: + handle-stale-discussions: + name: Handle stale discussions + runs-on: ubuntu-latest + permissions: + discussions: write + steps: + - name: Stale discussions action + uses: aws-github-ops/handle-stale-discussions@v1 + env: + GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}} + \ No newline at end of file diff --git a/.github/workflows/issue-regression-labeler.yml b/.github/workflows/issue-regression-labeler.yml new file mode 100644 index 0000000..bd00071 --- /dev/null +++ b/.github/workflows/issue-regression-labeler.yml @@ -0,0 +1,32 @@ +# Apply potential regression label on issues +name: issue-regression-label +on: + issues: + types: [opened, edited] +jobs: + add-regression-label: + runs-on: ubuntu-latest + permissions: + issues: write + steps: + - name: Fetch template body + id: check_regression + uses: actions/github-script@v7 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + TEMPLATE_BODY: ${{ github.event.issue.body }} + with: + script: | + const regressionPattern = /\[x\] Select this option if this issue appears to be a regression\./i; + const template = `${process.env.TEMPLATE_BODY}` + const match = regressionPattern.test(template); + core.setOutput('is_regression', match); + - name: Manage regression label + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + if [ "${{ steps.check_regression.outputs.is_regression }}" == "true" ]; then + gh issue edit ${{ github.event.issue.number }} --add-label "potential-regression" -R ${{ github.repository }} + else + gh issue edit ${{ github.event.issue.number }} --remove-label "potential-regression" -R ${{ github.repository }} + fi diff --git a/.github/workflows/stale_issue.yml b/.github/workflows/stale_issue.yml new file mode 100644 index 0000000..cbcc8b4 --- /dev/null +++ b/.github/workflows/stale_issue.yml @@ -0,0 +1,49 @@ +name: "Close stale issues" + +# Controls when the action will run. +on: + schedule: + - cron: "0 0 * * *" + +jobs: + cleanup: + runs-on: ubuntu-latest + name: Stale issue job + permissions: + issues: write + pull-requests: write + steps: + - uses: aws-actions/stale-issue-cleanup@v3 + with: + # Setting messages to an empty string will cause the automation to skip + # that category + ancient-issue-message: Greetings! Sorry to say but this is a very old issue that is probably not getting as much attention as it deserves. We encourage you to try V2 and if you find that this is still a problem, please feel free to open a new issue there. + stale-issue-message: Greetings! It looks like this issue hasn’t been active in longer than a week. Because it has been longer than a week since the last update on this, and in the absence of more information, we will be closing this issue soon. If you find that this is still a problem, please feel free to provide a comment or add an upvote to prevent automatic closure, or if the issue is already closed, please feel free to open a new one, also please try V2 as this might be solved there too. + stale-pr-message: Greetings! It looks like this PR hasn’t been active in longer than a week, add a comment or an upvote to prevent automatic closure, or if the issue is already closed, please feel free to open a new one. + + # These labels are required + stale-issue-label: closing-soon + exempt-issue-label: automation-exempt + stale-pr-label: closing-soon + exempt-pr-label: pr/needs-review + response-requested-label: response-requested + + # Don't set closed-for-staleness label to skip closing very old issues + # regardless of label + closed-for-staleness-label: closed-for-staleness + + # Issue timing + days-before-stale: 7 + days-before-close: 4 + days-before-ancient: 36500 + + # If you don't want to mark a issue as being ancient based on a + # threshold of "upvotes", you can set this here. An "upvote" is + # the total number of +1, heart, hooray, and rocket reactions + # on an issue. + minimum-upvotes-to-exempt: 1 + + repo-token: ${{ secrets.GITHUB_TOKEN }} + loglevel: DEBUG + # Set dry-run to true to not perform label or close actions. + dry-run: false diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b1e0672 --- /dev/null +++ b/.gitignore @@ -0,0 +1,534 @@ + +# Created by https://www.gitignore.io/api/git,c++,cmake,python,visualstudio,visualstudiocode + +### C++ ### +# Prerequisites +*.d + +# Compiled Object files +*.slo +*.lo +*.o +*.obj + +# Precompiled Headers +*.gch +*.pch + +# Compiled Dynamic libraries +*.so +*.dylib +*.dll + +# Fortran module files +*.mod +*.smod + +# Compiled Static libraries +*.lai +*.la +*.a +*.lib + +# Executables +*.exe +*.out +*.app + +### CMake ### +CMakeCache.txt +CMakeFiles +CMakeScripts +Testing +Makefile +cmake_install.cmake +install_manifest.txt +compile_commands.json +CTestTestfile.cmake + +### Git ### +# Created by git for backups. To disable backups in Git: +# $ git config --global mergetool.keepBackup false +*.orig + +# Created by git when using merge tools for conflicts +*.BACKUP.* +*.BASE.* +*.LOCAL.* +*.REMOTE.* +*_BACKUP_*.txt +*_BASE_*.txt +*_LOCAL_*.txt +*_REMOTE_*.txt + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions + +# Distribution / packaging +.Python +build/ +deps_build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheelhouse/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +### Python Patch ### +.venv/ + +### Python.VirtualEnv Stack ### +# Virtualenv +# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ +[Ii]nclude +[Ll]ib +[Ll]ib64 +[Ll]ocal +[Ss]cripts +pyvenv.cfg +pip-selfcheck.json + +### VisualStudioCode ### +.vscode/* + +### VisualStudio ### +## Ignore Visual Studio temporary files, build results, and +## files generated by popular Visual Studio add-ons. +## +## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore + +# User-specific files +*.suo +*.user +*.userosscache +*.sln.docstates + +# User-specific files (MonoDevelop/Xamarin Studio) +*.userprefs + +# Build results +[Dd]ebug/ +[Dd]ebugPublic/ +[Rr]elease/ +[Rr]eleases/ +x64/ +x86/ +bld/ +[Oo]bj/ +[Ll]og/ + +# Visual Studio 2015/2017 cache/options directory +.vs/ +# Uncomment if you have tasks that create the project's static files in wwwroot +#wwwroot/ + +# Visual Studio 2017 auto generated files +Generated\ Files/ + +# MSTest test Results +[Tt]est[Rr]esult*/ +[Bb]uild[Ll]og.* + +# NUNIT +*.VisualState.xml +TestResult.xml + +# Build Results of an ATL Project +[Dd]ebugPS/ +[Rr]eleasePS/ +dlldata.c + +# Benchmark Results +BenchmarkDotNet.Artifacts/ + +# .NET Core +project.lock.json +project.fragment.lock.json +artifacts/ + +# StyleCop +StyleCopReport.xml + +# Files built by Visual Studio +*_i.c +*_p.c +*_h.h +*.ilk +*.meta +*.iobj +*.pdb +*.ipdb +*.pgc +*.pgd +*.rsp +*.sbr +*.tlb +*.tli +*.tlh +*.tmp +*.tmp_proj +*_wpftmp.csproj +*.vspscc +*.vssscc +.builds +*.pidb +*.svclog +*.scc + +# Chutzpah Test files +_Chutzpah* + +# Visual C++ cache files +ipch/ +*.aps +*.ncb +*.opendb +*.opensdf +*.sdf +*.cachefile +*.VC.db +*.VC.VC.opendb + +# Visual Studio profiler +*.psess +*.vsp +*.vspx +*.sap + +# Visual Studio Trace Files +*.e2e + +# TFS 2012 Local Workspace +$tf/ + +# Guidance Automation Toolkit +*.gpState + +# ReSharper is a .NET coding add-in +_ReSharper*/ +*.[Rr]e[Ss]harper +*.DotSettings.user + +# JustCode is a .NET coding add-in +.JustCode + +# TeamCity is a build add-in +_TeamCity* + +# DotCover is a Code Coverage Tool +*.dotCover + +# AxoCover is a Code Coverage Tool +.axoCover/* +!.axoCover/settings.json + +# Visual Studio code coverage results +*.coverage +*.coveragexml + +# NCrunch +_NCrunch_* +.*crunch*.local.xml +nCrunchTemp_* + +# MightyMoose +*.mm.* +AutoTest.Net/ + +# Web workbench (sass) +.sass-cache/ + +# Installshield output folder +[Ee]xpress/ + +# DocProject is a documentation generator add-in +DocProject/buildhelp/ +DocProject/Help/*.HxT +DocProject/Help/*.HxC +DocProject/Help/*.hhc +DocProject/Help/*.hhk +DocProject/Help/*.hhp +DocProject/Help/Html2 +DocProject/Help/html + +# Click-Once directory +publish/ + +# Publish Web Output +*.[Pp]ublish.xml +*.azurePubxml +# Note: Comment the next line if you want to checkin your web deploy settings, +# but database connection strings (with potential passwords) will be unencrypted +*.pubxml +*.publishproj + +# Microsoft Azure Web App publish settings. Comment the next line if you want to +# checkin your Azure Web App publish settings, but sensitive information contained +# in these scripts will be unencrypted +PublishScripts/ + +# NuGet Packages +*.nupkg +# The packages folder can be ignored because of Package Restore +**/[Pp]ackages/* +# except build/, which is used as an MSBuild target. +!**/[Pp]ackages/build/ +# Uncomment if necessary however generally it will be regenerated when needed +#!**/[Pp]ackages/repositories.config +# NuGet v3's project.json files produces more ignorable files +*.nuget.props +*.nuget.targets + +# Microsoft Azure Build Output +csx/ +*.build.csdef + +# Microsoft Azure Emulator +ecf/ +rcf/ + +# Windows Store app package directories and files +AppPackages/ +BundleArtifacts/ +Package.StoreAssociation.xml +_pkginfo.txt +*.appx + +# Visual Studio cache files +# files ending in .cache can be ignored +*.[Cc]ache +# but keep track of directories ending in .cache +!*.[Cc]ache/ + +# Others +ClientBin/ +~$* +*~ +*.dbmdl +*.dbproj.schemaview +*.jfm +*.pfx +*.publishsettings +orleans.codegen.cs + +# Including strong name files can present a security risk +# (https://github.com/github/gitignore/pull/2483#issue-259490424) +#*.snk + +# Since there are multiple workflows, uncomment next line to ignore bower_components +# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) +#bower_components/ + +# RIA/Silverlight projects +Generated_Code/ + +# Backup & report files from converting an old project file +# to a newer Visual Studio version. Backup files are not needed, +# because we have git ;-) +_UpgradeReport_Files/ +Backup*/ +UpgradeLog*.XML +UpgradeLog*.htm +ServiceFabricBackup/ +*.rptproj.bak + +# SQL Server files +*.mdf +*.ldf +*.ndf + +# Business Intelligence projects +*.rdl.data +*.bim.layout +*.bim_*.settings +*.rptproj.rsuser + +# Microsoft Fakes +FakesAssemblies/ + +# GhostDoc plugin setting file +*.GhostDoc.xml + +# Node.js Tools for Visual Studio +.ntvs_analysis.dat +node_modules/ + +# Visual Studio 6 build log +*.plg + +# Visual Studio 6 workspace options file +*.opt + +# Visual Studio 6 auto-generated workspace file (contains which files were open etc.) +*.vbw + +# Visual Studio LightSwitch build output +**/*.HTMLClient/GeneratedArtifacts +**/*.DesktopClient/GeneratedArtifacts +**/*.DesktopClient/ModelManifest.xml +**/*.Server/GeneratedArtifacts +**/*.Server/ModelManifest.xml +_Pvt_Extensions + +# Paket dependency manager +.paket/paket.exe +paket-files/ + +# FAKE - F# Make +.fake/ + +# JetBrains Rider +.idea/ +*.sln.iml + +# CodeRush personal settings +.cr/personal + +# Python Tools for Visual Studio (PTVS) +*.pyc + +# Cake - Uncomment if you are using it +# tools/** +# !tools/packages.config + +# Tabs Studio +*.tss + +# Telerik's JustMock configuration file +*.jmconfig + +# BizTalk build output +*.btp.cs +*.btm.cs +*.odx.cs +*.xsd.cs + +# OpenCover UI analysis results +OpenCover/ + +# Azure Stream Analytics local run output +ASALocalRun/ + +# MSBuild Binary and Structured Log +*.binlog + +# NVidia Nsight GPU debugger configuration file +*.nvuser + +# MFractors (Xamarin productivity tool) working folder +.mfractor/ + +# Local History for Visual Studio +.localhistory/ + + +# End of https://www.gitignore.io/api/git,c++,cmake,python,visualstudio,visualstudiocode + +# credentials +.key +*.pem +.crt + +# deps from build-deps.sh +deps/ diff --git a/AWSIoTPythonSDK/MQTTLib.py b/AWSIoTPythonSDK/MQTTLib.py index 083ad54..6b9f20c 100755 --- a/AWSIoTPythonSDK/MQTTLib.py +++ b/AWSIoTPythonSDK/MQTTLib.py @@ -1,6 +1,6 @@ # #/* -# * Copyright 2010-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # * # * Licensed under the Apache License, Version 2.0 (the "License"). # * You may not use this file except in compliance with the License. @@ -14,20 +14,24 @@ # * permissions and limitations under the License. # */ -# import mqttCore -import AWSIoTPythonSDK.core.protocol.mqttCore as mqttCore -# import shadowManager +from AWSIoTPythonSDK.core.util.providers import CertificateCredentialsProvider +from AWSIoTPythonSDK.core.util.providers import CiphersProvider +from AWSIoTPythonSDK.core.util.providers import IAMCredentialsProvider +from AWSIoTPythonSDK.core.util.providers import EndpointProvider +from AWSIoTPythonSDK.core.jobs.thingJobManager import jobExecutionTopicType +from AWSIoTPythonSDK.core.jobs.thingJobManager import jobExecutionTopicReplyType +from AWSIoTPythonSDK.core.protocol.mqtt_core import MqttCore import AWSIoTPythonSDK.core.shadow.shadowManager as shadowManager -# import deviceShadow import AWSIoTPythonSDK.core.shadow.deviceShadow as deviceShadow +import AWSIoTPythonSDK.core.jobs.thingJobManager as thingJobManager + # Constants # - Protocol types: MQTTv3_1 = 3 MQTTv3_1_1 = 4 -# - OfflinePublishQueueing drop behavior: + DROP_OLDEST = 0 DROP_NEWEST = 1 -# class AWSIoTMQTTClient: @@ -78,14 +82,13 @@ def __init__(self, clientID, protocolType=MQTTv3_1_1, useWebsocket=False, cleanS **Returns** - AWSIoTPythonSDK.MQTTLib.AWSIoTMQTTClient object + :code:`AWSIoTPythonSDK.MQTTLib.AWSIoTMQTTClient` object """ - # mqttCore(clientID, cleanSession, protocol, srcLogManager, srcUseWebsocket=False) - self._mqttCore = mqttCore.mqttCore(clientID, cleanSession, protocolType, useWebsocket) + self._mqtt_core = MqttCore(clientID, cleanSession, protocolType, useWebsocket) # Configuration APIs - def configureLastWill(self, topic, payload, QoS): + def configureLastWill(self, topic, payload, QoS, retain=False): """ **Description** @@ -110,8 +113,7 @@ def configureLastWill(self, topic, payload, QoS): None """ - # mqttCore.setLastWill(srcTopic, srcPayload, srcQos) - self._mqttCore.setLastWill(topic, payload, QoS) + self._mqtt_core.configure_last_will(topic, payload, QoS, retain) def clearLastWill(self): """ @@ -134,8 +136,7 @@ def clearLastWill(self): None """ - #mqttCore.clearLastWill() - self._mqttCore.clearLastWill() + self._mqtt_core.clear_last_will() def configureEndpoint(self, hostName, portNumber): """ @@ -155,15 +156,20 @@ def configureEndpoint(self, hostName, portNumber): *hostName* - String that denotes the host name of the user-specific AWS IoT endpoint. *portNumber* - Integer that denotes the port number to connect to. Could be :code:`8883` for - TLSv1.2 Mutual Authentication or :code:`443` for Websocket SigV4. + TLSv1.2 Mutual Authentication or :code:`443` for Websocket SigV4 and TLSv1.2 Mutual Authentication + with ALPN extension. **Returns** None """ - # mqttCore.configEndpoint(srcHost, srcPort) - self._mqttCore.configEndpoint(hostName, portNumber) + endpoint_provider = EndpointProvider() + endpoint_provider.set_host(hostName) + endpoint_provider.set_port(portNumber) + self._mqtt_core.configure_endpoint(endpoint_provider) + if portNumber == 443 and not self._mqtt_core.use_wss(): + self._mqtt_core.configure_alpn_protocols() def configureIAMCredentials(self, AWSAccessKeyID, AWSSecretAccessKey, AWSSessionToken=""): """ @@ -196,10 +202,13 @@ def configureIAMCredentials(self, AWSAccessKeyID, AWSSecretAccessKey, AWSSession None """ - # mqttCore.configIAMCredentials(srcAWSAccessKeyID, srcAWSSecretAccessKey, srcAWSSessionToken) - self._mqttCore.configIAMCredentials(AWSAccessKeyID, AWSSecretAccessKey, AWSSessionToken) + iam_credentials_provider = IAMCredentialsProvider() + iam_credentials_provider.set_access_key_id(AWSAccessKeyID) + iam_credentials_provider.set_secret_access_key(AWSSecretAccessKey) + iam_credentials_provider.set_session_token(AWSSessionToken) + self._mqtt_core.configure_iam_credentials(iam_credentials_provider) - def configureCredentials(self, CAFilePath, KeyPath="", CertificatePath=""): # Should be good for MutualAuth certs config and Websocket rootCA config + def configureCredentials(self, CAFilePath, KeyPath="", CertificatePath="", Ciphers=None): # Should be good for MutualAuth certs config and Websocket rootCA config """ **Description** @@ -219,13 +228,22 @@ def configureCredentials(self, CAFilePath, KeyPath="", CertificatePath=""): # S *CertificatePath* - Path to read the certificate. Required for X.509 certificate based connection. + *Ciphers* - String of colon split SSL ciphers to use. If not passed, default ciphers will be used. + **Returns** None """ - # mqttCore.configCredentials(srcCAFile, srcKey, srcCert) - self._mqttCore.configCredentials(CAFilePath, KeyPath, CertificatePath) + cert_credentials_provider = CertificateCredentialsProvider() + cert_credentials_provider.set_ca_path(CAFilePath) + cert_credentials_provider.set_key_path(KeyPath) + cert_credentials_provider.set_cert_path(CertificatePath) + + cipher_provider = CiphersProvider() + cipher_provider.set_ciphers(Ciphers) + + self._mqtt_core.configure_cert_credentials(cert_credentials_provider, cipher_provider) def configureAutoReconnectBackoffTime(self, baseReconnectQuietTimeSecond, maxReconnectQuietTimeSecond, stableConnectionTimeSecond): """ @@ -256,15 +274,14 @@ def configureAutoReconnectBackoffTime(self, baseReconnectQuietTimeSecond, maxRec None """ - # mqttCore.setBackoffTime(srcBaseReconnectTimeSecond, srcMaximumReconnectTimeSecond, srcMinimumConnectTimeSecond) - self._mqttCore.setBackoffTime(baseReconnectQuietTimeSecond, maxReconnectQuietTimeSecond, stableConnectionTimeSecond) + self._mqtt_core.configure_reconnect_back_off(baseReconnectQuietTimeSecond, maxReconnectQuietTimeSecond, stableConnectionTimeSecond) def configureOfflinePublishQueueing(self, queueSize, dropBehavior=DROP_NEWEST): """ **Description** - Used to configure the queue size and drop behavior for the offline publish requests queueing. Should be - called before connect. + Used to configure the queue size and drop behavior for the offline requests queueing. Should be + called before connect. Queueable offline requests include publish, subscribe and unsubscribe. **Syntax** @@ -282,15 +299,15 @@ def configureOfflinePublishQueueing(self, queueSize, dropBehavior=DROP_NEWEST): If set to 0, the queue is disabled. If set to -1, the queue size is set to be infinite. *dropBehavior* - the type of drop behavior when the queue is full. - Could be :code:`AWSIoTPythonSDK.MQTTLib.DROP_OLDEST` or :code:`AWSIoTPythonSDK.MQTTLib.DROP_NEWEST`. + Could be :code:`AWSIoTPythonSDK.core.util.enums.DropBehaviorTypes.DROP_OLDEST` or + :code:`AWSIoTPythonSDK.core.util.enums.DropBehaviorTypes.DROP_NEWEST`. **Returns** None """ - # mqttCore.setOfflinePublishQueueing(srcQueueSize, srcDropBehavior=mqtt.MSG_QUEUEING_DROP_NEWEST) - self._mqttCore.setOfflinePublishQueueing(queueSize, dropBehavior) + self._mqtt_core.configure_offline_requests_queue(queueSize, dropBehavior) def configureDrainingFrequency(self, frequencyInHz): """ @@ -320,8 +337,7 @@ def configureDrainingFrequency(self, frequencyInHz): None """ - # mqttCore.setDrainingIntervalSecond(srcDrainingIntervalSecond) - self._mqttCore.setDrainingIntervalSecond(1/float(frequencyInHz)) + self._mqtt_core.configure_draining_interval_sec(1/float(frequencyInHz)) def configureConnectDisconnectTimeout(self, timeoutSecond): """ @@ -346,8 +362,7 @@ def configureConnectDisconnectTimeout(self, timeoutSecond): None """ - # mqttCore.setConnectDisconnectTimeoutSecond(srcConnectDisconnectTimeout) - self._mqttCore.setConnectDisconnectTimeoutSecond(timeoutSecond) + self._mqtt_core.configure_connect_disconnect_timeout_sec(timeoutSecond) def configureMQTTOperationTimeout(self, timeoutSecond): """ @@ -372,457 +387,446 @@ def configureMQTTOperationTimeout(self, timeoutSecond): None """ - # mqttCore.setMQTTOperationTimeoutSecond(srcMQTTOperationTimeout) - self._mqttCore.setMQTTOperationTimeoutSecond(timeoutSecond) + self._mqtt_core.configure_operation_timeout_sec(timeoutSecond) - # MQTT functionality APIs - def connect(self, keepAliveIntervalSecond=30): + def configureUsernamePassword(self, username, password=None): """ **Description** - Connect to AWS IoT, with user-specific keeoalive interval configuration. + Used to configure the username and password used in CONNECT packet. **Syntax** .. code:: python - # Connect to AWS IoT with default keepalive set to 30 seconds - myAWSIoTMQTTClient.connect() - # Connect to AWS IoT with keepalive interval set to 55 seconds - myAWSIoTMQTTClient.connect(55) + # Configure user name and password + myAWSIoTMQTTClient.configureUsernamePassword("myUsername", "myPassword") **Parameters** - *keepAliveIntervalSecond* - Time in seconds for interval of sending MQTT ping request. - Default set to 30 seconds. + *username* - Username used in the username field of CONNECT packet. + + *password* - Password used in the password field of CONNECT packet. **Returns** - True if the connect attempt succeeded. False if failed. + None """ - # mqttCore.connect(keepAliveInterval=30) - return self._mqttCore.connect(keepAliveIntervalSecond) + self._mqtt_core.configure_username_password(username, password) - def disconnect(self): + def configureSocketFactory(self, socket_factory): """ **Description** - Disconnect from AWS IoT. + Configure a socket factory to custom configure a different socket type for + mqtt connection. Creating a custom socket allows for configuration of a proxy **Syntax** .. code:: python - myAWSIoTMQTTClient.disconnect() + # Configure socket factory + custom_args = {"arg1": "val1", "arg2": "val2"} + socket_factory = lambda: custom.create_connection((host, port), **custom_args) + myAWSIoTMQTTClient.configureSocketFactory(socket_factory) **Parameters** - None + *socket_factory* - Anonymous function which creates a custom socket to spec. **Returns** - True if the disconnect attempt succeeded. False if failed. + None """ - # mqttCore.disconnect() - return self._mqttCore.disconnect() - - def publish(self, topic, payload, QoS): + self._mqtt_core.configure_socket_factory(socket_factory) + + def enableMetricsCollection(self): """ **Description** - Publish a new message to the desired topic with QoS. + Used to enable SDK metrics collection. Username field in CONNECT packet will be used to append the SDK name + and SDK version in use and communicate to AWS IoT cloud. This metrics collection is enabled by default. **Syntax** .. code:: python - # Publish a QoS0 message "myPayload" to topic "myToppic" - myAWSIoTMQTTClient.publish("myTopic", "myPayload", 0) - # Publish a QoS1 message "myPayloadWithQos1" to topic "myTopic/sub" - myAWSIoTMQTTClient.publish("myTopic/sub", "myPayloadWithQos1", 1) + myAWSIoTMQTTClient.enableMetricsCollection() **Parameters** - *topic* - Topic name to publish to. - - *payload* - Payload to publish. - - *QoS* - Quality of Service. Could be 0 or 1. + None **Returns** - True if the publish request has been sent to paho. False if the request did not reach paho. + None """ - # mqttCore.publish(topic, payload, qos, retain) - return self._mqttCore.publish(topic, payload, QoS, False) # Disable retain for publish by now + self._mqtt_core.enable_metrics_collection() - def subscribe(self, topic, QoS, callback): + def disableMetricsCollection(self): """ **Description** - Subscribe to the desired topic and register a callback. + Used to disable SDK metrics collection. **Syntax** .. code:: python - # Subscribe to "myTopic" with QoS0 and register a callback - myAWSIoTMQTTClient.subscribe("myTopic", 0, customCallback) - # Subscribe to "myTopic/#" with QoS1 and register a callback - myAWSIoTMQTTClient.subscribe("myTopic/#", 1, customCallback) + myAWSIoTMQTTClient.disableMetricsCollection() **Parameters** - *topic* - Topic name or filter to subscribe to. - - *QoS* - Quality of Service. Could be 0 or 1. - - *callback* - Function to be called when a new message for the subscribed topic - comes in. Should be in form :code:`customCallback(client, userdata, message)`, where - :code:`message` contains :code:`topic` and :code:`payload`. + None **Returns** - True if the subscribe attempt succeeded. False if failed. + None """ - # mqttCore.subscribe(topic, qos, callback) - return self._mqttCore.subscribe(topic, QoS, callback) + self._mqtt_core.disable_metrics_collection() - def unsubscribe(self, topic): + # MQTT functionality APIs + def connect(self, keepAliveIntervalSecond=600): """ **Description** - Unsubscribed to the desired topic. + Connect to AWS IoT, with user-specific keepalive interval configuration. **Syntax** .. code:: python - myAWSIoTMQTTClient.unsubscribe("myTopic") + # Connect to AWS IoT with default keepalive set to 600 seconds + myAWSIoTMQTTClient.connect() + # Connect to AWS IoT with keepalive interval set to 1200 seconds + myAWSIoTMQTTClient.connect(1200) **Parameters** - *topic* - Topic name or filter to unsubscribe to. + *keepAliveIntervalSecond* - Time in seconds for interval of sending MQTT ping request. + A shorter keep-alive interval allows the client to detect disconnects more quickly. + Default set to 600 seconds. **Returns** - True if the unsubscribe attempt succeeded. False if failed. + True if the connect attempt succeeded. False if failed. """ - # mqttCore.unsubscribe(topic) - return self._mqttCore.unsubscribe(topic) - - -class AWSIoTMQTTShadowClient: + self._load_callbacks() + return self._mqtt_core.connect(keepAliveIntervalSecond) - def __init__(self, clientID, protocolType=MQTTv3_1_1, useWebsocket=False, cleanSession=True): + def connectAsync(self, keepAliveIntervalSecond=600, ackCallback=None): """ + **Description** - The client class that manages device shadow and accesses its functionality in AWS IoT over MQTT v3.1/3.1.1. - - It is built on top of the AWS IoT MQTT Client and exposes devive shadow related operations. - It shares the same connection types, synchronous MQTT operations and partial on-top features - with the AWS IoT MQTT Client: - - - Auto reconnect/resubscribe - - Same as AWS IoT MQTT Client. - - - Progressive reconnect backoff - - Same as AWS IoT MQTT Client. - - - Offline publish requests queueing with draining - - Disabled by default. Queueing is not allowed for time-sensitive shadow requests/messages. + Connect asynchronously to AWS IoT, with user-specific keepalive interval configuration and CONNACK callback. **Syntax** .. code:: python - import AWSIoTPythonSDK.MQTTLib as AWSIoTPyMQTT - - # Create an AWS IoT MQTT Shadow Client using TLSv1.2 Mutual Authentication - myAWSIoTMQTTShadowClient = AWSIoTPyMQTT.AWSIoTMQTTShadowClient("testIoTPySDK") - # Create an AWS IoT MQTT Shadow Client using Websocket SigV4 - myAWSIoTMQTTShadowClient = AWSIoTPyMQTT.AWSIoTMQTTShadowClient("testIoTPySDK", useWebsocket=True) + # Connect to AWS IoT with default keepalive set to 600 seconds and a custom CONNACK callback + myAWSIoTMQTTClient.connectAsync(ackCallback=my_connack_callback) + # Connect to AWS IoT with default keepalive set to 1200 seconds and a custom CONNACK callback + myAWSIoTMQTTClient.connectAsync(keepAliveInternvalSecond=1200, ackCallback=myConnackCallback) **Parameters** - *clientID* - String that denotes the client identifier used to connect to AWS IoT. - If empty string were provided, client id for this connection will be randomly generated - n server side. - - *protocolType* - MQTT version in use for this connection. Could be :code:`AWSIoTPythonSDK.MQTTLib.MQTTv3_1` or :code:`AWSIoTPythonSDK.MQTTLib.MQTTv3_1_1` + *keepAliveIntervalSecond* - Time in seconds for interval of sending MQTT ping request. + Default set to 600 seconds. - *useWebsocket* - Boolean that denotes enabling MQTT over Websocket SigV4 or not. + *ackCallback* - Callback to be invoked when the client receives a CONNACK. Should be in form + :code:`customCallback(mid, data)`, where :code:`mid` is the packet id for the connect request + and :code:`data` is the connect result code. **Returns** - AWSIoTPythonSDK.MQTTLib.AWSIoTMQTTShadowClient object + Connect request packet id, for tracking purpose in the corresponding callback. """ - # AWSIOTMQTTClient instance - self._AWSIoTMQTTClient = AWSIoTMQTTClient(clientID, protocolType, useWebsocket, cleanSession) - # Configure it to disable offline Publish Queueing - self._AWSIoTMQTTClient.configureOfflinePublishQueueing(0) # Disable queueing, no queueing for time-sentive shadow messages - self._AWSIoTMQTTClient.configureDrainingFrequency(10) - # Now retrieve the configured mqttCore and init a shadowManager instance - self._shadowManager = shadowManager.shadowManager(self._AWSIoTMQTTClient._mqttCore) + self._load_callbacks() + return self._mqtt_core.connect_async(keepAliveIntervalSecond, ackCallback) - # Configuration APIs - def configureLastWill(self, topic, payload, QoS): + def _load_callbacks(self): + self._mqtt_core.on_online = self.onOnline + self._mqtt_core.on_offline = self.onOffline + self._mqtt_core.on_message = self.onMessage + + def disconnect(self): """ **Description** - Used to configure the last will topic, payload and QoS of the client. Should be called before connect. + Disconnect from AWS IoT. **Syntax** .. code:: python - myAWSIoTMQTTClient.configureLastWill("last/Will/Topic", "lastWillPayload", 0) + myAWSIoTMQTTClient.disconnect() **Parameters** - *topic* - Topic name that last will publishes to. - - *payload* - Payload to publish for last will. - - *QoS* - Quality of Service. Could be 0 or 1. + None **Returns** - None + True if the disconnect attempt succeeded. False if failed. """ - # AWSIoTMQTTClient.configureLastWill(srcTopic, srcPayload, srcQos) - self._AWSIoTMQTTClient.configureLastWill(topic, payload, QoS) + return self._mqtt_core.disconnect() - def clearLastWill(self): + def disconnectAsync(self, ackCallback=None): """ **Description** - Used to clear the last will configuration that is previously set through configureLastWill. + Disconnect asynchronously to AWS IoT. **Syntax** .. code:: python - myAWSIoTShadowMQTTClient.clearLastWill() + myAWSIoTMQTTClient.disconnectAsync(ackCallback=myDisconnectCallback) - **Parameter** + **Parameters** - None + *ackCallback* - Callback to be invoked when the client finishes sending disconnect and internal clean-up. + Should be in form :code:`customCallback(mid, data)`, where :code:`mid` is the packet id for the disconnect + request and :code:`data` is the disconnect result code. **Returns** - None - + Disconnect request packet id, for tracking purpose in the corresponding callback. + """ - # AWSIoTMQTTClient.clearLastWill() - self._AWSIoTMQTTClient.clearLastWill() + return self._mqtt_core.disconnect_async(ackCallback) - def configureEndpoint(self, hostName, portNumber): + def publish(self, topic, payload, QoS): """ **Description** - Used to configure the host name and port number the underneath AWS IoT MQTT Client tries to connect to. Should be called - before connect. + Publish a new message to the desired topic with QoS. **Syntax** .. code:: python - myAWSIoTMQTTShadowClient.configureEndpoint("random.iot.region.amazonaws.com", 8883) + # Publish a QoS0 message "myPayload" to topic "myTopic" + myAWSIoTMQTTClient.publish("myTopic", "myPayload", 0) + # Publish a QoS1 message "myPayloadWithQos1" to topic "myTopic/sub" + myAWSIoTMQTTClient.publish("myTopic/sub", "myPayloadWithQos1", 1) **Parameters** - *hostName* - String that denotes the host name of the user-specific AWS IoT endpoint. + *topic* - Topic name to publish to. - *portNumber* - Integer that denotes the port number to connect to. Could be :code:`8883` for - TLSv1.2 Mutual Authentication or :code:`443` for Websocket SigV4. + *payload* - Payload to publish. + + *QoS* - Quality of Service. Could be 0 or 1. **Returns** - None + True if the publish request has been sent to paho. False if the request did not reach paho. """ - # AWSIoTMQTTClient.configureEndpoint - self._AWSIoTMQTTClient.configureEndpoint(hostName, portNumber) + return self._mqtt_core.publish(topic, payload, QoS, False) # Disable retain for publish by now - def configureIAMCredentials(self, AWSAccessKeyID, AWSSecretAccessKey, AWSSTSToken=""): + def publishAsync(self, topic, payload, QoS, ackCallback=None): """ **Description** - Used to configure/update the custom IAM credentials for the underneath AWS IoT MQTT Client - for Websocket SigV4 connection to AWS IoT. Should be called before connect. + Publish a new message asynchronously to the desired topic with QoS and PUBACK callback. Note that the ack + callback configuration for a QoS0 publish request will be ignored as there are no PUBACK reception. **Syntax** .. code:: python - myAWSIoTMQTTShadowClient.configureIAMCredentials(obtainedAccessKeyID, obtainedSecretAccessKey, obtainedSessionToken) - - .. note:: - - Hard-coding credentials into custom script is NOT recommended. Please use AWS Cognito identity service - or other credential provider. + # Publish a QoS0 message "myPayload" to topic "myTopic" + myAWSIoTMQTTClient.publishAsync("myTopic", "myPayload", 0) + # Publish a QoS1 message "myPayloadWithQos1" to topic "myTopic/sub", with custom PUBACK callback + myAWSIoTMQTTClient.publishAsync("myTopic/sub", "myPayloadWithQos1", 1, ackCallback=myPubackCallback) **Parameters** - *AWSAccessKeyID* - AWS Access Key Id from user-specific IAM credentials. + *topic* - Topic name to publish to. - *AWSSecretAccessKey* - AWS Secret Access Key from user-specific IAM credentials. + *payload* - Payload to publish. - *AWSSessionToken* - AWS Session Token for temporary authentication from STS. + *QoS* - Quality of Service. Could be 0 or 1. + + *ackCallback* - Callback to be invoked when the client receives a PUBACK. Should be in form + :code:`customCallback(mid)`, where :code:`mid` is the packet id for the disconnect request. **Returns** - None + Publish request packet id, for tracking purpose in the corresponding callback. """ - # AWSIoTMQTTClient.configureIAMCredentials - self._AWSIoTMQTTClient.configureIAMCredentials(AWSAccessKeyID, AWSSecretAccessKey, AWSSTSToken) + return self._mqtt_core.publish_async(topic, payload, QoS, False, ackCallback) - def configureCredentials(self, CAFilePath, KeyPath="", CertificatePath=""): # Should be good for MutualAuth and Websocket + def subscribe(self, topic, QoS, callback): """ **Description** + Subscribe to the desired topic and register a callback. + **Syntax** + .. code:: python + + # Subscribe to "myTopic" with QoS0 and register a callback + myAWSIoTMQTTClient.subscribe("myTopic", 0, customCallback) + # Subscribe to "myTopic/#" with QoS1 and register a callback + myAWSIoTMQTTClient.subscribe("myTopic/#", 1, customCallback) + **Parameters** + *topic* - Topic name or filter to subscribe to. + + *QoS* - Quality of Service. Could be 0 or 1. + + *callback* - Function to be called when a new message for the subscribed topic + comes in. Should be in form :code:`customCallback(client, userdata, message)`, where + :code:`message` contains :code:`topic` and :code:`payload`. Note that :code:`client` and :code:`userdata` are + here just to be aligned with the underneath Paho callback function signature. These fields are pending to be + deprecated and should not be depended on. + **Returns** + True if the subscribe attempt succeeded. False if failed. + """ - # AWSIoTMQTTClient.configureCredentials - self._AWSIoTMQTTClient.configureCredentials(CAFilePath, KeyPath, CertificatePath) + return self._mqtt_core.subscribe(topic, QoS, callback) - def configureAutoReconnectBackoffTime(self, baseReconnectQuietTimeSecond, maxReconnectQuietTimeSecond, stableConnectionTimeSecond): + def subscribeAsync(self, topic, QoS, ackCallback=None, messageCallback=None): """ **Description** - Used to configure the rootCA, private key and certificate files. Should be called before connect. + Subscribe to the desired topic and register a message callback with SUBACK callback. **Syntax** .. code:: python - myAWSIoTMQTTShadowClient.configureCredentials("PATH/TO/ROOT_CA", "PATH/TO/PRIVATE_KEY", "PATH/TO/CERTIFICATE") + # Subscribe to "myTopic" with QoS0, custom SUBACK callback and a message callback + myAWSIoTMQTTClient.subscribe("myTopic", 0, ackCallback=mySubackCallback, messageCallback=customMessageCallback) + # Subscribe to "myTopic/#" with QoS1, custom SUBACK callback and a message callback + myAWSIoTMQTTClient.subscribe("myTopic/#", 1, ackCallback=mySubackCallback, messageCallback=customMessageCallback) **Parameters** - *CAFilePath* - Path to read the root CA file. Required for all connection types. + *topic* - Topic name or filter to subscribe to. - *KeyPath* - Path to read the private key. Required for X.509 certificate based connection. + *QoS* - Quality of Service. Could be 0 or 1. - *CertificatePath* - Path to read the certificate. Required for X.509 certificate based connection. + *ackCallback* - Callback to be invoked when the client receives a SUBACK. Should be in form + :code:`customCallback(mid, data)`, where :code:`mid` is the packet id for the disconnect request and + :code:`data` is the granted QoS for this subscription. + + *messageCallback* - Function to be called when a new message for the subscribed topic + comes in. Should be in form :code:`customCallback(client, userdata, message)`, where + :code:`message` contains :code:`topic` and :code:`payload`. Note that :code:`client` and :code:`userdata` are + here just to be aligned with the underneath Paho callback function signature. These fields are pending to be + deprecated and should not be depended on. **Returns** - None + Subscribe request packet id, for tracking purpose in the corresponding callback. """ - # AWSIoTMQTTClient.configureBackoffTime - self._AWSIoTMQTTClient.configureAutoReconnectBackoffTime(baseReconnectQuietTimeSecond, maxReconnectQuietTimeSecond, stableConnectionTimeSecond) + return self._mqtt_core.subscribe_async(topic, QoS, ackCallback, messageCallback) - def configureConnectDisconnectTimeout(self, timeoutSecond): + def unsubscribe(self, topic): """ **Description** - Used to configure the time in seconds to wait for a CONNACK or a disconnect to complete. - Should be called before connect. + Unsubscribe to the desired topic. **Syntax** .. code:: python - # Configure connect/disconnect timeout to be 10 seconds - myAWSIoTMQTTShadowClient.configureConnectDisconnectTimeout(10) + myAWSIoTMQTTClient.unsubscribe("myTopic") **Parameters** - *timeoutSecond* - Time in seconds to wait for a CONNACK or a disconnect to complete. + *topic* - Topic name or filter to unsubscribe to. **Returns** - None + True if the unsubscribe attempt succeeded. False if failed. """ - # AWSIoTMQTTClient.configureConnectDisconnectTimeout - self._AWSIoTMQTTClient.configureConnectDisconnectTimeout(timeoutSecond) + return self._mqtt_core.unsubscribe(topic) - def configureMQTTOperationTimeout(self, timeoutSecond): + def unsubscribeAsync(self, topic, ackCallback=None): """ **Description** - Used to configure the timeout in seconds for MQTT QoS 1 publish, subscribe and unsubscribe. - Should be called before connect. + Unsubscribe to the desired topic with UNSUBACK callback. **Syntax** .. code:: python - # Configure MQTT operation timeout to be 5 seconds - myAWSIoTMQTTShadowClient.configureMQTTOperationTimeout(5) + myAWSIoTMQTTClient.unsubscribe("myTopic", ackCallback=myUnsubackCallback) **Parameters** - *timeoutSecond* - Time in seconds to wait for a PUBACK/SUBACK/UNSUBACK. + *topic* - Topic name or filter to unsubscribe to. + + *ackCallback* - Callback to be invoked when the client receives a UNSUBACK. Should be in form + :code:`customCallback(mid)`, where :code:`mid` is the packet id for the disconnect request. **Returns** - None + Unsubscribe request packet id, for tracking purpose in the corresponding callback. """ - # AWSIoTMQTTClient.configureMQTTOperationTimeout - self._AWSIoTMQTTClient.configureMQTTOperationTimeout(timeoutSecond) + return self._mqtt_core.unsubscribe_async(topic, ackCallback) - # Start the MQTT connection - def connect(self, keepAliveIntervalSecond=30): + def onOnline(self): """ **Description** - Connect to AWS IoT, with user-specific keepalive interval configuration. + Callback that gets called when the client is online. The callback registration should happen before calling + connect/connectAsync. **Syntax** .. code:: python - # Connect to AWS IoT with default keepalive set to 30 seconds - myAWSIoTMQTTShadowClient.connect() - # Connect to AWS IoT with keepalive interval set to 55 seconds - myAWSIoTMQTTShadowClient.connect(55) + # Register an onOnline callback + myAWSIoTMQTTClient.onOnline = myOnOnlineCallback **Parameters** - *keepAliveIntervalSecond* - Time in seconds for interval of sending MQTT ping request. - Default set to 30 seconds. + None **Returns** - True if the connect attempt succeeded. False if failed. + None """ - return self._AWSIoTMQTTClient.connect(keepAliveIntervalSecond) + pass - # End the MQTT connection - def disconnect(self): + def onOffline(self): """ **Description** - Disconnect from AWS IoT. + Callback that gets called when the client is offline. The callback registration should happen before calling + connect/connectAsync. **Syntax** .. code:: python - myAWSIoTMQTTShadowClient.disconnect() + # Register an onOffline callback + myAWSIoTMQTTClient.onOffline = myOnOfflineCallback **Parameters** @@ -830,76 +834,953 @@ def disconnect(self): **Returns** - True if the disconnect attempt succeeded. False if failed. + None """ - return self._AWSIoTMQTTClient.disconnect() + pass - # Shadow management API - def createShadowHandlerWithName(self, shadowName, isPersistentSubscribe): + def onMessage(self, message): """ **Description** - Create a device shadow handler using the specified shadow name and isPersistentSubscribe. + Callback that gets called when the client receives a new message. The callback registration should happen before + calling connect/connectAsync. This callback, if present, will always be triggered regardless of whether there is + any message callback registered upon subscribe API call. It is for the purpose to aggregating the processing of + received messages in one function. **Syntax** .. code:: python - # Create a device shadow handler for shadow named "Bot1", using persistent subscription - Bot1Shadow = myAWSIoTMQTTShadowClient.createShadowHandlerWithName("Bot1", True) - # Create a device shadow handler for shadow named "Bot2", using non-persistent subscription - Bot2Shadow = myAWSIoTMQTTShadowClient.createShadowHandlerWithName("Bot2", False) + # Register an onMessage callback + myAWSIoTMQTTClient.onMessage = myOnMessageCallback **Parameters** - *shadowName* - Name of the device shadow. - - *isPersistentSubscribe* - Whether to unsubscribe from shadow response (accepted/rejected) topics - when there is a response. Will subscribe at the first time the shadow request is made and will - not unsubscribe if isPersistentSubscribe is set. + *message* - Received MQTT message. It contains the source topic as :code:`message.topic`, and the payload as + :code:`message.payload`. **Returns** - AWSIoTPythonSDK.core.shadow.deviceShadow.deviceShadow object, which exposes the device shadow interface. + None - """ - # Create and return a deviceShadow instance - return deviceShadow.deviceShadow(shadowName, isPersistentSubscribe, self._shadowManager) - # Shadow APIs are accessible in deviceShadow instance": - ### - # deviceShadow.shadowGet - # deviceShadow.shadowUpdate - # deviceShadow.shadowDelete - # deviceShadow.shadowRegisterDelta - # deviceShadow.shadowUnregisterDelta + """ + pass - # MQTT connection management API - def getMQTTConnection(self): +class _AWSIoTMQTTDelegatingClient(object): + + def __init__(self, clientID, protocolType=MQTTv3_1_1, useWebsocket=False, cleanSession=True, awsIoTMQTTClient=None): """ - **Description** - Retrieve the AWS IoT MQTT Client used underneath for shadow operations, making it possible to perform - plain MQTT operations along with shadow operations using the same single connection. + This class is used internally by the SDK and should not be instantiated directly. + + It delegates to a provided AWS IoT MQTT Client or creates a new one given the configuration + parameters and exposes core operations for subclasses provide convenience methods **Syntax** - .. code:: python + None - # Retrieve the AWS IoT MQTT Client used in the AWS IoT MQTT Shadow Client - thisAWSIoTMQTTClient = myAWSIoTMQTTShadowClient.getMQTTConnection() - # Perform plain MQTT operations using the same connection - thisAWSIoTMQTTClient.publish("Topic", "Payload", 1) - ... + **Parameters** + + *clientID* - String that denotes the client identifier used to connect to AWS IoT. + If empty string were provided, client id for this connection will be randomly generated + n server side. + + *protocolType* - MQTT version in use for this connection. Could be :code:`AWSIoTPythonSDK.MQTTLib.MQTTv3_1` or :code:`AWSIoTPythonSDK.MQTTLib.MQTTv3_1_1` + + *useWebsocket* - Boolean that denotes enabling MQTT over Websocket SigV4 or not. + + **Returns** + + AWSIoTPythonSDK.MQTTLib._AWSIoTMQTTDelegatingClient object + + """ + # AWSIOTMQTTClient instance + self._AWSIoTMQTTClient = awsIoTMQTTClient if awsIoTMQTTClient is not None else AWSIoTMQTTClient(clientID, protocolType, useWebsocket, cleanSession) + + # Configuration APIs + def configureLastWill(self, topic, payload, QoS): + """ + **Description** + + Used to configure the last will topic, payload and QoS of the client. Should be called before connect. This is a public + facing API inherited by application level public clients. + + **Syntax** + + .. code:: python + + myShadowClient.configureLastWill("last/Will/Topic", "lastWillPayload", 0) + myJobsClient.configureLastWill("last/Will/Topic", "lastWillPayload", 0) **Parameters** + *topic* - Topic name that last will publishes to. + + *payload* - Payload to publish for last will. + + *QoS* - Quality of Service. Could be 0 or 1. + + **Returns** + + None + + """ + # AWSIoTMQTTClient.configureLastWill(srcTopic, srcPayload, srcQos) + self._AWSIoTMQTTClient.configureLastWill(topic, payload, QoS) + + def clearLastWill(self): + """ + **Description** + + Used to clear the last will configuration that is previously set through configureLastWill. This is a public + facing API inherited by application level public clients. + + **Syntax** + + .. code:: python + + myShadowClient.clearLastWill() + myJobsClient.clearLastWill() + + **Parameter** + None **Returns** - AWSIoTPythonSDK.MQTTLib.AWSIoTMQTTClient object + None + + """ + # AWSIoTMQTTClient.clearLastWill() + self._AWSIoTMQTTClient.clearLastWill() - """ - # Return the internal AWSIoTMQTTClient instance - return self._AWSIoTMQTTClient + def configureEndpoint(self, hostName, portNumber): + """ + **Description** + + Used to configure the host name and port number the underneath AWS IoT MQTT Client tries to connect to. Should be called + before connect. This is a public facing API inherited by application level public clients. + + **Syntax** + + .. code:: python + + myShadowClient.clearLastWill("random.iot.region.amazonaws.com", 8883) + myJobsClient.clearLastWill("random.iot.region.amazonaws.com", 8883) + + **Parameters** + + *hostName* - String that denotes the host name of the user-specific AWS IoT endpoint. + + *portNumber* - Integer that denotes the port number to connect to. Could be :code:`8883` for + TLSv1.2 Mutual Authentication or :code:`443` for Websocket SigV4 and TLSv1.2 Mutual Authentication + with ALPN extension. + + **Returns** + + None + + """ + # AWSIoTMQTTClient.configureEndpoint + self._AWSIoTMQTTClient.configureEndpoint(hostName, portNumber) + + def configureIAMCredentials(self, AWSAccessKeyID, AWSSecretAccessKey, AWSSTSToken=""): + """ + **Description** + + Used to configure/update the custom IAM credentials for the underneath AWS IoT MQTT Client + for Websocket SigV4 connection to AWS IoT. Should be called before connect. This is a public + facing API inherited by application level public clients. + + **Syntax** + + .. code:: python + + myShadowClient.clearLastWill(obtainedAccessKeyID, obtainedSecretAccessKey, obtainedSessionToken) + myJobsClient.clearLastWill(obtainedAccessKeyID, obtainedSecretAccessKey, obtainedSessionToken) + + .. note:: + + Hard-coding credentials into custom script is NOT recommended. Please use AWS Cognito identity service + or other credential provider. + + **Parameters** + + *AWSAccessKeyID* - AWS Access Key Id from user-specific IAM credentials. + + *AWSSecretAccessKey* - AWS Secret Access Key from user-specific IAM credentials. + + *AWSSessionToken* - AWS Session Token for temporary authentication from STS. + + **Returns** + + None + + """ + # AWSIoTMQTTClient.configureIAMCredentials + self._AWSIoTMQTTClient.configureIAMCredentials(AWSAccessKeyID, AWSSecretAccessKey, AWSSTSToken) + + def configureCredentials(self, CAFilePath, KeyPath="", CertificatePath=""): # Should be good for MutualAuth and Websocket + """ + **Description** + + Used to configure the rootCA, private key and certificate files. Should be called before connect. This is a public + facing API inherited by application level public clients. + + **Syntax** + + .. code:: python + + myShadowClient.clearLastWill("PATH/TO/ROOT_CA", "PATH/TO/PRIVATE_KEY", "PATH/TO/CERTIFICATE") + myJobsClient.clearLastWill("PATH/TO/ROOT_CA", "PATH/TO/PRIVATE_KEY", "PATH/TO/CERTIFICATE") + + **Parameters** + + *CAFilePath* - Path to read the root CA file. Required for all connection types. + + *KeyPath* - Path to read the private key. Required for X.509 certificate based connection. + + *CertificatePath* - Path to read the certificate. Required for X.509 certificate based connection. + + **Returns** + + None + + """ + # AWSIoTMQTTClient.configureCredentials + self._AWSIoTMQTTClient.configureCredentials(CAFilePath, KeyPath, CertificatePath) + + def configureAutoReconnectBackoffTime(self, baseReconnectQuietTimeSecond, maxReconnectQuietTimeSecond, stableConnectionTimeSecond): + """ + **Description** + + Used to configure the auto-reconnect backoff timing. Should be called before connect. This is a public + facing API inherited by application level public clients. + + **Syntax** + + .. code:: python + + # Configure the auto-reconnect backoff to start with 1 second and use 128 seconds as a maximum back off time. + # Connection over 20 seconds is considered stable and will reset the back off time back to its base. + myShadowClient.clearLastWill(1, 128, 20) + myJobsClient.clearLastWill(1, 128, 20) + + **Parameters** + + *baseReconnectQuietTimeSecond* - The initial back off time to start with, in seconds. + Should be less than the stableConnectionTime. + + *maxReconnectQuietTimeSecond* - The maximum back off time, in seconds. + + *stableConnectionTimeSecond* - The number of seconds for a connection to last to be considered as stable. + Back off time will be reset to base once the connection is stable. + + **Returns** + + None + + """ + # AWSIoTMQTTClient.configureBackoffTime + self._AWSIoTMQTTClient.configureAutoReconnectBackoffTime(baseReconnectQuietTimeSecond, maxReconnectQuietTimeSecond, stableConnectionTimeSecond) + + def configureConnectDisconnectTimeout(self, timeoutSecond): + """ + **Description** + + Used to configure the time in seconds to wait for a CONNACK or a disconnect to complete. + Should be called before connect. This is a public facing API inherited by application level public clients. + + **Syntax** + + .. code:: python + + # Configure connect/disconnect timeout to be 10 seconds + myShadowClient.configureConnectDisconnectTimeout(10) + myJobsClient.configureConnectDisconnectTimeout(10) + + **Parameters** + + *timeoutSecond* - Time in seconds to wait for a CONNACK or a disconnect to complete. + + **Returns** + + None + + """ + # AWSIoTMQTTClient.configureConnectDisconnectTimeout + self._AWSIoTMQTTClient.configureConnectDisconnectTimeout(timeoutSecond) + + def configureMQTTOperationTimeout(self, timeoutSecond): + """ + **Description** + + Used to configure the timeout in seconds for MQTT QoS 1 publish, subscribe and unsubscribe. + Should be called before connect. This is a public facing API inherited by application level public clients. + + **Syntax** + + .. code:: python + + # Configure MQTT operation timeout to be 5 seconds + myShadowClient.configureMQTTOperationTimeout(5) + myJobsClient.configureMQTTOperationTimeout(5) + + **Parameters** + + *timeoutSecond* - Time in seconds to wait for a PUBACK/SUBACK/UNSUBACK. + + **Returns** + + None + + """ + # AWSIoTMQTTClient.configureMQTTOperationTimeout + self._AWSIoTMQTTClient.configureMQTTOperationTimeout(timeoutSecond) + + def configureUsernamePassword(self, username, password=None): + """ + **Description** + + Used to configure the username and password used in CONNECT packet. This is a public facing API + inherited by application level public clients. + + **Syntax** + + .. code:: python + + # Configure user name and password + myShadowClient.configureUsernamePassword("myUsername", "myPassword") + myJobsClient.configureUsernamePassword("myUsername", "myPassword") + + **Parameters** + + *username* - Username used in the username field of CONNECT packet. + + *password* - Password used in the password field of CONNECT packet. + + **Returns** + + None + + """ + self._AWSIoTMQTTClient.configureUsernamePassword(username, password) + + def configureSocketFactory(self, socket_factory): + """ + **Description** + + Configure a socket factory to custom configure a different socket type for + mqtt connection. Creating a custom socket allows for configuration of a proxy + + **Syntax** + + .. code:: python + + # Configure socket factory + custom_args = {"arg1": "val1", "arg2": "val2"} + socket_factory = lambda: custom.create_connection((host, port), **custom_args) + myAWSIoTMQTTClient.configureSocketFactory(socket_factory) + + **Parameters** + + *socket_factory* - Anonymous function which creates a custom socket to spec. + + **Returns** + + None + + """ + self._AWSIoTMQTTClient.configureSocketFactory(socket_factory) + + def enableMetricsCollection(self): + """ + **Description** + + Used to enable SDK metrics collection. Username field in CONNECT packet will be used to append the SDK name + and SDK version in use and communicate to AWS IoT cloud. This metrics collection is enabled by default. + This is a public facing API inherited by application level public clients. + + **Syntax** + + .. code:: python + + myShadowClient.enableMetricsCollection() + myJobsClient.enableMetricsCollection() + + **Parameters** + + None + + **Returns** + + None + + """ + self._AWSIoTMQTTClient.enableMetricsCollection() + + def disableMetricsCollection(self): + """ + **Description** + + Used to disable SDK metrics collection. This is a public facing API inherited by application level public clients. + + **Syntax** + + .. code:: python + + myShadowClient.disableMetricsCollection() + myJobsClient.disableMetricsCollection() + + **Parameters** + + None + + **Returns** + + None + + """ + self._AWSIoTMQTTClient.disableMetricsCollection() + + # Start the MQTT connection + def connect(self, keepAliveIntervalSecond=600): + """ + **Description** + + Connect to AWS IoT, with user-specific keepalive interval configuration. This is a public facing API inherited + by application level public clients. + + **Syntax** + + .. code:: python + + # Connect to AWS IoT with default keepalive set to 600 seconds + myShadowClient.connect() + myJobsClient.connect() + # Connect to AWS IoT with keepalive interval set to 1200 seconds + myShadowClient.connect(1200) + myJobsClient.connect(1200) + + **Parameters** + + *keepAliveIntervalSecond* - Time in seconds for interval of sending MQTT ping request. + Default set to 30 seconds. + + **Returns** + + True if the connect attempt succeeded. False if failed. + + """ + self._load_callbacks() + return self._AWSIoTMQTTClient.connect(keepAliveIntervalSecond) + + def _load_callbacks(self): + self._AWSIoTMQTTClient.onOnline = self.onOnline + self._AWSIoTMQTTClient.onOffline = self.onOffline + + # End the MQTT connection + def disconnect(self): + """ + **Description** + + Disconnect from AWS IoT. This is a public facing API inherited by application level public clients. + + **Syntax** + + .. code:: python + + myShadowClient.disconnect() + myJobsClient.disconnect() + + **Parameters** + + None + + **Returns** + + True if the disconnect attempt succeeded. False if failed. + + """ + return self._AWSIoTMQTTClient.disconnect() + + # MQTT connection management API + def getMQTTConnection(self): + """ + **Description** + + Retrieve the AWS IoT MQTT Client used underneath, making it possible to perform + plain MQTT operations along with specialized operations using the same single connection. + This is a public facing API inherited by application level public clients. + + **Syntax** + + .. code:: python + + # Retrieve the AWS IoT MQTT Client used in the AWS IoT MQTT Delegating Client + thisAWSIoTMQTTClient = myShadowClient.getMQTTConnection() + thisAWSIoTMQTTClient = myJobsClient.getMQTTConnection() + # Perform plain MQTT operations using the same connection + thisAWSIoTMQTTClient.publish("Topic", "Payload", 1) + ... + + **Parameters** + + None + + **Returns** + + AWSIoTPythonSDK.MQTTLib.AWSIoTMQTTClient object + + """ + # Return the internal AWSIoTMQTTClient instance + return self._AWSIoTMQTTClient + + def onOnline(self): + """ + **Description** + + Callback that gets called when the client is online. The callback registration should happen before calling + connect. This is a public facing API inherited by application level public clients. + + **Syntax** + + .. code:: python + + # Register an onOnline callback + myShadowClient.onOnline = myOnOnlineCallback + myJobsClient.onOnline = myOnOnlineCallback + + **Parameters** + + None + + **Returns** + + None + + """ + pass + + def onOffline(self): + """ + **Description** + + Callback that gets called when the client is offline. The callback registration should happen before calling + connect. This is a public facing API inherited by application level public clients. + + **Syntax** + + .. code:: python + + # Register an onOffline callback + myShadowClient.onOffline = myOnOfflineCallback + myJobsClient.onOffline = myOnOfflineCallback + + **Parameters** + + None + + **Returns** + + None + + """ + pass + + +class AWSIoTMQTTShadowClient(_AWSIoTMQTTDelegatingClient): + + def __init__(self, clientID, protocolType=MQTTv3_1_1, useWebsocket=False, cleanSession=True, awsIoTMQTTClient=None): + """ + + The client class that manages device shadow and accesses its functionality in AWS IoT over MQTT v3.1/3.1.1. + + It delegates to the AWS IoT MQTT Client and exposes devive shadow related operations. + It shares the same connection types, synchronous MQTT operations and partial on-top features + with the AWS IoT MQTT Client: + + - Auto reconnect/resubscribe + + Same as AWS IoT MQTT Client. + + - Progressive reconnect backoff + + Same as AWS IoT MQTT Client. + + - Offline publish requests queueing with draining + + Disabled by default. Queueing is not allowed for time-sensitive shadow requests/messages. + + **Syntax** + + .. code:: python + + import AWSIoTPythonSDK.MQTTLib as AWSIoTPyMQTT + + # Create an AWS IoT MQTT Shadow Client using TLSv1.2 Mutual Authentication + myAWSIoTMQTTShadowClient = AWSIoTPyMQTT.AWSIoTMQTTShadowClient("testIoTPySDK") + # Create an AWS IoT MQTT Shadow Client using Websocket SigV4 + myAWSIoTMQTTShadowClient = AWSIoTPyMQTT.AWSIoTMQTTShadowClient("testIoTPySDK", useWebsocket=True) + + **Parameters** + + *clientID* - String that denotes the client identifier used to connect to AWS IoT. + If empty string were provided, client id for this connection will be randomly generated + n server side. + + *protocolType* - MQTT version in use for this connection. Could be :code:`AWSIoTPythonSDK.MQTTLib.MQTTv3_1` or :code:`AWSIoTPythonSDK.MQTTLib.MQTTv3_1_1` + + *useWebsocket* - Boolean that denotes enabling MQTT over Websocket SigV4 or not. + + **Returns** + + AWSIoTPythonSDK.MQTTLib.AWSIoTMQTTShadowClient object + + """ + super(AWSIoTMQTTShadowClient, self).__init__(clientID, protocolType, useWebsocket, cleanSession, awsIoTMQTTClient) + #leave passed in clients alone + if awsIoTMQTTClient is None: + # Configure it to disable offline Publish Queueing + self._AWSIoTMQTTClient.configureOfflinePublishQueueing(0) # Disable queueing, no queueing for time-sensitive shadow messages + self._AWSIoTMQTTClient.configureDrainingFrequency(10) + # Now retrieve the configured mqttCore and init a shadowManager instance + self._shadowManager = shadowManager.shadowManager(self._AWSIoTMQTTClient._mqtt_core) + + # Shadow management API + def createShadowHandlerWithName(self, shadowName, isPersistentSubscribe): + """ + **Description** + + Create a device shadow handler using the specified shadow name and isPersistentSubscribe. + + **Syntax** + + .. code:: python + + # Create a device shadow handler for shadow named "Bot1", using persistent subscription + Bot1Shadow = myAWSIoTMQTTShadowClient.createShadowHandlerWithName("Bot1", True) + # Create a device shadow handler for shadow named "Bot2", using non-persistent subscription + Bot2Shadow = myAWSIoTMQTTShadowClient.createShadowHandlerWithName("Bot2", False) + + **Parameters** + + *shadowName* - Name of the device shadow. + + *isPersistentSubscribe* - Whether to unsubscribe from shadow response (accepted/rejected) topics + when there is a response. Will subscribe at the first time the shadow request is made and will + not unsubscribe if isPersistentSubscribe is set. + + **Returns** + + AWSIoTPythonSDK.core.shadow.deviceShadow.deviceShadow object, which exposes the device shadow interface. + + """ + # Create and return a deviceShadow instance + return deviceShadow.deviceShadow(shadowName, isPersistentSubscribe, self._shadowManager) + # Shadow APIs are accessible in deviceShadow instance": + ### + # deviceShadow.shadowGet + # deviceShadow.shadowUpdate + # deviceShadow.shadowDelete + # deviceShadow.shadowRegisterDelta + # deviceShadow.shadowUnregisterDelta + +class AWSIoTMQTTThingJobsClient(_AWSIoTMQTTDelegatingClient): + + def __init__(self, clientID, thingName, QoS=0, protocolType=MQTTv3_1_1, useWebsocket=False, cleanSession=True, awsIoTMQTTClient=None): + """ + + The client class that specializes in handling jobs messages and accesses its functionality in AWS IoT over MQTT v3.1/3.1.1. + + It delegates to the AWS IoT MQTT Client and exposes jobs related operations. + It shares the same connection types, synchronous MQTT operations and partial on-top features + with the AWS IoT MQTT Client: + + - Auto reconnect/resubscribe + + Same as AWS IoT MQTT Client. + + - Progressive reconnect backoff + + Same as AWS IoT MQTT Client. + + - Offline publish requests queueing with draining + + Same as AWS IoT MQTT Client + + **Syntax** + + .. code:: python + + import AWSIoTPythonSDK.MQTTLib as AWSIoTPyMQTT + + # Create an AWS IoT MQTT Jobs Client using TLSv1.2 Mutual Authentication + myAWSIoTMQTTJobsClient = AWSIoTPyMQTT.AWSIoTMQTTThingJobsClient("testIoTPySDK") + # Create an AWS IoT MQTT Jobs Client using Websocket SigV4 + myAWSIoTMQTTJobsClient = AWSIoTPyMQTT.AWSIoTMQTTThingJobsClient("testIoTPySDK", useWebsocket=True) + + **Parameters** + + *clientID* - String that denotes the client identifier and client token for jobs requests + If empty string is provided, client id for this connection will be randomly generated + on server side. If an awsIotMQTTClient is specified, this will not override the client ID + for the existing MQTT connection and only impact the client token for jobs request payloads + + *thingName* - String that represents the thingName used to send requests to proper topics and subscribe + to proper topics. + + *QoS* - QoS used for all requests sent through this client + + *awsIoTMQTTClient* - An instance of AWSIoTMQTTClient to use if not None. If not None, clientID, protocolType, useWebSocket, + and cleanSession parameters are not used. Caller is expected to invoke connect() prior to calling the pub/sub methods on this client. + + *protocolType* - MQTT version in use for this connection. Could be :code:`AWSIoTPythonSDK.MQTTLib.MQTTv3_1` or :code:`AWSIoTPythonSDK.MQTTLib.MQTTv3_1_1` + + *useWebsocket* - Boolean that denotes enabling MQTT over Websocket SigV4 or not. + + **Returns** + + AWSIoTPythonSDK.MQTTLib.AWSIoTMQTTJobsClient object + + """ + # AWSIOTMQTTClient instance + super(AWSIoTMQTTThingJobsClient, self).__init__(clientID, protocolType, useWebsocket, cleanSession, awsIoTMQTTClient) + self._thingJobManager = thingJobManager.thingJobManager(thingName, clientID) + self._QoS = QoS + + def createJobSubscription(self, callback, jobExecutionType=jobExecutionTopicType.JOB_WILDCARD_TOPIC, jobReplyType=jobExecutionTopicReplyType.JOB_REQUEST_TYPE, jobId=None): + """ + **Description** + + Synchronously creates an MQTT subscription to a jobs related topic based on the provided arguments + + **Syntax** + + .. code:: python + + #Subscribe to notify-next topic to monitor change in job referred to by $next + myAWSIoTMQTTJobsClient.createJobSubscription(callback, jobExecutionTopicType.JOB_NOTIFY_NEXT_TOPIC) + #Subscribe to notify topic to monitor changes to jobs in pending list + myAWSIoTMQTTJobsClient.createJobSubscription(callback, jobExecutionTopicType.JOB_NOTIFY_TOPIC) + #Subscribe to receive messages for job execution updates + myAWSIoTMQTTJobsClient.createJobSubscription(callback, jobExecutionTopicType.JOB_UPDATE_TOPIC, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE) + #Subscribe to receive messages for describing a job execution + myAWSIoTMQTTJobsClient.createJobSubscription(callback, jobExecutionTopicType.JOB_DESCRIBE_TOPIC, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE, jobId) + + **Parameters** + + *callback* - Function to be called when a new message for the subscribed job topic + comes in. Should be in form :code:`customCallback(client, userdata, message)`, where + :code:`message` contains :code:`topic` and :code:`payload`. Note that :code:`client` and :code:`userdata` are + here just to be aligned with the underneath Paho callback function signature. These fields are pending to be + deprecated and should not be depended on. + + *jobExecutionType* - Member of the jobExecutionTopicType class specifying the jobs topic to subscribe to + Defaults to jobExecutionTopicType.JOB_WILDCARD_TOPIC + + *jobReplyType* - Member of the jobExecutionTopicReplyType class specifying the (optional) reply sub-topic to subscribe to + Defaults to jobExecutionTopicReplyType.JOB_REQUEST_TYPE which indicates the subscription isn't intended for a jobs reply topic + + *jobId* - JobId string if the topic type requires one. + Defaults to None + + **Returns** + + True if the subscribe attempt succeeded. False if failed. + + """ + topic = self._thingJobManager.getJobTopic(jobExecutionType, jobReplyType, jobId) + return self._AWSIoTMQTTClient.subscribe(topic, self._QoS, callback) + + def createJobSubscriptionAsync(self, ackCallback, callback, jobExecutionType=jobExecutionTopicType.JOB_WILDCARD_TOPIC, jobReplyType=jobExecutionTopicReplyType.JOB_REQUEST_TYPE, jobId=None): + """ + **Description** + + Asynchronously creates an MQTT subscription to a jobs related topic based on the provided arguments + + **Syntax** + + .. code:: python + + #Subscribe to notify-next topic to monitor change in job referred to by $next + myAWSIoTMQTTJobsClient.createJobSubscriptionAsync(callback, jobExecutionTopicType.JOB_NOTIFY_NEXT_TOPIC) + #Subscribe to notify topic to monitor changes to jobs in pending list + myAWSIoTMQTTJobsClient.createJobSubscriptionAsync(callback, jobExecutionTopicType.JOB_NOTIFY_TOPIC) + #Subscribe to receive messages for job execution updates + myAWSIoTMQTTJobsClient.createJobSubscriptionAsync(callback, jobExecutionTopicType.JOB_UPDATE_TOPIC, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE) + #Subscribe to receive messages for describing a job execution + myAWSIoTMQTTJobsClient.createJobSubscriptionAsync(callback, jobExecutionTopicType.JOB_DESCRIBE_TOPIC, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE, jobId) + + **Parameters** + + *ackCallback* - Callback to be invoked when the client receives a SUBACK. Should be in form + :code:`customCallback(mid, data)`, where :code:`mid` is the packet id for the disconnect request and + :code:`data` is the granted QoS for this subscription. + + *callback* - Function to be called when a new message for the subscribed job topic + comes in. Should be in form :code:`customCallback(client, userdata, message)`, where + :code:`message` contains :code:`topic` and :code:`payload`. Note that :code:`client` and :code:`userdata` are + here just to be aligned with the underneath Paho callback function signature. These fields are pending to be + deprecated and should not be depended on. + + *jobExecutionType* - Member of the jobExecutionTopicType class specifying the jobs topic to subscribe to + Defaults to jobExecutionTopicType.JOB_WILDCARD_TOPIC + + *jobReplyType* - Member of the jobExecutionTopicReplyType class specifying the (optional) reply sub-topic to subscribe to + Defaults to jobExecutionTopicReplyType.JOB_REQUEST_TYPE which indicates the subscription isn't intended for a jobs reply topic + + *jobId* - JobId of the topic if the topic type requires one. + Defaults to None + + **Returns** + + Subscribe request packet id, for tracking purpose in the corresponding callback. + + """ + topic = self._thingJobManager.getJobTopic(jobExecutionType, jobReplyType, jobId) + return self._AWSIoTMQTTClient.subscribeAsync(topic, self._QoS, ackCallback, callback) + + def sendJobsQuery(self, jobExecTopicType, jobId=None): + """ + **Description** + + Publishes an MQTT jobs related request for a potentially specific jobId (or wildcard) + + **Syntax** + + .. code:: python + + #send a request to describe the next job + myAWSIoTMQTTJobsClient.sendJobsQuery(jobExecutionTopicType.JOB_DESCRIBE_TOPIC, '$next') + #send a request to get list of pending jobs + myAWSIoTMQTTJobsClient.sendJobsQuery(jobExecutionTopicType.JOB_GET_PENDING_TOPIC) + + **Parameters** + + *jobExecutionType* - Member of the jobExecutionTopicType class that correlates the jobs topic to publish to + + *jobId* - JobId string if the topic type requires one. + Defaults to None + + **Returns** + + True if the publish request has been sent to paho. False if the request did not reach paho. + + """ + topic = self._thingJobManager.getJobTopic(jobExecTopicType, jobExecutionTopicReplyType.JOB_REQUEST_TYPE, jobId) + payload = self._thingJobManager.serializeClientTokenPayload() + return self._AWSIoTMQTTClient.publish(topic, payload, self._QoS) + + def sendJobsStartNext(self, statusDetails=None, stepTimeoutInMinutes=None): + """ + **Description** + + Publishes an MQTT message to the StartNextJobExecution topic. This will attempt to get the next pending + job execution and change its status to IN_PROGRESS. + + **Syntax** + + .. code:: python + + #Start next job (set status to IN_PROGRESS) and update with optional statusDetails + myAWSIoTMQTTJobsClient.sendJobsStartNext({'StartedBy': 'myClientId'}) + + **Parameters** + + *statusDetails* - Dictionary containing the key value pairs to use for the status details of the job execution + + *stepTimeoutInMinutes - Specifies the amount of time this device has to finish execution of this job. + + **Returns** + + True if the publish request has been sent to paho. False if the request did not reach paho. + + """ + topic = self._thingJobManager.getJobTopic(jobExecutionTopicType.JOB_START_NEXT_TOPIC, jobExecutionTopicReplyType.JOB_REQUEST_TYPE) + payload = self._thingJobManager.serializeStartNextPendingJobExecutionPayload(statusDetails, stepTimeoutInMinutes) + return self._AWSIoTMQTTClient.publish(topic, payload, self._QoS) + + def sendJobsUpdate(self, jobId, status, statusDetails=None, expectedVersion=0, executionNumber=0, includeJobExecutionState=False, includeJobDocument=False, stepTimeoutInMinutes=None): + """ + **Description** + + Publishes an MQTT message to a corresponding job execution specific topic to update its status according to the parameters. + Can be used to change a job from QUEUED to IN_PROGRESS to SUCEEDED or FAILED. + + **Syntax** + + .. code:: python + + #Update job with id 'jobId123' to succeeded state, specifying new status details, with expectedVersion=1, executionNumber=2. + #For the response, include job execution state and not the job document + myAWSIoTMQTTJobsClient.sendJobsUpdate('jobId123', jobExecutionStatus.JOB_EXECUTION_SUCCEEDED, statusDetailsMap, 1, 2, True, False) + + + #Update job with id 'jobId456' to failed state + myAWSIoTMQTTJobsClient.sendJobsUpdate('jobId456', jobExecutionStatus.JOB_EXECUTION_FAILED) + + **Parameters** + + *jobId* - JobID String of the execution to update the status of + + *status* - job execution status to change the job execution to. Member of jobExecutionStatus + + *statusDetails* - new status details to set on the job execution + + *expectedVersion* - The expected current version of the job execution. IoT jobs increments expectedVersion each time you update the job execution. + If the version of the job execution stored in Jobs does not match, the update is rejected with a VersionMismatch error, and an ErrorResponse + that contains the current job execution status data is returned. (This makes it unnecessary to perform a separate DescribeJobExecution request + n order to obtain the job execution status data.) + + *executionNumber* - A number that identifies a particular job execution on a particular device. If not specified, the latest job execution is used. + + *includeJobExecutionState* - When included and set to True, the response contains the JobExecutionState field. The default is False. + + *includeJobDocument* - When included and set to True, the response contains the JobDocument. The default is False. + + *stepTimeoutInMinutes - Specifies the amount of time this device has to finish execution of this job. + + **Returns** + + True if the publish request has been sent to paho. False if the request did not reach paho. + + """ + topic = self._thingJobManager.getJobTopic(jobExecutionTopicType.JOB_UPDATE_TOPIC, jobExecutionTopicReplyType.JOB_REQUEST_TYPE, jobId) + payload = self._thingJobManager.serializeJobExecutionUpdatePayload(status, statusDetails, expectedVersion, executionNumber, includeJobExecutionState, includeJobDocument, stepTimeoutInMinutes) + return self._AWSIoTMQTTClient.publish(topic, payload, self._QoS) + + def sendJobsDescribe(self, jobId, executionNumber=0, includeJobDocument=True): + """ + **Description** + + Publishes a method to the describe topic for a particular job. + + **Syntax** + + .. code:: python + + #Describe job with id 'jobId1' of any executionNumber, job document will be included in response + myAWSIoTMQTTJobsClient.sendJobsDescribe('jobId1') + + #Describe job with id 'jobId2', with execution number of 2, and includeJobDocument in the response + myAWSIoTMQTTJobsClient.sendJobsDescribe('jobId2', 2, True) + + **Parameters** + + *jobId* - jobID to describe. This is allowed to be a wildcard such as '$next' + + *executionNumber* - A number that identifies a particular job execution on a particular device. If not specified, the latest job execution is used. + + *includeJobDocument* - When included and set to True, the response contains the JobDocument. + + **Returns** + + True if the publish request has been sent to paho. False if the request did not reach paho. + + """ + topic = self._thingJobManager.getJobTopic(jobExecutionTopicType.JOB_DESCRIBE_TOPIC, jobExecutionTopicReplyType.JOB_REQUEST_TYPE, jobId) + payload = self._thingJobManager.serializeDescribeJobExecutionPayload(executionNumber, includeJobDocument) + return self._AWSIoTMQTTClient.publish(topic, payload, self._QoS) diff --git a/AWSIoTPythonSDK/__init__.py b/AWSIoTPythonSDK/__init__.py index 3925732..3a384fb 100755 --- a/AWSIoTPythonSDK/__init__.py +++ b/AWSIoTPythonSDK/__init__.py @@ -1,6 +1 @@ -import os -import sys - -__version__ = "1.1.2" - - +__version__ = "1.5.4" diff --git a/AWSIoTPythonSDK/core/protocol/paho/securedWebsocket/__init__.py b/AWSIoTPythonSDK/core/greengrass/__init__.py old mode 100755 new mode 100644 similarity index 100% rename from AWSIoTPythonSDK/core/protocol/paho/securedWebsocket/__init__.py rename to AWSIoTPythonSDK/core/greengrass/__init__.py diff --git a/AWSIoTPythonSDK/core/greengrass/discovery/__init__.py b/AWSIoTPythonSDK/core/greengrass/discovery/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/AWSIoTPythonSDK/core/greengrass/discovery/models.py b/AWSIoTPythonSDK/core/greengrass/discovery/models.py new file mode 100644 index 0000000..ed8256d --- /dev/null +++ b/AWSIoTPythonSDK/core/greengrass/discovery/models.py @@ -0,0 +1,466 @@ +# /* +# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# * +# * Licensed under the Apache License, Version 2.0 (the "License"). +# * You may not use this file except in compliance with the License. +# * A copy of the License is located at +# * +# * http://aws.amazon.com/apache2.0 +# * +# * or in the "license" file accompanying this file. This file is distributed +# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# * express or implied. See the License for the specific language governing +# * permissions and limitations under the License. +# */ + +import json + + +KEY_GROUP_LIST = "GGGroups" +KEY_GROUP_ID = "GGGroupId" +KEY_CORE_LIST = "Cores" +KEY_CORE_ARN = "thingArn" +KEY_CA_LIST = "CAs" +KEY_CONNECTIVITY_INFO_LIST = "Connectivity" +KEY_CONNECTIVITY_INFO_ID = "Id" +KEY_HOST_ADDRESS = "HostAddress" +KEY_PORT_NUMBER = "PortNumber" +KEY_METADATA = "Metadata" + + +class ConnectivityInfo(object): + """ + + Class the stores one set of the connectivity information. + This is the data model for easy access to the discovery information from the discovery request function call. No + need to call directly from user scripts. + + """ + + def __init__(self, id, host, port, metadata): + self._id = id + self._host = host + self._port = port + self._metadata = metadata + + @property + def id(self): + """ + + Connectivity Information Id. + + """ + return self._id + + @property + def host(self): + """ + + Host address. + + """ + return self._host + + @property + def port(self): + """ + + Port number. + + """ + return self._port + + @property + def metadata(self): + """ + + Metadata string. + + """ + return self._metadata + + +class CoreConnectivityInfo(object): + """ + + Class that stores the connectivity information for a Greengrass core. + This is the data model for easy access to the discovery information from the discovery request function call. No + need to call directly from user scripts. + + """ + + def __init__(self, coreThingArn, groupId): + self._core_thing_arn = coreThingArn + self._group_id = groupId + self._connectivity_info_dict = dict() + + @property + def coreThingArn(self): + """ + + Thing arn for this Greengrass core. + + """ + return self._core_thing_arn + + @property + def groupId(self): + """ + + Greengrass group id that this Greengrass core belongs to. + + """ + return self._group_id + + @property + def connectivityInfoList(self): + """ + + The list of connectivity information that this Greengrass core has. + + """ + return list(self._connectivity_info_dict.values()) + + def getConnectivityInfo(self, id): + """ + + **Description** + + Used for quickly accessing a certain set of connectivity information by id. + + **Syntax** + + .. code:: python + + myCoreConnectivityInfo.getConnectivityInfo("CoolId") + + **Parameters** + + *id* - The id for the desired connectivity information. + + **Return** + + :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.ConnectivityInfo` object. + + """ + return self._connectivity_info_dict.get(id) + + def appendConnectivityInfo(self, connectivityInfo): + """ + + **Description** + + Used for adding a new set of connectivity information to the list for this Greengrass core. This is used by the + SDK internally. No need to call directly from user scripts. + + **Syntax** + + .. code:: python + + myCoreConnectivityInfo.appendConnectivityInfo(newInfo) + + **Parameters** + + *connectivityInfo* - :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.ConnectivityInfo` object. + + **Returns** + + None + + """ + self._connectivity_info_dict[connectivityInfo.id] = connectivityInfo + + +class GroupConnectivityInfo(object): + """ + + Class that stores the connectivity information for a specific Greengrass group. + This is the data model for easy access to the discovery information from the discovery request function call. No + need to call directly from user scripts. + + """ + def __init__(self, groupId): + self._group_id = groupId + self._core_connectivity_info_dict = dict() + self._ca_list = list() + + @property + def groupId(self): + """ + + Id for this Greengrass group. + + """ + return self._group_id + + @property + def coreConnectivityInfoList(self): + """ + + A list of Greengrass cores + (:code:`AWSIoTPythonSDK.core.greengrass.discovery.models.CoreConnectivityInfo` object) that belong to this + Greengrass group. + + """ + return list(self._core_connectivity_info_dict.values()) + + @property + def caList(self): + """ + + A list of CA content strings for this Greengrass group. + + """ + return self._ca_list + + def getCoreConnectivityInfo(self, coreThingArn): + """ + + **Description** + + Used to retrieve the corresponding :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.CoreConnectivityInfo` + object by core thing arn. + + **Syntax** + + .. code:: python + + myGroupConnectivityInfo.getCoreConnectivityInfo("YourOwnArnString") + + **Parameters** + + coreThingArn - Thing arn for the desired Greengrass core. + + **Returns** + + :code:`AWSIoTPythonSDK.core.greengrass.discovery.CoreConnectivityInfo` object. + + """ + return self._core_connectivity_info_dict.get(coreThingArn) + + def appendCoreConnectivityInfo(self, coreConnectivityInfo): + """ + + **Description** + + Used to append new core connectivity information to this group connectivity information. This is used by the + SDK internally. No need to call directly from user scripts. + + **Syntax** + + .. code:: python + + myGroupConnectivityInfo.appendCoreConnectivityInfo(newCoreConnectivityInfo) + + **Parameters** + + *coreConnectivityInfo* - :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.CoreConnectivityInfo` object. + + **Returns** + + None + + """ + self._core_connectivity_info_dict[coreConnectivityInfo.coreThingArn] = coreConnectivityInfo + + def appendCa(self, ca): + """ + + **Description** + + Used to append new CA content string to this group connectivity information. This is used by the SDK internally. + No need to call directly from user scripts. + + **Syntax** + + .. code:: python + + myGroupConnectivityInfo.appendCa("CaContentString") + + **Parameters** + + *ca* - Group CA content string. + + **Returns** + + None + + """ + self._ca_list.append(ca) + + +class DiscoveryInfo(object): + """ + + Class that stores the discovery information coming back from the discovery request. + This is the data model for easy access to the discovery information from the discovery request function call. No + need to call directly from user scripts. + + """ + def __init__(self, rawJson): + self._raw_json = rawJson + + @property + def rawJson(self): + """ + + JSON response string that contains the discovery information. This is reserved in case users want to do + some process by themselves. + + """ + return self._raw_json + + def getAllCores(self): + """ + + **Description** + + Used to retrieve the list of :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.CoreConnectivityInfo` + object for this discovery information. The retrieved cores could be from different Greengrass groups. This is + designed for uses who want to iterate through all available cores at the same time, regardless of which group + those cores are in. + + **Syntax** + + .. code:: python + + myDiscoveryInfo.getAllCores() + + **Parameters** + + None + + **Returns** + + List of :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.CoreConnectivtyInfo` object. + + """ + groups_list = self.getAllGroups() + core_list = list() + + for group in groups_list: + core_list.extend(group.coreConnectivityInfoList) + + return core_list + + def getAllCas(self): + """ + + **Description** + + Used to retrieve the list of :code:`(groupId, caContent)` pair for this discovery information. The retrieved + pairs could be from different Greengrass groups. This is designed for users who want to iterate through all + available cores/groups/CAs at the same time, regardless of which group those CAs belong to. + + **Syntax** + + .. code:: python + + myDiscoveryInfo.getAllCas() + + **Parameters** + + None + + **Returns** + + List of :code:`(groupId, caContent)` string pair, where :code:`caContent` is the CA content string and + :code:`groupId` is the group id that this CA belongs to. + + """ + group_list = self.getAllGroups() + ca_list = list() + + for group in group_list: + for ca in group.caList: + ca_list.append((group.groupId, ca)) + + return ca_list + + def getAllGroups(self): + """ + + **Description** + + Used to retrieve the list of :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.GroupConnectivityInfo` + object for this discovery information. This is designed for users who want to iterate through all available + groups that this Greengrass aware device (GGAD) belongs to. + + **Syntax** + + .. code:: python + + myDiscoveryInfo.getAllGroups() + + **Parameters** + + None + + **Returns** + + List of :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.GroupConnectivityInfo` object. + + """ + groups_dict = self.toObjectAtGroupLevel() + return list(groups_dict.values()) + + def toObjectAtGroupLevel(self): + """ + + **Description** + + Used to get a dictionary of Greengrass group discovery information, with group id string as key and the + corresponding :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.GroupConnectivityInfo` object as the + value. This is designed for users who know exactly which group, which core and which set of connectivity info + they want to use for the Greengrass aware device to connect. + + **Syntax** + + .. code:: python + + # Get to the targeted connectivity information for a specific core in a specific group + groupLevelDiscoveryInfoObj = myDiscoveryInfo.toObjectAtGroupLevel() + groupConnectivityInfoObj = groupLevelDiscoveryInfoObj.toObjectAtGroupLevel("IKnowMyGroupId") + coreConnectivityInfoObj = groupConnectivityInfoObj.getCoreConnectivityInfo("IKnowMyCoreThingArn") + connectivityInfo = coreConnectivityInfoObj.getConnectivityInfo("IKnowMyConnectivityInfoSetId") + # Now retrieve the detailed information + caList = groupConnectivityInfoObj.caList + host = connectivityInfo.host + port = connectivityInfo.port + metadata = connectivityInfo.metadata + # Actual connecting logic follows... + + """ + groups_object = json.loads(self._raw_json) + groups_dict = dict() + + for group_object in groups_object[KEY_GROUP_LIST]: + group_info = self._decode_group_info(group_object) + groups_dict[group_info.groupId] = group_info + + return groups_dict + + def _decode_group_info(self, group_object): + group_id = group_object[KEY_GROUP_ID] + group_info = GroupConnectivityInfo(group_id) + + for core in group_object[KEY_CORE_LIST]: + core_info = self._decode_core_info(core, group_id) + group_info.appendCoreConnectivityInfo(core_info) + + for ca in group_object[KEY_CA_LIST]: + group_info.appendCa(ca) + + return group_info + + def _decode_core_info(self, core_object, group_id): + core_info = CoreConnectivityInfo(core_object[KEY_CORE_ARN], group_id) + + for connectivity_info_object in core_object[KEY_CONNECTIVITY_INFO_LIST]: + connectivity_info = ConnectivityInfo(connectivity_info_object[KEY_CONNECTIVITY_INFO_ID], + connectivity_info_object[KEY_HOST_ADDRESS], + connectivity_info_object[KEY_PORT_NUMBER], + connectivity_info_object.get(KEY_METADATA,'')) + core_info.appendConnectivityInfo(connectivity_info) + + return core_info diff --git a/AWSIoTPythonSDK/core/greengrass/discovery/providers.py b/AWSIoTPythonSDK/core/greengrass/discovery/providers.py new file mode 100644 index 0000000..192f71a --- /dev/null +++ b/AWSIoTPythonSDK/core/greengrass/discovery/providers.py @@ -0,0 +1,442 @@ +# /* +# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# * +# * Licensed under the Apache License, Version 2.0 (the "License"). +# * You may not use this file except in compliance with the License. +# * A copy of the License is located at +# * +# * http://aws.amazon.com/apache2.0 +# * +# * or in the "license" file accompanying this file. This file is distributed +# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# * express or implied. See the License for the specific language governing +# * permissions and limitations under the License. +# */ + + +from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryInvalidRequestException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryUnauthorizedException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryDataNotFoundException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryThrottlingException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryTimeoutException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryFailure +from AWSIoTPythonSDK.core.greengrass.discovery.models import DiscoveryInfo +from AWSIoTPythonSDK.core.protocol.connection.alpn import SSLContextBuilder +import re +import sys +import ssl +import time +import errno +import logging +import socket +import platform +if platform.system() == 'Windows': + EAGAIN = errno.WSAEWOULDBLOCK +else: + EAGAIN = errno.EAGAIN + + +class DiscoveryInfoProvider(object): + + REQUEST_TYPE_PREFIX = "GET " + PAYLOAD_PREFIX = "/greengrass/discover/thing/" + PAYLOAD_SUFFIX = " HTTP/1.1\r\n" # Space in the front + HOST_PREFIX = "Host: " + HOST_SUFFIX = "\r\n\r\n" + HTTP_PROTOCOL = r"HTTP/1.1 " + CONTENT_LENGTH = r"content-length: " + CONTENT_LENGTH_PATTERN = CONTENT_LENGTH + r"([0-9]+)\r\n" + HTTP_RESPONSE_CODE_PATTERN = HTTP_PROTOCOL + r"([0-9]+) " + + HTTP_SC_200 = "200" + HTTP_SC_400 = "400" + HTTP_SC_401 = "401" + HTTP_SC_404 = "404" + HTTP_SC_429 = "429" + + LOW_LEVEL_RC_COMPLETE = 0 + LOW_LEVEL_RC_TIMEOUT = -1 + + _logger = logging.getLogger(__name__) + + def __init__(self, caPath="", certPath="", keyPath="", host="", port=8443, timeoutSec=120): + """ + + The class that provides functionality to perform a Greengrass discovery process to the cloud. + + Users can perform Greengrass discovery process for a specific Greengrass aware device to retrieve + connectivity/identity information of Greengrass cores within the same group. + + **Syntax** + + .. code:: python + + from AWSIoTPythonSDK.core.greengrass.discovery.providers import DiscoveryInfoProvider + + # Create a discovery information provider + myDiscoveryInfoProvider = DiscoveryInfoProvider() + # Create a discovery information provider with custom configuration + myDiscoveryInfoProvider = DiscoveryInfoProvider(caPath=myCAPath, certPath=myCertPath, keyPath=myKeyPath, host=myHost, timeoutSec=myTimeoutSec) + + **Parameters** + + *caPath* - Path to read the root CA file. + + *certPath* - Path to read the certificate file. + + *keyPath* - Path to read the private key file. + + *host* - String that denotes the host name of the user-specific AWS IoT endpoint. + + *port* - Integer that denotes the port number to connect to. For discovery purpose, it is 8443 by default. + + *timeoutSec* - Time out configuration in seconds to consider a discovery request sending/response waiting has + been timed out. + + **Returns** + + AWSIoTPythonSDK.core.greengrass.discovery.providers.DiscoveryInfoProvider object + + """ + self._ca_path = caPath + self._cert_path = certPath + self._key_path = keyPath + self._host = host + self._port = port + self._timeout_sec = timeoutSec + self._expected_exception_map = { + self.HTTP_SC_400 : DiscoveryInvalidRequestException(), + self.HTTP_SC_401 : DiscoveryUnauthorizedException(), + self.HTTP_SC_404 : DiscoveryDataNotFoundException(), + self.HTTP_SC_429 : DiscoveryThrottlingException() + } + + def configureEndpoint(self, host, port=8443): + """ + + **Description** + + Used to configure the host address and port number for the discovery request to hit. Should be called before + the discovery request happens. + + **Syntax** + + .. code:: python + + # Using default port configuration, 8443 + myDiscoveryInfoProvider.configureEndpoint(host="prefix.iot.us-east-1.amazonaws.com") + # Customize port configuration + myDiscoveryInfoProvider.configureEndpoint(host="prefix.iot.us-east-1.amazonaws.com", port=8888) + + **Parameters** + + *host* - String that denotes the host name of the user-specific AWS IoT endpoint. + + *port* - Integer that denotes the port number to connect to. For discovery purpose, it is 8443 by default. + + **Returns** + + None + + """ + self._host = host + self._port = port + + def configureCredentials(self, caPath, certPath, keyPath): + """ + + **Description** + + Used to configure the credentials for discovery request. Should be called before the discovery request happens. + + **Syntax** + + .. code:: python + + myDiscoveryInfoProvider.configureCredentials("my/ca/path", "my/cert/path", "my/key/path") + + **Parameters** + + *caPath* - Path to read the root CA file. + + *certPath* - Path to read the certificate file. + + *keyPath* - Path to read the private key file. + + **Returns** + + None + + """ + self._ca_path = caPath + self._cert_path = certPath + self._key_path = keyPath + + def configureTimeout(self, timeoutSec): + """ + + **Description** + + Used to configure the time out in seconds for discovery request sending/response waiting. Should be called before + the discovery request happens. + + **Syntax** + + .. code:: python + + # Configure the time out for discovery to be 10 seconds + myDiscoveryInfoProvider.configureTimeout(10) + + **Parameters** + + *timeoutSec* - Time out configuration in seconds to consider a discovery request sending/response waiting has + been timed out. + + **Returns** + + None + + """ + self._timeout_sec = timeoutSec + + def discover(self, thingName): + """ + + **Description** + + Perform the discovery request for the given Greengrass aware device thing name. + + **Syntax** + + .. code:: python + + myDiscoveryInfoProvider.discover(thingName="myGGAD") + + **Parameters** + + *thingName* - Greengrass aware device thing name. + + **Returns** + + :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.DiscoveryInfo` object. + + """ + self._logger.info("Starting discover request...") + self._logger.info("Endpoint: " + self._host + ":" + str(self._port)) + self._logger.info("Target thing: " + thingName) + sock = self._create_tcp_connection() + ssl_sock = self._create_ssl_connection(sock) + self._raise_on_timeout(self._send_discovery_request(ssl_sock, thingName)) + status_code, response_body = self._receive_discovery_response(ssl_sock) + + return self._raise_if_not_200(status_code, response_body) + + def _create_tcp_connection(self): + self._logger.debug("Creating tcp connection...") + try: + if (sys.version_info[0] == 2 and sys.version_info[1] < 7) or (sys.version_info[0] == 3 and sys.version_info[1] < 2): + sock = socket.create_connection((self._host, self._port)) + else: + sock = socket.create_connection((self._host, self._port), source_address=("", 0)) + return sock + except socket.error as err: + if err.errno != errno.EINPROGRESS and err.errno != errno.EWOULDBLOCK and err.errno != EAGAIN: + raise + self._logger.debug("Created tcp connection.") + + def _create_ssl_connection(self, sock): + self._logger.debug("Creating ssl connection...") + + ssl_protocol_version = ssl.PROTOCOL_SSLv23 + + if self._port == 443: + ssl_context = SSLContextBuilder()\ + .with_ca_certs(self._ca_path)\ + .with_cert_key_pair(self._cert_path, self._key_path)\ + .with_cert_reqs(ssl.CERT_REQUIRED)\ + .with_check_hostname(True)\ + .with_ciphers(None)\ + .with_alpn_protocols(['x-amzn-http-ca'])\ + .build() + ssl_sock = ssl_context.wrap_socket(sock, server_hostname=self._host, do_handshake_on_connect=False) + ssl_sock.do_handshake() + else: + # To keep the SSL Context update minimal, only apply forced ssl context to python3.12+ + force_ssl_context = sys.version_info[0] > 3 or (sys.version_info[0] == 3 and sys.version_info[1] >= 12) + if force_ssl_context: + ssl_context = ssl.SSLContext(ssl_protocol_version) + ssl_context.load_cert_chain(self._cert_path, self._key_path) + ssl_context.load_verify_locations(self._ca_path) + ssl_context.verify_mode = ssl.CERT_REQUIRED + + ssl_sock = ssl_context.wrap_socket(sock) + else: + ssl_sock = ssl.wrap_socket(sock, + certfile=self._cert_path, + keyfile=self._key_path, + ca_certs=self._ca_path, + cert_reqs=ssl.CERT_REQUIRED, + ssl_version=ssl_protocol_version) + + self._logger.debug("Matching host name...") + if sys.version_info[0] < 3 or (sys.version_info[0] == 3 and sys.version_info[1] < 2): + self._tls_match_hostname(ssl_sock) + elif sys.version_info[0] == 3 and sys.version_info[1] < 7: + # host name verification is handled internally in Python3.7+ + ssl.match_hostname(ssl_sock.getpeercert(), self._host) + + return ssl_sock + + def _tls_match_hostname(self, ssl_sock): + try: + cert = ssl_sock.getpeercert() + except AttributeError: + # the getpeercert can throw Attribute error: object has no attribute 'peer_certificate' + # Don't let that crash the whole client. See also: http://bugs.python.org/issue13721 + raise ssl.SSLError('Not connected') + + san = cert.get('subjectAltName') + if san: + have_san_dns = False + for (key, value) in san: + if key == 'DNS': + have_san_dns = True + if self._host_matches_cert(self._host.lower(), value.lower()) == True: + return + if key == 'IP Address': + have_san_dns = True + if value.lower() == self._host.lower(): + return + + if have_san_dns: + # Only check subject if subjectAltName dns not found. + raise ssl.SSLError('Certificate subject does not match remote hostname.') + subject = cert.get('subject') + if subject: + for ((key, value),) in subject: + if key == 'commonName': + if self._host_matches_cert(self._host.lower(), value.lower()) == True: + return + + raise ssl.SSLError('Certificate subject does not match remote hostname.') + + def _host_matches_cert(self, host, cert_host): + if cert_host[0:2] == "*.": + if cert_host.count("*") != 1: + return False + + host_match = host.split(".", 1)[1] + cert_match = cert_host.split(".", 1)[1] + if host_match == cert_match: + return True + else: + return False + else: + if host == cert_host: + return True + else: + return False + + def _send_discovery_request(self, ssl_sock, thing_name): + request = self.REQUEST_TYPE_PREFIX + \ + self.PAYLOAD_PREFIX + \ + thing_name + \ + self.PAYLOAD_SUFFIX + \ + self.HOST_PREFIX + \ + self._host + ":" + str(self._port) + \ + self.HOST_SUFFIX + self._logger.debug("Sending discover request: " + request) + + start_time = time.time() + desired_length_to_write = len(request) + actual_length_written = 0 + while True: + try: + length_written = ssl_sock.write(request.encode("utf-8")) + actual_length_written += length_written + except socket.error as err: + if err.errno == ssl.SSL_ERROR_WANT_READ or err.errno == ssl.SSL_ERROR_WANT_WRITE: + pass + if actual_length_written == desired_length_to_write: + return self.LOW_LEVEL_RC_COMPLETE + if start_time + self._timeout_sec < time.time(): + return self.LOW_LEVEL_RC_TIMEOUT + + def _receive_discovery_response(self, ssl_sock): + self._logger.debug("Receiving discover response header...") + rc1, response_header = self._receive_until(ssl_sock, self._got_two_crlfs) + status_code, body_length = self._handle_discovery_response_header(rc1, response_header.decode("utf-8")) + + self._logger.debug("Receiving discover response body...") + rc2, response_body = self._receive_until(ssl_sock, self._got_enough_bytes, body_length) + response_body = self._handle_discovery_response_body(rc2, response_body.decode("utf-8")) + + return status_code, response_body + + def _receive_until(self, ssl_sock, criteria_function, extra_data=None): + start_time = time.time() + response = bytearray() + number_bytes_read = 0 + ssl_sock_tmp = None + while True: # Python does not have do-while + try: + ssl_sock_tmp = self._convert_to_int_py3(ssl_sock.read(1)) + if isinstance(ssl_sock_tmp, list): + response.extend(ssl_sock_tmp) + else: + response.append(ssl_sock_tmp) + number_bytes_read += 1 + except socket.error as err: + if err.errno == ssl.SSL_ERROR_WANT_READ or err.errno == ssl.SSL_ERROR_WANT_WRITE: + pass + + if criteria_function((number_bytes_read, response, extra_data)): + return self.LOW_LEVEL_RC_COMPLETE, response + if start_time + self._timeout_sec < time.time(): + return self.LOW_LEVEL_RC_TIMEOUT, response + + def _convert_to_int_py3(self, input_char): + try: + return ord(input_char) + except: + return input_char + + def _got_enough_bytes(self, data): + number_bytes_read, response, target_length = data + return number_bytes_read == int(target_length) + + def _got_two_crlfs(self, data): + number_bytes_read, response, extra_data_unused = data + number_of_crlf = 2 + has_enough_bytes = number_bytes_read > number_of_crlf * 2 - 1 + if has_enough_bytes: + end_of_received = response[number_bytes_read - number_of_crlf * 2 : number_bytes_read] + expected_end_of_response = b"\r\n" * number_of_crlf + return end_of_received == expected_end_of_response + else: + return False + + def _handle_discovery_response_header(self, rc, response): + self._raise_on_timeout(rc) + http_status_code_matcher = re.compile(self.HTTP_RESPONSE_CODE_PATTERN) + http_status_code_matched_groups = http_status_code_matcher.match(response) + content_length_matcher = re.compile(self.CONTENT_LENGTH_PATTERN) + content_length_matched_groups = content_length_matcher.search(response) + return http_status_code_matched_groups.group(1), content_length_matched_groups.group(1) + + def _handle_discovery_response_body(self, rc, response): + self._raise_on_timeout(rc) + return response + + def _raise_on_timeout(self, rc): + if rc == self.LOW_LEVEL_RC_TIMEOUT: + raise DiscoveryTimeoutException() + + def _raise_if_not_200(self, status_code, response_body): # response_body here is str in Py3 + if status_code != self.HTTP_SC_200: + expected_exception = self._expected_exception_map.get(status_code) + if expected_exception: + raise expected_exception + else: + raise DiscoveryFailure(response_body) + return DiscoveryInfo(response_body) diff --git a/AWSIoTPythonSDK/core/jobs/__init__.py b/AWSIoTPythonSDK/core/jobs/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/AWSIoTPythonSDK/core/jobs/thingJobManager.py b/AWSIoTPythonSDK/core/jobs/thingJobManager.py new file mode 100755 index 0000000..d2396b2 --- /dev/null +++ b/AWSIoTPythonSDK/core/jobs/thingJobManager.py @@ -0,0 +1,156 @@ +# /* +# * Copyright 2010-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# * +# * Licensed under the Apache License, Version 2.0 (the "License"). +# * You may not use this file except in compliance with the License. +# * A copy of the License is located at +# * +# * http://aws.amazon.com/apache2.0 +# * +# * or in the "license" file accompanying this file. This file is distributed +# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# * express or implied. See the License for the specific language governing +# * permissions and limitations under the License. +# */ + +import json + +_BASE_THINGS_TOPIC = "$aws/things/" +_NOTIFY_OPERATION = "notify" +_NOTIFY_NEXT_OPERATION = "notify-next" +_GET_OPERATION = "get" +_START_NEXT_OPERATION = "start-next" +_WILDCARD_OPERATION = "+" +_UPDATE_OPERATION = "update" +_ACCEPTED_REPLY = "accepted" +_REJECTED_REPLY = "rejected" +_WILDCARD_REPLY = "#" + +#Members of this enum are tuples +_JOB_ID_REQUIRED_INDEX = 1 +_JOB_OPERATION_INDEX = 2 + +_STATUS_KEY = 'status' +_STATUS_DETAILS_KEY = 'statusDetails' +_EXPECTED_VERSION_KEY = 'expectedVersion' +_EXEXCUTION_NUMBER_KEY = 'executionNumber' +_INCLUDE_JOB_EXECUTION_STATE_KEY = 'includeJobExecutionState' +_INCLUDE_JOB_DOCUMENT_KEY = 'includeJobDocument' +_CLIENT_TOKEN_KEY = 'clientToken' +_STEP_TIMEOUT_IN_MINUTES_KEY = 'stepTimeoutInMinutes' + +#The type of job topic. +class jobExecutionTopicType(object): + JOB_UNRECOGNIZED_TOPIC = (0, False, '') + JOB_GET_PENDING_TOPIC = (1, False, _GET_OPERATION) + JOB_START_NEXT_TOPIC = (2, False, _START_NEXT_OPERATION) + JOB_DESCRIBE_TOPIC = (3, True, _GET_OPERATION) + JOB_UPDATE_TOPIC = (4, True, _UPDATE_OPERATION) + JOB_NOTIFY_TOPIC = (5, False, _NOTIFY_OPERATION) + JOB_NOTIFY_NEXT_TOPIC = (6, False, _NOTIFY_NEXT_OPERATION) + JOB_WILDCARD_TOPIC = (7, False, _WILDCARD_OPERATION) + +#Members of this enum are tuples +_JOB_SUFFIX_INDEX = 1 +#The type of reply topic, or #JOB_REQUEST_TYPE for topics that are not replies. +class jobExecutionTopicReplyType(object): + JOB_UNRECOGNIZED_TOPIC_TYPE = (0, '') + JOB_REQUEST_TYPE = (1, '') + JOB_ACCEPTED_REPLY_TYPE = (2, '/' + _ACCEPTED_REPLY) + JOB_REJECTED_REPLY_TYPE = (3, '/' + _REJECTED_REPLY) + JOB_WILDCARD_REPLY_TYPE = (4, '/' + _WILDCARD_REPLY) + +_JOB_STATUS_INDEX = 1 +class jobExecutionStatus(object): + JOB_EXECUTION_STATUS_NOT_SET = (0, None) + JOB_EXECUTION_QUEUED = (1, 'QUEUED') + JOB_EXECUTION_IN_PROGRESS = (2, 'IN_PROGRESS') + JOB_EXECUTION_FAILED = (3, 'FAILED') + JOB_EXECUTION_SUCCEEDED = (4, 'SUCCEEDED') + JOB_EXECUTION_CANCELED = (5, 'CANCELED') + JOB_EXECUTION_REJECTED = (6, 'REJECTED') + JOB_EXECUTION_UNKNOWN_STATUS = (99, None) + +def _getExecutionStatus(jobStatus): + try: + return jobStatus[_JOB_STATUS_INDEX] + except KeyError: + return None + +def _isWithoutJobIdTopicType(srcJobExecTopicType): + return (srcJobExecTopicType == jobExecutionTopicType.JOB_GET_PENDING_TOPIC or srcJobExecTopicType == jobExecutionTopicType.JOB_START_NEXT_TOPIC + or srcJobExecTopicType == jobExecutionTopicType.JOB_NOTIFY_TOPIC or srcJobExecTopicType == jobExecutionTopicType.JOB_NOTIFY_NEXT_TOPIC) + +class thingJobManager: + def __init__(self, thingName, clientToken = None): + self._thingName = thingName + self._clientToken = clientToken + + def getJobTopic(self, srcJobExecTopicType, srcJobExecTopicReplyType=jobExecutionTopicReplyType.JOB_REQUEST_TYPE, jobId=None): + if self._thingName is None: + return None + + #Verify topics that only support request type, actually have request type specified for reply + if (srcJobExecTopicType == jobExecutionTopicType.JOB_NOTIFY_TOPIC or srcJobExecTopicType == jobExecutionTopicType.JOB_NOTIFY_NEXT_TOPIC) and srcJobExecTopicReplyType != jobExecutionTopicReplyType.JOB_REQUEST_TYPE: + return None + + #Verify topics that explicitly do not want a job ID do not have one specified + if (jobId is not None and _isWithoutJobIdTopicType(srcJobExecTopicType)): + return None + + #Verify job ID is present if the topic requires one + if jobId is None and srcJobExecTopicType[_JOB_ID_REQUIRED_INDEX]: + return None + + #Ensure the job operation is a non-empty string + if srcJobExecTopicType[_JOB_OPERATION_INDEX] == '': + return None + + if srcJobExecTopicType[_JOB_ID_REQUIRED_INDEX]: + return '{0}{1}/jobs/{2}/{3}{4}'.format(_BASE_THINGS_TOPIC, self._thingName, str(jobId), srcJobExecTopicType[_JOB_OPERATION_INDEX], srcJobExecTopicReplyType[_JOB_SUFFIX_INDEX]) + elif srcJobExecTopicType == jobExecutionTopicType.JOB_WILDCARD_TOPIC: + return '{0}{1}/jobs/#'.format(_BASE_THINGS_TOPIC, self._thingName) + else: + return '{0}{1}/jobs/{2}{3}'.format(_BASE_THINGS_TOPIC, self._thingName, srcJobExecTopicType[_JOB_OPERATION_INDEX], srcJobExecTopicReplyType[_JOB_SUFFIX_INDEX]) + + def serializeJobExecutionUpdatePayload(self, status, statusDetails=None, expectedVersion=0, executionNumber=0, includeJobExecutionState=False, includeJobDocument=False, stepTimeoutInMinutes=None): + executionStatus = _getExecutionStatus(status) + if executionStatus is None: + return None + payload = {_STATUS_KEY: executionStatus} + if statusDetails: + payload[_STATUS_DETAILS_KEY] = statusDetails + if expectedVersion > 0: + payload[_EXPECTED_VERSION_KEY] = str(expectedVersion) + if executionNumber > 0: + payload[_EXEXCUTION_NUMBER_KEY] = str(executionNumber) + if includeJobExecutionState: + payload[_INCLUDE_JOB_EXECUTION_STATE_KEY] = True + if includeJobDocument: + payload[_INCLUDE_JOB_DOCUMENT_KEY] = True + if self._clientToken is not None: + payload[_CLIENT_TOKEN_KEY] = self._clientToken + if stepTimeoutInMinutes is not None: + payload[_STEP_TIMEOUT_IN_MINUTES_KEY] = stepTimeoutInMinutes + return json.dumps(payload) + + def serializeDescribeJobExecutionPayload(self, executionNumber=0, includeJobDocument=True): + payload = {_INCLUDE_JOB_DOCUMENT_KEY: includeJobDocument} + if executionNumber > 0: + payload[_EXEXCUTION_NUMBER_KEY] = executionNumber + if self._clientToken is not None: + payload[_CLIENT_TOKEN_KEY] = self._clientToken + return json.dumps(payload) + + def serializeStartNextPendingJobExecutionPayload(self, statusDetails=None, stepTimeoutInMinutes=None): + payload = {} + if self._clientToken is not None: + payload[_CLIENT_TOKEN_KEY] = self._clientToken + if statusDetails is not None: + payload[_STATUS_DETAILS_KEY] = statusDetails + if stepTimeoutInMinutes is not None: + payload[_STEP_TIMEOUT_IN_MINUTES_KEY] = stepTimeoutInMinutes + return json.dumps(payload) + + def serializeClientTokenPayload(self): + return json.dumps({_CLIENT_TOKEN_KEY: self._clientToken}) if self._clientToken is not None else '{}' diff --git a/AWSIoTPythonSDK/core/protocol/connection/__init__.py b/AWSIoTPythonSDK/core/protocol/connection/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/AWSIoTPythonSDK/core/protocol/connection/alpn.py b/AWSIoTPythonSDK/core/protocol/connection/alpn.py new file mode 100644 index 0000000..8da98dd --- /dev/null +++ b/AWSIoTPythonSDK/core/protocol/connection/alpn.py @@ -0,0 +1,63 @@ +# /* +# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# * +# * Licensed under the Apache License, Version 2.0 (the "License"). +# * You may not use this file except in compliance with the License. +# * A copy of the License is located at +# * +# * http://aws.amazon.com/apache2.0 +# * +# * or in the "license" file accompanying this file. This file is distributed +# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# * express or implied. See the License for the specific language governing +# * permissions and limitations under the License. +# */ + + +try: + import ssl +except: + ssl = None + + +class SSLContextBuilder(object): + + def __init__(self): + self.check_supportability() + self._ssl_context = ssl.create_default_context() + + def check_supportability(self): + if ssl is None: + raise RuntimeError("This platform has no SSL/TLS.") + if not hasattr(ssl, "SSLContext"): + raise NotImplementedError("This platform does not support SSLContext. Python 2.7.10+/3.5+ is required.") + if not hasattr(ssl.SSLContext, "set_alpn_protocols"): + raise NotImplementedError("This platform does not support ALPN as TLS extensions. Python 2.7.10+/3.5+ is required.") + + def with_ca_certs(self, ca_certs): + self._ssl_context.load_verify_locations(ca_certs) + return self + + def with_cert_key_pair(self, cert_file, key_file): + self._ssl_context.load_cert_chain(cert_file, key_file) + return self + + def with_cert_reqs(self, cert_reqs): + self._ssl_context.verify_mode = cert_reqs + return self + + def with_check_hostname(self, check_hostname): + self._ssl_context.check_hostname = check_hostname + return self + + def with_ciphers(self, ciphers): + if ciphers is not None: + self._ssl_context.set_ciphers(ciphers) # set_ciphers() does not allow None input. Use default (do nothing) if None + return self + + def with_alpn_protocols(self, alpn_protocols): + self._ssl_context.set_alpn_protocols(alpn_protocols) + return self + + def build(self): + return self._ssl_context diff --git a/AWSIoTPythonSDK/core/protocol/paho/securedWebsocket/securedWebsocketCore.py b/AWSIoTPythonSDK/core/protocol/connection/cores.py old mode 100755 new mode 100644 similarity index 55% rename from AWSIoTPythonSDK/core/protocol/paho/securedWebsocket/securedWebsocketCore.py rename to AWSIoTPythonSDK/core/protocol/connection/cores.py index 8699f97..df12470 --- a/AWSIoTPythonSDK/core/protocol/paho/securedWebsocket/securedWebsocketCore.py +++ b/AWSIoTPythonSDK/core/protocol/connection/cores.py @@ -1,37 +1,290 @@ -''' -/* - * Copyright 2010-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - ''' - -# This class implements a simple secured websocket client -# with support for websocket handshake, frame encoding/decoding -# and Python paho-mqtt compatible low level socket I/O -# By now, we assume that for each MQTT packet over websocket, -# it will be wrapped into ONE websocket frame. Fragments of -# MQTT packet should be ignored. - -import os +# /* +# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# * +# * Licensed under the Apache License, Version 2.0 (the "License"). +# * You may not use this file except in compliance with the License. +# * A copy of the License is located at +# * +# * http://aws.amazon.com/apache2.0 +# * +# * or in the "license" file accompanying this file. This file is distributed +# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# * express or implied. See the License for the specific language governing +# * permissions and limitations under the License. +# */ + +# This class implements the progressive backoff logic for auto-reconnect. +# It manages the reconnect wait time for the current reconnect, controling +# when to increase it and when to reset it. + + +import re import sys import ssl +import errno import struct import socket import base64 +import time +import threading +import logging +import os +from datetime import datetime import hashlib -from AWSIoTPythonSDK.core.util.sigV4Core import sigV4Core +import hmac +from AWSIoTPythonSDK.exception.AWSIoTExceptions import ClientError from AWSIoTPythonSDK.exception.AWSIoTExceptions import wssNoKeyInEnvironmentError from AWSIoTPythonSDK.exception.AWSIoTExceptions import wssHandShakeError +from AWSIoTPythonSDK.core.protocol.internal.defaults import DEFAULT_CONNECT_DISCONNECT_TIMEOUT_SEC +try: + from urllib.parse import quote # Python 3+ +except ImportError: + from urllib import quote +# INI config file handling +try: + from configparser import ConfigParser # Python 3+ + from configparser import NoOptionError + from configparser import NoSectionError +except ImportError: + from ConfigParser import ConfigParser + from ConfigParser import NoOptionError + from ConfigParser import NoSectionError + + +class ProgressiveBackOffCore: + # Logger + _logger = logging.getLogger(__name__) + + def __init__(self, srcBaseReconnectTimeSecond=1, srcMaximumReconnectTimeSecond=32, srcMinimumConnectTimeSecond=20): + # The base reconnection time in seconds, default 1 + self._baseReconnectTimeSecond = srcBaseReconnectTimeSecond + # The maximum reconnection time in seconds, default 32 + self._maximumReconnectTimeSecond = srcMaximumReconnectTimeSecond + # The minimum time in milliseconds that a connection must be maintained in order to be considered stable + # Default 20 + self._minimumConnectTimeSecond = srcMinimumConnectTimeSecond + # Current backOff time in seconds, init to equal to 0 + self._currentBackoffTimeSecond = 1 + # Handler for timer + self._resetBackoffTimer = None + + # For custom progressiveBackoff timing configuration + def configTime(self, srcBaseReconnectTimeSecond, srcMaximumReconnectTimeSecond, srcMinimumConnectTimeSecond): + if srcBaseReconnectTimeSecond < 0 or srcMaximumReconnectTimeSecond < 0 or srcMinimumConnectTimeSecond < 0: + self._logger.error("init: Negative time configuration detected.") + raise ValueError("Negative time configuration detected.") + if srcBaseReconnectTimeSecond >= srcMinimumConnectTimeSecond: + self._logger.error("init: Min connect time should be bigger than base reconnect time.") + raise ValueError("Min connect time should be bigger than base reconnect time.") + self._baseReconnectTimeSecond = srcBaseReconnectTimeSecond + self._maximumReconnectTimeSecond = srcMaximumReconnectTimeSecond + self._minimumConnectTimeSecond = srcMinimumConnectTimeSecond + self._currentBackoffTimeSecond = 1 + + # Block the reconnect logic for _currentBackoffTimeSecond + # Update the currentBackoffTimeSecond for the next reconnect + # Cancel the in-waiting timer for resetting backOff time + # This should get called only when a disconnect/reconnect happens + def backOff(self): + self._logger.debug("backOff: current backoff time is: " + str(self._currentBackoffTimeSecond) + " sec.") + if self._resetBackoffTimer is not None: + # Cancel the timer + self._resetBackoffTimer.cancel() + # Block the reconnect logic + time.sleep(self._currentBackoffTimeSecond) + # Update the backoff time + if self._currentBackoffTimeSecond == 0: + # This is the first attempt to connect, set it to base + self._currentBackoffTimeSecond = self._baseReconnectTimeSecond + else: + # r_cur = min(2^n*r_base, r_max) + self._currentBackoffTimeSecond = min(self._maximumReconnectTimeSecond, self._currentBackoffTimeSecond * 2) + + # Start the timer for resetting _currentBackoffTimeSecond + # Will be cancelled upon calling backOff + def startStableConnectionTimer(self): + self._resetBackoffTimer = threading.Timer(self._minimumConnectTimeSecond, + self._connectionStableThenResetBackoffTime) + self._resetBackoffTimer.start() + + def stopStableConnectionTimer(self): + if self._resetBackoffTimer is not None: + # Cancel the timer + self._resetBackoffTimer.cancel() + + # Timer callback to reset _currentBackoffTimeSecond + # If the connection is stable for longer than _minimumConnectTimeSecond, + # reset the currentBackoffTimeSecond to _baseReconnectTimeSecond + def _connectionStableThenResetBackoffTime(self): + self._logger.debug( + "stableConnection: Resetting the backoff time to: " + str(self._baseReconnectTimeSecond) + " sec.") + self._currentBackoffTimeSecond = self._baseReconnectTimeSecond + + +class SigV4Core: + + _logger = logging.getLogger(__name__) + + def __init__(self): + self._aws_access_key_id = "" + self._aws_secret_access_key = "" + self._aws_session_token = "" + self._credentialConfigFilePath = "~/.aws/credentials" + + def setIAMCredentials(self, srcAWSAccessKeyID, srcAWSSecretAccessKey, srcAWSSessionToken): + self._aws_access_key_id = srcAWSAccessKeyID + self._aws_secret_access_key = srcAWSSecretAccessKey + self._aws_session_token = srcAWSSessionToken + + def _createAmazonDate(self): + # Returned as a unicode string in Py3.x + amazonDate = [] + currentTime = datetime.utcnow() + YMDHMS = currentTime.strftime('%Y%m%dT%H%M%SZ') + YMD = YMDHMS[0:YMDHMS.index('T')] + amazonDate.append(YMD) + amazonDate.append(YMDHMS) + return amazonDate + + def _sign(self, key, message): + # Returned as a utf-8 byte string in Py3.x + return hmac.new(key, message.encode('utf-8'), hashlib.sha256).digest() + + def _getSignatureKey(self, key, dateStamp, regionName, serviceName): + # Returned as a utf-8 byte string in Py3.x + kDate = self._sign(('AWS4' + key).encode('utf-8'), dateStamp) + kRegion = self._sign(kDate, regionName) + kService = self._sign(kRegion, serviceName) + kSigning = self._sign(kService, 'aws4_request') + return kSigning + + def _checkIAMCredentials(self): + # Check custom config + ret = self._checkKeyInCustomConfig() + # Check environment variables + if not ret: + ret = self._checkKeyInEnv() + # Check files + if not ret: + ret = self._checkKeyInFiles() + # All credentials returned as unicode strings in Py3.x + return ret + + def _checkKeyInEnv(self): + ret = dict() + self._aws_access_key_id = os.environ.get('AWS_ACCESS_KEY_ID') + self._aws_secret_access_key = os.environ.get('AWS_SECRET_ACCESS_KEY') + self._aws_session_token = os.environ.get('AWS_SESSION_TOKEN') + if self._aws_access_key_id is not None and self._aws_secret_access_key is not None: + ret["aws_access_key_id"] = self._aws_access_key_id + ret["aws_secret_access_key"] = self._aws_secret_access_key + # We do not necessarily need session token... + if self._aws_session_token is not None: + ret["aws_session_token"] = self._aws_session_token + self._logger.debug("IAM credentials from env var.") + return ret + + def _checkKeyInINIDefault(self, srcConfigParser, sectionName): + ret = dict() + # Check aws_access_key_id and aws_secret_access_key + try: + ret["aws_access_key_id"] = srcConfigParser.get(sectionName, "aws_access_key_id") + ret["aws_secret_access_key"] = srcConfigParser.get(sectionName, "aws_secret_access_key") + except NoOptionError: + self._logger.warn("Cannot find IAM keyID/secretKey in credential file.") + # We do not continue searching if we cannot even get IAM id/secret right + if len(ret) == 2: + # Check aws_session_token, optional + try: + ret["aws_session_token"] = srcConfigParser.get(sectionName, "aws_session_token") + except NoOptionError: + self._logger.debug("No AWS Session Token found.") + return ret + + def _checkKeyInFiles(self): + credentialFile = None + credentialConfig = None + ret = dict() + # Should be compatible with aws cli default credential configuration + # *NIX/Windows + try: + # See if we get the file + credentialConfig = ConfigParser() + credentialFilePath = os.path.expanduser(self._credentialConfigFilePath) # Is it compatible with windows? \/ + credentialConfig.read(credentialFilePath) + # Now we have the file, start looking for credentials... + # 'default' section + ret = self._checkKeyInINIDefault(credentialConfig, "default") + if not ret: + # 'DEFAULT' section + ret = self._checkKeyInINIDefault(credentialConfig, "DEFAULT") + self._logger.debug("IAM credentials from file.") + except IOError: + self._logger.debug("No IAM credential configuration file in " + credentialFilePath) + except NoSectionError: + self._logger.error("Cannot find IAM 'default' section.") + return ret + + def _checkKeyInCustomConfig(self): + ret = dict() + if self._aws_access_key_id != "" and self._aws_secret_access_key != "": + ret["aws_access_key_id"] = self._aws_access_key_id + ret["aws_secret_access_key"] = self._aws_secret_access_key + # We do not necessarily need session token... + if self._aws_session_token != "": + ret["aws_session_token"] = self._aws_session_token + self._logger.debug("IAM credentials from custom config.") + return ret + + def createWebsocketEndpoint(self, host, port, region, method, awsServiceName, path): + # Return the endpoint as unicode string in 3.x + # Gather all the facts + amazonDate = self._createAmazonDate() + amazonDateSimple = amazonDate[0] # Unicode in 3.x + amazonDateComplex = amazonDate[1] # Unicode in 3.x + allKeys = self._checkIAMCredentials() # Unicode in 3.x + if not self._hasCredentialsNecessaryForWebsocket(allKeys): + raise wssNoKeyInEnvironmentError() + else: + # Because of self._hasCredentialsNecessaryForWebsocket(...), keyID and secretKey should not be None from here + keyID = allKeys["aws_access_key_id"] + secretKey = allKeys["aws_secret_access_key"] + # amazonDateSimple and amazonDateComplex are guaranteed not to be None + queryParameters = "X-Amz-Algorithm=AWS4-HMAC-SHA256" + \ + "&X-Amz-Credential=" + keyID + "%2F" + amazonDateSimple + "%2F" + region + "%2F" + awsServiceName + "%2Faws4_request" + \ + "&X-Amz-Date=" + amazonDateComplex + \ + "&X-Amz-Expires=86400" + \ + "&X-Amz-SignedHeaders=host" # Unicode in 3.x + hashedPayload = hashlib.sha256(str("").encode('utf-8')).hexdigest() # Unicode in 3.x + # Create the string to sign + signedHeaders = "host" + canonicalHeaders = "host:" + host + "\n" + canonicalRequest = method + "\n" + path + "\n" + queryParameters + "\n" + canonicalHeaders + "\n" + signedHeaders + "\n" + hashedPayload # Unicode in 3.x + hashedCanonicalRequest = hashlib.sha256(str(canonicalRequest).encode('utf-8')).hexdigest() # Unicoede in 3.x + stringToSign = "AWS4-HMAC-SHA256\n" + amazonDateComplex + "\n" + amazonDateSimple + "/" + region + "/" + awsServiceName + "/aws4_request\n" + hashedCanonicalRequest # Unicode in 3.x + # Sign it + signingKey = self._getSignatureKey(secretKey, amazonDateSimple, region, awsServiceName) + signature = hmac.new(signingKey, (stringToSign).encode("utf-8"), hashlib.sha256).hexdigest() + # generate url + url = "wss://" + host + ":" + str(port) + path + '?' + queryParameters + "&X-Amz-Signature=" + signature + # See if we have STS token, if we do, add it + awsSessionTokenCandidate = allKeys.get("aws_session_token") + if awsSessionTokenCandidate is not None and len(awsSessionTokenCandidate) != 0: + aws_session_token = allKeys["aws_session_token"] + url += "&X-Amz-Security-Token=" + quote(aws_session_token.encode("utf-8")) # Unicode in 3.x + self._logger.debug("createWebsocketEndpoint: Websocket URL: " + url) + return url + + def _hasCredentialsNecessaryForWebsocket(self, allKeys): + awsAccessKeyIdCandidate = allKeys.get("aws_access_key_id") + awsSecretAccessKeyCandidate = allKeys.get("aws_secret_access_key") + # None value is NOT considered as valid entries + validEntries = awsAccessKeyIdCandidate is not None and awsAccessKeyIdCandidate is not None + if validEntries: + # Empty value is NOT considered as valid entries + validEntries &= (len(awsAccessKeyIdCandidate) != 0 and len(awsSecretAccessKeyCandidate) != 0) + return validEntries + # This is an internal class that buffers the incoming bytes into an # internal buffer until it gets the full desired length of bytes. @@ -43,7 +296,7 @@ # For other errors, leave them to the paho _packet_read for error reporting. -class _bufferedReader: +class _BufferedReader: _sslSocket = None _internalBuffer = None _remainedLength = -1 @@ -67,6 +320,10 @@ def read(self, numberOfBytesToBeBuffered): while self._remainedLength > 0: # Read in a loop, always try to read in the remained length # If the data is temporarily not available, socket.error will be raised and catched by paho dataChunk = self._sslSocket.read(self._remainedLength) + # There is a chance where the server terminates the connection without closing the socket. + # If that happens, let's raise an exception and enter the reconnect flow. + if not dataChunk: + raise socket.error(errno.ECONNABORTED, 0) self._internalBuffer.extend(dataChunk) # Buffer the data self._remainedLength -= len(dataChunk) # Update the remained length @@ -76,6 +333,7 @@ def read(self, numberOfBytesToBeBuffered): self._reset() return ret # This should always be bytearray + # This is the internal class that sends requested data out chunk by chunk according # to the availablity of the socket write operation. If the requested bytes of data # (after encoding) needs to be sent out in separate socket write operations (most @@ -89,7 +347,7 @@ def read(self, numberOfBytesToBeBuffered): # For other errors, leave them to the paho _packet_read for error reporting. -class _bufferedWriter: +class _BufferedWriter: _sslSocket = None _internalBuffer = None _writingInProgress = False @@ -109,7 +367,7 @@ def _reset(self): # Input data for this function needs to be an encoded wss frame # Always request for packet[pos=0:] (raw MQTT data) def write(self, encodedData, payloadLength): - # encodedData should always be bytearray + # encodedData should always be bytearray # Check if we have a frame that is partially sent if not self._writingInProgress: self._internalBuffer = encodedData @@ -128,7 +386,7 @@ def write(self, encodedData, payloadLength): return 0 # Ensure that the 'pos' inside the MQTT packet never moves since we have not finished the transmission of this encoded frame -class securedWebsocketCore: +class SecuredWebSocketCore: # Websocket Constants _OP_CONTINUATION = 0x0 _OP_TEXT = 0x1 @@ -140,6 +398,8 @@ class securedWebsocketCore: _WebsocketConnectInit = -1 _WebsocketDisconnected = 1 + _logger = logging.getLogger(__name__) + def __init__(self, socket, hostAddress, portNumber, AWSAccessKeyID="", AWSSecretAccessKey="", AWSSessionToken=""): self._connectStatus = self._WebsocketConnectInit # Handlers @@ -170,12 +430,14 @@ def __init__(self, socket, hostAddress, portNumber, AWSAccessKeyID="", AWSSecret raise ValueError("No Access Key/KeyID Error") except wssHandShakeError: raise ValueError("Websocket Handshake Error") + except ClientError as e: + raise ValueError(e.message) # Now we have a socket with secured websocket... - self._bufferedReader = _bufferedReader(self._sslSocket) - self._bufferedWriter = _bufferedWriter(self._sslSocket) + self._bufferedReader = _BufferedReader(self._sslSocket) + self._bufferedWriter = _BufferedWriter(self._sslSocket) def _createSigV4Core(self): - return sigV4Core() + return SigV4Core() def _generateMaskKey(self): return bytearray(os.urandom(4)) @@ -220,11 +482,12 @@ def _verifyWSSAcceptKey(self, srcAcceptKey, clientKey): def _handShake(self, hostAddress, portNumber): CRLF = "\r\n" - hostAddressChunks = hostAddress.split('.') # .iot..amazonaws.com - region = hostAddressChunks[2] # XXXX..beta + IOT_ENDPOINT_PATTERN = r"^[0-9a-zA-Z]+(\.ats|-ats)?\.iot\.(.*)\.amazonaws\..*" + matched = re.compile(IOT_ENDPOINT_PATTERN, re.IGNORECASE).match(hostAddress) + if not matched: + raise ClientError("Invalid endpoint pattern for wss: %s" % hostAddress) + region = matched.group(2) signedURL = self._sigV4Handler.createWebsocketEndpoint(hostAddress, portNumber, region, "GET", "iotdata", "/mqtt") - if signedURL == "": - raise wssNoKeyInEnvironmentError() # Now we got a signedURL path = signedURL[signedURL.index("/mqtt"):] # Assemble HTTP request headers @@ -243,14 +506,17 @@ def _handShake(self, hostAddress, portNumber): handshakeBytes = handshakeBytes.encode('utf-8') self._sslSocket.write(handshakeBytes) # Read it back (Non-blocking socket) - # Do we need a timeout here? + timeStart = time.time() wssHandshakeResponse = bytearray() while len(wssHandshakeResponse) == 0: try: wssHandshakeResponse += self._sslSocket.read(1024) # Response is always less than 1024 bytes except socket.error as err: if err.errno == ssl.SSL_ERROR_WANT_READ or err.errno == ssl.SSL_ERROR_WANT_WRITE: - pass + if time.time() - timeStart > self._getTimeoutSec(): + raise err # We make sure that reconnect gets retried in Paho upon a wss reconnect response timeout + else: + raise err # Verify response # Now both wssHandshakeResponse and rawSecWebSocketKey are byte strings if not self._verifyWSSResponse(wssHandshakeResponse, rawSecWebSocketKey): @@ -258,6 +524,9 @@ def _handShake(self, hostAddress, portNumber): else: pass + def _getTimeoutSec(self): + return DEFAULT_CONNECT_DISCONNECT_TIMEOUT_SEC + # Used to create a single wss frame # Assume that the maximum length of a MQTT packet never exceeds the maximum length # for a wss frame. Therefore, the FIN bit for the encoded frame will always be 1. @@ -326,7 +595,7 @@ def read(self, numberOfBytes): # struct.unpack(fmt, buffer) # Py3.x # Here ret is always in bytes (buffer interface) if sys.version_info[0] < 3: # Py2.x - ret = str(ret) + ret = str(ret) return ret # Emmm, We don't. Try to buffer from the socket (It's a new wss frame). if not self._hasOpByte: # Check if we need to buffer OpByte @@ -362,9 +631,9 @@ def read(self, numberOfBytes): payloadLengthExtended = self._bufferedReader.read(self._payloadLengthBytesLength) self._hasPayloadLengthExtended = True if sys.version_info[0] < 3: - payloadLengthExtended = str(payloadLengthExtended) + payloadLengthExtended = str(payloadLengthExtended) if self._payloadLengthBytesLength == 2: - self._payloadLength = struct.unpack("!H", payloadLengthExtended)[0] + self._payloadLength = struct.unpack("!H", payloadLengthExtended)[0] else: # _payloadLengthBytesLength == 8 self._payloadLength = struct.unpack("!Q", payloadLengthExtended)[0] @@ -401,7 +670,7 @@ def read(self, numberOfBytes): # struct.unpack(fmt, buffer) # Py3.x # Here ret is always in bytes (buffer interface) if sys.version_info[0] < 3: # Py2.x - ret = str(ret) + ret = str(ret) return ret else: # Fragmented MQTT packets in separate wss frames raise socket.error(ssl.SSL_ERROR_WANT_READ, "Not a complete MQTT packet payload within this wss frame.") @@ -420,6 +689,9 @@ def close(self): self._sslSocket.close() self._sslSocket = None + def getpeercert(self): + return self._sslSocket.getpeercert() + def getSSLSocket(self): if self._connectStatus != self._WebsocketDisconnected: return self._sslSocket diff --git a/AWSIoTPythonSDK/core/protocol/internal/__init__.py b/AWSIoTPythonSDK/core/protocol/internal/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/AWSIoTPythonSDK/core/protocol/internal/clients.py b/AWSIoTPythonSDK/core/protocol/internal/clients.py new file mode 100644 index 0000000..90f48b7 --- /dev/null +++ b/AWSIoTPythonSDK/core/protocol/internal/clients.py @@ -0,0 +1,247 @@ +# /* +# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# * +# * Licensed under the Apache License, Version 2.0 (the "License"). +# * You may not use this file except in compliance with the License. +# * A copy of the License is located at +# * +# * http://aws.amazon.com/apache2.0 +# * +# * or in the "license" file accompanying this file. This file is distributed +# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# * express or implied. See the License for the specific language governing +# * permissions and limitations under the License. +# */ + +import ssl +import logging +from threading import Lock +from numbers import Number +import AWSIoTPythonSDK.core.protocol.paho.client as mqtt +from AWSIoTPythonSDK.core.protocol.paho.client import MQTT_ERR_SUCCESS +from AWSIoTPythonSDK.core.protocol.internal.events import FixedEventMids + + +class ClientStatus(object): + + IDLE = 0 + CONNECT = 1 + RESUBSCRIBE = 2 + DRAINING = 3 + STABLE = 4 + USER_DISCONNECT = 5 + ABNORMAL_DISCONNECT = 6 + + +class ClientStatusContainer(object): + + def __init__(self): + self._status = ClientStatus.IDLE + + def get_status(self): + return self._status + + def set_status(self, status): + if ClientStatus.USER_DISCONNECT == self._status: # If user requests to disconnect, no status updates other than user connect + if ClientStatus.CONNECT == status: + self._status = status + else: + self._status = status + + +class InternalAsyncMqttClient(object): + + _logger = logging.getLogger(__name__) + + def __init__(self, client_id, clean_session, protocol, use_wss): + self._paho_client = self._create_paho_client(client_id, clean_session, None, protocol, use_wss) + self._use_wss = use_wss + self._event_callback_map_lock = Lock() + self._event_callback_map = dict() + + def _create_paho_client(self, client_id, clean_session, user_data, protocol, use_wss): + self._logger.debug("Initializing MQTT layer...") + return mqtt.Client(client_id, clean_session, user_data, protocol, use_wss) + + # TODO: Merge credentials providers configuration into one + def set_cert_credentials_provider(self, cert_credentials_provider, ciphers_provider): + # History issue from Yun SDK where AR9331 embedded Linux only have Python 2.7.3 + # pre-installed. In this version, TLSv1_2 is not even an option. + # SSLv23 is a work-around which selects the highest TLS version between the client + # and service. If user installs opensslv1.0.1+, this option will work fine for Mutual + # Auth. + # Note that we cannot force TLSv1.2 for Mutual Auth. in Python 2.7.3 and TLS support + # in Python only starts from Python2.7. + # See also: https://docs.python.org/2/library/ssl.html#ssl.PROTOCOL_SSLv23 + if self._use_wss: + ca_path = cert_credentials_provider.get_ca_path() + ciphers = ciphers_provider.get_ciphers() + self._paho_client.tls_set(ca_certs=ca_path, cert_reqs=ssl.CERT_REQUIRED, tls_version=ssl.PROTOCOL_SSLv23, + ciphers=ciphers) + else: + ca_path = cert_credentials_provider.get_ca_path() + cert_path = cert_credentials_provider.get_cert_path() + key_path = cert_credentials_provider.get_key_path() + ciphers = ciphers_provider.get_ciphers() + self._paho_client.tls_set(ca_certs=ca_path,certfile=cert_path, keyfile=key_path, + cert_reqs=ssl.CERT_REQUIRED, tls_version=ssl.PROTOCOL_SSLv23, ciphers=ciphers) + + def set_iam_credentials_provider(self, iam_credentials_provider): + self._paho_client.configIAMCredentials(iam_credentials_provider.get_access_key_id(), + iam_credentials_provider.get_secret_access_key(), + iam_credentials_provider.get_session_token()) + + def set_endpoint_provider(self, endpoint_provider): + self._endpoint_provider = endpoint_provider + + def configure_last_will(self, topic, payload, qos, retain=False): + self._paho_client.will_set(topic, payload, qos, retain) + + def configure_alpn_protocols(self, alpn_protocols): + self._paho_client.config_alpn_protocols(alpn_protocols) + + def clear_last_will(self): + self._paho_client.will_clear() + + def set_username_password(self, username, password=None): + self._paho_client.username_pw_set(username, password) + + def set_socket_factory(self, socket_factory): + self._paho_client.socket_factory_set(socket_factory) + + def configure_reconnect_back_off(self, base_reconnect_quiet_sec, max_reconnect_quiet_sec, stable_connection_sec): + self._paho_client.setBackoffTiming(base_reconnect_quiet_sec, max_reconnect_quiet_sec, stable_connection_sec) + + def connect(self, keep_alive_sec, ack_callback=None): + host = self._endpoint_provider.get_host() + port = self._endpoint_provider.get_port() + + with self._event_callback_map_lock: + self._logger.debug("Filling in fixed event callbacks: CONNACK, DISCONNECT, MESSAGE") + self._event_callback_map[FixedEventMids.CONNACK_MID] = self._create_combined_on_connect_callback(ack_callback) + self._event_callback_map[FixedEventMids.DISCONNECT_MID] = self._create_combined_on_disconnect_callback(None) + self._event_callback_map[FixedEventMids.MESSAGE_MID] = self._create_converted_on_message_callback() + + rc = self._paho_client.connect(host, port, keep_alive_sec) + if MQTT_ERR_SUCCESS == rc: + self.start_background_network_io() + + return rc + + def start_background_network_io(self): + self._logger.debug("Starting network I/O thread...") + self._paho_client.loop_start() + + def stop_background_network_io(self): + self._logger.debug("Stopping network I/O thread...") + self._paho_client.loop_stop() + + def disconnect(self, ack_callback=None): + with self._event_callback_map_lock: + rc = self._paho_client.disconnect() + if MQTT_ERR_SUCCESS == rc: + self._logger.debug("Filling in custom disconnect event callback...") + combined_on_disconnect_callback = self._create_combined_on_disconnect_callback(ack_callback) + self._event_callback_map[FixedEventMids.DISCONNECT_MID] = combined_on_disconnect_callback + return rc + + def _create_combined_on_connect_callback(self, ack_callback): + def combined_on_connect_callback(mid, data): + self.on_online() + if ack_callback: + ack_callback(mid, data) + return combined_on_connect_callback + + def _create_combined_on_disconnect_callback(self, ack_callback): + def combined_on_disconnect_callback(mid, data): + self.on_offline() + if ack_callback: + ack_callback(mid, data) + return combined_on_disconnect_callback + + def _create_converted_on_message_callback(self): + def converted_on_message_callback(mid, data): + self.on_message(data) + return converted_on_message_callback + + # For client online notification + def on_online(self): + pass + + # For client offline notification + def on_offline(self): + pass + + # For client message reception notification + def on_message(self, message): + pass + + def publish(self, topic, payload, qos, retain=False, ack_callback=None): + with self._event_callback_map_lock: + rc, mid = self._paho_client.publish(topic, payload, qos, retain) + if MQTT_ERR_SUCCESS == rc and qos > 0 and ack_callback: + self._logger.debug("Filling in custom puback (QoS>0) event callback...") + self._event_callback_map[mid] = ack_callback + return rc, mid + + def subscribe(self, topic, qos, ack_callback=None): + with self._event_callback_map_lock: + rc, mid = self._paho_client.subscribe(topic, qos) + if MQTT_ERR_SUCCESS == rc and ack_callback: + self._logger.debug("Filling in custom suback event callback...") + self._event_callback_map[mid] = ack_callback + return rc, mid + + def unsubscribe(self, topic, ack_callback=None): + with self._event_callback_map_lock: + rc, mid = self._paho_client.unsubscribe(topic) + if MQTT_ERR_SUCCESS == rc and ack_callback: + self._logger.debug("Filling in custom unsuback event callback...") + self._event_callback_map[mid] = ack_callback + return rc, mid + + def register_internal_event_callbacks(self, on_connect, on_disconnect, on_publish, on_subscribe, on_unsubscribe, on_message): + self._logger.debug("Registering internal event callbacks to MQTT layer...") + self._paho_client.on_connect = on_connect + self._paho_client.on_disconnect = on_disconnect + self._paho_client.on_publish = on_publish + self._paho_client.on_subscribe = on_subscribe + self._paho_client.on_unsubscribe = on_unsubscribe + self._paho_client.on_message = on_message + + def unregister_internal_event_callbacks(self): + self._logger.debug("Unregistering internal event callbacks from MQTT layer...") + self._paho_client.on_connect = None + self._paho_client.on_disconnect = None + self._paho_client.on_publish = None + self._paho_client.on_subscribe = None + self._paho_client.on_unsubscribe = None + self._paho_client.on_message = None + + def invoke_event_callback(self, mid, data=None): + with self._event_callback_map_lock: + event_callback = self._event_callback_map.get(mid) + # For invoking the event callback, we do not need to acquire the lock + if event_callback: + self._logger.debug("Invoking custom event callback...") + if data is not None: + event_callback(mid=mid, data=data) + else: + event_callback(mid=mid) + if isinstance(mid, Number): # Do NOT remove callbacks for CONNACK/DISCONNECT/MESSAGE + self._logger.debug("This custom event callback is for pub/sub/unsub, removing it after invocation...") + with self._event_callback_map_lock: + del self._event_callback_map[mid] + + def remove_event_callback(self, mid): + with self._event_callback_map_lock: + if mid in self._event_callback_map: + self._logger.debug("Removing custom event callback...") + del self._event_callback_map[mid] + + def clean_up_event_callbacks(self): + with self._event_callback_map_lock: + self._event_callback_map.clear() + + def get_event_callback_map(self): + return self._event_callback_map diff --git a/AWSIoTPythonSDK/core/protocol/internal/defaults.py b/AWSIoTPythonSDK/core/protocol/internal/defaults.py new file mode 100644 index 0000000..66817d3 --- /dev/null +++ b/AWSIoTPythonSDK/core/protocol/internal/defaults.py @@ -0,0 +1,20 @@ +# /* +# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# * +# * Licensed under the Apache License, Version 2.0 (the "License"). +# * You may not use this file except in compliance with the License. +# * A copy of the License is located at +# * +# * http://aws.amazon.com/apache2.0 +# * +# * or in the "license" file accompanying this file. This file is distributed +# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# * express or implied. See the License for the specific language governing +# * permissions and limitations under the License. +# */ + +DEFAULT_CONNECT_DISCONNECT_TIMEOUT_SEC = 30 +DEFAULT_OPERATION_TIMEOUT_SEC = 5 +DEFAULT_DRAINING_INTERNAL_SEC = 0.5 +METRICS_PREFIX = "?SDK=Python&Version=" +ALPN_PROTCOLS = "x-amzn-mqtt-ca" \ No newline at end of file diff --git a/AWSIoTPythonSDK/core/protocol/internal/events.py b/AWSIoTPythonSDK/core/protocol/internal/events.py new file mode 100644 index 0000000..90f0b70 --- /dev/null +++ b/AWSIoTPythonSDK/core/protocol/internal/events.py @@ -0,0 +1,29 @@ +# /* +# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# * +# * Licensed under the Apache License, Version 2.0 (the "License"). +# * You may not use this file except in compliance with the License. +# * A copy of the License is located at +# * +# * http://aws.amazon.com/apache2.0 +# * +# * or in the "license" file accompanying this file. This file is distributed +# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# * express or implied. See the License for the specific language governing +# * permissions and limitations under the License. +# */ + +class EventTypes(object): + CONNACK = 0 + DISCONNECT = 1 + PUBACK = 2 + SUBACK = 3 + UNSUBACK = 4 + MESSAGE = 5 + + +class FixedEventMids(object): + CONNACK_MID = "CONNECTED" + DISCONNECT_MID = "DISCONNECTED" + MESSAGE_MID = "MESSAGE" + QUEUED_MID = "QUEUED" diff --git a/AWSIoTPythonSDK/core/protocol/internal/queues.py b/AWSIoTPythonSDK/core/protocol/internal/queues.py new file mode 100644 index 0000000..77046a8 --- /dev/null +++ b/AWSIoTPythonSDK/core/protocol/internal/queues.py @@ -0,0 +1,87 @@ +# /* +# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# * +# * Licensed under the Apache License, Version 2.0 (the "License"). +# * You may not use this file except in compliance with the License. +# * A copy of the License is located at +# * +# * http://aws.amazon.com/apache2.0 +# * +# * or in the "license" file accompanying this file. This file is distributed +# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# * express or implied. See the License for the specific language governing +# * permissions and limitations under the License. +# */ + +import logging +from AWSIoTPythonSDK.core.util.enums import DropBehaviorTypes + + +class AppendResults(object): + APPEND_FAILURE_QUEUE_FULL = -1 + APPEND_FAILURE_QUEUE_DISABLED = -2 + APPEND_SUCCESS = 0 + + +class OfflineRequestQueue(list): + _logger = logging.getLogger(__name__) + + def __init__(self, max_size, drop_behavior=DropBehaviorTypes.DROP_NEWEST): + if not isinstance(max_size, int) or not isinstance(drop_behavior, int): + self._logger.error("init: MaximumSize/DropBehavior must be integer.") + raise TypeError("MaximumSize/DropBehavior must be integer.") + if drop_behavior != DropBehaviorTypes.DROP_OLDEST and drop_behavior != DropBehaviorTypes.DROP_NEWEST: + self._logger.error("init: Drop behavior not supported.") + raise ValueError("Drop behavior not supported.") + + list.__init__([]) + self._drop_behavior = drop_behavior + # When self._maximumSize > 0, queue is limited + # When self._maximumSize == 0, queue is disabled + # When self._maximumSize < 0. queue is infinite + self._max_size = max_size + + def _is_enabled(self): + return self._max_size != 0 + + def _need_drop_messages(self): + # Need to drop messages when: + # 1. Queue is limited and full + # 2. Queue is disabled + is_queue_full = len(self) >= self._max_size + is_queue_limited = self._max_size > 0 + is_queue_disabled = not self._is_enabled() + return (is_queue_full and is_queue_limited) or is_queue_disabled + + def set_behavior_drop_newest(self): + self._drop_behavior = DropBehaviorTypes.DROP_NEWEST + + def set_behavior_drop_oldest(self): + self._drop_behavior = DropBehaviorTypes.DROP_OLDEST + + # Override + # Append to a queue with a limited size. + # Return APPEND_SUCCESS if the append is successful + # Return APPEND_FAILURE_QUEUE_FULL if the append failed because the queue is full + # Return APPEND_FAILURE_QUEUE_DISABLED if the append failed because the queue is disabled + def append(self, data): + ret = AppendResults.APPEND_SUCCESS + if self._is_enabled(): + if self._need_drop_messages(): + # We should drop the newest + if DropBehaviorTypes.DROP_NEWEST == self._drop_behavior: + self._logger.warn("append: Full queue. Drop the newest: " + str(data)) + ret = AppendResults.APPEND_FAILURE_QUEUE_FULL + # We should drop the oldest + else: + current_oldest = super(OfflineRequestQueue, self).pop(0) + self._logger.warn("append: Full queue. Drop the oldest: " + str(current_oldest)) + super(OfflineRequestQueue, self).append(data) + ret = AppendResults.APPEND_FAILURE_QUEUE_FULL + else: + self._logger.debug("append: Add new element: " + str(data)) + super(OfflineRequestQueue, self).append(data) + else: + self._logger.debug("append: Queue is disabled. Drop the message: " + str(data)) + ret = AppendResults.APPEND_FAILURE_QUEUE_DISABLED + return ret diff --git a/AWSIoTPythonSDK/core/protocol/internal/requests.py b/AWSIoTPythonSDK/core/protocol/internal/requests.py new file mode 100644 index 0000000..bd2585d --- /dev/null +++ b/AWSIoTPythonSDK/core/protocol/internal/requests.py @@ -0,0 +1,27 @@ +# /* +# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# * +# * Licensed under the Apache License, Version 2.0 (the "License"). +# * You may not use this file except in compliance with the License. +# * A copy of the License is located at +# * +# * http://aws.amazon.com/apache2.0 +# * +# * or in the "license" file accompanying this file. This file is distributed +# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# * express or implied. See the License for the specific language governing +# * permissions and limitations under the License. +# */ + +class RequestTypes(object): + CONNECT = 0 + DISCONNECT = 1 + PUBLISH = 2 + SUBSCRIBE = 3 + UNSUBSCRIBE = 4 + +class QueueableRequest(object): + + def __init__(self, type, data): + self.type = type + self.data = data # Can be a tuple diff --git a/AWSIoTPythonSDK/core/protocol/internal/workers.py b/AWSIoTPythonSDK/core/protocol/internal/workers.py new file mode 100644 index 0000000..e52db3f --- /dev/null +++ b/AWSIoTPythonSDK/core/protocol/internal/workers.py @@ -0,0 +1,296 @@ +# /* +# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# * +# * Licensed under the Apache License, Version 2.0 (the "License"). +# * You may not use this file except in compliance with the License. +# * A copy of the License is located at +# * +# * http://aws.amazon.com/apache2.0 +# * +# * or in the "license" file accompanying this file. This file is distributed +# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# * express or implied. See the License for the specific language governing +# * permissions and limitations under the License. +# */ + +import time +import logging +from threading import Thread +from threading import Event +from AWSIoTPythonSDK.core.protocol.internal.events import EventTypes +from AWSIoTPythonSDK.core.protocol.internal.events import FixedEventMids +from AWSIoTPythonSDK.core.protocol.internal.clients import ClientStatus +from AWSIoTPythonSDK.core.protocol.internal.queues import OfflineRequestQueue +from AWSIoTPythonSDK.core.protocol.internal.requests import RequestTypes +from AWSIoTPythonSDK.core.protocol.paho.client import topic_matches_sub +from AWSIoTPythonSDK.core.protocol.internal.defaults import DEFAULT_DRAINING_INTERNAL_SEC + + +class EventProducer(object): + + _logger = logging.getLogger(__name__) + + def __init__(self, cv, event_queue): + self._cv = cv + self._event_queue = event_queue + + def on_connect(self, client, user_data, flags, rc): + self._add_to_queue(FixedEventMids.CONNACK_MID, EventTypes.CONNACK, rc) + self._logger.debug("Produced [connack] event") + + def on_disconnect(self, client, user_data, rc): + self._add_to_queue(FixedEventMids.DISCONNECT_MID, EventTypes.DISCONNECT, rc) + self._logger.debug("Produced [disconnect] event") + + def on_publish(self, client, user_data, mid): + self._add_to_queue(mid, EventTypes.PUBACK, None) + self._logger.debug("Produced [puback] event") + + def on_subscribe(self, client, user_data, mid, granted_qos): + self._add_to_queue(mid, EventTypes.SUBACK, granted_qos) + self._logger.debug("Produced [suback] event") + + def on_unsubscribe(self, client, user_data, mid): + self._add_to_queue(mid, EventTypes.UNSUBACK, None) + self._logger.debug("Produced [unsuback] event") + + def on_message(self, client, user_data, message): + self._add_to_queue(FixedEventMids.MESSAGE_MID, EventTypes.MESSAGE, message) + self._logger.debug("Produced [message] event") + + def _add_to_queue(self, mid, event_type, data): + with self._cv: + self._event_queue.put((mid, event_type, data)) + self._cv.notify() + + +class EventConsumer(object): + + MAX_DISPATCH_INTERNAL_SEC = 0.01 + _logger = logging.getLogger(__name__) + + def __init__(self, cv, event_queue, internal_async_client, + subscription_manager, offline_requests_manager, client_status): + self._cv = cv + self._event_queue = event_queue + self._internal_async_client = internal_async_client + self._subscription_manager = subscription_manager + self._offline_requests_manager = offline_requests_manager + self._client_status = client_status + self._is_running = False + self._draining_interval_sec = DEFAULT_DRAINING_INTERNAL_SEC + self._dispatch_methods = { + EventTypes.CONNACK : self._dispatch_connack, + EventTypes.DISCONNECT : self._dispatch_disconnect, + EventTypes.PUBACK : self._dispatch_puback, + EventTypes.SUBACK : self._dispatch_suback, + EventTypes.UNSUBACK : self._dispatch_unsuback, + EventTypes.MESSAGE : self._dispatch_message + } + self._offline_request_handlers = { + RequestTypes.PUBLISH : self._handle_offline_publish, + RequestTypes.SUBSCRIBE : self._handle_offline_subscribe, + RequestTypes.UNSUBSCRIBE : self._handle_offline_unsubscribe + } + self._stopper = Event() + + def update_offline_requests_manager(self, offline_requests_manager): + self._offline_requests_manager = offline_requests_manager + + def update_draining_interval_sec(self, draining_interval_sec): + self._draining_interval_sec = draining_interval_sec + + def get_draining_interval_sec(self): + return self._draining_interval_sec + + def is_running(self): + return self._is_running + + def start(self): + self._stopper.clear() + self._is_running = True + dispatch_events = Thread(target=self._dispatch) + dispatch_events.daemon = True + dispatch_events.start() + self._logger.debug("Event consuming thread started") + + def stop(self): + if self._is_running: + self._is_running = False + self._clean_up() + self._logger.debug("Event consuming thread stopped") + + def _clean_up(self): + self._logger.debug("Cleaning up before stopping event consuming") + with self._event_queue.mutex: + self._event_queue.queue.clear() + self._logger.debug("Event queue cleared") + self._internal_async_client.stop_background_network_io() + self._logger.debug("Network thread stopped") + self._internal_async_client.clean_up_event_callbacks() + self._logger.debug("Event callbacks cleared") + + def wait_until_it_stops(self, timeout_sec): + self._logger.debug("Waiting for event consumer to completely stop") + return self._stopper.wait(timeout=timeout_sec) + + def is_fully_stopped(self): + return self._stopper.is_set() + + def _dispatch(self): + while self._is_running: + with self._cv: + if self._event_queue.empty(): + self._cv.wait(self.MAX_DISPATCH_INTERNAL_SEC) + else: + while not self._event_queue.empty(): + self._dispatch_one() + self._stopper.set() + self._logger.debug("Exiting dispatching loop...") + + def _dispatch_one(self): + mid, event_type, data = self._event_queue.get() + if mid: + self._dispatch_methods[event_type](mid, data) + self._internal_async_client.invoke_event_callback(mid, data=data) + # We need to make sure disconnect event gets dispatched and then we stop the consumer + if self._need_to_stop_dispatching(mid): + self.stop() + + def _need_to_stop_dispatching(self, mid): + status = self._client_status.get_status() + return (ClientStatus.USER_DISCONNECT == status or ClientStatus.CONNECT == status) \ + and mid == FixedEventMids.DISCONNECT_MID + + def _dispatch_connack(self, mid, rc): + status = self._client_status.get_status() + self._logger.debug("Dispatching [connack] event") + if self._need_recover(): + if ClientStatus.STABLE != status: # To avoid multiple connack dispatching + self._logger.debug("Has recovery job") + clean_up_debt = Thread(target=self._clean_up_debt) + clean_up_debt.start() + else: + self._logger.debug("No need for recovery") + self._client_status.set_status(ClientStatus.STABLE) + + def _need_recover(self): + return self._subscription_manager.list_records() or self._offline_requests_manager.has_more() + + def _clean_up_debt(self): + self._handle_resubscribe() + self._handle_draining() + self._client_status.set_status(ClientStatus.STABLE) + + def _handle_resubscribe(self): + subscriptions = self._subscription_manager.list_records() + if subscriptions and not self._has_user_disconnect_request(): + self._logger.debug("Start resubscribing") + self._client_status.set_status(ClientStatus.RESUBSCRIBE) + for topic, (qos, message_callback, ack_callback) in subscriptions: + if self._has_user_disconnect_request(): + self._logger.debug("User disconnect detected") + break + self._internal_async_client.subscribe(topic, qos, ack_callback) + + def _handle_draining(self): + if self._offline_requests_manager.has_more() and not self._has_user_disconnect_request(): + self._logger.debug("Start draining") + self._client_status.set_status(ClientStatus.DRAINING) + while self._offline_requests_manager.has_more(): + if self._has_user_disconnect_request(): + self._logger.debug("User disconnect detected") + break + offline_request = self._offline_requests_manager.get_next() + if offline_request: + self._offline_request_handlers[offline_request.type](offline_request) + time.sleep(self._draining_interval_sec) + + def _has_user_disconnect_request(self): + return ClientStatus.USER_DISCONNECT == self._client_status.get_status() + + def _dispatch_disconnect(self, mid, rc): + self._logger.debug("Dispatching [disconnect] event") + status = self._client_status.get_status() + if ClientStatus.USER_DISCONNECT == status or ClientStatus.CONNECT == status: + pass + else: + self._client_status.set_status(ClientStatus.ABNORMAL_DISCONNECT) + + # For puback, suback and unsuback, ack callback invocation is handled in dispatch_one + # Do nothing in the event dispatching itself + def _dispatch_puback(self, mid, rc): + self._logger.debug("Dispatching [puback] event") + + def _dispatch_suback(self, mid, rc): + self._logger.debug("Dispatching [suback] event") + + def _dispatch_unsuback(self, mid, rc): + self._logger.debug("Dispatching [unsuback] event") + + def _dispatch_message(self, mid, message): + self._logger.debug("Dispatching [message] event") + subscriptions = self._subscription_manager.list_records() + if subscriptions: + for topic, (qos, message_callback, _) in subscriptions: + if topic_matches_sub(topic, message.topic) and message_callback: + message_callback(None, None, message) # message_callback(client, userdata, message) + + def _handle_offline_publish(self, request): + topic, payload, qos, retain = request.data + self._internal_async_client.publish(topic, payload, qos, retain) + self._logger.debug("Processed offline publish request") + + def _handle_offline_subscribe(self, request): + topic, qos, message_callback, ack_callback = request.data + self._subscription_manager.add_record(topic, qos, message_callback, ack_callback) + self._internal_async_client.subscribe(topic, qos, ack_callback) + self._logger.debug("Processed offline subscribe request") + + def _handle_offline_unsubscribe(self, request): + topic, ack_callback = request.data + self._subscription_manager.remove_record(topic) + self._internal_async_client.unsubscribe(topic, ack_callback) + self._logger.debug("Processed offline unsubscribe request") + + +class SubscriptionManager(object): + + _logger = logging.getLogger(__name__) + + def __init__(self): + self._subscription_map = dict() + + def add_record(self, topic, qos, message_callback, ack_callback): + self._logger.debug("Adding a new subscription record: %s qos: %d", topic, qos) + self._subscription_map[topic] = qos, message_callback, ack_callback # message_callback and/or ack_callback could be None + + def remove_record(self, topic): + self._logger.debug("Removing subscription record: %s", topic) + if self._subscription_map.get(topic): # Ignore topics that are never subscribed to + del self._subscription_map[topic] + else: + self._logger.warn("Removing attempt for non-exist subscription record: %s", topic) + + def list_records(self): + return list(self._subscription_map.items()) + + +class OfflineRequestsManager(object): + + _logger = logging.getLogger(__name__) + + def __init__(self, max_size, drop_behavior): + self._queue = OfflineRequestQueue(max_size, drop_behavior) + + def has_more(self): + return len(self._queue) > 0 + + def add_one(self, request): + return self._queue.append(request) + + def get_next(self): + if self.has_more(): + return self._queue.pop(0) + else: + return None diff --git a/AWSIoTPythonSDK/core/protocol/mqttCore.py b/AWSIoTPythonSDK/core/protocol/mqttCore.py deleted file mode 100755 index 1c5a465..0000000 --- a/AWSIoTPythonSDK/core/protocol/mqttCore.py +++ /dev/null @@ -1,459 +0,0 @@ -# /* -# * Copyright 2010-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# * -# * Licensed under the Apache License, Version 2.0 (the "License"). -# * You may not use this file except in compliance with the License. -# * A copy of the License is located at -# * -# * http://aws.amazon.com/apache2.0 -# * -# * or in the "license" file accompanying this file. This file is distributed -# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# * express or implied. See the License for the specific language governing -# * permissions and limitations under the License. -# */ - -import sys -import ssl -import time -import logging -import threading -import AWSIoTPythonSDK.core.protocol.paho.client as mqtt -import AWSIoTPythonSDK.core.util.offlinePublishQueue as offlinePublishQueue -from threading import Lock -from AWSIoTPythonSDK.exception.AWSIoTExceptions import connectError -from AWSIoTPythonSDK.exception.AWSIoTExceptions import connectTimeoutException -from AWSIoTPythonSDK.exception.AWSIoTExceptions import disconnectError -from AWSIoTPythonSDK.exception.AWSIoTExceptions import disconnectTimeoutException -from AWSIoTPythonSDK.exception.AWSIoTExceptions import publishError -from AWSIoTPythonSDK.exception.AWSIoTExceptions import publishQueueFullException -from AWSIoTPythonSDK.exception.AWSIoTExceptions import publishQueueDisabledException -from AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeError -from AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeTimeoutException -from AWSIoTPythonSDK.exception.AWSIoTExceptions import unsubscribeError -from AWSIoTPythonSDK.exception.AWSIoTExceptions import unsubscribeTimeoutException - -# Class that holds queued publish request details -class _publishRequest: - def __init__(self, srcTopic, srcPayload, srcQos, srcRetain): - self.topic = srcTopic - self.payload = srcPayload - self.qos = srcQos - self.retain = srcRetain - - -class mqttCore: - - def getClientID(self): - return self._clientID - - def setConnectDisconnectTimeoutSecond(self, srcConnectDisconnectTimeout): - self._connectdisconnectTimeout = srcConnectDisconnectTimeout - self._log.debug("Set maximum connect/disconnect timeout to be " + str(self._connectdisconnectTimeout) + " second.") - - def getConnectDisconnectTimeoutSecond(self): - return self._connectdisconnectTimeout - - def setMQTTOperationTimeoutSecond(self, srcMQTTOperationTimeout): - self._mqttOperationTimeout = srcMQTTOperationTimeout - self._log.debug("Set maximum MQTT operation timeout to be " + str(self._mqttOperationTimeout) + " second") - - def getMQTTOperationTimeoutSecond(self): - return self._mqttOperationTimeout - - def setUserData(self, srcUserData): - self._pahoClient.user_data_set(srcUserData) - - def createPahoClient(self, clientID, cleanSession, userdata, protocol, useWebsocket): - return mqtt.Client(clientID, cleanSession, userdata, protocol, useWebsocket) # Throw exception when error happens - - def _doResubscribe(self): - if self._subscribePool: - self._resubscribeCount = len(self._subscribePool) # This is the only place where _resubscribeCount gets its count - for key in self._subscribePool.keys(): - qos, callback = self._subscribePool.get(key) - try: - self.subscribe(key, qos, callback) - time.sleep(self._drainingIntervalSecond) # Subscribe requests should also be sent out using the draining interval - except (subscribeError, subscribeTimeoutException): - self._log.warn("Error in re-subscription to topic: " + str(key)) - pass # Subscribe error resulted from network error, will redo subscription in the next re-connect - - # Performed in a seperate thread, draining the offlinePublishQueue at a given draining rate - # Publish theses queued messages to Paho - # Should always pop the queue since Paho has its own queueing and retry logic - # Should exit immediately when there is an error in republishing queued message - # Should leave it to the next round of reconnect/resubscribe/republish logic at mqttCore - def _doPublishDraining(self): - while True: - self._offlinePublishQueueLock.acquire() - # This should be a complete publish requests containing topic, payload, qos, retain information - # This is the only thread that pops the offlinePublishQueue - if self._offlinePublishQueue: - queuedPublishRequest = self._offlinePublishQueue.pop(0) - # Publish it (call paho API directly) - (rc, mid) = self._pahoClient.publish(queuedPublishRequest.topic, queuedPublishRequest.payload, queuedPublishRequest.qos, queuedPublishRequest.retain) - if rc != 0: - self._offlinePublishQueueLock.release() - break - else: - self._drainingComplete = True - self._offlinePublishQueueLock.release() - break - self._offlinePublishQueueLock.release() - time.sleep(self._drainingIntervalSecond) - - # Callbacks - def on_connect(self, client, userdata, flags, rc): - self._disconnectResultCode = sys.maxsize - self._connectResultCode = rc - if self._connectResultCode == 0: # If this is a successful connect, do resubscribe - processResubscription = threading.Thread(target=self._doResubscribe) - processResubscription.start() - # If we do not have any topics to resubscribe to, still start a new thread to process queued publish requests - if not self._subscribePool: - offlinePublishQueueDraining = threading.Thread(target=self._doPublishDraining) - offlinePublishQueueDraining.start() - self._log.debug("Connect result code " + str(rc)) - - def on_disconnect(self, client, userdata, rc): - self._connectResultCode = sys.maxsize - self._disconnectResultCode = rc - self._drainingComplete = False # Draining status should be reset when disconnect happens - self._log.debug("Disconnect result code " + str(rc)) - - def on_subscribe(self, client, userdata, mid, granted_qos): - # Execution of this callback is atomic, guaranteed by Paho - # Check if we have got all SUBACKs for all resubscriptions - self._log.debug("_resubscribeCount: " + str(self._resubscribeCount)) - if self._resubscribeCount > 0: # Check if there is actually a need for resubscribe - self._resubscribeCount -= 1 # collect SUBACK for all resubscriptions - if self._resubscribeCount == 0: - # start a thread draining the offline publish queue - offlinePublishQueueDraining = threading.Thread(target=self._doPublishDraining) - offlinePublishQueueDraining.start() - self._resubscribeCount = -1 # Recover the context for resubscribe - self._subscribeSent = True - self._log.debug("Subscribe request " + str(mid) + " sent.") - - def on_unsubscribe(self, client, userdata, mid): - self._unsubscribeSent = True - self._log.debug("Unsubscribe request " + str(mid) + " sent.") - - def on_message(self, client, userdata, message): - # Generic message callback - self._log.warn("Received (No custom callback registered) : message: " + str(message.payload) + " from topic: " + str(message.topic)) - - ####### API starts here ####### - def __init__(self, clientID, cleanSession, protocol, srcUseWebsocket=False): - if clientID is None or cleanSession is None or protocol is None: - raise TypeError("None type inputs detected.") - # All internal data member should be unique per mqttCore intance - # Tool handler - self._log = logging.getLogger(__name__) - self._clientID = clientID - self._pahoClient = self.createPahoClient(clientID, cleanSession, None, protocol, srcUseWebsocket) # User data is set to None as default - self._log.debug("Paho MQTT Client init.") - self._log.info("ClientID: " + str(clientID)) - protocolType = "MQTTv3.1.1" - if protocol == 3: - protocolType = "MQTTv3.1" - self._log.info("Protocol: " + protocolType) - self._pahoClient.on_connect = self.on_connect - self._pahoClient.on_disconnect = self.on_disconnect - self._pahoClient.on_message = self.on_message - self._pahoClient.on_subscribe = self.on_subscribe - self._pahoClient.on_unsubscribe = self.on_unsubscribe - self._log.debug("Register Paho MQTT Client callbacks.") - # Tool data structure - self._connectResultCode = sys.maxsize - self._disconnectResultCode = sys.maxsize - self._subscribeSent = False - self._unsubscribeSent = False - self._connectdisconnectTimeout = 30 # Default connect/disconnect timeout set to 30 second - self._mqttOperationTimeout = 5 # Default MQTT operation timeout set to 5 second - # Use Websocket - self._useWebsocket = srcUseWebsocket - # Subscribe record - self._subscribePool = dict() - self._resubscribeCount = -1 # Ensure that initial value for _resubscribeCount does not trigger draining on each SUBACK - # Broker information - self._host = "" - self._port = -1 - self._cafile = "" - self._key = "" - self._cert = "" - self._stsToken = "" - # Operation mutex - self._publishLock = Lock() - self._subscribeLock = Lock() - self._unsubscribeLock = Lock() - # OfflinePublishQueue - self._offlinePublishQueueLock = Lock() - self._offlinePublishQueue = offlinePublishQueue.offlinePublishQueue(20, 1) - # Draining interval in seconds - self._drainingIntervalSecond = 0.5 - # Is Draining complete - self._drainingComplete = True - self._log.debug("mqttCore init.") - - def configEndpoint(self, srcHost, srcPort): - if srcHost is None or srcPort is None: - self._log.error("configEndpoint: None type inputs detected.") - raise TypeError("None type inputs detected.") - self._host = srcHost - self._port = srcPort - - def configCredentials(self, srcCAFile, srcKey, srcCert): - if srcCAFile is None or srcKey is None or srcCert is None: - self._log.error("configCredentials: None type inputs detected.") - raise TypeError("None type inputs detected.") - self._cafile = srcCAFile - self._key = srcKey - self._cert = srcCert - self._log.debug("Load CAFile from: " + self._cafile) - self._log.debug("Load Key from: " + self._key) - self._log.debug("Load Cert from: " + self._cert) - - def configIAMCredentials(self, srcAWSAccessKeyID, srcAWSSecretAccessKey, srcAWSSessionToken): - if srcAWSSecretAccessKey is None or srcAWSSecretAccessKey is None or srcAWSSessionToken is None: - self._log.error("configIAMCredentials: None type inputs detected.") - raise TypeError("None type inputs detected.") - self._pahoClient.configIAMCredentials(srcAWSAccessKeyID, srcAWSSecretAccessKey, srcAWSSessionToken) - - def setLastWill(self, srcTopic, srcPayload, srcQos): - if srcTopic is None or srcPayload is None or srcQos is None: - self._log.error("setLastWill: None type inputs detected.") - raise TypeError("None type inputs detected.") - self._pahoClient.will_set(srcTopic, srcPayload, srcQos, False) - - def clearLastWill(self): - self._pahoClient.will_clear() - - def setBackoffTime(self, srcBaseReconnectTimeSecond, srcMaximumReconnectTimeSecond, srcMinimumConnectTimeSecond): - if srcBaseReconnectTimeSecond is None or srcMaximumReconnectTimeSecond is None or srcMinimumConnectTimeSecond is None: - self._log.error("setBackoffTime: None type inputs detected.") - raise TypeError("None type inputs detected.") - # Below line could raise ValueError if input params are not properly selected - self._pahoClient.setBackoffTiming(srcBaseReconnectTimeSecond, srcMaximumReconnectTimeSecond, srcMinimumConnectTimeSecond) - self._log.debug("Custom setting for backoff timing: baseReconnectTime = " + str(srcBaseReconnectTimeSecond) + " sec") - self._log.debug("Custom setting for backoff timing: maximumReconnectTime = " + str(srcMaximumReconnectTimeSecond) + " sec") - self._log.debug("Custom setting for backoff timing: minimumConnectTime = " + str(srcMinimumConnectTimeSecond) + " sec") - - def setOfflinePublishQueueing(self, srcQueueSize, srcDropBehavior=mqtt.MSG_QUEUEING_DROP_NEWEST): - if srcQueueSize is None or srcDropBehavior is None: - self._log.error("setOfflinePublishQueueing: None type inputs detected.") - raise TypeError("None type inputs detected.") - self._offlinePublishQueue = offlinePublishQueue.offlinePublishQueue(srcQueueSize, srcDropBehavior) - self._log.debug("Custom setting for publish queueing: queueSize = " + str(srcQueueSize)) - dropBehavior_word = "Drop Oldest" - if srcDropBehavior == 1: - dropBehavior_word = "Drop Newest" - self._log.debug("Custom setting for publish queueing: dropBehavior = " + dropBehavior_word) - - def setDrainingIntervalSecond(self, srcDrainingIntervalSecond): - if srcDrainingIntervalSecond is None: - self._log.error("setDrainingIntervalSecond: None type inputs detected.") - raise TypeError("None type inputs detected.") - if srcDrainingIntervalSecond < 0: - self._log.error("setDrainingIntervalSecond: Draining interval should not be negative.") - raise ValueError("Draining interval should not be negative.") - self._drainingIntervalSecond = srcDrainingIntervalSecond - self._log.debug("Custom setting for draining interval: " + str(srcDrainingIntervalSecond) + " sec") - - # MQTT connection - def connect(self, keepAliveInterval=30): - if keepAliveInterval is None : - self._log.error("connect: None type inputs detected.") - raise TypeError("None type inputs detected.") - if not isinstance(keepAliveInterval, int): - self._log.error("connect: Wrong input type detected. KeepAliveInterval must be an integer.") - raise TypeError("Non-integer type inputs detected.") - # Return connect succeeded/failed - ret = False - # TLS configuration - if self._useWebsocket: - # History issue from Yun SDK where AR9331 embedded Linux only have Python 2.7.3 - # pre-installed. In this version, TLSv1_2 is not even an option. - # SSLv23 is a work-around which selects the highest TLS version between the client - # and service. If user installs opensslv1.0.1+, this option will work fine for Mutal - # Auth. - # Note that we cannot force TLSv1.2 for Mutual Auth. in Python 2.7.3 and TLS support - # in Python only starts from Python2.7. - # See also: https://docs.python.org/2/library/ssl.html#ssl.PROTOCOL_SSLv23 - self._pahoClient.tls_set(ca_certs=self._cafile, cert_reqs=ssl.CERT_REQUIRED, tls_version=ssl.PROTOCOL_SSLv23) - self._log.info("Connection type: Websocket") - else: - self._pahoClient.tls_set(self._cafile, self._cert, self._key, ssl.CERT_REQUIRED, ssl.PROTOCOL_SSLv23) # Throw exception... - self._log.info("Connection type: TLSv1.2 Mutual Authentication") - # Connect - self._pahoClient.connect(self._host, self._port, keepAliveInterval) # Throw exception... - self._pahoClient.loop_start() - TenmsCount = 0 - while(TenmsCount != self._connectdisconnectTimeout * 100 and self._connectResultCode == sys.maxsize): - TenmsCount += 1 - time.sleep(0.01) - if(self._connectResultCode == sys.maxsize): - self._log.error("Connect timeout.") - self._pahoClient.loop_stop() - raise connectTimeoutException() - elif(self._connectResultCode == 0): - ret = True - self._log.info("Connected to AWS IoT.") - self._log.debug("Connect time consumption: " + str(float(TenmsCount) * 10) + "ms.") - else: - self._log.error("A connect error happened: " + str(self._connectResultCode)) - self._pahoClient.loop_stop() - raise connectError(self._connectResultCode) - return ret - - def disconnect(self): - # Return disconnect succeeded/failed - ret = False - # Disconnect - self._pahoClient.disconnect() # Throw exception... - TenmsCount = 0 - while(TenmsCount != self._connectdisconnectTimeout * 100 and self._disconnectResultCode == sys.maxsize): - TenmsCount += 1 - time.sleep(0.01) - if(self._disconnectResultCode == sys.maxsize): - self._log.error("Disconnect timeout.") - raise disconnectTimeoutException() - elif(self._disconnectResultCode == 0): - ret = True - self._log.info("Disconnected.") - self._log.debug("Disconnect time consumption: " + str(float(TenmsCount) * 10) + "ms.") - self._pahoClient.loop_stop() # Do NOT maintain a background thread for socket communication since it is a successful disconnect - else: - self._log.error("A disconnect error happened: " + str(self._disconnectResultCode)) - raise disconnectError(self._disconnectResultCode) - return ret - - def publish(self, topic, payload, qos, retain): - if(topic is None or payload is None or qos is None or retain is None): - self._log.error("publish: None type inputs detected.") - raise TypeError("None type inputs detected.") - # Return publish succeeded/failed - ret = False - # Queueing should happen when disconnected or draining is in progress - self._offlinePublishQueueLock.acquire() - queuedPublishCondition = not self._drainingComplete or self._connectResultCode == sys.maxsize - if queuedPublishCondition: - if self._connectResultCode == sys.maxsize: - self._log.info("Offline publish request detected.") - # If the client is connected but draining is not completed... - elif not self._drainingComplete: - self._log.info("Drainging is still on-going.") - self._log.info("Try queueing up this request...") - # Publish to the queue and report error (raise Exception) - currentQueuedPublishRequest = _publishRequest(topic, payload, qos, retain) - # Try to append the element... - appendResult = self._offlinePublishQueue.append(currentQueuedPublishRequest) - # When the queue is full... - if appendResult == self._offlinePublishQueue.APPEND_FAILURE_QUEUE_FULL: - self._offlinePublishQueueLock.release() - raise publishQueueFullException() - # When the queue is disabled... - elif appendResult == self._offlinePublishQueue.APPEND_FAILURE_QUEUE_DISABLED: - self._offlinePublishQueueLock.release() - raise publishQueueDisabledException() - # When the queue is good... - else: - self._offlinePublishQueueLock.release() - # Publish to Paho - else: - self._offlinePublishQueueLock.release() - self._publishLock.acquire() - # Publish - (rc, mid) = self._pahoClient.publish(topic, payload, qos, retain) # Throw exception... - self._log.debug("Try to put a publish request " + str(mid) + " in the TCP stack.") - ret = rc == 0 - if(ret): - self._log.debug("Publish request " + str(mid) + " succeeded.") - else: - self._log.error("Publish request " + str(mid) + " failed with code: " + str(rc)) - self._publishLock.release() # Release the lock when exception is raised - raise publishError(rc) - self._publishLock.release() - return ret - - def subscribe(self, topic, qos, callback): - if(topic is None or qos is None): - self._log.error("subscribe: None type inputs detected.") - raise TypeError("None type inputs detected.") - # Return subscribe succeeded/failed - ret = False - self._subscribeLock.acquire() - # Subscribe - # Register callback - if(callback is not None): - self._pahoClient.message_callback_add(topic, callback) - (rc, mid) = self._pahoClient.subscribe(topic, qos) # Throw exception... - self._log.debug("Started a subscribe request " + str(mid)) - TenmsCount = 0 - while(TenmsCount != self._mqttOperationTimeout * 100 and not self._subscribeSent): - TenmsCount += 1 - time.sleep(0.01) - if(self._subscribeSent): - ret = rc == 0 - if(ret): - self._subscribePool[topic] = (qos, callback) - self._log.debug("Subscribe request " + str(mid) + " succeeded. Time consumption: " + str(float(TenmsCount) * 10) + "ms.") - else: - if(callback is not None): - self._pahoClient.message_callback_remove(topic) - self._log.error("Subscribe request " + str(mid) + " failed with code: " + str(rc)) - self._log.debug("Callback cleaned up.") - self._subscribeLock.release() # Release the lock when exception is raised - raise subscribeError(rc) - else: - # Subscribe timeout - if(callback is not None): - self._pahoClient.message_callback_remove(topic) - self._log.error("No feedback detected for subscribe request " + str(mid) + ". Timeout and failed.") - self._log.debug("Callback cleaned up.") - self._subscribeLock.release() # Release the lock when exception is raised - raise subscribeTimeoutException() - self._subscribeSent = False - self._log.debug("Recover subscribe context for the next request: subscribeSent: " + str(self._subscribeSent)) - self._subscribeLock.release() - return ret - - def unsubscribe(self, topic): - if(topic is None): - self._log.error("unsubscribe: None type inputs detected.") - raise TypeError("None type inputs detected.") - self._log.debug("unsubscribe from: " + topic) - # Return unsubscribe succeeded/failed - ret = False - self._unsubscribeLock.acquire() - # Unsubscribe - (rc, mid) = self._pahoClient.unsubscribe(topic) # Throw exception... - self._log.debug("Started an unsubscribe request " + str(mid)) - TenmsCount = 0 - while(TenmsCount != self._mqttOperationTimeout * 100 and not self._unsubscribeSent): - TenmsCount += 1 - time.sleep(0.01) - if(self._unsubscribeSent): - ret = rc == 0 - if(ret): - try: - del self._subscribePool[topic] - except KeyError: - pass # Ignore topics that are never subscribed to - self._log.debug("Unsubscribe request " + str(mid) + " succeeded. Time consumption: " + str(float(TenmsCount) * 10) + "ms.") - self._pahoClient.message_callback_remove(topic) - self._log.debug("Remove the callback.") - else: - self._log.error("Unsubscribe request " + str(mid) + " failed with code: " + str(rc)) - self._unsubscribeLock.release() # Release the lock when exception is raised - raise unsubscribeError(rc) - else: - # Unsubscribe timeout - self._log.error("No feedback detected for unsubscribe request " + str(mid) + ". Timeout and failed.") - self._unsubscribeLock.release() # Release the lock when exception is raised - raise unsubscribeTimeoutException() - self._unsubscribeSent = False - self._log.debug("Recover unsubscribe context for the next request: unsubscribeSent: " + str(self._unsubscribeSent)) - self._unsubscribeLock.release() - return ret diff --git a/AWSIoTPythonSDK/core/protocol/mqtt_core.py b/AWSIoTPythonSDK/core/protocol/mqtt_core.py new file mode 100644 index 0000000..fbdd6bf --- /dev/null +++ b/AWSIoTPythonSDK/core/protocol/mqtt_core.py @@ -0,0 +1,373 @@ +# /* +# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# * +# * Licensed under the Apache License, Version 2.0 (the "License"). +# * You may not use this file except in compliance with the License. +# * A copy of the License is located at +# * +# * http://aws.amazon.com/apache2.0 +# * +# * or in the "license" file accompanying this file. This file is distributed +# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# * express or implied. See the License for the specific language governing +# * permissions and limitations under the License. +# */ + +import AWSIoTPythonSDK +from AWSIoTPythonSDK.core.protocol.internal.clients import InternalAsyncMqttClient +from AWSIoTPythonSDK.core.protocol.internal.clients import ClientStatusContainer +from AWSIoTPythonSDK.core.protocol.internal.clients import ClientStatus +from AWSIoTPythonSDK.core.protocol.internal.workers import EventProducer +from AWSIoTPythonSDK.core.protocol.internal.workers import EventConsumer +from AWSIoTPythonSDK.core.protocol.internal.workers import SubscriptionManager +from AWSIoTPythonSDK.core.protocol.internal.workers import OfflineRequestsManager +from AWSIoTPythonSDK.core.protocol.internal.requests import RequestTypes +from AWSIoTPythonSDK.core.protocol.internal.requests import QueueableRequest +from AWSIoTPythonSDK.core.protocol.internal.defaults import DEFAULT_CONNECT_DISCONNECT_TIMEOUT_SEC +from AWSIoTPythonSDK.core.protocol.internal.defaults import DEFAULT_OPERATION_TIMEOUT_SEC +from AWSIoTPythonSDK.core.protocol.internal.defaults import METRICS_PREFIX +from AWSIoTPythonSDK.core.protocol.internal.defaults import ALPN_PROTCOLS +from AWSIoTPythonSDK.core.protocol.internal.events import FixedEventMids +from AWSIoTPythonSDK.core.protocol.paho.client import MQTT_ERR_SUCCESS +from AWSIoTPythonSDK.exception.AWSIoTExceptions import connectError +from AWSIoTPythonSDK.exception.AWSIoTExceptions import connectTimeoutException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import disconnectError +from AWSIoTPythonSDK.exception.AWSIoTExceptions import disconnectTimeoutException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import publishError +from AWSIoTPythonSDK.exception.AWSIoTExceptions import publishTimeoutException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import publishQueueFullException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import publishQueueDisabledException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeQueueFullException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeQueueDisabledException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import unsubscribeQueueFullException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import unsubscribeQueueDisabledException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeError +from AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeTimeoutException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import unsubscribeError +from AWSIoTPythonSDK.exception.AWSIoTExceptions import unsubscribeTimeoutException +from AWSIoTPythonSDK.core.protocol.internal.queues import AppendResults +from AWSIoTPythonSDK.core.util.enums import DropBehaviorTypes +from AWSIoTPythonSDK.core.protocol.paho.client import MQTTv31 +from threading import Condition +from threading import Event +import logging +import sys +if sys.version_info[0] < 3: + from Queue import Queue +else: + from queue import Queue + + +class MqttCore(object): + + _logger = logging.getLogger(__name__) + + def __init__(self, client_id, clean_session, protocol, use_wss): + self._use_wss = use_wss + self._username = "" + self._password = None + self._enable_metrics_collection = True + self._event_queue = Queue() + self._event_cv = Condition() + self._event_producer = EventProducer(self._event_cv, self._event_queue) + self._client_status = ClientStatusContainer() + self._internal_async_client = InternalAsyncMqttClient(client_id, clean_session, protocol, use_wss) + self._subscription_manager = SubscriptionManager() + self._offline_requests_manager = OfflineRequestsManager(-1, DropBehaviorTypes.DROP_NEWEST) # Infinite queue + self._event_consumer = EventConsumer(self._event_cv, + self._event_queue, + self._internal_async_client, + self._subscription_manager, + self._offline_requests_manager, + self._client_status) + self._connect_disconnect_timeout_sec = DEFAULT_CONNECT_DISCONNECT_TIMEOUT_SEC + self._operation_timeout_sec = DEFAULT_OPERATION_TIMEOUT_SEC + self._init_offline_request_exceptions() + self._init_workers() + self._logger.info("MqttCore initialized") + self._logger.info("Client id: %s" % client_id) + self._logger.info("Protocol version: %s" % ("MQTTv3.1" if protocol == MQTTv31 else "MQTTv3.1.1")) + self._logger.info("Authentication type: %s" % ("SigV4 WebSocket" if use_wss else "TLSv1.2 certificate based Mutual Auth.")) + + def _init_offline_request_exceptions(self): + self._offline_request_queue_disabled_exceptions = { + RequestTypes.PUBLISH : publishQueueDisabledException, + RequestTypes.SUBSCRIBE : subscribeQueueDisabledException, + RequestTypes.UNSUBSCRIBE : unsubscribeQueueDisabledException + } + self._offline_request_queue_full_exceptions = { + RequestTypes.PUBLISH : publishQueueFullException, + RequestTypes.SUBSCRIBE : subscribeQueueFullException, + RequestTypes.UNSUBSCRIBE : unsubscribeQueueFullException + } + + def _init_workers(self): + self._internal_async_client.register_internal_event_callbacks(self._event_producer.on_connect, + self._event_producer.on_disconnect, + self._event_producer.on_publish, + self._event_producer.on_subscribe, + self._event_producer.on_unsubscribe, + self._event_producer.on_message) + + def _start_workers(self): + self._event_consumer.start() + + def use_wss(self): + return self._use_wss + + # Used for general message event reception + def on_message(self, message): + pass + + # Used for general online event notification + def on_online(self): + pass + + # Used for general offline event notification + def on_offline(self): + pass + + def configure_cert_credentials(self, cert_credentials_provider, ciphers_provider): + self._logger.info("Configuring certificates and ciphers...") + self._internal_async_client.set_cert_credentials_provider(cert_credentials_provider, ciphers_provider) + + def configure_iam_credentials(self, iam_credentials_provider): + self._logger.info("Configuring custom IAM credentials...") + self._internal_async_client.set_iam_credentials_provider(iam_credentials_provider) + + def configure_endpoint(self, endpoint_provider): + self._logger.info("Configuring endpoint...") + self._internal_async_client.set_endpoint_provider(endpoint_provider) + + def configure_connect_disconnect_timeout_sec(self, connect_disconnect_timeout_sec): + self._logger.info("Configuring connect/disconnect time out: %f sec" % connect_disconnect_timeout_sec) + self._connect_disconnect_timeout_sec = connect_disconnect_timeout_sec + + def configure_operation_timeout_sec(self, operation_timeout_sec): + self._logger.info("Configuring MQTT operation time out: %f sec" % operation_timeout_sec) + self._operation_timeout_sec = operation_timeout_sec + + def configure_reconnect_back_off(self, base_reconnect_quiet_sec, max_reconnect_quiet_sec, stable_connection_sec): + self._logger.info("Configuring reconnect back off timing...") + self._logger.info("Base quiet time: %f sec" % base_reconnect_quiet_sec) + self._logger.info("Max quiet time: %f sec" % max_reconnect_quiet_sec) + self._logger.info("Stable connection time: %f sec" % stable_connection_sec) + self._internal_async_client.configure_reconnect_back_off(base_reconnect_quiet_sec, max_reconnect_quiet_sec, stable_connection_sec) + + def configure_alpn_protocols(self): + self._logger.info("Configuring alpn protocols...") + self._internal_async_client.configure_alpn_protocols([ALPN_PROTCOLS]) + + def configure_last_will(self, topic, payload, qos, retain=False): + self._logger.info("Configuring last will...") + self._internal_async_client.configure_last_will(topic, payload, qos, retain) + + def clear_last_will(self): + self._logger.info("Clearing last will...") + self._internal_async_client.clear_last_will() + + def configure_username_password(self, username, password=None): + self._logger.info("Configuring username and password...") + self._username = username + self._password = password + + def configure_socket_factory(self, socket_factory): + self._logger.info("Configuring socket factory...") + self._internal_async_client.set_socket_factory(socket_factory) + + def enable_metrics_collection(self): + self._enable_metrics_collection = True + + def disable_metrics_collection(self): + self._enable_metrics_collection = False + + def configure_offline_requests_queue(self, max_size, drop_behavior): + self._logger.info("Configuring offline requests queueing: max queue size: %d", max_size) + self._offline_requests_manager = OfflineRequestsManager(max_size, drop_behavior) + self._event_consumer.update_offline_requests_manager(self._offline_requests_manager) + + def configure_draining_interval_sec(self, draining_interval_sec): + self._logger.info("Configuring offline requests queue draining interval: %f sec", draining_interval_sec) + self._event_consumer.update_draining_interval_sec(draining_interval_sec) + + def connect(self, keep_alive_sec): + self._logger.info("Performing sync connect...") + event = Event() + self.connect_async(keep_alive_sec, self._create_blocking_ack_callback(event)) + if not event.wait(self._connect_disconnect_timeout_sec): + self._logger.error("Connect timed out") + raise connectTimeoutException() + return True + + def connect_async(self, keep_alive_sec, ack_callback=None): + self._logger.info("Performing async connect...") + self._logger.info("Keep-alive: %f sec" % keep_alive_sec) + self._start_workers() + self._load_callbacks() + self._load_username_password() + + try: + self._client_status.set_status(ClientStatus.CONNECT) + rc = self._internal_async_client.connect(keep_alive_sec, ack_callback) + if MQTT_ERR_SUCCESS != rc: + self._logger.error("Connect error: %d", rc) + raise connectError(rc) + except Exception as e: + # Provided any error in connect, we should clean up the threads that have been created + self._event_consumer.stop() + if not self._event_consumer.wait_until_it_stops(self._connect_disconnect_timeout_sec): + self._logger.error("Time out in waiting for event consumer to stop") + else: + self._logger.debug("Event consumer stopped") + self._client_status.set_status(ClientStatus.IDLE) + raise e + + return FixedEventMids.CONNACK_MID + + def _load_callbacks(self): + self._logger.debug("Passing in general notification callbacks to internal client...") + self._internal_async_client.on_online = self.on_online + self._internal_async_client.on_offline = self.on_offline + self._internal_async_client.on_message = self.on_message + + def _load_username_password(self): + username_candidate = self._username + if self._enable_metrics_collection: + username_candidate += METRICS_PREFIX + username_candidate += AWSIoTPythonSDK.__version__ + self._internal_async_client.set_username_password(username_candidate, self._password) + + def disconnect(self): + self._logger.info("Performing sync disconnect...") + event = Event() + self.disconnect_async(self._create_blocking_ack_callback(event)) + if not event.wait(self._connect_disconnect_timeout_sec): + self._logger.error("Disconnect timed out") + raise disconnectTimeoutException() + if not self._event_consumer.wait_until_it_stops(self._connect_disconnect_timeout_sec): + self._logger.error("Disconnect timed out in waiting for event consumer") + raise disconnectTimeoutException() + return True + + def disconnect_async(self, ack_callback=None): + self._logger.info("Performing async disconnect...") + self._client_status.set_status(ClientStatus.USER_DISCONNECT) + rc = self._internal_async_client.disconnect(ack_callback) + if MQTT_ERR_SUCCESS != rc: + self._logger.error("Disconnect error: %d", rc) + raise disconnectError(rc) + return FixedEventMids.DISCONNECT_MID + + def publish(self, topic, payload, qos, retain=False): + self._logger.info("Performing sync publish...") + ret = False + if ClientStatus.STABLE != self._client_status.get_status(): + self._handle_offline_request(RequestTypes.PUBLISH, (topic, payload, qos, retain)) + else: + if qos > 0: + event = Event() + rc, mid = self._publish_async(topic, payload, qos, retain, self._create_blocking_ack_callback(event)) + if not event.wait(self._operation_timeout_sec): + self._internal_async_client.remove_event_callback(mid) + self._logger.error("Publish timed out") + raise publishTimeoutException() + else: + self._publish_async(topic, payload, qos, retain) + ret = True + return ret + + def publish_async(self, topic, payload, qos, retain=False, ack_callback=None): + self._logger.info("Performing async publish...") + if ClientStatus.STABLE != self._client_status.get_status(): + self._handle_offline_request(RequestTypes.PUBLISH, (topic, payload, qos, retain)) + return FixedEventMids.QUEUED_MID + else: + rc, mid = self._publish_async(topic, payload, qos, retain, ack_callback) + return mid + + def _publish_async(self, topic, payload, qos, retain=False, ack_callback=None): + rc, mid = self._internal_async_client.publish(topic, payload, qos, retain, ack_callback) + if MQTT_ERR_SUCCESS != rc: + self._logger.error("Publish error: %d", rc) + raise publishError(rc) + return rc, mid + + def subscribe(self, topic, qos, message_callback=None): + self._logger.info("Performing sync subscribe...") + ret = False + if ClientStatus.STABLE != self._client_status.get_status(): + self._handle_offline_request(RequestTypes.SUBSCRIBE, (topic, qos, message_callback, None)) + else: + event = Event() + rc, mid = self._subscribe_async(topic, qos, self._create_blocking_ack_callback(event), message_callback) + if not event.wait(self._operation_timeout_sec): + self._internal_async_client.remove_event_callback(mid) + self._logger.error("Subscribe timed out") + raise subscribeTimeoutException() + ret = True + return ret + + def subscribe_async(self, topic, qos, ack_callback=None, message_callback=None): + self._logger.info("Performing async subscribe...") + if ClientStatus.STABLE != self._client_status.get_status(): + self._handle_offline_request(RequestTypes.SUBSCRIBE, (topic, qos, message_callback, ack_callback)) + return FixedEventMids.QUEUED_MID + else: + rc, mid = self._subscribe_async(topic, qos, ack_callback, message_callback) + return mid + + def _subscribe_async(self, topic, qos, ack_callback=None, message_callback=None): + self._subscription_manager.add_record(topic, qos, message_callback, ack_callback) + rc, mid = self._internal_async_client.subscribe(topic, qos, ack_callback) + if MQTT_ERR_SUCCESS != rc: + self._logger.error("Subscribe error: %d", rc) + raise subscribeError(rc) + return rc, mid + + def unsubscribe(self, topic): + self._logger.info("Performing sync unsubscribe...") + ret = False + if ClientStatus.STABLE != self._client_status.get_status(): + self._handle_offline_request(RequestTypes.UNSUBSCRIBE, (topic, None)) + else: + event = Event() + rc, mid = self._unsubscribe_async(topic, self._create_blocking_ack_callback(event)) + if not event.wait(self._operation_timeout_sec): + self._internal_async_client.remove_event_callback(mid) + self._logger.error("Unsubscribe timed out") + raise unsubscribeTimeoutException() + ret = True + return ret + + def unsubscribe_async(self, topic, ack_callback=None): + self._logger.info("Performing async unsubscribe...") + if ClientStatus.STABLE != self._client_status.get_status(): + self._handle_offline_request(RequestTypes.UNSUBSCRIBE, (topic, ack_callback)) + return FixedEventMids.QUEUED_MID + else: + rc, mid = self._unsubscribe_async(topic, ack_callback) + return mid + + def _unsubscribe_async(self, topic, ack_callback=None): + self._subscription_manager.remove_record(topic) + rc, mid = self._internal_async_client.unsubscribe(topic, ack_callback) + if MQTT_ERR_SUCCESS != rc: + self._logger.error("Unsubscribe error: %d", rc) + raise unsubscribeError(rc) + return rc, mid + + def _create_blocking_ack_callback(self, event): + def ack_callback(mid, data=None): + event.set() + return ack_callback + + def _handle_offline_request(self, type, data): + self._logger.info("Offline request detected!") + offline_request = QueueableRequest(type, data) + append_result = self._offline_requests_manager.add_one(offline_request) + if AppendResults.APPEND_FAILURE_QUEUE_DISABLED == append_result: + self._logger.error("Offline request queue has been disabled") + raise self._offline_request_queue_disabled_exceptions[type]() + if AppendResults.APPEND_FAILURE_QUEUE_FULL == append_result: + self._logger.error("Offline request queue is full") + raise self._offline_request_queue_full_exceptions[type]() diff --git a/AWSIoTPythonSDK/core/protocol/paho/client.py b/AWSIoTPythonSDK/core/protocol/paho/client.py index 6096aa4..0b637c5 100755 --- a/AWSIoTPythonSDK/core/protocol/paho/client.py +++ b/AWSIoTPythonSDK/core/protocol/paho/client.py @@ -44,10 +44,10 @@ EAGAIN = errno.WSAEWOULDBLOCK else: EAGAIN = errno.EAGAIN -# AWS WSS implementation -import AWSIoTPythonSDK.core.protocol.paho.securedWebsocket.securedWebsocketCore as wssCore -import AWSIoTPythonSDK.core.util.progressiveBackoffCore as backoffCore -import AWSIoTPythonSDK.core.util.offlinePublishQueue as offlinePublishQueue + +from AWSIoTPythonSDK.core.protocol.connection.cores import ProgressiveBackOffCore +from AWSIoTPythonSDK.core.protocol.connection.cores import SecuredWebSocketCore +from AWSIoTPythonSDK.core.protocol.connection.alpn import SSLContextBuilder VERSION_MAJOR=1 VERSION_MINOR=0 @@ -483,6 +483,7 @@ def __init__(self, client_id="", clean_session=True, userdata=None, protocol=MQT self._host = "" self._port = 1883 self._bind_address = "" + self._socket_factory = None self._in_callback = False self._strict_protocol = False self._callback_mutex = threading.Lock() @@ -503,10 +504,11 @@ def __init__(self, client_id="", clean_session=True, userdata=None, protocol=MQT self._tls_version = tls_version self._tls_insecure = False self._useSecuredWebsocket = useSecuredWebsocket # Do we enable secured websocket - self._backoffCore = backoffCore.progressiveBackoffCore() # Init the backoffCore using default configuration + self._backoffCore = ProgressiveBackOffCore() # Init the backoffCore using default configuration self._AWSAccessKeyIDCustomConfig = "" self._AWSSecretAccessKeyCustomConfig = "" self._AWSSessionTokenCustomConfig = "" + self._alpn_protocols = None def __del__(self): pass @@ -517,7 +519,7 @@ def setBackoffTiming(self, srcBaseReconnectTimeSecond, srcMaximumReconnectTimeSe Make custom settings for backoff timing for reconnect logic srcBaseReconnectTimeSecond - The base reconnection time in seconds srcMaximumReconnectTimeSecond - The maximum reconnection time in seconds - srcMinimumConnectTimeSecond - The minimum time in milliseconds that a connection must be maintained in order to be considered stable + srcMinimumConnectTimeSecond - The minimum time in seconds that a connection must be maintained in order to be considered stable * Raise ValueError if input params are malformed """ self._backoffCore.configTime(srcBaseReconnectTimeSecond, srcMaximumReconnectTimeSecond, srcMinimumConnectTimeSecond) @@ -533,6 +535,14 @@ def configIAMCredentials(self, srcAWSAccessKeyID, srcAWSSecretAccessKey, srcAWSS self._AWSSecretAccessKeyCustomConfig = srcAWSSecretAccessKey self._AWSSessionTokenCustomConfig = srcAWSSessionToken + def config_alpn_protocols(self, alpn_protocols): + """ + Make custom settings for ALPN protocols + :param alpn_protocols: Array of strings that specifies the alpn protocols to be used + :return: None + """ + self._alpn_protocols = alpn_protocols + def reinitialise(self, client_id="", clean_session=True, userdata=None): if self._ssl: self._ssl.close() @@ -771,7 +781,9 @@ def reconnect(self): self._messages_reconnect_reset() try: - if (sys.version_info[0] == 2 and sys.version_info[1] < 7) or (sys.version_info[0] == 3 and sys.version_info[1] < 2): + if self._socket_factory: + sock = self._socket_factory() + elif (sys.version_info[0] == 2 and sys.version_info[1] < 7) or (sys.version_info[0] == 3 and sys.version_info[1] < 2): sock = socket.create_connection((self._host, self._port)) else: sock = socket.create_connection((self._host, self._port), source_address=(self._bind_address, 0)) @@ -779,29 +791,67 @@ def reconnect(self): if err.errno != errno.EINPROGRESS and err.errno != errno.EWOULDBLOCK and err.errno != EAGAIN: raise + verify_hostname = self._tls_insecure is False # Decide whether we need to verify hostname + + # To keep the SSL Context update minimal, only apply forced ssl context to python3.12+ + force_ssl_context = sys.version_info[0] > 3 or (sys.version_info[0] == 3 and sys.version_info[1] >= 12) + if self._tls_ca_certs is not None: if self._useSecuredWebsocket: # Never assign to ._ssl before wss handshake is finished # Non-None value for ._ssl will allow ops before wss-MQTT connection is established - rawSSL = ssl.wrap_socket(sock, ca_certs=self._tls_ca_certs, cert_reqs=ssl.CERT_REQUIRED) # Add server certificate verification + if force_ssl_context: + ssl_context = ssl.SSLContext() + ssl_context.load_verify_locations(self._tls_ca_certs) + ssl_context.verify_mode = ssl.CERT_REQUIRED + + rawSSL = ssl_context.wrap_socket(sock) + else: + rawSSL = ssl.wrap_socket(sock, ca_certs=self._tls_ca_certs, cert_reqs=ssl.CERT_REQUIRED) # Add server certificate verification + rawSSL.setblocking(0) # Non-blocking socket - self._ssl = wssCore.securedWebsocketCore(rawSSL, self._host, self._port, self._AWSAccessKeyIDCustomConfig, self._AWSSecretAccessKeyCustomConfig, self._AWSSessionTokenCustomConfig) # Overeride the _ssl socket + self._ssl = SecuredWebSocketCore(rawSSL, self._host, self._port, self._AWSAccessKeyIDCustomConfig, self._AWSSecretAccessKeyCustomConfig, self._AWSSessionTokenCustomConfig) # Override the _ssl socket # self._ssl.enableDebug() + elif self._alpn_protocols is not None: + # SSLContext is required to enable ALPN support + # Assuming Python 2.7.10+/3.5+ till the end of this elif branch + ssl_context = SSLContextBuilder()\ + .with_ca_certs(self._tls_ca_certs)\ + .with_cert_key_pair(self._tls_certfile, self._tls_keyfile)\ + .with_cert_reqs(self._tls_cert_reqs)\ + .with_check_hostname(True)\ + .with_ciphers(self._tls_ciphers)\ + .with_alpn_protocols(self._alpn_protocols)\ + .build() + self._ssl = ssl_context.wrap_socket(sock, server_hostname=self._host, do_handshake_on_connect=False) + verify_hostname = False # Since check_hostname in SSLContext is already set to True, no need to verify it again + self._ssl.do_handshake() else: - self._ssl = ssl.wrap_socket( - sock, - certfile=self._tls_certfile, - keyfile=self._tls_keyfile, - ca_certs=self._tls_ca_certs, - cert_reqs=self._tls_cert_reqs, - ssl_version=self._tls_version, - ciphers=self._tls_ciphers) - - if self._tls_insecure is False: - if sys.version_info[0] < 3 or (sys.version_info[0] == 3 and sys.version_info[1] < 2): - self._tls_match_hostname() - else: - ssl.match_hostname(self._ssl.getpeercert(), self._host) + if force_ssl_context: + ssl_context = ssl.SSLContext(self._tls_version) + ssl_context.load_cert_chain(self._tls_certfile, self._tls_keyfile) + ssl_context.load_verify_locations(self._tls_ca_certs) + ssl_context.verify_mode = self._tls_cert_reqs + if self._tls_ciphers is not None: + ssl_context.set_ciphers(self._tls_ciphers) + + self._ssl = ssl_context.wrap_socket(sock) + else: + self._ssl = ssl.wrap_socket( + sock, + certfile=self._tls_certfile, + keyfile=self._tls_keyfile, + ca_certs=self._tls_ca_certs, + cert_reqs=self._tls_cert_reqs, + ssl_version=self._tls_version, + ciphers=self._tls_ciphers) + + if verify_hostname: + if sys.version_info[0] < 3 or (sys.version_info[0] == 3 and sys.version_info[1] < 5): # No IP host match before 3.5.x + self._tls_match_hostname() + elif sys.version_info[0] == 3 and sys.version_info[1] < 7: + # host name verification is handled internally in Python3.7+ + ssl.match_hostname(self._ssl.getpeercert(), self._host) self._sock = sock @@ -849,6 +899,15 @@ def loop(self, timeout=1.0, max_packets=1): self._out_packet_mutex.release() self._current_out_packet_mutex.release() + # used to check if there are any bytes left in the ssl socket + pending_bytes = 0 + if self._ssl: + pending_bytes = self.socket().pending() + + # if bytes are pending do not wait in select + if pending_bytes > 0: + timeout = 0.0 + # sockpairR is used to break out of select() before the timeout, on a # call to publish() etc. rlist = [self.socket(), self._sockpairR] @@ -864,7 +923,7 @@ def loop(self, timeout=1.0, max_packets=1): except: return MQTT_ERR_UNKNOWN - if self.socket() in socklist[0]: + if self.socket() in socklist[0] or pending_bytes > 0: rc = self.loop_read(max_packets) if rc or (self._ssl is None and self._sock is None): return rc @@ -989,6 +1048,14 @@ def username_pw_set(self, username, password=None): self._username = username.encode('utf-8') self._password = password + def socket_factory_set(self, socket_factory): + """Set a socket factory to custom configure a different socket type for + mqtt connection. + Must be called before connect() to have any effect. + socket_factory: create_connection function which creates a socket to user's specification + """ + self._socket_factory = socket_factory + def disconnect(self): """Disconnect a connected client from the broker.""" self._state_mutex.acquire() @@ -2387,7 +2454,7 @@ def _tls_match_hostname(self): return if key == 'IP Address': have_san_dns = True - if value.lower() == self._host.lower(): + if value.lower().strip() == self._host.lower().strip(): return if have_san_dns: diff --git a/AWSIoTPythonSDK/core/shadow/deviceShadow.py b/AWSIoTPythonSDK/core/shadow/deviceShadow.py index 4404aa8..bb5d667 100755 --- a/AWSIoTPythonSDK/core/shadow/deviceShadow.py +++ b/AWSIoTPythonSDK/core/shadow/deviceShadow.py @@ -1,5 +1,5 @@ # /* -# * Copyright 2010-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # * # * Licensed under the Apache License, Version 2.0 (the "License"). # * You may not use this file except in compliance with the License. @@ -23,14 +23,18 @@ class _shadowRequestToken: URN_PREFIX_LENGTH = 9 - def __init__(self, srcShadowName, srcClientID): - self._shadowName = srcShadowName - self._clientID = srcClientID - def getNextToken(self): return uuid.uuid4().urn[self.URN_PREFIX_LENGTH:] # We only need the uuid digits, not the urn prefix +def _validateJSON(jsonString): + try: + json.loads(jsonString) + except ValueError: + return False + return True + + class _basicJSONParser: def setString(self, srcString): @@ -86,7 +90,7 @@ def __init__(self, srcShadowName, srcIsPersistentSubscribe, srcShadowManager): # Tool handler self._shadowManagerHandler = srcShadowManager self._basicJSONParserHandler = _basicJSONParser() - self._tokenHandler = _shadowRequestToken(self._shadowName, self._shadowManagerHandler.getClientID()) + self._tokenHandler = _shadowRequestToken() # Properties self._isPersistentSubscribe = srcIsPersistentSubscribe self._lastVersionInSync = -1 # -1 means not initialized @@ -109,60 +113,59 @@ def _doNonPersistentUnsubscribe(self, currentAction): self._shadowManagerHandler.basicShadowUnsubscribe(self._shadowName, currentAction) self._logger.info("Unsubscribed to " + currentAction + " accepted/rejected topics for deviceShadow: " + self._shadowName) - def _generalCallback(self, client, userdata, message): + def generalCallback(self, client, userdata, message): # In Py3.x, message.payload comes in as a bytes(string) # json.loads needs a string input - self._dataStructureLock.acquire() - currentTopic = message.topic - currentAction = self._parseTopicAction(currentTopic) # get/delete/update/delta - currentType = self._parseTopicType(currentTopic) # accepted/rejected/delta - payloadUTF8String = message.payload.decode('utf-8') - # get/delete/update: Need to deal with token, timer and unsubscribe - if currentAction in ["get", "delete", "update"]: - # Check for token - self._basicJSONParserHandler.setString(payloadUTF8String) - if self._basicJSONParserHandler.validateJSON(): # Filter out invalid JSON - currentToken = self._basicJSONParserHandler.getAttributeValue(u"clientToken") - if currentToken is not None: - self._logger.debug("shadow message clientToken: " + currentToken) - if currentToken is not None and currentToken in self._tokenPool.keys(): # Filter out JSON without the desired token - # Sync local version when it is an accepted response - self._logger.debug("Token is in the pool. Type: " + currentType) - if currentType == "accepted": - incomingVersion = self._basicJSONParserHandler.getAttributeValue(u"version") - # If it is get/update accepted response, we need to sync the local version - if incomingVersion is not None and incomingVersion > self._lastVersionInSync and currentAction != "delete": - self._lastVersionInSync = incomingVersion - # If it is a delete accepted, we need to reset the version - else: - self._lastVersionInSync = -1 # The version will always be synced for the next incoming delta/GU-accepted response - # Cancel the timer and clear the token - self._tokenPool[currentToken].cancel() - del self._tokenPool[currentToken] - # Need to unsubscribe? - self._shadowSubscribeStatusTable[currentAction] -= 1 - if not self._isPersistentSubscribe and self._shadowSubscribeStatusTable.get(currentAction) <= 0: - self._shadowSubscribeStatusTable[currentAction] = 0 - processNonPersistentUnsubscribe = Thread(target=self._doNonPersistentUnsubscribe, args=[currentAction]) - processNonPersistentUnsubscribe.start() - # Custom callback - if self._shadowSubscribeCallbackTable.get(currentAction) is not None: - processCustomCallback = Thread(target=self._shadowSubscribeCallbackTable[currentAction], args=[payloadUTF8String, currentType, currentToken]) - processCustomCallback.start() - # delta: Watch for version - else: - currentType += "/" + self._parseTopicShadowName(currentTopic) - # Sync local version - self._basicJSONParserHandler.setString(payloadUTF8String) - if self._basicJSONParserHandler.validateJSON(): # Filter out JSON without version - incomingVersion = self._basicJSONParserHandler.getAttributeValue(u"version") - if incomingVersion is not None and incomingVersion > self._lastVersionInSync: - self._lastVersionInSync = incomingVersion - # Custom callback - if self._shadowSubscribeCallbackTable.get(currentAction) is not None: - processCustomCallback = Thread(target=self._shadowSubscribeCallbackTable[currentAction], args=[payloadUTF8String, currentType, None]) - processCustomCallback.start() - self._dataStructureLock.release() + with self._dataStructureLock: + currentTopic = message.topic + currentAction = self._parseTopicAction(currentTopic) # get/delete/update/delta + currentType = self._parseTopicType(currentTopic) # accepted/rejected/delta + payloadUTF8String = message.payload.decode('utf-8') + # get/delete/update: Need to deal with token, timer and unsubscribe + if currentAction in ["get", "delete", "update"]: + # Check for token + self._basicJSONParserHandler.setString(payloadUTF8String) + if self._basicJSONParserHandler.validateJSON(): # Filter out invalid JSON + currentToken = self._basicJSONParserHandler.getAttributeValue(u"clientToken") + if currentToken is not None: + self._logger.debug("shadow message clientToken: " + currentToken) + if currentToken is not None and currentToken in self._tokenPool.keys(): # Filter out JSON without the desired token + # Sync local version when it is an accepted response + self._logger.debug("Token is in the pool. Type: " + currentType) + if currentType == "accepted": + incomingVersion = self._basicJSONParserHandler.getAttributeValue(u"version") + # If it is get/update accepted response, we need to sync the local version + if incomingVersion is not None and incomingVersion > self._lastVersionInSync and currentAction != "delete": + self._lastVersionInSync = incomingVersion + # If it is a delete accepted, we need to reset the version + else: + self._lastVersionInSync = -1 # The version will always be synced for the next incoming delta/GU-accepted response + # Cancel the timer and clear the token + self._tokenPool[currentToken].cancel() + del self._tokenPool[currentToken] + # Need to unsubscribe? + self._shadowSubscribeStatusTable[currentAction] -= 1 + if not self._isPersistentSubscribe and self._shadowSubscribeStatusTable.get(currentAction) <= 0: + self._shadowSubscribeStatusTable[currentAction] = 0 + processNonPersistentUnsubscribe = Thread(target=self._doNonPersistentUnsubscribe, args=[currentAction]) + processNonPersistentUnsubscribe.start() + # Custom callback + if self._shadowSubscribeCallbackTable.get(currentAction) is not None: + processCustomCallback = Thread(target=self._shadowSubscribeCallbackTable[currentAction], args=[payloadUTF8String, currentType, currentToken]) + processCustomCallback.start() + # delta: Watch for version + else: + currentType += "/" + self._parseTopicShadowName(currentTopic) + # Sync local version + self._basicJSONParserHandler.setString(payloadUTF8String) + if self._basicJSONParserHandler.validateJSON(): # Filter out JSON without version + incomingVersion = self._basicJSONParserHandler.getAttributeValue(u"version") + if incomingVersion is not None and incomingVersion > self._lastVersionInSync: + self._lastVersionInSync = incomingVersion + # Custom callback + if self._shadowSubscribeCallbackTable.get(currentAction) is not None: + processCustomCallback = Thread(target=self._shadowSubscribeCallbackTable[currentAction], args=[payloadUTF8String, currentType, None]) + processCustomCallback.start() def _parseTopicAction(self, srcTopic): ret = None @@ -182,19 +185,22 @@ def _parseTopicShadowName(self, srcTopic): return fragments[2] def _timerHandler(self, srcActionName, srcToken): - self._dataStructureLock.acquire() - # Remove the token - del self._tokenPool[srcToken] - # Need to unsubscribe? - self._shadowSubscribeStatusTable[srcActionName] -= 1 - if not self._isPersistentSubscribe and self._shadowSubscribeStatusTable.get(srcActionName) <= 0: - self._shadowSubscribeStatusTable[srcActionName] = 0 - self._shadowManagerHandler.basicShadowUnsubscribe(self._shadowName, srcActionName) - # Notify time-out issue - if self._shadowSubscribeCallbackTable.get(srcActionName) is not None: - self._logger.info("Shadow request with token: " + str(srcToken) + " has timed out.") - self._shadowSubscribeCallbackTable[srcActionName]("REQUEST TIME OUT", "timeout", srcToken) - self._dataStructureLock.release() + with self._dataStructureLock: + # Don't crash if we try to remove an unknown token + if srcToken not in self._tokenPool: + self._logger.warn('Tried to remove non-existent token from pool: %s' % str(srcToken)) + return + # Remove the token + del self._tokenPool[srcToken] + # Need to unsubscribe? + self._shadowSubscribeStatusTable[srcActionName] -= 1 + if not self._isPersistentSubscribe and self._shadowSubscribeStatusTable.get(srcActionName) <= 0: + self._shadowSubscribeStatusTable[srcActionName] = 0 + self._shadowManagerHandler.basicShadowUnsubscribe(self._shadowName, srcActionName) + # Notify time-out issue + if self._shadowSubscribeCallbackTable.get(srcActionName) is not None: + self._logger.info("Shadow request with token: " + str(srcToken) + " has timed out.") + self._shadowSubscribeCallbackTable[srcActionName]("REQUEST TIME OUT", "timeout", srcToken) def shadowGet(self, srcCallback, srcTimeout): """ @@ -228,28 +234,29 @@ def shadowGet(self, srcCallback, srcTimeout): The token used for tracing in this shadow request. """ - self._dataStructureLock.acquire() - # Update callback data structure - self._shadowSubscribeCallbackTable["get"] = srcCallback - # Update number of pending feedback - self._shadowSubscribeStatusTable["get"] += 1 - # clientToken - currentToken = self._tokenHandler.getNextToken() - self._tokenPool[currentToken] = Timer(srcTimeout, self._timerHandler, ["get", currentToken]) - self._basicJSONParserHandler.setString("{}") - self._basicJSONParserHandler.validateJSON() - self._basicJSONParserHandler.setAttributeValue("clientToken", currentToken) - currentPayload = self._basicJSONParserHandler.regenerateString() - self._dataStructureLock.release() + with self._dataStructureLock: + # Update callback data structure + self._shadowSubscribeCallbackTable["get"] = srcCallback + # Update number of pending feedback + self._shadowSubscribeStatusTable["get"] += 1 + # clientToken + currentToken = self._tokenHandler.getNextToken() + self._tokenPool[currentToken] = Timer(srcTimeout, self._timerHandler, ["get", currentToken]) + self._basicJSONParserHandler.setString("{}") + self._basicJSONParserHandler.validateJSON() + self._basicJSONParserHandler.setAttributeValue("clientToken", currentToken) + currentPayload = self._basicJSONParserHandler.regenerateString() # Two subscriptions if not self._isPersistentSubscribe or not self._isGetSubscribed: - self._shadowManagerHandler.basicShadowSubscribe(self._shadowName, "get", self._generalCallback) + self._shadowManagerHandler.basicShadowSubscribe(self._shadowName, "get", self.generalCallback) self._isGetSubscribed = True self._logger.info("Subscribed to get accepted/rejected topics for deviceShadow: " + self._shadowName) # One publish self._shadowManagerHandler.basicShadowPublish(self._shadowName, "get", currentPayload) # Start the timer - self._tokenPool[currentToken].start() + with self._dataStructureLock: + if currentToken in self._tokenPool: + self._tokenPool[currentToken].start() return currentToken def shadowDelete(self, srcCallback, srcTimeout): @@ -284,28 +291,29 @@ def shadowDelete(self, srcCallback, srcTimeout): The token used for tracing in this shadow request. """ - self._dataStructureLock.acquire() - # Update callback data structure - self._shadowSubscribeCallbackTable["delete"] = srcCallback - # Update number of pending feedback - self._shadowSubscribeStatusTable["delete"] += 1 - # clientToken - currentToken = self._tokenHandler.getNextToken() - self._tokenPool[currentToken] = Timer(srcTimeout, self._timerHandler, ["delete", currentToken]) - self._basicJSONParserHandler.setString("{}") - self._basicJSONParserHandler.validateJSON() - self._basicJSONParserHandler.setAttributeValue("clientToken", currentToken) - currentPayload = self._basicJSONParserHandler.regenerateString() - self._dataStructureLock.release() + with self._dataStructureLock: + # Update callback data structure + self._shadowSubscribeCallbackTable["delete"] = srcCallback + # Update number of pending feedback + self._shadowSubscribeStatusTable["delete"] += 1 + # clientToken + currentToken = self._tokenHandler.getNextToken() + self._tokenPool[currentToken] = Timer(srcTimeout, self._timerHandler, ["delete", currentToken]) + self._basicJSONParserHandler.setString("{}") + self._basicJSONParserHandler.validateJSON() + self._basicJSONParserHandler.setAttributeValue("clientToken", currentToken) + currentPayload = self._basicJSONParserHandler.regenerateString() # Two subscriptions if not self._isPersistentSubscribe or not self._isDeleteSubscribed: - self._shadowManagerHandler.basicShadowSubscribe(self._shadowName, "delete", self._generalCallback) + self._shadowManagerHandler.basicShadowSubscribe(self._shadowName, "delete", self.generalCallback) self._isDeleteSubscribed = True self._logger.info("Subscribed to delete accepted/rejected topics for deviceShadow: " + self._shadowName) # One publish self._shadowManagerHandler.basicShadowPublish(self._shadowName, "delete", currentPayload) # Start the timer - self._tokenPool[currentToken].start() + with self._dataStructureLock: + if currentToken in self._tokenPool: + self._tokenPool[currentToken].start() return currentToken def shadowUpdate(self, srcJSONPayload, srcCallback, srcTimeout): @@ -343,30 +351,30 @@ def shadowUpdate(self, srcJSONPayload, srcCallback, srcTimeout): """ # Validate JSON - JSONPayloadWithToken = None - currentToken = None - self._basicJSONParserHandler.setString(srcJSONPayload) - if self._basicJSONParserHandler.validateJSON(): - self._dataStructureLock.acquire() - # clientToken - currentToken = self._tokenHandler.getNextToken() - self._tokenPool[currentToken] = Timer(srcTimeout, self._timerHandler, ["update", currentToken]) - self._basicJSONParserHandler.setAttributeValue("clientToken", currentToken) - JSONPayloadWithToken = self._basicJSONParserHandler.regenerateString() - # Update callback data structure - self._shadowSubscribeCallbackTable["update"] = srcCallback - # Update number of pending feedback - self._shadowSubscribeStatusTable["update"] += 1 - self._dataStructureLock.release() + if _validateJSON(srcJSONPayload): + with self._dataStructureLock: + self._basicJSONParserHandler.setString(srcJSONPayload) + self._basicJSONParserHandler.validateJSON() + # clientToken + currentToken = self._tokenHandler.getNextToken() + self._tokenPool[currentToken] = Timer(srcTimeout, self._timerHandler, ["update", currentToken]) + self._basicJSONParserHandler.setAttributeValue("clientToken", currentToken) + JSONPayloadWithToken = self._basicJSONParserHandler.regenerateString() + # Update callback data structure + self._shadowSubscribeCallbackTable["update"] = srcCallback + # Update number of pending feedback + self._shadowSubscribeStatusTable["update"] += 1 # Two subscriptions if not self._isPersistentSubscribe or not self._isUpdateSubscribed: - self._shadowManagerHandler.basicShadowSubscribe(self._shadowName, "update", self._generalCallback) + self._shadowManagerHandler.basicShadowSubscribe(self._shadowName, "update", self.generalCallback) self._isUpdateSubscribed = True self._logger.info("Subscribed to update accepted/rejected topics for deviceShadow: " + self._shadowName) # One publish self._shadowManagerHandler.basicShadowPublish(self._shadowName, "update", JSONPayloadWithToken) # Start the timer - self._tokenPool[currentToken].start() + with self._dataStructureLock: + if currentToken in self._tokenPool: + self._tokenPool[currentToken].start() else: raise ValueError("Invalid JSON file.") return currentToken @@ -398,12 +406,11 @@ def shadowRegisterDeltaCallback(self, srcCallback): None """ - self._dataStructureLock.acquire() - # Update callback data structure - self._shadowSubscribeCallbackTable["delta"] = srcCallback - self._dataStructureLock.release() + with self._dataStructureLock: + # Update callback data structure + self._shadowSubscribeCallbackTable["delta"] = srcCallback # One subscription - self._shadowManagerHandler.basicShadowSubscribe(self._shadowName, "delta", self._generalCallback) + self._shadowManagerHandler.basicShadowSubscribe(self._shadowName, "delta", self.generalCallback) self._logger.info("Subscribed to delta topic for deviceShadow: " + self._shadowName) def shadowUnregisterDeltaCallback(self): @@ -430,10 +437,9 @@ def shadowUnregisterDeltaCallback(self): None """ - self._dataStructureLock.acquire() - # Update callback data structure - del self._shadowSubscribeCallbackTable["delta"] - self._dataStructureLock.release() + with self._dataStructureLock: + # Update callback data structure + del self._shadowSubscribeCallbackTable["delta"] # One unsubscription self._shadowManagerHandler.basicShadowUnsubscribe(self._shadowName, "delta") self._logger.info("Unsubscribed to delta topics for deviceShadow: " + self._shadowName) diff --git a/AWSIoTPythonSDK/core/shadow/shadowManager.py b/AWSIoTPythonSDK/core/shadow/shadowManager.py index 2572aef..3dafa74 100755 --- a/AWSIoTPythonSDK/core/shadow/shadowManager.py +++ b/AWSIoTPythonSDK/core/shadow/shadowManager.py @@ -1,5 +1,5 @@ # /* -# * Copyright 2010-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # * # * Licensed under the Apache License, Version 2.0 (the "License"). # * You may not use this file except in compliance with the License. @@ -57,32 +57,27 @@ def __init__(self, srcMQTTCore): self._mqttCoreHandler = srcMQTTCore self._shadowSubUnsubOperationLock = Lock() - def getClientID(self): - return self._mqttCoreHandler.getClientID() - def basicShadowPublish(self, srcShadowName, srcShadowAction, srcPayload): currentShadowAction = _shadowAction(srcShadowName, srcShadowAction) self._mqttCoreHandler.publish(currentShadowAction.getTopicGeneral(), srcPayload, 0, False) def basicShadowSubscribe(self, srcShadowName, srcShadowAction, srcCallback): - self._shadowSubUnsubOperationLock.acquire() - currentShadowAction = _shadowAction(srcShadowName, srcShadowAction) - if currentShadowAction.isDelta: - self._mqttCoreHandler.subscribe(currentShadowAction.getTopicDelta(), 0, srcCallback) - else: - self._mqttCoreHandler.subscribe(currentShadowAction.getTopicAccept(), 0, srcCallback) - self._mqttCoreHandler.subscribe(currentShadowAction.getTopicReject(), 0, srcCallback) - time.sleep(2) - self._shadowSubUnsubOperationLock.release() + with self._shadowSubUnsubOperationLock: + currentShadowAction = _shadowAction(srcShadowName, srcShadowAction) + if currentShadowAction.isDelta: + self._mqttCoreHandler.subscribe(currentShadowAction.getTopicDelta(), 0, srcCallback) + else: + self._mqttCoreHandler.subscribe(currentShadowAction.getTopicAccept(), 0, srcCallback) + self._mqttCoreHandler.subscribe(currentShadowAction.getTopicReject(), 0, srcCallback) + time.sleep(2) def basicShadowUnsubscribe(self, srcShadowName, srcShadowAction): - self._shadowSubUnsubOperationLock.acquire() - currentShadowAction = _shadowAction(srcShadowName, srcShadowAction) - if currentShadowAction.isDelta: - self._mqttCoreHandler.unsubscribe(currentShadowAction.getTopicDelta()) - else: - self._logger.debug(currentShadowAction.getTopicAccept()) - self._mqttCoreHandler.unsubscribe(currentShadowAction.getTopicAccept()) - self._logger.debug(currentShadowAction.getTopicReject()) - self._mqttCoreHandler.unsubscribe(currentShadowAction.getTopicReject()) - self._shadowSubUnsubOperationLock.release() + with self._shadowSubUnsubOperationLock: + currentShadowAction = _shadowAction(srcShadowName, srcShadowAction) + if currentShadowAction.isDelta: + self._mqttCoreHandler.unsubscribe(currentShadowAction.getTopicDelta()) + else: + self._logger.debug(currentShadowAction.getTopicAccept()) + self._mqttCoreHandler.unsubscribe(currentShadowAction.getTopicAccept()) + self._logger.debug(currentShadowAction.getTopicReject()) + self._mqttCoreHandler.unsubscribe(currentShadowAction.getTopicReject()) diff --git a/AWSIoTPythonSDK/core/util/enums.py b/AWSIoTPythonSDK/core/util/enums.py new file mode 100644 index 0000000..3aa3d2f --- /dev/null +++ b/AWSIoTPythonSDK/core/util/enums.py @@ -0,0 +1,19 @@ +# /* +# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# * +# * Licensed under the Apache License, Version 2.0 (the "License"). +# * You may not use this file except in compliance with the License. +# * A copy of the License is located at +# * +# * http://aws.amazon.com/apache2.0 +# * +# * or in the "license" file accompanying this file. This file is distributed +# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# * express or implied. See the License for the specific language governing +# * permissions and limitations under the License. +# */ + + +class DropBehaviorTypes(object): + DROP_OLDEST = 0 + DROP_NEWEST = 1 diff --git a/AWSIoTPythonSDK/core/util/offlinePublishQueue.py b/AWSIoTPythonSDK/core/util/offlinePublishQueue.py deleted file mode 100755 index 8ba2d44..0000000 --- a/AWSIoTPythonSDK/core/util/offlinePublishQueue.py +++ /dev/null @@ -1,92 +0,0 @@ -# /* -# * Copyright 2010-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# * -# * Licensed under the Apache License, Version 2.0 (the "License"). -# * You may not use this file except in compliance with the License. -# * A copy of the License is located at -# * -# * http://aws.amazon.com/apache2.0 -# * -# * or in the "license" file accompanying this file. This file is distributed -# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# * express or implied. See the License for the specific language governing -# * permissions and limitations under the License. -# */ - -# This class implements the offline Publish Queue, with configurable length and drop behaviors. -# This queue will be used as the offline Publish Queue for all message outside Paho as an option -# to publish to when the client is offline. -# DROP_OLDEST: Drop the head of the queue when the size limit is reached. -# DROP_NEWEST: Drop the new incoming elements when the size limit is reached. - -import logging - -class offlinePublishQueue(list): - - _DROPBEHAVIOR_OLDEST = 0 - _DROPBEHAVIOR_NEWEST = 1 - - APPEND_FAILURE_QUEUE_FULL = -1 - APPEND_FAILURE_QUEUE_DISABLED = -2 - APPEND_SUCCESS = 0 - - _logger = logging.getLogger(__name__) - - def __init__(self, srcMaximumSize, srcDropBehavior=1): - if not isinstance(srcMaximumSize, int) or not isinstance(srcDropBehavior, int): - self._logger.error("init: MaximumSize/DropBehavior must be integer.") - raise TypeError("MaximumSize/DropBehavior must be integer.") - if srcDropBehavior != self._DROPBEHAVIOR_OLDEST and srcDropBehavior != self._DROPBEHAVIOR_NEWEST: - self._logger.error("init: Drop behavior not supported.") - raise ValueError("Drop behavior not supported.") - list.__init__([]) - self._dropBehavior = srcDropBehavior - # When self._maximumSize > 0, queue is limited - # When self._maximumSize == 0, queue is disabled - # When self._maximumSize < 0. queue is infinite - self._maximumSize = srcMaximumSize - - def _isEnabled(self): - return self._maximumSize != 0 - - def _needDropMessages(self): - # Need to drop messages when: - # 1. Queue is limited and full - # 2. Queue is disabled - isQueueFull = len(self) >= self._maximumSize - isQueueLimited = self._maximumSize > 0 - isQueueDisabled = not self._isEnabled() - return (isQueueFull and isQueueLimited) or isQueueDisabled - - def setQueueBehaviorDropNewest(self): - self._dropBehavior = self._DROPBEHAVIOR_NEWEST - - def setQueueBehaviorDropOldest(self): - self._dropBehavior = self._DROPBEHAVIOR_OLDEST - - # Override - # Append to a queue with a limited size. - # Return APPEND_SUCCESS if the append is successful - # Return APPEND_FAILURE_QUEUE_FULL if the append failed because the queue is full - # Return APPEND_FAILURE_QUEUE_DISABLED if the append failed because the queue is disabled - def append(self, srcData): - ret = self.APPEND_SUCCESS - if self._isEnabled(): - if self._needDropMessages(): - # We should drop the newest - if self._dropBehavior == self._DROPBEHAVIOR_NEWEST: - self._logger.warn("append: Full queue. Drop the newest: " + str(srcData)) - ret = self.APPEND_FAILURE_QUEUE_FULL - # We should drop the oldest - else: - currentOldest = super(offlinePublishQueue, self).pop(0) - self._logger.warn("append: Full queue. Drop the oldest: " + str(currentOldest)) - super(offlinePublishQueue, self).append(srcData) - ret = self.APPEND_FAILURE_QUEUE_FULL - else: - self._logger.debug("append: Add new element: " + str(srcData)) - super(offlinePublishQueue, self).append(srcData) - else: - self._logger.debug("append: Queue is disabled. Drop the message: " + str(srcData)) - ret = self.APPEND_FAILURE_QUEUE_DISABLED - return ret diff --git a/AWSIoTPythonSDK/core/util/progressiveBackoffCore.py b/AWSIoTPythonSDK/core/util/progressiveBackoffCore.py deleted file mode 100755 index cc56533..0000000 --- a/AWSIoTPythonSDK/core/util/progressiveBackoffCore.py +++ /dev/null @@ -1,91 +0,0 @@ -# /* -# * Copyright 2010-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# * -# * Licensed under the Apache License, Version 2.0 (the "License"). -# * You may not use this file except in compliance with the License. -# * A copy of the License is located at -# * -# * http://aws.amazon.com/apache2.0 -# * -# * or in the "license" file accompanying this file. This file is distributed -# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# * express or implied. See the License for the specific language governing -# * permissions and limitations under the License. -# */ - - # This class implements the progressive backoff logic for auto-reconnect. - # It manages the reconnect wait time for the current reconnect, controling - # when to increase it and when to reset it. - -import time -import threading -import logging - - -class progressiveBackoffCore: - - # Logger - _logger = logging.getLogger(__name__) - - def __init__(self, srcBaseReconnectTimeSecond=1, srcMaximumReconnectTimeSecond=32, srcMinimumConnectTimeSecond=20): - # The base reconnection time in seconds, default 1 - self._baseReconnectTimeSecond = srcBaseReconnectTimeSecond - # The maximum reconnection time in seconds, default 32 - self._maximumReconnectTimeSecond = srcMaximumReconnectTimeSecond - # The minimum time in milliseconds that a connection must be maintained in order to be considered stable - # Default 20 - self._minimumConnectTimeSecond = srcMinimumConnectTimeSecond - # Current backOff time in seconds, init to equal to 0 - self._currentBackoffTimeSecond = 1 - # Handler for timer - self._resetBackoffTimer = None - - # For custom progressiveBackoff timing configuration - def configTime(self, srcBaseReconnectTimeSecond, srcMaximumReconnectTimeSecond, srcMinimumConnectTimeSecond): - if srcBaseReconnectTimeSecond < 0 or srcMaximumReconnectTimeSecond < 0 or srcMinimumConnectTimeSecond < 0: - self._logger.error("init: Negative time configuration detected.") - raise ValueError("Negative time configuration detected.") - if srcBaseReconnectTimeSecond >= srcMinimumConnectTimeSecond: - self._logger.error("init: Min connect time should be bigger than base reconnect time.") - raise ValueError("Min connect time should be bigger than base reconnect time.") - self._baseReconnectTimeSecond = srcBaseReconnectTimeSecond - self._maximumReconnectTimeSecond = srcMaximumReconnectTimeSecond - self._minimumConnectTimeSecond = srcMinimumConnectTimeSecond - self._currentBackoffTimeSecond = 1 - - # Block the reconnect logic for _currentBackoffTimeSecond - # Update the currentBackoffTimeSecond for the next reconnect - # Cancel the in-waiting timer for resetting backOff time - # This should get called only when a disconnect/reconnect happens - def backOff(self): - self._logger.debug("backOff: current backoff time is: " + str(self._currentBackoffTimeSecond) + " sec.") - if self._resetBackoffTimer is not None: - # Cancel the timer - self._resetBackoffTimer.cancel() - # Block the reconnect logic - time.sleep(self._currentBackoffTimeSecond) - # Update the backoff time - if self._currentBackoffTimeSecond == 0: - # This is the first attempt to connect, set it to base - self._currentBackoffTimeSecond = self._baseReconnectTimeSecond - else: - # r_cur = min(2^n*r_base, r_max) - self._currentBackoffTimeSecond = min(self._maximumReconnectTimeSecond, self._currentBackoffTimeSecond * 2) - - # Start the timer for resetting _currentBackoffTimeSecond - # Will be cancelled upon calling backOff - def startStableConnectionTimer(self): - self._resetBackoffTimer = threading.Timer(self._minimumConnectTimeSecond, self._connectionStableThenResetBackoffTime) - self._resetBackoffTimer.start() - - def stopStableConnectionTimer(self): - if self._resetBackoffTimer is not None: - # Cancel the timer - self._resetBackoffTimer.cancel() - - # Timer callback to reset _currentBackoffTimeSecond - # If the connection is stable for longer than _minimumConnectTimeSecond, - # reset the currentBackoffTimeSecond to _baseReconnectTimeSecond - def _connectionStableThenResetBackoffTime(self): - self._logger.debug("stableConnection: Resetting the backoff time to: " + str(self._baseReconnectTimeSecond) + " sec.") - self._currentBackoffTimeSecond = self._baseReconnectTimeSecond diff --git a/AWSIoTPythonSDK/core/util/providers.py b/AWSIoTPythonSDK/core/util/providers.py new file mode 100644 index 0000000..d09f8a0 --- /dev/null +++ b/AWSIoTPythonSDK/core/util/providers.py @@ -0,0 +1,102 @@ +# /* +# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# * +# * Licensed under the Apache License, Version 2.0 (the "License"). +# * You may not use this file except in compliance with the License. +# * A copy of the License is located at +# * +# * http://aws.amazon.com/apache2.0 +# * +# * or in the "license" file accompanying this file. This file is distributed +# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# * express or implied. See the License for the specific language governing +# * permissions and limitations under the License. +# */ + + +class CredentialsProvider(object): + + def __init__(self): + self._ca_path = "" + + def set_ca_path(self, ca_path): + self._ca_path = ca_path + + def get_ca_path(self): + return self._ca_path + + +class CertificateCredentialsProvider(CredentialsProvider): + + def __init__(self): + CredentialsProvider.__init__(self) + self._cert_path = "" + self._key_path = "" + + def set_cert_path(self,cert_path): + self._cert_path = cert_path + + def set_key_path(self, key_path): + self._key_path = key_path + + def get_cert_path(self): + return self._cert_path + + def get_key_path(self): + return self._key_path + + +class IAMCredentialsProvider(CredentialsProvider): + + def __init__(self): + CredentialsProvider.__init__(self) + self._aws_access_key_id = "" + self._aws_secret_access_key = "" + self._aws_session_token = "" + + def set_access_key_id(self, access_key_id): + self._aws_access_key_id = access_key_id + + def set_secret_access_key(self, secret_access_key): + self._aws_secret_access_key = secret_access_key + + def set_session_token(self, session_token): + self._aws_session_token = session_token + + def get_access_key_id(self): + return self._aws_access_key_id + + def get_secret_access_key(self): + return self._aws_secret_access_key + + def get_session_token(self): + return self._aws_session_token + + +class EndpointProvider(object): + + def __init__(self): + self._host = "" + self._port = -1 + + def set_host(self, host): + self._host = host + + def set_port(self, port): + self._port = port + + def get_host(self): + return self._host + + def get_port(self): + return self._port + +class CiphersProvider(object): + def __init__(self): + self._ciphers = None + + def set_ciphers(self, ciphers=None): + self._ciphers = ciphers + + def get_ciphers(self): + return self._ciphers diff --git a/AWSIoTPythonSDK/core/util/sigV4Core.py b/AWSIoTPythonSDK/core/util/sigV4Core.py deleted file mode 100755 index 0b22dab..0000000 --- a/AWSIoTPythonSDK/core/util/sigV4Core.py +++ /dev/null @@ -1,187 +0,0 @@ -# /* -# * Copyright 2010-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# * -# * Licensed under the Apache License, Version 2.0 (the "License"). -# * You may not use this file except in compliance with the License. -# * A copy of the License is located at -# * -# * http://aws.amazon.com/apache2.0 -# * -# * or in the "license" file accompanying this file. This file is distributed -# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# * express or implied. See the License for the specific language governing -# * permissions and limitations under the License. -# */ - -# This class implements the sigV4 signing process and return the signed URL for connection - -import os -import datetime -import hashlib -import hmac -try: - from urllib.parse import quote # Python 3+ -except ImportError: - from urllib import quote -import logging -# INI config file handling -try: - from configparser import ConfigParser # Python 3+ - from configparser import NoOptionError - from configparser import NoSectionError -except ImportError: - from ConfigParser import ConfigParser - from ConfigParser import NoOptionError - from ConfigParser import NoSectionError - -class sigV4Core: - - _logger = logging.getLogger(__name__) - - def __init__(self): - self._aws_access_key_id = "" - self._aws_secret_access_key = "" - self._aws_session_token = "" - self._credentialConfigFilePath = "~/.aws/credentials" - - def setIAMCredentials(self, srcAWSAccessKeyID, srcAWSSecretAccessKey, srcAWSSessionToken): - self._aws_access_key_id = srcAWSAccessKeyID - self._aws_secret_access_key = srcAWSSecretAccessKey - self._aws_session_token = srcAWSSessionToken - - def _createAmazonDate(self): - # Returned as a unicode string in Py3.x - amazonDate = [] - currentTime = datetime.datetime.utcnow() - YMDHMS = currentTime.strftime('%Y%m%dT%H%M%SZ') - YMD = YMDHMS[0:YMDHMS.index('T')] - amazonDate.append(YMD) - amazonDate.append(YMDHMS) - return amazonDate - - def _sign(self, key, message): - # Returned as a utf-8 byte string in Py3.x - return hmac.new(key, message.encode('utf-8'), hashlib.sha256).digest() - - def _getSignatureKey(self, key, dateStamp, regionName, serviceName): - # Returned as a utf-8 byte string in Py3.x - kDate = self._sign(('AWS4' + key).encode('utf-8'), dateStamp) - kRegion = self._sign(kDate, regionName) - kService = self._sign(kRegion, serviceName) - kSigning = self._sign(kService, 'aws4_request') - return kSigning - - def _checkIAMCredentials(self): - # Check custom config - ret = self._checkKeyInCustomConfig() - # Check environment variables - if not ret: - ret = self._checkKeyInEnv() - # Check files - if not ret: - ret = self._checkKeyInFiles() - # All credentials returned as unicode strings in Py3.x - return ret - - def _checkKeyInEnv(self): - ret = dict() - self._aws_access_key_id = os.environ.get('AWS_ACCESS_KEY_ID') - self._aws_secret_access_key = os.environ.get('AWS_SECRET_ACCESS_KEY') - self._aws_session_token = os.environ.get('AWS_SESSION_TOKEN') - if self._aws_access_key_id is not None and self._aws_secret_access_key is not None: - ret["aws_access_key_id"] = self._aws_access_key_id - ret["aws_secret_access_key"] = self._aws_secret_access_key - # We do not necessarily need session token... - if self._aws_session_token is not None: - ret["aws_session_token"] = self._aws_session_token - self._logger.debug("IAM credentials from env var.") - return ret - - def _checkKeyInINIDefault(self, srcConfigParser, sectionName): - ret = dict() - # Check aws_access_key_id and aws_secret_access_key - try: - ret["aws_access_key_id"] = srcConfigParser.get(sectionName, "aws_access_key_id") - ret["aws_secret_access_key"] = srcConfigParser.get(sectionName, "aws_secret_access_key") - except NoOptionError: - self._logger.warn("Cannot find IAM keyID/secretKey in credential file.") - # We do not continue searching if we cannot even get IAM id/secret right - if len(ret) == 2: - # Check aws_session_token, optional - try: - ret["aws_session_token"] = srcConfigParser.get(sectionName, "aws_session_token") - except NoOptionError: - self._logger.debug("No AWS Session Token found.") - return ret - - def _checkKeyInFiles(self): - credentialFile = None - credentialConfig = None - ret = dict() - # Should be compatible with aws cli default credential configuration - # *NIX/Windows - try: - # See if we get the file - credentialConfig = ConfigParser() - credentialFilePath = os.path.expanduser(self._credentialConfigFilePath) # Is it compatible with windows? \/ - credentialConfig.read(credentialFilePath) - # Now we have the file, start looking for credentials... - # 'default' section - ret = self._checkKeyInINIDefault(credentialConfig, "default") - if not ret: - # 'DEFAULT' section - ret = self._checkKeyInINIDefault(credentialConfig, "DEFAULT") - self._logger.debug("IAM credentials from file.") - except IOError: - self._logger.debug("No IAM credential configuration file in " + credentialFilePath) - except NoSectionError: - self._logger.error("Cannot find IAM 'default' section.") - return ret - - def _checkKeyInCustomConfig(self): - ret = dict() - if self._aws_access_key_id != "" and self._aws_secret_access_key != "": - ret["aws_access_key_id"] = self._aws_access_key_id - ret["aws_secret_access_key"] = self._aws_secret_access_key - # We do not necessarily need session token... - if self._aws_session_token != "": - ret["aws_session_token"] = self._aws_session_token - self._logger.debug("IAM credentials from custom config.") - return ret - - def createWebsocketEndpoint(self, host, port, region, method, awsServiceName, path): - # Return the endpoint as unicode string in 3.x - # Gather all the facts - amazonDate = self._createAmazonDate() - amazonDateSimple = amazonDate[0] # Unicode in 3.x - amazonDateComplex = amazonDate[1] # Unicode in 3.x - allKeys = self._checkIAMCredentials() # Unicode in 3.x - hasCredentialsNecessaryForWebsocket = "aws_access_key_id" in allKeys.keys() and "aws_secret_access_key" in allKeys.keys() - if not hasCredentialsNecessaryForWebsocket: - return "" - else: - keyID = allKeys["aws_access_key_id"] - secretKey = allKeys["aws_secret_access_key"] - queryParameters = "X-Amz-Algorithm=AWS4-HMAC-SHA256" + \ - "&X-Amz-Credential=" + keyID + "%2F" + amazonDateSimple + "%2F" + region + "%2F" + awsServiceName + "%2Faws4_request" + \ - "&X-Amz-Date=" + amazonDateComplex + \ - "&X-Amz-Expires=86400" + \ - "&X-Amz-SignedHeaders=host" # Unicode in 3.x - hashedPayload = hashlib.sha256(str("").encode('utf-8')).hexdigest() # Unicode in 3.x - # Create the string to sign - signedHeaders = "host" - canonicalHeaders = "host:" + host + "\n" - canonicalRequest = method + "\n" + path + "\n" + queryParameters + "\n" + canonicalHeaders + "\n" + signedHeaders + "\n" + hashedPayload # Unicode in 3.x - hashedCanonicalRequest = hashlib.sha256(str(canonicalRequest).encode('utf-8')).hexdigest() # Unicoede in 3.x - stringToSign = "AWS4-HMAC-SHA256\n" + amazonDateComplex + "\n" + amazonDateSimple + "/" + region + "/" + awsServiceName + "/aws4_request\n" + hashedCanonicalRequest # Unicode in 3.x - # Sign it - signingKey = self._getSignatureKey(secretKey, amazonDateSimple, region, awsServiceName) - signature = hmac.new(signingKey, (stringToSign).encode("utf-8"), hashlib.sha256).hexdigest() - # generate url - url = "wss://" + host + ":" + str(port) + path + '?' + queryParameters + "&X-Amz-Signature=" + signature - # See if we have STS token, if we do, add it - if "aws_session_token" in allKeys.keys(): - aws_session_token = allKeys["aws_session_token"] - url += "&X-Amz-Security-Token=" + quote(aws_session_token.encode("utf-8")) # Unicode in 3.x - self._logger.debug("createWebsocketEndpoint: Websocket URL: " + url) - return url diff --git a/AWSIoTPythonSDK/exception/AWSIoTExceptions.py b/AWSIoTPythonSDK/exception/AWSIoTExceptions.py index 0ddfa73..0de5401 100755 --- a/AWSIoTPythonSDK/exception/AWSIoTExceptions.py +++ b/AWSIoTPythonSDK/exception/AWSIoTExceptions.py @@ -1,5 +1,5 @@ # /* -# * Copyright 2010-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # * # * Licensed under the Apache License, Version 2.0 (the "License"). # * You may not use this file except in compliance with the License. @@ -80,11 +80,31 @@ def __init__(self, errorCode): self.message = "Subscribe Error: " + str(errorCode) +class subscribeQueueFullException(operationError.operationError): + def __init__(self): + self.message = "Internal Subscribe Queue Full" + + +class subscribeQueueDisabledException(operationError.operationError): + def __init__(self): + self.message = "Offline subscribe request dropped because queueing is disabled" + + class unsubscribeError(operationError.operationError): def __init__(self, errorCode): self.message = "Unsubscribe Error: " + str(errorCode) +class unsubscribeQueueFullException(operationError.operationError): + def __init__(self): + self.message = "Internal Unsubscribe Queue Full" + + +class unsubscribeQueueDisabledException(operationError.operationError): + def __init__(self): + self.message = "Offline unsubscribe request dropped because queueing is disabled" + + # Websocket Error class wssNoKeyInEnvironmentError(operationError.operationError): def __init__(self): @@ -94,3 +114,40 @@ def __init__(self): class wssHandShakeError(operationError.operationError): def __init__(self): self.message = "Error in WSS handshake." + + +# Greengrass Discovery Error +class DiscoveryDataNotFoundException(operationError.operationError): + def __init__(self): + self.message = "No discovery data found" + + +class DiscoveryTimeoutException(operationTimeoutException.operationTimeoutException): + def __init__(self, message="Discovery request timed out"): + self.message = message + + +class DiscoveryInvalidRequestException(operationError.operationError): + def __init__(self): + self.message = "Invalid discovery request" + + +class DiscoveryUnauthorizedException(operationError.operationError): + def __init__(self): + self.message = "Discovery request not authorized" + + +class DiscoveryThrottlingException(operationError.operationError): + def __init__(self): + self.message = "Too many discovery requests" + + +class DiscoveryFailure(operationError.operationError): + def __init__(self, message): + self.message = message + + +# Client Error +class ClientError(Exception): + def __init__(self, message): + self.message = message diff --git a/AWSIoTPythonSDK/exception/operationError.py b/AWSIoTPythonSDK/exception/operationError.py index efbb399..1c86dfc 100755 --- a/AWSIoTPythonSDK/exception/operationError.py +++ b/AWSIoTPythonSDK/exception/operationError.py @@ -1,5 +1,5 @@ # /* -# * Copyright 2010-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # * # * Licensed under the Apache License, Version 2.0 (the "License"). # * You may not use this file except in compliance with the License. diff --git a/AWSIoTPythonSDK/exception/operationTimeoutException.py b/AWSIoTPythonSDK/exception/operationTimeoutException.py index 48d4f15..737154e 100755 --- a/AWSIoTPythonSDK/exception/operationTimeoutException.py +++ b/AWSIoTPythonSDK/exception/operationTimeoutException.py @@ -1,5 +1,5 @@ # /* -# * Copyright 2010-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # * # * Licensed under the Apache License, Version 2.0 (the "License"). # * You may not use this file except in compliance with the License. diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 9c48d50..765c557 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -2,6 +2,93 @@ CHANGELOG ========= +1.5.5 +===== +* chore: Update minimum Python version based on current supportable levels + +1.5.4 +===== +* chore: CD pipeline flushing try #1 + +1.5.3 +===== +* improvement: Support Python3.12+ by conditionally removing deprecated API usage + +1.4.9 +===== +* bugfix: Fixing possible race condition with timer in deviceShadow. + +1.4.8 +===== +* improvement: Added support for subscription acknowledgement callbacks while offline or resubscribing + +1.4.7 +===== +* improvement: Added connection establishment control through client socket factory option + +1.4.6 +===== +* bugfix: Use non-deprecated ssl API to specify ALPN when doing Greengrass discovery + +1.4.5 +===== +* improvement: Added validation to mTLS arguments in basicDiscovery + +1.4.3 +===== +* bugfix: [Issue #150](https://github.com/aws/aws-iot-device-sdk-python/issues/150)Fix for ALPN in Python 3.7 + +1.4.2 +===== +* bugfix: Websocket handshake supports Amazon Trust Store (ats) endpoints +* bugfix: Remove default port number in samples, which prevented WebSocket mode from using 443 +* bugfix: jobsSample print statements compatible with Python 3.x +* improvement: Small fixes to IoT Jobs documentation + + +1.4.0 +===== +* bugfix:Issue `#136 ` +* bugfix:Issue:`#124 ` +* improvement:Expose the missing getpeercert() from SecuredWebsocket class +* improvement:Enforce sending host header in the outbound discovery request +* improvement:Ensure credentials non error are properly handled and communicated to application level when creating wss endpoint +* feature:Add support for ALPN, along with API docs, sample and updated README +* feature:Add support for IoT Jobs, along with API docs, sample and updated README +* feature:Add command line option to allow port number override + +1.3.1 +===== +* bugfix:Issue:`#67 `__ +* bugfix:Fixed a dead lock issue when client async API is called within the event callback +* bugfix:Updated README and API documentation to provide clear usage information on sync/async API and callbacks +* improvement:Added a new sample to show API usage within callbacks + +1.3.0 +===== +* bugfix:WebSocket handshake response timeout and error escalation +* bugfix:Prevent GG discovery from crashing if Metadata field is None +* bugfix:Fix the client object reusability issue +* bugfix:Prevent NPE due to shadow operation token not found in the pool +* improvement:Split the publish and subscribe operations in basicPubSub.py sample +* improvement:Updated default connection keep-alive interval to 600 seconds +* feature:AWSIoTMQTTClient:New API for username and password configuration +* feature:AWSIoTMQTTShadowClient:New API for username and password configuration +* feature:AWSIoTMQTTClient:New API for enabling/disabling metrics collection +* feature:AWSIoTMQTTShadowClient:New API for enabling/disabling metrics collection + +1.2.0 +===== +* improvement:AWSIoTMQTTClient:Improved synchronous API backend for ACK tracking +* feature:AWSIoTMQTTClient:New API for asynchronous API +* feature:AWSIoTMQTTClient:Expose general notification callbacks for online, offline and message arrival +* feature:AWSIoTMQTTShadowClient:Expose general notification callbacks for online, offline and message arrival +* feature:AWSIoTMQTTClient:Extend offline queueing to include offline subscribe/unsubscribe requests +* feature:DiscoveryInfoProvider:Support for Greengrass discovery +* bugfix:Pull request:`#50 `__ +* bugfix:Pull request:`#51 `__ +* bugfix:Issue:`#52 `__ + 1.1.2 ===== * bugfix:Issue:`#28 `__ diff --git a/README.rst b/README.rst index 6a7eaf1..ba88218 100755 --- a/README.rst +++ b/README.rst @@ -1,3 +1,9 @@ +New Version Available +============================= +A new AWS IoT Device SDK is [now available](https://github.com/awslabs/aws-iot-device-sdk-python-v2). It is a complete rework, built to improve reliability, performance, and security. We invite your feedback! + +This SDK will no longer receive feature updates, but will receive security updates. + AWS IoT Device SDK for Python ============================= @@ -40,8 +46,9 @@ IoT: - MQTT (over TLS 1.2) with X.509 certificate-based mutual authentication. - MQTT over the WebSocket protocol with AWS Signature Version 4 authentication. +- MQTT (over TLS 1.2) with X.509 certificate-based mutual authentication with TLS ALPN extension. -For MQTT over TLS (port 8883), a valid certificate and a private key are +For MQTT over TLS (port 8883 and port 443), a valid certificate and a private key are required for authentication. For MQTT over the WebSocket protocol (port 443), a valid AWS Identity and Access Management (IAM) access key ID and secret access key pair are required for authentication. @@ -63,20 +70,10 @@ also allows the use of the same connection for shadow operations and non-shadow, Installation ~~~~~~~~~~~~ -Minimum Requirements +Requirements ____________________ -- Python 2.7+ or Python 3.3+ -- OpenSSL version 1.0.1+ (TLS version 1.2) compiled with the Python executable for - X.509 certificate-based mutual authentication - - To check your version of OpenSSL, use the following command in a Python interpreter: - - .. code-block:: python - - >>> import ssl - >>> ssl.OPENSSL_VERSION - +- Python3.8+. The SDK has worked for older Python versions in the past, but they are no longer formally supported. Over time, expect the minimum Python version to loosely track the minimum non-end-of-life version. Install from pip ________________ @@ -110,6 +107,24 @@ The SDK zip file is available `here `__. + CA `__. Use the AWS IoT console to create and download the certificate and private key. You must specify the location of these files when you initialize the client. @@ -130,8 +145,8 @@ types: For the Websocket with Signature Version 4 authentication type. You will need IAM credentials: an access key ID, a secret access key, and an optional session token. You must also download the `AWS IoT root - CA `__. - You can specify the IAM credentails by: + CA `__. + You can specify the IAM credentials by: - Passing method parameters @@ -221,6 +236,8 @@ You can initialize and configure the client like this: myMQTTClient.configureEndpoint("YOUR.ENDPOINT", 8883) # For Websocket # myMQTTClient.configureEndpoint("YOUR.ENDPOINT", 443) + # For TLS mutual authentication with TLS ALPN extension + # myMQTTClient.configureEndpoint("YOUR.ENDPOINT", 443) myMQTTClient.configureCredentials("YOUR/ROOT/CA/PATH", "PRIVATE/KEY/PATH", "CERTIFICATE/PATH") # For Websocket, we only need to configure the root CA # myMQTTClient.configureCredentials("YOUR/ROOT/CA/PATH") @@ -261,6 +278,8 @@ You can initialize and configure the client like this: myShadowClient.configureEndpoint("YOUR.ENDPOINT", 8883) # For Websocket # myShadowClient.configureEndpoint("YOUR.ENDPOINT", 443) + # For TLS mutual authentication with TLS ALPN extension + # myShadowClient.configureEndpoint("YOUR.ENDPOINT", 443) myShadowClient.configureCredentials("YOUR/ROOT/CA/PATH", "PRIVATE/KEY/PATH", "CERTIFICATE/PATH") # For Websocket, we only need to configure the root CA # myShadowClient.configureCredentials("YOUR/ROOT/CA/PATH") @@ -292,13 +311,132 @@ MQTT operations along with shadow operations: myMQTTClient = myShadowClient.getMQTTConnection() myMQTTClient.publish("plainMQTTTopic", "Payload", 1) +AWSIoTMQTTThingJobsClient +__________________ + +This is the client class used for jobs operations with AWS IoT. See docs here: +https://docs.aws.amazon.com/iot/latest/developerguide/iot-jobs.html +You can initialize and configure the client like this: + +.. code-block:: python + + from AWSIoTPythonSDK.MQTTLib import AWSIoTMQTTThingJobsClient + + # For certificate based connection + myJobsClient = AWSIoTMQTTThingJobsClient("myClientID", "myThingName") + # For Websocket connection + # myJobsClient = AWSIoTMQTTThingJobsClient("myClientID", "myThingName", useWebsocket=True) + # Configurations + # For TLS mutual authentication + myJobsClient.configureEndpoint("YOUR.ENDPOINT", 8883) + # For Websocket + # myJobsClient.configureEndpoint("YOUR.ENDPOINT", 443) + myJobsClient.configureCredentials("YOUR/ROOT/CA/PATH", "PRIVATE/KEY/PATH", "CERTIFICATE/PATH") + # For Websocket, we only need to configure the root CA + # myJobsClient.configureCredentials("YOUR/ROOT/CA/PATH") + myJobsClient.configureConnectDisconnectTimeout(10) # 10 sec + myJobsClient.configureMQTTOperationTimeout(5) # 5 sec + ... + +For job operations, your script will look like this: + +.. code-block:: python + + ... + myJobsClient.connect() + # Create a subsciption for $notify-next topic + myJobsClient.createJobSubscription(notifyNextCallback, jobExecutionTopicType.JOB_NOTIFY_NEXT_TOPIC) + # Create a subscription for update-job-execution accepted response topic + myJobsClient.createJobSubscription(updateSuccessfulCallback, jobExecutionTopicType.JOB_UPDATE_TOPIC, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE, '+') + # Send a message to start the next pending job (if any) + myJobsClient.sendJobsStartNext(statusDetailsDict) + # Send a message to update a successfully completed job + myJobsClient.sendJobsUpdate(jobId, jobExecutionStatus.JOB_EXECUTION_SUCCEEDED, statusDetailsDict) + ... + +You can also retrieve the MQTTClient(MQTT connection) to perform plain +MQTT operations along with shadow operations: + +.. code-block:: python + + myMQTTClient = myJobsClient.getMQTTConnection() + myMQTTClient.publish("plainMQTTTopic", "Payload", 1) + +DiscoveryInfoProvider +_____________________ + +This is the client class for device discovery process with AWS IoT Greengrass. +You can initialize and configure the client like this: + +.. code-block:: python + + from AWSIoTPythonSDK.core.greengrass.discovery.providers import DiscoveryInfoProvider + + discoveryInfoProvider = DiscoveryInfoProvider() + discoveryInfoProvider.configureEndpoint("YOUR.IOT.ENDPOINT") + discoveryInfoProvider.configureCredentials("YOUR/ROOT/CA/PATH", "CERTIFICATE/PATH", "PRIVATE/KEY/PATH") + discoveryInfoProvider.configureTimeout(10) # 10 sec + +To perform the discovery process for a Greengrass Aware Device (GGAD) that belongs to a deployed group, your script +should look like this: + +.. code-block:: python + + discoveryInfo = discoveryInfoProvider.discover("myGGADThingName") + # I know nothing about the group/core I want to connect to. I want to iterate through all cores and find out. + coreList = discoveryInfo.getAllCores() + groupIdCAList = discoveryInfo.getAllCas() # list([(groupId, ca), ...]) + # I know nothing about the group/core I want to connect to. I want to iterate through all groups and find out. + groupList = discoveryInfo.getAllGroups() + # I know exactly which group, which core and which connectivity info I need to connect. + connectivityInfo = discoveryInfo.toObjectAtGroupLevel()["YOUR_GROUP_ID"] + .getCoreConnectivityInfo("YOUR_CORE_THING_ARN") + .getConnectivityInfo("YOUR_CONNECTIVITY_ID") + # Connecting logic follows... + ... + +For more information about discovery information access at group/core/connectivity info set level, please refer to the +API documentation for ``AWSIoTPythonSDK.core.greengrass.discovery.models``, +`Greengrass Discovery documentation `__ +or `Greengrass overall documentation `__. + + +Synchronous APIs and Asynchronous APIs +______________________________________ + +Beginning with Release v1.2.0, SDK provides asynchronous APIs and enforces synchronous API behaviors for MQTT operations, +which includes: +- connect/connectAsync +- disconnect/disconnectAsync +- publish/publishAsync +- subscribe/subscribeAsync +- unsubscribe/unsubscribeAsync + +- Asynchronous APIs +Asynchronous APIs translate the invocation into MQTT packet and forward it to the underneath connection to be sent out. +They return immediately once packets are out for delivery, regardless of whether the corresponding ACKs, if any, have +been received. Users can specify their own callbacks for ACK/message (server side PUBLISH) processing for each +individual request. These callbacks will be sequentially dispatched and invoked upon the arrival of ACK/message (server +side PUBLISH) packets. + +- Synchronous APIs +Synchronous API behaviors are enforced by registering blocking ACK callbacks on top of the asynchronous APIs. +Synchronous APIs wait on their corresponding ACK packets, if there is any, before the invocation returns. For example, +a synchronous QoS1 publish call will wait until it gets its PUBACK back. A synchronous subscribe call will wait until +it gets its SUBACK back. Users can configure operation time out for synchronous APIs to stop the waiting. + +Since callbacks are sequentially dispatched and invoked, calling synchronous APIs within callbacks will deadlock the +user application. If users are inclined to utilize the asynchronous mode and perform MQTT operations +within callbacks, asynchronous APIs should be used. For more details, please check out the provided samples at +``samples/basicPubSub/basicPubSub_APICallInCallback.py`` + .. _Key_Features: Key Features ~~~~~~~~~~~~ -Progressive Reconnect Backoff -_____________________________ +Progressive Reconnect Back Off +______________________________ When a non-client-side disconnect occurs, the SDK will reconnect automatically. The following APIs are provided for configuration: @@ -332,11 +470,11 @@ default configuration for backoff timing will be performed on initialization: maxReconnectQuietTimeSecond = 32 stableConnectionTimeSecond = 20 -Offline Publish Requests Queueing with Draining -_______________________________________________ +Offline Requests Queueing with Draining +_______________________________________ If the client is temporarily offline and disconnected due to -network failure, publish requests will be added to an internal +network failure, publish/subscribe/unsubscribe requests will be added to an internal queue until the number of queued-up requests reaches the size limit of the queue. This functionality is for plain MQTT operations. Shadow client contains time-sensitive data and is therefore not supported. @@ -347,7 +485,7 @@ The following API is provided for configuration: AWSIoTPythonSDK.MQTTLib.AWSIoTMQTTClient.configureOfflinePublishQueueing(queueSize, dropBehavior) -After the queue is full, offline publish requests will be discarded or +After the queue is full, offline publish/subscribe/unsubscribe requests will be discarded or replaced according to the configuration of the drop behavior: .. code-block:: python @@ -406,7 +544,7 @@ Because the queue is already full, the newest requests ``pub_req6`` and When the client is back online, connected, and resubscribed to all topics it has previously subscribed to, the draining starts. All requests -in the offline publish queue will be resent at the configured draining +in the offline request queue will be resent at the configured draining rate: .. code-block:: python @@ -414,7 +552,7 @@ rate: AWSIoTPythonSDK.MQTTLib.AWSIoTMQTTClient.configureDrainingFrequency(frequencyInHz) If no ``configOfflinePublishQueue`` or ``configureDrainingFrequency`` is -called, the following default configuration for offline publish queueing +called, the following default configuration for offline request queueing and draining will be performed on the initialization: .. code-block:: python @@ -423,16 +561,16 @@ and draining will be performed on the initialization: dropBehavior = DROP_NEWEST drainingFrequency = 2Hz -Before the draining process is complete, any new publish request +Before the draining process is complete, any new publish/subscribe/unsubscribe request within this time period will be added to the queue. Therefore, the draining rate -should be higher than the normal publish rate to avoid an endless +should be higher than the normal request rate to avoid an endless draining process after reconnect. The disconnect event is detected based on PINGRESP MQTT -packet loss. Offline publish queueing will not be triggered until the +packet loss. Offline request queueing will not be triggered until the disconnect event is detected. Configuring a shorter keep-alive interval allows the client to detect disconnects more quickly. Any QoS0 -publish requests issued after the network failure and before the +publish, subscribe and unsubscribe requests issued after the network failure and before the detection of the PINGRESP loss will be lost. Persistent/Non-Persistent Subscription @@ -485,6 +623,18 @@ accepted/rejected topics. In all SDK examples, PersistentSubscription is used in consideration of its better performance. +SSL Ciphers Setup +______________________________________ +If custom SSL Ciphers are required for the client, they can be set when configuring the client before +starting the connection. + +To setup specific SSL Ciphers: + +.. code-block:: python + + myAWSIoTMQTTClient.configureCredentials(rootCAPath, privateKeyPath, certificatePath, Ciphers="AES128-SHA256") + + .. _Examples: Examples @@ -512,6 +662,12 @@ Run the example like this: python basicPubSub.py -e -r -w # Customize client id and topic python basicPubSub.py -e -r -c -k -id -t + # Customize the message + python basicPubSub.py -e -r -c -k -id -t -M + # Customize the port number + python basicPubSub.py -e -r -c -k -p + # change the run mode to subscribe or publish only (see python basicPubSub.py -h for the available options) + python basicPubSub.py -e -r -c -k -m Source ****** @@ -551,6 +707,72 @@ Source The example is available in ``samples/basicPubSub/``. +BasicPubSub Asynchronous version +________________________________ + +This example demonstrates a simple MQTT publish/subscribe with asynchronous APIs using AWS IoT. +It first registers general notification callbacks for CONNACK reception, disconnect reception and message arrival. +It then registers ACK callbacks for subscribe and publish requests to print out received ack packet ids. +It subscribes to a topic with no specific callback and then publishes to the same topic in a loop. +New messages are printed upon reception by the general message arrival callback, indicating +the callback function has been called. +New ack packet ids are printed upon reception of PUBACK and SUBACK through ACK callbacks registered with asynchronous +API calls, indicating that the the client received ACKs for the corresponding asynchronous API calls. + +Instructions +************ + +Run the example like this: + +.. code-block:: python + + # Certificate based mutual authentication + python basicPubSubAsync.py -e -r -c -k + # MQTT over WebSocket + python basicPubSubAsync.py -e -r -w + # Customize client id and topic + python basicPubSubAsync.py -e -r -c -k -id -t + # Customize the port number + python basicPubSubAsync.py -e -r -c -k -p + +Source +****** + +The example is available in ``samples/basicPubSub/``. + +BasicPubSub with API invocation in callback +___________ + +This example demonstrates the usage of asynchronous APIs within callbacks. It first connects to AWS IoT and subscribes +to 2 topics with the corresponding message callbacks registered. One message callback contains client asynchronous API +invocation that republishes the received message from to /republish. The other message callback simply +prints out the received message. It then publishes messages to in an infinite loop. For every message received +from , it will be republished to /republish and be printed out as configured in the simple print-out +message callback. +New ack packet ids are printed upon reception of PUBACK and SUBACK through ACK callbacks registered with asynchronous +API calls, indicating that the the client received ACKs for the corresponding asynchronous API calls. + +Instructions +************ + +Run the example like this: + +.. code-block:: python + + # Certificate based mutual authentication + python basicPubSub_APICallInCallback.py -e -r -c -k + # MQTT over WebSocket + python basicPubSub_APICallInCallback.py -e -r -w + # Customize client id and topic + python basicPubSub_APICallInCallback.py -e -r -c -k -id -t + # Customize the port number + python basicPubSub_APICallInCallback.py -e -r -c -k -p + +Source +****** + +The example is available in ``samples/basicPubSub/``. + BasicShadow ___________ @@ -583,6 +805,8 @@ First, start the basicShadowDeltaListener: python basicShadowDeltaListener.py -e -r -c -k # MQTT over WebSocket python basicShadowDeltaListener.py -e -r -w + # Customize the port number + python basicShadowDeltaListener.py -e -r -c -k -p Then, start the basicShadowUpdater: @@ -593,6 +817,8 @@ Then, start the basicShadowUpdater: python basicShadowUpdater.py -e -r -c -k # MQTT over WebSocket python basicShadowUpdater.py -e -r -w + # Customize the port number + python basicShadowUpdater.py -e -r -c -k -p After the basicShadowUpdater starts sending shadow update requests, you @@ -628,6 +854,8 @@ Run the example like this: python ThingShadowEcho.py -e -r -w # Customize client Id and thing name python ThingShadowEcho.py -e -r -c -k -id -n + # Customize the port number + python ThingShadowEcho.py -e -r -c -k -p Now use the `AWS IoT console `__ or other MQTT client to update the shadow desired state only. You should be able to see the reported state is updated to match @@ -638,6 +866,73 @@ Source The example is available in ``samples/ThingShadowEcho/``. +JobsSample +__________ + +This example demonstrates how a device communicates with AWS IoT while +also taking advantage of AWS IoT Jobs functionality. It shows how to +subscribe to Jobs topics in order to recieve Job documents on your +device. It also shows how to process those Jobs so that you can see in +the `AWS IoT console `__ which of your devices have received and processed +which Jobs. See the AWS IoT Device Management documentation `here `__ +for more information on creating and deploying Jobs to your fleet of +devices to facilitate management tasks such deploying software updates +and running diagnostics. + +Instructions +************ + +First use the `AWS IoT console `__ to create and deploy Jobs to your fleet of devices. + +Then run the example like this: + +.. code-block:: python + + # Certificate based mutual authentication + python jobsSample.py -e -r -c -k -n + # MQTT over WebSocket + python jobsSample.py -e -r -w -n + # Customize client Id and thing name + python jobsSample.py -e -r -c -k -id -n + # Customize the port number + python jobsSample.py -e -r -c -k -n -p + +Source +****** + +The example is available in ``samples/jobs/``. + +BasicDiscovery +______________ + +This example demonstrates how to perform a discovery process from a Greengrass Aware Device (GGAD) to obtain the required +connectivity/identity information to connect to the Greengrass Core (GGC) deployed within the same group. It uses the +discovery information provider to invoke discover call for a certain GGAD with its thing name. After it gets back a +success response, it picks up the first GGC and the first set of identity information (CA) for the first group, persists \ +it locally and iterates through all connectivity info sets for this GGC to establish a MQTT connection to the designated +GGC. It then publishes messages to the topic, which, on the GGC side, is configured to route the messages back to the +same GGAD. Therefore, it receives the published messages and invokes the corresponding message callbacks. + +Note that in order to get the sample up and running correctly, you need: + +1. Have a successfully deployed Greengrass group. + +2. Use the certificate and private key that have been deployed with the group for the GGAD to perform discovery process. + +3. The subscription records for that deployed group should contain a route that routes messages from the targeted GGAD to itself via a dedicated MQTT topic. + +4. The deployed GGAD thing name, the deployed GGAD certificate/private key and the dedicated MQTT topic should be used as the inputs for this sample. + + +Run the sample like this: + +.. code-block:: python + + python basicDiscovery.py -e -r -c -k -n -t + +If the group, GGC, GGAD and group subscription/routes are set up correctly, you should be able to see the sample running +on your GGAD, receiving messages that get published to GGC by itself. + .. _API_Documentation: API Documentation diff --git a/continuous-delivery/pip-install-with-retry.py b/continuous-delivery/pip-install-with-retry.py new file mode 100644 index 0000000..347e0dc --- /dev/null +++ b/continuous-delivery/pip-install-with-retry.py @@ -0,0 +1,39 @@ +import time +import sys +import subprocess + +DOCS = """Given cmdline args, executes: python3 -m pip install [args...] +Keeps retrying until the new version becomes available in pypi (or we time out)""" +if len(sys.argv) < 2: + sys.exit(DOCS) + +RETRY_INTERVAL_SECS = 10 +GIVE_UP_AFTER_SECS = 60 * 15 + +pip_install_args = [sys.executable, '-m', 'pip', 'install'] + sys.argv[1:] + +start_time = time.time() +while True: + print(subprocess.list2cmdline(pip_install_args)) + result = subprocess.run(pip_install_args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + + stdout = result.stdout.decode().strip() + if stdout: + print(stdout) + + if result.returncode == 0: + # success + sys.exit(0) + + if "could not find a version" in stdout.lower(): + elapsed_secs = time.time() - start_time + if elapsed_secs < GIVE_UP_AFTER_SECS: + # try again + print("Retrying in", RETRY_INTERVAL_SECS, "secs...") + time.sleep(RETRY_INTERVAL_SECS) + continue + else: + print("Giving up on retries after", int(elapsed_secs), "total secs.") + + # fail + sys.exit(result.returncode) diff --git a/continuous-delivery/publish_to_prod_pypi.yml b/continuous-delivery/publish_to_prod_pypi.yml new file mode 100644 index 0000000..905d849 --- /dev/null +++ b/continuous-delivery/publish_to_prod_pypi.yml @@ -0,0 +1,25 @@ +version: 0.2 +# this image assumes Ubuntu 14.04 base image +phases: + install: + commands: + - sudo apt-get update -y + - sudo apt-get install python3 python3-pip -y + - export PATH=$PATH:$HOME/.local/bin + - python3 -m pip install --user --upgrade pip + - python3 -m pip install --user --upgrade twine setuptools wheel awscli PyOpenSSL six + pre_build: + commands: + - cd aws-iot-device-sdk-python + - pypirc=$(aws secretsmanager get-secret-value --secret-id "prod/aws-sdk-python-v1/.pypirc" --query "SecretString" | cut -f2 -d\") && echo "$pypirc" > ~/.pypirc + - export PKG_VERSION=$(git describe --tags | cut -f2 -dv) + - echo "Updating package version to ${PKG_VERSION}" + - sed --in-place -E "s/__version__ = \".+\"/__version__ = \"${PKG_VERSION}\"/" AWSIoTPythonSDK/__init__.py + build: + commands: + - echo Build started on `date` + - python3 setup.py sdist bdist_wheel --universal + - python3 -m twine upload -r pypi dist/* + post_build: + commands: + - echo Build completed on `date` diff --git a/continuous-delivery/publish_to_test_pypi.yml b/continuous-delivery/publish_to_test_pypi.yml new file mode 100644 index 0000000..c435e5e --- /dev/null +++ b/continuous-delivery/publish_to_test_pypi.yml @@ -0,0 +1,25 @@ +version: 0.2 +# this image assumes Ubuntu 14.04 base image +phases: + install: + commands: + - sudo apt-get update -y + - sudo apt-get install python3 python3-pip -y + - export PATH=$PATH:$HOME/.local/bin + - python3 -m pip install --user --upgrade pip + - python3 -m pip install --user --upgrade twine setuptools wheel awscli PyOpenSSL six + pre_build: + commands: + - pypirc=$(aws secretsmanager get-secret-value --secret-id "alpha/aws-sdk-python-v1/.pypirc" --query "SecretString" | cut -f2 -d\") && echo "$pypirc" > ~/.pypirc + - cd aws-iot-device-sdk-python + - export PKG_VERSION=$(git describe --tags | cut -f2 -dv) + - echo "Updating package version to ${PKG_VERSION}" + - sed --in-place -E "s/__version__ = \".+\"/__version__ = \"${PKG_VERSION}\"/" AWSIoTPythonSDK/__init__.py + build: + commands: + - echo Build started on `date` + - python3 setup_test.py sdist bdist_wheel --universal + - python3 -m twine upload -r testpypi dist/* --verbose + post_build: + commands: + - echo Build completed on `date` diff --git a/continuous-delivery/test_prod_pypi.yml b/continuous-delivery/test_prod_pypi.yml new file mode 100644 index 0000000..4575306 --- /dev/null +++ b/continuous-delivery/test_prod_pypi.yml @@ -0,0 +1,28 @@ +version: 0.2 +# this image assumes Ubuntu 14.04 base image +phases: + install: + commands: + - sudo apt-get update -y + - sudo apt-get install python3 python3-pip -y + - python3 -m pip install --upgrade pip + - python3 -m pip install --upgrade setuptools + + pre_build: + commands: + - curl https://www.amazontrust.com/repository/AmazonRootCA1.pem --output /tmp/AmazonRootCA1.pem + - cert=$(aws secretsmanager get-secret-value --secret-id "unit-test/certificate" --query "SecretString" | cut -f2 -d":" | cut -f2 -d\") && echo "$cert" > /tmp/certificate.pem + - key=$(aws secretsmanager get-secret-value --secret-id "unit-test/privatekey" --query "SecretString" | cut -f2 -d":" | cut -f2 -d\") && echo "$key" > /tmp/privatekey.pem + - ENDPOINT=$(aws secretsmanager get-secret-value --secret-id "unit-test/endpoint" --query "SecretString" | cut -f2 -d":" | sed -e 's/[\\\"\}]//g') + build: + commands: + - echo Build started on `date` + - cd aws-iot-device-sdk-python + - CURRENT_TAG_VERSION=$(git describe --tags | cut -f2 -dv) + - python3 continuous-delivery/pip-install-with-retry.py --no-cache-dir --user AWSIoTPythonSDK==$CURRENT_TAG_VERSION + - python3 samples/greengrass/basicDiscovery.py -e ${ENDPOINT} -c /tmp/certificate.pem -k /tmp/privatekey.pem -r /tmp/AmazonRootCA1.pem --print_discover_resp_only + + post_build: + commands: + - echo Build completed on `date` + diff --git a/continuous-delivery/test_test_pypi.yml b/continuous-delivery/test_test_pypi.yml new file mode 100644 index 0000000..c3aa47d --- /dev/null +++ b/continuous-delivery/test_test_pypi.yml @@ -0,0 +1,30 @@ +version: 0.2 +# this image assumes Ubuntu 14.04 base image +phases: + install: + commands: + - sudo apt-get update -y + - sudo apt-get install python3 python3-pip -y + - python3 -m pip install --upgrade pip + - python3 -m pip install --upgrade setuptools + + pre_build: + commands: + - curl https://www.amazontrust.com/repository/AmazonRootCA1.pem --output /tmp/AmazonRootCA1.pem + - cert=$(aws secretsmanager get-secret-value --secret-id "unit-test/certificate" --query "SecretString" | cut -f2 -d":" | cut -f2 -d\") && echo "$cert" > /tmp/certificate.pem + - key=$(aws secretsmanager get-secret-value --secret-id "unit-test/privatekey" --query "SecretString" | cut -f2 -d":" | cut -f2 -d\") && echo "$key" > /tmp/privatekey.pem + - ENDPOINT=$(aws secretsmanager get-secret-value --secret-id "unit-test/endpoint" --query "SecretString" | cut -f2 -d":" | sed -e 's/[\\\"\}]//g') + build: + commands: + - echo Build started on `date` + - cd aws-iot-device-sdk-python + - CURRENT_TAG_VERSION=$(git describe --tags | cut -f2 -dv) + # this is here because typing isn't in testpypi, so pull it from prod instead + - python3 -m pip install typing + - python3 continuous-delivery/pip-install-with-retry.py -i https://testpypi.python.org/simple --user AWSIoTPythonSDK-V1==$CURRENT_TAG_VERSION + - python3 samples/greengrass/basicDiscovery.py -e ${ENDPOINT} -c /tmp/certificate.pem -k /tmp/privatekey.pem -r /tmp/AmazonRootCA1.pem --print_discover_resp_only + + post_build: + commands: + - echo Build completed on `date` + diff --git a/continuous-delivery/test_version_exists b/continuous-delivery/test_version_exists new file mode 100644 index 0000000..3579dbc --- /dev/null +++ b/continuous-delivery/test_version_exists @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +set -e +set -x +# force a failure if there's no tag +git describe --tags +# now get the tag +CURRENT_TAG=$(git describe --tags | cut -f2 -dv) +# convert v0.2.12-2-g50254a9 to 0.2.12 +CURRENT_TAG_VERSION=$(git describe --tags | cut -f1 -d'-' | cut -f2 -dv) +# if there's a hash on the tag, then this is not a release tagged commit +if [ "$CURRENT_TAG" != "$CURRENT_TAG_VERSION" ]; then + echo "Current tag version is not a release tag, cut a new release if you want to publish." + exit 1 +fi + +if python3 -m pip install --no-cache-dir -vvv AWSIoTPythonSDK==$CURRENT_TAG_VERSION; then + echo "$CURRENT_TAG_VERSION is already in pypi, cut a new tag if you want to upload another version." + exit 1 +fi + +echo "$CURRENT_TAG_VERSION currently does not exist in pypi, allowing pipeline to continue." +exit 0 diff --git a/continuous-delivery/test_version_exists.yml b/continuous-delivery/test_version_exists.yml new file mode 100644 index 0000000..2704ba7 --- /dev/null +++ b/continuous-delivery/test_version_exists.yml @@ -0,0 +1,21 @@ +version: 0.2 +#this build spec assumes the ubuntu 14.04 trusty image +#this build run simply verifies we haven't published something at this tag yet. +#if we have we fail the build and stop the pipeline, if we haven't we allow the pipeline to run. +phases: + install: + commands: + - sudo apt-get update -y + - sudo apt-get install python3 python3-pip -y + - pip3 install --upgrade setuptools + pre_build: + commands: + - echo Build start on `date` + build: + commands: + - cd aws-iot-device-sdk-python + - bash ./continuous-delivery/test_version_exists + post_build: + commands: + - echo Build completed on `date` + diff --git a/samples/ThingShadowEcho/ThingShadowEcho.py b/samples/ThingShadowEcho/ThingShadowEcho.py index a026ca9..cd1f17e 100755 --- a/samples/ThingShadowEcho/ThingShadowEcho.py +++ b/samples/ThingShadowEcho/ThingShadowEcho.py @@ -1,6 +1,6 @@ ''' /* - * Copyright 2010-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). * You may not use this file except in compliance with the License. @@ -16,28 +16,29 @@ ''' from AWSIoTPythonSDK.MQTTLib import AWSIoTMQTTShadowClient -import sys import logging import time import json import argparse + class shadowCallbackContainer: - def __init__(self, deviceShadowInstance): - self.deviceShadowInstance = deviceShadowInstance - - # Custom Shadow callback - def customShadowCallback_Delta(self, payload, responseStatus, token): - # payload is a JSON string ready to be parsed using json.loads(...) - # in both Py2.x and Py3.x - print("Received a delta message:") - payloadDict = json.loads(payload) - deltaMessage = json.dumps(payloadDict["state"]) - print(deltaMessage) - print("Request to update the reported state...") - newPayload = '{"state":{"reported":' + deltaMessage + '}}' - self.deviceShadowInstance.shadowUpdate(newPayload, None, 5) - print("Sent.") + def __init__(self, deviceShadowInstance): + self.deviceShadowInstance = deviceShadowInstance + + # Custom Shadow callback + def customShadowCallback_Delta(self, payload, responseStatus, token): + # payload is a JSON string ready to be parsed using json.loads(...) + # in both Py2.x and Py3.x + print("Received a delta message:") + payloadDict = json.loads(payload) + deltaMessage = json.dumps(payloadDict["state"]) + print(deltaMessage) + print("Request to update the reported state...") + newPayload = '{"state":{"reported":' + deltaMessage + '}}' + self.deviceShadowInstance.shadowUpdate(newPayload, None, 5) + print("Sent.") + # Read in command-line parameters parser = argparse.ArgumentParser() @@ -45,27 +46,36 @@ def customShadowCallback_Delta(self, payload, responseStatus, token): parser.add_argument("-r", "--rootCA", action="store", required=True, dest="rootCAPath", help="Root CA file path") parser.add_argument("-c", "--cert", action="store", dest="certificatePath", help="Certificate file path") parser.add_argument("-k", "--key", action="store", dest="privateKeyPath", help="Private key file path") +parser.add_argument("-p", "--port", action="store", dest="port", type=int, help="Port number override") parser.add_argument("-w", "--websocket", action="store_true", dest="useWebsocket", default=False, help="Use MQTT over WebSocket") parser.add_argument("-n", "--thingName", action="store", dest="thingName", default="Bot", help="Targeted thing name") -parser.add_argument("-id", "--clientId", action="store", dest="clientId", default="ThingShadowEcho", help="Targeted client id") +parser.add_argument("-id", "--clientId", action="store", dest="clientId", default="ThingShadowEcho", + help="Targeted client id") args = parser.parse_args() host = args.host rootCAPath = args.rootCAPath certificatePath = args.certificatePath privateKeyPath = args.privateKeyPath +port = args.port useWebsocket = args.useWebsocket thingName = args.thingName clientId = args.clientId if args.useWebsocket and args.certificatePath and args.privateKeyPath: - parser.error("X.509 cert authentication and WebSocket are mutual exclusive. Please pick one.") - exit(2) + parser.error("X.509 cert authentication and WebSocket are mutual exclusive. Please pick one.") + exit(2) if not args.useWebsocket and (not args.certificatePath or not args.privateKeyPath): - parser.error("Missing credentials for authentication.") - exit(2) + parser.error("Missing credentials for authentication.") + exit(2) + +# Port defaults +if args.useWebsocket and not args.port: # When no port override for WebSocket, default to 443 + port = 443 +if not args.useWebsocket and not args.port: # When no port override for non-WebSocket, default to 8883 + port = 8883 # Configure logging logger = logging.getLogger("AWSIoTPythonSDK.core") @@ -78,13 +88,13 @@ def customShadowCallback_Delta(self, payload, responseStatus, token): # Init AWSIoTMQTTShadowClient myAWSIoTMQTTShadowClient = None if useWebsocket: - myAWSIoTMQTTShadowClient = AWSIoTMQTTShadowClient(clientId, useWebsocket=True) - myAWSIoTMQTTShadowClient.configureEndpoint(host, 443) - myAWSIoTMQTTShadowClient.configureCredentials(rootCAPath) + myAWSIoTMQTTShadowClient = AWSIoTMQTTShadowClient(clientId, useWebsocket=True) + myAWSIoTMQTTShadowClient.configureEndpoint(host, port) + myAWSIoTMQTTShadowClient.configureCredentials(rootCAPath) else: - myAWSIoTMQTTShadowClient = AWSIoTMQTTShadowClient(clientId) - myAWSIoTMQTTShadowClient.configureEndpoint(host, 8883) - myAWSIoTMQTTShadowClient.configureCredentials(rootCAPath, privateKeyPath, certificatePath) + myAWSIoTMQTTShadowClient = AWSIoTMQTTShadowClient(clientId) + myAWSIoTMQTTShadowClient.configureEndpoint(host, port) + myAWSIoTMQTTShadowClient.configureCredentials(rootCAPath, privateKeyPath, certificatePath) # AWSIoTMQTTShadowClient configuration myAWSIoTMQTTShadowClient.configureAutoReconnectBackoffTime(1, 32, 20) @@ -103,4 +113,4 @@ def customShadowCallback_Delta(self, payload, responseStatus, token): # Loop forever while True: - time.sleep(1) + time.sleep(1) diff --git a/samples/basicPubSub/basicPubSub.py b/samples/basicPubSub/basicPubSub.py index 1ef4e84..dc823fc 100755 --- a/samples/basicPubSub/basicPubSub.py +++ b/samples/basicPubSub/basicPubSub.py @@ -1,6 +1,6 @@ ''' /* - * Copyright 2010-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). * You may not use this file except in compliance with the License. @@ -16,18 +16,21 @@ ''' from AWSIoTPythonSDK.MQTTLib import AWSIoTMQTTClient -import sys import logging import time import argparse +import json + +AllowedActions = ['both', 'publish', 'subscribe'] # Custom MQTT message callback def customCallback(client, userdata, message): - print("Received a new message: ") - print(message.payload) - print("from topic: ") - print(message.topic) - print("--------------\n\n") + print("Received a new message: ") + print(message.payload) + print("from topic: ") + print(message.topic) + print("--------------\n\n") + # Read in command-line parameters parser = argparse.ArgumentParser() @@ -35,27 +38,44 @@ def customCallback(client, userdata, message): parser.add_argument("-r", "--rootCA", action="store", required=True, dest="rootCAPath", help="Root CA file path") parser.add_argument("-c", "--cert", action="store", dest="certificatePath", help="Certificate file path") parser.add_argument("-k", "--key", action="store", dest="privateKeyPath", help="Private key file path") +parser.add_argument("-p", "--port", action="store", dest="port", type=int, help="Port number override") parser.add_argument("-w", "--websocket", action="store_true", dest="useWebsocket", default=False, help="Use MQTT over WebSocket") -parser.add_argument("-id", "--clientId", action="store", dest="clientId", default="basicPubSub", help="Targeted client id") +parser.add_argument("-id", "--clientId", action="store", dest="clientId", default="basicPubSub", + help="Targeted client id") parser.add_argument("-t", "--topic", action="store", dest="topic", default="sdk/test/Python", help="Targeted topic") +parser.add_argument("-m", "--mode", action="store", dest="mode", default="both", + help="Operation modes: %s"%str(AllowedActions)) +parser.add_argument("-M", "--message", action="store", dest="message", default="Hello World!", + help="Message to publish") args = parser.parse_args() host = args.host rootCAPath = args.rootCAPath certificatePath = args.certificatePath privateKeyPath = args.privateKeyPath +port = args.port useWebsocket = args.useWebsocket clientId = args.clientId topic = args.topic +if args.mode not in AllowedActions: + parser.error("Unknown --mode option %s. Must be one of %s" % (args.mode, str(AllowedActions))) + exit(2) + if args.useWebsocket and args.certificatePath and args.privateKeyPath: - parser.error("X.509 cert authentication and WebSocket are mutual exclusive. Please pick one.") - exit(2) + parser.error("X.509 cert authentication and WebSocket are mutual exclusive. Please pick one.") + exit(2) if not args.useWebsocket and (not args.certificatePath or not args.privateKeyPath): - parser.error("Missing credentials for authentication.") - exit(2) + parser.error("Missing credentials for authentication.") + exit(2) + +# Port defaults +if args.useWebsocket and not args.port: # When no port override for WebSocket, default to 443 + port = 443 +if not args.useWebsocket and not args.port: # When no port override for non-WebSocket, default to 8883 + port = 8883 # Configure logging logger = logging.getLogger("AWSIoTPythonSDK.core") @@ -68,13 +88,13 @@ def customCallback(client, userdata, message): # Init AWSIoTMQTTClient myAWSIoTMQTTClient = None if useWebsocket: - myAWSIoTMQTTClient = AWSIoTMQTTClient(clientId, useWebsocket=True) - myAWSIoTMQTTClient.configureEndpoint(host, 443) - myAWSIoTMQTTClient.configureCredentials(rootCAPath) + myAWSIoTMQTTClient = AWSIoTMQTTClient(clientId, useWebsocket=True) + myAWSIoTMQTTClient.configureEndpoint(host, port) + myAWSIoTMQTTClient.configureCredentials(rootCAPath) else: - myAWSIoTMQTTClient = AWSIoTMQTTClient(clientId) - myAWSIoTMQTTClient.configureEndpoint(host, 8883) - myAWSIoTMQTTClient.configureCredentials(rootCAPath, privateKeyPath, certificatePath) + myAWSIoTMQTTClient = AWSIoTMQTTClient(clientId) + myAWSIoTMQTTClient.configureEndpoint(host, port) + myAWSIoTMQTTClient.configureCredentials(rootCAPath, privateKeyPath, certificatePath) # AWSIoTMQTTClient connection configuration myAWSIoTMQTTClient.configureAutoReconnectBackoffTime(1, 32, 20) @@ -85,12 +105,20 @@ def customCallback(client, userdata, message): # Connect and subscribe to AWS IoT myAWSIoTMQTTClient.connect() -myAWSIoTMQTTClient.subscribe(topic, 1, customCallback) +if args.mode == 'both' or args.mode == 'subscribe': + myAWSIoTMQTTClient.subscribe(topic, 1, customCallback) time.sleep(2) # Publish to the same topic in a loop forever loopCount = 0 while True: - myAWSIoTMQTTClient.publish(topic, "New Message " + str(loopCount), 1) - loopCount += 1 - time.sleep(1) + if args.mode == 'both' or args.mode == 'publish': + message = {} + message['message'] = args.message + message['sequence'] = loopCount + messageJson = json.dumps(message) + myAWSIoTMQTTClient.publish(topic, messageJson, 1) + if args.mode == 'publish': + print('Published topic %s: %s\n' % (topic, messageJson)) + loopCount += 1 + time.sleep(1) diff --git a/samples/basicPubSub/basicPubSubAsync.py b/samples/basicPubSub/basicPubSubAsync.py new file mode 100644 index 0000000..25050f4 --- /dev/null +++ b/samples/basicPubSub/basicPubSubAsync.py @@ -0,0 +1,124 @@ +''' +/* + * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + ''' + +from AWSIoTPythonSDK.MQTTLib import AWSIoTMQTTClient +import logging +import time +import argparse + + +# General message notification callback +def customOnMessage(message): + print("Received a new message: ") + print(message.payload) + print("from topic: ") + print(message.topic) + print("--------------\n\n") + + +# Suback callback +def customSubackCallback(mid, data): + print("Received SUBACK packet id: ") + print(mid) + print("Granted QoS: ") + print(data) + print("++++++++++++++\n\n") + + +# Puback callback +def customPubackCallback(mid): + print("Received PUBACK packet id: ") + print(mid) + print("++++++++++++++\n\n") + + +# Read in command-line parameters +parser = argparse.ArgumentParser() +parser.add_argument("-e", "--endpoint", action="store", required=True, dest="host", help="Your AWS IoT custom endpoint") +parser.add_argument("-r", "--rootCA", action="store", required=True, dest="rootCAPath", help="Root CA file path") +parser.add_argument("-c", "--cert", action="store", dest="certificatePath", help="Certificate file path") +parser.add_argument("-k", "--key", action="store", dest="privateKeyPath", help="Private key file path") +parser.add_argument("-p", "--port", action="store", dest="port", type=int, help="Port number override") +parser.add_argument("-w", "--websocket", action="store_true", dest="useWebsocket", default=False, + help="Use MQTT over WebSocket") +parser.add_argument("-id", "--clientId", action="store", dest="clientId", default="basicPubSub", + help="Targeted client id") +parser.add_argument("-t", "--topic", action="store", dest="topic", default="sdk/test/Python", help="Targeted topic") + +args = parser.parse_args() +host = args.host +rootCAPath = args.rootCAPath +certificatePath = args.certificatePath +privateKeyPath = args.privateKeyPath +port = args.port +useWebsocket = args.useWebsocket +clientId = args.clientId +topic = args.topic + +if args.useWebsocket and args.certificatePath and args.privateKeyPath: + parser.error("X.509 cert authentication and WebSocket are mutual exclusive. Please pick one.") + exit(2) + +if not args.useWebsocket and (not args.certificatePath or not args.privateKeyPath): + parser.error("Missing credentials for authentication.") + exit(2) + +# Port defaults +if args.useWebsocket and not args.port: # When no port override for WebSocket, default to 443 + port = 443 +if not args.useWebsocket and not args.port: # When no port override for non-WebSocket, default to 8883 + port = 8883 + +# Configure logging +logger = logging.getLogger("AWSIoTPythonSDK.core") +logger.setLevel(logging.DEBUG) +streamHandler = logging.StreamHandler() +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +streamHandler.setFormatter(formatter) +logger.addHandler(streamHandler) + +# Init AWSIoTMQTTClient +myAWSIoTMQTTClient = None +if useWebsocket: + myAWSIoTMQTTClient = AWSIoTMQTTClient(clientId, useWebsocket=True) + myAWSIoTMQTTClient.configureEndpoint(host, port) + myAWSIoTMQTTClient.configureCredentials(rootCAPath) +else: + myAWSIoTMQTTClient = AWSIoTMQTTClient(clientId) + myAWSIoTMQTTClient.configureEndpoint(host, port) + myAWSIoTMQTTClient.configureCredentials(rootCAPath, privateKeyPath, certificatePath) + +# AWSIoTMQTTClient connection configuration +myAWSIoTMQTTClient.configureAutoReconnectBackoffTime(1, 32, 20) +myAWSIoTMQTTClient.configureOfflinePublishQueueing(-1) # Infinite offline Publish queueing +myAWSIoTMQTTClient.configureDrainingFrequency(2) # Draining: 2 Hz +myAWSIoTMQTTClient.configureConnectDisconnectTimeout(10) # 10 sec +myAWSIoTMQTTClient.configureMQTTOperationTimeout(5) # 5 sec +myAWSIoTMQTTClient.onMessage = customOnMessage + +# Connect and subscribe to AWS IoT +myAWSIoTMQTTClient.connect() +# Note that we are not putting a message callback here. We are using the general message notification callback. +myAWSIoTMQTTClient.subscribeAsync(topic, 1, ackCallback=customSubackCallback) +time.sleep(2) + +# Publish to the same topic in a loop forever +loopCount = 0 +while True: + myAWSIoTMQTTClient.publishAsync(topic, "New Message " + str(loopCount), 1, ackCallback=customPubackCallback) + loopCount += 1 + time.sleep(1) diff --git a/samples/basicPubSub/basicPubSubProxy.py b/samples/basicPubSub/basicPubSubProxy.py new file mode 100644 index 0000000..929ef42 --- /dev/null +++ b/samples/basicPubSub/basicPubSubProxy.py @@ -0,0 +1,136 @@ +''' +/* + * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + ''' + +from AWSIoTPythonSDK.MQTTLib import AWSIoTMQTTClient +import logging +import time +import argparse +import json + +AllowedActions = ['both', 'publish', 'subscribe'] + +# Custom MQTT message callback +def customCallback(client, userdata, message): + print("Received a new message: ") + print(message.payload) + print("from topic: ") + print(message.topic) + print("--------------\n\n") + + +# Read in command-line parameters +parser = argparse.ArgumentParser() +parser.add_argument("-e", "--endpoint", action="store", required=True, dest="host", help="Your AWS IoT custom endpoint") +parser.add_argument("-r", "--rootCA", action="store", required=True, dest="rootCAPath", help="Root CA file path") +parser.add_argument("-c", "--cert", action="store", dest="certificatePath", help="Certificate file path") +parser.add_argument("-k", "--key", action="store", dest="privateKeyPath", help="Private key file path") +parser.add_argument("-p", "--port", action="store", dest="port", type=int, help="Port number override") +parser.add_argument("-w", "--websocket", action="store_true", dest="useWebsocket", default=False, + help="Use MQTT over WebSocket") +parser.add_argument("-id", "--clientId", action="store", dest="clientId", default="basicPubSub", + help="Targeted client id") +parser.add_argument("-t", "--topic", action="store", dest="topic", default="sdk/test/Python", help="Targeted topic") +parser.add_argument("-m", "--mode", action="store", dest="mode", default="both", + help="Operation modes: %s"%str(AllowedActions)) +parser.add_argument("-M", "--message", action="store", dest="message", default="Hello World!", + help="Message to publish") + +args = parser.parse_args() +host = args.host +rootCAPath = args.rootCAPath +certificatePath = args.certificatePath +privateKeyPath = args.privateKeyPath +port = args.port +useWebsocket = args.useWebsocket +clientId = args.clientId +topic = args.topic + +if args.mode not in AllowedActions: + parser.error("Unknown --mode option %s. Must be one of %s" % (args.mode, str(AllowedActions))) + exit(2) + +if args.useWebsocket and args.certificatePath and args.privateKeyPath: + parser.error("X.509 cert authentication and WebSocket are mutual exclusive. Please pick one.") + exit(2) + +if not args.useWebsocket and (not args.certificatePath or not args.privateKeyPath): + parser.error("Missing credentials for authentication.") + exit(2) + +# Port defaults +if args.useWebsocket and not args.port: # When no port override for WebSocket, default to 443 + port = 443 +if not args.useWebsocket and not args.port: # When no port override for non-WebSocket, default to 8883 + port = 8883 + +# Configure logging +logger = logging.getLogger("AWSIoTPythonSDK.core") +logger.setLevel(logging.DEBUG) +streamHandler = logging.StreamHandler() +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +streamHandler.setFormatter(formatter) +logger.addHandler(streamHandler) + +# Init AWSIoTMQTTClient +myAWSIoTMQTTClient = None +if useWebsocket: + myAWSIoTMQTTClient = AWSIoTMQTTClient(clientId, useWebsocket=True) + myAWSIoTMQTTClient.configureEndpoint(host, port) + myAWSIoTMQTTClient.configureCredentials(rootCAPath) +else: + myAWSIoTMQTTClient = AWSIoTMQTTClient(clientId) + myAWSIoTMQTTClient.configureEndpoint(host, port) + myAWSIoTMQTTClient.configureCredentials(rootCAPath, privateKeyPath, certificatePath) + +# AWSIoTMQTTClient connection configuration +myAWSIoTMQTTClient.configureAutoReconnectBackoffTime(1, 32, 20) +myAWSIoTMQTTClient.configureOfflinePublishQueueing(-1) # Infinite offline Publish queueing +myAWSIoTMQTTClient.configureDrainingFrequency(2) # Draining: 2 Hz +myAWSIoTMQTTClient.configureConnectDisconnectTimeout(10) # 10 sec +myAWSIoTMQTTClient.configureMQTTOperationTimeout(5) # 5 sec + +# AWSIoTMQTTClient socket configuration +# import pysocks to help us build a socket that supports a proxy configuration +import socks + +# set proxy arguments (for SOCKS5 proxy: proxy_type=2, for HTTP proxy: proxy_type=3) +proxy_config = {"proxy_addr":, "proxy_port":, "proxy_type":} + +# create anonymous function to handle socket creation +socket_factory = lambda: socks.create_connection((host, port), **proxy_config) +myAWSIoTMQTTClient.configureSocketFactory(socket_factory) + +# Connect and subscribe to AWS IoT +myAWSIoTMQTTClient.connect() +if args.mode == 'both' or args.mode == 'subscribe': + myAWSIoTMQTTClient.subscribe(topic, 1, customCallback) +time.sleep(2) + +# Publish to the same topic in a loop forever +loopCount = 0 +while True: + if args.mode == 'both' or args.mode == 'publish': + message = {} + message['message'] = args.message + message['sequence'] = loopCount + messageJson = json.dumps(message) + myAWSIoTMQTTClient.publish(topic, messageJson, 1) + if args.mode == 'publish': + print('Published topic %s: %s\n' % (topic, messageJson)) + loopCount += 1 + time.sleep(1) + diff --git a/samples/basicPubSub/basicPubSub_APICallInCallback.py b/samples/basicPubSub/basicPubSub_APICallInCallback.py new file mode 100644 index 0000000..710457e --- /dev/null +++ b/samples/basicPubSub/basicPubSub_APICallInCallback.py @@ -0,0 +1,133 @@ +''' +/* + * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + ''' + +from AWSIoTPythonSDK.MQTTLib import AWSIoTMQTTClient +import logging +import time +import argparse + + +class CallbackContainer(object): + + def __init__(self, client): + self._client = client + + def messagePrint(self, client, userdata, message): + print("Received a new message: ") + print(message.payload) + print("from topic: ") + print(message.topic) + print("--------------\n\n") + + def messageForward(self, client, userdata, message): + topicRepublish = message.topic + "/republish" + print("Forwarding message from: %s to %s" % (message.topic, topicRepublish)) + print("--------------\n\n") + self._client.publishAsync(topicRepublish, str(message.payload), 1, self.pubackCallback) + + def pubackCallback(self, mid): + print("Received PUBACK packet id: ") + print(mid) + print("++++++++++++++\n\n") + + def subackCallback(self, mid, data): + print("Received SUBACK packet id: ") + print(mid) + print("Granted QoS: ") + print(data) + print("++++++++++++++\n\n") + + +# Read in command-line parameters +parser = argparse.ArgumentParser() +parser.add_argument("-e", "--endpoint", action="store", required=True, dest="host", help="Your AWS IoT custom endpoint") +parser.add_argument("-r", "--rootCA", action="store", required=True, dest="rootCAPath", help="Root CA file path") +parser.add_argument("-c", "--cert", action="store", dest="certificatePath", help="Certificate file path") +parser.add_argument("-k", "--key", action="store", dest="privateKeyPath", help="Private key file path") +parser.add_argument("-p", "--port", action="store", dest="port", type=int, help="Port number override") +parser.add_argument("-w", "--websocket", action="store_true", dest="useWebsocket", default=False, + help="Use MQTT over WebSocket") +parser.add_argument("-id", "--clientId", action="store", dest="clientId", default="basicPubSub", + help="Targeted client id") +parser.add_argument("-t", "--topic", action="store", dest="topic", default="sdk/test/Python", help="Targeted topic") + +args = parser.parse_args() +host = args.host +rootCAPath = args.rootCAPath +certificatePath = args.certificatePath +privateKeyPath = args.privateKeyPath +port = args.port +useWebsocket = args.useWebsocket +clientId = args.clientId +topic = args.topic + +if args.useWebsocket and args.certificatePath and args.privateKeyPath: + parser.error("X.509 cert authentication and WebSocket are mutual exclusive. Please pick one.") + exit(2) + +if not args.useWebsocket and (not args.certificatePath or not args.privateKeyPath): + parser.error("Missing credentials for authentication.") + exit(2) + +# Port defaults +if args.useWebsocket and not args.port: # When no port override for WebSocket, default to 443 + port = 443 +if not args.useWebsocket and not args.port: # When no port override for non-WebSocket, default to 8883 + port = 8883 + +# Configure logging +logger = logging.getLogger("AWSIoTPythonSDK.core") +logger.setLevel(logging.DEBUG) +streamHandler = logging.StreamHandler() +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +streamHandler.setFormatter(formatter) +logger.addHandler(streamHandler) + +# Init AWSIoTMQTTClient +myAWSIoTMQTTClient = None +if useWebsocket: + myAWSIoTMQTTClient = AWSIoTMQTTClient(clientId, useWebsocket=True) + myAWSIoTMQTTClient.configureEndpoint(host, port) + myAWSIoTMQTTClient.configureCredentials(rootCAPath) +else: + myAWSIoTMQTTClient = AWSIoTMQTTClient(clientId) + myAWSIoTMQTTClient.configureEndpoint(host, port) + myAWSIoTMQTTClient.configureCredentials(rootCAPath, privateKeyPath, certificatePath) + +# AWSIoTMQTTClient connection configuration +myAWSIoTMQTTClient.configureAutoReconnectBackoffTime(1, 32, 20) +myAWSIoTMQTTClient.configureOfflinePublishQueueing(-1) # Infinite offline Publish queueing +myAWSIoTMQTTClient.configureDrainingFrequency(2) # Draining: 2 Hz +myAWSIoTMQTTClient.configureConnectDisconnectTimeout(10) # 10 sec +myAWSIoTMQTTClient.configureMQTTOperationTimeout(5) # 5 sec + +myCallbackContainer = CallbackContainer(myAWSIoTMQTTClient) + +# Connect and subscribe to AWS IoT +myAWSIoTMQTTClient.connect() + +# Perform synchronous subscribes +myAWSIoTMQTTClient.subscribe(topic, 1, myCallbackContainer.messageForward) +myAWSIoTMQTTClient.subscribe(topic + "/republish", 1, myCallbackContainer.messagePrint) +time.sleep(2) + +# Publish to the same topic in a loop forever +loopCount = 0 +while True: + myAWSIoTMQTTClient.publishAsync(topic, "New Message " + str(loopCount), 1, ackCallback=myCallbackContainer.pubackCallback) + loopCount += 1 + time.sleep(1) diff --git a/samples/basicPubSub/basicPubSub_CognitoSTS.py b/samples/basicPubSub/basicPubSub_CognitoSTS.py index f6d20e7..d67e624 100755 --- a/samples/basicPubSub/basicPubSub_CognitoSTS.py +++ b/samples/basicPubSub/basicPubSub_CognitoSTS.py @@ -1,6 +1,6 @@ ''' /* - * Copyright 2010-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). * You may not use this file except in compliance with the License. @@ -17,25 +17,28 @@ import boto3 from AWSIoTPythonSDK.MQTTLib import AWSIoTMQTTClient -import sys import logging import time import argparse + # Custom MQTT message callback def customCallback(client, userdata, message): - print("Received a new message: ") - print(message.payload) - print("from topic: ") - print(message.topic) - print("--------------\n\n") + print("Received a new message: ") + print(message.payload) + print("from topic: ") + print(message.topic) + print("--------------\n\n") + # Read in command-line parameters parser = argparse.ArgumentParser() parser.add_argument("-e", "--endpoint", action="store", required=True, dest="host", help="Your AWS IoT custom endpoint") parser.add_argument("-r", "--rootCA", action="store", required=True, dest="rootCAPath", help="Root CA file path") -parser.add_argument("-C", "--CognitoIdentityPoolID", action="store", required=True, dest="cognitoIdentityPoolID", help="Your AWS Cognito Identity Pool ID") -parser.add_argument("-id", "--clientId", action="store", dest="clientId", default="basicPubSub_CognitoSTS", help="Targeted client id") +parser.add_argument("-C", "--CognitoIdentityPoolID", action="store", required=True, dest="cognitoIdentityPoolID", + help="Your AWS Cognito Identity Pool ID") +parser.add_argument("-id", "--clientId", action="store", dest="clientId", default="basicPubSub_CognitoSTS", + help="Targeted client id") parser.add_argument("-t", "--topic", action="store", dest="topic", default="sdk/test/Python", help="Targeted topic") args = parser.parse_args() @@ -89,6 +92,6 @@ def customCallback(client, userdata, message): # Publish to the same topic in a loop forever loopCount = 0 while True: - myAWSIoTMQTTClient.publish(topic, "New Message " + str(loopCount), 1) - loopCount += 1 - time.sleep(1) + myAWSIoTMQTTClient.publish(topic, "New Message " + str(loopCount), 1) + loopCount += 1 + time.sleep(1) diff --git a/samples/basicShadow/basicShadowDeltaListener.py b/samples/basicShadow/basicShadowDeltaListener.py index 86d2b5c..73a6b7b 100755 --- a/samples/basicShadow/basicShadowDeltaListener.py +++ b/samples/basicShadow/basicShadowDeltaListener.py @@ -1,6 +1,6 @@ ''' /* - * Copyright 2010-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). * You may not use this file except in compliance with the License. @@ -16,12 +16,12 @@ ''' from AWSIoTPythonSDK.MQTTLib import AWSIoTMQTTShadowClient -import sys import logging import time import json import argparse + # Shadow JSON schema: # # Name: Bot @@ -31,18 +31,19 @@ # "property": # } # } -#} +# } # Custom Shadow callback def customShadowCallback_Delta(payload, responseStatus, token): - # payload is a JSON string ready to be parsed using json.loads(...) - # in both Py2.x and Py3.x - print(responseStatus) - payloadDict = json.loads(payload) - print("++++++++DELTA++++++++++") - print("property: " + str(payloadDict["state"]["property"])) - print("version: " + str(payloadDict["version"])) - print("+++++++++++++++++++++++\n\n") + # payload is a JSON string ready to be parsed using json.loads(...) + # in both Py2.x and Py3.x + print(responseStatus) + payloadDict = json.loads(payload) + print("++++++++DELTA++++++++++") + print("property: " + str(payloadDict["state"]["property"])) + print("version: " + str(payloadDict["version"])) + print("+++++++++++++++++++++++\n\n") + # Read in command-line parameters parser = argparse.ArgumentParser() @@ -50,27 +51,36 @@ def customShadowCallback_Delta(payload, responseStatus, token): parser.add_argument("-r", "--rootCA", action="store", required=True, dest="rootCAPath", help="Root CA file path") parser.add_argument("-c", "--cert", action="store", dest="certificatePath", help="Certificate file path") parser.add_argument("-k", "--key", action="store", dest="privateKeyPath", help="Private key file path") +parser.add_argument("-p", "--port", action="store", dest="port", type=int, help="Port number override") parser.add_argument("-w", "--websocket", action="store_true", dest="useWebsocket", default=False, help="Use MQTT over WebSocket") parser.add_argument("-n", "--thingName", action="store", dest="thingName", default="Bot", help="Targeted thing name") -parser.add_argument("-id", "--clientId", action="store", dest="clientId", default="basicShadowDeltaListener", help="Targeted client id") +parser.add_argument("-id", "--clientId", action="store", dest="clientId", default="basicShadowDeltaListener", + help="Targeted client id") args = parser.parse_args() host = args.host rootCAPath = args.rootCAPath certificatePath = args.certificatePath privateKeyPath = args.privateKeyPath +port = args.port useWebsocket = args.useWebsocket thingName = args.thingName clientId = args.clientId if args.useWebsocket and args.certificatePath and args.privateKeyPath: - parser.error("X.509 cert authentication and WebSocket are mutual exclusive. Please pick one.") - exit(2) + parser.error("X.509 cert authentication and WebSocket are mutual exclusive. Please pick one.") + exit(2) if not args.useWebsocket and (not args.certificatePath or not args.privateKeyPath): - parser.error("Missing credentials for authentication.") - exit(2) + parser.error("Missing credentials for authentication.") + exit(2) + +# Port defaults +if args.useWebsocket and not args.port: # When no port override for WebSocket, default to 443 + port = 443 +if not args.useWebsocket and not args.port: # When no port override for non-WebSocket, default to 8883 + port = 8883 # Configure logging logger = logging.getLogger("AWSIoTPythonSDK.core") @@ -83,13 +93,13 @@ def customShadowCallback_Delta(payload, responseStatus, token): # Init AWSIoTMQTTShadowClient myAWSIoTMQTTShadowClient = None if useWebsocket: - myAWSIoTMQTTShadowClient = AWSIoTMQTTShadowClient(clientId, useWebsocket=True) - myAWSIoTMQTTShadowClient.configureEndpoint(host, 443) - myAWSIoTMQTTShadowClient.configureCredentials(rootCAPath) + myAWSIoTMQTTShadowClient = AWSIoTMQTTShadowClient(clientId, useWebsocket=True) + myAWSIoTMQTTShadowClient.configureEndpoint(host, port) + myAWSIoTMQTTShadowClient.configureCredentials(rootCAPath) else: - myAWSIoTMQTTShadowClient = AWSIoTMQTTShadowClient(clientId) - myAWSIoTMQTTShadowClient.configureEndpoint(host, 8883) - myAWSIoTMQTTShadowClient.configureCredentials(rootCAPath, privateKeyPath, certificatePath) + myAWSIoTMQTTShadowClient = AWSIoTMQTTShadowClient(clientId) + myAWSIoTMQTTShadowClient.configureEndpoint(host, port) + myAWSIoTMQTTShadowClient.configureCredentials(rootCAPath, privateKeyPath, certificatePath) # AWSIoTMQTTShadowClient configuration myAWSIoTMQTTShadowClient.configureAutoReconnectBackoffTime(1, 32, 20) @@ -107,4 +117,4 @@ def customShadowCallback_Delta(payload, responseStatus, token): # Loop forever while True: - time.sleep(1) + time.sleep(1) diff --git a/samples/basicShadow/basicShadowUpdater.py b/samples/basicShadow/basicShadowUpdater.py index 8b7d39f..2f9b9e2 100755 --- a/samples/basicShadow/basicShadowUpdater.py +++ b/samples/basicShadow/basicShadowUpdater.py @@ -1,6 +1,6 @@ ''' /* - * Copyright 2010-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). * You may not use this file except in compliance with the License. @@ -16,7 +16,6 @@ ''' from AWSIoTPythonSDK.MQTTLib import AWSIoTMQTTShadowClient -import sys import logging import time import json @@ -65,6 +64,7 @@ def customShadowCallback_Delete(payload, responseStatus, token): parser.add_argument("-r", "--rootCA", action="store", required=True, dest="rootCAPath", help="Root CA file path") parser.add_argument("-c", "--cert", action="store", dest="certificatePath", help="Certificate file path") parser.add_argument("-k", "--key", action="store", dest="privateKeyPath", help="Private key file path") +parser.add_argument("-p", "--port", action="store", dest="port", type=int, help="Port number override") parser.add_argument("-w", "--websocket", action="store_true", dest="useWebsocket", default=False, help="Use MQTT over WebSocket") parser.add_argument("-n", "--thingName", action="store", dest="thingName", default="Bot", help="Targeted thing name") @@ -75,6 +75,7 @@ def customShadowCallback_Delete(payload, responseStatus, token): rootCAPath = args.rootCAPath certificatePath = args.certificatePath privateKeyPath = args.privateKeyPath +port = args.port useWebsocket = args.useWebsocket thingName = args.thingName clientId = args.clientId @@ -87,6 +88,12 @@ def customShadowCallback_Delete(payload, responseStatus, token): parser.error("Missing credentials for authentication.") exit(2) +# Port defaults +if args.useWebsocket and not args.port: # When no port override for WebSocket, default to 443 + port = 443 +if not args.useWebsocket and not args.port: # When no port override for non-WebSocket, default to 8883 + port = 8883 + # Configure logging logger = logging.getLogger("AWSIoTPythonSDK.core") logger.setLevel(logging.DEBUG) @@ -99,11 +106,11 @@ def customShadowCallback_Delete(payload, responseStatus, token): myAWSIoTMQTTShadowClient = None if useWebsocket: myAWSIoTMQTTShadowClient = AWSIoTMQTTShadowClient(clientId, useWebsocket=True) - myAWSIoTMQTTShadowClient.configureEndpoint(host, 443) + myAWSIoTMQTTShadowClient.configureEndpoint(host, port) myAWSIoTMQTTShadowClient.configureCredentials(rootCAPath) else: myAWSIoTMQTTShadowClient = AWSIoTMQTTShadowClient(clientId) - myAWSIoTMQTTShadowClient.configureEndpoint(host, 8883) + myAWSIoTMQTTShadowClient.configureEndpoint(host, port) myAWSIoTMQTTShadowClient.configureCredentials(rootCAPath, privateKeyPath, certificatePath) # AWSIoTMQTTShadowClient configuration diff --git a/samples/greengrass/basicDiscovery.py b/samples/greengrass/basicDiscovery.py new file mode 100644 index 0000000..a6fcd61 --- /dev/null +++ b/samples/greengrass/basicDiscovery.py @@ -0,0 +1,188 @@ +# /* +# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# * +# * Licensed under the Apache License, Version 2.0 (the "License"). +# * You may not use this file except in compliance with the License. +# * A copy of the License is located at +# * +# * http://aws.amazon.com/apache2.0 +# * +# * or in the "license" file accompanying this file. This file is distributed +# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# * express or implied. See the License for the specific language governing +# * permissions and limitations under the License. +# */ + + +import os +import sys +import time +import uuid +import json +import logging +import argparse +from AWSIoTPythonSDK.core.greengrass.discovery.providers import DiscoveryInfoProvider +from AWSIoTPythonSDK.core.protocol.connection.cores import ProgressiveBackOffCore +from AWSIoTPythonSDK.MQTTLib import AWSIoTMQTTClient +from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryInvalidRequestException + +AllowedActions = ['both', 'publish', 'subscribe'] + +# General message notification callback +def customOnMessage(message): + print('Received message on topic %s: %s\n' % (message.topic, message.payload)) + +MAX_DISCOVERY_RETRIES = 10 +GROUP_CA_PATH = "./groupCA/" + +# Read in command-line parameters +parser = argparse.ArgumentParser() +parser.add_argument("-e", "--endpoint", action="store", required=True, dest="host", help="Your AWS IoT custom endpoint") +parser.add_argument("-r", "--rootCA", action="store", required=True, dest="rootCAPath", help="Root CA file path") +parser.add_argument("-c", "--cert", action="store", dest="certificatePath", help="Certificate file path") +parser.add_argument("-k", "--key", action="store", dest="privateKeyPath", help="Private key file path") +parser.add_argument("-n", "--thingName", action="store", dest="thingName", default="Bot", help="Targeted thing name") +parser.add_argument("-t", "--topic", action="store", dest="topic", default="sdk/test/Python", help="Targeted topic") +parser.add_argument("-m", "--mode", action="store", dest="mode", default="both", + help="Operation modes: %s"%str(AllowedActions)) +parser.add_argument("-M", "--message", action="store", dest="message", default="Hello World!", + help="Message to publish") +#--print_discover_resp_only used for delopyment testing. The test run will return 0 as long as the SDK installed correctly. +parser.add_argument("-p", "--print_discover_resp_only", action="store_true", dest="print_only", default=False) + +args = parser.parse_args() +host = args.host +rootCAPath = args.rootCAPath +certificatePath = args.certificatePath +privateKeyPath = args.privateKeyPath +clientId = args.thingName +thingName = args.thingName +topic = args.topic +print_only = args.print_only + +if args.mode not in AllowedActions: + parser.error("Unknown --mode option %s. Must be one of %s" % (args.mode, str(AllowedActions))) + exit(2) + +if not args.certificatePath or not args.privateKeyPath: + parser.error("Missing credentials for authentication, you must specify --cert and --key args.") + exit(2) + +if not os.path.isfile(rootCAPath): + parser.error("Root CA path does not exist {}".format(rootCAPath)) + exit(3) + +if not os.path.isfile(certificatePath): + parser.error("No certificate found at {}".format(certificatePath)) + exit(3) + +if not os.path.isfile(privateKeyPath): + parser.error("No private key found at {}".format(privateKeyPath)) + exit(3) + +# Configure logging +logger = logging.getLogger("AWSIoTPythonSDK.core") +logger.setLevel(logging.DEBUG) +streamHandler = logging.StreamHandler() +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +streamHandler.setFormatter(formatter) +logger.addHandler(streamHandler) + +# Progressive back off core +backOffCore = ProgressiveBackOffCore() + +# Discover GGCs +discoveryInfoProvider = DiscoveryInfoProvider() +discoveryInfoProvider.configureEndpoint(host) +discoveryInfoProvider.configureCredentials(rootCAPath, certificatePath, privateKeyPath) +discoveryInfoProvider.configureTimeout(10) # 10 sec + +retryCount = MAX_DISCOVERY_RETRIES if not print_only else 1 +discovered = False +groupCA = None +coreInfo = None +while retryCount != 0: + try: + discoveryInfo = discoveryInfoProvider.discover(thingName) + caList = discoveryInfo.getAllCas() + coreList = discoveryInfo.getAllCores() + + # We only pick the first ca and core info + groupId, ca = caList[0] + coreInfo = coreList[0] + print("Discovered GGC: %s from Group: %s" % (coreInfo.coreThingArn, groupId)) + + print("Now we persist the connectivity/identity information...") + groupCA = GROUP_CA_PATH + groupId + "_CA_" + str(uuid.uuid4()) + ".crt" + if not os.path.exists(GROUP_CA_PATH): + os.makedirs(GROUP_CA_PATH) + groupCAFile = open(groupCA, "w") + groupCAFile.write(ca) + groupCAFile.close() + + discovered = True + print("Now proceed to the connecting flow...") + break + except DiscoveryInvalidRequestException as e: + print("Invalid discovery request detected!") + print("Type: %s" % str(type(e))) + print("Error message: %s" % str(e)) + print("Stopping...") + break + except BaseException as e: + print("Error in discovery!") + print("Type: %s" % str(type(e))) + print("Error message: %s" % str(e)) + retryCount -= 1 + print("\n%d/%d retries left\n" % (retryCount, MAX_DISCOVERY_RETRIES)) + print("Backing off...\n") + backOffCore.backOff() + +if not discovered: + # With print_discover_resp_only flag, we only woud like to check if the API get called correctly. + if print_only: + sys.exit(0) + print("Discovery failed after %d retries. Exiting...\n" % (MAX_DISCOVERY_RETRIES)) + sys.exit(-1) + +# Iterate through all connection options for the core and use the first successful one +myAWSIoTMQTTClient = AWSIoTMQTTClient(clientId) +myAWSIoTMQTTClient.configureCredentials(groupCA, privateKeyPath, certificatePath) +myAWSIoTMQTTClient.onMessage = customOnMessage + +connected = False +for connectivityInfo in coreInfo.connectivityInfoList: + currentHost = connectivityInfo.host + currentPort = connectivityInfo.port + print("Trying to connect to core at %s:%d" % (currentHost, currentPort)) + myAWSIoTMQTTClient.configureEndpoint(currentHost, currentPort) + try: + myAWSIoTMQTTClient.connect() + connected = True + break + except BaseException as e: + print("Error in connect!") + print("Type: %s" % str(type(e))) + print("Error message: %s" % str(e)) + +if not connected: + print("Cannot connect to core %s. Exiting..." % coreInfo.coreThingArn) + sys.exit(-2) + +# Successfully connected to the core +if args.mode == 'both' or args.mode == 'subscribe': + myAWSIoTMQTTClient.subscribe(topic, 0, None) +time.sleep(2) + +loopCount = 0 +while True: + if args.mode == 'both' or args.mode == 'publish': + message = {} + message['message'] = args.message + message['sequence'] = loopCount + messageJson = json.dumps(message) + myAWSIoTMQTTClient.publish(topic, messageJson, 0) + if args.mode == 'publish': + print('Published topic %s: %s\n' % (topic, messageJson)) + loopCount += 1 + time.sleep(1) diff --git a/samples/jobs/jobsSample.py b/samples/jobs/jobsSample.py new file mode 100644 index 0000000..7cbd27e --- /dev/null +++ b/samples/jobs/jobsSample.py @@ -0,0 +1,178 @@ +''' +/* + * Copyright 2010-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + ''' + +from AWSIoTPythonSDK.MQTTLib import AWSIoTMQTTClient +from AWSIoTPythonSDK.MQTTLib import AWSIoTMQTTThingJobsClient +from AWSIoTPythonSDK.core.jobs.thingJobManager import jobExecutionTopicType +from AWSIoTPythonSDK.core.jobs.thingJobManager import jobExecutionTopicReplyType +from AWSIoTPythonSDK.core.jobs.thingJobManager import jobExecutionStatus + +import threading +import logging +import time +import datetime +import argparse +import json + +class JobsMessageProcessor(object): + def __init__(self, awsIoTMQTTThingJobsClient, clientToken): + #keep track of this to correlate request/responses + self.clientToken = clientToken + self.awsIoTMQTTThingJobsClient = awsIoTMQTTThingJobsClient + self.done = False + self.jobsStarted = 0 + self.jobsSucceeded = 0 + self.jobsRejected = 0 + self._setupCallbacks(self.awsIoTMQTTThingJobsClient) + + def _setupCallbacks(self, awsIoTMQTTThingJobsClient): + self.awsIoTMQTTThingJobsClient.createJobSubscription(self.newJobReceived, jobExecutionTopicType.JOB_NOTIFY_NEXT_TOPIC) + self.awsIoTMQTTThingJobsClient.createJobSubscription(self.startNextJobSuccessfullyInProgress, jobExecutionTopicType.JOB_START_NEXT_TOPIC, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE) + self.awsIoTMQTTThingJobsClient.createJobSubscription(self.startNextRejected, jobExecutionTopicType.JOB_START_NEXT_TOPIC, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE) + + # '+' indicates a wildcard for jobId in the following subscriptions + self.awsIoTMQTTThingJobsClient.createJobSubscription(self.updateJobSuccessful, jobExecutionTopicType.JOB_UPDATE_TOPIC, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE, '+') + self.awsIoTMQTTThingJobsClient.createJobSubscription(self.updateJobRejected, jobExecutionTopicType.JOB_UPDATE_TOPIC, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE, '+') + + #call back on successful job updates + def startNextJobSuccessfullyInProgress(self, client, userdata, message): + payload = json.loads(message.payload.decode('utf-8')) + if 'execution' in payload: + self.jobsStarted += 1 + execution = payload['execution'] + self.executeJob(execution) + statusDetails = {'HandledBy': 'ClientToken: {}'.format(self.clientToken)} + threading.Thread(target = self.awsIoTMQTTThingJobsClient.sendJobsUpdate, kwargs = {'jobId': execution['jobId'], 'status': jobExecutionStatus.JOB_EXECUTION_SUCCEEDED, 'statusDetails': statusDetails, 'expectedVersion': execution['versionNumber'], 'executionNumber': execution['executionNumber']}).start() + else: + print('Start next saw no execution: ' + message.payload.decode('utf-8')) + self.done = True + + def executeJob(self, execution): + print('Executing job ID, version, number: {}, {}, {}'.format(execution['jobId'], execution['versionNumber'], execution['executionNumber'])) + print('With jobDocument: ' + json.dumps(execution['jobDocument'])) + + def newJobReceived(self, client, userdata, message): + payload = json.loads(message.payload.decode('utf-8')) + if 'execution' in payload: + self._attemptStartNextJob() + else: + print('Notify next saw no execution') + self.done = True + + def processJobs(self): + self.done = False + self._attemptStartNextJob() + + def startNextRejected(self, client, userdata, message): + printf('Start next rejected:' + message.payload.decode('utf-8')) + self.jobsRejected += 1 + + def updateJobSuccessful(self, client, userdata, message): + self.jobsSucceeded += 1 + + def updateJobRejected(self, client, userdata, message): + self.jobsRejected += 1 + + def _attemptStartNextJob(self): + statusDetails = {'StartedBy': 'ClientToken: {} on {}'.format(self.clientToken, datetime.datetime.now().isoformat())} + threading.Thread(target=self.awsIoTMQTTThingJobsClient.sendJobsStartNext, kwargs = {'statusDetails': statusDetails}).start() + + def isDone(self): + return self.done + + def getStats(self): + stats = {} + stats['jobsStarted'] = self.jobsStarted + stats['jobsSucceeded'] = self.jobsSucceeded + stats['jobsRejected'] = self.jobsRejected + return stats + +# Read in command-line parameters +parser = argparse.ArgumentParser() +parser.add_argument("-n", "--thingName", action="store", dest="thingName", help="Your AWS IoT ThingName to process jobs for") +parser.add_argument("-e", "--endpoint", action="store", required=True, dest="host", help="Your AWS IoT custom endpoint") +parser.add_argument("-r", "--rootCA", action="store", required=True, dest="rootCAPath", help="Root CA file path") +parser.add_argument("-c", "--cert", action="store", dest="certificatePath", help="Certificate file path") +parser.add_argument("-k", "--key", action="store", dest="privateKeyPath", help="Private key file path") +parser.add_argument("-p", "--port", action="store", dest="port", type=int, help="Port number override") +parser.add_argument("-w", "--websocket", action="store_true", dest="useWebsocket", default=False, + help="Use MQTT over WebSocket") +parser.add_argument("-id", "--clientId", action="store", dest="clientId", default="basicJobsSampleClient", + help="Targeted client id") + +args = parser.parse_args() +host = args.host +rootCAPath = args.rootCAPath +certificatePath = args.certificatePath +privateKeyPath = args.privateKeyPath +port = args.port +useWebsocket = args.useWebsocket +clientId = args.clientId +thingName = args.thingName + +if args.useWebsocket and args.certificatePath and args.privateKeyPath: + parser.error("X.509 cert authentication and WebSocket are mutual exclusive. Please pick one.") + exit(2) + +if not args.useWebsocket and (not args.certificatePath or not args.privateKeyPath): + parser.error("Missing credentials for authentication.") + exit(2) + +# Port defaults +if args.useWebsocket and not args.port: # When no port override for WebSocket, default to 443 + port = 443 +if not args.useWebsocket and not args.port: # When no port override for non-WebSocket, default to 8883 + port = 8883 + +# Configure logging +logger = logging.getLogger("AWSIoTPythonSDK.core") +logger.setLevel(logging.DEBUG) +streamHandler = logging.StreamHandler() +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +streamHandler.setFormatter(formatter) +logger.addHandler(streamHandler) + +# Init AWSIoTMQTTClient +myAWSIoTMQTTClient = None +if useWebsocket: + myAWSIoTMQTTClient = AWSIoTMQTTClient(clientId, useWebsocket=True) + myAWSIoTMQTTClient.configureEndpoint(host, port) + myAWSIoTMQTTClient.configureCredentials(rootCAPath) +else: + myAWSIoTMQTTClient = AWSIoTMQTTClient(clientId) + myAWSIoTMQTTClient.configureEndpoint(host, port) + myAWSIoTMQTTClient.configureCredentials(rootCAPath, privateKeyPath, certificatePath) + +# AWSIoTMQTTClient connection configuration +myAWSIoTMQTTClient.configureAutoReconnectBackoffTime(1, 32, 20) +myAWSIoTMQTTClient.configureConnectDisconnectTimeout(10) # 10 sec +myAWSIoTMQTTClient.configureMQTTOperationTimeout(10) # 5 sec + +jobsClient = AWSIoTMQTTThingJobsClient(clientId, thingName, QoS=1, awsIoTMQTTClient=myAWSIoTMQTTClient) + +print('Connecting to MQTT server and setting up callbacks...') +jobsClient.connect() +jobsMsgProc = JobsMessageProcessor(jobsClient, clientId) +print('Starting to process jobs...') +jobsMsgProc.processJobs() +while not jobsMsgProc.isDone(): + time.sleep(2) + +print('Done processing jobs') +print('Stats: ' + json.dumps(jobsMsgProc.getStats())) + +jobsClient.disconnect() diff --git a/setup.py b/setup.py index 86ba48a..0ca4cfa 100644 --- a/setup.py +++ b/setup.py @@ -3,14 +3,15 @@ import AWSIoTPythonSDK currentVersion = AWSIoTPythonSDK.__version__ -from distutils.core import setup +from setuptools import setup setup( name = 'AWSIoTPythonSDK', - packages = ['AWSIoTPythonSDK', "AWSIoTPythonSDK.core", \ - "AWSIoTPythonSDK.exception", "AWSIoTPythonSDK.core.shadow", \ - "AWSIoTPythonSDK.core.util", \ - "AWSIoTPythonSDK.core.protocol", "AWSIoTPythonSDK.core.protocol.paho", \ - "AWSIoTPythonSDK.core.protocol.paho.securedWebsocket"], + packages=['AWSIoTPythonSDK', 'AWSIoTPythonSDK.core', + 'AWSIoTPythonSDK.core.util', 'AWSIoTPythonSDK.core.shadow', 'AWSIoTPythonSDK.core.protocol', + 'AWSIoTPythonSDK.core.jobs', + 'AWSIoTPythonSDK.core.protocol.paho', 'AWSIoTPythonSDK.core.protocol.internal', + 'AWSIoTPythonSDK.core.protocol.connection', 'AWSIoTPythonSDK.core.greengrass', + 'AWSIoTPythonSDK.core.greengrass.discovery', 'AWSIoTPythonSDK.exception'], version = currentVersion, description = 'SDK for connecting to AWS IoT using Python.', author = 'Amazon Web Service', @@ -19,15 +20,11 @@ download_url = 'https://s3.amazonaws.com/aws-iot-device-sdk-python/aws-iot-device-sdk-python-latest.zip', keywords = ['aws', 'iot', 'mqtt'], classifiers = [ - "Development Status :: 5 - Production/Stable", \ - "Intended Audience :: Developers", \ - "Natural Language :: English", \ - "License :: OSI Approved :: Apache Software License", \ - "Programming Language :: Python", \ - "Programming Language :: Python :: 2.7", \ - "Programming Language :: Python :: 3", \ - "Programming Language :: Python :: 3.3", \ - "Programming Language :: Python :: 3.4", \ - "Programming Language :: Python :: 3.5" + "Development Status :: 6 - Mature", + "Intended Audience :: Developers", + "Natural Language :: English", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3" ] ) diff --git a/setup_test.py b/setup_test.py new file mode 100644 index 0000000..2a4e78e --- /dev/null +++ b/setup_test.py @@ -0,0 +1,34 @@ +# For test deployment with package AWSIoTPythonSDK. The package name has already taken. Therefore we used an +# alternative name for test pypi. +# prod_pypi : AWSIoTPythonSDK +# test_pypi : AWSIoTPythonSDK-V1 +import sys +sys.path.insert(0, 'AWSIoTPythonSDK') +import AWSIoTPythonSDK +currentVersion = AWSIoTPythonSDK.__version__ + +from distutils.core import setup +setup( + name = 'AWSIoTPythonSDK-V1', + packages=['AWSIoTPythonSDK', 'AWSIoTPythonSDK.core', + 'AWSIoTPythonSDK.core.util', 'AWSIoTPythonSDK.core.shadow', 'AWSIoTPythonSDK.core.protocol', + 'AWSIoTPythonSDK.core.jobs', + 'AWSIoTPythonSDK.core.protocol.paho', 'AWSIoTPythonSDK.core.protocol.internal', + 'AWSIoTPythonSDK.core.protocol.connection', 'AWSIoTPythonSDK.core.greengrass', + 'AWSIoTPythonSDK.core.greengrass.discovery', 'AWSIoTPythonSDK.exception'], + version = currentVersion, + description = 'SDK for connecting to AWS IoT using Python.', + author = 'Amazon Web Service', + author_email = '', + url = 'https://github.com/aws/aws-iot-device-sdk-python.git', + download_url = 'https://s3.amazonaws.com/aws-iot-device-sdk-python/aws-iot-device-sdk-python-latest.zip', + keywords = ['aws', 'iot', 'mqtt'], + classifiers = [ + "Development Status :: 6 - Mature", + "Intended Audience :: Developers", + "Natural Language :: English", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + ] +) diff --git a/test-integration/Credentials/.gitignore b/test-integration/Credentials/.gitignore new file mode 100644 index 0000000..94548af --- /dev/null +++ b/test-integration/Credentials/.gitignore @@ -0,0 +1,3 @@ +* +*/ +!.gitignore diff --git a/test-integration/IntegrationTests/IntegrationTestAsyncAPIGeneralNotificationCallbacks.py b/test-integration/IntegrationTests/IntegrationTestAsyncAPIGeneralNotificationCallbacks.py new file mode 100644 index 0000000..577c5fa --- /dev/null +++ b/test-integration/IntegrationTests/IntegrationTestAsyncAPIGeneralNotificationCallbacks.py @@ -0,0 +1,159 @@ +# This integration test verifies the functionality of asynchronous API for plain MQTT operations, as well as general +# notification callbacks. There are 2 phases for this test: +# a) Testing async APIs + onMessage general notification callback +# b) Testing onOnline, onOffline notification callbacks +# To achieve test goal a) and b), the client will follow the routine described below: +# 1. Client does async connect to AWS IoT and captures the CONNACK event and onOnline callback event in the record +# 2. Client does async subscribe to a topic and captures the SUBACK event in the record +# 3. Client does several async publish (QoS1) to the same topic and captures the PUBACK event in the record +# 4. Since client subscribes and publishes to the same topic, onMessage callback should be triggered. We capture these +# events as well in the record. +# 5. Client does async disconnect. This would trigger the offline callback and disconnect event callback. We capture +# them in the record. +# We should be able to receive all ACKs for all operations and corresponding general notification callback triggering +# events. + + +import random +import string +import time +import sys +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary") +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary/SDKPackage") + +from TestToolLibrary.checkInManager import checkInManager +from TestToolLibrary.MQTTClientManager import MQTTClientManager +from TestToolLibrary.skip import skip_when_match +from TestToolLibrary.skip import ModeIsALPN +from TestToolLibrary.skip import Python2VersionLowerThan +from TestToolLibrary.skip import Python3VersionLowerThan + + +TOPIC = "topic/test/async_cb/" +MESSAGE_PREFIX = "MagicMessage-" +NUMBER_OF_PUBLISHES = 3 +ROOT_CA = "./test-integration/Credentials/rootCA.crt" +CERT = "./test-integration/Credentials/certificate.pem.crt" +KEY = "./test-integration/Credentials/privateKey.pem.key" +CLIENT_ID = "PySdkIntegTest_AsyncAPI_Callbacks" + +KEY_ON_ONLINE = "OnOnline" +KEY_ON_OFFLINE = "OnOffline" +KEY_ON_MESSAGE = "OnMessage" +KEY_CONNACK = "Connack" +KEY_DISCONNECT = "Disconnect" +KEY_PUBACK = "Puback" +KEY_SUBACK = "Suback" +KEY_UNSUBACK = "Unsuback" + + +class CallbackManager(object): + + def __init__(self): + self.callback_invocation_record = { + KEY_ON_ONLINE : 0, + KEY_ON_OFFLINE : 0, + KEY_ON_MESSAGE : 0, + KEY_CONNACK : 0, + KEY_DISCONNECT : 0, + KEY_PUBACK : 0, + KEY_SUBACK : 0, + KEY_UNSUBACK : 0 + } + + def on_online(self): + print("OMG, I am online!") + self.callback_invocation_record[KEY_ON_ONLINE] += 1 + + def on_offline(self): + print("OMG, I am offline!") + self.callback_invocation_record[KEY_ON_OFFLINE] += 1 + + def on_message(self, message): + print("OMG, I got a message!") + self.callback_invocation_record[KEY_ON_MESSAGE] += 1 + + def connack(self, mid, data): + print("OMG, I got a connack!") + self.callback_invocation_record[KEY_CONNACK] += 1 + + def disconnect(self, mid, data): + print("OMG, I got a disconnect!") + self.callback_invocation_record[KEY_DISCONNECT] += 1 + + def puback(self, mid): + print("OMG, I got a puback!") + self.callback_invocation_record[KEY_PUBACK] += 1 + + def suback(self, mid, data): + print("OMG, I got a suback!") + self.callback_invocation_record[KEY_SUBACK] += 1 + + def unsuback(self, mid): + print("OMG, I got an unsuback!") + self.callback_invocation_record[KEY_UNSUBACK] += 1 + + +def get_random_string(length): + return "".join(random.choice(string.ascii_lowercase) for i in range(length)) + + +############################################################################ +# Main # +# Check inputs +my_check_in_manager = checkInManager(2) +my_check_in_manager.verify(sys.argv) +mode = my_check_in_manager.mode +host = my_check_in_manager.host + +skip_when_match(ModeIsALPN(mode).And( + Python2VersionLowerThan((2, 7, 10)).Or(Python3VersionLowerThan((3, 5, 0))) +), "This test is not applicable for mode %s and Python verison %s. Skipping..." % (mode, sys.version_info[:3])) + +# Performing +############ +print("Connecting...") +callback_manager = CallbackManager() +sdk_mqtt_client = MQTTClientManager()\ + .create_nonconnected_mqtt_client(mode, CLIENT_ID, host, (ROOT_CA, CERT, KEY), callback_manager) +sdk_mqtt_client.connectAsync(keepAliveIntervalSecond=1, ackCallback=callback_manager.connack) # Add callback +print("Wait some time to make sure we are connected...") +time.sleep(10) # 10 sec + +topic = TOPIC + get_random_string(4) +print("Subscribing to topic: " + topic) +sdk_mqtt_client.subscribeAsync(topic, 1, ackCallback=callback_manager.suback, messageCallback=None) +print("Wait some time to make sure we are subscribed...") +time.sleep(3) # 3 sec + +print("Publishing...") +for i in range(NUMBER_OF_PUBLISHES): + sdk_mqtt_client.publishAsync(topic, MESSAGE_PREFIX + str(i), 1, ackCallback=callback_manager.puback) + time.sleep(1) +print("Wait sometime to make sure we finished with publishing...") +time.sleep(2) + +print("Unsubscribing...") +sdk_mqtt_client.unsubscribeAsync(topic, ackCallback=callback_manager.unsuback) +print("Wait sometime to make sure we finished with unsubscribing...") +time.sleep(2) + +print("Disconnecting...") +sdk_mqtt_client.disconnectAsync(ackCallback=callback_manager.disconnect) + +print("Wait sometime to let the test result sync...") +time.sleep(3) + +print("Verifying...") +try: + assert callback_manager.callback_invocation_record[KEY_ON_ONLINE] == 1 + assert callback_manager.callback_invocation_record[KEY_CONNACK] == 1 + assert callback_manager.callback_invocation_record[KEY_SUBACK] == 1 + assert callback_manager.callback_invocation_record[KEY_PUBACK] == NUMBER_OF_PUBLISHES + assert callback_manager.callback_invocation_record[KEY_ON_MESSAGE] == NUMBER_OF_PUBLISHES + assert callback_manager.callback_invocation_record[KEY_UNSUBACK] == 1 + assert callback_manager.callback_invocation_record[KEY_DISCONNECT] == 1 + assert callback_manager.callback_invocation_record[KEY_ON_OFFLINE] == 1 +except BaseException as e: + print("Failed! %s" % e.message) +print("Pass!") diff --git a/test-integration/IntegrationTests/IntegrationTestAutoReconnectResubscribe.py b/test-integration/IntegrationTests/IntegrationTestAutoReconnectResubscribe.py new file mode 100644 index 0000000..e6c1bee --- /dev/null +++ b/test-integration/IntegrationTests/IntegrationTestAutoReconnectResubscribe.py @@ -0,0 +1,202 @@ +# This integration test verifies the functionality in the Python core of Yun/Python SDK +# for auto-reconnect and auto-resubscribe. +# It starts two threads using two different connections to AWS IoT: +# Thread A publishes 10 messages to topicB first, then quiet for a while, and finally +# publishes another 10 messages to topicB. +# Thread B subscribes to topicB and waits to receive messages. Once it receives the first +# 10 messages. It simulates a network error, disconnecting from the broker. In a short time, +# it should automatically reconnect and resubscribe to the previous topic and be able to +# receive the next 10 messages from thread A. +# Because of auto-reconnect/resubscribe, thread B should be able to receive all of the +# messages from topicB published by thread A without calling subscribe again in user code +# explicitly. + + +import random +import string +import sys +import time +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary") +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary/SDKPackage") + +import TestToolLibrary.checkInManager as checkInManager +import TestToolLibrary.MQTTClientManager as MQTTClientManager +from TestToolLibrary import simpleThreadManager +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import publishError +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeError +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeTimeoutException +from TestToolLibrary.skip import skip_when_match +from TestToolLibrary.skip import ModeIsALPN +from TestToolLibrary.skip import Python2VersionLowerThan +from TestToolLibrary.skip import Python3VersionLowerThan + +CLIENT_ID_PUB = "integrationTestMQTT_ClientPub" + "".join(random.choice(string.ascii_lowercase) for i in range(4)) +CLIENT_ID_SUB = "integrationTestMQTT_ClientSub" + "".join(random.choice(string.ascii_lowercase) for i in range(4)) + +# Callback unit +class callbackUnit: + def __init__(self): + self._internalSet = set() + + # Callback fro clientSub + def messageCallback(self, client, userdata, message): + print("Received a new message: " + str(message.payload)) + self._internalSet.add(message.payload.decode('utf-8')) + + def getInternalSet(self): + return self._internalSet + + +# Simulate a network error +def manualNetworkError(srcPyMQTTCore): + # Ensure we close the socket + if srcPyMQTTCore._internal_async_client._paho_client._sock: + srcPyMQTTCore._internal_async_client._paho_client._sock.close() + srcPyMQTTCore._internal_async_client._paho_client._sock = None + if srcPyMQTTCore._internal_async_client._paho_client._ssl: + srcPyMQTTCore._internal_async_client._paho_client._ssl.close() + srcPyMQTTCore._internal_async_client._paho_client._ssl = None + # Fake that we have detected the disconnection + srcPyMQTTCore._internal_async_client._paho_client.on_disconnect(None, None, 0) + + +# runFunctionUnit +class runFunctionUnit(): + def __init__(self): + self._messagesPublished = set() + self._topicB = "topicB/" + "".join(random.choice(string.ascii_lowercase) for i in range(4)) + + # ThreadA runtime function: + # 1. Publish 10 messages to topicB. + # 2. Take a nap: 20 sec + # 3. Publish another 10 messages to topicB. + def threadARuntime(self, pyCoreClient): + time.sleep(3) # Ensure a valid subscription + messageCount = 0 + # First 10 messages + while messageCount < 10: + try: + pyCoreClient.publish(self._topicB, str(messageCount), 1, False) + self._messagesPublished.add(str(messageCount)) + except publishError: + print("Publish error!") + except Exception as e: + print("Unknown exception!") + print("Type: " + str(type(e))) + print("Message: " + str(e.message)) + messageCount += 1 + time.sleep(0.5) # TPS = 2 + # Take a nap + time.sleep(20) + # Second 10 messages + while messageCount < 20: + try: + pyCoreClient.publish(self._topicB, str(messageCount), 1, False) + self._messagesPublished.add(str(messageCount)) + except publishError: + print("Publish Error!") + except Exception as e: + print("Unknown exception!") + print("Type: " + str(type(e))) + print("Message: " + str(e.message)) + messageCount += 1 + time.sleep(0.5) + print("Publish thread terminated.") + + # ThreadB runtime function: + # 1. Subscribe to topicB + # 2. Wait for a while + # 3. Create network blocking, triggering auto-reconnect and auto-resubscribe + # 4. On connect, wait for another while + def threadBRuntime(self, pyCoreClient, callback): + try: + # Subscribe to topicB + pyCoreClient.subscribe(self._topicB, 1, callback) + except subscribeTimeoutException: + print("Subscribe timeout!") + except subscribeError: + print("Subscribe error!") + except Exception as e: + print("Unknown exception!") + print("Type: " + str(type(e))) + print("Message: " + str(e.message)) + # Wait to get the first 10 messages from thread A + time.sleep(10) + # Block the network for 3 sec + print("Block the network for 3 sec...") + blockingTimeTenMs = 300 + while blockingTimeTenMs != 0: + manualNetworkError(pyCoreClient) + blockingTimeTenMs -= 1 + time.sleep(0.01) + print("Leave it to the main thread to keep waiting...") + + +############################################################################ +# Main # +# Check inputs +myCheckInManager = checkInManager.checkInManager(2) +myCheckInManager.verify(sys.argv) + +rootCA = "./test-integration/Credentials/rootCA.crt" +certificate = "./test-integration/Credentials/certificate.pem.crt" +privateKey = "./test-integration/Credentials/privateKey.pem.key" +mode = myCheckInManager.mode +host = myCheckInManager.host + +skip_when_match(ModeIsALPN(mode).And( + Python2VersionLowerThan((2, 7, 10)).Or(Python3VersionLowerThan((3, 5, 0))) +), "This test is not applicable for mode %s and Python verison %s. Skipping..." % (mode, sys.version_info[:3])) + +# Init Python core and connect +myMQTTClientManager = MQTTClientManager.MQTTClientManager() +clientPub = myMQTTClientManager.create_connected_mqtt_core(CLIENT_ID_PUB, host, rootCA, + certificate, privateKey, mode=mode) +clientSub = myMQTTClientManager.create_connected_mqtt_core(CLIENT_ID_SUB, host, rootCA, + certificate, privateKey, mode=mode) + +if clientPub is None or clientSub is None: + print("Clients not init!") + exit(4) + +print("Two clients are connected!") + +# Configurations +################ +# Callback unit +subCallbackUnit = callbackUnit() +# Threads +mySimpleThreadManager = simpleThreadManager.simpleThreadManager() +myRunFunctionUnit = runFunctionUnit() +publishThreadID = mySimpleThreadManager.createOneTimeThread(myRunFunctionUnit.threadARuntime, [clientPub]) +subscribeThreadID = mySimpleThreadManager.createOneTimeThread(myRunFunctionUnit.threadBRuntime, + [clientSub, subCallbackUnit.messageCallback]) + +# Performing +############ +mySimpleThreadManager.startThreadWithID(subscribeThreadID) +mySimpleThreadManager.startThreadWithID(publishThreadID) +mySimpleThreadManager.joinOneTimeThreadWithID(subscribeThreadID) +mySimpleThreadManager.joinOneTimeThreadWithID(publishThreadID) +time.sleep(3) # Just in case messages arrive slowly + +# Verifying +########### +# Length +print("Check if the length of the two sets are equal...") +print("Received from subscription: " + str(len(subCallbackUnit.getInternalSet()))) +print("Sent through publishes: " + str(len(myRunFunctionUnit._messagesPublished))) +if len(myRunFunctionUnit._messagesPublished) != len(subCallbackUnit.getInternalSet()): + print("Number of messages not equal!") + exit(4) +# Content +print("Check if the content if the two sets are equivalent...") +if myRunFunctionUnit._messagesPublished != subCallbackUnit.getInternalSet(): + print("Sent through publishes:") + print(myRunFunctionUnit._messagesPublished) + print("Received from subscription:") + print(subCallbackUnit.getInternalSet()) + print("Set content not equal!") + exit(4) +else: + print("Yes!") diff --git a/test-integration/IntegrationTests/IntegrationTestClientReusability.py b/test-integration/IntegrationTests/IntegrationTestClientReusability.py new file mode 100644 index 0000000..56e77b8 --- /dev/null +++ b/test-integration/IntegrationTests/IntegrationTestClientReusability.py @@ -0,0 +1,128 @@ +# This integration test verifies the re-usability of SDK MQTT client. +# By saying re-usability, we mean that users should be able to reuse +# the same SDK MQTT client object to connect and invoke other APIs +# after a disconnect API call has been invoked on that client object. +# This test contains 2 clients living 2 separate threads: +# 1. Thread publish: In this thread, a MQTT client will do the following +# in a loop: +# a. Connect to AWS IoT +# b. Publish several messages to a dedicated topic +# c. Disconnect from AWS IoT +# d. Sleep for a while +# 2. Thread subscribe: In this thread, a MQTT client will do nothing +# other than subscribing to a dedicated topic and counting the incoming +# messages. +# Assuming the client is reusable, the subscriber should be able to +# receive all the messages published by the publisher from the same +# client object in different connect sessions. + + +import uuid +import time +import sys +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary") +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary/SDKPackage") + +from threading import Event +from TestToolLibrary.checkInManager import checkInManager +from TestToolLibrary.simpleThreadManager import simpleThreadManager +from TestToolLibrary.MQTTClientManager import MQTTClientManager +from TestToolLibrary.skip import skip_when_match +from TestToolLibrary.skip import ModeIsALPN +from TestToolLibrary.skip import Python2VersionLowerThan +from TestToolLibrary.skip import Python3VersionLowerThan + + +TOPIC = "topic/" + str(uuid.uuid1()) +CLIENT_ID_PUB = "publisher" + str(uuid.uuid1()) +CLIENT_ID_SUB = "subscriber" + str(uuid.uuid1()) +MESSAGE_PREFIX = "Message-" +NUMBER_OF_MESSAGES_PER_LOOP = 3 +NUMBER_OF_LOOPS = 3 +SUB_WAIT_TIME_OUT_SEC = 20 +ROOT_CA = "./test-integration/Credentials/rootCA.crt" +CERT = "./test-integration/Credentials/certificate.pem.crt" +KEY = "./test-integration/Credentials/privateKey.pem.key" + + +class ClientTwins(object): + + def __init__(self, client_pub, client_sub): + self._client_pub = client_pub + self._client_sub = client_sub + self._message_publish_set = set() + self._message_receive_set = set() + self._publish_done = Event() + + def run_publisher(self, *params): + self._publish_done.clear() + time.sleep(3) + for i in range(NUMBER_OF_LOOPS): + self._single_publish_loop(i) + time.sleep(2) + self._publish_done.set() + + def _single_publish_loop(self, iteration_count): + print("In loop %d: " % iteration_count) + self._client_pub.connect() + print("Publisher connected!") + for i in range(NUMBER_OF_MESSAGES_PER_LOOP): + message = MESSAGE_PREFIX + str(iteration_count) + "_" + str(i) + self._client_pub.publish(TOPIC, message, 1) + print("Publisher published %s to topic %s" % (message, TOPIC)) + self._message_publish_set.add(message.encode("utf-8")) + time.sleep(1) + self._client_pub.disconnect() + print("Publisher disconnected!\n\n") + + def run_subscriber(self, *params): + self._client_sub.connect() + self._client_sub.subscribe(TOPIC, 1, self._callback) + self._publish_done.wait(20) + self._client_sub.disconnect() + + def _callback(self, client, user_data, message): + self._message_receive_set.add(message.payload) + print("Subscriber received %s from topic %s" % (message.payload, message.topic)) + + def verify(self): + assert len(self._message_receive_set) != 0 + assert len(self._message_publish_set) != 0 + assert self._message_publish_set == self._message_receive_set + + +############################################################################ +# Main # +my_check_in_manager = checkInManager(2) +my_check_in_manager.verify(sys.argv) +mode = my_check_in_manager.mode +host = my_check_in_manager.host + +skip_when_match(ModeIsALPN(mode).And( + Python2VersionLowerThan((2, 7, 10)).Or(Python3VersionLowerThan((3, 5, 0))) +), "This test is not applicable for mode %s and Python verison %s. Skipping..." % (mode, sys.version_info[:3])) + +simple_thread_manager = simpleThreadManager() + +client_pub = MQTTClientManager().create_nonconnected_mqtt_client(mode, CLIENT_ID_PUB, host, (ROOT_CA, CERT, KEY)) +print("Client publisher initialized.") +client_sub = MQTTClientManager().create_nonconnected_mqtt_client(mode, CLIENT_ID_SUB, host, (ROOT_CA, CERT, KEY)) +print("Client subscriber initialized.") +client_twins = ClientTwins(client_pub, client_sub) +print("Client twins initialized.") + +publisher_thread_id = simple_thread_manager.createOneTimeThread(client_twins.run_publisher, []) +subscriber_thread_id = simple_thread_manager.createOneTimeThread(client_twins.run_subscriber, []) +simple_thread_manager.startThreadWithID(subscriber_thread_id) +print("Started subscriber thread.") +simple_thread_manager.startThreadWithID(publisher_thread_id) +print("Started publisher thread.") + +print("Main thread starts waiting.") +simple_thread_manager.joinOneTimeThreadWithID(publisher_thread_id) +simple_thread_manager.joinOneTimeThreadWithID(subscriber_thread_id) +print("Main thread waiting is done!") + +print("Verifying...") +client_twins.verify() +print("Pass!") diff --git a/test-integration/IntegrationTests/IntegrationTestConfigurablePublishMessageQueueing.py b/test-integration/IntegrationTests/IntegrationTestConfigurablePublishMessageQueueing.py new file mode 100644 index 0000000..0d78f4f --- /dev/null +++ b/test-integration/IntegrationTests/IntegrationTestConfigurablePublishMessageQueueing.py @@ -0,0 +1,305 @@ +# This integration test verifies the functionality in the Python core of Yun SDK +# for configurable offline publish message queueing. +# For each offline publish queue to be tested, it starts two threads using +# different connections to AWS IoT: +# Thread A subscribes to TopicOnly and wait to receive messages published to +# TopicOnly from ThreadB. +# Thread B publishes to TopicOnly with a manual network error which triggers the +# offline publish message queueing. According to different configurations, the +# internal queue should keep as many publish requests as configured and then +# republish them once the connection is back. +# * After the network is down but before the client gets the notification of being +# * disconnected, QoS0 messages in between this "blind-window" will be lost. However, +# * once the client gets the notification, it should start queueing messages up to +# * its queue size limit. +# * Therefore, all published messages are QoS0, we are verifying the total amount. +# * Configuration to be tested: +# 1. Limited queueing section, limited response (in-flight) section, drop oldest +# 2. Limited queueing section, limited response (in-flight) section, drop newest + + +import threading +import sys +import time +import random +import string +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary") +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary/SDKPackage") + +import TestToolLibrary.checkInManager as checkInManager +import TestToolLibrary.MQTTClientManager as MQTTClientManager +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.core.util.enums import DropBehaviorTypes +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import publishError +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import publishQueueFullException +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeError +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeTimeoutException +from TestToolLibrary.skip import skip_when_match +from TestToolLibrary.skip import ModeIsALPN +from TestToolLibrary.skip import Python2VersionLowerThan +from TestToolLibrary.skip import Python3VersionLowerThan + +CLIENT_ID_PUB = "integrationTestMQTT_ClientPub" + "".join(random.choice(string.ascii_lowercase) for i in range(4)) +CLIENT_ID_SUB = "integrationTestMQTT_ClientSub" + "".join(random.choice(string.ascii_lowercase) for i in range(4)) + +# Class that implements the publishing thread: Thread A, with network failure +# This thread will publish 3 messages first, and then keep publishing +# with a network failure, and then publish another set of 3 messages +# once the connection is resumed. +# * TPS = 1 +class threadPub: + def __init__(self, pyCoreClient, numberOfOfflinePublish, srcTopic): + self._publishMessagePool = list() + self._pyCoreClient = pyCoreClient + self._numberOfOfflinePublish = numberOfOfflinePublish + self._topic = srcTopic + + # Simulate a network error + def _manualNetworkError(self): + # Ensure we close the socket + if self._pyCoreClient._internal_async_client._paho_client._sock: + self._pyCoreClient._internal_async_client._paho_client._sock.close() + self._pyCoreClient._internal_async_client._paho_client._sock = None + if self._pyCoreClient._internal_async_client._paho_client._ssl: + self._pyCoreClient._internal_async_client._paho_client._ssl.close() + self._pyCoreClient._internal_async_client._paho_client._ssl = None + # Fake that we have detected the disconnection + self._pyCoreClient._internal_async_client._paho_client.on_disconnect(None, None, 0) + + def _runtime(self): + messageCount = 0 + # Publish 3 messages + print("Thread A: Publish 3 messages.") + step1PublishCount = 3 + while step1PublishCount != 0: + currentMessage = str(messageCount) + self._publishMessagePool.append(int(currentMessage)) + try: + self._pyCoreClient.publish(self._topic, currentMessage, 0, False) + print("Thread A: Published a message: " + str(currentMessage)) + step1PublishCount -= 1 + messageCount += 1 + except publishError: + print("Publish Error!") + except publishQueueFullException: + print("Internal Publish Queue is FULL!") + except Exception as e: + print("Unknown exception!") + print("Type: " + str(type(e))) + print("Message: " + str(e.message)) + time.sleep(1) + # Network Failure, publish #numberOfOfflinePublish# messages + # Scanning rate = 100 TPS + print( + "Thread A: Simulate an network error. Keep publishing for " + str(self._numberOfOfflinePublish) + " messages.") + step2LoopCount = self._numberOfOfflinePublish * 100 + while step2LoopCount != 0: + self._manualNetworkError() + if step2LoopCount % 100 == 0: + currentMessage = str(messageCount) + self._publishMessagePool.append(int(currentMessage)) + try: + self._pyCoreClient.publish(self._topic, currentMessage, 0, False) + print("Thread A: Published a message: " + str(currentMessage)) + except publishError: + print("Publish Error!") + except Exception as e: + print("Unknown exception!") + print("Type: " + str(type(e))) + print("Message: " + str(e.message)) + messageCount += 1 + step2LoopCount -= 1 + time.sleep(0.01) + # Reconnecting + reconnectTiming = 0 # count per 0.01 seconds + while reconnectTiming <= 1000: + if reconnectTiming % 100 == 0: + print("Thread A: Counting reconnect time: " + str(reconnectTiming / 100) + "seconds.") + reconnectTiming += 1 + time.sleep(0.01) + print("Thread A: Reconnected!") + # Publish another set of 3 messages + print("Thread A: Publish 3 messages again.") + step3PublishCount = 3 + while step3PublishCount != 0: + currentMessage = str(messageCount) + self._publishMessagePool.append(int(currentMessage)) + try: + self._pyCoreClient.publish(self._topic, currentMessage, 0, False) + print("Thread A: Published a message: " + str(currentMessage)) + step3PublishCount -= 1 + messageCount += 1 + except publishError: + print("Publish Error!") + except Exception as e: + print("Unknown exception!") + print("Type: " + str(type(e))) + print("Message: " + str(e.message)) + time.sleep(1) + # Wrap up: Sleep for extra 5 seconds + time.sleep(5) + + def startThreadAndGo(self): + threadHandler = threading.Thread(target=self._runtime) + threadHandler.start() + return threadHandler + + def getPublishMessagePool(self): + return self._publishMessagePool + + +# Class that implements the subscribing thread: Thread B. +# Basically this thread does nothing but subscribes to TopicOnly and keeps receiving messages. +class threadSub: + def __init__(self, pyCoreClient, srcTopic): + self._keepRunning = True + self._pyCoreClient = pyCoreClient + self._subscribeMessagePool = list() + self._topic = srcTopic + + def _messageCallback(self, client, userdata, message): + print("Thread B: Received a new message from topic: " + str(message.topic)) + print("Thread B: Payload is: " + str(message.payload)) + self._subscribeMessagePool.append(int(message.payload)) + + def _runtime(self): + # Subscribe to self._topic + try: + self._pyCoreClient.subscribe(self._topic, 1, self._messageCallback) + except subscribeTimeoutException: + print("Subscribe timeout!") + except subscribeError: + print("Subscribe error!") + except Exception as e: + print("Unknown exception!") + print("Type: " + str(type(e))) + print("Message: " + str(e.message)) + time.sleep(2.2) + print("Thread B: Subscribed to " + self._topic) + print("Thread B: Now wait for Thread A.") + # Scanning rate is 100 TPS + while self._keepRunning: + time.sleep(0.01) + + def startThreadAndGo(self): + threadHandler = threading.Thread(target=self._runtime) + threadHandler.start() + return threadHandler + + def stopRunning(self): + self._keepRunning = False + + def getSubscribeMessagePool(self): + return self._subscribeMessagePool + + +# Generate answer for this integration test using queue configuration +def generateAnswer(data, queueingSize, srcMode): + dataInWork = sorted(data) + dataHead = dataInWork[:3] + dataTail = dataInWork[-3:] + dataRet = dataHead + dataInWork = dataInWork[3:] + dataInWork = dataInWork[:-3] + if srcMode == 0: # DROP_OLDEST + dataInWork = dataInWork[(-1 * queueingSize):] + dataRet.extend(dataInWork) + dataRet.extend(dataTail) + return sorted(dataRet) + elif srcMode == 1: # DROP_NEWEST + dataInWork = dataInWork[:queueingSize] + dataRet.extend(dataInWork) + dataRet.extend(dataTail) + return sorted(dataRet) + else: + print("Unsupported drop behavior!") + raise ValueError + + +# Create thread object, load in pyCoreClient and perform the set of integration tests +def performConfigurableOfflinePublishQueueTest(clientPub, clientSub): + print("Test DROP_NEWEST....") + clientPub[0].configure_offline_requests_queue(10, DropBehaviorTypes.DROP_NEWEST) # dropNewest + clientSub[0].configure_offline_requests_queue(10, DropBehaviorTypes.DROP_NEWEST) # dropNewest + # Create Topics + TopicOnly = "TopicOnly/" + "".join(random.choice(string.ascii_lowercase) for i in range(4)) + # Create thread object + threadPubObject = threadPub(clientPub[0], 15, TopicOnly) # Configure to publish 15 messages during network outage + threadSubObject = threadSub(clientSub[0], TopicOnly) + threadSubHandler = threadSubObject.startThreadAndGo() + time.sleep(3) + threadPubHandler = threadPubObject.startThreadAndGo() + threadPubHandler.join() + threadSubObject.stopRunning() + threadSubHandler.join() + # Verify result + print("Verify DROP_NEWEST:") + answer = generateAnswer(threadPubObject.getPublishMessagePool(), 10, 1) + print("ANSWER:") + print(answer) + print("ACTUAL:") + print(threadSubObject.getSubscribeMessagePool()) + # We are doing QoS0 publish. We cannot guarantee when the drop will happen since we cannot guarantee a fixed time out + # of disconnect detection. However, once offline requests queue starts involving, it should queue up to its limit, + # thus the total number of received messages after draining should always match. + if len(threadSubObject.getSubscribeMessagePool()) == len(answer): + print("Passed.") + else: + print("Verify DROP_NEWEST failed!!!") + return False + time.sleep(5) + print("Test DROP_OLDEST....") + clientPub[0].configure_offline_requests_queue(10, DropBehaviorTypes.DROP_OLDEST) # dropOldest + clientSub[0].configure_offline_requests_queue(10, DropBehaviorTypes.DROP_OLDEST) # dropOldest + # Create thread object + threadPubObject = threadPub(clientPub[0], 15, TopicOnly) # Configure to publish 15 messages during network outage + threadSubObject = threadSub(clientSub[0], TopicOnly) + threadSubHandler = threadSubObject.startThreadAndGo() + time.sleep(3) + threadPubHandler = threadPubObject.startThreadAndGo() + threadPubHandler.join() + threadSubObject.stopRunning() + threadSubHandler.join() + # Verify result + print("Verify DROP_OLDEST:") + answer = generateAnswer(threadPubObject.getPublishMessagePool(), 10, 0) + print(answer) + print("ACTUAL:") + print(threadSubObject.getSubscribeMessagePool()) + if len(threadSubObject.getSubscribeMessagePool()) == len(answer): + print("Passed.") + else: + print("Verify DROP_OLDEST failed!!!") + return False + return True + + +# Check inputs +myCheckInManager = checkInManager.checkInManager(2) +myCheckInManager.verify(sys.argv) + +host = myCheckInManager.host +rootCA = "./test-integration/Credentials/rootCA.crt" +certificate = "./test-integration/Credentials/certificate.pem.crt" +privateKey = "./test-integration/Credentials/privateKey.pem.key" +mode = myCheckInManager.mode + +skip_when_match(ModeIsALPN(mode).And( + Python2VersionLowerThan((2, 7, 10)).Or(Python3VersionLowerThan((3, 5, 0))) +), "This test is not applicable for mode %s and Python verison %s. Skipping..." % (mode, sys.version_info[:3])) + +# Init Python core and connect +myMQTTClientManager = MQTTClientManager.MQTTClientManager() +clientPub = myMQTTClientManager.create_connected_mqtt_core(CLIENT_ID_PUB, host, rootCA, + certificate, privateKey, mode=mode) +clientSub = myMQTTClientManager.create_connected_mqtt_core(CLIENT_ID_SUB, host, rootCA, + certificate, privateKey, mode=mode) + +if clientPub is None or clientSub is None: + exit(4) + +print("Two clients are connected!") + +# Functionality test +if not performConfigurableOfflinePublishQueueTest([clientPub], [clientSub]): + print("The above Drop behavior broken!") + exit(4) diff --git a/test-integration/IntegrationTests/IntegrationTestDiscovery.py b/test-integration/IntegrationTests/IntegrationTestDiscovery.py new file mode 100644 index 0000000..2fac25b --- /dev/null +++ b/test-integration/IntegrationTests/IntegrationTestDiscovery.py @@ -0,0 +1,216 @@ +import sys +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary") +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary/SDKPackage") + +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.core.greengrass.discovery.providers import DiscoveryInfoProvider +from TestToolLibrary.checkInManager import checkInManager +from TestToolLibrary.skip import skip_when_match +from TestToolLibrary.skip import ModeIsWebSocket + + +PORT = 8443 +CA = "./test-integration/Credentials/rootCA.crt" +CERT = "./test-integration/Credentials/certificate_drs.pem.crt" +KEY = "./test-integration/Credentials/privateKey_drs.pem.key" +TIME_OUT_SEC = 30 +# This is a pre-generated test data from DRS integration tests +# The test resources point to account # 003261610643 +ID_PREFIX = "Id-" +GGC_ARN = "arn:aws:iot:us-east-1:003261610643:thing/DRS_GGC_0kegiNGA_0" +GGC_PORT_NUMBER_BASE = 8080 +GGC_HOST_ADDRESS_PREFIX = "192.168.101." +METADATA_PREFIX = "Description-" +GROUP_ID = "627bf63d-ae64-4f58-a18c-80a44fcf4088" +THING_NAME = "DRS_GGAD_0kegiNGA_0" +EXPECTED_CA_CONTENT = "-----BEGIN CERTIFICATE-----\n" \ + "MIIEFTCCAv2gAwIBAgIVAPZfc4GMLZPmXbnoaZm6jRDqDs4+MA0GCSqGSIb3DQEB\n" \ + "CwUAMIGoMQswCQYDVQQGEwJVUzEYMBYGA1UECgwPQW1hem9uLmNvbSBJbmMuMRww\n" \ + "GgYDVQQLDBNBbWF6b24gV2ViIFNlcnZpY2VzMRMwEQYDVQQIDApXYXNoaW5ndG9u\n" \ + "MRAwDgYDVQQHDAdTZWF0dGxlMTowOAYDVQQDDDEwMDMyNjE2MTA2NDM6NjI3YmY2\n" \ + "M2QtYWU2NC00ZjU4LWExOGMtODBhNDRmY2Y0MDg4MCAXDTE3MDUyNTE4NDI1OVoY\n" \ + "DzIwOTcwNTI1MTg0MjU4WjCBqDELMAkGA1UEBhMCVVMxGDAWBgNVBAoMD0FtYXpv\n" \ + "bi5jb20gSW5jLjEcMBoGA1UECwwTQW1hem9uIFdlYiBTZXJ2aWNlczETMBEGA1UE\n" \ + "CAwKV2FzaGluZ3RvbjEQMA4GA1UEBwwHU2VhdHRsZTE6MDgGA1UEAwwxMDAzMjYx\n" \ + "NjEwNjQzOjYyN2JmNjNkLWFlNjQtNGY1OC1hMThjLTgwYTQ0ZmNmNDA4ODCCASIw\n" \ + "DQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKEWtZtKyJUg2VUwZkbkVtltrfam\n" \ + "s9LMIdKNA3Wz4zSLhZjKHiTSkQmpZwKle5ziYs6Q5hfeT8WC0FNAVv1JhnwsuGfT\n" \ + "sG0UO5dSn7wqXOJigKC1CaSGqeFpKB0/a3wR1L6pCGVbLZ86/sPCEPHHJDieQ+Ps\n" \ + "RnOcUGb4CuIBnI2N+lafWNa4F4KRSVJCEeZ6u4iWVVdIEcDLKlakY45jtVvQqwnz\n" \ + "3leFsN7PTLEkVq5u1PXSbT5DWv6p+5NoDnGAT7j7Wbr2yJw7DtpBOL6oWkAdbFAQ\n" \ + "2097e8mIxNYE9xAzRlb5wEr6jZl/8K60v9P83OapMeuOg4JS8FGulHXbDg0CAwEA\n" \ + "AaMyMDAwDwYDVR0TAQH/BAUwAwEB/zAdBgNVHQ4EFgQU21ELaPCH9Oh001OS0JMv\n" \ + "n8hU8dYwDQYJKoZIhvcNAQELBQADggEBABW66eH/+/v9Nq5jtJzflfrqAfBOzWLj\n" \ + "UTEv6szkYzV5Crr8vnu2P5OlyA0NdiKGiAm0AgoDkf+n9HU3Hc0zm3G/QaAO2UmN\n" \ + "9MwtIp29BSRf+gd1bX/WZTtl5I5xl290BDfr5o08I6TOf0A4P8IAkGwku5j0IQjM\n" \ + "ns2HH5UVki155dtmWDEGX6q35KABbsmv3tO1+geJVYnd1QkHzR5IXA12gxlMw9GJ\n" \ + "+cOw+rwJJ2ZcXo3HFoXBcsPqPOa1SO3vTl3XWQ+jX3vyDsxh/VGoJ4epsjwmJ+dW\n" \ + "sHJoqsa3ZPDW0LcEuYgdzYWRhumGwH9fJJUx0GS4Tdg4ud+6jpuyflU=\n" \ + "-----END CERTIFICATE-----\n" +# The expected response from DRS should be: +''' +{ + "GGGroups": [ + { + "GGGroupId": "627bf63d-ae64-4f58-a18c-80a44fcf4088", + "Cores": [ + { + "thingArn": "arn:aws:iot:us-east-1:003261610643:thing\/DRS_GGC_0kegiNGA_0", + "Connectivity": [ + { + "Id": "Id-0", + "HostAddress": "192.168.101.0", + "PortNumber": 8080, + "Metadata": "Description-0" + }, + { + "Id": "Id-1", + "HostAddress": "192.168.101.1", + "PortNumber": 8081, + "Metadata": "Description-1" + }, + { + "Id": "Id-2", + "HostAddress": "192.168.101.2", + "PortNumber": 8082, + "Metadata": "Description-2" + } + ] + } + ], + "CAs": [ + "-----BEGIN CERTIFICATE-----\n + MIIEFTCCAv2gAwIBAgIVAPZfc4GMLZPmXbnoaZm6jRDqDs4+MA0GCSqGSIb3DQEB\n + CwUAMIGoMQswCQYDVQQGEwJVUzEYMBYGA1UECgwPQW1hem9uLmNvbSBJbmMuMRww\n + GgYDVQQLDBNBbWF6b24gV2ViIFNlcnZpY2VzMRMwEQYDVQQIDApXYXNoaW5ndG9u\n + MRAwDgYDVQQHDAdTZWF0dGxlMTowOAYDVQQDDDEwMDMyNjE2MTA2NDM6NjI3YmY2\n + M2QtYWU2NC00ZjU4LWExOGMtODBhNDRmY2Y0MDg4MCAXDTE3MDUyNTE4NDI1OVoY\n + DzIwOTcwNTI1MTg0MjU4WjCBqDELMAkGA1UEBhMCVVMxGDAWBgNVBAoMD0FtYXpv\n + bi5jb20gSW5jLjEcMBoGA1UECwwTQW1hem9uIFdlYiBTZXJ2aWNlczETMBEGA1UE\n + CAwKV2FzaGluZ3RvbjEQMA4GA1UEBwwHU2VhdHRsZTE6MDgGA1UEAwwxMDAzMjYx\n + NjEwNjQzOjYyN2JmNjNkLWFlNjQtNGY1OC1hMThjLTgwYTQ0ZmNmNDA4ODCCASIw\n + DQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKEWtZtKyJUg2VUwZkbkVtltrfam\n + s9LMIdKNA3Wz4zSLhZjKHiTSkQmpZwKle5ziYs6Q5hfeT8WC0FNAVv1JhnwsuGfT\n + sG0UO5dSn7wqXOJigKC1CaSGqeFpKB0\/a3wR1L6pCGVbLZ86\/sPCEPHHJDieQ+Ps\n + RnOcUGb4CuIBnI2N+lafWNa4F4KRSVJCEeZ6u4iWVVdIEcDLKlakY45jtVvQqwnz\n + 3leFsN7PTLEkVq5u1PXSbT5DWv6p+5NoDnGAT7j7Wbr2yJw7DtpBOL6oWkAdbFAQ\n + 2097e8mIxNYE9xAzRlb5wEr6jZl\/8K60v9P83OapMeuOg4JS8FGulHXbDg0CAwEA\n + AaMyMDAwDwYDVR0TAQH\/BAUwAwEB\/zAdBgNVHQ4EFgQU21ELaPCH9Oh001OS0JMv\n + n8hU8dYwDQYJKoZIhvcNAQELBQADggEBABW66eH\/+\/v9Nq5jtJzflfrqAfBOzWLj\n + UTEv6szkYzV5Crr8vnu2P5OlyA0NdiKGiAm0AgoDkf+n9HU3Hc0zm3G\/QaAO2UmN\n + 9MwtIp29BSRf+gd1bX\/WZTtl5I5xl290BDfr5o08I6TOf0A4P8IAkGwku5j0IQjM\n + ns2HH5UVki155dtmWDEGX6q35KABbsmv3tO1+geJVYnd1QkHzR5IXA12gxlMw9GJ\n + +cOw+rwJJ2ZcXo3HFoXBcsPqPOa1SO3vTl3XWQ+jX3vyDsxh\/VGoJ4epsjwmJ+dW\n + sHJoqsa3ZPDW0LcEuYgdzYWRhumGwH9fJJUx0GS4Tdg4ud+6jpuyflU=\n + -----END CERTIFICATE-----\n" + ] + } + ] +} +''' + +my_check_in_manager = checkInManager(2) +my_check_in_manager.verify(sys.argv) +mode = my_check_in_manager.mode +host = my_check_in_manager.host + +def create_discovery_info_provider(): + discovery_info_provider = DiscoveryInfoProvider() + discovery_info_provider.configureEndpoint(host, PORT) + discovery_info_provider.configureCredentials(CA, CERT, KEY) + discovery_info_provider.configureTimeout(TIME_OUT_SEC) + return discovery_info_provider + + +def perform_integ_test_discovery(): + discovery_info_provider = create_discovery_info_provider() + return discovery_info_provider.discover(THING_NAME) + + +def _verify_connectivity_info(actual_connectivity_info): + info_id = actual_connectivity_info.id + sequence_number_string = info_id[-1:] + assert actual_connectivity_info.host == GGC_HOST_ADDRESS_PREFIX + sequence_number_string + assert actual_connectivity_info.port == GGC_PORT_NUMBER_BASE + int(sequence_number_string) + assert actual_connectivity_info.metadata == METADATA_PREFIX + sequence_number_string + + +def _verify_connectivity_info_list(actual_connectivity_info_list): + for actual_connectivity_info in actual_connectivity_info_list: + _verify_connectivity_info(actual_connectivity_info) + + +def _verify_ggc_info(actual_ggc_info): + assert actual_ggc_info.coreThingArn == GGC_ARN + assert actual_ggc_info.groupId == GROUP_ID + _verify_connectivity_info_list(actual_ggc_info.connectivityInfoList) + + +def _verify_ca_list(ca_list): + assert len(ca_list) == 1 + try: + group_id, ca = ca_list[0] + assert group_id == GROUP_ID + assert ca == EXPECTED_CA_CONTENT + except: + assert ca_list[0] == EXPECTED_CA_CONTENT + + +def verify_all_cores(discovery_info): + print("Verifying \"getAllCores\"...") + ggc_info_list = discovery_info.getAllCores() + assert len(ggc_info_list) == 1 + _verify_ggc_info(ggc_info_list[0]) + print("Pass!") + + +def verify_all_cas(discovery_info): + print("Verifying \"getAllCas\"...") + ca_list = discovery_info.getAllCas() + _verify_ca_list(ca_list) + print("Pass!") + + +def verify_all_groups(discovery_info): + print("Verifying \"getAllGroups\"...") + group_list = discovery_info.getAllGroups() + assert len(group_list) == 1 + group_info = group_list[0] + _verify_ca_list(group_info.caList) + _verify_ggc_info(group_info.coreConnectivityInfoList[0]) + print("Pass!") + + +def verify_group_object(discovery_info): + print("Verifying \"toObjectAtGroupLevel\"...") + group_info_object = discovery_info.toObjectAtGroupLevel() + _verify_connectivity_info(group_info_object + .get(GROUP_ID) + .getCoreConnectivityInfo(GGC_ARN) + .getConnectivityInfo(ID_PREFIX + "0")) + _verify_connectivity_info(group_info_object + .get(GROUP_ID) + .getCoreConnectivityInfo(GGC_ARN) + .getConnectivityInfo(ID_PREFIX + "1")) + _verify_connectivity_info(group_info_object + .get(GROUP_ID) + .getCoreConnectivityInfo(GGC_ARN) + .getConnectivityInfo(ID_PREFIX + "2")) + print("Pass!") + + +############################################################################ +# Main # + +skip_when_match(ModeIsWebSocket(mode), "This test is not applicable for mode: %s. Skipping..." % mode) + +# GG Discovery only applies mutual auth with cert +try: + discovery_info = perform_integ_test_discovery() + + verify_all_cores(discovery_info) + verify_all_cas(discovery_info) + verify_all_groups(discovery_info) + verify_group_object(discovery_info) +except BaseException as e: + print("Failed! " + e.message) + exit(4) diff --git a/test-integration/IntegrationTests/IntegrationTestJobsClient.py b/test-integration/IntegrationTests/IntegrationTestJobsClient.py new file mode 100644 index 0000000..3653725 --- /dev/null +++ b/test-integration/IntegrationTests/IntegrationTestJobsClient.py @@ -0,0 +1,185 @@ +# This integration test verifies the jobs client functionality in the +# Python SDK. +# It performs a number of basic operations without expecting an actual job or +# job execution to be present. The callbacks associated with these actions +# are written to accept and pass server responses given when no jobs or job +# executions exist. +# Finally, the tester pumps through all jobs queued for the given thing +# doing a basic echo of the job document and updating the job execution +# to SUCCEEDED + +import random +import string +import time +import sys +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary") +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary/SDKPackage") + +from TestToolLibrary import simpleThreadManager +import TestToolLibrary.checkInManager as checkInManager +import TestToolLibrary.MQTTClientManager as MQTTClientManager +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import publishError +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeError +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeTimeoutException +from TestToolLibrary.skip import skip_when_match +from TestToolLibrary.skip import ModeIsALPN +from TestToolLibrary.skip import Python2VersionLowerThan +from TestToolLibrary.skip import Python3VersionLowerThan + +from AWSIoTPythonSDK.MQTTLib import AWSIoTMQTTThingJobsClient +from AWSIoTPythonSDK.core.jobs.thingJobManager import jobExecutionTopicType +from AWSIoTPythonSDK.core.jobs.thingJobManager import jobExecutionTopicReplyType +from AWSIoTPythonSDK.core.jobs.thingJobManager import jobExecutionStatus + +import threading +import datetime +import argparse +import json + +IOT_JOBS_MQTT_RESPONSE_WAIT_SECONDS = 5 +CLIENT_ID = "integrationTestMQTT_Client" + "".join(random.choice(string.ascii_lowercase) for i in range(4)) + +class JobsMessageProcessor(object): + def __init__(self, awsIoTMQTTThingJobsClient, clientToken): + #keep track of this to correlate request/responses + self.clientToken = clientToken + self.awsIoTMQTTThingJobsClient = awsIoTMQTTThingJobsClient + + def _setupCallbacks(self): + print('Creating test subscriptions...') + assert True == self.awsIoTMQTTThingJobsClient.createJobSubscription(self.getPendingJobAcceptedCallback, jobExecutionTopicType.JOB_GET_PENDING_TOPIC, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE) + assert True == self.awsIoTMQTTThingJobsClient.createJobSubscription(self.getPendingJobRejectedCallback, jobExecutionTopicType.JOB_GET_PENDING_TOPIC, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE) + assert True == self.awsIoTMQTTThingJobsClient.createJobSubscription(self.describeJobExecAcceptedCallback, jobExecutionTopicType.JOB_DESCRIBE_TOPIC, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE, '+') + assert True == self.awsIoTMQTTThingJobsClient.createJobSubscription(self.describeJobExecRejectedCallback, jobExecutionTopicType.JOB_DESCRIBE_TOPIC, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE, '+') + assert True == self.awsIoTMQTTThingJobsClient.createJobSubscription(self.startNextPendingJobAcceptedCallback, jobExecutionTopicType.JOB_START_NEXT_TOPIC, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE) + assert True == self.awsIoTMQTTThingJobsClient.createJobSubscription(self.startNextPendingJobRejectedCallback, jobExecutionTopicType.JOB_START_NEXT_TOPIC, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE) + assert True == self.awsIoTMQTTThingJobsClient.createJobSubscription(self.updateJobAcceptedCallback, jobExecutionTopicType.JOB_UPDATE_TOPIC, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE, '+') + assert True == self.awsIoTMQTTThingJobsClient.createJobSubscription(self.updateJobRejectedCallback, jobExecutionTopicType.JOB_UPDATE_TOPIC, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE, '+') + + def getPendingJobAcceptedCallback(self, client, userdata, message): + self.testResult = (True, 'GetPending accepted callback invoked!') + self.waitEvent.set() + + def getPendingJobRejectedCallback(self, client, userdata, message): + self.testResult = (False, 'GetPending rejection callback invoked!') + self.waitEvent.set() + + def describeJobExecAcceptedCallback(self, client, userdata, message): + self.testResult = (True, 'DescribeJobExecution accepted callback invoked!') + self.waitEvent.set() + + def describeJobExecRejectedCallback(self, client, userdata, message): + self.testResult = (False, 'DescribeJobExecution rejected callback invoked!') + self.waitEvent.set() + + def startNextPendingJobAcceptedCallback(self, client, userdata, message): + self.testResult = (True, 'StartNextPendingJob accepted callback invoked!') + payload = json.loads(message.payload.decode('utf-8')) + if 'execution' not in payload: + self.done = True + else: + print('Found job! Document: ' + payload['execution']['jobDocument']) + threading.Thread(target=self.awsIoTMQTTThingJobsClient.sendJobsUpdate(payload['execution']['jobId'], jobExecutionStatus.JOB_EXECUTION_SUCCEEDED)).start() + self.waitEvent.set() + + def startNextPendingJobRejectedCallback(self, client, userdata, message): + self.testResult = (False, 'StartNextPendingJob rejected callback invoked!') + self.waitEvent.set() + + def updateJobAcceptedCallback(self, client, userdata, message): + self.testResult = (True, 'UpdateJob accepted callback invoked!') + self.waitEvent.set() + + def updateJobRejectedCallback(self, client, userdata, message): + #rejection is still a successful test because job IDs may or may not exist, and could exist in unknown state + self.testResult = (True, 'UpdateJob rejected callback invoked!') + self.waitEvent.set() + + def executeJob(self, execution): + print('Executing job ID, version, number: {}, {}, {}'.format(execution['jobId'], execution['versionNumber'], execution['executionNumber'])) + print('With jobDocument: ' + json.dumps(execution['jobDocument'])) + + def runTests(self): + print('Running jobs tests...') + ##create subscriptions + self._setupCallbacks() + + #make publish calls + self._init_test_wait() + self._test_send_response_confirm(self.awsIoTMQTTThingJobsClient.sendJobsDescribe('$next')) + + self._init_test_wait() + self._test_send_response_confirm(self.awsIoTMQTTThingJobsClient.sendJobsUpdate('junkJobId', jobExecutionStatus.JOB_EXECUTION_SUCCEEDED)) + + self._init_test_wait() + self._test_send_response_confirm(self.awsIoTMQTTThingJobsClient.sendJobsQuery(jobExecutionTopicType.JOB_GET_PENDING_TOPIC)) + + self._init_test_wait() + self._test_send_response_confirm(self.awsIoTMQTTThingJobsClient.sendJobsStartNext()) + + self.processAllJobs() + + def processAllJobs(self): + #process all enqueued jobs + print('Processing all jobs found in queue for thing...') + self.done = False + while not self.done: + self._attemptStartNextJob() + time.sleep(5) + + def _attemptStartNextJob(self): + statusDetails = {'StartedBy': 'ClientToken: {} on {}'.format(self.clientToken, datetime.datetime.now().isoformat())} + threading.Thread(target=self.awsIoTMQTTThingJobsClient.sendJobsStartNext, kwargs = {'statusDetails': statusDetails}).start() + + def _init_test_wait(self): + self.testResult = (False, 'Callback not invoked') + self.waitEvent = threading.Event() + + def _test_send_response_confirm(self, sendResult): + if not sendResult: + print('Failed to send jobs message') + exit(4) + else: + #wait 25 seconds for expected callback response to happen + if not self.waitEvent.wait(IOT_JOBS_MQTT_RESPONSE_WAIT_SECONDS): + print('Did not receive expected callback within %d second timeout' % IOT_JOBS_MQTT_RESPONSE_WAIT_SECONDS) + exit(4) + elif not self.testResult[0]: + print('Callback result has failed the test with message: ' + self.testResult[1]) + exit(4) + else: + print('Recieved expected result: ' + self.testResult[1]) + + +############################################################################ +# Main # +# Check inputs +myCheckInManager = checkInManager.checkInManager(2) +myCheckInManager.verify(sys.argv) + +host = myCheckInManager.host +rootCA = "./test-integration/Credentials/rootCA.crt" +certificate = "./test-integration/Credentials/certificate.pem.crt" +privateKey = "./test-integration/Credentials/privateKey.pem.key" +mode = myCheckInManager.mode + +skip_when_match(ModeIsALPN(mode).And( + Python2VersionLowerThan((2, 7, 10)).Or(Python3VersionLowerThan((3, 5, 0))) +), "This test is not applicable for mode %s and Python verison %s. Skipping..." % (mode, sys.version_info[:3])) + +# Init Python core and connect +myMQTTClientManager = MQTTClientManager.MQTTClientManager() +client = myMQTTClientManager.create_connected_mqtt_client(mode, CLIENT_ID, host, (rootCA, certificate, privateKey)) + +clientId = 'AWSPythonkSDKTestThingClient' +thingName = 'AWSPythonkSDKTestThing' +jobsClient = AWSIoTMQTTThingJobsClient(clientId, thingName, QoS=1, awsIoTMQTTClient=client) + +print('Connecting to MQTT server and setting up callbacks...') +jobsMsgProc = JobsMessageProcessor(jobsClient, clientId) +print('Starting jobs tests...') +jobsMsgProc.runTests() +print('Done running jobs tests') + +#can call this on the jobsClient, or myAWSIoTMQTTClient directly +jobsClient.disconnect() diff --git a/test-integration/IntegrationTests/IntegrationTestMQTTConnection.py b/test-integration/IntegrationTests/IntegrationTestMQTTConnection.py new file mode 100644 index 0000000..9adc38c --- /dev/null +++ b/test-integration/IntegrationTests/IntegrationTestMQTTConnection.py @@ -0,0 +1,177 @@ +# This integration test verifies the functionality in the Python core of IoT Yun/Python SDK +# for basic MQTT connection. +# It starts two threads using two different connections to AWS IoT: +# Thread A: publish to "deviceSDK/PyIntegrationTest/Topic", X messages, QoS1, TPS=50 +# Thread B: subscribe to "deviceSDK/PyIntegrationTest/Topic", QoS1 +# Thread B will be started first with extra delay to ensure a valid subscription +# Then thread A will be started. +# Verify send/receive messages are equivalent + + +import random +import string +import time +import sys +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary") +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary/SDKPackage") + +from TestToolLibrary import simpleThreadManager +import TestToolLibrary.checkInManager as checkInManager +import TestToolLibrary.MQTTClientManager as MQTTClientManager +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import publishError +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeError +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeTimeoutException +from TestToolLibrary.skip import skip_when_match +from TestToolLibrary.skip import ModeIsALPN +from TestToolLibrary.skip import Python2VersionLowerThan +from TestToolLibrary.skip import Python3VersionLowerThan + + +API_TYPE_SYNC = "sync" +API_TYPE_ASYNC = "async" + +CLIENT_ID_PUB = "integrationTestMQTT_ClientPub" + "".join(random.choice(string.ascii_lowercase) for i in range(4)) +CLIENT_ID_SUB = "integrationTestMQTT_ClientSub" + "".join(random.choice(string.ascii_lowercase) for i in range(4)) + + + +# Callback unit for subscribe +class callbackUnit: + def __init__(self, srcSet, apiType): + self._internalSet = srcSet + self._apiType = apiType + + # Callback for clientSub + def messageCallback(self, client, userdata, message): + print(self._apiType + ": Received a new message: " + str(message.payload)) + self._internalSet.add(message.payload.decode('utf-8')) + + def getInternalSet(self): + return self._internalSet + + +# Run function unit +class runFunctionUnit: + def __init__(self, apiType): + self._messagesPublished = set() + self._apiType = apiType + + # Run function for publish thread (one time) + def threadPublish(self, pyCoreClient, numberOfTotalMessages, topic, TPS): + # One time thread + time.sleep(3) # Extra waiting time for valid subscription + messagesLeftToBePublished = numberOfTotalMessages + while messagesLeftToBePublished != 0: + try: + currentMessage = str(messagesLeftToBePublished) + self._performPublish(pyCoreClient, topic, 1, currentMessage) + self._messagesPublished.add(currentMessage) + except publishError: + print("Publish Error for message: " + currentMessage) + except Exception as e: + print("Unknown exception: " + str(type(e)) + " " + str(e.message)) + messagesLeftToBePublished -= 1 + time.sleep(1 / float(TPS)) + print("End of publish thread.") + + def _performPublish(self, pyCoreClient, topic, qos, payload): + if self._apiType == API_TYPE_SYNC: + pyCoreClient.publish(topic, payload, qos, False) + if self._apiType == API_TYPE_ASYNC: + pyCoreClient.publish_async(topic, payload, qos, False, None) # TODO: See if we can also check PUBACKs + + +############################################################################ +# Main # +# Check inputs +myCheckInManager = checkInManager.checkInManager(3) +myCheckInManager.verify(sys.argv) + +host = myCheckInManager.host +rootCA = "./test-integration/Credentials/rootCA.crt" +certificate = "./test-integration/Credentials/certificate.pem.crt" +privateKey = "./test-integration/Credentials/privateKey.pem.key" +mode = myCheckInManager.mode + +skip_when_match(ModeIsALPN(mode).And( + Python2VersionLowerThan((2, 7, 10)).Or(Python3VersionLowerThan((3, 5, 0))) +), "This test is not applicable for mode %s and Python verison %s. Skipping..." % (mode, sys.version_info[:3])) + +# Init Python core and connect +myMQTTClientManager = MQTTClientManager.MQTTClientManager() +clientPub = myMQTTClientManager.create_connected_mqtt_core(CLIENT_ID_PUB, host, rootCA, + certificate, privateKey, mode=mode) +clientSub = myMQTTClientManager.create_connected_mqtt_core(CLIENT_ID_SUB, host, rootCA, + certificate, privateKey, mode=mode) + +if clientPub is None or clientSub is None: + exit(4) + +print("Two clients are connected!") + +# Configurations +################ +# Data/Data pool +TPS = 20 +numberOfTotalMessagesAsync = myCheckInManager.customParameter +numberOfTotalMessagesSync = numberOfTotalMessagesAsync / 10 +subSetAsync = set() +subSetSync = set() +subCallbackUnitAsync = callbackUnit(subSetAsync, API_TYPE_ASYNC) +subCallbackUnitSync = callbackUnit(subSetSync, API_TYPE_SYNC) +syncTopic = "YunSDK/PyIntegrationTest/Topic/sync" + "".join(random.choice(string.ascii_lowercase) for i in range(4)) +print(syncTopic) +asyncTopic = "YunSDK/PyIntegrationTest/Topic/async" + "".join(random.choice(string.ascii_lowercase) for j in range(4)) +# clientSub +try: + clientSub.subscribe(asyncTopic, 1, subCallbackUnitAsync.messageCallback) + clientSub.subscribe(syncTopic, 1, subCallbackUnitSync.messageCallback) + time.sleep(3) +except subscribeTimeoutException: + print("Subscribe timeout!") +except subscribeError: + print("Subscribe error!") +except Exception as e: + print("Unknown exception!") + print("Type: " + str(type(e))) + print("Message: " + str(e.message)) +# Threads +mySimpleThreadManager = simpleThreadManager.simpleThreadManager() +myRunFunctionUnitSyncPub = runFunctionUnit(API_TYPE_SYNC) +myRunFunctionUnitAsyncPub = runFunctionUnit(API_TYPE_ASYNC) +publishSyncThreadID = mySimpleThreadManager.createOneTimeThread(myRunFunctionUnitSyncPub.threadPublish, + [clientPub, numberOfTotalMessagesSync, syncTopic, TPS]) +publishAsyncThreadID = mySimpleThreadManager.createOneTimeThread(myRunFunctionUnitAsyncPub.threadPublish, + [clientPub, numberOfTotalMessagesAsync, asyncTopic, TPS]) + +# Performing +############ +mySimpleThreadManager.startThreadWithID(publishSyncThreadID) +mySimpleThreadManager.startThreadWithID(publishAsyncThreadID) +mySimpleThreadManager.joinOneTimeThreadWithID(publishSyncThreadID) +mySimpleThreadManager.joinOneTimeThreadWithID(publishAsyncThreadID) +time.sleep(numberOfTotalMessagesAsync / float(TPS) * 0.5) + +# Verifying +########### +# Length +print("Check if the length of the two sets are equal...") +print("Received from subscription (sync pub): " + str(len(subCallbackUnitSync.getInternalSet()))) +print("Received from subscription (async pub): " + str(len(subCallbackUnitAsync.getInternalSet()))) +print("Sent through sync publishes: " + str(len(myRunFunctionUnitSyncPub._messagesPublished))) +print("Sent through async publishes: " + str(len(myRunFunctionUnitAsyncPub._messagesPublished))) +if len(myRunFunctionUnitSyncPub._messagesPublished) != len(subCallbackUnitSync.getInternalSet()): + print("[Sync pub] Number of messages not equal!") + exit(4) +if len(myRunFunctionUnitAsyncPub._messagesPublished) != len(subCallbackUnitAsync.getInternalSet()): + print("[Asyn pub] Number of messages not equal!") + exit(4) +# Content +print("Check if the content if the two sets are equivalent...") +if myRunFunctionUnitSyncPub._messagesPublished != subCallbackUnitSync.getInternalSet(): + print("[Sync pub] Set content not equal!") + exit(4) +elif myRunFunctionUnitAsyncPub._messagesPublished != subCallbackUnitAsync.getInternalSet(): + print("[Async pub] Set content not equal!") +else: + print("Yes!") diff --git a/test-integration/IntegrationTests/IntegrationTestOfflineQueueingForSubscribeUnsubscribe.py b/test-integration/IntegrationTests/IntegrationTestOfflineQueueingForSubscribeUnsubscribe.py new file mode 100644 index 0000000..37c1862 --- /dev/null +++ b/test-integration/IntegrationTests/IntegrationTestOfflineQueueingForSubscribeUnsubscribe.py @@ -0,0 +1,210 @@ +# This integration test verifies the functionality off queueing up subscribe/unsubscribe requests submitted by the +# client when it is offline, and drain them out when the client is reconnected. The test contains 2 clients running in +# 2 different threads: +# +# In thread A, client_sub_unsub follows the below workflow: +# 1. Client connects to AWS IoT. +# 2. Client subscribes to "topic_A". +# 3. Experience a simulated network error which brings the client offline. +# 4. While offline, client subscribes to "topic_B' and unsubscribes from "topic_A". +# 5. Client reconnects, comes back online and drains out all offline queued requests. +# 6. Client stays and receives messages published in another thread. +# +# In thread B, client_pub follows the below workflow: +# 1. Client in thread B connects to AWS IoT. +# 2. After client in thread A connects and subscribes to "topic_A", client in thread B publishes messages to "topic_A". +# 3. Client in thread B keeps sleeping until client in thread A goes back online and reaches to a stable state (draining done). +# 4. Client in thread B then publishes messages to "topic_A" and "topic_B". +# +# Since client in thread A does a unsubscribe to "topic_A", it should never receive messages published to "topic_A" after +# it reconnects and gets stable. It should have the messages from "topic_A" published at the very beginning. +# Since client in thread A does a subscribe to "topic_B", it should receive messages published to "topic_B" after it +# reconnects and gets stable. + + +import random +import string +import time +from threading import Event +from threading import Thread +import sys +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary") +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary/SDKPackage") + +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.core.protocol.internal.clients import ClientStatus +from TestToolLibrary.checkInManager import checkInManager +from TestToolLibrary.MQTTClientManager import MQTTClientManager +from TestToolLibrary.skip import skip_when_match +from TestToolLibrary.skip import ModeIsALPN +from TestToolLibrary.skip import Python2VersionLowerThan +from TestToolLibrary.skip import Python3VersionLowerThan + + +def get_random_string(length): + return "".join(random.choice(string.ascii_lowercase) for i in range(length)) + +TOPIC_A = "topic/test/offline_sub_unsub/a" + get_random_string(4) +TOPIC_B = "topic/test/offline_sub_unsub/b" + get_random_string(4) +MESSAGE_PREFIX = "MagicMessage-" +NUMBER_OF_PUBLISHES = 3 +ROOT_CA = "./test-integration/Credentials/rootCA.crt" +CERT = "./test-integration/Credentials/certificate.pem.crt" +KEY = "./test-integration/Credentials/privateKey.pem.key" +CLIENT_PUB_ID = "PySdkIntegTest_OfflineSubUnsub_pub" + get_random_string(4) +CLIENT_SUB_UNSUB_ID = "PySdkIntegTest_OfflineSubUnsub_subunsub" + get_random_string(4) +KEEP_ALIVE_SEC = 1 +EVENT_WAIT_TIME_OUT_SEC = 5 + + +class DualClientRunner(object): + + def __init__(self, mode): + self._publish_end_flag = Event() + self._stable_flag = Event() + self._received_messages_topic_a = list() + self._received_messages_topic_b = list() + self.__mode = mode + self._client_pub = self._create_connected_client(CLIENT_PUB_ID) + print("Created connected client pub.") + self._client_sub_unsub = self._create_connected_client(CLIENT_SUB_UNSUB_ID) + print("Created connected client sub/unsub.") + self._client_sub_unsub.subscribe(TOPIC_A, 1, self._collect_sub_messages) + print("Client sub/unsub subscribed to topic: %s" % TOPIC_A) + time.sleep(2) # Make sure the subscription is valid + + def _create_connected_client(self, id_prefix): + return MQTTClientManager().create_connected_mqtt_client(self.__mode, id_prefix, host, (ROOT_CA, CERT, KEY)) + + def start(self): + thread_client_sub_unsub = Thread(target=self._thread_client_sub_unsub_runtime) + thread_client_pub = Thread(target=self._thread_client_pub_runtime) + thread_client_sub_unsub.start() + thread_client_pub.start() + thread_client_pub.join() + thread_client_sub_unsub.join() + + def _thread_client_sub_unsub_runtime(self): + print("Start client sub/unsub runtime thread...") + print("Client sub/unsub waits on the 1st round of publishes to end...") + if not self._publish_end_flag.wait(EVENT_WAIT_TIME_OUT_SEC): + raise RuntimeError("Timed out in waiting for the publishes to topic: %s" % TOPIC_A) + print("Client sub/unsub gets notified.") + self._publish_end_flag.clear() + + print("Client sub/unsub now goes offline...") + self._go_offline_and_send_requests() + + # Wait until the connection is stable and then notify + print("Client sub/unsub waits on a stable connection...") + self._wait_until_stable_connection() + + print("Client sub/unsub waits on the 2nd round of publishes to end...") + if not self._publish_end_flag.wait(EVENT_WAIT_TIME_OUT_SEC): + raise RuntimeError("Timed out in waiting for the publishes to topic: %s" % TOPIC_B) + print("Client sub/unsub gets notified.") + self._publish_end_flag.clear() + + print("Client sub/unsub runtime thread ends.") + + def _wait_until_stable_connection(self): + reconnect_timing = 0 + while self._client_sub_unsub._mqtt_core._client_status.get_status() != ClientStatus.STABLE: + time.sleep(0.01) + reconnect_timing += 1 + if reconnect_timing % 100 == 0: + print("Client sub/unsub: Counting reconnect time: " + str(reconnect_timing / 100) + " seconds.") + print("Client sub/unsub: Counting reconnect time result: " + str(float(reconnect_timing) / 100) + " seconds.") + self._stable_flag.set() + + def _collect_sub_messages(self, client, userdata, message): + if message.topic == TOPIC_A: + print("Client sub/unsub: Got a message from %s" % TOPIC_A) + self._received_messages_topic_a.append(message.payload) + if message.topic == TOPIC_B: + print("Client sub/unsub: Got a message from %s" % TOPIC_B) + self._received_messages_topic_b.append(message.payload) + + def _go_offline_and_send_requests(self): + do_once = True + loop_count = EVENT_WAIT_TIME_OUT_SEC * 100 + while loop_count != 0: + self._manual_network_error() + if loop_count % 100 == 0: + print("Client sub/unsub: Offline time down count: %d sec" % (loop_count / 100)) + if do_once and (loop_count / 100) <= (EVENT_WAIT_TIME_OUT_SEC / 2): + print("Client sub/unsub: Performing offline sub/unsub...") + self._client_sub_unsub.subscribe(TOPIC_B, 1, self._collect_sub_messages) + self._client_sub_unsub.unsubscribe(TOPIC_A) + print("Client sub/unsub: Done with offline sub/unsub.") + do_once = False + loop_count -= 1 + time.sleep(0.01) + + def _manual_network_error(self): + # Ensure we close the socket + if self._client_sub_unsub._mqtt_core._internal_async_client._paho_client._sock: + self._client_sub_unsub._mqtt_core._internal_async_client._paho_client._sock.close() + self._client_sub_unsub._mqtt_core._internal_async_client._paho_client._sock = None + if self._client_sub_unsub._mqtt_core._internal_async_client._paho_client._ssl: + self._client_sub_unsub._mqtt_core._internal_async_client._paho_client._ssl.close() + self._client_sub_unsub._mqtt_core._internal_async_client._paho_client._ssl = None + # Fake that we have detected the disconnection + self._client_sub_unsub._mqtt_core._internal_async_client._paho_client.on_disconnect(None, None, 0) + + def _thread_client_pub_runtime(self): + print("Start client pub runtime thread...") + print("Client pub: 1st round of publishes") + for i in range(NUMBER_OF_PUBLISHES): + self._client_pub.publish(TOPIC_A, MESSAGE_PREFIX + str(i), 1) + print("Client pub: Published a message") + time.sleep(0.5) + time.sleep(1) + print("Client pub: Publishes done. Notifying...") + self._publish_end_flag.set() + + print("Client pub waits on client sub/unsub to be stable...") + time.sleep(1) + if not self._stable_flag.wait(EVENT_WAIT_TIME_OUT_SEC * 3): # We wait longer for the reconnect/stabilization + raise RuntimeError("Timed out in waiting for client_sub_unsub to be stable") + self._stable_flag.clear() + + print("Client pub: 2nd round of publishes") + for j in range(NUMBER_OF_PUBLISHES): + self._client_pub.publish(TOPIC_B, MESSAGE_PREFIX + str(j), 1) + print("Client pub: Published a message to %s" % TOPIC_B) + self._client_pub.publish(TOPIC_A, MESSAGE_PREFIX + str(j) + "-dup", 1) + print("Client pub: Published a message to %s" % TOPIC_A) + time.sleep(0.5) + time.sleep(1) + print("Client pub: Publishes done. Notifying...") + self._publish_end_flag.set() + + print("Client pub runtime thread ends.") + + def verify(self): + print("Verifying...") + assert len(self._received_messages_topic_a) == NUMBER_OF_PUBLISHES # We should only receive the first round + assert len(self._received_messages_topic_b) == NUMBER_OF_PUBLISHES # We should only receive the second round + print("Pass!") + + +############################################################################ +# Main # +# Check inputs +my_check_in_manager = checkInManager(2) +my_check_in_manager.verify(sys.argv) +mode = my_check_in_manager.mode +host = my_check_in_manager.host + +skip_when_match(ModeIsALPN(mode).And( + Python2VersionLowerThan((2, 7, 10)).Or(Python3VersionLowerThan((3, 5, 0))) +), "This test is not applicable for mode %s and Python verison %s. Skipping..." % (mode, sys.version_info[:3])) + +# Performing +############ +dual_client_runner = DualClientRunner(mode) +dual_client_runner.start() + +# Verifying +########### +dual_client_runner.verify() diff --git a/test-integration/IntegrationTests/IntegrationTestProgressiveBackoff.py b/test-integration/IntegrationTests/IntegrationTestProgressiveBackoff.py new file mode 100644 index 0000000..fc937ef --- /dev/null +++ b/test-integration/IntegrationTests/IntegrationTestProgressiveBackoff.py @@ -0,0 +1,289 @@ +# This integration test verifies the functionality in the Python core of Yun/Python SDK +# for progressive backoff logic in auto-reconnect. +# It starts two threads using two different connections to AWS IoT: +# Thread B subscribes to "coolTopic" and waits for incoming messages. Network +# failure will happen occasionally in thread B with a variant interval (connected +# period), simulating stable/unstable connection so as to test the reset logic of +# backoff timing. Once thread B is back online, an internal flag will be set to +# notify the other thread, Thread A, to start publishing to the same topic. +# Thread A will publish a set of messages (a fixed number of messages) to "coolTopic" +# using QoS1 and does nothing in the rest of the time. It will only start publishing +# when it gets ready notification from thread B. No network failure will happen in +# thread A. +# Because thread A is always online and only publishes when thread B is back online, +# all messages published to "coolTopic" should be received by thread B. In meantime, +# thread B should have an increasing amount of backoff waiting period until the +# connected period reaches the length of time for a stable connection. After that, +# the backoff waiting period should be reset. +# The following things will be verified to pass the test: +# 1. All messages are received. +# 2. Backoff waiting period increases as configured before the thread reaches to a +# stable connection. +# 3. Backoff waiting period does not exceed the maximum allowed time. +# 4. Backoff waiting period is reset after the thread reaches to a stable connection. + + +import string +import random +import time +import threading +import sys +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary") +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary/SDKPackage") + +from TestToolLibrary import simpleThreadManager +import TestToolLibrary.checkInManager as checkInManager +import TestToolLibrary.MQTTClientManager as MQTTClientManager +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import publishError +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeError +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeTimeoutException +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.core.protocol.internal.clients import ClientStatus +from TestToolLibrary.skip import skip_when_match +from TestToolLibrary.skip import ModeIsALPN +from TestToolLibrary.skip import Python2VersionLowerThan +from TestToolLibrary.skip import Python3VersionLowerThan + +CLIENT_ID_PUB = "integrationTestMQTT_ClientPub" + "".join(random.choice(string.ascii_lowercase) for i in range(4)) +CLIENT_ID_SUB = "integrationTestMQTT_ClientSub" + "".join(random.choice(string.ascii_lowercase) for i in range(4)) + +# Class that implements all the related threads in the test in a controllable manner +class threadPool: + def __init__(self, srcTotalNumberOfNetworkFailure, clientPub, clientSub): + self._threadBReadyFlag = 0 # 0-Not connected, 1-Connected+Subscribed, -1-ShouldExit + self._threadBReadyFlagMutex = threading.Lock() + self._targetedTopic = "coolTopic" + "".join(random.choice(string.ascii_lowercase) for i in range(4)) + self._publishMessagePool = set() + self._receiveMessagePool = set() + self._roundOfNetworkFailure = 1 + self._totalNumberOfNetworkFailure = srcTotalNumberOfNetworkFailure + self._clientPub = clientPub + self._clientSub = clientSub + self._pubCount = 0 + self._reconnectTimeRecord = list() + self._connectedTimeRecord = list() + + # Message callback for collecting incoming messages from the subscribed topic + def _messageCallback(self, client, userdata, message): + print("Thread B: Received new message: " + str(message.payload)) + self._receiveMessagePool.add(message.payload.decode('utf-8')) + + # The one that publishes + def threadARuntime(self): + exitNow = False + while not exitNow: + self._threadBReadyFlagMutex.acquire() + # Thread A is still reconnecting, WAIT! + if self._threadBReadyFlag == 0: + pass + # Thread A is connected and subscribed, PUBLISH! + elif self._threadBReadyFlag == 1: + self._publish3Messages() + self._threadBReadyFlag = 0 # Reset the readyFlag + # Thread A has finished all rounds of network failure/reconnect, EXIT! + else: + exitNow = True + self._threadBReadyFlagMutex.release() + time.sleep(0.01) # 0.01 sec scanning + + # Publish a set of messages: 3 + def _publish3Messages(self): + loopCount = 3 + while loopCount != 0: + try: + currentMessage = "Message" + str(self._pubCount) + print("Test publish to topic : " + self._targetedTopic) + self._clientPub.publish(self._targetedTopic, currentMessage, 1, False) + print("Thread A: Published new message: " + str(currentMessage)) + self._publishMessagePool.add(currentMessage) + self._pubCount += 1 + loopCount -= 1 + except publishError: + print("Publish error!") + except Exception as e: + print("Unknown exception!") + print("Type: " + str(type(e))) + print("Message: " + str(e.message)) + time.sleep(0.5) + + # The one that subscribes and has network failures + def threadBRuntime(self): + # Subscribe to the topic + try: + print("Test subscribe to topic : " + self._targetedTopic) + self._clientSub.subscribe(self._targetedTopic, 1, self._messageCallback) + except subscribeTimeoutException: + print("Subscribe timeout!") + except subscribeError: + print("Subscribe error!") + except Exception as e: + print("Unknown exception!") + print("Type: " + str(type(e))) + print("Message: " + str(e.message)) + print("Thread B: Subscribe request sent. Staring waiting for subscription processing...") + time.sleep(3) + print("Thread B: Done waiting.") + self._threadBReadyFlagMutex.acquire() + self._threadBReadyFlag = 1 + self._threadBReadyFlagMutex.release() + # Start looping with network failure + connectedPeriodSecond = 3 + while self._roundOfNetworkFailure <= self._totalNumberOfNetworkFailure: + self._connectedTimeRecord.append(connectedPeriodSecond) + # Wait for connectedPeriodSecond + print("Thread B: Connected time: " + str(connectedPeriodSecond) + " seconds.") + print("Thread B: Stable time: 60 seconds.") + time.sleep(connectedPeriodSecond) + print("Thread B: Network failure. Round: " + str(self._roundOfNetworkFailure) + ". 0.5 seconds.") + print("Thread B: Backoff time for this round should be: " + str( + self._clientSub._internal_async_client._paho_client._backoffCore._currentBackoffTimeSecond) + " second(s).") + # Set the readyFlag + self._threadBReadyFlagMutex.acquire() + self._threadBReadyFlag = 0 + self._threadBReadyFlagMutex.release() + # Now lose connection for 0.5 seconds, preventing multiple reconnect attempts + loseConnectionLoopCount = 50 + while loseConnectionLoopCount != 0: + self._manualNetworkError() + loseConnectionLoopCount -= 1 + time.sleep(0.01) + # Wait until the connection/subscription is recovered + reconnectTiming = 0 + while self._clientSub._client_status.get_status() != ClientStatus.STABLE: + time.sleep(0.01) + reconnectTiming += 1 + if reconnectTiming % 100 == 0: + print("Thread B: Counting reconnect time: " + str(reconnectTiming / 100) + " seconds.") + print("Thread B: Counting reconnect time result: " + str(float(reconnectTiming) / 100) + " seconds.") + self._reconnectTimeRecord.append(reconnectTiming / 100) + + time.sleep(3) # For valid subscription + + # Update thread B status + self._threadBReadyFlagMutex.acquire() + self._threadBReadyFlag = 1 + self._threadBReadyFlagMutex.release() + + # Update connectedPeriodSecond + connectedPeriodSecond += (2 ** (self._roundOfNetworkFailure - 1)) + # Update roundOfNetworkFailure + self._roundOfNetworkFailure += 1 + + # Notify thread A shouldExit + self._threadBReadyFlagMutex.acquire() + self._threadBReadyFlag = -1 + self._threadBReadyFlagMutex.release() + + # Simulate a network error + def _manualNetworkError(self): + # Only the subscriber needs the network error + if self._clientSub._internal_async_client._paho_client._sock: + self._clientSub._internal_async_client._paho_client._sock.close() + self._clientSub._internal_async_client._paho_client._sock = None + if self._clientSub._internal_async_client._paho_client._ssl: + self._clientSub._internal_async_client._paho_client._ssl.close() + self._clientSub._internal_async_client._paho_client._ssl = None + # Fake that we have detected the disconnection + self._clientSub._internal_async_client._paho_client.on_disconnect(None, None, 0) + + def getReconnectTimeRecord(self): + return self._reconnectTimeRecord + + def getConnectedTimeRecord(self): + return self._connectedTimeRecord + + +# Generate the correct backoff timing to compare the test result with +def generateCorrectAnswer(baseTime, maximumTime, stableTime, connectedTimeRecord): + answer = list() + currentTime = baseTime + nextTime = baseTime + for i in range(0, len(connectedTimeRecord)): + if connectedTimeRecord[i] >= stableTime or i == 0: + currentTime = baseTime + else: + currentTime = min(currentTime * 2, maximumTime) + answer.append(currentTime) + return answer + + +# Verify backoff time +# Corresponding element should have no diff or a bias greater than 1.5 +def verifyBackoffTime(answerList, resultList): + result = True + for i in range(0, len(answerList)): + if abs(answerList[i] - resultList[i]) > 1.5: + result = False + break + return result + + +############################################################################ +# Main # +# Check inputs +myCheckInManager = checkInManager.checkInManager(3) +myCheckInManager.verify(sys.argv) + +#host via describe-endpoint on this OdinMS: com.amazonaws.iot.device.sdk.credentials.testing.websocket +host = myCheckInManager.host +rootCA = "./test-integration/Credentials/rootCA.crt" +certificate = "./test-integration/Credentials/certificate.pem.crt" +privateKey = "./test-integration/Credentials/privateKey.pem.key" +mode = myCheckInManager.mode + +skip_when_match(ModeIsALPN(mode).And( + Python2VersionLowerThan((2, 7, 10)).Or(Python3VersionLowerThan((3, 5, 0))) +), "This test is not applicable for mode %s and Python verison %s. Skipping..." % (mode, sys.version_info[:3])) + +# Init Python core and connect +myMQTTClientManager = MQTTClientManager.MQTTClientManager() +clientPub = myMQTTClientManager.create_connected_mqtt_core(CLIENT_ID_PUB, host, rootCA, + certificate, privateKey, mode=mode) +clientSub = myMQTTClientManager.create_connected_mqtt_core(CLIENT_ID_SUB, host, rootCA, + certificate, privateKey, mode=mode) + +if clientPub is None or clientSub is None: + exit(4) + +# Extra configuration for clients +clientPub.configure_reconnect_back_off(1, 16, 60) +clientSub.configure_reconnect_back_off(1, 16, 60) + +print("Two clients are connected!") + +# Configurations +################ +# Custom parameters +NumberOfNetworkFailure = myCheckInManager.customParameter +# ThreadPool object +threadPoolObject = threadPool(NumberOfNetworkFailure, clientPub, clientSub) +# Threads +mySimpleThreadManager = simpleThreadManager.simpleThreadManager() +threadAID = mySimpleThreadManager.createOneTimeThread(threadPoolObject.threadARuntime, []) +threadBID = mySimpleThreadManager.createOneTimeThread(threadPoolObject.threadBRuntime, []) + +# Performing +############ +mySimpleThreadManager.startThreadWithID(threadBID) +mySimpleThreadManager.startThreadWithID(threadAID) +mySimpleThreadManager.joinOneTimeThreadWithID(threadBID) +mySimpleThreadManager.joinOneTimeThreadWithID(threadAID) + +# Verifying +########### +print("Verify that all messages are received...") +if threadPoolObject._publishMessagePool == threadPoolObject._receiveMessagePool: + print("Passed. Recv/Pub: " + str(len(threadPoolObject._receiveMessagePool)) + "/" + str( + len(threadPoolObject._publishMessagePool))) +else: + print("Not all messages are received!") + exit(4) +print("Verify reconnect backoff time record...") +print("ConnectedTimeRecord: " + str(threadPoolObject.getConnectedTimeRecord())) +print("ReconnectTimeRecord: " + str(threadPoolObject.getReconnectTimeRecord())) +print("Answer: " + str(generateCorrectAnswer(1, 16, 60, threadPoolObject.getConnectedTimeRecord()))) +if verifyBackoffTime(generateCorrectAnswer(1, 16, 60, threadPoolObject.getConnectedTimeRecord()), + threadPoolObject.getReconnectTimeRecord()): + print("Passed.") +else: + print("Backoff time does not match theoretical value!") + exit(4) diff --git a/test-integration/IntegrationTests/IntegrationTestShadow.py b/test-integration/IntegrationTests/IntegrationTestShadow.py new file mode 100644 index 0000000..9b2d85a --- /dev/null +++ b/test-integration/IntegrationTests/IntegrationTestShadow.py @@ -0,0 +1,248 @@ +# This integration test verifies the functionality in the Python core of Yun/Python SDK +# for IoT shadow operations: shadowUpdate and delta. +# 1. The test generates a X-byte-long random sting and breaks it into a random +# number of chunks, with a fixed length variation from 1 byte to 10 bytes. +# 2. Two threads are created to do shadowUpdate and delta on the same device +# shadow. The update thread updates the desired state with an increasing sequence +# number and a chunk. It is terminated when there are no more chunks to be sent. +# 3. The delta thread listens on delta topic and receives the changes in device +# shadow JSON document. It parses out the sequence number and the chunk, then pack +# them into a dictionary with sequence number as the key and the chunk as the value. +# 4. To verify the result of the test, the random string is re-assembled for both +# the update thread and the delta thread to see if they are equal. +# 5. Since shadow operations are all QoS0 (Pub/Sub), it is still a valid case when +# the re-assembled strings are not equal. Then we need to make sure that the number +# of the missing chunks does not exceed 10% of the total number of chunk transmission +# that succeeds. + +import time +import random +import string +import json +import sys +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary") +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary/SDKPackage") + +from TestToolLibrary import simpleThreadManager +import TestToolLibrary.checkInManager as checkInManager +import TestToolLibrary.MQTTClientManager as MQTTClientManager +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.core.shadow.deviceShadow import deviceShadow +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.core.shadow.shadowManager import shadowManager +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import publishError +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeError +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeTimeoutException +from TestToolLibrary.skip import skip_when_match +from TestToolLibrary.skip import ModeIsALPN +from TestToolLibrary.skip import Python2VersionLowerThan +from TestToolLibrary.skip import Python3VersionLowerThan + + +# Global configuration +TPS = 1 # Update speed, Spectre does not tolerate high TPS shadow operations... +CLIENT_ID_PUB = "integrationTestMQTT_ClientPub" + "".join(random.choice(string.ascii_lowercase) for i in range(4)) +CLIENT_ID_SUB = "integrationTestMQTT_ClientSub" + "".join(random.choice(string.ascii_lowercase) for i in range(4)) + + +# Class that manages the generation and chopping of the random string +class GibberishBox: + def __init__(self, length): + self._content = self._generateGibberish(length) + + def getGibberish(self): + return self._content + + # Random string generator: lower/upper case letter + digits + def _generateGibberish(self, length): + s = string.ascii_lowercase + string.digits + string.ascii_uppercase + return ''.join(random.sample(s * length, length)) + + # Spit out the gibberish chunk by chunk (1-10 bytes) + def gibberishSpitter(self): + randomLength = random.randrange(1, 11) + ret = None + if self._content is not None: + ret = self._content[0:randomLength] + self._content = self._content[randomLength:] + return ret + + +# Class that manages the callback function and record of chunks for re-assembling +class callbackContainer: + def __init__(self): + self._internalDictionary = dict() + + def getInternalDictionary(self): + return self._internalDictionary + + def testCallback(self, payload, type, token): + print("Type: " + type) + print(payload) + print("&&&&&&&&&&&&&&&&&&&&") + # This is the shadow delta callback, so the token should be None + if type == "accepted": + JsonDict = json.loads(payload) + try: + sequenceNumber = int(JsonDict['state']['desired']['sequenceNumber']) + gibberishChunk = JsonDict['state']['desired']['gibberishChunk'] + self._internalDictionary[sequenceNumber] = gibberishChunk + except KeyError as e: + print(e.message) + print("No such key!") + else: + JsonDict = json.loads(payload) + try: + sequenceNumber = int(JsonDict['state']['sequenceNumber']) + gibberishChunk = JsonDict['state']['gibberishChunk'] + self._internalDictionary[sequenceNumber] = gibberishChunk + except KeyError as e: + print(e.message) + print("No such key!") + + +# Thread runtime function +def threadShadowUpdate(deviceShadow, callback, TPS, gibberishBox, maxNumMessage): + time.sleep(2) + chunkSequence = 0 + while True: + currentChunk = gibberishBox.gibberishSpitter() + if currentChunk != "": + outboundJsonDict = dict() + outboundJsonDict["state"] = dict() + outboundJsonDict["state"]["desired"] = dict() + outboundJsonDict["state"]["desired"]["sequenceNumber"] = chunkSequence + outboundJsonDict["state"]["desired"]["gibberishChunk"] = currentChunk + outboundJSON = json.dumps(outboundJsonDict) + chunkSequence += 1 + try: + deviceShadow.shadowUpdate(outboundJSON, callback, 5) + except publishError: + print("Publish error!") + except subscribeTimeoutException: + print("Subscribe timeout!") + except subscribeError: + print("Subscribe error!") + except Exception as e: + print("Unknown exception!") + print("Type: " + str(type(e))) + print("Message: " + str(e.message)) + time.sleep(1 / TPS) + else: + break + print("Update thread completed.") + + +# Re-assemble gibberish +def reAssembleGibberish(srcDict, maxNumMessage): + ret = "" + for i in range(0, maxNumMessage): + try: + ret += srcDict[i] + except KeyError: + pass + return ret + + +# RandomShadowNameSuffix +def randomString(lengthOfString): + return "".join(random.choice(string.ascii_lowercase) for i in range(lengthOfString)) + + +############################################################################ +# Main # +# Check inputs +myCheckInManager = checkInManager.checkInManager(3) +myCheckInManager.verify(sys.argv) + +host = myCheckInManager.host +rootCA = "./test-integration/Credentials/rootCA.crt" +certificate = "./test-integration/Credentials/certificate.pem.crt" +privateKey = "./test-integration/Credentials/privateKey.pem.key" +mode = myCheckInManager.mode + +skip_when_match(ModeIsALPN(mode).And( + Python2VersionLowerThan((2, 7, 10)).Or(Python3VersionLowerThan((3, 5, 0))) +), "This test is not applicable for mode %s and Python verison %s. Skipping..." % (mode, sys.version_info[:3])) + +# Init Python core and connect +myMQTTClientManager = MQTTClientManager.MQTTClientManager() +clientPub = myMQTTClientManager.create_connected_mqtt_core(CLIENT_ID_PUB, host, rootCA, + certificate, privateKey, mode=mode) +clientSub = myMQTTClientManager.create_connected_mqtt_core(CLIENT_ID_SUB, host, rootCA, + certificate, privateKey, mode=mode) + +if clientPub is None or clientSub is None: + exit(4) + +print("Two clients are connected!") + +# Configurations +################ +# Data +gibberishLength = myCheckInManager.customParameter +# Init device shadow instance +shadowManager1 = shadowManager(clientPub) +shadowManager2 = shadowManager(clientSub) +shadowName = "GibberChunk" + randomString(5) +deviceShadow1 = deviceShadow(shadowName, True, shadowManager1) +deviceShadow2 = deviceShadow(shadowName, True, shadowManager2) +print("Two device shadow instances are created!") + +# Callbacks +callbackHome_Update = callbackContainer() +callbackHome_Delta = callbackContainer() + +# Listen on delta topic +try: + deviceShadow2.shadowRegisterDeltaCallback(callbackHome_Delta.testCallback) +except subscribeError: + print("Subscribe error!") +except subscribeTimeoutException: + print("Subscribe timeout!") +except Exception as e: + print("Unknown exception!") + print("Type: " + str(type(e))) + print("Message: " + str(e.message)) + +# Init gibberishBox +cipher = GibberishBox(gibberishLength) +gibberish = cipher.getGibberish() +print("Random string: " + gibberish) + +# Threads +mySimpleThreadManager = simpleThreadManager.simpleThreadManager() +updateThreadID = mySimpleThreadManager.createOneTimeThread(threadShadowUpdate, + [deviceShadow1, callbackHome_Update.testCallback, TPS, + cipher, gibberishLength]) + +# Performing +############ +# Functionality test +mySimpleThreadManager.startThreadWithID(updateThreadID) +mySimpleThreadManager.joinOneTimeThreadWithID(updateThreadID) +time.sleep(10) # Just in case + +# Now check the gibberish +gibberishUpdateResult = reAssembleGibberish(callbackHome_Update.getInternalDictionary(), gibberishLength) +gibberishDeltaResult = reAssembleGibberish(callbackHome_Delta.getInternalDictionary(), gibberishLength) +print("Update:") +print(gibberishUpdateResult) +print("Delta:") +print(gibberishDeltaResult) +print("Origin:") +print(gibberish) + +if gibberishUpdateResult != gibberishDeltaResult: + # Since shadow operations are on QoS0 (Pub/Sub), there could be a chance + # where incoming messages are missing on the subscribed side + # A ratio of 95% must be guaranteed to pass this test + dictUpdate = callbackHome_Update.getInternalDictionary() + dictDelta = callbackHome_Delta.getInternalDictionary() + maxBaseNumber = max(len(dictUpdate), len(dictDelta)) + diff = float(abs(len(dictUpdate) - len(dictDelta))) / maxBaseNumber + print("Update/Delta string not equal, missing rate is: " + str(diff * 100) + "%.") + # Largest chunk is 10 bytes, total length is X bytes. + # Minimum number of chunks is X/10 + # Maximum missing rate = 10% + if diff > 0.1: + print("Missing rate too high!") + exit(4) diff --git a/test-integration/IntegrationTests/TestToolLibrary/MQTTClientManager.py b/test-integration/IntegrationTests/TestToolLibrary/MQTTClientManager.py new file mode 100644 index 0000000..05c671a --- /dev/null +++ b/test-integration/IntegrationTests/TestToolLibrary/MQTTClientManager.py @@ -0,0 +1,145 @@ +import random +import string +import traceback +from ssl import SSLError + +import TestToolLibrary.SDKPackage.AWSIoTPythonSDK.core.protocol.paho.client as paho +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.MQTTLib import AWSIoTMQTTClient +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.core.util.providers import CertificateCredentialsProvider +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.core.util.providers import CiphersProvider +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.core.util.enums import DropBehaviorTypes +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.core.util.providers import EndpointProvider +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.core.protocol.mqtt_core import MqttCore +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import connectError +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import connectTimeoutException + + +CERT_MUTUAL_AUTH = "MutualAuth" +WEBSOCKET = 'Websocket' +CERT_ALPN = "ALPN" + + +# Class that manages the creation, configuration and connection of MQTT Client +class MQTTClientManager: + + def create_connected_mqtt_client(self, mode, client_id, host, credentials_data, callbacks=None): + client = self.create_nonconnected_mqtt_client(mode, client_id, host, credentials_data, callbacks) + return self._connect_client(client) + + def create_nonconnected_mqtt_client(self, mode, client_id, host, credentials_data, callbacks=None): + if mode == CERT_MUTUAL_AUTH: + sdk_mqtt_client = self._create_nonconnected_mqtt_client_with_cert(client_id, host, 8883, credentials_data) + elif mode == WEBSOCKET: + root_ca, certificate, private_key = credentials_data + sdk_mqtt_client = AWSIoTMQTTClient(clientID=client_id + "_" + self._random_string(3), useWebsocket=True) + sdk_mqtt_client.configureEndpoint(host, 443) + sdk_mqtt_client.configureCredentials(CAFilePath=root_ca) + elif mode == CERT_ALPN: + sdk_mqtt_client = self._create_nonconnected_mqtt_client_with_cert(client_id, host, 443, credentials_data) + else: + raise RuntimeError("Test mode: " + str(mode) + " not supported!") + + sdk_mqtt_client.configureConnectDisconnectTimeout(10) + sdk_mqtt_client.configureMQTTOperationTimeout(5) + + if callbacks is not None: + sdk_mqtt_client.onOnline = callbacks.on_online + sdk_mqtt_client.onOffline = callbacks.on_offline + sdk_mqtt_client.onMessage = callbacks.on_message + + return sdk_mqtt_client + + def _create_nonconnected_mqtt_client_with_cert(self, client_id, host, port, credentials_data): + root_ca, certificate, private_key = credentials_data + sdk_mqtt_client = AWSIoTMQTTClient(clientID=client_id + "_" + self._random_string(3)) + sdk_mqtt_client.configureEndpoint(host, port) + sdk_mqtt_client.configureCredentials(CAFilePath=root_ca, KeyPath=private_key, CertificatePath=certificate) + + return sdk_mqtt_client + + def create_connected_mqtt_core(self, client_id, host, root_ca, certificate, private_key, mode): + client = self.create_nonconnected_mqtt_core(client_id, host, root_ca, certificate, private_key, mode) + return self._connect_client(client) + + def create_nonconnected_mqtt_core(self, client_id, host, root_ca, certificate, private_key, mode): + client = None + protocol = None + port = None + is_websocket = False + is_alpn = False + + if mode == CERT_MUTUAL_AUTH: + protocol = paho.MQTTv311 + port = 8883 + elif mode == WEBSOCKET: + protocol = paho.MQTTv31 + port = 443 + is_websocket = True + elif mode == CERT_ALPN: + protocol = paho.MQTTv311 + port = 443 + is_alpn = True + else: + print("Error in creating the client") + + if protocol is None or port is None: + print("Not enough input parameters") + return client # client is None is the necessary params are not there + + try: + client = MqttCore(client_id + "_" + self._random_string(3), True, protocol, is_websocket) + + endpoint_provider = EndpointProvider() + endpoint_provider.set_host(host) + endpoint_provider.set_port(port) + + # Once is_websocket is True, certificate_credentials_provider will NOT be used + # by the client even if it is configured + certificate_credentials_provider = CertificateCredentialsProvider() + certificate_credentials_provider.set_ca_path(root_ca) + certificate_credentials_provider.set_cert_path(certificate) + certificate_credentials_provider.set_key_path(private_key) + + cipher_provider = CiphersProvider() + cipher_provider.set_ciphers(None) + + client.configure_endpoint(endpoint_provider) + client.configure_cert_credentials(certificate_credentials_provider, cipher_provider) + client.configure_connect_disconnect_timeout_sec(10) + client.configure_operation_timeout_sec(5) + client.configure_offline_requests_queue(10, DropBehaviorTypes.DROP_NEWEST) + + if is_alpn: + client.configure_alpn_protocols() + except Exception as e: + print("Unknown exception in creating the client: " + str(e)) + finally: + return client + + def _random_string(self, length): + return "".join(random.choice(string.ascii_lowercase) for i in range(length)) + + def _connect_client(self, client): + if client is None: + return client + + try: + client.connect(1) + except connectTimeoutException as e: + print("Connect timeout: " + str(e)) + return None + except connectError as e: + print("Connect error:" + str(e)) + return None + except SSLError as e: + print("Connect SSL error: " + str(e)) + return None + except IOError as e: + print("Credentials not found: " + str(e)) + return None + except Exception as e: + print("Unknown exception in connect: ") + traceback.print_exc() + return None + + return client diff --git a/test-integration/IntegrationTests/TestToolLibrary/SDKPackage/.gitignore b/test-integration/IntegrationTests/TestToolLibrary/SDKPackage/.gitignore new file mode 100644 index 0000000..151aa74 --- /dev/null +++ b/test-integration/IntegrationTests/TestToolLibrary/SDKPackage/.gitignore @@ -0,0 +1,3 @@ +*.* +!.gitignore +!__init__.py \ No newline at end of file diff --git a/test-integration/IntegrationTests/TestToolLibrary/SDKPackage/__init__.py b/test-integration/IntegrationTests/TestToolLibrary/SDKPackage/__init__.py new file mode 100644 index 0000000..1ad354e --- /dev/null +++ b/test-integration/IntegrationTests/TestToolLibrary/SDKPackage/__init__.py @@ -0,0 +1 @@ +__version__ = "1.4.9" diff --git a/test-integration/IntegrationTests/TestToolLibrary/__init__.py b/test-integration/IntegrationTests/TestToolLibrary/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test-integration/IntegrationTests/TestToolLibrary/checkInManager.py b/test-integration/IntegrationTests/TestToolLibrary/checkInManager.py new file mode 100644 index 0000000..aeeedd9 --- /dev/null +++ b/test-integration/IntegrationTests/TestToolLibrary/checkInManager.py @@ -0,0 +1,20 @@ +# Class that is responsible for input/dependency verification +import sys + + +class checkInManager: + + def __init__(self, numberOfInputParameters): + self._numberOfInputParameters = numberOfInputParameters + self.mode = None + self.host = None + self.customParameter = None + + def verify(self, args): + # Check if we got the correct command line params + if len(args) != self._numberOfInputParameters + 1: + exit(4) + self.mode = str(args[1]) + self.host = str(args[2]) + if self._numberOfInputParameters + 1 > 3: + self.customParameter = int(args[3]) diff --git a/test-integration/IntegrationTests/TestToolLibrary/simpleThreadManager.py b/test-integration/IntegrationTests/TestToolLibrary/simpleThreadManager.py new file mode 100644 index 0000000..023f8e7 --- /dev/null +++ b/test-integration/IntegrationTests/TestToolLibrary/simpleThreadManager.py @@ -0,0 +1,110 @@ +# Library for controllable threads. Should be able to: +# 1. Create different threads +# 2. Terminate certain thread +# 3. Join certain thread + + +import time +import threading + + +# Describe the type, property and control flag for a certain thread +class _threadControlUnit: + # Constants for different thread type + _THREAD_TYPE_ONETIME = 0 + _THREAD_TYPE_LOOP = 1 + + def __init__(self, threadID, threadType, runFunction, runParameters, scanningSpeedSecond=0.01): + if threadID is None or threadType is None or runFunction is None or runParameters is None: + raise ValueError("None input detected.") + if threadType != self._THREAD_TYPE_ONETIME and threadType != self._THREAD_TYPE_LOOP: + raise ValueError("Thread type not supported.") + self.threadID = threadID + self.threadType = threadType + self.runFunction = runFunction + self.runParameters = runParameters + self.threadObject = None # Holds the real thread object + # Now configure control flag, only meaning for loop thread type + if self.threadType == self._THREAD_TYPE_LOOP: + self.stopSign = False # Enabled infinite loop by default + self.scanningSpeedSecond = scanningSpeedSecond + else: + self.stopSign = None # No control flag for one time thread + self.scanningSpeedSecond = -1 + + def _oneTimeRunFunction(self): + self.runFunction(*self.runParameters) + + def _loopRunFunction(self): + while not self.stopSign: + self.runFunction(*self.runParameters) # There should be no manual delay in this function + time.sleep(self.scanningSpeedSecond) + + def _stopMe(self): + self.stopSign = True + + def _setThreadObject(self, threadObject): + self.threadObject = threadObject + + def _getThreadObject(self): + return self.threadObject + + +# Class that manages all threadControlUnit +# Used in a single thread +class simpleThreadManager: + def __init__(self): + self._internalCount = 0 + self._controlCenter = dict() + + def createOneTimeThread(self, runFunction, runParameters): + returnID = self._internalCount + self._controlCenter[self._internalCount] = _threadControlUnit(self._internalCount, + _threadControlUnit._THREAD_TYPE_ONETIME, + runFunction, runParameters) + self._internalCount += 1 + return returnID + + def createLoopThread(self, runFunction, runParameters, scanningSpeedSecond): + returnID = self._internalCount + self._controlCenter[self._internalCount] = _threadControlUnit(self._internalCount, + _threadControlUnit._THREAD_TYPE_LOOP, runFunction, + runParameters, scanningSpeedSecond) + self._internalCount += 1 + return returnID + + def stopLoopThreadWithID(self, threadID): + threadToStop = self._controlCenter.get(threadID) + if threadToStop is None: + raise ValueError("No such threadID.") + else: + if threadToStop.threadType == _threadControlUnit._THREAD_TYPE_LOOP: + threadToStop._stopMe() + time.sleep(3 * threadToStop.scanningSpeedSecond) + else: + raise TypeError("Error! Try to stop a one time thread.") + + def startThreadWithID(self, threadID): + threadToStart = self._controlCenter.get(threadID) + if threadToStart is None: + raise ValueError("No such threadID.") + else: + currentThreadType = threadToStart.threadType + newThreadObject = None + if currentThreadType == _threadControlUnit._THREAD_TYPE_LOOP: + newThreadObject = threading.Thread(target=threadToStart._loopRunFunction) + else: # One time thread + newThreadObject = threading.Thread(target=threadToStart._oneTimeRunFunction) + newThreadObject.start() + threadToStart._setThreadObject(newThreadObject) + + def joinOneTimeThreadWithID(self, threadID): + threadToJoin = self._controlCenter.get(threadID) + if threadToJoin is None: + raise ValueError("No such threadID.") + else: + if threadToJoin.threadType == _threadControlUnit._THREAD_TYPE_ONETIME: + currentThreadObject = threadToJoin._getThreadObject() + currentThreadObject.join() + else: + raise TypeError("Error! Try to join a loop thread.") diff --git a/test-integration/IntegrationTests/TestToolLibrary/skip.py b/test-integration/IntegrationTests/TestToolLibrary/skip.py new file mode 100644 index 0000000..4d6e5ca --- /dev/null +++ b/test-integration/IntegrationTests/TestToolLibrary/skip.py @@ -0,0 +1,110 @@ +import sys +from TestToolLibrary.MQTTClientManager import CERT_ALPN +from TestToolLibrary.MQTTClientManager import WEBSOCKET + +# This module manages the skip policy validation for each test + + +def skip_when_match(policy, message): + if policy.validate(): + print(message) + exit(0) # Exit the Python interpreter + + +class Policy(object): + + AND = "and" + OR = "or" + + def __init__(self): + self._relations = [] + + # Use caps to avoid collision with Python built-in and/or keywords + def And(self, policy): + self._relations.append((self.AND, policy)) + return self + + def Or(self, policy): + self._relations.append((self.OR, policy)) + return self + + def validate(self): + result = self.validate_impl() + + for element in self._relations: + operand, policy = element + if operand == self.AND: + result = result and policy.validate() + elif operand == self.OR: + result = result or policy.validate() + else: + raise RuntimeError("Unrecognized operand: " + str(operand)) + + return result + + def validate_impl(self): + raise RuntimeError("Not implemented") + + +class PythonVersion(Policy): + + HIGHER = "higher" + LOWER = "lower" + EQUALS = "equals" + + def __init__(self, actual_version, expected_version, operand): + Policy.__init__(self) + self._actual_version = actual_version + self._expected_version = expected_version + self._operand = operand + + def validate_impl(self): + if self._operand == self.LOWER: + return self._actual_version < self._expected_version + elif self._operand == self.HIGHER: + return self._actual_version > self._expected_version + elif self._operand == self.EQUALS: + return self._actual_version == self._expected_version + else: + raise RuntimeError("Unsupported operand: " + self._operand) + + +class Python2VersionLowerThan(PythonVersion): + + def __init__(self, version): + PythonVersion.__init__(self, sys.version_info[:3], version, PythonVersion.LOWER) + + def validate_impl(self): + return sys.version_info[0] == 2 and PythonVersion.validate_impl(self) + + +class Python3VersionLowerThan(PythonVersion): + + def __init__(self, version): + PythonVersion.__init__(self, sys.version_info[:3], version, PythonVersion.LOWER) + + def validate_impl(self): + return sys.version_info[0] == 3 and PythonVersion.validate_impl(self) + + +class ModeIs(Policy): + + def __init__(self, actual_mode, expected_mode): + Policy.__init__(self) + self._actual_mode = actual_mode + self._expected_mode = expected_mode + + def validate_impl(self): + return self._actual_mode == self._expected_mode + + +class ModeIsALPN(ModeIs): + + def __init__(self, actual_mode): + ModeIs.__init__(self, actual_mode=actual_mode, expected_mode=CERT_ALPN) + + +class ModeIsWebSocket(ModeIs): + + def __init__(self, actual_mode): + ModeIs.__init__(self, actual_mode=actual_mode, expected_mode=WEBSOCKET) diff --git a/test-integration/Tools/retrieve-key.py b/test-integration/Tools/retrieve-key.py new file mode 100644 index 0000000..3884d7f --- /dev/null +++ b/test-integration/Tools/retrieve-key.py @@ -0,0 +1,59 @@ + +import boto3 +import base64 +import sys +from botocore.exceptions import ClientError + +def main(): + secret_name = sys.argv[1] + region_name = "us-east-1" + + # Create a Secrets Manager client + session = boto3.session.Session() + client = session.client( + service_name='secretsmanager', + region_name=region_name + ) + # In this sample we only handle the specific exceptions for the 'GetSecretValue' API. + # See https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html + # We rethrow the exception by default. + + try: + get_secret_value_response = client.get_secret_value( + SecretId=secret_name + ) + + except ClientError as e: + if e.response['Error']['Code'] == 'DecryptionFailureException': + # Secrets Manager can't decrypt the protected secret text using the provided KMS key. + # Deal with the exception here, and/or rethrow at your discretion. + raise e + elif e.response['Error']['Code'] == 'InternalServiceErrorException': + # An error occurred on the server side. + # Deal with the exception here, and/or rethrow at your discretion. + raise e + elif e.response['Error']['Code'] == 'InvalidParameterException': + # You provided an invalid value for a parameter. + # Deal with the exception here, and/or rethrow at your discretion. + raise e + elif e.response['Error']['Code'] == 'InvalidRequestException': + # You provided a parameter value that is not valid for the current state of the resource. + # Deal with the exception here, and/or rethrow at your discretion. + raise e + elif e.response['Error']['Code'] == 'ResourceNotFoundException': + # We can't find the resource that you asked for. + # Deal with the exception here, and/or rethrow at your discretion. + raise e + print(e) + else: + # Decrypts secret using the associated KMS key. + # Depending on whether the secret is a string or binary, one of these fields will be populated. + if 'SecretString' in get_secret_value_response: + secret = get_secret_value_response['SecretString'] + else: + secret = base64.b64decode(get_secret_value_response['SecretBinary']) + print(secret) + + +if __name__ == '__main__': + sys.exit(main()) # next section explains the use of sys.exit diff --git a/test-integration/run/run.sh b/test-integration/run/run.sh new file mode 100755 index 0000000..8e23c91 --- /dev/null +++ b/test-integration/run/run.sh @@ -0,0 +1,155 @@ +#!/bin/bash +# +# This script manages the start of integration +# tests for Python core in AWS IoT Arduino Yun +# SDK. The tests should be able to run both in +# Brazil and ToD Worker environment. +# The script will perform the following tasks: +# 1. Retrieve credentials as needed from AWS +# 2. Obtain ZIP package and unzip it locally +# 3. Start the integration tests and check results +# 4. Report any status returned. +# To start the tests as TodWorker: +# > run.sh MutualAuth 1000 100 7 +# or +# > run.sh Websocket 1000 100 7 +# or +# > run.sh ALPN 1000 100 7 +# +# To start the tests from desktop: +# > run.sh MutualAuthT 1000 100 7 +# or +# > run.sh WebsocketT 1000 100 7 +# or +# > run.sh ALPNT 1000 100 7 +# +# 1000 MQTT messages, 100 bytes of random string +# in length and 7 rounds of network failure for +# progressive backoff. +# Test mode (MutualAuth/Websocket) must be +# specified. +# Scale number must also be specified (see usage) + +# Define const +USAGE="usage: run.sh " + +UnitTestHostArn="arn:aws:secretsmanager:us-east-1:180635532705:secret:unit-test/endpoint-HSpeEu" +GreenGrassHostArn="arn:aws:secretsmanager:us-east-1:180635532705:secret:ci/greengrassv1/endpoint-DgM00X" + +AWSMutualAuth_TodWorker_private_key="arn:aws:secretsmanager:us-east-1:180635532705:secret:ci/mqtt5/us/Mqtt5Prod/key-kqgyvf" +AWSMutualAuth_TodWorker_certificate="arn:aws:secretsmanager:us-east-1:180635532705:secret:ci/mqtt5/us/Mqtt5Prod/cert-VDI1Gd" + +AWSGGDiscovery_TodWorker_private_key="arn:aws:secretsmanager:us-east-1:180635532705:secret:V1IotSdkIntegrationTestGGDiscoveryPrivateKey-BsLvNP" +AWSGGDiscovery_TodWorker_certificate="arn:aws:secretsmanager:us-east-1:180635532705:secret:V1IotSdkIntegrationTestGGDiscoveryCertificate-DSwdhA" + + +SDKLocation="./AWSIoTPythonSDK" +RetrieveAWSKeys="./test-integration/Tools/retrieve-key.py" +CREDENTIAL_DIR="./test-integration/Credentials/" +TEST_DIR="./test-integration/IntegrationTests/" +CA_CERT_URL="https://www.amazontrust.com/repository/AmazonRootCA1.pem" +CA_CERT_PATH=${CREDENTIAL_DIR}rootCA.crt +TestHost=$(python ${RetrieveAWSKeys} ${UnitTestHostArn}) +GreengrassHost=$(python ${RetrieveAWSKeys} ${GreenGrassHostArn}) + + + + +# If input args not correct, echo usage +if [ $# -ne 4 ]; then + echo ${USAGE} +else +# Description + echo "[STEP] Start run.sh" + echo "***************************************************" + echo "About to start integration tests for IoTPySDK..." + echo "Test Mode: $1" +# Determine the Python versions need to test for this SDK + pythonExecutableArray=() + pythonExecutableArray[0]="3" +# Retrieve credentials as needed from AWS + TestMode="" + echo "[STEP] Retrieve credentials from AWS" + echo "***************************************************" + if [ "$1"x == "MutualAuth"x ]; then + AWSSetName_privatekey=${AWSMutualAuth_TodWorker_private_key} + AWSSetName_certificate=${AWSMutualAuth_TodWorker_certificate} + AWSDRSName_privatekey=${AWSGGDiscovery_TodWorker_private_key} + AWSDRSName_certificate=${AWSGGDiscovery_TodWorker_certificate} + TestMode="MutualAuth" + python ${RetrieveAWSKeys} ${AWSSetName_certificate} > ${CREDENTIAL_DIR}certificate.pem.crt + python ${RetrieveAWSKeys} ${AWSSetName_privatekey} > ${CREDENTIAL_DIR}privateKey.pem.key + curl -s "${CA_CERT_URL}" > ${CA_CERT_PATH} + echo -e "URL retrieved certificate data\n" + python ${RetrieveAWSKeys} ${AWSDRSName_certificate} > ${CREDENTIAL_DIR}certificate_drs.pem.crt + python ${RetrieveAWSKeys} ${AWSDRSName_privatekey} > ${CREDENTIAL_DIR}privateKey_drs.pem.key + elif [ "$1"x == "Websocket"x ]; then + TestMode="Websocket" + curl -s "${CA_CERT_URL}" > ${CA_CERT_PATH} + echo -e "URL retrieved certificate data\n" + elif [ "$1"x == "ALPN"x ]; then + AWSSetName_privatekey=${AWSMutualAuth_TodWorker_private_key} + AWSSetName_certificate=${AWSMutualAuth_TodWorker_certificate} + AWSDRSName_privatekey=${AWSGGDiscovery_TodWorker_private_key} + AWSDRSName_certificate=${AWSGGDiscovery_TodWorker_certificate} + TestMode="ALPN" + python ${RetrieveAWSKeys} ${AWSSetName_certificate} > ${CREDENTIAL_DIR}certificate.pem.crt + python ${RetrieveAWSKeys} ${AWSSetName_privatekey} > ${CREDENTIAL_DIR}privateKey.pem.key + curl -s "${CA_CERT_URL}" > ${CA_CERT_PATH} + echo -e "URL retrieved certificate data\n" + python ${RetrieveAWSKeys} ${AWSDRSName_certificate} > ${CREDENTIAL_DIR}certificate_drs.pem.crt + python ${RetrieveAWSKeys} ${AWSDRSName_privatekey} > ${CREDENTIAL_DIR}privateKey_drs.pem.key + else + echo "Mode not supported" + exit 1 + fi +# Obtain ZIP package and unzip it locally + echo ${TestMode} + echo "[STEP] Obtain ZIP package" + echo "***************************************************" + cp -R ${SDKLocation} ./test-integration/IntegrationTests/TestToolLibrary/SDKPackage/ +# Obtain Python executable + + echo "***************************************************" + for file in `ls ${TEST_DIR}` + do + if [ ${file##*.}x == "py"x ]; then + echo "[SUB] Running test: ${file}..." + + Scale=10 + Host=TestHost + case "$file" in + "IntegrationTestMQTTConnection.py") Scale=$2 + ;; + "IntegrationTestShadow.py") Scale=$3 + ;; + "IntegrationTestAutoReconnectResubscribe.py") Scale="" + ;; + "IntegrationTestProgressiveBackoff.py") Scale=$4 + ;; + "IntegrationTestConfigurablePublishMessageQueueing.py") Scale="" + ;; + "IntegrationTestDiscovery.py") Scale="" + Host=${GreengrassHost} + ;; + "IntegrationTestAsyncAPIGeneralNotificationCallbacks.py") Scale="" + ;; + "IntegrationTestOfflineQueueingForSubscribeUnsubscribe.py") Scale="" + ;; + "IntegrationTestClientReusability.py") Scale="" + ;; + "IntegrationTestJobsClient.py") Scale="" + esac + + python ${TEST_DIR}${file} ${TestMode} ${TestHost} ${Scale} + currentTestStatus=$? + echo "[SUB] Test: ${file} completed. Exiting with status: ${currentTestStatus}" + if [ ${currentTestStatus} -ne 0 ]; then + echo "!!!!!!!!!!!!!Test: ${file} failed.!!!!!!!!!!!!!" + exit ${currentTestStatus} + fi + echo "" + fi + done + echo "All integration tests passed" +fi diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/core/__init__.py b/test/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/core/greengrass/__init__.py b/test/core/greengrass/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/core/greengrass/discovery/__init__.py b/test/core/greengrass/discovery/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/core/greengrass/discovery/test_discovery_info_parsing.py b/test/core/greengrass/discovery/test_discovery_info_parsing.py new file mode 100644 index 0000000..318ede3 --- /dev/null +++ b/test/core/greengrass/discovery/test_discovery_info_parsing.py @@ -0,0 +1,127 @@ +from AWSIoTPythonSDK.core.greengrass.discovery.models import DiscoveryInfo + + +DRS_INFO_JSON = "{\"GGGroups\":[{\"GGGroupId\":\"627bf63d-ae64-4f58-a18c-80a44fcf4088\"," \ + "\"Cores\":[{\"thingArn\":\"arn:aws:iot:us-east-1:003261610643:thing/DRS_GGC_0kegiNGA_0\"," \ + "\"Connectivity\":[{\"Id\":\"Id-0\",\"HostAddress\":\"192.168.101.0\",\"PortNumber\":8080," \ + "\"Metadata\":\"Description-0\"}," \ + "{\"Id\":\"Id-1\",\"HostAddress\":\"192.168.101.1\",\"PortNumber\":8081,\"Metadata\":\"Description-1\"}," \ + "{\"Id\":\"Id-2\",\"HostAddress\":\"192.168.101.2\",\"PortNumber\":8082,\"Metadata\":\"Description-2\"}]}]," \ + "\"CAs\":[\"-----BEGIN CERTIFICATE-----\\n" \ + "MIIEFTCCAv2gAwIBAgIVAPZfc4GMLZPmXbnoaZm6jRDqDs4+MA0GCSqGSIb3DQEB\\n" \ + "CwUAMIGoMQswCQYDVQQGEwJVUzEYMBYGA1UECgwPQW1hem9uLmNvbSBJbmMuMRww\\n" \ + "GgYDVQQLDBNBbWF6b24gV2ViIFNlcnZpY2VzMRMwEQYDVQQIDApXYXNoaW5ndG9u\\n" \ + "MRAwDgYDVQQHDAdTZWF0dGxlMTowOAYDVQQDDDEwMDMyNjE2MTA2NDM6NjI3YmY2\\n" \ + "M2QtYWU2NC00ZjU4LWExOGMtODBhNDRmY2Y0MDg4MCAXDTE3MDUyNTE4NDI1OVoY\\n" \ + "DzIwOTcwNTI1MTg0MjU4WjCBqDELMAkGA1UEBhMCVVMxGDAWBgNVBAoMD0FtYXpv\\n" \ + "bi5jb20gSW5jLjEcMBoGA1UECwwTQW1hem9uIFdlYiBTZXJ2aWNlczETMBEGA1UE\\n" \ + "CAwKV2FzaGluZ3RvbjEQMA4GA1UEBwwHU2VhdHRsZTE6MDgGA1UEAwwxMDAzMjYx\\n" \ + "NjEwNjQzOjYyN2JmNjNkLWFlNjQtNGY1OC1hMThjLTgwYTQ0ZmNmNDA4ODCCASIw\\n" \ + "DQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKEWtZtKyJUg2VUwZkbkVtltrfam\\n" \ + "s9LMIdKNA3Wz4zSLhZjKHiTSkQmpZwKle5ziYs6Q5hfeT8WC0FNAVv1JhnwsuGfT\\n" \ + "sG0UO5dSn7wqXOJigKC1CaSGqeFpKB0/a3wR1L6pCGVbLZ86/sPCEPHHJDieQ+Ps\\n" \ + "RnOcUGb4CuIBnI2N+lafWNa4F4KRSVJCEeZ6u4iWVVdIEcDLKlakY45jtVvQqwnz\\n" \ + "3leFsN7PTLEkVq5u1PXSbT5DWv6p+5NoDnGAT7j7Wbr2yJw7DtpBOL6oWkAdbFAQ\\n" \ + "2097e8mIxNYE9xAzRlb5wEr6jZl/8K60v9P83OapMeuOg4JS8FGulHXbDg0CAwEA\\n" \ + "AaMyMDAwDwYDVR0TAQH/BAUwAwEB/zAdBgNVHQ4EFgQU21ELaPCH9Oh001OS0JMv\\n" \ + "n8hU8dYwDQYJKoZIhvcNAQELBQADggEBABW66eH/+/v9Nq5jtJzflfrqAfBOzWLj\\n" \ + "UTEv6szkYzV5Crr8vnu2P5OlyA0NdiKGiAm0AgoDkf+n9HU3Hc0zm3G/QaAO2UmN\\n" \ + "9MwtIp29BSRf+gd1bX/WZTtl5I5xl290BDfr5o08I6TOf0A4P8IAkGwku5j0IQjM\\n" \ + "ns2HH5UVki155dtmWDEGX6q35KABbsmv3tO1+geJVYnd1QkHzR5IXA12gxlMw9GJ\\n" \ + "+cOw+rwJJ2ZcXo3HFoXBcsPqPOa1SO3vTl3XWQ+jX3vyDsxh/VGoJ4epsjwmJ+dW\\n" \ + "sHJoqsa3ZPDW0LcEuYgdzYWRhumGwH9fJJUx0GS4Tdg4ud+6jpuyflU=\\n" \ + "-----END CERTIFICATE-----\\n\"]}]}" + +EXPECTED_CORE_THING_ARN = "arn:aws:iot:us-east-1:003261610643:thing/DRS_GGC_0kegiNGA_0" +EXPECTED_GROUP_ID = "627bf63d-ae64-4f58-a18c-80a44fcf4088" +EXPECTED_CONNECTIVITY_INFO_ID_0 = "Id-0" +EXPECTED_CONNECTIVITY_INFO_ID_1 = "Id-1" +EXPECTED_CONNECTIVITY_INFO_ID_2 = "Id-2" +EXPECTED_CA = "-----BEGIN CERTIFICATE-----\n" \ + "MIIEFTCCAv2gAwIBAgIVAPZfc4GMLZPmXbnoaZm6jRDqDs4+MA0GCSqGSIb3DQEB\n" \ + "CwUAMIGoMQswCQYDVQQGEwJVUzEYMBYGA1UECgwPQW1hem9uLmNvbSBJbmMuMRww\n" \ + "GgYDVQQLDBNBbWF6b24gV2ViIFNlcnZpY2VzMRMwEQYDVQQIDApXYXNoaW5ndG9u\n" \ + "MRAwDgYDVQQHDAdTZWF0dGxlMTowOAYDVQQDDDEwMDMyNjE2MTA2NDM6NjI3YmY2\n" \ + "M2QtYWU2NC00ZjU4LWExOGMtODBhNDRmY2Y0MDg4MCAXDTE3MDUyNTE4NDI1OVoY\n" \ + "DzIwOTcwNTI1MTg0MjU4WjCBqDELMAkGA1UEBhMCVVMxGDAWBgNVBAoMD0FtYXpv\n" \ + "bi5jb20gSW5jLjEcMBoGA1UECwwTQW1hem9uIFdlYiBTZXJ2aWNlczETMBEGA1UE\n" \ + "CAwKV2FzaGluZ3RvbjEQMA4GA1UEBwwHU2VhdHRsZTE6MDgGA1UEAwwxMDAzMjYx\n" \ + "NjEwNjQzOjYyN2JmNjNkLWFlNjQtNGY1OC1hMThjLTgwYTQ0ZmNmNDA4ODCCASIw\n" \ + "DQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKEWtZtKyJUg2VUwZkbkVtltrfam\n" \ + "s9LMIdKNA3Wz4zSLhZjKHiTSkQmpZwKle5ziYs6Q5hfeT8WC0FNAVv1JhnwsuGfT\n" \ + "sG0UO5dSn7wqXOJigKC1CaSGqeFpKB0/a3wR1L6pCGVbLZ86/sPCEPHHJDieQ+Ps\n" \ + "RnOcUGb4CuIBnI2N+lafWNa4F4KRSVJCEeZ6u4iWVVdIEcDLKlakY45jtVvQqwnz\n" \ + "3leFsN7PTLEkVq5u1PXSbT5DWv6p+5NoDnGAT7j7Wbr2yJw7DtpBOL6oWkAdbFAQ\n" \ + "2097e8mIxNYE9xAzRlb5wEr6jZl/8K60v9P83OapMeuOg4JS8FGulHXbDg0CAwEA\n" \ + "AaMyMDAwDwYDVR0TAQH/BAUwAwEB/zAdBgNVHQ4EFgQU21ELaPCH9Oh001OS0JMv\n" \ + "n8hU8dYwDQYJKoZIhvcNAQELBQADggEBABW66eH/+/v9Nq5jtJzflfrqAfBOzWLj\n" \ + "UTEv6szkYzV5Crr8vnu2P5OlyA0NdiKGiAm0AgoDkf+n9HU3Hc0zm3G/QaAO2UmN\n" \ + "9MwtIp29BSRf+gd1bX/WZTtl5I5xl290BDfr5o08I6TOf0A4P8IAkGwku5j0IQjM\n" \ + "ns2HH5UVki155dtmWDEGX6q35KABbsmv3tO1+geJVYnd1QkHzR5IXA12gxlMw9GJ\n" \ + "+cOw+rwJJ2ZcXo3HFoXBcsPqPOa1SO3vTl3XWQ+jX3vyDsxh/VGoJ4epsjwmJ+dW\n" \ + "sHJoqsa3ZPDW0LcEuYgdzYWRhumGwH9fJJUx0GS4Tdg4ud+6jpuyflU=\n" \ + "-----END CERTIFICATE-----\n" + + +class TestDiscoveryInfoParsing: + + def setup_method(self, test_method): + self.discovery_info = DiscoveryInfo(DRS_INFO_JSON) + + def test_parsing_ggc_list_ca_list(self): + ggc_list = self.discovery_info.getAllCores() + ca_list = self.discovery_info.getAllCas() + + self._verify_core_connectivity_info_list(ggc_list) + self._verify_ca_list(ca_list) + + def test_parsing_group_object(self): + group_object = self.discovery_info.toObjectAtGroupLevel() + self._verify_connectivity_info(group_object + .get(EXPECTED_GROUP_ID) + .getCoreConnectivityInfo(EXPECTED_CORE_THING_ARN) + .getConnectivityInfo(EXPECTED_CONNECTIVITY_INFO_ID_0)) + self._verify_connectivity_info(group_object + .get(EXPECTED_GROUP_ID) + .getCoreConnectivityInfo(EXPECTED_CORE_THING_ARN) + .getConnectivityInfo(EXPECTED_CONNECTIVITY_INFO_ID_1)) + self._verify_connectivity_info(group_object + .get(EXPECTED_GROUP_ID) + .getCoreConnectivityInfo(EXPECTED_CORE_THING_ARN) + .getConnectivityInfo(EXPECTED_CONNECTIVITY_INFO_ID_2)) + + def test_parsing_group_list(self): + group_list = self.discovery_info.getAllGroups() + + assert len(group_list) == 1 + group_info = group_list[0] + assert group_info.groupId == EXPECTED_GROUP_ID + self._verify_ca_list(group_info.caList) + self._verify_core_connectivity_info_list(group_info.coreConnectivityInfoList) + + def _verify_ca_list(self, actual_ca_list): + assert len(actual_ca_list) == 1 + try: + actual_group_id, actual_ca = actual_ca_list[0] + assert actual_group_id == EXPECTED_GROUP_ID + assert actual_ca == EXPECTED_CA + except: + assert actual_ca_list[0] == EXPECTED_CA + + def _verify_core_connectivity_info_list(self, actual_core_connectivity_info_list): + assert len(actual_core_connectivity_info_list) == 1 + actual_core_connectivity_info = actual_core_connectivity_info_list[0] + assert actual_core_connectivity_info.coreThingArn == EXPECTED_CORE_THING_ARN + assert actual_core_connectivity_info.groupId == EXPECTED_GROUP_ID + self._verify_connectivity_info_list(actual_core_connectivity_info.connectivityInfoList) + + def _verify_connectivity_info_list(self, actual_connectivity_info_list): + for actual_connectivity_info in actual_connectivity_info_list: + self._verify_connectivity_info(actual_connectivity_info) + + def _verify_connectivity_info(self, actual_connectivity_info): + info_id = actual_connectivity_info.id + sequence_number_string = info_id[-1:] + assert actual_connectivity_info.host == "192.168.101." + sequence_number_string + assert actual_connectivity_info.port == int("808" + sequence_number_string) + assert actual_connectivity_info.metadata == "Description-" + sequence_number_string diff --git a/test/core/greengrass/discovery/test_discovery_info_provider.py b/test/core/greengrass/discovery/test_discovery_info_provider.py new file mode 100644 index 0000000..2f11d20 --- /dev/null +++ b/test/core/greengrass/discovery/test_discovery_info_provider.py @@ -0,0 +1,169 @@ +from AWSIoTPythonSDK.core.greengrass.discovery.providers import DiscoveryInfoProvider +from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryInvalidRequestException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryUnauthorizedException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryDataNotFoundException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryThrottlingException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryTimeoutException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryFailure +import pytest +try: + from mock import patch + from mock import MagicMock +except: + from unittest.mock import patch + from unittest.mock import MagicMock + + +DUMMY_CA_PATH = "dummy/ca/path" +DUMMY_CERT_PATH = "dummy/cert/path" +DUMMY_KEY_PATH = "dummy/key/path" +DUMMY_HOST = "dummy.host.amazonaws.com" +DUMMY_PORT = "8443" +DUMMY_TIME_OUT_SEC = 3 +DUMMY_GGAD_THING_NAME = "CoolGGAD" +FORMAT_REQUEST = "GET /greengrass/discover/thing/%s HTTP/1.1\r\nHost: " + DUMMY_HOST + ":" + DUMMY_PORT + "\r\n\r\n" +FORMAT_RESPONSE_HEADER = "HTTP/1.1 %s %s\r\n" \ + "content-type: application/json\r\n" \ + "content-length: %d\r\n" \ + "date: Wed, 05 Jul 2017 22:17:19 GMT\r\n" \ + "x-amzn-RequestId: 97408dd9-06a0-73bb-8e00-c4fc6845d555\r\n" \ + "connection: Keep-Alive\r\n\r\n" + +SERVICE_ERROR_MESSAGE_FORMAT = "{\"errorMessage\":\"%s\"}" +SERVICE_ERROR_MESSAGE_400 = SERVICE_ERROR_MESSAGE_FORMAT % "Invalid input detected for this request" +SERVICE_ERROR_MESSAGE_401 = SERVICE_ERROR_MESSAGE_FORMAT % "Unauthorized request" +SERVICE_ERROR_MESSAGE_404 = SERVICE_ERROR_MESSAGE_FORMAT % "Resource not found" +SERVICE_ERROR_MESSAGE_429 = SERVICE_ERROR_MESSAGE_FORMAT % "Too many requests" +SERVICE_ERROR_MESSAGE_500 = SERVICE_ERROR_MESSAGE_FORMAT % "Internal server error" +PAYLOAD_200 = "{\"GGGroups\":[{\"GGGroupId\":\"627bf63d-ae64-4f58-a18c-80a44fcf4088\"," \ + "\"Cores\":[{\"thingArn\":\"arn:aws:iot:us-east-1:003261610643:thing/DRS_GGC_0kegiNGA_0\"," \ + "\"Connectivity\":[{\"Id\":\"Id-0\",\"HostAddress\":\"192.168.101.0\",\"PortNumber\":8080," \ + "\"Metadata\":\"Description-0\"}," \ + "{\"Id\":\"Id-1\",\"HostAddress\":\"192.168.101.1\",\"PortNumber\":8081,\"Metadata\":\"Description-1\"}," \ + "{\"Id\":\"Id-2\",\"HostAddress\":\"192.168.101.2\",\"PortNumber\":8082,\"Metadata\":\"Description-2\"}]}]," \ + "\"CAs\":[\"-----BEGIN CERTIFICATE-----\\n" \ + "MIIEFTCCAv2gAwIBAgIVAPZfc4GMLZPmXbnoaZm6jRDqDs4+MA0GCSqGSIb3DQEB\\n" \ + "CwUAMIGoMQswCQYDVQQGEwJVUzEYMBYGA1UECgwPQW1hem9uLmNvbSBJbmMuMRww\\n" \ + "GgYDVQQLDBNBbWF6b24gV2ViIFNlcnZpY2VzMRMwEQYDVQQIDApXYXNoaW5ndG9u\\n" \ + "MRAwDgYDVQQHDAdTZWF0dGxlMTowOAYDVQQDDDEwMDMyNjE2MTA2NDM6NjI3YmY2\\n" \ + "M2QtYWU2NC00ZjU4LWExOGMtODBhNDRmY2Y0MDg4MCAXDTE3MDUyNTE4NDI1OVoY\\n" \ + "DzIwOTcwNTI1MTg0MjU4WjCBqDELMAkGA1UEBhMCVVMxGDAWBgNVBAoMD0FtYXpv\\n" \ + "bi5jb20gSW5jLjEcMBoGA1UECwwTQW1hem9uIFdlYiBTZXJ2aWNlczETMBEGA1UE\\n" \ + "CAwKV2FzaGluZ3RvbjEQMA4GA1UEBwwHU2VhdHRsZTE6MDgGA1UEAwwxMDAzMjYx\\n" \ + "NjEwNjQzOjYyN2JmNjNkLWFlNjQtNGY1OC1hMThjLTgwYTQ0ZmNmNDA4ODCCASIw\\n" \ + "DQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKEWtZtKyJUg2VUwZkbkVtltrfam\\n" \ + "s9LMIdKNA3Wz4zSLhZjKHiTSkQmpZwKle5ziYs6Q5hfeT8WC0FNAVv1JhnwsuGfT\\n" \ + "sG0UO5dSn7wqXOJigKC1CaSGqeFpKB0/a3wR1L6pCGVbLZ86/sPCEPHHJDieQ+Ps\\n" \ + "RnOcUGb4CuIBnI2N+lafWNa4F4KRSVJCEeZ6u4iWVVdIEcDLKlakY45jtVvQqwnz\\n" \ + "3leFsN7PTLEkVq5u1PXSbT5DWv6p+5NoDnGAT7j7Wbr2yJw7DtpBOL6oWkAdbFAQ\\n" \ + "2097e8mIxNYE9xAzRlb5wEr6jZl/8K60v9P83OapMeuOg4JS8FGulHXbDg0CAwEA\\n" \ + "AaMyMDAwDwYDVR0TAQH/BAUwAwEB/zAdBgNVHQ4EFgQU21ELaPCH9Oh001OS0JMv\\n" \ + "n8hU8dYwDQYJKoZIhvcNAQELBQADggEBABW66eH/+/v9Nq5jtJzflfrqAfBOzWLj\\n" \ + "UTEv6szkYzV5Crr8vnu2P5OlyA0NdiKGiAm0AgoDkf+n9HU3Hc0zm3G/QaAO2UmN\\n" \ + "9MwtIp29BSRf+gd1bX/WZTtl5I5xl290BDfr5o08I6TOf0A4P8IAkGwku5j0IQjM\\n" \ + "ns2HH5UVki155dtmWDEGX6q35KABbsmv3tO1+geJVYnd1QkHzR5IXA12gxlMw9GJ\\n" \ + "+cOw+rwJJ2ZcXo3HFoXBcsPqPOa1SO3vTl3XWQ+jX3vyDsxh/VGoJ4epsjwmJ+dW\\n" \ + "sHJoqsa3ZPDW0LcEuYgdzYWRhumGwH9fJJUx0GS4Tdg4ud+6jpuyflU=\\n" \ + "-----END CERTIFICATE-----\\n\"]}]}" + + +class TestDiscoveryInfoProvider: + + def setup_class(cls): + cls.service_error_message_dict = { + "400" : SERVICE_ERROR_MESSAGE_400, + "401" : SERVICE_ERROR_MESSAGE_401, + "404" : SERVICE_ERROR_MESSAGE_404, + "429" : SERVICE_ERROR_MESSAGE_429 + } + cls.client_exception_dict = { + "400" : DiscoveryInvalidRequestException, + "401" : DiscoveryUnauthorizedException, + "404" : DiscoveryDataNotFoundException, + "429" : DiscoveryThrottlingException + } + + def setup_method(self, test_method): + self.mock_sock = MagicMock() + self.mock_ssl_sock = MagicMock() + + def test_200_drs_response_should_succeed(self): + with patch.object(DiscoveryInfoProvider, "_create_tcp_connection") as mock_method_create_tcp_connection, \ + patch.object(DiscoveryInfoProvider, "_create_ssl_connection") as mock_method_create_ssl_connection: + mock_method_create_tcp_connection.return_value = self.mock_sock + mock_method_create_ssl_connection.return_value = self.mock_ssl_sock + raw_outbound_request = FORMAT_REQUEST % DUMMY_GGAD_THING_NAME + self._create_test_target() + self.mock_ssl_sock.write.return_value = len(raw_outbound_request) + self.mock_ssl_sock.read.side_effect = \ + list((FORMAT_RESPONSE_HEADER % ("200", "OK", len(PAYLOAD_200)) + PAYLOAD_200).encode("utf-8")) + + discovery_info = self.discovery_info_provider.discover(DUMMY_GGAD_THING_NAME) + self.mock_ssl_sock.write.assert_called_with(raw_outbound_request.encode("utf-8")) + assert discovery_info.rawJson == PAYLOAD_200 + + def test_400_drs_response_should_raise(self): + self._internal_test_non_200_drs_response_should_raise("400", "Bad request") + + def test_401_drs_response_should_raise(self): + self._internal_test_non_200_drs_response_should_raise("401", "Unauthorized") + + def test_404_drs_response_should_raise(self): + self._internal_test_non_200_drs_response_should_raise("404", "Not found") + + def test_429_drs_response_should_raise(self): + self._internal_test_non_200_drs_response_should_raise("429", "Throttled") + + def test_unexpected_drs_response_should_raise(self): + self._internal_test_non_200_drs_response_should_raise("500", "Internal server error") + self._internal_test_non_200_drs_response_should_raise("1234", "Gibberish") + + def _internal_test_non_200_drs_response_should_raise(self, http_status_code, http_status_message): + with patch.object(DiscoveryInfoProvider, "_create_tcp_connection") as mock_method_create_tcp_connection, \ + patch.object(DiscoveryInfoProvider, "_create_ssl_connection") as mock_method_create_ssl_connection: + mock_method_create_tcp_connection.return_value = self.mock_sock + mock_method_create_ssl_connection.return_value = self.mock_ssl_sock + self._create_test_target() + service_error_message = self.service_error_message_dict.get(http_status_code) + if service_error_message is None: + service_error_message = SERVICE_ERROR_MESSAGE_500 + client_exception_type = self.client_exception_dict.get(http_status_code) + if client_exception_type is None: + client_exception_type = DiscoveryFailure + self.mock_ssl_sock.write.return_value = len(FORMAT_REQUEST % DUMMY_GGAD_THING_NAME) + self.mock_ssl_sock.read.side_effect = \ + list((FORMAT_RESPONSE_HEADER % (http_status_code, http_status_message, len(service_error_message)) + + service_error_message).encode("utf-8")) + + with pytest.raises(client_exception_type): + self.discovery_info_provider.discover(DUMMY_GGAD_THING_NAME) + + def test_request_time_out_should_raise(self): + with patch.object(DiscoveryInfoProvider, "_create_tcp_connection") as mock_method_create_tcp_connection, \ + patch.object(DiscoveryInfoProvider, "_create_ssl_connection") as mock_method_create_ssl_connection: + mock_method_create_tcp_connection.return_value = self.mock_sock + mock_method_create_ssl_connection.return_value = self.mock_ssl_sock + self._create_test_target() + + # We do not configure any return value and simply let request part time out + with pytest.raises(DiscoveryTimeoutException): + self.discovery_info_provider.discover(DUMMY_GGAD_THING_NAME) + + def test_response_time_out_should_raise(self): + with patch.object(DiscoveryInfoProvider, "_create_tcp_connection") as mock_method_create_tcp_connection, \ + patch.object(DiscoveryInfoProvider, "_create_ssl_connection") as mock_method_create_ssl_connection: + mock_method_create_tcp_connection.return_value = self.mock_sock + mock_method_create_ssl_connection.return_value = self.mock_ssl_sock + self._create_test_target() + + # We configure the request to succeed and let the response part time out + self.mock_ssl_sock.write.return_value = len(FORMAT_REQUEST % DUMMY_GGAD_THING_NAME) + with pytest.raises(DiscoveryTimeoutException): + self.discovery_info_provider.discover(DUMMY_GGAD_THING_NAME) + + def _create_test_target(self): + self.discovery_info_provider = DiscoveryInfoProvider(caPath=DUMMY_CA_PATH, + certPath=DUMMY_CERT_PATH, + keyPath=DUMMY_KEY_PATH, + host=DUMMY_HOST, + timeoutSec=DUMMY_TIME_OUT_SEC) diff --git a/test/core/jobs/test_jobs_client.py b/test/core/jobs/test_jobs_client.py new file mode 100644 index 0000000..c36fc55 --- /dev/null +++ b/test/core/jobs/test_jobs_client.py @@ -0,0 +1,169 @@ +# Test AWSIoTMQTTThingJobsClient behavior + +from AWSIoTPythonSDK.MQTTLib import AWSIoTMQTTClient +from AWSIoTPythonSDK.MQTTLib import AWSIoTMQTTThingJobsClient +from AWSIoTPythonSDK.core.jobs.thingJobManager import thingJobManager +from AWSIoTPythonSDK.core.jobs.thingJobManager import jobExecutionTopicType +from AWSIoTPythonSDK.core.jobs.thingJobManager import jobExecutionTopicReplyType +from AWSIoTPythonSDK.core.jobs.thingJobManager import jobExecutionStatus +import AWSIoTPythonSDK.MQTTLib +import time +import json +from mock import MagicMock + +#asserts based on this documentation: https://docs.aws.amazon.com/iot/latest/developerguide/jobs-api.html +class TestAWSIoTMQTTThingJobsClient: + thingName = 'testThing' + clientTokenValue = 'testClientToken123' + statusDetailsMap = {'testKey':'testVal'} + + def setup_method(self, method): + self.mockAWSIoTMQTTClient = MagicMock(spec=AWSIoTMQTTClient) + self.jobsClient = AWSIoTMQTTThingJobsClient(self.clientTokenValue, self.thingName, QoS=0, awsIoTMQTTClient=self.mockAWSIoTMQTTClient) + self.jobsClient._thingJobManager = MagicMock(spec=thingJobManager) + + def test_unsuccessful_create_subscription(self): + fake_callback = MagicMock(); + self.jobsClient._thingJobManager.getJobTopic.return_value = 'UnsuccessfulCreateSubTopic' + self.mockAWSIoTMQTTClient.subscribe.return_value = False + assert False == self.jobsClient.createJobSubscription(fake_callback) + + def test_successful_job_request_create_subscription(self): + fake_callback = MagicMock(); + self.jobsClient._thingJobManager.getJobTopic.return_value = 'SuccessfulCreateSubRequestTopic' + self.mockAWSIoTMQTTClient.subscribe.return_value = True + assert True == self.jobsClient.createJobSubscription(fake_callback) + self.jobsClient._thingJobManager.getJobTopic.assert_called_with(jobExecutionTopicType.JOB_WILDCARD_TOPIC, jobExecutionTopicReplyType.JOB_REQUEST_TYPE, None) + self.mockAWSIoTMQTTClient.subscribe.assert_called_with(self.jobsClient._thingJobManager.getJobTopic.return_value, 0, fake_callback) + + def test_successful_job_start_next_create_subscription(self): + fake_callback = MagicMock(); + self.jobsClient._thingJobManager.getJobTopic.return_value = 'SuccessfulCreateSubStartNextTopic' + self.mockAWSIoTMQTTClient.subscribe.return_value = True + assert True == self.jobsClient.createJobSubscription(fake_callback, jobExecutionTopicType.JOB_START_NEXT_TOPIC, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE) + self.jobsClient._thingJobManager.getJobTopic.assert_called_with(jobExecutionTopicType.JOB_START_NEXT_TOPIC, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE, None) + self.mockAWSIoTMQTTClient.subscribe.assert_called_with(self.jobsClient._thingJobManager.getJobTopic.return_value, 0, fake_callback) + + def test_successful_job_update_create_subscription(self): + fake_callback = MagicMock(); + self.jobsClient._thingJobManager.getJobTopic.return_value = 'SuccessfulCreateSubUpdateTopic' + self.mockAWSIoTMQTTClient.subscribe.return_value = True + assert True == self.jobsClient.createJobSubscription(fake_callback, jobExecutionTopicType.JOB_UPDATE_TOPIC, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE, 'jobUpdateId') + self.jobsClient._thingJobManager.getJobTopic.assert_called_with(jobExecutionTopicType.JOB_UPDATE_TOPIC, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE, 'jobUpdateId') + self.mockAWSIoTMQTTClient.subscribe.assert_called_with(self.jobsClient._thingJobManager.getJobTopic.return_value, 0, fake_callback) + + def test_successful_job_update_notify_next_create_subscription(self): + fake_callback = MagicMock(); + self.jobsClient._thingJobManager.getJobTopic.return_value = 'SuccessfulCreateSubNotifyNextTopic' + self.mockAWSIoTMQTTClient.subscribe.return_value = True + assert True == self.jobsClient.createJobSubscription(fake_callback, jobExecutionTopicType.JOB_NOTIFY_NEXT_TOPIC) + self.jobsClient._thingJobManager.getJobTopic.assert_called_with(jobExecutionTopicType.JOB_NOTIFY_NEXT_TOPIC, jobExecutionTopicReplyType.JOB_REQUEST_TYPE, None) + self.mockAWSIoTMQTTClient.subscribe.assert_called_with(self.jobsClient._thingJobManager.getJobTopic.return_value, 0, fake_callback) + + def test_successful_job_request_create_subscription_async(self): + fake_callback = MagicMock(); + fake_ack_callback = MagicMock(); + self.jobsClient._thingJobManager.getJobTopic.return_value = 'CreateSubTopic1' + self.mockAWSIoTMQTTClient.subscribeAsync.return_value = 'MessageId1' + assert self.mockAWSIoTMQTTClient.subscribeAsync.return_value == self.jobsClient.createJobSubscriptionAsync(fake_ack_callback, fake_callback) + self.jobsClient._thingJobManager.getJobTopic.assert_called_with(jobExecutionTopicType.JOB_WILDCARD_TOPIC, jobExecutionTopicReplyType.JOB_REQUEST_TYPE, None) + self.mockAWSIoTMQTTClient.subscribeAsync.assert_called_with(self.jobsClient._thingJobManager.getJobTopic.return_value, 0, fake_ack_callback, fake_callback) + + def test_successful_job_start_next_create_subscription_async(self): + fake_callback = MagicMock(); + fake_ack_callback = MagicMock(); + self.jobsClient._thingJobManager.getJobTopic.return_value = 'CreateSubTopic3' + self.mockAWSIoTMQTTClient.subscribeAsync.return_value = 'MessageId3' + assert self.mockAWSIoTMQTTClient.subscribeAsync.return_value == self.jobsClient.createJobSubscriptionAsync(fake_ack_callback, fake_callback, jobExecutionTopicType.JOB_START_NEXT_TOPIC, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE) + self.jobsClient._thingJobManager.getJobTopic.assert_called_with(jobExecutionTopicType.JOB_START_NEXT_TOPIC, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE, None) + self.mockAWSIoTMQTTClient.subscribeAsync.assert_called_with(self.jobsClient._thingJobManager.getJobTopic.return_value, 0, fake_ack_callback, fake_callback) + + def test_successful_job_update_create_subscription_async(self): + fake_callback = MagicMock(); + fake_ack_callback = MagicMock(); + self.jobsClient._thingJobManager.getJobTopic.return_value = 'CreateSubTopic4' + self.mockAWSIoTMQTTClient.subscribeAsync.return_value = 'MessageId4' + assert self.mockAWSIoTMQTTClient.subscribeAsync.return_value == self.jobsClient.createJobSubscriptionAsync(fake_ack_callback, fake_callback, jobExecutionTopicType.JOB_UPDATE_TOPIC, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE, 'jobUpdateId3') + self.jobsClient._thingJobManager.getJobTopic.assert_called_with(jobExecutionTopicType.JOB_UPDATE_TOPIC, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE, 'jobUpdateId3') + self.mockAWSIoTMQTTClient.subscribeAsync.assert_called_with(self.jobsClient._thingJobManager.getJobTopic.return_value, 0, fake_ack_callback, fake_callback) + + def test_successful_job_notify_next_subscription_async(self): + fake_callback = MagicMock(); + fake_ack_callback = MagicMock(); + self.jobsClient._thingJobManager.getJobTopic.return_value = 'CreateSubTopic5' + self.mockAWSIoTMQTTClient.subscribeAsync.return_value = 'MessageId5' + assert self.mockAWSIoTMQTTClient.subscribeAsync.return_value == self.jobsClient.createJobSubscriptionAsync(fake_ack_callback, fake_callback, jobExecutionTopicType.JOB_NOTIFY_NEXT_TOPIC) + self.jobsClient._thingJobManager.getJobTopic.assert_called_with(jobExecutionTopicType.JOB_NOTIFY_NEXT_TOPIC, jobExecutionTopicReplyType.JOB_REQUEST_TYPE, None) + self.mockAWSIoTMQTTClient.subscribeAsync.assert_called_with(self.jobsClient._thingJobManager.getJobTopic.return_value, 0, fake_ack_callback, fake_callback) + + def test_send_jobs_query_get_pending(self): + self.jobsClient._thingJobManager.getJobTopic.return_value = 'SendJobsQuery1' + self.jobsClient._thingJobManager.serializeClientTokenPayload.return_value = {} + self.mockAWSIoTMQTTClient.publish.return_value = True + assert self.mockAWSIoTMQTTClient.publish.return_value == self.jobsClient.sendJobsQuery(jobExecutionTopicType.JOB_GET_PENDING_TOPIC) + self.jobsClient._thingJobManager.getJobTopic.assert_called_with(jobExecutionTopicType.JOB_GET_PENDING_TOPIC, jobExecutionTopicReplyType.JOB_REQUEST_TYPE, None) + self.jobsClient._thingJobManager.serializeClientTokenPayload.assert_called_with() + self.mockAWSIoTMQTTClient.publish.assert_called_with(self.jobsClient._thingJobManager.getJobTopic.return_value, self.jobsClient._thingJobManager.serializeClientTokenPayload.return_value, 0) + + def test_send_jobs_query_describe(self): + self.jobsClient._thingJobManager.getJobTopic.return_value = 'SendJobsQuery2' + self.jobsClient._thingJobManager.serializeClientTokenPayload.return_value = {} + self.mockAWSIoTMQTTClient.publish.return_value = True + assert self.mockAWSIoTMQTTClient.publish.return_value == self.jobsClient.sendJobsQuery(jobExecutionTopicType.JOB_DESCRIBE_TOPIC, 'jobId2') + self.jobsClient._thingJobManager.getJobTopic.assert_called_with(jobExecutionTopicType.JOB_DESCRIBE_TOPIC, jobExecutionTopicReplyType.JOB_REQUEST_TYPE, 'jobId2') + self.jobsClient._thingJobManager.serializeClientTokenPayload.assert_called_with() + self.mockAWSIoTMQTTClient.publish.assert_called_with(self.jobsClient._thingJobManager.getJobTopic.return_value, self.jobsClient._thingJobManager.serializeClientTokenPayload.return_value, 0) + + def test_send_jobs_start_next(self): + self.jobsClient._thingJobManager.getJobTopic.return_value = 'SendStartNext1' + self.jobsClient._thingJobManager.serializeStartNextPendingJobExecutionPayload.return_value = {} + self.mockAWSIoTMQTTClient.publish.return_value = True + assert self.mockAWSIoTMQTTClient.publish.return_value == self.jobsClient.sendJobsStartNext(self.statusDetailsMap, 12) + self.jobsClient._thingJobManager.getJobTopic.assert_called_with(jobExecutionTopicType.JOB_START_NEXT_TOPIC, jobExecutionTopicReplyType.JOB_REQUEST_TYPE) + self.jobsClient._thingJobManager.serializeStartNextPendingJobExecutionPayload.assert_called_with(self.statusDetailsMap, 12) + self.mockAWSIoTMQTTClient.publish.assert_called_with(self.jobsClient._thingJobManager.getJobTopic.return_value, self.jobsClient._thingJobManager.serializeStartNextPendingJobExecutionPayload.return_value, 0) + + def test_send_jobs_start_next_no_status_details(self): + self.jobsClient._thingJobManager.getJobTopic.return_value = 'SendStartNext2' + self.jobsClient._thingJobManager.serializeStartNextPendingJobExecutionPayload.return_value = {} + self.mockAWSIoTMQTTClient.publish.return_value = False + assert self.mockAWSIoTMQTTClient.publish.return_value == self.jobsClient.sendJobsStartNext({}) + self.jobsClient._thingJobManager.getJobTopic.assert_called_with(jobExecutionTopicType.JOB_START_NEXT_TOPIC, jobExecutionTopicReplyType.JOB_REQUEST_TYPE) + self.jobsClient._thingJobManager.serializeStartNextPendingJobExecutionPayload.assert_called_with({}, None) + self.mockAWSIoTMQTTClient.publish.assert_called_with(self.jobsClient._thingJobManager.getJobTopic.return_value, self.jobsClient._thingJobManager.serializeStartNextPendingJobExecutionPayload.return_value, 0) + + def test_send_jobs_update_succeeded(self): + self.jobsClient._thingJobManager.getJobTopic.return_value = 'SendJobsUpdate1' + self.jobsClient._thingJobManager.serializeJobExecutionUpdatePayload.return_value = {} + self.mockAWSIoTMQTTClient.publish.return_value = True + assert self.mockAWSIoTMQTTClient.publish.return_value == self.jobsClient.sendJobsUpdate('jobId1', jobExecutionStatus.JOB_EXECUTION_SUCCEEDED, self.statusDetailsMap, 1, 2, True, False, 12) + self.jobsClient._thingJobManager.getJobTopic.assert_called_with(jobExecutionTopicType.JOB_UPDATE_TOPIC, jobExecutionTopicReplyType.JOB_REQUEST_TYPE, 'jobId1') + self.jobsClient._thingJobManager.serializeJobExecutionUpdatePayload.assert_called_with(jobExecutionStatus.JOB_EXECUTION_SUCCEEDED, self.statusDetailsMap, 1, 2, True, False, 12) + self.mockAWSIoTMQTTClient.publish.assert_called_with(self.jobsClient._thingJobManager.getJobTopic.return_value, self.jobsClient._thingJobManager.serializeJobExecutionUpdatePayload.return_value, 0) + + def test_send_jobs_update_failed(self): + self.jobsClient._thingJobManager.getJobTopic.return_value = 'SendJobsUpdate2' + self.jobsClient._thingJobManager.serializeJobExecutionUpdatePayload.return_value = {} + self.mockAWSIoTMQTTClient.publish.return_value = False + assert self.mockAWSIoTMQTTClient.publish.return_value == self.jobsClient.sendJobsUpdate('jobId2', jobExecutionStatus.JOB_EXECUTION_FAILED, {}, 3, 4, False, True, 34) + self.jobsClient._thingJobManager.getJobTopic.assert_called_with(jobExecutionTopicType.JOB_UPDATE_TOPIC, jobExecutionTopicReplyType.JOB_REQUEST_TYPE, 'jobId2') + self.jobsClient._thingJobManager.serializeJobExecutionUpdatePayload.assert_called_with(jobExecutionStatus.JOB_EXECUTION_FAILED, {}, 3, 4, False, True, 34) + self.mockAWSIoTMQTTClient.publish.assert_called_with(self.jobsClient._thingJobManager.getJobTopic.return_value, self.jobsClient._thingJobManager.serializeJobExecutionUpdatePayload.return_value, 0) + + def test_send_jobs_describe(self): + self.jobsClient._thingJobManager.getJobTopic.return_value = 'SendJobsDescribe1' + self.jobsClient._thingJobManager.serializeDescribeJobExecutionPayload.return_value = {} + self.mockAWSIoTMQTTClient.publish.return_value = True + assert self.mockAWSIoTMQTTClient.publish.return_value == self.jobsClient.sendJobsDescribe('jobId1', 2, True) + self.jobsClient._thingJobManager.getJobTopic.assert_called_with(jobExecutionTopicType.JOB_DESCRIBE_TOPIC, jobExecutionTopicReplyType.JOB_REQUEST_TYPE, 'jobId1') + self.jobsClient._thingJobManager.serializeDescribeJobExecutionPayload.assert_called_with(2, True) + self.mockAWSIoTMQTTClient.publish.assert_called_with(self.jobsClient._thingJobManager.getJobTopic.return_value, self.jobsClient._thingJobManager.serializeDescribeJobExecutionPayload.return_value, 0) + + def test_send_jobs_describe_false_return_val(self): + self.jobsClient._thingJobManager.getJobTopic.return_value = 'SendJobsDescribe2' + self.jobsClient._thingJobManager.serializeDescribeJobExecutionPayload.return_value = {} + self.mockAWSIoTMQTTClient.publish.return_value = False + assert self.mockAWSIoTMQTTClient.publish.return_value == self.jobsClient.sendJobsDescribe('jobId2', 1, False) + self.jobsClient._thingJobManager.getJobTopic.assert_called_with(jobExecutionTopicType.JOB_DESCRIBE_TOPIC, jobExecutionTopicReplyType.JOB_REQUEST_TYPE, 'jobId2') + self.jobsClient._thingJobManager.serializeDescribeJobExecutionPayload.assert_called_with(1, False) + self.mockAWSIoTMQTTClient.publish.assert_called_with(self.jobsClient._thingJobManager.getJobTopic.return_value, self.jobsClient._thingJobManager.serializeDescribeJobExecutionPayload.return_value, 0) diff --git a/test/core/jobs/test_thing_job_manager.py b/test/core/jobs/test_thing_job_manager.py new file mode 100644 index 0000000..c3fa7b1 --- /dev/null +++ b/test/core/jobs/test_thing_job_manager.py @@ -0,0 +1,191 @@ +# Test thingJobManager behavior + +from AWSIoTPythonSDK.core.jobs.thingJobManager import thingJobManager as JobManager +from AWSIoTPythonSDK.core.jobs.thingJobManager import jobExecutionTopicType +from AWSIoTPythonSDK.core.jobs.thingJobManager import jobExecutionTopicReplyType +from AWSIoTPythonSDK.core.jobs.thingJobManager import jobExecutionStatus +import time +import json +from mock import MagicMock + +#asserts based on this documentation: https://docs.aws.amazon.com/iot/latest/developerguide/jobs-api.html +class TestThingJobManager: + thingName = 'testThing' + clientTokenValue = "testClientToken123" + thingJobManager = JobManager(thingName, clientTokenValue) + noClientTokenJobManager = JobManager(thingName) + jobId = '8192' + statusDetailsMap = {'testKey':'testVal'} + + def test_pending_topics(self): + topicType = jobExecutionTopicType.JOB_GET_PENDING_TOPIC + assert ('$aws/things/' + self.thingName + '/jobs/get') == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REQUEST_TYPE) + assert ('$aws/things/' + self.thingName + '/jobs/get/accepted') == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE) + assert ('$aws/things/' + self.thingName + '/jobs/get/rejected') == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE) + assert ('$aws/things/' + self.thingName + '/jobs/get/#') == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_WILDCARD_REPLY_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REQUEST_TYPE, self.jobId) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE, self.jobId) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE, self.jobId) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_WILDCARD_REPLY_TYPE, self.jobId) + + def test_start_next_topics(self): + topicType = jobExecutionTopicType.JOB_START_NEXT_TOPIC + assert ('$aws/things/' + self.thingName + '/jobs/start-next') == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REQUEST_TYPE) + assert ('$aws/things/' + self.thingName + '/jobs/start-next/accepted') == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE) + assert ('$aws/things/' + self.thingName + '/jobs/start-next/rejected') == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE) + assert ('$aws/things/' + self.thingName + '/jobs/start-next/#') == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_WILDCARD_REPLY_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE, self.jobId) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE, self.jobId) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_WILDCARD_REPLY_TYPE, self.jobId) + + def test_describe_topics(self): + topicType = jobExecutionTopicType.JOB_DESCRIBE_TOPIC + assert ('$aws/things/' + self.thingName + '/jobs/' + str(self.jobId) + '/get') == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REQUEST_TYPE, self.jobId) + assert ('$aws/things/' + self.thingName + '/jobs/' + str(self.jobId) + '/get/accepted') == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE, self.jobId) + assert ('$aws/things/' + self.thingName + '/jobs/' + str(self.jobId) + '/get/rejected') == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE, self.jobId) + assert ('$aws/things/' + self.thingName + '/jobs/' + str(self.jobId) + '/get/#') == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_WILDCARD_REPLY_TYPE, self.jobId) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REQUEST_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_WILDCARD_REPLY_TYPE) + + def test_update_topics(self): + topicType = jobExecutionTopicType.JOB_UPDATE_TOPIC + assert ('$aws/things/' + self.thingName + '/jobs/' + str(self.jobId) + '/update') == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REQUEST_TYPE, self.jobId) + assert ('$aws/things/' + self.thingName + '/jobs/' + str(self.jobId) + '/update/accepted') == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE, self.jobId) + assert ('$aws/things/' + self.thingName + '/jobs/' + str(self.jobId) + '/update/rejected') == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE, self.jobId) + assert ('$aws/things/' + self.thingName + '/jobs/' + str(self.jobId) + '/update/#') == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_WILDCARD_REPLY_TYPE, self.jobId) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REQUEST_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_WILDCARD_REPLY_TYPE) + + def test_notify_topics(self): + topicType = jobExecutionTopicType.JOB_NOTIFY_TOPIC + assert ('$aws/things/' + self.thingName + '/jobs/notify') == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REQUEST_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_WILDCARD_REPLY_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE, self.jobId) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE, self.jobId) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_WILDCARD_REPLY_TYPE, self.jobId) + + def test_notify_next_topics(self): + topicType = jobExecutionTopicType.JOB_NOTIFY_NEXT_TOPIC + assert ('$aws/things/' + self.thingName + '/jobs/notify-next') == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REQUEST_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_WILDCARD_REPLY_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REQUEST_TYPE, self.jobId) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE, self.jobId) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE, self.jobId) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_WILDCARD_REPLY_TYPE, self.jobId) + + def test_wildcard_topics(self): + topicType = jobExecutionTopicType.JOB_WILDCARD_TOPIC + topicString = '$aws/things/' + self.thingName + '/jobs/#' + assert topicString == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REQUEST_TYPE) + assert topicString == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE) + assert topicString == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE) + assert topicString == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_WILDCARD_REPLY_TYPE) + + def test_thingless_topics(self): + thinglessJobManager = JobManager(None) + assert None == thinglessJobManager.getJobTopic(jobExecutionTopicType.JOB_GET_PENDING_TOPIC) + assert None == thinglessJobManager.getJobTopic(jobExecutionTopicType.JOB_START_NEXT_TOPIC) + assert None == thinglessJobManager.getJobTopic(jobExecutionTopicType.JOB_DESCRIBE_TOPIC) + assert None == thinglessJobManager.getJobTopic(jobExecutionTopicType.JOB_UPDATE_TOPIC) + assert None == thinglessJobManager.getJobTopic(jobExecutionTopicType.JOB_NOTIFY_TOPIC) + assert None == thinglessJobManager.getJobTopic(jobExecutionTopicType.JOB_NOTIFY_NEXT_TOPIC) + assert None == thinglessJobManager.getJobTopic(jobExecutionTopicType.JOB_WILDCARD_TOPIC) + + def test_unrecognized_topics(self): + topicType = jobExecutionTopicType.JOB_UNRECOGNIZED_TOPIC + assert None == self.thingJobManager.getJobTopic(topicType) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_WILDCARD_REPLY_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REQUEST_TYPE, self.jobId) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE, self.jobId) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE, self.jobId) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_WILDCARD_REPLY_TYPE, self.jobId) + + def test_serialize_client_token(self): + payload = '{"clientToken": "' + self.clientTokenValue + '"}' + assert payload == self.thingJobManager.serializeClientTokenPayload() + assert "{}" == self.noClientTokenJobManager.serializeClientTokenPayload() + + def test_serialize_start_next_pending_job_execution(self): + payload = {'clientToken': self.clientTokenValue} + assert payload == json.loads(self.thingJobManager.serializeStartNextPendingJobExecutionPayload()) + assert {} == json.loads(self.noClientTokenJobManager.serializeStartNextPendingJobExecutionPayload()) + payload.update({'statusDetails': self.statusDetailsMap}) + assert payload == json.loads(self.thingJobManager.serializeStartNextPendingJobExecutionPayload(self.statusDetailsMap)) + assert {'statusDetails': self.statusDetailsMap} == json.loads(self.noClientTokenJobManager.serializeStartNextPendingJobExecutionPayload(self.statusDetailsMap)) + + def test_serialize_describe_job_execution(self): + payload = {'includeJobDocument': True} + assert payload == json.loads(self.noClientTokenJobManager.serializeDescribeJobExecutionPayload()) + payload.update({'executionNumber': 1}) + assert payload == json.loads(self.noClientTokenJobManager.serializeDescribeJobExecutionPayload(1)) + payload.update({'includeJobDocument': False}) + assert payload == json.loads(self.noClientTokenJobManager.serializeDescribeJobExecutionPayload(1, False)) + + payload = {'includeJobDocument': True, 'clientToken': self.clientTokenValue} + assert payload == json.loads(self.thingJobManager.serializeDescribeJobExecutionPayload()) + payload.update({'executionNumber': 1}) + assert payload == json.loads(self.thingJobManager.serializeDescribeJobExecutionPayload(1)) + payload.update({'includeJobDocument': False}) + assert payload == json.loads(self.thingJobManager.serializeDescribeJobExecutionPayload(1, False)) + + def test_serialize_job_execution_update(self): + assert None == self.thingJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_STATUS_NOT_SET) + assert None == self.thingJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_UNKNOWN_STATUS) + assert None == self.noClientTokenJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_STATUS_NOT_SET) + assert None == self.noClientTokenJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_UNKNOWN_STATUS) + + payload = {'status':'IN_PROGRESS'} + assert payload == json.loads(self.noClientTokenJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_IN_PROGRESS)) + payload.update({'status':'FAILED'}) + assert payload == json.loads(self.noClientTokenJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_FAILED)) + payload.update({'status':'SUCCEEDED'}) + assert payload == json.loads(self.noClientTokenJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_SUCCEEDED)) + payload.update({'status':'CANCELED'}) + assert payload == json.loads(self.noClientTokenJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_CANCELED)) + payload.update({'status':'REJECTED'}) + assert payload == json.loads(self.noClientTokenJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_REJECTED)) + payload.update({'status':'QUEUED'}) + assert payload == json.loads(self.noClientTokenJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_QUEUED)) + payload.update({'statusDetails': self.statusDetailsMap}) + assert payload == json.loads(self.noClientTokenJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_QUEUED, self.statusDetailsMap)) + payload.update({'expectedVersion': '1'}) + assert payload == json.loads(self.noClientTokenJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_QUEUED, self.statusDetailsMap, 1)) + payload.update({'executionNumber': '1'}) + assert payload == json.loads(self.noClientTokenJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_QUEUED, self.statusDetailsMap, 1, 1)) + payload.update({'includeJobExecutionState': True}) + assert payload == json.loads(self.noClientTokenJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_QUEUED, self.statusDetailsMap, 1, 1, True)) + payload.update({'includeJobDocument': True}) + assert payload == json.loads(self.noClientTokenJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_QUEUED, self.statusDetailsMap, 1, 1, True, True)) + + payload = {'status':'IN_PROGRESS', 'clientToken': self.clientTokenValue} + assert payload == json.loads(self.thingJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_IN_PROGRESS)) + payload.update({'status':'FAILED'}) + assert payload == json.loads(self.thingJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_FAILED)) + payload.update({'status':'SUCCEEDED'}) + assert payload == json.loads(self.thingJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_SUCCEEDED)) + payload.update({'status':'CANCELED'}) + assert payload == json.loads(self.thingJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_CANCELED)) + payload.update({'status':'REJECTED'}) + assert payload == json.loads(self.thingJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_REJECTED)) + payload.update({'status':'QUEUED'}) + assert payload == json.loads(self.thingJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_QUEUED)) + payload.update({'statusDetails': self.statusDetailsMap}) + assert payload == json.loads(self.thingJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_QUEUED, self.statusDetailsMap)) + payload.update({'expectedVersion': '1'}) + assert payload == json.loads(self.thingJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_QUEUED, self.statusDetailsMap, 1)) + payload.update({'executionNumber': '1'}) + assert payload == json.loads(self.thingJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_QUEUED, self.statusDetailsMap, 1, 1)) + payload.update({'includeJobExecutionState': True}) + assert payload == json.loads(self.thingJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_QUEUED, self.statusDetailsMap, 1, 1, True)) + payload.update({'includeJobDocument': True}) + assert payload == json.loads(self.thingJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_QUEUED, self.statusDetailsMap, 1, 1, True, True)) diff --git a/test/core/protocol/__init__.py b/test/core/protocol/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/core/protocol/connection/__init__.py b/test/core/protocol/connection/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/core/protocol/connection/test_alpn.py b/test/core/protocol/connection/test_alpn.py new file mode 100644 index 0000000..e9d2a2b --- /dev/null +++ b/test/core/protocol/connection/test_alpn.py @@ -0,0 +1,123 @@ +import AWSIoTPythonSDK.core.protocol.connection.alpn as alpn +from AWSIoTPythonSDK.core.protocol.connection.alpn import SSLContextBuilder +import sys +import pytest +try: + from mock import patch + from mock import MagicMock +except: + from unittest.mock import patch + from unittest.mock import MagicMock +if sys.version_info >= (3, 4): + from importlib import reload + + +python3_5_above_only = pytest.mark.skipif(sys.version_info >= (3, 0) and sys.version_info < (3, 5), reason="Requires Python 3.5+") +python2_7_10_above_only = pytest.mark.skipif(sys.version_info < (2, 7, 10), reason="Requires Python 2.7.10+") + +PATCH_MODULE_LOCATION = "AWSIoTPythonSDK.core.protocol.connection.alpn" +SSL_MODULE_NAME = "ssl" +SSL_CONTEXT_METHOD_NAME = "create_default_context" + +DUMMY_SSL_PROTOCOL = "DummySSLProtocol" +DUMMY_CERT_REQ = "DummyCertReq" +DUMMY_CIPHERS = "DummyCiphers" +DUMMY_CA_FILE_PATH = "fake/path/to/ca" +DUMMY_CERT_FILE_PATH = "fake/path/to/cert" +DUMMY_KEY_FILE_PATH = "fake/path/to/key" +DUMMY_ALPN_PROTOCOLS = "x-amzn-mqtt-ca" + + +@python2_7_10_above_only +@python3_5_above_only +class TestALPNSSLContextBuilder: + + def test_check_supportability_no_ssl(self): + self._preserve_ssl() + try: + self._none_ssl() + with pytest.raises(RuntimeError): + alpn.SSLContextBuilder().build() + finally: + self._unnone_ssl() + + def _none_ssl(self): + # We always run the unit test with Python versions that have proper ssl support + # We need to mock it out in this test + sys.modules[SSL_MODULE_NAME] = None + reload(alpn) + + def _unnone_ssl(self): + sys.modules[SSL_MODULE_NAME] = self._normal_ssl_module + reload(alpn) + + def test_check_supportability_no_ssl_context(self): + self._preserve_ssl() + try: + self._mock_ssl() + del self.ssl_mock.SSLContext + with pytest.raises(NotImplementedError): + SSLContextBuilder() + finally: + self._unmock_ssl() + + def test_check_supportability_no_alpn(self): + self._preserve_ssl() + try: + self._mock_ssl() + del self.ssl_mock.SSLContext.set_alpn_protocols + with pytest.raises(NotImplementedError): + SSLContextBuilder() + finally: + self._unmock_ssl() + + def _preserve_ssl(self): + self._normal_ssl_module = sys.modules[SSL_MODULE_NAME] + + def _mock_ssl(self): + self.ssl_mock = MagicMock() + alpn.ssl = self.ssl_mock + + def _unmock_ssl(self): + alpn.ssl = self._normal_ssl_module + + def test_with_ca_certs(self): + self._use_mock_ssl_context() + SSLContextBuilder().with_ca_certs(DUMMY_CA_FILE_PATH).build() + self.mock_ssl_context.load_verify_locations.assert_called_once_with(DUMMY_CA_FILE_PATH) + + def test_with_cert_key_pair(self): + self._use_mock_ssl_context() + SSLContextBuilder().with_cert_key_pair(DUMMY_CERT_FILE_PATH, DUMMY_KEY_FILE_PATH).build() + self.mock_ssl_context.load_cert_chain.assert_called_once_with(DUMMY_CERT_FILE_PATH, DUMMY_KEY_FILE_PATH) + + def test_with_cert_reqs(self): + self._use_mock_ssl_context() + SSLContextBuilder().with_cert_reqs(DUMMY_CERT_REQ).build() + assert self.mock_ssl_context.verify_mode == DUMMY_CERT_REQ + + def test_with_check_hostname(self): + self._use_mock_ssl_context() + SSLContextBuilder().with_check_hostname(True).build() + assert self.mock_ssl_context.check_hostname == True + + def test_with_ciphers(self): + self._use_mock_ssl_context() + SSLContextBuilder().with_ciphers(DUMMY_CIPHERS).build() + self.mock_ssl_context.set_ciphers.assert_called_once_with(DUMMY_CIPHERS) + + def test_with_none_ciphers(self): + self._use_mock_ssl_context() + SSLContextBuilder().with_ciphers(None).build() + assert not self.mock_ssl_context.set_ciphers.called + + def test_with_alpn_protocols(self): + self._use_mock_ssl_context() + SSLContextBuilder().with_alpn_protocols(DUMMY_ALPN_PROTOCOLS) + self.mock_ssl_context.set_alpn_protocols.assert_called_once_with(DUMMY_ALPN_PROTOCOLS) + + def _use_mock_ssl_context(self): + self.mock_ssl_context = MagicMock() + self.ssl_create_default_context_patcher = patch("%s.%s.%s" % (PATCH_MODULE_LOCATION, SSL_MODULE_NAME, SSL_CONTEXT_METHOD_NAME)) + self.mock_ssl_create_default_context = self.ssl_create_default_context_patcher.start() + self.mock_ssl_create_default_context.return_value = self.mock_ssl_context diff --git a/test/core/protocol/connection/test_progressive_back_off_core.py b/test/core/protocol/connection/test_progressive_back_off_core.py new file mode 100755 index 0000000..1cfcd40 --- /dev/null +++ b/test/core/protocol/connection/test_progressive_back_off_core.py @@ -0,0 +1,74 @@ +import time +import AWSIoTPythonSDK.core.protocol.connection.cores as backoff +import pytest + + +class TestProgressiveBackOffCore(): + def setup_method(self, method): + self._dummyBackOffCore = backoff.ProgressiveBackOffCore() + + def teardown_method(self, method): + self._dummyBackOffCore = None + + # Check that current backoff time is one seconds when this is the first time to backoff + def test_BackoffForTheFirstTime(self): + assert self._dummyBackOffCore._currentBackoffTimeSecond == 1 + + # Check that valid input values for backoff configuration is properly configued + def test_CustomConfig_ValidInput(self): + self._dummyBackOffCore.configTime(2, 128, 30) + assert self._dummyBackOffCore._baseReconnectTimeSecond == 2 + assert self._dummyBackOffCore._maximumReconnectTimeSecond == 128 + assert self._dummyBackOffCore._minimumConnectTimeSecond == 30 + + # Check the negative input values will trigger exception + def test_CustomConfig_NegativeInput(self): + with pytest.raises(ValueError) as e: + # _baseReconnectTimeSecond should be greater than zero, otherwise raise exception + self._dummyBackOffCore.configTime(-10, 128, 30) + with pytest.raises(ValueError) as e: + # _maximumReconnectTimeSecond should be greater than zero, otherwise raise exception + self._dummyBackOffCore.configTime(2, -11, 30) + with pytest.raises(ValueError) as e: + # _minimumConnectTimeSecond should be greater than zero, otherwise raise exception + self._dummyBackOffCore.configTime(2, 128, -12) + + # Check the invalid input values will trigger exception + def test_CustomConfig_InvalidInput(self): + with pytest.raises(ValueError) as e: + # _baseReconnectTimeSecond is larger than _minimumConnectTimeSecond, + # which is not allowed... + self._dummyBackOffCore.configTime(200, 128, 30) + + # Check the _currentBackoffTimeSecond increases to twice of the origin after 2nd backoff + def test_backOffUpdatesCurrentBackoffTime(self): + self._dummyBackOffCore.configTime(1, 32, 20) + self._dummyBackOffCore.backOff() # This is the first backoff, block for 0 seconds + assert self._dummyBackOffCore._currentBackoffTimeSecond == self._dummyBackOffCore._baseReconnectTimeSecond * 2 + self._dummyBackOffCore.backOff() # Now progressive backoff calc starts + assert self._dummyBackOffCore._currentBackoffTimeSecond == self._dummyBackOffCore._baseReconnectTimeSecond * 2 * 2 + + # Check that backoff time is reset when connection is stable enough + def test_backOffResetWhenConnectionIsStable(self): + self._dummyBackOffCore.configTime(1, 32, 5) + self._dummyBackOffCore.backOff() # This is the first backoff, block for 0 seconds + assert self._dummyBackOffCore._currentBackoffTimeSecond == self._dummyBackOffCore._baseReconnectTimeSecond * 2 + self._dummyBackOffCore.backOff() # Now progressive backoff calc starts + assert self._dummyBackOffCore._currentBackoffTimeSecond == self._dummyBackOffCore._baseReconnectTimeSecond * 2 * 2 + # Now simulate a stable connection that exceeds _minimumConnectTimeSecond + self._dummyBackOffCore.startStableConnectionTimer() # Called when CONNACK arrives + time.sleep(self._dummyBackOffCore._minimumConnectTimeSecond + 1) + # Timer expires, currentBackoffTimeSecond should be reset + assert self._dummyBackOffCore._currentBackoffTimeSecond == self._dummyBackOffCore._baseReconnectTimeSecond + + # Check that backoff resetting timer is properly cancelled when a disconnect happens immediately + def test_resetTimerProperlyCancelledOnUnstableConnection(self): + self._dummyBackOffCore.configTime(1, 32, 5) + self._dummyBackOffCore.backOff() # This is the first backoff, block for 0 seconds + assert self._dummyBackOffCore._currentBackoffTimeSecond == self._dummyBackOffCore._baseReconnectTimeSecond * 2 + # Now simulate an unstable connection that is within _minimumConnectTimeSecond + self._dummyBackOffCore.startStableConnectionTimer() # Called when CONNACK arrives + time.sleep(self._dummyBackOffCore._minimumConnectTimeSecond - 1) + # Now "disconnect" + self._dummyBackOffCore.backOff() + assert self._dummyBackOffCore._currentBackoffTimeSecond == self._dummyBackOffCore._baseReconnectTimeSecond * 2 * 2 diff --git a/test/core/protocol/connection/test_sigv4_core.py b/test/core/protocol/connection/test_sigv4_core.py new file mode 100644 index 0000000..576efb1 --- /dev/null +++ b/test/core/protocol/connection/test_sigv4_core.py @@ -0,0 +1,148 @@ +from AWSIoTPythonSDK.core.protocol.connection.cores import SigV4Core +from AWSIoTPythonSDK.exception.AWSIoTExceptions import wssNoKeyInEnvironmentError +import os +from datetime import datetime +try: + from mock import patch + from mock import MagicMock +except: + from unittest.mock import patch + from unittest.mock import MagicMock +import pytest +try: + from configparser import ConfigParser # Python 3+ + from configparser import NoOptionError +except ImportError: + from ConfigParser import ConfigParser + from ConfigParser import NoOptionError + + +CREDS_NOT_FOUND_MODE_NO_KEYS = "NoKeys" +CREDS_NOT_FOUND_MODE_EMPTY_VALUES = "EmptyValues" + +PATCH_MODULE_LOCATION = "AWSIoTPythonSDK.core.protocol.connection.cores." +DUMMY_ACCESS_KEY_ID = "TRUSTMETHISIDISFAKE0" +DUMMY_SECRET_ACCESS_KEY = "trustMeThisSecretKeyIsSoFakeAaBbCc00Dd11" +DUMMY_SESSION_TOKEN = "FQoDYXdzEGcaDNSwicOypVyhiHj4JSLUAXTsOXu1YGT/Oaltz" \ + "XujI+cwvEA3zPoUdebHOkaUmRBO3o34J/3r2/+hBqZZNSpyzK" \ + "sBge1MXPwbM2G5ojz3aY4Qj+zD3hEMu9nxk3rhKkmTQWLoB4Z" \ + "rPRG6GJGkoLMAL1sSEh9kqbHN6XIt3F2E+Wn2BhDoGA7ZsXSg" \ + "+pgIntkSZcLT7pCX8pTEaEtRBhJQVc5GTYhG9y9mgjpeVRsbE" \ + "j8yDJzSWDpLGgR7APSvCFX2H+DwsKM564Z4IzjpbntIlLXdQw" \ + "Oytd65dgTlWZkmmYpTwVh+KMq+0MoF" +DUMMY_UTC_NOW_STRFTIME_RESULT = "20170628T204845Z" + +EXPECTED_WSS_URL_WITH_TOKEN = "wss://data.iot.us-east-1.amazonaws.com:44" \ + "3/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X" \ + "-Amz-Credential=TRUSTMETHISIDISFAKE0%2F20" \ + "170628%2Fus-east-1%2Fiotdata%2Faws4_reque" \ + "st&X-Amz-Date=20170628T204845Z&X-Amz-Expi" \ + "res=86400&X-Amz-SignedHeaders=host&X-Amz-" \ + "Signature=b79a4d7e31ccbf96b22d93cce1b500b" \ + "9ee611ec966159547e140ae32e4dcebed&X-Amz-S" \ + "ecurity-Token=FQoDYXdzEGcaDNSwicOypVyhiHj" \ + "4JSLUAXTsOXu1YGT/OaltzXujI%2BcwvEA3zPoUde" \ + "bHOkaUmRBO3o34J/3r2/%2BhBqZZNSpyzKsBge1MX" \ + "PwbM2G5ojz3aY4Qj%2BzD3hEMu9nxk3rhKkmTQWLo" \ + "B4ZrPRG6GJGkoLMAL1sSEh9kqbHN6XIt3F2E%2BWn" \ + "2BhDoGA7ZsXSg%2BpgIntkSZcLT7pCX8pTEaEtRBh" \ + "JQVc5GTYhG9y9mgjpeVRsbEj8yDJzSWDpLGgR7APS" \ + "vCFX2H%2BDwsKM564Z4IzjpbntIlLXdQwOytd65dg" \ + "TlWZkmmYpTwVh%2BKMq%2B0MoF" +EXPECTED_WSS_URL_WITHOUT_TOKEN = "wss://data.iot.us-east-1.amazonaws.com" \ + ":443/mqtt?X-Amz-Algorithm=AWS4-HMAC-SH" \ + "A256&X-Amz-Credential=TRUSTMETHISIDISF" \ + "AKE0%2F20170628%2Fus-east-1%2Fiotdata%" \ + "2Faws4_request&X-Amz-Date=20170628T204" \ + "845Z&X-Amz-Expires=86400&X-Amz-SignedH" \ + "eaders=host&X-Amz-Signature=b79a4d7e31" \ + "ccbf96b22d93cce1b500b9ee611ec966159547" \ + "e140ae32e4dcebed" + + +class TestSigV4Core: + + def setup_method(self, test_method): + self._use_mock_datetime() + self.mock_utc_now_result.strftime.return_value = DUMMY_UTC_NOW_STRFTIME_RESULT + self.sigv4_core = SigV4Core() + + def _use_mock_datetime(self): + self.datetime_patcher = patch(PATCH_MODULE_LOCATION + "datetime", spec=datetime) + self.mock_datetime_constructor = self.datetime_patcher.start() + self.mock_utc_now_result = MagicMock(spec=datetime) + self.mock_datetime_constructor.utcnow.return_value = self.mock_utc_now_result + + def teardown_method(self, test_method): + self.datetime_patcher.stop() + + def test_generate_url_with_env_credentials(self): + self._use_mock_os_environ({ + "AWS_ACCESS_KEY_ID" : DUMMY_ACCESS_KEY_ID, + "AWS_SECRET_ACCESS_KEY" : DUMMY_SECRET_ACCESS_KEY + }) + assert self._invoke_create_wss_endpoint_api() == EXPECTED_WSS_URL_WITHOUT_TOKEN + self.python_os_environ_patcher.stop() + + def test_generate_url_with_env_credentials_token(self): + self._use_mock_os_environ({ + "AWS_ACCESS_KEY_ID" : DUMMY_ACCESS_KEY_ID, + "AWS_SECRET_ACCESS_KEY" : DUMMY_SECRET_ACCESS_KEY, + "AWS_SESSION_TOKEN" : DUMMY_SESSION_TOKEN + }) + assert self._invoke_create_wss_endpoint_api() == EXPECTED_WSS_URL_WITH_TOKEN + self.python_os_environ_patcher.stop() + + def _use_mock_os_environ(self, os_environ_map): + self.python_os_environ_patcher = patch.dict(os.environ, os_environ_map) + self.python_os_environ_patcher.start() + + def _use_mock_configparser(self): + self.configparser_patcher = patch(PATCH_MODULE_LOCATION + "ConfigParser", spec=ConfigParser) + self.mock_configparser_constructor = self.configparser_patcher.start() + self.mock_configparser = MagicMock(spec=ConfigParser) + self.mock_configparser_constructor.return_value = self.mock_configparser + + def test_generate_url_with_input_credentials(self): + self._configure_mocks_credentials_not_found_in_env_config() + self.sigv4_core.setIAMCredentials(DUMMY_ACCESS_KEY_ID, DUMMY_SECRET_ACCESS_KEY, "") + + assert self._invoke_create_wss_endpoint_api() == EXPECTED_WSS_URL_WITHOUT_TOKEN + + self._recover_mocks_for_env_config() + + def test_generate_url_with_input_credentials_token(self): + self._configure_mocks_credentials_not_found_in_env_config() + self.sigv4_core.setIAMCredentials(DUMMY_ACCESS_KEY_ID, DUMMY_SECRET_ACCESS_KEY, DUMMY_SESSION_TOKEN) + + assert self._invoke_create_wss_endpoint_api() == EXPECTED_WSS_URL_WITH_TOKEN + + self._recover_mocks_for_env_config() + + def _recover_mocks_for_env_config(self): + self.python_os_environ_patcher.stop() + self.configparser_patcher.stop() + + def test_generate_url_failure_when_credential_configured_with_none_values(self): + self._use_mock_os_environ({}) + self._use_mock_configparser() + self.mock_configparser.get.side_effect = NoOptionError("option", "section") + self.sigv4_core.setIAMCredentials(None, None, None) + + with pytest.raises(wssNoKeyInEnvironmentError): + self._invoke_create_wss_endpoint_api() + + def _configure_mocks_credentials_not_found_in_env_config(self, mode=CREDS_NOT_FOUND_MODE_NO_KEYS): + if mode == CREDS_NOT_FOUND_MODE_NO_KEYS: + self._use_mock_os_environ({}) + elif mode == CREDS_NOT_FOUND_MODE_EMPTY_VALUES: + self._use_mock_os_environ({ + "AWS_ACCESS_KEY_ID" : "", + "AWS_SECRET_ACCESS_KEY" : "" + }) + self._use_mock_configparser() + self.mock_configparser.get.side_effect = NoOptionError("option", "section") + + def _invoke_create_wss_endpoint_api(self): + return self.sigv4_core.createWebsocketEndpoint("data.iot.us-east-1.amazonaws.com", 443, "us-east-1", + "GET", "iotdata", "/mqtt") diff --git a/test/core/protocol/connection/test_wss_core.py b/test/core/protocol/connection/test_wss_core.py new file mode 100755 index 0000000..bbe7244 --- /dev/null +++ b/test/core/protocol/connection/test_wss_core.py @@ -0,0 +1,249 @@ +from test.sdk_mock.mockSecuredWebsocketCore import mockSecuredWebsocketCoreNoRealHandshake +from test.sdk_mock.mockSecuredWebsocketCore import MockSecuredWebSocketCoreNoSocketIO +from test.sdk_mock.mockSecuredWebsocketCore import MockSecuredWebSocketCoreWithRealHandshake +from test.sdk_mock.mockSSLSocket import mockSSLSocket +import struct +import socket +import pytest +try: + from configparser import ConfigParser # Python 3+ +except ImportError: + from ConfigParser import ConfigParser + + +class TestWssCore: + + # Websocket Constants + _OP_CONTINUATION = 0x0 + _OP_TEXT = 0x1 + _OP_BINARY = 0x2 + _OP_CONNECTION_CLOSE = 0x8 + _OP_PING = 0x9 + _OP_PONG = 0xa + + def _generateStringOfAs(self, length): + ret = "" + for i in range(0, length): + ret += 'a' + return ret + + def _printByteArray(self, src): + for i in range(0, len(src)): + print(hex(src[i])) + print("") + + def _encodeFrame(self, rawPayload, opCode, FIN=1, RSV1=0, RSV2=0, RSV3=0, masked=1): + ret = bytearray() + # FIN+RSV1+RSV2+RSV3 + F = (FIN & 0x01) << 3 + R1 = (RSV1 & 0x01) << 2 + R2 = (RSV2 & 0x01) << 1 + R3 = (RSV3 & 0x01) + FRRR = (F | R1 | R2 | R3) << 4 + # Op byte + opByte = FRRR | opCode + ret.append(opByte) + # Payload Length bytes + maskBit = masked + payloadLength = len(rawPayload) + if payloadLength <= 125: + ret.append((maskBit << 7) | payloadLength) + elif payloadLength <= 0xffff: # 16-bit unsigned int + ret.append((maskBit << 7) | 126) + ret.extend(struct.pack("!H", payloadLength)) + elif payloadLength <= 0x7fffffffffffffff: # 64-bit unsigned int (most significant bit must be 0) + ret.append((maskBit << 7) | 127) + ret.extend(struct.pack("!Q", payloadLength)) + else: # Overflow + raise ValueError("Exceeds the maximum number of bytes for a single websocket frame.") + if maskBit == 1: + # Mask key bytes + maskKey = bytearray(b"1234") + ret.extend(maskKey) + # Mask the payload + payloadBytes = bytearray(rawPayload) + if maskBit == 1: + for i in range(0, payloadLength): + payloadBytes[i] ^= maskKey[i % 4] + ret.extend(payloadBytes) + # Return the assembled wss frame + return ret + + def setup_method(self, method): + self._dummySSLSocket = mockSSLSocket() + + # Wss Handshake + def test_WssHandshakeTimeout(self): + self._dummySSLSocket.refreshReadBuffer(bytearray()) # Empty bytes to read from socket + with pytest.raises(socket.error): + self._dummySecuredWebsocket = \ + MockSecuredWebSocketCoreNoSocketIO(self._dummySSLSocket, "data.iot.region.amazonaws.com", 1234) + + # Constructor + def test_InvalidEndpointPattern(self): + with pytest.raises(ValueError): + self._dummySecuredWebsocket = MockSecuredWebSocketCoreWithRealHandshake(None, "ThisIsNotAValidIoTEndpoint!", 1234) + + def test_BJSEndpointPattern(self): + bjsStyleEndpoint = "blablabla.iot.cn-north-1.amazonaws.com.cn" + unexpectedExceptionMessage = "Invalid endpoint pattern for wss: %s" % bjsStyleEndpoint + # Garbage wss handshake response to ensure the test code gets passed endpoint pattern validation + self._dummySSLSocket.refreshReadBuffer(b"GarbageWssHanshakeResponse") + try: + self._dummySecuredWebsocket = MockSecuredWebSocketCoreWithRealHandshake(self._dummySSLSocket, bjsStyleEndpoint, 1234) + except ValueError as e: + if str(e) == unexpectedExceptionMessage: + raise AssertionError("Encountered unexpected exception when initializing wss core with BJS style endpoint", e) + + # Wss I/O + def test_WssReadComplete(self): + # Config mockSSLSocket to contain a Wss frame + rawPayload = b"If you can see me, this is good." + # The payload of this frame will be masked by a randomly-generated mask key + # securedWebsocketCore should be able to decode it and get the raw payload back + coolFrame = self._encodeFrame(rawPayload, self._OP_BINARY, FIN=1, RSV1=0, RSV2=0, RSV3=0, masked=0) + # self._printByteArray(coolFrame) + self._dummySSLSocket.refreshReadBuffer(coolFrame) + # Init securedWebsocket with this mockSSLSocket + self._dummySecuredWebsocket = \ + mockSecuredWebsocketCoreNoRealHandshake(self._dummySSLSocket, "data.iot.region.amazonaws.com", 1234) + # Read it back: + readItBack = self._dummySecuredWebsocket.read(len(rawPayload)) # Basically read everything + assert rawPayload == readItBack + + def test_WssReadFragmented(self): + rawPayloadFragmented = b"I am designed to be fragmented..." + # The payload of this frame will be masked by a randomly-generated mask key + # securedWebsocketCore should be able to decode it and get the raw payload back + stop1 = 4 + stop2 = 9 + coolFrame = self._encodeFrame(rawPayloadFragmented, self._OP_BINARY, FIN=1, RSV1=0, RSV2=0, RSV3=0, masked=0) + # self._printByteArray(coolFrame) + coolFramePart1 = coolFrame[0:stop1] + coolFramePart2 = coolFrame[stop1:stop2] + coolFramePart3 = coolFrame[stop2:len(coolFrame)] + # Config mockSSLSocket to contain a fragmented Wss frame + self._dummySSLSocket.setReadFragmented() + self._dummySSLSocket.addReadBufferFragment(coolFramePart1) + self._dummySSLSocket.addReadBufferFragment(coolFramePart2) + self._dummySSLSocket.addReadBufferFragment(coolFramePart3) + self._dummySSLSocket.loadFirstFragmented() + # In this way, reading from SSLSocket will result in 3 sslError, simulating the situation where data is not ready + # Init securedWebsocket with this mockSSLSocket + self._dummySecuredWebsocket = \ + mockSecuredWebsocketCoreNoRealHandshake(self._dummySSLSocket, "data.iot.region.amazonaws.com", 1234) + # Read it back: + readItBack = bytearray() + while len(readItBack) != len(rawPayloadFragmented): + try: + # Will be interrupted due to faked socket I/O Error + # Should be able to read back the complete + readItBack += self._dummySecuredWebsocket.read(len(rawPayloadFragmented)) # Basically read everything + except: + pass + assert rawPayloadFragmented == readItBack + + def test_WssReadlongFrame(self): + # Config mockSSLSocket to contain a Wss frame + rawPayloadLong = bytearray(self._generateStringOfAs(300), 'utf-8') # 300 bytes of raw payload, will use extended payload length bytes in encoding + # The payload of this frame will be masked by a randomly-generated mask key + # securedWebsocketCore should be able to decode it and get the raw payload back + coolFrame = self._encodeFrame(rawPayloadLong, self._OP_BINARY, FIN=1, RSV1=0, RSV2=0, RSV3=0, masked=0) + # self._printByteArray(coolFrame) + self._dummySSLSocket.refreshReadBuffer(coolFrame) + # Init securedWebsocket with this mockSSLSocket + self._dummySecuredWebsocket = \ + mockSecuredWebsocketCoreNoRealHandshake(self._dummySSLSocket, "data.iot.region.amazonaws.com", 1234) + # Read it back: + readItBack = self._dummySecuredWebsocket.read(len(rawPayloadLong)) # Basically read everything + assert rawPayloadLong == readItBack + + def test_WssReadReallylongFrame(self): + # Config mockSSLSocket to contain a Wss frame + # Maximum allowed length of a wss payload is greater than maximum allowed payload length of a MQTT payload + rawPayloadLong = bytearray(self._generateStringOfAs(0xffff + 3), 'utf-8') # 0xffff + 3 bytes of raw payload, will use extended payload length bytes in encoding + # The payload of this frame will be masked by a randomly-generated mask key + # securedWebsocketCore should be able to decode it and get the raw payload back + coolFrame = self._encodeFrame(rawPayloadLong, self._OP_BINARY, FIN=1, RSV1=0, RSV2=0, RSV3=0, masked=0) + # self._printByteArray(coolFrame) + self._dummySSLSocket.refreshReadBuffer(coolFrame) + # Init securedWebsocket with this mockSSLSocket + self._dummySecuredWebsocket = \ + mockSecuredWebsocketCoreNoRealHandshake(self._dummySSLSocket, "data.iot.region.amazonaws.com", 1234) + # Read it back: + readItBack = self._dummySecuredWebsocket.read(len(rawPayloadLong)) # Basically read everything + assert rawPayloadLong == readItBack + + def test_WssWriteComplete(self): + ToBeWritten = b"Write me to the cloud." + # Init securedWebsocket with this mockSSLSocket + self._dummySecuredWebsocket = \ + mockSecuredWebsocketCoreNoRealHandshake(self._dummySSLSocket, "data.iot.region.amazonaws.com", 1234) + # Fire the write op + self._dummySecuredWebsocket.write(ToBeWritten) + ans = self._encodeFrame(ToBeWritten, self._OP_BINARY, FIN=1, RSV1=0, RSV2=0, RSV3=0, masked=1) + # self._printByteArray(ans) + assert ans == self._dummySSLSocket.getWriteBuffer() + + def test_WssWriteFragmented(self): + ToBeWritten = b"Write me to the cloud again." + # Configure SSLSocket to perform interrupted write op + self._dummySSLSocket.setFlipWriteError() + # Init securedWebsocket with this mockSSLSocket + self._dummySecuredWebsocket = \ + mockSecuredWebsocketCoreNoRealHandshake(self._dummySSLSocket, "data.iot.region.amazonaws.com", 1234) + # Fire the write op + with pytest.raises(socket.error) as e: + self._dummySecuredWebsocket.write(ToBeWritten) + assert "Not ready for write op" == e.value.strerror + lengthWritten = self._dummySecuredWebsocket.write(ToBeWritten) + ans = self._encodeFrame(ToBeWritten, self._OP_BINARY, FIN=1, RSV1=0, RSV2=0, RSV3=0, masked=1) + assert lengthWritten == len(ToBeWritten) + assert ans == self._dummySSLSocket.getWriteBuffer() + + # Wss Client Behavior + def test_ClientClosesConnectionIfServerResponseIsMasked(self): + ToBeWritten = b"I am designed to be masked." + maskedFrame = self._encodeFrame(ToBeWritten, self._OP_BINARY, FIN=1, RSV1=0, RSV2=0, RSV3=0, masked=1) + self._dummySSLSocket.refreshReadBuffer(maskedFrame) + # Init securedWebsocket with this mockSSLSocket + self._dummySecuredWebsocket = \ + mockSecuredWebsocketCoreNoRealHandshake(self._dummySSLSocket, "data.iot.region.amazonaws.com", 1234) + # Now read it back + with pytest.raises(socket.error) as e: + self._dummySecuredWebsocket.read(len(ToBeWritten)) + assert "Server response masked, closing connection and try again." == e.value.strerror + # Verify that a closing frame from the client is on its way + closingFrame = self._encodeFrame(b"", self._OP_CONNECTION_CLOSE, FIN=1, RSV1=0, RSV2=0, RSV3=0, masked=1) + assert closingFrame == self._dummySSLSocket.getWriteBuffer() + + def test_ClientClosesConnectionIfServerResponseHasReserveBitsSet(self): + ToBeWritten = b"I am designed to be masked." + maskedFrame = self._encodeFrame(ToBeWritten, self._OP_BINARY, FIN=1, RSV1=1, RSV2=0, RSV3=0, masked=1) + self._dummySSLSocket.refreshReadBuffer(maskedFrame) + # Init securedWebsocket with this mockSSLSocket + self._dummySecuredWebsocket = \ + mockSecuredWebsocketCoreNoRealHandshake(self._dummySSLSocket, "data.iot.region.amazonaws.com", 1234) + # Now read it back + with pytest.raises(socket.error) as e: + self._dummySecuredWebsocket.read(len(ToBeWritten)) + assert "RSV bits set with NO negotiated extensions." == e.value.strerror + # Verify that a closing frame from the client is on its way + closingFrame = self._encodeFrame(b"", self._OP_CONNECTION_CLOSE, FIN=1, RSV1=0, RSV2=0, RSV3=0, masked=1) + assert closingFrame == self._dummySSLSocket.getWriteBuffer() + + def test_ClientSendsPONGIfReceivedPING(self): + PINGFrame = self._encodeFrame(b"", self._OP_PING, FIN=1, RSV1=0, RSV2=0, RSV3=0, masked=0) + self._dummySSLSocket.refreshReadBuffer(PINGFrame) + # Init securedWebsocket with this mockSSLSocket + self._dummySecuredWebsocket = \ + mockSecuredWebsocketCoreNoRealHandshake(self._dummySSLSocket, "data.iot.region.amazonaws.com", 1234) + # Now read it back, this must be in the next round of paho MQTT packet reading + # Should fail since we only have a PING to read, it never contains a valid MQTT payload + with pytest.raises(socket.error) as e: + self._dummySecuredWebsocket.read(5) + assert "Not a complete MQTT packet payload within this wss frame." == e.value.strerror + # Verify that PONG frame from the client is on its way + PONGFrame = self._encodeFrame(b"", self._OP_PONG, FIN=1, RSV1=0, RSV2=0, RSV3=0, masked=1) + assert PONGFrame == self._dummySSLSocket.getWriteBuffer() + diff --git a/test/core/protocol/internal/__init__.py b/test/core/protocol/internal/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/core/protocol/internal/test_clients_client_status.py b/test/core/protocol/internal/test_clients_client_status.py new file mode 100644 index 0000000..b84a0d6 --- /dev/null +++ b/test/core/protocol/internal/test_clients_client_status.py @@ -0,0 +1,31 @@ +from AWSIoTPythonSDK.core.protocol.internal.clients import ClientStatus +from AWSIoTPythonSDK.core.protocol.internal.clients import ClientStatusContainer + + +class TestClientsClientStatus: + + def setup_method(self, test_method): + self.client_status = ClientStatusContainer() + + def test_set_client_status(self): + assert self.client_status.get_status() == ClientStatus.IDLE # Client status should start with IDLE + self._set_client_status_and_verify(ClientStatus.ABNORMAL_DISCONNECT) + self._set_client_status_and_verify(ClientStatus.CONNECT) + self._set_client_status_and_verify(ClientStatus.RESUBSCRIBE) + self._set_client_status_and_verify(ClientStatus.DRAINING) + self._set_client_status_and_verify(ClientStatus.STABLE) + + def test_client_status_does_not_change_unless_user_connect_after_user_disconnect(self): + self.client_status.set_status(ClientStatus.USER_DISCONNECT) + self._set_client_status_and_verify(ClientStatus.ABNORMAL_DISCONNECT, ClientStatus.USER_DISCONNECT) + self._set_client_status_and_verify(ClientStatus.RESUBSCRIBE, ClientStatus.USER_DISCONNECT) + self._set_client_status_and_verify(ClientStatus.DRAINING, ClientStatus.USER_DISCONNECT) + self._set_client_status_and_verify(ClientStatus.STABLE, ClientStatus.USER_DISCONNECT) + self._set_client_status_and_verify(ClientStatus.CONNECT) + + def _set_client_status_and_verify(self, set_client_status_type, verify_client_status_type=None): + self.client_status.set_status(set_client_status_type) + if verify_client_status_type: + assert self.client_status.get_status() == verify_client_status_type + else: + assert self.client_status.get_status() == set_client_status_type diff --git a/test/core/protocol/internal/test_clients_internal_async_client.py b/test/core/protocol/internal/test_clients_internal_async_client.py new file mode 100644 index 0000000..2d0e3cf --- /dev/null +++ b/test/core/protocol/internal/test_clients_internal_async_client.py @@ -0,0 +1,388 @@ +from AWSIoTPythonSDK.core.protocol.internal.clients import InternalAsyncMqttClient +from AWSIoTPythonSDK.core.protocol.internal.events import FixedEventMids +from AWSIoTPythonSDK.core.util.providers import CertificateCredentialsProvider +from AWSIoTPythonSDK.core.util.providers import IAMCredentialsProvider +from AWSIoTPythonSDK.core.util.providers import EndpointProvider +from AWSIoTPythonSDK.core.util.providers import CiphersProvider +from AWSIoTPythonSDK.core.protocol.paho.client import MQTTv311 +from AWSIoTPythonSDK.core.protocol.paho.client import Client +from AWSIoTPythonSDK.core.protocol.paho.client import MQTT_ERR_SUCCESS +from AWSIoTPythonSDK.core.protocol.paho.client import MQTT_ERR_ERRNO +try: + from mock import patch + from mock import MagicMock + from mock import NonCallableMagicMock +except: + from unittest.mock import patch + from unittest.mock import MagicMock + from unittest.mock import NonCallableMagicMock +import ssl +import pytest + + +DUMMY_CLIENT_ID = "CoolClientId" +FAKE_PATH = "/fake/path/" +DUMMY_CA_PATH = FAKE_PATH + "ca.crt" +DUMMY_CERT_PATH = FAKE_PATH + "cert.pem" +DUMMY_KEY_PATH = FAKE_PATH + "key.pem" +DUMMY_ACCESS_KEY_ID = "AccessKeyId" +DUMMY_SECRET_ACCESS_KEY = "SecretAccessKey" +DUMMY_SESSION_TOKEN = "SessionToken" +DUMMY_TOPIC = "topic/test" +DUMMY_PAYLOAD = "TestPayload" +DUMMY_QOS = 1 +DUMMY_BASE_RECONNECT_QUIET_SEC = 1 +DUMMY_MAX_RECONNECT_QUIET_SEC = 32 +DUMMY_STABLE_CONNECTION_SEC = 20 +DUMMY_ENDPOINT = "dummy.endpoint.com" +DUMMY_PORT = 8888 +DUMMY_SUCCESS_RC = MQTT_ERR_SUCCESS +DUMMY_FAILURE_RC = MQTT_ERR_ERRNO +DUMMY_KEEP_ALIVE_SEC = 60 +DUMMY_REQUEST_MID = 89757 +DUMMY_USERNAME = "DummyUsername" +DUMMY_PASSWORD = "DummyPassword" +DUMMY_ALPN_PROTOCOLS = ["DummyALPNProtocol"] + +KEY_GET_CA_PATH_CALL_COUNT = "get_ca_path_call_count" +KEY_GET_CERT_PATH_CALL_COUNT = "get_cert_path_call_count" +KEY_GET_KEY_PATH_CALL_COUNT = "get_key_path_call_count" + +class TestClientsInternalAsyncClient: + + def setup_method(self, test_method): + # We init a cert based client by default + self.internal_async_client = InternalAsyncMqttClient(DUMMY_CLIENT_ID, False, MQTTv311, False) + self._mock_internal_members() + + def _mock_internal_members(self): + self.mock_paho_client = MagicMock(spec=Client) + # TODO: See if we can replace the following with patch.object + self.internal_async_client._paho_client = self.mock_paho_client + + def test_set_cert_credentials_provider_x509(self): + mock_cert_credentials_provider = self._mock_cert_credentials_provider() + cipher_provider = CiphersProvider() + self.internal_async_client.set_cert_credentials_provider(mock_cert_credentials_provider, cipher_provider) + + expected_call_count = { + KEY_GET_CA_PATH_CALL_COUNT : 1, + KEY_GET_CERT_PATH_CALL_COUNT : 1, + KEY_GET_KEY_PATH_CALL_COUNT : 1 + } + self._verify_cert_credentials_provider(mock_cert_credentials_provider, expected_call_count) + self.mock_paho_client.tls_set.assert_called_once_with(ca_certs=DUMMY_CA_PATH, + certfile=DUMMY_CERT_PATH, + keyfile=DUMMY_KEY_PATH, + cert_reqs=ssl.CERT_REQUIRED, + tls_version=ssl.PROTOCOL_SSLv23, + ciphers=cipher_provider.get_ciphers()) + + def test_set_cert_credentials_provider_wss(self): + self.internal_async_client = InternalAsyncMqttClient(DUMMY_CLIENT_ID, False, MQTTv311, True) + self._mock_internal_members() + mock_cert_credentials_provider = self._mock_cert_credentials_provider() + cipher_provider = CiphersProvider() + + self.internal_async_client.set_cert_credentials_provider(mock_cert_credentials_provider, cipher_provider) + + expected_call_count = { + KEY_GET_CA_PATH_CALL_COUNT : 1 + } + self._verify_cert_credentials_provider(mock_cert_credentials_provider, expected_call_count) + self.mock_paho_client.tls_set.assert_called_once_with(ca_certs=DUMMY_CA_PATH, + cert_reqs=ssl.CERT_REQUIRED, + tls_version=ssl.PROTOCOL_SSLv23, + ciphers=cipher_provider.get_ciphers()) + + def _mock_cert_credentials_provider(self): + mock_cert_credentials_provider = MagicMock(spec=CertificateCredentialsProvider) + mock_cert_credentials_provider.get_ca_path.return_value = DUMMY_CA_PATH + mock_cert_credentials_provider.get_cert_path.return_value = DUMMY_CERT_PATH + mock_cert_credentials_provider.get_key_path.return_value = DUMMY_KEY_PATH + return mock_cert_credentials_provider + + def _verify_cert_credentials_provider(self, mock_cert_credentials_provider, expected_values): + expected_get_ca_path_call_count = expected_values.get(KEY_GET_CA_PATH_CALL_COUNT) + expected_get_cert_path_call_count = expected_values.get(KEY_GET_CERT_PATH_CALL_COUNT) + expected_get_key_path_call_count = expected_values.get(KEY_GET_KEY_PATH_CALL_COUNT) + + if expected_get_ca_path_call_count is not None: + assert mock_cert_credentials_provider.get_ca_path.call_count == expected_get_ca_path_call_count + if expected_get_cert_path_call_count is not None: + assert mock_cert_credentials_provider.get_cert_path.call_count == expected_get_cert_path_call_count + if expected_get_key_path_call_count is not None: + assert mock_cert_credentials_provider.get_key_path.call_count == expected_get_key_path_call_count + + def test_set_iam_credentials_provider(self): + self.internal_async_client = InternalAsyncMqttClient(DUMMY_CLIENT_ID, False, MQTTv311, True) + self._mock_internal_members() + mock_iam_credentials_provider = self._mock_iam_credentials_provider() + + self.internal_async_client.set_iam_credentials_provider(mock_iam_credentials_provider) + + self._verify_iam_credentials_provider(mock_iam_credentials_provider) + + def _mock_iam_credentials_provider(self): + mock_iam_credentials_provider = MagicMock(spec=IAMCredentialsProvider) + mock_iam_credentials_provider.get_ca_path.return_value = DUMMY_CA_PATH + mock_iam_credentials_provider.get_access_key_id.return_value = DUMMY_ACCESS_KEY_ID + mock_iam_credentials_provider.get_secret_access_key.return_value = DUMMY_SECRET_ACCESS_KEY + mock_iam_credentials_provider.get_session_token.return_value = DUMMY_SESSION_TOKEN + return mock_iam_credentials_provider + + def _verify_iam_credentials_provider(self, mock_iam_credentials_provider): + assert mock_iam_credentials_provider.get_access_key_id.call_count == 1 + assert mock_iam_credentials_provider.get_secret_access_key.call_count == 1 + assert mock_iam_credentials_provider.get_session_token.call_count == 1 + + def test_configure_last_will(self): + self.internal_async_client.configure_last_will(DUMMY_TOPIC, DUMMY_PAYLOAD, DUMMY_QOS) + self.mock_paho_client.will_set.assert_called_once_with(DUMMY_TOPIC, + DUMMY_PAYLOAD, + DUMMY_QOS, + False) + + def test_clear_last_will(self): + self.internal_async_client.clear_last_will() + assert self.mock_paho_client.will_clear.call_count == 1 + + def test_set_username_password(self): + self.internal_async_client.set_username_password(DUMMY_USERNAME, DUMMY_PASSWORD) + self.mock_paho_client.username_pw_set.assert_called_once_with(DUMMY_USERNAME, DUMMY_PASSWORD) + + def test_configure_reconnect_back_off(self): + self.internal_async_client.configure_reconnect_back_off(DUMMY_BASE_RECONNECT_QUIET_SEC, + DUMMY_MAX_RECONNECT_QUIET_SEC, + DUMMY_STABLE_CONNECTION_SEC) + self.mock_paho_client.setBackoffTiming.assert_called_once_with(DUMMY_BASE_RECONNECT_QUIET_SEC, + DUMMY_MAX_RECONNECT_QUIET_SEC, + DUMMY_STABLE_CONNECTION_SEC) + def test_configure_alpn_protocols(self): + self.internal_async_client.configure_alpn_protocols(DUMMY_ALPN_PROTOCOLS) + self.mock_paho_client.config_alpn_protocols.assert_called_once_with(DUMMY_ALPN_PROTOCOLS) + + def test_connect_success_rc(self): + self._internal_test_connect_with_rc(DUMMY_SUCCESS_RC) + + def test_connect_failure_rc(self): + self._internal_test_connect_with_rc(DUMMY_FAILURE_RC) + + def _internal_test_connect_with_rc(self, expected_connect_rc): + mock_endpoint_provider = self._mock_endpoint_provider() + self.mock_paho_client.connect.return_value = expected_connect_rc + self.internal_async_client.set_endpoint_provider(mock_endpoint_provider) + + actual_rc = self.internal_async_client.connect(DUMMY_KEEP_ALIVE_SEC) + + assert mock_endpoint_provider.get_host.call_count == 1 + assert mock_endpoint_provider.get_port.call_count == 1 + event_callback_map = self.internal_async_client.get_event_callback_map() + assert len(event_callback_map) == 3 + assert event_callback_map[FixedEventMids.CONNACK_MID] is not None + assert event_callback_map[FixedEventMids.DISCONNECT_MID] is not None + assert event_callback_map[FixedEventMids.MESSAGE_MID] is not None + assert self.mock_paho_client.connect.call_count == 1 + if expected_connect_rc == MQTT_ERR_SUCCESS: + assert self.mock_paho_client.loop_start.call_count == 1 + else: + assert self.mock_paho_client.loop_start.call_count == 0 + assert actual_rc == expected_connect_rc + + def _mock_endpoint_provider(self): + mock_endpoint_provider = MagicMock(spec=EndpointProvider) + mock_endpoint_provider.get_host.return_value = DUMMY_ENDPOINT + mock_endpoint_provider.get_port.return_value = DUMMY_PORT + return mock_endpoint_provider + + def test_start_background_network_io(self): + self.internal_async_client.start_background_network_io() + assert self.mock_paho_client.loop_start.call_count == 1 + + def test_stop_background_network_io(self): + self.internal_async_client.stop_background_network_io() + assert self.mock_paho_client.loop_stop.call_count == 1 + + def test_disconnect_success_rc(self): + self._internal_test_disconnect_with_rc(DUMMY_SUCCESS_RC) + + def test_disconnect_failure_rc(self): + self._internal_test_disconnect_with_rc(DUMMY_FAILURE_RC) + + def _internal_test_disconnect_with_rc(self, expected_disconnect_rc): + self.mock_paho_client.disconnect.return_value = expected_disconnect_rc + + actual_rc = self.internal_async_client.disconnect() + + event_callback_map = self.internal_async_client.get_event_callback_map() + assert self.mock_paho_client.disconnect.call_count == 1 + if expected_disconnect_rc == MQTT_ERR_SUCCESS: + # Since we only call disconnect, there should be only one registered callback + assert len(event_callback_map) == 1 + assert event_callback_map[FixedEventMids.DISCONNECT_MID] is not None + else: + assert len(event_callback_map) == 0 + assert actual_rc == expected_disconnect_rc + + def test_publish_qos0_success_rc(self): + self._internal_test_publish_with(0, DUMMY_REQUEST_MID, DUMMY_SUCCESS_RC) + self._internal_test_publish_with(0, DUMMY_REQUEST_MID, DUMMY_SUCCESS_RC, NonCallableMagicMock()) + + def test_publish_qos0_failure_rc(self): + self._internal_test_publish_with(0, DUMMY_REQUEST_MID, DUMMY_FAILURE_RC) + self._internal_test_publish_with(0, DUMMY_REQUEST_MID, DUMMY_FAILURE_RC, NonCallableMagicMock()) + + def test_publish_qos1_success_rc(self): + self._internal_test_publish_with(1, DUMMY_REQUEST_MID, DUMMY_SUCCESS_RC) + self._internal_test_publish_with(1, DUMMY_REQUEST_MID, DUMMY_SUCCESS_RC, NonCallableMagicMock()) + + def test_publish_qos1_failure_rc(self): + self._internal_test_publish_with(1, DUMMY_REQUEST_MID, DUMMY_FAILURE_RC) + self._internal_test_publish_with(1, DUMMY_REQUEST_MID, DUMMY_FAILURE_RC, NonCallableMagicMock()) + + def _internal_test_publish_with(self, qos, expected_mid, expected_rc, expected_callback=None): + self.mock_paho_client.publish.return_value = expected_rc, expected_mid + + actual_rc, actual_mid = self.internal_async_client.publish(DUMMY_TOPIC, + DUMMY_PAYLOAD, + qos, + retain=False, + ack_callback=expected_callback) + + self._verify_event_callback_map_for_pub_sub_unsub(expected_rc, expected_mid, qos, expected_callback) + assert actual_rc == expected_rc + assert actual_mid == expected_mid + + def test_subscribe_success_rc(self): + self._internal_test_subscribe_with(DUMMY_REQUEST_MID, DUMMY_SUCCESS_RC) + self._internal_test_subscribe_with(DUMMY_REQUEST_MID, DUMMY_SUCCESS_RC, NonCallableMagicMock()) + + def test_subscribe_failure_rc(self): + self._internal_test_subscribe_with(DUMMY_REQUEST_MID, DUMMY_FAILURE_RC) + self._internal_test_subscribe_with(DUMMY_REQUEST_MID, DUMMY_FAILURE_RC, NonCallableMagicMock()) + + def _internal_test_subscribe_with(self, expected_mid, expected_rc, expected_callback=None): + self.mock_paho_client.subscribe.return_value = expected_rc, expected_mid + + actual_rc, actual_mid = self.internal_async_client.subscribe(DUMMY_TOPIC, DUMMY_QOS, expected_callback) + + self._verify_event_callback_map_for_pub_sub_unsub(expected_rc, expected_mid, qos=None, callback=expected_callback) + assert actual_rc == expected_rc + assert actual_mid == expected_mid + + def test_unsubscribe_success_rc(self): + self._internal_test_unsubscribe_with(DUMMY_REQUEST_MID, DUMMY_SUCCESS_RC) + self._internal_test_unsubscribe_with(DUMMY_REQUEST_MID, DUMMY_SUCCESS_RC, NonCallableMagicMock()) + + def test_unsubscribe_failure_rc(self): + self._internal_test_unsubscribe_with(DUMMY_REQUEST_MID, DUMMY_FAILURE_RC) + self._internal_test_publish_with(DUMMY_REQUEST_MID, DUMMY_FAILURE_RC, NonCallableMagicMock()) + + def _internal_test_unsubscribe_with(self, expected_mid, expected_rc, expected_callback=None): + self.mock_paho_client.unsubscribe.return_value = expected_rc, expected_mid + + actual_rc, actual_mid = self.internal_async_client.unsubscribe(DUMMY_TOPIC, expected_callback) + + self._verify_event_callback_map_for_pub_sub_unsub(expected_rc, expected_mid, qos=None, callback=expected_callback) + assert actual_rc == expected_rc + assert actual_mid == expected_mid + + def _verify_event_callback_map_for_pub_sub_unsub(self, expected_rc, expected_mid, qos=None, callback=None): + event_callback_map = self.internal_async_client.get_event_callback_map() + should_have_callback_in_map = expected_rc == DUMMY_SUCCESS_RC and callback + if qos is not None: + should_have_callback_in_map = should_have_callback_in_map and qos > 0 + + if should_have_callback_in_map: + # Since we only perform this request, there should be only one registered callback + assert len(event_callback_map) == 1 + assert event_callback_map[expected_mid] == callback + else: + assert len(event_callback_map) == 0 + + def test_register_internal_event_callbacks(self): + expected_callback = NonCallableMagicMock() + self.internal_async_client.register_internal_event_callbacks(expected_callback, + expected_callback, + expected_callback, + expected_callback, + expected_callback, + expected_callback) + self._verify_internal_event_callbacks(expected_callback) + + def test_unregister_internal_event_callbacks(self): + self.internal_async_client.unregister_internal_event_callbacks() + self._verify_internal_event_callbacks(None) + + def _verify_internal_event_callbacks(self, expected_callback): + assert self.mock_paho_client.on_connect == expected_callback + assert self.mock_paho_client.on_disconnect == expected_callback + assert self.mock_paho_client.on_publish == expected_callback + assert self.mock_paho_client.on_subscribe == expected_callback + assert self.mock_paho_client.on_unsubscribe == expected_callback + assert self.mock_paho_client.on_message == expected_callback + + def test_invoke_event_callback_fixed_request(self): + # We use disconnect as an example for fixed request to "register" and event callback + self.mock_paho_client.disconnect.return_value = DUMMY_SUCCESS_RC + event_callback = MagicMock() + rc = self.internal_async_client.disconnect(event_callback) + self.internal_async_client.invoke_event_callback(FixedEventMids.DISCONNECT_MID, rc) + + event_callback.assert_called_once_with(FixedEventMids.DISCONNECT_MID, rc) + event_callback_map = self.internal_async_client.get_event_callback_map() + assert len(event_callback_map) == 1 # Fixed request event callback never gets removed + assert event_callback_map[FixedEventMids.DISCONNECT_MID] is not None + + def test_invoke_event_callback_non_fixed_request(self): + # We use unsubscribe as an example for non-fixed request to "register" an event callback + self.mock_paho_client.unsubscribe.return_value = DUMMY_SUCCESS_RC, DUMMY_REQUEST_MID + event_callback = MagicMock() + rc, mid = self.internal_async_client.unsubscribe(DUMMY_TOPIC, event_callback) + self.internal_async_client.invoke_event_callback(mid) + + event_callback.assert_called_once_with(mid=mid) + event_callback_map = self.internal_async_client.get_event_callback_map() + assert len(event_callback_map) == 0 # Non-fixed request event callback gets removed after successfully invoked + + @pytest.mark.timeout(3) + def test_invoke_event_callback_that_has_client_api_call(self): + # We use subscribe and publish on SUBACK as an example of having client API call within event callbacks + self.mock_paho_client.subscribe.return_value = DUMMY_SUCCESS_RC, DUMMY_REQUEST_MID + self.mock_paho_client.publish.return_value = DUMMY_SUCCESS_RC, DUMMY_REQUEST_MID + 1 + rc, mid = self.internal_async_client.subscribe(DUMMY_TOPIC, DUMMY_QOS, ack_callback=self._publish_on_suback) + + self.internal_async_client.invoke_event_callback(mid, (DUMMY_QOS,)) + + event_callback_map = self.internal_async_client.get_event_callback_map() + assert len(event_callback_map) == 0 + + def _publish_on_suback(self, mid, data): + self.internal_async_client.publish(DUMMY_TOPIC, DUMMY_PAYLOAD, DUMMY_QOS) + + def test_remove_event_callback(self): + # We use unsubscribe as an example for non-fixed request to "register" an event callback + self.mock_paho_client.unsubscribe.return_value = DUMMY_SUCCESS_RC, DUMMY_REQUEST_MID + event_callback = MagicMock() + rc, mid = self.internal_async_client.unsubscribe(DUMMY_TOPIC, event_callback) + + event_callback_map = self.internal_async_client.get_event_callback_map() + assert len(event_callback_map) == 1 + + self.internal_async_client.remove_event_callback(mid) + assert len(event_callback_map) == 0 + + def test_clean_up_event_callbacks(self): + # We use unsubscribe as an example for on-fixed request to "register" an event callback + self.mock_paho_client.unsubscribe.return_value = DUMMY_SUCCESS_RC, DUMMY_REQUEST_MID + # We use disconnect as an example for fixed request to "register" and event callback + self.mock_paho_client.disconnect.return_value = DUMMY_SUCCESS_RC + event_callback = MagicMock() + self.internal_async_client.unsubscribe(DUMMY_TOPIC, event_callback) + self.internal_async_client.disconnect(event_callback) + + event_callback_map = self.internal_async_client.get_event_callback_map() + assert len(event_callback_map) == 2 + + self.internal_async_client.clean_up_event_callbacks() + assert len(event_callback_map) == 0 diff --git a/test/core/protocol/internal/test_offline_request_queue.py b/test/core/protocol/internal/test_offline_request_queue.py new file mode 100755 index 0000000..f666bb9 --- /dev/null +++ b/test/core/protocol/internal/test_offline_request_queue.py @@ -0,0 +1,67 @@ +import AWSIoTPythonSDK.core.protocol.internal.queues as Q +from AWSIoTPythonSDK.core.protocol.internal.queues import AppendResults +import pytest + + +class TestOfflineRequestQueue(): + + # Check that invalid input types are filtered out on initialization + def test_InvalidTypeInit(self): + with pytest.raises(TypeError): + Q.OfflineRequestQueue(1.7, 0) + with pytest.raises(TypeError): + Q.OfflineRequestQueue(0, 1.7) + + # Check that elements can be append to a normal finite queue + def test_NormalAppend(self): + coolQueue = Q.OfflineRequestQueue(20, 1) + numberOfMessages = 5 + answer = list(range(0, numberOfMessages)) + for i in range(0, numberOfMessages): + coolQueue.append(i) + assert answer == coolQueue + + # Check that new elements are dropped for DROPNEWEST configuration + def test_DropNewest(self): + coolQueue = Q.OfflineRequestQueue(3, 1) # Queueing section: 3, Response section: 1, DropNewest + numberOfMessages = 10 + answer = [0, 1, 2] # '0', '1' and '2' are stored, others are dropped. + fullCount = 0 + for i in range(0, numberOfMessages): + if coolQueue.append(i) == AppendResults.APPEND_FAILURE_QUEUE_FULL: + fullCount += 1 + assert answer == coolQueue + assert 7 == fullCount + + # Check that old elements are dropped for DROPOLDEST configuration + def test_DropOldest(self): + coolQueue = Q.OfflineRequestQueue(3, 0) + numberOfMessages = 10 + answer = [7, 8, 9] # '7', '8' and '9' are stored, others (older ones) are dropped. + fullCount = 0 + for i in range(0, numberOfMessages): + if coolQueue.append(i) == AppendResults.APPEND_FAILURE_QUEUE_FULL: + fullCount += 1 + assert answer == coolQueue + assert 7 == fullCount + + # Check infinite queue + def test_Infinite(self): + coolQueue = Q.OfflineRequestQueue(-100, 1) + numberOfMessages = 10000 + answer = list(range(0, numberOfMessages)) + for i in range(0, numberOfMessages): + coolQueue.append(i) + assert answer == coolQueue # Nothing should be dropped since response section is infinite + + # Check disabled queue + def test_Disabled(self): + coolQueue = Q.OfflineRequestQueue(0, 1) + numberOfMessages = 10 + answer = list() + disableFailureCount = 0 + for i in range(0, numberOfMessages): + if coolQueue.append(i) == AppendResults.APPEND_FAILURE_QUEUE_DISABLED: + disableFailureCount += 1 + assert answer == coolQueue # Nothing should be appended since the queue is disabled + assert numberOfMessages == disableFailureCount diff --git a/test/core/protocol/internal/test_workers_event_consumer.py b/test/core/protocol/internal/test_workers_event_consumer.py new file mode 100644 index 0000000..4edfb6c --- /dev/null +++ b/test/core/protocol/internal/test_workers_event_consumer.py @@ -0,0 +1,273 @@ +from AWSIoTPythonSDK.core.protocol.internal.workers import EventConsumer +from AWSIoTPythonSDK.core.protocol.internal.clients import ClientStatusContainer +from AWSIoTPythonSDK.core.protocol.internal.clients import ClientStatus +from AWSIoTPythonSDK.core.protocol.internal.events import FixedEventMids +from AWSIoTPythonSDK.core.protocol.internal.events import EventTypes +from AWSIoTPythonSDK.core.protocol.paho.client import MQTTMessage +from AWSIoTPythonSDK.core.protocol.internal.workers import SubscriptionManager +from AWSIoTPythonSDK.core.protocol.internal.workers import OfflineRequestsManager +from AWSIoTPythonSDK.core.protocol.internal.clients import InternalAsyncMqttClient +from AWSIoTPythonSDK.core.protocol.internal.requests import QueueableRequest +from AWSIoTPythonSDK.core.protocol.internal.requests import RequestTypes +from AWSIoTPythonSDK.core.protocol.internal.defaults import DEFAULT_DRAINING_INTERNAL_SEC +try: + from mock import patch + from mock import MagicMock + from mock import call +except: + from unittest.mock import patch + from unittest.mock import MagicMock + from unittest.mock import call +from threading import Condition +import time +import sys +if sys.version_info[0] < 3: + from Queue import Queue +else: + from queue import Queue + + +DUMMY_TOPIC = "dummy/topic" +DUMMY_MESSAGE = "dummy_message" +DUMMY_QOS = 1 +DUMMY_SUCCESS_RC = 0 +DUMMY_PUBACK_MID = 89757 +DUMMY_SUBACK_MID = 89758 +DUMMY_UNSUBACK_MID = 89579 + +KEY_CLIENT_STATUS_AFTER = "status_after" +KEY_STOP_BG_NW_IO_CALL_COUNT = "stop_background_network_io_call_count" +KEY_CLEAN_UP_EVENT_CBS_CALL_COUNT = "clean_up_event_callbacks_call_count" +KEY_IS_EVENT_Q_EMPTY = "is_event_queue_empty" +KEY_IS_EVENT_CONSUMER_UP = "is_event_consumer_running" + +class TestWorkersEventConsumer: + + def setup_method(self, test_method): + self.cv = Condition() + self.event_queue = Queue() + self.client_status = ClientStatusContainer() + self.internal_async_client = MagicMock(spec=InternalAsyncMqttClient) + self.subscription_manager = MagicMock(spec=SubscriptionManager) + self.offline_requests_manager = MagicMock(spec=OfflineRequestsManager) + self.message_callback = MagicMock() + self.subscribe_callback = MagicMock() + self.unsubscribe_callback = MagicMock() + self.event_consumer = None + + def teardown_method(self, test_method): + if self.event_consumer and self.event_consumer.is_running(): + self.event_consumer.stop() + self.event_consumer.wait_until_it_stops(2) # Make sure the event consumer stops gracefully + + def test_update_draining_interval_sec(self): + EXPECTED_DRAINING_INTERVAL_SEC = 0.5 + self.load_mocks_into_test_target() + self.event_consumer.update_draining_interval_sec(EXPECTED_DRAINING_INTERVAL_SEC) + assert self.event_consumer.get_draining_interval_sec() == EXPECTED_DRAINING_INTERVAL_SEC + + def test_dispatch_message_event(self): + expected_message_event = self._configure_mocks_message_event() + self._start_consumer() + self._verify_message_event_dispatch(expected_message_event) + + def _configure_mocks_message_event(self): + message_event = self._create_message_event(DUMMY_TOPIC, DUMMY_MESSAGE, DUMMY_QOS) + self._fill_in_fake_events([message_event]) + self.subscription_manager.list_records.return_value = [(DUMMY_TOPIC, (DUMMY_QOS, self.message_callback, self.subscribe_callback))] + self.load_mocks_into_test_target() + return message_event + + def _create_message_event(self, topic, payload, qos): + mqtt_message = MQTTMessage() + mqtt_message.topic = topic + mqtt_message.payload = payload + mqtt_message.qos = qos + return FixedEventMids.MESSAGE_MID, EventTypes.MESSAGE, mqtt_message + + def _verify_message_event_dispatch(self, expected_message_event): + expected_message = expected_message_event[2] + self.message_callback.assert_called_once_with(None, None, expected_message) + self.internal_async_client.invoke_event_callback.assert_called_once_with(FixedEventMids.MESSAGE_MID, data=expected_message) + assert self.event_consumer.is_running() is True + + def test_dispatch_disconnect_event_user_disconnect(self): + self._configure_mocks_disconnect_event(ClientStatus.USER_DISCONNECT) + self._start_consumer() + expected_values = { + KEY_CLIENT_STATUS_AFTER : ClientStatus.USER_DISCONNECT, + KEY_STOP_BG_NW_IO_CALL_COUNT : 1, + KEY_CLEAN_UP_EVENT_CBS_CALL_COUNT : 1, + KEY_IS_EVENT_Q_EMPTY : True, + KEY_IS_EVENT_CONSUMER_UP : False + } + self._verify_disconnect_event_dispatch(expected_values) + assert self.event_consumer.is_fully_stopped() is True + + def test_dispatch_disconnect_event_connect_failure(self): + self._configure_mocks_disconnect_event(ClientStatus.CONNECT) + self._start_consumer() + expected_values = { + KEY_CLIENT_STATUS_AFTER : ClientStatus.CONNECT, + KEY_STOP_BG_NW_IO_CALL_COUNT : 1, + KEY_CLEAN_UP_EVENT_CBS_CALL_COUNT : 1, + KEY_IS_EVENT_Q_EMPTY : True, + KEY_IS_EVENT_CONSUMER_UP : False + } + self._verify_disconnect_event_dispatch(expected_values) + assert self.event_consumer.is_fully_stopped() is True + + def test_dispatch_disconnect_event_abnormal_disconnect(self): + self._configure_mocks_disconnect_event(ClientStatus.STABLE) + self._start_consumer() + expected_values = { + KEY_CLIENT_STATUS_AFTER : ClientStatus.ABNORMAL_DISCONNECT, + KEY_STOP_BG_NW_IO_CALL_COUNT : 0, + KEY_CLEAN_UP_EVENT_CBS_CALL_COUNT : 0, + KEY_IS_EVENT_CONSUMER_UP : True + } + self._verify_disconnect_event_dispatch(expected_values) + assert self.event_consumer.is_fully_stopped() is False + + def _configure_mocks_disconnect_event(self, start_client_status): + self.client_status.set_status(start_client_status) + self._fill_in_fake_events([self._create_disconnect_event()]) + self.load_mocks_into_test_target() + + def _create_disconnect_event(self): + return FixedEventMids.DISCONNECT_MID, EventTypes.DISCONNECT, DUMMY_SUCCESS_RC + + def _verify_disconnect_event_dispatch(self, expected_values): + client_status_after = expected_values.get(KEY_CLIENT_STATUS_AFTER) + stop_background_network_io_call_count = expected_values.get(KEY_STOP_BG_NW_IO_CALL_COUNT) + clean_up_event_callbacks_call_count = expected_values.get(KEY_CLEAN_UP_EVENT_CBS_CALL_COUNT) + is_event_queue_empty = expected_values.get(KEY_IS_EVENT_Q_EMPTY) + is_event_consumer_running = expected_values.get(KEY_IS_EVENT_CONSUMER_UP) + + if client_status_after is not None: + assert self.client_status.get_status() == client_status_after + if stop_background_network_io_call_count is not None: + assert self.internal_async_client.stop_background_network_io.call_count == stop_background_network_io_call_count + if clean_up_event_callbacks_call_count is not None: + assert self.internal_async_client.clean_up_event_callbacks.call_count == clean_up_event_callbacks_call_count + if is_event_queue_empty is not None: + assert self.event_queue.empty() == is_event_queue_empty + if is_event_consumer_running is not None: + assert self.event_consumer.is_running() == is_event_consumer_running + + self.internal_async_client.invoke_event_callback.assert_called_once_with(FixedEventMids.DISCONNECT_MID, data=DUMMY_SUCCESS_RC) + + def test_dispatch_connack_event_no_recovery(self): + self._configure_mocks_connack_event() + self._start_consumer() + self._verify_connack_event_dispatch() + + def test_dispatch_connack_event_need_resubscribe(self): + resub_records = [ + (DUMMY_TOPIC + "1", (DUMMY_QOS, self.message_callback, self.subscribe_callback)), + (DUMMY_TOPIC + "2", (DUMMY_QOS, self.message_callback, self.subscribe_callback)), + (DUMMY_TOPIC + "3", (DUMMY_QOS, self.message_callback, self.subscribe_callback)) + ] + self._configure_mocks_connack_event(resubscribe_records=resub_records) + self._start_consumer() + self._verify_connack_event_dispatch(resubscribe_records=resub_records) + + def test_dispatch_connack_event_need_draining(self): + self._configure_mocks_connack_event(need_draining=True) + self._start_consumer() + self._verify_connack_event_dispatch(need_draining=True) + + def test_dispatch_connack_event_need_resubscribe_draining(self): + resub_records = [ + (DUMMY_TOPIC + "1", (DUMMY_QOS, self.message_callback, self.subscribe_callback)), + (DUMMY_TOPIC + "2", (DUMMY_QOS, self.message_callback, self.subscribe_callback)), + (DUMMY_TOPIC + "3", (DUMMY_QOS, self.message_callback, self.subscribe_callback)) + ] + self._configure_mocks_connack_event(resubscribe_records=resub_records, need_draining=True) + self._start_consumer() + self._verify_connack_event_dispatch(resubscribe_records=resub_records, need_draining=True) + + def _configure_mocks_connack_event(self, resubscribe_records=list(), need_draining=False): + self.client_status.set_status(ClientStatus.CONNECT) + self._fill_in_fake_events([self._create_connack_event()]) + self.subscription_manager.list_records.return_value = resubscribe_records + if need_draining: # We pack publish, subscribe and unsubscribe requests into the offline queue + if resubscribe_records: + has_more_side_effect_list = 4 * [True] + else: + has_more_side_effect_list = 5 * [True] + has_more_side_effect_list += [False] + self.offline_requests_manager.has_more.side_effect = has_more_side_effect_list + self.offline_requests_manager.get_next.side_effect = [ + QueueableRequest(RequestTypes.PUBLISH, (DUMMY_TOPIC, DUMMY_MESSAGE, DUMMY_QOS, False)), + QueueableRequest(RequestTypes.SUBSCRIBE, (DUMMY_TOPIC, DUMMY_QOS, self.message_callback, self.subscribe_callback)), + QueueableRequest(RequestTypes.UNSUBSCRIBE, (DUMMY_TOPIC, self.unsubscribe_callback)) + ] + else: + self.offline_requests_manager.has_more.return_value = False + self.load_mocks_into_test_target() + + def _create_connack_event(self): + return FixedEventMids.CONNACK_MID, EventTypes.CONNACK, DUMMY_SUCCESS_RC + + def _verify_connack_event_dispatch(self, resubscribe_records=list(), need_draining=False): + time.sleep(3 * DEFAULT_DRAINING_INTERNAL_SEC) # Make sure resubscribe/draining finishes + assert self.event_consumer.is_running() is True + self.internal_async_client.invoke_event_callback.assert_called_once_with(FixedEventMids.CONNACK_MID, data=DUMMY_SUCCESS_RC) + if resubscribe_records: + resub_call_sequence = [] + for topic, (qos, message_callback, subscribe_callback) in resubscribe_records: + resub_call_sequence.append(call(topic, qos, subscribe_callback)) + self.internal_async_client.subscribe.assert_has_calls(resub_call_sequence) + if need_draining: + assert self.internal_async_client.publish.call_count == 1 + assert self.internal_async_client.unsubscribe.call_count == 1 + assert self.internal_async_client.subscribe.call_count == len(resubscribe_records) + 1 + assert self.event_consumer.is_fully_stopped() is False + + def test_dispatch_puback_suback_unsuback_events(self): + self._configure_mocks_puback_suback_unsuback_events() + self._start_consumer() + self._verify_puback_suback_unsuback_events_dispatch() + + def _configure_mocks_puback_suback_unsuback_events(self): + self.client_status.set_status(ClientStatus.STABLE) + self._fill_in_fake_events([ + self._create_puback_event(DUMMY_PUBACK_MID), + self._create_suback_event(DUMMY_SUBACK_MID), + self._create_unsuback_event(DUMMY_UNSUBACK_MID)]) + self.load_mocks_into_test_target() + + def _verify_puback_suback_unsuback_events_dispatch(self): + assert self.event_consumer.is_running() is True + call_sequence = [ + call(DUMMY_PUBACK_MID, data=None), + call(DUMMY_SUBACK_MID, data=DUMMY_QOS), + call(DUMMY_UNSUBACK_MID, data=None)] + self.internal_async_client.invoke_event_callback.assert_has_calls(call_sequence) + assert self.event_consumer.is_fully_stopped() is False + + def _fill_in_fake_events(self, events): + for event in events: + self.event_queue.put(event) + + def _start_consumer(self): + self.event_consumer.start() + time.sleep(1) # Make sure the event gets picked up by the consumer + + def load_mocks_into_test_target(self): + self.event_consumer = EventConsumer(self.cv, + self.event_queue, + self.internal_async_client, + self.subscription_manager, + self.offline_requests_manager, + self.client_status) + + def _create_puback_event(self, mid): + return mid, EventTypes.PUBACK, None + + def _create_suback_event(self, mid): + return mid, EventTypes.SUBACK, DUMMY_QOS + + def _create_unsuback_event(self, mid): + return mid, EventTypes.UNSUBACK, None diff --git a/test/core/protocol/internal/test_workers_event_producer.py b/test/core/protocol/internal/test_workers_event_producer.py new file mode 100644 index 0000000..fbb97b2 --- /dev/null +++ b/test/core/protocol/internal/test_workers_event_producer.py @@ -0,0 +1,65 @@ +import pytest +from threading import Condition +from AWSIoTPythonSDK.core.protocol.internal.workers import EventProducer +from AWSIoTPythonSDK.core.protocol.internal.events import FixedEventMids +from AWSIoTPythonSDK.core.protocol.internal.events import EventTypes +from AWSIoTPythonSDK.core.protocol.paho.client import MQTTMessage +import sys +if sys.version_info[0] < 3: + from Queue import Queue +else: + from queue import Queue + +DUMMY_PAHO_CLIENT = None +DUMMY_USER_DATA = None +DUMMY_FLAGS = None +DUMMY_GRANTED_QOS = 1 +DUMMY_MID = 89757 +SUCCESS_RC = 0 + +MAX_CV_WAIT_TIME_SEC = 5 + +class TestWorkersEventProducer: + + def setup_method(self, test_method): + self._generate_test_targets() + + def test_produce_on_connect_event(self): + self.event_producer.on_connect(DUMMY_PAHO_CLIENT, DUMMY_USER_DATA, DUMMY_FLAGS, SUCCESS_RC) + self._verify_queued_event(self.event_queue, (FixedEventMids.CONNACK_MID, EventTypes.CONNACK, SUCCESS_RC)) + + def test_produce_on_disconnect_event(self): + self.event_producer.on_disconnect(DUMMY_PAHO_CLIENT, DUMMY_USER_DATA, SUCCESS_RC) + self._verify_queued_event(self.event_queue, (FixedEventMids.DISCONNECT_MID, EventTypes.DISCONNECT, SUCCESS_RC)) + + def test_produce_on_publish_event(self): + self.event_producer.on_publish(DUMMY_PAHO_CLIENT, DUMMY_USER_DATA, DUMMY_MID) + self._verify_queued_event(self.event_queue, (DUMMY_MID, EventTypes.PUBACK, None)) + + def test_produce_on_subscribe_event(self): + self.event_producer.on_subscribe(DUMMY_PAHO_CLIENT, DUMMY_USER_DATA, DUMMY_MID, DUMMY_GRANTED_QOS) + self._verify_queued_event(self.event_queue, (DUMMY_MID, EventTypes.SUBACK, DUMMY_GRANTED_QOS)) + + def test_produce_on_unsubscribe_event(self): + self.event_producer.on_unsubscribe(DUMMY_PAHO_CLIENT, DUMMY_USER_DATA, DUMMY_MID) + self._verify_queued_event(self.event_queue, (DUMMY_MID, EventTypes.UNSUBACK, None)) + + def test_produce_on_message_event(self): + dummy_message = MQTTMessage() + dummy_message.topic = "test/topic" + dummy_message.qos = 1 + dummy_message.payload = "test_payload" + self.event_producer.on_message(DUMMY_PAHO_CLIENT, DUMMY_USER_DATA, dummy_message) + self._verify_queued_event(self.event_queue, (FixedEventMids.MESSAGE_MID, EventTypes.MESSAGE, dummy_message)) + + def _generate_test_targets(self): + self.cv = Condition() + self.event_queue = Queue() + self.event_producer = EventProducer(self.cv, self.event_queue) + + def _verify_queued_event(self, queue, expected_results): + expected_mid, expected_event_type, expected_data = expected_results + actual_mid, actual_event_type, actual_data = queue.get() + assert actual_mid == expected_mid + assert actual_event_type == expected_event_type + assert actual_data == expected_data diff --git a/test/core/protocol/internal/test_workers_offline_requests_manager.py b/test/core/protocol/internal/test_workers_offline_requests_manager.py new file mode 100644 index 0000000..8193718 --- /dev/null +++ b/test/core/protocol/internal/test_workers_offline_requests_manager.py @@ -0,0 +1,69 @@ +import pytest +from AWSIoTPythonSDK.core.protocol.internal.workers import OfflineRequestsManager +from AWSIoTPythonSDK.core.util.enums import DropBehaviorTypes +from AWSIoTPythonSDK.core.protocol.internal.queues import AppendResults + +DEFAULT_QUEUE_SIZE = 3 +FAKE_REQUEST_PREFIX = "Fake Request " + +def test_has_more(): + offline_requests_manager = OfflineRequestsManager(DEFAULT_QUEUE_SIZE, DropBehaviorTypes.DROP_NEWEST) + + assert not offline_requests_manager.has_more() + + offline_requests_manager.add_one(FAKE_REQUEST_PREFIX + "0") + assert offline_requests_manager.has_more() + + +def test_add_more_normal(): + offline_requests_manager = OfflineRequestsManager(DEFAULT_QUEUE_SIZE, DropBehaviorTypes.DROP_NEWEST) + append_result = offline_requests_manager.add_one(FAKE_REQUEST_PREFIX + "0") + + assert append_result == AppendResults.APPEND_SUCCESS + + +def test_add_more_full_drop_newest(): + offline_requests_manager = OfflineRequestsManager(DEFAULT_QUEUE_SIZE, DropBehaviorTypes.DROP_NEWEST) + _overflow_the_queue(offline_requests_manager) + append_result = offline_requests_manager.add_one(FAKE_REQUEST_PREFIX + "A") + + assert append_result == AppendResults.APPEND_FAILURE_QUEUE_FULL + + next_request = offline_requests_manager.get_next() + assert next_request == FAKE_REQUEST_PREFIX + "0" + + +def test_add_more_full_drop_oldest(): + offline_requests_manager = OfflineRequestsManager(DEFAULT_QUEUE_SIZE, DropBehaviorTypes.DROP_OLDEST) + _overflow_the_queue(offline_requests_manager) + append_result = offline_requests_manager.add_one(FAKE_REQUEST_PREFIX + "A") + + assert append_result == AppendResults.APPEND_FAILURE_QUEUE_FULL + + next_request = offline_requests_manager.get_next() + assert next_request == FAKE_REQUEST_PREFIX + "1" + + +def test_add_more_disabled(): + offline_requests_manager = OfflineRequestsManager(0, DropBehaviorTypes.DROP_NEWEST) + append_result = offline_requests_manager.add_one(FAKE_REQUEST_PREFIX + "0") + + assert append_result == AppendResults.APPEND_FAILURE_QUEUE_DISABLED + + +def _overflow_the_queue(offline_requests_manager): + for i in range(0, DEFAULT_QUEUE_SIZE): + offline_requests_manager.add_one(FAKE_REQUEST_PREFIX + str(i)) + + +def test_get_next_normal(): + offline_requests_manager = OfflineRequestsManager(DEFAULT_QUEUE_SIZE, DropBehaviorTypes.DROP_NEWEST) + append_result = offline_requests_manager.add_one(FAKE_REQUEST_PREFIX + "0") + + assert append_result == AppendResults.APPEND_SUCCESS + assert offline_requests_manager.get_next() is not None + + +def test_get_next_empty(): + offline_requests_manager = OfflineRequestsManager(DEFAULT_QUEUE_SIZE, DropBehaviorTypes.DROP_NEWEST) + assert offline_requests_manager.get_next() is None diff --git a/test/core/protocol/internal/test_workers_subscription_manager.py b/test/core/protocol/internal/test_workers_subscription_manager.py new file mode 100644 index 0000000..6e436b7 --- /dev/null +++ b/test/core/protocol/internal/test_workers_subscription_manager.py @@ -0,0 +1,41 @@ +import pytest +from AWSIoTPythonSDK.core.protocol.internal.workers import SubscriptionManager + +DUMMY_TOPIC1 = "topic1" +DUMMY_TOPIC2 = "topic2" + + +def _dummy_callback(client, user_data, message): + pass + + +def test_add_record(): + subscription_manager = SubscriptionManager() + subscription_manager.add_record(DUMMY_TOPIC1, 1, _dummy_callback, _dummy_callback) + + record_list = subscription_manager.list_records() + + assert len(record_list) == 1 + + topic, (qos, message_callback, ack_callback) = record_list[0] + assert topic == DUMMY_TOPIC1 + assert qos == 1 + assert message_callback == _dummy_callback + assert ack_callback == _dummy_callback + + +def test_remove_record(): + subscription_manager = SubscriptionManager() + subscription_manager.add_record(DUMMY_TOPIC1, 1, _dummy_callback, _dummy_callback) + subscription_manager.add_record(DUMMY_TOPIC2, 0, _dummy_callback, _dummy_callback) + subscription_manager.remove_record(DUMMY_TOPIC1) + + record_list = subscription_manager.list_records() + + assert len(record_list) == 1 + + topic, (qos, message_callback, ack_callback) = record_list[0] + assert topic == DUMMY_TOPIC2 + assert qos == 0 + assert message_callback == _dummy_callback + assert ack_callback == _dummy_callback diff --git a/test/core/protocol/test_mqtt_core.py b/test/core/protocol/test_mqtt_core.py new file mode 100644 index 0000000..8469ea6 --- /dev/null +++ b/test/core/protocol/test_mqtt_core.py @@ -0,0 +1,585 @@ +import AWSIoTPythonSDK +from AWSIoTPythonSDK.core.protocol.mqtt_core import MqttCore +from AWSIoTPythonSDK.core.protocol.internal.clients import InternalAsyncMqttClient +from AWSIoTPythonSDK.core.protocol.internal.clients import ClientStatusContainer +from AWSIoTPythonSDK.core.protocol.internal.clients import ClientStatus +from AWSIoTPythonSDK.core.protocol.internal.workers import EventProducer +from AWSIoTPythonSDK.core.protocol.internal.workers import EventConsumer +from AWSIoTPythonSDK.core.protocol.internal.workers import SubscriptionManager +from AWSIoTPythonSDK.core.protocol.internal.workers import OfflineRequestsManager +from AWSIoTPythonSDK.core.protocol.internal.events import FixedEventMids +from AWSIoTPythonSDK.core.protocol.internal.queues import AppendResults +from AWSIoTPythonSDK.core.protocol.internal.requests import RequestTypes +from AWSIoTPythonSDK.core.protocol.internal.defaults import METRICS_PREFIX +from AWSIoTPythonSDK.exception.AWSIoTExceptions import connectError +from AWSIoTPythonSDK.exception.AWSIoTExceptions import connectTimeoutException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import disconnectError +from AWSIoTPythonSDK.exception.AWSIoTExceptions import disconnectTimeoutException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import publishError +from AWSIoTPythonSDK.exception.AWSIoTExceptions import publishTimeoutException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import publishQueueFullException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import publishQueueDisabledException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeError +from AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeTimeoutException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeQueueFullException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeQueueDisabledException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import unsubscribeError +from AWSIoTPythonSDK.exception.AWSIoTExceptions import unsubscribeTimeoutException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import unsubscribeQueueFullException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import unsubscribeQueueDisabledException +from AWSIoTPythonSDK.core.protocol.paho.client import MQTT_ERR_SUCCESS +from AWSIoTPythonSDK.core.protocol.paho.client import MQTT_ERR_ERRNO +from AWSIoTPythonSDK.core.protocol.paho.client import MQTTv311 +from AWSIoTPythonSDK.core.protocol.internal.defaults import ALPN_PROTCOLS +try: + from mock import patch + from mock import MagicMock + from mock import NonCallableMagicMock + from mock import call +except: + from unittest.mock import patch + from unittest.mock import MagicMock + from unittest.mock import NonCallableMagicMock + from unittest.mock import call +from threading import Event +import pytest + + +PATCH_MODULE_LOCATION = "AWSIoTPythonSDK.core.protocol.mqtt_core." +DUMMY_SUCCESS_RC = MQTT_ERR_SUCCESS +DUMMY_FAILURE_RC = MQTT_ERR_ERRNO +DUMMY_REQUEST_MID = 89757 +DUMMY_CLIENT_ID = "CoolClientId" +DUMMY_KEEP_ALIVE_SEC = 60 +DUMMY_TOPIC = "topic/cool" +DUMMY_PAYLOAD = "CoolPayload" +DUMMY_QOS = 1 +DUMMY_USERNAME = "DummyUsername" +DUMMY_PASSWORD = "DummyPassword" + +KEY_EXPECTED_REQUEST_RC = "ExpectedRequestRc" +KEY_EXPECTED_QUEUE_APPEND_RESULT = "ExpectedQueueAppendResult" +KEY_EXPECTED_REQUEST_MID_OVERRIDE = "ExpectedRequestMidOverride" +KEY_EXPECTED_REQUEST_TIMEOUT = "ExpectedRequestTimeout" +SUCCESS_RC_EXPECTED_VALUES = { + KEY_EXPECTED_REQUEST_RC : DUMMY_SUCCESS_RC +} +FAILURE_RC_EXPECTED_VALUES = { + KEY_EXPECTED_REQUEST_RC : DUMMY_FAILURE_RC +} +TIMEOUT_EXPECTED_VALUES = { + KEY_EXPECTED_REQUEST_TIMEOUT : True +} +NO_TIMEOUT_EXPECTED_VALUES = { + KEY_EXPECTED_REQUEST_TIMEOUT : False +} +QUEUED_EXPECTED_VALUES = { + KEY_EXPECTED_QUEUE_APPEND_RESULT : AppendResults.APPEND_SUCCESS +} +QUEUE_FULL_EXPECTED_VALUES = { + KEY_EXPECTED_QUEUE_APPEND_RESULT : AppendResults.APPEND_FAILURE_QUEUE_FULL +} +QUEUE_DISABLED_EXPECTED_VALUES = { + KEY_EXPECTED_QUEUE_APPEND_RESULT : AppendResults.APPEND_FAILURE_QUEUE_DISABLED +} + +class TestMqttCore: + + def setup_class(cls): + cls.configure_internal_async_client = { + RequestTypes.CONNECT : cls._configure_internal_async_client_connect, + RequestTypes.DISCONNECT : cls._configure_internal_async_client_disconnect, + RequestTypes.PUBLISH : cls._configure_internal_async_client_publish, + RequestTypes.SUBSCRIBE : cls._configure_internal_async_client_subscribe, + RequestTypes.UNSUBSCRIBE : cls._configure_internal_async_client_unsubscribe + } + cls.invoke_mqtt_core_async_api = { + RequestTypes.CONNECT : cls._invoke_mqtt_core_connect_async, + RequestTypes.DISCONNECT : cls._invoke_mqtt_core_disconnect_async, + RequestTypes.PUBLISH : cls._invoke_mqtt_core_publish_async, + RequestTypes.SUBSCRIBE : cls._invoke_mqtt_core_subscribe_async, + RequestTypes.UNSUBSCRIBE : cls._invoke_mqtt_core_unsubscribe_async + } + cls.invoke_mqtt_core_sync_api = { + RequestTypes.CONNECT : cls._invoke_mqtt_core_connect, + RequestTypes.DISCONNECT : cls._invoke_mqtt_core_disconnect, + RequestTypes.PUBLISH : cls._invoke_mqtt_core_publish, + RequestTypes.SUBSCRIBE : cls._invoke_mqtt_core_subscribe, + RequestTypes.UNSUBSCRIBE : cls._invoke_mqtt_core_unsubscribe + } + cls.verify_mqtt_core_async_api = { + RequestTypes.CONNECT : cls._verify_mqtt_core_connect_async, + RequestTypes.DISCONNECT : cls._verify_mqtt_core_disconnect_async, + RequestTypes.PUBLISH : cls._verify_mqtt_core_publish_async, + RequestTypes.SUBSCRIBE : cls._verify_mqtt_core_subscribe_async, + RequestTypes.UNSUBSCRIBE : cls._verify_mqtt_core_unsubscribe_async + } + cls.request_error = { + RequestTypes.CONNECT : connectError, + RequestTypes.DISCONNECT : disconnectError, + RequestTypes.PUBLISH : publishError, + RequestTypes.SUBSCRIBE: subscribeError, + RequestTypes.UNSUBSCRIBE: unsubscribeError + } + cls.request_queue_full = { + RequestTypes.PUBLISH : publishQueueFullException, + RequestTypes.SUBSCRIBE: subscribeQueueFullException, + RequestTypes.UNSUBSCRIBE: unsubscribeQueueFullException + } + cls.request_queue_disable = { + RequestTypes.PUBLISH : publishQueueDisabledException, + RequestTypes.SUBSCRIBE : subscribeQueueDisabledException, + RequestTypes.UNSUBSCRIBE : unsubscribeQueueDisabledException + } + cls.request_timeout = { + RequestTypes.CONNECT : connectTimeoutException, + RequestTypes.DISCONNECT : disconnectTimeoutException, + RequestTypes.PUBLISH : publishTimeoutException, + RequestTypes.SUBSCRIBE : subscribeTimeoutException, + RequestTypes.UNSUBSCRIBE : unsubscribeTimeoutException + } + + def _configure_internal_async_client_connect(self, expected_rc, expected_mid=None): + self.internal_async_client_mock.connect.return_value = expected_rc + + def _configure_internal_async_client_disconnect(self, expected_rc, expeected_mid=None): + self.internal_async_client_mock.disconnect.return_value = expected_rc + + def _configure_internal_async_client_publish(self, expected_rc, expected_mid): + self.internal_async_client_mock.publish.return_value = expected_rc, expected_mid + + def _configure_internal_async_client_subscribe(self, expected_rc, expected_mid): + self.internal_async_client_mock.subscribe.return_value = expected_rc, expected_mid + + def _configure_internal_async_client_unsubscribe(self, expected_rc, expected_mid): + self.internal_async_client_mock.unsubscribe.return_value = expected_rc, expected_mid + + def _invoke_mqtt_core_connect_async(self, ack_callback, message_callback): + return self.mqtt_core.connect_async(DUMMY_KEEP_ALIVE_SEC, ack_callback) + + def _invoke_mqtt_core_disconnect_async(self, ack_callback, message_callback): + return self.mqtt_core.disconnect_async(ack_callback) + + def _invoke_mqtt_core_publish_async(self, ack_callback, message_callback): + return self.mqtt_core.publish_async(DUMMY_TOPIC, DUMMY_PAYLOAD, DUMMY_QOS, False, ack_callback) + + def _invoke_mqtt_core_subscribe_async(self, ack_callback, message_callback): + return self.mqtt_core.subscribe_async(DUMMY_TOPIC, DUMMY_QOS, ack_callback, message_callback) + + def _invoke_mqtt_core_unsubscribe_async(self, ack_callback, message_callback): + return self.mqtt_core.unsubscribe_async(DUMMY_TOPIC, ack_callback) + + def _invoke_mqtt_core_connect(self, message_callback): + return self.mqtt_core.connect(DUMMY_KEEP_ALIVE_SEC) + + def _invoke_mqtt_core_disconnect(self, message_callback): + return self.mqtt_core.disconnect() + + def _invoke_mqtt_core_publish(self, message_callback): + return self.mqtt_core.publish(DUMMY_TOPIC, DUMMY_PAYLOAD, DUMMY_QOS) + + def _invoke_mqtt_core_subscribe(self, message_callback): + return self.mqtt_core.subscribe(DUMMY_TOPIC, DUMMY_QOS, message_callback) + + def _invoke_mqtt_core_unsubscribe(self, message_callback): + return self.mqtt_core.unsubscribe(DUMMY_TOPIC) + + def _verify_mqtt_core_connect_async(self, ack_callback, message_callback): + self.internal_async_client_mock.connect.assert_called_once_with(DUMMY_KEEP_ALIVE_SEC, ack_callback) + self.client_status_mock.set_status.assert_called_once_with(ClientStatus.CONNECT) + + def _verify_mqtt_core_disconnect_async(self, ack_callback, message_callback): + self.internal_async_client_mock.disconnect.assert_called_once_with(ack_callback) + self.client_status_mock.set_status.assert_called_once_with(ClientStatus.USER_DISCONNECT) + + def _verify_mqtt_core_publish_async(self, ack_callback, message_callback): + self.internal_async_client_mock.publish.assert_called_once_with(DUMMY_TOPIC, DUMMY_PAYLOAD, DUMMY_QOS, + False, ack_callback) + + def _verify_mqtt_core_subscribe_async(self, ack_callback, message_callback): + self.internal_async_client_mock.subscribe.assert_called_once_with(DUMMY_TOPIC, DUMMY_QOS, ack_callback) + self.subscription_manager_mock.add_record.assert_called_once_with(DUMMY_TOPIC, DUMMY_QOS, message_callback, ack_callback) + + def _verify_mqtt_core_unsubscribe_async(self, ack_callback, message_callback): + self.internal_async_client_mock.unsubscribe.assert_called_once_with(DUMMY_TOPIC, ack_callback) + self.subscription_manager_mock.remove_record.assert_called_once_with(DUMMY_TOPIC) + + def setup_method(self, test_method): + self._use_mock_internal_async_client() + self._use_mock_event_producer() + self._use_mock_event_consumer() + self._use_mock_subscription_manager() + self._use_mock_offline_requests_manager() + self._use_mock_client_status() + self.mqtt_core = MqttCore(DUMMY_CLIENT_ID, True, MQTTv311, False) # We choose x.509 auth type for this test + + def _use_mock_internal_async_client(self): + self.internal_async_client_patcher = patch(PATCH_MODULE_LOCATION + "InternalAsyncMqttClient", + spec=InternalAsyncMqttClient) + self.mock_internal_async_client_constructor = self.internal_async_client_patcher.start() + self.internal_async_client_mock = MagicMock() + self.mock_internal_async_client_constructor.return_value = self.internal_async_client_mock + + def _use_mock_event_producer(self): + self.event_producer_patcher = patch(PATCH_MODULE_LOCATION + "EventProducer", spec=EventProducer) + self.mock_event_producer_constructor = self.event_producer_patcher.start() + self.event_producer_mock = MagicMock() + self.mock_event_producer_constructor.return_value = self.event_producer_mock + + def _use_mock_event_consumer(self): + self.event_consumer_patcher = patch(PATCH_MODULE_LOCATION + "EventConsumer", spec=EventConsumer) + self.mock_event_consumer_constructor = self.event_consumer_patcher.start() + self.event_consumer_mock = MagicMock() + self.mock_event_consumer_constructor.return_value = self.event_consumer_mock + + def _use_mock_subscription_manager(self): + self.subscription_manager_patcher = patch(PATCH_MODULE_LOCATION + "SubscriptionManager", + spec=SubscriptionManager) + self.mock_subscription_manager_constructor = self.subscription_manager_patcher.start() + self.subscription_manager_mock = MagicMock() + self.mock_subscription_manager_constructor.return_value = self.subscription_manager_mock + + def _use_mock_offline_requests_manager(self): + self.offline_requests_manager_patcher = patch(PATCH_MODULE_LOCATION + "OfflineRequestsManager", + spec=OfflineRequestsManager) + self.mock_offline_requests_manager_constructor = self.offline_requests_manager_patcher.start() + self.offline_requests_manager_mock = MagicMock() + self.mock_offline_requests_manager_constructor.return_value = self.offline_requests_manager_mock + + def _use_mock_client_status(self): + self.client_status_patcher = patch(PATCH_MODULE_LOCATION + "ClientStatusContainer", spec=ClientStatusContainer) + self.mock_client_status_constructor = self.client_status_patcher.start() + self.client_status_mock = MagicMock() + self.mock_client_status_constructor.return_value = self.client_status_mock + + def teardown_method(self, test_method): + self.internal_async_client_patcher.stop() + self.event_producer_patcher.stop() + self.event_consumer_patcher.stop() + self.subscription_manager_patcher.stop() + self.offline_requests_manager_patcher.stop() + self.client_status_patcher.stop() + + # Finally... Tests start + def test_use_wss(self): + self.mqtt_core = MqttCore(DUMMY_CLIENT_ID, True, MQTTv311, True) # use wss + assert self.mqtt_core.use_wss() is True + + def test_configure_alpn_protocols(self): + self.mqtt_core.configure_alpn_protocols() + self.internal_async_client_mock.configure_alpn_protocols.assert_called_once_with([ALPN_PROTCOLS]) + + def test_enable_metrics_collection_with_username_in_connect(self): + self._configure_internal_async_client_connect(DUMMY_SUCCESS_RC) + self._use_mock_python_event() + self.python_event_mock.wait.return_value = True + self.mqtt_core.configure_username_password(DUMMY_USERNAME, DUMMY_PASSWORD) + self.mqtt_core.connect(DUMMY_KEEP_ALIVE_SEC) + self.internal_async_client_mock.set_username_password.assert_called_once_with(DUMMY_USERNAME + + METRICS_PREFIX + + AWSIoTPythonSDK.__version__, + DUMMY_PASSWORD) + self.python_event_patcher.stop() + + def test_enable_metrics_collection_with_username_in_connect_async(self): + self._configure_internal_async_client_connect(DUMMY_SUCCESS_RC) + self.mqtt_core.configure_username_password(DUMMY_USERNAME, DUMMY_PASSWORD) + self.mqtt_core.connect_async(DUMMY_KEEP_ALIVE_SEC) + self.internal_async_client_mock.set_username_password.assert_called_once_with(DUMMY_USERNAME + + METRICS_PREFIX + + AWSIoTPythonSDK.__version__, + DUMMY_PASSWORD) + + def test_enable_metrics_collection_without_username_in_connect(self): + self._configure_internal_async_client_connect(DUMMY_SUCCESS_RC) + self._use_mock_python_event() + self.python_event_mock.wait.return_value = True + self.mqtt_core.connect(DUMMY_KEEP_ALIVE_SEC) + self.internal_async_client_mock.set_username_password.assert_called_once_with(METRICS_PREFIX + + AWSIoTPythonSDK.__version__, + None) + self.python_event_patcher.stop() + + def test_enable_metrics_collection_without_username_in_connect_async(self): + self._configure_internal_async_client_connect(DUMMY_SUCCESS_RC) + self.mqtt_core.connect_async(DUMMY_KEEP_ALIVE_SEC) + self.internal_async_client_mock.set_username_password.assert_called_once_with(METRICS_PREFIX + + AWSIoTPythonSDK.__version__, + None) + + def test_disable_metrics_collection_with_username_in_connect(self): + self._configure_internal_async_client_connect(DUMMY_SUCCESS_RC) + self._use_mock_python_event() + self.python_event_mock.wait.return_value = True + self.mqtt_core.disable_metrics_collection() + self.mqtt_core.configure_username_password(DUMMY_USERNAME, DUMMY_PASSWORD) + self.mqtt_core.connect(DUMMY_KEEP_ALIVE_SEC) + self.internal_async_client_mock.set_username_password.assert_called_once_with(DUMMY_USERNAME, DUMMY_PASSWORD) + self.python_event_patcher.stop() + + def test_disable_metrics_collection_with_username_in_connect_async(self): + self._configure_internal_async_client_connect(DUMMY_SUCCESS_RC) + self.mqtt_core.disable_metrics_collection() + self.mqtt_core.configure_username_password(DUMMY_USERNAME, DUMMY_PASSWORD) + self.mqtt_core.connect_async(DUMMY_KEEP_ALIVE_SEC) + self.internal_async_client_mock.set_username_password.assert_called_once_with(DUMMY_USERNAME, DUMMY_PASSWORD) + + def test_disable_metrics_collection_without_username_in_connect(self): + self._configure_internal_async_client_connect(DUMMY_SUCCESS_RC) + self._use_mock_python_event() + self.python_event_mock.wait.return_value = True + self.mqtt_core.disable_metrics_collection() + self.mqtt_core.connect(DUMMY_KEEP_ALIVE_SEC) + self.internal_async_client_mock.set_username_password.assert_called_once_with("", None) + self.python_event_patcher.stop() + + def test_disable_metrics_collection_without_username_in_connect_asyc(self): + self._configure_internal_async_client_connect(DUMMY_SUCCESS_RC) + self.mqtt_core.disable_metrics_collection() + self.mqtt_core.connect_async(DUMMY_KEEP_ALIVE_SEC) + self.internal_async_client_mock.set_username_password.assert_called_once_with("", None) + + def test_connect_async_success_rc(self): + expected_values = { + KEY_EXPECTED_REQUEST_RC : DUMMY_SUCCESS_RC, + KEY_EXPECTED_REQUEST_MID_OVERRIDE : FixedEventMids.CONNACK_MID + } + self._internal_test_async_api_with(RequestTypes.CONNECT, expected_values) + + def test_connect_async_failure_rc(self): + expected_values = { + KEY_EXPECTED_REQUEST_RC : DUMMY_FAILURE_RC, + KEY_EXPECTED_REQUEST_MID_OVERRIDE : FixedEventMids.CONNACK_MID + } + self._internal_test_async_api_with(RequestTypes.CONNECT, expected_values) + + def test_connect_async_when_failure_rc_should_stop_event_consumer(self): + self.internal_async_client_mock.connect.return_value = DUMMY_FAILURE_RC + + with pytest.raises(connectError): + self.mqtt_core.connect_async(DUMMY_KEEP_ALIVE_SEC) + + self.event_consumer_mock.start.assert_called_once() + self.event_consumer_mock.stop.assert_called_once() + self.event_consumer_mock.wait_until_it_stops.assert_called_once() + assert self.client_status_mock.set_status.call_count == 2 + assert self.client_status_mock.set_status.call_args_list == [call(ClientStatus.CONNECT), call(ClientStatus.IDLE)] + + def test_connect_async_when_exception_should_stop_event_consumer(self): + self.internal_async_client_mock.connect.side_effect = Exception("Something weird happened") + + with pytest.raises(Exception): + self.mqtt_core.connect_async(DUMMY_KEEP_ALIVE_SEC) + + self.event_consumer_mock.start.assert_called_once() + self.event_consumer_mock.stop.assert_called_once() + self.event_consumer_mock.wait_until_it_stops.assert_called_once() + assert self.client_status_mock.set_status.call_count == 2 + assert self.client_status_mock.set_status.call_args_list == [call(ClientStatus.CONNECT), call(ClientStatus.IDLE)] + + def test_disconnect_async_success_rc(self): + expected_values = { + KEY_EXPECTED_REQUEST_RC : DUMMY_SUCCESS_RC, + KEY_EXPECTED_REQUEST_MID_OVERRIDE : FixedEventMids.DISCONNECT_MID + } + self._internal_test_async_api_with(RequestTypes.DISCONNECT, expected_values) + + def test_disconnect_async_failure_rc(self): + expected_values = { + KEY_EXPECTED_REQUEST_RC : DUMMY_FAILURE_RC, + KEY_EXPECTED_REQUEST_MID_OVERRIDE : FixedEventMids.DISCONNECT_MID + } + self._internal_test_async_api_with(RequestTypes.DISCONNECT, expected_values) + + def test_publish_async_success_rc(self): + self._internal_test_async_api_with(RequestTypes.PUBLISH, SUCCESS_RC_EXPECTED_VALUES) + + def test_publish_async_failure_rc(self): + self._internal_test_async_api_with(RequestTypes.PUBLISH, FAILURE_RC_EXPECTED_VALUES) + + def test_publish_async_queued(self): + self._internal_test_async_api_with(RequestTypes.PUBLISH, QUEUED_EXPECTED_VALUES) + + def test_publish_async_queue_disabled(self): + self._internal_test_async_api_with(RequestTypes.PUBLISH, QUEUE_DISABLED_EXPECTED_VALUES) + + def test_publish_async_queue_full(self): + self._internal_test_async_api_with(RequestTypes.PUBLISH, QUEUE_FULL_EXPECTED_VALUES) + + def test_subscribe_async_success_rc(self): + self._internal_test_async_api_with(RequestTypes.SUBSCRIBE, SUCCESS_RC_EXPECTED_VALUES) + + def test_subscribe_async_failure_rc(self): + self._internal_test_async_api_with(RequestTypes.SUBSCRIBE, FAILURE_RC_EXPECTED_VALUES) + + def test_subscribe_async_queued(self): + self._internal_test_async_api_with(RequestTypes.SUBSCRIBE, QUEUED_EXPECTED_VALUES) + + def test_subscribe_async_queue_full(self): + self._internal_test_async_api_with(RequestTypes.SUBSCRIBE, QUEUE_FULL_EXPECTED_VALUES) + + def test_subscribe_async_queue_disabled(self): + self._internal_test_async_api_with(RequestTypes.SUBSCRIBE, QUEUE_DISABLED_EXPECTED_VALUES) + + def test_unsubscribe_async_success_rc(self): + self._internal_test_async_api_with(RequestTypes.UNSUBSCRIBE, SUCCESS_RC_EXPECTED_VALUES) + + def test_unsubscribe_async_failure_rc(self): + self._internal_test_async_api_with(RequestTypes.UNSUBSCRIBE, FAILURE_RC_EXPECTED_VALUES) + + def test_unsubscribe_async_queued(self): + self._internal_test_async_api_with(RequestTypes.UNSUBSCRIBE, QUEUED_EXPECTED_VALUES) + + def test_unsubscribe_async_queue_full(self): + self._internal_test_async_api_with(RequestTypes.UNSUBSCRIBE, QUEUE_FULL_EXPECTED_VALUES) + + def test_unsubscribe_async_queue_disabled(self): + self._internal_test_async_api_with(RequestTypes.UNSUBSCRIBE, QUEUE_DISABLED_EXPECTED_VALUES) + + def _internal_test_async_api_with(self, request_type, expected_values): + expected_rc = expected_values.get(KEY_EXPECTED_REQUEST_RC) + expected_append_result = expected_values.get(KEY_EXPECTED_QUEUE_APPEND_RESULT) + expected_request_mid_override = expected_values.get(KEY_EXPECTED_REQUEST_MID_OVERRIDE) + ack_callback = NonCallableMagicMock() + message_callback = NonCallableMagicMock() + + if expected_rc is not None: + self.configure_internal_async_client[request_type](self, expected_rc, DUMMY_REQUEST_MID) + self.client_status_mock.get_status.return_value = ClientStatus.STABLE + if expected_rc == DUMMY_SUCCESS_RC: + mid = self.invoke_mqtt_core_async_api[request_type](self, ack_callback, message_callback) + self.verify_mqtt_core_async_api[request_type](self, ack_callback, message_callback) + if expected_request_mid_override is not None: + assert mid == expected_request_mid_override + else: + assert mid == DUMMY_REQUEST_MID + else: # FAILURE_RC + with pytest.raises(self.request_error[request_type]): + self.invoke_mqtt_core_async_api[request_type](self, ack_callback, message_callback) + + if expected_append_result is not None: + self.client_status_mock.get_status.return_value = ClientStatus.ABNORMAL_DISCONNECT + self.offline_requests_manager_mock.add_one.return_value = expected_append_result + if expected_append_result == AppendResults.APPEND_SUCCESS: + mid = self.invoke_mqtt_core_async_api[request_type](self, ack_callback, message_callback) + assert mid == FixedEventMids.QUEUED_MID + elif expected_append_result == AppendResults.APPEND_FAILURE_QUEUE_FULL: + with pytest.raises(self.request_queue_full[request_type]): + self.invoke_mqtt_core_async_api[request_type](self, ack_callback, message_callback) + else: # AppendResults.APPEND_FAILURE_QUEUE_DISABLED + with pytest.raises(self.request_queue_disable[request_type]): + self.invoke_mqtt_core_async_api[request_type](self, ack_callback, message_callback) + + def test_connect_success(self): + expected_values = { + KEY_EXPECTED_REQUEST_TIMEOUT : False, + KEY_EXPECTED_REQUEST_MID_OVERRIDE : FixedEventMids.CONNACK_MID + } + self._internal_test_sync_api_with(RequestTypes.CONNECT, expected_values) + + def test_connect_timeout(self): + expected_values = { + KEY_EXPECTED_REQUEST_TIMEOUT : True, + KEY_EXPECTED_REQUEST_MID_OVERRIDE : FixedEventMids.CONNACK_MID + } + self._internal_test_sync_api_with(RequestTypes.CONNECT, expected_values) + + def test_disconnect_success(self): + expected_values = { + KEY_EXPECTED_REQUEST_TIMEOUT : False, + KEY_EXPECTED_REQUEST_MID_OVERRIDE : FixedEventMids.DISCONNECT_MID + } + self._internal_test_sync_api_with(RequestTypes.DISCONNECT, expected_values) + + def test_disconnect_timeout(self): + expected_values = { + KEY_EXPECTED_REQUEST_TIMEOUT : True, + KEY_EXPECTED_REQUEST_MID_OVERRIDE : FixedEventMids.DISCONNECT_MID + } + self._internal_test_sync_api_with(RequestTypes.DISCONNECT, expected_values) + + def test_publish_success(self): + self._internal_test_sync_api_with(RequestTypes.PUBLISH, NO_TIMEOUT_EXPECTED_VALUES) + + def test_publish_timeout(self): + self._internal_test_sync_api_with(RequestTypes.PUBLISH, TIMEOUT_EXPECTED_VALUES) + + def test_publish_queued(self): + self._internal_test_sync_api_with(RequestTypes.PUBLISH, QUEUED_EXPECTED_VALUES) + + def test_publish_queue_full(self): + self._internal_test_sync_api_with(RequestTypes.PUBLISH, QUEUE_FULL_EXPECTED_VALUES) + + def test_publish_queue_disabled(self): + self._internal_test_sync_api_with(RequestTypes.PUBLISH, QUEUE_DISABLED_EXPECTED_VALUES) + + def test_subscribe_success(self): + self._internal_test_sync_api_with(RequestTypes.SUBSCRIBE, NO_TIMEOUT_EXPECTED_VALUES) + + def test_subscribe_timeout(self): + self._internal_test_sync_api_with(RequestTypes.SUBSCRIBE, TIMEOUT_EXPECTED_VALUES) + + def test_subscribe_queued(self): + self._internal_test_sync_api_with(RequestTypes.SUBSCRIBE, QUEUED_EXPECTED_VALUES) + + def test_subscribe_queue_full(self): + self._internal_test_sync_api_with(RequestTypes.SUBSCRIBE, QUEUE_FULL_EXPECTED_VALUES) + + def test_subscribe_queue_disabled(self): + self._internal_test_sync_api_with(RequestTypes.SUBSCRIBE, QUEUE_DISABLED_EXPECTED_VALUES) + + def test_unsubscribe_success(self): + self._internal_test_sync_api_with(RequestTypes.UNSUBSCRIBE, NO_TIMEOUT_EXPECTED_VALUES) + + def test_unsubscribe_timeout(self): + self._internal_test_sync_api_with(RequestTypes.UNSUBSCRIBE, TIMEOUT_EXPECTED_VALUES) + + def test_unsubscribe_queued(self): + self._internal_test_sync_api_with(RequestTypes.UNSUBSCRIBE, QUEUED_EXPECTED_VALUES) + + def test_unsubscribe_queue_full(self): + self._internal_test_sync_api_with(RequestTypes.UNSUBSCRIBE, QUEUE_FULL_EXPECTED_VALUES) + + def test_unsubscribe_queue_disabled(self): + self._internal_test_sync_api_with(RequestTypes.UNSUBSCRIBE, QUEUE_DISABLED_EXPECTED_VALUES) + + def _internal_test_sync_api_with(self, request_type, expected_values): + expected_request_mid = expected_values.get(KEY_EXPECTED_REQUEST_MID_OVERRIDE) + expected_timeout = expected_values.get(KEY_EXPECTED_REQUEST_TIMEOUT) + expected_append_result = expected_values.get(KEY_EXPECTED_QUEUE_APPEND_RESULT) + + if expected_request_mid is None: + expected_request_mid = DUMMY_REQUEST_MID + message_callback = NonCallableMagicMock() + self.configure_internal_async_client[request_type](self, DUMMY_SUCCESS_RC, expected_request_mid) + self._use_mock_python_event() + + if expected_timeout is not None: + self.client_status_mock.get_status.return_value = ClientStatus.STABLE + if expected_timeout: + self.python_event_mock.wait.return_value = False + with pytest.raises(self.request_timeout[request_type]): + self.invoke_mqtt_core_sync_api[request_type](self, message_callback) + else: + self.python_event_mock.wait.return_value = True + assert self.invoke_mqtt_core_sync_api[request_type](self, message_callback) is True + + if expected_append_result is not None: + self.client_status_mock.get_status.return_value = ClientStatus.ABNORMAL_DISCONNECT + self.offline_requests_manager_mock.add_one.return_value = expected_append_result + if expected_append_result == AppendResults.APPEND_SUCCESS: + assert self.invoke_mqtt_core_sync_api[request_type](self, message_callback) is False + elif expected_append_result == AppendResults.APPEND_FAILURE_QUEUE_FULL: + with pytest.raises(self.request_queue_full[request_type]): + self.invoke_mqtt_core_sync_api[request_type](self, message_callback) + else: + with pytest.raises(self.request_queue_disable[request_type]): + self.invoke_mqtt_core_sync_api[request_type](self, message_callback) + + self.python_event_patcher.stop() + + def _use_mock_python_event(self): + self.python_event_patcher = patch(PATCH_MODULE_LOCATION + "Event", spec=Event) + self.python_event_constructor = self.python_event_patcher.start() + self.python_event_mock = MagicMock() + self.python_event_constructor.return_value = self.python_event_mock diff --git a/test/core/shadow/__init__.py b/test/core/shadow/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/core/shadow/test_device_shadow.py b/test/core/shadow/test_device_shadow.py new file mode 100755 index 0000000..3b4ec61 --- /dev/null +++ b/test/core/shadow/test_device_shadow.py @@ -0,0 +1,297 @@ +# Test shadow behavior for a single device shadow + +from AWSIoTPythonSDK.core.shadow.deviceShadow import deviceShadow +from AWSIoTPythonSDK.core.shadow.shadowManager import shadowManager +from AWSIoTPythonSDK.core.protocol.paho.client import MQTTMessage +import time +import json +try: + from mock import MagicMock +except: + from unittest.mock import MagicMock + + +DUMMY_THING_NAME = "CoolThing" +DUMMY_SHADOW_OP_TIME_OUT_SEC = 3 + +SHADOW_OP_TYPE_GET = "get" +SHADOW_OP_TYPE_DELETE = "delete" +SHADOW_OP_TYPE_UPDATE = "update" +SHADOW_OP_RESPONSE_STATUS_ACCEPTED = "accepted" +SHADOW_OP_RESPONSE_STATUS_REJECTED = "rejected" +SHADOW_OP_RESPONSE_STATUS_TIMEOUT = "timeout" +SHADOW_OP_RESPONSE_STATUS_DELTA = "delta" + +SHADOW_TOPIC_PREFIX = "$aws/things/" +SHADOW_TOPIC_GET_ACCEPTED = SHADOW_TOPIC_PREFIX + DUMMY_THING_NAME + "/shadow/get/accepted" +SHADOW_TOPIC_GET_REJECTED = SHADOW_TOPIC_PREFIX + DUMMY_THING_NAME + "/shadow/get/rejected" +SHADOW_TOPIC_DELETE_ACCEPTED = SHADOW_TOPIC_PREFIX + DUMMY_THING_NAME + "/shadow/delete/accepted" +SHADOW_TOPIC_DELETE_REJECTED = SHADOW_TOPIC_PREFIX + DUMMY_THING_NAME + "/shadow/delete/rejected" +SHADOW_TOPIC_UPDATE_ACCEPTED = SHADOW_TOPIC_PREFIX + DUMMY_THING_NAME + "/shadow/update/accepted" +SHADOW_TOPIC_UPDATE_REJECTED = SHADOW_TOPIC_PREFIX + DUMMY_THING_NAME + "/shadow/update/rejected" +SHADOW_TOPIC_UPDATE_DELTA = SHADOW_TOPIC_PREFIX + DUMMY_THING_NAME + "/shadow/update/delta" +SHADOW_RESPONSE_PAYLOAD_TIMEOUT = "REQUEST TIME OUT" + +VALUE_OVERRIDE_KEY_INBOUND_PAYLOAD = "InBoundPayload" +VALUE_OVERRIDE_KEY_OUTBOUND_PAYLOAD = "OutBoundPayload" + +GARBAGE_PAYLOAD = b"ThisIsGarbagePayload!" + +VALUE_OVERRIDE_GARBAGE_INBOUND_PAYLOAD = { + VALUE_OVERRIDE_KEY_INBOUND_PAYLOAD : GARBAGE_PAYLOAD +} + + +class TestDeviceShadow: + + def setup_class(cls): + cls.invoke_shadow_operation = { + SHADOW_OP_TYPE_GET : cls._invoke_shadow_get, + SHADOW_OP_TYPE_DELETE : cls._invoke_shadow_delete, + SHADOW_OP_TYPE_UPDATE : cls._invoke_shadow_update + } + cls._get_topics = { + SHADOW_OP_RESPONSE_STATUS_ACCEPTED : SHADOW_TOPIC_GET_ACCEPTED, + SHADOW_OP_RESPONSE_STATUS_REJECTED : SHADOW_TOPIC_GET_REJECTED, + } + cls._delete_topics = { + SHADOW_OP_RESPONSE_STATUS_ACCEPTED : SHADOW_TOPIC_DELETE_ACCEPTED, + SHADOW_OP_RESPONSE_STATUS_REJECTED : SHADOW_TOPIC_DELETE_REJECTED + } + cls._update_topics = { + SHADOW_OP_RESPONSE_STATUS_ACCEPTED : SHADOW_TOPIC_UPDATE_ACCEPTED, + SHADOW_OP_RESPONSE_STATUS_REJECTED : SHADOW_TOPIC_UPDATE_REJECTED, + SHADOW_OP_RESPONSE_STATUS_DELTA : SHADOW_TOPIC_UPDATE_DELTA + } + cls.shadow_topics = { + SHADOW_OP_TYPE_GET : cls._get_topics, + SHADOW_OP_TYPE_DELETE : cls._delete_topics, + SHADOW_OP_TYPE_UPDATE : cls._update_topics + } + + def _invoke_shadow_get(self): + return self.device_shadow_handler.shadowGet(self.shadow_callback, DUMMY_SHADOW_OP_TIME_OUT_SEC) + + def _invoke_shadow_delete(self): + return self.device_shadow_handler.shadowDelete(self.shadow_callback, DUMMY_SHADOW_OP_TIME_OUT_SEC) + + def _invoke_shadow_update(self): + return self.device_shadow_handler.shadowUpdate("{}", self.shadow_callback, DUMMY_SHADOW_OP_TIME_OUT_SEC) + + def setup_method(self, method): + self.shadow_manager_mock = MagicMock(spec=shadowManager) + self.shadow_callback = MagicMock() + self._create_device_shadow_handler() # Create device shadow handler with persistent subscribe by default + + def _create_device_shadow_handler(self, is_persistent_subscribe=True): + self.device_shadow_handler = deviceShadow(DUMMY_THING_NAME, is_persistent_subscribe, self.shadow_manager_mock) + + # Shadow delta + def test_register_delta_callback_older_version_should_not_invoke(self): + self.device_shadow_handler.shadowRegisterDeltaCallback(self.shadow_callback) + self._fake_incoming_delta_message_with(version=3) + + # Make next delta message with an old version + self._fake_incoming_delta_message_with(version=1) + + assert self.shadow_callback.call_count == 1 # Once time from the previous delta message + + def test_unregister_delta_callback_should_not_invoke_after(self): + self.device_shadow_handler.shadowRegisterDeltaCallback(self.shadow_callback) + fake_delta_message = self._fake_incoming_delta_message_with(version=3) + self.shadow_callback.assert_called_once_with(fake_delta_message.payload.decode("utf-8"), + SHADOW_OP_RESPONSE_STATUS_DELTA + "/" + DUMMY_THING_NAME, + None) + + # Now we unregister + self.device_shadow_handler.shadowUnregisterDeltaCallback() + self._fake_incoming_delta_message_with(version=5) + assert self.shadow_callback.call_count == 1 # One time from the previous delta message + + def test_register_delta_callback_newer_version_should_invoke(self): + self.device_shadow_handler.shadowRegisterDeltaCallback(self.shadow_callback) + fake_delta_message = self._fake_incoming_delta_message_with(version=300) + + self.shadow_callback.assert_called_once_with(fake_delta_message.payload.decode("utf-8"), + SHADOW_OP_RESPONSE_STATUS_DELTA + "/" + DUMMY_THING_NAME, + None) + + def test_register_delta_callback_no_version_should_not_invoke(self): + self.device_shadow_handler.shadowRegisterDeltaCallback(self.shadow_callback) + self._fake_incoming_delta_message_with(version=None) + + assert self.shadow_callback.call_count == 0 + + def _fake_incoming_delta_message_with(self, version): + fake_delta_message = self._create_fake_shadow_response(SHADOW_TOPIC_UPDATE_DELTA, + self._create_simple_payload(token=None, version=version)) + self.device_shadow_handler.generalCallback(None, None, fake_delta_message) + time.sleep(1) # Callback executed in another thread, wait to make sure the artifacts are generated + return fake_delta_message + + # Shadow get + def test_persistent_shadow_get_accepted(self): + self._internal_test_persistent_shadow_operation(SHADOW_OP_TYPE_GET, SHADOW_OP_RESPONSE_STATUS_ACCEPTED) + + def test_persistent_shadow_get_rejected(self): + self._internal_test_persistent_shadow_operation(SHADOW_OP_TYPE_GET, SHADOW_OP_RESPONSE_STATUS_REJECTED) + + def test_persistent_shadow_get_time_out(self): + self._internal_test_persistent_shadow_operation(SHADOW_OP_TYPE_GET, SHADOW_OP_RESPONSE_STATUS_TIMEOUT) + + def test_persistent_shadow_get_garbage_response_should_time_out(self): + self._internal_test_persistent_shadow_operation(SHADOW_OP_TYPE_GET, + SHADOW_OP_RESPONSE_STATUS_ACCEPTED, + value_override=VALUE_OVERRIDE_GARBAGE_INBOUND_PAYLOAD) + + def test_non_persistent_shadow_get_accepted(self): + self._internal_test_non_persistent_shadow_operation(SHADOW_OP_TYPE_GET, SHADOW_OP_RESPONSE_STATUS_ACCEPTED) + + def test_non_persistent_shadow_get_rejected(self): + self._internal_test_non_persistent_shadow_operation(SHADOW_OP_TYPE_GET, SHADOW_OP_RESPONSE_STATUS_REJECTED) + + def test_non_persistent_shadow_get_time_out(self): + self._internal_test_non_persistent_shadow_operation(SHADOW_OP_TYPE_GET, SHADOW_OP_RESPONSE_STATUS_TIMEOUT) + + def test_non_persistent_shadow_get_garbage_response_should_time_out(self): + self._internal_test_non_persistent_shadow_operation(SHADOW_OP_TYPE_GET, + SHADOW_OP_RESPONSE_STATUS_ACCEPTED, + value_override=VALUE_OVERRIDE_GARBAGE_INBOUND_PAYLOAD) + + # Shadow delete + def test_persistent_shadow_delete_accepted(self): + self._internal_test_persistent_shadow_operation(SHADOW_OP_TYPE_DELETE, SHADOW_OP_RESPONSE_STATUS_ACCEPTED) + + def test_persistent_shadow_delete_rejected(self): + self._internal_test_persistent_shadow_operation(SHADOW_OP_TYPE_DELETE, SHADOW_OP_RESPONSE_STATUS_REJECTED) + + def test_persistent_shadow_delete_time_out(self): + self._internal_test_persistent_shadow_operation(SHADOW_OP_TYPE_DELETE, SHADOW_OP_RESPONSE_STATUS_TIMEOUT) + + def test_persistent_shadow_delete_garbage_response_should_time_out(self): + self._internal_test_persistent_shadow_operation(SHADOW_OP_TYPE_DELETE, + SHADOW_OP_RESPONSE_STATUS_ACCEPTED, + value_override=VALUE_OVERRIDE_GARBAGE_INBOUND_PAYLOAD) + + def test_non_persistent_shadow_delete_accepted(self): + self._internal_test_non_persistent_shadow_operation(SHADOW_OP_TYPE_DELETE, SHADOW_OP_RESPONSE_STATUS_ACCEPTED) + + def test_non_persistent_shadow_delete_rejected(self): + self._internal_test_non_persistent_shadow_operation(SHADOW_OP_TYPE_DELETE, SHADOW_OP_RESPONSE_STATUS_REJECTED) + + def test_non_persistent_shadow_delete_time_out(self): + self._internal_test_non_persistent_shadow_operation(SHADOW_OP_TYPE_DELETE, SHADOW_OP_RESPONSE_STATUS_TIMEOUT) + + def test_non_persistent_shadow_delete_garbage_response_should_time_out(self): + self._internal_test_non_persistent_shadow_operation(SHADOW_OP_TYPE_DELETE, + SHADOW_OP_RESPONSE_STATUS_ACCEPTED, + value_override=VALUE_OVERRIDE_GARBAGE_INBOUND_PAYLOAD) + + # Shadow update + def test_persistent_shadow_update_accepted(self): + self._internal_test_persistent_shadow_operation(SHADOW_OP_TYPE_UPDATE, SHADOW_OP_RESPONSE_STATUS_ACCEPTED) + + def test_persistent_shadow_update_rejected(self): + self._internal_test_persistent_shadow_operation(SHADOW_OP_TYPE_UPDATE, SHADOW_OP_RESPONSE_STATUS_REJECTED) + + def test_persistent_shadow_update_time_out(self): + self._internal_test_persistent_shadow_operation(SHADOW_OP_TYPE_UPDATE, SHADOW_OP_RESPONSE_STATUS_TIMEOUT) + + def test_persistent_shadow_update_garbage_response_should_time_out(self): + self._internal_test_persistent_shadow_operation(SHADOW_OP_TYPE_UPDATE, + SHADOW_OP_RESPONSE_STATUS_ACCEPTED, + value_override=VALUE_OVERRIDE_GARBAGE_INBOUND_PAYLOAD) + + def test_non_persistent_shadow_update_accepted(self): + self._internal_test_non_persistent_shadow_operation(SHADOW_OP_TYPE_UPDATE, SHADOW_OP_RESPONSE_STATUS_ACCEPTED) + + def test_non_persistent_shadow_update_rejected(self): + self._internal_test_non_persistent_shadow_operation(SHADOW_OP_TYPE_UPDATE, SHADOW_OP_RESPONSE_STATUS_REJECTED) + + def test_non_persistent_shadow_update_time_out(self): + self._internal_test_non_persistent_shadow_operation(SHADOW_OP_TYPE_UPDATE, SHADOW_OP_RESPONSE_STATUS_TIMEOUT) + + def test_non_persistent_shadow_update_garbage_response_should_time_out(self): + self._internal_test_non_persistent_shadow_operation(SHADOW_OP_TYPE_UPDATE, + SHADOW_OP_RESPONSE_STATUS_ACCEPTED, + value_override=VALUE_OVERRIDE_GARBAGE_INBOUND_PAYLOAD) + + def _internal_test_non_persistent_shadow_operation(self, operation_type, operation_response_type, value_override=None): + self._create_device_shadow_handler(is_persistent_subscribe=False) + token = self.invoke_shadow_operation[operation_type](self) + inbound_payload, wait_time_sec, expected_response_type, expected_shadow_response_payload = \ + self._prepare_test_values(token, operation_response_type, value_override) + expected_shadow_response_payload = \ + self._invoke_shadow_general_callback_on_demand(operation_type, operation_response_type, + (inbound_payload, wait_time_sec, expected_shadow_response_payload)) + + self._assert_first_call_correct(operation_type, (token, expected_response_type, expected_shadow_response_payload)) + + def _internal_test_persistent_shadow_operation(self, operation_type, operation_response_type, value_override=None): + token = self.invoke_shadow_operation[operation_type](self) + inbound_payload, wait_time_sec, expected_response_type, expected_shadow_response_payload = \ + self._prepare_test_values(token, operation_response_type, value_override) + expected_shadow_response_payload = \ + self._invoke_shadow_general_callback_on_demand(operation_type, operation_response_type, + (inbound_payload, wait_time_sec, expected_shadow_response_payload)) + + self._assert_first_call_correct(operation_type, + (token, expected_response_type, expected_shadow_response_payload), + is_persistent=True) + + def _prepare_test_values(self, token, operation_response_type, value_override): + inbound_payload = None + if value_override: + inbound_payload = value_override.get(VALUE_OVERRIDE_KEY_INBOUND_PAYLOAD) + if inbound_payload is None: + inbound_payload = self._create_simple_payload(token, version=3) # Should be bytes in Py3 + if inbound_payload == GARBAGE_PAYLOAD: + expected_shadow_response_payload = SHADOW_RESPONSE_PAYLOAD_TIMEOUT + wait_time_sec = DUMMY_SHADOW_OP_TIME_OUT_SEC + 1 + expected_response_type = SHADOW_OP_RESPONSE_STATUS_TIMEOUT + else: + expected_shadow_response_payload = inbound_payload.decode("utf-8") # Should always be str in Py2/3 + wait_time_sec = 1 + expected_response_type = operation_response_type + + return inbound_payload, wait_time_sec, expected_response_type, expected_shadow_response_payload + + def _invoke_shadow_general_callback_on_demand(self, operation_type, operation_response_type, data): + inbound_payload, wait_time_sec, expected_shadow_response_payload = data + + if operation_response_type == SHADOW_OP_RESPONSE_STATUS_TIMEOUT: + time.sleep(DUMMY_SHADOW_OP_TIME_OUT_SEC + 1) # Make it time out for sure + return SHADOW_RESPONSE_PAYLOAD_TIMEOUT + else: + fake_shadow_response = self._create_fake_shadow_response(self.shadow_topics[operation_type][operation_response_type], + inbound_payload) + self.device_shadow_handler.generalCallback(None, None, fake_shadow_response) + time.sleep(wait_time_sec) # Callback executed in another thread, wait to make sure the artifacts are generated + return expected_shadow_response_payload + + def _assert_first_call_correct(self, operation_type, expected_data, is_persistent=False): + token, expected_response_type, expected_shadow_response_payload = expected_data + + self.shadow_manager_mock.basicShadowSubscribe.assert_called_once_with(DUMMY_THING_NAME, operation_type, + self.device_shadow_handler.generalCallback) + self.shadow_manager_mock.basicShadowPublish.\ + assert_called_once_with(DUMMY_THING_NAME, + operation_type, + self._create_simple_payload(token, version=None).decode("utf-8")) + self.shadow_callback.assert_called_once_with(expected_shadow_response_payload, expected_response_type, token) + if not is_persistent: + self.shadow_manager_mock.basicShadowUnsubscribe.assert_called_once_with(DUMMY_THING_NAME, operation_type) + + def _create_fake_shadow_response(self, topic, payload): + response = MQTTMessage() + response.topic = topic + response.payload = payload + return response + + def _create_simple_payload(self, token, version): + payload_object = dict() + if token is not None: + payload_object["clientToken"] = token + if version is not None: + payload_object["version"] = version + return json.dumps(payload_object).encode("utf-8") diff --git a/test/core/shadow/test_shadow_manager.py b/test/core/shadow/test_shadow_manager.py new file mode 100644 index 0000000..f99bdb5 --- /dev/null +++ b/test/core/shadow/test_shadow_manager.py @@ -0,0 +1,83 @@ +from AWSIoTPythonSDK.core.protocol.mqtt_core import MqttCore +from AWSIoTPythonSDK.core.shadow.shadowManager import shadowManager +try: + from mock import MagicMock +except: + from unittest.mock import MagicMock +try: + from mock import NonCallableMagicMock +except: + from unittest.mock import NonCallableMagicMock +try: + from mock import call +except: + from unittest.mock import call +import pytest + + +DUMMY_SHADOW_NAME = "CoolShadow" +DUMMY_PAYLOAD = "{}" + +OP_SHADOW_GET = "get" +OP_SHADOW_UPDATE = "update" +OP_SHADOW_DELETE = "delete" +OP_SHADOW_DELTA = "delta" +OP_SHADOW_TROUBLE_MAKER = "not_a_valid_shadow_aciton_name" + +DUMMY_SHADOW_TOPIC_PREFIX = "$aws/things/" + DUMMY_SHADOW_NAME + "/shadow/" +DUMMY_SHADOW_TOPIC_GET = DUMMY_SHADOW_TOPIC_PREFIX + "get" +DUMMY_SHADOW_TOPIC_GET_ACCEPTED = DUMMY_SHADOW_TOPIC_GET + "/accepted" +DUMMY_SHADOW_TOPIC_GET_REJECTED = DUMMY_SHADOW_TOPIC_GET + "/rejected" +DUMMY_SHADOW_TOPIC_UPDATE = DUMMY_SHADOW_TOPIC_PREFIX + "update" +DUMMY_SHADOW_TOPIC_UPDATE_ACCEPTED = DUMMY_SHADOW_TOPIC_UPDATE + "/accepted" +DUMMY_SHADOW_TOPIC_UPDATE_REJECTED = DUMMY_SHADOW_TOPIC_UPDATE + "/rejected" +DUMMY_SHADOW_TOPIC_UPDATE_DELTA = DUMMY_SHADOW_TOPIC_UPDATE + "/delta" +DUMMY_SHADOW_TOPIC_DELETE = DUMMY_SHADOW_TOPIC_PREFIX + "delete" +DUMMY_SHADOW_TOPIC_DELETE_ACCEPTED = DUMMY_SHADOW_TOPIC_DELETE + "/accepted" +DUMMY_SHADOW_TOPIC_DELETE_REJECTED = DUMMY_SHADOW_TOPIC_DELETE + "/rejected" + + +class TestShadowManager: + + def setup_method(self, test_method): + self.mock_mqtt_core = MagicMock(spec=MqttCore) + self.shadow_manager = shadowManager(self.mock_mqtt_core) + + def test_basic_shadow_publish(self): + self.shadow_manager.basicShadowPublish(DUMMY_SHADOW_NAME, OP_SHADOW_GET, DUMMY_PAYLOAD) + self.shadow_manager.basicShadowPublish(DUMMY_SHADOW_NAME, OP_SHADOW_UPDATE, DUMMY_PAYLOAD) + self.shadow_manager.basicShadowPublish(DUMMY_SHADOW_NAME, OP_SHADOW_DELETE, DUMMY_PAYLOAD) + self.mock_mqtt_core.publish.assert_has_calls([call(DUMMY_SHADOW_TOPIC_GET, DUMMY_PAYLOAD, 0, False), + call(DUMMY_SHADOW_TOPIC_UPDATE, DUMMY_PAYLOAD, 0, False), + call(DUMMY_SHADOW_TOPIC_DELETE, DUMMY_PAYLOAD, 0, False)]) + + def test_basic_shadow_subscribe(self): + callback = NonCallableMagicMock() + self.shadow_manager.basicShadowSubscribe(DUMMY_SHADOW_NAME, OP_SHADOW_GET, callback) + self.shadow_manager.basicShadowSubscribe(DUMMY_SHADOW_NAME, OP_SHADOW_UPDATE, callback) + self.shadow_manager.basicShadowSubscribe(DUMMY_SHADOW_NAME, OP_SHADOW_DELETE, callback) + self.shadow_manager.basicShadowSubscribe(DUMMY_SHADOW_NAME, OP_SHADOW_DELTA, callback) + self.mock_mqtt_core.subscribe.assert_has_calls([call(DUMMY_SHADOW_TOPIC_GET_ACCEPTED, 0, callback), + call(DUMMY_SHADOW_TOPIC_GET_REJECTED, 0, callback), + call(DUMMY_SHADOW_TOPIC_UPDATE_ACCEPTED, 0, callback), + call(DUMMY_SHADOW_TOPIC_UPDATE_REJECTED, 0, callback), + call(DUMMY_SHADOW_TOPIC_DELETE_ACCEPTED, 0, callback), + call(DUMMY_SHADOW_TOPIC_DELETE_REJECTED, 0, callback), + call(DUMMY_SHADOW_TOPIC_UPDATE_DELTA, 0, callback)]) + + def test_basic_shadow_unsubscribe(self): + self.shadow_manager.basicShadowUnsubscribe(DUMMY_SHADOW_NAME, OP_SHADOW_GET) + self.shadow_manager.basicShadowUnsubscribe(DUMMY_SHADOW_NAME, OP_SHADOW_UPDATE) + self.shadow_manager.basicShadowUnsubscribe(DUMMY_SHADOW_NAME, OP_SHADOW_DELETE) + self.shadow_manager.basicShadowUnsubscribe(DUMMY_SHADOW_NAME, OP_SHADOW_DELTA) + self.mock_mqtt_core.unsubscribe.assert_has_calls([call(DUMMY_SHADOW_TOPIC_GET_ACCEPTED), + call(DUMMY_SHADOW_TOPIC_GET_REJECTED), + call(DUMMY_SHADOW_TOPIC_UPDATE_ACCEPTED), + call(DUMMY_SHADOW_TOPIC_UPDATE_REJECTED), + call(DUMMY_SHADOW_TOPIC_DELETE_ACCEPTED), + call(DUMMY_SHADOW_TOPIC_DELETE_REJECTED), + call(DUMMY_SHADOW_TOPIC_UPDATE_DELTA)]) + + def test_unsupported_shadow_action_name(self): + with pytest.raises(TypeError): + self.shadow_manager.basicShadowUnsubscribe(DUMMY_SHADOW_NAME, OP_SHADOW_TROUBLE_MAKER) diff --git a/test/core/util/__init__.py b/test/core/util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/core/util/test_providers.py b/test/core/util/test_providers.py new file mode 100644 index 0000000..0515790 --- /dev/null +++ b/test/core/util/test_providers.py @@ -0,0 +1,46 @@ +from AWSIoTPythonSDK.core.util.providers import CertificateCredentialsProvider +from AWSIoTPythonSDK.core.util.providers import IAMCredentialsProvider +from AWSIoTPythonSDK.core.util.providers import EndpointProvider + + +DUMMY_PATH = "/dummy/path/" +DUMMY_CERT_PATH = DUMMY_PATH + "cert.pem" +DUMMY_CA_PATH = DUMMY_PATH + "ca.crt" +DUMMY_KEY_PATH = DUMMY_PATH + "key.pem" +DUMMY_ACCESS_KEY_ID = "AccessKey" +DUMMY_SECRET_KEY = "SecretKey" +DUMMY_SESSION_TOKEN = "SessionToken" +DUMMY_HOST = "dummy.host.com" +DUMMY_PORT = 8888 + + +class TestProviders: + + def setup_method(self, test_method): + self.certificate_credentials_provider = CertificateCredentialsProvider() + self.iam_credentials_provider = IAMCredentialsProvider() + self.endpoint_provider = EndpointProvider() + + def test_certificate_credentials_provider(self): + self.certificate_credentials_provider.set_ca_path(DUMMY_CA_PATH) + self.certificate_credentials_provider.set_cert_path(DUMMY_CERT_PATH) + self.certificate_credentials_provider.set_key_path(DUMMY_KEY_PATH) + assert self.certificate_credentials_provider.get_ca_path() == DUMMY_CA_PATH + assert self.certificate_credentials_provider.get_cert_path() == DUMMY_CERT_PATH + assert self.certificate_credentials_provider.get_key_path() == DUMMY_KEY_PATH + + def test_iam_credentials_provider(self): + self.iam_credentials_provider.set_ca_path(DUMMY_CA_PATH) + self.iam_credentials_provider.set_access_key_id(DUMMY_ACCESS_KEY_ID) + self.iam_credentials_provider.set_secret_access_key(DUMMY_SECRET_KEY) + self.iam_credentials_provider.set_session_token(DUMMY_SESSION_TOKEN) + assert self.iam_credentials_provider.get_ca_path() == DUMMY_CA_PATH + assert self.iam_credentials_provider.get_access_key_id() == DUMMY_ACCESS_KEY_ID + assert self.iam_credentials_provider.get_secret_access_key() == DUMMY_SECRET_KEY + assert self.iam_credentials_provider.get_session_token() == DUMMY_SESSION_TOKEN + + def test_endpoint_provider(self): + self.endpoint_provider.set_host(DUMMY_HOST) + self.endpoint_provider.set_port(DUMMY_PORT) + assert self.endpoint_provider.get_host() == DUMMY_HOST + assert self.endpoint_provider.get_port() == DUMMY_PORT diff --git a/test/sdk_mock/__init__.py b/test/sdk_mock/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/test/sdk_mock/mockAWSIoTPythonSDK.py b/test/sdk_mock/mockAWSIoTPythonSDK.py new file mode 100755 index 0000000..b362570 --- /dev/null +++ b/test/sdk_mock/mockAWSIoTPythonSDK.py @@ -0,0 +1,34 @@ +import sys +import mockMQTTCore +import mockMQTTCoreQuiet +from AWSIoTPythonSDK import MQTTLib +import AWSIoTPythonSDK.core.shadow.shadowManager as shadowManager + +class mockAWSIoTMQTTClient(MQTTLib.AWSIoTMQTTClient): + def __init__(self, clientID, protocolType, useWebsocket=False, cleanSession=True): + self._mqttCore = mockMQTTCore.mockMQTTCore(clientID, cleanSession, protocolType, useWebsocket) + +class mockAWSIoTMQTTClientWithSubRecords(MQTTLib.AWSIoTMQTTClient): + def __init__(self, clientID, protocolType, useWebsocket=False, cleanSession=True): + self._mqttCore = mockMQTTCore.mockMQTTCoreWithSubRecords(clientID, cleanSession, protocolType, useWebsocket) + +class mockAWSIoTMQTTClientQuiet(MQTTLib.AWSIoTMQTTClient): + def __init__(self, clientID, protocolType, useWebsocket=False, cleanSession=True): + self._mqttCore = mockMQTTCoreQuiet.mockMQTTCoreQuiet(clientID, cleanSession, protocolType, useWebsocket) + +class mockAWSIoTMQTTClientQuietWithSubRecords(MQTTLib.AWSIoTMQTTClient): + def __init__(self, clientID, protocolType, useWebsocket=False, cleanSession=True): + self._mqttCore = mockMQTTCoreQuiet.mockMQTTCoreQuietWithSubRecords(clientID, cleanSession, protocolType, useWebsocket) + +class mockAWSIoTMQTTShadowClient(MQTTLib.AWSIoTMQTTShadowClient): + def __init__(self, clientID, protocolType, useWebsocket=False, cleanSession=True): + # AWSIOTMQTTClient instance + self._AWSIoTMQTTClient = mockAWSIoTMQTTClientQuiet(clientID, protocolType, useWebsocket, cleanSession) + # Configure it to disable offline Publish Queueing + self._AWSIoTMQTTClient.configureOfflinePublishQueueing(0) + self._AWSIoTMQTTClient.configureDrainingFrequency(10) + # Now retrieve the configured mqttCore and init a shadowManager instance + self._shadowManager = shadowManager.shadowManager(self._AWSIoTMQTTClient._mqttCore) + + + diff --git a/test/sdk_mock/mockMQTTCore.py b/test/sdk_mock/mockMQTTCore.py new file mode 100755 index 0000000..e8c61b0 --- /dev/null +++ b/test/sdk_mock/mockMQTTCore.py @@ -0,0 +1,17 @@ +import sys +import mockPahoClient +import AWSIoTPythonSDK.core.protocol.mqttCore as mqttCore + +class mockMQTTCore(mqttCore.mqttCore): + def createPahoClient(self, clientID, cleanSession, userdata, protocol, useWebsocket): + return mockPahoClient.mockPahoClient(clientID, cleanSession, userdata, protocol, useWebsocket) + + def setReturnTupleForPahoClient(self, srcReturnTuple): + self._pahoClient.setReturnTuple(srcReturnTuple) + +class mockMQTTCoreWithSubRecords(mockMQTTCore): + def reinitSubscribePool(self): + self._subscribePoolRecords = dict() + + def subscribe(self, topic, qos, callback): + self._subscribePoolRecords[topic] = qos diff --git a/test/sdk_mock/mockMQTTCoreQuiet.py b/test/sdk_mock/mockMQTTCoreQuiet.py new file mode 100755 index 0000000..bdb6faa --- /dev/null +++ b/test/sdk_mock/mockMQTTCoreQuiet.py @@ -0,0 +1,34 @@ +import sys +import mockPahoClient +import AWSIoTPythonSDK.core.protocol.mqttCore as mqttCore + +class mockMQTTCoreQuiet(mqttCore.mqttCore): + def createPahoClient(self, clientID, cleanSession, userdata, protocol, useWebsocket): + return mockPahoClient.mockPahoClient(clientID, cleanSession, userdata, protocol, useWebsocket) + + def setReturnTupleForPahoClient(self, srcReturnTuple): + self._pahoClient.setReturnTuple(srcReturnTuple) + + def connect(self, keepAliveInterval): + pass + + def disconnect(self): + pass + + def publish(self, topic, payload, qos, retain): + pass + + def subscribe(self, topic, qos, callback): + pass + + def unsubscribe(self, topic): + pass + +class mockMQTTCoreQuietWithSubRecords(mockMQTTCoreQuiet): + + def reinitSubscribePool(self): + self._subscribePoolRecords = dict() + + def subscribe(self, topic, qos, callback): + self._subscribePoolRecords[topic] = qos + diff --git a/test/sdk_mock/mockMessage.py b/test/sdk_mock/mockMessage.py new file mode 100755 index 0000000..61c733a --- /dev/null +++ b/test/sdk_mock/mockMessage.py @@ -0,0 +1,7 @@ +class mockMessage: + topic = None + payload = None + + def __init__(self, srcTopic, srcPayload): + self.topic = srcTopic + self.payload = srcPayload diff --git a/test/sdk_mock/mockPahoClient.py b/test/sdk_mock/mockPahoClient.py new file mode 100755 index 0000000..8bcfda6 --- /dev/null +++ b/test/sdk_mock/mockPahoClient.py @@ -0,0 +1,49 @@ +import sys +import AWSIoTPythonSDK.core.protocol.paho.client as mqtt +import logging + +class mockPahoClient(mqtt.Client): + _log = logging.getLogger(__name__) + _returnTuple = (-1, -1) + # Callback handlers + on_connect = None + on_disconnect = None + on_message = None + on_publish = None + on_subsribe = None + on_unsubscribe = None + + def setReturnTuple(self, srcTuple): + self._returnTuple = srcTuple + + # Tool function + def tls_set(self, ca_certs=None, certfile=None, keyfile=None, cert_reqs=None, tls_version=None): + self._log.debug("tls_set called.") + + def loop_start(self): + self._log.debug("Socket thread started.") + + def loop_stop(self): + self._log.debug("Socket thread stopped.") + + def message_callback_add(self, sub, callback): + self._log.debug("Add a user callback. Topic: " + str(sub)) + + # MQTT API + def connect(self, host, port, keepalive): + self._log.debug("Connect called.") + + def disconnect(self): + self._log.debug("Disconnect called.") + + def publish(self, topic, payload, qos, retain): + self._log.debug("Publish called.") + return self._returnTuple + + def subscribe(self, topic, qos): + self._log.debug("Subscribe called.") + return self._returnTuple + + def unsubscribe(self, topic): + self._log.debug("Unsubscribe called.") + return self._returnTuple diff --git a/test/sdk_mock/mockSSLSocket.py b/test/sdk_mock/mockSSLSocket.py new file mode 100755 index 0000000..6bf953e --- /dev/null +++ b/test/sdk_mock/mockSSLSocket.py @@ -0,0 +1,104 @@ +import socket +import ssl + +class mockSSLSocket: + def __init__(self): + self._readBuffer = bytearray() + self._writeBuffer = bytearray() + self._isClosed = False + self._isFragmented = False + self._fragmentDoneThrowError = False + self._currentFragments = bytearray() + self._fragments = list() + self._flipWriteError = False + self._flipWriteErrorCount = 0 + + # TestHelper APIs + def refreshReadBuffer(self, bytesToLoad): + self._readBuffer = bytesToLoad + + def reInit(self): + self._readBuffer = bytearray() + self._writeBuffer = bytearray() + self._isClosed = False + self._isFragmented = False + self._fragmentDoneThrowError = False + self._currentFragments = bytearray() + self._fragments = list() + self._flipWriteError = False + self._flipWriteErrorCount = 0 + + def getReaderBuffer(self): + return self._readBuffer + + def getWriteBuffer(self): + return self._writeBuffer + + def addReadBufferFragment(self, fragmentElement): + self._fragments.append(fragmentElement) + + def setReadFragmented(self): + self._isFragmented = True + + def setFlipWriteError(self): + self._flipWriteError = True + self._flipWriteErrorCount = 0 + + def loadFirstFragmented(self): + self._currentFragments = self._fragments.pop(0) + + # Public APIs + # Should return bytes, not string + def read(self, numberOfBytes): + if not self._isFragmented: # Read a lot, then nothing + if len(self._readBuffer) == 0: + raise socket.error(ssl.SSL_ERROR_WANT_READ, "End of read buffer") + # If we have enough data for the requested amount, give them out + if numberOfBytes <= len(self._readBuffer): + ret = self._readBuffer[0:numberOfBytes] + self._readBuffer = self._readBuffer[numberOfBytes:] + else: + ret = self._readBuffer + self._readBuffer = self._readBuffer[len(self._readBuffer):] # Empty + return ret + else: # Read 1 fragement util it is empty, then throw error, then load in next + if self._fragmentDoneThrowError and len(self._fragments) > 0: + self._currentFragments = self._fragments.pop(0) # Load in next fragment + self._fragmentDoneThrowError = False # Reset ThrowError flag + raise socket.error(ssl.SSL_ERROR_WANT_READ, "Not ready for read op") + # If we have enough data for the requested amount in the current fragment, give them out + ret = bytearray() + if numberOfBytes <= len(self._currentFragments): + ret = self._currentFragments[0:numberOfBytes] + self._currentFragments = self._currentFragments[numberOfBytes:] + if len(self._currentFragments) == 0: + self._fragmentDoneThrowError = True # Will throw error next time + else: + ret = self._currentFragments + self._currentFragments = self._currentFragments[len(self._currentFragments):] # Empty + self._fragmentDoneThrowError = True + return ret + + # Should write bytes, not string + def write(self, bytesToWrite): + if self._flipWriteError: + if self._flipWriteErrorCount % 2 == 1: + self._writeBuffer += bytesToWrite # bytesToWrite should always be in 'bytes' type + self._flipWriteErrorCount += 1 + return len(bytesToWrite) + else: + self._flipWriteErrorCount += 1 + raise socket.error(ssl.SSL_ERROR_WANT_WRITE, "Not ready for write op") + else: + self._writeBuffer += bytesToWrite # bytesToWrite should always be in 'bytes' type + return len(bytesToWrite) + + def close(self): + self._isClosed = True + + + + + + + diff --git a/test/sdk_mock/mockSecuredWebsocketCore.py b/test/sdk_mock/mockSecuredWebsocketCore.py new file mode 100755 index 0000000..4f2efe9 --- /dev/null +++ b/test/sdk_mock/mockSecuredWebsocketCore.py @@ -0,0 +1,35 @@ +from test.sdk_mock.mockSigV4Core import mockSigV4Core +from AWSIoTPythonSDK.core.protocol.connection.cores import SecuredWebSocketCore + + +class mockSecuredWebsocketCoreNoRealHandshake(SecuredWebSocketCore): + def _createSigV4Core(self): + ret = mockSigV4Core() + ret.setNoEnvVar(False) # Always has Env Var + return ret + + def _handShake(self, hostAddress, portNumber): # Override to pass handshake + pass + + def _generateMaskKey(self): + return bytearray(str("1234"), 'utf-8') # Arbitrary mask key for testing + + +class MockSecuredWebSocketCoreNoSocketIO(SecuredWebSocketCore): + def _createSigV4Core(self): + ret = mockSigV4Core() + ret.setNoEnvVar(False) # Always has Env Var + return ret + + def _generateMaskKey(self): + return bytearray(str("1234"), 'utf-8') # Arbitrary mask key for testing + + def _getTimeoutSec(self): + return 3 # 3 sec to time out from waiting for handshake response for testing + + +class MockSecuredWebSocketCoreWithRealHandshake(SecuredWebSocketCore): + def _createSigV4Core(self): + ret = mockSigV4Core() + ret.setNoEnvVar(False) # Always has Env Var + return ret diff --git a/test/sdk_mock/mockSigV4Core.py b/test/sdk_mock/mockSigV4Core.py new file mode 100755 index 0000000..142b27f --- /dev/null +++ b/test/sdk_mock/mockSigV4Core.py @@ -0,0 +1,17 @@ +from AWSIoTPythonSDK.core.protocol.connection.cores import SigV4Core + + +class mockSigV4Core(SigV4Core): + _forceNoEnvVar = False + + def setNoEnvVar(self, srcVal): + self._forceNoEnvVar = srcVal + + def _checkKeyInEnv(self): # Simulate no Env Var + if self._forceNoEnvVar: + return dict() # Return empty list + else: + ret = dict() + ret["aws_access_key_id"] = "blablablaID" + ret["aws_secret_access_key"] = "blablablaSecret" + return ret diff --git a/test/test_mqtt_lib.py b/test/test_mqtt_lib.py new file mode 100644 index 0000000..b74375b --- /dev/null +++ b/test/test_mqtt_lib.py @@ -0,0 +1,304 @@ +from AWSIoTPythonSDK.core.protocol.mqtt_core import MqttCore +from AWSIoTPythonSDK.MQTTLib import AWSIoTMQTTClient +from AWSIoTPythonSDK.MQTTLib import AWSIoTMQTTShadowClient +from AWSIoTPythonSDK.MQTTLib import DROP_NEWEST +try: + from mock import patch + from mock import MagicMock +except: + from unittest.mock import patch + from unittest.mock import MagicMock + + +PATCH_MODULE_LOCATION = "AWSIoTPythonSDK.MQTTLib." +CLIENT_ID = "DefaultClientId" +SHADOW_CLIENT_ID = "DefaultShadowClientId" +DUMMY_HOST = "dummy.host" +PORT_443 = 443 +PORT_8883 = 8883 +DEFAULT_KEEPALIVE_SEC = 600 +DUMMY_TOPIC = "dummy/topic" +DUMMY_PAYLOAD = "dummy/payload" +DUMMY_QOS = 1 +DUMMY_AWS_ACCESS_KEY_ID = "DummyKeyId" +DUMMY_AWS_SECRET_KEY = "SecretKey" +DUMMY_AWS_TOKEN = "Token" +DUMMY_CA_PATH = "path/to/ca" +DUMMY_CERT_PATH = "path/to/cert" +DUMMY_KEY_PATH = "path/to/key" +DUMMY_BASE_RECONNECT_BACKOFF_SEC = 1 +DUMMY_MAX_RECONNECT_BACKOFF_SEC = 32 +DUMMY_STABLE_CONNECTION_SEC = 16 +DUMMY_QUEUE_SIZE = 100 +DUMMY_DRAINING_FREQUENCY = 2 +DUMMY_TIMEOUT_SEC = 10 +DUMMY_USER_NAME = "UserName" +DUMMY_PASSWORD = "Password" + + +class TestMqttLibShadowClient: + + def setup_method(self, test_method): + self._use_mock_mqtt_core() + + def _use_mock_mqtt_core(self): + self.mqtt_core_patcher = patch(PATCH_MODULE_LOCATION + "MqttCore", spec=MqttCore) + self.mock_mqtt_core_constructor = self.mqtt_core_patcher.start() + self.mqtt_core_mock = MagicMock() + self.mock_mqtt_core_constructor.return_value = self.mqtt_core_mock + self.iot_mqtt_shadow_client = AWSIoTMQTTShadowClient(SHADOW_CLIENT_ID) + + def teardown_method(self, test_method): + self.mqtt_core_patcher.stop() + + def test_iot_mqtt_shadow_client_with_provided_mqtt_client(self): + mock_iot_mqtt_client = MagicMock() + iot_mqtt_shadow_client_with_provided_mqtt_client = AWSIoTMQTTShadowClient(SHADOW_CLIENT_ID, awsIoTMQTTClient=mock_iot_mqtt_client) + assert mock_iot_mqtt_client.configureOfflinePublishQueueing.called is False + + def test_iot_mqtt_shadow_client_connect_default_keepalive(self): + self.iot_mqtt_shadow_client.connect() + self.mqtt_core_mock.connect.assert_called_once_with(DEFAULT_KEEPALIVE_SEC) + + def test_iot_mqtt_shadow_client_auto_enable_when_use_cert_over_443(self): + self.mqtt_core_mock.use_wss.return_value = False + self.iot_mqtt_shadow_client.configureEndpoint(hostName=DUMMY_HOST, portNumber=PORT_443) + self.mqtt_core_mock.configure_alpn_protocols.assert_called_once() + + def test_iot_mqtt_shadow_client_alpn_auto_disable_when_use_wss(self): + self.mqtt_core_mock.use_wss.return_value = True + self.iot_mqtt_shadow_client.configureEndpoint(hostName=DUMMY_HOST, portNumber=PORT_443) + assert self.mqtt_core_mock.configure_alpn_protocols.called is False + + def test_iot_mqtt_shadow_client_alpn_auto_disable_when_use_cert_over_8883(self): + self.mqtt_core_mock.use_wss.return_value = False + self.iot_mqtt_shadow_client.configureEndpoint(hostName=DUMMY_HOST, portNumber=PORT_8883) + assert self.mqtt_core_mock.configure_alpn_protocols.called is False + + def test_iot_mqtt_shadow_client_clear_last_will(self): + self.iot_mqtt_shadow_client.clearLastWill() + self.mqtt_core_mock.clear_last_will.assert_called_once() + + def test_iot_mqtt_shadow_client_configure_endpoint(self): + self.iot_mqtt_shadow_client.configureEndpoint(DUMMY_HOST, PORT_8883) + self.mqtt_core_mock.configure_endpoint.assert_called_once() + + def test_iot_mqtt_shadow_client_configure_iam_credentials(self): + self.iot_mqtt_shadow_client.configureIAMCredentials(DUMMY_AWS_ACCESS_KEY_ID, DUMMY_AWS_SECRET_KEY, DUMMY_AWS_TOKEN) + self.mqtt_core_mock.configure_iam_credentials.assert_called_once() + + def test_iot_mqtt_shadowclient_configure_credentials(self): + self.iot_mqtt_shadow_client.configureCredentials(DUMMY_CA_PATH, DUMMY_KEY_PATH, DUMMY_CERT_PATH) + self.mqtt_core_mock.configure_cert_credentials.assert_called_once() + + def test_iot_mqtt_shadow_client_configure_auto_reconnect_backoff(self): + self.iot_mqtt_shadow_client.configureAutoReconnectBackoffTime(DUMMY_BASE_RECONNECT_BACKOFF_SEC, + DUMMY_MAX_RECONNECT_BACKOFF_SEC, + DUMMY_STABLE_CONNECTION_SEC) + self.mqtt_core_mock.configure_reconnect_back_off.assert_called_once_with(DUMMY_BASE_RECONNECT_BACKOFF_SEC, + DUMMY_MAX_RECONNECT_BACKOFF_SEC, + DUMMY_STABLE_CONNECTION_SEC) + + def test_iot_mqtt_shadow_client_configure_offline_publish_queueing(self): + # This configurable is done at object initialization. We do not allow customers to configure this. + self.mqtt_core_mock.configure_offline_requests_queue.assert_called_once_with(0, DROP_NEWEST) # Disabled + + def test_iot_mqtt_client_configure_draining_frequency(self): + # This configurable is done at object initialization. We do not allow customers to configure this. + # Sine queuing is disabled, draining interval configuration is meaningless. + # "10" is just a placeholder value in the internal implementation. + self.mqtt_core_mock.configure_draining_interval_sec.assert_called_once_with(1/float(10)) + + def test_iot_mqtt_client_configure_connect_disconnect_timeout(self): + self.iot_mqtt_shadow_client.configureConnectDisconnectTimeout(DUMMY_TIMEOUT_SEC) + self.mqtt_core_mock.configure_connect_disconnect_timeout_sec.assert_called_once_with(DUMMY_TIMEOUT_SEC) + + def test_iot_mqtt_client_configure_mqtt_operation_timeout(self): + self.iot_mqtt_shadow_client.configureMQTTOperationTimeout(DUMMY_TIMEOUT_SEC) + self.mqtt_core_mock.configure_operation_timeout_sec.assert_called_once_with(DUMMY_TIMEOUT_SEC) + + def test_iot_mqtt_client_configure_user_name_password(self): + self.iot_mqtt_shadow_client.configureUsernamePassword(DUMMY_USER_NAME, DUMMY_PASSWORD) + self.mqtt_core_mock.configure_username_password.assert_called_once_with(DUMMY_USER_NAME, DUMMY_PASSWORD) + + def test_iot_mqtt_client_enable_metrics_collection(self): + self.iot_mqtt_shadow_client.enableMetricsCollection() + self.mqtt_core_mock.enable_metrics_collection.assert_called_once() + + def test_iot_mqtt_client_disable_metrics_collection(self): + self.iot_mqtt_shadow_client.disableMetricsCollection() + self.mqtt_core_mock.disable_metrics_collection.assert_called_once() + + def test_iot_mqtt_client_callback_registration_upon_connect(self): + fake_on_online_callback = MagicMock() + fake_on_offline_callback = MagicMock() + + self.iot_mqtt_shadow_client.onOnline = fake_on_online_callback + self.iot_mqtt_shadow_client.onOffline = fake_on_offline_callback + # `onMessage` is used internally by the SDK. We do not expose this callback configurable to the customer + + self.iot_mqtt_shadow_client.connect() + + assert self.mqtt_core_mock.on_online == fake_on_online_callback + assert self.mqtt_core_mock.on_offline == fake_on_offline_callback + self.mqtt_core_mock.connect.assert_called_once() + + def test_iot_mqtt_client_disconnect(self): + self.iot_mqtt_shadow_client.disconnect() + self.mqtt_core_mock.disconnect.assert_called_once() + + +class TestMqttLibMqttClient: + + def setup_method(self, test_method): + self._use_mock_mqtt_core() + + def _use_mock_mqtt_core(self): + self.mqtt_core_patcher = patch(PATCH_MODULE_LOCATION + "MqttCore", spec=MqttCore) + self.mock_mqtt_core_constructor = self.mqtt_core_patcher.start() + self.mqtt_core_mock = MagicMock() + self.mock_mqtt_core_constructor.return_value = self.mqtt_core_mock + self.iot_mqtt_client = AWSIoTMQTTClient(CLIENT_ID) + + def teardown_method(self, test_method): + self.mqtt_core_patcher.stop() + + def test_iot_mqtt_client_connect_default_keepalive(self): + self.iot_mqtt_client.connect() + self.mqtt_core_mock.connect.assert_called_once_with(DEFAULT_KEEPALIVE_SEC) + + def test_iot_mqtt_client_connect_async_default_keepalive(self): + self.iot_mqtt_client.connectAsync() + self.mqtt_core_mock.connect_async.assert_called_once_with(DEFAULT_KEEPALIVE_SEC, None) + + def test_iot_mqtt_client_alpn_auto_enable_when_use_cert_over_443(self): + self.mqtt_core_mock.use_wss.return_value = False + self.iot_mqtt_client.configureEndpoint(hostName=DUMMY_HOST, portNumber=PORT_443) + self.mqtt_core_mock.configure_alpn_protocols.assert_called_once() + + def test_iot_mqtt_client_alpn_auto_disable_when_use_wss(self): + self.mqtt_core_mock.use_wss.return_value = True + self.iot_mqtt_client.configureEndpoint(hostName=DUMMY_HOST, portNumber=PORT_443) + assert self.mqtt_core_mock.configure_alpn_protocols.called is False + + def test_iot_mqtt_client_alpn_auto_disable_when_use_cert_over_8883(self): + self.mqtt_core_mock.use_wss.return_value = False + self.iot_mqtt_client.configureEndpoint(hostName=DUMMY_HOST, portNumber=PORT_8883) + assert self.mqtt_core_mock.configure_alpn_protocols.called is False + + def test_iot_mqtt_client_configure_last_will(self): + self.iot_mqtt_client.configureLastWill(topic=DUMMY_TOPIC, payload=DUMMY_PAYLOAD, QoS=DUMMY_QOS) + self.mqtt_core_mock.configure_last_will.assert_called_once_with(DUMMY_TOPIC, DUMMY_PAYLOAD, DUMMY_QOS, False) + + def test_iot_mqtt_client_clear_last_will(self): + self.iot_mqtt_client.clearLastWill() + self.mqtt_core_mock.clear_last_will.assert_called_once() + + def test_iot_mqtt_client_configure_endpoint(self): + self.iot_mqtt_client.configureEndpoint(DUMMY_HOST, PORT_8883) + self.mqtt_core_mock.configure_endpoint.assert_called_once() + + def test_iot_mqtt_client_configure_iam_credentials(self): + self.iot_mqtt_client.configureIAMCredentials(DUMMY_AWS_ACCESS_KEY_ID, DUMMY_AWS_SECRET_KEY, DUMMY_AWS_TOKEN) + self.mqtt_core_mock.configure_iam_credentials.assert_called_once() + + def test_iot_mqtt_client_configure_credentials(self): + self.iot_mqtt_client.configureCredentials(DUMMY_CA_PATH, DUMMY_KEY_PATH, DUMMY_CERT_PATH) + self.mqtt_core_mock.configure_cert_credentials.assert_called_once() + + def test_iot_mqtt_client_configure_auto_reconnect_backoff(self): + self.iot_mqtt_client.configureAutoReconnectBackoffTime(DUMMY_BASE_RECONNECT_BACKOFF_SEC, + DUMMY_MAX_RECONNECT_BACKOFF_SEC, + DUMMY_STABLE_CONNECTION_SEC) + self.mqtt_core_mock.configure_reconnect_back_off.assert_called_once_with(DUMMY_BASE_RECONNECT_BACKOFF_SEC, + DUMMY_MAX_RECONNECT_BACKOFF_SEC, + DUMMY_STABLE_CONNECTION_SEC) + + def test_iot_mqtt_client_configure_offline_publish_queueing(self): + self.iot_mqtt_client.configureOfflinePublishQueueing(DUMMY_QUEUE_SIZE) + self.mqtt_core_mock.configure_offline_requests_queue.assert_called_once_with(DUMMY_QUEUE_SIZE, DROP_NEWEST) + + def test_iot_mqtt_client_configure_draining_frequency(self): + self.iot_mqtt_client.configureDrainingFrequency(DUMMY_DRAINING_FREQUENCY) + self.mqtt_core_mock.configure_draining_interval_sec.assert_called_once_with(1/float(DUMMY_DRAINING_FREQUENCY)) + + def test_iot_mqtt_client_configure_connect_disconnect_timeout(self): + self.iot_mqtt_client.configureConnectDisconnectTimeout(DUMMY_TIMEOUT_SEC) + self.mqtt_core_mock.configure_connect_disconnect_timeout_sec.assert_called_once_with(DUMMY_TIMEOUT_SEC) + + def test_iot_mqtt_client_configure_mqtt_operation_timeout(self): + self.iot_mqtt_client.configureMQTTOperationTimeout(DUMMY_TIMEOUT_SEC) + self.mqtt_core_mock.configure_operation_timeout_sec.assert_called_once_with(DUMMY_TIMEOUT_SEC) + + def test_iot_mqtt_client_configure_user_name_password(self): + self.iot_mqtt_client.configureUsernamePassword(DUMMY_USER_NAME, DUMMY_PASSWORD) + self.mqtt_core_mock.configure_username_password.assert_called_once_with(DUMMY_USER_NAME, DUMMY_PASSWORD) + + def test_iot_mqtt_client_enable_metrics_collection(self): + self.iot_mqtt_client.enableMetricsCollection() + self.mqtt_core_mock.enable_metrics_collection.assert_called_once() + + def test_iot_mqtt_client_disable_metrics_collection(self): + self.iot_mqtt_client.disableMetricsCollection() + self.mqtt_core_mock.disable_metrics_collection.assert_called_once() + + def test_iot_mqtt_client_callback_registration_upon_connect(self): + fake_on_online_callback = MagicMock() + fake_on_offline_callback = MagicMock() + fake_on_message_callback = MagicMock() + + self.iot_mqtt_client.onOnline = fake_on_online_callback + self.iot_mqtt_client.onOffline = fake_on_offline_callback + self.iot_mqtt_client.onMessage = fake_on_message_callback + + self.iot_mqtt_client.connect() + + assert self.mqtt_core_mock.on_online == fake_on_online_callback + assert self.mqtt_core_mock.on_offline == fake_on_offline_callback + assert self.mqtt_core_mock.on_message == fake_on_message_callback + self.mqtt_core_mock.connect.assert_called_once() + + def test_iot_mqtt_client_disconnect(self): + self.iot_mqtt_client.disconnect() + self.mqtt_core_mock.disconnect.assert_called_once() + + def test_iot_mqtt_client_publish(self): + self.iot_mqtt_client.publish(DUMMY_TOPIC, DUMMY_PAYLOAD, DUMMY_QOS) + self.mqtt_core_mock.publish.assert_called_once_with(DUMMY_TOPIC, DUMMY_PAYLOAD, DUMMY_QOS, False) + + def test_iot_mqtt_client_subscribe(self): + message_callback = MagicMock() + self.iot_mqtt_client.subscribe(DUMMY_TOPIC, DUMMY_QOS, message_callback) + self.mqtt_core_mock.subscribe.assert_called_once_with(DUMMY_TOPIC, DUMMY_QOS, message_callback) + + def test_iot_mqtt_client_unsubscribe(self): + self.iot_mqtt_client.unsubscribe(DUMMY_TOPIC) + self.mqtt_core_mock.unsubscribe.assert_called_once_with(DUMMY_TOPIC) + + def test_iot_mqtt_client_connect_async(self): + connack_callback = MagicMock() + self.iot_mqtt_client.connectAsync(ackCallback=connack_callback) + self.mqtt_core_mock.connect_async.assert_called_once_with(DEFAULT_KEEPALIVE_SEC, connack_callback) + + def test_iot_mqtt_client_disconnect_async(self): + disconnect_callback = MagicMock() + self.iot_mqtt_client.disconnectAsync(ackCallback=disconnect_callback) + self.mqtt_core_mock.disconnect_async.assert_called_once_with(disconnect_callback) + + def test_iot_mqtt_client_publish_async(self): + puback_callback = MagicMock() + self.iot_mqtt_client.publishAsync(DUMMY_TOPIC, DUMMY_PAYLOAD, DUMMY_QOS, puback_callback) + self.mqtt_core_mock.publish_async.assert_called_once_with(DUMMY_TOPIC, DUMMY_PAYLOAD, DUMMY_QOS, + False, puback_callback) + + def test_iot_mqtt_client_subscribe_async(self): + suback_callback = MagicMock() + message_callback = MagicMock() + self.iot_mqtt_client.subscribeAsync(DUMMY_TOPIC, DUMMY_QOS, suback_callback, message_callback) + self.mqtt_core_mock.subscribe_async.assert_called_once_with(DUMMY_TOPIC, DUMMY_QOS, + suback_callback, message_callback) + + def test_iot_mqtt_client_unsubscribe_async(self): + unsuback_callback = MagicMock() + self.iot_mqtt_client.unsubscribeAsync(DUMMY_TOPIC, unsuback_callback) + self.mqtt_core_mock.unsubscribe_async.assert_called_once_with(DUMMY_TOPIC, unsuback_callback)