From 37c7f6529a0877c87f191bc566d2e14f7c96e192 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Sep 2024 19:16:44 +0200 Subject: [PATCH 01/51] Start version 14.0. --- docs/project/changelog.rst | 7 +++++++ src/websockets/version.py | 4 ++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 7e4bce9c6..65e26008f 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,6 +25,13 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented, private APIs may change without notice. +.. _14.0: + +14.0 +---- + +*In development* + .. _13.1: 13.1 diff --git a/src/websockets/version.py b/src/websockets/version.py index 00b0a985e..34fc2eaef 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -18,9 +18,9 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = True +released = False -tag = version = commit = "13.1" +tag = version = commit = "14.0" if not released: # pragma: no cover From 44ccee17c519ea1397ee28a8ac3a7d7685cd0b89 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Sep 2024 19:30:56 +0200 Subject: [PATCH 02/51] Drop Python 3.8. It is EOL at the end of October. --- .github/workflows/tests.yml | 1 - docs/howto/django.rst | 3 +-- docs/intro/index.rst | 2 +- docs/project/changelog.rst | 14 +++++++++----- pyproject.toml | 3 +-- src/websockets/asyncio/client.py | 5 +++-- src/websockets/asyncio/connection.py | 11 ++--------- src/websockets/asyncio/messages.py | 10 ++-------- src/websockets/asyncio/server.py | 16 ++++------------ src/websockets/client.py | 7 ++++--- src/websockets/datastructures.py | 15 +++------------ src/websockets/extensions/base.py | 2 +- src/websockets/extensions/permessage_deflate.py | 3 ++- src/websockets/frames.py | 3 ++- src/websockets/headers.py | 3 ++- src/websockets/http11.py | 3 ++- src/websockets/imports.py | 3 ++- src/websockets/legacy/auth.py | 6 +++--- src/websockets/legacy/client.py | 10 ++-------- src/websockets/legacy/framing.py | 3 ++- src/websockets/legacy/protocol.py | 13 ++----------- src/websockets/legacy/server.py | 16 +++------------- src/websockets/protocol.py | 7 ++++--- src/websockets/server.py | 5 +++-- src/websockets/streams.py | 2 +- src/websockets/sync/client.py | 3 ++- src/websockets/sync/connection.py | 6 +++--- src/websockets/sync/messages.py | 6 +++--- src/websockets/sync/server.py | 7 ++++--- src/websockets/typing.py | 8 +++----- tests/asyncio/test_client.py | 8 -------- tox.ini | 1 - 32 files changed, 76 insertions(+), 129 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 43193ea50..beaf9d12b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -55,7 +55,6 @@ jobs: strategy: matrix: python: - - "3.8" - "3.9" - "3.10" - "3.11" diff --git a/docs/howto/django.rst b/docs/howto/django.rst index dada9c5e4..4fe2311cb 100644 --- a/docs/howto/django.rst +++ b/docs/howto/django.rst @@ -121,8 +121,7 @@ authentication fails, it closes the connection and exits. When we call an API that makes a database query such as ``get_user()``, we wrap the call in :func:`~asyncio.to_thread`. Indeed, the Django ORM doesn't support asynchronous I/O. It would block the event loop if it didn't run in a -separate thread. :func:`~asyncio.to_thread` is available since Python 3.9. In -earlier versions, use :meth:`~asyncio.loop.run_in_executor` instead. +separate thread. Finally, we start a server with :func:`~websockets.asyncio.server.serve`. diff --git a/docs/intro/index.rst b/docs/intro/index.rst index 095262a20..642e50094 100644 --- a/docs/intro/index.rst +++ b/docs/intro/index.rst @@ -6,7 +6,7 @@ Getting started Requirements ------------ -websockets requires Python ≥ 3.8. +websockets requires Python ≥ 3.9. .. admonition:: Use the most recent Python release :class: tip diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 65e26008f..5f07fc09f 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -32,6 +32,15 @@ notice. *In development* +Backwards-incompatible changes +.............................. + +.. admonition:: websockets 14.0 requires Python ≥ 3.9. + :class: tip + + websockets 13.1 is the last version supporting Python 3.8. + + .. _13.1: 13.1 @@ -106,11 +115,6 @@ Bug fixes Backwards-incompatible changes .............................. -.. admonition:: websockets 13.0 requires Python ≥ 3.8. - :class: tip - - websockets 12.0 is the last version supporting Python 3.7. - .. admonition:: Receiving the request path in the second parameter of connection handlers is deprecated. :class: note diff --git a/pyproject.toml b/pyproject.toml index fde9c3226..6a0ab8d7c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta" [project] name = "websockets" description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" -requires-python = ">=3.8" +requires-python = ">=3.9" license = { text = "BSD-3-Clause" } authors = [ { name = "Aymeric Augustin", email = "aymeric.augustin@m4x.org" }, @@ -19,7 +19,6 @@ classifiers = [ "Operating System :: OS Independent", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index b1beb3e00..23b1a348a 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -4,8 +4,9 @@ import logging import os import urllib.parse +from collections.abc import AsyncIterator, Generator, Sequence from types import TracebackType -from typing import Any, AsyncIterator, Callable, Generator, Sequence +from typing import Any, Callable from ..client import ClientProtocol, backoff from ..datastructures import HeadersLike @@ -492,7 +493,7 @@ async def __aexit__( # async for ... in connect(...): async def __aiter__(self) -> AsyncIterator[ClientConnection]: - delays: Generator[float, None, None] | None = None + delays: Generator[float] | None = None while True: try: async with self as protocol: diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 6af61a4a9..702e69995 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -8,16 +8,9 @@ import struct import sys import uuid +from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterable, Mapping from types import TracebackType -from typing import ( - Any, - AsyncIterable, - AsyncIterator, - Awaitable, - Iterable, - Mapping, - cast, -) +from typing import Any, cast from ..exceptions import ( ConcurrencyError, diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index c2b4afd67..e3ec5062f 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -3,14 +3,8 @@ import asyncio import codecs import collections -from typing import ( - Any, - AsyncIterator, - Callable, - Generic, - Iterable, - TypeVar, -) +from collections.abc import AsyncIterator, Iterable +from typing import Any, Callable, Generic, TypeVar from ..exceptions import ConcurrencyError from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 19dae44b7..e11dd91f1 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -6,17 +6,9 @@ import logging import socket import sys +from collections.abc import Awaitable, Generator, Iterable, Sequence from types import TracebackType -from typing import ( - Any, - Awaitable, - Callable, - Generator, - Iterable, - Sequence, - Tuple, - cast, -) +from typing import Any, Callable, cast from ..exceptions import InvalidHeader from ..extensions.base import ServerExtensionFactory @@ -905,9 +897,9 @@ def basic_auth( if credentials is not None: if is_credentials(credentials): - credentials_list = [cast(Tuple[str, str], credentials)] + credentials_list = [cast(tuple[str, str], credentials)] elif isinstance(credentials, Iterable): - credentials_list = list(cast(Iterable[Tuple[str, str]], credentials)) + credentials_list = list(cast(Iterable[tuple[str, str]], credentials)) if not all(is_credentials(item) for item in credentials_list): raise TypeError(f"invalid credentials argument: {credentials}") else: diff --git a/src/websockets/client.py b/src/websockets/client.py index e5f294986..bce82d66b 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -3,7 +3,8 @@ import os import random import warnings -from typing import Any, Generator, Sequence +from collections.abc import Generator, Sequence +from typing import Any from .datastructures import Headers, MultipleValuesError from .exceptions import ( @@ -313,7 +314,7 @@ def send_request(self, request: Request) -> None: self.writes.append(request.serialize()) - def parse(self) -> Generator[None, None, None]: + def parse(self) -> Generator[None]: if self.state is CONNECTING: try: response = yield from Response.parse( @@ -374,7 +375,7 @@ def backoff( min_delay: float = BACKOFF_MIN_DELAY, max_delay: float = BACKOFF_MAX_DELAY, factor: float = BACKOFF_FACTOR, -) -> Generator[float, None, None]: +) -> Generator[float]: """ Generate a series of backoff delays between reconnection attempts. diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index 106d6f393..77b6f86fa 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -1,15 +1,7 @@ from __future__ import annotations -from typing import ( - Any, - Iterable, - Iterator, - Mapping, - MutableMapping, - Protocol, - Tuple, - Union, -) +from collections.abc import Iterable, Iterator, Mapping, MutableMapping +from typing import Any, Protocol, Union __all__ = ["Headers", "HeadersLike", "MultipleValuesError"] @@ -179,8 +171,7 @@ def __getitem__(self, key: str) -> str: ... HeadersLike = Union[ Headers, Mapping[str, str], - # Change to tuple[str, str] when dropping Python < 3.9. - Iterable[Tuple[str, str]], + Iterable[tuple[str, str]], SupportsKeysAndGetItem, ] """ diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index 75bae6b77..42dd6c5fa 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Sequence +from collections.abc import Sequence from ..frames import Frame from ..typing import ExtensionName, ExtensionParameter diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 25d2c1c45..f962b65fb 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -2,7 +2,8 @@ import dataclasses import zlib -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any from .. import frames from ..exceptions import ( diff --git a/src/websockets/frames.py b/src/websockets/frames.py index a63bdc3b6..dace2c902 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -6,7 +6,8 @@ import os import secrets import struct -from typing import Callable, Generator, Sequence +from collections.abc import Generator, Sequence +from typing import Callable from .exceptions import PayloadTooBig, ProtocolError diff --git a/src/websockets/headers.py b/src/websockets/headers.py index 9103018a0..e05948a1f 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -4,7 +4,8 @@ import binascii import ipaddress import re -from typing import Callable, Sequence, TypeVar, cast +from collections.abc import Sequence +from typing import Callable, TypeVar, cast from .exceptions import InvalidHeaderFormat, InvalidHeaderValue from .typing import ( diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 47cef7a9b..af542c77b 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -5,7 +5,8 @@ import re import sys import warnings -from typing import Callable, Generator +from collections.abc import Generator +from typing import Callable from .datastructures import Headers from .exceptions import SecurityError diff --git a/src/websockets/imports.py b/src/websockets/imports.py index bb80e4eac..c63fb212e 100644 --- a/src/websockets/imports.py +++ b/src/websockets/imports.py @@ -1,7 +1,8 @@ from __future__ import annotations import warnings -from typing import Any, Iterable +from collections.abc import Iterable +from typing import Any __all__ = ["lazy_import"] diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index 4d030e5e2..a262fcd79 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -3,7 +3,8 @@ import functools import hmac import http -from typing import Any, Awaitable, Callable, Iterable, Tuple, cast +from collections.abc import Awaitable, Iterable +from typing import Any, Callable, cast from ..datastructures import Headers from ..exceptions import InvalidHeader @@ -13,8 +14,7 @@ __all__ = ["BasicAuthWebSocketServerProtocol", "basic_auth_protocol_factory"] -# Change to tuple[str, str] when dropping Python < 3.9. -Credentials = Tuple[str, str] +Credentials = tuple[str, str] def is_credentials(value: Any) -> bool: diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index ec4c2ff64..116445e25 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -7,15 +7,9 @@ import random import urllib.parse import warnings +from collections.abc import AsyncIterator, Generator, Sequence from types import TracebackType -from typing import ( - Any, - AsyncIterator, - Callable, - Generator, - Sequence, - cast, -) +from typing import Any, Callable, cast from ..asyncio.compatibility import asyncio_timeout from ..datastructures import Headers, HeadersLike diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index 4c2f8c23f..4ec194ed7 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -1,7 +1,8 @@ from __future__ import annotations import struct -from typing import Any, Awaitable, Callable, NamedTuple, Sequence +from collections.abc import Awaitable, Sequence +from typing import Any, Callable, NamedTuple from .. import extensions, frames from ..exceptions import PayloadTooBig, ProtocolError diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 998e390d4..cedde6200 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -11,17 +11,8 @@ import time import uuid import warnings -from typing import ( - Any, - AsyncIterable, - AsyncIterator, - Awaitable, - Callable, - Deque, - Iterable, - Mapping, - cast, -) +from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterable, Mapping +from typing import Any, Callable, Deque, cast from ..asyncio.compatibility import asyncio_timeout from ..datastructures import Headers diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 2cb9b1abb..9326b6100 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -8,18 +8,9 @@ import logging import socket import warnings +from collections.abc import Awaitable, Generator, Iterable, Sequence from types import TracebackType -from typing import ( - Any, - Awaitable, - Callable, - Generator, - Iterable, - Sequence, - Tuple, - Union, - cast, -) +from typing import Any, Callable, Union, cast from ..asyncio.compatibility import asyncio_timeout from ..datastructures import Headers, HeadersLike, MultipleValuesError @@ -59,8 +50,7 @@ # Change to HeadersLike | ... when dropping Python < 3.10. HeadersLikeOrCallable = Union[HeadersLike, Callable[[str, Headers], HeadersLike]] -# Change to tuple[...] when dropping Python < 3.9. -HTTPResponse = Tuple[StatusLike, HeadersLike, bytes] +HTTPResponse = tuple[StatusLike, HeadersLike, bytes] class WebSocketServerProtocol(WebSocketCommonProtocol): diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 8751ebdb4..091b4a23a 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -3,7 +3,8 @@ import enum import logging import uuid -from typing import Generator, Union +from collections.abc import Generator +from typing import Union from .exceptions import ( ConnectionClosed, @@ -529,7 +530,7 @@ def close_expected(self) -> bool: # Private methods for receiving data. - def parse(self) -> Generator[None, None, None]: + def parse(self) -> Generator[None]: """ Parse incoming data into frames. @@ -600,7 +601,7 @@ def parse(self) -> Generator[None, None, None]: yield raise AssertionError("parse() shouldn't step after error") - def discard(self) -> Generator[None, None, None]: + def discard(self) -> Generator[None]: """ Discard incoming data. diff --git a/src/websockets/server.py b/src/websockets/server.py index 006d5bdd5..9fe970619 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -5,7 +5,8 @@ import email.utils import http import warnings -from typing import Any, Callable, Generator, Sequence, cast +from collections.abc import Generator, Sequence +from typing import Any, Callable, cast from .datastructures import Headers, MultipleValuesError from .exceptions import ( @@ -555,7 +556,7 @@ def send_response(self, response: Response) -> None: self.parser = self.discard() next(self.parser) # start coroutine - def parse(self) -> Generator[None, None, None]: + def parse(self) -> Generator[None]: if self.state is CONNECTING: try: request = yield from Request.parse( diff --git a/src/websockets/streams.py b/src/websockets/streams.py index 956f139d4..f52e6193a 100644 --- a/src/websockets/streams.py +++ b/src/websockets/streams.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Generator +from collections.abc import Generator class StreamReader: diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index d1e20a757..5e1ba6d84 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -4,7 +4,8 @@ import ssl as ssl_module import threading import warnings -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any from ..client import ClientProtocol from ..datastructures import HeadersLike diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 97588870e..8c5df9592 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -7,8 +7,9 @@ import struct import threading import uuid +from collections.abc import Iterable, Iterator, Mapping from types import TracebackType -from typing import Any, Iterable, Iterator, Mapping +from typing import Any from ..exceptions import ( ConcurrencyError, @@ -239,8 +240,7 @@ def recv_streaming(self) -> Iterator[Data]: """ try: - for frame in self.recv_messages.get_iter(): - yield frame + yield from self.recv_messages.get_iter() except EOFError: # Wait for the protocol state to be CLOSED before accessing close_exc. self.recv_events_thread.join() diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index 8d090538f..b96cd6880 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -3,7 +3,8 @@ import codecs import queue import threading -from typing import Iterator, cast +from collections.abc import Iterator +from typing import cast from ..exceptions import ConcurrencyError from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame @@ -150,8 +151,7 @@ def get_iter(self) -> Iterator[Data]: chunks = self.chunks self.chunks = [] self.chunks_queue = cast( - # Remove quotes around type when dropping Python < 3.9. - "queue.SimpleQueue[Data | None]", + queue.SimpleQueue[Data | None], queue.SimpleQueue(), ) diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 1b7cbb4b4..464c4a173 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -10,8 +10,9 @@ import sys import threading import warnings +from collections.abc import Iterable, Sequence from types import TracebackType -from typing import Any, Callable, Iterable, Sequence, Tuple, cast +from typing import Any, Callable, cast from ..exceptions import InvalidHeader from ..extensions.base import ServerExtensionFactory @@ -663,9 +664,9 @@ def basic_auth( if credentials is not None: if is_credentials(credentials): - credentials_list = [cast(Tuple[str, str], credentials)] + credentials_list = [cast(tuple[str, str], credentials)] elif isinstance(credentials, Iterable): - credentials_list = list(cast(Iterable[Tuple[str, str]], credentials)) + credentials_list = list(cast(Iterable[tuple[str, str]], credentials)) if not all(is_credentials(item) for item in credentials_list): raise TypeError(f"invalid credentials argument: {credentials}") else: diff --git a/src/websockets/typing.py b/src/websockets/typing.py index 447fe79da..0a37141c6 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -3,7 +3,7 @@ import http import logging import typing -from typing import Any, List, NewType, Optional, Tuple, Union +from typing import Any, NewType, Optional, Union __all__ = [ @@ -56,16 +56,14 @@ ExtensionName = NewType("ExtensionName", str) """Name of a WebSocket extension.""" -# Change to tuple[str, Optional[str]] when dropping Python < 3.9. # Change to tuple[str, str | None] when dropping Python < 3.10. -ExtensionParameter = Tuple[str, Optional[str]] +ExtensionParameter = tuple[str, Optional[str]] """Parameter of a WebSocket extension.""" # Private types -# Change to tuple[.., list[...]] when dropping Python < 3.9. -ExtensionHeader = Tuple[ExtensionName, List[ExtensionParameter]] +ExtensionHeader = tuple[ExtensionName, list[ExtensionParameter]] """Extension in a ``Sec-WebSocket-Extensions`` header.""" diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 999ef1b71..9354a6e0a 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -177,10 +177,6 @@ def process_request(connection, request): self.assertEqual(iterations, 5) self.assertEqual(successful, 2) - @unittest.skipUnless( - hasattr(http.HTTPStatus, "IM_A_TEAPOT"), - "test requires Python 3.9", - ) async def test_reconnect_with_custom_process_exception(self): """Client runs process_exception to tell if errors are retryable or fatal.""" iteration = 0 @@ -214,10 +210,6 @@ def process_exception(exc): "🫖 💔 ☕️", ) - @unittest.skipUnless( - hasattr(http.HTTPStatus, "IM_A_TEAPOT"), - "test requires Python 3.9", - ) async def test_reconnect_with_custom_process_exception_raising_exception(self): """Client supports raising an exception in process_exception.""" diff --git a/tox.ini b/tox.ini index cba9b290b..0bcec5ded 100644 --- a/tox.ini +++ b/tox.ini @@ -1,6 +1,5 @@ [tox] env_list = - py38 py39 py310 py311 From e44a1eadf3c287335fd10c381ba5ccb8bd1ff4ce Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 22 Sep 2024 12:00:24 +0200 Subject: [PATCH 03/51] Document why packagers mustn't run the test suite. Refs #1509, #1496, #1427, #1426, #1081, #1026, perhaps others. --- docs/project/contributing.rst | 26 ++++++++++++++++++++++++-- tests/utils.py | 18 +++++++++++++++--- 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/docs/project/contributing.rst b/docs/project/contributing.rst index 020ed7ad8..3988c028a 100644 --- a/docs/project/contributing.rst +++ b/docs/project/contributing.rst @@ -17,8 +17,8 @@ apologies. I know I can mess up. I can't expect you to tell me, but if you choose to do so, I'll do my best to handle criticism constructively. -- Aymeric)* -Contributions -------------- +Contributing +------------ Bug reports, patches and suggestions are welcome! @@ -34,6 +34,28 @@ websockets. .. _issue: https://github.com/python-websockets/websockets/issues/new .. _pull request: https://github.com/python-websockets/websockets/compare/ +Packaging +--------- + +Some distributions package websockets so that it can be installed with the +system package manager rather than with pip, possibly in a virtualenv. + +If you're packaging websockets for a distribution, you must use `releases +published on PyPI`_ as input. You may check `SLSA attestations on GitHub`_. + +.. _releases published on PyPI: https://pypi.org/project/websockets/#files +.. _SLSA attestations on GitHub: https://github.com/python-websockets/websockets/attestations + +You mustn't rely on the git repository as input. Specifically, you mustn't +attempt to run the main test suite. It isn't treated as a deliverable of the +project. It doesn't do what you think it does. It's designed for the needs of +developers, not packagers. + +On a typical build farm for a distribution, tests that exercise timeouts will +fail randomly. Indeed, the test suite is optimized for running very fast, with a +tolerable level of flakiness, on a high-end laptop without noisy neighbors. This +isn't your context. + Questions --------- diff --git a/tests/utils.py b/tests/utils.py index 960439135..639fb7fe5 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -9,6 +9,8 @@ import unittest import warnings +from websockets.version import released + # Generate TLS certificate with: # $ openssl req -x509 -config test_localhost.cnf -days 15340 -newkey rsa:2048 \ @@ -39,9 +41,19 @@ DATE = email.utils.formatdate(usegmt=True) -# Unit for timeouts. May be increased on slow machines by setting the -# WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variable. -MS = 0.001 * float(os.environ.get("WEBSOCKETS_TESTS_TIMEOUT_FACTOR", "1")) +# Unit for timeouts. May be increased in slow or noisy environments by setting +# the WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variable. + +# Downstream distributors insist on running the test suite despites my pleas to +# the contrary. They do it on build farms with unstable performance, leading to +# flakiness, and then they file bugs. Make tests 100x slower to avoid flakiness. + +MS = 0.001 * float( + os.environ.get( + "WEBSOCKETS_TESTS_TIMEOUT_FACTOR", + "100" if released else "1", + ) +) # PyPy, asyncio's debug mode, and coverage penalize performance of this # test suite. Increase timeouts to reduce the risk of spurious failures. From 4f4e64442e84763ca634f67ba0062c3fb8c985c6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 22 Sep 2024 12:46:55 +0200 Subject: [PATCH 04/51] Restore compatibility with Python 3.9. It was broken in 44ccee17. --- src/websockets/sync/messages.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index b96cd6880..997fa98df 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -151,7 +151,8 @@ def get_iter(self) -> Iterator[Data]: chunks = self.chunks self.chunks = [] self.chunks_queue = cast( - queue.SimpleQueue[Data | None], + # Remove quotes around type when dropping Python < 3.10. + "queue.SimpleQueue[Data | None]", queue.SimpleQueue(), ) From ddafc6682a3f6d0bed9d88dfbecd15dd246973bf Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 22 Sep 2024 20:25:10 +0200 Subject: [PATCH 05/51] Split "getting support" from "contributing". Also add a page about financial contributions. --- docs/project/contributing.rst | 31 ---------------------- docs/project/index.rst | 4 ++- docs/project/sponsoring.rst | 11 ++++++++ docs/project/support.rst | 49 +++++++++++++++++++++++++++++++++++ 4 files changed, 63 insertions(+), 32 deletions(-) create mode 100644 docs/project/sponsoring.rst create mode 100644 docs/project/support.rst diff --git a/docs/project/contributing.rst b/docs/project/contributing.rst index 3988c028a..6ecd175f8 100644 --- a/docs/project/contributing.rst +++ b/docs/project/contributing.rst @@ -55,34 +55,3 @@ On a typical build farm for a distribution, tests that exercise timeouts will fail randomly. Indeed, the test suite is optimized for running very fast, with a tolerable level of flakiness, on a high-end laptop without noisy neighbors. This isn't your context. - -Questions ---------- - -GitHub issues aren't a good medium for handling questions. There are better -places to ask questions, for example Stack Overflow. - -If you want to ask a question anyway, please make sure that: - -- it's a question about websockets and not about :mod:`asyncio`; -- it isn't answered in the documentation; -- it wasn't asked already. - -A good question can be written as a suggestion to improve the documentation. - -Cryptocurrency users --------------------- - -websockets appears to be quite popular for interfacing with Bitcoin or other -cryptocurrency trackers. I'm strongly opposed to Bitcoin's carbon footprint. - -I'm aware of efforts to build proof-of-stake models. I'll care once the total -energy consumption of all cryptocurrencies drops to a non-bullshit level. - -You already negated all of humanity's efforts to develop renewable energy. -Please stop heating the planet where my children will have to live. - -Since websockets is released under an open-source license, you can use it for -any purpose you like. However, I won't spend any of my time to help you. - -I will summarily close issues related to Bitcoin or cryptocurrency in any way. diff --git a/docs/project/index.rst b/docs/project/index.rst index 459146345..56c98196a 100644 --- a/docs/project/index.rst +++ b/docs/project/index.rst @@ -8,5 +8,7 @@ This is about websockets-the-project rather than websockets-the-software. changelog contributing - license + sponsoring For enterprise + support + license diff --git a/docs/project/sponsoring.rst b/docs/project/sponsoring.rst new file mode 100644 index 000000000..77a4fd1d8 --- /dev/null +++ b/docs/project/sponsoring.rst @@ -0,0 +1,11 @@ +Sponsoring +========== + +You may sponsor the development of websockets through: + +* `GitHub Sponsors`_ +* `Open Collective`_ +* :doc:`Tidelift ` + +.. _GitHub Sponsors: https://github.com/sponsors/python-websockets +.. _Open Collective: https://opencollective.com/websockets diff --git a/docs/project/support.rst b/docs/project/support.rst new file mode 100644 index 000000000..21aad6e02 --- /dev/null +++ b/docs/project/support.rst @@ -0,0 +1,49 @@ +Getting support +=============== + +.. admonition:: There are no free support channels. + :class: tip + + websockets is an open-source project. It's primarily maintained by one + person as a hobby. + + For this reason, the focus is on flawless code and self-service + documentation, not support. + +Enterprise +---------- + +websockets is maintained with high standards, making it suitable for enterprise +use cases. Additional guarantees are available via :doc:`Tidelift `. +If you're using it in a professional setting, consider subscribing. + +Questions +--------- + +GitHub issues aren't a good medium for handling questions. There are better +places to ask questions, for example Stack Overflow. + +If you want to ask a question anyway, please make sure that: + +- it's a question about websockets and not about :mod:`asyncio`; +- it isn't answered in the documentation; +- it wasn't asked already. + +A good question can be written as a suggestion to improve the documentation. + +Cryptocurrency users +-------------------- + +websockets appears to be quite popular for interfacing with Bitcoin or other +cryptocurrency trackers. I'm strongly opposed to Bitcoin's carbon footprint. + +I'm aware of efforts to build proof-of-stake models. I'll care once the total +energy consumption of all cryptocurrencies drops to a non-bullshit level. + +You already negated all of humanity's efforts to develop renewable energy. +Please stop heating the planet where my children will have to live. + +Since websockets is released under an open-source license, you can use it for +any purpose you like. However, I won't spend any of my time to help you. + +I will summarily close issues related to cryptocurrency in any way. From a942fcc13f47d9a5bdcd93b6f84da21e1d185e63 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Sep 2024 22:05:54 +0200 Subject: [PATCH 06/51] Switch convenience aliases to the new implementation. --- docs/howto/upgrade.rst | 29 ++++++++++++++--------------- docs/project/changelog.rst | 11 +++++++++++ docs/reference/index.rst | 2 +- src/websockets/__init__.py | 38 ++++++++++++++++++++------------------ 4 files changed, 46 insertions(+), 34 deletions(-) diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index f3e42591e..5b1b8e4a2 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -3,8 +3,8 @@ Upgrade to the new :mod:`asyncio` implementation .. currentmodule:: websockets -The new :mod:`asyncio` implementation is a rewrite of the original -implementation of websockets. +The new :mod:`asyncio` implementation, which is now the default, is a rewrite of +the original implementation of websockets. It provides a very similar API. However, there are a few differences. @@ -70,9 +70,8 @@ For context, the ``websockets`` package is structured as follows: * The new implementation is found in the ``websockets.asyncio`` package. * The original implementation was moved to the ``websockets.legacy`` package. -* The ``websockets`` package provides aliases for convenience. Currently, they - point to the original implementation. They will be updated to point to the new - implementation when it feels mature. +* The ``websockets`` package provides aliases for convenience. They were + switched to the new implementation in version 14.0. * The ``websockets.client`` and ``websockets.server`` packages provide aliases for backwards-compatibility with earlier versions of websockets. They will be deprecated together with the original implementation. @@ -90,12 +89,12 @@ Client APIs +-------------------------------------------------------------------+-----------------------------------------------------+ | Legacy :mod:`asyncio` implementation | New :mod:`asyncio` implementation | +===================================================================+=====================================================+ -| ``websockets.connect()`` |br| | :func:`websockets.asyncio.client.connect` | -| ``websockets.client.connect()`` |br| | | +| ``websockets.connect()`` *(before 14.0)* |br| | ``websockets.connect()`` *(since 14.0)* |br| | +| ``websockets.client.connect()`` |br| | :func:`websockets.asyncio.client.connect` | | :func:`websockets.legacy.client.connect` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -| ``websockets.unix_connect()`` |br| | :func:`websockets.asyncio.client.unix_connect` | -| ``websockets.client.unix_connect()`` |br| | | +| ``websockets.unix_connect()`` *(before 14.0)* |br| | ``websockets.unix_connect()`` *(since 14.0)* |br| | +| ``websockets.client.unix_connect()`` |br| | :func:`websockets.asyncio.client.unix_connect` | | :func:`websockets.legacy.client.unix_connect` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.WebSocketClientProtocol`` |br| | :class:`websockets.asyncio.client.ClientConnection` | @@ -109,12 +108,12 @@ Server APIs +-------------------------------------------------------------------+-----------------------------------------------------+ | Legacy :mod:`asyncio` implementation | New :mod:`asyncio` implementation | +===================================================================+=====================================================+ -| ``websockets.serve()`` |br| | :func:`websockets.asyncio.server.serve` | -| ``websockets.server.serve()`` |br| | | +| ``websockets.serve()`` *(before 14.0)* |br| | ``websockets.serve()`` *(since 14.0)* |br| | +| ``websockets.server.serve()`` |br| | :func:`websockets.asyncio.server.serve` | | :func:`websockets.legacy.server.serve` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -| ``websockets.unix_serve()`` |br| | :func:`websockets.asyncio.server.unix_serve` | -| ``websockets.server.unix_serve()`` |br| | | +| ``websockets.unix_serve()`` *(before 14.0)* |br| | ``websockets.unix_serve()`` *(since 14.0)* |br| | +| ``websockets.server.unix_serve()`` |br| | :func:`websockets.asyncio.server.unix_serve` | | :func:`websockets.legacy.server.unix_serve` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.WebSocketServer`` |br| | :class:`websockets.asyncio.server.Server` | @@ -125,8 +124,8 @@ Server APIs | ``websockets.server.WebSocketServerProtocol`` |br| | | | :class:`websockets.legacy.server.WebSocketServerProtocol` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -| ``websockets.broadcast()`` |br| | :func:`websockets.asyncio.server.broadcast` | -| :func:`websockets.legacy.server.broadcast()` | | +| ``websockets.broadcast()`` *(before 14.0)* |br| | ``websockets.broadcast()`` *(since 14.0)* |br| | +| :func:`websockets.legacy.server.broadcast()` | :func:`websockets.asyncio.server.broadcast` | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.BasicAuthWebSocketServerProtocol`` |br| | See below :ref:`how to migrate ` to | | ``websockets.auth.BasicAuthWebSocketServerProtocol`` |br| | :func:`websockets.asyncio.server.basic_auth`. | diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 5f07fc09f..45e4ef0cc 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -40,6 +40,17 @@ Backwards-incompatible changes websockets 13.1 is the last version supporting Python 3.8. +.. admonition:: The new :mod:`asyncio` implementation is now the default. + :class: caution + + The following aliases in the ``websockets`` package were switched to the new + :mod:`asyncio` implementation:: + + from websockets import connect, unix_connext + from websockets import broadcast, serve, unix_serve + + If you're using any of them, then you must follow the :doc:`upgrade guide + <../howto/upgrade>` immediately. .. _13.1: diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 77b538b78..ed2341cc6 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -98,5 +98,5 @@ guarantees of behavior or backwards-compatibility for private APIs. Convenience imports ------------------- -For convenience, many public APIs can be imported directly from the +For convenience, some public APIs can be imported directly from the ``websockets`` package. diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 54591e9fd..036e71c23 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -7,6 +7,14 @@ __all__ = [ + # .asyncio.client + "connect", + "unix_connect", + # .asyncio.server + "basic_auth", + "broadcast", + "serve", + "unix_serve", # .client "ClientProtocol", # .datastructures @@ -41,8 +49,6 @@ "basic_auth_protocol_factory", # .legacy.client "WebSocketClientProtocol", - "connect", - "unix_connect", # .legacy.exceptions "AbortHandshake", "InvalidMessage", @@ -53,9 +59,6 @@ # .legacy.server "WebSocketServer", "WebSocketServerProtocol", - "broadcast", - "serve", - "unix_serve", # .server "ServerProtocol", # .typing @@ -70,6 +73,8 @@ # When type checking, import non-deprecated aliases eagerly. Else, import on demand. if typing.TYPE_CHECKING: + from .asyncio.client import connect, unix_connect + from .asyncio.server import basic_auth, broadcast, serve, unix_serve from .client import ClientProtocol from .datastructures import Headers, HeadersLike, MultipleValuesError from .exceptions import ( @@ -100,7 +105,7 @@ BasicAuthWebSocketServerProtocol, basic_auth_protocol_factory, ) - from .legacy.client import WebSocketClientProtocol, connect, unix_connect + from .legacy.client import WebSocketClientProtocol from .legacy.exceptions import ( AbortHandshake, InvalidMessage, @@ -108,13 +113,7 @@ RedirectHandshake, ) from .legacy.protocol import WebSocketCommonProtocol - from .legacy.server import ( - WebSocketServer, - WebSocketServerProtocol, - broadcast, - serve, - unix_serve, - ) + from .legacy.server import WebSocketServer, WebSocketServerProtocol from .server import ServerProtocol from .typing import ( Data, @@ -129,6 +128,14 @@ lazy_import( globals(), aliases={ + # .asyncio.client + "connect": ".asyncio.client", + "unix_connect": ".asyncio.client", + # .asyncio.server + "basic_auth": ".asyncio.server", + "broadcast": ".asyncio.server", + "serve": ".asyncio.server", + "unix_serve": ".asyncio.server", # .client "ClientProtocol": ".client", # .datastructures @@ -163,8 +170,6 @@ "basic_auth_protocol_factory": ".legacy.auth", # .legacy.client "WebSocketClientProtocol": ".legacy.client", - "connect": ".legacy.client", - "unix_connect": ".legacy.client", # .legacy.exceptions "AbortHandshake": ".legacy.exceptions", "InvalidMessage": ".legacy.exceptions", @@ -175,9 +180,6 @@ # .legacy.server "WebSocketServer": ".legacy.server", "WebSocketServerProtocol": ".legacy.server", - "broadcast": ".legacy.server", - "serve": ".legacy.server", - "unix_serve": ".legacy.server", # .server "ServerProtocol": ".server", # .typing From 8d055ebf383520f46c1f0b6d40742e3ef8c3d723 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Sep 2024 22:40:09 +0200 Subject: [PATCH 07/51] Deprecate aliases pointing to the legacy implementation. --- docs/howto/upgrade.rst | 3 +- docs/project/changelog.rst | 6 +++ src/websockets/__init__.py | 63 ++++++++---------------------- src/websockets/auth.py | 10 ++++- src/websockets/client.py | 20 ++++++---- src/websockets/exceptions.py | 41 ++++++------------- src/websockets/server.py | 22 +++++++---- tests/legacy/test_auth.py | 2 +- tests/legacy/test_client_server.py | 2 +- tests/test_exports.py | 30 ++++++++++---- 10 files changed, 99 insertions(+), 100 deletions(-) diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index 5b1b8e4a2..bdaefd768 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -71,7 +71,8 @@ For context, the ``websockets`` package is structured as follows: * The new implementation is found in the ``websockets.asyncio`` package. * The original implementation was moved to the ``websockets.legacy`` package. * The ``websockets`` package provides aliases for convenience. They were - switched to the new implementation in version 14.0. + switched to the new implementation in version 14.0 or deprecated when there + isn't an equivalent API. * The ``websockets.client`` and ``websockets.server`` packages provide aliases for backwards-compatibility with earlier versions of websockets. They will be deprecated together with the original implementation. diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 45e4ef0cc..30e89a69c 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -52,6 +52,12 @@ Backwards-incompatible changes If you're using any of them, then you must follow the :doc:`upgrade guide <../howto/upgrade>` immediately. +.. admonition:: The legacy :mod:`asyncio` implementation is now deprecated. + :class: caution + + Aliases for deprecated API were removed from ``__all__``. As a consequence, + they cannot be imported e.g. with ``from websockets import *`` anymore. + .. _13.1: 13.1 diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 036e71c23..531ce49f7 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -43,22 +43,6 @@ "ProtocolError", "SecurityError", "WebSocketException", - "WebSocketProtocolError", - # .legacy.auth - "BasicAuthWebSocketServerProtocol", - "basic_auth_protocol_factory", - # .legacy.client - "WebSocketClientProtocol", - # .legacy.exceptions - "AbortHandshake", - "InvalidMessage", - "InvalidStatusCode", - "RedirectHandshake", - # .legacy.protocol - "WebSocketCommonProtocol", - # .legacy.server - "WebSocketServer", - "WebSocketServerProtocol", # .server "ServerProtocol", # .typing @@ -99,21 +83,7 @@ ProtocolError, SecurityError, WebSocketException, - WebSocketProtocolError, ) - from .legacy.auth import ( - BasicAuthWebSocketServerProtocol, - basic_auth_protocol_factory, - ) - from .legacy.client import WebSocketClientProtocol - from .legacy.exceptions import ( - AbortHandshake, - InvalidMessage, - InvalidStatusCode, - RedirectHandshake, - ) - from .legacy.protocol import WebSocketCommonProtocol - from .legacy.server import WebSocketServer, WebSocketServerProtocol from .server import ServerProtocol from .typing import ( Data, @@ -164,22 +134,6 @@ "ProtocolError": ".exceptions", "SecurityError": ".exceptions", "WebSocketException": ".exceptions", - "WebSocketProtocolError": ".exceptions", - # .legacy.auth - "BasicAuthWebSocketServerProtocol": ".legacy.auth", - "basic_auth_protocol_factory": ".legacy.auth", - # .legacy.client - "WebSocketClientProtocol": ".legacy.client", - # .legacy.exceptions - "AbortHandshake": ".legacy.exceptions", - "InvalidMessage": ".legacy.exceptions", - "InvalidStatusCode": ".legacy.exceptions", - "RedirectHandshake": ".legacy.exceptions", - # .legacy.protocol - "WebSocketCommonProtocol": ".legacy.protocol", - # .legacy.server - "WebSocketServer": ".legacy.server", - "WebSocketServerProtocol": ".legacy.server", # .server "ServerProtocol": ".server", # .typing @@ -197,5 +151,22 @@ "handshake": ".legacy", "parse_uri": ".uri", "WebSocketURI": ".uri", + # deprecated in 14.0 + # .legacy.auth + "BasicAuthWebSocketServerProtocol": ".legacy.auth", + "basic_auth_protocol_factory": ".legacy.auth", + # .legacy.client + "WebSocketClientProtocol": ".legacy.client", + # .legacy.exceptions + "AbortHandshake": ".legacy.exceptions", + "InvalidMessage": ".legacy.exceptions", + "InvalidStatusCode": ".legacy.exceptions", + "RedirectHandshake": ".legacy.exceptions", + "WebSocketProtocolError": ".legacy.exceptions", + # .legacy.protocol + "WebSocketCommonProtocol": ".legacy.protocol", + # .legacy.server + "WebSocketServer": ".legacy.server", + "WebSocketServerProtocol": ".legacy.server", }, ) diff --git a/src/websockets/auth.py b/src/websockets/auth.py index b792e02f5..1e0002cee 100644 --- a/src/websockets/auth.py +++ b/src/websockets/auth.py @@ -1,6 +1,12 @@ from __future__ import annotations -# See #940 for why lazy_import isn't used here for backwards compatibility. -# See #1400 for why listing compatibility imports in __all__ helps PyCharm. +import warnings + from .legacy.auth import * from .legacy.auth import __all__ # noqa: F401 + + +warnings.warn( # deprecated in 14.0 + "websockets.auth is deprecated", + DeprecationWarning, +) diff --git a/src/websockets/client.py b/src/websockets/client.py index bce82d66b..8b66900a8 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -27,6 +27,7 @@ parse_upgrade, ) from .http11 import Request, Response +from .imports import lazy_import from .protocol import CLIENT, CONNECTING, OPEN, Protocol, State from .typing import ( ConnectionOption, @@ -40,13 +41,7 @@ from .utils import accept_key, generate_key -# See #940 for why lazy_import isn't used here for backwards compatibility. -# See #1400 for why listing compatibility imports in __all__ helps PyCharm. -from .legacy.client import * # isort:skip # noqa: I001 -from .legacy.client import __all__ as legacy__all__ - - -__all__ = ["ClientProtocol"] + legacy__all__ +__all__ = ["ClientProtocol"] class ClientProtocol(Protocol): @@ -392,3 +387,14 @@ def backoff( delay *= factor while True: yield max_delay + + +lazy_import( + globals(), + deprecated_aliases={ + # deprecated in 14.0 + "WebSocketClientProtocol": ".legacy.client", + "connect": ".legacy.client", + "unix_connect": ".legacy.client", + }, +) diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index d723f2fec..7681736a4 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -31,7 +31,6 @@ from __future__ import annotations -import typing import warnings from .imports import lazy_import @@ -45,9 +44,7 @@ "InvalidURI", "InvalidHandshake", "SecurityError", - "InvalidMessage", "InvalidStatus", - "InvalidStatusCode", "InvalidHeader", "InvalidHeaderFormat", "InvalidHeaderValue", @@ -57,10 +54,7 @@ "DuplicateParameter", "InvalidParameterName", "InvalidParameterValue", - "AbortHandshake", - "RedirectHandshake", "ProtocolError", - "WebSocketProtocolError", "PayloadTooBig", "InvalidState", "ConcurrencyError", @@ -366,27 +360,18 @@ class ConcurrencyError(WebSocketException, RuntimeError): """ -# When type checking, import non-deprecated aliases eagerly. Else, import on demand. -if typing.TYPE_CHECKING: - from .legacy.exceptions import ( - AbortHandshake, - InvalidMessage, - InvalidStatusCode, - RedirectHandshake, - ) - - WebSocketProtocolError = ProtocolError -else: - lazy_import( - globals(), - aliases={ - "AbortHandshake": ".legacy.exceptions", - "InvalidMessage": ".legacy.exceptions", - "InvalidStatusCode": ".legacy.exceptions", - "RedirectHandshake": ".legacy.exceptions", - "WebSocketProtocolError": ".legacy.exceptions", - }, - ) - # At the bottom to break import cycles created by type annotations. from . import frames, http11 # noqa: E402 + + +lazy_import( + globals(), + deprecated_aliases={ + # deprecated in 14.0 + "AbortHandshake": ".legacy.exceptions", + "InvalidMessage": ".legacy.exceptions", + "InvalidStatusCode": ".legacy.exceptions", + "RedirectHandshake": ".legacy.exceptions", + "WebSocketProtocolError": ".legacy.exceptions", + }, +) diff --git a/src/websockets/server.py b/src/websockets/server.py index 9fe970619..527db8990 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -27,6 +27,7 @@ parse_upgrade, ) from .http11 import Request, Response +from .imports import lazy_import from .protocol import CONNECTING, OPEN, SERVER, Protocol, State from .typing import ( ConnectionOption, @@ -40,13 +41,7 @@ from .utils import accept_key -# See #940 for why lazy_import isn't used here for backwards compatibility. -# See #1400 for why listing compatibility imports in __all__ helps PyCharm. -from .legacy.server import * # isort:skip # noqa: I001 -from .legacy.server import __all__ as legacy__all__ - - -__all__ = ["ServerProtocol"] + legacy__all__ +__all__ = ["ServerProtocol"] class ServerProtocol(Protocol): @@ -586,3 +581,16 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: DeprecationWarning, ) super().__init__(*args, **kwargs) + + +lazy_import( + globals(), + deprecated_aliases={ + # deprecated in 14.0 + "WebSocketServer": ".legacy.server", + "WebSocketServerProtocol": ".legacy.server", + "broadcast": ".legacy.server", + "serve": ".legacy.server", + "unix_serve": ".legacy.server", + }, +) diff --git a/tests/legacy/test_auth.py b/tests/legacy/test_auth.py index 3754bcf3a..dabd4212a 100644 --- a/tests/legacy/test_auth.py +++ b/tests/legacy/test_auth.py @@ -2,10 +2,10 @@ import unittest import urllib.error -from websockets.exceptions import InvalidStatusCode from websockets.headers import build_authorization_basic from websockets.legacy.auth import * from websockets.legacy.auth import is_credentials +from websockets.legacy.exceptions import InvalidStatusCode from .test_client_server import ClientServerTestsMixin, with_client, with_server from .utils import AsyncioTestCase diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 2f3ba9b77..502ab68e7 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -21,7 +21,6 @@ ConnectionClosed, InvalidHandshake, InvalidHeader, - InvalidStatusCode, NegotiationError, ) from websockets.extensions.permessage_deflate import ( @@ -32,6 +31,7 @@ from websockets.frames import CloseCode from websockets.http11 import USER_AGENT from websockets.legacy.client import * +from websockets.legacy.exceptions import InvalidStatusCode from websockets.legacy.handshake import build_response from websockets.legacy.http import read_response from websockets.legacy.server import * diff --git a/tests/test_exports.py b/tests/test_exports.py index 67a1a6f99..93b0684f7 100644 --- a/tests/test_exports.py +++ b/tests/test_exports.py @@ -1,30 +1,46 @@ import unittest import websockets -import websockets.auth +import websockets.asyncio.client +import websockets.asyncio.server import websockets.client import websockets.datastructures import websockets.exceptions -import websockets.legacy.protocol import websockets.server import websockets.typing import websockets.uri combined_exports = ( - websockets.auth.__all__ + [] + + websockets.asyncio.client.__all__ + + websockets.asyncio.server.__all__ + websockets.client.__all__ + websockets.datastructures.__all__ + websockets.exceptions.__all__ - + websockets.legacy.protocol.__all__ + websockets.server.__all__ + websockets.typing.__all__ ) +# These API are intentionally not re-exported by the top-level module. +missing_reexports = [ + # websockets.asyncio.client + "ClientConnection", + # websockets.asyncio.server + "ServerConnection", + "Server", +] + class ExportsTests(unittest.TestCase): - def test_top_level_module_reexports_all_submodule_exports(self): - self.assertEqual(set(combined_exports), set(websockets.__all__)) + def test_top_level_module_reexports_submodule_exports(self): + self.assertEqual( + set(combined_exports), + set(websockets.__all__ + missing_reexports), + ) def test_submodule_exports_are_globally_unique(self): - self.assertEqual(len(set(combined_exports)), len(combined_exports)) + self.assertEqual( + len(set(combined_exports)), + len(combined_exports), + ) From d62e423744b3e20ef6405019f96a63faa21612a2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 22 Sep 2024 12:10:15 +0200 Subject: [PATCH 08/51] Deprecate the legacy asyncio implementation. --- docs/howto/upgrade.rst | 21 ++++++--------- docs/index.rst | 43 ++++++++++++++++--------------- docs/project/changelog.rst | 3 +++ docs/reference/index.rst | 30 +++++++++++---------- docs/reference/legacy/client.rst | 6 +++++ docs/reference/legacy/common.rst | 6 +++++ docs/reference/legacy/server.rst | 6 +++++ docs/topics/design.rst | 2 ++ docs/topics/index.rst | 1 - src/websockets/auth.py | 12 ++++++--- src/websockets/http.py | 7 ++++- src/websockets/legacy/__init__.py | 11 ++++++++ tests/legacy/__init__.py | 9 +++++++ tests/maxi_cov.py | 2 -- tests/test_auth.py | 14 ++++++++++ 15 files changed, 118 insertions(+), 55 deletions(-) create mode 100644 tests/test_auth.py diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index bdaefd768..02d4c6f01 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -27,15 +27,9 @@ respectively. .. admonition:: What will happen to the original implementation? :class: hint - The original implementation is now considered legacy. - - The next steps are: - - 1. Deprecating it once the new implementation is considered sufficiently - robust. - 2. Maintaining it for five years per the :ref:`backwards-compatibility - policy `. - 3. Removing it. This is expected to happen around 2030. + The original implementation is deprecated. It will be maintained for five + years after deprecation according to the :ref:`backwards-compatibility + policy `. Then, by 2030, it will be removed. .. _deprecated APIs: @@ -69,13 +63,14 @@ Import paths For context, the ``websockets`` package is structured as follows: * The new implementation is found in the ``websockets.asyncio`` package. -* The original implementation was moved to the ``websockets.legacy`` package. +* The original implementation was moved to the ``websockets.legacy`` package + and deprecated. * The ``websockets`` package provides aliases for convenience. They were switched to the new implementation in version 14.0 or deprecated when there - isn't an equivalent API. + wasn't an equivalent API. * The ``websockets.client`` and ``websockets.server`` packages provide aliases - for backwards-compatibility with earlier versions of websockets. They will - be deprecated together with the original implementation. + for backwards-compatibility with earlier versions of websockets. They were + deprecated. To upgrade to the new :mod:`asyncio` implementation, change import paths as shown in the tables below. diff --git a/docs/index.rst b/docs/index.rst index b8cd300e3..f9576f2dc 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -28,30 +28,13 @@ with a focus on correctness, simplicity, robustness, and performance. It supports several network I/O and control flow paradigms. -1. The primary implementation builds upon :mod:`asyncio`, Python's standard - asynchronous I/O framework. It provides an elegant coroutine-based API. It's - ideal for servers that handle many clients concurrently. - - .. admonition:: As of version :ref:`13.0`, there is a new :mod:`asyncio` - implementation. - :class: important - - The historical implementation in ``websockets.legacy`` traces its roots to - early versions of websockets. Although it's stable and robust, it is now - considered legacy. - - The new implementation in ``websockets.asyncio`` is a rewrite on top of - the Sans-I/O implementation. It adds a few features that were impossible - to implement within the original design. - - The new implementation provides all features of the historical - implementation, and a few more. If you're using the historical - implementation, you should :doc:`ugrade to the new implementation - `. It's usually straightforward. +1. The default implementation builds upon :mod:`asyncio`, Python's built-in + asynchronous I/O library. It provides an elegant coroutine-based API. It's + ideal for servers that handle many client connections. 2. The :mod:`threading` implementation is a good alternative for clients, especially if you aren't familiar with :mod:`asyncio`. It may also be used - for servers that don't need to serve many clients. + for servers that handle few client connections. 3. The `Sans-I/O`_ implementation is designed for integrating in third-party libraries, typically application servers, in addition being used internally @@ -59,6 +42,24 @@ It supports several network I/O and control flow paradigms. .. _Sans-I/O: https://sans-io.readthedocs.io/ +Refer to the :doc:`feature support matrices ` for the full +list of features provided by each implementation. + +.. admonition:: The :mod:`asyncio` implementation was rewritten. + :class: tip + + The new implementation in ``websockets.asyncio`` builds upon the Sans-I/O + implementation. It adds features that were impossible to provide in the + original design. It was introduced in version 13.0. + + The historical implementation in ``websockets.legacy`` traces its roots to + early versions of websockets. While it's stable and robust, it was deprecated + in version 14.0 and it will be removed by 2030. + + The new implementation provides the same features as the historical + implementation, and then some. If you're using the historical implementation, + you should :doc:`ugrade to the new implementation `. + Here's an echo server using the :mod:`asyncio` API: .. literalinclude:: ../example/echo.py diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 30e89a69c..c8d854ba4 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -55,6 +55,9 @@ Backwards-incompatible changes .. admonition:: The legacy :mod:`asyncio` implementation is now deprecated. :class: caution + The :doc:`upgrade guide <../howto/upgrade>` provides complete instructions + to migrate your application. + Aliases for deprecated API were removed from ``__all__``. As a consequence, they cannot be imported e.g. with ``from websockets import *`` anymore. diff --git a/docs/reference/index.rst b/docs/reference/index.rst index ed2341cc6..c78a3c095 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -13,12 +13,12 @@ Check which implementations support which features and known limitations. features -:mod:`asyncio` (new) --------------------- +:mod:`asyncio` +-------------- It's ideal for servers that handle many clients concurrently. -It's a rewrite of the legacy :mod:`asyncio` implementation. +This is the default implementation. .. toctree:: :titlesonly: @@ -26,17 +26,6 @@ It's a rewrite of the legacy :mod:`asyncio` implementation. asyncio/server asyncio/client -:mod:`asyncio` (legacy) ------------------------ - -This is the historical implementation. - -.. toctree:: - :titlesonly: - - legacy/server - legacy/client - :mod:`threading` ---------------- @@ -62,6 +51,19 @@ application servers. sansio/server sansio/client +:mod:`asyncio` (legacy) +----------------------- + +This is the historical implementation. + +It is deprecated and will be removed. + +.. toctree:: + :titlesonly: + + legacy/server + legacy/client + Extensions ---------- diff --git a/docs/reference/legacy/client.rst b/docs/reference/legacy/client.rst index fca45d218..a798409f0 100644 --- a/docs/reference/legacy/client.rst +++ b/docs/reference/legacy/client.rst @@ -1,6 +1,12 @@ Client (legacy :mod:`asyncio`) ============================== +.. admonition:: The legacy :mod:`asyncio` implementation is deprecated. + :class: caution + + The :doc:`upgrade guide <../../howto/upgrade>` provides complete instructions + to migrate your application. + .. automodule:: websockets.legacy.client Opening a connection diff --git a/docs/reference/legacy/common.rst b/docs/reference/legacy/common.rst index aee774479..45c56fccd 100644 --- a/docs/reference/legacy/common.rst +++ b/docs/reference/legacy/common.rst @@ -3,6 +3,12 @@ Both sides (legacy :mod:`asyncio`) ================================== +.. admonition:: The legacy :mod:`asyncio` implementation is deprecated. + :class: caution + + The :doc:`upgrade guide <../../howto/upgrade>` provides complete instructions + to migrate your application. + .. automodule:: websockets.legacy.protocol .. autoclass:: WebSocketCommonProtocol(*, logger=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) diff --git a/docs/reference/legacy/server.rst b/docs/reference/legacy/server.rst index b6c383ce7..3c1d19fc6 100644 --- a/docs/reference/legacy/server.rst +++ b/docs/reference/legacy/server.rst @@ -1,6 +1,12 @@ Server (legacy :mod:`asyncio`) ============================== +.. admonition:: The legacy :mod:`asyncio` implementation is deprecated. + :class: caution + + The :doc:`upgrade guide <../../howto/upgrade>` provides complete instructions + to migrate your application. + .. automodule:: websockets.legacy.server Starting a server diff --git a/docs/topics/design.rst b/docs/topics/design.rst index d2fd18d0c..b73ace517 100644 --- a/docs/topics/design.rst +++ b/docs/topics/design.rst @@ -1,3 +1,5 @@ +:orphan: + Design (legacy :mod:`asyncio`) ============================== diff --git a/docs/topics/index.rst b/docs/topics/index.rst index a2b8ca879..616753c6c 100644 --- a/docs/topics/index.rst +++ b/docs/topics/index.rst @@ -12,7 +12,6 @@ Get a deeper understanding of how websockets is built and why. broadcast compression keepalive - design memory security performance diff --git a/src/websockets/auth.py b/src/websockets/auth.py index 1e0002cee..98e62af3c 100644 --- a/src/websockets/auth.py +++ b/src/websockets/auth.py @@ -2,11 +2,17 @@ import warnings -from .legacy.auth import * -from .legacy.auth import __all__ # noqa: F401 + +with warnings.catch_warnings(): + # Suppress redundant DeprecationWarning raised by websockets.legacy. + warnings.filterwarnings("ignore", category=DeprecationWarning) + from .legacy.auth import * + from .legacy.auth import __all__ # noqa: F401 warnings.warn( # deprecated in 14.0 - "websockets.auth is deprecated", + "websockets.auth, an alias for websockets.legacy.auth, is deprecated; " + "see https://websockets.readthedocs.io/en/stable/howto/upgrade.html " + "for upgrade instructions", DeprecationWarning, ) diff --git a/src/websockets/http.py b/src/websockets/http.py index 0ff5598c7..0d860e537 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -3,7 +3,12 @@ import warnings from .datastructures import Headers, MultipleValuesError # noqa: F401 -from .legacy.http import read_request, read_response # noqa: F401 + + +with warnings.catch_warnings(): + # Suppress redundant DeprecationWarning raised by websockets.legacy. + warnings.filterwarnings("ignore", category=DeprecationWarning) + from .legacy.http import read_request, read_response # noqa: F401 warnings.warn( # deprecated in 9.0 - 2021-09-01 diff --git a/src/websockets/legacy/__init__.py b/src/websockets/legacy/__init__.py index e69de29bb..84f870f3a 100644 --- a/src/websockets/legacy/__init__.py +++ b/src/websockets/legacy/__init__.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +import warnings + + +warnings.warn( # deprecated in 14.0 + "websockets.legacy is deprecated; " + "see https://websockets.readthedocs.io/en/stable/howto/upgrade.html " + "for upgrade instructions", + DeprecationWarning, +) diff --git a/tests/legacy/__init__.py b/tests/legacy/__init__.py index e69de29bb..035834a89 100644 --- a/tests/legacy/__init__.py +++ b/tests/legacy/__init__.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +import warnings + + +with warnings.catch_warnings(): + # Suppress DeprecationWarning raised by websockets.legacy. + warnings.filterwarnings("ignore", category=DeprecationWarning) + import websockets.legacy # noqa: F401 diff --git a/tests/maxi_cov.py b/tests/maxi_cov.py index 83686c3d3..8ccef7d39 100755 --- a/tests/maxi_cov.py +++ b/tests/maxi_cov.py @@ -9,7 +9,6 @@ UNMAPPED_SRC_FILES = [ - "websockets/auth.py", "websockets/typing.py", "websockets/version.py", ] @@ -105,7 +104,6 @@ def get_ignored_files(src_dir="src"): # or websockets (import locations). "*/websockets/asyncio/async_timeout.py", "*/websockets/asyncio/compatibility.py", - "*/websockets/auth.py", # This approach isn't applicable to the test suite of the legacy # implementation, due to the huge test_client_server test module. "*/websockets/legacy/*", diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 000000000..16c00c1b9 --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,14 @@ +from .utils import DeprecationTestCase + + +class BackwardsCompatibilityTests(DeprecationTestCase): + def test_headers_class(self): + with self.assertDeprecationWarning( + "websockets.auth, an alias for websockets.legacy.auth, is deprecated; " + "see https://websockets.readthedocs.io/en/stable/howto/upgrade.html " + "for upgrade instructions", + ): + from websockets.auth import ( + BasicAuthWebSocketServerProtocol, # noqa: F401 + basic_auth_protocol_factory, # noqa: F401 + ) From a0b20f081d7ae48409c6b79ed16bbc261d5109f9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 22 Sep 2024 21:16:26 +0200 Subject: [PATCH 09/51] Document that only asyncio supports keepalive. Fix #1508. --- docs/topics/keepalive.rst | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/docs/topics/keepalive.rst b/docs/topics/keepalive.rst index 91f11fb11..458fa3d05 100644 --- a/docs/topics/keepalive.rst +++ b/docs/topics/keepalive.rst @@ -1,6 +1,11 @@ Keepalive and latency ===================== +.. admonition:: This guide applies only to the :mod:`asyncio` implementation. + :class: tip + + The :mod:`threading` implementation doesn't provide keepalive yet. + .. currentmodule:: websockets Long-lived connections @@ -31,14 +36,8 @@ based on WebSocket Ping_ and Pong_ frames, which are designed for this purpose. .. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 .. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 -It loops through these steps: - -1. Wait 20 seconds. -2. Send a Ping frame. -3. Receive a corresponding Pong frame within 20 seconds. - -If the Pong frame isn't received, websockets considers the connection broken and -closes it. +It sends a Ping frame every 20 seconds. It expects a Pong frame in return within +20 seconds. Else, it considers the connection broken and terminates it. This mechanism serves three purposes: From baadc33364131ff7236249c9077ae10b561395b6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 23 Sep 2024 23:52:25 +0200 Subject: [PATCH 10/51] Profile and optimize the permessage-deflate extension. dataclasses.replace is surprisingly expensive. zlib functions make up the bulk of the cost now. --- experiments/compression/corpus.py | 2 +- experiments/profiling/compression.py | 45 +++++++++++++++++++ .../extensions/permessage_deflate.py | 29 +++++++++--- 3 files changed, 68 insertions(+), 8 deletions(-) create mode 100644 experiments/profiling/compression.py diff --git a/experiments/compression/corpus.py b/experiments/compression/corpus.py index da5661dfa..56e262114 100644 --- a/experiments/compression/corpus.py +++ b/experiments/compression/corpus.py @@ -47,6 +47,6 @@ def main(corpus): if __name__ == "__main__": if len(sys.argv) < 2: - print(f"Usage: {sys.argv[0]} [directory]") + print(f"Usage: {sys.argv[0]} ") sys.exit(2) main(pathlib.Path(sys.argv[1])) diff --git a/experiments/profiling/compression.py b/experiments/profiling/compression.py new file mode 100644 index 000000000..1ece1f10e --- /dev/null +++ b/experiments/profiling/compression.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python + +""" +Profile the permessage-deflate extension. + +Usage:: + $ pip install line_profiler + $ python experiments/compression/corpus.py experiments/compression/corpus + $ PYTHONPATH=src python -m kernprof \ + --line-by-line \ + --prof-mod src/websockets/extensions/permessage_deflate.py \ + --view \ + experiments/profiling/compression.py experiments/compression/corpus 12 5 6 + +""" + +import pathlib +import sys + +from websockets.extensions.permessage_deflate import PerMessageDeflate +from websockets.frames import OP_TEXT, Frame + + +def compress_and_decompress(corpus, max_window_bits, memory_level, level): + extension = PerMessageDeflate( + remote_no_context_takeover=False, + local_no_context_takeover=False, + remote_max_window_bits=max_window_bits, + local_max_window_bits=max_window_bits, + compress_settings={"memLevel": memory_level, "level": level}, + ) + for data in corpus: + frame = Frame(OP_TEXT, data) + frame = extension.encode(frame) + frame = extension.decode(frame) + + +if __name__ == "__main__": + if len(sys.argv) < 2 or not pathlib.Path(sys.argv[1]).is_dir(): + print(f"Usage: {sys.argv[0]} [] []") + corpus = [file.read_bytes() for file in pathlib.Path(sys.argv[1]).iterdir()] + max_window_bits = int(sys.argv[2]) if len(sys.argv) > 2 else 12 + memory_level = int(sys.argv[3]) if len(sys.argv) > 3 else 5 + level = int(sys.argv[4]) if len(sys.argv) > 4 else 6 + compress_and_decompress(corpus, max_window_bits, memory_level, level) diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index f962b65fb..21df804fd 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -1,6 +1,5 @@ from __future__ import annotations -import dataclasses import zlib from collections.abc import Sequence from typing import Any @@ -120,7 +119,6 @@ def decode( else: if not frame.rsv1: return frame - frame = dataclasses.replace(frame, rsv1=False) if not frame.fin: self.decode_cont_data = True @@ -146,7 +144,15 @@ def decode( if frame.fin and self.remote_no_context_takeover: del self.decoder - return dataclasses.replace(frame, data=data) + return frames.Frame( + frame.opcode, + data, + frame.fin, + # Unset the rsv1 flag on the first frame of a compressed message. + False, + frame.rsv2, + frame.rsv3, + ) def encode(self, frame: frames.Frame) -> frames.Frame: """ @@ -161,8 +167,6 @@ def encode(self, frame: frames.Frame) -> frames.Frame: # data" flag similar to "decode continuation data" at this time. if frame.opcode is not frames.OP_CONT: - # Set the rsv1 flag on the first frame of a compressed message. - frame = dataclasses.replace(frame, rsv1=True) # Re-initialize per-message decoder. if self.local_no_context_takeover: self.encoder = zlib.compressobj( @@ -172,14 +176,25 @@ def encode(self, frame: frames.Frame) -> frames.Frame: # Compress data. data = self.encoder.compress(frame.data) + self.encoder.flush(zlib.Z_SYNC_FLUSH) - if frame.fin and data.endswith(_EMPTY_UNCOMPRESSED_BLOCK): + if frame.fin and data[-4:] == _EMPTY_UNCOMPRESSED_BLOCK: + # Making a copy is faster than memoryview(a)[:-4] until about 2kB. + # On larger messages, it's slower but profiling shows that it's + # marginal compared to compress() and flush(). Keep it simple. data = data[:-4] # Allow garbage collection of the encoder if it won't be reused. if frame.fin and self.local_no_context_takeover: del self.encoder - return dataclasses.replace(frame, data=data) + return frames.Frame( + frame.opcode, + data, + frame.fin, + # Set the rsv1 flag on the first frame of a compressed message. + frame.opcode is not frames.OP_CONT, + frame.rsv2, + frame.rsv3, + ) def _build_parameters( From 524dd4afa8dc2a0d17ab08f20f592af03c165db1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 26 Sep 2024 23:11:28 +0200 Subject: [PATCH 11/51] Avoid making a copy of large frames. This isn't very significant compared to the cost of compression. It can make a real difference for decompression. --- docs/project/changelog.rst | 14 +++++++++ .../extensions/permessage_deflate.py | 30 ++++++++++++------- src/websockets/frames.py | 8 ++--- tests/extensions/test_permessage_deflate.py | 11 +++++++ 4 files changed, 49 insertions(+), 14 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index c8d854ba4..f5b4812bd 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -61,6 +61,20 @@ Backwards-incompatible changes Aliases for deprecated API were removed from ``__all__``. As a consequence, they cannot be imported e.g. with ``from websockets import *`` anymore. +.. admonition:: :attr:`Frame.data ` is now a bytes-like object. + :class: note + + In addition to :class:`bytes`, it may be a :class:`bytearray` or a + :class:`memoryview`. + + If you wrote an :class:`extension ` that relies on + methods not provided by these new types, you may need to update your code. + +Improvements +............ + +* Sending or receiving large compressed frames is now faster. + .. _13.1: 13.1 diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 21df804fd..ed16937d8 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -129,16 +129,22 @@ def decode( # Uncompress data. Protect against zip bombs by preventing zlib from # decompressing more than max_length bytes (except when the limit is # disabled with max_size = None). - data = frame.data - if frame.fin: - data += _EMPTY_UNCOMPRESSED_BLOCK + if frame.fin and len(frame.data) < 2044: + # Profiling shows that appending four bytes, which makes a copy, is + # faster than calling decompress() again when data is less than 2kB. + data = bytes(frame.data) + _EMPTY_UNCOMPRESSED_BLOCK + else: + data = frame.data max_length = 0 if max_size is None else max_size try: data = self.decoder.decompress(data, max_length) + if self.decoder.unconsumed_tail: + raise PayloadTooBig(f"over size limit (? > {max_size} bytes)") + if frame.fin and len(frame.data) >= 2044: + # This cannot generate additional data. + self.decoder.decompress(_EMPTY_UNCOMPRESSED_BLOCK) except zlib.error as exc: raise ProtocolError("decompression failed") from exc - if self.decoder.unconsumed_tail: - raise PayloadTooBig(f"over size limit (? > {max_size} bytes)") # Allow garbage collection of the decoder if it won't be reused. if frame.fin and self.remote_no_context_takeover: @@ -176,11 +182,15 @@ def encode(self, frame: frames.Frame) -> frames.Frame: # Compress data. data = self.encoder.compress(frame.data) + self.encoder.flush(zlib.Z_SYNC_FLUSH) - if frame.fin and data[-4:] == _EMPTY_UNCOMPRESSED_BLOCK: - # Making a copy is faster than memoryview(a)[:-4] until about 2kB. - # On larger messages, it's slower but profiling shows that it's - # marginal compared to compress() and flush(). Keep it simple. - data = data[:-4] + if frame.fin: + # Sync flush generates between 5 or 6 bytes, ending with the bytes + # 0x00 0x00 0xff 0xff, which must be removed. + assert data[-4:] == _EMPTY_UNCOMPRESSED_BLOCK + # Making a copy is faster than memoryview(a)[:-4] until 2kB. + if len(data) < 2048: + data = data[:-4] + else: + data = memoryview(data)[:-4] # Allow garbage collection of the encoder if it won't be reused. if frame.fin and self.local_no_context_takeover: diff --git a/src/websockets/frames.py b/src/websockets/frames.py index dace2c902..5fadf3c2d 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -7,7 +7,7 @@ import secrets import struct from collections.abc import Generator, Sequence -from typing import Callable +from typing import Callable, Union from .exceptions import PayloadTooBig, ProtocolError @@ -139,7 +139,7 @@ class Frame: """ opcode: Opcode - data: bytes + data: Union[bytes, bytearray, memoryview] fin: bool = True rsv1: bool = False rsv2: bool = False @@ -160,7 +160,7 @@ def __str__(self) -> str: if self.opcode is OP_TEXT: # Decoding only the beginning and the end is needlessly hard. # Decode the entire payload then elide later if necessary. - data = repr(self.data.decode()) + data = repr(bytes(self.data).decode()) elif self.opcode is OP_BINARY: # We'll show at most the first 16 bytes and the last 8 bytes. # Encode just what we need, plus two dummy bytes to elide later. @@ -178,7 +178,7 @@ def __str__(self) -> str: # binary. If self.data is a memoryview, it has no decode() method, # which raises AttributeError. try: - data = repr(self.data.decode()) + data = repr(bytes(self.data).decode()) coding = "text" except (UnicodeDecodeError, AttributeError): binary = self.data diff --git a/tests/extensions/test_permessage_deflate.py b/tests/extensions/test_permessage_deflate.py index ee09813c4..76cd48623 100644 --- a/tests/extensions/test_permessage_deflate.py +++ b/tests/extensions/test_permessage_deflate.py @@ -1,4 +1,5 @@ import dataclasses +import os import unittest from websockets.exceptions import ( @@ -167,6 +168,16 @@ def test_encode_decode_fragmented_binary_frame(self): self.assertEqual(dec_frame1, frame1) self.assertEqual(dec_frame2, frame2) + def test_encode_decode_large_frame(self): + # There is a separate code path that avoids copying data + # when frames are larger than 2kB. Test it for coverage. + frame = Frame(OP_BINARY, os.urandom(4096)) + + enc_frame = self.extension.encode(frame) + dec_frame = self.extension.decode(enc_frame) + + self.assertEqual(dec_frame, frame) + def test_no_decode_text_frame(self): frame = Frame(OP_TEXT, "café".encode()) From 07dc56443acd8aaac4d79db6a30e450d0073137d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 27 Sep 2024 21:44:36 +0200 Subject: [PATCH 12/51] Add asyncio & threading examples to the homepage. Fix #1437. --- docs/index.rst | 19 +++++++++++++++---- example/{ => asyncio}/echo.py | 12 +++++++++--- example/asyncio/hello.py | 17 +++++++++++++++++ example/ruff.toml | 2 ++ example/sync/echo.py | 19 +++++++++++++++++++ example/{ => sync}/hello.py | 10 +++++++--- 6 files changed, 69 insertions(+), 10 deletions(-) rename example/{ => asyncio}/echo.py (51%) create mode 100755 example/asyncio/hello.py create mode 100644 example/ruff.toml create mode 100755 example/sync/echo.py rename example/{ => sync}/hello.py (66%) diff --git a/docs/index.rst b/docs/index.rst index f9576f2dc..de14fa2d0 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -60,13 +60,24 @@ list of features provided by each implementation. implementation, and then some. If you're using the historical implementation, you should :doc:`ugrade to the new implementation `. -Here's an echo server using the :mod:`asyncio` API: +Here's an echo server and corresponding client. -.. literalinclude:: ../example/echo.py +.. tab:: asyncio -Here's a client using the :mod:`threading` API: + .. literalinclude:: ../example/asyncio/echo.py -.. literalinclude:: ../example/hello.py +.. tab:: threading + + .. literalinclude:: ../example/sync/echo.py + +.. tab:: asyncio + :new-set: + + .. literalinclude:: ../example/asyncio/hello.py + +.. tab:: threading + + .. literalinclude:: ../example/sync/hello.py Don't worry about the opening and closing handshakes, pings and pongs, or any other behavior described in the WebSocket specification. websockets takes care diff --git a/example/echo.py b/example/asyncio/echo.py similarity index 51% rename from example/echo.py rename to example/asyncio/echo.py index b952a5cfb..28d877be7 100755 --- a/example/echo.py +++ b/example/asyncio/echo.py @@ -1,14 +1,20 @@ #!/usr/bin/env python +"""Echo server using the asyncio API.""" + import asyncio from websockets.asyncio.server import serve + async def echo(websocket): async for message in websocket: await websocket.send(message) + async def main(): - async with serve(echo, "localhost", 8765): - await asyncio.get_running_loop().create_future() # run forever + async with serve(echo, "localhost", 8765) as server: + await server.serve_forever() + -asyncio.run(main()) +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/asyncio/hello.py b/example/asyncio/hello.py new file mode 100755 index 000000000..6e4518497 --- /dev/null +++ b/example/asyncio/hello.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python + +"""Client using the asyncio API.""" + +import asyncio +from websockets.asyncio.client import connect + + +async def hello(): + async with connect("ws://localhost:8765") as websocket: + await websocket.send("Hello world!") + message = await websocket.recv() + print(message) + + +if __name__ == "__main__": + asyncio.run(hello()) diff --git a/example/ruff.toml b/example/ruff.toml new file mode 100644 index 000000000..13ae36c08 --- /dev/null +++ b/example/ruff.toml @@ -0,0 +1,2 @@ +[lint.isort] +no-sections = true diff --git a/example/sync/echo.py b/example/sync/echo.py new file mode 100755 index 000000000..4b47db1ba --- /dev/null +++ b/example/sync/echo.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python + +"""Echo server using the threading API.""" + +from websockets.sync.server import serve + + +def echo(websocket): + for message in websocket: + websocket.send(message) + + +def main(): + with serve(echo, "localhost", 8765) as server: + server.serve_forever() + + +if __name__ == "__main__": + main() diff --git a/example/hello.py b/example/sync/hello.py similarity index 66% rename from example/hello.py rename to example/sync/hello.py index a3ce0699e..bb4cd3ffd 100755 --- a/example/hello.py +++ b/example/sync/hello.py @@ -1,12 +1,16 @@ #!/usr/bin/env python -import asyncio +"""Client using the threading API.""" + from websockets.sync.client import connect + def hello(): with connect("ws://localhost:8765") as websocket: websocket.send("Hello world!") message = websocket.recv() - print(f"Received: {message}") + print(message) + -hello() +if __name__ == "__main__": + hello() From a5c8943a99a8625836b150ca3559923ecf79bcd0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 27 Sep 2024 21:53:49 +0200 Subject: [PATCH 13/51] Add asyncio & threading client & server examples. They're convenient to tweak to reproduce issues. --- example/asyncio/client.py | 22 ++++++++++++++++++++++ example/asyncio/server.py | 25 +++++++++++++++++++++++++ example/sync/client.py | 20 ++++++++++++++++++++ example/sync/server.py | 24 ++++++++++++++++++++++++ 4 files changed, 91 insertions(+) create mode 100644 example/asyncio/client.py create mode 100644 example/asyncio/server.py create mode 100644 example/sync/client.py create mode 100644 example/sync/server.py diff --git a/example/asyncio/client.py b/example/asyncio/client.py new file mode 100644 index 000000000..e3562642d --- /dev/null +++ b/example/asyncio/client.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python + +"""Client example using the asyncio API.""" + +import asyncio + +from websockets.asyncio.client import connect + + +async def hello(): + async with connect("ws://localhost:8765") as websocket: + name = input("What's your name? ") + + await websocket.send(name) + print(f">>> {name}") + + greeting = await websocket.recv() + print(f"<<< {greeting}") + + +if __name__ == "__main__": + asyncio.run(hello()) diff --git a/example/asyncio/server.py b/example/asyncio/server.py new file mode 100644 index 000000000..574e053bf --- /dev/null +++ b/example/asyncio/server.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python + +"""Server example using the asyncio API.""" + +import asyncio +from websockets.asyncio.server import serve + + +async def hello(websocket): + name = await websocket.recv() + print(f"<<< {name}") + + greeting = f"Hello {name}!" + + await websocket.send(greeting) + print(f">>> {greeting}") + + +async def main(): + async with serve(hello, "localhost", 8765) as server: + await server.serve_forever() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/sync/client.py b/example/sync/client.py new file mode 100644 index 000000000..c0d633c7b --- /dev/null +++ b/example/sync/client.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python + +"""Client example using the threading API.""" + +from websockets.sync.client import connect + + +def hello(): + with connect("ws://localhost:8765") as websocket: + name = input("What's your name? ") + + websocket.send(name) + print(f">>> {name}") + + greeting = websocket.recv() + print(f"<<< {greeting}") + + +if __name__ == "__main__": + hello() diff --git a/example/sync/server.py b/example/sync/server.py new file mode 100644 index 000000000..030049f81 --- /dev/null +++ b/example/sync/server.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python + +"""Server example using the threading API.""" + +from websockets.sync.server import serve + + +def hello(websocket): + name = websocket.recv() + print(f"<<< {name}") + + greeting = f"Hello {name}!" + + websocket.send(greeting) + print(f">>> {greeting}") + + +def main(): + with serve(hello, "localhost", 8765) as server: + server.serve_forever() + + +if __name__ == "__main__": + main() From 21987f96ad93f8c8bbf0b8ea99f3a18a52335730 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 27 Sep 2024 23:06:00 +0200 Subject: [PATCH 14/51] Migrate authentication experiment to new asyncio. --- docs/topics/authentication.rst | 118 ++++++-------- experiments/authentication/app.py | 194 ++++++++++-------------- experiments/authentication/script.js | 5 +- experiments/authentication/test.js | 2 - experiments/authentication/user_info.js | 2 +- 5 files changed, 127 insertions(+), 194 deletions(-) diff --git a/docs/topics/authentication.rst b/docs/topics/authentication.rst index 86d2e2587..e2de4332e 100644 --- a/docs/topics/authentication.rst +++ b/docs/topics/authentication.rst @@ -1,13 +1,13 @@ Authentication ============== -The WebSocket protocol was designed for creating web applications that need -bidirectional communication between clients running in browsers and servers. +The WebSocket protocol is designed for creating web applications that require +bidirectional communication between browsers and servers. In most practical use cases, WebSocket servers need to authenticate clients in order to route communications appropriately and securely. -:rfc:`6455` stays elusive when it comes to authentication: +:rfc:`6455` remains elusive when it comes to authentication: This protocol doesn't prescribe any particular way that servers can authenticate clients during the WebSocket handshake. The WebSocket @@ -26,8 +26,8 @@ System design Consider a setup where the WebSocket server is separate from the HTTP server. -Most servers built with websockets to complement a web application adopt this -design because websockets doesn't aim at supporting HTTP. +Most servers built with websockets adopt this design because they're a component +in a web application and websockets doesn't aim at supporting HTTP. The following diagram illustrates the authentication flow. @@ -82,8 +82,8 @@ WebSocket server. credentials would be a session identifier or a serialized, signed session. Unfortunately, when the WebSocket server runs on a different domain from - the web application, this idea bumps into the `Same-Origin Policy`_. For - security reasons, setting a cookie on a different origin is impossible. + the web application, this idea hits the wall of the `Same-Origin Policy`_. + For security reasons, setting a cookie on a different origin is impossible. The proper workaround consists in: @@ -108,13 +108,11 @@ WebSocket server. Letting the browser perform HTTP Basic Auth is a nice idea in theory. - In practice it doesn't work due to poor support in browsers. + In practice it doesn't work due to browser support limitations: - As of May 2021: + * Chrome behaves as expected. - * Chrome 90 behaves as expected. - - * Firefox 88 caches credentials too aggressively. + * Firefox caches credentials too aggressively. When connecting again to the same server with new credentials, it reuses the old credentials, which may be expired, resulting in an HTTP 401. Then @@ -123,7 +121,7 @@ WebSocket server. When tokens are short-lived or single-use, this bug produces an interesting effect: every other WebSocket connection fails. - * Safari 14 ignores credentials entirely. + * Safari behaves as expected. Two other options are off the table: @@ -142,8 +140,10 @@ Two other options are off the table: While this is suggested by the RFC, installing a TLS certificate is too far from the mainstream experience of browser users. This could make sense in - high security contexts. I hope developers working on such projects don't - take security advice from the documentation of random open source projects. + high security contexts. + + I hope that developers working on projects in this category don't take + security advice from the documentation of random open source projects :-) Let's experiment! ----------------- @@ -185,6 +185,8 @@ connection: .. code-block:: python + from websockets.frames import CloseCode + async def first_message_handler(websocket): token = await websocket.recv() user = get_user(token) @@ -212,24 +214,16 @@ the user. If authentication fails, it returns an HTTP 401: .. code-block:: python - from websockets.legacy.server import WebSocketServerProtocol - - class QueryParamProtocol(WebSocketServerProtocol): - async def process_request(self, path, headers): - token = get_query_parameter(path, "token") - if token is None: - return http.HTTPStatus.UNAUTHORIZED, [], b"Missing token\n" - - user = get_user(token) - if user is None: - return http.HTTPStatus.UNAUTHORIZED, [], b"Invalid token\n" + async def query_param_auth(connection, request): + token = get_query_param(request.path, "token") + if token is None: + return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Missing token\n") - self.user = user + user = get_user(token) + if user is None: + return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Invalid token\n") - async def query_param_handler(websocket): - user = websocket.user - - ... + connection.username = user Cookie ...... @@ -260,27 +254,19 @@ the user. If authentication fails, it returns an HTTP 401: .. code-block:: python - from websockets.legacy.server import WebSocketServerProtocol - - class CookieProtocol(WebSocketServerProtocol): - async def process_request(self, path, headers): - # Serve iframe on non-WebSocket requests - ... - - token = get_cookie(headers.get("Cookie", ""), "token") - if token is None: - return http.HTTPStatus.UNAUTHORIZED, [], b"Missing token\n" - - user = get_user(token) - if user is None: - return http.HTTPStatus.UNAUTHORIZED, [], b"Invalid token\n" + async def cookie_auth(connection, request): + # Serve iframe on non-WebSocket requests + ... - self.user = user + token = get_cookie(request.headers.get("Cookie", ""), "token") + if token is None: + return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Missing token\n") - async def cookie_handler(websocket): - user = websocket.user + user = get_user(token) + if user is None: + return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Invalid token\n") - ... + connection.username = user User information ................ @@ -303,24 +289,12 @@ the user. If authentication fails, it returns an HTTP 401: .. code-block:: python - from websockets.legacy.auth import BasicAuthWebSocketServerProtocol - - class UserInfoProtocol(BasicAuthWebSocketServerProtocol): - async def check_credentials(self, username, password): - if username != "token": - return False - - user = get_user(password) - if user is None: - return False + from websockets.asyncio.server import basic_auth as websockets_basic_auth - self.user = user - return True + def check_credentials(username, password): + return username == get_user(password) - async def user_info_handler(websocket): - user = websocket.user - - ... + basic_auth = websockets_basic_auth(check_credentials=check_credentials) Machine-to-machine authentication --------------------------------- @@ -334,11 +308,9 @@ To authenticate a websockets client with HTTP Basic Authentication .. code-block:: python - from websockets.legacy.client import connect + from websockets.asyncio.client import connect - async with connect( - f"wss://{username}:{password}@example.com" - ) as websocket: + async with connect(f"wss://{username}:{password}@.../") as websocket: ... (You must :func:`~urllib.parse.quote` ``username`` and ``password`` if they @@ -349,10 +321,8 @@ To authenticate a websockets client with HTTP Bearer Authentication .. code-block:: python - from websockets.legacy.client import connect + from websockets.asyncio.client import connect - async with connect( - "wss://example.com", - extra_headers={"Authorization": f"Bearer {token}"} - ) as websocket: + headers = {"Authorization": f"Bearer {token}"} + async with connect("wss://.../", additional_headers=headers) as websocket: ... diff --git a/experiments/authentication/app.py b/experiments/authentication/app.py index e3b2cf1f6..0bdd7fd2f 100644 --- a/experiments/authentication/app.py +++ b/experiments/authentication/app.py @@ -1,6 +1,7 @@ #!/usr/bin/env python import asyncio +import email.utils import http import http.cookies import pathlib @@ -8,9 +9,10 @@ import urllib.parse import uuid +from websockets.asyncio.server import basic_auth as websockets_basic_auth, serve +from websockets.datastructures import Headers from websockets.frames import CloseCode -from websockets.legacy.auth import BasicAuthWebSocketServerProtocol -from websockets.legacy.server import WebSocketServerProtocol, serve +from websockets.http11 import Response # User accounts database @@ -49,7 +51,19 @@ def get_query_param(path, key): return values[0] -# Main HTTP server +# WebSocket handler + + +async def handler(websocket): + try: + user = websocket.username + except AttributeError: + return + + await websocket.send(f"Hello {user}!") + message = await websocket.recv() + assert message == f"Goodbye {user}." + CONTENT_TYPES = { ".css": "text/css", @@ -59,9 +73,10 @@ def get_query_param(path, key): } -async def serve_html(path, request_headers): - user = get_query_param(path, "user") - path = urllib.parse.urlparse(path).path +async def serve_html(connection, request): + """Basic HTTP server implemented as a process_request hook.""" + user = get_query_param(request.path, "user") + path = urllib.parse.urlparse(request.path).path if path == "/": if user is None: page = "index.html" @@ -76,147 +91,96 @@ async def serve_html(path, request_headers): pass else: if template.is_file(): - headers = {"Content-Type": CONTENT_TYPES[template.suffix]} body = template.read_bytes() if user is not None: token = create_token(user) body = body.replace(b"TOKEN", token.encode()) - return http.HTTPStatus.OK, headers, body - - return http.HTTPStatus.NOT_FOUND, {}, b"Not found\n" - + headers = Headers( + { + "Date": email.utils.formatdate(usegmt=True), + "Connection": "close", + "Content-Length": str(len(body)), + "Content-Type": CONTENT_TYPES[template.suffix], + } + ) + return Response(200, "OK", headers, body) -async def noop_handler(websocket): - pass - - -# Send credentials as the first message in the WebSocket connection + return connection.respond(http.HTTPStatus.NOT_FOUND, "Not found\n") async def first_message_handler(websocket): + """Handler that sends credentials in the first WebSocket message.""" token = await websocket.recv() user = get_user(token) if user is None: await websocket.close(CloseCode.INTERNAL_ERROR, "authentication failed") return - await websocket.send(f"Hello {user}!") - message = await websocket.recv() - assert message == f"Goodbye {user}." - - -# Add credentials to the WebSocket URI in a query parameter - - -class QueryParamProtocol(WebSocketServerProtocol): - async def process_request(self, path, headers): - token = get_query_param(path, "token") - if token is None: - return http.HTTPStatus.UNAUTHORIZED, [], b"Missing token\n" - - user = get_user(token) - if user is None: - return http.HTTPStatus.UNAUTHORIZED, [], b"Invalid token\n" - - self.user = user - - -async def query_param_handler(websocket): - user = websocket.user - - await websocket.send(f"Hello {user}!") - message = await websocket.recv() - assert message == f"Goodbye {user}." - - -# Set a cookie on the domain of the WebSocket URI - - -class CookieProtocol(WebSocketServerProtocol): - async def process_request(self, path, headers): - if "Upgrade" not in headers: - template = pathlib.Path(__file__).with_name(path[1:]) - headers = {"Content-Type": CONTENT_TYPES[template.suffix]} - body = template.read_bytes() - return http.HTTPStatus.OK, headers, body - - token = get_cookie(headers.get("Cookie", ""), "token") - if token is None: - return http.HTTPStatus.UNAUTHORIZED, [], b"Missing token\n" - - user = get_user(token) - if user is None: - return http.HTTPStatus.UNAUTHORIZED, [], b"Invalid token\n" - - self.user = user + websocket.username = user + await handler(websocket) -async def cookie_handler(websocket): - user = websocket.user +async def query_param_auth(connection, request): + """Authenticate user from token in query parameter.""" + token = get_query_param(request.path, "token") + if token is None: + return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Missing token\n") - await websocket.send(f"Hello {user}!") - message = await websocket.recv() - assert message == f"Goodbye {user}." - - -# Adding credentials to the WebSocket URI in user information - - -class UserInfoProtocol(BasicAuthWebSocketServerProtocol): - async def check_credentials(self, username, password): - if username != "token": - return False - - user = get_user(password) - if user is None: - return False + user = get_user(token) + if user is None: + return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Invalid token\n") + + connection.username = user + + +async def cookie_auth(connection, request): + """Authenticate user from token in cookie.""" + if "Upgrade" not in request.headers: + template = pathlib.Path(__file__).with_name(request.path[1:]) + body = template.read_bytes() + headers = Headers( + { + "Date": email.utils.formatdate(usegmt=True), + "Connection": "close", + "Content-Length": str(len(body)), + "Content-Type": CONTENT_TYPES[template.suffix], + } + ) + return Response(200, "OK", headers, body) + + token = get_cookie(request.headers.get("Cookie", ""), "token") + if token is None: + return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Missing token\n") - self.user = user - return True + user = get_user(token) + if user is None: + return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Invalid token\n") + connection.username = user -async def user_info_handler(websocket): - user = websocket.user - await websocket.send(f"Hello {user}!") - message = await websocket.recv() - assert message == f"Goodbye {user}." +def check_credentials(username, password): + """Authenticate user with HTTP Basic Auth.""" + return username == get_user(password) -# Start all five servers +basic_auth = websockets_basic_auth(check_credentials=check_credentials) async def main(): + """Start one HTTP server and four WebSocket servers.""" # Set the stop condition when receiving SIGINT or SIGTERM. loop = asyncio.get_running_loop() stop = loop.create_future() loop.add_signal_handler(signal.SIGINT, stop.set_result, None) loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - async with serve( - noop_handler, - host="", - port=8000, - process_request=serve_html, - ), serve( - first_message_handler, - host="", - port=8001, - ), serve( - query_param_handler, - host="", - port=8002, - create_protocol=QueryParamProtocol, - ), serve( - cookie_handler, - host="", - port=8003, - create_protocol=CookieProtocol, - ), serve( - user_info_handler, - host="", - port=8004, - create_protocol=UserInfoProtocol, + async with ( + serve(handler, host="", port=8000, process_request=serve_html), + serve(first_message_handler, host="", port=8001), + serve(handler, host="", port=8002, process_request=query_param_auth), + serve(handler, host="", port=8003, process_request=cookie_auth), + serve(handler, host="", port=8004, process_request=basic_auth), ): print("Running on http://localhost:8000/") await stop diff --git a/experiments/authentication/script.js b/experiments/authentication/script.js index ec4e5e670..01dd5b168 100644 --- a/experiments/authentication/script.js +++ b/experiments/authentication/script.js @@ -1,4 +1,5 @@ -var token = window.parent.token; +var token = window.parent.token, + user = window.parent.user; function getExpectedEvents() { return [ @@ -7,7 +8,7 @@ function getExpectedEvents() { }, { type: "message", - data: `Hello ${window.parent.user}!`, + data: `Hello ${user}!`, }, { type: "close", diff --git a/experiments/authentication/test.js b/experiments/authentication/test.js index 428830ff3..e05ca697e 100644 --- a/experiments/authentication/test.js +++ b/experiments/authentication/test.js @@ -1,6 +1,4 @@ -// for connecting to WebSocket servers var token = document.body.dataset.token; -// for test assertions only const params = new URLSearchParams(window.location.search); var user = params.get("user"); diff --git a/experiments/authentication/user_info.js b/experiments/authentication/user_info.js index 1dab2ce4c..bc9a3f148 100644 --- a/experiments/authentication/user_info.js +++ b/experiments/authentication/user_info.js @@ -1,5 +1,5 @@ window.addEventListener("DOMContentLoaded", () => { - const uri = `ws://token:${token}@localhost:8004/`; + const uri = `ws://${user}:${token}@localhost:8004/`; const websocket = new WebSocket(uri); websocket.onmessage = ({ data }) => { From bc4b8f2776cd4da9aee2e67f66764bf26a5ad09e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 30 Sep 2024 21:07:28 +0200 Subject: [PATCH 15/51] Add option to force sending text or binary frames. Fix #1515. --- src/websockets/asyncio/connection.py | 124 +++++++++++++++------------ tests/asyncio/test_connection.py | 72 ++++++++++++++-- 2 files changed, 134 insertions(+), 62 deletions(-) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 702e69995..12871e4b3 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -251,12 +251,13 @@ async def recv(self, decode: bool | None = None) -> Data: You may override this behavior with the ``decode`` argument: - * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames - and return a bytestring (:class:`bytes`). This may be useful to - optimize performance when decoding isn't needed. + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames and + return a bytestring (:class:`bytes`). This improves performance + when decoding isn't needed, for example if the message contains + JSON and you're using a JSON library that expects a bytestring. * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames - and return a string (:class:`str`). This is useful for servers - that send binary frames instead of text frames. + and return a string (:class:`str`). This may be useful for + servers that send binary frames instead of text frames. Raises: ConnectionClosed: When the connection is closed. @@ -333,7 +334,11 @@ async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data "is already running recv or recv_streaming" ) from None - async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> None: + async def send( + self, + message: Data | Iterable[Data] | AsyncIterable[Data], + text: bool | None = None, + ) -> None: """ Send a message. @@ -344,6 +349,17 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + You may override this behavior with the ``text`` argument: + + * Set ``text=True`` to send a bytestring or bytes-like object + (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) as a + Text_ frame. This improves performance when the message is already + UTF-8 encoded, for example if the message contains JSON and you're + using a JSON library that produces a bytestring. + * Set ``text=False`` to send a string (:class:`str`) in a Binary_ + frame. This may be useful for servers that expect binary frames + instead of text frames. + :meth:`send` also accepts an iterable or an asynchronous iterable of strings, bytestrings, or bytes-like objects to enable fragmentation_. Each item is treated as a message fragment and sent in its own frame. @@ -393,12 +409,20 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No # strings and bytes-like objects are iterable. if isinstance(message, str): - async with self.send_context(): - self.protocol.send_text(message.encode()) + if text is False: + async with self.send_context(): + self.protocol.send_binary(message.encode()) + else: + async with self.send_context(): + self.protocol.send_text(message.encode()) elif isinstance(message, BytesLike): - async with self.send_context(): - self.protocol.send_binary(message) + if text is True: + async with self.send_context(): + self.protocol.send_text(message) + else: + async with self.send_context(): + self.protocol.send_binary(message) # Catch a common mistake -- passing a dict to send(). @@ -419,36 +443,32 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No try: # First fragment. if isinstance(chunk, str): - text = True - async with self.send_context(): - self.protocol.send_text( - chunk.encode(), - fin=False, - ) + if text is False: + async with self.send_context(): + self.protocol.send_binary(chunk.encode(), fin=False) + else: + async with self.send_context(): + self.protocol.send_text(chunk.encode(), fin=False) + encode = True elif isinstance(chunk, BytesLike): - text = False - async with self.send_context(): - self.protocol.send_binary( - chunk, - fin=False, - ) + if text is True: + async with self.send_context(): + self.protocol.send_text(chunk, fin=False) + else: + async with self.send_context(): + self.protocol.send_binary(chunk, fin=False) + encode = False else: raise TypeError("iterable must contain bytes or str") # Other fragments for chunk in chunks: - if isinstance(chunk, str) and text: + if isinstance(chunk, str) and encode: async with self.send_context(): - self.protocol.send_continuation( - chunk.encode(), - fin=False, - ) - elif isinstance(chunk, BytesLike) and not text: + self.protocol.send_continuation(chunk.encode(), fin=False) + elif isinstance(chunk, BytesLike) and not encode: async with self.send_context(): - self.protocol.send_continuation( - chunk, - fin=False, - ) + self.protocol.send_continuation(chunk, fin=False) else: raise TypeError("iterable must contain uniform types") @@ -481,36 +501,32 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No try: # First fragment. if isinstance(chunk, str): - text = True - async with self.send_context(): - self.protocol.send_text( - chunk.encode(), - fin=False, - ) + if text is False: + async with self.send_context(): + self.protocol.send_binary(chunk.encode(), fin=False) + else: + async with self.send_context(): + self.protocol.send_text(chunk.encode(), fin=False) + encode = True elif isinstance(chunk, BytesLike): - text = False - async with self.send_context(): - self.protocol.send_binary( - chunk, - fin=False, - ) + if text is True: + async with self.send_context(): + self.protocol.send_text(chunk, fin=False) + else: + async with self.send_context(): + self.protocol.send_binary(chunk, fin=False) + encode = False else: raise TypeError("async iterable must contain bytes or str") # Other fragments async for chunk in achunks: - if isinstance(chunk, str) and text: + if isinstance(chunk, str) and encode: async with self.send_context(): - self.protocol.send_continuation( - chunk.encode(), - fin=False, - ) - elif isinstance(chunk, BytesLike) and not text: + self.protocol.send_continuation(chunk.encode(), fin=False) + elif isinstance(chunk, BytesLike) and not encode: async with self.send_context(): - self.protocol.send_continuation( - chunk, - fin=False, - ) + self.protocol.send_continuation(chunk, fin=False) else: raise TypeError("async iterable must contain uniform types") diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 70d9dad63..563cf2b17 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -190,13 +190,13 @@ async def test_recv_binary(self): await self.remote_connection.send(b"\x01\x02\xfe\xff") self.assertEqual(await self.connection.recv(), b"\x01\x02\xfe\xff") - async def test_recv_encoded_text(self): - """recv receives an UTF-8 encoded text message.""" + async def test_recv_text_as_bytes(self): + """recv receives a text message as bytes.""" await self.remote_connection.send("😀") self.assertEqual(await self.connection.recv(decode=False), "😀".encode()) - async def test_recv_decoded_binary(self): - """recv receives an UTF-8 decoded binary message.""" + async def test_recv_binary_as_text(self): + """recv receives a binary message as a str.""" await self.remote_connection.send("😀".encode()) self.assertEqual(await self.connection.recv(decode=True), "😀") @@ -304,16 +304,16 @@ async def test_recv_streaming_binary(self): [b"\x01\x02\xfe\xff"], ) - async def test_recv_streaming_encoded_text(self): - """recv_streaming receives an UTF-8 encoded text message.""" + async def test_recv_streaming_text_as_bytes(self): + """recv_streaming receives a text message as bytes.""" await self.remote_connection.send("😀") self.assertEqual( await alist(self.connection.recv_streaming(decode=False)), ["😀".encode()], ) - async def test_recv_streaming_decoded_binary(self): - """recv_streaming receives a UTF-8 decoded binary message.""" + async def test_recv_streaming_binary_as_str(self): + """recv_streaming receives a binary message as a str.""" await self.remote_connection.send("😀".encode()) self.assertEqual( await alist(self.connection.recv_streaming(decode=True)), @@ -438,6 +438,16 @@ async def test_send_binary(self): await self.connection.send(b"\x01\x02\xfe\xff") self.assertEqual(await self.remote_connection.recv(), b"\x01\x02\xfe\xff") + async def test_send_binary_from_str(self): + """send sends a binary message from a str.""" + await self.connection.send("😀", text=False) + self.assertEqual(await self.remote_connection.recv(), "😀".encode()) + + async def test_send_text_from_bytes(self): + """send sends a text message from bytes.""" + await self.connection.send("😀".encode(), text=True) + self.assertEqual(await self.remote_connection.recv(), "😀") + async def test_send_fragmented_text(self): """send sends a fragmented text message.""" await self.connection.send(["😀", "😀"]) @@ -456,6 +466,24 @@ async def test_send_fragmented_binary(self): [b"\x01\x02", b"\xfe\xff", b""], ) + async def test_send_fragmented_binary_from_str(self): + """send sends a fragmented binary message from a str.""" + await self.connection.send(["😀", "😀"], text=False) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀".encode(), "😀".encode(), b""], + ) + + async def test_send_fragmented_text_from_bytes(self): + """send sends a fragmented text message from bytes.""" + await self.connection.send(["😀".encode(), "😀".encode()], text=True) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + async def test_send_async_fragmented_text(self): """send sends a fragmented text message asynchronously.""" @@ -484,6 +512,34 @@ async def fragments(): [b"\x01\x02", b"\xfe\xff", b""], ) + async def test_send_async_fragmented_binary_from_str(self): + """send sends a fragmented binary message from a str asynchronously.""" + + async def fragments(): + yield "😀" + yield "😀" + + await self.connection.send(fragments(), text=False) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀".encode(), "😀".encode(), b""], + ) + + async def test_send_async_fragmented_text_from_bytes(self): + """send sends a fragmented text message from bytes asynchronously.""" + + async def fragments(): + yield "😀".encode() + yield "😀".encode() + + await self.connection.send(fragments(), text=True) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + async def test_send_connection_closed_ok(self): """send raises ConnectionClosedOK after a normal closure.""" await self.remote_connection.close() From 7fdd932c6b29a9ef4db46b2141371f42207b4f00 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2024 14:21:00 +0200 Subject: [PATCH 16/51] Review & update gitignore. --- .gitignore | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index d8e6697a8..291bf1fb6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,16 +1,16 @@ *.pyc *.so .coverage -.direnv +.direnv/ .envrc .idea/ -.mypy_cache -.tox +.mypy_cache/ +.tox/ +.vscode/ build/ compliance/reports/ -experiments/compression/corpus/ dist/ docs/_build/ +experiments/compression/corpus/ htmlcov/ -MANIFEST -websockets.egg-info/ +src/websockets.egg-info/ From c5985d5c4192390b2da58ae97015e3ab1ba41cd2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 13 Oct 2024 16:21:16 +0200 Subject: [PATCH 17/51] Update compliance test suite. --- Makefile | 4 +- compliance/README.rst | 74 ++++++++++++++++++---------- compliance/asyncio/client.py | 52 +++++++++++++++++++ compliance/asyncio/server.py | 32 ++++++++++++ compliance/config/fuzzingclient.json | 11 +++++ compliance/config/fuzzingserver.json | 7 +++ compliance/fuzzingclient.json | 11 ----- compliance/fuzzingserver.json | 12 ----- compliance/sync/client.py | 51 +++++++++++++++++++ compliance/sync/server.py | 31 ++++++++++++ compliance/test_client.py | 48 ------------------ compliance/test_server.py | 29 ----------- 12 files changed, 234 insertions(+), 128 deletions(-) create mode 100644 compliance/asyncio/client.py create mode 100644 compliance/asyncio/server.py create mode 100644 compliance/config/fuzzingclient.json create mode 100644 compliance/config/fuzzingserver.json delete mode 100644 compliance/fuzzingclient.json delete mode 100644 compliance/fuzzingserver.json create mode 100644 compliance/sync/client.py create mode 100644 compliance/sync/server.py delete mode 100644 compliance/test_client.py delete mode 100644 compliance/test_server.py diff --git a/Makefile b/Makefile index fd36d0367..06bfe9edc 100644 --- a/Makefile +++ b/Makefile @@ -8,8 +8,8 @@ build: python setup.py build_ext --inplace style: - ruff format src tests - ruff check --fix src tests + ruff format compliance src tests + ruff check --fix compliance src tests types: mypy --strict src diff --git a/compliance/README.rst b/compliance/README.rst index 8570f9176..c7c7c93b4 100644 --- a/compliance/README.rst +++ b/compliance/README.rst @@ -4,47 +4,69 @@ Autobahn Testsuite General information and installation instructions are available at https://github.com/crossbario/autobahn-testsuite. -To improve performance, you should compile the C extension first:: +Running the test suite +---------------------- + +All commands below must be run from the root directory of the repository. + +To get acceptable performance, compile the C extension first: + +.. code-block:: console $ python setup.py build_ext --inplace -Running the test suite ----------------------- +Run each command in a different shell. Testing takes several minutes to complete +— wstest is the bottleneck. When clients finish, stop servers with Ctrl-C. + +You can exclude slow tests by modifying the configuration files as follows:: + + "exclude-cases": ["9.*", "12.*", "13.*"] -All commands below must be run from the directory containing this file. +The test server and client applications shouldn't display any exceptions. -To test the server:: +To test the servers: - $ PYTHONPATH=.. python test_server.py - $ wstest -m fuzzingclient +.. code-block:: console -To test the client:: + $ PYTHONPATH=src python compliance/asyncio/server.py + $ PYTHONPATH=src python compliance/sync/server.py - $ wstest -m fuzzingserver - $ PYTHONPATH=.. python test_client.py + $ docker run --interactive --tty --rm \ + --volume "${PWD}/compliance/config:/config" \ + --volume "${PWD}/compliance/reports:/reports" \ + --name fuzzingclient \ + crossbario/autobahn-testsuite \ + wstest --mode fuzzingclient --spec /config/fuzzingclient.json -Run the first command in a shell. Run the second command in another shell. -It should take about ten minutes to complete — wstest is the bottleneck. -Then kill the first one with Ctrl-C. + $ open reports/servers/index.html -The test client or server shouldn't display any exceptions. The results are -stored in reports/clients/index.html. +To test the clients: -Note that the Autobahn software only supports Python 2, while ``websockets`` -only supports Python 3; you need two different environments. +.. code-block:: console + $ docker run --interactive --tty --rm \ + --volume "${PWD}/compliance/config:/config" \ + --volume "${PWD}/compliance/reports:/reports" \ + --publish 9001:9001 \ + --name fuzzingserver \ + crossbario/autobahn-testsuite \ + wstest --mode fuzzingserver --spec /config/fuzzingserver.json + + $ PYTHONPATH=src python compliance/asyncio/client.py + $ PYTHONPATH=src python compliance/sync/client.py + + $ open reports/clients/index.html Conformance notes ----------------- Some test cases are more strict than the RFC. Given the implementation of the -library and the test echo client or server, ``websockets`` gets a "Non-Strict" -in these cases. - -In 3.2, 3.3, 4.1.3, 4.1.4, 4.2.3, 4.2.4, and 5.15 ``websockets`` notices the -protocol error and closes the connection before it has had a chance to echo -the previous frame. +library and the test client and server applications, websockets passes with a +"Non-Strict" result in these cases. -In 6.4.3 and 6.4.4, even though it uses an incremental decoder, ``websockets`` -doesn't notice the invalid utf-8 fast enough to get a "Strict" pass. These -tests are more strict than the RFC. +In 3.2, 3.3, 4.1.3, 4.1.4, 4.2.3, 4.2.4, and 5.15 websockets notices the +protocol error and closes the connection at the library level before the +application gets a chance to echo the previous frame. +In 6.4.3 and 6.4.4, even though it uses an incremental decoder, websockets +doesn't notice the invalid utf-8 fast enough to get a "Strict" pass. These tests +are more strict than the RFC. diff --git a/compliance/asyncio/client.py b/compliance/asyncio/client.py new file mode 100644 index 000000000..5b0bfb3ae --- /dev/null +++ b/compliance/asyncio/client.py @@ -0,0 +1,52 @@ +import asyncio +import json +import logging + +from websockets.asyncio.client import connect +from websockets.exceptions import WebSocketException + + +logging.basicConfig(level=logging.WARNING) + +SERVER = "ws://127.0.0.1:9001" + + +async def get_case_count(): + async with connect(f"{SERVER}/getCaseCount") as ws: + return json.loads(await ws.recv()) + + +async def run_case(case): + async with connect( + f"{SERVER}/runCase?case={case}", + user_agent_header="websockets.asyncio", + max_size=2**25, + ) as ws: + async for msg in ws: + await ws.send(msg) + + +async def update_reports(): + async with connect(f"{SERVER}/updateReports", open_timeout=60): + pass + + +async def main(): + cases = await get_case_count() + for case in range(1, cases + 1): + print(f"Running test case {case:03d} / {cases}... ", end="\t") + try: + await run_case(case) + except WebSocketException as exc: + print(f"ERROR: {type(exc).__name__}: {exc}") + except Exception as exc: + print(f"FAIL: {type(exc).__name__}: {exc}") + else: + print("OK") + print("Ran {cases} test cases") + await update_reports() + print("Updated reports") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/compliance/asyncio/server.py b/compliance/asyncio/server.py new file mode 100644 index 000000000..cff2728c9 --- /dev/null +++ b/compliance/asyncio/server.py @@ -0,0 +1,32 @@ +import asyncio +import logging + +from websockets.asyncio.server import serve + + +logging.basicConfig(level=logging.WARNING) + +HOST, PORT = "0.0.0.0", 9002 + + +async def echo(ws): + async for msg in ws: + await ws.send(msg) + + +async def main(): + async with serve( + echo, + HOST, + PORT, + server_header="websockets.sync", + max_size=2**25, + ) as server: + try: + await server.serve_forever() + except KeyboardInterrupt: + pass + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/compliance/config/fuzzingclient.json b/compliance/config/fuzzingclient.json new file mode 100644 index 000000000..37f8f3a35 --- /dev/null +++ b/compliance/config/fuzzingclient.json @@ -0,0 +1,11 @@ + +{ + "servers": [{ + "url": "ws://host.docker.internal:9002" + }, { + "url": "ws://host.docker.internal:9003" + }], + "outdir": "./reports/servers", + "cases": ["*"], + "exclude-cases": [] +} diff --git a/compliance/config/fuzzingserver.json b/compliance/config/fuzzingserver.json new file mode 100644 index 000000000..1bcb5959d --- /dev/null +++ b/compliance/config/fuzzingserver.json @@ -0,0 +1,7 @@ + +{ + "url": "ws://localhost:9001", + "outdir": "./reports/clients", + "cases": ["*"], + "exclude-cases": [] +} diff --git a/compliance/fuzzingclient.json b/compliance/fuzzingclient.json deleted file mode 100644 index 202ff49a0..000000000 --- a/compliance/fuzzingclient.json +++ /dev/null @@ -1,11 +0,0 @@ - -{ - "options": {"failByDrop": false}, - "outdir": "./reports/servers", - - "servers": [{"agent": "websockets", "url": "ws://localhost:8642", "options": {"version": 18}}], - - "cases": ["*"], - "exclude-cases": [], - "exclude-agent-cases": {} -} diff --git a/compliance/fuzzingserver.json b/compliance/fuzzingserver.json deleted file mode 100644 index 1bdb42723..000000000 --- a/compliance/fuzzingserver.json +++ /dev/null @@ -1,12 +0,0 @@ - -{ - "url": "ws://localhost:8642", - - "options": {"failByDrop": false}, - "outdir": "./reports/clients", - "webport": 8080, - - "cases": ["*"], - "exclude-cases": [], - "exclude-agent-cases": {} -} diff --git a/compliance/sync/client.py b/compliance/sync/client.py new file mode 100644 index 000000000..e585496f3 --- /dev/null +++ b/compliance/sync/client.py @@ -0,0 +1,51 @@ +import json +import logging + +from websockets.exceptions import WebSocketException +from websockets.sync.client import connect + + +logging.basicConfig(level=logging.WARNING) + +SERVER = "ws://127.0.0.1:9001" + + +def get_case_count(): + with connect(f"{SERVER}/getCaseCount") as ws: + return json.loads(ws.recv()) + + +def run_case(case): + with connect( + f"{SERVER}/runCase?case={case}", + user_agent_header="websockets.sync", + max_size=2**25, + ) as ws: + for msg in ws: + ws.send(msg) + + +def update_reports(): + with connect(f"{SERVER}/updateReports", open_timeout=60): + pass + + +def main(): + cases = get_case_count() + for case in range(1, cases + 1): + print(f"Running test case {case:03d} / {cases}... ", end="\t") + try: + run_case(case) + except WebSocketException as exc: + print(f"ERROR: {type(exc).__name__}: {exc}") + except Exception as exc: + print(f"FAIL: {type(exc).__name__}: {exc}") + else: + print("OK") + print("Ran {cases} test cases") + update_reports() + print("Updated reports") + + +if __name__ == "__main__": + main() diff --git a/compliance/sync/server.py b/compliance/sync/server.py new file mode 100644 index 000000000..c3cb4d989 --- /dev/null +++ b/compliance/sync/server.py @@ -0,0 +1,31 @@ +import logging + +from websockets.sync.server import serve + + +logging.basicConfig(level=logging.WARNING) + +HOST, PORT = "0.0.0.0", 9003 + + +def echo(ws): + for msg in ws: + ws.send(msg) + + +def main(): + with serve( + echo, + HOST, + PORT, + server_header="websockets.asyncio", + max_size=2**25, + ) as server: + try: + server.serve_forever() + except KeyboardInterrupt: + pass + + +if __name__ == "__main__": + main() diff --git a/compliance/test_client.py b/compliance/test_client.py deleted file mode 100644 index 8e22569fd..000000000 --- a/compliance/test_client.py +++ /dev/null @@ -1,48 +0,0 @@ -import asyncio -import json -import logging -import urllib.parse - -from websockets.asyncio.client import connect - - -logging.basicConfig(level=logging.WARNING) - -# Uncomment this line to make only websockets more verbose. -# logging.getLogger('websockets').setLevel(logging.DEBUG) - - -SERVER = "ws://127.0.0.1:8642" -AGENT = "websockets" - - -async def get_case_count(server): - uri = f"{server}/getCaseCount" - async with connect(uri) as ws: - msg = ws.recv() - return json.loads(msg) - - -async def run_case(server, case, agent): - uri = f"{server}/runCase?case={case}&agent={agent}" - async with connect(uri, max_size=2 ** 25, max_queue=1) as ws: - async for msg in ws: - await ws.send(msg) - - -async def update_reports(server, agent): - uri = f"{server}/updateReports?agent={agent}" - async with connect(uri): - pass - - -async def run_tests(server, agent): - cases = await get_case_count(server) - for case in range(1, cases + 1): - print(f"Running test case {case} out of {cases}", end="\r") - await run_case(server, case, agent) - print(f"Ran {cases} test cases ") - await update_reports(server, agent) - - -asyncio.run(run_tests(SERVER, urllib.parse.quote(AGENT))) diff --git a/compliance/test_server.py b/compliance/test_server.py deleted file mode 100644 index 39176e902..000000000 --- a/compliance/test_server.py +++ /dev/null @@ -1,29 +0,0 @@ -import asyncio -import logging - -from websockets.asyncio.server import serve - - -logging.basicConfig(level=logging.WARNING) - -# Uncomment this line to make only websockets more verbose. -# logging.getLogger('websockets').setLevel(logging.DEBUG) - - -HOST, PORT = "127.0.0.1", 8642 - - -async def echo(ws): - async for msg in ws: - await ws.send(msg) - - -async def main(): - with serve(echo, HOST, PORT, max_size=2 ** 25, max_queue=1): - try: - await asyncio.get_running_loop().create_future() # run forever - except KeyboardInterrupt: - pass - - -asyncio.run(main()) From b2f0a7647f1402c84a8dabb391c3ca7371975eb3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 13 Oct 2024 21:32:30 +0200 Subject: [PATCH 18/51] Add option to force sending text or binary frames. This adds the same functionality to the threading implemetation as bc4b8f2 did to the asyncio implementation. Refs #1515. --- src/websockets/asyncio/connection.py | 43 +++++++++++--------- src/websockets/sync/connection.py | 61 +++++++++++++++++----------- tests/sync/test_connection.py | 28 +++++++++++++ 3 files changed, 90 insertions(+), 42 deletions(-) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 12871e4b3..3b81e386b 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -409,19 +409,17 @@ async def send( # strings and bytes-like objects are iterable. if isinstance(message, str): - if text is False: - async with self.send_context(): + async with self.send_context(): + if text is False: self.protocol.send_binary(message.encode()) - else: - async with self.send_context(): + else: self.protocol.send_text(message.encode()) elif isinstance(message, BytesLike): - if text is True: - async with self.send_context(): + async with self.send_context(): + if text is True: self.protocol.send_text(message) - else: - async with self.send_context(): + else: self.protocol.send_binary(message) # Catch a common mistake -- passing a dict to send(). @@ -443,19 +441,17 @@ async def send( try: # First fragment. if isinstance(chunk, str): - if text is False: - async with self.send_context(): + async with self.send_context(): + if text is False: self.protocol.send_binary(chunk.encode(), fin=False) - else: - async with self.send_context(): + else: self.protocol.send_text(chunk.encode(), fin=False) encode = True elif isinstance(chunk, BytesLike): - if text is True: - async with self.send_context(): + async with self.send_context(): + if text is True: self.protocol.send_text(chunk, fin=False) - else: - async with self.send_context(): + else: self.protocol.send_binary(chunk, fin=False) encode = False else: @@ -480,7 +476,10 @@ async def send( # We're half-way through a fragmented message and we can't # complete it. This makes the connection unusable. async with self.send_context(): - self.protocol.fail(1011, "error in fragmented message") + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "error in fragmented message", + ) raise finally: @@ -538,7 +537,10 @@ async def send( # We're half-way through a fragmented message and we can't # complete it. This makes the connection unusable. async with self.send_context(): - self.protocol.fail(1011, "error in fragmented message") + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "error in fragmented message", + ) raise finally: @@ -568,7 +570,10 @@ async def close(self, code: int = 1000, reason: str = "") -> None: # to terminate after calling a method that sends a close frame. async with self.send_context(): if self.fragmented_send_waiter is not None: - self.protocol.fail(1011, "close during fragmented message") + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "close during fragmented message", + ) else: self.protocol.send_close(code, reason) except ConnectionClosed: diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 8c5df9592..3f4cac09f 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -251,7 +251,11 @@ def recv_streaming(self) -> Iterator[Data]: "is already running recv or recv_streaming" ) from None - def send(self, message: Data | Iterable[Data]) -> None: + def send( + self, + message: Data | Iterable[Data], + text: bool | None = None, + ) -> None: """ Send a message. @@ -262,6 +266,17 @@ def send(self, message: Data | Iterable[Data]) -> None: .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + You may override this behavior with the ``text`` argument: + + * Set ``text=True`` to send a bytestring or bytes-like object + (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) as a + Text_ frame. This improves performance when the message is already + UTF-8 encoded, for example if the message contains JSON and you're + using a JSON library that produces a bytestring. + * Set ``text=False`` to send a string (:class:`str`) in a Binary_ + frame. This may be useful for servers that expect binary frames + instead of text frames. + :meth:`send` also accepts an iterable of strings, bytestrings, or bytes-like objects to enable fragmentation_. Each item is treated as a message fragment and sent in its own frame. All items must be of the @@ -300,7 +315,10 @@ def send(self, message: Data | Iterable[Data]) -> None: "cannot call send while another thread " "is already running send" ) - self.protocol.send_text(message.encode()) + if text is False: + self.protocol.send_binary(message.encode()) + else: + self.protocol.send_text(message.encode()) elif isinstance(message, BytesLike): with self.send_context(): @@ -309,7 +327,10 @@ def send(self, message: Data | Iterable[Data]) -> None: "cannot call send while another thread " "is already running send" ) - self.protocol.send_binary(message) + if text is True: + self.protocol.send_text(message) + else: + self.protocol.send_binary(message) # Catch a common mistake -- passing a dict to send(). @@ -328,7 +349,6 @@ def send(self, message: Data | Iterable[Data]) -> None: try: # First fragment. if isinstance(chunk, str): - text = True with self.send_context(): if self.send_in_progress: raise ConcurrencyError( @@ -336,12 +356,12 @@ def send(self, message: Data | Iterable[Data]) -> None: "is already running send" ) self.send_in_progress = True - self.protocol.send_text( - chunk.encode(), - fin=False, - ) + if text is False: + self.protocol.send_binary(chunk.encode(), fin=False) + else: + self.protocol.send_text(chunk.encode(), fin=False) + encode = True elif isinstance(chunk, BytesLike): - text = False with self.send_context(): if self.send_in_progress: raise ConcurrencyError( @@ -349,29 +369,24 @@ def send(self, message: Data | Iterable[Data]) -> None: "is already running send" ) self.send_in_progress = True - self.protocol.send_binary( - chunk, - fin=False, - ) + if text is True: + self.protocol.send_text(chunk, fin=False) + else: + self.protocol.send_binary(chunk, fin=False) + encode = False else: raise TypeError("data iterable must contain bytes or str") # Other fragments for chunk in chunks: - if isinstance(chunk, str) and text: + if isinstance(chunk, str) and encode: with self.send_context(): assert self.send_in_progress - self.protocol.send_continuation( - chunk.encode(), - fin=False, - ) - elif isinstance(chunk, BytesLike) and not text: + self.protocol.send_continuation(chunk.encode(), fin=False) + elif isinstance(chunk, BytesLike) and not encode: with self.send_context(): assert self.send_in_progress - self.protocol.send_continuation( - chunk, - fin=False, - ) + self.protocol.send_continuation(chunk, fin=False) else: raise TypeError("data iterable must contain uniform types") diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 16f92e164..87333fd35 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -308,6 +308,16 @@ def test_send_binary(self): self.connection.send(b"\x01\x02\xfe\xff") self.assertEqual(self.remote_connection.recv(), b"\x01\x02\xfe\xff") + def test_send_binary_from_str(self): + """send sends a binary message from a str.""" + self.connection.send("😀", text=False) + self.assertEqual(self.remote_connection.recv(), "😀".encode()) + + def test_send_text_from_bytes(self): + """send sends a text message from bytes.""" + self.connection.send("😀".encode(), text=True) + self.assertEqual(self.remote_connection.recv(), "😀") + def test_send_fragmented_text(self): """send sends a fragmented text message.""" self.connection.send(["😀", "😀"]) @@ -326,6 +336,24 @@ def test_send_fragmented_binary(self): [b"\x01\x02", b"\xfe\xff", b""], ) + def test_send_fragmented_binary_from_str(self): + """send sends a fragmented binary message from a str.""" + self.connection.send(["😀", "😀"], text=False) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + list(self.remote_connection.recv_streaming()), + ["😀".encode(), "😀".encode(), b""], + ) + + def test_send_fragmented_text_from_bytes(self): + """send sends a fragmented text message from bytes.""" + self.connection.send(["😀".encode(), "😀".encode()], text=True) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + list(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + def test_send_connection_closed_ok(self): """send raises ConnectionClosedOK after a normal closure.""" self.remote_connection.close() From e5182c95a3332535a034c409d59463afbd760f0c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 13 Oct 2024 21:45:17 +0200 Subject: [PATCH 19/51] Blind fix for coverage failing in GitHub Actions. It doesn't fail locally. --- pyproject.toml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6a0ab8d7c..4e26c757e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,15 +60,19 @@ source = [ [tool.coverage.report] exclude_lines = [ + "pragma: no cover", "except ImportError:", "if self.debug:", "if sys.platform != \"win32\":", "if typing.TYPE_CHECKING:", - "pragma: no cover", "raise AssertionError", "self.fail\\(\".*\"\\)", "@unittest.skip", ] +partial_branches = [ + "pragma: no branch", + "with self.assertRaises\\(.*\\)", +] [tool.ruff] target-version = "py312" From 1387c976833956ea4d44ed0bd541fe648a065ed7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 25 Oct 2024 13:26:52 +0200 Subject: [PATCH 20/51] Rewrite sync Assembler to improve performance. Previously, a latch was used to synchronize the user thread reading messages and the background thread reading from the network. This required two thread switches per message. Now, the background thread writes messages to queue, from which the user thread reads. This allows passing several frames at each thread switch, reducing the overhead. With this server code: async def test(websocket): for i in range(int(await websocket.recv())): await websocket.send(f"{{\"iteration\": {i}}}") async with serve(test, "localhost", 8765) as server: await server.serve_forever() and this client code: with connect("ws://localhost:8765", compression=None) as websocket: websocket.send("1_000_000") for message in websocket: pass an unscientific benchmark (running it on my laptop) shows a 2.5x speedup, going from 11 seconds to 4.4 seconds. Setting a very large recv_bufsize and max_size doesn't yield significant further improvement. Flow control was tested by inserting debug logs in maybe_pause/resume() and by measuring the wait for the recv_flow_control lock. It showed the expected behavior of pausing and unpausing coupled with some wait time. The new implementation mirrors the asyncio implementation and gains the option to prevent or force decoding of frames. Fix #1376 for the threading implementation. --- docs/project/changelog.rst | 16 +- src/websockets/asyncio/client.py | 2 +- src/websockets/asyncio/messages.py | 59 +++-- src/websockets/asyncio/server.py | 2 +- src/websockets/sync/client.py | 12 +- src/websockets/sync/connection.py | 79 ++++-- src/websockets/sync/messages.py | 333 ++++++++++++------------ src/websockets/sync/server.py | 12 +- tests/asyncio/test_connection.py | 19 +- tests/asyncio/test_messages.py | 12 +- tests/sync/test_connection.py | 106 +++++--- tests/sync/test_messages.py | 400 ++++++++++++++--------------- 12 files changed, 585 insertions(+), 467 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index f5b4812bd..410671239 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -70,10 +70,21 @@ Backwards-incompatible changes If you wrote an :class:`extension ` that relies on methods not provided by these new types, you may need to update your code. +New features +............ + +* Added an option to receive text frames as :class:`bytes`, without decoding, + in the :mod:`threading` implementation; also binary frames as :class:`str`. + +* Added an option to send :class:`bytes` as a text frame in the :mod:`asyncio` + and :mod:`threading` implementations, as well as :class:`str` a binary frame. + Improvements ............ -* Sending or receiving large compressed frames is now faster. +* The :mod:`threading` implementation receives messages faster. + +* Sending or receiving large compressed messages is now faster. .. _13.1: @@ -198,6 +209,9 @@ New features * Validated compatibility with Python 3.12 and 3.13. +* Added an option to receive text frames as :class:`bytes`, without decoding, + in the :mod:`asyncio` implementation; also binary frames as :class:`str`. + * Added :doc:`environment variables <../reference/variables>` to configure debug logs, the ``Server`` and ``User-Agent`` headers, as well as security limits. diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 23b1a348a..0c8bedc5d 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -45,7 +45,7 @@ class ClientConnection(Connection): closed with any other code. The ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_queue``, - and ``write_limit`` arguments the same meaning as in :func:`connect`. + and ``write_limit`` arguments have the same meaning as in :func:`connect`. Args: protocol: Sans-I/O connection. diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index e3ec5062f..09be22ba2 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -60,6 +60,7 @@ def reset(self, items: Iterable[T]) -> None: self.queue.extend(items) def abort(self) -> None: + """Close the queue, raising EOFError in get() if necessary.""" if self.get_waiter is not None and not self.get_waiter.done(): self.get_waiter.set_exception(EOFError("stream of frames ended")) # Clear the queue to avoid storing unnecessary data in memory. @@ -89,7 +90,7 @@ def __init__( # pragma: no cover pause: Callable[[], Any] = lambda: None, resume: Callable[[], Any] = lambda: None, ) -> None: - # Queue of incoming messages. Each item is a queue of frames. + # Queue of incoming frames. self.frames: SimpleQueue[Frame] = SimpleQueue() # We cannot put a hard limit on the size of the queue because a single @@ -140,36 +141,35 @@ async def get(self, decode: bool | None = None) -> Data: if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - # Locking with get_in_progress ensures only one coroutine can get here. self.get_in_progress = True - # First frame + # Locking with get_in_progress prevents concurrent execution until + # get() fetches a complete message or is cancelled. + try: + # First frame frame = await self.frames.get() - except asyncio.CancelledError: - self.get_in_progress = False - raise - self.maybe_resume() - assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY - if decode is None: - decode = frame.opcode is OP_TEXT - frames = [frame] - - # Following frames, for fragmented messages - while not frame.fin: - try: - frame = await self.frames.get() - except asyncio.CancelledError: - # Put frames already received back into the queue - # so that future calls to get() can return them. - self.frames.reset(frames) - self.get_in_progress = False - raise self.maybe_resume() - assert frame.opcode is OP_CONT - frames.append(frame) - - self.get_in_progress = False + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + frames = [frame] + + # Following frames, for fragmented messages + while not frame.fin: + try: + frame = await self.frames.get() + except asyncio.CancelledError: + # Put frames already received back into the queue + # so that future calls to get() can return them. + self.frames.reset(frames) + raise + self.maybe_resume() + assert frame.opcode is OP_CONT + frames.append(frame) + + finally: + self.get_in_progress = False data = b"".join(frame.data for frame in frames) if decode: @@ -207,9 +207,14 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - # Locking with get_in_progress ensures only one coroutine can get here. self.get_in_progress = True + # Locking with get_in_progress prevents concurrent execution until + # get_iter() fetches a complete message or is cancelled. + + # If get_iter() raises an exception e.g. in decoder.decode(), + # get_in_progress remains set and the connection becomes unusable. + # First frame try: frame = await self.frames.get() diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index e11dd91f1..a6ae5996d 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -55,7 +55,7 @@ class ServerConnection(Connection): closed with any other code. The ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_queue``, - and ``write_limit`` arguments the same meaning as in :func:`serve`. + and ``write_limit`` arguments have the same meaning as in :func:`serve`. Args: protocol: Sans-I/O connection. diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 5e1ba6d84..42daa32ea 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -40,10 +40,12 @@ class ClientConnection(Connection): :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is closed with any other code. + The ``close_timeout`` and ``max_queue`` arguments have the same meaning as + in :func:`connect`. + Args: socket: Socket connected to a WebSocket server. protocol: Sans-I/O connection. - close_timeout: Timeout for closing the connection in seconds. """ @@ -53,6 +55,7 @@ def __init__( protocol: ClientProtocol, *, close_timeout: float | None = 10, + max_queue: int | tuple[int, int | None] = 16, ) -> None: self.protocol: ClientProtocol self.response_rcvd = threading.Event() @@ -60,6 +63,7 @@ def __init__( socket, protocol, close_timeout=close_timeout, + max_queue=max_queue, ) def handshake( @@ -135,6 +139,7 @@ def connect( close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, + max_queue: int | tuple[int, int | None] = 16, # Logging logger: LoggerLike | None = None, # Escape hatch for advanced customization @@ -183,6 +188,10 @@ def connect( :obj:`None` disables the timeout. max_size: Maximum size of incoming messages in bytes. :obj:`None` disables the limit. + max_queue: High-water mark of the buffer where frames are received. + It defaults to 16 frames. The low-water mark defaults to ``max_queue + // 4``. You may pass a ``(high, low)`` tuple to set the high-water + and low-water marks. logger: Logger for this client. It defaults to ``logging.getLogger("websockets.client")``. See the :doc:`logging guide <../../topics/logging>` for details. @@ -287,6 +296,7 @@ def connect( sock, protocol, close_timeout=close_timeout, + max_queue=max_queue, ) except Exception: if sock is not None: diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 3f4cac09f..3ab9f4937 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -49,10 +49,14 @@ def __init__( protocol: Protocol, *, close_timeout: float | None = 10, + max_queue: int | tuple[int, int | None] = 16, ) -> None: self.socket = socket self.protocol = protocol self.close_timeout = close_timeout + if isinstance(max_queue, int): + max_queue = (max_queue, None) + self.max_queue = max_queue # Inject reference to this instance in the protocol's logger. self.protocol.logger = logging.LoggerAdapter( @@ -76,8 +80,15 @@ def __init__( # Mutex serializing interactions with the protocol. self.protocol_mutex = threading.Lock() + # Lock stopping reads when the assembler buffer is full. + self.recv_flow_control = threading.Lock() + # Assembler turning frames into messages and serializing reads. - self.recv_messages = Assembler() + self.recv_messages = Assembler( + *self.max_queue, + pause=self.recv_flow_control.acquire, + resume=self.recv_flow_control.release, + ) # Whether we are busy sending a fragmented message. self.send_in_progress = False @@ -88,6 +99,10 @@ def __init__( # Mapping of ping IDs to pong waiters, in chronological order. self.ping_waiters: dict[bytes, threading.Event] = {} + # Exception raised in recv_events, to be chained to ConnectionClosed + # in the user thread in order to show why the TCP connection dropped. + self.recv_exc: BaseException | None = None + # Receiving events from the socket. This thread is marked as daemon to # allow creating a connection in a non-daemon thread and using it in a # daemon thread. This mustn't prevent the interpreter from exiting. @@ -97,10 +112,6 @@ def __init__( ) self.recv_events_thread.start() - # Exception raised in recv_events, to be chained to ConnectionClosed - # in the user thread in order to show why the TCP connection dropped. - self.recv_exc: BaseException | None = None - # Public attributes @property @@ -172,7 +183,7 @@ def __iter__(self) -> Iterator[Data]: except ConnectionClosedOK: return - def recv(self, timeout: float | None = None) -> Data: + def recv(self, timeout: float | None = None, decode: bool | None = None) -> Data: """ Receive the next message. @@ -191,6 +202,11 @@ def recv(self, timeout: float | None = None) -> Data: If the message is fragmented, wait until all fragments are received, reassemble them, and return the whole message. + Args: + timeout: Timeout for receiving a message in seconds. + decode: Set this flag to override the default behavior of returning + :class:`str` or :class:`bytes`. See below for details. + Returns: A string (:class:`str`) for a Text_ frame or a bytestring (:class:`bytes`) for a Binary_ frame. @@ -198,6 +214,16 @@ def recv(self, timeout: float | None = None) -> Data: .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + You may override this behavior with the ``decode`` argument: + + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames and + return a bytestring (:class:`bytes`). This improves performance + when decoding isn't needed, for example if the message contains + JSON and you're using a JSON library that expects a bytestring. + * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames + and return a string (:class:`str`). This may be useful for + servers that send binary frames instead of text frames. + Raises: ConnectionClosed: When the connection is closed. ConcurrencyError: If two threads call :meth:`recv` or @@ -205,7 +231,7 @@ def recv(self, timeout: float | None = None) -> Data: """ try: - return self.recv_messages.get(timeout) + return self.recv_messages.get(timeout, decode) except EOFError: # Wait for the protocol state to be CLOSED before accessing close_exc. self.recv_events_thread.join() @@ -216,16 +242,23 @@ def recv(self, timeout: float | None = None) -> Data: "is already running recv or recv_streaming" ) from None - def recv_streaming(self) -> Iterator[Data]: + def recv_streaming(self, decode: bool | None = None) -> Iterator[Data]: """ Receive the next message frame by frame. - If the message is fragmented, yield each fragment as it is received. - The iterator must be fully consumed, or else the connection will become + This method is designed for receiving fragmented messages. It returns an + iterator that yields each fragment as it is received. This iterator must + be fully consumed. Else, future calls to :meth:`recv` or + :meth:`recv_streaming` will raise + :exc:`~websockets.exceptions.ConcurrencyError`, making the connection unusable. :meth:`recv_streaming` raises the same exceptions as :meth:`recv`. + Args: + decode: Set this flag to override the default behavior of returning + :class:`str` or :class:`bytes`. See below for details. + Returns: An iterator of strings (:class:`str`) for a Text_ frame or bytestrings (:class:`bytes`) for a Binary_ frame. @@ -233,6 +266,15 @@ def recv_streaming(self) -> Iterator[Data]: .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + You may override this behavior with the ``decode`` argument: + + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames + and return bytestrings (:class:`bytes`). This may be useful to + optimize performance when decoding isn't needed. + * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames + and return strings (:class:`str`). This is useful for servers + that send binary frames instead of text frames. + Raises: ConnectionClosed: When the connection is closed. ConcurrencyError: If two threads call :meth:`recv` or @@ -240,7 +282,7 @@ def recv_streaming(self) -> Iterator[Data]: """ try: - yield from self.recv_messages.get_iter() + yield from self.recv_messages.get_iter(decode) except EOFError: # Wait for the protocol state to be CLOSED before accessing close_exc. self.recv_events_thread.join() @@ -571,8 +613,9 @@ def recv_events(self) -> None: try: while True: try: - if self.close_deadline is not None: - self.socket.settimeout(self.close_deadline.timeout()) + with self.recv_flow_control: + if self.close_deadline is not None: + self.socket.settimeout(self.close_deadline.timeout()) data = self.socket.recv(self.recv_bufsize) except Exception as exc: if self.debug: @@ -622,13 +665,9 @@ def recv_events(self) -> None: # Given that automatic responses write small amounts of data, # this should be uncommon, so we don't handle the edge case. - try: - for event in events: - # This may raise EOFError if the closing handshake - # times out while a message is waiting to be read. - self.process_event(event) - except EOFError: - break + for event in events: + # This isn't expected to raise an exception. + self.process_event(event) # Breaking out of the while True: ... loop means that we believe # that the socket doesn't work anymore. diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index 997fa98df..983b114dc 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -3,12 +3,12 @@ import codecs import queue import threading -from collections.abc import Iterator -from typing import cast +from typing import Any, Callable, Iterable, Iterator from ..exceptions import ConcurrencyError from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame from ..typing import Data +from .utils import Deadline __all__ = ["Assembler"] @@ -20,47 +20,83 @@ class Assembler: """ Assemble messages from frames. + :class:`Assembler` expects only data frames. The stream of frames must + respect the protocol; if it doesn't, the behavior is undefined. + + Args: + pause: Called when the buffer of frames goes above the high water mark; + should pause reading from the network. + resume: Called when the buffer of frames goes below the low water mark; + should resume reading from the network. + """ - def __init__(self) -> None: + def __init__( + self, + high: int = 16, + low: int | None = None, + pause: Callable[[], Any] = lambda: None, + resume: Callable[[], Any] = lambda: None, + ) -> None: # Serialize reads and writes -- except for reads via synchronization # primitives provided by the threading and queue modules. self.mutex = threading.Lock() - # We create a latch with two events to synchronize the production of - # frames and the consumption of messages (or frames) without a buffer. - # This design requires a switch between the library thread and the user - # thread for each message; that shouldn't be a performance bottleneck. - - # put() sets this event to tell get() that a message can be fetched. - self.message_complete = threading.Event() - # get() sets this event to let put() that the message was fetched. - self.message_fetched = threading.Event() + # Queue of incoming frames. + self.frames: queue.SimpleQueue[Frame | None] = queue.SimpleQueue() + + # We cannot put a hard limit on the size of the queue because a single + # call to Protocol.data_received() could produce thousands of frames, + # which must be buffered. Instead, we pause reading when the buffer goes + # above the high limit and we resume when it goes under the low limit. + if low is None: + low = high // 4 + if low < 0: + raise ValueError("low must be positive or equal to zero") + if high < low: + raise ValueError("high must be greater than or equal to low") + self.high, self.low = high, low + self.pause = pause + self.resume = resume + self.paused = False # This flag prevents concurrent calls to get() by user code. self.get_in_progress = False - # This flag prevents concurrent calls to put() by library code. - self.put_in_progress = False - - # Decoder for text frames, None for binary frames. - self.decoder: codecs.IncrementalDecoder | None = None - - # Buffer of frames belonging to the same message. - self.chunks: list[Data] = [] - - # When switching from "buffering" to "streaming", we use a thread-safe - # queue for transferring frames from the writing thread (library code) - # to the reading thread (user code). We're buffering when chunks_queue - # is None and streaming when it's a SimpleQueue. None is a sentinel - # value marking the end of the message, superseding message_complete. - - # Stream data from frames belonging to the same message. - self.chunks_queue: queue.SimpleQueue[Data | None] | None = None # This flag marks the end of the connection. self.closed = False - def get(self, timeout: float | None = None) -> Data: + def get_next_frame(self, timeout: float | None = None) -> Frame: + # Helper to factor out the logic for getting the next frame from the + # queue, while handling timeouts and reaching the end of the stream. + try: + frame = self.frames.get(timeout=timeout) + except queue.Empty: + raise TimeoutError(f"timed out in {timeout:.1f}s") from None + if frame is None: + raise EOFError("stream of frames ended") + return frame + + def reset_queue(self, frames: Iterable[Frame]) -> None: + # Helper to put frames back into the queue after they were fetched. + # This happens only when the queue is empty. However, by the time + # we acquire self.mutex, put() may have added items in the queue. + # Therefore, we must handle the case where the queue is not empty. + frame: Frame | None + with self.mutex: + queued = [] + try: + while True: + queued.append(self.frames.get_nowait()) + except queue.Empty: + pass + for frame in frames: + self.frames.put(frame) + # This loop runs only when a race condition occurs. + for frame in queued: # pragma: no cover + self.frames.put(frame) + + def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: """ Read the next message. @@ -73,11 +109,14 @@ def get(self, timeout: float | None = None) -> Data: Args: timeout: If a timeout is provided and elapses before a complete message is received, :meth:`get` raises :exc:`TimeoutError`. + decode: :obj:`False` disables UTF-8 decoding of text frames and + returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of + binary frames and returns :class:`str`. Raises: EOFError: If the stream of frames has ended. - ConcurrencyError: If two threads run :meth:`get` or :meth:`get_iter` - concurrently. + ConcurrencyError: If two coroutines run :meth:`get` or + :meth:`get_iter` concurrently. TimeoutError: If a timeout is provided and elapses before a complete message is received. @@ -89,40 +128,45 @@ def get(self, timeout: float | None = None) -> Data: if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") + # Locking with get_in_progress ensures only one thread can get here. self.get_in_progress = True - # If the message_complete event isn't set yet, release the lock to - # allow put() to run and eventually set it. - # Locking with get_in_progress ensures only one thread can get here. - completed = self.message_complete.wait(timeout) + try: + deadline = Deadline(timeout) + + # First frame + frame = self.get_next_frame(deadline.timeout()) + with self.mutex: + self.maybe_resume() + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + frames = [frame] + + # Following frames, for fragmented messages + while not frame.fin: + try: + frame = self.get_next_frame(deadline.timeout()) + except TimeoutError: + # Put frames already received back into the queue + # so that future calls to get() can return them. + self.reset_queue(frames) + raise + with self.mutex: + self.maybe_resume() + assert frame.opcode is OP_CONT + frames.append(frame) - with self.mutex: + finally: self.get_in_progress = False - # Waiting for a complete message timed out. - if not completed: - raise TimeoutError(f"timed out in {timeout:.1f}s") - - # get() was unblocked by close() rather than put(). - if self.closed: - raise EOFError("stream of frames ended") - - assert self.message_complete.is_set() - self.message_complete.clear() - - joiner: Data = b"" if self.decoder is None else "" - # mypy cannot figure out that chunks have the proper type. - message: Data = joiner.join(self.chunks) # type: ignore + data = b"".join(frame.data for frame in frames) + if decode: + return data.decode() + else: + return data - self.chunks = [] - assert self.chunks_queue is None - - assert not self.message_fetched.is_set() - self.message_fetched.set() - - return message - - def get_iter(self) -> Iterator[Data]: + def get_iter(self, decode: bool | None = None) -> Iterator[Data]: """ Stream the next message. @@ -135,10 +179,15 @@ def get_iter(self) -> Iterator[Data]: This method only makes sense for fragmented messages. If messages aren't fragmented, use :meth:`get` instead. + Args: + decode: :obj:`False` disables UTF-8 decoding of text frames and + returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of + binary frames and returns :class:`str`. + Raises: EOFError: If the stream of frames has ended. - ConcurrencyError: If two threads run :meth:`get` or :meth:`get_iter` - concurrently. + ConcurrencyError: If two coroutines run :meth:`get` or + :meth:`get_iter` concurrently. """ with self.mutex: @@ -148,116 +197,81 @@ def get_iter(self) -> Iterator[Data]: if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - chunks = self.chunks - self.chunks = [] - self.chunks_queue = cast( - # Remove quotes around type when dropping Python < 3.10. - "queue.SimpleQueue[Data | None]", - queue.SimpleQueue(), - ) - - # Sending None in chunk_queue supersedes setting message_complete - # when switching to "streaming". If message is already complete - # when the switch happens, put() didn't send None, so we have to. - if self.message_complete.is_set(): - self.chunks_queue.put(None) - + # Locking with get_in_progress ensures only one coroutine can get here. self.get_in_progress = True - # Locking with get_in_progress ensures only one thread can get here. - chunk: Data | None - for chunk in chunks: - yield chunk - while (chunk := self.chunks_queue.get()) is not None: - yield chunk + # Locking with get_in_progress prevents concurrent execution until + # get_iter() fetches a complete message or is cancelled. - with self.mutex: - self.get_in_progress = False + # If get_iter() raises an exception e.g. in decoder.decode(), + # get_in_progress remains set and the connection becomes unusable. - # get_iter() was unblocked by close() rather than put(). - if self.closed: - raise EOFError("stream of frames ended") - - assert self.message_complete.is_set() - self.message_complete.clear() - - assert self.chunks == [] - self.chunks_queue = None + # First frame + frame = self.get_next_frame() + with self.mutex: + self.maybe_resume() + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + if decode: + decoder = UTF8Decoder() + yield decoder.decode(frame.data, frame.fin) + else: + yield frame.data + + # Following frames, for fragmented messages + while not frame.fin: + frame = self.get_next_frame() + with self.mutex: + self.maybe_resume() + assert frame.opcode is OP_CONT + if decode: + yield decoder.decode(frame.data, frame.fin) + else: + yield frame.data - assert not self.message_fetched.is_set() - self.message_fetched.set() + self.get_in_progress = False def put(self, frame: Frame) -> None: """ Add ``frame`` to the next message. - When ``frame`` is the final frame in a message, :meth:`put` waits until - the message is fetched, which can be achieved by calling :meth:`get` or - by fully consuming the return value of :meth:`get_iter`. - - :meth:`put` assumes that the stream of frames respects the protocol. If - it doesn't, the behavior is undefined. - Raises: EOFError: If the stream of frames has ended. - ConcurrencyError: If two threads run :meth:`put` concurrently. """ with self.mutex: if self.closed: raise EOFError("stream of frames ended") - if self.put_in_progress: - raise ConcurrencyError("put is already running") - - if frame.opcode is OP_TEXT: - self.decoder = UTF8Decoder(errors="strict") - elif frame.opcode is OP_BINARY: - self.decoder = None - else: - assert frame.opcode is OP_CONT - - data: Data - if self.decoder is not None: - data = self.decoder.decode(frame.data, frame.fin) - else: - data = frame.data - - if self.chunks_queue is None: - self.chunks.append(data) - else: - self.chunks_queue.put(data) - - if not frame.fin: - return - - # Message is complete. Wait until it's fetched to return. - - assert not self.message_complete.is_set() - self.message_complete.set() - - if self.chunks_queue is not None: - self.chunks_queue.put(None) - - assert not self.message_fetched.is_set() - - self.put_in_progress = True - - # Release the lock to allow get() to run and eventually set the event. - # Locking with put_in_progress ensures only one coroutine can get here. - self.message_fetched.wait() - - with self.mutex: - self.put_in_progress = False - - # put() was unblocked by close() rather than get() or get_iter(). - if self.closed: - raise EOFError("stream of frames ended") - - assert self.message_fetched.is_set() - self.message_fetched.clear() - - self.decoder = None + self.frames.put(frame) + self.maybe_pause() + + # put() and get/get_iter() call maybe_pause() and maybe_resume() while + # holding self.mutex. This guarantees that the calls interleave properly. + # Specifically, it prevents a race condition where maybe_resume() would + # run before maybe_pause(), leaving the connection incorrectly paused. + + # A race condition is possible when get/get_iter() call self.frames.get() + # without holding self.mutex. However, it's harmless — and even beneficial! + # It can only result in popping an item from the queue before maybe_resume() + # runs and skipping a pause() - resume() cycle that would otherwise occur. + + def maybe_pause(self) -> None: + """Pause the writer if queue is above the high water mark.""" + assert self.mutex.locked() + # Check for "> high" to support high = 0 + if self.frames.qsize() > self.high and not self.paused: + self.paused = True + self.pause() + + def maybe_resume(self) -> None: + """Resume the writer if queue is below the low water mark.""" + assert self.mutex.locked() + # Check for "<= low" to support low = 0 + if self.frames.qsize() <= self.low and self.paused: + self.paused = False + self.resume() def close(self) -> None: """ @@ -273,12 +287,5 @@ def close(self) -> None: self.closed = True - # Unblock get or get_iter. - if self.get_in_progress: - self.message_complete.set() - if self.chunks_queue is not None: - self.chunks_queue.put(None) - - # Unblock put(). - if self.put_in_progress: - self.message_fetched.set() + # Unblock get() or get_iter(). + self.frames.put(None) diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 464c4a173..94f76b658 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -51,10 +51,12 @@ class ServerConnection(Connection): :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is closed with any other code. + The ``close_timeout`` and ``max_queue`` arguments have the same meaning as + in :func:`serve`. + Args: socket: Socket connected to a WebSocket client. protocol: Sans-I/O connection. - close_timeout: Timeout for closing the connection in seconds. """ @@ -64,6 +66,7 @@ def __init__( protocol: ServerProtocol, *, close_timeout: float | None = 10, + max_queue: int | tuple[int, int | None] = 16, ) -> None: self.protocol: ServerProtocol self.request_rcvd = threading.Event() @@ -71,6 +74,7 @@ def __init__( socket, protocol, close_timeout=close_timeout, + max_queue=max_queue, ) self.username: str # see basic_auth() @@ -349,6 +353,7 @@ def serve( close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, + max_queue: int | tuple[int, int | None] = 16, # Logging logger: LoggerLike | None = None, # Escape hatch for advanced customization @@ -427,6 +432,10 @@ def handler(websocket): :obj:`None` disables the timeout. max_size: Maximum size of incoming messages in bytes. :obj:`None` disables the limit. + max_queue: High-water mark of the buffer where frames are received. + It defaults to 16 frames. The low-water mark defaults to ``max_queue + // 4``. You may pass a ``(high, low)`` tuple to set the high-water + and low-water marks. logger: Logger for this server. It defaults to ``logging.getLogger("websockets.server")``. See the :doc:`logging guide <../../topics/logging>` for details. @@ -548,6 +557,7 @@ def protocol_select_subprotocol( sock, protocol, close_timeout=close_timeout, + max_queue=max_queue, ) except Exception: sock.close() diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 563cf2b17..12e2bd5fa 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -793,8 +793,8 @@ async def test_close_timeout_waiting_for_connection_closed(self): self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) async def test_close_does_not_wait_for_recv(self): - # The asyncio implementation has a buffer for incoming messages. Closing - # the connection discards buffered messages. This is allowed by the RFC: + # Closing the connection discards messages buffered in the assembler. + # This is allowed by the RFC: # > However, there is no guarantee that the endpoint that has already # > sent a Close frame will continue to process data. await self.remote_connection.send("😀") @@ -1075,7 +1075,10 @@ async def test_max_queue(self): async def test_max_queue_tuple(self): """max_queue parameter configures high-water mark of frames buffer.""" - connection = Connection(Protocol(self.LOCAL), max_queue=(4, 2)) + connection = Connection( + Protocol(self.LOCAL), + max_queue=(4, 2), + ) transport = Mock() connection.connection_made(transport) self.assertEqual(connection.recv_messages.high, 4) @@ -1083,14 +1086,20 @@ async def test_max_queue_tuple(self): async def test_write_limit(self): """write_limit parameter configures high-water mark of write buffer.""" - connection = Connection(Protocol(self.LOCAL), write_limit=4096) + connection = Connection( + Protocol(self.LOCAL), + write_limit=4096, + ) transport = Mock() connection.connection_made(transport) transport.set_write_buffer_limits.assert_called_once_with(4096, None) async def test_write_limits(self): """write_limit parameter configures high and low-water marks of write buffer.""" - connection = Connection(Protocol(self.LOCAL), write_limit=(4096, 2048)) + connection = Connection( + Protocol(self.LOCAL), + write_limit=(4096, 2048), + ) transport = Mock() connection.connection_made(transport) transport.set_write_buffer_limits.assert_called_once_with(4096, 2048) diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py index d2cf25c9c..2ff929d3a 100644 --- a/tests/asyncio/test_messages.py +++ b/tests/asyncio/test_messages.py @@ -350,7 +350,7 @@ async def test_cancel_get_iter_before_first_frame(self): self.assertEqual(fragments, ["café"]) async def test_cancel_get_iter_after_first_frame(self): - """get cannot be canceled after reading the first frame.""" + """get_iter cannot be canceled after reading the first frame.""" self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) getter_task = asyncio.create_task(alist(self.assembler.get_iter())) @@ -429,7 +429,7 @@ async def test_get_fails_when_get_is_running(self): await asyncio.sleep(0) with self.assertRaises(ConcurrencyError): await self.assembler.get() - self.assembler.close() # let task terminate + self.assembler.put(Frame(OP_TEXT, b"")) # let task terminate async def test_get_fails_when_get_iter_is_running(self): """get cannot be called concurrently with get_iter.""" @@ -437,7 +437,7 @@ async def test_get_fails_when_get_iter_is_running(self): await asyncio.sleep(0) with self.assertRaises(ConcurrencyError): await self.assembler.get() - self.assembler.close() # let task terminate + self.assembler.put(Frame(OP_TEXT, b"")) # let task terminate async def test_get_iter_fails_when_get_is_running(self): """get_iter cannot be called concurrently with get.""" @@ -445,7 +445,7 @@ async def test_get_iter_fails_when_get_is_running(self): await asyncio.sleep(0) with self.assertRaises(ConcurrencyError): await alist(self.assembler.get_iter()) - self.assembler.close() # let task terminate + self.assembler.put(Frame(OP_TEXT, b"")) # let task terminate async def test_get_iter_fails_when_get_iter_is_running(self): """get_iter cannot be called concurrently.""" @@ -453,7 +453,7 @@ async def test_get_iter_fails_when_get_iter_is_running(self): await asyncio.sleep(0) with self.assertRaises(ConcurrencyError): await alist(self.assembler.get_iter()) - self.assembler.close() # let task terminate + self.assembler.put(Frame(OP_TEXT, b"")) # let task terminate # Test setting limits @@ -463,7 +463,7 @@ async def test_set_high_water_mark(self): self.assertEqual(assembler.high, 10) async def test_set_high_and_low_water_mark(self): - """high sets the high-water mark.""" + """high sets the high-water mark and low-water mark.""" assembler = Assembler(high=10, low=5) self.assertEqual(assembler.high, 10) self.assertEqual(assembler.low, 5) diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 87333fd35..db1cc8e93 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -154,6 +154,16 @@ def test_recv_binary(self): self.remote_connection.send(b"\x01\x02\xfe\xff") self.assertEqual(self.connection.recv(), b"\x01\x02\xfe\xff") + def test_recv_text_as_bytes(self): + """recv receives a text message as bytes.""" + self.remote_connection.send("😀") + self.assertEqual(self.connection.recv(decode=False), "😀".encode()) + + def test_recv_binary_as_text(self): + """recv receives a binary message as a str.""" + self.remote_connection.send("😀".encode()) + self.assertEqual(self.connection.recv(decode=True), "😀") + def test_recv_fragmented_text(self): """recv receives a fragmented text message.""" self.remote_connection.send(["😀", "😀"]) @@ -228,6 +238,22 @@ def test_recv_streaming_binary(self): [b"\x01\x02\xfe\xff"], ) + def test_recv_streaming_text_as_bytes(self): + """recv_streaming receives a text message as bytes.""" + self.remote_connection.send("😀") + self.assertEqual( + list(self.connection.recv_streaming(decode=False)), + ["😀".encode()], + ) + + def test_recv_streaming_binary_as_str(self): + """recv_streaming receives a binary message as a str.""" + self.remote_connection.send("😀".encode()) + self.assertEqual( + list(self.connection.recv_streaming(decode=True)), + ["😀"], + ) + def test_recv_streaming_fragmented_text(self): """recv_streaming receives a fragmented text message.""" self.remote_connection.send(["😀", "😀"]) @@ -499,28 +525,17 @@ def test_close_timeout_waiting_for_connection_closed(self): # Remove socket.timeout when dropping Python < 3.10. self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) - def test_close_waits_for_recv(self): - # The sync implementation doesn't have a buffer for incoming messsages. - # It requires reading incoming frames until the close frame is reached. - # This behavior — close() blocks until recv() is called — is less than - # ideal and inconsistent with the asyncio implementation. + def test_close_does_not_wait_for_recv(self): + # Closing the connection discards messages buffered in the assembler. + # This is allowed by the RFC: + # > However, there is no guarantee that the endpoint that has already + # > sent a Close frame will continue to process data. self.remote_connection.send("😀") + self.connection.close() close_thread = threading.Thread(target=self.connection.close) close_thread.start() - # Let close() initiate the closing handshake and send a close frame. - time.sleep(MS) - self.assertTrue(close_thread.is_alive()) - - # Connection isn't closed yet. - self.connection.recv() - - # Let close() receive a close frame and finish the closing handshake. - time.sleep(MS) - self.assertFalse(close_thread.is_alive()) - - # Connection is closed now. with self.assertRaises(ConnectionClosedOK) as raised: self.connection.recv() @@ -528,24 +543,6 @@ def test_close_waits_for_recv(self): self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") self.assertIsNone(exc.__cause__) - def test_close_timeout_waiting_for_recv(self): - self.remote_connection.send("😀") - - close_thread = threading.Thread(target=self.connection.close) - close_thread.start() - - # Let close() time out during the closing handshake. - time.sleep(3 * MS) - self.assertFalse(close_thread.is_alive()) - - # Connection is closed now. - with self.assertRaises(ConnectionClosedError) as raised: - self.connection.recv() - - exc = raised.exception - self.assertEqual(str(exc), "sent 1000 (OK); no close frame received") - self.assertIsInstance(exc.__cause__, TimeoutError) - def test_close_idempotency(self): """close does nothing if the connection is already closed.""" self.connection.close() @@ -724,6 +721,45 @@ def test_pong_unsupported_type(self): with self.assertRaises(TypeError): self.connection.pong([]) + # Test parameters. + + def test_close_timeout(self): + """close_timeout parameter configures close timeout.""" + socket_, remote_socket = socket.socketpair() + self.addCleanup(socket_.close) + self.addCleanup(remote_socket.close) + connection = Connection( + socket_, + Protocol(self.LOCAL), + close_timeout=42 * MS, + ) + self.assertEqual(connection.close_timeout, 42 * MS) + + def test_max_queue(self): + """max_queue parameter configures high-water mark of frames buffer.""" + socket_, remote_socket = socket.socketpair() + self.addCleanup(socket_.close) + self.addCleanup(remote_socket.close) + connection = Connection( + socket_, + Protocol(self.LOCAL), + max_queue=4, + ) + self.assertEqual(connection.recv_messages.high, 4) + + def test_max_queue_tuple(self): + """max_queue parameter configures high-water mark of frames buffer.""" + socket_, remote_socket = socket.socketpair() + self.addCleanup(socket_.close) + self.addCleanup(remote_socket.close) + connection = Connection( + socket_, + Protocol(self.LOCAL), + max_queue=(4, 2), + ) + self.assertEqual(connection.recv_messages.high, 4) + self.assertEqual(connection.recv_messages.low, 2) + # Test attributes. def test_id(self): diff --git a/tests/sync/test_messages.py b/tests/sync/test_messages.py index d44b39b88..02513894a 100644 --- a/tests/sync/test_messages.py +++ b/tests/sync/test_messages.py @@ -1,4 +1,6 @@ import time +import unittest +import unittest.mock from websockets.exceptions import ConcurrencyError from websockets.frames import OP_BINARY, OP_CONT, OP_TEXT, Frame @@ -9,66 +11,23 @@ class AssemblerTests(ThreadTestCase): - """ - Tests in this class interact a lot with hidden synchronization mechanisms: - - - get() / get_iter() and put() must run in separate threads when a final - frame is set because put() waits for get() / get_iter() to fetch the - message before returning. - - - run_in_thread() lets its target run before yielding back control on entry, - which guarantees the intended execution order of test cases. - - - run_in_thread() waits for its target to finish running before yielding - back control on exit, which allows making assertions immediately. - - - When the main thread performs actions that let another thread progress, it - must wait before making assertions, to avoid depending on scheduling. - - """ - def setUp(self): - self.assembler = Assembler() - - def tearDown(self): - """ - Check that the assembler goes back to its default state after each test. - - This removes the need for testing various sequences. - - """ - self.assertFalse(self.assembler.mutex.locked()) - self.assertFalse(self.assembler.get_in_progress) - self.assertFalse(self.assembler.put_in_progress) - if not self.assembler.closed: - self.assertFalse(self.assembler.message_complete.is_set()) - self.assertFalse(self.assembler.message_fetched.is_set()) - self.assertIsNone(self.assembler.decoder) - self.assertEqual(self.assembler.chunks, []) - self.assertIsNone(self.assembler.chunks_queue) + self.pause = unittest.mock.Mock() + self.resume = unittest.mock.Mock() + self.assembler = Assembler(high=2, low=1, pause=self.pause, resume=self.resume) # Test get def test_get_text_message_already_received(self): """get returns a text message that is already received.""" - - def putter(): - self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) - - with self.run_in_thread(putter): - message = self.assembler.get() - + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = self.assembler.get() self.assertEqual(message, "café") def test_get_binary_message_already_received(self): """get returns a binary message that is already received.""" - - def putter(): - self.assembler.put(Frame(OP_BINARY, b"tea")) - - with self.run_in_thread(putter): - message = self.assembler.get() - + self.assembler.put(Frame(OP_BINARY, b"tea")) + message = self.assembler.get() self.assertEqual(message, b"tea") def test_get_text_message_not_received_yet(self): @@ -99,112 +58,145 @@ def getter(): def test_get_fragmented_text_message_already_received(self): """get reassembles a fragmented a text message that is already received.""" - - def putter(): - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) - self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) - self.assembler.put(Frame(OP_CONT, b"\xa9")) - - with self.run_in_thread(putter): - message = self.assembler.get() - + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + message = self.assembler.get() self.assertEqual(message, "café") def test_get_fragmented_binary_message_already_received(self): """get reassembles a fragmented binary message that is already received.""" - - def putter(): - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) - self.assembler.put(Frame(OP_CONT, b"e", fin=False)) - self.assembler.put(Frame(OP_CONT, b"a")) - - with self.run_in_thread(putter): - message = self.assembler.get() - + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + message = self.assembler.get() self.assertEqual(message, b"tea") - def test_get_fragmented_text_message_being_received(self): - """get reassembles a fragmented text message that is partially received.""" + def test_get_fragmented_text_message_not_received_yet(self): + """get reassembles a fragmented text message when it is received.""" message = None def getter(): nonlocal message message = self.assembler.get() - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) with self.run_in_thread(getter): + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) self.assertEqual(message, "café") - def test_get_fragmented_binary_message_being_received(self): - """get reassembles a fragmented binary message that is partially received.""" + def test_get_fragmented_binary_message_not_received_yet(self): + """get reassembles a fragmented binary message when it is received.""" message = None def getter(): nonlocal message message = self.assembler.get() - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) with self.run_in_thread(getter): + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) self.assertEqual(message, b"tea") - def test_get_fragmented_text_message_not_received_yet(self): - """get reassembles a fragmented text message when it is received.""" + def test_get_fragmented_text_message_being_received(self): + """get reassembles a fragmented text message that is partially received.""" message = None def getter(): nonlocal message message = self.assembler.get() + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) with self.run_in_thread(getter): - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) self.assertEqual(message, "café") - def test_get_fragmented_binary_message_not_received_yet(self): - """get reassembles a fragmented binary message when it is received.""" + def test_get_fragmented_binary_message_being_received(self): + """get reassembles a fragmented binary message that is partially received.""" message = None def getter(): nonlocal message message = self.assembler.get() + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) with self.run_in_thread(getter): - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) self.assertEqual(message, b"tea") - # Test get_iter + def test_get_encoded_text_message(self): + """get returns a text message without UTF-8 decoding.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = self.assembler.get(decode=False) + self.assertEqual(message, b"caf\xc3\xa9") + + def test_get_decoded_binary_message(self): + """get returns a binary message with UTF-8 decoding.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + message = self.assembler.get(decode=True) + self.assertEqual(message, "tea") + + def test_get_resumes_reading(self): + """get resumes reading when queue goes below the high-water mark.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"water")) + + # queue is above the low-water mark + self.assembler.get() + self.resume.assert_not_called() + + # queue is at the low-water mark + self.assembler.get() + self.resume.assert_called_once_with() + + # queue is below the low-water mark + self.assembler.get() + self.resume.assert_called_once_with() + + def test_get_timeout_before_first_frame(self): + """get times out before reading the first frame.""" + with self.assertRaises(TimeoutError): + self.assembler.get(timeout=MS) - def test_get_iter_text_message_already_received(self): - """get_iter yields a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) - def putter(): - self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = self.assembler.get() + self.assertEqual(message, "café") - with self.run_in_thread(putter): - fragments = list(self.assembler.get_iter()) + def test_get_timeout_after_first_frame(self): + """get times out after reading the first frame.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) - self.assertEqual(fragments, ["café"]) + with self.assertRaises(TimeoutError): + self.assembler.get(timeout=MS) - def test_get_iter_binary_message_already_received(self): - """get_iter yields a binary message that is already received.""" + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) - def putter(): - self.assembler.put(Frame(OP_BINARY, b"tea")) + message = self.assembler.get() + self.assertEqual(message, "café") - with self.run_in_thread(putter): - fragments = list(self.assembler.get_iter()) + # Test get_iter + def test_get_iter_text_message_already_received(self): + """get_iter yields a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + fragments = list(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + def test_get_iter_binary_message_already_received(self): + """get_iter yields a binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + fragments = list(self.assembler.get_iter()) self.assertEqual(fragments, [b"tea"]) def test_get_iter_text_message_not_received_yet(self): @@ -212,6 +204,7 @@ def test_get_iter_text_message_not_received_yet(self): fragments = [] def getter(): + nonlocal fragments for fragment in self.assembler.get_iter(): fragments.append(fragment) @@ -225,6 +218,7 @@ def test_get_iter_binary_message_not_received_yet(self): fragments = [] def getter(): + nonlocal fragments for fragment in self.assembler.get_iter(): fragments.append(fragment) @@ -235,121 +229,112 @@ def getter(): def test_get_iter_fragmented_text_message_already_received(self): """get_iter yields a fragmented text message that is already received.""" - - def putter(): - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) - self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) - self.assembler.put(Frame(OP_CONT, b"\xa9")) - - with self.run_in_thread(putter): - fragments = list(self.assembler.get_iter()) - + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + fragments = list(self.assembler.get_iter()) self.assertEqual(fragments, ["ca", "f", "é"]) def test_get_iter_fragmented_binary_message_already_received(self): """get_iter yields a fragmented binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + fragments = list(self.assembler.get_iter()) + self.assertEqual(fragments, [b"t", b"e", b"a"]) - def putter(): - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) - self.assembler.put(Frame(OP_CONT, b"e", fin=False)) - self.assembler.put(Frame(OP_CONT, b"a")) - - with self.run_in_thread(putter): - fragments = list(self.assembler.get_iter()) + def test_get_iter_fragmented_text_message_not_received_yet(self): + """get_iter yields a fragmented text message when it is received.""" + iterator = self.assembler.get_iter() + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assertEqual(next(iterator), "ca") + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assertEqual(next(iterator), "f") + self.assembler.put(Frame(OP_CONT, b"\xa9")) + self.assertEqual(next(iterator), "é") - self.assertEqual(fragments, [b"t", b"e", b"a"]) + def test_get_iter_fragmented_binary_message_not_received_yet(self): + """get_iter yields a fragmented binary message when it is received.""" + iterator = self.assembler.get_iter() + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assertEqual(next(iterator), b"t") + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assertEqual(next(iterator), b"e") + self.assembler.put(Frame(OP_CONT, b"a")) + self.assertEqual(next(iterator), b"a") def test_get_iter_fragmented_text_message_being_received(self): """get_iter yields a fragmented text message that is partially received.""" - fragments = [] - - def getter(): - for fragment in self.assembler.get_iter(): - fragments.append(fragment) - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) - with self.run_in_thread(getter): - self.assertEqual(fragments, ["ca"]) - self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) - time.sleep(MS) - self.assertEqual(fragments, ["ca", "f"]) - self.assembler.put(Frame(OP_CONT, b"\xa9")) - - self.assertEqual(fragments, ["ca", "f", "é"]) + iterator = self.assembler.get_iter() + self.assertEqual(next(iterator), "ca") + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assertEqual(next(iterator), "f") + self.assembler.put(Frame(OP_CONT, b"\xa9")) + self.assertEqual(next(iterator), "é") def test_get_iter_fragmented_binary_message_being_received(self): """get_iter yields a fragmented binary message that is partially received.""" - fragments = [] - - def getter(): - for fragment in self.assembler.get_iter(): - fragments.append(fragment) - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) - with self.run_in_thread(getter): - self.assertEqual(fragments, [b"t"]) - self.assembler.put(Frame(OP_CONT, b"e", fin=False)) - time.sleep(MS) - self.assertEqual(fragments, [b"t", b"e"]) - self.assembler.put(Frame(OP_CONT, b"a")) - - self.assertEqual(fragments, [b"t", b"e", b"a"]) - - def test_get_iter_fragmented_text_message_not_received_yet(self): - """get_iter yields a fragmented text message when it is received.""" - fragments = [] - - def getter(): - for fragment in self.assembler.get_iter(): - fragments.append(fragment) + iterator = self.assembler.get_iter() + self.assertEqual(next(iterator), b"t") + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assertEqual(next(iterator), b"e") + self.assembler.put(Frame(OP_CONT, b"a")) + self.assertEqual(next(iterator), b"a") + + def test_get_iter_encoded_text_message(self): + """get_iter yields a text message without UTF-8 decoding.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + fragments = list(self.assembler.get_iter(decode=False)) + self.assertEqual(fragments, [b"ca", b"f\xc3", b"\xa9"]) - with self.run_in_thread(getter): - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) - time.sleep(MS) - self.assertEqual(fragments, ["ca"]) - self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) - time.sleep(MS) - self.assertEqual(fragments, ["ca", "f"]) - self.assembler.put(Frame(OP_CONT, b"\xa9")) + def test_get_iter_decoded_binary_message(self): + """get_iter yields a binary message with UTF-8 decoding.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + fragments = list(self.assembler.get_iter(decode=True)) + self.assertEqual(fragments, ["t", "e", "a"]) - self.assertEqual(fragments, ["ca", "f", "é"]) + def test_get_iter_resumes_reading(self): + """get_iter resumes reading when queue goes below the high-water mark.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) - def test_get_iter_fragmented_binary_message_not_received_yet(self): - """get_iter yields a fragmented binary message when it is received.""" - fragments = [] + iterator = self.assembler.get_iter() - def getter(): - for fragment in self.assembler.get_iter(): - fragments.append(fragment) + # queue is above the low-water mark + next(iterator) + self.resume.assert_not_called() - with self.run_in_thread(getter): - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) - time.sleep(MS) - self.assertEqual(fragments, [b"t"]) - self.assembler.put(Frame(OP_CONT, b"e", fin=False)) - time.sleep(MS) - self.assertEqual(fragments, [b"t", b"e"]) - self.assembler.put(Frame(OP_CONT, b"a")) + # queue is at the low-water mark + next(iterator) + self.resume.assert_called_once_with() - self.assertEqual(fragments, [b"t", b"e", b"a"]) - - # Test timeouts + # queue is below the low-water mark + next(iterator) + self.resume.assert_called_once_with() - def test_get_with_timeout_completes(self): - """get returns a message when it is received before the timeout.""" + # Test put - def putter(): - self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) - - with self.run_in_thread(putter): - message = self.assembler.get(MS) + def test_put_pauses_reading(self): + """put pauses reading when queue goes above the high-water mark.""" + # queue is below the high-water mark + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.pause.assert_not_called() - self.assertEqual(message, "café") + # queue is at the high-water mark + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.pause.assert_called_once_with() - def test_get_with_timeout_times_out(self): - """get raises TimeoutError when no message is received before the timeout.""" - with self.assertRaises(TimeoutError): - self.assembler.get(MS) + # queue is above the high-water mark + self.assembler.put(Frame(OP_CONT, b"a")) + self.pause.assert_called_once_with() # Test termination @@ -373,18 +358,8 @@ def closer(): with self.run_in_thread(closer): with self.assertRaises(EOFError): - list(self.assembler.get_iter()) - - def test_put_fails_when_interrupted_by_close(self): - """put raises EOFError when close is called.""" - - def closer(): - time.sleep(2 * MS) - self.assembler.close() - - with self.run_in_thread(closer): - with self.assertRaises(EOFError): - self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + for _ in self.assembler.get_iter(): + self.fail("no fragment expected") def test_get_fails_after_close(self): """get raises EOFError after close is called.""" @@ -396,7 +371,8 @@ def test_get_iter_fails_after_close(self): """get_iter raises EOFError after close is called.""" self.assembler.close() with self.assertRaises(EOFError): - list(self.assembler.get_iter()) + for _ in self.assembler.get_iter(): + self.fail("no fragment expected") def test_put_fails_after_close(self): """put raises EOFError after close is called.""" @@ -439,13 +415,25 @@ def test_get_iter_fails_when_get_iter_is_running(self): list(self.assembler.get_iter()) self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread - def test_put_fails_when_put_is_running(self): - """put cannot be called concurrently.""" + # Test setting limits - def putter(): - self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + def test_set_high_water_mark(self): + """high sets the high-water mark.""" + assembler = Assembler(high=10) + self.assertEqual(assembler.high, 10) - with self.run_in_thread(putter): - with self.assertRaises(ConcurrencyError): - self.assembler.put(Frame(OP_BINARY, b"tea")) - self.assembler.get() # unblock other thread + def test_set_high_and_low_water_mark(self): + """high sets the high-water mark and low-water mark.""" + assembler = Assembler(high=10, low=5) + self.assertEqual(assembler.high, 10) + self.assertEqual(assembler.low, 5) + + def test_set_invalid_high_water_mark(self): + """high must be a non-negative integer.""" + with self.assertRaises(ValueError): + Assembler(high=-1) + + def test_set_invalid_low_water_mark(self): + """low must be higher than high.""" + with self.assertRaises(ValueError): + Assembler(low=10, high=5) From 6c9e3f48ff48dd4ad025f307e68d1be3c0687b4e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 25 Oct 2024 14:16:06 +0200 Subject: [PATCH 21/51] Update feature list for encode/decode params. --- docs/reference/features.rst | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 8b04034eb..576ea1025 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -43,12 +43,18 @@ Both sides +------------------------------------+--------+--------+--------+--------+ | Send a fragmented message | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ - | Receive a fragmented message frame | ✅ | ✅ | ✅ | ❌ | + | Receive a fragmented message frame | ✅ | ✅ | — | ❌ | | by frame | | | | | +------------------------------------+--------+--------+--------+--------+ | Receive a fragmented message after | ✅ | ✅ | — | ✅ | | reassembly | | | | | +------------------------------------+--------+--------+--------+--------+ + | Force sending a message as Text or | ✅ | ✅ | — | ❌ | + | Binary | | | | | + +------------------------------------+--------+--------+--------+--------+ + | Force receiving a message as | ✅ | ✅ | — | ❌ | + | :class:`bytes` or :class:`str` | | | | | + +------------------------------------+--------+--------+--------+--------+ | Send a ping | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Respond to pings automatically | ✅ | ✅ | ✅ | ✅ | From 8315d3cbd356fb2fbbe3fd03e189e3750a4d0399 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 25 Oct 2024 16:18:30 +0200 Subject: [PATCH 22/51] Close the connection with code 1007 on invalid UTF-8. Fix #1523. --- docs/reference/features.rst | 5 +++++ src/websockets/asyncio/connection.py | 33 +++++++++++++++++++++++----- src/websockets/asyncio/messages.py | 2 ++ src/websockets/frames.py | 1 - src/websockets/sync/connection.py | 33 +++++++++++++++++++++++----- src/websockets/sync/messages.py | 2 ++ tests/asyncio/test_connection.py | 18 +++++++++++++++ tests/sync/test_connection.py | 18 +++++++++++++++ 8 files changed, 99 insertions(+), 13 deletions(-) diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 576ea1025..9187fa505 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -194,3 +194,8 @@ connection to a given IP address in a CONNECTING state. This behavior is mandated by :rfc:`6455`, section 4.1. However, :func:`~asyncio.client.connect()` isn't the right layer for enforcing this constraint. It's the caller's responsibility. + +It is possible to send or receive a text message containing invalid UTF-8 with +``send(not_utf8_bytes, text=True)`` and ``not_utf8_bytes = recv(decode=False)`` +respectively. As a side effect of disabling UTF-8 encoding and decoding, these +options also disable UTF-8 validation. diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 3b81e386b..2568249c7 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -268,14 +268,24 @@ async def recv(self, decode: bool | None = None) -> Data: try: return await self.recv_messages.get(decode) except EOFError: - # Wait for the protocol state to be CLOSED before accessing close_exc. - await asyncio.shield(self.connection_lost_waiter) - raise self.protocol.close_exc from self.recv_exc + pass + # fallthrough except ConcurrencyError: raise ConcurrencyError( "cannot call recv while another coroutine " "is already running recv or recv_streaming" ) from None + except UnicodeDecodeError as exc: + async with self.send_context(): + self.protocol.fail( + CloseCode.INVALID_DATA, + f"{exc.reason} at position {exc.start}", + ) + # fallthrough + + # Wait for the protocol state to be CLOSED before accessing close_exc. + await asyncio.shield(self.connection_lost_waiter) + raise self.protocol.close_exc from self.recv_exc async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data]: """ @@ -324,15 +334,26 @@ async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data try: async for frame in self.recv_messages.get_iter(decode): yield frame + return except EOFError: - # Wait for the protocol state to be CLOSED before accessing close_exc. - await asyncio.shield(self.connection_lost_waiter) - raise self.protocol.close_exc from self.recv_exc + pass + # fallthrough except ConcurrencyError: raise ConcurrencyError( "cannot call recv_streaming while another coroutine " "is already running recv or recv_streaming" ) from None + except UnicodeDecodeError as exc: + async with self.send_context(): + self.protocol.fail( + CloseCode.INVALID_DATA, + f"{exc.reason} at position {exc.start}", + ) + # fallthrough + + # Wait for the protocol state to be CLOSED before accessing close_exc. + await asyncio.shield(self.connection_lost_waiter) + raise self.protocol.close_exc from self.recv_exc async def send( self, diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index 09be22ba2..b57c0ca4e 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -131,6 +131,7 @@ async def get(self, decode: bool | None = None) -> Data: Raises: EOFError: If the stream of frames has ended. + UnicodeDecodeError: If a text frame contains invalid UTF-8. ConcurrencyError: If two coroutines run :meth:`get` or :meth:`get_iter` concurrently. @@ -197,6 +198,7 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: Raises: EOFError: If the stream of frames has ended. + UnicodeDecodeError: If a text frame contains invalid UTF-8. ConcurrencyError: If two coroutines run :meth:`get` or :meth:`get_iter` concurrently. diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 5fadf3c2d..0ff9f4d71 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -222,7 +222,6 @@ def parse( Raises: EOFError: If the connection is closed without a full WebSocket frame. - UnicodeDecodeError: If the frame contains invalid UTF-8. PayloadTooBig: If the frame's payload size exceeds ``max_size``. ProtocolError: If the frame contains incorrect values. diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 3ab9f4937..823b44f74 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -233,14 +233,24 @@ def recv(self, timeout: float | None = None, decode: bool | None = None) -> Data try: return self.recv_messages.get(timeout, decode) except EOFError: - # Wait for the protocol state to be CLOSED before accessing close_exc. - self.recv_events_thread.join() - raise self.protocol.close_exc from self.recv_exc + pass + # fallthrough except ConcurrencyError: raise ConcurrencyError( "cannot call recv while another thread " "is already running recv or recv_streaming" ) from None + except UnicodeDecodeError as exc: + with self.send_context(): + self.protocol.fail( + CloseCode.INVALID_DATA, + f"{exc.reason} at position {exc.start}", + ) + # fallthrough + + # Wait for the protocol state to be CLOSED before accessing close_exc. + self.recv_events_thread.join() + raise self.protocol.close_exc from self.recv_exc def recv_streaming(self, decode: bool | None = None) -> Iterator[Data]: """ @@ -283,15 +293,26 @@ def recv_streaming(self, decode: bool | None = None) -> Iterator[Data]: """ try: yield from self.recv_messages.get_iter(decode) + return except EOFError: - # Wait for the protocol state to be CLOSED before accessing close_exc. - self.recv_events_thread.join() - raise self.protocol.close_exc from self.recv_exc + pass + # fallthrough except ConcurrencyError: raise ConcurrencyError( "cannot call recv_streaming while another thread " "is already running recv or recv_streaming" ) from None + except UnicodeDecodeError as exc: + with self.send_context(): + self.protocol.fail( + CloseCode.INVALID_DATA, + f"{exc.reason} at position {exc.start}", + ) + # fallthrough + + # Wait for the protocol state to be CLOSED before accessing close_exc. + self.recv_events_thread.join() + raise self.protocol.close_exc from self.recv_exc def send( self, diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index 983b114dc..17f8dce7e 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -115,6 +115,7 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: Raises: EOFError: If the stream of frames has ended. + UnicodeDecodeError: If a text frame contains invalid UTF-8. ConcurrencyError: If two coroutines run :meth:`get` or :meth:`get_iter` concurrently. TimeoutError: If a timeout is provided and elapses before a @@ -186,6 +187,7 @@ def get_iter(self, decode: bool | None = None) -> Iterator[Data]: Raises: EOFError: If the stream of frames has ended. + UnicodeDecodeError: If a text frame contains invalid UTF-8. ConcurrencyError: If two coroutines run :meth:`get` or :meth:`get_iter` concurrently. diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 12e2bd5fa..a3b65e956 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -222,6 +222,15 @@ async def test_recv_connection_closed_error(self): with self.assertRaises(ConnectionClosedError): await self.connection.recv() + async def test_recv_non_utf8_text(self): + """recv receives a non-UTF-8 text message.""" + await self.remote_connection.send(b"\x01\x02\xfe\xff", text=True) + with self.assertRaises(ConnectionClosedError): + await self.connection.recv() + await self.assertFrameSent( + Frame(Opcode.CLOSE, b"\x03\xefinvalid start byte at position 2") + ) + async def test_recv_during_recv(self): """recv raises ConcurrencyError when called concurrently.""" recv_task = asyncio.create_task(self.connection.recv()) @@ -352,6 +361,15 @@ async def test_recv_streaming_connection_closed_error(self): async for _ in self.connection.recv_streaming(): self.fail("did not raise") + async def test_recv_streaming_non_utf8_text(self): + """recv_streaming receives a non-UTF-8 text message.""" + await self.remote_connection.send(b"\x01\x02\xfe\xff", text=True) + with self.assertRaises(ConnectionClosedError): + await alist(self.connection.recv_streaming()) + await self.assertFrameSent( + Frame(Opcode.CLOSE, b"\x03\xefinvalid start byte at position 2") + ) + async def test_recv_streaming_during_recv(self): """recv_streaming raises ConcurrencyError when called concurrently with recv.""" recv_task = asyncio.create_task(self.connection.recv()) diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index db1cc8e93..abdfd3f78 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -186,6 +186,15 @@ def test_recv_connection_closed_error(self): with self.assertRaises(ConnectionClosedError): self.connection.recv() + def test_recv_non_utf8_text(self): + """recv receives a non-UTF-8 text message.""" + self.remote_connection.send(b"\x01\x02\xfe\xff", text=True) + with self.assertRaises(ConnectionClosedError): + self.connection.recv() + self.assertFrameSent( + Frame(Opcode.CLOSE, b"\x03\xefinvalid start byte at position 2") + ) + def test_recv_during_recv(self): """recv raises ConcurrencyError when called concurrently.""" recv_thread = threading.Thread(target=self.connection.recv) @@ -286,6 +295,15 @@ def test_recv_streaming_connection_closed_error(self): for _ in self.connection.recv_streaming(): self.fail("did not raise") + def test_recv_streaming_non_utf8_text(self): + """recv_streaming receives a non-UTF-8 text message.""" + self.remote_connection.send(b"\x01\x02\xfe\xff", text=True) + with self.assertRaises(ConnectionClosedError): + list(self.connection.recv_streaming()) + self.assertFrameSent( + Frame(Opcode.CLOSE, b"\x03\xefinvalid start byte at position 2") + ) + def test_recv_streaming_during_recv(self): """recv_streaming raises ConcurrencyError when called concurrently with recv.""" recv_thread = threading.Thread(target=self.connection.recv) From 0d2e246f4ee44ece6f74066cfe608e0df906e312 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 25 Oct 2024 16:36:51 +0200 Subject: [PATCH 23/51] Various fixes for the compliance test suite. --- compliance/README.rst | 17 ++++++++++++----- compliance/asyncio/client.py | 21 ++++++++++++++------- compliance/asyncio/server.py | 8 ++++++-- compliance/config/fuzzingclient.json | 2 +- compliance/config/fuzzingserver.json | 2 +- compliance/sync/client.py | 21 ++++++++++++++------- compliance/sync/server.py | 8 ++++++-- 7 files changed, 54 insertions(+), 25 deletions(-) diff --git a/compliance/README.rst b/compliance/README.rst index c7c7c93b4..ee491310f 100644 --- a/compliance/README.rst +++ b/compliance/README.rst @@ -38,7 +38,7 @@ To test the servers: crossbario/autobahn-testsuite \ wstest --mode fuzzingclient --spec /config/fuzzingclient.json - $ open reports/servers/index.html + $ open compliance/reports/servers/index.html To test the clients: @@ -54,7 +54,7 @@ To test the clients: $ PYTHONPATH=src python compliance/asyncio/client.py $ PYTHONPATH=src python compliance/sync/client.py - $ open reports/clients/index.html + $ open compliance/reports/clients/index.html Conformance notes ----------------- @@ -67,6 +67,13 @@ In 3.2, 3.3, 4.1.3, 4.1.4, 4.2.3, 4.2.4, and 5.15 websockets notices the protocol error and closes the connection at the library level before the application gets a chance to echo the previous frame. -In 6.4.3 and 6.4.4, even though it uses an incremental decoder, websockets -doesn't notice the invalid utf-8 fast enough to get a "Strict" pass. These tests -are more strict than the RFC. +In 6.4.1, 6.4.2, 6.4.3, and 6.4.4, even though it uses an incremental decoder, +websockets doesn't notice the invalid utf-8 fast enough to get a "Strict" pass. +These tests are more strict than the RFC. + +Test case 7.1.5 fails because websockets treats closing the connection in the +middle of a fragmented message as a protocol error. As a consequence, it sends +a close frame with code 1002. The test suite expects a close frame with code +1000, echoing the close code that it sent. This isn't required. RFC 6455 states +that "the endpoint typically echos the status code it received", which leaves +the possibility to send a close frame with a different status code. diff --git a/compliance/asyncio/client.py b/compliance/asyncio/client.py index 5b0bfb3ae..044ed6043 100644 --- a/compliance/asyncio/client.py +++ b/compliance/asyncio/client.py @@ -8,7 +8,9 @@ logging.basicConfig(level=logging.WARNING) -SERVER = "ws://127.0.0.1:9001" +SERVER = "ws://localhost:9001" + +AGENT = "websockets.asyncio" async def get_case_count(): @@ -18,16 +20,21 @@ async def get_case_count(): async def run_case(case): async with connect( - f"{SERVER}/runCase?case={case}", - user_agent_header="websockets.asyncio", + f"{SERVER}/runCase?case={case}&agent={AGENT}", max_size=2**25, ) as ws: - async for msg in ws: - await ws.send(msg) + try: + async for msg in ws: + await ws.send(msg) + except WebSocketException: + pass async def update_reports(): - async with connect(f"{SERVER}/updateReports", open_timeout=60): + async with connect( + f"{SERVER}/updateReports?agent={AGENT}", + open_timeout=60, + ): pass @@ -43,7 +50,7 @@ async def main(): print(f"FAIL: {type(exc).__name__}: {exc}") else: print("OK") - print("Ran {cases} test cases") + print(f"Ran {cases} test cases") await update_reports() print("Updated reports") diff --git a/compliance/asyncio/server.py b/compliance/asyncio/server.py index cff2728c9..84deb9727 100644 --- a/compliance/asyncio/server.py +++ b/compliance/asyncio/server.py @@ -2,6 +2,7 @@ import logging from websockets.asyncio.server import serve +from websockets.exceptions import WebSocketException logging.basicConfig(level=logging.WARNING) @@ -10,8 +11,11 @@ async def echo(ws): - async for msg in ws: - await ws.send(msg) + try: + async for msg in ws: + await ws.send(msg) + except WebSocketException: + pass async def main(): diff --git a/compliance/config/fuzzingclient.json b/compliance/config/fuzzingclient.json index 37f8f3a35..756ad03b6 100644 --- a/compliance/config/fuzzingclient.json +++ b/compliance/config/fuzzingclient.json @@ -5,7 +5,7 @@ }, { "url": "ws://host.docker.internal:9003" }], - "outdir": "./reports/servers", + "outdir": "/reports/servers", "cases": ["*"], "exclude-cases": [] } diff --git a/compliance/config/fuzzingserver.json b/compliance/config/fuzzingserver.json index 1bcb5959d..384caf0a2 100644 --- a/compliance/config/fuzzingserver.json +++ b/compliance/config/fuzzingserver.json @@ -1,7 +1,7 @@ { "url": "ws://localhost:9001", - "outdir": "./reports/clients", + "outdir": "/reports/clients", "cases": ["*"], "exclude-cases": [] } diff --git a/compliance/sync/client.py b/compliance/sync/client.py index e585496f3..c810e1beb 100644 --- a/compliance/sync/client.py +++ b/compliance/sync/client.py @@ -7,7 +7,9 @@ logging.basicConfig(level=logging.WARNING) -SERVER = "ws://127.0.0.1:9001" +SERVER = "ws://localhost:9001" + +AGENT = "websockets.sync" def get_case_count(): @@ -17,16 +19,21 @@ def get_case_count(): def run_case(case): with connect( - f"{SERVER}/runCase?case={case}", - user_agent_header="websockets.sync", + f"{SERVER}/runCase?case={case}&agent={AGENT}", max_size=2**25, ) as ws: - for msg in ws: - ws.send(msg) + try: + for msg in ws: + ws.send(msg) + except WebSocketException: + pass def update_reports(): - with connect(f"{SERVER}/updateReports", open_timeout=60): + with connect( + f"{SERVER}/updateReports?agent={AGENT}", + open_timeout=60, + ): pass @@ -42,7 +49,7 @@ def main(): print(f"FAIL: {type(exc).__name__}: {exc}") else: print("OK") - print("Ran {cases} test cases") + print(f"Ran {cases} test cases") update_reports() print("Updated reports") diff --git a/compliance/sync/server.py b/compliance/sync/server.py index c3cb4d989..494f56a44 100644 --- a/compliance/sync/server.py +++ b/compliance/sync/server.py @@ -1,5 +1,6 @@ import logging +from websockets.exceptions import WebSocketException from websockets.sync.server import serve @@ -9,8 +10,11 @@ def echo(ws): - for msg in ws: - ws.send(msg) + try: + for msg in ws: + ws.send(msg) + except WebSocketException: + pass def main(): From 6cea05e51d50455d66e90a1888aba9be8e8809db Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 26 Oct 2024 08:06:21 +0200 Subject: [PATCH 24/51] Support HTTP response without Content-Length. Fix #1531. --- src/websockets/asyncio/connection.py | 13 +++++++++-- src/websockets/legacy/exceptions.py | 2 +- src/websockets/sync/connection.py | 13 +++++++++-- tests/asyncio/test_client.py | 30 +++++++++++++++++++++++++ tests/sync/test_client.py | 33 +++++++++++++++++++++++++++- 5 files changed, 85 insertions(+), 6 deletions(-) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 2568249c7..5545632d6 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -1060,13 +1060,22 @@ def eof_received(self) -> None: # Feed the end of the data stream to the connection. self.protocol.receive_eof() - # This isn't expected to generate events. - assert not self.protocol.events_received() + # This isn't expected to raise an exception. + events = self.protocol.events_received() # There is no error handling because send_data() can only write # the end of the data stream here and it shouldn't raise errors. self.send_data() + # This code path is triggered when receiving an HTTP response + # without a Content-Length header. This is the only case where + # reading until EOF generates an event; all other events have + # a known length. Ignore for coverage measurement because tests + # are in test_client.py rather than test_connection.py. + for event in events: # pragma: no cover + # This isn't expected to raise an exception. + self.process_event(event) + # The WebSocket protocol has its own closing handshake: endpoints close # the TCP or TLS connection after sending and receiving a close frame. # As a consequence, they never need to write after receiving EOF, so diff --git a/src/websockets/legacy/exceptions.py b/src/websockets/legacy/exceptions.py index 9ca9b7aff..e2279c825 100644 --- a/src/websockets/legacy/exceptions.py +++ b/src/websockets/legacy/exceptions.py @@ -50,7 +50,7 @@ def __init__( headers: datastructures.HeadersLike, body: bytes = b"", ) -> None: - # If a user passes an int instead of a HTTPStatus, fix it automatically. + # If a user passes an int instead of an HTTPStatus, fix it automatically. self.status = http.HTTPStatus(status) self.headers = datastructures.Headers(headers) self.body = body diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 823b44f74..8d1dbcf58 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -696,13 +696,22 @@ def recv_events(self) -> None: # Feed the end of the data stream to the protocol. self.protocol.receive_eof() - # This isn't expected to generate events. - assert not self.protocol.events_received() + # This isn't expected to raise an exception. + events = self.protocol.events_received() # There is no error handling because send_data() can only write # the end of the data stream here and it handles errors itself. self.send_data() + # This code path is triggered when receiving an HTTP response + # without a Content-Length header. This is the only case where + # reading until EOF generates an event; all other events have + # a known length. Ignore for coverage measurement because tests + # are in test_client.py rather than test_connection.py. + for event in events: # pragma: no cover + # This isn't expected to raise an exception. + self.process_event(event) + except Exception as exc: # This branch should never run. It's a safety net in case of bugs. self.logger.error("unexpected internal error", exc_info=True) diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 9354a6e0a..1b89977ea 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -401,6 +401,36 @@ def close_connection(self, request): "connection closed while reading HTTP status line", ) + async def test_http_response(self): + """Client reads HTTP response.""" + + def http_response(connection, request): + return connection.respond(http.HTTPStatus.OK, "👌") + + async with serve(*args, process_request=http_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + + self.assertEqual(raised.exception.response.status_code, 200) + self.assertEqual(raised.exception.response.body.decode(), "👌") + + async def test_http_response_without_content_length(self): + """Client reads HTTP response without a Content-Length header.""" + + def http_response(connection, request): + response = connection.respond(http.HTTPStatus.OK, "👌") + del response.headers["Content-Length"] + return response + + async with serve(*args, process_request=http_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + + self.assertEqual(raised.exception.response.status_code, 200) + self.assertEqual(raised.exception.response.body.decode(), "👌") + async def test_junk_handshake(self): """Client closes the connection when receiving non-HTTP response from server.""" diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index e63d774b7..e9b0f63ad 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -1,3 +1,4 @@ +import http import logging import socket import socketserver @@ -6,7 +7,7 @@ import time import unittest -from websockets.exceptions import InvalidHandshake, InvalidURI +from websockets.exceptions import InvalidHandshake, InvalidStatus, InvalidURI from websockets.extensions.permessage_deflate import PerMessageDeflate from websockets.sync.client import * @@ -156,6 +157,36 @@ def close_connection(self, request): "connection closed while reading HTTP status line", ) + def test_http_response(self): + """Client reads HTTP response.""" + + def http_response(connection, request): + return connection.respond(http.HTTPStatus.OK, "👌") + + with run_server(process_request=http_response) as server: + with self.assertRaises(InvalidStatus) as raised: + with connect(get_uri(server)): + self.fail("did not raise") + + self.assertEqual(raised.exception.response.status_code, 200) + self.assertEqual(raised.exception.response.body.decode(), "👌") + + def test_http_response_without_content_length(self): + """Client reads HTTP response without a Content-Length header.""" + + def http_response(connection, request): + response = connection.respond(http.HTTPStatus.OK, "👌") + del response.headers["Content-Length"] + return response + + with run_server(process_request=http_response) as server: + with self.assertRaises(InvalidStatus) as raised: + with connect(get_uri(server)): + self.fail("did not raise") + + self.assertEqual(raised.exception.response.status_code, 200) + self.assertEqual(raised.exception.response.body.decode(), "👌") + def test_junk_handshake(self): """Client closes the connection when receiving non-HTTP response from server.""" From c75b1df159d83e46a2ae29069ae92789690ed22f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 26 Oct 2024 09:22:17 +0200 Subject: [PATCH 25/51] Mention the option to keep the legacy implementation. --- docs/project/changelog.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 410671239..1b3b0073c 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -52,6 +52,12 @@ Backwards-incompatible changes If you're using any of them, then you must follow the :doc:`upgrade guide <../howto/upgrade>` immediately. + Alternatively, you may stick to the legacy :mod:`asyncio` implementation for + now by importing it explicitly:: + + from websockets.legacy.client import connect, unix_connect + from websockets.legacy.server import broadcast, serve, unix_serve + .. admonition:: The legacy :mod:`asyncio` implementation is now deprecated. :class: caution From d248160b6b509c5701fd749cd2ef103d244c7631 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 26 Oct 2024 09:22:52 +0200 Subject: [PATCH 26/51] Raise ValueError for required or unacceptable arguments. This improves consistency with asyncio and within websockets. --- docs/project/changelog.rst | 13 +++++++++++++ src/websockets/asyncio/client.py | 4 ++-- src/websockets/asyncio/server.py | 4 +++- src/websockets/sync/client.py | 6 +++--- src/websockets/sync/server.py | 8 +++++--- tests/asyncio/test_client.py | 4 ++-- tests/asyncio/test_server.py | 4 ++-- tests/sync/test_client.py | 6 +++--- tests/sync/test_server.py | 8 ++++---- 9 files changed, 37 insertions(+), 20 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 1b3b0073c..576b7252e 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -76,6 +76,19 @@ Backwards-incompatible changes If you wrote an :class:`extension ` that relies on methods not provided by these new types, you may need to update your code. +.. admonition:: Several API raise :exc:`ValueError` instead of :exc:`TypeError` + on invalid arguments. + :class: note + + :func:`~asyncio.client.connect`, :func:`~asyncio.client.unix_connect`, and + :func:`~asyncio.server.basic_auth` in the :mod:`asyncio` implementation as + well as :func:`~sync.client.connect`, :func:`~sync.client.unix_connect`, + :func:`~sync.server.serve`, :func:`~sync.server.unix_serve`, and + :func:`~sync.server.basic_auth` in the :mod:`threading` implementation now + raise :exc:`ValueError` when a required argument isn't provided or an + argument that is incompatible with others is provided. + + New features ............ diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 0c8bedc5d..ff7916d39 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -352,10 +352,10 @@ def factory() -> ClientConnection: kwargs.setdefault("ssl", True) kwargs.setdefault("server_hostname", wsuri.host) if kwargs.get("ssl") is None: - raise TypeError("ssl=None is incompatible with a wss:// URI") + raise ValueError("ssl=None is incompatible with a wss:// URI") else: if kwargs.get("ssl") is not None: - raise TypeError("ssl argument is incompatible with a ws:// URI") + raise ValueError("ssl argument is incompatible with a ws:// URI") if kwargs.pop("unix", False): _, connection = await loop.create_unix_connection(factory, **kwargs) diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index a6ae5996d..180d3a5a9 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -890,10 +890,12 @@ def basic_auth( whether they're valid. Raises: TypeError: If ``credentials`` or ``check_credentials`` is wrong. + ValueError: If ``credentials`` and ``check_credentials`` are both + provided or both not provided. """ if (credentials is None) == (check_credentials is None): - raise TypeError("provide either credentials or check_credentials") + raise ValueError("provide either credentials or check_credentials") if credentials is not None: if is_credentials(credentials): diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 42daa32ea..0aada658e 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -221,7 +221,7 @@ def connect( wsuri = parse_uri(uri) if not wsuri.secure and ssl is not None: - raise TypeError("ssl argument is incompatible with a ws:// URI") + raise ValueError("ssl argument is incompatible with a ws:// URI") # Private APIs for unix_connect() unix: bool = kwargs.pop("unix", False) @@ -229,9 +229,9 @@ def connect( if unix: if path is None and sock is None: - raise TypeError("missing path argument") + raise ValueError("missing path argument") elif path is not None and sock is not None: - raise TypeError("path and sock arguments are incompatible") + raise ValueError("path and sock arguments are incompatible") if subprotocols is not None: validate_subprotocols(subprotocols) diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 94f76b658..44dbd7290 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -477,14 +477,14 @@ def handler(websocket): if sock is None: if unix: if path is None: - raise TypeError("missing path argument") + raise ValueError("missing path argument") kwargs.setdefault("family", socket.AF_UNIX) sock = socket.create_server(path, **kwargs) else: sock = socket.create_server((host, port), **kwargs) else: if path is not None: - raise TypeError("path and sock arguments are incompatible") + raise ValueError("path and sock arguments are incompatible") # Initialize TLS wrapper @@ -667,10 +667,12 @@ def basic_auth( whether they're valid. Raises: TypeError: If ``credentials`` or ``check_credentials`` is wrong. + ValueError: If ``credentials`` and ``check_credentials`` are both + provided or both not provided. """ if (credentials is None) == (check_credentials is None): - raise TypeError("provide either credentials or check_credentials") + raise ValueError("provide either credentials or check_credentials") if credentials is not None: if is_credentials(credentials): diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 1b89977ea..231d6b8ca 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -607,7 +607,7 @@ async def test_set_server_hostname(self): class ClientUsageErrorsTests(unittest.IsolatedAsyncioTestCase): async def test_ssl_without_secure_uri(self): """Client rejects ssl when URI isn't secure.""" - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: await connect("ws://localhost/", ssl=CLIENT_CONTEXT) self.assertEqual( str(raised.exception), @@ -616,7 +616,7 @@ async def test_ssl_without_secure_uri(self): async def test_secure_uri_without_ssl(self): """Client rejects no ssl when URI is secure.""" - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: await connect("wss://localhost/", ssl=None) self.assertEqual( str(raised.exception), diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 47e0148a6..1dcb8c7b7 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -717,7 +717,7 @@ async def check_credentials(username, password): async def test_without_credentials_or_check_credentials(self): """basic_auth requires either credentials or check_credentials.""" - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: basic_auth() self.assertEqual( str(raised.exception), @@ -726,7 +726,7 @@ async def test_without_credentials_or_check_credentials(self): async def test_with_credentials_and_check_credentials(self): """basic_auth requires only one of credentials and check_credentials.""" - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: basic_auth( credentials=("hello", "iloveyou"), check_credentials=lambda: False, # pragma: no cover diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index e9b0f63ad..9d457a912 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -311,7 +311,7 @@ def test_set_server_hostname(self): class ClientUsageErrorsTests(unittest.TestCase): def test_ssl_without_secure_uri(self): """Client rejects ssl when URI isn't secure.""" - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: connect("ws://localhost/", ssl=CLIENT_CONTEXT) self.assertEqual( str(raised.exception), @@ -320,7 +320,7 @@ def test_ssl_without_secure_uri(self): def test_unix_without_path_or_sock(self): """Unix client requires path when sock isn't provided.""" - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: unix_connect() self.assertEqual( str(raised.exception), @@ -331,7 +331,7 @@ def test_unix_with_path_and_sock(self): """Unix client rejects path when sock is provided.""" sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.addCleanup(sock.close) - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: unix_connect(path="/", sock=sock) self.assertEqual( str(raised.exception), diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 3bc6f76cd..54e49bf16 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -361,7 +361,7 @@ def test_connection(self): class ServerUsageErrorsTests(unittest.TestCase): def test_unix_without_path_or_sock(self): """Unix server requires path when sock isn't provided.""" - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: unix_serve(handler) self.assertEqual( str(raised.exception), @@ -372,7 +372,7 @@ def test_unix_with_path_and_sock(self): """Unix server rejects path when sock is provided.""" sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.addCleanup(sock.close) - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: unix_serve(handler, path="/", sock=sock) self.assertEqual( str(raised.exception), @@ -504,7 +504,7 @@ def check_credentials(username, password): def test_without_credentials_or_check_credentials(self): """basic_auth requires either credentials or check_credentials.""" - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: basic_auth() self.assertEqual( str(raised.exception), @@ -513,7 +513,7 @@ def test_without_credentials_or_check_credentials(self): def test_with_credentials_and_check_credentials(self): """basic_auth requires only one of credentials and check_credentials.""" - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: basic_auth( credentials=("hello", "iloveyou"), check_credentials=lambda: False, # pragma: no cover From 3b2c5223f7213b130d1469535fccbf9c3d08c4a9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 27 Oct 2024 07:59:16 +0100 Subject: [PATCH 27/51] Rewrite tips for Sans-I/O integration. --- docs/howto/sansio.rst | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/docs/howto/sansio.rst b/docs/howto/sansio.rst index d41519ff0..ca530e6a1 100644 --- a/docs/howto/sansio.rst +++ b/docs/howto/sansio.rst @@ -302,21 +302,24 @@ Tips Serialize operations .................... -The Sans-I/O layer expects to run sequentially. If your interact with it from -multiple threads or coroutines, you must ensure correct serialization. This -should happen automatically in a cooperative multitasking environment. +The Sans-I/O layer is designed to run sequentially. If you interact with it from +multiple threads or coroutines, you must ensure correct serialization. -However, you still have to make sure you don't break this property by -accident. For example, serialize writes to the network -when :meth:`~protocol.Protocol.data_to_send` returns multiple values to -prevent concurrent writes from interleaving incorrectly. +Usually, this comes for free in a cooperative multitasking environment. In a +preemptive multitasking environment, it requires mutual exclusion. -Avoid buffers -............. +Furthermore, you must serialize writes to the network. When +:meth:`~protocol.Protocol.data_to_send` returns several values, you must write +them all before starting the next write. -The Sans-I/O layer doesn't do any buffering. It makes events available in +Minimize buffers +................ + +The Sans-I/O layer doesn't perform any buffering. It makes events available in :meth:`~protocol.Protocol.events_received` as soon as they're received. -You should make incoming messages available to the application immediately and -stop further processing until the application fetches them. This will usually -result in the best performance. +You should make incoming messages available to the application immediately. + +A small buffer of incoming messages will usually result in the best performance. +It will reduce context switching between the library and the application while +ensuring that backpressure is propagated. From 018d2e5cf56ff03690bcbf271d188e76d59c62f3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 27 Oct 2024 07:59:35 +0100 Subject: [PATCH 28/51] Explain application-level keepalives. Fix #1514. --- docs/topics/keepalive.rst | 40 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/docs/topics/keepalive.rst b/docs/topics/keepalive.rst index 458fa3d05..003087fad 100644 --- a/docs/topics/keepalive.rst +++ b/docs/topics/keepalive.rst @@ -90,6 +90,46 @@ application layer. Read this `blog post `_ for a complete walk-through of this issue. +Application-level keepalive +--------------------------- + +Some servers require clients to send a keepalive message with a specific content +at regular intervals. Usually they expect Text_ frames rather than Ping_ frames, +meaning that you must send them with :attr:`~asyncio.connection.Connection.send` +rather than :attr:`~asyncio.connection.Connection.ping`. + +.. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + +In websockets, such keepalive mechanisms are considered as application-level +because they rely on data frames. That's unlike the protocol-level keepalive +based on control frames. Therefore, it's your responsibility to implement the +required behavior. + +You can run a task in the background to send keepalive messages: + +.. code-block:: python + + import itertools + import json + + from websockets import ConnectionClosed + + async def keepalive(websocket, ping_interval=30): + for ping in itertools.count(): + await asyncio.sleep(ping_interval) + try: + await websocket.send(json.dumps({"ping": ping})) + except ConnectionClosed: + break + + async def main(): + async with connect(...) as websocket: + keepalive_task = asyncio.create_task(keepalive(websocket)) + try: + ... # your application logic goes here + finally: + keepalive_task.cancel() + Latency issues -------------- From e44d79559d16661df4709c1b54150d735f85ae54 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 1 Nov 2024 10:52:16 +0000 Subject: [PATCH 29/51] Bump pypa/cibuildwheel from 2.20.0 to 2.21.3 Bumps [pypa/cibuildwheel](https://github.com/pypa/cibuildwheel) from 2.20.0 to 2.21.3. - [Release notes](https://github.com/pypa/cibuildwheel/releases) - [Changelog](https://github.com/pypa/cibuildwheel/blob/main/docs/changelog.md) - [Commits](https://github.com/pypa/cibuildwheel/compare/v2.20.0...v2.21.3) --- updated-dependencies: - dependency-name: pypa/cibuildwheel dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 863d88aa9..cc26502ca 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -53,7 +53,7 @@ jobs: with: platforms: all - name: Build wheels - uses: pypa/cibuildwheel@v2.20.0 + uses: pypa/cibuildwheel@v2.21.3 env: BUILD_EXTENSION: yes - name: Save wheels From 810bdeb7943e4cc0dcb662ef27ea764e31740e05 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 1 Nov 2024 12:35:38 +0100 Subject: [PATCH 30/51] Report size correctly in PayloadTooBig. Previously, it was reported incorrectly for fragmented messages. Fix #1522. --- docs/project/changelog.rst | 31 +++++--- src/websockets/exceptions.py | 41 +++++++++++ .../extensions/permessage_deflate.py | 3 +- src/websockets/frames.py | 2 +- src/websockets/legacy/framing.py | 2 +- src/websockets/protocol.py | 5 +- tests/test_exceptions.py | 20 +++++- tests/test_protocol.py | 72 ++++++++++++++----- 8 files changed, 143 insertions(+), 33 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 576b7252e..71b2a6960 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -67,15 +67,6 @@ Backwards-incompatible changes Aliases for deprecated API were removed from ``__all__``. As a consequence, they cannot be imported e.g. with ``from websockets import *`` anymore. -.. admonition:: :attr:`Frame.data ` is now a bytes-like object. - :class: note - - In addition to :class:`bytes`, it may be a :class:`bytearray` or a - :class:`memoryview`. - - If you wrote an :class:`extension ` that relies on - methods not provided by these new types, you may need to update your code. - .. admonition:: Several API raise :exc:`ValueError` instead of :exc:`TypeError` on invalid arguments. :class: note @@ -88,6 +79,26 @@ Backwards-incompatible changes raise :exc:`ValueError` when a required argument isn't provided or an argument that is incompatible with others is provided. +.. admonition:: :attr:`Frame.data ` is now a bytes-like object. + :class: note + + In addition to :class:`bytes`, it may be a :class:`bytearray` or a + :class:`memoryview`. + + If you wrote an :class:`extension ` that relies on + methods not provided by these new types, you may need to update your code. + +.. admonition:: The signature of :exc:`~exceptions.PayloadTooBig` changed. + :class: note + + If you wrote an extension that raises :exc:`~exceptions.PayloadTooBig` in + :meth:`~extensions.Extension.decode`, for example, you must replace:: + + PayloadTooBig(f"over size limit ({size} > {max_size} bytes)") + + with:: + + PayloadTooBig(size, max_size) New features ............ @@ -105,6 +116,8 @@ Improvements * Sending or receiving large compressed messages is now faster. +* Errors when a fragmented message is too large are clearer. + .. _13.1: 13.1 diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 7681736a4..be3d1ca5f 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -334,6 +334,47 @@ class PayloadTooBig(WebSocketException): """ + def __init__( + self, + size_or_message: int | None | str, + max_size: int | None = None, + cur_size: int | None = None, + ) -> None: + if isinstance(size_or_message, str): + assert max_size is None + assert cur_size is None + warnings.warn( # deprecated in 14.0 + "PayloadTooBig(message) is deprecated; " + "change to PayloadTooBig(size, max_size)", + DeprecationWarning, + ) + self.message: str | None = size_or_message + else: + self.message = None + self.size: int | None = size_or_message + assert max_size is not None + self.max_size: int = max_size + self.cur_size: int | None = None + self.set_current_size(cur_size) + + def __str__(self) -> str: + if self.message is not None: + return self.message + else: + message = "frame " + if self.size is not None: + message += f"with {self.size} bytes " + if self.cur_size is not None: + message += f"after reading {self.cur_size} bytes " + message += f"exceeds limit of {self.max_size} bytes" + return message + + def set_current_size(self, cur_size: int | None) -> None: + assert self.cur_size is None + if cur_size is not None: + self.max_size += cur_size + self.cur_size = cur_size + class InvalidState(WebSocketException, AssertionError): """ diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index ed16937d8..cefad4f56 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -139,7 +139,8 @@ def decode( try: data = self.decoder.decompress(data, max_length) if self.decoder.unconsumed_tail: - raise PayloadTooBig(f"over size limit (? > {max_size} bytes)") + assert max_size is not None # help mypy + raise PayloadTooBig(None, max_size) if frame.fin and len(frame.data) >= 2044: # This cannot generate additional data. self.decoder.decompress(_EMPTY_UNCOMPRESSED_BLOCK) diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 0ff9f4d71..7898c8a5d 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -252,7 +252,7 @@ def parse( data = yield from read_exact(8) (length,) = struct.unpack("!Q", data) if max_size is not None and length > max_size: - raise PayloadTooBig(f"over size limit ({length} > {max_size} bytes)") + raise PayloadTooBig(length, max_size) if mask: mask_bytes = yield from read_exact(4) diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index 4ec194ed7..add0c6e0e 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -93,7 +93,7 @@ async def read( data = await reader(8) (length,) = struct.unpack("!Q", data) if max_size is not None and length > max_size: - raise PayloadTooBig(f"over size limit ({length} > {max_size} bytes)") + raise PayloadTooBig(length, max_size) if mask: mask_bits = await reader(4) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 091b4a23a..19b813526 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -587,6 +587,7 @@ def parse(self) -> Generator[None]: self.parser_exc = exc except PayloadTooBig as exc: + exc.set_current_size(self.cur_size) self.fail(CloseCode.MESSAGE_TOO_BIG, str(exc)) self.parser_exc = exc @@ -639,9 +640,7 @@ def recv_frame(self, frame: Frame) -> None: if frame.opcode is OP_TEXT or frame.opcode is OP_BINARY: if self.cur_size is not None: raise ProtocolError("expected a continuation frame") - if frame.fin: - self.cur_size = None - else: + if not frame.fin: self.cur_size = len(frame.data) elif frame.opcode is OP_CONT: diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 8d41bf915..fef41d136 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -160,8 +160,16 @@ def test_str(self): "invalid opcode: 7", ), ( - PayloadTooBig("payload length exceeds limit: 2 > 1 bytes"), - "payload length exceeds limit: 2 > 1 bytes", + PayloadTooBig(None, 4), + "frame exceeds limit of 4 bytes", + ), + ( + PayloadTooBig(8, 4), + "frame with 8 bytes exceeds limit of 4 bytes", + ), + ( + PayloadTooBig(8, 4, 12), + "frame with 8 bytes after reading 12 bytes exceeds limit of 16 bytes", ), ( InvalidState("WebSocket connection isn't established yet"), @@ -202,3 +210,11 @@ def test_connection_closed_attributes_deprecation_defaults(self): "use Protocol.close_reason or ConnectionClosed.rcvd.reason" ): self.assertEqual(exception.reason, "") + + def test_payload_too_big_with_message(self): + with self.assertDeprecationWarning( + "PayloadTooBig(message) is deprecated; " + "change to PayloadTooBig(size, max_size)", + ): + exc = PayloadTooBig("payload length exceeds limit: 2 > 1 bytes") + self.assertEqual(str(exc), "payload length exceeds limit: 2 > 1 bytes") diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 7f1276bb2..0ae804bb3 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -265,18 +265,28 @@ def test_client_receives_text_over_size_limit(self): client = Protocol(CLIENT, max_size=3) client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") self.assertIsInstance(client.parser_exc, PayloadTooBig) - self.assertEqual(str(client.parser_exc), "over size limit (4 > 3 bytes)") + self.assertEqual( + str(client.parser_exc), + "frame with 4 bytes exceeds limit of 3 bytes", + ) self.assertConnectionFailing( - client, CloseCode.MESSAGE_TOO_BIG, "over size limit (4 > 3 bytes)" + client, + CloseCode.MESSAGE_TOO_BIG, + "frame with 4 bytes exceeds limit of 3 bytes", ) def test_server_receives_text_over_size_limit(self): server = Protocol(SERVER, max_size=3) server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") self.assertIsInstance(server.parser_exc, PayloadTooBig) - self.assertEqual(str(server.parser_exc), "over size limit (4 > 3 bytes)") + self.assertEqual( + str(server.parser_exc), + "frame with 4 bytes exceeds limit of 3 bytes", + ) self.assertConnectionFailing( - server, CloseCode.MESSAGE_TOO_BIG, "over size limit (4 > 3 bytes)" + server, + CloseCode.MESSAGE_TOO_BIG, + "frame with 4 bytes exceeds limit of 3 bytes", ) def test_client_receives_text_without_size_limit(self): @@ -363,9 +373,14 @@ def test_client_receives_fragmented_text_over_size_limit(self): ) client.receive_data(b"\x80\x02\x98\x80") self.assertIsInstance(client.parser_exc, PayloadTooBig) - self.assertEqual(str(client.parser_exc), "over size limit (2 > 1 bytes)") + self.assertEqual( + str(client.parser_exc), + "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", + ) self.assertConnectionFailing( - client, CloseCode.MESSAGE_TOO_BIG, "over size limit (2 > 1 bytes)" + client, + CloseCode.MESSAGE_TOO_BIG, + "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", ) def test_server_receives_fragmented_text_over_size_limit(self): @@ -377,9 +392,14 @@ def test_server_receives_fragmented_text_over_size_limit(self): ) server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") self.assertIsInstance(server.parser_exc, PayloadTooBig) - self.assertEqual(str(server.parser_exc), "over size limit (2 > 1 bytes)") + self.assertEqual( + str(server.parser_exc), + "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", + ) self.assertConnectionFailing( - server, CloseCode.MESSAGE_TOO_BIG, "over size limit (2 > 1 bytes)" + server, + CloseCode.MESSAGE_TOO_BIG, + "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", ) def test_client_receives_fragmented_text_without_size_limit(self): @@ -533,18 +553,28 @@ def test_client_receives_binary_over_size_limit(self): client = Protocol(CLIENT, max_size=3) client.receive_data(b"\x82\x04\x01\x02\xfe\xff") self.assertIsInstance(client.parser_exc, PayloadTooBig) - self.assertEqual(str(client.parser_exc), "over size limit (4 > 3 bytes)") + self.assertEqual( + str(client.parser_exc), + "frame with 4 bytes exceeds limit of 3 bytes", + ) self.assertConnectionFailing( - client, CloseCode.MESSAGE_TOO_BIG, "over size limit (4 > 3 bytes)" + client, + CloseCode.MESSAGE_TOO_BIG, + "frame with 4 bytes exceeds limit of 3 bytes", ) def test_server_receives_binary_over_size_limit(self): server = Protocol(SERVER, max_size=3) server.receive_data(b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff") self.assertIsInstance(server.parser_exc, PayloadTooBig) - self.assertEqual(str(server.parser_exc), "over size limit (4 > 3 bytes)") + self.assertEqual( + str(server.parser_exc), + "frame with 4 bytes exceeds limit of 3 bytes", + ) self.assertConnectionFailing( - server, CloseCode.MESSAGE_TOO_BIG, "over size limit (4 > 3 bytes)" + server, + CloseCode.MESSAGE_TOO_BIG, + "frame with 4 bytes exceeds limit of 3 bytes", ) def test_client_sends_fragmented_binary(self): @@ -615,9 +645,14 @@ def test_client_receives_fragmented_binary_over_size_limit(self): ) client.receive_data(b"\x80\x02\xfe\xff") self.assertIsInstance(client.parser_exc, PayloadTooBig) - self.assertEqual(str(client.parser_exc), "over size limit (2 > 1 bytes)") + self.assertEqual( + str(client.parser_exc), + "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", + ) self.assertConnectionFailing( - client, CloseCode.MESSAGE_TOO_BIG, "over size limit (2 > 1 bytes)" + client, + CloseCode.MESSAGE_TOO_BIG, + "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", ) def test_server_receives_fragmented_binary_over_size_limit(self): @@ -629,9 +664,14 @@ def test_server_receives_fragmented_binary_over_size_limit(self): ) server.receive_data(b"\x80\x82\x00\x00\x00\x00\xfe\xff") self.assertIsInstance(server.parser_exc, PayloadTooBig) - self.assertEqual(str(server.parser_exc), "over size limit (2 > 1 bytes)") + self.assertEqual( + str(server.parser_exc), + "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", + ) self.assertConnectionFailing( - server, CloseCode.MESSAGE_TOO_BIG, "over size limit (2 > 1 bytes)" + server, + CloseCode.MESSAGE_TOO_BIG, + "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", ) def test_client_sends_unexpected_binary(self): From cdeb882865145399ee0fb7d0e7623418916d6b78 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 3 Nov 2024 08:48:53 +0100 Subject: [PATCH 31/51] Don't log an error when process_request returns a response. Fix #1513. --- src/websockets/asyncio/client.py | 6 +- src/websockets/asyncio/server.py | 17 ++-- src/websockets/protocol.py | 29 ++++-- src/websockets/server.py | 6 -- src/websockets/sync/client.py | 6 +- src/websockets/sync/server.py | 11 ++- tests/asyncio/test_connection.py | 2 +- tests/asyncio/test_server.py | 146 ++++++++++++++++++++----------- tests/test_protocol.py | 20 ++++- 9 files changed, 163 insertions(+), 80 deletions(-) diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index ff7916d39..d276ac171 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -95,9 +95,9 @@ async def handshake( return_when=asyncio.FIRST_COMPLETED, ) - # self.protocol.handshake_exc is always set when the connection is lost - # before receiving a response, when the response cannot be parsed, or - # when the response fails the handshake. + # self.protocol.handshake_exc is set when the connection is lost before + # receiving a response, when the response cannot be parsed, or when the + # response fails the handshake. if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 180d3a5a9..15c9ba13e 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -192,10 +192,13 @@ async def handshake( self.protocol.send_response(self.response) - # self.protocol.handshake_exc is always set when the connection is lost - # before receiving a request, when the request cannot be parsed, when - # the handshake encounters an error, or when process_request or - # process_response sends an HTTP response that rejects the handshake. + # self.protocol.handshake_exc is set when the connection is lost before + # receiving a request, when the request cannot be parsed, or when the + # handshake fails, including when process_request or process_response + # raises an exception. + + # It isn't set when process_request or process_response sends an HTTP + # response that rejects the handshake. if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc @@ -360,7 +363,11 @@ async def conn_handler(self, connection: ServerConnection) -> None: connection.close_transport() return - assert connection.protocol.state is OPEN + if connection.protocol.state is not OPEN: + # process_request or process_response rejected the handshake. + connection.close_transport() + return + try: connection.start_keepalive() await self.handler(connection) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 19b813526..0f6fea250 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -518,15 +518,34 @@ def close_expected(self) -> bool: Whether the TCP connection is expected to close soon. """ - # We expect a TCP close if and only if we sent a close frame: + # During the opening handshake, when our state is CONNECTING, we expect + # a TCP close if and only if the hansdake fails. When it does, we start + # the TCP closing handshake by sending EOF with send_eof(). + + # Once the opening handshake completes successfully, we expect a TCP + # close if and only if we sent a close frame, meaning that our state + # progressed to CLOSING: + # * Normal closure: once we send a close frame, we expect a TCP close: # server waits for client to complete the TCP closing handshake; # client waits for server to initiate the TCP closing handshake. + # * Abnormal closure: we always send a close frame and the same logic # applies, except on EOFError where we don't send a close frame # because we already received the TCP close, so we don't expect it. - # We already got a TCP Close if and only if the state is CLOSED. - return self.state is CLOSING or self.handshake_exc is not None + + # If our state is CLOSED, we already received a TCP close so we don't + # expect it anymore. + + # Micro-optimization: put the most common case first + if self.state is OPEN: + return False + if self.state is CLOSING: + return True + if self.state is CLOSED: + return False + assert self.state is CONNECTING + return self.eof_sent # Private methods for receiving data. @@ -616,14 +635,14 @@ def discard(self) -> Generator[None]: # connection in the same circumstances where discard() replaces parse(). # The client closes it when it receives EOF from the server or times # out. (The latter case cannot be handled in this Sans-I/O layer.) - assert (self.state == CONNECTING or self.side is SERVER) == (self.eof_sent) + assert (self.side is SERVER or self.state is CONNECTING) == (self.eof_sent) while not (yield from self.reader.at_eof()): self.reader.discard() if self.debug: self.logger.debug("< EOF") # A server closes the TCP connection immediately, while a client # waits for the server to close the TCP connection. - if self.state != CONNECTING and self.side is CLIENT: + if self.side is CLIENT and self.state is not CONNECTING: self.send_eof() self.state = CLOSED # If discard() completes normally, execution ends here. diff --git a/src/websockets/server.py b/src/websockets/server.py index 527db8990..e3fdcc646 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -14,7 +14,6 @@ InvalidHeader, InvalidHeaderValue, InvalidOrigin, - InvalidStatus, InvalidUpgrade, NegotiationError, ) @@ -536,11 +535,6 @@ def send_response(self, response: Response) -> None: self.logger.info("connection open") else: - # handshake_exc may be already set if accept() encountered an error. - # If the connection isn't open, set handshake_exc to guarantee that - # handshake_exc is None if and only if opening handshake succeeded. - if self.handshake_exc is None: - self.handshake_exc = InvalidStatus(response) self.logger.info( "connection rejected (%d %s)", response.status_code, diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 0aada658e..54d0aef68 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -87,9 +87,9 @@ def handshake( if not self.response_rcvd.wait(timeout): raise TimeoutError("timed out during handshake") - # self.protocol.handshake_exc is always set when the connection is lost - # before receiving a response, when the response cannot be parsed, or - # when the response fails the handshake. + # self.protocol.handshake_exc is set when the connection is lost before + # receiving a response, when the response cannot be parsed, or when the + # response fails the handshake. if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 44dbd7290..8601ccef9 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -170,10 +170,13 @@ def handshake( self.protocol.send_response(self.response) - # self.protocol.handshake_exc is always set when the connection is lost - # before receiving a request, when the request cannot be parsed, when - # the handshake encounters an error, or when process_request or - # process_response sends an HTTP response that rejects the handshake. + # self.protocol.handshake_exc is set when the connection is lost before + # receiving a request, when the request cannot be parsed, or when the + # handshake fails, including when process_request or process_response + # raises an exception. + + # It isn't set when process_request or process_response sends an HTTP + # response that rejects the handshake. if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index a3b65e956..c98765d80 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -51,7 +51,7 @@ async def asyncTearDown(self): if sys.version_info[:2] < (3, 10): # pragma: no cover @contextlib.contextmanager - def assertNoLogs(self, logger="websockets", level=logging.ERROR): + def assertNoLogs(self, logger=None, level=None): """ No message is logged on the given logger with at least the given level. diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 1dcb8c7b7..c817f5ef6 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -148,14 +148,17 @@ def process_request(ws, request): async def handler(ws): self.fail("handler must not run") - async with serve(handler, *args[1:], process_request=process_request) as server: - with self.assertRaises(InvalidStatus) as raised: - async with connect(get_uri(server)): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "server rejected WebSocket connection: HTTP 403", - ) + with self.assertNoLogs("websockets", logging.ERROR): + async with serve( + handler, *args[1:], process_request=process_request + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) async def test_async_process_request_returns_response(self): """Server aborts handshake if async process_request returns a response.""" @@ -166,44 +169,65 @@ async def process_request(ws, request): async def handler(ws): self.fail("handler must not run") - async with serve(handler, *args[1:], process_request=process_request) as server: - with self.assertRaises(InvalidStatus) as raised: - async with connect(get_uri(server)): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "server rejected WebSocket connection: HTTP 403", - ) + with self.assertNoLogs("websockets", logging.ERROR): + async with serve( + handler, *args[1:], process_request=process_request + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) async def test_process_request_raises_exception(self): """Server returns an error if process_request raises an exception.""" def process_request(ws, request): - raise RuntimeError + raise RuntimeError("BOOM") - async with serve(*args, process_request=process_request) as server: - with self.assertRaises(InvalidStatus) as raised: - async with connect(get_uri(server)): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "server rejected WebSocket connection: HTTP 500", - ) + with self.assertLogs("websockets", logging.ERROR) as logs: + async with serve(*args, process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) async def test_async_process_request_raises_exception(self): """Server returns an error if async process_request raises an exception.""" async def process_request(ws, request): - raise RuntimeError + raise RuntimeError("BOOM") - async with serve(*args, process_request=process_request) as server: - with self.assertRaises(InvalidStatus) as raised: - async with connect(get_uri(server)): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "server rejected WebSocket connection: HTTP 500", - ) + with self.assertLogs("websockets", logging.ERROR) as logs: + async with serve(*args, process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) async def test_process_response_returns_none(self): """Server runs process_response but keeps the handshake response.""" @@ -277,31 +301,49 @@ async def test_process_response_raises_exception(self): """Server returns an error if process_response raises an exception.""" def process_response(ws, request, response): - raise RuntimeError + raise RuntimeError("BOOM") - async with serve(*args, process_response=process_response) as server: - with self.assertRaises(InvalidStatus) as raised: - async with connect(get_uri(server)): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "server rejected WebSocket connection: HTTP 500", - ) + with self.assertLogs("websockets", logging.ERROR) as logs: + async with serve(*args, process_response=process_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) async def test_async_process_response_raises_exception(self): """Server returns an error if async process_response raises an exception.""" async def process_response(ws, request, response): - raise RuntimeError + raise RuntimeError("BOOM") - async with serve(*args, process_response=process_response) as server: - with self.assertRaises(InvalidStatus) as raised: - async with connect(get_uri(server)): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "server rejected WebSocket connection: HTTP 500", - ) + with self.assertLogs("websockets", logging.ERROR) as logs: + async with serve(*args, process_response=process_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) async def test_override_server(self): """Server can override Server header with server_header.""" diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 0ae804bb3..1c092459d 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -20,7 +20,7 @@ Frame, ) from websockets.protocol import * -from websockets.protocol import CLIENT, CLOSED, CLOSING, SERVER +from websockets.protocol import CLIENT, CLOSED, CLOSING, CONNECTING, SERVER from .extensions.utils import Rsv2Extension from .test_frames import FramesTestCase @@ -1696,6 +1696,24 @@ def test_server_fails_connection(self): server.fail(CloseCode.PROTOCOL_ERROR) self.assertTrue(server.close_expected()) + def test_client_is_connecting(self): + client = Protocol(CLIENT, state=CONNECTING) + self.assertFalse(client.close_expected()) + + def test_server_is_connecting(self): + server = Protocol(SERVER, state=CONNECTING) + self.assertFalse(server.close_expected()) + + def test_client_failed_connecting(self): + client = Protocol(CLIENT, state=CONNECTING) + client.send_eof() + self.assertTrue(client.close_expected()) + + def test_server_failed_connecting(self): + server = Protocol(SERVER, state=CONNECTING) + server.send_eof() + self.assertTrue(server.close_expected()) + class ConnectionClosedTests(ProtocolTestCase): """ From 76f6f573e2ecb279230c2bf56c07bf4d4f717147 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 3 Nov 2024 09:11:51 +0100 Subject: [PATCH 32/51] Factor out backport of assertNoLogs. Fix previous commit on Python 3.9. --- tests/asyncio/test_connection.py | 25 ++++--------------------- tests/asyncio/test_server.py | 3 ++- tests/legacy/test_protocol.py | 12 ++++++------ tests/legacy/utils.py | 23 +++-------------------- tests/utils.py | 26 ++++++++++++++++++++++++++ 5 files changed, 41 insertions(+), 48 deletions(-) diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index c98765d80..d61798afb 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -19,7 +19,7 @@ from websockets.protocol import CLIENT, SERVER, Protocol, State from ..protocol import RecordingProtocol -from ..utils import MS +from ..utils import MS, AssertNoLogsMixin from .connection import InterceptingConnection from .utils import alist @@ -28,7 +28,7 @@ # All tests run on the client side and the server side to validate this. -class ClientConnectionTests(unittest.IsolatedAsyncioTestCase): +class ClientConnectionTests(AssertNoLogsMixin, unittest.IsolatedAsyncioTestCase): LOCAL = CLIENT REMOTE = SERVER @@ -48,23 +48,6 @@ async def asyncTearDown(self): await self.remote_connection.close() await self.connection.close() - if sys.version_info[:2] < (3, 10): # pragma: no cover - - @contextlib.contextmanager - def assertNoLogs(self, logger=None, level=None): - """ - No message is logged on the given logger with at least the given level. - - """ - with self.assertLogs(logger, level) as logs: - # We want to test that no log message is emitted - # but assertLogs expects at least one log message. - logging.getLogger(logger).log(level, "dummy") - yield - - level_name = logging.getLevelName(level) - self.assertEqual(logs.output, [f"{level_name}:{logger}:dummy"]) - # Test helpers built upon RecordingProtocol and InterceptingConnection. async def assertFrameSent(self, frame): @@ -1277,7 +1260,7 @@ async def test_broadcast_skips_closed_connection(self): await self.connection.close() await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) - with self.assertNoLogs(): + with self.assertNoLogs("websockets", logging.WARNING): broadcast([self.connection], "😀") await self.assertNoFrameSent() @@ -1288,7 +1271,7 @@ async def test_broadcast_skips_closing_connection(self): await asyncio.sleep(0) await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) - with self.assertNoLogs(): + with self.assertNoLogs("websockets", logging.WARNING): broadcast([self.connection], "😀") await self.assertNoFrameSent() diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index c817f5ef6..3e289e592 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -21,6 +21,7 @@ CLIENT_CONTEXT, MS, SERVER_CONTEXT, + AssertNoLogsMixin, temp_unix_socket_path, ) from .server import ( @@ -32,7 +33,7 @@ ) -class ServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): +class ServerTests(EvalShellMixin, AssertNoLogsMixin, unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Server receives connection from client and the handshake succeeds.""" async with serve(*args) as server: diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index de2a320b5..be2910a8f 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -938,7 +938,7 @@ def test_answer_ping_does_not_crash_if_connection_closing(self): self.receive_frame(Frame(True, OP_PING, b"test")) self.run_loop_once() - with self.assertNoLogs(): + with self.assertNoLogs("websockets", logging.ERROR): self.loop.run_until_complete(self.protocol.close()) self.loop.run_until_complete(close_task) # cleanup @@ -951,7 +951,7 @@ def test_answer_ping_does_not_crash_if_connection_closed(self): self.receive_eof() self.run_loop_once() - with self.assertNoLogs(): + with self.assertNoLogs("websockets", logging.ERROR): self.loop.run_until_complete(self.protocol.close()) def test_ignore_pong(self): @@ -1028,7 +1028,7 @@ def test_acknowledge_aborted_ping(self): pong_waiter.result() # transfer_data doesn't crash, which would be logged. - with self.assertNoLogs(): + with self.assertNoLogs("websockets", logging.ERROR): # Unclog incoming queue. self.loop.run_until_complete(self.protocol.recv()) self.loop.run_until_complete(self.protocol.recv()) @@ -1375,7 +1375,7 @@ def test_remote_close_and_connection_lost(self): self.receive_eof() self.run_loop_once() - with self.assertNoLogs(): + with self.assertNoLogs("websockets", logging.ERROR): self.loop.run_until_complete(self.protocol.close(reason="oh noes!")) self.assertConnectionClosed(CloseCode.NORMAL_CLOSURE, "close") @@ -1500,14 +1500,14 @@ def test_broadcast_two_clients(self): def test_broadcast_skips_closed_connection(self): self.close_connection() - with self.assertNoLogs(): + with self.assertNoLogs("websockets", logging.ERROR): broadcast([self.protocol], "café") self.assertNoFrameSent() def test_broadcast_skips_closing_connection(self): close_task = self.half_close_connection_local() - with self.assertNoLogs(): + with self.assertNoLogs("websockets", logging.ERROR): broadcast([self.protocol], "café") self.assertNoFrameSent() diff --git a/tests/legacy/utils.py b/tests/legacy/utils.py index 5b56050d5..1f79bb600 100644 --- a/tests/legacy/utils.py +++ b/tests/legacy/utils.py @@ -1,12 +1,12 @@ import asyncio -import contextlib import functools -import logging import sys import unittest +from ..utils import AssertNoLogsMixin -class AsyncioTestCase(unittest.TestCase): + +class AsyncioTestCase(AssertNoLogsMixin, unittest.TestCase): """ Base class for tests that sets up an isolated event loop for each test. @@ -56,23 +56,6 @@ def run_loop_once(self): self.loop.call_soon(self.loop.stop) self.loop.run_forever() - if sys.version_info[:2] < (3, 10): # pragma: no cover - - @contextlib.contextmanager - def assertNoLogs(self, logger="websockets", level=logging.ERROR): - """ - No message is logged on the given logger with at least the given level. - - """ - with self.assertLogs(logger, level) as logs: - # We want to test that no log message is emitted - # but assertLogs expects at least one log message. - logging.getLogger(logger).log(level, "dummy") - yield - - level_name = logging.getLevelName(level) - self.assertEqual(logs.output, [f"{level_name}:{logger}:dummy"]) - def assertDeprecationWarnings(self, recorded_warnings, expected_warnings): """ Check recorded deprecation warnings match a list of expected messages. diff --git a/tests/utils.py b/tests/utils.py index 639fb7fe5..77d020726 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,9 +1,11 @@ import contextlib import email.utils +import logging import os import pathlib import platform import ssl +import sys import tempfile import time import unittest @@ -113,6 +115,30 @@ def assertDeprecationWarning(self, message): self.assertEqual(str(warning.message), message) +class AssertNoLogsMixin: + """ + Backport of assertNoLogs for Python 3.9. + + """ + + if sys.version_info[:2] < (3, 10): # pragma: no cover + + @contextlib.contextmanager + def assertNoLogs(self, logger=None, level=None): + """ + No message is logged on the given logger with at least the given level. + + """ + with self.assertLogs(logger, level) as logs: + # We want to test that no log message is emitted + # but assertLogs expects at least one log message. + logging.getLogger(logger).log(level, "dummy") + yield + + level_name = logging.getLevelName(level) + self.assertEqual(logs.output, [f"{level_name}:{logger}:dummy"]) + + @contextlib.contextmanager def temp_unix_socket_path(): with tempfile.TemporaryDirectory() as temp_dir: From 0a5a79c224c9be97f79909f7218fd1da7b2acabb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 3 Nov 2024 09:47:51 +0100 Subject: [PATCH 33/51] Clean up prefixes of debug log messages. All debug messages and only debug messages should have them. --- docs/topics/logging.rst | 8 ++++++-- src/websockets/asyncio/client.py | 2 +- src/websockets/asyncio/connection.py | 6 +++--- src/websockets/legacy/client.py | 4 ++-- src/websockets/legacy/protocol.py | 8 ++++---- src/websockets/sync/connection.py | 15 ++++++++++++--- tests/legacy/test_client_server.py | 4 ++-- 7 files changed, 30 insertions(+), 17 deletions(-) diff --git a/docs/topics/logging.rst b/docs/topics/logging.rst index be5678455..ae71be265 100644 --- a/docs/topics/logging.rst +++ b/docs/topics/logging.rst @@ -218,7 +218,10 @@ Here's what websockets logs at each level. ``ERROR`` ......... -* Exceptions raised by connection handler coroutines in servers +* Exceptions raised by your code in servers + * connection handler coroutines + * ``select_subprotocol`` callbacks + * ``process_request`` and ``process_response`` callbacks * Exceptions resulting from bugs in websockets ``WARNING`` @@ -250,4 +253,5 @@ Debug messages have cute prefixes that make logs easier to scan: * ``=`` - set connection state * ``x`` - shut down connection * ``%`` - manage pings and pongs -* ``!`` - handle errors and timeouts +* ``-`` - timeout +* ``!`` - error, with a traceback diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index d276ac171..302d0b94d 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -521,7 +521,7 @@ async def __aiter__(self) -> AsyncIterator[ClientConnection]: delays = backoff() delay = next(delays) self.logger.info( - "! connect failed; reconnecting in %.1f seconds", + "connect failed; reconnecting in %.1f seconds", delay, exc_info=True, ) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 5545632d6..c4961884c 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -785,7 +785,7 @@ async def keepalive(self) -> None: self.logger.debug("% received keepalive pong") except asyncio.TimeoutError: if self.debug: - self.logger.debug("! timed out waiting for keepalive pong") + self.logger.debug("- timed out waiting for keepalive pong") async with self.send_context(): self.protocol.fail( CloseCode.INTERNAL_ERROR, @@ -866,7 +866,7 @@ async def send_context( await self.drain() except Exception as exc: if self.debug: - self.logger.debug("error while sending data", exc_info=True) + self.logger.debug("! error while sending data", exc_info=True) # While the only expected exception here is OSError, # other exceptions would be treated identically. wait_for_close = False @@ -1042,7 +1042,7 @@ def data_received(self, data: bytes) -> None: self.send_data() except Exception as exc: if self.debug: - self.logger.debug("error while sending data", exc_info=True) + self.logger.debug("! error while sending data", exc_info=True) self.set_recv_exc(exc) if self.protocol.close_expected(): diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 116445e25..a2dc0250f 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -603,14 +603,14 @@ async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]: if backoff_delay == self.BACKOFF_MIN: initial_delay = random.random() * self.BACKOFF_INITIAL self.logger.info( - "! connect failed; reconnecting in %.1f seconds", + "connect failed; reconnecting in %.1f seconds", initial_delay, exc_info=True, ) await asyncio.sleep(initial_delay) else: self.logger.info( - "! connect failed again; retrying in %d seconds", + "connect failed again; retrying in %d seconds", int(backoff_delay), exc_info=True, ) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index cedde6200..bd998dfd1 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1246,7 +1246,7 @@ async def keepalive_ping(self) -> None: self.logger.debug("% received keepalive pong") except asyncio.TimeoutError: if self.debug: - self.logger.debug("! timed out waiting for keepalive pong") + self.logger.debug("- timed out waiting for keepalive pong") self.fail_connection( CloseCode.INTERNAL_ERROR, "keepalive ping timeout", @@ -1288,7 +1288,7 @@ async def close_connection(self) -> None: if await self.wait_for_connection_lost(): return if self.debug: - self.logger.debug("! timed out waiting for TCP close") + self.logger.debug("- timed out waiting for TCP close") # Half-close the TCP connection if possible (when there's no TLS). if self.transport.can_write_eof(): @@ -1306,7 +1306,7 @@ async def close_connection(self) -> None: if await self.wait_for_connection_lost(): return if self.debug: - self.logger.debug("! timed out waiting for TCP close") + self.logger.debug("- timed out waiting for TCP close") finally: # The try/finally ensures that the transport never remains open, @@ -1332,7 +1332,7 @@ async def close_transport(self) -> None: if await self.wait_for_connection_lost(): return if self.debug: - self.logger.debug("! timed out waiting for TCP close") + self.logger.debug("- timed out waiting for TCP close") # Abort the TCP connection. Buffers are discarded. if self.debug: diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 8d1dbcf58..77f803c9b 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -640,7 +640,10 @@ def recv_events(self) -> None: data = self.socket.recv(self.recv_bufsize) except Exception as exc: if self.debug: - self.logger.debug("error while receiving data", exc_info=True) + self.logger.debug( + "! error while receiving data", + exc_info=True, + ) # When the closing handshake is initiated by our side, # recv() may block until send_context() closes the socket. # In that case, send_context() already set recv_exc. @@ -665,7 +668,10 @@ def recv_events(self) -> None: self.send_data() except Exception as exc: if self.debug: - self.logger.debug("error while sending data", exc_info=True) + self.logger.debug( + "! error while sending data", + exc_info=True, + ) # Similarly to the above, avoid overriding an exception # set by send_context(), in case of a race condition # i.e. send_context() closes the socket after recv() @@ -783,7 +789,10 @@ def send_context( self.send_data() except Exception as exc: if self.debug: - self.logger.debug("error while sending data", exc_info=True) + self.logger.debug( + "! error while sending data", + exc_info=True, + ) # While the only expected exception here is OSError, # other exceptions would be treated identically. wait_for_close = False diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 502ab68e7..375d47e29 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1615,12 +1615,12 @@ async def run_client(): [ "connection rejected (503 Service Unavailable)", "connection closed", - "! connect failed; reconnecting in X seconds", + "connect failed; reconnecting in X seconds", ] + [ "connection rejected (503 Service Unavailable)", "connection closed", - "! connect failed again; retrying in X seconds", + "connect failed again; retrying in X seconds", ] * ((len(logs.records) - 8) // 3) + [ From 9b3595d8d4d0573e00209f9e920f6a6fab981fa9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 4 Nov 2024 22:58:18 +0100 Subject: [PATCH 34/51] Remove stack traces from INFO and WARNING logs. Fix #1501. --- src/websockets/asyncio/client.py | 6 ++++-- src/websockets/asyncio/connection.py | 9 +++++++-- src/websockets/legacy/client.py | 13 ++++++++----- src/websockets/legacy/protocol.py | 9 +++++++-- tests/asyncio/test_connection.py | 5 ++++- tests/legacy/test_client_server.py | 8 ++++++-- tests/legacy/test_protocol.py | 4 ++-- 7 files changed, 38 insertions(+), 16 deletions(-) diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 302d0b94d..b3d50c12e 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -3,6 +3,7 @@ import asyncio import logging import os +import traceback import urllib.parse from collections.abc import AsyncIterator, Generator, Sequence from types import TracebackType @@ -521,9 +522,10 @@ async def __aiter__(self) -> AsyncIterator[ClientConnection]: delays = backoff() delay = next(delays) self.logger.info( - "connect failed; reconnecting in %.1f seconds", + "connect failed; reconnecting in %.1f seconds: %s", delay, - exc_info=True, + # Remove first argument when dropping Python 3.9. + traceback.format_exception_only(type(exc), exc)[0].strip(), ) await asyncio.sleep(delay) continue diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index c4961884c..186846ef3 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -7,6 +7,7 @@ import random import struct import sys +import traceback import uuid from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterable, Mapping from types import TracebackType @@ -1180,8 +1181,12 @@ def broadcast( exceptions.append(exception) else: connection.logger.warning( - "skipped broadcast: failed to write message", - exc_info=True, + "skipped broadcast: failed to write message: %s", + traceback.format_exception_only( + # Remove first argument when dropping Python 3.9. + type(write_exception), + write_exception, + )[0].strip(), ) if raise_exceptions and exceptions: diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index a2dc0250f..555069e8c 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -5,6 +5,7 @@ import logging import os import random +import traceback import urllib.parse import warnings from collections.abc import AsyncIterator, Generator, Sequence @@ -597,22 +598,24 @@ async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]: try: async with self as protocol: yield protocol - except Exception: + except Exception as exc: # Add a random initial delay between 0 and 5 seconds. # See 7.2.3. Recovering from Abnormal Closure in RFC 6455. if backoff_delay == self.BACKOFF_MIN: initial_delay = random.random() * self.BACKOFF_INITIAL self.logger.info( - "connect failed; reconnecting in %.1f seconds", + "connect failed; reconnecting in %.1f seconds: %s", initial_delay, - exc_info=True, + # Remove first argument when dropping Python 3.9. + traceback.format_exception_only(type(exc), exc)[0].strip(), ) await asyncio.sleep(initial_delay) else: self.logger.info( - "connect failed again; retrying in %d seconds", + "connect failed again; retrying in %d seconds: %s", int(backoff_delay), - exc_info=True, + # Remove first argument when dropping Python 3.9. + traceback.format_exception_only(type(exc), exc)[0].strip(), ) await asyncio.sleep(int(backoff_delay)) # Increase delay with truncated exponential backoff. diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index bd998dfd1..db126c01e 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -9,6 +9,7 @@ import struct import sys import time +import traceback import uuid import warnings from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterable, Mapping @@ -1624,8 +1625,12 @@ def broadcast( exceptions.append(exception) else: websocket.logger.warning( - "skipped broadcast: failed to write message", - exc_info=True, + "skipped broadcast: failed to write message: %s", + traceback.format_exception_only( + # Remove first argument when dropping Python 3.9. + type(write_exception), + write_exception, + )[0].strip(), ) if raise_exceptions and exceptions: diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index d61798afb..902b3b847 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -1337,7 +1337,10 @@ async def test_broadcast_skips_connection_failing_to_send(self): self.assertEqual( [record.getMessage() for record in logs.records], - ["skipped broadcast: failed to write message"], + [ + "skipped broadcast: failed to write message: " + "RuntimeError: Cannot call write() after write_eof()" + ], ) @unittest.skipIf( diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 375d47e29..c13c6c92e 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1607,6 +1607,10 @@ async def run_client(): ], ) # Iteration 3 + exc = ( + "websockets.legacy.exceptions.InvalidStatusCode: " + "server rejected WebSocket connection: HTTP 503" + ) self.assertEqual( [ re.sub(r"[0-9\.]+ seconds", "X seconds", record.getMessage()) @@ -1615,12 +1619,12 @@ async def run_client(): [ "connection rejected (503 Service Unavailable)", "connection closed", - "connect failed; reconnecting in X seconds", + f"connect failed; reconnecting in X seconds: {exc}", ] + [ "connection rejected (503 Service Unavailable)", "connection closed", - "connect failed again; retrying in X seconds", + f"connect failed again; retrying in X seconds: {exc}", ] * ((len(logs.records) - 8) // 3) + [ diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index be2910a8f..d30198934 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1547,14 +1547,14 @@ def test_broadcast_reports_connection_sending_fragmented_text(self): def test_broadcast_skips_connection_failing_to_send(self): # Configure mock to raise an exception when writing to the network. - self.protocol.transport.write.side_effect = RuntimeError + self.protocol.transport.write.side_effect = RuntimeError("BOOM") with self.assertLogs("websockets", logging.WARNING) as logs: broadcast([self.protocol], "café") self.assertEqual( [record.getMessage() for record in logs.records], - ["skipped broadcast: failed to write message"], + ["skipped broadcast: failed to write message: RuntimeError: BOOM"], ) @unittest.skipIf( From 5f34e2741e94ac21ef99e3dc212aa152d77d1a37 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 9 Nov 2024 08:55:26 +0100 Subject: [PATCH 35/51] Fix remaining instances of shortcut imports. --- docs/topics/broadcast.rst | 2 +- docs/topics/keepalive.rst | 2 +- experiments/broadcast/server.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/topics/broadcast.rst b/docs/topics/broadcast.rst index c9699feb2..66b0819b2 100644 --- a/docs/topics/broadcast.rst +++ b/docs/topics/broadcast.rst @@ -83,7 +83,7 @@ to:: Here's a coroutine that broadcasts a message to all clients:: - from websockets import ConnectionClosed + from websockets.exceptions import ConnectionClosed async def broadcast(message): for websocket in CLIENTS.copy(): diff --git a/docs/topics/keepalive.rst b/docs/topics/keepalive.rst index 003087fad..4897de2ba 100644 --- a/docs/topics/keepalive.rst +++ b/docs/topics/keepalive.rst @@ -112,7 +112,7 @@ You can run a task in the background to send keepalive messages: import itertools import json - from websockets import ConnectionClosed + from websockets.exceptions import ConnectionClosed async def keepalive(websocket, ping_interval=30): for ping in itertools.count(): diff --git a/experiments/broadcast/server.py b/experiments/broadcast/server.py index d5b50bd71..eca55357e 100644 --- a/experiments/broadcast/server.py +++ b/experiments/broadcast/server.py @@ -6,8 +6,8 @@ import sys import time -from websockets import ConnectionClosed from websockets.asyncio.server import broadcast, serve +from websockets.exceptions import ConnectionClosed CLIENTS = set() From c57bcb743fe1128f6d28e2eaebbdd202eb3ee2eb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 9 Nov 2024 09:36:42 +0100 Subject: [PATCH 36/51] Standardize links to RFC. There were 70 links to https://datatracker.ietf.org/doc/html/ vs. 15 links https://www.rfc-editor.org/rfc/. Also :rfc:`....` links to https://datatracker.ietf.org/doc/html/ by default. While https://www.ietf.org/process/rfcs/#introduction says: > The RFC Editor website is the authoritative site for RFCs. the IETF Datatracker looks a bit better and has more information. --- docs/faq/common.rst | 4 ++-- docs/howto/extensions.rst | 2 +- docs/project/changelog.rst | 2 +- docs/reference/extensions.rst | 2 +- docs/topics/design.rst | 10 +++++----- docs/topics/keepalive.rst | 6 +++--- docs/topics/logging.rst | 4 ++-- 7 files changed, 15 insertions(+), 15 deletions(-) diff --git a/docs/faq/common.rst b/docs/faq/common.rst index 0dc4a3aeb..ba7a95932 100644 --- a/docs/faq/common.rst +++ b/docs/faq/common.rst @@ -131,8 +131,8 @@ How do I respond to pings? If you are referring to Ping_ and Pong_ frames defined in the WebSocket protocol, don't bother, because websockets handles them for you. -.. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 -.. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 +.. _Ping: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5.5.2 +.. _Pong: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5.5.3 If you are connecting to a server that defines its own heartbeat at the application level, then you need to build that logic into your application. diff --git a/docs/howto/extensions.rst b/docs/howto/extensions.rst index 3c8a7d72a..c4e9da626 100644 --- a/docs/howto/extensions.rst +++ b/docs/howto/extensions.rst @@ -7,7 +7,7 @@ During the opening handshake, WebSocket clients and servers negotiate which extensions_ will be used with which parameters. Then each frame is processed by extensions before being sent or after being received. -.. _extensions: https://www.rfc-editor.org/rfc/rfc6455.html#section-9 +.. _extensions: https://datatracker.ietf.org/doc/html/rfc6455.html#section-9 As a consequence, writing an extension requires implementing several classes: diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 71b2a6960..1056bc980 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -1487,7 +1487,7 @@ New features * Added support for providing and checking Origin_. -.. _Origin: https://www.rfc-editor.org/rfc/rfc6455.html#section-10.2 +.. _Origin: https://datatracker.ietf.org/doc/html/rfc6455.html#section-10.2 .. _2.0: diff --git a/docs/reference/extensions.rst b/docs/reference/extensions.rst index a70f1b1e5..f3da464a5 100644 --- a/docs/reference/extensions.rst +++ b/docs/reference/extensions.rst @@ -8,7 +8,7 @@ The WebSocket protocol supports extensions_. At the time of writing, there's only one `registered extension`_ with a public specification, WebSocket Per-Message Deflate. -.. _extensions: https://www.rfc-editor.org/rfc/rfc6455.html#section-9 +.. _extensions: https://datatracker.ietf.org/doc/html/rfc6455.html#section-9 .. _registered extension: https://www.iana.org/assignments/websocket/websocket.xhtml#extension-name Per-Message Deflate diff --git a/docs/topics/design.rst b/docs/topics/design.rst index b73ace517..bc14bd332 100644 --- a/docs/topics/design.rst +++ b/docs/topics/design.rst @@ -173,16 +173,16 @@ differences between a server and a client: - `closing the TCP connection`_: the server closes the connection immediately; the client waits for the server to do it. -.. _client-to-server masking: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.3 -.. _closing the TCP connection: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.1 +.. _client-to-server masking: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5.3 +.. _closing the TCP connection: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5.5.1 These differences are so minor that all the logic for `data framing`_, for `sending and receiving data`_ and for `closing the connection`_ is implemented in the same class, :class:`~protocol.WebSocketCommonProtocol`. -.. _data framing: https://www.rfc-editor.org/rfc/rfc6455.html#section-5 -.. _sending and receiving data: https://www.rfc-editor.org/rfc/rfc6455.html#section-6 -.. _closing the connection: https://www.rfc-editor.org/rfc/rfc6455.html#section-7 +.. _data framing: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5 +.. _sending and receiving data: https://datatracker.ietf.org/doc/html/rfc6455.html#section-6 +.. _closing the connection: https://datatracker.ietf.org/doc/html/rfc6455.html#section-7 The :attr:`~protocol.WebSocketCommonProtocol.is_client` attribute tells which side a protocol instance is managing. This attribute is defined on the diff --git a/docs/topics/keepalive.rst b/docs/topics/keepalive.rst index 4897de2ba..a0467ced2 100644 --- a/docs/topics/keepalive.rst +++ b/docs/topics/keepalive.rst @@ -33,8 +33,8 @@ Keepalive in websockets To avoid these problems, websockets runs a keepalive and heartbeat mechanism based on WebSocket Ping_ and Pong_ frames, which are designed for this purpose. -.. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 -.. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 +.. _Ping: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5.5.2 +.. _Pong: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5.5.3 It sends a Ping frame every 20 seconds. It expects a Pong frame in return within 20 seconds. Else, it considers the connection broken and terminates it. @@ -98,7 +98,7 @@ at regular intervals. Usually they expect Text_ frames rather than Ping_ frames, meaning that you must send them with :attr:`~asyncio.connection.Connection.send` rather than :attr:`~asyncio.connection.Connection.ping`. -.. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 +.. _Text: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5.6 In websockets, such keepalive mechanisms are considered as application-level because they rely on data frames. That's unlike the protocol-level keepalive diff --git a/docs/topics/logging.rst b/docs/topics/logging.rst index ae71be265..fff33a024 100644 --- a/docs/topics/logging.rst +++ b/docs/topics/logging.rst @@ -35,8 +35,8 @@ Instead, when running as a server, websockets logs one event when a `connection is established`_ and another event when a `connection is closed`_. -.. _connection is established: https://www.rfc-editor.org/rfc/rfc6455.html#section-4 -.. _connection is closed: https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.4 +.. _connection is established: https://datatracker.ietf.org/doc/html/rfc6455.html#section-4 +.. _connection is closed: https://datatracker.ietf.org/doc/html/rfc6455.html#section-7.1.4 By default, websockets doesn't log an event for every message. That would be excessive for many applications exchanging small messages at a fast rate. If From 178c88447c2754d2d2b0c02472868cbaef7cec52 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 9 Nov 2024 13:56:49 +0100 Subject: [PATCH 37/51] Complete and polish changelog for 14.0. --- docs/project/changelog.rst | 41 +++++++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 1056bc980..161f71c7f 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -41,7 +41,7 @@ Backwards-incompatible changes websockets 13.1 is the last version supporting Python 3.8. .. admonition:: The new :mod:`asyncio` implementation is now the default. - :class: caution + :class: danger The following aliases in the ``websockets`` package were switched to the new :mod:`asyncio` implementation:: @@ -64,8 +64,8 @@ Backwards-incompatible changes The :doc:`upgrade guide <../howto/upgrade>` provides complete instructions to migrate your application. - Aliases for deprecated API were removed from ``__all__``. As a consequence, - they cannot be imported e.g. with ``from websockets import *`` anymore. + Aliases for deprecated API were removed from ``websockets.__all__``, meaning + that they cannot be imported with ``from websockets import *`` anymore. .. admonition:: Several API raise :exc:`ValueError` instead of :exc:`TypeError` on invalid arguments. @@ -83,22 +83,16 @@ Backwards-incompatible changes :class: note In addition to :class:`bytes`, it may be a :class:`bytearray` or a - :class:`memoryview`. - - If you wrote an :class:`extension ` that relies on - methods not provided by these new types, you may need to update your code. + :class:`memoryview`. If you wrote an :class:`~extensions.Extension` that + relies on methods not provided by these types, you must update your code. .. admonition:: The signature of :exc:`~exceptions.PayloadTooBig` changed. :class: note If you wrote an extension that raises :exc:`~exceptions.PayloadTooBig` in - :meth:`~extensions.Extension.decode`, for example, you must replace:: - - PayloadTooBig(f"over size limit ({size} > {max_size} bytes)") - - with:: - - PayloadTooBig(size, max_size) + :meth:`~extensions.Extension.decode`, for example, you must replace + ``PayloadTooBig(f"over size limit ({size} > {max_size} bytes)")`` with + ``PayloadTooBig(size, max_size)``. New features ............ @@ -106,8 +100,8 @@ New features * Added an option to receive text frames as :class:`bytes`, without decoding, in the :mod:`threading` implementation; also binary frames as :class:`str`. -* Added an option to send :class:`bytes` as a text frame in the :mod:`asyncio` - and :mod:`threading` implementations, as well as :class:`str` a binary frame. +* Added an option to send :class:`bytes` in a text frame in the :mod:`asyncio` + and :mod:`threading` implementations; also :class:`str` in a binary frame. Improvements ............ @@ -118,6 +112,21 @@ Improvements * Errors when a fragmented message is too large are clearer. +* Log messages at the :data:`~logging.WARNING` and :data:`~logging.INFO` levels + no longer include stack traces. + +Bug fixes +......... + +* Clients no longer crash when the server rejects the opening handshake and the + HTTP response doesn't Include a ``Content-Length`` header. + +* Returning an HTTP response in ``process_request`` or ``process_response`` + doesn't generate a log message at the :data:`~logging.ERROR` level anymore. + +* Connections are closed with code 1007 (invalid data) when receiving invalid + UTF-8 in a text frame. + .. _13.1: 13.1 From f0d20aafab027e9b99460b193dcb709872b219a5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 9 Nov 2024 13:58:10 +0100 Subject: [PATCH 38/51] Release version 14.0. --- docs/project/changelog.rst | 2 +- src/websockets/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 161f71c7f..5aa58a09b 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -30,7 +30,7 @@ notice. 14.0 ---- -*In development* +*November 9, 2024* Backwards-incompatible changes .............................. diff --git a/src/websockets/version.py b/src/websockets/version.py index 34fc2eaef..7c64f566a 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -18,7 +18,7 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = False +released = True tag = version = commit = "14.0" From b9d74504eaeedd35cb7dc2651a18420f10e3828d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 9 Nov 2024 14:01:07 +0100 Subject: [PATCH 39/51] Start version 14.1. --- docs/project/changelog.rst | 7 +++++++ src/websockets/version.py | 4 ++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 5aa58a09b..9e6a9d113 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,6 +25,13 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented, private APIs may change without notice. +.. _14.1: + +14.1 +---- + +*In development* + .. _14.0: 14.0 diff --git a/src/websockets/version.py b/src/websockets/version.py index 7c64f566a..48d2edaea 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -18,9 +18,9 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = True +released = False -tag = version = commit = "14.0" +tag = version = commit = "14.1" if not released: # pragma: no cover From e9fc77da927793d05072163d61e137dd35f97e4d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 9 Nov 2024 14:01:22 +0100 Subject: [PATCH 40/51] Add dates for deprecations. --- src/websockets/__init__.py | 2 +- src/websockets/auth.py | 2 +- src/websockets/client.py | 2 +- src/websockets/exceptions.py | 8 ++++---- src/websockets/legacy/__init__.py | 2 +- src/websockets/server.py | 2 +- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 531ce49f7..0c7e9b4c6 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -151,7 +151,7 @@ "handshake": ".legacy", "parse_uri": ".uri", "WebSocketURI": ".uri", - # deprecated in 14.0 + # deprecated in 14.0 - 2024-11-09 # .legacy.auth "BasicAuthWebSocketServerProtocol": ".legacy.auth", "basic_auth_protocol_factory": ".legacy.auth", diff --git a/src/websockets/auth.py b/src/websockets/auth.py index 98e62af3c..15b70a372 100644 --- a/src/websockets/auth.py +++ b/src/websockets/auth.py @@ -10,7 +10,7 @@ from .legacy.auth import __all__ # noqa: F401 -warnings.warn( # deprecated in 14.0 +warnings.warn( # deprecated in 14.0 - 2024-11-09 "websockets.auth, an alias for websockets.legacy.auth, is deprecated; " "see https://websockets.readthedocs.io/en/stable/howto/upgrade.html " "for upgrade instructions", diff --git a/src/websockets/client.py b/src/websockets/client.py index 8b66900a8..f6cbc9f65 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -392,7 +392,7 @@ def backoff( lazy_import( globals(), deprecated_aliases={ - # deprecated in 14.0 + # deprecated in 14.0 - 2024-11-09 "WebSocketClientProtocol": ".legacy.client", "connect": ".legacy.client", "unix_connect": ".legacy.client", diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index be3d1ca5f..f3e751971 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -113,7 +113,7 @@ def __str__(self) -> str: @property def code(self) -> int: - warnings.warn( # deprecated in 13.1 + warnings.warn( # deprecated in 13.1 - 2024-09-21 "ConnectionClosed.code is deprecated; " "use Protocol.close_code or ConnectionClosed.rcvd.code", DeprecationWarning, @@ -124,7 +124,7 @@ def code(self) -> int: @property def reason(self) -> str: - warnings.warn( # deprecated in 13.1 + warnings.warn( # deprecated in 13.1 - 2024-09-21 "ConnectionClosed.reason is deprecated; " "use Protocol.close_reason or ConnectionClosed.rcvd.reason", DeprecationWarning, @@ -343,7 +343,7 @@ def __init__( if isinstance(size_or_message, str): assert max_size is None assert cur_size is None - warnings.warn( # deprecated in 14.0 + warnings.warn( # deprecated in 14.0 - 2024-11-09 "PayloadTooBig(message) is deprecated; " "change to PayloadTooBig(size, max_size)", DeprecationWarning, @@ -408,7 +408,7 @@ class ConcurrencyError(WebSocketException, RuntimeError): lazy_import( globals(), deprecated_aliases={ - # deprecated in 14.0 + # deprecated in 14.0 - 2024-11-09 "AbortHandshake": ".legacy.exceptions", "InvalidMessage": ".legacy.exceptions", "InvalidStatusCode": ".legacy.exceptions", diff --git a/src/websockets/legacy/__init__.py b/src/websockets/legacy/__init__.py index 84f870f3a..ad9aa2506 100644 --- a/src/websockets/legacy/__init__.py +++ b/src/websockets/legacy/__init__.py @@ -3,7 +3,7 @@ import warnings -warnings.warn( # deprecated in 14.0 +warnings.warn( # deprecated in 14.0 - 2024-11-09 "websockets.legacy is deprecated; " "see https://websockets.readthedocs.io/en/stable/howto/upgrade.html " "for upgrade instructions", diff --git a/src/websockets/server.py b/src/websockets/server.py index e3fdcc646..607cc306e 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -580,7 +580,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: lazy_import( globals(), deprecated_aliases={ - # deprecated in 14.0 + # deprecated in 14.0 - 2024-11-09 "WebSocketServer": ".legacy.server", "WebSocketServerProtocol": ".legacy.server", "broadcast": ".legacy.server", From 083bcacd485a12dc1f9c6b98123bb92fafa4a9cf Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 11 Nov 2024 18:50:43 +0100 Subject: [PATCH 41/51] Import ConnectionClosed from websockets.exceptions. Fix #1539. --- docs/faq/client.rst | 3 ++- docs/faq/server.rst | 2 +- docs/intro/tutorial1.rst | 4 +++- src/websockets/asyncio/client.py | 2 +- src/websockets/legacy/client.py | 2 +- 5 files changed, 8 insertions(+), 5 deletions(-) diff --git a/docs/faq/client.rst b/docs/faq/client.rst index 0a7aab6e2..cc9856a8b 100644 --- a/docs/faq/client.rst +++ b/docs/faq/client.rst @@ -84,11 +84,12 @@ How do I reconnect when the connection drops? Use :func:`~websockets.asyncio.client.connect` as an asynchronous iterator:: from websockets.asyncio.client import connect + from websockets.exceptions import ConnectionClosed async for websocket in connect(...): try: ... - except websockets.ConnectionClosed: + except ConnectionClosed: continue Make sure you handle exceptions in the ``async for`` loop. Uncaught exceptions diff --git a/docs/faq/server.rst b/docs/faq/server.rst index 63eb5ffc6..ce7e1962d 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -147,7 +147,7 @@ Then, call :meth:`~ServerConnection.send`:: async def message_user(user_id, message): websocket = CONNECTIONS[user_id] # raises KeyError if user disconnected - await websocket.send(message) # may raise websockets.ConnectionClosed + await websocket.send(message) # may raise websockets.exceptions.ConnectionClosed Add error handling according to the behavior you want if the user disconnected before the message could be sent. diff --git a/docs/intro/tutorial1.rst b/docs/intro/tutorial1.rst index 6e91867c8..87074caee 100644 --- a/docs/intro/tutorial1.rst +++ b/docs/intro/tutorial1.rst @@ -271,11 +271,13 @@ spot real errors when you add functionality to the server. Catch it in the .. code-block:: python + from websockets.exceptions import ConnectionClosedOK + async def handler(websocket): while True: try: message = await websocket.recv() - except websockets.ConnectionClosedOK: + except ConnectionClosedOK: break print(message) diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index b3d50c12e..74ae70f0d 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -182,7 +182,7 @@ class connect: async for websocket in connect(...): try: ... - except websockets.ConnectionClosed: + except websockets.exceptions.ConnectionClosed: continue If the connection fails with a transient error, it is retried with diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 555069e8c..a3856b470 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -349,7 +349,7 @@ class Connect: async for websocket in connect(...): try: ... - except websockets.ConnectionClosed: + except websockets.exceptions.ConnectionClosed: continue The connection is closed automatically after each iteration of the loop. From 86bf0c5afc4c88429231ac7ba5f857fd32a02d0e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 11 Nov 2024 17:40:55 +0100 Subject: [PATCH 42/51] Remove unnecessary branch. --- src/websockets/asyncio/messages.py | 3 +-- tests/asyncio/test_messages.py | 8 -------- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index b57c0ca4e..69636e3d0 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -43,8 +43,7 @@ def put(self, item: T) -> None: async def get(self) -> T: """Remove and return an item from the queue, waiting if necessary.""" if not self.queue: - if self.get_waiter is not None: - raise ConcurrencyError("get is already running") + assert self.get_waiter is None, "cannot call get() concurrently" self.get_waiter = self.loop.create_future() try: await self.get_waiter diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py index 2ff929d3a..5c9ac9445 100644 --- a/tests/asyncio/test_messages.py +++ b/tests/asyncio/test_messages.py @@ -37,14 +37,6 @@ async def test_get_then_put(self): item = await getter_task self.assertEqual(item, 42) - async def test_get_concurrently(self): - """get cannot be called concurrently.""" - getter_task = asyncio.create_task(self.queue.get()) - await asyncio.sleep(0) # let the task start - with self.assertRaises(ConcurrencyError): - await self.queue.get() - getter_task.cancel() - async def test_reset(self): """reset sets the content of the queue.""" self.queue.reset([42]) From f17c11ab6b46cfe6817cbab5d92ba2626d3e87d5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 11 Nov 2024 17:42:16 +0100 Subject: [PATCH 43/51] Keep queued messages after abort. --- src/websockets/asyncio/messages.py | 2 -- tests/asyncio/test_messages.py | 7 ------- 2 files changed, 9 deletions(-) diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index 69636e3d0..678a3c14e 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -62,8 +62,6 @@ def abort(self) -> None: """Close the queue, raising EOFError in get() if necessary.""" if self.get_waiter is not None and not self.get_waiter.done(): self.get_waiter.set_exception(EOFError("stream of frames ended")) - # Clear the queue to avoid storing unnecessary data in memory. - self.queue.clear() class Assembler: diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py index 5c9ac9445..181ffd376 100644 --- a/tests/asyncio/test_messages.py +++ b/tests/asyncio/test_messages.py @@ -51,13 +51,6 @@ async def test_abort(self): with self.assertRaises(EOFError): await getter_task - async def test_abort_clears_queue(self): - """abort clears buffered data from the queue.""" - self.queue.put(42) - self.assertEqual(len(self.queue), 1) - self.queue.abort() - self.assertEqual(len(self.queue), 0) - class AssemblerTests(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): From bdfc8cf90301b528eebfbd0ef8a31e5a2fc7d5f7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 11 Nov 2024 18:11:02 +0100 Subject: [PATCH 44/51] Uniformize comments. --- src/websockets/asyncio/messages.py | 8 ++++---- src/websockets/sync/messages.py | 11 +++++------ 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index 678a3c14e..814a3c03c 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -141,8 +141,8 @@ async def get(self, decode: bool | None = None) -> Data: self.get_in_progress = True - # Locking with get_in_progress prevents concurrent execution until - # get() fetches a complete message or is cancelled. + # Locking with get_in_progress prevents concurrent execution + # until get() fetches a complete message or is cancelled. try: # First frame @@ -208,8 +208,8 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: self.get_in_progress = True - # Locking with get_in_progress prevents concurrent execution until - # get_iter() fetches a complete message or is cancelled. + # Locking with get_in_progress prevents concurrent execution + # until get_iter() fetches a complete message or is cancelled. # If get_iter() raises an exception e.g. in decoder.decode(), # get_in_progress remains set and the connection becomes unusable. diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index 17f8dce7e..ce08172b2 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -128,10 +128,11 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - - # Locking with get_in_progress ensures only one thread can get here. self.get_in_progress = True + # Locking with get_in_progress prevents concurrent execution + # until get() fetches a complete message or times out. + try: deadline = Deadline(timeout) @@ -198,12 +199,10 @@ def get_iter(self, decode: bool | None = None) -> Iterator[Data]: if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - - # Locking with get_in_progress ensures only one coroutine can get here. self.get_in_progress = True - # Locking with get_in_progress prevents concurrent execution until - # get_iter() fetches a complete message or is cancelled. + # Locking with get_in_progress prevents concurrent execution + # until get_iter() fetches a complete message or times out. # If get_iter() raises an exception e.g. in decoder.decode(), # get_in_progress remains set and the connection becomes unusable. From 303483412dc5b420d09c1421792b3f8b99c323e6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 11 Nov 2024 18:30:36 +0100 Subject: [PATCH 45/51] Support recv() after the connection is closed. Fix #1538. --- docs/project/changelog.rst | 7 ++++ src/websockets/asyncio/messages.py | 20 ++++-------- src/websockets/sync/messages.py | 27 ++++++++-------- tests/asyncio/test_connection.py | 8 ++--- tests/asyncio/test_messages.py | 52 ++++++++++++++++++++++++++++++ tests/sync/test_connection.py | 19 ++++------- tests/sync/test_messages.py | 52 ++++++++++++++++++++++++++++++ 7 files changed, 142 insertions(+), 43 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 9e6a9d113..074a81c85 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -34,6 +34,13 @@ notice. .. _14.0: +Bug fixes +......... + +* Once the connection is closed, messages previously received and buffered can + be read in the :mod:`asyncio` and :mod:`threading` implementations, just like + in the legacy implementation. + 14.0 ---- diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index 814a3c03c..14ea7bf90 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -40,9 +40,11 @@ def put(self, item: T) -> None: if self.get_waiter is not None and not self.get_waiter.done(): self.get_waiter.set_result(None) - async def get(self) -> T: + async def get(self, block: bool = True) -> T: """Remove and return an item from the queue, waiting if necessary.""" if not self.queue: + if not block: + raise EOFError("stream of frames ended") assert self.get_waiter is None, "cannot call get() concurrently" self.get_waiter = self.loop.create_future() try: @@ -133,12 +135,8 @@ async def get(self, decode: bool | None = None) -> Data: :meth:`get_iter` concurrently. """ - if self.closed: - raise EOFError("stream of frames ended") - if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - self.get_in_progress = True # Locking with get_in_progress prevents concurrent execution @@ -146,7 +144,7 @@ async def get(self, decode: bool | None = None) -> Data: try: # First frame - frame = await self.frames.get() + frame = await self.frames.get(not self.closed) self.maybe_resume() assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY if decode is None: @@ -156,7 +154,7 @@ async def get(self, decode: bool | None = None) -> Data: # Following frames, for fragmented messages while not frame.fin: try: - frame = await self.frames.get() + frame = await self.frames.get(not self.closed) except asyncio.CancelledError: # Put frames already received back into the queue # so that future calls to get() can return them. @@ -200,12 +198,8 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: :meth:`get_iter` concurrently. """ - if self.closed: - raise EOFError("stream of frames ended") - if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - self.get_in_progress = True # Locking with get_in_progress prevents concurrent execution @@ -216,7 +210,7 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: # First frame try: - frame = await self.frames.get() + frame = await self.frames.get(not self.closed) except asyncio.CancelledError: self.get_in_progress = False raise @@ -236,7 +230,7 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: # previous fragments — we're streaming them. Canceling get_iter() # here will leave the assembler in a stuck state. Future calls to # get() or get_iter() will raise ConcurrencyError. - frame = await self.frames.get() + frame = await self.frames.get(not self.closed) self.maybe_resume() assert frame.opcode is OP_CONT if decode: diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index ce08172b2..af8635f16 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -69,10 +69,16 @@ def __init__( def get_next_frame(self, timeout: float | None = None) -> Frame: # Helper to factor out the logic for getting the next frame from the # queue, while handling timeouts and reaching the end of the stream. - try: - frame = self.frames.get(timeout=timeout) - except queue.Empty: - raise TimeoutError(f"timed out in {timeout:.1f}s") from None + if self.closed: + try: + frame = self.frames.get(block=False) + except queue.Empty: + raise EOFError("stream of frames ended") from None + else: + try: + frame = self.frames.get(block=True, timeout=timeout) + except queue.Empty: + raise TimeoutError(f"timed out in {timeout:.1f}s") from None if frame is None: raise EOFError("stream of frames ended") return frame @@ -87,7 +93,7 @@ def reset_queue(self, frames: Iterable[Frame]) -> None: queued = [] try: while True: - queued.append(self.frames.get_nowait()) + queued.append(self.frames.get(block=False)) except queue.Empty: pass for frame in frames: @@ -123,9 +129,6 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: """ with self.mutex: - if self.closed: - raise EOFError("stream of frames ended") - if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") self.get_in_progress = True @@ -194,9 +197,6 @@ def get_iter(self, decode: bool | None = None) -> Iterator[Data]: """ with self.mutex: - if self.closed: - raise EOFError("stream of frames ended") - if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") self.get_in_progress = True @@ -288,5 +288,6 @@ def close(self) -> None: self.closed = True - # Unblock get() or get_iter(). - self.frames.put(None) + if self.get_in_progress: + # Unblock get() or get_iter(). + self.frames.put(None) diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 902b3b847..b1c57c8ca 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -793,14 +793,12 @@ async def test_close_timeout_waiting_for_connection_closed(self): # Remove socket.timeout when dropping Python < 3.10. self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) - async def test_close_does_not_wait_for_recv(self): - # Closing the connection discards messages buffered in the assembler. - # This is allowed by the RFC: - # > However, there is no guarantee that the endpoint that has already - # > sent a Close frame will continue to process data. + async def test_close_preserves_queued_messages(self): + """close preserves messages buffered in the assembler.""" await self.remote_connection.send("😀") await self.connection.close() + self.assertEqual(await self.connection.recv(), "😀") with self.assertRaises(ConnectionClosedOK) as raised: await self.connection.recv() diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py index 181ffd376..566f71cea 100644 --- a/tests/asyncio/test_messages.py +++ b/tests/asyncio/test_messages.py @@ -395,6 +395,58 @@ async def test_get_iter_fails_after_close(self): async for _ in self.assembler.get_iter(): self.fail("no fragment expected") + async def test_get_queued_message_after_close(self): + """get returns a message after close is called.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.close() + message = await self.assembler.get() + self.assertEqual(message, "café") + + async def test_get_iter_queued_message_after_close(self): + """get_iter yields a message after close is called.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.close() + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + async def test_get_queued_fragmented_message_after_close(self): + """get reassembles a fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + self.assembler.close() + self.assembler.close() + message = await self.assembler.get() + self.assertEqual(message, b"tea") + + async def test_get_iter_queued_fragmented_message_after_close(self): + """get_iter yields a fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + self.assembler.close() + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, [b"t", b"e", b"a"]) + + async def test_get_partially_queued_fragmented_message_after_close(self): + """get raises EOF on a partial fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.close() + with self.assertRaises(EOFError): + await self.assembler.get() + + async def test_get_iter_partially_queued_fragmented_message_after_close(self): + """get_iter yields a partial fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.close() + fragments = [] + with self.assertRaises(EOFError): + async for fragment in self.assembler.get_iter(): + fragments.append(fragment) + self.assertEqual(fragments, [b"t", b"e"]) + async def test_put_fails_after_close(self): """put raises EOFError after close is called.""" self.assembler.close() diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index abdfd3f78..408b9697a 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -543,17 +543,12 @@ def test_close_timeout_waiting_for_connection_closed(self): # Remove socket.timeout when dropping Python < 3.10. self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) - def test_close_does_not_wait_for_recv(self): - # Closing the connection discards messages buffered in the assembler. - # This is allowed by the RFC: - # > However, there is no guarantee that the endpoint that has already - # > sent a Close frame will continue to process data. + def test_close_preserves_queued_messages(self): + """close preserves messages buffered in the assembler.""" self.remote_connection.send("😀") self.connection.close() - close_thread = threading.Thread(target=self.connection.close) - close_thread.start() - + self.assertEqual(self.connection.recv(), "😀") with self.assertRaises(ConnectionClosedOK) as raised: self.connection.recv() @@ -576,10 +571,10 @@ def test_close_idempotency(self): def test_close_idempotency_race_condition(self): """close waits if the connection is already closing.""" - self.connection.close_timeout = 5 * MS + self.connection.close_timeout = 6 * MS def closer(): - with self.delay_frames_rcvd(3 * MS): + with self.delay_frames_rcvd(4 * MS): self.connection.close() close_thread = threading.Thread(target=closer) @@ -591,14 +586,14 @@ def closer(): # Connection isn't closed yet. with self.assertRaises(TimeoutError): - self.connection.recv(timeout=0) + self.connection.recv(timeout=MS) self.connection.close() self.assertNoFrameSent() # Connection is closed now. with self.assertRaises(ConnectionClosedOK): - self.connection.recv(timeout=0) + self.connection.recv(timeout=MS) close_thread.join() diff --git a/tests/sync/test_messages.py b/tests/sync/test_messages.py index 02513894a..9ebe45088 100644 --- a/tests/sync/test_messages.py +++ b/tests/sync/test_messages.py @@ -374,6 +374,58 @@ def test_get_iter_fails_after_close(self): for _ in self.assembler.get_iter(): self.fail("no fragment expected") + def test_get_queued_message_after_close(self): + """get returns a message after close is called.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.close() + message = self.assembler.get() + self.assertEqual(message, "café") + + def test_get_iter_queued_message_after_close(self): + """get_iter yields a message after close is called.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.close() + fragments = list(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + def test_get_queued_fragmented_message_after_close(self): + """get reassembles a fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + self.assembler.close() + self.assembler.close() + message = self.assembler.get() + self.assertEqual(message, b"tea") + + def test_get_iter_queued_fragmented_message_after_close(self): + """get_iter yields a fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + self.assembler.close() + fragments = list(self.assembler.get_iter()) + self.assertEqual(fragments, [b"t", b"e", b"a"]) + + def test_get_partially_queued_fragmented_message_after_close(self): + """get raises EOF on a partial fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.close() + with self.assertRaises(EOFError): + self.assembler.get() + + def test_get_iter_partially_queued_fragmented_message_after_close(self): + """get_iter yields a partial fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.close() + fragments = [] + with self.assertRaises(EOFError): + for fragment in self.assembler.get_iter(): + fragments.append(fragment) + self.assertEqual(fragments, [b"t", b"e"]) + def test_put_fails_after_close(self): """put raises EOFError after close is called.""" self.assembler.close() From 9a2f39fc66fd2427f904e551b1cc5f3995b02217 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 11 Nov 2024 21:13:39 +0100 Subject: [PATCH 46/51] Support max_queue=None like the legacy implementation. Fix #1540. --- docs/project/changelog.rst | 7 ++++ src/websockets/asyncio/client.py | 7 ++-- src/websockets/asyncio/connection.py | 4 +- src/websockets/asyncio/messages.py | 23 ++++++++--- src/websockets/asyncio/server.py | 7 ++-- src/websockets/sync/client.py | 7 ++-- src/websockets/sync/connection.py | 4 +- src/websockets/sync/messages.py | 25 +++++++++--- src/websockets/sync/server.py | 7 ++-- tests/asyncio/test_connection.py | 12 +++++- tests/asyncio/test_messages.py | 61 +++++++++++++++++++++++++--- tests/sync/test_connection.py | 17 +++++++- tests/sync/test_messages.py | 61 +++++++++++++++++++++++++--- 13 files changed, 200 insertions(+), 42 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 074a81c85..8e1ad81f0 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -34,6 +34,13 @@ notice. .. _14.0: +Improvements +............ + +* Supported ``max_queue=None`` in the :mod:`asyncio` and :mod:`threading` + implementations for consistency with the legacy implementation, even though + this is never a good idea. + Bug fixes ......... diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 74ae70f0d..cdd9bfac6 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -60,7 +60,7 @@ def __init__( ping_interval: float | None = 20, ping_timeout: float | None = 20, close_timeout: float | None = 10, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, ) -> None: self.protocol: ClientProtocol @@ -222,7 +222,8 @@ class connect: max_queue: High-water mark of the buffer where frames are received. It defaults to 16 frames. The low-water mark defaults to ``max_queue // 4``. You may pass a ``(high, low)`` tuple to set the high-water - and low-water marks. + and low-water marks. If you want to disable flow control entirely, + you may set it to ``None``, although that's a bad idea. write_limit: High-water mark of write buffer in bytes. It is passed to :meth:`~asyncio.WriteTransport.set_write_buffer_limits`. It defaults to 32 KiB. You may pass a ``(high, low)`` tuple to set the @@ -283,7 +284,7 @@ def __init__( close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, # Logging logger: LoggerLike | None = None, diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 186846ef3..f1dcbada6 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -56,14 +56,14 @@ def __init__( ping_interval: float | None = 20, ping_timeout: float | None = 20, close_timeout: float | None = 10, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, ) -> None: self.protocol = protocol self.ping_interval = ping_interval self.ping_timeout = ping_timeout self.close_timeout = close_timeout - if isinstance(max_queue, int): + if isinstance(max_queue, int) or max_queue is None: max_queue = (max_queue, None) self.max_queue = max_queue if isinstance(write_limit, int): diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index 14ea7bf90..e6d1d31cc 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -84,7 +84,7 @@ class Assembler: # coverage reports incorrectly: "line NN didn't jump to the function exit" def __init__( # pragma: no cover self, - high: int = 16, + high: int | None = None, low: int | None = None, pause: Callable[[], Any] = lambda: None, resume: Callable[[], Any] = lambda: None, @@ -96,12 +96,15 @@ def __init__( # pragma: no cover # call to Protocol.data_received() could produce thousands of frames, # which must be buffered. Instead, we pause reading when the buffer goes # above the high limit and we resume when it goes under the low limit. - if low is None: + if high is not None and low is None: low = high // 4 - if low < 0: - raise ValueError("low must be positive or equal to zero") - if high < low: - raise ValueError("high must be greater than or equal to low") + if high is None and low is not None: + high = low * 4 + if high is not None and low is not None: + if low < 0: + raise ValueError("low must be positive or equal to zero") + if high < low: + raise ValueError("high must be greater than or equal to low") self.high, self.low = high, low self.pause = pause self.resume = resume @@ -256,6 +259,10 @@ def put(self, frame: Frame) -> None: def maybe_pause(self) -> None: """Pause the writer if queue is above the high water mark.""" + # Skip if flow control is disabled + if self.high is None: + return + # Check for "> high" to support high = 0 if len(self.frames) > self.high and not self.paused: self.paused = True @@ -263,6 +270,10 @@ def maybe_pause(self) -> None: def maybe_resume(self) -> None: """Resume the writer if queue is below the low water mark.""" + # Skip if flow control is disabled + if self.low is None: + return + # Check for "<= low" to support low = 0 if len(self.frames) <= self.low and self.paused: self.paused = False diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 15c9ba13e..fdb928004 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -71,7 +71,7 @@ def __init__( ping_interval: float | None = 20, ping_timeout: float | None = 20, close_timeout: float | None = 10, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, ) -> None: self.protocol: ServerProtocol @@ -643,7 +643,8 @@ def handler(websocket): max_queue: High-water mark of the buffer where frames are received. It defaults to 16 frames. The low-water mark defaults to ``max_queue // 4``. You may pass a ``(high, low)`` tuple to set the high-water - and low-water marks. + and low-water marks. If you want to disable flow control entirely, + you may set it to ``None``, although that's a bad idea. write_limit: High-water mark of write buffer in bytes. It is passed to :meth:`~asyncio.WriteTransport.set_write_buffer_limits`. It defaults to 32 KiB. You may pass a ``(high, low)`` tuple to set the @@ -713,7 +714,7 @@ def __init__( close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, # Logging logger: LoggerLike | None = None, diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 54d0aef68..9e6da7caf 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -55,7 +55,7 @@ def __init__( protocol: ClientProtocol, *, close_timeout: float | None = 10, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, ) -> None: self.protocol: ClientProtocol self.response_rcvd = threading.Event() @@ -139,7 +139,7 @@ def connect( close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, # Logging logger: LoggerLike | None = None, # Escape hatch for advanced customization @@ -191,7 +191,8 @@ def connect( max_queue: High-water mark of the buffer where frames are received. It defaults to 16 frames. The low-water mark defaults to ``max_queue // 4``. You may pass a ``(high, low)`` tuple to set the high-water - and low-water marks. + and low-water marks. If you want to disable flow control entirely, + you may set it to ``None``, although that's a bad idea. logger: Logger for this client. It defaults to ``logging.getLogger("websockets.client")``. See the :doc:`logging guide <../../topics/logging>` for details. diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 77f803c9b..be3381c8a 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -49,12 +49,12 @@ def __init__( protocol: Protocol, *, close_timeout: float | None = 10, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, ) -> None: self.socket = socket self.protocol = protocol self.close_timeout = close_timeout - if isinstance(max_queue, int): + if isinstance(max_queue, int) or max_queue is None: max_queue = (max_queue, None) self.max_queue = max_queue diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index af8635f16..98490797f 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -33,7 +33,7 @@ class Assembler: def __init__( self, - high: int = 16, + high: int | None = None, low: int | None = None, pause: Callable[[], Any] = lambda: None, resume: Callable[[], Any] = lambda: None, @@ -49,12 +49,15 @@ def __init__( # call to Protocol.data_received() could produce thousands of frames, # which must be buffered. Instead, we pause reading when the buffer goes # above the high limit and we resume when it goes under the low limit. - if low is None: + if high is not None and low is None: low = high // 4 - if low < 0: - raise ValueError("low must be positive or equal to zero") - if high < low: - raise ValueError("high must be greater than or equal to low") + if high is None and low is not None: + high = low * 4 + if high is not None and low is not None: + if low < 0: + raise ValueError("low must be positive or equal to zero") + if high < low: + raise ValueError("high must be greater than or equal to low") self.high, self.low = high, low self.pause = pause self.resume = resume @@ -260,7 +263,12 @@ def put(self, frame: Frame) -> None: def maybe_pause(self) -> None: """Pause the writer if queue is above the high water mark.""" + # Skip if flow control is disabled + if self.high is None: + return + assert self.mutex.locked() + # Check for "> high" to support high = 0 if self.frames.qsize() > self.high and not self.paused: self.paused = True @@ -268,7 +276,12 @@ def maybe_pause(self) -> None: def maybe_resume(self) -> None: """Resume the writer if queue is below the low water mark.""" + # Skip if flow control is disabled + if self.low is None: + return + assert self.mutex.locked() + # Check for "<= low" to support low = 0 if self.frames.qsize() <= self.low and self.paused: self.paused = False diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 8601ccef9..9506d6830 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -66,7 +66,7 @@ def __init__( protocol: ServerProtocol, *, close_timeout: float | None = 10, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, ) -> None: self.protocol: ServerProtocol self.request_rcvd = threading.Event() @@ -356,7 +356,7 @@ def serve( close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, # Logging logger: LoggerLike | None = None, # Escape hatch for advanced customization @@ -438,7 +438,8 @@ def handler(websocket): max_queue: High-water mark of the buffer where frames are received. It defaults to 16 frames. The low-water mark defaults to ``max_queue // 4``. You may pass a ``(high, low)`` tuple to set the high-water - and low-water marks. + and low-water marks. If you want to disable flow control entirely, + you may set it to ``None``, although that's a bad idea. logger: Logger for this server. It defaults to ``logging.getLogger("websockets.server")``. See the :doc:`logging guide <../../topics/logging>` for details. diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index b1c57c8ca..8dd0a0335 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -1066,14 +1066,22 @@ async def test_close_timeout(self): self.assertEqual(connection.close_timeout, 42 * MS) async def test_max_queue(self): - """max_queue parameter configures high-water mark of frames buffer.""" + """max_queue configures high-water mark of frames buffer.""" connection = Connection(Protocol(self.LOCAL), max_queue=4) transport = Mock() connection.connection_made(transport) self.assertEqual(connection.recv_messages.high, 4) + async def test_max_queue_none(self): + """max_queue disables high-water mark of frames buffer.""" + connection = Connection(Protocol(self.LOCAL), max_queue=None) + transport = Mock() + connection.connection_made(transport) + self.assertEqual(connection.recv_messages.high, None) + self.assertEqual(connection.recv_messages.low, None) + async def test_max_queue_tuple(self): - """max_queue parameter configures high-water mark of frames buffer.""" + """max_queue configures high-water and low-water marks of frames buffer.""" connection = Connection( Protocol(self.LOCAL), max_queue=(4, 2), diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py index 566f71cea..a90788d02 100644 --- a/tests/asyncio/test_messages.py +++ b/tests/asyncio/test_messages.py @@ -153,7 +153,7 @@ async def test_get_decoded_binary_message(self): self.assertEqual(message, "tea") async def test_get_resumes_reading(self): - """get resumes reading when queue goes below the high-water mark.""" + """get resumes reading when queue goes below the low-water mark.""" self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) self.assembler.put(Frame(OP_TEXT, b"water")) @@ -170,6 +170,19 @@ async def test_get_resumes_reading(self): await self.assembler.get() self.resume.assert_called_once_with() + async def test_get_does_not_resume_reading(self): + """get does not resume reading when the low-water mark is unset.""" + self.assembler.low = None + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"water")) + await self.assembler.get() + await self.assembler.get() + await self.assembler.get() + + self.resume.assert_not_called() + async def test_cancel_get_before_first_frame(self): """get can be canceled safely before reading the first frame.""" getter_task = asyncio.create_task(self.assembler.get()) @@ -302,7 +315,7 @@ async def test_get_iter_decoded_binary_message(self): self.assertEqual(fragments, ["t", "e", "a"]) async def test_get_iter_resumes_reading(self): - """get_iter resumes reading when queue goes below the high-water mark.""" + """get_iter resumes reading when queue goes below the low-water mark.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) @@ -321,6 +334,20 @@ async def test_get_iter_resumes_reading(self): await anext(iterator) self.resume.assert_called_once_with() + async def test_get_iter_does_not_resume_reading(self): + """get_iter does not resume reading when the low-water mark is unset.""" + self.assembler.low = None + + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + iterator = aiter(self.assembler.get_iter()) + await anext(iterator) + await anext(iterator) + await anext(iterator) + + self.resume.assert_not_called() + async def test_cancel_get_iter_before_first_frame(self): """get_iter can be canceled safely before reading the first frame.""" getter_task = asyncio.create_task(alist(self.assembler.get_iter())) @@ -367,6 +394,17 @@ async def test_put_pauses_reading(self): self.assembler.put(Frame(OP_CONT, b"a")) self.pause.assert_called_once_with() + async def test_put_does_not_pause_reading(self): + """put does not pause reading when the high-water mark is unset.""" + self.assembler.high = None + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + + self.pause.assert_not_called() + # Test termination async def test_get_fails_when_interrupted_by_close(self): @@ -495,16 +533,29 @@ async def test_get_iter_fails_when_get_iter_is_running(self): # Test setting limits async def test_set_high_water_mark(self): - """high sets the high-water mark.""" + """high sets the high-water and low-water marks.""" assembler = Assembler(high=10) self.assertEqual(assembler.high, 10) + self.assertEqual(assembler.low, 2) + + async def test_set_low_water_mark(self): + """low sets the low-water and high-water marks.""" + assembler = Assembler(low=5) + self.assertEqual(assembler.low, 5) + self.assertEqual(assembler.high, 20) - async def test_set_high_and_low_water_mark(self): - """high sets the high-water mark and low-water mark.""" + async def test_set_high_and_low_water_marks(self): + """high and low set the high-water and low-water marks.""" assembler = Assembler(high=10, low=5) self.assertEqual(assembler.high, 10) self.assertEqual(assembler.low, 5) + async def test_unset_high_and_low_water_marks(self): + """High-water and low-water marks are unset.""" + assembler = Assembler() + self.assertEqual(assembler.high, None) + self.assertEqual(assembler.low, None) + async def test_set_invalid_high_water_mark(self): """high must be a non-negative integer.""" with self.assertRaises(ValueError): diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 408b9697a..6be490a5d 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -749,7 +749,7 @@ def test_close_timeout(self): self.assertEqual(connection.close_timeout, 42 * MS) def test_max_queue(self): - """max_queue parameter configures high-water mark of frames buffer.""" + """max_queue configures high-water mark of frames buffer.""" socket_, remote_socket = socket.socketpair() self.addCleanup(socket_.close) self.addCleanup(remote_socket.close) @@ -760,8 +760,21 @@ def test_max_queue(self): ) self.assertEqual(connection.recv_messages.high, 4) + def test_max_queue_none(self): + """max_queue disables high-water mark of frames buffer.""" + socket_, remote_socket = socket.socketpair() + self.addCleanup(socket_.close) + self.addCleanup(remote_socket.close) + connection = Connection( + socket_, + Protocol(self.LOCAL), + max_queue=None, + ) + self.assertEqual(connection.recv_messages.high, None) + self.assertEqual(connection.recv_messages.high, None) + def test_max_queue_tuple(self): - """max_queue parameter configures high-water mark of frames buffer.""" + """max_queue configures high-water and low-water marks of frames buffer.""" socket_, remote_socket = socket.socketpair() self.addCleanup(socket_.close) self.addCleanup(remote_socket.close) diff --git a/tests/sync/test_messages.py b/tests/sync/test_messages.py index 9ebe45088..d22693102 100644 --- a/tests/sync/test_messages.py +++ b/tests/sync/test_messages.py @@ -145,7 +145,7 @@ def test_get_decoded_binary_message(self): self.assertEqual(message, "tea") def test_get_resumes_reading(self): - """get resumes reading when queue goes below the high-water mark.""" + """get resumes reading when queue goes below the low-water mark.""" self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) self.assembler.put(Frame(OP_TEXT, b"water")) @@ -162,6 +162,19 @@ def test_get_resumes_reading(self): self.assembler.get() self.resume.assert_called_once_with() + def test_get_does_not_resume_reading(self): + """get does not resume reading when the low-water mark is unset.""" + self.assembler.low = None + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"water")) + self.assembler.get() + self.assembler.get() + self.assembler.get() + + self.resume.assert_not_called() + def test_get_timeout_before_first_frame(self): """get times out before reading the first frame.""" with self.assertRaises(TimeoutError): @@ -300,7 +313,7 @@ def test_get_iter_decoded_binary_message(self): self.assertEqual(fragments, ["t", "e", "a"]) def test_get_iter_resumes_reading(self): - """get_iter resumes reading when queue goes below the high-water mark.""" + """get_iter resumes reading when queue goes below the low-water mark.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) @@ -319,6 +332,20 @@ def test_get_iter_resumes_reading(self): next(iterator) self.resume.assert_called_once_with() + def test_get_iter_does_not_resume_reading(self): + """get_iter does not resume reading when the low-water mark is unset.""" + self.assembler.low = None + + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + iterator = self.assembler.get_iter() + next(iterator) + next(iterator) + next(iterator) + + self.resume.assert_not_called() + # Test put def test_put_pauses_reading(self): @@ -336,6 +363,17 @@ def test_put_pauses_reading(self): self.assembler.put(Frame(OP_CONT, b"a")) self.pause.assert_called_once_with() + def test_put_does_not_pause_reading(self): + """put does not pause reading when the high-water mark is unset.""" + self.assembler.high = None + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + + self.pause.assert_not_called() + # Test termination def test_get_fails_when_interrupted_by_close(self): @@ -470,16 +508,29 @@ def test_get_iter_fails_when_get_iter_is_running(self): # Test setting limits def test_set_high_water_mark(self): - """high sets the high-water mark.""" + """high sets the high-water and low-water marks.""" assembler = Assembler(high=10) self.assertEqual(assembler.high, 10) + self.assertEqual(assembler.low, 2) + + def test_set_low_water_mark(self): + """low sets the low-water and high-water marks.""" + assembler = Assembler(low=5) + self.assertEqual(assembler.low, 5) + self.assertEqual(assembler.high, 20) - def test_set_high_and_low_water_mark(self): - """high sets the high-water mark and low-water mark.""" + def test_set_high_and_low_water_marks(self): + """high and low set the high-water and low-water marks.""" assembler = Assembler(high=10, low=5) self.assertEqual(assembler.high, 10) self.assertEqual(assembler.low, 5) + def test_unset_high_and_low_water_marks(self): + """High-water and low-water marks are unset.""" + assembler = Assembler() + self.assertEqual(assembler.high, None) + self.assertEqual(assembler.low, None) + def test_set_invalid_high_water_mark(self): """high must be a non-negative integer.""" with self.assertRaises(ValueError): From de2e7fb8b7eaca56d633ae7d2ffffdbb212048a1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 11 Nov 2024 21:40:26 +0100 Subject: [PATCH 47/51] Add close_code and close_reason to new implementations. Also add state to threading implementation. Fix #1537. --- docs/reference/asyncio/client.rst | 7 ++++++ docs/reference/asyncio/common.rst | 7 ++++++ docs/reference/asyncio/server.rst | 7 ++++++ docs/reference/sync/client.rst | 9 +++++++ docs/reference/sync/common.rst | 9 +++++++ docs/reference/sync/server.rst | 9 +++++++ src/websockets/asyncio/connection.py | 24 ++++++++++++++++++ src/websockets/protocol.py | 21 ++++++++++------ src/websockets/sync/connection.py | 37 ++++++++++++++++++++++++++++ tests/asyncio/test_connection.py | 10 +++++++- tests/sync/test_connection.py | 14 ++++++++++- 11 files changed, 145 insertions(+), 9 deletions(-) diff --git a/docs/reference/asyncio/client.rst b/docs/reference/asyncio/client.rst index e2b0ff550..ea7b21506 100644 --- a/docs/reference/asyncio/client.rst +++ b/docs/reference/asyncio/client.rst @@ -57,3 +57,10 @@ Using a connection .. autoattribute:: response .. autoproperty:: subprotocol + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason diff --git a/docs/reference/asyncio/common.rst b/docs/reference/asyncio/common.rst index a58325fb9..325f20450 100644 --- a/docs/reference/asyncio/common.rst +++ b/docs/reference/asyncio/common.rst @@ -45,3 +45,10 @@ Both sides (new :mod:`asyncio`) .. autoattribute:: response .. autoproperty:: subprotocol + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason diff --git a/docs/reference/asyncio/server.rst b/docs/reference/asyncio/server.rst index 2fcaeb414..49bd6f072 100644 --- a/docs/reference/asyncio/server.rst +++ b/docs/reference/asyncio/server.rst @@ -79,6 +79,13 @@ Using a connection .. autoproperty:: subprotocol + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason + Broadcast --------- diff --git a/docs/reference/sync/client.rst b/docs/reference/sync/client.rst index af1132412..2aa491f6a 100644 --- a/docs/reference/sync/client.rst +++ b/docs/reference/sync/client.rst @@ -39,6 +39,8 @@ Using a connection .. autoproperty:: remote_address + .. autoproperty:: state + The following attributes are available after the opening handshake, once the WebSocket connection is open: @@ -47,3 +49,10 @@ Using a connection .. autoattribute:: response .. autoproperty:: subprotocol + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason diff --git a/docs/reference/sync/common.rst b/docs/reference/sync/common.rst index 3dc6d4a50..3c03b25b6 100644 --- a/docs/reference/sync/common.rst +++ b/docs/reference/sync/common.rst @@ -31,6 +31,8 @@ Both sides (:mod:`threading`) .. autoproperty:: remote_address + .. autoproperty:: state + The following attributes are available after the opening handshake, once the WebSocket connection is open: @@ -39,3 +41,10 @@ Both sides (:mod:`threading`) .. autoattribute:: response .. autoproperty:: subprotocol + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason diff --git a/docs/reference/sync/server.rst b/docs/reference/sync/server.rst index 80e9c17bb..1d80450f9 100644 --- a/docs/reference/sync/server.rst +++ b/docs/reference/sync/server.rst @@ -52,6 +52,8 @@ Using a connection .. autoproperty:: remote_address + .. autoproperty:: state + The following attributes are available after the opening handshake, once the WebSocket connection is open: @@ -61,6 +63,13 @@ Using a connection .. autoproperty:: subprotocol + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason + HTTP Basic Authentication ------------------------- diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index f1dcbada6..e5c350fe2 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -185,6 +185,30 @@ def subprotocol(self) -> Subprotocol | None: """ return self.protocol.subprotocol + @property + def close_code(self) -> int | None: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should inspect attributes + of :exc:`~websockets.exceptions.ConnectionClosed` exceptions. + + """ + return self.protocol.close_code + + @property + def close_reason(self) -> str | None: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should inspect attributes + of :exc:`~websockets.exceptions.ConnectionClosed` exceptions. + + """ + return self.protocol.close_reason + # Public methods async def __aenter__(self) -> Connection: diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 0f6fea250..bc64a216a 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -159,7 +159,12 @@ def state(self) -> State: """ State of the WebSocket connection. - Defined in 4.1, 4.2, 7.1.3, and 7.1.4 of :rfc:`6455`. + Defined in 4.1_, 4.2_, 7.1.3_, and 7.1.4_ of :rfc:`6455`. + + .. _4.1: https://datatracker.ietf.org/doc/html/rfc6455#section-4.1 + .. _4.2: https://datatracker.ietf.org/doc/html/rfc6455#section-4.2 + .. _7.1.3: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.3 + .. _7.1.4: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.4 """ return self._state @@ -173,10 +178,11 @@ def state(self, state: State) -> None: @property def close_code(self) -> int | None: """ - `WebSocket close code`_. + WebSocket close code received from the remote endpoint. + + Defined in 7.1.5_ of :rfc:`6455`. - .. _WebSocket close code: - https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.5 + .. _7.1.5: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.5 :obj:`None` if the connection isn't closed yet. @@ -191,10 +197,11 @@ def close_code(self) -> int | None: @property def close_reason(self) -> str | None: """ - `WebSocket close reason`_. + WebSocket close reason received from the remote endpoint. + + Defined in 7.1.6_ of :rfc:`6455`. - .. _WebSocket close reason: - https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.6 + .. _7.1.6: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.6 :obj:`None` if the connection isn't closed yet. diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index be3381c8a..d8dbf140e 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -140,6 +140,19 @@ def remote_address(self) -> Any: """ return self.socket.getpeername() + @property + def state(self) -> State: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should call :meth:`~recv` or + :meth:`send` and handle :exc:`~websockets.exceptions.ConnectionClosed` + exceptions. + + """ + return self.protocol.state + @property def subprotocol(self) -> Subprotocol | None: """ @@ -150,6 +163,30 @@ def subprotocol(self) -> Subprotocol | None: """ return self.protocol.subprotocol + @property + def close_code(self) -> int | None: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should inspect attributes + of :exc:`~websockets.exceptions.ConnectionClosed` exceptions. + + """ + return self.protocol.close_code + + @property + def close_reason(self) -> str | None: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should inspect attributes + of :exc:`~websockets.exceptions.ConnectionClosed` exceptions. + + """ + return self.protocol.close_reason + # Public methods def __enter__(self) -> Connection: diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 8dd0a0335..5a0b61bf7 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -1139,7 +1139,7 @@ async def test_remote_address(self, get_extra_info): async def test_state(self): """Connection has a state attribute.""" - self.assertEqual(self.connection.state, State.OPEN) + self.assertIs(self.connection.state, State.OPEN) async def test_request(self): """Connection has a request attribute.""" @@ -1153,6 +1153,14 @@ async def test_subprotocol(self): """Connection has a subprotocol attribute.""" self.assertIsNone(self.connection.subprotocol) + async def test_close_code(self): + """Connection has a close_code attribute.""" + self.assertIsNone(self.connection.close_code) + + async def test_close_reason(self): + """Connection has a close_reason attribute.""" + self.assertIsNone(self.connection.close_reason) + # Test reporting of network errors. async def test_writing_in_data_received_fails(self): diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 6be490a5d..4884bf13f 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -15,7 +15,7 @@ ConnectionClosedOK, ) from websockets.frames import CloseCode, Frame, Opcode -from websockets.protocol import CLIENT, SERVER, Protocol +from websockets.protocol import CLIENT, SERVER, Protocol, State from websockets.sync.connection import * from ..protocol import RecordingProtocol @@ -808,6 +808,10 @@ def test_remote_address(self, getpeername): self.assertEqual(self.connection.remote_address, ("peer", 1234)) getpeername.assert_called_with() + def test_state(self): + """Connection has a state attribute.""" + self.assertIs(self.connection.state, State.OPEN) + def test_request(self): """Connection has a request attribute.""" self.assertIsNone(self.connection.request) @@ -820,6 +824,14 @@ def test_subprotocol(self): """Connection has a subprotocol attribute.""" self.assertIsNone(self.connection.subprotocol) + def test_close_code(self): + """Connection has a close_code attribute.""" + self.assertIsNone(self.connection.close_code) + + def test_close_reason(self): + """Connection has a close_reason attribute.""" + self.assertIsNone(self.connection.close_reason) + # Test reporting of network errors. @unittest.skipUnless(sys.platform == "darwin", "works only on BSD") From 1f19487b0f55aac3f67a5f1b35209e0e3b294063 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 11 Nov 2024 22:01:58 +0100 Subject: [PATCH 48/51] Add changelog for previous commit. --- docs/project/changelog.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 8e1ad81f0..792f8ede4 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -41,6 +41,10 @@ Improvements implementations for consistency with the legacy implementation, even though this is never a good idea. +* Added ``close_code`` and ``close_reason`` attributes in the :mod:`asyncio` and + :mod:`threading` implementations for consistency with the legacy + implementation. + Bug fixes ......... From 7438b8ebee3f6ef59c15b86d627d32f98f49df8f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 12 Nov 2024 08:40:24 +0100 Subject: [PATCH 49/51] Stop testing on PyPy 3.9. PyPy v7.3.17 no longer provides PyPy 3.9. Also the test suite was flaky under PyPy 3.9. --- .github/workflows/tests.yml | 3 --- tests/sync/test_connection.py | 5 ----- 2 files changed, 8 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index beaf9d12b..5ab9c4c72 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -60,13 +60,10 @@ jobs: - "3.11" - "3.12" - "3.13" - - "pypy-3.9" - "pypy-3.10" is_main: - ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} exclude: - - python: "pypy-3.9" - is_main: false - python: "pypy-3.10" is_main: false steps: diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 4884bf13f..e21e310a2 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -1,6 +1,5 @@ import contextlib import logging -import platform import socket import sys import threading @@ -564,10 +563,6 @@ def test_close_idempotency(self): self.connection.close() self.assertNoFrameSent() - @unittest.skipIf( - platform.python_implementation() == "PyPy", - "this test fails randomly due to a bug in PyPy", # see #1314 for details - ) def test_close_idempotency_race_condition(self): """close waits if the connection is already closing.""" From d0015c93f49511eb8dd073caa6a3a338979f741b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 13 Nov 2024 07:51:11 +0100 Subject: [PATCH 50/51] Fix refactoring error in a78b5546. Fix #1546. --- example/faq/health_check_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/faq/health_check_server.py b/example/faq/health_check_server.py index c0fa4327f..30623a4bb 100755 --- a/example/faq/health_check_server.py +++ b/example/faq/health_check_server.py @@ -6,7 +6,7 @@ def health_check(connection, request): if request.path == "/healthz": - return connection.respond(HTTPStatus.OK, b"OK\n") + return connection.respond(HTTPStatus.OK, "OK\n") async def echo(websocket): async for message in websocket: From 0403823185b272ae389da1c9182773932b4df950 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 13 Nov 2024 07:53:45 +0100 Subject: [PATCH 51/51] Release version 14.1. --- docs/project/changelog.rst | 6 +++--- src/websockets/version.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 792f8ede4..ca6769199 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -30,9 +30,7 @@ notice. 14.1 ---- -*In development* - -.. _14.0: +*November 13, 2024* Improvements ............ @@ -52,6 +50,8 @@ Bug fixes be read in the :mod:`asyncio` and :mod:`threading` implementations, just like in the legacy implementation. +.. _14.0: + 14.0 ---- diff --git a/src/websockets/version.py b/src/websockets/version.py index 48d2edaea..f2defeff0 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -18,7 +18,7 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = False +released = True tag = version = commit = "14.1"