fix: static method

This commit is contained in:
Laurent Riviere 2022-10-27 13:06:48 +00:00
parent 10c53d3d98
commit 0ea463671e
2 changed files with 28 additions and 19 deletions

View File

@ -49,7 +49,7 @@ from .resolver import get_default_resolver
from .scalars import ID, Boolean, Float, Int, Scalar, String from .scalars import ID, Boolean, Float, Int, Scalar, String
from .structures import List, NonNull from .structures import List, NonNull
from .union import Union from .union import Union
from .utils import get_field_as from .utils import get_field_as, get_type_name
introspection_query = get_introspection_query() introspection_query = get_introspection_query()
IntrospectionSchema = introspection_types["__Schema"] IntrospectionSchema = introspection_types["__Schema"]
@ -141,9 +141,9 @@ class TypeMap(dict):
elif issubclass(graphene_type, Interface): elif issubclass(graphene_type, Interface):
graphql_type = self.create_interface(graphene_type) graphql_type = self.create_interface(graphene_type)
elif issubclass(graphene_type, Scalar): elif issubclass(graphene_type, Scalar):
graphql_type = self.create_scalar(graphene_type) graphql_type = self.create_scalar(graphene_type, self.type_name_prefix)
elif issubclass(graphene_type, Enum): elif issubclass(graphene_type, Enum):
graphql_type = self.create_enum(graphene_type) graphql_type = self.create_enum(graphene_type, self.type_name_prefix)
elif issubclass(graphene_type, Union): elif issubclass(graphene_type, Union):
graphql_type = self.construct_union(graphene_type) graphql_type = self.construct_union(graphene_type)
else: else:
@ -151,7 +151,11 @@ class TypeMap(dict):
self[name] = graphql_type self[name] = graphql_type
return graphql_type return graphql_type
def create_scalar(self, graphene_type): @staticmethod
def create_scalar(
graphene_type,
type_name_prefix=None,
):
# We have a mapping to the original GraphQL types # We have a mapping to the original GraphQL types
# so there are no collisions. # so there are no collisions.
_scalars = { _scalars = {
@ -166,14 +170,15 @@ class TypeMap(dict):
return GrapheneScalarType( return GrapheneScalarType(
graphene_type=graphene_type, graphene_type=graphene_type,
name=self.get_type_name(graphene_type), name=get_type_name(graphene_type, type_name_prefix),
description=graphene_type._meta.description, description=graphene_type._meta.description,
serialize=getattr(graphene_type, "serialize", None), serialize=getattr(graphene_type, "serialize", None),
parse_value=getattr(graphene_type, "parse_value", None), parse_value=getattr(graphene_type, "parse_value", None),
parse_literal=getattr(graphene_type, "parse_literal", None), parse_literal=getattr(graphene_type, "parse_literal", None),
) )
def create_enum(self, graphene_type): @staticmethod
def create_enum(graphene_type, type_name_prefix=None):
values = {} values = {}
for name, value in graphene_type._meta.enum.__members__.items(): for name, value in graphene_type._meta.enum.__members__.items():
description = getattr(value, "description", None) description = getattr(value, "description", None)
@ -201,7 +206,7 @@ class TypeMap(dict):
return GrapheneEnumType( return GrapheneEnumType(
graphene_type=graphene_type, graphene_type=graphene_type,
values=values, values=values,
name=self.get_type_name(graphene_type), name=get_type_name(graphene_type, type_name_prefix),
description=type_description, description=type_description,
) )
@ -408,18 +413,7 @@ class TypeMap(dict):
return default_type_resolver(root, info, return_type) return default_type_resolver(root, info, return_type)
def get_type_name(self, graphene_type): def get_type_name(self, graphene_type):
type_name_prefix = ( return get_type_name(graphene_type, self.type_name_prefix)
graphene_type._meta.type_name_prefix
if graphene_type._meta.type_name_prefix is not None
else self.type_name_prefix
)
if type_name_prefix:
return (
type_name_prefix[0].upper()
+ type_name_prefix[1:]
+ graphene_type._meta.name
)
return graphene_type._meta.name
def get_field_name(self, graphene_type, name): def get_field_name(self, graphene_type, name):
if graphene_type._meta.name in self.root_type_names: if graphene_type._meta.name in self.root_type_names:

View File

@ -48,3 +48,18 @@ def get_underlying_type(_type):
while hasattr(_type, "of_type"): while hasattr(_type, "of_type"):
_type = _type.of_type _type = _type.of_type
return _type return _type
def get_type_name(graphene_type, type_name_prefix):
type_name_prefix = (
graphene_type._meta.type_name_prefix
if graphene_type._meta.type_name_prefix is not None
else type_name_prefix
)
if type_name_prefix:
return (
type_name_prefix[0].upper()
+ type_name_prefix[1:]
+ graphene_type._meta.name
)
return graphene_type._meta.name