diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml new file mode 100644 index 0000000..946fd2f --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -0,0 +1,81 @@ +--- +name: "🐛 Bug Report" +description: Report a bug +title: "(short issue description)" +labels: [bug, needs-triage] +assignees: [] +body: + - type: textarea + id: description + attributes: + label: Describe the bug + description: What is the problem? A clear and concise description of the bug. + validations: + required: true + - type: checkboxes + id: regression + attributes: + label: Regression Issue + description: What is a regression? If it worked in a previous version but doesn't in the latest version, it's considered a regression. In this case, please provide specific version number in the report. + options: + - label: Select this option if this issue appears to be a regression. + required: false + - type: textarea + id: expected + attributes: + label: Expected Behavior + description: | + What did you expect to happen? + validations: + required: true + - type: textarea + id: current + attributes: + label: Current Behavior + description: | + What actually happened? + + Please include full errors, uncaught exceptions, stack traces, and relevant logs. + If service responses are relevant, please include wire logs. + validations: + required: true + - type: textarea + id: reproduction + attributes: + label: Reproduction Steps + description: | + Provide a self-contained, concise snippet of code that can be used to reproduce the issue. + For more complex issues provide a repo with the smallest sample that reproduces the bug. + + Avoid including business logic or unrelated code, it makes diagnosis more difficult. + The code sample should be an SSCCE. See http://sscce.org/ for details. In short, please provide a code sample that we can copy/paste, run and reproduce. + validations: + required: true + - type: textarea + id: solution + attributes: + label: Possible Solution + description: | + Suggest a fix/reason for the bug + validations: + required: false + - type: textarea + id: context + attributes: + label: Additional Information/Context + description: | + Anything else that might be relevant for troubleshooting this bug. Providing context helps us come up with a solution that is most useful in the real world. + validations: + required: false + - type: input + id: sdk-version + attributes: + label: SDK version used + validations: + required: true + - type: input + id: environment + attributes: + label: Environment details (OS name and version, etc.) + validations: + required: true diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000..fe0acce --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,6 @@ +--- +blank_issues_enabled: false +contact_links: + - name: 💬 General Question + url: https://github.com/aws/aws-iot-device-sdk-python/discussions/categories/q-a + about: Please ask and answer questions as a discussion thread diff --git a/.github/ISSUE_TEMPLATE/documentation.yml b/.github/ISSUE_TEMPLATE/documentation.yml new file mode 100644 index 0000000..7d73869 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/documentation.yml @@ -0,0 +1,23 @@ +--- +name: "📕 Documentation Issue" +description: Report an issue in the API Reference documentation or Developer Guide +title: "(short issue description)" +labels: [documentation, needs-triage] +assignees: [] +body: + - type: textarea + id: description + attributes: + label: Describe the issue + description: A clear and concise description of the issue. + validations: + required: true + + - type: textarea + id: links + attributes: + label: Links + description: | + Include links to affected documentation page(s). + validations: + required: true diff --git a/.github/ISSUE_TEMPLATE/feature-request.yml b/.github/ISSUE_TEMPLATE/feature-request.yml new file mode 100644 index 0000000..60d2431 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature-request.yml @@ -0,0 +1,59 @@ +--- +name: 🚀 Feature Request +description: Suggest an idea for this project +title: "(short issue description)" +labels: [feature-request, needs-triage] +assignees: [] +body: + - type: textarea + id: description + attributes: + label: Describe the feature + description: A clear and concise description of the feature you are proposing. + validations: + required: true + - type: textarea + id: use-case + attributes: + label: Use Case + description: | + Why do you need this feature? For example: "I'm always frustrated when..." + validations: + required: true + - type: textarea + id: solution + attributes: + label: Proposed Solution + description: | + Suggest how to implement the addition or change. Please include prototype/workaround/sketch/reference implementation. + validations: + required: false + - type: textarea + id: other + attributes: + label: Other Information + description: | + Any alternative solutions or features you considered, a more detailed explanation, stack traces, related issues, links for context, etc. + validations: + required: false + - type: checkboxes + id: ack + attributes: + label: Acknowledgements + options: + - label: I may be able to implement this feature request + required: false + - label: This feature might incur a breaking change + required: false + - type: input + id: sdk-version + attributes: + label: SDK version used + validations: + required: true + - type: input + id: environment + attributes: + label: Environment details (OS name and version, etc.) + validations: + required: true diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..195bf2d --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,59 @@ +name: CI + +on: + push: + branches: + - '*' + - '!main' + +env: + RUN: ${{ github.run_id }}-${{ github.run_number }} + AWS_DEFAULT_REGION: us-east-1 + CI_SDK_V1_ROLE: arn:aws:iam::180635532705:role/CI_SDK_V1_ROLE + PACKAGE_NAME: aws-iot-device-sdk-python + AWS_EC2_METADATA_DISABLED: true + +jobs: + unit-tests: + runs-on: ubuntu-20.04 + strategy: + fail-fast: false + + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: '3.8' + - name: Unit tests + run: | + python3 setup.py install + pip install pytest + pip install mock + python3 -m pytest test + + integration-tests: + runs-on: ubuntu-latest + permissions: + id-token: write # This is required for requesting the JWT + contents: read # This is required for actions/checkout + strategy: + fail-fast: false + matrix: + test-type: [ MutualAuth, Websocket, ALPN ] + python-version: [ '3.8', '3.13' ] + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - uses: aws-actions/configure-aws-credentials@v2 + with: + role-to-assume: ${{ env.CI_SDK_V1_ROLE }} + aws-region: ${{ env.AWS_DEFAULT_REGION }} + - name: Integration tests + run: | + pip install pytest + pip install mock + pip install boto3 + python --version + ./test-integration/run/run.sh ${{ matrix.test-type }} 1000 100 7 diff --git a/.github/workflows/closed-issue-message.yml b/.github/workflows/closed-issue-message.yml new file mode 100644 index 0000000..22bf2a7 --- /dev/null +++ b/.github/workflows/closed-issue-message.yml @@ -0,0 +1,19 @@ +name: Closed Issue Message +on: + issues: + types: [closed] +jobs: + auto_comment: + runs-on: ubuntu-latest + permissions: + issues: write + steps: + - uses: aws-actions/closed-issue-message@v1 + with: + # These inputs are both required + repo-token: "${{ secrets.GITHUB_TOKEN }}" + message: | + ### ⚠️COMMENT VISIBILITY WARNING⚠️ + Comments on closed issues are hard for our team to see. + If you need more assistance, please either tag a team member or open a new issue that references this one. + If you wish to keep having a conversation with other community members under this issue feel free to do so. diff --git a/.github/workflows/handle-stale-discussions.yml b/.github/workflows/handle-stale-discussions.yml new file mode 100644 index 0000000..4fbcd70 --- /dev/null +++ b/.github/workflows/handle-stale-discussions.yml @@ -0,0 +1,19 @@ +name: HandleStaleDiscussions +on: + schedule: + - cron: '0 */4 * * *' + discussion_comment: + types: [created] + +jobs: + handle-stale-discussions: + name: Handle stale discussions + runs-on: ubuntu-latest + permissions: + discussions: write + steps: + - name: Stale discussions action + uses: aws-github-ops/handle-stale-discussions@v1 + env: + GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}} + \ No newline at end of file diff --git a/.github/workflows/issue-regression-labeler.yml b/.github/workflows/issue-regression-labeler.yml new file mode 100644 index 0000000..bd00071 --- /dev/null +++ b/.github/workflows/issue-regression-labeler.yml @@ -0,0 +1,32 @@ +# Apply potential regression label on issues +name: issue-regression-label +on: + issues: + types: [opened, edited] +jobs: + add-regression-label: + runs-on: ubuntu-latest + permissions: + issues: write + steps: + - name: Fetch template body + id: check_regression + uses: actions/github-script@v7 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + TEMPLATE_BODY: ${{ github.event.issue.body }} + with: + script: | + const regressionPattern = /\[x\] Select this option if this issue appears to be a regression\./i; + const template = `${process.env.TEMPLATE_BODY}` + const match = regressionPattern.test(template); + core.setOutput('is_regression', match); + - name: Manage regression label + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + if [ "${{ steps.check_regression.outputs.is_regression }}" == "true" ]; then + gh issue edit ${{ github.event.issue.number }} --add-label "potential-regression" -R ${{ github.repository }} + else + gh issue edit ${{ github.event.issue.number }} --remove-label "potential-regression" -R ${{ github.repository }} + fi diff --git a/.github/workflows/stale_issue.yml b/.github/workflows/stale_issue.yml new file mode 100644 index 0000000..cbcc8b4 --- /dev/null +++ b/.github/workflows/stale_issue.yml @@ -0,0 +1,49 @@ +name: "Close stale issues" + +# Controls when the action will run. +on: + schedule: + - cron: "0 0 * * *" + +jobs: + cleanup: + runs-on: ubuntu-latest + name: Stale issue job + permissions: + issues: write + pull-requests: write + steps: + - uses: aws-actions/stale-issue-cleanup@v3 + with: + # Setting messages to an empty string will cause the automation to skip + # that category + ancient-issue-message: Greetings! Sorry to say but this is a very old issue that is probably not getting as much attention as it deserves. We encourage you to try V2 and if you find that this is still a problem, please feel free to open a new issue there. + stale-issue-message: Greetings! It looks like this issue hasn’t been active in longer than a week. Because it has been longer than a week since the last update on this, and in the absence of more information, we will be closing this issue soon. If you find that this is still a problem, please feel free to provide a comment or add an upvote to prevent automatic closure, or if the issue is already closed, please feel free to open a new one, also please try V2 as this might be solved there too. + stale-pr-message: Greetings! It looks like this PR hasn’t been active in longer than a week, add a comment or an upvote to prevent automatic closure, or if the issue is already closed, please feel free to open a new one. + + # These labels are required + stale-issue-label: closing-soon + exempt-issue-label: automation-exempt + stale-pr-label: closing-soon + exempt-pr-label: pr/needs-review + response-requested-label: response-requested + + # Don't set closed-for-staleness label to skip closing very old issues + # regardless of label + closed-for-staleness-label: closed-for-staleness + + # Issue timing + days-before-stale: 7 + days-before-close: 4 + days-before-ancient: 36500 + + # If you don't want to mark a issue as being ancient based on a + # threshold of "upvotes", you can set this here. An "upvote" is + # the total number of +1, heart, hooray, and rocket reactions + # on an issue. + minimum-upvotes-to-exempt: 1 + + repo-token: ${{ secrets.GITHUB_TOKEN }} + loglevel: DEBUG + # Set dry-run to true to not perform label or close actions. + dry-run: false diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b1e0672 --- /dev/null +++ b/.gitignore @@ -0,0 +1,534 @@ + +# Created by https://www.gitignore.io/api/git,c++,cmake,python,visualstudio,visualstudiocode + +### C++ ### +# Prerequisites +*.d + +# Compiled Object files +*.slo +*.lo +*.o +*.obj + +# Precompiled Headers +*.gch +*.pch + +# Compiled Dynamic libraries +*.so +*.dylib +*.dll + +# Fortran module files +*.mod +*.smod + +# Compiled Static libraries +*.lai +*.la +*.a +*.lib + +# Executables +*.exe +*.out +*.app + +### CMake ### +CMakeCache.txt +CMakeFiles +CMakeScripts +Testing +Makefile +cmake_install.cmake +install_manifest.txt +compile_commands.json +CTestTestfile.cmake + +### Git ### +# Created by git for backups. To disable backups in Git: +# $ git config --global mergetool.keepBackup false +*.orig + +# Created by git when using merge tools for conflicts +*.BACKUP.* +*.BASE.* +*.LOCAL.* +*.REMOTE.* +*_BACKUP_*.txt +*_BASE_*.txt +*_LOCAL_*.txt +*_REMOTE_*.txt + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions + +# Distribution / packaging +.Python +build/ +deps_build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheelhouse/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +### Python Patch ### +.venv/ + +### Python.VirtualEnv Stack ### +# Virtualenv +# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ +[Ii]nclude +[Ll]ib +[Ll]ib64 +[Ll]ocal +[Ss]cripts +pyvenv.cfg +pip-selfcheck.json + +### VisualStudioCode ### +.vscode/* + +### VisualStudio ### +## Ignore Visual Studio temporary files, build results, and +## files generated by popular Visual Studio add-ons. +## +## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore + +# User-specific files +*.suo +*.user +*.userosscache +*.sln.docstates + +# User-specific files (MonoDevelop/Xamarin Studio) +*.userprefs + +# Build results +[Dd]ebug/ +[Dd]ebugPublic/ +[Rr]elease/ +[Rr]eleases/ +x64/ +x86/ +bld/ +[Oo]bj/ +[Ll]og/ + +# Visual Studio 2015/2017 cache/options directory +.vs/ +# Uncomment if you have tasks that create the project's static files in wwwroot +#wwwroot/ + +# Visual Studio 2017 auto generated files +Generated\ Files/ + +# MSTest test Results +[Tt]est[Rr]esult*/ +[Bb]uild[Ll]og.* + +# NUNIT +*.VisualState.xml +TestResult.xml + +# Build Results of an ATL Project +[Dd]ebugPS/ +[Rr]eleasePS/ +dlldata.c + +# Benchmark Results +BenchmarkDotNet.Artifacts/ + +# .NET Core +project.lock.json +project.fragment.lock.json +artifacts/ + +# StyleCop +StyleCopReport.xml + +# Files built by Visual Studio +*_i.c +*_p.c +*_h.h +*.ilk +*.meta +*.iobj +*.pdb +*.ipdb +*.pgc +*.pgd +*.rsp +*.sbr +*.tlb +*.tli +*.tlh +*.tmp +*.tmp_proj +*_wpftmp.csproj +*.vspscc +*.vssscc +.builds +*.pidb +*.svclog +*.scc + +# Chutzpah Test files +_Chutzpah* + +# Visual C++ cache files +ipch/ +*.aps +*.ncb +*.opendb +*.opensdf +*.sdf +*.cachefile +*.VC.db +*.VC.VC.opendb + +# Visual Studio profiler +*.psess +*.vsp +*.vspx +*.sap + +# Visual Studio Trace Files +*.e2e + +# TFS 2012 Local Workspace +$tf/ + +# Guidance Automation Toolkit +*.gpState + +# ReSharper is a .NET coding add-in +_ReSharper*/ +*.[Rr]e[Ss]harper +*.DotSettings.user + +# JustCode is a .NET coding add-in +.JustCode + +# TeamCity is a build add-in +_TeamCity* + +# DotCover is a Code Coverage Tool +*.dotCover + +# AxoCover is a Code Coverage Tool +.axoCover/* +!.axoCover/settings.json + +# Visual Studio code coverage results +*.coverage +*.coveragexml + +# NCrunch +_NCrunch_* +.*crunch*.local.xml +nCrunchTemp_* + +# MightyMoose +*.mm.* +AutoTest.Net/ + +# Web workbench (sass) +.sass-cache/ + +# Installshield output folder +[Ee]xpress/ + +# DocProject is a documentation generator add-in +DocProject/buildhelp/ +DocProject/Help/*.HxT +DocProject/Help/*.HxC +DocProject/Help/*.hhc +DocProject/Help/*.hhk +DocProject/Help/*.hhp +DocProject/Help/Html2 +DocProject/Help/html + +# Click-Once directory +publish/ + +# Publish Web Output +*.[Pp]ublish.xml +*.azurePubxml +# Note: Comment the next line if you want to checkin your web deploy settings, +# but database connection strings (with potential passwords) will be unencrypted +*.pubxml +*.publishproj + +# Microsoft Azure Web App publish settings. Comment the next line if you want to +# checkin your Azure Web App publish settings, but sensitive information contained +# in these scripts will be unencrypted +PublishScripts/ + +# NuGet Packages +*.nupkg +# The packages folder can be ignored because of Package Restore +**/[Pp]ackages/* +# except build/, which is used as an MSBuild target. +!**/[Pp]ackages/build/ +# Uncomment if necessary however generally it will be regenerated when needed +#!**/[Pp]ackages/repositories.config +# NuGet v3's project.json files produces more ignorable files +*.nuget.props +*.nuget.targets + +# Microsoft Azure Build Output +csx/ +*.build.csdef + +# Microsoft Azure Emulator +ecf/ +rcf/ + +# Windows Store app package directories and files +AppPackages/ +BundleArtifacts/ +Package.StoreAssociation.xml +_pkginfo.txt +*.appx + +# Visual Studio cache files +# files ending in .cache can be ignored +*.[Cc]ache +# but keep track of directories ending in .cache +!*.[Cc]ache/ + +# Others +ClientBin/ +~$* +*~ +*.dbmdl +*.dbproj.schemaview +*.jfm +*.pfx +*.publishsettings +orleans.codegen.cs + +# Including strong name files can present a security risk +# (https://github.com/github/gitignore/pull/2483#issue-259490424) +#*.snk + +# Since there are multiple workflows, uncomment next line to ignore bower_components +# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) +#bower_components/ + +# RIA/Silverlight projects +Generated_Code/ + +# Backup & report files from converting an old project file +# to a newer Visual Studio version. Backup files are not needed, +# because we have git ;-) +_UpgradeReport_Files/ +Backup*/ +UpgradeLog*.XML +UpgradeLog*.htm +ServiceFabricBackup/ +*.rptproj.bak + +# SQL Server files +*.mdf +*.ldf +*.ndf + +# Business Intelligence projects +*.rdl.data +*.bim.layout +*.bim_*.settings +*.rptproj.rsuser + +# Microsoft Fakes +FakesAssemblies/ + +# GhostDoc plugin setting file +*.GhostDoc.xml + +# Node.js Tools for Visual Studio +.ntvs_analysis.dat +node_modules/ + +# Visual Studio 6 build log +*.plg + +# Visual Studio 6 workspace options file +*.opt + +# Visual Studio 6 auto-generated workspace file (contains which files were open etc.) +*.vbw + +# Visual Studio LightSwitch build output +**/*.HTMLClient/GeneratedArtifacts +**/*.DesktopClient/GeneratedArtifacts +**/*.DesktopClient/ModelManifest.xml +**/*.Server/GeneratedArtifacts +**/*.Server/ModelManifest.xml +_Pvt_Extensions + +# Paket dependency manager +.paket/paket.exe +paket-files/ + +# FAKE - F# Make +.fake/ + +# JetBrains Rider +.idea/ +*.sln.iml + +# CodeRush personal settings +.cr/personal + +# Python Tools for Visual Studio (PTVS) +*.pyc + +# Cake - Uncomment if you are using it +# tools/** +# !tools/packages.config + +# Tabs Studio +*.tss + +# Telerik's JustMock configuration file +*.jmconfig + +# BizTalk build output +*.btp.cs +*.btm.cs +*.odx.cs +*.xsd.cs + +# OpenCover UI analysis results +OpenCover/ + +# Azure Stream Analytics local run output +ASALocalRun/ + +# MSBuild Binary and Structured Log +*.binlog + +# NVidia Nsight GPU debugger configuration file +*.nvuser + +# MFractors (Xamarin productivity tool) working folder +.mfractor/ + +# Local History for Visual Studio +.localhistory/ + + +# End of https://www.gitignore.io/api/git,c++,cmake,python,visualstudio,visualstudiocode + +# credentials +.key +*.pem +.crt + +# deps from build-deps.sh +deps/ diff --git a/AWSIoTPythonSDK/MQTTLib.py b/AWSIoTPythonSDK/MQTTLib.py index cb01bc0..6b9f20c 100755 --- a/AWSIoTPythonSDK/MQTTLib.py +++ b/AWSIoTPythonSDK/MQTTLib.py @@ -15,6 +15,7 @@ # */ from AWSIoTPythonSDK.core.util.providers import CertificateCredentialsProvider +from AWSIoTPythonSDK.core.util.providers import CiphersProvider from AWSIoTPythonSDK.core.util.providers import IAMCredentialsProvider from AWSIoTPythonSDK.core.util.providers import EndpointProvider from AWSIoTPythonSDK.core.jobs.thingJobManager import jobExecutionTopicType @@ -207,7 +208,7 @@ def configureIAMCredentials(self, AWSAccessKeyID, AWSSecretAccessKey, AWSSession iam_credentials_provider.set_session_token(AWSSessionToken) self._mqtt_core.configure_iam_credentials(iam_credentials_provider) - def configureCredentials(self, CAFilePath, KeyPath="", CertificatePath=""): # Should be good for MutualAuth certs config and Websocket rootCA config + def configureCredentials(self, CAFilePath, KeyPath="", CertificatePath="", Ciphers=None): # Should be good for MutualAuth certs config and Websocket rootCA config """ **Description** @@ -227,6 +228,8 @@ def configureCredentials(self, CAFilePath, KeyPath="", CertificatePath=""): # S *CertificatePath* - Path to read the certificate. Required for X.509 certificate based connection. + *Ciphers* - String of colon split SSL ciphers to use. If not passed, default ciphers will be used. + **Returns** None @@ -236,7 +239,11 @@ def configureCredentials(self, CAFilePath, KeyPath="", CertificatePath=""): # S cert_credentials_provider.set_ca_path(CAFilePath) cert_credentials_provider.set_key_path(KeyPath) cert_credentials_provider.set_cert_path(CertificatePath) - self._mqtt_core.configure_cert_credentials(cert_credentials_provider) + + cipher_provider = CiphersProvider() + cipher_provider.set_ciphers(Ciphers) + + self._mqtt_core.configure_cert_credentials(cert_credentials_provider, cipher_provider) def configureAutoReconnectBackoffTime(self, baseReconnectQuietTimeSecond, maxReconnectQuietTimeSecond, stableConnectionTimeSecond): """ @@ -408,6 +415,33 @@ def configureUsernamePassword(self, username, password=None): """ self._mqtt_core.configure_username_password(username, password) + def configureSocketFactory(self, socket_factory): + """ + **Description** + + Configure a socket factory to custom configure a different socket type for + mqtt connection. Creating a custom socket allows for configuration of a proxy + + **Syntax** + + .. code:: python + + # Configure socket factory + custom_args = {"arg1": "val1", "arg2": "val2"} + socket_factory = lambda: custom.create_connection((host, port), **custom_args) + myAWSIoTMQTTClient.configureSocketFactory(socket_factory) + + **Parameters** + + *socket_factory* - Anonymous function which creates a custom socket to spec. + + **Returns** + + None + + """ + self._mqtt_core.configure_socket_factory(socket_factory) + def enableMetricsCollection(self): """ **Description** @@ -473,7 +507,8 @@ def connect(self, keepAliveIntervalSecond=600): **Parameters** - *keepAliveIntervalSecond* - Time in seconds for interval of sending MQTT ping request. + *keepAliveIntervalSecond* - Time in seconds for interval of sending MQTT ping request. + A shorter keep-alive interval allows the client to detect disconnects more quickly. Default set to 600 seconds. **Returns** @@ -1133,6 +1168,33 @@ def configureUsernamePassword(self, username, password=None): """ self._AWSIoTMQTTClient.configureUsernamePassword(username, password) + def configureSocketFactory(self, socket_factory): + """ + **Description** + + Configure a socket factory to custom configure a different socket type for + mqtt connection. Creating a custom socket allows for configuration of a proxy + + **Syntax** + + .. code:: python + + # Configure socket factory + custom_args = {"arg1": "val1", "arg2": "val2"} + socket_factory = lambda: custom.create_connection((host, port), **custom_args) + myAWSIoTMQTTClient.configureSocketFactory(socket_factory) + + **Parameters** + + *socket_factory* - Anonymous function which creates a custom socket to spec. + + **Returns** + + None + + """ + self._AWSIoTMQTTClient.configureSocketFactory(socket_factory) + def enableMetricsCollection(self): """ **Description** @@ -1498,13 +1560,14 @@ def createJobSubscription(self, callback, jobExecutionType=jobExecutionTopicType **Syntax** .. code:: python + #Subscribe to notify-next topic to monitor change in job referred to by $next myAWSIoTMQTTJobsClient.createJobSubscription(callback, jobExecutionTopicType.JOB_NOTIFY_NEXT_TOPIC) #Subscribe to notify topic to monitor changes to jobs in pending list myAWSIoTMQTTJobsClient.createJobSubscription(callback, jobExecutionTopicType.JOB_NOTIFY_TOPIC) - #Subscribe to recieve messages for job execution updates + #Subscribe to receive messages for job execution updates myAWSIoTMQTTJobsClient.createJobSubscription(callback, jobExecutionTopicType.JOB_UPDATE_TOPIC, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE) - #Subscribe to recieve messages for describing a job execution + #Subscribe to receive messages for describing a job execution myAWSIoTMQTTJobsClient.createJobSubscription(callback, jobExecutionTopicType.JOB_DESCRIBE_TOPIC, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE, jobId) **Parameters** @@ -1541,13 +1604,14 @@ def createJobSubscriptionAsync(self, ackCallback, callback, jobExecutionType=job **Syntax** .. code:: python + #Subscribe to notify-next topic to monitor change in job referred to by $next myAWSIoTMQTTJobsClient.createJobSubscriptionAsync(callback, jobExecutionTopicType.JOB_NOTIFY_NEXT_TOPIC) #Subscribe to notify topic to monitor changes to jobs in pending list myAWSIoTMQTTJobsClient.createJobSubscriptionAsync(callback, jobExecutionTopicType.JOB_NOTIFY_TOPIC) - #Subscribe to recieve messages for job execution updates + #Subscribe to receive messages for job execution updates myAWSIoTMQTTJobsClient.createJobSubscriptionAsync(callback, jobExecutionTopicType.JOB_UPDATE_TOPIC, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE) - #Subscribe to recieve messages for describing a job execution + #Subscribe to receive messages for describing a job execution myAWSIoTMQTTJobsClient.createJobSubscriptionAsync(callback, jobExecutionTopicType.JOB_DESCRIBE_TOPIC, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE, jobId) **Parameters** @@ -1588,6 +1652,7 @@ def sendJobsQuery(self, jobExecTopicType, jobId=None): **Syntax** .. code:: python + #send a request to describe the next job myAWSIoTMQTTJobsClient.sendJobsQuery(jobExecutionTopicType.JOB_DESCRIBE_TOPIC, '$next') #send a request to get list of pending jobs @@ -1609,7 +1674,7 @@ def sendJobsQuery(self, jobExecTopicType, jobId=None): payload = self._thingJobManager.serializeClientTokenPayload() return self._AWSIoTMQTTClient.publish(topic, payload, self._QoS) - def sendJobsStartNext(self, statusDetails=None): + def sendJobsStartNext(self, statusDetails=None, stepTimeoutInMinutes=None): """ **Description** @@ -1619,6 +1684,7 @@ def sendJobsStartNext(self, statusDetails=None): **Syntax** .. code:: python + #Start next job (set status to IN_PROGRESS) and update with optional statusDetails myAWSIoTMQTTJobsClient.sendJobsStartNext({'StartedBy': 'myClientId'}) @@ -1626,16 +1692,18 @@ def sendJobsStartNext(self, statusDetails=None): *statusDetails* - Dictionary containing the key value pairs to use for the status details of the job execution + *stepTimeoutInMinutes - Specifies the amount of time this device has to finish execution of this job. + **Returns** True if the publish request has been sent to paho. False if the request did not reach paho. """ topic = self._thingJobManager.getJobTopic(jobExecutionTopicType.JOB_START_NEXT_TOPIC, jobExecutionTopicReplyType.JOB_REQUEST_TYPE) - payload = self._thingJobManager.serializeStartNextPendingJobExecutionPayload(statusDetails) + payload = self._thingJobManager.serializeStartNextPendingJobExecutionPayload(statusDetails, stepTimeoutInMinutes) return self._AWSIoTMQTTClient.publish(topic, payload, self._QoS) - def sendJobsUpdate(self, jobId, status, statusDetails=None, expectedVersion=0, executionNumber=0, includeJobExecutionState=False, includeJobDocument=False): + def sendJobsUpdate(self, jobId, status, statusDetails=None, expectedVersion=0, executionNumber=0, includeJobExecutionState=False, includeJobDocument=False, stepTimeoutInMinutes=None): """ **Description** @@ -1645,6 +1713,7 @@ def sendJobsUpdate(self, jobId, status, statusDetails=None, expectedVersion=0, e **Syntax** .. code:: python + #Update job with id 'jobId123' to succeeded state, specifying new status details, with expectedVersion=1, executionNumber=2. #For the response, include job execution state and not the job document myAWSIoTMQTTJobsClient.sendJobsUpdate('jobId123', jobExecutionStatus.JOB_EXECUTION_SUCCEEDED, statusDetailsMap, 1, 2, True, False) @@ -1670,7 +1739,9 @@ def sendJobsUpdate(self, jobId, status, statusDetails=None, expectedVersion=0, e *includeJobExecutionState* - When included and set to True, the response contains the JobExecutionState field. The default is False. - *includeJobDocument - When included and set to True, the response contains the JobDocument. The default is False. + *includeJobDocument* - When included and set to True, the response contains the JobDocument. The default is False. + + *stepTimeoutInMinutes - Specifies the amount of time this device has to finish execution of this job. **Returns** @@ -1678,7 +1749,7 @@ def sendJobsUpdate(self, jobId, status, statusDetails=None, expectedVersion=0, e """ topic = self._thingJobManager.getJobTopic(jobExecutionTopicType.JOB_UPDATE_TOPIC, jobExecutionTopicReplyType.JOB_REQUEST_TYPE, jobId) - payload = self._thingJobManager.serializeJobExecutionUpdatePayload(status, statusDetails, expectedVersion, executionNumber, includeJobExecutionState, includeJobDocument) + payload = self._thingJobManager.serializeJobExecutionUpdatePayload(status, statusDetails, expectedVersion, executionNumber, includeJobExecutionState, includeJobDocument, stepTimeoutInMinutes) return self._AWSIoTMQTTClient.publish(topic, payload, self._QoS) def sendJobsDescribe(self, jobId, executionNumber=0, includeJobDocument=True): @@ -1690,6 +1761,7 @@ def sendJobsDescribe(self, jobId, executionNumber=0, includeJobDocument=True): **Syntax** .. code:: python + #Describe job with id 'jobId1' of any executionNumber, job document will be included in response myAWSIoTMQTTJobsClient.sendJobsDescribe('jobId1') diff --git a/AWSIoTPythonSDK/__init__.py b/AWSIoTPythonSDK/__init__.py index 6b8b97e..3a384fb 100755 --- a/AWSIoTPythonSDK/__init__.py +++ b/AWSIoTPythonSDK/__init__.py @@ -1,3 +1 @@ -__version__ = "1.4.0" - - +__version__ = "1.5.4" diff --git a/AWSIoTPythonSDK/core/greengrass/discovery/providers.py b/AWSIoTPythonSDK/core/greengrass/discovery/providers.py index 0842ba1..192f71a 100644 --- a/AWSIoTPythonSDK/core/greengrass/discovery/providers.py +++ b/AWSIoTPythonSDK/core/greengrass/discovery/providers.py @@ -21,6 +21,7 @@ from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryTimeoutException from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryFailure from AWSIoTPythonSDK.core.greengrass.discovery.models import DiscoveryInfo +from AWSIoTPythonSDK.core.protocol.connection.alpn import SSLContextBuilder import re import sys import ssl @@ -202,7 +203,7 @@ def discover(self, thingName): """ **Description** - + Perform the discovery request for the given Greengrass aware device thing name. **Syntax** @@ -245,16 +246,43 @@ def _create_tcp_connection(self): def _create_ssl_connection(self, sock): self._logger.debug("Creating ssl connection...") - ssl_sock = ssl.wrap_socket(sock, - certfile=self._cert_path, - keyfile=self._key_path, - ca_certs=self._ca_path, - cert_reqs=ssl.CERT_REQUIRED, - ssl_version=ssl.PROTOCOL_SSLv23) + + ssl_protocol_version = ssl.PROTOCOL_SSLv23 + + if self._port == 443: + ssl_context = SSLContextBuilder()\ + .with_ca_certs(self._ca_path)\ + .with_cert_key_pair(self._cert_path, self._key_path)\ + .with_cert_reqs(ssl.CERT_REQUIRED)\ + .with_check_hostname(True)\ + .with_ciphers(None)\ + .with_alpn_protocols(['x-amzn-http-ca'])\ + .build() + ssl_sock = ssl_context.wrap_socket(sock, server_hostname=self._host, do_handshake_on_connect=False) + ssl_sock.do_handshake() + else: + # To keep the SSL Context update minimal, only apply forced ssl context to python3.12+ + force_ssl_context = sys.version_info[0] > 3 or (sys.version_info[0] == 3 and sys.version_info[1] >= 12) + if force_ssl_context: + ssl_context = ssl.SSLContext(ssl_protocol_version) + ssl_context.load_cert_chain(self._cert_path, self._key_path) + ssl_context.load_verify_locations(self._ca_path) + ssl_context.verify_mode = ssl.CERT_REQUIRED + + ssl_sock = ssl_context.wrap_socket(sock) + else: + ssl_sock = ssl.wrap_socket(sock, + certfile=self._cert_path, + keyfile=self._key_path, + ca_certs=self._ca_path, + cert_reqs=ssl.CERT_REQUIRED, + ssl_version=ssl_protocol_version) + self._logger.debug("Matching host name...") if sys.version_info[0] < 3 or (sys.version_info[0] == 3 and sys.version_info[1] < 2): self._tls_match_hostname(ssl_sock) - else: + elif sys.version_info[0] == 3 and sys.version_info[1] < 7: + # host name verification is handled internally in Python3.7+ ssl.match_hostname(ssl_sock.getpeercert(), self._host) return ssl_sock @@ -349,9 +377,14 @@ def _receive_until(self, ssl_sock, criteria_function, extra_data=None): start_time = time.time() response = bytearray() number_bytes_read = 0 + ssl_sock_tmp = None while True: # Python does not have do-while try: - response.append(self._convert_to_int_py3(ssl_sock.read(1))) + ssl_sock_tmp = self._convert_to_int_py3(ssl_sock.read(1)) + if isinstance(ssl_sock_tmp, list): + response.extend(ssl_sock_tmp) + else: + response.append(ssl_sock_tmp) number_bytes_read += 1 except socket.error as err: if err.errno == ssl.SSL_ERROR_WANT_READ or err.errno == ssl.SSL_ERROR_WANT_WRITE: diff --git a/AWSIoTPythonSDK/core/jobs/thingJobManager.py b/AWSIoTPythonSDK/core/jobs/thingJobManager.py index 0dd7290..d2396b2 100755 --- a/AWSIoTPythonSDK/core/jobs/thingJobManager.py +++ b/AWSIoTPythonSDK/core/jobs/thingJobManager.py @@ -37,6 +37,7 @@ _INCLUDE_JOB_EXECUTION_STATE_KEY = 'includeJobExecutionState' _INCLUDE_JOB_DOCUMENT_KEY = 'includeJobDocument' _CLIENT_TOKEN_KEY = 'clientToken' +_STEP_TIMEOUT_IN_MINUTES_KEY = 'stepTimeoutInMinutes' #The type of job topic. class jobExecutionTopicType(object): @@ -112,7 +113,7 @@ def getJobTopic(self, srcJobExecTopicType, srcJobExecTopicReplyType=jobExecution else: return '{0}{1}/jobs/{2}{3}'.format(_BASE_THINGS_TOPIC, self._thingName, srcJobExecTopicType[_JOB_OPERATION_INDEX], srcJobExecTopicReplyType[_JOB_SUFFIX_INDEX]) - def serializeJobExecutionUpdatePayload(self, status, statusDetails=None, expectedVersion=0, executionNumber=0, includeJobExecutionState=False, includeJobDocument=False): + def serializeJobExecutionUpdatePayload(self, status, statusDetails=None, expectedVersion=0, executionNumber=0, includeJobExecutionState=False, includeJobDocument=False, stepTimeoutInMinutes=None): executionStatus = _getExecutionStatus(status) if executionStatus is None: return None @@ -129,6 +130,8 @@ def serializeJobExecutionUpdatePayload(self, status, statusDetails=None, expecte payload[_INCLUDE_JOB_DOCUMENT_KEY] = True if self._clientToken is not None: payload[_CLIENT_TOKEN_KEY] = self._clientToken + if stepTimeoutInMinutes is not None: + payload[_STEP_TIMEOUT_IN_MINUTES_KEY] = stepTimeoutInMinutes return json.dumps(payload) def serializeDescribeJobExecutionPayload(self, executionNumber=0, includeJobDocument=True): @@ -139,12 +142,14 @@ def serializeDescribeJobExecutionPayload(self, executionNumber=0, includeJobDocu payload[_CLIENT_TOKEN_KEY] = self._clientToken return json.dumps(payload) - def serializeStartNextPendingJobExecutionPayload(self, statusDetails=None): + def serializeStartNextPendingJobExecutionPayload(self, statusDetails=None, stepTimeoutInMinutes=None): payload = {} if self._clientToken is not None: payload[_CLIENT_TOKEN_KEY] = self._clientToken if statusDetails is not None: payload[_STATUS_DETAILS_KEY] = statusDetails + if stepTimeoutInMinutes is not None: + payload[_STEP_TIMEOUT_IN_MINUTES_KEY] = stepTimeoutInMinutes return json.dumps(payload) def serializeClientTokenPayload(self): diff --git a/AWSIoTPythonSDK/core/protocol/connection/alpn.py b/AWSIoTPythonSDK/core/protocol/connection/alpn.py index b7d5137..8da98dd 100644 --- a/AWSIoTPythonSDK/core/protocol/connection/alpn.py +++ b/AWSIoTPythonSDK/core/protocol/connection/alpn.py @@ -34,10 +34,6 @@ def check_supportability(self): if not hasattr(ssl.SSLContext, "set_alpn_protocols"): raise NotImplementedError("This platform does not support ALPN as TLS extensions. Python 2.7.10+/3.5+ is required.") - def with_protocol(self, protocol): - self._ssl_context.protocol = protocol - return self - def with_ca_certs(self, ca_certs): self._ssl_context.load_verify_locations(ca_certs) return self diff --git a/AWSIoTPythonSDK/core/protocol/connection/cores.py b/AWSIoTPythonSDK/core/protocol/connection/cores.py index a431d24..df12470 100644 --- a/AWSIoTPythonSDK/core/protocol/connection/cores.py +++ b/AWSIoTPythonSDK/core/protocol/connection/cores.py @@ -482,11 +482,11 @@ def _verifyWSSAcceptKey(self, srcAcceptKey, clientKey): def _handShake(self, hostAddress, portNumber): CRLF = "\r\n" - IOT_ENDPOINT_PATTERN = r"^[0-9a-zA-Z]+\.iot\.(.*)\.amazonaws\..*" - matched = re.compile(IOT_ENDPOINT_PATTERN).match(hostAddress) + IOT_ENDPOINT_PATTERN = r"^[0-9a-zA-Z]+(\.ats|-ats)?\.iot\.(.*)\.amazonaws\..*" + matched = re.compile(IOT_ENDPOINT_PATTERN, re.IGNORECASE).match(hostAddress) if not matched: raise ClientError("Invalid endpoint pattern for wss: %s" % hostAddress) - region = matched.group(1) + region = matched.group(2) signedURL = self._sigV4Handler.createWebsocketEndpoint(hostAddress, portNumber, region, "GET", "iotdata", "/mqtt") # Now we got a signedURL path = signedURL[signedURL.index("/mqtt"):] diff --git a/AWSIoTPythonSDK/core/protocol/internal/clients.py b/AWSIoTPythonSDK/core/protocol/internal/clients.py index 7a6552b..90f48b7 100644 --- a/AWSIoTPythonSDK/core/protocol/internal/clients.py +++ b/AWSIoTPythonSDK/core/protocol/internal/clients.py @@ -64,7 +64,7 @@ def _create_paho_client(self, client_id, clean_session, user_data, protocol, use return mqtt.Client(client_id, clean_session, user_data, protocol, use_wss) # TODO: Merge credentials providers configuration into one - def set_cert_credentials_provider(self, cert_credentials_provider): + def set_cert_credentials_provider(self, cert_credentials_provider, ciphers_provider): # History issue from Yun SDK where AR9331 embedded Linux only have Python 2.7.3 # pre-installed. In this version, TLSv1_2 is not even an option. # SSLv23 is a work-around which selects the highest TLS version between the client @@ -75,13 +75,16 @@ def set_cert_credentials_provider(self, cert_credentials_provider): # See also: https://docs.python.org/2/library/ssl.html#ssl.PROTOCOL_SSLv23 if self._use_wss: ca_path = cert_credentials_provider.get_ca_path() - self._paho_client.tls_set(ca_certs=ca_path, cert_reqs=ssl.CERT_REQUIRED, tls_version=ssl.PROTOCOL_SSLv23) + ciphers = ciphers_provider.get_ciphers() + self._paho_client.tls_set(ca_certs=ca_path, cert_reqs=ssl.CERT_REQUIRED, tls_version=ssl.PROTOCOL_SSLv23, + ciphers=ciphers) else: ca_path = cert_credentials_provider.get_ca_path() cert_path = cert_credentials_provider.get_cert_path() key_path = cert_credentials_provider.get_key_path() + ciphers = ciphers_provider.get_ciphers() self._paho_client.tls_set(ca_certs=ca_path,certfile=cert_path, keyfile=key_path, - cert_reqs=ssl.CERT_REQUIRED, tls_version=ssl.PROTOCOL_SSLv23) + cert_reqs=ssl.CERT_REQUIRED, tls_version=ssl.PROTOCOL_SSLv23, ciphers=ciphers) def set_iam_credentials_provider(self, iam_credentials_provider): self._paho_client.configIAMCredentials(iam_credentials_provider.get_access_key_id(), @@ -103,6 +106,9 @@ def clear_last_will(self): def set_username_password(self, username, password=None): self._paho_client.username_pw_set(username, password) + def set_socket_factory(self, socket_factory): + self._paho_client.socket_factory_set(socket_factory) + def configure_reconnect_back_off(self, base_reconnect_quiet_sec, max_reconnect_quiet_sec, stable_connection_sec): self._paho_client.setBackoffTiming(base_reconnect_quiet_sec, max_reconnect_quiet_sec, stable_connection_sec) diff --git a/AWSIoTPythonSDK/core/protocol/internal/workers.py b/AWSIoTPythonSDK/core/protocol/internal/workers.py index 604b201..e52db3f 100644 --- a/AWSIoTPythonSDK/core/protocol/internal/workers.py +++ b/AWSIoTPythonSDK/core/protocol/internal/workers.py @@ -187,11 +187,11 @@ def _handle_resubscribe(self): if subscriptions and not self._has_user_disconnect_request(): self._logger.debug("Start resubscribing") self._client_status.set_status(ClientStatus.RESUBSCRIBE) - for topic, (qos, message_callback) in subscriptions: + for topic, (qos, message_callback, ack_callback) in subscriptions: if self._has_user_disconnect_request(): self._logger.debug("User disconnect detected") break - self._internal_async_client.subscribe(topic, qos) + self._internal_async_client.subscribe(topic, qos, ack_callback) def _handle_draining(self): if self._offline_requests_manager.has_more() and not self._has_user_disconnect_request(): @@ -232,7 +232,7 @@ def _dispatch_message(self, mid, message): self._logger.debug("Dispatching [message] event") subscriptions = self._subscription_manager.list_records() if subscriptions: - for topic, (qos, message_callback) in subscriptions: + for topic, (qos, message_callback, _) in subscriptions: if topic_matches_sub(topic, message.topic) and message_callback: message_callback(None, None, message) # message_callback(client, userdata, message) @@ -242,15 +242,15 @@ def _handle_offline_publish(self, request): self._logger.debug("Processed offline publish request") def _handle_offline_subscribe(self, request): - topic, qos, message_callback = request.data - self._subscription_manager.add_record(topic, qos, message_callback) - self._internal_async_client.subscribe(topic, qos) + topic, qos, message_callback, ack_callback = request.data + self._subscription_manager.add_record(topic, qos, message_callback, ack_callback) + self._internal_async_client.subscribe(topic, qos, ack_callback) self._logger.debug("Processed offline subscribe request") def _handle_offline_unsubscribe(self, request): - topic = request.data + topic, ack_callback = request.data self._subscription_manager.remove_record(topic) - self._internal_async_client.unsubscribe(topic) + self._internal_async_client.unsubscribe(topic, ack_callback) self._logger.debug("Processed offline unsubscribe request") @@ -261,9 +261,9 @@ class SubscriptionManager(object): def __init__(self): self._subscription_map = dict() - def add_record(self, topic, qos, message_callback): + def add_record(self, topic, qos, message_callback, ack_callback): self._logger.debug("Adding a new subscription record: %s qos: %d", topic, qos) - self._subscription_map[topic] = qos, message_callback # message_callback could be None + self._subscription_map[topic] = qos, message_callback, ack_callback # message_callback and/or ack_callback could be None def remove_record(self, topic): self._logger.debug("Removing subscription record: %s", topic) diff --git a/AWSIoTPythonSDK/core/protocol/mqtt_core.py b/AWSIoTPythonSDK/core/protocol/mqtt_core.py index 215cfc7..fbdd6bf 100644 --- a/AWSIoTPythonSDK/core/protocol/mqtt_core.py +++ b/AWSIoTPythonSDK/core/protocol/mqtt_core.py @@ -91,14 +91,14 @@ def __init__(self, client_id, clean_session, protocol, use_wss): def _init_offline_request_exceptions(self): self._offline_request_queue_disabled_exceptions = { - RequestTypes.PUBLISH : publishQueueDisabledException(), - RequestTypes.SUBSCRIBE : subscribeQueueDisabledException(), - RequestTypes.UNSUBSCRIBE : unsubscribeQueueDisabledException() + RequestTypes.PUBLISH : publishQueueDisabledException, + RequestTypes.SUBSCRIBE : subscribeQueueDisabledException, + RequestTypes.UNSUBSCRIBE : unsubscribeQueueDisabledException } self._offline_request_queue_full_exceptions = { - RequestTypes.PUBLISH : publishQueueFullException(), - RequestTypes.SUBSCRIBE : subscribeQueueFullException(), - RequestTypes.UNSUBSCRIBE : unsubscribeQueueFullException() + RequestTypes.PUBLISH : publishQueueFullException, + RequestTypes.SUBSCRIBE : subscribeQueueFullException, + RequestTypes.UNSUBSCRIBE : unsubscribeQueueFullException } def _init_workers(self): @@ -127,9 +127,9 @@ def on_online(self): def on_offline(self): pass - def configure_cert_credentials(self, cert_credentials_provider): - self._logger.info("Configuring certificates...") - self._internal_async_client.set_cert_credentials_provider(cert_credentials_provider) + def configure_cert_credentials(self, cert_credentials_provider, ciphers_provider): + self._logger.info("Configuring certificates and ciphers...") + self._internal_async_client.set_cert_credentials_provider(cert_credentials_provider, ciphers_provider) def configure_iam_credentials(self, iam_credentials_provider): self._logger.info("Configuring custom IAM credentials...") @@ -171,6 +171,10 @@ def configure_username_password(self, username, password=None): self._username = username self._password = password + def configure_socket_factory(self, socket_factory): + self._logger.info("Configuring socket factory...") + self._internal_async_client.set_socket_factory(socket_factory) + def enable_metrics_collection(self): self._enable_metrics_collection = True @@ -292,7 +296,7 @@ def subscribe(self, topic, qos, message_callback=None): self._logger.info("Performing sync subscribe...") ret = False if ClientStatus.STABLE != self._client_status.get_status(): - self._handle_offline_request(RequestTypes.SUBSCRIBE, (topic, qos, message_callback)) + self._handle_offline_request(RequestTypes.SUBSCRIBE, (topic, qos, message_callback, None)) else: event = Event() rc, mid = self._subscribe_async(topic, qos, self._create_blocking_ack_callback(event), message_callback) @@ -306,14 +310,14 @@ def subscribe(self, topic, qos, message_callback=None): def subscribe_async(self, topic, qos, ack_callback=None, message_callback=None): self._logger.info("Performing async subscribe...") if ClientStatus.STABLE != self._client_status.get_status(): - self._handle_offline_request(RequestTypes.SUBSCRIBE, (topic, qos, message_callback)) + self._handle_offline_request(RequestTypes.SUBSCRIBE, (topic, qos, message_callback, ack_callback)) return FixedEventMids.QUEUED_MID else: rc, mid = self._subscribe_async(topic, qos, ack_callback, message_callback) return mid def _subscribe_async(self, topic, qos, ack_callback=None, message_callback=None): - self._subscription_manager.add_record(topic, qos, message_callback) + self._subscription_manager.add_record(topic, qos, message_callback, ack_callback) rc, mid = self._internal_async_client.subscribe(topic, qos, ack_callback) if MQTT_ERR_SUCCESS != rc: self._logger.error("Subscribe error: %d", rc) @@ -324,7 +328,7 @@ def unsubscribe(self, topic): self._logger.info("Performing sync unsubscribe...") ret = False if ClientStatus.STABLE != self._client_status.get_status(): - self._handle_offline_request(RequestTypes.UNSUBSCRIBE, topic) + self._handle_offline_request(RequestTypes.UNSUBSCRIBE, (topic, None)) else: event = Event() rc, mid = self._unsubscribe_async(topic, self._create_blocking_ack_callback(event)) @@ -338,7 +342,7 @@ def unsubscribe(self, topic): def unsubscribe_async(self, topic, ack_callback=None): self._logger.info("Performing async unsubscribe...") if ClientStatus.STABLE != self._client_status.get_status(): - self._handle_offline_request(RequestTypes.UNSUBSCRIBE, topic) + self._handle_offline_request(RequestTypes.UNSUBSCRIBE, (topic, ack_callback)) return FixedEventMids.QUEUED_MID else: rc, mid = self._unsubscribe_async(topic, ack_callback) @@ -363,7 +367,7 @@ def _handle_offline_request(self, type, data): append_result = self._offline_requests_manager.add_one(offline_request) if AppendResults.APPEND_FAILURE_QUEUE_DISABLED == append_result: self._logger.error("Offline request queue has been disabled") - raise self._offline_request_queue_disabled_exceptions[type] + raise self._offline_request_queue_disabled_exceptions[type]() if AppendResults.APPEND_FAILURE_QUEUE_FULL == append_result: self._logger.error("Offline request queue is full") - raise self._offline_request_queue_full_exceptions[type] + raise self._offline_request_queue_full_exceptions[type]() diff --git a/AWSIoTPythonSDK/core/protocol/paho/client.py b/AWSIoTPythonSDK/core/protocol/paho/client.py index 614e4bf..0b637c5 100755 --- a/AWSIoTPythonSDK/core/protocol/paho/client.py +++ b/AWSIoTPythonSDK/core/protocol/paho/client.py @@ -483,6 +483,7 @@ def __init__(self, client_id="", clean_session=True, userdata=None, protocol=MQT self._host = "" self._port = 1883 self._bind_address = "" + self._socket_factory = None self._in_callback = False self._strict_protocol = False self._callback_mutex = threading.Lock() @@ -780,7 +781,9 @@ def reconnect(self): self._messages_reconnect_reset() try: - if (sys.version_info[0] == 2 and sys.version_info[1] < 7) or (sys.version_info[0] == 3 and sys.version_info[1] < 2): + if self._socket_factory: + sock = self._socket_factory() + elif (sys.version_info[0] == 2 and sys.version_info[1] < 7) or (sys.version_info[0] == 3 and sys.version_info[1] < 2): sock = socket.create_connection((self._host, self._port)) else: sock = socket.create_connection((self._host, self._port), source_address=(self._bind_address, 0)) @@ -790,11 +793,22 @@ def reconnect(self): verify_hostname = self._tls_insecure is False # Decide whether we need to verify hostname + # To keep the SSL Context update minimal, only apply forced ssl context to python3.12+ + force_ssl_context = sys.version_info[0] > 3 or (sys.version_info[0] == 3 and sys.version_info[1] >= 12) + if self._tls_ca_certs is not None: if self._useSecuredWebsocket: # Never assign to ._ssl before wss handshake is finished # Non-None value for ._ssl will allow ops before wss-MQTT connection is established - rawSSL = ssl.wrap_socket(sock, ca_certs=self._tls_ca_certs, cert_reqs=ssl.CERT_REQUIRED) # Add server certificate verification + if force_ssl_context: + ssl_context = ssl.SSLContext() + ssl_context.load_verify_locations(self._tls_ca_certs) + ssl_context.verify_mode = ssl.CERT_REQUIRED + + rawSSL = ssl_context.wrap_socket(sock) + else: + rawSSL = ssl.wrap_socket(sock, ca_certs=self._tls_ca_certs, cert_reqs=ssl.CERT_REQUIRED) # Add server certificate verification + rawSSL.setblocking(0) # Non-blocking socket self._ssl = SecuredWebSocketCore(rawSSL, self._host, self._port, self._AWSAccessKeyIDCustomConfig, self._AWSSecretAccessKeyCustomConfig, self._AWSSessionTokenCustomConfig) # Override the _ssl socket # self._ssl.enableDebug() @@ -802,7 +816,6 @@ def reconnect(self): # SSLContext is required to enable ALPN support # Assuming Python 2.7.10+/3.5+ till the end of this elif branch ssl_context = SSLContextBuilder()\ - .with_protocol(self._tls_version)\ .with_ca_certs(self._tls_ca_certs)\ .with_cert_key_pair(self._tls_certfile, self._tls_keyfile)\ .with_cert_reqs(self._tls_cert_reqs)\ @@ -814,19 +827,30 @@ def reconnect(self): verify_hostname = False # Since check_hostname in SSLContext is already set to True, no need to verify it again self._ssl.do_handshake() else: - self._ssl = ssl.wrap_socket( - sock, - certfile=self._tls_certfile, - keyfile=self._tls_keyfile, - ca_certs=self._tls_ca_certs, - cert_reqs=self._tls_cert_reqs, - ssl_version=self._tls_version, - ciphers=self._tls_ciphers) + if force_ssl_context: + ssl_context = ssl.SSLContext(self._tls_version) + ssl_context.load_cert_chain(self._tls_certfile, self._tls_keyfile) + ssl_context.load_verify_locations(self._tls_ca_certs) + ssl_context.verify_mode = self._tls_cert_reqs + if self._tls_ciphers is not None: + ssl_context.set_ciphers(self._tls_ciphers) + + self._ssl = ssl_context.wrap_socket(sock) + else: + self._ssl = ssl.wrap_socket( + sock, + certfile=self._tls_certfile, + keyfile=self._tls_keyfile, + ca_certs=self._tls_ca_certs, + cert_reqs=self._tls_cert_reqs, + ssl_version=self._tls_version, + ciphers=self._tls_ciphers) if verify_hostname: if sys.version_info[0] < 3 or (sys.version_info[0] == 3 and sys.version_info[1] < 5): # No IP host match before 3.5.x self._tls_match_hostname() - else: + elif sys.version_info[0] == 3 and sys.version_info[1] < 7: + # host name verification is handled internally in Python3.7+ ssl.match_hostname(self._ssl.getpeercert(), self._host) self._sock = sock @@ -875,6 +899,15 @@ def loop(self, timeout=1.0, max_packets=1): self._out_packet_mutex.release() self._current_out_packet_mutex.release() + # used to check if there are any bytes left in the ssl socket + pending_bytes = 0 + if self._ssl: + pending_bytes = self.socket().pending() + + # if bytes are pending do not wait in select + if pending_bytes > 0: + timeout = 0.0 + # sockpairR is used to break out of select() before the timeout, on a # call to publish() etc. rlist = [self.socket(), self._sockpairR] @@ -890,7 +923,7 @@ def loop(self, timeout=1.0, max_packets=1): except: return MQTT_ERR_UNKNOWN - if self.socket() in socklist[0]: + if self.socket() in socklist[0] or pending_bytes > 0: rc = self.loop_read(max_packets) if rc or (self._ssl is None and self._sock is None): return rc @@ -1015,6 +1048,14 @@ def username_pw_set(self, username, password=None): self._username = username.encode('utf-8') self._password = password + def socket_factory_set(self, socket_factory): + """Set a socket factory to custom configure a different socket type for + mqtt connection. + Must be called before connect() to have any effect. + socket_factory: create_connection function which creates a socket to user's specification + """ + self._socket_factory = socket_factory + def disconnect(self): """Disconnect a connected client from the broker.""" self._state_mutex.acquire() @@ -2413,7 +2454,7 @@ def _tls_match_hostname(self): return if key == 'IP Address': have_san_dns = True - if value.lower() == self._host.lower(): + if value.lower().strip() == self._host.lower().strip(): return if have_san_dns: diff --git a/AWSIoTPythonSDK/core/shadow/deviceShadow.py b/AWSIoTPythonSDK/core/shadow/deviceShadow.py index f58240a..bb5d667 100755 --- a/AWSIoTPythonSDK/core/shadow/deviceShadow.py +++ b/AWSIoTPythonSDK/core/shadow/deviceShadow.py @@ -27,6 +27,14 @@ def getNextToken(self): return uuid.uuid4().urn[self.URN_PREFIX_LENGTH:] # We only need the uuid digits, not the urn prefix +def _validateJSON(jsonString): + try: + json.loads(jsonString) + except ValueError: + return False + return True + + class _basicJSONParser: def setString(self, srcString): @@ -246,7 +254,9 @@ def shadowGet(self, srcCallback, srcTimeout): # One publish self._shadowManagerHandler.basicShadowPublish(self._shadowName, "get", currentPayload) # Start the timer - self._tokenPool[currentToken].start() + with self._dataStructureLock: + if currentToken in self._tokenPool: + self._tokenPool[currentToken].start() return currentToken def shadowDelete(self, srcCallback, srcTimeout): @@ -301,7 +311,9 @@ def shadowDelete(self, srcCallback, srcTimeout): # One publish self._shadowManagerHandler.basicShadowPublish(self._shadowName, "delete", currentPayload) # Start the timer - self._tokenPool[currentToken].start() + with self._dataStructureLock: + if currentToken in self._tokenPool: + self._tokenPool[currentToken].start() return currentToken def shadowUpdate(self, srcJSONPayload, srcCallback, srcTimeout): @@ -339,9 +351,10 @@ def shadowUpdate(self, srcJSONPayload, srcCallback, srcTimeout): """ # Validate JSON - self._basicJSONParserHandler.setString(srcJSONPayload) - if self._basicJSONParserHandler.validateJSON(): + if _validateJSON(srcJSONPayload): with self._dataStructureLock: + self._basicJSONParserHandler.setString(srcJSONPayload) + self._basicJSONParserHandler.validateJSON() # clientToken currentToken = self._tokenHandler.getNextToken() self._tokenPool[currentToken] = Timer(srcTimeout, self._timerHandler, ["update", currentToken]) @@ -359,7 +372,9 @@ def shadowUpdate(self, srcJSONPayload, srcCallback, srcTimeout): # One publish self._shadowManagerHandler.basicShadowPublish(self._shadowName, "update", JSONPayloadWithToken) # Start the timer - self._tokenPool[currentToken].start() + with self._dataStructureLock: + if currentToken in self._tokenPool: + self._tokenPool[currentToken].start() else: raise ValueError("Invalid JSON file.") return currentToken diff --git a/AWSIoTPythonSDK/core/util/providers.py b/AWSIoTPythonSDK/core/util/providers.py index d90789a..d09f8a0 100644 --- a/AWSIoTPythonSDK/core/util/providers.py +++ b/AWSIoTPythonSDK/core/util/providers.py @@ -90,3 +90,13 @@ def get_host(self): def get_port(self): return self._port + +class CiphersProvider(object): + def __init__(self): + self._ciphers = None + + def set_ciphers(self, ciphers=None): + self._ciphers = ciphers + + def get_ciphers(self): + return self._ciphers diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 7b316a5..765c557 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -2,6 +2,50 @@ CHANGELOG ========= +1.5.5 +===== +* chore: Update minimum Python version based on current supportable levels + +1.5.4 +===== +* chore: CD pipeline flushing try #1 + +1.5.3 +===== +* improvement: Support Python3.12+ by conditionally removing deprecated API usage + +1.4.9 +===== +* bugfix: Fixing possible race condition with timer in deviceShadow. + +1.4.8 +===== +* improvement: Added support for subscription acknowledgement callbacks while offline or resubscribing + +1.4.7 +===== +* improvement: Added connection establishment control through client socket factory option + +1.4.6 +===== +* bugfix: Use non-deprecated ssl API to specify ALPN when doing Greengrass discovery + +1.4.5 +===== +* improvement: Added validation to mTLS arguments in basicDiscovery + +1.4.3 +===== +* bugfix: [Issue #150](https://github.com/aws/aws-iot-device-sdk-python/issues/150)Fix for ALPN in Python 3.7 + +1.4.2 +===== +* bugfix: Websocket handshake supports Amazon Trust Store (ats) endpoints +* bugfix: Remove default port number in samples, which prevented WebSocket mode from using 443 +* bugfix: jobsSample print statements compatible with Python 3.x +* improvement: Small fixes to IoT Jobs documentation + + 1.4.0 ===== * bugfix:Issue `#136 ` diff --git a/README.rst b/README.rst index 6b65ef3..ba88218 100755 --- a/README.rst +++ b/README.rst @@ -1,3 +1,9 @@ +New Version Available +============================= +A new AWS IoT Device SDK is [now available](https://github.com/awslabs/aws-iot-device-sdk-python-v2). It is a complete rework, built to improve reliability, performance, and security. We invite your feedback! + +This SDK will no longer receive feature updates, but will receive security updates. + AWS IoT Device SDK for Python ============================= @@ -64,22 +70,10 @@ also allows the use of the same connection for shadow operations and non-shadow, Installation ~~~~~~~~~~~~ -Minimum Requirements +Requirements ____________________ -- Python 2.7+ or Python 3.3+ for X.509 certificate-based mutual authentication via port 8883 - and MQTT over WebSocket protocol with AWS Signature Version 4 authentication -- Python 2.7.10+ or Python 3.5+ for X.509 certificate-based mutual authentication via port 443 -- OpenSSL version 1.0.1+ (TLS version 1.2) compiled with the Python executable for - X.509 certificate-based mutual authentication - - To check your version of OpenSSL, use the following command in a Python interpreter: - - .. code-block:: python - - >>> import ssl - >>> ssl.OPENSSL_VERSION - +- Python3.8+. The SDK has worked for older Python versions in the past, but they are no longer formally supported. Over time, expect the minimum Python version to loosely track the minimum non-end-of-life version. Install from pip ________________ @@ -142,7 +136,7 @@ types: For the certificate-based mutual authentication connection type. Download the `AWS IoT root - CA `__. + CA `__. Use the AWS IoT console to create and download the certificate and private key. You must specify the location of these files when you initialize the client. @@ -151,8 +145,8 @@ types: For the Websocket with Signature Version 4 authentication type. You will need IAM credentials: an access key ID, a secret access key, and an optional session token. You must also download the `AWS IoT root - CA `__. - You can specify the IAM credentails by: + CA `__. + You can specify the IAM credentials by: - Passing method parameters @@ -629,6 +623,18 @@ accepted/rejected topics. In all SDK examples, PersistentSubscription is used in consideration of its better performance. +SSL Ciphers Setup +______________________________________ +If custom SSL Ciphers are required for the client, they can be set when configuring the client before +starting the connection. + +To setup specific SSL Ciphers: + +.. code-block:: python + + myAWSIoTMQTTClient.configureCredentials(rootCAPath, privateKeyPath, certificatePath, Ciphers="AES128-SHA256") + + .. _Examples: Examples diff --git a/continuous-delivery/pip-install-with-retry.py b/continuous-delivery/pip-install-with-retry.py new file mode 100644 index 0000000..347e0dc --- /dev/null +++ b/continuous-delivery/pip-install-with-retry.py @@ -0,0 +1,39 @@ +import time +import sys +import subprocess + +DOCS = """Given cmdline args, executes: python3 -m pip install [args...] +Keeps retrying until the new version becomes available in pypi (or we time out)""" +if len(sys.argv) < 2: + sys.exit(DOCS) + +RETRY_INTERVAL_SECS = 10 +GIVE_UP_AFTER_SECS = 60 * 15 + +pip_install_args = [sys.executable, '-m', 'pip', 'install'] + sys.argv[1:] + +start_time = time.time() +while True: + print(subprocess.list2cmdline(pip_install_args)) + result = subprocess.run(pip_install_args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + + stdout = result.stdout.decode().strip() + if stdout: + print(stdout) + + if result.returncode == 0: + # success + sys.exit(0) + + if "could not find a version" in stdout.lower(): + elapsed_secs = time.time() - start_time + if elapsed_secs < GIVE_UP_AFTER_SECS: + # try again + print("Retrying in", RETRY_INTERVAL_SECS, "secs...") + time.sleep(RETRY_INTERVAL_SECS) + continue + else: + print("Giving up on retries after", int(elapsed_secs), "total secs.") + + # fail + sys.exit(result.returncode) diff --git a/continuous-delivery/publish_to_prod_pypi.yml b/continuous-delivery/publish_to_prod_pypi.yml new file mode 100644 index 0000000..905d849 --- /dev/null +++ b/continuous-delivery/publish_to_prod_pypi.yml @@ -0,0 +1,25 @@ +version: 0.2 +# this image assumes Ubuntu 14.04 base image +phases: + install: + commands: + - sudo apt-get update -y + - sudo apt-get install python3 python3-pip -y + - export PATH=$PATH:$HOME/.local/bin + - python3 -m pip install --user --upgrade pip + - python3 -m pip install --user --upgrade twine setuptools wheel awscli PyOpenSSL six + pre_build: + commands: + - cd aws-iot-device-sdk-python + - pypirc=$(aws secretsmanager get-secret-value --secret-id "prod/aws-sdk-python-v1/.pypirc" --query "SecretString" | cut -f2 -d\") && echo "$pypirc" > ~/.pypirc + - export PKG_VERSION=$(git describe --tags | cut -f2 -dv) + - echo "Updating package version to ${PKG_VERSION}" + - sed --in-place -E "s/__version__ = \".+\"/__version__ = \"${PKG_VERSION}\"/" AWSIoTPythonSDK/__init__.py + build: + commands: + - echo Build started on `date` + - python3 setup.py sdist bdist_wheel --universal + - python3 -m twine upload -r pypi dist/* + post_build: + commands: + - echo Build completed on `date` diff --git a/continuous-delivery/publish_to_test_pypi.yml b/continuous-delivery/publish_to_test_pypi.yml new file mode 100644 index 0000000..c435e5e --- /dev/null +++ b/continuous-delivery/publish_to_test_pypi.yml @@ -0,0 +1,25 @@ +version: 0.2 +# this image assumes Ubuntu 14.04 base image +phases: + install: + commands: + - sudo apt-get update -y + - sudo apt-get install python3 python3-pip -y + - export PATH=$PATH:$HOME/.local/bin + - python3 -m pip install --user --upgrade pip + - python3 -m pip install --user --upgrade twine setuptools wheel awscli PyOpenSSL six + pre_build: + commands: + - pypirc=$(aws secretsmanager get-secret-value --secret-id "alpha/aws-sdk-python-v1/.pypirc" --query "SecretString" | cut -f2 -d\") && echo "$pypirc" > ~/.pypirc + - cd aws-iot-device-sdk-python + - export PKG_VERSION=$(git describe --tags | cut -f2 -dv) + - echo "Updating package version to ${PKG_VERSION}" + - sed --in-place -E "s/__version__ = \".+\"/__version__ = \"${PKG_VERSION}\"/" AWSIoTPythonSDK/__init__.py + build: + commands: + - echo Build started on `date` + - python3 setup_test.py sdist bdist_wheel --universal + - python3 -m twine upload -r testpypi dist/* --verbose + post_build: + commands: + - echo Build completed on `date` diff --git a/continuous-delivery/test_prod_pypi.yml b/continuous-delivery/test_prod_pypi.yml new file mode 100644 index 0000000..4575306 --- /dev/null +++ b/continuous-delivery/test_prod_pypi.yml @@ -0,0 +1,28 @@ +version: 0.2 +# this image assumes Ubuntu 14.04 base image +phases: + install: + commands: + - sudo apt-get update -y + - sudo apt-get install python3 python3-pip -y + - python3 -m pip install --upgrade pip + - python3 -m pip install --upgrade setuptools + + pre_build: + commands: + - curl https://www.amazontrust.com/repository/AmazonRootCA1.pem --output /tmp/AmazonRootCA1.pem + - cert=$(aws secretsmanager get-secret-value --secret-id "unit-test/certificate" --query "SecretString" | cut -f2 -d":" | cut -f2 -d\") && echo "$cert" > /tmp/certificate.pem + - key=$(aws secretsmanager get-secret-value --secret-id "unit-test/privatekey" --query "SecretString" | cut -f2 -d":" | cut -f2 -d\") && echo "$key" > /tmp/privatekey.pem + - ENDPOINT=$(aws secretsmanager get-secret-value --secret-id "unit-test/endpoint" --query "SecretString" | cut -f2 -d":" | sed -e 's/[\\\"\}]//g') + build: + commands: + - echo Build started on `date` + - cd aws-iot-device-sdk-python + - CURRENT_TAG_VERSION=$(git describe --tags | cut -f2 -dv) + - python3 continuous-delivery/pip-install-with-retry.py --no-cache-dir --user AWSIoTPythonSDK==$CURRENT_TAG_VERSION + - python3 samples/greengrass/basicDiscovery.py -e ${ENDPOINT} -c /tmp/certificate.pem -k /tmp/privatekey.pem -r /tmp/AmazonRootCA1.pem --print_discover_resp_only + + post_build: + commands: + - echo Build completed on `date` + diff --git a/continuous-delivery/test_test_pypi.yml b/continuous-delivery/test_test_pypi.yml new file mode 100644 index 0000000..c3aa47d --- /dev/null +++ b/continuous-delivery/test_test_pypi.yml @@ -0,0 +1,30 @@ +version: 0.2 +# this image assumes Ubuntu 14.04 base image +phases: + install: + commands: + - sudo apt-get update -y + - sudo apt-get install python3 python3-pip -y + - python3 -m pip install --upgrade pip + - python3 -m pip install --upgrade setuptools + + pre_build: + commands: + - curl https://www.amazontrust.com/repository/AmazonRootCA1.pem --output /tmp/AmazonRootCA1.pem + - cert=$(aws secretsmanager get-secret-value --secret-id "unit-test/certificate" --query "SecretString" | cut -f2 -d":" | cut -f2 -d\") && echo "$cert" > /tmp/certificate.pem + - key=$(aws secretsmanager get-secret-value --secret-id "unit-test/privatekey" --query "SecretString" | cut -f2 -d":" | cut -f2 -d\") && echo "$key" > /tmp/privatekey.pem + - ENDPOINT=$(aws secretsmanager get-secret-value --secret-id "unit-test/endpoint" --query "SecretString" | cut -f2 -d":" | sed -e 's/[\\\"\}]//g') + build: + commands: + - echo Build started on `date` + - cd aws-iot-device-sdk-python + - CURRENT_TAG_VERSION=$(git describe --tags | cut -f2 -dv) + # this is here because typing isn't in testpypi, so pull it from prod instead + - python3 -m pip install typing + - python3 continuous-delivery/pip-install-with-retry.py -i https://testpypi.python.org/simple --user AWSIoTPythonSDK-V1==$CURRENT_TAG_VERSION + - python3 samples/greengrass/basicDiscovery.py -e ${ENDPOINT} -c /tmp/certificate.pem -k /tmp/privatekey.pem -r /tmp/AmazonRootCA1.pem --print_discover_resp_only + + post_build: + commands: + - echo Build completed on `date` + diff --git a/continuous-delivery/test_version_exists b/continuous-delivery/test_version_exists new file mode 100644 index 0000000..3579dbc --- /dev/null +++ b/continuous-delivery/test_version_exists @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +set -e +set -x +# force a failure if there's no tag +git describe --tags +# now get the tag +CURRENT_TAG=$(git describe --tags | cut -f2 -dv) +# convert v0.2.12-2-g50254a9 to 0.2.12 +CURRENT_TAG_VERSION=$(git describe --tags | cut -f1 -d'-' | cut -f2 -dv) +# if there's a hash on the tag, then this is not a release tagged commit +if [ "$CURRENT_TAG" != "$CURRENT_TAG_VERSION" ]; then + echo "Current tag version is not a release tag, cut a new release if you want to publish." + exit 1 +fi + +if python3 -m pip install --no-cache-dir -vvv AWSIoTPythonSDK==$CURRENT_TAG_VERSION; then + echo "$CURRENT_TAG_VERSION is already in pypi, cut a new tag if you want to upload another version." + exit 1 +fi + +echo "$CURRENT_TAG_VERSION currently does not exist in pypi, allowing pipeline to continue." +exit 0 diff --git a/continuous-delivery/test_version_exists.yml b/continuous-delivery/test_version_exists.yml new file mode 100644 index 0000000..2704ba7 --- /dev/null +++ b/continuous-delivery/test_version_exists.yml @@ -0,0 +1,21 @@ +version: 0.2 +#this build spec assumes the ubuntu 14.04 trusty image +#this build run simply verifies we haven't published something at this tag yet. +#if we have we fail the build and stop the pipeline, if we haven't we allow the pipeline to run. +phases: + install: + commands: + - sudo apt-get update -y + - sudo apt-get install python3 python3-pip -y + - pip3 install --upgrade setuptools + pre_build: + commands: + - echo Build start on `date` + build: + commands: + - cd aws-iot-device-sdk-python + - bash ./continuous-delivery/test_version_exists + post_build: + commands: + - echo Build completed on `date` + diff --git a/samples/basicPubSub/basicPubSubProxy.py b/samples/basicPubSub/basicPubSubProxy.py new file mode 100644 index 0000000..929ef42 --- /dev/null +++ b/samples/basicPubSub/basicPubSubProxy.py @@ -0,0 +1,136 @@ +''' +/* + * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + ''' + +from AWSIoTPythonSDK.MQTTLib import AWSIoTMQTTClient +import logging +import time +import argparse +import json + +AllowedActions = ['both', 'publish', 'subscribe'] + +# Custom MQTT message callback +def customCallback(client, userdata, message): + print("Received a new message: ") + print(message.payload) + print("from topic: ") + print(message.topic) + print("--------------\n\n") + + +# Read in command-line parameters +parser = argparse.ArgumentParser() +parser.add_argument("-e", "--endpoint", action="store", required=True, dest="host", help="Your AWS IoT custom endpoint") +parser.add_argument("-r", "--rootCA", action="store", required=True, dest="rootCAPath", help="Root CA file path") +parser.add_argument("-c", "--cert", action="store", dest="certificatePath", help="Certificate file path") +parser.add_argument("-k", "--key", action="store", dest="privateKeyPath", help="Private key file path") +parser.add_argument("-p", "--port", action="store", dest="port", type=int, help="Port number override") +parser.add_argument("-w", "--websocket", action="store_true", dest="useWebsocket", default=False, + help="Use MQTT over WebSocket") +parser.add_argument("-id", "--clientId", action="store", dest="clientId", default="basicPubSub", + help="Targeted client id") +parser.add_argument("-t", "--topic", action="store", dest="topic", default="sdk/test/Python", help="Targeted topic") +parser.add_argument("-m", "--mode", action="store", dest="mode", default="both", + help="Operation modes: %s"%str(AllowedActions)) +parser.add_argument("-M", "--message", action="store", dest="message", default="Hello World!", + help="Message to publish") + +args = parser.parse_args() +host = args.host +rootCAPath = args.rootCAPath +certificatePath = args.certificatePath +privateKeyPath = args.privateKeyPath +port = args.port +useWebsocket = args.useWebsocket +clientId = args.clientId +topic = args.topic + +if args.mode not in AllowedActions: + parser.error("Unknown --mode option %s. Must be one of %s" % (args.mode, str(AllowedActions))) + exit(2) + +if args.useWebsocket and args.certificatePath and args.privateKeyPath: + parser.error("X.509 cert authentication and WebSocket are mutual exclusive. Please pick one.") + exit(2) + +if not args.useWebsocket and (not args.certificatePath or not args.privateKeyPath): + parser.error("Missing credentials for authentication.") + exit(2) + +# Port defaults +if args.useWebsocket and not args.port: # When no port override for WebSocket, default to 443 + port = 443 +if not args.useWebsocket and not args.port: # When no port override for non-WebSocket, default to 8883 + port = 8883 + +# Configure logging +logger = logging.getLogger("AWSIoTPythonSDK.core") +logger.setLevel(logging.DEBUG) +streamHandler = logging.StreamHandler() +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +streamHandler.setFormatter(formatter) +logger.addHandler(streamHandler) + +# Init AWSIoTMQTTClient +myAWSIoTMQTTClient = None +if useWebsocket: + myAWSIoTMQTTClient = AWSIoTMQTTClient(clientId, useWebsocket=True) + myAWSIoTMQTTClient.configureEndpoint(host, port) + myAWSIoTMQTTClient.configureCredentials(rootCAPath) +else: + myAWSIoTMQTTClient = AWSIoTMQTTClient(clientId) + myAWSIoTMQTTClient.configureEndpoint(host, port) + myAWSIoTMQTTClient.configureCredentials(rootCAPath, privateKeyPath, certificatePath) + +# AWSIoTMQTTClient connection configuration +myAWSIoTMQTTClient.configureAutoReconnectBackoffTime(1, 32, 20) +myAWSIoTMQTTClient.configureOfflinePublishQueueing(-1) # Infinite offline Publish queueing +myAWSIoTMQTTClient.configureDrainingFrequency(2) # Draining: 2 Hz +myAWSIoTMQTTClient.configureConnectDisconnectTimeout(10) # 10 sec +myAWSIoTMQTTClient.configureMQTTOperationTimeout(5) # 5 sec + +# AWSIoTMQTTClient socket configuration +# import pysocks to help us build a socket that supports a proxy configuration +import socks + +# set proxy arguments (for SOCKS5 proxy: proxy_type=2, for HTTP proxy: proxy_type=3) +proxy_config = {"proxy_addr":, "proxy_port":, "proxy_type":} + +# create anonymous function to handle socket creation +socket_factory = lambda: socks.create_connection((host, port), **proxy_config) +myAWSIoTMQTTClient.configureSocketFactory(socket_factory) + +# Connect and subscribe to AWS IoT +myAWSIoTMQTTClient.connect() +if args.mode == 'both' or args.mode == 'subscribe': + myAWSIoTMQTTClient.subscribe(topic, 1, customCallback) +time.sleep(2) + +# Publish to the same topic in a loop forever +loopCount = 0 +while True: + if args.mode == 'both' or args.mode == 'publish': + message = {} + message['message'] = args.message + message['sequence'] = loopCount + messageJson = json.dumps(message) + myAWSIoTMQTTClient.publish(topic, messageJson, 1) + if args.mode == 'publish': + print('Published topic %s: %s\n' % (topic, messageJson)) + loopCount += 1 + time.sleep(1) + diff --git a/samples/greengrass/basicDiscovery.py b/samples/greengrass/basicDiscovery.py index 73acb4d..a6fcd61 100644 --- a/samples/greengrass/basicDiscovery.py +++ b/samples/greengrass/basicDiscovery.py @@ -47,6 +47,8 @@ def customOnMessage(message): help="Operation modes: %s"%str(AllowedActions)) parser.add_argument("-M", "--message", action="store", dest="message", default="Hello World!", help="Message to publish") +#--print_discover_resp_only used for delopyment testing. The test run will return 0 as long as the SDK installed correctly. +parser.add_argument("-p", "--print_discover_resp_only", action="store_true", dest="print_only", default=False) args = parser.parse_args() host = args.host @@ -56,15 +58,28 @@ def customOnMessage(message): clientId = args.thingName thingName = args.thingName topic = args.topic +print_only = args.print_only if args.mode not in AllowedActions: parser.error("Unknown --mode option %s. Must be one of %s" % (args.mode, str(AllowedActions))) exit(2) if not args.certificatePath or not args.privateKeyPath: - parser.error("Missing credentials for authentication.") + parser.error("Missing credentials for authentication, you must specify --cert and --key args.") exit(2) +if not os.path.isfile(rootCAPath): + parser.error("Root CA path does not exist {}".format(rootCAPath)) + exit(3) + +if not os.path.isfile(certificatePath): + parser.error("No certificate found at {}".format(certificatePath)) + exit(3) + +if not os.path.isfile(privateKeyPath): + parser.error("No private key found at {}".format(privateKeyPath)) + exit(3) + # Configure logging logger = logging.getLogger("AWSIoTPythonSDK.core") logger.setLevel(logging.DEBUG) @@ -82,7 +97,7 @@ def customOnMessage(message): discoveryInfoProvider.configureCredentials(rootCAPath, certificatePath, privateKeyPath) discoveryInfoProvider.configureTimeout(10) # 10 sec -retryCount = MAX_DISCOVERY_RETRIES +retryCount = MAX_DISCOVERY_RETRIES if not print_only else 1 discovered = False groupCA = None coreInfo = None @@ -111,19 +126,22 @@ def customOnMessage(message): except DiscoveryInvalidRequestException as e: print("Invalid discovery request detected!") print("Type: %s" % str(type(e))) - print("Error message: %s" % e.message) + print("Error message: %s" % str(e)) print("Stopping...") break except BaseException as e: print("Error in discovery!") print("Type: %s" % str(type(e))) - print("Error message: %s" % e.message) + print("Error message: %s" % str(e)) retryCount -= 1 print("\n%d/%d retries left\n" % (retryCount, MAX_DISCOVERY_RETRIES)) print("Backing off...\n") backOffCore.backOff() if not discovered: + # With print_discover_resp_only flag, we only woud like to check if the API get called correctly. + if print_only: + sys.exit(0) print("Discovery failed after %d retries. Exiting...\n" % (MAX_DISCOVERY_RETRIES)) sys.exit(-1) @@ -145,7 +163,7 @@ def customOnMessage(message): except BaseException as e: print("Error in connect!") print("Type: %s" % str(type(e))) - print("Error message: %s" % e.message) + print("Error message: %s" % str(e)) if not connected: print("Cannot connect to core %s. Exiting..." % coreInfo.coreThingArn) diff --git a/setup.py b/setup.py index 3846bae..0ca4cfa 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ import AWSIoTPythonSDK currentVersion = AWSIoTPythonSDK.__version__ -from distutils.core import setup +from setuptools import setup setup( name = 'AWSIoTPythonSDK', packages=['AWSIoTPythonSDK', 'AWSIoTPythonSDK.core', @@ -20,15 +20,11 @@ download_url = 'https://s3.amazonaws.com/aws-iot-device-sdk-python/aws-iot-device-sdk-python-latest.zip', keywords = ['aws', 'iot', 'mqtt'], classifiers = [ - "Development Status :: 5 - Production/Stable", + "Development Status :: 6 - Mature", "Intended Audience :: Developers", "Natural Language :: English", "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", - "Programming Language :: Python :: 2.7", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.3", - "Programming Language :: Python :: 3.4", - "Programming Language :: Python :: 3.5" + "Programming Language :: Python :: 3" ] ) diff --git a/setup_test.py b/setup_test.py new file mode 100644 index 0000000..2a4e78e --- /dev/null +++ b/setup_test.py @@ -0,0 +1,34 @@ +# For test deployment with package AWSIoTPythonSDK. The package name has already taken. Therefore we used an +# alternative name for test pypi. +# prod_pypi : AWSIoTPythonSDK +# test_pypi : AWSIoTPythonSDK-V1 +import sys +sys.path.insert(0, 'AWSIoTPythonSDK') +import AWSIoTPythonSDK +currentVersion = AWSIoTPythonSDK.__version__ + +from distutils.core import setup +setup( + name = 'AWSIoTPythonSDK-V1', + packages=['AWSIoTPythonSDK', 'AWSIoTPythonSDK.core', + 'AWSIoTPythonSDK.core.util', 'AWSIoTPythonSDK.core.shadow', 'AWSIoTPythonSDK.core.protocol', + 'AWSIoTPythonSDK.core.jobs', + 'AWSIoTPythonSDK.core.protocol.paho', 'AWSIoTPythonSDK.core.protocol.internal', + 'AWSIoTPythonSDK.core.protocol.connection', 'AWSIoTPythonSDK.core.greengrass', + 'AWSIoTPythonSDK.core.greengrass.discovery', 'AWSIoTPythonSDK.exception'], + version = currentVersion, + description = 'SDK for connecting to AWS IoT using Python.', + author = 'Amazon Web Service', + author_email = '', + url = 'https://github.com/aws/aws-iot-device-sdk-python.git', + download_url = 'https://s3.amazonaws.com/aws-iot-device-sdk-python/aws-iot-device-sdk-python-latest.zip', + keywords = ['aws', 'iot', 'mqtt'], + classifiers = [ + "Development Status :: 6 - Mature", + "Intended Audience :: Developers", + "Natural Language :: English", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + ] +) diff --git a/test-integration/Credentials/.gitignore b/test-integration/Credentials/.gitignore new file mode 100644 index 0000000..94548af --- /dev/null +++ b/test-integration/Credentials/.gitignore @@ -0,0 +1,3 @@ +* +*/ +!.gitignore diff --git a/test-integration/IntegrationTests/IntegrationTestAsyncAPIGeneralNotificationCallbacks.py b/test-integration/IntegrationTests/IntegrationTestAsyncAPIGeneralNotificationCallbacks.py new file mode 100644 index 0000000..577c5fa --- /dev/null +++ b/test-integration/IntegrationTests/IntegrationTestAsyncAPIGeneralNotificationCallbacks.py @@ -0,0 +1,159 @@ +# This integration test verifies the functionality of asynchronous API for plain MQTT operations, as well as general +# notification callbacks. There are 2 phases for this test: +# a) Testing async APIs + onMessage general notification callback +# b) Testing onOnline, onOffline notification callbacks +# To achieve test goal a) and b), the client will follow the routine described below: +# 1. Client does async connect to AWS IoT and captures the CONNACK event and onOnline callback event in the record +# 2. Client does async subscribe to a topic and captures the SUBACK event in the record +# 3. Client does several async publish (QoS1) to the same topic and captures the PUBACK event in the record +# 4. Since client subscribes and publishes to the same topic, onMessage callback should be triggered. We capture these +# events as well in the record. +# 5. Client does async disconnect. This would trigger the offline callback and disconnect event callback. We capture +# them in the record. +# We should be able to receive all ACKs for all operations and corresponding general notification callback triggering +# events. + + +import random +import string +import time +import sys +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary") +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary/SDKPackage") + +from TestToolLibrary.checkInManager import checkInManager +from TestToolLibrary.MQTTClientManager import MQTTClientManager +from TestToolLibrary.skip import skip_when_match +from TestToolLibrary.skip import ModeIsALPN +from TestToolLibrary.skip import Python2VersionLowerThan +from TestToolLibrary.skip import Python3VersionLowerThan + + +TOPIC = "topic/test/async_cb/" +MESSAGE_PREFIX = "MagicMessage-" +NUMBER_OF_PUBLISHES = 3 +ROOT_CA = "./test-integration/Credentials/rootCA.crt" +CERT = "./test-integration/Credentials/certificate.pem.crt" +KEY = "./test-integration/Credentials/privateKey.pem.key" +CLIENT_ID = "PySdkIntegTest_AsyncAPI_Callbacks" + +KEY_ON_ONLINE = "OnOnline" +KEY_ON_OFFLINE = "OnOffline" +KEY_ON_MESSAGE = "OnMessage" +KEY_CONNACK = "Connack" +KEY_DISCONNECT = "Disconnect" +KEY_PUBACK = "Puback" +KEY_SUBACK = "Suback" +KEY_UNSUBACK = "Unsuback" + + +class CallbackManager(object): + + def __init__(self): + self.callback_invocation_record = { + KEY_ON_ONLINE : 0, + KEY_ON_OFFLINE : 0, + KEY_ON_MESSAGE : 0, + KEY_CONNACK : 0, + KEY_DISCONNECT : 0, + KEY_PUBACK : 0, + KEY_SUBACK : 0, + KEY_UNSUBACK : 0 + } + + def on_online(self): + print("OMG, I am online!") + self.callback_invocation_record[KEY_ON_ONLINE] += 1 + + def on_offline(self): + print("OMG, I am offline!") + self.callback_invocation_record[KEY_ON_OFFLINE] += 1 + + def on_message(self, message): + print("OMG, I got a message!") + self.callback_invocation_record[KEY_ON_MESSAGE] += 1 + + def connack(self, mid, data): + print("OMG, I got a connack!") + self.callback_invocation_record[KEY_CONNACK] += 1 + + def disconnect(self, mid, data): + print("OMG, I got a disconnect!") + self.callback_invocation_record[KEY_DISCONNECT] += 1 + + def puback(self, mid): + print("OMG, I got a puback!") + self.callback_invocation_record[KEY_PUBACK] += 1 + + def suback(self, mid, data): + print("OMG, I got a suback!") + self.callback_invocation_record[KEY_SUBACK] += 1 + + def unsuback(self, mid): + print("OMG, I got an unsuback!") + self.callback_invocation_record[KEY_UNSUBACK] += 1 + + +def get_random_string(length): + return "".join(random.choice(string.ascii_lowercase) for i in range(length)) + + +############################################################################ +# Main # +# Check inputs +my_check_in_manager = checkInManager(2) +my_check_in_manager.verify(sys.argv) +mode = my_check_in_manager.mode +host = my_check_in_manager.host + +skip_when_match(ModeIsALPN(mode).And( + Python2VersionLowerThan((2, 7, 10)).Or(Python3VersionLowerThan((3, 5, 0))) +), "This test is not applicable for mode %s and Python verison %s. Skipping..." % (mode, sys.version_info[:3])) + +# Performing +############ +print("Connecting...") +callback_manager = CallbackManager() +sdk_mqtt_client = MQTTClientManager()\ + .create_nonconnected_mqtt_client(mode, CLIENT_ID, host, (ROOT_CA, CERT, KEY), callback_manager) +sdk_mqtt_client.connectAsync(keepAliveIntervalSecond=1, ackCallback=callback_manager.connack) # Add callback +print("Wait some time to make sure we are connected...") +time.sleep(10) # 10 sec + +topic = TOPIC + get_random_string(4) +print("Subscribing to topic: " + topic) +sdk_mqtt_client.subscribeAsync(topic, 1, ackCallback=callback_manager.suback, messageCallback=None) +print("Wait some time to make sure we are subscribed...") +time.sleep(3) # 3 sec + +print("Publishing...") +for i in range(NUMBER_OF_PUBLISHES): + sdk_mqtt_client.publishAsync(topic, MESSAGE_PREFIX + str(i), 1, ackCallback=callback_manager.puback) + time.sleep(1) +print("Wait sometime to make sure we finished with publishing...") +time.sleep(2) + +print("Unsubscribing...") +sdk_mqtt_client.unsubscribeAsync(topic, ackCallback=callback_manager.unsuback) +print("Wait sometime to make sure we finished with unsubscribing...") +time.sleep(2) + +print("Disconnecting...") +sdk_mqtt_client.disconnectAsync(ackCallback=callback_manager.disconnect) + +print("Wait sometime to let the test result sync...") +time.sleep(3) + +print("Verifying...") +try: + assert callback_manager.callback_invocation_record[KEY_ON_ONLINE] == 1 + assert callback_manager.callback_invocation_record[KEY_CONNACK] == 1 + assert callback_manager.callback_invocation_record[KEY_SUBACK] == 1 + assert callback_manager.callback_invocation_record[KEY_PUBACK] == NUMBER_OF_PUBLISHES + assert callback_manager.callback_invocation_record[KEY_ON_MESSAGE] == NUMBER_OF_PUBLISHES + assert callback_manager.callback_invocation_record[KEY_UNSUBACK] == 1 + assert callback_manager.callback_invocation_record[KEY_DISCONNECT] == 1 + assert callback_manager.callback_invocation_record[KEY_ON_OFFLINE] == 1 +except BaseException as e: + print("Failed! %s" % e.message) +print("Pass!") diff --git a/test-integration/IntegrationTests/IntegrationTestAutoReconnectResubscribe.py b/test-integration/IntegrationTests/IntegrationTestAutoReconnectResubscribe.py new file mode 100644 index 0000000..e6c1bee --- /dev/null +++ b/test-integration/IntegrationTests/IntegrationTestAutoReconnectResubscribe.py @@ -0,0 +1,202 @@ +# This integration test verifies the functionality in the Python core of Yun/Python SDK +# for auto-reconnect and auto-resubscribe. +# It starts two threads using two different connections to AWS IoT: +# Thread A publishes 10 messages to topicB first, then quiet for a while, and finally +# publishes another 10 messages to topicB. +# Thread B subscribes to topicB and waits to receive messages. Once it receives the first +# 10 messages. It simulates a network error, disconnecting from the broker. In a short time, +# it should automatically reconnect and resubscribe to the previous topic and be able to +# receive the next 10 messages from thread A. +# Because of auto-reconnect/resubscribe, thread B should be able to receive all of the +# messages from topicB published by thread A without calling subscribe again in user code +# explicitly. + + +import random +import string +import sys +import time +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary") +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary/SDKPackage") + +import TestToolLibrary.checkInManager as checkInManager +import TestToolLibrary.MQTTClientManager as MQTTClientManager +from TestToolLibrary import simpleThreadManager +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import publishError +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeError +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeTimeoutException +from TestToolLibrary.skip import skip_when_match +from TestToolLibrary.skip import ModeIsALPN +from TestToolLibrary.skip import Python2VersionLowerThan +from TestToolLibrary.skip import Python3VersionLowerThan + +CLIENT_ID_PUB = "integrationTestMQTT_ClientPub" + "".join(random.choice(string.ascii_lowercase) for i in range(4)) +CLIENT_ID_SUB = "integrationTestMQTT_ClientSub" + "".join(random.choice(string.ascii_lowercase) for i in range(4)) + +# Callback unit +class callbackUnit: + def __init__(self): + self._internalSet = set() + + # Callback fro clientSub + def messageCallback(self, client, userdata, message): + print("Received a new message: " + str(message.payload)) + self._internalSet.add(message.payload.decode('utf-8')) + + def getInternalSet(self): + return self._internalSet + + +# Simulate a network error +def manualNetworkError(srcPyMQTTCore): + # Ensure we close the socket + if srcPyMQTTCore._internal_async_client._paho_client._sock: + srcPyMQTTCore._internal_async_client._paho_client._sock.close() + srcPyMQTTCore._internal_async_client._paho_client._sock = None + if srcPyMQTTCore._internal_async_client._paho_client._ssl: + srcPyMQTTCore._internal_async_client._paho_client._ssl.close() + srcPyMQTTCore._internal_async_client._paho_client._ssl = None + # Fake that we have detected the disconnection + srcPyMQTTCore._internal_async_client._paho_client.on_disconnect(None, None, 0) + + +# runFunctionUnit +class runFunctionUnit(): + def __init__(self): + self._messagesPublished = set() + self._topicB = "topicB/" + "".join(random.choice(string.ascii_lowercase) for i in range(4)) + + # ThreadA runtime function: + # 1. Publish 10 messages to topicB. + # 2. Take a nap: 20 sec + # 3. Publish another 10 messages to topicB. + def threadARuntime(self, pyCoreClient): + time.sleep(3) # Ensure a valid subscription + messageCount = 0 + # First 10 messages + while messageCount < 10: + try: + pyCoreClient.publish(self._topicB, str(messageCount), 1, False) + self._messagesPublished.add(str(messageCount)) + except publishError: + print("Publish error!") + except Exception as e: + print("Unknown exception!") + print("Type: " + str(type(e))) + print("Message: " + str(e.message)) + messageCount += 1 + time.sleep(0.5) # TPS = 2 + # Take a nap + time.sleep(20) + # Second 10 messages + while messageCount < 20: + try: + pyCoreClient.publish(self._topicB, str(messageCount), 1, False) + self._messagesPublished.add(str(messageCount)) + except publishError: + print("Publish Error!") + except Exception as e: + print("Unknown exception!") + print("Type: " + str(type(e))) + print("Message: " + str(e.message)) + messageCount += 1 + time.sleep(0.5) + print("Publish thread terminated.") + + # ThreadB runtime function: + # 1. Subscribe to topicB + # 2. Wait for a while + # 3. Create network blocking, triggering auto-reconnect and auto-resubscribe + # 4. On connect, wait for another while + def threadBRuntime(self, pyCoreClient, callback): + try: + # Subscribe to topicB + pyCoreClient.subscribe(self._topicB, 1, callback) + except subscribeTimeoutException: + print("Subscribe timeout!") + except subscribeError: + print("Subscribe error!") + except Exception as e: + print("Unknown exception!") + print("Type: " + str(type(e))) + print("Message: " + str(e.message)) + # Wait to get the first 10 messages from thread A + time.sleep(10) + # Block the network for 3 sec + print("Block the network for 3 sec...") + blockingTimeTenMs = 300 + while blockingTimeTenMs != 0: + manualNetworkError(pyCoreClient) + blockingTimeTenMs -= 1 + time.sleep(0.01) + print("Leave it to the main thread to keep waiting...") + + +############################################################################ +# Main # +# Check inputs +myCheckInManager = checkInManager.checkInManager(2) +myCheckInManager.verify(sys.argv) + +rootCA = "./test-integration/Credentials/rootCA.crt" +certificate = "./test-integration/Credentials/certificate.pem.crt" +privateKey = "./test-integration/Credentials/privateKey.pem.key" +mode = myCheckInManager.mode +host = myCheckInManager.host + +skip_when_match(ModeIsALPN(mode).And( + Python2VersionLowerThan((2, 7, 10)).Or(Python3VersionLowerThan((3, 5, 0))) +), "This test is not applicable for mode %s and Python verison %s. Skipping..." % (mode, sys.version_info[:3])) + +# Init Python core and connect +myMQTTClientManager = MQTTClientManager.MQTTClientManager() +clientPub = myMQTTClientManager.create_connected_mqtt_core(CLIENT_ID_PUB, host, rootCA, + certificate, privateKey, mode=mode) +clientSub = myMQTTClientManager.create_connected_mqtt_core(CLIENT_ID_SUB, host, rootCA, + certificate, privateKey, mode=mode) + +if clientPub is None or clientSub is None: + print("Clients not init!") + exit(4) + +print("Two clients are connected!") + +# Configurations +################ +# Callback unit +subCallbackUnit = callbackUnit() +# Threads +mySimpleThreadManager = simpleThreadManager.simpleThreadManager() +myRunFunctionUnit = runFunctionUnit() +publishThreadID = mySimpleThreadManager.createOneTimeThread(myRunFunctionUnit.threadARuntime, [clientPub]) +subscribeThreadID = mySimpleThreadManager.createOneTimeThread(myRunFunctionUnit.threadBRuntime, + [clientSub, subCallbackUnit.messageCallback]) + +# Performing +############ +mySimpleThreadManager.startThreadWithID(subscribeThreadID) +mySimpleThreadManager.startThreadWithID(publishThreadID) +mySimpleThreadManager.joinOneTimeThreadWithID(subscribeThreadID) +mySimpleThreadManager.joinOneTimeThreadWithID(publishThreadID) +time.sleep(3) # Just in case messages arrive slowly + +# Verifying +########### +# Length +print("Check if the length of the two sets are equal...") +print("Received from subscription: " + str(len(subCallbackUnit.getInternalSet()))) +print("Sent through publishes: " + str(len(myRunFunctionUnit._messagesPublished))) +if len(myRunFunctionUnit._messagesPublished) != len(subCallbackUnit.getInternalSet()): + print("Number of messages not equal!") + exit(4) +# Content +print("Check if the content if the two sets are equivalent...") +if myRunFunctionUnit._messagesPublished != subCallbackUnit.getInternalSet(): + print("Sent through publishes:") + print(myRunFunctionUnit._messagesPublished) + print("Received from subscription:") + print(subCallbackUnit.getInternalSet()) + print("Set content not equal!") + exit(4) +else: + print("Yes!") diff --git a/test-integration/IntegrationTests/IntegrationTestClientReusability.py b/test-integration/IntegrationTests/IntegrationTestClientReusability.py new file mode 100644 index 0000000..56e77b8 --- /dev/null +++ b/test-integration/IntegrationTests/IntegrationTestClientReusability.py @@ -0,0 +1,128 @@ +# This integration test verifies the re-usability of SDK MQTT client. +# By saying re-usability, we mean that users should be able to reuse +# the same SDK MQTT client object to connect and invoke other APIs +# after a disconnect API call has been invoked on that client object. +# This test contains 2 clients living 2 separate threads: +# 1. Thread publish: In this thread, a MQTT client will do the following +# in a loop: +# a. Connect to AWS IoT +# b. Publish several messages to a dedicated topic +# c. Disconnect from AWS IoT +# d. Sleep for a while +# 2. Thread subscribe: In this thread, a MQTT client will do nothing +# other than subscribing to a dedicated topic and counting the incoming +# messages. +# Assuming the client is reusable, the subscriber should be able to +# receive all the messages published by the publisher from the same +# client object in different connect sessions. + + +import uuid +import time +import sys +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary") +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary/SDKPackage") + +from threading import Event +from TestToolLibrary.checkInManager import checkInManager +from TestToolLibrary.simpleThreadManager import simpleThreadManager +from TestToolLibrary.MQTTClientManager import MQTTClientManager +from TestToolLibrary.skip import skip_when_match +from TestToolLibrary.skip import ModeIsALPN +from TestToolLibrary.skip import Python2VersionLowerThan +from TestToolLibrary.skip import Python3VersionLowerThan + + +TOPIC = "topic/" + str(uuid.uuid1()) +CLIENT_ID_PUB = "publisher" + str(uuid.uuid1()) +CLIENT_ID_SUB = "subscriber" + str(uuid.uuid1()) +MESSAGE_PREFIX = "Message-" +NUMBER_OF_MESSAGES_PER_LOOP = 3 +NUMBER_OF_LOOPS = 3 +SUB_WAIT_TIME_OUT_SEC = 20 +ROOT_CA = "./test-integration/Credentials/rootCA.crt" +CERT = "./test-integration/Credentials/certificate.pem.crt" +KEY = "./test-integration/Credentials/privateKey.pem.key" + + +class ClientTwins(object): + + def __init__(self, client_pub, client_sub): + self._client_pub = client_pub + self._client_sub = client_sub + self._message_publish_set = set() + self._message_receive_set = set() + self._publish_done = Event() + + def run_publisher(self, *params): + self._publish_done.clear() + time.sleep(3) + for i in range(NUMBER_OF_LOOPS): + self._single_publish_loop(i) + time.sleep(2) + self._publish_done.set() + + def _single_publish_loop(self, iteration_count): + print("In loop %d: " % iteration_count) + self._client_pub.connect() + print("Publisher connected!") + for i in range(NUMBER_OF_MESSAGES_PER_LOOP): + message = MESSAGE_PREFIX + str(iteration_count) + "_" + str(i) + self._client_pub.publish(TOPIC, message, 1) + print("Publisher published %s to topic %s" % (message, TOPIC)) + self._message_publish_set.add(message.encode("utf-8")) + time.sleep(1) + self._client_pub.disconnect() + print("Publisher disconnected!\n\n") + + def run_subscriber(self, *params): + self._client_sub.connect() + self._client_sub.subscribe(TOPIC, 1, self._callback) + self._publish_done.wait(20) + self._client_sub.disconnect() + + def _callback(self, client, user_data, message): + self._message_receive_set.add(message.payload) + print("Subscriber received %s from topic %s" % (message.payload, message.topic)) + + def verify(self): + assert len(self._message_receive_set) != 0 + assert len(self._message_publish_set) != 0 + assert self._message_publish_set == self._message_receive_set + + +############################################################################ +# Main # +my_check_in_manager = checkInManager(2) +my_check_in_manager.verify(sys.argv) +mode = my_check_in_manager.mode +host = my_check_in_manager.host + +skip_when_match(ModeIsALPN(mode).And( + Python2VersionLowerThan((2, 7, 10)).Or(Python3VersionLowerThan((3, 5, 0))) +), "This test is not applicable for mode %s and Python verison %s. Skipping..." % (mode, sys.version_info[:3])) + +simple_thread_manager = simpleThreadManager() + +client_pub = MQTTClientManager().create_nonconnected_mqtt_client(mode, CLIENT_ID_PUB, host, (ROOT_CA, CERT, KEY)) +print("Client publisher initialized.") +client_sub = MQTTClientManager().create_nonconnected_mqtt_client(mode, CLIENT_ID_SUB, host, (ROOT_CA, CERT, KEY)) +print("Client subscriber initialized.") +client_twins = ClientTwins(client_pub, client_sub) +print("Client twins initialized.") + +publisher_thread_id = simple_thread_manager.createOneTimeThread(client_twins.run_publisher, []) +subscriber_thread_id = simple_thread_manager.createOneTimeThread(client_twins.run_subscriber, []) +simple_thread_manager.startThreadWithID(subscriber_thread_id) +print("Started subscriber thread.") +simple_thread_manager.startThreadWithID(publisher_thread_id) +print("Started publisher thread.") + +print("Main thread starts waiting.") +simple_thread_manager.joinOneTimeThreadWithID(publisher_thread_id) +simple_thread_manager.joinOneTimeThreadWithID(subscriber_thread_id) +print("Main thread waiting is done!") + +print("Verifying...") +client_twins.verify() +print("Pass!") diff --git a/test-integration/IntegrationTests/IntegrationTestConfigurablePublishMessageQueueing.py b/test-integration/IntegrationTests/IntegrationTestConfigurablePublishMessageQueueing.py new file mode 100644 index 0000000..0d78f4f --- /dev/null +++ b/test-integration/IntegrationTests/IntegrationTestConfigurablePublishMessageQueueing.py @@ -0,0 +1,305 @@ +# This integration test verifies the functionality in the Python core of Yun SDK +# for configurable offline publish message queueing. +# For each offline publish queue to be tested, it starts two threads using +# different connections to AWS IoT: +# Thread A subscribes to TopicOnly and wait to receive messages published to +# TopicOnly from ThreadB. +# Thread B publishes to TopicOnly with a manual network error which triggers the +# offline publish message queueing. According to different configurations, the +# internal queue should keep as many publish requests as configured and then +# republish them once the connection is back. +# * After the network is down but before the client gets the notification of being +# * disconnected, QoS0 messages in between this "blind-window" will be lost. However, +# * once the client gets the notification, it should start queueing messages up to +# * its queue size limit. +# * Therefore, all published messages are QoS0, we are verifying the total amount. +# * Configuration to be tested: +# 1. Limited queueing section, limited response (in-flight) section, drop oldest +# 2. Limited queueing section, limited response (in-flight) section, drop newest + + +import threading +import sys +import time +import random +import string +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary") +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary/SDKPackage") + +import TestToolLibrary.checkInManager as checkInManager +import TestToolLibrary.MQTTClientManager as MQTTClientManager +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.core.util.enums import DropBehaviorTypes +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import publishError +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import publishQueueFullException +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeError +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeTimeoutException +from TestToolLibrary.skip import skip_when_match +from TestToolLibrary.skip import ModeIsALPN +from TestToolLibrary.skip import Python2VersionLowerThan +from TestToolLibrary.skip import Python3VersionLowerThan + +CLIENT_ID_PUB = "integrationTestMQTT_ClientPub" + "".join(random.choice(string.ascii_lowercase) for i in range(4)) +CLIENT_ID_SUB = "integrationTestMQTT_ClientSub" + "".join(random.choice(string.ascii_lowercase) for i in range(4)) + +# Class that implements the publishing thread: Thread A, with network failure +# This thread will publish 3 messages first, and then keep publishing +# with a network failure, and then publish another set of 3 messages +# once the connection is resumed. +# * TPS = 1 +class threadPub: + def __init__(self, pyCoreClient, numberOfOfflinePublish, srcTopic): + self._publishMessagePool = list() + self._pyCoreClient = pyCoreClient + self._numberOfOfflinePublish = numberOfOfflinePublish + self._topic = srcTopic + + # Simulate a network error + def _manualNetworkError(self): + # Ensure we close the socket + if self._pyCoreClient._internal_async_client._paho_client._sock: + self._pyCoreClient._internal_async_client._paho_client._sock.close() + self._pyCoreClient._internal_async_client._paho_client._sock = None + if self._pyCoreClient._internal_async_client._paho_client._ssl: + self._pyCoreClient._internal_async_client._paho_client._ssl.close() + self._pyCoreClient._internal_async_client._paho_client._ssl = None + # Fake that we have detected the disconnection + self._pyCoreClient._internal_async_client._paho_client.on_disconnect(None, None, 0) + + def _runtime(self): + messageCount = 0 + # Publish 3 messages + print("Thread A: Publish 3 messages.") + step1PublishCount = 3 + while step1PublishCount != 0: + currentMessage = str(messageCount) + self._publishMessagePool.append(int(currentMessage)) + try: + self._pyCoreClient.publish(self._topic, currentMessage, 0, False) + print("Thread A: Published a message: " + str(currentMessage)) + step1PublishCount -= 1 + messageCount += 1 + except publishError: + print("Publish Error!") + except publishQueueFullException: + print("Internal Publish Queue is FULL!") + except Exception as e: + print("Unknown exception!") + print("Type: " + str(type(e))) + print("Message: " + str(e.message)) + time.sleep(1) + # Network Failure, publish #numberOfOfflinePublish# messages + # Scanning rate = 100 TPS + print( + "Thread A: Simulate an network error. Keep publishing for " + str(self._numberOfOfflinePublish) + " messages.") + step2LoopCount = self._numberOfOfflinePublish * 100 + while step2LoopCount != 0: + self._manualNetworkError() + if step2LoopCount % 100 == 0: + currentMessage = str(messageCount) + self._publishMessagePool.append(int(currentMessage)) + try: + self._pyCoreClient.publish(self._topic, currentMessage, 0, False) + print("Thread A: Published a message: " + str(currentMessage)) + except publishError: + print("Publish Error!") + except Exception as e: + print("Unknown exception!") + print("Type: " + str(type(e))) + print("Message: " + str(e.message)) + messageCount += 1 + step2LoopCount -= 1 + time.sleep(0.01) + # Reconnecting + reconnectTiming = 0 # count per 0.01 seconds + while reconnectTiming <= 1000: + if reconnectTiming % 100 == 0: + print("Thread A: Counting reconnect time: " + str(reconnectTiming / 100) + "seconds.") + reconnectTiming += 1 + time.sleep(0.01) + print("Thread A: Reconnected!") + # Publish another set of 3 messages + print("Thread A: Publish 3 messages again.") + step3PublishCount = 3 + while step3PublishCount != 0: + currentMessage = str(messageCount) + self._publishMessagePool.append(int(currentMessage)) + try: + self._pyCoreClient.publish(self._topic, currentMessage, 0, False) + print("Thread A: Published a message: " + str(currentMessage)) + step3PublishCount -= 1 + messageCount += 1 + except publishError: + print("Publish Error!") + except Exception as e: + print("Unknown exception!") + print("Type: " + str(type(e))) + print("Message: " + str(e.message)) + time.sleep(1) + # Wrap up: Sleep for extra 5 seconds + time.sleep(5) + + def startThreadAndGo(self): + threadHandler = threading.Thread(target=self._runtime) + threadHandler.start() + return threadHandler + + def getPublishMessagePool(self): + return self._publishMessagePool + + +# Class that implements the subscribing thread: Thread B. +# Basically this thread does nothing but subscribes to TopicOnly and keeps receiving messages. +class threadSub: + def __init__(self, pyCoreClient, srcTopic): + self._keepRunning = True + self._pyCoreClient = pyCoreClient + self._subscribeMessagePool = list() + self._topic = srcTopic + + def _messageCallback(self, client, userdata, message): + print("Thread B: Received a new message from topic: " + str(message.topic)) + print("Thread B: Payload is: " + str(message.payload)) + self._subscribeMessagePool.append(int(message.payload)) + + def _runtime(self): + # Subscribe to self._topic + try: + self._pyCoreClient.subscribe(self._topic, 1, self._messageCallback) + except subscribeTimeoutException: + print("Subscribe timeout!") + except subscribeError: + print("Subscribe error!") + except Exception as e: + print("Unknown exception!") + print("Type: " + str(type(e))) + print("Message: " + str(e.message)) + time.sleep(2.2) + print("Thread B: Subscribed to " + self._topic) + print("Thread B: Now wait for Thread A.") + # Scanning rate is 100 TPS + while self._keepRunning: + time.sleep(0.01) + + def startThreadAndGo(self): + threadHandler = threading.Thread(target=self._runtime) + threadHandler.start() + return threadHandler + + def stopRunning(self): + self._keepRunning = False + + def getSubscribeMessagePool(self): + return self._subscribeMessagePool + + +# Generate answer for this integration test using queue configuration +def generateAnswer(data, queueingSize, srcMode): + dataInWork = sorted(data) + dataHead = dataInWork[:3] + dataTail = dataInWork[-3:] + dataRet = dataHead + dataInWork = dataInWork[3:] + dataInWork = dataInWork[:-3] + if srcMode == 0: # DROP_OLDEST + dataInWork = dataInWork[(-1 * queueingSize):] + dataRet.extend(dataInWork) + dataRet.extend(dataTail) + return sorted(dataRet) + elif srcMode == 1: # DROP_NEWEST + dataInWork = dataInWork[:queueingSize] + dataRet.extend(dataInWork) + dataRet.extend(dataTail) + return sorted(dataRet) + else: + print("Unsupported drop behavior!") + raise ValueError + + +# Create thread object, load in pyCoreClient and perform the set of integration tests +def performConfigurableOfflinePublishQueueTest(clientPub, clientSub): + print("Test DROP_NEWEST....") + clientPub[0].configure_offline_requests_queue(10, DropBehaviorTypes.DROP_NEWEST) # dropNewest + clientSub[0].configure_offline_requests_queue(10, DropBehaviorTypes.DROP_NEWEST) # dropNewest + # Create Topics + TopicOnly = "TopicOnly/" + "".join(random.choice(string.ascii_lowercase) for i in range(4)) + # Create thread object + threadPubObject = threadPub(clientPub[0], 15, TopicOnly) # Configure to publish 15 messages during network outage + threadSubObject = threadSub(clientSub[0], TopicOnly) + threadSubHandler = threadSubObject.startThreadAndGo() + time.sleep(3) + threadPubHandler = threadPubObject.startThreadAndGo() + threadPubHandler.join() + threadSubObject.stopRunning() + threadSubHandler.join() + # Verify result + print("Verify DROP_NEWEST:") + answer = generateAnswer(threadPubObject.getPublishMessagePool(), 10, 1) + print("ANSWER:") + print(answer) + print("ACTUAL:") + print(threadSubObject.getSubscribeMessagePool()) + # We are doing QoS0 publish. We cannot guarantee when the drop will happen since we cannot guarantee a fixed time out + # of disconnect detection. However, once offline requests queue starts involving, it should queue up to its limit, + # thus the total number of received messages after draining should always match. + if len(threadSubObject.getSubscribeMessagePool()) == len(answer): + print("Passed.") + else: + print("Verify DROP_NEWEST failed!!!") + return False + time.sleep(5) + print("Test DROP_OLDEST....") + clientPub[0].configure_offline_requests_queue(10, DropBehaviorTypes.DROP_OLDEST) # dropOldest + clientSub[0].configure_offline_requests_queue(10, DropBehaviorTypes.DROP_OLDEST) # dropOldest + # Create thread object + threadPubObject = threadPub(clientPub[0], 15, TopicOnly) # Configure to publish 15 messages during network outage + threadSubObject = threadSub(clientSub[0], TopicOnly) + threadSubHandler = threadSubObject.startThreadAndGo() + time.sleep(3) + threadPubHandler = threadPubObject.startThreadAndGo() + threadPubHandler.join() + threadSubObject.stopRunning() + threadSubHandler.join() + # Verify result + print("Verify DROP_OLDEST:") + answer = generateAnswer(threadPubObject.getPublishMessagePool(), 10, 0) + print(answer) + print("ACTUAL:") + print(threadSubObject.getSubscribeMessagePool()) + if len(threadSubObject.getSubscribeMessagePool()) == len(answer): + print("Passed.") + else: + print("Verify DROP_OLDEST failed!!!") + return False + return True + + +# Check inputs +myCheckInManager = checkInManager.checkInManager(2) +myCheckInManager.verify(sys.argv) + +host = myCheckInManager.host +rootCA = "./test-integration/Credentials/rootCA.crt" +certificate = "./test-integration/Credentials/certificate.pem.crt" +privateKey = "./test-integration/Credentials/privateKey.pem.key" +mode = myCheckInManager.mode + +skip_when_match(ModeIsALPN(mode).And( + Python2VersionLowerThan((2, 7, 10)).Or(Python3VersionLowerThan((3, 5, 0))) +), "This test is not applicable for mode %s and Python verison %s. Skipping..." % (mode, sys.version_info[:3])) + +# Init Python core and connect +myMQTTClientManager = MQTTClientManager.MQTTClientManager() +clientPub = myMQTTClientManager.create_connected_mqtt_core(CLIENT_ID_PUB, host, rootCA, + certificate, privateKey, mode=mode) +clientSub = myMQTTClientManager.create_connected_mqtt_core(CLIENT_ID_SUB, host, rootCA, + certificate, privateKey, mode=mode) + +if clientPub is None or clientSub is None: + exit(4) + +print("Two clients are connected!") + +# Functionality test +if not performConfigurableOfflinePublishQueueTest([clientPub], [clientSub]): + print("The above Drop behavior broken!") + exit(4) diff --git a/test-integration/IntegrationTests/IntegrationTestDiscovery.py b/test-integration/IntegrationTests/IntegrationTestDiscovery.py new file mode 100644 index 0000000..2fac25b --- /dev/null +++ b/test-integration/IntegrationTests/IntegrationTestDiscovery.py @@ -0,0 +1,216 @@ +import sys +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary") +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary/SDKPackage") + +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.core.greengrass.discovery.providers import DiscoveryInfoProvider +from TestToolLibrary.checkInManager import checkInManager +from TestToolLibrary.skip import skip_when_match +from TestToolLibrary.skip import ModeIsWebSocket + + +PORT = 8443 +CA = "./test-integration/Credentials/rootCA.crt" +CERT = "./test-integration/Credentials/certificate_drs.pem.crt" +KEY = "./test-integration/Credentials/privateKey_drs.pem.key" +TIME_OUT_SEC = 30 +# This is a pre-generated test data from DRS integration tests +# The test resources point to account # 003261610643 +ID_PREFIX = "Id-" +GGC_ARN = "arn:aws:iot:us-east-1:003261610643:thing/DRS_GGC_0kegiNGA_0" +GGC_PORT_NUMBER_BASE = 8080 +GGC_HOST_ADDRESS_PREFIX = "192.168.101." +METADATA_PREFIX = "Description-" +GROUP_ID = "627bf63d-ae64-4f58-a18c-80a44fcf4088" +THING_NAME = "DRS_GGAD_0kegiNGA_0" +EXPECTED_CA_CONTENT = "-----BEGIN CERTIFICATE-----\n" \ + "MIIEFTCCAv2gAwIBAgIVAPZfc4GMLZPmXbnoaZm6jRDqDs4+MA0GCSqGSIb3DQEB\n" \ + "CwUAMIGoMQswCQYDVQQGEwJVUzEYMBYGA1UECgwPQW1hem9uLmNvbSBJbmMuMRww\n" \ + "GgYDVQQLDBNBbWF6b24gV2ViIFNlcnZpY2VzMRMwEQYDVQQIDApXYXNoaW5ndG9u\n" \ + "MRAwDgYDVQQHDAdTZWF0dGxlMTowOAYDVQQDDDEwMDMyNjE2MTA2NDM6NjI3YmY2\n" \ + "M2QtYWU2NC00ZjU4LWExOGMtODBhNDRmY2Y0MDg4MCAXDTE3MDUyNTE4NDI1OVoY\n" \ + "DzIwOTcwNTI1MTg0MjU4WjCBqDELMAkGA1UEBhMCVVMxGDAWBgNVBAoMD0FtYXpv\n" \ + "bi5jb20gSW5jLjEcMBoGA1UECwwTQW1hem9uIFdlYiBTZXJ2aWNlczETMBEGA1UE\n" \ + "CAwKV2FzaGluZ3RvbjEQMA4GA1UEBwwHU2VhdHRsZTE6MDgGA1UEAwwxMDAzMjYx\n" \ + "NjEwNjQzOjYyN2JmNjNkLWFlNjQtNGY1OC1hMThjLTgwYTQ0ZmNmNDA4ODCCASIw\n" \ + "DQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKEWtZtKyJUg2VUwZkbkVtltrfam\n" \ + "s9LMIdKNA3Wz4zSLhZjKHiTSkQmpZwKle5ziYs6Q5hfeT8WC0FNAVv1JhnwsuGfT\n" \ + "sG0UO5dSn7wqXOJigKC1CaSGqeFpKB0/a3wR1L6pCGVbLZ86/sPCEPHHJDieQ+Ps\n" \ + "RnOcUGb4CuIBnI2N+lafWNa4F4KRSVJCEeZ6u4iWVVdIEcDLKlakY45jtVvQqwnz\n" \ + "3leFsN7PTLEkVq5u1PXSbT5DWv6p+5NoDnGAT7j7Wbr2yJw7DtpBOL6oWkAdbFAQ\n" \ + "2097e8mIxNYE9xAzRlb5wEr6jZl/8K60v9P83OapMeuOg4JS8FGulHXbDg0CAwEA\n" \ + "AaMyMDAwDwYDVR0TAQH/BAUwAwEB/zAdBgNVHQ4EFgQU21ELaPCH9Oh001OS0JMv\n" \ + "n8hU8dYwDQYJKoZIhvcNAQELBQADggEBABW66eH/+/v9Nq5jtJzflfrqAfBOzWLj\n" \ + "UTEv6szkYzV5Crr8vnu2P5OlyA0NdiKGiAm0AgoDkf+n9HU3Hc0zm3G/QaAO2UmN\n" \ + "9MwtIp29BSRf+gd1bX/WZTtl5I5xl290BDfr5o08I6TOf0A4P8IAkGwku5j0IQjM\n" \ + "ns2HH5UVki155dtmWDEGX6q35KABbsmv3tO1+geJVYnd1QkHzR5IXA12gxlMw9GJ\n" \ + "+cOw+rwJJ2ZcXo3HFoXBcsPqPOa1SO3vTl3XWQ+jX3vyDsxh/VGoJ4epsjwmJ+dW\n" \ + "sHJoqsa3ZPDW0LcEuYgdzYWRhumGwH9fJJUx0GS4Tdg4ud+6jpuyflU=\n" \ + "-----END CERTIFICATE-----\n" +# The expected response from DRS should be: +''' +{ + "GGGroups": [ + { + "GGGroupId": "627bf63d-ae64-4f58-a18c-80a44fcf4088", + "Cores": [ + { + "thingArn": "arn:aws:iot:us-east-1:003261610643:thing\/DRS_GGC_0kegiNGA_0", + "Connectivity": [ + { + "Id": "Id-0", + "HostAddress": "192.168.101.0", + "PortNumber": 8080, + "Metadata": "Description-0" + }, + { + "Id": "Id-1", + "HostAddress": "192.168.101.1", + "PortNumber": 8081, + "Metadata": "Description-1" + }, + { + "Id": "Id-2", + "HostAddress": "192.168.101.2", + "PortNumber": 8082, + "Metadata": "Description-2" + } + ] + } + ], + "CAs": [ + "-----BEGIN CERTIFICATE-----\n + MIIEFTCCAv2gAwIBAgIVAPZfc4GMLZPmXbnoaZm6jRDqDs4+MA0GCSqGSIb3DQEB\n + CwUAMIGoMQswCQYDVQQGEwJVUzEYMBYGA1UECgwPQW1hem9uLmNvbSBJbmMuMRww\n + GgYDVQQLDBNBbWF6b24gV2ViIFNlcnZpY2VzMRMwEQYDVQQIDApXYXNoaW5ndG9u\n + MRAwDgYDVQQHDAdTZWF0dGxlMTowOAYDVQQDDDEwMDMyNjE2MTA2NDM6NjI3YmY2\n + M2QtYWU2NC00ZjU4LWExOGMtODBhNDRmY2Y0MDg4MCAXDTE3MDUyNTE4NDI1OVoY\n + DzIwOTcwNTI1MTg0MjU4WjCBqDELMAkGA1UEBhMCVVMxGDAWBgNVBAoMD0FtYXpv\n + bi5jb20gSW5jLjEcMBoGA1UECwwTQW1hem9uIFdlYiBTZXJ2aWNlczETMBEGA1UE\n + CAwKV2FzaGluZ3RvbjEQMA4GA1UEBwwHU2VhdHRsZTE6MDgGA1UEAwwxMDAzMjYx\n + NjEwNjQzOjYyN2JmNjNkLWFlNjQtNGY1OC1hMThjLTgwYTQ0ZmNmNDA4ODCCASIw\n + DQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKEWtZtKyJUg2VUwZkbkVtltrfam\n + s9LMIdKNA3Wz4zSLhZjKHiTSkQmpZwKle5ziYs6Q5hfeT8WC0FNAVv1JhnwsuGfT\n + sG0UO5dSn7wqXOJigKC1CaSGqeFpKB0\/a3wR1L6pCGVbLZ86\/sPCEPHHJDieQ+Ps\n + RnOcUGb4CuIBnI2N+lafWNa4F4KRSVJCEeZ6u4iWVVdIEcDLKlakY45jtVvQqwnz\n + 3leFsN7PTLEkVq5u1PXSbT5DWv6p+5NoDnGAT7j7Wbr2yJw7DtpBOL6oWkAdbFAQ\n + 2097e8mIxNYE9xAzRlb5wEr6jZl\/8K60v9P83OapMeuOg4JS8FGulHXbDg0CAwEA\n + AaMyMDAwDwYDVR0TAQH\/BAUwAwEB\/zAdBgNVHQ4EFgQU21ELaPCH9Oh001OS0JMv\n + n8hU8dYwDQYJKoZIhvcNAQELBQADggEBABW66eH\/+\/v9Nq5jtJzflfrqAfBOzWLj\n + UTEv6szkYzV5Crr8vnu2P5OlyA0NdiKGiAm0AgoDkf+n9HU3Hc0zm3G\/QaAO2UmN\n + 9MwtIp29BSRf+gd1bX\/WZTtl5I5xl290BDfr5o08I6TOf0A4P8IAkGwku5j0IQjM\n + ns2HH5UVki155dtmWDEGX6q35KABbsmv3tO1+geJVYnd1QkHzR5IXA12gxlMw9GJ\n + +cOw+rwJJ2ZcXo3HFoXBcsPqPOa1SO3vTl3XWQ+jX3vyDsxh\/VGoJ4epsjwmJ+dW\n + sHJoqsa3ZPDW0LcEuYgdzYWRhumGwH9fJJUx0GS4Tdg4ud+6jpuyflU=\n + -----END CERTIFICATE-----\n" + ] + } + ] +} +''' + +my_check_in_manager = checkInManager(2) +my_check_in_manager.verify(sys.argv) +mode = my_check_in_manager.mode +host = my_check_in_manager.host + +def create_discovery_info_provider(): + discovery_info_provider = DiscoveryInfoProvider() + discovery_info_provider.configureEndpoint(host, PORT) + discovery_info_provider.configureCredentials(CA, CERT, KEY) + discovery_info_provider.configureTimeout(TIME_OUT_SEC) + return discovery_info_provider + + +def perform_integ_test_discovery(): + discovery_info_provider = create_discovery_info_provider() + return discovery_info_provider.discover(THING_NAME) + + +def _verify_connectivity_info(actual_connectivity_info): + info_id = actual_connectivity_info.id + sequence_number_string = info_id[-1:] + assert actual_connectivity_info.host == GGC_HOST_ADDRESS_PREFIX + sequence_number_string + assert actual_connectivity_info.port == GGC_PORT_NUMBER_BASE + int(sequence_number_string) + assert actual_connectivity_info.metadata == METADATA_PREFIX + sequence_number_string + + +def _verify_connectivity_info_list(actual_connectivity_info_list): + for actual_connectivity_info in actual_connectivity_info_list: + _verify_connectivity_info(actual_connectivity_info) + + +def _verify_ggc_info(actual_ggc_info): + assert actual_ggc_info.coreThingArn == GGC_ARN + assert actual_ggc_info.groupId == GROUP_ID + _verify_connectivity_info_list(actual_ggc_info.connectivityInfoList) + + +def _verify_ca_list(ca_list): + assert len(ca_list) == 1 + try: + group_id, ca = ca_list[0] + assert group_id == GROUP_ID + assert ca == EXPECTED_CA_CONTENT + except: + assert ca_list[0] == EXPECTED_CA_CONTENT + + +def verify_all_cores(discovery_info): + print("Verifying \"getAllCores\"...") + ggc_info_list = discovery_info.getAllCores() + assert len(ggc_info_list) == 1 + _verify_ggc_info(ggc_info_list[0]) + print("Pass!") + + +def verify_all_cas(discovery_info): + print("Verifying \"getAllCas\"...") + ca_list = discovery_info.getAllCas() + _verify_ca_list(ca_list) + print("Pass!") + + +def verify_all_groups(discovery_info): + print("Verifying \"getAllGroups\"...") + group_list = discovery_info.getAllGroups() + assert len(group_list) == 1 + group_info = group_list[0] + _verify_ca_list(group_info.caList) + _verify_ggc_info(group_info.coreConnectivityInfoList[0]) + print("Pass!") + + +def verify_group_object(discovery_info): + print("Verifying \"toObjectAtGroupLevel\"...") + group_info_object = discovery_info.toObjectAtGroupLevel() + _verify_connectivity_info(group_info_object + .get(GROUP_ID) + .getCoreConnectivityInfo(GGC_ARN) + .getConnectivityInfo(ID_PREFIX + "0")) + _verify_connectivity_info(group_info_object + .get(GROUP_ID) + .getCoreConnectivityInfo(GGC_ARN) + .getConnectivityInfo(ID_PREFIX + "1")) + _verify_connectivity_info(group_info_object + .get(GROUP_ID) + .getCoreConnectivityInfo(GGC_ARN) + .getConnectivityInfo(ID_PREFIX + "2")) + print("Pass!") + + +############################################################################ +# Main # + +skip_when_match(ModeIsWebSocket(mode), "This test is not applicable for mode: %s. Skipping..." % mode) + +# GG Discovery only applies mutual auth with cert +try: + discovery_info = perform_integ_test_discovery() + + verify_all_cores(discovery_info) + verify_all_cas(discovery_info) + verify_all_groups(discovery_info) + verify_group_object(discovery_info) +except BaseException as e: + print("Failed! " + e.message) + exit(4) diff --git a/test-integration/IntegrationTests/IntegrationTestJobsClient.py b/test-integration/IntegrationTests/IntegrationTestJobsClient.py new file mode 100644 index 0000000..3653725 --- /dev/null +++ b/test-integration/IntegrationTests/IntegrationTestJobsClient.py @@ -0,0 +1,185 @@ +# This integration test verifies the jobs client functionality in the +# Python SDK. +# It performs a number of basic operations without expecting an actual job or +# job execution to be present. The callbacks associated with these actions +# are written to accept and pass server responses given when no jobs or job +# executions exist. +# Finally, the tester pumps through all jobs queued for the given thing +# doing a basic echo of the job document and updating the job execution +# to SUCCEEDED + +import random +import string +import time +import sys +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary") +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary/SDKPackage") + +from TestToolLibrary import simpleThreadManager +import TestToolLibrary.checkInManager as checkInManager +import TestToolLibrary.MQTTClientManager as MQTTClientManager +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import publishError +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeError +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeTimeoutException +from TestToolLibrary.skip import skip_when_match +from TestToolLibrary.skip import ModeIsALPN +from TestToolLibrary.skip import Python2VersionLowerThan +from TestToolLibrary.skip import Python3VersionLowerThan + +from AWSIoTPythonSDK.MQTTLib import AWSIoTMQTTThingJobsClient +from AWSIoTPythonSDK.core.jobs.thingJobManager import jobExecutionTopicType +from AWSIoTPythonSDK.core.jobs.thingJobManager import jobExecutionTopicReplyType +from AWSIoTPythonSDK.core.jobs.thingJobManager import jobExecutionStatus + +import threading +import datetime +import argparse +import json + +IOT_JOBS_MQTT_RESPONSE_WAIT_SECONDS = 5 +CLIENT_ID = "integrationTestMQTT_Client" + "".join(random.choice(string.ascii_lowercase) for i in range(4)) + +class JobsMessageProcessor(object): + def __init__(self, awsIoTMQTTThingJobsClient, clientToken): + #keep track of this to correlate request/responses + self.clientToken = clientToken + self.awsIoTMQTTThingJobsClient = awsIoTMQTTThingJobsClient + + def _setupCallbacks(self): + print('Creating test subscriptions...') + assert True == self.awsIoTMQTTThingJobsClient.createJobSubscription(self.getPendingJobAcceptedCallback, jobExecutionTopicType.JOB_GET_PENDING_TOPIC, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE) + assert True == self.awsIoTMQTTThingJobsClient.createJobSubscription(self.getPendingJobRejectedCallback, jobExecutionTopicType.JOB_GET_PENDING_TOPIC, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE) + assert True == self.awsIoTMQTTThingJobsClient.createJobSubscription(self.describeJobExecAcceptedCallback, jobExecutionTopicType.JOB_DESCRIBE_TOPIC, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE, '+') + assert True == self.awsIoTMQTTThingJobsClient.createJobSubscription(self.describeJobExecRejectedCallback, jobExecutionTopicType.JOB_DESCRIBE_TOPIC, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE, '+') + assert True == self.awsIoTMQTTThingJobsClient.createJobSubscription(self.startNextPendingJobAcceptedCallback, jobExecutionTopicType.JOB_START_NEXT_TOPIC, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE) + assert True == self.awsIoTMQTTThingJobsClient.createJobSubscription(self.startNextPendingJobRejectedCallback, jobExecutionTopicType.JOB_START_NEXT_TOPIC, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE) + assert True == self.awsIoTMQTTThingJobsClient.createJobSubscription(self.updateJobAcceptedCallback, jobExecutionTopicType.JOB_UPDATE_TOPIC, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE, '+') + assert True == self.awsIoTMQTTThingJobsClient.createJobSubscription(self.updateJobRejectedCallback, jobExecutionTopicType.JOB_UPDATE_TOPIC, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE, '+') + + def getPendingJobAcceptedCallback(self, client, userdata, message): + self.testResult = (True, 'GetPending accepted callback invoked!') + self.waitEvent.set() + + def getPendingJobRejectedCallback(self, client, userdata, message): + self.testResult = (False, 'GetPending rejection callback invoked!') + self.waitEvent.set() + + def describeJobExecAcceptedCallback(self, client, userdata, message): + self.testResult = (True, 'DescribeJobExecution accepted callback invoked!') + self.waitEvent.set() + + def describeJobExecRejectedCallback(self, client, userdata, message): + self.testResult = (False, 'DescribeJobExecution rejected callback invoked!') + self.waitEvent.set() + + def startNextPendingJobAcceptedCallback(self, client, userdata, message): + self.testResult = (True, 'StartNextPendingJob accepted callback invoked!') + payload = json.loads(message.payload.decode('utf-8')) + if 'execution' not in payload: + self.done = True + else: + print('Found job! Document: ' + payload['execution']['jobDocument']) + threading.Thread(target=self.awsIoTMQTTThingJobsClient.sendJobsUpdate(payload['execution']['jobId'], jobExecutionStatus.JOB_EXECUTION_SUCCEEDED)).start() + self.waitEvent.set() + + def startNextPendingJobRejectedCallback(self, client, userdata, message): + self.testResult = (False, 'StartNextPendingJob rejected callback invoked!') + self.waitEvent.set() + + def updateJobAcceptedCallback(self, client, userdata, message): + self.testResult = (True, 'UpdateJob accepted callback invoked!') + self.waitEvent.set() + + def updateJobRejectedCallback(self, client, userdata, message): + #rejection is still a successful test because job IDs may or may not exist, and could exist in unknown state + self.testResult = (True, 'UpdateJob rejected callback invoked!') + self.waitEvent.set() + + def executeJob(self, execution): + print('Executing job ID, version, number: {}, {}, {}'.format(execution['jobId'], execution['versionNumber'], execution['executionNumber'])) + print('With jobDocument: ' + json.dumps(execution['jobDocument'])) + + def runTests(self): + print('Running jobs tests...') + ##create subscriptions + self._setupCallbacks() + + #make publish calls + self._init_test_wait() + self._test_send_response_confirm(self.awsIoTMQTTThingJobsClient.sendJobsDescribe('$next')) + + self._init_test_wait() + self._test_send_response_confirm(self.awsIoTMQTTThingJobsClient.sendJobsUpdate('junkJobId', jobExecutionStatus.JOB_EXECUTION_SUCCEEDED)) + + self._init_test_wait() + self._test_send_response_confirm(self.awsIoTMQTTThingJobsClient.sendJobsQuery(jobExecutionTopicType.JOB_GET_PENDING_TOPIC)) + + self._init_test_wait() + self._test_send_response_confirm(self.awsIoTMQTTThingJobsClient.sendJobsStartNext()) + + self.processAllJobs() + + def processAllJobs(self): + #process all enqueued jobs + print('Processing all jobs found in queue for thing...') + self.done = False + while not self.done: + self._attemptStartNextJob() + time.sleep(5) + + def _attemptStartNextJob(self): + statusDetails = {'StartedBy': 'ClientToken: {} on {}'.format(self.clientToken, datetime.datetime.now().isoformat())} + threading.Thread(target=self.awsIoTMQTTThingJobsClient.sendJobsStartNext, kwargs = {'statusDetails': statusDetails}).start() + + def _init_test_wait(self): + self.testResult = (False, 'Callback not invoked') + self.waitEvent = threading.Event() + + def _test_send_response_confirm(self, sendResult): + if not sendResult: + print('Failed to send jobs message') + exit(4) + else: + #wait 25 seconds for expected callback response to happen + if not self.waitEvent.wait(IOT_JOBS_MQTT_RESPONSE_WAIT_SECONDS): + print('Did not receive expected callback within %d second timeout' % IOT_JOBS_MQTT_RESPONSE_WAIT_SECONDS) + exit(4) + elif not self.testResult[0]: + print('Callback result has failed the test with message: ' + self.testResult[1]) + exit(4) + else: + print('Recieved expected result: ' + self.testResult[1]) + + +############################################################################ +# Main # +# Check inputs +myCheckInManager = checkInManager.checkInManager(2) +myCheckInManager.verify(sys.argv) + +host = myCheckInManager.host +rootCA = "./test-integration/Credentials/rootCA.crt" +certificate = "./test-integration/Credentials/certificate.pem.crt" +privateKey = "./test-integration/Credentials/privateKey.pem.key" +mode = myCheckInManager.mode + +skip_when_match(ModeIsALPN(mode).And( + Python2VersionLowerThan((2, 7, 10)).Or(Python3VersionLowerThan((3, 5, 0))) +), "This test is not applicable for mode %s and Python verison %s. Skipping..." % (mode, sys.version_info[:3])) + +# Init Python core and connect +myMQTTClientManager = MQTTClientManager.MQTTClientManager() +client = myMQTTClientManager.create_connected_mqtt_client(mode, CLIENT_ID, host, (rootCA, certificate, privateKey)) + +clientId = 'AWSPythonkSDKTestThingClient' +thingName = 'AWSPythonkSDKTestThing' +jobsClient = AWSIoTMQTTThingJobsClient(clientId, thingName, QoS=1, awsIoTMQTTClient=client) + +print('Connecting to MQTT server and setting up callbacks...') +jobsMsgProc = JobsMessageProcessor(jobsClient, clientId) +print('Starting jobs tests...') +jobsMsgProc.runTests() +print('Done running jobs tests') + +#can call this on the jobsClient, or myAWSIoTMQTTClient directly +jobsClient.disconnect() diff --git a/test-integration/IntegrationTests/IntegrationTestMQTTConnection.py b/test-integration/IntegrationTests/IntegrationTestMQTTConnection.py new file mode 100644 index 0000000..9adc38c --- /dev/null +++ b/test-integration/IntegrationTests/IntegrationTestMQTTConnection.py @@ -0,0 +1,177 @@ +# This integration test verifies the functionality in the Python core of IoT Yun/Python SDK +# for basic MQTT connection. +# It starts two threads using two different connections to AWS IoT: +# Thread A: publish to "deviceSDK/PyIntegrationTest/Topic", X messages, QoS1, TPS=50 +# Thread B: subscribe to "deviceSDK/PyIntegrationTest/Topic", QoS1 +# Thread B will be started first with extra delay to ensure a valid subscription +# Then thread A will be started. +# Verify send/receive messages are equivalent + + +import random +import string +import time +import sys +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary") +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary/SDKPackage") + +from TestToolLibrary import simpleThreadManager +import TestToolLibrary.checkInManager as checkInManager +import TestToolLibrary.MQTTClientManager as MQTTClientManager +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import publishError +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeError +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeTimeoutException +from TestToolLibrary.skip import skip_when_match +from TestToolLibrary.skip import ModeIsALPN +from TestToolLibrary.skip import Python2VersionLowerThan +from TestToolLibrary.skip import Python3VersionLowerThan + + +API_TYPE_SYNC = "sync" +API_TYPE_ASYNC = "async" + +CLIENT_ID_PUB = "integrationTestMQTT_ClientPub" + "".join(random.choice(string.ascii_lowercase) for i in range(4)) +CLIENT_ID_SUB = "integrationTestMQTT_ClientSub" + "".join(random.choice(string.ascii_lowercase) for i in range(4)) + + + +# Callback unit for subscribe +class callbackUnit: + def __init__(self, srcSet, apiType): + self._internalSet = srcSet + self._apiType = apiType + + # Callback for clientSub + def messageCallback(self, client, userdata, message): + print(self._apiType + ": Received a new message: " + str(message.payload)) + self._internalSet.add(message.payload.decode('utf-8')) + + def getInternalSet(self): + return self._internalSet + + +# Run function unit +class runFunctionUnit: + def __init__(self, apiType): + self._messagesPublished = set() + self._apiType = apiType + + # Run function for publish thread (one time) + def threadPublish(self, pyCoreClient, numberOfTotalMessages, topic, TPS): + # One time thread + time.sleep(3) # Extra waiting time for valid subscription + messagesLeftToBePublished = numberOfTotalMessages + while messagesLeftToBePublished != 0: + try: + currentMessage = str(messagesLeftToBePublished) + self._performPublish(pyCoreClient, topic, 1, currentMessage) + self._messagesPublished.add(currentMessage) + except publishError: + print("Publish Error for message: " + currentMessage) + except Exception as e: + print("Unknown exception: " + str(type(e)) + " " + str(e.message)) + messagesLeftToBePublished -= 1 + time.sleep(1 / float(TPS)) + print("End of publish thread.") + + def _performPublish(self, pyCoreClient, topic, qos, payload): + if self._apiType == API_TYPE_SYNC: + pyCoreClient.publish(topic, payload, qos, False) + if self._apiType == API_TYPE_ASYNC: + pyCoreClient.publish_async(topic, payload, qos, False, None) # TODO: See if we can also check PUBACKs + + +############################################################################ +# Main # +# Check inputs +myCheckInManager = checkInManager.checkInManager(3) +myCheckInManager.verify(sys.argv) + +host = myCheckInManager.host +rootCA = "./test-integration/Credentials/rootCA.crt" +certificate = "./test-integration/Credentials/certificate.pem.crt" +privateKey = "./test-integration/Credentials/privateKey.pem.key" +mode = myCheckInManager.mode + +skip_when_match(ModeIsALPN(mode).And( + Python2VersionLowerThan((2, 7, 10)).Or(Python3VersionLowerThan((3, 5, 0))) +), "This test is not applicable for mode %s and Python verison %s. Skipping..." % (mode, sys.version_info[:3])) + +# Init Python core and connect +myMQTTClientManager = MQTTClientManager.MQTTClientManager() +clientPub = myMQTTClientManager.create_connected_mqtt_core(CLIENT_ID_PUB, host, rootCA, + certificate, privateKey, mode=mode) +clientSub = myMQTTClientManager.create_connected_mqtt_core(CLIENT_ID_SUB, host, rootCA, + certificate, privateKey, mode=mode) + +if clientPub is None or clientSub is None: + exit(4) + +print("Two clients are connected!") + +# Configurations +################ +# Data/Data pool +TPS = 20 +numberOfTotalMessagesAsync = myCheckInManager.customParameter +numberOfTotalMessagesSync = numberOfTotalMessagesAsync / 10 +subSetAsync = set() +subSetSync = set() +subCallbackUnitAsync = callbackUnit(subSetAsync, API_TYPE_ASYNC) +subCallbackUnitSync = callbackUnit(subSetSync, API_TYPE_SYNC) +syncTopic = "YunSDK/PyIntegrationTest/Topic/sync" + "".join(random.choice(string.ascii_lowercase) for i in range(4)) +print(syncTopic) +asyncTopic = "YunSDK/PyIntegrationTest/Topic/async" + "".join(random.choice(string.ascii_lowercase) for j in range(4)) +# clientSub +try: + clientSub.subscribe(asyncTopic, 1, subCallbackUnitAsync.messageCallback) + clientSub.subscribe(syncTopic, 1, subCallbackUnitSync.messageCallback) + time.sleep(3) +except subscribeTimeoutException: + print("Subscribe timeout!") +except subscribeError: + print("Subscribe error!") +except Exception as e: + print("Unknown exception!") + print("Type: " + str(type(e))) + print("Message: " + str(e.message)) +# Threads +mySimpleThreadManager = simpleThreadManager.simpleThreadManager() +myRunFunctionUnitSyncPub = runFunctionUnit(API_TYPE_SYNC) +myRunFunctionUnitAsyncPub = runFunctionUnit(API_TYPE_ASYNC) +publishSyncThreadID = mySimpleThreadManager.createOneTimeThread(myRunFunctionUnitSyncPub.threadPublish, + [clientPub, numberOfTotalMessagesSync, syncTopic, TPS]) +publishAsyncThreadID = mySimpleThreadManager.createOneTimeThread(myRunFunctionUnitAsyncPub.threadPublish, + [clientPub, numberOfTotalMessagesAsync, asyncTopic, TPS]) + +# Performing +############ +mySimpleThreadManager.startThreadWithID(publishSyncThreadID) +mySimpleThreadManager.startThreadWithID(publishAsyncThreadID) +mySimpleThreadManager.joinOneTimeThreadWithID(publishSyncThreadID) +mySimpleThreadManager.joinOneTimeThreadWithID(publishAsyncThreadID) +time.sleep(numberOfTotalMessagesAsync / float(TPS) * 0.5) + +# Verifying +########### +# Length +print("Check if the length of the two sets are equal...") +print("Received from subscription (sync pub): " + str(len(subCallbackUnitSync.getInternalSet()))) +print("Received from subscription (async pub): " + str(len(subCallbackUnitAsync.getInternalSet()))) +print("Sent through sync publishes: " + str(len(myRunFunctionUnitSyncPub._messagesPublished))) +print("Sent through async publishes: " + str(len(myRunFunctionUnitAsyncPub._messagesPublished))) +if len(myRunFunctionUnitSyncPub._messagesPublished) != len(subCallbackUnitSync.getInternalSet()): + print("[Sync pub] Number of messages not equal!") + exit(4) +if len(myRunFunctionUnitAsyncPub._messagesPublished) != len(subCallbackUnitAsync.getInternalSet()): + print("[Asyn pub] Number of messages not equal!") + exit(4) +# Content +print("Check if the content if the two sets are equivalent...") +if myRunFunctionUnitSyncPub._messagesPublished != subCallbackUnitSync.getInternalSet(): + print("[Sync pub] Set content not equal!") + exit(4) +elif myRunFunctionUnitAsyncPub._messagesPublished != subCallbackUnitAsync.getInternalSet(): + print("[Async pub] Set content not equal!") +else: + print("Yes!") diff --git a/test-integration/IntegrationTests/IntegrationTestOfflineQueueingForSubscribeUnsubscribe.py b/test-integration/IntegrationTests/IntegrationTestOfflineQueueingForSubscribeUnsubscribe.py new file mode 100644 index 0000000..37c1862 --- /dev/null +++ b/test-integration/IntegrationTests/IntegrationTestOfflineQueueingForSubscribeUnsubscribe.py @@ -0,0 +1,210 @@ +# This integration test verifies the functionality off queueing up subscribe/unsubscribe requests submitted by the +# client when it is offline, and drain them out when the client is reconnected. The test contains 2 clients running in +# 2 different threads: +# +# In thread A, client_sub_unsub follows the below workflow: +# 1. Client connects to AWS IoT. +# 2. Client subscribes to "topic_A". +# 3. Experience a simulated network error which brings the client offline. +# 4. While offline, client subscribes to "topic_B' and unsubscribes from "topic_A". +# 5. Client reconnects, comes back online and drains out all offline queued requests. +# 6. Client stays and receives messages published in another thread. +# +# In thread B, client_pub follows the below workflow: +# 1. Client in thread B connects to AWS IoT. +# 2. After client in thread A connects and subscribes to "topic_A", client in thread B publishes messages to "topic_A". +# 3. Client in thread B keeps sleeping until client in thread A goes back online and reaches to a stable state (draining done). +# 4. Client in thread B then publishes messages to "topic_A" and "topic_B". +# +# Since client in thread A does a unsubscribe to "topic_A", it should never receive messages published to "topic_A" after +# it reconnects and gets stable. It should have the messages from "topic_A" published at the very beginning. +# Since client in thread A does a subscribe to "topic_B", it should receive messages published to "topic_B" after it +# reconnects and gets stable. + + +import random +import string +import time +from threading import Event +from threading import Thread +import sys +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary") +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary/SDKPackage") + +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.core.protocol.internal.clients import ClientStatus +from TestToolLibrary.checkInManager import checkInManager +from TestToolLibrary.MQTTClientManager import MQTTClientManager +from TestToolLibrary.skip import skip_when_match +from TestToolLibrary.skip import ModeIsALPN +from TestToolLibrary.skip import Python2VersionLowerThan +from TestToolLibrary.skip import Python3VersionLowerThan + + +def get_random_string(length): + return "".join(random.choice(string.ascii_lowercase) for i in range(length)) + +TOPIC_A = "topic/test/offline_sub_unsub/a" + get_random_string(4) +TOPIC_B = "topic/test/offline_sub_unsub/b" + get_random_string(4) +MESSAGE_PREFIX = "MagicMessage-" +NUMBER_OF_PUBLISHES = 3 +ROOT_CA = "./test-integration/Credentials/rootCA.crt" +CERT = "./test-integration/Credentials/certificate.pem.crt" +KEY = "./test-integration/Credentials/privateKey.pem.key" +CLIENT_PUB_ID = "PySdkIntegTest_OfflineSubUnsub_pub" + get_random_string(4) +CLIENT_SUB_UNSUB_ID = "PySdkIntegTest_OfflineSubUnsub_subunsub" + get_random_string(4) +KEEP_ALIVE_SEC = 1 +EVENT_WAIT_TIME_OUT_SEC = 5 + + +class DualClientRunner(object): + + def __init__(self, mode): + self._publish_end_flag = Event() + self._stable_flag = Event() + self._received_messages_topic_a = list() + self._received_messages_topic_b = list() + self.__mode = mode + self._client_pub = self._create_connected_client(CLIENT_PUB_ID) + print("Created connected client pub.") + self._client_sub_unsub = self._create_connected_client(CLIENT_SUB_UNSUB_ID) + print("Created connected client sub/unsub.") + self._client_sub_unsub.subscribe(TOPIC_A, 1, self._collect_sub_messages) + print("Client sub/unsub subscribed to topic: %s" % TOPIC_A) + time.sleep(2) # Make sure the subscription is valid + + def _create_connected_client(self, id_prefix): + return MQTTClientManager().create_connected_mqtt_client(self.__mode, id_prefix, host, (ROOT_CA, CERT, KEY)) + + def start(self): + thread_client_sub_unsub = Thread(target=self._thread_client_sub_unsub_runtime) + thread_client_pub = Thread(target=self._thread_client_pub_runtime) + thread_client_sub_unsub.start() + thread_client_pub.start() + thread_client_pub.join() + thread_client_sub_unsub.join() + + def _thread_client_sub_unsub_runtime(self): + print("Start client sub/unsub runtime thread...") + print("Client sub/unsub waits on the 1st round of publishes to end...") + if not self._publish_end_flag.wait(EVENT_WAIT_TIME_OUT_SEC): + raise RuntimeError("Timed out in waiting for the publishes to topic: %s" % TOPIC_A) + print("Client sub/unsub gets notified.") + self._publish_end_flag.clear() + + print("Client sub/unsub now goes offline...") + self._go_offline_and_send_requests() + + # Wait until the connection is stable and then notify + print("Client sub/unsub waits on a stable connection...") + self._wait_until_stable_connection() + + print("Client sub/unsub waits on the 2nd round of publishes to end...") + if not self._publish_end_flag.wait(EVENT_WAIT_TIME_OUT_SEC): + raise RuntimeError("Timed out in waiting for the publishes to topic: %s" % TOPIC_B) + print("Client sub/unsub gets notified.") + self._publish_end_flag.clear() + + print("Client sub/unsub runtime thread ends.") + + def _wait_until_stable_connection(self): + reconnect_timing = 0 + while self._client_sub_unsub._mqtt_core._client_status.get_status() != ClientStatus.STABLE: + time.sleep(0.01) + reconnect_timing += 1 + if reconnect_timing % 100 == 0: + print("Client sub/unsub: Counting reconnect time: " + str(reconnect_timing / 100) + " seconds.") + print("Client sub/unsub: Counting reconnect time result: " + str(float(reconnect_timing) / 100) + " seconds.") + self._stable_flag.set() + + def _collect_sub_messages(self, client, userdata, message): + if message.topic == TOPIC_A: + print("Client sub/unsub: Got a message from %s" % TOPIC_A) + self._received_messages_topic_a.append(message.payload) + if message.topic == TOPIC_B: + print("Client sub/unsub: Got a message from %s" % TOPIC_B) + self._received_messages_topic_b.append(message.payload) + + def _go_offline_and_send_requests(self): + do_once = True + loop_count = EVENT_WAIT_TIME_OUT_SEC * 100 + while loop_count != 0: + self._manual_network_error() + if loop_count % 100 == 0: + print("Client sub/unsub: Offline time down count: %d sec" % (loop_count / 100)) + if do_once and (loop_count / 100) <= (EVENT_WAIT_TIME_OUT_SEC / 2): + print("Client sub/unsub: Performing offline sub/unsub...") + self._client_sub_unsub.subscribe(TOPIC_B, 1, self._collect_sub_messages) + self._client_sub_unsub.unsubscribe(TOPIC_A) + print("Client sub/unsub: Done with offline sub/unsub.") + do_once = False + loop_count -= 1 + time.sleep(0.01) + + def _manual_network_error(self): + # Ensure we close the socket + if self._client_sub_unsub._mqtt_core._internal_async_client._paho_client._sock: + self._client_sub_unsub._mqtt_core._internal_async_client._paho_client._sock.close() + self._client_sub_unsub._mqtt_core._internal_async_client._paho_client._sock = None + if self._client_sub_unsub._mqtt_core._internal_async_client._paho_client._ssl: + self._client_sub_unsub._mqtt_core._internal_async_client._paho_client._ssl.close() + self._client_sub_unsub._mqtt_core._internal_async_client._paho_client._ssl = None + # Fake that we have detected the disconnection + self._client_sub_unsub._mqtt_core._internal_async_client._paho_client.on_disconnect(None, None, 0) + + def _thread_client_pub_runtime(self): + print("Start client pub runtime thread...") + print("Client pub: 1st round of publishes") + for i in range(NUMBER_OF_PUBLISHES): + self._client_pub.publish(TOPIC_A, MESSAGE_PREFIX + str(i), 1) + print("Client pub: Published a message") + time.sleep(0.5) + time.sleep(1) + print("Client pub: Publishes done. Notifying...") + self._publish_end_flag.set() + + print("Client pub waits on client sub/unsub to be stable...") + time.sleep(1) + if not self._stable_flag.wait(EVENT_WAIT_TIME_OUT_SEC * 3): # We wait longer for the reconnect/stabilization + raise RuntimeError("Timed out in waiting for client_sub_unsub to be stable") + self._stable_flag.clear() + + print("Client pub: 2nd round of publishes") + for j in range(NUMBER_OF_PUBLISHES): + self._client_pub.publish(TOPIC_B, MESSAGE_PREFIX + str(j), 1) + print("Client pub: Published a message to %s" % TOPIC_B) + self._client_pub.publish(TOPIC_A, MESSAGE_PREFIX + str(j) + "-dup", 1) + print("Client pub: Published a message to %s" % TOPIC_A) + time.sleep(0.5) + time.sleep(1) + print("Client pub: Publishes done. Notifying...") + self._publish_end_flag.set() + + print("Client pub runtime thread ends.") + + def verify(self): + print("Verifying...") + assert len(self._received_messages_topic_a) == NUMBER_OF_PUBLISHES # We should only receive the first round + assert len(self._received_messages_topic_b) == NUMBER_OF_PUBLISHES # We should only receive the second round + print("Pass!") + + +############################################################################ +# Main # +# Check inputs +my_check_in_manager = checkInManager(2) +my_check_in_manager.verify(sys.argv) +mode = my_check_in_manager.mode +host = my_check_in_manager.host + +skip_when_match(ModeIsALPN(mode).And( + Python2VersionLowerThan((2, 7, 10)).Or(Python3VersionLowerThan((3, 5, 0))) +), "This test is not applicable for mode %s and Python verison %s. Skipping..." % (mode, sys.version_info[:3])) + +# Performing +############ +dual_client_runner = DualClientRunner(mode) +dual_client_runner.start() + +# Verifying +########### +dual_client_runner.verify() diff --git a/test-integration/IntegrationTests/IntegrationTestProgressiveBackoff.py b/test-integration/IntegrationTests/IntegrationTestProgressiveBackoff.py new file mode 100644 index 0000000..fc937ef --- /dev/null +++ b/test-integration/IntegrationTests/IntegrationTestProgressiveBackoff.py @@ -0,0 +1,289 @@ +# This integration test verifies the functionality in the Python core of Yun/Python SDK +# for progressive backoff logic in auto-reconnect. +# It starts two threads using two different connections to AWS IoT: +# Thread B subscribes to "coolTopic" and waits for incoming messages. Network +# failure will happen occasionally in thread B with a variant interval (connected +# period), simulating stable/unstable connection so as to test the reset logic of +# backoff timing. Once thread B is back online, an internal flag will be set to +# notify the other thread, Thread A, to start publishing to the same topic. +# Thread A will publish a set of messages (a fixed number of messages) to "coolTopic" +# using QoS1 and does nothing in the rest of the time. It will only start publishing +# when it gets ready notification from thread B. No network failure will happen in +# thread A. +# Because thread A is always online and only publishes when thread B is back online, +# all messages published to "coolTopic" should be received by thread B. In meantime, +# thread B should have an increasing amount of backoff waiting period until the +# connected period reaches the length of time for a stable connection. After that, +# the backoff waiting period should be reset. +# The following things will be verified to pass the test: +# 1. All messages are received. +# 2. Backoff waiting period increases as configured before the thread reaches to a +# stable connection. +# 3. Backoff waiting period does not exceed the maximum allowed time. +# 4. Backoff waiting period is reset after the thread reaches to a stable connection. + + +import string +import random +import time +import threading +import sys +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary") +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary/SDKPackage") + +from TestToolLibrary import simpleThreadManager +import TestToolLibrary.checkInManager as checkInManager +import TestToolLibrary.MQTTClientManager as MQTTClientManager +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import publishError +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeError +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeTimeoutException +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.core.protocol.internal.clients import ClientStatus +from TestToolLibrary.skip import skip_when_match +from TestToolLibrary.skip import ModeIsALPN +from TestToolLibrary.skip import Python2VersionLowerThan +from TestToolLibrary.skip import Python3VersionLowerThan + +CLIENT_ID_PUB = "integrationTestMQTT_ClientPub" + "".join(random.choice(string.ascii_lowercase) for i in range(4)) +CLIENT_ID_SUB = "integrationTestMQTT_ClientSub" + "".join(random.choice(string.ascii_lowercase) for i in range(4)) + +# Class that implements all the related threads in the test in a controllable manner +class threadPool: + def __init__(self, srcTotalNumberOfNetworkFailure, clientPub, clientSub): + self._threadBReadyFlag = 0 # 0-Not connected, 1-Connected+Subscribed, -1-ShouldExit + self._threadBReadyFlagMutex = threading.Lock() + self._targetedTopic = "coolTopic" + "".join(random.choice(string.ascii_lowercase) for i in range(4)) + self._publishMessagePool = set() + self._receiveMessagePool = set() + self._roundOfNetworkFailure = 1 + self._totalNumberOfNetworkFailure = srcTotalNumberOfNetworkFailure + self._clientPub = clientPub + self._clientSub = clientSub + self._pubCount = 0 + self._reconnectTimeRecord = list() + self._connectedTimeRecord = list() + + # Message callback for collecting incoming messages from the subscribed topic + def _messageCallback(self, client, userdata, message): + print("Thread B: Received new message: " + str(message.payload)) + self._receiveMessagePool.add(message.payload.decode('utf-8')) + + # The one that publishes + def threadARuntime(self): + exitNow = False + while not exitNow: + self._threadBReadyFlagMutex.acquire() + # Thread A is still reconnecting, WAIT! + if self._threadBReadyFlag == 0: + pass + # Thread A is connected and subscribed, PUBLISH! + elif self._threadBReadyFlag == 1: + self._publish3Messages() + self._threadBReadyFlag = 0 # Reset the readyFlag + # Thread A has finished all rounds of network failure/reconnect, EXIT! + else: + exitNow = True + self._threadBReadyFlagMutex.release() + time.sleep(0.01) # 0.01 sec scanning + + # Publish a set of messages: 3 + def _publish3Messages(self): + loopCount = 3 + while loopCount != 0: + try: + currentMessage = "Message" + str(self._pubCount) + print("Test publish to topic : " + self._targetedTopic) + self._clientPub.publish(self._targetedTopic, currentMessage, 1, False) + print("Thread A: Published new message: " + str(currentMessage)) + self._publishMessagePool.add(currentMessage) + self._pubCount += 1 + loopCount -= 1 + except publishError: + print("Publish error!") + except Exception as e: + print("Unknown exception!") + print("Type: " + str(type(e))) + print("Message: " + str(e.message)) + time.sleep(0.5) + + # The one that subscribes and has network failures + def threadBRuntime(self): + # Subscribe to the topic + try: + print("Test subscribe to topic : " + self._targetedTopic) + self._clientSub.subscribe(self._targetedTopic, 1, self._messageCallback) + except subscribeTimeoutException: + print("Subscribe timeout!") + except subscribeError: + print("Subscribe error!") + except Exception as e: + print("Unknown exception!") + print("Type: " + str(type(e))) + print("Message: " + str(e.message)) + print("Thread B: Subscribe request sent. Staring waiting for subscription processing...") + time.sleep(3) + print("Thread B: Done waiting.") + self._threadBReadyFlagMutex.acquire() + self._threadBReadyFlag = 1 + self._threadBReadyFlagMutex.release() + # Start looping with network failure + connectedPeriodSecond = 3 + while self._roundOfNetworkFailure <= self._totalNumberOfNetworkFailure: + self._connectedTimeRecord.append(connectedPeriodSecond) + # Wait for connectedPeriodSecond + print("Thread B: Connected time: " + str(connectedPeriodSecond) + " seconds.") + print("Thread B: Stable time: 60 seconds.") + time.sleep(connectedPeriodSecond) + print("Thread B: Network failure. Round: " + str(self._roundOfNetworkFailure) + ". 0.5 seconds.") + print("Thread B: Backoff time for this round should be: " + str( + self._clientSub._internal_async_client._paho_client._backoffCore._currentBackoffTimeSecond) + " second(s).") + # Set the readyFlag + self._threadBReadyFlagMutex.acquire() + self._threadBReadyFlag = 0 + self._threadBReadyFlagMutex.release() + # Now lose connection for 0.5 seconds, preventing multiple reconnect attempts + loseConnectionLoopCount = 50 + while loseConnectionLoopCount != 0: + self._manualNetworkError() + loseConnectionLoopCount -= 1 + time.sleep(0.01) + # Wait until the connection/subscription is recovered + reconnectTiming = 0 + while self._clientSub._client_status.get_status() != ClientStatus.STABLE: + time.sleep(0.01) + reconnectTiming += 1 + if reconnectTiming % 100 == 0: + print("Thread B: Counting reconnect time: " + str(reconnectTiming / 100) + " seconds.") + print("Thread B: Counting reconnect time result: " + str(float(reconnectTiming) / 100) + " seconds.") + self._reconnectTimeRecord.append(reconnectTiming / 100) + + time.sleep(3) # For valid subscription + + # Update thread B status + self._threadBReadyFlagMutex.acquire() + self._threadBReadyFlag = 1 + self._threadBReadyFlagMutex.release() + + # Update connectedPeriodSecond + connectedPeriodSecond += (2 ** (self._roundOfNetworkFailure - 1)) + # Update roundOfNetworkFailure + self._roundOfNetworkFailure += 1 + + # Notify thread A shouldExit + self._threadBReadyFlagMutex.acquire() + self._threadBReadyFlag = -1 + self._threadBReadyFlagMutex.release() + + # Simulate a network error + def _manualNetworkError(self): + # Only the subscriber needs the network error + if self._clientSub._internal_async_client._paho_client._sock: + self._clientSub._internal_async_client._paho_client._sock.close() + self._clientSub._internal_async_client._paho_client._sock = None + if self._clientSub._internal_async_client._paho_client._ssl: + self._clientSub._internal_async_client._paho_client._ssl.close() + self._clientSub._internal_async_client._paho_client._ssl = None + # Fake that we have detected the disconnection + self._clientSub._internal_async_client._paho_client.on_disconnect(None, None, 0) + + def getReconnectTimeRecord(self): + return self._reconnectTimeRecord + + def getConnectedTimeRecord(self): + return self._connectedTimeRecord + + +# Generate the correct backoff timing to compare the test result with +def generateCorrectAnswer(baseTime, maximumTime, stableTime, connectedTimeRecord): + answer = list() + currentTime = baseTime + nextTime = baseTime + for i in range(0, len(connectedTimeRecord)): + if connectedTimeRecord[i] >= stableTime or i == 0: + currentTime = baseTime + else: + currentTime = min(currentTime * 2, maximumTime) + answer.append(currentTime) + return answer + + +# Verify backoff time +# Corresponding element should have no diff or a bias greater than 1.5 +def verifyBackoffTime(answerList, resultList): + result = True + for i in range(0, len(answerList)): + if abs(answerList[i] - resultList[i]) > 1.5: + result = False + break + return result + + +############################################################################ +# Main # +# Check inputs +myCheckInManager = checkInManager.checkInManager(3) +myCheckInManager.verify(sys.argv) + +#host via describe-endpoint on this OdinMS: com.amazonaws.iot.device.sdk.credentials.testing.websocket +host = myCheckInManager.host +rootCA = "./test-integration/Credentials/rootCA.crt" +certificate = "./test-integration/Credentials/certificate.pem.crt" +privateKey = "./test-integration/Credentials/privateKey.pem.key" +mode = myCheckInManager.mode + +skip_when_match(ModeIsALPN(mode).And( + Python2VersionLowerThan((2, 7, 10)).Or(Python3VersionLowerThan((3, 5, 0))) +), "This test is not applicable for mode %s and Python verison %s. Skipping..." % (mode, sys.version_info[:3])) + +# Init Python core and connect +myMQTTClientManager = MQTTClientManager.MQTTClientManager() +clientPub = myMQTTClientManager.create_connected_mqtt_core(CLIENT_ID_PUB, host, rootCA, + certificate, privateKey, mode=mode) +clientSub = myMQTTClientManager.create_connected_mqtt_core(CLIENT_ID_SUB, host, rootCA, + certificate, privateKey, mode=mode) + +if clientPub is None or clientSub is None: + exit(4) + +# Extra configuration for clients +clientPub.configure_reconnect_back_off(1, 16, 60) +clientSub.configure_reconnect_back_off(1, 16, 60) + +print("Two clients are connected!") + +# Configurations +################ +# Custom parameters +NumberOfNetworkFailure = myCheckInManager.customParameter +# ThreadPool object +threadPoolObject = threadPool(NumberOfNetworkFailure, clientPub, clientSub) +# Threads +mySimpleThreadManager = simpleThreadManager.simpleThreadManager() +threadAID = mySimpleThreadManager.createOneTimeThread(threadPoolObject.threadARuntime, []) +threadBID = mySimpleThreadManager.createOneTimeThread(threadPoolObject.threadBRuntime, []) + +# Performing +############ +mySimpleThreadManager.startThreadWithID(threadBID) +mySimpleThreadManager.startThreadWithID(threadAID) +mySimpleThreadManager.joinOneTimeThreadWithID(threadBID) +mySimpleThreadManager.joinOneTimeThreadWithID(threadAID) + +# Verifying +########### +print("Verify that all messages are received...") +if threadPoolObject._publishMessagePool == threadPoolObject._receiveMessagePool: + print("Passed. Recv/Pub: " + str(len(threadPoolObject._receiveMessagePool)) + "/" + str( + len(threadPoolObject._publishMessagePool))) +else: + print("Not all messages are received!") + exit(4) +print("Verify reconnect backoff time record...") +print("ConnectedTimeRecord: " + str(threadPoolObject.getConnectedTimeRecord())) +print("ReconnectTimeRecord: " + str(threadPoolObject.getReconnectTimeRecord())) +print("Answer: " + str(generateCorrectAnswer(1, 16, 60, threadPoolObject.getConnectedTimeRecord()))) +if verifyBackoffTime(generateCorrectAnswer(1, 16, 60, threadPoolObject.getConnectedTimeRecord()), + threadPoolObject.getReconnectTimeRecord()): + print("Passed.") +else: + print("Backoff time does not match theoretical value!") + exit(4) diff --git a/test-integration/IntegrationTests/IntegrationTestShadow.py b/test-integration/IntegrationTests/IntegrationTestShadow.py new file mode 100644 index 0000000..9b2d85a --- /dev/null +++ b/test-integration/IntegrationTests/IntegrationTestShadow.py @@ -0,0 +1,248 @@ +# This integration test verifies the functionality in the Python core of Yun/Python SDK +# for IoT shadow operations: shadowUpdate and delta. +# 1. The test generates a X-byte-long random sting and breaks it into a random +# number of chunks, with a fixed length variation from 1 byte to 10 bytes. +# 2. Two threads are created to do shadowUpdate and delta on the same device +# shadow. The update thread updates the desired state with an increasing sequence +# number and a chunk. It is terminated when there are no more chunks to be sent. +# 3. The delta thread listens on delta topic and receives the changes in device +# shadow JSON document. It parses out the sequence number and the chunk, then pack +# them into a dictionary with sequence number as the key and the chunk as the value. +# 4. To verify the result of the test, the random string is re-assembled for both +# the update thread and the delta thread to see if they are equal. +# 5. Since shadow operations are all QoS0 (Pub/Sub), it is still a valid case when +# the re-assembled strings are not equal. Then we need to make sure that the number +# of the missing chunks does not exceed 10% of the total number of chunk transmission +# that succeeds. + +import time +import random +import string +import json +import sys +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary") +sys.path.insert(0, "./test-integration/IntegrationTests/TestToolLibrary/SDKPackage") + +from TestToolLibrary import simpleThreadManager +import TestToolLibrary.checkInManager as checkInManager +import TestToolLibrary.MQTTClientManager as MQTTClientManager +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.core.shadow.deviceShadow import deviceShadow +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.core.shadow.shadowManager import shadowManager +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import publishError +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeError +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeTimeoutException +from TestToolLibrary.skip import skip_when_match +from TestToolLibrary.skip import ModeIsALPN +from TestToolLibrary.skip import Python2VersionLowerThan +from TestToolLibrary.skip import Python3VersionLowerThan + + +# Global configuration +TPS = 1 # Update speed, Spectre does not tolerate high TPS shadow operations... +CLIENT_ID_PUB = "integrationTestMQTT_ClientPub" + "".join(random.choice(string.ascii_lowercase) for i in range(4)) +CLIENT_ID_SUB = "integrationTestMQTT_ClientSub" + "".join(random.choice(string.ascii_lowercase) for i in range(4)) + + +# Class that manages the generation and chopping of the random string +class GibberishBox: + def __init__(self, length): + self._content = self._generateGibberish(length) + + def getGibberish(self): + return self._content + + # Random string generator: lower/upper case letter + digits + def _generateGibberish(self, length): + s = string.ascii_lowercase + string.digits + string.ascii_uppercase + return ''.join(random.sample(s * length, length)) + + # Spit out the gibberish chunk by chunk (1-10 bytes) + def gibberishSpitter(self): + randomLength = random.randrange(1, 11) + ret = None + if self._content is not None: + ret = self._content[0:randomLength] + self._content = self._content[randomLength:] + return ret + + +# Class that manages the callback function and record of chunks for re-assembling +class callbackContainer: + def __init__(self): + self._internalDictionary = dict() + + def getInternalDictionary(self): + return self._internalDictionary + + def testCallback(self, payload, type, token): + print("Type: " + type) + print(payload) + print("&&&&&&&&&&&&&&&&&&&&") + # This is the shadow delta callback, so the token should be None + if type == "accepted": + JsonDict = json.loads(payload) + try: + sequenceNumber = int(JsonDict['state']['desired']['sequenceNumber']) + gibberishChunk = JsonDict['state']['desired']['gibberishChunk'] + self._internalDictionary[sequenceNumber] = gibberishChunk + except KeyError as e: + print(e.message) + print("No such key!") + else: + JsonDict = json.loads(payload) + try: + sequenceNumber = int(JsonDict['state']['sequenceNumber']) + gibberishChunk = JsonDict['state']['gibberishChunk'] + self._internalDictionary[sequenceNumber] = gibberishChunk + except KeyError as e: + print(e.message) + print("No such key!") + + +# Thread runtime function +def threadShadowUpdate(deviceShadow, callback, TPS, gibberishBox, maxNumMessage): + time.sleep(2) + chunkSequence = 0 + while True: + currentChunk = gibberishBox.gibberishSpitter() + if currentChunk != "": + outboundJsonDict = dict() + outboundJsonDict["state"] = dict() + outboundJsonDict["state"]["desired"] = dict() + outboundJsonDict["state"]["desired"]["sequenceNumber"] = chunkSequence + outboundJsonDict["state"]["desired"]["gibberishChunk"] = currentChunk + outboundJSON = json.dumps(outboundJsonDict) + chunkSequence += 1 + try: + deviceShadow.shadowUpdate(outboundJSON, callback, 5) + except publishError: + print("Publish error!") + except subscribeTimeoutException: + print("Subscribe timeout!") + except subscribeError: + print("Subscribe error!") + except Exception as e: + print("Unknown exception!") + print("Type: " + str(type(e))) + print("Message: " + str(e.message)) + time.sleep(1 / TPS) + else: + break + print("Update thread completed.") + + +# Re-assemble gibberish +def reAssembleGibberish(srcDict, maxNumMessage): + ret = "" + for i in range(0, maxNumMessage): + try: + ret += srcDict[i] + except KeyError: + pass + return ret + + +# RandomShadowNameSuffix +def randomString(lengthOfString): + return "".join(random.choice(string.ascii_lowercase) for i in range(lengthOfString)) + + +############################################################################ +# Main # +# Check inputs +myCheckInManager = checkInManager.checkInManager(3) +myCheckInManager.verify(sys.argv) + +host = myCheckInManager.host +rootCA = "./test-integration/Credentials/rootCA.crt" +certificate = "./test-integration/Credentials/certificate.pem.crt" +privateKey = "./test-integration/Credentials/privateKey.pem.key" +mode = myCheckInManager.mode + +skip_when_match(ModeIsALPN(mode).And( + Python2VersionLowerThan((2, 7, 10)).Or(Python3VersionLowerThan((3, 5, 0))) +), "This test is not applicable for mode %s and Python verison %s. Skipping..." % (mode, sys.version_info[:3])) + +# Init Python core and connect +myMQTTClientManager = MQTTClientManager.MQTTClientManager() +clientPub = myMQTTClientManager.create_connected_mqtt_core(CLIENT_ID_PUB, host, rootCA, + certificate, privateKey, mode=mode) +clientSub = myMQTTClientManager.create_connected_mqtt_core(CLIENT_ID_SUB, host, rootCA, + certificate, privateKey, mode=mode) + +if clientPub is None or clientSub is None: + exit(4) + +print("Two clients are connected!") + +# Configurations +################ +# Data +gibberishLength = myCheckInManager.customParameter +# Init device shadow instance +shadowManager1 = shadowManager(clientPub) +shadowManager2 = shadowManager(clientSub) +shadowName = "GibberChunk" + randomString(5) +deviceShadow1 = deviceShadow(shadowName, True, shadowManager1) +deviceShadow2 = deviceShadow(shadowName, True, shadowManager2) +print("Two device shadow instances are created!") + +# Callbacks +callbackHome_Update = callbackContainer() +callbackHome_Delta = callbackContainer() + +# Listen on delta topic +try: + deviceShadow2.shadowRegisterDeltaCallback(callbackHome_Delta.testCallback) +except subscribeError: + print("Subscribe error!") +except subscribeTimeoutException: + print("Subscribe timeout!") +except Exception as e: + print("Unknown exception!") + print("Type: " + str(type(e))) + print("Message: " + str(e.message)) + +# Init gibberishBox +cipher = GibberishBox(gibberishLength) +gibberish = cipher.getGibberish() +print("Random string: " + gibberish) + +# Threads +mySimpleThreadManager = simpleThreadManager.simpleThreadManager() +updateThreadID = mySimpleThreadManager.createOneTimeThread(threadShadowUpdate, + [deviceShadow1, callbackHome_Update.testCallback, TPS, + cipher, gibberishLength]) + +# Performing +############ +# Functionality test +mySimpleThreadManager.startThreadWithID(updateThreadID) +mySimpleThreadManager.joinOneTimeThreadWithID(updateThreadID) +time.sleep(10) # Just in case + +# Now check the gibberish +gibberishUpdateResult = reAssembleGibberish(callbackHome_Update.getInternalDictionary(), gibberishLength) +gibberishDeltaResult = reAssembleGibberish(callbackHome_Delta.getInternalDictionary(), gibberishLength) +print("Update:") +print(gibberishUpdateResult) +print("Delta:") +print(gibberishDeltaResult) +print("Origin:") +print(gibberish) + +if gibberishUpdateResult != gibberishDeltaResult: + # Since shadow operations are on QoS0 (Pub/Sub), there could be a chance + # where incoming messages are missing on the subscribed side + # A ratio of 95% must be guaranteed to pass this test + dictUpdate = callbackHome_Update.getInternalDictionary() + dictDelta = callbackHome_Delta.getInternalDictionary() + maxBaseNumber = max(len(dictUpdate), len(dictDelta)) + diff = float(abs(len(dictUpdate) - len(dictDelta))) / maxBaseNumber + print("Update/Delta string not equal, missing rate is: " + str(diff * 100) + "%.") + # Largest chunk is 10 bytes, total length is X bytes. + # Minimum number of chunks is X/10 + # Maximum missing rate = 10% + if diff > 0.1: + print("Missing rate too high!") + exit(4) diff --git a/test-integration/IntegrationTests/TestToolLibrary/MQTTClientManager.py b/test-integration/IntegrationTests/TestToolLibrary/MQTTClientManager.py new file mode 100644 index 0000000..05c671a --- /dev/null +++ b/test-integration/IntegrationTests/TestToolLibrary/MQTTClientManager.py @@ -0,0 +1,145 @@ +import random +import string +import traceback +from ssl import SSLError + +import TestToolLibrary.SDKPackage.AWSIoTPythonSDK.core.protocol.paho.client as paho +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.MQTTLib import AWSIoTMQTTClient +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.core.util.providers import CertificateCredentialsProvider +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.core.util.providers import CiphersProvider +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.core.util.enums import DropBehaviorTypes +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.core.util.providers import EndpointProvider +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.core.protocol.mqtt_core import MqttCore +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import connectError +from TestToolLibrary.SDKPackage.AWSIoTPythonSDK.exception.AWSIoTExceptions import connectTimeoutException + + +CERT_MUTUAL_AUTH = "MutualAuth" +WEBSOCKET = 'Websocket' +CERT_ALPN = "ALPN" + + +# Class that manages the creation, configuration and connection of MQTT Client +class MQTTClientManager: + + def create_connected_mqtt_client(self, mode, client_id, host, credentials_data, callbacks=None): + client = self.create_nonconnected_mqtt_client(mode, client_id, host, credentials_data, callbacks) + return self._connect_client(client) + + def create_nonconnected_mqtt_client(self, mode, client_id, host, credentials_data, callbacks=None): + if mode == CERT_MUTUAL_AUTH: + sdk_mqtt_client = self._create_nonconnected_mqtt_client_with_cert(client_id, host, 8883, credentials_data) + elif mode == WEBSOCKET: + root_ca, certificate, private_key = credentials_data + sdk_mqtt_client = AWSIoTMQTTClient(clientID=client_id + "_" + self._random_string(3), useWebsocket=True) + sdk_mqtt_client.configureEndpoint(host, 443) + sdk_mqtt_client.configureCredentials(CAFilePath=root_ca) + elif mode == CERT_ALPN: + sdk_mqtt_client = self._create_nonconnected_mqtt_client_with_cert(client_id, host, 443, credentials_data) + else: + raise RuntimeError("Test mode: " + str(mode) + " not supported!") + + sdk_mqtt_client.configureConnectDisconnectTimeout(10) + sdk_mqtt_client.configureMQTTOperationTimeout(5) + + if callbacks is not None: + sdk_mqtt_client.onOnline = callbacks.on_online + sdk_mqtt_client.onOffline = callbacks.on_offline + sdk_mqtt_client.onMessage = callbacks.on_message + + return sdk_mqtt_client + + def _create_nonconnected_mqtt_client_with_cert(self, client_id, host, port, credentials_data): + root_ca, certificate, private_key = credentials_data + sdk_mqtt_client = AWSIoTMQTTClient(clientID=client_id + "_" + self._random_string(3)) + sdk_mqtt_client.configureEndpoint(host, port) + sdk_mqtt_client.configureCredentials(CAFilePath=root_ca, KeyPath=private_key, CertificatePath=certificate) + + return sdk_mqtt_client + + def create_connected_mqtt_core(self, client_id, host, root_ca, certificate, private_key, mode): + client = self.create_nonconnected_mqtt_core(client_id, host, root_ca, certificate, private_key, mode) + return self._connect_client(client) + + def create_nonconnected_mqtt_core(self, client_id, host, root_ca, certificate, private_key, mode): + client = None + protocol = None + port = None + is_websocket = False + is_alpn = False + + if mode == CERT_MUTUAL_AUTH: + protocol = paho.MQTTv311 + port = 8883 + elif mode == WEBSOCKET: + protocol = paho.MQTTv31 + port = 443 + is_websocket = True + elif mode == CERT_ALPN: + protocol = paho.MQTTv311 + port = 443 + is_alpn = True + else: + print("Error in creating the client") + + if protocol is None or port is None: + print("Not enough input parameters") + return client # client is None is the necessary params are not there + + try: + client = MqttCore(client_id + "_" + self._random_string(3), True, protocol, is_websocket) + + endpoint_provider = EndpointProvider() + endpoint_provider.set_host(host) + endpoint_provider.set_port(port) + + # Once is_websocket is True, certificate_credentials_provider will NOT be used + # by the client even if it is configured + certificate_credentials_provider = CertificateCredentialsProvider() + certificate_credentials_provider.set_ca_path(root_ca) + certificate_credentials_provider.set_cert_path(certificate) + certificate_credentials_provider.set_key_path(private_key) + + cipher_provider = CiphersProvider() + cipher_provider.set_ciphers(None) + + client.configure_endpoint(endpoint_provider) + client.configure_cert_credentials(certificate_credentials_provider, cipher_provider) + client.configure_connect_disconnect_timeout_sec(10) + client.configure_operation_timeout_sec(5) + client.configure_offline_requests_queue(10, DropBehaviorTypes.DROP_NEWEST) + + if is_alpn: + client.configure_alpn_protocols() + except Exception as e: + print("Unknown exception in creating the client: " + str(e)) + finally: + return client + + def _random_string(self, length): + return "".join(random.choice(string.ascii_lowercase) for i in range(length)) + + def _connect_client(self, client): + if client is None: + return client + + try: + client.connect(1) + except connectTimeoutException as e: + print("Connect timeout: " + str(e)) + return None + except connectError as e: + print("Connect error:" + str(e)) + return None + except SSLError as e: + print("Connect SSL error: " + str(e)) + return None + except IOError as e: + print("Credentials not found: " + str(e)) + return None + except Exception as e: + print("Unknown exception in connect: ") + traceback.print_exc() + return None + + return client diff --git a/test-integration/IntegrationTests/TestToolLibrary/SDKPackage/.gitignore b/test-integration/IntegrationTests/TestToolLibrary/SDKPackage/.gitignore new file mode 100644 index 0000000..151aa74 --- /dev/null +++ b/test-integration/IntegrationTests/TestToolLibrary/SDKPackage/.gitignore @@ -0,0 +1,3 @@ +*.* +!.gitignore +!__init__.py \ No newline at end of file diff --git a/test-integration/IntegrationTests/TestToolLibrary/SDKPackage/__init__.py b/test-integration/IntegrationTests/TestToolLibrary/SDKPackage/__init__.py new file mode 100644 index 0000000..1ad354e --- /dev/null +++ b/test-integration/IntegrationTests/TestToolLibrary/SDKPackage/__init__.py @@ -0,0 +1 @@ +__version__ = "1.4.9" diff --git a/test-integration/IntegrationTests/TestToolLibrary/__init__.py b/test-integration/IntegrationTests/TestToolLibrary/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test-integration/IntegrationTests/TestToolLibrary/checkInManager.py b/test-integration/IntegrationTests/TestToolLibrary/checkInManager.py new file mode 100644 index 0000000..aeeedd9 --- /dev/null +++ b/test-integration/IntegrationTests/TestToolLibrary/checkInManager.py @@ -0,0 +1,20 @@ +# Class that is responsible for input/dependency verification +import sys + + +class checkInManager: + + def __init__(self, numberOfInputParameters): + self._numberOfInputParameters = numberOfInputParameters + self.mode = None + self.host = None + self.customParameter = None + + def verify(self, args): + # Check if we got the correct command line params + if len(args) != self._numberOfInputParameters + 1: + exit(4) + self.mode = str(args[1]) + self.host = str(args[2]) + if self._numberOfInputParameters + 1 > 3: + self.customParameter = int(args[3]) diff --git a/test-integration/IntegrationTests/TestToolLibrary/simpleThreadManager.py b/test-integration/IntegrationTests/TestToolLibrary/simpleThreadManager.py new file mode 100644 index 0000000..023f8e7 --- /dev/null +++ b/test-integration/IntegrationTests/TestToolLibrary/simpleThreadManager.py @@ -0,0 +1,110 @@ +# Library for controllable threads. Should be able to: +# 1. Create different threads +# 2. Terminate certain thread +# 3. Join certain thread + + +import time +import threading + + +# Describe the type, property and control flag for a certain thread +class _threadControlUnit: + # Constants for different thread type + _THREAD_TYPE_ONETIME = 0 + _THREAD_TYPE_LOOP = 1 + + def __init__(self, threadID, threadType, runFunction, runParameters, scanningSpeedSecond=0.01): + if threadID is None or threadType is None or runFunction is None or runParameters is None: + raise ValueError("None input detected.") + if threadType != self._THREAD_TYPE_ONETIME and threadType != self._THREAD_TYPE_LOOP: + raise ValueError("Thread type not supported.") + self.threadID = threadID + self.threadType = threadType + self.runFunction = runFunction + self.runParameters = runParameters + self.threadObject = None # Holds the real thread object + # Now configure control flag, only meaning for loop thread type + if self.threadType == self._THREAD_TYPE_LOOP: + self.stopSign = False # Enabled infinite loop by default + self.scanningSpeedSecond = scanningSpeedSecond + else: + self.stopSign = None # No control flag for one time thread + self.scanningSpeedSecond = -1 + + def _oneTimeRunFunction(self): + self.runFunction(*self.runParameters) + + def _loopRunFunction(self): + while not self.stopSign: + self.runFunction(*self.runParameters) # There should be no manual delay in this function + time.sleep(self.scanningSpeedSecond) + + def _stopMe(self): + self.stopSign = True + + def _setThreadObject(self, threadObject): + self.threadObject = threadObject + + def _getThreadObject(self): + return self.threadObject + + +# Class that manages all threadControlUnit +# Used in a single thread +class simpleThreadManager: + def __init__(self): + self._internalCount = 0 + self._controlCenter = dict() + + def createOneTimeThread(self, runFunction, runParameters): + returnID = self._internalCount + self._controlCenter[self._internalCount] = _threadControlUnit(self._internalCount, + _threadControlUnit._THREAD_TYPE_ONETIME, + runFunction, runParameters) + self._internalCount += 1 + return returnID + + def createLoopThread(self, runFunction, runParameters, scanningSpeedSecond): + returnID = self._internalCount + self._controlCenter[self._internalCount] = _threadControlUnit(self._internalCount, + _threadControlUnit._THREAD_TYPE_LOOP, runFunction, + runParameters, scanningSpeedSecond) + self._internalCount += 1 + return returnID + + def stopLoopThreadWithID(self, threadID): + threadToStop = self._controlCenter.get(threadID) + if threadToStop is None: + raise ValueError("No such threadID.") + else: + if threadToStop.threadType == _threadControlUnit._THREAD_TYPE_LOOP: + threadToStop._stopMe() + time.sleep(3 * threadToStop.scanningSpeedSecond) + else: + raise TypeError("Error! Try to stop a one time thread.") + + def startThreadWithID(self, threadID): + threadToStart = self._controlCenter.get(threadID) + if threadToStart is None: + raise ValueError("No such threadID.") + else: + currentThreadType = threadToStart.threadType + newThreadObject = None + if currentThreadType == _threadControlUnit._THREAD_TYPE_LOOP: + newThreadObject = threading.Thread(target=threadToStart._loopRunFunction) + else: # One time thread + newThreadObject = threading.Thread(target=threadToStart._oneTimeRunFunction) + newThreadObject.start() + threadToStart._setThreadObject(newThreadObject) + + def joinOneTimeThreadWithID(self, threadID): + threadToJoin = self._controlCenter.get(threadID) + if threadToJoin is None: + raise ValueError("No such threadID.") + else: + if threadToJoin.threadType == _threadControlUnit._THREAD_TYPE_ONETIME: + currentThreadObject = threadToJoin._getThreadObject() + currentThreadObject.join() + else: + raise TypeError("Error! Try to join a loop thread.") diff --git a/test-integration/IntegrationTests/TestToolLibrary/skip.py b/test-integration/IntegrationTests/TestToolLibrary/skip.py new file mode 100644 index 0000000..4d6e5ca --- /dev/null +++ b/test-integration/IntegrationTests/TestToolLibrary/skip.py @@ -0,0 +1,110 @@ +import sys +from TestToolLibrary.MQTTClientManager import CERT_ALPN +from TestToolLibrary.MQTTClientManager import WEBSOCKET + +# This module manages the skip policy validation for each test + + +def skip_when_match(policy, message): + if policy.validate(): + print(message) + exit(0) # Exit the Python interpreter + + +class Policy(object): + + AND = "and" + OR = "or" + + def __init__(self): + self._relations = [] + + # Use caps to avoid collision with Python built-in and/or keywords + def And(self, policy): + self._relations.append((self.AND, policy)) + return self + + def Or(self, policy): + self._relations.append((self.OR, policy)) + return self + + def validate(self): + result = self.validate_impl() + + for element in self._relations: + operand, policy = element + if operand == self.AND: + result = result and policy.validate() + elif operand == self.OR: + result = result or policy.validate() + else: + raise RuntimeError("Unrecognized operand: " + str(operand)) + + return result + + def validate_impl(self): + raise RuntimeError("Not implemented") + + +class PythonVersion(Policy): + + HIGHER = "higher" + LOWER = "lower" + EQUALS = "equals" + + def __init__(self, actual_version, expected_version, operand): + Policy.__init__(self) + self._actual_version = actual_version + self._expected_version = expected_version + self._operand = operand + + def validate_impl(self): + if self._operand == self.LOWER: + return self._actual_version < self._expected_version + elif self._operand == self.HIGHER: + return self._actual_version > self._expected_version + elif self._operand == self.EQUALS: + return self._actual_version == self._expected_version + else: + raise RuntimeError("Unsupported operand: " + self._operand) + + +class Python2VersionLowerThan(PythonVersion): + + def __init__(self, version): + PythonVersion.__init__(self, sys.version_info[:3], version, PythonVersion.LOWER) + + def validate_impl(self): + return sys.version_info[0] == 2 and PythonVersion.validate_impl(self) + + +class Python3VersionLowerThan(PythonVersion): + + def __init__(self, version): + PythonVersion.__init__(self, sys.version_info[:3], version, PythonVersion.LOWER) + + def validate_impl(self): + return sys.version_info[0] == 3 and PythonVersion.validate_impl(self) + + +class ModeIs(Policy): + + def __init__(self, actual_mode, expected_mode): + Policy.__init__(self) + self._actual_mode = actual_mode + self._expected_mode = expected_mode + + def validate_impl(self): + return self._actual_mode == self._expected_mode + + +class ModeIsALPN(ModeIs): + + def __init__(self, actual_mode): + ModeIs.__init__(self, actual_mode=actual_mode, expected_mode=CERT_ALPN) + + +class ModeIsWebSocket(ModeIs): + + def __init__(self, actual_mode): + ModeIs.__init__(self, actual_mode=actual_mode, expected_mode=WEBSOCKET) diff --git a/test-integration/Tools/retrieve-key.py b/test-integration/Tools/retrieve-key.py new file mode 100644 index 0000000..3884d7f --- /dev/null +++ b/test-integration/Tools/retrieve-key.py @@ -0,0 +1,59 @@ + +import boto3 +import base64 +import sys +from botocore.exceptions import ClientError + +def main(): + secret_name = sys.argv[1] + region_name = "us-east-1" + + # Create a Secrets Manager client + session = boto3.session.Session() + client = session.client( + service_name='secretsmanager', + region_name=region_name + ) + # In this sample we only handle the specific exceptions for the 'GetSecretValue' API. + # See https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html + # We rethrow the exception by default. + + try: + get_secret_value_response = client.get_secret_value( + SecretId=secret_name + ) + + except ClientError as e: + if e.response['Error']['Code'] == 'DecryptionFailureException': + # Secrets Manager can't decrypt the protected secret text using the provided KMS key. + # Deal with the exception here, and/or rethrow at your discretion. + raise e + elif e.response['Error']['Code'] == 'InternalServiceErrorException': + # An error occurred on the server side. + # Deal with the exception here, and/or rethrow at your discretion. + raise e + elif e.response['Error']['Code'] == 'InvalidParameterException': + # You provided an invalid value for a parameter. + # Deal with the exception here, and/or rethrow at your discretion. + raise e + elif e.response['Error']['Code'] == 'InvalidRequestException': + # You provided a parameter value that is not valid for the current state of the resource. + # Deal with the exception here, and/or rethrow at your discretion. + raise e + elif e.response['Error']['Code'] == 'ResourceNotFoundException': + # We can't find the resource that you asked for. + # Deal with the exception here, and/or rethrow at your discretion. + raise e + print(e) + else: + # Decrypts secret using the associated KMS key. + # Depending on whether the secret is a string or binary, one of these fields will be populated. + if 'SecretString' in get_secret_value_response: + secret = get_secret_value_response['SecretString'] + else: + secret = base64.b64decode(get_secret_value_response['SecretBinary']) + print(secret) + + +if __name__ == '__main__': + sys.exit(main()) # next section explains the use of sys.exit diff --git a/test-integration/run/run.sh b/test-integration/run/run.sh new file mode 100755 index 0000000..8e23c91 --- /dev/null +++ b/test-integration/run/run.sh @@ -0,0 +1,155 @@ +#!/bin/bash +# +# This script manages the start of integration +# tests for Python core in AWS IoT Arduino Yun +# SDK. The tests should be able to run both in +# Brazil and ToD Worker environment. +# The script will perform the following tasks: +# 1. Retrieve credentials as needed from AWS +# 2. Obtain ZIP package and unzip it locally +# 3. Start the integration tests and check results +# 4. Report any status returned. +# To start the tests as TodWorker: +# > run.sh MutualAuth 1000 100 7 +# or +# > run.sh Websocket 1000 100 7 +# or +# > run.sh ALPN 1000 100 7 +# +# To start the tests from desktop: +# > run.sh MutualAuthT 1000 100 7 +# or +# > run.sh WebsocketT 1000 100 7 +# or +# > run.sh ALPNT 1000 100 7 +# +# 1000 MQTT messages, 100 bytes of random string +# in length and 7 rounds of network failure for +# progressive backoff. +# Test mode (MutualAuth/Websocket) must be +# specified. +# Scale number must also be specified (see usage) + +# Define const +USAGE="usage: run.sh " + +UnitTestHostArn="arn:aws:secretsmanager:us-east-1:180635532705:secret:unit-test/endpoint-HSpeEu" +GreenGrassHostArn="arn:aws:secretsmanager:us-east-1:180635532705:secret:ci/greengrassv1/endpoint-DgM00X" + +AWSMutualAuth_TodWorker_private_key="arn:aws:secretsmanager:us-east-1:180635532705:secret:ci/mqtt5/us/Mqtt5Prod/key-kqgyvf" +AWSMutualAuth_TodWorker_certificate="arn:aws:secretsmanager:us-east-1:180635532705:secret:ci/mqtt5/us/Mqtt5Prod/cert-VDI1Gd" + +AWSGGDiscovery_TodWorker_private_key="arn:aws:secretsmanager:us-east-1:180635532705:secret:V1IotSdkIntegrationTestGGDiscoveryPrivateKey-BsLvNP" +AWSGGDiscovery_TodWorker_certificate="arn:aws:secretsmanager:us-east-1:180635532705:secret:V1IotSdkIntegrationTestGGDiscoveryCertificate-DSwdhA" + + +SDKLocation="./AWSIoTPythonSDK" +RetrieveAWSKeys="./test-integration/Tools/retrieve-key.py" +CREDENTIAL_DIR="./test-integration/Credentials/" +TEST_DIR="./test-integration/IntegrationTests/" +CA_CERT_URL="https://www.amazontrust.com/repository/AmazonRootCA1.pem" +CA_CERT_PATH=${CREDENTIAL_DIR}rootCA.crt +TestHost=$(python ${RetrieveAWSKeys} ${UnitTestHostArn}) +GreengrassHost=$(python ${RetrieveAWSKeys} ${GreenGrassHostArn}) + + + + +# If input args not correct, echo usage +if [ $# -ne 4 ]; then + echo ${USAGE} +else +# Description + echo "[STEP] Start run.sh" + echo "***************************************************" + echo "About to start integration tests for IoTPySDK..." + echo "Test Mode: $1" +# Determine the Python versions need to test for this SDK + pythonExecutableArray=() + pythonExecutableArray[0]="3" +# Retrieve credentials as needed from AWS + TestMode="" + echo "[STEP] Retrieve credentials from AWS" + echo "***************************************************" + if [ "$1"x == "MutualAuth"x ]; then + AWSSetName_privatekey=${AWSMutualAuth_TodWorker_private_key} + AWSSetName_certificate=${AWSMutualAuth_TodWorker_certificate} + AWSDRSName_privatekey=${AWSGGDiscovery_TodWorker_private_key} + AWSDRSName_certificate=${AWSGGDiscovery_TodWorker_certificate} + TestMode="MutualAuth" + python ${RetrieveAWSKeys} ${AWSSetName_certificate} > ${CREDENTIAL_DIR}certificate.pem.crt + python ${RetrieveAWSKeys} ${AWSSetName_privatekey} > ${CREDENTIAL_DIR}privateKey.pem.key + curl -s "${CA_CERT_URL}" > ${CA_CERT_PATH} + echo -e "URL retrieved certificate data\n" + python ${RetrieveAWSKeys} ${AWSDRSName_certificate} > ${CREDENTIAL_DIR}certificate_drs.pem.crt + python ${RetrieveAWSKeys} ${AWSDRSName_privatekey} > ${CREDENTIAL_DIR}privateKey_drs.pem.key + elif [ "$1"x == "Websocket"x ]; then + TestMode="Websocket" + curl -s "${CA_CERT_URL}" > ${CA_CERT_PATH} + echo -e "URL retrieved certificate data\n" + elif [ "$1"x == "ALPN"x ]; then + AWSSetName_privatekey=${AWSMutualAuth_TodWorker_private_key} + AWSSetName_certificate=${AWSMutualAuth_TodWorker_certificate} + AWSDRSName_privatekey=${AWSGGDiscovery_TodWorker_private_key} + AWSDRSName_certificate=${AWSGGDiscovery_TodWorker_certificate} + TestMode="ALPN" + python ${RetrieveAWSKeys} ${AWSSetName_certificate} > ${CREDENTIAL_DIR}certificate.pem.crt + python ${RetrieveAWSKeys} ${AWSSetName_privatekey} > ${CREDENTIAL_DIR}privateKey.pem.key + curl -s "${CA_CERT_URL}" > ${CA_CERT_PATH} + echo -e "URL retrieved certificate data\n" + python ${RetrieveAWSKeys} ${AWSDRSName_certificate} > ${CREDENTIAL_DIR}certificate_drs.pem.crt + python ${RetrieveAWSKeys} ${AWSDRSName_privatekey} > ${CREDENTIAL_DIR}privateKey_drs.pem.key + else + echo "Mode not supported" + exit 1 + fi +# Obtain ZIP package and unzip it locally + echo ${TestMode} + echo "[STEP] Obtain ZIP package" + echo "***************************************************" + cp -R ${SDKLocation} ./test-integration/IntegrationTests/TestToolLibrary/SDKPackage/ +# Obtain Python executable + + echo "***************************************************" + for file in `ls ${TEST_DIR}` + do + if [ ${file##*.}x == "py"x ]; then + echo "[SUB] Running test: ${file}..." + + Scale=10 + Host=TestHost + case "$file" in + "IntegrationTestMQTTConnection.py") Scale=$2 + ;; + "IntegrationTestShadow.py") Scale=$3 + ;; + "IntegrationTestAutoReconnectResubscribe.py") Scale="" + ;; + "IntegrationTestProgressiveBackoff.py") Scale=$4 + ;; + "IntegrationTestConfigurablePublishMessageQueueing.py") Scale="" + ;; + "IntegrationTestDiscovery.py") Scale="" + Host=${GreengrassHost} + ;; + "IntegrationTestAsyncAPIGeneralNotificationCallbacks.py") Scale="" + ;; + "IntegrationTestOfflineQueueingForSubscribeUnsubscribe.py") Scale="" + ;; + "IntegrationTestClientReusability.py") Scale="" + ;; + "IntegrationTestJobsClient.py") Scale="" + esac + + python ${TEST_DIR}${file} ${TestMode} ${TestHost} ${Scale} + currentTestStatus=$? + echo "[SUB] Test: ${file} completed. Exiting with status: ${currentTestStatus}" + if [ ${currentTestStatus} -ne 0 ]; then + echo "!!!!!!!!!!!!!Test: ${file} failed.!!!!!!!!!!!!!" + exit ${currentTestStatus} + fi + echo "" + fi + done + echo "All integration tests passed" +fi diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/core/__init__.py b/test/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/core/greengrass/__init__.py b/test/core/greengrass/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/core/greengrass/discovery/__init__.py b/test/core/greengrass/discovery/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/core/greengrass/discovery/test_discovery_info_parsing.py b/test/core/greengrass/discovery/test_discovery_info_parsing.py new file mode 100644 index 0000000..318ede3 --- /dev/null +++ b/test/core/greengrass/discovery/test_discovery_info_parsing.py @@ -0,0 +1,127 @@ +from AWSIoTPythonSDK.core.greengrass.discovery.models import DiscoveryInfo + + +DRS_INFO_JSON = "{\"GGGroups\":[{\"GGGroupId\":\"627bf63d-ae64-4f58-a18c-80a44fcf4088\"," \ + "\"Cores\":[{\"thingArn\":\"arn:aws:iot:us-east-1:003261610643:thing/DRS_GGC_0kegiNGA_0\"," \ + "\"Connectivity\":[{\"Id\":\"Id-0\",\"HostAddress\":\"192.168.101.0\",\"PortNumber\":8080," \ + "\"Metadata\":\"Description-0\"}," \ + "{\"Id\":\"Id-1\",\"HostAddress\":\"192.168.101.1\",\"PortNumber\":8081,\"Metadata\":\"Description-1\"}," \ + "{\"Id\":\"Id-2\",\"HostAddress\":\"192.168.101.2\",\"PortNumber\":8082,\"Metadata\":\"Description-2\"}]}]," \ + "\"CAs\":[\"-----BEGIN CERTIFICATE-----\\n" \ + "MIIEFTCCAv2gAwIBAgIVAPZfc4GMLZPmXbnoaZm6jRDqDs4+MA0GCSqGSIb3DQEB\\n" \ + "CwUAMIGoMQswCQYDVQQGEwJVUzEYMBYGA1UECgwPQW1hem9uLmNvbSBJbmMuMRww\\n" \ + "GgYDVQQLDBNBbWF6b24gV2ViIFNlcnZpY2VzMRMwEQYDVQQIDApXYXNoaW5ndG9u\\n" \ + "MRAwDgYDVQQHDAdTZWF0dGxlMTowOAYDVQQDDDEwMDMyNjE2MTA2NDM6NjI3YmY2\\n" \ + "M2QtYWU2NC00ZjU4LWExOGMtODBhNDRmY2Y0MDg4MCAXDTE3MDUyNTE4NDI1OVoY\\n" \ + "DzIwOTcwNTI1MTg0MjU4WjCBqDELMAkGA1UEBhMCVVMxGDAWBgNVBAoMD0FtYXpv\\n" \ + "bi5jb20gSW5jLjEcMBoGA1UECwwTQW1hem9uIFdlYiBTZXJ2aWNlczETMBEGA1UE\\n" \ + "CAwKV2FzaGluZ3RvbjEQMA4GA1UEBwwHU2VhdHRsZTE6MDgGA1UEAwwxMDAzMjYx\\n" \ + "NjEwNjQzOjYyN2JmNjNkLWFlNjQtNGY1OC1hMThjLTgwYTQ0ZmNmNDA4ODCCASIw\\n" \ + "DQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKEWtZtKyJUg2VUwZkbkVtltrfam\\n" \ + "s9LMIdKNA3Wz4zSLhZjKHiTSkQmpZwKle5ziYs6Q5hfeT8WC0FNAVv1JhnwsuGfT\\n" \ + "sG0UO5dSn7wqXOJigKC1CaSGqeFpKB0/a3wR1L6pCGVbLZ86/sPCEPHHJDieQ+Ps\\n" \ + "RnOcUGb4CuIBnI2N+lafWNa4F4KRSVJCEeZ6u4iWVVdIEcDLKlakY45jtVvQqwnz\\n" \ + "3leFsN7PTLEkVq5u1PXSbT5DWv6p+5NoDnGAT7j7Wbr2yJw7DtpBOL6oWkAdbFAQ\\n" \ + "2097e8mIxNYE9xAzRlb5wEr6jZl/8K60v9P83OapMeuOg4JS8FGulHXbDg0CAwEA\\n" \ + "AaMyMDAwDwYDVR0TAQH/BAUwAwEB/zAdBgNVHQ4EFgQU21ELaPCH9Oh001OS0JMv\\n" \ + "n8hU8dYwDQYJKoZIhvcNAQELBQADggEBABW66eH/+/v9Nq5jtJzflfrqAfBOzWLj\\n" \ + "UTEv6szkYzV5Crr8vnu2P5OlyA0NdiKGiAm0AgoDkf+n9HU3Hc0zm3G/QaAO2UmN\\n" \ + "9MwtIp29BSRf+gd1bX/WZTtl5I5xl290BDfr5o08I6TOf0A4P8IAkGwku5j0IQjM\\n" \ + "ns2HH5UVki155dtmWDEGX6q35KABbsmv3tO1+geJVYnd1QkHzR5IXA12gxlMw9GJ\\n" \ + "+cOw+rwJJ2ZcXo3HFoXBcsPqPOa1SO3vTl3XWQ+jX3vyDsxh/VGoJ4epsjwmJ+dW\\n" \ + "sHJoqsa3ZPDW0LcEuYgdzYWRhumGwH9fJJUx0GS4Tdg4ud+6jpuyflU=\\n" \ + "-----END CERTIFICATE-----\\n\"]}]}" + +EXPECTED_CORE_THING_ARN = "arn:aws:iot:us-east-1:003261610643:thing/DRS_GGC_0kegiNGA_0" +EXPECTED_GROUP_ID = "627bf63d-ae64-4f58-a18c-80a44fcf4088" +EXPECTED_CONNECTIVITY_INFO_ID_0 = "Id-0" +EXPECTED_CONNECTIVITY_INFO_ID_1 = "Id-1" +EXPECTED_CONNECTIVITY_INFO_ID_2 = "Id-2" +EXPECTED_CA = "-----BEGIN CERTIFICATE-----\n" \ + "MIIEFTCCAv2gAwIBAgIVAPZfc4GMLZPmXbnoaZm6jRDqDs4+MA0GCSqGSIb3DQEB\n" \ + "CwUAMIGoMQswCQYDVQQGEwJVUzEYMBYGA1UECgwPQW1hem9uLmNvbSBJbmMuMRww\n" \ + "GgYDVQQLDBNBbWF6b24gV2ViIFNlcnZpY2VzMRMwEQYDVQQIDApXYXNoaW5ndG9u\n" \ + "MRAwDgYDVQQHDAdTZWF0dGxlMTowOAYDVQQDDDEwMDMyNjE2MTA2NDM6NjI3YmY2\n" \ + "M2QtYWU2NC00ZjU4LWExOGMtODBhNDRmY2Y0MDg4MCAXDTE3MDUyNTE4NDI1OVoY\n" \ + "DzIwOTcwNTI1MTg0MjU4WjCBqDELMAkGA1UEBhMCVVMxGDAWBgNVBAoMD0FtYXpv\n" \ + "bi5jb20gSW5jLjEcMBoGA1UECwwTQW1hem9uIFdlYiBTZXJ2aWNlczETMBEGA1UE\n" \ + "CAwKV2FzaGluZ3RvbjEQMA4GA1UEBwwHU2VhdHRsZTE6MDgGA1UEAwwxMDAzMjYx\n" \ + "NjEwNjQzOjYyN2JmNjNkLWFlNjQtNGY1OC1hMThjLTgwYTQ0ZmNmNDA4ODCCASIw\n" \ + "DQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKEWtZtKyJUg2VUwZkbkVtltrfam\n" \ + "s9LMIdKNA3Wz4zSLhZjKHiTSkQmpZwKle5ziYs6Q5hfeT8WC0FNAVv1JhnwsuGfT\n" \ + "sG0UO5dSn7wqXOJigKC1CaSGqeFpKB0/a3wR1L6pCGVbLZ86/sPCEPHHJDieQ+Ps\n" \ + "RnOcUGb4CuIBnI2N+lafWNa4F4KRSVJCEeZ6u4iWVVdIEcDLKlakY45jtVvQqwnz\n" \ + "3leFsN7PTLEkVq5u1PXSbT5DWv6p+5NoDnGAT7j7Wbr2yJw7DtpBOL6oWkAdbFAQ\n" \ + "2097e8mIxNYE9xAzRlb5wEr6jZl/8K60v9P83OapMeuOg4JS8FGulHXbDg0CAwEA\n" \ + "AaMyMDAwDwYDVR0TAQH/BAUwAwEB/zAdBgNVHQ4EFgQU21ELaPCH9Oh001OS0JMv\n" \ + "n8hU8dYwDQYJKoZIhvcNAQELBQADggEBABW66eH/+/v9Nq5jtJzflfrqAfBOzWLj\n" \ + "UTEv6szkYzV5Crr8vnu2P5OlyA0NdiKGiAm0AgoDkf+n9HU3Hc0zm3G/QaAO2UmN\n" \ + "9MwtIp29BSRf+gd1bX/WZTtl5I5xl290BDfr5o08I6TOf0A4P8IAkGwku5j0IQjM\n" \ + "ns2HH5UVki155dtmWDEGX6q35KABbsmv3tO1+geJVYnd1QkHzR5IXA12gxlMw9GJ\n" \ + "+cOw+rwJJ2ZcXo3HFoXBcsPqPOa1SO3vTl3XWQ+jX3vyDsxh/VGoJ4epsjwmJ+dW\n" \ + "sHJoqsa3ZPDW0LcEuYgdzYWRhumGwH9fJJUx0GS4Tdg4ud+6jpuyflU=\n" \ + "-----END CERTIFICATE-----\n" + + +class TestDiscoveryInfoParsing: + + def setup_method(self, test_method): + self.discovery_info = DiscoveryInfo(DRS_INFO_JSON) + + def test_parsing_ggc_list_ca_list(self): + ggc_list = self.discovery_info.getAllCores() + ca_list = self.discovery_info.getAllCas() + + self._verify_core_connectivity_info_list(ggc_list) + self._verify_ca_list(ca_list) + + def test_parsing_group_object(self): + group_object = self.discovery_info.toObjectAtGroupLevel() + self._verify_connectivity_info(group_object + .get(EXPECTED_GROUP_ID) + .getCoreConnectivityInfo(EXPECTED_CORE_THING_ARN) + .getConnectivityInfo(EXPECTED_CONNECTIVITY_INFO_ID_0)) + self._verify_connectivity_info(group_object + .get(EXPECTED_GROUP_ID) + .getCoreConnectivityInfo(EXPECTED_CORE_THING_ARN) + .getConnectivityInfo(EXPECTED_CONNECTIVITY_INFO_ID_1)) + self._verify_connectivity_info(group_object + .get(EXPECTED_GROUP_ID) + .getCoreConnectivityInfo(EXPECTED_CORE_THING_ARN) + .getConnectivityInfo(EXPECTED_CONNECTIVITY_INFO_ID_2)) + + def test_parsing_group_list(self): + group_list = self.discovery_info.getAllGroups() + + assert len(group_list) == 1 + group_info = group_list[0] + assert group_info.groupId == EXPECTED_GROUP_ID + self._verify_ca_list(group_info.caList) + self._verify_core_connectivity_info_list(group_info.coreConnectivityInfoList) + + def _verify_ca_list(self, actual_ca_list): + assert len(actual_ca_list) == 1 + try: + actual_group_id, actual_ca = actual_ca_list[0] + assert actual_group_id == EXPECTED_GROUP_ID + assert actual_ca == EXPECTED_CA + except: + assert actual_ca_list[0] == EXPECTED_CA + + def _verify_core_connectivity_info_list(self, actual_core_connectivity_info_list): + assert len(actual_core_connectivity_info_list) == 1 + actual_core_connectivity_info = actual_core_connectivity_info_list[0] + assert actual_core_connectivity_info.coreThingArn == EXPECTED_CORE_THING_ARN + assert actual_core_connectivity_info.groupId == EXPECTED_GROUP_ID + self._verify_connectivity_info_list(actual_core_connectivity_info.connectivityInfoList) + + def _verify_connectivity_info_list(self, actual_connectivity_info_list): + for actual_connectivity_info in actual_connectivity_info_list: + self._verify_connectivity_info(actual_connectivity_info) + + def _verify_connectivity_info(self, actual_connectivity_info): + info_id = actual_connectivity_info.id + sequence_number_string = info_id[-1:] + assert actual_connectivity_info.host == "192.168.101." + sequence_number_string + assert actual_connectivity_info.port == int("808" + sequence_number_string) + assert actual_connectivity_info.metadata == "Description-" + sequence_number_string diff --git a/test/core/greengrass/discovery/test_discovery_info_provider.py b/test/core/greengrass/discovery/test_discovery_info_provider.py new file mode 100644 index 0000000..2f11d20 --- /dev/null +++ b/test/core/greengrass/discovery/test_discovery_info_provider.py @@ -0,0 +1,169 @@ +from AWSIoTPythonSDK.core.greengrass.discovery.providers import DiscoveryInfoProvider +from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryInvalidRequestException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryUnauthorizedException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryDataNotFoundException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryThrottlingException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryTimeoutException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryFailure +import pytest +try: + from mock import patch + from mock import MagicMock +except: + from unittest.mock import patch + from unittest.mock import MagicMock + + +DUMMY_CA_PATH = "dummy/ca/path" +DUMMY_CERT_PATH = "dummy/cert/path" +DUMMY_KEY_PATH = "dummy/key/path" +DUMMY_HOST = "dummy.host.amazonaws.com" +DUMMY_PORT = "8443" +DUMMY_TIME_OUT_SEC = 3 +DUMMY_GGAD_THING_NAME = "CoolGGAD" +FORMAT_REQUEST = "GET /greengrass/discover/thing/%s HTTP/1.1\r\nHost: " + DUMMY_HOST + ":" + DUMMY_PORT + "\r\n\r\n" +FORMAT_RESPONSE_HEADER = "HTTP/1.1 %s %s\r\n" \ + "content-type: application/json\r\n" \ + "content-length: %d\r\n" \ + "date: Wed, 05 Jul 2017 22:17:19 GMT\r\n" \ + "x-amzn-RequestId: 97408dd9-06a0-73bb-8e00-c4fc6845d555\r\n" \ + "connection: Keep-Alive\r\n\r\n" + +SERVICE_ERROR_MESSAGE_FORMAT = "{\"errorMessage\":\"%s\"}" +SERVICE_ERROR_MESSAGE_400 = SERVICE_ERROR_MESSAGE_FORMAT % "Invalid input detected for this request" +SERVICE_ERROR_MESSAGE_401 = SERVICE_ERROR_MESSAGE_FORMAT % "Unauthorized request" +SERVICE_ERROR_MESSAGE_404 = SERVICE_ERROR_MESSAGE_FORMAT % "Resource not found" +SERVICE_ERROR_MESSAGE_429 = SERVICE_ERROR_MESSAGE_FORMAT % "Too many requests" +SERVICE_ERROR_MESSAGE_500 = SERVICE_ERROR_MESSAGE_FORMAT % "Internal server error" +PAYLOAD_200 = "{\"GGGroups\":[{\"GGGroupId\":\"627bf63d-ae64-4f58-a18c-80a44fcf4088\"," \ + "\"Cores\":[{\"thingArn\":\"arn:aws:iot:us-east-1:003261610643:thing/DRS_GGC_0kegiNGA_0\"," \ + "\"Connectivity\":[{\"Id\":\"Id-0\",\"HostAddress\":\"192.168.101.0\",\"PortNumber\":8080," \ + "\"Metadata\":\"Description-0\"}," \ + "{\"Id\":\"Id-1\",\"HostAddress\":\"192.168.101.1\",\"PortNumber\":8081,\"Metadata\":\"Description-1\"}," \ + "{\"Id\":\"Id-2\",\"HostAddress\":\"192.168.101.2\",\"PortNumber\":8082,\"Metadata\":\"Description-2\"}]}]," \ + "\"CAs\":[\"-----BEGIN CERTIFICATE-----\\n" \ + "MIIEFTCCAv2gAwIBAgIVAPZfc4GMLZPmXbnoaZm6jRDqDs4+MA0GCSqGSIb3DQEB\\n" \ + "CwUAMIGoMQswCQYDVQQGEwJVUzEYMBYGA1UECgwPQW1hem9uLmNvbSBJbmMuMRww\\n" \ + "GgYDVQQLDBNBbWF6b24gV2ViIFNlcnZpY2VzMRMwEQYDVQQIDApXYXNoaW5ndG9u\\n" \ + "MRAwDgYDVQQHDAdTZWF0dGxlMTowOAYDVQQDDDEwMDMyNjE2MTA2NDM6NjI3YmY2\\n" \ + "M2QtYWU2NC00ZjU4LWExOGMtODBhNDRmY2Y0MDg4MCAXDTE3MDUyNTE4NDI1OVoY\\n" \ + "DzIwOTcwNTI1MTg0MjU4WjCBqDELMAkGA1UEBhMCVVMxGDAWBgNVBAoMD0FtYXpv\\n" \ + "bi5jb20gSW5jLjEcMBoGA1UECwwTQW1hem9uIFdlYiBTZXJ2aWNlczETMBEGA1UE\\n" \ + "CAwKV2FzaGluZ3RvbjEQMA4GA1UEBwwHU2VhdHRsZTE6MDgGA1UEAwwxMDAzMjYx\\n" \ + "NjEwNjQzOjYyN2JmNjNkLWFlNjQtNGY1OC1hMThjLTgwYTQ0ZmNmNDA4ODCCASIw\\n" \ + "DQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKEWtZtKyJUg2VUwZkbkVtltrfam\\n" \ + "s9LMIdKNA3Wz4zSLhZjKHiTSkQmpZwKle5ziYs6Q5hfeT8WC0FNAVv1JhnwsuGfT\\n" \ + "sG0UO5dSn7wqXOJigKC1CaSGqeFpKB0/a3wR1L6pCGVbLZ86/sPCEPHHJDieQ+Ps\\n" \ + "RnOcUGb4CuIBnI2N+lafWNa4F4KRSVJCEeZ6u4iWVVdIEcDLKlakY45jtVvQqwnz\\n" \ + "3leFsN7PTLEkVq5u1PXSbT5DWv6p+5NoDnGAT7j7Wbr2yJw7DtpBOL6oWkAdbFAQ\\n" \ + "2097e8mIxNYE9xAzRlb5wEr6jZl/8K60v9P83OapMeuOg4JS8FGulHXbDg0CAwEA\\n" \ + "AaMyMDAwDwYDVR0TAQH/BAUwAwEB/zAdBgNVHQ4EFgQU21ELaPCH9Oh001OS0JMv\\n" \ + "n8hU8dYwDQYJKoZIhvcNAQELBQADggEBABW66eH/+/v9Nq5jtJzflfrqAfBOzWLj\\n" \ + "UTEv6szkYzV5Crr8vnu2P5OlyA0NdiKGiAm0AgoDkf+n9HU3Hc0zm3G/QaAO2UmN\\n" \ + "9MwtIp29BSRf+gd1bX/WZTtl5I5xl290BDfr5o08I6TOf0A4P8IAkGwku5j0IQjM\\n" \ + "ns2HH5UVki155dtmWDEGX6q35KABbsmv3tO1+geJVYnd1QkHzR5IXA12gxlMw9GJ\\n" \ + "+cOw+rwJJ2ZcXo3HFoXBcsPqPOa1SO3vTl3XWQ+jX3vyDsxh/VGoJ4epsjwmJ+dW\\n" \ + "sHJoqsa3ZPDW0LcEuYgdzYWRhumGwH9fJJUx0GS4Tdg4ud+6jpuyflU=\\n" \ + "-----END CERTIFICATE-----\\n\"]}]}" + + +class TestDiscoveryInfoProvider: + + def setup_class(cls): + cls.service_error_message_dict = { + "400" : SERVICE_ERROR_MESSAGE_400, + "401" : SERVICE_ERROR_MESSAGE_401, + "404" : SERVICE_ERROR_MESSAGE_404, + "429" : SERVICE_ERROR_MESSAGE_429 + } + cls.client_exception_dict = { + "400" : DiscoveryInvalidRequestException, + "401" : DiscoveryUnauthorizedException, + "404" : DiscoveryDataNotFoundException, + "429" : DiscoveryThrottlingException + } + + def setup_method(self, test_method): + self.mock_sock = MagicMock() + self.mock_ssl_sock = MagicMock() + + def test_200_drs_response_should_succeed(self): + with patch.object(DiscoveryInfoProvider, "_create_tcp_connection") as mock_method_create_tcp_connection, \ + patch.object(DiscoveryInfoProvider, "_create_ssl_connection") as mock_method_create_ssl_connection: + mock_method_create_tcp_connection.return_value = self.mock_sock + mock_method_create_ssl_connection.return_value = self.mock_ssl_sock + raw_outbound_request = FORMAT_REQUEST % DUMMY_GGAD_THING_NAME + self._create_test_target() + self.mock_ssl_sock.write.return_value = len(raw_outbound_request) + self.mock_ssl_sock.read.side_effect = \ + list((FORMAT_RESPONSE_HEADER % ("200", "OK", len(PAYLOAD_200)) + PAYLOAD_200).encode("utf-8")) + + discovery_info = self.discovery_info_provider.discover(DUMMY_GGAD_THING_NAME) + self.mock_ssl_sock.write.assert_called_with(raw_outbound_request.encode("utf-8")) + assert discovery_info.rawJson == PAYLOAD_200 + + def test_400_drs_response_should_raise(self): + self._internal_test_non_200_drs_response_should_raise("400", "Bad request") + + def test_401_drs_response_should_raise(self): + self._internal_test_non_200_drs_response_should_raise("401", "Unauthorized") + + def test_404_drs_response_should_raise(self): + self._internal_test_non_200_drs_response_should_raise("404", "Not found") + + def test_429_drs_response_should_raise(self): + self._internal_test_non_200_drs_response_should_raise("429", "Throttled") + + def test_unexpected_drs_response_should_raise(self): + self._internal_test_non_200_drs_response_should_raise("500", "Internal server error") + self._internal_test_non_200_drs_response_should_raise("1234", "Gibberish") + + def _internal_test_non_200_drs_response_should_raise(self, http_status_code, http_status_message): + with patch.object(DiscoveryInfoProvider, "_create_tcp_connection") as mock_method_create_tcp_connection, \ + patch.object(DiscoveryInfoProvider, "_create_ssl_connection") as mock_method_create_ssl_connection: + mock_method_create_tcp_connection.return_value = self.mock_sock + mock_method_create_ssl_connection.return_value = self.mock_ssl_sock + self._create_test_target() + service_error_message = self.service_error_message_dict.get(http_status_code) + if service_error_message is None: + service_error_message = SERVICE_ERROR_MESSAGE_500 + client_exception_type = self.client_exception_dict.get(http_status_code) + if client_exception_type is None: + client_exception_type = DiscoveryFailure + self.mock_ssl_sock.write.return_value = len(FORMAT_REQUEST % DUMMY_GGAD_THING_NAME) + self.mock_ssl_sock.read.side_effect = \ + list((FORMAT_RESPONSE_HEADER % (http_status_code, http_status_message, len(service_error_message)) + + service_error_message).encode("utf-8")) + + with pytest.raises(client_exception_type): + self.discovery_info_provider.discover(DUMMY_GGAD_THING_NAME) + + def test_request_time_out_should_raise(self): + with patch.object(DiscoveryInfoProvider, "_create_tcp_connection") as mock_method_create_tcp_connection, \ + patch.object(DiscoveryInfoProvider, "_create_ssl_connection") as mock_method_create_ssl_connection: + mock_method_create_tcp_connection.return_value = self.mock_sock + mock_method_create_ssl_connection.return_value = self.mock_ssl_sock + self._create_test_target() + + # We do not configure any return value and simply let request part time out + with pytest.raises(DiscoveryTimeoutException): + self.discovery_info_provider.discover(DUMMY_GGAD_THING_NAME) + + def test_response_time_out_should_raise(self): + with patch.object(DiscoveryInfoProvider, "_create_tcp_connection") as mock_method_create_tcp_connection, \ + patch.object(DiscoveryInfoProvider, "_create_ssl_connection") as mock_method_create_ssl_connection: + mock_method_create_tcp_connection.return_value = self.mock_sock + mock_method_create_ssl_connection.return_value = self.mock_ssl_sock + self._create_test_target() + + # We configure the request to succeed and let the response part time out + self.mock_ssl_sock.write.return_value = len(FORMAT_REQUEST % DUMMY_GGAD_THING_NAME) + with pytest.raises(DiscoveryTimeoutException): + self.discovery_info_provider.discover(DUMMY_GGAD_THING_NAME) + + def _create_test_target(self): + self.discovery_info_provider = DiscoveryInfoProvider(caPath=DUMMY_CA_PATH, + certPath=DUMMY_CERT_PATH, + keyPath=DUMMY_KEY_PATH, + host=DUMMY_HOST, + timeoutSec=DUMMY_TIME_OUT_SEC) diff --git a/test/core/jobs/test_jobs_client.py b/test/core/jobs/test_jobs_client.py new file mode 100644 index 0000000..c36fc55 --- /dev/null +++ b/test/core/jobs/test_jobs_client.py @@ -0,0 +1,169 @@ +# Test AWSIoTMQTTThingJobsClient behavior + +from AWSIoTPythonSDK.MQTTLib import AWSIoTMQTTClient +from AWSIoTPythonSDK.MQTTLib import AWSIoTMQTTThingJobsClient +from AWSIoTPythonSDK.core.jobs.thingJobManager import thingJobManager +from AWSIoTPythonSDK.core.jobs.thingJobManager import jobExecutionTopicType +from AWSIoTPythonSDK.core.jobs.thingJobManager import jobExecutionTopicReplyType +from AWSIoTPythonSDK.core.jobs.thingJobManager import jobExecutionStatus +import AWSIoTPythonSDK.MQTTLib +import time +import json +from mock import MagicMock + +#asserts based on this documentation: https://docs.aws.amazon.com/iot/latest/developerguide/jobs-api.html +class TestAWSIoTMQTTThingJobsClient: + thingName = 'testThing' + clientTokenValue = 'testClientToken123' + statusDetailsMap = {'testKey':'testVal'} + + def setup_method(self, method): + self.mockAWSIoTMQTTClient = MagicMock(spec=AWSIoTMQTTClient) + self.jobsClient = AWSIoTMQTTThingJobsClient(self.clientTokenValue, self.thingName, QoS=0, awsIoTMQTTClient=self.mockAWSIoTMQTTClient) + self.jobsClient._thingJobManager = MagicMock(spec=thingJobManager) + + def test_unsuccessful_create_subscription(self): + fake_callback = MagicMock(); + self.jobsClient._thingJobManager.getJobTopic.return_value = 'UnsuccessfulCreateSubTopic' + self.mockAWSIoTMQTTClient.subscribe.return_value = False + assert False == self.jobsClient.createJobSubscription(fake_callback) + + def test_successful_job_request_create_subscription(self): + fake_callback = MagicMock(); + self.jobsClient._thingJobManager.getJobTopic.return_value = 'SuccessfulCreateSubRequestTopic' + self.mockAWSIoTMQTTClient.subscribe.return_value = True + assert True == self.jobsClient.createJobSubscription(fake_callback) + self.jobsClient._thingJobManager.getJobTopic.assert_called_with(jobExecutionTopicType.JOB_WILDCARD_TOPIC, jobExecutionTopicReplyType.JOB_REQUEST_TYPE, None) + self.mockAWSIoTMQTTClient.subscribe.assert_called_with(self.jobsClient._thingJobManager.getJobTopic.return_value, 0, fake_callback) + + def test_successful_job_start_next_create_subscription(self): + fake_callback = MagicMock(); + self.jobsClient._thingJobManager.getJobTopic.return_value = 'SuccessfulCreateSubStartNextTopic' + self.mockAWSIoTMQTTClient.subscribe.return_value = True + assert True == self.jobsClient.createJobSubscription(fake_callback, jobExecutionTopicType.JOB_START_NEXT_TOPIC, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE) + self.jobsClient._thingJobManager.getJobTopic.assert_called_with(jobExecutionTopicType.JOB_START_NEXT_TOPIC, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE, None) + self.mockAWSIoTMQTTClient.subscribe.assert_called_with(self.jobsClient._thingJobManager.getJobTopic.return_value, 0, fake_callback) + + def test_successful_job_update_create_subscription(self): + fake_callback = MagicMock(); + self.jobsClient._thingJobManager.getJobTopic.return_value = 'SuccessfulCreateSubUpdateTopic' + self.mockAWSIoTMQTTClient.subscribe.return_value = True + assert True == self.jobsClient.createJobSubscription(fake_callback, jobExecutionTopicType.JOB_UPDATE_TOPIC, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE, 'jobUpdateId') + self.jobsClient._thingJobManager.getJobTopic.assert_called_with(jobExecutionTopicType.JOB_UPDATE_TOPIC, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE, 'jobUpdateId') + self.mockAWSIoTMQTTClient.subscribe.assert_called_with(self.jobsClient._thingJobManager.getJobTopic.return_value, 0, fake_callback) + + def test_successful_job_update_notify_next_create_subscription(self): + fake_callback = MagicMock(); + self.jobsClient._thingJobManager.getJobTopic.return_value = 'SuccessfulCreateSubNotifyNextTopic' + self.mockAWSIoTMQTTClient.subscribe.return_value = True + assert True == self.jobsClient.createJobSubscription(fake_callback, jobExecutionTopicType.JOB_NOTIFY_NEXT_TOPIC) + self.jobsClient._thingJobManager.getJobTopic.assert_called_with(jobExecutionTopicType.JOB_NOTIFY_NEXT_TOPIC, jobExecutionTopicReplyType.JOB_REQUEST_TYPE, None) + self.mockAWSIoTMQTTClient.subscribe.assert_called_with(self.jobsClient._thingJobManager.getJobTopic.return_value, 0, fake_callback) + + def test_successful_job_request_create_subscription_async(self): + fake_callback = MagicMock(); + fake_ack_callback = MagicMock(); + self.jobsClient._thingJobManager.getJobTopic.return_value = 'CreateSubTopic1' + self.mockAWSIoTMQTTClient.subscribeAsync.return_value = 'MessageId1' + assert self.mockAWSIoTMQTTClient.subscribeAsync.return_value == self.jobsClient.createJobSubscriptionAsync(fake_ack_callback, fake_callback) + self.jobsClient._thingJobManager.getJobTopic.assert_called_with(jobExecutionTopicType.JOB_WILDCARD_TOPIC, jobExecutionTopicReplyType.JOB_REQUEST_TYPE, None) + self.mockAWSIoTMQTTClient.subscribeAsync.assert_called_with(self.jobsClient._thingJobManager.getJobTopic.return_value, 0, fake_ack_callback, fake_callback) + + def test_successful_job_start_next_create_subscription_async(self): + fake_callback = MagicMock(); + fake_ack_callback = MagicMock(); + self.jobsClient._thingJobManager.getJobTopic.return_value = 'CreateSubTopic3' + self.mockAWSIoTMQTTClient.subscribeAsync.return_value = 'MessageId3' + assert self.mockAWSIoTMQTTClient.subscribeAsync.return_value == self.jobsClient.createJobSubscriptionAsync(fake_ack_callback, fake_callback, jobExecutionTopicType.JOB_START_NEXT_TOPIC, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE) + self.jobsClient._thingJobManager.getJobTopic.assert_called_with(jobExecutionTopicType.JOB_START_NEXT_TOPIC, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE, None) + self.mockAWSIoTMQTTClient.subscribeAsync.assert_called_with(self.jobsClient._thingJobManager.getJobTopic.return_value, 0, fake_ack_callback, fake_callback) + + def test_successful_job_update_create_subscription_async(self): + fake_callback = MagicMock(); + fake_ack_callback = MagicMock(); + self.jobsClient._thingJobManager.getJobTopic.return_value = 'CreateSubTopic4' + self.mockAWSIoTMQTTClient.subscribeAsync.return_value = 'MessageId4' + assert self.mockAWSIoTMQTTClient.subscribeAsync.return_value == self.jobsClient.createJobSubscriptionAsync(fake_ack_callback, fake_callback, jobExecutionTopicType.JOB_UPDATE_TOPIC, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE, 'jobUpdateId3') + self.jobsClient._thingJobManager.getJobTopic.assert_called_with(jobExecutionTopicType.JOB_UPDATE_TOPIC, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE, 'jobUpdateId3') + self.mockAWSIoTMQTTClient.subscribeAsync.assert_called_with(self.jobsClient._thingJobManager.getJobTopic.return_value, 0, fake_ack_callback, fake_callback) + + def test_successful_job_notify_next_subscription_async(self): + fake_callback = MagicMock(); + fake_ack_callback = MagicMock(); + self.jobsClient._thingJobManager.getJobTopic.return_value = 'CreateSubTopic5' + self.mockAWSIoTMQTTClient.subscribeAsync.return_value = 'MessageId5' + assert self.mockAWSIoTMQTTClient.subscribeAsync.return_value == self.jobsClient.createJobSubscriptionAsync(fake_ack_callback, fake_callback, jobExecutionTopicType.JOB_NOTIFY_NEXT_TOPIC) + self.jobsClient._thingJobManager.getJobTopic.assert_called_with(jobExecutionTopicType.JOB_NOTIFY_NEXT_TOPIC, jobExecutionTopicReplyType.JOB_REQUEST_TYPE, None) + self.mockAWSIoTMQTTClient.subscribeAsync.assert_called_with(self.jobsClient._thingJobManager.getJobTopic.return_value, 0, fake_ack_callback, fake_callback) + + def test_send_jobs_query_get_pending(self): + self.jobsClient._thingJobManager.getJobTopic.return_value = 'SendJobsQuery1' + self.jobsClient._thingJobManager.serializeClientTokenPayload.return_value = {} + self.mockAWSIoTMQTTClient.publish.return_value = True + assert self.mockAWSIoTMQTTClient.publish.return_value == self.jobsClient.sendJobsQuery(jobExecutionTopicType.JOB_GET_PENDING_TOPIC) + self.jobsClient._thingJobManager.getJobTopic.assert_called_with(jobExecutionTopicType.JOB_GET_PENDING_TOPIC, jobExecutionTopicReplyType.JOB_REQUEST_TYPE, None) + self.jobsClient._thingJobManager.serializeClientTokenPayload.assert_called_with() + self.mockAWSIoTMQTTClient.publish.assert_called_with(self.jobsClient._thingJobManager.getJobTopic.return_value, self.jobsClient._thingJobManager.serializeClientTokenPayload.return_value, 0) + + def test_send_jobs_query_describe(self): + self.jobsClient._thingJobManager.getJobTopic.return_value = 'SendJobsQuery2' + self.jobsClient._thingJobManager.serializeClientTokenPayload.return_value = {} + self.mockAWSIoTMQTTClient.publish.return_value = True + assert self.mockAWSIoTMQTTClient.publish.return_value == self.jobsClient.sendJobsQuery(jobExecutionTopicType.JOB_DESCRIBE_TOPIC, 'jobId2') + self.jobsClient._thingJobManager.getJobTopic.assert_called_with(jobExecutionTopicType.JOB_DESCRIBE_TOPIC, jobExecutionTopicReplyType.JOB_REQUEST_TYPE, 'jobId2') + self.jobsClient._thingJobManager.serializeClientTokenPayload.assert_called_with() + self.mockAWSIoTMQTTClient.publish.assert_called_with(self.jobsClient._thingJobManager.getJobTopic.return_value, self.jobsClient._thingJobManager.serializeClientTokenPayload.return_value, 0) + + def test_send_jobs_start_next(self): + self.jobsClient._thingJobManager.getJobTopic.return_value = 'SendStartNext1' + self.jobsClient._thingJobManager.serializeStartNextPendingJobExecutionPayload.return_value = {} + self.mockAWSIoTMQTTClient.publish.return_value = True + assert self.mockAWSIoTMQTTClient.publish.return_value == self.jobsClient.sendJobsStartNext(self.statusDetailsMap, 12) + self.jobsClient._thingJobManager.getJobTopic.assert_called_with(jobExecutionTopicType.JOB_START_NEXT_TOPIC, jobExecutionTopicReplyType.JOB_REQUEST_TYPE) + self.jobsClient._thingJobManager.serializeStartNextPendingJobExecutionPayload.assert_called_with(self.statusDetailsMap, 12) + self.mockAWSIoTMQTTClient.publish.assert_called_with(self.jobsClient._thingJobManager.getJobTopic.return_value, self.jobsClient._thingJobManager.serializeStartNextPendingJobExecutionPayload.return_value, 0) + + def test_send_jobs_start_next_no_status_details(self): + self.jobsClient._thingJobManager.getJobTopic.return_value = 'SendStartNext2' + self.jobsClient._thingJobManager.serializeStartNextPendingJobExecutionPayload.return_value = {} + self.mockAWSIoTMQTTClient.publish.return_value = False + assert self.mockAWSIoTMQTTClient.publish.return_value == self.jobsClient.sendJobsStartNext({}) + self.jobsClient._thingJobManager.getJobTopic.assert_called_with(jobExecutionTopicType.JOB_START_NEXT_TOPIC, jobExecutionTopicReplyType.JOB_REQUEST_TYPE) + self.jobsClient._thingJobManager.serializeStartNextPendingJobExecutionPayload.assert_called_with({}, None) + self.mockAWSIoTMQTTClient.publish.assert_called_with(self.jobsClient._thingJobManager.getJobTopic.return_value, self.jobsClient._thingJobManager.serializeStartNextPendingJobExecutionPayload.return_value, 0) + + def test_send_jobs_update_succeeded(self): + self.jobsClient._thingJobManager.getJobTopic.return_value = 'SendJobsUpdate1' + self.jobsClient._thingJobManager.serializeJobExecutionUpdatePayload.return_value = {} + self.mockAWSIoTMQTTClient.publish.return_value = True + assert self.mockAWSIoTMQTTClient.publish.return_value == self.jobsClient.sendJobsUpdate('jobId1', jobExecutionStatus.JOB_EXECUTION_SUCCEEDED, self.statusDetailsMap, 1, 2, True, False, 12) + self.jobsClient._thingJobManager.getJobTopic.assert_called_with(jobExecutionTopicType.JOB_UPDATE_TOPIC, jobExecutionTopicReplyType.JOB_REQUEST_TYPE, 'jobId1') + self.jobsClient._thingJobManager.serializeJobExecutionUpdatePayload.assert_called_with(jobExecutionStatus.JOB_EXECUTION_SUCCEEDED, self.statusDetailsMap, 1, 2, True, False, 12) + self.mockAWSIoTMQTTClient.publish.assert_called_with(self.jobsClient._thingJobManager.getJobTopic.return_value, self.jobsClient._thingJobManager.serializeJobExecutionUpdatePayload.return_value, 0) + + def test_send_jobs_update_failed(self): + self.jobsClient._thingJobManager.getJobTopic.return_value = 'SendJobsUpdate2' + self.jobsClient._thingJobManager.serializeJobExecutionUpdatePayload.return_value = {} + self.mockAWSIoTMQTTClient.publish.return_value = False + assert self.mockAWSIoTMQTTClient.publish.return_value == self.jobsClient.sendJobsUpdate('jobId2', jobExecutionStatus.JOB_EXECUTION_FAILED, {}, 3, 4, False, True, 34) + self.jobsClient._thingJobManager.getJobTopic.assert_called_with(jobExecutionTopicType.JOB_UPDATE_TOPIC, jobExecutionTopicReplyType.JOB_REQUEST_TYPE, 'jobId2') + self.jobsClient._thingJobManager.serializeJobExecutionUpdatePayload.assert_called_with(jobExecutionStatus.JOB_EXECUTION_FAILED, {}, 3, 4, False, True, 34) + self.mockAWSIoTMQTTClient.publish.assert_called_with(self.jobsClient._thingJobManager.getJobTopic.return_value, self.jobsClient._thingJobManager.serializeJobExecutionUpdatePayload.return_value, 0) + + def test_send_jobs_describe(self): + self.jobsClient._thingJobManager.getJobTopic.return_value = 'SendJobsDescribe1' + self.jobsClient._thingJobManager.serializeDescribeJobExecutionPayload.return_value = {} + self.mockAWSIoTMQTTClient.publish.return_value = True + assert self.mockAWSIoTMQTTClient.publish.return_value == self.jobsClient.sendJobsDescribe('jobId1', 2, True) + self.jobsClient._thingJobManager.getJobTopic.assert_called_with(jobExecutionTopicType.JOB_DESCRIBE_TOPIC, jobExecutionTopicReplyType.JOB_REQUEST_TYPE, 'jobId1') + self.jobsClient._thingJobManager.serializeDescribeJobExecutionPayload.assert_called_with(2, True) + self.mockAWSIoTMQTTClient.publish.assert_called_with(self.jobsClient._thingJobManager.getJobTopic.return_value, self.jobsClient._thingJobManager.serializeDescribeJobExecutionPayload.return_value, 0) + + def test_send_jobs_describe_false_return_val(self): + self.jobsClient._thingJobManager.getJobTopic.return_value = 'SendJobsDescribe2' + self.jobsClient._thingJobManager.serializeDescribeJobExecutionPayload.return_value = {} + self.mockAWSIoTMQTTClient.publish.return_value = False + assert self.mockAWSIoTMQTTClient.publish.return_value == self.jobsClient.sendJobsDescribe('jobId2', 1, False) + self.jobsClient._thingJobManager.getJobTopic.assert_called_with(jobExecutionTopicType.JOB_DESCRIBE_TOPIC, jobExecutionTopicReplyType.JOB_REQUEST_TYPE, 'jobId2') + self.jobsClient._thingJobManager.serializeDescribeJobExecutionPayload.assert_called_with(1, False) + self.mockAWSIoTMQTTClient.publish.assert_called_with(self.jobsClient._thingJobManager.getJobTopic.return_value, self.jobsClient._thingJobManager.serializeDescribeJobExecutionPayload.return_value, 0) diff --git a/test/core/jobs/test_thing_job_manager.py b/test/core/jobs/test_thing_job_manager.py new file mode 100644 index 0000000..c3fa7b1 --- /dev/null +++ b/test/core/jobs/test_thing_job_manager.py @@ -0,0 +1,191 @@ +# Test thingJobManager behavior + +from AWSIoTPythonSDK.core.jobs.thingJobManager import thingJobManager as JobManager +from AWSIoTPythonSDK.core.jobs.thingJobManager import jobExecutionTopicType +from AWSIoTPythonSDK.core.jobs.thingJobManager import jobExecutionTopicReplyType +from AWSIoTPythonSDK.core.jobs.thingJobManager import jobExecutionStatus +import time +import json +from mock import MagicMock + +#asserts based on this documentation: https://docs.aws.amazon.com/iot/latest/developerguide/jobs-api.html +class TestThingJobManager: + thingName = 'testThing' + clientTokenValue = "testClientToken123" + thingJobManager = JobManager(thingName, clientTokenValue) + noClientTokenJobManager = JobManager(thingName) + jobId = '8192' + statusDetailsMap = {'testKey':'testVal'} + + def test_pending_topics(self): + topicType = jobExecutionTopicType.JOB_GET_PENDING_TOPIC + assert ('$aws/things/' + self.thingName + '/jobs/get') == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REQUEST_TYPE) + assert ('$aws/things/' + self.thingName + '/jobs/get/accepted') == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE) + assert ('$aws/things/' + self.thingName + '/jobs/get/rejected') == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE) + assert ('$aws/things/' + self.thingName + '/jobs/get/#') == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_WILDCARD_REPLY_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REQUEST_TYPE, self.jobId) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE, self.jobId) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE, self.jobId) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_WILDCARD_REPLY_TYPE, self.jobId) + + def test_start_next_topics(self): + topicType = jobExecutionTopicType.JOB_START_NEXT_TOPIC + assert ('$aws/things/' + self.thingName + '/jobs/start-next') == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REQUEST_TYPE) + assert ('$aws/things/' + self.thingName + '/jobs/start-next/accepted') == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE) + assert ('$aws/things/' + self.thingName + '/jobs/start-next/rejected') == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE) + assert ('$aws/things/' + self.thingName + '/jobs/start-next/#') == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_WILDCARD_REPLY_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE, self.jobId) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE, self.jobId) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_WILDCARD_REPLY_TYPE, self.jobId) + + def test_describe_topics(self): + topicType = jobExecutionTopicType.JOB_DESCRIBE_TOPIC + assert ('$aws/things/' + self.thingName + '/jobs/' + str(self.jobId) + '/get') == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REQUEST_TYPE, self.jobId) + assert ('$aws/things/' + self.thingName + '/jobs/' + str(self.jobId) + '/get/accepted') == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE, self.jobId) + assert ('$aws/things/' + self.thingName + '/jobs/' + str(self.jobId) + '/get/rejected') == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE, self.jobId) + assert ('$aws/things/' + self.thingName + '/jobs/' + str(self.jobId) + '/get/#') == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_WILDCARD_REPLY_TYPE, self.jobId) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REQUEST_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_WILDCARD_REPLY_TYPE) + + def test_update_topics(self): + topicType = jobExecutionTopicType.JOB_UPDATE_TOPIC + assert ('$aws/things/' + self.thingName + '/jobs/' + str(self.jobId) + '/update') == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REQUEST_TYPE, self.jobId) + assert ('$aws/things/' + self.thingName + '/jobs/' + str(self.jobId) + '/update/accepted') == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE, self.jobId) + assert ('$aws/things/' + self.thingName + '/jobs/' + str(self.jobId) + '/update/rejected') == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE, self.jobId) + assert ('$aws/things/' + self.thingName + '/jobs/' + str(self.jobId) + '/update/#') == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_WILDCARD_REPLY_TYPE, self.jobId) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REQUEST_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_WILDCARD_REPLY_TYPE) + + def test_notify_topics(self): + topicType = jobExecutionTopicType.JOB_NOTIFY_TOPIC + assert ('$aws/things/' + self.thingName + '/jobs/notify') == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REQUEST_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_WILDCARD_REPLY_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE, self.jobId) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE, self.jobId) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_WILDCARD_REPLY_TYPE, self.jobId) + + def test_notify_next_topics(self): + topicType = jobExecutionTopicType.JOB_NOTIFY_NEXT_TOPIC + assert ('$aws/things/' + self.thingName + '/jobs/notify-next') == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REQUEST_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_WILDCARD_REPLY_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REQUEST_TYPE, self.jobId) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE, self.jobId) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE, self.jobId) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_WILDCARD_REPLY_TYPE, self.jobId) + + def test_wildcard_topics(self): + topicType = jobExecutionTopicType.JOB_WILDCARD_TOPIC + topicString = '$aws/things/' + self.thingName + '/jobs/#' + assert topicString == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REQUEST_TYPE) + assert topicString == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE) + assert topicString == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE) + assert topicString == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_WILDCARD_REPLY_TYPE) + + def test_thingless_topics(self): + thinglessJobManager = JobManager(None) + assert None == thinglessJobManager.getJobTopic(jobExecutionTopicType.JOB_GET_PENDING_TOPIC) + assert None == thinglessJobManager.getJobTopic(jobExecutionTopicType.JOB_START_NEXT_TOPIC) + assert None == thinglessJobManager.getJobTopic(jobExecutionTopicType.JOB_DESCRIBE_TOPIC) + assert None == thinglessJobManager.getJobTopic(jobExecutionTopicType.JOB_UPDATE_TOPIC) + assert None == thinglessJobManager.getJobTopic(jobExecutionTopicType.JOB_NOTIFY_TOPIC) + assert None == thinglessJobManager.getJobTopic(jobExecutionTopicType.JOB_NOTIFY_NEXT_TOPIC) + assert None == thinglessJobManager.getJobTopic(jobExecutionTopicType.JOB_WILDCARD_TOPIC) + + def test_unrecognized_topics(self): + topicType = jobExecutionTopicType.JOB_UNRECOGNIZED_TOPIC + assert None == self.thingJobManager.getJobTopic(topicType) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_WILDCARD_REPLY_TYPE) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REQUEST_TYPE, self.jobId) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_ACCEPTED_REPLY_TYPE, self.jobId) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_REJECTED_REPLY_TYPE, self.jobId) + assert None == self.thingJobManager.getJobTopic(topicType, jobExecutionTopicReplyType.JOB_WILDCARD_REPLY_TYPE, self.jobId) + + def test_serialize_client_token(self): + payload = '{"clientToken": "' + self.clientTokenValue + '"}' + assert payload == self.thingJobManager.serializeClientTokenPayload() + assert "{}" == self.noClientTokenJobManager.serializeClientTokenPayload() + + def test_serialize_start_next_pending_job_execution(self): + payload = {'clientToken': self.clientTokenValue} + assert payload == json.loads(self.thingJobManager.serializeStartNextPendingJobExecutionPayload()) + assert {} == json.loads(self.noClientTokenJobManager.serializeStartNextPendingJobExecutionPayload()) + payload.update({'statusDetails': self.statusDetailsMap}) + assert payload == json.loads(self.thingJobManager.serializeStartNextPendingJobExecutionPayload(self.statusDetailsMap)) + assert {'statusDetails': self.statusDetailsMap} == json.loads(self.noClientTokenJobManager.serializeStartNextPendingJobExecutionPayload(self.statusDetailsMap)) + + def test_serialize_describe_job_execution(self): + payload = {'includeJobDocument': True} + assert payload == json.loads(self.noClientTokenJobManager.serializeDescribeJobExecutionPayload()) + payload.update({'executionNumber': 1}) + assert payload == json.loads(self.noClientTokenJobManager.serializeDescribeJobExecutionPayload(1)) + payload.update({'includeJobDocument': False}) + assert payload == json.loads(self.noClientTokenJobManager.serializeDescribeJobExecutionPayload(1, False)) + + payload = {'includeJobDocument': True, 'clientToken': self.clientTokenValue} + assert payload == json.loads(self.thingJobManager.serializeDescribeJobExecutionPayload()) + payload.update({'executionNumber': 1}) + assert payload == json.loads(self.thingJobManager.serializeDescribeJobExecutionPayload(1)) + payload.update({'includeJobDocument': False}) + assert payload == json.loads(self.thingJobManager.serializeDescribeJobExecutionPayload(1, False)) + + def test_serialize_job_execution_update(self): + assert None == self.thingJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_STATUS_NOT_SET) + assert None == self.thingJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_UNKNOWN_STATUS) + assert None == self.noClientTokenJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_STATUS_NOT_SET) + assert None == self.noClientTokenJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_UNKNOWN_STATUS) + + payload = {'status':'IN_PROGRESS'} + assert payload == json.loads(self.noClientTokenJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_IN_PROGRESS)) + payload.update({'status':'FAILED'}) + assert payload == json.loads(self.noClientTokenJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_FAILED)) + payload.update({'status':'SUCCEEDED'}) + assert payload == json.loads(self.noClientTokenJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_SUCCEEDED)) + payload.update({'status':'CANCELED'}) + assert payload == json.loads(self.noClientTokenJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_CANCELED)) + payload.update({'status':'REJECTED'}) + assert payload == json.loads(self.noClientTokenJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_REJECTED)) + payload.update({'status':'QUEUED'}) + assert payload == json.loads(self.noClientTokenJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_QUEUED)) + payload.update({'statusDetails': self.statusDetailsMap}) + assert payload == json.loads(self.noClientTokenJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_QUEUED, self.statusDetailsMap)) + payload.update({'expectedVersion': '1'}) + assert payload == json.loads(self.noClientTokenJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_QUEUED, self.statusDetailsMap, 1)) + payload.update({'executionNumber': '1'}) + assert payload == json.loads(self.noClientTokenJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_QUEUED, self.statusDetailsMap, 1, 1)) + payload.update({'includeJobExecutionState': True}) + assert payload == json.loads(self.noClientTokenJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_QUEUED, self.statusDetailsMap, 1, 1, True)) + payload.update({'includeJobDocument': True}) + assert payload == json.loads(self.noClientTokenJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_QUEUED, self.statusDetailsMap, 1, 1, True, True)) + + payload = {'status':'IN_PROGRESS', 'clientToken': self.clientTokenValue} + assert payload == json.loads(self.thingJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_IN_PROGRESS)) + payload.update({'status':'FAILED'}) + assert payload == json.loads(self.thingJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_FAILED)) + payload.update({'status':'SUCCEEDED'}) + assert payload == json.loads(self.thingJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_SUCCEEDED)) + payload.update({'status':'CANCELED'}) + assert payload == json.loads(self.thingJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_CANCELED)) + payload.update({'status':'REJECTED'}) + assert payload == json.loads(self.thingJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_REJECTED)) + payload.update({'status':'QUEUED'}) + assert payload == json.loads(self.thingJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_QUEUED)) + payload.update({'statusDetails': self.statusDetailsMap}) + assert payload == json.loads(self.thingJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_QUEUED, self.statusDetailsMap)) + payload.update({'expectedVersion': '1'}) + assert payload == json.loads(self.thingJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_QUEUED, self.statusDetailsMap, 1)) + payload.update({'executionNumber': '1'}) + assert payload == json.loads(self.thingJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_QUEUED, self.statusDetailsMap, 1, 1)) + payload.update({'includeJobExecutionState': True}) + assert payload == json.loads(self.thingJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_QUEUED, self.statusDetailsMap, 1, 1, True)) + payload.update({'includeJobDocument': True}) + assert payload == json.loads(self.thingJobManager.serializeJobExecutionUpdatePayload(jobExecutionStatus.JOB_EXECUTION_QUEUED, self.statusDetailsMap, 1, 1, True, True)) diff --git a/test/core/protocol/__init__.py b/test/core/protocol/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/core/protocol/connection/__init__.py b/test/core/protocol/connection/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/core/protocol/connection/test_alpn.py b/test/core/protocol/connection/test_alpn.py new file mode 100644 index 0000000..e9d2a2b --- /dev/null +++ b/test/core/protocol/connection/test_alpn.py @@ -0,0 +1,123 @@ +import AWSIoTPythonSDK.core.protocol.connection.alpn as alpn +from AWSIoTPythonSDK.core.protocol.connection.alpn import SSLContextBuilder +import sys +import pytest +try: + from mock import patch + from mock import MagicMock +except: + from unittest.mock import patch + from unittest.mock import MagicMock +if sys.version_info >= (3, 4): + from importlib import reload + + +python3_5_above_only = pytest.mark.skipif(sys.version_info >= (3, 0) and sys.version_info < (3, 5), reason="Requires Python 3.5+") +python2_7_10_above_only = pytest.mark.skipif(sys.version_info < (2, 7, 10), reason="Requires Python 2.7.10+") + +PATCH_MODULE_LOCATION = "AWSIoTPythonSDK.core.protocol.connection.alpn" +SSL_MODULE_NAME = "ssl" +SSL_CONTEXT_METHOD_NAME = "create_default_context" + +DUMMY_SSL_PROTOCOL = "DummySSLProtocol" +DUMMY_CERT_REQ = "DummyCertReq" +DUMMY_CIPHERS = "DummyCiphers" +DUMMY_CA_FILE_PATH = "fake/path/to/ca" +DUMMY_CERT_FILE_PATH = "fake/path/to/cert" +DUMMY_KEY_FILE_PATH = "fake/path/to/key" +DUMMY_ALPN_PROTOCOLS = "x-amzn-mqtt-ca" + + +@python2_7_10_above_only +@python3_5_above_only +class TestALPNSSLContextBuilder: + + def test_check_supportability_no_ssl(self): + self._preserve_ssl() + try: + self._none_ssl() + with pytest.raises(RuntimeError): + alpn.SSLContextBuilder().build() + finally: + self._unnone_ssl() + + def _none_ssl(self): + # We always run the unit test with Python versions that have proper ssl support + # We need to mock it out in this test + sys.modules[SSL_MODULE_NAME] = None + reload(alpn) + + def _unnone_ssl(self): + sys.modules[SSL_MODULE_NAME] = self._normal_ssl_module + reload(alpn) + + def test_check_supportability_no_ssl_context(self): + self._preserve_ssl() + try: + self._mock_ssl() + del self.ssl_mock.SSLContext + with pytest.raises(NotImplementedError): + SSLContextBuilder() + finally: + self._unmock_ssl() + + def test_check_supportability_no_alpn(self): + self._preserve_ssl() + try: + self._mock_ssl() + del self.ssl_mock.SSLContext.set_alpn_protocols + with pytest.raises(NotImplementedError): + SSLContextBuilder() + finally: + self._unmock_ssl() + + def _preserve_ssl(self): + self._normal_ssl_module = sys.modules[SSL_MODULE_NAME] + + def _mock_ssl(self): + self.ssl_mock = MagicMock() + alpn.ssl = self.ssl_mock + + def _unmock_ssl(self): + alpn.ssl = self._normal_ssl_module + + def test_with_ca_certs(self): + self._use_mock_ssl_context() + SSLContextBuilder().with_ca_certs(DUMMY_CA_FILE_PATH).build() + self.mock_ssl_context.load_verify_locations.assert_called_once_with(DUMMY_CA_FILE_PATH) + + def test_with_cert_key_pair(self): + self._use_mock_ssl_context() + SSLContextBuilder().with_cert_key_pair(DUMMY_CERT_FILE_PATH, DUMMY_KEY_FILE_PATH).build() + self.mock_ssl_context.load_cert_chain.assert_called_once_with(DUMMY_CERT_FILE_PATH, DUMMY_KEY_FILE_PATH) + + def test_with_cert_reqs(self): + self._use_mock_ssl_context() + SSLContextBuilder().with_cert_reqs(DUMMY_CERT_REQ).build() + assert self.mock_ssl_context.verify_mode == DUMMY_CERT_REQ + + def test_with_check_hostname(self): + self._use_mock_ssl_context() + SSLContextBuilder().with_check_hostname(True).build() + assert self.mock_ssl_context.check_hostname == True + + def test_with_ciphers(self): + self._use_mock_ssl_context() + SSLContextBuilder().with_ciphers(DUMMY_CIPHERS).build() + self.mock_ssl_context.set_ciphers.assert_called_once_with(DUMMY_CIPHERS) + + def test_with_none_ciphers(self): + self._use_mock_ssl_context() + SSLContextBuilder().with_ciphers(None).build() + assert not self.mock_ssl_context.set_ciphers.called + + def test_with_alpn_protocols(self): + self._use_mock_ssl_context() + SSLContextBuilder().with_alpn_protocols(DUMMY_ALPN_PROTOCOLS) + self.mock_ssl_context.set_alpn_protocols.assert_called_once_with(DUMMY_ALPN_PROTOCOLS) + + def _use_mock_ssl_context(self): + self.mock_ssl_context = MagicMock() + self.ssl_create_default_context_patcher = patch("%s.%s.%s" % (PATCH_MODULE_LOCATION, SSL_MODULE_NAME, SSL_CONTEXT_METHOD_NAME)) + self.mock_ssl_create_default_context = self.ssl_create_default_context_patcher.start() + self.mock_ssl_create_default_context.return_value = self.mock_ssl_context diff --git a/test/core/protocol/connection/test_progressive_back_off_core.py b/test/core/protocol/connection/test_progressive_back_off_core.py new file mode 100755 index 0000000..1cfcd40 --- /dev/null +++ b/test/core/protocol/connection/test_progressive_back_off_core.py @@ -0,0 +1,74 @@ +import time +import AWSIoTPythonSDK.core.protocol.connection.cores as backoff +import pytest + + +class TestProgressiveBackOffCore(): + def setup_method(self, method): + self._dummyBackOffCore = backoff.ProgressiveBackOffCore() + + def teardown_method(self, method): + self._dummyBackOffCore = None + + # Check that current backoff time is one seconds when this is the first time to backoff + def test_BackoffForTheFirstTime(self): + assert self._dummyBackOffCore._currentBackoffTimeSecond == 1 + + # Check that valid input values for backoff configuration is properly configued + def test_CustomConfig_ValidInput(self): + self._dummyBackOffCore.configTime(2, 128, 30) + assert self._dummyBackOffCore._baseReconnectTimeSecond == 2 + assert self._dummyBackOffCore._maximumReconnectTimeSecond == 128 + assert self._dummyBackOffCore._minimumConnectTimeSecond == 30 + + # Check the negative input values will trigger exception + def test_CustomConfig_NegativeInput(self): + with pytest.raises(ValueError) as e: + # _baseReconnectTimeSecond should be greater than zero, otherwise raise exception + self._dummyBackOffCore.configTime(-10, 128, 30) + with pytest.raises(ValueError) as e: + # _maximumReconnectTimeSecond should be greater than zero, otherwise raise exception + self._dummyBackOffCore.configTime(2, -11, 30) + with pytest.raises(ValueError) as e: + # _minimumConnectTimeSecond should be greater than zero, otherwise raise exception + self._dummyBackOffCore.configTime(2, 128, -12) + + # Check the invalid input values will trigger exception + def test_CustomConfig_InvalidInput(self): + with pytest.raises(ValueError) as e: + # _baseReconnectTimeSecond is larger than _minimumConnectTimeSecond, + # which is not allowed... + self._dummyBackOffCore.configTime(200, 128, 30) + + # Check the _currentBackoffTimeSecond increases to twice of the origin after 2nd backoff + def test_backOffUpdatesCurrentBackoffTime(self): + self._dummyBackOffCore.configTime(1, 32, 20) + self._dummyBackOffCore.backOff() # This is the first backoff, block for 0 seconds + assert self._dummyBackOffCore._currentBackoffTimeSecond == self._dummyBackOffCore._baseReconnectTimeSecond * 2 + self._dummyBackOffCore.backOff() # Now progressive backoff calc starts + assert self._dummyBackOffCore._currentBackoffTimeSecond == self._dummyBackOffCore._baseReconnectTimeSecond * 2 * 2 + + # Check that backoff time is reset when connection is stable enough + def test_backOffResetWhenConnectionIsStable(self): + self._dummyBackOffCore.configTime(1, 32, 5) + self._dummyBackOffCore.backOff() # This is the first backoff, block for 0 seconds + assert self._dummyBackOffCore._currentBackoffTimeSecond == self._dummyBackOffCore._baseReconnectTimeSecond * 2 + self._dummyBackOffCore.backOff() # Now progressive backoff calc starts + assert self._dummyBackOffCore._currentBackoffTimeSecond == self._dummyBackOffCore._baseReconnectTimeSecond * 2 * 2 + # Now simulate a stable connection that exceeds _minimumConnectTimeSecond + self._dummyBackOffCore.startStableConnectionTimer() # Called when CONNACK arrives + time.sleep(self._dummyBackOffCore._minimumConnectTimeSecond + 1) + # Timer expires, currentBackoffTimeSecond should be reset + assert self._dummyBackOffCore._currentBackoffTimeSecond == self._dummyBackOffCore._baseReconnectTimeSecond + + # Check that backoff resetting timer is properly cancelled when a disconnect happens immediately + def test_resetTimerProperlyCancelledOnUnstableConnection(self): + self._dummyBackOffCore.configTime(1, 32, 5) + self._dummyBackOffCore.backOff() # This is the first backoff, block for 0 seconds + assert self._dummyBackOffCore._currentBackoffTimeSecond == self._dummyBackOffCore._baseReconnectTimeSecond * 2 + # Now simulate an unstable connection that is within _minimumConnectTimeSecond + self._dummyBackOffCore.startStableConnectionTimer() # Called when CONNACK arrives + time.sleep(self._dummyBackOffCore._minimumConnectTimeSecond - 1) + # Now "disconnect" + self._dummyBackOffCore.backOff() + assert self._dummyBackOffCore._currentBackoffTimeSecond == self._dummyBackOffCore._baseReconnectTimeSecond * 2 * 2 diff --git a/test/core/protocol/connection/test_sigv4_core.py b/test/core/protocol/connection/test_sigv4_core.py new file mode 100644 index 0000000..576efb1 --- /dev/null +++ b/test/core/protocol/connection/test_sigv4_core.py @@ -0,0 +1,148 @@ +from AWSIoTPythonSDK.core.protocol.connection.cores import SigV4Core +from AWSIoTPythonSDK.exception.AWSIoTExceptions import wssNoKeyInEnvironmentError +import os +from datetime import datetime +try: + from mock import patch + from mock import MagicMock +except: + from unittest.mock import patch + from unittest.mock import MagicMock +import pytest +try: + from configparser import ConfigParser # Python 3+ + from configparser import NoOptionError +except ImportError: + from ConfigParser import ConfigParser + from ConfigParser import NoOptionError + + +CREDS_NOT_FOUND_MODE_NO_KEYS = "NoKeys" +CREDS_NOT_FOUND_MODE_EMPTY_VALUES = "EmptyValues" + +PATCH_MODULE_LOCATION = "AWSIoTPythonSDK.core.protocol.connection.cores." +DUMMY_ACCESS_KEY_ID = "TRUSTMETHISIDISFAKE0" +DUMMY_SECRET_ACCESS_KEY = "trustMeThisSecretKeyIsSoFakeAaBbCc00Dd11" +DUMMY_SESSION_TOKEN = "FQoDYXdzEGcaDNSwicOypVyhiHj4JSLUAXTsOXu1YGT/Oaltz" \ + "XujI+cwvEA3zPoUdebHOkaUmRBO3o34J/3r2/+hBqZZNSpyzK" \ + "sBge1MXPwbM2G5ojz3aY4Qj+zD3hEMu9nxk3rhKkmTQWLoB4Z" \ + "rPRG6GJGkoLMAL1sSEh9kqbHN6XIt3F2E+Wn2BhDoGA7ZsXSg" \ + "+pgIntkSZcLT7pCX8pTEaEtRBhJQVc5GTYhG9y9mgjpeVRsbE" \ + "j8yDJzSWDpLGgR7APSvCFX2H+DwsKM564Z4IzjpbntIlLXdQw" \ + "Oytd65dgTlWZkmmYpTwVh+KMq+0MoF" +DUMMY_UTC_NOW_STRFTIME_RESULT = "20170628T204845Z" + +EXPECTED_WSS_URL_WITH_TOKEN = "wss://data.iot.us-east-1.amazonaws.com:44" \ + "3/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X" \ + "-Amz-Credential=TRUSTMETHISIDISFAKE0%2F20" \ + "170628%2Fus-east-1%2Fiotdata%2Faws4_reque" \ + "st&X-Amz-Date=20170628T204845Z&X-Amz-Expi" \ + "res=86400&X-Amz-SignedHeaders=host&X-Amz-" \ + "Signature=b79a4d7e31ccbf96b22d93cce1b500b" \ + "9ee611ec966159547e140ae32e4dcebed&X-Amz-S" \ + "ecurity-Token=FQoDYXdzEGcaDNSwicOypVyhiHj" \ + "4JSLUAXTsOXu1YGT/OaltzXujI%2BcwvEA3zPoUde" \ + "bHOkaUmRBO3o34J/3r2/%2BhBqZZNSpyzKsBge1MX" \ + "PwbM2G5ojz3aY4Qj%2BzD3hEMu9nxk3rhKkmTQWLo" \ + "B4ZrPRG6GJGkoLMAL1sSEh9kqbHN6XIt3F2E%2BWn" \ + "2BhDoGA7ZsXSg%2BpgIntkSZcLT7pCX8pTEaEtRBh" \ + "JQVc5GTYhG9y9mgjpeVRsbEj8yDJzSWDpLGgR7APS" \ + "vCFX2H%2BDwsKM564Z4IzjpbntIlLXdQwOytd65dg" \ + "TlWZkmmYpTwVh%2BKMq%2B0MoF" +EXPECTED_WSS_URL_WITHOUT_TOKEN = "wss://data.iot.us-east-1.amazonaws.com" \ + ":443/mqtt?X-Amz-Algorithm=AWS4-HMAC-SH" \ + "A256&X-Amz-Credential=TRUSTMETHISIDISF" \ + "AKE0%2F20170628%2Fus-east-1%2Fiotdata%" \ + "2Faws4_request&X-Amz-Date=20170628T204" \ + "845Z&X-Amz-Expires=86400&X-Amz-SignedH" \ + "eaders=host&X-Amz-Signature=b79a4d7e31" \ + "ccbf96b22d93cce1b500b9ee611ec966159547" \ + "e140ae32e4dcebed" + + +class TestSigV4Core: + + def setup_method(self, test_method): + self._use_mock_datetime() + self.mock_utc_now_result.strftime.return_value = DUMMY_UTC_NOW_STRFTIME_RESULT + self.sigv4_core = SigV4Core() + + def _use_mock_datetime(self): + self.datetime_patcher = patch(PATCH_MODULE_LOCATION + "datetime", spec=datetime) + self.mock_datetime_constructor = self.datetime_patcher.start() + self.mock_utc_now_result = MagicMock(spec=datetime) + self.mock_datetime_constructor.utcnow.return_value = self.mock_utc_now_result + + def teardown_method(self, test_method): + self.datetime_patcher.stop() + + def test_generate_url_with_env_credentials(self): + self._use_mock_os_environ({ + "AWS_ACCESS_KEY_ID" : DUMMY_ACCESS_KEY_ID, + "AWS_SECRET_ACCESS_KEY" : DUMMY_SECRET_ACCESS_KEY + }) + assert self._invoke_create_wss_endpoint_api() == EXPECTED_WSS_URL_WITHOUT_TOKEN + self.python_os_environ_patcher.stop() + + def test_generate_url_with_env_credentials_token(self): + self._use_mock_os_environ({ + "AWS_ACCESS_KEY_ID" : DUMMY_ACCESS_KEY_ID, + "AWS_SECRET_ACCESS_KEY" : DUMMY_SECRET_ACCESS_KEY, + "AWS_SESSION_TOKEN" : DUMMY_SESSION_TOKEN + }) + assert self._invoke_create_wss_endpoint_api() == EXPECTED_WSS_URL_WITH_TOKEN + self.python_os_environ_patcher.stop() + + def _use_mock_os_environ(self, os_environ_map): + self.python_os_environ_patcher = patch.dict(os.environ, os_environ_map) + self.python_os_environ_patcher.start() + + def _use_mock_configparser(self): + self.configparser_patcher = patch(PATCH_MODULE_LOCATION + "ConfigParser", spec=ConfigParser) + self.mock_configparser_constructor = self.configparser_patcher.start() + self.mock_configparser = MagicMock(spec=ConfigParser) + self.mock_configparser_constructor.return_value = self.mock_configparser + + def test_generate_url_with_input_credentials(self): + self._configure_mocks_credentials_not_found_in_env_config() + self.sigv4_core.setIAMCredentials(DUMMY_ACCESS_KEY_ID, DUMMY_SECRET_ACCESS_KEY, "") + + assert self._invoke_create_wss_endpoint_api() == EXPECTED_WSS_URL_WITHOUT_TOKEN + + self._recover_mocks_for_env_config() + + def test_generate_url_with_input_credentials_token(self): + self._configure_mocks_credentials_not_found_in_env_config() + self.sigv4_core.setIAMCredentials(DUMMY_ACCESS_KEY_ID, DUMMY_SECRET_ACCESS_KEY, DUMMY_SESSION_TOKEN) + + assert self._invoke_create_wss_endpoint_api() == EXPECTED_WSS_URL_WITH_TOKEN + + self._recover_mocks_for_env_config() + + def _recover_mocks_for_env_config(self): + self.python_os_environ_patcher.stop() + self.configparser_patcher.stop() + + def test_generate_url_failure_when_credential_configured_with_none_values(self): + self._use_mock_os_environ({}) + self._use_mock_configparser() + self.mock_configparser.get.side_effect = NoOptionError("option", "section") + self.sigv4_core.setIAMCredentials(None, None, None) + + with pytest.raises(wssNoKeyInEnvironmentError): + self._invoke_create_wss_endpoint_api() + + def _configure_mocks_credentials_not_found_in_env_config(self, mode=CREDS_NOT_FOUND_MODE_NO_KEYS): + if mode == CREDS_NOT_FOUND_MODE_NO_KEYS: + self._use_mock_os_environ({}) + elif mode == CREDS_NOT_FOUND_MODE_EMPTY_VALUES: + self._use_mock_os_environ({ + "AWS_ACCESS_KEY_ID" : "", + "AWS_SECRET_ACCESS_KEY" : "" + }) + self._use_mock_configparser() + self.mock_configparser.get.side_effect = NoOptionError("option", "section") + + def _invoke_create_wss_endpoint_api(self): + return self.sigv4_core.createWebsocketEndpoint("data.iot.us-east-1.amazonaws.com", 443, "us-east-1", + "GET", "iotdata", "/mqtt") diff --git a/test/core/protocol/connection/test_wss_core.py b/test/core/protocol/connection/test_wss_core.py new file mode 100755 index 0000000..bbe7244 --- /dev/null +++ b/test/core/protocol/connection/test_wss_core.py @@ -0,0 +1,249 @@ +from test.sdk_mock.mockSecuredWebsocketCore import mockSecuredWebsocketCoreNoRealHandshake +from test.sdk_mock.mockSecuredWebsocketCore import MockSecuredWebSocketCoreNoSocketIO +from test.sdk_mock.mockSecuredWebsocketCore import MockSecuredWebSocketCoreWithRealHandshake +from test.sdk_mock.mockSSLSocket import mockSSLSocket +import struct +import socket +import pytest +try: + from configparser import ConfigParser # Python 3+ +except ImportError: + from ConfigParser import ConfigParser + + +class TestWssCore: + + # Websocket Constants + _OP_CONTINUATION = 0x0 + _OP_TEXT = 0x1 + _OP_BINARY = 0x2 + _OP_CONNECTION_CLOSE = 0x8 + _OP_PING = 0x9 + _OP_PONG = 0xa + + def _generateStringOfAs(self, length): + ret = "" + for i in range(0, length): + ret += 'a' + return ret + + def _printByteArray(self, src): + for i in range(0, len(src)): + print(hex(src[i])) + print("") + + def _encodeFrame(self, rawPayload, opCode, FIN=1, RSV1=0, RSV2=0, RSV3=0, masked=1): + ret = bytearray() + # FIN+RSV1+RSV2+RSV3 + F = (FIN & 0x01) << 3 + R1 = (RSV1 & 0x01) << 2 + R2 = (RSV2 & 0x01) << 1 + R3 = (RSV3 & 0x01) + FRRR = (F | R1 | R2 | R3) << 4 + # Op byte + opByte = FRRR | opCode + ret.append(opByte) + # Payload Length bytes + maskBit = masked + payloadLength = len(rawPayload) + if payloadLength <= 125: + ret.append((maskBit << 7) | payloadLength) + elif payloadLength <= 0xffff: # 16-bit unsigned int + ret.append((maskBit << 7) | 126) + ret.extend(struct.pack("!H", payloadLength)) + elif payloadLength <= 0x7fffffffffffffff: # 64-bit unsigned int (most significant bit must be 0) + ret.append((maskBit << 7) | 127) + ret.extend(struct.pack("!Q", payloadLength)) + else: # Overflow + raise ValueError("Exceeds the maximum number of bytes for a single websocket frame.") + if maskBit == 1: + # Mask key bytes + maskKey = bytearray(b"1234") + ret.extend(maskKey) + # Mask the payload + payloadBytes = bytearray(rawPayload) + if maskBit == 1: + for i in range(0, payloadLength): + payloadBytes[i] ^= maskKey[i % 4] + ret.extend(payloadBytes) + # Return the assembled wss frame + return ret + + def setup_method(self, method): + self._dummySSLSocket = mockSSLSocket() + + # Wss Handshake + def test_WssHandshakeTimeout(self): + self._dummySSLSocket.refreshReadBuffer(bytearray()) # Empty bytes to read from socket + with pytest.raises(socket.error): + self._dummySecuredWebsocket = \ + MockSecuredWebSocketCoreNoSocketIO(self._dummySSLSocket, "data.iot.region.amazonaws.com", 1234) + + # Constructor + def test_InvalidEndpointPattern(self): + with pytest.raises(ValueError): + self._dummySecuredWebsocket = MockSecuredWebSocketCoreWithRealHandshake(None, "ThisIsNotAValidIoTEndpoint!", 1234) + + def test_BJSEndpointPattern(self): + bjsStyleEndpoint = "blablabla.iot.cn-north-1.amazonaws.com.cn" + unexpectedExceptionMessage = "Invalid endpoint pattern for wss: %s" % bjsStyleEndpoint + # Garbage wss handshake response to ensure the test code gets passed endpoint pattern validation + self._dummySSLSocket.refreshReadBuffer(b"GarbageWssHanshakeResponse") + try: + self._dummySecuredWebsocket = MockSecuredWebSocketCoreWithRealHandshake(self._dummySSLSocket, bjsStyleEndpoint, 1234) + except ValueError as e: + if str(e) == unexpectedExceptionMessage: + raise AssertionError("Encountered unexpected exception when initializing wss core with BJS style endpoint", e) + + # Wss I/O + def test_WssReadComplete(self): + # Config mockSSLSocket to contain a Wss frame + rawPayload = b"If you can see me, this is good." + # The payload of this frame will be masked by a randomly-generated mask key + # securedWebsocketCore should be able to decode it and get the raw payload back + coolFrame = self._encodeFrame(rawPayload, self._OP_BINARY, FIN=1, RSV1=0, RSV2=0, RSV3=0, masked=0) + # self._printByteArray(coolFrame) + self._dummySSLSocket.refreshReadBuffer(coolFrame) + # Init securedWebsocket with this mockSSLSocket + self._dummySecuredWebsocket = \ + mockSecuredWebsocketCoreNoRealHandshake(self._dummySSLSocket, "data.iot.region.amazonaws.com", 1234) + # Read it back: + readItBack = self._dummySecuredWebsocket.read(len(rawPayload)) # Basically read everything + assert rawPayload == readItBack + + def test_WssReadFragmented(self): + rawPayloadFragmented = b"I am designed to be fragmented..." + # The payload of this frame will be masked by a randomly-generated mask key + # securedWebsocketCore should be able to decode it and get the raw payload back + stop1 = 4 + stop2 = 9 + coolFrame = self._encodeFrame(rawPayloadFragmented, self._OP_BINARY, FIN=1, RSV1=0, RSV2=0, RSV3=0, masked=0) + # self._printByteArray(coolFrame) + coolFramePart1 = coolFrame[0:stop1] + coolFramePart2 = coolFrame[stop1:stop2] + coolFramePart3 = coolFrame[stop2:len(coolFrame)] + # Config mockSSLSocket to contain a fragmented Wss frame + self._dummySSLSocket.setReadFragmented() + self._dummySSLSocket.addReadBufferFragment(coolFramePart1) + self._dummySSLSocket.addReadBufferFragment(coolFramePart2) + self._dummySSLSocket.addReadBufferFragment(coolFramePart3) + self._dummySSLSocket.loadFirstFragmented() + # In this way, reading from SSLSocket will result in 3 sslError, simulating the situation where data is not ready + # Init securedWebsocket with this mockSSLSocket + self._dummySecuredWebsocket = \ + mockSecuredWebsocketCoreNoRealHandshake(self._dummySSLSocket, "data.iot.region.amazonaws.com", 1234) + # Read it back: + readItBack = bytearray() + while len(readItBack) != len(rawPayloadFragmented): + try: + # Will be interrupted due to faked socket I/O Error + # Should be able to read back the complete + readItBack += self._dummySecuredWebsocket.read(len(rawPayloadFragmented)) # Basically read everything + except: + pass + assert rawPayloadFragmented == readItBack + + def test_WssReadlongFrame(self): + # Config mockSSLSocket to contain a Wss frame + rawPayloadLong = bytearray(self._generateStringOfAs(300), 'utf-8') # 300 bytes of raw payload, will use extended payload length bytes in encoding + # The payload of this frame will be masked by a randomly-generated mask key + # securedWebsocketCore should be able to decode it and get the raw payload back + coolFrame = self._encodeFrame(rawPayloadLong, self._OP_BINARY, FIN=1, RSV1=0, RSV2=0, RSV3=0, masked=0) + # self._printByteArray(coolFrame) + self._dummySSLSocket.refreshReadBuffer(coolFrame) + # Init securedWebsocket with this mockSSLSocket + self._dummySecuredWebsocket = \ + mockSecuredWebsocketCoreNoRealHandshake(self._dummySSLSocket, "data.iot.region.amazonaws.com", 1234) + # Read it back: + readItBack = self._dummySecuredWebsocket.read(len(rawPayloadLong)) # Basically read everything + assert rawPayloadLong == readItBack + + def test_WssReadReallylongFrame(self): + # Config mockSSLSocket to contain a Wss frame + # Maximum allowed length of a wss payload is greater than maximum allowed payload length of a MQTT payload + rawPayloadLong = bytearray(self._generateStringOfAs(0xffff + 3), 'utf-8') # 0xffff + 3 bytes of raw payload, will use extended payload length bytes in encoding + # The payload of this frame will be masked by a randomly-generated mask key + # securedWebsocketCore should be able to decode it and get the raw payload back + coolFrame = self._encodeFrame(rawPayloadLong, self._OP_BINARY, FIN=1, RSV1=0, RSV2=0, RSV3=0, masked=0) + # self._printByteArray(coolFrame) + self._dummySSLSocket.refreshReadBuffer(coolFrame) + # Init securedWebsocket with this mockSSLSocket + self._dummySecuredWebsocket = \ + mockSecuredWebsocketCoreNoRealHandshake(self._dummySSLSocket, "data.iot.region.amazonaws.com", 1234) + # Read it back: + readItBack = self._dummySecuredWebsocket.read(len(rawPayloadLong)) # Basically read everything + assert rawPayloadLong == readItBack + + def test_WssWriteComplete(self): + ToBeWritten = b"Write me to the cloud." + # Init securedWebsocket with this mockSSLSocket + self._dummySecuredWebsocket = \ + mockSecuredWebsocketCoreNoRealHandshake(self._dummySSLSocket, "data.iot.region.amazonaws.com", 1234) + # Fire the write op + self._dummySecuredWebsocket.write(ToBeWritten) + ans = self._encodeFrame(ToBeWritten, self._OP_BINARY, FIN=1, RSV1=0, RSV2=0, RSV3=0, masked=1) + # self._printByteArray(ans) + assert ans == self._dummySSLSocket.getWriteBuffer() + + def test_WssWriteFragmented(self): + ToBeWritten = b"Write me to the cloud again." + # Configure SSLSocket to perform interrupted write op + self._dummySSLSocket.setFlipWriteError() + # Init securedWebsocket with this mockSSLSocket + self._dummySecuredWebsocket = \ + mockSecuredWebsocketCoreNoRealHandshake(self._dummySSLSocket, "data.iot.region.amazonaws.com", 1234) + # Fire the write op + with pytest.raises(socket.error) as e: + self._dummySecuredWebsocket.write(ToBeWritten) + assert "Not ready for write op" == e.value.strerror + lengthWritten = self._dummySecuredWebsocket.write(ToBeWritten) + ans = self._encodeFrame(ToBeWritten, self._OP_BINARY, FIN=1, RSV1=0, RSV2=0, RSV3=0, masked=1) + assert lengthWritten == len(ToBeWritten) + assert ans == self._dummySSLSocket.getWriteBuffer() + + # Wss Client Behavior + def test_ClientClosesConnectionIfServerResponseIsMasked(self): + ToBeWritten = b"I am designed to be masked." + maskedFrame = self._encodeFrame(ToBeWritten, self._OP_BINARY, FIN=1, RSV1=0, RSV2=0, RSV3=0, masked=1) + self._dummySSLSocket.refreshReadBuffer(maskedFrame) + # Init securedWebsocket with this mockSSLSocket + self._dummySecuredWebsocket = \ + mockSecuredWebsocketCoreNoRealHandshake(self._dummySSLSocket, "data.iot.region.amazonaws.com", 1234) + # Now read it back + with pytest.raises(socket.error) as e: + self._dummySecuredWebsocket.read(len(ToBeWritten)) + assert "Server response masked, closing connection and try again." == e.value.strerror + # Verify that a closing frame from the client is on its way + closingFrame = self._encodeFrame(b"", self._OP_CONNECTION_CLOSE, FIN=1, RSV1=0, RSV2=0, RSV3=0, masked=1) + assert closingFrame == self._dummySSLSocket.getWriteBuffer() + + def test_ClientClosesConnectionIfServerResponseHasReserveBitsSet(self): + ToBeWritten = b"I am designed to be masked." + maskedFrame = self._encodeFrame(ToBeWritten, self._OP_BINARY, FIN=1, RSV1=1, RSV2=0, RSV3=0, masked=1) + self._dummySSLSocket.refreshReadBuffer(maskedFrame) + # Init securedWebsocket with this mockSSLSocket + self._dummySecuredWebsocket = \ + mockSecuredWebsocketCoreNoRealHandshake(self._dummySSLSocket, "data.iot.region.amazonaws.com", 1234) + # Now read it back + with pytest.raises(socket.error) as e: + self._dummySecuredWebsocket.read(len(ToBeWritten)) + assert "RSV bits set with NO negotiated extensions." == e.value.strerror + # Verify that a closing frame from the client is on its way + closingFrame = self._encodeFrame(b"", self._OP_CONNECTION_CLOSE, FIN=1, RSV1=0, RSV2=0, RSV3=0, masked=1) + assert closingFrame == self._dummySSLSocket.getWriteBuffer() + + def test_ClientSendsPONGIfReceivedPING(self): + PINGFrame = self._encodeFrame(b"", self._OP_PING, FIN=1, RSV1=0, RSV2=0, RSV3=0, masked=0) + self._dummySSLSocket.refreshReadBuffer(PINGFrame) + # Init securedWebsocket with this mockSSLSocket + self._dummySecuredWebsocket = \ + mockSecuredWebsocketCoreNoRealHandshake(self._dummySSLSocket, "data.iot.region.amazonaws.com", 1234) + # Now read it back, this must be in the next round of paho MQTT packet reading + # Should fail since we only have a PING to read, it never contains a valid MQTT payload + with pytest.raises(socket.error) as e: + self._dummySecuredWebsocket.read(5) + assert "Not a complete MQTT packet payload within this wss frame." == e.value.strerror + # Verify that PONG frame from the client is on its way + PONGFrame = self._encodeFrame(b"", self._OP_PONG, FIN=1, RSV1=0, RSV2=0, RSV3=0, masked=1) + assert PONGFrame == self._dummySSLSocket.getWriteBuffer() + diff --git a/test/core/protocol/internal/__init__.py b/test/core/protocol/internal/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/core/protocol/internal/test_clients_client_status.py b/test/core/protocol/internal/test_clients_client_status.py new file mode 100644 index 0000000..b84a0d6 --- /dev/null +++ b/test/core/protocol/internal/test_clients_client_status.py @@ -0,0 +1,31 @@ +from AWSIoTPythonSDK.core.protocol.internal.clients import ClientStatus +from AWSIoTPythonSDK.core.protocol.internal.clients import ClientStatusContainer + + +class TestClientsClientStatus: + + def setup_method(self, test_method): + self.client_status = ClientStatusContainer() + + def test_set_client_status(self): + assert self.client_status.get_status() == ClientStatus.IDLE # Client status should start with IDLE + self._set_client_status_and_verify(ClientStatus.ABNORMAL_DISCONNECT) + self._set_client_status_and_verify(ClientStatus.CONNECT) + self._set_client_status_and_verify(ClientStatus.RESUBSCRIBE) + self._set_client_status_and_verify(ClientStatus.DRAINING) + self._set_client_status_and_verify(ClientStatus.STABLE) + + def test_client_status_does_not_change_unless_user_connect_after_user_disconnect(self): + self.client_status.set_status(ClientStatus.USER_DISCONNECT) + self._set_client_status_and_verify(ClientStatus.ABNORMAL_DISCONNECT, ClientStatus.USER_DISCONNECT) + self._set_client_status_and_verify(ClientStatus.RESUBSCRIBE, ClientStatus.USER_DISCONNECT) + self._set_client_status_and_verify(ClientStatus.DRAINING, ClientStatus.USER_DISCONNECT) + self._set_client_status_and_verify(ClientStatus.STABLE, ClientStatus.USER_DISCONNECT) + self._set_client_status_and_verify(ClientStatus.CONNECT) + + def _set_client_status_and_verify(self, set_client_status_type, verify_client_status_type=None): + self.client_status.set_status(set_client_status_type) + if verify_client_status_type: + assert self.client_status.get_status() == verify_client_status_type + else: + assert self.client_status.get_status() == set_client_status_type diff --git a/test/core/protocol/internal/test_clients_internal_async_client.py b/test/core/protocol/internal/test_clients_internal_async_client.py new file mode 100644 index 0000000..2d0e3cf --- /dev/null +++ b/test/core/protocol/internal/test_clients_internal_async_client.py @@ -0,0 +1,388 @@ +from AWSIoTPythonSDK.core.protocol.internal.clients import InternalAsyncMqttClient +from AWSIoTPythonSDK.core.protocol.internal.events import FixedEventMids +from AWSIoTPythonSDK.core.util.providers import CertificateCredentialsProvider +from AWSIoTPythonSDK.core.util.providers import IAMCredentialsProvider +from AWSIoTPythonSDK.core.util.providers import EndpointProvider +from AWSIoTPythonSDK.core.util.providers import CiphersProvider +from AWSIoTPythonSDK.core.protocol.paho.client import MQTTv311 +from AWSIoTPythonSDK.core.protocol.paho.client import Client +from AWSIoTPythonSDK.core.protocol.paho.client import MQTT_ERR_SUCCESS +from AWSIoTPythonSDK.core.protocol.paho.client import MQTT_ERR_ERRNO +try: + from mock import patch + from mock import MagicMock + from mock import NonCallableMagicMock +except: + from unittest.mock import patch + from unittest.mock import MagicMock + from unittest.mock import NonCallableMagicMock +import ssl +import pytest + + +DUMMY_CLIENT_ID = "CoolClientId" +FAKE_PATH = "/fake/path/" +DUMMY_CA_PATH = FAKE_PATH + "ca.crt" +DUMMY_CERT_PATH = FAKE_PATH + "cert.pem" +DUMMY_KEY_PATH = FAKE_PATH + "key.pem" +DUMMY_ACCESS_KEY_ID = "AccessKeyId" +DUMMY_SECRET_ACCESS_KEY = "SecretAccessKey" +DUMMY_SESSION_TOKEN = "SessionToken" +DUMMY_TOPIC = "topic/test" +DUMMY_PAYLOAD = "TestPayload" +DUMMY_QOS = 1 +DUMMY_BASE_RECONNECT_QUIET_SEC = 1 +DUMMY_MAX_RECONNECT_QUIET_SEC = 32 +DUMMY_STABLE_CONNECTION_SEC = 20 +DUMMY_ENDPOINT = "dummy.endpoint.com" +DUMMY_PORT = 8888 +DUMMY_SUCCESS_RC = MQTT_ERR_SUCCESS +DUMMY_FAILURE_RC = MQTT_ERR_ERRNO +DUMMY_KEEP_ALIVE_SEC = 60 +DUMMY_REQUEST_MID = 89757 +DUMMY_USERNAME = "DummyUsername" +DUMMY_PASSWORD = "DummyPassword" +DUMMY_ALPN_PROTOCOLS = ["DummyALPNProtocol"] + +KEY_GET_CA_PATH_CALL_COUNT = "get_ca_path_call_count" +KEY_GET_CERT_PATH_CALL_COUNT = "get_cert_path_call_count" +KEY_GET_KEY_PATH_CALL_COUNT = "get_key_path_call_count" + +class TestClientsInternalAsyncClient: + + def setup_method(self, test_method): + # We init a cert based client by default + self.internal_async_client = InternalAsyncMqttClient(DUMMY_CLIENT_ID, False, MQTTv311, False) + self._mock_internal_members() + + def _mock_internal_members(self): + self.mock_paho_client = MagicMock(spec=Client) + # TODO: See if we can replace the following with patch.object + self.internal_async_client._paho_client = self.mock_paho_client + + def test_set_cert_credentials_provider_x509(self): + mock_cert_credentials_provider = self._mock_cert_credentials_provider() + cipher_provider = CiphersProvider() + self.internal_async_client.set_cert_credentials_provider(mock_cert_credentials_provider, cipher_provider) + + expected_call_count = { + KEY_GET_CA_PATH_CALL_COUNT : 1, + KEY_GET_CERT_PATH_CALL_COUNT : 1, + KEY_GET_KEY_PATH_CALL_COUNT : 1 + } + self._verify_cert_credentials_provider(mock_cert_credentials_provider, expected_call_count) + self.mock_paho_client.tls_set.assert_called_once_with(ca_certs=DUMMY_CA_PATH, + certfile=DUMMY_CERT_PATH, + keyfile=DUMMY_KEY_PATH, + cert_reqs=ssl.CERT_REQUIRED, + tls_version=ssl.PROTOCOL_SSLv23, + ciphers=cipher_provider.get_ciphers()) + + def test_set_cert_credentials_provider_wss(self): + self.internal_async_client = InternalAsyncMqttClient(DUMMY_CLIENT_ID, False, MQTTv311, True) + self._mock_internal_members() + mock_cert_credentials_provider = self._mock_cert_credentials_provider() + cipher_provider = CiphersProvider() + + self.internal_async_client.set_cert_credentials_provider(mock_cert_credentials_provider, cipher_provider) + + expected_call_count = { + KEY_GET_CA_PATH_CALL_COUNT : 1 + } + self._verify_cert_credentials_provider(mock_cert_credentials_provider, expected_call_count) + self.mock_paho_client.tls_set.assert_called_once_with(ca_certs=DUMMY_CA_PATH, + cert_reqs=ssl.CERT_REQUIRED, + tls_version=ssl.PROTOCOL_SSLv23, + ciphers=cipher_provider.get_ciphers()) + + def _mock_cert_credentials_provider(self): + mock_cert_credentials_provider = MagicMock(spec=CertificateCredentialsProvider) + mock_cert_credentials_provider.get_ca_path.return_value = DUMMY_CA_PATH + mock_cert_credentials_provider.get_cert_path.return_value = DUMMY_CERT_PATH + mock_cert_credentials_provider.get_key_path.return_value = DUMMY_KEY_PATH + return mock_cert_credentials_provider + + def _verify_cert_credentials_provider(self, mock_cert_credentials_provider, expected_values): + expected_get_ca_path_call_count = expected_values.get(KEY_GET_CA_PATH_CALL_COUNT) + expected_get_cert_path_call_count = expected_values.get(KEY_GET_CERT_PATH_CALL_COUNT) + expected_get_key_path_call_count = expected_values.get(KEY_GET_KEY_PATH_CALL_COUNT) + + if expected_get_ca_path_call_count is not None: + assert mock_cert_credentials_provider.get_ca_path.call_count == expected_get_ca_path_call_count + if expected_get_cert_path_call_count is not None: + assert mock_cert_credentials_provider.get_cert_path.call_count == expected_get_cert_path_call_count + if expected_get_key_path_call_count is not None: + assert mock_cert_credentials_provider.get_key_path.call_count == expected_get_key_path_call_count + + def test_set_iam_credentials_provider(self): + self.internal_async_client = InternalAsyncMqttClient(DUMMY_CLIENT_ID, False, MQTTv311, True) + self._mock_internal_members() + mock_iam_credentials_provider = self._mock_iam_credentials_provider() + + self.internal_async_client.set_iam_credentials_provider(mock_iam_credentials_provider) + + self._verify_iam_credentials_provider(mock_iam_credentials_provider) + + def _mock_iam_credentials_provider(self): + mock_iam_credentials_provider = MagicMock(spec=IAMCredentialsProvider) + mock_iam_credentials_provider.get_ca_path.return_value = DUMMY_CA_PATH + mock_iam_credentials_provider.get_access_key_id.return_value = DUMMY_ACCESS_KEY_ID + mock_iam_credentials_provider.get_secret_access_key.return_value = DUMMY_SECRET_ACCESS_KEY + mock_iam_credentials_provider.get_session_token.return_value = DUMMY_SESSION_TOKEN + return mock_iam_credentials_provider + + def _verify_iam_credentials_provider(self, mock_iam_credentials_provider): + assert mock_iam_credentials_provider.get_access_key_id.call_count == 1 + assert mock_iam_credentials_provider.get_secret_access_key.call_count == 1 + assert mock_iam_credentials_provider.get_session_token.call_count == 1 + + def test_configure_last_will(self): + self.internal_async_client.configure_last_will(DUMMY_TOPIC, DUMMY_PAYLOAD, DUMMY_QOS) + self.mock_paho_client.will_set.assert_called_once_with(DUMMY_TOPIC, + DUMMY_PAYLOAD, + DUMMY_QOS, + False) + + def test_clear_last_will(self): + self.internal_async_client.clear_last_will() + assert self.mock_paho_client.will_clear.call_count == 1 + + def test_set_username_password(self): + self.internal_async_client.set_username_password(DUMMY_USERNAME, DUMMY_PASSWORD) + self.mock_paho_client.username_pw_set.assert_called_once_with(DUMMY_USERNAME, DUMMY_PASSWORD) + + def test_configure_reconnect_back_off(self): + self.internal_async_client.configure_reconnect_back_off(DUMMY_BASE_RECONNECT_QUIET_SEC, + DUMMY_MAX_RECONNECT_QUIET_SEC, + DUMMY_STABLE_CONNECTION_SEC) + self.mock_paho_client.setBackoffTiming.assert_called_once_with(DUMMY_BASE_RECONNECT_QUIET_SEC, + DUMMY_MAX_RECONNECT_QUIET_SEC, + DUMMY_STABLE_CONNECTION_SEC) + def test_configure_alpn_protocols(self): + self.internal_async_client.configure_alpn_protocols(DUMMY_ALPN_PROTOCOLS) + self.mock_paho_client.config_alpn_protocols.assert_called_once_with(DUMMY_ALPN_PROTOCOLS) + + def test_connect_success_rc(self): + self._internal_test_connect_with_rc(DUMMY_SUCCESS_RC) + + def test_connect_failure_rc(self): + self._internal_test_connect_with_rc(DUMMY_FAILURE_RC) + + def _internal_test_connect_with_rc(self, expected_connect_rc): + mock_endpoint_provider = self._mock_endpoint_provider() + self.mock_paho_client.connect.return_value = expected_connect_rc + self.internal_async_client.set_endpoint_provider(mock_endpoint_provider) + + actual_rc = self.internal_async_client.connect(DUMMY_KEEP_ALIVE_SEC) + + assert mock_endpoint_provider.get_host.call_count == 1 + assert mock_endpoint_provider.get_port.call_count == 1 + event_callback_map = self.internal_async_client.get_event_callback_map() + assert len(event_callback_map) == 3 + assert event_callback_map[FixedEventMids.CONNACK_MID] is not None + assert event_callback_map[FixedEventMids.DISCONNECT_MID] is not None + assert event_callback_map[FixedEventMids.MESSAGE_MID] is not None + assert self.mock_paho_client.connect.call_count == 1 + if expected_connect_rc == MQTT_ERR_SUCCESS: + assert self.mock_paho_client.loop_start.call_count == 1 + else: + assert self.mock_paho_client.loop_start.call_count == 0 + assert actual_rc == expected_connect_rc + + def _mock_endpoint_provider(self): + mock_endpoint_provider = MagicMock(spec=EndpointProvider) + mock_endpoint_provider.get_host.return_value = DUMMY_ENDPOINT + mock_endpoint_provider.get_port.return_value = DUMMY_PORT + return mock_endpoint_provider + + def test_start_background_network_io(self): + self.internal_async_client.start_background_network_io() + assert self.mock_paho_client.loop_start.call_count == 1 + + def test_stop_background_network_io(self): + self.internal_async_client.stop_background_network_io() + assert self.mock_paho_client.loop_stop.call_count == 1 + + def test_disconnect_success_rc(self): + self._internal_test_disconnect_with_rc(DUMMY_SUCCESS_RC) + + def test_disconnect_failure_rc(self): + self._internal_test_disconnect_with_rc(DUMMY_FAILURE_RC) + + def _internal_test_disconnect_with_rc(self, expected_disconnect_rc): + self.mock_paho_client.disconnect.return_value = expected_disconnect_rc + + actual_rc = self.internal_async_client.disconnect() + + event_callback_map = self.internal_async_client.get_event_callback_map() + assert self.mock_paho_client.disconnect.call_count == 1 + if expected_disconnect_rc == MQTT_ERR_SUCCESS: + # Since we only call disconnect, there should be only one registered callback + assert len(event_callback_map) == 1 + assert event_callback_map[FixedEventMids.DISCONNECT_MID] is not None + else: + assert len(event_callback_map) == 0 + assert actual_rc == expected_disconnect_rc + + def test_publish_qos0_success_rc(self): + self._internal_test_publish_with(0, DUMMY_REQUEST_MID, DUMMY_SUCCESS_RC) + self._internal_test_publish_with(0, DUMMY_REQUEST_MID, DUMMY_SUCCESS_RC, NonCallableMagicMock()) + + def test_publish_qos0_failure_rc(self): + self._internal_test_publish_with(0, DUMMY_REQUEST_MID, DUMMY_FAILURE_RC) + self._internal_test_publish_with(0, DUMMY_REQUEST_MID, DUMMY_FAILURE_RC, NonCallableMagicMock()) + + def test_publish_qos1_success_rc(self): + self._internal_test_publish_with(1, DUMMY_REQUEST_MID, DUMMY_SUCCESS_RC) + self._internal_test_publish_with(1, DUMMY_REQUEST_MID, DUMMY_SUCCESS_RC, NonCallableMagicMock()) + + def test_publish_qos1_failure_rc(self): + self._internal_test_publish_with(1, DUMMY_REQUEST_MID, DUMMY_FAILURE_RC) + self._internal_test_publish_with(1, DUMMY_REQUEST_MID, DUMMY_FAILURE_RC, NonCallableMagicMock()) + + def _internal_test_publish_with(self, qos, expected_mid, expected_rc, expected_callback=None): + self.mock_paho_client.publish.return_value = expected_rc, expected_mid + + actual_rc, actual_mid = self.internal_async_client.publish(DUMMY_TOPIC, + DUMMY_PAYLOAD, + qos, + retain=False, + ack_callback=expected_callback) + + self._verify_event_callback_map_for_pub_sub_unsub(expected_rc, expected_mid, qos, expected_callback) + assert actual_rc == expected_rc + assert actual_mid == expected_mid + + def test_subscribe_success_rc(self): + self._internal_test_subscribe_with(DUMMY_REQUEST_MID, DUMMY_SUCCESS_RC) + self._internal_test_subscribe_with(DUMMY_REQUEST_MID, DUMMY_SUCCESS_RC, NonCallableMagicMock()) + + def test_subscribe_failure_rc(self): + self._internal_test_subscribe_with(DUMMY_REQUEST_MID, DUMMY_FAILURE_RC) + self._internal_test_subscribe_with(DUMMY_REQUEST_MID, DUMMY_FAILURE_RC, NonCallableMagicMock()) + + def _internal_test_subscribe_with(self, expected_mid, expected_rc, expected_callback=None): + self.mock_paho_client.subscribe.return_value = expected_rc, expected_mid + + actual_rc, actual_mid = self.internal_async_client.subscribe(DUMMY_TOPIC, DUMMY_QOS, expected_callback) + + self._verify_event_callback_map_for_pub_sub_unsub(expected_rc, expected_mid, qos=None, callback=expected_callback) + assert actual_rc == expected_rc + assert actual_mid == expected_mid + + def test_unsubscribe_success_rc(self): + self._internal_test_unsubscribe_with(DUMMY_REQUEST_MID, DUMMY_SUCCESS_RC) + self._internal_test_unsubscribe_with(DUMMY_REQUEST_MID, DUMMY_SUCCESS_RC, NonCallableMagicMock()) + + def test_unsubscribe_failure_rc(self): + self._internal_test_unsubscribe_with(DUMMY_REQUEST_MID, DUMMY_FAILURE_RC) + self._internal_test_publish_with(DUMMY_REQUEST_MID, DUMMY_FAILURE_RC, NonCallableMagicMock()) + + def _internal_test_unsubscribe_with(self, expected_mid, expected_rc, expected_callback=None): + self.mock_paho_client.unsubscribe.return_value = expected_rc, expected_mid + + actual_rc, actual_mid = self.internal_async_client.unsubscribe(DUMMY_TOPIC, expected_callback) + + self._verify_event_callback_map_for_pub_sub_unsub(expected_rc, expected_mid, qos=None, callback=expected_callback) + assert actual_rc == expected_rc + assert actual_mid == expected_mid + + def _verify_event_callback_map_for_pub_sub_unsub(self, expected_rc, expected_mid, qos=None, callback=None): + event_callback_map = self.internal_async_client.get_event_callback_map() + should_have_callback_in_map = expected_rc == DUMMY_SUCCESS_RC and callback + if qos is not None: + should_have_callback_in_map = should_have_callback_in_map and qos > 0 + + if should_have_callback_in_map: + # Since we only perform this request, there should be only one registered callback + assert len(event_callback_map) == 1 + assert event_callback_map[expected_mid] == callback + else: + assert len(event_callback_map) == 0 + + def test_register_internal_event_callbacks(self): + expected_callback = NonCallableMagicMock() + self.internal_async_client.register_internal_event_callbacks(expected_callback, + expected_callback, + expected_callback, + expected_callback, + expected_callback, + expected_callback) + self._verify_internal_event_callbacks(expected_callback) + + def test_unregister_internal_event_callbacks(self): + self.internal_async_client.unregister_internal_event_callbacks() + self._verify_internal_event_callbacks(None) + + def _verify_internal_event_callbacks(self, expected_callback): + assert self.mock_paho_client.on_connect == expected_callback + assert self.mock_paho_client.on_disconnect == expected_callback + assert self.mock_paho_client.on_publish == expected_callback + assert self.mock_paho_client.on_subscribe == expected_callback + assert self.mock_paho_client.on_unsubscribe == expected_callback + assert self.mock_paho_client.on_message == expected_callback + + def test_invoke_event_callback_fixed_request(self): + # We use disconnect as an example for fixed request to "register" and event callback + self.mock_paho_client.disconnect.return_value = DUMMY_SUCCESS_RC + event_callback = MagicMock() + rc = self.internal_async_client.disconnect(event_callback) + self.internal_async_client.invoke_event_callback(FixedEventMids.DISCONNECT_MID, rc) + + event_callback.assert_called_once_with(FixedEventMids.DISCONNECT_MID, rc) + event_callback_map = self.internal_async_client.get_event_callback_map() + assert len(event_callback_map) == 1 # Fixed request event callback never gets removed + assert event_callback_map[FixedEventMids.DISCONNECT_MID] is not None + + def test_invoke_event_callback_non_fixed_request(self): + # We use unsubscribe as an example for non-fixed request to "register" an event callback + self.mock_paho_client.unsubscribe.return_value = DUMMY_SUCCESS_RC, DUMMY_REQUEST_MID + event_callback = MagicMock() + rc, mid = self.internal_async_client.unsubscribe(DUMMY_TOPIC, event_callback) + self.internal_async_client.invoke_event_callback(mid) + + event_callback.assert_called_once_with(mid=mid) + event_callback_map = self.internal_async_client.get_event_callback_map() + assert len(event_callback_map) == 0 # Non-fixed request event callback gets removed after successfully invoked + + @pytest.mark.timeout(3) + def test_invoke_event_callback_that_has_client_api_call(self): + # We use subscribe and publish on SUBACK as an example of having client API call within event callbacks + self.mock_paho_client.subscribe.return_value = DUMMY_SUCCESS_RC, DUMMY_REQUEST_MID + self.mock_paho_client.publish.return_value = DUMMY_SUCCESS_RC, DUMMY_REQUEST_MID + 1 + rc, mid = self.internal_async_client.subscribe(DUMMY_TOPIC, DUMMY_QOS, ack_callback=self._publish_on_suback) + + self.internal_async_client.invoke_event_callback(mid, (DUMMY_QOS,)) + + event_callback_map = self.internal_async_client.get_event_callback_map() + assert len(event_callback_map) == 0 + + def _publish_on_suback(self, mid, data): + self.internal_async_client.publish(DUMMY_TOPIC, DUMMY_PAYLOAD, DUMMY_QOS) + + def test_remove_event_callback(self): + # We use unsubscribe as an example for non-fixed request to "register" an event callback + self.mock_paho_client.unsubscribe.return_value = DUMMY_SUCCESS_RC, DUMMY_REQUEST_MID + event_callback = MagicMock() + rc, mid = self.internal_async_client.unsubscribe(DUMMY_TOPIC, event_callback) + + event_callback_map = self.internal_async_client.get_event_callback_map() + assert len(event_callback_map) == 1 + + self.internal_async_client.remove_event_callback(mid) + assert len(event_callback_map) == 0 + + def test_clean_up_event_callbacks(self): + # We use unsubscribe as an example for on-fixed request to "register" an event callback + self.mock_paho_client.unsubscribe.return_value = DUMMY_SUCCESS_RC, DUMMY_REQUEST_MID + # We use disconnect as an example for fixed request to "register" and event callback + self.mock_paho_client.disconnect.return_value = DUMMY_SUCCESS_RC + event_callback = MagicMock() + self.internal_async_client.unsubscribe(DUMMY_TOPIC, event_callback) + self.internal_async_client.disconnect(event_callback) + + event_callback_map = self.internal_async_client.get_event_callback_map() + assert len(event_callback_map) == 2 + + self.internal_async_client.clean_up_event_callbacks() + assert len(event_callback_map) == 0 diff --git a/test/core/protocol/internal/test_offline_request_queue.py b/test/core/protocol/internal/test_offline_request_queue.py new file mode 100755 index 0000000..f666bb9 --- /dev/null +++ b/test/core/protocol/internal/test_offline_request_queue.py @@ -0,0 +1,67 @@ +import AWSIoTPythonSDK.core.protocol.internal.queues as Q +from AWSIoTPythonSDK.core.protocol.internal.queues import AppendResults +import pytest + + +class TestOfflineRequestQueue(): + + # Check that invalid input types are filtered out on initialization + def test_InvalidTypeInit(self): + with pytest.raises(TypeError): + Q.OfflineRequestQueue(1.7, 0) + with pytest.raises(TypeError): + Q.OfflineRequestQueue(0, 1.7) + + # Check that elements can be append to a normal finite queue + def test_NormalAppend(self): + coolQueue = Q.OfflineRequestQueue(20, 1) + numberOfMessages = 5 + answer = list(range(0, numberOfMessages)) + for i in range(0, numberOfMessages): + coolQueue.append(i) + assert answer == coolQueue + + # Check that new elements are dropped for DROPNEWEST configuration + def test_DropNewest(self): + coolQueue = Q.OfflineRequestQueue(3, 1) # Queueing section: 3, Response section: 1, DropNewest + numberOfMessages = 10 + answer = [0, 1, 2] # '0', '1' and '2' are stored, others are dropped. + fullCount = 0 + for i in range(0, numberOfMessages): + if coolQueue.append(i) == AppendResults.APPEND_FAILURE_QUEUE_FULL: + fullCount += 1 + assert answer == coolQueue + assert 7 == fullCount + + # Check that old elements are dropped for DROPOLDEST configuration + def test_DropOldest(self): + coolQueue = Q.OfflineRequestQueue(3, 0) + numberOfMessages = 10 + answer = [7, 8, 9] # '7', '8' and '9' are stored, others (older ones) are dropped. + fullCount = 0 + for i in range(0, numberOfMessages): + if coolQueue.append(i) == AppendResults.APPEND_FAILURE_QUEUE_FULL: + fullCount += 1 + assert answer == coolQueue + assert 7 == fullCount + + # Check infinite queue + def test_Infinite(self): + coolQueue = Q.OfflineRequestQueue(-100, 1) + numberOfMessages = 10000 + answer = list(range(0, numberOfMessages)) + for i in range(0, numberOfMessages): + coolQueue.append(i) + assert answer == coolQueue # Nothing should be dropped since response section is infinite + + # Check disabled queue + def test_Disabled(self): + coolQueue = Q.OfflineRequestQueue(0, 1) + numberOfMessages = 10 + answer = list() + disableFailureCount = 0 + for i in range(0, numberOfMessages): + if coolQueue.append(i) == AppendResults.APPEND_FAILURE_QUEUE_DISABLED: + disableFailureCount += 1 + assert answer == coolQueue # Nothing should be appended since the queue is disabled + assert numberOfMessages == disableFailureCount diff --git a/test/core/protocol/internal/test_workers_event_consumer.py b/test/core/protocol/internal/test_workers_event_consumer.py new file mode 100644 index 0000000..4edfb6c --- /dev/null +++ b/test/core/protocol/internal/test_workers_event_consumer.py @@ -0,0 +1,273 @@ +from AWSIoTPythonSDK.core.protocol.internal.workers import EventConsumer +from AWSIoTPythonSDK.core.protocol.internal.clients import ClientStatusContainer +from AWSIoTPythonSDK.core.protocol.internal.clients import ClientStatus +from AWSIoTPythonSDK.core.protocol.internal.events import FixedEventMids +from AWSIoTPythonSDK.core.protocol.internal.events import EventTypes +from AWSIoTPythonSDK.core.protocol.paho.client import MQTTMessage +from AWSIoTPythonSDK.core.protocol.internal.workers import SubscriptionManager +from AWSIoTPythonSDK.core.protocol.internal.workers import OfflineRequestsManager +from AWSIoTPythonSDK.core.protocol.internal.clients import InternalAsyncMqttClient +from AWSIoTPythonSDK.core.protocol.internal.requests import QueueableRequest +from AWSIoTPythonSDK.core.protocol.internal.requests import RequestTypes +from AWSIoTPythonSDK.core.protocol.internal.defaults import DEFAULT_DRAINING_INTERNAL_SEC +try: + from mock import patch + from mock import MagicMock + from mock import call +except: + from unittest.mock import patch + from unittest.mock import MagicMock + from unittest.mock import call +from threading import Condition +import time +import sys +if sys.version_info[0] < 3: + from Queue import Queue +else: + from queue import Queue + + +DUMMY_TOPIC = "dummy/topic" +DUMMY_MESSAGE = "dummy_message" +DUMMY_QOS = 1 +DUMMY_SUCCESS_RC = 0 +DUMMY_PUBACK_MID = 89757 +DUMMY_SUBACK_MID = 89758 +DUMMY_UNSUBACK_MID = 89579 + +KEY_CLIENT_STATUS_AFTER = "status_after" +KEY_STOP_BG_NW_IO_CALL_COUNT = "stop_background_network_io_call_count" +KEY_CLEAN_UP_EVENT_CBS_CALL_COUNT = "clean_up_event_callbacks_call_count" +KEY_IS_EVENT_Q_EMPTY = "is_event_queue_empty" +KEY_IS_EVENT_CONSUMER_UP = "is_event_consumer_running" + +class TestWorkersEventConsumer: + + def setup_method(self, test_method): + self.cv = Condition() + self.event_queue = Queue() + self.client_status = ClientStatusContainer() + self.internal_async_client = MagicMock(spec=InternalAsyncMqttClient) + self.subscription_manager = MagicMock(spec=SubscriptionManager) + self.offline_requests_manager = MagicMock(spec=OfflineRequestsManager) + self.message_callback = MagicMock() + self.subscribe_callback = MagicMock() + self.unsubscribe_callback = MagicMock() + self.event_consumer = None + + def teardown_method(self, test_method): + if self.event_consumer and self.event_consumer.is_running(): + self.event_consumer.stop() + self.event_consumer.wait_until_it_stops(2) # Make sure the event consumer stops gracefully + + def test_update_draining_interval_sec(self): + EXPECTED_DRAINING_INTERVAL_SEC = 0.5 + self.load_mocks_into_test_target() + self.event_consumer.update_draining_interval_sec(EXPECTED_DRAINING_INTERVAL_SEC) + assert self.event_consumer.get_draining_interval_sec() == EXPECTED_DRAINING_INTERVAL_SEC + + def test_dispatch_message_event(self): + expected_message_event = self._configure_mocks_message_event() + self._start_consumer() + self._verify_message_event_dispatch(expected_message_event) + + def _configure_mocks_message_event(self): + message_event = self._create_message_event(DUMMY_TOPIC, DUMMY_MESSAGE, DUMMY_QOS) + self._fill_in_fake_events([message_event]) + self.subscription_manager.list_records.return_value = [(DUMMY_TOPIC, (DUMMY_QOS, self.message_callback, self.subscribe_callback))] + self.load_mocks_into_test_target() + return message_event + + def _create_message_event(self, topic, payload, qos): + mqtt_message = MQTTMessage() + mqtt_message.topic = topic + mqtt_message.payload = payload + mqtt_message.qos = qos + return FixedEventMids.MESSAGE_MID, EventTypes.MESSAGE, mqtt_message + + def _verify_message_event_dispatch(self, expected_message_event): + expected_message = expected_message_event[2] + self.message_callback.assert_called_once_with(None, None, expected_message) + self.internal_async_client.invoke_event_callback.assert_called_once_with(FixedEventMids.MESSAGE_MID, data=expected_message) + assert self.event_consumer.is_running() is True + + def test_dispatch_disconnect_event_user_disconnect(self): + self._configure_mocks_disconnect_event(ClientStatus.USER_DISCONNECT) + self._start_consumer() + expected_values = { + KEY_CLIENT_STATUS_AFTER : ClientStatus.USER_DISCONNECT, + KEY_STOP_BG_NW_IO_CALL_COUNT : 1, + KEY_CLEAN_UP_EVENT_CBS_CALL_COUNT : 1, + KEY_IS_EVENT_Q_EMPTY : True, + KEY_IS_EVENT_CONSUMER_UP : False + } + self._verify_disconnect_event_dispatch(expected_values) + assert self.event_consumer.is_fully_stopped() is True + + def test_dispatch_disconnect_event_connect_failure(self): + self._configure_mocks_disconnect_event(ClientStatus.CONNECT) + self._start_consumer() + expected_values = { + KEY_CLIENT_STATUS_AFTER : ClientStatus.CONNECT, + KEY_STOP_BG_NW_IO_CALL_COUNT : 1, + KEY_CLEAN_UP_EVENT_CBS_CALL_COUNT : 1, + KEY_IS_EVENT_Q_EMPTY : True, + KEY_IS_EVENT_CONSUMER_UP : False + } + self._verify_disconnect_event_dispatch(expected_values) + assert self.event_consumer.is_fully_stopped() is True + + def test_dispatch_disconnect_event_abnormal_disconnect(self): + self._configure_mocks_disconnect_event(ClientStatus.STABLE) + self._start_consumer() + expected_values = { + KEY_CLIENT_STATUS_AFTER : ClientStatus.ABNORMAL_DISCONNECT, + KEY_STOP_BG_NW_IO_CALL_COUNT : 0, + KEY_CLEAN_UP_EVENT_CBS_CALL_COUNT : 0, + KEY_IS_EVENT_CONSUMER_UP : True + } + self._verify_disconnect_event_dispatch(expected_values) + assert self.event_consumer.is_fully_stopped() is False + + def _configure_mocks_disconnect_event(self, start_client_status): + self.client_status.set_status(start_client_status) + self._fill_in_fake_events([self._create_disconnect_event()]) + self.load_mocks_into_test_target() + + def _create_disconnect_event(self): + return FixedEventMids.DISCONNECT_MID, EventTypes.DISCONNECT, DUMMY_SUCCESS_RC + + def _verify_disconnect_event_dispatch(self, expected_values): + client_status_after = expected_values.get(KEY_CLIENT_STATUS_AFTER) + stop_background_network_io_call_count = expected_values.get(KEY_STOP_BG_NW_IO_CALL_COUNT) + clean_up_event_callbacks_call_count = expected_values.get(KEY_CLEAN_UP_EVENT_CBS_CALL_COUNT) + is_event_queue_empty = expected_values.get(KEY_IS_EVENT_Q_EMPTY) + is_event_consumer_running = expected_values.get(KEY_IS_EVENT_CONSUMER_UP) + + if client_status_after is not None: + assert self.client_status.get_status() == client_status_after + if stop_background_network_io_call_count is not None: + assert self.internal_async_client.stop_background_network_io.call_count == stop_background_network_io_call_count + if clean_up_event_callbacks_call_count is not None: + assert self.internal_async_client.clean_up_event_callbacks.call_count == clean_up_event_callbacks_call_count + if is_event_queue_empty is not None: + assert self.event_queue.empty() == is_event_queue_empty + if is_event_consumer_running is not None: + assert self.event_consumer.is_running() == is_event_consumer_running + + self.internal_async_client.invoke_event_callback.assert_called_once_with(FixedEventMids.DISCONNECT_MID, data=DUMMY_SUCCESS_RC) + + def test_dispatch_connack_event_no_recovery(self): + self._configure_mocks_connack_event() + self._start_consumer() + self._verify_connack_event_dispatch() + + def test_dispatch_connack_event_need_resubscribe(self): + resub_records = [ + (DUMMY_TOPIC + "1", (DUMMY_QOS, self.message_callback, self.subscribe_callback)), + (DUMMY_TOPIC + "2", (DUMMY_QOS, self.message_callback, self.subscribe_callback)), + (DUMMY_TOPIC + "3", (DUMMY_QOS, self.message_callback, self.subscribe_callback)) + ] + self._configure_mocks_connack_event(resubscribe_records=resub_records) + self._start_consumer() + self._verify_connack_event_dispatch(resubscribe_records=resub_records) + + def test_dispatch_connack_event_need_draining(self): + self._configure_mocks_connack_event(need_draining=True) + self._start_consumer() + self._verify_connack_event_dispatch(need_draining=True) + + def test_dispatch_connack_event_need_resubscribe_draining(self): + resub_records = [ + (DUMMY_TOPIC + "1", (DUMMY_QOS, self.message_callback, self.subscribe_callback)), + (DUMMY_TOPIC + "2", (DUMMY_QOS, self.message_callback, self.subscribe_callback)), + (DUMMY_TOPIC + "3", (DUMMY_QOS, self.message_callback, self.subscribe_callback)) + ] + self._configure_mocks_connack_event(resubscribe_records=resub_records, need_draining=True) + self._start_consumer() + self._verify_connack_event_dispatch(resubscribe_records=resub_records, need_draining=True) + + def _configure_mocks_connack_event(self, resubscribe_records=list(), need_draining=False): + self.client_status.set_status(ClientStatus.CONNECT) + self._fill_in_fake_events([self._create_connack_event()]) + self.subscription_manager.list_records.return_value = resubscribe_records + if need_draining: # We pack publish, subscribe and unsubscribe requests into the offline queue + if resubscribe_records: + has_more_side_effect_list = 4 * [True] + else: + has_more_side_effect_list = 5 * [True] + has_more_side_effect_list += [False] + self.offline_requests_manager.has_more.side_effect = has_more_side_effect_list + self.offline_requests_manager.get_next.side_effect = [ + QueueableRequest(RequestTypes.PUBLISH, (DUMMY_TOPIC, DUMMY_MESSAGE, DUMMY_QOS, False)), + QueueableRequest(RequestTypes.SUBSCRIBE, (DUMMY_TOPIC, DUMMY_QOS, self.message_callback, self.subscribe_callback)), + QueueableRequest(RequestTypes.UNSUBSCRIBE, (DUMMY_TOPIC, self.unsubscribe_callback)) + ] + else: + self.offline_requests_manager.has_more.return_value = False + self.load_mocks_into_test_target() + + def _create_connack_event(self): + return FixedEventMids.CONNACK_MID, EventTypes.CONNACK, DUMMY_SUCCESS_RC + + def _verify_connack_event_dispatch(self, resubscribe_records=list(), need_draining=False): + time.sleep(3 * DEFAULT_DRAINING_INTERNAL_SEC) # Make sure resubscribe/draining finishes + assert self.event_consumer.is_running() is True + self.internal_async_client.invoke_event_callback.assert_called_once_with(FixedEventMids.CONNACK_MID, data=DUMMY_SUCCESS_RC) + if resubscribe_records: + resub_call_sequence = [] + for topic, (qos, message_callback, subscribe_callback) in resubscribe_records: + resub_call_sequence.append(call(topic, qos, subscribe_callback)) + self.internal_async_client.subscribe.assert_has_calls(resub_call_sequence) + if need_draining: + assert self.internal_async_client.publish.call_count == 1 + assert self.internal_async_client.unsubscribe.call_count == 1 + assert self.internal_async_client.subscribe.call_count == len(resubscribe_records) + 1 + assert self.event_consumer.is_fully_stopped() is False + + def test_dispatch_puback_suback_unsuback_events(self): + self._configure_mocks_puback_suback_unsuback_events() + self._start_consumer() + self._verify_puback_suback_unsuback_events_dispatch() + + def _configure_mocks_puback_suback_unsuback_events(self): + self.client_status.set_status(ClientStatus.STABLE) + self._fill_in_fake_events([ + self._create_puback_event(DUMMY_PUBACK_MID), + self._create_suback_event(DUMMY_SUBACK_MID), + self._create_unsuback_event(DUMMY_UNSUBACK_MID)]) + self.load_mocks_into_test_target() + + def _verify_puback_suback_unsuback_events_dispatch(self): + assert self.event_consumer.is_running() is True + call_sequence = [ + call(DUMMY_PUBACK_MID, data=None), + call(DUMMY_SUBACK_MID, data=DUMMY_QOS), + call(DUMMY_UNSUBACK_MID, data=None)] + self.internal_async_client.invoke_event_callback.assert_has_calls(call_sequence) + assert self.event_consumer.is_fully_stopped() is False + + def _fill_in_fake_events(self, events): + for event in events: + self.event_queue.put(event) + + def _start_consumer(self): + self.event_consumer.start() + time.sleep(1) # Make sure the event gets picked up by the consumer + + def load_mocks_into_test_target(self): + self.event_consumer = EventConsumer(self.cv, + self.event_queue, + self.internal_async_client, + self.subscription_manager, + self.offline_requests_manager, + self.client_status) + + def _create_puback_event(self, mid): + return mid, EventTypes.PUBACK, None + + def _create_suback_event(self, mid): + return mid, EventTypes.SUBACK, DUMMY_QOS + + def _create_unsuback_event(self, mid): + return mid, EventTypes.UNSUBACK, None diff --git a/test/core/protocol/internal/test_workers_event_producer.py b/test/core/protocol/internal/test_workers_event_producer.py new file mode 100644 index 0000000..fbb97b2 --- /dev/null +++ b/test/core/protocol/internal/test_workers_event_producer.py @@ -0,0 +1,65 @@ +import pytest +from threading import Condition +from AWSIoTPythonSDK.core.protocol.internal.workers import EventProducer +from AWSIoTPythonSDK.core.protocol.internal.events import FixedEventMids +from AWSIoTPythonSDK.core.protocol.internal.events import EventTypes +from AWSIoTPythonSDK.core.protocol.paho.client import MQTTMessage +import sys +if sys.version_info[0] < 3: + from Queue import Queue +else: + from queue import Queue + +DUMMY_PAHO_CLIENT = None +DUMMY_USER_DATA = None +DUMMY_FLAGS = None +DUMMY_GRANTED_QOS = 1 +DUMMY_MID = 89757 +SUCCESS_RC = 0 + +MAX_CV_WAIT_TIME_SEC = 5 + +class TestWorkersEventProducer: + + def setup_method(self, test_method): + self._generate_test_targets() + + def test_produce_on_connect_event(self): + self.event_producer.on_connect(DUMMY_PAHO_CLIENT, DUMMY_USER_DATA, DUMMY_FLAGS, SUCCESS_RC) + self._verify_queued_event(self.event_queue, (FixedEventMids.CONNACK_MID, EventTypes.CONNACK, SUCCESS_RC)) + + def test_produce_on_disconnect_event(self): + self.event_producer.on_disconnect(DUMMY_PAHO_CLIENT, DUMMY_USER_DATA, SUCCESS_RC) + self._verify_queued_event(self.event_queue, (FixedEventMids.DISCONNECT_MID, EventTypes.DISCONNECT, SUCCESS_RC)) + + def test_produce_on_publish_event(self): + self.event_producer.on_publish(DUMMY_PAHO_CLIENT, DUMMY_USER_DATA, DUMMY_MID) + self._verify_queued_event(self.event_queue, (DUMMY_MID, EventTypes.PUBACK, None)) + + def test_produce_on_subscribe_event(self): + self.event_producer.on_subscribe(DUMMY_PAHO_CLIENT, DUMMY_USER_DATA, DUMMY_MID, DUMMY_GRANTED_QOS) + self._verify_queued_event(self.event_queue, (DUMMY_MID, EventTypes.SUBACK, DUMMY_GRANTED_QOS)) + + def test_produce_on_unsubscribe_event(self): + self.event_producer.on_unsubscribe(DUMMY_PAHO_CLIENT, DUMMY_USER_DATA, DUMMY_MID) + self._verify_queued_event(self.event_queue, (DUMMY_MID, EventTypes.UNSUBACK, None)) + + def test_produce_on_message_event(self): + dummy_message = MQTTMessage() + dummy_message.topic = "test/topic" + dummy_message.qos = 1 + dummy_message.payload = "test_payload" + self.event_producer.on_message(DUMMY_PAHO_CLIENT, DUMMY_USER_DATA, dummy_message) + self._verify_queued_event(self.event_queue, (FixedEventMids.MESSAGE_MID, EventTypes.MESSAGE, dummy_message)) + + def _generate_test_targets(self): + self.cv = Condition() + self.event_queue = Queue() + self.event_producer = EventProducer(self.cv, self.event_queue) + + def _verify_queued_event(self, queue, expected_results): + expected_mid, expected_event_type, expected_data = expected_results + actual_mid, actual_event_type, actual_data = queue.get() + assert actual_mid == expected_mid + assert actual_event_type == expected_event_type + assert actual_data == expected_data diff --git a/test/core/protocol/internal/test_workers_offline_requests_manager.py b/test/core/protocol/internal/test_workers_offline_requests_manager.py new file mode 100644 index 0000000..8193718 --- /dev/null +++ b/test/core/protocol/internal/test_workers_offline_requests_manager.py @@ -0,0 +1,69 @@ +import pytest +from AWSIoTPythonSDK.core.protocol.internal.workers import OfflineRequestsManager +from AWSIoTPythonSDK.core.util.enums import DropBehaviorTypes +from AWSIoTPythonSDK.core.protocol.internal.queues import AppendResults + +DEFAULT_QUEUE_SIZE = 3 +FAKE_REQUEST_PREFIX = "Fake Request " + +def test_has_more(): + offline_requests_manager = OfflineRequestsManager(DEFAULT_QUEUE_SIZE, DropBehaviorTypes.DROP_NEWEST) + + assert not offline_requests_manager.has_more() + + offline_requests_manager.add_one(FAKE_REQUEST_PREFIX + "0") + assert offline_requests_manager.has_more() + + +def test_add_more_normal(): + offline_requests_manager = OfflineRequestsManager(DEFAULT_QUEUE_SIZE, DropBehaviorTypes.DROP_NEWEST) + append_result = offline_requests_manager.add_one(FAKE_REQUEST_PREFIX + "0") + + assert append_result == AppendResults.APPEND_SUCCESS + + +def test_add_more_full_drop_newest(): + offline_requests_manager = OfflineRequestsManager(DEFAULT_QUEUE_SIZE, DropBehaviorTypes.DROP_NEWEST) + _overflow_the_queue(offline_requests_manager) + append_result = offline_requests_manager.add_one(FAKE_REQUEST_PREFIX + "A") + + assert append_result == AppendResults.APPEND_FAILURE_QUEUE_FULL + + next_request = offline_requests_manager.get_next() + assert next_request == FAKE_REQUEST_PREFIX + "0" + + +def test_add_more_full_drop_oldest(): + offline_requests_manager = OfflineRequestsManager(DEFAULT_QUEUE_SIZE, DropBehaviorTypes.DROP_OLDEST) + _overflow_the_queue(offline_requests_manager) + append_result = offline_requests_manager.add_one(FAKE_REQUEST_PREFIX + "A") + + assert append_result == AppendResults.APPEND_FAILURE_QUEUE_FULL + + next_request = offline_requests_manager.get_next() + assert next_request == FAKE_REQUEST_PREFIX + "1" + + +def test_add_more_disabled(): + offline_requests_manager = OfflineRequestsManager(0, DropBehaviorTypes.DROP_NEWEST) + append_result = offline_requests_manager.add_one(FAKE_REQUEST_PREFIX + "0") + + assert append_result == AppendResults.APPEND_FAILURE_QUEUE_DISABLED + + +def _overflow_the_queue(offline_requests_manager): + for i in range(0, DEFAULT_QUEUE_SIZE): + offline_requests_manager.add_one(FAKE_REQUEST_PREFIX + str(i)) + + +def test_get_next_normal(): + offline_requests_manager = OfflineRequestsManager(DEFAULT_QUEUE_SIZE, DropBehaviorTypes.DROP_NEWEST) + append_result = offline_requests_manager.add_one(FAKE_REQUEST_PREFIX + "0") + + assert append_result == AppendResults.APPEND_SUCCESS + assert offline_requests_manager.get_next() is not None + + +def test_get_next_empty(): + offline_requests_manager = OfflineRequestsManager(DEFAULT_QUEUE_SIZE, DropBehaviorTypes.DROP_NEWEST) + assert offline_requests_manager.get_next() is None diff --git a/test/core/protocol/internal/test_workers_subscription_manager.py b/test/core/protocol/internal/test_workers_subscription_manager.py new file mode 100644 index 0000000..6e436b7 --- /dev/null +++ b/test/core/protocol/internal/test_workers_subscription_manager.py @@ -0,0 +1,41 @@ +import pytest +from AWSIoTPythonSDK.core.protocol.internal.workers import SubscriptionManager + +DUMMY_TOPIC1 = "topic1" +DUMMY_TOPIC2 = "topic2" + + +def _dummy_callback(client, user_data, message): + pass + + +def test_add_record(): + subscription_manager = SubscriptionManager() + subscription_manager.add_record(DUMMY_TOPIC1, 1, _dummy_callback, _dummy_callback) + + record_list = subscription_manager.list_records() + + assert len(record_list) == 1 + + topic, (qos, message_callback, ack_callback) = record_list[0] + assert topic == DUMMY_TOPIC1 + assert qos == 1 + assert message_callback == _dummy_callback + assert ack_callback == _dummy_callback + + +def test_remove_record(): + subscription_manager = SubscriptionManager() + subscription_manager.add_record(DUMMY_TOPIC1, 1, _dummy_callback, _dummy_callback) + subscription_manager.add_record(DUMMY_TOPIC2, 0, _dummy_callback, _dummy_callback) + subscription_manager.remove_record(DUMMY_TOPIC1) + + record_list = subscription_manager.list_records() + + assert len(record_list) == 1 + + topic, (qos, message_callback, ack_callback) = record_list[0] + assert topic == DUMMY_TOPIC2 + assert qos == 0 + assert message_callback == _dummy_callback + assert ack_callback == _dummy_callback diff --git a/test/core/protocol/test_mqtt_core.py b/test/core/protocol/test_mqtt_core.py new file mode 100644 index 0000000..8469ea6 --- /dev/null +++ b/test/core/protocol/test_mqtt_core.py @@ -0,0 +1,585 @@ +import AWSIoTPythonSDK +from AWSIoTPythonSDK.core.protocol.mqtt_core import MqttCore +from AWSIoTPythonSDK.core.protocol.internal.clients import InternalAsyncMqttClient +from AWSIoTPythonSDK.core.protocol.internal.clients import ClientStatusContainer +from AWSIoTPythonSDK.core.protocol.internal.clients import ClientStatus +from AWSIoTPythonSDK.core.protocol.internal.workers import EventProducer +from AWSIoTPythonSDK.core.protocol.internal.workers import EventConsumer +from AWSIoTPythonSDK.core.protocol.internal.workers import SubscriptionManager +from AWSIoTPythonSDK.core.protocol.internal.workers import OfflineRequestsManager +from AWSIoTPythonSDK.core.protocol.internal.events import FixedEventMids +from AWSIoTPythonSDK.core.protocol.internal.queues import AppendResults +from AWSIoTPythonSDK.core.protocol.internal.requests import RequestTypes +from AWSIoTPythonSDK.core.protocol.internal.defaults import METRICS_PREFIX +from AWSIoTPythonSDK.exception.AWSIoTExceptions import connectError +from AWSIoTPythonSDK.exception.AWSIoTExceptions import connectTimeoutException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import disconnectError +from AWSIoTPythonSDK.exception.AWSIoTExceptions import disconnectTimeoutException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import publishError +from AWSIoTPythonSDK.exception.AWSIoTExceptions import publishTimeoutException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import publishQueueFullException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import publishQueueDisabledException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeError +from AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeTimeoutException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeQueueFullException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeQueueDisabledException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import unsubscribeError +from AWSIoTPythonSDK.exception.AWSIoTExceptions import unsubscribeTimeoutException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import unsubscribeQueueFullException +from AWSIoTPythonSDK.exception.AWSIoTExceptions import unsubscribeQueueDisabledException +from AWSIoTPythonSDK.core.protocol.paho.client import MQTT_ERR_SUCCESS +from AWSIoTPythonSDK.core.protocol.paho.client import MQTT_ERR_ERRNO +from AWSIoTPythonSDK.core.protocol.paho.client import MQTTv311 +from AWSIoTPythonSDK.core.protocol.internal.defaults import ALPN_PROTCOLS +try: + from mock import patch + from mock import MagicMock + from mock import NonCallableMagicMock + from mock import call +except: + from unittest.mock import patch + from unittest.mock import MagicMock + from unittest.mock import NonCallableMagicMock + from unittest.mock import call +from threading import Event +import pytest + + +PATCH_MODULE_LOCATION = "AWSIoTPythonSDK.core.protocol.mqtt_core." +DUMMY_SUCCESS_RC = MQTT_ERR_SUCCESS +DUMMY_FAILURE_RC = MQTT_ERR_ERRNO +DUMMY_REQUEST_MID = 89757 +DUMMY_CLIENT_ID = "CoolClientId" +DUMMY_KEEP_ALIVE_SEC = 60 +DUMMY_TOPIC = "topic/cool" +DUMMY_PAYLOAD = "CoolPayload" +DUMMY_QOS = 1 +DUMMY_USERNAME = "DummyUsername" +DUMMY_PASSWORD = "DummyPassword" + +KEY_EXPECTED_REQUEST_RC = "ExpectedRequestRc" +KEY_EXPECTED_QUEUE_APPEND_RESULT = "ExpectedQueueAppendResult" +KEY_EXPECTED_REQUEST_MID_OVERRIDE = "ExpectedRequestMidOverride" +KEY_EXPECTED_REQUEST_TIMEOUT = "ExpectedRequestTimeout" +SUCCESS_RC_EXPECTED_VALUES = { + KEY_EXPECTED_REQUEST_RC : DUMMY_SUCCESS_RC +} +FAILURE_RC_EXPECTED_VALUES = { + KEY_EXPECTED_REQUEST_RC : DUMMY_FAILURE_RC +} +TIMEOUT_EXPECTED_VALUES = { + KEY_EXPECTED_REQUEST_TIMEOUT : True +} +NO_TIMEOUT_EXPECTED_VALUES = { + KEY_EXPECTED_REQUEST_TIMEOUT : False +} +QUEUED_EXPECTED_VALUES = { + KEY_EXPECTED_QUEUE_APPEND_RESULT : AppendResults.APPEND_SUCCESS +} +QUEUE_FULL_EXPECTED_VALUES = { + KEY_EXPECTED_QUEUE_APPEND_RESULT : AppendResults.APPEND_FAILURE_QUEUE_FULL +} +QUEUE_DISABLED_EXPECTED_VALUES = { + KEY_EXPECTED_QUEUE_APPEND_RESULT : AppendResults.APPEND_FAILURE_QUEUE_DISABLED +} + +class TestMqttCore: + + def setup_class(cls): + cls.configure_internal_async_client = { + RequestTypes.CONNECT : cls._configure_internal_async_client_connect, + RequestTypes.DISCONNECT : cls._configure_internal_async_client_disconnect, + RequestTypes.PUBLISH : cls._configure_internal_async_client_publish, + RequestTypes.SUBSCRIBE : cls._configure_internal_async_client_subscribe, + RequestTypes.UNSUBSCRIBE : cls._configure_internal_async_client_unsubscribe + } + cls.invoke_mqtt_core_async_api = { + RequestTypes.CONNECT : cls._invoke_mqtt_core_connect_async, + RequestTypes.DISCONNECT : cls._invoke_mqtt_core_disconnect_async, + RequestTypes.PUBLISH : cls._invoke_mqtt_core_publish_async, + RequestTypes.SUBSCRIBE : cls._invoke_mqtt_core_subscribe_async, + RequestTypes.UNSUBSCRIBE : cls._invoke_mqtt_core_unsubscribe_async + } + cls.invoke_mqtt_core_sync_api = { + RequestTypes.CONNECT : cls._invoke_mqtt_core_connect, + RequestTypes.DISCONNECT : cls._invoke_mqtt_core_disconnect, + RequestTypes.PUBLISH : cls._invoke_mqtt_core_publish, + RequestTypes.SUBSCRIBE : cls._invoke_mqtt_core_subscribe, + RequestTypes.UNSUBSCRIBE : cls._invoke_mqtt_core_unsubscribe + } + cls.verify_mqtt_core_async_api = { + RequestTypes.CONNECT : cls._verify_mqtt_core_connect_async, + RequestTypes.DISCONNECT : cls._verify_mqtt_core_disconnect_async, + RequestTypes.PUBLISH : cls._verify_mqtt_core_publish_async, + RequestTypes.SUBSCRIBE : cls._verify_mqtt_core_subscribe_async, + RequestTypes.UNSUBSCRIBE : cls._verify_mqtt_core_unsubscribe_async + } + cls.request_error = { + RequestTypes.CONNECT : connectError, + RequestTypes.DISCONNECT : disconnectError, + RequestTypes.PUBLISH : publishError, + RequestTypes.SUBSCRIBE: subscribeError, + RequestTypes.UNSUBSCRIBE: unsubscribeError + } + cls.request_queue_full = { + RequestTypes.PUBLISH : publishQueueFullException, + RequestTypes.SUBSCRIBE: subscribeQueueFullException, + RequestTypes.UNSUBSCRIBE: unsubscribeQueueFullException + } + cls.request_queue_disable = { + RequestTypes.PUBLISH : publishQueueDisabledException, + RequestTypes.SUBSCRIBE : subscribeQueueDisabledException, + RequestTypes.UNSUBSCRIBE : unsubscribeQueueDisabledException + } + cls.request_timeout = { + RequestTypes.CONNECT : connectTimeoutException, + RequestTypes.DISCONNECT : disconnectTimeoutException, + RequestTypes.PUBLISH : publishTimeoutException, + RequestTypes.SUBSCRIBE : subscribeTimeoutException, + RequestTypes.UNSUBSCRIBE : unsubscribeTimeoutException + } + + def _configure_internal_async_client_connect(self, expected_rc, expected_mid=None): + self.internal_async_client_mock.connect.return_value = expected_rc + + def _configure_internal_async_client_disconnect(self, expected_rc, expeected_mid=None): + self.internal_async_client_mock.disconnect.return_value = expected_rc + + def _configure_internal_async_client_publish(self, expected_rc, expected_mid): + self.internal_async_client_mock.publish.return_value = expected_rc, expected_mid + + def _configure_internal_async_client_subscribe(self, expected_rc, expected_mid): + self.internal_async_client_mock.subscribe.return_value = expected_rc, expected_mid + + def _configure_internal_async_client_unsubscribe(self, expected_rc, expected_mid): + self.internal_async_client_mock.unsubscribe.return_value = expected_rc, expected_mid + + def _invoke_mqtt_core_connect_async(self, ack_callback, message_callback): + return self.mqtt_core.connect_async(DUMMY_KEEP_ALIVE_SEC, ack_callback) + + def _invoke_mqtt_core_disconnect_async(self, ack_callback, message_callback): + return self.mqtt_core.disconnect_async(ack_callback) + + def _invoke_mqtt_core_publish_async(self, ack_callback, message_callback): + return self.mqtt_core.publish_async(DUMMY_TOPIC, DUMMY_PAYLOAD, DUMMY_QOS, False, ack_callback) + + def _invoke_mqtt_core_subscribe_async(self, ack_callback, message_callback): + return self.mqtt_core.subscribe_async(DUMMY_TOPIC, DUMMY_QOS, ack_callback, message_callback) + + def _invoke_mqtt_core_unsubscribe_async(self, ack_callback, message_callback): + return self.mqtt_core.unsubscribe_async(DUMMY_TOPIC, ack_callback) + + def _invoke_mqtt_core_connect(self, message_callback): + return self.mqtt_core.connect(DUMMY_KEEP_ALIVE_SEC) + + def _invoke_mqtt_core_disconnect(self, message_callback): + return self.mqtt_core.disconnect() + + def _invoke_mqtt_core_publish(self, message_callback): + return self.mqtt_core.publish(DUMMY_TOPIC, DUMMY_PAYLOAD, DUMMY_QOS) + + def _invoke_mqtt_core_subscribe(self, message_callback): + return self.mqtt_core.subscribe(DUMMY_TOPIC, DUMMY_QOS, message_callback) + + def _invoke_mqtt_core_unsubscribe(self, message_callback): + return self.mqtt_core.unsubscribe(DUMMY_TOPIC) + + def _verify_mqtt_core_connect_async(self, ack_callback, message_callback): + self.internal_async_client_mock.connect.assert_called_once_with(DUMMY_KEEP_ALIVE_SEC, ack_callback) + self.client_status_mock.set_status.assert_called_once_with(ClientStatus.CONNECT) + + def _verify_mqtt_core_disconnect_async(self, ack_callback, message_callback): + self.internal_async_client_mock.disconnect.assert_called_once_with(ack_callback) + self.client_status_mock.set_status.assert_called_once_with(ClientStatus.USER_DISCONNECT) + + def _verify_mqtt_core_publish_async(self, ack_callback, message_callback): + self.internal_async_client_mock.publish.assert_called_once_with(DUMMY_TOPIC, DUMMY_PAYLOAD, DUMMY_QOS, + False, ack_callback) + + def _verify_mqtt_core_subscribe_async(self, ack_callback, message_callback): + self.internal_async_client_mock.subscribe.assert_called_once_with(DUMMY_TOPIC, DUMMY_QOS, ack_callback) + self.subscription_manager_mock.add_record.assert_called_once_with(DUMMY_TOPIC, DUMMY_QOS, message_callback, ack_callback) + + def _verify_mqtt_core_unsubscribe_async(self, ack_callback, message_callback): + self.internal_async_client_mock.unsubscribe.assert_called_once_with(DUMMY_TOPIC, ack_callback) + self.subscription_manager_mock.remove_record.assert_called_once_with(DUMMY_TOPIC) + + def setup_method(self, test_method): + self._use_mock_internal_async_client() + self._use_mock_event_producer() + self._use_mock_event_consumer() + self._use_mock_subscription_manager() + self._use_mock_offline_requests_manager() + self._use_mock_client_status() + self.mqtt_core = MqttCore(DUMMY_CLIENT_ID, True, MQTTv311, False) # We choose x.509 auth type for this test + + def _use_mock_internal_async_client(self): + self.internal_async_client_patcher = patch(PATCH_MODULE_LOCATION + "InternalAsyncMqttClient", + spec=InternalAsyncMqttClient) + self.mock_internal_async_client_constructor = self.internal_async_client_patcher.start() + self.internal_async_client_mock = MagicMock() + self.mock_internal_async_client_constructor.return_value = self.internal_async_client_mock + + def _use_mock_event_producer(self): + self.event_producer_patcher = patch(PATCH_MODULE_LOCATION + "EventProducer", spec=EventProducer) + self.mock_event_producer_constructor = self.event_producer_patcher.start() + self.event_producer_mock = MagicMock() + self.mock_event_producer_constructor.return_value = self.event_producer_mock + + def _use_mock_event_consumer(self): + self.event_consumer_patcher = patch(PATCH_MODULE_LOCATION + "EventConsumer", spec=EventConsumer) + self.mock_event_consumer_constructor = self.event_consumer_patcher.start() + self.event_consumer_mock = MagicMock() + self.mock_event_consumer_constructor.return_value = self.event_consumer_mock + + def _use_mock_subscription_manager(self): + self.subscription_manager_patcher = patch(PATCH_MODULE_LOCATION + "SubscriptionManager", + spec=SubscriptionManager) + self.mock_subscription_manager_constructor = self.subscription_manager_patcher.start() + self.subscription_manager_mock = MagicMock() + self.mock_subscription_manager_constructor.return_value = self.subscription_manager_mock + + def _use_mock_offline_requests_manager(self): + self.offline_requests_manager_patcher = patch(PATCH_MODULE_LOCATION + "OfflineRequestsManager", + spec=OfflineRequestsManager) + self.mock_offline_requests_manager_constructor = self.offline_requests_manager_patcher.start() + self.offline_requests_manager_mock = MagicMock() + self.mock_offline_requests_manager_constructor.return_value = self.offline_requests_manager_mock + + def _use_mock_client_status(self): + self.client_status_patcher = patch(PATCH_MODULE_LOCATION + "ClientStatusContainer", spec=ClientStatusContainer) + self.mock_client_status_constructor = self.client_status_patcher.start() + self.client_status_mock = MagicMock() + self.mock_client_status_constructor.return_value = self.client_status_mock + + def teardown_method(self, test_method): + self.internal_async_client_patcher.stop() + self.event_producer_patcher.stop() + self.event_consumer_patcher.stop() + self.subscription_manager_patcher.stop() + self.offline_requests_manager_patcher.stop() + self.client_status_patcher.stop() + + # Finally... Tests start + def test_use_wss(self): + self.mqtt_core = MqttCore(DUMMY_CLIENT_ID, True, MQTTv311, True) # use wss + assert self.mqtt_core.use_wss() is True + + def test_configure_alpn_protocols(self): + self.mqtt_core.configure_alpn_protocols() + self.internal_async_client_mock.configure_alpn_protocols.assert_called_once_with([ALPN_PROTCOLS]) + + def test_enable_metrics_collection_with_username_in_connect(self): + self._configure_internal_async_client_connect(DUMMY_SUCCESS_RC) + self._use_mock_python_event() + self.python_event_mock.wait.return_value = True + self.mqtt_core.configure_username_password(DUMMY_USERNAME, DUMMY_PASSWORD) + self.mqtt_core.connect(DUMMY_KEEP_ALIVE_SEC) + self.internal_async_client_mock.set_username_password.assert_called_once_with(DUMMY_USERNAME + + METRICS_PREFIX + + AWSIoTPythonSDK.__version__, + DUMMY_PASSWORD) + self.python_event_patcher.stop() + + def test_enable_metrics_collection_with_username_in_connect_async(self): + self._configure_internal_async_client_connect(DUMMY_SUCCESS_RC) + self.mqtt_core.configure_username_password(DUMMY_USERNAME, DUMMY_PASSWORD) + self.mqtt_core.connect_async(DUMMY_KEEP_ALIVE_SEC) + self.internal_async_client_mock.set_username_password.assert_called_once_with(DUMMY_USERNAME + + METRICS_PREFIX + + AWSIoTPythonSDK.__version__, + DUMMY_PASSWORD) + + def test_enable_metrics_collection_without_username_in_connect(self): + self._configure_internal_async_client_connect(DUMMY_SUCCESS_RC) + self._use_mock_python_event() + self.python_event_mock.wait.return_value = True + self.mqtt_core.connect(DUMMY_KEEP_ALIVE_SEC) + self.internal_async_client_mock.set_username_password.assert_called_once_with(METRICS_PREFIX + + AWSIoTPythonSDK.__version__, + None) + self.python_event_patcher.stop() + + def test_enable_metrics_collection_without_username_in_connect_async(self): + self._configure_internal_async_client_connect(DUMMY_SUCCESS_RC) + self.mqtt_core.connect_async(DUMMY_KEEP_ALIVE_SEC) + self.internal_async_client_mock.set_username_password.assert_called_once_with(METRICS_PREFIX + + AWSIoTPythonSDK.__version__, + None) + + def test_disable_metrics_collection_with_username_in_connect(self): + self._configure_internal_async_client_connect(DUMMY_SUCCESS_RC) + self._use_mock_python_event() + self.python_event_mock.wait.return_value = True + self.mqtt_core.disable_metrics_collection() + self.mqtt_core.configure_username_password(DUMMY_USERNAME, DUMMY_PASSWORD) + self.mqtt_core.connect(DUMMY_KEEP_ALIVE_SEC) + self.internal_async_client_mock.set_username_password.assert_called_once_with(DUMMY_USERNAME, DUMMY_PASSWORD) + self.python_event_patcher.stop() + + def test_disable_metrics_collection_with_username_in_connect_async(self): + self._configure_internal_async_client_connect(DUMMY_SUCCESS_RC) + self.mqtt_core.disable_metrics_collection() + self.mqtt_core.configure_username_password(DUMMY_USERNAME, DUMMY_PASSWORD) + self.mqtt_core.connect_async(DUMMY_KEEP_ALIVE_SEC) + self.internal_async_client_mock.set_username_password.assert_called_once_with(DUMMY_USERNAME, DUMMY_PASSWORD) + + def test_disable_metrics_collection_without_username_in_connect(self): + self._configure_internal_async_client_connect(DUMMY_SUCCESS_RC) + self._use_mock_python_event() + self.python_event_mock.wait.return_value = True + self.mqtt_core.disable_metrics_collection() + self.mqtt_core.connect(DUMMY_KEEP_ALIVE_SEC) + self.internal_async_client_mock.set_username_password.assert_called_once_with("", None) + self.python_event_patcher.stop() + + def test_disable_metrics_collection_without_username_in_connect_asyc(self): + self._configure_internal_async_client_connect(DUMMY_SUCCESS_RC) + self.mqtt_core.disable_metrics_collection() + self.mqtt_core.connect_async(DUMMY_KEEP_ALIVE_SEC) + self.internal_async_client_mock.set_username_password.assert_called_once_with("", None) + + def test_connect_async_success_rc(self): + expected_values = { + KEY_EXPECTED_REQUEST_RC : DUMMY_SUCCESS_RC, + KEY_EXPECTED_REQUEST_MID_OVERRIDE : FixedEventMids.CONNACK_MID + } + self._internal_test_async_api_with(RequestTypes.CONNECT, expected_values) + + def test_connect_async_failure_rc(self): + expected_values = { + KEY_EXPECTED_REQUEST_RC : DUMMY_FAILURE_RC, + KEY_EXPECTED_REQUEST_MID_OVERRIDE : FixedEventMids.CONNACK_MID + } + self._internal_test_async_api_with(RequestTypes.CONNECT, expected_values) + + def test_connect_async_when_failure_rc_should_stop_event_consumer(self): + self.internal_async_client_mock.connect.return_value = DUMMY_FAILURE_RC + + with pytest.raises(connectError): + self.mqtt_core.connect_async(DUMMY_KEEP_ALIVE_SEC) + + self.event_consumer_mock.start.assert_called_once() + self.event_consumer_mock.stop.assert_called_once() + self.event_consumer_mock.wait_until_it_stops.assert_called_once() + assert self.client_status_mock.set_status.call_count == 2 + assert self.client_status_mock.set_status.call_args_list == [call(ClientStatus.CONNECT), call(ClientStatus.IDLE)] + + def test_connect_async_when_exception_should_stop_event_consumer(self): + self.internal_async_client_mock.connect.side_effect = Exception("Something weird happened") + + with pytest.raises(Exception): + self.mqtt_core.connect_async(DUMMY_KEEP_ALIVE_SEC) + + self.event_consumer_mock.start.assert_called_once() + self.event_consumer_mock.stop.assert_called_once() + self.event_consumer_mock.wait_until_it_stops.assert_called_once() + assert self.client_status_mock.set_status.call_count == 2 + assert self.client_status_mock.set_status.call_args_list == [call(ClientStatus.CONNECT), call(ClientStatus.IDLE)] + + def test_disconnect_async_success_rc(self): + expected_values = { + KEY_EXPECTED_REQUEST_RC : DUMMY_SUCCESS_RC, + KEY_EXPECTED_REQUEST_MID_OVERRIDE : FixedEventMids.DISCONNECT_MID + } + self._internal_test_async_api_with(RequestTypes.DISCONNECT, expected_values) + + def test_disconnect_async_failure_rc(self): + expected_values = { + KEY_EXPECTED_REQUEST_RC : DUMMY_FAILURE_RC, + KEY_EXPECTED_REQUEST_MID_OVERRIDE : FixedEventMids.DISCONNECT_MID + } + self._internal_test_async_api_with(RequestTypes.DISCONNECT, expected_values) + + def test_publish_async_success_rc(self): + self._internal_test_async_api_with(RequestTypes.PUBLISH, SUCCESS_RC_EXPECTED_VALUES) + + def test_publish_async_failure_rc(self): + self._internal_test_async_api_with(RequestTypes.PUBLISH, FAILURE_RC_EXPECTED_VALUES) + + def test_publish_async_queued(self): + self._internal_test_async_api_with(RequestTypes.PUBLISH, QUEUED_EXPECTED_VALUES) + + def test_publish_async_queue_disabled(self): + self._internal_test_async_api_with(RequestTypes.PUBLISH, QUEUE_DISABLED_EXPECTED_VALUES) + + def test_publish_async_queue_full(self): + self._internal_test_async_api_with(RequestTypes.PUBLISH, QUEUE_FULL_EXPECTED_VALUES) + + def test_subscribe_async_success_rc(self): + self._internal_test_async_api_with(RequestTypes.SUBSCRIBE, SUCCESS_RC_EXPECTED_VALUES) + + def test_subscribe_async_failure_rc(self): + self._internal_test_async_api_with(RequestTypes.SUBSCRIBE, FAILURE_RC_EXPECTED_VALUES) + + def test_subscribe_async_queued(self): + self._internal_test_async_api_with(RequestTypes.SUBSCRIBE, QUEUED_EXPECTED_VALUES) + + def test_subscribe_async_queue_full(self): + self._internal_test_async_api_with(RequestTypes.SUBSCRIBE, QUEUE_FULL_EXPECTED_VALUES) + + def test_subscribe_async_queue_disabled(self): + self._internal_test_async_api_with(RequestTypes.SUBSCRIBE, QUEUE_DISABLED_EXPECTED_VALUES) + + def test_unsubscribe_async_success_rc(self): + self._internal_test_async_api_with(RequestTypes.UNSUBSCRIBE, SUCCESS_RC_EXPECTED_VALUES) + + def test_unsubscribe_async_failure_rc(self): + self._internal_test_async_api_with(RequestTypes.UNSUBSCRIBE, FAILURE_RC_EXPECTED_VALUES) + + def test_unsubscribe_async_queued(self): + self._internal_test_async_api_with(RequestTypes.UNSUBSCRIBE, QUEUED_EXPECTED_VALUES) + + def test_unsubscribe_async_queue_full(self): + self._internal_test_async_api_with(RequestTypes.UNSUBSCRIBE, QUEUE_FULL_EXPECTED_VALUES) + + def test_unsubscribe_async_queue_disabled(self): + self._internal_test_async_api_with(RequestTypes.UNSUBSCRIBE, QUEUE_DISABLED_EXPECTED_VALUES) + + def _internal_test_async_api_with(self, request_type, expected_values): + expected_rc = expected_values.get(KEY_EXPECTED_REQUEST_RC) + expected_append_result = expected_values.get(KEY_EXPECTED_QUEUE_APPEND_RESULT) + expected_request_mid_override = expected_values.get(KEY_EXPECTED_REQUEST_MID_OVERRIDE) + ack_callback = NonCallableMagicMock() + message_callback = NonCallableMagicMock() + + if expected_rc is not None: + self.configure_internal_async_client[request_type](self, expected_rc, DUMMY_REQUEST_MID) + self.client_status_mock.get_status.return_value = ClientStatus.STABLE + if expected_rc == DUMMY_SUCCESS_RC: + mid = self.invoke_mqtt_core_async_api[request_type](self, ack_callback, message_callback) + self.verify_mqtt_core_async_api[request_type](self, ack_callback, message_callback) + if expected_request_mid_override is not None: + assert mid == expected_request_mid_override + else: + assert mid == DUMMY_REQUEST_MID + else: # FAILURE_RC + with pytest.raises(self.request_error[request_type]): + self.invoke_mqtt_core_async_api[request_type](self, ack_callback, message_callback) + + if expected_append_result is not None: + self.client_status_mock.get_status.return_value = ClientStatus.ABNORMAL_DISCONNECT + self.offline_requests_manager_mock.add_one.return_value = expected_append_result + if expected_append_result == AppendResults.APPEND_SUCCESS: + mid = self.invoke_mqtt_core_async_api[request_type](self, ack_callback, message_callback) + assert mid == FixedEventMids.QUEUED_MID + elif expected_append_result == AppendResults.APPEND_FAILURE_QUEUE_FULL: + with pytest.raises(self.request_queue_full[request_type]): + self.invoke_mqtt_core_async_api[request_type](self, ack_callback, message_callback) + else: # AppendResults.APPEND_FAILURE_QUEUE_DISABLED + with pytest.raises(self.request_queue_disable[request_type]): + self.invoke_mqtt_core_async_api[request_type](self, ack_callback, message_callback) + + def test_connect_success(self): + expected_values = { + KEY_EXPECTED_REQUEST_TIMEOUT : False, + KEY_EXPECTED_REQUEST_MID_OVERRIDE : FixedEventMids.CONNACK_MID + } + self._internal_test_sync_api_with(RequestTypes.CONNECT, expected_values) + + def test_connect_timeout(self): + expected_values = { + KEY_EXPECTED_REQUEST_TIMEOUT : True, + KEY_EXPECTED_REQUEST_MID_OVERRIDE : FixedEventMids.CONNACK_MID + } + self._internal_test_sync_api_with(RequestTypes.CONNECT, expected_values) + + def test_disconnect_success(self): + expected_values = { + KEY_EXPECTED_REQUEST_TIMEOUT : False, + KEY_EXPECTED_REQUEST_MID_OVERRIDE : FixedEventMids.DISCONNECT_MID + } + self._internal_test_sync_api_with(RequestTypes.DISCONNECT, expected_values) + + def test_disconnect_timeout(self): + expected_values = { + KEY_EXPECTED_REQUEST_TIMEOUT : True, + KEY_EXPECTED_REQUEST_MID_OVERRIDE : FixedEventMids.DISCONNECT_MID + } + self._internal_test_sync_api_with(RequestTypes.DISCONNECT, expected_values) + + def test_publish_success(self): + self._internal_test_sync_api_with(RequestTypes.PUBLISH, NO_TIMEOUT_EXPECTED_VALUES) + + def test_publish_timeout(self): + self._internal_test_sync_api_with(RequestTypes.PUBLISH, TIMEOUT_EXPECTED_VALUES) + + def test_publish_queued(self): + self._internal_test_sync_api_with(RequestTypes.PUBLISH, QUEUED_EXPECTED_VALUES) + + def test_publish_queue_full(self): + self._internal_test_sync_api_with(RequestTypes.PUBLISH, QUEUE_FULL_EXPECTED_VALUES) + + def test_publish_queue_disabled(self): + self._internal_test_sync_api_with(RequestTypes.PUBLISH, QUEUE_DISABLED_EXPECTED_VALUES) + + def test_subscribe_success(self): + self._internal_test_sync_api_with(RequestTypes.SUBSCRIBE, NO_TIMEOUT_EXPECTED_VALUES) + + def test_subscribe_timeout(self): + self._internal_test_sync_api_with(RequestTypes.SUBSCRIBE, TIMEOUT_EXPECTED_VALUES) + + def test_subscribe_queued(self): + self._internal_test_sync_api_with(RequestTypes.SUBSCRIBE, QUEUED_EXPECTED_VALUES) + + def test_subscribe_queue_full(self): + self._internal_test_sync_api_with(RequestTypes.SUBSCRIBE, QUEUE_FULL_EXPECTED_VALUES) + + def test_subscribe_queue_disabled(self): + self._internal_test_sync_api_with(RequestTypes.SUBSCRIBE, QUEUE_DISABLED_EXPECTED_VALUES) + + def test_unsubscribe_success(self): + self._internal_test_sync_api_with(RequestTypes.UNSUBSCRIBE, NO_TIMEOUT_EXPECTED_VALUES) + + def test_unsubscribe_timeout(self): + self._internal_test_sync_api_with(RequestTypes.UNSUBSCRIBE, TIMEOUT_EXPECTED_VALUES) + + def test_unsubscribe_queued(self): + self._internal_test_sync_api_with(RequestTypes.UNSUBSCRIBE, QUEUED_EXPECTED_VALUES) + + def test_unsubscribe_queue_full(self): + self._internal_test_sync_api_with(RequestTypes.UNSUBSCRIBE, QUEUE_FULL_EXPECTED_VALUES) + + def test_unsubscribe_queue_disabled(self): + self._internal_test_sync_api_with(RequestTypes.UNSUBSCRIBE, QUEUE_DISABLED_EXPECTED_VALUES) + + def _internal_test_sync_api_with(self, request_type, expected_values): + expected_request_mid = expected_values.get(KEY_EXPECTED_REQUEST_MID_OVERRIDE) + expected_timeout = expected_values.get(KEY_EXPECTED_REQUEST_TIMEOUT) + expected_append_result = expected_values.get(KEY_EXPECTED_QUEUE_APPEND_RESULT) + + if expected_request_mid is None: + expected_request_mid = DUMMY_REQUEST_MID + message_callback = NonCallableMagicMock() + self.configure_internal_async_client[request_type](self, DUMMY_SUCCESS_RC, expected_request_mid) + self._use_mock_python_event() + + if expected_timeout is not None: + self.client_status_mock.get_status.return_value = ClientStatus.STABLE + if expected_timeout: + self.python_event_mock.wait.return_value = False + with pytest.raises(self.request_timeout[request_type]): + self.invoke_mqtt_core_sync_api[request_type](self, message_callback) + else: + self.python_event_mock.wait.return_value = True + assert self.invoke_mqtt_core_sync_api[request_type](self, message_callback) is True + + if expected_append_result is not None: + self.client_status_mock.get_status.return_value = ClientStatus.ABNORMAL_DISCONNECT + self.offline_requests_manager_mock.add_one.return_value = expected_append_result + if expected_append_result == AppendResults.APPEND_SUCCESS: + assert self.invoke_mqtt_core_sync_api[request_type](self, message_callback) is False + elif expected_append_result == AppendResults.APPEND_FAILURE_QUEUE_FULL: + with pytest.raises(self.request_queue_full[request_type]): + self.invoke_mqtt_core_sync_api[request_type](self, message_callback) + else: + with pytest.raises(self.request_queue_disable[request_type]): + self.invoke_mqtt_core_sync_api[request_type](self, message_callback) + + self.python_event_patcher.stop() + + def _use_mock_python_event(self): + self.python_event_patcher = patch(PATCH_MODULE_LOCATION + "Event", spec=Event) + self.python_event_constructor = self.python_event_patcher.start() + self.python_event_mock = MagicMock() + self.python_event_constructor.return_value = self.python_event_mock diff --git a/test/core/shadow/__init__.py b/test/core/shadow/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/core/shadow/test_device_shadow.py b/test/core/shadow/test_device_shadow.py new file mode 100755 index 0000000..3b4ec61 --- /dev/null +++ b/test/core/shadow/test_device_shadow.py @@ -0,0 +1,297 @@ +# Test shadow behavior for a single device shadow + +from AWSIoTPythonSDK.core.shadow.deviceShadow import deviceShadow +from AWSIoTPythonSDK.core.shadow.shadowManager import shadowManager +from AWSIoTPythonSDK.core.protocol.paho.client import MQTTMessage +import time +import json +try: + from mock import MagicMock +except: + from unittest.mock import MagicMock + + +DUMMY_THING_NAME = "CoolThing" +DUMMY_SHADOW_OP_TIME_OUT_SEC = 3 + +SHADOW_OP_TYPE_GET = "get" +SHADOW_OP_TYPE_DELETE = "delete" +SHADOW_OP_TYPE_UPDATE = "update" +SHADOW_OP_RESPONSE_STATUS_ACCEPTED = "accepted" +SHADOW_OP_RESPONSE_STATUS_REJECTED = "rejected" +SHADOW_OP_RESPONSE_STATUS_TIMEOUT = "timeout" +SHADOW_OP_RESPONSE_STATUS_DELTA = "delta" + +SHADOW_TOPIC_PREFIX = "$aws/things/" +SHADOW_TOPIC_GET_ACCEPTED = SHADOW_TOPIC_PREFIX + DUMMY_THING_NAME + "/shadow/get/accepted" +SHADOW_TOPIC_GET_REJECTED = SHADOW_TOPIC_PREFIX + DUMMY_THING_NAME + "/shadow/get/rejected" +SHADOW_TOPIC_DELETE_ACCEPTED = SHADOW_TOPIC_PREFIX + DUMMY_THING_NAME + "/shadow/delete/accepted" +SHADOW_TOPIC_DELETE_REJECTED = SHADOW_TOPIC_PREFIX + DUMMY_THING_NAME + "/shadow/delete/rejected" +SHADOW_TOPIC_UPDATE_ACCEPTED = SHADOW_TOPIC_PREFIX + DUMMY_THING_NAME + "/shadow/update/accepted" +SHADOW_TOPIC_UPDATE_REJECTED = SHADOW_TOPIC_PREFIX + DUMMY_THING_NAME + "/shadow/update/rejected" +SHADOW_TOPIC_UPDATE_DELTA = SHADOW_TOPIC_PREFIX + DUMMY_THING_NAME + "/shadow/update/delta" +SHADOW_RESPONSE_PAYLOAD_TIMEOUT = "REQUEST TIME OUT" + +VALUE_OVERRIDE_KEY_INBOUND_PAYLOAD = "InBoundPayload" +VALUE_OVERRIDE_KEY_OUTBOUND_PAYLOAD = "OutBoundPayload" + +GARBAGE_PAYLOAD = b"ThisIsGarbagePayload!" + +VALUE_OVERRIDE_GARBAGE_INBOUND_PAYLOAD = { + VALUE_OVERRIDE_KEY_INBOUND_PAYLOAD : GARBAGE_PAYLOAD +} + + +class TestDeviceShadow: + + def setup_class(cls): + cls.invoke_shadow_operation = { + SHADOW_OP_TYPE_GET : cls._invoke_shadow_get, + SHADOW_OP_TYPE_DELETE : cls._invoke_shadow_delete, + SHADOW_OP_TYPE_UPDATE : cls._invoke_shadow_update + } + cls._get_topics = { + SHADOW_OP_RESPONSE_STATUS_ACCEPTED : SHADOW_TOPIC_GET_ACCEPTED, + SHADOW_OP_RESPONSE_STATUS_REJECTED : SHADOW_TOPIC_GET_REJECTED, + } + cls._delete_topics = { + SHADOW_OP_RESPONSE_STATUS_ACCEPTED : SHADOW_TOPIC_DELETE_ACCEPTED, + SHADOW_OP_RESPONSE_STATUS_REJECTED : SHADOW_TOPIC_DELETE_REJECTED + } + cls._update_topics = { + SHADOW_OP_RESPONSE_STATUS_ACCEPTED : SHADOW_TOPIC_UPDATE_ACCEPTED, + SHADOW_OP_RESPONSE_STATUS_REJECTED : SHADOW_TOPIC_UPDATE_REJECTED, + SHADOW_OP_RESPONSE_STATUS_DELTA : SHADOW_TOPIC_UPDATE_DELTA + } + cls.shadow_topics = { + SHADOW_OP_TYPE_GET : cls._get_topics, + SHADOW_OP_TYPE_DELETE : cls._delete_topics, + SHADOW_OP_TYPE_UPDATE : cls._update_topics + } + + def _invoke_shadow_get(self): + return self.device_shadow_handler.shadowGet(self.shadow_callback, DUMMY_SHADOW_OP_TIME_OUT_SEC) + + def _invoke_shadow_delete(self): + return self.device_shadow_handler.shadowDelete(self.shadow_callback, DUMMY_SHADOW_OP_TIME_OUT_SEC) + + def _invoke_shadow_update(self): + return self.device_shadow_handler.shadowUpdate("{}", self.shadow_callback, DUMMY_SHADOW_OP_TIME_OUT_SEC) + + def setup_method(self, method): + self.shadow_manager_mock = MagicMock(spec=shadowManager) + self.shadow_callback = MagicMock() + self._create_device_shadow_handler() # Create device shadow handler with persistent subscribe by default + + def _create_device_shadow_handler(self, is_persistent_subscribe=True): + self.device_shadow_handler = deviceShadow(DUMMY_THING_NAME, is_persistent_subscribe, self.shadow_manager_mock) + + # Shadow delta + def test_register_delta_callback_older_version_should_not_invoke(self): + self.device_shadow_handler.shadowRegisterDeltaCallback(self.shadow_callback) + self._fake_incoming_delta_message_with(version=3) + + # Make next delta message with an old version + self._fake_incoming_delta_message_with(version=1) + + assert self.shadow_callback.call_count == 1 # Once time from the previous delta message + + def test_unregister_delta_callback_should_not_invoke_after(self): + self.device_shadow_handler.shadowRegisterDeltaCallback(self.shadow_callback) + fake_delta_message = self._fake_incoming_delta_message_with(version=3) + self.shadow_callback.assert_called_once_with(fake_delta_message.payload.decode("utf-8"), + SHADOW_OP_RESPONSE_STATUS_DELTA + "/" + DUMMY_THING_NAME, + None) + + # Now we unregister + self.device_shadow_handler.shadowUnregisterDeltaCallback() + self._fake_incoming_delta_message_with(version=5) + assert self.shadow_callback.call_count == 1 # One time from the previous delta message + + def test_register_delta_callback_newer_version_should_invoke(self): + self.device_shadow_handler.shadowRegisterDeltaCallback(self.shadow_callback) + fake_delta_message = self._fake_incoming_delta_message_with(version=300) + + self.shadow_callback.assert_called_once_with(fake_delta_message.payload.decode("utf-8"), + SHADOW_OP_RESPONSE_STATUS_DELTA + "/" + DUMMY_THING_NAME, + None) + + def test_register_delta_callback_no_version_should_not_invoke(self): + self.device_shadow_handler.shadowRegisterDeltaCallback(self.shadow_callback) + self._fake_incoming_delta_message_with(version=None) + + assert self.shadow_callback.call_count == 0 + + def _fake_incoming_delta_message_with(self, version): + fake_delta_message = self._create_fake_shadow_response(SHADOW_TOPIC_UPDATE_DELTA, + self._create_simple_payload(token=None, version=version)) + self.device_shadow_handler.generalCallback(None, None, fake_delta_message) + time.sleep(1) # Callback executed in another thread, wait to make sure the artifacts are generated + return fake_delta_message + + # Shadow get + def test_persistent_shadow_get_accepted(self): + self._internal_test_persistent_shadow_operation(SHADOW_OP_TYPE_GET, SHADOW_OP_RESPONSE_STATUS_ACCEPTED) + + def test_persistent_shadow_get_rejected(self): + self._internal_test_persistent_shadow_operation(SHADOW_OP_TYPE_GET, SHADOW_OP_RESPONSE_STATUS_REJECTED) + + def test_persistent_shadow_get_time_out(self): + self._internal_test_persistent_shadow_operation(SHADOW_OP_TYPE_GET, SHADOW_OP_RESPONSE_STATUS_TIMEOUT) + + def test_persistent_shadow_get_garbage_response_should_time_out(self): + self._internal_test_persistent_shadow_operation(SHADOW_OP_TYPE_GET, + SHADOW_OP_RESPONSE_STATUS_ACCEPTED, + value_override=VALUE_OVERRIDE_GARBAGE_INBOUND_PAYLOAD) + + def test_non_persistent_shadow_get_accepted(self): + self._internal_test_non_persistent_shadow_operation(SHADOW_OP_TYPE_GET, SHADOW_OP_RESPONSE_STATUS_ACCEPTED) + + def test_non_persistent_shadow_get_rejected(self): + self._internal_test_non_persistent_shadow_operation(SHADOW_OP_TYPE_GET, SHADOW_OP_RESPONSE_STATUS_REJECTED) + + def test_non_persistent_shadow_get_time_out(self): + self._internal_test_non_persistent_shadow_operation(SHADOW_OP_TYPE_GET, SHADOW_OP_RESPONSE_STATUS_TIMEOUT) + + def test_non_persistent_shadow_get_garbage_response_should_time_out(self): + self._internal_test_non_persistent_shadow_operation(SHADOW_OP_TYPE_GET, + SHADOW_OP_RESPONSE_STATUS_ACCEPTED, + value_override=VALUE_OVERRIDE_GARBAGE_INBOUND_PAYLOAD) + + # Shadow delete + def test_persistent_shadow_delete_accepted(self): + self._internal_test_persistent_shadow_operation(SHADOW_OP_TYPE_DELETE, SHADOW_OP_RESPONSE_STATUS_ACCEPTED) + + def test_persistent_shadow_delete_rejected(self): + self._internal_test_persistent_shadow_operation(SHADOW_OP_TYPE_DELETE, SHADOW_OP_RESPONSE_STATUS_REJECTED) + + def test_persistent_shadow_delete_time_out(self): + self._internal_test_persistent_shadow_operation(SHADOW_OP_TYPE_DELETE, SHADOW_OP_RESPONSE_STATUS_TIMEOUT) + + def test_persistent_shadow_delete_garbage_response_should_time_out(self): + self._internal_test_persistent_shadow_operation(SHADOW_OP_TYPE_DELETE, + SHADOW_OP_RESPONSE_STATUS_ACCEPTED, + value_override=VALUE_OVERRIDE_GARBAGE_INBOUND_PAYLOAD) + + def test_non_persistent_shadow_delete_accepted(self): + self._internal_test_non_persistent_shadow_operation(SHADOW_OP_TYPE_DELETE, SHADOW_OP_RESPONSE_STATUS_ACCEPTED) + + def test_non_persistent_shadow_delete_rejected(self): + self._internal_test_non_persistent_shadow_operation(SHADOW_OP_TYPE_DELETE, SHADOW_OP_RESPONSE_STATUS_REJECTED) + + def test_non_persistent_shadow_delete_time_out(self): + self._internal_test_non_persistent_shadow_operation(SHADOW_OP_TYPE_DELETE, SHADOW_OP_RESPONSE_STATUS_TIMEOUT) + + def test_non_persistent_shadow_delete_garbage_response_should_time_out(self): + self._internal_test_non_persistent_shadow_operation(SHADOW_OP_TYPE_DELETE, + SHADOW_OP_RESPONSE_STATUS_ACCEPTED, + value_override=VALUE_OVERRIDE_GARBAGE_INBOUND_PAYLOAD) + + # Shadow update + def test_persistent_shadow_update_accepted(self): + self._internal_test_persistent_shadow_operation(SHADOW_OP_TYPE_UPDATE, SHADOW_OP_RESPONSE_STATUS_ACCEPTED) + + def test_persistent_shadow_update_rejected(self): + self._internal_test_persistent_shadow_operation(SHADOW_OP_TYPE_UPDATE, SHADOW_OP_RESPONSE_STATUS_REJECTED) + + def test_persistent_shadow_update_time_out(self): + self._internal_test_persistent_shadow_operation(SHADOW_OP_TYPE_UPDATE, SHADOW_OP_RESPONSE_STATUS_TIMEOUT) + + def test_persistent_shadow_update_garbage_response_should_time_out(self): + self._internal_test_persistent_shadow_operation(SHADOW_OP_TYPE_UPDATE, + SHADOW_OP_RESPONSE_STATUS_ACCEPTED, + value_override=VALUE_OVERRIDE_GARBAGE_INBOUND_PAYLOAD) + + def test_non_persistent_shadow_update_accepted(self): + self._internal_test_non_persistent_shadow_operation(SHADOW_OP_TYPE_UPDATE, SHADOW_OP_RESPONSE_STATUS_ACCEPTED) + + def test_non_persistent_shadow_update_rejected(self): + self._internal_test_non_persistent_shadow_operation(SHADOW_OP_TYPE_UPDATE, SHADOW_OP_RESPONSE_STATUS_REJECTED) + + def test_non_persistent_shadow_update_time_out(self): + self._internal_test_non_persistent_shadow_operation(SHADOW_OP_TYPE_UPDATE, SHADOW_OP_RESPONSE_STATUS_TIMEOUT) + + def test_non_persistent_shadow_update_garbage_response_should_time_out(self): + self._internal_test_non_persistent_shadow_operation(SHADOW_OP_TYPE_UPDATE, + SHADOW_OP_RESPONSE_STATUS_ACCEPTED, + value_override=VALUE_OVERRIDE_GARBAGE_INBOUND_PAYLOAD) + + def _internal_test_non_persistent_shadow_operation(self, operation_type, operation_response_type, value_override=None): + self._create_device_shadow_handler(is_persistent_subscribe=False) + token = self.invoke_shadow_operation[operation_type](self) + inbound_payload, wait_time_sec, expected_response_type, expected_shadow_response_payload = \ + self._prepare_test_values(token, operation_response_type, value_override) + expected_shadow_response_payload = \ + self._invoke_shadow_general_callback_on_demand(operation_type, operation_response_type, + (inbound_payload, wait_time_sec, expected_shadow_response_payload)) + + self._assert_first_call_correct(operation_type, (token, expected_response_type, expected_shadow_response_payload)) + + def _internal_test_persistent_shadow_operation(self, operation_type, operation_response_type, value_override=None): + token = self.invoke_shadow_operation[operation_type](self) + inbound_payload, wait_time_sec, expected_response_type, expected_shadow_response_payload = \ + self._prepare_test_values(token, operation_response_type, value_override) + expected_shadow_response_payload = \ + self._invoke_shadow_general_callback_on_demand(operation_type, operation_response_type, + (inbound_payload, wait_time_sec, expected_shadow_response_payload)) + + self._assert_first_call_correct(operation_type, + (token, expected_response_type, expected_shadow_response_payload), + is_persistent=True) + + def _prepare_test_values(self, token, operation_response_type, value_override): + inbound_payload = None + if value_override: + inbound_payload = value_override.get(VALUE_OVERRIDE_KEY_INBOUND_PAYLOAD) + if inbound_payload is None: + inbound_payload = self._create_simple_payload(token, version=3) # Should be bytes in Py3 + if inbound_payload == GARBAGE_PAYLOAD: + expected_shadow_response_payload = SHADOW_RESPONSE_PAYLOAD_TIMEOUT + wait_time_sec = DUMMY_SHADOW_OP_TIME_OUT_SEC + 1 + expected_response_type = SHADOW_OP_RESPONSE_STATUS_TIMEOUT + else: + expected_shadow_response_payload = inbound_payload.decode("utf-8") # Should always be str in Py2/3 + wait_time_sec = 1 + expected_response_type = operation_response_type + + return inbound_payload, wait_time_sec, expected_response_type, expected_shadow_response_payload + + def _invoke_shadow_general_callback_on_demand(self, operation_type, operation_response_type, data): + inbound_payload, wait_time_sec, expected_shadow_response_payload = data + + if operation_response_type == SHADOW_OP_RESPONSE_STATUS_TIMEOUT: + time.sleep(DUMMY_SHADOW_OP_TIME_OUT_SEC + 1) # Make it time out for sure + return SHADOW_RESPONSE_PAYLOAD_TIMEOUT + else: + fake_shadow_response = self._create_fake_shadow_response(self.shadow_topics[operation_type][operation_response_type], + inbound_payload) + self.device_shadow_handler.generalCallback(None, None, fake_shadow_response) + time.sleep(wait_time_sec) # Callback executed in another thread, wait to make sure the artifacts are generated + return expected_shadow_response_payload + + def _assert_first_call_correct(self, operation_type, expected_data, is_persistent=False): + token, expected_response_type, expected_shadow_response_payload = expected_data + + self.shadow_manager_mock.basicShadowSubscribe.assert_called_once_with(DUMMY_THING_NAME, operation_type, + self.device_shadow_handler.generalCallback) + self.shadow_manager_mock.basicShadowPublish.\ + assert_called_once_with(DUMMY_THING_NAME, + operation_type, + self._create_simple_payload(token, version=None).decode("utf-8")) + self.shadow_callback.assert_called_once_with(expected_shadow_response_payload, expected_response_type, token) + if not is_persistent: + self.shadow_manager_mock.basicShadowUnsubscribe.assert_called_once_with(DUMMY_THING_NAME, operation_type) + + def _create_fake_shadow_response(self, topic, payload): + response = MQTTMessage() + response.topic = topic + response.payload = payload + return response + + def _create_simple_payload(self, token, version): + payload_object = dict() + if token is not None: + payload_object["clientToken"] = token + if version is not None: + payload_object["version"] = version + return json.dumps(payload_object).encode("utf-8") diff --git a/test/core/shadow/test_shadow_manager.py b/test/core/shadow/test_shadow_manager.py new file mode 100644 index 0000000..f99bdb5 --- /dev/null +++ b/test/core/shadow/test_shadow_manager.py @@ -0,0 +1,83 @@ +from AWSIoTPythonSDK.core.protocol.mqtt_core import MqttCore +from AWSIoTPythonSDK.core.shadow.shadowManager import shadowManager +try: + from mock import MagicMock +except: + from unittest.mock import MagicMock +try: + from mock import NonCallableMagicMock +except: + from unittest.mock import NonCallableMagicMock +try: + from mock import call +except: + from unittest.mock import call +import pytest + + +DUMMY_SHADOW_NAME = "CoolShadow" +DUMMY_PAYLOAD = "{}" + +OP_SHADOW_GET = "get" +OP_SHADOW_UPDATE = "update" +OP_SHADOW_DELETE = "delete" +OP_SHADOW_DELTA = "delta" +OP_SHADOW_TROUBLE_MAKER = "not_a_valid_shadow_aciton_name" + +DUMMY_SHADOW_TOPIC_PREFIX = "$aws/things/" + DUMMY_SHADOW_NAME + "/shadow/" +DUMMY_SHADOW_TOPIC_GET = DUMMY_SHADOW_TOPIC_PREFIX + "get" +DUMMY_SHADOW_TOPIC_GET_ACCEPTED = DUMMY_SHADOW_TOPIC_GET + "/accepted" +DUMMY_SHADOW_TOPIC_GET_REJECTED = DUMMY_SHADOW_TOPIC_GET + "/rejected" +DUMMY_SHADOW_TOPIC_UPDATE = DUMMY_SHADOW_TOPIC_PREFIX + "update" +DUMMY_SHADOW_TOPIC_UPDATE_ACCEPTED = DUMMY_SHADOW_TOPIC_UPDATE + "/accepted" +DUMMY_SHADOW_TOPIC_UPDATE_REJECTED = DUMMY_SHADOW_TOPIC_UPDATE + "/rejected" +DUMMY_SHADOW_TOPIC_UPDATE_DELTA = DUMMY_SHADOW_TOPIC_UPDATE + "/delta" +DUMMY_SHADOW_TOPIC_DELETE = DUMMY_SHADOW_TOPIC_PREFIX + "delete" +DUMMY_SHADOW_TOPIC_DELETE_ACCEPTED = DUMMY_SHADOW_TOPIC_DELETE + "/accepted" +DUMMY_SHADOW_TOPIC_DELETE_REJECTED = DUMMY_SHADOW_TOPIC_DELETE + "/rejected" + + +class TestShadowManager: + + def setup_method(self, test_method): + self.mock_mqtt_core = MagicMock(spec=MqttCore) + self.shadow_manager = shadowManager(self.mock_mqtt_core) + + def test_basic_shadow_publish(self): + self.shadow_manager.basicShadowPublish(DUMMY_SHADOW_NAME, OP_SHADOW_GET, DUMMY_PAYLOAD) + self.shadow_manager.basicShadowPublish(DUMMY_SHADOW_NAME, OP_SHADOW_UPDATE, DUMMY_PAYLOAD) + self.shadow_manager.basicShadowPublish(DUMMY_SHADOW_NAME, OP_SHADOW_DELETE, DUMMY_PAYLOAD) + self.mock_mqtt_core.publish.assert_has_calls([call(DUMMY_SHADOW_TOPIC_GET, DUMMY_PAYLOAD, 0, False), + call(DUMMY_SHADOW_TOPIC_UPDATE, DUMMY_PAYLOAD, 0, False), + call(DUMMY_SHADOW_TOPIC_DELETE, DUMMY_PAYLOAD, 0, False)]) + + def test_basic_shadow_subscribe(self): + callback = NonCallableMagicMock() + self.shadow_manager.basicShadowSubscribe(DUMMY_SHADOW_NAME, OP_SHADOW_GET, callback) + self.shadow_manager.basicShadowSubscribe(DUMMY_SHADOW_NAME, OP_SHADOW_UPDATE, callback) + self.shadow_manager.basicShadowSubscribe(DUMMY_SHADOW_NAME, OP_SHADOW_DELETE, callback) + self.shadow_manager.basicShadowSubscribe(DUMMY_SHADOW_NAME, OP_SHADOW_DELTA, callback) + self.mock_mqtt_core.subscribe.assert_has_calls([call(DUMMY_SHADOW_TOPIC_GET_ACCEPTED, 0, callback), + call(DUMMY_SHADOW_TOPIC_GET_REJECTED, 0, callback), + call(DUMMY_SHADOW_TOPIC_UPDATE_ACCEPTED, 0, callback), + call(DUMMY_SHADOW_TOPIC_UPDATE_REJECTED, 0, callback), + call(DUMMY_SHADOW_TOPIC_DELETE_ACCEPTED, 0, callback), + call(DUMMY_SHADOW_TOPIC_DELETE_REJECTED, 0, callback), + call(DUMMY_SHADOW_TOPIC_UPDATE_DELTA, 0, callback)]) + + def test_basic_shadow_unsubscribe(self): + self.shadow_manager.basicShadowUnsubscribe(DUMMY_SHADOW_NAME, OP_SHADOW_GET) + self.shadow_manager.basicShadowUnsubscribe(DUMMY_SHADOW_NAME, OP_SHADOW_UPDATE) + self.shadow_manager.basicShadowUnsubscribe(DUMMY_SHADOW_NAME, OP_SHADOW_DELETE) + self.shadow_manager.basicShadowUnsubscribe(DUMMY_SHADOW_NAME, OP_SHADOW_DELTA) + self.mock_mqtt_core.unsubscribe.assert_has_calls([call(DUMMY_SHADOW_TOPIC_GET_ACCEPTED), + call(DUMMY_SHADOW_TOPIC_GET_REJECTED), + call(DUMMY_SHADOW_TOPIC_UPDATE_ACCEPTED), + call(DUMMY_SHADOW_TOPIC_UPDATE_REJECTED), + call(DUMMY_SHADOW_TOPIC_DELETE_ACCEPTED), + call(DUMMY_SHADOW_TOPIC_DELETE_REJECTED), + call(DUMMY_SHADOW_TOPIC_UPDATE_DELTA)]) + + def test_unsupported_shadow_action_name(self): + with pytest.raises(TypeError): + self.shadow_manager.basicShadowUnsubscribe(DUMMY_SHADOW_NAME, OP_SHADOW_TROUBLE_MAKER) diff --git a/test/core/util/__init__.py b/test/core/util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/core/util/test_providers.py b/test/core/util/test_providers.py new file mode 100644 index 0000000..0515790 --- /dev/null +++ b/test/core/util/test_providers.py @@ -0,0 +1,46 @@ +from AWSIoTPythonSDK.core.util.providers import CertificateCredentialsProvider +from AWSIoTPythonSDK.core.util.providers import IAMCredentialsProvider +from AWSIoTPythonSDK.core.util.providers import EndpointProvider + + +DUMMY_PATH = "/dummy/path/" +DUMMY_CERT_PATH = DUMMY_PATH + "cert.pem" +DUMMY_CA_PATH = DUMMY_PATH + "ca.crt" +DUMMY_KEY_PATH = DUMMY_PATH + "key.pem" +DUMMY_ACCESS_KEY_ID = "AccessKey" +DUMMY_SECRET_KEY = "SecretKey" +DUMMY_SESSION_TOKEN = "SessionToken" +DUMMY_HOST = "dummy.host.com" +DUMMY_PORT = 8888 + + +class TestProviders: + + def setup_method(self, test_method): + self.certificate_credentials_provider = CertificateCredentialsProvider() + self.iam_credentials_provider = IAMCredentialsProvider() + self.endpoint_provider = EndpointProvider() + + def test_certificate_credentials_provider(self): + self.certificate_credentials_provider.set_ca_path(DUMMY_CA_PATH) + self.certificate_credentials_provider.set_cert_path(DUMMY_CERT_PATH) + self.certificate_credentials_provider.set_key_path(DUMMY_KEY_PATH) + assert self.certificate_credentials_provider.get_ca_path() == DUMMY_CA_PATH + assert self.certificate_credentials_provider.get_cert_path() == DUMMY_CERT_PATH + assert self.certificate_credentials_provider.get_key_path() == DUMMY_KEY_PATH + + def test_iam_credentials_provider(self): + self.iam_credentials_provider.set_ca_path(DUMMY_CA_PATH) + self.iam_credentials_provider.set_access_key_id(DUMMY_ACCESS_KEY_ID) + self.iam_credentials_provider.set_secret_access_key(DUMMY_SECRET_KEY) + self.iam_credentials_provider.set_session_token(DUMMY_SESSION_TOKEN) + assert self.iam_credentials_provider.get_ca_path() == DUMMY_CA_PATH + assert self.iam_credentials_provider.get_access_key_id() == DUMMY_ACCESS_KEY_ID + assert self.iam_credentials_provider.get_secret_access_key() == DUMMY_SECRET_KEY + assert self.iam_credentials_provider.get_session_token() == DUMMY_SESSION_TOKEN + + def test_endpoint_provider(self): + self.endpoint_provider.set_host(DUMMY_HOST) + self.endpoint_provider.set_port(DUMMY_PORT) + assert self.endpoint_provider.get_host() == DUMMY_HOST + assert self.endpoint_provider.get_port() == DUMMY_PORT diff --git a/test/sdk_mock/__init__.py b/test/sdk_mock/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/test/sdk_mock/mockAWSIoTPythonSDK.py b/test/sdk_mock/mockAWSIoTPythonSDK.py new file mode 100755 index 0000000..b362570 --- /dev/null +++ b/test/sdk_mock/mockAWSIoTPythonSDK.py @@ -0,0 +1,34 @@ +import sys +import mockMQTTCore +import mockMQTTCoreQuiet +from AWSIoTPythonSDK import MQTTLib +import AWSIoTPythonSDK.core.shadow.shadowManager as shadowManager + +class mockAWSIoTMQTTClient(MQTTLib.AWSIoTMQTTClient): + def __init__(self, clientID, protocolType, useWebsocket=False, cleanSession=True): + self._mqttCore = mockMQTTCore.mockMQTTCore(clientID, cleanSession, protocolType, useWebsocket) + +class mockAWSIoTMQTTClientWithSubRecords(MQTTLib.AWSIoTMQTTClient): + def __init__(self, clientID, protocolType, useWebsocket=False, cleanSession=True): + self._mqttCore = mockMQTTCore.mockMQTTCoreWithSubRecords(clientID, cleanSession, protocolType, useWebsocket) + +class mockAWSIoTMQTTClientQuiet(MQTTLib.AWSIoTMQTTClient): + def __init__(self, clientID, protocolType, useWebsocket=False, cleanSession=True): + self._mqttCore = mockMQTTCoreQuiet.mockMQTTCoreQuiet(clientID, cleanSession, protocolType, useWebsocket) + +class mockAWSIoTMQTTClientQuietWithSubRecords(MQTTLib.AWSIoTMQTTClient): + def __init__(self, clientID, protocolType, useWebsocket=False, cleanSession=True): + self._mqttCore = mockMQTTCoreQuiet.mockMQTTCoreQuietWithSubRecords(clientID, cleanSession, protocolType, useWebsocket) + +class mockAWSIoTMQTTShadowClient(MQTTLib.AWSIoTMQTTShadowClient): + def __init__(self, clientID, protocolType, useWebsocket=False, cleanSession=True): + # AWSIOTMQTTClient instance + self._AWSIoTMQTTClient = mockAWSIoTMQTTClientQuiet(clientID, protocolType, useWebsocket, cleanSession) + # Configure it to disable offline Publish Queueing + self._AWSIoTMQTTClient.configureOfflinePublishQueueing(0) + self._AWSIoTMQTTClient.configureDrainingFrequency(10) + # Now retrieve the configured mqttCore and init a shadowManager instance + self._shadowManager = shadowManager.shadowManager(self._AWSIoTMQTTClient._mqttCore) + + + diff --git a/test/sdk_mock/mockMQTTCore.py b/test/sdk_mock/mockMQTTCore.py new file mode 100755 index 0000000..e8c61b0 --- /dev/null +++ b/test/sdk_mock/mockMQTTCore.py @@ -0,0 +1,17 @@ +import sys +import mockPahoClient +import AWSIoTPythonSDK.core.protocol.mqttCore as mqttCore + +class mockMQTTCore(mqttCore.mqttCore): + def createPahoClient(self, clientID, cleanSession, userdata, protocol, useWebsocket): + return mockPahoClient.mockPahoClient(clientID, cleanSession, userdata, protocol, useWebsocket) + + def setReturnTupleForPahoClient(self, srcReturnTuple): + self._pahoClient.setReturnTuple(srcReturnTuple) + +class mockMQTTCoreWithSubRecords(mockMQTTCore): + def reinitSubscribePool(self): + self._subscribePoolRecords = dict() + + def subscribe(self, topic, qos, callback): + self._subscribePoolRecords[topic] = qos diff --git a/test/sdk_mock/mockMQTTCoreQuiet.py b/test/sdk_mock/mockMQTTCoreQuiet.py new file mode 100755 index 0000000..bdb6faa --- /dev/null +++ b/test/sdk_mock/mockMQTTCoreQuiet.py @@ -0,0 +1,34 @@ +import sys +import mockPahoClient +import AWSIoTPythonSDK.core.protocol.mqttCore as mqttCore + +class mockMQTTCoreQuiet(mqttCore.mqttCore): + def createPahoClient(self, clientID, cleanSession, userdata, protocol, useWebsocket): + return mockPahoClient.mockPahoClient(clientID, cleanSession, userdata, protocol, useWebsocket) + + def setReturnTupleForPahoClient(self, srcReturnTuple): + self._pahoClient.setReturnTuple(srcReturnTuple) + + def connect(self, keepAliveInterval): + pass + + def disconnect(self): + pass + + def publish(self, topic, payload, qos, retain): + pass + + def subscribe(self, topic, qos, callback): + pass + + def unsubscribe(self, topic): + pass + +class mockMQTTCoreQuietWithSubRecords(mockMQTTCoreQuiet): + + def reinitSubscribePool(self): + self._subscribePoolRecords = dict() + + def subscribe(self, topic, qos, callback): + self._subscribePoolRecords[topic] = qos + diff --git a/test/sdk_mock/mockMessage.py b/test/sdk_mock/mockMessage.py new file mode 100755 index 0000000..61c733a --- /dev/null +++ b/test/sdk_mock/mockMessage.py @@ -0,0 +1,7 @@ +class mockMessage: + topic = None + payload = None + + def __init__(self, srcTopic, srcPayload): + self.topic = srcTopic + self.payload = srcPayload diff --git a/test/sdk_mock/mockPahoClient.py b/test/sdk_mock/mockPahoClient.py new file mode 100755 index 0000000..8bcfda6 --- /dev/null +++ b/test/sdk_mock/mockPahoClient.py @@ -0,0 +1,49 @@ +import sys +import AWSIoTPythonSDK.core.protocol.paho.client as mqtt +import logging + +class mockPahoClient(mqtt.Client): + _log = logging.getLogger(__name__) + _returnTuple = (-1, -1) + # Callback handlers + on_connect = None + on_disconnect = None + on_message = None + on_publish = None + on_subsribe = None + on_unsubscribe = None + + def setReturnTuple(self, srcTuple): + self._returnTuple = srcTuple + + # Tool function + def tls_set(self, ca_certs=None, certfile=None, keyfile=None, cert_reqs=None, tls_version=None): + self._log.debug("tls_set called.") + + def loop_start(self): + self._log.debug("Socket thread started.") + + def loop_stop(self): + self._log.debug("Socket thread stopped.") + + def message_callback_add(self, sub, callback): + self._log.debug("Add a user callback. Topic: " + str(sub)) + + # MQTT API + def connect(self, host, port, keepalive): + self._log.debug("Connect called.") + + def disconnect(self): + self._log.debug("Disconnect called.") + + def publish(self, topic, payload, qos, retain): + self._log.debug("Publish called.") + return self._returnTuple + + def subscribe(self, topic, qos): + self._log.debug("Subscribe called.") + return self._returnTuple + + def unsubscribe(self, topic): + self._log.debug("Unsubscribe called.") + return self._returnTuple diff --git a/test/sdk_mock/mockSSLSocket.py b/test/sdk_mock/mockSSLSocket.py new file mode 100755 index 0000000..6bf953e --- /dev/null +++ b/test/sdk_mock/mockSSLSocket.py @@ -0,0 +1,104 @@ +import socket +import ssl + +class mockSSLSocket: + def __init__(self): + self._readBuffer = bytearray() + self._writeBuffer = bytearray() + self._isClosed = False + self._isFragmented = False + self._fragmentDoneThrowError = False + self._currentFragments = bytearray() + self._fragments = list() + self._flipWriteError = False + self._flipWriteErrorCount = 0 + + # TestHelper APIs + def refreshReadBuffer(self, bytesToLoad): + self._readBuffer = bytesToLoad + + def reInit(self): + self._readBuffer = bytearray() + self._writeBuffer = bytearray() + self._isClosed = False + self._isFragmented = False + self._fragmentDoneThrowError = False + self._currentFragments = bytearray() + self._fragments = list() + self._flipWriteError = False + self._flipWriteErrorCount = 0 + + def getReaderBuffer(self): + return self._readBuffer + + def getWriteBuffer(self): + return self._writeBuffer + + def addReadBufferFragment(self, fragmentElement): + self._fragments.append(fragmentElement) + + def setReadFragmented(self): + self._isFragmented = True + + def setFlipWriteError(self): + self._flipWriteError = True + self._flipWriteErrorCount = 0 + + def loadFirstFragmented(self): + self._currentFragments = self._fragments.pop(0) + + # Public APIs + # Should return bytes, not string + def read(self, numberOfBytes): + if not self._isFragmented: # Read a lot, then nothing + if len(self._readBuffer) == 0: + raise socket.error(ssl.SSL_ERROR_WANT_READ, "End of read buffer") + # If we have enough data for the requested amount, give them out + if numberOfBytes <= len(self._readBuffer): + ret = self._readBuffer[0:numberOfBytes] + self._readBuffer = self._readBuffer[numberOfBytes:] + else: + ret = self._readBuffer + self._readBuffer = self._readBuffer[len(self._readBuffer):] # Empty + return ret + else: # Read 1 fragement util it is empty, then throw error, then load in next + if self._fragmentDoneThrowError and len(self._fragments) > 0: + self._currentFragments = self._fragments.pop(0) # Load in next fragment + self._fragmentDoneThrowError = False # Reset ThrowError flag + raise socket.error(ssl.SSL_ERROR_WANT_READ, "Not ready for read op") + # If we have enough data for the requested amount in the current fragment, give them out + ret = bytearray() + if numberOfBytes <= len(self._currentFragments): + ret = self._currentFragments[0:numberOfBytes] + self._currentFragments = self._currentFragments[numberOfBytes:] + if len(self._currentFragments) == 0: + self._fragmentDoneThrowError = True # Will throw error next time + else: + ret = self._currentFragments + self._currentFragments = self._currentFragments[len(self._currentFragments):] # Empty + self._fragmentDoneThrowError = True + return ret + + # Should write bytes, not string + def write(self, bytesToWrite): + if self._flipWriteError: + if self._flipWriteErrorCount % 2 == 1: + self._writeBuffer += bytesToWrite # bytesToWrite should always be in 'bytes' type + self._flipWriteErrorCount += 1 + return len(bytesToWrite) + else: + self._flipWriteErrorCount += 1 + raise socket.error(ssl.SSL_ERROR_WANT_WRITE, "Not ready for write op") + else: + self._writeBuffer += bytesToWrite # bytesToWrite should always be in 'bytes' type + return len(bytesToWrite) + + def close(self): + self._isClosed = True + + + + + + + diff --git a/test/sdk_mock/mockSecuredWebsocketCore.py b/test/sdk_mock/mockSecuredWebsocketCore.py new file mode 100755 index 0000000..4f2efe9 --- /dev/null +++ b/test/sdk_mock/mockSecuredWebsocketCore.py @@ -0,0 +1,35 @@ +from test.sdk_mock.mockSigV4Core import mockSigV4Core +from AWSIoTPythonSDK.core.protocol.connection.cores import SecuredWebSocketCore + + +class mockSecuredWebsocketCoreNoRealHandshake(SecuredWebSocketCore): + def _createSigV4Core(self): + ret = mockSigV4Core() + ret.setNoEnvVar(False) # Always has Env Var + return ret + + def _handShake(self, hostAddress, portNumber): # Override to pass handshake + pass + + def _generateMaskKey(self): + return bytearray(str("1234"), 'utf-8') # Arbitrary mask key for testing + + +class MockSecuredWebSocketCoreNoSocketIO(SecuredWebSocketCore): + def _createSigV4Core(self): + ret = mockSigV4Core() + ret.setNoEnvVar(False) # Always has Env Var + return ret + + def _generateMaskKey(self): + return bytearray(str("1234"), 'utf-8') # Arbitrary mask key for testing + + def _getTimeoutSec(self): + return 3 # 3 sec to time out from waiting for handshake response for testing + + +class MockSecuredWebSocketCoreWithRealHandshake(SecuredWebSocketCore): + def _createSigV4Core(self): + ret = mockSigV4Core() + ret.setNoEnvVar(False) # Always has Env Var + return ret diff --git a/test/sdk_mock/mockSigV4Core.py b/test/sdk_mock/mockSigV4Core.py new file mode 100755 index 0000000..142b27f --- /dev/null +++ b/test/sdk_mock/mockSigV4Core.py @@ -0,0 +1,17 @@ +from AWSIoTPythonSDK.core.protocol.connection.cores import SigV4Core + + +class mockSigV4Core(SigV4Core): + _forceNoEnvVar = False + + def setNoEnvVar(self, srcVal): + self._forceNoEnvVar = srcVal + + def _checkKeyInEnv(self): # Simulate no Env Var + if self._forceNoEnvVar: + return dict() # Return empty list + else: + ret = dict() + ret["aws_access_key_id"] = "blablablaID" + ret["aws_secret_access_key"] = "blablablaSecret" + return ret diff --git a/test/test_mqtt_lib.py b/test/test_mqtt_lib.py new file mode 100644 index 0000000..b74375b --- /dev/null +++ b/test/test_mqtt_lib.py @@ -0,0 +1,304 @@ +from AWSIoTPythonSDK.core.protocol.mqtt_core import MqttCore +from AWSIoTPythonSDK.MQTTLib import AWSIoTMQTTClient +from AWSIoTPythonSDK.MQTTLib import AWSIoTMQTTShadowClient +from AWSIoTPythonSDK.MQTTLib import DROP_NEWEST +try: + from mock import patch + from mock import MagicMock +except: + from unittest.mock import patch + from unittest.mock import MagicMock + + +PATCH_MODULE_LOCATION = "AWSIoTPythonSDK.MQTTLib." +CLIENT_ID = "DefaultClientId" +SHADOW_CLIENT_ID = "DefaultShadowClientId" +DUMMY_HOST = "dummy.host" +PORT_443 = 443 +PORT_8883 = 8883 +DEFAULT_KEEPALIVE_SEC = 600 +DUMMY_TOPIC = "dummy/topic" +DUMMY_PAYLOAD = "dummy/payload" +DUMMY_QOS = 1 +DUMMY_AWS_ACCESS_KEY_ID = "DummyKeyId" +DUMMY_AWS_SECRET_KEY = "SecretKey" +DUMMY_AWS_TOKEN = "Token" +DUMMY_CA_PATH = "path/to/ca" +DUMMY_CERT_PATH = "path/to/cert" +DUMMY_KEY_PATH = "path/to/key" +DUMMY_BASE_RECONNECT_BACKOFF_SEC = 1 +DUMMY_MAX_RECONNECT_BACKOFF_SEC = 32 +DUMMY_STABLE_CONNECTION_SEC = 16 +DUMMY_QUEUE_SIZE = 100 +DUMMY_DRAINING_FREQUENCY = 2 +DUMMY_TIMEOUT_SEC = 10 +DUMMY_USER_NAME = "UserName" +DUMMY_PASSWORD = "Password" + + +class TestMqttLibShadowClient: + + def setup_method(self, test_method): + self._use_mock_mqtt_core() + + def _use_mock_mqtt_core(self): + self.mqtt_core_patcher = patch(PATCH_MODULE_LOCATION + "MqttCore", spec=MqttCore) + self.mock_mqtt_core_constructor = self.mqtt_core_patcher.start() + self.mqtt_core_mock = MagicMock() + self.mock_mqtt_core_constructor.return_value = self.mqtt_core_mock + self.iot_mqtt_shadow_client = AWSIoTMQTTShadowClient(SHADOW_CLIENT_ID) + + def teardown_method(self, test_method): + self.mqtt_core_patcher.stop() + + def test_iot_mqtt_shadow_client_with_provided_mqtt_client(self): + mock_iot_mqtt_client = MagicMock() + iot_mqtt_shadow_client_with_provided_mqtt_client = AWSIoTMQTTShadowClient(SHADOW_CLIENT_ID, awsIoTMQTTClient=mock_iot_mqtt_client) + assert mock_iot_mqtt_client.configureOfflinePublishQueueing.called is False + + def test_iot_mqtt_shadow_client_connect_default_keepalive(self): + self.iot_mqtt_shadow_client.connect() + self.mqtt_core_mock.connect.assert_called_once_with(DEFAULT_KEEPALIVE_SEC) + + def test_iot_mqtt_shadow_client_auto_enable_when_use_cert_over_443(self): + self.mqtt_core_mock.use_wss.return_value = False + self.iot_mqtt_shadow_client.configureEndpoint(hostName=DUMMY_HOST, portNumber=PORT_443) + self.mqtt_core_mock.configure_alpn_protocols.assert_called_once() + + def test_iot_mqtt_shadow_client_alpn_auto_disable_when_use_wss(self): + self.mqtt_core_mock.use_wss.return_value = True + self.iot_mqtt_shadow_client.configureEndpoint(hostName=DUMMY_HOST, portNumber=PORT_443) + assert self.mqtt_core_mock.configure_alpn_protocols.called is False + + def test_iot_mqtt_shadow_client_alpn_auto_disable_when_use_cert_over_8883(self): + self.mqtt_core_mock.use_wss.return_value = False + self.iot_mqtt_shadow_client.configureEndpoint(hostName=DUMMY_HOST, portNumber=PORT_8883) + assert self.mqtt_core_mock.configure_alpn_protocols.called is False + + def test_iot_mqtt_shadow_client_clear_last_will(self): + self.iot_mqtt_shadow_client.clearLastWill() + self.mqtt_core_mock.clear_last_will.assert_called_once() + + def test_iot_mqtt_shadow_client_configure_endpoint(self): + self.iot_mqtt_shadow_client.configureEndpoint(DUMMY_HOST, PORT_8883) + self.mqtt_core_mock.configure_endpoint.assert_called_once() + + def test_iot_mqtt_shadow_client_configure_iam_credentials(self): + self.iot_mqtt_shadow_client.configureIAMCredentials(DUMMY_AWS_ACCESS_KEY_ID, DUMMY_AWS_SECRET_KEY, DUMMY_AWS_TOKEN) + self.mqtt_core_mock.configure_iam_credentials.assert_called_once() + + def test_iot_mqtt_shadowclient_configure_credentials(self): + self.iot_mqtt_shadow_client.configureCredentials(DUMMY_CA_PATH, DUMMY_KEY_PATH, DUMMY_CERT_PATH) + self.mqtt_core_mock.configure_cert_credentials.assert_called_once() + + def test_iot_mqtt_shadow_client_configure_auto_reconnect_backoff(self): + self.iot_mqtt_shadow_client.configureAutoReconnectBackoffTime(DUMMY_BASE_RECONNECT_BACKOFF_SEC, + DUMMY_MAX_RECONNECT_BACKOFF_SEC, + DUMMY_STABLE_CONNECTION_SEC) + self.mqtt_core_mock.configure_reconnect_back_off.assert_called_once_with(DUMMY_BASE_RECONNECT_BACKOFF_SEC, + DUMMY_MAX_RECONNECT_BACKOFF_SEC, + DUMMY_STABLE_CONNECTION_SEC) + + def test_iot_mqtt_shadow_client_configure_offline_publish_queueing(self): + # This configurable is done at object initialization. We do not allow customers to configure this. + self.mqtt_core_mock.configure_offline_requests_queue.assert_called_once_with(0, DROP_NEWEST) # Disabled + + def test_iot_mqtt_client_configure_draining_frequency(self): + # This configurable is done at object initialization. We do not allow customers to configure this. + # Sine queuing is disabled, draining interval configuration is meaningless. + # "10" is just a placeholder value in the internal implementation. + self.mqtt_core_mock.configure_draining_interval_sec.assert_called_once_with(1/float(10)) + + def test_iot_mqtt_client_configure_connect_disconnect_timeout(self): + self.iot_mqtt_shadow_client.configureConnectDisconnectTimeout(DUMMY_TIMEOUT_SEC) + self.mqtt_core_mock.configure_connect_disconnect_timeout_sec.assert_called_once_with(DUMMY_TIMEOUT_SEC) + + def test_iot_mqtt_client_configure_mqtt_operation_timeout(self): + self.iot_mqtt_shadow_client.configureMQTTOperationTimeout(DUMMY_TIMEOUT_SEC) + self.mqtt_core_mock.configure_operation_timeout_sec.assert_called_once_with(DUMMY_TIMEOUT_SEC) + + def test_iot_mqtt_client_configure_user_name_password(self): + self.iot_mqtt_shadow_client.configureUsernamePassword(DUMMY_USER_NAME, DUMMY_PASSWORD) + self.mqtt_core_mock.configure_username_password.assert_called_once_with(DUMMY_USER_NAME, DUMMY_PASSWORD) + + def test_iot_mqtt_client_enable_metrics_collection(self): + self.iot_mqtt_shadow_client.enableMetricsCollection() + self.mqtt_core_mock.enable_metrics_collection.assert_called_once() + + def test_iot_mqtt_client_disable_metrics_collection(self): + self.iot_mqtt_shadow_client.disableMetricsCollection() + self.mqtt_core_mock.disable_metrics_collection.assert_called_once() + + def test_iot_mqtt_client_callback_registration_upon_connect(self): + fake_on_online_callback = MagicMock() + fake_on_offline_callback = MagicMock() + + self.iot_mqtt_shadow_client.onOnline = fake_on_online_callback + self.iot_mqtt_shadow_client.onOffline = fake_on_offline_callback + # `onMessage` is used internally by the SDK. We do not expose this callback configurable to the customer + + self.iot_mqtt_shadow_client.connect() + + assert self.mqtt_core_mock.on_online == fake_on_online_callback + assert self.mqtt_core_mock.on_offline == fake_on_offline_callback + self.mqtt_core_mock.connect.assert_called_once() + + def test_iot_mqtt_client_disconnect(self): + self.iot_mqtt_shadow_client.disconnect() + self.mqtt_core_mock.disconnect.assert_called_once() + + +class TestMqttLibMqttClient: + + def setup_method(self, test_method): + self._use_mock_mqtt_core() + + def _use_mock_mqtt_core(self): + self.mqtt_core_patcher = patch(PATCH_MODULE_LOCATION + "MqttCore", spec=MqttCore) + self.mock_mqtt_core_constructor = self.mqtt_core_patcher.start() + self.mqtt_core_mock = MagicMock() + self.mock_mqtt_core_constructor.return_value = self.mqtt_core_mock + self.iot_mqtt_client = AWSIoTMQTTClient(CLIENT_ID) + + def teardown_method(self, test_method): + self.mqtt_core_patcher.stop() + + def test_iot_mqtt_client_connect_default_keepalive(self): + self.iot_mqtt_client.connect() + self.mqtt_core_mock.connect.assert_called_once_with(DEFAULT_KEEPALIVE_SEC) + + def test_iot_mqtt_client_connect_async_default_keepalive(self): + self.iot_mqtt_client.connectAsync() + self.mqtt_core_mock.connect_async.assert_called_once_with(DEFAULT_KEEPALIVE_SEC, None) + + def test_iot_mqtt_client_alpn_auto_enable_when_use_cert_over_443(self): + self.mqtt_core_mock.use_wss.return_value = False + self.iot_mqtt_client.configureEndpoint(hostName=DUMMY_HOST, portNumber=PORT_443) + self.mqtt_core_mock.configure_alpn_protocols.assert_called_once() + + def test_iot_mqtt_client_alpn_auto_disable_when_use_wss(self): + self.mqtt_core_mock.use_wss.return_value = True + self.iot_mqtt_client.configureEndpoint(hostName=DUMMY_HOST, portNumber=PORT_443) + assert self.mqtt_core_mock.configure_alpn_protocols.called is False + + def test_iot_mqtt_client_alpn_auto_disable_when_use_cert_over_8883(self): + self.mqtt_core_mock.use_wss.return_value = False + self.iot_mqtt_client.configureEndpoint(hostName=DUMMY_HOST, portNumber=PORT_8883) + assert self.mqtt_core_mock.configure_alpn_protocols.called is False + + def test_iot_mqtt_client_configure_last_will(self): + self.iot_mqtt_client.configureLastWill(topic=DUMMY_TOPIC, payload=DUMMY_PAYLOAD, QoS=DUMMY_QOS) + self.mqtt_core_mock.configure_last_will.assert_called_once_with(DUMMY_TOPIC, DUMMY_PAYLOAD, DUMMY_QOS, False) + + def test_iot_mqtt_client_clear_last_will(self): + self.iot_mqtt_client.clearLastWill() + self.mqtt_core_mock.clear_last_will.assert_called_once() + + def test_iot_mqtt_client_configure_endpoint(self): + self.iot_mqtt_client.configureEndpoint(DUMMY_HOST, PORT_8883) + self.mqtt_core_mock.configure_endpoint.assert_called_once() + + def test_iot_mqtt_client_configure_iam_credentials(self): + self.iot_mqtt_client.configureIAMCredentials(DUMMY_AWS_ACCESS_KEY_ID, DUMMY_AWS_SECRET_KEY, DUMMY_AWS_TOKEN) + self.mqtt_core_mock.configure_iam_credentials.assert_called_once() + + def test_iot_mqtt_client_configure_credentials(self): + self.iot_mqtt_client.configureCredentials(DUMMY_CA_PATH, DUMMY_KEY_PATH, DUMMY_CERT_PATH) + self.mqtt_core_mock.configure_cert_credentials.assert_called_once() + + def test_iot_mqtt_client_configure_auto_reconnect_backoff(self): + self.iot_mqtt_client.configureAutoReconnectBackoffTime(DUMMY_BASE_RECONNECT_BACKOFF_SEC, + DUMMY_MAX_RECONNECT_BACKOFF_SEC, + DUMMY_STABLE_CONNECTION_SEC) + self.mqtt_core_mock.configure_reconnect_back_off.assert_called_once_with(DUMMY_BASE_RECONNECT_BACKOFF_SEC, + DUMMY_MAX_RECONNECT_BACKOFF_SEC, + DUMMY_STABLE_CONNECTION_SEC) + + def test_iot_mqtt_client_configure_offline_publish_queueing(self): + self.iot_mqtt_client.configureOfflinePublishQueueing(DUMMY_QUEUE_SIZE) + self.mqtt_core_mock.configure_offline_requests_queue.assert_called_once_with(DUMMY_QUEUE_SIZE, DROP_NEWEST) + + def test_iot_mqtt_client_configure_draining_frequency(self): + self.iot_mqtt_client.configureDrainingFrequency(DUMMY_DRAINING_FREQUENCY) + self.mqtt_core_mock.configure_draining_interval_sec.assert_called_once_with(1/float(DUMMY_DRAINING_FREQUENCY)) + + def test_iot_mqtt_client_configure_connect_disconnect_timeout(self): + self.iot_mqtt_client.configureConnectDisconnectTimeout(DUMMY_TIMEOUT_SEC) + self.mqtt_core_mock.configure_connect_disconnect_timeout_sec.assert_called_once_with(DUMMY_TIMEOUT_SEC) + + def test_iot_mqtt_client_configure_mqtt_operation_timeout(self): + self.iot_mqtt_client.configureMQTTOperationTimeout(DUMMY_TIMEOUT_SEC) + self.mqtt_core_mock.configure_operation_timeout_sec.assert_called_once_with(DUMMY_TIMEOUT_SEC) + + def test_iot_mqtt_client_configure_user_name_password(self): + self.iot_mqtt_client.configureUsernamePassword(DUMMY_USER_NAME, DUMMY_PASSWORD) + self.mqtt_core_mock.configure_username_password.assert_called_once_with(DUMMY_USER_NAME, DUMMY_PASSWORD) + + def test_iot_mqtt_client_enable_metrics_collection(self): + self.iot_mqtt_client.enableMetricsCollection() + self.mqtt_core_mock.enable_metrics_collection.assert_called_once() + + def test_iot_mqtt_client_disable_metrics_collection(self): + self.iot_mqtt_client.disableMetricsCollection() + self.mqtt_core_mock.disable_metrics_collection.assert_called_once() + + def test_iot_mqtt_client_callback_registration_upon_connect(self): + fake_on_online_callback = MagicMock() + fake_on_offline_callback = MagicMock() + fake_on_message_callback = MagicMock() + + self.iot_mqtt_client.onOnline = fake_on_online_callback + self.iot_mqtt_client.onOffline = fake_on_offline_callback + self.iot_mqtt_client.onMessage = fake_on_message_callback + + self.iot_mqtt_client.connect() + + assert self.mqtt_core_mock.on_online == fake_on_online_callback + assert self.mqtt_core_mock.on_offline == fake_on_offline_callback + assert self.mqtt_core_mock.on_message == fake_on_message_callback + self.mqtt_core_mock.connect.assert_called_once() + + def test_iot_mqtt_client_disconnect(self): + self.iot_mqtt_client.disconnect() + self.mqtt_core_mock.disconnect.assert_called_once() + + def test_iot_mqtt_client_publish(self): + self.iot_mqtt_client.publish(DUMMY_TOPIC, DUMMY_PAYLOAD, DUMMY_QOS) + self.mqtt_core_mock.publish.assert_called_once_with(DUMMY_TOPIC, DUMMY_PAYLOAD, DUMMY_QOS, False) + + def test_iot_mqtt_client_subscribe(self): + message_callback = MagicMock() + self.iot_mqtt_client.subscribe(DUMMY_TOPIC, DUMMY_QOS, message_callback) + self.mqtt_core_mock.subscribe.assert_called_once_with(DUMMY_TOPIC, DUMMY_QOS, message_callback) + + def test_iot_mqtt_client_unsubscribe(self): + self.iot_mqtt_client.unsubscribe(DUMMY_TOPIC) + self.mqtt_core_mock.unsubscribe.assert_called_once_with(DUMMY_TOPIC) + + def test_iot_mqtt_client_connect_async(self): + connack_callback = MagicMock() + self.iot_mqtt_client.connectAsync(ackCallback=connack_callback) + self.mqtt_core_mock.connect_async.assert_called_once_with(DEFAULT_KEEPALIVE_SEC, connack_callback) + + def test_iot_mqtt_client_disconnect_async(self): + disconnect_callback = MagicMock() + self.iot_mqtt_client.disconnectAsync(ackCallback=disconnect_callback) + self.mqtt_core_mock.disconnect_async.assert_called_once_with(disconnect_callback) + + def test_iot_mqtt_client_publish_async(self): + puback_callback = MagicMock() + self.iot_mqtt_client.publishAsync(DUMMY_TOPIC, DUMMY_PAYLOAD, DUMMY_QOS, puback_callback) + self.mqtt_core_mock.publish_async.assert_called_once_with(DUMMY_TOPIC, DUMMY_PAYLOAD, DUMMY_QOS, + False, puback_callback) + + def test_iot_mqtt_client_subscribe_async(self): + suback_callback = MagicMock() + message_callback = MagicMock() + self.iot_mqtt_client.subscribeAsync(DUMMY_TOPIC, DUMMY_QOS, suback_callback, message_callback) + self.mqtt_core_mock.subscribe_async.assert_called_once_with(DUMMY_TOPIC, DUMMY_QOS, + suback_callback, message_callback) + + def test_iot_mqtt_client_unsubscribe_async(self): + unsuback_callback = MagicMock() + self.iot_mqtt_client.unsubscribeAsync(DUMMY_TOPIC, unsuback_callback) + self.mqtt_core_mock.unsubscribe_async.assert_called_once_with(DUMMY_TOPIC, unsuback_callback)