diff --git a/graphene/types/dynamic.py b/graphene/types/dynamic.py index c5aada20..6c4092f0 100644 --- a/graphene/types/dynamic.py +++ b/graphene/types/dynamic.py @@ -9,10 +9,13 @@ class Dynamic(MountedType): the schema. So we can have lazy fields. ''' - def __init__(self, type, _creation_counter=None): + def __init__(self, type, with_schema=False, _creation_counter=None): super(Dynamic, self).__init__(_creation_counter=_creation_counter) assert inspect.isfunction(type) self.type = type + self.with_schema = with_schema - def get_type(self): + def get_type(self, schema=None): + if schema and self.with_schema: + return self.type(schema=schema) return self.type() diff --git a/graphene/types/schema.py b/graphene/types/schema.py index b95490ca..e9600dc9 100644 --- a/graphene/types/schema.py +++ b/graphene/types/schema.py @@ -94,4 +94,4 @@ class Schema(GraphQLSchema): ] if self.types: initial_types += self.types - self._type_map = TypeMap(initial_types, auto_camelcase=self.auto_camelcase) + self._type_map = TypeMap(initial_types, auto_camelcase=self.auto_camelcase, schema=self) diff --git a/graphene/types/typemap.py b/graphene/types/typemap.py index 70aa84cc..7c71f35a 100644 --- a/graphene/types/typemap.py +++ b/graphene/types/typemap.py @@ -51,8 +51,9 @@ def resolve_type(resolve_type_func, map, type_name, root, context, info): class TypeMap(GraphQLTypeMap): - def __init__(self, types, auto_camelcase=True): + def __init__(self, types, auto_camelcase=True, schema=None): self.auto_camelcase = auto_camelcase + self.schema = schema super(TypeMap, self).__init__(types) def reducer(self, map, type): @@ -72,21 +73,25 @@ class TypeMap(GraphQLTypeMap): if isinstance(_type, GrapheneGraphQLType): assert _type.graphene_type == type return map + if issubclass(type, ObjectType): - return self.construct_objecttype(map, type) + internal_type = self.construct_objecttype(map, type) if issubclass(type, InputObjectType): - return self.construct_inputobjecttype(map, type) + internal_type = self.construct_inputobjecttype(map, type) if issubclass(type, Interface): - return self.construct_interface(map, type) + internal_type = self.construct_interface(map, type) if issubclass(type, Scalar): - return self.construct_scalar(map, type) + internal_type = self.construct_scalar(map, type) if issubclass(type, Enum): - return self.construct_enum(map, type) + internal_type = self.construct_enum(map, type) if issubclass(type, Union): - return self.construct_union(map, type) - return map + internal_type = self.construct_union(map, type) + + return GraphQLTypeMap.reducer(map, internal_type) def construct_scalar(self, map, type): + # We have a mapping to the original GraphQL types + # so there are no collisions. _scalars = { String: GraphQLString, Int: GraphQLInt, @@ -95,18 +100,17 @@ class TypeMap(GraphQLTypeMap): ID: GraphQLID } if type in _scalars: - map[type._meta.name] = _scalars[type] - else: - map[type._meta.name] = GrapheneScalarType( - graphene_type=type, - name=type._meta.name, - description=type._meta.description, + return _scalars[type] - serialize=getattr(type, 'serialize', None), - parse_value=getattr(type, 'parse_value', None), - parse_literal=getattr(type, 'parse_literal', None), - ) - return map + return GrapheneScalarType( + graphene_type=type, + name=type._meta.name, + description=type._meta.description, + + serialize=getattr(type, 'serialize', None), + parse_value=getattr(type, 'parse_value', None), + parse_literal=getattr(type, 'parse_literal', None), + ) def construct_enum(self, map, type): values = OrderedDict() @@ -117,61 +121,61 @@ class TypeMap(GraphQLTypeMap): description=getattr(value, 'description', None), deprecation_reason=getattr(value, 'deprecation_reason', None) ) - map[type._meta.name] = GrapheneEnumType( + return GrapheneEnumType( graphene_type=type, values=values, name=type._meta.name, description=type._meta.description, ) - return map def construct_objecttype(self, map, type): if type._meta.name in map: _type = map[type._meta.name] if isinstance(_type, GrapheneGraphQLType): assert _type.graphene_type == type - return map - map[type._meta.name] = GrapheneObjectType( + return _type + + def interfaces(): + interfaces = [] + for interface in type._meta.interfaces: + i = self.construct_interface(map, interface) + interfaces.append(i) + return interfaces + + return GrapheneObjectType( graphene_type=type, name=type._meta.name, description=type._meta.description, - fields=None, + fields=partial(self.construct_fields_for_type, map, type), is_type_of=type.is_type_of, - interfaces=None + interfaces=interfaces ) - interfaces = [] - for i in type._meta.interfaces: - map = self.reducer(map, i) - interfaces.append(map[i._meta.name]) - map[type._meta.name]._provided_interfaces = interfaces - map[type._meta.name]._fields = self.construct_fields_for_type(map, type) - # self.reducer(map, map[type._meta.name]) - return map def construct_interface(self, map, type): + if type._meta.name in map: + _type = map[type._meta.name] + if isinstance(_type, GrapheneInterfaceType): + assert _type.graphene_type == type + return _type + _resolve_type = None if type.resolve_type: _resolve_type = partial(resolve_type, type.resolve_type, map, type._meta.name) - map[type._meta.name] = GrapheneInterfaceType( + return GrapheneInterfaceType( graphene_type=type, name=type._meta.name, description=type._meta.description, - fields=None, + fields=partial(self.construct_fields_for_type, map, type), resolve_type=_resolve_type, ) - map[type._meta.name]._fields = self.construct_fields_for_type(map, type) - # self.reducer(map, map[type._meta.name]) - return map def construct_inputobjecttype(self, map, type): - map[type._meta.name] = GrapheneInputObjectType( + return GrapheneInputObjectType( graphene_type=type, name=type._meta.name, description=type._meta.description, - fields=None, + fields=partial(self.construct_fields_for_type, map, type, is_input_type=True), ) - map[type._meta.name]._fields = self.construct_fields_for_type(map, type, is_input_type=True) - return map def construct_union(self, map, type): _resolve_type = None @@ -179,16 +183,14 @@ class TypeMap(GraphQLTypeMap): _resolve_type = partial(resolve_type, type.resolve_type, map, type._meta.name) types = [] for i in type._meta.types: - map = self.construct_objecttype(map, i) - types.append(map[i._meta.name]) - map[type._meta.name] = GrapheneUnionType( + internal_type = self.construct_objecttype(map, i) + types.append(internal_type) + return GrapheneUnionType( graphene_type=type, name=type._meta.name, types=types, resolve_type=_resolve_type, ) - map[type._meta.name].types = types - return map def get_name(self, name): if self.auto_camelcase: @@ -202,7 +204,7 @@ class TypeMap(GraphQLTypeMap): fields = OrderedDict() for name, field in type._meta.fields.items(): if isinstance(field, Dynamic): - field = get_field_as(field.get_type(), _as=Field) + field = get_field_as(field.get_type(self.schema), _as=Field) if not field: continue map = self.reducer(map, field.type)