mirror of
https://github.com/psycopg/psycopg2.git
synced 2024-11-10 19:16:34 +03:00
Added basic sql module implementation
This commit is contained in:
parent
fad5100079
commit
f11e6d82b0
174
lib/sql.py
174
lib/sql.py
|
@ -23,3 +23,177 @@
|
|||
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
|
||||
# License for more details.
|
||||
|
||||
from psycopg2 import extensions as ext
|
||||
|
||||
|
||||
class Composible(object):
|
||||
"""Base class for objects that can be used to compose an SQL string."""
|
||||
def as_string(self, conn_or_curs):
|
||||
raise NotImplementedError
|
||||
|
||||
def __add__(self, other):
|
||||
if isinstance(other, Composed):
|
||||
return Composed([self]) + other
|
||||
if isinstance(other, Composible):
|
||||
return Composed([self]) + Composed([other])
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class Composed(Composible):
|
||||
def __init__(self, seq):
|
||||
self._seq = []
|
||||
for i in seq:
|
||||
if not isinstance(i, Composible):
|
||||
raise TypeError(
|
||||
"Composed elements must be Composible, got %r instead" % i)
|
||||
self._seq.append(i)
|
||||
|
||||
def __repr__(self):
|
||||
return "sql.Composed(%r)" % (self.seq,)
|
||||
|
||||
def as_string(self, conn_or_curs):
|
||||
rv = []
|
||||
for i in self._seq:
|
||||
rv.append(i.as_string(conn_or_curs))
|
||||
return ''.join(rv)
|
||||
|
||||
def __add__(self, other):
|
||||
if isinstance(other, Composed):
|
||||
return Composed(self._seq + other._seq)
|
||||
if isinstance(other, Composible):
|
||||
return Composed(self._seq + [other])
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def __mul__(self, n):
|
||||
return Composed(self._seq * n)
|
||||
|
||||
def join(self, joiner):
|
||||
if isinstance(joiner, basestring):
|
||||
joiner = SQL(joiner)
|
||||
elif not isinstance(joiner, SQL):
|
||||
raise TypeError(
|
||||
"Composed.join() argument must be a string or an SQL")
|
||||
|
||||
if len(self._seq) <= 1:
|
||||
return self
|
||||
|
||||
it = iter(self._seq)
|
||||
rv = [it.next()]
|
||||
for i in it:
|
||||
rv.append(joiner)
|
||||
rv.append(i)
|
||||
|
||||
return Composed(rv)
|
||||
|
||||
|
||||
class SQL(Composible):
|
||||
def __init__(self, wrapped):
|
||||
if not isinstance(wrapped, basestring):
|
||||
raise TypeError("SQL values must be strings")
|
||||
self._wrapped = wrapped
|
||||
|
||||
def __repr__(self):
|
||||
return "sql.SQL(%r)" % (self._wrapped,)
|
||||
|
||||
def as_string(self, conn_or_curs):
|
||||
return self._wrapped
|
||||
|
||||
def __mul__(self, n):
|
||||
return Composed([self] * n)
|
||||
|
||||
def join(self, seq):
|
||||
rv = []
|
||||
it = iter(seq)
|
||||
try:
|
||||
rv.append(it.next())
|
||||
except StopIteration:
|
||||
pass
|
||||
else:
|
||||
for i in it:
|
||||
rv.append(self)
|
||||
rv.append(i)
|
||||
|
||||
return Composed(rv)
|
||||
|
||||
|
||||
class Identifier(Composible):
|
||||
def __init__(self, wrapped):
|
||||
if not isinstance(wrapped, basestring):
|
||||
raise TypeError("SQL identifiers must be strings")
|
||||
|
||||
self._wrapped = wrapped
|
||||
|
||||
@property
|
||||
def wrapped(self):
|
||||
return self._wrapped
|
||||
|
||||
def __repr__(self):
|
||||
return "sql.Identifier(%r)" % (self._wrapped,)
|
||||
|
||||
def as_string(self, conn_or_curs):
|
||||
return ext.quote_ident(self._wrapped, conn_or_curs)
|
||||
|
||||
|
||||
class Literal(Composible):
|
||||
def __init__(self, wrapped):
|
||||
self._wrapped = wrapped
|
||||
|
||||
def __repr__(self):
|
||||
return "sql.Literal(%r)" % (self._wrapped,)
|
||||
|
||||
def as_string(self, conn_or_curs):
|
||||
a = ext.adapt(self._wrapped)
|
||||
if hasattr(a, 'prepare'):
|
||||
# is it a connection or cursor?
|
||||
if isinstance(conn_or_curs, ext.connection):
|
||||
conn = conn_or_curs
|
||||
elif isinstance(conn_or_curs, ext.cursor):
|
||||
conn = conn_or_curs.connection
|
||||
else:
|
||||
raise TypeError("conn_or_curs must be a connection or a cursor")
|
||||
|
||||
a.prepare(conn)
|
||||
|
||||
return a.getquoted()
|
||||
|
||||
def __mul__(self, n):
|
||||
return Composed([self] * n)
|
||||
|
||||
|
||||
class Placeholder(Composible):
|
||||
def __init__(self, name=None):
|
||||
if isinstance(name, basestring):
|
||||
if ')' in name:
|
||||
raise ValueError("invalid name: %r" % name)
|
||||
|
||||
elif name is not None:
|
||||
raise TypeError("expected string or None as name, got %r" % name)
|
||||
|
||||
self._name = name
|
||||
|
||||
def __repr__(self):
|
||||
return "sql.Placeholder(%r)" % (
|
||||
self._name if self._name is not None else '',)
|
||||
|
||||
def __mul__(self, n):
|
||||
return Composed([self] * n)
|
||||
|
||||
def as_string(self, conn_or_curs):
|
||||
if self._name is not None:
|
||||
return "%%(%s)s" % self._name
|
||||
else:
|
||||
return "%s"
|
||||
|
||||
|
||||
def compose(sql, args=()):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Alias
|
||||
PH = Placeholder
|
||||
|
||||
# Literals
|
||||
NULL = SQL("NULL")
|
||||
DEFAULT = SQL("DEFAULT")
|
||||
|
|
|
@ -22,7 +22,228 @@
|
|||
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
|
||||
# License for more details.
|
||||
|
||||
from testutils import unittest
|
||||
import datetime as dt
|
||||
from testutils import unittest, ConnectingTestCase
|
||||
|
||||
from psycopg2 import sql
|
||||
|
||||
|
||||
class ComposeTests(ConnectingTestCase):
|
||||
def test_pos(self):
|
||||
s = sql.compose("select %s from %s",
|
||||
(sql.Identifier('field'), sql.Identifier('table')))
|
||||
s1 = s.as_string(self.conn)
|
||||
self.assert_(isinstance(s1, str))
|
||||
self.assertEqual(s1, 'select "field" from "table"')
|
||||
|
||||
def test_dict(self):
|
||||
s = sql.compose("select %(f)s from %(t)s",
|
||||
{'f': sql.Identifier('field'), 't': sql.Identifier('table')})
|
||||
s1 = s.as_string(self.conn)
|
||||
self.assert_(isinstance(s1, str))
|
||||
self.assertEqual(s1, 'select "field" from "table"')
|
||||
|
||||
def test_unicode(self):
|
||||
s = sql.compose(u"select %s from %s",
|
||||
(sql.Identifier(u'field'), sql.Identifier('table')))
|
||||
s1 = s.as_string(self.conn)
|
||||
self.assert_(isinstance(s1, unicode))
|
||||
self.assertEqual(s1, u'select "field" from "table"')
|
||||
|
||||
def test_compose_literal(self):
|
||||
s = sql.compose("select %s;", [sql.Literal(dt.date(2016, 12, 31))])
|
||||
s1 = s.as_string(self.conn)
|
||||
self.assertEqual(s1, "select '2016-12-31'::date;")
|
||||
|
||||
def test_must_be_adaptable(self):
|
||||
class Foo(object):
|
||||
pass
|
||||
|
||||
self.assertRaises(TypeError,
|
||||
sql.compose, "select %s;", [Foo()])
|
||||
|
||||
def test_execute(self):
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("""
|
||||
create table test_compose (
|
||||
id serial primary key,
|
||||
foo text, bar text, "ba'z" text)
|
||||
""")
|
||||
cur.execute(
|
||||
sql.compose("insert into %s (id, %s) values (%%s, %s)", [
|
||||
sql.Identifier('test_compose'),
|
||||
sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])),
|
||||
(sql.PH() * 3).join(', '),
|
||||
]),
|
||||
(10, 'a', 'b', 'c'))
|
||||
|
||||
cur.execute("select * from test_compose")
|
||||
self.assertEqual(cur.fetchall(), [(10, 'a', 'b', 'c')])
|
||||
|
||||
def test_executemany(self):
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("""
|
||||
create table test_compose (
|
||||
id serial primary key,
|
||||
foo text, bar text, "ba'z" text)
|
||||
""")
|
||||
cur.executemany(
|
||||
sql.compose("insert into %s (id, %s) values (%%s, %s)", [
|
||||
sql.Identifier('test_compose'),
|
||||
sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])),
|
||||
(sql.PH() * 3).join(', '),
|
||||
]),
|
||||
[(10, 'a', 'b', 'c'), (20, 'd', 'e', 'f')])
|
||||
|
||||
cur.execute("select * from test_compose")
|
||||
self.assertEqual(cur.fetchall(),
|
||||
[(10, 'a', 'b', 'c'), (20, 'd', 'e', 'f')])
|
||||
|
||||
|
||||
class IdentifierTests(ConnectingTestCase):
|
||||
def test_class(self):
|
||||
self.assert_(issubclass(sql.Identifier, sql.Composible))
|
||||
|
||||
def test_init(self):
|
||||
self.assert_(isinstance(sql.Identifier('foo'), sql.Identifier))
|
||||
self.assert_(isinstance(sql.Identifier(u'foo'), sql.Identifier))
|
||||
self.assert_(isinstance(sql.Identifier(b'foo'), sql.Identifier))
|
||||
self.assertRaises(TypeError, sql.Identifier, 10)
|
||||
self.assertRaises(TypeError, sql.Identifier, dt.date(2016, 12, 31))
|
||||
|
||||
def test_repr(self):
|
||||
obj = sql.Identifier("fo'o")
|
||||
self.assertEqual(repr(obj), 'sql.Identifier("fo\'o")')
|
||||
self.assertEqual(repr(obj), str(obj))
|
||||
|
||||
def test_as_str(self):
|
||||
self.assertEqual(sql.Identifier('foo').as_string(self.conn), '"foo"')
|
||||
self.assertEqual(sql.Identifier("fo'o").as_string(self.conn), '"fo\'o"')
|
||||
|
||||
def test_join(self):
|
||||
self.assert_(not hasattr(sql.Identifier('foo'), 'join'))
|
||||
|
||||
|
||||
class LiteralTests(ConnectingTestCase):
|
||||
def test_class(self):
|
||||
self.assert_(issubclass(sql.Literal, sql.Composible))
|
||||
|
||||
def test_init(self):
|
||||
self.assert_(isinstance(sql.Literal('foo'), sql.Literal))
|
||||
self.assert_(isinstance(sql.Literal(u'foo'), sql.Literal))
|
||||
self.assert_(isinstance(sql.Literal(b'foo'), sql.Literal))
|
||||
self.assert_(isinstance(sql.Literal(42), sql.Literal))
|
||||
self.assert_(isinstance(
|
||||
sql.Literal(dt.date(2016, 12, 31)), sql.Literal))
|
||||
|
||||
def test_repr(self):
|
||||
self.assertEqual(repr(sql.Literal("foo")), "sql.Literal('foo')")
|
||||
self.assertEqual(str(sql.Literal("foo")), "sql.Literal('foo')")
|
||||
self.assertEqual(
|
||||
sql.Literal("foo").as_string(self.conn).replace("E'", "'"),
|
||||
"'foo'")
|
||||
self.assertEqual(sql.Literal(42).as_string(self.conn), "42")
|
||||
self.assertEqual(
|
||||
sql.Literal(dt.date(2017, 1, 1)).as_string(self.conn),
|
||||
"'2017-01-01'::date")
|
||||
|
||||
|
||||
class SQLTests(ConnectingTestCase):
|
||||
def test_class(self):
|
||||
self.assert_(issubclass(sql.SQL, sql.Composible))
|
||||
|
||||
def test_init(self):
|
||||
self.assert_(isinstance(sql.SQL('foo'), sql.SQL))
|
||||
self.assert_(isinstance(sql.SQL(u'foo'), sql.SQL))
|
||||
self.assert_(isinstance(sql.SQL(b'foo'), sql.SQL))
|
||||
self.assertRaises(TypeError, sql.SQL, 10)
|
||||
self.assertRaises(TypeError, sql.SQL, dt.date(2016, 12, 31))
|
||||
|
||||
def test_str(self):
|
||||
self.assertEqual(repr(sql.SQL("foo")), "sql.SQL('foo')")
|
||||
self.assertEqual(str(sql.SQL("foo")), "sql.SQL('foo')")
|
||||
self.assertEqual(sql.SQL("foo").as_string(self.conn), "foo")
|
||||
|
||||
def test_sum(self):
|
||||
obj = sql.SQL("foo") + sql.SQL("bar")
|
||||
self.assert_(isinstance(obj, sql.Composed))
|
||||
self.assertEqual(obj.as_string(self.conn), "foobar")
|
||||
|
||||
def test_sum_inplace(self):
|
||||
obj = sql.SQL("foo")
|
||||
obj += sql.SQL("bar")
|
||||
self.assert_(isinstance(obj, sql.Composed))
|
||||
self.assertEqual(obj.as_string(self.conn), "foobar")
|
||||
|
||||
def test_multiply(self):
|
||||
obj = sql.SQL("foo") * 3
|
||||
self.assert_(isinstance(obj, sql.Composed))
|
||||
self.assertEqual(obj.as_string(self.conn), "foofoofoo")
|
||||
|
||||
def test_join(self):
|
||||
obj = sql.SQL(", ").join(
|
||||
[sql.Identifier('foo'), sql.SQL('bar'), sql.Literal(42)])
|
||||
self.assert_(isinstance(obj, sql.Composed))
|
||||
self.assertEqual(obj.as_string(self.conn), '"foo", bar, 42')
|
||||
|
||||
|
||||
class ComposedTest(ConnectingTestCase):
|
||||
def test_class(self):
|
||||
self.assert_(issubclass(sql.Composed, sql.Composible))
|
||||
|
||||
def test_join(self):
|
||||
obj = sql.Composed([sql.Literal("foo"), sql.Identifier("b'ar")])
|
||||
obj = obj.join(", ")
|
||||
self.assert_(isinstance(obj, sql.Composed))
|
||||
self.assertEqual(obj.as_string(self.conn), "'foo', \"b'ar\"")
|
||||
|
||||
def test_sum(self):
|
||||
obj = sql.Composed([sql.SQL("foo ")])
|
||||
obj = obj + sql.Literal("bar")
|
||||
self.assert_(isinstance(obj, sql.Composed))
|
||||
self.assertEqual(obj.as_string(self.conn), "foo 'bar'")
|
||||
|
||||
def test_sum_inplace(self):
|
||||
obj = sql.Composed([sql.SQL("foo ")])
|
||||
obj += sql.Literal("bar")
|
||||
self.assert_(isinstance(obj, sql.Composed))
|
||||
self.assertEqual(obj.as_string(self.conn), "foo 'bar'")
|
||||
|
||||
obj = sql.Composed([sql.SQL("foo ")])
|
||||
obj += sql.Composed([sql.Literal("bar")])
|
||||
self.assert_(isinstance(obj, sql.Composed))
|
||||
self.assertEqual(obj.as_string(self.conn), "foo 'bar'")
|
||||
|
||||
|
||||
class PlaceholderTest(ConnectingTestCase):
|
||||
def test_class(self):
|
||||
self.assert_(issubclass(sql.Placeholder, sql.Composible))
|
||||
|
||||
def test_alias(self):
|
||||
self.assert_(sql.Placeholder is sql.PH)
|
||||
|
||||
def test_repr(self):
|
||||
self.assert_(str(sql.Placeholder()), 'sql.Placeholder()')
|
||||
self.assert_(repr(sql.Placeholder()), 'sql.Placeholder()')
|
||||
self.assert_(sql.Placeholder().as_string(self.conn), '%s')
|
||||
|
||||
def test_repr_name(self):
|
||||
self.assert_(str(sql.Placeholder('foo')), "sql.Placeholder('foo')")
|
||||
self.assert_(repr(sql.Placeholder('foo')), "sql.Placeholder('foo')")
|
||||
self.assert_(sql.Placeholder('foo').as_string(self.conn), '%(foo)s')
|
||||
|
||||
def test_bad_name(self):
|
||||
self.assertRaises(ValueError, sql.Placeholder, ')')
|
||||
|
||||
|
||||
class ValuesTest(ConnectingTestCase):
|
||||
def test_null(self):
|
||||
self.assert_(isinstance(sql.NULL, sql.SQL))
|
||||
self.assertEqual(sql.NULL.as_string(self.conn), "NULL")
|
||||
|
||||
def test_default(self):
|
||||
self.assert_(isinstance(sql.DEFAULT, sql.SQL))
|
||||
self.assertEqual(sql.DEFAULT.as_string(self.conn), "DEFAULT")
|
||||
|
||||
|
||||
def test_suite():
|
||||
|
|
Loading…
Reference in New Issue
Block a user