diff --git a/graphene/types/tests/test_typemap.py b/graphene/types/tests/test_typemap.py index e5dc3995..4d427a0a 100644 --- a/graphene/types/tests/test_typemap.py +++ b/graphene/types/tests/test_typemap.py @@ -141,3 +141,45 @@ def test_inputobject(): foo_field = fields['fooBar'] assert isinstance(foo_field, GraphQLInputObjectField) assert foo_field.description == 'Field description' + + +def test_objecttype_camelcase(): + class MyObjectType(ObjectType): + '''Description''' + foo_bar = String(bar_foo=String()) + + typemap = TypeMap([MyObjectType]) + assert 'MyObjectType' in typemap + graphql_type = typemap['MyObjectType'] + assert isinstance(graphql_type, GraphQLObjectType) + assert graphql_type.name == 'MyObjectType' + assert graphql_type.description == 'Description' + + fields = graphql_type.fields + assert list(fields.keys()) == ['fooBar'] + foo_field = fields['fooBar'] + assert isinstance(foo_field, GraphQLField) + assert foo_field.args == { + 'barFoo': GraphQLArgument(GraphQLString, out_name='bar_foo') + } + + +def test_objecttype_camelcase_disabled(): + class MyObjectType(ObjectType): + '''Description''' + foo_bar = String(bar_foo=String()) + + typemap = TypeMap([MyObjectType], auto_camelcase=False) + assert 'MyObjectType' in typemap + graphql_type = typemap['MyObjectType'] + assert isinstance(graphql_type, GraphQLObjectType) + assert graphql_type.name == 'MyObjectType' + assert graphql_type.description == 'Description' + + fields = graphql_type.fields + assert list(fields.keys()) == ['foo_bar'] + foo_field = fields['foo_bar'] + assert isinstance(foo_field, GraphQLField) + assert foo_field.args == { + 'bar_foo': GraphQLArgument(GraphQLString, out_name='bar_foo') + } diff --git a/graphene/types/typemap.py b/graphene/types/typemap.py index 457bba6e..360939f2 100644 --- a/graphene/types/typemap.py +++ b/graphene/types/typemap.py @@ -41,45 +41,41 @@ def resolve_type(resolve_type_func, map, root, args, info): class TypeMap(GraphQLTypeMap): def __init__(self, types, auto_camelcase=True): - if not auto_camelcase: - raise Exception("Disabling auto_camelcase is not *yet* supported, but will be soon!") + self.auto_camelcase = auto_camelcase super(TypeMap, self).__init__(types) - @classmethod - def reducer(cls, map, type): + def reducer(self, map, type): if not type: return map if inspect.isfunction(type): type = type() if is_graphene_type(type): - return cls.graphene_reducer(map, type) - return super(TypeMap, cls).reducer(map, type) + return self.graphene_reducer(map, type) + return GraphQLTypeMap.reducer(map, type) - @classmethod - def graphene_reducer(cls, map, type): + def graphene_reducer(self, map, type): if isinstance(type, (List, NonNull)): - return cls.reducer(map, type.of_type) + return self.reducer(map, type.of_type) if type._meta.name in map: _type = map[type._meta.name] if is_graphene_type(_type): assert _type.graphene_type == type return map if issubclass(type, ObjectType): - return cls.construct_objecttype(map, type) + return self.construct_objecttype(map, type) if issubclass(type, InputObjectType): - return cls.construct_inputobjecttype(map, type) + return self.construct_inputobjecttype(map, type) if issubclass(type, Interface): - return cls.construct_interface(map, type) + return self.construct_interface(map, type) if issubclass(type, Scalar): - return cls.construct_scalar(map, type) + return self.construct_scalar(map, type) if issubclass(type, Enum): - return cls.construct_enum(map, type) + return self.construct_enum(map, type) if issubclass(type, Union): - return cls.construct_union(map, type) + return self.construct_union(map, type) return map - @classmethod - def construct_scalar(cls, map, type): + def construct_scalar(self, map, type): from .definitions import GrapheneScalarType _scalars = { String: GraphQLString, @@ -102,8 +98,7 @@ class TypeMap(GraphQLTypeMap): ) return map - @classmethod - def construct_enum(cls, map, type): + def construct_enum(self, map, type): from .definitions import GrapheneEnumType values = OrderedDict() for name, value in type._meta.enum.__members__.items(): @@ -121,8 +116,7 @@ class TypeMap(GraphQLTypeMap): ) return map - @classmethod - def construct_objecttype(cls, map, type): + def construct_objecttype(self, map, type): from .definitions import GrapheneObjectType map[type._meta.name] = GrapheneObjectType( graphene_type=type, @@ -134,15 +128,14 @@ class TypeMap(GraphQLTypeMap): ) interfaces = [] for i in type._meta.interfaces: - map = cls.reducer(map, i) + map = self.reducer(map, i) interfaces.append(map[i._meta.name]) map[type._meta.name]._provided_interfaces = interfaces - map[type._meta.name]._fields = cls.construct_fields_for_type(map, type) - # cls.reducer(map, map[type._meta.name]) + map[type._meta.name]._fields = self.construct_fields_for_type(map, type) + # self.reducer(map, map[type._meta.name]) return map - @classmethod - def construct_interface(cls, map, type): + def construct_interface(self, map, type): from .definitions import GrapheneInterfaceType _resolve_type = None if type.resolve_type: @@ -154,12 +147,11 @@ class TypeMap(GraphQLTypeMap): fields=None, resolve_type=_resolve_type, ) - map[type._meta.name]._fields = cls.construct_fields_for_type(map, type) - # cls.reducer(map, map[type._meta.name]) + map[type._meta.name]._fields = self.construct_fields_for_type(map, type) + # self.reducer(map, map[type._meta.name]) return map - @classmethod - def construct_inputobjecttype(cls, map, type): + def construct_inputobjecttype(self, map, type): from .definitions import GrapheneInputObjectType map[type._meta.name] = GrapheneInputObjectType( graphene_type=type, @@ -167,18 +159,17 @@ class TypeMap(GraphQLTypeMap): description=type._meta.description, fields=None, ) - map[type._meta.name]._fields = cls.construct_fields_for_type(map, type, is_input_type=True) + map[type._meta.name]._fields = self.construct_fields_for_type(map, type, is_input_type=True) return map - @classmethod - def construct_union(cls, map, type): + def construct_union(self, map, type): from .definitions import GrapheneUnionType _resolve_type = None if type.resolve_type: _resolve_type = partial(resolve_type, type.resolve_type, map) types = [] for i in type._meta.types: - map = cls.construct_objecttype(map, i) + map = self.construct_objecttype(map, i) types.append(map[i._meta.name]) map[type._meta.name] = GrapheneUnionType( graphene_type=type, @@ -189,24 +180,23 @@ class TypeMap(GraphQLTypeMap): map[type._meta.name].types = types return map - @classmethod - def process_field_name(cls, name): - return to_camel_case(name) + def get_name(self, name): + if self.auto_camelcase: + return to_camel_case(name) + return name - @classmethod - def default_resolver(cls, attname, root, *_): + def default_resolver(self, attname, root, *_): return getattr(root, attname, None) - @classmethod - def construct_fields_for_type(cls, map, type, is_input_type=False): + def construct_fields_for_type(self, map, type, is_input_type=False): fields = OrderedDict() for name, field in type._meta.fields.items(): if isinstance(field, Dynamic): field = field.get_type() if not field: continue - map = cls.reducer(map, field.type) - field_type = cls.get_field_type(map, field.type) + map = self.reducer(map, field.type) + field_type = self.get_field_type(map, field.type) if is_input_type: _field = GraphQLInputObjectField( field_type, @@ -217,9 +207,9 @@ class TypeMap(GraphQLTypeMap): else: args = OrderedDict() for arg_name, arg in field.args.items(): - map = cls.reducer(map, arg.type) - arg_type = cls.get_field_type(map, arg.type) - processed_arg_name = arg.name or cls.process_field_name(arg_name) + map = self.reducer(map, arg.type) + arg_type = self.get_field_type(map, arg.type) + processed_arg_name = arg.name or self.get_name(arg_name) args[processed_arg_name] = GraphQLArgument( arg_type, out_name=arg.name or arg_name, @@ -229,16 +219,15 @@ class TypeMap(GraphQLTypeMap): _field = GraphQLField( field_type, args=args, - resolver=field.get_resolver(cls.get_resolver_for_type(type, name)), + resolver=field.get_resolver(self.get_resolver_for_type(type, name)), deprecation_reason=field.deprecation_reason, description=field.description ) - field_name = field.name or cls.process_field_name(name) + field_name = field.name or self.get_name(name) fields[field_name] = _field return fields - @classmethod - def get_resolver_for_type(cls, type, name): + def get_resolver_for_type(self, type, name): if not issubclass(type, ObjectType): return resolver = getattr(type, 'resolve_{}'.format(name), None) @@ -259,14 +248,13 @@ class TypeMap(GraphQLTypeMap): return resolver.__func__ return resolver - return partial(cls.default_resolver, name) + return partial(self.default_resolver, name) - @classmethod - def get_field_type(cls, map, type): + def get_field_type(self, map, type): if isinstance(type, List): - return GraphQLList(cls.get_field_type(map, type.of_type)) + return GraphQLList(self.get_field_type(map, type.of_type)) if isinstance(type, NonNull): - return GraphQLNonNull(cls.get_field_type(map, type.of_type)) + return GraphQLNonNull(self.get_field_type(map, type.of_type)) if inspect.isfunction(type): type = type() return map.get(type._meta.name)