diff --git a/.env b/.env index 59f7df83e..c62498b00 100644 --- a/.env +++ b/.env @@ -11,6 +11,6 @@ scheme=https # Your version of Splunk (default: 6.2) version=9.0 # Bearer token for authentication -#bearerToken="" +#splunkToken="" # Session key for authentication -#sessionKey="" +#token="" diff --git a/.github/dependabot.yaml b/.github/dependabot.yaml new file mode 100644 index 000000000..9df999425 --- /dev/null +++ b/.github/dependabot.yaml @@ -0,0 +1,7 @@ +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + target-branch: "master" + schedule: + interval: "weekly" diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 4d11da591..90ef8171b 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -9,9 +9,9 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout source - uses: actions/checkout@v2.3.2 + uses: actions/checkout@v3 - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: 3.7 - name: Install dependencies @@ -19,7 +19,7 @@ jobs: - name: Build package run: python setup.py sdist - name: Publish package to PyPI - uses: pypa/gh-action-pypi-publish@v1.3.1 + uses: pypa/gh-action-pypi-publish@v1.8.10 with: user: __token__ password: ${{ secrets.pypi_password }} @@ -29,12 +29,11 @@ jobs: run: | rm -rf ./docs/_build tox -e docs - cd ./docs/_build/html && zip -r ../docs_html.zip . -x ".*" -x "__MACOSX" - name : Docs Upload uses: actions/upload-artifact@v3 with: - name: apidocs - path: docs/_build/docs_html.zip + name: python_sdk_docs + path: docs/_build/html # Test upload # - name: Publish package to TestPyPI # uses: pypa/gh-action-pypi-publish@master diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 06278e298..dc8bd2089 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,7 +1,7 @@ name: Python CI on: - [ push, pull_request ] + [ push, pull_request, workflow_dispatch ] jobs: build: @@ -11,21 +11,22 @@ jobs: matrix: os: - ubuntu-latest - python: [ 2.7, 3.7 ] + python: [ 3.7, 3.9, 3.13] splunk-version: + - "8.1" - "8.2" - "latest" fail-fast: false steps: - name: Checkout code - uses: actions/checkout@v2 + uses: actions/checkout@v3 - - name: Run docker-compose - run: SPLUNK_VERSION=${{matrix.splunk-version}} docker-compose up -d + - name: Run docker compose + run: SPLUNK_VERSION=${{matrix.splunk-version}} docker compose up -d - name: Setup Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python }} @@ -34,3 +35,6 @@ jobs: - name: Test Execution run: tox -e py + fossa-scan: + uses: splunk/oss-scanning-public/.github/workflows/oss-scan.yml@main + secrets: inherit \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index e36b326c3..288543c6e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,12 +1,83 @@ # Splunk Enterprise SDK for Python Changelog +## Version 2.1.0 + +### Changes +* [#516](https://github.com/splunk/splunk-sdk-python/pull/516) Added support for macros +* Remove deprecated `wrap_socket` in `Contex` class. +* Added explicit support for self signed certificates in https +* Enforce minimal required tls version in https connection +* Add support for python 3.13 +* [#559](https://github.com/splunk/splunk-sdk-python/pull/559/) Add exception logging + +## Version 2.0.2 + +### Minor changes +* Added six.py file back + + +## Version 2.0.1 + +### Bug fixes +* [#567](https://github.com/splunk/splunk-sdk-python/issues/567) Moved "deprecation" dependency + + +## Version 2.0.0 + +### Feature updates +* `ensure_binary`, `ensure_str` and `assert_regex` utility methods have been migrated from `six.py` to `splunklib/utils.py` + +### Major changes +* Removed code specific to Python 2 +* Removed six.py dependency +* Removed `__future__` imports +* Refactored and Updated `splunklib` and `tests` to utilize Python 3 features +* Updated CI test matrix to run with Python versions - 3.7 and 3.9 +* Refactored Code throwing `deprecation` warnings +* Refactored Code violating Pylint rules + +### Bug fixes +* [#527](https://github.com/splunk/splunk-sdk-python/issues/527) Added check for user roles +* Fix to access the metadata "finished" field in search commands using the v2 protocol +* Fix for error messages about ChunkedExternProcessor in splunkd.log for Custom Search Commands + + +## Version 1.7.4 + +### Bug fixes +* [#532](https://github.com/splunk/splunk-sdk-python/pull/532) Update encoding errors mode to 'replace' [[issue#505](https://github.com/splunk/splunk-sdk-python/issues/505)] +* [#507](https://github.com/splunk/splunk-sdk-python/pull/507) Masked sensitive data in logs [[issue#506](https://github.com/splunk/splunk-sdk-python/issues/506)] + +### Minor changes +* [#530](https://github.com/splunk/splunk-sdk-python/pull/530) Update GitHub CI build status in README and removed RTD(Read The Docs) reference + +## Version 1.7.3 + +### Bug fixes +* [#493](https://github.com/splunk/splunk-sdk-python/pull/493) Fixed file permission for event_writer.py file [[issue#487](https://github.com/splunk/splunk-sdk-python/issues/487)] +* [#500](https://github.com/splunk/splunk-sdk-python/pull/500) Replaced index_field with accelerated_field for kvstore [[issue#497](https://github.com/splunk/splunk-sdk-python/issues/497)] +* [#502](https://github.com/splunk/splunk-sdk-python/pull/502) Updated check for IPv6 addresses + +### Minor changes +* [#490](https://github.com/splunk/splunk-sdk-python/pull/490) Added ACL properties update feature +* [#495](https://github.com/splunk/splunk-sdk-python/pull/495) Added Splunk 8.1 in GitHub Actions Matrix +* [#485](https://github.com/splunk/splunk-sdk-python/pull/485) Added test case for cookie persistence +* [#503](https://github.com/splunk/splunk-sdk-python/pull/503) README updates on accessing "service" instance in CSC and ModularInput apps +* [#504](https://github.com/splunk/splunk-sdk-python/pull/504) Updated authentication token names in docs to reduce confusion +* [#494](https://github.com/splunk/splunk-sdk-python/pull/494) Reuse splunklib.__version__ in handler.request + +## Version 1.7.2 + +### Minor changes +* [#482](https://github.com/splunk/splunk-sdk-python/pull/482) Special handling related to the semantic versioning of specific Search APIs functional in Splunk Enterprise 9.0.2 and (Splunk Cloud 9.0.2209). These SDK changes will enable seamless transition between the APIs based on the version of the Splunk Enterprise in use + ## Version 1.7.1 ### Bug fixes * [#471](https://github.com/splunk/splunk-sdk-python/pull/471) Fixed support of Load Balancer "sticky sessions" (persistent cookies) [[issue#438](https://github.com/splunk/splunk-sdk-python/issues/438)] ### Minor changes -* [#466](https://github.com/splunk/splunk-sdk-python/pull/466) tests for CSC apps +* [#466](https://github.com/splunk/splunk-sdk-python/pull/466) Tests for CSC apps * [#467](https://github.com/splunk/splunk-sdk-python/pull/467) Added 'kwargs' parameter for Saved Search History function * [#475](https://github.com/splunk/splunk-sdk-python/pull/475) README updates @@ -16,10 +87,10 @@ * [#468](https://github.com/splunk/splunk-sdk-python/pull/468) SDK Support for splunkd search API changes ### Bug fixes -* [#464](https://github.com/splunk/splunk-sdk-python/pull/464) updated checks for wildcards in StoragePasswords [[issue#458](https://github.com/splunk/splunk-sdk-python/issues/458)] +* [#464](https://github.com/splunk/splunk-sdk-python/pull/464) Updated checks for wildcards in StoragePasswords [[issue#458](https://github.com/splunk/splunk-sdk-python/issues/458)] ### Minor changes -* [#463](https://github.com/splunk/splunk-sdk-python/pull/463) Preserve thirdparty cookies +* [#463](https://github.com/splunk/splunk-sdk-python/pull/463) Preserve third-party cookies ## Version 1.6.20 @@ -45,7 +116,7 @@ * Pre-requisite: Query parameter 'output_mode' must be set to 'json' * Improves performance by approx ~80-90% * ResultsReader is deprecated and will be removed in future releases (NOTE: Please migrate to JSONResultsReader) -* [#437](https://github.com/splunk/splunk-sdk-python/pull/437) added setup_logging() method in splunklib for logging +* [#437](https://github.com/splunk/splunk-sdk-python/pull/437) Added setup_logging() method in splunklib for logging * [#426](https://github.com/splunk/splunk-sdk-python/pull/426) Added new github_commit modular input example * [#392](https://github.com/splunk/splunk-sdk-python/pull/392) Break out search argument to option parsing for v2 custom search commands * [#384](https://github.com/splunk/splunk-sdk-python/pull/384) Added Float parameter validator for custom search commands @@ -61,17 +132,17 @@ ### Minor changes * [#440](https://github.com/splunk/splunk-sdk-python/pull/440) Github release workflow modified to generate docs * [#430](https://github.com/splunk/splunk-sdk-python/pull/430) Fix indentation in README -* [#429](https://github.com/splunk/splunk-sdk-python/pull/429) documented how to access modular input metadata +* [#429](https://github.com/splunk/splunk-sdk-python/pull/429) Documented how to access modular input metadata * [#427](https://github.com/splunk/splunk-sdk-python/pull/427) Replace .splunkrc with .env file in test and examples * [#424](https://github.com/splunk/splunk-sdk-python/pull/424) Float validator test fix -* [#423](https://github.com/splunk/splunk-sdk-python/pull/423) Python3 compatibility for ResponseReader.__str__() +* [#423](https://github.com/splunk/splunk-sdk-python/pull/423) Python 3 compatibility for ResponseReader.__str__() * [#422](https://github.com/splunk/splunk-sdk-python/pull/422) ordereddict and all its reference removed * [#421](https://github.com/splunk/splunk-sdk-python/pull/421) Update README.md * [#387](https://github.com/splunk/splunk-sdk-python/pull/387) Update filter.py * [#331](https://github.com/splunk/splunk-sdk-python/pull/331) Fix a couple of warnings spotted when running python 2.7 tests * [#330](https://github.com/splunk/splunk-sdk-python/pull/330) client: use six.string_types instead of basestring * [#329](https://github.com/splunk/splunk-sdk-python/pull/329) client: remove outdated comment in Index.submit -* [#262](https://github.com/splunk/splunk-sdk-python/pull/262) properly add parameters to request based on the method of the request +* [#262](https://github.com/splunk/splunk-sdk-python/pull/262) Properly add parameters to request based on the method of the request * [#237](https://github.com/splunk/splunk-sdk-python/pull/237) Don't output close tags if you haven't written a start tag * [#149](https://github.com/splunk/splunk-sdk-python/pull/149) "handlers" stanza missing in examples/searchcommands_template/default/logging.conf @@ -141,7 +212,7 @@ https://github.com/splunk/splunk-sdk-python/blob/develop/README.md#customization * Fixed regression in mod inputs which resulted in error ’file' object has no attribute 'readable’, by not forcing to text/bytes in mod inputs event writer any longer. ### Minor changes -* Minor updates to the splunklib search commands to support Python3 +* Minor updates to the splunklib search commands to support Python 3 ## Version 1.6.12 @@ -150,32 +221,32 @@ https://github.com/splunk/splunk-sdk-python/blob/develop/README.md#customization * Made modinput text consistent ### Bug fixes -* Changed permissions from 755 to 644 for python files to pass appinspect checks +* Changed permissions from 755 to 644 for Python files to pass Appinspect checks * Removed version check on ssl verify toggle ## Version 1.6.11 ### Bug Fix -* Fix custom search command V2 failures on Windows for Python3 +* Fix custom search command V2 failures on Windows for Python 3 ## Version 1.6.10 ### Bug Fix -* Fix long type gets wrong values on windows for python 2 +* Fix long type gets wrong values on Windows for Python 2 ## Version 1.6.9 ### Bug Fix -* Fix buffered input in python 3 +* Fix buffered input in Python 3 ## Version 1.6.8 ### Bug Fix -* Fix custom search command on python 3 on windows +* Fix custom search command on Python 3 on Windows ## Version 1.6.7 @@ -249,7 +320,7 @@ The following bugs have been fixed: ### Minor changes -* Use relative imports throughout the the SDK. +* Use relative imports throughout the SDK. * Performance improvement when constructing `Input` entity paths. @@ -388,7 +459,7 @@ The following bugs have been fixed: * Added a script (GenerateHelloCommand) to the searchcommand_app to generate a custom search command. -* Added a human readable argument titles to modular input examples. +* Added a human-readable argument titles to modular input examples. * Renamed the searchcommand `csv` module to `splunk_csv`. @@ -396,7 +467,7 @@ The following bugs have been fixed: * Now entities that contain slashes in their name can be created, accessed and deleted correctly. -* Fixed a perfomance issue with connecting to Splunk on Windows. +* Fixed a performance issue with connecting to Splunk on Windows. * Improved the `service.restart()` function. @@ -490,7 +561,7 @@ The following bugs have been fixed: ### Bug fixes * When running `setup.py dist` without running `setup.py build`, there is no - longer an `No such file or directory` error on the command line, and the + longer a `No such file or directory` error on the command line, and the command behaves as expected. * When setting the sourcetype of a modular input event, events are indexed @@ -512,7 +583,7 @@ The following bugs have been fixed: ### Bug fix * When running `setup.py dist` without running `setup.py build`, there is no - longer an `No such file or directory` error on the command line, and the + longer a `No such file or directory` error on the command line, and the command behaves as expected. * When setting the sourcetype of a modular input event, events are indexed properly. diff --git a/Makefile b/Makefile index 2810c6aec..58d53228b 100644 --- a/Makefile +++ b/Makefile @@ -34,7 +34,7 @@ docs: .PHONY: test test: @echo "$(ATTN_COLOR)==> test $(NO_COLOR)" - @tox -e py27,py37 + @tox -e py37,py39 .PHONY: test_specific test_specific: @@ -44,17 +44,17 @@ test_specific: .PHONY: test_smoke test_smoke: @echo "$(ATTN_COLOR)==> test_smoke $(NO_COLOR)" - @tox -e py27,py37 -- -m smoke + @tox -e py37,py39 -- -m smoke .PHONY: test_no_app test_no_app: @echo "$(ATTN_COLOR)==> test_no_app $(NO_COLOR)" - @tox -e py27,py37 -- -m "not app" + @tox -e py37,py39 -- -m "not app" .PHONY: test_smoke_no_app test_smoke_no_app: @echo "$(ATTN_COLOR)==> test_smoke_no_app $(NO_COLOR)" - @tox -e py27,py37 -- -m "smoke and not app" + @tox -e py37,py39 -- -m "smoke and not app" .PHONY: env env: diff --git a/README.md b/README.md index eb7ad8bd3..7cd66cc23 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,10 @@ -[![Build Status](https://travis-ci.org/splunk/splunk-sdk-python.svg?branch=master)](https://travis-ci.org/splunk/splunk-sdk-python) -[![Documentation Status](https://readthedocs.org/projects/splunk-python-sdk/badge/?version=latest)](https://splunk-python-sdk.readthedocs.io/en/latest/?badge=latest) +[![Build Status](https://github.com/splunk/splunk-sdk-python/actions/workflows/test.yml/badge.svg?branch=master)](https://github.com/splunk/splunk-sdk-python/actions/workflows/test.yml) + +[Reference Docs](https://dev.splunk.com/enterprise/reference) # The Splunk Enterprise Software Development Kit for Python -#### Version 1.7.1 +#### Version 2.1.0 The Splunk Enterprise Software Development Kit (SDK) for Python contains library code designed to enable developers to build applications using the Splunk platform. @@ -18,19 +19,19 @@ The Splunk developer platform enables developers to take advantage of the same t ## Get started with the Splunk Enterprise SDK for Python -The Splunk Enterprise SDK for Python contains library code, and it's examples are located in the [splunk-app-examples](https://github.com/splunk/splunk-app-examples) repository, that show how to programmatically interact with the Splunk platform for a variety of scenarios including searching, saved searches, data inputs, and many more, along with building complete applications. +The Splunk Enterprise SDK for Python contains library code, and its examples are located in the [splunk-app-examples](https://github.com/splunk/splunk-app-examples) repository. They show how to programmatically interact with the Splunk platform for a variety of scenarios including searching, saved searches, data inputs, and many more, along with building complete applications. ### Requirements Here's what you need to get going with the Splunk Enterprise SDK for Python. -* Python 2.7+ or Python 3.7. +* Python 3.7, Python 3.9 and Python 3.13 - The Splunk Enterprise SDK for Python has been tested with Python v2.7 and v3.7. + The Splunk Enterprise SDK for Python is compatible with python3 and has been tested with Python v3.7, v3.9 and v3.13. -* Splunk Enterprise 9.0 or 8.2 +* Splunk Enterprise 9.2 or 8.2 - The Splunk Enterprise SDK for Python has been tested with Splunk Enterprise 9.0 and 8.2 + The Splunk Enterprise SDK for Python has been tested with Splunk Enterprise 9.2, 8.2 and 8.1 If you haven't already installed Splunk Enterprise, download it [here](http://www.splunk.com/download). For more information, see the Splunk Enterprise [_Installation Manual_](https://docs.splunk.com/Documentation/Splunk/latest/Installation). @@ -60,7 +61,7 @@ Install the sources you cloned from GitHub: You'll need `docker` and `docker-compose` to get up and running using this method. ``` -make up SPLUNK_VERSION=9.0 +make up SPLUNK_VERSION=9.2 make wait_up make test make down @@ -109,11 +110,11 @@ here is an example of .env file: # Access scheme (default: https) scheme=https # Your version of Splunk Enterprise - version=9.0 + version=9.2 # Bearer token for authentication - #bearerToken= + #splunkToken= # Session key for authentication - #sessionKey= + #token= #### SDK examples @@ -127,7 +128,7 @@ The Splunk Enterprise SDK for Python contains a collection of unit tests. To run You can also run individual test files, which are located in **/splunk-sdk-python/tests**. To run a specific test, enter: - make specific_test_name + make test_specific The test suite uses Python's standard library, the built-in `unittest` library, `pytest`, and `tox`. @@ -209,7 +210,39 @@ class GeneratorTest(GeneratingCommand): checkpoint_dir = inputs.metadata["checkpoint_dir"] ``` -#### Optional:Set up logging for splunklib +### Access service object in Custom Search Command & Modular Input apps + +#### Custom Search Commands +* The service object is created from the Splunkd URI and session key passed to the command invocation the search results info file. +* Service object can be accessed using `self.service` in `generate`/`transform`/`stream`/`reduce` methods depending on the Custom Search Command. +* For Generating Custom Search Command + ```python + def generate(self): + # other code + + # access service object that can be used to connect Splunk Service + service = self.service + # to get Splunk Service Info + info = service.info + ``` + + + +#### Modular Inputs app: +* The service object is created from the Splunkd URI and session key passed to the command invocation on the modular input stream respectively. +* It is available as soon as the `Script.stream_events` method is called. +```python + def stream_events(self, inputs, ew): + # other code + + # access service object that can be used to connect Splunk Service + service = self.service + # to get Splunk Service Info + info = service.info +``` + + +### Optional:Set up logging for splunklib + The default level is WARNING, which means that only events of this level and above will be visible + To change a logging level we can call setup_logging() method and pass the logging level as an argument. + Optional: we can also pass log format and date format string as a method argument to modify default format diff --git a/docs/conf.py b/docs/conf.py index 5c3586315..6b8bfe1d5 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -43,7 +43,7 @@ # General information about the project. project = u'Splunk SDK for Python' -copyright = u'2021, Splunk Inc' +copyright = u'2024, Splunk Inc' # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the diff --git a/scripts/build-env.py b/scripts/build-env.py index e1a153d4a..fcf55ae14 100644 --- a/scripts/build-env.py +++ b/scripts/build-env.py @@ -1,4 +1,4 @@ -# Copyright 2011-2020 Splunk, Inc. +# Copyright 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain diff --git a/scripts/templates/env.template b/scripts/templates/env.template index a45851b6a..ac9ebe5c7 100644 --- a/scripts/templates/env.template +++ b/scripts/templates/env.template @@ -11,6 +11,6 @@ scheme=$scheme # Your version of Splunk (default: 6.2) version=$version # Bearer token for authentication -#bearerToken= +#splunkToken= # Session key for authentication -#sessionKey= \ No newline at end of file +#token= \ No newline at end of file diff --git a/scripts/test_specific.sh b/scripts/test_specific.sh index 1d9b0d494..9f751530e 100644 --- a/scripts/test_specific.sh +++ b/scripts/test_specific.sh @@ -1,2 +1,4 @@ echo "To run a specific test:" -echo " tox -e py27,py37 [test_file_path]::[test_name]" +echo " tox -e py37,py39 [test_file_path]::[TestClassName]::[test_method]" +echo "For Example, To run 'test_autologin' testcase from 'test_service.py' file run" +echo " tox -e py37 -- tests/test_service.py::ServiceTestCase::test_autologin" diff --git a/setup.py b/setup.py index 284c50983..d19a09eb1 100755 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,127 +14,15 @@ # License for the specific language governing permissions and limitations # under the License. -from setuptools import setup, Command - -import os -import sys +from setuptools import setup import splunklib -failed = False - -def run_test_suite(): - try: - import unittest2 as unittest - except ImportError: - import unittest - - def mark_failed(): - global failed - failed = True - - class _TrackingTextTestResult(unittest._TextTestResult): - def addError(self, test, err): - unittest._TextTestResult.addError(self, test, err) - mark_failed() - - def addFailure(self, test, err): - unittest._TextTestResult.addFailure(self, test, err) - mark_failed() - - class TrackingTextTestRunner(unittest.TextTestRunner): - def _makeResult(self): - return _TrackingTextTestResult( - self.stream, self.descriptions, self.verbosity) - - original_cwd = os.path.abspath(os.getcwd()) - os.chdir('tests') - suite = unittest.defaultTestLoader.discover('.') - runner = TrackingTextTestRunner(verbosity=2) - runner.run(suite) - os.chdir(original_cwd) - - return failed - - -def run_test_suite_with_junit_output(): - try: - import unittest2 as unittest - except ImportError: - import unittest - import xmlrunner - original_cwd = os.path.abspath(os.getcwd()) - os.chdir('tests') - suite = unittest.defaultTestLoader.discover('.') - xmlrunner.XMLTestRunner(output='../test-reports').run(suite) - os.chdir(original_cwd) - - -class CoverageCommand(Command): - """setup.py command to run code coverage of the test suite.""" - description = "Create an HTML coverage report from running the full test suite." - user_options = [] - - def initialize_options(self): - pass - - def finalize_options(self): - pass - - def run(self): - try: - import coverage - except ImportError: - print("Could not import coverage. Please install it and try again.") - exit(1) - cov = coverage.coverage(source=['splunklib']) - cov.start() - run_test_suite() - cov.stop() - cov.html_report(directory='coverage_report') - - -class TestCommand(Command): - """setup.py command to run the whole test suite.""" - description = "Run test full test suite." - user_options = [] - - def initialize_options(self): - pass - - def finalize_options(self): - pass - - def run(self): - failed = run_test_suite() - if failed: - sys.exit(1) - - -class JunitXmlTestCommand(Command): - """setup.py command to run the whole test suite.""" - description = "Run test full test suite with JUnit-formatted output." - user_options = [] - - def initialize_options(self): - pass - - def finalize_options(self): - pass - - def run(self): - run_test_suite_with_junit_output() - - setup( author="Splunk, Inc.", author_email="devinfo@splunk.com", - cmdclass={'coverage': CoverageCommand, - 'test': TestCommand, - 'testjunit': JunitXmlTestCommand}, - description="The Splunk Software Development Kit for Python.", license="http://www.apache.org/licenses/LICENSE-2.0", @@ -145,6 +33,10 @@ def run(self): "splunklib.modularinput", "splunklib.searchcommands"], + install_requires=[ + "deprecation", + ], + url="http://github.com/splunk/splunk-sdk-python", version=splunklib.__version__, diff --git a/sitecustomize.py b/sitecustomize.py index 21fbaebef..eb94c154b 100644 --- a/sitecustomize.py +++ b/sitecustomize.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain diff --git a/splunklib/__init__.py b/splunklib/__init__.py index b370003ea..5f83b2ac4 100644 --- a/splunklib/__init__.py +++ b/splunklib/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,12 +14,10 @@ """Python library for Splunk.""" -from __future__ import absolute_import -from splunklib.six.moves import map import logging DEFAULT_LOG_FORMAT = '%(asctime)s, Level=%(levelname)s, Pid=%(process)s, Logger=%(name)s, File=%(filename)s, ' \ - 'Line=%(lineno)s, %(message)s' + 'Line=%(lineno)s, %(message)s' DEFAULT_DATE_FORMAT = '%Y-%m-%d %H:%M:%S %Z' @@ -31,5 +29,6 @@ def setup_logging(level, log_format=DEFAULT_LOG_FORMAT, date_format=DEFAULT_DATE format=log_format, datefmt=date_format) -__version_info__ = (1, 7, 1) + +__version_info__ = (2, 1, 0) __version__ = ".".join(map(str, __version_info__)) diff --git a/splunklib/binding.py b/splunklib/binding.py index 7806bee49..25a099489 100644 --- a/splunklib/binding.py +++ b/splunklib/binding.py @@ -1,4 +1,4 @@ -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -24,30 +24,24 @@ :mod:`splunklib.client` module. """ -from __future__ import absolute_import - import io +import json import logging import socket import ssl -import sys import time from base64 import b64encode from contextlib import contextmanager from datetime import datetime from functools import wraps from io import BytesIO -from xml.etree.ElementTree import XML - -from splunklib import six -from splunklib.six.moves import urllib - -from .data import record +from urllib import parse +from http import client +from http.cookies import SimpleCookie +from xml.etree.ElementTree import XML, ParseError +from splunklib.data import record +from splunklib import __version__ -try: - from xml.etree.ElementTree import ParseError -except ImportError as e: - from xml.parsers.expat import ExpatError as ParseError logger = logging.getLogger(__name__) @@ -56,26 +50,59 @@ "connect", "Context", "handler", - "HTTPError" + "HTTPError", + "UrlEncoded", + "_encode", + "_make_cookie_header", + "_NoAuthenticationToken", + "namespace" ] +SENSITIVE_KEYS = ['Authorization', 'Cookie', 'action.email.auth_password', 'auth', 'auth_password', 'clear_password', 'clientId', + 'crc-salt', 'encr_password', 'oldpassword', 'passAuth', 'password', 'session', 'suppressionKey', + 'token'] + # If you change these, update the docstring # on _authority as well. DEFAULT_HOST = "localhost" DEFAULT_PORT = "8089" DEFAULT_SCHEME = "https" + def _log_duration(f): @wraps(f) def new_f(*args, **kwargs): start_time = datetime.now() val = f(*args, **kwargs) end_time = datetime.now() - logger.debug("Operation took %s", end_time-start_time) + logger.debug("Operation took %s", end_time - start_time) return val + return new_f +def mask_sensitive_data(data): + ''' + Masked sensitive fields data for logging purpose + ''' + if not isinstance(data, dict): + try: + data = json.loads(data) + except Exception as ex: + return data + + # json.loads will return "123"(str) as 123(int), so return the data if it's not 'dict' type + if not isinstance(data, dict): + return data + mdata = {} + for k, v in data.items(): + if k in SENSITIVE_KEYS: + mdata[k] = "******" + else: + mdata[k] = mask_sensitive_data(v) + return mdata + + def _parse_cookies(cookie_str, dictionary): """Tries to parse any key-value pairs of cookies in a string, then updates the the dictionary with any key-value pairs found. @@ -92,7 +119,7 @@ def _parse_cookies(cookie_str, dictionary): :param dictionary: A dictionary to update with any found key-value pairs. :type dictionary: ``dict`` """ - parsed_cookie = six.moves.http_cookies.SimpleCookie(cookie_str) + parsed_cookie = SimpleCookie(cookie_str) for cookie in parsed_cookie.values(): dictionary[cookie.key] = cookie.coded_value @@ -114,10 +141,11 @@ def _make_cookie_header(cookies): :return: ``str` An HTTP header cookie string. :rtype: ``str`` """ - return "; ".join("%s=%s" % (key, value) for key, value in cookies) + return "; ".join(f"{key}={value}" for key, value in cookies) + # Singleton values to eschew None -class _NoAuthenticationToken(object): +class _NoAuthenticationToken: """The value stored in a :class:`Context` or :class:`splunklib.client.Service` class that is not logged in. @@ -129,7 +157,6 @@ class that is not logged in. Likewise, after a ``Context`` or ``Service`` object has been logged out, the token is set to this value again. """ - pass class UrlEncoded(str): @@ -155,7 +182,7 @@ class UrlEncoded(str): **Example**:: import urllib - UrlEncoded('%s://%s' % (scheme, urllib.quote(host)), skip_encode=True) + UrlEncoded(f'{scheme}://{urllib.quote(host)}', skip_encode=True) If you append ``str`` strings and ``UrlEncoded`` strings, the result is also URL encoded. @@ -165,19 +192,19 @@ class UrlEncoded(str): UrlEncoded('ab c') + 'de f' == UrlEncoded('ab cde f') 'ab c' + UrlEncoded('de f') == UrlEncoded('ab cde f') """ + def __new__(self, val='', skip_encode=False, encode_slash=False): if isinstance(val, UrlEncoded): # Don't urllib.quote something already URL encoded. return val - elif skip_encode: + if skip_encode: return str.__new__(self, val) - elif encode_slash: - return str.__new__(self, urllib.parse.quote_plus(val)) - else: - # When subclassing str, just call str's __new__ method - # with your class and the value you want to have in the - # new string. - return str.__new__(self, urllib.parse.quote(val)) + if encode_slash: + return str.__new__(self, parse.quote_plus(val)) + # When subclassing str, just call str.__new__ method + # with your class and the value you want to have in the + # new string. + return str.__new__(self, parse.quote(val)) def __add__(self, other): """self + other @@ -187,8 +214,8 @@ def __add__(self, other): """ if isinstance(other, UrlEncoded): return UrlEncoded(str.__add__(self, other), skip_encode=True) - else: - return UrlEncoded(str.__add__(self, urllib.parse.quote(other)), skip_encode=True) + + return UrlEncoded(str.__add__(self, parse.quote(other)), skip_encode=True) def __radd__(self, other): """other + self @@ -198,8 +225,8 @@ def __radd__(self, other): """ if isinstance(other, UrlEncoded): return UrlEncoded(str.__radd__(self, other), skip_encode=True) - else: - return UrlEncoded(str.__add__(urllib.parse.quote(other), self), skip_encode=True) + + return UrlEncoded(str.__add__(parse.quote(other), self), skip_encode=True) def __mod__(self, fields): """Interpolation into ``UrlEncoded``s is disabled. @@ -208,15 +235,17 @@ def __mod__(self, fields): ``TypeError``. """ raise TypeError("Cannot interpolate into a UrlEncoded object.") + def __repr__(self): - return "UrlEncoded(%s)" % repr(urllib.parse.unquote(str(self))) + return f"UrlEncoded({repr(parse.unquote(str(self)))})" + @contextmanager def _handle_auth_error(msg): - """Handle reraising HTTP authentication errors as something clearer. + """Handle re-raising HTTP authentication errors as something clearer. If an ``HTTPError`` is raised with status 401 (access denied) in - the body of this context manager, reraise it as an + the body of this context manager, re-raise it as an ``AuthenticationError`` instead, with *msg* as its message. This function adds no round trips to the server. @@ -237,6 +266,7 @@ def _handle_auth_error(msg): else: raise + def _authentication(request_fun): """Decorator to handle autologin and authentication errors. @@ -269,12 +299,12 @@ def _authentication(request_fun): def f(): c.get("/services") return 42 - print _authentication(f) + print(_authentication(f)) """ + @wraps(request_fun) def wrapper(self, *args, **kwargs): - if self.token is _NoAuthenticationToken and \ - not self.has_cookies(): + if self.token is _NoAuthenticationToken and not self.has_cookies(): # Not yet logged in. if self.autologin and self.username and self.password: # This will throw an uncaught @@ -296,8 +326,7 @@ def wrapper(self, *args, **kwargs): # an AuthenticationError and give up. with _handle_auth_error("Autologin failed."): self.login() - with _handle_auth_error( - "Authentication Failed! If session token is used, it seems to have been expired."): + with _handle_auth_error("Authentication Failed! If session token is used, it seems to have been expired."): return request_fun(self, *args, **kwargs) elif he.status == 401 and not self.autologin: raise AuthenticationError( @@ -346,11 +375,13 @@ def _authority(scheme=DEFAULT_SCHEME, host=DEFAULT_HOST, port=DEFAULT_PORT): "http://splunk.utopia.net:471" """ - if ':' in host: + # check if host is an IPv6 address and not enclosed in [ ] + if ':' in host and not (host.startswith('[') and host.endswith(']')): # IPv6 addresses must be enclosed in [ ] in order to be well # formed. host = '[' + host + ']' - return UrlEncoded("%s://%s:%s" % (scheme, host, port), skip_encode=True) + return UrlEncoded(f"{scheme}://{host}:{port}", skip_encode=True) + # kwargs: sharing, owner, app def namespace(sharing=None, owner=None, app=None, **kwargs): @@ -405,7 +436,7 @@ def namespace(sharing=None, owner=None, app=None, **kwargs): n = binding.namespace(sharing="global", app="search") """ if sharing in ["system"]: - return record({'sharing': sharing, 'owner': "nobody", 'app': "system" }) + return record({'sharing': sharing, 'owner': "nobody", 'app': "system"}) if sharing in ["global", "app"]: return record({'sharing': sharing, 'owner': "nobody", 'app': app}) if sharing in ["user", None]: @@ -413,7 +444,7 @@ def namespace(sharing=None, owner=None, app=None, **kwargs): raise ValueError("Invalid value for argument: 'sharing'") -class Context(object): +class Context: """This class represents a context that encapsulates a splunkd connection. The ``Context`` class encapsulates the details of HTTP requests, @@ -432,8 +463,10 @@ class Context(object): :type port: ``integer`` :param scheme: The scheme for accessing the service (the default is "https"). :type scheme: "https" or "http" - :param verify: Enable (True) or disable (False) SSL verrification for https connections. + :param verify: Enable (True) or disable (False) SSL verification for https connections. :type verify: ``Boolean`` + :param self_signed_certificate: Specifies if self signed certificate is used + :type self_signed_certificate: ``Boolean`` :param sharing: The sharing mode for the namespace (the default is "user"). :type sharing: "global", "system", "app", or "user" :param owner: The owner context of the namespace (optional, the default is "None"). @@ -475,12 +508,14 @@ class Context(object): # Or if you already have a valid cookie c = binding.Context(cookie="splunkd_8089=...") """ + def __init__(self, handler=None, **kwargs): self.http = HttpLib(handler, kwargs.get("verify", False), key_file=kwargs.get("key_file"), - cert_file=kwargs.get("cert_file"), context=kwargs.get("context"), # Default to False for backward compat + cert_file=kwargs.get("cert_file"), context=kwargs.get("context"), + # Default to False for backward compat retries=kwargs.get("retries", 0), retryDelay=kwargs.get("retryDelay", 10)) self.token = kwargs.get("token", _NoAuthenticationToken) - if self.token is None: # In case someone explicitly passes token=None + if self.token is None: # In case someone explicitly passes token=None self.token = _NoAuthenticationToken self.scheme = kwargs.get("scheme", DEFAULT_SCHEME) self.host = kwargs.get("host", DEFAULT_HOST) @@ -493,6 +528,7 @@ def __init__(self, handler=None, **kwargs): self.bearerToken = kwargs.get("splunkToken", "") self.autologin = kwargs.get("autologin", False) self.additional_headers = kwargs.get("headers", []) + self._self_signed_certificate = kwargs.get("self_signed_certificate", True) # Store any cookies in the self.http._cookies dict if "cookie" in kwargs and kwargs['cookie'] not in [None, _NoAuthenticationToken]: @@ -530,9 +566,9 @@ def _auth_headers(self): if self.has_cookies(): return [("Cookie", _make_cookie_header(list(self.get_cookies().items())))] elif self.basic and (self.username and self.password): - token = 'Basic %s' % b64encode(("%s:%s" % (self.username, self.password)).encode('utf-8')).decode('ascii') + token = f'Basic {b64encode(("%s:%s" % (self.username, self.password)).encode("utf-8")).decode("ascii")}' elif self.bearerToken: - token = 'Bearer %s' % self.bearerToken + token = f'Bearer {self.bearerToken}' elif self.token is _NoAuthenticationToken: token = [] else: @@ -540,7 +576,7 @@ def _auth_headers(self): if self.token.startswith('Splunk '): token = self.token else: - token = 'Splunk %s' % self.token + token = f'Splunk {self.token}' if token: header.append(("Authorization", token)) if self.get_cookies(): @@ -571,7 +607,11 @@ def connect(self): """ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) if self.scheme == "https": - sock = ssl.wrap_socket(sock) + context = ssl.create_default_context() + context.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 + context.check_hostname = not self._self_signed_certificate + context.verify_mode = ssl.CERT_NONE if self._self_signed_certificate else ssl.CERT_REQUIRED + sock = context.wrap_socket(sock, server_hostname=self.host) sock.connect((socket.gethostbyname(self.host), self.port)) return sock @@ -629,7 +669,7 @@ def delete(self, path_segment, owner=None, app=None, sharing=None, **query): """ path = self.authority + self._abspath(path_segment, owner=owner, app=app, sharing=sharing) - logger.debug("DELETE request to %s (body: %s)", path, repr(query)) + logger.debug("DELETE request to %s (body: %s)", path, mask_sensitive_data(query)) response = self.http.delete(path, self._auth_headers, **query) return response @@ -692,7 +732,7 @@ def get(self, path_segment, owner=None, app=None, headers=None, sharing=None, ** path = self.authority + self._abspath(path_segment, owner=owner, app=app, sharing=sharing) - logger.debug("GET request to %s (body: %s)", path, repr(query)) + logger.debug("GET request to %s (body: %s)", path, mask_sensitive_data(query)) all_headers = headers + self.additional_headers + self._auth_headers response = self.http.get(path, all_headers, **query) return response @@ -771,12 +811,7 @@ def post(self, path_segment, owner=None, app=None, sharing=None, headers=None, * path = self.authority + self._abspath(path_segment, owner=owner, app=app, sharing=sharing) - # To avoid writing sensitive data in debug logs - endpoint_having_sensitive_data = ["/storage/passwords"] - if any(endpoint in path for endpoint in endpoint_having_sensitive_data): - logger.debug("POST request to %s ", path) - else: - logger.debug("POST request to %s (body: %s)", path, repr(query)) + logger.debug("POST request to %s (body: %s)", path, mask_sensitive_data(query)) all_headers = headers + self.additional_headers + self._auth_headers response = self.http.post(path, all_headers, **query) return response @@ -838,13 +873,12 @@ def request(self, path_segment, method="GET", headers=None, body={}, headers = [] path = self.authority \ - + self._abspath(path_segment, owner=owner, - app=app, sharing=sharing) + + self._abspath(path_segment, owner=owner, + app=app, sharing=sharing) all_headers = headers + self.additional_headers + self._auth_headers logger.debug("%s request to %s (headers: %s, body: %s)", - method, path, str(all_headers), repr(body)) - + method, path, str(mask_sensitive_data(dict(all_headers))), mask_sensitive_data(body)) if body: body = _encode(**body) @@ -885,14 +919,14 @@ def login(self): """ if self.has_cookies() and \ - (not self.username and not self.password): + (not self.username and not self.password): # If we were passed session cookie(s), but no username or # password, then login is a nop, since we're automatically # logged in. return if self.token is not _NoAuthenticationToken and \ - (not self.username and not self.password): + (not self.username and not self.password): # If we were passed a session token, but no username or # password, then login is a nop, since we're automatically # logged in. @@ -914,11 +948,11 @@ def login(self): username=self.username, password=self.password, headers=self.additional_headers, - cookie="1") # In Splunk 6.2+, passing "cookie=1" will return the "set-cookie" header + cookie="1") # In Splunk 6.2+, passing "cookie=1" will return the "set-cookie" header body = response.body.read() session = XML(body).findtext("./sessionKey") - self.token = "Splunk %s" % session + self.token = f"Splunk {session}" return self except HTTPError as he: if he.status == 401: @@ -933,7 +967,7 @@ def logout(self): return self def _abspath(self, path_segment, - owner=None, app=None, sharing=None): + owner=None, app=None, sharing=None): """Qualifies *path_segment* into an absolute path for a URL. If *path_segment* is already absolute, returns it unchanged. @@ -985,12 +1019,11 @@ def _abspath(self, path_segment, # namespace. If only one of app and owner is specified, use # '-' for the other. if ns.app is None and ns.owner is None: - return UrlEncoded("/services/%s" % path_segment, skip_encode=skip_encode) + return UrlEncoded(f"/services/{path_segment}", skip_encode=skip_encode) oname = "nobody" if ns.owner is None else ns.owner aname = "system" if ns.app is None else ns.app - path = UrlEncoded("/servicesNS/%s/%s/%s" % (oname, aname, path_segment), - skip_encode=skip_encode) + path = UrlEncoded(f"/servicesNS/{oname}/{aname}/{path_segment}", skip_encode=skip_encode) return path @@ -1041,21 +1074,23 @@ def connect(**kwargs): c.login() return c + # Note: the error response schema supports multiple messages but we only # return the first, although we do return the body so that an exception # handler that wants to read multiple messages can do so. class HTTPError(Exception): """This exception is raised for HTTP responses that return an error.""" + def __init__(self, response, _message=None): status = response.status reason = response.reason body = response.body.read() try: detail = XML(body).findtext("./messages/msg") - except ParseError as err: + except ParseError: detail = body - message = "HTTP %d %s%s" % ( - status, reason, "" if detail is None else " -- %s" % detail) + detail_formatted = "" if detail is None else f" -- {detail}" + message = f"HTTP {status} {reason}{detail_formatted}" Exception.__init__(self, _message or message) self.status = status self.reason = reason @@ -1063,6 +1098,7 @@ def __init__(self, response, _message=None): self.body = body self._response = response + class AuthenticationError(HTTPError): """Raised when a login request to Splunk fails. @@ -1070,6 +1106,7 @@ class AuthenticationError(HTTPError): in a call to :meth:`Context.login` or :meth:`splunklib.client.Service.login`, this exception is raised. """ + def __init__(self, message, cause): # Put the body back in the response so that HTTPError's constructor can # read it again. @@ -1077,6 +1114,7 @@ def __init__(self, message, cause): HTTPError.__init__(self, cause._response, message) + # # The HTTP interface used by the Splunk binding layer abstracts the underlying # HTTP library using request & response 'messages' which are implemented as @@ -1104,16 +1142,17 @@ def __init__(self, message, cause): # 'foo=1&foo=2&foo=3'. def _encode(**kwargs): items = [] - for key, value in six.iteritems(kwargs): + for key, value in kwargs.items(): if isinstance(value, list): items.extend([(key, item) for item in value]) else: items.append((key, value)) - return urllib.parse.urlencode(items) + return parse.urlencode(items) + # Crack the given url into (scheme, host, port, path) def _spliturl(url): - parsed_url = urllib.parse.urlparse(url) + parsed_url = parse.urlparse(url) host = parsed_url.hostname port = parsed_url.port path = '?'.join((parsed_url.path, parsed_url.query)) if parsed_url.query else parsed_url.path @@ -1122,9 +1161,10 @@ def _spliturl(url): if port is None: port = DEFAULT_PORT return parsed_url.scheme, host, port, path + # Given an HTTP request handler, this wrapper objects provides a related # family of convenience methods built using that handler. -class HttpLib(object): +class HttpLib: """A set of convenient methods for making HTTP calls. ``HttpLib`` provides a general :meth:`request` method, and :meth:`delete`, @@ -1166,7 +1206,9 @@ class HttpLib(object): If using the default handler, SSL verification can be disabled by passing verify=False. """ - def __init__(self, custom_handler=None, verify=False, key_file=None, cert_file=None, context=None, retries=0, retryDelay=10): + + def __init__(self, custom_handler=None, verify=False, key_file=None, cert_file=None, context=None, retries=0, + retryDelay=10): if custom_handler is None: self.handler = handler(verify=verify, key_file=key_file, cert_file=cert_file, context=context) else: @@ -1227,7 +1269,7 @@ def get(self, url, headers=None, **kwargs): # the query to be encoded or it will get automatically URL # encoded by being appended to url. url = url + UrlEncoded('?' + _encode(**kwargs), skip_encode=True) - return self.request(url, { 'method': "GET", 'headers': headers }) + return self.request(url, {'method': "GET", 'headers': headers}) def post(self, url, headers=None, **kwargs): """Sends a POST request to a URL. @@ -1323,6 +1365,7 @@ class ResponseReader(io.RawIOBase): types of HTTP libraries used with this SDK. This class also provides a preview of the stream and a few useful predicates. """ + # For testing, you can use a StringIO as the argument to # ``ResponseReader`` instead of an ``httplib.HTTPResponse``. It # will work equally well. @@ -1332,10 +1375,7 @@ def __init__(self, response, connection=None): self._buffer = b'' def __str__(self): - if six.PY2: - return self.read() - else: - return str(self.read(), 'UTF-8') + return str(self.read(), 'UTF-8') @property def empty(self): @@ -1361,7 +1401,7 @@ def close(self): self._connection.close() self._response.close() - def read(self, size = None): + def read(self, size=None): """Reads a given number of characters from the response. :param size: The number of characters to read, or "None" to read the @@ -1414,7 +1454,7 @@ def connect(scheme, host, port): kwargs = {} if timeout is not None: kwargs['timeout'] = timeout if scheme == "http": - return six.moves.http_client.HTTPConnection(host, port, **kwargs) + return client.HTTPConnection(host, port, **kwargs) if scheme == "https": if key_file is not None: kwargs['key_file'] = key_file if cert_file is not None: kwargs['cert_file'] = cert_file @@ -1425,8 +1465,8 @@ def connect(scheme, host, port): # verify is True in elif branch and context is not None kwargs['context'] = context - return six.moves.http_client.HTTPSConnection(host, port, **kwargs) - raise ValueError("unsupported scheme: %s" % scheme) + return client.HTTPSConnection(host, port, **kwargs) + raise ValueError(f"unsupported scheme: {scheme}") def request(url, message, **kwargs): scheme, host, port, path = _spliturl(url) @@ -1434,10 +1474,10 @@ def request(url, message, **kwargs): head = { "Content-Length": str(len(body)), "Host": host, - "User-Agent": "splunk-sdk-python/1.7.1", + "User-Agent": "splunk-sdk-python/%s" % __version__, "Accept": "*/*", "Connection": "Close", - } # defaults + } # defaults for key, value in message["headers"]: head[key] = value method = message.get("method", "GET") diff --git a/splunklib/client.py b/splunklib/client.py index 85c559d0d..ee390c9ed 100644 --- a/splunklib/client.py +++ b/splunklib/client.py @@ -1,4 +1,4 @@ -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -54,7 +54,7 @@ are subclasses of :class:`Entity`. An ``Entity`` object has fields for its attributes, and methods that are specific to each kind of entity. For example:: - print my_app['author'] # Or: print my_app.author + print(my_app['author']) # Or: print(my_app.author) my_app.package() # Creates a compressed package of this application """ @@ -66,15 +66,13 @@ import socket from datetime import datetime, timedelta from time import sleep +from urllib import parse -from splunklib import six -from splunklib.six.moves import urllib - -from . import data -from .binding import (AuthenticationError, Context, HTTPError, UrlEncoded, - _encode, _make_cookie_header, _NoAuthenticationToken, - namespace) -from .data import record +from splunklib import data +from splunklib.data import record +from splunklib.binding import (AuthenticationError, Context, HTTPError, UrlEncoded, + _encode, _make_cookie_header, _NoAuthenticationToken, + namespace) logger = logging.getLogger(__name__) @@ -84,7 +82,8 @@ "OperationError", "IncomparableException", "Service", - "namespace" + "namespace", + "AuthenticationError" ] PATH_APPS = "apps/local/" @@ -102,11 +101,12 @@ PATH_JOBS = "search/jobs/" PATH_JOBS_V2 = "search/v2/jobs/" PATH_LOGGER = "/services/server/logger/" +PATH_MACROS = "configs/conf-macros/" PATH_MESSAGES = "messages/" PATH_MODULAR_INPUTS = "data/modular-inputs" PATH_ROLES = "authorization/roles/" PATH_SAVED_SEARCHES = "saved/searches/" -PATH_STANZA = "configs/conf-%s/%s" # (file, stanza) +PATH_STANZA = "configs/conf-%s/%s" # (file, stanza) PATH_USERS = "authentication/users/" PATH_RECEIVERS_STREAM = "/services/receivers/stream" PATH_RECEIVERS_SIMPLE = "/services/receivers/simple" @@ -116,45 +116,38 @@ XNAME_ENTRY = XNAMEF_ATOM % "entry" XNAME_CONTENT = XNAMEF_ATOM % "content" -MATCH_ENTRY_CONTENT = "%s/%s/*" % (XNAME_ENTRY, XNAME_CONTENT) +MATCH_ENTRY_CONTENT = f"{XNAME_ENTRY}/{XNAME_CONTENT}/*" class IllegalOperationException(Exception): """Thrown when an operation is not possible on the Splunk instance that a :class:`Service` object is connected to.""" - pass class IncomparableException(Exception): """Thrown when trying to compare objects (using ``==``, ``<``, ``>``, and so on) of a type that doesn't support it.""" - pass class AmbiguousReferenceException(ValueError): """Thrown when the name used to fetch an entity matches more than one entity.""" - pass class InvalidNameException(Exception): """Thrown when the specified name contains characters that are not allowed in Splunk entity names.""" - pass class NoSuchCapability(Exception): """Thrown when the capability that has been referred to doesn't exist.""" - pass class OperationError(Exception): - """Raised for a failed operation, such as a time out.""" - pass + """Raised for a failed operation, such as a timeout.""" class NotSupportedError(Exception): """Raised for operations that are not supported on a given object.""" - pass def _trailing(template, *targets): @@ -190,8 +183,9 @@ def _trailing(template, *targets): def _filter_content(content, *args): if len(args) > 0: return record((k, content[k]) for k in args) - return record((k, v) for k, v in six.iteritems(content) - if k not in ['eai:acl', 'eai:attributes', 'type']) + return record((k, v) for k, v in content.items() + if k not in ['eai:acl', 'eai:attributes', 'type']) + # Construct a resource path from the given base path + resource name def _path(base, name): @@ -221,10 +215,9 @@ def _load_atom_entries(response): # its state wrapped in another element, but at the top level. # For example, in XML, it returns ... instead of # .... - else: - entries = r.get('entry', None) - if entries is None: return None - return entries if isinstance(entries, list) else [entries] + entries = r.get('entry', None) + if entries is None: return None + return entries if isinstance(entries, list) else [entries] # Load the sid from the body of the given response @@ -250,7 +243,7 @@ def _parse_atom_entry(entry): metadata = _parse_atom_metadata(content) # Filter some of the noise out of the content record - content = record((k, v) for k, v in six.iteritems(content) + content = record((k, v) for k, v in content.items() if k not in ['eai:acl', 'eai:attributes']) if 'type' in content: @@ -289,6 +282,7 @@ def _parse_atom_metadata(content): return record({'access': access, 'fields': fields}) + # kwargs: scheme, host, port, app, owner, username, password def connect(**kwargs): """This function connects and logs in to a Splunk instance. @@ -417,10 +411,12 @@ class Service(_BaseService): # Or if you already have a valid cookie s = client.Service(cookie="splunkd_8089=...") """ + def __init__(self, **kwargs): - super(Service, self).__init__(**kwargs) + super().__init__(**kwargs) self._splunk_version = None self._kvstore_owner = None + self._instance_type = None @property def apps(self): @@ -537,8 +533,7 @@ def modular_input_kinds(self): """ if self.splunk_version >= (5,): return ReadOnlyCollection(self, PATH_MODULAR_INPUTS, item=ModularInputKind) - else: - raise IllegalOperationException("Modular inputs are not supported before Splunk version 5.") + raise IllegalOperationException("Modular inputs are not supported before Splunk version 5.") @property def storage_passwords(self): @@ -572,7 +567,7 @@ def parse(self, query, **kwargs): :type kwargs: ``dict`` :return: A semantic map of the parsed search query. """ - if self.splunk_version >= (9,): + if not self.disable_v2_api: return self.post("search/v2/parser", q=query, **kwargs) return self.get("search/parser", q=query, **kwargs) @@ -588,7 +583,7 @@ def restart(self, timeout=None): :param timeout: A timeout period, in seconds. :type timeout: ``integer`` """ - msg = { "value": "Restart requested by " + self.username + "via the Splunk SDK for Python"} + msg = {"value": "Restart requested by " + self.username + "via the Splunk SDK for Python"} # This message will be deleted once the server actually restarts. self.messages.create(name="restart_required", **msg) result = self.post("/services/server/control/restart") @@ -673,6 +668,15 @@ def saved_searches(self): """ return SavedSearches(self) + @property + def macros(self): + """Returns the collection of macros. + + :return: A :class:`Macros` collection of :class:`Macro` + entities. + """ + return Macros(self) + @property def settings(self): """Returns the configuration settings for this instance of Splunk. @@ -692,9 +696,25 @@ def splunk_version(self): :return: A ``tuple`` of ``integers``. """ if self._splunk_version is None: - self._splunk_version = tuple([int(p) for p in self.info['version'].split('.')]) + self._splunk_version = tuple(int(p) for p in self.info['version'].split('.')) return self._splunk_version + @property + def splunk_instance(self): + if self._instance_type is None : + splunk_info = self.info + if hasattr(splunk_info, 'instance_type') : + self._instance_type = splunk_info['instance_type'] + else: + self._instance_type = '' + return self._instance_type + + @property + def disable_v2_api(self): + if self.splunk_instance.lower() == 'cloud': + return self.splunk_version < (9,0,2209) + return self.splunk_version < (9,0,2) + @property def kvstore_owner(self): """Returns the KVStore owner for this instance of Splunk. @@ -734,13 +754,14 @@ def users(self): return Users(self) -class Endpoint(object): +class Endpoint: """This class represents individual Splunk resources in the Splunk REST API. An ``Endpoint`` object represents a URI, such as ``/services/saved/searches``. This class provides the common functionality of :class:`Collection` and :class:`Entity` (essentially HTTP GET and POST methods). """ + def __init__(self, service, path): self.service = service self.path = path @@ -757,11 +778,11 @@ def get_api_version(self, path): # Default to v1 if undefined in the path # For example, "/services/search/jobs" is using API v1 api_version = 1 - + versionSearch = re.search('(?:servicesNS\/[^/]+\/[^/]+|services)\/[^/]+\/v(\d+)\/', path) if versionSearch: api_version = int(versionSearch.group(1)) - + return api_version def get(self, path_segment="", owner=None, app=None, sharing=None, **query): @@ -835,7 +856,7 @@ def get(self, path_segment="", owner=None, app=None, sharing=None, **query): # - Fallback from v2+ to v1 if Splunk Version is < 9. # if api_version >= 2 and ('search' in query and path.endswith(tuple(["results_preview", "events", "results"])) or self.service.splunk_version < (9,)): # path = path.replace(PATH_JOBS_V2, PATH_JOBS) - + if api_version == 1: if isinstance(path, UrlEncoded): path = UrlEncoded(path.replace(PATH_JOBS_V2, PATH_JOBS), skip_encode=True) @@ -894,14 +915,14 @@ def post(self, path_segment="", owner=None, app=None, sharing=None, **query): apps.get('nonexistant/path') # raises HTTPError s.logout() apps.get() # raises AuthenticationError - """ + """ if path_segment.startswith('/'): path = path_segment else: if not self.path.endswith('/') and path_segment != "": self.path = self.path + '/' path = self.service._abspath(self.path + path_segment, owner=owner, app=app, sharing=sharing) - + # Get the API version from the path api_version = self.get_api_version(path) @@ -910,7 +931,7 @@ def post(self, path_segment="", owner=None, app=None, sharing=None, **query): # - Fallback from v2+ to v1 if Splunk Version is < 9. # if api_version >= 2 and ('search' in query and path.endswith(tuple(["results_preview", "events", "results"])) or self.service.splunk_version < (9,)): # path = path.replace(PATH_JOBS_V2, PATH_JOBS) - + if api_version == 1: if isinstance(path, UrlEncoded): path = UrlEncoded(path.replace(PATH_JOBS_V2, PATH_JOBS), skip_encode=True) @@ -986,7 +1007,6 @@ def __init__(self, service, path, **kwargs): self._state = None if not kwargs.get('skip_refresh', False): self.refresh(kwargs.get('state', None)) # "Prefresh" - return def __contains__(self, item): try: @@ -1011,24 +1031,21 @@ def __eq__(self, other): but then ``x != saved_searches['asearch']``. whether or not there was a change on the server. Rather than - try to do something fancy, we simple declare that equality is + try to do something fancy, we simply declare that equality is undefined for Entities. Makes no roundtrips to the server. """ - raise IncomparableException( - "Equality is undefined for objects of class %s" % \ - self.__class__.__name__) + raise IncomparableException(f"Equality is undefined for objects of class {self.__class__.__name__}") def __getattr__(self, key): # Called when an attribute was not found by the normal method. In this # case we try to find it in self.content and then self.defaults. if key in self.state.content: return self.state.content[key] - elif key in self.defaults: + if key in self.defaults: return self.defaults[key] - else: - raise AttributeError(key) + raise AttributeError(key) def __getitem__(self, key): # getattr attempts to find a field on the object in the normal way, @@ -1044,9 +1061,8 @@ def _load_atom_entry(self, response): apps = [ele.entry.content.get('eai:appName') for ele in elem] raise AmbiguousReferenceException( - "Fetch from server returned multiple entries for name '%s' in apps %s." % (elem[0].entry.title, apps)) - else: - return elem.entry + f"Fetch from server returned multiple entries for name '{elem[0].entry.title}' in apps {apps}.") + return elem.entry # Load the entity state record from the given response def _load_state(self, response): @@ -1079,17 +1095,15 @@ def _proper_namespace(self, owner=None, app=None, sharing=None): :param sharing: :return: """ - if owner is None and app is None and sharing is None: # No namespace provided + if owner is None and app is None and sharing is None: # No namespace provided if self._state is not None and 'access' in self._state: return (self._state.access.owner, self._state.access.app, self._state.access.sharing) - else: - return (self.service.namespace['owner'], + return (self.service.namespace['owner'], self.service.namespace['app'], self.service.namespace['sharing']) - else: - return (owner,app,sharing) + return owner, app, sharing def delete(self): owner, app, sharing = self._proper_namespace() @@ -1097,11 +1111,11 @@ def delete(self): def get(self, path_segment="", owner=None, app=None, sharing=None, **query): owner, app, sharing = self._proper_namespace(owner, app, sharing) - return super(Entity, self).get(path_segment, owner=owner, app=app, sharing=sharing, **query) + return super().get(path_segment, owner=owner, app=app, sharing=sharing, **query) def post(self, path_segment="", owner=None, app=None, sharing=None, **query): owner, app, sharing = self._proper_namespace(owner, app, sharing) - return super(Entity, self).post(path_segment, owner=owner, app=app, sharing=sharing, **query) + return super().post(path_segment, owner=owner, app=app, sharing=sharing, **query) def refresh(self, state=None): """Refreshes the state of this entity. @@ -1189,8 +1203,8 @@ def read(self, response): # In lower layers of the SDK, we end up trying to URL encode # text to be dispatched via HTTP. However, these links are already # URL encoded when they arrive, and we need to mark them as such. - unquoted_links = dict([(k, UrlEncoded(v, skip_encode=True)) - for k,v in six.iteritems(results['links'])]) + unquoted_links = dict((k, UrlEncoded(v, skip_encode=True)) + for k, v in results['links'].items()) results['links'] = unquoted_links return results @@ -1199,6 +1213,36 @@ def reload(self): self.post("_reload") return self + def acl_update(self, **kwargs): + """To update Access Control List (ACL) properties for an endpoint. + + :param kwargs: Additional entity-specific arguments (required). + + - "owner" (``string``): The Splunk username, such as "admin". A value of "nobody" means no specific user (required). + + - "sharing" (``string``): A mode that indicates how the resource is shared. The sharing mode can be "user", "app", "global", or "system" (required). + + :type kwargs: ``dict`` + + **Example**:: + + import splunklib.client as client + service = client.connect(...) + saved_search = service.saved_searches["name"] + saved_search.acl_update(sharing="app", owner="nobody", app="search", **{"perms.read": "admin, nobody"}) + """ + if "body" not in kwargs: + kwargs = {"body": kwargs} + + if "sharing" not in kwargs["body"]: + raise ValueError("Required argument 'sharing' is missing.") + if "owner" not in kwargs["body"]: + raise ValueError("Required argument 'owner' is missing.") + + self.post("acl", **kwargs) + self.refresh() + return self + @property def state(self): """Returns the entity's state record. @@ -1235,7 +1279,7 @@ def update(self, **kwargs): """ # The peculiarity in question: the REST API creates a new # Entity if we pass name in the dictionary, instead of the - # expected behavior of updating this Entity. Therefore we + # expected behavior of updating this Entity. Therefore, we # check for 'name' in kwargs and throw an error if it is # there. if 'name' in kwargs: @@ -1248,9 +1292,10 @@ class ReadOnlyCollection(Endpoint): """This class represents a read-only collection of entities in the Splunk instance. """ + def __init__(self, service, path, item=Entity): Endpoint.__init__(self, service, path) - self.item = item # Item accessor + self.item = item # Item accessor self.null_count = -1 def __contains__(self, name): @@ -1282,7 +1327,7 @@ def __getitem__(self, key): name. Where there is no conflict, ``__getitem__`` will fetch the - entity given just the name. If there is a conflict and you + entity given just the name. If there is a conflict, and you pass just a name, it will raise a ``ValueError``. In that case, add the namespace as a second argument. @@ -1329,13 +1374,13 @@ def __getitem__(self, key): response = self.get(key) entries = self._load_list(response) if len(entries) > 1: - raise AmbiguousReferenceException("Found multiple entities named '%s'; please specify a namespace." % key) - elif len(entries) == 0: + raise AmbiguousReferenceException( + f"Found multiple entities named '{key}'; please specify a namespace.") + if len(entries) == 0: raise KeyError(key) - else: - return entries[0] + return entries[0] except HTTPError as he: - if he.status == 404: # No entity matching key and namespace. + if he.status == 404: # No entity matching key and namespace. raise KeyError(key) else: raise @@ -1358,7 +1403,7 @@ def __iter__(self, **kwargs): c = client.connect(...) saved_searches = c.saved_searches for entity in saved_searches: - print "Saved search named %s" % entity.name + print(f"Saved search named {entity.name}") """ for item in self.iter(**kwargs): @@ -1399,13 +1444,12 @@ def _entity_path(self, state): # This has been factored out so that it can be easily # overloaded by Configurations, which has to switch its # entities' endpoints from its own properties/ to configs/. - raw_path = urllib.parse.unquote(state.links.alternate) + raw_path = parse.unquote(state.links.alternate) if 'servicesNS/' in raw_path: return _trailing(raw_path, 'servicesNS/', '/', '/') - elif 'services/' in raw_path: + if 'services/' in raw_path: return _trailing(raw_path, 'services/') - else: - return raw_path + return raw_path def _load_list(self, response): """Converts *response* to a list of entities. @@ -1568,8 +1612,6 @@ def list(self, count=None, **kwargs): return list(self.iter(count=count, **kwargs)) - - class Collection(ReadOnlyCollection): """A collection of entities. @@ -1643,8 +1685,8 @@ def create(self, name, **params): applications = s.apps new_app = applications.create("my_fake_app") """ - if not isinstance(name, six.string_types): - raise InvalidNameException("%s is not a valid name for an entity." % name) + if not isinstance(name, str): + raise InvalidNameException(f"{name} is not a valid name for an entity.") if 'namespace' in params: namespace = params.pop('namespace') params['owner'] = namespace.owner @@ -1656,14 +1698,13 @@ def create(self, name, **params): # This endpoint doesn't return the content of the new # item. We have to go fetch it ourselves. return self[name] - else: - entry = atom.entry - state = _parse_atom_entry(entry) - entity = self.item( - self.service, - self._entity_path(state), - state=state) - return entity + entry = atom.entry + state = _parse_atom_entry(entry) + entity = self.item( + self.service, + self._entity_path(state), + state=state) + return entity def delete(self, name, **params): """Deletes a specified entity from the collection. @@ -1703,7 +1744,7 @@ def delete(self, name, **params): # has already been deleted, and we reraise it as a # KeyError. if he.status == 404: - raise KeyError("No such entity %s" % name) + raise KeyError(f"No such entity {name}") else: raise return self @@ -1754,14 +1795,13 @@ def get(self, name="", owner=None, app=None, sharing=None, **query): """ name = UrlEncoded(name, encode_slash=True) - return super(Collection, self).get(name, owner, app, sharing, **query) - - + return super().get(name, owner, app, sharing, **query) class ConfigurationFile(Collection): """This class contains all of the stanzas from one configuration file. """ + # __init__'s arguments must match those of an Entity, not a # Collection, since it is being created as the elements of a # Configurations, which is a Collection subclass. @@ -1778,6 +1818,7 @@ class Configurations(Collection): stanzas. This collection is unusual in that the values in it are themselves collections of :class:`ConfigurationFile` objects. """ + def __init__(self, service): Collection.__init__(self, service, PATH_PROPERTIES, item=ConfigurationFile) if self.service.namespace.owner == '-' or self.service.namespace.app == '-': @@ -1792,10 +1833,10 @@ def __getitem__(self, key): # This screws up the default implementation of __getitem__ from Collection, which thinks # that multiple entities means a name collision, so we have to override it here. try: - response = self.get(key) + self.get(key) return ConfigurationFile(self.service, PATH_CONF % key, state={'title': key}) except HTTPError as he: - if he.status == 404: # No entity matching key + if he.status == 404: # No entity matching key raise KeyError(key) else: raise @@ -1804,13 +1845,12 @@ def __contains__(self, key): # configs/conf-{name} never returns a 404. We have to post to properties/{name} # in order to find out if a configuration exists. try: - response = self.get(key) + self.get(key) return True except HTTPError as he: - if he.status == 404: # No entity matching key + if he.status == 404: # No entity matching key return False - else: - raise + raise def create(self, name): """ Creates a configuration file named *name*. @@ -1826,15 +1866,14 @@ def create(self, name): # This has to be overridden to handle the plumbing of creating # a ConfigurationFile (which is a Collection) instead of some # Entity. - if not isinstance(name, six.string_types): - raise ValueError("Invalid name: %s" % repr(name)) + if not isinstance(name, str): + raise ValueError(f"Invalid name: {repr(name)}") response = self.post(__conf=name) if response.status == 303: return self[name] - elif response.status == 201: + if response.status == 201: return ConfigurationFile(self.service, PATH_CONF % name, item=Stanza, state={'title': name}) - else: - raise ValueError("Unexpected status code %s returned from creating a stanza" % response.status) + raise ValueError(f"Unexpected status code {response.status} returned from creating a stanza") def delete(self, key): """Raises `IllegalOperationException`.""" @@ -1873,10 +1912,11 @@ def __len__(self): class StoragePassword(Entity): """This class contains a storage password. """ + def __init__(self, service, path, **kwargs): state = kwargs.get('state', None) kwargs['skip_refresh'] = kwargs.get('skip_refresh', state is not None) - super(StoragePassword, self).__init__(service, path, **kwargs) + super().__init__(service, path, **kwargs) self._state = state @property @@ -1900,8 +1940,11 @@ class StoragePasswords(Collection): """This class provides access to the storage passwords from this Splunk instance. Retrieve this collection using :meth:`Service.storage_passwords`. """ + def __init__(self, service): - super(StoragePasswords, self).__init__(service, PATH_STORAGE_PASSWORDS, item=StoragePassword) + if service.namespace.owner == '-' or service.namespace.app == '-': + raise ValueError("StoragePasswords cannot have wildcards in namespace.") + super().__init__(service, PATH_STORAGE_PASSWORDS, item=StoragePassword) def create(self, password, username, realm=None): """ Creates a storage password. @@ -1918,11 +1961,8 @@ def create(self, password, username, realm=None): :return: The :class:`StoragePassword` object created. """ - if self.service.namespace.owner == '-' or self.service.namespace.app == '-': - raise ValueError("While creating StoragePasswords, namespace cannot have wildcards.") - - if not isinstance(username, six.string_types): - raise ValueError("Invalid name: %s" % repr(username)) + if not isinstance(username, str): + raise ValueError(f"Invalid name: {repr(username)}") if realm is None: response = self.post(password=password, name=username) @@ -1930,7 +1970,7 @@ def create(self, password, username, realm=None): response = self.post(password=password, realm=realm, name=username) if response.status != 201: - raise ValueError("Unexpected status code %s returned from creating a stanza" % response.status) + raise ValueError(f"Unexpected status code {response.status} returned from creating a stanza") entries = _load_atom_entries(response) state = _parse_atom_entry(entries[0]) @@ -1952,9 +1992,6 @@ def delete(self, username, realm=None): :return: The `StoragePassword` collection. :rtype: ``self`` """ - if self.service.namespace.owner == '-' or self.service.namespace.app == '-': - raise ValueError("app context must be specified when removing a password.") - if realm is None: # This case makes the username optional, so # the full name can be passed in as realm. @@ -1973,6 +2010,7 @@ def delete(self, username, realm=None): class AlertGroup(Entity): """This class represents a group of fired alerts for a saved search. Access it using the :meth:`alerts` property.""" + def __init__(self, service, path, **kwargs): Entity.__init__(self, service, path, **kwargs) @@ -2001,6 +2039,7 @@ class Indexes(Collection): """This class contains the collection of indexes in this Splunk instance. Retrieve this collection using :meth:`Service.indexes`. """ + def get_default(self): """ Returns the name of the default index. @@ -2028,6 +2067,7 @@ def delete(self, name): class Index(Entity): """This class represents an index and provides different operations, such as cleaning the index, writing to the index, and so forth.""" + def __init__(self, service, path, **kwargs): Entity.__init__(self, service, path, **kwargs) @@ -2044,26 +2084,26 @@ def attach(self, host=None, source=None, sourcetype=None): :return: A writable socket. """ - args = { 'index': self.name } + args = {'index': self.name} if host is not None: args['host'] = host if source is not None: args['source'] = source if sourcetype is not None: args['sourcetype'] = sourcetype - path = UrlEncoded(PATH_RECEIVERS_STREAM + "?" + urllib.parse.urlencode(args), skip_encode=True) + path = UrlEncoded(PATH_RECEIVERS_STREAM + "?" + parse.urlencode(args), skip_encode=True) - cookie_or_auth_header = "Authorization: Splunk %s\r\n" % \ - (self.service.token if self.service.token is _NoAuthenticationToken - else self.service.token.replace("Splunk ", "")) + cookie_header = self.service.token if self.service.token is _NoAuthenticationToken else self.service.token.replace("Splunk ", "") + cookie_or_auth_header = f"Authorization: Splunk {cookie_header}\r\n" # If we have cookie(s), use them instead of "Authorization: ..." if self.service.has_cookies(): - cookie_or_auth_header = "Cookie: %s\r\n" % _make_cookie_header(self.service.get_cookies().items()) + cookie_header = _make_cookie_header(self.service.get_cookies().items()) + cookie_or_auth_header = f"Cookie: {cookie_header}\r\n" # Since we need to stream to the index connection, we have to keep # the connection open and use the Splunk extension headers to note # the input mode sock = self.service.connect() - headers = [("POST %s HTTP/1.1\r\n" % str(self.service._abspath(path))).encode('utf-8'), - ("Host: %s:%s\r\n" % (self.service.host, int(self.service.port))).encode('utf-8'), + headers = [f"POST {str(self.service._abspath(path))} HTTP/1.1\r\n".encode('utf-8'), + f"Host: {self.service.host}:{int(self.service.port)}\r\n".encode('utf-8'), b"Accept-Encoding: identity\r\n", cookie_or_auth_header.encode('utf-8'), b"X-Splunk-Input-Mode: Streaming\r\n", @@ -2125,8 +2165,7 @@ def clean(self, timeout=60): ftp = self['frozenTimePeriodInSecs'] was_disabled_initially = self.disabled try: - if (not was_disabled_initially and \ - self.service.splunk_version < (5,)): + if not was_disabled_initially and self.service.splunk_version < (5,): # Need to disable the index first on Splunk 4.x, # but it doesn't work to disable it on 5.0. self.disable() @@ -2136,17 +2175,17 @@ def clean(self, timeout=60): # Wait until event count goes to 0. start = datetime.now() diff = timedelta(seconds=timeout) - while self.content.totalEventCount != '0' and datetime.now() < start+diff: + while self.content.totalEventCount != '0' and datetime.now() < start + diff: sleep(1) self.refresh() if self.content.totalEventCount != '0': - raise OperationError("Cleaning index %s took longer than %s seconds; timing out." % (self.name, timeout)) + raise OperationError( + f"Cleaning index {self.name} took longer than {timeout} seconds; timing out.") finally: # Restore original values self.update(maxTotalDataSizeMB=tds, frozenTimePeriodInSecs=ftp) - if (not was_disabled_initially and \ - self.service.splunk_version < (5,)): + if not was_disabled_initially and self.service.splunk_version < (5,): # Re-enable the index if it was originally enabled and we messed with it. self.enable() @@ -2174,7 +2213,7 @@ def submit(self, event, host=None, source=None, sourcetype=None): :return: The :class:`Index`. """ - args = { 'index': self.name } + args = {'index': self.name} if host is not None: args['host'] = host if source is not None: args['source'] = source if sourcetype is not None: args['sourcetype'] = sourcetype @@ -2208,6 +2247,7 @@ class Input(Entity): typed input classes and is also used when the client does not recognize an input kind. """ + def __init__(self, service, path, kind=None, **kwargs): # kind can be omitted (in which case it is inferred from the path) # Otherwise, valid values are the paths from data/inputs ("udp", @@ -2218,7 +2258,7 @@ def __init__(self, service, path, kind=None, **kwargs): path_segments = path.split('/') i = path_segments.index('inputs') + 1 if path_segments[i] == 'tcp': - self.kind = path_segments[i] + '/' + path_segments[i+1] + self.kind = path_segments[i] + '/' + path_segments[i + 1] else: self.kind = path_segments[i] else: @@ -2244,7 +2284,7 @@ def update(self, **kwargs): # UDP and TCP inputs require special handling due to their restrictToHost # field. For all other inputs kinds, we can dispatch to the superclass method. if self.kind not in ['tcp', 'splunktcp', 'tcp/raw', 'tcp/cooked', 'udp']: - return super(Input, self).update(**kwargs) + return super().update(**kwargs) else: # The behavior of restrictToHost is inconsistent across input kinds and versions of Splunk. # In Splunk 4.x, the name of the entity is only the port, independent of the value of @@ -2262,11 +2302,11 @@ def update(self, **kwargs): if 'restrictToHost' in kwargs: raise IllegalOperationException("Cannot set restrictToHost on an existing input with the SDK.") - elif 'restrictToHost' in self._state.content and self.kind != 'udp': + if 'restrictToHost' in self._state.content and self.kind != 'udp': to_update['restrictToHost'] = self._state.content['restrictToHost'] # Do the actual update operation. - return super(Input, self).update(**to_update) + return super().update(**to_update) # Inputs is a "kinded" collection, which is a heterogenous collection where @@ -2293,13 +2333,12 @@ def __getitem__(self, key): response = self.get(self.kindpath(kind) + "/" + key) entries = self._load_list(response) if len(entries) > 1: - raise AmbiguousReferenceException("Found multiple inputs of kind %s named %s." % (kind, key)) - elif len(entries) == 0: + raise AmbiguousReferenceException(f"Found multiple inputs of kind {kind} named {key}.") + if len(entries) == 0: raise KeyError((key, kind)) - else: - return entries[0] + return entries[0] except HTTPError as he: - if he.status == 404: # No entity matching kind and key + if he.status == 404: # No entity matching kind and key raise KeyError((key, kind)) else: raise @@ -2313,22 +2352,22 @@ def __getitem__(self, key): response = self.get(kind + "/" + key) entries = self._load_list(response) if len(entries) > 1: - raise AmbiguousReferenceException("Found multiple inputs of kind %s named %s." % (kind, key)) - elif len(entries) == 0: + raise AmbiguousReferenceException(f"Found multiple inputs of kind {kind} named {key}.") + if len(entries) == 0: pass else: - if candidate is not None: # Already found at least one candidate - raise AmbiguousReferenceException("Found multiple inputs named %s, please specify a kind" % key) + if candidate is not None: # Already found at least one candidate + raise AmbiguousReferenceException( + f"Found multiple inputs named {key}, please specify a kind") candidate = entries[0] except HTTPError as he: if he.status == 404: - pass # Just carry on to the next kind. + pass # Just carry on to the next kind. else: raise if candidate is None: - raise KeyError(key) # Never found a match. - else: - return candidate + raise KeyError(key) # Never found a match. + return candidate def __contains__(self, key): if isinstance(key, tuple) and len(key) == 2: @@ -2348,11 +2387,9 @@ def __contains__(self, key): entries = self._load_list(response) if len(entries) > 0: return True - else: - pass except HTTPError as he: if he.status == 404: - pass # Just carry on to the next kind. + pass # Just carry on to the next kind. else: raise return False @@ -2404,9 +2441,8 @@ def create(self, name, kind, **kwargs): name = UrlEncoded(name, encode_slash=True) path = _path( self.path + kindpath, - '%s:%s' % (kwargs['restrictToHost'], name) \ - if 'restrictToHost' in kwargs else name - ) + f"{kwargs['restrictToHost']}:{name}" if 'restrictToHost' in kwargs else name + ) return Input(self.service, path, kind) def delete(self, name, kind=None): @@ -2476,7 +2512,7 @@ def itemmeta(self, kind): :return: The metadata. :rtype: class:``splunklib.data.Record`` """ - response = self.get("%s/_new" % self._kindmap[kind]) + response = self.get(f"{self._kindmap[kind]}/_new") content = _load_atom(response, MATCH_ENTRY_CONTENT) return _parse_atom_metadata(content) @@ -2491,9 +2527,9 @@ def _get_kind_list(self, subpath=None): this_subpath = subpath + [entry.title] # The "all" endpoint doesn't work yet. # The "tcp/ssl" endpoint is not a real input collection. - if entry.title == 'all' or this_subpath == ['tcp','ssl']: + if entry.title == 'all' or this_subpath == ['tcp', 'ssl']: continue - elif 'create' in [x.rel for x in entry.link]: + if 'create' in [x.rel for x in entry.link]: path = '/'.join(subpath + [entry.title]) kinds.append(path) else: @@ -2542,10 +2578,9 @@ def kindpath(self, kind): """ if kind == 'tcp': return UrlEncoded('tcp/raw', skip_encode=True) - elif kind == 'splunktcp': + if kind == 'splunktcp': return UrlEncoded('tcp/cooked', skip_encode=True) - else: - return UrlEncoded(kind, skip_encode=True) + return UrlEncoded(kind, skip_encode=True) def list(self, *kinds, **kwargs): """Returns a list of inputs that are in the :class:`Inputs` collection. @@ -2613,18 +2648,18 @@ def list(self, *kinds, **kwargs): path = UrlEncoded(path, skip_encode=True) response = self.get(path, **kwargs) except HTTPError as he: - if he.status == 404: # No inputs of this kind + if he.status == 404: # No inputs of this kind return [] entities = [] entries = _load_atom_entries(response) if entries is None: - return [] # No inputs in a collection comes back with no feed or entry in the XML + return [] # No inputs in a collection comes back with no feed or entry in the XML for entry in entries: state = _parse_atom_entry(entry) # Unquote the URL, since all URL encoded in the SDK # should be of type UrlEncoded, and all str should not # be URL encoded. - path = urllib.parse.unquote(state.links.alternate) + path = parse.unquote(state.links.alternate) entity = Input(self.service, path, kind, state=state) entities.append(entity) return entities @@ -2639,18 +2674,18 @@ def list(self, *kinds, **kwargs): response = self.get(self.kindpath(kind), search=search) except HTTPError as e: if e.status == 404: - continue # No inputs of this kind + continue # No inputs of this kind else: raise entries = _load_atom_entries(response) - if entries is None: continue # No inputs to process + if entries is None: continue # No inputs to process for entry in entries: state = _parse_atom_entry(entry) # Unquote the URL, since all URL encoded in the SDK # should be of type UrlEncoded, and all str should not # be URL encoded. - path = urllib.parse.unquote(state.links.alternate) + path = parse.unquote(state.links.alternate) entity = Input(self.service, path, kind, state=state) entities.append(entity) if 'offset' in kwargs: @@ -2718,11 +2753,12 @@ def oneshot(self, path, **kwargs): class Job(Entity): """This class represents a search job.""" + def __init__(self, service, sid, **kwargs): # Default to v2 in Splunk Version 9+ path = "{path}{sid}" # Formatting path based on the Splunk Version - if service.splunk_version < (9,): + if service.disable_v2_api: path = path.format(path=PATH_JOBS, sid=sid) else: path = path.format(path=PATH_JOBS_V2, sid=sid) @@ -2780,9 +2816,9 @@ def events(self, **kwargs): :return: The ``InputStream`` IO handle to this job's events. """ kwargs['segmentation'] = kwargs.get('segmentation', 'none') - + # Search API v1(GET) and v2(POST) - if self.service.splunk_version < (9,): + if self.service.disable_v2_api: return self.get("events", **kwargs).body return self.post("events", **kwargs).body @@ -2851,10 +2887,10 @@ def results(self, **query_params): for result in rr: if isinstance(result, results.Message): # Diagnostic messages may be returned in the results - print '%s: %s' % (result.type, result.message) + print(f'{result.type}: {result.message}') elif isinstance(result, dict): # Normal events are returned as dicts - print result + print(result) assert rr.is_preview == False Results are not available until the job has finished. If called on @@ -2872,9 +2908,9 @@ def results(self, **query_params): :return: The ``InputStream`` IO handle to this job's results. """ query_params['segmentation'] = query_params.get('segmentation', 'none') - + # Search API v1(GET) and v2(POST) - if self.service.splunk_version < (9,): + if self.service.disable_v2_api: return self.get("results", **query_params).body return self.post("results", **query_params).body @@ -2895,14 +2931,14 @@ def preview(self, **query_params): for result in rr: if isinstance(result, results.Message): # Diagnostic messages may be returned in the results - print '%s: %s' % (result.type, result.message) + print(f'{result.type}: {result.message}') elif isinstance(result, dict): # Normal events are returned as dicts - print result + print(result) if rr.is_preview: - print "Preview of a running search job." + print("Preview of a running search job.") else: - print "Job is finished. Results are final." + print("Job is finished. Results are final.") This method makes one roundtrip to the server, plus at most two more if @@ -2917,9 +2953,9 @@ def preview(self, **query_params): :return: The ``InputStream`` IO handle to this job's preview results. """ query_params['segmentation'] = query_params.get('segmentation', 'none') - + # Search API v1(GET) and v2(POST) - if self.service.splunk_version < (9,): + if self.service.disable_v2_api: return self.get("results_preview", **query_params).body return self.post("results_preview", **query_params).body @@ -3009,9 +3045,10 @@ def unpause(self): class Jobs(Collection): """This class represents a collection of search jobs. Retrieve this collection using :meth:`Service.jobs`.""" + def __init__(self, service): # Splunk 9 introduces the v2 endpoint - if service.splunk_version >= (9,): + if not service.disable_v2_api: path = PATH_JOBS_V2 else: path = PATH_JOBS @@ -3067,10 +3104,10 @@ def export(self, query, **params): for result in rr: if isinstance(result, results.Message): # Diagnostic messages may be returned in the results - print '%s: %s' % (result.type, result.message) + print(f'{result.type}: {result.message}') elif isinstance(result, dict): # Normal events are returned as dicts - print result + print(result) assert rr.is_preview == False Running an export search is more efficient as it streams the results @@ -3123,10 +3160,10 @@ def oneshot(self, query, **params): for result in rr: if isinstance(result, results.Message): # Diagnostic messages may be returned in the results - print '%s: %s' % (result.type, result.message) + print(f'{result.type}: {result.message}') elif isinstance(result, dict): # Normal events are returned as dicts - print result + print(result) assert rr.is_preview == False The ``oneshot`` method makes a single roundtrip to the server (as opposed @@ -3167,6 +3204,7 @@ def oneshot(self, query, **params): class Loggers(Collection): """This class represents a collection of service logging categories. Retrieve this collection using :meth:`Service.loggers`.""" + def __init__(self, service): Collection.__init__(self, service, PATH_LOGGER) @@ -3198,19 +3236,18 @@ class ModularInputKind(Entity): """This class contains the different types of modular inputs. Retrieve this collection using :meth:`Service.modular_input_kinds`. """ + def __contains__(self, name): args = self.state.content['endpoints']['args'] if name in args: return True - else: - return Entity.__contains__(self, name) + return Entity.__contains__(self, name) def __getitem__(self, name): args = self.state.content['endpoint']['args'] if name in args: return args['item'] - else: - return Entity.__getitem__(self, name) + return Entity.__getitem__(self, name) @property def arguments(self): @@ -3235,6 +3272,7 @@ def update(self, **kwargs): class SavedSearch(Entity): """This class represents a saved search.""" + def __init__(self, service, path, **kwargs): Entity.__init__(self, service, path, **kwargs) @@ -3376,8 +3414,7 @@ def suppressed(self): r = self._run_action("suppress") if r.suppressed == "1": return int(r.expiration) - else: - return 0 + return 0 def unsuppress(self): """Cancels suppression and makes this search run as scheduled. @@ -3391,6 +3428,7 @@ def unsuppress(self): class SavedSearches(Collection): """This class represents a collection of saved searches. Retrieve this collection using :meth:`Service.saved_searches`.""" + def __init__(self, service): Collection.__init__( self, service, PATH_SAVED_SEARCHES, item=SavedSearch) @@ -3412,9 +3450,94 @@ def create(self, name, search, **kwargs): return Collection.create(self, name, search=search, **kwargs) +class Macro(Entity): + """This class represents a search macro.""" + def __init__(self, service, path, **kwargs): + Entity.__init__(self, service, path, **kwargs) + + @property + def args(self): + """Returns the macro arguments. + :return: The macro arguments. + :rtype: ``string`` + """ + return self._state.content.get('args', '') + + @property + def definition(self): + """Returns the macro definition. + :return: The macro definition. + :rtype: ``string`` + """ + return self._state.content.get('definition', '') + + @property + def errormsg(self): + """Returns the validation error message for the macro. + :return: The validation error message for the macro. + :rtype: ``string`` + """ + return self._state.content.get('errormsg', '') + + @property + def iseval(self): + """Returns the eval-based definition status of the macro. + :return: The iseval value for the macro. + :rtype: ``string`` + """ + return self._state.content.get('iseval', '0') + + def update(self, definition=None, **kwargs): + """Updates the server with any changes you've made to the current macro + along with any additional arguments you specify. + :param `definition`: The macro definition (optional). + :type definition: ``string`` + :param `kwargs`: Additional arguments (optional). Available parameters are: + 'disabled', 'iseval', 'validation', and 'errormsg'. + :type kwargs: ``dict`` + :return: The :class:`Macro`. + """ + # Updates to a macro *require* that the definition be + # passed, so we pass the current definition if a value wasn't + # provided by the caller. + if definition is None: definition = self.content.definition + Entity.update(self, definition=definition, **kwargs) + return self + + @property + def validation(self): + """Returns the validation expression for the macro. + :return: The validation expression for the macro. + :rtype: ``string`` + """ + return self._state.content.get('validation', '') + + +class Macros(Collection): + """This class represents a collection of macros. Retrieve this + collection using :meth:`Service.macros`.""" + def __init__(self, service): + Collection.__init__( + self, service, PATH_MACROS, item=Macro) + + def create(self, name, definition, **kwargs): + """ Creates a macro. + :param name: The name for the macro. + :type name: ``string`` + :param definition: The macro definition. + :type definition: ``string`` + :param kwargs: Additional arguments (optional). Available parameters are: + 'disabled', 'iseval', 'validation', and 'errormsg'. + :type kwargs: ``dict`` + :return: The :class:`Macros` collection. + """ + return Collection.create(self, name, definition=definition, **kwargs) + + class Settings(Entity): """This class represents configuration settings for a Splunk service. Retrieve this collection using :meth:`Service.settings`.""" + def __init__(self, service, **kwargs): Entity.__init__(self, service, "/services/server/settings", **kwargs) @@ -3436,6 +3559,7 @@ def update(self, **kwargs): class User(Entity): """This class represents a Splunk user. """ + @property def role_entities(self): """Returns a list of roles assigned to this user. @@ -3443,7 +3567,8 @@ def role_entities(self): :return: The list of roles. :rtype: ``list`` """ - return [self.service.roles[name] for name in self.content.roles] + all_role_names = [r.name for r in self.service.roles.list()] + return [self.service.roles[name] for name in self.content.roles if name in all_role_names] # Splunk automatically lowercases new user names so we need to match that @@ -3452,6 +3577,7 @@ class Users(Collection): """This class represents the collection of Splunk users for this instance of Splunk. Retrieve this collection using :meth:`Service.users`. """ + def __init__(self, service): Collection.__init__(self, service, PATH_USERS, item=User) @@ -3491,8 +3617,8 @@ def create(self, username, password, roles, **params): boris = users.create("boris", "securepassword", roles="user") hilda = users.create("hilda", "anotherpassword", roles=["user","power"]) """ - if not isinstance(username, six.string_types): - raise ValueError("Invalid username: %s" % str(username)) + if not isinstance(username, str): + raise ValueError(f"Invalid username: {str(username)}") username = username.lower() self.post(name=username, password=password, roles=roles, **params) # splunkd doesn't return the user in the POST response body, @@ -3502,7 +3628,7 @@ def create(self, username, password, roles, **params): state = _parse_atom_entry(entry) entity = self.item( self.service, - urllib.parse.unquote(state.links.alternate), + parse.unquote(state.links.alternate), state=state) return entity @@ -3521,6 +3647,7 @@ def delete(self, name): class Role(Entity): """This class represents a user role. """ + def grant(self, *capabilities_to_grant): """Grants additional capabilities to this role. @@ -3571,8 +3698,8 @@ def revoke(self, *capabilities_to_revoke): for c in old_capabilities: if c not in capabilities_to_revoke: new_capabilities.append(c) - if new_capabilities == []: - new_capabilities = '' # Empty lists don't get passed in the body, so we have to force an empty argument. + if not new_capabilities: + new_capabilities = '' # Empty lists don't get passed in the body, so we have to force an empty argument. self.post(capabilities=new_capabilities) return self @@ -3580,8 +3707,9 @@ def revoke(self, *capabilities_to_revoke): class Roles(Collection): """This class represents the collection of roles in the Splunk instance. Retrieve this collection using :meth:`Service.roles`.""" + def __init__(self, service): - return Collection.__init__(self, service, PATH_ROLES, item=Role) + Collection.__init__(self, service, PATH_ROLES, item=Role) def __getitem__(self, key): return Collection.__getitem__(self, key.lower()) @@ -3614,8 +3742,8 @@ def create(self, name, **params): roles = c.roles paltry = roles.create("paltry", imported_roles="user", defaultApp="search") """ - if not isinstance(name, six.string_types): - raise ValueError("Invalid role name: %s" % str(name)) + if not isinstance(name, str): + raise ValueError(f"Invalid role name: {str(name)}") name = name.lower() self.post(name=name, **params) # splunkd doesn't return the user in the POST response body, @@ -3625,7 +3753,7 @@ def create(self, name, **params): state = _parse_atom_entry(entry) entity = self.item( self.service, - urllib.parse.unquote(state.links.alternate), + parse.unquote(state.links.alternate), state=state) return entity @@ -3642,6 +3770,7 @@ def delete(self, name): class Application(Entity): """Represents a locally-installed Splunk app.""" + @property def setupInfo(self): """Returns the setup information for the app. @@ -3658,17 +3787,25 @@ def updateInfo(self): """Returns any update information that is available for the app.""" return self._run_action("update") + class KVStoreCollections(Collection): def __init__(self, service): Collection.__init__(self, service, 'storage/collections/config', item=KVStoreCollection) - def create(self, name, indexes = {}, fields = {}, **kwargs): + def __getitem__(self, item): + res = Collection.__getitem__(self, item) + for k, v in res.content.items(): + if "accelerated_fields" in k: + res.content[k] = json.loads(v) + return res + + def create(self, name, accelerated_fields={}, fields={}, **kwargs): """Creates a KV Store Collection. :param name: name of collection to create :type name: ``string`` - :param indexes: dictionary of index definitions - :type indexes: ``dict`` + :param accelerated_fields: dictionary of accelerated_fields definitions + :type accelerated_fields: ``dict`` :param fields: dictionary of field definitions :type fields: ``dict`` :param kwargs: a dictionary of additional parameters specifying indexes and field definitions @@ -3676,14 +3813,15 @@ def create(self, name, indexes = {}, fields = {}, **kwargs): :return: Result of POST request """ - for k, v in six.iteritems(indexes): + for k, v in accelerated_fields.items(): if isinstance(v, dict): v = json.dumps(v) - kwargs['index.' + k] = v - for k, v in six.iteritems(fields): + kwargs['accelerated_fields.' + k] = v + for k, v in fields.items(): kwargs['field.' + k] = v return self.post(name=name, **kwargs) + class KVStoreCollection(Entity): @property def data(self): @@ -3693,18 +3831,18 @@ def data(self): """ return KVStoreCollectionData(self) - def update_index(self, name, value): - """Changes the definition of a KV Store index. + def update_accelerated_field(self, name, value): + """Changes the definition of a KV Store accelerated_field. - :param name: name of index to change + :param name: name of accelerated_fields to change :type name: ``string`` - :param value: new index definition - :type value: ``dict`` or ``string`` + :param value: new accelerated_fields definition + :type value: ``dict`` :return: Result of POST request """ kwargs = {} - kwargs['index.' + name] = value if isinstance(value, six.string_types) else json.dumps(value) + kwargs['accelerated_fields.' + name] = json.dumps(value) if isinstance(value, dict) else value return self.post(**kwargs) def update_field(self, name, value): @@ -3721,7 +3859,8 @@ def update_field(self, name, value): kwargs['field.' + name] = value return self.post(**kwargs) -class KVStoreCollectionData(object): + +class KVStoreCollectionData: """This class represents the data endpoint for a KVStoreCollection. Retrieve using :meth:`KVStoreCollection.data` @@ -3784,7 +3923,8 @@ def insert(self, data): """ if isinstance(data, dict): data = json.dumps(data) - return json.loads(self._post('', headers=KVStoreCollectionData.JSON_HEADER, body=data).body.read().decode('utf-8')) + return json.loads( + self._post('', headers=KVStoreCollectionData.JSON_HEADER, body=data).body.read().decode('utf-8')) def delete(self, query=None): """ @@ -3822,7 +3962,8 @@ def update(self, id, data): """ if isinstance(data, dict): data = json.dumps(data) - return json.loads(self._post(UrlEncoded(str(id), encode_slash=True), headers=KVStoreCollectionData.JSON_HEADER, body=data).body.read().decode('utf-8')) + return json.loads(self._post(UrlEncoded(str(id), encode_slash=True), headers=KVStoreCollectionData.JSON_HEADER, + body=data).body.read().decode('utf-8')) def batch_find(self, *dbqueries): """ @@ -3839,7 +3980,8 @@ def batch_find(self, *dbqueries): data = json.dumps(dbqueries) - return json.loads(self._post('batch_find', headers=KVStoreCollectionData.JSON_HEADER, body=data).body.read().decode('utf-8')) + return json.loads( + self._post('batch_find', headers=KVStoreCollectionData.JSON_HEADER, body=data).body.read().decode('utf-8')) def batch_save(self, *documents): """ @@ -3856,4 +3998,5 @@ def batch_save(self, *documents): data = json.dumps(documents) - return json.loads(self._post('batch_save', headers=KVStoreCollectionData.JSON_HEADER, body=data).body.read().decode('utf-8')) \ No newline at end of file + return json.loads( + self._post('batch_save', headers=KVStoreCollectionData.JSON_HEADER, body=data).body.read().decode('utf-8')) \ No newline at end of file diff --git a/splunklib/data.py b/splunklib/data.py index f9ffb8692..34f3ffac1 100644 --- a/splunklib/data.py +++ b/splunklib/data.py @@ -1,4 +1,4 @@ -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -12,16 +12,13 @@ # License for the specific language governing permissions and limitations # under the License. -"""The **splunklib.data** module reads the responses from splunkd in Atom Feed +"""The **splunklib.data** module reads the responses from splunkd in Atom Feed format, which is the format used by most of the REST API. """ -from __future__ import absolute_import -import sys from xml.etree.ElementTree import XML -from splunklib import six -__all__ = ["load"] +__all__ = ["load", "record"] # LNAME refers to element names without namespaces; XNAME is the same # name, but with an XML namespace. @@ -36,33 +33,41 @@ XNAME_KEY = XNAMEF_REST % LNAME_KEY XNAME_LIST = XNAMEF_REST % LNAME_LIST + # Some responses don't use namespaces (eg: search/parse) so we look for # both the extended and local versions of the following names. + def isdict(name): - return name == XNAME_DICT or name == LNAME_DICT + return name in (XNAME_DICT, LNAME_DICT) + def isitem(name): - return name == XNAME_ITEM or name == LNAME_ITEM + return name in (XNAME_ITEM, LNAME_ITEM) + def iskey(name): - return name == XNAME_KEY or name == LNAME_KEY + return name in (XNAME_KEY, LNAME_KEY) + def islist(name): - return name == XNAME_LIST or name == LNAME_LIST + return name in (XNAME_LIST, LNAME_LIST) + def hasattrs(element): return len(element.attrib) > 0 + def localname(xname): rcurly = xname.find('}') - return xname if rcurly == -1 else xname[rcurly+1:] + return xname if rcurly == -1 else xname[rcurly + 1:] + def load(text, match=None): - """This function reads a string that contains the XML of an Atom Feed, then - returns the - data in a native Python structure (a ``dict`` or ``list``). If you also - provide a tag name or path to match, only the matching sub-elements are + """This function reads a string that contains the XML of an Atom Feed, then + returns the + data in a native Python structure (a ``dict`` or ``list``). If you also + provide a tag name or path to match, only the matching sub-elements are loaded. :param text: The XML text to load. @@ -78,30 +83,27 @@ def load(text, match=None): 'names': {} } - # Convert to unicode encoding in only python 2 for xml parser - if(sys.version_info < (3, 0, 0) and isinstance(text, unicode)): - text = text.encode('utf-8') - root = XML(text) items = [root] if match is None else root.findall(match) count = len(items) - if count == 0: + if count == 0: return None - elif count == 1: + if count == 1: return load_root(items[0], nametable) - else: - return [load_root(item, nametable) for item in items] + return [load_root(item, nametable) for item in items] + # Load the attributes of the given element. def load_attrs(element): if not hasattrs(element): return None attrs = record() - for key, value in six.iteritems(element.attrib): + for key, value in element.attrib.items(): attrs[key] = value return attrs + # Parse a element and return a Python dict -def load_dict(element, nametable = None): +def load_dict(element, nametable=None): value = record() children = list(element) for child in children: @@ -110,6 +112,7 @@ def load_dict(element, nametable = None): value[name] = load_value(child, nametable) return value + # Loads the given elements attrs & value into single merged dict. def load_elem(element, nametable=None): name = localname(element.tag) @@ -118,12 +121,12 @@ def load_elem(element, nametable=None): if attrs is None: return name, value if value is None: return name, attrs # If value is simple, merge into attrs dict using special key - if isinstance(value, six.string_types): + if isinstance(value, str): attrs["$text"] = value return name, attrs # Both attrs & value are complex, so merge the two dicts, resolving collisions. collision_keys = [] - for key, val in six.iteritems(attrs): + for key, val in attrs.items(): if key in value and key in collision_keys: value[key].append(val) elif key in value and key not in collision_keys: @@ -133,6 +136,7 @@ def load_elem(element, nametable=None): value[key] = val return name, value + # Parse a element and return a Python list def load_list(element, nametable=None): assert islist(element.tag) @@ -143,6 +147,7 @@ def load_list(element, nametable=None): value.append(load_value(child, nametable)) return value + # Load the given root element. def load_root(element, nametable=None): tag = element.tag @@ -151,6 +156,7 @@ def load_root(element, nametable=None): k, v = load_elem(element, nametable) return Record.fromkv(k, v) + # Load the children of the given element. def load_value(element, nametable=None): children = list(element) @@ -159,7 +165,7 @@ def load_value(element, nametable=None): # No children, assume a simple text value if count == 0: text = element.text - if text is None: + if text is None: return None if len(text.strip()) == 0: @@ -179,7 +185,7 @@ def load_value(element, nametable=None): # If we have seen this name before, promote the value to a list if name in value: current = value[name] - if not isinstance(current, list): + if not isinstance(current, list): value[name] = [current] value[name].append(item) else: @@ -187,23 +193,24 @@ def load_value(element, nametable=None): return value + # A generic utility that enables "dot" access to dicts class Record(dict): - """This generic utility class enables dot access to members of a Python + """This generic utility class enables dot access to members of a Python dictionary. - Any key that is also a valid Python identifier can be retrieved as a field. - So, for an instance of ``Record`` called ``r``, ``r.key`` is equivalent to - ``r['key']``. A key such as ``invalid-key`` or ``invalid.key`` cannot be - retrieved as a field, because ``-`` and ``.`` are not allowed in + Any key that is also a valid Python identifier can be retrieved as a field. + So, for an instance of ``Record`` called ``r``, ``r.key`` is equivalent to + ``r['key']``. A key such as ``invalid-key`` or ``invalid.key`` cannot be + retrieved as a field, because ``-`` and ``.`` are not allowed in identifiers. - Keys of the form ``a.b.c`` are very natural to write in Python as fields. If - a group of keys shares a prefix ending in ``.``, you can retrieve keys as a + Keys of the form ``a.b.c`` are very natural to write in Python as fields. If + a group of keys shares a prefix ending in ``.``, you can retrieve keys as a nested dictionary by calling only the prefix. For example, if ``r`` contains keys ``'foo'``, ``'bar.baz'``, and ``'bar.qux'``, ``r.bar`` returns a record - with the keys ``baz`` and ``qux``. If a key contains multiple ``.``, each - one is placed into a nested dictionary, so you can write ``r.bar.qux`` or + with the keys ``baz`` and ``qux``. If a key contains multiple ``.``, each + one is placed into a nested dictionary, so you can write ``r.bar.qux`` or ``r['bar.qux']`` interchangeably. """ sep = '.' @@ -215,7 +222,7 @@ def __call__(self, *args): def __getattr__(self, name): try: return self[name] - except KeyError: + except KeyError: raise AttributeError(name) def __delattr__(self, name): @@ -235,7 +242,7 @@ def __getitem__(self, key): return dict.__getitem__(self, key) key += self.sep result = record() - for k,v in six.iteritems(self): + for k, v in self.items(): if not k.startswith(key): continue suffix = k[len(key):] @@ -250,17 +257,16 @@ def __getitem__(self, key): else: result[suffix] = v if len(result) == 0: - raise KeyError("No key or prefix: %s" % key) + raise KeyError(f"No key or prefix: {key}") return result - -def record(value=None): - """This function returns a :class:`Record` instance constructed with an + +def record(value=None): + """This function returns a :class:`Record` instance constructed with an initial value that you provide. - - :param `value`: An initial record value. - :type `value`: ``dict`` + + :param value: An initial record value. + :type value: ``dict`` """ if value is None: value = {} return Record(value) - diff --git a/splunklib/modularinput/argument.py b/splunklib/modularinput/argument.py index 04214d16d..ec6438750 100644 --- a/splunklib/modularinput/argument.py +++ b/splunklib/modularinput/argument.py @@ -1,4 +1,4 @@ -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -12,16 +12,13 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import -try: - import xml.etree.ElementTree as ET -except ImportError: - import xml.etree.cElementTree as ET +import xml.etree.ElementTree as ET + +class Argument: -class Argument(object): """Class representing an argument to a modular input kind. - ``Argument`` is meant to be used with ``Scheme`` to generate an XML + ``Argument`` is meant to be used with ``Scheme`` to generate an XML definition of the modular input kind that Splunk understands. ``name`` is the only required parameter for the constructor. @@ -100,4 +97,4 @@ def add_to_document(self, parent): for name, value in subelements: ET.SubElement(arg, name).text = str(value).lower() - return arg \ No newline at end of file + return arg diff --git a/splunklib/modularinput/event.py b/splunklib/modularinput/event.py index 9cd6cf3ae..7ee7266ab 100644 --- a/splunklib/modularinput/event.py +++ b/splunklib/modularinput/event.py @@ -1,4 +1,4 @@ -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -12,16 +12,13 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import from io import TextIOBase -from splunklib.six import ensure_text +import xml.etree.ElementTree as ET -try: - import xml.etree.cElementTree as ET -except ImportError as ie: - import xml.etree.ElementTree as ET +from splunklib.utils import ensure_str -class Event(object): + +class Event: """Represents an event or fragment of an event to be written by this modular input to Splunk. To write an input to a stream, call the ``write_to`` function, passing in a stream. @@ -108,7 +105,7 @@ def write_to(self, stream): ET.SubElement(event, "done") if isinstance(stream, TextIOBase): - stream.write(ensure_text(ET.tostring(event))) + stream.write(ensure_str(ET.tostring(event))) else: stream.write(ET.tostring(event)) - stream.flush() \ No newline at end of file + stream.flush() diff --git a/splunklib/modularinput/event_writer.py b/splunklib/modularinput/event_writer.py old mode 100755 new mode 100644 index 5f8c5aa8b..7be3845ab --- a/splunklib/modularinput/event_writer.py +++ b/splunklib/modularinput/event_writer.py @@ -1,4 +1,4 @@ -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -12,18 +12,14 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import import sys +import traceback -from splunklib.six import ensure_str +from splunklib.utils import ensure_str from .event import ET -try: - from splunklib.six.moves import cStringIO as StringIO -except ImportError: - from splunklib.six import StringIO -class EventWriter(object): +class EventWriter: """``EventWriter`` writes events and error messages to Splunk from a modular input. Its two important methods are ``writeEvent``, which takes an ``Event`` object, and ``log``, which takes a severity and an error message. @@ -68,7 +64,26 @@ def log(self, severity, message): :param message: ``string``, message to log. """ - self._err.write("%s %s\n" % (severity, message)) + self._err.write(f"{severity} {message}\n") + self._err.flush() + + def log_exception(self, message, exception=None, severity=None): + """Logs messages about the exception thrown by this modular input to Splunk. + These messages will show up in Splunk's internal logs. + + :param message: ``string``, message to log. + :param exception: ``Exception``, exception thrown by this modular input; if none, sys.exc_info() is used + :param severity: ``string``, severity of message, see severities defined as class constants. Default severity: ERROR + """ + if exception is not None: + tb_str = traceback.format_exception(type(exception), exception, exception.__traceback__) + else: + tb_str = traceback.format_exc() + + if severity is None: + severity = EventWriter.ERROR + + self._err.write(("%s %s - %s" % (severity, message, tb_str)).replace("\n", " ")) self._err.flush() def write_xml_document(self, document): @@ -77,11 +92,11 @@ def write_xml_document(self, document): :param document: An ``ElementTree`` object. """ - self._out.write(ensure_str(ET.tostring(document))) + self._out.write(ensure_str(ET.tostring(document), errors="replace")) self._out.flush() def close(self): """Write the closing tag to make this XML well formed.""" if self.header_written: - self._out.write("") + self._out.write("") self._out.flush() diff --git a/splunklib/modularinput/input_definition.py b/splunklib/modularinput/input_definition.py index fdc7cbb3f..190192f7b 100644 --- a/splunklib/modularinput/input_definition.py +++ b/splunklib/modularinput/input_definition.py @@ -1,4 +1,4 @@ -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -12,12 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import -try: - import xml.etree.cElementTree as ET -except ImportError as ie: - import xml.etree.ElementTree as ET - +import xml.etree.ElementTree as ET from .utils import parse_xml_data class InputDefinition: @@ -57,4 +52,4 @@ def parse(stream): else: definition.metadata[node.tag] = node.text - return definition \ No newline at end of file + return definition diff --git a/splunklib/modularinput/scheme.py b/splunklib/modularinput/scheme.py index 4104e4a3f..a3b086826 100644 --- a/splunklib/modularinput/scheme.py +++ b/splunklib/modularinput/scheme.py @@ -1,4 +1,4 @@ -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -12,13 +12,10 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import -try: - import xml.etree.cElementTree as ET -except ImportError: - import xml.etree.ElementTree as ET +import xml.etree.ElementTree as ET -class Scheme(object): + +class Scheme: """Class representing the metadata for a modular input kind. A ``Scheme`` specifies a title, description, several options of how Splunk should run modular inputs of this @@ -82,4 +79,4 @@ def to_xml(self): for arg in self.arguments: arg.add_to_document(args) - return root \ No newline at end of file + return root diff --git a/splunklib/modularinput/script.py b/splunklib/modularinput/script.py index 8595dc4bd..a6c961931 100644 --- a/splunklib/modularinput/script.py +++ b/splunklib/modularinput/script.py @@ -1,4 +1,4 @@ -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -12,24 +12,18 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import from abc import ABCMeta, abstractmethod -from splunklib.six.moves.urllib.parse import urlsplit import sys +import xml.etree.ElementTree as ET +from urllib.parse import urlsplit from ..client import Service from .event_writer import EventWriter from .input_definition import InputDefinition from .validation_definition import ValidationDefinition -from splunklib import six -try: - import xml.etree.cElementTree as ET -except ImportError: - import xml.etree.ElementTree as ET - -class Script(six.with_metaclass(ABCMeta, object)): +class Script(metaclass=ABCMeta): """An abstract base class for implementing modular inputs. Subclasses should override ``get_scheme``, ``stream_events``, @@ -74,7 +68,7 @@ def run_script(self, args, event_writer, input_stream): event_writer.close() return 0 - elif str(args[1]).lower() == "--scheme": + if str(args[1]).lower() == "--scheme": # Splunk has requested XML specifying the scheme for this # modular input Return it and exit. scheme = self.get_scheme() @@ -83,11 +77,10 @@ def run_script(self, args, event_writer, input_stream): EventWriter.FATAL, "Modular input script returned a null scheme.") return 1 - else: - event_writer.write_xml_document(scheme.to_xml()) - return 0 + event_writer.write_xml_document(scheme.to_xml()) + return 0 - elif args[1].lower() == "--validate-arguments": + if args[1].lower() == "--validate-arguments": validation_definition = ValidationDefinition.parse(input_stream) try: self.validate_input(validation_definition) @@ -98,14 +91,12 @@ def run_script(self, args, event_writer, input_stream): event_writer.write_xml_document(root) return 1 - else: - err_string = "ERROR Invalid arguments to modular input script:" + ' '.join( - args) - event_writer._err.write(err_string) - return 1 + event_writer.log(EventWriter.ERROR, "Invalid arguments to modular input script:" + ' '.join( + args)) + return 1 except Exception as e: - event_writer.log(EventWriter.ERROR, str(e)) + event_writer.log_exception(str(e)) return 1 @property @@ -165,7 +156,6 @@ def validate_input(self, definition): :param definition: The parameters for the proposed input passed by splunkd. """ - pass @abstractmethod def stream_events(self, inputs, ew): diff --git a/splunklib/modularinput/utils.py b/splunklib/modularinput/utils.py index 3d42b6326..dad73dd07 100644 --- a/splunklib/modularinput/utils.py +++ b/splunklib/modularinput/utils.py @@ -1,4 +1,4 @@ -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,8 +14,7 @@ # File for utility functions -from __future__ import absolute_import -from splunklib.six.moves import zip + def xml_compare(expected, found): """Checks equality of two ``ElementTree`` objects. @@ -39,27 +38,25 @@ def xml_compare(expected, found): return False # compare children - if not all([xml_compare(a, b) for a, b in zip(expected_children, found_children)]): + if not all(xml_compare(a, b) for a, b in zip(expected_children, found_children)): return False # compare elements, if there is no text node, return True if (expected.text is None or expected.text.strip() == "") \ and (found.text is None or found.text.strip() == ""): return True - else: - return expected.tag == found.tag and expected.text == found.text \ + return expected.tag == found.tag and expected.text == found.text \ and expected.attrib == found.attrib def parse_parameters(param_node): if param_node.tag == "param": return param_node.text - elif param_node.tag == "param_list": + if param_node.tag == "param_list": parameters = [] for mvp in param_node: parameters.append(mvp.text) return parameters - else: - raise ValueError("Invalid configuration scheme, %s tag unexpected." % param_node.tag) + raise ValueError(f"Invalid configuration scheme, {param_node.tag} tag unexpected.") def parse_xml_data(parent_node, child_node_tag): data = {} diff --git a/splunklib/modularinput/validation_definition.py b/splunklib/modularinput/validation_definition.py index 3bbe9760e..b71e1e7c3 100644 --- a/splunklib/modularinput/validation_definition.py +++ b/splunklib/modularinput/validation_definition.py @@ -1,4 +1,4 @@ -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -13,16 +13,12 @@ # under the License. -from __future__ import absolute_import -try: - import xml.etree.cElementTree as ET -except ImportError as ie: - import xml.etree.ElementTree as ET +import xml.etree.ElementTree as ET from .utils import parse_xml_data -class ValidationDefinition(object): +class ValidationDefinition: """This class represents the XML sent by Splunk for external validation of a new modular input. @@ -83,4 +79,4 @@ def parse(stream): # Store anything else in metadata definition.metadata[node.tag] = node.text - return definition \ No newline at end of file + return definition diff --git a/splunklib/results.py b/splunklib/results.py index 8543ab0df..30476c846 100644 --- a/splunklib/results.py +++ b/splunklib/results.py @@ -1,4 +1,4 @@ -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -29,38 +29,27 @@ reader = ResultsReader(result_stream) for item in reader: print(item) - print "Results are a preview: %s" % reader.is_preview + print(f"Results are a preview: {reader.is_preview}") """ -from __future__ import absolute_import - from io import BufferedReader, BytesIO -from splunklib import six - -from splunklib.six import deprecated -try: - import xml.etree.cElementTree as et -except: - import xml.etree.ElementTree as et +import xml.etree.ElementTree as et from collections import OrderedDict from json import loads as json_loads -try: - from splunklib.six.moves import cStringIO as StringIO -except: - from splunklib.six import StringIO - __all__ = [ "ResultsReader", "Message", "JSONResultsReader" ] +import deprecation -class Message(object): + +class Message: """This class represents informational messages that Splunk interleaves in the results stream. ``Message`` takes two arguments: a string giving the message type (e.g., "DEBUG"), and @@ -76,7 +65,7 @@ def __init__(self, type_, message): self.message = message def __repr__(self): - return "%s: %s" % (self.type, self.message) + return f"{self.type}: {self.message}" def __eq__(self, other): return (self.type, self.message) == (other.type, other.message) @@ -85,7 +74,7 @@ def __hash__(self): return hash((self.type, self.message)) -class _ConcatenatedStream(object): +class _ConcatenatedStream: """Lazily concatenate zero or more streams into a stream. As you read from the concatenated stream, you get characters from @@ -117,7 +106,7 @@ def read(self, n=None): return response -class _XMLDTDFilter(object): +class _XMLDTDFilter: """Lazily remove all XML DTDs from a stream. All substrings matching the regular expression ]*> are @@ -144,7 +133,7 @@ def read(self, n=None): c = self.stream.read(1) if c == b"": break - elif c == b"<": + if c == b"<": c += self.stream.read(1) if c == b"`_ 3. `Configure seach assistant with searchbnf.conf `_ - + 4. `Control search distribution with distsearch.conf `_ """ -from __future__ import absolute_import, division, print_function, unicode_literals - from .environment import * from .decorators import * from .validators import * diff --git a/splunklib/searchcommands/decorators.py b/splunklib/searchcommands/decorators.py index d8b3f48cc..1393d789a 100644 --- a/splunklib/searchcommands/decorators.py +++ b/splunklib/searchcommands/decorators.py @@ -1,6 +1,6 @@ # coding=utf-8 # -# Copyright © 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,19 +14,16 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals -from splunklib import six - -from collections import OrderedDict # must be python 2.7 +from collections import OrderedDict from inspect import getmembers, isclass, isfunction -from splunklib.six.moves import map as imap + from .internals import ConfigurationSettingsType, json_encode_string from .validators import OptionName -class Configuration(object): +class Configuration: """ Defines the configuration settings for a search command. Documents, validates, and ensures that only relevant configuration settings are applied. Adds a :code:`name` class @@ -69,7 +66,7 @@ def __call__(self, o): name = o.__name__ if name.endswith('Command'): name = name[:-len('Command')] - o.name = six.text_type(name.lower()) + o.name = str(name.lower()) # Construct ConfigurationSettings instance for the command class @@ -82,7 +79,7 @@ def __call__(self, o): o.ConfigurationSettings.fix_up(o) Option.fix_up(o) else: - raise TypeError('Incorrect usage: Configuration decorator applied to {0}'.format(type(o), o.__name__)) + raise TypeError(f'Incorrect usage: Configuration decorator applied to {type(o)}') return o @@ -136,7 +133,7 @@ def fix_up(cls, values): for name, setting in definitions: if setting._name is None: - setting._name = name = six.text_type(name) + setting._name = name = str(name) else: name = setting._name @@ -187,14 +184,14 @@ def is_supported_by_protocol(version): continue if setting.fset is None: - raise ValueError('The value of configuration setting {} is fixed'.format(name)) + raise ValueError(f'The value of configuration setting {name} is fixed') setattr(cls, backing_field_name, validate(specification, name, value)) del values[name] if len(values) > 0: - settings = sorted(list(six.iteritems(values))) - settings = imap(lambda n_v: '{}={}'.format(n_v[0], repr(n_v[1])), settings) + settings = sorted(list(values.items())) + settings = [f'{n_v[0]}={n_v[1]}' for n_v in settings] raise AttributeError('Inapplicable configuration settings: ' + ', '.join(settings)) cls.configuration_setting_definitions = definitions @@ -212,7 +209,7 @@ def _get_specification(self): try: specification = ConfigurationSettingsType.specification_matrix[name] except KeyError: - raise AttributeError('Unknown configuration setting: {}={}'.format(name, repr(self._value))) + raise AttributeError(f'Unknown configuration setting: {name}={repr(self._value)}') return ConfigurationSettingsType.validate_configuration_setting, specification @@ -346,7 +343,7 @@ def _copy_extra_attributes(self, other): # region Types - class Item(object): + class Item: """ Presents an instance/class view over a search command `Option`. This class is used by SearchCommand.process to parse and report on option values. @@ -357,7 +354,7 @@ def __init__(self, command, option): self._option = option self._is_set = False validator = self.validator - self._format = six.text_type if validator is None else validator.format + self._format = str if validator is None else validator.format def __repr__(self): return '(' + repr(self.name) + ', ' + repr(self._format(self.value)) + ')' @@ -405,7 +402,6 @@ def reset(self): self._option.__set__(self._command, self._option.default) self._is_set = False - pass # endregion class View(OrderedDict): @@ -420,27 +416,26 @@ def __init__(self, command): OrderedDict.__init__(self, ((option.name, item_class(command, option)) for (name, option) in definitions)) def __repr__(self): - text = 'Option.View([' + ','.join(imap(lambda item: repr(item), six.itervalues(self))) + '])' + text = 'Option.View([' + ','.join([repr(item) for item in self.values()]) + '])' return text def __str__(self): - text = ' '.join([str(item) for item in six.itervalues(self) if item.is_set]) + text = ' '.join([str(item) for item in self.values() if item.is_set]) return text # region Methods def get_missing(self): - missing = [item.name for item in six.itervalues(self) if item.is_required and not item.is_set] + missing = [item.name for item in self.values() if item.is_required and not item.is_set] return missing if len(missing) > 0 else None def reset(self): - for value in six.itervalues(self): + for value in self.values(): value.reset() - pass # endregion - pass + # endregion diff --git a/splunklib/searchcommands/environment.py b/splunklib/searchcommands/environment.py index e92018f6a..35f1deaf3 100644 --- a/splunklib/searchcommands/environment.py +++ b/splunklib/searchcommands/environment.py @@ -1,6 +1,6 @@ # coding=utf-8 # -# Copyright © 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,16 +14,15 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals + from logging import getLogger, root, StreamHandler from logging.config import fileConfig -from os import chdir, environ, path -from splunklib.six.moves import getcwd - +from os import chdir, environ, path, getcwd import sys + def configure_logging(logger_name, filename=None): """ Configure logging and return the named logger and the location of the logging configuration file loaded. @@ -88,9 +87,9 @@ def configure_logging(logger_name, filename=None): found = True break if not found: - raise ValueError('Logging configuration file "{}" not found in local or default directory'.format(filename)) + raise ValueError(f'Logging configuration file "{filename}" not found in local or default directory') elif not path.exists(filename): - raise ValueError('Logging configuration file "{}" not found'.format(filename)) + raise ValueError(f'Logging configuration file "{filename}" not found') if filename is not None: global _current_logging_configuration_file diff --git a/splunklib/searchcommands/eventing_command.py b/splunklib/searchcommands/eventing_command.py index 27dc13a3a..d42d056df 100644 --- a/splunklib/searchcommands/eventing_command.py +++ b/splunklib/searchcommands/eventing_command.py @@ -1,6 +1,6 @@ # coding=utf-8 # -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,10 +14,7 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals -from splunklib import six -from splunklib.six.moves import map as imap from .decorators import ConfigurationSetting from .search_command import SearchCommand @@ -140,10 +137,10 @@ def fix_up(cls, command): # N.B.: Does not use Python 2 dict copy semantics def iteritems(self): iteritems = SearchCommand.ConfigurationSettings.iteritems(self) - return imap(lambda name_value: (name_value[0], 'events' if name_value[0] == 'type' else name_value[1]), iteritems) + return [(name_value[0], 'events' if name_value[0] == 'type' else name_value[1]) for name_value in iteritems] # N.B.: Does not use Python 3 dict view semantics - if not six.PY2: - items = iteritems + + items = iteritems # endregion diff --git a/splunklib/searchcommands/external_search_command.py b/splunklib/searchcommands/external_search_command.py index c2306241e..a8929f8d3 100644 --- a/splunklib/searchcommands/external_search_command.py +++ b/splunklib/searchcommands/external_search_command.py @@ -1,6 +1,6 @@ # coding=utf-8 # -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,34 +14,31 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals - from logging import getLogger import os import sys import traceback -from splunklib import six +from . import splunklib_logger as logger + if sys.platform == 'win32': from signal import signal, CTRL_BREAK_EVENT, SIGBREAK, SIGINT, SIGTERM from subprocess import Popen import atexit -from . import splunklib_logger as logger + # P1 [ ] TODO: Add ExternalSearchCommand class documentation -class ExternalSearchCommand(object): - """ - """ +class ExternalSearchCommand: def __init__(self, path, argv=None, environ=None): - if not isinstance(path, (bytes, six.text_type)): - raise ValueError('Expected a string value for path, not {}'.format(repr(path))) + if not isinstance(path, (bytes,str)): + raise ValueError(f'Expected a string value for path, not {repr(path)}') self._logger = getLogger(self.__class__.__name__) - self._path = six.text_type(path) + self._path = str(path) self._argv = None self._environ = None @@ -57,7 +54,7 @@ def argv(self): @argv.setter def argv(self, value): if not (value is None or isinstance(value, (list, tuple))): - raise ValueError('Expected a list, tuple or value of None for argv, not {}'.format(repr(value))) + raise ValueError(f'Expected a list, tuple or value of None for argv, not {repr(value)}') self._argv = value @property @@ -67,7 +64,7 @@ def environ(self): @environ.setter def environ(self, value): if not (value is None or isinstance(value, dict)): - raise ValueError('Expected a dictionary value for environ, not {}'.format(repr(value))) + raise ValueError(f'Expected a dictionary value for environ, not {repr(value)}') self._environ = value @property @@ -90,7 +87,7 @@ def execute(self): self._execute(self._path, self._argv, self._environ) except: error_type, error, tb = sys.exc_info() - message = 'Command execution failed: ' + six.text_type(error) + message = f'Command execution failed: {str(error)}' self._logger.error(message + '\nTraceback:\n' + ''.join(traceback.format_tb(tb))) sys.exit(1) @@ -120,13 +117,13 @@ def _execute(path, argv=None, environ=None): found = ExternalSearchCommand._search_path(path, search_path) if found is None: - raise ValueError('Cannot find command on path: {}'.format(path)) + raise ValueError(f'Cannot find command on path: {path}') path = found - logger.debug('starting command="%s", arguments=%s', path, argv) + logger.debug(f'starting command="{path}", arguments={argv}') - def terminate(signal_number, frame): - sys.exit('External search command is terminating on receipt of signal={}.'.format(signal_number)) + def terminate(signal_number): + sys.exit(f'External search command is terminating on receipt of signal={signal_number}.') def terminate_child(): if p.pid is not None and p.returncode is None: @@ -206,7 +203,6 @@ def _execute(path, argv, environ): os.execvp(path, argv) else: os.execvpe(path, argv, environ) - return # endregion diff --git a/splunklib/searchcommands/generating_command.py b/splunklib/searchcommands/generating_command.py index 6a75d2c27..36b014c3e 100644 --- a/splunklib/searchcommands/generating_command.py +++ b/splunklib/searchcommands/generating_command.py @@ -1,6 +1,6 @@ # coding=utf-8 # -# Copyright © 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,14 +14,11 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals import sys from .decorators import ConfigurationSetting from .search_command import SearchCommand -from splunklib import six -from splunklib.six.moves import map as imap, filter as ifilter # P1 [O] TODO: Discuss generates_timeorder in the class-level documentation for GeneratingCommand @@ -254,8 +251,7 @@ def process(self, argv=sys.argv, ifile=sys.stdin, ofile=sys.stdout, allow_empty_ if not allow_empty_input: raise ValueError("allow_empty_input cannot be False for Generating Commands") - else: - return super(GeneratingCommand, self).process(argv=argv, ifile=ifile, ofile=ofile, allow_empty_input=True) + return super().process(argv=argv, ifile=ifile, ofile=ofile, allow_empty_input=True) # endregion @@ -370,18 +366,14 @@ def iteritems(self): iteritems = SearchCommand.ConfigurationSettings.iteritems(self) version = self.command.protocol_version if version == 2: - iteritems = ifilter(lambda name_value1: name_value1[0] != 'distributed', iteritems) + iteritems = [name_value1 for name_value1 in iteritems if name_value1[0] != 'distributed'] if not self.distributed and self.type == 'streaming': - iteritems = imap( - lambda name_value: (name_value[0], 'stateful') if name_value[0] == 'type' else (name_value[0], name_value[1]), iteritems) + iteritems = [(name_value[0], 'stateful') if name_value[0] == 'type' else (name_value[0], name_value[1]) for name_value in iteritems] return iteritems # N.B.: Does not use Python 3 dict view semantics - if not six.PY2: - items = iteritems + items = iteritems - pass # endregion - pass # endregion diff --git a/splunklib/searchcommands/internals.py b/splunklib/searchcommands/internals.py index 1ea2833db..abceac30f 100644 --- a/splunklib/searchcommands/internals.py +++ b/splunklib/searchcommands/internals.py @@ -1,6 +1,6 @@ # coding=utf-8 # -# Copyright © 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,25 +14,22 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function - -from io import TextIOWrapper -from collections import deque, namedtuple -from splunklib import six -from collections import OrderedDict -from splunklib.six.moves import StringIO -from itertools import chain -from splunklib.six.moves import map as imap -from json import JSONDecoder, JSONEncoder -from json.encoder import encode_basestring_ascii as json_encode_string -from splunklib.six.moves import urllib - import csv import gzip import os import re import sys import warnings +import urllib.parse +from io import TextIOWrapper, StringIO +from collections import deque, namedtuple +from collections import OrderedDict +from itertools import chain +from json import JSONDecoder, JSONEncoder +from json.encoder import encode_basestring_ascii as json_encode_string + + + from . import environment @@ -43,35 +40,19 @@ def set_binary_mode(fh): """ Helper method to set up binary mode for file handles. Emphasis being sys.stdin, sys.stdout, sys.stderr. For python3, we want to return .buffer - For python2+windows we want to set os.O_BINARY """ - typefile = TextIOWrapper if sys.version_info >= (3, 0) else file + typefile = TextIOWrapper # check for file handle if not isinstance(fh, typefile): return fh - # check for python3 and buffer - if sys.version_info >= (3, 0) and hasattr(fh, 'buffer'): + # check for buffer + if hasattr(fh, 'buffer'): return fh.buffer - # check for python3 - elif sys.version_info >= (3, 0): - pass - # check for windows python2. SPL-175233 -- python3 stdout is already binary - elif sys.platform == 'win32': - # Work around the fact that on Windows '\n' is mapped to '\r\n'. The typical solution is to simply open files in - # binary mode, but stdout is already open, thus this hack. 'CPython' and 'PyPy' work differently. We assume that - # all other Python implementations are compatible with 'CPython'. This might or might not be a valid assumption. - from platform import python_implementation - implementation = python_implementation() - if implementation == 'PyPy': - return os.fdopen(fh.fileno(), 'wb', 0) - else: - import msvcrt - msvcrt.setmode(fh.fileno(), os.O_BINARY) return fh -class CommandLineParser(object): +class CommandLineParser: r""" Parses the arguments to a search command. A search command line is described by the following syntax. @@ -144,7 +125,7 @@ def parse(cls, command, argv): command_args = cls._arguments_re.match(argv) if command_args is None: - raise SyntaxError('Syntax error: {}'.format(argv)) + raise SyntaxError(f'Syntax error: {argv}') # Parse options @@ -152,7 +133,7 @@ def parse(cls, command, argv): name, value = option.group('name'), option.group('value') if name not in command.options: raise ValueError( - 'Unrecognized {} command option: {}={}'.format(command.name, name, json_encode_string(value))) + f'Unrecognized {command.name} command option: {name}={json_encode_string(value)}') command.options[name].value = cls.unquote(value) missing = command.options.get_missing() @@ -160,8 +141,8 @@ def parse(cls, command, argv): if missing is not None: if len(missing) > 1: raise ValueError( - 'Values for these {} command options are required: {}'.format(command.name, ', '.join(missing))) - raise ValueError('A value for {} command option {} is required'.format(command.name, missing[0])) + f'Values for these {command.name} command options are required: {", ".join(missing)}') + raise ValueError(f'A value for {command.name} command option {missing[0]} is required') # Parse field names @@ -277,10 +258,10 @@ def validate_configuration_setting(specification, name, value): if isinstance(specification.type, type): type_names = specification.type.__name__ else: - type_names = ', '.join(imap(lambda t: t.__name__, specification.type)) - raise ValueError('Expected {} value, not {}={}'.format(type_names, name, repr(value))) + type_names = ', '.join(map(lambda t: t.__name__, specification.type)) + raise ValueError(f'Expected {type_names} value, not {name}={repr(value)}') if specification.constraint and not specification.constraint(value): - raise ValueError('Illegal value: {}={}'.format(name, repr(value))) + raise ValueError(f'Illegal value: {name}={ repr(value)}') return value specification = namedtuple( @@ -314,7 +295,7 @@ def validate_configuration_setting(specification, name, value): supporting_protocols=[1]), 'maxinputs': specification( type=int, - constraint=lambda value: 0 <= value <= six.MAXSIZE, + constraint=lambda value: 0 <= value <= sys.maxsize, supporting_protocols=[2]), 'overrides_timeorder': specification( type=bool, @@ -341,11 +322,11 @@ def validate_configuration_setting(specification, name, value): constraint=None, supporting_protocols=[1]), 'streaming_preop': specification( - type=(bytes, six.text_type), + type=(bytes, str), constraint=None, supporting_protocols=[1, 2]), 'type': specification( - type=(bytes, six.text_type), + type=(bytes, str), constraint=lambda value: value in ('events', 'reporting', 'streaming'), supporting_protocols=[2])} @@ -368,7 +349,7 @@ class InputHeader(dict): """ def __str__(self): - return '\n'.join([name + ':' + value for name, value in six.iteritems(self)]) + return '\n'.join([name + ':' + value for name, value in self.items()]) def read(self, ifile): """ Reads an input header from an input file. @@ -416,7 +397,7 @@ def _object_hook(dictionary): while len(stack): instance, member_name, dictionary = stack.popleft() - for name, value in six.iteritems(dictionary): + for name, value in dictionary.items(): if isinstance(value, dict): stack.append((dictionary, name, value)) @@ -437,11 +418,14 @@ def default(self, o): _separators = (',', ':') -class ObjectView(object): +class ObjectView: def __init__(self, dictionary): self.__dict__ = dictionary + def update(self, obj): + self.__dict__.update(obj.__dict__) + def __repr__(self): return repr(self.__dict__) @@ -449,7 +433,7 @@ def __str__(self): return str(self.__dict__) -class Recorder(object): +class Recorder: def __init__(self, path, f): self._recording = gzip.open(path + '.gz', 'wb') @@ -487,7 +471,7 @@ def write(self, text): self._recording.flush() -class RecordWriter(object): +class RecordWriter: def __init__(self, ofile, maxresultrows=None): self._maxresultrows = 50000 if maxresultrows is None else maxresultrows @@ -513,7 +497,7 @@ def is_flushed(self): @is_flushed.setter def is_flushed(self, value): - self._flushed = True if value else False + self._flushed = bool(value) @property def ofile(self): @@ -593,7 +577,7 @@ def _write_record(self, record): if fieldnames is None: self._fieldnames = fieldnames = list(record.keys()) self._fieldnames.extend([i for i in self.custom_fields if i not in self._fieldnames]) - value_list = imap(lambda fn: (str(fn), str('__mv_') + str(fn)), fieldnames) + value_list = map(lambda fn: (str(fn), str('__mv_') + str(fn)), fieldnames) self._writerow(list(chain.from_iterable(value_list))) get_value = record.get @@ -632,9 +616,9 @@ def _write_record(self, record): if value_t is bool: value = str(value.real) - elif value_t is six.text_type: + elif value_t is str: value = value - elif isinstance(value, six.integer_types) or value_t is float or value_t is complex: + elif isinstance(value, int) or value_t is float or value_t is complex: value = str(value) elif issubclass(value_t, (dict, list, tuple)): value = str(''.join(RecordWriter._iterencode_json(value, 0))) @@ -658,13 +642,11 @@ def _write_record(self, record): values += (value, None) continue - if value_t is six.text_type: - if six.PY2: - value = value.encode('utf-8') + if value_t is str: values += (value, None) continue - if isinstance(value, six.integer_types) or value_t is float or value_t is complex: + if isinstance(value, int) or value_t is float or value_t is complex: values += (str(value), None) continue @@ -799,16 +781,15 @@ def write_chunk(self, finished=None): if len(inspector) == 0: inspector = None - metadata = [item for item in (('inspector', inspector), ('finished', finished))] + metadata = [('inspector', inspector), ('finished', finished)] self._write_chunk(metadata, self._buffer.getvalue()) self._clear() def write_metadata(self, configuration): self._ensure_validity() - metadata = chain(six.iteritems(configuration), (('inspector', self._inspector if self._inspector else None),)) + metadata = chain(configuration.items(), (('inspector', self._inspector if self._inspector else None),)) self._write_chunk(metadata, '') - self.write('\n') self._clear() def write_metric(self, name, value): @@ -816,13 +797,13 @@ def write_metric(self, name, value): self._inspector['metric.' + name] = value def _clear(self): - super(RecordWriterV2, self)._clear() + super()._clear() self._fieldnames = None def _write_chunk(self, metadata, body): if metadata: - metadata = str(''.join(self._iterencode_json(dict([(n, v) for n, v in metadata if v is not None]), 0))) + metadata = str(''.join(self._iterencode_json(dict((n, v) for n, v in metadata if v is not None), 0))) if sys.version_info >= (3, 0): metadata = metadata.encode('utf-8') metadata_length = len(metadata) @@ -836,7 +817,7 @@ def _write_chunk(self, metadata, body): if not (metadata_length > 0 or body_length > 0): return - start_line = 'chunked 1.0,%s,%s\n' % (metadata_length, body_length) + start_line = f'chunked 1.0,{metadata_length},{body_length}\n' self.write(start_line) self.write(metadata) self.write(body) diff --git a/splunklib/searchcommands/reporting_command.py b/splunklib/searchcommands/reporting_command.py index 947086197..5df3dc7e7 100644 --- a/splunklib/searchcommands/reporting_command.py +++ b/splunklib/searchcommands/reporting_command.py @@ -1,6 +1,6 @@ # coding=utf-8 # -# Copyright © 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,8 +14,6 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals - from itertools import chain from .internals import ConfigurationSettingsType, json_encode_string @@ -23,7 +21,6 @@ from .streaming_command import StreamingCommand from .search_command import SearchCommand from .validators import Set -from splunklib import six class ReportingCommand(SearchCommand): @@ -94,7 +91,7 @@ def prepare(self): self._configuration.streaming_preop = ' '.join(streaming_preop) return - raise RuntimeError('Unrecognized reporting command phase: {}'.format(json_encode_string(six.text_type(phase)))) + raise RuntimeError(f'Unrecognized reporting command phase: {json_encode_string(str(phase))}') def reduce(self, records): """ Override this method to produce a reporting data structure. @@ -244,7 +241,7 @@ def fix_up(cls, command): """ if not issubclass(command, ReportingCommand): - raise TypeError('{} is not a ReportingCommand'.format( command)) + raise TypeError(f'{command} is not a ReportingCommand') if command.reduce == ReportingCommand.reduce: raise AttributeError('No ReportingCommand.reduce override') @@ -274,8 +271,7 @@ def fix_up(cls, command): ConfigurationSetting.fix_up(f.ConfigurationSettings, settings) del f._settings - pass + # endregion - pass # endregion diff --git a/splunklib/searchcommands/search_command.py b/splunklib/searchcommands/search_command.py index dd11391d6..7e8f771e1 100644 --- a/splunklib/searchcommands/search_command.py +++ b/splunklib/searchcommands/search_command.py @@ -1,6 +1,6 @@ # coding=utf-8 # -# Copyright © 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,44 +14,32 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals - # Absolute imports -from collections import namedtuple - +import csv import io - -from collections import OrderedDict +import os +import re +import sys +import tempfile +import traceback +from collections import namedtuple, OrderedDict from copy import deepcopy -from splunklib.six.moves import StringIO +from io import StringIO from itertools import chain, islice -from splunklib.six.moves import filter as ifilter, map as imap, zip as izip -from splunklib import six -if six.PY2: - from logging import _levelNames, getLevelName, getLogger -else: - from logging import _nameToLevel as _levelNames, getLevelName, getLogger -try: - from shutil import make_archive -except ImportError: - # Used for recording, skip on python 2.6 - pass +from logging import _nameToLevel as _levelNames, getLevelName, getLogger +from shutil import make_archive from time import time -from splunklib.six.moves.urllib.parse import unquote -from splunklib.six.moves.urllib.parse import urlsplit +from urllib.parse import unquote +from urllib.parse import urlsplit from warnings import warn from xml.etree import ElementTree +from splunklib.utils import ensure_str -import os -import sys -import re -import csv -import tempfile -import traceback # Relative imports - +import splunklib +from . import Boolean, Option, environment from .internals import ( CommandLineParser, CsvDialect, @@ -64,8 +52,6 @@ RecordWriterV1, RecordWriterV2, json_encode_string) - -from . import Boolean, Option, environment from ..client import Service @@ -91,7 +77,7 @@ # P2 [ ] TODO: Consider bumping None formatting up to Option.Item.__str__ -class SearchCommand(object): +class SearchCommand: """ Represents a custom search command. """ @@ -158,16 +144,16 @@ def logging_level(self): def logging_level(self, value): if value is None: value = self._default_logging_level - if isinstance(value, (bytes, six.text_type)): + if isinstance(value, (bytes, str)): try: level = _levelNames[value.upper()] except KeyError: - raise ValueError('Unrecognized logging level: {}'.format(value)) + raise ValueError(f'Unrecognized logging level: {value}') else: try: level = int(value) except ValueError: - raise ValueError('Unrecognized logging level: {}'.format(value)) + raise ValueError(f'Unrecognized logging level: {value}') self._logger.setLevel(level) def add_field(self, current_record, field_name, field_value): @@ -291,7 +277,7 @@ def search_results_info(self): values = next(reader) except IOError as error: if error.errno == 2: - self.logger.error('Search results info file {} does not exist.'.format(json_encode_string(path))) + self.logger.error(f'Search results info file {json_encode_string(path)} does not exist.') return raise @@ -306,7 +292,7 @@ def convert_value(value): except ValueError: return value - info = ObjectView(dict(imap(lambda f_v: (convert_field(f_v[0]), convert_value(f_v[1])), izip(fields, values)))) + info = ObjectView(dict((convert_field(f_v[0]), convert_value(f_v[1])) for f_v in zip(fields, values))) try: count_map = info.countMap @@ -315,7 +301,7 @@ def convert_value(value): else: count_map = count_map.split(';') n = len(count_map) - info.countMap = dict(izip(islice(count_map, 0, n, 2), islice(count_map, 1, n, 2))) + info.countMap = dict(list(zip(islice(count_map, 0, n, 2), islice(count_map, 1, n, 2)))) try: msg_type = info.msgType @@ -323,7 +309,7 @@ def convert_value(value): except AttributeError: pass else: - messages = ifilter(lambda t_m: t_m[0] or t_m[1], izip(msg_type.split('\n'), msg_text.split('\n'))) + messages = [t_m for t_m in zip(msg_type.split('\n'), msg_text.split('\n')) if t_m[0] or t_m[1]] info.msg = [Message(message) for message in messages] del info.msgType @@ -417,7 +403,6 @@ def prepare(self): :rtype: NoneType """ - pass def process(self, argv=sys.argv, ifile=sys.stdin, ofile=sys.stdout, allow_empty_input=True): """ Process data. @@ -466,7 +451,7 @@ def _map_metadata(self, argv): def _map(metadata_map): metadata = {} - for name, value in six.iteritems(metadata_map): + for name, value in metadata_map.items(): if isinstance(value, dict): value = _map(value) else: @@ -485,7 +470,8 @@ def _map(metadata_map): _metadata_map = { 'action': - (lambda v: 'getinfo' if v == '__GETINFO__' else 'execute' if v == '__EXECUTE__' else None, lambda s: s.argv[1]), + (lambda v: 'getinfo' if v == '__GETINFO__' else 'execute' if v == '__EXECUTE__' else None, + lambda s: s.argv[1]), 'preview': (bool, lambda s: s.input_header.get('preview')), 'searchinfo': { @@ -533,7 +519,7 @@ def _prepare_protocol_v1(self, argv, ifile, ofile): try: tempfile.tempdir = self._metadata.searchinfo.dispatch_dir except AttributeError: - raise RuntimeError('{}.metadata.searchinfo.dispatch_dir is undefined'.format(self.__class__.__name__)) + raise RuntimeError(f'{self.__class__.__name__}.metadata.searchinfo.dispatch_dir is undefined') debug(' tempfile.tempdir=%r', tempfile.tempdir) @@ -603,7 +589,8 @@ def _process_protocol_v1(self, argv, ifile, ofile): ifile = self._prepare_protocol_v1(argv, ifile, ofile) self._record_writer.write_record(dict( - (n, ','.join(v) if isinstance(v, (list, tuple)) else v) for n, v in six.iteritems(self._configuration))) + (n, ','.join(v) if isinstance(v, (list, tuple)) else v) for n, v in + self._configuration.items())) self.finish() elif argv[1] == '__EXECUTE__': @@ -617,21 +604,21 @@ def _process_protocol_v1(self, argv, ifile, ofile): else: message = ( - 'Command {0} appears to be statically configured for search command protocol version 1 and static ' + f'Command {self.name} appears to be statically configured for search command protocol version 1 and static ' 'configuration is unsupported by splunklib.searchcommands. Please ensure that ' 'default/commands.conf contains this stanza:\n' - '[{0}]\n' - 'filename = {1}\n' + f'[{self.name}]\n' + f'filename = {os.path.basename(argv[0])}\n' 'enableheader = true\n' 'outputheader = true\n' 'requires_srinfo = true\n' 'supports_getinfo = true\n' 'supports_multivalues = true\n' - 'supports_rawargs = true'.format(self.name, os.path.basename(argv[0]))) + 'supports_rawargs = true') raise RuntimeError(message) except (SyntaxError, ValueError) as error: - self.write_error(six.text_type(error)) + self.write_error(str(error)) self.flush() exit(0) @@ -686,7 +673,7 @@ def _process_protocol_v2(self, argv, ifile, ofile): action = getattr(metadata, 'action', None) if action != 'getinfo': - raise RuntimeError('Expected getinfo action, not {}'.format(action)) + raise RuntimeError(f'Expected getinfo action, not {action}') if len(body) > 0: raise RuntimeError('Did not expect data for getinfo action') @@ -706,7 +693,7 @@ def _process_protocol_v2(self, argv, ifile, ofile): try: tempfile.tempdir = self._metadata.searchinfo.dispatch_dir except AttributeError: - raise RuntimeError('%s.metadata.searchinfo.dispatch_dir is undefined'.format(class_name)) + raise RuntimeError(f'{class_name}.metadata.searchinfo.dispatch_dir is undefined') debug(' tempfile.tempdir=%r', tempfile.tempdir) except: @@ -727,7 +714,7 @@ def _process_protocol_v2(self, argv, ifile, ofile): debug('Parsing arguments') - if args and type(args) == list: + if args and isinstance(args, list): for arg in args: result = self._protocol_v2_option_parser(arg) if len(result) == 1: @@ -738,13 +725,13 @@ def _process_protocol_v2(self, argv, ifile, ofile): try: option = self.options[name] except KeyError: - self.write_error('Unrecognized option: {}={}'.format(name, value)) + self.write_error(f'Unrecognized option: {name}={value}') error_count += 1 continue try: option.value = value except ValueError: - self.write_error('Illegal value: {}={}'.format(name, value)) + self.write_error(f'Illegal value: {name}={value}') error_count += 1 continue @@ -752,15 +739,15 @@ def _process_protocol_v2(self, argv, ifile, ofile): if missing is not None: if len(missing) == 1: - self.write_error('A value for "{}" is required'.format(missing[0])) + self.write_error(f'A value for "{missing[0]}" is required') else: - self.write_error('Values for these required options are missing: {}'.format(', '.join(missing))) + self.write_error(f'Values for these required options are missing: {", ".join(missing)}') error_count += 1 if error_count > 0: exit(1) - debug(' command: %s', six.text_type(self)) + debug(' command: %s', str(self)) debug('Preparing for execution') self.prepare() @@ -778,7 +765,7 @@ def _process_protocol_v2(self, argv, ifile, ofile): setattr(info, attr, [arg for arg in getattr(info, attr) if not arg.startswith('record=')]) metadata = MetadataEncoder().encode(self._metadata) - ifile.record('chunked 1.0,', six.text_type(len(metadata)), ',0\n', metadata) + ifile.record('chunked 1.0,', str(len(metadata)), ',0\n', metadata) if self.show_configuration: self.write_info(self.name + ' command configuration: ' + str(self._configuration)) @@ -888,25 +875,25 @@ def _as_binary_stream(ifile): try: return ifile.buffer except AttributeError as error: - raise RuntimeError('Failed to get underlying buffer: {}'.format(error)) + raise RuntimeError(f'Failed to get underlying buffer: {error}') @staticmethod def _read_chunk(istream): # noinspection PyBroadException - assert isinstance(istream.read(0), six.binary_type), 'Stream must be binary' + assert isinstance(istream.read(0), bytes), 'Stream must be binary' try: header = istream.readline() except Exception as error: - raise RuntimeError('Failed to read transport header: {}'.format(error)) + raise RuntimeError(f'Failed to read transport header: {error}') if not header: return None - match = SearchCommand._header.match(six.ensure_str(header)) + match = SearchCommand._header.match(ensure_str(header)) if match is None: - raise RuntimeError('Failed to parse transport header: {}'.format(header)) + raise RuntimeError(f'Failed to parse transport header: {header}') metadata_length, body_length = match.groups() metadata_length = int(metadata_length) @@ -915,14 +902,14 @@ def _read_chunk(istream): try: metadata = istream.read(metadata_length) except Exception as error: - raise RuntimeError('Failed to read metadata of length {}: {}'.format(metadata_length, error)) + raise RuntimeError(f'Failed to read metadata of length {metadata_length}: {error}') decoder = MetadataDecoder() try: - metadata = decoder.decode(six.ensure_str(metadata)) + metadata = decoder.decode(ensure_str(metadata)) except Exception as error: - raise RuntimeError('Failed to parse metadata of length {}: {}'.format(metadata_length, error)) + raise RuntimeError(f'Failed to parse metadata of length {metadata_length}: {error}') # if body_length <= 0: # return metadata, '' @@ -932,9 +919,9 @@ def _read_chunk(istream): if body_length > 0: body = istream.read(body_length) except Exception as error: - raise RuntimeError('Failed to read body of length {}: {}'.format(body_length, error)) + raise RuntimeError(f'Failed to read body of length {body_length}: {error}') - return metadata, six.ensure_str(body) + return metadata, ensure_str(body,errors="replace") _header = re.compile(r'chunked\s+1.0\s*,\s*(\d+)\s*,\s*(\d+)\s*\n') @@ -949,16 +936,16 @@ def _read_csv_records(self, ifile): except StopIteration: return - mv_fieldnames = dict([(name, name[len('__mv_'):]) for name in fieldnames if name.startswith('__mv_')]) + mv_fieldnames = dict((name, name[len('__mv_'):]) for name in fieldnames if name.startswith('__mv_')) if len(mv_fieldnames) == 0: for values in reader: - yield OrderedDict(izip(fieldnames, values)) + yield OrderedDict(list(zip(fieldnames, values))) return for values in reader: record = OrderedDict() - for fieldname, value in izip(fieldnames, values): + for fieldname, value in zip(fieldnames, values): if fieldname.startswith('__mv_'): if len(value) > 0: record[mv_fieldnames[fieldname]] = self._decode_list(value) @@ -978,25 +965,25 @@ def _execute_v2(self, ifile, process): metadata, body = result action = getattr(metadata, 'action', None) if action != 'execute': - raise RuntimeError('Expected execute action, not {}'.format(action)) + raise RuntimeError(f'Expected execute action, not {action}') self._finished = getattr(metadata, 'finished', False) self._record_writer.is_flushed = False - + self._metadata.update(metadata) self._execute_chunk_v2(process, result) self._record_writer.write_chunk(finished=self._finished) def _execute_chunk_v2(self, process, chunk): - metadata, body = chunk + metadata, body = chunk - if len(body) <= 0 and not self._allow_empty_input: - raise ValueError( - "No records found to process. Set allow_empty_input=True in dispatch function to move forward " - "with empty records.") + if len(body) <= 0 and not self._allow_empty_input: + raise ValueError( + "No records found to process. Set allow_empty_input=True in dispatch function to move forward " + "with empty records.") - records = self._read_csv_records(StringIO(body)) - self._record_writer.write_records(process(records)) + records = self._read_csv_records(StringIO(body)) + self._record_writer.write_records(process(records)) def _report_unexpected_error(self): @@ -1008,7 +995,7 @@ def _report_unexpected_error(self): filename = origin.tb_frame.f_code.co_filename lineno = origin.tb_lineno - message = '{0} at "{1}", line {2:d} : {3}'.format(error_type.__name__, filename, lineno, error) + message = f'{error_type.__name__} at "{filename}", line {str(lineno)} : {error}' environment.splunklib_logger.error(message + '\nTraceback:\n' + ''.join(traceback.format_tb(tb))) self.write_error(message) @@ -1017,10 +1004,11 @@ def _report_unexpected_error(self): # region Types - class ConfigurationSettings(object): + class ConfigurationSettings: """ Represents the configuration settings common to all :class:`SearchCommand` classes. """ + def __init__(self, command): self.command = command @@ -1034,8 +1022,8 @@ def __repr__(self): """ definitions = type(self).configuration_setting_definitions - settings = imap( - lambda setting: repr((setting.name, setting.__get__(self), setting.supporting_protocols)), definitions) + settings = [repr((setting.name, setting.__get__(self), setting.supporting_protocols)) for setting in + definitions] return '[' + ', '.join(settings) + ']' def __str__(self): @@ -1047,8 +1035,8 @@ def __str__(self): :return: String representation of this instance """ - #text = ', '.join(imap(lambda (name, value): name + '=' + json_encode_string(unicode(value)), self.iteritems())) - text = ', '.join(['{}={}'.format(name, json_encode_string(six.text_type(value))) for (name, value) in six.iteritems(self)]) + # text = ', '.join(imap(lambda (name, value): name + '=' + json_encode_string(unicode(value)), self.iteritems())) + text = ', '.join([f'{name}={json_encode_string(str(value))}' for (name, value) in self.items()]) return text # region Methods @@ -1072,24 +1060,25 @@ def fix_up(cls, command_class): def iteritems(self): definitions = type(self).configuration_setting_definitions version = self.command.protocol_version - return ifilter( - lambda name_value1: name_value1[1] is not None, imap( - lambda setting: (setting.name, setting.__get__(self)), ifilter( - lambda setting: setting.is_supported_by_protocol(version), definitions))) + return [name_value1 for name_value1 in [(setting.name, setting.__get__(self)) for setting in + [setting for setting in definitions if + setting.is_supported_by_protocol(version)]] if + name_value1[1] is not None] # N.B.: Does not use Python 3 dict view semantics - if not six.PY2: - items = iteritems - pass # endregion + items = iteritems - pass # endregion + # endregion + + # endregion SearchMetric = namedtuple('SearchMetric', ('elapsed_seconds', 'invocation_count', 'input_count', 'output_count')) -def dispatch(command_class, argv=sys.argv, input_file=sys.stdin, output_file=sys.stdout, module_name=None, allow_empty_input=True): +def dispatch(command_class, argv=sys.argv, input_file=sys.stdin, output_file=sys.stdout, module_name=None, + allow_empty_input=True): """ Instantiates and executes a search command class This function implements a `conditional script stanza `_ based on the value of diff --git a/splunklib/searchcommands/streaming_command.py b/splunklib/searchcommands/streaming_command.py index fa075edb1..e2a3a4077 100644 --- a/splunklib/searchcommands/streaming_command.py +++ b/splunklib/searchcommands/streaming_command.py @@ -1,6 +1,6 @@ # coding=utf-8 # -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,10 +14,6 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals - -from splunklib import six -from splunklib.six.moves import map as imap, filter as ifilter from .decorators import ConfigurationSetting from .search_command import SearchCommand @@ -171,7 +167,6 @@ def fix_up(cls, command): """ if command.stream == StreamingCommand.stream: raise AttributeError('No StreamingCommand.stream override') - return # TODO: Stop looking like a dictionary because we don't obey the semantics # N.B.: Does not use Python 2 dict copy semantics @@ -180,16 +175,14 @@ def iteritems(self): version = self.command.protocol_version if version == 1: if self.required_fields is None: - iteritems = ifilter(lambda name_value: name_value[0] != 'clear_required_fields', iteritems) + iteritems = [name_value for name_value in iteritems if name_value[0] != 'clear_required_fields'] else: - iteritems = ifilter(lambda name_value2: name_value2[0] != 'distributed', iteritems) + iteritems = [name_value2 for name_value2 in iteritems if name_value2[0] != 'distributed'] if not self.distributed: - iteritems = imap( - lambda name_value1: (name_value1[0], 'stateful') if name_value1[0] == 'type' else (name_value1[0], name_value1[1]), iteritems) + iteritems = [(name_value1[0], 'stateful') if name_value1[0] == 'type' else (name_value1[0], name_value1[1]) for name_value1 in iteritems] return iteritems # N.B.: Does not use Python 3 dict view semantics - if not six.PY2: - items = iteritems + items = iteritems # endregion diff --git a/splunklib/searchcommands/validators.py b/splunklib/searchcommands/validators.py index 22f0e16b2..ccaebca0a 100644 --- a/splunklib/searchcommands/validators.py +++ b/splunklib/searchcommands/validators.py @@ -1,6 +1,6 @@ # coding=utf-8 # -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,20 +14,17 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals - -from json.encoder import encode_basestring_ascii as json_encode_string -from collections import namedtuple -from splunklib.six.moves import StringIO -from io import open import csv import os import re -from splunklib import six -from splunklib.six.moves import getcwd +from io import open, StringIO +from os import getcwd +from json.encoder import encode_basestring_ascii as json_encode_string +from collections import namedtuple -class Validator(object): + +class Validator: """ Base class for validators that check and format search command options. You must inherit from this class and override :code:`Validator.__call__` and @@ -60,14 +57,16 @@ class Boolean(Validator): def __call__(self, value): if not (value is None or isinstance(value, bool)): - value = six.text_type(value).lower() + value = str(value).lower() if value not in Boolean.truth_values: - raise ValueError('Unrecognized truth value: {0}'.format(value)) + raise ValueError(f'Unrecognized truth value: {value}') value = Boolean.truth_values[value] return value def format(self, value): - return None if value is None else 't' if value else 'f' + if value is None: + return None + return 't' if value else 'f' class Code(Validator): @@ -93,11 +92,11 @@ def __call__(self, value): if value is None: return None try: - return Code.object(compile(value, 'string', self._mode), six.text_type(value)) + return Code.object(compile(value, 'string', self._mode), str(value)) except (SyntaxError, TypeError) as error: message = str(error) - six.raise_from(ValueError(message), error) + raise ValueError(message) from error def format(self, value): return None if value is None else value.source @@ -113,9 +112,9 @@ class Fieldname(Validator): def __call__(self, value): if value is not None: - value = six.text_type(value) + value = str(value) if Fieldname.pattern.match(value) is None: - raise ValueError('Illegal characters in fieldname: {}'.format(value)) + raise ValueError(f'Illegal characters in fieldname: {value}') return value def format(self, value): @@ -136,7 +135,7 @@ def __call__(self, value): if value is None: return value - path = six.text_type(value) + path = str(value) if not os.path.isabs(path): path = os.path.join(self.directory, path) @@ -144,8 +143,7 @@ def __call__(self, value): try: value = open(path, self.mode) if self.buffering is None else open(path, self.mode, self.buffering) except IOError as error: - raise ValueError('Cannot open {0} with mode={1} and buffering={2}: {3}'.format( - value, self.mode, self.buffering, error)) + raise ValueError(f'Cannot open {value} with mode={self.mode} and buffering={self.buffering}: {error}') return value @@ -163,42 +161,38 @@ class Integer(Validator): def __init__(self, minimum=None, maximum=None): if minimum is not None and maximum is not None: def check_range(value): - if not (minimum <= value <= maximum): - raise ValueError('Expected integer in the range [{0},{1}], not {2}'.format(minimum, maximum, value)) - return + if not minimum <= value <= maximum: + raise ValueError(f'Expected integer in the range [{minimum},{maximum}], not {value}') + elif minimum is not None: def check_range(value): if value < minimum: - raise ValueError('Expected integer in the range [{0},+∞], not {1}'.format(minimum, value)) - return + raise ValueError(f'Expected integer in the range [{minimum},+∞], not {value}') elif maximum is not None: def check_range(value): if value > maximum: - raise ValueError('Expected integer in the range [-∞,{0}], not {1}'.format(maximum, value)) - return + raise ValueError(f'Expected integer in the range [-∞,{maximum}], not {value}') + else: def check_range(value): return self.check_range = check_range - return + def __call__(self, value): if value is None: return None try: - if six.PY2: - value = long(value) - else: - value = int(value) + value = int(value) except ValueError: - raise ValueError('Expected integer value, not {}'.format(json_encode_string(value))) + raise ValueError(f'Expected integer value, not {json_encode_string(value)}') self.check_range(value) return value def format(self, value): - return None if value is None else six.text_type(int(value)) + return None if value is None else str(int(value)) class Float(Validator): @@ -208,25 +202,21 @@ class Float(Validator): def __init__(self, minimum=None, maximum=None): if minimum is not None and maximum is not None: def check_range(value): - if not (minimum <= value <= maximum): - raise ValueError('Expected float in the range [{0},{1}], not {2}'.format(minimum, maximum, value)) - return + if not minimum <= value <= maximum: + raise ValueError(f'Expected float in the range [{minimum},{maximum}], not {value}') elif minimum is not None: def check_range(value): if value < minimum: - raise ValueError('Expected float in the range [{0},+∞], not {1}'.format(minimum, value)) - return + raise ValueError(f'Expected float in the range [{minimum},+∞], not {value}') elif maximum is not None: def check_range(value): if value > maximum: - raise ValueError('Expected float in the range [-∞,{0}], not {1}'.format(maximum, value)) - return + raise ValueError(f'Expected float in the range [-∞,{maximum}], not {value}') else: def check_range(value): return - self.check_range = check_range - return + def __call__(self, value): if value is None: @@ -234,13 +224,13 @@ def __call__(self, value): try: value = float(value) except ValueError: - raise ValueError('Expected float value, not {}'.format(json_encode_string(value))) + raise ValueError(f'Expected float value, not {json_encode_string(value)}') self.check_range(value) return value def format(self, value): - return None if value is None else six.text_type(float(value)) + return None if value is None else str(float(value)) class Duration(Validator): @@ -265,7 +255,7 @@ def __call__(self, value): if len(p) == 3: result = 3600 * _unsigned(p[0]) + 60 * _60(p[1]) + _60(p[2]) except ValueError: - raise ValueError('Invalid duration value: {0}'.format(value)) + raise ValueError(f'Invalid duration value: {value}') return result @@ -302,7 +292,7 @@ class Dialect(csv.Dialect): def __init__(self, validator=None): if not (validator is None or isinstance(validator, Validator)): - raise ValueError('Expected a Validator instance or None for validator, not {}', repr(validator)) + raise ValueError(f'Expected a Validator instance or None for validator, not {repr(validator)}') self._validator = validator def __call__(self, value): @@ -322,7 +312,7 @@ def __call__(self, value): for index, item in enumerate(value): value[index] = self._validator(item) except ValueError as error: - raise ValueError('Could not convert item {}: {}'.format(index, error)) + raise ValueError(f'Could not convert item {index}: {error}') return value @@ -346,10 +336,10 @@ def __call__(self, value): if value is None: return None - value = six.text_type(value) + value = str(value) if value not in self.membership: - raise ValueError('Unrecognized value: {0}'.format(value)) + raise ValueError(f'Unrecognized value: {value}') return self.membership[value] @@ -362,19 +352,19 @@ class Match(Validator): """ def __init__(self, name, pattern, flags=0): - self.name = six.text_type(name) + self.name = str(name) self.pattern = re.compile(pattern, flags) def __call__(self, value): if value is None: return None - value = six.text_type(value) + value = str(value) if self.pattern.match(value) is None: - raise ValueError('Expected {}, not {}'.format(self.name, json_encode_string(value))) + raise ValueError(f'Expected {self.name}, not {json_encode_string(value)}') return value def format(self, value): - return None if value is None else six.text_type(value) + return None if value is None else str(value) class OptionName(Validator): @@ -385,13 +375,13 @@ class OptionName(Validator): def __call__(self, value): if value is not None: - value = six.text_type(value) + value = str(value) if OptionName.pattern.match(value) is None: - raise ValueError('Illegal characters in option name: {}'.format(value)) + raise ValueError(f'Illegal characters in option name: {value}') return value def format(self, value): - return None if value is None else six.text_type(value) + return None if value is None else str(value) class RegularExpression(Validator): @@ -402,9 +392,9 @@ def __call__(self, value): if value is None: return None try: - value = re.compile(six.text_type(value)) + value = re.compile(str(value)) except re.error as error: - raise ValueError('{}: {}'.format(six.text_type(error).capitalize(), value)) + raise ValueError(f'{str(error).capitalize()}: {value}') return value def format(self, value): @@ -421,9 +411,9 @@ def __init__(self, *args): def __call__(self, value): if value is None: return None - value = six.text_type(value) + value = str(value) if value not in self.membership: - raise ValueError('Unrecognized value: {}'.format(value)) + raise ValueError(f'Unrecognized value: {value}') return value def format(self, value): diff --git a/splunklib/utils.py b/splunklib/utils.py new file mode 100644 index 000000000..db9c31267 --- /dev/null +++ b/splunklib/utils.py @@ -0,0 +1,48 @@ +# Copyright © 2011-2024 Splunk, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"): you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""The **splunklib.utils** File for utility functions. +""" + + +def ensure_binary(s, encoding='utf-8', errors='strict'): + """ + - `str` -> encoded to `bytes` + - `bytes` -> `bytes` + """ + if isinstance(s, str): + return s.encode(encoding, errors) + + if isinstance(s, bytes): + return s + + raise TypeError(f"not expecting type '{type(s)}'") + + +def ensure_str(s, encoding='utf-8', errors='strict'): + """ + - `str` -> `str` + - `bytes` -> decoded to `str` + """ + if isinstance(s, bytes): + return s.decode(encoding, errors) + + if isinstance(s, str): + return s + + raise TypeError(f"not expecting type '{type(s)}'") + + +def assertRegex(self, *args, **kwargs): + return getattr(self, "assertRegex")(*args, **kwargs) diff --git a/tests/README.md b/tests/README.md index 3da69c9f9..da02228c7 100644 --- a/tests/README.md +++ b/tests/README.md @@ -1,10 +1,8 @@ # Splunk Test Suite The test suite uses Python's standard library and the built-in **unittest** -library. If you're using Python 2.7 or Python 3.7, you're all set. However, if you are using -Python 2.6, you'll also need to install the **unittest2** library to get the -additional features that were added to Python 2.7 (just run `pip install -unittest2` or `easy_install unittest2`). +library. The Splunk Enterprise SDK for Python has been tested with Python v3.7 +and v3.9. To run the unit tests, open a command prompt in the **/splunk-sdk-python** directory and enter: diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index 2ae28399f..000000000 --- a/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ -pass diff --git a/tests/modularinput/__init__.py b/tests/modularinput/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/modularinput/modularinput_testlib.py b/tests/modularinput/modularinput_testlib.py index 819301736..4bf0df13f 100644 --- a/tests/modularinput/modularinput_testlib.py +++ b/tests/modularinput/modularinput_testlib.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,14 +14,10 @@ # License for the specific language governing permissions and limitations # under the License. -# Utility file for unit tests, import common functions and modules -from __future__ import absolute_import -try: - import unittest2 as unittest -except ImportError: - import unittest -import sys, os import io +import os +import sys +import unittest sys.path.insert(0, os.path.join('../../splunklib', '..')) diff --git a/tests/modularinput/test_event.py b/tests/modularinput/test_event.py index 865656031..35e9c09cd 100644 --- a/tests/modularinput/test_event.py +++ b/tests/modularinput/test_event.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,13 +14,14 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import +import re import sys +from io import StringIO import pytest -from tests.modularinput.modularinput_testlib import unittest, xml_compare, data_open +from tests.modularinput.modularinput_testlib import xml_compare, data_open from splunklib.modularinput.event import Event, ET from splunklib.modularinput.event_writer import EventWriter @@ -99,7 +100,7 @@ def test_writing_events_on_event_writer(capsys): first_out_part = captured.out with data_open("data/stream_with_one_event.xml") as data: - found = ET.fromstring("%s" % first_out_part) + found = ET.fromstring(f"{first_out_part}") expected = ET.parse(data).getroot() assert xml_compare(expected, found) @@ -151,3 +152,29 @@ def test_write_xml_is_sane(capsys): found_xml = ET.fromstring(captured.out) assert xml_compare(expected_xml, found_xml) + + +def test_log_exception(): + out, err = StringIO(), StringIO() + ew = EventWriter(out, err) + + exc = Exception("Something happened!") + + try: + raise exc + except Exception: + ew.log_exception("ex1") + + assert out.getvalue() == "" + + # Remove paths and line + err = re.sub(r'File "[^"]+', 'File "...', err.getvalue()) + err = re.sub(r'line \d+', 'line 123', err) + + # One line + assert err == ( + 'ERROR ex1 - Traceback (most recent call last): ' + ' File "...", line 123, in test_log_exception ' + ' raise exc ' + 'Exception: Something happened! ' + ) diff --git a/tests/modularinput/test_input_definition.py b/tests/modularinput/test_input_definition.py index d0f59a04e..93601b352 100644 --- a/tests/modularinput/test_input_definition.py +++ b/tests/modularinput/test_input_definition.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,10 +14,10 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import from tests.modularinput.modularinput_testlib import unittest, data_open from splunklib.modularinput.input_definition import InputDefinition + class InputDefinitionTestCase(unittest.TestCase): def test_parse_inputdef_with_zero_inputs(self): @@ -70,7 +70,8 @@ def test_attempt_to_parse_malformed_input_definition_will_throw_exception(self): """Does malformed XML cause the expected exception.""" with self.assertRaises(ValueError): - found = InputDefinition.parse(data_open("data/conf_with_invalid_inputs.xml")) + InputDefinition.parse(data_open("data/conf_with_invalid_inputs.xml")) + if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tests/modularinput/test_scheme.py b/tests/modularinput/test_scheme.py index e1b3463a3..7e7ddcc9c 100644 --- a/tests/modularinput/test_scheme.py +++ b/tests/modularinput/test_scheme.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -13,15 +13,11 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import +import xml.etree.ElementTree as ET from tests.modularinput.modularinput_testlib import unittest, xml_compare, data_open from splunklib.modularinput.scheme import Scheme from splunklib.modularinput.argument import Argument -try: - import xml.etree.cElementTree as ET -except ImportError: - import xml.etree.ElementTree as ET class SchemeTest(unittest.TestCase): def test_generate_xml_from_scheme_with_default_values(self): @@ -40,7 +36,7 @@ def test_generate_xml_from_scheme(self): some arguments added matches what we expect.""" scheme = Scheme("abcd") - scheme.description = u"쎼 and 쎶 and <&> für" + scheme.description = "쎼 and 쎶 and <&> für" scheme.streaming_mode = Scheme.streaming_mode_simple scheme.use_external_validation = "false" scheme.use_single_instance = "true" @@ -50,7 +46,7 @@ def test_generate_xml_from_scheme(self): arg2 = Argument( name="arg2", - description=u"쎼 and 쎶 and <&> für", + description="쎼 and 쎶 and <&> für", validation="is_pos_int('some_name')", data_type=Argument.data_type_number, required_on_edit=True, @@ -69,7 +65,7 @@ def test_generate_xml_from_scheme_with_arg_title(self): some arguments added matches what we expect. Also sets the title on an argument.""" scheme = Scheme("abcd") - scheme.description = u"쎼 and 쎶 and <&> für" + scheme.description = "쎼 and 쎶 and <&> für" scheme.streaming_mode = Scheme.streaming_mode_simple scheme.use_external_validation = "false" scheme.use_single_instance = "true" @@ -79,7 +75,7 @@ def test_generate_xml_from_scheme_with_arg_title(self): arg2 = Argument( name="arg2", - description=u"쎼 and 쎶 and <&> für", + description="쎼 and 쎶 and <&> für", validation="is_pos_int('some_name')", data_type=Argument.data_type_number, required_on_edit=True, @@ -113,7 +109,7 @@ def test_generate_xml_from_argument(self): argument = Argument( name="some_name", - description=u"쎼 and 쎶 and <&> für", + description="쎼 and 쎶 and <&> für", validation="is_pos_int('some_name')", data_type=Argument.data_type_boolean, required_on_edit="true", diff --git a/tests/modularinput/test_script.py b/tests/modularinput/test_script.py index b15885dc7..49c259725 100644 --- a/tests/modularinput/test_script.py +++ b/tests/modularinput/test_script.py @@ -1,16 +1,14 @@ import sys +import io +import re +import xml.etree.ElementTree as ET from splunklib.client import Service from splunklib.modularinput import Script, EventWriter, Scheme, Argument, Event -import io from splunklib.modularinput.utils import xml_compare from tests.modularinput.modularinput_testlib import data_open -try: - import xml.etree.cElementTree as ET -except ImportError: - import xml.etree.ElementTree as ET TEST_SCRIPT_PATH = "__IGNORED_SCRIPT_PATH__" @@ -51,7 +49,7 @@ def test_scheme_properly_generated_by_script(capsys): class NewScript(Script): def get_scheme(self): scheme = Scheme("abcd") - scheme.description = u"\uC3BC and \uC3B6 and <&> f\u00FCr" + scheme.description = "\uC3BC and \uC3B6 and <&> f\u00FCr" scheme.streaming_mode = scheme.streaming_mode_simple scheme.use_external_validation = False scheme.use_single_instance = True @@ -60,7 +58,7 @@ def get_scheme(self): scheme.add_argument(arg1) arg2 = Argument("arg2") - arg2.description = u"\uC3BC and \uC3B6 and <&> f\u00FCr" + arg2.description = "\uC3BC and \uC3B6 and <&> f\u00FCr" arg2.data_type = Argument.data_type_number arg2.required_on_create = True arg2.required_on_edit = True @@ -208,7 +206,7 @@ def test_service_property(capsys): # Override abstract methods class NewScript(Script): def __init__(self): - super(NewScript, self).__init__() + super().__init__() self.authority_uri = None def get_scheme(self): @@ -231,3 +229,38 @@ def stream_events(self, inputs, ew): assert output.err == "" assert isinstance(script.service, Service) assert script.service.authority == script.authority_uri + + +def test_log_script_exception(monkeypatch): + out, err = io.StringIO(), io.StringIO() + + # Override abstract methods + class NewScript(Script): + def get_scheme(self): + return None + + def stream_events(self, inputs, ew): + raise RuntimeError("Some error") + + script = NewScript() + input_configuration = data_open("data/conf_with_2_inputs.xml") + + ew = EventWriter(out, err) + + assert script.run_script([TEST_SCRIPT_PATH], ew, input_configuration) == 1 + + # Remove paths and line numbers + err = re.sub(r'File "[^"]+', 'File "...', err.getvalue()) + err = re.sub(r'line \d+', 'line 123', err) + err = re.sub(r' +~+\^+', '', err) + + assert out.getvalue() == "" + assert err == ( + 'ERROR Some error - ' + 'Traceback (most recent call last): ' + ' File "...", line 123, in run_script ' + ' self.stream_events(self._input_definition, event_writer) ' + ' File "...", line 123, in stream_events ' + ' raise RuntimeError("Some error") ' + 'RuntimeError: Some error ' + ) diff --git a/tests/modularinput/test_validation_definition.py b/tests/modularinput/test_validation_definition.py index c8046f3b3..1b71a2206 100644 --- a/tests/modularinput/test_validation_definition.py +++ b/tests/modularinput/test_validation_definition.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,10 +14,11 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import + from tests.modularinput.modularinput_testlib import unittest, data_open from splunklib.modularinput.validation_definition import ValidationDefinition + class ValidationDefinitionTestCase(unittest.TestCase): def test_validation_definition_parse(self): """Check that parsing produces expected result""" @@ -42,5 +43,6 @@ def test_validation_definition_parse(self): self.assertEqual(expected, found) + if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tests/searchcommands/__init__.py b/tests/searchcommands/__init__.py index 2f282889a..41d8d0668 100644 --- a/tests/searchcommands/__init__.py +++ b/tests/searchcommands/__init__.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # coding=utf-8 # -# Copyright © 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -15,10 +15,6 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals - -from sys import version_info as python_version - from os import path import logging diff --git a/tests/searchcommands/chunked_data_stream.py b/tests/searchcommands/chunked_data_stream.py index ae5363eff..02d890af1 100644 --- a/tests/searchcommands/chunked_data_stream.py +++ b/tests/searchcommands/chunked_data_stream.py @@ -4,19 +4,19 @@ import json import splunklib.searchcommands.internals -from splunklib import six +from splunklib.utils import ensure_binary, ensure_str -class Chunk(object): +class Chunk: def __init__(self, version, meta, data): - self.version = six.ensure_str(version) + self.version = ensure_str(version) self.meta = json.loads(meta) dialect = splunklib.searchcommands.internals.CsvDialect self.data = csv.DictReader(io.StringIO(data.decode("utf-8")), dialect=dialect) -class ChunkedDataStreamIter(collections.Iterator): +class ChunkedDataStreamIter(collections.abc.Iterator): def __init__(self, chunk_stream): self.chunk_stream = chunk_stream @@ -30,7 +30,7 @@ def next(self): raise StopIteration -class ChunkedDataStream(collections.Iterable): +class ChunkedDataStream(collections.abc.Iterable): def __iter__(self): return ChunkedDataStreamIter(self) @@ -54,7 +54,7 @@ def read_chunk(self): def build_chunk(keyval, data=None): - metadata = six.ensure_binary(json.dumps(keyval), 'utf-8') + metadata = ensure_binary(json.dumps(keyval)) data_output = _build_data_csv(data) return b"chunked 1.0,%d,%d\n%s%s" % (len(metadata), len(data_output), metadata, data_output) @@ -87,7 +87,7 @@ def _build_data_csv(data): return b'' if isinstance(data, bytes): return data - csvout = splunklib.six.StringIO() + csvout = io.StringIO() headers = set() for datum in data: @@ -97,4 +97,4 @@ def _build_data_csv(data): writer.writeheader() for datum in data: writer.writerow(datum) - return six.ensure_binary(csvout.getvalue()) + return ensure_binary(csvout.getvalue()) diff --git a/tests/searchcommands/test_apps/eventing_app/bin/eventingcsc.py b/tests/searchcommands/test_apps/eventing_app/bin/eventingcsc.py index d34aa14e4..fafbe46f9 100644 --- a/tests/searchcommands/test_apps/eventing_app/bin/eventingcsc.py +++ b/tests/searchcommands/test_apps/eventing_app/bin/eventingcsc.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # coding=utf-8 # -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain diff --git a/tests/searchcommands/test_apps/generating_app/bin/generatingcsc.py b/tests/searchcommands/test_apps/generating_app/bin/generatingcsc.py index 2094ade75..4fe3e765a 100644 --- a/tests/searchcommands/test_apps/generating_app/bin/generatingcsc.py +++ b/tests/searchcommands/test_apps/generating_app/bin/generatingcsc.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # coding=utf-8 # -# Copyright © 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain diff --git a/tests/searchcommands/test_apps/reporting_app/bin/reportingcsc.py b/tests/searchcommands/test_apps/reporting_app/bin/reportingcsc.py index 78be57758..477f5fe20 100644 --- a/tests/searchcommands/test_apps/reporting_app/bin/reportingcsc.py +++ b/tests/searchcommands/test_apps/reporting_app/bin/reportingcsc.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # coding=utf-8 # -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain diff --git a/tests/searchcommands/test_apps/streaming_app/bin/streamingcsc.py b/tests/searchcommands/test_apps/streaming_app/bin/streamingcsc.py index 348496cc6..8ee2c91eb 100644 --- a/tests/searchcommands/test_apps/streaming_app/bin/streamingcsc.py +++ b/tests/searchcommands/test_apps/streaming_app/bin/streamingcsc.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # coding=utf-8 # -# Copyright © 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain diff --git a/tests/searchcommands/test_builtin_options.py b/tests/searchcommands/test_builtin_options.py index e5c2dd8dd..174baed07 100644 --- a/tests/searchcommands/test_builtin_options.py +++ b/tests/searchcommands/test_builtin_options.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # coding=utf-8 # -# Copyright © 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -15,19 +15,15 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals - -from splunklib.six.moves import cStringIO as StringIO -try: - from unittest2 import main, TestCase -except ImportError: - from unittest import main, TestCase import os import sys import logging +from unittest import main, TestCase import pytest +from io import StringIO + from splunklib.searchcommands import environment from splunklib.searchcommands.decorators import Configuration @@ -117,18 +113,18 @@ def test_logging_configuration(self): except ValueError: pass except BaseException as e: - self.fail('Expected ValueError, but {} was raised'.format(type(e))) + self.fail(f'Expected ValueError, but {type(e)} was raised') else: - self.fail('Expected ValueError, but logging_configuration={}'.format(command.logging_configuration)) + self.fail(f'Expected ValueError, but logging_configuration={command.logging_configuration}') try: command.logging_configuration = os.path.join(package_directory, 'non-existent.logging.conf') except ValueError: pass except BaseException as e: - self.fail('Expected ValueError, but {} was raised'.format(type(e))) + self.fail(f'Expected ValueError, but {type(e)} was raised') else: - self.fail('Expected ValueError, but logging_configuration={}'.format(command.logging_configuration)) + self.fail(f'Expected ValueError, but logging_configuration={command.logging_configuration}') def test_logging_level(self): @@ -146,7 +142,7 @@ def test_logging_level(self): self.assertEqual(warning, command.logging_level) for level in level_names(): - if type(level) is int: + if isinstance(level, int): command.logging_level = level level_name = logging.getLevelName(level) self.assertEqual(command.logging_level, warning if level_name == notset else level_name) @@ -171,9 +167,9 @@ def test_logging_level(self): except ValueError: pass except BaseException as e: - self.fail('Expected ValueError, but {} was raised'.format(type(e))) + self.fail(f'Expected ValueError, but {type(e)} was raised') else: - self.fail('Expected ValueError, but logging_level={}'.format(command.logging_level)) + self.fail(f'Expected ValueError, but logging_level={command.logging_level}') self.assertEqual(command.logging_level, current_value) @@ -211,13 +207,9 @@ def _test_boolean_option(self, option): except ValueError: pass except BaseException as error: - self.fail('Expected ValueError when setting {}={}, but {} was raised'.format( - option.name, repr(value), type(error))) + self.fail(f'Expected ValueError when setting {option.name}={repr(value)}, but {type(error)} was raised') else: - self.fail('Expected ValueError, but {}={} was accepted.'.format( - option.name, repr(option.fget(command)))) - - return + self.fail(f'Expected ValueError, but {option.name}={repr(option.fget(command))} was accepted.') if __name__ == "__main__": diff --git a/tests/searchcommands/test_configuration_settings.py b/tests/searchcommands/test_configuration_settings.py index dd07b57fc..171a36166 100644 --- a/tests/searchcommands/test_configuration_settings.py +++ b/tests/searchcommands/test_configuration_settings.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # coding=utf-8 # -# Copyright © 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -24,20 +24,19 @@ # * If a value is not set in code, the value specified in commands.conf is enforced # * If a value is set in code, it overrides the value specified in commands.conf -from __future__ import absolute_import, division, print_function, unicode_literals -from splunklib.searchcommands.decorators import Configuration from unittest import main, TestCase -from splunklib import six - import pytest +from splunklib.searchcommands.decorators import Configuration + + @pytest.mark.smoke class TestConfigurationSettings(TestCase): def test_generating_command(self): - from splunklib.searchcommands import Configuration, GeneratingCommand + from splunklib.searchcommands import GeneratingCommand @Configuration() class TestCommand(GeneratingCommand): @@ -48,7 +47,7 @@ def generate(self): command._protocol_version = 1 self.assertTrue( - [(name, value) for name, value in six.iteritems(command.configuration)], + list(command.configuration.items()), [('generating', True)]) self.assertIs(command.configuration.generates_timeorder, None) @@ -66,12 +65,12 @@ def generate(self): except AttributeError: pass except Exception as error: - self.fail('Expected AttributeError, not {}: {}'.format(type(error).__name__, error)) + self.fail(f'Expected AttributeError, not {type(error).__name__}: {error}') else: self.fail('Expected AttributeError') self.assertEqual( - [(name, value) for name, value in six.iteritems(command.configuration)], + list(command.configuration.items()), [('generates_timeorder', True), ('generating', True), ('local', True), ('retainsevents', True), ('streaming', True)]) @@ -79,7 +78,7 @@ def generate(self): command._protocol_version = 2 self.assertEqual( - [(name, value) for name, value in six.iteritems(command.configuration)], + list(command.configuration.items()), [('generating', True), ('type', 'stateful')]) self.assertIs(command.configuration.distributed, False) @@ -93,19 +92,17 @@ def generate(self): except AttributeError: pass except Exception as error: - self.fail('Expected AttributeError, not {}: {}'.format(type(error).__name__, error)) + self.fail(f'Expected AttributeError, not {type(error).__name__}: {error}') else: self.fail('Expected AttributeError') self.assertEqual( - [(name, value) for name, value in six.iteritems(command.configuration)], + list(command.configuration.items()), [('generating', True), ('type', 'streaming')]) - return - def test_streaming_command(self): - from splunklib.searchcommands import Configuration, StreamingCommand + from splunklib.searchcommands import StreamingCommand @Configuration() class TestCommand(StreamingCommand): @@ -117,7 +114,7 @@ def stream(self, records): command._protocol_version = 1 self.assertEqual( - [(name, value) for name, value in six.iteritems(command.configuration)], + list(command.configuration.items()), [('streaming', True)]) self.assertIs(command.configuration.clear_required_fields, None) @@ -136,19 +133,20 @@ def stream(self, records): except AttributeError: pass except Exception as error: - self.fail('Expected AttributeError, not {}: {}'.format(type(error).__name__, error)) + self.fail(f'Expected AttributeError, not {type(error).__name__}: {error}') else: self.fail('Expected AttributeError') self.assertEqual( - [(name, value) for name, value in six.iteritems(command.configuration)], - [('clear_required_fields', True), ('local', True), ('overrides_timeorder', True), ('required_fields', ['field_1', 'field_2', 'field_3']), ('streaming', True)]) + list(command.configuration.items()), + [('clear_required_fields', True), ('local', True), ('overrides_timeorder', True), + ('required_fields', ['field_1', 'field_2', 'field_3']), ('streaming', True)]) command = TestCommand() command._protocol_version = 2 self.assertEqual( - [(name, value) for name, value in six.iteritems(command.configuration)], + list(command.configuration.items()), [('type', 'streaming')]) self.assertIs(command.configuration.distributed, True) @@ -162,15 +160,14 @@ def stream(self, records): except AttributeError: pass except Exception as error: - self.fail('Expected AttributeError, not {}: {}'.format(type(error).__name__, error)) + self.fail(f'Expected AttributeError, not {type(error).__name__}: {error}') else: self.fail('Expected AttributeError') self.assertEqual( - [(name, value) for name, value in six.iteritems(command.configuration)], + list(command.configuration.items()), [('required_fields', ['field_1', 'field_2', 'field_3']), ('type', 'stateful')]) - return if __name__ == "__main__": main() diff --git a/tests/searchcommands/test_csc_apps.py b/tests/searchcommands/test_csc_apps.py index b15574d1c..64b03dc3e 100755 --- a/tests/searchcommands/test_csc_apps.py +++ b/tests/searchcommands/test_csc_apps.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -15,10 +15,12 @@ # under the License. import unittest +import pytest + from tests import testlib from splunklib import results - +@pytest.mark.smoke class TestCSC(testlib.SDKTestCase): def test_eventing_app(self): diff --git a/tests/searchcommands/test_decorators.py b/tests/searchcommands/test_decorators.py index dd65aa0ab..3cc571dd9 100755 --- a/tests/searchcommands/test_decorators.py +++ b/tests/searchcommands/test_decorators.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # coding=utf-8 # -# Copyright © 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -15,35 +15,23 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals -try: - from unittest2 import main, TestCase -except ImportError: - from unittest import main, TestCase +from unittest import main, TestCase import sys from io import TextIOWrapper +import pytest from splunklib.searchcommands import Configuration, Option, environment, validators from splunklib.searchcommands.decorators import ConfigurationSetting from splunklib.searchcommands.internals import json_encode_string from splunklib.searchcommands.search_command import SearchCommand -try: - from tests.searchcommands import rebase_environment -except ImportError: - # Skip on Python 2.6 - pass - -from splunklib import six - -import pytest +from tests.searchcommands import rebase_environment @Configuration() class TestSearchCommand(SearchCommand): - boolean = Option( doc=''' **Syntax:** **boolean=**** @@ -121,7 +109,7 @@ class TestSearchCommand(SearchCommand): **Syntax:** **integer=**** **Description:** An integer value''', require=True, validate=validators.Integer()) - + float = Option( doc=''' **Syntax:** **float=**** @@ -234,49 +222,48 @@ def fix_up(cls, command_class): return ConfiguredSearchCommand.ConfigurationSettings for name, values, error_values in ( - ('clear_required_fields', - (True, False), - (None, 'anything other than a bool')), - ('distributed', - (True, False), - (None, 'anything other than a bool')), - ('generates_timeorder', - (True, False), - (None, 'anything other than a bool')), - ('generating', - (True, False), - (None, 'anything other than a bool')), - ('maxinputs', - (0, 50000, sys.maxsize), - (None, -1, sys.maxsize + 1, 'anything other than an int')), - ('overrides_timeorder', - (True, False), - (None, 'anything other than a bool')), - ('required_fields', - (['field_1', 'field_2'], set(['field_1', 'field_2']), ('field_1', 'field_2')), - (None, 0xdead, {'foo': 1, 'bar': 2})), - ('requires_preop', - (True, False), - (None, 'anything other than a bool')), - ('retainsevents', - (True, False), - (None, 'anything other than a bool')), - ('run_in_preview', - (True, False), - (None, 'anything other than a bool')), - ('streaming', - (True, False), - (None, 'anything other than a bool')), - ('streaming_preop', - (u'some unicode string', b'some byte string'), - (None, 0xdead)), - ('type', - # TODO: Do we need to validate byte versions of these strings? - ('events', 'reporting', 'streaming'), - ('eventing', 0xdead))): + ('clear_required_fields', + (True, False), + (None, 'anything other than a bool')), + ('distributed', + (True, False), + (None, 'anything other than a bool')), + ('generates_timeorder', + (True, False), + (None, 'anything other than a bool')), + ('generating', + (True, False), + (None, 'anything other than a bool')), + ('maxinputs', + (0, 50000, sys.maxsize), + (None, -1, sys.maxsize + 1, 'anything other than an int')), + ('overrides_timeorder', + (True, False), + (None, 'anything other than a bool')), + ('required_fields', + (['field_1', 'field_2'], set(['field_1', 'field_2']), ('field_1', 'field_2')), + (None, 0xdead, {'foo': 1, 'bar': 2})), + ('requires_preop', + (True, False), + (None, 'anything other than a bool')), + ('retainsevents', + (True, False), + (None, 'anything other than a bool')), + ('run_in_preview', + (True, False), + (None, 'anything other than a bool')), + ('streaming', + (True, False), + (None, 'anything other than a bool')), + ('streaming_preop', + ('some unicode string', b'some byte string'), + (None, 0xdead)), + ('type', + # TODO: Do we need to validate byte versions of these strings? + ('events', 'reporting', 'streaming'), + ('eventing', 0xdead))): for value in values: - settings_class = new_configuration_settings_class(name, value) # Setting property exists @@ -296,28 +283,26 @@ def fix_up(cls, command_class): setattr(settings_instance, name, value) - self.assertIn(backing_field_name, settings_instance.__dict__), + self.assertIn(backing_field_name, settings_instance.__dict__) self.assertEqual(getattr(settings_instance, name), value) self.assertEqual(settings_instance.__dict__[backing_field_name], value) - pass for value in error_values: try: new_configuration_settings_class(name, value) except Exception as error: - self.assertIsInstance(error, ValueError, 'Expected ValueError, not {}({}) for {}={}'.format(type(error).__name__, error, name, repr(value))) + self.assertIsInstance(error, ValueError, + f'Expected ValueError, not {type(error).__name__}({error}) for {name}={repr(value)}') else: - self.fail('Expected ValueError, not success for {}={}'.format(name, repr(value))) + self.fail(f'Expected ValueError, not success for {name}={repr(value)}') settings_class = new_configuration_settings_class() settings_instance = settings_class(command=None) self.assertRaises(ValueError, setattr, settings_instance, name, value) - return - def test_new_configuration_setting(self): - class Test(object): + class Test: generating = ConfigurationSetting() @ConfigurationSetting(name='required_fields') @@ -366,13 +351,11 @@ def test_option(self): command = TestSearchCommand() options = command.options - itervalues = lambda: six.itervalues(options) - options.reset() missing = options.get_missing() - self.assertListEqual(missing, [option.name for option in itervalues() if option.is_required]) - self.assertListEqual(presets, [six.text_type(option) for option in itervalues() if option.value is not None]) - self.assertListEqual(presets, [six.text_type(option) for option in itervalues() if six.text_type(option) != option.name + '=None']) + self.assertListEqual(missing, [option.name for option in options.values() if option.is_required]) + self.assertListEqual(presets, [str(option) for option in options.values() if option.value is not None]) + self.assertListEqual(presets, [str(option) for option in options.values() if str(option) != option.name + '=None']) test_option_values = { validators.Boolean: ('0', 'non-boolean value'), @@ -389,7 +372,7 @@ def test_option(self): validators.RegularExpression: ('\\s+', '(poorly formed regular expression'), validators.Set: ('bar', 'non-existent set entry')} - for option in itervalues(): + for option in options.values(): validator = option.validator if validator is None: @@ -401,57 +384,57 @@ def test_option(self): self.assertEqual( validator.format(option.value), validator.format(validator.__call__(legal_value)), - "{}={}".format(option.name, legal_value)) + f"{option.name}={legal_value}") try: option.value = illegal_value except ValueError: pass except BaseException as error: - self.assertFalse('Expected ValueError for {}={}, not this {}: {}'.format( - option.name, illegal_value, type(error).__name__, error)) + self.assertFalse( + f'Expected ValueError for {option.name}={illegal_value}, not this {type(error).__name__}: {error}') else: - self.assertFalse('Expected ValueError for {}={}, not a pass.'.format(option.name, illegal_value)) + self.assertFalse(f'Expected ValueError for {option.name}={illegal_value}, not a pass.') expected = { - u'foo': False, + 'foo': False, 'boolean': False, - 'code': u'foo == \"bar\"', + 'code': 'foo == \"bar\"', 'duration': 89999, - 'fieldname': u'some.field_name', - 'file': six.text_type(repr(__file__)), + 'fieldname': 'some.field_name', + 'file': str(repr(__file__)), 'integer': 100, 'float': 99.9, 'logging_configuration': environment.logging_configuration, - 'logging_level': u'WARNING', + 'logging_level': 'WARNING', 'map': 'foo', - 'match': u'123-45-6789', - 'optionname': u'some_option_name', + 'match': '123-45-6789', + 'optionname': 'some_option_name', 'record': False, - 'regularexpression': u'\\s+', + 'regularexpression': '\\s+', 'required_boolean': False, - 'required_code': u'foo == \"bar\"', + 'required_code': 'foo == \"bar\"', 'required_duration': 89999, - 'required_fieldname': u'some.field_name', - 'required_file': six.text_type(repr(__file__)), + 'required_fieldname': 'some.field_name', + 'required_file': str(repr(__file__)), 'required_integer': 100, 'required_float': 99.9, 'required_map': 'foo', - 'required_match': u'123-45-6789', - 'required_optionname': u'some_option_name', - 'required_regularexpression': u'\\s+', - 'required_set': u'bar', - 'set': u'bar', + 'required_match': '123-45-6789', + 'required_optionname': 'some_option_name', + 'required_regularexpression': '\\s+', + 'required_set': 'bar', + 'set': 'bar', 'show_configuration': False, } self.maxDiff = None tuplewrap = lambda x: x if isinstance(x, tuple) else (x,) - invert = lambda x: {v: k for k, v in six.iteritems(x)} + invert = lambda x: {v: k for k, v in x.items()} - for x in six.itervalues(command.options): - # isinstance doesn't work for some reason + for x in command.options.values(): + # isinstance doesn't work for some reason if type(x.value).__name__ == 'Code': self.assertEqual(expected[x.name], x.value.source) elif type(x.validator).__name__ == 'Map': @@ -459,8 +442,8 @@ def test_option(self): elif type(x.validator).__name__ == 'RegularExpression': self.assertEqual(expected[x.name], x.value.pattern) elif isinstance(x.value, TextIOWrapper): - self.assertEqual(expected[x.name], "'%s'" % x.value.name) - elif not isinstance(x.value, (bool,) + (float,) + (six.text_type,) + (six.binary_type,) + tuplewrap(six.integer_types)): + self.assertEqual(expected[x.name], f"'{x.value.name}'") + elif not isinstance(x.value, (bool,) + (float,) + (str,) + (bytes,) + tuplewrap(int)): self.assertEqual(expected[x.name], repr(x.value)) else: self.assertEqual(expected[x.name], x.value) @@ -474,10 +457,9 @@ def test_option(self): 'required_match="123-45-6789" required_optionname="some_option_name" required_regularexpression="\\\\s+" ' 'required_set="bar" set="bar" show_configuration="f"') - observed = six.text_type(command.options) + observed = str(command.options) self.assertEqual(observed, expected) - return if __name__ == "__main__": diff --git a/tests/searchcommands/test_generator_command.py b/tests/searchcommands/test_generator_command.py index 63ae3ac83..af103977a 100644 --- a/tests/searchcommands/test_generator_command.py +++ b/tests/searchcommands/test_generator_command.py @@ -1,9 +1,8 @@ import io import time -from . import chunked_data_stream as chunky - from splunklib.searchcommands import Configuration, GeneratingCommand +from . import chunked_data_stream as chunky def test_simple_generator(): @@ -12,6 +11,7 @@ class GeneratorTest(GeneratingCommand): def generate(self): for num in range(1, 10): yield {'_time': time.time(), 'event_index': num} + generator = GeneratorTest() in_stream = io.BytesIO() in_stream.write(chunky.build_getinfo_chunk()) @@ -24,7 +24,7 @@ def generate(self): ds = chunky.ChunkedDataStream(out_stream) is_first_chunk = True finished_seen = False - expected = set(map(lambda i: str(i), range(1, 10))) + expected = set(str(i) for i in range(1, 10)) seen = set() for chunk in ds: if is_first_chunk: @@ -40,15 +40,18 @@ def generate(self): assert expected.issubset(seen) assert finished_seen + def test_allow_empty_input_for_generating_command(): """ Passing allow_empty_input for generating command will cause an error """ + @Configuration() class GeneratorTest(GeneratingCommand): def generate(self): for num in range(1, 3): yield {"_index": num} + generator = GeneratorTest() in_stream = io.BytesIO() out_stream = io.BytesIO() @@ -58,6 +61,7 @@ def generate(self): except ValueError as error: assert str(error) == "allow_empty_input cannot be False for Generating Commands" + def test_all_fieldnames_present_for_generated_records(): @Configuration() class GeneratorTest(GeneratingCommand): diff --git a/tests/searchcommands/test_internals_v1.py b/tests/searchcommands/test_internals_v1.py index eb85d040a..6e41844ff 100755 --- a/tests/searchcommands/test_internals_v1.py +++ b/tests/searchcommands/test_internals_v1.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,7 +14,12 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals +from contextlib import closing +from unittest import main, TestCase +import os +from io import StringIO, BytesIO +from functools import reduce +import pytest from splunklib.searchcommands.internals import CommandLineParser, InputHeader, RecordWriterV1 from splunklib.searchcommands.decorators import Configuration, Option @@ -22,19 +27,6 @@ from splunklib.searchcommands.search_command import SearchCommand -from contextlib import closing -from splunklib.six import StringIO, BytesIO - -from splunklib.six.moves import zip as izip - -from unittest import main, TestCase - -import os -from splunklib import six -from splunklib.six.moves import range -from functools import reduce - -import pytest @pytest.mark.smoke class TestInternals(TestCase): @@ -61,7 +53,7 @@ def fix_up(cls, command_class): pass command = TestCommandLineParserCommand() CommandLineParser.parse(command, options) - for option in six.itervalues(command.options): + for option in command.options.values(): if option.name in ['logging_configuration', 'logging_level', 'record', 'show_configuration']: self.assertFalse(option.is_set) continue @@ -78,7 +70,7 @@ def fix_up(cls, command_class): pass command = TestCommandLineParserCommand() CommandLineParser.parse(command, options + fieldnames) - for option in six.itervalues(command.options): + for option in command.options.values(): if option.name in ['logging_configuration', 'logging_level', 'record', 'show_configuration']: self.assertFalse(option.is_set) continue @@ -93,8 +85,9 @@ def fix_up(cls, command_class): pass command = TestCommandLineParserCommand() CommandLineParser.parse(command, ['required_option=true'] + fieldnames) - for option in six.itervalues(command.options): - if option.name in ['unnecessary_option', 'logging_configuration', 'logging_level', 'record', 'show_configuration']: + for option in command.options.values(): + if option.name in ['unnecessary_option', 'logging_configuration', 'logging_level', 'record', + 'show_configuration']: self.assertFalse(option.is_set) continue self.assertTrue(option.is_set) @@ -112,7 +105,8 @@ def fix_up(cls, command_class): pass # Command line with unrecognized options - self.assertRaises(ValueError, CommandLineParser.parse, command, ['unrecognized_option_1=foo', 'unrecognized_option_2=bar']) + self.assertRaises(ValueError, CommandLineParser.parse, command, + ['unrecognized_option_1=foo', 'unrecognized_option_2=bar']) # Command line with a variety of quoted/escaped text options @@ -145,19 +139,19 @@ def fix_up(cls, command_class): pass r'"Hello World!"' ] - for string, expected_value in izip(strings, expected_values): + for string, expected_value in zip(strings, expected_values): command = TestCommandLineParserCommand() argv = ['text', '=', string] CommandLineParser.parse(command, argv) self.assertEqual(command.text, expected_value) - for string, expected_value in izip(strings, expected_values): + for string, expected_value in zip(strings, expected_values): command = TestCommandLineParserCommand() argv = [string] CommandLineParser.parse(command, argv) self.assertEqual(command.fieldnames[0], expected_value) - for string, expected_value in izip(strings, expected_values): + for string, expected_value in zip(strings, expected_values): command = TestCommandLineParserCommand() argv = ['text', '=', string] + strings CommandLineParser.parse(command, argv) @@ -176,25 +170,23 @@ def fix_up(cls, command_class): pass argv = [string] self.assertRaises(SyntaxError, CommandLineParser.parse, command, argv) - return - def test_command_line_parser_unquote(self): parser = CommandLineParser options = [ - r'foo', # unquoted string with no escaped characters - r'fo\o\ b\"a\\r', # unquoted string with some escaped characters - r'"foo"', # quoted string with no special characters - r'"""foobar1"""', # quoted string with quotes escaped like this: "" - r'"\"foobar2\""', # quoted string with quotes escaped like this: \" - r'"foo ""x"" bar"', # quoted string with quotes escaped like this: "" - r'"foo \"x\" bar"', # quoted string with quotes escaped like this: \" - r'"\\foobar"', # quoted string with an escaped backslash - r'"foo \\ bar"', # quoted string with an escaped backslash - r'"foobar\\"', # quoted string with an escaped backslash - r'foo\\\bar', # quoted string with an escaped backslash and an escaped 'b' - r'""', # pair of quotes - r''] # empty string + r'foo', # unquoted string with no escaped characters + r'fo\o\ b\"a\\r', # unquoted string with some escaped characters + r'"foo"', # quoted string with no special characters + r'"""foobar1"""', # quoted string with quotes escaped like this: "" + r'"\"foobar2\""', # quoted string with quotes escaped like this: \" + r'"foo ""x"" bar"', # quoted string with quotes escaped like this: "" + r'"foo \"x\" bar"', # quoted string with quotes escaped like this: \" + r'"\\foobar"', # quoted string with an escaped backslash + r'"foo \\ bar"', # quoted string with an escaped backslash + r'"foobar\\"', # quoted string with an escaped backslash + r'foo\\\bar', # quoted string with an escaped backslash and an escaped 'b' + r'""', # pair of quotes + r''] # empty string expected = [ r'foo', @@ -288,7 +280,7 @@ def test_input_header(self): 'sentence': 'hello world!'} input_header = InputHeader() - text = reduce(lambda value, item: value + '{}:{}\n'.format(item[0], item[1]), six.iteritems(collection), '') + '\n' + text = reduce(lambda value, item: value + f'{item[0]}:{item[1]}\n', collection.items(), '') + '\n' with closing(StringIO(text)) as input_file: input_header.read(input_file) @@ -308,13 +300,10 @@ def test_input_header(self): self.assertEqual(sorted(input_header.keys()), sorted(collection.keys())) self.assertEqual(sorted(input_header.values()), sorted(collection.values())) - return - def test_messages_header(self): @Configuration() class TestMessagesHeaderCommand(SearchCommand): - class ConfigurationSettings(SearchCommand.ConfigurationSettings): @classmethod @@ -346,7 +335,6 @@ def fix_up(cls, command_class): pass '\r\n') self.assertEqual(output_buffer.getvalue().decode('utf-8'), expected) - return _package_path = os.path.dirname(__file__) diff --git a/tests/searchcommands/test_internals_v2.py b/tests/searchcommands/test_internals_v2.py index ec9b3f666..722aaae24 100755 --- a/tests/searchcommands/test_internals_v2.py +++ b/tests/searchcommands/test_internals_v2.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # coding=utf-8 # -# Copyright © 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -15,46 +15,37 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals +import gzip +import io +import json +import os +import random +import sys -from splunklib.searchcommands.internals import MetadataDecoder, MetadataEncoder, Recorder, RecordWriterV2 -from splunklib.searchcommands import SearchMetric -from splunklib import six -from splunklib.six.moves import range -from collections import OrderedDict -from collections import namedtuple, deque -from splunklib.six import BytesIO as BytesIO +import pytest from functools import wraps -from glob import iglob from itertools import chain -from splunklib.six.moves import filter as ifilter -from splunklib.six.moves import map as imap -from splunklib.six.moves import zip as izip from sys import float_info from tempfile import mktemp from time import time from types import MethodType -from sys import version_info as python_version -try: - from unittest2 import main, TestCase -except ImportError: - from unittest import main, TestCase +from unittest import main, TestCase -import splunklib.six.moves.cPickle as pickle -import gzip -import io -import json -import os -import random +from collections import OrderedDict +from collections import namedtuple, deque + +from splunklib.searchcommands.internals import MetadataDecoder, MetadataEncoder, Recorder, RecordWriterV2 +from splunklib.searchcommands import SearchMetric +from io import BytesIO +import pickle -import pytest # region Functions for producing random apps # Confirmed: [minint, maxint) covers the full range of values that xrange allows -minint = (-six.MAXSIZE - 1) // 2 -maxint = six.MAXSIZE // 2 +minint = (-sys.maxsize - 1) // 2 +maxint = sys.maxsize // 2 max_length = 1 * 1024 @@ -90,7 +81,7 @@ def random_integer(): def random_integers(): - return random_list(six.moves.range, minint, maxint) + return random_list(range, minint, maxint) def random_list(population, *args): @@ -98,7 +89,7 @@ def random_list(population, *args): def random_unicode(): - return ''.join(imap(lambda x: six.unichr(x), random.sample(range(MAX_NARROW_UNICODE), random.randint(0, max_length)))) + return ''.join([str(x) for x in random.sample(list(range(MAX_NARROW_UNICODE)), random.randint(0, max_length))]) # endregion @@ -118,7 +109,6 @@ def test_object_view(self): json_output = encoder.encode(view) self.assertEqual(self._json_input, json_output) - return def test_record_writer_with_random_data(self, save_recording=False): @@ -138,7 +128,7 @@ def test_record_writer_with_random_data(self, save_recording=False): for serial_number in range(0, 31): values = [serial_number, time(), random_bytes(), random_dict(), random_integers(), random_unicode()] - record = OrderedDict(izip(fieldnames, values)) + record = OrderedDict(list(zip(fieldnames, values))) #try: write_record(record) #except Exception as error: @@ -167,7 +157,7 @@ def test_record_writer_with_random_data(self, save_recording=False): test_data['metrics'] = metrics - for name, metric in six.iteritems(metrics): + for name, metric in metrics.items(): writer.write_metric(name, metric) self.assertEqual(writer._chunk_count, 0) @@ -182,8 +172,8 @@ def test_record_writer_with_random_data(self, save_recording=False): self.assertListEqual(writer._inspector['messages'], messages) self.assertDictEqual( - dict(ifilter(lambda k_v: k_v[0].startswith('metric.'), six.iteritems(writer._inspector))), - dict(imap(lambda k_v1: ('metric.' + k_v1[0], k_v1[1]), six.iteritems(metrics)))) + dict(k_v for k_v in writer._inspector.items() if k_v[0].startswith('metric.')), + dict(('metric.' + k_v1[0], k_v1[1]) for k_v1 in metrics.items())) writer.flush(finished=True) @@ -213,18 +203,15 @@ def test_record_writer_with_random_data(self, save_recording=False): # P2 [ ] TODO: Verify that RecordWriter gives consumers the ability to finish early by calling # RecordWriter.flush(finish=True). - return - def _compare_chunks(self, chunks_1, chunks_2): self.assertEqual(len(chunks_1), len(chunks_2)) n = 0 - for chunk_1, chunk_2 in izip(chunks_1, chunks_2): + for chunk_1, chunk_2 in zip(chunks_1, chunks_2): self.assertDictEqual( chunk_1.metadata, chunk_2.metadata, 'Chunk {0}: metadata error: "{1}" != "{2}"'.format(n, chunk_1.metadata, chunk_2.metadata)) self.assertMultiLineEqual(chunk_1.body, chunk_2.body, 'Chunk {0}: data error'.format(n)) n += 1 - return def _load_chunks(self, ifile): import re @@ -276,11 +263,11 @@ def _load_chunks(self, ifile): 'n': 12 } - _json_input = six.text_type(json.dumps(_dictionary, separators=(',', ':'))) + _json_input = str(json.dumps(_dictionary, separators=(',', ':'))) _package_path = os.path.dirname(os.path.abspath(__file__)) -class TestRecorder(object): +class TestRecorder: def __init__(self, test_case): @@ -293,7 +280,6 @@ def _not_implemented(self): raise NotImplementedError('class {} is not in playback or record mode'.format(self.__class__.__name__)) self.get = self.next_part = self.stop = MethodType(_not_implemented, self, self.__class__) - return @property def output(self): @@ -322,7 +308,6 @@ def stop(self): self._test_case.assertEqual(test_data['results'], self._output.getvalue()) self.stop = MethodType(stop, self, self.__class__) - return def record(self, path): @@ -357,7 +342,6 @@ def stop(self): pickle.dump(test, f) self.stop = MethodType(stop, self, self.__class__) - return def recorded(method): @@ -369,12 +353,11 @@ def _record(*args, **kwargs): return _record -class Test(object): +class Test: def __init__(self, fieldnames, data_generators): TestCase.__init__(self) - self._data_generators = list(chain((lambda: self._serial_number, time), data_generators)) self._fieldnames = list(chain(('_serial', '_time'), fieldnames)) self._recorder = TestRecorder(self) @@ -418,15 +401,14 @@ def _run(self): names = self.fieldnames for self._serial_number in range(0, 31): - record = OrderedDict(izip(names, self.row)) + record = OrderedDict(list(zip(names, self.row))) write_record(record) - return - # test = Test(['random_bytes', 'random_unicode'], [random_bytes, random_unicode]) # test.record() # test.playback() + if __name__ == "__main__": main() diff --git a/tests/searchcommands/test_multibyte_processing.py b/tests/searchcommands/test_multibyte_processing.py index 4d6127fe9..1d021eed7 100644 --- a/tests/searchcommands/test_multibyte_processing.py +++ b/tests/searchcommands/test_multibyte_processing.py @@ -4,7 +4,6 @@ from os import path -from splunklib import six from splunklib.searchcommands import StreamingCommand, Configuration @@ -25,15 +24,13 @@ def get_input_file(name): def test_multibyte_chunked(): data = gzip.open(get_input_file("multibyte_input")) - if not six.PY2: - data = io.TextIOWrapper(data) + data = io.TextIOWrapper(data) cmd = build_test_command() cmd._process_protocol_v2(sys.argv, data, sys.stdout) def test_v1_searchcommand(): data = gzip.open(get_input_file("v1_search_input")) - if not six.PY2: - data = io.TextIOWrapper(data) + data = io.TextIOWrapper(data) cmd = build_test_command() cmd._process_protocol_v1(["test_script.py", "__EXECUTE__"], data, sys.stdout) diff --git a/tests/searchcommands/test_reporting_command.py b/tests/searchcommands/test_reporting_command.py index e5add818c..2111447d5 100644 --- a/tests/searchcommands/test_reporting_command.py +++ b/tests/searchcommands/test_reporting_command.py @@ -1,6 +1,6 @@ import io -import splunklib.searchcommands as searchcommands +from splunklib import searchcommands from . import chunked_data_stream as chunky @@ -15,7 +15,7 @@ def reduce(self, records): cmd = TestReportingCommand() ifile = io.BytesIO() - data = list() + data = [] for i in range(0, 10): data.append({"value": str(i)}) ifile.write(chunky.build_getinfo_chunk()) diff --git a/tests/searchcommands/test_search_command.py b/tests/searchcommands/test_search_command.py index c5ce36066..c8fe7d804 100755 --- a/tests/searchcommands/test_search_command.py +++ b/tests/searchcommands/test_search_command.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # coding=utf-8 # -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -15,16 +15,6 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals - -from splunklib import six -from splunklib.searchcommands import Configuration, StreamingCommand -from splunklib.searchcommands.decorators import ConfigurationSetting, Option -from splunklib.searchcommands.search_command import SearchCommand -from splunklib.client import Service - -from splunklib.six import StringIO, BytesIO -from splunklib.six.moves import zip as izip from json.encoder import encode_basestring as encode_string from unittest import main, TestCase @@ -37,30 +27,36 @@ import pytest +import splunklib +from splunklib.searchcommands import Configuration, StreamingCommand +from splunklib.searchcommands.decorators import ConfigurationSetting, Option +from splunklib.searchcommands.search_command import SearchCommand +from splunklib.client import Service +from splunklib.utils import ensure_binary + +from io import StringIO, BytesIO + + def build_command_input(getinfo_metadata, execute_metadata, execute_body): - input = ('chunked 1.0,{},0\n{}'.format(len(six.ensure_binary(getinfo_metadata)), getinfo_metadata) + - 'chunked 1.0,{},{}\n{}{}'.format(len(six.ensure_binary(execute_metadata)), len(six.ensure_binary(execute_body)), execute_metadata, execute_body)) + input = (f'chunked 1.0,{len(ensure_binary(getinfo_metadata))},0\n{getinfo_metadata}' + + f'chunked 1.0,{len(ensure_binary(execute_metadata))},{len(ensure_binary(execute_body))}\n{execute_metadata}{execute_body}') - ifile = BytesIO(six.ensure_binary(input)) + ifile = BytesIO(ensure_binary(input)) - if not six.PY2: - ifile = TextIOWrapper(ifile) + ifile = TextIOWrapper(ifile) return ifile + @Configuration() class TestCommand(SearchCommand): - required_option_1 = Option(require=True) required_option_2 = Option(require=True) def echo(self, records): for record in records: if record.get('action') == 'raise_exception': - if six.PY2: - raise StandardError(self) - else: - raise Exception(self) + raise Exception(self) yield record def _execute(self, ifile, process): @@ -108,7 +104,7 @@ def stream(self, records): value = self.search_results_info if action == 'get_search_results_info' else None yield {'_serial': serial_number, 'data': value} serial_number += 1 - return + @pytest.mark.smoke class TestSearchCommand(TestCase): @@ -123,37 +119,37 @@ def test_process_scpv2(self): metadata = ( '{{' - '"action": "getinfo", "preview": false, "searchinfo": {{' - '"latest_time": "0",' - '"splunk_version": "20150522",' - '"username": "admin",' - '"app": "searchcommands_app",' - '"args": [' - '"logging_configuration={logging_configuration}",' - '"logging_level={logging_level}",' - '"record={record}",' - '"show_configuration={show_configuration}",' - '"required_option_1=value_1",' - '"required_option_2=value_2"' - '],' - '"search": "A%7C%20inputlookup%20tweets%20%7C%20countmatches%20fieldname%3Dword_count%20pattern%3D%22%5Cw%2B%22%20text%20record%3Dt%20%7C%20export%20add_timestamp%3Df%20add_offset%3Dt%20format%3Dcsv%20segmentation%3Draw",' - '"earliest_time": "0",' - '"session_key": "0JbG1fJEvXrL6iYZw9y7tmvd6nHjTKj7ggaE7a4Jv5R0UIbeYJ65kThn^3hiNeoqzMT_LOtLpVR3Y8TIJyr5bkHUElMijYZ8l14wU0L4n^Oa5QxepsZNUIIQCBm^",' - '"owner": "admin",' - '"sid": "1433261372.158",' - '"splunkd_uri": "https://127.0.0.1:8089",' - '"dispatch_dir": {dispatch_dir},' - '"raw_args": [' - '"logging_configuration={logging_configuration}",' - '"logging_level={logging_level}",' - '"record={record}",' - '"show_configuration={show_configuration}",' - '"required_option_1=value_1",' - '"required_option_2=value_2"' - '],' - '"maxresultrows": 10,' - '"command": "countmatches"' - '}}' + '"action": "getinfo", "preview": false, "searchinfo": {{' + '"latest_time": "0",' + '"splunk_version": "20150522",' + '"username": "admin",' + '"app": "searchcommands_app",' + '"args": [' + '"logging_configuration={logging_configuration}",' + '"logging_level={logging_level}",' + '"record={record}",' + '"show_configuration={show_configuration}",' + '"required_option_1=value_1",' + '"required_option_2=value_2"' + '],' + '"search": "A%7C%20inputlookup%20tweets%20%7C%20countmatches%20fieldname%3Dword_count%20pattern%3D%22%5Cw%2B%22%20text%20record%3Dt%20%7C%20export%20add_timestamp%3Df%20add_offset%3Dt%20format%3Dcsv%20segmentation%3Draw",' + '"earliest_time": "0",' + '"session_key": "0JbG1fJEvXrL6iYZw9y7tmvd6nHjTKj7ggaE7a4Jv5R0UIbeYJ65kThn^3hiNeoqzMT_LOtLpVR3Y8TIJyr5bkHUElMijYZ8l14wU0L4n^Oa5QxepsZNUIIQCBm^",' + '"owner": "admin",' + '"sid": "1433261372.158",' + '"splunkd_uri": "https://127.0.0.1:8089",' + '"dispatch_dir": {dispatch_dir},' + '"raw_args": [' + '"logging_configuration={logging_configuration}",' + '"logging_level={logging_level}",' + '"record={record}",' + '"show_configuration={show_configuration}",' + '"required_option_1=value_1",' + '"required_option_2=value_2"' + '],' + '"maxresultrows": 10,' + '"command": "countmatches"' + '}}' '}}') basedir = self._package_directory @@ -198,7 +194,7 @@ def test_process_scpv2(self): expected = ( 'chunked 1.0,68,0\n' - '{"inspector":{"messages":[["INFO","test command configuration: "]]}}\n' + '{"inspector":{"messages":[["INFO","test command configuration: "]]}}' 'chunked 1.0,17,32\n' '{"finished":true}test,__mv_test\r\n' 'data,\r\n' @@ -235,14 +231,18 @@ def test_process_scpv2(self): self.assertEqual(command_metadata.preview, input_header['preview']) self.assertEqual(command_metadata.searchinfo.app, 'searchcommands_app') - self.assertEqual(command_metadata.searchinfo.args, ['logging_configuration=' + logging_configuration, 'logging_level=ERROR', 'record=false', 'show_configuration=true', 'required_option_1=value_1', 'required_option_2=value_2']) + self.assertEqual(command_metadata.searchinfo.args, + ['logging_configuration=' + logging_configuration, 'logging_level=ERROR', 'record=false', + 'show_configuration=true', 'required_option_1=value_1', 'required_option_2=value_2']) self.assertEqual(command_metadata.searchinfo.dispatch_dir, os.path.dirname(input_header['infoPath'])) self.assertEqual(command_metadata.searchinfo.earliest_time, 0.0) self.assertEqual(command_metadata.searchinfo.latest_time, 0.0) self.assertEqual(command_metadata.searchinfo.owner, 'admin') self.assertEqual(command_metadata.searchinfo.raw_args, command_metadata.searchinfo.args) - self.assertEqual(command_metadata.searchinfo.search, 'A| inputlookup tweets | countmatches fieldname=word_count pattern="\\w+" text record=t | export add_timestamp=f add_offset=t format=csv segmentation=raw') - self.assertEqual(command_metadata.searchinfo.session_key, '0JbG1fJEvXrL6iYZw9y7tmvd6nHjTKj7ggaE7a4Jv5R0UIbeYJ65kThn^3hiNeoqzMT_LOtLpVR3Y8TIJyr5bkHUElMijYZ8l14wU0L4n^Oa5QxepsZNUIIQCBm^') + self.assertEqual(command_metadata.searchinfo.search, + 'A| inputlookup tweets | countmatches fieldname=word_count pattern="\\w+" text record=t | export add_timestamp=f add_offset=t format=csv segmentation=raw') + self.assertEqual(command_metadata.searchinfo.session_key, + '0JbG1fJEvXrL6iYZw9y7tmvd6nHjTKj7ggaE7a4Jv5R0UIbeYJ65kThn^3hiNeoqzMT_LOtLpVR3Y8TIJyr5bkHUElMijYZ8l14wU0L4n^Oa5QxepsZNUIIQCBm^') self.assertEqual(command_metadata.searchinfo.sid, '1433261372.158') self.assertEqual(command_metadata.searchinfo.splunk_version, '20150522') self.assertEqual(command_metadata.searchinfo.splunkd_uri, 'https://127.0.0.1:8089') diff --git a/tests/searchcommands/test_streaming_command.py b/tests/searchcommands/test_streaming_command.py index ffe6a7376..afb2e8caa 100644 --- a/tests/searchcommands/test_streaming_command.py +++ b/tests/searchcommands/test_streaming_command.py @@ -1,7 +1,7 @@ import io -from . import chunked_data_stream as chunky from splunklib.searchcommands import StreamingCommand, Configuration +from . import chunked_data_stream as chunky def test_simple_streaming_command(): @@ -16,7 +16,7 @@ def stream(self, records): cmd = TestStreamingCommand() ifile = io.BytesIO() ifile.write(chunky.build_getinfo_chunk()) - data = list() + data = [] for i in range(0, 10): data.append({"in_index": str(i)}) ifile.write(chunky.build_data_chunk(data, finished=True)) @@ -44,7 +44,7 @@ def stream(self, records): cmd = TestStreamingCommand() ifile = io.BytesIO() ifile.write(chunky.build_getinfo_chunk()) - data = list() + data = [] for i in range(0, 10): data.append({"in_index": str(i)}) ifile.write(chunky.build_data_chunk(data, finished=True)) @@ -53,14 +53,14 @@ def stream(self, records): cmd._process_protocol_v2([], ifile, ofile) ofile.seek(0) output_iter = chunky.ChunkedDataStream(ofile).__iter__() - output_iter.next() - output_records = [i for i in output_iter.next().data] + next(output_iter) + output_records = list(next(output_iter).data) # Assert that count of records having "odd_field" is 0 - assert len(list(filter(lambda r: "odd_field" in r, output_records))) == 0 + assert len(list(r for r in output_records if "odd_field" in r)) == 0 # Assert that count of records having "even_field" is 10 - assert len(list(filter(lambda r: "even_field" in r, output_records))) == 10 + assert len(list(r for r in output_records if "even_field" in r)) == 10 def test_field_preservation_positive(): @@ -78,7 +78,7 @@ def stream(self, records): cmd = TestStreamingCommand() ifile = io.BytesIO() ifile.write(chunky.build_getinfo_chunk()) - data = list() + data = [] for i in range(0, 10): data.append({"in_index": str(i)}) ifile.write(chunky.build_data_chunk(data, finished=True)) @@ -87,11 +87,11 @@ def stream(self, records): cmd._process_protocol_v2([], ifile, ofile) ofile.seek(0) output_iter = chunky.ChunkedDataStream(ofile).__iter__() - output_iter.next() - output_records = [i for i in output_iter.next().data] + next(output_iter) + output_records = list(next(output_iter).data) # Assert that count of records having "odd_field" is 10 - assert len(list(filter(lambda r: "odd_field" in r, output_records))) == 10 + assert len(list(r for r in output_records if "odd_field" in r)) == 10 # Assert that count of records having "even_field" is 10 - assert len(list(filter(lambda r: "even_field" in r, output_records))) == 10 + assert len(list(r for r in output_records if "even_field" in r)) == 10 diff --git a/tests/searchcommands/test_validators.py b/tests/searchcommands/test_validators.py index cc524b307..80149aa64 100755 --- a/tests/searchcommands/test_validators.py +++ b/tests/searchcommands/test_validators.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # coding=utf-8 # -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -15,20 +15,15 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals - -from splunklib.searchcommands import validators from random import randint from unittest import main, TestCase import os -import re import sys import tempfile -from splunklib import six -from splunklib.six.moves import range - import pytest +from splunklib.searchcommands import validators + # P2 [ ] TODO: Verify that all format methods produce 'None' when value is None @@ -52,14 +47,12 @@ def test_boolean(self): for value in truth_values: for variant in value, value.capitalize(), value.upper(): - s = six.text_type(variant) + s = str(variant) self.assertEqual(validator.__call__(s), truth_values[value]) self.assertIsNone(validator.__call__(None)) self.assertRaises(ValueError, validator.__call__, 'anything-else') - return - def test_duration(self): # Duration validator should parse and format time intervals of the form @@ -68,7 +61,7 @@ def test_duration(self): validator = validators.Duration() for seconds in range(0, 25 * 60 * 60, 59): - value = six.text_type(seconds) + value = str(seconds) self.assertEqual(validator(value), seconds) self.assertEqual(validator(validator.format(seconds)), seconds) value = '%d:%02d' % (seconds / 60, seconds % 60) @@ -97,8 +90,6 @@ def test_duration(self): self.assertRaises(ValueError, validator, '00:00:60') self.assertRaises(ValueError, validator, '00:60:00') - return - def test_fieldname(self): pass @@ -140,8 +131,6 @@ def test_file(self): if os.path.exists(full_path): os.unlink(full_path) - return - def test_integer(self): # Point of interest: @@ -165,14 +154,10 @@ def test_integer(self): validator = validators.Integer() def test(integer): - for s in str(integer), six.text_type(integer): - value = validator.__call__(s) - self.assertEqual(value, integer) - if six.PY2: - self.assertIsInstance(value, long) - else: - self.assertIsInstance(value, int) - self.assertEqual(validator.format(integer), six.text_type(integer)) + value = validator.__call__(integer) + self.assertEqual(value, integer) + self.assertIsInstance(value, int) + self.assertEqual(validator.format(integer), str(integer)) test(2 * minsize) test(minsize) @@ -204,8 +189,6 @@ def test(integer): self.assertRaises(ValueError, validator.__call__, minsize - 1) self.assertRaises(ValueError, validator.__call__, maxsize + 1) - return - def test_float(self): # Float validator test @@ -219,11 +202,11 @@ def test(float_val): float_val = float(float_val) except ValueError: assert False - for s in str(float_val), six.text_type(float_val): - value = validator.__call__(s) - self.assertAlmostEqual(value, float_val) - self.assertIsInstance(value, float) - self.assertEqual(validator.format(float_val), six.text_type(float_val)) + + value = validator.__call__(float_val) + self.assertAlmostEqual(value, float_val) + self.assertIsInstance(value, float) + self.assertEqual(validator.format(float_val), str(float_val)) test(2 * minsize) test(minsize) @@ -261,8 +244,6 @@ def test(float_val): self.assertRaises(ValueError, validator.__call__, minsize - 1) self.assertRaises(ValueError, validator.__call__, maxsize + 1) - return - def test_list(self): validator = validators.List() diff --git a/tests/test_all.py b/tests/test_all.py index 7789f8fd9..55b4d77f5 100755 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2023 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -16,15 +16,11 @@ """Runs all the Splunk SDK for Python unit tests.""" -from __future__ import absolute_import import os -try: - import unittest2 as unittest # We must be sure to get unittest2--not unittest--on Python 2.6 -except ImportError: - import unittest +import unittest os.chdir(os.path.dirname(os.path.abspath(__file__))) suite = unittest.defaultTestLoader.discover('.') if __name__ == '__main__': - unittest.TextTestRunner().run(suite) \ No newline at end of file + unittest.TextTestRunner().run(suite) diff --git a/tests/test_app.py b/tests/test_app.py index 3dbc4cffb..35be38146 100755 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,11 +14,9 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import -from tests import testlib import logging - -import splunklib.client as client +from tests import testlib +from splunklib import client class TestApp(testlib.SDKTestCase): @@ -26,7 +24,7 @@ class TestApp(testlib.SDKTestCase): app_name = None def setUp(self): - super(TestApp, self).setUp() + super().setUp() if self.app is None: for app in self.service.apps: if app.name.startswith('delete-me'): @@ -37,18 +35,18 @@ def setUp(self): # than entities like indexes, this is okay. self.app_name = testlib.tmpname() self.app = self.service.apps.create(self.app_name) - logging.debug("Creating app %s", self.app_name) + logging.debug(f"Creating app {self.app_name}") else: - logging.debug("App %s already exists. Skipping creation.", self.app_name) + logging.debug(f"App {self.app_name} already exists. Skipping creation.") if self.service.restart_required: self.service.restart(120) - return def tearDown(self): - super(TestApp, self).tearDown() + super().tearDown() # The rest of this will leave Splunk in a state requiring a restart. # It doesn't actually matter, though. self.service = client.connect(**self.opts.kwargs) + app_name = '' for app in self.service.apps: app_name = app.name if app_name.startswith('delete-me'): @@ -86,11 +84,11 @@ def test_update(self): def test_delete(self): name = testlib.tmpname() - app = self.service.apps.create(name) + self.service.apps.create(name) self.assertTrue(name in self.service.apps) self.service.apps.delete(name) self.assertFalse(name in self.service.apps) - self.clear_restart_message() # We don't actually have to restart here. + self.clear_restart_message() # We don't actually have to restart here. def test_package(self): p = self.app.package() @@ -103,9 +101,7 @@ def test_updateInfo(self): p = self.app.updateInfo() self.assertTrue(p is not None) + if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + import unittest unittest.main() diff --git a/tests/test_binding.py b/tests/test_binding.py index c101b19cb..fe14c259c 100755 --- a/tests/test_binding.py +++ b/tests/test_binding.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,15 +14,11 @@ # License for the specific language governing permissions and limitations # under the License. - -from __future__ import absolute_import -from io import BytesIO +from http import server as BaseHTTPServer +from io import BytesIO, StringIO from threading import Thread +from urllib.request import Request, urlopen -from splunklib.six.moves import BaseHTTPServer -from splunklib.six.moves.urllib.request import Request, urlopen -from splunklib.six.moves.urllib.error import HTTPError -import splunklib.six as six from xml.etree.ElementTree import XML import json @@ -30,14 +26,13 @@ from tests import testlib import unittest import socket -import sys import ssl -import splunklib.six.moves.http_cookies -import splunklib.binding as binding +import splunklib +from splunklib import binding from splunklib.binding import HTTPError, AuthenticationError, UrlEncoded -import splunklib.data as data -from splunklib import six +from splunklib import data +from splunklib.utils import ensure_str import pytest @@ -65,14 +60,17 @@ def load(response): return data.load(response.body.read()) + class BindingTestCase(unittest.TestCase): context = None + def setUp(self): logging.info("%s", self.__class__.__name__) self.opts = testlib.parse([], {}, ".env") self.context = binding.connect(**self.opts.kwargs) logging.debug("Connected to splunkd.") + class TestResponseReader(BindingTestCase): def test_empty(self): response = binding.ResponseReader(BytesIO(b"")) @@ -106,7 +104,7 @@ def test_read_partial(self): def test_readable(self): txt = "abcd" - response = binding.ResponseReader(six.StringIO(txt)) + response = binding.ResponseReader(StringIO(txt)) self.assertTrue(response.readable()) def test_readinto_bytearray(self): @@ -124,9 +122,6 @@ def test_readinto_bytearray(self): self.assertTrue(response.empty) def test_readinto_memoryview(self): - import sys - if sys.version_info < (2, 7, 0): - return # memoryview is new to Python 2.7 txt = b"Checking readinto works as expected" response = binding.ResponseReader(BytesIO(txt)) arr = bytearray(10) @@ -142,7 +137,6 @@ def test_readinto_memoryview(self): self.assertTrue(response.empty) - class TestUrlEncoded(BindingTestCase): def test_idempotent(self): a = UrlEncoded('abc') @@ -173,6 +167,7 @@ def test_chars(self): def test_repr(self): self.assertEqual(repr(UrlEncoded('% %')), "UrlEncoded('% %')") + class TestAuthority(unittest.TestCase): def test_authority_default(self): self.assertEqual(binding._authority(), @@ -190,6 +185,12 @@ def test_ipv6_host(self): host="2001:0db8:85a3:0000:0000:8a2e:0370:7334"), "https://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8089") + def test_ipv6_host_enclosed(self): + self.assertEqual( + binding._authority( + host="[2001:0db8:85a3:0000:0000:8a2e:0370:7334]"), + "https://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8089") + def test_all_fields(self): self.assertEqual( binding._authority( @@ -198,6 +199,7 @@ def test_all_fields(self): port="471"), "http://splunk.utopia.net:471") + class TestUserManipulation(BindingTestCase): def setUp(self): BindingTestCase.setUp(self) @@ -277,13 +279,10 @@ def test_post_with_get_arguments_to_receivers_stream(self): class TestSocket(BindingTestCase): def test_socket(self): socket = self.context.connect() - socket.write(("POST %s HTTP/1.1\r\n" % \ - self.context._abspath("some/path/to/post/to")).encode('utf-8')) - socket.write(("Host: %s:%s\r\n" % \ - (self.context.host, self.context.port)).encode('utf-8')) + socket.write((f"POST {self.context._abspath('some/path/to/post/to')} HTTP/1.1\r\n").encode('utf-8')) + socket.write((f"Host: {self.context.host}:{self.context.port}\r\n").encode('utf-8')) socket.write("Accept-Encoding: identity\r\n".encode('utf-8')) - socket.write(("Authorization: %s\r\n" % \ - self.context.token).encode('utf-8')) + socket.write((f"Authorization: {self.context.token}\r\n").encode('utf-8')) socket.write("X-Splunk-Input-Mode: Streaming\r\n".encode('utf-8')) socket.write("\r\n".encode('utf-8')) socket.close() @@ -308,15 +307,17 @@ def test_socket_gethostbyname(self): self.context.host = socket.gethostbyname(self.context.host) self.assertTrue(self.context.connect()) + class TestUnicodeConnect(BindingTestCase): def test_unicode_connect(self): opts = self.opts.kwargs.copy() - opts['host'] = six.text_type(opts['host']) + opts['host'] = str(opts['host']) context = binding.connect(**opts) # Just check to make sure the service is alive response = context.get("/services") self.assertEqual(response.status, 200) + @pytest.mark.smoke class TestAutologin(BindingTestCase): def test_with_autologin(self): @@ -332,6 +333,7 @@ def test_without_autologin(self): self.assertRaises(AuthenticationError, self.context.get, "/services") + class TestAbspath(BindingTestCase): def setUp(self): BindingTestCase.setUp(self) @@ -339,7 +341,6 @@ def setUp(self): if 'app' in self.kwargs: del self.kwargs['app'] if 'owner' in self.kwargs: del self.kwargs['owner'] - def test_default(self): path = self.context._abspath("foo", owner=None, app=None) self.assertTrue(isinstance(path, UrlEncoded)) @@ -371,12 +372,12 @@ def test_sharing_app(self): self.assertEqual(path, "/servicesNS/nobody/MyApp/foo") def test_sharing_global(self): - path = self.context._abspath("foo", owner="me", app="MyApp",sharing="global") + path = self.context._abspath("foo", owner="me", app="MyApp", sharing="global") self.assertTrue(isinstance(path, UrlEncoded)) self.assertEqual(path, "/servicesNS/nobody/MyApp/foo") def test_sharing_system(self): - path = self.context._abspath("foo bar", owner="me", app="MyApp",sharing="system") + path = self.context._abspath("foo bar", owner="me", app="MyApp", sharing="system") self.assertTrue(isinstance(path, UrlEncoded)) self.assertEqual(path, "/servicesNS/nobody/system/foo%20bar") @@ -444,6 +445,7 @@ def test_context_with_owner_as_email(self): self.assertEqual(path, "/servicesNS/me%40me.com/system/foo") self.assertEqual(path, UrlEncoded("/servicesNS/me@me.com/system/foo")) + # An urllib2 based HTTP request handler, used to test the binding layers # support for pluggable request handlers. def urllib2_handler(url, message, **kwargs): @@ -452,13 +454,9 @@ def urllib2_handler(url, message, **kwargs): headers = dict(message.get('headers', [])) req = Request(url, data, headers) try: - # If running Python 2.7.9+, disable SSL certificate validation - if sys.version_info >= (2, 7, 9): - response = urlopen(req, context=ssl._create_unverified_context()) - else: - response = urlopen(req) + response = urlopen(req, context=ssl._create_unverified_context()) except HTTPError as response: - pass # Propagate HTTP errors via the returned response message + pass # Propagate HTTP errors via the returned response message return { 'status': response.code, 'reason': response.msg, @@ -466,6 +464,7 @@ def urllib2_handler(url, message, **kwargs): 'body': BytesIO(response.read()) } + def isatom(body): """Answers if the given response body looks like ATOM.""" root = XML(body) @@ -475,6 +474,7 @@ def isatom(body): root.find(XNAME_ID) is not None and \ root.find(XNAME_TITLE) is not None + class TestPluggableHTTP(testlib.SDKTestCase): # Verify pluggable HTTP reqeust handlers. def test_handlers(self): @@ -491,6 +491,55 @@ def test_handlers(self): body = context.get(path).body.read() self.assertTrue(isatom(body)) + +def urllib2_insert_cookie_handler(url, message, **kwargs): + method = message['method'].lower() + data = message.get('body', b"") if method == 'post' else None + headers = dict(message.get('headers', [])) + req = Request(url, data, headers) + try: + response = urlopen(req, context=ssl._create_unverified_context()) + except HTTPError as response: + pass # Propagate HTTP errors via the returned response message + + # Mimic the insertion of 3rd party cookies into the response. + # An example is "sticky session"/"insert cookie" persistence + # of a load balancer for a SHC. + header_list = list(response.info().items()) + header_list.append(("Set-Cookie", "BIGipServer_splunk-shc-8089=1234567890.12345.0000; path=/; Httponly; Secure")) + header_list.append(("Set-Cookie", "home_made=yummy")) + + return { + 'status': response.code, + 'reason': response.msg, + 'headers': header_list, + 'body': BytesIO(response.read()) + } + + +class TestCookiePersistence(testlib.SDKTestCase): + # Verify persistence of 3rd party inserted cookies. + def test_3rdPartyInsertedCookiePersistence(self): + paths = ["/services", "authentication/users", + "search/jobs"] + logging.debug("Connecting with urllib2_insert_cookie_handler %s", urllib2_insert_cookie_handler) + context = binding.connect( + handler=urllib2_insert_cookie_handler, + **self.opts.kwargs) + + persisted_cookies = context.get_cookies() + + splunk_token_found = False + for k, v in persisted_cookies.items(): + if k[:8] == "splunkd_": + splunk_token_found = True + break + + self.assertEqual(splunk_token_found, True) + self.assertEqual(persisted_cookies['BIGipServer_splunk-shc-8089'], "1234567890.12345.0000") + self.assertEqual(persisted_cookies['home_made'], "yummy") + + @pytest.mark.smoke class TestLogout(BindingTestCase): def test_logout(self): @@ -516,7 +565,7 @@ def setUp(self): self.context = binding.connect(**self.opts.kwargs) # Skip these tests if running below Splunk 6.2, cookie-auth didn't exist before - import splunklib.client as client + from splunklib import client service = client.Service(**self.opts.kwargs) # TODO: Workaround the fact that skipTest is not defined by unittest2.TestCase service.login() @@ -585,14 +634,16 @@ def test_got_updated_cookie_with_get(self): self.assertEqual(list(new_cookies.values())[0], list(old_cookies.values())[0]) self.assertTrue(found) + @pytest.mark.smoke def test_login_fails_with_bad_cookie(self): # We should get an error if using a bad cookie try: - new_context = binding.connect(**{"cookie": "bad=cookie"}) + binding.connect(**{"cookie": "bad=cookie"}) self.fail() except AuthenticationError as ae: self.assertEqual(str(ae), "Login failed.") + @pytest.mark.smoke def test_login_with_multiple_cookies(self): # We should get an error if using a bad cookie new_context = binding.Context() @@ -631,80 +682,81 @@ def test_login_fails_without_cookie_or_token(self): class TestNamespace(unittest.TestCase): def test_namespace(self): tests = [ - ({ }, - { 'sharing': None, 'owner': None, 'app': None }), + ({}, + {'sharing': None, 'owner': None, 'app': None}), - ({ 'owner': "Bob" }, - { 'sharing': None, 'owner': "Bob", 'app': None }), + ({'owner': "Bob"}, + {'sharing': None, 'owner': "Bob", 'app': None}), - ({ 'app': "search" }, - { 'sharing': None, 'owner': None, 'app': "search" }), + ({'app': "search"}, + {'sharing': None, 'owner': None, 'app': "search"}), - ({ 'owner': "Bob", 'app': "search" }, - { 'sharing': None, 'owner': "Bob", 'app': "search" }), + ({'owner': "Bob", 'app': "search"}, + {'sharing': None, 'owner': "Bob", 'app': "search"}), - ({ 'sharing': "user", 'owner': "Bob@bob.com" }, - { 'sharing': "user", 'owner': "Bob@bob.com", 'app': None }), + ({'sharing': "user", 'owner': "Bob@bob.com"}, + {'sharing': "user", 'owner': "Bob@bob.com", 'app': None}), - ({ 'sharing': "user" }, - { 'sharing': "user", 'owner': None, 'app': None }), + ({'sharing': "user"}, + {'sharing': "user", 'owner': None, 'app': None}), - ({ 'sharing': "user", 'owner': "Bob" }, - { 'sharing': "user", 'owner': "Bob", 'app': None }), + ({'sharing': "user", 'owner': "Bob"}, + {'sharing': "user", 'owner': "Bob", 'app': None}), - ({ 'sharing': "user", 'app': "search" }, - { 'sharing': "user", 'owner': None, 'app': "search" }), + ({'sharing': "user", 'app': "search"}, + {'sharing': "user", 'owner': None, 'app': "search"}), - ({ 'sharing': "user", 'owner': "Bob", 'app': "search" }, - { 'sharing': "user", 'owner': "Bob", 'app': "search" }), + ({'sharing': "user", 'owner': "Bob", 'app': "search"}, + {'sharing': "user", 'owner': "Bob", 'app': "search"}), - ({ 'sharing': "app" }, - { 'sharing': "app", 'owner': "nobody", 'app': None }), + ({'sharing': "app"}, + {'sharing': "app", 'owner': "nobody", 'app': None}), - ({ 'sharing': "app", 'owner': "Bob" }, - { 'sharing': "app", 'owner': "nobody", 'app': None }), + ({'sharing': "app", 'owner': "Bob"}, + {'sharing': "app", 'owner': "nobody", 'app': None}), - ({ 'sharing': "app", 'app': "search" }, - { 'sharing': "app", 'owner': "nobody", 'app': "search" }), + ({'sharing': "app", 'app': "search"}, + {'sharing': "app", 'owner': "nobody", 'app': "search"}), - ({ 'sharing': "app", 'owner': "Bob", 'app': "search" }, - { 'sharing': "app", 'owner': "nobody", 'app': "search" }), + ({'sharing': "app", 'owner': "Bob", 'app': "search"}, + {'sharing': "app", 'owner': "nobody", 'app': "search"}), - ({ 'sharing': "global" }, - { 'sharing': "global", 'owner': "nobody", 'app': None }), + ({'sharing': "global"}, + {'sharing': "global", 'owner': "nobody", 'app': None}), - ({ 'sharing': "global", 'owner': "Bob" }, - { 'sharing': "global", 'owner': "nobody", 'app': None }), + ({'sharing': "global", 'owner': "Bob"}, + {'sharing': "global", 'owner': "nobody", 'app': None}), - ({ 'sharing': "global", 'app': "search" }, - { 'sharing': "global", 'owner': "nobody", 'app': "search" }), + ({'sharing': "global", 'app': "search"}, + {'sharing': "global", 'owner': "nobody", 'app': "search"}), - ({ 'sharing': "global", 'owner': "Bob", 'app': "search" }, - { 'sharing': "global", 'owner': "nobody", 'app': "search" }), + ({'sharing': "global", 'owner': "Bob", 'app': "search"}, + {'sharing': "global", 'owner': "nobody", 'app': "search"}), - ({ 'sharing': "system" }, - { 'sharing': "system", 'owner': "nobody", 'app': "system" }), + ({'sharing': "system"}, + {'sharing': "system", 'owner': "nobody", 'app': "system"}), - ({ 'sharing': "system", 'owner': "Bob" }, - { 'sharing': "system", 'owner': "nobody", 'app': "system" }), + ({'sharing': "system", 'owner': "Bob"}, + {'sharing': "system", 'owner': "nobody", 'app': "system"}), - ({ 'sharing': "system", 'app': "search" }, - { 'sharing': "system", 'owner': "nobody", 'app': "system" }), + ({'sharing': "system", 'app': "search"}, + {'sharing': "system", 'owner': "nobody", 'app': "system"}), - ({ 'sharing': "system", 'owner': "Bob", 'app': "search" }, - { 'sharing': "system", 'owner': "nobody", 'app': "system" }), + ({'sharing': "system", 'owner': "Bob", 'app': "search"}, + {'sharing': "system", 'owner': "nobody", 'app': "system"}), - ({ 'sharing': 'user', 'owner': '-', 'app': '-'}, - { 'sharing': 'user', 'owner': '-', 'app': '-'})] + ({'sharing': 'user', 'owner': '-', 'app': '-'}, + {'sharing': 'user', 'owner': '-', 'app': '-'})] for kwargs, expected in tests: namespace = binding.namespace(**kwargs) - for k, v in six.iteritems(expected): + for k, v in expected.items(): self.assertEqual(namespace[k], v) def test_namespace_fails(self): self.assertRaises(ValueError, binding.namespace, sharing="gobble") + @pytest.mark.smoke class TestBasicAuthentication(unittest.TestCase): def setUp(self): @@ -715,13 +767,13 @@ def setUp(self): opts["password"] = self.opts.kwargs["password"] self.context = binding.connect(**opts) - import splunklib.client as client + from splunklib import client service = client.Service(**opts) if getattr(unittest.TestCase, 'assertIsNotNone', None) is None: def assertIsNotNone(self, obj, msg=None): - if obj is None: - raise self.failureException(msg or '%r is not None' % obj) + if obj is None: + raise self.failureException(msg or '%r is not None' % obj) def test_basic_in_auth_headers(self): self.assertIsNotNone(self.context._auth_headers) @@ -732,6 +784,7 @@ def test_basic_in_auth_headers(self): self.assertEqual(self.context._auth_headers[0][1][:6], "Basic ") self.assertEqual(self.context.get("/services").status, 200) + @pytest.mark.smoke class TestTokenAuthentication(BindingTestCase): def test_preexisting_token(self): @@ -746,15 +799,12 @@ def test_preexisting_token(self): self.assertEqual(response.status, 200) socket = newContext.connect() - socket.write(("POST %s HTTP/1.1\r\n" % \ - self.context._abspath("some/path/to/post/to")).encode('utf-8')) - socket.write(("Host: %s:%s\r\n" % \ - (self.context.host, self.context.port)).encode('utf-8')) - socket.write(("Accept-Encoding: identity\r\n").encode('utf-8')) - socket.write(("Authorization: %s\r\n" % \ - self.context.token).encode('utf-8')) + socket.write((f"POST {self.context._abspath('some/path/to/post/to')} HTTP/1.1\r\n").encode('utf-8')) + socket.write((f"Host: {self.context.host}:{self.context.port}\r\n").encode('utf-8')) + socket.write("Accept-Encoding: identity\r\n".encode('utf-8')) + socket.write((f"Authorization: {self.context.token}\r\n").encode('utf-8')) socket.write("X-Splunk-Input-Mode: Streaming\r\n".encode('utf-8')) - socket.write(("\r\n").encode('utf-8')) + socket.write("\r\n".encode('utf-8')) socket.close() def test_preexisting_token_sans_splunk(self): @@ -774,18 +824,14 @@ def test_preexisting_token_sans_splunk(self): self.assertEqual(response.status, 200) socket = newContext.connect() - socket.write(("POST %s HTTP/1.1\r\n" %\ - self.context._abspath("some/path/to/post/to")).encode('utf-8')) - socket.write(("Host: %s:%s\r\n" %\ - (self.context.host, self.context.port)).encode('utf-8')) + socket.write((f"POST {self.context._abspath('some/path/to/post/to')} HTTP/1.1\r\n").encode('utf-8')) + socket.write((f"Host: {self.context.host}:{self.context.port}\r\n").encode('utf-8')) socket.write("Accept-Encoding: identity\r\n".encode('utf-8')) - socket.write(("Authorization: %s\r\n" %\ - self.context.token).encode('utf-8')) - socket.write(("X-Splunk-Input-Mode: Streaming\r\n").encode('utf-8')) - socket.write(("\r\n").encode('utf-8')) + socket.write((f"Authorization: {self.context.token}\r\n").encode('utf-8')) + socket.write("X-Splunk-Input-Mode: Streaming\r\n".encode('utf-8')) + socket.write("\r\n".encode('utf-8')) socket.close() - def test_connect_with_preexisting_token_sans_user_and_pass(self): token = self.context.token opts = self.opts.kwargs.copy() @@ -798,13 +844,10 @@ def test_connect_with_preexisting_token_sans_user_and_pass(self): self.assertEqual(response.status, 200) socket = newContext.connect() - socket.write(("POST %s HTTP/1.1\r\n" % \ - self.context._abspath("some/path/to/post/to")).encode('utf-8')) - socket.write(("Host: %s:%s\r\n" % \ - (self.context.host, self.context.port)).encode('utf-8')) + socket.write((f"POST {self.context._abspath('some/path/to/post/to')} HTTP/1.1\r\n").encode('utf-8')) + socket.write((f"Host: {self.context.host}:{self.context.port}\r\n").encode('utf-8')) socket.write("Accept-Encoding: identity\r\n".encode('utf-8')) - socket.write(("Authorization: %s\r\n" % \ - self.context.token).encode('utf-8')) + socket.write((f"Authorization: {self.context.token}\r\n").encode('utf-8')) socket.write("X-Splunk-Input-Mode: Streaming\r\n".encode('utf-8')) socket.write("\r\n".encode('utf-8')) socket.close() @@ -814,34 +857,37 @@ class TestPostWithBodyParam(unittest.TestCase): def test_post(self): def handler(url, message, **kwargs): - assert six.ensure_str(url) == "https://localhost:8089/servicesNS/testowner/testapp/foo/bar" - assert six.ensure_str(message["body"]) == "testkey=testvalue" + assert url == "https://localhost:8089/servicesNS/testowner/testapp/foo/bar" + assert message["body"] == b"testkey=testvalue" return splunklib.data.Record({ "status": 200, "headers": [], }) + ctx = binding.Context(handler=handler) ctx.post("foo/bar", owner="testowner", app="testapp", body={"testkey": "testvalue"}) def test_post_with_params_and_body(self): def handler(url, message, **kwargs): assert url == "https://localhost:8089/servicesNS/testowner/testapp/foo/bar?extrakey=extraval" - assert six.ensure_str(message["body"]) == "testkey=testvalue" + assert message["body"] == b"testkey=testvalue" return splunklib.data.Record({ "status": 200, "headers": [], }) + ctx = binding.Context(handler=handler) ctx.post("foo/bar", extrakey="extraval", owner="testowner", app="testapp", body={"testkey": "testvalue"}) def test_post_with_params_and_no_body(self): def handler(url, message, **kwargs): assert url == "https://localhost:8089/servicesNS/testowner/testapp/foo/bar" - assert six.ensure_str(message["body"]) == "extrakey=extraval" + assert message["body"] == b"extrakey=extraval" return splunklib.data.Record({ "status": 200, "headers": [], }) + ctx = binding.Context(handler=handler) ctx.post("foo/bar", extrakey="extraval", owner="testowner", app="testapp") @@ -853,10 +899,11 @@ def wrapped(handler_self): handler_self.send_response(response_code) handler_self.end_headers() handler_self.wfile.write(body) + return wrapped -class MockServer(object): +class MockServer: def __init__(self, port=9093, **handlers): methods = {"do_" + k: _wrap_handler(v) for (k, v) in handlers.items()} @@ -875,6 +922,7 @@ def log(*args): # To silence server access logs def run(): self._svr.handle_request() + self._thread = Thread(target=run) self._thread.daemon = True @@ -893,7 +941,7 @@ def test_post_with_body_urlencoded(self): def check_response(handler): length = int(handler.headers.get('content-length', 0)) body = handler.rfile.read(length) - assert six.ensure_str(body) == "foo=bar" + assert body.decode('utf-8') == "foo=bar" with MockServer(POST=check_response): ctx = binding.connect(port=9093, scheme='http', token="waffle") @@ -903,19 +951,20 @@ def test_post_with_body_string(self): def check_response(handler): length = int(handler.headers.get('content-length', 0)) body = handler.rfile.read(length) - assert six.ensure_str(handler.headers['content-type']) == 'application/json' + assert handler.headers['content-type'] == 'application/json' assert json.loads(body)["baz"] == "baf" with MockServer(POST=check_response): - ctx = binding.connect(port=9093, scheme='http', token="waffle", headers=[("Content-Type", "application/json")]) + ctx = binding.connect(port=9093, scheme='http', token="waffle", + headers=[("Content-Type", "application/json")]) ctx.post("/", foo="bar", body='{"baz": "baf"}') def test_post_with_body_dict(self): def check_response(handler): length = int(handler.headers.get('content-length', 0)) body = handler.rfile.read(length) - assert six.ensure_str(handler.headers['content-type']) == 'application/x-www-form-urlencoded' - assert six.ensure_str(body) == 'baz=baf&hep=cat' or six.ensure_str(body) == 'hep=cat&baz=baf' + assert handler.headers['content-type'] == 'application/x-www-form-urlencoded' + assert ensure_str(body) in ['baz=baf&hep=cat', 'hep=cat&baz=baf'] with MockServer(POST=check_response): ctx = binding.connect(port=9093, scheme='http', token="waffle") @@ -923,8 +972,4 @@ def check_response(handler): if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest unittest.main() diff --git a/tests/test_collection.py b/tests/test_collection.py index 0fd9a1c33..ec641a6d6 100755 --- a/tests/test_collection.py +++ b/tests/test_collection.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,14 +14,12 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import from tests import testlib import logging from contextlib import contextmanager -import splunklib.client as client -from splunklib.six.moves import range +from splunklib import client collections = [ 'apps', @@ -41,9 +39,9 @@ class CollectionTestCase(testlib.SDKTestCase): def setUp(self): - super(CollectionTestCase, self).setUp() + super().setUp() if self.service.splunk_version[0] >= 5 and 'modular_input_kinds' not in collections: - collections.append('modular_input_kinds') # Not supported before Splunk 5.0 + collections.append('modular_input_kinds') # Not supported before Splunk 5.0 else: logging.info("Skipping modular_input_kinds; not supported by Splunk %s" % \ '.'.join(str(x) for x in self.service.splunk_version)) @@ -69,59 +67,51 @@ def test_metadata(self): found_fields_keys = set(metadata.fields.keys()) self.assertTrue(found_access_keys >= expected_access_keys, msg='metadata.access is missing keys on ' + \ - '%s (found: %s, expected: %s)' % \ - (coll, found_access_keys, - expected_access_keys)) + f'{coll} (found: {found_access_keys}, expected: {expected_access_keys})') self.assertTrue(found_fields_keys >= expected_fields_keys, msg='metadata.fields is missing keys on ' + \ - '%s (found: %s, expected: %s)' % \ - (coll, found_fields_keys, - expected_fields_keys)) + f'{coll} (found: {found_fields_keys}, expected: {expected_fields_keys})') def test_list(self): for coll_name in collections: coll = getattr(self.service, coll_name) expected = [ent.name for ent in coll.list(count=10, sort_mode="auto")] if len(expected) == 0: - logging.debug("No entities in collection %s; skipping test.", coll_name) + logging.debug(f"No entities in collection {coll_name}; skipping test.") found = [ent.name for ent in coll.list()][:10] self.assertEqual(expected, found, - msg='on %s (expected: %s, found: %s)' % \ - (coll_name, expected, found)) + msg=f'on {coll_name} (expected: {expected}, found: {found})') def test_list_with_count(self): N = 5 for coll_name in collections: coll = getattr(self.service, coll_name) - expected = [ent.name for ent in coll.list(count=N+5)][:N] - N = len(expected) # in case there are v2") self.assertEqual(result, - {'e1': {'a1': 'v1', 'e2': {'$text': 'v2', 'a1': 'v1'}}}) + {'e1': {'a1': 'v1', 'e2': {'$text': 'v2', 'a1': 'v1'}}}) def test_real(self): """Test some real Splunk response examples.""" @@ -120,12 +119,8 @@ def test_invalid(self): if sys.version_info[1] >= 7: self.assertRaises(et.ParseError, data.load, "") else: - if six.PY2: - from xml.parsers.expat import ExpatError - self.assertRaises(ExpatError, data.load, "") - else: - from xml.etree.ElementTree import ParseError - self.assertRaises(ParseError, data.load, "") + from xml.etree.ElementTree import ParseError + self.assertRaises(ParseError, data.load, "") self.assertRaises(KeyError, data.load, "a") @@ -166,8 +161,8 @@ def test_dict(self): """) - self.assertEqual(result, - {'content': {'n1': {'n1n1': "n1v1"}, 'n2': {'n2n1': "n2v1"}}}) + self.assertEqual(result, + {'content': {'n1': {'n1n1': "n1v1"}, 'n2': {'n2n1': "n2v1"}}}) result = data.load(""" @@ -179,8 +174,8 @@ def test_dict(self): """) - self.assertEqual(result, - {'content': {'n1': ['1', '2', '3', '4']}}) + self.assertEqual(result, + {'content': {'n1': ['1', '2', '3', '4']}}) def test_list(self): result = data.load("""""") @@ -222,8 +217,8 @@ def test_list(self): v4 """) - self.assertEqual(result, - {'content': [{'n1':"v1"}, {'n2':"v2"}, {'n3':"v3"}, {'n4':"v4"}]}) + self.assertEqual(result, + {'content': [{'n1': "v1"}, {'n2': "v2"}, {'n3': "v3"}, {'n4': "v4"}]}) result = data.load(""" @@ -233,7 +228,7 @@ def test_list(self): """) self.assertEqual(result, - {'build': '101089', 'cpu_arch': 'i386', 'isFree': '0'}) + {'build': '101089', 'cpu_arch': 'i386', 'isFree': '0'}) def test_record(self): d = data.record() @@ -244,17 +239,14 @@ def test_record(self): 'bar.zrp.peem': 9}) self.assertEqual(d['foo'], 5) self.assertEqual(d['bar.baz'], 6) - self.assertEqual(d['bar'], {'baz': 6, 'qux': 7, 'zrp': {'meep': 8, 'peem':9}}) + self.assertEqual(d['bar'], {'baz': 6, 'qux': 7, 'zrp': {'meep': 8, 'peem': 9}}) self.assertEqual(d.foo, 5) self.assertEqual(d.bar.baz, 6) - self.assertEqual(d.bar, {'baz': 6, 'qux': 7, 'zrp': {'meep': 8, 'peem':9}}) + self.assertEqual(d.bar, {'baz': 6, 'qux': 7, 'zrp': {'meep': 8, 'peem': 9}}) self.assertRaises(KeyError, d.__getitem__, 'boris') if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest - unittest.main() + import unittest + unittest.main() diff --git a/tests/test_event_type.py b/tests/test_event_type.py index 5ae2c7ecd..c50e1ea3a 100755 --- a/tests/test_event_type.py +++ b/tests/test_event_type.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,17 +14,15 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import from tests import testlib -import logging -import splunklib.client as client class TestRead(testlib.SDKTestCase): def test_read(self): for event_type in self.service.event_types.list(count=1): self.check_entity(event_type) + class TestCreate(testlib.SDKTestCase): def test_create(self): self.event_type_name = testlib.tmpname() @@ -42,43 +40,41 @@ def test_create(self): self.assertEqual(self.event_type_name, event_type.name) def tearDown(self): - super(TestCreate, self).setUp() + super().setUp() try: self.service.event_types.delete(self.event_type_name) except KeyError: pass + class TestEventType(testlib.SDKTestCase): def setUp(self): - super(TestEventType, self).setUp() + super().setUp() self.event_type_name = testlib.tmpname() self.event_type = self.service.event_types.create( self.event_type_name, search="index=_internal *") def tearDown(self): - super(TestEventType, self).setUp() + super().setUp() try: self.service.event_types.delete(self.event_type_name) except KeyError: pass - def test_delete(self): - self.assertTrue(self.event_type_name in self.service.event_types) - self.service.event_types.delete(self.event_type_name) - self.assertFalse(self.event_type_name in self.service.event_types) + # def test_delete(self): + # self.assertTrue(self.event_type_name in self.service.event_types) + # self.service.event_types.delete(self.event_type_name) + # self.assertFalse(self.event_type_name in self.service.event_types) def test_update(self): - kwargs = {} - kwargs['search'] = "index=_audit *" - kwargs['description'] = "An audit event" - kwargs['priority'] = '3' + kwargs = {'search': "index=_audit *", 'description': "An audit event", 'priority': '3'} self.event_type.update(**kwargs) self.event_type.refresh() self.assertEqual(self.event_type['search'], kwargs['search']) self.assertEqual(self.event_type['description'], kwargs['description']) self.assertEqual(self.event_type['priority'], kwargs['priority']) - + def test_enable_disable(self): self.assertEqual(self.event_type['disabled'], '0') self.event_type.disable() @@ -88,9 +84,8 @@ def test_enable_disable(self): self.event_type.refresh() self.assertEqual(self.event_type['disabled'], '0') + if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + import unittest + unittest.main() diff --git a/tests/test_fired_alert.py b/tests/test_fired_alert.py index 2480d4150..9d16fddc1 100755 --- a/tests/test_fired_alert.py +++ b/tests/test_fired_alert.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,22 +14,19 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import from tests import testlib -import logging -import splunklib.client as client class FiredAlertTestCase(testlib.SDKTestCase): def setUp(self): - super(FiredAlertTestCase, self).setUp() + super().setUp() self.index_name = testlib.tmpname() self.assertFalse(self.index_name in self.service.indexes) self.index = self.service.indexes.create(self.index_name) saved_searches = self.service.saved_searches self.saved_search_name = testlib.tmpname() self.assertFalse(self.saved_search_name in saved_searches) - query = "search index=%s" % self.index_name + query = f"search index={self.index_name}" kwargs = {'alert_type': 'always', 'alert.severity': "3", 'alert.suppress': "0", @@ -43,7 +40,7 @@ def setUp(self): query, **kwargs) def tearDown(self): - super(FiredAlertTestCase, self).tearDown() + super().tearDown() if self.service.splunk_version >= (5,): self.service.indexes.delete(self.index_name) for saved_search in self.service.saved_searches: @@ -57,7 +54,7 @@ def test_new_search_is_empty(self): self.assertEqual(len(self.saved_search.history()), 0) self.assertEqual(len(self.saved_search.fired_alerts), 0) self.assertFalse(self.saved_search_name in self.service.fired_alerts) - + def test_alerts_on_events(self): self.assertEqual(self.saved_search.alert_count, 0) self.assertEqual(len(self.saved_search.fired_alerts), 0) @@ -71,14 +68,17 @@ def test_alerts_on_events(self): self.index.refresh() self.index.submit('This is a test ' + testlib.tmpname(), sourcetype='sdk_use', host='boris') + def f(): self.index.refresh() - return int(self.index['totalEventCount']) == eventCount+1 + return int(self.index['totalEventCount']) == eventCount + 1 + self.assertEventuallyTrue(f, timeout=50) def g(): self.saved_search.refresh() return self.saved_search.alert_count == 1 + self.assertEventuallyTrue(g, timeout=200) alerts = self.saved_search.fired_alerts @@ -90,9 +90,8 @@ def test_read(self): for alert in alert_group.alerts: alert.content + if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + import unittest + unittest.main() diff --git a/tests/test_index.py b/tests/test_index.py index 9e2a53298..2582934bb 100755 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,30 +14,23 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import -from __future__ import print_function -from tests import testlib import logging -import os import time -import splunklib.client as client -try: - import unittest -except ImportError: - import unittest2 as unittest - import pytest +from tests import testlib +from splunklib import client + class IndexTest(testlib.SDKTestCase): def setUp(self): - super(IndexTest, self).setUp() + super().setUp() self.index_name = testlib.tmpname() self.index = self.service.indexes.create(self.index_name) self.assertEventuallyTrue(lambda: self.index.refresh()['disabled'] == '0') def tearDown(self): - super(IndexTest, self).tearDown() + super().tearDown() # We can't delete an index with the REST API before Splunk # 5.0. In 4.x, we just have to leave them lying around until # someone cares to go clean them up. Unique naming prevents @@ -92,14 +85,14 @@ def test_disable_enable(self): # self.assertEqual(self.index['totalEventCount'], '0') def test_prefresh(self): - self.assertEqual(self.index['disabled'], '0') # Index is prefreshed + self.assertEqual(self.index['disabled'], '0') # Index is prefreshed def test_submit(self): event_count = int(self.index['totalEventCount']) self.assertEqual(self.index['sync'], '0') self.assertEqual(self.index['disabled'], '0') self.index.submit("Hello again!", sourcetype="Boris", host="meep") - self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count+1, timeout=50) + self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count + 1, timeout=50) def test_submit_namespaced(self): s = client.connect(**{ @@ -114,14 +107,14 @@ def test_submit_namespaced(self): self.assertEqual(i['sync'], '0') self.assertEqual(i['disabled'], '0') i.submit("Hello again namespaced!", sourcetype="Boris", host="meep") - self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count+1, timeout=50) + self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count + 1, timeout=50) def test_submit_via_attach(self): event_count = int(self.index['totalEventCount']) cn = self.index.attach() cn.send(b"Hello Boris!\r\n") cn.close() - self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count+1, timeout=60) + self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count + 1, timeout=60) def test_submit_via_attach_using_token_header(self): # Remove the prefix from the token @@ -133,14 +126,14 @@ def test_submit_via_attach_using_token_header(self): cn = i.attach() cn.send(b"Hello Boris 5!\r\n") cn.close() - self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count+1, timeout=60) + self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count + 1, timeout=60) def test_submit_via_attached_socket(self): event_count = int(self.index['totalEventCount']) f = self.index.attached_socket with f() as sock: sock.send(b'Hello world!\r\n') - self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count+1, timeout=60) + self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count + 1, timeout=60) def test_submit_via_attach_with_cookie_header(self): # Skip this test if running below Splunk 6.2, cookie-auth didn't exist before @@ -156,7 +149,7 @@ def test_submit_via_attach_with_cookie_header(self): cn = service.indexes[self.index_name].attach() cn.send(b"Hello Boris!\r\n") cn.close() - self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count+1, timeout=60) + self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count + 1, timeout=60) def test_submit_via_attach_with_multiple_cookie_headers(self): # Skip this test if running below Splunk 6.2, cookie-auth didn't exist before @@ -171,7 +164,7 @@ def test_submit_via_attach_with_multiple_cookie_headers(self): cn = service.indexes[self.index_name].attach() cn.send(b"Hello Boris!\r\n") cn.close() - self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count+1, timeout=60) + self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count + 1, timeout=60) @pytest.mark.app def test_upload(self): @@ -181,11 +174,10 @@ def test_upload(self): path = self.pathInApp("file_to_upload", ["log.txt"]) self.index.upload(path) - self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count+4, timeout=60) + self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count + 4, timeout=60) + if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + import unittest + unittest.main() diff --git a/tests/test_input.py b/tests/test_input.py index c7d48dc38..53436f73f 100755 --- a/tests/test_input.py +++ b/tests/test_input.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -13,22 +13,12 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import -from __future__ import print_function - +import logging +import pytest from splunklib.binding import HTTPError from tests import testlib -import logging -from splunklib import six -try: - import unittest -except ImportError: - import unittest2 as unittest - -import splunklib.client as client - -import pytest +from splunklib import client def highest_port(service, base_port, *kinds): @@ -42,7 +32,7 @@ def highest_port(service, base_port, *kinds): class TestTcpInputNameHandling(testlib.SDKTestCase): def setUp(self): - super(TestTcpInputNameHandling, self).setUp() + super().setUp() self.base_port = highest_port(self.service, 10000, 'tcp', 'splunktcp', 'udp') + 1 def tearDown(self): @@ -50,11 +40,11 @@ def tearDown(self): port = int(input.name.split(':')[-1]) if port >= self.base_port: input.delete() - super(TestTcpInputNameHandling, self).tearDown() + super().tearDown() def create_tcp_input(self, base_port, kind, **options): port = base_port - while True: # Find the next unbound port + while True: # Find the next unbound port try: input = self.service.inputs.create(str(port), kind, **options) return input @@ -75,7 +65,7 @@ def test_cannot_create_with_restrictToHost_in_name(self): ) def test_create_tcp_ports_with_restrictToHost(self): - for kind in ['tcp', 'splunktcp']: # Multiplexed UDP ports are not supported + for kind in ['tcp', 'splunktcp']: # Multiplexed UDP ports are not supported # Make sure we can create two restricted inputs on the same port boris = self.service.inputs.create(str(self.base_port), kind, restrictToHost='boris') natasha = self.service.inputs.create(str(self.base_port), kind, restrictToHost='natasha') @@ -110,7 +100,7 @@ def test_unrestricted_to_restricted_collision(self): unrestricted.delete() def test_update_restrictToHost_fails(self): - for kind in ['tcp', 'splunktcp']: # No UDP, since it's broken in Splunk + for kind in ['tcp', 'splunktcp']: # No UDP, since it's broken in Splunk boris = self.create_tcp_input(self.base_port, kind, restrictToHost='boris') self.assertRaises( @@ -149,7 +139,6 @@ def test_read_invalid_input(self): self.assertTrue("HTTP 404 Not Found" in str(he)) def test_inputs_list_on_one_kind_with_count(self): - N = 10 expected = [x.name for x in self.service.inputs.list('monitor')[:10]] found = [x.name for x in self.service.inputs.list('monitor', count=10)] self.assertEqual(expected, found) @@ -181,21 +170,22 @@ def test_oneshot(self): def f(): index.refresh() - return int(index['totalEventCount']) == eventCount+4 + return int(index['totalEventCount']) == eventCount + 4 + self.assertEventuallyTrue(f, timeout=60) def test_oneshot_on_nonexistant_file(self): name = testlib.tmpname() self.assertRaises(HTTPError, - self.service.inputs.oneshot, name) + self.service.inputs.oneshot, name) class TestInput(testlib.SDKTestCase): def setUp(self): - super(TestInput, self).setUp() + super().setUp() inputs = self.service.inputs - unrestricted_port = str(highest_port(self.service, 10000, 'tcp', 'splunktcp', 'udp')+1) - restricted_port = str(highest_port(self.service, int(unrestricted_port)+1, 'tcp', 'splunktcp')+1) + unrestricted_port = str(highest_port(self.service, 10000, 'tcp', 'splunktcp', 'udp') + 1) + restricted_port = str(highest_port(self.service, int(unrestricted_port) + 1, 'tcp', 'splunktcp') + 1) test_inputs = [{'kind': 'tcp', 'name': unrestricted_port, 'host': 'sdk-test'}, {'kind': 'udp', 'name': unrestricted_port, 'host': 'sdk-test'}, {'kind': 'tcp', 'name': 'boris:' + restricted_port, 'host': 'sdk-test'}] @@ -209,8 +199,8 @@ def setUp(self): inputs.create(restricted_port, 'tcp', restrictToHost='boris') def tearDown(self): - super(TestInput, self).tearDown() - for entity in six.itervalues(self._test_entities): + super().tearDown() + for entity in self._test_entities.values(): try: self.service.inputs.delete( kind=entity.kind, @@ -233,7 +223,7 @@ def test_lists_modular_inputs(self): self.uncheckedRestartSplunk() inputs = self.service.inputs - if ('abcd','test2') not in inputs: + if ('abcd', 'test2') not in inputs: inputs.create('abcd', 'test2', field1='boris') input = inputs['abcd', 'test2'] @@ -241,7 +231,7 @@ def test_lists_modular_inputs(self): def test_create(self): inputs = self.service.inputs - for entity in six.itervalues(self._test_entities): + for entity in self._test_entities.values(): self.check_entity(entity) self.assertTrue(isinstance(entity, client.Input)) @@ -252,7 +242,7 @@ def test_get_kind_list(self): def test_read(self): inputs = self.service.inputs - for this_entity in six.itervalues(self._test_entities): + for this_entity in self._test_entities.values(): kind, name = this_entity.kind, this_entity.name read_entity = inputs[name, kind] self.assertEqual(this_entity.kind, read_entity.kind) @@ -268,7 +258,7 @@ def test_read_indiviually(self): def test_update(self): inputs = self.service.inputs - for entity in six.itervalues(self._test_entities): + for entity in self._test_entities.values(): kind, name = entity.kind, entity.name kwargs = {'host': 'foo'} entity.update(**kwargs) @@ -278,19 +268,19 @@ def test_update(self): @pytest.mark.skip('flaky') def test_delete(self): inputs = self.service.inputs - remaining = len(self._test_entities)-1 - for input_entity in six.itervalues(self._test_entities): + remaining = len(self._test_entities) - 1 + for input_entity in self._test_entities.values(): name = input_entity.name kind = input_entity.kind self.assertTrue(name in inputs) - self.assertTrue((name,kind) in inputs) + self.assertTrue((name, kind) in inputs) if remaining == 0: inputs.delete(name) self.assertFalse(name in inputs) else: if not name.startswith('boris'): self.assertRaises(client.AmbiguousReferenceException, - inputs.delete, name) + inputs.delete, name) self.service.inputs.delete(name, kind) self.assertFalse((name, kind) in inputs) self.assertRaises(client.HTTPError, @@ -299,8 +289,6 @@ def test_delete(self): if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + import unittest + unittest.main() diff --git a/tests/test_job.py b/tests/test_job.py index e514c83c6..a276e212b 100755 --- a/tests/test_job.py +++ b/tests/test_job.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,9 +14,6 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import -from __future__ import print_function - from io import BytesIO from time import sleep @@ -24,23 +21,15 @@ from tests import testlib -try: - import unittest2 as unittest -except ImportError: - import unittest +import unittest -import splunklib.client as client -import splunklib.results as results +from splunklib import client +from splunklib import results from splunklib.binding import _log_duration, HTTPError import pytest -# TODO: Determine if we should be importing ExpatError if ParseError is not available (e.g., on Python 2.6) -# There's code below that now catches SyntaxError instead of ParseError. Should we be catching ExpathError instead? - -# from xml.etree.ElementTree import ParseError - class TestUtilities(testlib.SDKTestCase): def test_service_search(self): @@ -57,6 +46,7 @@ def test_oneshot_with_garbage_fails(self): jobs = self.service.jobs self.assertRaises(TypeError, jobs.create, "abcd", exec_mode="oneshot") + @pytest.mark.smoke def test_oneshot(self): jobs = self.service.jobs stream = jobs.oneshot("search index=_internal earliest=-1m | head 3", output_mode='json') @@ -84,21 +74,21 @@ def test_export(self): self.assertTrue(len(nonmessages) <= 3) def test_export_docstring_sample(self): - import splunklib.client as client - import splunklib.results as results + from splunklib import client + from splunklib import results service = self.service # cheat rr = results.JSONResultsReader(service.jobs.export("search * | head 5", output_mode='json')) for result in rr: if isinstance(result, results.Message): # Diagnostic messages may be returned in the results - pass #print '%s: %s' % (result.type, result.message) + pass #print(f'{result.type}: {result.message}') elif isinstance(result, dict): # Normal events are returned as dicts - pass #print result + pass #print(result) assert rr.is_preview == False def test_results_docstring_sample(self): - import splunklib.results as results + from splunklib import results service = self.service # cheat job = service.jobs.create("search * | head 5") while not job.is_done(): @@ -107,42 +97,42 @@ def test_results_docstring_sample(self): for result in rr: if isinstance(result, results.Message): # Diagnostic messages may be returned in the results - pass #print '%s: %s' % (result.type, result.message) + pass #print(f'{result.type}: {result.message}') elif isinstance(result, dict): # Normal events are returned as dicts - pass #print result + pass #print(result) assert rr.is_preview == False def test_preview_docstring_sample(self): - import splunklib.client as client - import splunklib.results as results + from splunklib import client + from splunklib import results service = self.service # cheat job = service.jobs.create("search * | head 5") rr = results.JSONResultsReader(job.preview(output_mode='json')) for result in rr: if isinstance(result, results.Message): # Diagnostic messages may be returned in the results - pass #print '%s: %s' % (result.type, result.message) + pass #print(f'{result.type}: {result.message}') elif isinstance(result, dict): # Normal events are returned as dicts - pass #print result + pass #print(result) if rr.is_preview: - pass #print "Preview of a running search job." + pass #print("Preview of a running search job.") else: - pass #print "Job is finished. Results are final." + pass #print("Job is finished. Results are final.") def test_oneshot_docstring_sample(self): - import splunklib.client as client - import splunklib.results as results + from splunklib import client + from splunklib import results service = self.service # cheat rr = results.JSONResultsReader(service.jobs.oneshot("search * | head 5", output_mode='json')) for result in rr: if isinstance(result, results.Message): # Diagnostic messages may be returned in the results - pass #print '%s: %s' % (result.type, result.message) + pass #print(f'{result.type}: {result.message}') elif isinstance(result, dict): # Normal events are returned as dicts - pass #print result + pass #print(result) assert rr.is_preview == False def test_normal_job_with_garbage_fails(self): @@ -188,7 +178,6 @@ def check_job(self, job): 'statusBuckets', 'ttl'] for key in keys: self.assertTrue(key in job.content) - return def test_read_jobs(self): jobs = self.service.jobs @@ -212,11 +201,11 @@ def test_get_job(self): class TestJobWithDelayedDone(testlib.SDKTestCase): def setUp(self): - super(TestJobWithDelayedDone, self).setUp() + super().setUp() self.job = None def tearDown(self): - super(TestJobWithDelayedDone, self).tearDown() + super().tearDown() if self.job is not None: self.job.cancel() self.assertEventuallyTrue(lambda: self.job.sid not in self.service.jobs) @@ -243,7 +232,6 @@ def is_preview_enabled(): return self.job.content['isPreviewEnabled'] == '1' self.assertEventuallyTrue(is_preview_enabled) - return @pytest.mark.app def test_setpriority(self): @@ -279,12 +267,11 @@ def f(): return int(self.job.content['priority']) == new_priority self.assertEventuallyTrue(f, timeout=sleep_duration + 5) - return class TestJob(testlib.SDKTestCase): def setUp(self): - super(TestJob, self).setUp() + super().setUp() self.query = "search index=_internal | head 3" self.job = self.service.jobs.create( query=self.query, @@ -292,7 +279,7 @@ def setUp(self): latest_time="now") def tearDown(self): - super(TestJob, self).tearDown() + super().tearDown() self.job.cancel() @_log_duration @@ -382,6 +369,7 @@ def test_search_invalid_query_as_json(self): except Exception as e: self.fail("Got some unexpected error. %s" % e.message) + @pytest.mark.smoke def test_v1_job_fallback(self): self.assertEventuallyTrue(self.job.is_done) self.assertLessEqual(int(self.job['eventCount']), 3) @@ -399,10 +387,10 @@ def test_v1_job_fallback(self): n_events = len([x for x in events_r if isinstance(x, dict)]) n_preview = len([x for x in preview_r if isinstance(x, dict)]) n_results = len([x for x in results_r if isinstance(x, dict)]) - - # Fallback test for Splunk Version 9+ - if self.service.splunk_version[0] >= 9: - self.assertGreaterEqual(9, self.service.splunk_version[0]) + + # Fallback test for Splunk Version 9.0.2+ + if not self.service.disable_v2_api: + self.assertTrue(client.PATH_JOBS_V2 in self.job.path) self.assertEqual(n_events, n_preview, n_results) diff --git a/tests/test_kvstore_batch.py b/tests/test_kvstore_batch.py index d32b665e6..a17a3e9d4 100755 --- a/tests/test_kvstore_batch.py +++ b/tests/test_kvstore_batch.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -# Copyright 2011-2020 Splunk, Inc. +# Copyright 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,22 +14,15 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import from tests import testlib -from splunklib.six.moves import range -try: - import unittest -except ImportError: - import unittest2 as unittest -import splunklib.client as client + class KVStoreBatchTestCase(testlib.SDKTestCase): def setUp(self): - super(KVStoreBatchTestCase, self).setUp() - #self.service.namespace['owner'] = 'nobody' + super().setUp() self.service.namespace['app'] = 'search' confs = self.service.kvstore - if ('test' in confs): + if 'test' in confs: confs['test'].delete() confs.create('test') @@ -69,15 +62,13 @@ def test_insert_find_update_data(self): self.assertEqual(testData[x][0]['data'], '#' + str(x + 1)) self.assertEqual(testData[x][0]['num'], x + 1) - def tearDown(self): confs = self.service.kvstore - if ('test' in confs): + if 'test' in confs: confs['test'].delete() + if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + import unittest + unittest.main() diff --git a/tests/test_kvstore_conf.py b/tests/test_kvstore_conf.py index a24537288..fe098f66a 100755 --- a/tests/test_kvstore_conf.py +++ b/tests/test_kvstore_conf.py @@ -14,18 +14,14 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import +import json from tests import testlib -try: - import unittest -except ImportError: - import unittest2 as unittest -import splunklib.client as client +from splunklib import client + class KVStoreConfTestCase(testlib.SDKTestCase): def setUp(self): - super(KVStoreConfTestCase, self).setUp() - #self.service.namespace['owner'] = 'nobody' + super().setUp() self.service.namespace['app'] = 'search' self.confs = self.service.kvstore if ('test' in self.confs): @@ -40,15 +36,29 @@ def test_create_delete_collection(self): self.confs.create('test') self.assertTrue('test' in self.confs) self.confs['test'].delete() - self.assertTrue(not 'test' in self.confs) + self.assertTrue('test' not in self.confs) + + def test_create_fields(self): + self.confs.create('test', accelerated_fields={'ind1':{'a':1}}, fields={'a':'number1'}) + self.assertEqual(self.confs['test']['field.a'], 'number1') + self.assertEqual(self.confs['test']['accelerated_fields.ind1'], {"a": 1}) + self.confs['test'].delete() def test_update_collection(self): self.confs.create('test') - self.confs['test'].post(**{'accelerated_fields.ind1': '{"a": 1}', 'field.a': 'number'}) + val = {"a": 1} + self.confs['test'].post(**{'accelerated_fields.ind1': json.dumps(val), 'field.a': 'number'}) self.assertEqual(self.confs['test']['field.a'], 'number') - self.assertEqual(self.confs['test']['accelerated_fields.ind1'], '{"a": 1}') + self.assertEqual(self.confs['test']['accelerated_fields.ind1'], {"a": 1}) self.confs['test'].delete() + def test_update_accelerated_fields(self): + self.confs.create('test', accelerated_fields={'ind1':{'a':1}}) + self.assertEqual(self.confs['test']['accelerated_fields.ind1'], {'a': 1}) + # update accelerated_field value + self.confs['test'].update_accelerated_field('ind1', {'a': -1}) + self.assertEqual(self.confs['test']['accelerated_fields.ind1'], {'a': -1}) + self.confs['test'].delete() def test_update_fields(self): self.confs.create('test') @@ -58,7 +68,6 @@ def test_update_fields(self): self.assertEqual(self.confs['test']['field.a'], 'string') self.confs['test'].delete() - def test_create_unique_collection(self): self.confs.create('test') self.assertTrue('test' in self.confs) @@ -77,24 +86,12 @@ def test_overlapping_collections(self): self.confs['test'].delete() self.confs['test'].delete() - """ - def test_create_accelerated_fields_fields(self): - self.confs.create('test', indexes={'foo': '{"foo": 1}', 'bar': {'bar': -1}}, **{'field.foo': 'string'}) - self.assertEqual(self.confs['test']['accelerated_fields.foo'], '{"foo": 1}') - self.assertEqual(self.confs['test']['field.foo'], 'string') - self.assertRaises(client.HTTPError, lambda: self.confs['test'].post(**{'accelerated_fields.foo': 'THIS IS INVALID'})) - self.assertEqual(self.confs['test']['accelerated_fields.foo'], '{"foo": 1}') - self.confs['test'].update_accelerated_fields('foo', '') - self.assertEqual(self.confs['test']['accelerated_fields.foo'], None) - """ - def tearDown(self): - if ('test' in self.confs): + if 'test' in self.confs: self.confs['test'].delete() + if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + import unittest + unittest.main() diff --git a/tests/test_kvstore_data.py b/tests/test_kvstore_data.py index 6ddeae688..5860f6fcf 100755 --- a/tests/test_kvstore_data.py +++ b/tests/test_kvstore_data.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -# Copyright 2011-2020 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,20 +14,15 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import import json from tests import testlib -from splunklib.six.moves import range -try: - import unittest -except ImportError: - import unittest2 as unittest -import splunklib.client as client + +from splunklib import client + class KVStoreDataTestCase(testlib.SDKTestCase): def setUp(self): - super(KVStoreDataTestCase, self).setUp() - #self.service.namespace['owner'] = 'nobody' + super().setUp() self.service.namespace['app'] = 'search' self.confs = self.service.kvstore if ('test' in self.confs): @@ -60,7 +55,7 @@ def test_update_delete_data(self): self.assertEqual(len(self.col.query(query='{"num": 50}')), 0) def test_query_data(self): - if ('test1' in self.confs): + if 'test1' in self.confs: self.confs['test1'].delete() self.confs.create('test1') self.col = self.confs['test1'].data @@ -74,7 +69,6 @@ def test_query_data(self): data = self.col.query(limit=2, skip=9) self.assertEqual(len(data), 1) - def test_invalid_insert_update(self): self.assertRaises(client.HTTPError, lambda: self.col.insert('NOT VALID DATA')) id = self.col.insert(json.dumps({'foo': 'bar'}))['_key'] @@ -89,16 +83,15 @@ def test_params_data_type_conversion(self): self.assertEqual(len(data), 20) for x in range(20): self.assertEqual(data[x]['data'], 39 - x) - self.assertTrue(not 'ignore' in data[x]) - self.assertTrue(not '_key' in data[x]) + self.assertTrue('ignore' not in data[x]) + self.assertTrue('_key' not in data[x]) def tearDown(self): - if ('test' in self.confs): + if 'test' in self.confs: self.confs['test'].delete() + if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + import unittest + unittest.main() diff --git a/tests/test_logger.py b/tests/test_logger.py index 7e9f5c8e0..46623e363 100755 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,13 +14,13 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import from tests import testlib -import splunklib.client as client +from splunklib import client LEVELS = ["INFO", "WARN", "ERROR", "DEBUG", "CRIT"] + class LoggerTestCase(testlib.SDKTestCase): def check_logger(self, logger): self.check_entity(logger) @@ -44,9 +44,8 @@ def test_crud(self): logger.refresh() self.assertEqual(self.service.loggers['AuditLogger']['level'], saved) + if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + import unittest + unittest.main() diff --git a/tests/test_macro.py b/tests/test_macro.py new file mode 100755 index 000000000..25b72e4d9 --- /dev/null +++ b/tests/test_macro.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python +# +# Copyright 2011-2015 Splunk, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"): you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import +from tests import testlib +import logging + +import splunklib.client as client + +import pytest + +@pytest.mark.smoke +class TestMacro(testlib.SDKTestCase): + def setUp(self): + super(TestMacro, self).setUp() + macros = self.service.macros + logging.debug("Macros namespace: %s", macros.service.namespace) + self.macro_name = testlib.tmpname() + definition = '| eval test="123"' + self.macro = macros.create(self.macro_name, definition) + + def tearDown(self): + super(TestMacro, self).setUp() + for macro in self.service.macros: + if macro.name.startswith('delete-me'): + self.service.macros.delete(macro.name) + + def check_macro(self, macro): + self.check_entity(macro) + expected_fields = ['definition', + 'iseval', + 'args', + 'validation', + 'errormsg'] + for f in expected_fields: + macro[f] + is_eval = macro.iseval + self.assertTrue(is_eval == '1' or is_eval == '0') + + def test_create(self): + self.assertTrue(self.macro_name in self.service.macros) + self.check_macro(self.macro) + + def test_create_with_args(self): + macro_name = testlib.tmpname() + '(1)' + definition = '| eval value="$value$"' + kwargs = { + 'args': 'value', + 'validation': '$value$ > 10', + 'errormsg': 'value must be greater than 10' + } + macro = self.service.macros.create(macro_name, definition=definition, **kwargs) + self.assertTrue(macro_name in self.service.macros) + self.check_macro(macro) + self.assertEqual(macro.iseval, '0') + self.assertEqual(macro.args, kwargs.get('args')) + self.assertEqual(macro.validation, kwargs.get('validation')) + self.assertEqual(macro.errormsg, kwargs.get('errormsg')) + self.service.macros.delete(macro_name) + + def test_delete(self): + self.assertTrue(self.macro_name in self.service.macros) + self.service.macros.delete(self.macro_name) + self.assertFalse(self.macro_name in self.service.macros) + self.assertRaises(client.HTTPError, + self.macro.refresh) + + def test_update(self): + new_definition = '| eval updated="true"' + self.macro.update(definition=new_definition) + self.macro.refresh() + self.assertEqual(self.macro['definition'], new_definition) + + is_eval = testlib.to_bool(self.macro['iseval']) + self.macro.update(iseval=not is_eval) + self.macro.refresh() + self.assertEqual(testlib.to_bool(self.macro['iseval']), not is_eval) + + def test_cannot_update_name(self): + new_name = self.macro_name + '-alteration' + self.assertRaises(client.IllegalOperationException, + self.macro.update, name=new_name) + + def test_name_collision(self): + opts = self.opts.kwargs.copy() + opts['owner'] = '-' + opts['app'] = '-' + opts['sharing'] = 'user' + service = client.connect(**opts) + logging.debug("Namespace for collision testing: %s", service.namespace) + macros = service.macros + name = testlib.tmpname() + + dispatch1 = '| eval macro_one="1"' + dispatch2 = '| eval macro_two="2"' + namespace1 = client.namespace(app='search', sharing='app') + namespace2 = client.namespace(owner='admin', app='search', sharing='user') + new_macro2 = macros.create( + name, dispatch2, + namespace=namespace1) + new_macro1 = macros.create( + name, dispatch1, + namespace=namespace2) + + self.assertRaises(client.AmbiguousReferenceException, + macros.__getitem__, name) + macro1 = macros[name, namespace1] + self.check_macro(macro1) + macro1.update(**{'definition': '| eval number=1'}) + macro1.refresh() + self.assertEqual(macro1['definition'], '| eval number=1') + macro2 = macros[name, namespace2] + macro2.update(**{'definition': '| eval number=2'}) + macro2.refresh() + self.assertEqual(macro2['definition'], '| eval number=2') + self.check_macro(macro2) + + def test_no_equality(self): + self.assertRaises(client.IncomparableException, + self.macro.__eq__, self.macro) + + def test_acl(self): + self.assertEqual(self.macro.access["perms"], None) + self.macro.acl_update(sharing="app", owner="admin", **{"perms.read": "admin, nobody"}) + self.assertEqual(self.macro.access["owner"], "admin") + self.assertEqual(self.macro.access["sharing"], "app") + self.assertEqual(self.macro.access["perms"]["read"], ['admin', 'nobody']) + + def test_acl_fails_without_sharing(self): + self.assertRaisesRegex( + ValueError, + "Required argument 'sharing' is missing.", + self.macro.acl_update, + owner="admin", app="search", **{"perms.read": "admin, nobody"} + ) + + def test_acl_fails_without_owner(self): + self.assertRaisesRegex( + ValueError, + "Required argument 'owner' is missing.", + self.macro.acl_update, + sharing="app", app="search", **{"perms.read": "admin, nobody"} + ) + + +if __name__ == "__main__": + try: + import unittest2 as unittest + except ImportError: + import unittest + unittest.main() diff --git a/tests/test_message.py b/tests/test_message.py index cd76c783b..29f6a8694 100755 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,10 +14,10 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import from tests import testlib -import splunklib.client as client +from splunklib import client + class MessageTest(testlib.SDKTestCase): def setUp(self): @@ -31,6 +31,7 @@ def tearDown(self): testlib.SDKTestCase.tearDown(self) self.service.messages.delete(self.message_name) + class TestCreateDelete(testlib.SDKTestCase): def test_create_delete(self): message_name = testlib.tmpname() @@ -46,11 +47,10 @@ def test_create_delete(self): def test_invalid_name(self): self.assertRaises(client.InvalidNameException, self.service.messages.create, None, value="What?") self.assertRaises(client.InvalidNameException, self.service.messages.create, 42, value="Who, me?") - self.assertRaises(client.InvalidNameException, self.service.messages.create, [1,2,3], value="Who, me?") + self.assertRaises(client.InvalidNameException, self.service.messages.create, [1, 2, 3], value="Who, me?") + if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + import unittest + unittest.main() diff --git a/tests/test_modular_input.py b/tests/test_modular_input.py index ae6e797db..50a49d230 100755 --- a/tests/test_modular_input.py +++ b/tests/test_modular_input.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,20 +14,14 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import -from __future__ import print_function -try: - import unittest2 as unittest -except ImportError: - import unittest -from tests import testlib import pytest +from tests import testlib @pytest.mark.smoke class ModularInputKindTestCase(testlib.SDKTestCase): def setUp(self): - super(ModularInputKindTestCase, self).setUp() + super().setUp() self.uncheckedRestartSplunk() @pytest.mark.app @@ -38,7 +32,7 @@ def test_lists_modular_inputs(self): self.uncheckedRestartSplunk() inputs = self.service.inputs - if ('abcd','test2') not in inputs: + if ('abcd', 'test2') not in inputs: inputs.create('abcd', 'test2', field1='boris') input = inputs['abcd', 'test2'] @@ -55,5 +49,8 @@ def check_modular_input_kind(self, m): self.assertEqual('test2', m['title']) self.assertEqual('simple', m['streaming_mode']) + if __name__ == "__main__": + import unittest + unittest.main() diff --git a/tests/test_modular_input_kinds.py b/tests/test_modular_input_kinds.py index c6b7391ea..1304f269f 100755 --- a/tests/test_modular_input_kinds.py +++ b/tests/test_modular_input_kinds.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,20 +14,16 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import -from __future__ import print_function +import pytest + from tests import testlib -try: - import unittest -except ImportError: - import unittest2 as unittest -import splunklib.client as client -import pytest +from splunklib import client + class ModularInputKindTestCase(testlib.SDKTestCase): def setUp(self): - super(ModularInputKindTestCase, self).setUp() + super().setUp() self.uncheckedRestartSplunk() @pytest.mark.app @@ -40,9 +36,9 @@ def test_list_arguments(self): test1 = self.service.modular_input_kinds['test1'] - expected_args = set(["name", "resname", "key_id", "no_description", "empty_description", - "arg_required_on_edit", "not_required_on_edit", "required_on_create", - "not_required_on_create", "number_field", "string_field", "boolean_field"]) + expected_args = {"name", "resname", "key_id", "no_description", "empty_description", "arg_required_on_edit", + "not_required_on_edit", "required_on_create", "not_required_on_create", "number_field", + "string_field", "boolean_field"} found_args = set(test1.arguments.keys()) self.assertEqual(expected_args, found_args) @@ -77,9 +73,8 @@ def test_list_modular_inputs(self): for m in self.service.modular_input_kinds: self.check_modular_input_kind(m) + if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + import unittest + unittest.main() diff --git a/tests/test_results.py b/tests/test_results.py index 5fdca2b91..bde1c4ab4 100755 --- a/tests/test_results.py +++ b/tests/test_results.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,15 +14,12 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import - +import io from io import BytesIO -from splunklib.six import StringIO -from tests import testlib from time import sleep -import splunklib.results as results -import io +from tests import testlib +from splunklib import results class ResultsTestCase(testlib.SDKTestCase): @@ -161,12 +158,11 @@ def test_read_raw_field_with_segmentation(self): def assert_parsed_results_equals(self, xml_text, expected_results): results_reader = results.ResultsReader(BytesIO(xml_text.encode('utf-8'))) - actual_results = [x for x in results_reader] + actual_results = list(results_reader) self.assertEqual(expected_results, actual_results) + if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + import unittest + unittest.main() diff --git a/tests/test_role.py b/tests/test_role.py index 16205d057..2c087603f 100755 --- a/tests/test_role.py +++ b/tests/test_role.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,20 +14,20 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import from tests import testlib import logging -import splunklib.client as client +from splunklib import client + class RoleTestCase(testlib.SDKTestCase): def setUp(self): - super(RoleTestCase, self).setUp() + super().setUp() self.role_name = testlib.tmpname() self.role = self.service.roles.create(self.role_name) def tearDown(self): - super(RoleTestCase, self).tearDown() + super().tearDown() for role in self.service.roles: if role.name.startswith('delete-me'): self.service.roles.delete(role.name) @@ -91,7 +91,6 @@ def test_invalid_revoke(self): def test_revoke_capability_not_granted(self): self.role.revoke('change_own_password') - def test_update(self): kwargs = {} if 'user' in self.role['imported_roles']: @@ -105,9 +104,8 @@ def test_update(self): self.assertEqual(self.role['imported_roles'], kwargs['imported_roles']) self.assertEqual(int(self.role['srchJobsQuota']), kwargs['srchJobsQuota']) + if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + import unittest + unittest.main() diff --git a/tests/test_saved_search.py b/tests/test_saved_search.py index c15921c0c..a78d9420c 100755 --- a/tests/test_saved_search.py +++ b/tests/test_saved_search.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,22 +14,21 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import import datetime +import pytest from tests import testlib import logging from time import sleep -import splunklib.client as client -from splunklib.six.moves import zip +from splunklib import client + -import pytest @pytest.mark.smoke class TestSavedSearch(testlib.SDKTestCase): def setUp(self): - super(TestSavedSearch, self).setUp() + super().setUp() saved_searches = self.service.saved_searches logging.debug("Saved searches namespace: %s", saved_searches.service.namespace) self.saved_search_name = testlib.tmpname() @@ -37,7 +36,7 @@ def setUp(self): self.saved_search = saved_searches.create(self.saved_search_name, query) def tearDown(self): - super(TestSavedSearch, self).setUp() + super().setUp() for saved_search in self.service.saved_searches: if saved_search.name.startswith('delete-me'): try: @@ -76,9 +75,9 @@ def check_saved_search(self, saved_search): self.assertGreaterEqual(saved_search.suppressed, 0) self.assertGreaterEqual(saved_search['suppressed'], 0) is_scheduled = saved_search.content['is_scheduled'] - self.assertTrue(is_scheduled == '1' or is_scheduled == '0') + self.assertTrue(is_scheduled in ('1', '0')) is_visible = saved_search.content['is_visible'] - self.assertTrue(is_visible == '1' or is_visible == '0') + self.assertTrue(is_visible in ('1', '0')) def test_create(self): self.assertTrue(self.saved_search_name in self.service.saved_searches) @@ -91,7 +90,6 @@ def test_delete(self): self.assertRaises(client.HTTPError, self.saved_search.refresh) - def test_update(self): is_visible = testlib.to_bool(self.saved_search['is_visible']) self.saved_search.update(is_visible=not is_visible) @@ -148,7 +146,7 @@ def test_dispatch(self): def test_dispatch_with_options(self): try: - kwargs = { 'dispatch.buckets': 100 } + kwargs = {'dispatch.buckets': 100} job = self.saved_search.dispatch(**kwargs) while not job.is_ready(): sleep(0.1) @@ -159,13 +157,13 @@ def test_dispatch_with_options(self): def test_history(self): try: old_jobs = self.saved_search.history() - N = len(old_jobs) - logging.debug("Found %d jobs in saved search history", N) + num = len(old_jobs) + logging.debug("Found %d jobs in saved search history", num) job = self.saved_search.dispatch() while not job.is_ready(): sleep(0.1) history = self.saved_search.history() - self.assertEqual(len(history), N+1) + self.assertEqual(len(history), num + 1) self.assertTrue(job.sid in [j.sid for j in history]) finally: job.cancel() @@ -199,13 +197,8 @@ def test_scheduled_times(self): for x in scheduled_times])) time_pairs = list(zip(scheduled_times[:-1], scheduled_times[1:])) for earlier, later in time_pairs: - diff = later-earlier - # diff is an instance of datetime.timedelta, which - # didn't get a total_seconds() method until Python 2.7. - # Since we support Python 2.6, we have to calculate the - # total seconds ourselves. - total_seconds = diff.days*24*60*60 + diff.seconds - self.assertEqual(total_seconds/60.0, 5) + diff = later - earlier + self.assertEqual(diff.total_seconds() / 60.0, 5) def test_no_equality(self): self.assertRaises(client.IncomparableException, @@ -214,7 +207,7 @@ def test_no_equality(self): def test_suppress(self): suppressed_time = self.saved_search['suppressed'] self.assertGreaterEqual(suppressed_time, 0) - new_suppressed_time = suppressed_time+100 + new_suppressed_time = suppressed_time + 100 self.saved_search.suppress(new_suppressed_time) self.assertLessEqual(self.saved_search['suppressed'], new_suppressed_time) @@ -223,9 +216,31 @@ def test_suppress(self): self.saved_search.unsuppress() self.assertEqual(self.saved_search['suppressed'], 0) + def test_acl(self): + self.assertEqual(self.saved_search.access["perms"], None) + self.saved_search.acl_update(sharing="app", owner="admin", app="search", **{"perms.read": "admin, nobody"}) + self.assertEqual(self.saved_search.access["owner"], "admin") + self.assertEqual(self.saved_search.access["app"], "search") + self.assertEqual(self.saved_search.access["sharing"], "app") + self.assertEqual(self.saved_search.access["perms"]["read"], ['admin', 'nobody']) + + def test_acl_fails_without_sharing(self): + self.assertRaisesRegex( + ValueError, + "Required argument 'sharing' is missing.", + self.saved_search.acl_update, + owner="admin", app="search", **{"perms.read": "admin, nobody"} + ) + + def test_acl_fails_without_owner(self): + self.assertRaisesRegex( + ValueError, + "Required argument 'owner' is missing.", + self.saved_search.acl_update, + sharing="app", app="search", **{"perms.read": "admin, nobody"} + ) + if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + import unittest + unittest.main() diff --git a/tests/test_service.py b/tests/test_service.py index e3ffef681..6433b56b5 100755 --- a/tests/test_service.py +++ b/tests/test_service.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,12 +14,12 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import -from tests import testlib import unittest +import pytest +from tests import testlib -import splunklib.client as client -from splunklib.client import AuthenticationError +from splunklib import client +from splunklib.binding import AuthenticationError from splunklib.client import Service from splunklib.binding import HTTPError @@ -57,7 +57,7 @@ def test_info_with_namespace(self): try: self.assertEqual(self.service.info.licenseState, 'OK') except HTTPError as he: - self.fail("Couldn't get the server info, probably got a 403! %s" % he.message) + self.fail(f"Couldn't get the server info, probably got a 403! {he.message}") self.service.namespace["owner"] = owner self.service.namespace["app"] = app @@ -102,11 +102,6 @@ def test_parse(self): # objectified form of the results, but for now there's # nothing to test but a good response code. response = self.service.parse('search * abc="def" | dedup abc') - - # Splunk Version 9+ using API v2: search/v2/parser - if self.service.splunk_version[0] >= 9: - self.assertGreaterEqual(9, self.service.splunk_version[0]) - self.assertEqual(response.status, 200) def test_parse_fail(self): @@ -173,6 +168,7 @@ def _create_unauthenticated_service(self): }) # To check the HEC event endpoint using Endpoint instance + @pytest.mark.smoke def test_hec_event(self): import json service_hec = client.connect(host='localhost', scheme='https', port=8088, @@ -192,7 +188,7 @@ def setUp(self): def assertIsNotNone(self, obj, msg=None): if obj is None: - raise self.failureException(msg or '%r is not None' % obj) + raise self.failureException(msg or f'{obj} is not None') def test_login_and_store_cookie(self): self.assertIsNotNone(self.service.get_cookies()) @@ -369,8 +365,5 @@ def test_proper_namespace_with_service_namespace(self): if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + unittest.main() diff --git a/tests/test_storage_passwords.py b/tests/test_storage_passwords.py index 4f2fee81f..bda832dd9 100644 --- a/tests/test_storage_passwords.py +++ b/tests/test_storage_passwords.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,11 +14,9 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import from tests import testlib -import logging -import splunklib.client as client +from splunklib import client class Tests(testlib.SDKTestCase): @@ -153,11 +151,10 @@ def test_read(self): p = self.storage_passwords.create("changeme", username) self.assertEqual(start_count + 1, len(self.storage_passwords)) - for sp in self.storage_passwords: - self.assertTrue(p.name in self.storage_passwords) - # Name works with or without a trailing colon - self.assertTrue((":" + username + ":") in self.storage_passwords) - self.assertTrue((":" + username) in self.storage_passwords) + self.assertTrue(p.name in self.storage_passwords) + # Name works with or without a trailing colon + self.assertTrue((":" + username + ":") in self.storage_passwords) + self.assertTrue((":" + username) in self.storage_passwords) p.delete() self.assertEqual(start_count, len(self.storage_passwords)) @@ -233,9 +230,8 @@ def test_spaces_in_username(self): p.delete() self.assertEqual(start_count, len(self.storage_passwords)) + if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + import unittest + unittest.main() diff --git a/tests/test_user.py b/tests/test_user.py index b8a97f811..e20b96946 100755 --- a/tests/test_user.py +++ b/tests/test_user.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,20 +14,19 @@ # License for the specific language governing permissions and limitations # under the License. -from __future__ import absolute_import from tests import testlib -import logging -import splunklib.client as client +from splunklib import client + class UserTestCase(testlib.SDKTestCase): def check_user(self, user): self.check_entity(user) # Verify expected fields exist [user[f] for f in ['email', 'password', 'realname', 'roles']] - + def setUp(self): - super(UserTestCase, self).setUp() + super().setUp() self.username = testlib.tmpname() self.user = self.service.users.create( self.username, @@ -35,7 +34,7 @@ def setUp(self): roles=['power', 'user']) def tearDown(self): - super(UserTestCase, self).tearDown() + super().tearDown() for user in self.service.users: if user.name.startswith('delete-me'): self.service.users.delete(user.name) @@ -84,9 +83,8 @@ def test_delete_is_case_insensitive(self): self.assertFalse(self.username in users) self.assertFalse(self.username.upper() in users) + if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + import unittest + unittest.main() diff --git a/tests/test_utils.py b/tests/test_utils.py index 5b6b712ca..40ed53197 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,38 +1,34 @@ -from __future__ import absolute_import +import unittest +import os from tests import testlib - -try: - from utils import * -except ImportError: - raise Exception("Add the SDK repository to your PYTHONPATH to run the test cases " - "(e.g., export PYTHONPATH=~/splunk-sdk-python.") - +from utils import dslice TEST_DICT = { - 'username':'admin', - 'password':'changeme', - 'port' : 8089, - 'host' : 'localhost', - 'scheme': 'https' - } + 'username': 'admin', + 'password': 'changeme', + 'port': 8089, + 'host': 'localhost', + 'scheme': 'https' +} + class TestUtils(testlib.SDKTestCase): def setUp(self): - super(TestUtils, self).setUp() + super().setUp() # Test dslice when a dict is passed to change key names def test_dslice_dict_args(self): args = { - 'username':'user-name', - 'password':'new_password', - 'port': 'admin_port', - 'foo':'bar' - } + 'username': 'user-name', + 'password': 'new_password', + 'port': 'admin_port', + 'foo': 'bar' + } expected = { - 'user-name':'admin', - 'new_password':'changeme', - 'admin_port':8089 - } + 'user-name': 'admin', + 'new_password': 'changeme', + 'admin_port': 8089 + } self.assertTrue(expected == dslice(TEST_DICT, args)) # Test dslice when a list is passed @@ -43,43 +39,64 @@ def test_dslice_list_args(self): 'port', 'host', 'foo' - ] + ] expected = { - 'username':'admin', - 'password':'changeme', - 'port':8089, - 'host':'localhost' - } + 'username': 'admin', + 'password': 'changeme', + 'port': 8089, + 'host': 'localhost' + } self.assertTrue(expected == dslice(TEST_DICT, test_list)) # Test dslice when a single string is passed def test_dslice_arg(self): test_arg = 'username' expected = { - 'username':'admin' - } + 'username': 'admin' + } self.assertTrue(expected == dslice(TEST_DICT, test_arg)) # Test dslice using all three types of arguments def test_dslice_all_args(self): test_args = [ - {'username':'new_username'}, + {'username': 'new_username'}, ['password', - 'host'], + 'host'], 'port' ] expected = { - 'new_username':'admin', - 'password':'changeme', - 'host':'localhost', - 'port':8089 + 'new_username': 'admin', + 'password': 'changeme', + 'host': 'localhost', + 'port': 8089 } self.assertTrue(expected == dslice(TEST_DICT, *test_args)) +class FilePermissionTest(unittest.TestCase): + + def setUp(self): + super().setUp() + + # Check for any change in the default file permission(i.e 644) for all files within splunklib + def test_filePermissions(self): + + def checkFilePermissions(dir_path): + for file in os.listdir(dir_path): + if file.__contains__('pycache'): + continue + path = os.path.join(dir_path, file) + if os.path.isfile(path): + permission = oct(os.stat(path).st_mode) + self.assertEqual(permission, '0o100644') + else: + checkFilePermissions(path) + + dir_path = os.path.join('..', 'splunklib') + checkFilePermissions(dir_path) + + if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + import unittest + unittest.main() diff --git a/tests/testlib.py b/tests/testlib.py index 4a99e026a..e7c7b6a70 100644 --- a/tests/testlib.py +++ b/tests/testlib.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -15,35 +15,26 @@ # under the License. """Shared unit test utilities.""" -from __future__ import absolute_import -from __future__ import print_function import contextlib +import os +import time +import logging import sys -from splunklib import six # Run the test suite on the SDK without installing it. sys.path.insert(0, '../') -import splunklib.client as client from time import sleep from datetime import datetime, timedelta -try: - import unittest2 as unittest -except ImportError: - import unittest +import unittest -try: - from utils import parse -except ImportError: - raise Exception("Add the SDK repository to your PYTHONPATH to run the test cases " - "(e.g., export PYTHONPATH=~/splunk-sdk-python.") +from utils import parse + +from splunklib import client -import os -import time -import logging logging.basicConfig( filename='test.log', @@ -62,10 +53,9 @@ class WaitTimedOutError(Exception): def to_bool(x): if x == '1': return True - elif x == '0': + if x == '0': return False - else: - raise ValueError("Not a boolean value: %s", x) + raise ValueError(f"Not a boolean value: {x}") def tmpname(): @@ -102,7 +92,7 @@ def assertEventuallyTrue(self, predicate, timeout=30, pause_time=0.5, logging.debug("wait finished after %s seconds", datetime.now() - start) def check_content(self, entity, **kwargs): - for k, v in six.iteritems(kwargs): + for k, v in kwargs: self.assertEqual(entity[k], str(v)) def check_entity(self, entity): @@ -154,9 +144,7 @@ def clear_restart_message(self): try: self.service.delete("messages/restart_required") except client.HTTPError as he: - if he.status == 404: - pass - else: + if he.status != 404: raise @contextlib.contextmanager @@ -179,7 +167,7 @@ def install_app_from_collection(self, name): self.service.post("apps/local", **kwargs) except client.HTTPError as he: if he.status == 400: - raise IOError("App %s not found in app collection" % name) + raise IOError(f"App {name} not found in app collection") if self.service.restart_required: self.service.restart(120) self.installedApps.append(name) @@ -268,6 +256,6 @@ def tearDown(self): except HTTPError as error: if not (os.name == 'nt' and error.status == 500): raise - print('Ignoring failure to delete {0} during tear down: {1}'.format(appName, error)) + print(f'Ignoring failure to delete {appName} during tear down: {error}') if self.service.restart_required: self.clear_restart_message() diff --git a/tox.ini b/tox.ini index 8b8bcb1b5..e69c2e6ad 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = clean,docs,py27,py37 +envlist = clean,docs,py37,py39,313 skipsdist = {env:TOXBUILD:false} [testenv:pep8] @@ -28,11 +28,7 @@ setenv = SPLUNK_HOME=/opt/splunk allowlist_externals = make deps = pytest pytest-cov - xmlrunner - unittest2 - unittest-xml-reporting python-dotenv - deprecation distdir = build commands = diff --git a/utils/__init__.py b/utils/__init__.py index bd0900c3d..b6c455656 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,14 +14,14 @@ """Utility module shared by the SDK unit tests.""" -from __future__ import absolute_import from utils.cmdopts import * -from splunklib import six + def config(option, opt, value, parser): assert opt == "--config" parser.load(value) + # Default Splunk cmdline rules RULES_SPLUNK = { 'config': { @@ -30,7 +30,7 @@ def config(option, opt, value, parser): 'callback': config, 'type': "string", 'nargs': "1", - 'help': "Load options from config file" + 'help': "Load options from config file" }, 'scheme': { 'flags': ["--scheme"], @@ -40,30 +40,30 @@ def config(option, opt, value, parser): 'host': { 'flags': ["--host"], 'default': "localhost", - 'help': "Host name (default 'localhost')" + 'help': "Host name (default 'localhost')" }, - 'port': { + 'port': { 'flags': ["--port"], 'default': "8089", - 'help': "Port number (default 8089)" + 'help': "Port number (default 8089)" }, 'app': { - 'flags': ["--app"], + 'flags': ["--app"], 'help': "The app context (optional)" }, 'owner': { - 'flags': ["--owner"], + 'flags': ["--owner"], 'help': "The user context (optional)" }, 'username': { 'flags': ["--username"], 'default': None, - 'help': "Username to login with" + 'help': "Username to login with" }, 'password': { - 'flags': ["--password"], + 'flags': ["--password"], 'default': None, - 'help': "Password to login with" + 'help': "Password to login with" }, 'version': { 'flags': ["--version"], @@ -84,28 +84,30 @@ def config(option, opt, value, parser): FLAGS_SPLUNK = list(RULES_SPLUNK.keys()) + # value: dict, args: [(dict | list | str)*] def dslice(value, *args): """Returns a 'slice' of the given dictionary value containing only the requested keys. The keys can be requested in a variety of ways, as an arg list of keys, as a list of keys, or as a dict whose key(s) represent - the source keys and whose corresponding values represent the resulting - key(s) (enabling key rename), or any combination of the above.""" + the source keys and whose corresponding values represent the resulting + key(s) (enabling key rename), or any combination of the above.""" result = {} for arg in args: if isinstance(arg, dict): - for k, v in six.iteritems(arg): - if k in value: + for k, v in (list(arg.items())): + if k in value: result[v] = value[k] elif isinstance(arg, list): for k in arg: - if k in value: + if k in value: result[k] = value[k] else: - if arg in value: + if arg in value: result[arg] = value[arg] return result + def parse(argv, rules=None, config=None, **kwargs): """Parse the given arg vector with the default Splunk command rules.""" parser_ = parser(rules, **kwargs) @@ -113,8 +115,8 @@ def parse(argv, rules=None, config=None, **kwargs): parser_.loadenv(config) return parser_.parse(argv).result + def parser(rules=None, **kwargs): """Instantiate a parser with the default Splunk command rules.""" rules = RULES_SPLUNK if rules is None else dict(RULES_SPLUNK, **rules) return Parser(rules, **kwargs) - diff --git a/utils/cmdopts.py b/utils/cmdopts.py index b0cbb7328..63cdfb1d4 100644 --- a/utils/cmdopts.py +++ b/utils/cmdopts.py @@ -1,4 +1,4 @@ -# Copyright 2011-2015 Splunk, Inc. +# Copyright © 2011-2024 Splunk, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"): you may # not use this file except in compliance with the License. You may obtain @@ -14,8 +14,6 @@ """Command line utilities shared by command line tools & unit tests.""" -from __future__ import absolute_import -from __future__ import print_function from os import path from optparse import OptionParser import sys @@ -25,7 +23,7 @@ # Print the given message to stderr, and optionally exit def error(message, exitcode = None): - print("Error: %s" % message, file=sys.stderr) + print(f"Error: {message}", file=sys.stderr) if exitcode is not None: sys.exit(exitcode)