Improved querying

Conflicts:
	graphene/core/schema.py
This commit is contained in:
Syrus Akbary 2015-12-09 15:33:55 -08:00
parent dd5b26e6ed
commit 0aa251fd94
7 changed files with 178 additions and 21 deletions

View File

@ -2,7 +2,7 @@ from django.db import models
from singledispatch import singledispatch from singledispatch import singledispatch
from ...core.types.scalars import ID, Boolean, Float, Int, String from ...core.types.scalars import ID, Boolean, Float, Int, String
from .fields import ConnectionOrListField, DjangoModelField from .fields import DjangoField, ConnectionOrListField, DjangoModelField
try: try:
UUIDField = models.UUIDField UUIDField = models.UUIDField
@ -19,6 +19,14 @@ def convert_django_field(field):
(field, field.__class__)) (field, field.__class__))
def fetch_field(f):
def wrapped(field):
_type = f(field)
kwargs = dict(_type.kwargs, _field=field)
return DjangoField(_type, *_type.args, **kwargs)
return wrapped
@convert_django_field.register(models.DateField) @convert_django_field.register(models.DateField)
@convert_django_field.register(models.CharField) @convert_django_field.register(models.CharField)
@convert_django_field.register(models.TextField) @convert_django_field.register(models.TextField)
@ -26,11 +34,13 @@ def convert_django_field(field):
@convert_django_field.register(models.SlugField) @convert_django_field.register(models.SlugField)
@convert_django_field.register(models.URLField) @convert_django_field.register(models.URLField)
@convert_django_field.register(UUIDField) @convert_django_field.register(UUIDField)
@fetch_field
def convert_field_to_string(field): def convert_field_to_string(field):
return String(description=field.help_text) return String(description=field.help_text)
@convert_django_field.register(models.AutoField) @convert_django_field.register(models.AutoField)
@fetch_field
def convert_field_to_id(field): def convert_field_to_id(field):
return ID(description=field.help_text) return ID(description=field.help_text)
@ -40,22 +50,26 @@ def convert_field_to_id(field):
@convert_django_field.register(models.SmallIntegerField) @convert_django_field.register(models.SmallIntegerField)
@convert_django_field.register(models.BigIntegerField) @convert_django_field.register(models.BigIntegerField)
@convert_django_field.register(models.IntegerField) @convert_django_field.register(models.IntegerField)
@fetch_field
def convert_field_to_int(field): def convert_field_to_int(field):
return Int(description=field.help_text) return Int(description=field.help_text)
@convert_django_field.register(models.BooleanField) @convert_django_field.register(models.BooleanField)
@fetch_field
def convert_field_to_boolean(field): def convert_field_to_boolean(field):
return Boolean(description=field.help_text, required=True) return Boolean(description=field.help_text, required=True)
@convert_django_field.register(models.NullBooleanField) @convert_django_field.register(models.NullBooleanField)
@fetch_field
def convert_field_to_nullboolean(field): def convert_field_to_nullboolean(field):
return Boolean(description=field.help_text) return Boolean(description=field.help_text)
@convert_django_field.register(models.DecimalField) @convert_django_field.register(models.DecimalField)
@convert_django_field.register(models.FloatField) @convert_django_field.register(models.FloatField)
@fetch_field
def convert_field_to_float(field): def convert_field_to_float(field):
return Float(description=field.help_text) return Float(description=field.help_text)
@ -64,10 +78,10 @@ def convert_field_to_float(field):
@convert_django_field.register(models.ManyToOneRel) @convert_django_field.register(models.ManyToOneRel)
def convert_field_to_list_or_connection(field): def convert_field_to_list_or_connection(field):
model_field = DjangoModelField(field.related_model) model_field = DjangoModelField(field.related_model)
return ConnectionOrListField(model_field) return ConnectionOrListField(model_field, _field=field)
@convert_django_field.register(models.OneToOneField) @convert_django_field.register(models.OneToOneField)
@convert_django_field.register(models.ForeignKey) @convert_django_field.register(models.ForeignKey)
def convert_field_to_djangomodel(field): def convert_field_to_djangomodel(field):
return DjangoModelField(field.related_model, description=field.help_text) return DjangoField(DjangoModelField(field.related_model), description=field.help_text, _field=field)

View File

@ -9,7 +9,18 @@ from ...relay.utils import is_node
from .utils import get_type_for_model from .utils import get_type_for_model
class DjangoConnectionField(ConnectionField): class DjangoField(Field):
def decorate_resolver(self, resolver):
f = super(DjangoField, self).decorate_resolver(resolver)
setattr(f, 'django_fetch_field', self.field.name)
return f
def __init__(self, *args, **kwargs):
self.field = kwargs.pop('_field')
return super(DjangoField, self).__init__(*args, **kwargs)
class DjangoConnectionField(DjangoField, ConnectionField):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
cls = self.__class__ cls = self.__class__
@ -19,7 +30,7 @@ class DjangoConnectionField(ConnectionField):
return super(DjangoConnectionField, self).__init__(*args, **kwargs) return super(DjangoConnectionField, self).__init__(*args, **kwargs)
class ConnectionOrListField(Field): class ConnectionOrListField(DjangoField):
def internal_type(self, schema): def internal_type(self, schema):
model_field = self.type model_field = self.type
@ -27,9 +38,9 @@ class ConnectionOrListField(Field):
if not field_object_type: if not field_object_type:
raise SkipField() raise SkipField()
if is_node(field_object_type): if is_node(field_object_type):
field = ConnectionField(field_object_type) field = DjangoConnectionField(field_object_type, _field=self.field)
else: else:
field = Field(List(field_object_type)) field = DjangoField(List(field_object_type), _field=self.field)
field.contribute_to_class(self.object_type, self.attname) field.contribute_to_class(self.object_type, self.attname)
return schema.T(field) return schema.T(field)

View File

@ -11,9 +11,9 @@ from .models import Article, Reporter
def assert_conversion(django_field, graphene_field, *args): def assert_conversion(django_field, graphene_field, *args):
field = django_field(*args, help_text='Custom Help Text') field = django_field(*args, help_text='Custom Help Text')
graphene_type = convert_django_field(field) field = convert_django_field(field)
graphene_type = field.type
assert isinstance(graphene_type, graphene_field) assert isinstance(graphene_type, graphene_field)
field = graphene_type.as_field()
assert field.description == 'Custom Help Text' assert field.description == 'Custom Help Text'
return field return field

View File

@ -0,0 +1,79 @@
from graphql.core.utils.get_field_def import get_field_def
import pytest
import graphene
from graphene.contrib.django import DjangoObjectType
from ..tests.models import Reporter
from ..debug.plugin import DjangoDebugPlugin
# from examples.starwars_django.models import Character
pytestmark = pytest.mark.django_db
def test_should_query_well():
r1 = Reporter(last_name='ABA')
r1.save()
r2 = Reporter(last_name='Griffin')
r2.save()
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType)
all_reporters = ReporterType.List()
def resolve_all_reporters(self, args, info):
queryset = Reporter.objects.all()
# from graphql.core.execution.base import collect_fields
# print info.field_asts[0], info.parent_type, info.return_type.of_type
# field_asts = collect_fields(info.context, info.parent_type, info.field_asts[0], {}, set())
# field_asts = info.field_asts
field_asts = info.field_asts[0].selection_set.selections
only_args = []
for field in field_asts:
field_def = get_field_def(info.schema, info.return_type.of_type, field)
f = field_def.resolver
fetch_field = getattr(f, 'django_fetch_field')
only_args.append(fetch_field)
queryset = queryset.only(*only_args)
return queryset
def resolve_reporter(self, *args, **kwargs):
return Reporter.objects.first()
query = '''
query ReporterQuery {
allReporters {
lastName
email
}
__debug {
sql {
rawSql
}
}
}
'''
expected = {
'allReporters': [{
'lastName': 'ABA',
'email': '',
}, {
'lastName': 'Griffin',
'email': '',
}],
'__debug': {
'sql': [{
'rawSql': str(Reporter.objects.all().only('last_name', 'email').query)
}]
}
}
schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()])
result = schema.execute(query)
assert not result.errors
assert result.data == expected

View File

@ -29,7 +29,15 @@ class Human(DjangoNode):
def get_node(self, id): def get_node(self, id):
pass pass
schema = Schema(query=Human)
class Query(graphene.ObjectType):
human = graphene.Field(Human)
def resolve_human(self, args, info):
return Human()
schema = Schema(query=Query)
urlpatterns = [ urlpatterns = [

View File

@ -7,11 +7,13 @@ def format_response(response):
def test_client_get_good_query(settings, client): def test_client_get_good_query(settings, client):
settings.ROOT_URLCONF = 'graphene.contrib.django.tests.test_urls' settings.ROOT_URLCONF = 'graphene.contrib.django.tests.test_urls'
response = client.get('/graphql', {'query': '{ headline }'}) response = client.get('/graphql', {'query': '{ human { headline } }'})
json_response = format_response(response) json_response = format_response(response)
expected_json = { expected_json = {
'data': { 'data': {
'headline': None 'human': {
'headline': None
}
} }
} }
assert json_response == expected_json assert json_response == expected_json
@ -19,20 +21,22 @@ def test_client_get_good_query(settings, client):
def test_client_get_good_query_with_raise(settings, client): def test_client_get_good_query_with_raise(settings, client):
settings.ROOT_URLCONF = 'graphene.contrib.django.tests.test_urls' settings.ROOT_URLCONF = 'graphene.contrib.django.tests.test_urls'
response = client.get('/graphql', {'query': '{ raises }'}) response = client.get('/graphql', {'query': '{ human { raises } }'})
json_response = format_response(response) json_response = format_response(response)
assert json_response['errors'][0]['message'] == 'This field should raise exception' assert json_response['errors'][0]['message'] == 'This field should raise exception'
assert json_response['data']['raises'] is None assert json_response['data']['human']['raises'] is None
def test_client_post_good_query_json(settings, client): def test_client_post_good_query_json(settings, client):
settings.ROOT_URLCONF = 'graphene.contrib.django.tests.test_urls' settings.ROOT_URLCONF = 'graphene.contrib.django.tests.test_urls'
response = client.post( response = client.post(
'/graphql', json.dumps({'query': '{ headline }'}), 'application/json') '/graphql', json.dumps({'query': '{ human { headline } }'}), 'application/json')
json_response = format_response(response) json_response = format_response(response)
expected_json = { expected_json = {
'data': { 'data': {
'headline': None 'human': {
'headline': None
}
} }
} }
assert json_response == expected_json assert json_response == expected_json
@ -41,11 +45,13 @@ def test_client_post_good_query_json(settings, client):
def test_client_post_good_query_graphql(settings, client): def test_client_post_good_query_graphql(settings, client):
settings.ROOT_URLCONF = 'graphene.contrib.django.tests.test_urls' settings.ROOT_URLCONF = 'graphene.contrib.django.tests.test_urls'
response = client.post( response = client.post(
'/graphql', '{ headline }', 'application/graphql') '/graphql', '{ human { headline } }', 'application/graphql')
json_response = format_response(response) json_response = format_response(response)
expected_json = { expected_json = {
'data': { 'data': {
'headline': None 'human': {
'headline': None
}
} }
} }
assert json_response == expected_json assert json_response == expected_json

View File

@ -1,4 +1,10 @@
from graphql_django_view import GraphQLView as BaseGraphQLView from django.http import HttpResponseNotAllowed
from django.http.response import HttpResponseBadRequest
from graphql.core import Source, parse, validate
from graphql.core.execution import ExecutionResult
from graphql.core.utils.get_operation_ast import get_operation_ast
from graphql_django_view import GraphQLView as BaseGraphQLView, HttpError
class GraphQLView(BaseGraphQLView): class GraphQLView(BaseGraphQLView):
@ -12,5 +18,38 @@ class GraphQLView(BaseGraphQLView):
**kwargs **kwargs
) )
def get_root_value(self, request): def execute_graphql_request(self, request):
return self.graphene_schema.query(super(GraphQLView, self).get_root_value(request)) query, variables, operation_name = self.get_graphql_params(request, self.parse_body(request))
if not query:
raise HttpError(HttpResponseBadRequest('Must provide query string.'))
source = Source(query, name='GraphQL request')
try:
document_ast = parse(source)
except Exception as e:
return ExecutionResult(errors=[e], invalid=True)
validation_errors = validate(self.schema, document_ast)
if validation_errors:
return ExecutionResult(invalid=True, errors=validation_errors)
if request.method.lower() == 'get':
operation_ast = get_operation_ast(document_ast, operation_name)
if operation_ast and operation_ast.operation != 'query':
raise HttpError(HttpResponseNotAllowed(
['POST'], 'Can only perform a {} operation from a POST request.'.format(operation_ast.operation)
))
try:
return self.graphene_schema.execute(
document_ast,
self.get_root_value(request),
variables,
operation_name=operation_name,
validate_ast=False,
request_context=request
)
except Exception as e:
return ExecutionResult(errors=[e], invalid=True)