Add tests for channel names sensitivity.

This commit is contained in:
Valentin Lorentz 2015-12-22 12:14:55 +01:00
parent 63671afcf4
commit 95db1d4dfd
3 changed files with 68 additions and 11 deletions

View File

@ -32,7 +32,7 @@ class _IrcTestCase(unittest.TestCase):
if self.show_io:
print('---- new test ----')
def assertMessageEqual(self, msg, subcommand=None, subparams=None,
target=None, fail_msg=None, **kwargs):
target=None, nick=None, fail_msg=None, **kwargs):
"""Helper for partially comparing a message.
Takes the message as first arguments, and comparisons to be made
@ -43,6 +43,9 @@ class _IrcTestCase(unittest.TestCase):
fail_msg = fail_msg or '{msg}'
for (key, value) in kwargs.items():
self.assertEqual(getattr(msg, key), value, msg, fail_msg)
if nick:
self.assertNotEqual(msg.prefix, None, msg, fail_msg)
self.assertEqual(msg.prefix.split('!')[0], nick, msg, fail_msg)
if subcommand is not None or subparams is not None:
self.assertGreater(len(msg.params), 2, fail_msg)
msg_target = msg.params[0]
@ -55,14 +58,21 @@ class _IrcTestCase(unittest.TestCase):
with self.subTest(key='subparams'):
self.assertEqual(msg_subparams, subparams, msg, fail_msg)
def assertIn(self, got, expects, msg=None, fail_msg=None):
def assertIn(self, item, list_, msg=None, fail_msg=None, extra_format=()):
if fail_msg:
fail_msg = fail_msg.format(got=got, expects=expects, msg=msg)
super().assertIn(got, expects, fail_msg)
def assertEqual(self, got, expects, msg=None, fail_msg=None):
fail_msg = fail_msg.format(*extra_format,
item=item, list=list_, msg=msg)
super().assertIn(item, list_, fail_msg)
def assertEqual(self, got, expects, msg=None, fail_msg=None, extra_format=()):
if fail_msg:
fail_msg = fail_msg.format(got=got, expects=expects, msg=msg)
fail_msg = fail_msg.format(*extra_format,
got=got, expects=expects, msg=msg)
super().assertEqual(got, expects, fail_msg)
def assertNotEqual(self, got, expects, msg=None, fail_msg=None, extra_format=()):
if fail_msg:
fail_msg = fail_msg.format(*extra_format,
got=got, expects=expects, msg=msg)
super().assertNotEqual(got, expects, fail_msg)
class BaseClientTestCase(_IrcTestCase):
"""Basic class for client tests. Handles spawning a client and exchanging
@ -289,6 +299,14 @@ class BaseServerTestCase(_IrcTestCase):
if m.command == 'PONG':
break
def joinClient(self, client, channel):
self.sendLine(client, 'JOIN {}'.format(channel))
received = {m.command for m in self.getMessages(client)}
self.assertIn('366', received,
fail_msg='Join to {} failed, {item} is not in the set of '
'received responses: {list}',
extra_format=(channel,))
class OptionalityHelper:
def checkSaslSupport(self):
if self.controller.supported_sasl_mechanisms:

View File

@ -48,7 +48,7 @@ def parse_message(s):
else:
tags = []
if tokens[0].startswith(':'):
prefix = tokens.pop(0)
prefix = tokens.pop(0)[1:]
else:
prefix = None
command = tokens.pop(0)

View File

@ -53,7 +53,7 @@ class JoinTestCase(cases.BaseServerTestCase):
'<3 or >4: {msg}')
params = ambiguities.normalize_namreply_params(m.params)
self.assertIn(params[1], '=*@', m,
fail_msg='Bad channel prefix: {got} not in {expects}: {msg}')
fail_msg='Bad channel prefix: {item} not in {list}: {msg}')
self.assertEqual(params[2], '#chan', m,
fail_msg='Bad channel name: {got} instead of '
'{expects}: {msg}')
@ -124,7 +124,7 @@ class JoinTestCase(cases.BaseServerTestCase):
'<3 or >4: {msg}')
params = ambiguities.normalize_namreply_params(m.params)
self.assertIn(params[1], '=*@', m,
fail_msg='Bad channel prefix: {got} not in {expects}: {msg}')
fail_msg='Bad channel prefix: {item} not in {list}: {msg}')
self.assertEqual(params[2], '#chan', m,
fail_msg='Bad channel name: {got} instead of '
'{expects}: {msg}')
@ -155,7 +155,6 @@ class JoinTestCase(cases.BaseServerTestCase):
try:
m = self.getMessage(1)
if m.command == '482':
print(m)
raise optionality.ImplementationChoice(
'Channel creators are not opped by default, and '
'channel modes to no allow regular users to change '
@ -188,7 +187,6 @@ class JoinTestCase(cases.BaseServerTestCase):
try:
m = self.getMessage(1)
if m.command == '482':
print(m)
raise optionality.ImplementationChoice(
'Channel creators are not opped by default.')
self.assertMessageEqual(m, command='TOPIC')
@ -346,3 +344,44 @@ class JoinTestCase(cases.BaseServerTestCase):
m = self.getMessage(4)
self.assertMessageEqual(m, command='KICK',
params=['#chan', 'baz', 'bye'])
class testChannelCaseSensitivity(cases.BaseServerTestCase):
def _testChannelsEquivalent(name1, name2):
def f(self):
self.connectClient('foo')
self.connectClient('bar')
self.joinClient(1, name1)
self.joinClient(2, name2)
try:
m = self.getMessage(1)
self.assertMessageEqual(m, command='JOIN',
nick='bar')
except client_mock.NoMessageException:
raise AssertionError(
'Channel names {} and {} are not equivalent.'
.format(name1, name2))
f.__name__ = 'testEquivalence__{}__{}'.format(name1, name2)
return f
def _testChannelsNotEquivalent(name1, name2):
def f(self):
self.connectClient('foo')
self.connectClient('bar')
self.joinClient(1, name1)
self.joinClient(2, name2)
try:
m = self.getMessage(1)
except client_mock.NoMessageException:
pass
else:
self.assertMessageEqual(m, command='JOIN',
nick='bar') # This should always be true
raise AssertionError(
'Channel names {} and {} are equivalent.'
.format(name1, name2))
f.__name__ = 'testEquivalence__{}__{}'.format(name1, name2)
return f
testSimpleEquivalent = _testChannelsEquivalent('#Foo', '#foo')
testSimpleNotEquivalent = _testChannelsNotEquivalent('#Foo', '#fooa')
testFancyEquivalent = _testChannelsEquivalent('#F]|oo{', '#f}\\oo[')
testFancyNotEquivalent = _testChannelsEquivalent('#F}o\\o[', '#f]o|o{')