mirror of
https://github.com/graphql-python/graphene-django.git
synced 2024-11-25 02:54:06 +03:00
45c2aa09b5
* Allows field's choices to be a callable Starting in Django 5 field's choices can also be a callable * test if field with callable choices converts into enum --------- Co-authored-by: Kien Dang <mail@kien.ai>
514 lines
18 KiB
Python
514 lines
18 KiB
Python
import inspect
|
|
from collections.abc import Callable
|
|
from functools import partial, singledispatch, wraps
|
|
|
|
from django.db import models
|
|
from django.utils.encoding import force_str
|
|
from django.utils.functional import Promise
|
|
from django.utils.module_loading import import_string
|
|
from graphql import GraphQLError
|
|
|
|
from graphene import (
|
|
ID,
|
|
UUID,
|
|
Boolean,
|
|
Date,
|
|
DateTime,
|
|
Decimal,
|
|
Dynamic,
|
|
Enum,
|
|
Field,
|
|
Float,
|
|
Int,
|
|
List,
|
|
NonNull,
|
|
String,
|
|
Time,
|
|
)
|
|
from graphene.types.json import JSONString
|
|
from graphene.types.resolver import get_default_resolver
|
|
from graphene.types.scalars import BigInt
|
|
from graphene.utils.str_converters import to_camel_case
|
|
|
|
try:
|
|
from graphql import assert_name
|
|
except ImportError:
|
|
# Support for older versions of graphql
|
|
from graphql import assert_valid_name as assert_name
|
|
from graphql.pyutils import register_description
|
|
|
|
from .compat import ArrayField, HStoreField, RangeField
|
|
from .fields import DjangoConnectionField, DjangoListField
|
|
from .settings import graphene_settings
|
|
from .utils.str_converters import to_const
|
|
|
|
|
|
class BlankValueField(Field):
|
|
def wrap_resolve(self, parent_resolver):
|
|
resolver = self.resolver or parent_resolver
|
|
|
|
# create custom resolver
|
|
def blank_field_wrapper(func):
|
|
@wraps(func)
|
|
def wrapped_resolver(*args, **kwargs):
|
|
return_value = func(*args, **kwargs)
|
|
if return_value == "":
|
|
return None
|
|
return return_value
|
|
|
|
return wrapped_resolver
|
|
|
|
return blank_field_wrapper(resolver)
|
|
|
|
|
|
def convert_choice_name(name):
|
|
name = to_const(force_str(name))
|
|
try:
|
|
assert_name(name)
|
|
except GraphQLError:
|
|
name = "A_%s" % name
|
|
return name
|
|
|
|
|
|
def get_choices(choices):
|
|
converted_names = []
|
|
if isinstance(choices, Callable):
|
|
choices = choices()
|
|
|
|
# In restframework==3.15.0, choices are not passed
|
|
# as OrderedDict anymore, so it's safer to check
|
|
# for a dict
|
|
if isinstance(choices, dict):
|
|
choices = choices.items()
|
|
|
|
for value, help_text in choices:
|
|
if isinstance(help_text, (tuple, list)):
|
|
yield from get_choices(help_text)
|
|
else:
|
|
name = convert_choice_name(value)
|
|
while name in converted_names:
|
|
name += "_" + str(len(converted_names))
|
|
converted_names.append(name)
|
|
description = str(
|
|
help_text
|
|
) # TODO: translatable description: https://github.com/graphql-python/graphql-core-next/issues/58
|
|
yield name, value, description
|
|
|
|
|
|
def convert_choices_to_named_enum_with_descriptions(name, choices):
|
|
choices = list(get_choices(choices))
|
|
named_choices = [(c[0], c[1]) for c in choices]
|
|
named_choices_descriptions = {c[0]: c[2] for c in choices}
|
|
|
|
class EnumWithDescriptionsType:
|
|
@property
|
|
def description(self):
|
|
return str(named_choices_descriptions[self.name])
|
|
|
|
return_type = Enum(
|
|
name,
|
|
list(named_choices),
|
|
type=EnumWithDescriptionsType,
|
|
description="An enumeration.", # Temporary fix until https://github.com/graphql-python/graphene/pull/1502 is merged
|
|
)
|
|
return return_type
|
|
|
|
|
|
def generate_enum_name(django_model_meta, field):
|
|
if graphene_settings.DJANGO_CHOICE_FIELD_ENUM_CUSTOM_NAME:
|
|
# Try and import custom function
|
|
custom_func = import_string(
|
|
graphene_settings.DJANGO_CHOICE_FIELD_ENUM_CUSTOM_NAME
|
|
)
|
|
name = custom_func(field)
|
|
elif graphene_settings.DJANGO_CHOICE_FIELD_ENUM_V2_NAMING is True:
|
|
name = to_camel_case(f"{django_model_meta.object_name}_{field.name}")
|
|
else:
|
|
name = "{app_label}{object_name}{field_name}Choices".format(
|
|
app_label=to_camel_case(django_model_meta.app_label.title()),
|
|
object_name=django_model_meta.object_name,
|
|
field_name=to_camel_case(field.name.title()),
|
|
)
|
|
return name
|
|
|
|
|
|
def convert_choice_field_to_enum(field, name=None):
|
|
if name is None:
|
|
name = generate_enum_name(field.model._meta, field)
|
|
choices = field.choices
|
|
return convert_choices_to_named_enum_with_descriptions(name, choices)
|
|
|
|
|
|
def convert_django_field_with_choices(
|
|
field, registry=None, convert_choices_to_enum=None
|
|
):
|
|
if registry is not None:
|
|
converted = registry.get_converted_field(field)
|
|
if converted:
|
|
return converted
|
|
choices = getattr(field, "choices", None)
|
|
if convert_choices_to_enum is None:
|
|
convert_choices_to_enum = bool(
|
|
graphene_settings.DJANGO_CHOICE_FIELD_ENUM_CONVERT
|
|
)
|
|
if choices and convert_choices_to_enum:
|
|
EnumCls = convert_choice_field_to_enum(field)
|
|
required = not (field.blank or field.null)
|
|
|
|
converted = EnumCls(
|
|
description=get_django_field_description(field), required=required
|
|
).mount_as(BlankValueField)
|
|
else:
|
|
converted = convert_django_field(field, registry)
|
|
if registry is not None:
|
|
registry.register_converted_field(field, converted)
|
|
return converted
|
|
|
|
|
|
def get_django_field_description(field):
|
|
return str(field.help_text) if field.help_text else None
|
|
|
|
|
|
@singledispatch
|
|
def convert_django_field(field, registry=None):
|
|
raise Exception(
|
|
f"Don't know how to convert the Django field {field} ({field.__class__})"
|
|
)
|
|
|
|
|
|
@convert_django_field.register(models.CharField)
|
|
@convert_django_field.register(models.TextField)
|
|
@convert_django_field.register(models.EmailField)
|
|
@convert_django_field.register(models.SlugField)
|
|
@convert_django_field.register(models.URLField)
|
|
@convert_django_field.register(models.GenericIPAddressField)
|
|
@convert_django_field.register(models.FileField)
|
|
@convert_django_field.register(models.FilePathField)
|
|
def convert_field_to_string(field, registry=None):
|
|
return String(
|
|
description=get_django_field_description(field), required=not field.null
|
|
)
|
|
|
|
|
|
@convert_django_field.register(models.BigAutoField)
|
|
@convert_django_field.register(models.AutoField)
|
|
def convert_field_to_id(field, registry=None):
|
|
return ID(description=get_django_field_description(field), required=not field.null)
|
|
|
|
|
|
if hasattr(models, "SmallAutoField"):
|
|
|
|
@convert_django_field.register(models.SmallAutoField)
|
|
def convert_field_small_to_id(field, registry=None):
|
|
return convert_field_to_id(field, registry)
|
|
|
|
|
|
@convert_django_field.register(models.UUIDField)
|
|
def convert_field_to_uuid(field, registry=None):
|
|
return UUID(
|
|
description=get_django_field_description(field), required=not field.null
|
|
)
|
|
|
|
|
|
@convert_django_field.register(models.BigIntegerField)
|
|
def convert_big_int_field(field, registry=None):
|
|
return BigInt(description=field.help_text, required=not field.null)
|
|
|
|
|
|
@convert_django_field.register(models.PositiveIntegerField)
|
|
@convert_django_field.register(models.PositiveSmallIntegerField)
|
|
@convert_django_field.register(models.SmallIntegerField)
|
|
@convert_django_field.register(models.IntegerField)
|
|
def convert_field_to_int(field, registry=None):
|
|
return Int(description=get_django_field_description(field), required=not field.null)
|
|
|
|
|
|
@convert_django_field.register(models.NullBooleanField)
|
|
@convert_django_field.register(models.BooleanField)
|
|
def convert_field_to_boolean(field, registry=None):
|
|
return Boolean(
|
|
description=get_django_field_description(field), required=not field.null
|
|
)
|
|
|
|
|
|
@convert_django_field.register(models.DecimalField)
|
|
def convert_field_to_decimal(field, registry=None):
|
|
return Decimal(
|
|
description=get_django_field_description(field), required=not field.null
|
|
)
|
|
|
|
|
|
@convert_django_field.register(models.FloatField)
|
|
@convert_django_field.register(models.DurationField)
|
|
def convert_field_to_float(field, registry=None):
|
|
return Float(
|
|
description=get_django_field_description(field), required=not field.null
|
|
)
|
|
|
|
|
|
@convert_django_field.register(models.DateTimeField)
|
|
def convert_datetime_to_string(field, registry=None):
|
|
return DateTime(
|
|
description=get_django_field_description(field), required=not field.null
|
|
)
|
|
|
|
|
|
@convert_django_field.register(models.DateField)
|
|
def convert_date_to_string(field, registry=None):
|
|
return Date(
|
|
description=get_django_field_description(field), required=not field.null
|
|
)
|
|
|
|
|
|
@convert_django_field.register(models.TimeField)
|
|
def convert_time_to_string(field, registry=None):
|
|
return Time(
|
|
description=get_django_field_description(field), required=not field.null
|
|
)
|
|
|
|
|
|
@convert_django_field.register(models.OneToOneRel)
|
|
def convert_onetoone_field_to_djangomodel(field, registry=None):
|
|
from graphene.utils.str_converters import to_snake_case
|
|
|
|
from .types import DjangoObjectType
|
|
|
|
model = field.related_model
|
|
|
|
def dynamic_type():
|
|
_type = registry.get_type_for_model(model)
|
|
if not _type:
|
|
return
|
|
|
|
class CustomField(Field):
|
|
def wrap_resolve(self, parent_resolver):
|
|
"""
|
|
Implements a custom resolver which goes through the `get_node` method to ensure that
|
|
it goes through the `get_queryset` method of the DjangoObjectType.
|
|
"""
|
|
resolver = super().wrap_resolve(parent_resolver)
|
|
|
|
# If `get_queryset` was not overridden in the DjangoObjectType
|
|
# or if we explicitly bypass the `get_queryset` method,
|
|
# we can just return the default resolver.
|
|
if (
|
|
_type.get_queryset.__func__
|
|
is DjangoObjectType.get_queryset.__func__
|
|
or getattr(resolver, "_bypass_get_queryset", False)
|
|
):
|
|
return resolver
|
|
|
|
def custom_resolver(root, info, **args):
|
|
# Note: this function is used to resolve 1:1 relation fields
|
|
|
|
is_resolver_awaitable = inspect.iscoroutinefunction(resolver)
|
|
|
|
if is_resolver_awaitable:
|
|
fk_obj = resolver(root, info, **args)
|
|
# In case the resolver is a custom awaitable resolver that overwrites
|
|
# the default Django resolver
|
|
return fk_obj
|
|
|
|
field_name = to_snake_case(info.field_name)
|
|
reversed_field_name = root.__class__._meta.get_field(
|
|
field_name
|
|
).remote_field.name
|
|
try:
|
|
return _type.get_queryset(
|
|
_type._meta.model.objects.filter(
|
|
**{reversed_field_name: root.pk}
|
|
),
|
|
info,
|
|
).get()
|
|
except _type._meta.model.DoesNotExist:
|
|
return None
|
|
|
|
return custom_resolver
|
|
|
|
return CustomField(
|
|
_type,
|
|
required=not field.null,
|
|
)
|
|
|
|
return Dynamic(dynamic_type)
|
|
|
|
|
|
@convert_django_field.register(models.ManyToManyField)
|
|
@convert_django_field.register(models.ManyToManyRel)
|
|
@convert_django_field.register(models.ManyToOneRel)
|
|
def convert_field_to_list_or_connection(field, registry=None):
|
|
model = field.related_model
|
|
|
|
def dynamic_type():
|
|
_type = registry.get_type_for_model(model)
|
|
if not _type:
|
|
return
|
|
|
|
if isinstance(field, models.ManyToManyField):
|
|
description = get_django_field_description(field)
|
|
else:
|
|
description = get_django_field_description(field.field)
|
|
|
|
# If there is a connection, we should transform the field
|
|
# into a DjangoConnectionField
|
|
if _type._meta.connection:
|
|
# Use a DjangoFilterConnectionField if there are
|
|
# defined filter_fields or a filterset_class in the
|
|
# DjangoObjectType Meta
|
|
if _type._meta.filter_fields or _type._meta.filterset_class:
|
|
from .filter.fields import DjangoFilterConnectionField
|
|
|
|
return DjangoFilterConnectionField(
|
|
_type, required=True, description=description
|
|
)
|
|
|
|
return DjangoConnectionField(_type, required=True, description=description)
|
|
|
|
return DjangoListField(
|
|
_type,
|
|
required=True, # A Set is always returned, never None.
|
|
description=description,
|
|
)
|
|
|
|
return Dynamic(dynamic_type)
|
|
|
|
|
|
@convert_django_field.register(models.OneToOneField)
|
|
@convert_django_field.register(models.ForeignKey)
|
|
def convert_field_to_djangomodel(field, registry=None):
|
|
from graphene.utils.str_converters import to_snake_case
|
|
|
|
from .types import DjangoObjectType
|
|
|
|
model = field.related_model
|
|
|
|
def dynamic_type():
|
|
_type = registry.get_type_for_model(model)
|
|
if not _type:
|
|
return
|
|
|
|
class CustomField(Field):
|
|
def wrap_resolve(self, parent_resolver):
|
|
"""
|
|
Implements a custom resolver which goes through the `get_node` method to ensure that
|
|
it goes through the `get_queryset` method of the DjangoObjectType.
|
|
"""
|
|
resolver = super().wrap_resolve(parent_resolver)
|
|
|
|
# If `get_queryset` was not overridden in the DjangoObjectType
|
|
# or if we explicitly bypass the `get_queryset` method,
|
|
# we can just return the default resolver.
|
|
if (
|
|
_type.get_queryset.__func__
|
|
is DjangoObjectType.get_queryset.__func__
|
|
or getattr(resolver, "_bypass_get_queryset", False)
|
|
):
|
|
return resolver
|
|
|
|
def custom_resolver(root, info, **args):
|
|
# Note: this function is used to resolve FK or 1:1 fields
|
|
# it does not differentiate between custom-resolved fields
|
|
# and default resolved fields.
|
|
|
|
# because this is a django foreign key or one-to-one field, the primary-key for
|
|
# this node can be accessed from the root node.
|
|
# ex: article.reporter_id
|
|
|
|
# get the name of the id field from the root's model
|
|
field_name = to_snake_case(info.field_name)
|
|
db_field_key = root.__class__._meta.get_field(field_name).attname
|
|
if hasattr(root, db_field_key):
|
|
# get the object's primary-key from root
|
|
object_pk = getattr(root, db_field_key)
|
|
else:
|
|
return None
|
|
|
|
is_resolver_awaitable = inspect.iscoroutinefunction(resolver)
|
|
|
|
if is_resolver_awaitable:
|
|
fk_obj = resolver(root, info, **args)
|
|
# In case the resolver is a custom awaitable resolver that overwrites
|
|
# the default Django resolver
|
|
return fk_obj
|
|
|
|
instance_from_get_node = _type.get_node(info, object_pk)
|
|
|
|
if instance_from_get_node is None:
|
|
# no instance to return
|
|
return
|
|
elif (
|
|
isinstance(resolver, partial)
|
|
and resolver.func is get_default_resolver()
|
|
):
|
|
return instance_from_get_node
|
|
elif resolver is not get_default_resolver():
|
|
# Default resolver is overridden
|
|
# For optimization, add the instance to the resolver
|
|
setattr(root, field_name, instance_from_get_node)
|
|
# Explanation:
|
|
# previously, _type.get_node` is called which results in at least one hit to the database.
|
|
# But, if we did not pass the instance to the root, calling the resolver will result in
|
|
# another call to get the instance which results in at least two database queries in total
|
|
# to resolve this node only.
|
|
# That's why the value of the object is set in the root so when the object is accessed
|
|
# in the resolver (root.field_name) it does not access the database unless queried explicitly.
|
|
fk_obj = resolver(root, info, **args)
|
|
return fk_obj
|
|
else:
|
|
return instance_from_get_node
|
|
|
|
return custom_resolver
|
|
|
|
return CustomField(
|
|
_type,
|
|
description=get_django_field_description(field),
|
|
required=not field.null,
|
|
)
|
|
|
|
return Dynamic(dynamic_type)
|
|
|
|
|
|
@convert_django_field.register(ArrayField)
|
|
def convert_postgres_array_to_list(field, registry=None):
|
|
inner_type = convert_django_field(field.base_field)
|
|
if not isinstance(inner_type, (List, NonNull)):
|
|
inner_type = (
|
|
NonNull(type(inner_type))
|
|
if inner_type.kwargs["required"]
|
|
else type(inner_type)
|
|
)
|
|
return List(
|
|
inner_type,
|
|
description=get_django_field_description(field),
|
|
required=not field.null,
|
|
)
|
|
|
|
|
|
@convert_django_field.register(HStoreField)
|
|
@convert_django_field.register(models.JSONField)
|
|
def convert_json_field_to_string(field, registry=None):
|
|
return JSONString(
|
|
description=get_django_field_description(field), required=not field.null
|
|
)
|
|
|
|
|
|
@convert_django_field.register(RangeField)
|
|
def convert_postgres_range_to_string(field, registry=None):
|
|
inner_type = convert_django_field(field.base_field)
|
|
if not isinstance(inner_type, (List, NonNull)):
|
|
inner_type = (
|
|
NonNull(type(inner_type))
|
|
if inner_type.kwargs["required"]
|
|
else type(inner_type)
|
|
)
|
|
return List(
|
|
inner_type,
|
|
description=get_django_field_description(field),
|
|
required=not field.null,
|
|
)
|
|
|
|
|
|
# Register Django lazy()-wrapped values as GraphQL description/help_text.
|
|
# This is needed for using lazy translations, see https://github.com/graphql-python/graphql-core-next/issues/58.
|
|
register_description(Promise)
|