Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions google/cloud/spanner_v1/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,11 @@ def execute_update(
transaction = self._make_txn_selector()
api = database.spanner_api

seqno, self._execute_sql_count = (
self._execute_sql_count,
self._execute_sql_count + 1,
)

# Query-level options have higher precedence than client-level and
# environment-level options
default_query_options = database._instance._client._query_options
Expand All @@ -214,11 +219,9 @@ def execute_update(
param_types=param_types,
query_mode=query_mode,
query_options=query_options,
seqno=self._execute_sql_count,
seqno=seqno,
metadata=metadata,
)

self._execute_sql_count += 1
return response.stats.row_count_exact

def batch_update(self, statements):
Expand Down Expand Up @@ -259,15 +262,18 @@ def batch_update(self, statements):
transaction = self._make_txn_selector()
api = database.spanner_api

seqno, self._execute_sql_count = (
self._execute_sql_count,
self._execute_sql_count + 1,
)

response = api.execute_batch_dml(
session=self._session.name,
transaction=transaction,
statements=parsed,
seqno=self._execute_sql_count,
seqno=seqno,
metadata=metadata,
)

self._execute_sql_count += 1
row_counts = [
result_set.stats.row_count_exact for result_set in response.result_sets
]
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,8 @@ def test_execute_sql_other_error(self):
with self.assertRaises(RuntimeError):
list(derived.execute_sql(SQL_QUERY))

self.assertEqual(derived._execute_sql_count, 1)

def test_execute_sql_w_params_wo_param_types(self):
database = _Database()
session = _Session(database)
Expand Down
38 changes: 38 additions & 0 deletions tests/unit/test_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,19 @@ def test_execute_update_new_transaction(self):
def test_execute_update_w_count(self):
self._execute_update_helper(count=1)

def test_execute_update_error(self):
database = _Database()
database.spanner_api = self._make_spanner_api()
database.spanner_api.execute_sql.side_effect = RuntimeError()
session = _Session(database)
transaction = self._make_one(session)
transaction._transaction_id = self.TRANSACTION_ID

with self.assertRaises(RuntimeError):
transaction.execute_update(DML_QUERY)

self.assertEqual(transaction._execute_sql_count, 1)

def test_execute_update_w_query_options(self):
from google.cloud.spanner_v1.proto.spanner_pb2 import ExecuteSqlRequest

Expand Down Expand Up @@ -513,6 +526,31 @@ def test_batch_update_wo_errors(self):
def test_batch_update_w_errors(self):
self._batch_update_helper(error_after=2, count=1)

def test_batch_update_error(self):
database = _Database()
api = database.spanner_api = self._make_spanner_api()
api.execute_batch_dml.side_effect = RuntimeError()
session = _Session(database)
transaction = self._make_one(session)
transaction._transaction_id = self.TRANSACTION_ID

insert_dml = "INSERT INTO table(pkey, desc) VALUES (%pkey, %desc)"
insert_params = {"pkey": 12345, "desc": "DESCRIPTION"}
insert_param_types = {"pkey": "INT64", "desc": "STRING"}
update_dml = 'UPDATE table SET desc = desc + "-amended"'
delete_dml = "DELETE FROM table WHERE desc IS NULL"

dml_statements = [
(insert_dml, insert_params, insert_param_types),
update_dml,
delete_dml,
]

with self.assertRaises(RuntimeError):
transaction.batch_update(dml_statements)

self.assertEqual(transaction._execute_sql_count, 1)

def test_context_mgr_success(self):
import datetime
from google.cloud.spanner_v1.proto.spanner_pb2 import CommitResponse
Expand Down