mirror of
				https://github.com/graphql-python/graphene-django.git
				synced 2025-11-04 09:57:53 +03:00 
			
		
		
		
	
		
			
				
	
	
		
			170 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			170 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import inspect
 | 
						|
from asyncio import get_running_loop
 | 
						|
 | 
						|
import django
 | 
						|
from django.db import connection, models, transaction
 | 
						|
from django.db.models.manager import Manager
 | 
						|
from django.utils.encoding import force_str
 | 
						|
from django.utils.functional import Promise
 | 
						|
 | 
						|
from graphene.utils.str_converters import to_camel_case
 | 
						|
 | 
						|
try:
 | 
						|
    import django_filters  # noqa
 | 
						|
 | 
						|
    DJANGO_FILTER_INSTALLED = True
 | 
						|
except ImportError:
 | 
						|
    DJANGO_FILTER_INSTALLED = False
 | 
						|
 | 
						|
 | 
						|
def isiterable(value):
 | 
						|
    try:
 | 
						|
        iter(value)
 | 
						|
    except TypeError:
 | 
						|
        return False
 | 
						|
    return True
 | 
						|
 | 
						|
 | 
						|
def _camelize_django_str(s):
 | 
						|
    if isinstance(s, Promise):
 | 
						|
        s = force_str(s)
 | 
						|
    return to_camel_case(s) if isinstance(s, str) else s
 | 
						|
 | 
						|
 | 
						|
def camelize(data):
 | 
						|
    if isinstance(data, dict):
 | 
						|
        return {_camelize_django_str(k): camelize(v) for k, v in data.items()}
 | 
						|
    if isiterable(data) and not isinstance(data, (str, Promise)):
 | 
						|
        return [camelize(d) for d in data]
 | 
						|
    return data
 | 
						|
 | 
						|
 | 
						|
def _get_model_ancestry(model):
 | 
						|
    model_ancestry = [model]
 | 
						|
 | 
						|
    for base in model.__bases__:
 | 
						|
        if is_valid_django_model(base) and getattr(base, "_meta", False):
 | 
						|
            model_ancestry.append(base)
 | 
						|
    return model_ancestry
 | 
						|
 | 
						|
 | 
						|
def get_reverse_fields(model, local_field_names):
 | 
						|
    """
 | 
						|
    Searches through the model's ancestry and gets reverse relationships the models
 | 
						|
    Yields a tuple of (field.name, field)
 | 
						|
    """
 | 
						|
    model_ancestry = _get_model_ancestry(model)
 | 
						|
 | 
						|
    for _model in model_ancestry:
 | 
						|
        for name, attr in _model.__dict__.items():
 | 
						|
            # Don't duplicate any local fields
 | 
						|
            if name in local_field_names:
 | 
						|
                continue
 | 
						|
 | 
						|
            # "rel" for FK and M2M relations and "related" for O2O Relations
 | 
						|
            related = getattr(attr, "rel", None) or getattr(attr, "related", None)
 | 
						|
            if isinstance(related, models.ManyToOneRel):
 | 
						|
                yield (name, related)
 | 
						|
            elif isinstance(related, models.ManyToManyRel) and not related.symmetrical:
 | 
						|
                yield (name, related)
 | 
						|
 | 
						|
 | 
						|
def get_local_fields(model):
 | 
						|
    """
 | 
						|
    Searches through the model's ancestry and gets the fields on the models
 | 
						|
    Returns a dict of {field.name: field}
 | 
						|
    """
 | 
						|
    model_ancestry = _get_model_ancestry(model)
 | 
						|
 | 
						|
    local_fields_dict = {}
 | 
						|
    for _model in model_ancestry:
 | 
						|
        for field in sorted(
 | 
						|
            list(_model._meta.fields) + list(_model._meta.local_many_to_many)
 | 
						|
        ):
 | 
						|
            if field.name not in local_fields_dict:
 | 
						|
                local_fields_dict[field.name] = field
 | 
						|
 | 
						|
    return list(local_fields_dict.items())
 | 
						|
 | 
						|
 | 
						|
def maybe_queryset(value):
 | 
						|
    if isinstance(value, Manager):
 | 
						|
        value = value.get_queryset()
 | 
						|
    return value
 | 
						|
 | 
						|
 | 
						|
def get_model_fields(model):
 | 
						|
    """
 | 
						|
    Gets all the fields and relationships on the Django model and its ancestry.
 | 
						|
    Prioritizes local fields and relationships over the reverse relationships of the same name
 | 
						|
    Returns a tuple of (field.name, field)
 | 
						|
    """
 | 
						|
    local_fields = get_local_fields(model)
 | 
						|
    local_field_names = {field[0] for field in local_fields}
 | 
						|
    reverse_fields = get_reverse_fields(model, local_field_names)
 | 
						|
    all_fields = local_fields + list(reverse_fields)
 | 
						|
 | 
						|
    return all_fields
 | 
						|
 | 
						|
 | 
						|
def is_valid_django_model(model):
 | 
						|
    return inspect.isclass(model) and issubclass(model, models.Model)
 | 
						|
 | 
						|
 | 
						|
def import_single_dispatch():
 | 
						|
    try:
 | 
						|
        from functools import singledispatch
 | 
						|
    except ImportError:
 | 
						|
        singledispatch = None
 | 
						|
 | 
						|
    if not singledispatch:
 | 
						|
        try:
 | 
						|
            from singledispatch import singledispatch
 | 
						|
        except ImportError:
 | 
						|
            pass
 | 
						|
 | 
						|
    if not singledispatch:
 | 
						|
        raise Exception(
 | 
						|
            "It seems your python version does not include "
 | 
						|
            "functools.singledispatch. Please install the 'singledispatch' "
 | 
						|
            "package. More information here: "
 | 
						|
            "https://pypi.python.org/pypi/singledispatch"
 | 
						|
        )
 | 
						|
 | 
						|
    return singledispatch
 | 
						|
 | 
						|
 | 
						|
def set_rollback():
 | 
						|
    atomic_requests = connection.settings_dict.get("ATOMIC_REQUESTS", False)
 | 
						|
    if atomic_requests and connection.in_atomic_block:
 | 
						|
        transaction.set_rollback(True)
 | 
						|
 | 
						|
 | 
						|
def is_running_async():
 | 
						|
    try:
 | 
						|
        get_running_loop()
 | 
						|
    except RuntimeError:
 | 
						|
        return False
 | 
						|
    else:
 | 
						|
        return True
 | 
						|
 | 
						|
 | 
						|
def is_sync_function(func):
 | 
						|
    return not inspect.iscoroutinefunction(func) and not inspect.isasyncgenfunction(
 | 
						|
        func
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def bypass_get_queryset(resolver):
 | 
						|
    """
 | 
						|
    Adds a bypass_get_queryset attribute to the resolver, which is used to
 | 
						|
    bypass any custom get_queryset method of the DjangoObjectType.
 | 
						|
    """
 | 
						|
    resolver._bypass_get_queryset = True
 | 
						|
    return resolver
 | 
						|
 | 
						|
 | 
						|
_DJANGO_VERSION_AT_LEAST_4_2 = django.VERSION[0] > 4 or (
 | 
						|
    django.VERSION[0] >= 4 and django.VERSION[1] >= 2
 | 
						|
)
 |