from __future__ import unicode_literals
from collections import defaultdict
from django.db.backends.postgresql_psycopg2 import schema
from django.conf import settings
import sqlparse
from sqlparse.tokens import DDL, DML, Keyword
from ...schema import is_shared_table
from ...schema import get_schema_model, _schema_table_exists
from ...schema import deactivate_schema, activate_template_schema
[docs]def get_constraints(cursor, table_name):
"""
Retrieves any constraints or keys (unique, pk, fk, check, index) across one or more columns.
This is copied (almost) verbatim from django, but replaces the use of "public" with "public" + "__template__".
We assume that this will find the relevant constraint, and rely on our operations keeping the others in sync.
"""
constraints = {}
# Loop over the key table, collecting things as constraints
# This will get PKs, FKs, and uniques, but not CHECK
cursor.execute("""
SELECT
kc.constraint_name,
kc.column_name,
c.constraint_type,
array(SELECT table_name::text || '.' || column_name::text
FROM information_schema.constraint_column_usage
WHERE constraint_name = kc.constraint_name)
FROM information_schema.key_column_usage AS kc
JOIN information_schema.table_constraints AS c ON
kc.table_schema = c.table_schema AND
kc.table_name = c.table_name AND
kc.constraint_name = c.constraint_name
WHERE
kc.table_schema IN (%s, %s) AND
kc.table_name = %s
ORDER BY kc.ordinal_position ASC
""", [settings.PUBLIC_SCHEMA, "__template__", table_name])
for constraint, column, kind, used_cols in cursor.fetchall():
# If we're the first column, make the record
if constraint not in constraints:
constraints[constraint] = {
"columns": [],
"primary_key": kind.lower() == "primary key",
"unique": kind.lower() in ["primary key", "unique"],
"foreign_key": tuple(used_cols[0].split(".", 1)) if kind.lower() == "foreign key" else None,
"check": False,
"index": False,
}
# Record the details
constraints[constraint]['columns'].append(column)
# Now get CHECK constraint columns
cursor.execute("""
SELECT kc.constraint_name, kc.column_name
FROM information_schema.constraint_column_usage AS kc
JOIN information_schema.table_constraints AS c ON
kc.table_schema = c.table_schema AND
kc.table_name = c.table_name AND
kc.constraint_name = c.constraint_name
WHERE
c.constraint_type = 'CHECK' AND
kc.table_schema IN (%s, %s) AND
kc.table_name = %s
""", [settings.PUBLIC_SCHEMA, "__template__", table_name])
for constraint, column in cursor.fetchall():
# If we're the first column, make the record
if constraint not in constraints:
constraints[constraint] = {
"columns": [],
"primary_key": False,
"unique": False,
"foreign_key": None,
"check": True,
"index": False,
}
# Record the details
constraints[constraint]['columns'].append(column)
# Now get indexes
cursor.execute("""
SELECT
c2.relname,
ARRAY(
SELECT (SELECT attname FROM pg_catalog.pg_attribute WHERE attnum = i AND attrelid = c.oid)
FROM unnest(idx.indkey) i
),
idx.indisunique,
idx.indisprimary
FROM pg_catalog.pg_class c, pg_catalog.pg_class c2,
pg_catalog.pg_index idx, pg_catalog.pg_namespace n
WHERE c.oid = idx.indrelid
AND idx.indexrelid = c2.oid
AND n.oid = c.relnamespace
AND n.nspname IN (%s, %s)
AND c.relname = %s
""", [settings.PUBLIC_SCHEMA, '__template__', table_name])
for index, columns, unique, primary in cursor.fetchall():
if index not in constraints:
constraints[index] = {
"columns": list(columns),
"primary_key": primary,
"unique": unique,
"foreign_key": None,
"check": False,
"index": True,
}
return constraints
def get_index_data(cursor, index_name):
cursor.execute('''SELECT
c.relname AS table_name,
n.nspname AS schema_name
FROM pg_catalog.pg_class c, pg_catalog.pg_class c2,
pg_catalog.pg_index idx, pg_catalog.pg_namespace n
WHERE c.oid = idx.indrelid
AND idx.indexrelid = c2.oid
AND n.oid = c.relnamespace
AND n.nspname IN (%s, %s)
AND c2.relname = %s
''', [settings.PUBLIC_SCHEMA, '__template__', index_name])
return [table_name for (table_name, schema_name) in cursor.fetchall()]
def get_table_and_schema(sql, cursor):
parsed = sqlparse.parse(sql)[0]
grouped = defaultdict(list)
identifiers = []
for token in parsed.tokens:
if token.ttype:
grouped[token.ttype].append(token.value)
elif token.get_name():
identifiers.append(token)
if grouped[DDL] and grouped[DDL][0] in ['CREATE', 'DROP', 'ALTER', 'CREATE OR REPLACE']:
# We may care about this.
keywords = grouped[Keyword]
# DROP INDEX does not have a table associated with it.
# We will have to hit the database to see what schema(ta) have an index with that name.
if 'INDEX' in keywords and grouped[DDL][0] == 'DROP':
return get_index_data(cursor, identifiers[0].get_name())[0], None
if 'VIEW' in keywords or 'TABLE' in keywords:
# We care about identifier 0
if identifiers:
return identifiers[0].get_name(), identifiers[0].get_parent_name()
elif 'TRIGGER' in keywords or 'INDEX' in keywords:
# We care about identifier 1
if len(identifiers) > 1:
return identifiers[1].get_name(), identifiers[1].get_parent_name()
# We also care about other non-DDL statements, as the implication is that they
# should apply to every known schema, if we are updating as part of a migration.
if grouped[DML] and grouped[DML][0] in ['INSERT INTO', 'UPDATE', 'DELETE FROM']:
if identifiers:
return identifiers[0].get_name(), identifiers[0].get_parent_name()
return None, None
class DatabaseSchemaEditor(schema.DatabaseSchemaEditor):
def __exit__(self, exc_type, exc_value, traceback):
# It seems that actions that add stuff to the deferred sql
# will fire per-schema, so we can end up with multiples.
# We'll reduce that to a unique list.
# Can't just do a set, as that may change ordering.
deferred_sql = []
for sql in self.deferred_sql:
if sql not in deferred_sql:
deferred_sql.append(sql)
self.deferred_sql = deferred_sql
return super(DatabaseSchemaEditor, self).__exit__(exc_type, exc_value, traceback)
# If we manage to rewrite the SQL so it injects schema clauses, then we can remove this override.
def execute(self, sql, params=None):
# We want to execute our SQL multiple times, if it is per-schema.
execute = super(DatabaseSchemaEditor, self).execute
table_name, schema_name = get_table_and_schema(sql, self.connection.cursor())
# TODO: try to get the apps from current project_state, not global apps.
if table_name and not schema_name and not is_shared_table(table_name):
if _schema_table_exists():
for each in get_schema_model().objects.all():
each.activate()
execute(sql, params)
activate_template_schema()
execute(sql, params)
deactivate_schema()
else:
execute(sql, params)
def _constraint_names(self, model, column_names=None, unique=None,
primary_key=None, index=None, foreign_key=None,
check=None):
"""
Returns all constraint names matching the columns and conditions
"""
column_names = list(column_names) if column_names else None
with self.connection.cursor() as cursor:
constraints = get_constraints(cursor, model._meta.db_table)
result = []
for name, infodict in constraints.items():
if column_names is None or column_names == infodict['columns']:
if unique is not None and infodict['unique'] != unique:
continue
if primary_key is not None and infodict['primary_key'] != primary_key:
continue
if index is not None and infodict['index'] != index:
continue
if check is not None and infodict['check'] != check:
continue
if foreign_key is not None and not infodict['foreign_key']:
continue
result.append(name)
return result