From 31321e1e413df9c151571a74630ca8f563d7f70f Mon Sep 17 00:00:00 2001 From: Thomas Leonard Date: Mon, 19 Oct 2020 19:17:28 +0200 Subject: [PATCH] Add support for custom global ID (Issue #1276) --- graphene/__init__.py | 8 + graphene/relay/__init__.py | 5 + graphene/relay/id_type.py | 74 +++++++ graphene/relay/node.py | 34 +-- graphene/relay/tests/test_custom_global_id.py | 194 ++++++++++++++++++ 5 files changed, 299 insertions(+), 16 deletions(-) create mode 100644 graphene/relay/id_type.py create mode 100644 graphene/relay/tests/test_custom_global_id.py diff --git a/graphene/__init__.py b/graphene/__init__.py index 9cbbc38f..6e2bf67b 100644 --- a/graphene/__init__.py +++ b/graphene/__init__.py @@ -38,6 +38,10 @@ from .relay import ( Connection, ConnectionField, PageInfo, + BaseGlobalIDType, + DefaultGlobalIDType, + SimpleGlobalIDType, + UUIDGlobalIDType, ) from .utils.resolve_only_args import resolve_only_args from .utils.module_loading import lazy_import @@ -85,6 +89,10 @@ __all__ = [ "lazy_import", "Context", "ResolveInfo", + "BaseGlobalIDType", + "DefaultGlobalIDType", + "SimpleGlobalIDType", + "UUIDGlobalIDType", # Deprecated "AbstractType", ] diff --git a/graphene/relay/__init__.py b/graphene/relay/__init__.py index 7238fa72..a26c72c2 100644 --- a/graphene/relay/__init__.py +++ b/graphene/relay/__init__.py @@ -1,6 +1,7 @@ from .node import Node, is_node, GlobalID from .mutation import ClientIDMutation from .connection import Connection, ConnectionField, PageInfo +from .id_type import BaseGlobalIDType, DefaultGlobalIDType, SimpleGlobalIDType, UUIDGlobalIDType __all__ = [ "Node", @@ -10,4 +11,8 @@ __all__ = [ "Connection", "ConnectionField", "PageInfo", + "BaseGlobalIDType", + "DefaultGlobalIDType", + "SimpleGlobalIDType", + "UUIDGlobalIDType", ] diff --git a/graphene/relay/id_type.py b/graphene/relay/id_type.py new file mode 100644 index 00000000..29aa16a4 --- /dev/null +++ b/graphene/relay/id_type.py @@ -0,0 +1,74 @@ +from graphql_relay import from_global_id, to_global_id + +from ..types import ID, UUID + + +class BaseGlobalIDType: + """ + Base class that define the required attributes/method for a type. + """ + + graphene_type = None + + @classmethod + def resolve_global_id(cls, info, global_id): + # return _type, _id + raise NotImplementedError + + @classmethod + def to_global_id(cls, _type, _id): + # return _id + raise NotImplementedError + + +class DefaultGlobalIDType(BaseGlobalIDType): + """ + Default global ID type: base64 encoded version of ": ". + """ + + graphene_type = ID + + @classmethod + def resolve_global_id(cls, info, global_id): + return from_global_id(global_id) + + @classmethod + def to_global_id(cls, _type, _id): + return to_global_id(_type, _id) + + +class SimpleGlobalIDType(BaseGlobalIDType): + """ + Simple global ID type: simply the id of the object. + To be used carefully as the user is responsible for ensuring that the IDs are indeed global + (otherwise it could cause request caching issues). + """ + + graphene_type = ID + + @classmethod + def resolve_global_id(cls, info, global_id): + _type = info.return_type.graphene_type._meta.name + return _type, global_id + + @classmethod + def to_global_id(cls, _type, _id): + return _id + + +class UUIDGlobalIDType(BaseGlobalIDType): + """ + UUID global ID type. + By definition UUID are global so they are used as they are. + """ + + graphene_type = UUID + + @classmethod + def resolve_global_id(cls, info, global_id): + _type = info.return_type.graphene_type._meta.name + return _type, global_id + + @classmethod + def to_global_id(cls, _type, _id): + return _id diff --git a/graphene/relay/node.py b/graphene/relay/node.py index d9c4c0f6..532d6baa 100644 --- a/graphene/relay/node.py +++ b/graphene/relay/node.py @@ -2,11 +2,10 @@ from collections import OrderedDict from functools import partial from inspect import isclass -from graphql_relay import from_global_id, to_global_id - -from ..types import ID, Field, Interface, ObjectType +from ..types import Field, Interface, ObjectType from ..types.interface import InterfaceOptions from ..types.utils import get_type +from .id_type import BaseGlobalIDType, DefaultGlobalIDType def is_node(objecttype): @@ -27,8 +26,8 @@ def is_node(objecttype): class GlobalID(Field): - def __init__(self, node=None, parent_type=None, required=True, *args, **kwargs): - super(GlobalID, self).__init__(ID, required=required, *args, **kwargs) + def __init__(self, node=None, parent_type=None, required=True, global_id_type=DefaultGlobalIDType, *args, **kwargs): + super(GlobalID, self).__init__(global_id_type.graphene_type, required=required, *args, **kwargs) self.node = node or Node self.parent_type_name = parent_type._meta.name if parent_type else None @@ -52,13 +51,13 @@ class NodeField(Field): assert issubclass(node, Node), "NodeField can only operate in Nodes" self.node_type = node self.field_type = type + global_id_type = node._meta.global_id_type super(NodeField, self).__init__( - # If we don's specify a type, the field type will be the node - # interface + # If we don't specify a type, the field type will be the node interface type or node, description="The ID of the object", - id=ID(required=True), + id=global_id_type.graphene_type(required=True), ) def get_resolver(self, parent_resolver): @@ -70,13 +69,20 @@ class AbstractNode(Interface): abstract = True @classmethod - def __init_subclass_with_meta__(cls, **options): + def __init_subclass_with_meta__(cls, global_id_type=DefaultGlobalIDType, **options): + assert issubclass(global_id_type, BaseGlobalIDType), \ + "Custom ID type need to be implemented as a subclass of BaseGlobalIDType." _meta = InterfaceOptions(cls) + _meta.global_id_type = global_id_type _meta.fields = OrderedDict( - id=GlobalID(cls, description="The ID of the object.") + id=GlobalID(cls, global_id_type=global_id_type, description="The ID of the object.") ) super(AbstractNode, cls).__init_subclass_with_meta__(_meta=_meta, **options) + @classmethod + def resolve_global_id(cls, info, global_id): + return cls._meta.global_id_type.resolve_global_id(info, global_id) + class Node(AbstractNode): """An object with an ID""" @@ -92,7 +98,7 @@ class Node(AbstractNode): @classmethod def get_node_from_global_id(cls, info, global_id, only_type=None): try: - _type, _id = cls.from_global_id(global_id) + _type, _id = cls.resolve_global_id(info, global_id) graphene_type = info.schema.get_type(_type).graphene_type except Exception: return None @@ -110,10 +116,6 @@ class Node(AbstractNode): if get_node: return get_node(info, _id) - @classmethod - def from_global_id(cls, global_id): - return from_global_id(global_id) - @classmethod def to_global_id(cls, type, id): - return to_global_id(type, id) + return cls._meta.global_id_type.to_global_id(type, id) diff --git a/graphene/relay/tests/test_custom_global_id.py b/graphene/relay/tests/test_custom_global_id.py new file mode 100644 index 00000000..1e58022f --- /dev/null +++ b/graphene/relay/tests/test_custom_global_id.py @@ -0,0 +1,194 @@ +import re +from uuid import uuid4 + +from graphql import graphql + +from ..connection import Connection, ConnectionField +from ..id_type import BaseGlobalIDType, SimpleGlobalIDType, UUIDGlobalIDType +from ..node import Node +from ...types import Int, ObjectType, Schema, String + + +class TestUUIDGlobalID: + def setup(self): + self.user_list = [ + {"id": uuid4(), "name": "First"}, + {"id": uuid4(), "name": "Second"}, + {"id": uuid4(), "name": "Third"}, + {"id": uuid4(), "name": "Fourth"}, + ] + self.users = {user["id"]: user for user in self.user_list} + + class CustomNode(Node): + class Meta: + global_id_type = UUIDGlobalIDType + + class User(ObjectType): + class Meta: + interfaces = [CustomNode] + + name = String() + + @classmethod + def get_node(cls, _type, _id): + return self.users[_id] + + class RootQuery(ObjectType): + user = CustomNode.Field(User) + + self.schema = Schema(query=RootQuery, types=[User]) + + def test_str_schema_correct(self): + """ + Check that the schema has the expected and custom node interface and user type and that they both use UUIDs + """ + parsed = re.findall(r"(.+) \{\n\s*([\w\W]*?)\n\}", str(self.schema)) + types = [t for t, f in parsed] + fields = [f for t, f in parsed] + custom_node_interface = "interface CustomNode" + assert custom_node_interface in types + assert "id: UUID!" == fields[types.index(custom_node_interface)] + user_type = "type User implements CustomNode" + assert user_type in types + assert "id: UUID!\n name: String" == fields[types.index(user_type)] + + def test_get_by_id(self): + query = """query userById($id: UUID!) { + user(id: $id) { + id + name + } + }""" + # UUID need to be converted to string for serialization + result = graphql(self.schema, query, variable_values={"id": str(self.user_list[0]["id"])}) + assert not result.errors + assert result.data["user"]["id"] == str(self.user_list[0]["id"]) + assert result.data["user"]["name"] == self.user_list[0]["name"] + + +class TestSimpleGlobalID: + def setup(self): + self.user_list = [ + {"id": "my global primary key in clear 1", "name": "First"}, + {"id": "my global primary key in clear 2", "name": "Second"}, + {"id": "my global primary key in clear 3", "name": "Third"}, + {"id": "my global primary key in clear 4", "name": "Fourth"}, + ] + self.users = {user["id"]: user for user in self.user_list} + + class CustomNode(Node): + class Meta: + global_id_type = SimpleGlobalIDType + + class User(ObjectType): + class Meta: + interfaces = [CustomNode] + + name = String() + + @classmethod + def get_node(cls, _type, _id): + return self.users[_id] + + class RootQuery(ObjectType): + user = CustomNode.Field(User) + + self.schema = Schema(query=RootQuery, types=[User]) + + def test_str_schema_correct(self): + """ + Check that the schema has the expected and custom node interface and user type and that they both use UUIDs + """ + parsed = re.findall(r"(.+) \{\n\s*([\w\W]*?)\n\}", str(self.schema)) + types = [t for t, f in parsed] + fields = [f for t, f in parsed] + custom_node_interface = "interface CustomNode" + assert custom_node_interface in types + assert "id: ID!" == fields[types.index(custom_node_interface)] + user_type = "type User implements CustomNode" + assert user_type in types + assert "id: ID!\n name: String" == fields[types.index(user_type)] + + def test_get_by_id(self): + query = """query { + user(id: "my global primary key in clear 3") { + id + name + } + }""" + result = graphql(self.schema, query) + assert not result.errors + assert result.data["user"]["id"] == self.user_list[2]["id"] + assert result.data["user"]["name"] == self.user_list[2]["name"] + + +class TestCustomGlobalID: + def setup(self): + self.user_list = [ + {"id": 1, "name": "First"}, + {"id": 2, "name": "Second"}, + {"id": 3, "name": "Third"}, + {"id": 4, "name": "Fourth"}, + ] + self.users = {user["id"]: user for user in self.user_list} + + class CustomGlobalIDType(BaseGlobalIDType): + """ + Global id that is simply and integer in clear. + """ + + graphene_type = Int + + @classmethod + def resolve_global_id(cls, info, global_id): + _type = info.return_type.graphene_type._meta.name + return _type, global_id + + @classmethod + def to_global_id(cls, _type, _id): + return _id + + class CustomNode(Node): + class Meta: + global_id_type = CustomGlobalIDType + + class User(ObjectType): + class Meta: + interfaces = [CustomNode] + + name = String() + + @classmethod + def get_node(cls, _type, _id): + return self.users[_id] + + class RootQuery(ObjectType): + user = CustomNode.Field(User) + + self.schema = Schema(query=RootQuery, types=[User]) + + def test_str_schema_correct(self): + """ + Check that the schema has the expected and custom node interface and user type and that they both use UUIDs + """ + parsed = re.findall(r"(.+) \{\n\s*([\w\W]*?)\n\}", str(self.schema)) + types = [t for t, f in parsed] + fields = [f for t, f in parsed] + custom_node_interface = "interface CustomNode" + assert custom_node_interface in types + assert "id: Int!" == fields[types.index(custom_node_interface)] + user_type = "type User implements CustomNode" + assert user_type in types + assert "id: Int!\n name: String" == fields[types.index(user_type)] + + def test_get_by_id(self): + query = """query { + user(id: 2) { + id + name + } + }""" + result = graphql(self.schema, query) + assert not result.errors + assert result.data["user"]["id"] == self.user_list[1]["id"] + assert result.data["user"]["name"] == self.user_list[1]["name"]