diff --git a/conftest.py b/conftest.py index af3e5dc..0a65b3f 100644 --- a/conftest.py +++ b/conftest.py @@ -1,5 +1,4 @@ import importlib -import unittest import _pytest.unittest import pytest @@ -77,14 +76,13 @@ def pytest_collection_modifyitems(session, config, items): # Iterate over each of the test functions (they are pytest "Nodes") for item in items: - # we only use unittest-style test function here - assert isinstance(item, _pytest.unittest.TestCaseFunction) + assert isinstance(item, _pytest.python.Function) # unittest-style test functions have the node of UnitTest class as parent - assert isinstance(item.parent, _pytest.unittest.UnitTestCase) + assert isinstance(item.parent, _pytest.python.Instance) # and that node references the UnitTest class - assert issubclass(item.parent.cls, unittest.TestCase) + assert issubclass(item.parent.cls, _IrcTestCase) # and in this project, TestCase classes all inherit either from # BaseClientController or BaseServerController. diff --git a/irctest/cases.py b/irctest/cases.py index 9e3fddc..3070069 100644 --- a/irctest/cases.py +++ b/irctest/cases.py @@ -1,3 +1,4 @@ +import contextlib import functools import socket import ssl @@ -11,6 +12,7 @@ from typing import ( Generic, Hashable, Iterable, + Iterator, List, Optional, Set, @@ -20,7 +22,6 @@ from typing import ( Union, cast, ) -import unittest import pytest @@ -29,7 +30,7 @@ from .authentication import Authentication from .basecontrollers import TestCaseControllerConfig from .exceptions import ConnectionClosed from .irc_utils import capabilities, message_parser -from .irc_utils.junkdrawer import find_hostname_and_port, normalizeWhitespace +from .irc_utils.junkdrawer import find_hostname_and_port from .irc_utils.message_parser import Message from .irc_utils.sasl import sasl_plain_blob from .numerics import ( @@ -75,8 +76,14 @@ class ChannelJoinException(Exception): self.params = params -class _IrcTestCase(unittest.TestCase, Generic[TController]): - """Base class for test cases.""" +class _IrcTestCase(Generic[TController]): + """Base class for test cases. + + It implements various `assert*` method that look like unittest's, + but is actually based on the `assert` statement so derived classes are + pytest-style rather than unittest-style. + + It also calls setUp() and tearDown() like unittest would.""" # Will be set by __main__.py controllerClass: Type[TController] @@ -84,6 +91,8 @@ class _IrcTestCase(unittest.TestCase, Generic[TController]): controller: TController + __new__ = object.__new__ # pytest won't collect Generic subclasses otherwise + @staticmethod def config() -> TestCaseControllerConfig: """Some configuration to pass to the controllers. @@ -92,21 +101,21 @@ class _IrcTestCase(unittest.TestCase, Generic[TController]): """ return TestCaseControllerConfig() - def description(self) -> str: - method_doc = self._testMethodDoc - if not method_doc: - return "" - return "\t" + normalizeWhitespace( - method_doc, removeNewline=False - ).strip().replace("\n ", "\n\t") - def setUp(self) -> None: - super().setUp() if self.controllerClass is not None: self.controller = self.controllerClass(self.config()) if self.show_io: print("---- new test ----") + def tearDown(self) -> None: + pass + + def setup_method(self, method: Callable) -> None: + self.setUp() + + def teardown_method(self, method: Callable) -> None: + self.tearDown() + def assertMessageMatch(self, msg: Message, **kwargs: Any) -> None: """Helper for partially comparing a message. @@ -117,7 +126,7 @@ class _IrcTestCase(unittest.TestCase, Generic[TController]): """ error = self.messageDiffers(msg, **kwargs) if error: - raise self.failureException(error) + raise AssertionError(error) def messageEqual(self, msg: Message, **kwargs: Any) -> bool: """Boolean negation of `messageDiffers` (returns a boolean, @@ -187,7 +196,7 @@ class _IrcTestCase(unittest.TestCase, Generic[TController]): ) -> None: if fail_msg: msg = fail_msg.format(*extra_format, item=member, list=container, msg=msg) - super().assertIn(member, container, msg) + assert member in container, msg # type: ignore def assertNotIn( self, @@ -199,7 +208,7 @@ class _IrcTestCase(unittest.TestCase, Generic[TController]): ) -> None: if fail_msg: msg = fail_msg.format(*extra_format, item=member, list=container, msg=msg) - super().assertNotIn(member, container, msg) + assert member not in container, msg # type: ignore def assertEqual( self, @@ -211,7 +220,7 @@ class _IrcTestCase(unittest.TestCase, Generic[TController]): ) -> None: if fail_msg: msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg) - super().assertEqual(got, expects, msg) + assert got == expects, msg def assertNotEqual( self, @@ -223,7 +232,7 @@ class _IrcTestCase(unittest.TestCase, Generic[TController]): ) -> None: if fail_msg: msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg) - super().assertNotEqual(got, expects, msg) + assert got != expects, msg def assertGreater( self, @@ -235,7 +244,7 @@ class _IrcTestCase(unittest.TestCase, Generic[TController]): ) -> None: if fail_msg: msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg) - super().assertGreater(got, expects, msg) + assert got >= expects, msg # type: ignore def assertGreaterEqual( self, @@ -247,7 +256,7 @@ class _IrcTestCase(unittest.TestCase, Generic[TController]): ) -> None: if fail_msg: msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg) - super().assertGreaterEqual(got, expects, msg) + assert got >= expects, msg # type: ignore def assertLess( self, @@ -259,7 +268,7 @@ class _IrcTestCase(unittest.TestCase, Generic[TController]): ) -> None: if fail_msg: msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg) - super().assertLess(got, expects, msg) + assert got < expects, msg # type: ignore def assertLessEqual( self, @@ -271,7 +280,34 @@ class _IrcTestCase(unittest.TestCase, Generic[TController]): ) -> None: if fail_msg: msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg) - super().assertLessEqual(got, expects, msg) + assert got <= expects, msg # type: ignore + + def assertTrue( + self, + got: T, + msg: Any = None, + fail_msg: Optional[str] = None, + extra_format: Tuple = (), + ) -> None: + if fail_msg: + msg = fail_msg.format(*extra_format, got=got, msg=msg) + assert got, msg + + def assertFalse( + self, + got: T, + msg: Any = None, + fail_msg: Optional[str] = None, + extra_format: Tuple = (), + ) -> None: + if fail_msg: + msg = fail_msg.format(*extra_format, got=got, msg=msg) + assert not got, msg + + @contextlib.contextmanager + def assertRaises(self, exception: Type[Exception]) -> Iterator[None]: + with pytest.raises(exception): + yield class BaseClientTestCase(_IrcTestCase[basecontrollers.BaseClientController]): @@ -285,6 +321,8 @@ class BaseClientTestCase(_IrcTestCase[basecontrollers.BaseClientController]): protocol_version = Optional[str] acked_capabilities = Optional[Set[str]] + __new__ = object.__new__ # pytest won't collect Generic[] subclasses otherwise + def setUp(self) -> None: super().setUp() self.conn = None @@ -461,6 +499,8 @@ class BaseServerTestCase( server_support: Optional[Dict[str, Optional[str]]] run_services = False + __new__ = object.__new__ # pytest won't collect Generic[] subclasses otherwise + def setUp(self) -> None: super().setUp() self.server_support = None diff --git a/irctest/self_tests/test_cases.py b/irctest/self_tests/test_cases.py index e31ea51..2681987 100644 --- a/irctest/self_tests/test_cases.py +++ b/irctest/self_tests/test_cases.py @@ -157,18 +157,29 @@ MESSAGE_SPECS: List[Tuple[Dict, List[str], List[str]]] = [ class IrcTestCaseTestCase(cases._IrcTestCase): - def test_message_matching(self): - for (spec, positive_matches, negative_matches) in MESSAGE_SPECS: - with self.subTest(spec): - for msg in positive_matches: - with self.subTest(msg): - assert not self.messageDiffers(parse_message(msg), **spec), msg - assert self.messageEqual(parse_message(msg), **spec), msg - self.assertMessageMatch(parse_message(msg), **spec), msg + @pytest.mark.parametrize( + "spec,msg", + [ + pytest.param(spec, msg, id=f"{spec}-{msg}") + for (spec, positive_matches, _) in MESSAGE_SPECS + for msg in positive_matches + ], + ) + def test_message_matching_positive(self, spec, msg): + assert not self.messageDiffers(parse_message(msg), **spec), msg + assert self.messageEqual(parse_message(msg), **spec), msg + self.assertMessageMatch(parse_message(msg), **spec), msg - for msg in negative_matches: - with self.subTest(msg): - assert self.messageDiffers(parse_message(msg), **spec), msg - assert not self.messageEqual(parse_message(msg), **spec), msg - with pytest.raises(AssertionError): - self.assertMessageMatch(parse_message(msg), **spec), msg + @pytest.mark.parametrize( + "spec,msg", + [ + pytest.param(spec, msg, id=f"{spec}-{msg}") + for (spec, _, negative_matches) in MESSAGE_SPECS + for msg in negative_matches + ], + ) + def test_message_matching_negative(self, spec, msg): + assert self.messageDiffers(parse_message(msg), **spec), msg + assert not self.messageEqual(parse_message(msg), **spec), msg + with pytest.raises(AssertionError): + self.assertMessageMatch(parse_message(msg), **spec), msg diff --git a/pytest.ini b/pytest.ini index 6a525fb..a715143 100644 --- a/pytest.ini +++ b/pytest.ini @@ -33,3 +33,5 @@ markers = BOT MONITOR STATUSMSG + +python_classes = *TestCase Test*