Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit ff38c23

Browse files
Tokenize Connection header in WebSocket client handshake (aio-libs#12746)
1 parent c41e3b6 commit ff38c23

5 files changed

Lines changed: 75 additions & 1 deletion

File tree

aiohttp/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1124,7 +1124,7 @@ async def _ws_connect(
11241124
headers=resp.headers,
11251125
)
11261126

1127-
if resp.headers.get(hdrs.CONNECTION, "").lower() != "upgrade":
1127+
if not resp._upgraded:
11281128
raise WSServerHandshakeError(
11291129
resp.request_info,
11301130
resp.history,

aiohttp/client_reqrep.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ class ClientResponse(HeadersMixin):
197197
_headers: HeadersDictProxy = None # type: ignore[assignment]
198198
_history: tuple["ClientResponse", ...] = ()
199199
_raw_headers: RawHeaders = None # type: ignore[assignment]
200+
_upgraded: bool = False # parser saw a Connection: upgrade token
200201

201202
_connection: "Connection | None" = None # current connection
202203
_cookies: SimpleCookie | None = None
@@ -490,6 +491,7 @@ async def start(self, connection: "Connection") -> "ClientResponse":
490491
# headers
491492
self._headers = message.headers
492493
self._raw_headers = message.raw_headers
494+
self._upgraded = message.upgrade
493495

494496
# payload
495497
self.content = payload

tests/test_client_proto.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,3 +387,43 @@ async def test_abort_without_transport() -> None:
387387
# Should not raise and should still clean up
388388
assert proto._exception is None
389389
mock_drop_timeout.assert_not_called()
390+
391+
392+
@pytest.mark.parametrize(
393+
("connection", "expected"),
394+
[(b"upgrade, keep-alive", True), (b"keep-alive", False)],
395+
)
396+
async def test_response_start_records_upgrade(
397+
connection: bytes, expected: bool
398+
) -> None:
399+
"""ClientResponse.start() preserves the parser's Connection upgrade flag."""
400+
loop = asyncio.get_running_loop()
401+
proto = ResponseHandler(loop=loop)
402+
proto.connection_made(mock.Mock())
403+
conn = mock.Mock(protocol=proto)
404+
proto.set_response_params(read_until_eof=True)
405+
proto.data_received(
406+
b"HTTP/1.1 101 Switching Protocols\r\n"
407+
b"Upgrade: websocket\r\n"
408+
b"Connection: " + connection + b"\r\n\r\n"
409+
)
410+
411+
url = URL("http://ws-upgrade.org")
412+
response = ClientResponse(
413+
"get",
414+
url,
415+
writer=mock.Mock(),
416+
continue100=None,
417+
timer=TimerNoop(),
418+
traces=[],
419+
loop=loop,
420+
session=mock.Mock(),
421+
request_headers=CIMultiDict[str](),
422+
original_url=url,
423+
stream_writer=mock.create_autospec(
424+
AbstractStreamWriter, spec_set=True, instance=True
425+
),
426+
)
427+
await response.start(conn)
428+
assert response._upgraded is expected
429+
response.close()

tests/test_client_ws.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ async def test_ws_connect(ws_key: str, key_data: bytes) -> None:
3030
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
3131
hdrs.SEC_WEBSOCKET_PROTOCOL: "chat",
3232
}
33+
resp._upgraded = True
3334
resp.connection.protocol.read_timeout = None
3435
with mock.patch("aiohttp.client.os") as m_os:
3536
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
@@ -255,13 +256,16 @@ async def test_ws_connect_err_upgrade(ws_key: str, key_data: bytes) -> None:
255256

256257

257258
async def test_ws_connect_err_conn(ws_key: str, key_data: bytes) -> None:
259+
# The parser did not see a Connection: upgrade token (resp._upgraded is
260+
# False), so the handshake must be rejected.
258261
resp = mock.Mock()
259262
resp.status = 101
260263
resp.headers = {
261264
hdrs.UPGRADE: "websocket",
262265
hdrs.CONNECTION: "close",
263266
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
264267
}
268+
resp._upgraded = False
265269
with mock.patch("aiohttp.client.os") as m_os:
266270
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
267271
m_os.urandom.return_value = key_data

tests/test_http_parser.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -847,6 +847,34 @@ def test_upgrade_header_non_ascii(parser: HttpRequestParser) -> None:
847847
assert not upgrade
848848

849849

850+
@pytest.mark.parametrize(
851+
("connection", "expected"),
852+
[
853+
("upgrade", True),
854+
("upgrade, keep-alive", True), # other tokens alongside upgrade
855+
("keep-alive, upgrade", True), # upgrade not first
856+
("Upgrade, Keep-Alive", True), # case-insensitive
857+
("keep-alive", False), # no upgrade token
858+
("keep-alive, notupgrade", False), # substring is not a token
859+
],
860+
)
861+
def test_response_upgrade_token_in_connection_list(
862+
response: HttpResponseParser, connection: str, expected: bool
863+
) -> None:
864+
# RFC 9110 §7.6.1: Connection is a comma-separated token list, so the parser
865+
# must set msg.upgrade for a 101 response whenever "upgrade" appears as a
866+
# token, regardless of position, case, or neighbouring tokens.
867+
text = (
868+
b"HTTP/1.1 101 Switching Protocols\r\n"
869+
b"Upgrade: websocket\r\n"
870+
b"Connection: " + connection.encode() + b"\r\n\r\n"
871+
)
872+
messages, upgrade, tail = response.feed_data(text)
873+
msg = messages[0][0]
874+
assert msg.upgrade == expected
875+
assert upgrade == expected
876+
877+
850878
def test_request_te_chunked_with_content_length(parser: HttpRequestParser) -> None:
851879
text = (
852880
b"GET /test HTTP/1.1\r\n"

0 commit comments

Comments
 (0)