From ee1ff975d71f6590eb6933d76d12054c9839774a Mon Sep 17 00:00:00 2001 From: Thomas Leonard <64223923+tcleonard@users.noreply.github.com> Date: Mon, 19 Sep 2022 10:17:31 +0200 Subject: [PATCH] feat: Add support for custom global (Issue #1276) (#1428) Co-authored-by: Thomas Leonard --- .github/workflows/tests.yml | 2 +- Makefile | 1 + graphene/__init__.py | 10 +- graphene/relay/__init__.py | 16 +- graphene/relay/id_type.py | 87 +++++ graphene/relay/node.py | 60 ++-- graphene/relay/tests/test_custom_global_id.py | 325 ++++++++++++++++++ graphene/relay/tests/test_node.py | 1 + 8 files changed, 472 insertions(+), 30 deletions(-) create mode 100644 graphene/relay/id_type.py create mode 100644 graphene/relay/tests/test_custom_global_id.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 51832084..9df18f99 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -58,7 +58,7 @@ jobs: if: ${{ matrix.python == '3.10' }} uses: actions/upload-artifact@v3 with: - name: graphene-sqlalchemy-coverage + name: graphene-coverage path: coverage.xml if-no-files-found: error - name: Upload coverage.xml to codecov diff --git a/Makefile b/Makefile index c78e2b4f..08947707 100644 --- a/Makefile +++ b/Makefile @@ -7,6 +7,7 @@ help: install-dev: pip install -e ".[dev]" +.PHONY: test ## Run tests test: py.test graphene examples diff --git a/graphene/__init__.py b/graphene/__init__.py index aeb6d6d2..af83f059 100644 --- a/graphene/__init__.py +++ b/graphene/__init__.py @@ -1,11 +1,15 @@ from .pyutils.version import get_version from .relay import ( + BaseGlobalIDType, ClientIDMutation, Connection, ConnectionField, + DefaultGlobalIDType, GlobalID, Node, PageInfo, + SimpleGlobalIDType, + UUIDGlobalIDType, is_node, ) from .types import ( @@ -52,6 +56,7 @@ __all__ = [ "Argument", "Base64", "BigInt", + "BaseGlobalIDType", "Boolean", "ClientIDMutation", "Connection", @@ -60,6 +65,7 @@ __all__ = [ "Date", "DateTime", "Decimal", + "DefaultGlobalIDType", "Dynamic", "Enum", "Field", @@ -80,10 +86,12 @@ __all__ = [ "ResolveInfo", "Scalar", "Schema", + "SimpleGlobalIDType", "String", "Time", - "UUID", "Union", + "UUID", + "UUIDGlobalIDType", "is_node", "lazy_import", "resolve_only_args", diff --git a/graphene/relay/__init__.py b/graphene/relay/__init__.py index 7238fa72..3b842cf5 100644 --- a/graphene/relay/__init__.py +++ b/graphene/relay/__init__.py @@ -1,13 +1,23 @@ from .node import Node, is_node, GlobalID from .mutation import ClientIDMutation from .connection import Connection, ConnectionField, PageInfo +from .id_type import ( + BaseGlobalIDType, + DefaultGlobalIDType, + SimpleGlobalIDType, + UUIDGlobalIDType, +) __all__ = [ - "Node", - "is_node", - "GlobalID", + "BaseGlobalIDType", "ClientIDMutation", "Connection", "ConnectionField", + "DefaultGlobalIDType", + "GlobalID", + "Node", "PageInfo", + "SimpleGlobalIDType", + "UUIDGlobalIDType", + "is_node", ] diff --git a/graphene/relay/id_type.py b/graphene/relay/id_type.py new file mode 100644 index 00000000..fb5c30e7 --- /dev/null +++ b/graphene/relay/id_type.py @@ -0,0 +1,87 @@ +from graphql_relay import from_global_id, to_global_id + +from ..types import ID, UUID +from ..types.base import BaseType + +from typing import Type + + +class BaseGlobalIDType: + """ + Base class that define the required attributes/method for a type. + """ + + graphene_type = ID # type: Type[BaseType] + + @classmethod + def resolve_global_id(cls, info, global_id): + # return _type, _id + raise NotImplementedError + + @classmethod + def to_global_id(cls, _type, _id): + # return _id + raise NotImplementedError + + +class DefaultGlobalIDType(BaseGlobalIDType): + """ + Default global ID type: base64 encoded version of ": ". + """ + + graphene_type = ID + + @classmethod + def resolve_global_id(cls, info, global_id): + try: + _type, _id = from_global_id(global_id) + if not _type: + raise ValueError("Invalid Global ID") + return _type, _id + except Exception as e: + raise Exception( + f'Unable to parse global ID "{global_id}". ' + 'Make sure it is a base64 encoded string in the format: "TypeName:id". ' + f"Exception message: {e}" + ) + + @classmethod + def to_global_id(cls, _type, _id): + return to_global_id(_type, _id) + + +class SimpleGlobalIDType(BaseGlobalIDType): + """ + Simple global ID type: simply the id of the object. + To be used carefully as the user is responsible for ensuring that the IDs are indeed global + (otherwise it could cause request caching issues). + """ + + graphene_type = ID + + @classmethod + def resolve_global_id(cls, info, global_id): + _type = info.return_type.graphene_type._meta.name + return _type, global_id + + @classmethod + def to_global_id(cls, _type, _id): + return _id + + +class UUIDGlobalIDType(BaseGlobalIDType): + """ + UUID global ID type. + By definition UUID are global so they are used as they are. + """ + + graphene_type = UUID + + @classmethod + def resolve_global_id(cls, info, global_id): + _type = info.return_type.graphene_type._meta.name + return _type, global_id + + @classmethod + def to_global_id(cls, _type, _id): + return _id diff --git a/graphene/relay/node.py b/graphene/relay/node.py index dabcff6c..54438281 100644 --- a/graphene/relay/node.py +++ b/graphene/relay/node.py @@ -1,11 +1,10 @@ from functools import partial from inspect import isclass -from graphql_relay import from_global_id, to_global_id - -from ..types import ID, Field, Interface, ObjectType +from ..types import Field, Interface, ObjectType from ..types.interface import InterfaceOptions from ..types.utils import get_type +from .id_type import BaseGlobalIDType, DefaultGlobalIDType def is_node(objecttype): @@ -22,8 +21,18 @@ def is_node(objecttype): class GlobalID(Field): - def __init__(self, node=None, parent_type=None, required=True, *args, **kwargs): - super(GlobalID, self).__init__(ID, required=required, *args, **kwargs) + def __init__( + self, + node=None, + parent_type=None, + required=True, + global_id_type=DefaultGlobalIDType, + *args, + **kwargs, + ): + super(GlobalID, self).__init__( + global_id_type.graphene_type, required=required, *args, **kwargs + ) self.node = node or Node self.parent_type_name = parent_type._meta.name if parent_type else None @@ -47,12 +56,14 @@ class NodeField(Field): assert issubclass(node, Node), "NodeField can only operate in Nodes" self.node_type = node self.field_type = type_ + global_id_type = node._meta.global_id_type super(NodeField, self).__init__( - # If we don's specify a type, the field type will be the node - # interface + # If we don't specify a type, the field type will be the node interface type_ or node, - id=ID(required=True, description="The ID of the object"), + id=global_id_type.graphene_type( + required=True, description="The ID of the object" + ), **kwargs, ) @@ -65,11 +76,23 @@ class AbstractNode(Interface): abstract = True @classmethod - def __init_subclass_with_meta__(cls, **options): + def __init_subclass_with_meta__(cls, global_id_type=DefaultGlobalIDType, **options): + assert issubclass( + global_id_type, BaseGlobalIDType + ), "Custom ID type need to be implemented as a subclass of BaseGlobalIDType." _meta = InterfaceOptions(cls) - _meta.fields = {"id": GlobalID(cls, description="The ID of the object")} + _meta.global_id_type = global_id_type + _meta.fields = { + "id": GlobalID( + cls, global_id_type=global_id_type, description="The ID of the object" + ) + } super(AbstractNode, cls).__init_subclass_with_meta__(_meta=_meta, **options) + @classmethod + def resolve_global_id(cls, info, global_id): + return cls._meta.global_id_type.resolve_global_id(info, global_id) + class Node(AbstractNode): """An object with an ID""" @@ -84,16 +107,7 @@ class Node(AbstractNode): @classmethod def get_node_from_global_id(cls, info, global_id, only_type=None): - try: - _type, _id = cls.from_global_id(global_id) - if not _type: - raise ValueError("Invalid Global ID") - except Exception as e: - raise Exception( - f'Unable to parse global ID "{global_id}". ' - 'Make sure it is a base64 encoded string in the format: "TypeName:id". ' - f"Exception message: {e}" - ) + _type, _id = cls.resolve_global_id(info, global_id) graphene_type = info.schema.get_type(_type) if graphene_type is None: @@ -116,10 +130,6 @@ class Node(AbstractNode): if get_node: return get_node(info, _id) - @classmethod - def from_global_id(cls, global_id): - return from_global_id(global_id) - @classmethod def to_global_id(cls, type_, id): - return to_global_id(type_, id) + return cls._meta.global_id_type.to_global_id(type_, id) diff --git a/graphene/relay/tests/test_custom_global_id.py b/graphene/relay/tests/test_custom_global_id.py new file mode 100644 index 00000000..c1bf0fb4 --- /dev/null +++ b/graphene/relay/tests/test_custom_global_id.py @@ -0,0 +1,325 @@ +import re +from uuid import uuid4 + +from graphql import graphql_sync + +from ..id_type import BaseGlobalIDType, SimpleGlobalIDType, UUIDGlobalIDType +from ..node import Node +from ...types import Int, ObjectType, Schema, String + + +class TestUUIDGlobalID: + def setup(self): + self.user_list = [ + {"id": uuid4(), "name": "First"}, + {"id": uuid4(), "name": "Second"}, + {"id": uuid4(), "name": "Third"}, + {"id": uuid4(), "name": "Fourth"}, + ] + self.users = {user["id"]: user for user in self.user_list} + + class CustomNode(Node): + class Meta: + global_id_type = UUIDGlobalIDType + + class User(ObjectType): + class Meta: + interfaces = [CustomNode] + + name = String() + + @classmethod + def get_node(cls, _type, _id): + return self.users[_id] + + class RootQuery(ObjectType): + user = CustomNode.Field(User) + + self.schema = Schema(query=RootQuery, types=[User]) + self.graphql_schema = self.schema.graphql_schema + + def test_str_schema_correct(self): + """ + Check that the schema has the expected and custom node interface and user type and that they both use UUIDs + """ + parsed = re.findall(r"(.+) \{\n\s*([\w\W]*?)\n\}", str(self.schema)) + types = [t for t, f in parsed] + fields = [f for t, f in parsed] + custom_node_interface = "interface CustomNode" + assert custom_node_interface in types + assert ( + '"""The ID of the object"""\n id: UUID!' + == fields[types.index(custom_node_interface)] + ) + user_type = "type User implements CustomNode" + assert user_type in types + assert ( + '"""The ID of the object"""\n id: UUID!\n name: String' + == fields[types.index(user_type)] + ) + + def test_get_by_id(self): + query = """query userById($id: UUID!) { + user(id: $id) { + id + name + } + }""" + # UUID need to be converted to string for serialization + result = graphql_sync( + self.graphql_schema, + query, + variable_values={"id": str(self.user_list[0]["id"])}, + ) + assert not result.errors + assert result.data["user"]["id"] == str(self.user_list[0]["id"]) + assert result.data["user"]["name"] == self.user_list[0]["name"] + + +class TestSimpleGlobalID: + def setup(self): + self.user_list = [ + {"id": "my global primary key in clear 1", "name": "First"}, + {"id": "my global primary key in clear 2", "name": "Second"}, + {"id": "my global primary key in clear 3", "name": "Third"}, + {"id": "my global primary key in clear 4", "name": "Fourth"}, + ] + self.users = {user["id"]: user for user in self.user_list} + + class CustomNode(Node): + class Meta: + global_id_type = SimpleGlobalIDType + + class User(ObjectType): + class Meta: + interfaces = [CustomNode] + + name = String() + + @classmethod + def get_node(cls, _type, _id): + return self.users[_id] + + class RootQuery(ObjectType): + user = CustomNode.Field(User) + + self.schema = Schema(query=RootQuery, types=[User]) + self.graphql_schema = self.schema.graphql_schema + + def test_str_schema_correct(self): + """ + Check that the schema has the expected and custom node interface and user type and that they both use UUIDs + """ + parsed = re.findall(r"(.+) \{\n\s*([\w\W]*?)\n\}", str(self.schema)) + types = [t for t, f in parsed] + fields = [f for t, f in parsed] + custom_node_interface = "interface CustomNode" + assert custom_node_interface in types + assert ( + '"""The ID of the object"""\n id: ID!' + == fields[types.index(custom_node_interface)] + ) + user_type = "type User implements CustomNode" + assert user_type in types + assert ( + '"""The ID of the object"""\n id: ID!\n name: String' + == fields[types.index(user_type)] + ) + + def test_get_by_id(self): + query = """query { + user(id: "my global primary key in clear 3") { + id + name + } + }""" + result = graphql_sync(self.graphql_schema, query) + assert not result.errors + assert result.data["user"]["id"] == self.user_list[2]["id"] + assert result.data["user"]["name"] == self.user_list[2]["name"] + + +class TestCustomGlobalID: + def setup(self): + self.user_list = [ + {"id": 1, "name": "First"}, + {"id": 2, "name": "Second"}, + {"id": 3, "name": "Third"}, + {"id": 4, "name": "Fourth"}, + ] + self.users = {user["id"]: user for user in self.user_list} + + class CustomGlobalIDType(BaseGlobalIDType): + """ + Global id that is simply and integer in clear. + """ + + graphene_type = Int + + @classmethod + def resolve_global_id(cls, info, global_id): + _type = info.return_type.graphene_type._meta.name + return _type, global_id + + @classmethod + def to_global_id(cls, _type, _id): + return _id + + class CustomNode(Node): + class Meta: + global_id_type = CustomGlobalIDType + + class User(ObjectType): + class Meta: + interfaces = [CustomNode] + + name = String() + + @classmethod + def get_node(cls, _type, _id): + return self.users[_id] + + class RootQuery(ObjectType): + user = CustomNode.Field(User) + + self.schema = Schema(query=RootQuery, types=[User]) + self.graphql_schema = self.schema.graphql_schema + + def test_str_schema_correct(self): + """ + Check that the schema has the expected and custom node interface and user type and that they both use UUIDs + """ + parsed = re.findall(r"(.+) \{\n\s*([\w\W]*?)\n\}", str(self.schema)) + types = [t for t, f in parsed] + fields = [f for t, f in parsed] + custom_node_interface = "interface CustomNode" + assert custom_node_interface in types + assert ( + '"""The ID of the object"""\n id: Int!' + == fields[types.index(custom_node_interface)] + ) + user_type = "type User implements CustomNode" + assert user_type in types + assert ( + '"""The ID of the object"""\n id: Int!\n name: String' + == fields[types.index(user_type)] + ) + + def test_get_by_id(self): + query = """query { + user(id: 2) { + id + name + } + }""" + result = graphql_sync(self.graphql_schema, query) + assert not result.errors + assert result.data["user"]["id"] == self.user_list[1]["id"] + assert result.data["user"]["name"] == self.user_list[1]["name"] + + +class TestIncompleteCustomGlobalID: + def setup(self): + self.user_list = [ + {"id": 1, "name": "First"}, + {"id": 2, "name": "Second"}, + {"id": 3, "name": "Third"}, + {"id": 4, "name": "Fourth"}, + ] + self.users = {user["id"]: user for user in self.user_list} + + def test_must_define_to_global_id(self): + """ + Test that if the `to_global_id` method is not defined, we can query the object, but we can't request its ID. + """ + + class CustomGlobalIDType(BaseGlobalIDType): + graphene_type = Int + + @classmethod + def resolve_global_id(cls, info, global_id): + _type = info.return_type.graphene_type._meta.name + return _type, global_id + + class CustomNode(Node): + class Meta: + global_id_type = CustomGlobalIDType + + class User(ObjectType): + class Meta: + interfaces = [CustomNode] + + name = String() + + @classmethod + def get_node(cls, _type, _id): + return self.users[_id] + + class RootQuery(ObjectType): + user = CustomNode.Field(User) + + self.schema = Schema(query=RootQuery, types=[User]) + self.graphql_schema = self.schema.graphql_schema + + query = """query { + user(id: 2) { + name + } + }""" + result = graphql_sync(self.graphql_schema, query) + assert not result.errors + assert result.data["user"]["name"] == self.user_list[1]["name"] + + query = """query { + user(id: 2) { + id + name + } + }""" + result = graphql_sync(self.graphql_schema, query) + assert result.errors is not None + assert len(result.errors) == 1 + assert result.errors[0].path == ["user", "id"] + + def test_must_define_resolve_global_id(self): + """ + Test that if the `resolve_global_id` method is not defined, we can't query the object by ID. + """ + + class CustomGlobalIDType(BaseGlobalIDType): + graphene_type = Int + + @classmethod + def to_global_id(cls, _type, _id): + return _id + + class CustomNode(Node): + class Meta: + global_id_type = CustomGlobalIDType + + class User(ObjectType): + class Meta: + interfaces = [CustomNode] + + name = String() + + @classmethod + def get_node(cls, _type, _id): + return self.users[_id] + + class RootQuery(ObjectType): + user = CustomNode.Field(User) + + self.schema = Schema(query=RootQuery, types=[User]) + self.graphql_schema = self.schema.graphql_schema + + query = """query { + user(id: 2) { + id + name + } + }""" + result = graphql_sync(self.graphql_schema, query) + assert result.errors is not None + assert len(result.errors) == 1 + assert result.errors[0].path == ["user"] diff --git a/graphene/relay/tests/test_node.py b/graphene/relay/tests/test_node.py index 6b310fde..e7564566 100644 --- a/graphene/relay/tests/test_node.py +++ b/graphene/relay/tests/test_node.py @@ -55,6 +55,7 @@ def test_node_good(): assert "id" in MyNode._meta.fields assert is_node(MyNode) assert not is_node(object) + assert not is_node("node") def test_node_query():