diff --git a/graphene/types/geo.py b/graphene/types/geo.py new file mode 100644 index 00000000..29bb3109 --- /dev/null +++ b/graphene/types/geo.py @@ -0,0 +1,60 @@ +from __future__ import absolute_import +from collections import Iterable + +from graphql.language import ast + +from .scalars import Scalar + +try: + import shapely + from shapely import geometry + from shapely.wkt import loads +except: + raise ImportError( + "shapely package is required for Graphene geo.\n" + "You can install it using: pip install shapely." + ) + + +def get_coords(geom): + if isinstance(geom, Iterable): + return geom + return geom.coords + + +class Point(Scalar): + ''' + The `Point` scalar type represents a Point + value as specified by + [iso8601](https://en.wikipedia.org/wiki/ISO_8601). + ''' + + @staticmethod + def serialize(point): + coords = get_coords(point) + assert coords is not None, ( + 'Received not compatible Point "{}"'.format(repr(point)) + ) + if coords: + if isinstance(coords[0], Iterable): + return list(coords[0]) + return coords + return [] + + @classmethod + def parse_literal(cls, node): + if isinstance(node, ast.StringValue): + loaded = loads(node.value) + if isinstance(loaded, geometry.Point): + return loaded + + if isinstance(node, ast.ListValue): + inner_values = [float(v.value) for v in node.values] + return geometry.Point(*inner_values) + + @staticmethod + def parse_value(coords): + if not isinstance(coords, Iterable): + raise Exception("Received incompatible value for Point") + + return geometry.Point(*coords) diff --git a/graphene/types/tests/test_geo.py b/graphene/types/tests/test_geo.py new file mode 100644 index 00000000..c61ccf8a --- /dev/null +++ b/graphene/types/tests/test_geo.py @@ -0,0 +1,41 @@ +import pytest + +from ..geo import Point +from ..objecttype import ObjectType +from ..schema import Schema + + +class Query(ObjectType): + point = Point(input=Point()) + point_list = Point() + + def resolve_point(self, args, context, info): + input = args.get('input') + return input + + def resolve_point_list(self, args, context, info): + return [1, 2] + +schema = Schema(query=Query) + + +@pytest.mark.parametrize("input,expected", [ + ("[1,2]", [1,2]), + ("[1,2,3]", [1,2,3]), + ("[]", []), + (""" "POINT (1 2)" """, [1,2]), +]) +def test_point_query(input, expected): + result = schema.execute('''{ point(input: %s) }'''%(input)) + assert not result.errors + assert result.data == { + 'point': expected + } + + +def test_point_list_query(): + result = schema.execute('''{ pointList }''') + assert not result.errors + assert result.data == { + 'pointList': [1, 2] + }