From 025ce62dbd45c6826b6b48dba3ab852bf900ae01 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Thu, 25 May 2023 01:35:49 +0900 Subject: [PATCH 1/4] Add `collation` option and set_character_set() to Connection. --- pymysql/connections.py | 36 +++++++++++++++++++++++++++++++++--- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/pymysql/connections.py b/pymysql/connections.py index d161e789..285e99e9 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -112,7 +112,8 @@ class Connection: (default: None - no timeout) :param write_timeout: The timeout for writing to the connection in seconds. (default: None - no timeout) - :param charset: Charset to use. + :param str charset: Charset to use. + :param str collation: Collation name to use. :param sql_mode: Default SQL_MODE to use. :param read_default_file: Specifies my.cnf file to read these parameters from under the [client] section. @@ -174,6 +175,7 @@ def __init__( unix_socket=None, port=0, charset="", + collation=None, sql_mode=None, read_default_file=None, conv=None, @@ -308,6 +310,7 @@ def _config(key, arg): self._write_timeout = write_timeout self.charset = charset or DEFAULT_CHARSET + self.collation = collation self.use_unicode = use_unicode self.encoding = charset_by_name(self.charset).encoding @@ -593,10 +596,22 @@ def ping(self, reconnect=True): raise def set_charset(self, charset): + """Deprecated. Use set_character_set() instead.""" + # This function has been implemented in old PyMySQL. + # But this name is different from MySQLdb. + # So we keep this function for compatibility and add + # new set_character_set() function. + self.set_character_set(charset) + + def set_character_set(self, charset, collation=None): # Make sure charset is supported. encoding = charset_by_name(charset).encoding - self._execute_command(COMMAND.COM_QUERY, "SET NAMES %s" % self.escape(charset)) + if collation: + query = f"SET NAMES {charset} COLLATE {collation}" + else: + query = f"SET NAMES {charset}" + self._execute_command(COMMAND.COM_QUERY, query) self._read_packet() self.charset = charset self.encoding = encoding @@ -641,15 +656,30 @@ def connect(self, sock=None): self._get_server_information() self._request_authentication() + # Send "SET NAMES" query on init for: + # - Ensure charaset (and collation) is set to the server. + # - collation_id in handshake packet may be ignored. + # - If collation is not specified, we don't know what is server's + # default collation for the charset. For example, default collation + # of utf8mb4 is: + # - MySQL 5.7, MariaDB 10.x: utf8mb4_general_ci + # - MySQL 8.0: utf8mb4_0900_ai_ci + # + # Reference: + # - https://github.com/PyMySQL/PyMySQL/issues/1092 + # - https://github.com/wagtail/wagtail/issues/9477 + # - https://zenn.dev/methane/articles/2023-mysql-collation (Japanese) + self.set_character_set(self.charset, self.collation) + if self.sql_mode is not None: c = self.cursor() c.execute("SET sql_mode=%s", (self.sql_mode,)) + c.close() if self.init_command is not None: c = self.cursor() c.execute(self.init_command) c.close() - self.commit() if self.autocommit_mode is not None: self.autocommit(self.autocommit_mode) From 1d81ded4b9f267c92e9de99a1193d625220d709f Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Thu, 25 May 2023 01:43:51 +0900 Subject: [PATCH 2/4] Add docstring --- pymysql/connections.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pymysql/connections.py b/pymysql/connections.py index 285e99e9..f4782939 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -604,6 +604,12 @@ def set_charset(self, charset): self.set_character_set(charset) def set_character_set(self, charset, collation=None): + """ + Set charaset (and collation) + + Send "SET NAMES charset [COLLATE collation]" query. + Update Connection.encoding based on charset. + """ # Make sure charset is supported. encoding = charset_by_name(charset).encoding @@ -615,6 +621,7 @@ def set_character_set(self, charset, collation=None): self._read_packet() self.charset = charset self.encoding = encoding + self.collation = collation def connect(self, sock=None): self._closed = False From 00060db2a393d2f63100a0f4fc1cf1f7b7f8fc71 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Thu, 25 May 2023 01:54:42 +0900 Subject: [PATCH 3/4] Add test --- pymysql/tests/test_connection.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/pymysql/tests/test_connection.py b/pymysql/tests/test_connection.py index 869ff0f8..bb2931b3 100644 --- a/pymysql/tests/test_connection.py +++ b/pymysql/tests/test_connection.py @@ -444,6 +444,20 @@ def test_utf8mb4(self): arg["charset"] = "utf8mb4" pymysql.connect(**arg) + def test_set_character_set(self): + con = self.connect() + cur = con.cursor() + + con.set_character_set("latin1") + cur.execute("SELECT @@character_set_connection") + self.assertEqual(cur.fetchone(), ("latin1",)) + self.assertEqual(con.encoding, "cp1252") + + con.set_character_set("utf8mb3", "utf8mb3_general_ci") + cur.execute("SELECT @@character_set_connection, @@collation_connection") + self.assertEqual(cur.fetchone(), ("utf8mb3", "utf8mb3_general_ci")) + self.assertEqual(con.encoding, "utf8") + def test_largedata(self): """Large query and response (>=16MB)""" cur = self.connect().cursor() From 9c38074351673ecb83a4f24894925411c2187f0c Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Thu, 25 May 2023 01:57:00 +0900 Subject: [PATCH 4/4] fix test --- pymysql/tests/test_connection.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymysql/tests/test_connection.py b/pymysql/tests/test_connection.py index bb2931b3..0803efc9 100644 --- a/pymysql/tests/test_connection.py +++ b/pymysql/tests/test_connection.py @@ -453,9 +453,9 @@ def test_set_character_set(self): self.assertEqual(cur.fetchone(), ("latin1",)) self.assertEqual(con.encoding, "cp1252") - con.set_character_set("utf8mb3", "utf8mb3_general_ci") + con.set_character_set("utf8mb4", "utf8mb4_general_ci") cur.execute("SELECT @@character_set_connection, @@collation_connection") - self.assertEqual(cur.fetchone(), ("utf8mb3", "utf8mb3_general_ci")) + self.assertEqual(cur.fetchone(), ("utf8mb4", "utf8mb4_general_ci")) self.assertEqual(con.encoding, "utf8") def test_largedata(self):