mirror of
https://github.com/graphql-python/graphene-django.git
synced 2024-11-22 17:47:12 +03:00
e8f36b018d
* Fix test Client headers for Django 4.2 * Lazy import pkg_resources since it could be quite heavy * Remove use of pkg_resources altogether
154 lines
4.3 KiB
Python
154 lines
4.3 KiB
Python
import inspect
|
|
|
|
import django
|
|
from django.db import connection, models, transaction
|
|
from django.db.models.manager import Manager
|
|
from django.utils.encoding import force_str
|
|
from django.utils.functional import Promise
|
|
|
|
from graphene.utils.str_converters import to_camel_case
|
|
|
|
try:
|
|
import django_filters # noqa
|
|
|
|
DJANGO_FILTER_INSTALLED = True
|
|
except ImportError:
|
|
DJANGO_FILTER_INSTALLED = False
|
|
|
|
|
|
def isiterable(value):
|
|
try:
|
|
iter(value)
|
|
except TypeError:
|
|
return False
|
|
return True
|
|
|
|
|
|
def _camelize_django_str(s):
|
|
if isinstance(s, Promise):
|
|
s = force_str(s)
|
|
return to_camel_case(s) if isinstance(s, str) else s
|
|
|
|
|
|
def camelize(data):
|
|
if isinstance(data, dict):
|
|
return {_camelize_django_str(k): camelize(v) for k, v in data.items()}
|
|
if isiterable(data) and not isinstance(data, (str, Promise)):
|
|
return [camelize(d) for d in data]
|
|
return data
|
|
|
|
|
|
def _get_model_ancestry(model):
|
|
model_ancestry = [model]
|
|
|
|
for base in model.__bases__:
|
|
if is_valid_django_model(base) and getattr(base, "_meta", False):
|
|
model_ancestry.append(base)
|
|
return model_ancestry
|
|
|
|
|
|
def get_reverse_fields(model, local_field_names):
|
|
"""
|
|
Searches through the model's ancestry and gets reverse relationships the models
|
|
Yields a tuple of (field.name, field)
|
|
"""
|
|
model_ancestry = _get_model_ancestry(model)
|
|
|
|
for _model in model_ancestry:
|
|
for name, attr in _model.__dict__.items():
|
|
# Don't duplicate any local fields
|
|
if name in local_field_names:
|
|
continue
|
|
|
|
# "rel" for FK and M2M relations and "related" for O2O Relations
|
|
related = getattr(attr, "rel", None) or getattr(attr, "related", None)
|
|
if isinstance(related, models.ManyToOneRel):
|
|
yield (name, related)
|
|
elif isinstance(related, models.ManyToManyRel) and not related.symmetrical:
|
|
yield (name, related)
|
|
|
|
|
|
def get_local_fields(model):
|
|
"""
|
|
Searches through the model's ancestry and gets the fields on the models
|
|
Returns a dict of {field.name: field}
|
|
"""
|
|
model_ancestry = _get_model_ancestry(model)
|
|
|
|
local_fields_dict = {}
|
|
for _model in model_ancestry:
|
|
for field in sorted(
|
|
list(_model._meta.fields) + list(_model._meta.local_many_to_many)
|
|
):
|
|
if field.name not in local_fields_dict:
|
|
local_fields_dict[field.name] = field
|
|
|
|
return list(local_fields_dict.items())
|
|
|
|
|
|
def maybe_queryset(value):
|
|
if isinstance(value, Manager):
|
|
value = value.get_queryset()
|
|
return value
|
|
|
|
|
|
def get_model_fields(model):
|
|
"""
|
|
Gets all the fields and relationships on the Django model and its ancestry.
|
|
Prioritizes local fields and relationships over the reverse relationships of the same name
|
|
Returns a tuple of (field.name, field)
|
|
"""
|
|
local_fields = get_local_fields(model)
|
|
local_field_names = {field[0] for field in local_fields}
|
|
reverse_fields = get_reverse_fields(model, local_field_names)
|
|
all_fields = local_fields + list(reverse_fields)
|
|
|
|
return all_fields
|
|
|
|
|
|
def is_valid_django_model(model):
|
|
return inspect.isclass(model) and issubclass(model, models.Model)
|
|
|
|
|
|
def import_single_dispatch():
|
|
try:
|
|
from functools import singledispatch
|
|
except ImportError:
|
|
singledispatch = None
|
|
|
|
if not singledispatch:
|
|
try:
|
|
from singledispatch import singledispatch
|
|
except ImportError:
|
|
pass
|
|
|
|
if not singledispatch:
|
|
raise Exception(
|
|
"It seems your python version does not include "
|
|
"functools.singledispatch. Please install the 'singledispatch' "
|
|
"package. More information here: "
|
|
"https://pypi.python.org/pypi/singledispatch"
|
|
)
|
|
|
|
return singledispatch
|
|
|
|
|
|
def set_rollback():
|
|
atomic_requests = connection.settings_dict.get("ATOMIC_REQUESTS", False)
|
|
if atomic_requests and connection.in_atomic_block:
|
|
transaction.set_rollback(True)
|
|
|
|
|
|
def bypass_get_queryset(resolver):
|
|
"""
|
|
Adds a bypass_get_queryset attribute to the resolver, which is used to
|
|
bypass any custom get_queryset method of the DjangoObjectType.
|
|
"""
|
|
resolver._bypass_get_queryset = True
|
|
return resolver
|
|
|
|
|
|
_DJANGO_VERSION_AT_LEAST_4_2 = django.VERSION[0] > 4 or (
|
|
django.VERSION[0] >= 4 and django.VERSION[1] >= 2
|
|
)
|