diff --git a/irctest/basecontrollers.py b/irctest/basecontrollers.py index b50353c..5163d41 100644 --- a/irctest/basecontrollers.py +++ b/irctest/basecontrollers.py @@ -60,7 +60,7 @@ class BaseServerController(_BaseController): """Base controller for IRC server.""" def run(self, hostname, port, start_wait): raise NotImplementedError() - def registerUser(self, case, username): + def registerUser(self, case, username, password=None): raise NotImplementedByController('registration') def wait_for_port(self, proc, port): port_open = False diff --git a/irctest/cases.py b/irctest/cases.py index ec099a5..f775de6 100644 --- a/irctest/cases.py +++ b/irctest/cases.py @@ -1,10 +1,13 @@ import time import socket import unittest +import functools import collections from . import authentication +from . import optional_extensions from .irc_utils import message_parser +from .irc_utils import capabilities class _IrcTestCase(unittest.TestCase): """Base class for test cases.""" @@ -59,12 +62,15 @@ class BaseClientTestCase(_IrcTestCase): user = None def setUp(self): super().setUp() + self.conn = None self._setUpServer() def tearDown(self): - self.conn.sendall(b'QUIT :end of test.') + if self.conn: + self.conn.sendall(b'QUIT :end of test.') self.controller.kill() - self.conn_file.close() - self.conn.close() + if self.conn: + self.conn_file.close() + self.conn.close() self.server.close() def _setUpServer(self): @@ -129,7 +135,7 @@ class ClientNegociationHelper: else: return True - def negotiateCapabilities(self, capabilities, cap_ls=True, auth=None): + def negotiateCapabilities(self, caps, cap_ls=True, auth=None): """Performes a complete capability negociation process, without ending it, so the caller can continue the negociation.""" if cap_ls: @@ -137,8 +143,8 @@ class ClientNegociationHelper: if not self.protocol_version: # No negotiation. return - self.sendLine('CAP * LS :{}'.format(' '.join(capabilities))) - capability_names = {x.split('=')[0] for x in capabilities} + self.sendLine('CAP * LS :{}'.format(' '.join(caps))) + capability_names = frozenset(capabilities.cap_list_to_dict(caps)) self.acked_capabilities = set() while True: m = self.getMessage(filter_pred=self.userNickPredicate) @@ -191,11 +197,15 @@ class BaseServerTestCase(_IrcTestCase): conn.connect((self.hostname, self.port)) conn_file = conn.makefile(newline='\r\n', encoding='utf8') self.clients[name] = Client(conn=conn, conn_file=conn_file) + if self.show_io: + print('{}: connects to server.'.format(name)) return name def removeClient(self, name): """Disconnects the client, without QUIT.""" assert name in self.clients + if self.show_io: + print('{}: disconnects from server.'.format(name)) self.clients[name].conn.close() del self.clients[name] @@ -209,7 +219,7 @@ class BaseServerTestCase(_IrcTestCase): data += conn.recv(4096) except BlockingIOError: for line in data.decode().split('\r\n'): - if line: + if line and self.show_io: print('S -> {}: {}'.format(client, line.strip())) yield line + '\r\n' finally: @@ -229,16 +239,37 @@ class BaseServerTestCase(_IrcTestCase): if self.show_io: print('{} -> S: {}'.format(client, line.strip())) - def getCapLs(self, client): + def getCapLs(self, client, as_list=False): """Waits for a CAP LS block, parses all CAP LS messages, and return - the list of capabilities.""" - capabilities = [] + the dict capabilities, with their values. + + If as_list is given, returns the raw list (ie. key/value not split) + in case the order matters (but it shouldn't).""" + caps = [] while True: m = self.getMessage(client, filter_pred=lambda m:m.command != 'NOTICE') self.assertMessageEqual(m, command='CAP', subcommand='LS') if m.params[2] == '*': - capabilities.extend(m.params[3].split()) + caps.extend(m.params[3].split()) else: - capabilities.extend(m.params[2].split()) - return capabilities + caps.extend(m.params[2].split()) + if not as_list: + caps = capabilities.cap_list_to_dict(caps) + return caps + +class OptionalityHelper: + def checkMechanismSupport(self, mechanism): + if mechanism in self.controller.supported_sasl_mechanisms: + return + raise optional_extensions.OptionalSaslMechanismNotSupported(mechanism) + + def skipUnlessHasMechanism(mech): + def decorator(f): + @functools.wraps(f) + def newf(self): + self.checkMechanismSupport(mech) + return f(self) + return newf + return decorator + diff --git a/irctest/client_tests/test_sasl.py b/irctest/client_tests/test_sasl.py index 126406f..3885ecb 100644 --- a/irctest/client_tests/test_sasl.py +++ b/irctest/client_tests/test_sasl.py @@ -2,7 +2,6 @@ import ecdsa import base64 from irctest import cases from irctest import authentication -from irctest import optional_extensions from irctest.irc_utils.message_parser import Message ECDSA_KEY = """ @@ -16,14 +15,9 @@ IRX9cyi2wdYg9mUUYyh9GKdBCYHGUJAiCA== -----END EC PRIVATE KEY----- """ -class SaslMechanismCheck: - def checkMechanismSupport(self, mechanism): - if mechanism in self.controller.supported_sasl_mechanisms: - return - raise optional_extensions.OptionalSaslMechanismNotSupported(mechanism) - class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper, - SaslMechanismCheck): + cases.OptionalityHelper): + @cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN') def testPlain(self): """Test PLAIN authentication.""" auth = authentication.Authentication( @@ -32,7 +26,6 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper, password='sesame', ) m = self.negotiateCapabilities(['sasl'], auth=auth) - self.checkMechanismSupport('PLAIN') self.assertEqual(m, Message([], None, 'AUTHENTICATE', ['PLAIN'])) self.sendLine('AUTHENTICATE +') m = self.getMessage() @@ -43,6 +36,7 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper, m = self.negotiateCapabilities(['sasl'], False) self.assertEqual(m, Message([], None, 'CAP', ['END'])) + @cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN') def testPlainNotAvailable(self): """Test the client handles gracefully servers that don't provide a mechanism it could use.""" @@ -52,7 +46,6 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper, password='sesame', ) m = self.negotiateCapabilities(['sasl=EXTERNAL'], auth=auth) - self.checkMechanismSupport('PLAIN') self.assertEqual(self.acked_capabilities, {'sasl'}) if m == Message([], None, 'CAP', ['END']): # IRCv3.2-style @@ -63,6 +56,7 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper, self.assertMessageEqual(m, command='CAP') + @cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN') def testPlainLarge(self): """Test the client splits large AUTHENTICATE messages whose payload is not a multiple of 400.""" @@ -75,7 +69,6 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper, authstring = base64.b64encode(b'\x00'.join( [b'foo', b'foo', b'bar'*200])).decode() m = self.negotiateCapabilities(['sasl'], auth=auth) - self.checkMechanismSupport('PLAIN') self.assertEqual(m, Message([], None, 'AUTHENTICATE', ['PLAIN'])) self.sendLine('AUTHENTICATE +') m = self.getMessage() @@ -92,6 +85,7 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper, m = self.negotiateCapabilities(['sasl'], False) self.assertEqual(m, Message([], None, 'CAP', ['END'])) + @cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN') def testPlainLargeMultiple(self): """Test the client splits large AUTHENTICATE messages whose payload is a multiple of 400.""" @@ -104,7 +98,6 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper, authstring = base64.b64encode(b'\x00'.join( [b'foo', b'foo', b'quux'*148])).decode() m = self.negotiateCapabilities(['sasl'], auth=auth) - self.checkMechanismSupport('PLAIN') self.assertEqual(m, Message([], None, 'AUTHENTICATE', ['PLAIN'])) self.sendLine('AUTHENTICATE +') m = self.getMessage() @@ -121,6 +114,7 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper, m = self.negotiateCapabilities(['sasl'], False) self.assertEqual(m, Message([], None, 'CAP', ['END'])) + @cases.OptionalityHelper.skipUnlessHasMechanism('ECDSA-NIST256P-CHALLENGE') def testEcdsa(self): """Test ECDSA authentication.""" auth = authentication.Authentication( @@ -129,7 +123,6 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper, ecdsa_key=ECDSA_KEY, ) m = self.negotiateCapabilities(['sasl'], auth=auth) - self.checkMechanismSupport('ECDSA-NIST256P-CHALLENGE') self.assertEqual(m, Message([], None, 'AUTHENTICATE', ['ECDSA-NIST256P-CHALLENGE'])) self.sendLine('AUTHENTICATE +') m = self.getMessage() @@ -151,7 +144,8 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper, self.assertEqual(m, Message([], None, 'CAP', ['END'])) class Irc302SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper, - SaslMechanismCheck): + cases.OptionalityHelper): + @cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN') def testPlainNotAvailable(self): """Test the client does not try to authenticate using a mechanism the server does not advertise.""" @@ -161,6 +155,5 @@ class Irc302SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper password='sesame', ) m = self.negotiateCapabilities(['sasl=EXTERNAL'], auth=auth) - self.checkMechanismSupport('PLAIN') self.assertEqual(self.acked_capabilities, {'sasl'}) self.assertEqual(m, Message([], None, 'CAP', ['END'])) diff --git a/irctest/controllers/inspircd.py b/irctest/controllers/inspircd.py index fa22f7e..719309c 100644 --- a/irctest/controllers/inspircd.py +++ b/irctest/controllers/inspircd.py @@ -17,6 +17,7 @@ TEMPLATE_CONFIG = """ """ class InspircdController(BaseServerController, DirectoryBasedController): + supported_sasl_mechanisms = {} def create_config(self): super().create_config() with self.open_file('server.conf'): diff --git a/irctest/controllers/mammon.py b/irctest/controllers/mammon.py index 4574fed..b9d3d0e 100644 --- a/irctest/controllers/mammon.py +++ b/irctest/controllers/mammon.py @@ -55,6 +55,9 @@ server: """ class MammonController(BaseServerController, DirectoryBasedController): + supported_sasl_mechanisms = { + 'PLAIN', 'ECDSA-NIST256P-CHALLENGE', + } def create_config(self): super().create_config() with self.open_file('server.conf'): @@ -77,7 +80,7 @@ class MammonController(BaseServerController, DirectoryBasedController): '--config', os.path.join(self.directory, 'server.yml')]) self.wait_for_port(self.proc, port) - def registerUser(self, case, username): + def registerUser(self, case, username, password=None): # XXX: Move this somewhere else when # https://github.com/ircv3/ircv3-specifications/pull/152 becomes # part of the specification @@ -87,7 +90,8 @@ class MammonController(BaseServerController, DirectoryBasedController): case.sendLine(client, 'USER r e g :user') case.sendLine(client, 'CAP END') list(case.getLines(client)) - case.sendLine(client, 'REG CREATE {} passphrase temporarypassword'.format(username)) + case.sendLine(client, 'REG CREATE {} passphrase {}'.format( + username, password)) msg = case.getMessage(client) assert msg.command == '920' list(case.getLines(client)) diff --git a/irctest/server_tests/test_registration.py b/irctest/server_tests/test_registration.py index 35cfad2..5725855 100644 --- a/irctest/server_tests/test_registration.py +++ b/irctest/server_tests/test_registration.py @@ -3,4 +3,25 @@ from irctest.irc_utils.message_parser import Message class RegistrationTestCase(cases.BaseServerTestCase): def testRegistration(self): - self.controller.registerUser(self, 'testuser') + self.controller.registerUser(self, 'testuser', 'mypassword') + +class SaslTestCase(cases.BaseServerTestCase, cases.OptionalityHelper): + @cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN') + def testPlain(self): + """Test PLAIN authentication.""" + self.controller.registerUser(self, 'foo', 'sesame') + self.controller.registerUser(self, 'jilles', 'sesame') + self.controller.registerUser(self, 'bar', 'sesame') + self.addClient() + self.sendLine(1, 'CAP LS 302') + capabilities = self.getCapLs(1) + self.assertIn('sasl', capabilities) + if capabilities['sasl'] is not None: + self.assertIn('PLAIN', capabilities['sasl']) + self.sendLine(1, 'AUTHENTICATE PLAIN') + m = self.getMessage(1, filter_pred=lambda m:m.command != 'NOTICE') + self.assertMessageEqual(m, command='AUTHENTICATE', params=['+']) + self.sendLine(1, 'AUTHENTICATE amlsbGVzAGppbGxlcwBzZXNhbWU=') + m = self.getMessage(1, filter_pred=lambda m:m.command != 'NOTICE') + self.assertMessageEqual(m, command='900') + self.assertEqual(m.params[2], 'jilles', m)