This commit is contained in:
lilac-supernova-2 2023-10-25 11:33:15 +03:00 committed by GitHub
commit d0b8c97cdc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 198 additions and 115 deletions

View File

@ -12,6 +12,10 @@ dev-setup:
tests: tests:
PYTHONPATH=. pytest graphene_django --cov=graphene_django -vv PYTHONPATH=. pytest graphene_django --cov=graphene_django -vv
.PHONY: tests-repeat ## Run unit tests 100 times to possibly identify flaky unit tests (and run them in parallel)
tests-repeat:
PYTHONPATH=. pytest graphene_django --cov=graphene_django -vv --count 100 -n logical
.PHONY: format ## Format code .PHONY: format ## Format code
format: format:
ruff format graphene_django examples setup.py ruff format graphene_django examples setup.py

View File

@ -31,14 +31,14 @@ def initialize():
falcon = Ship(id="4", name="Millennium Falcon", faction=rebels) falcon = Ship(id="4", name="Millennium Falcon", faction=rebels)
falcon.save() falcon.save()
homeOne = Ship(id="5", name="Home One", faction=rebels) home_one = Ship(id="5", name="Home One", faction=rebels)
homeOne.save() home_one.save()
tieFighter = Ship(id="6", name="TIE Fighter", faction=empire) tie_fighter = Ship(id="6", name="TIE Fighter", faction=empire)
tieFighter.save() tie_fighter.save()
tieInterceptor = Ship(id="7", name="TIE Interceptor", faction=empire) tie_interceptor = Ship(id="7", name="TIE Interceptor", faction=empire)
tieInterceptor.save() tie_interceptor.save()
executor = Ship(id="8", name="Executor", faction=empire) executor = Ship(id="8", name="Executor", faction=empire)
executor.save() executor.save()

View File

@ -1,6 +1,6 @@
from django.db import connections from django.db import connections
from .exception.formating import wrap_exception from .exception.formatting import wrap_exception
from .sql.tracking import unwrap_cursor, wrap_cursor from .sql.tracking import unwrap_cursor, wrap_cursor
from .types import DjangoDebug from .types import DjangoDebug

View File

@ -8,19 +8,102 @@ from .filters import ListFilter, RangeFilter, TypedFilter
from .filterset import custom_filterset_factory, setup_filterset from .filterset import custom_filterset_factory, setup_filterset
def get_field_type(registry, model, field_name): def get_field_type_from_registry(registry, model, field_name):
""" """
Try to get a model field corresponding Graphql type from the DjangoObjectType. Try to get a model field corresponding GraphQL type from the DjangoObjectType.
""" """
object_type = registry.get_type_for_model(model) object_type = registry.get_type_for_model(model)
if object_type: if not object_type:
object_type_field = object_type._meta.fields.get(field_name) return None
if object_type_field:
field_type = object_type_field.type object_type_field = object_type._meta.fields.get(field_name)
if isinstance(field_type, graphene.NonNull): if not object_type_field:
field_type = field_type.of_type return None
return field_type
return None field_type = object_type_field.type
if isinstance(field_type, graphene.NonNull):
field_type = field_type.of_type
return field_type
def get_field_type_from_model_field(model_field, form_field, registry):
"""
Get the field type from the model field.
If the model field is a foreign key, then we need to get the type from the related model.
"""
if (
isinstance(form_field, forms.ModelChoiceField)
or isinstance(form_field, forms.ModelMultipleChoiceField)
or isinstance(form_field, GlobalIDMultipleChoiceField)
or isinstance(form_field, GlobalIDFormField)
):
# Foreign key have dynamic types and filtering on a foreign key actually means filtering on its ID.
return get_field_type_from_registry(registry, model_field.related_model, "id")
return get_field_type_from_registry(registry, model_field.model, model_field.name)
def get_form_field(model_field, filter_field, required):
"""
Retrieve the form field to use for the filter.
Get the form field either from:
# 1. the formfield corresponding to the model field
# 2. the field defined on filter
Returns None if no form field can be found.
"""
form_field = None
if hasattr(model_field, "formfield"):
form_field = model_field.formfield(required=required)
if not form_field:
form_field = filter_field.field
return form_field
def get_field_type_and_form_field_for_implicit_filter(
model, filter_type, filter_field, registry, required
):
"""
Get the filter type for filters that are not explicitly declared.
Returns a tuple of (field_type, form_field) where:
- field_type is the type of the filter argument
- form_field is the form field to use to validate the input value
"""
if filter_type == "isnull":
# Filter type is boolean, no form field.
return (graphene.Boolean, None)
model_field = get_model_field(model, filter_field.field_name)
form_field = get_form_field(model_field, filter_field, required)
# First try to get the matching field type from the GraphQL DjangoObjectType
if model_field:
field_type = get_field_type_from_model_field(model_field, form_field, registry)
return (field_type, form_field)
return (None, None)
def get_field_type_for_explicit_filter(filter_field, form_field):
"""
Fallback on converting the form field either because:
- it's an explicitly declared filters
- we did not manage to get the type from the model type
"""
from ..forms.converter import convert_form_field
form_field = form_field or filter_field.field
return convert_form_field(form_field).get_type()
def is_filter_list_or_range(filter_field):
"""
Determine if the filter is a ListFilter or RangeFilter.
"""
return isinstance(filter_field, ListFilter) or isinstance(filter_field, RangeFilter)
def get_filtering_args_from_filterset(filterset_class, type): def get_filtering_args_from_filterset(filterset_class, type):
@ -28,7 +111,6 @@ def get_filtering_args_from_filterset(filterset_class, type):
Inspect a FilterSet and produce the arguments to pass to a Graphene Field. Inspect a FilterSet and produce the arguments to pass to a Graphene Field.
These arguments will be available to filter against in the GraphQL API. These arguments will be available to filter against in the GraphQL API.
""" """
from ..forms.converter import convert_form_field
args = {} args = {}
model = filterset_class._meta.model model = filterset_class._meta.model
@ -49,49 +131,21 @@ def get_filtering_args_from_filterset(filterset_class, type):
if name not in filterset_class.declared_filters or isinstance( if name not in filterset_class.declared_filters or isinstance(
filter_field, TypedFilter filter_field, TypedFilter
): ):
# Get the filter field for filters that are no explicitly declared. (
if filter_type == "isnull": field_type,
field_type = graphene.Boolean form_field,
else: ) = get_field_type_and_form_field_for_implicit_filter(
model_field = get_model_field(model, filter_field.field_name) model, filter_type, filter_field, registry, required
)
# Get the form field either from:
# 1. the formfield corresponding to the model field
# 2. the field defined on filter
if hasattr(model_field, "formfield"):
form_field = model_field.formfield(required=required)
if not form_field:
form_field = filter_field.field
# First try to get the matching field type from the GraphQL DjangoObjectType
if model_field:
if (
isinstance(form_field, forms.ModelChoiceField)
or isinstance(form_field, forms.ModelMultipleChoiceField)
or isinstance(form_field, GlobalIDMultipleChoiceField)
or isinstance(form_field, GlobalIDFormField)
):
# Foreign key have dynamic types and filtering on a foreign key actually means filtering on its ID.
field_type = get_field_type(
registry, model_field.related_model, "id"
)
else:
field_type = get_field_type(
registry, model_field.model, model_field.name
)
if not field_type: if not field_type:
# Fallback on converting the form field either because: field_type = get_field_type_for_explicit_filter(
# - it's an explicitly declared filters filter_field, form_field
# - we did not manage to get the type from the model type )
form_field = form_field or filter_field.field
field_type = convert_form_field(form_field).get_type()
if isinstance(filter_field, ListFilter) or isinstance( # Replace InFilter/RangeFilter filters (`in`, `range`) argument type to be a list of
filter_field, RangeFilter # the same type as the field. See comments in `replace_csv_filters` method for more details.
): if is_filter_list_or_range(filter_field):
# Replace InFilter/RangeFilter filters (`in`, `range`) argument type to be a list of
# the same type as the field. See comments in `replace_csv_filters` method for more details.
field_type = graphene.List(field_type) field_type = graphene.List(field_type)
args[name] = graphene.Argument( args[name] = graphene.Argument(

View File

@ -30,7 +30,7 @@ class ArticleConnection(Connection):
test = String() test = String()
def resolve_test(): def resolve_test(self):
return "test" return "test"
class Meta: class Meta:

View File

@ -62,10 +62,7 @@ def construct_fields(
return fields return fields
def validate_fields(type_, model, fields, only_fields, exclude_fields): def validate_only_fields(only_fields, all_field_names, model, type_):
# Validate the given fields against the model's fields and custom fields
all_field_names = set(fields.keys())
only_fields = only_fields if only_fields is not ALL_FIELDS else ()
for name in only_fields or (): for name in only_fields or ():
if name in all_field_names: if name in all_field_names:
continue continue
@ -83,20 +80,22 @@ def validate_fields(type_, model, fields, only_fields, exclude_fields):
type_=type_, type_=type_,
) )
) )
continue
else: warnings.warn(
warnings.warn( (
( 'Field name "{field_name}" doesn\'t exist on Django model "{app_label}.{object_name}". '
'Field name "{field_name}" doesn\'t exist on Django model "{app_label}.{object_name}". ' 'Consider removing the field from the "fields" list of DjangoObjectType "{type_}" because it has no effect.'
'Consider removing the field from the "fields" list of DjangoObjectType "{type_}" because it has no effect.' ).format(
).format( field_name=name,
field_name=name, app_label=model._meta.app_label,
app_label=model._meta.app_label, object_name=model._meta.object_name,
object_name=model._meta.object_name, type_=type_,
type_=type_,
)
) )
)
def validate_exclude_fields(exclude_fields, all_field_names, model, type_):
# Validate exclude fields # Validate exclude fields
for name in exclude_fields or (): for name in exclude_fields or ():
if name in all_field_names: if name in all_field_names:
@ -105,19 +104,29 @@ def validate_fields(type_, model, fields, only_fields, exclude_fields):
f'Excluding the custom field "{name}" on DjangoObjectType "{type_}" has no effect. ' f'Excluding the custom field "{name}" on DjangoObjectType "{type_}" has no effect. '
'Either remove the custom field or remove the field from the "exclude" list.' 'Either remove the custom field or remove the field from the "exclude" list.'
) )
else: continue
if not hasattr(model, name):
warnings.warn( if not hasattr(model, name):
( warnings.warn(
'Django model "{app_label}.{object_name}" does not have a field or attribute named "{field_name}". ' (
'Consider removing the field from the "exclude" list of DjangoObjectType "{type_}" because it has no effect' 'Django model "{app_label}.{object_name}" does not have a field or attribute named "{field_name}". '
).format( 'Consider removing the field from the "exclude" list of DjangoObjectType "{type_}" because it has no effect'
field_name=name, ).format(
app_label=model._meta.app_label, field_name=name,
object_name=model._meta.object_name, app_label=model._meta.app_label,
type_=type_, object_name=model._meta.object_name,
) type_=type_,
) )
)
def validate_fields(type_, model, fields, only_fields, exclude_fields):
# Validate the given fields against the model's fields and custom fields
all_field_names = set(fields.keys())
only_fields = only_fields if only_fields is not ALL_FIELDS else ()
validate_only_fields(only_fields, all_field_names, model, type_)
validate_exclude_fields(exclude_fields, all_field_names, model, type_)
class DjangoObjectTypeOptions(ObjectTypeOptions): class DjangoObjectTypeOptions(ObjectTypeOptions):

View File

@ -135,7 +135,7 @@ class GraphQLTestMixin:
) )
self.client = client self.client = client
def assertResponseNoErrors(self, resp, msg=None): def assert_response_no_errors(self, resp, msg=None):
""" """
Assert that the call went through correctly. 200 means the syntax is ok, if there are no `errors`, Assert that the call went through correctly. 200 means the syntax is ok, if there are no `errors`,
the call was fine. the call was fine.
@ -145,7 +145,7 @@ class GraphQLTestMixin:
self.assertEqual(resp.status_code, 200, msg or content) self.assertEqual(resp.status_code, 200, msg or content)
self.assertNotIn("errors", list(content.keys()), msg or content) self.assertNotIn("errors", list(content.keys()), msg or content)
def assertResponseHasErrors(self, resp, msg=None): def assert_response_has_errors(self, resp, msg=None):
""" """
Assert that the call was failing. Take care: Even with errors, GraphQL returns status 200! Assert that the call was failing. Take care: Even with errors, GraphQL returns status 200!
:resp HttpResponse: Response :resp HttpResponse: Response

View File

@ -103,16 +103,15 @@ class GraphQLView(View):
): ):
if not schema: if not schema:
schema = graphene_settings.SCHEMA schema = graphene_settings.SCHEMA
self.schema = schema or self.schema
if middleware is None: if middleware is None:
middleware = graphene_settings.MIDDLEWARE middleware = graphene_settings.MIDDLEWARE
if isinstance(middleware, MiddlewareManager):
self.middleware = middleware
else:
self.middleware = list(instantiate_middleware(middleware))
self.schema = schema or self.schema
if middleware is not None:
if isinstance(middleware, MiddlewareManager):
self.middleware = middleware
else:
self.middleware = list(instantiate_middleware(middleware))
self.root_value = root_value self.root_value = root_value
self.pretty = pretty or self.pretty self.pretty = pretty or self.pretty
self.graphiql = graphiql or self.graphiql self.graphiql = graphiql or self.graphiql
@ -287,6 +286,25 @@ class GraphQLView(View):
return {} return {}
def validate_query_request_type(
self, request, document, operation_name, show_graphiql
):
if request.method.lower() == "get":
operation_ast = get_operation_ast(document, operation_name)
if (
operation_ast
and operation_ast.operation != OperationType.QUERY
and not show_graphiql
):
raise HttpError(
HttpResponseNotAllowed(
["POST"],
"Can only perform a {} operation from a POST request.".format(
operation_ast.operation.value
),
)
)
def execute_graphql_request( def execute_graphql_request(
self, request, data, query, variables, operation_name, show_graphiql=False self, request, data, query, variables, operation_name, show_graphiql=False
): ):
@ -300,20 +318,12 @@ class GraphQLView(View):
except Exception as e: except Exception as e:
return ExecutionResult(errors=[e]) return ExecutionResult(errors=[e])
if request.method.lower() == "get": self.validate_query_request_type(
operation_ast = get_operation_ast(document, operation_name) request, document, operation_name, show_graphiql
if operation_ast and operation_ast.operation != OperationType.QUERY: )
if show_graphiql: if show_graphiql:
return None return None
raise HttpError(
HttpResponseNotAllowed(
["POST"],
"Can only perform a {} operation from a POST request.".format(
operation_ast.operation.value
),
)
)
try: try:
extra_options = {} extra_options = {}
if self.execution_context_class: if self.execution_context_class:
@ -330,14 +340,7 @@ class GraphQLView(View):
options.update(extra_options) options.update(extra_options)
operation_ast = get_operation_ast(document, operation_name) operation_ast = get_operation_ast(document, operation_name)
if ( if self.is_atomic_mutation_enabled(operation_ast, connection):
operation_ast
and operation_ast.operation == OperationType.MUTATION
and (
graphene_settings.ATOMIC_MUTATIONS is True
or connection.settings_dict.get("ATOMIC_MUTATIONS", False) is True
)
):
with transaction.atomic(): with transaction.atomic():
result = self.schema.execute(**options) result = self.schema.execute(**options)
if getattr(request, MUTATION_ERRORS_FLAG, False) is True: if getattr(request, MUTATION_ERRORS_FLAG, False) is True:
@ -402,3 +405,14 @@ class GraphQLView(View):
meta = request.META meta = request.META
content_type = meta.get("CONTENT_TYPE", meta.get("HTTP_CONTENT_TYPE", "")) content_type = meta.get("CONTENT_TYPE", meta.get("HTTP_CONTENT_TYPE", ""))
return content_type.split(";", 1)[0].lower() return content_type.split(";", 1)[0].lower()
@staticmethod
def is_atomic_mutation_enabled(operation_ast, connection):
return (
operation_ast
and operation_ast.operation == OperationType.MUTATION
and (
graphene_settings.ATOMIC_MUTATIONS is True
or connection.settings_dict.get("ATOMIC_MUTATIONS", False) is True
)
)

View File

@ -22,6 +22,8 @@ tests_require = [
"pytz", "pytz",
"django-filter>=22.1", "django-filter>=22.1",
"pytest-django>=4.5.2", "pytest-django>=4.5.2",
"pytest-repeat>=0.9.1",
"pytest-xdist>=3.3.1",
] + rest_framework_require ] + rest_framework_require