Thanks to visit codestin.com
Credit goes to github.com

Skip to content
34 changes: 17 additions & 17 deletions spanner_dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class Cursor:
def __init__(self, connection):
self._itr = None
self._res = None
self._row_count = _UNSET_COUNT
self._rowcount = _UNSET_COUNT
self._connection = connection
self._is_closed = False

Expand All @@ -78,13 +78,11 @@ def execute(self, sql, args=None):
sql: A SQL statement
*args: variadic argument list
**kwargs: key worded arguments
Returns:
None
"""
self._raise_if_closed()

if not self._connection:
raise ProgrammingError("Cursor is not connected to the database")
raise ProgrammingError("Cursor is not connected to a database")

self._res = None

Expand All @@ -97,38 +95,40 @@ def execute(self, sql, args=None):

# For every other operation, we've got to ensure that
# any prior DDL statements were run.
self._run_prior_DDL_statements()
self._run_prior_ddl_statements()

if classification == STMT_NON_UPDATING:
self.__handle_DQL(sql, args or None)
self._handle_dql(sql, args or None)
elif classification == STMT_INSERT:
self.__handle_insert(sql, args or None)
self._handle_insert(sql, args or None)
else:
self.__handle_update(sql, args or None)
self._handle_update(sql, args or None)
except (AlreadyExists, FailedPrecondition) as e:
raise IntegrityError(e.details if hasattr(e, "details") else e)
except (AlreadyExists, FailedPrecondition) as e:
raise IntegrityError(e.details if hasattr(e, "details") else e)
except InvalidArgument as e:
raise ProgrammingError(e.details if hasattr(e, "details") else e)
except InternalServerError as e:
raise OperationalError(e.details if hasattr(e, "details") else e)

def __handle_update(self, sql, params):
def _handle_update(self, sql, params):
ensure_where_clause(sql)
self._connection.in_transaction(self.__do_execute_update, sql, params)

def __do_execute_update(self, transaction, sql, params, param_types=None):
sql = ensure_where_clause(sql)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also realized that it's not very okay to do the check here, as __do_execute_update() will be sent to Database.run_in_transaction(), and then to Session.run_in_transaction() as func arg. This last method will start a transaction (if needed) and only then it'll run __do_execute_update() (and the check). We better run the check before starting anything. The check call moved to line 99

sql, params = sql_pyformat_args_to_spanner(sql, params)

res = transaction.execute_update(
sql, params=params, param_types=get_param_types(params)
)
self._itr = None
if type(res) == int:
self._row_count = res
self._rowcount = res

return res

def __handle_insert(self, sql, params):
def _handle_insert(self, sql, params):
parts = parse_insert(sql, params)

# The split between the two styles exists because:
Expand Down Expand Up @@ -176,7 +176,7 @@ def __do_execute_insert_homogenous(self, transaction, parts):
values = parts.get("values")
return transaction.insert(table, columns, values)

def __handle_DQL(self, sql, params):
def _handle_dql(self, sql, params):
with self._connection.read_snapshot() as snapshot:
# Reference
# https://googleapis.dev/python/spanner/latest/session-api.html#google.cloud.spanner_v1.session.Session.execute_sql
Expand All @@ -185,7 +185,7 @@ def __handle_DQL(self, sql, params):
sql, params=params, param_types=get_param_types(params)
)
if type(res) == int:
self._row_count = res
self._rowcount = res
self._itr = None
else:
# Immediately using:
Expand All @@ -202,7 +202,7 @@ def __handle_DQL(self, sql, params):
self._itr = PeekIterator(self._res)
# Unfortunately, Spanner doesn't seem to send back
# information about the number of rows available.
self._row_count = _UNSET_COUNT
self._rowcount = _UNSET_COUNT

def __enter__(self):
return self
Expand Down Expand Up @@ -235,7 +235,7 @@ def description(self):

@property
def rowcount(self):
return self._row_count
return self._rowcount

@property
def is_closed(self):
Expand Down Expand Up @@ -334,7 +334,7 @@ def setinputsizes(sizes):
def setoutputsize(size, column=None):
raise ProgrammingError("Unimplemented")

def _run_prior_DDL_statements(self):
def _run_prior_ddl_statements(self):
return self._connection.run_prior_DDL_statements()

def list_tables(self):
Expand Down
27 changes: 18 additions & 9 deletions spanner_dbapi/parse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,16 +402,25 @@ def get_param_types(params):


def ensure_where_clause(sql):
"""Check if the given sql query includes WHERE clause.

Cloud Spanner requires a WHERE clause with UPDATE
and DELETE statements to avoid accidental delete
or update of all the table rows.

:type sql: :class:`str`
:param sql: SQL query to check for WHERE clause presence.

:raises: :class:`ProgrammingError` if there is
no WHERE clause in the given query.
"""
Cloud Spanner requires a WHERE clause on UPDATE and DELETE statements.
Add a dummy WHERE clause if necessary.
"""
if any(
isinstance(token, sqlparse.sql.Where)
for token in sqlparse.parse(sql)[0]
):
return sql
return sql + " WHERE 1=1"
for token in sqlparse.parse(sql)[0]:
if isinstance(token, sqlparse.sql.Where):
return

raise ProgrammingError(
"Cloud Spanner requires a WHERE clause when executing DELETE or UPDATE query"
)


SPANNER_RESERVED_KEYWORDS = {
Expand Down
41 changes: 15 additions & 26 deletions tests/spanner_dbapi/test_parse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,34 +448,23 @@ def test_get_param_types(self):
self.assertEqual(got_param_types, want_param_types)

def test_ensure_where_clause(self):
cases = [
(
"UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1",
"UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1",
),
(
"UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5",
"UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5 WHERE 1=1",
),
(
"UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2",
"UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2",
),
(
"UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
"UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
),
(
"UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
"UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
),
("DELETE * FROM TABLE", "DELETE * FROM TABLE WHERE 1=1"),
]
cases = (
"UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1",
"UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2",
"UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
)
err_cases = (
"UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5",
"DELETE * FROM TABLE",
)
for sql in cases:
with self.subTest(sql=sql):
ensure_where_clause(sql)

for sql, want in cases:
for sql in err_cases:
with self.subTest(sql=sql):
got = ensure_where_clause(sql)
self.assertEqual(got, want)
with self.assertRaises(ProgrammingError):
ensure_where_clause(sql)

def test_escape_name(self):
cases = [
Expand Down