2015-12-20 12:12:54 +00:00
|
|
|
import time
|
2015-12-19 00:11:57 +00:00
|
|
|
import socket
|
|
|
|
import unittest
|
2015-12-20 14:11:56 +00:00
|
|
|
import functools
|
2015-12-19 22:09:06 +00:00
|
|
|
import collections
|
2015-12-19 00:11:57 +00:00
|
|
|
|
2015-12-19 16:52:38 +00:00
|
|
|
from . import authentication
|
2015-12-20 14:11:56 +00:00
|
|
|
from . import optional_extensions
|
2015-12-19 07:43:45 +00:00
|
|
|
from .irc_utils import message_parser
|
2015-12-20 14:11:56 +00:00
|
|
|
from .irc_utils import capabilities
|
2015-12-19 07:43:45 +00:00
|
|
|
|
2015-12-19 00:11:57 +00:00
|
|
|
class _IrcTestCase(unittest.TestCase):
|
2015-12-20 12:47:30 +00:00
|
|
|
"""Base class for test cases."""
|
2015-12-19 00:11:57 +00:00
|
|
|
controllerClass = None # Will be set by __main__.py
|
|
|
|
|
2015-12-20 00:48:56 +00:00
|
|
|
def setUp(self):
|
|
|
|
super().setUp()
|
|
|
|
self.controller = self.controllerClass()
|
|
|
|
if self.show_io:
|
|
|
|
print('---- new test ----')
|
2015-12-19 08:30:50 +00:00
|
|
|
def getLine(self):
|
|
|
|
raise NotImplementedError()
|
2015-12-20 12:12:54 +00:00
|
|
|
def getMessages(self, *args):
|
|
|
|
lines = self.getLines(*args)
|
|
|
|
return map(message_parser.parse_message, lines)
|
2015-12-19 22:09:06 +00:00
|
|
|
def getMessage(self, *args, filter_pred=None):
|
2015-12-19 08:30:50 +00:00
|
|
|
"""Gets a message and returns it. If a filter predicate is given,
|
|
|
|
fetches messages until the predicate returns a False on a message,
|
|
|
|
and returns this message."""
|
|
|
|
while True:
|
2015-12-19 22:09:06 +00:00
|
|
|
msg = message_parser.parse_message(self.getLine(*args))
|
2015-12-19 08:30:50 +00:00
|
|
|
if not filter_pred or filter_pred(msg):
|
|
|
|
return msg
|
2015-12-19 23:47:06 +00:00
|
|
|
def assertMessageEqual(self, msg, subcommand=None, subparams=None,
|
|
|
|
target=None, **kwargs):
|
2015-12-20 12:47:30 +00:00
|
|
|
"""Helper for partially comparing a message.
|
|
|
|
|
|
|
|
Takes the message as first arguments, and comparisons to be made
|
|
|
|
as keyword arguments.
|
|
|
|
|
|
|
|
Deals with subcommands (eg. `CAP`) if any of `subcommand`,
|
|
|
|
`subparams`, and `target` are given."""
|
2015-12-19 23:47:06 +00:00
|
|
|
for (key, value) in kwargs.items():
|
|
|
|
with self.subTest(key=key):
|
|
|
|
self.assertEqual(getattr(msg, key), value, msg)
|
|
|
|
if subcommand is not None or subparams is not None:
|
|
|
|
self.assertGreater(len(msg.params), 2, msg)
|
|
|
|
msg_target = msg.params[0]
|
|
|
|
msg_subcommand = msg.params[1]
|
|
|
|
msg_subparams = msg.params[2:]
|
|
|
|
if subcommand:
|
|
|
|
with self.subTest(key='subcommand'):
|
|
|
|
self.assertEqual(msg_subcommand, subcommand, msg)
|
|
|
|
if subparams is not None:
|
|
|
|
with self.subTest(key='subparams'):
|
|
|
|
self.assertEqual(msg_subparams, subparams, msg)
|
2015-12-19 08:30:50 +00:00
|
|
|
|
2015-12-19 08:03:08 +00:00
|
|
|
class BaseClientTestCase(_IrcTestCase):
|
2015-12-19 22:09:06 +00:00
|
|
|
"""Basic class for client tests. Handles spawning a client and exchanging
|
|
|
|
messages with it."""
|
2015-12-19 16:52:38 +00:00
|
|
|
nick = None
|
|
|
|
user = None
|
2015-12-19 00:11:57 +00:00
|
|
|
def setUp(self):
|
2015-12-20 00:48:56 +00:00
|
|
|
super().setUp()
|
2015-12-20 14:11:56 +00:00
|
|
|
self.conn = None
|
2015-12-19 00:11:57 +00:00
|
|
|
self._setUpServer()
|
|
|
|
def tearDown(self):
|
2015-12-20 14:11:56 +00:00
|
|
|
if self.conn:
|
|
|
|
self.conn.sendall(b'QUIT :end of test.')
|
2015-12-20 00:17:52 +00:00
|
|
|
self.controller.kill()
|
2015-12-20 14:11:56 +00:00
|
|
|
if self.conn:
|
|
|
|
self.conn_file.close()
|
|
|
|
self.conn.close()
|
2015-12-19 00:11:57 +00:00
|
|
|
self.server.close()
|
|
|
|
|
|
|
|
def _setUpServer(self):
|
|
|
|
"""Creates the server and make it listen."""
|
|
|
|
self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
|
|
self.server.bind(('', 0)) # Bind any free port
|
|
|
|
self.server.listen(1)
|
|
|
|
def acceptClient(self):
|
|
|
|
"""Make the server accept a client connection. Blocking."""
|
|
|
|
(self.conn, addr) = self.server.accept()
|
2015-12-19 08:30:50 +00:00
|
|
|
self.conn_file = self.conn.makefile(newline='\r\n',
|
|
|
|
encoding='utf8')
|
2015-12-19 00:11:57 +00:00
|
|
|
|
|
|
|
def getLine(self):
|
2015-12-19 08:30:50 +00:00
|
|
|
line = self.conn_file.readline()
|
|
|
|
if self.show_io:
|
|
|
|
print('C: {}'.format(line.strip()))
|
|
|
|
return line
|
|
|
|
def sendLine(self, line):
|
2015-12-19 22:09:06 +00:00
|
|
|
ret = self.conn.sendall(line.encode())
|
|
|
|
assert ret is None
|
2015-12-19 08:30:50 +00:00
|
|
|
if not line.endswith('\r\n'):
|
2015-12-19 22:09:06 +00:00
|
|
|
ret = self.conn.sendall(b'\r\n')
|
|
|
|
assert ret is None
|
2015-12-19 16:52:38 +00:00
|
|
|
if self.show_io:
|
|
|
|
print('S: {}'.format(line.strip()))
|
2015-12-19 08:03:08 +00:00
|
|
|
|
2015-12-19 08:30:50 +00:00
|
|
|
class ClientNegociationHelper:
|
2015-12-19 08:03:08 +00:00
|
|
|
"""Helper class for tests handling capabilities negociation."""
|
2015-12-19 16:52:38 +00:00
|
|
|
def readCapLs(self, auth=None):
|
2015-12-19 08:03:08 +00:00
|
|
|
(hostname, port) = self.server.getsockname()
|
|
|
|
self.controller.run(
|
|
|
|
hostname=hostname,
|
|
|
|
port=port,
|
2015-12-19 16:52:38 +00:00
|
|
|
auth=auth,
|
2015-12-19 08:03:08 +00:00
|
|
|
)
|
|
|
|
self.acceptClient()
|
|
|
|
m = self.getMessage()
|
|
|
|
self.assertEqual(m.command, 'CAP',
|
|
|
|
'First message is not CAP LS.')
|
2015-12-19 08:30:50 +00:00
|
|
|
if m.params == ['LS']:
|
2015-12-19 08:03:08 +00:00
|
|
|
self.protocol_version = 301
|
2015-12-19 08:30:50 +00:00
|
|
|
elif m.params == ['LS', '302']:
|
2015-12-19 08:03:08 +00:00
|
|
|
self.protocol_version = 302
|
2015-12-19 10:32:19 +00:00
|
|
|
elif m.params == ['END']:
|
|
|
|
self.protocol_version = None
|
2015-12-19 08:03:08 +00:00
|
|
|
else:
|
2015-12-19 08:30:50 +00:00
|
|
|
raise AssertionError('Unknown CAP params: {}'
|
2015-12-19 08:03:08 +00:00
|
|
|
.format(m.params))
|
2015-12-19 08:30:50 +00:00
|
|
|
|
|
|
|
def userNickPredicate(self, msg):
|
|
|
|
"""Predicate to be used with getMessage to handle NICK/USER
|
|
|
|
transparently."""
|
|
|
|
if msg.command == 'NICK':
|
|
|
|
self.assertEqual(len(msg.params), 1, msg)
|
|
|
|
self.nick = msg.params[0]
|
|
|
|
return False
|
|
|
|
elif msg.command == 'USER':
|
|
|
|
self.assertEqual(len(msg.params), 4, msg)
|
2015-12-19 16:52:38 +00:00
|
|
|
self.user = msg.params
|
2015-12-19 08:30:50 +00:00
|
|
|
return False
|
|
|
|
else:
|
|
|
|
return True
|
|
|
|
|
2015-12-20 14:11:56 +00:00
|
|
|
def negotiateCapabilities(self, caps, cap_ls=True, auth=None):
|
2015-12-20 12:47:30 +00:00
|
|
|
"""Performes a complete capability negociation process, without
|
|
|
|
ending it, so the caller can continue the negociation."""
|
2015-12-19 16:52:38 +00:00
|
|
|
if cap_ls:
|
|
|
|
self.readCapLs(auth)
|
|
|
|
if not self.protocol_version:
|
|
|
|
# No negotiation.
|
|
|
|
return
|
2015-12-20 14:11:56 +00:00
|
|
|
self.sendLine('CAP * LS :{}'.format(' '.join(caps)))
|
|
|
|
capability_names = frozenset(capabilities.cap_list_to_dict(caps))
|
2015-12-19 20:17:06 +00:00
|
|
|
self.acked_capabilities = set()
|
2015-12-19 09:05:37 +00:00
|
|
|
while True:
|
|
|
|
m = self.getMessage(filter_pred=self.userNickPredicate)
|
2015-12-19 16:52:38 +00:00
|
|
|
if m.command != 'CAP':
|
|
|
|
return m
|
2015-12-19 09:05:37 +00:00
|
|
|
self.assertGreater(len(m.params), 0, m)
|
|
|
|
if m.params[0] == 'REQ':
|
|
|
|
self.assertEqual(len(m.params), 2, m)
|
|
|
|
requested = frozenset(m.params[1].split())
|
2015-12-19 20:17:06 +00:00
|
|
|
if not requested.issubset(capability_names):
|
2015-12-19 16:52:38 +00:00
|
|
|
self.sendLine('CAP {} NAK :{}'.format(
|
|
|
|
self.nick or '*',
|
|
|
|
m.params[1][0:100]))
|
|
|
|
else:
|
|
|
|
self.sendLine('CAP {} ACK :{}'.format(
|
|
|
|
self.nick or '*',
|
|
|
|
m.params[1]))
|
2015-12-19 20:17:06 +00:00
|
|
|
self.acked_capabilities.update(requested)
|
2015-12-19 09:05:37 +00:00
|
|
|
else:
|
|
|
|
return m
|
|
|
|
|
2015-12-19 22:09:06 +00:00
|
|
|
Client = collections.namedtuple('Client',
|
|
|
|
'conn conn_file')
|
|
|
|
|
|
|
|
class BaseServerTestCase(_IrcTestCase):
|
|
|
|
"""Basic class for server tests. Handles spawning a server and exchanging
|
|
|
|
messages with it."""
|
|
|
|
def setUp(self):
|
2015-12-20 00:48:56 +00:00
|
|
|
super().setUp()
|
2015-12-19 22:09:06 +00:00
|
|
|
self.find_hostname_and_port()
|
2015-12-20 12:35:33 +00:00
|
|
|
self.controller.run(self.hostname, self.port)
|
2015-12-19 22:09:06 +00:00
|
|
|
self.clients = {}
|
|
|
|
def tearDown(self):
|
|
|
|
self.controller.kill()
|
|
|
|
for client in list(self.clients):
|
|
|
|
self.removeClient(client)
|
|
|
|
def find_hostname_and_port(self):
|
|
|
|
"""Find available hostname/port to listen on."""
|
|
|
|
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
|
|
s.bind(("",0))
|
|
|
|
(self.hostname, self.port) = s.getsockname()
|
|
|
|
s.close()
|
|
|
|
|
|
|
|
def addClient(self, name=None):
|
|
|
|
"""Connects a client to the server and adds it to the dict.
|
|
|
|
If 'name' is not given, uses the lowest unused non-negative integer."""
|
|
|
|
if not name:
|
|
|
|
name = max(map(int, list(self.clients)+[0]))+1
|
|
|
|
conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
|
|
conn.connect((self.hostname, self.port))
|
|
|
|
conn_file = conn.makefile(newline='\r\n', encoding='utf8')
|
|
|
|
self.clients[name] = Client(conn=conn, conn_file=conn_file)
|
2015-12-20 14:11:56 +00:00
|
|
|
if self.show_io:
|
|
|
|
print('{}: connects to server.'.format(name))
|
2015-12-20 12:12:54 +00:00
|
|
|
return name
|
2015-12-19 22:09:06 +00:00
|
|
|
|
|
|
|
def removeClient(self, name):
|
2015-12-20 12:47:30 +00:00
|
|
|
"""Disconnects the client, without QUIT."""
|
2015-12-19 22:09:06 +00:00
|
|
|
assert name in self.clients
|
2015-12-20 14:11:56 +00:00
|
|
|
if self.show_io:
|
|
|
|
print('{}: disconnects from server.'.format(name))
|
2015-12-19 22:09:06 +00:00
|
|
|
self.clients[name].conn.close()
|
|
|
|
del self.clients[name]
|
|
|
|
|
2015-12-20 12:12:54 +00:00
|
|
|
def getLines(self, client):
|
|
|
|
data = b''
|
|
|
|
conn = self.clients[client].conn
|
|
|
|
try:
|
|
|
|
conn.setblocking(False)
|
|
|
|
while True:
|
|
|
|
time.sleep(0.1) # TODO: do better than this (use ping?)
|
|
|
|
data += conn.recv(4096)
|
|
|
|
except BlockingIOError:
|
|
|
|
for line in data.decode().split('\r\n'):
|
2015-12-20 14:11:56 +00:00
|
|
|
if line and self.show_io:
|
2015-12-20 12:12:54 +00:00
|
|
|
print('S -> {}: {}'.format(client, line.strip()))
|
|
|
|
yield line + '\r\n'
|
|
|
|
finally:
|
|
|
|
conn.setblocking(True) # required for readline()
|
2015-12-19 22:09:06 +00:00
|
|
|
def getLine(self, client):
|
|
|
|
assert client in self.clients
|
|
|
|
line = self.clients[client].conn_file.readline()
|
|
|
|
if self.show_io:
|
|
|
|
print('S -> {}: {}'.format(client, line.strip()))
|
|
|
|
return line
|
|
|
|
def sendLine(self, client, line):
|
|
|
|
ret = self.clients[client].conn.sendall(line.encode())
|
|
|
|
assert ret is None
|
|
|
|
if not line.endswith('\r\n'):
|
|
|
|
ret = self.clients[client].conn.sendall(b'\r\n')
|
|
|
|
assert ret is None
|
|
|
|
if self.show_io:
|
|
|
|
print('{} -> S: {}'.format(client, line.strip()))
|
2015-12-20 00:48:56 +00:00
|
|
|
|
2015-12-20 14:11:56 +00:00
|
|
|
def getCapLs(self, client, as_list=False):
|
2015-12-20 12:47:30 +00:00
|
|
|
"""Waits for a CAP LS block, parses all CAP LS messages, and return
|
2015-12-20 14:11:56 +00:00
|
|
|
the dict capabilities, with their values.
|
|
|
|
|
|
|
|
If as_list is given, returns the raw list (ie. key/value not split)
|
|
|
|
in case the order matters (but it shouldn't)."""
|
|
|
|
caps = []
|
2015-12-20 00:48:56 +00:00
|
|
|
while True:
|
|
|
|
m = self.getMessage(client,
|
|
|
|
filter_pred=lambda m:m.command != 'NOTICE')
|
|
|
|
self.assertMessageEqual(m, command='CAP', subcommand='LS')
|
|
|
|
if m.params[2] == '*':
|
2015-12-20 14:11:56 +00:00
|
|
|
caps.extend(m.params[3].split())
|
2015-12-20 00:48:56 +00:00
|
|
|
else:
|
2015-12-20 14:11:56 +00:00
|
|
|
caps.extend(m.params[2].split())
|
|
|
|
if not as_list:
|
|
|
|
caps = capabilities.cap_list_to_dict(caps)
|
|
|
|
return caps
|
|
|
|
|
|
|
|
class OptionalityHelper:
|
|
|
|
def checkMechanismSupport(self, mechanism):
|
|
|
|
if mechanism in self.controller.supported_sasl_mechanisms:
|
|
|
|
return
|
|
|
|
raise optional_extensions.OptionalSaslMechanismNotSupported(mechanism)
|
|
|
|
|
|
|
|
def skipUnlessHasMechanism(mech):
|
|
|
|
def decorator(f):
|
|
|
|
@functools.wraps(f)
|
|
|
|
def newf(self):
|
|
|
|
self.checkMechanismSupport(mech)
|
|
|
|
return f(self)
|
|
|
|
return newf
|
|
|
|
return decorator
|
|
|
|
|