Add PING-based synchronization for fetching messages from server.

This commit is contained in:
Valentin Lorentz
2015-12-21 12:24:40 +01:00
parent 4335b909e5
commit 5eb014d4ba
4 changed files with 149 additions and 98 deletions

View File

@ -9,6 +9,11 @@ from . import optional_extensions
from .irc_utils import message_parser
from .irc_utils import capabilities
class ConnectionClosed(Exception):
pass
class NoMessageException(AssertionError):
pass
class _IrcTestCase(unittest.TestCase):
"""Base class for test cases."""
controllerClass = None # Will be set by __main__.py
@ -16,22 +21,9 @@ class _IrcTestCase(unittest.TestCase):
def setUp(self):
super().setUp()
self.controller = self.controllerClass()
self.inbuffer = []
if self.show_io:
print('---- new test ----')
def getLine(self):
raise NotImplementedError()
def getMessages(self, *args):
lines = self.getLines(*args)
return map(message_parser.parse_message, lines)
def getMessage(self, *args, filter_pred=None):
"""Gets a message and returns it. If a filter predicate is given,
fetches messages until the predicate returns a False on a message,
and returns this message."""
while True:
line = self.getLine(*args)
msg = message_parser.parse_message(line)
if not filter_pred or filter_pred(msg):
return msg
def assertMessageEqual(self, msg, subcommand=None, subparams=None,
target=None, **kwargs):
"""Helper for partially comparing a message.
@ -42,7 +34,7 @@ class _IrcTestCase(unittest.TestCase):
Deals with subcommands (eg. `CAP`) if any of `subcommand`,
`subparams`, and `target` are given."""
for (key, value) in kwargs.items():
with self.subTest(key=key):
#with self.subTest(key=key):
self.assertEqual(getattr(msg, key), value, msg)
if subcommand is not None or subparams is not None:
self.assertGreater(len(msg.params), 2, msg)
@ -90,6 +82,18 @@ class BaseClientTestCase(_IrcTestCase):
if self.show_io:
print('{:.3f} C: {}'.format(time.time(), line.strip()))
return line
def getMessages(self, *args):
lines = self.getLines(*args)
return map(message_parser.parse_message, lines)
def getMessage(self, *args, filter_pred=None):
"""Gets a message and returns it. If a filter predicate is given,
fetches messages until the predicate returns a False on a message,
and returns this message."""
while True:
line = self.getLine(*args)
msg = message_parser.parse_message(line)
if not filter_pred or filter_pred(msg):
return msg
def sendLine(self, line):
ret = self.conn.sendall(line.encode())
assert ret is None
@ -167,8 +171,80 @@ class ClientNegociationHelper:
else:
return m
Client = collections.namedtuple('Client',
'conn conn_file')
class Client:
def __init__(self, name, show_io):
self.name = name
self.show_io = show_io
self.inbuffer = []
def connect(self, hostname, port):
self.conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.conn.settimeout(1) # TODO: configurable
self.conn.connect((hostname, port))
if self.show_io:
print('{:.3f} {}: connects to server.'.format(time.time(), self.name))
def disconnect(self):
if self.show_io:
print('{:.3f} {}: disconnects from server.'.format(time.time(), self.name))
self.conn.close()
def getMessages(self, synchronize=True, assert_get_one=False):
if synchronize:
token = 'synchronize{}'.format(time.monotonic())
self.sendLine('PING {}'.format(token))
got_pong = False
data = b''
messages = []
conn = self.conn
while not got_pong:
try:
new_data = conn.recv(4096)
except socket.timeout:
if not assert_get_one and not synchronize and data == b'':
# Received nothing
return []
if self.show_io:
print('{:.3f} waiting…'.format(time.time()))
time.sleep(0.1)
continue
else:
if not new_data:
# Connection closed
raise ConnectionClosed()
data += new_data
if not new_data.endswith(b'\r\n'):
time.sleep(0.1)
continue
if not synchronize:
got_pong = True
for line in data.decode().split('\r\n'):
if line:
if self.show_io:
print('{:.3f} S -> {}: {}'.format(time.time(), self.name, line.strip()))
message = message_parser.parse_message(line + '\r\n')
if message.command == 'PONG' and \
token in message.params:
got_pong = True
else:
messages.append(message)
data = b''
return messages
def getMessage(self, filter_pred=None, synchronize=True):
while True:
if not self.inbuffer:
self.inbuffer = self.getMessages(
synchronize=synchronize, assert_get_one=True)
if not self.inbuffer:
raise NoMessageException()
message = self.inbuffer.pop(0) # TODO: use dequeue
if not filter_pred or filter_pred(message):
return message
def sendLine(self, line):
ret = self.conn.sendall(line.encode())
assert ret is None
if not line.endswith('\r\n'):
ret = self.conn.sendall(b'\r\n')
assert ret is None
if self.show_io:
print('{:.3f} {} -> S: {}'.format(time.time(), self.name, line.strip()))
class BaseServerTestCase(_IrcTestCase):
"""Basic class for server tests. Handles spawning a server and exchanging
@ -189,59 +265,33 @@ class BaseServerTestCase(_IrcTestCase):
(self.hostname, self.port) = s.getsockname()
s.close()
def addClient(self, name=None):
def addClient(self, name=None, show_io=None):
"""Connects a client to the server and adds it to the dict.
If 'name' is not given, uses the lowest unused non-negative integer."""
if not name:
name = max(map(int, list(self.clients)+[0]))+1
conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
conn.connect((self.hostname, self.port))
conn_file = conn.makefile(newline='\r\n', encoding='utf8')
self.clients[name] = Client(conn=conn, conn_file=conn_file)
if self.show_io:
print('{:.3f} {}: connects to server.'.format(time.time(), name))
show_io = show_io if show_io is not None else self.show_io
self.clients[name] = Client(name=name, show_io=show_io)
self.clients[name].connect(self.hostname, self.port)
return name
def removeClient(self, name):
"""Disconnects the client, without QUIT."""
assert name in self.clients
if self.show_io:
print('{:.3f} {}: disconnects from server.'.format(time.time(), name))
self.clients[name].conn.close()
self.clients[name].disconnect()
del self.clients[name]
def getLines(self, client):
data = b''
conn = self.clients[client].conn
try:
conn.setblocking(False)
while True:
time.sleep(0.1) # TODO: do better than this (use ping?)
new_data = conn.recv(4096) # May raise BlockingIOError
if not new_data:
raise BlockingIOError
data += new_data
except BlockingIOError:
for line in data.decode().split('\r\n'):
if line and self.show_io:
print('{:.3f} S -> {}: {}'.format(time.time(), client, line.strip()))
yield line + '\r\n'
finally:
conn.setblocking(True) # required for readline()
def getLine(self, client):
assert client in self.clients
line = self.clients[client].conn_file.readline()
if self.show_io:
print('{:.3f} S -> {}: {}'.format(time.time(), client, line.strip()))
return line
def getMessages(self, client, **kwargs):
return self.clients[client].getMessages(**kwargs)
def getMessage(self, client, **kwargs):
return self.clients[client].getMessage(**kwargs)
def getRegistrationMessage(self, client):
"""Filter notices, do not send pings."""
return self.getMessage(client, synchronize=False,
filter_pred=lambda m:m.command != 'NOTICE')
def sendLine(self, client, line):
ret = self.clients[client].conn.sendall(line.encode())
assert ret is None
if not line.endswith('\r\n'):
ret = self.clients[client].conn.sendall(b'\r\n')
assert ret is None
if self.show_io:
print('{:.3f} {} -> S: {}'.format(time.time(), client, line.strip()))
return self.clients[client].sendLine(line)
def getCapLs(self, client, as_list=False):
"""Waits for a CAP LS block, parses all CAP LS messages, and return
@ -251,8 +301,7 @@ class BaseServerTestCase(_IrcTestCase):
in case the order matters (but it shouldn't)."""
caps = []
while True:
m = self.getMessage(client,
filter_pred=lambda m:m.command != 'NOTICE')
m = self.getRegistrationMessage(client)
self.assertMessageEqual(m, command='CAP', subcommand='LS')
if m.params[2] == '*':
caps.extend(m.params[3].split())
@ -277,6 +326,24 @@ class BaseServerTestCase(_IrcTestCase):
return
else:
raise AssertionError('Client not disconnected.')
def connectClient(self, nick, name=None):
name = self.addClient(name)
self.sendLine(1, 'NICK {}'.format(nick))
self.sendLine(1, 'USER username * * :Realname')
# Skip to the point where we are registered
# https://tools.ietf.org/html/rfc2812#section-3.1
while True:
m = self.getMessage(1, synchronize=False)
if m.command == '001':
break
self.sendLine(1, 'PING foo')
# Skip all that happy welcoming stuff
while True:
m = self.getMessage(1)
if m.command == 'PONG':
break
class OptionalityHelper:
def checkMechanismSupport(self, mechanism):

View File

@ -76,6 +76,8 @@ class MammonController(BaseServerController, DirectoryBasedController):
hostname=hostname,
port=port,
))
#with self.open_file('server.yml', 'r') as fd:
# print(fd.read())
self.proc = subprocess.Popen(['mammond', '--nofork', #'--debug',
'--config', os.path.join(self.directory, 'server.yml')])
self.wait_for_port(self.proc, port)
@ -84,17 +86,19 @@ class MammonController(BaseServerController, DirectoryBasedController):
# XXX: Move this somewhere else when
# https://github.com/ircv3/ircv3-specifications/pull/152 becomes
# part of the specification
client = case.addClient()
client = case.addClient(show_io=False)
case.sendLine(client, 'CAP LS 302')
case.sendLine(client, 'NICK registration_user')
case.sendLine(client, 'USER r e g :user')
case.sendLine(client, 'CAP END')
list(case.getLines(client))
while case.getRegistrationMessage(client).command != '001':
pass
list(case.getMessages(client))
case.sendLine(client, 'REG CREATE {} passphrase {}'.format(
username, password))
msg = case.getMessage(client)
assert msg.command == '920'
list(case.getLines(client))
assert msg.command == '920', msg
list(case.getMessages(client))
case.removeClient(client)
def get_irctest_controller_class():

View File

@ -11,7 +11,7 @@ class CapTestCase(cases.BaseServerTestCase):
self.sendLine(1, 'USER foo foo foo :foo')
self.sendLine(1, 'NICK foo')
self.sendLine(1, 'CAP END')
m = self.getMessage(1, filter_pred=lambda m:m.command != 'NOTICE')
m = self.getRegistrationMessage(1)
self.assertMessageEqual(m, command='001')
def testReqUnavailable(self):
@ -23,11 +23,11 @@ class CapTestCase(cases.BaseServerTestCase):
self.sendLine(1, 'USER foo foo foo :foo')
self.sendLine(1, 'NICK foo')
self.sendLine(1, 'CAP REQ :foo')
m = self.getMessage(1, filter_pred=lambda m:m.command != 'NOTICE')
m = self.getRegistrationMessage(1)
self.assertMessageEqual(m, command='CAP',
subcommand='NAK', subparams=['foo'])
self.sendLine(1, 'CAP END')
m = self.getMessage(1, filter_pred=lambda m:m.command != 'NOTICE')
m = self.getRegistrationMessage(1)
self.assertEqual(m.command, '001')
def testNakExactString(self):
@ -39,7 +39,7 @@ class CapTestCase(cases.BaseServerTestCase):
# Five should be enough to check there is no reordering, even
# alphabetical
self.sendLine(1, 'CAP REQ :foo bar baz qux quux')
m = self.getMessage(1, filter_pred=lambda m:m.command != 'NOTICE')
m = self.getRegistrationMessage(1)
self.assertMessageEqual(m, command='CAP',
subcommand='NAK', subparams=['foo bar baz qux quux'])
@ -49,19 +49,19 @@ class CapTestCase(cases.BaseServerTestCase):
self.sendLine(1, 'CAP LS 302')
self.assertIn('multi-prefix', self.getCapLs(1))
self.sendLine(1, 'CAP REQ :foo multi-prefix bar')
m = self.getMessage(1, filter_pred=lambda m:m.command != 'NOTICE')
m = self.getRegistrationMessage(1)
self.assertMessageEqual(m, command='CAP',
subcommand='NAK', subparams=['foo multi-prefix bar'])
self.sendLine(1, 'CAP REQ :multi-prefix bar')
m = self.getMessage(1, filter_pred=lambda m:m.command != 'NOTICE')
m = self.getRegistrationMessage(1)
self.assertMessageEqual(m, command='CAP',
subcommand='NAK', subparams=['multi-prefix bar'])
self.sendLine(1, 'CAP REQ :foo multi-prefix')
m = self.getMessage(1, filter_pred=lambda m:m.command != 'NOTICE')
m = self.getRegistrationMessage(1)
self.assertMessageEqual(m, command='CAP',
subcommand='NAK', subparams=['foo multi-prefix'])
# TODO: make sure multi-prefix is not enabled at this point
self.sendLine(1, 'CAP REQ :multi-prefix')
m = self.getMessage(1, filter_pred=lambda m:m.command != 'NOTICE')
m = self.getRegistrationMessage(1)
self.assertMessageEqual(m, command='CAP',
subcommand='ACK', subparams=['multi-prefix'])

View File

@ -8,33 +8,15 @@ from irctest import authentication
from irctest.irc_utils.message_parser import Message
class ConnectionRegistrationTestCase(cases.BaseServerTestCase):
def connectClient(self, nick, name=None):
name = self.addClient(name)
self.sendLine(1, 'NICK {}'.format(nick))
self.sendLine(1, 'USER username * * :Realname')
# Skip to the point where we are registered
# https://tools.ietf.org/html/rfc2812#section-3.1
while True:
m = self.getMessage(1)
if m.command == '001':
break
self.sendLine(1, 'PING foo')
# Skip all that happy welcoming stuff
while True:
m = self.getMessage(1)
if m.command == 'PONG':
break
def testPassBeforeNickuser(self):
"""“Currently this requires that user send a PASS command before
sending the NICK/USER combination.”
<https://tools.ietf.org/html/rfc2812#section-3.1.1>"""
self.connectClient('foo')
self.getMessages(1)
self.getMessages(1, synchronize=False)
self.sendLine(1, 'PASS :foo')
m = self.getMessage(1)
m = self.getRegistrationMessage(1)
self.assertMessageEqual(m, command='462') # ERR_ALREADYREGISTRED
def testQuitDisconnects(self):
@ -44,16 +26,14 @@ class ConnectionRegistrationTestCase(cases.BaseServerTestCase):
self.connectClient('foo')
self.getMessages(1)
self.sendLine(1, 'QUIT')
m = self.getMessage(1)
self.assertMessageEqual(m, command='ERROR')
self.assertDisconnected(1)
self.assertRaises(cases.ConnectionClosed, self.getMessages, 1)
def testNickCollision(self):
self.connectClient('foo')
self.addClient()
self.sendLine(2, 'NICK foo')
self.sendLine(2, 'USER username * * :Realname')
m = self.getMessage(2, filter_pred=lambda m:m.command != 'NOTICE')
m = self.getRegistrationMessage(2)
self.assertNotEqual(m.command, '001')
def testEarlyNickCollision(self):
@ -63,6 +43,6 @@ class ConnectionRegistrationTestCase(cases.BaseServerTestCase):
self.sendLine(2, 'NICK foo')
self.sendLine(1, 'USER username * * :Realname')
self.sendLine(2, 'USER username * * :Realname')
m1 = self.getMessage(1, filter_pred=lambda m:m.command != 'NOTICE')
m2 = self.getMessage(2, filter_pred=lambda m:m.command != 'NOTICE')
m1 = self.getRegistrationMessage(1)
m2 = self.getRegistrationMessage(2)
self.assertNotEqual((m1.command, m2.command), ('001', '001'))