diff --git a/spanner_dbapi/cursor.py b/spanner_dbapi/cursor.py index 10e5184ed2..4a0a923236 100644 --- a/spanner_dbapi/cursor.py +++ b/spanner_dbapi/cursor.py @@ -64,7 +64,7 @@ class Cursor: def __init__(self, connection): self._itr = None self._res = None - self._row_count = _UNSET_COUNT + self._rowcount = _UNSET_COUNT self._connection = connection self._is_closed = False @@ -78,13 +78,11 @@ def execute(self, sql, args=None): sql: A SQL statement *args: variadic argument list **kwargs: key worded arguments - Returns: - None """ self._raise_if_closed() if not self._connection: - raise ProgrammingError("Cursor is not connected to the database") + raise ProgrammingError("Cursor is not connected to a database") self._res = None @@ -97,14 +95,16 @@ def execute(self, sql, args=None): # For every other operation, we've got to ensure that # any prior DDL statements were run. - self._run_prior_DDL_statements() + self._run_prior_ddl_statements() if classification == STMT_NON_UPDATING: - self.__handle_DQL(sql, args or None) + self._handle_dql(sql, args or None) elif classification == STMT_INSERT: - self.__handle_insert(sql, args or None) + self._handle_insert(sql, args or None) else: - self.__handle_update(sql, args or None) + self._handle_update(sql, args or None) + except (AlreadyExists, FailedPrecondition) as e: + raise IntegrityError(e.details if hasattr(e, "details") else e) except (AlreadyExists, FailedPrecondition) as e: raise IntegrityError(e.details if hasattr(e, "details") else e) except InvalidArgument as e: @@ -112,11 +112,11 @@ def execute(self, sql, args=None): except InternalServerError as e: raise OperationalError(e.details if hasattr(e, "details") else e) - def __handle_update(self, sql, params): + def _handle_update(self, sql, params): + ensure_where_clause(sql) self._connection.in_transaction(self.__do_execute_update, sql, params) def __do_execute_update(self, transaction, sql, params, param_types=None): - sql = ensure_where_clause(sql) sql, params = sql_pyformat_args_to_spanner(sql, params) res = transaction.execute_update( @@ -124,11 +124,11 @@ def __do_execute_update(self, transaction, sql, params, param_types=None): ) self._itr = None if type(res) == int: - self._row_count = res + self._rowcount = res return res - def __handle_insert(self, sql, params): + def _handle_insert(self, sql, params): parts = parse_insert(sql, params) # The split between the two styles exists because: @@ -176,7 +176,7 @@ def __do_execute_insert_homogenous(self, transaction, parts): values = parts.get("values") return transaction.insert(table, columns, values) - def __handle_DQL(self, sql, params): + def _handle_dql(self, sql, params): with self._connection.read_snapshot() as snapshot: # Reference # https://googleapis.dev/python/spanner/latest/session-api.html#google.cloud.spanner_v1.session.Session.execute_sql @@ -185,7 +185,7 @@ def __handle_DQL(self, sql, params): sql, params=params, param_types=get_param_types(params) ) if type(res) == int: - self._row_count = res + self._rowcount = res self._itr = None else: # Immediately using: @@ -202,7 +202,7 @@ def __handle_DQL(self, sql, params): self._itr = PeekIterator(self._res) # Unfortunately, Spanner doesn't seem to send back # information about the number of rows available. - self._row_count = _UNSET_COUNT + self._rowcount = _UNSET_COUNT def __enter__(self): return self @@ -235,7 +235,7 @@ def description(self): @property def rowcount(self): - return self._row_count + return self._rowcount @property def is_closed(self): @@ -334,7 +334,7 @@ def setinputsizes(sizes): def setoutputsize(size, column=None): raise ProgrammingError("Unimplemented") - def _run_prior_DDL_statements(self): + def _run_prior_ddl_statements(self): return self._connection.run_prior_DDL_statements() def list_tables(self): diff --git a/spanner_dbapi/parse_utils.py b/spanner_dbapi/parse_utils.py index edf3d3e64f..5b79116348 100644 --- a/spanner_dbapi/parse_utils.py +++ b/spanner_dbapi/parse_utils.py @@ -402,16 +402,25 @@ def get_param_types(params): def ensure_where_clause(sql): + """Check if the given sql query includes WHERE clause. + + Cloud Spanner requires a WHERE clause with UPDATE + and DELETE statements to avoid accidental delete + or update of all the table rows. + + :type sql: :class:`str` + :param sql: SQL query to check for WHERE clause presence. + + :raises: :class:`ProgrammingError` if there is + no WHERE clause in the given query. """ - Cloud Spanner requires a WHERE clause on UPDATE and DELETE statements. - Add a dummy WHERE clause if necessary. - """ - if any( - isinstance(token, sqlparse.sql.Where) - for token in sqlparse.parse(sql)[0] - ): - return sql - return sql + " WHERE 1=1" + for token in sqlparse.parse(sql)[0]: + if isinstance(token, sqlparse.sql.Where): + return + + raise ProgrammingError( + "Cloud Spanner requires a WHERE clause when executing DELETE or UPDATE query" + ) SPANNER_RESERVED_KEYWORDS = { diff --git a/tests/spanner_dbapi/test_parse_utils.py b/tests/spanner_dbapi/test_parse_utils.py index 91d795ad0f..6cd7834805 100644 --- a/tests/spanner_dbapi/test_parse_utils.py +++ b/tests/spanner_dbapi/test_parse_utils.py @@ -448,34 +448,23 @@ def test_get_param_types(self): self.assertEqual(got_param_types, want_param_types) def test_ensure_where_clause(self): - cases = [ - ( - "UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1", - "UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1", - ), - ( - "UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5", - "UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5 WHERE 1=1", - ), - ( - "UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2", - "UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2", - ), - ( - "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", - "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", - ), - ( - "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", - "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", - ), - ("DELETE * FROM TABLE", "DELETE * FROM TABLE WHERE 1=1"), - ] + cases = ( + "UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1", + "UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2", + "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", + ) + err_cases = ( + "UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5", + "DELETE * FROM TABLE", + ) + for sql in cases: + with self.subTest(sql=sql): + ensure_where_clause(sql) - for sql, want in cases: + for sql in err_cases: with self.subTest(sql=sql): - got = ensure_where_clause(sql) - self.assertEqual(got, want) + with self.assertRaises(ProgrammingError): + ensure_where_clause(sql) def test_escape_name(self): cases = [