diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 863d88aa9..cc26502ca 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -53,7 +53,7 @@ jobs: with: platforms: all - name: Build wheels - uses: pypa/cibuildwheel@v2.20.0 + uses: pypa/cibuildwheel@v2.21.3 env: BUILD_EXTENSION: yes - name: Save wheels diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 43193ea50..5ab9c4c72 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -55,19 +55,15 @@ jobs: strategy: matrix: python: - - "3.8" - "3.9" - "3.10" - "3.11" - "3.12" - "3.13" - - "pypy-3.9" - "pypy-3.10" is_main: - ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} exclude: - - python: "pypy-3.9" - is_main: false - python: "pypy-3.10" is_main: false steps: diff --git a/.gitignore b/.gitignore index d8e6697a8..291bf1fb6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,16 +1,16 @@ *.pyc *.so .coverage -.direnv +.direnv/ .envrc .idea/ -.mypy_cache -.tox +.mypy_cache/ +.tox/ +.vscode/ build/ compliance/reports/ -experiments/compression/corpus/ dist/ docs/_build/ +experiments/compression/corpus/ htmlcov/ -MANIFEST -websockets.egg-info/ +src/websockets.egg-info/ diff --git a/Makefile b/Makefile index fd36d0367..06bfe9edc 100644 --- a/Makefile +++ b/Makefile @@ -8,8 +8,8 @@ build: python setup.py build_ext --inplace style: - ruff format src tests - ruff check --fix src tests + ruff format compliance src tests + ruff check --fix compliance src tests types: mypy --strict src diff --git a/compliance/README.rst b/compliance/README.rst index 8570f9176..ee491310f 100644 --- a/compliance/README.rst +++ b/compliance/README.rst @@ -4,47 +4,76 @@ Autobahn Testsuite General information and installation instructions are available at https://github.com/crossbario/autobahn-testsuite. -To improve performance, you should compile the C extension first:: +Running the test suite +---------------------- + +All commands below must be run from the root directory of the repository. + +To get acceptable performance, compile the C extension first: + +.. code-block:: console $ python setup.py build_ext --inplace -Running the test suite ----------------------- +Run each command in a different shell. Testing takes several minutes to complete +— wstest is the bottleneck. When clients finish, stop servers with Ctrl-C. -All commands below must be run from the directory containing this file. +You can exclude slow tests by modifying the configuration files as follows:: -To test the server:: + "exclude-cases": ["9.*", "12.*", "13.*"] - $ PYTHONPATH=.. python test_server.py - $ wstest -m fuzzingclient +The test server and client applications shouldn't display any exceptions. -To test the client:: +To test the servers: - $ wstest -m fuzzingserver - $ PYTHONPATH=.. python test_client.py +.. code-block:: console -Run the first command in a shell. Run the second command in another shell. -It should take about ten minutes to complete — wstest is the bottleneck. -Then kill the first one with Ctrl-C. + $ PYTHONPATH=src python compliance/asyncio/server.py + $ PYTHONPATH=src python compliance/sync/server.py -The test client or server shouldn't display any exceptions. The results are -stored in reports/clients/index.html. + $ docker run --interactive --tty --rm \ + --volume "${PWD}/compliance/config:/config" \ + --volume "${PWD}/compliance/reports:/reports" \ + --name fuzzingclient \ + crossbario/autobahn-testsuite \ + wstest --mode fuzzingclient --spec /config/fuzzingclient.json -Note that the Autobahn software only supports Python 2, while ``websockets`` -only supports Python 3; you need two different environments. + $ open compliance/reports/servers/index.html -Conformance notes ------------------ +To test the clients: -Some test cases are more strict than the RFC. Given the implementation of the -library and the test echo client or server, ``websockets`` gets a "Non-Strict" -in these cases. +.. code-block:: console + $ docker run --interactive --tty --rm \ + --volume "${PWD}/compliance/config:/config" \ + --volume "${PWD}/compliance/reports:/reports" \ + --publish 9001:9001 \ + --name fuzzingserver \ + crossbario/autobahn-testsuite \ + wstest --mode fuzzingserver --spec /config/fuzzingserver.json -In 3.2, 3.3, 4.1.3, 4.1.4, 4.2.3, 4.2.4, and 5.15 ``websockets`` notices the -protocol error and closes the connection before it has had a chance to echo -the previous frame. + $ PYTHONPATH=src python compliance/asyncio/client.py + $ PYTHONPATH=src python compliance/sync/client.py -In 6.4.3 and 6.4.4, even though it uses an incremental decoder, ``websockets`` -doesn't notice the invalid utf-8 fast enough to get a "Strict" pass. These -tests are more strict than the RFC. + $ open compliance/reports/clients/index.html +Conformance notes +----------------- + +Some test cases are more strict than the RFC. Given the implementation of the +library and the test client and server applications, websockets passes with a +"Non-Strict" result in these cases. + +In 3.2, 3.3, 4.1.3, 4.1.4, 4.2.3, 4.2.4, and 5.15 websockets notices the +protocol error and closes the connection at the library level before the +application gets a chance to echo the previous frame. + +In 6.4.1, 6.4.2, 6.4.3, and 6.4.4, even though it uses an incremental decoder, +websockets doesn't notice the invalid utf-8 fast enough to get a "Strict" pass. +These tests are more strict than the RFC. + +Test case 7.1.5 fails because websockets treats closing the connection in the +middle of a fragmented message as a protocol error. As a consequence, it sends +a close frame with code 1002. The test suite expects a close frame with code +1000, echoing the close code that it sent. This isn't required. RFC 6455 states +that "the endpoint typically echos the status code it received", which leaves +the possibility to send a close frame with a different status code. diff --git a/compliance/asyncio/client.py b/compliance/asyncio/client.py new file mode 100644 index 000000000..044ed6043 --- /dev/null +++ b/compliance/asyncio/client.py @@ -0,0 +1,59 @@ +import asyncio +import json +import logging + +from websockets.asyncio.client import connect +from websockets.exceptions import WebSocketException + + +logging.basicConfig(level=logging.WARNING) + +SERVER = "ws://localhost:9001" + +AGENT = "websockets.asyncio" + + +async def get_case_count(): + async with connect(f"{SERVER}/getCaseCount") as ws: + return json.loads(await ws.recv()) + + +async def run_case(case): + async with connect( + f"{SERVER}/runCase?case={case}&agent={AGENT}", + max_size=2**25, + ) as ws: + try: + async for msg in ws: + await ws.send(msg) + except WebSocketException: + pass + + +async def update_reports(): + async with connect( + f"{SERVER}/updateReports?agent={AGENT}", + open_timeout=60, + ): + pass + + +async def main(): + cases = await get_case_count() + for case in range(1, cases + 1): + print(f"Running test case {case:03d} / {cases}... ", end="\t") + try: + await run_case(case) + except WebSocketException as exc: + print(f"ERROR: {type(exc).__name__}: {exc}") + except Exception as exc: + print(f"FAIL: {type(exc).__name__}: {exc}") + else: + print("OK") + print(f"Ran {cases} test cases") + await update_reports() + print("Updated reports") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/compliance/asyncio/server.py b/compliance/asyncio/server.py new file mode 100644 index 000000000..84deb9727 --- /dev/null +++ b/compliance/asyncio/server.py @@ -0,0 +1,36 @@ +import asyncio +import logging + +from websockets.asyncio.server import serve +from websockets.exceptions import WebSocketException + + +logging.basicConfig(level=logging.WARNING) + +HOST, PORT = "0.0.0.0", 9002 + + +async def echo(ws): + try: + async for msg in ws: + await ws.send(msg) + except WebSocketException: + pass + + +async def main(): + async with serve( + echo, + HOST, + PORT, + server_header="websockets.sync", + max_size=2**25, + ) as server: + try: + await server.serve_forever() + except KeyboardInterrupt: + pass + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/compliance/config/fuzzingclient.json b/compliance/config/fuzzingclient.json new file mode 100644 index 000000000..756ad03b6 --- /dev/null +++ b/compliance/config/fuzzingclient.json @@ -0,0 +1,11 @@ + +{ + "servers": [{ + "url": "ws://host.docker.internal:9002" + }, { + "url": "ws://host.docker.internal:9003" + }], + "outdir": "/reports/servers", + "cases": ["*"], + "exclude-cases": [] +} diff --git a/compliance/config/fuzzingserver.json b/compliance/config/fuzzingserver.json new file mode 100644 index 000000000..384caf0a2 --- /dev/null +++ b/compliance/config/fuzzingserver.json @@ -0,0 +1,7 @@ + +{ + "url": "ws://localhost:9001", + "outdir": "/reports/clients", + "cases": ["*"], + "exclude-cases": [] +} diff --git a/compliance/fuzzingclient.json b/compliance/fuzzingclient.json deleted file mode 100644 index 202ff49a0..000000000 --- a/compliance/fuzzingclient.json +++ /dev/null @@ -1,11 +0,0 @@ - -{ - "options": {"failByDrop": false}, - "outdir": "./reports/servers", - - "servers": [{"agent": "websockets", "url": "ws://localhost:8642", "options": {"version": 18}}], - - "cases": ["*"], - "exclude-cases": [], - "exclude-agent-cases": {} -} diff --git a/compliance/fuzzingserver.json b/compliance/fuzzingserver.json deleted file mode 100644 index 1bdb42723..000000000 --- a/compliance/fuzzingserver.json +++ /dev/null @@ -1,12 +0,0 @@ - -{ - "url": "ws://localhost:8642", - - "options": {"failByDrop": false}, - "outdir": "./reports/clients", - "webport": 8080, - - "cases": ["*"], - "exclude-cases": [], - "exclude-agent-cases": {} -} diff --git a/compliance/sync/client.py b/compliance/sync/client.py new file mode 100644 index 000000000..c810e1beb --- /dev/null +++ b/compliance/sync/client.py @@ -0,0 +1,58 @@ +import json +import logging + +from websockets.exceptions import WebSocketException +from websockets.sync.client import connect + + +logging.basicConfig(level=logging.WARNING) + +SERVER = "ws://localhost:9001" + +AGENT = "websockets.sync" + + +def get_case_count(): + with connect(f"{SERVER}/getCaseCount") as ws: + return json.loads(ws.recv()) + + +def run_case(case): + with connect( + f"{SERVER}/runCase?case={case}&agent={AGENT}", + max_size=2**25, + ) as ws: + try: + for msg in ws: + ws.send(msg) + except WebSocketException: + pass + + +def update_reports(): + with connect( + f"{SERVER}/updateReports?agent={AGENT}", + open_timeout=60, + ): + pass + + +def main(): + cases = get_case_count() + for case in range(1, cases + 1): + print(f"Running test case {case:03d} / {cases}... ", end="\t") + try: + run_case(case) + except WebSocketException as exc: + print(f"ERROR: {type(exc).__name__}: {exc}") + except Exception as exc: + print(f"FAIL: {type(exc).__name__}: {exc}") + else: + print("OK") + print(f"Ran {cases} test cases") + update_reports() + print("Updated reports") + + +if __name__ == "__main__": + main() diff --git a/compliance/sync/server.py b/compliance/sync/server.py new file mode 100644 index 000000000..494f56a44 --- /dev/null +++ b/compliance/sync/server.py @@ -0,0 +1,35 @@ +import logging + +from websockets.exceptions import WebSocketException +from websockets.sync.server import serve + + +logging.basicConfig(level=logging.WARNING) + +HOST, PORT = "0.0.0.0", 9003 + + +def echo(ws): + try: + for msg in ws: + ws.send(msg) + except WebSocketException: + pass + + +def main(): + with serve( + echo, + HOST, + PORT, + server_header="websockets.asyncio", + max_size=2**25, + ) as server: + try: + server.serve_forever() + except KeyboardInterrupt: + pass + + +if __name__ == "__main__": + main() diff --git a/compliance/test_client.py b/compliance/test_client.py deleted file mode 100644 index 8e22569fd..000000000 --- a/compliance/test_client.py +++ /dev/null @@ -1,48 +0,0 @@ -import asyncio -import json -import logging -import urllib.parse - -from websockets.asyncio.client import connect - - -logging.basicConfig(level=logging.WARNING) - -# Uncomment this line to make only websockets more verbose. -# logging.getLogger('websockets').setLevel(logging.DEBUG) - - -SERVER = "ws://127.0.0.1:8642" -AGENT = "websockets" - - -async def get_case_count(server): - uri = f"{server}/getCaseCount" - async with connect(uri) as ws: - msg = ws.recv() - return json.loads(msg) - - -async def run_case(server, case, agent): - uri = f"{server}/runCase?case={case}&agent={agent}" - async with connect(uri, max_size=2 ** 25, max_queue=1) as ws: - async for msg in ws: - await ws.send(msg) - - -async def update_reports(server, agent): - uri = f"{server}/updateReports?agent={agent}" - async with connect(uri): - pass - - -async def run_tests(server, agent): - cases = await get_case_count(server) - for case in range(1, cases + 1): - print(f"Running test case {case} out of {cases}", end="\r") - await run_case(server, case, agent) - print(f"Ran {cases} test cases ") - await update_reports(server, agent) - - -asyncio.run(run_tests(SERVER, urllib.parse.quote(AGENT))) diff --git a/compliance/test_server.py b/compliance/test_server.py deleted file mode 100644 index 39176e902..000000000 --- a/compliance/test_server.py +++ /dev/null @@ -1,29 +0,0 @@ -import asyncio -import logging - -from websockets.asyncio.server import serve - - -logging.basicConfig(level=logging.WARNING) - -# Uncomment this line to make only websockets more verbose. -# logging.getLogger('websockets').setLevel(logging.DEBUG) - - -HOST, PORT = "127.0.0.1", 8642 - - -async def echo(ws): - async for msg in ws: - await ws.send(msg) - - -async def main(): - with serve(echo, HOST, PORT, max_size=2 ** 25, max_queue=1): - try: - await asyncio.get_running_loop().create_future() # run forever - except KeyboardInterrupt: - pass - - -asyncio.run(main()) diff --git a/docs/faq/client.rst b/docs/faq/client.rst index 0a7aab6e2..cc9856a8b 100644 --- a/docs/faq/client.rst +++ b/docs/faq/client.rst @@ -84,11 +84,12 @@ How do I reconnect when the connection drops? Use :func:`~websockets.asyncio.client.connect` as an asynchronous iterator:: from websockets.asyncio.client import connect + from websockets.exceptions import ConnectionClosed async for websocket in connect(...): try: ... - except websockets.ConnectionClosed: + except ConnectionClosed: continue Make sure you handle exceptions in the ``async for`` loop. Uncaught exceptions diff --git a/docs/faq/common.rst b/docs/faq/common.rst index 0dc4a3aeb..ba7a95932 100644 --- a/docs/faq/common.rst +++ b/docs/faq/common.rst @@ -131,8 +131,8 @@ How do I respond to pings? If you are referring to Ping_ and Pong_ frames defined in the WebSocket protocol, don't bother, because websockets handles them for you. -.. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 -.. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 +.. _Ping: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5.5.2 +.. _Pong: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5.5.3 If you are connecting to a server that defines its own heartbeat at the application level, then you need to build that logic into your application. diff --git a/docs/faq/server.rst b/docs/faq/server.rst index 63eb5ffc6..ce7e1962d 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -147,7 +147,7 @@ Then, call :meth:`~ServerConnection.send`:: async def message_user(user_id, message): websocket = CONNECTIONS[user_id] # raises KeyError if user disconnected - await websocket.send(message) # may raise websockets.ConnectionClosed + await websocket.send(message) # may raise websockets.exceptions.ConnectionClosed Add error handling according to the behavior you want if the user disconnected before the message could be sent. diff --git a/docs/howto/django.rst b/docs/howto/django.rst index dada9c5e4..4fe2311cb 100644 --- a/docs/howto/django.rst +++ b/docs/howto/django.rst @@ -121,8 +121,7 @@ authentication fails, it closes the connection and exits. When we call an API that makes a database query such as ``get_user()``, we wrap the call in :func:`~asyncio.to_thread`. Indeed, the Django ORM doesn't support asynchronous I/O. It would block the event loop if it didn't run in a -separate thread. :func:`~asyncio.to_thread` is available since Python 3.9. In -earlier versions, use :meth:`~asyncio.loop.run_in_executor` instead. +separate thread. Finally, we start a server with :func:`~websockets.asyncio.server.serve`. diff --git a/docs/howto/extensions.rst b/docs/howto/extensions.rst index 3c8a7d72a..c4e9da626 100644 --- a/docs/howto/extensions.rst +++ b/docs/howto/extensions.rst @@ -7,7 +7,7 @@ During the opening handshake, WebSocket clients and servers negotiate which extensions_ will be used with which parameters. Then each frame is processed by extensions before being sent or after being received. -.. _extensions: https://www.rfc-editor.org/rfc/rfc6455.html#section-9 +.. _extensions: https://datatracker.ietf.org/doc/html/rfc6455.html#section-9 As a consequence, writing an extension requires implementing several classes: diff --git a/docs/howto/sansio.rst b/docs/howto/sansio.rst index d41519ff0..ca530e6a1 100644 --- a/docs/howto/sansio.rst +++ b/docs/howto/sansio.rst @@ -302,21 +302,24 @@ Tips Serialize operations .................... -The Sans-I/O layer expects to run sequentially. If your interact with it from -multiple threads or coroutines, you must ensure correct serialization. This -should happen automatically in a cooperative multitasking environment. +The Sans-I/O layer is designed to run sequentially. If you interact with it from +multiple threads or coroutines, you must ensure correct serialization. -However, you still have to make sure you don't break this property by -accident. For example, serialize writes to the network -when :meth:`~protocol.Protocol.data_to_send` returns multiple values to -prevent concurrent writes from interleaving incorrectly. +Usually, this comes for free in a cooperative multitasking environment. In a +preemptive multitasking environment, it requires mutual exclusion. -Avoid buffers -............. +Furthermore, you must serialize writes to the network. When +:meth:`~protocol.Protocol.data_to_send` returns several values, you must write +them all before starting the next write. -The Sans-I/O layer doesn't do any buffering. It makes events available in +Minimize buffers +................ + +The Sans-I/O layer doesn't perform any buffering. It makes events available in :meth:`~protocol.Protocol.events_received` as soon as they're received. -You should make incoming messages available to the application immediately and -stop further processing until the application fetches them. This will usually -result in the best performance. +You should make incoming messages available to the application immediately. + +A small buffer of incoming messages will usually result in the best performance. +It will reduce context switching between the library and the application while +ensuring that backpressure is propagated. diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index f3e42591e..02d4c6f01 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -3,8 +3,8 @@ Upgrade to the new :mod:`asyncio` implementation .. currentmodule:: websockets -The new :mod:`asyncio` implementation is a rewrite of the original -implementation of websockets. +The new :mod:`asyncio` implementation, which is now the default, is a rewrite of +the original implementation of websockets. It provides a very similar API. However, there are a few differences. @@ -27,15 +27,9 @@ respectively. .. admonition:: What will happen to the original implementation? :class: hint - The original implementation is now considered legacy. - - The next steps are: - - 1. Deprecating it once the new implementation is considered sufficiently - robust. - 2. Maintaining it for five years per the :ref:`backwards-compatibility - policy `. - 3. Removing it. This is expected to happen around 2030. + The original implementation is deprecated. It will be maintained for five + years after deprecation according to the :ref:`backwards-compatibility + policy `. Then, by 2030, it will be removed. .. _deprecated APIs: @@ -69,13 +63,14 @@ Import paths For context, the ``websockets`` package is structured as follows: * The new implementation is found in the ``websockets.asyncio`` package. -* The original implementation was moved to the ``websockets.legacy`` package. -* The ``websockets`` package provides aliases for convenience. Currently, they - point to the original implementation. They will be updated to point to the new - implementation when it feels mature. +* The original implementation was moved to the ``websockets.legacy`` package + and deprecated. +* The ``websockets`` package provides aliases for convenience. They were + switched to the new implementation in version 14.0 or deprecated when there + wasn't an equivalent API. * The ``websockets.client`` and ``websockets.server`` packages provide aliases - for backwards-compatibility with earlier versions of websockets. They will - be deprecated together with the original implementation. + for backwards-compatibility with earlier versions of websockets. They were + deprecated. To upgrade to the new :mod:`asyncio` implementation, change import paths as shown in the tables below. @@ -90,12 +85,12 @@ Client APIs +-------------------------------------------------------------------+-----------------------------------------------------+ | Legacy :mod:`asyncio` implementation | New :mod:`asyncio` implementation | +===================================================================+=====================================================+ -| ``websockets.connect()`` |br| | :func:`websockets.asyncio.client.connect` | -| ``websockets.client.connect()`` |br| | | +| ``websockets.connect()`` *(before 14.0)* |br| | ``websockets.connect()`` *(since 14.0)* |br| | +| ``websockets.client.connect()`` |br| | :func:`websockets.asyncio.client.connect` | | :func:`websockets.legacy.client.connect` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -| ``websockets.unix_connect()`` |br| | :func:`websockets.asyncio.client.unix_connect` | -| ``websockets.client.unix_connect()`` |br| | | +| ``websockets.unix_connect()`` *(before 14.0)* |br| | ``websockets.unix_connect()`` *(since 14.0)* |br| | +| ``websockets.client.unix_connect()`` |br| | :func:`websockets.asyncio.client.unix_connect` | | :func:`websockets.legacy.client.unix_connect` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.WebSocketClientProtocol`` |br| | :class:`websockets.asyncio.client.ClientConnection` | @@ -109,12 +104,12 @@ Server APIs +-------------------------------------------------------------------+-----------------------------------------------------+ | Legacy :mod:`asyncio` implementation | New :mod:`asyncio` implementation | +===================================================================+=====================================================+ -| ``websockets.serve()`` |br| | :func:`websockets.asyncio.server.serve` | -| ``websockets.server.serve()`` |br| | | +| ``websockets.serve()`` *(before 14.0)* |br| | ``websockets.serve()`` *(since 14.0)* |br| | +| ``websockets.server.serve()`` |br| | :func:`websockets.asyncio.server.serve` | | :func:`websockets.legacy.server.serve` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -| ``websockets.unix_serve()`` |br| | :func:`websockets.asyncio.server.unix_serve` | -| ``websockets.server.unix_serve()`` |br| | | +| ``websockets.unix_serve()`` *(before 14.0)* |br| | ``websockets.unix_serve()`` *(since 14.0)* |br| | +| ``websockets.server.unix_serve()`` |br| | :func:`websockets.asyncio.server.unix_serve` | | :func:`websockets.legacy.server.unix_serve` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.WebSocketServer`` |br| | :class:`websockets.asyncio.server.Server` | @@ -125,8 +120,8 @@ Server APIs | ``websockets.server.WebSocketServerProtocol`` |br| | | | :class:`websockets.legacy.server.WebSocketServerProtocol` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -| ``websockets.broadcast()`` |br| | :func:`websockets.asyncio.server.broadcast` | -| :func:`websockets.legacy.server.broadcast()` | | +| ``websockets.broadcast()`` *(before 14.0)* |br| | ``websockets.broadcast()`` *(since 14.0)* |br| | +| :func:`websockets.legacy.server.broadcast()` | :func:`websockets.asyncio.server.broadcast` | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.BasicAuthWebSocketServerProtocol`` |br| | See below :ref:`how to migrate ` to | | ``websockets.auth.BasicAuthWebSocketServerProtocol`` |br| | :func:`websockets.asyncio.server.basic_auth`. | diff --git a/docs/index.rst b/docs/index.rst index b8cd300e3..de14fa2d0 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -28,30 +28,13 @@ with a focus on correctness, simplicity, robustness, and performance. It supports several network I/O and control flow paradigms. -1. The primary implementation builds upon :mod:`asyncio`, Python's standard - asynchronous I/O framework. It provides an elegant coroutine-based API. It's - ideal for servers that handle many clients concurrently. - - .. admonition:: As of version :ref:`13.0`, there is a new :mod:`asyncio` - implementation. - :class: important - - The historical implementation in ``websockets.legacy`` traces its roots to - early versions of websockets. Although it's stable and robust, it is now - considered legacy. - - The new implementation in ``websockets.asyncio`` is a rewrite on top of - the Sans-I/O implementation. It adds a few features that were impossible - to implement within the original design. - - The new implementation provides all features of the historical - implementation, and a few more. If you're using the historical - implementation, you should :doc:`ugrade to the new implementation - `. It's usually straightforward. +1. The default implementation builds upon :mod:`asyncio`, Python's built-in + asynchronous I/O library. It provides an elegant coroutine-based API. It's + ideal for servers that handle many client connections. 2. The :mod:`threading` implementation is a good alternative for clients, especially if you aren't familiar with :mod:`asyncio`. It may also be used - for servers that don't need to serve many clients. + for servers that handle few client connections. 3. The `Sans-I/O`_ implementation is designed for integrating in third-party libraries, typically application servers, in addition being used internally @@ -59,13 +42,42 @@ It supports several network I/O and control flow paradigms. .. _Sans-I/O: https://sans-io.readthedocs.io/ -Here's an echo server using the :mod:`asyncio` API: +Refer to the :doc:`feature support matrices ` for the full +list of features provided by each implementation. + +.. admonition:: The :mod:`asyncio` implementation was rewritten. + :class: tip + + The new implementation in ``websockets.asyncio`` builds upon the Sans-I/O + implementation. It adds features that were impossible to provide in the + original design. It was introduced in version 13.0. + + The historical implementation in ``websockets.legacy`` traces its roots to + early versions of websockets. While it's stable and robust, it was deprecated + in version 14.0 and it will be removed by 2030. + + The new implementation provides the same features as the historical + implementation, and then some. If you're using the historical implementation, + you should :doc:`ugrade to the new implementation `. + +Here's an echo server and corresponding client. + +.. tab:: asyncio + + .. literalinclude:: ../example/asyncio/echo.py + +.. tab:: threading + + .. literalinclude:: ../example/sync/echo.py + +.. tab:: asyncio + :new-set: -.. literalinclude:: ../example/echo.py + .. literalinclude:: ../example/asyncio/hello.py -Here's a client using the :mod:`threading` API: +.. tab:: threading -.. literalinclude:: ../example/hello.py + .. literalinclude:: ../example/sync/hello.py Don't worry about the opening and closing handshakes, pings and pongs, or any other behavior described in the WebSocket specification. websockets takes care diff --git a/docs/intro/index.rst b/docs/intro/index.rst index 095262a20..642e50094 100644 --- a/docs/intro/index.rst +++ b/docs/intro/index.rst @@ -6,7 +6,7 @@ Getting started Requirements ------------ -websockets requires Python ≥ 3.8. +websockets requires Python ≥ 3.9. .. admonition:: Use the most recent Python release :class: tip diff --git a/docs/intro/tutorial1.rst b/docs/intro/tutorial1.rst index 6e91867c8..87074caee 100644 --- a/docs/intro/tutorial1.rst +++ b/docs/intro/tutorial1.rst @@ -271,11 +271,13 @@ spot real errors when you add functionality to the server. Catch it in the .. code-block:: python + from websockets.exceptions import ConnectionClosedOK + async def handler(websocket): while True: try: message = await websocket.recv() - except websockets.ConnectionClosedOK: + except ConnectionClosedOK: break print(message) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 7e4bce9c6..ca6769199 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,6 +25,133 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented, private APIs may change without notice. +.. _14.1: + +14.1 +---- + +*November 13, 2024* + +Improvements +............ + +* Supported ``max_queue=None`` in the :mod:`asyncio` and :mod:`threading` + implementations for consistency with the legacy implementation, even though + this is never a good idea. + +* Added ``close_code`` and ``close_reason`` attributes in the :mod:`asyncio` and + :mod:`threading` implementations for consistency with the legacy + implementation. + +Bug fixes +......... + +* Once the connection is closed, messages previously received and buffered can + be read in the :mod:`asyncio` and :mod:`threading` implementations, just like + in the legacy implementation. + +.. _14.0: + +14.0 +---- + +*November 9, 2024* + +Backwards-incompatible changes +.............................. + +.. admonition:: websockets 14.0 requires Python ≥ 3.9. + :class: tip + + websockets 13.1 is the last version supporting Python 3.8. + +.. admonition:: The new :mod:`asyncio` implementation is now the default. + :class: danger + + The following aliases in the ``websockets`` package were switched to the new + :mod:`asyncio` implementation:: + + from websockets import connect, unix_connext + from websockets import broadcast, serve, unix_serve + + If you're using any of them, then you must follow the :doc:`upgrade guide + <../howto/upgrade>` immediately. + + Alternatively, you may stick to the legacy :mod:`asyncio` implementation for + now by importing it explicitly:: + + from websockets.legacy.client import connect, unix_connect + from websockets.legacy.server import broadcast, serve, unix_serve + +.. admonition:: The legacy :mod:`asyncio` implementation is now deprecated. + :class: caution + + The :doc:`upgrade guide <../howto/upgrade>` provides complete instructions + to migrate your application. + + Aliases for deprecated API were removed from ``websockets.__all__``, meaning + that they cannot be imported with ``from websockets import *`` anymore. + +.. admonition:: Several API raise :exc:`ValueError` instead of :exc:`TypeError` + on invalid arguments. + :class: note + + :func:`~asyncio.client.connect`, :func:`~asyncio.client.unix_connect`, and + :func:`~asyncio.server.basic_auth` in the :mod:`asyncio` implementation as + well as :func:`~sync.client.connect`, :func:`~sync.client.unix_connect`, + :func:`~sync.server.serve`, :func:`~sync.server.unix_serve`, and + :func:`~sync.server.basic_auth` in the :mod:`threading` implementation now + raise :exc:`ValueError` when a required argument isn't provided or an + argument that is incompatible with others is provided. + +.. admonition:: :attr:`Frame.data ` is now a bytes-like object. + :class: note + + In addition to :class:`bytes`, it may be a :class:`bytearray` or a + :class:`memoryview`. If you wrote an :class:`~extensions.Extension` that + relies on methods not provided by these types, you must update your code. + +.. admonition:: The signature of :exc:`~exceptions.PayloadTooBig` changed. + :class: note + + If you wrote an extension that raises :exc:`~exceptions.PayloadTooBig` in + :meth:`~extensions.Extension.decode`, for example, you must replace + ``PayloadTooBig(f"over size limit ({size} > {max_size} bytes)")`` with + ``PayloadTooBig(size, max_size)``. + +New features +............ + +* Added an option to receive text frames as :class:`bytes`, without decoding, + in the :mod:`threading` implementation; also binary frames as :class:`str`. + +* Added an option to send :class:`bytes` in a text frame in the :mod:`asyncio` + and :mod:`threading` implementations; also :class:`str` in a binary frame. + +Improvements +............ + +* The :mod:`threading` implementation receives messages faster. + +* Sending or receiving large compressed messages is now faster. + +* Errors when a fragmented message is too large are clearer. + +* Log messages at the :data:`~logging.WARNING` and :data:`~logging.INFO` levels + no longer include stack traces. + +Bug fixes +......... + +* Clients no longer crash when the server rejects the opening handshake and the + HTTP response doesn't Include a ``Content-Length`` header. + +* Returning an HTTP response in ``process_request`` or ``process_response`` + doesn't generate a log message at the :data:`~logging.ERROR` level anymore. + +* Connections are closed with code 1007 (invalid data) when receiving invalid + UTF-8 in a text frame. + .. _13.1: 13.1 @@ -99,11 +226,6 @@ Bug fixes Backwards-incompatible changes .............................. -.. admonition:: websockets 13.0 requires Python ≥ 3.8. - :class: tip - - websockets 12.0 is the last version supporting Python 3.7. - .. admonition:: Receiving the request path in the second parameter of connection handlers is deprecated. :class: note @@ -153,6 +275,9 @@ New features * Validated compatibility with Python 3.12 and 3.13. +* Added an option to receive text frames as :class:`bytes`, without decoding, + in the :mod:`asyncio` implementation; also binary frames as :class:`str`. + * Added :doc:`environment variables <../reference/variables>` to configure debug logs, the ``Server`` and ``User-Agent`` headers, as well as security limits. @@ -1396,7 +1521,7 @@ New features * Added support for providing and checking Origin_. -.. _Origin: https://www.rfc-editor.org/rfc/rfc6455.html#section-10.2 +.. _Origin: https://datatracker.ietf.org/doc/html/rfc6455.html#section-10.2 .. _2.0: diff --git a/docs/project/contributing.rst b/docs/project/contributing.rst index 020ed7ad8..6ecd175f8 100644 --- a/docs/project/contributing.rst +++ b/docs/project/contributing.rst @@ -17,8 +17,8 @@ apologies. I know I can mess up. I can't expect you to tell me, but if you choose to do so, I'll do my best to handle criticism constructively. -- Aymeric)* -Contributions -------------- +Contributing +------------ Bug reports, patches and suggestions are welcome! @@ -34,33 +34,24 @@ websockets. .. _issue: https://github.com/python-websockets/websockets/issues/new .. _pull request: https://github.com/python-websockets/websockets/compare/ -Questions +Packaging --------- -GitHub issues aren't a good medium for handling questions. There are better -places to ask questions, for example Stack Overflow. - -If you want to ask a question anyway, please make sure that: - -- it's a question about websockets and not about :mod:`asyncio`; -- it isn't answered in the documentation; -- it wasn't asked already. - -A good question can be written as a suggestion to improve the documentation. - -Cryptocurrency users --------------------- - -websockets appears to be quite popular for interfacing with Bitcoin or other -cryptocurrency trackers. I'm strongly opposed to Bitcoin's carbon footprint. +Some distributions package websockets so that it can be installed with the +system package manager rather than with pip, possibly in a virtualenv. -I'm aware of efforts to build proof-of-stake models. I'll care once the total -energy consumption of all cryptocurrencies drops to a non-bullshit level. +If you're packaging websockets for a distribution, you must use `releases +published on PyPI`_ as input. You may check `SLSA attestations on GitHub`_. -You already negated all of humanity's efforts to develop renewable energy. -Please stop heating the planet where my children will have to live. +.. _releases published on PyPI: https://pypi.org/project/websockets/#files +.. _SLSA attestations on GitHub: https://github.com/python-websockets/websockets/attestations -Since websockets is released under an open-source license, you can use it for -any purpose you like. However, I won't spend any of my time to help you. +You mustn't rely on the git repository as input. Specifically, you mustn't +attempt to run the main test suite. It isn't treated as a deliverable of the +project. It doesn't do what you think it does. It's designed for the needs of +developers, not packagers. -I will summarily close issues related to Bitcoin or cryptocurrency in any way. +On a typical build farm for a distribution, tests that exercise timeouts will +fail randomly. Indeed, the test suite is optimized for running very fast, with a +tolerable level of flakiness, on a high-end laptop without noisy neighbors. This +isn't your context. diff --git a/docs/project/index.rst b/docs/project/index.rst index 459146345..56c98196a 100644 --- a/docs/project/index.rst +++ b/docs/project/index.rst @@ -8,5 +8,7 @@ This is about websockets-the-project rather than websockets-the-software. changelog contributing - license + sponsoring For enterprise + support + license diff --git a/docs/project/sponsoring.rst b/docs/project/sponsoring.rst new file mode 100644 index 000000000..77a4fd1d8 --- /dev/null +++ b/docs/project/sponsoring.rst @@ -0,0 +1,11 @@ +Sponsoring +========== + +You may sponsor the development of websockets through: + +* `GitHub Sponsors`_ +* `Open Collective`_ +* :doc:`Tidelift ` + +.. _GitHub Sponsors: https://github.com/sponsors/python-websockets +.. _Open Collective: https://opencollective.com/websockets diff --git a/docs/project/support.rst b/docs/project/support.rst new file mode 100644 index 000000000..21aad6e02 --- /dev/null +++ b/docs/project/support.rst @@ -0,0 +1,49 @@ +Getting support +=============== + +.. admonition:: There are no free support channels. + :class: tip + + websockets is an open-source project. It's primarily maintained by one + person as a hobby. + + For this reason, the focus is on flawless code and self-service + documentation, not support. + +Enterprise +---------- + +websockets is maintained with high standards, making it suitable for enterprise +use cases. Additional guarantees are available via :doc:`Tidelift `. +If you're using it in a professional setting, consider subscribing. + +Questions +--------- + +GitHub issues aren't a good medium for handling questions. There are better +places to ask questions, for example Stack Overflow. + +If you want to ask a question anyway, please make sure that: + +- it's a question about websockets and not about :mod:`asyncio`; +- it isn't answered in the documentation; +- it wasn't asked already. + +A good question can be written as a suggestion to improve the documentation. + +Cryptocurrency users +-------------------- + +websockets appears to be quite popular for interfacing with Bitcoin or other +cryptocurrency trackers. I'm strongly opposed to Bitcoin's carbon footprint. + +I'm aware of efforts to build proof-of-stake models. I'll care once the total +energy consumption of all cryptocurrencies drops to a non-bullshit level. + +You already negated all of humanity's efforts to develop renewable energy. +Please stop heating the planet where my children will have to live. + +Since websockets is released under an open-source license, you can use it for +any purpose you like. However, I won't spend any of my time to help you. + +I will summarily close issues related to cryptocurrency in any way. diff --git a/docs/reference/asyncio/client.rst b/docs/reference/asyncio/client.rst index e2b0ff550..ea7b21506 100644 --- a/docs/reference/asyncio/client.rst +++ b/docs/reference/asyncio/client.rst @@ -57,3 +57,10 @@ Using a connection .. autoattribute:: response .. autoproperty:: subprotocol + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason diff --git a/docs/reference/asyncio/common.rst b/docs/reference/asyncio/common.rst index a58325fb9..325f20450 100644 --- a/docs/reference/asyncio/common.rst +++ b/docs/reference/asyncio/common.rst @@ -45,3 +45,10 @@ Both sides (new :mod:`asyncio`) .. autoattribute:: response .. autoproperty:: subprotocol + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason diff --git a/docs/reference/asyncio/server.rst b/docs/reference/asyncio/server.rst index 2fcaeb414..49bd6f072 100644 --- a/docs/reference/asyncio/server.rst +++ b/docs/reference/asyncio/server.rst @@ -79,6 +79,13 @@ Using a connection .. autoproperty:: subprotocol + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason + Broadcast --------- diff --git a/docs/reference/extensions.rst b/docs/reference/extensions.rst index a70f1b1e5..f3da464a5 100644 --- a/docs/reference/extensions.rst +++ b/docs/reference/extensions.rst @@ -8,7 +8,7 @@ The WebSocket protocol supports extensions_. At the time of writing, there's only one `registered extension`_ with a public specification, WebSocket Per-Message Deflate. -.. _extensions: https://www.rfc-editor.org/rfc/rfc6455.html#section-9 +.. _extensions: https://datatracker.ietf.org/doc/html/rfc6455.html#section-9 .. _registered extension: https://www.iana.org/assignments/websocket/websocket.xhtml#extension-name Per-Message Deflate diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 8b04034eb..9187fa505 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -43,12 +43,18 @@ Both sides +------------------------------------+--------+--------+--------+--------+ | Send a fragmented message | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ - | Receive a fragmented message frame | ✅ | ✅ | ✅ | ❌ | + | Receive a fragmented message frame | ✅ | ✅ | — | ❌ | | by frame | | | | | +------------------------------------+--------+--------+--------+--------+ | Receive a fragmented message after | ✅ | ✅ | — | ✅ | | reassembly | | | | | +------------------------------------+--------+--------+--------+--------+ + | Force sending a message as Text or | ✅ | ✅ | — | ❌ | + | Binary | | | | | + +------------------------------------+--------+--------+--------+--------+ + | Force receiving a message as | ✅ | ✅ | — | ❌ | + | :class:`bytes` or :class:`str` | | | | | + +------------------------------------+--------+--------+--------+--------+ | Send a ping | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Respond to pings automatically | ✅ | ✅ | ✅ | ✅ | @@ -188,3 +194,8 @@ connection to a given IP address in a CONNECTING state. This behavior is mandated by :rfc:`6455`, section 4.1. However, :func:`~asyncio.client.connect()` isn't the right layer for enforcing this constraint. It's the caller's responsibility. + +It is possible to send or receive a text message containing invalid UTF-8 with +``send(not_utf8_bytes, text=True)`` and ``not_utf8_bytes = recv(decode=False)`` +respectively. As a side effect of disabling UTF-8 encoding and decoding, these +options also disable UTF-8 validation. diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 77b538b78..c78a3c095 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -13,12 +13,12 @@ Check which implementations support which features and known limitations. features -:mod:`asyncio` (new) --------------------- +:mod:`asyncio` +-------------- It's ideal for servers that handle many clients concurrently. -It's a rewrite of the legacy :mod:`asyncio` implementation. +This is the default implementation. .. toctree:: :titlesonly: @@ -26,17 +26,6 @@ It's a rewrite of the legacy :mod:`asyncio` implementation. asyncio/server asyncio/client -:mod:`asyncio` (legacy) ------------------------ - -This is the historical implementation. - -.. toctree:: - :titlesonly: - - legacy/server - legacy/client - :mod:`threading` ---------------- @@ -62,6 +51,19 @@ application servers. sansio/server sansio/client +:mod:`asyncio` (legacy) +----------------------- + +This is the historical implementation. + +It is deprecated and will be removed. + +.. toctree:: + :titlesonly: + + legacy/server + legacy/client + Extensions ---------- @@ -98,5 +100,5 @@ guarantees of behavior or backwards-compatibility for private APIs. Convenience imports ------------------- -For convenience, many public APIs can be imported directly from the +For convenience, some public APIs can be imported directly from the ``websockets`` package. diff --git a/docs/reference/legacy/client.rst b/docs/reference/legacy/client.rst index fca45d218..a798409f0 100644 --- a/docs/reference/legacy/client.rst +++ b/docs/reference/legacy/client.rst @@ -1,6 +1,12 @@ Client (legacy :mod:`asyncio`) ============================== +.. admonition:: The legacy :mod:`asyncio` implementation is deprecated. + :class: caution + + The :doc:`upgrade guide <../../howto/upgrade>` provides complete instructions + to migrate your application. + .. automodule:: websockets.legacy.client Opening a connection diff --git a/docs/reference/legacy/common.rst b/docs/reference/legacy/common.rst index aee774479..45c56fccd 100644 --- a/docs/reference/legacy/common.rst +++ b/docs/reference/legacy/common.rst @@ -3,6 +3,12 @@ Both sides (legacy :mod:`asyncio`) ================================== +.. admonition:: The legacy :mod:`asyncio` implementation is deprecated. + :class: caution + + The :doc:`upgrade guide <../../howto/upgrade>` provides complete instructions + to migrate your application. + .. automodule:: websockets.legacy.protocol .. autoclass:: WebSocketCommonProtocol(*, logger=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) diff --git a/docs/reference/legacy/server.rst b/docs/reference/legacy/server.rst index b6c383ce7..3c1d19fc6 100644 --- a/docs/reference/legacy/server.rst +++ b/docs/reference/legacy/server.rst @@ -1,6 +1,12 @@ Server (legacy :mod:`asyncio`) ============================== +.. admonition:: The legacy :mod:`asyncio` implementation is deprecated. + :class: caution + + The :doc:`upgrade guide <../../howto/upgrade>` provides complete instructions + to migrate your application. + .. automodule:: websockets.legacy.server Starting a server diff --git a/docs/reference/sync/client.rst b/docs/reference/sync/client.rst index af1132412..2aa491f6a 100644 --- a/docs/reference/sync/client.rst +++ b/docs/reference/sync/client.rst @@ -39,6 +39,8 @@ Using a connection .. autoproperty:: remote_address + .. autoproperty:: state + The following attributes are available after the opening handshake, once the WebSocket connection is open: @@ -47,3 +49,10 @@ Using a connection .. autoattribute:: response .. autoproperty:: subprotocol + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason diff --git a/docs/reference/sync/common.rst b/docs/reference/sync/common.rst index 3dc6d4a50..3c03b25b6 100644 --- a/docs/reference/sync/common.rst +++ b/docs/reference/sync/common.rst @@ -31,6 +31,8 @@ Both sides (:mod:`threading`) .. autoproperty:: remote_address + .. autoproperty:: state + The following attributes are available after the opening handshake, once the WebSocket connection is open: @@ -39,3 +41,10 @@ Both sides (:mod:`threading`) .. autoattribute:: response .. autoproperty:: subprotocol + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason diff --git a/docs/reference/sync/server.rst b/docs/reference/sync/server.rst index 80e9c17bb..1d80450f9 100644 --- a/docs/reference/sync/server.rst +++ b/docs/reference/sync/server.rst @@ -52,6 +52,8 @@ Using a connection .. autoproperty:: remote_address + .. autoproperty:: state + The following attributes are available after the opening handshake, once the WebSocket connection is open: @@ -61,6 +63,13 @@ Using a connection .. autoproperty:: subprotocol + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason + HTTP Basic Authentication ------------------------- diff --git a/docs/topics/authentication.rst b/docs/topics/authentication.rst index 86d2e2587..e2de4332e 100644 --- a/docs/topics/authentication.rst +++ b/docs/topics/authentication.rst @@ -1,13 +1,13 @@ Authentication ============== -The WebSocket protocol was designed for creating web applications that need -bidirectional communication between clients running in browsers and servers. +The WebSocket protocol is designed for creating web applications that require +bidirectional communication between browsers and servers. In most practical use cases, WebSocket servers need to authenticate clients in order to route communications appropriately and securely. -:rfc:`6455` stays elusive when it comes to authentication: +:rfc:`6455` remains elusive when it comes to authentication: This protocol doesn't prescribe any particular way that servers can authenticate clients during the WebSocket handshake. The WebSocket @@ -26,8 +26,8 @@ System design Consider a setup where the WebSocket server is separate from the HTTP server. -Most servers built with websockets to complement a web application adopt this -design because websockets doesn't aim at supporting HTTP. +Most servers built with websockets adopt this design because they're a component +in a web application and websockets doesn't aim at supporting HTTP. The following diagram illustrates the authentication flow. @@ -82,8 +82,8 @@ WebSocket server. credentials would be a session identifier or a serialized, signed session. Unfortunately, when the WebSocket server runs on a different domain from - the web application, this idea bumps into the `Same-Origin Policy`_. For - security reasons, setting a cookie on a different origin is impossible. + the web application, this idea hits the wall of the `Same-Origin Policy`_. + For security reasons, setting a cookie on a different origin is impossible. The proper workaround consists in: @@ -108,13 +108,11 @@ WebSocket server. Letting the browser perform HTTP Basic Auth is a nice idea in theory. - In practice it doesn't work due to poor support in browsers. + In practice it doesn't work due to browser support limitations: - As of May 2021: + * Chrome behaves as expected. - * Chrome 90 behaves as expected. - - * Firefox 88 caches credentials too aggressively. + * Firefox caches credentials too aggressively. When connecting again to the same server with new credentials, it reuses the old credentials, which may be expired, resulting in an HTTP 401. Then @@ -123,7 +121,7 @@ WebSocket server. When tokens are short-lived or single-use, this bug produces an interesting effect: every other WebSocket connection fails. - * Safari 14 ignores credentials entirely. + * Safari behaves as expected. Two other options are off the table: @@ -142,8 +140,10 @@ Two other options are off the table: While this is suggested by the RFC, installing a TLS certificate is too far from the mainstream experience of browser users. This could make sense in - high security contexts. I hope developers working on such projects don't - take security advice from the documentation of random open source projects. + high security contexts. + + I hope that developers working on projects in this category don't take + security advice from the documentation of random open source projects :-) Let's experiment! ----------------- @@ -185,6 +185,8 @@ connection: .. code-block:: python + from websockets.frames import CloseCode + async def first_message_handler(websocket): token = await websocket.recv() user = get_user(token) @@ -212,24 +214,16 @@ the user. If authentication fails, it returns an HTTP 401: .. code-block:: python - from websockets.legacy.server import WebSocketServerProtocol - - class QueryParamProtocol(WebSocketServerProtocol): - async def process_request(self, path, headers): - token = get_query_parameter(path, "token") - if token is None: - return http.HTTPStatus.UNAUTHORIZED, [], b"Missing token\n" - - user = get_user(token) - if user is None: - return http.HTTPStatus.UNAUTHORIZED, [], b"Invalid token\n" + async def query_param_auth(connection, request): + token = get_query_param(request.path, "token") + if token is None: + return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Missing token\n") - self.user = user + user = get_user(token) + if user is None: + return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Invalid token\n") - async def query_param_handler(websocket): - user = websocket.user - - ... + connection.username = user Cookie ...... @@ -260,27 +254,19 @@ the user. If authentication fails, it returns an HTTP 401: .. code-block:: python - from websockets.legacy.server import WebSocketServerProtocol - - class CookieProtocol(WebSocketServerProtocol): - async def process_request(self, path, headers): - # Serve iframe on non-WebSocket requests - ... - - token = get_cookie(headers.get("Cookie", ""), "token") - if token is None: - return http.HTTPStatus.UNAUTHORIZED, [], b"Missing token\n" - - user = get_user(token) - if user is None: - return http.HTTPStatus.UNAUTHORIZED, [], b"Invalid token\n" + async def cookie_auth(connection, request): + # Serve iframe on non-WebSocket requests + ... - self.user = user + token = get_cookie(request.headers.get("Cookie", ""), "token") + if token is None: + return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Missing token\n") - async def cookie_handler(websocket): - user = websocket.user + user = get_user(token) + if user is None: + return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Invalid token\n") - ... + connection.username = user User information ................ @@ -303,24 +289,12 @@ the user. If authentication fails, it returns an HTTP 401: .. code-block:: python - from websockets.legacy.auth import BasicAuthWebSocketServerProtocol - - class UserInfoProtocol(BasicAuthWebSocketServerProtocol): - async def check_credentials(self, username, password): - if username != "token": - return False - - user = get_user(password) - if user is None: - return False + from websockets.asyncio.server import basic_auth as websockets_basic_auth - self.user = user - return True + def check_credentials(username, password): + return username == get_user(password) - async def user_info_handler(websocket): - user = websocket.user - - ... + basic_auth = websockets_basic_auth(check_credentials=check_credentials) Machine-to-machine authentication --------------------------------- @@ -334,11 +308,9 @@ To authenticate a websockets client with HTTP Basic Authentication .. code-block:: python - from websockets.legacy.client import connect + from websockets.asyncio.client import connect - async with connect( - f"wss://{username}:{password}@example.com" - ) as websocket: + async with connect(f"wss://{username}:{password}@.../") as websocket: ... (You must :func:`~urllib.parse.quote` ``username`` and ``password`` if they @@ -349,10 +321,8 @@ To authenticate a websockets client with HTTP Bearer Authentication .. code-block:: python - from websockets.legacy.client import connect + from websockets.asyncio.client import connect - async with connect( - "wss://example.com", - extra_headers={"Authorization": f"Bearer {token}"} - ) as websocket: + headers = {"Authorization": f"Bearer {token}"} + async with connect("wss://.../", additional_headers=headers) as websocket: ... diff --git a/docs/topics/broadcast.rst b/docs/topics/broadcast.rst index c9699feb2..66b0819b2 100644 --- a/docs/topics/broadcast.rst +++ b/docs/topics/broadcast.rst @@ -83,7 +83,7 @@ to:: Here's a coroutine that broadcasts a message to all clients:: - from websockets import ConnectionClosed + from websockets.exceptions import ConnectionClosed async def broadcast(message): for websocket in CLIENTS.copy(): diff --git a/docs/topics/design.rst b/docs/topics/design.rst index d2fd18d0c..bc14bd332 100644 --- a/docs/topics/design.rst +++ b/docs/topics/design.rst @@ -1,3 +1,5 @@ +:orphan: + Design (legacy :mod:`asyncio`) ============================== @@ -171,16 +173,16 @@ differences between a server and a client: - `closing the TCP connection`_: the server closes the connection immediately; the client waits for the server to do it. -.. _client-to-server masking: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.3 -.. _closing the TCP connection: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.1 +.. _client-to-server masking: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5.3 +.. _closing the TCP connection: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5.5.1 These differences are so minor that all the logic for `data framing`_, for `sending and receiving data`_ and for `closing the connection`_ is implemented in the same class, :class:`~protocol.WebSocketCommonProtocol`. -.. _data framing: https://www.rfc-editor.org/rfc/rfc6455.html#section-5 -.. _sending and receiving data: https://www.rfc-editor.org/rfc/rfc6455.html#section-6 -.. _closing the connection: https://www.rfc-editor.org/rfc/rfc6455.html#section-7 +.. _data framing: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5 +.. _sending and receiving data: https://datatracker.ietf.org/doc/html/rfc6455.html#section-6 +.. _closing the connection: https://datatracker.ietf.org/doc/html/rfc6455.html#section-7 The :attr:`~protocol.WebSocketCommonProtocol.is_client` attribute tells which side a protocol instance is managing. This attribute is defined on the diff --git a/docs/topics/index.rst b/docs/topics/index.rst index a2b8ca879..616753c6c 100644 --- a/docs/topics/index.rst +++ b/docs/topics/index.rst @@ -12,7 +12,6 @@ Get a deeper understanding of how websockets is built and why. broadcast compression keepalive - design memory security performance diff --git a/docs/topics/keepalive.rst b/docs/topics/keepalive.rst index 91f11fb11..a0467ced2 100644 --- a/docs/topics/keepalive.rst +++ b/docs/topics/keepalive.rst @@ -1,6 +1,11 @@ Keepalive and latency ===================== +.. admonition:: This guide applies only to the :mod:`asyncio` implementation. + :class: tip + + The :mod:`threading` implementation doesn't provide keepalive yet. + .. currentmodule:: websockets Long-lived connections @@ -28,17 +33,11 @@ Keepalive in websockets To avoid these problems, websockets runs a keepalive and heartbeat mechanism based on WebSocket Ping_ and Pong_ frames, which are designed for this purpose. -.. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 -.. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 - -It loops through these steps: - -1. Wait 20 seconds. -2. Send a Ping frame. -3. Receive a corresponding Pong frame within 20 seconds. +.. _Ping: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5.5.2 +.. _Pong: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5.5.3 -If the Pong frame isn't received, websockets considers the connection broken and -closes it. +It sends a Ping frame every 20 seconds. It expects a Pong frame in return within +20 seconds. Else, it considers the connection broken and terminates it. This mechanism serves three purposes: @@ -91,6 +90,46 @@ application layer. Read this `blog post `_ for a complete walk-through of this issue. +Application-level keepalive +--------------------------- + +Some servers require clients to send a keepalive message with a specific content +at regular intervals. Usually they expect Text_ frames rather than Ping_ frames, +meaning that you must send them with :attr:`~asyncio.connection.Connection.send` +rather than :attr:`~asyncio.connection.Connection.ping`. + +.. _Text: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5.6 + +In websockets, such keepalive mechanisms are considered as application-level +because they rely on data frames. That's unlike the protocol-level keepalive +based on control frames. Therefore, it's your responsibility to implement the +required behavior. + +You can run a task in the background to send keepalive messages: + +.. code-block:: python + + import itertools + import json + + from websockets.exceptions import ConnectionClosed + + async def keepalive(websocket, ping_interval=30): + for ping in itertools.count(): + await asyncio.sleep(ping_interval) + try: + await websocket.send(json.dumps({"ping": ping})) + except ConnectionClosed: + break + + async def main(): + async with connect(...) as websocket: + keepalive_task = asyncio.create_task(keepalive(websocket)) + try: + ... # your application logic goes here + finally: + keepalive_task.cancel() + Latency issues -------------- diff --git a/docs/topics/logging.rst b/docs/topics/logging.rst index be5678455..fff33a024 100644 --- a/docs/topics/logging.rst +++ b/docs/topics/logging.rst @@ -35,8 +35,8 @@ Instead, when running as a server, websockets logs one event when a `connection is established`_ and another event when a `connection is closed`_. -.. _connection is established: https://www.rfc-editor.org/rfc/rfc6455.html#section-4 -.. _connection is closed: https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.4 +.. _connection is established: https://datatracker.ietf.org/doc/html/rfc6455.html#section-4 +.. _connection is closed: https://datatracker.ietf.org/doc/html/rfc6455.html#section-7.1.4 By default, websockets doesn't log an event for every message. That would be excessive for many applications exchanging small messages at a fast rate. If @@ -218,7 +218,10 @@ Here's what websockets logs at each level. ``ERROR`` ......... -* Exceptions raised by connection handler coroutines in servers +* Exceptions raised by your code in servers + * connection handler coroutines + * ``select_subprotocol`` callbacks + * ``process_request`` and ``process_response`` callbacks * Exceptions resulting from bugs in websockets ``WARNING`` @@ -250,4 +253,5 @@ Debug messages have cute prefixes that make logs easier to scan: * ``=`` - set connection state * ``x`` - shut down connection * ``%`` - manage pings and pongs -* ``!`` - handle errors and timeouts +* ``-`` - timeout +* ``!`` - error, with a traceback diff --git a/example/asyncio/client.py b/example/asyncio/client.py new file mode 100644 index 000000000..e3562642d --- /dev/null +++ b/example/asyncio/client.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python + +"""Client example using the asyncio API.""" + +import asyncio + +from websockets.asyncio.client import connect + + +async def hello(): + async with connect("ws://localhost:8765") as websocket: + name = input("What's your name? ") + + await websocket.send(name) + print(f">>> {name}") + + greeting = await websocket.recv() + print(f"<<< {greeting}") + + +if __name__ == "__main__": + asyncio.run(hello()) diff --git a/example/echo.py b/example/asyncio/echo.py similarity index 51% rename from example/echo.py rename to example/asyncio/echo.py index b952a5cfb..28d877be7 100755 --- a/example/echo.py +++ b/example/asyncio/echo.py @@ -1,14 +1,20 @@ #!/usr/bin/env python +"""Echo server using the asyncio API.""" + import asyncio from websockets.asyncio.server import serve + async def echo(websocket): async for message in websocket: await websocket.send(message) + async def main(): - async with serve(echo, "localhost", 8765): - await asyncio.get_running_loop().create_future() # run forever + async with serve(echo, "localhost", 8765) as server: + await server.serve_forever() + -asyncio.run(main()) +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/asyncio/hello.py b/example/asyncio/hello.py new file mode 100755 index 000000000..6e4518497 --- /dev/null +++ b/example/asyncio/hello.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python + +"""Client using the asyncio API.""" + +import asyncio +from websockets.asyncio.client import connect + + +async def hello(): + async with connect("ws://localhost:8765") as websocket: + await websocket.send("Hello world!") + message = await websocket.recv() + print(message) + + +if __name__ == "__main__": + asyncio.run(hello()) diff --git a/example/asyncio/server.py b/example/asyncio/server.py new file mode 100644 index 000000000..574e053bf --- /dev/null +++ b/example/asyncio/server.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python + +"""Server example using the asyncio API.""" + +import asyncio +from websockets.asyncio.server import serve + + +async def hello(websocket): + name = await websocket.recv() + print(f"<<< {name}") + + greeting = f"Hello {name}!" + + await websocket.send(greeting) + print(f">>> {greeting}") + + +async def main(): + async with serve(hello, "localhost", 8765) as server: + await server.serve_forever() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/faq/health_check_server.py b/example/faq/health_check_server.py index c0fa4327f..30623a4bb 100755 --- a/example/faq/health_check_server.py +++ b/example/faq/health_check_server.py @@ -6,7 +6,7 @@ def health_check(connection, request): if request.path == "/healthz": - return connection.respond(HTTPStatus.OK, b"OK\n") + return connection.respond(HTTPStatus.OK, "OK\n") async def echo(websocket): async for message in websocket: diff --git a/example/ruff.toml b/example/ruff.toml new file mode 100644 index 000000000..13ae36c08 --- /dev/null +++ b/example/ruff.toml @@ -0,0 +1,2 @@ +[lint.isort] +no-sections = true diff --git a/example/sync/client.py b/example/sync/client.py new file mode 100644 index 000000000..c0d633c7b --- /dev/null +++ b/example/sync/client.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python + +"""Client example using the threading API.""" + +from websockets.sync.client import connect + + +def hello(): + with connect("ws://localhost:8765") as websocket: + name = input("What's your name? ") + + websocket.send(name) + print(f">>> {name}") + + greeting = websocket.recv() + print(f"<<< {greeting}") + + +if __name__ == "__main__": + hello() diff --git a/example/sync/echo.py b/example/sync/echo.py new file mode 100755 index 000000000..4b47db1ba --- /dev/null +++ b/example/sync/echo.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python + +"""Echo server using the threading API.""" + +from websockets.sync.server import serve + + +def echo(websocket): + for message in websocket: + websocket.send(message) + + +def main(): + with serve(echo, "localhost", 8765) as server: + server.serve_forever() + + +if __name__ == "__main__": + main() diff --git a/example/hello.py b/example/sync/hello.py similarity index 66% rename from example/hello.py rename to example/sync/hello.py index a3ce0699e..bb4cd3ffd 100755 --- a/example/hello.py +++ b/example/sync/hello.py @@ -1,12 +1,16 @@ #!/usr/bin/env python -import asyncio +"""Client using the threading API.""" + from websockets.sync.client import connect + def hello(): with connect("ws://localhost:8765") as websocket: websocket.send("Hello world!") message = websocket.recv() - print(f"Received: {message}") + print(message) + -hello() +if __name__ == "__main__": + hello() diff --git a/example/sync/server.py b/example/sync/server.py new file mode 100644 index 000000000..030049f81 --- /dev/null +++ b/example/sync/server.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python + +"""Server example using the threading API.""" + +from websockets.sync.server import serve + + +def hello(websocket): + name = websocket.recv() + print(f"<<< {name}") + + greeting = f"Hello {name}!" + + websocket.send(greeting) + print(f">>> {greeting}") + + +def main(): + with serve(hello, "localhost", 8765) as server: + server.serve_forever() + + +if __name__ == "__main__": + main() diff --git a/experiments/authentication/app.py b/experiments/authentication/app.py index e3b2cf1f6..0bdd7fd2f 100644 --- a/experiments/authentication/app.py +++ b/experiments/authentication/app.py @@ -1,6 +1,7 @@ #!/usr/bin/env python import asyncio +import email.utils import http import http.cookies import pathlib @@ -8,9 +9,10 @@ import urllib.parse import uuid +from websockets.asyncio.server import basic_auth as websockets_basic_auth, serve +from websockets.datastructures import Headers from websockets.frames import CloseCode -from websockets.legacy.auth import BasicAuthWebSocketServerProtocol -from websockets.legacy.server import WebSocketServerProtocol, serve +from websockets.http11 import Response # User accounts database @@ -49,7 +51,19 @@ def get_query_param(path, key): return values[0] -# Main HTTP server +# WebSocket handler + + +async def handler(websocket): + try: + user = websocket.username + except AttributeError: + return + + await websocket.send(f"Hello {user}!") + message = await websocket.recv() + assert message == f"Goodbye {user}." + CONTENT_TYPES = { ".css": "text/css", @@ -59,9 +73,10 @@ def get_query_param(path, key): } -async def serve_html(path, request_headers): - user = get_query_param(path, "user") - path = urllib.parse.urlparse(path).path +async def serve_html(connection, request): + """Basic HTTP server implemented as a process_request hook.""" + user = get_query_param(request.path, "user") + path = urllib.parse.urlparse(request.path).path if path == "/": if user is None: page = "index.html" @@ -76,147 +91,96 @@ async def serve_html(path, request_headers): pass else: if template.is_file(): - headers = {"Content-Type": CONTENT_TYPES[template.suffix]} body = template.read_bytes() if user is not None: token = create_token(user) body = body.replace(b"TOKEN", token.encode()) - return http.HTTPStatus.OK, headers, body - - return http.HTTPStatus.NOT_FOUND, {}, b"Not found\n" - + headers = Headers( + { + "Date": email.utils.formatdate(usegmt=True), + "Connection": "close", + "Content-Length": str(len(body)), + "Content-Type": CONTENT_TYPES[template.suffix], + } + ) + return Response(200, "OK", headers, body) -async def noop_handler(websocket): - pass - - -# Send credentials as the first message in the WebSocket connection + return connection.respond(http.HTTPStatus.NOT_FOUND, "Not found\n") async def first_message_handler(websocket): + """Handler that sends credentials in the first WebSocket message.""" token = await websocket.recv() user = get_user(token) if user is None: await websocket.close(CloseCode.INTERNAL_ERROR, "authentication failed") return - await websocket.send(f"Hello {user}!") - message = await websocket.recv() - assert message == f"Goodbye {user}." - - -# Add credentials to the WebSocket URI in a query parameter - - -class QueryParamProtocol(WebSocketServerProtocol): - async def process_request(self, path, headers): - token = get_query_param(path, "token") - if token is None: - return http.HTTPStatus.UNAUTHORIZED, [], b"Missing token\n" - - user = get_user(token) - if user is None: - return http.HTTPStatus.UNAUTHORIZED, [], b"Invalid token\n" - - self.user = user - - -async def query_param_handler(websocket): - user = websocket.user - - await websocket.send(f"Hello {user}!") - message = await websocket.recv() - assert message == f"Goodbye {user}." - - -# Set a cookie on the domain of the WebSocket URI - - -class CookieProtocol(WebSocketServerProtocol): - async def process_request(self, path, headers): - if "Upgrade" not in headers: - template = pathlib.Path(__file__).with_name(path[1:]) - headers = {"Content-Type": CONTENT_TYPES[template.suffix]} - body = template.read_bytes() - return http.HTTPStatus.OK, headers, body - - token = get_cookie(headers.get("Cookie", ""), "token") - if token is None: - return http.HTTPStatus.UNAUTHORIZED, [], b"Missing token\n" - - user = get_user(token) - if user is None: - return http.HTTPStatus.UNAUTHORIZED, [], b"Invalid token\n" - - self.user = user + websocket.username = user + await handler(websocket) -async def cookie_handler(websocket): - user = websocket.user +async def query_param_auth(connection, request): + """Authenticate user from token in query parameter.""" + token = get_query_param(request.path, "token") + if token is None: + return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Missing token\n") - await websocket.send(f"Hello {user}!") - message = await websocket.recv() - assert message == f"Goodbye {user}." - - -# Adding credentials to the WebSocket URI in user information - - -class UserInfoProtocol(BasicAuthWebSocketServerProtocol): - async def check_credentials(self, username, password): - if username != "token": - return False - - user = get_user(password) - if user is None: - return False + user = get_user(token) + if user is None: + return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Invalid token\n") + + connection.username = user + + +async def cookie_auth(connection, request): + """Authenticate user from token in cookie.""" + if "Upgrade" not in request.headers: + template = pathlib.Path(__file__).with_name(request.path[1:]) + body = template.read_bytes() + headers = Headers( + { + "Date": email.utils.formatdate(usegmt=True), + "Connection": "close", + "Content-Length": str(len(body)), + "Content-Type": CONTENT_TYPES[template.suffix], + } + ) + return Response(200, "OK", headers, body) + + token = get_cookie(request.headers.get("Cookie", ""), "token") + if token is None: + return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Missing token\n") - self.user = user - return True + user = get_user(token) + if user is None: + return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Invalid token\n") + connection.username = user -async def user_info_handler(websocket): - user = websocket.user - await websocket.send(f"Hello {user}!") - message = await websocket.recv() - assert message == f"Goodbye {user}." +def check_credentials(username, password): + """Authenticate user with HTTP Basic Auth.""" + return username == get_user(password) -# Start all five servers +basic_auth = websockets_basic_auth(check_credentials=check_credentials) async def main(): + """Start one HTTP server and four WebSocket servers.""" # Set the stop condition when receiving SIGINT or SIGTERM. loop = asyncio.get_running_loop() stop = loop.create_future() loop.add_signal_handler(signal.SIGINT, stop.set_result, None) loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - async with serve( - noop_handler, - host="", - port=8000, - process_request=serve_html, - ), serve( - first_message_handler, - host="", - port=8001, - ), serve( - query_param_handler, - host="", - port=8002, - create_protocol=QueryParamProtocol, - ), serve( - cookie_handler, - host="", - port=8003, - create_protocol=CookieProtocol, - ), serve( - user_info_handler, - host="", - port=8004, - create_protocol=UserInfoProtocol, + async with ( + serve(handler, host="", port=8000, process_request=serve_html), + serve(first_message_handler, host="", port=8001), + serve(handler, host="", port=8002, process_request=query_param_auth), + serve(handler, host="", port=8003, process_request=cookie_auth), + serve(handler, host="", port=8004, process_request=basic_auth), ): print("Running on http://localhost:8000/") await stop diff --git a/experiments/authentication/script.js b/experiments/authentication/script.js index ec4e5e670..01dd5b168 100644 --- a/experiments/authentication/script.js +++ b/experiments/authentication/script.js @@ -1,4 +1,5 @@ -var token = window.parent.token; +var token = window.parent.token, + user = window.parent.user; function getExpectedEvents() { return [ @@ -7,7 +8,7 @@ function getExpectedEvents() { }, { type: "message", - data: `Hello ${window.parent.user}!`, + data: `Hello ${user}!`, }, { type: "close", diff --git a/experiments/authentication/test.js b/experiments/authentication/test.js index 428830ff3..e05ca697e 100644 --- a/experiments/authentication/test.js +++ b/experiments/authentication/test.js @@ -1,6 +1,4 @@ -// for connecting to WebSocket servers var token = document.body.dataset.token; -// for test assertions only const params = new URLSearchParams(window.location.search); var user = params.get("user"); diff --git a/experiments/authentication/user_info.js b/experiments/authentication/user_info.js index 1dab2ce4c..bc9a3f148 100644 --- a/experiments/authentication/user_info.js +++ b/experiments/authentication/user_info.js @@ -1,5 +1,5 @@ window.addEventListener("DOMContentLoaded", () => { - const uri = `ws://token:${token}@localhost:8004/`; + const uri = `ws://${user}:${token}@localhost:8004/`; const websocket = new WebSocket(uri); websocket.onmessage = ({ data }) => { diff --git a/experiments/broadcast/server.py b/experiments/broadcast/server.py index d5b50bd71..eca55357e 100644 --- a/experiments/broadcast/server.py +++ b/experiments/broadcast/server.py @@ -6,8 +6,8 @@ import sys import time -from websockets import ConnectionClosed from websockets.asyncio.server import broadcast, serve +from websockets.exceptions import ConnectionClosed CLIENTS = set() diff --git a/experiments/compression/corpus.py b/experiments/compression/corpus.py index da5661dfa..56e262114 100644 --- a/experiments/compression/corpus.py +++ b/experiments/compression/corpus.py @@ -47,6 +47,6 @@ def main(corpus): if __name__ == "__main__": if len(sys.argv) < 2: - print(f"Usage: {sys.argv[0]} [directory]") + print(f"Usage: {sys.argv[0]} ") sys.exit(2) main(pathlib.Path(sys.argv[1])) diff --git a/experiments/profiling/compression.py b/experiments/profiling/compression.py new file mode 100644 index 000000000..1ece1f10e --- /dev/null +++ b/experiments/profiling/compression.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python + +""" +Profile the permessage-deflate extension. + +Usage:: + $ pip install line_profiler + $ python experiments/compression/corpus.py experiments/compression/corpus + $ PYTHONPATH=src python -m kernprof \ + --line-by-line \ + --prof-mod src/websockets/extensions/permessage_deflate.py \ + --view \ + experiments/profiling/compression.py experiments/compression/corpus 12 5 6 + +""" + +import pathlib +import sys + +from websockets.extensions.permessage_deflate import PerMessageDeflate +from websockets.frames import OP_TEXT, Frame + + +def compress_and_decompress(corpus, max_window_bits, memory_level, level): + extension = PerMessageDeflate( + remote_no_context_takeover=False, + local_no_context_takeover=False, + remote_max_window_bits=max_window_bits, + local_max_window_bits=max_window_bits, + compress_settings={"memLevel": memory_level, "level": level}, + ) + for data in corpus: + frame = Frame(OP_TEXT, data) + frame = extension.encode(frame) + frame = extension.decode(frame) + + +if __name__ == "__main__": + if len(sys.argv) < 2 or not pathlib.Path(sys.argv[1]).is_dir(): + print(f"Usage: {sys.argv[0]} [] []") + corpus = [file.read_bytes() for file in pathlib.Path(sys.argv[1]).iterdir()] + max_window_bits = int(sys.argv[2]) if len(sys.argv) > 2 else 12 + memory_level = int(sys.argv[3]) if len(sys.argv) > 3 else 5 + level = int(sys.argv[4]) if len(sys.argv) > 4 else 6 + compress_and_decompress(corpus, max_window_bits, memory_level, level) diff --git a/pyproject.toml b/pyproject.toml index fde9c3226..4e26c757e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta" [project] name = "websockets" description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" -requires-python = ">=3.8" +requires-python = ">=3.9" license = { text = "BSD-3-Clause" } authors = [ { name = "Aymeric Augustin", email = "aymeric.augustin@m4x.org" }, @@ -19,7 +19,6 @@ classifiers = [ "Operating System :: OS Independent", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", @@ -61,15 +60,19 @@ source = [ [tool.coverage.report] exclude_lines = [ + "pragma: no cover", "except ImportError:", "if self.debug:", "if sys.platform != \"win32\":", "if typing.TYPE_CHECKING:", - "pragma: no cover", "raise AssertionError", "self.fail\\(\".*\"\\)", "@unittest.skip", ] +partial_branches = [ + "pragma: no branch", + "with self.assertRaises\\(.*\\)", +] [tool.ruff] target-version = "py312" diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 54591e9fd..0c7e9b4c6 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -7,6 +7,14 @@ __all__ = [ + # .asyncio.client + "connect", + "unix_connect", + # .asyncio.server + "basic_auth", + "broadcast", + "serve", + "unix_serve", # .client "ClientProtocol", # .datastructures @@ -35,27 +43,6 @@ "ProtocolError", "SecurityError", "WebSocketException", - "WebSocketProtocolError", - # .legacy.auth - "BasicAuthWebSocketServerProtocol", - "basic_auth_protocol_factory", - # .legacy.client - "WebSocketClientProtocol", - "connect", - "unix_connect", - # .legacy.exceptions - "AbortHandshake", - "InvalidMessage", - "InvalidStatusCode", - "RedirectHandshake", - # .legacy.protocol - "WebSocketCommonProtocol", - # .legacy.server - "WebSocketServer", - "WebSocketServerProtocol", - "broadcast", - "serve", - "unix_serve", # .server "ServerProtocol", # .typing @@ -70,6 +57,8 @@ # When type checking, import non-deprecated aliases eagerly. Else, import on demand. if typing.TYPE_CHECKING: + from .asyncio.client import connect, unix_connect + from .asyncio.server import basic_auth, broadcast, serve, unix_serve from .client import ClientProtocol from .datastructures import Headers, HeadersLike, MultipleValuesError from .exceptions import ( @@ -94,26 +83,6 @@ ProtocolError, SecurityError, WebSocketException, - WebSocketProtocolError, - ) - from .legacy.auth import ( - BasicAuthWebSocketServerProtocol, - basic_auth_protocol_factory, - ) - from .legacy.client import WebSocketClientProtocol, connect, unix_connect - from .legacy.exceptions import ( - AbortHandshake, - InvalidMessage, - InvalidStatusCode, - RedirectHandshake, - ) - from .legacy.protocol import WebSocketCommonProtocol - from .legacy.server import ( - WebSocketServer, - WebSocketServerProtocol, - broadcast, - serve, - unix_serve, ) from .server import ServerProtocol from .typing import ( @@ -129,6 +98,14 @@ lazy_import( globals(), aliases={ + # .asyncio.client + "connect": ".asyncio.client", + "unix_connect": ".asyncio.client", + # .asyncio.server + "basic_auth": ".asyncio.server", + "broadcast": ".asyncio.server", + "serve": ".asyncio.server", + "unix_serve": ".asyncio.server", # .client "ClientProtocol": ".client", # .datastructures @@ -157,27 +134,6 @@ "ProtocolError": ".exceptions", "SecurityError": ".exceptions", "WebSocketException": ".exceptions", - "WebSocketProtocolError": ".exceptions", - # .legacy.auth - "BasicAuthWebSocketServerProtocol": ".legacy.auth", - "basic_auth_protocol_factory": ".legacy.auth", - # .legacy.client - "WebSocketClientProtocol": ".legacy.client", - "connect": ".legacy.client", - "unix_connect": ".legacy.client", - # .legacy.exceptions - "AbortHandshake": ".legacy.exceptions", - "InvalidMessage": ".legacy.exceptions", - "InvalidStatusCode": ".legacy.exceptions", - "RedirectHandshake": ".legacy.exceptions", - # .legacy.protocol - "WebSocketCommonProtocol": ".legacy.protocol", - # .legacy.server - "WebSocketServer": ".legacy.server", - "WebSocketServerProtocol": ".legacy.server", - "broadcast": ".legacy.server", - "serve": ".legacy.server", - "unix_serve": ".legacy.server", # .server "ServerProtocol": ".server", # .typing @@ -195,5 +151,22 @@ "handshake": ".legacy", "parse_uri": ".uri", "WebSocketURI": ".uri", + # deprecated in 14.0 - 2024-11-09 + # .legacy.auth + "BasicAuthWebSocketServerProtocol": ".legacy.auth", + "basic_auth_protocol_factory": ".legacy.auth", + # .legacy.client + "WebSocketClientProtocol": ".legacy.client", + # .legacy.exceptions + "AbortHandshake": ".legacy.exceptions", + "InvalidMessage": ".legacy.exceptions", + "InvalidStatusCode": ".legacy.exceptions", + "RedirectHandshake": ".legacy.exceptions", + "WebSocketProtocolError": ".legacy.exceptions", + # .legacy.protocol + "WebSocketCommonProtocol": ".legacy.protocol", + # .legacy.server + "WebSocketServer": ".legacy.server", + "WebSocketServerProtocol": ".legacy.server", }, ) diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index b1beb3e00..cdd9bfac6 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -3,9 +3,11 @@ import asyncio import logging import os +import traceback import urllib.parse +from collections.abc import AsyncIterator, Generator, Sequence from types import TracebackType -from typing import Any, AsyncIterator, Callable, Generator, Sequence +from typing import Any, Callable from ..client import ClientProtocol, backoff from ..datastructures import HeadersLike @@ -44,7 +46,7 @@ class ClientConnection(Connection): closed with any other code. The ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_queue``, - and ``write_limit`` arguments the same meaning as in :func:`connect`. + and ``write_limit`` arguments have the same meaning as in :func:`connect`. Args: protocol: Sans-I/O connection. @@ -58,7 +60,7 @@ def __init__( ping_interval: float | None = 20, ping_timeout: float | None = 20, close_timeout: float | None = 10, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, ) -> None: self.protocol: ClientProtocol @@ -94,9 +96,9 @@ async def handshake( return_when=asyncio.FIRST_COMPLETED, ) - # self.protocol.handshake_exc is always set when the connection is lost - # before receiving a response, when the response cannot be parsed, or - # when the response fails the handshake. + # self.protocol.handshake_exc is set when the connection is lost before + # receiving a response, when the response cannot be parsed, or when the + # response fails the handshake. if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc @@ -180,7 +182,7 @@ class connect: async for websocket in connect(...): try: ... - except websockets.ConnectionClosed: + except websockets.exceptions.ConnectionClosed: continue If the connection fails with a transient error, it is retried with @@ -220,7 +222,8 @@ class connect: max_queue: High-water mark of the buffer where frames are received. It defaults to 16 frames. The low-water mark defaults to ``max_queue // 4``. You may pass a ``(high, low)`` tuple to set the high-water - and low-water marks. + and low-water marks. If you want to disable flow control entirely, + you may set it to ``None``, although that's a bad idea. write_limit: High-water mark of write buffer in bytes. It is passed to :meth:`~asyncio.WriteTransport.set_write_buffer_limits`. It defaults to 32 KiB. You may pass a ``(high, low)`` tuple to set the @@ -281,7 +284,7 @@ def __init__( close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, # Logging logger: LoggerLike | None = None, @@ -351,10 +354,10 @@ def factory() -> ClientConnection: kwargs.setdefault("ssl", True) kwargs.setdefault("server_hostname", wsuri.host) if kwargs.get("ssl") is None: - raise TypeError("ssl=None is incompatible with a wss:// URI") + raise ValueError("ssl=None is incompatible with a wss:// URI") else: if kwargs.get("ssl") is not None: - raise TypeError("ssl argument is incompatible with a ws:// URI") + raise ValueError("ssl argument is incompatible with a ws:// URI") if kwargs.pop("unix", False): _, connection = await loop.create_unix_connection(factory, **kwargs) @@ -492,7 +495,7 @@ async def __aexit__( # async for ... in connect(...): async def __aiter__(self) -> AsyncIterator[ClientConnection]: - delays: Generator[float, None, None] | None = None + delays: Generator[float] | None = None while True: try: async with self as protocol: @@ -520,9 +523,10 @@ async def __aiter__(self) -> AsyncIterator[ClientConnection]: delays = backoff() delay = next(delays) self.logger.info( - "! connect failed; reconnecting in %.1f seconds", + "connect failed; reconnecting in %.1f seconds: %s", delay, - exc_info=True, + # Remove first argument when dropping Python 3.9. + traceback.format_exception_only(type(exc), exc)[0].strip(), ) await asyncio.sleep(delay) continue diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 6af61a4a9..e5c350fe2 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -7,17 +7,11 @@ import random import struct import sys +import traceback import uuid +from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterable, Mapping from types import TracebackType -from typing import ( - Any, - AsyncIterable, - AsyncIterator, - Awaitable, - Iterable, - Mapping, - cast, -) +from typing import Any, cast from ..exceptions import ( ConcurrencyError, @@ -62,14 +56,14 @@ def __init__( ping_interval: float | None = 20, ping_timeout: float | None = 20, close_timeout: float | None = 10, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, ) -> None: self.protocol = protocol self.ping_interval = ping_interval self.ping_timeout = ping_timeout self.close_timeout = close_timeout - if isinstance(max_queue, int): + if isinstance(max_queue, int) or max_queue is None: max_queue = (max_queue, None) self.max_queue = max_queue if isinstance(write_limit, int): @@ -191,6 +185,30 @@ def subprotocol(self) -> Subprotocol | None: """ return self.protocol.subprotocol + @property + def close_code(self) -> int | None: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should inspect attributes + of :exc:`~websockets.exceptions.ConnectionClosed` exceptions. + + """ + return self.protocol.close_code + + @property + def close_reason(self) -> str | None: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should inspect attributes + of :exc:`~websockets.exceptions.ConnectionClosed` exceptions. + + """ + return self.protocol.close_reason + # Public methods async def __aenter__(self) -> Connection: @@ -258,12 +276,13 @@ async def recv(self, decode: bool | None = None) -> Data: You may override this behavior with the ``decode`` argument: - * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames - and return a bytestring (:class:`bytes`). This may be useful to - optimize performance when decoding isn't needed. + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames and + return a bytestring (:class:`bytes`). This improves performance + when decoding isn't needed, for example if the message contains + JSON and you're using a JSON library that expects a bytestring. * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames - and return a string (:class:`str`). This is useful for servers - that send binary frames instead of text frames. + and return a string (:class:`str`). This may be useful for + servers that send binary frames instead of text frames. Raises: ConnectionClosed: When the connection is closed. @@ -274,14 +293,24 @@ async def recv(self, decode: bool | None = None) -> Data: try: return await self.recv_messages.get(decode) except EOFError: - # Wait for the protocol state to be CLOSED before accessing close_exc. - await asyncio.shield(self.connection_lost_waiter) - raise self.protocol.close_exc from self.recv_exc + pass + # fallthrough except ConcurrencyError: raise ConcurrencyError( "cannot call recv while another coroutine " "is already running recv or recv_streaming" ) from None + except UnicodeDecodeError as exc: + async with self.send_context(): + self.protocol.fail( + CloseCode.INVALID_DATA, + f"{exc.reason} at position {exc.start}", + ) + # fallthrough + + # Wait for the protocol state to be CLOSED before accessing close_exc. + await asyncio.shield(self.connection_lost_waiter) + raise self.protocol.close_exc from self.recv_exc async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data]: """ @@ -330,17 +359,32 @@ async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data try: async for frame in self.recv_messages.get_iter(decode): yield frame + return except EOFError: - # Wait for the protocol state to be CLOSED before accessing close_exc. - await asyncio.shield(self.connection_lost_waiter) - raise self.protocol.close_exc from self.recv_exc + pass + # fallthrough except ConcurrencyError: raise ConcurrencyError( "cannot call recv_streaming while another coroutine " "is already running recv or recv_streaming" ) from None + except UnicodeDecodeError as exc: + async with self.send_context(): + self.protocol.fail( + CloseCode.INVALID_DATA, + f"{exc.reason} at position {exc.start}", + ) + # fallthrough + + # Wait for the protocol state to be CLOSED before accessing close_exc. + await asyncio.shield(self.connection_lost_waiter) + raise self.protocol.close_exc from self.recv_exc - async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> None: + async def send( + self, + message: Data | Iterable[Data] | AsyncIterable[Data], + text: bool | None = None, + ) -> None: """ Send a message. @@ -351,6 +395,17 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + You may override this behavior with the ``text`` argument: + + * Set ``text=True`` to send a bytestring or bytes-like object + (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) as a + Text_ frame. This improves performance when the message is already + UTF-8 encoded, for example if the message contains JSON and you're + using a JSON library that produces a bytestring. + * Set ``text=False`` to send a string (:class:`str`) in a Binary_ + frame. This may be useful for servers that expect binary frames + instead of text frames. + :meth:`send` also accepts an iterable or an asynchronous iterable of strings, bytestrings, or bytes-like objects to enable fragmentation_. Each item is treated as a message fragment and sent in its own frame. @@ -401,11 +456,17 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No if isinstance(message, str): async with self.send_context(): - self.protocol.send_text(message.encode()) + if text is False: + self.protocol.send_binary(message.encode()) + else: + self.protocol.send_text(message.encode()) elif isinstance(message, BytesLike): async with self.send_context(): - self.protocol.send_binary(message) + if text is True: + self.protocol.send_text(message) + else: + self.protocol.send_binary(message) # Catch a common mistake -- passing a dict to send(). @@ -426,36 +487,30 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No try: # First fragment. if isinstance(chunk, str): - text = True async with self.send_context(): - self.protocol.send_text( - chunk.encode(), - fin=False, - ) + if text is False: + self.protocol.send_binary(chunk.encode(), fin=False) + else: + self.protocol.send_text(chunk.encode(), fin=False) + encode = True elif isinstance(chunk, BytesLike): - text = False async with self.send_context(): - self.protocol.send_binary( - chunk, - fin=False, - ) + if text is True: + self.protocol.send_text(chunk, fin=False) + else: + self.protocol.send_binary(chunk, fin=False) + encode = False else: raise TypeError("iterable must contain bytes or str") # Other fragments for chunk in chunks: - if isinstance(chunk, str) and text: + if isinstance(chunk, str) and encode: async with self.send_context(): - self.protocol.send_continuation( - chunk.encode(), - fin=False, - ) - elif isinstance(chunk, BytesLike) and not text: + self.protocol.send_continuation(chunk.encode(), fin=False) + elif isinstance(chunk, BytesLike) and not encode: async with self.send_context(): - self.protocol.send_continuation( - chunk, - fin=False, - ) + self.protocol.send_continuation(chunk, fin=False) else: raise TypeError("iterable must contain uniform types") @@ -467,7 +522,10 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No # We're half-way through a fragmented message and we can't # complete it. This makes the connection unusable. async with self.send_context(): - self.protocol.fail(1011, "error in fragmented message") + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "error in fragmented message", + ) raise finally: @@ -488,36 +546,32 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No try: # First fragment. if isinstance(chunk, str): - text = True - async with self.send_context(): - self.protocol.send_text( - chunk.encode(), - fin=False, - ) + if text is False: + async with self.send_context(): + self.protocol.send_binary(chunk.encode(), fin=False) + else: + async with self.send_context(): + self.protocol.send_text(chunk.encode(), fin=False) + encode = True elif isinstance(chunk, BytesLike): - text = False - async with self.send_context(): - self.protocol.send_binary( - chunk, - fin=False, - ) + if text is True: + async with self.send_context(): + self.protocol.send_text(chunk, fin=False) + else: + async with self.send_context(): + self.protocol.send_binary(chunk, fin=False) + encode = False else: raise TypeError("async iterable must contain bytes or str") # Other fragments async for chunk in achunks: - if isinstance(chunk, str) and text: + if isinstance(chunk, str) and encode: async with self.send_context(): - self.protocol.send_continuation( - chunk.encode(), - fin=False, - ) - elif isinstance(chunk, BytesLike) and not text: + self.protocol.send_continuation(chunk.encode(), fin=False) + elif isinstance(chunk, BytesLike) and not encode: async with self.send_context(): - self.protocol.send_continuation( - chunk, - fin=False, - ) + self.protocol.send_continuation(chunk, fin=False) else: raise TypeError("async iterable must contain uniform types") @@ -529,7 +583,10 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No # We're half-way through a fragmented message and we can't # complete it. This makes the connection unusable. async with self.send_context(): - self.protocol.fail(1011, "error in fragmented message") + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "error in fragmented message", + ) raise finally: @@ -559,7 +616,10 @@ async def close(self, code: int = 1000, reason: str = "") -> None: # to terminate after calling a method that sends a close frame. async with self.send_context(): if self.fragmented_send_waiter is not None: - self.protocol.fail(1011, "close during fragmented message") + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "close during fragmented message", + ) else: self.protocol.send_close(code, reason) except ConnectionClosed: @@ -750,7 +810,7 @@ async def keepalive(self) -> None: self.logger.debug("% received keepalive pong") except asyncio.TimeoutError: if self.debug: - self.logger.debug("! timed out waiting for keepalive pong") + self.logger.debug("- timed out waiting for keepalive pong") async with self.send_context(): self.protocol.fail( CloseCode.INTERNAL_ERROR, @@ -831,7 +891,7 @@ async def send_context( await self.drain() except Exception as exc: if self.debug: - self.logger.debug("error while sending data", exc_info=True) + self.logger.debug("! error while sending data", exc_info=True) # While the only expected exception here is OSError, # other exceptions would be treated identically. wait_for_close = False @@ -1007,7 +1067,7 @@ def data_received(self, data: bytes) -> None: self.send_data() except Exception as exc: if self.debug: - self.logger.debug("error while sending data", exc_info=True) + self.logger.debug("! error while sending data", exc_info=True) self.set_recv_exc(exc) if self.protocol.close_expected(): @@ -1025,13 +1085,22 @@ def eof_received(self) -> None: # Feed the end of the data stream to the connection. self.protocol.receive_eof() - # This isn't expected to generate events. - assert not self.protocol.events_received() + # This isn't expected to raise an exception. + events = self.protocol.events_received() # There is no error handling because send_data() can only write # the end of the data stream here and it shouldn't raise errors. self.send_data() + # This code path is triggered when receiving an HTTP response + # without a Content-Length header. This is the only case where + # reading until EOF generates an event; all other events have + # a known length. Ignore for coverage measurement because tests + # are in test_client.py rather than test_connection.py. + for event in events: # pragma: no cover + # This isn't expected to raise an exception. + self.process_event(event) + # The WebSocket protocol has its own closing handshake: endpoints close # the TCP or TLS connection after sending and receiving a close frame. # As a consequence, they never need to write after receiving EOF, so @@ -1136,8 +1205,12 @@ def broadcast( exceptions.append(exception) else: connection.logger.warning( - "skipped broadcast: failed to write message", - exc_info=True, + "skipped broadcast: failed to write message: %s", + traceback.format_exception_only( + # Remove first argument when dropping Python 3.9. + type(write_exception), + write_exception, + )[0].strip(), ) if raise_exceptions and exceptions: diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index c2b4afd67..e6d1d31cc 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -3,14 +3,8 @@ import asyncio import codecs import collections -from typing import ( - Any, - AsyncIterator, - Callable, - Generic, - Iterable, - TypeVar, -) +from collections.abc import AsyncIterator, Iterable +from typing import Any, Callable, Generic, TypeVar from ..exceptions import ConcurrencyError from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame @@ -46,11 +40,12 @@ def put(self, item: T) -> None: if self.get_waiter is not None and not self.get_waiter.done(): self.get_waiter.set_result(None) - async def get(self) -> T: + async def get(self, block: bool = True) -> T: """Remove and return an item from the queue, waiting if necessary.""" if not self.queue: - if self.get_waiter is not None: - raise ConcurrencyError("get is already running") + if not block: + raise EOFError("stream of frames ended") + assert self.get_waiter is None, "cannot call get() concurrently" self.get_waiter = self.loop.create_future() try: await self.get_waiter @@ -66,10 +61,9 @@ def reset(self, items: Iterable[T]) -> None: self.queue.extend(items) def abort(self) -> None: + """Close the queue, raising EOFError in get() if necessary.""" if self.get_waiter is not None and not self.get_waiter.done(): self.get_waiter.set_exception(EOFError("stream of frames ended")) - # Clear the queue to avoid storing unnecessary data in memory. - self.queue.clear() class Assembler: @@ -90,24 +84,27 @@ class Assembler: # coverage reports incorrectly: "line NN didn't jump to the function exit" def __init__( # pragma: no cover self, - high: int = 16, + high: int | None = None, low: int | None = None, pause: Callable[[], Any] = lambda: None, resume: Callable[[], Any] = lambda: None, ) -> None: - # Queue of incoming messages. Each item is a queue of frames. + # Queue of incoming frames. self.frames: SimpleQueue[Frame] = SimpleQueue() # We cannot put a hard limit on the size of the queue because a single # call to Protocol.data_received() could produce thousands of frames, # which must be buffered. Instead, we pause reading when the buffer goes # above the high limit and we resume when it goes under the low limit. - if low is None: + if high is not None and low is None: low = high // 4 - if low < 0: - raise ValueError("low must be positive or equal to zero") - if high < low: - raise ValueError("high must be greater than or equal to low") + if high is None and low is not None: + high = low * 4 + if high is not None and low is not None: + if low < 0: + raise ValueError("low must be positive or equal to zero") + if high < low: + raise ValueError("high must be greater than or equal to low") self.high, self.low = high, low self.pause = pause self.resume = resume @@ -136,46 +133,42 @@ async def get(self, decode: bool | None = None) -> Data: Raises: EOFError: If the stream of frames has ended. + UnicodeDecodeError: If a text frame contains invalid UTF-8. ConcurrencyError: If two coroutines run :meth:`get` or :meth:`get_iter` concurrently. """ - if self.closed: - raise EOFError("stream of frames ended") - if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - - # Locking with get_in_progress ensures only one coroutine can get here. self.get_in_progress = True - # First frame - try: - frame = await self.frames.get() - except asyncio.CancelledError: - self.get_in_progress = False - raise - self.maybe_resume() - assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY - if decode is None: - decode = frame.opcode is OP_TEXT - frames = [frame] + # Locking with get_in_progress prevents concurrent execution + # until get() fetches a complete message or is cancelled. - # Following frames, for fragmented messages - while not frame.fin: - try: - frame = await self.frames.get() - except asyncio.CancelledError: - # Put frames already received back into the queue - # so that future calls to get() can return them. - self.frames.reset(frames) - self.get_in_progress = False - raise + try: + # First frame + frame = await self.frames.get(not self.closed) self.maybe_resume() - assert frame.opcode is OP_CONT - frames.append(frame) - - self.get_in_progress = False + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + frames = [frame] + + # Following frames, for fragmented messages + while not frame.fin: + try: + frame = await self.frames.get(not self.closed) + except asyncio.CancelledError: + # Put frames already received back into the queue + # so that future calls to get() can return them. + self.frames.reset(frames) + raise + self.maybe_resume() + assert frame.opcode is OP_CONT + frames.append(frame) + + finally: + self.get_in_progress = False data = b"".join(frame.data for frame in frames) if decode: @@ -203,22 +196,24 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: Raises: EOFError: If the stream of frames has ended. + UnicodeDecodeError: If a text frame contains invalid UTF-8. ConcurrencyError: If two coroutines run :meth:`get` or :meth:`get_iter` concurrently. """ - if self.closed: - raise EOFError("stream of frames ended") - if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - - # Locking with get_in_progress ensures only one coroutine can get here. self.get_in_progress = True + # Locking with get_in_progress prevents concurrent execution + # until get_iter() fetches a complete message or is cancelled. + + # If get_iter() raises an exception e.g. in decoder.decode(), + # get_in_progress remains set and the connection becomes unusable. + # First frame try: - frame = await self.frames.get() + frame = await self.frames.get(not self.closed) except asyncio.CancelledError: self.get_in_progress = False raise @@ -238,7 +233,7 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: # previous fragments — we're streaming them. Canceling get_iter() # here will leave the assembler in a stuck state. Future calls to # get() or get_iter() will raise ConcurrencyError. - frame = await self.frames.get() + frame = await self.frames.get(not self.closed) self.maybe_resume() assert frame.opcode is OP_CONT if decode: @@ -264,6 +259,10 @@ def put(self, frame: Frame) -> None: def maybe_pause(self) -> None: """Pause the writer if queue is above the high water mark.""" + # Skip if flow control is disabled + if self.high is None: + return + # Check for "> high" to support high = 0 if len(self.frames) > self.high and not self.paused: self.paused = True @@ -271,6 +270,10 @@ def maybe_pause(self) -> None: def maybe_resume(self) -> None: """Resume the writer if queue is below the low water mark.""" + # Skip if flow control is disabled + if self.low is None: + return + # Check for "<= low" to support low = 0 if len(self.frames) <= self.low and self.paused: self.paused = False diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 19dae44b7..fdb928004 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -6,17 +6,9 @@ import logging import socket import sys +from collections.abc import Awaitable, Generator, Iterable, Sequence from types import TracebackType -from typing import ( - Any, - Awaitable, - Callable, - Generator, - Iterable, - Sequence, - Tuple, - cast, -) +from typing import Any, Callable, cast from ..exceptions import InvalidHeader from ..extensions.base import ServerExtensionFactory @@ -63,7 +55,7 @@ class ServerConnection(Connection): closed with any other code. The ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_queue``, - and ``write_limit`` arguments the same meaning as in :func:`serve`. + and ``write_limit`` arguments have the same meaning as in :func:`serve`. Args: protocol: Sans-I/O connection. @@ -79,7 +71,7 @@ def __init__( ping_interval: float | None = 20, ping_timeout: float | None = 20, close_timeout: float | None = 10, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, ) -> None: self.protocol: ServerProtocol @@ -200,10 +192,13 @@ async def handshake( self.protocol.send_response(self.response) - # self.protocol.handshake_exc is always set when the connection is lost - # before receiving a request, when the request cannot be parsed, when - # the handshake encounters an error, or when process_request or - # process_response sends an HTTP response that rejects the handshake. + # self.protocol.handshake_exc is set when the connection is lost before + # receiving a request, when the request cannot be parsed, or when the + # handshake fails, including when process_request or process_response + # raises an exception. + + # It isn't set when process_request or process_response sends an HTTP + # response that rejects the handshake. if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc @@ -368,7 +363,11 @@ async def conn_handler(self, connection: ServerConnection) -> None: connection.close_transport() return - assert connection.protocol.state is OPEN + if connection.protocol.state is not OPEN: + # process_request or process_response rejected the handshake. + connection.close_transport() + return + try: connection.start_keepalive() await self.handler(connection) @@ -644,7 +643,8 @@ def handler(websocket): max_queue: High-water mark of the buffer where frames are received. It defaults to 16 frames. The low-water mark defaults to ``max_queue // 4``. You may pass a ``(high, low)`` tuple to set the high-water - and low-water marks. + and low-water marks. If you want to disable flow control entirely, + you may set it to ``None``, although that's a bad idea. write_limit: High-water mark of write buffer in bytes. It is passed to :meth:`~asyncio.WriteTransport.set_write_buffer_limits`. It defaults to 32 KiB. You may pass a ``(high, low)`` tuple to set the @@ -714,7 +714,7 @@ def __init__( close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, # Logging logger: LoggerLike | None = None, @@ -898,16 +898,18 @@ def basic_auth( whether they're valid. Raises: TypeError: If ``credentials`` or ``check_credentials`` is wrong. + ValueError: If ``credentials`` and ``check_credentials`` are both + provided or both not provided. """ if (credentials is None) == (check_credentials is None): - raise TypeError("provide either credentials or check_credentials") + raise ValueError("provide either credentials or check_credentials") if credentials is not None: if is_credentials(credentials): - credentials_list = [cast(Tuple[str, str], credentials)] + credentials_list = [cast(tuple[str, str], credentials)] elif isinstance(credentials, Iterable): - credentials_list = list(cast(Iterable[Tuple[str, str]], credentials)) + credentials_list = list(cast(Iterable[tuple[str, str]], credentials)) if not all(is_credentials(item) for item in credentials_list): raise TypeError(f"invalid credentials argument: {credentials}") else: diff --git a/src/websockets/auth.py b/src/websockets/auth.py index b792e02f5..15b70a372 100644 --- a/src/websockets/auth.py +++ b/src/websockets/auth.py @@ -1,6 +1,18 @@ from __future__ import annotations -# See #940 for why lazy_import isn't used here for backwards compatibility. -# See #1400 for why listing compatibility imports in __all__ helps PyCharm. -from .legacy.auth import * -from .legacy.auth import __all__ # noqa: F401 +import warnings + + +with warnings.catch_warnings(): + # Suppress redundant DeprecationWarning raised by websockets.legacy. + warnings.filterwarnings("ignore", category=DeprecationWarning) + from .legacy.auth import * + from .legacy.auth import __all__ # noqa: F401 + + +warnings.warn( # deprecated in 14.0 - 2024-11-09 + "websockets.auth, an alias for websockets.legacy.auth, is deprecated; " + "see https://websockets.readthedocs.io/en/stable/howto/upgrade.html " + "for upgrade instructions", + DeprecationWarning, +) diff --git a/src/websockets/client.py b/src/websockets/client.py index e5f294986..f6cbc9f65 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -3,7 +3,8 @@ import os import random import warnings -from typing import Any, Generator, Sequence +from collections.abc import Generator, Sequence +from typing import Any from .datastructures import Headers, MultipleValuesError from .exceptions import ( @@ -26,6 +27,7 @@ parse_upgrade, ) from .http11 import Request, Response +from .imports import lazy_import from .protocol import CLIENT, CONNECTING, OPEN, Protocol, State from .typing import ( ConnectionOption, @@ -39,13 +41,7 @@ from .utils import accept_key, generate_key -# See #940 for why lazy_import isn't used here for backwards compatibility. -# See #1400 for why listing compatibility imports in __all__ helps PyCharm. -from .legacy.client import * # isort:skip # noqa: I001 -from .legacy.client import __all__ as legacy__all__ - - -__all__ = ["ClientProtocol"] + legacy__all__ +__all__ = ["ClientProtocol"] class ClientProtocol(Protocol): @@ -313,7 +309,7 @@ def send_request(self, request: Request) -> None: self.writes.append(request.serialize()) - def parse(self) -> Generator[None, None, None]: + def parse(self) -> Generator[None]: if self.state is CONNECTING: try: response = yield from Response.parse( @@ -374,7 +370,7 @@ def backoff( min_delay: float = BACKOFF_MIN_DELAY, max_delay: float = BACKOFF_MAX_DELAY, factor: float = BACKOFF_FACTOR, -) -> Generator[float, None, None]: +) -> Generator[float]: """ Generate a series of backoff delays between reconnection attempts. @@ -391,3 +387,14 @@ def backoff( delay *= factor while True: yield max_delay + + +lazy_import( + globals(), + deprecated_aliases={ + # deprecated in 14.0 - 2024-11-09 + "WebSocketClientProtocol": ".legacy.client", + "connect": ".legacy.client", + "unix_connect": ".legacy.client", + }, +) diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index 106d6f393..77b6f86fa 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -1,15 +1,7 @@ from __future__ import annotations -from typing import ( - Any, - Iterable, - Iterator, - Mapping, - MutableMapping, - Protocol, - Tuple, - Union, -) +from collections.abc import Iterable, Iterator, Mapping, MutableMapping +from typing import Any, Protocol, Union __all__ = ["Headers", "HeadersLike", "MultipleValuesError"] @@ -179,8 +171,7 @@ def __getitem__(self, key: str) -> str: ... HeadersLike = Union[ Headers, Mapping[str, str], - # Change to tuple[str, str] when dropping Python < 3.9. - Iterable[Tuple[str, str]], + Iterable[tuple[str, str]], SupportsKeysAndGetItem, ] """ diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index d723f2fec..f3e751971 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -31,7 +31,6 @@ from __future__ import annotations -import typing import warnings from .imports import lazy_import @@ -45,9 +44,7 @@ "InvalidURI", "InvalidHandshake", "SecurityError", - "InvalidMessage", "InvalidStatus", - "InvalidStatusCode", "InvalidHeader", "InvalidHeaderFormat", "InvalidHeaderValue", @@ -57,10 +54,7 @@ "DuplicateParameter", "InvalidParameterName", "InvalidParameterValue", - "AbortHandshake", - "RedirectHandshake", "ProtocolError", - "WebSocketProtocolError", "PayloadTooBig", "InvalidState", "ConcurrencyError", @@ -119,7 +113,7 @@ def __str__(self) -> str: @property def code(self) -> int: - warnings.warn( # deprecated in 13.1 + warnings.warn( # deprecated in 13.1 - 2024-09-21 "ConnectionClosed.code is deprecated; " "use Protocol.close_code or ConnectionClosed.rcvd.code", DeprecationWarning, @@ -130,7 +124,7 @@ def code(self) -> int: @property def reason(self) -> str: - warnings.warn( # deprecated in 13.1 + warnings.warn( # deprecated in 13.1 - 2024-09-21 "ConnectionClosed.reason is deprecated; " "use Protocol.close_reason or ConnectionClosed.rcvd.reason", DeprecationWarning, @@ -340,6 +334,47 @@ class PayloadTooBig(WebSocketException): """ + def __init__( + self, + size_or_message: int | None | str, + max_size: int | None = None, + cur_size: int | None = None, + ) -> None: + if isinstance(size_or_message, str): + assert max_size is None + assert cur_size is None + warnings.warn( # deprecated in 14.0 - 2024-11-09 + "PayloadTooBig(message) is deprecated; " + "change to PayloadTooBig(size, max_size)", + DeprecationWarning, + ) + self.message: str | None = size_or_message + else: + self.message = None + self.size: int | None = size_or_message + assert max_size is not None + self.max_size: int = max_size + self.cur_size: int | None = None + self.set_current_size(cur_size) + + def __str__(self) -> str: + if self.message is not None: + return self.message + else: + message = "frame " + if self.size is not None: + message += f"with {self.size} bytes " + if self.cur_size is not None: + message += f"after reading {self.cur_size} bytes " + message += f"exceeds limit of {self.max_size} bytes" + return message + + def set_current_size(self, cur_size: int | None) -> None: + assert self.cur_size is None + if cur_size is not None: + self.max_size += cur_size + self.cur_size = cur_size + class InvalidState(WebSocketException, AssertionError): """ @@ -366,27 +401,18 @@ class ConcurrencyError(WebSocketException, RuntimeError): """ -# When type checking, import non-deprecated aliases eagerly. Else, import on demand. -if typing.TYPE_CHECKING: - from .legacy.exceptions import ( - AbortHandshake, - InvalidMessage, - InvalidStatusCode, - RedirectHandshake, - ) - - WebSocketProtocolError = ProtocolError -else: - lazy_import( - globals(), - aliases={ - "AbortHandshake": ".legacy.exceptions", - "InvalidMessage": ".legacy.exceptions", - "InvalidStatusCode": ".legacy.exceptions", - "RedirectHandshake": ".legacy.exceptions", - "WebSocketProtocolError": ".legacy.exceptions", - }, - ) - # At the bottom to break import cycles created by type annotations. from . import frames, http11 # noqa: E402 + + +lazy_import( + globals(), + deprecated_aliases={ + # deprecated in 14.0 - 2024-11-09 + "AbortHandshake": ".legacy.exceptions", + "InvalidMessage": ".legacy.exceptions", + "InvalidStatusCode": ".legacy.exceptions", + "RedirectHandshake": ".legacy.exceptions", + "WebSocketProtocolError": ".legacy.exceptions", + }, +) diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index 75bae6b77..42dd6c5fa 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Sequence +from collections.abc import Sequence from ..frames import Frame from ..typing import ExtensionName, ExtensionParameter diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 25d2c1c45..cefad4f56 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -1,8 +1,8 @@ from __future__ import annotations -import dataclasses import zlib -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any from .. import frames from ..exceptions import ( @@ -119,7 +119,6 @@ def decode( else: if not frame.rsv1: return frame - frame = dataclasses.replace(frame, rsv1=False) if not frame.fin: self.decode_cont_data = True @@ -130,22 +129,37 @@ def decode( # Uncompress data. Protect against zip bombs by preventing zlib from # decompressing more than max_length bytes (except when the limit is # disabled with max_size = None). - data = frame.data - if frame.fin: - data += _EMPTY_UNCOMPRESSED_BLOCK + if frame.fin and len(frame.data) < 2044: + # Profiling shows that appending four bytes, which makes a copy, is + # faster than calling decompress() again when data is less than 2kB. + data = bytes(frame.data) + _EMPTY_UNCOMPRESSED_BLOCK + else: + data = frame.data max_length = 0 if max_size is None else max_size try: data = self.decoder.decompress(data, max_length) + if self.decoder.unconsumed_tail: + assert max_size is not None # help mypy + raise PayloadTooBig(None, max_size) + if frame.fin and len(frame.data) >= 2044: + # This cannot generate additional data. + self.decoder.decompress(_EMPTY_UNCOMPRESSED_BLOCK) except zlib.error as exc: raise ProtocolError("decompression failed") from exc - if self.decoder.unconsumed_tail: - raise PayloadTooBig(f"over size limit (? > {max_size} bytes)") # Allow garbage collection of the decoder if it won't be reused. if frame.fin and self.remote_no_context_takeover: del self.decoder - return dataclasses.replace(frame, data=data) + return frames.Frame( + frame.opcode, + data, + frame.fin, + # Unset the rsv1 flag on the first frame of a compressed message. + False, + frame.rsv2, + frame.rsv3, + ) def encode(self, frame: frames.Frame) -> frames.Frame: """ @@ -160,8 +174,6 @@ def encode(self, frame: frames.Frame) -> frames.Frame: # data" flag similar to "decode continuation data" at this time. if frame.opcode is not frames.OP_CONT: - # Set the rsv1 flag on the first frame of a compressed message. - frame = dataclasses.replace(frame, rsv1=True) # Re-initialize per-message decoder. if self.local_no_context_takeover: self.encoder = zlib.compressobj( @@ -171,14 +183,29 @@ def encode(self, frame: frames.Frame) -> frames.Frame: # Compress data. data = self.encoder.compress(frame.data) + self.encoder.flush(zlib.Z_SYNC_FLUSH) - if frame.fin and data.endswith(_EMPTY_UNCOMPRESSED_BLOCK): - data = data[:-4] + if frame.fin: + # Sync flush generates between 5 or 6 bytes, ending with the bytes + # 0x00 0x00 0xff 0xff, which must be removed. + assert data[-4:] == _EMPTY_UNCOMPRESSED_BLOCK + # Making a copy is faster than memoryview(a)[:-4] until 2kB. + if len(data) < 2048: + data = data[:-4] + else: + data = memoryview(data)[:-4] # Allow garbage collection of the encoder if it won't be reused. if frame.fin and self.local_no_context_takeover: del self.encoder - return dataclasses.replace(frame, data=data) + return frames.Frame( + frame.opcode, + data, + frame.fin, + # Set the rsv1 flag on the first frame of a compressed message. + frame.opcode is not frames.OP_CONT, + frame.rsv2, + frame.rsv3, + ) def _build_parameters( diff --git a/src/websockets/frames.py b/src/websockets/frames.py index a63bdc3b6..7898c8a5d 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -6,7 +6,8 @@ import os import secrets import struct -from typing import Callable, Generator, Sequence +from collections.abc import Generator, Sequence +from typing import Callable, Union from .exceptions import PayloadTooBig, ProtocolError @@ -138,7 +139,7 @@ class Frame: """ opcode: Opcode - data: bytes + data: Union[bytes, bytearray, memoryview] fin: bool = True rsv1: bool = False rsv2: bool = False @@ -159,7 +160,7 @@ def __str__(self) -> str: if self.opcode is OP_TEXT: # Decoding only the beginning and the end is needlessly hard. # Decode the entire payload then elide later if necessary. - data = repr(self.data.decode()) + data = repr(bytes(self.data).decode()) elif self.opcode is OP_BINARY: # We'll show at most the first 16 bytes and the last 8 bytes. # Encode just what we need, plus two dummy bytes to elide later. @@ -177,7 +178,7 @@ def __str__(self) -> str: # binary. If self.data is a memoryview, it has no decode() method, # which raises AttributeError. try: - data = repr(self.data.decode()) + data = repr(bytes(self.data).decode()) coding = "text" except (UnicodeDecodeError, AttributeError): binary = self.data @@ -221,7 +222,6 @@ def parse( Raises: EOFError: If the connection is closed without a full WebSocket frame. - UnicodeDecodeError: If the frame contains invalid UTF-8. PayloadTooBig: If the frame's payload size exceeds ``max_size``. ProtocolError: If the frame contains incorrect values. @@ -252,7 +252,7 @@ def parse( data = yield from read_exact(8) (length,) = struct.unpack("!Q", data) if max_size is not None and length > max_size: - raise PayloadTooBig(f"over size limit ({length} > {max_size} bytes)") + raise PayloadTooBig(length, max_size) if mask: mask_bytes = yield from read_exact(4) diff --git a/src/websockets/headers.py b/src/websockets/headers.py index 9103018a0..e05948a1f 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -4,7 +4,8 @@ import binascii import ipaddress import re -from typing import Callable, Sequence, TypeVar, cast +from collections.abc import Sequence +from typing import Callable, TypeVar, cast from .exceptions import InvalidHeaderFormat, InvalidHeaderValue from .typing import ( diff --git a/src/websockets/http.py b/src/websockets/http.py index 0ff5598c7..0d860e537 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -3,7 +3,12 @@ import warnings from .datastructures import Headers, MultipleValuesError # noqa: F401 -from .legacy.http import read_request, read_response # noqa: F401 + + +with warnings.catch_warnings(): + # Suppress redundant DeprecationWarning raised by websockets.legacy. + warnings.filterwarnings("ignore", category=DeprecationWarning) + from .legacy.http import read_request, read_response # noqa: F401 warnings.warn( # deprecated in 9.0 - 2021-09-01 diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 47cef7a9b..af542c77b 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -5,7 +5,8 @@ import re import sys import warnings -from typing import Callable, Generator +from collections.abc import Generator +from typing import Callable from .datastructures import Headers from .exceptions import SecurityError diff --git a/src/websockets/imports.py b/src/websockets/imports.py index bb80e4eac..c63fb212e 100644 --- a/src/websockets/imports.py +++ b/src/websockets/imports.py @@ -1,7 +1,8 @@ from __future__ import annotations import warnings -from typing import Any, Iterable +from collections.abc import Iterable +from typing import Any __all__ = ["lazy_import"] diff --git a/src/websockets/legacy/__init__.py b/src/websockets/legacy/__init__.py index e69de29bb..ad9aa2506 100644 --- a/src/websockets/legacy/__init__.py +++ b/src/websockets/legacy/__init__.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +import warnings + + +warnings.warn( # deprecated in 14.0 - 2024-11-09 + "websockets.legacy is deprecated; " + "see https://websockets.readthedocs.io/en/stable/howto/upgrade.html " + "for upgrade instructions", + DeprecationWarning, +) diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index 4d030e5e2..a262fcd79 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -3,7 +3,8 @@ import functools import hmac import http -from typing import Any, Awaitable, Callable, Iterable, Tuple, cast +from collections.abc import Awaitable, Iterable +from typing import Any, Callable, cast from ..datastructures import Headers from ..exceptions import InvalidHeader @@ -13,8 +14,7 @@ __all__ = ["BasicAuthWebSocketServerProtocol", "basic_auth_protocol_factory"] -# Change to tuple[str, str] when dropping Python < 3.9. -Credentials = Tuple[str, str] +Credentials = tuple[str, str] def is_credentials(value: Any) -> bool: diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index ec4c2ff64..a3856b470 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -5,17 +5,12 @@ import logging import os import random +import traceback import urllib.parse import warnings +from collections.abc import AsyncIterator, Generator, Sequence from types import TracebackType -from typing import ( - Any, - AsyncIterator, - Callable, - Generator, - Sequence, - cast, -) +from typing import Any, Callable, cast from ..asyncio.compatibility import asyncio_timeout from ..datastructures import Headers, HeadersLike @@ -354,7 +349,7 @@ class Connect: async for websocket in connect(...): try: ... - except websockets.ConnectionClosed: + except websockets.exceptions.ConnectionClosed: continue The connection is closed automatically after each iteration of the loop. @@ -603,22 +598,24 @@ async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]: try: async with self as protocol: yield protocol - except Exception: + except Exception as exc: # Add a random initial delay between 0 and 5 seconds. # See 7.2.3. Recovering from Abnormal Closure in RFC 6455. if backoff_delay == self.BACKOFF_MIN: initial_delay = random.random() * self.BACKOFF_INITIAL self.logger.info( - "! connect failed; reconnecting in %.1f seconds", + "connect failed; reconnecting in %.1f seconds: %s", initial_delay, - exc_info=True, + # Remove first argument when dropping Python 3.9. + traceback.format_exception_only(type(exc), exc)[0].strip(), ) await asyncio.sleep(initial_delay) else: self.logger.info( - "! connect failed again; retrying in %d seconds", + "connect failed again; retrying in %d seconds: %s", int(backoff_delay), - exc_info=True, + # Remove first argument when dropping Python 3.9. + traceback.format_exception_only(type(exc), exc)[0].strip(), ) await asyncio.sleep(int(backoff_delay)) # Increase delay with truncated exponential backoff. diff --git a/src/websockets/legacy/exceptions.py b/src/websockets/legacy/exceptions.py index 9ca9b7aff..e2279c825 100644 --- a/src/websockets/legacy/exceptions.py +++ b/src/websockets/legacy/exceptions.py @@ -50,7 +50,7 @@ def __init__( headers: datastructures.HeadersLike, body: bytes = b"", ) -> None: - # If a user passes an int instead of a HTTPStatus, fix it automatically. + # If a user passes an int instead of an HTTPStatus, fix it automatically. self.status = http.HTTPStatus(status) self.headers = datastructures.Headers(headers) self.body = body diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index 4c2f8c23f..add0c6e0e 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -1,7 +1,8 @@ from __future__ import annotations import struct -from typing import Any, Awaitable, Callable, NamedTuple, Sequence +from collections.abc import Awaitable, Sequence +from typing import Any, Callable, NamedTuple from .. import extensions, frames from ..exceptions import PayloadTooBig, ProtocolError @@ -92,7 +93,7 @@ async def read( data = await reader(8) (length,) = struct.unpack("!Q", data) if max_size is not None and length > max_size: - raise PayloadTooBig(f"over size limit ({length} > {max_size} bytes)") + raise PayloadTooBig(length, max_size) if mask: mask_bits = await reader(4) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 998e390d4..db126c01e 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -9,19 +9,11 @@ import struct import sys import time +import traceback import uuid import warnings -from typing import ( - Any, - AsyncIterable, - AsyncIterator, - Awaitable, - Callable, - Deque, - Iterable, - Mapping, - cast, -) +from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterable, Mapping +from typing import Any, Callable, Deque, cast from ..asyncio.compatibility import asyncio_timeout from ..datastructures import Headers @@ -1255,7 +1247,7 @@ async def keepalive_ping(self) -> None: self.logger.debug("% received keepalive pong") except asyncio.TimeoutError: if self.debug: - self.logger.debug("! timed out waiting for keepalive pong") + self.logger.debug("- timed out waiting for keepalive pong") self.fail_connection( CloseCode.INTERNAL_ERROR, "keepalive ping timeout", @@ -1297,7 +1289,7 @@ async def close_connection(self) -> None: if await self.wait_for_connection_lost(): return if self.debug: - self.logger.debug("! timed out waiting for TCP close") + self.logger.debug("- timed out waiting for TCP close") # Half-close the TCP connection if possible (when there's no TLS). if self.transport.can_write_eof(): @@ -1315,7 +1307,7 @@ async def close_connection(self) -> None: if await self.wait_for_connection_lost(): return if self.debug: - self.logger.debug("! timed out waiting for TCP close") + self.logger.debug("- timed out waiting for TCP close") finally: # The try/finally ensures that the transport never remains open, @@ -1341,7 +1333,7 @@ async def close_transport(self) -> None: if await self.wait_for_connection_lost(): return if self.debug: - self.logger.debug("! timed out waiting for TCP close") + self.logger.debug("- timed out waiting for TCP close") # Abort the TCP connection. Buffers are discarded. if self.debug: @@ -1633,8 +1625,12 @@ def broadcast( exceptions.append(exception) else: websocket.logger.warning( - "skipped broadcast: failed to write message", - exc_info=True, + "skipped broadcast: failed to write message: %s", + traceback.format_exception_only( + # Remove first argument when dropping Python 3.9. + type(write_exception), + write_exception, + )[0].strip(), ) if raise_exceptions and exceptions: diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 2cb9b1abb..9326b6100 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -8,18 +8,9 @@ import logging import socket import warnings +from collections.abc import Awaitable, Generator, Iterable, Sequence from types import TracebackType -from typing import ( - Any, - Awaitable, - Callable, - Generator, - Iterable, - Sequence, - Tuple, - Union, - cast, -) +from typing import Any, Callable, Union, cast from ..asyncio.compatibility import asyncio_timeout from ..datastructures import Headers, HeadersLike, MultipleValuesError @@ -59,8 +50,7 @@ # Change to HeadersLike | ... when dropping Python < 3.10. HeadersLikeOrCallable = Union[HeadersLike, Callable[[str, Headers], HeadersLike]] -# Change to tuple[...] when dropping Python < 3.9. -HTTPResponse = Tuple[StatusLike, HeadersLike, bytes] +HTTPResponse = tuple[StatusLike, HeadersLike, bytes] class WebSocketServerProtocol(WebSocketCommonProtocol): diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 8751ebdb4..bc64a216a 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -3,7 +3,8 @@ import enum import logging import uuid -from typing import Generator, Union +from collections.abc import Generator +from typing import Union from .exceptions import ( ConnectionClosed, @@ -158,7 +159,12 @@ def state(self) -> State: """ State of the WebSocket connection. - Defined in 4.1, 4.2, 7.1.3, and 7.1.4 of :rfc:`6455`. + Defined in 4.1_, 4.2_, 7.1.3_, and 7.1.4_ of :rfc:`6455`. + + .. _4.1: https://datatracker.ietf.org/doc/html/rfc6455#section-4.1 + .. _4.2: https://datatracker.ietf.org/doc/html/rfc6455#section-4.2 + .. _7.1.3: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.3 + .. _7.1.4: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.4 """ return self._state @@ -172,10 +178,11 @@ def state(self, state: State) -> None: @property def close_code(self) -> int | None: """ - `WebSocket close code`_. + WebSocket close code received from the remote endpoint. + + Defined in 7.1.5_ of :rfc:`6455`. - .. _WebSocket close code: - https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.5 + .. _7.1.5: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.5 :obj:`None` if the connection isn't closed yet. @@ -190,10 +197,11 @@ def close_code(self) -> int | None: @property def close_reason(self) -> str | None: """ - `WebSocket close reason`_. + WebSocket close reason received from the remote endpoint. + + Defined in 7.1.6_ of :rfc:`6455`. - .. _WebSocket close reason: - https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.6 + .. _7.1.6: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.6 :obj:`None` if the connection isn't closed yet. @@ -517,19 +525,38 @@ def close_expected(self) -> bool: Whether the TCP connection is expected to close soon. """ - # We expect a TCP close if and only if we sent a close frame: + # During the opening handshake, when our state is CONNECTING, we expect + # a TCP close if and only if the hansdake fails. When it does, we start + # the TCP closing handshake by sending EOF with send_eof(). + + # Once the opening handshake completes successfully, we expect a TCP + # close if and only if we sent a close frame, meaning that our state + # progressed to CLOSING: + # * Normal closure: once we send a close frame, we expect a TCP close: # server waits for client to complete the TCP closing handshake; # client waits for server to initiate the TCP closing handshake. + # * Abnormal closure: we always send a close frame and the same logic # applies, except on EOFError where we don't send a close frame # because we already received the TCP close, so we don't expect it. - # We already got a TCP Close if and only if the state is CLOSED. - return self.state is CLOSING or self.handshake_exc is not None + + # If our state is CLOSED, we already received a TCP close so we don't + # expect it anymore. + + # Micro-optimization: put the most common case first + if self.state is OPEN: + return False + if self.state is CLOSING: + return True + if self.state is CLOSED: + return False + assert self.state is CONNECTING + return self.eof_sent # Private methods for receiving data. - def parse(self) -> Generator[None, None, None]: + def parse(self) -> Generator[None]: """ Parse incoming data into frames. @@ -586,6 +613,7 @@ def parse(self) -> Generator[None, None, None]: self.parser_exc = exc except PayloadTooBig as exc: + exc.set_current_size(self.cur_size) self.fail(CloseCode.MESSAGE_TOO_BIG, str(exc)) self.parser_exc = exc @@ -600,7 +628,7 @@ def parse(self) -> Generator[None, None, None]: yield raise AssertionError("parse() shouldn't step after error") - def discard(self) -> Generator[None, None, None]: + def discard(self) -> Generator[None]: """ Discard incoming data. @@ -614,14 +642,14 @@ def discard(self) -> Generator[None, None, None]: # connection in the same circumstances where discard() replaces parse(). # The client closes it when it receives EOF from the server or times # out. (The latter case cannot be handled in this Sans-I/O layer.) - assert (self.state == CONNECTING or self.side is SERVER) == (self.eof_sent) + assert (self.side is SERVER or self.state is CONNECTING) == (self.eof_sent) while not (yield from self.reader.at_eof()): self.reader.discard() if self.debug: self.logger.debug("< EOF") # A server closes the TCP connection immediately, while a client # waits for the server to close the TCP connection. - if self.state != CONNECTING and self.side is CLIENT: + if self.side is CLIENT and self.state is not CONNECTING: self.send_eof() self.state = CLOSED # If discard() completes normally, execution ends here. @@ -638,9 +666,7 @@ def recv_frame(self, frame: Frame) -> None: if frame.opcode is OP_TEXT or frame.opcode is OP_BINARY: if self.cur_size is not None: raise ProtocolError("expected a continuation frame") - if frame.fin: - self.cur_size = None - else: + if not frame.fin: self.cur_size = len(frame.data) elif frame.opcode is OP_CONT: diff --git a/src/websockets/server.py b/src/websockets/server.py index 006d5bdd5..607cc306e 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -5,7 +5,8 @@ import email.utils import http import warnings -from typing import Any, Callable, Generator, Sequence, cast +from collections.abc import Generator, Sequence +from typing import Any, Callable, cast from .datastructures import Headers, MultipleValuesError from .exceptions import ( @@ -13,7 +14,6 @@ InvalidHeader, InvalidHeaderValue, InvalidOrigin, - InvalidStatus, InvalidUpgrade, NegotiationError, ) @@ -26,6 +26,7 @@ parse_upgrade, ) from .http11 import Request, Response +from .imports import lazy_import from .protocol import CONNECTING, OPEN, SERVER, Protocol, State from .typing import ( ConnectionOption, @@ -39,13 +40,7 @@ from .utils import accept_key -# See #940 for why lazy_import isn't used here for backwards compatibility. -# See #1400 for why listing compatibility imports in __all__ helps PyCharm. -from .legacy.server import * # isort:skip # noqa: I001 -from .legacy.server import __all__ as legacy__all__ - - -__all__ = ["ServerProtocol"] + legacy__all__ +__all__ = ["ServerProtocol"] class ServerProtocol(Protocol): @@ -540,11 +535,6 @@ def send_response(self, response: Response) -> None: self.logger.info("connection open") else: - # handshake_exc may be already set if accept() encountered an error. - # If the connection isn't open, set handshake_exc to guarantee that - # handshake_exc is None if and only if opening handshake succeeded. - if self.handshake_exc is None: - self.handshake_exc = InvalidStatus(response) self.logger.info( "connection rejected (%d %s)", response.status_code, @@ -555,7 +545,7 @@ def send_response(self, response: Response) -> None: self.parser = self.discard() next(self.parser) # start coroutine - def parse(self) -> Generator[None, None, None]: + def parse(self) -> Generator[None]: if self.state is CONNECTING: try: request = yield from Request.parse( @@ -585,3 +575,16 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: DeprecationWarning, ) super().__init__(*args, **kwargs) + + +lazy_import( + globals(), + deprecated_aliases={ + # deprecated in 14.0 - 2024-11-09 + "WebSocketServer": ".legacy.server", + "WebSocketServerProtocol": ".legacy.server", + "broadcast": ".legacy.server", + "serve": ".legacy.server", + "unix_serve": ".legacy.server", + }, +) diff --git a/src/websockets/streams.py b/src/websockets/streams.py index 956f139d4..f52e6193a 100644 --- a/src/websockets/streams.py +++ b/src/websockets/streams.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Generator +from collections.abc import Generator class StreamReader: diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index d1e20a757..9e6da7caf 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -4,7 +4,8 @@ import ssl as ssl_module import threading import warnings -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any from ..client import ClientProtocol from ..datastructures import HeadersLike @@ -39,10 +40,12 @@ class ClientConnection(Connection): :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is closed with any other code. + The ``close_timeout`` and ``max_queue`` arguments have the same meaning as + in :func:`connect`. + Args: socket: Socket connected to a WebSocket server. protocol: Sans-I/O connection. - close_timeout: Timeout for closing the connection in seconds. """ @@ -52,6 +55,7 @@ def __init__( protocol: ClientProtocol, *, close_timeout: float | None = 10, + max_queue: int | None | tuple[int | None, int | None] = 16, ) -> None: self.protocol: ClientProtocol self.response_rcvd = threading.Event() @@ -59,6 +63,7 @@ def __init__( socket, protocol, close_timeout=close_timeout, + max_queue=max_queue, ) def handshake( @@ -82,9 +87,9 @@ def handshake( if not self.response_rcvd.wait(timeout): raise TimeoutError("timed out during handshake") - # self.protocol.handshake_exc is always set when the connection is lost - # before receiving a response, when the response cannot be parsed, or - # when the response fails the handshake. + # self.protocol.handshake_exc is set when the connection is lost before + # receiving a response, when the response cannot be parsed, or when the + # response fails the handshake. if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc @@ -134,6 +139,7 @@ def connect( close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, + max_queue: int | None | tuple[int | None, int | None] = 16, # Logging logger: LoggerLike | None = None, # Escape hatch for advanced customization @@ -182,6 +188,11 @@ def connect( :obj:`None` disables the timeout. max_size: Maximum size of incoming messages in bytes. :obj:`None` disables the limit. + max_queue: High-water mark of the buffer where frames are received. + It defaults to 16 frames. The low-water mark defaults to ``max_queue + // 4``. You may pass a ``(high, low)`` tuple to set the high-water + and low-water marks. If you want to disable flow control entirely, + you may set it to ``None``, although that's a bad idea. logger: Logger for this client. It defaults to ``logging.getLogger("websockets.client")``. See the :doc:`logging guide <../../topics/logging>` for details. @@ -211,7 +222,7 @@ def connect( wsuri = parse_uri(uri) if not wsuri.secure and ssl is not None: - raise TypeError("ssl argument is incompatible with a ws:// URI") + raise ValueError("ssl argument is incompatible with a ws:// URI") # Private APIs for unix_connect() unix: bool = kwargs.pop("unix", False) @@ -219,9 +230,9 @@ def connect( if unix: if path is None and sock is None: - raise TypeError("missing path argument") + raise ValueError("missing path argument") elif path is not None and sock is not None: - raise TypeError("path and sock arguments are incompatible") + raise ValueError("path and sock arguments are incompatible") if subprotocols is not None: validate_subprotocols(subprotocols) @@ -286,6 +297,7 @@ def connect( sock, protocol, close_timeout=close_timeout, + max_queue=max_queue, ) except Exception: if sock is not None: diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 97588870e..d8dbf140e 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -7,8 +7,9 @@ import struct import threading import uuid +from collections.abc import Iterable, Iterator, Mapping from types import TracebackType -from typing import Any, Iterable, Iterator, Mapping +from typing import Any from ..exceptions import ( ConcurrencyError, @@ -48,10 +49,14 @@ def __init__( protocol: Protocol, *, close_timeout: float | None = 10, + max_queue: int | None | tuple[int | None, int | None] = 16, ) -> None: self.socket = socket self.protocol = protocol self.close_timeout = close_timeout + if isinstance(max_queue, int) or max_queue is None: + max_queue = (max_queue, None) + self.max_queue = max_queue # Inject reference to this instance in the protocol's logger. self.protocol.logger = logging.LoggerAdapter( @@ -75,8 +80,15 @@ def __init__( # Mutex serializing interactions with the protocol. self.protocol_mutex = threading.Lock() + # Lock stopping reads when the assembler buffer is full. + self.recv_flow_control = threading.Lock() + # Assembler turning frames into messages and serializing reads. - self.recv_messages = Assembler() + self.recv_messages = Assembler( + *self.max_queue, + pause=self.recv_flow_control.acquire, + resume=self.recv_flow_control.release, + ) # Whether we are busy sending a fragmented message. self.send_in_progress = False @@ -87,6 +99,10 @@ def __init__( # Mapping of ping IDs to pong waiters, in chronological order. self.ping_waiters: dict[bytes, threading.Event] = {} + # Exception raised in recv_events, to be chained to ConnectionClosed + # in the user thread in order to show why the TCP connection dropped. + self.recv_exc: BaseException | None = None + # Receiving events from the socket. This thread is marked as daemon to # allow creating a connection in a non-daemon thread and using it in a # daemon thread. This mustn't prevent the interpreter from exiting. @@ -96,10 +112,6 @@ def __init__( ) self.recv_events_thread.start() - # Exception raised in recv_events, to be chained to ConnectionClosed - # in the user thread in order to show why the TCP connection dropped. - self.recv_exc: BaseException | None = None - # Public attributes @property @@ -128,6 +140,19 @@ def remote_address(self) -> Any: """ return self.socket.getpeername() + @property + def state(self) -> State: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should call :meth:`~recv` or + :meth:`send` and handle :exc:`~websockets.exceptions.ConnectionClosed` + exceptions. + + """ + return self.protocol.state + @property def subprotocol(self) -> Subprotocol | None: """ @@ -138,6 +163,30 @@ def subprotocol(self) -> Subprotocol | None: """ return self.protocol.subprotocol + @property + def close_code(self) -> int | None: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should inspect attributes + of :exc:`~websockets.exceptions.ConnectionClosed` exceptions. + + """ + return self.protocol.close_code + + @property + def close_reason(self) -> str | None: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should inspect attributes + of :exc:`~websockets.exceptions.ConnectionClosed` exceptions. + + """ + return self.protocol.close_reason + # Public methods def __enter__(self) -> Connection: @@ -171,7 +220,7 @@ def __iter__(self) -> Iterator[Data]: except ConnectionClosedOK: return - def recv(self, timeout: float | None = None) -> Data: + def recv(self, timeout: float | None = None, decode: bool | None = None) -> Data: """ Receive the next message. @@ -190,6 +239,11 @@ def recv(self, timeout: float | None = None) -> Data: If the message is fragmented, wait until all fragments are received, reassemble them, and return the whole message. + Args: + timeout: Timeout for receiving a message in seconds. + decode: Set this flag to override the default behavior of returning + :class:`str` or :class:`bytes`. See below for details. + Returns: A string (:class:`str`) for a Text_ frame or a bytestring (:class:`bytes`) for a Binary_ frame. @@ -197,6 +251,16 @@ def recv(self, timeout: float | None = None) -> Data: .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + You may override this behavior with the ``decode`` argument: + + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames and + return a bytestring (:class:`bytes`). This improves performance + when decoding isn't needed, for example if the message contains + JSON and you're using a JSON library that expects a bytestring. + * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames + and return a string (:class:`str`). This may be useful for + servers that send binary frames instead of text frames. + Raises: ConnectionClosed: When the connection is closed. ConcurrencyError: If two threads call :meth:`recv` or @@ -204,27 +268,44 @@ def recv(self, timeout: float | None = None) -> Data: """ try: - return self.recv_messages.get(timeout) + return self.recv_messages.get(timeout, decode) except EOFError: - # Wait for the protocol state to be CLOSED before accessing close_exc. - self.recv_events_thread.join() - raise self.protocol.close_exc from self.recv_exc + pass + # fallthrough except ConcurrencyError: raise ConcurrencyError( "cannot call recv while another thread " "is already running recv or recv_streaming" ) from None + except UnicodeDecodeError as exc: + with self.send_context(): + self.protocol.fail( + CloseCode.INVALID_DATA, + f"{exc.reason} at position {exc.start}", + ) + # fallthrough + + # Wait for the protocol state to be CLOSED before accessing close_exc. + self.recv_events_thread.join() + raise self.protocol.close_exc from self.recv_exc - def recv_streaming(self) -> Iterator[Data]: + def recv_streaming(self, decode: bool | None = None) -> Iterator[Data]: """ Receive the next message frame by frame. - If the message is fragmented, yield each fragment as it is received. - The iterator must be fully consumed, or else the connection will become + This method is designed for receiving fragmented messages. It returns an + iterator that yields each fragment as it is received. This iterator must + be fully consumed. Else, future calls to :meth:`recv` or + :meth:`recv_streaming` will raise + :exc:`~websockets.exceptions.ConcurrencyError`, making the connection unusable. :meth:`recv_streaming` raises the same exceptions as :meth:`recv`. + Args: + decode: Set this flag to override the default behavior of returning + :class:`str` or :class:`bytes`. See below for details. + Returns: An iterator of strings (:class:`str`) for a Text_ frame or bytestrings (:class:`bytes`) for a Binary_ frame. @@ -232,6 +313,15 @@ def recv_streaming(self) -> Iterator[Data]: .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + You may override this behavior with the ``decode`` argument: + + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames + and return bytestrings (:class:`bytes`). This may be useful to + optimize performance when decoding isn't needed. + * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames + and return strings (:class:`str`). This is useful for servers + that send binary frames instead of text frames. + Raises: ConnectionClosed: When the connection is closed. ConcurrencyError: If two threads call :meth:`recv` or @@ -239,19 +329,33 @@ def recv_streaming(self) -> Iterator[Data]: """ try: - for frame in self.recv_messages.get_iter(): - yield frame + yield from self.recv_messages.get_iter(decode) + return except EOFError: - # Wait for the protocol state to be CLOSED before accessing close_exc. - self.recv_events_thread.join() - raise self.protocol.close_exc from self.recv_exc + pass + # fallthrough except ConcurrencyError: raise ConcurrencyError( "cannot call recv_streaming while another thread " "is already running recv or recv_streaming" ) from None + except UnicodeDecodeError as exc: + with self.send_context(): + self.protocol.fail( + CloseCode.INVALID_DATA, + f"{exc.reason} at position {exc.start}", + ) + # fallthrough + + # Wait for the protocol state to be CLOSED before accessing close_exc. + self.recv_events_thread.join() + raise self.protocol.close_exc from self.recv_exc - def send(self, message: Data | Iterable[Data]) -> None: + def send( + self, + message: Data | Iterable[Data], + text: bool | None = None, + ) -> None: """ Send a message. @@ -262,6 +366,17 @@ def send(self, message: Data | Iterable[Data]) -> None: .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + You may override this behavior with the ``text`` argument: + + * Set ``text=True`` to send a bytestring or bytes-like object + (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) as a + Text_ frame. This improves performance when the message is already + UTF-8 encoded, for example if the message contains JSON and you're + using a JSON library that produces a bytestring. + * Set ``text=False`` to send a string (:class:`str`) in a Binary_ + frame. This may be useful for servers that expect binary frames + instead of text frames. + :meth:`send` also accepts an iterable of strings, bytestrings, or bytes-like objects to enable fragmentation_. Each item is treated as a message fragment and sent in its own frame. All items must be of the @@ -300,7 +415,10 @@ def send(self, message: Data | Iterable[Data]) -> None: "cannot call send while another thread " "is already running send" ) - self.protocol.send_text(message.encode()) + if text is False: + self.protocol.send_binary(message.encode()) + else: + self.protocol.send_text(message.encode()) elif isinstance(message, BytesLike): with self.send_context(): @@ -309,7 +427,10 @@ def send(self, message: Data | Iterable[Data]) -> None: "cannot call send while another thread " "is already running send" ) - self.protocol.send_binary(message) + if text is True: + self.protocol.send_text(message) + else: + self.protocol.send_binary(message) # Catch a common mistake -- passing a dict to send(). @@ -328,7 +449,6 @@ def send(self, message: Data | Iterable[Data]) -> None: try: # First fragment. if isinstance(chunk, str): - text = True with self.send_context(): if self.send_in_progress: raise ConcurrencyError( @@ -336,12 +456,12 @@ def send(self, message: Data | Iterable[Data]) -> None: "is already running send" ) self.send_in_progress = True - self.protocol.send_text( - chunk.encode(), - fin=False, - ) + if text is False: + self.protocol.send_binary(chunk.encode(), fin=False) + else: + self.protocol.send_text(chunk.encode(), fin=False) + encode = True elif isinstance(chunk, BytesLike): - text = False with self.send_context(): if self.send_in_progress: raise ConcurrencyError( @@ -349,29 +469,24 @@ def send(self, message: Data | Iterable[Data]) -> None: "is already running send" ) self.send_in_progress = True - self.protocol.send_binary( - chunk, - fin=False, - ) + if text is True: + self.protocol.send_text(chunk, fin=False) + else: + self.protocol.send_binary(chunk, fin=False) + encode = False else: raise TypeError("data iterable must contain bytes or str") # Other fragments for chunk in chunks: - if isinstance(chunk, str) and text: + if isinstance(chunk, str) and encode: with self.send_context(): assert self.send_in_progress - self.protocol.send_continuation( - chunk.encode(), - fin=False, - ) - elif isinstance(chunk, BytesLike) and not text: + self.protocol.send_continuation(chunk.encode(), fin=False) + elif isinstance(chunk, BytesLike) and not encode: with self.send_context(): assert self.send_in_progress - self.protocol.send_continuation( - chunk, - fin=False, - ) + self.protocol.send_continuation(chunk, fin=False) else: raise TypeError("data iterable must contain uniform types") @@ -556,12 +671,16 @@ def recv_events(self) -> None: try: while True: try: - if self.close_deadline is not None: - self.socket.settimeout(self.close_deadline.timeout()) + with self.recv_flow_control: + if self.close_deadline is not None: + self.socket.settimeout(self.close_deadline.timeout()) data = self.socket.recv(self.recv_bufsize) except Exception as exc: if self.debug: - self.logger.debug("error while receiving data", exc_info=True) + self.logger.debug( + "! error while receiving data", + exc_info=True, + ) # When the closing handshake is initiated by our side, # recv() may block until send_context() closes the socket. # In that case, send_context() already set recv_exc. @@ -586,7 +705,10 @@ def recv_events(self) -> None: self.send_data() except Exception as exc: if self.debug: - self.logger.debug("error while sending data", exc_info=True) + self.logger.debug( + "! error while sending data", + exc_info=True, + ) # Similarly to the above, avoid overriding an exception # set by send_context(), in case of a race condition # i.e. send_context() closes the socket after recv() @@ -607,13 +729,9 @@ def recv_events(self) -> None: # Given that automatic responses write small amounts of data, # this should be uncommon, so we don't handle the edge case. - try: - for event in events: - # This may raise EOFError if the closing handshake - # times out while a message is waiting to be read. - self.process_event(event) - except EOFError: - break + for event in events: + # This isn't expected to raise an exception. + self.process_event(event) # Breaking out of the while True: ... loop means that we believe # that the socket doesn't work anymore. @@ -621,13 +739,22 @@ def recv_events(self) -> None: # Feed the end of the data stream to the protocol. self.protocol.receive_eof() - # This isn't expected to generate events. - assert not self.protocol.events_received() + # This isn't expected to raise an exception. + events = self.protocol.events_received() # There is no error handling because send_data() can only write # the end of the data stream here and it handles errors itself. self.send_data() + # This code path is triggered when receiving an HTTP response + # without a Content-Length header. This is the only case where + # reading until EOF generates an event; all other events have + # a known length. Ignore for coverage measurement because tests + # are in test_client.py rather than test_connection.py. + for event in events: # pragma: no cover + # This isn't expected to raise an exception. + self.process_event(event) + except Exception as exc: # This branch should never run. It's a safety net in case of bugs. self.logger.error("unexpected internal error", exc_info=True) @@ -699,7 +826,10 @@ def send_context( self.send_data() except Exception as exc: if self.debug: - self.logger.debug("error while sending data", exc_info=True) + self.logger.debug( + "! error while sending data", + exc_info=True, + ) # While the only expected exception here is OSError, # other exceptions would be treated identically. wait_for_close = False diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index 8d090538f..98490797f 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -3,11 +3,12 @@ import codecs import queue import threading -from typing import Iterator, cast +from typing import Any, Callable, Iterable, Iterator from ..exceptions import ConcurrencyError from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame from ..typing import Data +from .utils import Deadline __all__ = ["Assembler"] @@ -19,47 +20,92 @@ class Assembler: """ Assemble messages from frames. + :class:`Assembler` expects only data frames. The stream of frames must + respect the protocol; if it doesn't, the behavior is undefined. + + Args: + pause: Called when the buffer of frames goes above the high water mark; + should pause reading from the network. + resume: Called when the buffer of frames goes below the low water mark; + should resume reading from the network. + """ - def __init__(self) -> None: + def __init__( + self, + high: int | None = None, + low: int | None = None, + pause: Callable[[], Any] = lambda: None, + resume: Callable[[], Any] = lambda: None, + ) -> None: # Serialize reads and writes -- except for reads via synchronization # primitives provided by the threading and queue modules. self.mutex = threading.Lock() - # We create a latch with two events to synchronize the production of - # frames and the consumption of messages (or frames) without a buffer. - # This design requires a switch between the library thread and the user - # thread for each message; that shouldn't be a performance bottleneck. - - # put() sets this event to tell get() that a message can be fetched. - self.message_complete = threading.Event() - # get() sets this event to let put() that the message was fetched. - self.message_fetched = threading.Event() + # Queue of incoming frames. + self.frames: queue.SimpleQueue[Frame | None] = queue.SimpleQueue() + + # We cannot put a hard limit on the size of the queue because a single + # call to Protocol.data_received() could produce thousands of frames, + # which must be buffered. Instead, we pause reading when the buffer goes + # above the high limit and we resume when it goes under the low limit. + if high is not None and low is None: + low = high // 4 + if high is None and low is not None: + high = low * 4 + if high is not None and low is not None: + if low < 0: + raise ValueError("low must be positive or equal to zero") + if high < low: + raise ValueError("high must be greater than or equal to low") + self.high, self.low = high, low + self.pause = pause + self.resume = resume + self.paused = False # This flag prevents concurrent calls to get() by user code. self.get_in_progress = False - # This flag prevents concurrent calls to put() by library code. - self.put_in_progress = False - - # Decoder for text frames, None for binary frames. - self.decoder: codecs.IncrementalDecoder | None = None - - # Buffer of frames belonging to the same message. - self.chunks: list[Data] = [] - - # When switching from "buffering" to "streaming", we use a thread-safe - # queue for transferring frames from the writing thread (library code) - # to the reading thread (user code). We're buffering when chunks_queue - # is None and streaming when it's a SimpleQueue. None is a sentinel - # value marking the end of the message, superseding message_complete. - - # Stream data from frames belonging to the same message. - self.chunks_queue: queue.SimpleQueue[Data | None] | None = None # This flag marks the end of the connection. self.closed = False - def get(self, timeout: float | None = None) -> Data: + def get_next_frame(self, timeout: float | None = None) -> Frame: + # Helper to factor out the logic for getting the next frame from the + # queue, while handling timeouts and reaching the end of the stream. + if self.closed: + try: + frame = self.frames.get(block=False) + except queue.Empty: + raise EOFError("stream of frames ended") from None + else: + try: + frame = self.frames.get(block=True, timeout=timeout) + except queue.Empty: + raise TimeoutError(f"timed out in {timeout:.1f}s") from None + if frame is None: + raise EOFError("stream of frames ended") + return frame + + def reset_queue(self, frames: Iterable[Frame]) -> None: + # Helper to put frames back into the queue after they were fetched. + # This happens only when the queue is empty. However, by the time + # we acquire self.mutex, put() may have added items in the queue. + # Therefore, we must handle the case where the queue is not empty. + frame: Frame | None + with self.mutex: + queued = [] + try: + while True: + queued.append(self.frames.get(block=False)) + except queue.Empty: + pass + for frame in frames: + self.frames.put(frame) + # This loop runs only when a race condition occurs. + for frame in queued: # pragma: no cover + self.frames.put(frame) + + def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: """ Read the next message. @@ -72,56 +118,63 @@ def get(self, timeout: float | None = None) -> Data: Args: timeout: If a timeout is provided and elapses before a complete message is received, :meth:`get` raises :exc:`TimeoutError`. + decode: :obj:`False` disables UTF-8 decoding of text frames and + returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of + binary frames and returns :class:`str`. Raises: EOFError: If the stream of frames has ended. - ConcurrencyError: If two threads run :meth:`get` or :meth:`get_iter` - concurrently. + UnicodeDecodeError: If a text frame contains invalid UTF-8. + ConcurrencyError: If two coroutines run :meth:`get` or + :meth:`get_iter` concurrently. TimeoutError: If a timeout is provided and elapses before a complete message is received. """ with self.mutex: - if self.closed: - raise EOFError("stream of frames ended") - if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - self.get_in_progress = True - # If the message_complete event isn't set yet, release the lock to - # allow put() to run and eventually set it. - # Locking with get_in_progress ensures only one thread can get here. - completed = self.message_complete.wait(timeout) + # Locking with get_in_progress prevents concurrent execution + # until get() fetches a complete message or times out. + + try: + deadline = Deadline(timeout) + + # First frame + frame = self.get_next_frame(deadline.timeout()) + with self.mutex: + self.maybe_resume() + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + frames = [frame] + + # Following frames, for fragmented messages + while not frame.fin: + try: + frame = self.get_next_frame(deadline.timeout()) + except TimeoutError: + # Put frames already received back into the queue + # so that future calls to get() can return them. + self.reset_queue(frames) + raise + with self.mutex: + self.maybe_resume() + assert frame.opcode is OP_CONT + frames.append(frame) - with self.mutex: + finally: self.get_in_progress = False - # Waiting for a complete message timed out. - if not completed: - raise TimeoutError(f"timed out in {timeout:.1f}s") + data = b"".join(frame.data for frame in frames) + if decode: + return data.decode() + else: + return data - # get() was unblocked by close() rather than put(). - if self.closed: - raise EOFError("stream of frames ended") - - assert self.message_complete.is_set() - self.message_complete.clear() - - joiner: Data = b"" if self.decoder is None else "" - # mypy cannot figure out that chunks have the proper type. - message: Data = joiner.join(self.chunks) # type: ignore - - self.chunks = [] - assert self.chunks_queue is None - - assert not self.message_fetched.is_set() - self.message_fetched.set() - - return message - - def get_iter(self) -> Iterator[Data]: + def get_iter(self, decode: bool | None = None) -> Iterator[Data]: """ Stream the next message. @@ -134,129 +187,105 @@ def get_iter(self) -> Iterator[Data]: This method only makes sense for fragmented messages. If messages aren't fragmented, use :meth:`get` instead. + Args: + decode: :obj:`False` disables UTF-8 decoding of text frames and + returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of + binary frames and returns :class:`str`. + Raises: EOFError: If the stream of frames has ended. - ConcurrencyError: If two threads run :meth:`get` or :meth:`get_iter` - concurrently. + UnicodeDecodeError: If a text frame contains invalid UTF-8. + ConcurrencyError: If two coroutines run :meth:`get` or + :meth:`get_iter` concurrently. """ with self.mutex: - if self.closed: - raise EOFError("stream of frames ended") - if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - - chunks = self.chunks - self.chunks = [] - self.chunks_queue = cast( - # Remove quotes around type when dropping Python < 3.9. - "queue.SimpleQueue[Data | None]", - queue.SimpleQueue(), - ) - - # Sending None in chunk_queue supersedes setting message_complete - # when switching to "streaming". If message is already complete - # when the switch happens, put() didn't send None, so we have to. - if self.message_complete.is_set(): - self.chunks_queue.put(None) - self.get_in_progress = True - # Locking with get_in_progress ensures only one thread can get here. - chunk: Data | None - for chunk in chunks: - yield chunk - while (chunk := self.chunks_queue.get()) is not None: - yield chunk - - with self.mutex: - self.get_in_progress = False - - # get_iter() was unblocked by close() rather than put(). - if self.closed: - raise EOFError("stream of frames ended") + # Locking with get_in_progress prevents concurrent execution + # until get_iter() fetches a complete message or times out. - assert self.message_complete.is_set() - self.message_complete.clear() + # If get_iter() raises an exception e.g. in decoder.decode(), + # get_in_progress remains set and the connection becomes unusable. - assert self.chunks == [] - self.chunks_queue = None + # First frame + frame = self.get_next_frame() + with self.mutex: + self.maybe_resume() + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + if decode: + decoder = UTF8Decoder() + yield decoder.decode(frame.data, frame.fin) + else: + yield frame.data + + # Following frames, for fragmented messages + while not frame.fin: + frame = self.get_next_frame() + with self.mutex: + self.maybe_resume() + assert frame.opcode is OP_CONT + if decode: + yield decoder.decode(frame.data, frame.fin) + else: + yield frame.data - assert not self.message_fetched.is_set() - self.message_fetched.set() + self.get_in_progress = False def put(self, frame: Frame) -> None: """ Add ``frame`` to the next message. - When ``frame`` is the final frame in a message, :meth:`put` waits until - the message is fetched, which can be achieved by calling :meth:`get` or - by fully consuming the return value of :meth:`get_iter`. - - :meth:`put` assumes that the stream of frames respects the protocol. If - it doesn't, the behavior is undefined. - Raises: EOFError: If the stream of frames has ended. - ConcurrencyError: If two threads run :meth:`put` concurrently. """ with self.mutex: if self.closed: raise EOFError("stream of frames ended") - if self.put_in_progress: - raise ConcurrencyError("put is already running") - - if frame.opcode is OP_TEXT: - self.decoder = UTF8Decoder(errors="strict") - elif frame.opcode is OP_BINARY: - self.decoder = None - else: - assert frame.opcode is OP_CONT - - data: Data - if self.decoder is not None: - data = self.decoder.decode(frame.data, frame.fin) - else: - data = frame.data - - if self.chunks_queue is None: - self.chunks.append(data) - else: - self.chunks_queue.put(data) - - if not frame.fin: - return + self.frames.put(frame) + self.maybe_pause() - # Message is complete. Wait until it's fetched to return. + # put() and get/get_iter() call maybe_pause() and maybe_resume() while + # holding self.mutex. This guarantees that the calls interleave properly. + # Specifically, it prevents a race condition where maybe_resume() would + # run before maybe_pause(), leaving the connection incorrectly paused. - assert not self.message_complete.is_set() - self.message_complete.set() + # A race condition is possible when get/get_iter() call self.frames.get() + # without holding self.mutex. However, it's harmless — and even beneficial! + # It can only result in popping an item from the queue before maybe_resume() + # runs and skipping a pause() - resume() cycle that would otherwise occur. - if self.chunks_queue is not None: - self.chunks_queue.put(None) + def maybe_pause(self) -> None: + """Pause the writer if queue is above the high water mark.""" + # Skip if flow control is disabled + if self.high is None: + return - assert not self.message_fetched.is_set() + assert self.mutex.locked() - self.put_in_progress = True + # Check for "> high" to support high = 0 + if self.frames.qsize() > self.high and not self.paused: + self.paused = True + self.pause() - # Release the lock to allow get() to run and eventually set the event. - # Locking with put_in_progress ensures only one coroutine can get here. - self.message_fetched.wait() + def maybe_resume(self) -> None: + """Resume the writer if queue is below the low water mark.""" + # Skip if flow control is disabled + if self.low is None: + return - with self.mutex: - self.put_in_progress = False + assert self.mutex.locked() - # put() was unblocked by close() rather than get() or get_iter(). - if self.closed: - raise EOFError("stream of frames ended") - - assert self.message_fetched.is_set() - self.message_fetched.clear() - - self.decoder = None + # Check for "<= low" to support low = 0 + if self.frames.qsize() <= self.low and self.paused: + self.paused = False + self.resume() def close(self) -> None: """ @@ -272,12 +301,6 @@ def close(self) -> None: self.closed = True - # Unblock get or get_iter. if self.get_in_progress: - self.message_complete.set() - if self.chunks_queue is not None: - self.chunks_queue.put(None) - - # Unblock put(). - if self.put_in_progress: - self.message_fetched.set() + # Unblock get() or get_iter(). + self.frames.put(None) diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 1b7cbb4b4..9506d6830 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -10,8 +10,9 @@ import sys import threading import warnings +from collections.abc import Iterable, Sequence from types import TracebackType -from typing import Any, Callable, Iterable, Sequence, Tuple, cast +from typing import Any, Callable, cast from ..exceptions import InvalidHeader from ..extensions.base import ServerExtensionFactory @@ -50,10 +51,12 @@ class ServerConnection(Connection): :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is closed with any other code. + The ``close_timeout`` and ``max_queue`` arguments have the same meaning as + in :func:`serve`. + Args: socket: Socket connected to a WebSocket client. protocol: Sans-I/O connection. - close_timeout: Timeout for closing the connection in seconds. """ @@ -63,6 +66,7 @@ def __init__( protocol: ServerProtocol, *, close_timeout: float | None = 10, + max_queue: int | None | tuple[int | None, int | None] = 16, ) -> None: self.protocol: ServerProtocol self.request_rcvd = threading.Event() @@ -70,6 +74,7 @@ def __init__( socket, protocol, close_timeout=close_timeout, + max_queue=max_queue, ) self.username: str # see basic_auth() @@ -165,10 +170,13 @@ def handshake( self.protocol.send_response(self.response) - # self.protocol.handshake_exc is always set when the connection is lost - # before receiving a request, when the request cannot be parsed, when - # the handshake encounters an error, or when process_request or - # process_response sends an HTTP response that rejects the handshake. + # self.protocol.handshake_exc is set when the connection is lost before + # receiving a request, when the request cannot be parsed, or when the + # handshake fails, including when process_request or process_response + # raises an exception. + + # It isn't set when process_request or process_response sends an HTTP + # response that rejects the handshake. if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc @@ -348,6 +356,7 @@ def serve( close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, + max_queue: int | None | tuple[int | None, int | None] = 16, # Logging logger: LoggerLike | None = None, # Escape hatch for advanced customization @@ -426,6 +435,11 @@ def handler(websocket): :obj:`None` disables the timeout. max_size: Maximum size of incoming messages in bytes. :obj:`None` disables the limit. + max_queue: High-water mark of the buffer where frames are received. + It defaults to 16 frames. The low-water mark defaults to ``max_queue + // 4``. You may pass a ``(high, low)`` tuple to set the high-water + and low-water marks. If you want to disable flow control entirely, + you may set it to ``None``, although that's a bad idea. logger: Logger for this server. It defaults to ``logging.getLogger("websockets.server")``. See the :doc:`logging guide <../../topics/logging>` for details. @@ -467,14 +481,14 @@ def handler(websocket): if sock is None: if unix: if path is None: - raise TypeError("missing path argument") + raise ValueError("missing path argument") kwargs.setdefault("family", socket.AF_UNIX) sock = socket.create_server(path, **kwargs) else: sock = socket.create_server((host, port), **kwargs) else: if path is not None: - raise TypeError("path and sock arguments are incompatible") + raise ValueError("path and sock arguments are incompatible") # Initialize TLS wrapper @@ -547,6 +561,7 @@ def protocol_select_subprotocol( sock, protocol, close_timeout=close_timeout, + max_queue=max_queue, ) except Exception: sock.close() @@ -656,16 +671,18 @@ def basic_auth( whether they're valid. Raises: TypeError: If ``credentials`` or ``check_credentials`` is wrong. + ValueError: If ``credentials`` and ``check_credentials`` are both + provided or both not provided. """ if (credentials is None) == (check_credentials is None): - raise TypeError("provide either credentials or check_credentials") + raise ValueError("provide either credentials or check_credentials") if credentials is not None: if is_credentials(credentials): - credentials_list = [cast(Tuple[str, str], credentials)] + credentials_list = [cast(tuple[str, str], credentials)] elif isinstance(credentials, Iterable): - credentials_list = list(cast(Iterable[Tuple[str, str]], credentials)) + credentials_list = list(cast(Iterable[tuple[str, str]], credentials)) if not all(is_credentials(item) for item in credentials_list): raise TypeError(f"invalid credentials argument: {credentials}") else: diff --git a/src/websockets/typing.py b/src/websockets/typing.py index 447fe79da..0a37141c6 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -3,7 +3,7 @@ import http import logging import typing -from typing import Any, List, NewType, Optional, Tuple, Union +from typing import Any, NewType, Optional, Union __all__ = [ @@ -56,16 +56,14 @@ ExtensionName = NewType("ExtensionName", str) """Name of a WebSocket extension.""" -# Change to tuple[str, Optional[str]] when dropping Python < 3.9. # Change to tuple[str, str | None] when dropping Python < 3.10. -ExtensionParameter = Tuple[str, Optional[str]] +ExtensionParameter = tuple[str, Optional[str]] """Parameter of a WebSocket extension.""" # Private types -# Change to tuple[.., list[...]] when dropping Python < 3.9. -ExtensionHeader = Tuple[ExtensionName, List[ExtensionParameter]] +ExtensionHeader = tuple[ExtensionName, list[ExtensionParameter]] """Extension in a ``Sec-WebSocket-Extensions`` header.""" diff --git a/src/websockets/version.py b/src/websockets/version.py index 00b0a985e..f2defeff0 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -20,7 +20,7 @@ released = True -tag = version = commit = "13.1" +tag = version = commit = "14.1" if not released: # pragma: no cover diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 999ef1b71..231d6b8ca 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -177,10 +177,6 @@ def process_request(connection, request): self.assertEqual(iterations, 5) self.assertEqual(successful, 2) - @unittest.skipUnless( - hasattr(http.HTTPStatus, "IM_A_TEAPOT"), - "test requires Python 3.9", - ) async def test_reconnect_with_custom_process_exception(self): """Client runs process_exception to tell if errors are retryable or fatal.""" iteration = 0 @@ -214,10 +210,6 @@ def process_exception(exc): "🫖 💔 ☕️", ) - @unittest.skipUnless( - hasattr(http.HTTPStatus, "IM_A_TEAPOT"), - "test requires Python 3.9", - ) async def test_reconnect_with_custom_process_exception_raising_exception(self): """Client supports raising an exception in process_exception.""" @@ -409,6 +401,36 @@ def close_connection(self, request): "connection closed while reading HTTP status line", ) + async def test_http_response(self): + """Client reads HTTP response.""" + + def http_response(connection, request): + return connection.respond(http.HTTPStatus.OK, "👌") + + async with serve(*args, process_request=http_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + + self.assertEqual(raised.exception.response.status_code, 200) + self.assertEqual(raised.exception.response.body.decode(), "👌") + + async def test_http_response_without_content_length(self): + """Client reads HTTP response without a Content-Length header.""" + + def http_response(connection, request): + response = connection.respond(http.HTTPStatus.OK, "👌") + del response.headers["Content-Length"] + return response + + async with serve(*args, process_request=http_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + + self.assertEqual(raised.exception.response.status_code, 200) + self.assertEqual(raised.exception.response.body.decode(), "👌") + async def test_junk_handshake(self): """Client closes the connection when receiving non-HTTP response from server.""" @@ -585,7 +607,7 @@ async def test_set_server_hostname(self): class ClientUsageErrorsTests(unittest.IsolatedAsyncioTestCase): async def test_ssl_without_secure_uri(self): """Client rejects ssl when URI isn't secure.""" - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: await connect("ws://localhost/", ssl=CLIENT_CONTEXT) self.assertEqual( str(raised.exception), @@ -594,7 +616,7 @@ async def test_ssl_without_secure_uri(self): async def test_secure_uri_without_ssl(self): """Client rejects no ssl when URI is secure.""" - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: await connect("wss://localhost/", ssl=None) self.assertEqual( str(raised.exception), diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 70d9dad63..5a0b61bf7 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -19,7 +19,7 @@ from websockets.protocol import CLIENT, SERVER, Protocol, State from ..protocol import RecordingProtocol -from ..utils import MS +from ..utils import MS, AssertNoLogsMixin from .connection import InterceptingConnection from .utils import alist @@ -28,7 +28,7 @@ # All tests run on the client side and the server side to validate this. -class ClientConnectionTests(unittest.IsolatedAsyncioTestCase): +class ClientConnectionTests(AssertNoLogsMixin, unittest.IsolatedAsyncioTestCase): LOCAL = CLIENT REMOTE = SERVER @@ -48,23 +48,6 @@ async def asyncTearDown(self): await self.remote_connection.close() await self.connection.close() - if sys.version_info[:2] < (3, 10): # pragma: no cover - - @contextlib.contextmanager - def assertNoLogs(self, logger="websockets", level=logging.ERROR): - """ - No message is logged on the given logger with at least the given level. - - """ - with self.assertLogs(logger, level) as logs: - # We want to test that no log message is emitted - # but assertLogs expects at least one log message. - logging.getLogger(logger).log(level, "dummy") - yield - - level_name = logging.getLevelName(level) - self.assertEqual(logs.output, [f"{level_name}:{logger}:dummy"]) - # Test helpers built upon RecordingProtocol and InterceptingConnection. async def assertFrameSent(self, frame): @@ -190,13 +173,13 @@ async def test_recv_binary(self): await self.remote_connection.send(b"\x01\x02\xfe\xff") self.assertEqual(await self.connection.recv(), b"\x01\x02\xfe\xff") - async def test_recv_encoded_text(self): - """recv receives an UTF-8 encoded text message.""" + async def test_recv_text_as_bytes(self): + """recv receives a text message as bytes.""" await self.remote_connection.send("😀") self.assertEqual(await self.connection.recv(decode=False), "😀".encode()) - async def test_recv_decoded_binary(self): - """recv receives an UTF-8 decoded binary message.""" + async def test_recv_binary_as_text(self): + """recv receives a binary message as a str.""" await self.remote_connection.send("😀".encode()) self.assertEqual(await self.connection.recv(decode=True), "😀") @@ -222,6 +205,15 @@ async def test_recv_connection_closed_error(self): with self.assertRaises(ConnectionClosedError): await self.connection.recv() + async def test_recv_non_utf8_text(self): + """recv receives a non-UTF-8 text message.""" + await self.remote_connection.send(b"\x01\x02\xfe\xff", text=True) + with self.assertRaises(ConnectionClosedError): + await self.connection.recv() + await self.assertFrameSent( + Frame(Opcode.CLOSE, b"\x03\xefinvalid start byte at position 2") + ) + async def test_recv_during_recv(self): """recv raises ConcurrencyError when called concurrently.""" recv_task = asyncio.create_task(self.connection.recv()) @@ -304,16 +296,16 @@ async def test_recv_streaming_binary(self): [b"\x01\x02\xfe\xff"], ) - async def test_recv_streaming_encoded_text(self): - """recv_streaming receives an UTF-8 encoded text message.""" + async def test_recv_streaming_text_as_bytes(self): + """recv_streaming receives a text message as bytes.""" await self.remote_connection.send("😀") self.assertEqual( await alist(self.connection.recv_streaming(decode=False)), ["😀".encode()], ) - async def test_recv_streaming_decoded_binary(self): - """recv_streaming receives a UTF-8 decoded binary message.""" + async def test_recv_streaming_binary_as_str(self): + """recv_streaming receives a binary message as a str.""" await self.remote_connection.send("😀".encode()) self.assertEqual( await alist(self.connection.recv_streaming(decode=True)), @@ -352,6 +344,15 @@ async def test_recv_streaming_connection_closed_error(self): async for _ in self.connection.recv_streaming(): self.fail("did not raise") + async def test_recv_streaming_non_utf8_text(self): + """recv_streaming receives a non-UTF-8 text message.""" + await self.remote_connection.send(b"\x01\x02\xfe\xff", text=True) + with self.assertRaises(ConnectionClosedError): + await alist(self.connection.recv_streaming()) + await self.assertFrameSent( + Frame(Opcode.CLOSE, b"\x03\xefinvalid start byte at position 2") + ) + async def test_recv_streaming_during_recv(self): """recv_streaming raises ConcurrencyError when called concurrently with recv.""" recv_task = asyncio.create_task(self.connection.recv()) @@ -438,6 +439,16 @@ async def test_send_binary(self): await self.connection.send(b"\x01\x02\xfe\xff") self.assertEqual(await self.remote_connection.recv(), b"\x01\x02\xfe\xff") + async def test_send_binary_from_str(self): + """send sends a binary message from a str.""" + await self.connection.send("😀", text=False) + self.assertEqual(await self.remote_connection.recv(), "😀".encode()) + + async def test_send_text_from_bytes(self): + """send sends a text message from bytes.""" + await self.connection.send("😀".encode(), text=True) + self.assertEqual(await self.remote_connection.recv(), "😀") + async def test_send_fragmented_text(self): """send sends a fragmented text message.""" await self.connection.send(["😀", "😀"]) @@ -456,6 +467,24 @@ async def test_send_fragmented_binary(self): [b"\x01\x02", b"\xfe\xff", b""], ) + async def test_send_fragmented_binary_from_str(self): + """send sends a fragmented binary message from a str.""" + await self.connection.send(["😀", "😀"], text=False) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀".encode(), "😀".encode(), b""], + ) + + async def test_send_fragmented_text_from_bytes(self): + """send sends a fragmented text message from bytes.""" + await self.connection.send(["😀".encode(), "😀".encode()], text=True) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + async def test_send_async_fragmented_text(self): """send sends a fragmented text message asynchronously.""" @@ -484,6 +513,34 @@ async def fragments(): [b"\x01\x02", b"\xfe\xff", b""], ) + async def test_send_async_fragmented_binary_from_str(self): + """send sends a fragmented binary message from a str asynchronously.""" + + async def fragments(): + yield "😀" + yield "😀" + + await self.connection.send(fragments(), text=False) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀".encode(), "😀".encode(), b""], + ) + + async def test_send_async_fragmented_text_from_bytes(self): + """send sends a fragmented text message from bytes asynchronously.""" + + async def fragments(): + yield "😀".encode() + yield "😀".encode() + + await self.connection.send(fragments(), text=True) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + async def test_send_connection_closed_ok(self): """send raises ConnectionClosedOK after a normal closure.""" await self.remote_connection.close() @@ -736,14 +793,12 @@ async def test_close_timeout_waiting_for_connection_closed(self): # Remove socket.timeout when dropping Python < 3.10. self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) - async def test_close_does_not_wait_for_recv(self): - # The asyncio implementation has a buffer for incoming messages. Closing - # the connection discards buffered messages. This is allowed by the RFC: - # > However, there is no guarantee that the endpoint that has already - # > sent a Close frame will continue to process data. + async def test_close_preserves_queued_messages(self): + """close preserves messages buffered in the assembler.""" await self.remote_connection.send("😀") await self.connection.close() + self.assertEqual(await self.connection.recv(), "😀") with self.assertRaises(ConnectionClosedOK) as raised: await self.connection.recv() @@ -1011,15 +1066,26 @@ async def test_close_timeout(self): self.assertEqual(connection.close_timeout, 42 * MS) async def test_max_queue(self): - """max_queue parameter configures high-water mark of frames buffer.""" + """max_queue configures high-water mark of frames buffer.""" connection = Connection(Protocol(self.LOCAL), max_queue=4) transport = Mock() connection.connection_made(transport) self.assertEqual(connection.recv_messages.high, 4) + async def test_max_queue_none(self): + """max_queue disables high-water mark of frames buffer.""" + connection = Connection(Protocol(self.LOCAL), max_queue=None) + transport = Mock() + connection.connection_made(transport) + self.assertEqual(connection.recv_messages.high, None) + self.assertEqual(connection.recv_messages.low, None) + async def test_max_queue_tuple(self): - """max_queue parameter configures high-water mark of frames buffer.""" - connection = Connection(Protocol(self.LOCAL), max_queue=(4, 2)) + """max_queue configures high-water and low-water marks of frames buffer.""" + connection = Connection( + Protocol(self.LOCAL), + max_queue=(4, 2), + ) transport = Mock() connection.connection_made(transport) self.assertEqual(connection.recv_messages.high, 4) @@ -1027,14 +1093,20 @@ async def test_max_queue_tuple(self): async def test_write_limit(self): """write_limit parameter configures high-water mark of write buffer.""" - connection = Connection(Protocol(self.LOCAL), write_limit=4096) + connection = Connection( + Protocol(self.LOCAL), + write_limit=4096, + ) transport = Mock() connection.connection_made(transport) transport.set_write_buffer_limits.assert_called_once_with(4096, None) async def test_write_limits(self): """write_limit parameter configures high and low-water marks of write buffer.""" - connection = Connection(Protocol(self.LOCAL), write_limit=(4096, 2048)) + connection = Connection( + Protocol(self.LOCAL), + write_limit=(4096, 2048), + ) transport = Mock() connection.connection_made(transport) transport.set_write_buffer_limits.assert_called_once_with(4096, 2048) @@ -1067,7 +1139,7 @@ async def test_remote_address(self, get_extra_info): async def test_state(self): """Connection has a state attribute.""" - self.assertEqual(self.connection.state, State.OPEN) + self.assertIs(self.connection.state, State.OPEN) async def test_request(self): """Connection has a request attribute.""" @@ -1081,6 +1153,14 @@ async def test_subprotocol(self): """Connection has a subprotocol attribute.""" self.assertIsNone(self.connection.subprotocol) + async def test_close_code(self): + """Connection has a close_code attribute.""" + self.assertIsNone(self.connection.close_code) + + async def test_close_reason(self): + """Connection has a close_reason attribute.""" + self.assertIsNone(self.connection.close_reason) + # Test reporting of network errors. async def test_writing_in_data_received_fails(self): @@ -1194,7 +1274,7 @@ async def test_broadcast_skips_closed_connection(self): await self.connection.close() await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) - with self.assertNoLogs(): + with self.assertNoLogs("websockets", logging.WARNING): broadcast([self.connection], "😀") await self.assertNoFrameSent() @@ -1205,7 +1285,7 @@ async def test_broadcast_skips_closing_connection(self): await asyncio.sleep(0) await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) - with self.assertNoLogs(): + with self.assertNoLogs("websockets", logging.WARNING): broadcast([self.connection], "😀") await self.assertNoFrameSent() @@ -1271,7 +1351,10 @@ async def test_broadcast_skips_connection_failing_to_send(self): self.assertEqual( [record.getMessage() for record in logs.records], - ["skipped broadcast: failed to write message"], + [ + "skipped broadcast: failed to write message: " + "RuntimeError: Cannot call write() after write_eof()" + ], ) @unittest.skipIf( diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py index d2cf25c9c..a90788d02 100644 --- a/tests/asyncio/test_messages.py +++ b/tests/asyncio/test_messages.py @@ -37,14 +37,6 @@ async def test_get_then_put(self): item = await getter_task self.assertEqual(item, 42) - async def test_get_concurrently(self): - """get cannot be called concurrently.""" - getter_task = asyncio.create_task(self.queue.get()) - await asyncio.sleep(0) # let the task start - with self.assertRaises(ConcurrencyError): - await self.queue.get() - getter_task.cancel() - async def test_reset(self): """reset sets the content of the queue.""" self.queue.reset([42]) @@ -59,13 +51,6 @@ async def test_abort(self): with self.assertRaises(EOFError): await getter_task - async def test_abort_clears_queue(self): - """abort clears buffered data from the queue.""" - self.queue.put(42) - self.assertEqual(len(self.queue), 1) - self.queue.abort() - self.assertEqual(len(self.queue), 0) - class AssemblerTests(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): @@ -168,7 +153,7 @@ async def test_get_decoded_binary_message(self): self.assertEqual(message, "tea") async def test_get_resumes_reading(self): - """get resumes reading when queue goes below the high-water mark.""" + """get resumes reading when queue goes below the low-water mark.""" self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) self.assembler.put(Frame(OP_TEXT, b"water")) @@ -185,6 +170,19 @@ async def test_get_resumes_reading(self): await self.assembler.get() self.resume.assert_called_once_with() + async def test_get_does_not_resume_reading(self): + """get does not resume reading when the low-water mark is unset.""" + self.assembler.low = None + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"water")) + await self.assembler.get() + await self.assembler.get() + await self.assembler.get() + + self.resume.assert_not_called() + async def test_cancel_get_before_first_frame(self): """get can be canceled safely before reading the first frame.""" getter_task = asyncio.create_task(self.assembler.get()) @@ -317,7 +315,7 @@ async def test_get_iter_decoded_binary_message(self): self.assertEqual(fragments, ["t", "e", "a"]) async def test_get_iter_resumes_reading(self): - """get_iter resumes reading when queue goes below the high-water mark.""" + """get_iter resumes reading when queue goes below the low-water mark.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) @@ -336,6 +334,20 @@ async def test_get_iter_resumes_reading(self): await anext(iterator) self.resume.assert_called_once_with() + async def test_get_iter_does_not_resume_reading(self): + """get_iter does not resume reading when the low-water mark is unset.""" + self.assembler.low = None + + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + iterator = aiter(self.assembler.get_iter()) + await anext(iterator) + await anext(iterator) + await anext(iterator) + + self.resume.assert_not_called() + async def test_cancel_get_iter_before_first_frame(self): """get_iter can be canceled safely before reading the first frame.""" getter_task = asyncio.create_task(alist(self.assembler.get_iter())) @@ -350,7 +362,7 @@ async def test_cancel_get_iter_before_first_frame(self): self.assertEqual(fragments, ["café"]) async def test_cancel_get_iter_after_first_frame(self): - """get cannot be canceled after reading the first frame.""" + """get_iter cannot be canceled after reading the first frame.""" self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) getter_task = asyncio.create_task(alist(self.assembler.get_iter())) @@ -382,6 +394,17 @@ async def test_put_pauses_reading(self): self.assembler.put(Frame(OP_CONT, b"a")) self.pause.assert_called_once_with() + async def test_put_does_not_pause_reading(self): + """put does not pause reading when the high-water mark is unset.""" + self.assembler.high = None + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + + self.pause.assert_not_called() + # Test termination async def test_get_fails_when_interrupted_by_close(self): @@ -410,6 +433,58 @@ async def test_get_iter_fails_after_close(self): async for _ in self.assembler.get_iter(): self.fail("no fragment expected") + async def test_get_queued_message_after_close(self): + """get returns a message after close is called.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.close() + message = await self.assembler.get() + self.assertEqual(message, "café") + + async def test_get_iter_queued_message_after_close(self): + """get_iter yields a message after close is called.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.close() + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + async def test_get_queued_fragmented_message_after_close(self): + """get reassembles a fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + self.assembler.close() + self.assembler.close() + message = await self.assembler.get() + self.assertEqual(message, b"tea") + + async def test_get_iter_queued_fragmented_message_after_close(self): + """get_iter yields a fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + self.assembler.close() + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, [b"t", b"e", b"a"]) + + async def test_get_partially_queued_fragmented_message_after_close(self): + """get raises EOF on a partial fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.close() + with self.assertRaises(EOFError): + await self.assembler.get() + + async def test_get_iter_partially_queued_fragmented_message_after_close(self): + """get_iter yields a partial fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.close() + fragments = [] + with self.assertRaises(EOFError): + async for fragment in self.assembler.get_iter(): + fragments.append(fragment) + self.assertEqual(fragments, [b"t", b"e"]) + async def test_put_fails_after_close(self): """put raises EOFError after close is called.""" self.assembler.close() @@ -429,7 +504,7 @@ async def test_get_fails_when_get_is_running(self): await asyncio.sleep(0) with self.assertRaises(ConcurrencyError): await self.assembler.get() - self.assembler.close() # let task terminate + self.assembler.put(Frame(OP_TEXT, b"")) # let task terminate async def test_get_fails_when_get_iter_is_running(self): """get cannot be called concurrently with get_iter.""" @@ -437,7 +512,7 @@ async def test_get_fails_when_get_iter_is_running(self): await asyncio.sleep(0) with self.assertRaises(ConcurrencyError): await self.assembler.get() - self.assembler.close() # let task terminate + self.assembler.put(Frame(OP_TEXT, b"")) # let task terminate async def test_get_iter_fails_when_get_is_running(self): """get_iter cannot be called concurrently with get.""" @@ -445,7 +520,7 @@ async def test_get_iter_fails_when_get_is_running(self): await asyncio.sleep(0) with self.assertRaises(ConcurrencyError): await alist(self.assembler.get_iter()) - self.assembler.close() # let task terminate + self.assembler.put(Frame(OP_TEXT, b"")) # let task terminate async def test_get_iter_fails_when_get_iter_is_running(self): """get_iter cannot be called concurrently.""" @@ -453,21 +528,34 @@ async def test_get_iter_fails_when_get_iter_is_running(self): await asyncio.sleep(0) with self.assertRaises(ConcurrencyError): await alist(self.assembler.get_iter()) - self.assembler.close() # let task terminate + self.assembler.put(Frame(OP_TEXT, b"")) # let task terminate # Test setting limits async def test_set_high_water_mark(self): - """high sets the high-water mark.""" + """high sets the high-water and low-water marks.""" assembler = Assembler(high=10) self.assertEqual(assembler.high, 10) + self.assertEqual(assembler.low, 2) + + async def test_set_low_water_mark(self): + """low sets the low-water and high-water marks.""" + assembler = Assembler(low=5) + self.assertEqual(assembler.low, 5) + self.assertEqual(assembler.high, 20) - async def test_set_high_and_low_water_mark(self): - """high sets the high-water mark.""" + async def test_set_high_and_low_water_marks(self): + """high and low set the high-water and low-water marks.""" assembler = Assembler(high=10, low=5) self.assertEqual(assembler.high, 10) self.assertEqual(assembler.low, 5) + async def test_unset_high_and_low_water_marks(self): + """High-water and low-water marks are unset.""" + assembler = Assembler() + self.assertEqual(assembler.high, None) + self.assertEqual(assembler.low, None) + async def test_set_invalid_high_water_mark(self): """high must be a non-negative integer.""" with self.assertRaises(ValueError): diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 47e0148a6..3e289e592 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -21,6 +21,7 @@ CLIENT_CONTEXT, MS, SERVER_CONTEXT, + AssertNoLogsMixin, temp_unix_socket_path, ) from .server import ( @@ -32,7 +33,7 @@ ) -class ServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): +class ServerTests(EvalShellMixin, AssertNoLogsMixin, unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Server receives connection from client and the handshake succeeds.""" async with serve(*args) as server: @@ -148,14 +149,17 @@ def process_request(ws, request): async def handler(ws): self.fail("handler must not run") - async with serve(handler, *args[1:], process_request=process_request) as server: - with self.assertRaises(InvalidStatus) as raised: - async with connect(get_uri(server)): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "server rejected WebSocket connection: HTTP 403", - ) + with self.assertNoLogs("websockets", logging.ERROR): + async with serve( + handler, *args[1:], process_request=process_request + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) async def test_async_process_request_returns_response(self): """Server aborts handshake if async process_request returns a response.""" @@ -166,44 +170,65 @@ async def process_request(ws, request): async def handler(ws): self.fail("handler must not run") - async with serve(handler, *args[1:], process_request=process_request) as server: - with self.assertRaises(InvalidStatus) as raised: - async with connect(get_uri(server)): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "server rejected WebSocket connection: HTTP 403", - ) + with self.assertNoLogs("websockets", logging.ERROR): + async with serve( + handler, *args[1:], process_request=process_request + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) async def test_process_request_raises_exception(self): """Server returns an error if process_request raises an exception.""" def process_request(ws, request): - raise RuntimeError + raise RuntimeError("BOOM") - async with serve(*args, process_request=process_request) as server: - with self.assertRaises(InvalidStatus) as raised: - async with connect(get_uri(server)): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "server rejected WebSocket connection: HTTP 500", - ) + with self.assertLogs("websockets", logging.ERROR) as logs: + async with serve(*args, process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) async def test_async_process_request_raises_exception(self): """Server returns an error if async process_request raises an exception.""" async def process_request(ws, request): - raise RuntimeError + raise RuntimeError("BOOM") - async with serve(*args, process_request=process_request) as server: - with self.assertRaises(InvalidStatus) as raised: - async with connect(get_uri(server)): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "server rejected WebSocket connection: HTTP 500", - ) + with self.assertLogs("websockets", logging.ERROR) as logs: + async with serve(*args, process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) async def test_process_response_returns_none(self): """Server runs process_response but keeps the handshake response.""" @@ -277,31 +302,49 @@ async def test_process_response_raises_exception(self): """Server returns an error if process_response raises an exception.""" def process_response(ws, request, response): - raise RuntimeError + raise RuntimeError("BOOM") - async with serve(*args, process_response=process_response) as server: - with self.assertRaises(InvalidStatus) as raised: - async with connect(get_uri(server)): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "server rejected WebSocket connection: HTTP 500", - ) + with self.assertLogs("websockets", logging.ERROR) as logs: + async with serve(*args, process_response=process_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) async def test_async_process_response_raises_exception(self): """Server returns an error if async process_response raises an exception.""" async def process_response(ws, request, response): - raise RuntimeError + raise RuntimeError("BOOM") - async with serve(*args, process_response=process_response) as server: - with self.assertRaises(InvalidStatus) as raised: - async with connect(get_uri(server)): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "server rejected WebSocket connection: HTTP 500", - ) + with self.assertLogs("websockets", logging.ERROR) as logs: + async with serve(*args, process_response=process_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) async def test_override_server(self): """Server can override Server header with server_header.""" @@ -717,7 +760,7 @@ async def check_credentials(username, password): async def test_without_credentials_or_check_credentials(self): """basic_auth requires either credentials or check_credentials.""" - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: basic_auth() self.assertEqual( str(raised.exception), @@ -726,7 +769,7 @@ async def test_without_credentials_or_check_credentials(self): async def test_with_credentials_and_check_credentials(self): """basic_auth requires only one of credentials and check_credentials.""" - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: basic_auth( credentials=("hello", "iloveyou"), check_credentials=lambda: False, # pragma: no cover diff --git a/tests/extensions/test_permessage_deflate.py b/tests/extensions/test_permessage_deflate.py index ee09813c4..76cd48623 100644 --- a/tests/extensions/test_permessage_deflate.py +++ b/tests/extensions/test_permessage_deflate.py @@ -1,4 +1,5 @@ import dataclasses +import os import unittest from websockets.exceptions import ( @@ -167,6 +168,16 @@ def test_encode_decode_fragmented_binary_frame(self): self.assertEqual(dec_frame1, frame1) self.assertEqual(dec_frame2, frame2) + def test_encode_decode_large_frame(self): + # There is a separate code path that avoids copying data + # when frames are larger than 2kB. Test it for coverage. + frame = Frame(OP_BINARY, os.urandom(4096)) + + enc_frame = self.extension.encode(frame) + dec_frame = self.extension.decode(enc_frame) + + self.assertEqual(dec_frame, frame) + def test_no_decode_text_frame(self): frame = Frame(OP_TEXT, "café".encode()) diff --git a/tests/legacy/__init__.py b/tests/legacy/__init__.py index e69de29bb..035834a89 100644 --- a/tests/legacy/__init__.py +++ b/tests/legacy/__init__.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +import warnings + + +with warnings.catch_warnings(): + # Suppress DeprecationWarning raised by websockets.legacy. + warnings.filterwarnings("ignore", category=DeprecationWarning) + import websockets.legacy # noqa: F401 diff --git a/tests/legacy/test_auth.py b/tests/legacy/test_auth.py index 3754bcf3a..dabd4212a 100644 --- a/tests/legacy/test_auth.py +++ b/tests/legacy/test_auth.py @@ -2,10 +2,10 @@ import unittest import urllib.error -from websockets.exceptions import InvalidStatusCode from websockets.headers import build_authorization_basic from websockets.legacy.auth import * from websockets.legacy.auth import is_credentials +from websockets.legacy.exceptions import InvalidStatusCode from .test_client_server import ClientServerTestsMixin, with_client, with_server from .utils import AsyncioTestCase diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 2f3ba9b77..c13c6c92e 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -21,7 +21,6 @@ ConnectionClosed, InvalidHandshake, InvalidHeader, - InvalidStatusCode, NegotiationError, ) from websockets.extensions.permessage_deflate import ( @@ -32,6 +31,7 @@ from websockets.frames import CloseCode from websockets.http11 import USER_AGENT from websockets.legacy.client import * +from websockets.legacy.exceptions import InvalidStatusCode from websockets.legacy.handshake import build_response from websockets.legacy.http import read_response from websockets.legacy.server import * @@ -1607,6 +1607,10 @@ async def run_client(): ], ) # Iteration 3 + exc = ( + "websockets.legacy.exceptions.InvalidStatusCode: " + "server rejected WebSocket connection: HTTP 503" + ) self.assertEqual( [ re.sub(r"[0-9\.]+ seconds", "X seconds", record.getMessage()) @@ -1615,12 +1619,12 @@ async def run_client(): [ "connection rejected (503 Service Unavailable)", "connection closed", - "! connect failed; reconnecting in X seconds", + f"connect failed; reconnecting in X seconds: {exc}", ] + [ "connection rejected (503 Service Unavailable)", "connection closed", - "! connect failed again; retrying in X seconds", + f"connect failed again; retrying in X seconds: {exc}", ] * ((len(logs.records) - 8) // 3) + [ diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index de2a320b5..d30198934 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -938,7 +938,7 @@ def test_answer_ping_does_not_crash_if_connection_closing(self): self.receive_frame(Frame(True, OP_PING, b"test")) self.run_loop_once() - with self.assertNoLogs(): + with self.assertNoLogs("websockets", logging.ERROR): self.loop.run_until_complete(self.protocol.close()) self.loop.run_until_complete(close_task) # cleanup @@ -951,7 +951,7 @@ def test_answer_ping_does_not_crash_if_connection_closed(self): self.receive_eof() self.run_loop_once() - with self.assertNoLogs(): + with self.assertNoLogs("websockets", logging.ERROR): self.loop.run_until_complete(self.protocol.close()) def test_ignore_pong(self): @@ -1028,7 +1028,7 @@ def test_acknowledge_aborted_ping(self): pong_waiter.result() # transfer_data doesn't crash, which would be logged. - with self.assertNoLogs(): + with self.assertNoLogs("websockets", logging.ERROR): # Unclog incoming queue. self.loop.run_until_complete(self.protocol.recv()) self.loop.run_until_complete(self.protocol.recv()) @@ -1375,7 +1375,7 @@ def test_remote_close_and_connection_lost(self): self.receive_eof() self.run_loop_once() - with self.assertNoLogs(): + with self.assertNoLogs("websockets", logging.ERROR): self.loop.run_until_complete(self.protocol.close(reason="oh noes!")) self.assertConnectionClosed(CloseCode.NORMAL_CLOSURE, "close") @@ -1500,14 +1500,14 @@ def test_broadcast_two_clients(self): def test_broadcast_skips_closed_connection(self): self.close_connection() - with self.assertNoLogs(): + with self.assertNoLogs("websockets", logging.ERROR): broadcast([self.protocol], "café") self.assertNoFrameSent() def test_broadcast_skips_closing_connection(self): close_task = self.half_close_connection_local() - with self.assertNoLogs(): + with self.assertNoLogs("websockets", logging.ERROR): broadcast([self.protocol], "café") self.assertNoFrameSent() @@ -1547,14 +1547,14 @@ def test_broadcast_reports_connection_sending_fragmented_text(self): def test_broadcast_skips_connection_failing_to_send(self): # Configure mock to raise an exception when writing to the network. - self.protocol.transport.write.side_effect = RuntimeError + self.protocol.transport.write.side_effect = RuntimeError("BOOM") with self.assertLogs("websockets", logging.WARNING) as logs: broadcast([self.protocol], "café") self.assertEqual( [record.getMessage() for record in logs.records], - ["skipped broadcast: failed to write message"], + ["skipped broadcast: failed to write message: RuntimeError: BOOM"], ) @unittest.skipIf( diff --git a/tests/legacy/utils.py b/tests/legacy/utils.py index 5b56050d5..1f79bb600 100644 --- a/tests/legacy/utils.py +++ b/tests/legacy/utils.py @@ -1,12 +1,12 @@ import asyncio -import contextlib import functools -import logging import sys import unittest +from ..utils import AssertNoLogsMixin -class AsyncioTestCase(unittest.TestCase): + +class AsyncioTestCase(AssertNoLogsMixin, unittest.TestCase): """ Base class for tests that sets up an isolated event loop for each test. @@ -56,23 +56,6 @@ def run_loop_once(self): self.loop.call_soon(self.loop.stop) self.loop.run_forever() - if sys.version_info[:2] < (3, 10): # pragma: no cover - - @contextlib.contextmanager - def assertNoLogs(self, logger="websockets", level=logging.ERROR): - """ - No message is logged on the given logger with at least the given level. - - """ - with self.assertLogs(logger, level) as logs: - # We want to test that no log message is emitted - # but assertLogs expects at least one log message. - logging.getLogger(logger).log(level, "dummy") - yield - - level_name = logging.getLevelName(level) - self.assertEqual(logs.output, [f"{level_name}:{logger}:dummy"]) - def assertDeprecationWarnings(self, recorded_warnings, expected_warnings): """ Check recorded deprecation warnings match a list of expected messages. diff --git a/tests/maxi_cov.py b/tests/maxi_cov.py index 83686c3d3..8ccef7d39 100755 --- a/tests/maxi_cov.py +++ b/tests/maxi_cov.py @@ -9,7 +9,6 @@ UNMAPPED_SRC_FILES = [ - "websockets/auth.py", "websockets/typing.py", "websockets/version.py", ] @@ -105,7 +104,6 @@ def get_ignored_files(src_dir="src"): # or websockets (import locations). "*/websockets/asyncio/async_timeout.py", "*/websockets/asyncio/compatibility.py", - "*/websockets/auth.py", # This approach isn't applicable to the test suite of the legacy # implementation, due to the huge test_client_server test module. "*/websockets/legacy/*", diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index e63d774b7..9d457a912 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -1,3 +1,4 @@ +import http import logging import socket import socketserver @@ -6,7 +7,7 @@ import time import unittest -from websockets.exceptions import InvalidHandshake, InvalidURI +from websockets.exceptions import InvalidHandshake, InvalidStatus, InvalidURI from websockets.extensions.permessage_deflate import PerMessageDeflate from websockets.sync.client import * @@ -156,6 +157,36 @@ def close_connection(self, request): "connection closed while reading HTTP status line", ) + def test_http_response(self): + """Client reads HTTP response.""" + + def http_response(connection, request): + return connection.respond(http.HTTPStatus.OK, "👌") + + with run_server(process_request=http_response) as server: + with self.assertRaises(InvalidStatus) as raised: + with connect(get_uri(server)): + self.fail("did not raise") + + self.assertEqual(raised.exception.response.status_code, 200) + self.assertEqual(raised.exception.response.body.decode(), "👌") + + def test_http_response_without_content_length(self): + """Client reads HTTP response without a Content-Length header.""" + + def http_response(connection, request): + response = connection.respond(http.HTTPStatus.OK, "👌") + del response.headers["Content-Length"] + return response + + with run_server(process_request=http_response) as server: + with self.assertRaises(InvalidStatus) as raised: + with connect(get_uri(server)): + self.fail("did not raise") + + self.assertEqual(raised.exception.response.status_code, 200) + self.assertEqual(raised.exception.response.body.decode(), "👌") + def test_junk_handshake(self): """Client closes the connection when receiving non-HTTP response from server.""" @@ -280,7 +311,7 @@ def test_set_server_hostname(self): class ClientUsageErrorsTests(unittest.TestCase): def test_ssl_without_secure_uri(self): """Client rejects ssl when URI isn't secure.""" - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: connect("ws://localhost/", ssl=CLIENT_CONTEXT) self.assertEqual( str(raised.exception), @@ -289,7 +320,7 @@ def test_ssl_without_secure_uri(self): def test_unix_without_path_or_sock(self): """Unix client requires path when sock isn't provided.""" - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: unix_connect() self.assertEqual( str(raised.exception), @@ -300,7 +331,7 @@ def test_unix_with_path_and_sock(self): """Unix client rejects path when sock is provided.""" sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.addCleanup(sock.close) - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: unix_connect(path="/", sock=sock) self.assertEqual( str(raised.exception), diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 16f92e164..e21e310a2 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -1,6 +1,5 @@ import contextlib import logging -import platform import socket import sys import threading @@ -15,7 +14,7 @@ ConnectionClosedOK, ) from websockets.frames import CloseCode, Frame, Opcode -from websockets.protocol import CLIENT, SERVER, Protocol +from websockets.protocol import CLIENT, SERVER, Protocol, State from websockets.sync.connection import * from ..protocol import RecordingProtocol @@ -154,6 +153,16 @@ def test_recv_binary(self): self.remote_connection.send(b"\x01\x02\xfe\xff") self.assertEqual(self.connection.recv(), b"\x01\x02\xfe\xff") + def test_recv_text_as_bytes(self): + """recv receives a text message as bytes.""" + self.remote_connection.send("😀") + self.assertEqual(self.connection.recv(decode=False), "😀".encode()) + + def test_recv_binary_as_text(self): + """recv receives a binary message as a str.""" + self.remote_connection.send("😀".encode()) + self.assertEqual(self.connection.recv(decode=True), "😀") + def test_recv_fragmented_text(self): """recv receives a fragmented text message.""" self.remote_connection.send(["😀", "😀"]) @@ -176,6 +185,15 @@ def test_recv_connection_closed_error(self): with self.assertRaises(ConnectionClosedError): self.connection.recv() + def test_recv_non_utf8_text(self): + """recv receives a non-UTF-8 text message.""" + self.remote_connection.send(b"\x01\x02\xfe\xff", text=True) + with self.assertRaises(ConnectionClosedError): + self.connection.recv() + self.assertFrameSent( + Frame(Opcode.CLOSE, b"\x03\xefinvalid start byte at position 2") + ) + def test_recv_during_recv(self): """recv raises ConcurrencyError when called concurrently.""" recv_thread = threading.Thread(target=self.connection.recv) @@ -228,6 +246,22 @@ def test_recv_streaming_binary(self): [b"\x01\x02\xfe\xff"], ) + def test_recv_streaming_text_as_bytes(self): + """recv_streaming receives a text message as bytes.""" + self.remote_connection.send("😀") + self.assertEqual( + list(self.connection.recv_streaming(decode=False)), + ["😀".encode()], + ) + + def test_recv_streaming_binary_as_str(self): + """recv_streaming receives a binary message as a str.""" + self.remote_connection.send("😀".encode()) + self.assertEqual( + list(self.connection.recv_streaming(decode=True)), + ["😀"], + ) + def test_recv_streaming_fragmented_text(self): """recv_streaming receives a fragmented text message.""" self.remote_connection.send(["😀", "😀"]) @@ -260,6 +294,15 @@ def test_recv_streaming_connection_closed_error(self): for _ in self.connection.recv_streaming(): self.fail("did not raise") + def test_recv_streaming_non_utf8_text(self): + """recv_streaming receives a non-UTF-8 text message.""" + self.remote_connection.send(b"\x01\x02\xfe\xff", text=True) + with self.assertRaises(ConnectionClosedError): + list(self.connection.recv_streaming()) + self.assertFrameSent( + Frame(Opcode.CLOSE, b"\x03\xefinvalid start byte at position 2") + ) + def test_recv_streaming_during_recv(self): """recv_streaming raises ConcurrencyError when called concurrently with recv.""" recv_thread = threading.Thread(target=self.connection.recv) @@ -308,6 +351,16 @@ def test_send_binary(self): self.connection.send(b"\x01\x02\xfe\xff") self.assertEqual(self.remote_connection.recv(), b"\x01\x02\xfe\xff") + def test_send_binary_from_str(self): + """send sends a binary message from a str.""" + self.connection.send("😀", text=False) + self.assertEqual(self.remote_connection.recv(), "😀".encode()) + + def test_send_text_from_bytes(self): + """send sends a text message from bytes.""" + self.connection.send("😀".encode(), text=True) + self.assertEqual(self.remote_connection.recv(), "😀") + def test_send_fragmented_text(self): """send sends a fragmented text message.""" self.connection.send(["😀", "😀"]) @@ -326,6 +379,24 @@ def test_send_fragmented_binary(self): [b"\x01\x02", b"\xfe\xff", b""], ) + def test_send_fragmented_binary_from_str(self): + """send sends a fragmented binary message from a str.""" + self.connection.send(["😀", "😀"], text=False) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + list(self.remote_connection.recv_streaming()), + ["😀".encode(), "😀".encode(), b""], + ) + + def test_send_fragmented_text_from_bytes(self): + """send sends a fragmented text message from bytes.""" + self.connection.send(["😀".encode(), "😀".encode()], text=True) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + list(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + def test_send_connection_closed_ok(self): """send raises ConnectionClosedOK after a normal closure.""" self.remote_connection.close() @@ -471,28 +542,12 @@ def test_close_timeout_waiting_for_connection_closed(self): # Remove socket.timeout when dropping Python < 3.10. self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) - def test_close_waits_for_recv(self): - # The sync implementation doesn't have a buffer for incoming messsages. - # It requires reading incoming frames until the close frame is reached. - # This behavior — close() blocks until recv() is called — is less than - # ideal and inconsistent with the asyncio implementation. + def test_close_preserves_queued_messages(self): + """close preserves messages buffered in the assembler.""" self.remote_connection.send("😀") + self.connection.close() - close_thread = threading.Thread(target=self.connection.close) - close_thread.start() - - # Let close() initiate the closing handshake and send a close frame. - time.sleep(MS) - self.assertTrue(close_thread.is_alive()) - - # Connection isn't closed yet. - self.connection.recv() - - # Let close() receive a close frame and finish the closing handshake. - time.sleep(MS) - self.assertFalse(close_thread.is_alive()) - - # Connection is closed now. + self.assertEqual(self.connection.recv(), "😀") with self.assertRaises(ConnectionClosedOK) as raised: self.connection.recv() @@ -500,24 +555,6 @@ def test_close_waits_for_recv(self): self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") self.assertIsNone(exc.__cause__) - def test_close_timeout_waiting_for_recv(self): - self.remote_connection.send("😀") - - close_thread = threading.Thread(target=self.connection.close) - close_thread.start() - - # Let close() time out during the closing handshake. - time.sleep(3 * MS) - self.assertFalse(close_thread.is_alive()) - - # Connection is closed now. - with self.assertRaises(ConnectionClosedError) as raised: - self.connection.recv() - - exc = raised.exception - self.assertEqual(str(exc), "sent 1000 (OK); no close frame received") - self.assertIsInstance(exc.__cause__, TimeoutError) - def test_close_idempotency(self): """close does nothing if the connection is already closed.""" self.connection.close() @@ -526,17 +563,13 @@ def test_close_idempotency(self): self.connection.close() self.assertNoFrameSent() - @unittest.skipIf( - platform.python_implementation() == "PyPy", - "this test fails randomly due to a bug in PyPy", # see #1314 for details - ) def test_close_idempotency_race_condition(self): """close waits if the connection is already closing.""" - self.connection.close_timeout = 5 * MS + self.connection.close_timeout = 6 * MS def closer(): - with self.delay_frames_rcvd(3 * MS): + with self.delay_frames_rcvd(4 * MS): self.connection.close() close_thread = threading.Thread(target=closer) @@ -548,14 +581,14 @@ def closer(): # Connection isn't closed yet. with self.assertRaises(TimeoutError): - self.connection.recv(timeout=0) + self.connection.recv(timeout=MS) self.connection.close() self.assertNoFrameSent() # Connection is closed now. with self.assertRaises(ConnectionClosedOK): - self.connection.recv(timeout=0) + self.connection.recv(timeout=MS) close_thread.join() @@ -696,6 +729,58 @@ def test_pong_unsupported_type(self): with self.assertRaises(TypeError): self.connection.pong([]) + # Test parameters. + + def test_close_timeout(self): + """close_timeout parameter configures close timeout.""" + socket_, remote_socket = socket.socketpair() + self.addCleanup(socket_.close) + self.addCleanup(remote_socket.close) + connection = Connection( + socket_, + Protocol(self.LOCAL), + close_timeout=42 * MS, + ) + self.assertEqual(connection.close_timeout, 42 * MS) + + def test_max_queue(self): + """max_queue configures high-water mark of frames buffer.""" + socket_, remote_socket = socket.socketpair() + self.addCleanup(socket_.close) + self.addCleanup(remote_socket.close) + connection = Connection( + socket_, + Protocol(self.LOCAL), + max_queue=4, + ) + self.assertEqual(connection.recv_messages.high, 4) + + def test_max_queue_none(self): + """max_queue disables high-water mark of frames buffer.""" + socket_, remote_socket = socket.socketpair() + self.addCleanup(socket_.close) + self.addCleanup(remote_socket.close) + connection = Connection( + socket_, + Protocol(self.LOCAL), + max_queue=None, + ) + self.assertEqual(connection.recv_messages.high, None) + self.assertEqual(connection.recv_messages.high, None) + + def test_max_queue_tuple(self): + """max_queue configures high-water and low-water marks of frames buffer.""" + socket_, remote_socket = socket.socketpair() + self.addCleanup(socket_.close) + self.addCleanup(remote_socket.close) + connection = Connection( + socket_, + Protocol(self.LOCAL), + max_queue=(4, 2), + ) + self.assertEqual(connection.recv_messages.high, 4) + self.assertEqual(connection.recv_messages.low, 2) + # Test attributes. def test_id(self): @@ -718,6 +803,10 @@ def test_remote_address(self, getpeername): self.assertEqual(self.connection.remote_address, ("peer", 1234)) getpeername.assert_called_with() + def test_state(self): + """Connection has a state attribute.""" + self.assertIs(self.connection.state, State.OPEN) + def test_request(self): """Connection has a request attribute.""" self.assertIsNone(self.connection.request) @@ -730,6 +819,14 @@ def test_subprotocol(self): """Connection has a subprotocol attribute.""" self.assertIsNone(self.connection.subprotocol) + def test_close_code(self): + """Connection has a close_code attribute.""" + self.assertIsNone(self.connection.close_code) + + def test_close_reason(self): + """Connection has a close_reason attribute.""" + self.assertIsNone(self.connection.close_reason) + # Test reporting of network errors. @unittest.skipUnless(sys.platform == "darwin", "works only on BSD") diff --git a/tests/sync/test_messages.py b/tests/sync/test_messages.py index d44b39b88..d22693102 100644 --- a/tests/sync/test_messages.py +++ b/tests/sync/test_messages.py @@ -1,4 +1,6 @@ import time +import unittest +import unittest.mock from websockets.exceptions import ConcurrencyError from websockets.frames import OP_BINARY, OP_CONT, OP_TEXT, Frame @@ -9,66 +11,23 @@ class AssemblerTests(ThreadTestCase): - """ - Tests in this class interact a lot with hidden synchronization mechanisms: - - - get() / get_iter() and put() must run in separate threads when a final - frame is set because put() waits for get() / get_iter() to fetch the - message before returning. - - - run_in_thread() lets its target run before yielding back control on entry, - which guarantees the intended execution order of test cases. - - - run_in_thread() waits for its target to finish running before yielding - back control on exit, which allows making assertions immediately. - - - When the main thread performs actions that let another thread progress, it - must wait before making assertions, to avoid depending on scheduling. - - """ - def setUp(self): - self.assembler = Assembler() - - def tearDown(self): - """ - Check that the assembler goes back to its default state after each test. - - This removes the need for testing various sequences. - - """ - self.assertFalse(self.assembler.mutex.locked()) - self.assertFalse(self.assembler.get_in_progress) - self.assertFalse(self.assembler.put_in_progress) - if not self.assembler.closed: - self.assertFalse(self.assembler.message_complete.is_set()) - self.assertFalse(self.assembler.message_fetched.is_set()) - self.assertIsNone(self.assembler.decoder) - self.assertEqual(self.assembler.chunks, []) - self.assertIsNone(self.assembler.chunks_queue) + self.pause = unittest.mock.Mock() + self.resume = unittest.mock.Mock() + self.assembler = Assembler(high=2, low=1, pause=self.pause, resume=self.resume) # Test get def test_get_text_message_already_received(self): """get returns a text message that is already received.""" - - def putter(): - self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) - - with self.run_in_thread(putter): - message = self.assembler.get() - + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = self.assembler.get() self.assertEqual(message, "café") def test_get_binary_message_already_received(self): """get returns a binary message that is already received.""" - - def putter(): - self.assembler.put(Frame(OP_BINARY, b"tea")) - - with self.run_in_thread(putter): - message = self.assembler.get() - + self.assembler.put(Frame(OP_BINARY, b"tea")) + message = self.assembler.get() self.assertEqual(message, b"tea") def test_get_text_message_not_received_yet(self): @@ -99,112 +58,158 @@ def getter(): def test_get_fragmented_text_message_already_received(self): """get reassembles a fragmented a text message that is already received.""" - - def putter(): - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) - self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) - self.assembler.put(Frame(OP_CONT, b"\xa9")) - - with self.run_in_thread(putter): - message = self.assembler.get() - + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + message = self.assembler.get() self.assertEqual(message, "café") def test_get_fragmented_binary_message_already_received(self): """get reassembles a fragmented binary message that is already received.""" - - def putter(): - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) - self.assembler.put(Frame(OP_CONT, b"e", fin=False)) - self.assembler.put(Frame(OP_CONT, b"a")) - - with self.run_in_thread(putter): - message = self.assembler.get() - + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + message = self.assembler.get() self.assertEqual(message, b"tea") - def test_get_fragmented_text_message_being_received(self): - """get reassembles a fragmented text message that is partially received.""" + def test_get_fragmented_text_message_not_received_yet(self): + """get reassembles a fragmented text message when it is received.""" message = None def getter(): nonlocal message message = self.assembler.get() - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) with self.run_in_thread(getter): + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) self.assertEqual(message, "café") - def test_get_fragmented_binary_message_being_received(self): - """get reassembles a fragmented binary message that is partially received.""" + def test_get_fragmented_binary_message_not_received_yet(self): + """get reassembles a fragmented binary message when it is received.""" message = None def getter(): nonlocal message message = self.assembler.get() - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) with self.run_in_thread(getter): + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) self.assertEqual(message, b"tea") - def test_get_fragmented_text_message_not_received_yet(self): - """get reassembles a fragmented text message when it is received.""" + def test_get_fragmented_text_message_being_received(self): + """get reassembles a fragmented text message that is partially received.""" message = None def getter(): nonlocal message message = self.assembler.get() + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) with self.run_in_thread(getter): - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) self.assertEqual(message, "café") - def test_get_fragmented_binary_message_not_received_yet(self): - """get reassembles a fragmented binary message when it is received.""" + def test_get_fragmented_binary_message_being_received(self): + """get reassembles a fragmented binary message that is partially received.""" message = None def getter(): nonlocal message message = self.assembler.get() + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) with self.run_in_thread(getter): - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) self.assertEqual(message, b"tea") - # Test get_iter + def test_get_encoded_text_message(self): + """get returns a text message without UTF-8 decoding.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = self.assembler.get(decode=False) + self.assertEqual(message, b"caf\xc3\xa9") + + def test_get_decoded_binary_message(self): + """get returns a binary message with UTF-8 decoding.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + message = self.assembler.get(decode=True) + self.assertEqual(message, "tea") + + def test_get_resumes_reading(self): + """get resumes reading when queue goes below the low-water mark.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"water")) + + # queue is above the low-water mark + self.assembler.get() + self.resume.assert_not_called() + + # queue is at the low-water mark + self.assembler.get() + self.resume.assert_called_once_with() + + # queue is below the low-water mark + self.assembler.get() + self.resume.assert_called_once_with() + + def test_get_does_not_resume_reading(self): + """get does not resume reading when the low-water mark is unset.""" + self.assembler.low = None + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"water")) + self.assembler.get() + self.assembler.get() + self.assembler.get() + + self.resume.assert_not_called() + + def test_get_timeout_before_first_frame(self): + """get times out before reading the first frame.""" + with self.assertRaises(TimeoutError): + self.assembler.get(timeout=MS) - def test_get_iter_text_message_already_received(self): - """get_iter yields a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) - def putter(): - self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = self.assembler.get() + self.assertEqual(message, "café") - with self.run_in_thread(putter): - fragments = list(self.assembler.get_iter()) + def test_get_timeout_after_first_frame(self): + """get times out after reading the first frame.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) - self.assertEqual(fragments, ["café"]) + with self.assertRaises(TimeoutError): + self.assembler.get(timeout=MS) - def test_get_iter_binary_message_already_received(self): - """get_iter yields a binary message that is already received.""" + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) - def putter(): - self.assembler.put(Frame(OP_BINARY, b"tea")) + message = self.assembler.get() + self.assertEqual(message, "café") - with self.run_in_thread(putter): - fragments = list(self.assembler.get_iter()) + # Test get_iter + + def test_get_iter_text_message_already_received(self): + """get_iter yields a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + fragments = list(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + def test_get_iter_binary_message_already_received(self): + """get_iter yields a binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + fragments = list(self.assembler.get_iter()) self.assertEqual(fragments, [b"tea"]) def test_get_iter_text_message_not_received_yet(self): @@ -212,6 +217,7 @@ def test_get_iter_text_message_not_received_yet(self): fragments = [] def getter(): + nonlocal fragments for fragment in self.assembler.get_iter(): fragments.append(fragment) @@ -225,6 +231,7 @@ def test_get_iter_binary_message_not_received_yet(self): fragments = [] def getter(): + nonlocal fragments for fragment in self.assembler.get_iter(): fragments.append(fragment) @@ -235,121 +242,137 @@ def getter(): def test_get_iter_fragmented_text_message_already_received(self): """get_iter yields a fragmented text message that is already received.""" - - def putter(): - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) - self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) - self.assembler.put(Frame(OP_CONT, b"\xa9")) - - with self.run_in_thread(putter): - fragments = list(self.assembler.get_iter()) - + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + fragments = list(self.assembler.get_iter()) self.assertEqual(fragments, ["ca", "f", "é"]) def test_get_iter_fragmented_binary_message_already_received(self): """get_iter yields a fragmented binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + fragments = list(self.assembler.get_iter()) + self.assertEqual(fragments, [b"t", b"e", b"a"]) - def putter(): - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) - self.assembler.put(Frame(OP_CONT, b"e", fin=False)) - self.assembler.put(Frame(OP_CONT, b"a")) - - with self.run_in_thread(putter): - fragments = list(self.assembler.get_iter()) + def test_get_iter_fragmented_text_message_not_received_yet(self): + """get_iter yields a fragmented text message when it is received.""" + iterator = self.assembler.get_iter() + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assertEqual(next(iterator), "ca") + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assertEqual(next(iterator), "f") + self.assembler.put(Frame(OP_CONT, b"\xa9")) + self.assertEqual(next(iterator), "é") - self.assertEqual(fragments, [b"t", b"e", b"a"]) + def test_get_iter_fragmented_binary_message_not_received_yet(self): + """get_iter yields a fragmented binary message when it is received.""" + iterator = self.assembler.get_iter() + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assertEqual(next(iterator), b"t") + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assertEqual(next(iterator), b"e") + self.assembler.put(Frame(OP_CONT, b"a")) + self.assertEqual(next(iterator), b"a") def test_get_iter_fragmented_text_message_being_received(self): """get_iter yields a fragmented text message that is partially received.""" - fragments = [] - - def getter(): - for fragment in self.assembler.get_iter(): - fragments.append(fragment) - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) - with self.run_in_thread(getter): - self.assertEqual(fragments, ["ca"]) - self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) - time.sleep(MS) - self.assertEqual(fragments, ["ca", "f"]) - self.assembler.put(Frame(OP_CONT, b"\xa9")) - - self.assertEqual(fragments, ["ca", "f", "é"]) + iterator = self.assembler.get_iter() + self.assertEqual(next(iterator), "ca") + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assertEqual(next(iterator), "f") + self.assembler.put(Frame(OP_CONT, b"\xa9")) + self.assertEqual(next(iterator), "é") def test_get_iter_fragmented_binary_message_being_received(self): """get_iter yields a fragmented binary message that is partially received.""" - fragments = [] + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + iterator = self.assembler.get_iter() + self.assertEqual(next(iterator), b"t") + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assertEqual(next(iterator), b"e") + self.assembler.put(Frame(OP_CONT, b"a")) + self.assertEqual(next(iterator), b"a") + + def test_get_iter_encoded_text_message(self): + """get_iter yields a text message without UTF-8 decoding.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + fragments = list(self.assembler.get_iter(decode=False)) + self.assertEqual(fragments, [b"ca", b"f\xc3", b"\xa9"]) - def getter(): - for fragment in self.assembler.get_iter(): - fragments.append(fragment) + def test_get_iter_decoded_binary_message(self): + """get_iter yields a binary message with UTF-8 decoding.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + fragments = list(self.assembler.get_iter(decode=True)) + self.assertEqual(fragments, ["t", "e", "a"]) + def test_get_iter_resumes_reading(self): + """get_iter resumes reading when queue goes below the low-water mark.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) - with self.run_in_thread(getter): - self.assertEqual(fragments, [b"t"]) - self.assembler.put(Frame(OP_CONT, b"e", fin=False)) - time.sleep(MS) - self.assertEqual(fragments, [b"t", b"e"]) - self.assembler.put(Frame(OP_CONT, b"a")) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) - self.assertEqual(fragments, [b"t", b"e", b"a"]) + iterator = self.assembler.get_iter() - def test_get_iter_fragmented_text_message_not_received_yet(self): - """get_iter yields a fragmented text message when it is received.""" - fragments = [] + # queue is above the low-water mark + next(iterator) + self.resume.assert_not_called() - def getter(): - for fragment in self.assembler.get_iter(): - fragments.append(fragment) + # queue is at the low-water mark + next(iterator) + self.resume.assert_called_once_with() - with self.run_in_thread(getter): - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) - time.sleep(MS) - self.assertEqual(fragments, ["ca"]) - self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) - time.sleep(MS) - self.assertEqual(fragments, ["ca", "f"]) - self.assembler.put(Frame(OP_CONT, b"\xa9")) + # queue is below the low-water mark + next(iterator) + self.resume.assert_called_once_with() - self.assertEqual(fragments, ["ca", "f", "é"]) + def test_get_iter_does_not_resume_reading(self): + """get_iter does not resume reading when the low-water mark is unset.""" + self.assembler.low = None - def test_get_iter_fragmented_binary_message_not_received_yet(self): - """get_iter yields a fragmented binary message when it is received.""" - fragments = [] + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + iterator = self.assembler.get_iter() + next(iterator) + next(iterator) + next(iterator) - def getter(): - for fragment in self.assembler.get_iter(): - fragments.append(fragment) + self.resume.assert_not_called() - with self.run_in_thread(getter): - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) - time.sleep(MS) - self.assertEqual(fragments, [b"t"]) - self.assembler.put(Frame(OP_CONT, b"e", fin=False)) - time.sleep(MS) - self.assertEqual(fragments, [b"t", b"e"]) - self.assembler.put(Frame(OP_CONT, b"a")) + # Test put - self.assertEqual(fragments, [b"t", b"e", b"a"]) - - # Test timeouts + def test_put_pauses_reading(self): + """put pauses reading when queue goes above the high-water mark.""" + # queue is below the high-water mark + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.pause.assert_not_called() - def test_get_with_timeout_completes(self): - """get returns a message when it is received before the timeout.""" + # queue is at the high-water mark + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.pause.assert_called_once_with() - def putter(): - self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + # queue is above the high-water mark + self.assembler.put(Frame(OP_CONT, b"a")) + self.pause.assert_called_once_with() - with self.run_in_thread(putter): - message = self.assembler.get(MS) + def test_put_does_not_pause_reading(self): + """put does not pause reading when the high-water mark is unset.""" + self.assembler.high = None - self.assertEqual(message, "café") + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) - def test_get_with_timeout_times_out(self): - """get raises TimeoutError when no message is received before the timeout.""" - with self.assertRaises(TimeoutError): - self.assembler.get(MS) + self.pause.assert_not_called() # Test termination @@ -373,18 +396,8 @@ def closer(): with self.run_in_thread(closer): with self.assertRaises(EOFError): - list(self.assembler.get_iter()) - - def test_put_fails_when_interrupted_by_close(self): - """put raises EOFError when close is called.""" - - def closer(): - time.sleep(2 * MS) - self.assembler.close() - - with self.run_in_thread(closer): - with self.assertRaises(EOFError): - self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + for _ in self.assembler.get_iter(): + self.fail("no fragment expected") def test_get_fails_after_close(self): """get raises EOFError after close is called.""" @@ -396,7 +409,60 @@ def test_get_iter_fails_after_close(self): """get_iter raises EOFError after close is called.""" self.assembler.close() with self.assertRaises(EOFError): - list(self.assembler.get_iter()) + for _ in self.assembler.get_iter(): + self.fail("no fragment expected") + + def test_get_queued_message_after_close(self): + """get returns a message after close is called.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.close() + message = self.assembler.get() + self.assertEqual(message, "café") + + def test_get_iter_queued_message_after_close(self): + """get_iter yields a message after close is called.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.close() + fragments = list(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + def test_get_queued_fragmented_message_after_close(self): + """get reassembles a fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + self.assembler.close() + self.assembler.close() + message = self.assembler.get() + self.assertEqual(message, b"tea") + + def test_get_iter_queued_fragmented_message_after_close(self): + """get_iter yields a fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + self.assembler.close() + fragments = list(self.assembler.get_iter()) + self.assertEqual(fragments, [b"t", b"e", b"a"]) + + def test_get_partially_queued_fragmented_message_after_close(self): + """get raises EOF on a partial fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.close() + with self.assertRaises(EOFError): + self.assembler.get() + + def test_get_iter_partially_queued_fragmented_message_after_close(self): + """get_iter yields a partial fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.close() + fragments = [] + with self.assertRaises(EOFError): + for fragment in self.assembler.get_iter(): + fragments.append(fragment) + self.assertEqual(fragments, [b"t", b"e"]) def test_put_fails_after_close(self): """put raises EOFError after close is called.""" @@ -439,13 +505,38 @@ def test_get_iter_fails_when_get_iter_is_running(self): list(self.assembler.get_iter()) self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread - def test_put_fails_when_put_is_running(self): - """put cannot be called concurrently.""" - - def putter(): - self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) - - with self.run_in_thread(putter): - with self.assertRaises(ConcurrencyError): - self.assembler.put(Frame(OP_BINARY, b"tea")) - self.assembler.get() # unblock other thread + # Test setting limits + + def test_set_high_water_mark(self): + """high sets the high-water and low-water marks.""" + assembler = Assembler(high=10) + self.assertEqual(assembler.high, 10) + self.assertEqual(assembler.low, 2) + + def test_set_low_water_mark(self): + """low sets the low-water and high-water marks.""" + assembler = Assembler(low=5) + self.assertEqual(assembler.low, 5) + self.assertEqual(assembler.high, 20) + + def test_set_high_and_low_water_marks(self): + """high and low set the high-water and low-water marks.""" + assembler = Assembler(high=10, low=5) + self.assertEqual(assembler.high, 10) + self.assertEqual(assembler.low, 5) + + def test_unset_high_and_low_water_marks(self): + """High-water and low-water marks are unset.""" + assembler = Assembler() + self.assertEqual(assembler.high, None) + self.assertEqual(assembler.low, None) + + def test_set_invalid_high_water_mark(self): + """high must be a non-negative integer.""" + with self.assertRaises(ValueError): + Assembler(high=-1) + + def test_set_invalid_low_water_mark(self): + """low must be higher than high.""" + with self.assertRaises(ValueError): + Assembler(low=10, high=5) diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 3bc6f76cd..54e49bf16 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -361,7 +361,7 @@ def test_connection(self): class ServerUsageErrorsTests(unittest.TestCase): def test_unix_without_path_or_sock(self): """Unix server requires path when sock isn't provided.""" - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: unix_serve(handler) self.assertEqual( str(raised.exception), @@ -372,7 +372,7 @@ def test_unix_with_path_and_sock(self): """Unix server rejects path when sock is provided.""" sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.addCleanup(sock.close) - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: unix_serve(handler, path="/", sock=sock) self.assertEqual( str(raised.exception), @@ -504,7 +504,7 @@ def check_credentials(username, password): def test_without_credentials_or_check_credentials(self): """basic_auth requires either credentials or check_credentials.""" - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: basic_auth() self.assertEqual( str(raised.exception), @@ -513,7 +513,7 @@ def test_without_credentials_or_check_credentials(self): def test_with_credentials_and_check_credentials(self): """basic_auth requires only one of credentials and check_credentials.""" - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: basic_auth( credentials=("hello", "iloveyou"), check_credentials=lambda: False, # pragma: no cover diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 000000000..16c00c1b9 --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,14 @@ +from .utils import DeprecationTestCase + + +class BackwardsCompatibilityTests(DeprecationTestCase): + def test_headers_class(self): + with self.assertDeprecationWarning( + "websockets.auth, an alias for websockets.legacy.auth, is deprecated; " + "see https://websockets.readthedocs.io/en/stable/howto/upgrade.html " + "for upgrade instructions", + ): + from websockets.auth import ( + BasicAuthWebSocketServerProtocol, # noqa: F401 + basic_auth_protocol_factory, # noqa: F401 + ) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 8d41bf915..fef41d136 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -160,8 +160,16 @@ def test_str(self): "invalid opcode: 7", ), ( - PayloadTooBig("payload length exceeds limit: 2 > 1 bytes"), - "payload length exceeds limit: 2 > 1 bytes", + PayloadTooBig(None, 4), + "frame exceeds limit of 4 bytes", + ), + ( + PayloadTooBig(8, 4), + "frame with 8 bytes exceeds limit of 4 bytes", + ), + ( + PayloadTooBig(8, 4, 12), + "frame with 8 bytes after reading 12 bytes exceeds limit of 16 bytes", ), ( InvalidState("WebSocket connection isn't established yet"), @@ -202,3 +210,11 @@ def test_connection_closed_attributes_deprecation_defaults(self): "use Protocol.close_reason or ConnectionClosed.rcvd.reason" ): self.assertEqual(exception.reason, "") + + def test_payload_too_big_with_message(self): + with self.assertDeprecationWarning( + "PayloadTooBig(message) is deprecated; " + "change to PayloadTooBig(size, max_size)", + ): + exc = PayloadTooBig("payload length exceeds limit: 2 > 1 bytes") + self.assertEqual(str(exc), "payload length exceeds limit: 2 > 1 bytes") diff --git a/tests/test_exports.py b/tests/test_exports.py index 67a1a6f99..93b0684f7 100644 --- a/tests/test_exports.py +++ b/tests/test_exports.py @@ -1,30 +1,46 @@ import unittest import websockets -import websockets.auth +import websockets.asyncio.client +import websockets.asyncio.server import websockets.client import websockets.datastructures import websockets.exceptions -import websockets.legacy.protocol import websockets.server import websockets.typing import websockets.uri combined_exports = ( - websockets.auth.__all__ + [] + + websockets.asyncio.client.__all__ + + websockets.asyncio.server.__all__ + websockets.client.__all__ + websockets.datastructures.__all__ + websockets.exceptions.__all__ - + websockets.legacy.protocol.__all__ + websockets.server.__all__ + websockets.typing.__all__ ) +# These API are intentionally not re-exported by the top-level module. +missing_reexports = [ + # websockets.asyncio.client + "ClientConnection", + # websockets.asyncio.server + "ServerConnection", + "Server", +] + class ExportsTests(unittest.TestCase): - def test_top_level_module_reexports_all_submodule_exports(self): - self.assertEqual(set(combined_exports), set(websockets.__all__)) + def test_top_level_module_reexports_submodule_exports(self): + self.assertEqual( + set(combined_exports), + set(websockets.__all__ + missing_reexports), + ) def test_submodule_exports_are_globally_unique(self): - self.assertEqual(len(set(combined_exports)), len(combined_exports)) + self.assertEqual( + len(set(combined_exports)), + len(combined_exports), + ) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 7f1276bb2..1c092459d 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -20,7 +20,7 @@ Frame, ) from websockets.protocol import * -from websockets.protocol import CLIENT, CLOSED, CLOSING, SERVER +from websockets.protocol import CLIENT, CLOSED, CLOSING, CONNECTING, SERVER from .extensions.utils import Rsv2Extension from .test_frames import FramesTestCase @@ -265,18 +265,28 @@ def test_client_receives_text_over_size_limit(self): client = Protocol(CLIENT, max_size=3) client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") self.assertIsInstance(client.parser_exc, PayloadTooBig) - self.assertEqual(str(client.parser_exc), "over size limit (4 > 3 bytes)") + self.assertEqual( + str(client.parser_exc), + "frame with 4 bytes exceeds limit of 3 bytes", + ) self.assertConnectionFailing( - client, CloseCode.MESSAGE_TOO_BIG, "over size limit (4 > 3 bytes)" + client, + CloseCode.MESSAGE_TOO_BIG, + "frame with 4 bytes exceeds limit of 3 bytes", ) def test_server_receives_text_over_size_limit(self): server = Protocol(SERVER, max_size=3) server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") self.assertIsInstance(server.parser_exc, PayloadTooBig) - self.assertEqual(str(server.parser_exc), "over size limit (4 > 3 bytes)") + self.assertEqual( + str(server.parser_exc), + "frame with 4 bytes exceeds limit of 3 bytes", + ) self.assertConnectionFailing( - server, CloseCode.MESSAGE_TOO_BIG, "over size limit (4 > 3 bytes)" + server, + CloseCode.MESSAGE_TOO_BIG, + "frame with 4 bytes exceeds limit of 3 bytes", ) def test_client_receives_text_without_size_limit(self): @@ -363,9 +373,14 @@ def test_client_receives_fragmented_text_over_size_limit(self): ) client.receive_data(b"\x80\x02\x98\x80") self.assertIsInstance(client.parser_exc, PayloadTooBig) - self.assertEqual(str(client.parser_exc), "over size limit (2 > 1 bytes)") + self.assertEqual( + str(client.parser_exc), + "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", + ) self.assertConnectionFailing( - client, CloseCode.MESSAGE_TOO_BIG, "over size limit (2 > 1 bytes)" + client, + CloseCode.MESSAGE_TOO_BIG, + "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", ) def test_server_receives_fragmented_text_over_size_limit(self): @@ -377,9 +392,14 @@ def test_server_receives_fragmented_text_over_size_limit(self): ) server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") self.assertIsInstance(server.parser_exc, PayloadTooBig) - self.assertEqual(str(server.parser_exc), "over size limit (2 > 1 bytes)") + self.assertEqual( + str(server.parser_exc), + "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", + ) self.assertConnectionFailing( - server, CloseCode.MESSAGE_TOO_BIG, "over size limit (2 > 1 bytes)" + server, + CloseCode.MESSAGE_TOO_BIG, + "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", ) def test_client_receives_fragmented_text_without_size_limit(self): @@ -533,18 +553,28 @@ def test_client_receives_binary_over_size_limit(self): client = Protocol(CLIENT, max_size=3) client.receive_data(b"\x82\x04\x01\x02\xfe\xff") self.assertIsInstance(client.parser_exc, PayloadTooBig) - self.assertEqual(str(client.parser_exc), "over size limit (4 > 3 bytes)") + self.assertEqual( + str(client.parser_exc), + "frame with 4 bytes exceeds limit of 3 bytes", + ) self.assertConnectionFailing( - client, CloseCode.MESSAGE_TOO_BIG, "over size limit (4 > 3 bytes)" + client, + CloseCode.MESSAGE_TOO_BIG, + "frame with 4 bytes exceeds limit of 3 bytes", ) def test_server_receives_binary_over_size_limit(self): server = Protocol(SERVER, max_size=3) server.receive_data(b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff") self.assertIsInstance(server.parser_exc, PayloadTooBig) - self.assertEqual(str(server.parser_exc), "over size limit (4 > 3 bytes)") + self.assertEqual( + str(server.parser_exc), + "frame with 4 bytes exceeds limit of 3 bytes", + ) self.assertConnectionFailing( - server, CloseCode.MESSAGE_TOO_BIG, "over size limit (4 > 3 bytes)" + server, + CloseCode.MESSAGE_TOO_BIG, + "frame with 4 bytes exceeds limit of 3 bytes", ) def test_client_sends_fragmented_binary(self): @@ -615,9 +645,14 @@ def test_client_receives_fragmented_binary_over_size_limit(self): ) client.receive_data(b"\x80\x02\xfe\xff") self.assertIsInstance(client.parser_exc, PayloadTooBig) - self.assertEqual(str(client.parser_exc), "over size limit (2 > 1 bytes)") + self.assertEqual( + str(client.parser_exc), + "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", + ) self.assertConnectionFailing( - client, CloseCode.MESSAGE_TOO_BIG, "over size limit (2 > 1 bytes)" + client, + CloseCode.MESSAGE_TOO_BIG, + "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", ) def test_server_receives_fragmented_binary_over_size_limit(self): @@ -629,9 +664,14 @@ def test_server_receives_fragmented_binary_over_size_limit(self): ) server.receive_data(b"\x80\x82\x00\x00\x00\x00\xfe\xff") self.assertIsInstance(server.parser_exc, PayloadTooBig) - self.assertEqual(str(server.parser_exc), "over size limit (2 > 1 bytes)") + self.assertEqual( + str(server.parser_exc), + "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", + ) self.assertConnectionFailing( - server, CloseCode.MESSAGE_TOO_BIG, "over size limit (2 > 1 bytes)" + server, + CloseCode.MESSAGE_TOO_BIG, + "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", ) def test_client_sends_unexpected_binary(self): @@ -1656,6 +1696,24 @@ def test_server_fails_connection(self): server.fail(CloseCode.PROTOCOL_ERROR) self.assertTrue(server.close_expected()) + def test_client_is_connecting(self): + client = Protocol(CLIENT, state=CONNECTING) + self.assertFalse(client.close_expected()) + + def test_server_is_connecting(self): + server = Protocol(SERVER, state=CONNECTING) + self.assertFalse(server.close_expected()) + + def test_client_failed_connecting(self): + client = Protocol(CLIENT, state=CONNECTING) + client.send_eof() + self.assertTrue(client.close_expected()) + + def test_server_failed_connecting(self): + server = Protocol(SERVER, state=CONNECTING) + server.send_eof() + self.assertTrue(server.close_expected()) + class ConnectionClosedTests(ProtocolTestCase): """ diff --git a/tests/utils.py b/tests/utils.py index 960439135..77d020726 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,14 +1,18 @@ import contextlib import email.utils +import logging import os import pathlib import platform import ssl +import sys import tempfile import time import unittest import warnings +from websockets.version import released + # Generate TLS certificate with: # $ openssl req -x509 -config test_localhost.cnf -days 15340 -newkey rsa:2048 \ @@ -39,9 +43,19 @@ DATE = email.utils.formatdate(usegmt=True) -# Unit for timeouts. May be increased on slow machines by setting the -# WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variable. -MS = 0.001 * float(os.environ.get("WEBSOCKETS_TESTS_TIMEOUT_FACTOR", "1")) +# Unit for timeouts. May be increased in slow or noisy environments by setting +# the WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variable. + +# Downstream distributors insist on running the test suite despites my pleas to +# the contrary. They do it on build farms with unstable performance, leading to +# flakiness, and then they file bugs. Make tests 100x slower to avoid flakiness. + +MS = 0.001 * float( + os.environ.get( + "WEBSOCKETS_TESTS_TIMEOUT_FACTOR", + "100" if released else "1", + ) +) # PyPy, asyncio's debug mode, and coverage penalize performance of this # test suite. Increase timeouts to reduce the risk of spurious failures. @@ -101,6 +115,30 @@ def assertDeprecationWarning(self, message): self.assertEqual(str(warning.message), message) +class AssertNoLogsMixin: + """ + Backport of assertNoLogs for Python 3.9. + + """ + + if sys.version_info[:2] < (3, 10): # pragma: no cover + + @contextlib.contextmanager + def assertNoLogs(self, logger=None, level=None): + """ + No message is logged on the given logger with at least the given level. + + """ + with self.assertLogs(logger, level) as logs: + # We want to test that no log message is emitted + # but assertLogs expects at least one log message. + logging.getLogger(logger).log(level, "dummy") + yield + + level_name = logging.getLevelName(level) + self.assertEqual(logs.output, [f"{level_name}:{logger}:dummy"]) + + @contextlib.contextmanager def temp_unix_socket_path(): with tempfile.TemporaryDirectory() as temp_dir: diff --git a/tox.ini b/tox.ini index cba9b290b..0bcec5ded 100644 --- a/tox.ini +++ b/tox.ini @@ -1,6 +1,5 @@ [tox] env_list = - py38 py39 py310 py311