This commit is contained in:
Andy Neff 2017-09-20 14:51:41 +00:00 committed by GitHub
commit 5f0e7cfd06
2 changed files with 88 additions and 1 deletions

View File

@ -7,6 +7,8 @@ import datetime
import decimal import decimal
import json import json
import uuid import uuid
from json.encoder import (INFINITY, _make_iterencode,
encode_basestring, encode_basestring_ascii)
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
from django.utils import six, timezone from django.utils import six, timezone
@ -15,12 +17,68 @@ from django.utils.functional import Promise
from rest_framework.compat import coreapi, total_seconds from rest_framework.compat import coreapi, total_seconds
try:
from json.encoder import FLOAT_REPR
except:
FLOAT_REPR = float.__repr__
class JSONEncoder(json.JSONEncoder): class JSONEncoder(json.JSONEncoder):
""" """
JSONEncoder subclass that knows how to encode date/time/timedelta, JSONEncoder subclass that knows how to encode date/time/timedelta,
decimal types, generators and other basic python objects. decimal types, generators and other basic python objects.
""" """
def iterencode(self, o, _one_shot=False):
"""Encode the given object and yield each string
representation as available.
For example::
for chunk in JSONEncoder().iterencode(bigobject):
mysocket.write(chunk)
"""
if self.check_circular:
markers = {}
else:
markers = None
if self.ensure_ascii:
_encoder = encode_basestring_ascii
else:
_encoder = encode_basestring
if six.PY2:
if self.encoding != 'utf-8':
def _encoder(o, _orig_encoder=_encoder,
_encoding=self.encoding):
if isinstance(o, str):
o = o.decode(_encoding)
return _orig_encoder(o)
def floatstr(o, allow_nan=self.allow_nan,
_repr=FLOAT_REPR, _inf=INFINITY, _neginf=-INFINITY):
# Check for specials. Note that this type of test is processor
# and/or platform-specific, so do tests which don't depend on the
# internals.
if o != o:
text = '"NaN"'
elif o == _inf:
text = '"Infinity"'
elif o == _neginf:
text = '"-Infinity"'
else:
return _repr(o)
if not allow_nan:
raise ValueError(
"Out of range float values are not JSON compliant: " +
repr(o))
return text
return _make_iterencode(markers, self.default, _encoder, self.indent,
floatstr, self.key_separator,
self.item_separator, self.sort_keys,
self.skipkeys, _one_shot)(o, 0)
def default(self, obj): def default(self, obj):
# For Date Time string spec, see ECMA 262 # For Date Time string spec, see ECMA 262
# http://ecma-international.org/ecma-262/5.1/#sec-15.9.1.15 # http://ecma-international.org/ecma-262/5.1/#sec-15.9.1.15
@ -55,7 +113,8 @@ class JSONEncoder(json.JSONEncoder):
elif hasattr(obj, 'tolist'): elif hasattr(obj, 'tolist'):
# Numpy arrays and array scalars. # Numpy arrays and array scalars.
return obj.tolist() return obj.tolist()
elif (coreapi is not None) and isinstance(obj, (coreapi.Document, coreapi.Error)): elif (coreapi is not None) and isinstance(obj, (coreapi.Document,
coreapi.Error)):
raise RuntimeError( raise RuntimeError(
'Cannot return a coreapi object from a JSON view. ' 'Cannot return a coreapi object from a JSON view. '
'You should be using a schema renderer instead for this view.' 'You should be using a schema renderer instead for this view.'

View File

@ -1,9 +1,12 @@
# -*- coding: utf-8 -*-
from datetime import date, datetime, timedelta from datetime import date, datetime, timedelta
from decimal import Decimal from decimal import Decimal
from uuid import uuid4 from uuid import uuid4
import pytest import pytest
from django.test import TestCase from django.test import TestCase
from django.utils import six
from django.utils.timezone import utc from django.utils.timezone import utc
from rest_framework.compat import coreapi from rest_framework.compat import coreapi
@ -92,3 +95,28 @@ class JSONEncoderTests(TestCase):
""" """
foo = MockList() foo = MockList()
assert self.encoder.default(foo) == [1, 2, 3] assert self.encoder.default(foo) == [1, 2, 3]
def test_encode_float(self):
"""
Tests encoding floats with special values
"""
f = [3.141592653, float('inf'), float('-inf'), float('nan')]
assert self.encoder.encode(f) == '[3.141592653, "Infinity", "-Infinity", "NaN"]'
encoder = JSONEncoder(allow_nan=False)
try:
encoder.encode(f)
except ValueError:
pass
else:
assert False
def test_encode_string(self):
"""
Tests encoding string
"""
if six.PY2:
encoder2 = JSONEncoder(encoding='latin_1', check_circular=False)
assert encoder2.encode(['foo☺']) == '["foo\\u00e2\\u0098\\u00ba"]'