Add PLAIN test for servers.

This commit is contained in:
Valentin Lorentz
2015-12-20 15:11:56 +01:00
parent d705c4ffbe
commit 38569f013f
6 changed files with 82 additions and 32 deletions

View File

@ -60,7 +60,7 @@ class BaseServerController(_BaseController):
"""Base controller for IRC server.""" """Base controller for IRC server."""
def run(self, hostname, port, start_wait): def run(self, hostname, port, start_wait):
raise NotImplementedError() raise NotImplementedError()
def registerUser(self, case, username): def registerUser(self, case, username, password=None):
raise NotImplementedByController('registration') raise NotImplementedByController('registration')
def wait_for_port(self, proc, port): def wait_for_port(self, proc, port):
port_open = False port_open = False

View File

@ -1,10 +1,13 @@
import time import time
import socket import socket
import unittest import unittest
import functools
import collections import collections
from . import authentication from . import authentication
from . import optional_extensions
from .irc_utils import message_parser from .irc_utils import message_parser
from .irc_utils import capabilities
class _IrcTestCase(unittest.TestCase): class _IrcTestCase(unittest.TestCase):
"""Base class for test cases.""" """Base class for test cases."""
@ -59,10 +62,13 @@ class BaseClientTestCase(_IrcTestCase):
user = None user = None
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.conn = None
self._setUpServer() self._setUpServer()
def tearDown(self): def tearDown(self):
if self.conn:
self.conn.sendall(b'QUIT :end of test.') self.conn.sendall(b'QUIT :end of test.')
self.controller.kill() self.controller.kill()
if self.conn:
self.conn_file.close() self.conn_file.close()
self.conn.close() self.conn.close()
self.server.close() self.server.close()
@ -129,7 +135,7 @@ class ClientNegociationHelper:
else: else:
return True return True
def negotiateCapabilities(self, capabilities, cap_ls=True, auth=None): def negotiateCapabilities(self, caps, cap_ls=True, auth=None):
"""Performes a complete capability negociation process, without """Performes a complete capability negociation process, without
ending it, so the caller can continue the negociation.""" ending it, so the caller can continue the negociation."""
if cap_ls: if cap_ls:
@ -137,8 +143,8 @@ class ClientNegociationHelper:
if not self.protocol_version: if not self.protocol_version:
# No negotiation. # No negotiation.
return return
self.sendLine('CAP * LS :{}'.format(' '.join(capabilities))) self.sendLine('CAP * LS :{}'.format(' '.join(caps)))
capability_names = {x.split('=')[0] for x in capabilities} capability_names = frozenset(capabilities.cap_list_to_dict(caps))
self.acked_capabilities = set() self.acked_capabilities = set()
while True: while True:
m = self.getMessage(filter_pred=self.userNickPredicate) m = self.getMessage(filter_pred=self.userNickPredicate)
@ -191,11 +197,15 @@ class BaseServerTestCase(_IrcTestCase):
conn.connect((self.hostname, self.port)) conn.connect((self.hostname, self.port))
conn_file = conn.makefile(newline='\r\n', encoding='utf8') conn_file = conn.makefile(newline='\r\n', encoding='utf8')
self.clients[name] = Client(conn=conn, conn_file=conn_file) self.clients[name] = Client(conn=conn, conn_file=conn_file)
if self.show_io:
print('{}: connects to server.'.format(name))
return name return name
def removeClient(self, name): def removeClient(self, name):
"""Disconnects the client, without QUIT.""" """Disconnects the client, without QUIT."""
assert name in self.clients assert name in self.clients
if self.show_io:
print('{}: disconnects from server.'.format(name))
self.clients[name].conn.close() self.clients[name].conn.close()
del self.clients[name] del self.clients[name]
@ -209,7 +219,7 @@ class BaseServerTestCase(_IrcTestCase):
data += conn.recv(4096) data += conn.recv(4096)
except BlockingIOError: except BlockingIOError:
for line in data.decode().split('\r\n'): for line in data.decode().split('\r\n'):
if line: if line and self.show_io:
print('S -> {}: {}'.format(client, line.strip())) print('S -> {}: {}'.format(client, line.strip()))
yield line + '\r\n' yield line + '\r\n'
finally: finally:
@ -229,16 +239,37 @@ class BaseServerTestCase(_IrcTestCase):
if self.show_io: if self.show_io:
print('{} -> S: {}'.format(client, line.strip())) print('{} -> S: {}'.format(client, line.strip()))
def getCapLs(self, client): def getCapLs(self, client, as_list=False):
"""Waits for a CAP LS block, parses all CAP LS messages, and return """Waits for a CAP LS block, parses all CAP LS messages, and return
the list of capabilities.""" the dict capabilities, with their values.
capabilities = []
If as_list is given, returns the raw list (ie. key/value not split)
in case the order matters (but it shouldn't)."""
caps = []
while True: while True:
m = self.getMessage(client, m = self.getMessage(client,
filter_pred=lambda m:m.command != 'NOTICE') filter_pred=lambda m:m.command != 'NOTICE')
self.assertMessageEqual(m, command='CAP', subcommand='LS') self.assertMessageEqual(m, command='CAP', subcommand='LS')
if m.params[2] == '*': if m.params[2] == '*':
capabilities.extend(m.params[3].split()) caps.extend(m.params[3].split())
else: else:
capabilities.extend(m.params[2].split()) caps.extend(m.params[2].split())
return capabilities if not as_list:
caps = capabilities.cap_list_to_dict(caps)
return caps
class OptionalityHelper:
def checkMechanismSupport(self, mechanism):
if mechanism in self.controller.supported_sasl_mechanisms:
return
raise optional_extensions.OptionalSaslMechanismNotSupported(mechanism)
def skipUnlessHasMechanism(mech):
def decorator(f):
@functools.wraps(f)
def newf(self):
self.checkMechanismSupport(mech)
return f(self)
return newf
return decorator

View File

@ -2,7 +2,6 @@ import ecdsa
import base64 import base64
from irctest import cases from irctest import cases
from irctest import authentication from irctest import authentication
from irctest import optional_extensions
from irctest.irc_utils.message_parser import Message from irctest.irc_utils.message_parser import Message
ECDSA_KEY = """ ECDSA_KEY = """
@ -16,14 +15,9 @@ IRX9cyi2wdYg9mUUYyh9GKdBCYHGUJAiCA==
-----END EC PRIVATE KEY----- -----END EC PRIVATE KEY-----
""" """
class SaslMechanismCheck:
def checkMechanismSupport(self, mechanism):
if mechanism in self.controller.supported_sasl_mechanisms:
return
raise optional_extensions.OptionalSaslMechanismNotSupported(mechanism)
class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper, class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
SaslMechanismCheck): cases.OptionalityHelper):
@cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN')
def testPlain(self): def testPlain(self):
"""Test PLAIN authentication.""" """Test PLAIN authentication."""
auth = authentication.Authentication( auth = authentication.Authentication(
@ -32,7 +26,6 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
password='sesame', password='sesame',
) )
m = self.negotiateCapabilities(['sasl'], auth=auth) m = self.negotiateCapabilities(['sasl'], auth=auth)
self.checkMechanismSupport('PLAIN')
self.assertEqual(m, Message([], None, 'AUTHENTICATE', ['PLAIN'])) self.assertEqual(m, Message([], None, 'AUTHENTICATE', ['PLAIN']))
self.sendLine('AUTHENTICATE +') self.sendLine('AUTHENTICATE +')
m = self.getMessage() m = self.getMessage()
@ -43,6 +36,7 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
m = self.negotiateCapabilities(['sasl'], False) m = self.negotiateCapabilities(['sasl'], False)
self.assertEqual(m, Message([], None, 'CAP', ['END'])) self.assertEqual(m, Message([], None, 'CAP', ['END']))
@cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN')
def testPlainNotAvailable(self): def testPlainNotAvailable(self):
"""Test the client handles gracefully servers that don't provide a """Test the client handles gracefully servers that don't provide a
mechanism it could use.""" mechanism it could use."""
@ -52,7 +46,6 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
password='sesame', password='sesame',
) )
m = self.negotiateCapabilities(['sasl=EXTERNAL'], auth=auth) m = self.negotiateCapabilities(['sasl=EXTERNAL'], auth=auth)
self.checkMechanismSupport('PLAIN')
self.assertEqual(self.acked_capabilities, {'sasl'}) self.assertEqual(self.acked_capabilities, {'sasl'})
if m == Message([], None, 'CAP', ['END']): if m == Message([], None, 'CAP', ['END']):
# IRCv3.2-style # IRCv3.2-style
@ -63,6 +56,7 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
self.assertMessageEqual(m, command='CAP') self.assertMessageEqual(m, command='CAP')
@cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN')
def testPlainLarge(self): def testPlainLarge(self):
"""Test the client splits large AUTHENTICATE messages whose payload """Test the client splits large AUTHENTICATE messages whose payload
is not a multiple of 400.""" is not a multiple of 400."""
@ -75,7 +69,6 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
authstring = base64.b64encode(b'\x00'.join( authstring = base64.b64encode(b'\x00'.join(
[b'foo', b'foo', b'bar'*200])).decode() [b'foo', b'foo', b'bar'*200])).decode()
m = self.negotiateCapabilities(['sasl'], auth=auth) m = self.negotiateCapabilities(['sasl'], auth=auth)
self.checkMechanismSupport('PLAIN')
self.assertEqual(m, Message([], None, 'AUTHENTICATE', ['PLAIN'])) self.assertEqual(m, Message([], None, 'AUTHENTICATE', ['PLAIN']))
self.sendLine('AUTHENTICATE +') self.sendLine('AUTHENTICATE +')
m = self.getMessage() m = self.getMessage()
@ -92,6 +85,7 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
m = self.negotiateCapabilities(['sasl'], False) m = self.negotiateCapabilities(['sasl'], False)
self.assertEqual(m, Message([], None, 'CAP', ['END'])) self.assertEqual(m, Message([], None, 'CAP', ['END']))
@cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN')
def testPlainLargeMultiple(self): def testPlainLargeMultiple(self):
"""Test the client splits large AUTHENTICATE messages whose payload """Test the client splits large AUTHENTICATE messages whose payload
is a multiple of 400.""" is a multiple of 400."""
@ -104,7 +98,6 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
authstring = base64.b64encode(b'\x00'.join( authstring = base64.b64encode(b'\x00'.join(
[b'foo', b'foo', b'quux'*148])).decode() [b'foo', b'foo', b'quux'*148])).decode()
m = self.negotiateCapabilities(['sasl'], auth=auth) m = self.negotiateCapabilities(['sasl'], auth=auth)
self.checkMechanismSupport('PLAIN')
self.assertEqual(m, Message([], None, 'AUTHENTICATE', ['PLAIN'])) self.assertEqual(m, Message([], None, 'AUTHENTICATE', ['PLAIN']))
self.sendLine('AUTHENTICATE +') self.sendLine('AUTHENTICATE +')
m = self.getMessage() m = self.getMessage()
@ -121,6 +114,7 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
m = self.negotiateCapabilities(['sasl'], False) m = self.negotiateCapabilities(['sasl'], False)
self.assertEqual(m, Message([], None, 'CAP', ['END'])) self.assertEqual(m, Message([], None, 'CAP', ['END']))
@cases.OptionalityHelper.skipUnlessHasMechanism('ECDSA-NIST256P-CHALLENGE')
def testEcdsa(self): def testEcdsa(self):
"""Test ECDSA authentication.""" """Test ECDSA authentication."""
auth = authentication.Authentication( auth = authentication.Authentication(
@ -129,7 +123,6 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
ecdsa_key=ECDSA_KEY, ecdsa_key=ECDSA_KEY,
) )
m = self.negotiateCapabilities(['sasl'], auth=auth) m = self.negotiateCapabilities(['sasl'], auth=auth)
self.checkMechanismSupport('ECDSA-NIST256P-CHALLENGE')
self.assertEqual(m, Message([], None, 'AUTHENTICATE', ['ECDSA-NIST256P-CHALLENGE'])) self.assertEqual(m, Message([], None, 'AUTHENTICATE', ['ECDSA-NIST256P-CHALLENGE']))
self.sendLine('AUTHENTICATE +') self.sendLine('AUTHENTICATE +')
m = self.getMessage() m = self.getMessage()
@ -151,7 +144,8 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
self.assertEqual(m, Message([], None, 'CAP', ['END'])) self.assertEqual(m, Message([], None, 'CAP', ['END']))
class Irc302SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper, class Irc302SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
SaslMechanismCheck): cases.OptionalityHelper):
@cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN')
def testPlainNotAvailable(self): def testPlainNotAvailable(self):
"""Test the client does not try to authenticate using a mechanism the """Test the client does not try to authenticate using a mechanism the
server does not advertise.""" server does not advertise."""
@ -161,6 +155,5 @@ class Irc302SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper
password='sesame', password='sesame',
) )
m = self.negotiateCapabilities(['sasl=EXTERNAL'], auth=auth) m = self.negotiateCapabilities(['sasl=EXTERNAL'], auth=auth)
self.checkMechanismSupport('PLAIN')
self.assertEqual(self.acked_capabilities, {'sasl'}) self.assertEqual(self.acked_capabilities, {'sasl'})
self.assertEqual(m, Message([], None, 'CAP', ['END'])) self.assertEqual(m, Message([], None, 'CAP', ['END']))

View File

@ -17,6 +17,7 @@ TEMPLATE_CONFIG = """
""" """
class InspircdController(BaseServerController, DirectoryBasedController): class InspircdController(BaseServerController, DirectoryBasedController):
supported_sasl_mechanisms = {}
def create_config(self): def create_config(self):
super().create_config() super().create_config()
with self.open_file('server.conf'): with self.open_file('server.conf'):

View File

@ -55,6 +55,9 @@ server:
""" """
class MammonController(BaseServerController, DirectoryBasedController): class MammonController(BaseServerController, DirectoryBasedController):
supported_sasl_mechanisms = {
'PLAIN', 'ECDSA-NIST256P-CHALLENGE',
}
def create_config(self): def create_config(self):
super().create_config() super().create_config()
with self.open_file('server.conf'): with self.open_file('server.conf'):
@ -77,7 +80,7 @@ class MammonController(BaseServerController, DirectoryBasedController):
'--config', os.path.join(self.directory, 'server.yml')]) '--config', os.path.join(self.directory, 'server.yml')])
self.wait_for_port(self.proc, port) self.wait_for_port(self.proc, port)
def registerUser(self, case, username): def registerUser(self, case, username, password=None):
# XXX: Move this somewhere else when # XXX: Move this somewhere else when
# https://github.com/ircv3/ircv3-specifications/pull/152 becomes # https://github.com/ircv3/ircv3-specifications/pull/152 becomes
# part of the specification # part of the specification
@ -87,7 +90,8 @@ class MammonController(BaseServerController, DirectoryBasedController):
case.sendLine(client, 'USER r e g :user') case.sendLine(client, 'USER r e g :user')
case.sendLine(client, 'CAP END') case.sendLine(client, 'CAP END')
list(case.getLines(client)) list(case.getLines(client))
case.sendLine(client, 'REG CREATE {} passphrase temporarypassword'.format(username)) case.sendLine(client, 'REG CREATE {} passphrase {}'.format(
username, password))
msg = case.getMessage(client) msg = case.getMessage(client)
assert msg.command == '920' assert msg.command == '920'
list(case.getLines(client)) list(case.getLines(client))

View File

@ -3,4 +3,25 @@ from irctest.irc_utils.message_parser import Message
class RegistrationTestCase(cases.BaseServerTestCase): class RegistrationTestCase(cases.BaseServerTestCase):
def testRegistration(self): def testRegistration(self):
self.controller.registerUser(self, 'testuser') self.controller.registerUser(self, 'testuser', 'mypassword')
class SaslTestCase(cases.BaseServerTestCase, cases.OptionalityHelper):
@cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN')
def testPlain(self):
"""Test PLAIN authentication."""
self.controller.registerUser(self, 'foo', 'sesame')
self.controller.registerUser(self, 'jilles', 'sesame')
self.controller.registerUser(self, 'bar', 'sesame')
self.addClient()
self.sendLine(1, 'CAP LS 302')
capabilities = self.getCapLs(1)
self.assertIn('sasl', capabilities)
if capabilities['sasl'] is not None:
self.assertIn('PLAIN', capabilities['sasl'])
self.sendLine(1, 'AUTHENTICATE PLAIN')
m = self.getMessage(1, filter_pred=lambda m:m.command != 'NOTICE')
self.assertMessageEqual(m, command='AUTHENTICATE', params=['+'])
self.sendLine(1, 'AUTHENTICATE amlsbGVzAGppbGxlcwBzZXNhbWU=')
m = self.getMessage(1, filter_pred=lambda m:m.command != 'NOTICE')
self.assertMessageEqual(m, command='900')
self.assertEqual(m.params[2], 'jilles', m)