Source code for boardinghouse.schema

import logging
import inspect
import threading

from django.apps import apps
from django.conf import settings
from django.core.cache import cache
from django.db import connection
from django.db.migrations.operations.base import Operation
from django.utils.translation import lazy

LOGGER = logging.getLogger(__name__)
LOGGER.addHandler(logging.NullHandler())

_thread_locals = threading.local()


[docs]class Forbidden(Exception): """ An exception that will be raised when an attempt to activate a non-valid schema is made. """
[docs]class TemplateSchemaActivation(Forbidden): """ An exception that will be raised when a user attempts to activate the __template__ schema. """ def __init__(self, *args, **kwargs): super(TemplateSchemaActivation, self).__init__( 'Activating template schema forbidden.', *args, **kwargs )
[docs]def get_schema_model(): """ Return the class that is currently set as the schema model. """ return apps.get_model(settings.BOARDINGHOUSE_SCHEMA_MODEL)
def _get_search_path(): cursor = connection.cursor() cursor.execute('SELECT current_schema()') search_path = cursor.fetchone()[0] cursor.close() return search_path.split(',') def _set_search_path(search_path): cursor = connection.cursor() cursor.execute('SET search_path TO %s,{}'.format(settings.PUBLIC_SCHEMA), [search_path]) cursor.close() def _schema_exists(schema_name, cursor=None): if cursor: cursor.execute('''SELECT schema_name FROM information_schema.schemata WHERE schema_name = %s''', [schema_name]) return bool(cursor.fetchone()) cursor = connection.cursor() try: return _schema_exists(schema_name, cursor) finally: cursor.close()
[docs]def get_active_schema_name(): """ Get the currently active schema. This requires a database query to ask it what the current `search_path` is. """ active_schema = getattr(_thread_locals, 'schema', None) if not active_schema: reported_schema = _get_search_path()[0] if _get_schema(reported_schema): active_schema = reported_schema else: active_schema = None _thread_locals.schema = active_schema return active_schema
[docs]def get_active_schema(): """ Get the (internal) name of the currently active schema. """ return _get_schema(get_active_schema_name())
[docs]def get_active_schemata(): """ Get a (cached) list of all currently active schemata. """ schemata = cache.get('active-schemata') if schemata is None: schemata = get_schema_model().objects.active() cache.set('active-schemata', schemata) return schemata
[docs]def get_all_schemata(): """ Get a (cached) list of all schemata. """ schemata = cache.get('all-schemata') if schemata is None: schemata = get_schema_model().objects.all() cache.set('all-schemata', schemata) return schemata
def _get_schema(schema_name): """ Get the matching active schema object for the given name, if it exists. """ if not schema_name: return for schema in get_active_schemata(): if schema_name == schema.schema or schema_name == schema: return schema
[docs]def activate_schema(schema_name): """ Activate the current schema: this will execute, in the database connection, something like: .. code:: sql SET search_path TO "foo",public; It sends signals before and after that the schema will be, and was activated. Must be passed a string: the internal name of the schema to activate. """ from .signals import schema_pre_activate, schema_post_activate if schema_name == '__template__': raise TemplateSchemaActivation() schema_pre_activate.send(sender=None, schema_name=schema_name) _set_search_path(schema_name) schema_post_activate.send(sender=None, schema_name=schema_name) _thread_locals.schema = schema_name
[docs]def activate_template_schema(): """ Activate the template schema. You probably don't want to do this. Sometimes you do (like for instance to apply migrations). """ from .signals import schema_pre_activate, schema_post_activate _thread_locals.schema = None schema_name = '__template__' schema_pre_activate.send(sender=None, schema_name=schema_name) _set_search_path(schema_name) schema_post_activate.send(sender=None, schema_name=schema_name)
def get_template_schema(): return get_schema_model()('__template__')
[docs]def deactivate_schema(schema=None): """ Deactivate the provided (or current) schema. """ from .signals import schema_pre_activate, schema_post_activate cursor = connection.cursor() schema_pre_activate.send(sender=None, schema_name=None) cursor.execute('SET search_path TO "$user",{}'.format(settings.PUBLIC_SCHEMA)) schema_post_activate.send(sender=None, schema_name=None) _thread_locals.schema = None cursor.close()
#: These models are required to be shared by the system. REQUIRED_SHARED_MODELS = [ 'auth.user', 'auth.permission', 'auth.group', 'boardinghouse.schema', 'sites.site', 'sessions.session', 'contenttypes.contenttype', 'admin.logentry', 'migrations.migration', # Maybe lazy() these? They only apply if the values for the settings.* # are not the defaults. lazy(lambda: settings.BOARDINGHOUSE_SCHEMA_MODEL), lazy(lambda: settings.AUTH_USER_MODEL), ] REQUIRED_SHARED_TABLES = [ 'django_migrations', ] def _is_join_model(model): """ We define a model as a join model if all of it's fields are related fields (or it's primary key), and there is more than just one field. I am not 100% happy with this definition. """ return all([ (field.primary_key or field.rel) for field in model._meta.fields ]) and len(model._meta.fields) > 1
[docs]def is_shared_model(model): """ Is the model (or instance of a model) one that should be in the public/shared schema? """ if model._is_shared_model: return True app_model = '{m.app_label}.{m.model_name}'.format(m=model._meta).lower() # These should be case insensitive! if app_model in REQUIRED_SHARED_MODELS: return True if app_model in settings.SHARED_MODELS: return True # Sometimes, we want a join table to be private. if app_model in settings.PRIVATE_MODELS: return False # if all fields are auto or fk, then we are a join model, # and if all related objects are shared, then we must # also be shared, unless we were explicitly marked as private # above. if _is_join_model(model): return all([ is_shared_model(field.rel.get_related_field().model) for field in model._meta.fields if field.rel ]) return False
[docs]def is_shared_table(table, apps=apps): """ Is the model from the provided database table name shared? We may need to look and see if we can work out which models this table joins. """ if table in REQUIRED_SHARED_TABLES: return True # Get a mapping of all table names to models. models = apps.get_models() # If we are in a migration operation, we need to look in that for models. # We really only should be injecting ourselves if we find a frame that contains # a database_(forwards|backwards) function. for frame in inspect.stack(): frame_locals = frame[0].f_locals if frame[3] == 'database_forwards' and all([ local in frame_locals for local in ('from_state', 'to_state', 'schema_editor', 'self') ]) and isinstance(frame_locals['self'], Operation): # Should this be from_state, or to_state, or should we look in both? from_state = frame_locals['from_state'] to_state = frame_locals['to_state'] models = set() if to_state.apps: models = models.union(to_state.apps.get_models()) if from_state.apps: models = models.union(from_state.apps.get_models()) break table_map = dict([ (x._meta.db_table, x) for x in models if not x._meta.proxy ]) # If we have a match, see if that one is shared. if table in table_map: return is_shared_model(table_map[table]) # It may be a join table. prefixes = [ (db_table, model) for db_table, model in table_map.items() if table.startswith(db_table) ] if len(prefixes) == 1: db_table, model = prefixes[0] rel_model = model._meta.get_field( table.replace(db_table, '').lstrip('_') ).rel.get_related_field().model elif len(prefixes) == 0: # No matching models found. # Assume this is not a shared table... return False else: return is_shared_model(model) return is_shared_model(model) and is_shared_model(rel_model)
# Internal helper functions. def _schema_table_exists(): table_name = get_schema_model()._meta.db_table cursor = connection.cursor() cursor.execute("SELECT * FROM information_schema.tables WHERE table_name = %s", [table_name]) return bool(cursor.fetchone()) def _wrap_command(command): def inner(self, *args, **kwargs): cursor = connection.cursor() # In the case of create table statements, we want to make sure # they go to the public schema, but want reads to come from # __template__. cursor.execute('SET search_path TO {},__template__'.format(settings.PUBLIC_SCHEMA)) command(self, *args, **kwargs) deactivate_schema() return inner