mirror of
https://github.com/progval/irctest.git
synced 2025-04-06 23:39:46 +00:00
Add PING-based synchronization for fetching messages from server.
This commit is contained in:
187
irctest/cases.py
187
irctest/cases.py
@ -9,6 +9,11 @@ from . import optional_extensions
|
||||
from .irc_utils import message_parser
|
||||
from .irc_utils import capabilities
|
||||
|
||||
class ConnectionClosed(Exception):
|
||||
pass
|
||||
class NoMessageException(AssertionError):
|
||||
pass
|
||||
|
||||
class _IrcTestCase(unittest.TestCase):
|
||||
"""Base class for test cases."""
|
||||
controllerClass = None # Will be set by __main__.py
|
||||
@ -16,22 +21,9 @@ class _IrcTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.controller = self.controllerClass()
|
||||
self.inbuffer = []
|
||||
if self.show_io:
|
||||
print('---- new test ----')
|
||||
def getLine(self):
|
||||
raise NotImplementedError()
|
||||
def getMessages(self, *args):
|
||||
lines = self.getLines(*args)
|
||||
return map(message_parser.parse_message, lines)
|
||||
def getMessage(self, *args, filter_pred=None):
|
||||
"""Gets a message and returns it. If a filter predicate is given,
|
||||
fetches messages until the predicate returns a False on a message,
|
||||
and returns this message."""
|
||||
while True:
|
||||
line = self.getLine(*args)
|
||||
msg = message_parser.parse_message(line)
|
||||
if not filter_pred or filter_pred(msg):
|
||||
return msg
|
||||
def assertMessageEqual(self, msg, subcommand=None, subparams=None,
|
||||
target=None, **kwargs):
|
||||
"""Helper for partially comparing a message.
|
||||
@ -42,7 +34,7 @@ class _IrcTestCase(unittest.TestCase):
|
||||
Deals with subcommands (eg. `CAP`) if any of `subcommand`,
|
||||
`subparams`, and `target` are given."""
|
||||
for (key, value) in kwargs.items():
|
||||
with self.subTest(key=key):
|
||||
#with self.subTest(key=key):
|
||||
self.assertEqual(getattr(msg, key), value, msg)
|
||||
if subcommand is not None or subparams is not None:
|
||||
self.assertGreater(len(msg.params), 2, msg)
|
||||
@ -90,6 +82,18 @@ class BaseClientTestCase(_IrcTestCase):
|
||||
if self.show_io:
|
||||
print('{:.3f} C: {}'.format(time.time(), line.strip()))
|
||||
return line
|
||||
def getMessages(self, *args):
|
||||
lines = self.getLines(*args)
|
||||
return map(message_parser.parse_message, lines)
|
||||
def getMessage(self, *args, filter_pred=None):
|
||||
"""Gets a message and returns it. If a filter predicate is given,
|
||||
fetches messages until the predicate returns a False on a message,
|
||||
and returns this message."""
|
||||
while True:
|
||||
line = self.getLine(*args)
|
||||
msg = message_parser.parse_message(line)
|
||||
if not filter_pred or filter_pred(msg):
|
||||
return msg
|
||||
def sendLine(self, line):
|
||||
ret = self.conn.sendall(line.encode())
|
||||
assert ret is None
|
||||
@ -167,8 +171,80 @@ class ClientNegociationHelper:
|
||||
else:
|
||||
return m
|
||||
|
||||
Client = collections.namedtuple('Client',
|
||||
'conn conn_file')
|
||||
class Client:
|
||||
def __init__(self, name, show_io):
|
||||
self.name = name
|
||||
self.show_io = show_io
|
||||
self.inbuffer = []
|
||||
def connect(self, hostname, port):
|
||||
self.conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.conn.settimeout(1) # TODO: configurable
|
||||
self.conn.connect((hostname, port))
|
||||
if self.show_io:
|
||||
print('{:.3f} {}: connects to server.'.format(time.time(), self.name))
|
||||
def disconnect(self):
|
||||
if self.show_io:
|
||||
print('{:.3f} {}: disconnects from server.'.format(time.time(), self.name))
|
||||
self.conn.close()
|
||||
def getMessages(self, synchronize=True, assert_get_one=False):
|
||||
if synchronize:
|
||||
token = 'synchronize{}'.format(time.monotonic())
|
||||
self.sendLine('PING {}'.format(token))
|
||||
got_pong = False
|
||||
data = b''
|
||||
messages = []
|
||||
conn = self.conn
|
||||
while not got_pong:
|
||||
try:
|
||||
new_data = conn.recv(4096)
|
||||
except socket.timeout:
|
||||
if not assert_get_one and not synchronize and data == b'':
|
||||
# Received nothing
|
||||
return []
|
||||
if self.show_io:
|
||||
print('{:.3f} waiting…'.format(time.time()))
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
else:
|
||||
if not new_data:
|
||||
# Connection closed
|
||||
raise ConnectionClosed()
|
||||
data += new_data
|
||||
if not new_data.endswith(b'\r\n'):
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
if not synchronize:
|
||||
got_pong = True
|
||||
for line in data.decode().split('\r\n'):
|
||||
if line:
|
||||
if self.show_io:
|
||||
print('{:.3f} S -> {}: {}'.format(time.time(), self.name, line.strip()))
|
||||
message = message_parser.parse_message(line + '\r\n')
|
||||
if message.command == 'PONG' and \
|
||||
token in message.params:
|
||||
got_pong = True
|
||||
else:
|
||||
messages.append(message)
|
||||
data = b''
|
||||
return messages
|
||||
def getMessage(self, filter_pred=None, synchronize=True):
|
||||
while True:
|
||||
if not self.inbuffer:
|
||||
self.inbuffer = self.getMessages(
|
||||
synchronize=synchronize, assert_get_one=True)
|
||||
if not self.inbuffer:
|
||||
raise NoMessageException()
|
||||
message = self.inbuffer.pop(0) # TODO: use dequeue
|
||||
if not filter_pred or filter_pred(message):
|
||||
return message
|
||||
def sendLine(self, line):
|
||||
ret = self.conn.sendall(line.encode())
|
||||
assert ret is None
|
||||
if not line.endswith('\r\n'):
|
||||
ret = self.conn.sendall(b'\r\n')
|
||||
assert ret is None
|
||||
if self.show_io:
|
||||
print('{:.3f} {} -> S: {}'.format(time.time(), self.name, line.strip()))
|
||||
|
||||
class BaseServerTestCase(_IrcTestCase):
|
||||
"""Basic class for server tests. Handles spawning a server and exchanging
|
||||
@ -189,59 +265,33 @@ class BaseServerTestCase(_IrcTestCase):
|
||||
(self.hostname, self.port) = s.getsockname()
|
||||
s.close()
|
||||
|
||||
def addClient(self, name=None):
|
||||
def addClient(self, name=None, show_io=None):
|
||||
"""Connects a client to the server and adds it to the dict.
|
||||
If 'name' is not given, uses the lowest unused non-negative integer."""
|
||||
if not name:
|
||||
name = max(map(int, list(self.clients)+[0]))+1
|
||||
conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
conn.connect((self.hostname, self.port))
|
||||
conn_file = conn.makefile(newline='\r\n', encoding='utf8')
|
||||
self.clients[name] = Client(conn=conn, conn_file=conn_file)
|
||||
if self.show_io:
|
||||
print('{:.3f} {}: connects to server.'.format(time.time(), name))
|
||||
show_io = show_io if show_io is not None else self.show_io
|
||||
self.clients[name] = Client(name=name, show_io=show_io)
|
||||
self.clients[name].connect(self.hostname, self.port)
|
||||
return name
|
||||
|
||||
|
||||
def removeClient(self, name):
|
||||
"""Disconnects the client, without QUIT."""
|
||||
assert name in self.clients
|
||||
if self.show_io:
|
||||
print('{:.3f} {}: disconnects from server.'.format(time.time(), name))
|
||||
self.clients[name].conn.close()
|
||||
self.clients[name].disconnect()
|
||||
del self.clients[name]
|
||||
|
||||
def getLines(self, client):
|
||||
data = b''
|
||||
conn = self.clients[client].conn
|
||||
try:
|
||||
conn.setblocking(False)
|
||||
while True:
|
||||
time.sleep(0.1) # TODO: do better than this (use ping?)
|
||||
new_data = conn.recv(4096) # May raise BlockingIOError
|
||||
if not new_data:
|
||||
raise BlockingIOError
|
||||
data += new_data
|
||||
except BlockingIOError:
|
||||
for line in data.decode().split('\r\n'):
|
||||
if line and self.show_io:
|
||||
print('{:.3f} S -> {}: {}'.format(time.time(), client, line.strip()))
|
||||
yield line + '\r\n'
|
||||
finally:
|
||||
conn.setblocking(True) # required for readline()
|
||||
def getLine(self, client):
|
||||
assert client in self.clients
|
||||
line = self.clients[client].conn_file.readline()
|
||||
if self.show_io:
|
||||
print('{:.3f} S -> {}: {}'.format(time.time(), client, line.strip()))
|
||||
return line
|
||||
def getMessages(self, client, **kwargs):
|
||||
return self.clients[client].getMessages(**kwargs)
|
||||
def getMessage(self, client, **kwargs):
|
||||
return self.clients[client].getMessage(**kwargs)
|
||||
def getRegistrationMessage(self, client):
|
||||
"""Filter notices, do not send pings."""
|
||||
return self.getMessage(client, synchronize=False,
|
||||
filter_pred=lambda m:m.command != 'NOTICE')
|
||||
def sendLine(self, client, line):
|
||||
ret = self.clients[client].conn.sendall(line.encode())
|
||||
assert ret is None
|
||||
if not line.endswith('\r\n'):
|
||||
ret = self.clients[client].conn.sendall(b'\r\n')
|
||||
assert ret is None
|
||||
if self.show_io:
|
||||
print('{:.3f} {} -> S: {}'.format(time.time(), client, line.strip()))
|
||||
return self.clients[client].sendLine(line)
|
||||
|
||||
def getCapLs(self, client, as_list=False):
|
||||
"""Waits for a CAP LS block, parses all CAP LS messages, and return
|
||||
@ -251,8 +301,7 @@ class BaseServerTestCase(_IrcTestCase):
|
||||
in case the order matters (but it shouldn't)."""
|
||||
caps = []
|
||||
while True:
|
||||
m = self.getMessage(client,
|
||||
filter_pred=lambda m:m.command != 'NOTICE')
|
||||
m = self.getRegistrationMessage(client)
|
||||
self.assertMessageEqual(m, command='CAP', subcommand='LS')
|
||||
if m.params[2] == '*':
|
||||
caps.extend(m.params[3].split())
|
||||
@ -277,6 +326,24 @@ class BaseServerTestCase(_IrcTestCase):
|
||||
return
|
||||
else:
|
||||
raise AssertionError('Client not disconnected.')
|
||||
def connectClient(self, nick, name=None):
|
||||
name = self.addClient(name)
|
||||
self.sendLine(1, 'NICK {}'.format(nick))
|
||||
self.sendLine(1, 'USER username * * :Realname')
|
||||
|
||||
# Skip to the point where we are registered
|
||||
# https://tools.ietf.org/html/rfc2812#section-3.1
|
||||
while True:
|
||||
m = self.getMessage(1, synchronize=False)
|
||||
if m.command == '001':
|
||||
break
|
||||
self.sendLine(1, 'PING foo')
|
||||
|
||||
# Skip all that happy welcoming stuff
|
||||
while True:
|
||||
m = self.getMessage(1)
|
||||
if m.command == 'PONG':
|
||||
break
|
||||
|
||||
class OptionalityHelper:
|
||||
def checkMechanismSupport(self, mechanism):
|
||||
|
@ -76,6 +76,8 @@ class MammonController(BaseServerController, DirectoryBasedController):
|
||||
hostname=hostname,
|
||||
port=port,
|
||||
))
|
||||
#with self.open_file('server.yml', 'r') as fd:
|
||||
# print(fd.read())
|
||||
self.proc = subprocess.Popen(['mammond', '--nofork', #'--debug',
|
||||
'--config', os.path.join(self.directory, 'server.yml')])
|
||||
self.wait_for_port(self.proc, port)
|
||||
@ -84,17 +86,19 @@ class MammonController(BaseServerController, DirectoryBasedController):
|
||||
# XXX: Move this somewhere else when
|
||||
# https://github.com/ircv3/ircv3-specifications/pull/152 becomes
|
||||
# part of the specification
|
||||
client = case.addClient()
|
||||
client = case.addClient(show_io=False)
|
||||
case.sendLine(client, 'CAP LS 302')
|
||||
case.sendLine(client, 'NICK registration_user')
|
||||
case.sendLine(client, 'USER r e g :user')
|
||||
case.sendLine(client, 'CAP END')
|
||||
list(case.getLines(client))
|
||||
while case.getRegistrationMessage(client).command != '001':
|
||||
pass
|
||||
list(case.getMessages(client))
|
||||
case.sendLine(client, 'REG CREATE {} passphrase {}'.format(
|
||||
username, password))
|
||||
msg = case.getMessage(client)
|
||||
assert msg.command == '920'
|
||||
list(case.getLines(client))
|
||||
assert msg.command == '920', msg
|
||||
list(case.getMessages(client))
|
||||
case.removeClient(client)
|
||||
|
||||
def get_irctest_controller_class():
|
||||
|
@ -11,7 +11,7 @@ class CapTestCase(cases.BaseServerTestCase):
|
||||
self.sendLine(1, 'USER foo foo foo :foo')
|
||||
self.sendLine(1, 'NICK foo')
|
||||
self.sendLine(1, 'CAP END')
|
||||
m = self.getMessage(1, filter_pred=lambda m:m.command != 'NOTICE')
|
||||
m = self.getRegistrationMessage(1)
|
||||
self.assertMessageEqual(m, command='001')
|
||||
|
||||
def testReqUnavailable(self):
|
||||
@ -23,11 +23,11 @@ class CapTestCase(cases.BaseServerTestCase):
|
||||
self.sendLine(1, 'USER foo foo foo :foo')
|
||||
self.sendLine(1, 'NICK foo')
|
||||
self.sendLine(1, 'CAP REQ :foo')
|
||||
m = self.getMessage(1, filter_pred=lambda m:m.command != 'NOTICE')
|
||||
m = self.getRegistrationMessage(1)
|
||||
self.assertMessageEqual(m, command='CAP',
|
||||
subcommand='NAK', subparams=['foo'])
|
||||
self.sendLine(1, 'CAP END')
|
||||
m = self.getMessage(1, filter_pred=lambda m:m.command != 'NOTICE')
|
||||
m = self.getRegistrationMessage(1)
|
||||
self.assertEqual(m.command, '001')
|
||||
|
||||
def testNakExactString(self):
|
||||
@ -39,7 +39,7 @@ class CapTestCase(cases.BaseServerTestCase):
|
||||
# Five should be enough to check there is no reordering, even
|
||||
# alphabetical
|
||||
self.sendLine(1, 'CAP REQ :foo bar baz qux quux')
|
||||
m = self.getMessage(1, filter_pred=lambda m:m.command != 'NOTICE')
|
||||
m = self.getRegistrationMessage(1)
|
||||
self.assertMessageEqual(m, command='CAP',
|
||||
subcommand='NAK', subparams=['foo bar baz qux quux'])
|
||||
|
||||
@ -49,19 +49,19 @@ class CapTestCase(cases.BaseServerTestCase):
|
||||
self.sendLine(1, 'CAP LS 302')
|
||||
self.assertIn('multi-prefix', self.getCapLs(1))
|
||||
self.sendLine(1, 'CAP REQ :foo multi-prefix bar')
|
||||
m = self.getMessage(1, filter_pred=lambda m:m.command != 'NOTICE')
|
||||
m = self.getRegistrationMessage(1)
|
||||
self.assertMessageEqual(m, command='CAP',
|
||||
subcommand='NAK', subparams=['foo multi-prefix bar'])
|
||||
self.sendLine(1, 'CAP REQ :multi-prefix bar')
|
||||
m = self.getMessage(1, filter_pred=lambda m:m.command != 'NOTICE')
|
||||
m = self.getRegistrationMessage(1)
|
||||
self.assertMessageEqual(m, command='CAP',
|
||||
subcommand='NAK', subparams=['multi-prefix bar'])
|
||||
self.sendLine(1, 'CAP REQ :foo multi-prefix')
|
||||
m = self.getMessage(1, filter_pred=lambda m:m.command != 'NOTICE')
|
||||
m = self.getRegistrationMessage(1)
|
||||
self.assertMessageEqual(m, command='CAP',
|
||||
subcommand='NAK', subparams=['foo multi-prefix'])
|
||||
# TODO: make sure multi-prefix is not enabled at this point
|
||||
self.sendLine(1, 'CAP REQ :multi-prefix')
|
||||
m = self.getMessage(1, filter_pred=lambda m:m.command != 'NOTICE')
|
||||
m = self.getRegistrationMessage(1)
|
||||
self.assertMessageEqual(m, command='CAP',
|
||||
subcommand='ACK', subparams=['multi-prefix'])
|
||||
|
@ -8,33 +8,15 @@ from irctest import authentication
|
||||
from irctest.irc_utils.message_parser import Message
|
||||
|
||||
class ConnectionRegistrationTestCase(cases.BaseServerTestCase):
|
||||
def connectClient(self, nick, name=None):
|
||||
name = self.addClient(name)
|
||||
self.sendLine(1, 'NICK {}'.format(nick))
|
||||
self.sendLine(1, 'USER username * * :Realname')
|
||||
|
||||
# Skip to the point where we are registered
|
||||
# https://tools.ietf.org/html/rfc2812#section-3.1
|
||||
while True:
|
||||
m = self.getMessage(1)
|
||||
if m.command == '001':
|
||||
break
|
||||
self.sendLine(1, 'PING foo')
|
||||
|
||||
# Skip all that happy welcoming stuff
|
||||
while True:
|
||||
m = self.getMessage(1)
|
||||
if m.command == 'PONG':
|
||||
break
|
||||
|
||||
def testPassBeforeNickuser(self):
|
||||
"""“Currently this requires that user send a PASS command before
|
||||
sending the NICK/USER combination.”
|
||||
<https://tools.ietf.org/html/rfc2812#section-3.1.1>"""
|
||||
self.connectClient('foo')
|
||||
self.getMessages(1)
|
||||
self.getMessages(1, synchronize=False)
|
||||
self.sendLine(1, 'PASS :foo')
|
||||
m = self.getMessage(1)
|
||||
m = self.getRegistrationMessage(1)
|
||||
self.assertMessageEqual(m, command='462') # ERR_ALREADYREGISTRED
|
||||
|
||||
def testQuitDisconnects(self):
|
||||
@ -44,16 +26,14 @@ class ConnectionRegistrationTestCase(cases.BaseServerTestCase):
|
||||
self.connectClient('foo')
|
||||
self.getMessages(1)
|
||||
self.sendLine(1, 'QUIT')
|
||||
m = self.getMessage(1)
|
||||
self.assertMessageEqual(m, command='ERROR')
|
||||
self.assertDisconnected(1)
|
||||
self.assertRaises(cases.ConnectionClosed, self.getMessages, 1)
|
||||
|
||||
def testNickCollision(self):
|
||||
self.connectClient('foo')
|
||||
self.addClient()
|
||||
self.sendLine(2, 'NICK foo')
|
||||
self.sendLine(2, 'USER username * * :Realname')
|
||||
m = self.getMessage(2, filter_pred=lambda m:m.command != 'NOTICE')
|
||||
m = self.getRegistrationMessage(2)
|
||||
self.assertNotEqual(m.command, '001')
|
||||
|
||||
def testEarlyNickCollision(self):
|
||||
@ -63,6 +43,6 @@ class ConnectionRegistrationTestCase(cases.BaseServerTestCase):
|
||||
self.sendLine(2, 'NICK foo')
|
||||
self.sendLine(1, 'USER username * * :Realname')
|
||||
self.sendLine(2, 'USER username * * :Realname')
|
||||
m1 = self.getMessage(1, filter_pred=lambda m:m.command != 'NOTICE')
|
||||
m2 = self.getMessage(2, filter_pred=lambda m:m.command != 'NOTICE')
|
||||
m1 = self.getRegistrationMessage(1)
|
||||
m2 = self.getRegistrationMessage(2)
|
||||
self.assertNotEqual((m1.command, m2.command), ('001', '001'))
|
||||
|
Reference in New Issue
Block a user