diff --git a/graphene/contrib/django/converter.py b/graphene/contrib/django/converter.py index 663cd11c..b253de15 100644 --- a/graphene/contrib/django/converter.py +++ b/graphene/contrib/django/converter.py @@ -9,7 +9,7 @@ from graphene.core.fields import ( FloatField, ListField ) -from graphene.contrib.django.fields import DjangoModelField +from graphene.contrib.django.fields import ConnectionOrListField, DjangoModelField @singledispatch def convert_django_field(field, cls): @@ -48,9 +48,7 @@ def _(field, cls): def _(field, cls): schema = cls._meta.schema model_field = DjangoModelField(field.related_model) - if issubclass(cls, schema.relay.Node): - return schema.relay.ConnectionField(model_field) - return ListField(model_field) + return ConnectionOrListField(model_field) @convert_django_field.register(models.ForeignKey) diff --git a/graphene/contrib/django/fields.py b/graphene/contrib/django/fields.py index 992614c1..80925e4c 100644 --- a/graphene/contrib/django/fields.py +++ b/graphene/contrib/django/fields.py @@ -1,6 +1,9 @@ -from graphene.core.fields import Field -from graphene.utils import cached_property +from graphene.core.fields import ( + ListField +) +from graphene.core.fields import Field, LazyField +from graphene.utils import cached_property from graphene.env import get_global_schema @@ -13,6 +16,19 @@ def get_type_for_model(schema, model): return _type +class ConnectionOrListField(LazyField): + def get_field(self): + schema = self.schema + model_field = self.field_type + field_object_type = model_field.get_object_type() + if field_object_type and issubclass(field_object_type, schema.relay.Node): + field = schema.relay.ConnectionField(model_field) + else: + field = ListField(model_field) + field.contribute_to_class(self.object_type, self.field_name) + return field + + class DjangoModelField(Field): def __init__(self, model): super(DjangoModelField, self).__init__(None) diff --git a/graphene/core/fields.py b/graphene/core/fields.py index bd27aaaf..1c008422 100644 --- a/graphene/core/fields.py +++ b/graphene/core/fields.py @@ -125,6 +125,19 @@ class NativeField(Field): self.field = field or getattr(self, 'field') +class LazyField(Field): + @cached_property + def inner_field(self): + return self.get_field() + + @cached_property + def type(self): + return self.inner_field.type + + @cached_property + def field(self): + return self.inner_field.field + class TypeField(Field): def __init__(self, *args, **kwargs): super(TypeField, self).__init__(self.field_type, *args, **kwargs)