mirror of
https://github.com/progval/irctest.git
synced 2025-04-07 15:59:49 +00:00
Add PLAIN test for servers.
This commit is contained in:
@ -60,7 +60,7 @@ class BaseServerController(_BaseController):
|
||||
"""Base controller for IRC server."""
|
||||
def run(self, hostname, port, start_wait):
|
||||
raise NotImplementedError()
|
||||
def registerUser(self, case, username):
|
||||
def registerUser(self, case, username, password=None):
|
||||
raise NotImplementedByController('registration')
|
||||
def wait_for_port(self, proc, port):
|
||||
port_open = False
|
||||
|
@ -1,10 +1,13 @@
|
||||
import time
|
||||
import socket
|
||||
import unittest
|
||||
import functools
|
||||
import collections
|
||||
|
||||
from . import authentication
|
||||
from . import optional_extensions
|
||||
from .irc_utils import message_parser
|
||||
from .irc_utils import capabilities
|
||||
|
||||
class _IrcTestCase(unittest.TestCase):
|
||||
"""Base class for test cases."""
|
||||
@ -59,10 +62,13 @@ class BaseClientTestCase(_IrcTestCase):
|
||||
user = None
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.conn = None
|
||||
self._setUpServer()
|
||||
def tearDown(self):
|
||||
if self.conn:
|
||||
self.conn.sendall(b'QUIT :end of test.')
|
||||
self.controller.kill()
|
||||
if self.conn:
|
||||
self.conn_file.close()
|
||||
self.conn.close()
|
||||
self.server.close()
|
||||
@ -129,7 +135,7 @@ class ClientNegociationHelper:
|
||||
else:
|
||||
return True
|
||||
|
||||
def negotiateCapabilities(self, capabilities, cap_ls=True, auth=None):
|
||||
def negotiateCapabilities(self, caps, cap_ls=True, auth=None):
|
||||
"""Performes a complete capability negociation process, without
|
||||
ending it, so the caller can continue the negociation."""
|
||||
if cap_ls:
|
||||
@ -137,8 +143,8 @@ class ClientNegociationHelper:
|
||||
if not self.protocol_version:
|
||||
# No negotiation.
|
||||
return
|
||||
self.sendLine('CAP * LS :{}'.format(' '.join(capabilities)))
|
||||
capability_names = {x.split('=')[0] for x in capabilities}
|
||||
self.sendLine('CAP * LS :{}'.format(' '.join(caps)))
|
||||
capability_names = frozenset(capabilities.cap_list_to_dict(caps))
|
||||
self.acked_capabilities = set()
|
||||
while True:
|
||||
m = self.getMessage(filter_pred=self.userNickPredicate)
|
||||
@ -191,11 +197,15 @@ class BaseServerTestCase(_IrcTestCase):
|
||||
conn.connect((self.hostname, self.port))
|
||||
conn_file = conn.makefile(newline='\r\n', encoding='utf8')
|
||||
self.clients[name] = Client(conn=conn, conn_file=conn_file)
|
||||
if self.show_io:
|
||||
print('{}: connects to server.'.format(name))
|
||||
return name
|
||||
|
||||
def removeClient(self, name):
|
||||
"""Disconnects the client, without QUIT."""
|
||||
assert name in self.clients
|
||||
if self.show_io:
|
||||
print('{}: disconnects from server.'.format(name))
|
||||
self.clients[name].conn.close()
|
||||
del self.clients[name]
|
||||
|
||||
@ -209,7 +219,7 @@ class BaseServerTestCase(_IrcTestCase):
|
||||
data += conn.recv(4096)
|
||||
except BlockingIOError:
|
||||
for line in data.decode().split('\r\n'):
|
||||
if line:
|
||||
if line and self.show_io:
|
||||
print('S -> {}: {}'.format(client, line.strip()))
|
||||
yield line + '\r\n'
|
||||
finally:
|
||||
@ -229,16 +239,37 @@ class BaseServerTestCase(_IrcTestCase):
|
||||
if self.show_io:
|
||||
print('{} -> S: {}'.format(client, line.strip()))
|
||||
|
||||
def getCapLs(self, client):
|
||||
def getCapLs(self, client, as_list=False):
|
||||
"""Waits for a CAP LS block, parses all CAP LS messages, and return
|
||||
the list of capabilities."""
|
||||
capabilities = []
|
||||
the dict capabilities, with their values.
|
||||
|
||||
If as_list is given, returns the raw list (ie. key/value not split)
|
||||
in case the order matters (but it shouldn't)."""
|
||||
caps = []
|
||||
while True:
|
||||
m = self.getMessage(client,
|
||||
filter_pred=lambda m:m.command != 'NOTICE')
|
||||
self.assertMessageEqual(m, command='CAP', subcommand='LS')
|
||||
if m.params[2] == '*':
|
||||
capabilities.extend(m.params[3].split())
|
||||
caps.extend(m.params[3].split())
|
||||
else:
|
||||
capabilities.extend(m.params[2].split())
|
||||
return capabilities
|
||||
caps.extend(m.params[2].split())
|
||||
if not as_list:
|
||||
caps = capabilities.cap_list_to_dict(caps)
|
||||
return caps
|
||||
|
||||
class OptionalityHelper:
|
||||
def checkMechanismSupport(self, mechanism):
|
||||
if mechanism in self.controller.supported_sasl_mechanisms:
|
||||
return
|
||||
raise optional_extensions.OptionalSaslMechanismNotSupported(mechanism)
|
||||
|
||||
def skipUnlessHasMechanism(mech):
|
||||
def decorator(f):
|
||||
@functools.wraps(f)
|
||||
def newf(self):
|
||||
self.checkMechanismSupport(mech)
|
||||
return f(self)
|
||||
return newf
|
||||
return decorator
|
||||
|
||||
|
@ -2,7 +2,6 @@ import ecdsa
|
||||
import base64
|
||||
from irctest import cases
|
||||
from irctest import authentication
|
||||
from irctest import optional_extensions
|
||||
from irctest.irc_utils.message_parser import Message
|
||||
|
||||
ECDSA_KEY = """
|
||||
@ -16,14 +15,9 @@ IRX9cyi2wdYg9mUUYyh9GKdBCYHGUJAiCA==
|
||||
-----END EC PRIVATE KEY-----
|
||||
"""
|
||||
|
||||
class SaslMechanismCheck:
|
||||
def checkMechanismSupport(self, mechanism):
|
||||
if mechanism in self.controller.supported_sasl_mechanisms:
|
||||
return
|
||||
raise optional_extensions.OptionalSaslMechanismNotSupported(mechanism)
|
||||
|
||||
class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
|
||||
SaslMechanismCheck):
|
||||
cases.OptionalityHelper):
|
||||
@cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN')
|
||||
def testPlain(self):
|
||||
"""Test PLAIN authentication."""
|
||||
auth = authentication.Authentication(
|
||||
@ -32,7 +26,6 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
|
||||
password='sesame',
|
||||
)
|
||||
m = self.negotiateCapabilities(['sasl'], auth=auth)
|
||||
self.checkMechanismSupport('PLAIN')
|
||||
self.assertEqual(m, Message([], None, 'AUTHENTICATE', ['PLAIN']))
|
||||
self.sendLine('AUTHENTICATE +')
|
||||
m = self.getMessage()
|
||||
@ -43,6 +36,7 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
|
||||
m = self.negotiateCapabilities(['sasl'], False)
|
||||
self.assertEqual(m, Message([], None, 'CAP', ['END']))
|
||||
|
||||
@cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN')
|
||||
def testPlainNotAvailable(self):
|
||||
"""Test the client handles gracefully servers that don't provide a
|
||||
mechanism it could use."""
|
||||
@ -52,7 +46,6 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
|
||||
password='sesame',
|
||||
)
|
||||
m = self.negotiateCapabilities(['sasl=EXTERNAL'], auth=auth)
|
||||
self.checkMechanismSupport('PLAIN')
|
||||
self.assertEqual(self.acked_capabilities, {'sasl'})
|
||||
if m == Message([], None, 'CAP', ['END']):
|
||||
# IRCv3.2-style
|
||||
@ -63,6 +56,7 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
|
||||
self.assertMessageEqual(m, command='CAP')
|
||||
|
||||
|
||||
@cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN')
|
||||
def testPlainLarge(self):
|
||||
"""Test the client splits large AUTHENTICATE messages whose payload
|
||||
is not a multiple of 400."""
|
||||
@ -75,7 +69,6 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
|
||||
authstring = base64.b64encode(b'\x00'.join(
|
||||
[b'foo', b'foo', b'bar'*200])).decode()
|
||||
m = self.negotiateCapabilities(['sasl'], auth=auth)
|
||||
self.checkMechanismSupport('PLAIN')
|
||||
self.assertEqual(m, Message([], None, 'AUTHENTICATE', ['PLAIN']))
|
||||
self.sendLine('AUTHENTICATE +')
|
||||
m = self.getMessage()
|
||||
@ -92,6 +85,7 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
|
||||
m = self.negotiateCapabilities(['sasl'], False)
|
||||
self.assertEqual(m, Message([], None, 'CAP', ['END']))
|
||||
|
||||
@cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN')
|
||||
def testPlainLargeMultiple(self):
|
||||
"""Test the client splits large AUTHENTICATE messages whose payload
|
||||
is a multiple of 400."""
|
||||
@ -104,7 +98,6 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
|
||||
authstring = base64.b64encode(b'\x00'.join(
|
||||
[b'foo', b'foo', b'quux'*148])).decode()
|
||||
m = self.negotiateCapabilities(['sasl'], auth=auth)
|
||||
self.checkMechanismSupport('PLAIN')
|
||||
self.assertEqual(m, Message([], None, 'AUTHENTICATE', ['PLAIN']))
|
||||
self.sendLine('AUTHENTICATE +')
|
||||
m = self.getMessage()
|
||||
@ -121,6 +114,7 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
|
||||
m = self.negotiateCapabilities(['sasl'], False)
|
||||
self.assertEqual(m, Message([], None, 'CAP', ['END']))
|
||||
|
||||
@cases.OptionalityHelper.skipUnlessHasMechanism('ECDSA-NIST256P-CHALLENGE')
|
||||
def testEcdsa(self):
|
||||
"""Test ECDSA authentication."""
|
||||
auth = authentication.Authentication(
|
||||
@ -129,7 +123,6 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
|
||||
ecdsa_key=ECDSA_KEY,
|
||||
)
|
||||
m = self.negotiateCapabilities(['sasl'], auth=auth)
|
||||
self.checkMechanismSupport('ECDSA-NIST256P-CHALLENGE')
|
||||
self.assertEqual(m, Message([], None, 'AUTHENTICATE', ['ECDSA-NIST256P-CHALLENGE']))
|
||||
self.sendLine('AUTHENTICATE +')
|
||||
m = self.getMessage()
|
||||
@ -151,7 +144,8 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
|
||||
self.assertEqual(m, Message([], None, 'CAP', ['END']))
|
||||
|
||||
class Irc302SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
|
||||
SaslMechanismCheck):
|
||||
cases.OptionalityHelper):
|
||||
@cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN')
|
||||
def testPlainNotAvailable(self):
|
||||
"""Test the client does not try to authenticate using a mechanism the
|
||||
server does not advertise."""
|
||||
@ -161,6 +155,5 @@ class Irc302SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper
|
||||
password='sesame',
|
||||
)
|
||||
m = self.negotiateCapabilities(['sasl=EXTERNAL'], auth=auth)
|
||||
self.checkMechanismSupport('PLAIN')
|
||||
self.assertEqual(self.acked_capabilities, {'sasl'})
|
||||
self.assertEqual(m, Message([], None, 'CAP', ['END']))
|
||||
|
@ -17,6 +17,7 @@ TEMPLATE_CONFIG = """
|
||||
"""
|
||||
|
||||
class InspircdController(BaseServerController, DirectoryBasedController):
|
||||
supported_sasl_mechanisms = {}
|
||||
def create_config(self):
|
||||
super().create_config()
|
||||
with self.open_file('server.conf'):
|
||||
|
@ -55,6 +55,9 @@ server:
|
||||
"""
|
||||
|
||||
class MammonController(BaseServerController, DirectoryBasedController):
|
||||
supported_sasl_mechanisms = {
|
||||
'PLAIN', 'ECDSA-NIST256P-CHALLENGE',
|
||||
}
|
||||
def create_config(self):
|
||||
super().create_config()
|
||||
with self.open_file('server.conf'):
|
||||
@ -77,7 +80,7 @@ class MammonController(BaseServerController, DirectoryBasedController):
|
||||
'--config', os.path.join(self.directory, 'server.yml')])
|
||||
self.wait_for_port(self.proc, port)
|
||||
|
||||
def registerUser(self, case, username):
|
||||
def registerUser(self, case, username, password=None):
|
||||
# XXX: Move this somewhere else when
|
||||
# https://github.com/ircv3/ircv3-specifications/pull/152 becomes
|
||||
# part of the specification
|
||||
@ -87,7 +90,8 @@ class MammonController(BaseServerController, DirectoryBasedController):
|
||||
case.sendLine(client, 'USER r e g :user')
|
||||
case.sendLine(client, 'CAP END')
|
||||
list(case.getLines(client))
|
||||
case.sendLine(client, 'REG CREATE {} passphrase temporarypassword'.format(username))
|
||||
case.sendLine(client, 'REG CREATE {} passphrase {}'.format(
|
||||
username, password))
|
||||
msg = case.getMessage(client)
|
||||
assert msg.command == '920'
|
||||
list(case.getLines(client))
|
||||
|
@ -3,4 +3,25 @@ from irctest.irc_utils.message_parser import Message
|
||||
|
||||
class RegistrationTestCase(cases.BaseServerTestCase):
|
||||
def testRegistration(self):
|
||||
self.controller.registerUser(self, 'testuser')
|
||||
self.controller.registerUser(self, 'testuser', 'mypassword')
|
||||
|
||||
class SaslTestCase(cases.BaseServerTestCase, cases.OptionalityHelper):
|
||||
@cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN')
|
||||
def testPlain(self):
|
||||
"""Test PLAIN authentication."""
|
||||
self.controller.registerUser(self, 'foo', 'sesame')
|
||||
self.controller.registerUser(self, 'jilles', 'sesame')
|
||||
self.controller.registerUser(self, 'bar', 'sesame')
|
||||
self.addClient()
|
||||
self.sendLine(1, 'CAP LS 302')
|
||||
capabilities = self.getCapLs(1)
|
||||
self.assertIn('sasl', capabilities)
|
||||
if capabilities['sasl'] is not None:
|
||||
self.assertIn('PLAIN', capabilities['sasl'])
|
||||
self.sendLine(1, 'AUTHENTICATE PLAIN')
|
||||
m = self.getMessage(1, filter_pred=lambda m:m.command != 'NOTICE')
|
||||
self.assertMessageEqual(m, command='AUTHENTICATE', params=['+'])
|
||||
self.sendLine(1, 'AUTHENTICATE amlsbGVzAGppbGxlcwBzZXNhbWU=')
|
||||
m = self.getMessage(1, filter_pred=lambda m:m.command != 'NOTICE')
|
||||
self.assertMessageEqual(m, command='900')
|
||||
self.assertEqual(m.params[2], 'jilles', m)
|
||||
|
Reference in New Issue
Block a user