Switch from unittest-style to pytest-style test collection

I was to use parametrization in a future test, but pytest doesn't
support it on unittest-style tests.
This commit is contained in:
Valentin Lorentz 2021-07-04 16:18:33 +02:00 committed by Val Lorentz
parent ed2b75534e
commit 0177c369dd
4 changed files with 92 additions and 41 deletions

View File

@ -1,5 +1,4 @@
import importlib import importlib
import unittest
import _pytest.unittest import _pytest.unittest
import pytest import pytest
@ -77,14 +76,13 @@ def pytest_collection_modifyitems(session, config, items):
# Iterate over each of the test functions (they are pytest "Nodes") # Iterate over each of the test functions (they are pytest "Nodes")
for item in items: for item in items:
# we only use unittest-style test function here assert isinstance(item, _pytest.python.Function)
assert isinstance(item, _pytest.unittest.TestCaseFunction)
# unittest-style test functions have the node of UnitTest class as parent # unittest-style test functions have the node of UnitTest class as parent
assert isinstance(item.parent, _pytest.unittest.UnitTestCase) assert isinstance(item.parent, _pytest.python.Instance)
# and that node references the UnitTest class # and that node references the UnitTest class
assert issubclass(item.parent.cls, unittest.TestCase) assert issubclass(item.parent.cls, _IrcTestCase)
# and in this project, TestCase classes all inherit either from # and in this project, TestCase classes all inherit either from
# BaseClientController or BaseServerController. # BaseClientController or BaseServerController.

View File

@ -1,3 +1,4 @@
import contextlib
import functools import functools
import socket import socket
import ssl import ssl
@ -11,6 +12,7 @@ from typing import (
Generic, Generic,
Hashable, Hashable,
Iterable, Iterable,
Iterator,
List, List,
Optional, Optional,
Set, Set,
@ -20,7 +22,6 @@ from typing import (
Union, Union,
cast, cast,
) )
import unittest
import pytest import pytest
@ -29,7 +30,7 @@ from .authentication import Authentication
from .basecontrollers import TestCaseControllerConfig from .basecontrollers import TestCaseControllerConfig
from .exceptions import ConnectionClosed from .exceptions import ConnectionClosed
from .irc_utils import capabilities, message_parser from .irc_utils import capabilities, message_parser
from .irc_utils.junkdrawer import find_hostname_and_port, normalizeWhitespace from .irc_utils.junkdrawer import find_hostname_and_port
from .irc_utils.message_parser import Message from .irc_utils.message_parser import Message
from .irc_utils.sasl import sasl_plain_blob from .irc_utils.sasl import sasl_plain_blob
from .numerics import ( from .numerics import (
@ -75,8 +76,14 @@ class ChannelJoinException(Exception):
self.params = params self.params = params
class _IrcTestCase(unittest.TestCase, Generic[TController]): class _IrcTestCase(Generic[TController]):
"""Base class for test cases.""" """Base class for test cases.
It implements various `assert*` method that look like unittest's,
but is actually based on the `assert` statement so derived classes are
pytest-style rather than unittest-style.
It also calls setUp() and tearDown() like unittest would."""
# Will be set by __main__.py # Will be set by __main__.py
controllerClass: Type[TController] controllerClass: Type[TController]
@ -84,6 +91,8 @@ class _IrcTestCase(unittest.TestCase, Generic[TController]):
controller: TController controller: TController
__new__ = object.__new__ # pytest won't collect Generic subclasses otherwise
@staticmethod @staticmethod
def config() -> TestCaseControllerConfig: def config() -> TestCaseControllerConfig:
"""Some configuration to pass to the controllers. """Some configuration to pass to the controllers.
@ -92,21 +101,21 @@ class _IrcTestCase(unittest.TestCase, Generic[TController]):
""" """
return TestCaseControllerConfig() return TestCaseControllerConfig()
def description(self) -> str:
method_doc = self._testMethodDoc
if not method_doc:
return ""
return "\t" + normalizeWhitespace(
method_doc, removeNewline=False
).strip().replace("\n ", "\n\t")
def setUp(self) -> None: def setUp(self) -> None:
super().setUp()
if self.controllerClass is not None: if self.controllerClass is not None:
self.controller = self.controllerClass(self.config()) self.controller = self.controllerClass(self.config())
if self.show_io: if self.show_io:
print("---- new test ----") print("---- new test ----")
def tearDown(self) -> None:
pass
def setup_method(self, method: Callable) -> None:
self.setUp()
def teardown_method(self, method: Callable) -> None:
self.tearDown()
def assertMessageMatch(self, msg: Message, **kwargs: Any) -> None: def assertMessageMatch(self, msg: Message, **kwargs: Any) -> None:
"""Helper for partially comparing a message. """Helper for partially comparing a message.
@ -117,7 +126,7 @@ class _IrcTestCase(unittest.TestCase, Generic[TController]):
""" """
error = self.messageDiffers(msg, **kwargs) error = self.messageDiffers(msg, **kwargs)
if error: if error:
raise self.failureException(error) raise AssertionError(error)
def messageEqual(self, msg: Message, **kwargs: Any) -> bool: def messageEqual(self, msg: Message, **kwargs: Any) -> bool:
"""Boolean negation of `messageDiffers` (returns a boolean, """Boolean negation of `messageDiffers` (returns a boolean,
@ -187,7 +196,7 @@ class _IrcTestCase(unittest.TestCase, Generic[TController]):
) -> None: ) -> None:
if fail_msg: if fail_msg:
msg = fail_msg.format(*extra_format, item=member, list=container, msg=msg) msg = fail_msg.format(*extra_format, item=member, list=container, msg=msg)
super().assertIn(member, container, msg) assert member in container, msg # type: ignore
def assertNotIn( def assertNotIn(
self, self,
@ -199,7 +208,7 @@ class _IrcTestCase(unittest.TestCase, Generic[TController]):
) -> None: ) -> None:
if fail_msg: if fail_msg:
msg = fail_msg.format(*extra_format, item=member, list=container, msg=msg) msg = fail_msg.format(*extra_format, item=member, list=container, msg=msg)
super().assertNotIn(member, container, msg) assert member not in container, msg # type: ignore
def assertEqual( def assertEqual(
self, self,
@ -211,7 +220,7 @@ class _IrcTestCase(unittest.TestCase, Generic[TController]):
) -> None: ) -> None:
if fail_msg: if fail_msg:
msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg) msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
super().assertEqual(got, expects, msg) assert got == expects, msg
def assertNotEqual( def assertNotEqual(
self, self,
@ -223,7 +232,7 @@ class _IrcTestCase(unittest.TestCase, Generic[TController]):
) -> None: ) -> None:
if fail_msg: if fail_msg:
msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg) msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
super().assertNotEqual(got, expects, msg) assert got != expects, msg
def assertGreater( def assertGreater(
self, self,
@ -235,7 +244,7 @@ class _IrcTestCase(unittest.TestCase, Generic[TController]):
) -> None: ) -> None:
if fail_msg: if fail_msg:
msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg) msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
super().assertGreater(got, expects, msg) assert got >= expects, msg # type: ignore
def assertGreaterEqual( def assertGreaterEqual(
self, self,
@ -247,7 +256,7 @@ class _IrcTestCase(unittest.TestCase, Generic[TController]):
) -> None: ) -> None:
if fail_msg: if fail_msg:
msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg) msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
super().assertGreaterEqual(got, expects, msg) assert got >= expects, msg # type: ignore
def assertLess( def assertLess(
self, self,
@ -259,7 +268,7 @@ class _IrcTestCase(unittest.TestCase, Generic[TController]):
) -> None: ) -> None:
if fail_msg: if fail_msg:
msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg) msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
super().assertLess(got, expects, msg) assert got < expects, msg # type: ignore
def assertLessEqual( def assertLessEqual(
self, self,
@ -271,7 +280,34 @@ class _IrcTestCase(unittest.TestCase, Generic[TController]):
) -> None: ) -> None:
if fail_msg: if fail_msg:
msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg) msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
super().assertLessEqual(got, expects, msg) assert got <= expects, msg # type: ignore
def assertTrue(
self,
got: T,
msg: Any = None,
fail_msg: Optional[str] = None,
extra_format: Tuple = (),
) -> None:
if fail_msg:
msg = fail_msg.format(*extra_format, got=got, msg=msg)
assert got, msg
def assertFalse(
self,
got: T,
msg: Any = None,
fail_msg: Optional[str] = None,
extra_format: Tuple = (),
) -> None:
if fail_msg:
msg = fail_msg.format(*extra_format, got=got, msg=msg)
assert not got, msg
@contextlib.contextmanager
def assertRaises(self, exception: Type[Exception]) -> Iterator[None]:
with pytest.raises(exception):
yield
class BaseClientTestCase(_IrcTestCase[basecontrollers.BaseClientController]): class BaseClientTestCase(_IrcTestCase[basecontrollers.BaseClientController]):
@ -285,6 +321,8 @@ class BaseClientTestCase(_IrcTestCase[basecontrollers.BaseClientController]):
protocol_version = Optional[str] protocol_version = Optional[str]
acked_capabilities = Optional[Set[str]] acked_capabilities = Optional[Set[str]]
__new__ = object.__new__ # pytest won't collect Generic[] subclasses otherwise
def setUp(self) -> None: def setUp(self) -> None:
super().setUp() super().setUp()
self.conn = None self.conn = None
@ -461,6 +499,8 @@ class BaseServerTestCase(
server_support: Optional[Dict[str, Optional[str]]] server_support: Optional[Dict[str, Optional[str]]]
run_services = False run_services = False
__new__ = object.__new__ # pytest won't collect Generic[] subclasses otherwise
def setUp(self) -> None: def setUp(self) -> None:
super().setUp() super().setUp()
self.server_support = None self.server_support = None

View File

@ -157,17 +157,28 @@ MESSAGE_SPECS: List[Tuple[Dict, List[str], List[str]]] = [
class IrcTestCaseTestCase(cases._IrcTestCase): class IrcTestCaseTestCase(cases._IrcTestCase):
def test_message_matching(self): @pytest.mark.parametrize(
for (spec, positive_matches, negative_matches) in MESSAGE_SPECS: "spec,msg",
with self.subTest(spec): [
for msg in positive_matches: pytest.param(spec, msg, id=f"{spec}-{msg}")
with self.subTest(msg): for (spec, positive_matches, _) in MESSAGE_SPECS
for msg in positive_matches
],
)
def test_message_matching_positive(self, spec, msg):
assert not self.messageDiffers(parse_message(msg), **spec), msg assert not self.messageDiffers(parse_message(msg), **spec), msg
assert self.messageEqual(parse_message(msg), **spec), msg assert self.messageEqual(parse_message(msg), **spec), msg
self.assertMessageMatch(parse_message(msg), **spec), msg self.assertMessageMatch(parse_message(msg), **spec), msg
for msg in negative_matches: @pytest.mark.parametrize(
with self.subTest(msg): "spec,msg",
[
pytest.param(spec, msg, id=f"{spec}-{msg}")
for (spec, _, negative_matches) in MESSAGE_SPECS
for msg in negative_matches
],
)
def test_message_matching_negative(self, spec, msg):
assert self.messageDiffers(parse_message(msg), **spec), msg assert self.messageDiffers(parse_message(msg), **spec), msg
assert not self.messageEqual(parse_message(msg), **spec), msg assert not self.messageEqual(parse_message(msg), **spec), msg
with pytest.raises(AssertionError): with pytest.raises(AssertionError):

View File

@ -33,3 +33,5 @@ markers =
BOT BOT
MONITOR MONITOR
STATUSMSG STATUSMSG
python_classes = *TestCase Test*