from django.db import connections

from .exception.formating import wrap_exception
from .sql.tracking import unwrap_cursor, wrap_cursor
from .types import DjangoDebug


class DjangoDebugContext:
    def __init__(self):
        self.debug_result = None
        self.results = []
        self.object = DjangoDebug(sql=[], exceptions=[])
        self.enable_instrumentation()

    def get_debug_result(self):
        if not self.debug_result:
            self.debug_result = self.results
            self.results = []
        return self.on_resolve_all_results()

    def on_resolve_error(self, value):
        if hasattr(self, "object"):
            self.object.exceptions.append(wrap_exception(value))
        return value

    def on_resolve_all_results(self):
        if self.results:
            self.debug_result = None
            return self.get_debug_result()
        self.disable_instrumentation()
        return self.object

    def add_result(self, result):
        if self.debug_result:
            self.results.append(result)

    def enable_instrumentation(self):
        # This is thread-safe because database connections are thread-local.
        for connection in connections.all():
            wrap_cursor(connection, self)

    def disable_instrumentation(self):
        for connection in connections.all():
            unwrap_cursor(connection)


class DjangoDebugMiddleware:
    def resolve(self, next, root, info, **args):
        context = info.context
        django_debug = getattr(context, "django_debug", None)
        if not django_debug:
            if context is None:
                raise Exception("DjangoDebug cannot be executed in None contexts")
            try:
                context.django_debug = DjangoDebugContext()
            except Exception:
                raise Exception(
                    "DjangoDebug need the context to be writable, context received: {}.".format(
                        context.__class__.__name__
                    )
                )
        if info.schema.get_type("DjangoDebug") == info.return_type:
            return context.django_debug.get_debug_result()
        try:
            result = next(root, info, **args)
        except Exception as e:
            return context.django_debug.on_resolve_error(e)
        context.django_debug.add_result(result)
        return result