Add PLAIN test for servers.

This commit is contained in:
Valentin Lorentz
2015-12-20 15:11:56 +01:00
parent d705c4ffbe
commit 38569f013f
6 changed files with 82 additions and 32 deletions

View File

@ -60,7 +60,7 @@ class BaseServerController(_BaseController):
"""Base controller for IRC server."""
def run(self, hostname, port, start_wait):
raise NotImplementedError()
def registerUser(self, case, username):
def registerUser(self, case, username, password=None):
raise NotImplementedByController('registration')
def wait_for_port(self, proc, port):
port_open = False

View File

@ -1,10 +1,13 @@
import time
import socket
import unittest
import functools
import collections
from . import authentication
from . import optional_extensions
from .irc_utils import message_parser
from .irc_utils import capabilities
class _IrcTestCase(unittest.TestCase):
"""Base class for test cases."""
@ -59,10 +62,13 @@ class BaseClientTestCase(_IrcTestCase):
user = None
def setUp(self):
super().setUp()
self.conn = None
self._setUpServer()
def tearDown(self):
if self.conn:
self.conn.sendall(b'QUIT :end of test.')
self.controller.kill()
if self.conn:
self.conn_file.close()
self.conn.close()
self.server.close()
@ -129,7 +135,7 @@ class ClientNegociationHelper:
else:
return True
def negotiateCapabilities(self, capabilities, cap_ls=True, auth=None):
def negotiateCapabilities(self, caps, cap_ls=True, auth=None):
"""Performes a complete capability negociation process, without
ending it, so the caller can continue the negociation."""
if cap_ls:
@ -137,8 +143,8 @@ class ClientNegociationHelper:
if not self.protocol_version:
# No negotiation.
return
self.sendLine('CAP * LS :{}'.format(' '.join(capabilities)))
capability_names = {x.split('=')[0] for x in capabilities}
self.sendLine('CAP * LS :{}'.format(' '.join(caps)))
capability_names = frozenset(capabilities.cap_list_to_dict(caps))
self.acked_capabilities = set()
while True:
m = self.getMessage(filter_pred=self.userNickPredicate)
@ -191,11 +197,15 @@ class BaseServerTestCase(_IrcTestCase):
conn.connect((self.hostname, self.port))
conn_file = conn.makefile(newline='\r\n', encoding='utf8')
self.clients[name] = Client(conn=conn, conn_file=conn_file)
if self.show_io:
print('{}: connects to server.'.format(name))
return name
def removeClient(self, name):
"""Disconnects the client, without QUIT."""
assert name in self.clients
if self.show_io:
print('{}: disconnects from server.'.format(name))
self.clients[name].conn.close()
del self.clients[name]
@ -209,7 +219,7 @@ class BaseServerTestCase(_IrcTestCase):
data += conn.recv(4096)
except BlockingIOError:
for line in data.decode().split('\r\n'):
if line:
if line and self.show_io:
print('S -> {}: {}'.format(client, line.strip()))
yield line + '\r\n'
finally:
@ -229,16 +239,37 @@ class BaseServerTestCase(_IrcTestCase):
if self.show_io:
print('{} -> S: {}'.format(client, line.strip()))
def getCapLs(self, client):
def getCapLs(self, client, as_list=False):
"""Waits for a CAP LS block, parses all CAP LS messages, and return
the list of capabilities."""
capabilities = []
the dict capabilities, with their values.
If as_list is given, returns the raw list (ie. key/value not split)
in case the order matters (but it shouldn't)."""
caps = []
while True:
m = self.getMessage(client,
filter_pred=lambda m:m.command != 'NOTICE')
self.assertMessageEqual(m, command='CAP', subcommand='LS')
if m.params[2] == '*':
capabilities.extend(m.params[3].split())
caps.extend(m.params[3].split())
else:
capabilities.extend(m.params[2].split())
return capabilities
caps.extend(m.params[2].split())
if not as_list:
caps = capabilities.cap_list_to_dict(caps)
return caps
class OptionalityHelper:
def checkMechanismSupport(self, mechanism):
if mechanism in self.controller.supported_sasl_mechanisms:
return
raise optional_extensions.OptionalSaslMechanismNotSupported(mechanism)
def skipUnlessHasMechanism(mech):
def decorator(f):
@functools.wraps(f)
def newf(self):
self.checkMechanismSupport(mech)
return f(self)
return newf
return decorator

View File

@ -2,7 +2,6 @@ import ecdsa
import base64
from irctest import cases
from irctest import authentication
from irctest import optional_extensions
from irctest.irc_utils.message_parser import Message
ECDSA_KEY = """
@ -16,14 +15,9 @@ IRX9cyi2wdYg9mUUYyh9GKdBCYHGUJAiCA==
-----END EC PRIVATE KEY-----
"""
class SaslMechanismCheck:
def checkMechanismSupport(self, mechanism):
if mechanism in self.controller.supported_sasl_mechanisms:
return
raise optional_extensions.OptionalSaslMechanismNotSupported(mechanism)
class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
SaslMechanismCheck):
cases.OptionalityHelper):
@cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN')
def testPlain(self):
"""Test PLAIN authentication."""
auth = authentication.Authentication(
@ -32,7 +26,6 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
password='sesame',
)
m = self.negotiateCapabilities(['sasl'], auth=auth)
self.checkMechanismSupport('PLAIN')
self.assertEqual(m, Message([], None, 'AUTHENTICATE', ['PLAIN']))
self.sendLine('AUTHENTICATE +')
m = self.getMessage()
@ -43,6 +36,7 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
m = self.negotiateCapabilities(['sasl'], False)
self.assertEqual(m, Message([], None, 'CAP', ['END']))
@cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN')
def testPlainNotAvailable(self):
"""Test the client handles gracefully servers that don't provide a
mechanism it could use."""
@ -52,7 +46,6 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
password='sesame',
)
m = self.negotiateCapabilities(['sasl=EXTERNAL'], auth=auth)
self.checkMechanismSupport('PLAIN')
self.assertEqual(self.acked_capabilities, {'sasl'})
if m == Message([], None, 'CAP', ['END']):
# IRCv3.2-style
@ -63,6 +56,7 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
self.assertMessageEqual(m, command='CAP')
@cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN')
def testPlainLarge(self):
"""Test the client splits large AUTHENTICATE messages whose payload
is not a multiple of 400."""
@ -75,7 +69,6 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
authstring = base64.b64encode(b'\x00'.join(
[b'foo', b'foo', b'bar'*200])).decode()
m = self.negotiateCapabilities(['sasl'], auth=auth)
self.checkMechanismSupport('PLAIN')
self.assertEqual(m, Message([], None, 'AUTHENTICATE', ['PLAIN']))
self.sendLine('AUTHENTICATE +')
m = self.getMessage()
@ -92,6 +85,7 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
m = self.negotiateCapabilities(['sasl'], False)
self.assertEqual(m, Message([], None, 'CAP', ['END']))
@cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN')
def testPlainLargeMultiple(self):
"""Test the client splits large AUTHENTICATE messages whose payload
is a multiple of 400."""
@ -104,7 +98,6 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
authstring = base64.b64encode(b'\x00'.join(
[b'foo', b'foo', b'quux'*148])).decode()
m = self.negotiateCapabilities(['sasl'], auth=auth)
self.checkMechanismSupport('PLAIN')
self.assertEqual(m, Message([], None, 'AUTHENTICATE', ['PLAIN']))
self.sendLine('AUTHENTICATE +')
m = self.getMessage()
@ -121,6 +114,7 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
m = self.negotiateCapabilities(['sasl'], False)
self.assertEqual(m, Message([], None, 'CAP', ['END']))
@cases.OptionalityHelper.skipUnlessHasMechanism('ECDSA-NIST256P-CHALLENGE')
def testEcdsa(self):
"""Test ECDSA authentication."""
auth = authentication.Authentication(
@ -129,7 +123,6 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
ecdsa_key=ECDSA_KEY,
)
m = self.negotiateCapabilities(['sasl'], auth=auth)
self.checkMechanismSupport('ECDSA-NIST256P-CHALLENGE')
self.assertEqual(m, Message([], None, 'AUTHENTICATE', ['ECDSA-NIST256P-CHALLENGE']))
self.sendLine('AUTHENTICATE +')
m = self.getMessage()
@ -151,7 +144,8 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
self.assertEqual(m, Message([], None, 'CAP', ['END']))
class Irc302SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
SaslMechanismCheck):
cases.OptionalityHelper):
@cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN')
def testPlainNotAvailable(self):
"""Test the client does not try to authenticate using a mechanism the
server does not advertise."""
@ -161,6 +155,5 @@ class Irc302SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper
password='sesame',
)
m = self.negotiateCapabilities(['sasl=EXTERNAL'], auth=auth)
self.checkMechanismSupport('PLAIN')
self.assertEqual(self.acked_capabilities, {'sasl'})
self.assertEqual(m, Message([], None, 'CAP', ['END']))

View File

@ -17,6 +17,7 @@ TEMPLATE_CONFIG = """
"""
class InspircdController(BaseServerController, DirectoryBasedController):
supported_sasl_mechanisms = {}
def create_config(self):
super().create_config()
with self.open_file('server.conf'):

View File

@ -55,6 +55,9 @@ server:
"""
class MammonController(BaseServerController, DirectoryBasedController):
supported_sasl_mechanisms = {
'PLAIN', 'ECDSA-NIST256P-CHALLENGE',
}
def create_config(self):
super().create_config()
with self.open_file('server.conf'):
@ -77,7 +80,7 @@ class MammonController(BaseServerController, DirectoryBasedController):
'--config', os.path.join(self.directory, 'server.yml')])
self.wait_for_port(self.proc, port)
def registerUser(self, case, username):
def registerUser(self, case, username, password=None):
# XXX: Move this somewhere else when
# https://github.com/ircv3/ircv3-specifications/pull/152 becomes
# part of the specification
@ -87,7 +90,8 @@ class MammonController(BaseServerController, DirectoryBasedController):
case.sendLine(client, 'USER r e g :user')
case.sendLine(client, 'CAP END')
list(case.getLines(client))
case.sendLine(client, 'REG CREATE {} passphrase temporarypassword'.format(username))
case.sendLine(client, 'REG CREATE {} passphrase {}'.format(
username, password))
msg = case.getMessage(client)
assert msg.command == '920'
list(case.getLines(client))

View File

@ -3,4 +3,25 @@ from irctest.irc_utils.message_parser import Message
class RegistrationTestCase(cases.BaseServerTestCase):
def testRegistration(self):
self.controller.registerUser(self, 'testuser')
self.controller.registerUser(self, 'testuser', 'mypassword')
class SaslTestCase(cases.BaseServerTestCase, cases.OptionalityHelper):
@cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN')
def testPlain(self):
"""Test PLAIN authentication."""
self.controller.registerUser(self, 'foo', 'sesame')
self.controller.registerUser(self, 'jilles', 'sesame')
self.controller.registerUser(self, 'bar', 'sesame')
self.addClient()
self.sendLine(1, 'CAP LS 302')
capabilities = self.getCapLs(1)
self.assertIn('sasl', capabilities)
if capabilities['sasl'] is not None:
self.assertIn('PLAIN', capabilities['sasl'])
self.sendLine(1, 'AUTHENTICATE PLAIN')
m = self.getMessage(1, filter_pred=lambda m:m.command != 'NOTICE')
self.assertMessageEqual(m, command='AUTHENTICATE', params=['+'])
self.sendLine(1, 'AUTHENTICATE amlsbGVzAGppbGxlcwBzZXNhbWU=')
m = self.getMessage(1, filter_pred=lambda m:m.command != 'NOTICE')
self.assertMessageEqual(m, command='900')
self.assertEqual(m.params[2], 'jilles', m)