From 5fd0f6c4eefecb0d6150179c32c43d16c11b173d Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Wed, 16 Dec 2015 12:00:52 +0000 Subject: [PATCH] Fixed race condition on import in errorcodes.lookup Fixes #382. --- NEWS | 1 + lib/errorcodes.py | 10 +++++-- tests/__init__.py | 2 ++ tests/test_errcodes.py | 65 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 76 insertions(+), 2 deletions(-) create mode 100755 tests/test_errcodes.py diff --git a/NEWS b/NEWS index 5200c4dd..c1e4152f 100644 --- a/NEWS +++ b/NEWS @@ -27,6 +27,7 @@ What's new in psycopg 2.6.2 - Raise `!NotSupportedError` on unhandled server response status (:ticket:`#352`). - Fixed `!PersistentConnectionPool` on Python 3 (:ticket:`#348`). +- Fixed `!errorcodes.lookup` initialization thread-safety (:ticket:`#382`). What's new in psycopg 2.6.1 diff --git a/lib/errorcodes.py b/lib/errorcodes.py index 12c300f6..aa5a723c 100644 --- a/lib/errorcodes.py +++ b/lib/errorcodes.py @@ -38,11 +38,17 @@ def lookup(code, _cache={}): return _cache[code] # Generate the lookup map at first usage. + tmp = {} for k, v in globals().iteritems(): if isinstance(v, str) and len(v) in (2, 5): - _cache[v] = k + tmp[v] = k - return lookup(code) + assert tmp + + # Atomic update, to avoid race condition on import (bug #382) + _cache.update(tmp) + + return _cache[code] # autogenerated data: do not edit below this point. diff --git a/tests/__init__.py b/tests/__init__.py index 3e677d85..3e0db779 100755 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -34,6 +34,7 @@ import test_connection import test_copy import test_cursor import test_dates +import test_errcodes import test_extras_dictcursor import test_green import test_lobject @@ -71,6 +72,7 @@ def test_suite(): suite.addTest(test_copy.test_suite()) suite.addTest(test_cursor.test_suite()) suite.addTest(test_dates.test_suite()) + suite.addTest(test_errcodes.test_suite()) suite.addTest(test_extras_dictcursor.test_suite()) suite.addTest(test_green.test_suite()) suite.addTest(test_lobject.test_suite()) diff --git a/tests/test_errcodes.py b/tests/test_errcodes.py new file mode 100755 index 00000000..6cf5ddba --- /dev/null +++ b/tests/test_errcodes.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python + +# test_errcodes.py - unit test for psycopg2.errcodes module +# +# Copyright (C) 2015 Daniele Varrazzo +# +# psycopg2 is free software: you can redistribute it and/or modify it +# under the terms of the GNU Lesser General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# In addition, as a special exception, the copyright holders give +# permission to link this program with the OpenSSL library (or with +# modified versions of OpenSSL that use the same license as OpenSSL), +# and distribute linked combinations including the two. +# +# You must obey the GNU Lesser General Public License in all respects for +# all of the code used other than OpenSSL. +# +# psycopg2 is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public +# License for more details. + +from testutils import unittest, ConnectingTestCase + +try: + reload +except NameError: + from imp import reload + +from threading import Thread +from psycopg2 import errorcodes + +class ErrocodeTests(ConnectingTestCase): + def test_lookup_threadsafe(self): + + # Increase if it does not fail with KeyError + MAX_CYCLES = 2000 + + errs = [] + def f(pg_code='40001'): + try: + errorcodes.lookup(pg_code) + except Exception, e: + errs.append(e) + + for __ in xrange(MAX_CYCLES): + reload(errorcodes) + (t1, t2) = (Thread(target=f), Thread(target=f)) + (t1.start(), t2.start()) + (t1.join(), t2.join()) + + if errs: + self.fail( + "raised %s errors in %s cycles (first is %s %s)" % ( + len(errs), MAX_CYCLES, + errs[0].__class__.__name__, errs[0])) + + +def test_suite(): + return unittest.TestLoader().loadTestsFromName(__name__) + +if __name__ == "__main__": + unittest.main()