mirror of
https://github.com/progval/irctest.git
synced 2025-04-07 15:59:49 +00:00
Add STARTTLS tests.
This commit is contained in:
@ -51,6 +51,23 @@ class DirectoryBasedController(_BaseController):
|
|||||||
def create_config(self):
|
def create_config(self):
|
||||||
self.directory = tempfile.mkdtemp()
|
self.directory = tempfile.mkdtemp()
|
||||||
|
|
||||||
|
def gen_ssl(self):
|
||||||
|
self.csr_path = os.path.join(self.directory, 'ssl.csr')
|
||||||
|
self.key_path = os.path.join(self.directory, 'ssl.key')
|
||||||
|
self.pem_path = os.path.join(self.directory, 'ssl.pem')
|
||||||
|
self.dh_path = os.path.join(self.directory, 'dh.pem')
|
||||||
|
subprocess.check_output(['openssl', 'req', '-new', '-newkey', 'rsa',
|
||||||
|
'-nodes', '-out', self.csr_path, '-keyout', self.key_path,
|
||||||
|
'-batch'],
|
||||||
|
stderr=subprocess.DEVNULL)
|
||||||
|
subprocess.check_output(['openssl', 'x509', '-req',
|
||||||
|
'-in', self.csr_path, '-signkey', self.key_path,
|
||||||
|
'-out', self.pem_path],
|
||||||
|
stderr=subprocess.DEVNULL)
|
||||||
|
subprocess.check_output(['openssl', 'dhparam',
|
||||||
|
'-out', self.dh_path, '128'],
|
||||||
|
stderr=subprocess.DEVNULL)
|
||||||
|
|
||||||
class BaseClientController(_BaseController):
|
class BaseClientController(_BaseController):
|
||||||
"""Base controller for IRC clients."""
|
"""Base controller for IRC clients."""
|
||||||
def run(self, hostname, port, auth):
|
def run(self, hostname, port, auth):
|
||||||
|
@ -210,6 +210,7 @@ class BaseServerTestCase(_IrcTestCase):
|
|||||||
"""Basic class for server tests. Handles spawning a server and exchanging
|
"""Basic class for server tests. Handles spawning a server and exchanging
|
||||||
messages with it."""
|
messages with it."""
|
||||||
password = None
|
password = None
|
||||||
|
ssl = False
|
||||||
valid_metadata_keys = frozenset()
|
valid_metadata_keys = frozenset()
|
||||||
invalid_metadata_keys = frozenset()
|
invalid_metadata_keys = frozenset()
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -217,7 +218,8 @@ class BaseServerTestCase(_IrcTestCase):
|
|||||||
self.find_hostname_and_port()
|
self.find_hostname_and_port()
|
||||||
self.controller.run(self.hostname, self.port, password=self.password,
|
self.controller.run(self.hostname, self.port, password=self.password,
|
||||||
valid_metadata_keys=self.valid_metadata_keys,
|
valid_metadata_keys=self.valid_metadata_keys,
|
||||||
invalid_metadata_keys=self.invalid_metadata_keys)
|
invalid_metadata_keys=self.invalid_metadata_keys,
|
||||||
|
ssl=self.ssl)
|
||||||
self.clients = {}
|
self.clients = {}
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
self.controller.kill()
|
self.controller.kill()
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import ssl
|
||||||
import time
|
import time
|
||||||
import socket
|
import socket
|
||||||
from .irc_utils import message_parser
|
from .irc_utils import message_parser
|
||||||
@ -14,6 +15,7 @@ class ClientMock:
|
|||||||
self.name = name
|
self.name = name
|
||||||
self.show_io = show_io
|
self.show_io = show_io
|
||||||
self.inbuffer = []
|
self.inbuffer = []
|
||||||
|
self.ssl = False
|
||||||
def connect(self, hostname, port):
|
def connect(self, hostname, port):
|
||||||
self.conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
self.conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
self.conn.settimeout(1) # TODO: configurable
|
self.conn.settimeout(1) # TODO: configurable
|
||||||
@ -24,6 +26,10 @@ class ClientMock:
|
|||||||
if self.show_io:
|
if self.show_io:
|
||||||
print('{:.3f} {}: disconnects from server.'.format(time.time(), self.name))
|
print('{:.3f} {}: disconnects from server.'.format(time.time(), self.name))
|
||||||
self.conn.close()
|
self.conn.close()
|
||||||
|
def starttls(self):
|
||||||
|
assert not self.ssl, 'SSL already active.'
|
||||||
|
self.conn = ssl.wrap_socket(self.conn)
|
||||||
|
self.ssl = True
|
||||||
def getMessages(self, synchronize=True, assert_get_one=False):
|
def getMessages(self, synchronize=True, assert_get_one=False):
|
||||||
if synchronize:
|
if synchronize:
|
||||||
token = 'synchronize{}'.format(time.monotonic())
|
token = 'synchronize{}'.format(time.monotonic())
|
||||||
@ -57,7 +63,11 @@ class ClientMock:
|
|||||||
for line in data.decode().split('\r\n'):
|
for line in data.decode().split('\r\n'):
|
||||||
if line:
|
if line:
|
||||||
if self.show_io:
|
if self.show_io:
|
||||||
print('{:.3f} S -> {}: {}'.format(time.time(), self.name, line))
|
print('{time:.3f}{ssl} S -> {client}: {line}'.format(
|
||||||
|
time=time.time(),
|
||||||
|
ssl=' (ssl)' if self.ssl else '',
|
||||||
|
client=self.name,
|
||||||
|
line=line))
|
||||||
message = message_parser.parse_message(line + '\r\n')
|
message = message_parser.parse_message(line + '\r\n')
|
||||||
if message.command == 'PONG' and \
|
if message.command == 'PONG' and \
|
||||||
token in message.params:
|
token in message.params:
|
||||||
@ -83,10 +93,17 @@ class ClientMock:
|
|||||||
if not filter_pred or filter_pred(message):
|
if not filter_pred or filter_pred(message):
|
||||||
return message
|
return message
|
||||||
def sendLine(self, line):
|
def sendLine(self, line):
|
||||||
ret = self.conn.sendall(line.encode())
|
|
||||||
assert ret is None
|
|
||||||
if not line.endswith('\r\n'):
|
if not line.endswith('\r\n'):
|
||||||
ret = self.conn.sendall(b'\r\n')
|
line += '\r\n'
|
||||||
assert ret is None
|
encoded_line = line.encode()
|
||||||
|
ret = self.conn.sendall(encoded_line)
|
||||||
|
if self.ssl:
|
||||||
|
assert ret == len(encoded_line), (ret, repr(encoded_line))
|
||||||
|
else:
|
||||||
|
assert ret is None, ret
|
||||||
if self.show_io:
|
if self.show_io:
|
||||||
print('{:.3f} {} -> S: {}'.format(time.time(), self.name, line.strip('\r\n')))
|
print('{time:.3f}{ssl} {client} -> S: {line}'.format(
|
||||||
|
time=time.time(),
|
||||||
|
ssl=' (ssl)' if self.ssl else '',
|
||||||
|
client=self.name,
|
||||||
|
line=line.strip('\r\n')))
|
||||||
|
@ -14,6 +14,7 @@ serverinfo {{
|
|||||||
name = "My.Little.Server";
|
name = "My.Little.Server";
|
||||||
sid = "42X";
|
sid = "42X";
|
||||||
description = "test server";
|
description = "test server";
|
||||||
|
{ssl_config}
|
||||||
}};
|
}};
|
||||||
listen {{
|
listen {{
|
||||||
defer_accept = yes;
|
defer_accept = yes;
|
||||||
@ -32,6 +33,14 @@ channel {{
|
|||||||
no_join_on_split = no;
|
no_join_on_split = no;
|
||||||
}};
|
}};
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
TEMPLATE_SSL_CONFIG = """
|
||||||
|
ssl_private_key = "{key_path}";
|
||||||
|
ssl_cert = "{pem_path}";
|
||||||
|
ssl_dh_params = "{dh_path}";
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class CharybdisController(BaseServerController, DirectoryBasedController):
|
class CharybdisController(BaseServerController, DirectoryBasedController):
|
||||||
software_name = 'Charybdis'
|
software_name = 'Charybdis'
|
||||||
supported_sasl_mechanisms = set()
|
supported_sasl_mechanisms = set()
|
||||||
@ -40,7 +49,7 @@ class CharybdisController(BaseServerController, DirectoryBasedController):
|
|||||||
with self.open_file('server.conf'):
|
with self.open_file('server.conf'):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def run(self, hostname, port, password=None,
|
def run(self, hostname, port, password=None, ssl=False,
|
||||||
valid_metadata_keys=None, invalid_metadata_keys=None):
|
valid_metadata_keys=None, invalid_metadata_keys=None):
|
||||||
if valid_metadata_keys or invalid_metadata_keys:
|
if valid_metadata_keys or invalid_metadata_keys:
|
||||||
raise NotImplementedByController(
|
raise NotImplementedByController(
|
||||||
@ -49,11 +58,21 @@ class CharybdisController(BaseServerController, DirectoryBasedController):
|
|||||||
self.create_config()
|
self.create_config()
|
||||||
self.port = port
|
self.port = port
|
||||||
password_field = 'password = "{}";'.format(password) if password else ''
|
password_field = 'password = "{}";'.format(password) if password else ''
|
||||||
|
if ssl:
|
||||||
|
self.gen_ssl()
|
||||||
|
ssl_config = TEMPLATE_SSL_CONFIG.format(
|
||||||
|
key_path=self.key_path,
|
||||||
|
pem_path=self.pem_path,
|
||||||
|
dh_path=self.dh_path,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ssl_config = ''
|
||||||
with self.open_file('server.conf') as fd:
|
with self.open_file('server.conf') as fd:
|
||||||
fd.write(TEMPLATE_CONFIG.format(
|
fd.write(TEMPLATE_CONFIG.format(
|
||||||
hostname=hostname,
|
hostname=hostname,
|
||||||
port=port,
|
port=port,
|
||||||
password_field=password_field
|
password_field=password_field,
|
||||||
|
ssl_config=ssl_config,
|
||||||
))
|
))
|
||||||
self.proc = subprocess.Popen(['ircd', '-foreground',
|
self.proc = subprocess.Popen(['ircd', '-foreground',
|
||||||
'-configfile', os.path.join(self.directory, 'server.conf'),
|
'-configfile', os.path.join(self.directory, 'server.conf'),
|
||||||
|
@ -10,6 +10,7 @@ from irctest.basecontrollers import BaseServerController, DirectoryBasedControll
|
|||||||
|
|
||||||
TEMPLATE_CONFIG = """
|
TEMPLATE_CONFIG = """
|
||||||
<bind address="{hostname}" port="{port}" type="clients">
|
<bind address="{hostname}" port="{port}" type="clients">
|
||||||
|
{ssl_config}
|
||||||
<module name="cap">
|
<module name="cap">
|
||||||
<module name="ircv3">
|
<module name="ircv3">
|
||||||
<module name="ircv3_capnotify">
|
<module name="ircv3_capnotify">
|
||||||
@ -21,6 +22,11 @@ TEMPLATE_CONFIG = """
|
|||||||
<log method="file" type="*" level="debug" target="/tmp/ircd-{port}.log">
|
<log method="file" type="*" level="debug" target="/tmp/ircd-{port}.log">
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
TEMPLATE_SSL_CONFIG = """
|
||||||
|
<module name="ssl_openssl">
|
||||||
|
<openssl certfile="{pem_path}" keyfile="{key_path}" dhfile="{dh_path}" hash="sha1">
|
||||||
|
"""
|
||||||
|
|
||||||
class InspircdController(BaseServerController, DirectoryBasedController):
|
class InspircdController(BaseServerController, DirectoryBasedController):
|
||||||
software_name = 'InspIRCd'
|
software_name = 'InspIRCd'
|
||||||
supported_sasl_mechanisms = set()
|
supported_sasl_mechanisms = set()
|
||||||
@ -29,7 +35,8 @@ class InspircdController(BaseServerController, DirectoryBasedController):
|
|||||||
with self.open_file('server.conf'):
|
with self.open_file('server.conf'):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def run(self, hostname, port, password=None, restricted_metadata_keys=None,
|
def run(self, hostname, port, password=None, ssl=False,
|
||||||
|
restricted_metadata_keys=None,
|
||||||
valid_metadata_keys=None, invalid_metadata_keys=None):
|
valid_metadata_keys=None, invalid_metadata_keys=None):
|
||||||
if valid_metadata_keys or invalid_metadata_keys:
|
if valid_metadata_keys or invalid_metadata_keys:
|
||||||
raise NotImplementedByController(
|
raise NotImplementedByController(
|
||||||
@ -38,11 +45,21 @@ class InspircdController(BaseServerController, DirectoryBasedController):
|
|||||||
self.port = port
|
self.port = port
|
||||||
self.create_config()
|
self.create_config()
|
||||||
password_field = 'password="{}"'.format(password) if password else ''
|
password_field = 'password="{}"'.format(password) if password else ''
|
||||||
|
if ssl:
|
||||||
|
self.gen_ssl()
|
||||||
|
ssl_config = TEMPLATE_SSL_CONFIG.format(
|
||||||
|
key_path=self.key_path,
|
||||||
|
pem_path=self.pem_path,
|
||||||
|
dh_path=self.dh_path,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ssl_config = ''
|
||||||
with self.open_file('server.conf') as fd:
|
with self.open_file('server.conf') as fd:
|
||||||
fd.write(TEMPLATE_CONFIG.format(
|
fd.write(TEMPLATE_CONFIG.format(
|
||||||
hostname=hostname,
|
hostname=hostname,
|
||||||
port=port,
|
port=port,
|
||||||
password_field=password_field
|
password_field=password_field,
|
||||||
|
ssl_config=ssl_config
|
||||||
))
|
))
|
||||||
self.proc = subprocess.Popen(['inspircd', '--nofork', '--config',
|
self.proc = subprocess.Popen(['inspircd', '--nofork', '--config',
|
||||||
os.path.join(self.directory, 'server.conf')],
|
os.path.join(self.directory, 'server.conf')],
|
||||||
|
@ -75,10 +75,13 @@ class MammonController(BaseServerController, DirectoryBasedController):
|
|||||||
# Mammon does not seem to handle SIGTERM very well
|
# Mammon does not seem to handle SIGTERM very well
|
||||||
self.proc.kill()
|
self.proc.kill()
|
||||||
|
|
||||||
def run(self, hostname, port, password=None, restricted_metadata_keys=(),
|
def run(self, hostname, port, password=None, ssl=False,
|
||||||
|
restricted_metadata_keys=(),
|
||||||
valid_metadata_keys=(), invalid_metadata_keys=()):
|
valid_metadata_keys=(), invalid_metadata_keys=()):
|
||||||
if password is not None:
|
if password is not None:
|
||||||
raise NotImplementedByController('PASS command')
|
raise NotImplementedByController('PASS command')
|
||||||
|
if ssl:
|
||||||
|
raise NotImplementedByController('SSL')
|
||||||
assert self.proc is None
|
assert self.proc is None
|
||||||
self.port = port
|
self.port = port
|
||||||
self.create_config()
|
self.create_config()
|
||||||
|
62
irctest/server_tests/test_starttls.py
Normal file
62
irctest/server_tests/test_starttls.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
"""
|
||||||
|
<http://ircv3.net/specs/extensions/tls-3.1.html>
|
||||||
|
"""
|
||||||
|
|
||||||
|
from irctest import cases
|
||||||
|
from irctest.basecontrollers import NotImplementedByController
|
||||||
|
|
||||||
|
class StarttlsFailTestCase(cases.BaseServerTestCase):
|
||||||
|
@cases.SpecificationSelector.requiredBySpecification('IRCv3.1')
|
||||||
|
def testStarttlsRequestTlsFail(self):
|
||||||
|
"""<http://ircv3.net/specs/extensions/tls-3.1.html>
|
||||||
|
"""
|
||||||
|
self.addClient()
|
||||||
|
|
||||||
|
# TODO: check also without this
|
||||||
|
self.sendLine(1, 'CAP LS')
|
||||||
|
capabilities = self.getCapLs(1)
|
||||||
|
if 'tls' not in capabilities:
|
||||||
|
raise NotImplementedByController('tls')
|
||||||
|
|
||||||
|
# TODO: check also without this
|
||||||
|
self.sendLine(1, 'CAP REQ :tls')
|
||||||
|
m = self.getRegistrationMessage(1)
|
||||||
|
# TODO: Remove this one the trailing space issue is fixed in Charybdis
|
||||||
|
# and Mammon:
|
||||||
|
#self.assertMessageEqual(m, command='CAP', params=['*', 'ACK', 'tls'],
|
||||||
|
# fail_msg='Did not ACK capability `tls`: {msg}')
|
||||||
|
self.sendLine(1, 'STARTTLS')
|
||||||
|
m = self.getRegistrationMessage(1)
|
||||||
|
self.assertMessageEqual(m, command='691',
|
||||||
|
fail_msg='Did not respond to STARTTLS with 691 whereas '
|
||||||
|
'SSL is not configured: {msg}.')
|
||||||
|
|
||||||
|
class StarttlsTestCase(cases.BaseServerTestCase):
|
||||||
|
ssl = True
|
||||||
|
def testStarttlsRequestTls(self):
|
||||||
|
"""<http://ircv3.net/specs/extensions/tls-3.1.html>
|
||||||
|
"""
|
||||||
|
self.addClient()
|
||||||
|
|
||||||
|
# TODO: check also without this
|
||||||
|
self.sendLine(1, 'CAP LS')
|
||||||
|
capabilities = self.getCapLs(1)
|
||||||
|
if 'tls' not in capabilities:
|
||||||
|
raise NotImplementedByController('tls')
|
||||||
|
|
||||||
|
# TODO: check also without this
|
||||||
|
self.sendLine(1, 'CAP REQ :tls')
|
||||||
|
m = self.getRegistrationMessage(1)
|
||||||
|
# TODO: Remove this one the trailing space issue is fixed in Charybdis
|
||||||
|
# and Mammon:
|
||||||
|
#self.assertMessageEqual(m, command='CAP', params=['*', 'ACK', 'tls'],
|
||||||
|
# fail_msg='Did not ACK capability `tls`: {msg}')
|
||||||
|
self.sendLine(1, 'STARTTLS')
|
||||||
|
m = self.getRegistrationMessage(1)
|
||||||
|
self.assertMessageEqual(m, command='670',
|
||||||
|
fail_msg='Did not respond to STARTTLS with 670: {msg}.')
|
||||||
|
self.clients[1].starttls()
|
||||||
|
self.sendLine(1, 'USER f * * :foo')
|
||||||
|
self.sendLine(1, 'NICK foo')
|
||||||
|
self.sendLine(1, 'CAP END')
|
||||||
|
self.getMessages(1)
|
Reference in New Issue
Block a user