Improved TypeMap and Dynamic Field to optionally include the schema

This commit is contained in:
Syrus Akbary 2017-02-20 02:35:30 -08:00
parent ecb1edd5c2
commit 2f87698a0b
3 changed files with 56 additions and 51 deletions

View File

@ -9,10 +9,13 @@ class Dynamic(MountedType):
the schema. So we can have lazy fields. the schema. So we can have lazy fields.
''' '''
def __init__(self, type, _creation_counter=None): def __init__(self, type, with_schema=False, _creation_counter=None):
super(Dynamic, self).__init__(_creation_counter=_creation_counter) super(Dynamic, self).__init__(_creation_counter=_creation_counter)
assert inspect.isfunction(type) assert inspect.isfunction(type)
self.type = type self.type = type
self.with_schema = with_schema
def get_type(self): def get_type(self, schema=None):
if schema and self.with_schema:
return self.type(schema=schema)
return self.type() return self.type()

View File

@ -94,4 +94,4 @@ class Schema(GraphQLSchema):
] ]
if self.types: if self.types:
initial_types += self.types initial_types += self.types
self._type_map = TypeMap(initial_types, auto_camelcase=self.auto_camelcase) self._type_map = TypeMap(initial_types, auto_camelcase=self.auto_camelcase, schema=self)

View File

@ -51,8 +51,9 @@ def resolve_type(resolve_type_func, map, type_name, root, context, info):
class TypeMap(GraphQLTypeMap): class TypeMap(GraphQLTypeMap):
def __init__(self, types, auto_camelcase=True): def __init__(self, types, auto_camelcase=True, schema=None):
self.auto_camelcase = auto_camelcase self.auto_camelcase = auto_camelcase
self.schema = schema
super(TypeMap, self).__init__(types) super(TypeMap, self).__init__(types)
def reducer(self, map, type): def reducer(self, map, type):
@ -72,21 +73,25 @@ class TypeMap(GraphQLTypeMap):
if isinstance(_type, GrapheneGraphQLType): if isinstance(_type, GrapheneGraphQLType):
assert _type.graphene_type == type assert _type.graphene_type == type
return map return map
if issubclass(type, ObjectType): if issubclass(type, ObjectType):
return self.construct_objecttype(map, type) internal_type = self.construct_objecttype(map, type)
if issubclass(type, InputObjectType): if issubclass(type, InputObjectType):
return self.construct_inputobjecttype(map, type) internal_type = self.construct_inputobjecttype(map, type)
if issubclass(type, Interface): if issubclass(type, Interface):
return self.construct_interface(map, type) internal_type = self.construct_interface(map, type)
if issubclass(type, Scalar): if issubclass(type, Scalar):
return self.construct_scalar(map, type) internal_type = self.construct_scalar(map, type)
if issubclass(type, Enum): if issubclass(type, Enum):
return self.construct_enum(map, type) internal_type = self.construct_enum(map, type)
if issubclass(type, Union): if issubclass(type, Union):
return self.construct_union(map, type) internal_type = self.construct_union(map, type)
return map
return GraphQLTypeMap.reducer(map, internal_type)
def construct_scalar(self, map, type): def construct_scalar(self, map, type):
# We have a mapping to the original GraphQL types
# so there are no collisions.
_scalars = { _scalars = {
String: GraphQLString, String: GraphQLString,
Int: GraphQLInt, Int: GraphQLInt,
@ -95,9 +100,9 @@ class TypeMap(GraphQLTypeMap):
ID: GraphQLID ID: GraphQLID
} }
if type in _scalars: if type in _scalars:
map[type._meta.name] = _scalars[type] return _scalars[type]
else:
map[type._meta.name] = GrapheneScalarType( return GrapheneScalarType(
graphene_type=type, graphene_type=type,
name=type._meta.name, name=type._meta.name,
description=type._meta.description, description=type._meta.description,
@ -106,7 +111,6 @@ class TypeMap(GraphQLTypeMap):
parse_value=getattr(type, 'parse_value', None), parse_value=getattr(type, 'parse_value', None),
parse_literal=getattr(type, 'parse_literal', None), parse_literal=getattr(type, 'parse_literal', None),
) )
return map
def construct_enum(self, map, type): def construct_enum(self, map, type):
values = OrderedDict() values = OrderedDict()
@ -117,61 +121,61 @@ class TypeMap(GraphQLTypeMap):
description=getattr(value, 'description', None), description=getattr(value, 'description', None),
deprecation_reason=getattr(value, 'deprecation_reason', None) deprecation_reason=getattr(value, 'deprecation_reason', None)
) )
map[type._meta.name] = GrapheneEnumType( return GrapheneEnumType(
graphene_type=type, graphene_type=type,
values=values, values=values,
name=type._meta.name, name=type._meta.name,
description=type._meta.description, description=type._meta.description,
) )
return map
def construct_objecttype(self, map, type): def construct_objecttype(self, map, type):
if type._meta.name in map: if type._meta.name in map:
_type = map[type._meta.name] _type = map[type._meta.name]
if isinstance(_type, GrapheneGraphQLType): if isinstance(_type, GrapheneGraphQLType):
assert _type.graphene_type == type assert _type.graphene_type == type
return map return _type
map[type._meta.name] = GrapheneObjectType(
def interfaces():
interfaces = []
for interface in type._meta.interfaces:
i = self.construct_interface(map, interface)
interfaces.append(i)
return interfaces
return GrapheneObjectType(
graphene_type=type, graphene_type=type,
name=type._meta.name, name=type._meta.name,
description=type._meta.description, description=type._meta.description,
fields=None, fields=partial(self.construct_fields_for_type, map, type),
is_type_of=type.is_type_of, is_type_of=type.is_type_of,
interfaces=None interfaces=interfaces
) )
interfaces = []
for i in type._meta.interfaces:
map = self.reducer(map, i)
interfaces.append(map[i._meta.name])
map[type._meta.name]._provided_interfaces = interfaces
map[type._meta.name]._fields = self.construct_fields_for_type(map, type)
# self.reducer(map, map[type._meta.name])
return map
def construct_interface(self, map, type): def construct_interface(self, map, type):
if type._meta.name in map:
_type = map[type._meta.name]
if isinstance(_type, GrapheneInterfaceType):
assert _type.graphene_type == type
return _type
_resolve_type = None _resolve_type = None
if type.resolve_type: if type.resolve_type:
_resolve_type = partial(resolve_type, type.resolve_type, map, type._meta.name) _resolve_type = partial(resolve_type, type.resolve_type, map, type._meta.name)
map[type._meta.name] = GrapheneInterfaceType( return GrapheneInterfaceType(
graphene_type=type, graphene_type=type,
name=type._meta.name, name=type._meta.name,
description=type._meta.description, description=type._meta.description,
fields=None, fields=partial(self.construct_fields_for_type, map, type),
resolve_type=_resolve_type, resolve_type=_resolve_type,
) )
map[type._meta.name]._fields = self.construct_fields_for_type(map, type)
# self.reducer(map, map[type._meta.name])
return map
def construct_inputobjecttype(self, map, type): def construct_inputobjecttype(self, map, type):
map[type._meta.name] = GrapheneInputObjectType( return GrapheneInputObjectType(
graphene_type=type, graphene_type=type,
name=type._meta.name, name=type._meta.name,
description=type._meta.description, description=type._meta.description,
fields=None, fields=partial(self.construct_fields_for_type, map, type, is_input_type=True),
) )
map[type._meta.name]._fields = self.construct_fields_for_type(map, type, is_input_type=True)
return map
def construct_union(self, map, type): def construct_union(self, map, type):
_resolve_type = None _resolve_type = None
@ -179,16 +183,14 @@ class TypeMap(GraphQLTypeMap):
_resolve_type = partial(resolve_type, type.resolve_type, map, type._meta.name) _resolve_type = partial(resolve_type, type.resolve_type, map, type._meta.name)
types = [] types = []
for i in type._meta.types: for i in type._meta.types:
map = self.construct_objecttype(map, i) internal_type = self.construct_objecttype(map, i)
types.append(map[i._meta.name]) types.append(internal_type)
map[type._meta.name] = GrapheneUnionType( return GrapheneUnionType(
graphene_type=type, graphene_type=type,
name=type._meta.name, name=type._meta.name,
types=types, types=types,
resolve_type=_resolve_type, resolve_type=_resolve_type,
) )
map[type._meta.name].types = types
return map
def get_name(self, name): def get_name(self, name):
if self.auto_camelcase: if self.auto_camelcase:
@ -202,7 +204,7 @@ class TypeMap(GraphQLTypeMap):
fields = OrderedDict() fields = OrderedDict()
for name, field in type._meta.fields.items(): for name, field in type._meta.fields.items():
if isinstance(field, Dynamic): if isinstance(field, Dynamic):
field = get_field_as(field.get_type(), _as=Field) field = get_field_as(field.get_type(self.schema), _as=Field)
if not field: if not field:
continue continue
map = self.reducer(map, field.type) map = self.reducer(map, field.type)