mirror of
https://github.com/progval/irctest.git
synced 2025-04-08 00:09:46 +00:00
Add PLAIN test for servers.
This commit is contained in:
@ -60,7 +60,7 @@ class BaseServerController(_BaseController):
|
|||||||
"""Base controller for IRC server."""
|
"""Base controller for IRC server."""
|
||||||
def run(self, hostname, port, start_wait):
|
def run(self, hostname, port, start_wait):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
def registerUser(self, case, username):
|
def registerUser(self, case, username, password=None):
|
||||||
raise NotImplementedByController('registration')
|
raise NotImplementedByController('registration')
|
||||||
def wait_for_port(self, proc, port):
|
def wait_for_port(self, proc, port):
|
||||||
port_open = False
|
port_open = False
|
||||||
|
@ -1,10 +1,13 @@
|
|||||||
import time
|
import time
|
||||||
import socket
|
import socket
|
||||||
import unittest
|
import unittest
|
||||||
|
import functools
|
||||||
import collections
|
import collections
|
||||||
|
|
||||||
from . import authentication
|
from . import authentication
|
||||||
|
from . import optional_extensions
|
||||||
from .irc_utils import message_parser
|
from .irc_utils import message_parser
|
||||||
|
from .irc_utils import capabilities
|
||||||
|
|
||||||
class _IrcTestCase(unittest.TestCase):
|
class _IrcTestCase(unittest.TestCase):
|
||||||
"""Base class for test cases."""
|
"""Base class for test cases."""
|
||||||
@ -59,10 +62,13 @@ class BaseClientTestCase(_IrcTestCase):
|
|||||||
user = None
|
user = None
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
self.conn = None
|
||||||
self._setUpServer()
|
self._setUpServer()
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
|
if self.conn:
|
||||||
self.conn.sendall(b'QUIT :end of test.')
|
self.conn.sendall(b'QUIT :end of test.')
|
||||||
self.controller.kill()
|
self.controller.kill()
|
||||||
|
if self.conn:
|
||||||
self.conn_file.close()
|
self.conn_file.close()
|
||||||
self.conn.close()
|
self.conn.close()
|
||||||
self.server.close()
|
self.server.close()
|
||||||
@ -129,7 +135,7 @@ class ClientNegociationHelper:
|
|||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def negotiateCapabilities(self, capabilities, cap_ls=True, auth=None):
|
def negotiateCapabilities(self, caps, cap_ls=True, auth=None):
|
||||||
"""Performes a complete capability negociation process, without
|
"""Performes a complete capability negociation process, without
|
||||||
ending it, so the caller can continue the negociation."""
|
ending it, so the caller can continue the negociation."""
|
||||||
if cap_ls:
|
if cap_ls:
|
||||||
@ -137,8 +143,8 @@ class ClientNegociationHelper:
|
|||||||
if not self.protocol_version:
|
if not self.protocol_version:
|
||||||
# No negotiation.
|
# No negotiation.
|
||||||
return
|
return
|
||||||
self.sendLine('CAP * LS :{}'.format(' '.join(capabilities)))
|
self.sendLine('CAP * LS :{}'.format(' '.join(caps)))
|
||||||
capability_names = {x.split('=')[0] for x in capabilities}
|
capability_names = frozenset(capabilities.cap_list_to_dict(caps))
|
||||||
self.acked_capabilities = set()
|
self.acked_capabilities = set()
|
||||||
while True:
|
while True:
|
||||||
m = self.getMessage(filter_pred=self.userNickPredicate)
|
m = self.getMessage(filter_pred=self.userNickPredicate)
|
||||||
@ -191,11 +197,15 @@ class BaseServerTestCase(_IrcTestCase):
|
|||||||
conn.connect((self.hostname, self.port))
|
conn.connect((self.hostname, self.port))
|
||||||
conn_file = conn.makefile(newline='\r\n', encoding='utf8')
|
conn_file = conn.makefile(newline='\r\n', encoding='utf8')
|
||||||
self.clients[name] = Client(conn=conn, conn_file=conn_file)
|
self.clients[name] = Client(conn=conn, conn_file=conn_file)
|
||||||
|
if self.show_io:
|
||||||
|
print('{}: connects to server.'.format(name))
|
||||||
return name
|
return name
|
||||||
|
|
||||||
def removeClient(self, name):
|
def removeClient(self, name):
|
||||||
"""Disconnects the client, without QUIT."""
|
"""Disconnects the client, without QUIT."""
|
||||||
assert name in self.clients
|
assert name in self.clients
|
||||||
|
if self.show_io:
|
||||||
|
print('{}: disconnects from server.'.format(name))
|
||||||
self.clients[name].conn.close()
|
self.clients[name].conn.close()
|
||||||
del self.clients[name]
|
del self.clients[name]
|
||||||
|
|
||||||
@ -209,7 +219,7 @@ class BaseServerTestCase(_IrcTestCase):
|
|||||||
data += conn.recv(4096)
|
data += conn.recv(4096)
|
||||||
except BlockingIOError:
|
except BlockingIOError:
|
||||||
for line in data.decode().split('\r\n'):
|
for line in data.decode().split('\r\n'):
|
||||||
if line:
|
if line and self.show_io:
|
||||||
print('S -> {}: {}'.format(client, line.strip()))
|
print('S -> {}: {}'.format(client, line.strip()))
|
||||||
yield line + '\r\n'
|
yield line + '\r\n'
|
||||||
finally:
|
finally:
|
||||||
@ -229,16 +239,37 @@ class BaseServerTestCase(_IrcTestCase):
|
|||||||
if self.show_io:
|
if self.show_io:
|
||||||
print('{} -> S: {}'.format(client, line.strip()))
|
print('{} -> S: {}'.format(client, line.strip()))
|
||||||
|
|
||||||
def getCapLs(self, client):
|
def getCapLs(self, client, as_list=False):
|
||||||
"""Waits for a CAP LS block, parses all CAP LS messages, and return
|
"""Waits for a CAP LS block, parses all CAP LS messages, and return
|
||||||
the list of capabilities."""
|
the dict capabilities, with their values.
|
||||||
capabilities = []
|
|
||||||
|
If as_list is given, returns the raw list (ie. key/value not split)
|
||||||
|
in case the order matters (but it shouldn't)."""
|
||||||
|
caps = []
|
||||||
while True:
|
while True:
|
||||||
m = self.getMessage(client,
|
m = self.getMessage(client,
|
||||||
filter_pred=lambda m:m.command != 'NOTICE')
|
filter_pred=lambda m:m.command != 'NOTICE')
|
||||||
self.assertMessageEqual(m, command='CAP', subcommand='LS')
|
self.assertMessageEqual(m, command='CAP', subcommand='LS')
|
||||||
if m.params[2] == '*':
|
if m.params[2] == '*':
|
||||||
capabilities.extend(m.params[3].split())
|
caps.extend(m.params[3].split())
|
||||||
else:
|
else:
|
||||||
capabilities.extend(m.params[2].split())
|
caps.extend(m.params[2].split())
|
||||||
return capabilities
|
if not as_list:
|
||||||
|
caps = capabilities.cap_list_to_dict(caps)
|
||||||
|
return caps
|
||||||
|
|
||||||
|
class OptionalityHelper:
|
||||||
|
def checkMechanismSupport(self, mechanism):
|
||||||
|
if mechanism in self.controller.supported_sasl_mechanisms:
|
||||||
|
return
|
||||||
|
raise optional_extensions.OptionalSaslMechanismNotSupported(mechanism)
|
||||||
|
|
||||||
|
def skipUnlessHasMechanism(mech):
|
||||||
|
def decorator(f):
|
||||||
|
@functools.wraps(f)
|
||||||
|
def newf(self):
|
||||||
|
self.checkMechanismSupport(mech)
|
||||||
|
return f(self)
|
||||||
|
return newf
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
@ -2,7 +2,6 @@ import ecdsa
|
|||||||
import base64
|
import base64
|
||||||
from irctest import cases
|
from irctest import cases
|
||||||
from irctest import authentication
|
from irctest import authentication
|
||||||
from irctest import optional_extensions
|
|
||||||
from irctest.irc_utils.message_parser import Message
|
from irctest.irc_utils.message_parser import Message
|
||||||
|
|
||||||
ECDSA_KEY = """
|
ECDSA_KEY = """
|
||||||
@ -16,14 +15,9 @@ IRX9cyi2wdYg9mUUYyh9GKdBCYHGUJAiCA==
|
|||||||
-----END EC PRIVATE KEY-----
|
-----END EC PRIVATE KEY-----
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class SaslMechanismCheck:
|
|
||||||
def checkMechanismSupport(self, mechanism):
|
|
||||||
if mechanism in self.controller.supported_sasl_mechanisms:
|
|
||||||
return
|
|
||||||
raise optional_extensions.OptionalSaslMechanismNotSupported(mechanism)
|
|
||||||
|
|
||||||
class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
|
class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
|
||||||
SaslMechanismCheck):
|
cases.OptionalityHelper):
|
||||||
|
@cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN')
|
||||||
def testPlain(self):
|
def testPlain(self):
|
||||||
"""Test PLAIN authentication."""
|
"""Test PLAIN authentication."""
|
||||||
auth = authentication.Authentication(
|
auth = authentication.Authentication(
|
||||||
@ -32,7 +26,6 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
|
|||||||
password='sesame',
|
password='sesame',
|
||||||
)
|
)
|
||||||
m = self.negotiateCapabilities(['sasl'], auth=auth)
|
m = self.negotiateCapabilities(['sasl'], auth=auth)
|
||||||
self.checkMechanismSupport('PLAIN')
|
|
||||||
self.assertEqual(m, Message([], None, 'AUTHENTICATE', ['PLAIN']))
|
self.assertEqual(m, Message([], None, 'AUTHENTICATE', ['PLAIN']))
|
||||||
self.sendLine('AUTHENTICATE +')
|
self.sendLine('AUTHENTICATE +')
|
||||||
m = self.getMessage()
|
m = self.getMessage()
|
||||||
@ -43,6 +36,7 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
|
|||||||
m = self.negotiateCapabilities(['sasl'], False)
|
m = self.negotiateCapabilities(['sasl'], False)
|
||||||
self.assertEqual(m, Message([], None, 'CAP', ['END']))
|
self.assertEqual(m, Message([], None, 'CAP', ['END']))
|
||||||
|
|
||||||
|
@cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN')
|
||||||
def testPlainNotAvailable(self):
|
def testPlainNotAvailable(self):
|
||||||
"""Test the client handles gracefully servers that don't provide a
|
"""Test the client handles gracefully servers that don't provide a
|
||||||
mechanism it could use."""
|
mechanism it could use."""
|
||||||
@ -52,7 +46,6 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
|
|||||||
password='sesame',
|
password='sesame',
|
||||||
)
|
)
|
||||||
m = self.negotiateCapabilities(['sasl=EXTERNAL'], auth=auth)
|
m = self.negotiateCapabilities(['sasl=EXTERNAL'], auth=auth)
|
||||||
self.checkMechanismSupport('PLAIN')
|
|
||||||
self.assertEqual(self.acked_capabilities, {'sasl'})
|
self.assertEqual(self.acked_capabilities, {'sasl'})
|
||||||
if m == Message([], None, 'CAP', ['END']):
|
if m == Message([], None, 'CAP', ['END']):
|
||||||
# IRCv3.2-style
|
# IRCv3.2-style
|
||||||
@ -63,6 +56,7 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
|
|||||||
self.assertMessageEqual(m, command='CAP')
|
self.assertMessageEqual(m, command='CAP')
|
||||||
|
|
||||||
|
|
||||||
|
@cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN')
|
||||||
def testPlainLarge(self):
|
def testPlainLarge(self):
|
||||||
"""Test the client splits large AUTHENTICATE messages whose payload
|
"""Test the client splits large AUTHENTICATE messages whose payload
|
||||||
is not a multiple of 400."""
|
is not a multiple of 400."""
|
||||||
@ -75,7 +69,6 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
|
|||||||
authstring = base64.b64encode(b'\x00'.join(
|
authstring = base64.b64encode(b'\x00'.join(
|
||||||
[b'foo', b'foo', b'bar'*200])).decode()
|
[b'foo', b'foo', b'bar'*200])).decode()
|
||||||
m = self.negotiateCapabilities(['sasl'], auth=auth)
|
m = self.negotiateCapabilities(['sasl'], auth=auth)
|
||||||
self.checkMechanismSupport('PLAIN')
|
|
||||||
self.assertEqual(m, Message([], None, 'AUTHENTICATE', ['PLAIN']))
|
self.assertEqual(m, Message([], None, 'AUTHENTICATE', ['PLAIN']))
|
||||||
self.sendLine('AUTHENTICATE +')
|
self.sendLine('AUTHENTICATE +')
|
||||||
m = self.getMessage()
|
m = self.getMessage()
|
||||||
@ -92,6 +85,7 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
|
|||||||
m = self.negotiateCapabilities(['sasl'], False)
|
m = self.negotiateCapabilities(['sasl'], False)
|
||||||
self.assertEqual(m, Message([], None, 'CAP', ['END']))
|
self.assertEqual(m, Message([], None, 'CAP', ['END']))
|
||||||
|
|
||||||
|
@cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN')
|
||||||
def testPlainLargeMultiple(self):
|
def testPlainLargeMultiple(self):
|
||||||
"""Test the client splits large AUTHENTICATE messages whose payload
|
"""Test the client splits large AUTHENTICATE messages whose payload
|
||||||
is a multiple of 400."""
|
is a multiple of 400."""
|
||||||
@ -104,7 +98,6 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
|
|||||||
authstring = base64.b64encode(b'\x00'.join(
|
authstring = base64.b64encode(b'\x00'.join(
|
||||||
[b'foo', b'foo', b'quux'*148])).decode()
|
[b'foo', b'foo', b'quux'*148])).decode()
|
||||||
m = self.negotiateCapabilities(['sasl'], auth=auth)
|
m = self.negotiateCapabilities(['sasl'], auth=auth)
|
||||||
self.checkMechanismSupport('PLAIN')
|
|
||||||
self.assertEqual(m, Message([], None, 'AUTHENTICATE', ['PLAIN']))
|
self.assertEqual(m, Message([], None, 'AUTHENTICATE', ['PLAIN']))
|
||||||
self.sendLine('AUTHENTICATE +')
|
self.sendLine('AUTHENTICATE +')
|
||||||
m = self.getMessage()
|
m = self.getMessage()
|
||||||
@ -121,6 +114,7 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
|
|||||||
m = self.negotiateCapabilities(['sasl'], False)
|
m = self.negotiateCapabilities(['sasl'], False)
|
||||||
self.assertEqual(m, Message([], None, 'CAP', ['END']))
|
self.assertEqual(m, Message([], None, 'CAP', ['END']))
|
||||||
|
|
||||||
|
@cases.OptionalityHelper.skipUnlessHasMechanism('ECDSA-NIST256P-CHALLENGE')
|
||||||
def testEcdsa(self):
|
def testEcdsa(self):
|
||||||
"""Test ECDSA authentication."""
|
"""Test ECDSA authentication."""
|
||||||
auth = authentication.Authentication(
|
auth = authentication.Authentication(
|
||||||
@ -129,7 +123,6 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
|
|||||||
ecdsa_key=ECDSA_KEY,
|
ecdsa_key=ECDSA_KEY,
|
||||||
)
|
)
|
||||||
m = self.negotiateCapabilities(['sasl'], auth=auth)
|
m = self.negotiateCapabilities(['sasl'], auth=auth)
|
||||||
self.checkMechanismSupport('ECDSA-NIST256P-CHALLENGE')
|
|
||||||
self.assertEqual(m, Message([], None, 'AUTHENTICATE', ['ECDSA-NIST256P-CHALLENGE']))
|
self.assertEqual(m, Message([], None, 'AUTHENTICATE', ['ECDSA-NIST256P-CHALLENGE']))
|
||||||
self.sendLine('AUTHENTICATE +')
|
self.sendLine('AUTHENTICATE +')
|
||||||
m = self.getMessage()
|
m = self.getMessage()
|
||||||
@ -151,7 +144,8 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
|
|||||||
self.assertEqual(m, Message([], None, 'CAP', ['END']))
|
self.assertEqual(m, Message([], None, 'CAP', ['END']))
|
||||||
|
|
||||||
class Irc302SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
|
class Irc302SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
|
||||||
SaslMechanismCheck):
|
cases.OptionalityHelper):
|
||||||
|
@cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN')
|
||||||
def testPlainNotAvailable(self):
|
def testPlainNotAvailable(self):
|
||||||
"""Test the client does not try to authenticate using a mechanism the
|
"""Test the client does not try to authenticate using a mechanism the
|
||||||
server does not advertise."""
|
server does not advertise."""
|
||||||
@ -161,6 +155,5 @@ class Irc302SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper
|
|||||||
password='sesame',
|
password='sesame',
|
||||||
)
|
)
|
||||||
m = self.negotiateCapabilities(['sasl=EXTERNAL'], auth=auth)
|
m = self.negotiateCapabilities(['sasl=EXTERNAL'], auth=auth)
|
||||||
self.checkMechanismSupport('PLAIN')
|
|
||||||
self.assertEqual(self.acked_capabilities, {'sasl'})
|
self.assertEqual(self.acked_capabilities, {'sasl'})
|
||||||
self.assertEqual(m, Message([], None, 'CAP', ['END']))
|
self.assertEqual(m, Message([], None, 'CAP', ['END']))
|
||||||
|
@ -17,6 +17,7 @@ TEMPLATE_CONFIG = """
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
class InspircdController(BaseServerController, DirectoryBasedController):
|
class InspircdController(BaseServerController, DirectoryBasedController):
|
||||||
|
supported_sasl_mechanisms = {}
|
||||||
def create_config(self):
|
def create_config(self):
|
||||||
super().create_config()
|
super().create_config()
|
||||||
with self.open_file('server.conf'):
|
with self.open_file('server.conf'):
|
||||||
|
@ -55,6 +55,9 @@ server:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
class MammonController(BaseServerController, DirectoryBasedController):
|
class MammonController(BaseServerController, DirectoryBasedController):
|
||||||
|
supported_sasl_mechanisms = {
|
||||||
|
'PLAIN', 'ECDSA-NIST256P-CHALLENGE',
|
||||||
|
}
|
||||||
def create_config(self):
|
def create_config(self):
|
||||||
super().create_config()
|
super().create_config()
|
||||||
with self.open_file('server.conf'):
|
with self.open_file('server.conf'):
|
||||||
@ -77,7 +80,7 @@ class MammonController(BaseServerController, DirectoryBasedController):
|
|||||||
'--config', os.path.join(self.directory, 'server.yml')])
|
'--config', os.path.join(self.directory, 'server.yml')])
|
||||||
self.wait_for_port(self.proc, port)
|
self.wait_for_port(self.proc, port)
|
||||||
|
|
||||||
def registerUser(self, case, username):
|
def registerUser(self, case, username, password=None):
|
||||||
# XXX: Move this somewhere else when
|
# XXX: Move this somewhere else when
|
||||||
# https://github.com/ircv3/ircv3-specifications/pull/152 becomes
|
# https://github.com/ircv3/ircv3-specifications/pull/152 becomes
|
||||||
# part of the specification
|
# part of the specification
|
||||||
@ -87,7 +90,8 @@ class MammonController(BaseServerController, DirectoryBasedController):
|
|||||||
case.sendLine(client, 'USER r e g :user')
|
case.sendLine(client, 'USER r e g :user')
|
||||||
case.sendLine(client, 'CAP END')
|
case.sendLine(client, 'CAP END')
|
||||||
list(case.getLines(client))
|
list(case.getLines(client))
|
||||||
case.sendLine(client, 'REG CREATE {} passphrase temporarypassword'.format(username))
|
case.sendLine(client, 'REG CREATE {} passphrase {}'.format(
|
||||||
|
username, password))
|
||||||
msg = case.getMessage(client)
|
msg = case.getMessage(client)
|
||||||
assert msg.command == '920'
|
assert msg.command == '920'
|
||||||
list(case.getLines(client))
|
list(case.getLines(client))
|
||||||
|
@ -3,4 +3,25 @@ from irctest.irc_utils.message_parser import Message
|
|||||||
|
|
||||||
class RegistrationTestCase(cases.BaseServerTestCase):
|
class RegistrationTestCase(cases.BaseServerTestCase):
|
||||||
def testRegistration(self):
|
def testRegistration(self):
|
||||||
self.controller.registerUser(self, 'testuser')
|
self.controller.registerUser(self, 'testuser', 'mypassword')
|
||||||
|
|
||||||
|
class SaslTestCase(cases.BaseServerTestCase, cases.OptionalityHelper):
|
||||||
|
@cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN')
|
||||||
|
def testPlain(self):
|
||||||
|
"""Test PLAIN authentication."""
|
||||||
|
self.controller.registerUser(self, 'foo', 'sesame')
|
||||||
|
self.controller.registerUser(self, 'jilles', 'sesame')
|
||||||
|
self.controller.registerUser(self, 'bar', 'sesame')
|
||||||
|
self.addClient()
|
||||||
|
self.sendLine(1, 'CAP LS 302')
|
||||||
|
capabilities = self.getCapLs(1)
|
||||||
|
self.assertIn('sasl', capabilities)
|
||||||
|
if capabilities['sasl'] is not None:
|
||||||
|
self.assertIn('PLAIN', capabilities['sasl'])
|
||||||
|
self.sendLine(1, 'AUTHENTICATE PLAIN')
|
||||||
|
m = self.getMessage(1, filter_pred=lambda m:m.command != 'NOTICE')
|
||||||
|
self.assertMessageEqual(m, command='AUTHENTICATE', params=['+'])
|
||||||
|
self.sendLine(1, 'AUTHENTICATE amlsbGVzAGppbGxlcwBzZXNhbWU=')
|
||||||
|
m = self.getMessage(1, filter_pred=lambda m:m.command != 'NOTICE')
|
||||||
|
self.assertMessageEqual(m, command='900')
|
||||||
|
self.assertEqual(m.params[2], 'jilles', m)
|
||||||
|
Reference in New Issue
Block a user