From 3873c6c09f133dd7793d8fcd1547341e22f1705b Mon Sep 17 00:00:00 2001
From: Daniele Varrazzo <daniele.varrazzo@gmail.com>
Date: Mon, 6 Feb 2017 18:43:39 +0000
Subject: [PATCH] Deal consistently with E'' quotes in tests

---
 tests/test_cursor.py       | 12 ++++++------
 tests/test_quote.py        |  4 ++--
 tests/test_sql.py          | 12 +++++-------
 tests/test_types_extras.py | 39 ++++++++++++++------------------------
 tests/testutils.py         | 15 ++++++++++++++-
 5 files changed, 41 insertions(+), 41 deletions(-)

diff --git a/tests/test_cursor.py b/tests/test_cursor.py
index a8fedccb..42e12c49 100755
--- a/tests/test_cursor.py
+++ b/tests/test_cursor.py
@@ -79,20 +79,20 @@ class CursorTests(ConnectingTestCase):
         # unicode query with non-ascii data
         cur.execute(u"SELECT '%s';" % snowman)
         self.assertEqual(snowman.encode('utf8'), b(cur.fetchone()[0]))
-        self.assertEqual(("SELECT '%s';" % snowman).encode('utf8'),
-            cur.mogrify(u"SELECT '%s';" % snowman).replace(b"E'", b"'"))
+        self.assertQuotedEqual(("SELECT '%s';" % snowman).encode('utf8'),
+            cur.mogrify(u"SELECT '%s';" % snowman))
 
         # unicode args
         cur.execute("SELECT %s;", (snowman,))
         self.assertEqual(snowman.encode("utf-8"), b(cur.fetchone()[0]))
-        self.assertEqual(("SELECT '%s';" % snowman).encode('utf8'),
-            cur.mogrify("SELECT %s;", (snowman,)).replace(b"E'", b"'"))
+        self.assertQuotedEqual(("SELECT '%s';" % snowman).encode('utf8'),
+            cur.mogrify("SELECT %s;", (snowman,)))
 
         # unicode query and args
         cur.execute(u"SELECT %s;", (snowman,))
         self.assertEqual(snowman.encode("utf-8"), b(cur.fetchone()[0]))
-        self.assertEqual(("SELECT '%s';" % snowman).encode('utf8'),
-            cur.mogrify(u"SELECT %s;", (snowman,)).replace(b"E'", b"'"))
+        self.assertQuotedEqual(("SELECT '%s';" % snowman).encode('utf8'),
+            cur.mogrify(u"SELECT %s;", (snowman,)))
 
     def test_mogrify_decimal_explodes(self):
         # issue #7: explodes on windows with python 2.5 and psycopg 2.2.2
diff --git a/tests/test_quote.py b/tests/test_quote.py
index 72c9c1e4..22d896aa 100755
--- a/tests/test_quote.py
+++ b/tests/test_quote.py
@@ -236,7 +236,7 @@ class TestStringAdapter(ConnectingTestCase):
         a.prepare(self.conn)
 
         self.assertEqual(a.encoding, 'utf_8')
-        self.assertEqual(a.getquoted(), b"'\xe2\x98\x83'")
+        self.assertQuotedEqual(a.getquoted(), b"'\xe2\x98\x83'")
 
     @testutils.skip_before_python(3)
     def test_adapt_bytes(self):
@@ -244,7 +244,7 @@ class TestStringAdapter(ConnectingTestCase):
         self.conn.set_client_encoding('utf8')
         a = psycopg2.extensions.QuotedString(snowman.encode('utf8'))
         a.prepare(self.conn)
-        self.assertEqual(a.getquoted(), b"'\xe2\x98\x83'")
+        self.assertQuotedEqual(a.getquoted(), b"'\xe2\x98\x83'")
 
 
 def test_suite():
diff --git a/tests/test_sql.py b/tests/test_sql.py
index ffb4f1fb..16c4937d 100755
--- a/tests/test_sql.py
+++ b/tests/test_sql.py
@@ -200,9 +200,7 @@ class LiteralTests(ConnectingTestCase):
     def test_repr(self):
         self.assertEqual(repr(sql.Literal("foo")), "Literal('foo')")
         self.assertEqual(str(sql.Literal("foo")), "Literal('foo')")
-        self.assertEqual(
-            sql.Literal("foo").as_string(self.conn).replace("E'", "'"),
-            "'foo'")
+        self.assertQuotedEqual(sql.Literal("foo").as_string(self.conn), "'foo'")
         self.assertEqual(sql.Literal(42).as_string(self.conn), "42")
         self.assertEqual(
             sql.Literal(dt.date(2017, 1, 1)).as_string(self.conn),
@@ -302,24 +300,24 @@ class ComposedTest(ConnectingTestCase):
         obj = sql.Composed([sql.Literal("foo"), sql.Identifier("b'ar")])
         obj = obj.join(", ")
         self.assert_(isinstance(obj, sql.Composed))
-        self.assertEqual(obj.as_string(self.conn), "'foo', \"b'ar\"")
+        self.assertQuotedEqual(obj.as_string(self.conn), "'foo', \"b'ar\"")
 
     def test_sum(self):
         obj = sql.Composed([sql.SQL("foo ")])
         obj = obj + sql.Literal("bar")
         self.assert_(isinstance(obj, sql.Composed))
-        self.assertEqual(obj.as_string(self.conn), "foo 'bar'")
+        self.assertQuotedEqual(obj.as_string(self.conn), "foo 'bar'")
 
     def test_sum_inplace(self):
         obj = sql.Composed([sql.SQL("foo ")])
         obj += sql.Literal("bar")
         self.assert_(isinstance(obj, sql.Composed))
-        self.assertEqual(obj.as_string(self.conn), "foo 'bar'")
+        self.assertQuotedEqual(obj.as_string(self.conn), "foo 'bar'")
 
         obj = sql.Composed([sql.SQL("foo ")])
         obj += sql.Composed([sql.Literal("bar")])
         self.assert_(isinstance(obj, sql.Composed))
-        self.assertEqual(obj.as_string(self.conn), "foo 'bar'")
+        self.assertQuotedEqual(obj.as_string(self.conn), "foo 'bar'")
 
     def test_iter(self):
         obj = sql.Composed([sql.SQL("foo"), sql.SQL('bar')])
diff --git a/tests/test_types_extras.py b/tests/test_types_extras.py
index 264fca89..3e4771a7 100755
--- a/tests/test_types_extras.py
+++ b/tests/test_types_extras.py
@@ -31,13 +31,6 @@ import psycopg2.extras
 import psycopg2.extensions as ext
 
 
-def filter_scs(conn, s):
-    if conn.get_parameter_status("standard_conforming_strings") == 'off':
-        return s
-    else:
-        return s.replace(b"E'", b"'")
-
-
 class TypesExtrasTests(ConnectingTestCase):
     """Test that all type conversions are working."""
 
@@ -105,17 +98,13 @@ class TypesExtrasTests(ConnectingTestCase):
         i = Inet("192.168.1.0/24")
         a = psycopg2.extensions.adapt(i)
         a.prepare(self.conn)
-        self.assertEqual(
-            filter_scs(self.conn, b"E'192.168.1.0/24'::inet"),
-            a.getquoted())
+        self.assertQuotedEqual(a.getquoted(), b"'192.168.1.0/24'::inet")
 
         # adapts ok with unicode too
         i = Inet(u"192.168.1.0/24")
         a = psycopg2.extensions.adapt(i)
         a.prepare(self.conn)
-        self.assertEqual(
-            filter_scs(self.conn, b"E'192.168.1.0/24'::inet"),
-            a.getquoted())
+        self.assertQuotedEqual(a.getquoted(), b"'192.168.1.0/24'::inet")
 
     def test_adapt_fail(self):
         class Foo(object):
@@ -160,13 +149,12 @@ class HstoreTestCase(ConnectingTestCase):
         ii.sort()
 
         self.assertEqual(len(ii), len(o))
-        self.assertEqual(ii[0], filter_scs(self.conn, b"(E'a' => E'1')"))
-        self.assertEqual(ii[1], filter_scs(self.conn, b"(E'b' => E'''')"))
-        self.assertEqual(ii[2], filter_scs(self.conn, b"(E'c' => NULL)"))
+        self.assertQuotedEqual(ii[0], b"('a' => '1')")
+        self.assertQuotedEqual(ii[1], b"('b' => '''')")
+        self.assertQuotedEqual(ii[2], b"('c' => NULL)")
         if 'd' in o:
             encc = u'\xe0'.encode(psycopg2.extensions.encodings[self.conn.encoding])
-            self.assertEqual(ii[3],
-                filter_scs(self.conn, b"(E'd' => E'" + encc + b"')"))
+            self.assertQuotedEqual(ii[3], b"('d' => '" + encc + b"')")
 
     def test_adapt_9(self):
         if self.conn.server_version < 90000:
@@ -190,16 +178,17 @@ class HstoreTestCase(ConnectingTestCase):
         ii = zip(kk, vv)
         ii.sort()
 
-        def f(*args):
-            return tuple([filter_scs(self.conn, s) for s in args])
-
         self.assertEqual(len(ii), len(o))
-        self.assertEqual(ii[0], f(b"E'a'", b"E'1'"))
-        self.assertEqual(ii[1], f(b"E'b'", b"E''''"))
-        self.assertEqual(ii[2], f(b"E'c'", b"NULL"))
+        self.assertQuotedEqual(ii[0][0], b"'a'")
+        self.assertQuotedEqual(ii[0][1], b"'1'")
+        self.assertQuotedEqual(ii[1][0], b"'b'")
+        self.assertQuotedEqual(ii[1][1], b"''''")
+        self.assertQuotedEqual(ii[2][0], b"'c'")
+        self.assertQuotedEqual(ii[2][1], b"NULL")
         if 'd' in o:
             encc = u'\xe0'.encode(psycopg2.extensions.encodings[self.conn.encoding])
-            self.assertEqual(ii[3], f(b"E'd'", b"E'" + encc + b"'"))
+            self.assertQuotedEqual(ii[3][0], b"'d'")
+            self.assertQuotedEqual(ii[3][1], b"'" + encc + b"'")
 
     def test_parse(self):
         from psycopg2.extras import HstoreAdapter
diff --git a/tests/testutils.py b/tests/testutils.py
index b32f6a83..179f4df3 100644
--- a/tests/testutils.py
+++ b/tests/testutils.py
@@ -24,10 +24,11 @@
 
 # Use unittest2 if available. Otherwise mock a skip facility with warnings.
 
+import re
 import os
-import platform
 import sys
 import select
+import platform
 from functools import wraps
 from testconfig import dsn, repl_dsn
 
@@ -107,6 +108,18 @@ class ConnectingTestCase(unittest.TestCase):
             if not conn.closed:
                 conn.close()
 
+    def assertQuotedEqual(self, first, second, msg=None):
+        """Compare two quoted strings disregarding eventual E'' quotes"""
+        def f(s):
+            if isinstance(s, unicode):
+                return re.sub(r"\bE'", "'", s)
+            elif isinstance(first, bytes):
+                return re.sub(br"\bE'", b"'", s)
+            else:
+                return s
+
+        return self.assertEqual(f(first), f(second), msg)
+
     def connect(self, **kwargs):
         try:
             self._conns