irctest/irctest/cases.py

337 lines
12 KiB
Python
Raw Normal View History

2015-12-20 12:12:54 +00:00
import time
2015-12-19 00:11:57 +00:00
import socket
import unittest
2015-12-20 14:11:56 +00:00
import functools
2015-12-19 22:09:06 +00:00
import collections
2015-12-19 00:11:57 +00:00
import supybot.utils
2015-12-21 19:13:16 +00:00
from . import client_mock
2015-12-19 16:52:38 +00:00
from . import authentication
from . import runner
from .irc_utils import message_parser
2015-12-20 14:11:56 +00:00
from .irc_utils import capabilities
2015-12-19 00:11:57 +00:00
class _IrcTestCase(unittest.TestCase):
2015-12-20 12:47:30 +00:00
"""Base class for test cases."""
2015-12-19 00:11:57 +00:00
controllerClass = None # Will be set by __main__.py
def description(self):
method_doc = self._testMethodDoc
if not method_doc:
return ''
return '\t'+supybot.utils.str.normalizeWhitespace(
method_doc,
removeNewline=False,
).strip().replace('\n ', '\n\t')
2015-12-20 00:48:56 +00:00
def setUp(self):
super().setUp()
self.controller = self.controllerClass()
self.inbuffer = []
2015-12-20 00:48:56 +00:00
if self.show_io:
print('---- new test ----')
2015-12-19 23:47:06 +00:00
def assertMessageEqual(self, msg, subcommand=None, subparams=None,
target=None, nick=None, fail_msg=None, **kwargs):
2015-12-20 12:47:30 +00:00
"""Helper for partially comparing a message.
Takes the message as first arguments, and comparisons to be made
as keyword arguments.
Deals with subcommands (eg. `CAP`) if any of `subcommand`,
`subparams`, and `target` are given."""
2015-12-22 04:06:51 +00:00
fail_msg = fail_msg or '{msg}'
2015-12-19 23:47:06 +00:00
for (key, value) in kwargs.items():
self.assertEqual(getattr(msg, key), value, msg, fail_msg)
if nick:
self.assertNotEqual(msg.prefix, None, msg, fail_msg)
self.assertEqual(msg.prefix.split('!')[0], nick, msg, fail_msg)
2015-12-19 23:47:06 +00:00
if subcommand is not None or subparams is not None:
self.assertGreater(len(msg.params), 2, fail_msg)
2015-12-19 23:47:06 +00:00
msg_target = msg.params[0]
msg_subcommand = msg.params[1]
msg_subparams = msg.params[2:]
if subcommand:
with self.subTest(key='subcommand'):
self.assertEqual(msg_subcommand, subcommand, msg, fail_msg)
2015-12-19 23:47:06 +00:00
if subparams is not None:
with self.subTest(key='subparams'):
self.assertEqual(msg_subparams, subparams, msg, fail_msg)
def assertIn(self, item, list_, msg=None, fail_msg=None, extra_format=()):
if fail_msg:
fail_msg = fail_msg.format(*extra_format,
item=item, list=list_, msg=msg)
super().assertIn(item, list_, fail_msg)
def assertEqual(self, got, expects, msg=None, fail_msg=None, extra_format=()):
if fail_msg:
fail_msg = fail_msg.format(*extra_format,
got=got, expects=expects, msg=msg)
super().assertEqual(got, expects, fail_msg)
def assertNotEqual(self, got, expects, msg=None, fail_msg=None, extra_format=()):
if fail_msg:
fail_msg = fail_msg.format(*extra_format,
got=got, expects=expects, msg=msg)
super().assertNotEqual(got, expects, fail_msg)
class BaseClientTestCase(_IrcTestCase):
2015-12-19 22:09:06 +00:00
"""Basic class for client tests. Handles spawning a client and exchanging
messages with it."""
2015-12-19 16:52:38 +00:00
nick = None
user = None
2015-12-19 00:11:57 +00:00
def setUp(self):
2015-12-20 00:48:56 +00:00
super().setUp()
2015-12-20 14:11:56 +00:00
self.conn = None
2015-12-19 00:11:57 +00:00
self._setUpServer()
def tearDown(self):
2015-12-20 14:11:56 +00:00
if self.conn:
self.conn.sendall(b'QUIT :end of test.')
self.controller.kill()
2015-12-20 14:11:56 +00:00
if self.conn:
self.conn_file.close()
self.conn.close()
2015-12-19 00:11:57 +00:00
self.server.close()
def _setUpServer(self):
"""Creates the server and make it listen."""
self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.server.bind(('', 0)) # Bind any free port
self.server.listen(1)
def acceptClient(self):
"""Make the server accept a client connection. Blocking."""
(self.conn, addr) = self.server.accept()
self.conn_file = self.conn.makefile(newline='\r\n',
encoding='utf8')
2015-12-19 00:11:57 +00:00
def getLine(self):
line = self.conn_file.readline()
if self.show_io:
2015-12-20 23:23:32 +00:00
print('{:.3f} C: {}'.format(time.time(), line.strip()))
return line
def getMessages(self, *args):
lines = self.getLines(*args)
return map(message_parser.parse_message, lines)
def getMessage(self, *args, filter_pred=None):
"""Gets a message and returns it. If a filter predicate is given,
fetches messages until the predicate returns a False on a message,
and returns this message."""
while True:
line = self.getLine(*args)
msg = message_parser.parse_message(line)
if not filter_pred or filter_pred(msg):
return msg
def sendLine(self, line):
2015-12-19 22:09:06 +00:00
ret = self.conn.sendall(line.encode())
assert ret is None
if not line.endswith('\r\n'):
2015-12-19 22:09:06 +00:00
ret = self.conn.sendall(b'\r\n')
assert ret is None
2015-12-19 16:52:38 +00:00
if self.show_io:
2015-12-20 23:23:32 +00:00
print('{:.3f} S: {}'.format(time.time(), line.strip()))
class ClientNegociationHelper:
"""Helper class for tests handling capabilities negociation."""
2015-12-19 16:52:38 +00:00
def readCapLs(self, auth=None):
(hostname, port) = self.server.getsockname()
self.controller.run(
hostname=hostname,
port=port,
2015-12-19 16:52:38 +00:00
auth=auth,
)
self.acceptClient()
m = self.getMessage()
self.assertEqual(m.command, 'CAP',
'First message is not CAP LS.')
if m.params == ['LS']:
self.protocol_version = 301
elif m.params == ['LS', '302']:
self.protocol_version = 302
elif m.params == ['END']:
self.protocol_version = None
else:
raise AssertionError('Unknown CAP params: {}'
.format(m.params))
def userNickPredicate(self, msg):
"""Predicate to be used with getMessage to handle NICK/USER
transparently."""
if msg.command == 'NICK':
self.assertEqual(len(msg.params), 1, msg)
self.nick = msg.params[0]
return False
elif msg.command == 'USER':
self.assertEqual(len(msg.params), 4, msg)
2015-12-19 16:52:38 +00:00
self.user = msg.params
return False
else:
return True
2015-12-20 14:11:56 +00:00
def negotiateCapabilities(self, caps, cap_ls=True, auth=None):
2015-12-20 12:47:30 +00:00
"""Performes a complete capability negociation process, without
ending it, so the caller can continue the negociation."""
2015-12-19 16:52:38 +00:00
if cap_ls:
self.readCapLs(auth)
if not self.protocol_version:
# No negotiation.
return
2015-12-20 14:11:56 +00:00
self.sendLine('CAP * LS :{}'.format(' '.join(caps)))
capability_names = frozenset(capabilities.cap_list_to_dict(caps))
2015-12-19 20:17:06 +00:00
self.acked_capabilities = set()
while True:
m = self.getMessage(filter_pred=self.userNickPredicate)
2015-12-19 16:52:38 +00:00
if m.command != 'CAP':
return m
self.assertGreater(len(m.params), 0, m)
if m.params[0] == 'REQ':
self.assertEqual(len(m.params), 2, m)
requested = frozenset(m.params[1].split())
2015-12-19 20:17:06 +00:00
if not requested.issubset(capability_names):
2015-12-19 16:52:38 +00:00
self.sendLine('CAP {} NAK :{}'.format(
self.nick or '*',
m.params[1][0:100]))
else:
self.sendLine('CAP {} ACK :{}'.format(
self.nick or '*',
m.params[1]))
2015-12-19 20:17:06 +00:00
self.acked_capabilities.update(requested)
else:
return m
2015-12-19 22:09:06 +00:00
class BaseServerTestCase(_IrcTestCase):
"""Basic class for server tests. Handles spawning a server and exchanging
messages with it."""
password = None
2015-12-19 22:09:06 +00:00
def setUp(self):
2015-12-20 00:48:56 +00:00
super().setUp()
2015-12-19 22:09:06 +00:00
self.find_hostname_and_port()
self.controller.run(self.hostname, self.port, password=self.password)
2015-12-19 22:09:06 +00:00
self.clients = {}
def tearDown(self):
self.controller.kill()
for client in list(self.clients):
self.removeClient(client)
def find_hostname_and_port(self):
"""Find available hostname/port to listen on."""
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("",0))
(self.hostname, self.port) = s.getsockname()
s.close()
def addClient(self, name=None, show_io=None):
2015-12-19 22:09:06 +00:00
"""Connects a client to the server and adds it to the dict.
If 'name' is not given, uses the lowest unused non-negative integer."""
if not name:
name = max(map(int, list(self.clients)+[0]))+1
show_io = show_io if show_io is not None else self.show_io
2015-12-21 19:13:16 +00:00
self.clients[name] = client_mock.ClientMock(name=name,
show_io=show_io)
self.clients[name].connect(self.hostname, self.port)
2015-12-20 12:12:54 +00:00
return name
2015-12-19 22:09:06 +00:00
2015-12-19 22:09:06 +00:00
def removeClient(self, name):
2015-12-20 12:47:30 +00:00
"""Disconnects the client, without QUIT."""
2015-12-19 22:09:06 +00:00
assert name in self.clients
self.clients[name].disconnect()
2015-12-19 22:09:06 +00:00
del self.clients[name]
def getMessages(self, client, **kwargs):
return self.clients[client].getMessages(**kwargs)
def getMessage(self, client, **kwargs):
return self.clients[client].getMessage(**kwargs)
def getRegistrationMessage(self, client):
"""Filter notices, do not send pings."""
return self.getMessage(client, synchronize=False,
filter_pred=lambda m:m.command != 'NOTICE')
2015-12-19 22:09:06 +00:00
def sendLine(self, client, line):
return self.clients[client].sendLine(line)
2015-12-20 00:48:56 +00:00
2015-12-20 14:11:56 +00:00
def getCapLs(self, client, as_list=False):
2015-12-20 12:47:30 +00:00
"""Waits for a CAP LS block, parses all CAP LS messages, and return
2015-12-20 14:11:56 +00:00
the dict capabilities, with their values.
If as_list is given, returns the raw list (ie. key/value not split)
in case the order matters (but it shouldn't)."""
caps = []
2015-12-20 00:48:56 +00:00
while True:
m = self.getRegistrationMessage(client)
2015-12-20 00:48:56 +00:00
self.assertMessageEqual(m, command='CAP', subcommand='LS')
if m.params[2] == '*':
2015-12-20 14:11:56 +00:00
caps.extend(m.params[3].split())
2015-12-20 00:48:56 +00:00
else:
2015-12-20 14:11:56 +00:00
caps.extend(m.params[2].split())
if not as_list:
caps = capabilities.cap_list_to_dict(caps)
return caps
def assertDisconnected(self, client):
try:
self.getLines(client)
self.sendLine(client, 'PING foo')
while True:
l = self.getLine(client)
self.assertNotEqual(line, '')
m = message_parser.parse_message(l)
self.assertNotEqual(m.command, 'PONG',
'Client not disconnected.')
except socket.error:
del self.clients[client]
return
else:
raise AssertionError('Client not disconnected.')
def connectClient(self, nick, name=None):
name = self.addClient(name)
self.sendLine(name, 'NICK {}'.format(nick))
self.sendLine(name, 'USER username * * :Realname')
# Skip to the point where we are registered
# https://tools.ietf.org/html/rfc2812#section-3.1
while True:
m = self.getMessage(name, synchronize=False)
if m.command == '001':
break
self.sendLine(name, 'PING foo')
# Skip all that happy welcoming stuff
while True:
m = self.getMessage(name)
if m.command == 'PONG':
break
def joinClient(self, client, channel):
self.sendLine(client, 'JOIN {}'.format(channel))
received = {m.command for m in self.getMessages(client)}
self.assertIn('366', received,
fail_msg='Join to {} failed, {item} is not in the set of '
'received responses: {list}',
extra_format=(channel,))
2015-12-20 14:11:56 +00:00
class OptionalityHelper:
def checkSaslSupport(self):
if self.controller.supported_sasl_mechanisms:
return
raise runner.NotImplementedByController('SASL')
2015-12-20 14:11:56 +00:00
def checkMechanismSupport(self, mechanism):
if mechanism in self.controller.supported_sasl_mechanisms:
return
raise runner.OptionalSaslMechanismNotSupported(mechanism)
2015-12-20 14:11:56 +00:00
def skipUnlessHasMechanism(mech):
def decorator(f):
@functools.wraps(f)
def newf(self):
self.checkMechanismSupport(mech)
return f(self)
return newf
return decorator
def skipUnlessHasSasl(f):
@functools.wraps(f)
def newf(self):
self.checkSaslSupport()
return f(self)
return newf