diff --git a/irctest/cases.py b/irctest/cases.py index 03201c6..6d00d0e 100644 --- a/irctest/cases.py +++ b/irctest/cases.py @@ -6,7 +6,9 @@ from .irc_utils import message_parser class _IrcTestCase(unittest.TestCase): controllerClass = None # Will be set by __main__.py -class ClientTestCase(_IrcTestCase): +class BaseClientTestCase(_IrcTestCase): + """Basic class for client tests. Handles spawning a client and getting + messages from it.""" def setUp(self): self.controller = self.controllerClass() self._setUpServer() @@ -30,3 +32,26 @@ class ClientTestCase(_IrcTestCase): return self.conn_file.readline().strip() def getMessage(self): return message_parser.parse_message(self.conn_file.readline()) + +class NegociationHelper: + """Helper class for tests handling capabilities negociation.""" + def readCapLs(self): + (hostname, port) = self.server.getsockname() + self.controller.run( + hostname=hostname, + port=port, + authentication=None, + ) + self.acceptClient() + m = self.getMessage() + self.assertEqual(m.command, 'CAP', + 'First message is not CAP LS.') + self.assertEqual(m.subcommand, 'LS', + 'First message is not CAP LS.') + if m.params == []: + self.protocol_version = 301 + elif m.params == ['302']: + self.protocol_version = 302 + else: + raise AssertionError('Unknown protocol version {}' + .format(m.params)) diff --git a/irctest/clienttests/test_cap.py b/irctest/clienttests/test_cap.py index 23c24dd..5577ce5 100644 --- a/irctest/clienttests/test_cap.py +++ b/irctest/clienttests/test_cap.py @@ -1,18 +1,6 @@ -from irctest.cases import ClientTestCase +from irctest import cases from irctest.irc_utils.message_parser import Message -class CapTestCase(ClientTestCase): +class CapTestCase(cases.BaseClientTestCase, cases.NegociationHelper): def testSendCap(self): - (hostname, port) = self.server.getsockname() - self.controller.run( - hostname=hostname, - port=port, - authentication=None, - ) - self.acceptClient() - m = self.getMessage() - self.assertEqual(m.command, 'CAP', - 'First message is not CAP LS.') - self.assertEqual(m.subcommand, 'LS', - 'First message is not CAP LS.') - self.assertIn(m.params, ([], ['302'])) # IRCv3.1 or IRVv3.2 + self.readCapLs()