diff --git a/lib/errors.py b/lib/errors.py new file mode 100644 index 00000000..1e1d8a84 --- /dev/null +++ b/lib/errors.py @@ -0,0 +1,16 @@ +"""Error classes for PostgreSQL error codes +""" + +from psycopg2._psycopg import ( # noqa + Error, Warning, DataError, DatabaseError, ProgrammingError, IntegrityError, + InterfaceError, InternalError, NotSupportedError, OperationalError, + QueryCanceledError, TransactionRollbackError) + + +_by_sqlstate = {} + + +class UndefinedTable(ProgrammingError): + pass + +_by_sqlstate['42P01'] = UndefinedTable diff --git a/psycopg/pqpath.c b/psycopg/pqpath.c index 204a6b00..7e2f0360 100644 --- a/psycopg/pqpath.c +++ b/psycopg/pqpath.c @@ -76,6 +76,35 @@ strip_severity(const char *msg) return msg; } +/* Return a Python exception from a SQLSTATE from psycopg2.errors */ +BORROWED static PyObject * +exception_from_module(const char *sqlstate) +{ + PyObject *rv = NULL; + PyObject *m = NULL; + PyObject *map = NULL; + + if (!(m = PyImport_ImportModule("psycopg2.errors"))) { goto exit; } + if (!(map = PyObject_GetAttrString(m, "_by_sqlstate"))) { goto exit; } + if (!PyDict_Check(map)) { + Dprintf("'psycopg2.errors._by_sqlstate' is not a dict!"); + goto exit; + } + + /* get the sqlstate class (borrowed reference), or fail trying. */ + rv = PyDict_GetItemString(map, sqlstate); + +exit: + /* We exit with a borrowed object, or a NULL but no error + * If an error did happen in this function, we don't want to clobber the + * database error. So better reporting it, albeit with the wrong class. */ + PyErr_Clear(); + + Py_XDECREF(map); + Py_XDECREF(m); + return rv; +} + /* Returns the Python exception corresponding to an SQLSTATE error code. A list of error codes can be found at: @@ -83,6 +112,17 @@ strip_severity(const char *msg) BORROWED static PyObject * exception_from_sqlstate(const char *sqlstate) { + PyObject *exc; + + /* First look up an exception of the proper class from the Python module */ + exc = exception_from_module(sqlstate); + if (exc) { + return exc; + } + else { + PyErr_Clear(); + } + switch (sqlstate[0]) { case '0': switch (sqlstate[1]) { diff --git a/tests/__init__.py b/tests/__init__.py index e58b6fa7..5c57849e 100755 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -40,6 +40,7 @@ from . import test_copy from . import test_cursor from . import test_dates from . import test_errcodes +from . import test_errors from . import test_extras_dictcursor from . import test_fast_executemany from . import test_green @@ -84,6 +85,7 @@ def test_suite(): suite.addTest(test_cursor.test_suite()) suite.addTest(test_dates.test_suite()) suite.addTest(test_errcodes.test_suite()) + suite.addTest(test_errors.test_suite()) suite.addTest(test_extras_dictcursor.test_suite()) suite.addTest(test_fast_executemany.test_suite()) suite.addTest(test_green.test_suite()) diff --git a/tests/test_errors.py b/tests/test_errors.py new file mode 100755 index 00000000..5645e182 --- /dev/null +++ b/tests/test_errors.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python + +# test_errors.py - unit test for psycopg2.errors module +# +# Copyright (C) 2018 Daniele Varrazzo +# +# psycopg2 is free software: you can redistribute it and/or modify it +# under the terms of the GNU Lesser General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# In addition, as a special exception, the copyright holders give +# permission to link this program with the OpenSSL library (or with +# modified versions of OpenSSL that use the same license as OpenSSL), +# and distribute linked combinations including the two. +# +# You must obey the GNU Lesser General Public License in all respects for +# all of the code used other than OpenSSL. +# +# psycopg2 is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public +# License for more details. + +import unittest +from .testutils import ConnectingTestCase + +import psycopg2 + + +class ErrorsTests(ConnectingTestCase): + def test_exception_class(self): + cur = self.conn.cursor() + try: + cur.execute("select * from nonexist") + except psycopg2.Error as exc: + e = exc + + from psycopg2.errors import UndefinedTable + self.assert_(isinstance(e, UndefinedTable), type(e)) + self.assert_(isinstance(e, self.conn.ProgrammingError)) + + def test_exception_class_fallback(self): + cur = self.conn.cursor() + + from psycopg2 import errors + x = errors._by_sqlstate.pop('42P01') + try: + cur.execute("select * from nonexist") + except psycopg2.Error as exc: + e = exc + finally: + errors._by_sqlstate['42P01'] = x + + self.assertEqual(type(e), self.conn.ProgrammingError) + + +def test_suite(): + return unittest.TestLoader().loadTestsFromName(__name__) + +if __name__ == "__main__": + unittest.main()