pep8 Tests/tester.py

This commit is contained in:
hugovk 2014-04-08 17:17:10 +03:00
parent ed00b2e6d2
commit 5afdd6cb54

View File

@ -11,10 +11,8 @@ except NameError:
# we expect a NameError on py2.x, since it doesn't have ResourceWarnings. # we expect a NameError on py2.x, since it doesn't have ResourceWarnings.
pass pass
import sys import sys
py3 = (sys.version_info >= (3,0)) py3 = (sys.version_info >= (3, 0))
# some test helpers # some test helpers
@ -22,6 +20,7 @@ _target = None
_tempfiles = [] _tempfiles = []
_logfile = None _logfile = None
def success(): def success():
import sys import sys
success.count += 1 success.count += 1
@ -29,8 +28,10 @@ def success():
print(sys.argv[0], success.count, failure.count, file=_logfile) print(sys.argv[0], success.count, failure.count, file=_logfile)
return True return True
def failure(msg=None, frame=None): def failure(msg=None, frame=None):
import sys, linecache import sys
import linecache
failure.count += 1 failure.count += 1
if _target: if _target:
if frame is None: if frame is None:
@ -49,6 +50,7 @@ def failure(msg=None, frame=None):
success.count = failure.count = 0 success.count = failure.count = 0
# predicates # predicates
def assert_true(v, msg=None): def assert_true(v, msg=None):
@ -57,35 +59,39 @@ def assert_true(v, msg=None):
else: else:
failure(msg or "got %r, expected true value" % v) failure(msg or "got %r, expected true value" % v)
def assert_false(v, msg=None): def assert_false(v, msg=None):
if v: if v:
failure(msg or "got %r, expected false value" % v) failure(msg or "got %r, expected false value" % v)
else: else:
success() success()
def assert_equal(a, b, msg=None): def assert_equal(a, b, msg=None):
if a == b: if a == b:
success() success()
else: else:
failure(msg or "got %r, expected %r" % (a, b)) failure(msg or "got %r, expected %r" % (a, b))
def assert_almost_equal(a, b, msg=None, eps=1e-6): def assert_almost_equal(a, b, msg=None, eps=1e-6):
if abs(a-b) < eps: if abs(a-b) < eps:
success() success()
else: else:
failure(msg or "got %r, expected %r" % (a, b)) failure(msg or "got %r, expected %r" % (a, b))
def assert_deep_equal(a, b, msg=None): def assert_deep_equal(a, b, msg=None):
try: try:
if len(a) == len(b): if len(a) == len(b):
if all([x==y for x,y in zip(a,b)]): if all([x == y for x, y in zip(a, b)]):
success() success()
else: else:
failure(msg or "got %s, expected %s" % (a,b)) failure(msg or "got %s, expected %s" % (a, b))
else: else:
failure(msg or "got length %s, expected %s" % (len(a), len(b))) failure(msg or "got length %s, expected %s" % (len(a), len(b)))
except: except:
assert_equal(a,b,msg) assert_equal(a, b, msg)
def assert_match(v, pattern, msg=None): def assert_match(v, pattern, msg=None):
@ -95,8 +101,10 @@ def assert_match(v, pattern, msg=None):
else: else:
failure(msg or "got %r, doesn't match pattern %r" % (v, pattern)) failure(msg or "got %r, doesn't match pattern %r" % (v, pattern))
def assert_exception(exc_class, func): def assert_exception(exc_class, func):
import sys, traceback import sys
import traceback
try: try:
func() func()
except exc_class: except exc_class:
@ -108,8 +116,10 @@ def assert_exception(exc_class, func):
else: else:
failure("expected %r exception, got no exception" % exc_class.__name__) failure("expected %r exception, got no exception" % exc_class.__name__)
def assert_no_exception(func): def assert_no_exception(func):
import sys, traceback import sys
import traceback
try: try:
func() func()
except: except:
@ -118,11 +128,14 @@ def assert_no_exception(func):
else: else:
success() success()
def assert_warning(warn_class, func): def assert_warning(warn_class, func):
# note: this assert calls func three times! # note: this assert calls func three times!
import warnings import warnings
def warn_error(message, category=UserWarning, **options): def warn_error(message, category=UserWarning, **options):
raise category(message) raise category(message)
def warn_ignore(message, category=UserWarning, **options): def warn_ignore(message, category=UserWarning, **options):
pass pass
warn = warnings.warn warn = warnings.warn
@ -141,15 +154,18 @@ def assert_warning(warn_class, func):
from io import BytesIO from io import BytesIO
def fromstring(data): def fromstring(data):
from PIL import Image from PIL import Image
return Image.open(BytesIO(data)) return Image.open(BytesIO(data))
def tostring(im, format, **options): def tostring(im, format, **options):
out = BytesIO() out = BytesIO()
im.save(out, format, **options) im.save(out, format, **options)
return out.getvalue() return out.getvalue()
def lena(mode="RGB", cache={}): def lena(mode="RGB", cache={}):
from PIL import Image from PIL import Image
im = cache.get(mode) im = cache.get(mode)
@ -165,6 +181,7 @@ def lena(mode="RGB", cache={}):
cache[mode] = im cache[mode] = im
return im return im
def assert_image(im, mode, size, msg=None): def assert_image(im, mode, size, msg=None):
if mode is not None and im.mode != mode: if mode is not None and im.mode != mode:
failure(msg or "got mode %r, expected %r" % (im.mode, mode)) failure(msg or "got mode %r, expected %r" % (im.mode, mode))
@ -173,6 +190,7 @@ def assert_image(im, mode, size, msg=None):
else: else:
success() success()
def assert_image_equal(a, b, msg=None): def assert_image_equal(a, b, msg=None):
if a.mode != b.mode: if a.mode != b.mode:
failure(msg or "got mode %r, expected %r" % (a.mode, b.mode)) failure(msg or "got mode %r, expected %r" % (a.mode, b.mode))
@ -184,6 +202,7 @@ def assert_image_equal(a, b, msg=None):
else: else:
success() success()
def assert_image_similar(a, b, epsilon, msg=None): def assert_image_similar(a, b, epsilon, msg=None):
epsilon = float(epsilon) epsilon = float(epsilon)
if a.mode != b.mode: if a.mode != b.mode:
@ -193,19 +212,25 @@ def assert_image_similar(a, b, epsilon, msg=None):
diff = 0 diff = 0
try: try:
ord(b'0') ord(b'0')
for abyte,bbyte in zip(a.tobytes(),b.tobytes()): for abyte, bbyte in zip(a.tobytes(), b.tobytes()):
diff += abs(ord(abyte)-ord(bbyte)) diff += abs(ord(abyte)-ord(bbyte))
except: except:
for abyte,bbyte in zip(a.tobytes(),b.tobytes()): for abyte, bbyte in zip(a.tobytes(), b.tobytes()):
diff += abs(abyte-bbyte) diff += abs(abyte-bbyte)
ave_diff = float(diff)/(a.size[0]*a.size[1]) ave_diff = float(diff)/(a.size[0]*a.size[1])
if epsilon < ave_diff: if epsilon < ave_diff:
return failure(msg or "average pixel value difference %.4f > epsilon %.4f" %(ave_diff, epsilon)) return failure(
msg or "average pixel value difference %.4f > epsilon %.4f" % (
ave_diff, epsilon))
else: else:
return success() return success()
def tempfile(template, *extra): def tempfile(template, *extra):
import os, os.path, sys, tempfile import os
import os.path
import sys
import tempfile
files = [] files = []
root = os.path.join(tempfile.gettempdir(), 'pillow-tests') root = os.path.join(tempfile.gettempdir(), 'pillow-tests')
try: try:
@ -222,11 +247,13 @@ def tempfile(template, *extra):
_tempfiles.extend(files) _tempfiles.extend(files)
return files[0] return files[0]
# test runner # test runner
def run(): def run():
global _target, _tests, run global _target, _tests, run
import sys, traceback import sys
import traceback
_target = sys.modules["__main__"] _target = sys.modules["__main__"]
run = None # no need to run twice run = None # no need to run twice
tests = [] tests = []
@ -251,29 +278,36 @@ def run():
sys.argv[0], lineno, v)) sys.argv[0], lineno, v))
failure.count += 1 failure.count += 1
def yield_test(function, *args): def yield_test(function, *args):
# collect delayed/generated tests # collect delayed/generated tests
_tests.append((function, args)) _tests.append((function, args))
def skip(msg=None): def skip(msg=None):
import os import os
print("skip") print("skip")
os._exit(0) # don't run exit handlers os._exit(0) # don't run exit handlers
def ignore(pattern): def ignore(pattern):
"""Tells the driver to ignore messages matching the pattern, for the """Tells the driver to ignore messages matching the pattern, for the
duration of the current test.""" duration of the current test."""
print('ignore: %s' % pattern) print('ignore: %s' % pattern)
def _setup(): def _setup():
global _logfile global _logfile
def report(): def report():
if run: if run:
run() run()
if success.count and not failure.count: if success.count and not failure.count:
print("ok") print("ok")
# only clean out tempfiles if test passed # only clean out tempfiles if test passed
import os, os.path, tempfile import os
import os.path
import tempfile
for file in _tempfiles: for file in _tempfiles:
try: try:
os.remove(file) os.remove(file)
@ -285,7 +319,8 @@ def _setup():
except OSError: except OSError:
pass pass
import atexit, sys import atexit
import sys
atexit.register(report) atexit.register(report)
if "--coverage" in sys.argv: if "--coverage" in sys.argv:
import coverage import coverage