diff --git a/tests/test_query_parse.py b/tests/test_query_parse.py new file mode 100644 index 0000000..7dc5cad --- /dev/null +++ b/tests/test_query_parse.py @@ -0,0 +1,12 @@ +from websocket_server import WebSocketHandler + + +def test_websocket_handler_query_parse(): + case1 = WebSocketHandler.parse_query("GET /?a=hello HTTP/1.1") + case2 = WebSocketHandler.parse_query("GET / HTTP/1.1") + case3 = WebSocketHandler.parse_query("GET /?a=hello&b=world HTTP/1.1") + case4 = WebSocketHandler.parse_query("GET /?a=hello&a=world HTTP/1.1") + assert case1 == {'a': ['hello']} + assert case2 == {} + assert case3 == {'a': ['hello'], 'b': ['world']} + assert case4 == {'a': ['hello', 'world']} diff --git a/websocket_server/websocket_server.py b/websocket_server/websocket_server.py index 083ee17..fda9934 100644 --- a/websocket_server/websocket_server.py +++ b/websocket_server/websocket_server.py @@ -11,6 +11,7 @@ import errno import threading from socketserver import ThreadingMixIn, TCPServer, StreamRequestHandler +from urllib.parse import urlparse, parse_qs from websocket_server.thread import WebsocketServerThread @@ -261,6 +262,8 @@ class WebSocketHandler(StreamRequestHandler): def __init__(self, socket, addr, server): self.server = server + self.headers = {} + self.query_params = {} assert not hasattr(self, "_send_lock"), "_send_lock already exists" self._send_lock = threading.Lock() if server.key and server.cert: @@ -412,6 +415,16 @@ def send_text(self, message, opcode=OPCODE_TEXT): with self._send_lock: self.request.send(header + payload) + @staticmethod + def parse_query(http_get): + """ + Parses the query parameters from the first line. + Example: "GET /?q=hello HTTP/1.1" will be parsed to {'q': ['hello']} + """ + query = http_get.split(" ")[1] # example: http_get = "GET /?q=hello HTTP/1.1" + parsed_url = urlparse(query) + return parse_qs(parsed_url.query) + def read_http_headers(self): headers = {} # first line should be HTTP GET @@ -424,6 +437,8 @@ def read_http_headers(self): break head, value = header.split(':', 1) headers[head.lower().strip()] = value.strip() + self.headers = headers + self.query_params = WebSocketHandler.parse_query(http_get) return headers def handshake(self):