Thanks to visit codestin.com
Credit goes to github.com

Skip to content
84 changes: 39 additions & 45 deletions django_spanner/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,14 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
TypeCode.NUMERIC: "DecimalField",
TypeCode.JSON: "JSONField",
}
if USE_EMULATOR:
# Emulator does not support table_type yet.
# https://github.com/GoogleCloudPlatform/cloud-spanner-emulator/issues/43
LIST_TABLE_SQL = """
SELECT
t.table_name, t.table_name
FROM
information_schema.tables AS t
WHERE
t.table_catalog = '' and t.table_schema = ''
"""
else:
LIST_TABLE_SQL = """
SELECT
t.table_name, t.table_type
FROM
information_schema.tables AS t
WHERE
t.table_catalog = '' and t.table_schema = ''
"""
LIST_TABLE_SQL = """
SELECT
t.table_name, t.table_type
FROM
information_schema.tables AS t
WHERE
t.table_catalog = '' and t.table_schema = @schema_name
"""

def get_field_type(self, data_type, description):
"""A hook for a Spanner database to use the cursor description to
Expand Down Expand Up @@ -76,7 +64,10 @@ def get_table_list(self, cursor):
:rtype: list
:returns: A list of table and view names in the current database.
"""
results = cursor.run_sql_in_snapshot(self.LIST_TABLE_SQL)
schema_name = self._get_schema_name(cursor)
results = cursor.run_sql_in_snapshot(
self.LIST_TABLE_SQL, params={"schema_name": schema_name}
)
tables = []
# The second TableInfo field is 't' for table or 'v' for view.
for row in results:
Expand Down Expand Up @@ -159,8 +150,9 @@ def get_relations(self, cursor, table_name):
:rtype: dict
:returns: A dictionary representing column relationships to other tables.
"""
schema_name = self._get_schema_name(cursor)
results = cursor.run_sql_in_snapshot(
'''
"""
SELECT
tc.COLUMN_NAME as col, ccu.COLUMN_NAME as ref_col, ccu.TABLE_NAME as ref_table
FROM
Expand All @@ -174,8 +166,9 @@ def get_relations(self, cursor, table_name):
ON
rc.UNIQUE_CONSTRAINT_NAME = ccu.CONSTRAINT_NAME
WHERE
tc.TABLE_NAME="%s"'''
% self.connection.ops.quote_name(table_name)
tc.TABLE_SCHEMA=@schema_name AND tc.TABLE_NAME=@view_name
""",
params={"schema_name": schema_name, "view_name": table_name},
)
return {
column: (referred_column, referred_table)
Expand All @@ -194,6 +187,7 @@ def get_primary_key_column(self, cursor, table_name):
:rtype: str
:returns: The name of the PK column.
"""
schema_name = self._get_schema_name(cursor)
results = cursor.run_sql_in_snapshot(
"""
SELECT
Expand All @@ -205,9 +199,9 @@ def get_primary_key_column(self, cursor, table_name):
AS
ccu ON tc.CONSTRAINT_NAME = ccu.CONSTRAINT_NAME
WHERE
tc.TABLE_NAME="%s" AND tc.CONSTRAINT_TYPE='PRIMARY KEY' AND tc.TABLE_SCHEMA=''
"""
% self.connection.ops.quote_name(table_name)
tc.TABLE_NAME=@table_name AND tc.CONSTRAINT_TYPE='PRIMARY KEY' AND tc.TABLE_SCHEMA=@schema_name
""",
params={"schema_name": schema_name, "table_name": table_name},
)
return results[0][0] if results else None

Expand All @@ -224,18 +218,17 @@ def get_constraints(self, cursor, table_name):
:returns: A dictionary with constraints.
"""
constraints = {}
quoted_table_name = self.connection.ops.quote_name(table_name)
schema_name = self._get_schema_name(cursor)

# Firstly populate all available constraints and their columns.
constraint_columns = cursor.run_sql_in_snapshot(
'''
"""
SELECT
CONSTRAINT_NAME, COLUMN_NAME
FROM
INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE
WHERE TABLE_NAME="{table}"'''.format(
table=quoted_table_name
)
WHERE TABLE_NAME=@table AND TABLE_SCHEMA=@schema_name""",
params={"table": table_name, "schema_name": schema_name},
)
for constraint, column_name in constraint_columns:
if constraint not in constraints:
Expand All @@ -254,15 +247,14 @@ def get_constraints(self, cursor, table_name):

# Add the various constraints by type.
constraint_types = cursor.run_sql_in_snapshot(
'''
"""
SELECT
CONSTRAINT_NAME, CONSTRAINT_TYPE
FROM
INFORMATION_SCHEMA.TABLE_CONSTRAINTS
WHERE
TABLE_NAME="{table}"'''.format(
table=quoted_table_name
)
TABLE_NAME=@table AND TABLE_SCHEMA=@schema_name""",
params={"table": table_name, "schema_name": schema_name},
)
for constraint, constraint_type in constraint_types:
already_added = constraint in constraints
Expand Down Expand Up @@ -303,14 +295,13 @@ def get_constraints(self, cursor, table_name):
RIGHT JOIN
INFORMATION_SCHEMA.INDEX_COLUMNS AS idx_col
ON
idx_col.INDEX_NAME = idx.INDEX_NAME AND idx_col.TABLE_NAME="{table}"
idx_col.INDEX_NAME = idx.INDEX_NAME AND idx_col.TABLE_NAME=@table AND idx_col.TABLE_SCHEMA=idx.TABLE_SCHEMA
WHERE
idx.TABLE_NAME="{table}"
idx.TABLE_NAME=@table AND idx.TABLE_SCHEMA=@schema_name
ORDER BY
idx_col.ORDINAL_POSITION
""".format(
table=quoted_table_name
)
""",
params={"table": table_name, "schema_name": schema_name},
)
for (
index_name,
Expand Down Expand Up @@ -350,6 +341,7 @@ def get_key_columns(self, cursor, table_name):
for all key columns in the given table.
"""
key_columns = []
schema_name = self._get_schema_name(cursor)
cursor.execute(
"""SELECT
tc.COLUMN_NAME as column_name,
Expand All @@ -366,10 +358,12 @@ def get_key_columns(self, cursor, table_name):
ON
rc.CONSTRAINT_NAME = ccu.CONSTRAINT_NAME
WHERE
tc.TABLE_NAME="{table}"
""".format(
table=self.connection.ops.quote_name(table_name)
)
tc.TABLE_NAME=@table AND tc.TABLE_SCHEMA=@schema_name
""",
params={"table": table_name, "schema_name": schema_name},
)
key_columns.extend(cursor.fetchall())
return key_columns

def _get_schema_name(self, cursor):
return cursor.connection.current_schema