from django.db import connections

from promise import Promise

from .sql.tracking import unwrap_cursor, wrap_cursor
from .types import DjangoDebug


class DjangoDebugContext(object):

    def __init__(self):
        self.debug_promise = None
        self.promises = []
        self.enable_instrumentation()
        self.object = DjangoDebug(sql=[])

    def get_debug_promise(self):
        if not self.debug_promise:
            self.debug_promise = Promise.all(self.promises)
        return self.debug_promise.then(self.on_resolve_all_promises)

    def on_resolve_all_promises(self, values):
        self.disable_instrumentation()
        return self.object

    def add_promise(self, promise):
        if self.debug_promise and not self.debug_promise.is_fulfilled:
            self.promises.append(promise)

    def enable_instrumentation(self):
        # This is thread-safe because database connections are thread-local.
        for connection in connections.all():
            wrap_cursor(connection, self)

    def disable_instrumentation(self):
        for connection in connections.all():
            unwrap_cursor(connection)


class DjangoDebugMiddleware(object):

    def resolve(self, next, root, args, context, info):
        django_debug = getattr(context, 'django_debug', None)
        if not django_debug:
            if context is None:
                raise Exception('DjangoDebug cannot be executed in None contexts')
            try:
                context.django_debug = DjangoDebugContext()
            except Exception:
                raise Exception('DjangoDebug need the context to be writable, context received: {}.'.format(
                    context.__class__.__name__
                ))
        if info.schema.get_type('DjangoDebug') == info.return_type:
            return context.django_debug.get_debug_promise()
        promise = next(root, args, context, info)
        context.django_debug.add_promise(promise)
        return promise