removes duplication in tests

This commit is contained in:
Roman Konoval 2024-09-11 17:12:41 +02:00 committed by Daniele Varrazzo
parent 282360dd04
commit cba6d39be0

View File

@ -26,6 +26,7 @@
import os
import unittest
from collections import deque
from functools import partial
import psycopg2
from psycopg2 import extensions
@ -129,69 +130,51 @@ conn.close()
self.assertEqual(pid, self.conn.notifies[0][0])
self.assertEqual('foo', self.conn.notifies[0][1])
@slow
@skip_if_windows
def test_notifies_received_on_commit(self):
def _test_notifies_received_on_operation(self, operation, execute_query=True):
self.listen('foo')
self.conn.commit()
if execute_query:
self.conn.cursor().execute('select 1;')
pid = int(self.notify('foo').communicate()[0])
self.assertEqual(0, len(self.conn.notifies))
self.conn.commit()
operation()
self.assertEqual(1, len(self.conn.notifies))
self.assertEqual(pid, self.conn.notifies[0][0])
self.assertEqual('foo', self.conn.notifies[0][1])
@slow
@skip_if_windows
def test_notifies_received_on_commit(self):
self._test_notifies_received_on_operation(self.conn.commit)
@slow
@skip_if_windows
def test_notifies_received_on_rollback(self):
self.listen('foo')
self.conn.commit()
self.conn.cursor().execute('select 1;')
pid = int(self.notify('foo').communicate()[0])
self.assertEqual(0, len(self.conn.notifies))
self.conn.rollback()
self.assertEqual(1, len(self.conn.notifies))
self.assertEqual(pid, self.conn.notifies[0][0])
self.assertEqual('foo', self.conn.notifies[0][1])
self._test_notifies_received_on_operation(self.conn.rollback)
@slow
@skip_if_windows
def test_notifies_received_on_reset(self):
self.listen('foo')
self.conn.commit()
pid = int(self.notify('foo').communicate()[0])
self.assertEqual(0, len(self.conn.notifies))
self.conn.reset()
self.assertEqual(1, len(self.conn.notifies))
self.assertEqual(pid, self.conn.notifies[0][0])
self.assertEqual('foo', self.conn.notifies[0][1])
self._test_notifies_received_on_operation(self.conn.reset, execute_query=False)
@slow
@skip_if_windows
def test_notifies_received_on_set_session(self):
self.listen('foo')
self.conn.commit()
pid = int(self.notify('foo').communicate()[0])
self.assertEqual(0, len(self.conn.notifies))
self.conn.set_session(autocommit=True, readonly=True)
self.assertEqual(1, len(self.conn.notifies))
self.assertEqual(pid, self.conn.notifies[0][0])
self.assertEqual('foo', self.conn.notifies[0][1])
self._test_notifies_received_on_operation(
partial(self.conn.set_session, autocommit=True, readonly=True),
execute_query=False,
)
@slow
@skip_if_windows
def test_notifies_received_on_set_client_encoding(self):
self.listen('foo')
self.conn.commit()
pid = int(self.notify('foo').communicate()[0])
self.assertEqual(0, len(self.conn.notifies))
self.conn.set_client_encoding(
self._test_notifies_received_on_operation(
partial(
self.conn.set_client_encoding,
'LATIN1' if self.conn.encoding != 'LATIN1' else 'UTF8'
),
execute_query=False,
)
self.assertEqual(1, len(self.conn.notifies))
self.assertEqual(pid, self.conn.notifies[0][0])
self.assertEqual('foo', self.conn.notifies[0][1])
@slow
def test_notify_object(self):