mirror of
https://github.com/graphql-python/graphene-django.git
synced 2025-04-13 05:34:20 +03:00
fix: fk resolver permissions leak
This commit is contained in:
parent
72a3700856
commit
edb9690dbc
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -11,6 +11,7 @@ __pycache__/
|
|||
# Distribution / packaging
|
||||
.Python
|
||||
env/
|
||||
.env/
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import inspect
|
||||
from collections import OrderedDict
|
||||
from functools import singledispatch, wraps
|
||||
from functools import partial, singledispatch, wraps
|
||||
|
||||
from django.db import models
|
||||
from django.utils.encoding import force_str
|
||||
|
@ -25,6 +26,7 @@ from graphene import (
|
|||
)
|
||||
from graphene.types.json import JSONString
|
||||
from graphene.types.scalars import BigInt
|
||||
from graphene.types.resolver import get_default_resolver
|
||||
from graphene.utils.str_converters import to_camel_case
|
||||
from graphql import GraphQLError
|
||||
|
||||
|
@ -258,6 +260,9 @@ def convert_time_to_string(field, registry=None):
|
|||
|
||||
@convert_django_field.register(models.OneToOneRel)
|
||||
def convert_onetoone_field_to_djangomodel(field, registry=None):
|
||||
from graphene.utils.str_converters import to_snake_case
|
||||
from .types import DjangoObjectType
|
||||
|
||||
model = field.related_model
|
||||
|
||||
def dynamic_type():
|
||||
|
@ -265,7 +270,69 @@ def convert_onetoone_field_to_djangomodel(field, registry=None):
|
|||
if not _type:
|
||||
return
|
||||
|
||||
return Field(_type, required=not field.null)
|
||||
class CustomField(Field):
|
||||
def wrap_resolve(self, parent_resolver):
|
||||
"""
|
||||
Implements a custom resolver which go through the `get_node` method to insure that
|
||||
it goes through the `get_queryset` method of the DjangoObjectType.
|
||||
"""
|
||||
resolver = super().wrap_resolve(parent_resolver)
|
||||
|
||||
# If `get_queryset` was not overridden in the DjangoObjectType,
|
||||
# we can just return the default resolver.
|
||||
if (
|
||||
_type.get_queryset.__func__
|
||||
is DjangoObjectType.get_queryset.__func__
|
||||
):
|
||||
return resolver
|
||||
|
||||
def custom_resolver(root, info, **args):
|
||||
# Note: this function is used to resolve 1:1 relation fields
|
||||
# it does not differentiate between custom-resolved fields
|
||||
# and default resolved fields.
|
||||
|
||||
is_resolver_awaitable = inspect.iscoroutinefunction(resolver)
|
||||
|
||||
if is_resolver_awaitable:
|
||||
fk_obj = resolver(root, info, **args)
|
||||
# In case the resolver is a custom awaitable resolver that overwrites
|
||||
# the default Django resolver
|
||||
return fk_obj
|
||||
|
||||
object_pk = resolver(root, info, **args).pk
|
||||
instance_from_get_node = _type.get_node(info, object_pk)
|
||||
|
||||
if instance_from_get_node is None:
|
||||
# no instance to return
|
||||
return
|
||||
elif (
|
||||
isinstance(resolver, partial)
|
||||
and resolver.func is get_default_resolver()
|
||||
):
|
||||
return instance_from_get_node
|
||||
elif resolver is not get_default_resolver():
|
||||
# Default resolver is overridden
|
||||
# For optimization, add the instance to the resolver
|
||||
field_name = to_snake_case(info.field_name)
|
||||
setattr(root, field_name, instance_from_get_node)
|
||||
# Explanation:
|
||||
# previously, _type.get_node` is called which results in at least one hit to the database.
|
||||
# But, if we did not pass the instance to the root, calling the resolver will result in
|
||||
# another call to get the instance which results in at least two database queries in total
|
||||
# to resolve this node only.
|
||||
# That's why the value of the object is set in the root so when the object is accessed
|
||||
# in the resolver (root.field_name) it does not access the database unless queried explicitly.
|
||||
fk_obj = resolver(root, info, **args)
|
||||
return fk_obj
|
||||
else:
|
||||
return instance_from_get_node
|
||||
|
||||
return custom_resolver
|
||||
|
||||
return CustomField(
|
||||
_type,
|
||||
required=not field.null,
|
||||
)
|
||||
|
||||
return Dynamic(dynamic_type)
|
||||
|
||||
|
@ -313,6 +380,9 @@ def convert_field_to_list_or_connection(field, registry=None):
|
|||
@convert_django_field.register(models.OneToOneField)
|
||||
@convert_django_field.register(models.ForeignKey)
|
||||
def convert_field_to_djangomodel(field, registry=None):
|
||||
from graphene.utils.str_converters import to_snake_case
|
||||
from .types import DjangoObjectType
|
||||
|
||||
model = field.related_model
|
||||
|
||||
def dynamic_type():
|
||||
|
@ -320,7 +390,77 @@ def convert_field_to_djangomodel(field, registry=None):
|
|||
if not _type:
|
||||
return
|
||||
|
||||
return Field(
|
||||
class CustomField(Field):
|
||||
def wrap_resolve(self, parent_resolver):
|
||||
"""
|
||||
Implements a custom resolver which go through the `get_node` method to ensure that
|
||||
it goes through the `get_queryset` method of the DjangoObjectType.
|
||||
"""
|
||||
resolver = super().wrap_resolve(parent_resolver)
|
||||
|
||||
# If `get_queryset` was not overridden in the DjangoObjectType,
|
||||
# we can just return the default resolver.
|
||||
if (
|
||||
_type.get_queryset.__func__
|
||||
is DjangoObjectType.get_queryset.__func__
|
||||
):
|
||||
return resolver
|
||||
|
||||
def custom_resolver(root, info, **args):
|
||||
# Note: this function is used to resolve FK or 1:1 fields
|
||||
# it does not differentiate between custom-resolved fields
|
||||
# and default resolved fields.
|
||||
|
||||
# because this is a django foreign key or one-to-one field, the primary-key for
|
||||
# this node can be accessed from the root node.
|
||||
# ex: article.reporter_id
|
||||
|
||||
# get the name of the id field from the root's model
|
||||
field_name = to_snake_case(info.field_name)
|
||||
db_field_key = root.__class__._meta.get_field(field_name).attname
|
||||
if hasattr(root, db_field_key):
|
||||
# get the object's primary-key from root
|
||||
object_pk = getattr(root, db_field_key)
|
||||
else:
|
||||
return None
|
||||
|
||||
is_resolver_awaitable = inspect.iscoroutinefunction(resolver)
|
||||
|
||||
if is_resolver_awaitable:
|
||||
fk_obj = resolver(root, info, **args)
|
||||
# In case the resolver is a custom awaitable resolver that overwrites
|
||||
# the default Django resolver
|
||||
return fk_obj
|
||||
|
||||
instance_from_get_node = _type.get_node(info, object_pk)
|
||||
|
||||
if instance_from_get_node is None:
|
||||
# no instance to return
|
||||
return
|
||||
elif (
|
||||
isinstance(resolver, partial)
|
||||
and resolver.func is get_default_resolver()
|
||||
):
|
||||
return instance_from_get_node
|
||||
elif resolver is not get_default_resolver():
|
||||
# Default resolver is overridden
|
||||
# For optimization, add the instance to the resolver
|
||||
setattr(root, field_name, instance_from_get_node)
|
||||
# Explanation:
|
||||
# previously, _type.get_node` is called which results in at least one hit to the database.
|
||||
# But, if we did not pass the instance to the root, calling the resolver will result in
|
||||
# another call to get the instance which results in at least two database queries in total
|
||||
# to resolve this node only.
|
||||
# That's why the value of the object is set in the root so when the object is accessed
|
||||
# in the resolver (root.field_name) it does not access the database unless queried explicitly.
|
||||
fk_obj = resolver(root, info, **args)
|
||||
return fk_obj
|
||||
else:
|
||||
return instance_from_get_node
|
||||
|
||||
return custom_resolver
|
||||
|
||||
return CustomField(
|
||||
_type,
|
||||
description=get_django_field_description(field),
|
||||
required=not field.null,
|
||||
|
|
|
@ -8,7 +8,7 @@ from graphql_relay import to_global_id
|
|||
from ..fields import DjangoConnectionField
|
||||
from ..types import DjangoObjectType
|
||||
|
||||
from .models import Article, Reporter
|
||||
from .models import Article, Reporter, FilmDetails, Film
|
||||
|
||||
|
||||
class TestShouldCallGetQuerySetOnForeignKey:
|
||||
|
@ -127,6 +127,69 @@ class TestShouldCallGetQuerySetOnForeignKey:
|
|||
assert not result.errors
|
||||
assert result.data == {"reporter": {"firstName": "Jane"}}
|
||||
|
||||
def test_get_queryset_called_on_foreignkey(self):
|
||||
# If a user tries to access a reporter through an article they should get our authorization error
|
||||
query = """
|
||||
query getArticle($id: ID!) {
|
||||
article(id: $id) {
|
||||
headline
|
||||
reporter {
|
||||
firstName
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
result = self.schema.execute(query, variables={"id": self.articles[0].id})
|
||||
assert len(result.errors) == 1
|
||||
assert result.errors[0].message == "Not authorized to access reporters."
|
||||
|
||||
# An admin user should be able to get reporters through an article
|
||||
query = """
|
||||
query getArticle($id: ID!) {
|
||||
article(id: $id) {
|
||||
headline
|
||||
reporter {
|
||||
firstName
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
result = self.schema.execute(
|
||||
query,
|
||||
variables={"id": self.articles[0].id},
|
||||
context_value={"admin": True},
|
||||
)
|
||||
assert not result.errors
|
||||
assert result.data["article"] == {
|
||||
"headline": "A fantastic article",
|
||||
"reporter": {"firstName": "Jane"},
|
||||
}
|
||||
|
||||
# An admin user should not be able to access draft article through a reporter
|
||||
query = """
|
||||
query getReporter($id: ID!) {
|
||||
reporter(id: $id) {
|
||||
firstName
|
||||
articles {
|
||||
headline
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
result = self.schema.execute(
|
||||
query,
|
||||
variables={"id": self.reporter.id},
|
||||
context_value={"admin": True},
|
||||
)
|
||||
assert not result.errors
|
||||
assert result.data["reporter"] == {
|
||||
"firstName": "Jane",
|
||||
"articles": [{"headline": "A fantastic article"}],
|
||||
}
|
||||
|
||||
|
||||
class TestShouldCallGetQuerySetOnForeignKeyNode:
|
||||
"""
|
||||
|
@ -233,3 +296,268 @@ class TestShouldCallGetQuerySetOnForeignKeyNode:
|
|||
)
|
||||
assert not result.errors
|
||||
assert result.data == {"reporter": {"firstName": "Jane"}}
|
||||
|
||||
def test_get_queryset_called_on_foreignkey(self):
|
||||
# If a user tries to access a reporter through an article they should get our authorization error
|
||||
query = """
|
||||
query getArticle($id: ID!) {
|
||||
article(id: $id) {
|
||||
headline
|
||||
reporter {
|
||||
firstName
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
result = self.schema.execute(
|
||||
query, variables={"id": to_global_id("ArticleType", self.articles[0].id)}
|
||||
)
|
||||
assert len(result.errors) == 1
|
||||
assert result.errors[0].message == "Not authorized to access reporters."
|
||||
|
||||
# An admin user should be able to get reporters through an article
|
||||
query = """
|
||||
query getArticle($id: ID!) {
|
||||
article(id: $id) {
|
||||
headline
|
||||
reporter {
|
||||
firstName
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
result = self.schema.execute(
|
||||
query,
|
||||
variables={"id": to_global_id("ArticleType", self.articles[0].id)},
|
||||
context_value={"admin": True},
|
||||
)
|
||||
assert not result.errors
|
||||
assert result.data["article"] == {
|
||||
"headline": "A fantastic article",
|
||||
"reporter": {"firstName": "Jane"},
|
||||
}
|
||||
|
||||
# An admin user should not be able to access draft article through a reporter
|
||||
query = """
|
||||
query getReporter($id: ID!) {
|
||||
reporter(id: $id) {
|
||||
firstName
|
||||
articles {
|
||||
edges {
|
||||
node {
|
||||
headline
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
result = self.schema.execute(
|
||||
query,
|
||||
variables={"id": to_global_id("ReporterType", self.reporter.id)},
|
||||
context_value={"admin": True},
|
||||
)
|
||||
assert not result.errors
|
||||
assert result.data["reporter"] == {
|
||||
"firstName": "Jane",
|
||||
"articles": {"edges": [{"node": {"headline": "A fantastic article"}}]},
|
||||
}
|
||||
|
||||
|
||||
class TestShouldCallGetQuerySetOnOneToOne:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_schema(self):
|
||||
class FilmDetailsType(DjangoObjectType):
|
||||
class Meta:
|
||||
model = FilmDetails
|
||||
|
||||
@classmethod
|
||||
def get_queryset(cls, queryset, info):
|
||||
if info.context and info.context.get("permission_get_film_details"):
|
||||
return queryset
|
||||
raise Exception("Not authorized to access film details.")
|
||||
|
||||
class FilmType(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Film
|
||||
|
||||
@classmethod
|
||||
def get_queryset(cls, queryset, info):
|
||||
if info.context and info.context.get("permission_get_film"):
|
||||
return queryset
|
||||
raise Exception("Not authorized to access film.")
|
||||
|
||||
class Query(graphene.ObjectType):
|
||||
film_details = graphene.Field(
|
||||
FilmDetailsType, id=graphene.ID(required=True)
|
||||
)
|
||||
film = graphene.Field(FilmType, id=graphene.ID(required=True))
|
||||
|
||||
def resolve_film_details(self, info, id):
|
||||
return (
|
||||
FilmDetailsType.get_queryset(FilmDetails.objects, info)
|
||||
.filter(id=id)
|
||||
.last()
|
||||
)
|
||||
|
||||
def resolve_film(self, info, id):
|
||||
return FilmType.get_queryset(Film.objects, info).filter(id=id).last()
|
||||
|
||||
self.schema = graphene.Schema(query=Query)
|
||||
|
||||
self.films = [
|
||||
Film.objects.create(
|
||||
genre="do",
|
||||
),
|
||||
Film.objects.create(
|
||||
genre="ac",
|
||||
),
|
||||
]
|
||||
|
||||
self.film_details = [
|
||||
FilmDetails.objects.create(
|
||||
film=self.films[0],
|
||||
),
|
||||
FilmDetails.objects.create(
|
||||
film=self.films[1],
|
||||
),
|
||||
]
|
||||
|
||||
def test_get_queryset_called_on_field(self):
|
||||
# A user tries to access a film
|
||||
query = """
|
||||
query getFilm($id: ID!) {
|
||||
film(id: $id) {
|
||||
genre
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
# With `permission_get_film`
|
||||
result = self.schema.execute(
|
||||
query,
|
||||
variables={"id": self.films[0].id},
|
||||
context_value={"permission_get_film": True},
|
||||
)
|
||||
assert not result.errors
|
||||
assert result.data["film"] == {
|
||||
"genre": "DO",
|
||||
}
|
||||
|
||||
# Without `permission_get_film`
|
||||
result = self.schema.execute(
|
||||
query,
|
||||
variables={"id": self.films[1].id},
|
||||
context_value={"permission_get_film": False},
|
||||
)
|
||||
assert len(result.errors) == 1
|
||||
assert result.errors[0].message == "Not authorized to access film."
|
||||
|
||||
# A user tries to access a film details
|
||||
query = """
|
||||
query getFilmDetails($id: ID!) {
|
||||
filmDetails(id: $id) {
|
||||
location
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
# With `permission_get_film`
|
||||
result = self.schema.execute(
|
||||
query,
|
||||
variables={"id": self.film_details[0].id},
|
||||
context_value={"permission_get_film_details": True},
|
||||
)
|
||||
assert not result.errors
|
||||
assert result.data == {"filmDetails": {"location": ""}}
|
||||
|
||||
# Without `permission_get_film`
|
||||
result = self.schema.execute(
|
||||
query,
|
||||
variables={"id": self.film_details[0].id},
|
||||
context_value={"permission_get_film_details": False},
|
||||
)
|
||||
assert len(result.errors) == 1
|
||||
assert result.errors[0].message == "Not authorized to access film details."
|
||||
|
||||
def test_get_queryset_called_on_foreignkey(self):
|
||||
# A user tries to access a film details through a film
|
||||
query = """
|
||||
query getFilm($id: ID!) {
|
||||
film(id: $id) {
|
||||
genre
|
||||
details {
|
||||
location
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
# With `permission_get_film_details`
|
||||
result = self.schema.execute(
|
||||
query,
|
||||
variables={"id": self.films[0].id},
|
||||
context_value={
|
||||
"permission_get_film": True,
|
||||
"permission_get_film_details": True,
|
||||
},
|
||||
)
|
||||
assert not result.errors
|
||||
assert result.data["film"] == {
|
||||
"genre": "DO",
|
||||
"details": {"location": ""},
|
||||
}
|
||||
|
||||
# Without `permission_get_film_details`
|
||||
result = self.schema.execute(
|
||||
query,
|
||||
variables={"id": self.films[0].id},
|
||||
context_value={
|
||||
"permission_get_film": True,
|
||||
"permission_get_film_details": False,
|
||||
},
|
||||
)
|
||||
assert len(result.errors) == 1
|
||||
assert result.errors[0].message == "Not authorized to access film details."
|
||||
|
||||
# A user tries to access a film through a film details
|
||||
query = """
|
||||
query getFilmDetails($id: ID!) {
|
||||
filmDetails(id: $id) {
|
||||
location
|
||||
film {
|
||||
genre
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
# With `permission_get_film`
|
||||
result = self.schema.execute(
|
||||
query,
|
||||
variables={"id": self.film_details[0].id},
|
||||
context_value={
|
||||
"permission_get_film": True,
|
||||
"permission_get_film_details": True,
|
||||
},
|
||||
)
|
||||
assert not result.errors
|
||||
assert result.data["filmDetails"] == {
|
||||
"location": "",
|
||||
"film": {"genre": "DO"},
|
||||
}
|
||||
|
||||
# Without `permission_get_film`
|
||||
result = self.schema.execute(
|
||||
query,
|
||||
variables={"id": self.film_details[1].id},
|
||||
context_value={
|
||||
"permission_get_film": False,
|
||||
"permission_get_film_details": True,
|
||||
},
|
||||
)
|
||||
assert len(result.errors) == 1
|
||||
assert result.errors[0].message == "Not authorized to access film."
|
||||
|
|
Loading…
Reference in New Issue
Block a user