mirror of
				https://github.com/psycopg/psycopg2.git
				synced 2025-11-04 09:47:30 +03:00 
			
		
		
		
	Use {} instead of %s placeholders in SQL composition
This commit is contained in:
		
							parent
							
								
									49461c2c39
								
							
						
					
					
						commit
						a76e665567
					
				| 
						 | 
					@ -51,7 +51,8 @@ from the query parameters::
 | 
				
			||||||
    from psycopg2 import sql
 | 
					    from psycopg2 import sql
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    cur.execute(
 | 
					    cur.execute(
 | 
				
			||||||
        sql.SQL("insert into %s values (%%s, %%s)") % [sql.Identifier('my_table')],
 | 
					        sql.SQL("insert into {} values (%s, %s)")
 | 
				
			||||||
 | 
					            .format(sql.Identifier('my_table')),
 | 
				
			||||||
        [10, 20])
 | 
					        [10, 20])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										219
									
								
								lib/sql.py
									
									
									
									
									
								
							
							
						
						
									
										219
									
								
								lib/sql.py
									
									
									
									
									
								
							| 
						 | 
					@ -23,24 +23,27 @@
 | 
				
			||||||
# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
 | 
					# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
 | 
				
			||||||
# License for more details.
 | 
					# License for more details.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import re
 | 
					 | 
				
			||||||
import sys
 | 
					import sys
 | 
				
			||||||
import collections
 | 
					import string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from psycopg2 import extensions as ext
 | 
					from psycopg2 import extensions as ext
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_formatter = string.Formatter()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Composable(object):
 | 
					class Composable(object):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Abstract base class for objects that can be used to compose an SQL string.
 | 
					    Abstract base class for objects that can be used to compose an SQL string.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Composables can be passed directly to `~cursor.execute()` and
 | 
					    `!Composable` objects can be passed directly to `~cursor.execute()` and
 | 
				
			||||||
    `~cursor.executemany()`.
 | 
					    `~cursor.executemany()`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Composables can be joined using the ``+`` operator: the result will be
 | 
					    `!Composable` objects can be joined using the ``+`` operator: the result
 | 
				
			||||||
    a `Composed` instance containing the objects joined. The operator ``*`` is
 | 
					    will be a `Composed` instance containing the objects joined. The operator
 | 
				
			||||||
    also supported with an integer argument: the result is a `!Composed`
 | 
					    ``*`` is also supported with an integer argument: the result is a
 | 
				
			||||||
    instance containing the left argument repeated as many times as requested.
 | 
					    `!Composed` instance containing the left argument repeated as many times as
 | 
				
			||||||
 | 
					    requested.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    .. automethod:: as_string
 | 
					    .. automethod:: as_string
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
| 
						 | 
					@ -144,21 +147,22 @@ class Composed(Composable):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class SQL(Composable):
 | 
					class SQL(Composable):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    A `Composable` representing a snippet of SQL string to be included verbatim.
 | 
					    A `Composable` representing a snippet of SQL statement.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    `!SQL` supports the ``%`` operator to incorporate variable parts of a query
 | 
					    `!SQL` exposes `join()` and `format()` methods useful to create a template
 | 
				
			||||||
    into a template: the operator takes a sequence or mapping of `Composable`
 | 
					    where to merge variable parts of a query (for instance field or table
 | 
				
			||||||
    (according to the style of the placeholders in the *string*) and returning
 | 
					    names).
 | 
				
			||||||
    a `Composed` object.
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Example::
 | 
					    Example::
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        >>> query = sql.SQL("select %s from %s") % [
 | 
					        >>> query = sql.SQL("select {} from {}").format(
 | 
				
			||||||
        ...    sql.SQL(', ').join([sql.Identifier('foo'), sql.Identifier('bar')]),
 | 
					        ...    sql.SQL(', ').join([sql.Identifier('foo'), sql.Identifier('bar')]),
 | 
				
			||||||
        ...    sql.Identifier('table')]
 | 
					        ...    sql.Identifier('table'))
 | 
				
			||||||
        >>> print(query.as_string(conn))
 | 
					        >>> print(query.as_string(conn))
 | 
				
			||||||
        select "foo", "bar" from "table"
 | 
					        select "foo", "bar" from "table"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    .. automethod:: format
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    .. automethod:: join
 | 
					    .. automethod:: join
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    def __init__(self, string):
 | 
					    def __init__(self, string):
 | 
				
			||||||
| 
						 | 
					@ -169,12 +173,73 @@ class SQL(Composable):
 | 
				
			||||||
    def __repr__(self):
 | 
					    def __repr__(self):
 | 
				
			||||||
        return "sql.SQL(%r)" % (self._wrapped,)
 | 
					        return "sql.SQL(%r)" % (self._wrapped,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __mod__(self, args):
 | 
					 | 
				
			||||||
        return _compose(self._wrapped, args)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def as_string(self, conn_or_curs):
 | 
					    def as_string(self, conn_or_curs):
 | 
				
			||||||
        return self._wrapped
 | 
					        return self._wrapped
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def format(self, *args, **kwargs):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Merge `Composable` objects into a template.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        :param `Composable` args: parameters to replace to numbered
 | 
				
			||||||
 | 
					            (``{0}``, ``{1}``) or auto-numbered (``{}``) placeholders
 | 
				
			||||||
 | 
					        :param `Composable` kwargs: parameters to replace to named (``{name}``)
 | 
				
			||||||
 | 
					            placeholders
 | 
				
			||||||
 | 
					        :return: the union of the `!SQL` string with placeholders replaced
 | 
				
			||||||
 | 
					        :rtype: `Composed`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        The method is similar to the Python `str.format()` method: the string
 | 
				
			||||||
 | 
					        template supports auto-numbered (``{}``), numbered (``{0}``,
 | 
				
			||||||
 | 
					        ``{1}``...), and named placeholders (``{name}``), with positional
 | 
				
			||||||
 | 
					        arguments replacing the numbered placeholders and keywords replacing
 | 
				
			||||||
 | 
					        the named ones. However placeholder modifiers (``{{0!r}}``,
 | 
				
			||||||
 | 
					        ``{{0:<10}}``) are not supported. Only `!Composable` objects can be
 | 
				
			||||||
 | 
					        passed to the template.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Example::
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            >>> print(sql.SQL("select * from {} where {} = %s")
 | 
				
			||||||
 | 
					            ...     .format(sql.Identifier('people'), sql.Identifier('id'))
 | 
				
			||||||
 | 
					            ...     .as_string(conn))
 | 
				
			||||||
 | 
					            select * from "people" where "id" = %s
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            >>> print(sql.SQL("select * from {tbl} where {pkey} = %s")
 | 
				
			||||||
 | 
					            ...     .format(tbl=sql.Identifier('people'), pkey=sql.Identifier('id'))
 | 
				
			||||||
 | 
					            ...     .as_string(conn))
 | 
				
			||||||
 | 
					            select * from "people" where "id" = %s
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        rv = []
 | 
				
			||||||
 | 
					        autonum = 0
 | 
				
			||||||
 | 
					        for pre, name, spec, conv in _formatter.parse(self._wrapped):
 | 
				
			||||||
 | 
					            if spec:
 | 
				
			||||||
 | 
					                raise ValueError("no format specification supported by SQL")
 | 
				
			||||||
 | 
					            if conv:
 | 
				
			||||||
 | 
					                raise ValueError("no format conversion supported by SQL")
 | 
				
			||||||
 | 
					            if pre:
 | 
				
			||||||
 | 
					                rv.append(SQL(pre))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if name is None:
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if name.isdigit():
 | 
				
			||||||
 | 
					                if autonum:
 | 
				
			||||||
 | 
					                    raise ValueError(
 | 
				
			||||||
 | 
					                        "cannot switch from automatic field numbering to manual")
 | 
				
			||||||
 | 
					                rv.append(args[int(name)])
 | 
				
			||||||
 | 
					                autonum = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            elif not name:
 | 
				
			||||||
 | 
					                if autonum is None:
 | 
				
			||||||
 | 
					                    raise ValueError(
 | 
				
			||||||
 | 
					                        "cannot switch from manual field numbering to automatic")
 | 
				
			||||||
 | 
					                rv.append(args[autonum])
 | 
				
			||||||
 | 
					                autonum += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                rv.append(kwargs[name])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return Composed(rv)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def join(self, seq):
 | 
					    def join(self, seq):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Join a sequence of `Composable` or a `Composed` and return a `!Composed`.
 | 
					        Join a sequence of `Composable` or a `Composed` and return a `!Composed`.
 | 
				
			||||||
| 
						 | 
					@ -183,7 +248,8 @@ class SQL(Composable):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Example::
 | 
					        Example::
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            >>> snip - sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', 'baz']))
 | 
					            >>> snip = sql.SQL(', ').join(
 | 
				
			||||||
 | 
					            ...     sql.Identifier(n) for n in ['foo', 'bar', 'baz'])
 | 
				
			||||||
            >>> print(snip.as_string(conn))
 | 
					            >>> print(snip.as_string(conn))
 | 
				
			||||||
            "foo", "bar", "baz"
 | 
					            "foo", "bar", "baz"
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
| 
						 | 
					@ -331,123 +397,6 @@ class Placeholder(Composable):
 | 
				
			||||||
            return "%s"
 | 
					            return "%s"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
re_compose = re.compile("""
 | 
					 | 
				
			||||||
    %                           # percent sign
 | 
					 | 
				
			||||||
    (?:
 | 
					 | 
				
			||||||
        ([%s])                  # either % or s
 | 
					 | 
				
			||||||
        | \( ([^\)]+) \) s      # or a (named)s placeholder (named captured)
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    """, re.VERBOSE)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def _compose(sql, args=None):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Merge an SQL string with some variable parts.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    The *sql* string can contain placeholders such as `%s` or `%(name)s`.
 | 
					 | 
				
			||||||
    If the string must contain a literal ``%`` symbol use ``%%``. Note that,
 | 
					 | 
				
			||||||
    unlike `~cursor.execute()`, the replacement ``%%`` |=>| ``%`` is *always*
 | 
					 | 
				
			||||||
    performed, even if there is no argument.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    .. |=>| unicode:: 0x21D2 .. double right arrow
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    *args* must be a sequence or mapping (according to the placeholder style)
 | 
					 | 
				
			||||||
    of `Composable` instances.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    The value returned is a `Composed` instance obtained replacing the
 | 
					 | 
				
			||||||
    arguments to the query placeholders.
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    if args is None:
 | 
					 | 
				
			||||||
        args = ()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    phs = list(re_compose.finditer(sql))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # check placeholders consistent
 | 
					 | 
				
			||||||
    counts = {'%': 0, 's': 0, None: 0}
 | 
					 | 
				
			||||||
    for ph in phs:
 | 
					 | 
				
			||||||
        counts[ph.group(1)] += 1
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    npos = counts['s']
 | 
					 | 
				
			||||||
    nnamed = counts[None]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if npos and nnamed:
 | 
					 | 
				
			||||||
        raise ValueError(
 | 
					 | 
				
			||||||
            "the sql string contains both named and positional placeholders")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    elif npos:
 | 
					 | 
				
			||||||
        if not isinstance(args, collections.Sequence):
 | 
					 | 
				
			||||||
            raise TypeError(
 | 
					 | 
				
			||||||
                "the sql string expects values in a sequence, got %s instead"
 | 
					 | 
				
			||||||
                % type(args).__name__)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if len(args) != npos:
 | 
					 | 
				
			||||||
            raise ValueError(
 | 
					 | 
				
			||||||
                "the sql string expects %s values, got %s" % (npos, len(args)))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return _compose_seq(sql, phs, args)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    elif nnamed:
 | 
					 | 
				
			||||||
        if not isinstance(args, collections.Mapping):
 | 
					 | 
				
			||||||
            raise TypeError(
 | 
					 | 
				
			||||||
                "the sql string expects values in a mapping, got %s instead"
 | 
					 | 
				
			||||||
                % type(args))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return _compose_map(sql, phs, args)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        if isinstance(args, collections.Sequence) and args:
 | 
					 | 
				
			||||||
            raise ValueError(
 | 
					 | 
				
			||||||
                "the sql string expects no value, got %s instead" % len(args))
 | 
					 | 
				
			||||||
        # If args are a mapping, no placeholder is an acceptable case
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Convert %% into %
 | 
					 | 
				
			||||||
        return _compose_seq(sql, phs, ())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def _compose_seq(sql, phs, args):
 | 
					 | 
				
			||||||
    rv = []
 | 
					 | 
				
			||||||
    j = 0
 | 
					 | 
				
			||||||
    for i, ph in enumerate(phs):
 | 
					 | 
				
			||||||
        if i:
 | 
					 | 
				
			||||||
            rv.append(SQL(sql[phs[i - 1].end():ph.start()]))
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            rv.append(SQL(sql[0:ph.start()]))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if ph.group(1) == 's':
 | 
					 | 
				
			||||||
            rv.append(args[j])
 | 
					 | 
				
			||||||
            j += 1
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            rv.append(SQL('%'))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if phs:
 | 
					 | 
				
			||||||
        rv.append(SQL(sql[phs[-1].end():]))
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        rv.append(SQL(sql))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return Composed(rv)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def _compose_map(sql, phs, args):
 | 
					 | 
				
			||||||
    rv = []
 | 
					 | 
				
			||||||
    for i, ph in enumerate(phs):
 | 
					 | 
				
			||||||
        if i:
 | 
					 | 
				
			||||||
            rv.append(SQL(sql[phs[i - 1].end():ph.start()]))
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            rv.append(SQL(sql[0:ph.start()]))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if ph.group(2):
 | 
					 | 
				
			||||||
            rv.append(args[ph.group(2)])
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            rv.append(SQL('%'))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if phs:
 | 
					 | 
				
			||||||
        rv.append(SQL(sql[phs[-1].end():]))
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        rv.append(sql)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return Composed(rv)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Alias
 | 
					# Alias
 | 
				
			||||||
PH = Placeholder
 | 
					PH = Placeholder
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -25,65 +25,89 @@
 | 
				
			||||||
import datetime as dt
 | 
					import datetime as dt
 | 
				
			||||||
from testutils import unittest, ConnectingTestCase
 | 
					from testutils import unittest, ConnectingTestCase
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import psycopg2
 | 
				
			||||||
from psycopg2 import sql
 | 
					from psycopg2 import sql
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ComposeTests(ConnectingTestCase):
 | 
					class SqlFormatTests(ConnectingTestCase):
 | 
				
			||||||
    def test_pos(self):
 | 
					    def test_pos(self):
 | 
				
			||||||
        s = sql.SQL("select %s from %s") \
 | 
					        s = sql.SQL("select {} from {}").format(
 | 
				
			||||||
            % (sql.Identifier('field'), sql.Identifier('table'))
 | 
					            sql.Identifier('field'), sql.Identifier('table'))
 | 
				
			||||||
 | 
					        s1 = s.as_string(self.conn)
 | 
				
			||||||
 | 
					        self.assert_(isinstance(s1, str))
 | 
				
			||||||
 | 
					        self.assertEqual(s1, 'select "field" from "table"')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_pos_spec(self):
 | 
				
			||||||
 | 
					        s = sql.SQL("select {0} from {1}").format(
 | 
				
			||||||
 | 
					            sql.Identifier('field'), sql.Identifier('table'))
 | 
				
			||||||
 | 
					        s1 = s.as_string(self.conn)
 | 
				
			||||||
 | 
					        self.assert_(isinstance(s1, str))
 | 
				
			||||||
 | 
					        self.assertEqual(s1, 'select "field" from "table"')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        s = sql.SQL("select {1} from {0}").format(
 | 
				
			||||||
 | 
					            sql.Identifier('table'), sql.Identifier('field'))
 | 
				
			||||||
        s1 = s.as_string(self.conn)
 | 
					        s1 = s.as_string(self.conn)
 | 
				
			||||||
        self.assert_(isinstance(s1, str))
 | 
					        self.assert_(isinstance(s1, str))
 | 
				
			||||||
        self.assertEqual(s1, 'select "field" from "table"')
 | 
					        self.assertEqual(s1, 'select "field" from "table"')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_dict(self):
 | 
					    def test_dict(self):
 | 
				
			||||||
        s = sql.SQL("select %(f)s from %(t)s") \
 | 
					        s = sql.SQL("select {f} from {t}").format(
 | 
				
			||||||
            % {'f': sql.Identifier('field'), 't': sql.Identifier('table')}
 | 
					            f=sql.Identifier('field'), t=sql.Identifier('table'))
 | 
				
			||||||
        s1 = s.as_string(self.conn)
 | 
					        s1 = s.as_string(self.conn)
 | 
				
			||||||
        self.assert_(isinstance(s1, str))
 | 
					        self.assert_(isinstance(s1, str))
 | 
				
			||||||
        self.assertEqual(s1, 'select "field" from "table"')
 | 
					        self.assertEqual(s1, 'select "field" from "table"')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_unicode(self):
 | 
					    def test_unicode(self):
 | 
				
			||||||
        s = sql.SQL(u"select %s from %s") \
 | 
					        s = sql.SQL(u"select {} from {}").format(
 | 
				
			||||||
            % (sql.Identifier(u'field'), sql.Identifier('table'))
 | 
					            sql.Identifier(u'field'), sql.Identifier('table'))
 | 
				
			||||||
        s1 = s.as_string(self.conn)
 | 
					        s1 = s.as_string(self.conn)
 | 
				
			||||||
        self.assert_(isinstance(s1, unicode))
 | 
					        self.assert_(isinstance(s1, unicode))
 | 
				
			||||||
        self.assertEqual(s1, u'select "field" from "table"')
 | 
					        self.assertEqual(s1, u'select "field" from "table"')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_compose_literal(self):
 | 
					    def test_compose_literal(self):
 | 
				
			||||||
        s = sql.SQL("select %s;") % [sql.Literal(dt.date(2016, 12, 31))]
 | 
					        s = sql.SQL("select {};").format(sql.Literal(dt.date(2016, 12, 31)))
 | 
				
			||||||
        s1 = s.as_string(self.conn)
 | 
					        s1 = s.as_string(self.conn)
 | 
				
			||||||
        self.assertEqual(s1, "select '2016-12-31'::date;")
 | 
					        self.assertEqual(s1, "select '2016-12-31'::date;")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_compose_empty(self):
 | 
					    def test_compose_empty(self):
 | 
				
			||||||
        s = sql.SQL("select foo;") % ()
 | 
					        s = sql.SQL("select foo;").format()
 | 
				
			||||||
        s1 = s.as_string(self.conn)
 | 
					        s1 = s.as_string(self.conn)
 | 
				
			||||||
        self.assertEqual(s1, "select foo;")
 | 
					        self.assertEqual(s1, "select foo;")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_percent_escape(self):
 | 
					    def test_percent_escape(self):
 | 
				
			||||||
        s = sql.SQL("42 %% %s") % [sql.Literal(7)]
 | 
					        s = sql.SQL("42 % {}").format(sql.Literal(7))
 | 
				
			||||||
        s1 = s.as_string(self.conn)
 | 
					        s1 = s.as_string(self.conn)
 | 
				
			||||||
        self.assertEqual(s1, "42 % 7")
 | 
					        self.assertEqual(s1, "42 % 7")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        s = sql.SQL("42 %% 7") % []
 | 
					    def test_braces_escape(self):
 | 
				
			||||||
        s1 = s.as_string(self.conn)
 | 
					        s = sql.SQL("{{{}}}").format(sql.Literal(7))
 | 
				
			||||||
        self.assertEqual(s1, "42 % 7")
 | 
					        self.assertEqual(s.as_string(self.conn), "{7}")
 | 
				
			||||||
 | 
					        s = sql.SQL("{{1,{}}}").format(sql.Literal(7))
 | 
				
			||||||
 | 
					        self.assertEqual(s.as_string(self.conn), "{1,7}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_compose_badnargs(self):
 | 
					    def test_compose_badnargs(self):
 | 
				
			||||||
        self.assertRaises(ValueError, sql.SQL("select foo;").__mod__, [10])
 | 
					        self.assertRaises(IndexError, sql.SQL("select {};").format)
 | 
				
			||||||
        self.assertRaises(ValueError, sql.SQL("select %s;").__mod__, [])
 | 
					        self.assertRaises(ValueError, sql.SQL("select {} {1};").format, 10, 20)
 | 
				
			||||||
        self.assertRaises(ValueError, sql.SQL("select %s;").__mod__, [10, 20])
 | 
					        self.assertRaises(ValueError, sql.SQL("select {0} {};").format, 10, 20)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_compose_bad_args_type(self):
 | 
					    def test_compose_bad_args_type(self):
 | 
				
			||||||
        self.assertRaises(TypeError, sql.SQL("select %s;").__mod__, {'a': 10})
 | 
					        self.assertRaises(IndexError, sql.SQL("select {};").format, a=10)
 | 
				
			||||||
        self.assertRaises(TypeError, sql.SQL("select %(x)s;").__mod__, [10])
 | 
					        self.assertRaises(KeyError, sql.SQL("select {x};").format, 10)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_must_be_composable(self):
 | 
				
			||||||
 | 
					        self.assertRaises(TypeError, sql.SQL("select {};").format, 'foo')
 | 
				
			||||||
 | 
					        self.assertRaises(TypeError, sql.SQL("select {};").format, 10)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_no_modifiers(self):
 | 
				
			||||||
 | 
					        self.assertRaises(ValueError, sql.SQL("select {a!r};").format, a=10)
 | 
				
			||||||
 | 
					        self.assertRaises(ValueError, sql.SQL("select {a:<};").format, a=10)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_must_be_adaptable(self):
 | 
					    def test_must_be_adaptable(self):
 | 
				
			||||||
        class Foo(object):
 | 
					        class Foo(object):
 | 
				
			||||||
            pass
 | 
					            pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.assertRaises(TypeError,
 | 
					        self.assertRaises(psycopg2.ProgrammingError,
 | 
				
			||||||
            sql.SQL("select %s;").__mod__, [Foo()])
 | 
					            sql.SQL("select {};").format(sql.Literal(Foo())).as_string, self.conn)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_execute(self):
 | 
					    def test_execute(self):
 | 
				
			||||||
        cur = self.conn.cursor()
 | 
					        cur = self.conn.cursor()
 | 
				
			||||||
| 
						 | 
					@ -93,11 +117,10 @@ class ComposeTests(ConnectingTestCase):
 | 
				
			||||||
                foo text, bar text, "ba'z" text)
 | 
					                foo text, bar text, "ba'z" text)
 | 
				
			||||||
            """)
 | 
					            """)
 | 
				
			||||||
        cur.execute(
 | 
					        cur.execute(
 | 
				
			||||||
            sql.SQL("insert into %s (id, %s) values (%%s, %s)") % [
 | 
					            sql.SQL("insert into {} (id, {}) values (%s, {})").format(
 | 
				
			||||||
                sql.Identifier('test_compose'),
 | 
					                sql.Identifier('test_compose'),
 | 
				
			||||||
                sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])),
 | 
					                sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])),
 | 
				
			||||||
                (sql.PH() * 3).join(', '),
 | 
					                (sql.PH() * 3).join(', ')),
 | 
				
			||||||
            ],
 | 
					 | 
				
			||||||
            (10, 'a', 'b', 'c'))
 | 
					            (10, 'a', 'b', 'c'))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        cur.execute("select * from test_compose")
 | 
					        cur.execute("select * from test_compose")
 | 
				
			||||||
| 
						 | 
					@ -111,11 +134,10 @@ class ComposeTests(ConnectingTestCase):
 | 
				
			||||||
                foo text, bar text, "ba'z" text)
 | 
					                foo text, bar text, "ba'z" text)
 | 
				
			||||||
            """)
 | 
					            """)
 | 
				
			||||||
        cur.executemany(
 | 
					        cur.executemany(
 | 
				
			||||||
            sql.SQL("insert into %s (id, %s) values (%%s, %s)") % [
 | 
					            sql.SQL("insert into {} (id, {}) values (%s, {})").format(
 | 
				
			||||||
                sql.Identifier('test_compose'),
 | 
					                sql.Identifier('test_compose'),
 | 
				
			||||||
                sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])),
 | 
					                sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])),
 | 
				
			||||||
                (sql.PH() * 3).join(', '),
 | 
					                (sql.PH() * 3).join(', ')),
 | 
				
			||||||
            ],
 | 
					 | 
				
			||||||
            [(10, 'a', 'b', 'c'), (20, 'd', 'e', 'f')])
 | 
					            [(10, 'a', 'b', 'c'), (20, 'd', 'e', 'f')])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        cur.execute("select * from test_compose")
 | 
					        cur.execute("select * from test_compose")
 | 
				
			||||||
| 
						 | 
					@ -169,6 +191,13 @@ class LiteralTests(ConnectingTestCase):
 | 
				
			||||||
            sql.Literal(dt.date(2017, 1, 1)).as_string(self.conn),
 | 
					            sql.Literal(dt.date(2017, 1, 1)).as_string(self.conn),
 | 
				
			||||||
            "'2017-01-01'::date")
 | 
					            "'2017-01-01'::date")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_must_be_adaptable(self):
 | 
				
			||||||
 | 
					        class Foo(object):
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertRaises(psycopg2.ProgrammingError,
 | 
				
			||||||
 | 
					            sql.Literal(Foo()).as_string, self.conn)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class SQLTests(ConnectingTestCase):
 | 
					class SQLTests(ConnectingTestCase):
 | 
				
			||||||
    def test_class(self):
 | 
					    def test_class(self):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user