diff --git a/django_spanner/introspection.py b/django_spanner/introspection.py index 9202c170d8..c752ec303b 100644 --- a/django_spanner/introspection.py +++ b/django_spanner/introspection.py @@ -29,26 +29,14 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): TypeCode.NUMERIC: "DecimalField", TypeCode.JSON: "JSONField", } - if USE_EMULATOR: - # Emulator does not support table_type yet. - # https://github.com/GoogleCloudPlatform/cloud-spanner-emulator/issues/43 - LIST_TABLE_SQL = """ - SELECT - t.table_name, t.table_name - FROM - information_schema.tables AS t - WHERE - t.table_catalog = '' and t.table_schema = '' - """ - else: - LIST_TABLE_SQL = """ - SELECT - t.table_name, t.table_type - FROM - information_schema.tables AS t - WHERE - t.table_catalog = '' and t.table_schema = '' - """ + LIST_TABLE_SQL = """ + SELECT + t.table_name, t.table_type + FROM + information_schema.tables AS t + WHERE + t.table_catalog = '' and t.table_schema = @schema_name + """ def get_field_type(self, data_type, description): """A hook for a Spanner database to use the cursor description to @@ -76,7 +64,10 @@ def get_table_list(self, cursor): :rtype: list :returns: A list of table and view names in the current database. """ - results = cursor.run_sql_in_snapshot(self.LIST_TABLE_SQL) + schema_name = self._get_schema_name(cursor) + results = cursor.run_sql_in_snapshot( + self.LIST_TABLE_SQL, params={"schema_name": schema_name} + ) tables = [] # The second TableInfo field is 't' for table or 'v' for view. for row in results: @@ -159,8 +150,9 @@ def get_relations(self, cursor, table_name): :rtype: dict :returns: A dictionary representing column relationships to other tables. """ + schema_name = self._get_schema_name(cursor) results = cursor.run_sql_in_snapshot( - ''' + """ SELECT tc.COLUMN_NAME as col, ccu.COLUMN_NAME as ref_col, ccu.TABLE_NAME as ref_table FROM @@ -174,8 +166,9 @@ def get_relations(self, cursor, table_name): ON rc.UNIQUE_CONSTRAINT_NAME = ccu.CONSTRAINT_NAME WHERE - tc.TABLE_NAME="%s"''' - % self.connection.ops.quote_name(table_name) + tc.TABLE_SCHEMA=@schema_name AND tc.TABLE_NAME=@view_name + """, + params={"schema_name": schema_name, "view_name": table_name}, ) return { column: (referred_column, referred_table) @@ -194,6 +187,7 @@ def get_primary_key_column(self, cursor, table_name): :rtype: str :returns: The name of the PK column. """ + schema_name = self._get_schema_name(cursor) results = cursor.run_sql_in_snapshot( """ SELECT @@ -205,9 +199,9 @@ def get_primary_key_column(self, cursor, table_name): AS ccu ON tc.CONSTRAINT_NAME = ccu.CONSTRAINT_NAME WHERE - tc.TABLE_NAME="%s" AND tc.CONSTRAINT_TYPE='PRIMARY KEY' AND tc.TABLE_SCHEMA='' - """ - % self.connection.ops.quote_name(table_name) + tc.TABLE_NAME=@table_name AND tc.CONSTRAINT_TYPE='PRIMARY KEY' AND tc.TABLE_SCHEMA=@schema_name + """, + params={"schema_name": schema_name, "table_name": table_name}, ) return results[0][0] if results else None @@ -224,18 +218,17 @@ def get_constraints(self, cursor, table_name): :returns: A dictionary with constraints. """ constraints = {} - quoted_table_name = self.connection.ops.quote_name(table_name) + schema_name = self._get_schema_name(cursor) # Firstly populate all available constraints and their columns. constraint_columns = cursor.run_sql_in_snapshot( - ''' + """ SELECT CONSTRAINT_NAME, COLUMN_NAME FROM INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE - WHERE TABLE_NAME="{table}"'''.format( - table=quoted_table_name - ) + WHERE TABLE_NAME=@table AND TABLE_SCHEMA=@schema_name""", + params={"table": table_name, "schema_name": schema_name}, ) for constraint, column_name in constraint_columns: if constraint not in constraints: @@ -254,15 +247,14 @@ def get_constraints(self, cursor, table_name): # Add the various constraints by type. constraint_types = cursor.run_sql_in_snapshot( - ''' + """ SELECT CONSTRAINT_NAME, CONSTRAINT_TYPE FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE - TABLE_NAME="{table}"'''.format( - table=quoted_table_name - ) + TABLE_NAME=@table AND TABLE_SCHEMA=@schema_name""", + params={"table": table_name, "schema_name": schema_name}, ) for constraint, constraint_type in constraint_types: already_added = constraint in constraints @@ -303,14 +295,13 @@ def get_constraints(self, cursor, table_name): RIGHT JOIN INFORMATION_SCHEMA.INDEX_COLUMNS AS idx_col ON - idx_col.INDEX_NAME = idx.INDEX_NAME AND idx_col.TABLE_NAME="{table}" + idx_col.INDEX_NAME = idx.INDEX_NAME AND idx_col.TABLE_NAME=@table AND idx_col.TABLE_SCHEMA=idx.TABLE_SCHEMA WHERE - idx.TABLE_NAME="{table}" + idx.TABLE_NAME=@table AND idx.TABLE_SCHEMA=@schema_name ORDER BY idx_col.ORDINAL_POSITION - """.format( - table=quoted_table_name - ) + """, + params={"table": table_name, "schema_name": schema_name}, ) for ( index_name, @@ -350,6 +341,7 @@ def get_key_columns(self, cursor, table_name): for all key columns in the given table. """ key_columns = [] + schema_name = self._get_schema_name(cursor) cursor.execute( """SELECT tc.COLUMN_NAME as column_name, @@ -366,10 +358,12 @@ def get_key_columns(self, cursor, table_name): ON rc.CONSTRAINT_NAME = ccu.CONSTRAINT_NAME WHERE - tc.TABLE_NAME="{table}" - """.format( - table=self.connection.ops.quote_name(table_name) - ) + tc.TABLE_NAME=@table AND tc.TABLE_SCHEMA=@schema_name + """, + params={"table": table_name, "schema_name": schema_name}, ) key_columns.extend(cursor.fetchall()) return key_columns + + def _get_schema_name(self, cursor): + return cursor.connection.current_schema