diff --git a/examples/flask_sqlalchemy/app.py b/examples/flask_sqlalchemy/app.py index 0bf2700f..afec4911 100644 --- a/examples/flask_sqlalchemy/app.py +++ b/examples/flask_sqlalchemy/app.py @@ -12,9 +12,14 @@ default_query = ''' allEmployees { edges { node { - id - name + id, + name, department { + id, + name + }, + role { + id, name } } diff --git a/examples/flask_sqlalchemy/database.py b/examples/flask_sqlalchemy/database.py index b2a51789..ca4d4122 100644 --- a/examples/flask_sqlalchemy/database.py +++ b/examples/flask_sqlalchemy/database.py @@ -14,7 +14,7 @@ def init_db(): # import all modules here that might define models so that # they will be registered properly on the metadata. Otherwise # you will have to import them first before calling init_db() - from models import Department, Employee + from models import Department, Employee, Role Base.metadata.drop_all(bind=engine) Base.metadata.create_all(bind=engine) @@ -24,10 +24,15 @@ def init_db(): hr = Department(name='Human Resources') db_session.add(hr) - peter = Employee(name='Peter', department=engineering) + manager = Role(name='manager') + db_session.add(manager) + engineer = Role(name='engineer') + db_session.add(engineer) + + peter = Employee(name='Peter', department=engineering, role=engineer) db_session.add(peter) - roy = Employee(name='Roy', department=engineering) + roy = Employee(name='Roy', department=engineering, role=engineer) db_session.add(roy) - tracy = Employee(name='Tracy', department=hr) + tracy = Employee(name='Tracy', department=hr, role=manager) db_session.add(tracy) db_session.commit() diff --git a/examples/flask_sqlalchemy/models.py b/examples/flask_sqlalchemy/models.py index 0fffb51d..119aca02 100644 --- a/examples/flask_sqlalchemy/models.py +++ b/examples/flask_sqlalchemy/models.py @@ -10,6 +10,12 @@ class Department(Base): name = Column(String) +class Role(Base): + __tablename__ = 'roles' + role_id = Column(Integer, primary_key=True) + name = Column(String) + + class Employee(Base): __tablename__ = 'employee' id = Column(Integer, primary_key=True) @@ -19,9 +25,15 @@ class Employee(Base): # Employee record was created hired_on = Column(DateTime, default=func.now()) department_id = Column(Integer, ForeignKey('department.id')) + role_id = Column(Integer, ForeignKey('roles.role_id')) # Use cascade='delete,all' to propagate the deletion of a Department onto its Employees department = relationship( Department, backref=backref('employees', uselist=True, cascade='delete,all')) + role = relationship( + Role, + backref=backref('roles', + uselist=True, + cascade='delete,all')) diff --git a/examples/flask_sqlalchemy/schema.py b/examples/flask_sqlalchemy/schema.py index d0de90f6..e880cd59 100644 --- a/examples/flask_sqlalchemy/schema.py +++ b/examples/flask_sqlalchemy/schema.py @@ -4,6 +4,7 @@ from graphene.contrib.sqlalchemy import (SQLAlchemyConnectionField, SQLAlchemyNode) from models import Department as DepartmentModel from models import Employee as EmployeeModel +from models import Role as RoleModel schema = graphene.Schema() @@ -22,8 +23,18 @@ class Employee(SQLAlchemyNode): model = EmployeeModel +@schema.register +class Role(SQLAlchemyNode): + + class Meta: + model = RoleModel + identifier = 'role_id' + + class Query(graphene.ObjectType): - node = relay.NodeField() + node = relay.NodeField(Employee) all_employees = SQLAlchemyConnectionField(Employee) + all_roles = SQLAlchemyConnectionField(Role) + role = relay.NodeField(Role) schema.query = Query diff --git a/graphene/contrib/sqlalchemy/options.py b/graphene/contrib/sqlalchemy/options.py index 1d4b2a4f..44886287 100644 --- a/graphene/contrib/sqlalchemy/options.py +++ b/graphene/contrib/sqlalchemy/options.py @@ -2,7 +2,7 @@ from ...core.classtypes.objecttype import ObjectTypeOptions from ...relay.types import Node from ...relay.utils import is_node -VALID_ATTRS = ('model', 'only_fields', 'exclude_fields') +VALID_ATTRS = ('model', 'only_fields', 'exclude_fields', 'identifier') class SQLAlchemyOptions(ObjectTypeOptions): @@ -10,6 +10,7 @@ class SQLAlchemyOptions(ObjectTypeOptions): def __init__(self, *args, **kwargs): super(SQLAlchemyOptions, self).__init__(*args, **kwargs) self.model = None + self.identifier = "id" self.valid_attrs += VALID_ATTRS self.only_fields = None self.exclude_fields = [] diff --git a/graphene/contrib/sqlalchemy/tests/models.py b/graphene/contrib/sqlalchemy/tests/models.py index ee021054..40f95e59 100644 --- a/graphene/contrib/sqlalchemy/tests/models.py +++ b/graphene/contrib/sqlalchemy/tests/models.py @@ -11,6 +11,12 @@ association_table = Table('association', Base.metadata, Column('reporter_id', Integer, ForeignKey('reporters.id'))) +class Editor(Base): + __tablename__ = 'editors' + editor_id = Column(Integer(), primary_key=True) + name = Column(String(100)) + + class Pet(Base): __tablename__ = 'pets' id = Column(Integer(), primary_key=True) diff --git a/graphene/contrib/sqlalchemy/tests/test_query.py b/graphene/contrib/sqlalchemy/tests/test_query.py index 5f970488..ce66d196 100644 --- a/graphene/contrib/sqlalchemy/tests/test_query.py +++ b/graphene/contrib/sqlalchemy/tests/test_query.py @@ -7,7 +7,7 @@ from graphene import relay from graphene.contrib.sqlalchemy import (SQLAlchemyConnectionField, SQLAlchemyNode, SQLAlchemyObjectType) -from .models import Article, Base, Reporter +from .models import Article, Base, Reporter, Editor db = create_engine('sqlite:///test_sqlalchemy.sqlite3') @@ -37,6 +37,8 @@ def setup_fixtures(session): session.add(reporter2) article = Article(headline='Hi!') session.add(article) + editor = Editor(name="John") + session.add(editor) session.commit() @@ -187,3 +189,51 @@ def test_should_node(session): result = schema.execute(query) assert not result.errors assert result.data == expected + + +def test_should_custom_identifier(session): + setup_fixtures(session) + + class EditorNode(SQLAlchemyNode): + + class Meta: + model = Editor + identifier = "editor_id" + + class Query(graphene.ObjectType): + node = relay.NodeField(EditorNode) + all_editors = SQLAlchemyConnectionField(EditorNode) + + query = ''' + query EditorQuery { + allEditors { + edges { + node { + id, + name + } + } + }, + node(id: "RWRpdG9yTm9kZTox") { + name + } + } + ''' + expected = { + 'allEditors': { + 'edges': [{ + 'node': { + 'id': 'RWRpdG9yTm9kZTox', + 'name': 'John' + } + }] + }, + 'node': { + 'name': 'John' + } + } + + schema = graphene.Schema(query=Query, session=session) + result = schema.execute(query) + assert not result.errors + assert result.data == expected diff --git a/graphene/contrib/sqlalchemy/types.py b/graphene/contrib/sqlalchemy/types.py index 64b09afd..f466a1af 100644 --- a/graphene/contrib/sqlalchemy/types.py +++ b/graphene/contrib/sqlalchemy/types.py @@ -109,12 +109,17 @@ class SQLAlchemyNode(six.with_metaclass( class Meta: abstract = True + def to_global_id(self): + id_ = getattr(self.instance, self._meta.identifier) + return self.global_id(id_) + @classmethod def get_node(cls, id, info=None): try: model = cls._meta.model + identifier = cls._meta.identifier query = get_query(model, info) - instance = query.filter(model.id == id).one() + instance = query.filter(getattr(model, identifier) == id).one() return cls(instance) except NoResultFound: return None