From 8016e01daf8a04df76faf6eee0aa493b8d955151 Mon Sep 17 00:00:00 2001 From: Valentin Lorentz Date: Mon, 22 Feb 2021 19:02:13 +0100 Subject: [PATCH] Use Black code style --- .pre-commit-config.yaml | 6 + CONTRIBUTING.md | 16 + conftest.py | 21 +- irctest/authentication.py | 18 +- irctest/basecontrollers.py | 79 +- irctest/cases.py | 316 +++-- irctest/client_mock.py | 79 +- irctest/client_tests/test_cap.py | 7 +- irctest/client_tests/test_sasl.py | 281 ++-- irctest/client_tests/test_tls.py | 96 +- irctest/controllers/charybdis.py | 64 +- irctest/controllers/girc.py | 18 +- irctest/controllers/hybrid.py | 61 +- irctest/controllers/inspircd.py | 64 +- irctest/controllers/limnoria.py | 59 +- irctest/controllers/mammon.py | 76 +- irctest/controllers/oragono.py | 300 ++-- irctest/controllers/sopel.py | 43 +- irctest/exceptions.py | 2 +- irctest/irc_utils/ambiguities.py | 3 +- irctest/irc_utils/capabilities.py | 4 +- irctest/irc_utils/junkdrawer.py | 41 +- irctest/irc_utils/message_parser.py | 58 +- irctest/irc_utils/sasl.py | 15 +- irctest/numerics.py | 374 ++--- irctest/runner.py | 40 +- irctest/server_tests/test_account_tag.py | 80 +- irctest/server_tests/test_away_notify.py | 36 +- irctest/server_tests/test_bouncer.py | 164 ++- irctest/server_tests/test_cap.py | 159 ++- irctest/server_tests/test_channel_forward.py | 56 +- .../server_tests/test_channel_operations.py | 1209 ++++++++++------- irctest/server_tests/test_channel_rename.py | 66 +- irctest/server_tests/test_chathistory.py | 494 +++++-- irctest/server_tests/test_confusables.py | 21 +- .../test_connection_registration.py | 148 +- irctest/server_tests/test_echo_message.py | 148 +- irctest/server_tests/test_extended_join.py | 84 +- .../server_tests/test_labeled_responses.py | 592 ++++++-- irctest/server_tests/test_message_tags.py | 181 +-- irctest/server_tests/test_messages.py | 68 +- irctest/server_tests/test_metadata.py | 272 ++-- irctest/server_tests/test_monitor.py | 346 +++-- irctest/server_tests/test_multi_prefix.py | 39 +- irctest/server_tests/test_multiline.py | 133 +- irctest/server_tests/test_register_verify.py | 121 +- irctest/server_tests/test_regressions.py | 132 +- irctest/server_tests/test_relaymsg.py | 147 +- irctest/server_tests/test_resume.py | 230 ++-- irctest/server_tests/test_roleplay.py | 71 +- irctest/server_tests/test_sasl.py | 266 ++-- irctest/server_tests/test_statusmsg.py | 36 +- irctest/server_tests/test_user_commands.py | 190 +-- irctest/server_tests/test_utf8.py | 30 +- irctest/server_tests/test_znc_playback.py | 198 ++- irctest/specifications.py | 19 +- irctest/tls.py | 3 +- pyproject.toml | 2 + setup.cfg | 6 + 59 files changed, 4855 insertions(+), 3033 deletions(-) create mode 100644 .pre-commit-config.yaml create mode 100644 CONTRIBUTING.md create mode 100644 pyproject.toml create mode 100644 setup.cfg diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..a4e6356 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,6 @@ +repos: + - repo: https://github.com/psf/black + rev: 20.8b1 + hooks: + - id: black + language_version: python3.7 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..07c1419 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,16 @@ +# Contributing + +## Code style + +Any color you like as long as it's [Black](https://github.com/psf/black). +In short: + +* 88 columns +* double quotes +* avoid backslashes at line breaks (use parentheses) +* closing brackets/parentheses/... go on the same indent level as the line + that opened them + +You can use [pre-commit](https://pre-commit.com/) to automatically run it +for you when you create a git commit. +Alternatively, run `pre-commit run -a` diff --git a/conftest.py b/conftest.py index d1d9b2a..8aea757 100644 --- a/conftest.py +++ b/conftest.py @@ -8,11 +8,15 @@ import _pytest.unittest from irctest.cases import _IrcTestCase, BaseClientTestCase, BaseServerTestCase from irctest.basecontrollers import BaseClientController, BaseServerController + def pytest_addoption(parser): """Called by pytest, registers CLI options passed to the pytest command.""" - parser.addoption("--controller", help="Which module to use to run the tested software.") - parser.addoption('--openssl-bin', type=str, default='openssl', - help='The openssl binary to use') + parser.addoption( + "--controller", help="Which module to use to run the tested software." + ) + parser.addoption( + "--openssl-bin", type=str, default="openssl", help="The openssl binary to use" + ) def pytest_configure(config): @@ -25,7 +29,7 @@ def pytest_configure(config): try: module = importlib.import_module(module_name) except ImportError: - pytest.exit('Cannot import module {}'.format(module_name), 1) + pytest.exit("Cannot import module {}".format(module_name), 1) controller_class = module.get_irctest_controller_class() if issubclass(controller_class, BaseClientController): @@ -34,10 +38,11 @@ def pytest_configure(config): from irctest import server_tests as module else: pytest.exit( - r'{}.Controller should be a subclass of ' - r'irctest.basecontroller.Base{{Client,Server}}Controller' - .format(module_name), - 1 + r"{}.Controller should be a subclass of " + r"irctest.basecontroller.Base{{Client,Server}}Controller".format( + module_name + ), + 1, ) _IrcTestCase.controllerClass = controller_class _IrcTestCase.controllerClass.openssl_bin = config.getoption("openssl_bin") diff --git a/irctest/authentication.py b/irctest/authentication.py index 1ae39ea..e6c2dde 100644 --- a/irctest/authentication.py +++ b/irctest/authentication.py @@ -1,19 +1,25 @@ import enum import collections + @enum.unique class Mechanisms(enum.Enum): """Enumeration for representing possible mechanisms.""" + @classmethod def as_string(cls, mech): - return {cls.plain: 'PLAIN', - cls.ecdsa_nist256p_challenge: 'ECDSA-NIST256P-CHALLENGE', - cls.scram_sha_256: 'SCRAM-SHA-256', - }[mech] + return { + cls.plain: "PLAIN", + cls.ecdsa_nist256p_challenge: "ECDSA-NIST256P-CHALLENGE", + cls.scram_sha_256: "SCRAM-SHA-256", + }[mech] + plain = 1 ecdsa_nist256p_challenge = 2 scram_sha_256 = 3 -Authentication = collections.namedtuple('Authentication', - 'mechanisms username password ecdsa_key') + +Authentication = collections.namedtuple( + "Authentication", "mechanisms username password ecdsa_key" +) Authentication.__new__.__defaults__ = ([Mechanisms.plain], None, None, None) diff --git a/irctest/basecontrollers.py b/irctest/basecontrollers.py index 691e655..211c9db 100644 --- a/irctest/basecontrollers.py +++ b/irctest/basecontrollers.py @@ -7,18 +7,22 @@ import subprocess from .runner import NotImplementedByController + class _BaseController: """Base class for software controllers. A software controller is an object that handles configuring and running a process (eg. a server or a client), as well as sending it instructions that are not part of the IRC specification.""" + def __init__(self, test_config): self.test_config = test_config + class DirectoryBasedController(_BaseController): """Helper for controllers whose software configuration is based on an arbitrary directory.""" + def __init__(self, test_config): super().__init__(test_config) self.directory = None @@ -33,18 +37,21 @@ class DirectoryBasedController(_BaseController): except subprocess.TimeoutExpired: self.proc.kill() self.proc = None + def kill(self): """Calls `kill_proc` and cleans the configuration.""" if self.proc: self.kill_proc() if self.directory: shutil.rmtree(self.directory) + def terminate(self): """Stops the process gracefully, and does not clean its config.""" self.proc.terminate() self.proc.wait() self.proc = None - def open_file(self, name, mode='a'): + + def open_file(self, name, mode="a"): """Open a file in the configuration directory.""" assert self.directory if os.sep in name: @@ -53,6 +60,7 @@ class DirectoryBasedController(_BaseController): os.makedirs(dir_) assert os.path.isdir(dir_) return open(os.path.join(self.directory, name), mode) + def create_config(self): """If there is no config dir, creates it and returns True. Else returns False.""" @@ -63,41 +71,70 @@ class DirectoryBasedController(_BaseController): return True def gen_ssl(self): - self.csr_path = os.path.join(self.directory, 'ssl.csr') - self.key_path = os.path.join(self.directory, 'ssl.key') - self.pem_path = os.path.join(self.directory, 'ssl.pem') - self.dh_path = os.path.join(self.directory, 'dh.pem') - subprocess.check_output([self.openssl_bin, 'req', '-new', '-newkey', 'rsa', - '-nodes', '-out', self.csr_path, '-keyout', self.key_path, - '-batch'], - stderr=subprocess.DEVNULL) - subprocess.check_output([self.openssl_bin, 'x509', '-req', - '-in', self.csr_path, '-signkey', self.key_path, - '-out', self.pem_path], - stderr=subprocess.DEVNULL) - subprocess.check_output([self.openssl_bin, 'dhparam', - '-out', self.dh_path, '128'], - stderr=subprocess.DEVNULL) + self.csr_path = os.path.join(self.directory, "ssl.csr") + self.key_path = os.path.join(self.directory, "ssl.key") + self.pem_path = os.path.join(self.directory, "ssl.pem") + self.dh_path = os.path.join(self.directory, "dh.pem") + subprocess.check_output( + [ + self.openssl_bin, + "req", + "-new", + "-newkey", + "rsa", + "-nodes", + "-out", + self.csr_path, + "-keyout", + self.key_path, + "-batch", + ], + stderr=subprocess.DEVNULL, + ) + subprocess.check_output( + [ + self.openssl_bin, + "x509", + "-req", + "-in", + self.csr_path, + "-signkey", + self.key_path, + "-out", + self.pem_path, + ], + stderr=subprocess.DEVNULL, + ) + subprocess.check_output( + [self.openssl_bin, "dhparam", "-out", self.dh_path, "128"], + stderr=subprocess.DEVNULL, + ) + class BaseClientController(_BaseController): """Base controller for IRC clients.""" + def run(self, hostname, port, auth): raise NotImplementedError() + class BaseServerController(_BaseController): """Base controller for IRC server.""" - _port_wait_interval = .1 + + _port_wait_interval = 0.1 port_open = False - def run(self, hostname, port, password, - valid_metadata_keys, invalid_metadata_keys): + + def run(self, hostname, port, password, valid_metadata_keys, invalid_metadata_keys): raise NotImplementedError() + def registerUser(self, case, username, password=None): - raise NotImplementedByController('account registration') + raise NotImplementedByController("account registration") + def wait_for_port(self): while not self.port_open: time.sleep(self._port_wait_interval) try: - c = socket.create_connection(('localhost', self.port), timeout=1.0) + c = socket.create_connection(("localhost", self.port), timeout=1.0) c.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) # Make sure the server properly processes the disconnect. diff --git a/irctest/cases.py b/irctest/cases.py index ae03967..782afad 100644 --- a/irctest/cases.py +++ b/irctest/cases.py @@ -15,19 +15,38 @@ from .irc_utils.junkdrawer import normalizeWhitespace, random_name from .irc_utils.sasl import sasl_plain_blob from .exceptions import ConnectionClosed from .specifications import Specifications -from .numerics import ERR_NOSUCHCHANNEL, ERR_TOOMANYCHANNELS, ERR_BADCHANNELKEY, ERR_INVITEONLYCHAN, ERR_BANNEDFROMCHAN, ERR_NEEDREGGEDNICK +from .numerics import ( + ERR_NOSUCHCHANNEL, + ERR_TOOMANYCHANNELS, + ERR_BADCHANNELKEY, + ERR_INVITEONLYCHAN, + ERR_BANNEDFROMCHAN, + ERR_NEEDREGGEDNICK, +) + +CHANNEL_JOIN_FAIL_NUMERICS = frozenset( + [ + ERR_NOSUCHCHANNEL, + ERR_TOOMANYCHANNELS, + ERR_BADCHANNELKEY, + ERR_INVITEONLYCHAN, + ERR_BANNEDFROMCHAN, + ERR_NEEDREGGEDNICK, + ] +) -CHANNEL_JOIN_FAIL_NUMERICS = frozenset([ERR_NOSUCHCHANNEL, ERR_TOOMANYCHANNELS, ERR_BADCHANNELKEY, ERR_INVITEONLYCHAN, ERR_BANNEDFROMCHAN, ERR_NEEDREGGEDNICK]) class ChannelJoinException(Exception): def __init__(self, code, params): - super().__init__(f'Failed to join channel ({code}): {params}') + super().__init__(f"Failed to join channel ({code}): {params}") self.code = code self.params = params + class _IrcTestCase(unittest.TestCase): """Base class for test cases.""" - controllerClass = None # Will be set by __main__.py + + controllerClass = None # Will be set by __main__.py @staticmethod def config(): @@ -40,21 +59,36 @@ class _IrcTestCase(unittest.TestCase): def description(self): method_doc = self._testMethodDoc if not method_doc: - return '' - return '\t'+normalizeWhitespace( + return "" + return ( + "\t" + + normalizeWhitespace( method_doc, removeNewline=False, - ).strip().replace('\n ', '\n\t') + ) + .strip() + .replace("\n ", "\n\t") + ) def setUp(self): super().setUp() self.controller = self.controllerClass(self.config()) self.inbuffer = [] if self.show_io: - print('---- new test ----') - def assertMessageEqual(self, msg, subcommand=None, subparams=None, - target=None, nick=None, fail_msg=None, extra_format=(), - strip_first_param=False, **kwargs): + print("---- new test ----") + + def assertMessageEqual( + self, + msg, + subcommand=None, + subparams=None, + target=None, + nick=None, + fail_msg=None, + extra_format=(), + strip_first_param=False, + **kwargs, + ): """Helper for partially comparing a message. Takes the message as first arguments, and comparisons to be made @@ -62,65 +96,71 @@ class _IrcTestCase(unittest.TestCase): Deals with subcommands (eg. `CAP`) if any of `subcommand`, `subparams`, and `target` are given.""" - fail_msg = fail_msg or '{msg}' + fail_msg = fail_msg or "{msg}" for (key, value) in kwargs.items(): - if strip_first_param and key == 'params': + if strip_first_param and key == "params": value = value[1:] - self.assertEqual(getattr(msg, key), value, msg, fail_msg, - extra_format=extra_format) + self.assertEqual( + getattr(msg, key), value, msg, fail_msg, extra_format=extra_format + ) if nick: self.assertNotEqual(msg.prefix, None, msg, fail_msg) - self.assertEqual(msg.prefix.split('!')[0], nick, msg, fail_msg) + self.assertEqual(msg.prefix.split("!")[0], nick, msg, fail_msg) if subcommand is not None or subparams is not None: self.assertGreater(len(msg.params), 2, fail_msg) - #msg_target = msg.params[0] + # msg_target = msg.params[0] msg_subcommand = msg.params[1] msg_subparams = msg.params[2:] if subcommand: - self.assertEqual(msg_subcommand, subcommand, msg, fail_msg, - extra_format=extra_format) + self.assertEqual( + msg_subcommand, subcommand, msg, fail_msg, extra_format=extra_format + ) if subparams is not None: - self.assertEqual(msg_subparams, subparams, msg, fail_msg, - extra_format=extra_format) + self.assertEqual( + msg_subparams, subparams, msg, fail_msg, extra_format=extra_format + ) def assertIn(self, item, list_, msg=None, fail_msg=None, extra_format=()): if fail_msg: - fail_msg = fail_msg.format(*extra_format, - item=item, list=list_, msg=msg) + fail_msg = fail_msg.format(*extra_format, item=item, list=list_, msg=msg) super().assertIn(item, list_, fail_msg) + def assertNotIn(self, item, list_, msg=None, fail_msg=None, extra_format=()): if fail_msg: - fail_msg = fail_msg.format(*extra_format, - item=item, list=list_, msg=msg) + fail_msg = fail_msg.format(*extra_format, item=item, list=list_, msg=msg) super().assertNotIn(item, list_, fail_msg) + def assertEqual(self, got, expects, msg=None, fail_msg=None, extra_format=()): if fail_msg: - fail_msg = fail_msg.format(*extra_format, - got=got, expects=expects, msg=msg) + fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg) super().assertEqual(got, expects, fail_msg) + def assertNotEqual(self, got, expects, msg=None, fail_msg=None, extra_format=()): if fail_msg: - fail_msg = fail_msg.format(*extra_format, - got=got, expects=expects, msg=msg) + fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg) super().assertNotEqual(got, expects, fail_msg) + class BaseClientTestCase(_IrcTestCase): """Basic class for client tests. Handles spawning a client and exchanging messages with it.""" + nick = None user = None + def setUp(self): super().setUp() self.conn = None self._setUpServer() + def tearDown(self): if self.conn: try: - self.conn.sendall(b'QUIT :end of test.') + self.conn.sendall(b"QUIT :end of test.") except BrokenPipeError: - pass # client already disconnected + pass # client already disconnected except OSError: - pass # the conn was already closed by the test, or something + pass # the conn was already closed by the test, or something self.controller.kill() if self.conn: self.conn_file.close() @@ -130,8 +170,9 @@ class BaseClientTestCase(_IrcTestCase): def _setUpServer(self): """Creates the server and make it listen.""" self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.server.bind(('', 0)) # Bind any free port + self.server.bind(("", 0)) # Bind any free port self.server.listen(1) + def acceptClient(self, tls_cert=None, tls_key=None, server=None): """Make the server accept a client connection. Blocking.""" server = server or self.server @@ -139,10 +180,12 @@ class BaseClientTestCase(_IrcTestCase): if tls_cert is None and tls_key is None: pass else: - assert tls_cert and tls_key, \ - 'tls_cert must be provided if and only if tls_key is.' - with tempfile.NamedTemporaryFile('at') as certfile, \ - tempfile.NamedTemporaryFile('at') as keyfile: + assert ( + tls_cert and tls_key + ), "tls_cert must be provided if and only if tls_key is." + with tempfile.NamedTemporaryFile( + "at" + ) as certfile, tempfile.NamedTemporaryFile("at") as keyfile: certfile.write(tls_cert) certfile.seek(0) keyfile.write(tls_key) @@ -150,17 +193,18 @@ class BaseClientTestCase(_IrcTestCase): context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) context.load_cert_chain(certfile=certfile.name, keyfile=keyfile.name) self.conn = context.wrap_socket(self.conn, server_side=True) - self.conn_file = self.conn.makefile(newline='\r\n', - encoding='utf8') + self.conn_file = self.conn.makefile(newline="\r\n", encoding="utf8") def getLine(self): line = self.conn_file.readline() if self.show_io: - print('{:.3f} C: {}'.format(time.time(), line.strip())) + print("{:.3f} C: {}".format(time.time(), line.strip())) return line + def getMessages(self, *args): lines = self.getLines(*args) return map(message_parser.parse_message, lines) + def getMessage(self, *args, filter_pred=None): """Gets a message and returns it. If a filter predicate is given, fetches messages until the predicate returns a False on a message, @@ -172,45 +216,46 @@ class BaseClientTestCase(_IrcTestCase): msg = message_parser.parse_message(line) if not filter_pred or filter_pred(msg): return msg + def sendLine(self, line): self.conn.sendall(line.encode()) - if not line.endswith('\r\n'): - self.conn.sendall(b'\r\n') + if not line.endswith("\r\n"): + self.conn.sendall(b"\r\n") if self.show_io: - print('{:.3f} S: {}'.format(time.time(), line.strip())) + print("{:.3f} S: {}".format(time.time(), line.strip())) + class ClientNegociationHelper: """Helper class for tests handling capabilities negociation.""" + def readCapLs(self, auth=None, tls_config=None): (hostname, port) = self.server.getsockname() self.controller.run( - hostname=hostname, - port=port, - auth=auth, - tls_config=tls_config, - ) + hostname=hostname, + port=port, + auth=auth, + tls_config=tls_config, + ) self.acceptClient() m = self.getMessage() - self.assertEqual(m.command, 'CAP', - 'First message is not CAP LS.') - if m.params == ['LS']: + self.assertEqual(m.command, "CAP", "First message is not CAP LS.") + if m.params == ["LS"]: self.protocol_version = 301 - elif m.params == ['LS', '302']: + elif m.params == ["LS", "302"]: self.protocol_version = 302 - elif m.params == ['END']: + elif m.params == ["END"]: self.protocol_version = None else: - raise AssertionError('Unknown CAP params: {}' - .format(m.params)) + raise AssertionError("Unknown CAP params: {}".format(m.params)) def userNickPredicate(self, msg): """Predicate to be used with getMessage to handle NICK/USER transparently.""" - if msg.command == 'NICK': + if msg.command == "NICK": self.assertEqual(len(msg.params), 1, msg) self.nick = msg.params[0] return False - elif msg.command == 'USER': + elif msg.command == "USER": self.assertEqual(len(msg.params), 4, msg) self.user = msg.params return False @@ -225,25 +270,25 @@ class ClientNegociationHelper: if not self.protocol_version: # No negotiation. return - self.sendLine('CAP * LS :{}'.format(' '.join(caps))) + self.sendLine("CAP * LS :{}".format(" ".join(caps))) capability_names = frozenset(capabilities.cap_list_to_dict(caps)) self.acked_capabilities = set() while True: m = self.getMessage(filter_pred=self.userNickPredicate) - if m.command != 'CAP': + if m.command != "CAP": return m self.assertGreater(len(m.params), 0, m) - if m.params[0] == 'REQ': + if m.params[0] == "REQ": self.assertEqual(len(m.params), 2, m) requested = frozenset(m.params[1].split()) if not requested.issubset(capability_names): - self.sendLine('CAP {} NAK :{}'.format( - self.nick or '*', - m.params[1][0:100])) + self.sendLine( + "CAP {} NAK :{}".format(self.nick or "*", m.params[1][0:100]) + ) else: - self.sendLine('CAP {} ACK :{}'.format( - self.nick or '*', - m.params[1])) + self.sendLine( + "CAP {} ACK :{}".format(self.nick or "*", m.params[1]) + ) self.acked_capabilities.update(requested) else: return m @@ -252,27 +297,35 @@ class ClientNegociationHelper: class BaseServerTestCase(_IrcTestCase): """Basic class for server tests. Handles spawning a server and exchanging messages with it.""" + password = None ssl = False valid_metadata_keys = frozenset() invalid_metadata_keys = frozenset() + def setUp(self): super().setUp() self.server_support = {} self.find_hostname_and_port() - self.controller.run(self.hostname, self.port, password=self.password, - valid_metadata_keys=self.valid_metadata_keys, - invalid_metadata_keys=self.invalid_metadata_keys, - ssl=self.ssl) + self.controller.run( + self.hostname, + self.port, + password=self.password, + valid_metadata_keys=self.valid_metadata_keys, + invalid_metadata_keys=self.invalid_metadata_keys, + ssl=self.ssl, + ) self.clients = {} + def tearDown(self): self.controller.kill() for client in list(self.clients): self.removeClient(client) + def find_hostname_and_port(self): """Find available hostname/port to listen on.""" s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.bind(("",0)) + s.bind(("", 0)) (self.hostname, self.port) = s.getsockname() s.close() @@ -281,14 +334,12 @@ class BaseServerTestCase(_IrcTestCase): If 'name' is not given, uses the lowest unused non-negative integer.""" self.controller.wait_for_port() if not name: - name = max(map(int, list(self.clients)+[0]))+1 + name = max(map(int, list(self.clients) + [0])) + 1 show_io = show_io if show_io is not None else self.show_io - self.clients[name] = client_mock.ClientMock(name=name, - show_io=show_io) + self.clients[name] = client_mock.ClientMock(name=name, show_io=show_io) self.clients[name].connect(self.hostname, self.port) return name - def removeClient(self, name): """Disconnects the client, without QUIT.""" assert name in self.clients @@ -297,12 +348,16 @@ class BaseServerTestCase(_IrcTestCase): def getMessages(self, client, **kwargs): return self.clients[client].getMessages(**kwargs) + def getMessage(self, client, **kwargs): return self.clients[client].getMessage(**kwargs) + def getRegistrationMessage(self, client): """Filter notices, do not send pings.""" - return self.getMessage(client, synchronize=False, - filter_pred=lambda m:m.command != 'NOTICE') + return self.getMessage( + client, synchronize=False, filter_pred=lambda m: m.command != "NOTICE" + ) + def sendLine(self, client, line): return self.clients[client].sendLine(line) @@ -315,8 +370,8 @@ class BaseServerTestCase(_IrcTestCase): caps = [] while True: m = self.getRegistrationMessage(client) - self.assertMessageEqual(m, command='CAP', subcommand='LS') - if m.params[2] == '*': + self.assertMessageEqual(m, command="CAP", subcommand="LS") + if m.params[2] == "*": caps.extend(m.params[3].split()) else: caps.extend(m.params[2].split()) @@ -332,8 +387,7 @@ class BaseServerTestCase(_IrcTestCase): del self.clients[client] return else: - raise AssertionError('Client not disconnected.') - + raise AssertionError("Client not disconnected.") def skipToWelcome(self, client): """Skip to the point where we are registered @@ -343,45 +397,54 @@ class BaseServerTestCase(_IrcTestCase): while True: m = self.getMessage(client, synchronize=False) result.append(m) - if m.command == '001': + if m.command == "001": return result - def connectClient(self, nick, name=None, capabilities=None, - skip_if_cap_nak=False, show_io=None, password=None, ident='username'): + def connectClient( + self, + nick, + name=None, + capabilities=None, + skip_if_cap_nak=False, + show_io=None, + password=None, + ident="username", + ): client = self.addClient(name, show_io=show_io) if capabilities is not None and 0 < len(capabilities): - self.sendLine(client, 'CAP REQ :{}'.format(' '.join(capabilities))) + self.sendLine(client, "CAP REQ :{}".format(" ".join(capabilities))) m = self.getRegistrationMessage(client) try: - self.assertMessageEqual(m, command='CAP', - fail_msg='Expected CAP ACK, got: {msg}') - self.assertEqual(m.params[1], 'ACK', m, - fail_msg='Expected CAP ACK, got: {msg}') + self.assertMessageEqual( + m, command="CAP", fail_msg="Expected CAP ACK, got: {msg}" + ) + self.assertEqual( + m.params[1], "ACK", m, fail_msg="Expected CAP ACK, got: {msg}" + ) except AssertionError: if skip_if_cap_nak: - raise runner.NotImplementedByController( - ', '.join(capabilities)) + raise runner.NotImplementedByController(", ".join(capabilities)) else: raise - self.sendLine(client, 'CAP END') + self.sendLine(client, "CAP END") if password is not None: - self.sendLine(client, 'AUTHENTICATE PLAIN') + self.sendLine(client, "AUTHENTICATE PLAIN") self.sendLine(client, sasl_plain_blob(nick, password)) - self.sendLine(client, 'NICK {}'.format(nick)) - self.sendLine(client, 'USER %s * * :Realname' % (ident,)) + self.sendLine(client, "NICK {}".format(nick)) + self.sendLine(client, "USER %s * * :Realname" % (ident,)) welcome = self.skipToWelcome(client) - self.sendLine(client, 'PING foo') + self.sendLine(client, "PING foo") # Skip all that happy welcoming stuff while True: m = self.getMessage(client) - if m.command == 'PONG': + if m.command == "PONG": break - elif m.command == '005': + elif m.command == "005": for param in m.params[1:-1]: - if '=' in param: - (key, value) = param.split('=') + if "=" in param: + (key, value) = param.split("=") else: (key, value) = (param, None) self.server_support[key] = value @@ -390,49 +453,57 @@ class BaseServerTestCase(_IrcTestCase): return welcome def joinClient(self, client, channel): - self.sendLine(client, 'JOIN {}'.format(channel)) + self.sendLine(client, "JOIN {}".format(channel)) received = {m.command for m in self.getMessages(client)} - self.assertIn('366', received, - fail_msg='Join to {} failed, {item} is not in the set of ' - 'received responses: {list}', - extra_format=(channel,)) + self.assertIn( + "366", + received, + fail_msg="Join to {} failed, {item} is not in the set of " + "received responses: {list}", + extra_format=(channel,), + ) def joinChannel(self, client, channel): - self.sendLine(client, 'JOIN {}'.format(channel)) + self.sendLine(client, "JOIN {}".format(channel)) # wait until we see them join the channel joined = False while not joined: for msg in self.getMessages(client): - if msg.command == 'JOIN' and 0 < len(msg.params) and msg.params[0].lower() == channel.lower(): + if ( + msg.command == "JOIN" + and 0 < len(msg.params) + and msg.params[0].lower() == channel.lower() + ): joined = True break elif msg.command in CHANNEL_JOIN_FAIL_NUMERICS: raise ChannelJoinException(msg.command, msg.params) def getISupport(self): - cn = random_name('bar') + cn = random_name("bar") self.addClient(name=cn) - self.sendLine(cn, 'NICK %s' % (cn,)) - self.sendLine(cn, 'USER u s e r') + self.sendLine(cn, "NICK %s" % (cn,)) + self.sendLine(cn, "USER u s e r") messages = self.getMessages(cn) isupport = {} for message in messages: - if message.command != '005': + if message.command != "005": continue # 005 nick :are supported by this server tokens = message.params[1:-1] for token in tokens: - name, _, value = token.partition('=') + name, _, value = token.partition("=") isupport[name] = value - self.sendLine(cn, 'QUIT') + self.sendLine(cn, "QUIT") self.assertDisconnected(cn) return isupport + class OptionalityHelper: def checkSaslSupport(self): if self.controller.supported_sasl_mechanisms: return - raise runner.NotImplementedByController('SASL') + raise runner.NotImplementedByController("SASL") def checkMechanismSupport(self, mechanism): if mechanism in self.controller.supported_sasl_mechanisms: @@ -445,7 +516,9 @@ class OptionalityHelper: def newf(self): self.checkMechanismSupport(mech) return f(self) + return newf + return decorator def skipUnlessHasSasl(f): @@ -453,6 +526,7 @@ class OptionalityHelper: def newf(self): self.checkSaslSupport() return f(self) + return newf def checkCapabilitySupport(self, cap): @@ -466,22 +540,26 @@ class OptionalityHelper: def newf(self): self.checkCapabilitySupport(cap) return f(self) + return newf + return decorator -class SpecificationSelector: +class SpecificationSelector: def requiredBySpecification(*specifications, strict=False): specifications = frozenset( - Specifications.of_name(s) if isinstance(s, str) else s - for s in specifications) + Specifications.of_name(s) if isinstance(s, str) else s + for s in specifications + ) if None in specifications: - raise ValueError('Invalid set of specifications: {}' - .format(specifications)) + raise ValueError("Invalid set of specifications: {}".format(specifications)) + def decorator(f): for specification in specifications: f = getattr(pytest.mark, specification.value)(f) if strict: f = pytest.mark.strict(f) return f + return decorator diff --git a/irctest/client_mock.py b/irctest/client_mock.py index d85b1bc..1c679c4 100644 --- a/irctest/client_mock.py +++ b/irctest/client_mock.py @@ -5,32 +5,37 @@ import socket from .irc_utils import message_parser from .exceptions import NoMessageException, ConnectionClosed + class ClientMock: def __init__(self, name, show_io): self.name = name self.show_io = show_io self.inbuffer = [] self.ssl = False + def connect(self, hostname, port): self.conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.conn.settimeout(1) # TODO: configurable + self.conn.settimeout(1) # TODO: configurable self.conn.connect((hostname, port)) if self.show_io: - print('{:.3f} {}: connects to server.'.format(time.time(), self.name)) + print("{:.3f} {}: connects to server.".format(time.time(), self.name)) + def disconnect(self): if self.show_io: - print('{:.3f} {}: disconnects from server.'.format(time.time(), self.name)) + print("{:.3f} {}: disconnects from server.".format(time.time(), self.name)) self.conn.close() + def starttls(self): - assert not self.ssl, 'SSL already active.' + assert not self.ssl, "SSL already active." self.conn = ssl.wrap_socket(self.conn) self.ssl = True + def getMessages(self, synchronize=True, assert_get_one=False, raw=False): if synchronize: - token = 'synchronize{}'.format(time.monotonic()) - self.sendLine('PING {}'.format(token)) + token = "synchronize{}".format(time.monotonic()) + self.sendLine("PING {}".format(token)) got_pong = False - data = b'' + data = b"" (self.inbuffer, messages) = ([], self.inbuffer) conn = self.conn try: @@ -38,11 +43,11 @@ class ClientMock: try: new_data = conn.recv(4096) except socket.timeout: - if not assert_get_one and not synchronize and data == b'': + if not assert_get_one and not synchronize and data == b"": # Received nothing return [] if self.show_io: - print('{:.3f} {}: waiting…'.format(time.time(), self.name)) + print("{:.3f} {}: waiting…".format(time.time(), self.name)) time.sleep(0.1) continue except ConnectionResetError: @@ -52,29 +57,31 @@ class ClientMock: # Connection closed raise ConnectionClosed() data += new_data - if not new_data.endswith(b'\r\n'): + if not new_data.endswith(b"\r\n"): time.sleep(0.1) continue if not synchronize: got_pong = True - for line in data.decode().split('\r\n'): + for line in data.decode().split("\r\n"): if line: if self.show_io: - print('{time:.3f}{ssl} S -> {client}: {line}'.format( - time=time.time(), - ssl=' (ssl)' if self.ssl else '', - client=self.name, - line=line)) + print( + "{time:.3f}{ssl} S -> {client}: {line}".format( + time=time.time(), + ssl=" (ssl)" if self.ssl else "", + client=self.name, + line=line, + ) + ) message = message_parser.parse_message(line) - if message.command == 'PONG' and \ - token in message.params: + if message.command == "PONG" and token in message.params: got_pong = True else: if raw: messages.append(line) else: messages.append(message) - data = b'' + data = b"" except ConnectionClosed: if messages: return messages @@ -82,16 +89,19 @@ class ClientMock: raise else: return messages + def getMessage(self, filter_pred=None, synchronize=True, raw=False): while True: if not self.inbuffer: self.inbuffer = self.getMessages( - synchronize=synchronize, assert_get_one=True, raw=raw) + synchronize=synchronize, assert_get_one=True, raw=raw + ) if not self.inbuffer: raise NoMessageException() - message = self.inbuffer.pop(0) # TODO: use dequeue + message = self.inbuffer.pop(0) # TODO: use dequeue if not filter_pred or filter_pred(message): return message + def sendLine(self, line): if isinstance(line, str): encoded_line = line.encode() @@ -99,26 +109,31 @@ class ClientMock: encoded_line = line else: raise ValueError(line) - if not encoded_line.endswith(b'\r\n'): - encoded_line += b'\r\n' + if not encoded_line.endswith(b"\r\n"): + encoded_line += b"\r\n" try: ret = self.conn.sendall(encoded_line) except BrokenPipeError: raise ConnectionClosed() - if sys.version_info <= (3, 6) and self.ssl: # https://bugs.python.org/issue25951 + if ( + sys.version_info <= (3, 6) and self.ssl + ): # https://bugs.python.org/issue25951 assert ret == len(encoded_line), (ret, repr(encoded_line)) else: assert ret is None, ret if self.show_io: if isinstance(line, str): escaped_line = line - escaped = '' + escaped = "" else: escaped_line = repr(line) - escaped = ' (escaped)' - print('{time:.3f}{escaped}{ssl} {client} -> S: {line}'.format( - time=time.time(), - escaped=escaped, - ssl=' (ssl)' if self.ssl else '', - client=self.name, - line=escaped_line.strip('\r\n'))) + escaped = " (escaped)" + print( + "{time:.3f}{escaped}{ssl} {client} -> S: {line}".format( + time=time.time(), + escaped=escaped, + ssl=" (ssl)" if self.ssl else "", + client=self.name, + line=escaped_line.strip("\r\n"), + ) + ) diff --git a/irctest/client_tests/test_cap.py b/irctest/client_tests/test_cap.py index 06b736c..6032ec1 100644 --- a/irctest/client_tests/test_cap.py +++ b/irctest/client_tests/test_cap.py @@ -1,14 +1,15 @@ from irctest import cases from irctest.irc_utils.message_parser import Message + class CapTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper): - @cases.SpecificationSelector.requiredBySpecification('IRCv3.1', 'IRCv3.2') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.1", "IRCv3.2") def testSendCap(self): """Send CAP LS 302 and read the result.""" self.readCapLs() - @cases.SpecificationSelector.requiredBySpecification('IRCv3.1', 'IRCv3.2') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.1", "IRCv3.2") def testEmptyCapLs(self): """Empty result to CAP LS. Client should send CAP END.""" m = self.negotiateCapabilities([]) - self.assertEqual(m, Message({}, None, 'CAP', ['END'])) + self.assertEqual(m, Message({}, None, "CAP", ["END"])) diff --git a/irctest/client_tests/test_sasl.py b/irctest/client_tests/test_sasl.py index d0b25d3..e8a4f45 100644 --- a/irctest/client_tests/test_sasl.py +++ b/irctest/client_tests/test_sasl.py @@ -27,6 +27,7 @@ IRX9cyi2wdYg9mUUYyh9GKdBCYHGUJAiCA== CHALLENGE = bytes(range(32)) assert len(CHALLENGE) == 32 + class IdentityHash: def __init__(self, data): self._data = data @@ -34,28 +35,31 @@ class IdentityHash: def digest(self): return self._data -class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper, - cases.OptionalityHelper): - @cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN') + +class SaslTestCase( + cases.BaseClientTestCase, cases.ClientNegociationHelper, cases.OptionalityHelper +): + @cases.OptionalityHelper.skipUnlessHasMechanism("PLAIN") def testPlain(self): """Test PLAIN authentication with correct username/password.""" auth = authentication.Authentication( - mechanisms=[authentication.Mechanisms.plain], - username='jilles', - password='sesame', - ) - m = self.negotiateCapabilities(['sasl'], auth=auth) - self.assertEqual(m, Message({}, None, 'AUTHENTICATE', ['PLAIN'])) - self.sendLine('AUTHENTICATE +') + mechanisms=[authentication.Mechanisms.plain], + username="jilles", + password="sesame", + ) + m = self.negotiateCapabilities(["sasl"], auth=auth) + self.assertEqual(m, Message({}, None, "AUTHENTICATE", ["PLAIN"])) + self.sendLine("AUTHENTICATE +") m = self.getMessage() - self.assertEqual(m, Message({}, None, 'AUTHENTICATE', - ['amlsbGVzAGppbGxlcwBzZXNhbWU='])) - self.sendLine('900 * * jilles :You are now logged in.') - self.sendLine('903 * :SASL authentication successful') - m = self.negotiateCapabilities(['sasl'], False) - self.assertEqual(m, Message({}, None, 'CAP', ['END'])) + self.assertEqual( + m, Message({}, None, "AUTHENTICATE", ["amlsbGVzAGppbGxlcwBzZXNhbWU="]) + ) + self.sendLine("900 * * jilles :You are now logged in.") + self.sendLine("903 * :SASL authentication successful") + m = self.negotiateCapabilities(["sasl"], False) + self.assertEqual(m, Message({}, None, "CAP", ["END"])) - @cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN') + @cases.OptionalityHelper.skipUnlessHasMechanism("PLAIN") def testPlainNotAvailable(self): """`sasl=EXTERNAL` is advertized, whereas the client is configured to use PLAIN. @@ -65,27 +69,26 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper, a 904. """ auth = authentication.Authentication( - mechanisms=[authentication.Mechanisms.plain], - username='jilles', - password='sesame', - ) - m = self.negotiateCapabilities(['sasl=EXTERNAL'], auth=auth) - self.assertEqual(self.acked_capabilities, {'sasl'}) - if m == Message({}, None, 'CAP', ['END']): + mechanisms=[authentication.Mechanisms.plain], + username="jilles", + password="sesame", + ) + m = self.negotiateCapabilities(["sasl=EXTERNAL"], auth=auth) + self.assertEqual(self.acked_capabilities, {"sasl"}) + if m == Message({}, None, "CAP", ["END"]): # IRCv3.2-style, for clients that skip authentication # when unavailable (eg. Limnoria) return - elif m.command == 'QUIT': + elif m.command == "QUIT": # IRCv3.2-style, for clients that quit when unavailable # (eg. Sopel) return - self.assertEqual(m, Message({}, None, 'AUTHENTICATE', ['PLAIN'])) - self.sendLine('904 {} :SASL auth failed'.format(self.nick)) + self.assertEqual(m, Message({}, None, "AUTHENTICATE", ["PLAIN"])) + self.sendLine("904 {} :SASL auth failed".format(self.nick)) m = self.getMessage() - self.assertMessageEqual(m, command='CAP') + self.assertMessageEqual(m, command="CAP") - - @cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN') + @cases.OptionalityHelper.skipUnlessHasMechanism("PLAIN") def testPlainLarge(self): """Test the client splits large AUTHENTICATE messages whose payload is not a multiple of 400. @@ -93,30 +96,28 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper, """ # TODO: authzid is optional auth = authentication.Authentication( - mechanisms=[authentication.Mechanisms.plain], - username='foo', - password='bar'*200, - ) - authstring = base64.b64encode(b'\x00'.join( - [b'foo', b'foo', b'bar'*200])).decode() - m = self.negotiateCapabilities(['sasl'], auth=auth) - self.assertEqual(m, Message({}, None, 'AUTHENTICATE', ['PLAIN'])) - self.sendLine('AUTHENTICATE +') + mechanisms=[authentication.Mechanisms.plain], + username="foo", + password="bar" * 200, + ) + authstring = base64.b64encode( + b"\x00".join([b"foo", b"foo", b"bar" * 200]) + ).decode() + m = self.negotiateCapabilities(["sasl"], auth=auth) + self.assertEqual(m, Message({}, None, "AUTHENTICATE", ["PLAIN"])) + self.sendLine("AUTHENTICATE +") m = self.getMessage() - self.assertEqual(m, Message({}, None, 'AUTHENTICATE', - [authstring[0:400]]), m) + self.assertEqual(m, Message({}, None, "AUTHENTICATE", [authstring[0:400]]), m) m = self.getMessage() - self.assertEqual(m, Message({}, None, 'AUTHENTICATE', - [authstring[400:800]])) + self.assertEqual(m, Message({}, None, "AUTHENTICATE", [authstring[400:800]])) m = self.getMessage() - self.assertEqual(m, Message({}, None, 'AUTHENTICATE', - [authstring[800:]])) - self.sendLine('900 * * {} :You are now logged in.'.format('foo')) - self.sendLine('903 * :SASL authentication successful') - m = self.negotiateCapabilities(['sasl'], False) - self.assertEqual(m, Message({}, None, 'CAP', ['END'])) + self.assertEqual(m, Message({}, None, "AUTHENTICATE", [authstring[800:]])) + self.sendLine("900 * * {} :You are now logged in.".format("foo")) + self.sendLine("903 * :SASL authentication successful") + m = self.negotiateCapabilities(["sasl"], False) + self.assertEqual(m, Message({}, None, "CAP", ["END"])) - @cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN') + @cases.OptionalityHelper.skipUnlessHasMechanism("PLAIN") def testPlainLargeMultiple(self): """Test the client splits large AUTHENTICATE messages whose payload is a multiple of 400. @@ -124,149 +125,157 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper, """ # TODO: authzid is optional auth = authentication.Authentication( - mechanisms=[authentication.Mechanisms.plain], - username='foo', - password='quux'*148, - ) - authstring = base64.b64encode(b'\x00'.join( - [b'foo', b'foo', b'quux'*148])).decode() - m = self.negotiateCapabilities(['sasl'], auth=auth) - self.assertEqual(m, Message({}, None, 'AUTHENTICATE', ['PLAIN'])) - self.sendLine('AUTHENTICATE +') + mechanisms=[authentication.Mechanisms.plain], + username="foo", + password="quux" * 148, + ) + authstring = base64.b64encode( + b"\x00".join([b"foo", b"foo", b"quux" * 148]) + ).decode() + m = self.negotiateCapabilities(["sasl"], auth=auth) + self.assertEqual(m, Message({}, None, "AUTHENTICATE", ["PLAIN"])) + self.sendLine("AUTHENTICATE +") m = self.getMessage() - self.assertEqual(m, Message({}, None, 'AUTHENTICATE', - [authstring[0:400]]), m) + self.assertEqual(m, Message({}, None, "AUTHENTICATE", [authstring[0:400]]), m) m = self.getMessage() - self.assertEqual(m, Message({}, None, 'AUTHENTICATE', - [authstring[400:800]])) + self.assertEqual(m, Message({}, None, "AUTHENTICATE", [authstring[400:800]])) m = self.getMessage() - self.assertEqual(m, Message({}, None, 'AUTHENTICATE', - ['+'])) - self.sendLine('900 * * {} :You are now logged in.'.format('foo')) - self.sendLine('903 * :SASL authentication successful') - m = self.negotiateCapabilities(['sasl'], False) - self.assertEqual(m, Message({}, None, 'CAP', ['END'])) + self.assertEqual(m, Message({}, None, "AUTHENTICATE", ["+"])) + self.sendLine("900 * * {} :You are now logged in.".format("foo")) + self.sendLine("903 * :SASL authentication successful") + m = self.negotiateCapabilities(["sasl"], False) + self.assertEqual(m, Message({}, None, "CAP", ["END"])) - @cases.OptionalityHelper.skipUnlessHasMechanism('ECDSA-NIST256P-CHALLENGE') + @cases.OptionalityHelper.skipUnlessHasMechanism("ECDSA-NIST256P-CHALLENGE") def testEcdsa(self): - """Test ECDSA authentication. - """ + """Test ECDSA authentication.""" auth = authentication.Authentication( - mechanisms=[authentication.Mechanisms.ecdsa_nist256p_challenge], - username='jilles', - ecdsa_key=ECDSA_KEY, - ) - m = self.negotiateCapabilities(['sasl'], auth=auth) - self.assertEqual(m, Message({}, None, 'AUTHENTICATE', ['ECDSA-NIST256P-CHALLENGE'])) - self.sendLine('AUTHENTICATE +') + mechanisms=[authentication.Mechanisms.ecdsa_nist256p_challenge], + username="jilles", + ecdsa_key=ECDSA_KEY, + ) + m = self.negotiateCapabilities(["sasl"], auth=auth) + self.assertEqual( + m, Message({}, None, "AUTHENTICATE", ["ECDSA-NIST256P-CHALLENGE"]) + ) + self.sendLine("AUTHENTICATE +") m = self.getMessage() - self.assertEqual(m, Message({}, None, 'AUTHENTICATE', - ['amlsbGVz'])) # jilles - self.sendLine('AUTHENTICATE {}'.format(base64.b64encode(CHALLENGE).decode('ascii'))) + self.assertEqual(m, Message({}, None, "AUTHENTICATE", ["amlsbGVz"])) # jilles + self.sendLine( + "AUTHENTICATE {}".format(base64.b64encode(CHALLENGE).decode("ascii")) + ) m = self.getMessage() - self.assertMessageEqual(m, command='AUTHENTICATE') + self.assertMessageEqual(m, command="AUTHENTICATE") sk = ecdsa.SigningKey.from_pem(ECDSA_KEY) vk = sk.get_verifying_key() signature = base64.b64decode(m.params[0]) try: - vk.verify(signature, CHALLENGE, hashfunc=IdentityHash, sigdecode=sigdecode_der) + vk.verify( + signature, CHALLENGE, hashfunc=IdentityHash, sigdecode=sigdecode_der + ) except ecdsa.BadSignatureError: - raise AssertionError('Bad signature') - self.sendLine('900 * * foo :You are now logged in.') - self.sendLine('903 * :SASL authentication successful') - m = self.negotiateCapabilities(['sasl'], False) - self.assertEqual(m, Message({}, None, 'CAP', ['END'])) + raise AssertionError("Bad signature") + self.sendLine("900 * * foo :You are now logged in.") + self.sendLine("903 * :SASL authentication successful") + m = self.negotiateCapabilities(["sasl"], False) + self.assertEqual(m, Message({}, None, "CAP", ["END"])) - @cases.OptionalityHelper.skipUnlessHasMechanism('SCRAM-SHA-256') + @cases.OptionalityHelper.skipUnlessHasMechanism("SCRAM-SHA-256") def testScram(self): - """Test SCRAM-SHA-256 authentication. - """ + """Test SCRAM-SHA-256 authentication.""" auth = authentication.Authentication( - mechanisms=[authentication.Mechanisms.scram_sha_256], - username='jilles', - password='sesame', - ) + mechanisms=[authentication.Mechanisms.scram_sha_256], + username="jilles", + password="sesame", + ) + class PasswdDb: def get_password(self, *args): - return ('sesame', 'plain') - authenticator = scram.SCRAMServerAuthenticator('SHA-256', - channel_binding=False, password_database=PasswdDb()) + return ("sesame", "plain") - m = self.negotiateCapabilities(['sasl'], auth=auth) - self.assertEqual(m, Message({}, None, 'AUTHENTICATE', ['SCRAM-SHA-256'])) - self.sendLine('AUTHENTICATE +') + authenticator = scram.SCRAMServerAuthenticator( + "SHA-256", channel_binding=False, password_database=PasswdDb() + ) + + m = self.negotiateCapabilities(["sasl"], auth=auth) + self.assertEqual(m, Message({}, None, "AUTHENTICATE", ["SCRAM-SHA-256"])) + self.sendLine("AUTHENTICATE +") m = self.getMessage() - self.assertEqual(m.command, 'AUTHENTICATE', m) + self.assertEqual(m.command, "AUTHENTICATE", m) client_first = base64.b64decode(m.params[0]) response = authenticator.start(properties={}, initial_response=client_first) assert isinstance(response, bytes), response - self.sendLine('AUTHENTICATE :' + base64.b64encode(response).decode()) + self.sendLine("AUTHENTICATE :" + base64.b64encode(response).decode()) m = self.getMessage() - self.assertEqual(m.command, 'AUTHENTICATE', m) + self.assertEqual(m.command, "AUTHENTICATE", m) msg = base64.b64decode(m.params[0]) r = authenticator.response(msg) assert isinstance(r, tuple), r assert len(r) == 2, r (properties, response) = r - self.sendLine('AUTHENTICATE :' + base64.b64encode(response).decode()) - self.assertEqual(properties, {'authzid': None, 'username': 'jilles'}) + self.sendLine("AUTHENTICATE :" + base64.b64encode(response).decode()) + self.assertEqual(properties, {"authzid": None, "username": "jilles"}) m = self.getMessage() - self.assertEqual(m.command, 'AUTHENTICATE', m) - self.assertEqual(m.params, ['+'], m) + self.assertEqual(m.command, "AUTHENTICATE", m) + self.assertEqual(m.params, ["+"], m) - @cases.OptionalityHelper.skipUnlessHasMechanism('SCRAM-SHA-256') + @cases.OptionalityHelper.skipUnlessHasMechanism("SCRAM-SHA-256") def testScramBadPassword(self): - """Test SCRAM-SHA-256 authentication with a bad password. - """ + """Test SCRAM-SHA-256 authentication with a bad password.""" auth = authentication.Authentication( - mechanisms=[authentication.Mechanisms.scram_sha_256], - username='jilles', - password='sesame', - ) + mechanisms=[authentication.Mechanisms.scram_sha_256], + username="jilles", + password="sesame", + ) + class PasswdDb: def get_password(self, *args): - return ('notsesame', 'plain') - authenticator = scram.SCRAMServerAuthenticator('SHA-256', - channel_binding=False, password_database=PasswdDb()) + return ("notsesame", "plain") - m = self.negotiateCapabilities(['sasl'], auth=auth) - self.assertEqual(m, Message({}, None, 'AUTHENTICATE', ['SCRAM-SHA-256'])) - self.sendLine('AUTHENTICATE +') + authenticator = scram.SCRAMServerAuthenticator( + "SHA-256", channel_binding=False, password_database=PasswdDb() + ) + + m = self.negotiateCapabilities(["sasl"], auth=auth) + self.assertEqual(m, Message({}, None, "AUTHENTICATE", ["SCRAM-SHA-256"])) + self.sendLine("AUTHENTICATE +") m = self.getMessage() - self.assertEqual(m.command, 'AUTHENTICATE', m) + self.assertEqual(m.command, "AUTHENTICATE", m) client_first = base64.b64decode(m.params[0]) response = authenticator.start(properties={}, initial_response=client_first) assert isinstance(response, bytes), response - self.sendLine('AUTHENTICATE :' + base64.b64encode(response).decode()) + self.sendLine("AUTHENTICATE :" + base64.b64encode(response).decode()) m = self.getMessage() - self.assertEqual(m.command, 'AUTHENTICATE', m) + self.assertEqual(m.command, "AUTHENTICATE", m) msg = base64.b64decode(m.params[0]) with self.assertRaises(scram.NotAuthorizedException): authenticator.response(msg) -class Irc302SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper, - cases.OptionalityHelper): - @cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN') + +class Irc302SaslTestCase( + cases.BaseClientTestCase, cases.ClientNegociationHelper, cases.OptionalityHelper +): + @cases.OptionalityHelper.skipUnlessHasMechanism("PLAIN") def testPlainNotAvailable(self): """Test the client does not try to authenticate using a mechanism the server does not advertise. Actually, this is optional.""" auth = authentication.Authentication( - mechanisms=[authentication.Mechanisms.plain], - username='jilles', - password='sesame', - ) - m = self.negotiateCapabilities(['sasl=EXTERNAL'], auth=auth) - self.assertEqual(self.acked_capabilities, {'sasl'}) + mechanisms=[authentication.Mechanisms.plain], + username="jilles", + password="sesame", + ) + m = self.negotiateCapabilities(["sasl=EXTERNAL"], auth=auth) + self.assertEqual(self.acked_capabilities, {"sasl"}) - if m.command == 'QUIT': + if m.command == "QUIT": # Some clients quit when it can't authenticate (eg. Sopel) pass else: # Others will just skip authentication (eg. Limnoria) - self.assertEqual(m, Message({}, None, 'CAP', ['END'])) + self.assertEqual(m, Message({}, None, "CAP", ["END"])) diff --git a/irctest/client_tests/test_tls.py b/irctest/client_tests/test_tls.py index 0765ec1..c13832c 100644 --- a/irctest/client_tests/test_tls.py +++ b/irctest/client_tests/test_tls.py @@ -60,7 +60,7 @@ h4WuPDAI4yh24GjaCZYGR5xcqPCy5CNjMLxdA7HsP+Gcr3eY5XS7noBrbC6IaA0j -----END PRIVATE KEY----- """ -GOOD_FINGERPRINT = 'E1EE6DE2DBC0D43E3B60407B5EE389AEC9D2C53178E0FB14CD51C3DFD544AA2B' +GOOD_FINGERPRINT = "E1EE6DE2DBC0D43E3B60407B5EE389AEC9D2C53178E0FB14CD51C3DFD544AA2B" GOOD_CERT = """ -----BEGIN CERTIFICATE----- MIIDXTCCAkWgAwIBAgIJAKtD9XMC1R0vMA0GCSqGSIb3DQEBCwUAMEUxCzAJBgNV @@ -115,32 +115,29 @@ El9iqRlAhgqaXc4Iz/Zxxhs= -----END PRIVATE KEY----- """ + class TlsTestCase(cases.BaseClientTestCase): - def testTrustedCertificate(self): - tls_config = tls.TlsConfig( - enable=True, - trusted_fingerprints=[GOOD_FINGERPRINT]) + def testTrustedCertificate(self): + tls_config = tls.TlsConfig(enable=True, trusted_fingerprints=[GOOD_FINGERPRINT]) (hostname, port) = self.server.getsockname() self.controller.run( - hostname=hostname, - port=port, - auth=None, - tls_config=tls_config, - ) + hostname=hostname, + port=port, + auth=None, + tls_config=tls_config, + ) self.acceptClient(tls_cert=GOOD_CERT, tls_key=GOOD_KEY) m = self.getMessage() - def testUntrustedCertificate(self): - tls_config = tls.TlsConfig( - enable=True, - trusted_fingerprints=[GOOD_FINGERPRINT]) + def testUntrustedCertificate(self): + tls_config = tls.TlsConfig(enable=True, trusted_fingerprints=[GOOD_FINGERPRINT]) (hostname, port) = self.server.getsockname() self.controller.run( - hostname=hostname, - port=port, - auth=None, - tls_config=tls_config, - ) + hostname=hostname, + port=port, + auth=None, + tls_config=tls_config, + ) self.acceptClient(tls_cert=BAD_CERT, tls_key=BAD_KEY) with self.assertRaises((ConnectionClosed, ConnectionResetError)): m = self.getMessage() @@ -150,36 +147,34 @@ class StsTestCase(cases.BaseClientTestCase, cases.OptionalityHelper): def setUp(self): super().setUp() self.insecure_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.insecure_server.bind(('', 0)) # Bind any free port + self.insecure_server.bind(("", 0)) # Bind any free port self.insecure_server.listen(1) def tearDown(self): self.insecure_server.close() super().tearDown() - @cases.OptionalityHelper.skipUnlessSupportsCapability('sts') + @cases.OptionalityHelper.skipUnlessSupportsCapability("sts") def testSts(self): tls_config = tls.TlsConfig( - enable=False, - trusted_fingerprints=[GOOD_FINGERPRINT]) + enable=False, trusted_fingerprints=[GOOD_FINGERPRINT] + ) # Connect client to insecure server (hostname, port) = self.insecure_server.getsockname() self.controller.run( - hostname=hostname, - port=port, - auth=None, - tls_config=tls_config, - ) + hostname=hostname, + port=port, + auth=None, + tls_config=tls_config, + ) self.acceptClient(server=self.insecure_server) # Send STS policy to client m = self.getMessage() - self.assertEqual(m.command, 'CAP', - 'First message is not CAP LS.') - self.assertEqual(m.params[0], 'LS', - 'First message is not CAP LS.') - self.sendLine('CAP * LS :sts=port={}'.format(self.server.getsockname()[1])) + self.assertEqual(m.command, "CAP", "First message is not CAP LS.") + self.assertEqual(m.params[0], "LS", "First message is not CAP LS.") + self.sendLine("CAP * LS :sts=port={}".format(self.server.getsockname()[1])) # "If the client is not already connected securely to the server # at the requested hostname, it MUST close the insecure connection @@ -187,11 +182,12 @@ class StsTestCase(cases.BaseClientTestCase, cases.OptionalityHelper): self.acceptClient(tls_cert=GOOD_CERT, tls_key=GOOD_KEY) # Send the STS policy, over secure connection this time - self.sendLine('CAP * LS :sts=duration=10,port={}'.format( - self.server.getsockname()[1])) + self.sendLine( + "CAP * LS :sts=duration=10,port={}".format(self.server.getsockname()[1]) + ) # Make the client reconnect. It should reconnect to the secure server. - self.sendLine('ERROR :closing link') + self.sendLine("ERROR :closing link") self.acceptClient() # Kill the client @@ -199,34 +195,32 @@ class StsTestCase(cases.BaseClientTestCase, cases.OptionalityHelper): # Run the client, still configured to connect to the insecure server self.controller.run( - hostname=hostname, - port=port, - auth=None, - tls_config=tls_config, - ) + hostname=hostname, + port=port, + auth=None, + tls_config=tls_config, + ) # The client should remember the STS policy and connect to the secure # server self.acceptClient() - @cases.OptionalityHelper.skipUnlessSupportsCapability('sts') + @cases.OptionalityHelper.skipUnlessSupportsCapability("sts") def testStsInvalidCertificate(self): # Connect client to insecure server (hostname, port) = self.insecure_server.getsockname() self.controller.run( - hostname=hostname, - port=port, - auth=None, - ) + hostname=hostname, + port=port, + auth=None, + ) self.acceptClient(server=self.insecure_server) # Send STS policy to client m = self.getMessage() - self.assertEqual(m.command, 'CAP', - 'First message is not CAP LS.') - self.assertEqual(m.params[0], 'LS', - 'First message is not CAP LS.') - self.sendLine('CAP * LS :sts=port={}'.format(self.server.getsockname()[1])) + self.assertEqual(m.command, "CAP", "First message is not CAP LS.") + self.assertEqual(m.params[0], "LS", "First message is not CAP LS.") + self.sendLine("CAP * LS :sts=port={}".format(self.server.getsockname()[1])) # The client will reconnect to the TLS port. Unfortunately, it does # not trust its fingerprint. diff --git a/irctest/controllers/charybdis.py b/irctest/controllers/charybdis.py index 4124517..498c989 100644 --- a/irctest/controllers/charybdis.py +++ b/irctest/controllers/charybdis.py @@ -43,45 +43,61 @@ TEMPLATE_SSL_CONFIG = """ class CharybdisController(BaseServerController, DirectoryBasedController): - software_name = 'Charybdis' + software_name = "Charybdis" supported_sasl_mechanisms = set() supported_capabilities = set() # Not exhaustive + def create_config(self): super().create_config() - with self.open_file('server.conf'): + with self.open_file("server.conf"): pass - def run(self, hostname, port, password=None, ssl=False, - valid_metadata_keys=None, invalid_metadata_keys=None): + def run( + self, + hostname, + port, + password=None, + ssl=False, + valid_metadata_keys=None, + invalid_metadata_keys=None, + ): if valid_metadata_keys or invalid_metadata_keys: raise NotImplementedByController( - 'Defining valid and invalid METADATA keys.') + "Defining valid and invalid METADATA keys." + ) assert self.proc is None self.create_config() self.port = port - password_field = 'password = "{}";'.format(password) if password else '' + password_field = 'password = "{}";'.format(password) if password else "" if ssl: self.gen_ssl() ssl_config = TEMPLATE_SSL_CONFIG.format( - key_path=self.key_path, - pem_path=self.pem_path, - dh_path=self.dh_path, - ) - else: - ssl_config = '' - with self.open_file('server.conf') as fd: - fd.write(TEMPLATE_CONFIG.format( - hostname=hostname, - port=port, - password_field=password_field, - ssl_config=ssl_config, - )) - self.proc = subprocess.Popen(['charybdis', '-foreground', - '-configfile', os.path.join(self.directory, 'server.conf'), - '-pidfile', os.path.join(self.directory, 'server.pid'), - ], - stderr=subprocess.DEVNULL + key_path=self.key_path, + pem_path=self.pem_path, + dh_path=self.dh_path, ) + else: + ssl_config = "" + with self.open_file("server.conf") as fd: + fd.write( + TEMPLATE_CONFIG.format( + hostname=hostname, + port=port, + password_field=password_field, + ssl_config=ssl_config, + ) + ) + self.proc = subprocess.Popen( + [ + "charybdis", + "-foreground", + "-configfile", + os.path.join(self.directory, "server.conf"), + "-pidfile", + os.path.join(self.directory, "server.pid"), + ], + stderr=subprocess.DEVNULL, + ) def get_irctest_controller_class(): diff --git a/irctest/controllers/girc.py b/irctest/controllers/girc.py index 43a43f8..5c52aba 100644 --- a/irctest/controllers/girc.py +++ b/irctest/controllers/girc.py @@ -2,9 +2,10 @@ import subprocess from irctest.basecontrollers import BaseClientController, NotImplementedByController + class GircController(BaseClientController): - software_name = 'gIRC' - supported_sasl_mechanisms = ['PLAIN'] + software_name = "gIRC" + supported_sasl_mechanisms = ["PLAIN"] supported_capabilities = set() # Not exhaustive def __init__(self): @@ -30,16 +31,17 @@ class GircController(BaseClientController): def run(self, hostname, port, auth, tls_config): if tls_config: print(tls_config) - raise NotImplementedByController('TLS options') - args = ['--host', hostname, '--port', str(port), '--quiet'] + raise NotImplementedByController("TLS options") + args = ["--host", hostname, "--port", str(port), "--quiet"] if auth and auth.username and auth.password: - args += ['--sasl-name', auth.username] - args += ['--sasl-pass', auth.password] - args += ['--sasl-fail-is-ok'] + args += ["--sasl-name", auth.username] + args += ["--sasl-pass", auth.password] + args += ["--sasl-fail-is-ok"] # Runs a client with the config given as arguments - self.proc = subprocess.Popen(['girc_test', 'connect'] + args) + self.proc = subprocess.Popen(["girc_test", "connect"] + args) + def get_irctest_controller_class(): return GircController diff --git a/irctest/controllers/hybrid.py b/irctest/controllers/hybrid.py index 8067199..3b2fc28 100644 --- a/irctest/controllers/hybrid.py +++ b/irctest/controllers/hybrid.py @@ -41,47 +41,62 @@ TEMPLATE_SSL_CONFIG = """ class HybridController(BaseServerController, DirectoryBasedController): - software_name = 'Hybrid' + software_name = "Hybrid" supported_sasl_mechanisms = set() supported_capabilities = set() # Not exhaustive def create_config(self): super().create_config() - with self.open_file('server.conf'): + with self.open_file("server.conf"): pass - def run(self, hostname, port, password=None, ssl=False, - valid_metadata_keys=None, invalid_metadata_keys=None): + def run( + self, + hostname, + port, + password=None, + ssl=False, + valid_metadata_keys=None, + invalid_metadata_keys=None, + ): if valid_metadata_keys or invalid_metadata_keys: raise NotImplementedByController( - 'Defining valid and invalid METADATA keys.') + "Defining valid and invalid METADATA keys." + ) assert self.proc is None self.create_config() self.port = port - password_field = 'password = "{}";'.format(password) if password else '' + password_field = 'password = "{}";'.format(password) if password else "" if ssl: self.gen_ssl() ssl_config = TEMPLATE_SSL_CONFIG.format( - key_path=self.key_path, - pem_path=self.pem_path, - dh_path=self.dh_path, - ) + key_path=self.key_path, + pem_path=self.pem_path, + dh_path=self.dh_path, + ) else: - ssl_config = '' - with self.open_file('server.conf') as fd: - fd.write(TEMPLATE_CONFIG.format( - hostname=hostname, - port=port, - password_field=password_field, - ssl_config=ssl_config, - )) - self.proc = subprocess.Popen(['ircd', '-foreground', - '-configfile', os.path.join(self.directory, 'server.conf'), - '-pidfile', os.path.join(self.directory, 'server.pid'), + ssl_config = "" + with self.open_file("server.conf") as fd: + fd.write( + TEMPLATE_CONFIG.format( + hostname=hostname, + port=port, + password_field=password_field, + ssl_config=ssl_config, + ) + ) + self.proc = subprocess.Popen( + [ + "ircd", + "-foreground", + "-configfile", + os.path.join(self.directory, "server.conf"), + "-pidfile", + os.path.join(self.directory, "server.pid"), ], stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL - ) + stderr=subprocess.DEVNULL, + ) def get_irctest_controller_class(): diff --git a/irctest/controllers/inspircd.py b/irctest/controllers/inspircd.py index cb54d63..0ea8502 100644 --- a/irctest/controllers/inspircd.py +++ b/irctest/controllers/inspircd.py @@ -27,47 +27,63 @@ TEMPLATE_SSL_CONFIG = """ """ + class InspircdController(BaseServerController, DirectoryBasedController): - software_name = 'InspIRCd' + software_name = "InspIRCd" supported_sasl_mechanisms = set() supported_capabilities = set() # Not exhaustive def create_config(self): super().create_config() - with self.open_file('server.conf'): + with self.open_file("server.conf"): pass - def run(self, hostname, port, password=None, ssl=False, - restricted_metadata_keys=None, - valid_metadata_keys=None, invalid_metadata_keys=None): + def run( + self, + hostname, + port, + password=None, + ssl=False, + restricted_metadata_keys=None, + valid_metadata_keys=None, + invalid_metadata_keys=None, + ): if valid_metadata_keys or invalid_metadata_keys: raise NotImplementedByController( - 'Defining valid and invalid METADATA keys.') + "Defining valid and invalid METADATA keys." + ) assert self.proc is None self.port = port self.create_config() - password_field = 'password="{}"'.format(password) if password else '' + password_field = 'password="{}"'.format(password) if password else "" if ssl: self.gen_ssl() ssl_config = TEMPLATE_SSL_CONFIG.format( - key_path=self.key_path, - pem_path=self.pem_path, - dh_path=self.dh_path, - ) - else: - ssl_config = '' - with self.open_file('server.conf') as fd: - fd.write(TEMPLATE_CONFIG.format( - hostname=hostname, - port=port, - password_field=password_field, - ssl_config=ssl_config - )) - self.proc = subprocess.Popen(['inspircd', '--nofork', '--config', - os.path.join(self.directory, 'server.conf')], - stdout=subprocess.DEVNULL + key_path=self.key_path, + pem_path=self.pem_path, + dh_path=self.dh_path, ) + else: + ssl_config = "" + with self.open_file("server.conf") as fd: + fd.write( + TEMPLATE_CONFIG.format( + hostname=hostname, + port=port, + password_field=password_field, + ssl_config=ssl_config, + ) + ) + self.proc = subprocess.Popen( + [ + "inspircd", + "--nofork", + "--config", + os.path.join(self.directory, "server.conf"), + ], + stdout=subprocess.DEVNULL, + ) + def get_irctest_controller_class(): return InspircdController - diff --git a/irctest/controllers/limnoria.py b/irctest/controllers/limnoria.py index fdf274d..ae00f63 100644 --- a/irctest/controllers/limnoria.py +++ b/irctest/controllers/limnoria.py @@ -26,19 +26,23 @@ supybot.networks.testnet.sasl.ecdsa_key: {directory}/ecdsa_key.pem supybot.networks.testnet.sasl.mechanisms: {mechanisms} """ + class LimnoriaController(BaseClientController, DirectoryBasedController): - software_name = 'Limnoria' + software_name = "Limnoria" supported_sasl_mechanisms = { - 'PLAIN', 'ECDSA-NIST256P-CHALLENGE', 'SCRAM-SHA-256', 'EXTERNAL', - } - supported_capabilities = set(['sts']) # Not exhaustive + "PLAIN", + "ECDSA-NIST256P-CHALLENGE", + "SCRAM-SHA-256", + "EXTERNAL", + } + supported_capabilities = set(["sts"]) # Not exhaustive def create_config(self): create_config = super().create_config() if create_config: - with self.open_file('bot.conf'): + with self.open_file("bot.conf"): pass - with self.open_file('conf/users.conf'): + with self.open_file("conf/users.conf"): pass def run(self, hostname, port, auth, tls_config=None): @@ -48,27 +52,34 @@ class LimnoriaController(BaseClientController, DirectoryBasedController): assert self.proc is None self.create_config() if auth: - mechanisms = ' '.join(map(authentication.Mechanisms.as_string, - auth.mechanisms)) + mechanisms = " ".join( + map(authentication.Mechanisms.as_string, auth.mechanisms) + ) if auth.ecdsa_key: - with self.open_file('ecdsa_key.pem') as fd: + with self.open_file("ecdsa_key.pem") as fd: fd.write(auth.ecdsa_key) else: - mechanisms = '' - with self.open_file('bot.conf') as fd: - fd.write(TEMPLATE_CONFIG.format( - directory=self.directory, - loglevel='CRITICAL', - hostname=hostname, - port=port, - username=auth.username if auth else '', - password=auth.password if auth else '', - mechanisms=mechanisms.lower(), - enable_tls=tls_config.enable if tls_config else 'False', - trusted_fingerprints=' '.join(tls_config.trusted_fingerprints) if tls_config else '', - )) - self.proc = subprocess.Popen(['supybot', - os.path.join(self.directory, 'bot.conf')]) + mechanisms = "" + with self.open_file("bot.conf") as fd: + fd.write( + TEMPLATE_CONFIG.format( + directory=self.directory, + loglevel="CRITICAL", + hostname=hostname, + port=port, + username=auth.username if auth else "", + password=auth.password if auth else "", + mechanisms=mechanisms.lower(), + enable_tls=tls_config.enable if tls_config else "False", + trusted_fingerprints=" ".join(tls_config.trusted_fingerprints) + if tls_config + else "", + ) + ) + self.proc = subprocess.Popen( + ["supybot", os.path.join(self.directory, "bot.conf")] + ) + def get_irctest_controller_class(): return LimnoriaController diff --git a/irctest/controllers/mammon.py b/irctest/controllers/mammon.py index 054e5ea..de786f2 100644 --- a/irctest/controllers/mammon.py +++ b/irctest/controllers/mammon.py @@ -58,66 +58,84 @@ server: recvq_len: 20 """ + def make_list(l): - return '\n'.join(map(' - {}'.format, l)) + return "\n".join(map(" - {}".format, l)) + class MammonController(BaseServerController, DirectoryBasedController): - software_name = 'Mammon' + software_name = "Mammon" supported_sasl_mechanisms = { - 'PLAIN', 'ECDSA-NIST256P-CHALLENGE', - } + "PLAIN", + "ECDSA-NIST256P-CHALLENGE", + } supported_capabilities = set() # Not exhaustive def create_config(self): super().create_config() - with self.open_file('server.conf'): + with self.open_file("server.conf"): pass def kill_proc(self): # Mammon does not seem to handle SIGTERM very well self.proc.kill() - def run(self, hostname, port, password=None, ssl=False, - restricted_metadata_keys=(), - valid_metadata_keys=(), invalid_metadata_keys=()): + def run( + self, + hostname, + port, + password=None, + ssl=False, + restricted_metadata_keys=(), + valid_metadata_keys=(), + invalid_metadata_keys=(), + ): if password is not None: - raise NotImplementedByController('PASS command') + raise NotImplementedByController("PASS command") if ssl: - raise NotImplementedByController('SSL') + raise NotImplementedByController("SSL") assert self.proc is None self.port = port self.create_config() - with self.open_file('server.yml') as fd: - fd.write(TEMPLATE_CONFIG.format( - directory=self.directory, - hostname=hostname, - port=port, - authorized_keys=make_list(valid_metadata_keys), - restricted_keys=make_list(restricted_metadata_keys), - )) - #with self.open_file('server.yml', 'r') as fd: + with self.open_file("server.yml") as fd: + fd.write( + TEMPLATE_CONFIG.format( + directory=self.directory, + hostname=hostname, + port=port, + authorized_keys=make_list(valid_metadata_keys), + restricted_keys=make_list(restricted_metadata_keys), + ) + ) + # with self.open_file('server.yml', 'r') as fd: # print(fd.read()) - self.proc = subprocess.Popen(['mammond', '--nofork', #'--debug', - '--config', os.path.join(self.directory, 'server.yml')]) + self.proc = subprocess.Popen( + [ + "mammond", + "--nofork", #'--debug', + "--config", + os.path.join(self.directory, "server.yml"), + ] + ) def registerUser(self, case, username, password=None): # XXX: Move this somewhere else when # https://github.com/ircv3/ircv3-specifications/pull/152 becomes # part of the specification client = case.addClient(show_io=False) - case.sendLine(client, 'CAP LS 302') - case.sendLine(client, 'NICK registration_user') - case.sendLine(client, 'USER r e g :user') - case.sendLine(client, 'CAP END') - while case.getRegistrationMessage(client).command != '001': + case.sendLine(client, "CAP LS 302") + case.sendLine(client, "NICK registration_user") + case.sendLine(client, "USER r e g :user") + case.sendLine(client, "CAP END") + while case.getRegistrationMessage(client).command != "001": pass list(case.getMessages(client)) - case.sendLine(client, 'REG CREATE {} passphrase {}'.format( - username, password)) + case.sendLine(client, "REG CREATE {} passphrase {}".format(username, password)) msg = case.getMessage(client) - assert msg.command == '920', msg + assert msg.command == "920", msg list(case.getMessages(client)) case.removeClient(client) + def get_irctest_controller_class(): return MammonController diff --git a/irctest/controllers/oragono.py b/irctest/controllers/oragono.py index 1fde6bf..85db19c 100644 --- a/irctest/controllers/oragono.py +++ b/irctest/controllers/oragono.py @@ -6,13 +6,12 @@ import subprocess from irctest.basecontrollers import NotImplementedByController from irctest.basecontrollers import BaseServerController, DirectoryBasedController -OPER_PWD = 'frenchfries' +OPER_PWD = "frenchfries" BASE_CONFIG = { "network": { "name": "OragonoTest", }, - "server": { "name": "oragono.test", "listeners": {}, @@ -35,140 +34,153 @@ BASE_CONFIG = { "ban-message": "Try again later", "exempted": ["localhost"], }, - 'enforce-utf8': True, - 'relaymsg': { - 'enabled': True, - 'separators': '/', - 'available-to-chanops': True, + "enforce-utf8": True, + "relaymsg": { + "enabled": True, + "separators": "/", + "available-to-chanops": True, }, }, - - 'accounts': { - 'authentication-enabled': True, - 'multiclient': { - 'allowed-by-default': True, - 'enabled': True, - 'always-on': 'disabled', + "accounts": { + "authentication-enabled": True, + "multiclient": { + "allowed-by-default": True, + "enabled": True, + "always-on": "disabled", }, - 'registration': { - 'bcrypt-cost': 4, - 'enabled': True, - 'enabled-callbacks': ['none'], - 'verify-timeout': '120h', - }, - 'nick-reservation': { - 'enabled': True, - 'additional-nick-limit': 2, - 'method': 'strict', + "registration": { + "bcrypt-cost": 4, + "enabled": True, + "enabled-callbacks": ["none"], + "verify-timeout": "120h", + }, + "nick-reservation": { + "enabled": True, + "additional-nick-limit": 2, + "method": "strict", }, }, - - "channels": { - "registration": {"enabled": True,}, - }, - - "datastore": { - "path": None, - }, - - 'limits': { - 'awaylen': 200, - 'chan-list-modes': 60, - 'channellen': 64, - 'kicklen': 390, - 'linelen': {'rest': 2048,}, - 'monitor-entries': 100, - 'nicklen': 32, - 'topiclen': 390, - 'whowas-entries': 100, - 'multiline': {'max-bytes': 4096, 'max-lines': 32,}, - }, - - "history": { - "enabled": True, - "channel-length": 128, - "client-length": 128, - "chathistory-maxmessages": 100, - "tagmsg-storage": { - "default": False, - "whitelist": ["+draft/persist", "+persist"], - }, - }, - - 'oper-classes': { - 'server-admin': { - 'title': 'Server Admin', - 'capabilities': [ - "oper:local_kill", - "oper:local_ban", - "oper:local_unban", - "nofakelag", - "oper:remote_kill", - "oper:remote_ban", - "oper:remote_unban", - "oper:rehash", - "oper:die", - "accreg", - "sajoin", - "samode", - "vhosts", - "chanreg", - "relaymsg", + "channels": { + "registration": { + "enabled": True, + }, + }, + "datastore": { + "path": None, + }, + "limits": { + "awaylen": 200, + "chan-list-modes": 60, + "channellen": 64, + "kicklen": 390, + "linelen": { + "rest": 2048, + }, + "monitor-entries": 100, + "nicklen": 32, + "topiclen": 390, + "whowas-entries": 100, + "multiline": { + "max-bytes": 4096, + "max-lines": 32, + }, + }, + "history": { + "enabled": True, + "channel-length": 128, + "client-length": 128, + "chathistory-maxmessages": 100, + "tagmsg-storage": { + "default": False, + "whitelist": ["+draft/persist", "+persist"], + }, + }, + "oper-classes": { + "server-admin": { + "title": "Server Admin", + "capabilities": [ + "oper:local_kill", + "oper:local_ban", + "oper:local_unban", + "nofakelag", + "oper:remote_kill", + "oper:remote_ban", + "oper:remote_unban", + "oper:rehash", + "oper:die", + "accreg", + "sajoin", + "samode", + "vhosts", + "chanreg", + "relaymsg", ], }, }, - - 'opers': { - 'root': { - 'class': 'server-admin', - 'whois-line': 'is a server admin', + "opers": { + "root": { + "class": "server-admin", + "whois-line": "is a server admin", # OPER_PWD - 'password': '$2a$04$3GzUZB5JapaAbwn7sogpOu9NSiLOgnozVllm2e96LiNPrm61ZsZSq', + "password": "$2a$04$3GzUZB5JapaAbwn7sogpOu9NSiLOgnozVllm2e96LiNPrm61ZsZSq", }, }, } LOGGING_CONFIG = { "logging": [ - { - "method": "stderr", - "level": "debug", - "type": "*", - }, + { + "method": "stderr", + "level": "debug", + "type": "*", + }, ] } + def hash_password(password): if isinstance(password, str): - password = password.encode('utf-8') + password = password.encode("utf-8") # simulate entry of password and confirmation: - input_ = password + b'\n' + password + b'\n' - p = subprocess.Popen(['oragono', 'genpasswd'], stdin=subprocess.PIPE, stdout=subprocess.PIPE) + input_ = password + b"\n" + password + b"\n" + p = subprocess.Popen( + ["oragono", "genpasswd"], stdin=subprocess.PIPE, stdout=subprocess.PIPE + ) out, _ = p.communicate(input_) - return out.decode('utf-8') + return out.decode("utf-8") + class OragonoController(BaseServerController, DirectoryBasedController): - software_name = 'Oragono' + software_name = "Oragono" supported_sasl_mechanisms = { - 'PLAIN', + "PLAIN", } - _port_wait_interval = .01 + _port_wait_interval = 0.01 supported_capabilities = set() # Not exhaustive def create_config(self): super().create_config() - with self.open_file('ircd.yaml'): + with self.open_file("ircd.yaml"): pass def kill_proc(self): self.proc.kill() - def run(self, hostname, port, password=None, ssl=False, - restricted_metadata_keys=None, - valid_metadata_keys=None, invalid_metadata_keys=None, config=None): + def run( + self, + hostname, + port, + password=None, + ssl=False, + restricted_metadata_keys=None, + valid_metadata_keys=None, + invalid_metadata_keys=None, + config=None, + ): if valid_metadata_keys or invalid_metadata_keys: raise NotImplementedByController( - 'Defining valid and invalid METADATA keys.') + "Defining valid and invalid METADATA keys." + ) self.create_config() if config is None: @@ -180,59 +192,60 @@ class OragonoController(BaseServerController, DirectoryBasedController): config = self.addMysqlToConfig(config) if enable_roleplay: - config['roleplay'] = { - 'enabled': True, + config["roleplay"] = { + "enabled": True, } - if 'oragono_config' in self.test_config: - self.test_config['oragono_config'](config) + if "oragono_config" in self.test_config: + self.test_config["oragono_config"](config) self.port = port bind_address = "127.0.0.1:%s" % (port,) - listener_conf = None # plaintext + listener_conf = None # plaintext if ssl: - self.key_path = os.path.join(self.directory, 'ssl.key') - self.pem_path = os.path.join(self.directory, 'ssl.pem') - listener_conf = {"tls": {"cert": self.pem_path, "key": self.key_path},} - config['server']['listeners'][bind_address] = listener_conf + self.key_path = os.path.join(self.directory, "ssl.key") + self.pem_path = os.path.join(self.directory, "ssl.pem") + listener_conf = { + "tls": {"cert": self.pem_path, "key": self.key_path}, + } + config["server"]["listeners"][bind_address] = listener_conf - config['datastore']['path'] = os.path.join(self.directory, 'ircd.db') + config["datastore"]["path"] = os.path.join(self.directory, "ircd.db") if password is not None: - config['server']['password'] = hash_password(password) + config["server"]["password"] = hash_password(password) assert self.proc is None - self._config_path = os.path.join(self.directory, 'server.yml') + self._config_path = os.path.join(self.directory, "server.yml") self._config = config self._write_config() - subprocess.call(['oragono', 'initdb', - '--conf', self._config_path, '--quiet']) - subprocess.call(['oragono', 'mkcerts', - '--conf', self._config_path, '--quiet']) - self.proc = subprocess.Popen(['oragono', 'run', - '--conf', self._config_path, '--quiet']) + subprocess.call(["oragono", "initdb", "--conf", self._config_path, "--quiet"]) + subprocess.call(["oragono", "mkcerts", "--conf", self._config_path, "--quiet"]) + self.proc = subprocess.Popen( + ["oragono", "run", "--conf", self._config_path, "--quiet"] + ) def registerUser(self, case, username, password=None): # XXX: Move this somewhere else when # https://github.com/ircv3/ircv3-specifications/pull/152 becomes # part of the specification client = case.addClient(show_io=False) - case.sendLine(client, 'CAP LS 302') - case.sendLine(client, 'NICK ' + username) - case.sendLine(client, 'USER r e g :user') - case.sendLine(client, 'CAP END') - while case.getRegistrationMessage(client).command != '001': + case.sendLine(client, "CAP LS 302") + case.sendLine(client, "NICK " + username) + case.sendLine(client, "USER r e g :user") + case.sendLine(client, "CAP END") + while case.getRegistrationMessage(client).command != "001": pass case.getMessages(client) - case.sendLine(client, 'NS REGISTER ' + password) + case.sendLine(client, "NS REGISTER " + password) msg = case.getMessage(client) - assert msg.params == [username, 'Account created'] - case.sendLine(client, 'QUIT') + assert msg.params == [username, "Account created"] + case.sendLine(client, "QUIT") case.assertDisconnected(client) def _write_config(self): - with open(self._config_path, 'w') as fd: + with open(self._config_path, "w") as fd: json.dump(self._config, fd) def baseConfig(self): @@ -248,25 +261,25 @@ class OragonoController(BaseServerController, DirectoryBasedController): return config def addMysqlToConfig(self, config=None): - mysql_password = os.getenv('MYSQL_PASSWORD') + mysql_password = os.getenv("MYSQL_PASSWORD") if not mysql_password: return config if config is None: config = self.baseConfig() - config['datastore']['mysql'] = { - "enabled": True, - "host": "localhost", - "user": "oragono", - "password": mysql_password, - "history-database": "oragono_history", - "timeout": "3s", + config["datastore"]["mysql"] = { + "enabled": True, + "host": "localhost", + "user": "oragono", + "password": mysql_password, + "history-database": "oragono_history", + "timeout": "3s", } - config['accounts']['multiclient'] = { - 'enabled': True, - 'allowed-by-default': True, - 'always-on': 'disabled', + config["accounts"]["multiclient"] = { + "enabled": True, + "allowed-by-default": True, + "always-on": "disabled", } - config['history']['persistent'] = { + config["history"]["persistent"] = { "enabled": True, "unregistered-channels": True, "registered-channels": "opt-out", @@ -277,12 +290,12 @@ class OragonoController(BaseServerController, DirectoryBasedController): def rehash(self, case, config): self._config = config self._write_config() - client = 'operator_for_rehash' + client = "operator_for_rehash" case.connectClient(nick=client, name=client) - case.sendLine(client, 'OPER root %s' % (OPER_PWD,)) - case.sendLine(client, 'REHASH') + case.sendLine(client, "OPER root %s" % (OPER_PWD,)) + case.sendLine(client, "REHASH") case.getMessages(client) - case.sendLine(client, 'QUIT') + case.sendLine(client, "QUIT") case.assertDisconnected(client) def enable_debug_logging(self, case): @@ -290,5 +303,6 @@ class OragonoController(BaseServerController, DirectoryBasedController): config.update(LOGGING_CONFIG) self.rehash(case, config) + def get_irctest_controller_class(): return OragonoController diff --git a/irctest/controllers/sopel.py b/irctest/controllers/sopel.py index a90db10..26d6a20 100644 --- a/irctest/controllers/sopel.py +++ b/irctest/controllers/sopel.py @@ -19,30 +19,30 @@ auth_password = {password} {auth_method} """ + class SopelController(BaseClientController): - software_name = 'Sopel' + software_name = "Sopel" supported_sasl_mechanisms = { - 'PLAIN', - } + "PLAIN", + } supported_capabilities = set() # Not exhaustive def __init__(self, test_config): super().__init__(test_config) - self.filename = next(tempfile._get_candidate_names()) + '.cfg' + self.filename = next(tempfile._get_candidate_names()) + ".cfg" self.proc = None + def kill(self): if self.proc: self.proc.kill() if self.filename: try: - os.unlink(os.path.join(os.path.expanduser('~/.sopel/'), - self.filename)) - except OSError: # File does not exist + os.unlink(os.path.join(os.path.expanduser("~/.sopel/"), self.filename)) + except OSError: #  File does not exist pass - def open_file(self, filename, mode='a'): - return open(os.path.join(os.path.expanduser('~/.sopel/'), filename), - mode) + def open_file(self, filename, mode="a"): + return open(os.path.join(os.path.expanduser("~/.sopel/"), filename), mode) def create_config(self): with self.open_file(self.filename) as fd: @@ -51,20 +51,21 @@ class SopelController(BaseClientController): def run(self, hostname, port, auth, tls_config): # Runs a client with the config given as arguments if tls_config is not None: - raise NotImplementedByController( - 'TLS configuration') + raise NotImplementedByController("TLS configuration") assert self.proc is None self.create_config() with self.open_file(self.filename) as fd: - fd.write(TEMPLATE_CONFIG.format( - hostname=hostname, - port=port, - username=auth.username if auth else '', - password=auth.password if auth else '', - auth_method='auth_method = sasl' if auth else '', - )) - self.proc = subprocess.Popen(['sopel', '--quiet', '-c', self.filename]) + fd.write( + TEMPLATE_CONFIG.format( + hostname=hostname, + port=port, + username=auth.username if auth else "", + password=auth.password if auth else "", + auth_method="auth_method = sasl" if auth else "", + ) + ) + self.proc = subprocess.Popen(["sopel", "--quiet", "-c", self.filename]) + def get_irctest_controller_class(): return SopelController - diff --git a/irctest/exceptions.py b/irctest/exceptions.py index 8b1ca49..c91e8a2 100644 --- a/irctest/exceptions.py +++ b/irctest/exceptions.py @@ -1,6 +1,6 @@ class NoMessageException(AssertionError): pass + class ConnectionClosed(Exception): pass - diff --git a/irctest/irc_utils/ambiguities.py b/irctest/irc_utils/ambiguities.py index f1be15a..42b9e2b 100644 --- a/irctest/irc_utils/ambiguities.py +++ b/irctest/irc_utils/ambiguities.py @@ -2,6 +2,7 @@ Handles ambiguities of RFCs. """ + def normalize_namreply_params(params): # So… RFC 2812 says: # "( "=" / "*" / "@" ) @@ -12,7 +13,7 @@ def normalize_namreply_params(params): # So let's normalize this to “with space”, and strip spaces at the # end of the nick list. if len(params) == 3: - assert params[1][0] in '=*@', params + assert params[1][0] in "=*@", params params.insert(1), params[1][0] params[2] = params[2][1:] params[3] = params[3].rstrip() diff --git a/irctest/irc_utils/capabilities.py b/irctest/irc_utils/capabilities.py index 78c0d17..fbbb569 100644 --- a/irctest/irc_utils/capabilities.py +++ b/irctest/irc_utils/capabilities.py @@ -1,8 +1,8 @@ def cap_list_to_dict(l): d = {} for cap in l: - if '=' in cap: - (key, value) = cap.split('=', 1) + if "=" in cap: + (key, value) = cap.split("=", 1) else: key = cap value = None diff --git a/irctest/irc_utils/junkdrawer.py b/irctest/irc_utils/junkdrawer.py index 3428dc6..3bd7bb3 100644 --- a/irctest/irc_utils/junkdrawer.py +++ b/irctest/irc_utils/junkdrawer.py @@ -3,24 +3,35 @@ import re import secrets from collections import namedtuple -HistoryMessage = namedtuple('HistoryMessage', ['time', 'msgid', 'target', 'text']) +HistoryMessage = namedtuple("HistoryMessage", ["time", "msgid", "target", "text"]) + def to_history_message(msg): - return HistoryMessage(time=msg.tags.get('time'), msgid=msg.tags.get('msgid'), target=msg.params[0], text=msg.params[1]) + return HistoryMessage( + time=msg.tags.get("time"), + msgid=msg.tags.get("msgid"), + target=msg.params[0], + text=msg.params[1], + ) + # thanks jess! IRCV3_FORMAT_STRFTIME = "%Y-%m-%dT%H:%M:%S.%f%z" + def ircv3_timestamp_to_unixtime(timestamp): return datetime.datetime.strptime(timestamp, IRCV3_FORMAT_STRFTIME).timestamp() + def random_name(base): - return base + '-' + secrets.token_hex(8) + return base + "-" + secrets.token_hex(8) + """ Stolen from supybot: """ + class MultipleReplacer: """Return a callable that replaces all dict keys by the associated value. More efficient than multiple .replace().""" @@ -30,24 +41,26 @@ class MultipleReplacer: # it to a class in Python 3. def __init__(self, dict_): self._dict = dict_ - dict_ = dict([(re.escape(key), val) for key,val in dict_.items()]) - self._matcher = re.compile('|'.join(dict_.keys())) + dict_ = dict([(re.escape(key), val) for key, val in dict_.items()]) + self._matcher = re.compile("|".join(dict_.keys())) + def __call__(self, s): return self._matcher.sub(lambda m: self._dict[m.group(0)], s) + def normalizeWhitespace(s, removeNewline=True): r"""Normalizes the whitespace in a string; \s+ becomes one space.""" if not s: - return str(s) # not the same reference - starts_with_space = (s[0] in ' \n\t\r') - ends_with_space = (s[-1] in ' \n\t\r') + return str(s) # not the same reference + starts_with_space = s[0] in " \n\t\r" + ends_with_space = s[-1] in " \n\t\r" if removeNewline: - newline_re = re.compile('[\r\n]+') - s = ' '.join(filter(bool, newline_re.split(s))) - s = ' '.join(filter(bool, s.split('\t'))) - s = ' '.join(filter(bool, s.split(' '))) + newline_re = re.compile("[\r\n]+") + s = " ".join(filter(bool, newline_re.split(s))) + s = " ".join(filter(bool, s.split("\t"))) + s = " ".join(filter(bool, s.split(" "))) if starts_with_space: - s = ' ' + s + s = " " + s if ends_with_space: - s += ' ' + s += " " return s diff --git a/irctest/irc_utils/message_parser.py b/irctest/irc_utils/message_parser.py index 3005448..d6303b4 100644 --- a/irctest/irc_utils/message_parser.py +++ b/irctest/irc_utils/message_parser.py @@ -5,58 +5,58 @@ from .junkdrawer import MultipleReplacer # http://ircv3.net/specs/core/message-tags-3.2.html#escaping-values TAG_ESCAPE = [ - ('\\', '\\\\'), # \ -> \\ - (' ', r'\s'), - (';', r'\:'), - ('\r', r'\r'), - ('\n', r'\n'), - ] -unescape_tag_value = MultipleReplacer( - dict(map(lambda x:(x[1],x[0]), TAG_ESCAPE))) + ("\\", "\\\\"), # \ -> \\ + (" ", r"\s"), + (";", r"\:"), + ("\r", r"\r"), + ("\n", r"\n"), +] +unescape_tag_value = MultipleReplacer(dict(map(lambda x: (x[1], x[0]), TAG_ESCAPE))) # TODO: validate host -tag_key_validator = re.compile(r'\+?(\S+/)?[a-zA-Z0-9-]+') +tag_key_validator = re.compile(r"\+?(\S+/)?[a-zA-Z0-9-]+") + def parse_tags(s): tags = {} - for tag in s.split(';'): - if '=' not in tag: + for tag in s.split(";"): + if "=" not in tag: tags[tag] = None else: - (key, value) = tag.split('=', 1) - assert tag_key_validator.match(key), \ - 'Invalid tag key: {}'.format(key) + (key, value) = tag.split("=", 1) + assert tag_key_validator.match(key), "Invalid tag key: {}".format(key) tags[key] = unescape_tag_value(value) return tags -Message = collections.namedtuple('Message', - 'tags prefix command params') + +Message = collections.namedtuple("Message", "tags prefix command params") + def parse_message(s): """Parse a message according to http://tools.ietf.org/html/rfc1459#section-2.3.1 and http://ircv3.net/specs/core/message-tags-3.2.html""" - s = s.rstrip('\r\n') - if s.startswith('@'): - (tags, s) = s.split(' ', 1) + s = s.rstrip("\r\n") + if s.startswith("@"): + (tags, s) = s.split(" ", 1) tags = parse_tags(tags[1:]) else: tags = {} - if ' :' in s: - (other_tokens, trailing_param) = s.split(' :', 1) - tokens = list(filter(bool, other_tokens.split(' '))) + [trailing_param] + if " :" in s: + (other_tokens, trailing_param) = s.split(" :", 1) + tokens = list(filter(bool, other_tokens.split(" "))) + [trailing_param] else: - tokens = list(filter(bool, s.split(' '))) - if tokens[0].startswith(':'): + tokens = list(filter(bool, s.split(" "))) + if tokens[0].startswith(":"): prefix = tokens.pop(0)[1:] else: prefix = None command = tokens.pop(0) params = tokens return Message( - tags=tags, - prefix=prefix, - command=command, - params=params, - ) + tags=tags, + prefix=prefix, + command=command, + params=params, + ) diff --git a/irctest/irc_utils/sasl.py b/irctest/irc_utils/sasl.py index 8bf075c..4957af2 100644 --- a/irctest/irc_utils/sasl.py +++ b/irctest/irc_utils/sasl.py @@ -1,6 +1,15 @@ import base64 + def sasl_plain_blob(username, passphrase): - blob = base64.b64encode(b'\x00'.join((username.encode('utf-8'), username.encode('utf-8'), passphrase.encode('utf-8')))) - blobstr = blob.decode('ascii') - return f'AUTHENTICATE {blobstr}' + blob = base64.b64encode( + b"\x00".join( + ( + username.encode("utf-8"), + username.encode("utf-8"), + passphrase.encode("utf-8"), + ) + ) + ) + blobstr = blob.decode("ascii") + return f"AUTHENTICATE {blobstr}" diff --git a/irctest/numerics.py b/irctest/numerics.py index dbfa4b1..776e804 100644 --- a/irctest/numerics.py +++ b/irctest/numerics.py @@ -9,191 +9,191 @@ # They're intended to represent a relatively-standard cross-section of the IRC # server ecosystem out there. Custom numerics will be marked as such. -RPL_WELCOME = "001" -RPL_YOURHOST = "002" -RPL_CREATED = "003" -RPL_MYINFO = "004" -RPL_ISUPPORT = "005" -RPL_SNOMASKIS = "008" -RPL_BOUNCE = "010" -RPL_TRACELINK = "200" -RPL_TRACECONNECTING = "201" -RPL_TRACEHANDSHAKE = "202" -RPL_TRACEUNKNOWN = "203" -RPL_TRACEOPERATOR = "204" -RPL_TRACEUSER = "205" -RPL_TRACESERVER = "206" -RPL_TRACESERVICE = "207" -RPL_TRACENEWTYPE = "208" -RPL_TRACECLASS = "209" -RPL_TRACERECONNECT = "210" -RPL_STATSLINKINFO = "211" -RPL_STATSCOMMANDS = "212" -RPL_ENDOFSTATS = "219" -RPL_UMODEIS = "221" -RPL_SERVLIST = "234" -RPL_SERVLISTEND = "235" -RPL_STATSUPTIME = "242" -RPL_STATSOLINE = "243" -RPL_LUSERCLIENT = "251" -RPL_LUSEROP = "252" -RPL_LUSERUNKNOWN = "253" -RPL_LUSERCHANNELS = "254" -RPL_LUSERME = "255" -RPL_ADMINME = "256" -RPL_ADMINLOC1 = "257" -RPL_ADMINLOC2 = "258" -RPL_ADMINEMAIL = "259" -RPL_TRACELOG = "261" -RPL_TRACEEND = "262" -RPL_TRYAGAIN = "263" -RPL_LOCALUSERS = "265" -RPL_GLOBALUSERS = "266" -RPL_WHOISCERTFP = "276" -RPL_AWAY = "301" -RPL_USERHOST = "302" -RPL_ISON = "303" -RPL_UNAWAY = "305" -RPL_NOWAWAY = "306" -RPL_WHOISUSER = "311" -RPL_WHOISSERVER = "312" -RPL_WHOISOPERATOR = "313" -RPL_WHOWASUSER = "314" -RPL_ENDOFWHO = "315" -RPL_WHOISIDLE = "317" -RPL_ENDOFWHOIS = "318" -RPL_WHOISCHANNELS = "319" -RPL_LIST = "322" -RPL_LISTEND = "323" -RPL_CHANNELMODEIS = "324" -RPL_UNIQOPIS = "325" -RPL_CHANNELCREATED = "329" -RPL_WHOISACCOUNT = "330" -RPL_NOTOPIC = "331" -RPL_TOPIC = "332" -RPL_TOPICTIME = "333" -RPL_WHOISBOT = "335" -RPL_WHOISACTUALLY = "338" -RPL_INVITING = "341" -RPL_SUMMONING = "342" -RPL_INVITELIST = "346" -RPL_ENDOFINVITELIST = "347" -RPL_EXCEPTLIST = "348" -RPL_ENDOFEXCEPTLIST = "349" -RPL_VERSION = "351" -RPL_WHOREPLY = "352" -RPL_NAMREPLY = "353" -RPL_LINKS = "364" -RPL_ENDOFLINKS = "365" -RPL_ENDOFNAMES = "366" -RPL_BANLIST = "367" -RPL_ENDOFBANLIST = "368" -RPL_ENDOFWHOWAS = "369" -RPL_INFO = "371" -RPL_MOTD = "372" -RPL_ENDOFINFO = "374" -RPL_MOTDSTART = "375" -RPL_ENDOFMOTD = "376" -RPL_YOUREOPER = "381" -RPL_REHASHING = "382" -RPL_YOURESERVICE = "383" -RPL_TIME = "391" -RPL_USERSSTART = "392" -RPL_USERS = "393" -RPL_ENDOFUSERS = "394" -RPL_NOUSERS = "395" -ERR_UNKNOWNERROR = "400" -ERR_NOSUCHNICK = "401" -ERR_NOSUCHSERVER = "402" -ERR_NOSUCHCHANNEL = "403" -ERR_CANNOTSENDTOCHAN = "404" -ERR_TOOMANYCHANNELS = "405" -ERR_WASNOSUCHNICK = "406" -ERR_TOOMANYTARGETS = "407" -ERR_NOSUCHSERVICE = "408" -ERR_NOORIGIN = "409" -ERR_INVALIDCAPCMD = "410" -ERR_NORECIPIENT = "411" -ERR_NOTEXTTOSEND = "412" -ERR_NOTOPLEVEL = "413" -ERR_WILDTOPLEVEL = "414" -ERR_BADMASK = "415" -ERR_INPUTTOOLONG = "417" -ERR_UNKNOWNCOMMAND = "421" -ERR_NOMOTD = "422" -ERR_NOADMININFO = "423" -ERR_FILEERROR = "424" -ERR_NONICKNAMEGIVEN = "431" -ERR_ERRONEUSNICKNAME = "432" -ERR_NICKNAMEINUSE = "433" -ERR_NICKCOLLISION = "436" -ERR_UNAVAILRESOURCE = "437" -ERR_REG_UNAVAILABLE = "440" -ERR_USERNOTINCHANNEL = "441" -ERR_NOTONCHANNEL = "442" -ERR_USERONCHANNEL = "443" -ERR_NOLOGIN = "444" -ERR_SUMMONDISABLED = "445" -ERR_USERSDISABLED = "446" -ERR_NOTREGISTERED = "451" -ERR_NEEDMOREPARAMS = "461" -ERR_ALREADYREGISTRED = "462" -ERR_NOPERMFORHOST = "463" -ERR_PASSWDMISMATCH = "464" -ERR_YOUREBANNEDCREEP = "465" -ERR_YOUWILLBEBANNED = "466" -ERR_KEYSET = "467" -ERR_INVALIDUSERNAME = "468" -ERR_LINKCHANNEL = "470" -ERR_CHANNELISFULL = "471" -ERR_UNKNOWNMODE = "472" -ERR_INVITEONLYCHAN = "473" -ERR_BANNEDFROMCHAN = "474" -ERR_BADCHANNELKEY = "475" -ERR_BADCHANMASK = "476" -ERR_NOCHANMODES = "477" -ERR_NEEDREGGEDNICK = "477" -ERR_BANLISTFULL = "478" -ERR_NOPRIVILEGES = "481" -ERR_CHANOPRIVSNEEDED = "482" -ERR_CANTKILLSERVER = "483" -ERR_RESTRICTED = "484" -ERR_UNIQOPPRIVSNEEDED = "485" -ERR_NOOPERHOST = "491" -ERR_UMODEUNKNOWNFLAG = "501" -ERR_USERSDONTMATCH = "502" -ERR_HELPNOTFOUND = "524" -ERR_CANNOTSENDRP = "573" -RPL_WHOISSECURE = "671" -RPL_YOURLANGUAGESARE = "687" -RPL_WHOISLANGUAGE = "690" -ERR_INVALIDMODEPARAM = "696" -RPL_HELPSTART = "704" -RPL_HELPTXT = "705" -RPL_ENDOFHELP = "706" -ERR_NOPRIVS = "723" -RPL_MONONLINE = "730" -RPL_MONOFFLINE = "731" -RPL_MONLIST = "732" -RPL_ENDOFMONLIST = "733" -ERR_MONLISTFULL = "734" -RPL_LOGGEDIN = "900" -RPL_LOGGEDOUT = "901" -ERR_NICKLOCKED = "902" -RPL_SASLSUCCESS = "903" -ERR_SASLFAIL = "904" -ERR_SASLTOOLONG = "905" -ERR_SASLABORTED = "906" -ERR_SASLALREADY = "907" -RPL_SASLMECHS = "908" -RPL_REGISTRATION_SUCCESS = "920" -ERR_ACCOUNT_ALREADY_EXISTS = "921" -ERR_REG_UNSPECIFIED_ERROR = "922" -RPL_VERIFYSUCCESS = "923" -ERR_ACCOUNT_ALREADY_VERIFIED = "924" +RPL_WELCOME = "001" +RPL_YOURHOST = "002" +RPL_CREATED = "003" +RPL_MYINFO = "004" +RPL_ISUPPORT = "005" +RPL_SNOMASKIS = "008" +RPL_BOUNCE = "010" +RPL_TRACELINK = "200" +RPL_TRACECONNECTING = "201" +RPL_TRACEHANDSHAKE = "202" +RPL_TRACEUNKNOWN = "203" +RPL_TRACEOPERATOR = "204" +RPL_TRACEUSER = "205" +RPL_TRACESERVER = "206" +RPL_TRACESERVICE = "207" +RPL_TRACENEWTYPE = "208" +RPL_TRACECLASS = "209" +RPL_TRACERECONNECT = "210" +RPL_STATSLINKINFO = "211" +RPL_STATSCOMMANDS = "212" +RPL_ENDOFSTATS = "219" +RPL_UMODEIS = "221" +RPL_SERVLIST = "234" +RPL_SERVLISTEND = "235" +RPL_STATSUPTIME = "242" +RPL_STATSOLINE = "243" +RPL_LUSERCLIENT = "251" +RPL_LUSEROP = "252" +RPL_LUSERUNKNOWN = "253" +RPL_LUSERCHANNELS = "254" +RPL_LUSERME = "255" +RPL_ADMINME = "256" +RPL_ADMINLOC1 = "257" +RPL_ADMINLOC2 = "258" +RPL_ADMINEMAIL = "259" +RPL_TRACELOG = "261" +RPL_TRACEEND = "262" +RPL_TRYAGAIN = "263" +RPL_LOCALUSERS = "265" +RPL_GLOBALUSERS = "266" +RPL_WHOISCERTFP = "276" +RPL_AWAY = "301" +RPL_USERHOST = "302" +RPL_ISON = "303" +RPL_UNAWAY = "305" +RPL_NOWAWAY = "306" +RPL_WHOISUSER = "311" +RPL_WHOISSERVER = "312" +RPL_WHOISOPERATOR = "313" +RPL_WHOWASUSER = "314" +RPL_ENDOFWHO = "315" +RPL_WHOISIDLE = "317" +RPL_ENDOFWHOIS = "318" +RPL_WHOISCHANNELS = "319" +RPL_LIST = "322" +RPL_LISTEND = "323" +RPL_CHANNELMODEIS = "324" +RPL_UNIQOPIS = "325" +RPL_CHANNELCREATED = "329" +RPL_WHOISACCOUNT = "330" +RPL_NOTOPIC = "331" +RPL_TOPIC = "332" +RPL_TOPICTIME = "333" +RPL_WHOISBOT = "335" +RPL_WHOISACTUALLY = "338" +RPL_INVITING = "341" +RPL_SUMMONING = "342" +RPL_INVITELIST = "346" +RPL_ENDOFINVITELIST = "347" +RPL_EXCEPTLIST = "348" +RPL_ENDOFEXCEPTLIST = "349" +RPL_VERSION = "351" +RPL_WHOREPLY = "352" +RPL_NAMREPLY = "353" +RPL_LINKS = "364" +RPL_ENDOFLINKS = "365" +RPL_ENDOFNAMES = "366" +RPL_BANLIST = "367" +RPL_ENDOFBANLIST = "368" +RPL_ENDOFWHOWAS = "369" +RPL_INFO = "371" +RPL_MOTD = "372" +RPL_ENDOFINFO = "374" +RPL_MOTDSTART = "375" +RPL_ENDOFMOTD = "376" +RPL_YOUREOPER = "381" +RPL_REHASHING = "382" +RPL_YOURESERVICE = "383" +RPL_TIME = "391" +RPL_USERSSTART = "392" +RPL_USERS = "393" +RPL_ENDOFUSERS = "394" +RPL_NOUSERS = "395" +ERR_UNKNOWNERROR = "400" +ERR_NOSUCHNICK = "401" +ERR_NOSUCHSERVER = "402" +ERR_NOSUCHCHANNEL = "403" +ERR_CANNOTSENDTOCHAN = "404" +ERR_TOOMANYCHANNELS = "405" +ERR_WASNOSUCHNICK = "406" +ERR_TOOMANYTARGETS = "407" +ERR_NOSUCHSERVICE = "408" +ERR_NOORIGIN = "409" +ERR_INVALIDCAPCMD = "410" +ERR_NORECIPIENT = "411" +ERR_NOTEXTTOSEND = "412" +ERR_NOTOPLEVEL = "413" +ERR_WILDTOPLEVEL = "414" +ERR_BADMASK = "415" +ERR_INPUTTOOLONG = "417" +ERR_UNKNOWNCOMMAND = "421" +ERR_NOMOTD = "422" +ERR_NOADMININFO = "423" +ERR_FILEERROR = "424" +ERR_NONICKNAMEGIVEN = "431" +ERR_ERRONEUSNICKNAME = "432" +ERR_NICKNAMEINUSE = "433" +ERR_NICKCOLLISION = "436" +ERR_UNAVAILRESOURCE = "437" +ERR_REG_UNAVAILABLE = "440" +ERR_USERNOTINCHANNEL = "441" +ERR_NOTONCHANNEL = "442" +ERR_USERONCHANNEL = "443" +ERR_NOLOGIN = "444" +ERR_SUMMONDISABLED = "445" +ERR_USERSDISABLED = "446" +ERR_NOTREGISTERED = "451" +ERR_NEEDMOREPARAMS = "461" +ERR_ALREADYREGISTRED = "462" +ERR_NOPERMFORHOST = "463" +ERR_PASSWDMISMATCH = "464" +ERR_YOUREBANNEDCREEP = "465" +ERR_YOUWILLBEBANNED = "466" +ERR_KEYSET = "467" +ERR_INVALIDUSERNAME = "468" +ERR_LINKCHANNEL = "470" +ERR_CHANNELISFULL = "471" +ERR_UNKNOWNMODE = "472" +ERR_INVITEONLYCHAN = "473" +ERR_BANNEDFROMCHAN = "474" +ERR_BADCHANNELKEY = "475" +ERR_BADCHANMASK = "476" +ERR_NOCHANMODES = "477" +ERR_NEEDREGGEDNICK = "477" +ERR_BANLISTFULL = "478" +ERR_NOPRIVILEGES = "481" +ERR_CHANOPRIVSNEEDED = "482" +ERR_CANTKILLSERVER = "483" +ERR_RESTRICTED = "484" +ERR_UNIQOPPRIVSNEEDED = "485" +ERR_NOOPERHOST = "491" +ERR_UMODEUNKNOWNFLAG = "501" +ERR_USERSDONTMATCH = "502" +ERR_HELPNOTFOUND = "524" +ERR_CANNOTSENDRP = "573" +RPL_WHOISSECURE = "671" +RPL_YOURLANGUAGESARE = "687" +RPL_WHOISLANGUAGE = "690" +ERR_INVALIDMODEPARAM = "696" +RPL_HELPSTART = "704" +RPL_HELPTXT = "705" +RPL_ENDOFHELP = "706" +ERR_NOPRIVS = "723" +RPL_MONONLINE = "730" +RPL_MONOFFLINE = "731" +RPL_MONLIST = "732" +RPL_ENDOFMONLIST = "733" +ERR_MONLISTFULL = "734" +RPL_LOGGEDIN = "900" +RPL_LOGGEDOUT = "901" +ERR_NICKLOCKED = "902" +RPL_SASLSUCCESS = "903" +ERR_SASLFAIL = "904" +ERR_SASLTOOLONG = "905" +ERR_SASLABORTED = "906" +ERR_SASLALREADY = "907" +RPL_SASLMECHS = "908" +RPL_REGISTRATION_SUCCESS = "920" +ERR_ACCOUNT_ALREADY_EXISTS = "921" +ERR_REG_UNSPECIFIED_ERROR = "922" +RPL_VERIFYSUCCESS = "923" +ERR_ACCOUNT_ALREADY_VERIFIED = "924" ERR_ACCOUNT_INVALID_VERIFY_CODE = "925" -RPL_REG_VERIFICATION_REQUIRED = "927" -ERR_REG_INVALID_CRED_TYPE = "928" -ERR_REG_INVALID_CALLBACK = "929" -ERR_TOOMANYLANGUAGES = "981" -ERR_NOLANGUAGE = "982" +RPL_REG_VERIFICATION_REQUIRED = "927" +ERR_REG_INVALID_CRED_TYPE = "928" +ERR_REG_INVALID_CALLBACK = "929" +ERR_TOOMANYLANGUAGES = "981" +ERR_NOLANGUAGE = "982" diff --git a/irctest/runner.py b/irctest/runner.py index 2754b90..3b2e9a9 100644 --- a/irctest/runner.py +++ b/irctest/runner.py @@ -2,47 +2,59 @@ import unittest import operator import collections + class NotImplementedByController(unittest.SkipTest, NotImplementedError): def __str__(self): - return 'Not implemented by controller: {}'.format(self.args[0]) + return "Not implemented by controller: {}".format(self.args[0]) + class ImplementationChoice(unittest.SkipTest): def __str__(self): - return 'Choice in the implementation makes it impossible to ' \ - 'perform a test: {}'.format(self.args[0]) + return ( + "Choice in the implementation makes it impossible to " + "perform a test: {}".format(self.args[0]) + ) + class OptionalExtensionNotSupported(unittest.SkipTest): def __str__(self): - return 'Unsupported extension: {}'.format(self.args[0]) + return "Unsupported extension: {}".format(self.args[0]) + class OptionalSaslMechanismNotSupported(unittest.SkipTest): def __str__(self): - return 'Unsupported SASL mechanism: {}'.format(self.args[0]) + return "Unsupported SASL mechanism: {}".format(self.args[0]) + class CapabilityNotSupported(unittest.SkipTest): def __str__(self): - return 'Unsupported capability: {}'.format(self.args[0]) + return "Unsupported capability: {}".format(self.args[0]) + class NotRequiredBySpecifications(unittest.SkipTest): def __str__(self): - return 'Tests not required by the set of tested specification(s).' + return "Tests not required by the set of tested specification(s)." + class SkipStrictTest(unittest.SkipTest): def __str__(self): - return 'Tests not required because strict tests are disabled.' + return "Tests not required because strict tests are disabled." + class TextTestResult(unittest.TextTestResult): def getDescription(self, test): - if hasattr(test, 'description'): + if hasattr(test, "description"): doc_first_lines = test.description() else: doc_first_lines = test.shortDescription() - return '\n'.join((str(test), doc_first_lines or '')) + return "\n".join((str(test), doc_first_lines or "")) + class TextTestRunner(unittest.TextTestRunner): """Small wrapper around unittest.TextTestRunner that reports the number of tests that were skipped because the software does not support an optional feature.""" + resultclass = TextTestResult def run(self, test): @@ -50,11 +62,13 @@ class TextTestRunner(unittest.TextTestRunner): assert self.resultclass is TextTestResult if result.skipped: print() - print('Some tests were skipped because the following optional ' - 'specifications/mechanisms are not supported:') + print( + "Some tests were skipped because the following optional " + "specifications/mechanisms are not supported:" + ) msg_to_count = collections.defaultdict(lambda: 0) for (test, msg) in result.skipped: msg_to_count[msg] += 1 for (msg, count) in sorted(msg_to_count.items()): - print('\t{} ({} test(s))'.format(msg, count)) + print("\t{} ({} test(s))".format(msg, count)) return result diff --git a/irctest/server_tests/test_account_tag.py b/irctest/server_tests/test_account_tag.py index 844672b..49acb0e 100644 --- a/irctest/server_tests/test_account_tag.py +++ b/irctest/server_tests/test_account_tag.py @@ -4,44 +4,62 @@ from irctest import cases + class AccountTagTestCase(cases.BaseServerTestCase, cases.OptionalityHelper): def connectRegisteredClient(self, nick): self.addClient() - self.sendLine(2, 'CAP LS 302') + self.sendLine(2, "CAP LS 302") capabilities = self.getCapLs(2) - assert 'sasl' in capabilities - self.sendLine(2, 'AUTHENTICATE PLAIN') - m = self.getMessage(2, filter_pred=lambda m:m.command != 'NOTICE') - self.assertMessageEqual(m, command='AUTHENTICATE', params=['+'], - fail_msg='Sent “AUTHENTICATE PLAIN”, server should have ' - 'replied with “AUTHENTICATE +”, but instead sent: {msg}') - self.sendLine(2, 'AUTHENTICATE amlsbGVzAGppbGxlcwBzZXNhbWU=') - m = self.getMessage(2, filter_pred=lambda m:m.command != 'NOTICE') - self.assertMessageEqual(m, command='900', - fail_msg='Did not send 900 after correct SASL authentication.') - self.sendLine(2, 'USER f * * :Realname') - self.sendLine(2, 'NICK {}'.format(nick)) - self.sendLine(2, 'CAP END') + assert "sasl" in capabilities + self.sendLine(2, "AUTHENTICATE PLAIN") + m = self.getMessage(2, filter_pred=lambda m: m.command != "NOTICE") + self.assertMessageEqual( + m, + command="AUTHENTICATE", + params=["+"], + fail_msg="Sent “AUTHENTICATE PLAIN”, server should have " + "replied with “AUTHENTICATE +”, but instead sent: {msg}", + ) + self.sendLine(2, "AUTHENTICATE amlsbGVzAGppbGxlcwBzZXNhbWU=") + m = self.getMessage(2, filter_pred=lambda m: m.command != "NOTICE") + self.assertMessageEqual( + m, + command="900", + fail_msg="Did not send 900 after correct SASL authentication.", + ) + self.sendLine(2, "USER f * * :Realname") + self.sendLine(2, "NICK {}".format(nick)) + self.sendLine(2, "CAP END") self.skipToWelcome(2) - @cases.SpecificationSelector.requiredBySpecification('IRCv3.2') - @cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.2") + @cases.OptionalityHelper.skipUnlessHasMechanism("PLAIN") def testPrivmsg(self): - self.connectClient('foo', capabilities=['account-tag'], - skip_if_cap_nak=True) + self.connectClient("foo", capabilities=["account-tag"], skip_if_cap_nak=True) self.getMessages(1) - self.controller.registerUser(self, 'jilles', 'sesame') - self.connectRegisteredClient('bar') - self.sendLine(2, 'PRIVMSG foo :hi') + self.controller.registerUser(self, "jilles", "sesame") + self.connectRegisteredClient("bar") + self.sendLine(2, "PRIVMSG foo :hi") self.getMessages(2) m = self.getMessage(1) - self.assertMessageEqual(m, command='PRIVMSG', # RPL_MONONLINE - fail_msg='Sent non-730 (RPL_MONONLINE) message after ' - '“bar” sent a PRIVMSG: {msg}') - self.assertIn('account', m.tags, m, - fail_msg='PRIVMSG by logged in nick ' - 'does not contain an account tag: {msg}') - self.assertEqual(m.tags['account'], 'jilles', m, - fail_msg='PRIVMSG by logged in nick ' - 'does not contain the correct account tag (should be ' - '“jilles”): {msg}') + self.assertMessageEqual( + m, + command="PRIVMSG", # RPL_MONONLINE + fail_msg="Sent non-730 (RPL_MONONLINE) message after " + "“bar” sent a PRIVMSG: {msg}", + ) + self.assertIn( + "account", + m.tags, + m, + fail_msg="PRIVMSG by logged in nick " + "does not contain an account tag: {msg}", + ) + self.assertEqual( + m.tags["account"], + "jilles", + m, + fail_msg="PRIVMSG by logged in nick " + "does not contain the correct account tag (should be " + "“jilles”): {msg}", + ) diff --git a/irctest/server_tests/test_away_notify.py b/irctest/server_tests/test_away_notify.py index d1d2bf0..a494f30 100644 --- a/irctest/server_tests/test_away_notify.py +++ b/irctest/server_tests/test_away_notify.py @@ -4,49 +4,55 @@ from irctest import cases -class AwayNotifyTestCase(cases.BaseServerTestCase, cases.OptionalityHelper): - @cases.SpecificationSelector.requiredBySpecification('IRCv3.1') +class AwayNotifyTestCase(cases.BaseServerTestCase, cases.OptionalityHelper): + @cases.SpecificationSelector.requiredBySpecification("IRCv3.1") def testAwayNotify(self): """Basic away-notify test.""" - self.connectClient('foo', capabilities=['away-notify'], skip_if_cap_nak=True) + self.connectClient("foo", capabilities=["away-notify"], skip_if_cap_nak=True) self.getMessages(1) - self.joinChannel(1, '#chan') + self.joinChannel(1, "#chan") - self.connectClient('bar') + self.connectClient("bar") self.getMessages(2) - self.joinChannel(2, '#chan') + self.joinChannel(2, "#chan") self.getMessages(2) self.getMessages(1) self.sendLine(2, "AWAY :i'm going away") self.getMessages(2) - messages = [msg for msg in self.getMessages(1) if msg.command == 'AWAY'] + messages = [msg for msg in self.getMessages(1) if msg.command == "AWAY"] self.assertEqual(len(messages), 1) awayNotify = messages[0] - self.assertTrue(awayNotify.prefix.startswith('bar!'), 'Unexpected away-notify source: %s' % (awayNotify.prefix,)) + self.assertTrue( + awayNotify.prefix.startswith("bar!"), + "Unexpected away-notify source: %s" % (awayNotify.prefix,), + ) self.assertEqual(awayNotify.params, ["i'm going away"]) - @cases.SpecificationSelector.requiredBySpecification('IRCv3.2') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.2") def testAwayNotifyOnJoin(self): """The away-notify specification states: "Clients will be sent an AWAY message [...] when a user joins and has an away message set." """ - self.connectClient('foo', capabilities=['away-notify'], skip_if_cap_nak=True) + self.connectClient("foo", capabilities=["away-notify"], skip_if_cap_nak=True) self.getMessages(1) - self.joinChannel(1, '#chan') + self.joinChannel(1, "#chan") - self.connectClient('bar') + self.connectClient("bar") self.getMessages(2) self.sendLine(2, "AWAY :i'm already away") self.getMessages(2) - self.joinChannel(2, '#chan') + self.joinChannel(2, "#chan") self.getMessages(2) - messages = [msg for msg in self.getMessages(1) if msg.command == 'AWAY'] + messages = [msg for msg in self.getMessages(1) if msg.command == "AWAY"] self.assertEqual(len(messages), 1) awayNotify = messages[0] - self.assertTrue(awayNotify.prefix.startswith('bar!'), 'Unexpected away-notify source: %s' % (awayNotify.prefix,)) + self.assertTrue( + awayNotify.prefix.startswith("bar!"), + "Unexpected away-notify source: %s" % (awayNotify.prefix,), + ) self.assertEqual(awayNotify.params, ["i'm already away"]) diff --git a/irctest/server_tests/test_bouncer.py b/irctest/server_tests/test_bouncer.py index 8a6d28b..2087647 100644 --- a/irctest/server_tests/test_bouncer.py +++ b/irctest/server_tests/test_bouncer.py @@ -4,139 +4,155 @@ from irctest.irc_utils.sasl import sasl_plain_blob from irctest.numerics import RPL_WELCOME from irctest.numerics import ERR_NICKNAMEINUSE -class Bouncer(cases.BaseServerTestCase): - @cases.SpecificationSelector.requiredBySpecification('Oragono') +class Bouncer(cases.BaseServerTestCase): + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testBouncer(self): """Test basic bouncer functionality.""" - self.controller.registerUser(self, 'observer', 'observerpassword') - self.controller.registerUser(self, 'testuser', 'mypassword') + self.controller.registerUser(self, "observer", "observerpassword") + self.controller.registerUser(self, "testuser", "mypassword") - self.connectClient('observer', password='observerpassword') - self.joinChannel(1, '#chan') - self.sendLine(1, 'CAP REQ :message-tags server-time') + self.connectClient("observer", password="observerpassword") + self.joinChannel(1, "#chan") + self.sendLine(1, "CAP REQ :message-tags server-time") self.getMessages(1) self.addClient() - self.sendLine(2, 'CAP LS 302') - self.sendLine(2, 'AUTHENTICATE PLAIN') - self.sendLine(2, sasl_plain_blob('testuser', 'mypassword')) - self.sendLine(2, 'NICK testnick') - self.sendLine(2, 'USER a 0 * a') - self.sendLine(2, 'CAP REQ :server-time message-tags') - self.sendLine(2, 'CAP END') + self.sendLine(2, "CAP LS 302") + self.sendLine(2, "AUTHENTICATE PLAIN") + self.sendLine(2, sasl_plain_blob("testuser", "mypassword")) + self.sendLine(2, "NICK testnick") + self.sendLine(2, "USER a 0 * a") + self.sendLine(2, "CAP REQ :server-time message-tags") + self.sendLine(2, "CAP END") messages = self.getMessages(2) welcomes = [message for message in messages if message.command == RPL_WELCOME] self.assertEqual(len(welcomes), 1) # should see a regburst for testnick - self.assertEqual(welcomes[0].params[0], 'testnick') - self.joinChannel(2, '#chan') + self.assertEqual(welcomes[0].params[0], "testnick") + self.joinChannel(2, "#chan") self.addClient() - self.sendLine(3, 'CAP LS 302') - self.sendLine(3, 'AUTHENTICATE PLAIN') - self.sendLine(3, sasl_plain_blob('testuser', 'mypassword')) - self.sendLine(3, 'NICK testnick') - self.sendLine(3, 'USER a 0 * a') - self.sendLine(3, 'CAP REQ :server-time message-tags account-tag') - self.sendLine(3, 'CAP END') + self.sendLine(3, "CAP LS 302") + self.sendLine(3, "AUTHENTICATE PLAIN") + self.sendLine(3, sasl_plain_blob("testuser", "mypassword")) + self.sendLine(3, "NICK testnick") + self.sendLine(3, "USER a 0 * a") + self.sendLine(3, "CAP REQ :server-time message-tags account-tag") + self.sendLine(3, "CAP END") messages = self.getMessages(3) welcomes = [message for message in messages if message.command == RPL_WELCOME] self.assertEqual(len(welcomes), 1) # should see the *same* regburst for testnick - self.assertEqual(welcomes[0].params[0], 'testnick') - joins = [message for message in messages if message.command == 'JOIN'] + self.assertEqual(welcomes[0].params[0], "testnick") + joins = [message for message in messages if message.command == "JOIN"] # we should be automatically joined to #chan - self.assertEqual(joins[0].params[0], '#chan') + self.assertEqual(joins[0].params[0], "#chan") # disable multiclient in nickserv - self.sendLine(3, 'NS SET MULTICLIENT OFF') + self.sendLine(3, "NS SET MULTICLIENT OFF") self.getMessages(3) self.addClient() - self.sendLine(4, 'CAP LS 302') - self.sendLine(4, 'AUTHENTICATE PLAIN') - self.sendLine(4, sasl_plain_blob('testuser', 'mypassword')) - self.sendLine(4, 'NICK testnick') - self.sendLine(4, 'USER a 0 * a') - self.sendLine(4, 'CAP REQ :server-time message-tags') - self.sendLine(4, 'CAP END') + self.sendLine(4, "CAP LS 302") + self.sendLine(4, "AUTHENTICATE PLAIN") + self.sendLine(4, sasl_plain_blob("testuser", "mypassword")) + self.sendLine(4, "NICK testnick") + self.sendLine(4, "USER a 0 * a") + self.sendLine(4, "CAP REQ :server-time message-tags") + self.sendLine(4, "CAP END") # with multiclient disabled, we should not be able to attach to the nick messages = self.getMessages(4) welcomes = [message for message in messages if message.command == RPL_WELCOME] self.assertEqual(len(welcomes), 0) - errors = [message for message in messages if message.command == ERR_NICKNAMEINUSE] + errors = [ + message for message in messages if message.command == ERR_NICKNAMEINUSE + ] self.assertEqual(len(errors), 1) - self.sendLine(3, 'NS SET MULTICLIENT ON') + self.sendLine(3, "NS SET MULTICLIENT ON") self.getMessages(3) self.addClient() - self.sendLine(5, 'CAP LS 302') - self.sendLine(5, 'AUTHENTICATE PLAIN') - self.sendLine(5, sasl_plain_blob('testuser', 'mypassword')) - self.sendLine(5, 'NICK testnick') - self.sendLine(5, 'USER a 0 * a') - self.sendLine(5, 'CAP REQ server-time') - self.sendLine(5, 'CAP END') + self.sendLine(5, "CAP LS 302") + self.sendLine(5, "AUTHENTICATE PLAIN") + self.sendLine(5, sasl_plain_blob("testuser", "mypassword")) + self.sendLine(5, "NICK testnick") + self.sendLine(5, "USER a 0 * a") + self.sendLine(5, "CAP REQ server-time") + self.sendLine(5, "CAP END") messages = self.getMessages(5) welcomes = [message for message in messages if message.command == RPL_WELCOME] self.assertEqual(len(welcomes), 1) - self.sendLine(1, '@+clientOnlyTag=Value PRIVMSG #chan :hey') + self.sendLine(1, "@+clientOnlyTag=Value PRIVMSG #chan :hey") self.getMessages(1) - messagesfortwo = [msg for msg in self.getMessages(2) if msg.command == 'PRIVMSG'] - messagesforthree = [msg for msg in self.getMessages(3) if msg.command == 'PRIVMSG'] + messagesfortwo = [ + msg for msg in self.getMessages(2) if msg.command == "PRIVMSG" + ] + messagesforthree = [ + msg for msg in self.getMessages(3) if msg.command == "PRIVMSG" + ] self.assertEqual(len(messagesfortwo), 1) self.assertEqual(len(messagesforthree), 1) messagefortwo = messagesfortwo[0] messageforthree = messagesforthree[0] messageforfive = self.getMessage(5) - self.assertEqual(messagefortwo.params, ['#chan', 'hey']) - self.assertEqual(messageforthree.params, ['#chan', 'hey']) - self.assertEqual(messageforfive.params, ['#chan', 'hey']) - self.assertIn('time', messagefortwo.tags) - self.assertIn('time', messageforthree.tags) - self.assertIn('time', messageforfive.tags) + self.assertEqual(messagefortwo.params, ["#chan", "hey"]) + self.assertEqual(messageforthree.params, ["#chan", "hey"]) + self.assertEqual(messageforfive.params, ["#chan", "hey"]) + self.assertIn("time", messagefortwo.tags) + self.assertIn("time", messageforthree.tags) + self.assertIn("time", messageforfive.tags) # 3 has account-tag - self.assertIn('account', messageforthree.tags) + self.assertIn("account", messageforthree.tags) # should get same msgid - self.assertEqual(messagefortwo.tags['msgid'], messageforthree.tags['msgid']) + self.assertEqual(messagefortwo.tags["msgid"], messageforthree.tags["msgid"]) # 5 only has server-time, shouldn't get account or msgid tags - self.assertNotIn('account', messageforfive.tags) - self.assertNotIn('msgid', messageforfive.tags) + self.assertNotIn("account", messageforfive.tags) + self.assertNotIn("msgid", messageforfive.tags) # test that copies of sent messages go out to other sessions - self.sendLine(2, 'PRIVMSG observer :this is a direct message') + self.sendLine(2, "PRIVMSG observer :this is a direct message") self.getMessages(2) - messageForRecipient = [msg for msg in self.getMessages(1) if msg.command == 'PRIVMSG'][0] - copyForOtherSession = [msg for msg in self.getMessages(3) if msg.command == 'PRIVMSG'][0] + messageForRecipient = [ + msg for msg in self.getMessages(1) if msg.command == "PRIVMSG" + ][0] + copyForOtherSession = [ + msg for msg in self.getMessages(3) if msg.command == "PRIVMSG" + ][0] self.assertEqual(messageForRecipient.params, copyForOtherSession.params) - self.assertEqual(messageForRecipient.tags['msgid'], copyForOtherSession.tags['msgid']) + self.assertEqual( + messageForRecipient.tags["msgid"], copyForOtherSession.tags["msgid"] + ) - self.sendLine(2, 'QUIT :two out') - quitLines = [msg for msg in self.getMessages(2) if msg.command == 'QUIT'] + self.sendLine(2, "QUIT :two out") + quitLines = [msg for msg in self.getMessages(2) if msg.command == "QUIT"] self.assertEqual(len(quitLines), 1) - self.assertIn('two out', quitLines[0].params[0]) + self.assertIn("two out", quitLines[0].params[0]) # neither the observer nor the other attached session should see a quit here - quitLines = [msg for msg in self.getMessages(1) if msg.command == 'QUIT'] + quitLines = [msg for msg in self.getMessages(1) if msg.command == "QUIT"] self.assertEqual(quitLines, []) - quitLines = [msg for msg in self.getMessages(3) if msg.command == 'QUIT'] + quitLines = [msg for msg in self.getMessages(3) if msg.command == "QUIT"] self.assertEqual(quitLines, []) # session 3 should be untouched at this point - self.sendLine(1, '@+clientOnlyTag=Value PRIVMSG #chan :hey again') + self.sendLine(1, "@+clientOnlyTag=Value PRIVMSG #chan :hey again") self.getMessages(1) - messagesforthree = [msg for msg in self.getMessages(3) if msg.command == 'PRIVMSG'] + messagesforthree = [ + msg for msg in self.getMessages(3) if msg.command == "PRIVMSG" + ] self.assertEqual(len(messagesforthree), 1) - self.assertMessageEqual(messagesforthree[0], command='PRIVMSG', params=['#chan', 'hey again']) + self.assertMessageEqual( + messagesforthree[0], command="PRIVMSG", params=["#chan", "hey again"] + ) - self.sendLine(5, 'QUIT :five out') + self.sendLine(5, "QUIT :five out") self.getMessages(5) - self.sendLine(3, 'QUIT :three out') - quitLines = [msg for msg in self.getMessages(3) if msg.command == 'QUIT'] + self.sendLine(3, "QUIT :three out") + quitLines = [msg for msg in self.getMessages(3) if msg.command == "QUIT"] self.assertEqual(len(quitLines), 1) - self.assertIn('three out', quitLines[0].params[0]) + self.assertIn("three out", quitLines[0].params[0]) # observer should see *this* quit - quitLines = [msg for msg in self.getMessages(1) if msg.command == 'QUIT'] + quitLines = [msg for msg in self.getMessages(1) if msg.command == "QUIT"] self.assertEqual(len(quitLines), 1) - self.assertIn('three out', quitLines[0].params[0]) + self.assertIn("three out", quitLines[0].params[0]) diff --git a/irctest/server_tests/test_cap.py b/irctest/server_tests/test_cap.py index 2bb7b98..d90592a 100644 --- a/irctest/server_tests/test_cap.py +++ b/irctest/server_tests/test_cap.py @@ -1,7 +1,8 @@ from irctest import cases + class CapTestCase(cases.BaseServerTestCase): - @cases.SpecificationSelector.requiredBySpecification('IRCv3.1') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.1") def testNoReq(self): """Test the server handles gracefully clients which do not send REQs. @@ -11,38 +12,44 @@ class CapTestCase(cases.BaseServerTestCase): -- """ self.addClient(1) - self.sendLine(1, 'CAP LS 302') + self.sendLine(1, "CAP LS 302") self.getCapLs(1) - self.sendLine(1, 'USER foo foo foo :foo') - self.sendLine(1, 'NICK foo') - self.sendLine(1, 'CAP END') + self.sendLine(1, "USER foo foo foo :foo") + self.sendLine(1, "NICK foo") + self.sendLine(1, "CAP END") m = self.getRegistrationMessage(1) - self.assertMessageEqual(m, command='001', - fail_msg='Expected 001 after sending CAP END, got {msg}.') + self.assertMessageEqual( + m, command="001", fail_msg="Expected 001 after sending CAP END, got {msg}." + ) - @cases.SpecificationSelector.requiredBySpecification('IRCv3.1') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.1") def testReqUnavailable(self): """Test the server handles gracefully clients which request capabilities that are not available. """ self.addClient(1) - self.sendLine(1, 'CAP LS 302') + self.sendLine(1, "CAP LS 302") self.getCapLs(1) - self.sendLine(1, 'USER foo foo foo :foo') - self.sendLine(1, 'NICK foo') - self.sendLine(1, 'CAP REQ :foo') + self.sendLine(1, "USER foo foo foo :foo") + self.sendLine(1, "NICK foo") + self.sendLine(1, "CAP REQ :foo") m = self.getRegistrationMessage(1) - self.assertMessageEqual(m, command='CAP', - subcommand='NAK', subparams=['foo'], - fail_msg='Expected CAP NAK after requesting non-existing ' - 'capability, got {msg}.') - self.sendLine(1, 'CAP END') + self.assertMessageEqual( + m, + command="CAP", + subcommand="NAK", + subparams=["foo"], + fail_msg="Expected CAP NAK after requesting non-existing " + "capability, got {msg}.", + ) + self.sendLine(1, "CAP END") m = self.getRegistrationMessage(1) - self.assertMessageEqual(m, command='001', - fail_msg='Expected 001 after sending CAP END, got {msg}.') + self.assertMessageEqual( + m, command="001", fail_msg="Expected 001 after sending CAP END, got {msg}." + ) - @cases.SpecificationSelector.requiredBySpecification('IRCv3.1') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.1") def testNakExactString(self): """“The argument of the NAK subcommand MUST consist of at least the first 100 characters of the capability list in the REQ subcommand which @@ -50,78 +57,100 @@ class CapTestCase(cases.BaseServerTestCase): -- """ self.addClient(1) - self.sendLine(1, 'CAP LS 302') + self.sendLine(1, "CAP LS 302") self.getCapLs(1) # Five should be enough to check there is no reordering, even # alphabetical - self.sendLine(1, 'CAP REQ :foo qux bar baz qux quux') + self.sendLine(1, "CAP REQ :foo qux bar baz qux quux") m = self.getRegistrationMessage(1) - self.assertMessageEqual(m, command='CAP', - subcommand='NAK', subparams=['foo qux bar baz qux quux'], - fail_msg='Expected “CAP NAK :foo qux bar baz qux quux” after ' - 'sending “CAP REQ :foo qux bar baz qux quux”, but got {msg}.') + self.assertMessageEqual( + m, + command="CAP", + subcommand="NAK", + subparams=["foo qux bar baz qux quux"], + fail_msg="Expected “CAP NAK :foo qux bar baz qux quux” after " + "sending “CAP REQ :foo qux bar baz qux quux”, but got {msg}.", + ) - @cases.SpecificationSelector.requiredBySpecification('IRCv3.1') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.1") def testNakWhole(self): """“The capability identifier set must be accepted as a whole, or rejected entirely.” -- """ self.addClient(1) - self.sendLine(1, 'CAP LS 302') - self.assertIn('multi-prefix', self.getCapLs(1)) - self.sendLine(1, 'CAP REQ :foo multi-prefix bar') + self.sendLine(1, "CAP LS 302") + self.assertIn("multi-prefix", self.getCapLs(1)) + self.sendLine(1, "CAP REQ :foo multi-prefix bar") m = self.getRegistrationMessage(1) - self.assertMessageEqual(m, command='CAP', - subcommand='NAK', subparams=['foo multi-prefix bar'], - fail_msg='Expected “CAP NAK :foo multi-prefix bar” after ' - 'sending “CAP REQ :foo multi-prefix bar”, but got {msg}.') - self.sendLine(1, 'CAP REQ :multi-prefix bar') + self.assertMessageEqual( + m, + command="CAP", + subcommand="NAK", + subparams=["foo multi-prefix bar"], + fail_msg="Expected “CAP NAK :foo multi-prefix bar” after " + "sending “CAP REQ :foo multi-prefix bar”, but got {msg}.", + ) + self.sendLine(1, "CAP REQ :multi-prefix bar") m = self.getRegistrationMessage(1) - self.assertMessageEqual(m, command='CAP', - subcommand='NAK', subparams=['multi-prefix bar'], - fail_msg='Expected “CAP NAK :multi-prefix bar” after ' - 'sending “CAP REQ :multi-prefix bar”, but got {msg}.') - self.sendLine(1, 'CAP REQ :foo multi-prefix') + self.assertMessageEqual( + m, + command="CAP", + subcommand="NAK", + subparams=["multi-prefix bar"], + fail_msg="Expected “CAP NAK :multi-prefix bar” after " + "sending “CAP REQ :multi-prefix bar”, but got {msg}.", + ) + self.sendLine(1, "CAP REQ :foo multi-prefix") m = self.getRegistrationMessage(1) - self.assertMessageEqual(m, command='CAP', - subcommand='NAK', subparams=['foo multi-prefix'], - fail_msg='Expected “CAP NAK :foo multi-prefix” after ' - 'sending “CAP REQ :foo multi-prefix”, but got {msg}.') + self.assertMessageEqual( + m, + command="CAP", + subcommand="NAK", + subparams=["foo multi-prefix"], + fail_msg="Expected “CAP NAK :foo multi-prefix” after " + "sending “CAP REQ :foo multi-prefix”, but got {msg}.", + ) # TODO: make sure multi-prefix is not enabled at this point - self.sendLine(1, 'CAP REQ :multi-prefix') + self.sendLine(1, "CAP REQ :multi-prefix") m = self.getRegistrationMessage(1) - self.assertMessageEqual(m, command='CAP', - subcommand='ACK', subparams=['multi-prefix'], - fail_msg='Expected “CAP ACK :multi-prefix” after ' - 'sending “CAP REQ :multi-prefix”, but got {msg}.') + self.assertMessageEqual( + m, + command="CAP", + subcommand="ACK", + subparams=["multi-prefix"], + fail_msg="Expected “CAP ACK :multi-prefix” after " + "sending “CAP REQ :multi-prefix”, but got {msg}.", + ) - @cases.SpecificationSelector.requiredBySpecification('Oragono') + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testCapRemovalByClient(self): """Test CAP LIST and removal of caps via CAP REQ :-tagname.""" self.addClient(1) - self.sendLine(1, 'CAP LS 302') - self.assertIn('multi-prefix', self.getCapLs(1)) - self.sendLine(1, 'CAP REQ :echo-message server-time') - self.sendLine(1, 'nick bar') - self.sendLine(1, 'user user 0 * realname') - self.sendLine(1, 'CAP END') + self.sendLine(1, "CAP LS 302") + self.assertIn("multi-prefix", self.getCapLs(1)) + self.sendLine(1, "CAP REQ :echo-message server-time") + self.sendLine(1, "nick bar") + self.sendLine(1, "user user 0 * realname") + self.sendLine(1, "CAP END") self.skipToWelcome(1) self.getMessages(1) - self.sendLine(1, 'CAP LIST') + self.sendLine(1, "CAP LIST") messages = self.getMessages(1) - cap_list = [m for m in messages if m.command == 'CAP'][0] - self.assertEqual(set(cap_list.params[2].split()), {'echo-message', 'server-time'}) - self.assertIn('time', cap_list.tags) + cap_list = [m for m in messages if m.command == "CAP"][0] + self.assertEqual( + set(cap_list.params[2].split()), {"echo-message", "server-time"} + ) + self.assertIn("time", cap_list.tags) # remove the server-time cap - self.sendLine(1, 'CAP REQ :-server-time') + self.sendLine(1, "CAP REQ :-server-time") self.getMessages(1) # server-time should be disabled - self.sendLine(1, 'CAP LIST') + self.sendLine(1, "CAP LIST") messages = self.getMessages(1) - cap_list = [m for m in messages if m.command == 'CAP'][0] - self.assertEqual(set(cap_list.params[2].split()), {'echo-message'}) - self.assertNotIn('time', cap_list.tags) + cap_list = [m for m in messages if m.command == "CAP"][0] + self.assertEqual(set(cap_list.params[2].split()), {"echo-message"}) + self.assertNotIn("time", cap_list.tags) diff --git a/irctest/server_tests/test_channel_forward.py b/irctest/server_tests/test_channel_forward.py index 5f56afd..3670e40 100644 --- a/irctest/server_tests/test_channel_forward.py +++ b/irctest/server_tests/test_channel_forward.py @@ -1,44 +1,52 @@ from irctest import cases from irctest.numerics import ERR_CHANOPRIVSNEEDED, ERR_INVALIDMODEPARAM, ERR_LINKCHANNEL -MODERN_CAPS = ['server-time', 'message-tags', 'batch', 'labeled-response', 'echo-message', 'account-tag'] +MODERN_CAPS = [ + "server-time", + "message-tags", + "batch", + "labeled-response", + "echo-message", + "account-tag", +] + class ChannelForwarding(cases.BaseServerTestCase): """Test the +f channel forwarding mode.""" - @cases.SpecificationSelector.requiredBySpecification('Oragono') + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testChannelForwarding(self): - self.connectClient('bar', name='bar', capabilities=MODERN_CAPS) - self.connectClient('baz', name='baz', capabilities=MODERN_CAPS) - self.joinChannel('bar', '#bar') - self.joinChannel('bar', '#bar_two') - self.joinChannel('baz', '#baz') + self.connectClient("bar", name="bar", capabilities=MODERN_CAPS) + self.connectClient("baz", name="baz", capabilities=MODERN_CAPS) + self.joinChannel("bar", "#bar") + self.joinChannel("bar", "#bar_two") + self.joinChannel("baz", "#baz") - self.sendLine('bar', 'MODE #bar +f #nonexistent') - msg = self.getMessage('bar') + self.sendLine("bar", "MODE #bar +f #nonexistent") + msg = self.getMessage("bar") self.assertMessageEqual(msg, command=ERR_INVALIDMODEPARAM) # need chanops in the target channel as well - self.sendLine('bar', 'MODE #bar +f #baz') - responses = set(msg.command for msg in self.getMessages('bar')) + self.sendLine("bar", "MODE #bar +f #baz") + responses = set(msg.command for msg in self.getMessages("bar")) self.assertIn(ERR_CHANOPRIVSNEEDED, responses) - self.sendLine('bar', 'MODE #bar +f #bar_two') - msg = self.getMessage('bar') - self.assertMessageEqual(msg, command='MODE', params=['#bar', '+f', '#bar_two']) + self.sendLine("bar", "MODE #bar +f #bar_two") + msg = self.getMessage("bar") + self.assertMessageEqual(msg, command="MODE", params=["#bar", "+f", "#bar_two"]) # can still join the channel fine - self.joinChannel('baz', '#bar') - self.sendLine('baz', 'PART #bar') - self.getMessages('baz') + self.joinChannel("baz", "#bar") + self.sendLine("baz", "PART #bar") + self.getMessages("baz") # now make it invite-only, which should cause forwarding - self.sendLine('bar', 'MODE #bar +i') - self.getMessages('bar') + self.sendLine("bar", "MODE #bar +i") + self.getMessages("bar") - self.sendLine('baz', 'JOIN #bar') - msgs = self.getMessages('baz') + self.sendLine("baz", "JOIN #bar") + msgs = self.getMessages("baz") forward = [msg for msg in msgs if msg.command == ERR_LINKCHANNEL] - self.assertEqual(forward[0].params[:3], ['baz', '#bar', '#bar_two']) - join = [msg for msg in msgs if msg.command == 'JOIN'] - self.assertMessageEqual(join[0], params=['#bar_two']) + self.assertEqual(forward[0].params[:3], ["baz", "#bar", "#bar_two"]) + join = [msg for msg in msgs if msg.command == "JOIN"] + self.assertMessageEqual(join[0], params=["#bar_two"]) diff --git a/irctest/server_tests/test_channel_operations.py b/irctest/server_tests/test_channel_operations.py index eb372b2..c48f6f8 100644 --- a/irctest/server_tests/test_channel_operations.py +++ b/irctest/server_tests/test_channel_operations.py @@ -7,14 +7,39 @@ from irctest import cases from irctest import client_mock from irctest import runner from irctest.irc_utils import ambiguities -from irctest.numerics import RPL_TOPIC, RPL_TOPICTIME, RPL_NOTOPIC, RPL_NAMREPLY, RPL_INVITING -from irctest.numerics import ERR_NOSUCHCHANNEL, ERR_NOTONCHANNEL, ERR_CHANOPRIVSNEEDED, ERR_NOSUCHNICK, ERR_INVITEONLYCHAN, ERR_CANNOTSENDTOCHAN, ERR_BADCHANNELKEY, ERR_INVALIDMODEPARAM, ERR_UNKNOWNERROR +from irctest.numerics import ( + RPL_TOPIC, + RPL_TOPICTIME, + RPL_NOTOPIC, + RPL_NAMREPLY, + RPL_INVITING, +) +from irctest.numerics import ( + ERR_NOSUCHCHANNEL, + ERR_NOTONCHANNEL, + ERR_CHANOPRIVSNEEDED, + ERR_NOSUCHNICK, + ERR_INVITEONLYCHAN, + ERR_CANNOTSENDTOCHAN, + ERR_BADCHANNELKEY, + ERR_INVALIDMODEPARAM, + ERR_UNKNOWNERROR, +) + +MODERN_CAPS = [ + "server-time", + "message-tags", + "batch", + "labeled-response", + "echo-message", + "account-tag", +] -MODERN_CAPS = ['server-time', 'message-tags', 'batch', 'labeled-response', 'echo-message', 'account-tag'] class JoinTestCase(cases.BaseServerTestCase): - @cases.SpecificationSelector.requiredBySpecification('RFC1459', 'RFC2812', - strict=True) + @cases.SpecificationSelector.requiredBySpecification( + "RFC1459", "RFC2812", strict=True + ) def testJoinAllMessages(self): """“If a JOIN is successful, the user receives a JOIN message as confirmation and is then sent the channel's topic (using RPL_TOPIC) and @@ -27,18 +52,21 @@ class JoinTestCase(cases.BaseServerTestCase): RPL_NAMREPLY), which must include the user joining.” -- """ - self.connectClient('foo') - self.sendLine(1, 'JOIN #chan') + self.connectClient("foo") + self.sendLine(1, "JOIN #chan") received_commands = {m.command for m in self.getMessages(1)} expected_commands = { - '353', # RPL_NAMREPLY - '366', # RPL_ENDOFNAMES - } - self.assertTrue(expected_commands.issubset(received_commands), - 'Server sent {} commands, but at least {} were expected.' - .format(received_commands, expected_commands)) + "353", # RPL_NAMREPLY + "366", # RPL_ENDOFNAMES + } + self.assertTrue( + expected_commands.issubset(received_commands), + "Server sent {} commands, but at least {} were expected.".format( + received_commands, expected_commands + ), + ) - @cases.SpecificationSelector.requiredBySpecification('RFC2812') + @cases.SpecificationSelector.requiredBySpecification("RFC2812") def testJoinNamreply(self): """“353 RPL_NAMREPLY "( "=" / "*" / "@" ) @@ -47,27 +75,41 @@ class JoinTestCase(cases.BaseServerTestCase): This test makes a user join and check what is sent to them. """ - self.connectClient('foo') - self.sendLine(1, 'JOIN #chan') + self.connectClient("foo") + self.sendLine(1, "JOIN #chan") for m in self.getMessages(1): - if m.command == '353': - self.assertIn(len(m.params), (3, 4), m, - fail_msg='RPL_NAM_REPLY with number of arguments ' - '<3 or >4: {msg}') + if m.command == "353": + self.assertIn( + len(m.params), + (3, 4), + m, + fail_msg="RPL_NAM_REPLY with number of arguments " + "<3 or >4: {msg}", + ) params = ambiguities.normalize_namreply_params(m.params) - self.assertIn(params[1], '=*@', m, - fail_msg='Bad channel prefix: {item} not in {list}: {msg}') - self.assertEqual(params[2], '#chan', m, - fail_msg='Bad channel name: {got} instead of ' - '{expects}: {msg}') - self.assertIn(params[3], {'foo', '@foo', '+foo'}, m, - fail_msg='Bad user list: should contain only user ' - '"foo" with an optional "+" or "@" prefix, but got: ' - '{msg}') + self.assertIn( + params[1], + "=*@", + m, + fail_msg="Bad channel prefix: {item} not in {list}: {msg}", + ) + self.assertEqual( + params[2], + "#chan", + m, + fail_msg="Bad channel name: {got} instead of " "{expects}: {msg}", + ) + self.assertIn( + params[3], + {"foo", "@foo", "+foo"}, + m, + fail_msg="Bad user list: should contain only user " + '"foo" with an optional "+" or "@" prefix, but got: ' + "{msg}", + ) - - @cases.SpecificationSelector.requiredBySpecification('RFC1459', 'RFC2812') + @cases.SpecificationSelector.requiredBySpecification("RFC1459", "RFC2812") def testPartNotInEmptyChannel(self): """“442 ERR_NOTONCHANNEL " :You're not on that channel" @@ -91,80 +133,102 @@ class JoinTestCase(cases.BaseServerTestCase): """ - self.connectClient('foo') - self.sendLine(1, 'PART #chan') + self.connectClient("foo") + self.sendLine(1, "PART #chan") m = self.getMessage(1) - self.assertIn(m.command, {'442', '403'}, m, # ERR_NOTONCHANNEL, ERR_NOSUCHCHANNEL - fail_msg='Expected ERR_NOTONCHANNEL (442) or ' - 'ERR_NOSUCHCHANNEL (403) after PARTing an empty channel ' - 'one is not on, but got: {msg}') + self.assertIn( + m.command, + {"442", "403"}, + m, # ERR_NOTONCHANNEL, ERR_NOSUCHCHANNEL + fail_msg="Expected ERR_NOTONCHANNEL (442) or " + "ERR_NOSUCHCHANNEL (403) after PARTing an empty channel " + "one is not on, but got: {msg}", + ) - @cases.SpecificationSelector.requiredBySpecification('RFC1459', 'RFC2812') + @cases.SpecificationSelector.requiredBySpecification("RFC1459", "RFC2812") def testPartNotInNonEmptyChannel(self): - self.connectClient('foo') - self.connectClient('bar') - self.sendLine(1, 'JOIN #chan') - self.getMessages(1) # Synchronize - self.sendLine(2, 'PART #chan') + self.connectClient("foo") + self.connectClient("bar") + self.sendLine(1, "JOIN #chan") + self.getMessages(1) # Synchronize + self.sendLine(2, "PART #chan") m = self.getMessage(2) - self.assertMessageEqual(m, command='442', # ERR_NOTONCHANNEL - fail_msg='Expected ERR_NOTONCHANNEL (442) ' - 'after PARTing a non-empty channel ' - 'one is not on, but got: {msg}') + self.assertMessageEqual( + m, + command="442", # ERR_NOTONCHANNEL + fail_msg="Expected ERR_NOTONCHANNEL (442) " + "after PARTing a non-empty channel " + "one is not on, but got: {msg}", + ) self.assertEqual(self.getMessages(2), []) + testPartNotInNonEmptyChannel.__doc__ = testPartNotInEmptyChannel.__doc__ def testJoinTwice(self): - self.connectClient('foo') - self.sendLine(1, 'JOIN #chan') + self.connectClient("foo") + self.sendLine(1, "JOIN #chan") m = self.getMessage(1) - self.assertMessageEqual(m, command='JOIN', params=['#chan']) + self.assertMessageEqual(m, command="JOIN", params=["#chan"]) self.getMessages(1) - self.sendLine(1, 'JOIN #chan') + self.sendLine(1, "JOIN #chan") # Note that there may be no message. Both RFCs require replies only # if the join is successful, or has an error among the given set. for m in self.getMessages(1): - if m.command == '353': - self.assertIn(len(m.params), (3, 4), m, - fail_msg='RPL_NAM_REPLY with number of arguments ' - '<3 or >4: {msg}') + if m.command == "353": + self.assertIn( + len(m.params), + (3, 4), + m, + fail_msg="RPL_NAM_REPLY with number of arguments " + "<3 or >4: {msg}", + ) params = ambiguities.normalize_namreply_params(m.params) - self.assertIn(params[1], '=*@', m, - fail_msg='Bad channel prefix: {item} not in {list}: {msg}') - self.assertEqual(params[2], '#chan', m, - fail_msg='Bad channel name: {got} instead of ' - '{expects}: {msg}') - self.assertIn(params[3], {'foo', '@foo', '+foo'}, m, - fail_msg='Bad user list after user "foo" joined twice ' - 'the same channel: should contain only user ' - '"foo" with an optional "+" or "@" prefix, but got: ' - '{msg}') + self.assertIn( + params[1], + "=*@", + m, + fail_msg="Bad channel prefix: {item} not in {list}: {msg}", + ) + self.assertEqual( + params[2], + "#chan", + m, + fail_msg="Bad channel name: {got} instead of " "{expects}: {msg}", + ) + self.assertIn( + params[3], + {"foo", "@foo", "+foo"}, + m, + fail_msg='Bad user list after user "foo" joined twice ' + "the same channel: should contain only user " + '"foo" with an optional "+" or "@" prefix, but got: ' + "{msg}", + ) - @cases.SpecificationSelector.requiredBySpecification('RFC1459', 'RFC2812') + @cases.SpecificationSelector.requiredBySpecification("RFC1459", "RFC2812") def testNormalPart(self): - self.connectClient('bar') - self.sendLine(1, 'JOIN #chan') + self.connectClient("bar") + self.sendLine(1, "JOIN #chan") m = self.getMessage(1) - self.assertMessageEqual(m, command='JOIN', params=['#chan']) + self.assertMessageEqual(m, command="JOIN", params=["#chan"]) - self.connectClient('baz') - self.sendLine(2, 'JOIN #chan') + self.connectClient("baz") + self.sendLine(2, "JOIN #chan") m = self.getMessage(2) - self.assertMessageEqual(m, command='JOIN', params=['#chan']) + self.assertMessageEqual(m, command="JOIN", params=["#chan"]) # skip the rest of the JOIN burst: self.getMessages(1) self.getMessages(2) - self.sendLine(1, 'PART #chan :bye everyone') + self.sendLine(1, "PART #chan :bye everyone") # both the PART'ing client and the other channel member should receive a PART line: m = self.getMessage(1) - self.assertMessageEqual(m, command='PART', params=['#chan', 'bye everyone']) + self.assertMessageEqual(m, command="PART", params=["#chan", "bye everyone"]) m = self.getMessage(2) - self.assertMessageEqual(m, command='PART', params=['#chan', 'bye everyone']) + self.assertMessageEqual(m, command="PART", params=["#chan", "bye everyone"]) - - @cases.SpecificationSelector.requiredBySpecification('RFC1459', 'RFC2812') + @cases.SpecificationSelector.requiredBySpecification("RFC1459", "RFC2812") def testTopic(self): """“Once a user has joined a channel, he receives information about all commands his server receives affecting the channel. This @@ -172,11 +236,11 @@ class JoinTestCase(cases.BaseServerTestCase): -- and """ - self.connectClient('foo') - self.joinChannel(1, '#chan') + self.connectClient("foo") + self.joinChannel(1, "#chan") - self.connectClient('bar') - self.joinChannel(2, '#chan') + self.connectClient("bar") + self.joinChannel(2, "#chan") # clear waiting msgs about cli 2 joining the channel self.getMessages(1) @@ -184,22 +248,23 @@ class JoinTestCase(cases.BaseServerTestCase): # TODO: check foo is opped OR +t is unset - self.sendLine(1, 'TOPIC #chan :T0P1C') + self.sendLine(1, "TOPIC #chan :T0P1C") try: m = self.getMessage(1) - if m.command == '482': + if m.command == "482": raise runner.ImplementationChoice( - 'Channel creators are not opped by default, and ' - 'channel modes to no allow regular users to change ' - 'topic.') - self.assertMessageEqual(m, command='TOPIC') + "Channel creators are not opped by default, and " + "channel modes to no allow regular users to change " + "topic." + ) + self.assertMessageEqual(m, command="TOPIC") except client_mock.NoMessageException: # The RFCs do not say TOPIC must be echoed pass m = self.getMessage(2) - self.assertMessageEqual(m, command='TOPIC', params=['#chan', 'T0P1C']) + self.assertMessageEqual(m, command="TOPIC", params=["#chan", "T0P1C"]) - @cases.SpecificationSelector.requiredBySpecification('RFC1459', 'RFC2812') + @cases.SpecificationSelector.requiredBySpecification("RFC1459", "RFC2812") def testTopicMode(self): """“Once a user has joined a channel, he receives information about all commands his server receives affecting the channel. This @@ -207,41 +272,44 @@ class JoinTestCase(cases.BaseServerTestCase): -- and """ - self.connectClient('foo') - self.joinChannel(1, '#chan') + self.connectClient("foo") + self.joinChannel(1, "#chan") - self.connectClient('bar') - self.joinChannel(2, '#chan') + self.connectClient("bar") + self.joinChannel(2, "#chan") self.getMessages(1) self.getMessages(2) # TODO: check foo is opped - self.sendLine(1, 'MODE #chan +t') + self.sendLine(1, "MODE #chan +t") self.getMessages(1) - self.sendLine(2, 'TOPIC #chan :T0P1C') + self.sendLine(2, "TOPIC #chan :T0P1C") m = self.getMessage(2) - self.assertMessageEqual(m, command='482', - fail_msg='Non-op user was not refused use of TOPIC: {msg}') + self.assertMessageEqual( + m, command="482", fail_msg="Non-op user was not refused use of TOPIC: {msg}" + ) self.assertEqual(self.getMessages(1), []) - self.sendLine(1, 'MODE #chan -t') + self.sendLine(1, "MODE #chan -t") self.getMessages(1) - self.sendLine(2, 'TOPIC #chan :T0P1C') + self.sendLine(2, "TOPIC #chan :T0P1C") try: m = self.getMessage(2) - self.assertNotEqual(m.command, '482', - msg='User was refused TOPIC whereas +t was not ' - 'set: {}'.format(m)) + self.assertNotEqual( + m.command, + "482", + msg="User was refused TOPIC whereas +t was not " "set: {}".format(m), + ) except client_mock.NoMessageException: # The RFCs do not say TOPIC must be echoed pass m = self.getMessage(1) - self.assertMessageEqual(m, command='TOPIC', params=['#chan', 'T0P1C']) + self.assertMessageEqual(m, command="TOPIC", params=["#chan", "T0P1C"]) - @cases.SpecificationSelector.requiredBySpecification('RFC2812') + @cases.SpecificationSelector.requiredBySpecification("RFC2812") def testTopicNonexistentChannel(self): """RFC2812 specifies ERR_NOTONCHANNEL as the correct response to TOPIC on a nonexistent channel. The modern spec prefers ERR_NOSUCHCHANNEL. @@ -249,98 +317,117 @@ class JoinTestCase(cases.BaseServerTestCase): """ - self.connectClient('foo') - self.sendLine(1, 'TOPIC #chan') + self.connectClient("foo") + self.sendLine(1, "TOPIC #chan") m = self.getMessage(1) # either 403 ERR_NOSUCHCHANNEL or 443 ERR_NOTONCHANNEL - self.assertIn(m.command, ('403', '443')) + self.assertIn(m.command, ("403", "443")) - @cases.SpecificationSelector.requiredBySpecification('RFC2812') + @cases.SpecificationSelector.requiredBySpecification("RFC2812") def testUnsetTopicResponses(self): """Test various cases related to RPL_NOTOPIC with set and unset topics.""" - self.connectClient('bar') - self.sendLine(1, 'JOIN #test') + self.connectClient("bar") + self.sendLine(1, "JOIN #test") messages = self.getMessages(1) # shouldn't send RPL_NOTOPIC for a new channel self.assertNotIn(RPL_NOTOPIC, [m.command for m in messages]) - self.connectClient('baz') - self.sendLine(2, 'JOIN #test') + self.connectClient("baz") + self.sendLine(2, "JOIN #test") messages = self.getMessages(2) # topic is still unset, shouldn't send RPL_NOTOPIC on initial join self.assertNotIn(RPL_NOTOPIC, [m.command for m in messages]) - self.sendLine(2, 'TOPIC #test') + self.sendLine(2, "TOPIC #test") messages = self.getMessages(2) # explicit TOPIC should receive RPL_NOTOPIC self.assertIn(RPL_NOTOPIC, [m.command for m in messages]) - self.sendLine(1, 'TOPIC #test :new topic') + self.sendLine(1, "TOPIC #test :new topic") self.getMessages(1) # client 2 should get the new TOPIC line - messages = [message for message in self.getMessages(2) if message.command == 'TOPIC'] + messages = [ + message for message in self.getMessages(2) if message.command == "TOPIC" + ] self.assertEqual(len(messages), 1) - self.assertEqual(messages[0].params, ['#test', 'new topic']) + self.assertEqual(messages[0].params, ["#test", "new topic"]) # unset the topic: - self.sendLine(1, 'TOPIC #test :') + self.sendLine(1, "TOPIC #test :") self.getMessages(1) - self.connectClient('qux') - self.sendLine(3, 'join #test') + self.connectClient("qux") + self.sendLine(3, "join #test") messages = self.getMessages(3) # topic is once again unset, shouldn't send RPL_NOTOPIC on initial join self.assertNotIn(RPL_NOTOPIC, [m.command for m in messages]) - @cases.SpecificationSelector.requiredBySpecification('RFC1459', 'RFC2812') + @cases.SpecificationSelector.requiredBySpecification("RFC1459", "RFC2812") def testListEmpty(self): """ """ - self.connectClient('foo') - self.connectClient('bar') + self.connectClient("foo") + self.connectClient("bar") self.getMessages(1) - self.sendLine(2, 'LIST') + self.sendLine(2, "LIST") m = self.getMessage(2) - if m.command == '321': + if m.command == "321": # skip RPL_LISTSTART m = self.getMessage(2) - self.assertNotEqual(m.command, '322', # RPL_LIST - 'LIST response gives (at least) one channel, whereas there ' - 'is none.') - self.assertMessageEqual(m, command='323', # RPL_LISTEND - fail_msg='Second reply to LIST is not 322 (RPL_LIST) ' - 'or 323 (RPL_LISTEND), or but: {msg}') + self.assertNotEqual( + m.command, + "322", # RPL_LIST + "LIST response gives (at least) one channel, whereas there " "is none.", + ) + self.assertMessageEqual( + m, + command="323", # RPL_LISTEND + fail_msg="Second reply to LIST is not 322 (RPL_LIST) " + "or 323 (RPL_LISTEND), or but: {msg}", + ) - @cases.SpecificationSelector.requiredBySpecification('RFC1459', 'RFC2812') + @cases.SpecificationSelector.requiredBySpecification("RFC1459", "RFC2812") def testListOne(self): """When a channel exists, LIST should get it in a reply. """ - self.connectClient('foo') - self.connectClient('bar') - self.sendLine(1, 'JOIN #chan') + self.connectClient("foo") + self.connectClient("bar") + self.sendLine(1, "JOIN #chan") self.getMessages(1) - self.sendLine(2, 'LIST') + self.sendLine(2, "LIST") m = self.getMessage(2) - if m.command == '321': + if m.command == "321": # skip RPL_LISTSTART m = self.getMessage(2) - self.assertNotEqual(m.command, '323', # RPL_LISTEND - fail_msg='LIST response ended (ie. 323, aka RPL_LISTEND) ' - 'without listing any channel, whereas there is one.') - self.assertMessageEqual(m, command='322', # RPL_LIST - fail_msg='Second reply to LIST is not 322 (RPL_LIST), ' - 'nor 323 (RPL_LISTEND) but: {msg}') + self.assertNotEqual( + m.command, + "323", # RPL_LISTEND + fail_msg="LIST response ended (ie. 323, aka RPL_LISTEND) " + "without listing any channel, whereas there is one.", + ) + self.assertMessageEqual( + m, + command="322", # RPL_LIST + fail_msg="Second reply to LIST is not 322 (RPL_LIST), " + "nor 323 (RPL_LISTEND) but: {msg}", + ) m = self.getMessage(2) - self.assertNotEqual(m.command, '322', # RPL_LIST - fail_msg='LIST response gives (at least) two channels, ' - 'whereas there is only one.') - self.assertMessageEqual(m, command='323', # RPL_LISTEND - fail_msg='Third reply to LIST is not 322 (RPL_LIST) ' - 'or 323 (RPL_LISTEND), or but: {msg}') + self.assertNotEqual( + m.command, + "322", # RPL_LIST + fail_msg="LIST response gives (at least) two channels, " + "whereas there is only one.", + ) + self.assertMessageEqual( + m, + command="323", # RPL_LISTEND + fail_msg="Third reply to LIST is not 322 (RPL_LIST) " + "or 323 (RPL_LISTEND), or but: {msg}", + ) - @cases.SpecificationSelector.requiredBySpecification('RFC1459', 'RFC2812') + @cases.SpecificationSelector.requiredBySpecification("RFC1459", "RFC2812") def testKickSendsMessages(self): """“Once a user has joined a channel, he receives information about all commands his server receives affecting the channel. This @@ -348,46 +435,45 @@ class JoinTestCase(cases.BaseServerTestCase): -- and """ - self.connectClient('foo') - self.joinChannel(1, '#chan') + self.connectClient("foo") + self.joinChannel(1, "#chan") - self.connectClient('bar') - self.joinChannel(2, '#chan') + self.connectClient("bar") + self.joinChannel(2, "#chan") - self.connectClient('baz') - self.joinChannel(3, '#chan') + self.connectClient("baz") + self.joinChannel(3, "#chan") # TODO: check foo is an operator self.getMessages(1) self.getMessages(2) self.getMessages(3) - self.sendLine(1, 'KICK #chan bar :bye') + self.sendLine(1, "KICK #chan bar :bye") try: m = self.getMessage(1) - if m.command == '482': + if m.command == "482": raise runner.ImplementationChoice( - 'Channel creators are not opped by default.') - self.assertMessageEqual(m, command='KICK') + "Channel creators are not opped by default." + ) + self.assertMessageEqual(m, command="KICK") except client_mock.NoMessageException: # The RFCs do not say KICK must be echoed pass m = self.getMessage(2) - self.assertMessageEqual(m, command='KICK', - params=['#chan', 'bar', 'bye']) + self.assertMessageEqual(m, command="KICK", params=["#chan", "bar", "bye"]) m = self.getMessage(3) - self.assertMessageEqual(m, command='KICK', - params=['#chan', 'bar', 'bye']) + self.assertMessageEqual(m, command="KICK", params=["#chan", "bar", "bye"]) - @cases.SpecificationSelector.requiredBySpecification('RFC2812') + @cases.SpecificationSelector.requiredBySpecification("RFC2812") def testKickPrivileges(self): """Test who has the ability to kick / what error codes are sent for invalid kicks.""" - self.connectClient('foo') - self.sendLine(1, 'JOIN #chan') + self.connectClient("foo") + self.sendLine(1, "JOIN #chan") self.getMessages(1) - self.connectClient('bar') - self.sendLine(2, 'JOIN #chan') + self.connectClient("bar") + self.sendLine(2, "JOIN #chan") messages = self.getMessages(2) names = set() @@ -395,63 +481,72 @@ class JoinTestCase(cases.BaseServerTestCase): if message.command == RPL_NAMREPLY: names.update(set(message.params[-1].split())) # assert foo is opped - self.assertIn('@foo', names, f'unexpected names: {names}') + self.assertIn("@foo", names, f"unexpected names: {names}") - self.connectClient('baz') + self.connectClient("baz") - self.sendLine(3, 'KICK #chan bar') + self.sendLine(3, "KICK #chan bar") replies = set(m.command for m in self.getMessages(3)) self.assertTrue( - ERR_NOTONCHANNEL in replies or ERR_CHANOPRIVSNEEDED in replies or ERR_NOSUCHCHANNEL in replies, - f'did not receive acceptable error code for kick from outside channel: {replies}') + ERR_NOTONCHANNEL in replies + or ERR_CHANOPRIVSNEEDED in replies + or ERR_NOSUCHCHANNEL in replies, + f"did not receive acceptable error code for kick from outside channel: {replies}", + ) - self.joinChannel(3, '#chan') + self.joinChannel(3, "#chan") self.getMessages(3) - self.sendLine(3, 'KICK #chan bar') + self.sendLine(3, "KICK #chan bar") replies = set(m.command for m in self.getMessages(3)) # now we're a channel member so we should receive ERR_CHANOPRIVSNEEDED self.assertIn(ERR_CHANOPRIVSNEEDED, replies) - self.sendLine(1, 'MODE #chan +o baz') + self.sendLine(1, "MODE #chan +o baz") self.getMessages(1) # should be able to kick an unprivileged user: - self.sendLine(3, 'KICK #chan bar') + self.sendLine(3, "KICK #chan bar") # should be able to kick an operator: - self.sendLine(3, 'KICK #chan foo') + self.sendLine(3, "KICK #chan foo") baz_replies = set(m.command for m in self.getMessages(3)) self.assertNotIn(ERR_CHANOPRIVSNEEDED, baz_replies) - kick_targets = [m.params[1] for m in self.getMessages(1) if m.command == 'KICK'] + kick_targets = [m.params[1] for m in self.getMessages(1) if m.command == "KICK"] # foo should see bar and foo being kicked - self.assertTrue(any(target.startswith('foo') for target in kick_targets), f'unexpected kick targets: {kick_targets}') - self.assertTrue(any(target.startswith('bar') for target in kick_targets), f'unexpected kick targets: {kick_targets}') + self.assertTrue( + any(target.startswith("foo") for target in kick_targets), + f"unexpected kick targets: {kick_targets}", + ) + self.assertTrue( + any(target.startswith("bar") for target in kick_targets), + f"unexpected kick targets: {kick_targets}", + ) - @cases.SpecificationSelector.requiredBySpecification('RFC2812') + @cases.SpecificationSelector.requiredBySpecification("RFC2812") def testKickNonexistentChannel(self): """“Kick command [...] Numeric replies: [...] ERR_NOSUCHCHANNEL.""" - self.connectClient('foo') - self.sendLine(1, 'KICK #chan nick') + self.connectClient("foo") + self.sendLine(1, "KICK #chan nick") m = self.getMessage(1) # should return ERR_NOSUCHCHANNEL - self.assertMessageEqual(m, command='403') + self.assertMessageEqual(m, command="403") - @cases.SpecificationSelector.requiredBySpecification('RFC2812') + @cases.SpecificationSelector.requiredBySpecification("RFC2812") def testDoubleKickMessages(self): """“The server MUST NOT send KICK messages with multiple channels or users to clients. This is necessarily to maintain backward compatibility with old client software.” -- https://tools.ietf.org/html/rfc2812#section-3.2.8 """ - self.connectClient('foo') - self.joinChannel(1, '#chan') + self.connectClient("foo") + self.joinChannel(1, "#chan") - self.connectClient('bar') - self.joinChannel(2, '#chan') + self.connectClient("bar") + self.joinChannel(2, "#chan") - self.connectClient('baz') - self.joinChannel(3, '#chan') + self.connectClient("baz") + self.joinChannel(3, "#chan") - self.connectClient('qux') - self.joinChannel(4, '#chan') + self.connectClient("qux") + self.joinChannel(4, "#chan") # TODO: check foo is an operator @@ -461,15 +556,15 @@ class JoinTestCase(cases.BaseServerTestCase): self.getMessages(3) self.getMessages(4) - self.sendLine(1, 'KICK #chan,#chan bar,baz :bye') + self.sendLine(1, "KICK #chan,#chan bar,baz :bye") try: m = self.getMessage(1) - if m.command == '482': + if m.command == "482": raise runner.OptionalExtensionNotSupported( - 'Channel creators are not opped by default.') - if m.command in {'401', '403'}: - raise runner.NotImplementedByController( - 'Multi-target KICK') + "Channel creators are not opped by default." + ) + if m.command in {"401", "403"}: + raise runner.NotImplementedByController("Multi-target KICK") except client_mock.NoMessageException: # The RFCs do not say KICK must be echoed pass @@ -479,18 +574,24 @@ class JoinTestCase(cases.BaseServerTestCase): m1, m2 = mgroup[:2] for m in m1, m2: - self.assertEqual(m.command, 'KICK') + self.assertEqual(m.command, "KICK") self.assertEqual(len(m.params), 3) - self.assertEqual(m.params[0], '#chan') - self.assertEqual(m.params[2], 'bye') - - if (m1.params[1] == 'bar' and m2.params[1] == 'baz') or (m1.params[1] == 'baz' and m2.params[1] == 'bar'): + self.assertEqual(m.params[0], "#chan") + self.assertEqual(m.params[2], "bye") + + if (m1.params[1] == "bar" and m2.params[1] == "baz") or ( + m1.params[1] == "baz" and m2.params[1] == "bar" + ): ... # success else: - raise AssertionError('Middle params [{}, {}] are not correct.'.format(m1.params[1], m2.params[1])) + raise AssertionError( + "Middle params [{}, {}] are not correct.".format( + m1.params[1], m2.params[1] + ) + ) - @cases.SpecificationSelector.requiredBySpecification('RFC-deprecated') + @cases.SpecificationSelector.requiredBySpecification("RFC-deprecated") def testInviteNonExistingChannelTransmitted(self): """“There is no requirement that the channel the target user is being invited to must exist or be a valid channel.” @@ -501,23 +602,29 @@ class JoinTestCase(cases.BaseServerTestCase): notification of the invitation.” -- """ - self.connectClient('foo') - self.connectClient('bar') + self.connectClient("foo") + self.connectClient("bar") self.getMessages(1) self.getMessages(2) - self.sendLine(1, 'INVITE #chan bar') + self.sendLine(1, "INVITE #chan bar") self.getMessages(1) l = self.getMessages(2) - self.assertNotEqual(l, [], - fail_msg='After using “INVITE #chan bar” while #chan does ' - 'not exist, “bar” received nothing.') - self.assertMessageEqual(l[0], command='INVITE', - params=['#chan', 'bar'], - fail_msg='After “foo” invited “bar” do non-existing channel ' - '#chan, “bar” should have received “INVITE #chan bar” but ' - 'got this instead: {msg}') + self.assertNotEqual( + l, + [], + fail_msg="After using “INVITE #chan bar” while #chan does " + "not exist, “bar” received nothing.", + ) + self.assertMessageEqual( + l[0], + command="INVITE", + params=["#chan", "bar"], + fail_msg="After “foo” invited “bar” do non-existing channel " + "#chan, “bar” should have received “INVITE #chan bar” but " + "got this instead: {msg}", + ) - @cases.SpecificationSelector.requiredBySpecification('RFC-deprecated') + @cases.SpecificationSelector.requiredBySpecification("RFC-deprecated") def testInviteNonExistingChannelEchoed(self): """“There is no requirement that the channel the target user is being invited to must exist or be a valid channel.” @@ -528,50 +635,64 @@ class JoinTestCase(cases.BaseServerTestCase): notification of the invitation.” -- """ - self.connectClient('foo') - self.connectClient('bar') + self.connectClient("foo") + self.connectClient("bar") self.getMessages(1) self.getMessages(2) - self.sendLine(1, 'INVITE #chan bar') + self.sendLine(1, "INVITE #chan bar") l = self.getMessages(1) - self.assertNotEqual(l, [], - fail_msg='After using “INVITE #chan bar” while #chan does ' - 'not exist, the author received nothing.') - self.assertMessageEqual(l[0], command='INVITE', - params=['#chan', 'bar'], - fail_msg='After “foo” invited “bar” do non-existing channel ' - '#chan, “foo” should have received “INVITE #chan bar” but ' - 'got this instead: {msg}') + self.assertNotEqual( + l, + [], + fail_msg="After using “INVITE #chan bar” while #chan does " + "not exist, the author received nothing.", + ) + self.assertMessageEqual( + l[0], + command="INVITE", + params=["#chan", "bar"], + fail_msg="After “foo” invited “bar” do non-existing channel " + "#chan, “foo” should have received “INVITE #chan bar” but " + "got this instead: {msg}", + ) + class testChannelCaseSensitivity(cases.BaseServerTestCase): def _testChannelsEquivalent(casemapping, name1, name2): - @cases.SpecificationSelector.requiredBySpecification('RFC1459', 'RFC2812', - strict=True) + @cases.SpecificationSelector.requiredBySpecification( + "RFC1459", "RFC2812", strict=True + ) def f(self): - self.connectClient('foo') - self.connectClient('bar') - if self.server_support['CASEMAPPING'] != casemapping: - raise runner.NotImplementedByController('Casemapping {} not implemented'.format(casemapping)) + self.connectClient("foo") + self.connectClient("bar") + if self.server_support["CASEMAPPING"] != casemapping: + raise runner.NotImplementedByController( + "Casemapping {} not implemented".format(casemapping) + ) self.joinClient(1, name1) self.joinClient(2, name2) try: m = self.getMessage(1) - self.assertMessageEqual(m, command='JOIN', - nick='bar') + self.assertMessageEqual(m, command="JOIN", nick="bar") except client_mock.NoMessageException: raise AssertionError( - 'Channel names {} and {} are not equivalent.' - .format(name1, name2)) - f.__name__ = 'testEquivalence__{}__{}'.format(name1, name2) + "Channel names {} and {} are not equivalent.".format(name1, name2) + ) + + f.__name__ = "testEquivalence__{}__{}".format(name1, name2) return f + def _testChannelsNotEquivalent(casemapping, name1, name2): - @cases.SpecificationSelector.requiredBySpecification('RFC1459', 'RFC2812', - strict=True) + @cases.SpecificationSelector.requiredBySpecification( + "RFC1459", "RFC2812", strict=True + ) def f(self): - self.connectClient('foo') - self.connectClient('bar') - if self.server_support['CASEMAPPING'] != casemapping: - raise runner.NotImplementedByController('Casemapping {} not implemented'.format(casemapping)) + self.connectClient("foo") + self.connectClient("bar") + if self.server_support["CASEMAPPING"] != casemapping: + raise runner.NotImplementedByController( + "Casemapping {} not implemented".format(casemapping) + ) self.joinClient(1, name1) self.joinClient(2, name2) try: @@ -579,465 +700,525 @@ class testChannelCaseSensitivity(cases.BaseServerTestCase): except client_mock.NoMessageException: pass else: - self.assertMessageEqual(m, command='JOIN', - nick='bar') # This should always be true + self.assertMessageEqual( + m, command="JOIN", nick="bar" + ) # This should always be true raise AssertionError( - 'Channel names {} and {} are equivalent.' - .format(name1, name2)) - f.__name__ = 'testEquivalence__{}__{}'.format(name1, name2) + "Channel names {} and {} are equivalent.".format(name1, name2) + ) + + f.__name__ = "testEquivalence__{}__{}".format(name1, name2) return f - testAsciiSimpleEquivalent = _testChannelsEquivalent('ascii', '#Foo', '#foo') - testAsciiSimpleNotEquivalent = _testChannelsNotEquivalent('ascii', '#Foo', '#fooa') + testAsciiSimpleEquivalent = _testChannelsEquivalent("ascii", "#Foo", "#foo") + testAsciiSimpleNotEquivalent = _testChannelsNotEquivalent("ascii", "#Foo", "#fooa") - testRfcSimpleEquivalent = _testChannelsEquivalent('rfc1459', '#Foo', '#foo') - testRfcSimpleNotEquivalent = _testChannelsNotEquivalent('rfc1459', '#Foo', '#fooa') - testRfcFancyEquivalent = _testChannelsEquivalent('rfc1459', '#F]|oo{', '#f}\\oo[') - testRfcFancyNotEquivalent = _testChannelsEquivalent('rfc1459', '#F}o\\o[', '#f]o|o{') + testRfcSimpleEquivalent = _testChannelsEquivalent("rfc1459", "#Foo", "#foo") + testRfcSimpleNotEquivalent = _testChannelsNotEquivalent("rfc1459", "#Foo", "#fooa") + testRfcFancyEquivalent = _testChannelsEquivalent("rfc1459", "#F]|oo{", "#f}\\oo[") + testRfcFancyNotEquivalent = _testChannelsEquivalent( + "rfc1459", "#F}o\\o[", "#f]o|o{" + ) class InviteTestCase(cases.BaseServerTestCase): - - @cases.SpecificationSelector.requiredBySpecification('IRCv3.2') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.2") def testInvites(self): """Test some basic functionality related to INVITE and the +i mode.""" - self.connectClient('foo') - self.joinChannel(1, '#chan') - self.sendLine(1, 'MODE #chan +i') + self.connectClient("foo") + self.joinChannel(1, "#chan") + self.sendLine(1, "MODE #chan +i") self.getMessages(1) - self.sendLine(1, 'INVITE bar #chan') + self.sendLine(1, "INVITE bar #chan") m = self.getMessage(1) self.assertEqual(m.command, ERR_NOSUCHNICK) - self.connectClient('bar') - self.sendLine(2, 'JOIN #chan') + self.connectClient("bar") + self.sendLine(2, "JOIN #chan") m = self.getMessage(2) self.assertEqual(m.command, ERR_INVITEONLYCHAN) - self.sendLine(1, 'INVITE bar #chan') + self.sendLine(1, "INVITE bar #chan") m = self.getMessage(1) self.assertEqual(m.command, RPL_INVITING) # modern/ircv3 param order: inviter, invitee, channel - self.assertEqual(m.params, ['foo', 'bar', '#chan']) + self.assertEqual(m.params, ["foo", "bar", "#chan"]) m = self.getMessage(2) - self.assertEqual(m.command, 'INVITE') - self.assertTrue(m.prefix.startswith("foo")) # nickmask of inviter - self.assertEqual(m.params, ['bar', '#chan']) + self.assertEqual(m.command, "INVITE") + self.assertTrue(m.prefix.startswith("foo")) # nickmask of inviter + self.assertEqual(m.params, ["bar", "#chan"]) # we were invited, so join should succeed now - self.joinChannel(2, '#chan') + self.joinChannel(2, "#chan") class ChannelQuitTestCase(cases.BaseServerTestCase): - - @cases.SpecificationSelector.requiredBySpecification('RFC2812') + @cases.SpecificationSelector.requiredBySpecification("RFC2812") def testQuit(self): """“Once a user has joined a channel, he receives information about all commands his server receives affecting the channel. This includes [...] QUIT” """ - self.connectClient('bar') - self.joinChannel(1, '#chan') - self.connectClient('qux') - self.sendLine(2, 'JOIN #chan') + self.connectClient("bar") + self.joinChannel(1, "#chan") + self.connectClient("qux") + self.sendLine(2, "JOIN #chan") self.getMessages(2) self.getMessages(1) - self.sendLine(2, 'QUIT :qux out') + self.sendLine(2, "QUIT :qux out") self.getMessages(2) m = self.getMessage(1) - self.assertEqual(m.command, 'QUIT') - self.assertTrue(m.prefix.startswith('qux')) # nickmask of quitter - self.assertIn('qux out', m.params[0]) + self.assertEqual(m.command, "QUIT") + self.assertTrue(m.prefix.startswith("qux")) # nickmask of quitter + self.assertIn("qux out", m.params[0]) class NoCTCPTestCase(cases.BaseServerTestCase): - - @cases.SpecificationSelector.requiredBySpecification('Oragono') + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testQuit(self): - self.connectClient('bar') - self.joinChannel(1, '#chan') - self.sendLine(1, 'MODE #chan +C') + self.connectClient("bar") + self.joinChannel(1, "#chan") + self.sendLine(1, "MODE #chan +C") self.getMessages(1) - self.connectClient('qux') - self.joinChannel(2, '#chan') + self.connectClient("qux") + self.joinChannel(2, "#chan") self.getMessages(2) - self.sendLine(1, 'PRIVMSG #chan :\x01ACTION hi\x01') + self.sendLine(1, "PRIVMSG #chan :\x01ACTION hi\x01") self.getMessages(1) ms = self.getMessages(2) self.assertEqual(len(ms), 1) - self.assertMessageEqual(ms[0], command='PRIVMSG', params=['#chan', '\x01ACTION hi\x01']) + self.assertMessageEqual( + ms[0], command="PRIVMSG", params=["#chan", "\x01ACTION hi\x01"] + ) - self.sendLine(1, 'PRIVMSG #chan :\x01PING 1473523796 918320\x01') + self.sendLine(1, "PRIVMSG #chan :\x01PING 1473523796 918320\x01") ms = self.getMessages(1) self.assertEqual(len(ms), 1) self.assertMessageEqual(ms[0], command=ERR_CANNOTSENDTOCHAN) ms = self.getMessages(2) self.assertEqual(ms, []) -class KeyTestCase(cases.BaseServerTestCase): - @cases.SpecificationSelector.requiredBySpecification('RFC2812') +class KeyTestCase(cases.BaseServerTestCase): + @cases.SpecificationSelector.requiredBySpecification("RFC2812") def testKeyNormal(self): - self.connectClient('bar') - self.joinChannel(1, '#chan') - self.sendLine(1, 'MODE #chan +k beer') + self.connectClient("bar") + self.joinChannel(1, "#chan") + self.sendLine(1, "MODE #chan +k beer") self.getMessages(1) - self.connectClient('qux') + self.connectClient("qux") self.getMessages(2) - self.sendLine(2, 'JOIN #chan') + self.sendLine(2, "JOIN #chan") reply = self.getMessages(2) - self.assertNotIn('JOIN', {msg.command for msg in reply}) + self.assertNotIn("JOIN", {msg.command for msg in reply}) self.assertIn(ERR_BADCHANNELKEY, {msg.command for msg in reply}) - self.sendLine(2, 'JOIN #chan beer') + self.sendLine(2, "JOIN #chan beer") reply = self.getMessages(2) - self.assertMessageEqual(reply[0], command='JOIN', params=['#chan']) + self.assertMessageEqual(reply[0], command="JOIN", params=["#chan"]) - @cases.SpecificationSelector.requiredBySpecification('Oragono') + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testKeyValidation(self): # oragono issue #1021 - self.connectClient('bar') - self.joinChannel(1, '#chan') - self.sendLine(1, 'MODE #chan +k :invalid channel passphrase') + self.connectClient("bar") + self.joinChannel(1, "#chan") + self.sendLine(1, "MODE #chan +k :invalid channel passphrase") reply = self.getMessages(1) self.assertNotIn(ERR_UNKNOWNERROR, {msg.command for msg in reply}) self.assertIn(ERR_INVALIDMODEPARAM, {msg.command for msg in reply}) class AuditoriumTestCase(cases.BaseServerTestCase): - - @cases.SpecificationSelector.requiredBySpecification('Oragono') + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testAuditorium(self): - self.connectClient('bar', name='bar', capabilities=MODERN_CAPS) - self.joinChannel('bar', '#auditorium') - self.getMessages('bar') - self.sendLine('bar', 'MODE #auditorium +u') - modelines = [msg for msg in self.getMessages('bar') if msg.command == 'MODE'] + self.connectClient("bar", name="bar", capabilities=MODERN_CAPS) + self.joinChannel("bar", "#auditorium") + self.getMessages("bar") + self.sendLine("bar", "MODE #auditorium +u") + modelines = [msg for msg in self.getMessages("bar") if msg.command == "MODE"] self.assertEqual(len(modelines), 1) - self.assertMessageEqual(modelines[0], params=['#auditorium', '+u']) + self.assertMessageEqual(modelines[0], params=["#auditorium", "+u"]) - self.connectClient('guest1', name='guest1', capabilities=MODERN_CAPS) - self.joinChannel('guest1', '#auditorium') - self.getMessages('guest1') + self.connectClient("guest1", name="guest1", capabilities=MODERN_CAPS) + self.joinChannel("guest1", "#auditorium") + self.getMessages("guest1") # chanop should get a JOIN message - join_msgs = [msg for msg in self.getMessages('bar') if msg.command == 'JOIN'] + join_msgs = [msg for msg in self.getMessages("bar") if msg.command == "JOIN"] self.assertEqual(len(join_msgs), 1) - self.assertMessageEqual(join_msgs[0], nick='guest1', params=['#auditorium']) + self.assertMessageEqual(join_msgs[0], nick="guest1", params=["#auditorium"]) - self.connectClient('guest2', name='guest2', capabilities=MODERN_CAPS) - self.joinChannel('guest2', '#auditorium') - self.getMessages('guest2') + self.connectClient("guest2", name="guest2", capabilities=MODERN_CAPS) + self.joinChannel("guest2", "#auditorium") + self.getMessages("guest2") # chanop should get a JOIN message - join_msgs = [msg for msg in self.getMessages('bar') if msg.command == 'JOIN'] + join_msgs = [msg for msg in self.getMessages("bar") if msg.command == "JOIN"] self.assertEqual(len(join_msgs), 1) - self.assertMessageEqual(join_msgs[0], nick='guest2', params=['#auditorium']) + self.assertMessageEqual(join_msgs[0], nick="guest2", params=["#auditorium"]) # fellow unvoiced participant should not - unvoiced_join_msgs = [msg for msg in self.getMessages('guest1') if msg.command == 'JOIN'] + unvoiced_join_msgs = [ + msg for msg in self.getMessages("guest1") if msg.command == "JOIN" + ] self.assertEqual(len(unvoiced_join_msgs), 0) - self.connectClient('guest3', name='guest3', capabilities=MODERN_CAPS) - self.joinChannel('guest3', '#auditorium') - self.getMessages('guest3') + self.connectClient("guest3", name="guest3", capabilities=MODERN_CAPS) + self.joinChannel("guest3", "#auditorium") + self.getMessages("guest3") - self.sendLine('bar', 'PRIVMSG #auditorium hi') - echo_message = [msg for msg in self.getMessages('bar') if msg.command == 'PRIVMSG'][0] - self.assertEqual(echo_message, self.getMessages('guest1')[0]) - self.assertEqual(echo_message, self.getMessages('guest2')[0]) - self.assertEqual(echo_message, self.getMessages('guest3')[0]) + self.sendLine("bar", "PRIVMSG #auditorium hi") + echo_message = [ + msg for msg in self.getMessages("bar") if msg.command == "PRIVMSG" + ][0] + self.assertEqual(echo_message, self.getMessages("guest1")[0]) + self.assertEqual(echo_message, self.getMessages("guest2")[0]) + self.assertEqual(echo_message, self.getMessages("guest3")[0]) # unvoiced users can speak - self.sendLine('guest1', 'PRIVMSG #auditorium :hi you') - echo_message = [msg for msg in self.getMessages('guest1') if msg.command == 'PRIVMSG'][0] - self.assertEqual(self.getMessages('bar'), [echo_message]) - self.assertEqual(self.getMessages('guest2'), [echo_message]) - self.assertEqual(self.getMessages('guest3'), [echo_message]) + self.sendLine("guest1", "PRIVMSG #auditorium :hi you") + echo_message = [ + msg for msg in self.getMessages("guest1") if msg.command == "PRIVMSG" + ][0] + self.assertEqual(self.getMessages("bar"), [echo_message]) + self.assertEqual(self.getMessages("guest2"), [echo_message]) + self.assertEqual(self.getMessages("guest3"), [echo_message]) def names(client): - self.sendLine(client, 'NAMES #auditorium') + self.sendLine(client, "NAMES #auditorium") result = set() for msg in self.getMessages(client): if msg.command == RPL_NAMREPLY: result.update(msg.params[-1].split()) return result - self.assertEqual(names('bar'), {'@bar', 'guest1', 'guest2', 'guest3'}) - self.assertEqual(names('guest1'), {'@bar',}) - self.assertEqual(names('guest2'), {'@bar',}) - self.assertEqual(names('guest3'), {'@bar',}) + self.assertEqual(names("bar"), {"@bar", "guest1", "guest2", "guest3"}) + self.assertEqual( + names("guest1"), + { + "@bar", + }, + ) + self.assertEqual( + names("guest2"), + { + "@bar", + }, + ) + self.assertEqual( + names("guest3"), + { + "@bar", + }, + ) - self.sendLine('bar', 'MODE #auditorium +v guest1') - modeLine = [msg for msg in self.getMessages('bar') if msg.command == 'MODE'][0] - self.assertEqual(self.getMessages('guest1'), [modeLine]) - self.assertEqual(self.getMessages('guest2'), [modeLine]) - self.assertEqual(self.getMessages('guest3'), [modeLine]) - self.assertEqual(names('bar'), {'@bar', '+guest1', 'guest2', 'guest3'}) - self.assertEqual(names('guest2'), {'@bar', '+guest1'}) - self.assertEqual(names('guest3'), {'@bar', '+guest1'}) + self.sendLine("bar", "MODE #auditorium +v guest1") + modeLine = [msg for msg in self.getMessages("bar") if msg.command == "MODE"][0] + self.assertEqual(self.getMessages("guest1"), [modeLine]) + self.assertEqual(self.getMessages("guest2"), [modeLine]) + self.assertEqual(self.getMessages("guest3"), [modeLine]) + self.assertEqual(names("bar"), {"@bar", "+guest1", "guest2", "guest3"}) + self.assertEqual(names("guest2"), {"@bar", "+guest1"}) + self.assertEqual(names("guest3"), {"@bar", "+guest1"}) - self.sendLine('guest1', 'PART #auditorium') - part = [msg for msg in self.getMessages('guest1') if msg.command == 'PART'][0] + self.sendLine("guest1", "PART #auditorium") + part = [msg for msg in self.getMessages("guest1") if msg.command == "PART"][0] # everyone should see voiced PART - self.assertEqual(self.getMessages('bar')[0], part) - self.assertEqual(self.getMessages('guest2')[0], part) - self.assertEqual(self.getMessages('guest3')[0], part) + self.assertEqual(self.getMessages("bar")[0], part) + self.assertEqual(self.getMessages("guest2")[0], part) + self.assertEqual(self.getMessages("guest3")[0], part) - self.joinChannel('guest1', '#auditorium') - self.getMessages('guest1') - self.getMessages('bar') + self.joinChannel("guest1", "#auditorium") + self.getMessages("guest1") + self.getMessages("bar") - self.sendLine('guest2', 'PART #auditorium') - part = [msg for msg in self.getMessages('guest2') if msg.command == 'PART'][0] - self.assertEqual(self.getMessages('bar'), [part]) + self.sendLine("guest2", "PART #auditorium") + part = [msg for msg in self.getMessages("guest2") if msg.command == "PART"][0] + self.assertEqual(self.getMessages("bar"), [part]) # part should be hidden from unvoiced participants - self.assertEqual(self.getMessages('guest1'), []) - self.assertEqual(self.getMessages('guest3'), []) + self.assertEqual(self.getMessages("guest1"), []) + self.assertEqual(self.getMessages("guest3"), []) - self.sendLine('guest3', 'QUIT') - self.assertDisconnected('guest3') + self.sendLine("guest3", "QUIT") + self.assertDisconnected("guest3") # quit should be hidden from unvoiced participants - self.assertEqual(len([msg for msg in self.getMessages('bar') if msg.command =='QUIT']), 1) - self.assertEqual(len([msg for msg in self.getMessages('guest1') if msg.command =='QUIT']), 0) + self.assertEqual( + len([msg for msg in self.getMessages("bar") if msg.command == "QUIT"]), 1 + ) + self.assertEqual( + len([msg for msg in self.getMessages("guest1") if msg.command == "QUIT"]), 0 + ) class TopicPrivileges(cases.BaseServerTestCase): - - @cases.SpecificationSelector.requiredBySpecification('RFC2812') + @cases.SpecificationSelector.requiredBySpecification("RFC2812") def testTopicPrivileges(self): # test the +t channel mode, which prevents unprivileged users from changing the topic - self.connectClient('bar', name='bar') - self.joinChannel('bar', '#chan') - self.getMessages('bar') - self.sendLine('bar', 'MODE #chan +t') - replies = {msg.command for msg in self.getMessages('bar')} + self.connectClient("bar", name="bar") + self.joinChannel("bar", "#chan") + self.getMessages("bar") + self.sendLine("bar", "MODE #chan +t") + replies = {msg.command for msg in self.getMessages("bar")} # success response is undefined, may be MODE or may be 324 RPL_CHANNELMODEIS, # depending on whether this was a no-op self.assertNotIn(ERR_CHANOPRIVSNEEDED, replies) - self.sendLine('bar', 'TOPIC #chan :new topic') - replies = {msg.command for msg in self.getMessages('bar')} - self.assertIn('TOPIC', replies) + self.sendLine("bar", "TOPIC #chan :new topic") + replies = {msg.command for msg in self.getMessages("bar")} + self.assertIn("TOPIC", replies) self.assertNotIn(ERR_CHANOPRIVSNEEDED, replies) - self.connectClient('qux', name='qux') - self.joinChannel('qux', '#chan') - self.getMessages('qux') - self.sendLine('qux', 'TOPIC #chan :new topic') - replies = {msg.command for msg in self.getMessages('qux')} + self.connectClient("qux", name="qux") + self.joinChannel("qux", "#chan") + self.getMessages("qux") + self.sendLine("qux", "TOPIC #chan :new topic") + replies = {msg.command for msg in self.getMessages("qux")} self.assertIn(ERR_CHANOPRIVSNEEDED, replies) - self.assertNotIn('TOPIC', replies) + self.assertNotIn("TOPIC", replies) - self.sendLine('bar', 'MODE #chan +v qux') - replies = {msg.command for msg in self.getMessages('bar')} - self.assertIn('MODE', replies) + self.sendLine("bar", "MODE #chan +v qux") + replies = {msg.command for msg in self.getMessages("bar")} + self.assertIn("MODE", replies) self.assertNotIn(ERR_CHANOPRIVSNEEDED, replies) # regression test: +v cannot change the topic of a +t channel - self.sendLine('qux', 'TOPIC #chan :new topic') - replies = {msg.command for msg in self.getMessages('qux')} + self.sendLine("qux", "TOPIC #chan :new topic") + replies = {msg.command for msg in self.getMessages("qux")} self.assertIn(ERR_CHANOPRIVSNEEDED, replies) - self.assertNotIn('TOPIC', replies) + self.assertNotIn("TOPIC", replies) # test that RPL_TOPIC and RPL_TOPICTIME are sent on join - self.connectClient('buzz', name='buzz') - self.sendLine('buzz', 'JOIN #chan') - replies = self.getMessages('buzz') + self.connectClient("buzz", name="buzz") + self.sendLine("buzz", "JOIN #chan") + replies = self.getMessages("buzz") rpl_topic = [msg for msg in replies if msg.command == RPL_TOPIC][0] - self.assertMessageEqual(rpl_topic, command=RPL_TOPIC, params=['buzz', '#chan', 'new topic']) - self.assertEqual(len([msg for msg in replies if msg.command == RPL_TOPICTIME]), 1) + self.assertMessageEqual( + rpl_topic, command=RPL_TOPIC, params=["buzz", "#chan", "new topic"] + ) + self.assertEqual( + len([msg for msg in replies if msg.command == RPL_TOPICTIME]), 1 + ) class ModeratedMode(cases.BaseServerTestCase): - - @cases.SpecificationSelector.requiredBySpecification('RFC2812') + @cases.SpecificationSelector.requiredBySpecification("RFC2812") def testModeratedMode(self): # test the +m channel mode - self.connectClient('chanop', name='chanop') - self.joinChannel('chanop', '#chan') - self.getMessages('chanop') - self.sendLine('chanop', 'MODE #chan +m') - replies = self.getMessages('chanop') - modeLines = [line for line in replies if line.command == 'MODE'] - self.assertMessageEqual(modeLines[0], command='MODE', params=['#chan', '+m']) + self.connectClient("chanop", name="chanop") + self.joinChannel("chanop", "#chan") + self.getMessages("chanop") + self.sendLine("chanop", "MODE #chan +m") + replies = self.getMessages("chanop") + modeLines = [line for line in replies if line.command == "MODE"] + self.assertMessageEqual(modeLines[0], command="MODE", params=["#chan", "+m"]) - self.connectClient('baz', name='baz') - self.joinChannel('baz', '#chan') - self.getMessages('chanop') + self.connectClient("baz", name="baz") + self.joinChannel("baz", "#chan") + self.getMessages("chanop") # this message should be suppressed completely by +m - self.sendLine('baz', 'PRIVMSG #chan :hi from baz') - replies = self.getMessages('baz') + self.sendLine("baz", "PRIVMSG #chan :hi from baz") + replies = self.getMessages("baz") reply_cmds = {reply.command for reply in replies} self.assertIn(ERR_CANNOTSENDTOCHAN, reply_cmds) - self.assertEqual(self.getMessages('chanop'), []) + self.assertEqual(self.getMessages("chanop"), []) # grant +v, user should be able to send messages - self.sendLine('chanop', 'MODE #chan +v baz') - self.getMessages('chanop') - self.getMessages('baz') - self.sendLine('baz', 'PRIVMSG #chan :hi again from baz') - self.getMessages('baz') - relays = self.getMessages('chanop') + self.sendLine("chanop", "MODE #chan +v baz") + self.getMessages("chanop") + self.getMessages("baz") + self.sendLine("baz", "PRIVMSG #chan :hi again from baz") + self.getMessages("baz") + relays = self.getMessages("chanop") relay = relays[0] - self.assertMessageEqual(relay, command='PRIVMSG', params=['#chan', 'hi again from baz']) + self.assertMessageEqual( + relay, command="PRIVMSG", params=["#chan", "hi again from baz"] + ) class OpModerated(cases.BaseServerTestCase): - - @cases.SpecificationSelector.requiredBySpecification('Oragono') + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testOpModerated(self): # test the +U channel mode - self.connectClient('chanop', name='chanop', capabilities=MODERN_CAPS) - self.joinChannel('chanop', '#chan') - self.getMessages('chanop') - self.sendLine('chanop', 'MODE #chan +U') - replies = {msg.command for msg in self.getMessages('chanop')} - self.assertIn('MODE', replies) + self.connectClient("chanop", name="chanop", capabilities=MODERN_CAPS) + self.joinChannel("chanop", "#chan") + self.getMessages("chanop") + self.sendLine("chanop", "MODE #chan +U") + replies = {msg.command for msg in self.getMessages("chanop")} + self.assertIn("MODE", replies) self.assertNotIn(ERR_CHANOPRIVSNEEDED, replies) - self.connectClient('baz', name='baz', capabilities=MODERN_CAPS) - self.joinChannel('baz', '#chan') - self.sendLine('baz', 'PRIVMSG #chan :hi from baz') - echo = self.getMessages('baz')[0] - self.assertMessageEqual(echo, command='PRIVMSG', params=['#chan', 'hi from baz']) - self.assertEqual([msg for msg in self.getMessages('chanop') if msg.command == 'PRIVMSG'], [echo]) + self.connectClient("baz", name="baz", capabilities=MODERN_CAPS) + self.joinChannel("baz", "#chan") + self.sendLine("baz", "PRIVMSG #chan :hi from baz") + echo = self.getMessages("baz")[0] + self.assertMessageEqual( + echo, command="PRIVMSG", params=["#chan", "hi from baz"] + ) + self.assertEqual( + [msg for msg in self.getMessages("chanop") if msg.command == "PRIVMSG"], + [echo], + ) - self.connectClient('qux', name='qux', capabilities=MODERN_CAPS) - self.joinChannel('qux', '#chan') - self.sendLine('qux', 'PRIVMSG #chan :hi from qux') - echo = self.getMessages('qux')[0] - self.assertMessageEqual(echo, command='PRIVMSG', params=['#chan', 'hi from qux']) + self.connectClient("qux", name="qux", capabilities=MODERN_CAPS) + self.joinChannel("qux", "#chan") + self.sendLine("qux", "PRIVMSG #chan :hi from qux") + echo = self.getMessages("qux")[0] + self.assertMessageEqual( + echo, command="PRIVMSG", params=["#chan", "hi from qux"] + ) # message is relayed to chanop but not to unprivileged - self.assertEqual([msg for msg in self.getMessages('chanop') if msg.command == 'PRIVMSG'], [echo]) - self.assertEqual([msg for msg in self.getMessages('baz') if msg.command == 'PRIVMSG'], []) + self.assertEqual( + [msg for msg in self.getMessages("chanop") if msg.command == "PRIVMSG"], + [echo], + ) + self.assertEqual( + [msg for msg in self.getMessages("baz") if msg.command == "PRIVMSG"], [] + ) - self.sendLine('chanop', 'MODE #chan +v qux') - self.getMessages('chanop') - self.sendLine('qux', 'PRIVMSG #chan :hi again from qux') - echo = [msg for msg in self.getMessages('qux') if msg.command == 'PRIVMSG'][0] - self.assertMessageEqual(echo, command='PRIVMSG', params=['#chan', 'hi again from qux']) - self.assertEqual([msg for msg in self.getMessages('chanop') if msg.command == 'PRIVMSG'], [echo]) - self.assertEqual([msg for msg in self.getMessages('baz') if msg.command == 'PRIVMSG'], [echo]) + self.sendLine("chanop", "MODE #chan +v qux") + self.getMessages("chanop") + self.sendLine("qux", "PRIVMSG #chan :hi again from qux") + echo = [msg for msg in self.getMessages("qux") if msg.command == "PRIVMSG"][0] + self.assertMessageEqual( + echo, command="PRIVMSG", params=["#chan", "hi again from qux"] + ) + self.assertEqual( + [msg for msg in self.getMessages("chanop") if msg.command == "PRIVMSG"], + [echo], + ) + self.assertEqual( + [msg for msg in self.getMessages("baz") if msg.command == "PRIVMSG"], [echo] + ) class MuteExtban(cases.BaseServerTestCase): - - @cases.SpecificationSelector.requiredBySpecification('Oragono') + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testISupport(self): isupport = self.getISupport() - token = isupport['EXTBAN'] - prefix, comma, types = token.partition(',') - self.assertEqual(prefix, '') - self.assertEqual(comma, ',') - self.assertIn('m', types) + token = isupport["EXTBAN"] + prefix, comma, types = token.partition(",") + self.assertEqual(prefix, "") + self.assertEqual(comma, ",") + self.assertIn("m", types) - @cases.SpecificationSelector.requiredBySpecification('Oragono') + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testMuteExtban(self): - clients = ('chanop', 'bar', 'qux') + clients = ("chanop", "bar", "qux") - self.connectClient('chanop', name='chanop', capabilities=MODERN_CAPS) - self.joinChannel('chanop', '#chan') - self.getMessages('chanop') - self.sendLine('chanop', 'MODE #chan +b m:bar!*@*') - self.sendLine('chanop', 'MODE #chan +b m:qux!*@*') - replies = {msg.command for msg in self.getMessages('chanop')} - self.assertIn('MODE', replies) + self.connectClient("chanop", name="chanop", capabilities=MODERN_CAPS) + self.joinChannel("chanop", "#chan") + self.getMessages("chanop") + self.sendLine("chanop", "MODE #chan +b m:bar!*@*") + self.sendLine("chanop", "MODE #chan +b m:qux!*@*") + replies = {msg.command for msg in self.getMessages("chanop")} + self.assertIn("MODE", replies) self.assertNotIn(ERR_CHANOPRIVSNEEDED, replies) - self.connectClient('bar', name='bar', capabilities=MODERN_CAPS) - self.joinChannel('bar', '#chan') - self.connectClient('qux', name='qux', capabilities=MODERN_CAPS, ident='evan') - self.joinChannel('qux', '#chan') + self.connectClient("bar", name="bar", capabilities=MODERN_CAPS) + self.joinChannel("bar", "#chan") + self.connectClient("qux", name="qux", capabilities=MODERN_CAPS, ident="evan") + self.joinChannel("qux", "#chan") for client in clients: self.getMessages(client) - self.sendLine('bar', 'PRIVMSG #chan :hi from bar') - replies = self.getMessages('bar') + self.sendLine("bar", "PRIVMSG #chan :hi from bar") + replies = self.getMessages("bar") replies_cmds = {msg.command for msg in replies} - self.assertNotIn('PRIVMSG', replies_cmds) + self.assertNotIn("PRIVMSG", replies_cmds) self.assertIn(ERR_CANNOTSENDTOCHAN, replies_cmds) - self.assertEqual(self.getMessages('chanop'), []) + self.assertEqual(self.getMessages("chanop"), []) - self.sendLine('qux', 'PRIVMSG #chan :hi from qux') - replies = self.getMessages('qux') + self.sendLine("qux", "PRIVMSG #chan :hi from qux") + replies = self.getMessages("qux") replies_cmds = {msg.command for msg in replies} - self.assertNotIn('PRIVMSG', replies_cmds) + self.assertNotIn("PRIVMSG", replies_cmds) self.assertIn(ERR_CANNOTSENDTOCHAN, replies_cmds) - self.assertEqual(self.getMessages('chanop'), []) + self.assertEqual(self.getMessages("chanop"), []) # remove mute with -b - self.sendLine('chanop', 'MODE #chan -b m:bar!*@*') - self.getMessages('chanop') - self.sendLine('bar', 'PRIVMSG #chan :hi again from bar') - replies = self.getMessages('bar') + self.sendLine("chanop", "MODE #chan -b m:bar!*@*") + self.getMessages("chanop") + self.sendLine("bar", "PRIVMSG #chan :hi again from bar") + replies = self.getMessages("bar") replies_cmds = {msg.command for msg in replies} - self.assertIn('PRIVMSG', replies_cmds) + self.assertIn("PRIVMSG", replies_cmds) self.assertNotIn(ERR_CANNOTSENDTOCHAN, replies_cmds) - self.assertEqual(self.getMessages('chanop'), [msg for msg in replies if msg.command == 'PRIVMSG']) + self.assertEqual( + self.getMessages("chanop"), + [msg for msg in replies if msg.command == "PRIVMSG"], + ) for client in clients: self.getMessages(client) # +v grants an exemption to +b - self.sendLine('chanop', 'MODE #chan +v qux') - self.getMessages('chanop') - self.sendLine('qux', 'PRIVMSG #chan :hi again from qux') - replies = self.getMessages('qux') + self.sendLine("chanop", "MODE #chan +v qux") + self.getMessages("chanop") + self.sendLine("qux", "PRIVMSG #chan :hi again from qux") + replies = self.getMessages("qux") replies_cmds = {msg.command for msg in replies} - self.assertIn('PRIVMSG', replies_cmds) + self.assertIn("PRIVMSG", replies_cmds) self.assertNotIn(ERR_CANNOTSENDTOCHAN, replies_cmds) - self.assertEqual(self.getMessages('chanop'), [msg for msg in replies if msg.command == 'PRIVMSG']) + self.assertEqual( + self.getMessages("chanop"), + [msg for msg in replies if msg.command == "PRIVMSG"], + ) - self.sendLine('qux', 'PART #chan') - self.sendLine('qux', 'JOIN #chan') - self.getMessages('qux') - self.sendLine('chanop', 'MODE #chan +e m:*!~evan@*') - self.getMessages('chanop') + self.sendLine("qux", "PART #chan") + self.sendLine("qux", "JOIN #chan") + self.getMessages("qux") + self.sendLine("chanop", "MODE #chan +e m:*!~evan@*") + self.getMessages("chanop") # +e grants an exemption to +b - self.sendLine('qux', 'PRIVMSG #chan :thanks for mute-excepting me') - replies = self.getMessages('qux') + self.sendLine("qux", "PRIVMSG #chan :thanks for mute-excepting me") + replies = self.getMessages("qux") replies_cmds = {msg.command for msg in replies} - self.assertIn('PRIVMSG', replies_cmds) + self.assertIn("PRIVMSG", replies_cmds) self.assertNotIn(ERR_CANNOTSENDTOCHAN, replies_cmds) - self.assertEqual(self.getMessages('chanop'), [msg for msg in replies if msg.command == 'PRIVMSG']) + self.assertEqual( + self.getMessages("chanop"), + [msg for msg in replies if msg.command == "PRIVMSG"], + ) - @cases.SpecificationSelector.requiredBySpecification('Oragono') + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testIssue1370(self): # regression test for oragono #1370: mutes not correctly enforced against # users with capital letters in their NUH - clients = ('chanop', 'bar') + clients = ("chanop", "bar") - self.connectClient('chanop', name='chanop', capabilities=MODERN_CAPS) - self.joinChannel('chanop', '#chan') - self.getMessages('chanop') - self.sendLine('chanop', 'MODE #chan +b m:BAR!*@*') - replies = {msg.command for msg in self.getMessages('chanop')} - self.assertIn('MODE', replies) + self.connectClient("chanop", name="chanop", capabilities=MODERN_CAPS) + self.joinChannel("chanop", "#chan") + self.getMessages("chanop") + self.sendLine("chanop", "MODE #chan +b m:BAR!*@*") + replies = {msg.command for msg in self.getMessages("chanop")} + self.assertIn("MODE", replies) self.assertNotIn(ERR_CHANOPRIVSNEEDED, replies) - self.connectClient('Bar', name='bar', capabilities=MODERN_CAPS) - self.joinChannel('bar', '#chan') + self.connectClient("Bar", name="bar", capabilities=MODERN_CAPS) + self.joinChannel("bar", "#chan") for client in clients: self.getMessages(client) - self.sendLine('bar', 'PRIVMSG #chan :hi from bar') - replies = self.getMessages('bar') + self.sendLine("bar", "PRIVMSG #chan :hi from bar") + replies = self.getMessages("bar") replies_cmds = {msg.command for msg in replies} - self.assertNotIn('PRIVMSG', replies_cmds) + self.assertNotIn("PRIVMSG", replies_cmds) self.assertIn(ERR_CANNOTSENDTOCHAN, replies_cmds) - self.assertEqual(self.getMessages('chanop'), []) + self.assertEqual(self.getMessages("chanop"), []) # remove mute with -b - self.sendLine('chanop', 'MODE #chan -b m:bar!*@*') - self.getMessages('chanop') - self.sendLine('bar', 'PRIVMSG #chan :hi again from bar') - replies = self.getMessages('bar') + self.sendLine("chanop", "MODE #chan -b m:bar!*@*") + self.getMessages("chanop") + self.sendLine("bar", "PRIVMSG #chan :hi again from bar") + replies = self.getMessages("bar") replies_cmds = {msg.command for msg in replies} - self.assertIn('PRIVMSG', replies_cmds) + self.assertIn("PRIVMSG", replies_cmds) self.assertNotIn(ERR_CANNOTSENDTOCHAN, replies_cmds) - self.assertEqual(self.getMessages('chanop'), [msg for msg in replies if msg.command == 'PRIVMSG']) + self.assertEqual( + self.getMessages("chanop"), + [msg for msg in replies if msg.command == "PRIVMSG"], + ) diff --git a/irctest/server_tests/test_channel_rename.py b/irctest/server_tests/test_channel_rename.py index 724d869..fd673ab 100644 --- a/irctest/server_tests/test_channel_rename.py +++ b/irctest/server_tests/test_channel_rename.py @@ -1,27 +1,59 @@ from irctest import cases from irctest.numerics import ERR_CHANOPRIVSNEEDED -MODERN_CAPS = ['server-time', 'message-tags', 'batch', 'labeled-response', 'echo-message', 'account-tag'] -RENAME_CAP = 'draft/channel-rename' +MODERN_CAPS = [ + "server-time", + "message-tags", + "batch", + "labeled-response", + "echo-message", + "account-tag", +] +RENAME_CAP = "draft/channel-rename" + class ChannelRename(cases.BaseServerTestCase): """Basic tests for channel-rename.""" - @cases.SpecificationSelector.requiredBySpecification('Oragono') + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testChannelRename(self): - self.connectClient('bar', name='bar', capabilities=MODERN_CAPS+[RENAME_CAP]) - self.connectClient('baz', name='baz', capabilities=MODERN_CAPS) - self.joinChannel('bar', '#bar') - self.joinChannel('baz', '#bar') - self.getMessages('bar') - self.getMessages('baz') + self.connectClient("bar", name="bar", capabilities=MODERN_CAPS + [RENAME_CAP]) + self.connectClient("baz", name="baz", capabilities=MODERN_CAPS) + self.joinChannel("bar", "#bar") + self.joinChannel("baz", "#bar") + self.getMessages("bar") + self.getMessages("baz") - self.sendLine('bar', 'RENAME #bar #qux :no reason') - self.assertMessageEqual(self.getMessage('bar'), command='RENAME', params=['#bar', '#qux', 'no reason']) - legacy_responses = self.getMessages('baz') - self.assertEqual(1, len([msg for msg in legacy_responses if msg.command == 'PART' and msg.params[0] == '#bar'])) - self.assertEqual(1, len([msg for msg in legacy_responses if msg.command == 'JOIN' and msg.params == ['#qux']])) + self.sendLine("bar", "RENAME #bar #qux :no reason") + self.assertMessageEqual( + self.getMessage("bar"), + command="RENAME", + params=["#bar", "#qux", "no reason"], + ) + legacy_responses = self.getMessages("baz") + self.assertEqual( + 1, + len( + [ + msg + for msg in legacy_responses + if msg.command == "PART" and msg.params[0] == "#bar" + ] + ), + ) + self.assertEqual( + 1, + len( + [ + msg + for msg in legacy_responses + if msg.command == "JOIN" and msg.params == ["#qux"] + ] + ), + ) - self.joinChannel('baz', '#bar') - self.sendLine('baz', 'MODE #bar +k beer') - self.assertNotIn(ERR_CHANOPRIVSNEEDED, [msg.command for msg in self.getMessages('baz')]) + self.joinChannel("baz", "#bar") + self.sendLine("baz", "MODE #bar +k beer") + self.assertNotIn( + ERR_CHANOPRIVSNEEDED, [msg.command for msg in self.getMessages("baz")] + ) diff --git a/irctest/server_tests/test_chathistory.py b/irctest/server_tests/test_chathistory.py index f1e5c23..61b7599 100644 --- a/irctest/server_tests/test_chathistory.py +++ b/irctest/server_tests/test_chathistory.py @@ -4,12 +4,13 @@ import time from irctest import cases from irctest.irc_utils.junkdrawer import to_history_message, random_name -CHATHISTORY_CAP = 'draft/chathistory' -EVENT_PLAYBACK_CAP = 'draft/event-playback' +CHATHISTORY_CAP = "draft/chathistory" +EVENT_PLAYBACK_CAP = "draft/event-playback" MYSQL_PASSWORD = "" + def validate_chathistory_batch(msgs): batch_tag = None closed_batch_tag = None @@ -17,91 +18,120 @@ def validate_chathistory_batch(msgs): for msg in msgs: if msg.command == "BATCH": batch_param = msg.params[0] - if batch_tag is None and batch_param[0] == '+': + if batch_tag is None and batch_param[0] == "+": batch_tag = batch_param[1:] - elif batch_param[0] == '-': + elif batch_param[0] == "-": closed_batch_tag = batch_param[1:] - elif msg.command == "PRIVMSG" and batch_tag is not None and msg.tags.get("batch") == batch_tag: + elif ( + msg.command == "PRIVMSG" + and batch_tag is not None + and msg.tags.get("batch") == batch_tag + ): result.append(to_history_message(msg)) assert batch_tag == closed_batch_tag return result + class ChathistoryTestCase(cases.BaseServerTestCase): @staticmethod def config(): return { - "chathistory": True, - } + "chathistory": True, + } - @cases.SpecificationSelector.requiredBySpecification('Oragono') + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testInvalidTargets(self): - bar, pw = random_name('bar'), random_name('pw') + bar, pw = random_name("bar"), random_name("pw") self.controller.registerUser(self, bar, pw) - self.connectClient(bar, name=bar, capabilities=['batch', 'labeled-response', 'message-tags', 'server-time', CHATHISTORY_CAP, EVENT_PLAYBACK_CAP], password=pw) + self.connectClient( + bar, + name=bar, + capabilities=[ + "batch", + "labeled-response", + "message-tags", + "server-time", + CHATHISTORY_CAP, + EVENT_PLAYBACK_CAP, + ], + password=pw, + ) self.getMessages(bar) - qux = random_name('qux') - real_chname = random_name('#real_channel') + qux = random_name("qux") + real_chname = random_name("#real_channel") self.connectClient(qux, name=qux) self.joinChannel(qux, real_chname) self.getMessages(qux) # test a nonexistent channel - self.sendLine(bar, 'CHATHISTORY LATEST #nonexistent_channel * 10') + self.sendLine(bar, "CHATHISTORY LATEST #nonexistent_channel * 10") msgs = self.getMessages(bar) - self.assertEqual(msgs[0].command, 'FAIL') - self.assertEqual(msgs[0].params[:2], ['CHATHISTORY', 'INVALID_TARGET']) + self.assertEqual(msgs[0].command, "FAIL") + self.assertEqual(msgs[0].params[:2], ["CHATHISTORY", "INVALID_TARGET"]) # as should a real channel to which one is not joined: - self.sendLine(bar, 'CHATHISTORY LATEST %s * 10' % (real_chname,)) + self.sendLine(bar, "CHATHISTORY LATEST %s * 10" % (real_chname,)) msgs = self.getMessages(bar) - self.assertEqual(msgs[0].command, 'FAIL') - self.assertEqual(msgs[0].params[:2], ['CHATHISTORY', 'INVALID_TARGET']) + self.assertEqual(msgs[0].command, "FAIL") + self.assertEqual(msgs[0].params[:2], ["CHATHISTORY", "INVALID_TARGET"]) - @cases.SpecificationSelector.requiredBySpecification('Oragono') + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testMessagesToSelf(self): - bar, pw = random_name('bar'), random_name('pw') + bar, pw = random_name("bar"), random_name("pw") self.controller.registerUser(self, bar, pw) - self.connectClient(bar, name=bar, capabilities=['batch', 'labeled-response', 'message-tags', 'server-time'], password=pw) + self.connectClient( + bar, + name=bar, + capabilities=["batch", "labeled-response", "message-tags", "server-time"], + password=pw, + ) self.getMessages(bar) messages = [] - self.sendLine(bar, 'PRIVMSG %s :this is a privmsg sent to myself' % (bar,)) - replies = [msg for msg in self.getMessages(bar) if msg.command == 'PRIVMSG'] + self.sendLine(bar, "PRIVMSG %s :this is a privmsg sent to myself" % (bar,)) + replies = [msg for msg in self.getMessages(bar) if msg.command == "PRIVMSG"] self.assertEqual(len(replies), 1) msg = replies[0] - self.assertEqual(msg.params, [bar, 'this is a privmsg sent to myself']) + self.assertEqual(msg.params, [bar, "this is a privmsg sent to myself"]) messages.append(to_history_message(msg)) - self.sendLine(bar, 'CAP REQ echo-message') + self.sendLine(bar, "CAP REQ echo-message") self.getMessages(bar) - self.sendLine(bar, 'PRIVMSG %s :this is a second privmsg sent to myself' % (bar,)) - replies = [msg for msg in self.getMessages(bar) if msg.command == 'PRIVMSG'] + self.sendLine( + bar, "PRIVMSG %s :this is a second privmsg sent to myself" % (bar,) + ) + replies = [msg for msg in self.getMessages(bar) if msg.command == "PRIVMSG"] # two messages, the echo and the delivery self.assertEqual(len(replies), 2) - self.assertEqual(replies[0].params, [bar, 'this is a second privmsg sent to myself']) + self.assertEqual( + replies[0].params, [bar, "this is a second privmsg sent to myself"] + ) messages.append(to_history_message(replies[0])) # messages should be otherwise identical self.assertEqual(to_history_message(replies[0]), to_history_message(replies[1])) - self.sendLine(bar, '@label=xyz PRIVMSG %s :this is a third privmsg sent to myself' % (bar,)) - replies = [msg for msg in self.getMessages(bar) if msg.command == 'PRIVMSG'] + self.sendLine( + bar, + "@label=xyz PRIVMSG %s :this is a third privmsg sent to myself" % (bar,), + ) + replies = [msg for msg in self.getMessages(bar) if msg.command == "PRIVMSG"] self.assertEqual(len(replies), 2) # exactly one of the replies MUST be labeled - echo = [msg for msg in replies if msg.tags.get('label') == 'xyz'][0] - delivery = [msg for msg in replies if msg.tags.get('label') is None][0] - self.assertEqual(echo.params, [bar, 'this is a third privmsg sent to myself']) + echo = [msg for msg in replies if msg.tags.get("label") == "xyz"][0] + delivery = [msg for msg in replies if msg.tags.get("label") is None][0] + self.assertEqual(echo.params, [bar, "this is a third privmsg sent to myself"]) messages.append(to_history_message(echo)) self.assertEqual(to_history_message(echo), to_history_message(delivery)) # should receive exactly 3 messages in the correct order, no duplicates - self.sendLine(bar, 'CHATHISTORY LATEST * * 10') - replies = [msg for msg in self.getMessages(bar) if msg.command == 'PRIVMSG'] + self.sendLine(bar, "CHATHISTORY LATEST * * 10") + replies = [msg for msg in self.getMessages(bar) if msg.command == "PRIVMSG"] self.assertEqual([to_history_message(msg) for msg in replies], messages) - self.sendLine(bar, 'CHATHISTORY LATEST %s * 10' % (bar,)) - replies = [msg for msg in self.getMessages(bar) if msg.command == 'PRIVMSG'] + self.sendLine(bar, "CHATHISTORY LATEST %s * 10" % (bar,)) + replies = [msg for msg in self.getMessages(bar) if msg.command == "PRIVMSG"] self.assertEqual([to_history_message(msg) for msg in replies], messages) def validate_echo_messages(self, num_messages, echo_messages): @@ -111,31 +141,66 @@ class ChathistoryTestCase(cases.BaseServerTestCase): self.assertEqual(len(set(msg.msgid for msg in echo_messages)), num_messages) self.assertEqual(len(set(msg.time for msg in echo_messages)), num_messages) - @cases.SpecificationSelector.requiredBySpecification('Oragono') + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testChathistory(self): - self.connectClient('bar', capabilities=['message-tags', 'server-time', 'echo-message', 'batch', 'labeled-response', CHATHISTORY_CAP, EVENT_PLAYBACK_CAP]) - chname = '#' + secrets.token_hex(12) + self.connectClient( + "bar", + capabilities=[ + "message-tags", + "server-time", + "echo-message", + "batch", + "labeled-response", + CHATHISTORY_CAP, + EVENT_PLAYBACK_CAP, + ], + ) + chname = "#" + secrets.token_hex(12) self.joinChannel(1, chname) self.getMessages(1) NUM_MESSAGES = 10 echo_messages = [] for i in range(NUM_MESSAGES): - self.sendLine(1, 'PRIVMSG %s :this is message %d' % (chname, i)) + self.sendLine(1, "PRIVMSG %s :this is message %d" % (chname, i)) echo_messages.extend(to_history_message(msg) for msg in self.getMessages(1)) time.sleep(0.002) self.validate_echo_messages(NUM_MESSAGES, echo_messages) self.validate_chathistory(echo_messages, 1, chname) - @cases.SpecificationSelector.requiredBySpecification('Oragono') + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testChathistoryDMs(self): c1 = secrets.token_hex(12) c2 = secrets.token_hex(12) - self.controller.registerUser(self, c1, 'sesame1') - self.controller.registerUser(self, c2, 'sesame2') - self.connectClient(c1, capabilities=['message-tags', 'server-time', 'echo-message', 'batch', 'labeled-response', CHATHISTORY_CAP, EVENT_PLAYBACK_CAP], password='sesame1') - self.connectClient(c2, capabilities=['message-tags', 'server-time', 'echo-message', 'batch', 'labeled-response', CHATHISTORY_CAP, EVENT_PLAYBACK_CAP], password='sesame2') + self.controller.registerUser(self, c1, "sesame1") + self.controller.registerUser(self, c2, "sesame2") + self.connectClient( + c1, + capabilities=[ + "message-tags", + "server-time", + "echo-message", + "batch", + "labeled-response", + CHATHISTORY_CAP, + EVENT_PLAYBACK_CAP, + ], + password="sesame1", + ) + self.connectClient( + c2, + capabilities=[ + "message-tags", + "server-time", + "echo-message", + "batch", + "labeled-response", + CHATHISTORY_CAP, + EVENT_PLAYBACK_CAP, + ], + password="sesame2", + ) self.getMessages(1) self.getMessages(2) @@ -148,29 +213,60 @@ class ChathistoryTestCase(cases.BaseServerTestCase): else: target = c1 self.getMessages(user) - self.sendLine(user, 'PRIVMSG %s :this is message %d' % (target, i)) - echo_messages.extend(to_history_message(msg) for msg in self.getMessages(user)) + self.sendLine(user, "PRIVMSG %s :this is message %d" % (target, i)) + echo_messages.extend( + to_history_message(msg) for msg in self.getMessages(user) + ) time.sleep(0.002) self.validate_echo_messages(NUM_MESSAGES, echo_messages) self.validate_chathistory(echo_messages, 1, c2) - self.validate_chathistory(echo_messages, 1, '*') + self.validate_chathistory(echo_messages, 1, "*") self.validate_chathistory(echo_messages, 2, c1) - self.validate_chathistory(echo_messages, 2, '*') + self.validate_chathistory(echo_messages, 2, "*") c3 = secrets.token_hex(12) - self.connectClient(c3, capabilities=['message-tags', 'server-time', 'echo-message', 'batch', 'labeled-response', CHATHISTORY_CAP, EVENT_PLAYBACK_CAP]) - self.sendLine(1, 'PRIVMSG %s :this is a message in a separate conversation' % (c3,)) + self.connectClient( + c3, + capabilities=[ + "message-tags", + "server-time", + "echo-message", + "batch", + "labeled-response", + CHATHISTORY_CAP, + EVENT_PLAYBACK_CAP, + ], + ) + self.sendLine( + 1, "PRIVMSG %s :this is a message in a separate conversation" % (c3,) + ) self.getMessages(1) - self.sendLine(3, 'PRIVMSG %s :i agree that this is a separate conversation' % (c1,)) + self.sendLine( + 3, "PRIVMSG %s :i agree that this is a separate conversation" % (c1,) + ) # 3 received the first message as a delivery and the second as an echo - new_convo = [to_history_message(msg) for msg in self.getMessages(3) if msg.command == 'PRIVMSG'] - self.assertEqual([msg.text for msg in new_convo], ['this is a message in a separate conversation', 'i agree that this is a separate conversation']) + new_convo = [ + to_history_message(msg) + for msg in self.getMessages(3) + if msg.command == "PRIVMSG" + ] + self.assertEqual( + [msg.text for msg in new_convo], + [ + "this is a message in a separate conversation", + "i agree that this is a separate conversation", + ], + ) # messages should be stored and retrievable by c1, even though c3 is not registered self.getMessages(1) - self.sendLine(1, 'CHATHISTORY LATEST %s * 10' % (c3,)) - results = [to_history_message(msg) for msg in self.getMessages(1) if msg.command == 'PRIVMSG'] + self.sendLine(1, "CHATHISTORY LATEST %s * 10" % (c3,)) + results = [ + to_history_message(msg) + for msg in self.getMessages(1) + if msg.command == "PRIVMSG" + ] self.assertEqual(results, new_convo) # additional messages with c3 should not show up in the c1-c2 history: @@ -179,14 +275,31 @@ class ChathistoryTestCase(cases.BaseServerTestCase): self.validate_chathistory(echo_messages, 2, c1.upper()) # regression test for #833 - self.sendLine(3, 'QUIT') + self.sendLine(3, "QUIT") self.assertDisconnected(3) # register c3 as an account, then attempt to retrieve the conversation history with c1 - self.controller.registerUser(self, c3, 'sesame3') - self.connectClient(c3, name=c3, capabilities=['message-tags', 'server-time', 'echo-message', 'batch', 'labeled-response', CHATHISTORY_CAP, EVENT_PLAYBACK_CAP], password='sesame3') + self.controller.registerUser(self, c3, "sesame3") + self.connectClient( + c3, + name=c3, + capabilities=[ + "message-tags", + "server-time", + "echo-message", + "batch", + "labeled-response", + CHATHISTORY_CAP, + EVENT_PLAYBACK_CAP, + ], + password="sesame3", + ) self.getMessages(c3) - self.sendLine(c3, 'CHATHISTORY LATEST %s * 10' % (c1,)) - results = [to_history_message(msg) for msg in self.getMessages(c3) if msg.command == 'PRIVMSG'] + self.sendLine(c3, "CHATHISTORY LATEST %s * 10" % (c1,)) + results = [ + to_history_message(msg) + for msg in self.getMessages(c3) + if msg.command == "PRIVMSG" + ] # should get nothing self.assertEqual(results, []) @@ -205,105 +318,213 @@ class ChathistoryTestCase(cases.BaseServerTestCase): result = validate_chathistory_batch(self.getMessages(user)) self.assertEqual(echo_messages[-1:], result) - self.sendLine(user, "CHATHISTORY LATEST %s msgid=%s %d" % (chname, echo_messages[4].msgid, INCLUSIVE_LIMIT)) + self.sendLine( + user, + "CHATHISTORY LATEST %s msgid=%s %d" + % (chname, echo_messages[4].msgid, INCLUSIVE_LIMIT), + ) result = validate_chathistory_batch(self.getMessages(user)) self.assertEqual(echo_messages[5:], result) - self.sendLine(user, "CHATHISTORY LATEST %s timestamp=%s %d" % (chname, echo_messages[4].time, INCLUSIVE_LIMIT)) + self.sendLine( + user, + "CHATHISTORY LATEST %s timestamp=%s %d" + % (chname, echo_messages[4].time, INCLUSIVE_LIMIT), + ) result = validate_chathistory_batch(self.getMessages(user)) self.assertEqual(echo_messages[5:], result) - self.sendLine(user, "CHATHISTORY BEFORE %s msgid=%s %d" % (chname, echo_messages[6].msgid, INCLUSIVE_LIMIT)) + self.sendLine( + user, + "CHATHISTORY BEFORE %s msgid=%s %d" + % (chname, echo_messages[6].msgid, INCLUSIVE_LIMIT), + ) result = validate_chathistory_batch(self.getMessages(user)) self.assertEqual(echo_messages[:6], result) - self.sendLine(user, "CHATHISTORY BEFORE %s timestamp=%s %d" % (chname, echo_messages[6].time, INCLUSIVE_LIMIT)) + self.sendLine( + user, + "CHATHISTORY BEFORE %s timestamp=%s %d" + % (chname, echo_messages[6].time, INCLUSIVE_LIMIT), + ) result = validate_chathistory_batch(self.getMessages(user)) self.assertEqual(echo_messages[:6], result) - self.sendLine(user, "CHATHISTORY BEFORE %s timestamp=%s %d" % (chname, echo_messages[6].time, 2)) + self.sendLine( + user, + "CHATHISTORY BEFORE %s timestamp=%s %d" + % (chname, echo_messages[6].time, 2), + ) result = validate_chathistory_batch(self.getMessages(user)) self.assertEqual(echo_messages[4:6], result) - self.sendLine(user, "CHATHISTORY AFTER %s msgid=%s %d" % (chname, echo_messages[3].msgid, INCLUSIVE_LIMIT)) + self.sendLine( + user, + "CHATHISTORY AFTER %s msgid=%s %d" + % (chname, echo_messages[3].msgid, INCLUSIVE_LIMIT), + ) result = validate_chathistory_batch(self.getMessages(user)) self.assertEqual(echo_messages[4:], result) - self.sendLine(user, "CHATHISTORY AFTER %s timestamp=%s %d" % (chname, echo_messages[3].time, INCLUSIVE_LIMIT)) + self.sendLine( + user, + "CHATHISTORY AFTER %s timestamp=%s %d" + % (chname, echo_messages[3].time, INCLUSIVE_LIMIT), + ) result = validate_chathistory_batch(self.getMessages(user)) self.assertEqual(echo_messages[4:], result) - self.sendLine(user, "CHATHISTORY AFTER %s timestamp=%s %d" % (chname, echo_messages[3].time, 3)) + self.sendLine( + user, + "CHATHISTORY AFTER %s timestamp=%s %d" % (chname, echo_messages[3].time, 3), + ) result = validate_chathistory_batch(self.getMessages(user)) self.assertEqual(echo_messages[4:7], result) # BETWEEN forwards and backwards - self.sendLine(user, "CHATHISTORY BETWEEN %s msgid=%s msgid=%s %d" % (chname, echo_messages[0].msgid, echo_messages[-1].msgid, INCLUSIVE_LIMIT)) + self.sendLine( + user, + "CHATHISTORY BETWEEN %s msgid=%s msgid=%s %d" + % ( + chname, + echo_messages[0].msgid, + echo_messages[-1].msgid, + INCLUSIVE_LIMIT, + ), + ) result = validate_chathistory_batch(self.getMessages(user)) self.assertEqual(echo_messages[1:-1], result) - self.sendLine(user, "CHATHISTORY BETWEEN %s msgid=%s msgid=%s %d" % (chname, echo_messages[-1].msgid, echo_messages[0].msgid, INCLUSIVE_LIMIT)) + self.sendLine( + user, + "CHATHISTORY BETWEEN %s msgid=%s msgid=%s %d" + % ( + chname, + echo_messages[-1].msgid, + echo_messages[0].msgid, + INCLUSIVE_LIMIT, + ), + ) result = validate_chathistory_batch(self.getMessages(user)) self.assertEqual(echo_messages[1:-1], result) # BETWEEN forwards and backwards with a limit, should get different results this time - self.sendLine(user, "CHATHISTORY BETWEEN %s msgid=%s msgid=%s %d" % (chname, echo_messages[0].msgid, echo_messages[-1].msgid, 3)) + self.sendLine( + user, + "CHATHISTORY BETWEEN %s msgid=%s msgid=%s %d" + % (chname, echo_messages[0].msgid, echo_messages[-1].msgid, 3), + ) result = validate_chathistory_batch(self.getMessages(user)) self.assertEqual(echo_messages[1:4], result) - self.sendLine(user, "CHATHISTORY BETWEEN %s msgid=%s msgid=%s %d" % (chname, echo_messages[-1].msgid, echo_messages[0].msgid, 3)) + self.sendLine( + user, + "CHATHISTORY BETWEEN %s msgid=%s msgid=%s %d" + % (chname, echo_messages[-1].msgid, echo_messages[0].msgid, 3), + ) result = validate_chathistory_batch(self.getMessages(user)) self.assertEqual(echo_messages[-4:-1], result) # same stuff again but with timestamps - self.sendLine(user, "CHATHISTORY BETWEEN %s timestamp=%s timestamp=%s %d" % (chname, echo_messages[0].time, echo_messages[-1].time, INCLUSIVE_LIMIT)) + self.sendLine( + user, + "CHATHISTORY BETWEEN %s timestamp=%s timestamp=%s %d" + % (chname, echo_messages[0].time, echo_messages[-1].time, INCLUSIVE_LIMIT), + ) result = validate_chathistory_batch(self.getMessages(user)) self.assertEqual(echo_messages[1:-1], result) - self.sendLine(user, "CHATHISTORY BETWEEN %s timestamp=%s timestamp=%s %d" % (chname, echo_messages[-1].time, echo_messages[0].time, INCLUSIVE_LIMIT)) + self.sendLine( + user, + "CHATHISTORY BETWEEN %s timestamp=%s timestamp=%s %d" + % (chname, echo_messages[-1].time, echo_messages[0].time, INCLUSIVE_LIMIT), + ) result = validate_chathistory_batch(self.getMessages(user)) self.assertEqual(echo_messages[1:-1], result) - self.sendLine(user, "CHATHISTORY BETWEEN %s timestamp=%s timestamp=%s %d" % (chname, echo_messages[0].time, echo_messages[-1].time, 3)) + self.sendLine( + user, + "CHATHISTORY BETWEEN %s timestamp=%s timestamp=%s %d" + % (chname, echo_messages[0].time, echo_messages[-1].time, 3), + ) result = validate_chathistory_batch(self.getMessages(user)) self.assertEqual(echo_messages[1:4], result) - self.sendLine(user, "CHATHISTORY BETWEEN %s timestamp=%s timestamp=%s %d" % (chname, echo_messages[-1].time, echo_messages[0].time, 3)) + self.sendLine( + user, + "CHATHISTORY BETWEEN %s timestamp=%s timestamp=%s %d" + % (chname, echo_messages[-1].time, echo_messages[0].time, 3), + ) result = validate_chathistory_batch(self.getMessages(user)) self.assertEqual(echo_messages[-4:-1], result) # AROUND - self.sendLine(user, "CHATHISTORY AROUND %s msgid=%s %d" % (chname, echo_messages[7].msgid, 1)) + self.sendLine( + user, + "CHATHISTORY AROUND %s msgid=%s %d" % (chname, echo_messages[7].msgid, 1), + ) result = validate_chathistory_batch(self.getMessages(user)) self.assertEqual([echo_messages[7]], result) - self.sendLine(user, "CHATHISTORY AROUND %s msgid=%s %d" % (chname, echo_messages[7].msgid, 3)) + self.sendLine( + user, + "CHATHISTORY AROUND %s msgid=%s %d" % (chname, echo_messages[7].msgid, 3), + ) result = validate_chathistory_batch(self.getMessages(user)) self.assertEqual(echo_messages[6:9], result) - self.sendLine(user, "CHATHISTORY AROUND %s timestamp=%s %d" % (chname, echo_messages[7].time, 3)) + self.sendLine( + user, + "CHATHISTORY AROUND %s timestamp=%s %d" + % (chname, echo_messages[7].time, 3), + ) result = validate_chathistory_batch(self.getMessages(user)) self.assertIn(echo_messages[7], result) - @cases.SpecificationSelector.requiredBySpecification('Oragono') + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testChathistoryTagmsg(self): c1 = secrets.token_hex(12) c2 = secrets.token_hex(12) - chname = '#' + secrets.token_hex(12) - self.controller.registerUser(self, c1, 'sesame1') - self.controller.registerUser(self, c2, 'sesame2') - self.connectClient(c1, capabilities=['message-tags', 'server-time', 'echo-message', 'batch', 'labeled-response', CHATHISTORY_CAP, EVENT_PLAYBACK_CAP], password='sesame1') - self.connectClient(c2, capabilities=['message-tags', 'server-time', 'echo-message', 'batch', 'labeled-response', CHATHISTORY_CAP,], password='sesame2') + chname = "#" + secrets.token_hex(12) + self.controller.registerUser(self, c1, "sesame1") + self.controller.registerUser(self, c2, "sesame2") + self.connectClient( + c1, + capabilities=[ + "message-tags", + "server-time", + "echo-message", + "batch", + "labeled-response", + CHATHISTORY_CAP, + EVENT_PLAYBACK_CAP, + ], + password="sesame1", + ) + self.connectClient( + c2, + capabilities=[ + "message-tags", + "server-time", + "echo-message", + "batch", + "labeled-response", + CHATHISTORY_CAP, + ], + password="sesame2", + ) self.joinChannel(1, chname) self.joinChannel(2, chname) self.getMessages(1) self.getMessages(2) - self.sendLine(1, '@+client-only-tag-test=success;+draft/persist TAGMSG %s' % (chname,)) + self.sendLine( + 1, "@+client-only-tag-test=success;+draft/persist TAGMSG %s" % (chname,) + ) echo = self.getMessages(1)[0] - msgid = echo.tags['msgid'] + msgid = echo.tags["msgid"] def validate_tagmsg(msg, target, msgid): - self.assertEqual(msg.command, 'TAGMSG') - self.assertEqual(msg.tags['+client-only-tag-test'], 'success') - self.assertEqual(msg.tags['msgid'], msgid) + self.assertEqual(msg.command, "TAGMSG") + self.assertEqual(msg.tags["+client-only-tag-test"], "success") + self.assertEqual(msg.tags["msgid"], msgid) self.assertEqual(msg.params, [target]) validate_tagmsg(echo, chname, msgid) @@ -312,69 +533,104 @@ class ChathistoryTestCase(cases.BaseServerTestCase): self.assertEqual(len(relay), 1) validate_tagmsg(relay[0], chname, msgid) - self.sendLine(1, 'CHATHISTORY LATEST %s * 10' % (chname,)) - history_tagmsgs = [msg for msg in self.getMessages(1) if msg.command == 'TAGMSG'] + self.sendLine(1, "CHATHISTORY LATEST %s * 10" % (chname,)) + history_tagmsgs = [ + msg for msg in self.getMessages(1) if msg.command == "TAGMSG" + ] self.assertEqual(len(history_tagmsgs), 1) validate_tagmsg(history_tagmsgs[0], chname, msgid) # c2 doesn't have event-playback and MUST NOT receive replayed tagmsg - self.sendLine(2, 'CHATHISTORY LATEST %s * 10' % (chname,)) - history_tagmsgs = [msg for msg in self.getMessages(2) if msg.command == 'TAGMSG'] + self.sendLine(2, "CHATHISTORY LATEST %s * 10" % (chname,)) + history_tagmsgs = [ + msg for msg in self.getMessages(2) if msg.command == "TAGMSG" + ] self.assertEqual(len(history_tagmsgs), 0) # now try a DM - self.sendLine(1, '@+client-only-tag-test=success;+draft/persist TAGMSG %s' % (c2,)) + self.sendLine( + 1, "@+client-only-tag-test=success;+draft/persist TAGMSG %s" % (c2,) + ) echo = self.getMessages(1)[0] - msgid = echo.tags['msgid'] + msgid = echo.tags["msgid"] validate_tagmsg(echo, c2, msgid) relay = self.getMessages(2) self.assertEqual(len(relay), 1) validate_tagmsg(relay[0], c2, msgid) - self.sendLine(1, 'CHATHISTORY LATEST %s * 10' % (c2,)) - history_tagmsgs = [msg for msg in self.getMessages(1) if msg.command == 'TAGMSG'] + self.sendLine(1, "CHATHISTORY LATEST %s * 10" % (c2,)) + history_tagmsgs = [ + msg for msg in self.getMessages(1) if msg.command == "TAGMSG" + ] self.assertEqual(len(history_tagmsgs), 1) validate_tagmsg(history_tagmsgs[0], c2, msgid) # c2 doesn't have event-playback and MUST NOT receive replayed tagmsg - self.sendLine(2, 'CHATHISTORY LATEST %s * 10' % (c1,)) - history_tagmsgs = [msg for msg in self.getMessages(2) if msg.command == 'TAGMSG'] + self.sendLine(2, "CHATHISTORY LATEST %s * 10" % (c1,)) + history_tagmsgs = [ + msg for msg in self.getMessages(2) if msg.command == "TAGMSG" + ] self.assertEqual(len(history_tagmsgs), 0) - - @cases.SpecificationSelector.requiredBySpecification('Oragono') + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testChathistoryDMClientOnlyTags(self): # regression test for Oragono #1411 c1 = secrets.token_hex(12) c2 = secrets.token_hex(12) - self.controller.registerUser(self, c1, 'sesame1') - self.controller.registerUser(self, c2, 'sesame2') - self.connectClient(c1, capabilities=['message-tags', 'server-time', 'echo-message', 'batch', 'labeled-response', CHATHISTORY_CAP, EVENT_PLAYBACK_CAP], password='sesame1') - self.connectClient(c2, capabilities=['message-tags', 'server-time', 'echo-message', 'batch', 'labeled-response', CHATHISTORY_CAP,], password='sesame2') + self.controller.registerUser(self, c1, "sesame1") + self.controller.registerUser(self, c2, "sesame2") + self.connectClient( + c1, + capabilities=[ + "message-tags", + "server-time", + "echo-message", + "batch", + "labeled-response", + CHATHISTORY_CAP, + EVENT_PLAYBACK_CAP, + ], + password="sesame1", + ) + self.connectClient( + c2, + capabilities=[ + "message-tags", + "server-time", + "echo-message", + "batch", + "labeled-response", + CHATHISTORY_CAP, + ], + password="sesame2", + ) self.getMessages(1) self.getMessages(2) echo_msgid = None - def validate_msg(msg): - self.assertEqual(msg.command, 'PRIVMSG') - self.assertEqual(msg.tags['+client-only-tag-test'], 'success') - self.assertEqual(msg.tags['msgid'], echo_msgid) - self.assertEqual(msg.params, [c2, 'hi']) - self.sendLine(1, '@+client-only-tag-test=success;+draft/persist PRIVMSG %s hi' % (c2,)) + def validate_msg(msg): + self.assertEqual(msg.command, "PRIVMSG") + self.assertEqual(msg.tags["+client-only-tag-test"], "success") + self.assertEqual(msg.tags["msgid"], echo_msgid) + self.assertEqual(msg.params, [c2, "hi"]) + + self.sendLine( + 1, "@+client-only-tag-test=success;+draft/persist PRIVMSG %s hi" % (c2,) + ) echo = self.getMessage(1) - echo_msgid = echo.tags['msgid'] + echo_msgid = echo.tags["msgid"] validate_msg(echo) relay = self.getMessage(2) validate_msg(relay) - self.sendLine(1, 'CHATHISTORY LATEST * * 10') - hist = [msg for msg in self.getMessages(1) if msg.command == 'PRIVMSG'] + self.sendLine(1, "CHATHISTORY LATEST * * 10") + hist = [msg for msg in self.getMessages(1) if msg.command == "PRIVMSG"] self.assertEqual(len(hist), 1) validate_msg(hist[0]) - self.sendLine(2, 'CHATHISTORY LATEST * * 10') - hist = [msg for msg in self.getMessages(2) if msg.command == 'PRIVMSG'] + self.sendLine(2, "CHATHISTORY LATEST * * 10") + hist = [msg for msg in self.getMessages(2) if msg.command == "PRIVMSG"] self.assertEqual(len(hist), 1) validate_msg(hist[0]) diff --git a/irctest/server_tests/test_confusables.py b/irctest/server_tests/test_confusables.py index 735ace1..0f25bf6 100644 --- a/irctest/server_tests/test_confusables.py +++ b/irctest/server_tests/test_confusables.py @@ -1,31 +1,32 @@ from irctest import cases from irctest.numerics import RPL_WELCOME, ERR_NICKNAMEINUSE + class ConfusablesTestCase(cases.BaseServerTestCase): @staticmethod def config(): return { - "oragono_config": lambda config: config['accounts'].update( - {'nick-reservation': {'enabled': True, 'method': 'strict'}} + "oragono_config": lambda config: config["accounts"].update( + {"nick-reservation": {"enabled": True, "method": "strict"}} ) } - @cases.SpecificationSelector.requiredBySpecification('Oragono') + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testConfusableNicks(self): - self.controller.registerUser(self, 'evan', 'sesame') + self.controller.registerUser(self, "evan", "sesame") self.addClient(1) # U+0435 in place of e: - self.sendLine(1, 'NICK еvan') - self.sendLine(1, 'USER a 0 * a') + self.sendLine(1, "NICK еvan") + self.sendLine(1, "USER a 0 * a") messages = self.getMessages(1) commands = set(msg.command for msg in messages) self.assertNotIn(RPL_WELCOME, commands) self.assertIn(ERR_NICKNAMEINUSE, commands) - self.connectClient('evan', name='evan', password='sesame') + self.connectClient("evan", name="evan", password="sesame") # should be able to switch to the confusable nick - self.sendLine('evan', 'NICK еvan') - messages = self.getMessages('evan') + self.sendLine("evan", "NICK еvan") + messages = self.getMessages("evan") commands = set(msg.command for msg in messages) - self.assertIn('NICK', commands) + self.assertIn("NICK", commands) diff --git a/irctest/server_tests/test_connection_registration.py b/irctest/server_tests/test_connection_registration.py index 1536102..40aa67d 100644 --- a/irctest/server_tests/test_connection_registration.py +++ b/irctest/server_tests/test_connection_registration.py @@ -6,39 +6,48 @@ Tests section 4.1 of RFC 1459. from irctest import cases from irctest.client_mock import ConnectionClosed + class PasswordedConnectionRegistrationTestCase(cases.BaseServerTestCase): - password = 'testpassword' - @cases.SpecificationSelector.requiredBySpecification('RFC1459', 'RFC2812') + password = "testpassword" + + @cases.SpecificationSelector.requiredBySpecification("RFC1459", "RFC2812") def testPassBeforeNickuser(self): self.addClient() - self.sendLine(1, 'PASS {}'.format(self.password)) - self.sendLine(1, 'NICK foo') - self.sendLine(1, 'USER username * * :Realname') + self.sendLine(1, "PASS {}".format(self.password)) + self.sendLine(1, "NICK foo") + self.sendLine(1, "USER username * * :Realname") m = self.getRegistrationMessage(1) - self.assertMessageEqual(m, command='001', - fail_msg='Did not get 001 after correct PASS+NICK+USER: {msg}') + self.assertMessageEqual( + m, + command="001", + fail_msg="Did not get 001 after correct PASS+NICK+USER: {msg}", + ) - @cases.SpecificationSelector.requiredBySpecification('RFC1459', 'RFC2812') + @cases.SpecificationSelector.requiredBySpecification("RFC1459", "RFC2812") def testNoPassword(self): self.addClient() - self.sendLine(1, 'NICK foo') - self.sendLine(1, 'USER username * * :Realname') + self.sendLine(1, "NICK foo") + self.sendLine(1, "USER username * * :Realname") m = self.getRegistrationMessage(1) - self.assertNotEqual(m.command, '001', - msg='Got 001 after NICK+USER but missing PASS') + self.assertNotEqual( + m.command, "001", msg="Got 001 after NICK+USER but missing PASS" + ) - @cases.SpecificationSelector.requiredBySpecification('RFC1459', 'RFC2812') + @cases.SpecificationSelector.requiredBySpecification("RFC1459", "RFC2812") def testWrongPassword(self): self.addClient() - self.sendLine(1, 'PASS {}'.format(self.password + "garbage")) - self.sendLine(1, 'NICK foo') - self.sendLine(1, 'USER username * * :Realname') + self.sendLine(1, "PASS {}".format(self.password + "garbage")) + self.sendLine(1, "NICK foo") + self.sendLine(1, "USER username * * :Realname") m = self.getRegistrationMessage(1) - self.assertNotEqual(m.command, '001', - msg='Got 001 after NICK+USER but incorrect PASS') + self.assertNotEqual( + m.command, "001", msg="Got 001 after NICK+USER but incorrect PASS" + ) - @cases.SpecificationSelector.requiredBySpecification('RFC1459', 'RFC2812', strict=True) + @cases.SpecificationSelector.requiredBySpecification( + "RFC1459", "RFC2812", strict=True + ) def testPassAfterNickuser(self): """“The password can and must be set before any attempt to register the connection is made.” @@ -51,72 +60,77 @@ class PasswordedConnectionRegistrationTestCase(cases.BaseServerTestCase): -- """ self.addClient() - self.sendLine(1, 'NICK foo') - self.sendLine(1, 'USER username * * :Realname') - self.sendLine(1, 'PASS {}'.format(self.password)) + self.sendLine(1, "NICK foo") + self.sendLine(1, "USER username * * :Realname") + self.sendLine(1, "PASS {}".format(self.password)) m = self.getRegistrationMessage(1) - self.assertNotEqual(m.command, '001', - 'Got 001 after PASS sent after NICK+USER') + self.assertNotEqual(m.command, "001", "Got 001 after PASS sent after NICK+USER") + class ConnectionRegistrationTestCase(cases.BaseServerTestCase): - @cases.SpecificationSelector.requiredBySpecification('RFC1459') + @cases.SpecificationSelector.requiredBySpecification("RFC1459") def testQuitDisconnects(self): """“The server must close the connection to a client which sends a QUIT message.” -- """ - self.connectClient('foo') + self.connectClient("foo") self.getMessages(1) - self.sendLine(1, 'QUIT') + self.sendLine(1, "QUIT") with self.assertRaises(ConnectionClosed): - self.getMessages(1) # Fetch remaining messages + self.getMessages(1) # Fetch remaining messages self.getMessages(1) - @cases.SpecificationSelector.requiredBySpecification('RFC2812') + @cases.SpecificationSelector.requiredBySpecification("RFC2812") def testQuitErrors(self): """“A client session is terminated with a quit message. The server acknowledges this by sending an ERROR message to the client.” -- """ - self.connectClient('foo') + self.connectClient("foo") self.getMessages(1) - self.sendLine(1, 'QUIT') + self.sendLine(1, "QUIT") try: commands = {m.command for m in self.getMessages(1)} except ConnectionClosed: - assert False, 'Connection closed without ERROR.' - self.assertIn('ERROR', commands, - fail_msg='Did not receive ERROR as a reply to QUIT.') - + assert False, "Connection closed without ERROR." + self.assertIn( + "ERROR", commands, fail_msg="Did not receive ERROR as a reply to QUIT." + ) def testNickCollision(self): """A user connects and requests the same nickname as an already registered user. """ - self.connectClient('foo') + self.connectClient("foo") self.addClient() - self.sendLine(2, 'NICK foo') - self.sendLine(2, 'USER username * * :Realname') + self.sendLine(2, "NICK foo") + self.sendLine(2, "USER username * * :Realname") m = self.getRegistrationMessage(2) - self.assertNotEqual(m.command, '001', - 'Received 001 after registering with the nick of a ' - 'registered user.') + self.assertNotEqual( + m.command, + "001", + "Received 001 after registering with the nick of a " "registered user.", + ) def testEarlyNickCollision(self): """Two users register simultaneously with the same nick.""" self.addClient() self.addClient() - self.sendLine(1, 'NICK foo') - self.sendLine(2, 'NICK foo') - self.sendLine(1, 'USER username * * :Realname') - self.sendLine(2, 'USER username * * :Realname') + self.sendLine(1, "NICK foo") + self.sendLine(2, "NICK foo") + self.sendLine(1, "USER username * * :Realname") + self.sendLine(2, "USER username * * :Realname") m1 = self.getRegistrationMessage(1) m2 = self.getRegistrationMessage(2) - self.assertNotEqual((m1.command, m2.command), ('001', '001'), - 'Two concurrently registering requesting the same nickname ' - 'both got 001.') + self.assertNotEqual( + (m1.command, m2.command), + ("001", "001"), + "Two concurrently registering requesting the same nickname " + "both got 001.", + ) - @cases.SpecificationSelector.requiredBySpecification('IRCv3.1', 'IRCv3.2') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.1", "IRCv3.2") def testIrc301CapLs(self): """IRCv3.1: “The LS subcommand is used to list the capabilities supported by the server. The client should send an LS subcommand with @@ -128,24 +142,34 @@ class ConnectionRegistrationTestCase(cases.BaseServerTestCase): -- """ self.addClient() - self.sendLine(1, 'CAP LS') + self.sendLine(1, "CAP LS") m = self.getRegistrationMessage(1) - self.assertNotEqual(m.params[2], '*', m, - fail_msg='Server replied with multi-line CAP LS to a ' - '“CAP LS” (ie. IRCv3.1) request: {msg}') - self.assertFalse(any('=' in cap for cap in m.params[2].split()), - 'Server replied with a name-value capability in ' - 'CAP LS reply as a response to “CAP LS” (ie. IRCv3.1) ' - 'request: {}'.format(m)) + self.assertNotEqual( + m.params[2], + "*", + m, + fail_msg="Server replied with multi-line CAP LS to a " + "“CAP LS” (ie. IRCv3.1) request: {msg}", + ) + self.assertFalse( + any("=" in cap for cap in m.params[2].split()), + "Server replied with a name-value capability in " + "CAP LS reply as a response to “CAP LS” (ie. IRCv3.1) " + "request: {}".format(m), + ) - @cases.SpecificationSelector.requiredBySpecification('IRCv3.1') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.1") def testEmptyCapList(self): """“If no capabilities are active, an empty parameter must be sent.” -- """ self.addClient() - self.sendLine(1, 'CAP LIST') + self.sendLine(1, "CAP LIST") m = self.getRegistrationMessage(1) - self.assertMessageEqual(m, command='CAP', params=['*', 'LIST', ''], - fail_msg='Sending “CAP LIST” as first message got a reply ' - 'that is not “CAP * LIST :”: {msg}') + self.assertMessageEqual( + m, + command="CAP", + params=["*", "LIST", ""], + fail_msg="Sending “CAP LIST” as first message got a reply " + "that is not “CAP * LIST :”: {msg}", + ) diff --git a/irctest/server_tests/test_echo_message.py b/irctest/server_tests/test_echo_message.py index 8de1b82..9d6442f 100644 --- a/irctest/server_tests/test_echo_message.py +++ b/irctest/server_tests/test_echo_message.py @@ -6,65 +6,93 @@ from irctest import cases from irctest.basecontrollers import NotImplementedByController from irctest.irc_utils.junkdrawer import random_name -class DMEchoMessageTestCase(cases.BaseServerTestCase): - @cases.SpecificationSelector.requiredBySpecification('Oragono') +class DMEchoMessageTestCase(cases.BaseServerTestCase): + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testDirectMessageEcho(self): - bar = random_name('bar') - self.connectClient(bar, name=bar, capabilities=['batch', 'labeled-response', 'echo-message', 'message-tags', 'server-time']) + bar = random_name("bar") + self.connectClient( + bar, + name=bar, + capabilities=[ + "batch", + "labeled-response", + "echo-message", + "message-tags", + "server-time", + ], + ) self.getMessages(bar) - qux = random_name('qux') - self.connectClient(qux, name=qux, capabilities=['batch', 'labeled-response', 'echo-message', 'message-tags', 'server-time']) + qux = random_name("qux") + self.connectClient( + qux, + name=qux, + capabilities=[ + "batch", + "labeled-response", + "echo-message", + "message-tags", + "server-time", + ], + ) self.getMessages(qux) - self.sendLine(bar, '@label=xyz;+example-client-tag=example-value PRIVMSG %s :hi there' % (qux,)) + self.sendLine( + bar, + "@label=xyz;+example-client-tag=example-value PRIVMSG %s :hi there" + % (qux,), + ) echo = self.getMessages(bar)[0] delivery = self.getMessages(qux)[0] - self.assertEqual(delivery.params, [qux, 'hi there']) + self.assertEqual(delivery.params, [qux, "hi there"]) self.assertEqual(delivery.params, echo.params) - self.assertEqual(delivery.tags['msgid'], echo.tags['msgid']) - self.assertEqual(echo.tags['label'], 'xyz') - self.assertEqual(delivery.tags['+example-client-tag'], 'example-value') - self.assertEqual(delivery.tags['+example-client-tag'], echo.tags['+example-client-tag']) + self.assertEqual(delivery.tags["msgid"], echo.tags["msgid"]) + self.assertEqual(echo.tags["label"], "xyz") + self.assertEqual(delivery.tags["+example-client-tag"], "example-value") + self.assertEqual( + delivery.tags["+example-client-tag"], echo.tags["+example-client-tag"] + ) + class EchoMessageTestCase(cases.BaseServerTestCase): def _testEchoMessage(command, solo, server_time): - @cases.SpecificationSelector.requiredBySpecification('IRCv3.2') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.2") def f(self): - """ - """ + """""" self.addClient() - self.sendLine(1, 'CAP LS 302') + self.sendLine(1, "CAP LS 302") capabilities = self.getCapLs(1) - if 'echo-message' not in capabilities: - raise NotImplementedByController('echo-message') - if server_time and 'server-time' not in capabilities: - raise NotImplementedByController('server-time') + if "echo-message" not in capabilities: + raise NotImplementedByController("echo-message") + if server_time and "server-time" not in capabilities: + raise NotImplementedByController("server-time") # TODO: check also without this - self.sendLine(1, 'CAP REQ :echo-message{}'.format( - ' server-time' if server_time else '')) + self.sendLine( + 1, + "CAP REQ :echo-message{}".format(" server-time" if server_time else ""), + ) self.getRegistrationMessage(1) # TODO: Remove this one the trailing space issue is fixed in Charybdis # and Mammon: - #self.assertMessageEqual(m, command='CAP', + # self.assertMessageEqual(m, command='CAP', # params=['*', 'ACK', 'echo-message'] + # (['server-time'] if server_time else []), # fail_msg='Did not ACK advertised capabilities: {msg}') - self.sendLine(1, 'USER f * * :foo') - self.sendLine(1, 'NICK baz') - self.sendLine(1, 'CAP END') + self.sendLine(1, "USER f * * :foo") + self.sendLine(1, "NICK baz") + self.sendLine(1, "CAP END") self.skipToWelcome(1) self.getMessages(1) - self.sendLine(1, 'JOIN #chan') + self.sendLine(1, "JOIN #chan") if not solo: - capabilities = ['server-time'] if server_time else None - self.connectClient('qux', capabilities=capabilities) - self.sendLine(2, 'JOIN #chan') + capabilities = ["server-time"] if server_time else None + self.connectClient("qux", capabilities=capabilities) + self.sendLine(2, "JOIN #chan") # Synchronize and clean self.getMessages(1) @@ -72,30 +100,50 @@ class EchoMessageTestCase(cases.BaseServerTestCase): self.getMessages(2) self.getMessages(1) - self.sendLine(1, '{} #chan :hello everyone'.format(command)) + self.sendLine(1, "{} #chan :hello everyone".format(command)) m1 = self.getMessage(1) - self.assertMessageEqual(m1, command=command, - params=['#chan', 'hello everyone'], - fail_msg='Did not echo “{} #chan :hello everyone”: {msg}', - extra_format=(command,)) + self.assertMessageEqual( + m1, + command=command, + params=["#chan", "hello everyone"], + fail_msg="Did not echo “{} #chan :hello everyone”: {msg}", + extra_format=(command,), + ) if not solo: m2 = self.getMessage(2) - self.assertMessageEqual(m2, command=command, - params=['#chan', 'hello everyone'], - fail_msg='Did not propagate “{} #chan :hello everyone”: ' - 'after echoing it to the author: {msg}', - extra_format=(command,)) - self.assertEqual(m1.params, m2.params, - fail_msg='Parameters of forwarded and echoed ' - 'messages differ: {} {}', - extra_format=(m1, m2)) + self.assertMessageEqual( + m2, + command=command, + params=["#chan", "hello everyone"], + fail_msg="Did not propagate “{} #chan :hello everyone”: " + "after echoing it to the author: {msg}", + extra_format=(command,), + ) + self.assertEqual( + m1.params, + m2.params, + fail_msg="Parameters of forwarded and echoed " + "messages differ: {} {}", + extra_format=(m1, m2), + ) if server_time: - self.assertIn('time', m1.tags, fail_msg='Echoed message is missing server time: {}', extra_format=(m1,)) - self.assertIn('time', m2.tags, fail_msg='Forwarded message is missing server time: {}', extra_format=(m2,)) + self.assertIn( + "time", + m1.tags, + fail_msg="Echoed message is missing server time: {}", + extra_format=(m1,), + ) + self.assertIn( + "time", + m2.tags, + fail_msg="Forwarded message is missing server time: {}", + extra_format=(m2,), + ) + return f - testEchoMessagePrivmsgNoServerTime = _testEchoMessage('PRIVMSG', False, False) - testEchoMessagePrivmsgSolo = _testEchoMessage('PRIVMSG', True, True) - testEchoMessagePrivmsg = _testEchoMessage('PRIVMSG', False, True) - testEchoMessageNotice = _testEchoMessage('NOTICE', False, True) + testEchoMessagePrivmsgNoServerTime = _testEchoMessage("PRIVMSG", False, False) + testEchoMessagePrivmsgSolo = _testEchoMessage("PRIVMSG", True, True) + testEchoMessagePrivmsg = _testEchoMessage("PRIVMSG", False, True) + testEchoMessageNotice = _testEchoMessage("NOTICE", False, True) diff --git a/irctest/server_tests/test_extended_join.py b/irctest/server_tests/test_extended_join.py index 2f5c520..b8a660d 100644 --- a/irctest/server_tests/test_extended_join.py +++ b/irctest/server_tests/test_extended_join.py @@ -4,52 +4,64 @@ from irctest import cases + class MetadataTestCase(cases.BaseServerTestCase, cases.OptionalityHelper): def connectRegisteredClient(self, nick): self.addClient() - self.sendLine(2, 'CAP LS 302') + self.sendLine(2, "CAP LS 302") capabilities = self.getCapLs(2) - assert 'sasl' in capabilities - self.sendLine(2, 'AUTHENTICATE PLAIN') - m = self.getMessage(2, filter_pred=lambda m:m.command != 'NOTICE') - self.assertMessageEqual(m, command='AUTHENTICATE', params=['+'], - fail_msg='Sent “AUTHENTICATE PLAIN”, server should have ' - 'replied with “AUTHENTICATE +”, but instead sent: {msg}') - self.sendLine(2, 'AUTHENTICATE amlsbGVzAGppbGxlcwBzZXNhbWU=') - m = self.getMessage(2, filter_pred=lambda m:m.command != 'NOTICE') - self.assertMessageEqual(m, command='900', - fail_msg='Did not send 900 after correct SASL authentication.') - self.sendLine(2, 'USER f * * :Realname') - self.sendLine(2, 'NICK {}'.format(nick)) - self.sendLine(2, 'CAP END') + assert "sasl" in capabilities + self.sendLine(2, "AUTHENTICATE PLAIN") + m = self.getMessage(2, filter_pred=lambda m: m.command != "NOTICE") + self.assertMessageEqual( + m, + command="AUTHENTICATE", + params=["+"], + fail_msg="Sent “AUTHENTICATE PLAIN”, server should have " + "replied with “AUTHENTICATE +”, but instead sent: {msg}", + ) + self.sendLine(2, "AUTHENTICATE amlsbGVzAGppbGxlcwBzZXNhbWU=") + m = self.getMessage(2, filter_pred=lambda m: m.command != "NOTICE") + self.assertMessageEqual( + m, + command="900", + fail_msg="Did not send 900 after correct SASL authentication.", + ) + self.sendLine(2, "USER f * * :Realname") + self.sendLine(2, "NICK {}".format(nick)) + self.sendLine(2, "CAP END") self.skipToWelcome(2) - @cases.SpecificationSelector.requiredBySpecification('IRCv3.1') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.1") def testNotLoggedIn(self): - self.connectClient('foo', capabilities=['extended-join'], - skip_if_cap_nak=True) - self.joinChannel(1, '#chan') - self.connectClient('bar') - self.joinChannel(2, '#chan') + self.connectClient("foo", capabilities=["extended-join"], skip_if_cap_nak=True) + self.joinChannel(1, "#chan") + self.connectClient("bar") + self.joinChannel(2, "#chan") m = self.getMessage(1) - self.assertMessageEqual(m, command='JOIN', - params=['#chan', '*', 'Realname'], - fail_msg='Expected “JOIN #chan * :Realname” after ' - 'unregistered user joined, got: {msg}') + self.assertMessageEqual( + m, + command="JOIN", + params=["#chan", "*", "Realname"], + fail_msg="Expected “JOIN #chan * :Realname” after " + "unregistered user joined, got: {msg}", + ) - @cases.SpecificationSelector.requiredBySpecification('IRCv3.1') - @cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.1") + @cases.OptionalityHelper.skipUnlessHasMechanism("PLAIN") def testLoggedIn(self): - self.connectClient('foo', capabilities=['extended-join'], - skip_if_cap_nak=True) - self.joinChannel(1, '#chan') + self.connectClient("foo", capabilities=["extended-join"], skip_if_cap_nak=True) + self.joinChannel(1, "#chan") - self.controller.registerUser(self, 'jilles', 'sesame') - self.connectRegisteredClient('bar') - self.joinChannel(2, '#chan') + self.controller.registerUser(self, "jilles", "sesame") + self.connectRegisteredClient("bar") + self.joinChannel(2, "#chan") m = self.getMessage(1) - self.assertMessageEqual(m, command='JOIN', - params=['#chan', 'jilles', 'Realname'], - fail_msg='Expected “JOIN #chan * :Realname” after ' - 'nick “bar” logged in as “jilles” joined, got: {msg}') + self.assertMessageEqual( + m, + command="JOIN", + params=["#chan", "jilles", "Realname"], + fail_msg="Expected “JOIN #chan * :Realname” after " + "nick “bar” logged in as “jilles” joined, got: {msg}", + ) diff --git a/irctest/server_tests/test_labeled_responses.py b/irctest/server_tests/test_labeled_responses.py index 36e3b76..d40a9e5 100644 --- a/irctest/server_tests/test_labeled_responses.py +++ b/irctest/server_tests/test_labeled_responses.py @@ -6,240 +6,570 @@ import re from irctest import cases + class LabeledResponsesTestCase(cases.BaseServerTestCase, cases.OptionalityHelper): - @cases.SpecificationSelector.requiredBySpecification('IRCv3.2') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.2") def testLabeledPrivmsgResponsesToMultipleClients(self): - self.connectClient('foo', capabilities=['batch', 'echo-message', 'labeled-response'], skip_if_cap_nak=True) + self.connectClient( + "foo", + capabilities=["batch", "echo-message", "labeled-response"], + skip_if_cap_nak=True, + ) self.getMessages(1) - self.connectClient('bar', capabilities=['batch', 'echo-message', 'labeled-response'], skip_if_cap_nak=True) + self.connectClient( + "bar", + capabilities=["batch", "echo-message", "labeled-response"], + skip_if_cap_nak=True, + ) self.getMessages(2) - self.connectClient('carl', capabilities=['batch', 'echo-message', 'labeled-response'], skip_if_cap_nak=True) + self.connectClient( + "carl", + capabilities=["batch", "echo-message", "labeled-response"], + skip_if_cap_nak=True, + ) self.getMessages(3) - self.connectClient('alice', capabilities=['batch', 'echo-message', 'labeled-response'], skip_if_cap_nak=True) + self.connectClient( + "alice", + capabilities=["batch", "echo-message", "labeled-response"], + skip_if_cap_nak=True, + ) self.getMessages(4) - self.sendLine(1, '@label=12345 PRIVMSG bar,carl,alice :hi') + self.sendLine(1, "@label=12345 PRIVMSG bar,carl,alice :hi") m = self.getMessage(1) m2 = self.getMessage(2) m3 = self.getMessage(3) m4 = self.getMessage(4) # ensure the label isn't sent to recipients - self.assertMessageEqual(m2, command='PRIVMSG', fail_msg='No PRIVMSG received by target 1 after sending one out') - self.assertNotIn('label', m2.tags, m2, fail_msg="When sending a PRIVMSG with a label, the target users shouldn't receive the label (only the sending user should): {msg}") - self.assertMessageEqual(m3, command='PRIVMSG', fail_msg='No PRIVMSG received by target 1 after sending one out') - self.assertNotIn('label', m3.tags, m3, fail_msg="When sending a PRIVMSG with a label, the target users shouldn't receive the label (only the sending user should): {msg}") - self.assertMessageEqual(m4, command='PRIVMSG', fail_msg='No PRIVMSG received by target 1 after sending one out') - self.assertNotIn('label', m4.tags, m4, fail_msg="When sending a PRIVMSG with a label, the target users shouldn't receive the label (only the sending user should): {msg}") + self.assertMessageEqual( + m2, + command="PRIVMSG", + fail_msg="No PRIVMSG received by target 1 after sending one out", + ) + self.assertNotIn( + "label", + m2.tags, + m2, + fail_msg="When sending a PRIVMSG with a label, the target users shouldn't receive the label (only the sending user should): {msg}", + ) + self.assertMessageEqual( + m3, + command="PRIVMSG", + fail_msg="No PRIVMSG received by target 1 after sending one out", + ) + self.assertNotIn( + "label", + m3.tags, + m3, + fail_msg="When sending a PRIVMSG with a label, the target users shouldn't receive the label (only the sending user should): {msg}", + ) + self.assertMessageEqual( + m4, + command="PRIVMSG", + fail_msg="No PRIVMSG received by target 1 after sending one out", + ) + self.assertNotIn( + "label", + m4.tags, + m4, + fail_msg="When sending a PRIVMSG with a label, the target users shouldn't receive the label (only the sending user should): {msg}", + ) - self.assertMessageEqual(m, command='BATCH', fail_msg='No BATCH echo received after sending one out') + self.assertMessageEqual( + m, command="BATCH", fail_msg="No BATCH echo received after sending one out" + ) - @cases.SpecificationSelector.requiredBySpecification('IRCv3.2') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.2") def testLabeledPrivmsgResponsesToClient(self): - self.connectClient('foo', capabilities=['batch', 'echo-message', 'labeled-response'], skip_if_cap_nak=True) + self.connectClient( + "foo", + capabilities=["batch", "echo-message", "labeled-response"], + skip_if_cap_nak=True, + ) self.getMessages(1) - self.connectClient('bar', capabilities=['batch', 'echo-message', 'labeled-response'], skip_if_cap_nak=True) + self.connectClient( + "bar", + capabilities=["batch", "echo-message", "labeled-response"], + skip_if_cap_nak=True, + ) self.getMessages(2) - self.sendLine(1, '@label=12345 PRIVMSG bar :hi') + self.sendLine(1, "@label=12345 PRIVMSG bar :hi") m = self.getMessage(1) m2 = self.getMessage(2) # ensure the label isn't sent to recipient - self.assertMessageEqual(m2, command='PRIVMSG', fail_msg='No PRIVMSG received by the target after sending one out') - self.assertNotIn('label', m2.tags, m2, fail_msg="When sending a PRIVMSG with a label, the target user shouldn't receive the label (only the sending user should): {msg}") + self.assertMessageEqual( + m2, + command="PRIVMSG", + fail_msg="No PRIVMSG received by the target after sending one out", + ) + self.assertNotIn( + "label", + m2.tags, + m2, + fail_msg="When sending a PRIVMSG with a label, the target user shouldn't receive the label (only the sending user should): {msg}", + ) - self.assertMessageEqual(m, command='PRIVMSG', fail_msg='No PRIVMSG echo received after sending one out') - self.assertIn('label', m.tags, m, fail_msg="When sending a PRIVMSG with a label, the echo'd message didn't contain the label at all: {msg}") - self.assertEqual(m.tags['label'], '12345', m, fail_msg="Echo'd PRIVMSG to a client did not contain the same label we sent it with(should be '12345'): {msg}") + self.assertMessageEqual( + m, + command="PRIVMSG", + fail_msg="No PRIVMSG echo received after sending one out", + ) + self.assertIn( + "label", + m.tags, + m, + fail_msg="When sending a PRIVMSG with a label, the echo'd message didn't contain the label at all: {msg}", + ) + self.assertEqual( + m.tags["label"], + "12345", + m, + fail_msg="Echo'd PRIVMSG to a client did not contain the same label we sent it with(should be '12345'): {msg}", + ) - @cases.SpecificationSelector.requiredBySpecification('IRCv3.2') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.2") def testLabeledPrivmsgResponsesToChannel(self): - self.connectClient('foo', capabilities=['batch', 'echo-message', 'labeled-response'], skip_if_cap_nak=True) + self.connectClient( + "foo", + capabilities=["batch", "echo-message", "labeled-response"], + skip_if_cap_nak=True, + ) self.getMessages(1) - self.connectClient('bar', capabilities=['batch', 'echo-message', 'labeled-response'], skip_if_cap_nak=True) + self.connectClient( + "bar", + capabilities=["batch", "echo-message", "labeled-response"], + skip_if_cap_nak=True, + ) self.getMessages(2) # join channels - self.sendLine(1, 'JOIN #test') + self.sendLine(1, "JOIN #test") self.getMessages(1) - self.sendLine(2, 'JOIN #test') + self.sendLine(2, "JOIN #test") self.getMessages(2) self.getMessages(1) - self.sendLine(1, '@label=12345;+draft/reply=123;+draft/react=l😃l PRIVMSG #test :hi') + self.sendLine( + 1, "@label=12345;+draft/reply=123;+draft/react=l😃l PRIVMSG #test :hi" + ) ms = self.getMessage(1) mt = self.getMessage(2) # ensure the label isn't sent to recipient - self.assertMessageEqual(mt, command='PRIVMSG', fail_msg='No PRIVMSG received by the target after sending one out') - self.assertNotIn('label', mt.tags, mt, fail_msg="When sending a PRIVMSG with a label, the target user shouldn't receive the label (only the sending user should): {msg}") + self.assertMessageEqual( + mt, + command="PRIVMSG", + fail_msg="No PRIVMSG received by the target after sending one out", + ) + self.assertNotIn( + "label", + mt.tags, + mt, + fail_msg="When sending a PRIVMSG with a label, the target user shouldn't receive the label (only the sending user should): {msg}", + ) # ensure sender correctly receives msg - self.assertMessageEqual(ms, command='PRIVMSG', fail_msg="Got a message back that wasn't a PRIVMSG") - self.assertIn('label', ms.tags, ms, fail_msg="When sending a PRIVMSG with a label, the source user should receive the label but didn't: {msg}") - self.assertEqual(ms.tags['label'], '12345', ms, fail_msg="Echo'd label doesn't match the label we sent (should be '12345'): {msg}") + self.assertMessageEqual( + ms, command="PRIVMSG", fail_msg="Got a message back that wasn't a PRIVMSG" + ) + self.assertIn( + "label", + ms.tags, + ms, + fail_msg="When sending a PRIVMSG with a label, the source user should receive the label but didn't: {msg}", + ) + self.assertEqual( + ms.tags["label"], + "12345", + ms, + fail_msg="Echo'd label doesn't match the label we sent (should be '12345'): {msg}", + ) - @cases.SpecificationSelector.requiredBySpecification('IRCv3.2') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.2") def testLabeledPrivmsgResponsesToSelf(self): - self.connectClient('foo', capabilities=['batch', 'echo-message', 'labeled-response'], skip_if_cap_nak=True) + self.connectClient( + "foo", + capabilities=["batch", "echo-message", "labeled-response"], + skip_if_cap_nak=True, + ) self.getMessages(1) - self.sendLine(1, '@label=12345 PRIVMSG foo :hi') + self.sendLine(1, "@label=12345 PRIVMSG foo :hi") m1 = self.getMessage(1) m2 = self.getMessage(1) number_of_labels = 0 for m in [m1, m2]: - self.assertMessageEqual(m, command='PRIVMSG', fail_msg="Got a message back that wasn't a PRIVMSG") - if 'label' in m.tags: + self.assertMessageEqual( + m, + command="PRIVMSG", + fail_msg="Got a message back that wasn't a PRIVMSG", + ) + if "label" in m.tags: number_of_labels += 1 - self.assertEqual(m.tags['label'], '12345', m, fail_msg="Echo'd label doesn't match the label we sent (should be '12345'): {msg}") - - self.assertEqual(number_of_labels, 1, m1, fail_msg="When sending a PRIVMSG to self with echo-message, we only expect one message to contain the label. Instead, {} messages had the label".format(number_of_labels)) + self.assertEqual( + m.tags["label"], + "12345", + m, + fail_msg="Echo'd label doesn't match the label we sent (should be '12345'): {msg}", + ) - @cases.SpecificationSelector.requiredBySpecification('IRCv3.2') + self.assertEqual( + number_of_labels, + 1, + m1, + fail_msg="When sending a PRIVMSG to self with echo-message, we only expect one message to contain the label. Instead, {} messages had the label".format( + number_of_labels + ), + ) + + @cases.SpecificationSelector.requiredBySpecification("IRCv3.2") def testLabeledNoticeResponsesToClient(self): - self.connectClient('foo', capabilities=['batch', 'echo-message', 'labeled-response'], skip_if_cap_nak=True) + self.connectClient( + "foo", + capabilities=["batch", "echo-message", "labeled-response"], + skip_if_cap_nak=True, + ) self.getMessages(1) - self.connectClient('bar', capabilities=['batch', 'echo-message', 'labeled-response'], skip_if_cap_nak=True) + self.connectClient( + "bar", + capabilities=["batch", "echo-message", "labeled-response"], + skip_if_cap_nak=True, + ) self.getMessages(2) - self.sendLine(1, '@label=12345 NOTICE bar :hi') + self.sendLine(1, "@label=12345 NOTICE bar :hi") m = self.getMessage(1) m2 = self.getMessage(2) # ensure the label isn't sent to recipient - self.assertMessageEqual(m2, command='NOTICE', fail_msg='No NOTICE received by the target after sending one out') - self.assertNotIn('label', m2.tags, m2, fail_msg="When sending a NOTICE with a label, the target user shouldn't receive the label (only the sending user should): {msg}") + self.assertMessageEqual( + m2, + command="NOTICE", + fail_msg="No NOTICE received by the target after sending one out", + ) + self.assertNotIn( + "label", + m2.tags, + m2, + fail_msg="When sending a NOTICE with a label, the target user shouldn't receive the label (only the sending user should): {msg}", + ) - self.assertMessageEqual(m, command='NOTICE', fail_msg='No NOTICE echo received after sending one out') - self.assertIn('label', m.tags, m, fail_msg="When sending a NOTICE with a label, the echo'd message didn't contain the label at all: {msg}") - self.assertEqual(m.tags['label'], '12345', m, fail_msg="Echo'd NOTICE to a client did not contain the same label we sent it with(should be '12345'): {msg}") + self.assertMessageEqual( + m, + command="NOTICE", + fail_msg="No NOTICE echo received after sending one out", + ) + self.assertIn( + "label", + m.tags, + m, + fail_msg="When sending a NOTICE with a label, the echo'd message didn't contain the label at all: {msg}", + ) + self.assertEqual( + m.tags["label"], + "12345", + m, + fail_msg="Echo'd NOTICE to a client did not contain the same label we sent it with(should be '12345'): {msg}", + ) - @cases.SpecificationSelector.requiredBySpecification('IRCv3.2') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.2") def testLabeledNoticeResponsesToChannel(self): - self.connectClient('foo', capabilities=['batch', 'echo-message', 'labeled-response'], skip_if_cap_nak=True) + self.connectClient( + "foo", + capabilities=["batch", "echo-message", "labeled-response"], + skip_if_cap_nak=True, + ) self.getMessages(1) - self.connectClient('bar', capabilities=['batch', 'echo-message', 'labeled-response'], skip_if_cap_nak=True) + self.connectClient( + "bar", + capabilities=["batch", "echo-message", "labeled-response"], + skip_if_cap_nak=True, + ) self.getMessages(2) # join channels - self.sendLine(1, 'JOIN #test') + self.sendLine(1, "JOIN #test") self.getMessages(1) - self.sendLine(2, 'JOIN #test') + self.sendLine(2, "JOIN #test") self.getMessages(2) self.getMessages(1) - self.sendLine(1, '@label=12345;+draft/reply=123;+draft/react=l😃l NOTICE #test :hi') + self.sendLine( + 1, "@label=12345;+draft/reply=123;+draft/react=l😃l NOTICE #test :hi" + ) ms = self.getMessage(1) mt = self.getMessage(2) # ensure the label isn't sent to recipient - self.assertMessageEqual(mt, command='NOTICE', fail_msg='No NOTICE received by the target after sending one out') - self.assertNotIn('label', mt.tags, mt, fail_msg="When sending a NOTICE with a label, the target user shouldn't receive the label (only the sending user should): {msg}") + self.assertMessageEqual( + mt, + command="NOTICE", + fail_msg="No NOTICE received by the target after sending one out", + ) + self.assertNotIn( + "label", + mt.tags, + mt, + fail_msg="When sending a NOTICE with a label, the target user shouldn't receive the label (only the sending user should): {msg}", + ) # ensure sender correctly receives msg - self.assertMessageEqual(ms, command='NOTICE', fail_msg="Got a message back that wasn't a NOTICE") - self.assertIn('label', ms.tags, ms, fail_msg="When sending a NOTICE with a label, the source user should receive the label but didn't: {msg}") - self.assertEqual(ms.tags['label'], '12345', ms, fail_msg="Echo'd label doesn't match the label we sent (should be '12345'): {msg}") + self.assertMessageEqual( + ms, command="NOTICE", fail_msg="Got a message back that wasn't a NOTICE" + ) + self.assertIn( + "label", + ms.tags, + ms, + fail_msg="When sending a NOTICE with a label, the source user should receive the label but didn't: {msg}", + ) + self.assertEqual( + ms.tags["label"], + "12345", + ms, + fail_msg="Echo'd label doesn't match the label we sent (should be '12345'): {msg}", + ) - @cases.SpecificationSelector.requiredBySpecification('IRCv3.2') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.2") def testLabeledNoticeResponsesToSelf(self): - self.connectClient('foo', capabilities=['batch', 'echo-message', 'labeled-response'], skip_if_cap_nak=True) + self.connectClient( + "foo", + capabilities=["batch", "echo-message", "labeled-response"], + skip_if_cap_nak=True, + ) self.getMessages(1) - self.sendLine(1, '@label=12345 NOTICE foo :hi') + self.sendLine(1, "@label=12345 NOTICE foo :hi") m1 = self.getMessage(1) m2 = self.getMessage(1) number_of_labels = 0 for m in [m1, m2]: - self.assertMessageEqual(m, command='NOTICE', fail_msg="Got a message back that wasn't a NOTICE") - if 'label' in m.tags: + self.assertMessageEqual( + m, command="NOTICE", fail_msg="Got a message back that wasn't a NOTICE" + ) + if "label" in m.tags: number_of_labels += 1 - self.assertEqual(m.tags['label'], '12345', m, fail_msg="Echo'd label doesn't match the label we sent (should be '12345'): {msg}") - - self.assertEqual(number_of_labels, 1, m1, fail_msg="When sending a NOTICE to self with echo-message, we only expect one message to contain the label. Instead, {} messages had the label".format(number_of_labels)) + self.assertEqual( + m.tags["label"], + "12345", + m, + fail_msg="Echo'd label doesn't match the label we sent (should be '12345'): {msg}", + ) - @cases.SpecificationSelector.requiredBySpecification('IRCv3.2') + self.assertEqual( + number_of_labels, + 1, + m1, + fail_msg="When sending a NOTICE to self with echo-message, we only expect one message to contain the label. Instead, {} messages had the label".format( + number_of_labels + ), + ) + + @cases.SpecificationSelector.requiredBySpecification("IRCv3.2") def testLabeledTagMsgResponsesToClient(self): - self.connectClient('foo', capabilities=['batch', 'echo-message', 'labeled-response', 'message-tags'], skip_if_cap_nak=True) + self.connectClient( + "foo", + capabilities=["batch", "echo-message", "labeled-response", "message-tags"], + skip_if_cap_nak=True, + ) self.getMessages(1) - self.connectClient('bar', capabilities=['batch', 'echo-message', 'labeled-response', 'message-tags'], skip_if_cap_nak=True) + self.connectClient( + "bar", + capabilities=["batch", "echo-message", "labeled-response", "message-tags"], + skip_if_cap_nak=True, + ) self.getMessages(2) - self.sendLine(1, '@label=12345;+draft/reply=123;+draft/react=l😃l TAGMSG bar') + self.sendLine(1, "@label=12345;+draft/reply=123;+draft/react=l😃l TAGMSG bar") m = self.getMessage(1) m2 = self.getMessage(2) # ensure the label isn't sent to recipient - self.assertMessageEqual(m2, command='TAGMSG', fail_msg='No TAGMSG received by the target after sending one out') - self.assertNotIn('label', m2.tags, m2, fail_msg="When sending a TAGMSG with a label, the target user shouldn't receive the label (only the sending user should): {msg}") - self.assertIn('+draft/reply', m2.tags, m2, fail_msg="Reply tag wasn't present on the target user's TAGMSG: {msg}") - self.assertEqual(m2.tags['+draft/reply'], '123', m2, fail_msg="Reply tag wasn't the same on the target user's TAGMSG: {msg}") - self.assertIn('+draft/react', m2.tags, m2, fail_msg="React tag wasn't present on the target user's TAGMSG: {msg}") - self.assertEqual(m2.tags['+draft/react'], 'l😃l', m2, fail_msg="React tag wasn't the same on the target user's TAGMSG: {msg}") + self.assertMessageEqual( + m2, + command="TAGMSG", + fail_msg="No TAGMSG received by the target after sending one out", + ) + self.assertNotIn( + "label", + m2.tags, + m2, + fail_msg="When sending a TAGMSG with a label, the target user shouldn't receive the label (only the sending user should): {msg}", + ) + self.assertIn( + "+draft/reply", + m2.tags, + m2, + fail_msg="Reply tag wasn't present on the target user's TAGMSG: {msg}", + ) + self.assertEqual( + m2.tags["+draft/reply"], + "123", + m2, + fail_msg="Reply tag wasn't the same on the target user's TAGMSG: {msg}", + ) + self.assertIn( + "+draft/react", + m2.tags, + m2, + fail_msg="React tag wasn't present on the target user's TAGMSG: {msg}", + ) + self.assertEqual( + m2.tags["+draft/react"], + "l😃l", + m2, + fail_msg="React tag wasn't the same on the target user's TAGMSG: {msg}", + ) - self.assertMessageEqual(m, command='TAGMSG', fail_msg='No TAGMSG echo received after sending one out') - self.assertIn('label', m.tags, m, fail_msg="When sending a TAGMSG with a label, the echo'd message didn't contain the label at all: {msg}") - self.assertEqual(m.tags['label'], '12345', m, fail_msg="Echo'd TAGMSG to a client did not contain the same label we sent it with(should be '12345'): {msg}") - self.assertIn('+draft/reply', m.tags, m, fail_msg="Reply tag wasn't present on the source user's TAGMSG: {msg}") - self.assertEqual(m2.tags['+draft/reply'], '123', m, fail_msg="Reply tag wasn't the same on the source user's TAGMSG: {msg}") - self.assertIn('+draft/react', m.tags, m, fail_msg="React tag wasn't present on the source user's TAGMSG: {msg}") - self.assertEqual(m2.tags['+draft/react'], 'l😃l', m, fail_msg="React tag wasn't the same on the source user's TAGMSG: {msg}") + self.assertMessageEqual( + m, + command="TAGMSG", + fail_msg="No TAGMSG echo received after sending one out", + ) + self.assertIn( + "label", + m.tags, + m, + fail_msg="When sending a TAGMSG with a label, the echo'd message didn't contain the label at all: {msg}", + ) + self.assertEqual( + m.tags["label"], + "12345", + m, + fail_msg="Echo'd TAGMSG to a client did not contain the same label we sent it with(should be '12345'): {msg}", + ) + self.assertIn( + "+draft/reply", + m.tags, + m, + fail_msg="Reply tag wasn't present on the source user's TAGMSG: {msg}", + ) + self.assertEqual( + m2.tags["+draft/reply"], + "123", + m, + fail_msg="Reply tag wasn't the same on the source user's TAGMSG: {msg}", + ) + self.assertIn( + "+draft/react", + m.tags, + m, + fail_msg="React tag wasn't present on the source user's TAGMSG: {msg}", + ) + self.assertEqual( + m2.tags["+draft/react"], + "l😃l", + m, + fail_msg="React tag wasn't the same on the source user's TAGMSG: {msg}", + ) - @cases.SpecificationSelector.requiredBySpecification('IRCv3.2') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.2") def testLabeledTagMsgResponsesToChannel(self): - self.connectClient('foo', capabilities=['batch', 'echo-message', 'labeled-response', 'message-tags'], skip_if_cap_nak=True) + self.connectClient( + "foo", + capabilities=["batch", "echo-message", "labeled-response", "message-tags"], + skip_if_cap_nak=True, + ) self.getMessages(1) - self.connectClient('bar', capabilities=['batch', 'echo-message', 'labeled-response', 'message-tags'], skip_if_cap_nak=True) + self.connectClient( + "bar", + capabilities=["batch", "echo-message", "labeled-response", "message-tags"], + skip_if_cap_nak=True, + ) self.getMessages(2) # join channels - self.sendLine(1, 'JOIN #test') + self.sendLine(1, "JOIN #test") self.getMessages(1) - self.sendLine(2, 'JOIN #test') + self.sendLine(2, "JOIN #test") self.getMessages(2) self.getMessages(1) - self.sendLine(1, '@label=12345;+draft/reply=123;+draft/react=l😃l TAGMSG #test') + self.sendLine(1, "@label=12345;+draft/reply=123;+draft/react=l😃l TAGMSG #test") ms = self.getMessage(1) mt = self.getMessage(2) # ensure the label isn't sent to recipient - self.assertMessageEqual(mt, command='TAGMSG', fail_msg='No TAGMSG received by the target after sending one out') - self.assertNotIn('label', mt.tags, mt, fail_msg="When sending a TAGMSG with a label, the target user shouldn't receive the label (only the sending user should): {msg}") + self.assertMessageEqual( + mt, + command="TAGMSG", + fail_msg="No TAGMSG received by the target after sending one out", + ) + self.assertNotIn( + "label", + mt.tags, + mt, + fail_msg="When sending a TAGMSG with a label, the target user shouldn't receive the label (only the sending user should): {msg}", + ) # ensure sender correctly receives msg - self.assertMessageEqual(ms, command='TAGMSG', fail_msg="Got a message back that wasn't a TAGMSG") - self.assertIn('label', ms.tags, ms, fail_msg="When sending a TAGMSG with a label, the source user should receive the label but didn't: {msg}") - self.assertEqual(ms.tags['label'], '12345', ms, fail_msg="Echo'd label doesn't match the label we sent (should be '12345'): {msg}") + self.assertMessageEqual( + ms, command="TAGMSG", fail_msg="Got a message back that wasn't a TAGMSG" + ) + self.assertIn( + "label", + ms.tags, + ms, + fail_msg="When sending a TAGMSG with a label, the source user should receive the label but didn't: {msg}", + ) + self.assertEqual( + ms.tags["label"], + "12345", + ms, + fail_msg="Echo'd label doesn't match the label we sent (should be '12345'): {msg}", + ) - @cases.SpecificationSelector.requiredBySpecification('IRCv3.2') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.2") def testLabeledTagMsgResponsesToSelf(self): - self.connectClient('foo', capabilities=['batch', 'echo-message', 'labeled-response', 'message-tags'], skip_if_cap_nak=True) + self.connectClient( + "foo", + capabilities=["batch", "echo-message", "labeled-response", "message-tags"], + skip_if_cap_nak=True, + ) self.getMessages(1) - self.sendLine(1, '@label=12345;+draft/reply=123;+draft/react=l😃l TAGMSG foo') + self.sendLine(1, "@label=12345;+draft/reply=123;+draft/react=l😃l TAGMSG foo") m1 = self.getMessage(1) m2 = self.getMessage(1) number_of_labels = 0 for m in [m1, m2]: - self.assertMessageEqual(m, command='TAGMSG', fail_msg="Got a message back that wasn't a TAGMSG") - if 'label' in m.tags: + self.assertMessageEqual( + m, command="TAGMSG", fail_msg="Got a message back that wasn't a TAGMSG" + ) + if "label" in m.tags: number_of_labels += 1 - self.assertEqual(m.tags['label'], '12345', m, fail_msg="Echo'd label doesn't match the label we sent (should be '12345'): {msg}") + self.assertEqual( + m.tags["label"], + "12345", + m, + fail_msg="Echo'd label doesn't match the label we sent (should be '12345'): {msg}", + ) - self.assertEqual(number_of_labels, 1, m1, fail_msg="When sending a TAGMSG to self with echo-message, we only expect one message to contain the label. Instead, {} messages had the label".format(number_of_labels)) + self.assertEqual( + number_of_labels, + 1, + m1, + fail_msg="When sending a TAGMSG to self with echo-message, we only expect one message to contain the label. Instead, {} messages had the label".format( + number_of_labels + ), + ) - @cases.SpecificationSelector.requiredBySpecification('IRCv3.2') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.2") def testBatchedJoinMessages(self): - self.connectClient('bar', capabilities=['batch', 'labeled-response', 'message-tags', 'server-time'], skip_if_cap_nak=True) + self.connectClient( + "bar", + capabilities=["batch", "labeled-response", "message-tags", "server-time"], + skip_if_cap_nak=True, + ) self.getMessages(1) - self.sendLine(1, '@label=12345 JOIN #xyz') + self.sendLine(1, "@label=12345 JOIN #xyz") m = self.getMessages(1) # we expect at least join and names lines, which must be batched @@ -247,45 +577,57 @@ class LabeledResponsesTestCase(cases.BaseServerTestCase, cases.OptionalityHelper # valid BATCH start line: batch_start = m[0] - self.assertMessageEqual(batch_start, command='BATCH') + self.assertMessageEqual(batch_start, command="BATCH") self.assertEqual(len(batch_start.params), 2) - self.assertTrue(batch_start.params[0].startswith('+'), 'batch start param must begin with +, got %s' % (batch_start.params[0],)) + self.assertTrue( + batch_start.params[0].startswith("+"), + "batch start param must begin with +, got %s" % (batch_start.params[0],), + ) batch_id = batch_start.params[0][1:] # batch id MUST be alphanumerics and hyphens - self.assertTrue(re.match(r'^[A-Za-z0-9\-]+$', batch_id) is not None, 'batch id must be alphanumerics and hyphens, got %r' % (batch_id,)) - self.assertEqual(batch_start.params[1], 'labeled-response') - self.assertEqual(batch_start.tags.get('label'), '12345') + self.assertTrue( + re.match(r"^[A-Za-z0-9\-]+$", batch_id) is not None, + "batch id must be alphanumerics and hyphens, got %r" % (batch_id,), + ) + self.assertEqual(batch_start.params[1], "labeled-response") + self.assertEqual(batch_start.tags.get("label"), "12345") # valid BATCH end line batch_end = m[-1] - self.assertMessageEqual(batch_end, command='BATCH', params=['-' + batch_id]) + self.assertMessageEqual(batch_end, command="BATCH", params=["-" + batch_id]) # messages must have the BATCH tag for message in m[1:-1]: - self.assertEqual(message.tags.get('batch'), batch_id) + self.assertEqual(message.tags.get("batch"), batch_id) - @cases.SpecificationSelector.requiredBySpecification('Oragono') + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testNoBatchForSingleMessage(self): - self.connectClient('bar', capabilities=['batch', 'labeled-response', 'message-tags', 'server-time']) + self.connectClient( + "bar", + capabilities=["batch", "labeled-response", "message-tags", "server-time"], + ) self.getMessages(1) - self.sendLine(1, '@label=98765 PING adhoctestline') + self.sendLine(1, "@label=98765 PING adhoctestline") # no BATCH should be initiated for a one-line response, it should just be labeled ms = self.getMessages(1) self.assertEqual(len(ms), 1) m = ms[0] - self.assertEqual(m.command, 'PONG') - self.assertEqual(m.params[-1], 'adhoctestline') + self.assertEqual(m.command, "PONG") + self.assertEqual(m.params[-1], "adhoctestline") # check the label - self.assertEqual(m.tags.get('label'), '98765') + self.assertEqual(m.tags.get("label"), "98765") - @cases.SpecificationSelector.requiredBySpecification('Oragono') + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testEmptyBatchForNoResponse(self): - self.connectClient('bar', capabilities=['batch', 'labeled-response', 'message-tags', 'server-time']) + self.connectClient( + "bar", + capabilities=["batch", "labeled-response", "message-tags", "server-time"], + ) self.getMessages(1) # PONG never receives a response - self.sendLine(1, '@label=98765 PONG adhoctestline') + self.sendLine(1, "@label=98765 PONG adhoctestline") # labeled-response: "Servers MUST respond with a labeled # `ACK` message when a client sends a labeled command that normally @@ -294,5 +636,5 @@ class LabeledResponsesTestCase(cases.BaseServerTestCase, cases.OptionalityHelper self.assertEqual(len(ms), 1) ack = ms[0] - self.assertEqual(ack.command, 'ACK') - self.assertEqual(ack.tags.get('label'), '98765') + self.assertEqual(ack.command, "ACK") + self.assertEqual(ack.tags.get("label"), "98765") diff --git a/irctest/server_tests/test_message_tags.py b/irctest/server_tests/test_message_tags.py index 6c0f76b..0dc0bea 100644 --- a/irctest/server_tests/test_message_tags.py +++ b/irctest/server_tests/test_message_tags.py @@ -6,142 +6,143 @@ from irctest import cases from irctest.irc_utils.message_parser import parse_message from irctest.numerics import ERR_INPUTTOOLONG -class MessageTagsTestCase(cases.BaseServerTestCase, cases.OptionalityHelper): - @cases.SpecificationSelector.requiredBySpecification('message-tags') +class MessageTagsTestCase(cases.BaseServerTestCase, cases.OptionalityHelper): + @cases.SpecificationSelector.requiredBySpecification("message-tags") def testBasic(self): def getAllMessages(): - for name in ['alice', 'bob', 'carol', 'dave']: + for name in ["alice", "bob", "carol", "dave"]: self.getMessages(name) def assertNoTags(line): # tags start with '@', without tags we start with the prefix, # which begins with ':' - self.assertEqual(line[0], ':') + self.assertEqual(line[0], ":") msg = parse_message(line) self.assertEqual(msg.tags, {}) return msg self.connectClient( - 'alice', - name='alice', - capabilities=['message-tags'], - skip_if_cap_nak=True + "alice", name="alice", capabilities=["message-tags"], skip_if_cap_nak=True ) - self.joinChannel('alice', '#test') - self.connectClient('bob', name='bob', capabilities=['message-tags', 'echo-message']) - self.joinChannel('bob', '#test') - self.connectClient('carol', name='carol') - self.joinChannel('carol', '#test') - self.connectClient('dave', name='dave', capabilities=['server-time']) - self.joinChannel('dave', '#test') + self.joinChannel("alice", "#test") + self.connectClient( + "bob", name="bob", capabilities=["message-tags", "echo-message"] + ) + self.joinChannel("bob", "#test") + self.connectClient("carol", name="carol") + self.joinChannel("carol", "#test") + self.connectClient("dave", name="dave", capabilities=["server-time"]) + self.joinChannel("dave", "#test") getAllMessages() - self.sendLine('alice', '@+baz=bat;fizz=buzz PRIVMSG #test hi') - self.getMessages('alice') - bob_msg = self.getMessage('bob') - carol_line = self.getMessage('carol', raw=True) - self.assertMessageEqual(bob_msg, command='PRIVMSG', params=['#test', 'hi']) - self.assertEqual(bob_msg.tags['+baz'], "bat") - self.assertIn('msgid', bob_msg.tags) + self.sendLine("alice", "@+baz=bat;fizz=buzz PRIVMSG #test hi") + self.getMessages("alice") + bob_msg = self.getMessage("bob") + carol_line = self.getMessage("carol", raw=True) + self.assertMessageEqual(bob_msg, command="PRIVMSG", params=["#test", "hi"]) + self.assertEqual(bob_msg.tags["+baz"], "bat") + self.assertIn("msgid", bob_msg.tags) # should not relay a non-client-only tag - self.assertNotIn('fizz', bob_msg.tags) + self.assertNotIn("fizz", bob_msg.tags) # carol MUST NOT receive tags carol_msg = assertNoTags(carol_line) - self.assertMessageEqual(carol_msg, command='PRIVMSG', params=['#test', 'hi']) + self.assertMessageEqual(carol_msg, command="PRIVMSG", params=["#test", "hi"]) # dave SHOULD receive server-time tag - dave_msg = self.getMessage('dave') - self.assertIn('time', dave_msg.tags) + dave_msg = self.getMessage("dave") + self.assertIn("time", dave_msg.tags) # dave MUST NOT receive client-only tags - self.assertNotIn('+baz', dave_msg.tags) + self.assertNotIn("+baz", dave_msg.tags) getAllMessages() - self.sendLine('bob', '@+bat=baz;+fizz=buzz PRIVMSG #test :hi yourself') - bob_msg = self.getMessage('bob') # bob has echo-message - alice_msg = self.getMessage('alice') - carol_line = self.getMessage('carol', raw=True) + self.sendLine("bob", "@+bat=baz;+fizz=buzz PRIVMSG #test :hi yourself") + bob_msg = self.getMessage("bob") # bob has echo-message + alice_msg = self.getMessage("alice") + carol_line = self.getMessage("carol", raw=True) carol_msg = assertNoTags(carol_line) for msg in [alice_msg, bob_msg, carol_msg]: - self.assertMessageEqual(msg, command='PRIVMSG', params=['#test', 'hi yourself']) + self.assertMessageEqual( + msg, command="PRIVMSG", params=["#test", "hi yourself"] + ) for msg in [alice_msg, bob_msg]: - self.assertEqual(msg.tags['+bat'], 'baz') - self.assertEqual(msg.tags['+fizz'], 'buzz') - self.assertTrue(alice_msg.tags['msgid']) - self.assertEqual(alice_msg.tags['msgid'], bob_msg.tags['msgid']) + self.assertEqual(msg.tags["+bat"], "baz") + self.assertEqual(msg.tags["+fizz"], "buzz") + self.assertTrue(alice_msg.tags["msgid"]) + self.assertEqual(alice_msg.tags["msgid"], bob_msg.tags["msgid"]) getAllMessages() # test TAGMSG and basic escaping - self.sendLine('bob', '@+buzz=fizz\:buzz;cat=dog;+steel=wootz TAGMSG #test') - bob_msg = self.getMessage('bob') # bob has echo-message - alice_msg = self.getMessage('alice') + self.sendLine("bob", "@+buzz=fizz\:buzz;cat=dog;+steel=wootz TAGMSG #test") + bob_msg = self.getMessage("bob") # bob has echo-message + alice_msg = self.getMessage("alice") # carol MUST NOT receive TAGMSG at all - self.assertEqual(self.getMessages('carol'), []) + self.assertEqual(self.getMessages("carol"), []) # dave MUST NOT receive TAGMSG either, despite having server-time - self.assertEqual(self.getMessages('dave'), []) + self.assertEqual(self.getMessages("dave"), []) for msg in [alice_msg, bob_msg]: - self.assertMessageEqual(alice_msg, command='TAGMSG', params=['#test']) - self.assertEqual(msg.tags['+buzz'], 'fizz;buzz') - self.assertEqual(msg.tags['+steel'], 'wootz') - self.assertNotIn('cat', msg.tags) - self.assertTrue(alice_msg.tags['msgid']) - self.assertEqual(alice_msg.tags['msgid'], bob_msg.tags['msgid']) + self.assertMessageEqual(alice_msg, command="TAGMSG", params=["#test"]) + self.assertEqual(msg.tags["+buzz"], "fizz;buzz") + self.assertEqual(msg.tags["+steel"], "wootz") + self.assertNotIn("cat", msg.tags) + self.assertTrue(alice_msg.tags["msgid"]) + self.assertEqual(alice_msg.tags["msgid"], bob_msg.tags["msgid"]) - @cases.SpecificationSelector.requiredBySpecification('message-tags') + @cases.SpecificationSelector.requiredBySpecification("message-tags") def testLengthLimits(self): self.connectClient( - 'alice', - name='alice', - capabilities=['message-tags', 'echo-message'], - skip_if_cap_nak=True + "alice", + name="alice", + capabilities=["message-tags", "echo-message"], + skip_if_cap_nak=True, ) - self.joinChannel('alice', '#test') - self.connectClient('bob', name='bob', capabilities=['message-tags']) - self.joinChannel('bob', '#test') - self.getMessages('alice') - self.getMessages('bob') + self.joinChannel("alice", "#test") + self.connectClient("bob", name="bob", capabilities=["message-tags"]) + self.joinChannel("bob", "#test") + self.getMessages("alice") + self.getMessages("bob") # this is right at the limit of 4094 bytes of tag data, # 4096 bytes of tag section (including the starting '@' and the final ' ') - max_tagmsg = '@foo=bar;+baz=%s TAGMSG #test' % ('a' * 4081,) - self.assertEqual(max_tagmsg.index('TAGMSG'), 4096) - self.sendLine('alice', max_tagmsg) - echo = self.getMessage('alice') - relay = self.getMessage('bob') - self.assertMessageEqual(echo, command='TAGMSG', params=['#test']) - self.assertMessageEqual(relay, command='TAGMSG', params=['#test']) - self.assertNotEqual(echo.tags['msgid'], '') - self.assertEqual(echo.tags['msgid'], relay.tags['msgid']) - self.assertEqual(echo.tags['+baz'], 'a' * 4081) - self.assertEqual(relay.tags['+baz'], echo.tags['+baz']) + max_tagmsg = "@foo=bar;+baz=%s TAGMSG #test" % ("a" * 4081,) + self.assertEqual(max_tagmsg.index("TAGMSG"), 4096) + self.sendLine("alice", max_tagmsg) + echo = self.getMessage("alice") + relay = self.getMessage("bob") + self.assertMessageEqual(echo, command="TAGMSG", params=["#test"]) + self.assertMessageEqual(relay, command="TAGMSG", params=["#test"]) + self.assertNotEqual(echo.tags["msgid"], "") + self.assertEqual(echo.tags["msgid"], relay.tags["msgid"]) + self.assertEqual(echo.tags["+baz"], "a" * 4081) + self.assertEqual(relay.tags["+baz"], echo.tags["+baz"]) - excess_tagmsg = '@foo=bar;+baz=%s TAGMSG #test' % ('a' * 4082,) - self.assertEqual(excess_tagmsg.index('TAGMSG'), 4097) - self.sendLine('alice', excess_tagmsg) - reply = self.getMessage('alice') + excess_tagmsg = "@foo=bar;+baz=%s TAGMSG #test" % ("a" * 4082,) + self.assertEqual(excess_tagmsg.index("TAGMSG"), 4097) + self.sendLine("alice", excess_tagmsg) + reply = self.getMessage("alice") self.assertEqual(reply.command, ERR_INPUTTOOLONG) - self.assertEqual(self.getMessages('bob'), []) + self.assertEqual(self.getMessages("bob"), []) - max_privmsg = '@foo=bar;+baz=%s PRIVMSG #test %s' % ('a' * 4081, 'b' * 496) + max_privmsg = "@foo=bar;+baz=%s PRIVMSG #test %s" % ("a" * 4081, "b" * 496) # irctest adds the '\r\n' for us, this is right at the limit self.assertEqual(len(max_privmsg), 4096 + (512 - 2)) - self.sendLine('alice', max_privmsg) - echo = self.getMessage('alice') - relay = self.getMessage('bob') - self.assertNotEqual(echo.tags['msgid'], '') - self.assertEqual(echo.tags['msgid'], relay.tags['msgid']) - self.assertEqual(echo.tags['+baz'], 'a' * 4081) - self.assertEqual(relay.tags['+baz'], echo.tags['+baz']) + self.sendLine("alice", max_privmsg) + echo = self.getMessage("alice") + relay = self.getMessage("bob") + self.assertNotEqual(echo.tags["msgid"], "") + self.assertEqual(echo.tags["msgid"], relay.tags["msgid"]) + self.assertEqual(echo.tags["+baz"], "a" * 4081) + self.assertEqual(relay.tags["+baz"], echo.tags["+baz"]) # message may have been truncated - self.assertIn('b' * 400, echo.params[1]) - self.assertEqual(echo.params[1].rstrip('b'), '') - self.assertIn('b' * 400, relay.params[1]) - self.assertEqual(relay.params[1].rstrip('b'), '') + self.assertIn("b" * 400, echo.params[1]) + self.assertEqual(echo.params[1].rstrip("b"), "") + self.assertIn("b" * 400, relay.params[1]) + self.assertEqual(relay.params[1].rstrip("b"), "") - excess_privmsg = '@foo=bar;+baz=%s PRIVMSG #test %s' % ('a' * 4082, 'b' * 495) + excess_privmsg = "@foo=bar;+baz=%s PRIVMSG #test %s" % ("a" * 4082, "b" * 495) # TAGMSG data is over the limit, but we're within the overall limit for a line - self.assertEqual(excess_privmsg.index('PRIVMSG'), 4097) + self.assertEqual(excess_privmsg.index("PRIVMSG"), 4097) self.assertEqual(len(excess_privmsg), 4096 + (512 - 2)) - self.sendLine('alice', excess_privmsg) - reply = self.getMessage('alice') + self.sendLine("alice", excess_privmsg) + reply = self.getMessage("alice") self.assertEqual(reply.command, ERR_INPUTTOOLONG) - self.assertEqual(self.getMessages('bob'), []) + self.assertEqual(self.getMessages("bob"), []) diff --git a/irctest/server_tests/test_messages.py b/irctest/server_tests/test_messages.py index dcf730e..5a9c4a7 100644 --- a/irctest/server_tests/test_messages.py +++ b/irctest/server_tests/test_messages.py @@ -6,54 +6,52 @@ Section 3.2 of RFC 2812 from irctest import cases from irctest.numerics import ERR_INPUTTOOLONG + class PrivmsgTestCase(cases.BaseServerTestCase): - @cases.SpecificationSelector.requiredBySpecification('RFC1459', 'RFC2812') + @cases.SpecificationSelector.requiredBySpecification("RFC1459", "RFC2812") def testPrivmsg(self): """""" - self.connectClient('foo') - self.sendLine(1, 'JOIN #chan') - self.connectClient('bar') - self.sendLine(2, 'JOIN #chan') - self.getMessages(2) # synchronize - self.sendLine(1, 'PRIVMSG #chan :hello there') - self.getMessages(1) # synchronize - pms = [msg for msg in self.getMessages(2) if msg.command == 'PRIVMSG'] + self.connectClient("foo") + self.sendLine(1, "JOIN #chan") + self.connectClient("bar") + self.sendLine(2, "JOIN #chan") + self.getMessages(2) # synchronize + self.sendLine(1, "PRIVMSG #chan :hello there") + self.getMessages(1) # synchronize + pms = [msg for msg in self.getMessages(2) if msg.command == "PRIVMSG"] self.assertEqual(len(pms), 1) self.assertMessageEqual( - pms[0], - command='PRIVMSG', - params=['#chan', 'hello there'] + pms[0], command="PRIVMSG", params=["#chan", "hello there"] ) - @cases.SpecificationSelector.requiredBySpecification('RFC1459', 'RFC2812') + @cases.SpecificationSelector.requiredBySpecification("RFC1459", "RFC2812") def testPrivmsgNonexistentChannel(self): """""" - self.connectClient('foo') - self.sendLine(1, 'PRIVMSG #nonexistent :hello there') + self.connectClient("foo") + self.sendLine(1, "PRIVMSG #nonexistent :hello there") msg = self.getMessage(1) # ERR_NOSUCHNICK, ERR_NOSUCHCHANNEL, or ERR_CANNOTSENDTOCHAN - self.assertIn(msg.command, ('401', '403', '404')) + self.assertIn(msg.command, ("401", "403", "404")) + class NoticeTestCase(cases.BaseServerTestCase): - @cases.SpecificationSelector.requiredBySpecification('RFC1459', 'RFC2812') + @cases.SpecificationSelector.requiredBySpecification("RFC1459", "RFC2812") def testNotice(self): """""" - self.connectClient('foo') - self.sendLine(1, 'JOIN #chan') - self.connectClient('bar') - self.sendLine(2, 'JOIN #chan') - self.getMessages(2) # synchronize - self.sendLine(1, 'NOTICE #chan :hello there') - self.getMessages(1) # synchronize - notices = [msg for msg in self.getMessages(2) if msg.command == 'NOTICE'] + self.connectClient("foo") + self.sendLine(1, "JOIN #chan") + self.connectClient("bar") + self.sendLine(2, "JOIN #chan") + self.getMessages(2) # synchronize + self.sendLine(1, "NOTICE #chan :hello there") + self.getMessages(1) # synchronize + notices = [msg for msg in self.getMessages(2) if msg.command == "NOTICE"] self.assertEqual(len(notices), 1) self.assertMessageEqual( - notices[0], - command='NOTICE', - params=['#chan', 'hello there'] + notices[0], command="NOTICE", params=["#chan", "hello there"] ) - @cases.SpecificationSelector.requiredBySpecification('RFC1459', 'RFC2812') + @cases.SpecificationSelector.requiredBySpecification("RFC1459", "RFC2812") def testNoticeNonexistentChannel(self): """ 'automatic replies MUST NEVER be sent in response to a NOTICE message. @@ -61,17 +59,17 @@ class NoticeTestCase(cases.BaseServerTestCase): back to the client on receipt of a notice.' https://tools.ietf.org/html/rfc2812#section-3.3.2> """ - self.connectClient('foo') - self.sendLine(1, 'NOTICE #nonexistent :hello there') + self.connectClient("foo") + self.sendLine(1, "NOTICE #nonexistent :hello there") self.assertEqual(self.getMessages(1), []) class TagsTestCase(cases.BaseServerTestCase): - @cases.SpecificationSelector.requiredBySpecification('Oragono') + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testLineTooLong(self): - self.connectClient('bar') - self.joinChannel(1, '#xyz') - monsterMessage = '@+clientOnlyTagExample=' + 'a'*4096 + ' PRIVMSG #xyz hi!' + self.connectClient("bar") + self.joinChannel(1, "#xyz") + monsterMessage = "@+clientOnlyTagExample=" + "a" * 4096 + " PRIVMSG #xyz hi!" self.sendLine(1, monsterMessage) replies = self.getMessages(1) self.assertIn(ERR_INPUTTOOLONG, set(reply.command for reply in replies)) diff --git a/irctest/server_tests/test_metadata.py b/irctest/server_tests/test_metadata.py index 776119b..5a436b9 100644 --- a/irctest/server_tests/test_metadata.py +++ b/irctest/server_tests/test_metadata.py @@ -5,174 +5,244 @@ Tests METADATA features. from irctest import cases + class MetadataTestCase(cases.BaseServerTestCase): - valid_metadata_keys = {'valid_key1', 'valid_key2'} - invalid_metadata_keys = {'invalid_key1', 'invalid_key2'} - @cases.SpecificationSelector.requiredBySpecification('IRCv3.2-deprecated') + valid_metadata_keys = {"valid_key1", "valid_key2"} + invalid_metadata_keys = {"invalid_key1", "invalid_key2"} + + @cases.SpecificationSelector.requiredBySpecification("IRCv3.2-deprecated") def testInIsupport(self): """“If METADATA is supported, it MUST be specified in RPL_ISUPPORT using the METADATA key.” -- """ self.addClient() - self.sendLine(1, 'CAP LS 302') + self.sendLine(1, "CAP LS 302") self.getCapLs(1) - self.sendLine(1, 'USER foo foo foo :foo') - self.sendLine(1, 'NICK foo') - self.sendLine(1, 'CAP END') + self.sendLine(1, "USER foo foo foo :foo") + self.sendLine(1, "NICK foo") + self.sendLine(1, "CAP END") self.skipToWelcome(1) m = self.getMessage(1) - while m.command != '005': # RPL_ISUPPORT + while m.command != "005": # RPL_ISUPPORT m = self.getMessage(1) - self.assertIn('METADATA', {x.split('=')[0] for x in m.params[1:-1]}, - fail_msg='{item} missing from RPL_ISUPPORT') + self.assertIn( + "METADATA", + {x.split("=")[0] for x in m.params[1:-1]}, + fail_msg="{item} missing from RPL_ISUPPORT", + ) self.getMessages(1) - @cases.SpecificationSelector.requiredBySpecification('IRCv3.2-deprecated') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.2-deprecated") def testGetOneUnsetValid(self): - """ - """ - self.connectClient('foo') - self.sendLine(1, 'METADATA * GET valid_key1') + """""" + self.connectClient("foo") + self.sendLine(1, "METADATA * GET valid_key1") m = self.getMessage(1) - self.assertMessageEqual(m, command='766', # ERR_NOMATCHINGKEY - fail_msg='Did not reply with 766 (ERR_NOMATCHINGKEY) to a ' - 'request to an unset valid METADATA key.') + self.assertMessageEqual( + m, + command="766", # ERR_NOMATCHINGKEY + fail_msg="Did not reply with 766 (ERR_NOMATCHINGKEY) to a " + "request to an unset valid METADATA key.", + ) - @cases.SpecificationSelector.requiredBySpecification('IRCv3.2-deprecated') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.2-deprecated") def testGetTwoUnsetValid(self): """“Multiple keys may be given. The response will be either RPL_KEYVALUE, ERR_KEYINVALID or ERR_NOMATCHINGKEY for every key in order.” -- """ - self.connectClient('foo') - self.sendLine(1, 'METADATA * GET valid_key1 valid_key2') + self.connectClient("foo") + self.sendLine(1, "METADATA * GET valid_key1 valid_key2") m = self.getMessage(1) - self.assertMessageEqual(m, command='766', # ERR_NOMATCHINGKEY - fail_msg='Did not reply with 766 (ERR_NOMATCHINGKEY) to a ' - 'request to two unset valid METADATA key: {msg}') - self.assertEqual(m.params[1], 'valid_key1', m, - fail_msg='Response to “METADATA * GET valid_key1 valid_key2” ' - 'did not respond to valid_key1 first: {msg}') + self.assertMessageEqual( + m, + command="766", # ERR_NOMATCHINGKEY + fail_msg="Did not reply with 766 (ERR_NOMATCHINGKEY) to a " + "request to two unset valid METADATA key: {msg}", + ) + self.assertEqual( + m.params[1], + "valid_key1", + m, + fail_msg="Response to “METADATA * GET valid_key1 valid_key2” " + "did not respond to valid_key1 first: {msg}", + ) m = self.getMessage(1) - self.assertMessageEqual(m, command='766', # ERR_NOMATCHINGKEY - fail_msg='Did not reply with two 766 (ERR_NOMATCHINGKEY) to a ' - 'request to two unset valid METADATA key: {msg}') - self.assertEqual(m.params[1], 'valid_key2', m, - fail_msg='Response to “METADATA * GET valid_key1 valid_key2” ' - 'did not respond to valid_key2 as second response: {msg}') + self.assertMessageEqual( + m, + command="766", # ERR_NOMATCHINGKEY + fail_msg="Did not reply with two 766 (ERR_NOMATCHINGKEY) to a " + "request to two unset valid METADATA key: {msg}", + ) + self.assertEqual( + m.params[1], + "valid_key2", + m, + fail_msg="Response to “METADATA * GET valid_key1 valid_key2” " + "did not respond to valid_key2 as second response: {msg}", + ) - @cases.SpecificationSelector.requiredBySpecification('IRCv3.2-deprecated') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.2-deprecated") def testListNoSet(self): """“This subcommand MUST list all currently-set metadata keys along with their values. The response will be zero or more RPL_KEYVALUE events, following by RPL_METADATAEND event.” -- """ - self.connectClient('foo') - self.sendLine(1, 'METADATA * LIST') + self.connectClient("foo") + self.sendLine(1, "METADATA * LIST") m = self.getMessage(1) - self.assertMessageEqual(m, command='762', # RPL_METADATAEND - fail_msg='Response to “METADATA * LIST” was not ' - '762 (RPL_METADATAEND) but: {msg}') + self.assertMessageEqual( + m, + command="762", # RPL_METADATAEND + fail_msg="Response to “METADATA * LIST” was not " + "762 (RPL_METADATAEND) but: {msg}", + ) - @cases.SpecificationSelector.requiredBySpecification('IRCv3.2-deprecated') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.2-deprecated") def testListInvalidTarget(self): """“In case of invalid target RPL_METADATAEND MUST NOT be sent.” -- """ - self.connectClient('foo') - self.sendLine(1, 'METADATA foobar LIST') + self.connectClient("foo") + self.sendLine(1, "METADATA foobar LIST") m = self.getMessage(1) - self.assertMessageEqual(m, command='765', # ERR_TARGETINVALID - fail_msg='Response to “METADATA LIST” was ' - 'not 765 (ERR_TARGETINVALID) but: {msg}') + self.assertMessageEqual( + m, + command="765", # ERR_TARGETINVALID + fail_msg="Response to “METADATA LIST” was " + "not 765 (ERR_TARGETINVALID) but: {msg}", + ) commands = {m.command for m in self.getMessages(1)} - self.assertNotIn('762', commands, - fail_msg='Sent “METADATA LIST”, got 765 ' - '(ERR_TARGETINVALID), and then 762 (RPL_METADATAEND)') + self.assertNotIn( + "762", + commands, + fail_msg="Sent “METADATA LIST”, got 765 " + "(ERR_TARGETINVALID), and then 762 (RPL_METADATAEND)", + ) def assertSetValue(self, target, key, value, displayable_value=None): if displayable_value is None: displayable_value = value - self.sendLine(1, 'METADATA {} SET {} :{}'.format(target, key, value)) + self.sendLine(1, "METADATA {} SET {} :{}".format(target, key, value)) m = self.getMessage(1) - self.assertMessageEqual(m, command='761', # RPL_KEYVALUE - fail_msg='Did not reply with 761 (RPL_KEYVALUE) to a valid ' - '“METADATA * SET {} :{}”: {msg}', - extra_format=(key, displayable_value,)) - self.assertEqual(m.params[1], 'valid_key1', m, - fail_msg='Second param of 761 after setting “{expects}” to ' - '“{}” is not “{expects}”: {msg}.', - extra_format=(displayable_value,)) - self.assertEqual(m.params[3], value, m, - fail_msg='Fourth param of 761 after setting “{0}” to ' - '“{1}” is not “{1}”: {msg}.', - extra_format=(key, displayable_value)) + self.assertMessageEqual( + m, + command="761", # RPL_KEYVALUE + fail_msg="Did not reply with 761 (RPL_KEYVALUE) to a valid " + "“METADATA * SET {} :{}”: {msg}", + extra_format=( + key, + displayable_value, + ), + ) + self.assertEqual( + m.params[1], + "valid_key1", + m, + fail_msg="Second param of 761 after setting “{expects}” to " + "“{}” is not “{expects}”: {msg}.", + extra_format=(displayable_value,), + ) + self.assertEqual( + m.params[3], + value, + m, + fail_msg="Fourth param of 761 after setting “{0}” to " + "“{1}” is not “{1}”: {msg}.", + extra_format=(key, displayable_value), + ) m = self.getMessage(1) - self.assertMessageEqual(m, command='762', # RPL_METADATAEND - fail_msg='Did not send RPL_METADATAEND after setting ' - 'a valid METADATA key.') + self.assertMessageEqual( + m, + command="762", # RPL_METADATAEND + fail_msg="Did not send RPL_METADATAEND after setting " + "a valid METADATA key.", + ) + def assertGetValue(self, target, key, value, displayable_value=None): - self.sendLine(1, 'METADATA * GET {}'.format(key)) + self.sendLine(1, "METADATA * GET {}".format(key)) m = self.getMessage(1) - self.assertMessageEqual(m, command='761', # RPL_KEYVALUE - fail_msg='Did not reply with 761 (RPL_KEYVALUE) to a valid ' - '“METADATA * GET” when the key is set is set: {msg}') - self.assertEqual(m.params[1], key, m, - fail_msg='Second param of 761 after getting “{expects}” ' - '(which is set) is not “{expects}”: {msg}.') - self.assertEqual(m.params[3], value, m, - fail_msg='Fourth param of 761 after getting “{0}” ' - '(which is set to “{1}”) is not ”{1}”: {msg}.', - extra_format=(key, displayable_value)) + self.assertMessageEqual( + m, + command="761", # RPL_KEYVALUE + fail_msg="Did not reply with 761 (RPL_KEYVALUE) to a valid " + "“METADATA * GET” when the key is set is set: {msg}", + ) + self.assertEqual( + m.params[1], + key, + m, + fail_msg="Second param of 761 after getting “{expects}” " + "(which is set) is not “{expects}”: {msg}.", + ) + self.assertEqual( + m.params[3], + value, + m, + fail_msg="Fourth param of 761 after getting “{0}” " + "(which is set to “{1}”) is not ”{1}”: {msg}.", + extra_format=(key, displayable_value), + ) + def assertSetGetValue(self, target, key, value, displayable_value=None): self.assertSetValue(target, key, value, displayable_value) self.assertGetValue(target, key, value, displayable_value) - @cases.SpecificationSelector.requiredBySpecification('IRCv3.2-deprecated') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.2-deprecated") def testSetGetValid(self): - """ - """ - self.connectClient('foo') - self.assertSetGetValue('*', 'valid_key1', 'myvalue') + """""" + self.connectClient("foo") + self.assertSetGetValue("*", "valid_key1", "myvalue") - @cases.SpecificationSelector.requiredBySpecification('IRCv3.2-deprecated') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.2-deprecated") def testSetGetZeroCharInValue(self): """“Values are unrestricted, except that they MUST be UTF-8.” -- """ - self.connectClient('foo') - self.assertSetGetValue('*', 'valid_key1', 'zero->\0<-zero', - 'zero->\\0<-zero') + self.connectClient("foo") + self.assertSetGetValue("*", "valid_key1", "zero->\0<-zero", "zero->\\0<-zero") - @cases.SpecificationSelector.requiredBySpecification('IRCv3.2-deprecated') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.2-deprecated") def testSetGetHeartInValue(self): """“Values are unrestricted, except that they MUST be UTF-8.” -- """ - heart = b'\xf0\x9f\x92\x9c'.decode() - self.connectClient('foo') - self.assertSetGetValue('*', 'valid_key1', '->{}<-'.format(heart), - 'zero->{}<-zero'.format(heart.encode())) + heart = b"\xf0\x9f\x92\x9c".decode() + self.connectClient("foo") + self.assertSetGetValue( + "*", + "valid_key1", + "->{}<-".format(heart), + "zero->{}<-zero".format(heart.encode()), + ) - @cases.SpecificationSelector.requiredBySpecification('IRCv3.2-deprecated') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.2-deprecated") def testSetInvalidUtf8(self): """“Values are unrestricted, except that they MUST be UTF-8.” -- """ - self.connectClient('foo') + self.connectClient("foo") # Sending directly because it is not valid UTF-8 so Python would # not like it - self.clients[1].conn.sendall(b'METADATA * SET valid_key1 ' - b':invalid UTF-8 ->\xc3<-\r\n') + self.clients[1].conn.sendall( + b"METADATA * SET valid_key1 " b":invalid UTF-8 ->\xc3<-\r\n" + ) commands = {m.command for m in self.getMessages(1)} - self.assertNotIn('761', commands, # RPL_KEYVALUE - fail_msg='Setting METADATA key to a value containing invalid ' - 'UTF-8 was answered with 761 (RPL_KEYVALUE)') - self.clients[1].conn.sendall(b'METADATA * SET valid_key1 ' - b':invalid UTF-8: \xc3\r\n') + self.assertNotIn( + "761", + commands, # RPL_KEYVALUE + fail_msg="Setting METADATA key to a value containing invalid " + "UTF-8 was answered with 761 (RPL_KEYVALUE)", + ) + self.clients[1].conn.sendall( + b"METADATA * SET valid_key1 " b":invalid UTF-8: \xc3\r\n" + ) commands = {m.command for m in self.getMessages(1)} - self.assertNotIn('761', commands, # RPL_KEYVALUE - fail_msg='Setting METADATA key to a value containing invalid ' - 'UTF-8 was answered with 761 (RPL_KEYVALUE)') + self.assertNotIn( + "761", + commands, # RPL_KEYVALUE + fail_msg="Setting METADATA key to a value containing invalid " + "UTF-8 was answered with 761 (RPL_KEYVALUE)", + ) diff --git a/irctest/server_tests/test_monitor.py b/irctest/server_tests/test_monitor.py index f51ee87..9cd9927 100644 --- a/irctest/server_tests/test_monitor.py +++ b/irctest/server_tests/test_monitor.py @@ -5,106 +5,132 @@ from irctest import cases from irctest.client_mock import NoMessageException from irctest.basecontrollers import NotImplementedByController -from irctest.numerics import RPL_MONLIST, RPL_ENDOFMONLIST, RPL_MONONLINE, RPL_MONOFFLINE +from irctest.numerics import ( + RPL_MONLIST, + RPL_ENDOFMONLIST, + RPL_MONONLINE, + RPL_MONOFFLINE, +) + class MonitorTestCase(cases.BaseServerTestCase): def check_server_support(self): - if 'MONITOR' not in self.server_support: - raise NotImplementedByController('MONITOR') + if "MONITOR" not in self.server_support: + raise NotImplementedByController("MONITOR") def assertMononline(self, client, nick, m=None): if not m: m = self.getMessage(client) - self.assertMessageEqual(m, command='730', # RPL_MONONLINE - fail_msg='Sent non-730 (RPL_MONONLINE) message after ' - 'monitored nick “{}” connected: {msg}', - extra_format=(nick,)) - self.assertEqual(len(m.params), 2, m, - fail_msg='Invalid number of params of RPL_MONONLINE: {msg}') - self.assertEqual(m.params[1].split('!')[0], 'bar', - fail_msg='730 (RPL_MONONLINE) with bad target after “{}” ' - 'connects: {msg}', - extra_format=(nick,)) + self.assertMessageEqual( + m, + command="730", # RPL_MONONLINE + fail_msg="Sent non-730 (RPL_MONONLINE) message after " + "monitored nick “{}” connected: {msg}", + extra_format=(nick,), + ) + self.assertEqual( + len(m.params), + 2, + m, + fail_msg="Invalid number of params of RPL_MONONLINE: {msg}", + ) + self.assertEqual( + m.params[1].split("!")[0], + "bar", + fail_msg="730 (RPL_MONONLINE) with bad target after “{}” " + "connects: {msg}", + extra_format=(nick,), + ) def assertMonoffline(self, client, nick, m=None): if not m: m = self.getMessage(client) - self.assertMessageEqual(m, command='731', # RPL_MONOFFLINE - fail_msg='Did not reply with 731 (RPL_MONOFFLINE) to ' - '“MONITOR + {}”, while “{}” is offline: {msg}', - extra_format=(nick, nick)) - self.assertEqual(len(m.params), 2, m, - fail_msg='Invalid number of params of RPL_MONOFFLINE: {msg}') - self.assertEqual(m.params[1].split('!')[0], 'bar', - fail_msg='731 (RPL_MONOFFLINE) reply to “MONITOR + {}” ' - 'with bad target: {msg}', - extra_format=(nick,)) + self.assertMessageEqual( + m, + command="731", # RPL_MONOFFLINE + fail_msg="Did not reply with 731 (RPL_MONOFFLINE) to " + "“MONITOR + {}”, while “{}” is offline: {msg}", + extra_format=(nick, nick), + ) + self.assertEqual( + len(m.params), + 2, + m, + fail_msg="Invalid number of params of RPL_MONOFFLINE: {msg}", + ) + self.assertEqual( + m.params[1].split("!")[0], + "bar", + fail_msg="731 (RPL_MONOFFLINE) reply to “MONITOR + {}” " + "with bad target: {msg}", + extra_format=(nick,), + ) - @cases.SpecificationSelector.requiredBySpecification('IRCv3.2') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.2") def testMonitorOneDisconnected(self): """“If any of the targets being added are online, the server will generate RPL_MONONLINE numerics listing those targets that are online.” -- """ - self.connectClient('foo') + self.connectClient("foo") self.check_server_support() - self.sendLine(1, 'MONITOR + bar') - self.assertMonoffline(1, 'bar') - self.connectClient('bar') - self.assertMononline(1, 'bar') - self.sendLine(2, 'QUIT :bye') + self.sendLine(1, "MONITOR + bar") + self.assertMonoffline(1, "bar") + self.connectClient("bar") + self.assertMononline(1, "bar") + self.sendLine(2, "QUIT :bye") try: self.getMessages(2) except ConnectionResetError: pass - self.assertMonoffline(1, 'bar') + self.assertMonoffline(1, "bar") - @cases.SpecificationSelector.requiredBySpecification('IRCv3.2') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.2") def testMonitorOneConnection(self): - self.connectClient('foo') + self.connectClient("foo") self.check_server_support() - self.sendLine(1, 'MONITOR + bar') + self.sendLine(1, "MONITOR + bar") self.getMessages(1) - self.connectClient('bar') - self.assertMononline(1, 'bar') + self.connectClient("bar") + self.assertMononline(1, "bar") - @cases.SpecificationSelector.requiredBySpecification('IRCv3.2') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.2") def testMonitorOneConnected(self): """“If any of the targets being added are offline, the server will generate RPL_MONOFFLINE numerics listing those targets that are online.” -- """ - self.connectClient('foo') + self.connectClient("foo") self.check_server_support() - self.connectClient('bar') - self.sendLine(1, 'MONITOR + bar') - self.assertMononline(1, 'bar') - self.sendLine(2, 'QUIT :bye') + self.connectClient("bar") + self.sendLine(1, "MONITOR + bar") + self.assertMononline(1, "bar") + self.sendLine(2, "QUIT :bye") try: self.getMessages(2) except ConnectionResetError: pass - self.assertMonoffline(1, 'bar') + self.assertMonoffline(1, "bar") - @cases.SpecificationSelector.requiredBySpecification('IRCv3.2') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.2") def testMonitorOneConnectionWithQuit(self): - self.connectClient('foo') + self.connectClient("foo") self.check_server_support() - self.connectClient('bar') - self.sendLine(1, 'MONITOR + bar') - self.assertMononline(1, 'bar') - self.sendLine(2, 'QUIT :bye') + self.connectClient("bar") + self.sendLine(1, "MONITOR + bar") + self.assertMononline(1, "bar") + self.sendLine(2, "QUIT :bye") try: self.getMessages(2) except ConnectionResetError: pass - self.assertMonoffline(1, 'bar') - self.connectClient('bar') - self.assertMononline(1, 'bar') + self.assertMonoffline(1, "bar") + self.connectClient("bar") + self.assertMononline(1, "bar") - @cases.SpecificationSelector.requiredBySpecification('IRCv3.2') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.2") def testMonitorConnectedAndDisconnected(self): """“If any of the targets being added are online, the server will generate RPL_MONONLINE numerics listing those targets that are @@ -115,52 +141,76 @@ class MonitorTestCase(cases.BaseServerTestCase): online.” -- """ - self.connectClient('foo') + self.connectClient("foo") self.check_server_support() - self.connectClient('bar') - self.sendLine(1, 'MONITOR + bar,baz') + self.connectClient("bar") + self.sendLine(1, "MONITOR + bar,baz") m1 = self.getMessage(1) m2 = self.getMessage(1) commands = {m1.command, m2.command} - self.assertEqual(commands, {'730', '731'}, - fail_msg='Did not send one 730 (RPL_MONONLINE) and one ' - '731 (RPL_MONOFFLINE) after “MONITOR + bar,baz” when “bar” ' - 'is online and “baz” is offline. Sent this instead: {}', - extra_format=((m1, m2))) - if m1.command == '731': + self.assertEqual( + commands, + {"730", "731"}, + fail_msg="Did not send one 730 (RPL_MONONLINE) and one " + "731 (RPL_MONOFFLINE) after “MONITOR + bar,baz” when “bar” " + "is online and “baz” is offline. Sent this instead: {}", + extra_format=((m1, m2)), + ) + if m1.command == "731": (m1, m2) = (m2, m1) - self.assertEqual(len(m1.params), 2, m1, - fail_msg='Invalid number of params of RPL_MONONLINE: {msg}') - self.assertEqual(len(m2.params), 2, m2, - fail_msg='Invalid number of params of RPL_MONONLINE: {msg}') - self.assertEqual(m1.params[1].split('!')[0], 'bar', m1, - fail_msg='730 (RPL_MONONLINE) with bad target after ' - '“MONITOR + bar,baz” and “bar” is connected: {msg}') - self.assertEqual(m2.params[1].split('!')[0], 'baz', m2, - fail_msg='731 (RPL_MONOFFLINE) with bad target after ' - '“MONITOR + bar,baz” and “baz” is disconnected: {msg}') + self.assertEqual( + len(m1.params), + 2, + m1, + fail_msg="Invalid number of params of RPL_MONONLINE: {msg}", + ) + self.assertEqual( + len(m2.params), + 2, + m2, + fail_msg="Invalid number of params of RPL_MONONLINE: {msg}", + ) + self.assertEqual( + m1.params[1].split("!")[0], + "bar", + m1, + fail_msg="730 (RPL_MONONLINE) with bad target after " + "“MONITOR + bar,baz” and “bar” is connected: {msg}", + ) + self.assertEqual( + m2.params[1].split("!")[0], + "baz", + m2, + fail_msg="731 (RPL_MONOFFLINE) with bad target after " + "“MONITOR + bar,baz” and “baz” is disconnected: {msg}", + ) - @cases.SpecificationSelector.requiredBySpecification('IRCv3.2') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.2") def testUnmonitor(self): - self.connectClient('foo') + self.connectClient("foo") self.check_server_support() - self.sendLine(1, 'MONITOR + bar') + self.sendLine(1, "MONITOR + bar") self.getMessages(1) - self.connectClient('bar') - self.assertMononline(1, 'bar') - self.sendLine(1, 'MONITOR - bar') - self.assertEqual(self.getMessages(1), [], - fail_msg='Got messages after “MONITOR - bar”: {got}') - self.sendLine(2, 'QUIT :bye') + self.connectClient("bar") + self.assertMononline(1, "bar") + self.sendLine(1, "MONITOR - bar") + self.assertEqual( + self.getMessages(1), + [], + fail_msg="Got messages after “MONITOR - bar”: {got}", + ) + self.sendLine(2, "QUIT :bye") try: self.getMessages(2) except ConnectionResetError: pass - self.assertEqual(self.getMessages(1), [], - fail_msg='Got messages after disconnection of unmonitored ' - 'nick: {got}') + self.assertEqual( + self.getMessages(1), + [], + fail_msg="Got messages after disconnection of unmonitored " "nick: {got}", + ) - @cases.SpecificationSelector.requiredBySpecification('IRCv3.2') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.2") def testMonitorForbidsMasks(self): """“The MONITOR implementation also enhances user privacy by disallowing subscription to hostmasks, allowing users to avoid @@ -171,27 +221,33 @@ class MonitorTestCase(cases.BaseServerTestCase): by the IRC daemon.” -- """ - self.connectClient('foo') + self.connectClient("foo") self.check_server_support() - self.sendLine(1, 'MONITOR + *!username@localhost') - self.sendLine(1, 'MONITOR + *!username@127.0.0.1') + self.sendLine(1, "MONITOR + *!username@localhost") + self.sendLine(1, "MONITOR + *!username@127.0.0.1") try: m = self.getMessage(1) - self.assertNotEqual(m.command, '731', m, - fail_msg='Got 731 (RPL_MONOFFLINE) after adding a monitor ' - 'on a mask: {msg}') + self.assertNotEqual( + m.command, + "731", + m, + fail_msg="Got 731 (RPL_MONOFFLINE) after adding a monitor " + "on a mask: {msg}", + ) except NoMessageException: pass - self.connectClient('bar') + self.connectClient("bar") try: m = self.getMessage(1) except NoMessageException: pass else: - raise AssertionError('Got message after client whose MONITORing ' - 'was requested via hostmask connected: {}'.format(m)) + raise AssertionError( + "Got message after client whose MONITORing " + "was requested via hostmask connected: {}".format(m) + ) - @cases.SpecificationSelector.requiredBySpecification('IRCv3.2') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.2") def testTwoMonitoringOneRemove(self): """Tests the following scenario: * foo MONITORs qux @@ -199,30 +255,30 @@ class MonitorTestCase(cases.BaseServerTestCase): * bar unMONITORs qux * qux connects. """ - self.connectClient('foo') + self.connectClient("foo") self.check_server_support() - self.connectClient('bar') - self.sendLine(1, 'MONITOR + qux') - self.sendLine(2, 'MONITOR + qux') + self.connectClient("bar") + self.sendLine(1, "MONITOR + qux") + self.sendLine(2, "MONITOR + qux") self.getMessages(1) self.getMessages(2) - self.sendLine(2, 'MONITOR - qux') + self.sendLine(2, "MONITOR - qux") l = self.getMessages(2) - self.assertEqual(l, [], - fail_msg='Got response to “MONITOR -”: {}', - extra_format=(l,)) - self.connectClient('qux') + self.assertEqual( + l, [], fail_msg="Got response to “MONITOR -”: {}", extra_format=(l,) + ) + self.connectClient("qux") self.getMessages(3) l = self.getMessages(1) - self.assertNotEqual(l, [], - fail_msg='Received no message after MONITORed client ' - 'connects.') + self.assertNotEqual( + l, [], fail_msg="Received no message after MONITORed client " "connects." + ) l = self.getMessages(2) - self.assertEqual(l, [], - fail_msg='Got response to unmonitored client: {}', - extra_format=(l,)) + self.assertEqual( + l, [], fail_msg="Got response to unmonitored client: {}", extra_format=(l,) + ) - @cases.SpecificationSelector.requiredBySpecification('IRCv3.2') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.2") def testMonitorList(self): def checkMonitorSubjects(messages, client_nick, expected_targets): # collect all the RPL_MONLIST nicks into a set: @@ -230,63 +286,81 @@ class MonitorTestCase(cases.BaseServerTestCase): for message in messages: if message.command == RPL_MONLIST: self.assertEqual(message.params[0], client_nick) - result.update(message.params[1].split(',')) + result.update(message.params[1].split(",")) # finally, RPL_ENDOFMONLIST should be sent self.assertEqual(messages[-1].command, RPL_ENDOFMONLIST) self.assertEqual(messages[-1].params[0], client_nick) self.assertEqual(result, expected_targets) - self.connectClient('bar') + self.connectClient("bar") self.check_server_support() - self.sendLine(1, 'MONITOR L') - checkMonitorSubjects(self.getMessages(1), 'bar', set()) + self.sendLine(1, "MONITOR L") + checkMonitorSubjects(self.getMessages(1), "bar", set()) - self.sendLine(1, 'MONITOR + qux') + self.sendLine(1, "MONITOR + qux") self.getMessages(1) - self.sendLine(1, 'MONITOR L') - checkMonitorSubjects(self.getMessages(1), 'bar', {'qux',}) + self.sendLine(1, "MONITOR L") + checkMonitorSubjects( + self.getMessages(1), + "bar", + { + "qux", + }, + ) - self.sendLine(1, 'MONITOR + bazbat') + self.sendLine(1, "MONITOR + bazbat") self.getMessages(1) - self.sendLine(1, 'MONITOR L') - checkMonitorSubjects(self.getMessages(1), 'bar', {'qux', 'bazbat',}) + self.sendLine(1, "MONITOR L") + checkMonitorSubjects( + self.getMessages(1), + "bar", + { + "qux", + "bazbat", + }, + ) - self.sendLine(1, 'MONITOR - qux') + self.sendLine(1, "MONITOR - qux") self.getMessages(1) - self.sendLine(1, 'MONITOR L') - checkMonitorSubjects(self.getMessages(1), 'bar', {'bazbat',}) + self.sendLine(1, "MONITOR L") + checkMonitorSubjects( + self.getMessages(1), + "bar", + { + "bazbat", + }, + ) - - @cases.SpecificationSelector.requiredBySpecification('IRCv3.2') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.2") def testNickChange(self): # see oragono issue #1076: nickname changes must trigger RPL_MONOFFLINE - self.connectClient('bar') + self.connectClient("bar") self.check_server_support() - self.sendLine(1, 'MONITOR + qux') + self.sendLine(1, "MONITOR + qux") self.getMessages(1) - self.connectClient('baz') + self.connectClient("baz") self.getMessages(2) self.assertEqual(self.getMessages(1), []) - self.sendLine(2, 'NICK qux') + self.sendLine(2, "NICK qux") self.getMessages(2) mononline = self.getMessages(1)[0] self.assertEqual(mononline.command, RPL_MONONLINE) self.assertEqual(len(mononline.params), 2, mononline.params) - self.assertIn(mononline.params[0], ('bar', '*')) - self.assertEqual(mononline.params[1].split('!')[0], 'qux') + self.assertIn(mononline.params[0], ("bar", "*")) + self.assertEqual(mononline.params[1].split("!")[0], "qux") # no numerics for a case change - self.sendLine(2, 'NICK QUX') + self.sendLine(2, "NICK QUX") self.getMessages(2) self.assertEqual(self.getMessages(1), []) - self.sendLine(2, 'NICK bazbat') + self.sendLine(2, "NICK bazbat") self.getMessages(2) monoffline = self.getMessages(1)[0] # should get RPL_MONOFFLINE with the current unfolded nick self.assertEqual(monoffline.command, RPL_MONOFFLINE) self.assertEqual(len(monoffline.params), 2, monoffline.params) - self.assertIn(monoffline.params[0], ('bar', '*')) - self.assertEqual(monoffline.params[1].split('!')[0], 'QUX') + self.assertIn(monoffline.params[0], ("bar", "*")) + self.assertEqual(monoffline.params[1].split("!")[0], "QUX") diff --git a/irctest/server_tests/test_multi_prefix.py b/irctest/server_tests/test_multi_prefix.py index ef9a1f2..807ea4d 100644 --- a/irctest/server_tests/test_multi_prefix.py +++ b/irctest/server_tests/test_multi_prefix.py @@ -5,8 +5,9 @@ Tests multi-prefix. from irctest import cases + class MultiPrefixTestCase(cases.BaseServerTestCase): - @cases.SpecificationSelector.requiredBySpecification('IRCv3.1') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.1") def testMultiPrefix(self): """“When requested, the multi-prefix client capability will cause the IRC server to send all possible prefixes which apply to a user in NAMES @@ -14,19 +15,35 @@ class MultiPrefixTestCase(cases.BaseServerTestCase): These prefixes MUST be in order of ‘rank’, from highest to lowest. """ - self.connectClient('foo', capabilities=['multi-prefix']) - self.joinChannel(1, '#chan') - self.sendLine(1, 'MODE #chan +v foo') + self.connectClient("foo", capabilities=["multi-prefix"]) + self.joinChannel(1, "#chan") + self.sendLine(1, "MODE #chan +v foo") self.getMessages(1) - #TODO(dan): Make sure +v is voice + # TODO(dan): Make sure +v is voice - self.sendLine(1, 'NAMES #chan') - self.assertMessageEqual(self.getMessage(1), command='353', params=['foo', '=', '#chan', '@+foo'], fail_msg='Expected NAMES response (353) with @+foo, got: {msg}') + self.sendLine(1, "NAMES #chan") + self.assertMessageEqual( + self.getMessage(1), + command="353", + params=["foo", "=", "#chan", "@+foo"], + fail_msg="Expected NAMES response (353) with @+foo, got: {msg}", + ) self.getMessages(1) - self.sendLine(1, 'WHO #chan') + self.sendLine(1, "WHO #chan") msg = self.getMessage(1) - self.assertEqual(msg.command, '352', msg, fail_msg='Expected WHO response (352), got: {msg}') - self.assertGreaterEqual(len(msg.params), 8, 'Expected WHO response (352) with 8 params, got: {msg}'.format(msg=msg)) - self.assertTrue('@+' in msg.params[6], 'Expected WHO response (352) with "@+" in param 7, got: {msg}'.format(msg=msg)) + self.assertEqual( + msg.command, "352", msg, fail_msg="Expected WHO response (352), got: {msg}" + ) + self.assertGreaterEqual( + len(msg.params), + 8, + "Expected WHO response (352) with 8 params, got: {msg}".format(msg=msg), + ) + self.assertTrue( + "@+" in msg.params[6], + 'Expected WHO response (352) with "@+" in param 7, got: {msg}'.format( + msg=msg + ), + ) diff --git a/irctest/server_tests/test_multiline.py b/irctest/server_tests/test_multiline.py index abdcc85..a1e803b 100644 --- a/irctest/server_tests/test_multiline.py +++ b/irctest/server_tests/test_multiline.py @@ -4,118 +4,127 @@ draft/multiline from irctest import cases -CAP_NAME = 'draft/multiline' -BATCH_TYPE = 'draft/multiline' -CONCAT_TAG = 'draft/multiline-concat' +CAP_NAME = "draft/multiline" +BATCH_TYPE = "draft/multiline" +CONCAT_TAG = "draft/multiline-concat" + +base_caps = ["message-tags", "batch", "echo-message", "server-time", "labeled-response"] -base_caps = ['message-tags', 'batch', 'echo-message', 'server-time', 'labeled-response'] class MultilineTestCase(cases.BaseServerTestCase, cases.OptionalityHelper): - - @cases.SpecificationSelector.requiredBySpecification('multiline') + @cases.SpecificationSelector.requiredBySpecification("multiline") def testBasic(self): self.connectClient( - 'alice', capabilities=(base_caps + [CAP_NAME]), skip_if_cap_nak=True + "alice", capabilities=(base_caps + [CAP_NAME]), skip_if_cap_nak=True ) - self.joinChannel(1, '#test') - self.connectClient('bob', capabilities=(base_caps + [CAP_NAME])) - self.joinChannel(2, '#test') - self.connectClient('charlie', capabilities=base_caps) - self.joinChannel(3, '#test') + self.joinChannel(1, "#test") + self.connectClient("bob", capabilities=(base_caps + [CAP_NAME])) + self.joinChannel(2, "#test") + self.connectClient("charlie", capabilities=base_caps) + self.joinChannel(3, "#test") self.getMessages(1) self.getMessages(2) self.getMessages(3) - self.sendLine(1, '@label=xyz BATCH +123 %s #test' % (BATCH_TYPE,)) - self.sendLine(1, '@batch=123 PRIVMSG #test hello') - self.sendLine(1, '@batch=123 PRIVMSG #test :#how is ') - self.sendLine(1, '@batch=123;%s PRIVMSG #test :everyone?' % (CONCAT_TAG,)) - self.sendLine(1, 'BATCH -123') + self.sendLine(1, "@label=xyz BATCH +123 %s #test" % (BATCH_TYPE,)) + self.sendLine(1, "@batch=123 PRIVMSG #test hello") + self.sendLine(1, "@batch=123 PRIVMSG #test :#how is ") + self.sendLine(1, "@batch=123;%s PRIVMSG #test :everyone?" % (CONCAT_TAG,)) + self.sendLine(1, "BATCH -123") echo = self.getMessages(1) batchStart, batchEnd = echo[0], echo[-1] - self.assertEqual(batchStart.command, 'BATCH') - self.assertEqual(batchStart.tags.get('label'), 'xyz') + self.assertEqual(batchStart.command, "BATCH") + self.assertEqual(batchStart.tags.get("label"), "xyz") self.assertEqual(len(batchStart.params), 3) self.assertEqual(batchStart.params[1], CAP_NAME) self.assertEqual(batchStart.params[2], "#test") - self.assertEqual(batchEnd.command, 'BATCH') + self.assertEqual(batchEnd.command, "BATCH") self.assertEqual(batchStart.params[0][1:], batchEnd.params[0][1:]) - msgid = batchStart.tags.get('msgid') - time = batchStart.tags.get('time') + msgid = batchStart.tags.get("msgid") + time = batchStart.tags.get("time") assert msgid assert time privmsgs = echo[1:-1] for msg in privmsgs: - self.assertMessageEqual(msg, command='PRIVMSG') - self.assertNotIn('msgid', msg.tags) - self.assertNotIn('time', msg.tags) + self.assertMessageEqual(msg, command="PRIVMSG") + self.assertNotIn("msgid", msg.tags) + self.assertNotIn("time", msg.tags) self.assertIn(CONCAT_TAG, echo[3].tags) relay = self.getMessages(2) batchStart, batchEnd = relay[0], relay[-1] - self.assertEqual(batchStart.command, 'BATCH') - self.assertEqual(batchEnd.command, 'BATCH') + self.assertEqual(batchStart.command, "BATCH") + self.assertEqual(batchEnd.command, "BATCH") batchTag = batchStart.params[0][1:] - self.assertEqual(batchStart.params[0], '+'+batchTag) - self.assertEqual(batchEnd.params[0], '-'+batchTag) - self.assertEqual(batchStart.tags.get('msgid'), msgid) - self.assertEqual(batchStart.tags.get('time'), time) + self.assertEqual(batchStart.params[0], "+" + batchTag) + self.assertEqual(batchEnd.params[0], "-" + batchTag) + self.assertEqual(batchStart.tags.get("msgid"), msgid) + self.assertEqual(batchStart.tags.get("time"), time) privmsgs = relay[1:-1] for msg in privmsgs: - self.assertMessageEqual(msg, command='PRIVMSG') - self.assertNotIn('msgid', msg.tags) - self.assertNotIn('time', msg.tags) - self.assertEqual(msg.tags.get('batch'), batchTag) + self.assertMessageEqual(msg, command="PRIVMSG") + self.assertNotIn("msgid", msg.tags) + self.assertNotIn("time", msg.tags) + self.assertEqual(msg.tags.get("batch"), batchTag) self.assertIn(CONCAT_TAG, relay[3].tags) fallback_relay = self.getMessages(3) relayed_fmsgids = [] for msg in fallback_relay: - self.assertMessageEqual(msg, command='PRIVMSG') - relayed_fmsgids.append(msg.tags.get('msgid')) - self.assertEqual(msg.tags.get('time'), time) + self.assertMessageEqual(msg, command="PRIVMSG") + relayed_fmsgids.append(msg.tags.get("msgid")) + self.assertEqual(msg.tags.get("time"), time) self.assertNotIn(CONCAT_TAG, msg.tags) - self.assertEqual(relayed_fmsgids, [msgid] + [None]*(len(fallback_relay)-1)) + self.assertEqual(relayed_fmsgids, [msgid] + [None] * (len(fallback_relay) - 1)) - - @cases.SpecificationSelector.requiredBySpecification('multiline') + @cases.SpecificationSelector.requiredBySpecification("multiline") def testBlankLines(self): self.connectClient( - 'alice', capabilities=(base_caps + [CAP_NAME]), skip_if_cap_nak=True + "alice", capabilities=(base_caps + [CAP_NAME]), skip_if_cap_nak=True ) - self.joinChannel(1, '#test') - self.connectClient('bob', capabilities=(base_caps + [CAP_NAME])) - self.joinChannel(2, '#test') - self.connectClient('charlie', capabilities=base_caps) - self.joinChannel(3, '#test') + self.joinChannel(1, "#test") + self.connectClient("bob", capabilities=(base_caps + [CAP_NAME])) + self.joinChannel(2, "#test") + self.connectClient("charlie", capabilities=base_caps) + self.joinChannel(3, "#test") self.getMessages(1) self.getMessages(2) self.getMessages(3) - self.sendLine(1, '@label=xyz;+client-only-tag BATCH +123 %s #test' % (BATCH_TYPE,)) - self.sendLine(1, '@batch=123 PRIVMSG #test :') - self.sendLine(1, '@batch=123 PRIVMSG #test :#how is ') - self.sendLine(1, '@batch=123;%s PRIVMSG #test :everyone?' % (CONCAT_TAG,)) - self.sendLine(1, 'BATCH -123') + self.sendLine( + 1, "@label=xyz;+client-only-tag BATCH +123 %s #test" % (BATCH_TYPE,) + ) + self.sendLine(1, "@batch=123 PRIVMSG #test :") + self.sendLine(1, "@batch=123 PRIVMSG #test :#how is ") + self.sendLine(1, "@batch=123;%s PRIVMSG #test :everyone?" % (CONCAT_TAG,)) + self.sendLine(1, "BATCH -123") self.getMessages(1) relay = self.getMessages(2) batch_start = relay[0] privmsgs = relay[1:-1] self.assertEqual(len(privmsgs), 3) - self.assertMessageEqual(privmsgs[0], command='PRIVMSG', params=['#test', '']) - self.assertMessageEqual(privmsgs[1], command='PRIVMSG', params=['#test', '#how is ']) - self.assertMessageEqual(privmsgs[2], command='PRIVMSG', params=['#test', 'everyone?']) - self.assertIn('+client-only-tag', batch_start.tags) - msgid = batch_start.tags['msgid'] + self.assertMessageEqual(privmsgs[0], command="PRIVMSG", params=["#test", ""]) + self.assertMessageEqual( + privmsgs[1], command="PRIVMSG", params=["#test", "#how is "] + ) + self.assertMessageEqual( + privmsgs[2], command="PRIVMSG", params=["#test", "everyone?"] + ) + self.assertIn("+client-only-tag", batch_start.tags) + msgid = batch_start.tags["msgid"] fallback_relay = self.getMessages(3) self.assertEqual(len(fallback_relay), 2) - self.assertMessageEqual(fallback_relay[0], command='PRIVMSG', params=['#test', '#how is ']) - self.assertMessageEqual(fallback_relay[1], command='PRIVMSG', params=['#test', 'everyone?']) - self.assertIn('+client-only-tag', fallback_relay[0].tags) - self.assertIn('+client-only-tag', fallback_relay[1].tags) - self.assertEqual(fallback_relay[0].tags['msgid'], msgid) + self.assertMessageEqual( + fallback_relay[0], command="PRIVMSG", params=["#test", "#how is "] + ) + self.assertMessageEqual( + fallback_relay[1], command="PRIVMSG", params=["#test", "everyone?"] + ) + self.assertIn("+client-only-tag", fallback_relay[0].tags) + self.assertIn("+client-only-tag", fallback_relay[1].tags) + self.assertEqual(fallback_relay[0].tags["msgid"], msgid) diff --git a/irctest/server_tests/test_register_verify.py b/irctest/server_tests/test_register_verify.py index 8affa42..4b7766f 100644 --- a/irctest/server_tests/test_register_verify.py +++ b/irctest/server_tests/test_register_verify.py @@ -1,107 +1,114 @@ from irctest import cases -REGISTER_CAP_NAME = 'draft/register' +REGISTER_CAP_NAME = "draft/register" + class TestRegisterBeforeConnect(cases.BaseServerTestCase): @staticmethod def config(): return { - "oragono_config": lambda config: config['accounts']['registration'].update( - {'allow-before-connect': True} + "oragono_config": lambda config: config["accounts"]["registration"].update( + {"allow-before-connect": True} ) } - @cases.SpecificationSelector.requiredBySpecification('Oragono') + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testBeforeConnect(self): - self.addClient('bar') - self.sendLine('bar', 'CAP LS 302') - caps = self.getCapLs('bar') + self.addClient("bar") + self.sendLine("bar", "CAP LS 302") + caps = self.getCapLs("bar") self.assertIn(REGISTER_CAP_NAME, caps) - self.assertIn('before-connect', caps[REGISTER_CAP_NAME]) - self.sendLine('bar', 'NICK bar') - self.sendLine('bar', 'REGISTER * shivarampassphrase') - msgs = self.getMessages('bar') - register_response = [msg for msg in msgs if msg.command == 'REGISTER'][0] - self.assertEqual(register_response.params[0], 'SUCCESS') + self.assertIn("before-connect", caps[REGISTER_CAP_NAME]) + self.sendLine("bar", "NICK bar") + self.sendLine("bar", "REGISTER * shivarampassphrase") + msgs = self.getMessages("bar") + register_response = [msg for msg in msgs if msg.command == "REGISTER"][0] + self.assertEqual(register_response.params[0], "SUCCESS") + class TestRegisterBeforeConnectDisallowed(cases.BaseServerTestCase): @staticmethod def config(): return { - "oragono_config": lambda config: config['accounts']['registration'].update( - {'allow-before-connect': False} + "oragono_config": lambda config: config["accounts"]["registration"].update( + {"allow-before-connect": False} ) } - @cases.SpecificationSelector.requiredBySpecification('Oragono') + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testBeforeConnect(self): - self.addClient('bar') - self.sendLine('bar', 'CAP LS 302') - caps = self.getCapLs('bar') + self.addClient("bar") + self.sendLine("bar", "CAP LS 302") + caps = self.getCapLs("bar") self.assertIn(REGISTER_CAP_NAME, caps) self.assertEqual(caps[REGISTER_CAP_NAME], None) - self.sendLine('bar', 'NICK bar') - self.sendLine('bar', 'REGISTER * shivarampassphrase') - msgs = self.getMessages('bar') - fail_response = [msg for msg in msgs if msg.command == 'FAIL'][0] - self.assertEqual(fail_response.params[:2], ['REGISTER', 'DISALLOWED']) + self.sendLine("bar", "NICK bar") + self.sendLine("bar", "REGISTER * shivarampassphrase") + msgs = self.getMessages("bar") + fail_response = [msg for msg in msgs if msg.command == "FAIL"][0] + self.assertEqual(fail_response.params[:2], ["REGISTER", "DISALLOWED"]) + class TestRegisterEmailVerified(cases.BaseServerTestCase): @staticmethod def config(): return { - "oragono_config": lambda config: config['accounts']['registration'].update( + "oragono_config": lambda config: config["accounts"]["registration"].update( { - 'email-verification': { - 'enabled': True, - 'sender': 'test@example.com', - 'require-tls': True, - 'helo-domain': 'example.com', + "email-verification": { + "enabled": True, + "sender": "test@example.com", + "require-tls": True, + "helo-domain": "example.com", }, - 'allow-before-connect': True, + "allow-before-connect": True, } ) } - @cases.SpecificationSelector.requiredBySpecification('Oragono') + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testBeforeConnect(self): - self.addClient('bar') - self.sendLine('bar', 'CAP LS 302') - caps = self.getCapLs('bar') + self.addClient("bar") + self.sendLine("bar", "CAP LS 302") + caps = self.getCapLs("bar") self.assertIn(REGISTER_CAP_NAME, caps) - self.assertEqual(set(caps[REGISTER_CAP_NAME].split(',')), {'before-connect', 'email-required'}) - self.sendLine('bar', 'NICK bar') - self.sendLine('bar', 'REGISTER * shivarampassphrase') - msgs = self.getMessages('bar') - fail_response = [msg for msg in msgs if msg.command == 'FAIL'][0] - self.assertEqual(fail_response.params[:2], ['REGISTER', 'INVALID_EMAIL']) + self.assertEqual( + set(caps[REGISTER_CAP_NAME].split(",")), + {"before-connect", "email-required"}, + ) + self.sendLine("bar", "NICK bar") + self.sendLine("bar", "REGISTER * shivarampassphrase") + msgs = self.getMessages("bar") + fail_response = [msg for msg in msgs if msg.command == "FAIL"][0] + self.assertEqual(fail_response.params[:2], ["REGISTER", "INVALID_EMAIL"]) - @cases.SpecificationSelector.requiredBySpecification('Oragono') + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testAfterConnect(self): - self.connectClient('bar', name='bar') - self.sendLine('bar', 'REGISTER * shivarampassphrase') - msgs = self.getMessages('bar') - fail_response = [msg for msg in msgs if msg.command == 'FAIL'][0] - self.assertEqual(fail_response.params[:2], ['REGISTER', 'INVALID_EMAIL']) + self.connectClient("bar", name="bar") + self.sendLine("bar", "REGISTER * shivarampassphrase") + msgs = self.getMessages("bar") + fail_response = [msg for msg in msgs if msg.command == "FAIL"][0] + self.assertEqual(fail_response.params[:2], ["REGISTER", "INVALID_EMAIL"]) + class TestRegisterNoLandGrabs(cases.BaseServerTestCase): @staticmethod def config(): return { - "oragono_config": lambda config: config['accounts']['registration'].update( - {'allow-before-connect': True} + "oragono_config": lambda config: config["accounts"]["registration"].update( + {"allow-before-connect": True} ) } - @cases.SpecificationSelector.requiredBySpecification('Oragono') + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testBeforeConnect(self): # have an anonymous client take the 'root' username: - self.connectClient('root', name='root') + self.connectClient("root", name="root") # cannot register it out from under the anonymous nick holder: - self.addClient('bar') - self.sendLine('bar', 'NICK root') - self.sendLine('bar', 'REGISTER * shivarampassphrase') - msgs = self.getMessages('bar') - fail_response = [msg for msg in msgs if msg.command == 'FAIL'][0] - self.assertEqual(fail_response.params[:2], ['REGISTER', 'USERNAME_EXISTS']) + self.addClient("bar") + self.sendLine("bar", "NICK root") + self.sendLine("bar", "REGISTER * shivarampassphrase") + msgs = self.getMessages("bar") + fail_response = [msg for msg in msgs if msg.command == "FAIL"][0] + self.assertEqual(fail_response.params[:2], ["REGISTER", "USERNAME_EXISTS"]) diff --git a/irctest/server_tests/test_regressions.py b/irctest/server_tests/test_regressions.py index cc36233..bfd5c2a 100644 --- a/irctest/server_tests/test_regressions.py +++ b/irctest/server_tests/test_regressions.py @@ -6,164 +6,172 @@ from irctest import cases from irctest.numerics import ERR_ERRONEUSNICKNAME, ERR_NICKNAMEINUSE, RPL_WELCOME -class RegressionsTestCase(cases.BaseServerTestCase): - @cases.SpecificationSelector.requiredBySpecification('RFC1459') +class RegressionsTestCase(cases.BaseServerTestCase): + @cases.SpecificationSelector.requiredBySpecification("RFC1459") def testFailedNickChange(self): # see oragono commit d0ded906d4ac8f - self.connectClient('alice') - self.connectClient('bob') + self.connectClient("alice") + self.connectClient("bob") # bob tries to change to an in-use nickname; this MUST fail - self.sendLine(2, 'NICK alice') + self.sendLine(2, "NICK alice") ms = self.getMessages(2) self.assertEqual(len(ms), 1) self.assertMessageEqual(ms[0], command=ERR_NICKNAMEINUSE) # bob MUST still own the bob nick, and be able to receive PRIVMSG as bob - self.sendLine(1, 'PRIVMSG bob hi') + self.sendLine(1, "PRIVMSG bob hi") ms = self.getMessages(1) self.assertEqual(len(ms), 0) ms = self.getMessages(2) self.assertEqual(len(ms), 1) - self.assertMessageEqual(ms[0], command='PRIVMSG', params=['bob', 'hi']) + self.assertMessageEqual(ms[0], command="PRIVMSG", params=["bob", "hi"]) - @cases.SpecificationSelector.requiredBySpecification('RFC1459') + @cases.SpecificationSelector.requiredBySpecification("RFC1459") def testCaseChanges(self): - self.connectClient('alice') - self.joinChannel(1, '#test') - self.connectClient('bob') - self.joinChannel(2, '#test') + self.connectClient("alice") + self.joinChannel(1, "#test") + self.connectClient("bob") + self.joinChannel(2, "#test") self.getMessages(1) self.getMessages(2) # case change: both alice and bob should get a successful nick line - self.sendLine(1, 'NICK Alice') + self.sendLine(1, "NICK Alice") ms = self.getMessages(1) self.assertEqual(len(ms), 1) - self.assertMessageEqual(ms[0], command='NICK', params=['Alice']) + self.assertMessageEqual(ms[0], command="NICK", params=["Alice"]) ms = self.getMessages(2) self.assertEqual(len(ms), 1) - self.assertMessageEqual(ms[0], command='NICK', params=['Alice']) + self.assertMessageEqual(ms[0], command="NICK", params=["Alice"]) # no responses, either to the user or to friends, from a no-op nick change - self.sendLine(1, 'NICK Alice') + self.sendLine(1, "NICK Alice") ms = self.getMessages(1) self.assertEqual(ms, []) ms = self.getMessages(2) self.assertEqual(ms, []) - @cases.SpecificationSelector.requiredBySpecification('IRCv3.2') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.2") def testTagCap(self): # regression test for oragono #754 self.connectClient( - 'alice', - capabilities=['message-tags', 'batch', 'echo-message', 'server-time'], - skip_if_cap_nak=True + "alice", + capabilities=["message-tags", "batch", "echo-message", "server-time"], + skip_if_cap_nak=True, ) - self.connectClient('bob') + self.connectClient("bob") self.getMessages(1) self.getMessages(2) - self.sendLine(1, '@+draft/reply=ct95w3xemz8qj9du2h74wp8pee PRIVMSG bob :hey yourself') + self.sendLine( + 1, "@+draft/reply=ct95w3xemz8qj9du2h74wp8pee PRIVMSG bob :hey yourself" + ) ms = self.getMessages(1) self.assertEqual(len(ms), 1) - self.assertMessageEqual(ms[0], command='PRIVMSG', params=['bob', 'hey yourself']) - self.assertEqual(ms[0].tags.get('+draft/reply'), 'ct95w3xemz8qj9du2h74wp8pee') + self.assertMessageEqual( + ms[0], command="PRIVMSG", params=["bob", "hey yourself"] + ) + self.assertEqual(ms[0].tags.get("+draft/reply"), "ct95w3xemz8qj9du2h74wp8pee") ms = self.getMessages(2) self.assertEqual(len(ms), 1) - self.assertMessageEqual(ms[0], command='PRIVMSG', params=['bob', 'hey yourself']) + self.assertMessageEqual( + ms[0], command="PRIVMSG", params=["bob", "hey yourself"] + ) self.assertEqual(ms[0].tags, {}) - self.sendLine(2, 'CAP REQ :message-tags server-time') + self.sendLine(2, "CAP REQ :message-tags server-time") self.getMessages(2) - self.sendLine(1, '@+draft/reply=tbxqauh9nykrtpa3n6icd9whan PRIVMSG bob :hey again') + self.sendLine( + 1, "@+draft/reply=tbxqauh9nykrtpa3n6icd9whan PRIVMSG bob :hey again" + ) self.getMessages(1) ms = self.getMessages(2) # now bob has the tags cap, so he should receive the tags self.assertEqual(len(ms), 1) - self.assertMessageEqual(ms[0], command='PRIVMSG', params=['bob', 'hey again']) - self.assertEqual(ms[0].tags.get('+draft/reply'), 'tbxqauh9nykrtpa3n6icd9whan') + self.assertMessageEqual(ms[0], command="PRIVMSG", params=["bob", "hey again"]) + self.assertEqual(ms[0].tags.get("+draft/reply"), "tbxqauh9nykrtpa3n6icd9whan") - @cases.SpecificationSelector.requiredBySpecification('RFC1459') + @cases.SpecificationSelector.requiredBySpecification("RFC1459") def testStarNick(self): self.addClient(1) - self.sendLine(1, 'NICK *') - self.sendLine(1, 'USER u s e r') - replies = {'NOTICE'} - while replies == {'NOTICE'}: + self.sendLine(1, "NICK *") + self.sendLine(1, "USER u s e r") + replies = {"NOTICE"} + while replies == {"NOTICE"}: replies = set(msg.command for msg in self.getMessages(1, synchronize=False)) self.assertIn(ERR_ERRONEUSNICKNAME, replies) self.assertNotIn(RPL_WELCOME, replies) - self.sendLine(1, 'NICK valid') - replies = {'NOTICE'} - while replies <= {'NOTICE'}: + self.sendLine(1, "NICK valid") + replies = {"NOTICE"} + while replies <= {"NOTICE"}: replies = set(msg.command for msg in self.getMessages(1, synchronize=False)) self.assertNotIn(ERR_ERRONEUSNICKNAME, replies) self.assertIn(RPL_WELCOME, replies) - @cases.SpecificationSelector.requiredBySpecification('RFC1459') + @cases.SpecificationSelector.requiredBySpecification("RFC1459") def testEmptyNick(self): self.addClient(1) - self.sendLine(1, 'NICK :') - self.sendLine(1, 'USER u s e r') - replies = {'NOTICE'} - while replies == {'NOTICE'}: + self.sendLine(1, "NICK :") + self.sendLine(1, "USER u s e r") + replies = {"NOTICE"} + while replies == {"NOTICE"}: replies = set(msg.command for msg in self.getMessages(1, synchronize=False)) self.assertNotIn(RPL_WELCOME, replies) - @cases.SpecificationSelector.requiredBySpecification('RFC1459') + @cases.SpecificationSelector.requiredBySpecification("RFC1459") def testNickRelease(self): # regression test for oragono #1252 - self.connectClient('alice') + self.connectClient("alice") self.getMessages(1) - self.sendLine(1, 'NICK malice') - nick_msgs = [msg for msg in self.getMessages(1) if msg.command == 'NICK'] + self.sendLine(1, "NICK malice") + nick_msgs = [msg for msg in self.getMessages(1) if msg.command == "NICK"] self.assertEqual(len(nick_msgs), 1) - self.assertMessageEqual(nick_msgs[0], command='NICK', params=['malice']) + self.assertMessageEqual(nick_msgs[0], command="NICK", params=["malice"]) self.addClient(2) - self.sendLine(2, 'NICK alice') - self.sendLine(2, 'USER u s e r') + self.sendLine(2, "NICK alice") + self.sendLine(2, "USER u s e r") replies = set(msg.command for msg in self.getMessages(2)) self.assertNotIn(ERR_NICKNAMEINUSE, replies) self.assertIn(RPL_WELCOME, replies) - @cases.SpecificationSelector.requiredBySpecification('RFC1459') + @cases.SpecificationSelector.requiredBySpecification("RFC1459") def testNickReleaseQuit(self): - self.connectClient('alice') + self.connectClient("alice") self.getMessages(1) - self.sendLine(1, 'QUIT') + self.sendLine(1, "QUIT") self.assertDisconnected(1) self.addClient(2) - self.sendLine(2, 'NICK alice') - self.sendLine(2, 'USER u s e r') + self.sendLine(2, "NICK alice") + self.sendLine(2, "USER u s e r") replies = set(msg.command for msg in self.getMessages(2)) self.assertNotIn(ERR_NICKNAMEINUSE, replies) self.assertIn(RPL_WELCOME, replies) - self.sendLine(2, 'QUIT') + self.sendLine(2, "QUIT") self.assertDisconnected(2) self.addClient(3) - self.sendLine(3, 'NICK ALICE') - self.sendLine(3, 'USER u s e r') + self.sendLine(3, "NICK ALICE") + self.sendLine(3, "USER u s e r") replies = set(msg.command for msg in self.getMessages(3)) self.assertNotIn(ERR_NICKNAMEINUSE, replies) self.assertIn(RPL_WELCOME, replies) - @cases.SpecificationSelector.requiredBySpecification('RFC1459') + @cases.SpecificationSelector.requiredBySpecification("RFC1459") def testNickReleaseUnregistered(self): self.addClient(1) - self.sendLine(1, 'NICK alice') - self.sendLine(1, 'QUIT') + self.sendLine(1, "NICK alice") + self.sendLine(1, "QUIT") self.assertDisconnected(1) self.addClient(2) - self.sendLine(2, 'NICK alice') - self.sendLine(2, 'USER u s e r') + self.sendLine(2, "NICK alice") + self.sendLine(2, "USER u s e r") replies = set(msg.command for msg in self.getMessages(2)) self.assertNotIn(ERR_NICKNAMEINUSE, replies) self.assertIn(RPL_WELCOME, replies) diff --git a/irctest/server_tests/test_relaymsg.py b/irctest/server_tests/test_relaymsg.py index 786d3b2..bdfbee6 100644 --- a/irctest/server_tests/test_relaymsg.py +++ b/irctest/server_tests/test_relaymsg.py @@ -3,8 +3,9 @@ from irctest.irc_utils.junkdrawer import random_name from irctest.server_tests.test_chathistory import CHATHISTORY_CAP, EVENT_PLAYBACK_CAP -RELAYMSG_CAP = 'draft/relaymsg' -RELAYMSG_TAG_NAME = 'draft/relaymsg' +RELAYMSG_CAP = "draft/relaymsg" +RELAYMSG_TAG_NAME = "draft/relaymsg" + class RelaymsgTestCase(cases.BaseServerTestCase): @staticmethod @@ -13,60 +14,112 @@ class RelaymsgTestCase(cases.BaseServerTestCase): "chathistory": True, } - @cases.SpecificationSelector.requiredBySpecification('Oragono') + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testRelaymsg(self): - self.connectClient('baz', name='baz', capabilities=['server-time', 'message-tags', 'batch', 'labeled-response', 'echo-message', CHATHISTORY_CAP, EVENT_PLAYBACK_CAP]) - self.connectClient('qux', name='qux', capabilities=['server-time', 'message-tags', 'batch', 'labeled-response', 'echo-message', CHATHISTORY_CAP, EVENT_PLAYBACK_CAP]) - chname = random_name('#relaymsg') - self.joinChannel('baz', chname) - self.joinChannel('qux', chname) - self.getMessages('baz') - self.getMessages('qux') + self.connectClient( + "baz", + name="baz", + capabilities=[ + "server-time", + "message-tags", + "batch", + "labeled-response", + "echo-message", + CHATHISTORY_CAP, + EVENT_PLAYBACK_CAP, + ], + ) + self.connectClient( + "qux", + name="qux", + capabilities=[ + "server-time", + "message-tags", + "batch", + "labeled-response", + "echo-message", + CHATHISTORY_CAP, + EVENT_PLAYBACK_CAP, + ], + ) + chname = random_name("#relaymsg") + self.joinChannel("baz", chname) + self.joinChannel("qux", chname) + self.getMessages("baz") + self.getMessages("qux") - self.sendLine('baz', 'RELAYMSG %s invalid!nick/discord hi' % (chname,)) - response = self.getMessages('baz')[0] - self.assertEqual(response.command, 'FAIL') - self.assertEqual(response.params[:2], ['RELAYMSG', 'INVALID_NICK']) + self.sendLine("baz", "RELAYMSG %s invalid!nick/discord hi" % (chname,)) + response = self.getMessages("baz")[0] + self.assertEqual(response.command, "FAIL") + self.assertEqual(response.params[:2], ["RELAYMSG", "INVALID_NICK"]) - self.sendLine('baz', 'RELAYMSG %s regular_nick hi' % (chname,)) - response = self.getMessages('baz')[0] - self.assertEqual(response.command, 'FAIL') - self.assertEqual(response.params[:2], ['RELAYMSG', 'INVALID_NICK']) + self.sendLine("baz", "RELAYMSG %s regular_nick hi" % (chname,)) + response = self.getMessages("baz")[0] + self.assertEqual(response.command, "FAIL") + self.assertEqual(response.params[:2], ["RELAYMSG", "INVALID_NICK"]) - self.sendLine('baz', 'RELAYMSG %s smt/discord hi' % (chname,)) - response = self.getMessages('baz')[0] - self.assertMessageEqual(response, nick='smt/discord', command='PRIVMSG', params=[chname, 'hi']) - relayed_msg = self.getMessages('qux')[0] - self.assertMessageEqual(relayed_msg, nick='smt/discord', command='PRIVMSG', params=[chname, 'hi']) + self.sendLine("baz", "RELAYMSG %s smt/discord hi" % (chname,)) + response = self.getMessages("baz")[0] + self.assertMessageEqual( + response, nick="smt/discord", command="PRIVMSG", params=[chname, "hi"] + ) + relayed_msg = self.getMessages("qux")[0] + self.assertMessageEqual( + relayed_msg, nick="smt/discord", command="PRIVMSG", params=[chname, "hi"] + ) # labeled-response - self.sendLine('baz', '@label=x RELAYMSG %s smt/discord :hi again' % (chname,)) - response = self.getMessages('baz')[0] - self.assertMessageEqual(response, nick='smt/discord', command='PRIVMSG', params=[chname, 'hi again']) - self.assertEqual(response.tags.get('label'), 'x') - relayed_msg = self.getMessages('qux')[0] - self.assertMessageEqual(relayed_msg, nick='smt/discord', command='PRIVMSG', params=[chname, 'hi again']) + self.sendLine("baz", "@label=x RELAYMSG %s smt/discord :hi again" % (chname,)) + response = self.getMessages("baz")[0] + self.assertMessageEqual( + response, nick="smt/discord", command="PRIVMSG", params=[chname, "hi again"] + ) + self.assertEqual(response.tags.get("label"), "x") + relayed_msg = self.getMessages("qux")[0] + self.assertMessageEqual( + relayed_msg, + nick="smt/discord", + command="PRIVMSG", + params=[chname, "hi again"], + ) - self.sendLine('qux', 'RELAYMSG %s smt/discord :hi a third time' % (chname,)) - response = self.getMessages('qux')[0] - self.assertEqual(response.command, 'FAIL') - self.assertEqual(response.params[:2], ['RELAYMSG', 'PRIVS_NEEDED']) + self.sendLine("qux", "RELAYMSG %s smt/discord :hi a third time" % (chname,)) + response = self.getMessages("qux")[0] + self.assertEqual(response.command, "FAIL") + self.assertEqual(response.params[:2], ["RELAYMSG", "PRIVS_NEEDED"]) # grant qux chanop, allowing relaymsg - self.sendLine('baz', 'MODE %s +o qux' % (chname,)) - self.getMessages('baz') - self.getMessages('qux') + self.sendLine("baz", "MODE %s +o qux" % (chname,)) + self.getMessages("baz") + self.getMessages("qux") # give baz the relaymsg cap - self.sendLine('baz', 'CAP REQ %s' % (RELAYMSG_CAP)) - self.assertMessageEqual(self.getMessages('baz')[0], command='CAP', params=['baz', 'ACK', RELAYMSG_CAP]) + self.sendLine("baz", "CAP REQ %s" % (RELAYMSG_CAP)) + self.assertMessageEqual( + self.getMessages("baz")[0], + command="CAP", + params=["baz", "ACK", RELAYMSG_CAP], + ) - self.sendLine('qux', 'RELAYMSG %s smt/discord :hi a third time' % (chname,)) - response = self.getMessages('qux')[0] - self.assertMessageEqual(response, nick='smt/discord', command='PRIVMSG', params=[chname, 'hi a third time']) - relayed_msg = self.getMessages('baz')[0] - self.assertMessageEqual(relayed_msg, nick='smt/discord', command='PRIVMSG', params=[chname, 'hi a third time']) - self.assertEqual(relayed_msg.tags.get(RELAYMSG_TAG_NAME), 'qux') + self.sendLine("qux", "RELAYMSG %s smt/discord :hi a third time" % (chname,)) + response = self.getMessages("qux")[0] + self.assertMessageEqual( + response, + nick="smt/discord", + command="PRIVMSG", + params=[chname, "hi a third time"], + ) + relayed_msg = self.getMessages("baz")[0] + self.assertMessageEqual( + relayed_msg, + nick="smt/discord", + command="PRIVMSG", + params=[chname, "hi a third time"], + ) + self.assertEqual(relayed_msg.tags.get(RELAYMSG_TAG_NAME), "qux") - self.sendLine('baz', 'CHATHISTORY LATEST %s * 10' % (chname,)) - messages = self.getMessages('baz') - self.assertEqual([msg.params[-1] for msg in messages if msg.command == 'PRIVMSG'], ['hi', 'hi again', 'hi a third time']) + self.sendLine("baz", "CHATHISTORY LATEST %s * 10" % (chname,)) + messages = self.getMessages("baz") + self.assertEqual( + [msg.params[-1] for msg in messages if msg.command == "PRIVMSG"], + ["hi", "hi again", "hi a third time"], + ) diff --git a/irctest/server_tests/test_resume.py b/irctest/server_tests/test_resume.py index 61d0275..aa3bfb9 100644 --- a/irctest/server_tests/test_resume.py +++ b/irctest/server_tests/test_resume.py @@ -8,143 +8,209 @@ from irctest import cases from irctest.numerics import RPL_AWAY -ANCIENT_TIMESTAMP = '2006-01-02T15:04:05.999Z' +ANCIENT_TIMESTAMP = "2006-01-02T15:04:05.999Z" + class ResumeTestCase(cases.BaseServerTestCase): - - @cases.SpecificationSelector.requiredBySpecification('Oragono') + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testNoResumeByDefault(self): - self.connectClient('bar', capabilities=['batch', 'echo-message', 'labeled-response']) + self.connectClient( + "bar", capabilities=["batch", "echo-message", "labeled-response"] + ) ms = self.getMessages(1) - resume_messages = [m for m in ms if m.command == 'RESUME'] - self.assertEqual(resume_messages, [], 'should not see RESUME messages unless explicitly negotiated') + resume_messages = [m for m in ms if m.command == "RESUME"] + self.assertEqual( + resume_messages, + [], + "should not see RESUME messages unless explicitly negotiated", + ) - @cases.SpecificationSelector.requiredBySpecification('Oragono') + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testResume(self): - chname = '#' + secrets.token_hex(12) - self.connectClient('bar', capabilities=['batch', 'labeled-response', 'server-time']) + chname = "#" + secrets.token_hex(12) + self.connectClient( + "bar", capabilities=["batch", "labeled-response", "server-time"] + ) ms = self.getMessages(1) - welcome = self.connectClient('baz', capabilities=['batch', 'labeled-response', 'server-time', 'draft/resume-0.5']) - resume_messages = [m for m in welcome if m.command == 'RESUME'] + welcome = self.connectClient( + "baz", + capabilities=[ + "batch", + "labeled-response", + "server-time", + "draft/resume-0.5", + ], + ) + resume_messages = [m for m in welcome if m.command == "RESUME"] self.assertEqual(len(resume_messages), 1) - self.assertEqual(resume_messages[0].params[0], 'TOKEN') + self.assertEqual(resume_messages[0].params[0], "TOKEN") token = resume_messages[0].params[1] self.joinChannel(1, chname) self.joinChannel(2, chname) - self.sendLine(1, 'PRIVMSG %s :hello friends' % (chname,)) - self.sendLine(1, 'PRIVMSG baz :hello friend singular') + self.sendLine(1, "PRIVMSG %s :hello friends" % (chname,)) + self.sendLine(1, "PRIVMSG baz :hello friend singular") self.getMessages(1) # should receive these messages - privmsgs = [m for m in self.getMessages(2) if m.command == 'PRIVMSG'] + privmsgs = [m for m in self.getMessages(2) if m.command == "PRIVMSG"] self.assertEqual(len(privmsgs), 2) privmsgs.sort(key=lambda m: m.params[0]) - self.assertMessageEqual(privmsgs[0], command='PRIVMSG', params=[chname, 'hello friends']) - self.assertMessageEqual(privmsgs[1], command='PRIVMSG', params=['baz', 'hello friend singular']) - channelMsgTime = privmsgs[0].tags.get('time') + self.assertMessageEqual( + privmsgs[0], command="PRIVMSG", params=[chname, "hello friends"] + ) + self.assertMessageEqual( + privmsgs[1], command="PRIVMSG", params=["baz", "hello friend singular"] + ) + channelMsgTime = privmsgs[0].tags.get("time") # tokens MUST be cryptographically secure; therefore, this token should be invalid # with probability at least 1 - 1/(2**128) - bad_token = 'a' * len(token) + bad_token = "a" * len(token) self.addClient() - self.sendLine(3, 'CAP LS') - self.sendLine(3, 'CAP REQ :batch labeled-response server-time draft/resume-0.5') - self.sendLine(3, 'NICK tempnick') - self.sendLine(3, 'USER tempuser 0 * tempuser') - self.sendLine(3, ' '.join(('RESUME', bad_token, ANCIENT_TIMESTAMP))) + self.sendLine(3, "CAP LS") + self.sendLine(3, "CAP REQ :batch labeled-response server-time draft/resume-0.5") + self.sendLine(3, "NICK tempnick") + self.sendLine(3, "USER tempuser 0 * tempuser") + self.sendLine(3, " ".join(("RESUME", bad_token, ANCIENT_TIMESTAMP))) # resume with a bad token MUST fail ms = self.getMessages(3) - resume_err_messages = [m for m in ms if m.command == 'FAIL' and m.params[:2] == ['RESUME', 'INVALID_TOKEN']] + resume_err_messages = [ + m + for m in ms + if m.command == "FAIL" and m.params[:2] == ["RESUME", "INVALID_TOKEN"] + ] self.assertEqual(len(resume_err_messages), 1) # however, registration should proceed with the alternative nick - self.sendLine(3, 'CAP END') - welcome_msgs = [m for m in self.getMessages(3) if m.command == '001'] # RPL_WELCOME - self.assertEqual(welcome_msgs[0].params[0], 'tempnick') + self.sendLine(3, "CAP END") + welcome_msgs = [ + m for m in self.getMessages(3) if m.command == "001" + ] # RPL_WELCOME + self.assertEqual(welcome_msgs[0].params[0], "tempnick") self.addClient() - self.sendLine(4, 'CAP LS') - self.sendLine(4, 'CAP REQ :batch labeled-response server-time draft/resume-0.5') - self.sendLine(4, 'NICK tempnick_') - self.sendLine(4, 'USER tempuser 0 * tempuser') + self.sendLine(4, "CAP LS") + self.sendLine(4, "CAP REQ :batch labeled-response server-time draft/resume-0.5") + self.sendLine(4, "NICK tempnick_") + self.sendLine(4, "USER tempuser 0 * tempuser") # resume with a timestamp in the distant past - self.sendLine(4, ' '.join(('RESUME', token, ANCIENT_TIMESTAMP))) + self.sendLine(4, " ".join(("RESUME", token, ANCIENT_TIMESTAMP))) # successful resume does not require CAP END: # https://github.com/ircv3/ircv3-specifications/pull/306/files#r255318883 ms = self.getMessages(4) # now, do a valid resume with the correct token - resume_messages = [m for m in ms if m.command == 'RESUME'] + resume_messages = [m for m in ms if m.command == "RESUME"] self.assertEqual(len(resume_messages), 2) - self.assertEqual(resume_messages[0].params[0], 'TOKEN') + self.assertEqual(resume_messages[0].params[0], "TOKEN") new_token = resume_messages[0].params[1] - self.assertNotEqual(token, new_token, 'should receive a new, strong resume token; instead got ' + new_token) + self.assertNotEqual( + token, + new_token, + "should receive a new, strong resume token; instead got " + new_token, + ) # success message - self.assertMessageEqual(resume_messages[1], command='RESUME', params=['SUCCESS', 'baz']) + self.assertMessageEqual( + resume_messages[1], command="RESUME", params=["SUCCESS", "baz"] + ) # test replay of messages - privmsgs = [m for m in ms if m.command == 'PRIVMSG' and m.prefix.startswith('bar')] + privmsgs = [ + m for m in ms if m.command == "PRIVMSG" and m.prefix.startswith("bar") + ] self.assertEqual(len(privmsgs), 2) privmsgs.sort(key=lambda m: m.params[0]) - self.assertMessageEqual(privmsgs[0], command='PRIVMSG', params=[chname, 'hello friends']) - self.assertMessageEqual(privmsgs[1], command='PRIVMSG', params=['baz', 'hello friend singular']) + self.assertMessageEqual( + privmsgs[0], command="PRIVMSG", params=[chname, "hello friends"] + ) + self.assertMessageEqual( + privmsgs[1], command="PRIVMSG", params=["baz", "hello friend singular"] + ) # should replay with the original server-time # TODO this probably isn't testing anything because the timestamp only has second resolution, # hence will typically match by accident - self.assertEqual(privmsgs[0].tags.get('time'), channelMsgTime) + self.assertEqual(privmsgs[0].tags.get("time"), channelMsgTime) # legacy client should receive a QUIT and a JOIN - quit, join = [m for m in self.getMessages(1) if m.command in ('QUIT', 'JOIN')] - self.assertEqual(quit.command, 'QUIT') - self.assertTrue(quit.prefix.startswith('baz')) - self.assertMessageEqual(join, command='JOIN', params=[chname]) - self.assertTrue(join.prefix.startswith('baz')) + quit, join = [m for m in self.getMessages(1) if m.command in ("QUIT", "JOIN")] + self.assertEqual(quit.command, "QUIT") + self.assertTrue(quit.prefix.startswith("baz")) + self.assertMessageEqual(join, command="JOIN", params=[chname]) + self.assertTrue(join.prefix.startswith("baz")) # original client should have been disconnected self.assertDisconnected(2) # new client should be receiving PRIVMSG sent to baz - self.sendLine(1, 'PRIVMSG baz :hello again') + self.sendLine(1, "PRIVMSG baz :hello again") self.getMessages(1) - self.assertMessageEqual(self.getMessage(4), command='PRIVMSG', params=['baz', 'hello again']) + self.assertMessageEqual( + self.getMessage(4), command="PRIVMSG", params=["baz", "hello again"] + ) # test chain-resuming (resuming the resumed connection, using the new token) self.addClient() - self.sendLine(5, 'CAP LS') - self.sendLine(5, 'CAP REQ :batch labeled-response server-time draft/resume-0.5') - self.sendLine(5, 'NICK tempnick_') - self.sendLine(5, 'USER tempuser 0 * tempuser') - self.sendLine(5, 'RESUME ' + new_token) + self.sendLine(5, "CAP LS") + self.sendLine(5, "CAP REQ :batch labeled-response server-time draft/resume-0.5") + self.sendLine(5, "NICK tempnick_") + self.sendLine(5, "USER tempuser 0 * tempuser") + self.sendLine(5, "RESUME " + new_token) ms = self.getMessages(5) - resume_messages = [m for m in ms if m.command == 'RESUME'] + resume_messages = [m for m in ms if m.command == "RESUME"] self.assertEqual(len(resume_messages), 2) - self.assertEqual(resume_messages[0].params[0], 'TOKEN') + self.assertEqual(resume_messages[0].params[0], "TOKEN") new_new_token = resume_messages[0].params[1] - self.assertNotEqual(token, new_new_token, 'should receive a new, strong resume token; instead got ' + new_new_token) - self.assertNotEqual(new_token, new_new_token, 'should receive a new, strong resume token; instead got ' + new_new_token) + self.assertNotEqual( + token, + new_new_token, + "should receive a new, strong resume token; instead got " + new_new_token, + ) + self.assertNotEqual( + new_token, + new_new_token, + "should receive a new, strong resume token; instead got " + new_new_token, + ) # success message - self.assertMessageEqual(resume_messages[1], command='RESUME', params=['SUCCESS', 'baz']) + self.assertMessageEqual( + resume_messages[1], command="RESUME", params=["SUCCESS", "baz"] + ) - - @cases.SpecificationSelector.requiredBySpecification('Oragono') + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testBRB(self): - chname = '#' + secrets.token_hex(12) - self.connectClient('bar', capabilities=['batch', 'labeled-response', 'message-tags', 'server-time', 'draft/resume-0.5']) + chname = "#" + secrets.token_hex(12) + self.connectClient( + "bar", + capabilities=[ + "batch", + "labeled-response", + "message-tags", + "server-time", + "draft/resume-0.5", + ], + ) ms = self.getMessages(1) self.joinChannel(1, chname) - welcome = self.connectClient('baz', capabilities=['batch', 'labeled-response', 'server-time', 'draft/resume-0.5']) - resume_messages = [m for m in welcome if m.command == 'RESUME'] + welcome = self.connectClient( + "baz", + capabilities=[ + "batch", + "labeled-response", + "server-time", + "draft/resume-0.5", + ], + ) + resume_messages = [m for m in welcome if m.command == "RESUME"] self.assertEqual(len(resume_messages), 1) - self.assertEqual(resume_messages[0].params[0], 'TOKEN') + self.assertEqual(resume_messages[0].params[0], "TOKEN") token = resume_messages[0].params[1] self.joinChannel(2, chname) self.getMessages(1) - self.sendLine(2, 'BRB :software upgrade') + self.sendLine(2, "BRB :software upgrade") # should receive, e.g., `BRB 210` (number of seconds) - ms = [m for m in self.getMessages(2) if m.command == 'BRB'] + ms = [m for m in self.getMessages(2) if m.command == "BRB"] self.assertEqual(len(ms), 1) self.assertGreater(int(ms[0].params[0]), 1) # BRB disconnects you @@ -152,25 +218,33 @@ class ResumeTestCase(cases.BaseServerTestCase): # without sending a QUIT line to friends self.assertEqual(self.getMessages(1), []) - self.sendLine(1, 'PRIVMSG baz :hey there') + self.sendLine(1, "PRIVMSG baz :hey there") # BRB message should be sent as an away message - self.assertMessageEqual(self.getMessage(1), command=RPL_AWAY, params=['bar', 'baz', 'software upgrade']) + self.assertMessageEqual( + self.getMessage(1), + command=RPL_AWAY, + params=["bar", "baz", "software upgrade"], + ) self.addClient(3) - self.sendLine(3, 'CAP REQ :batch account-tag message-tags draft/resume-0.5') - self.sendLine(3, ' '.join(('RESUME', token, ANCIENT_TIMESTAMP))) + self.sendLine(3, "CAP REQ :batch account-tag message-tags draft/resume-0.5") + self.sendLine(3, " ".join(("RESUME", token, ANCIENT_TIMESTAMP))) ms = self.getMessages(3) - resume_messages = [m for m in ms if m.command == 'RESUME'] + resume_messages = [m for m in ms if m.command == "RESUME"] self.assertEqual(len(resume_messages), 2) - self.assertEqual(resume_messages[0].params[0], 'TOKEN') - self.assertMessageEqual(resume_messages[1], command='RESUME', params=['SUCCESS', 'baz']) + self.assertEqual(resume_messages[0].params[0], "TOKEN") + self.assertMessageEqual( + resume_messages[1], command="RESUME", params=["SUCCESS", "baz"] + ) - privmsgs = [m for m in ms if m.command == 'PRIVMSG' and m.prefix.startswith('bar')] + privmsgs = [ + m for m in ms if m.command == "PRIVMSG" and m.prefix.startswith("bar") + ] self.assertEqual(len(privmsgs), 1) - self.assertMessageEqual(privmsgs[0], params=['baz', 'hey there']) + self.assertMessageEqual(privmsgs[0], params=["baz", "hey there"]) # friend with the resume cap should receive a RESUMED message - resumed_messages = [m for m in self.getMessages(1) if m.command == 'RESUMED'] + resumed_messages = [m for m in self.getMessages(1) if m.command == "RESUMED"] self.assertEqual(len(resumed_messages), 1) - self.assertTrue(resumed_messages[0].prefix.startswith('baz')) + self.assertTrue(resumed_messages[0].prefix.startswith("baz")) diff --git a/irctest/server_tests/test_roleplay.py b/irctest/server_tests/test_roleplay.py index 85df055..1689a94 100644 --- a/irctest/server_tests/test_roleplay.py +++ b/irctest/server_tests/test_roleplay.py @@ -2,6 +2,7 @@ from irctest import cases from irctest.numerics import ERR_CANNOTSENDRP from irctest.irc_utils.junkdrawer import random_name + class RoleplayTestCase(cases.BaseServerTestCase): @staticmethod def config(): @@ -9,58 +10,70 @@ class RoleplayTestCase(cases.BaseServerTestCase): "oragono_roleplay": True, } - @cases.SpecificationSelector.requiredBySpecification('Oragono') + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testRoleplay(self): - bar = random_name('bar') - qux = random_name('qux') - chan = random_name('#chan') - self.connectClient(bar, name=bar, capabilities=['batch', 'labeled-response', 'message-tags', 'server-time']) - self.connectClient(qux, name=qux, capabilities=['batch', 'labeled-response', 'message-tags', 'server-time']) + bar = random_name("bar") + qux = random_name("qux") + chan = random_name("#chan") + self.connectClient( + bar, + name=bar, + capabilities=["batch", "labeled-response", "message-tags", "server-time"], + ) + self.connectClient( + qux, + name=qux, + capabilities=["batch", "labeled-response", "message-tags", "server-time"], + ) self.joinChannel(bar, chan) self.joinChannel(qux, chan) self.getMessages(bar) # roleplay should be forbidden because we aren't +E yet - self.sendLine(bar, 'NPC %s bilbo too much bread' % (chan,)) + self.sendLine(bar, "NPC %s bilbo too much bread" % (chan,)) reply = self.getMessages(bar)[0] self.assertEqual(reply.command, ERR_CANNOTSENDRP) - self.sendLine(bar, 'MODE %s +E' % (chan,)) + self.sendLine(bar, "MODE %s +E" % (chan,)) reply = self.getMessages(bar)[0] - self.assertEqual(reply.command, 'MODE') - self.assertMessageEqual(reply, command='MODE', params=[chan, '+E']) + self.assertEqual(reply.command, "MODE") + self.assertMessageEqual(reply, command="MODE", params=[chan, "+E"]) self.getMessages(qux) - self.sendLine(bar, 'NPC %s bilbo too much bread' % (chan,)) + self.sendLine(bar, "NPC %s bilbo too much bread" % (chan,)) reply = self.getMessages(bar)[0] - self.assertEqual(reply.command, 'PRIVMSG') + self.assertEqual(reply.command, "PRIVMSG") self.assertEqual(reply.params[0], chan) - self.assertTrue(reply.prefix.startswith('*bilbo*!')) - self.assertIn('too much bread', reply.params[1]) + self.assertTrue(reply.prefix.startswith("*bilbo*!")) + self.assertIn("too much bread", reply.params[1]) reply = self.getMessages(qux)[0] - self.assertEqual(reply.command, 'PRIVMSG') + self.assertEqual(reply.command, "PRIVMSG") self.assertEqual(reply.params[0], chan) - self.assertTrue(reply.prefix.startswith('*bilbo*!')) - self.assertIn('too much bread', reply.params[1]) + self.assertTrue(reply.prefix.startswith("*bilbo*!")) + self.assertIn("too much bread", reply.params[1]) - self.sendLine(bar, 'SCENE %s dark and stormy night' % (chan,)) + self.sendLine(bar, "SCENE %s dark and stormy night" % (chan,)) reply = self.getMessages(bar)[0] - self.assertEqual(reply.command, 'PRIVMSG') + self.assertEqual(reply.command, "PRIVMSG") self.assertEqual(reply.params[0], chan) - self.assertTrue(reply.prefix.startswith('=Scene=!')) - self.assertIn('dark and stormy night', reply.params[1]) + self.assertTrue(reply.prefix.startswith("=Scene=!")) + self.assertIn("dark and stormy night", reply.params[1]) reply = self.getMessages(qux)[0] - self.assertEqual(reply.command, 'PRIVMSG') + self.assertEqual(reply.command, "PRIVMSG") self.assertEqual(reply.params[0], chan) - self.assertTrue(reply.prefix.startswith('=Scene=!')) - self.assertIn('dark and stormy night', reply.params[1]) + self.assertTrue(reply.prefix.startswith("=Scene=!")) + self.assertIn("dark and stormy night", reply.params[1]) # test history storage - self.sendLine(qux, 'CHATHISTORY LATEST %s * 10' % (chan,)) - reply = [msg for msg in self.getMessages(qux) if msg.command == 'PRIVMSG' and 'bilbo' in msg.prefix][0] - self.assertEqual(reply.command, 'PRIVMSG') + self.sendLine(qux, "CHATHISTORY LATEST %s * 10" % (chan,)) + reply = [ + msg + for msg in self.getMessages(qux) + if msg.command == "PRIVMSG" and "bilbo" in msg.prefix + ][0] + self.assertEqual(reply.command, "PRIVMSG") self.assertEqual(reply.params[0], chan) - self.assertTrue(reply.prefix.startswith('*bilbo*!')) - self.assertIn('too much bread', reply.params[1]) + self.assertTrue(reply.prefix.startswith("*bilbo*!")) + self.assertIn("too much bread", reply.params[1]) diff --git a/irctest/server_tests/test_sasl.py b/irctest/server_tests/test_sasl.py index a1ac3dc..eb31ff9 100644 --- a/irctest/server_tests/test_sasl.py +++ b/irctest/server_tests/test_sasl.py @@ -2,42 +2,60 @@ import base64 from irctest import cases + class RegistrationTestCase(cases.BaseServerTestCase): def testRegistration(self): - self.controller.registerUser(self, 'testuser', 'mypassword') + self.controller.registerUser(self, "testuser", "mypassword") + class SaslTestCase(cases.BaseServerTestCase, cases.OptionalityHelper): - @cases.SpecificationSelector.requiredBySpecification('IRCv3.1') - @cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.1") + @cases.OptionalityHelper.skipUnlessHasMechanism("PLAIN") def testPlain(self): """PLAIN authentication with correct username/password.""" - self.controller.registerUser(self, 'foo', 'sesame') - self.controller.registerUser(self, 'jilles', 'sesame') - self.controller.registerUser(self, 'bar', 'sesame') + self.controller.registerUser(self, "foo", "sesame") + self.controller.registerUser(self, "jilles", "sesame") + self.controller.registerUser(self, "bar", "sesame") self.addClient() - self.sendLine(1, 'CAP LS 302') + self.sendLine(1, "CAP LS 302") capabilities = self.getCapLs(1) - self.assertIn('sasl', capabilities, - fail_msg='Does not have SASL as the controller claims.') - if capabilities['sasl'] is not None: - self.assertIn('PLAIN', capabilities['sasl'], - fail_msg='Does not have PLAIN mechanism as the controller ' - 'claims') - self.sendLine(1, 'AUTHENTICATE PLAIN') - m = self.getMessage(1, filter_pred=lambda m:m.command != 'NOTICE') - self.assertMessageEqual(m, command='AUTHENTICATE', params=['+'], - fail_msg='Sent “AUTHENTICATE PLAIN”, server should have ' - 'replied with “AUTHENTICATE +”, but instead sent: {msg}') - self.sendLine(1, 'AUTHENTICATE amlsbGVzAGppbGxlcwBzZXNhbWU=') - m = self.getMessage(1, filter_pred=lambda m:m.command != 'NOTICE') - self.assertMessageEqual(m, command='900', - fail_msg='Did not send 900 after correct SASL authentication.') - self.assertEqual(m.params[2], 'jilles', m, - fail_msg='900 should contain the account name as 3rd argument ' - '({expects}), not {got}: {msg}') + self.assertIn( + "sasl", + capabilities, + fail_msg="Does not have SASL as the controller claims.", + ) + if capabilities["sasl"] is not None: + self.assertIn( + "PLAIN", + capabilities["sasl"], + fail_msg="Does not have PLAIN mechanism as the controller " "claims", + ) + self.sendLine(1, "AUTHENTICATE PLAIN") + m = self.getMessage(1, filter_pred=lambda m: m.command != "NOTICE") + self.assertMessageEqual( + m, + command="AUTHENTICATE", + params=["+"], + fail_msg="Sent “AUTHENTICATE PLAIN”, server should have " + "replied with “AUTHENTICATE +”, but instead sent: {msg}", + ) + self.sendLine(1, "AUTHENTICATE amlsbGVzAGppbGxlcwBzZXNhbWU=") + m = self.getMessage(1, filter_pred=lambda m: m.command != "NOTICE") + self.assertMessageEqual( + m, + command="900", + fail_msg="Did not send 900 after correct SASL authentication.", + ) + self.assertEqual( + m.params[2], + "jilles", + m, + fail_msg="900 should contain the account name as 3rd argument " + "({expects}), not {got}: {msg}", + ) - @cases.SpecificationSelector.requiredBySpecification('IRCv3.1') - @cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.1") + @cases.OptionalityHelper.skipUnlessHasMechanism("PLAIN") def testPlainNoAuthzid(self): """“message = [authzid] UTF8NUL authcid UTF8NUL passwd @@ -60,73 +78,105 @@ class SaslTestCase(cases.BaseServerTestCase, cases.OptionalityHelper): identity produces the same authorization identity.” -- """ - self.controller.registerUser(self, 'foo', 'sesame') - self.controller.registerUser(self, 'jilles', 'sesame') - self.controller.registerUser(self, 'bar', 'sesame') + self.controller.registerUser(self, "foo", "sesame") + self.controller.registerUser(self, "jilles", "sesame") + self.controller.registerUser(self, "bar", "sesame") self.addClient() - self.sendLine(1, 'CAP LS 302') + self.sendLine(1, "CAP LS 302") capabilities = self.getCapLs(1) - self.assertIn('sasl', capabilities, - fail_msg='Does not have SASL as the controller claims.') - if capabilities['sasl'] is not None: - self.assertIn('PLAIN', capabilities['sasl'], - fail_msg='Does not have PLAIN mechanism as the controller ' - 'claims') - self.sendLine(1, 'AUTHENTICATE PLAIN') - m = self.getMessage(1, filter_pred=lambda m:m.command != 'NOTICE') - self.assertMessageEqual(m, command='AUTHENTICATE', params=['+'], - fail_msg='Sent “AUTHENTICATE PLAIN”, server should have ' - 'replied with “AUTHENTICATE +”, but instead sent: {msg}') - self.sendLine(1, 'AUTHENTICATE AGppbGxlcwBzZXNhbWU=') - m = self.getMessage(1, filter_pred=lambda m:m.command != 'NOTICE') - self.assertMessageEqual(m, command='900', - fail_msg='Did not send 900 after correct SASL authentication.') - self.assertEqual(m.params[2], 'jilles', m, - fail_msg='900 should contain the account name as 3rd argument ' - '({expects}), not {got}: {msg}') + self.assertIn( + "sasl", + capabilities, + fail_msg="Does not have SASL as the controller claims.", + ) + if capabilities["sasl"] is not None: + self.assertIn( + "PLAIN", + capabilities["sasl"], + fail_msg="Does not have PLAIN mechanism as the controller " "claims", + ) + self.sendLine(1, "AUTHENTICATE PLAIN") + m = self.getMessage(1, filter_pred=lambda m: m.command != "NOTICE") + self.assertMessageEqual( + m, + command="AUTHENTICATE", + params=["+"], + fail_msg="Sent “AUTHENTICATE PLAIN”, server should have " + "replied with “AUTHENTICATE +”, but instead sent: {msg}", + ) + self.sendLine(1, "AUTHENTICATE AGppbGxlcwBzZXNhbWU=") + m = self.getMessage(1, filter_pred=lambda m: m.command != "NOTICE") + self.assertMessageEqual( + m, + command="900", + fail_msg="Did not send 900 after correct SASL authentication.", + ) + self.assertEqual( + m.params[2], + "jilles", + m, + fail_msg="900 should contain the account name as 3rd argument " + "({expects}), not {got}: {msg}", + ) - @cases.SpecificationSelector.requiredBySpecification('IRCv3.1') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.1") def testMechanismNotAvailable(self): """“If authentication fails, a 904 or 905 numeric will be sent” -- """ - self.controller.registerUser(self, 'jilles', 'sesame') + self.controller.registerUser(self, "jilles", "sesame") self.addClient() - self.sendLine(1, 'CAP LS 302') + self.sendLine(1, "CAP LS 302") capabilities = self.getCapLs(1) - self.assertIn('sasl', capabilities, - fail_msg='Does not have SASL as the controller claims.') - self.sendLine(1, 'AUTHENTICATE FOO') + self.assertIn( + "sasl", + capabilities, + fail_msg="Does not have SASL as the controller claims.", + ) + self.sendLine(1, "AUTHENTICATE FOO") m = self.getRegistrationMessage(1) - self.assertMessageEqual(m, command='904', - fail_msg='Did not reply with 904 to “AUTHENTICATE FOO”: {msg}') + self.assertMessageEqual( + m, + command="904", + fail_msg="Did not reply with 904 to “AUTHENTICATE FOO”: {msg}", + ) - @cases.SpecificationSelector.requiredBySpecification('IRCv3.1') - @cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.1") + @cases.OptionalityHelper.skipUnlessHasMechanism("PLAIN") def testPlainLarge(self): """Test the client splits large AUTHENTICATE messages whose payload is not a multiple of 400. """ - self.controller.registerUser(self, 'foo', 'bar'*100) - authstring = base64.b64encode(b'\x00'.join( - [b'foo', b'foo', b'bar'*100])).decode() + self.controller.registerUser(self, "foo", "bar" * 100) + authstring = base64.b64encode( + b"\x00".join([b"foo", b"foo", b"bar" * 100]) + ).decode() self.addClient() - self.sendLine(1, 'CAP LS 302') + self.sendLine(1, "CAP LS 302") capabilities = self.getCapLs(1) - self.assertIn('sasl', capabilities, - fail_msg='Does not have SASL as the controller claims.') - if capabilities['sasl'] is not None: - self.assertIn('PLAIN', capabilities['sasl'], - fail_msg='Does not have PLAIN mechanism as the controller ' - 'claims') - self.sendLine(1, 'AUTHENTICATE PLAIN') + self.assertIn( + "sasl", + capabilities, + fail_msg="Does not have SASL as the controller claims.", + ) + if capabilities["sasl"] is not None: + self.assertIn( + "PLAIN", + capabilities["sasl"], + fail_msg="Does not have PLAIN mechanism as the controller " "claims", + ) + self.sendLine(1, "AUTHENTICATE PLAIN") m = self.getRegistrationMessage(1) - self.assertMessageEqual(m, command='AUTHENTICATE', params=['+'], - fail_msg='Sent “AUTHENTICATE PLAIN”, expected ' - '“AUTHENTICATE +” as a response, but got: {msg}') - self.sendLine(1, 'AUTHENTICATE {}'.format(authstring[0:400])) - self.sendLine(1, 'AUTHENTICATE {}'.format(authstring[400:])) + self.assertMessageEqual( + m, + command="AUTHENTICATE", + params=["+"], + fail_msg="Sent “AUTHENTICATE PLAIN”, expected " + "“AUTHENTICATE +” as a response, but got: {msg}", + ) + self.sendLine(1, "AUTHENTICATE {}".format(authstring[0:400])) + self.sendLine(1, "AUTHENTICATE {}".format(authstring[400:])) self.confirmSuccessfulAuth() @@ -134,45 +184,61 @@ class SaslTestCase(cases.BaseServerTestCase, cases.OptionalityHelper): # TODO: check username/etc in this as well, so we can apply it to other tests # TODO: may be in the other order m = self.getRegistrationMessage(1) - self.assertMessageEqual(m, command='900', - fail_msg='Expected 900 (RPL_LOGGEDIN) after successful ' - 'login, but got: {msg}') + self.assertMessageEqual( + m, + command="900", + fail_msg="Expected 900 (RPL_LOGGEDIN) after successful " + "login, but got: {msg}", + ) m = self.getRegistrationMessage(1) - self.assertMessageEqual(m, command='903', - fail_msg='Expected 903 (RPL_SASLSUCCESS) after successful ' - 'login, but got: {msg}') + self.assertMessageEqual( + m, + command="903", + fail_msg="Expected 903 (RPL_SASLSUCCESS) after successful " + "login, but got: {msg}", + ) # TODO: add a test for when the length of the authstring is greater than 800. # I don't know how to do it, because it would make the registration # message's length too big for it to be valid. - @cases.SpecificationSelector.requiredBySpecification('IRCv3.1') - @cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN') + @cases.SpecificationSelector.requiredBySpecification("IRCv3.1") + @cases.OptionalityHelper.skipUnlessHasMechanism("PLAIN") def testPlainLargeEquals400(self): """Test the client splits large AUTHENTICATE messages whose payload is not a multiple of 400. """ - self.controller.registerUser(self, 'foo', 'bar'*97) - authstring = base64.b64encode(b'\x00'.join( - [b'foo', b'foo', b'bar'*97])).decode() - assert len(authstring) == 400, 'Bad test' + self.controller.registerUser(self, "foo", "bar" * 97) + authstring = base64.b64encode( + b"\x00".join([b"foo", b"foo", b"bar" * 97]) + ).decode() + assert len(authstring) == 400, "Bad test" self.addClient() - self.sendLine(1, 'CAP LS 302') + self.sendLine(1, "CAP LS 302") capabilities = self.getCapLs(1) - self.assertIn('sasl', capabilities, - fail_msg='Does not have SASL as the controller claims.') - if capabilities['sasl'] is not None: - self.assertIn('PLAIN', capabilities['sasl'], - fail_msg='Does not have PLAIN mechanism as the controller ' - 'claims') - self.sendLine(1, 'AUTHENTICATE PLAIN') + self.assertIn( + "sasl", + capabilities, + fail_msg="Does not have SASL as the controller claims.", + ) + if capabilities["sasl"] is not None: + self.assertIn( + "PLAIN", + capabilities["sasl"], + fail_msg="Does not have PLAIN mechanism as the controller " "claims", + ) + self.sendLine(1, "AUTHENTICATE PLAIN") m = self.getRegistrationMessage(1) - self.assertMessageEqual(m, command='AUTHENTICATE', params=['+'], - fail_msg='Sent “AUTHENTICATE PLAIN”, expected ' - '“AUTHENTICATE +” as a response, but got: {msg}') - self.sendLine(1, 'AUTHENTICATE {}'.format(authstring)) - self.sendLine(1, 'AUTHENTICATE +') + self.assertMessageEqual( + m, + command="AUTHENTICATE", + params=["+"], + fail_msg="Sent “AUTHENTICATE PLAIN”, expected " + "“AUTHENTICATE +” as a response, but got: {msg}", + ) + self.sendLine(1, "AUTHENTICATE {}".format(authstring)) + self.sendLine(1, "AUTHENTICATE +") self.confirmSuccessfulAuth() diff --git a/irctest/server_tests/test_statusmsg.py b/irctest/server_tests/test_statusmsg.py index 7414908..849af4e 100644 --- a/irctest/server_tests/test_statusmsg.py +++ b/irctest/server_tests/test_statusmsg.py @@ -1,41 +1,47 @@ from irctest import cases from irctest.numerics import RPL_NAMREPLY -class StatusmsgTestCase(cases.BaseServerTestCase): - @cases.SpecificationSelector.requiredBySpecification('Oragono') +class StatusmsgTestCase(cases.BaseServerTestCase): + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testInIsupport(self): """Check that the expected STATUSMSG parameter appears in our isupport list.""" isupport = self.getISupport() - self.assertEqual(isupport['STATUSMSG'], '~&@%+') + self.assertEqual(isupport["STATUSMSG"], "~&@%+") - @cases.SpecificationSelector.requiredBySpecification('Oragono') + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testStatusmsg(self): """Test that STATUSMSG are sent to the intended recipients, with the intended prefixes.""" - self.connectClient('chanop') - self.joinChannel(1, '#chan') + self.connectClient("chanop") + self.joinChannel(1, "#chan") self.getMessages(1) - self.connectClient('joe') - self.joinChannel(2, '#chan') + self.connectClient("joe") + self.joinChannel(2, "#chan") self.getMessages(2) - self.connectClient('schmoe') - self.sendLine(3, 'join #chan') + self.connectClient("schmoe") + self.sendLine(3, "join #chan") messages = self.getMessages(3) names = set() for message in messages: if message.command == RPL_NAMREPLY: names.update(set(message.params[-1].split())) # chanop should be opped - self.assertEqual(names, {'@chanop', 'joe', 'schmoe'}, f'unexpected names: {names}') + self.assertEqual( + names, {"@chanop", "joe", "schmoe"}, f"unexpected names: {names}" + ) - self.sendLine(3, 'privmsg @#chan :this message is for operators') + self.sendLine(3, "privmsg @#chan :this message is for operators") self.getMessages(3) # check the operator's messages - statusMsg = self.getMessage(1, filter_pred=lambda m:m.command == 'PRIVMSG') - self.assertMessageEqual(statusMsg, params=['@#chan', 'this message is for operators']) + statusMsg = self.getMessage(1, filter_pred=lambda m: m.command == "PRIVMSG") + self.assertMessageEqual( + statusMsg, params=["@#chan", "this message is for operators"] + ) # check the non-operator's messages - unprivilegedMessages = [msg for msg in self.getMessages(2) if msg.command == 'PRIVMSG'] + unprivilegedMessages = [ + msg for msg in self.getMessages(2) if msg.command == "PRIVMSG" + ] self.assertEqual(len(unprivilegedMessages), 0) diff --git a/irctest/server_tests/test_user_commands.py b/irctest/server_tests/test_user_commands.py index 983ea24..e3fbede 100644 --- a/irctest/server_tests/test_user_commands.py +++ b/irctest/server_tests/test_user_commands.py @@ -4,126 +4,157 @@ User commands as specified in Section 3.6 of RFC 2812: """ from irctest import cases -from irctest.numerics import RPL_WHOISUSER, RPL_WHOISCHANNELS, RPL_AWAY, RPL_NOWAWAY, RPL_UNAWAY +from irctest.numerics import ( + RPL_WHOISUSER, + RPL_WHOISCHANNELS, + RPL_AWAY, + RPL_NOWAWAY, + RPL_UNAWAY, +) + class WhoisTestCase(cases.BaseServerTestCase): - - @cases.SpecificationSelector.requiredBySpecification('RFC2812') + @cases.SpecificationSelector.requiredBySpecification("RFC2812") def testWhoisUser(self): """Test basic WHOIS behavior""" - nick = 'myCoolNickname' - username = 'myUsernam' # may be truncated if longer than this - realname = 'My Real Name' + nick = "myCoolNickname" + username = "myUsernam" # may be truncated if longer than this + realname = "My Real Name" self.addClient() - self.sendLine(1, f'NICK {nick}') - self.sendLine(1, f'USER {username} 0 * :{realname}') + self.sendLine(1, f"NICK {nick}") + self.sendLine(1, f"USER {username} 0 * :{realname}") self.skipToWelcome(1) - self.connectClient('otherNickname') + self.connectClient("otherNickname") self.getMessages(2) - self.sendLine(2, 'WHOIS mycoolnickname') + self.sendLine(2, "WHOIS mycoolnickname") messages = self.getMessages(2) whois_user = messages[0] self.assertEqual(whois_user.command, RPL_WHOISUSER) # " * :" self.assertEqual(whois_user.params[1], nick) - self.assertIn(whois_user.params[2], ('~' + username, username)) + self.assertIn(whois_user.params[2], ("~" + username, username)) # dumb regression test for oragono/oragono#355: - self.assertNotIn(whois_user.params[3], [nick, username, '~' + username, realname]) + self.assertNotIn( + whois_user.params[3], [nick, username, "~" + username, realname] + ) self.assertEqual(whois_user.params[5], realname) class InvisibleTestCase(cases.BaseServerTestCase): - - @cases.SpecificationSelector.requiredBySpecification('Oragono') + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testInvisibleWhois(self): """Test interaction between MODE +i and RPL_WHOISCHANNELS.""" - self.connectClient('userOne') - self.joinChannel(1, '#xyz') + self.connectClient("userOne") + self.joinChannel(1, "#xyz") - self.connectClient('userTwo') + self.connectClient("userTwo") self.getMessages(2) - self.sendLine(2, 'WHOIS userOne') + self.sendLine(2, "WHOIS userOne") commands = {m.command for m in self.getMessages(2)} - self.assertIn(RPL_WHOISCHANNELS, commands, - 'RPL_WHOISCHANNELS should be sent for a non-invisible nick') + self.assertIn( + RPL_WHOISCHANNELS, + commands, + "RPL_WHOISCHANNELS should be sent for a non-invisible nick", + ) self.getMessages(1) - self.sendLine(1, 'MODE userOne +i') + self.sendLine(1, "MODE userOne +i") message = self.getMessage(1) - self.assertEqual(message.command, 'MODE', - 'Expected MODE reply, but received {}'.format(message.command)) - self.assertEqual(message.params, ['userOne', '+i'], - 'Expected user set +i, but received {}'.format(message.params)) + self.assertEqual( + message.command, + "MODE", + "Expected MODE reply, but received {}".format(message.command), + ) + self.assertEqual( + message.params, + ["userOne", "+i"], + "Expected user set +i, but received {}".format(message.params), + ) self.getMessages(2) - self.sendLine(2, 'WHOIS userOne') + self.sendLine(2, "WHOIS userOne") commands = {m.command for m in self.getMessages(2)} - self.assertNotIn(RPL_WHOISCHANNELS, commands, - 'RPL_WHOISCHANNELS should not be sent for an invisible nick' - 'unless the user is also a member of the channel') + self.assertNotIn( + RPL_WHOISCHANNELS, + commands, + "RPL_WHOISCHANNELS should not be sent for an invisible nick" + "unless the user is also a member of the channel", + ) - self.sendLine(2, 'JOIN #xyz') - self.sendLine(2, 'WHOIS userOne') + self.sendLine(2, "JOIN #xyz") + self.sendLine(2, "WHOIS userOne") commands = {m.command for m in self.getMessages(2)} - self.assertIn(RPL_WHOISCHANNELS, commands, - 'RPL_WHOISCHANNELS should be sent for an invisible nick' - 'if the user is also a member of the channel') + self.assertIn( + RPL_WHOISCHANNELS, + commands, + "RPL_WHOISCHANNELS should be sent for an invisible nick" + "if the user is also a member of the channel", + ) - self.sendLine(2, 'PART #xyz') + self.sendLine(2, "PART #xyz") self.getMessages(2) self.getMessages(1) - self.sendLine(1, 'MODE userOne -i') + self.sendLine(1, "MODE userOne -i") message = self.getMessage(1) - self.assertEqual(message.command, 'MODE', - 'Expected MODE reply, but received {}'.format(message.command)) - self.assertEqual(message.params, ['userOne', '-i'], - 'Expected user set -i, but received {}'.format(message.params)) + self.assertEqual( + message.command, + "MODE", + "Expected MODE reply, but received {}".format(message.command), + ) + self.assertEqual( + message.params, + ["userOne", "-i"], + "Expected user set -i, but received {}".format(message.params), + ) - self.sendLine(2, 'WHOIS userOne') + self.sendLine(2, "WHOIS userOne") commands = {m.command for m in self.getMessages(2)} - self.assertIn(RPL_WHOISCHANNELS, commands, - 'RPL_WHOISCHANNELS should be sent for a non-invisible nick') + self.assertIn( + RPL_WHOISCHANNELS, + commands, + "RPL_WHOISCHANNELS should be sent for a non-invisible nick", + ) - @cases.SpecificationSelector.requiredBySpecification('Oragono') + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testWhoisAccount(self): """Test numeric 330, RPL_WHOISACCOUNT.""" - self.controller.registerUser(self, 'shivaram', 'sesame') - self.connectClient('netcat') - self.sendLine(1, 'NS IDENTIFY shivaram sesame') + self.controller.registerUser(self, "shivaram", "sesame") + self.connectClient("netcat") + self.sendLine(1, "NS IDENTIFY shivaram sesame") self.getMessages(1) - self.connectClient('curious') - self.sendLine(2, 'WHOIS netcat') + self.connectClient("curious") + self.sendLine(2, "WHOIS netcat") messages = self.getMessages(2) # 330 RPL_WHOISACCOUNT - whoisaccount = [message for message in messages if message.command == '330'] + whoisaccount = [message for message in messages if message.command == "330"] self.assertEqual(len(whoisaccount), 1) params = whoisaccount[0].params # : self.assertEqual(len(params), 4) - self.assertEqual(params[:3], ['curious', 'netcat', 'shivaram']) + self.assertEqual(params[:3], ["curious", "netcat", "shivaram"]) - self.sendLine(1, 'WHOIS curious') + self.sendLine(1, "WHOIS curious") messages = self.getMessages(2) - whoisaccount = [message for message in messages if message.command == '330'] + whoisaccount = [message for message in messages if message.command == "330"] self.assertEqual(len(whoisaccount), 0) -class AwayTestCase(cases.BaseServerTestCase): - @cases.SpecificationSelector.requiredBySpecification('RFC2812') +class AwayTestCase(cases.BaseServerTestCase): + @cases.SpecificationSelector.requiredBySpecification("RFC2812") def testAway(self): - self.connectClient('bar') + self.connectClient("bar") self.sendLine(1, "AWAY :I'm not here right now") replies = self.getMessages(1) self.assertIn(RPL_NOWAWAY, [msg.command for msg in replies]) - self.connectClient('qux') + self.connectClient("qux") self.sendLine(2, "PRIVMSG bar :what's up") replies = self.getMessages(2) self.assertEqual(len(replies), 1) self.assertEqual(replies[0].command, RPL_AWAY) - self.assertEqual(replies[0].params, ['qux', 'bar', "I'm not here right now"]) + self.assertEqual(replies[0].params, ["qux", "bar", "I'm not here right now"]) self.sendLine(1, "AWAY") replies = self.getMessages(1) @@ -133,31 +164,36 @@ class AwayTestCase(cases.BaseServerTestCase): replies = self.getMessages(2) self.assertEqual(len(replies), 0) -class TestNoCTCPMode(cases.BaseServerTestCase): - @cases.SpecificationSelector.requiredBySpecification('Oragono') +class TestNoCTCPMode(cases.BaseServerTestCase): + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testNoCTCPMode(self): - self.connectClient('bar', 'bar') - self.connectClient('qux', 'qux') + self.connectClient("bar", "bar") + self.connectClient("qux", "qux") # CTCP is not blocked by default: - self.sendLine('qux', 'PRIVMSG bar :\x01VERSION\x01') - self.getMessages('qux') - relay = [msg for msg in self.getMessages('bar') if msg.command == 'PRIVMSG'][0] - self.assertEqual(relay.params[-1], '\x01VERSION\x01') + self.sendLine("qux", "PRIVMSG bar :\x01VERSION\x01") + self.getMessages("qux") + relay = [msg for msg in self.getMessages("bar") if msg.command == "PRIVMSG"][0] + self.assertEqual(relay.params[-1], "\x01VERSION\x01") # set the no-CTCP user mode on bar: - self.sendLine('bar', 'MODE bar +T') - replies = self.getMessages('bar') - umode_line = [msg for msg in replies if msg.command == 'MODE'][0] - self.assertMessageEqual(umode_line, command='MODE', params=['bar', '+T']) + self.sendLine("bar", "MODE bar +T") + replies = self.getMessages("bar") + umode_line = [msg for msg in replies if msg.command == "MODE"][0] + self.assertMessageEqual(umode_line, command="MODE", params=["bar", "+T"]) # CTCP is now blocked: - self.sendLine('qux', 'PRIVMSG bar :\x01VERSION\x01') - self.getMessages('qux') - self.assertEqual(self.getMessages('bar'), []) + self.sendLine("qux", "PRIVMSG bar :\x01VERSION\x01") + self.getMessages("qux") + self.assertEqual(self.getMessages("bar"), []) # normal PRIVMSG go through: - self.sendLine('qux', 'PRIVMSG bar :please just tell me your client version') - self.getMessages('qux') - relay = self.getMessages('bar')[0] - self.assertMessageEqual(relay, command='PRIVMSG', nick='qux', params=['bar', 'please just tell me your client version']) + self.sendLine("qux", "PRIVMSG bar :please just tell me your client version") + self.getMessages("qux") + relay = self.getMessages("bar")[0] + self.assertMessageEqual( + relay, + command="PRIVMSG", + nick="qux", + params=["bar", "please just tell me your client version"], + ) diff --git a/irctest/server_tests/test_utf8.py b/irctest/server_tests/test_utf8.py index 0c051d8..0fc3af8 100644 --- a/irctest/server_tests/test_utf8.py +++ b/irctest/server_tests/test_utf8.py @@ -1,23 +1,29 @@ from irctest import cases + class Utf8TestCase(cases.BaseServerTestCase, cases.OptionalityHelper): - @cases.SpecificationSelector.requiredBySpecification('Oragono') + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testUtf8Validation(self): - self.connectClient('bar', capabilities=['batch', 'echo-message', 'labeled-response', 'message-tags']) - self.joinChannel(1, '#qux') - self.sendLine(1, 'PRIVMSG #qux hi') + self.connectClient( + "bar", + capabilities=["batch", "echo-message", "labeled-response", "message-tags"], + ) + self.joinChannel(1, "#qux") + self.sendLine(1, "PRIVMSG #qux hi") ms = self.getMessages(1) - self.assertMessageEqual([m for m in ms if m.command == 'PRIVMSG'][0], params=['#qux', 'hi']) + self.assertMessageEqual( + [m for m in ms if m.command == "PRIVMSG"][0], params=["#qux", "hi"] + ) - self.sendLine(1, b'PRIVMSG #qux hi\xaa') + self.sendLine(1, b"PRIVMSG #qux hi\xaa") ms = self.getMessages(1) self.assertEqual(len(ms), 1) - self.assertEqual(ms[0].command, 'FAIL') - self.assertEqual(ms[0].params[:2], ['PRIVMSG', 'INVALID_UTF8']) + self.assertEqual(ms[0].command, "FAIL") + self.assertEqual(ms[0].params[:2], ["PRIVMSG", "INVALID_UTF8"]) - self.sendLine(1, b'@label=xyz PRIVMSG #qux hi\xaa') + self.sendLine(1, b"@label=xyz PRIVMSG #qux hi\xaa") ms = self.getMessages(1) self.assertEqual(len(ms), 1) - self.assertEqual(ms[0].command, 'FAIL') - self.assertEqual(ms[0].params[:2], ['PRIVMSG', 'INVALID_UTF8']) - self.assertEqual(ms[0].tags.get('label'), 'xyz') + self.assertEqual(ms[0].command, "FAIL") + self.assertEqual(ms[0].params[:2], ["PRIVMSG", "INVALID_UTF8"]) + self.assertEqual(ms[0].tags.get("label"), "xyz") diff --git a/irctest/server_tests/test_znc_playback.py b/irctest/server_tests/test_znc_playback.py index 3c30c32..a8bb277 100644 --- a/irctest/server_tests/test_znc_playback.py +++ b/irctest/server_tests/test_znc_playback.py @@ -10,7 +10,7 @@ def extract_playback_privmsgs(messages): # convert the output of a playback command, drop the echo message result = [] for msg in messages: - if msg.command == 'PRIVMSG' and msg.params[0].lower() != '*playback': + if msg.command == "PRIVMSG" and msg.params[0].lower() != "*playback": result.append(to_history_message(msg)) return result @@ -22,91 +22,197 @@ class ZncPlaybackTestCase(cases.BaseServerTestCase): "chathistory": True, } - @cases.SpecificationSelector.requiredBySpecification('Oragono') + @cases.SpecificationSelector.requiredBySpecification("Oragono") def testZncPlayback(self): early_time = int(time.time() - 60) - chname = random_name('#znc_channel') - bar, pw = random_name('bar'), random_name('pass') + chname = random_name("#znc_channel") + bar, pw = random_name("bar"), random_name("pass") self.controller.registerUser(self, bar, pw) - self.connectClient(bar, name=bar, capabilities=['batch', 'labeled-response', 'message-tags', 'server-time', 'echo-message'], password=pw) + self.connectClient( + bar, + name=bar, + capabilities=[ + "batch", + "labeled-response", + "message-tags", + "server-time", + "echo-message", + ], + password=pw, + ) self.joinChannel(bar, chname) - qux = random_name('qux') - self.connectClient(qux, name=qux, capabilities=['batch', 'labeled-response', 'message-tags', 'server-time', 'echo-message']) + qux = random_name("qux") + self.connectClient( + qux, + name=qux, + capabilities=[ + "batch", + "labeled-response", + "message-tags", + "server-time", + "echo-message", + ], + ) self.joinChannel(qux, chname) - self.sendLine(qux, 'PRIVMSG %s :hi there' % (bar,)) - dm = to_history_message([msg for msg in self.getMessages(qux) if msg.command == 'PRIVMSG'][0]) - self.assertEqual(dm.text, 'hi there') + self.sendLine(qux, "PRIVMSG %s :hi there" % (bar,)) + dm = to_history_message( + [msg for msg in self.getMessages(qux) if msg.command == "PRIVMSG"][0] + ) + self.assertEqual(dm.text, "hi there") NUM_MESSAGES = 10 echo_messages = [] for i in range(NUM_MESSAGES): - self.sendLine(qux, 'PRIVMSG %s :this is message %d' % (chname, i)) - echo_messages.extend(to_history_message(msg) for msg in self.getMessages(qux) if msg.command == 'PRIVMSG') + self.sendLine(qux, "PRIVMSG %s :this is message %d" % (chname, i)) + echo_messages.extend( + to_history_message(msg) + for msg in self.getMessages(qux) + if msg.command == "PRIVMSG" + ) time.sleep(0.003) self.assertEqual(len(echo_messages), NUM_MESSAGES) self.getMessages(bar) # reattach to 'bar' - self.connectClient(bar, name='viewer', capabilities=['batch', 'labeled-response', 'message-tags', 'server-time', 'echo-message'], password=pw) - self.sendLine('viewer', 'PRIVMSG *playback :play * %d' % (early_time,)) - messages = extract_playback_privmsgs(self.getMessages('viewer')) + self.connectClient( + bar, + name="viewer", + capabilities=[ + "batch", + "labeled-response", + "message-tags", + "server-time", + "echo-message", + ], + password=pw, + ) + self.sendLine("viewer", "PRIVMSG *playback :play * %d" % (early_time,)) + messages = extract_playback_privmsgs(self.getMessages("viewer")) self.assertEqual(set(messages), set([dm] + echo_messages)) - self.sendLine('viewer', 'QUIT') - self.assertDisconnected('viewer') + self.sendLine("viewer", "QUIT") + self.assertDisconnected("viewer") # reattach to 'bar', play back selectively - self.connectClient(bar, name='viewer', capabilities=['batch', 'labeled-response', 'message-tags', 'server-time', 'echo-message'], password=pw) + self.connectClient( + bar, + name="viewer", + capabilities=[ + "batch", + "labeled-response", + "message-tags", + "server-time", + "echo-message", + ], + password=pw, + ) mid_timestamp = ircv3_timestamp_to_unixtime(echo_messages[5].time) # exclude message 5 itself (oragono's CHATHISTORY implementation corrects for this, but znc.in/playback does not because whatever) - mid_timestamp += .001 - self.sendLine('viewer', 'PRIVMSG *playback :play * %s' % (mid_timestamp,)) - messages = extract_playback_privmsgs(self.getMessages('viewer')) + mid_timestamp += 0.001 + self.sendLine("viewer", "PRIVMSG *playback :play * %s" % (mid_timestamp,)) + messages = extract_playback_privmsgs(self.getMessages("viewer")) self.assertEqual(messages, echo_messages[6:]) - self.sendLine('viewer', 'QUIT') - self.assertDisconnected('viewer') + self.sendLine("viewer", "QUIT") + self.assertDisconnected("viewer") # reattach to 'bar', play back selectively (pass a parameter and 2 timestamps) - self.connectClient(bar, name='viewer', capabilities=['batch', 'labeled-response', 'message-tags', 'server-time', 'echo-message'], password=pw) + self.connectClient( + bar, + name="viewer", + capabilities=[ + "batch", + "labeled-response", + "message-tags", + "server-time", + "echo-message", + ], + password=pw, + ) start_timestamp = ircv3_timestamp_to_unixtime(echo_messages[2].time) - start_timestamp += .001 + start_timestamp += 0.001 end_timestamp = ircv3_timestamp_to_unixtime(echo_messages[7].time) - self.sendLine('viewer', 'PRIVMSG *playback :play %s %s %s' % (chname, start_timestamp, end_timestamp,)) - messages = extract_playback_privmsgs(self.getMessages('viewer')) + self.sendLine( + "viewer", + "PRIVMSG *playback :play %s %s %s" + % ( + chname, + start_timestamp, + end_timestamp, + ), + ) + messages = extract_playback_privmsgs(self.getMessages("viewer")) self.assertEqual(messages, echo_messages[3:7]) # test nicknames as targets - self.sendLine('viewer', 'PRIVMSG *playback :play %s %d' % (qux, early_time,)) - messages = extract_playback_privmsgs(self.getMessages('viewer')) + self.sendLine( + "viewer", + "PRIVMSG *playback :play %s %d" + % ( + qux, + early_time, + ), + ) + messages = extract_playback_privmsgs(self.getMessages("viewer")) self.assertEqual(messages, [dm]) - self.sendLine('viewer', 'PRIVMSG *playback :play %s %d' % (qux.upper(), early_time,)) - messages = extract_playback_privmsgs(self.getMessages('viewer')) + self.sendLine( + "viewer", + "PRIVMSG *playback :play %s %d" + % ( + qux.upper(), + early_time, + ), + ) + messages = extract_playback_privmsgs(self.getMessages("viewer")) self.assertEqual(messages, [dm]) - self.sendLine('viewer', 'QUIT') - self.assertDisconnected('viewer') + self.sendLine("viewer", "QUIT") + self.assertDisconnected("viewer") # test 2-argument form - self.connectClient(bar, name='viewer', capabilities=['batch', 'labeled-response', 'message-tags', 'server-time', 'echo-message'], password=pw) - self.sendLine('viewer', 'PRIVMSG *playback :play %s' % (chname,)) - messages = extract_playback_privmsgs(self.getMessages('viewer')) + self.connectClient( + bar, + name="viewer", + capabilities=[ + "batch", + "labeled-response", + "message-tags", + "server-time", + "echo-message", + ], + password=pw, + ) + self.sendLine("viewer", "PRIVMSG *playback :play %s" % (chname,)) + messages = extract_playback_privmsgs(self.getMessages("viewer")) self.assertEqual(messages, echo_messages) - self.sendLine('viewer', 'PRIVMSG *playback :play *self') - messages = extract_playback_privmsgs(self.getMessages('viewer')) + self.sendLine("viewer", "PRIVMSG *playback :play *self") + messages = extract_playback_privmsgs(self.getMessages("viewer")) self.assertEqual(messages, [dm]) - self.sendLine('viewer', 'PRIVMSG *playback :play *') - messages = extract_playback_privmsgs(self.getMessages('viewer')) + self.sendLine("viewer", "PRIVMSG *playback :play *") + messages = extract_playback_privmsgs(self.getMessages("viewer")) self.assertEqual(set(messages), set([dm] + echo_messages)) - self.sendLine('viewer', 'QUIT') - self.assertDisconnected('viewer') + self.sendLine("viewer", "QUIT") + self.assertDisconnected("viewer") # test limiting behavior config = self.controller.getConfig() - config['history']['znc-maxmessages'] = 5 + config["history"]["znc-maxmessages"] = 5 self.controller.rehash(self, config) - self.connectClient(bar, name='viewer', capabilities=['batch', 'labeled-response', 'message-tags', 'server-time', 'echo-message'], password=pw) - self.sendLine('viewer', 'PRIVMSG *playback :play %s %d' % (chname, int(time.time() - 60))) - messages = extract_playback_privmsgs(self.getMessages('viewer')) + self.connectClient( + bar, + name="viewer", + capabilities=[ + "batch", + "labeled-response", + "message-tags", + "server-time", + "echo-message", + ], + password=pw, + ) + self.sendLine( + "viewer", "PRIVMSG *playback :play %s %d" % (chname, int(time.time() - 60)) + ) + messages = extract_playback_privmsgs(self.getMessages("viewer")) # should receive the latest 5 messages self.assertEqual(messages, echo_messages[5:]) diff --git a/irctest/specifications.py b/irctest/specifications.py index 5609f09..2f3c72a 100644 --- a/irctest/specifications.py +++ b/irctest/specifications.py @@ -1,16 +1,17 @@ import enum + @enum.unique class Specifications(enum.Enum): - RFC1459 = 'RFC1459' - RFC2812 = 'RFC2812' - RFCDeprecated = 'RFC-deprecated' - IRC301 = 'IRCv3.1' - IRC302 = 'IRCv3.2' - IRC302Deprecated = 'IRCv3.2-deprecated' - Oragono = 'Oragono' - Multiline = 'multiline' - MessageTags = 'message-tags' + RFC1459 = "RFC1459" + RFC2812 = "RFC2812" + RFCDeprecated = "RFC-deprecated" + IRC301 = "IRCv3.1" + IRC302 = "IRCv3.2" + IRC302Deprecated = "IRCv3.2-deprecated" + Oragono = "Oragono" + Multiline = "multiline" + MessageTags = "message-tags" @classmethod def of_name(cls, name): diff --git a/irctest/tls.py b/irctest/tls.py index acb9b74..89ff454 100644 --- a/irctest/tls.py +++ b/irctest/tls.py @@ -1,4 +1,3 @@ import collections -TlsConfig = collections.namedtuple('TlsConfig', - 'enable trusted_fingerprints') +TlsConfig = collections.namedtuple("TlsConfig", "enable trusted_fingerprints") diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..b5413f6 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,2 @@ +[tool.black] +target-version = ['py37'] diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..8d79b7e --- /dev/null +++ b/setup.cfg @@ -0,0 +1,6 @@ +[flake8] +# E203: whitespaces before ':' +# E231: missing whitespace after ',' +# W503: line break before binary operator +ignore = E203,E231,W503 +max-line-length = 88