From cba6d39be0f3db0b32db7f64ceb70aab972016fe Mon Sep 17 00:00:00 2001 From: Roman Konoval Date: Wed, 11 Sep 2024 17:12:41 +0200 Subject: [PATCH] removes duplication in tests --- tests/test_notify.py | 61 ++++++++++++++++---------------------------- 1 file changed, 22 insertions(+), 39 deletions(-) diff --git a/tests/test_notify.py b/tests/test_notify.py index e3bbccd0..873a419b 100755 --- a/tests/test_notify.py +++ b/tests/test_notify.py @@ -26,6 +26,7 @@ import os import unittest from collections import deque +from functools import partial import psycopg2 from psycopg2 import extensions @@ -129,69 +130,51 @@ conn.close() self.assertEqual(pid, self.conn.notifies[0][0]) self.assertEqual('foo', self.conn.notifies[0][1]) - @slow - @skip_if_windows - def test_notifies_received_on_commit(self): + def _test_notifies_received_on_operation(self, operation, execute_query=True): self.listen('foo') self.conn.commit() - self.conn.cursor().execute('select 1;') + if execute_query: + self.conn.cursor().execute('select 1;') pid = int(self.notify('foo').communicate()[0]) self.assertEqual(0, len(self.conn.notifies)) - self.conn.commit() + operation() self.assertEqual(1, len(self.conn.notifies)) self.assertEqual(pid, self.conn.notifies[0][0]) self.assertEqual('foo', self.conn.notifies[0][1]) + @slow + @skip_if_windows + def test_notifies_received_on_commit(self): + self._test_notifies_received_on_operation(self.conn.commit) + @slow @skip_if_windows def test_notifies_received_on_rollback(self): - self.listen('foo') - self.conn.commit() - self.conn.cursor().execute('select 1;') - pid = int(self.notify('foo').communicate()[0]) - self.assertEqual(0, len(self.conn.notifies)) - self.conn.rollback() - self.assertEqual(1, len(self.conn.notifies)) - self.assertEqual(pid, self.conn.notifies[0][0]) - self.assertEqual('foo', self.conn.notifies[0][1]) + self._test_notifies_received_on_operation(self.conn.rollback) @slow @skip_if_windows def test_notifies_received_on_reset(self): - self.listen('foo') - self.conn.commit() - pid = int(self.notify('foo').communicate()[0]) - self.assertEqual(0, len(self.conn.notifies)) - self.conn.reset() - self.assertEqual(1, len(self.conn.notifies)) - self.assertEqual(pid, self.conn.notifies[0][0]) - self.assertEqual('foo', self.conn.notifies[0][1]) + self._test_notifies_received_on_operation(self.conn.reset, execute_query=False) @slow @skip_if_windows def test_notifies_received_on_set_session(self): - self.listen('foo') - self.conn.commit() - pid = int(self.notify('foo').communicate()[0]) - self.assertEqual(0, len(self.conn.notifies)) - self.conn.set_session(autocommit=True, readonly=True) - self.assertEqual(1, len(self.conn.notifies)) - self.assertEqual(pid, self.conn.notifies[0][0]) - self.assertEqual('foo', self.conn.notifies[0][1]) + self._test_notifies_received_on_operation( + partial(self.conn.set_session, autocommit=True, readonly=True), + execute_query=False, + ) @slow @skip_if_windows def test_notifies_received_on_set_client_encoding(self): - self.listen('foo') - self.conn.commit() - pid = int(self.notify('foo').communicate()[0]) - self.assertEqual(0, len(self.conn.notifies)) - self.conn.set_client_encoding( - 'LATIN1' if self.conn.encoding != 'LATIN1' else 'UTF8' + self._test_notifies_received_on_operation( + partial( + self.conn.set_client_encoding, + 'LATIN1' if self.conn.encoding != 'LATIN1' else 'UTF8' + ), + execute_query=False, ) - self.assertEqual(1, len(self.conn.notifies)) - self.assertEqual(pid, self.conn.notifies[0][0]) - self.assertEqual('foo', self.conn.notifies[0][1]) @slow def test_notify_object(self):