Use Black code style

This commit is contained in:
Valentin Lorentz 2021-02-22 19:02:13 +01:00 committed by Valentin Lorentz
parent 34ed62fd85
commit 8016e01daf
59 changed files with 4855 additions and 3033 deletions

6
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,6 @@
repos:
- repo: https://github.com/psf/black
rev: 20.8b1
hooks:
- id: black
language_version: python3.7

16
CONTRIBUTING.md Normal file
View File

@ -0,0 +1,16 @@
# Contributing
## Code style
Any color you like as long as it's [Black](https://github.com/psf/black).
In short:
* 88 columns
* double quotes
* avoid backslashes at line breaks (use parentheses)
* closing brackets/parentheses/... go on the same indent level as the line
that opened them
You can use [pre-commit](https://pre-commit.com/) to automatically run it
for you when you create a git commit.
Alternatively, run `pre-commit run -a`

View File

@ -8,11 +8,15 @@ import _pytest.unittest
from irctest.cases import _IrcTestCase, BaseClientTestCase, BaseServerTestCase from irctest.cases import _IrcTestCase, BaseClientTestCase, BaseServerTestCase
from irctest.basecontrollers import BaseClientController, BaseServerController from irctest.basecontrollers import BaseClientController, BaseServerController
def pytest_addoption(parser): def pytest_addoption(parser):
"""Called by pytest, registers CLI options passed to the pytest command.""" """Called by pytest, registers CLI options passed to the pytest command."""
parser.addoption("--controller", help="Which module to use to run the tested software.") parser.addoption(
parser.addoption('--openssl-bin', type=str, default='openssl', "--controller", help="Which module to use to run the tested software."
help='The openssl binary to use') )
parser.addoption(
"--openssl-bin", type=str, default="openssl", help="The openssl binary to use"
)
def pytest_configure(config): def pytest_configure(config):
@ -25,7 +29,7 @@ def pytest_configure(config):
try: try:
module = importlib.import_module(module_name) module = importlib.import_module(module_name)
except ImportError: except ImportError:
pytest.exit('Cannot import module {}'.format(module_name), 1) pytest.exit("Cannot import module {}".format(module_name), 1)
controller_class = module.get_irctest_controller_class() controller_class = module.get_irctest_controller_class()
if issubclass(controller_class, BaseClientController): if issubclass(controller_class, BaseClientController):
@ -34,10 +38,11 @@ def pytest_configure(config):
from irctest import server_tests as module from irctest import server_tests as module
else: else:
pytest.exit( pytest.exit(
r'{}.Controller should be a subclass of ' r"{}.Controller should be a subclass of "
r'irctest.basecontroller.Base{{Client,Server}}Controller' r"irctest.basecontroller.Base{{Client,Server}}Controller".format(
.format(module_name), module_name
1 ),
1,
) )
_IrcTestCase.controllerClass = controller_class _IrcTestCase.controllerClass = controller_class
_IrcTestCase.controllerClass.openssl_bin = config.getoption("openssl_bin") _IrcTestCase.controllerClass.openssl_bin = config.getoption("openssl_bin")

View File

@ -1,19 +1,25 @@
import enum import enum
import collections import collections
@enum.unique @enum.unique
class Mechanisms(enum.Enum): class Mechanisms(enum.Enum):
"""Enumeration for representing possible mechanisms.""" """Enumeration for representing possible mechanisms."""
@classmethod @classmethod
def as_string(cls, mech): def as_string(cls, mech):
return {cls.plain: 'PLAIN', return {
cls.ecdsa_nist256p_challenge: 'ECDSA-NIST256P-CHALLENGE', cls.plain: "PLAIN",
cls.scram_sha_256: 'SCRAM-SHA-256', cls.ecdsa_nist256p_challenge: "ECDSA-NIST256P-CHALLENGE",
}[mech] cls.scram_sha_256: "SCRAM-SHA-256",
}[mech]
plain = 1 plain = 1
ecdsa_nist256p_challenge = 2 ecdsa_nist256p_challenge = 2
scram_sha_256 = 3 scram_sha_256 = 3
Authentication = collections.namedtuple('Authentication',
'mechanisms username password ecdsa_key') Authentication = collections.namedtuple(
"Authentication", "mechanisms username password ecdsa_key"
)
Authentication.__new__.__defaults__ = ([Mechanisms.plain], None, None, None) Authentication.__new__.__defaults__ = ([Mechanisms.plain], None, None, None)

View File

@ -7,18 +7,22 @@ import subprocess
from .runner import NotImplementedByController from .runner import NotImplementedByController
class _BaseController: class _BaseController:
"""Base class for software controllers. """Base class for software controllers.
A software controller is an object that handles configuring and running A software controller is an object that handles configuring and running
a process (eg. a server or a client), as well as sending it instructions a process (eg. a server or a client), as well as sending it instructions
that are not part of the IRC specification.""" that are not part of the IRC specification."""
def __init__(self, test_config): def __init__(self, test_config):
self.test_config = test_config self.test_config = test_config
class DirectoryBasedController(_BaseController): class DirectoryBasedController(_BaseController):
"""Helper for controllers whose software configuration is based on an """Helper for controllers whose software configuration is based on an
arbitrary directory.""" arbitrary directory."""
def __init__(self, test_config): def __init__(self, test_config):
super().__init__(test_config) super().__init__(test_config)
self.directory = None self.directory = None
@ -33,18 +37,21 @@ class DirectoryBasedController(_BaseController):
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
self.proc.kill() self.proc.kill()
self.proc = None self.proc = None
def kill(self): def kill(self):
"""Calls `kill_proc` and cleans the configuration.""" """Calls `kill_proc` and cleans the configuration."""
if self.proc: if self.proc:
self.kill_proc() self.kill_proc()
if self.directory: if self.directory:
shutil.rmtree(self.directory) shutil.rmtree(self.directory)
def terminate(self): def terminate(self):
"""Stops the process gracefully, and does not clean its config.""" """Stops the process gracefully, and does not clean its config."""
self.proc.terminate() self.proc.terminate()
self.proc.wait() self.proc.wait()
self.proc = None self.proc = None
def open_file(self, name, mode='a'):
def open_file(self, name, mode="a"):
"""Open a file in the configuration directory.""" """Open a file in the configuration directory."""
assert self.directory assert self.directory
if os.sep in name: if os.sep in name:
@ -53,6 +60,7 @@ class DirectoryBasedController(_BaseController):
os.makedirs(dir_) os.makedirs(dir_)
assert os.path.isdir(dir_) assert os.path.isdir(dir_)
return open(os.path.join(self.directory, name), mode) return open(os.path.join(self.directory, name), mode)
def create_config(self): def create_config(self):
"""If there is no config dir, creates it and returns True. """If there is no config dir, creates it and returns True.
Else returns False.""" Else returns False."""
@ -63,41 +71,70 @@ class DirectoryBasedController(_BaseController):
return True return True
def gen_ssl(self): def gen_ssl(self):
self.csr_path = os.path.join(self.directory, 'ssl.csr') self.csr_path = os.path.join(self.directory, "ssl.csr")
self.key_path = os.path.join(self.directory, 'ssl.key') self.key_path = os.path.join(self.directory, "ssl.key")
self.pem_path = os.path.join(self.directory, 'ssl.pem') self.pem_path = os.path.join(self.directory, "ssl.pem")
self.dh_path = os.path.join(self.directory, 'dh.pem') self.dh_path = os.path.join(self.directory, "dh.pem")
subprocess.check_output([self.openssl_bin, 'req', '-new', '-newkey', 'rsa', subprocess.check_output(
'-nodes', '-out', self.csr_path, '-keyout', self.key_path, [
'-batch'], self.openssl_bin,
stderr=subprocess.DEVNULL) "req",
subprocess.check_output([self.openssl_bin, 'x509', '-req', "-new",
'-in', self.csr_path, '-signkey', self.key_path, "-newkey",
'-out', self.pem_path], "rsa",
stderr=subprocess.DEVNULL) "-nodes",
subprocess.check_output([self.openssl_bin, 'dhparam', "-out",
'-out', self.dh_path, '128'], self.csr_path,
stderr=subprocess.DEVNULL) "-keyout",
self.key_path,
"-batch",
],
stderr=subprocess.DEVNULL,
)
subprocess.check_output(
[
self.openssl_bin,
"x509",
"-req",
"-in",
self.csr_path,
"-signkey",
self.key_path,
"-out",
self.pem_path,
],
stderr=subprocess.DEVNULL,
)
subprocess.check_output(
[self.openssl_bin, "dhparam", "-out", self.dh_path, "128"],
stderr=subprocess.DEVNULL,
)
class BaseClientController(_BaseController): class BaseClientController(_BaseController):
"""Base controller for IRC clients.""" """Base controller for IRC clients."""
def run(self, hostname, port, auth): def run(self, hostname, port, auth):
raise NotImplementedError() raise NotImplementedError()
class BaseServerController(_BaseController): class BaseServerController(_BaseController):
"""Base controller for IRC server.""" """Base controller for IRC server."""
_port_wait_interval = .1
_port_wait_interval = 0.1
port_open = False port_open = False
def run(self, hostname, port, password,
valid_metadata_keys, invalid_metadata_keys): def run(self, hostname, port, password, valid_metadata_keys, invalid_metadata_keys):
raise NotImplementedError() raise NotImplementedError()
def registerUser(self, case, username, password=None): def registerUser(self, case, username, password=None):
raise NotImplementedByController('account registration') raise NotImplementedByController("account registration")
def wait_for_port(self): def wait_for_port(self):
while not self.port_open: while not self.port_open:
time.sleep(self._port_wait_interval) time.sleep(self._port_wait_interval)
try: try:
c = socket.create_connection(('localhost', self.port), timeout=1.0) c = socket.create_connection(("localhost", self.port), timeout=1.0)
c.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) c.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
# Make sure the server properly processes the disconnect. # Make sure the server properly processes the disconnect.

View File

@ -15,19 +15,38 @@ from .irc_utils.junkdrawer import normalizeWhitespace, random_name
from .irc_utils.sasl import sasl_plain_blob from .irc_utils.sasl import sasl_plain_blob
from .exceptions import ConnectionClosed from .exceptions import ConnectionClosed
from .specifications import Specifications from .specifications import Specifications
from .numerics import ERR_NOSUCHCHANNEL, ERR_TOOMANYCHANNELS, ERR_BADCHANNELKEY, ERR_INVITEONLYCHAN, ERR_BANNEDFROMCHAN, ERR_NEEDREGGEDNICK from .numerics import (
ERR_NOSUCHCHANNEL,
ERR_TOOMANYCHANNELS,
ERR_BADCHANNELKEY,
ERR_INVITEONLYCHAN,
ERR_BANNEDFROMCHAN,
ERR_NEEDREGGEDNICK,
)
CHANNEL_JOIN_FAIL_NUMERICS = frozenset(
[
ERR_NOSUCHCHANNEL,
ERR_TOOMANYCHANNELS,
ERR_BADCHANNELKEY,
ERR_INVITEONLYCHAN,
ERR_BANNEDFROMCHAN,
ERR_NEEDREGGEDNICK,
]
)
CHANNEL_JOIN_FAIL_NUMERICS = frozenset([ERR_NOSUCHCHANNEL, ERR_TOOMANYCHANNELS, ERR_BADCHANNELKEY, ERR_INVITEONLYCHAN, ERR_BANNEDFROMCHAN, ERR_NEEDREGGEDNICK])
class ChannelJoinException(Exception): class ChannelJoinException(Exception):
def __init__(self, code, params): def __init__(self, code, params):
super().__init__(f'Failed to join channel ({code}): {params}') super().__init__(f"Failed to join channel ({code}): {params}")
self.code = code self.code = code
self.params = params self.params = params
class _IrcTestCase(unittest.TestCase): class _IrcTestCase(unittest.TestCase):
"""Base class for test cases.""" """Base class for test cases."""
controllerClass = None # Will be set by __main__.py
controllerClass = None # Will be set by __main__.py
@staticmethod @staticmethod
def config(): def config():
@ -40,21 +59,36 @@ class _IrcTestCase(unittest.TestCase):
def description(self): def description(self):
method_doc = self._testMethodDoc method_doc = self._testMethodDoc
if not method_doc: if not method_doc:
return '' return ""
return '\t'+normalizeWhitespace( return (
"\t"
+ normalizeWhitespace(
method_doc, method_doc,
removeNewline=False, removeNewline=False,
).strip().replace('\n ', '\n\t') )
.strip()
.replace("\n ", "\n\t")
)
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.controller = self.controllerClass(self.config()) self.controller = self.controllerClass(self.config())
self.inbuffer = [] self.inbuffer = []
if self.show_io: if self.show_io:
print('---- new test ----') print("---- new test ----")
def assertMessageEqual(self, msg, subcommand=None, subparams=None,
target=None, nick=None, fail_msg=None, extra_format=(), def assertMessageEqual(
strip_first_param=False, **kwargs): self,
msg,
subcommand=None,
subparams=None,
target=None,
nick=None,
fail_msg=None,
extra_format=(),
strip_first_param=False,
**kwargs,
):
"""Helper for partially comparing a message. """Helper for partially comparing a message.
Takes the message as first arguments, and comparisons to be made Takes the message as first arguments, and comparisons to be made
@ -62,65 +96,71 @@ class _IrcTestCase(unittest.TestCase):
Deals with subcommands (eg. `CAP`) if any of `subcommand`, Deals with subcommands (eg. `CAP`) if any of `subcommand`,
`subparams`, and `target` are given.""" `subparams`, and `target` are given."""
fail_msg = fail_msg or '{msg}' fail_msg = fail_msg or "{msg}"
for (key, value) in kwargs.items(): for (key, value) in kwargs.items():
if strip_first_param and key == 'params': if strip_first_param and key == "params":
value = value[1:] value = value[1:]
self.assertEqual(getattr(msg, key), value, msg, fail_msg, self.assertEqual(
extra_format=extra_format) getattr(msg, key), value, msg, fail_msg, extra_format=extra_format
)
if nick: if nick:
self.assertNotEqual(msg.prefix, None, msg, fail_msg) self.assertNotEqual(msg.prefix, None, msg, fail_msg)
self.assertEqual(msg.prefix.split('!')[0], nick, msg, fail_msg) self.assertEqual(msg.prefix.split("!")[0], nick, msg, fail_msg)
if subcommand is not None or subparams is not None: if subcommand is not None or subparams is not None:
self.assertGreater(len(msg.params), 2, fail_msg) self.assertGreater(len(msg.params), 2, fail_msg)
#msg_target = msg.params[0] # msg_target = msg.params[0]
msg_subcommand = msg.params[1] msg_subcommand = msg.params[1]
msg_subparams = msg.params[2:] msg_subparams = msg.params[2:]
if subcommand: if subcommand:
self.assertEqual(msg_subcommand, subcommand, msg, fail_msg, self.assertEqual(
extra_format=extra_format) msg_subcommand, subcommand, msg, fail_msg, extra_format=extra_format
)
if subparams is not None: if subparams is not None:
self.assertEqual(msg_subparams, subparams, msg, fail_msg, self.assertEqual(
extra_format=extra_format) msg_subparams, subparams, msg, fail_msg, extra_format=extra_format
)
def assertIn(self, item, list_, msg=None, fail_msg=None, extra_format=()): def assertIn(self, item, list_, msg=None, fail_msg=None, extra_format=()):
if fail_msg: if fail_msg:
fail_msg = fail_msg.format(*extra_format, fail_msg = fail_msg.format(*extra_format, item=item, list=list_, msg=msg)
item=item, list=list_, msg=msg)
super().assertIn(item, list_, fail_msg) super().assertIn(item, list_, fail_msg)
def assertNotIn(self, item, list_, msg=None, fail_msg=None, extra_format=()): def assertNotIn(self, item, list_, msg=None, fail_msg=None, extra_format=()):
if fail_msg: if fail_msg:
fail_msg = fail_msg.format(*extra_format, fail_msg = fail_msg.format(*extra_format, item=item, list=list_, msg=msg)
item=item, list=list_, msg=msg)
super().assertNotIn(item, list_, fail_msg) super().assertNotIn(item, list_, fail_msg)
def assertEqual(self, got, expects, msg=None, fail_msg=None, extra_format=()): def assertEqual(self, got, expects, msg=None, fail_msg=None, extra_format=()):
if fail_msg: if fail_msg:
fail_msg = fail_msg.format(*extra_format, fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
got=got, expects=expects, msg=msg)
super().assertEqual(got, expects, fail_msg) super().assertEqual(got, expects, fail_msg)
def assertNotEqual(self, got, expects, msg=None, fail_msg=None, extra_format=()): def assertNotEqual(self, got, expects, msg=None, fail_msg=None, extra_format=()):
if fail_msg: if fail_msg:
fail_msg = fail_msg.format(*extra_format, fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
got=got, expects=expects, msg=msg)
super().assertNotEqual(got, expects, fail_msg) super().assertNotEqual(got, expects, fail_msg)
class BaseClientTestCase(_IrcTestCase): class BaseClientTestCase(_IrcTestCase):
"""Basic class for client tests. Handles spawning a client and exchanging """Basic class for client tests. Handles spawning a client and exchanging
messages with it.""" messages with it."""
nick = None nick = None
user = None user = None
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.conn = None self.conn = None
self._setUpServer() self._setUpServer()
def tearDown(self): def tearDown(self):
if self.conn: if self.conn:
try: try:
self.conn.sendall(b'QUIT :end of test.') self.conn.sendall(b"QUIT :end of test.")
except BrokenPipeError: except BrokenPipeError:
pass # client already disconnected pass # client already disconnected
except OSError: except OSError:
pass # the conn was already closed by the test, or something pass # the conn was already closed by the test, or something
self.controller.kill() self.controller.kill()
if self.conn: if self.conn:
self.conn_file.close() self.conn_file.close()
@ -130,8 +170,9 @@ class BaseClientTestCase(_IrcTestCase):
def _setUpServer(self): def _setUpServer(self):
"""Creates the server and make it listen.""" """Creates the server and make it listen."""
self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.server.bind(('', 0)) # Bind any free port self.server.bind(("", 0)) # Bind any free port
self.server.listen(1) self.server.listen(1)
def acceptClient(self, tls_cert=None, tls_key=None, server=None): def acceptClient(self, tls_cert=None, tls_key=None, server=None):
"""Make the server accept a client connection. Blocking.""" """Make the server accept a client connection. Blocking."""
server = server or self.server server = server or self.server
@ -139,10 +180,12 @@ class BaseClientTestCase(_IrcTestCase):
if tls_cert is None and tls_key is None: if tls_cert is None and tls_key is None:
pass pass
else: else:
assert tls_cert and tls_key, \ assert (
'tls_cert must be provided if and only if tls_key is.' tls_cert and tls_key
with tempfile.NamedTemporaryFile('at') as certfile, \ ), "tls_cert must be provided if and only if tls_key is."
tempfile.NamedTemporaryFile('at') as keyfile: with tempfile.NamedTemporaryFile(
"at"
) as certfile, tempfile.NamedTemporaryFile("at") as keyfile:
certfile.write(tls_cert) certfile.write(tls_cert)
certfile.seek(0) certfile.seek(0)
keyfile.write(tls_key) keyfile.write(tls_key)
@ -150,17 +193,18 @@ class BaseClientTestCase(_IrcTestCase):
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
context.load_cert_chain(certfile=certfile.name, keyfile=keyfile.name) context.load_cert_chain(certfile=certfile.name, keyfile=keyfile.name)
self.conn = context.wrap_socket(self.conn, server_side=True) self.conn = context.wrap_socket(self.conn, server_side=True)
self.conn_file = self.conn.makefile(newline='\r\n', self.conn_file = self.conn.makefile(newline="\r\n", encoding="utf8")
encoding='utf8')
def getLine(self): def getLine(self):
line = self.conn_file.readline() line = self.conn_file.readline()
if self.show_io: if self.show_io:
print('{:.3f} C: {}'.format(time.time(), line.strip())) print("{:.3f} C: {}".format(time.time(), line.strip()))
return line return line
def getMessages(self, *args): def getMessages(self, *args):
lines = self.getLines(*args) lines = self.getLines(*args)
return map(message_parser.parse_message, lines) return map(message_parser.parse_message, lines)
def getMessage(self, *args, filter_pred=None): def getMessage(self, *args, filter_pred=None):
"""Gets a message and returns it. If a filter predicate is given, """Gets a message and returns it. If a filter predicate is given,
fetches messages until the predicate returns a False on a message, fetches messages until the predicate returns a False on a message,
@ -172,45 +216,46 @@ class BaseClientTestCase(_IrcTestCase):
msg = message_parser.parse_message(line) msg = message_parser.parse_message(line)
if not filter_pred or filter_pred(msg): if not filter_pred or filter_pred(msg):
return msg return msg
def sendLine(self, line): def sendLine(self, line):
self.conn.sendall(line.encode()) self.conn.sendall(line.encode())
if not line.endswith('\r\n'): if not line.endswith("\r\n"):
self.conn.sendall(b'\r\n') self.conn.sendall(b"\r\n")
if self.show_io: if self.show_io:
print('{:.3f} S: {}'.format(time.time(), line.strip())) print("{:.3f} S: {}".format(time.time(), line.strip()))
class ClientNegociationHelper: class ClientNegociationHelper:
"""Helper class for tests handling capabilities negociation.""" """Helper class for tests handling capabilities negociation."""
def readCapLs(self, auth=None, tls_config=None): def readCapLs(self, auth=None, tls_config=None):
(hostname, port) = self.server.getsockname() (hostname, port) = self.server.getsockname()
self.controller.run( self.controller.run(
hostname=hostname, hostname=hostname,
port=port, port=port,
auth=auth, auth=auth,
tls_config=tls_config, tls_config=tls_config,
) )
self.acceptClient() self.acceptClient()
m = self.getMessage() m = self.getMessage()
self.assertEqual(m.command, 'CAP', self.assertEqual(m.command, "CAP", "First message is not CAP LS.")
'First message is not CAP LS.') if m.params == ["LS"]:
if m.params == ['LS']:
self.protocol_version = 301 self.protocol_version = 301
elif m.params == ['LS', '302']: elif m.params == ["LS", "302"]:
self.protocol_version = 302 self.protocol_version = 302
elif m.params == ['END']: elif m.params == ["END"]:
self.protocol_version = None self.protocol_version = None
else: else:
raise AssertionError('Unknown CAP params: {}' raise AssertionError("Unknown CAP params: {}".format(m.params))
.format(m.params))
def userNickPredicate(self, msg): def userNickPredicate(self, msg):
"""Predicate to be used with getMessage to handle NICK/USER """Predicate to be used with getMessage to handle NICK/USER
transparently.""" transparently."""
if msg.command == 'NICK': if msg.command == "NICK":
self.assertEqual(len(msg.params), 1, msg) self.assertEqual(len(msg.params), 1, msg)
self.nick = msg.params[0] self.nick = msg.params[0]
return False return False
elif msg.command == 'USER': elif msg.command == "USER":
self.assertEqual(len(msg.params), 4, msg) self.assertEqual(len(msg.params), 4, msg)
self.user = msg.params self.user = msg.params
return False return False
@ -225,25 +270,25 @@ class ClientNegociationHelper:
if not self.protocol_version: if not self.protocol_version:
# No negotiation. # No negotiation.
return return
self.sendLine('CAP * LS :{}'.format(' '.join(caps))) self.sendLine("CAP * LS :{}".format(" ".join(caps)))
capability_names = frozenset(capabilities.cap_list_to_dict(caps)) capability_names = frozenset(capabilities.cap_list_to_dict(caps))
self.acked_capabilities = set() self.acked_capabilities = set()
while True: while True:
m = self.getMessage(filter_pred=self.userNickPredicate) m = self.getMessage(filter_pred=self.userNickPredicate)
if m.command != 'CAP': if m.command != "CAP":
return m return m
self.assertGreater(len(m.params), 0, m) self.assertGreater(len(m.params), 0, m)
if m.params[0] == 'REQ': if m.params[0] == "REQ":
self.assertEqual(len(m.params), 2, m) self.assertEqual(len(m.params), 2, m)
requested = frozenset(m.params[1].split()) requested = frozenset(m.params[1].split())
if not requested.issubset(capability_names): if not requested.issubset(capability_names):
self.sendLine('CAP {} NAK :{}'.format( self.sendLine(
self.nick or '*', "CAP {} NAK :{}".format(self.nick or "*", m.params[1][0:100])
m.params[1][0:100])) )
else: else:
self.sendLine('CAP {} ACK :{}'.format( self.sendLine(
self.nick or '*', "CAP {} ACK :{}".format(self.nick or "*", m.params[1])
m.params[1])) )
self.acked_capabilities.update(requested) self.acked_capabilities.update(requested)
else: else:
return m return m
@ -252,27 +297,35 @@ class ClientNegociationHelper:
class BaseServerTestCase(_IrcTestCase): class BaseServerTestCase(_IrcTestCase):
"""Basic class for server tests. Handles spawning a server and exchanging """Basic class for server tests. Handles spawning a server and exchanging
messages with it.""" messages with it."""
password = None password = None
ssl = False ssl = False
valid_metadata_keys = frozenset() valid_metadata_keys = frozenset()
invalid_metadata_keys = frozenset() invalid_metadata_keys = frozenset()
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.server_support = {} self.server_support = {}
self.find_hostname_and_port() self.find_hostname_and_port()
self.controller.run(self.hostname, self.port, password=self.password, self.controller.run(
valid_metadata_keys=self.valid_metadata_keys, self.hostname,
invalid_metadata_keys=self.invalid_metadata_keys, self.port,
ssl=self.ssl) password=self.password,
valid_metadata_keys=self.valid_metadata_keys,
invalid_metadata_keys=self.invalid_metadata_keys,
ssl=self.ssl,
)
self.clients = {} self.clients = {}
def tearDown(self): def tearDown(self):
self.controller.kill() self.controller.kill()
for client in list(self.clients): for client in list(self.clients):
self.removeClient(client) self.removeClient(client)
def find_hostname_and_port(self): def find_hostname_and_port(self):
"""Find available hostname/port to listen on.""" """Find available hostname/port to listen on."""
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("",0)) s.bind(("", 0))
(self.hostname, self.port) = s.getsockname() (self.hostname, self.port) = s.getsockname()
s.close() s.close()
@ -281,14 +334,12 @@ class BaseServerTestCase(_IrcTestCase):
If 'name' is not given, uses the lowest unused non-negative integer.""" If 'name' is not given, uses the lowest unused non-negative integer."""
self.controller.wait_for_port() self.controller.wait_for_port()
if not name: if not name:
name = max(map(int, list(self.clients)+[0]))+1 name = max(map(int, list(self.clients) + [0])) + 1
show_io = show_io if show_io is not None else self.show_io show_io = show_io if show_io is not None else self.show_io
self.clients[name] = client_mock.ClientMock(name=name, self.clients[name] = client_mock.ClientMock(name=name, show_io=show_io)
show_io=show_io)
self.clients[name].connect(self.hostname, self.port) self.clients[name].connect(self.hostname, self.port)
return name return name
def removeClient(self, name): def removeClient(self, name):
"""Disconnects the client, without QUIT.""" """Disconnects the client, without QUIT."""
assert name in self.clients assert name in self.clients
@ -297,12 +348,16 @@ class BaseServerTestCase(_IrcTestCase):
def getMessages(self, client, **kwargs): def getMessages(self, client, **kwargs):
return self.clients[client].getMessages(**kwargs) return self.clients[client].getMessages(**kwargs)
def getMessage(self, client, **kwargs): def getMessage(self, client, **kwargs):
return self.clients[client].getMessage(**kwargs) return self.clients[client].getMessage(**kwargs)
def getRegistrationMessage(self, client): def getRegistrationMessage(self, client):
"""Filter notices, do not send pings.""" """Filter notices, do not send pings."""
return self.getMessage(client, synchronize=False, return self.getMessage(
filter_pred=lambda m:m.command != 'NOTICE') client, synchronize=False, filter_pred=lambda m: m.command != "NOTICE"
)
def sendLine(self, client, line): def sendLine(self, client, line):
return self.clients[client].sendLine(line) return self.clients[client].sendLine(line)
@ -315,8 +370,8 @@ class BaseServerTestCase(_IrcTestCase):
caps = [] caps = []
while True: while True:
m = self.getRegistrationMessage(client) m = self.getRegistrationMessage(client)
self.assertMessageEqual(m, command='CAP', subcommand='LS') self.assertMessageEqual(m, command="CAP", subcommand="LS")
if m.params[2] == '*': if m.params[2] == "*":
caps.extend(m.params[3].split()) caps.extend(m.params[3].split())
else: else:
caps.extend(m.params[2].split()) caps.extend(m.params[2].split())
@ -332,8 +387,7 @@ class BaseServerTestCase(_IrcTestCase):
del self.clients[client] del self.clients[client]
return return
else: else:
raise AssertionError('Client not disconnected.') raise AssertionError("Client not disconnected.")
def skipToWelcome(self, client): def skipToWelcome(self, client):
"""Skip to the point where we are registered """Skip to the point where we are registered
@ -343,45 +397,54 @@ class BaseServerTestCase(_IrcTestCase):
while True: while True:
m = self.getMessage(client, synchronize=False) m = self.getMessage(client, synchronize=False)
result.append(m) result.append(m)
if m.command == '001': if m.command == "001":
return result return result
def connectClient(self, nick, name=None, capabilities=None, def connectClient(
skip_if_cap_nak=False, show_io=None, password=None, ident='username'): self,
nick,
name=None,
capabilities=None,
skip_if_cap_nak=False,
show_io=None,
password=None,
ident="username",
):
client = self.addClient(name, show_io=show_io) client = self.addClient(name, show_io=show_io)
if capabilities is not None and 0 < len(capabilities): if capabilities is not None and 0 < len(capabilities):
self.sendLine(client, 'CAP REQ :{}'.format(' '.join(capabilities))) self.sendLine(client, "CAP REQ :{}".format(" ".join(capabilities)))
m = self.getRegistrationMessage(client) m = self.getRegistrationMessage(client)
try: try:
self.assertMessageEqual(m, command='CAP', self.assertMessageEqual(
fail_msg='Expected CAP ACK, got: {msg}') m, command="CAP", fail_msg="Expected CAP ACK, got: {msg}"
self.assertEqual(m.params[1], 'ACK', m, )
fail_msg='Expected CAP ACK, got: {msg}') self.assertEqual(
m.params[1], "ACK", m, fail_msg="Expected CAP ACK, got: {msg}"
)
except AssertionError: except AssertionError:
if skip_if_cap_nak: if skip_if_cap_nak:
raise runner.NotImplementedByController( raise runner.NotImplementedByController(", ".join(capabilities))
', '.join(capabilities))
else: else:
raise raise
self.sendLine(client, 'CAP END') self.sendLine(client, "CAP END")
if password is not None: if password is not None:
self.sendLine(client, 'AUTHENTICATE PLAIN') self.sendLine(client, "AUTHENTICATE PLAIN")
self.sendLine(client, sasl_plain_blob(nick, password)) self.sendLine(client, sasl_plain_blob(nick, password))
self.sendLine(client, 'NICK {}'.format(nick)) self.sendLine(client, "NICK {}".format(nick))
self.sendLine(client, 'USER %s * * :Realname' % (ident,)) self.sendLine(client, "USER %s * * :Realname" % (ident,))
welcome = self.skipToWelcome(client) welcome = self.skipToWelcome(client)
self.sendLine(client, 'PING foo') self.sendLine(client, "PING foo")
# Skip all that happy welcoming stuff # Skip all that happy welcoming stuff
while True: while True:
m = self.getMessage(client) m = self.getMessage(client)
if m.command == 'PONG': if m.command == "PONG":
break break
elif m.command == '005': elif m.command == "005":
for param in m.params[1:-1]: for param in m.params[1:-1]:
if '=' in param: if "=" in param:
(key, value) = param.split('=') (key, value) = param.split("=")
else: else:
(key, value) = (param, None) (key, value) = (param, None)
self.server_support[key] = value self.server_support[key] = value
@ -390,49 +453,57 @@ class BaseServerTestCase(_IrcTestCase):
return welcome return welcome
def joinClient(self, client, channel): def joinClient(self, client, channel):
self.sendLine(client, 'JOIN {}'.format(channel)) self.sendLine(client, "JOIN {}".format(channel))
received = {m.command for m in self.getMessages(client)} received = {m.command for m in self.getMessages(client)}
self.assertIn('366', received, self.assertIn(
fail_msg='Join to {} failed, {item} is not in the set of ' "366",
'received responses: {list}', received,
extra_format=(channel,)) fail_msg="Join to {} failed, {item} is not in the set of "
"received responses: {list}",
extra_format=(channel,),
)
def joinChannel(self, client, channel): def joinChannel(self, client, channel):
self.sendLine(client, 'JOIN {}'.format(channel)) self.sendLine(client, "JOIN {}".format(channel))
# wait until we see them join the channel # wait until we see them join the channel
joined = False joined = False
while not joined: while not joined:
for msg in self.getMessages(client): for msg in self.getMessages(client):
if msg.command == 'JOIN' and 0 < len(msg.params) and msg.params[0].lower() == channel.lower(): if (
msg.command == "JOIN"
and 0 < len(msg.params)
and msg.params[0].lower() == channel.lower()
):
joined = True joined = True
break break
elif msg.command in CHANNEL_JOIN_FAIL_NUMERICS: elif msg.command in CHANNEL_JOIN_FAIL_NUMERICS:
raise ChannelJoinException(msg.command, msg.params) raise ChannelJoinException(msg.command, msg.params)
def getISupport(self): def getISupport(self):
cn = random_name('bar') cn = random_name("bar")
self.addClient(name=cn) self.addClient(name=cn)
self.sendLine(cn, 'NICK %s' % (cn,)) self.sendLine(cn, "NICK %s" % (cn,))
self.sendLine(cn, 'USER u s e r') self.sendLine(cn, "USER u s e r")
messages = self.getMessages(cn) messages = self.getMessages(cn)
isupport = {} isupport = {}
for message in messages: for message in messages:
if message.command != '005': if message.command != "005":
continue continue
# 005 nick <tokens...> :are supported by this server # 005 nick <tokens...> :are supported by this server
tokens = message.params[1:-1] tokens = message.params[1:-1]
for token in tokens: for token in tokens:
name, _, value = token.partition('=') name, _, value = token.partition("=")
isupport[name] = value isupport[name] = value
self.sendLine(cn, 'QUIT') self.sendLine(cn, "QUIT")
self.assertDisconnected(cn) self.assertDisconnected(cn)
return isupport return isupport
class OptionalityHelper: class OptionalityHelper:
def checkSaslSupport(self): def checkSaslSupport(self):
if self.controller.supported_sasl_mechanisms: if self.controller.supported_sasl_mechanisms:
return return
raise runner.NotImplementedByController('SASL') raise runner.NotImplementedByController("SASL")
def checkMechanismSupport(self, mechanism): def checkMechanismSupport(self, mechanism):
if mechanism in self.controller.supported_sasl_mechanisms: if mechanism in self.controller.supported_sasl_mechanisms:
@ -445,7 +516,9 @@ class OptionalityHelper:
def newf(self): def newf(self):
self.checkMechanismSupport(mech) self.checkMechanismSupport(mech)
return f(self) return f(self)
return newf return newf
return decorator return decorator
def skipUnlessHasSasl(f): def skipUnlessHasSasl(f):
@ -453,6 +526,7 @@ class OptionalityHelper:
def newf(self): def newf(self):
self.checkSaslSupport() self.checkSaslSupport()
return f(self) return f(self)
return newf return newf
def checkCapabilitySupport(self, cap): def checkCapabilitySupport(self, cap):
@ -466,22 +540,26 @@ class OptionalityHelper:
def newf(self): def newf(self):
self.checkCapabilitySupport(cap) self.checkCapabilitySupport(cap)
return f(self) return f(self)
return newf return newf
return decorator return decorator
class SpecificationSelector:
class SpecificationSelector:
def requiredBySpecification(*specifications, strict=False): def requiredBySpecification(*specifications, strict=False):
specifications = frozenset( specifications = frozenset(
Specifications.of_name(s) if isinstance(s, str) else s Specifications.of_name(s) if isinstance(s, str) else s
for s in specifications) for s in specifications
)
if None in specifications: if None in specifications:
raise ValueError('Invalid set of specifications: {}' raise ValueError("Invalid set of specifications: {}".format(specifications))
.format(specifications))
def decorator(f): def decorator(f):
for specification in specifications: for specification in specifications:
f = getattr(pytest.mark, specification.value)(f) f = getattr(pytest.mark, specification.value)(f)
if strict: if strict:
f = pytest.mark.strict(f) f = pytest.mark.strict(f)
return f return f
return decorator return decorator

View File

@ -5,32 +5,37 @@ import socket
from .irc_utils import message_parser from .irc_utils import message_parser
from .exceptions import NoMessageException, ConnectionClosed from .exceptions import NoMessageException, ConnectionClosed
class ClientMock: class ClientMock:
def __init__(self, name, show_io): def __init__(self, name, show_io):
self.name = name self.name = name
self.show_io = show_io self.show_io = show_io
self.inbuffer = [] self.inbuffer = []
self.ssl = False self.ssl = False
def connect(self, hostname, port): def connect(self, hostname, port):
self.conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.conn.settimeout(1) # TODO: configurable self.conn.settimeout(1) # TODO: configurable
self.conn.connect((hostname, port)) self.conn.connect((hostname, port))
if self.show_io: if self.show_io:
print('{:.3f} {}: connects to server.'.format(time.time(), self.name)) print("{:.3f} {}: connects to server.".format(time.time(), self.name))
def disconnect(self): def disconnect(self):
if self.show_io: if self.show_io:
print('{:.3f} {}: disconnects from server.'.format(time.time(), self.name)) print("{:.3f} {}: disconnects from server.".format(time.time(), self.name))
self.conn.close() self.conn.close()
def starttls(self): def starttls(self):
assert not self.ssl, 'SSL already active.' assert not self.ssl, "SSL already active."
self.conn = ssl.wrap_socket(self.conn) self.conn = ssl.wrap_socket(self.conn)
self.ssl = True self.ssl = True
def getMessages(self, synchronize=True, assert_get_one=False, raw=False): def getMessages(self, synchronize=True, assert_get_one=False, raw=False):
if synchronize: if synchronize:
token = 'synchronize{}'.format(time.monotonic()) token = "synchronize{}".format(time.monotonic())
self.sendLine('PING {}'.format(token)) self.sendLine("PING {}".format(token))
got_pong = False got_pong = False
data = b'' data = b""
(self.inbuffer, messages) = ([], self.inbuffer) (self.inbuffer, messages) = ([], self.inbuffer)
conn = self.conn conn = self.conn
try: try:
@ -38,11 +43,11 @@ class ClientMock:
try: try:
new_data = conn.recv(4096) new_data = conn.recv(4096)
except socket.timeout: except socket.timeout:
if not assert_get_one and not synchronize and data == b'': if not assert_get_one and not synchronize and data == b"":
# Received nothing # Received nothing
return [] return []
if self.show_io: if self.show_io:
print('{:.3f} {}: waiting…'.format(time.time(), self.name)) print("{:.3f} {}: waiting…".format(time.time(), self.name))
time.sleep(0.1) time.sleep(0.1)
continue continue
except ConnectionResetError: except ConnectionResetError:
@ -52,29 +57,31 @@ class ClientMock:
# Connection closed # Connection closed
raise ConnectionClosed() raise ConnectionClosed()
data += new_data data += new_data
if not new_data.endswith(b'\r\n'): if not new_data.endswith(b"\r\n"):
time.sleep(0.1) time.sleep(0.1)
continue continue
if not synchronize: if not synchronize:
got_pong = True got_pong = True
for line in data.decode().split('\r\n'): for line in data.decode().split("\r\n"):
if line: if line:
if self.show_io: if self.show_io:
print('{time:.3f}{ssl} S -> {client}: {line}'.format( print(
time=time.time(), "{time:.3f}{ssl} S -> {client}: {line}".format(
ssl=' (ssl)' if self.ssl else '', time=time.time(),
client=self.name, ssl=" (ssl)" if self.ssl else "",
line=line)) client=self.name,
line=line,
)
)
message = message_parser.parse_message(line) message = message_parser.parse_message(line)
if message.command == 'PONG' and \ if message.command == "PONG" and token in message.params:
token in message.params:
got_pong = True got_pong = True
else: else:
if raw: if raw:
messages.append(line) messages.append(line)
else: else:
messages.append(message) messages.append(message)
data = b'' data = b""
except ConnectionClosed: except ConnectionClosed:
if messages: if messages:
return messages return messages
@ -82,16 +89,19 @@ class ClientMock:
raise raise
else: else:
return messages return messages
def getMessage(self, filter_pred=None, synchronize=True, raw=False): def getMessage(self, filter_pred=None, synchronize=True, raw=False):
while True: while True:
if not self.inbuffer: if not self.inbuffer:
self.inbuffer = self.getMessages( self.inbuffer = self.getMessages(
synchronize=synchronize, assert_get_one=True, raw=raw) synchronize=synchronize, assert_get_one=True, raw=raw
)
if not self.inbuffer: if not self.inbuffer:
raise NoMessageException() raise NoMessageException()
message = self.inbuffer.pop(0) # TODO: use dequeue message = self.inbuffer.pop(0) # TODO: use dequeue
if not filter_pred or filter_pred(message): if not filter_pred or filter_pred(message):
return message return message
def sendLine(self, line): def sendLine(self, line):
if isinstance(line, str): if isinstance(line, str):
encoded_line = line.encode() encoded_line = line.encode()
@ -99,26 +109,31 @@ class ClientMock:
encoded_line = line encoded_line = line
else: else:
raise ValueError(line) raise ValueError(line)
if not encoded_line.endswith(b'\r\n'): if not encoded_line.endswith(b"\r\n"):
encoded_line += b'\r\n' encoded_line += b"\r\n"
try: try:
ret = self.conn.sendall(encoded_line) ret = self.conn.sendall(encoded_line)
except BrokenPipeError: except BrokenPipeError:
raise ConnectionClosed() raise ConnectionClosed()
if sys.version_info <= (3, 6) and self.ssl: # https://bugs.python.org/issue25951 if (
sys.version_info <= (3, 6) and self.ssl
): # https://bugs.python.org/issue25951
assert ret == len(encoded_line), (ret, repr(encoded_line)) assert ret == len(encoded_line), (ret, repr(encoded_line))
else: else:
assert ret is None, ret assert ret is None, ret
if self.show_io: if self.show_io:
if isinstance(line, str): if isinstance(line, str):
escaped_line = line escaped_line = line
escaped = '' escaped = ""
else: else:
escaped_line = repr(line) escaped_line = repr(line)
escaped = ' (escaped)' escaped = " (escaped)"
print('{time:.3f}{escaped}{ssl} {client} -> S: {line}'.format( print(
time=time.time(), "{time:.3f}{escaped}{ssl} {client} -> S: {line}".format(
escaped=escaped, time=time.time(),
ssl=' (ssl)' if self.ssl else '', escaped=escaped,
client=self.name, ssl=" (ssl)" if self.ssl else "",
line=escaped_line.strip('\r\n'))) client=self.name,
line=escaped_line.strip("\r\n"),
)
)

View File

@ -1,14 +1,15 @@
from irctest import cases from irctest import cases
from irctest.irc_utils.message_parser import Message from irctest.irc_utils.message_parser import Message
class CapTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper): class CapTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper):
@cases.SpecificationSelector.requiredBySpecification('IRCv3.1', 'IRCv3.2') @cases.SpecificationSelector.requiredBySpecification("IRCv3.1", "IRCv3.2")
def testSendCap(self): def testSendCap(self):
"""Send CAP LS 302 and read the result.""" """Send CAP LS 302 and read the result."""
self.readCapLs() self.readCapLs()
@cases.SpecificationSelector.requiredBySpecification('IRCv3.1', 'IRCv3.2') @cases.SpecificationSelector.requiredBySpecification("IRCv3.1", "IRCv3.2")
def testEmptyCapLs(self): def testEmptyCapLs(self):
"""Empty result to CAP LS. Client should send CAP END.""" """Empty result to CAP LS. Client should send CAP END."""
m = self.negotiateCapabilities([]) m = self.negotiateCapabilities([])
self.assertEqual(m, Message({}, None, 'CAP', ['END'])) self.assertEqual(m, Message({}, None, "CAP", ["END"]))

View File

@ -27,6 +27,7 @@ IRX9cyi2wdYg9mUUYyh9GKdBCYHGUJAiCA==
CHALLENGE = bytes(range(32)) CHALLENGE = bytes(range(32))
assert len(CHALLENGE) == 32 assert len(CHALLENGE) == 32
class IdentityHash: class IdentityHash:
def __init__(self, data): def __init__(self, data):
self._data = data self._data = data
@ -34,28 +35,31 @@ class IdentityHash:
def digest(self): def digest(self):
return self._data return self._data
class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
cases.OptionalityHelper): class SaslTestCase(
@cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN') cases.BaseClientTestCase, cases.ClientNegociationHelper, cases.OptionalityHelper
):
@cases.OptionalityHelper.skipUnlessHasMechanism("PLAIN")
def testPlain(self): def testPlain(self):
"""Test PLAIN authentication with correct username/password.""" """Test PLAIN authentication with correct username/password."""
auth = authentication.Authentication( auth = authentication.Authentication(
mechanisms=[authentication.Mechanisms.plain], mechanisms=[authentication.Mechanisms.plain],
username='jilles', username="jilles",
password='sesame', password="sesame",
) )
m = self.negotiateCapabilities(['sasl'], auth=auth) m = self.negotiateCapabilities(["sasl"], auth=auth)
self.assertEqual(m, Message({}, None, 'AUTHENTICATE', ['PLAIN'])) self.assertEqual(m, Message({}, None, "AUTHENTICATE", ["PLAIN"]))
self.sendLine('AUTHENTICATE +') self.sendLine("AUTHENTICATE +")
m = self.getMessage() m = self.getMessage()
self.assertEqual(m, Message({}, None, 'AUTHENTICATE', self.assertEqual(
['amlsbGVzAGppbGxlcwBzZXNhbWU='])) m, Message({}, None, "AUTHENTICATE", ["amlsbGVzAGppbGxlcwBzZXNhbWU="])
self.sendLine('900 * * jilles :You are now logged in.') )
self.sendLine('903 * :SASL authentication successful') self.sendLine("900 * * jilles :You are now logged in.")
m = self.negotiateCapabilities(['sasl'], False) self.sendLine("903 * :SASL authentication successful")
self.assertEqual(m, Message({}, None, 'CAP', ['END'])) m = self.negotiateCapabilities(["sasl"], False)
self.assertEqual(m, Message({}, None, "CAP", ["END"]))
@cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN') @cases.OptionalityHelper.skipUnlessHasMechanism("PLAIN")
def testPlainNotAvailable(self): def testPlainNotAvailable(self):
"""`sasl=EXTERNAL` is advertized, whereas the client is configured """`sasl=EXTERNAL` is advertized, whereas the client is configured
to use PLAIN. to use PLAIN.
@ -65,27 +69,26 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
a 904. a 904.
""" """
auth = authentication.Authentication( auth = authentication.Authentication(
mechanisms=[authentication.Mechanisms.plain], mechanisms=[authentication.Mechanisms.plain],
username='jilles', username="jilles",
password='sesame', password="sesame",
) )
m = self.negotiateCapabilities(['sasl=EXTERNAL'], auth=auth) m = self.negotiateCapabilities(["sasl=EXTERNAL"], auth=auth)
self.assertEqual(self.acked_capabilities, {'sasl'}) self.assertEqual(self.acked_capabilities, {"sasl"})
if m == Message({}, None, 'CAP', ['END']): if m == Message({}, None, "CAP", ["END"]):
# IRCv3.2-style, for clients that skip authentication # IRCv3.2-style, for clients that skip authentication
# when unavailable (eg. Limnoria) # when unavailable (eg. Limnoria)
return return
elif m.command == 'QUIT': elif m.command == "QUIT":
# IRCv3.2-style, for clients that quit when unavailable # IRCv3.2-style, for clients that quit when unavailable
# (eg. Sopel) # (eg. Sopel)
return return
self.assertEqual(m, Message({}, None, 'AUTHENTICATE', ['PLAIN'])) self.assertEqual(m, Message({}, None, "AUTHENTICATE", ["PLAIN"]))
self.sendLine('904 {} :SASL auth failed'.format(self.nick)) self.sendLine("904 {} :SASL auth failed".format(self.nick))
m = self.getMessage() m = self.getMessage()
self.assertMessageEqual(m, command='CAP') self.assertMessageEqual(m, command="CAP")
@cases.OptionalityHelper.skipUnlessHasMechanism("PLAIN")
@cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN')
def testPlainLarge(self): def testPlainLarge(self):
"""Test the client splits large AUTHENTICATE messages whose payload """Test the client splits large AUTHENTICATE messages whose payload
is not a multiple of 400. is not a multiple of 400.
@ -93,30 +96,28 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
""" """
# TODO: authzid is optional # TODO: authzid is optional
auth = authentication.Authentication( auth = authentication.Authentication(
mechanisms=[authentication.Mechanisms.plain], mechanisms=[authentication.Mechanisms.plain],
username='foo', username="foo",
password='bar'*200, password="bar" * 200,
) )
authstring = base64.b64encode(b'\x00'.join( authstring = base64.b64encode(
[b'foo', b'foo', b'bar'*200])).decode() b"\x00".join([b"foo", b"foo", b"bar" * 200])
m = self.negotiateCapabilities(['sasl'], auth=auth) ).decode()
self.assertEqual(m, Message({}, None, 'AUTHENTICATE', ['PLAIN'])) m = self.negotiateCapabilities(["sasl"], auth=auth)
self.sendLine('AUTHENTICATE +') self.assertEqual(m, Message({}, None, "AUTHENTICATE", ["PLAIN"]))
self.sendLine("AUTHENTICATE +")
m = self.getMessage() m = self.getMessage()
self.assertEqual(m, Message({}, None, 'AUTHENTICATE', self.assertEqual(m, Message({}, None, "AUTHENTICATE", [authstring[0:400]]), m)
[authstring[0:400]]), m)
m = self.getMessage() m = self.getMessage()
self.assertEqual(m, Message({}, None, 'AUTHENTICATE', self.assertEqual(m, Message({}, None, "AUTHENTICATE", [authstring[400:800]]))
[authstring[400:800]]))
m = self.getMessage() m = self.getMessage()
self.assertEqual(m, Message({}, None, 'AUTHENTICATE', self.assertEqual(m, Message({}, None, "AUTHENTICATE", [authstring[800:]]))
[authstring[800:]])) self.sendLine("900 * * {} :You are now logged in.".format("foo"))
self.sendLine('900 * * {} :You are now logged in.'.format('foo')) self.sendLine("903 * :SASL authentication successful")
self.sendLine('903 * :SASL authentication successful') m = self.negotiateCapabilities(["sasl"], False)
m = self.negotiateCapabilities(['sasl'], False) self.assertEqual(m, Message({}, None, "CAP", ["END"]))
self.assertEqual(m, Message({}, None, 'CAP', ['END']))
@cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN') @cases.OptionalityHelper.skipUnlessHasMechanism("PLAIN")
def testPlainLargeMultiple(self): def testPlainLargeMultiple(self):
"""Test the client splits large AUTHENTICATE messages whose payload """Test the client splits large AUTHENTICATE messages whose payload
is a multiple of 400. is a multiple of 400.
@ -124,149 +125,157 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
""" """
# TODO: authzid is optional # TODO: authzid is optional
auth = authentication.Authentication( auth = authentication.Authentication(
mechanisms=[authentication.Mechanisms.plain], mechanisms=[authentication.Mechanisms.plain],
username='foo', username="foo",
password='quux'*148, password="quux" * 148,
) )
authstring = base64.b64encode(b'\x00'.join( authstring = base64.b64encode(
[b'foo', b'foo', b'quux'*148])).decode() b"\x00".join([b"foo", b"foo", b"quux" * 148])
m = self.negotiateCapabilities(['sasl'], auth=auth) ).decode()
self.assertEqual(m, Message({}, None, 'AUTHENTICATE', ['PLAIN'])) m = self.negotiateCapabilities(["sasl"], auth=auth)
self.sendLine('AUTHENTICATE +') self.assertEqual(m, Message({}, None, "AUTHENTICATE", ["PLAIN"]))
self.sendLine("AUTHENTICATE +")
m = self.getMessage() m = self.getMessage()
self.assertEqual(m, Message({}, None, 'AUTHENTICATE', self.assertEqual(m, Message({}, None, "AUTHENTICATE", [authstring[0:400]]), m)
[authstring[0:400]]), m)
m = self.getMessage() m = self.getMessage()
self.assertEqual(m, Message({}, None, 'AUTHENTICATE', self.assertEqual(m, Message({}, None, "AUTHENTICATE", [authstring[400:800]]))
[authstring[400:800]]))
m = self.getMessage() m = self.getMessage()
self.assertEqual(m, Message({}, None, 'AUTHENTICATE', self.assertEqual(m, Message({}, None, "AUTHENTICATE", ["+"]))
['+'])) self.sendLine("900 * * {} :You are now logged in.".format("foo"))
self.sendLine('900 * * {} :You are now logged in.'.format('foo')) self.sendLine("903 * :SASL authentication successful")
self.sendLine('903 * :SASL authentication successful') m = self.negotiateCapabilities(["sasl"], False)
m = self.negotiateCapabilities(['sasl'], False) self.assertEqual(m, Message({}, None, "CAP", ["END"]))
self.assertEqual(m, Message({}, None, 'CAP', ['END']))
@cases.OptionalityHelper.skipUnlessHasMechanism('ECDSA-NIST256P-CHALLENGE') @cases.OptionalityHelper.skipUnlessHasMechanism("ECDSA-NIST256P-CHALLENGE")
def testEcdsa(self): def testEcdsa(self):
"""Test ECDSA authentication. """Test ECDSA authentication."""
"""
auth = authentication.Authentication( auth = authentication.Authentication(
mechanisms=[authentication.Mechanisms.ecdsa_nist256p_challenge], mechanisms=[authentication.Mechanisms.ecdsa_nist256p_challenge],
username='jilles', username="jilles",
ecdsa_key=ECDSA_KEY, ecdsa_key=ECDSA_KEY,
) )
m = self.negotiateCapabilities(['sasl'], auth=auth) m = self.negotiateCapabilities(["sasl"], auth=auth)
self.assertEqual(m, Message({}, None, 'AUTHENTICATE', ['ECDSA-NIST256P-CHALLENGE'])) self.assertEqual(
self.sendLine('AUTHENTICATE +') m, Message({}, None, "AUTHENTICATE", ["ECDSA-NIST256P-CHALLENGE"])
)
self.sendLine("AUTHENTICATE +")
m = self.getMessage() m = self.getMessage()
self.assertEqual(m, Message({}, None, 'AUTHENTICATE', self.assertEqual(m, Message({}, None, "AUTHENTICATE", ["amlsbGVz"])) # jilles
['amlsbGVz'])) # jilles self.sendLine(
self.sendLine('AUTHENTICATE {}'.format(base64.b64encode(CHALLENGE).decode('ascii'))) "AUTHENTICATE {}".format(base64.b64encode(CHALLENGE).decode("ascii"))
)
m = self.getMessage() m = self.getMessage()
self.assertMessageEqual(m, command='AUTHENTICATE') self.assertMessageEqual(m, command="AUTHENTICATE")
sk = ecdsa.SigningKey.from_pem(ECDSA_KEY) sk = ecdsa.SigningKey.from_pem(ECDSA_KEY)
vk = sk.get_verifying_key() vk = sk.get_verifying_key()
signature = base64.b64decode(m.params[0]) signature = base64.b64decode(m.params[0])
try: try:
vk.verify(signature, CHALLENGE, hashfunc=IdentityHash, sigdecode=sigdecode_der) vk.verify(
signature, CHALLENGE, hashfunc=IdentityHash, sigdecode=sigdecode_der
)
except ecdsa.BadSignatureError: except ecdsa.BadSignatureError:
raise AssertionError('Bad signature') raise AssertionError("Bad signature")
self.sendLine('900 * * foo :You are now logged in.') self.sendLine("900 * * foo :You are now logged in.")
self.sendLine('903 * :SASL authentication successful') self.sendLine("903 * :SASL authentication successful")
m = self.negotiateCapabilities(['sasl'], False) m = self.negotiateCapabilities(["sasl"], False)
self.assertEqual(m, Message({}, None, 'CAP', ['END'])) self.assertEqual(m, Message({}, None, "CAP", ["END"]))
@cases.OptionalityHelper.skipUnlessHasMechanism('SCRAM-SHA-256') @cases.OptionalityHelper.skipUnlessHasMechanism("SCRAM-SHA-256")
def testScram(self): def testScram(self):
"""Test SCRAM-SHA-256 authentication. """Test SCRAM-SHA-256 authentication."""
"""
auth = authentication.Authentication( auth = authentication.Authentication(
mechanisms=[authentication.Mechanisms.scram_sha_256], mechanisms=[authentication.Mechanisms.scram_sha_256],
username='jilles', username="jilles",
password='sesame', password="sesame",
) )
class PasswdDb: class PasswdDb:
def get_password(self, *args): def get_password(self, *args):
return ('sesame', 'plain') return ("sesame", "plain")
authenticator = scram.SCRAMServerAuthenticator('SHA-256',
channel_binding=False, password_database=PasswdDb())
m = self.negotiateCapabilities(['sasl'], auth=auth) authenticator = scram.SCRAMServerAuthenticator(
self.assertEqual(m, Message({}, None, 'AUTHENTICATE', ['SCRAM-SHA-256'])) "SHA-256", channel_binding=False, password_database=PasswdDb()
self.sendLine('AUTHENTICATE +') )
m = self.negotiateCapabilities(["sasl"], auth=auth)
self.assertEqual(m, Message({}, None, "AUTHENTICATE", ["SCRAM-SHA-256"]))
self.sendLine("AUTHENTICATE +")
m = self.getMessage() m = self.getMessage()
self.assertEqual(m.command, 'AUTHENTICATE', m) self.assertEqual(m.command, "AUTHENTICATE", m)
client_first = base64.b64decode(m.params[0]) client_first = base64.b64decode(m.params[0])
response = authenticator.start(properties={}, initial_response=client_first) response = authenticator.start(properties={}, initial_response=client_first)
assert isinstance(response, bytes), response assert isinstance(response, bytes), response
self.sendLine('AUTHENTICATE :' + base64.b64encode(response).decode()) self.sendLine("AUTHENTICATE :" + base64.b64encode(response).decode())
m = self.getMessage() m = self.getMessage()
self.assertEqual(m.command, 'AUTHENTICATE', m) self.assertEqual(m.command, "AUTHENTICATE", m)
msg = base64.b64decode(m.params[0]) msg = base64.b64decode(m.params[0])
r = authenticator.response(msg) r = authenticator.response(msg)
assert isinstance(r, tuple), r assert isinstance(r, tuple), r
assert len(r) == 2, r assert len(r) == 2, r
(properties, response) = r (properties, response) = r
self.sendLine('AUTHENTICATE :' + base64.b64encode(response).decode()) self.sendLine("AUTHENTICATE :" + base64.b64encode(response).decode())
self.assertEqual(properties, {'authzid': None, 'username': 'jilles'}) self.assertEqual(properties, {"authzid": None, "username": "jilles"})
m = self.getMessage() m = self.getMessage()
self.assertEqual(m.command, 'AUTHENTICATE', m) self.assertEqual(m.command, "AUTHENTICATE", m)
self.assertEqual(m.params, ['+'], m) self.assertEqual(m.params, ["+"], m)
@cases.OptionalityHelper.skipUnlessHasMechanism('SCRAM-SHA-256') @cases.OptionalityHelper.skipUnlessHasMechanism("SCRAM-SHA-256")
def testScramBadPassword(self): def testScramBadPassword(self):
"""Test SCRAM-SHA-256 authentication with a bad password. """Test SCRAM-SHA-256 authentication with a bad password."""
"""
auth = authentication.Authentication( auth = authentication.Authentication(
mechanisms=[authentication.Mechanisms.scram_sha_256], mechanisms=[authentication.Mechanisms.scram_sha_256],
username='jilles', username="jilles",
password='sesame', password="sesame",
) )
class PasswdDb: class PasswdDb:
def get_password(self, *args): def get_password(self, *args):
return ('notsesame', 'plain') return ("notsesame", "plain")
authenticator = scram.SCRAMServerAuthenticator('SHA-256',
channel_binding=False, password_database=PasswdDb())
m = self.negotiateCapabilities(['sasl'], auth=auth) authenticator = scram.SCRAMServerAuthenticator(
self.assertEqual(m, Message({}, None, 'AUTHENTICATE', ['SCRAM-SHA-256'])) "SHA-256", channel_binding=False, password_database=PasswdDb()
self.sendLine('AUTHENTICATE +') )
m = self.negotiateCapabilities(["sasl"], auth=auth)
self.assertEqual(m, Message({}, None, "AUTHENTICATE", ["SCRAM-SHA-256"]))
self.sendLine("AUTHENTICATE +")
m = self.getMessage() m = self.getMessage()
self.assertEqual(m.command, 'AUTHENTICATE', m) self.assertEqual(m.command, "AUTHENTICATE", m)
client_first = base64.b64decode(m.params[0]) client_first = base64.b64decode(m.params[0])
response = authenticator.start(properties={}, initial_response=client_first) response = authenticator.start(properties={}, initial_response=client_first)
assert isinstance(response, bytes), response assert isinstance(response, bytes), response
self.sendLine('AUTHENTICATE :' + base64.b64encode(response).decode()) self.sendLine("AUTHENTICATE :" + base64.b64encode(response).decode())
m = self.getMessage() m = self.getMessage()
self.assertEqual(m.command, 'AUTHENTICATE', m) self.assertEqual(m.command, "AUTHENTICATE", m)
msg = base64.b64decode(m.params[0]) msg = base64.b64decode(m.params[0])
with self.assertRaises(scram.NotAuthorizedException): with self.assertRaises(scram.NotAuthorizedException):
authenticator.response(msg) authenticator.response(msg)
class Irc302SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
cases.OptionalityHelper): class Irc302SaslTestCase(
@cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN') cases.BaseClientTestCase, cases.ClientNegociationHelper, cases.OptionalityHelper
):
@cases.OptionalityHelper.skipUnlessHasMechanism("PLAIN")
def testPlainNotAvailable(self): def testPlainNotAvailable(self):
"""Test the client does not try to authenticate using a mechanism the """Test the client does not try to authenticate using a mechanism the
server does not advertise. server does not advertise.
Actually, this is optional.""" Actually, this is optional."""
auth = authentication.Authentication( auth = authentication.Authentication(
mechanisms=[authentication.Mechanisms.plain], mechanisms=[authentication.Mechanisms.plain],
username='jilles', username="jilles",
password='sesame', password="sesame",
) )
m = self.negotiateCapabilities(['sasl=EXTERNAL'], auth=auth) m = self.negotiateCapabilities(["sasl=EXTERNAL"], auth=auth)
self.assertEqual(self.acked_capabilities, {'sasl'}) self.assertEqual(self.acked_capabilities, {"sasl"})
if m.command == 'QUIT': if m.command == "QUIT":
# Some clients quit when it can't authenticate (eg. Sopel) # Some clients quit when it can't authenticate (eg. Sopel)
pass pass
else: else:
# Others will just skip authentication (eg. Limnoria) # Others will just skip authentication (eg. Limnoria)
self.assertEqual(m, Message({}, None, 'CAP', ['END'])) self.assertEqual(m, Message({}, None, "CAP", ["END"]))

View File

@ -60,7 +60,7 @@ h4WuPDAI4yh24GjaCZYGR5xcqPCy5CNjMLxdA7HsP+Gcr3eY5XS7noBrbC6IaA0j
-----END PRIVATE KEY----- -----END PRIVATE KEY-----
""" """
GOOD_FINGERPRINT = 'E1EE6DE2DBC0D43E3B60407B5EE389AEC9D2C53178E0FB14CD51C3DFD544AA2B' GOOD_FINGERPRINT = "E1EE6DE2DBC0D43E3B60407B5EE389AEC9D2C53178E0FB14CD51C3DFD544AA2B"
GOOD_CERT = """ GOOD_CERT = """
-----BEGIN CERTIFICATE----- -----BEGIN CERTIFICATE-----
MIIDXTCCAkWgAwIBAgIJAKtD9XMC1R0vMA0GCSqGSIb3DQEBCwUAMEUxCzAJBgNV MIIDXTCCAkWgAwIBAgIJAKtD9XMC1R0vMA0GCSqGSIb3DQEBCwUAMEUxCzAJBgNV
@ -115,32 +115,29 @@ El9iqRlAhgqaXc4Iz/Zxxhs=
-----END PRIVATE KEY----- -----END PRIVATE KEY-----
""" """
class TlsTestCase(cases.BaseClientTestCase): class TlsTestCase(cases.BaseClientTestCase):
def testTrustedCertificate(self): def testTrustedCertificate(self):
tls_config = tls.TlsConfig( tls_config = tls.TlsConfig(enable=True, trusted_fingerprints=[GOOD_FINGERPRINT])
enable=True,
trusted_fingerprints=[GOOD_FINGERPRINT])
(hostname, port) = self.server.getsockname() (hostname, port) = self.server.getsockname()
self.controller.run( self.controller.run(
hostname=hostname, hostname=hostname,
port=port, port=port,
auth=None, auth=None,
tls_config=tls_config, tls_config=tls_config,
) )
self.acceptClient(tls_cert=GOOD_CERT, tls_key=GOOD_KEY) self.acceptClient(tls_cert=GOOD_CERT, tls_key=GOOD_KEY)
m = self.getMessage() m = self.getMessage()
def testUntrustedCertificate(self): def testUntrustedCertificate(self):
tls_config = tls.TlsConfig( tls_config = tls.TlsConfig(enable=True, trusted_fingerprints=[GOOD_FINGERPRINT])
enable=True,
trusted_fingerprints=[GOOD_FINGERPRINT])
(hostname, port) = self.server.getsockname() (hostname, port) = self.server.getsockname()
self.controller.run( self.controller.run(
hostname=hostname, hostname=hostname,
port=port, port=port,
auth=None, auth=None,
tls_config=tls_config, tls_config=tls_config,
) )
self.acceptClient(tls_cert=BAD_CERT, tls_key=BAD_KEY) self.acceptClient(tls_cert=BAD_CERT, tls_key=BAD_KEY)
with self.assertRaises((ConnectionClosed, ConnectionResetError)): with self.assertRaises((ConnectionClosed, ConnectionResetError)):
m = self.getMessage() m = self.getMessage()
@ -150,36 +147,34 @@ class StsTestCase(cases.BaseClientTestCase, cases.OptionalityHelper):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.insecure_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.insecure_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.insecure_server.bind(('', 0)) # Bind any free port self.insecure_server.bind(("", 0)) # Bind any free port
self.insecure_server.listen(1) self.insecure_server.listen(1)
def tearDown(self): def tearDown(self):
self.insecure_server.close() self.insecure_server.close()
super().tearDown() super().tearDown()
@cases.OptionalityHelper.skipUnlessSupportsCapability('sts') @cases.OptionalityHelper.skipUnlessSupportsCapability("sts")
def testSts(self): def testSts(self):
tls_config = tls.TlsConfig( tls_config = tls.TlsConfig(
enable=False, enable=False, trusted_fingerprints=[GOOD_FINGERPRINT]
trusted_fingerprints=[GOOD_FINGERPRINT]) )
# Connect client to insecure server # Connect client to insecure server
(hostname, port) = self.insecure_server.getsockname() (hostname, port) = self.insecure_server.getsockname()
self.controller.run( self.controller.run(
hostname=hostname, hostname=hostname,
port=port, port=port,
auth=None, auth=None,
tls_config=tls_config, tls_config=tls_config,
) )
self.acceptClient(server=self.insecure_server) self.acceptClient(server=self.insecure_server)
# Send STS policy to client # Send STS policy to client
m = self.getMessage() m = self.getMessage()
self.assertEqual(m.command, 'CAP', self.assertEqual(m.command, "CAP", "First message is not CAP LS.")
'First message is not CAP LS.') self.assertEqual(m.params[0], "LS", "First message is not CAP LS.")
self.assertEqual(m.params[0], 'LS', self.sendLine("CAP * LS :sts=port={}".format(self.server.getsockname()[1]))
'First message is not CAP LS.')
self.sendLine('CAP * LS :sts=port={}'.format(self.server.getsockname()[1]))
# "If the client is not already connected securely to the server # "If the client is not already connected securely to the server
# at the requested hostname, it MUST close the insecure connection # at the requested hostname, it MUST close the insecure connection
@ -187,11 +182,12 @@ class StsTestCase(cases.BaseClientTestCase, cases.OptionalityHelper):
self.acceptClient(tls_cert=GOOD_CERT, tls_key=GOOD_KEY) self.acceptClient(tls_cert=GOOD_CERT, tls_key=GOOD_KEY)
# Send the STS policy, over secure connection this time # Send the STS policy, over secure connection this time
self.sendLine('CAP * LS :sts=duration=10,port={}'.format( self.sendLine(
self.server.getsockname()[1])) "CAP * LS :sts=duration=10,port={}".format(self.server.getsockname()[1])
)
# Make the client reconnect. It should reconnect to the secure server. # Make the client reconnect. It should reconnect to the secure server.
self.sendLine('ERROR :closing link') self.sendLine("ERROR :closing link")
self.acceptClient() self.acceptClient()
# Kill the client # Kill the client
@ -199,34 +195,32 @@ class StsTestCase(cases.BaseClientTestCase, cases.OptionalityHelper):
# Run the client, still configured to connect to the insecure server # Run the client, still configured to connect to the insecure server
self.controller.run( self.controller.run(
hostname=hostname, hostname=hostname,
port=port, port=port,
auth=None, auth=None,
tls_config=tls_config, tls_config=tls_config,
) )
# The client should remember the STS policy and connect to the secure # The client should remember the STS policy and connect to the secure
# server # server
self.acceptClient() self.acceptClient()
@cases.OptionalityHelper.skipUnlessSupportsCapability('sts') @cases.OptionalityHelper.skipUnlessSupportsCapability("sts")
def testStsInvalidCertificate(self): def testStsInvalidCertificate(self):
# Connect client to insecure server # Connect client to insecure server
(hostname, port) = self.insecure_server.getsockname() (hostname, port) = self.insecure_server.getsockname()
self.controller.run( self.controller.run(
hostname=hostname, hostname=hostname,
port=port, port=port,
auth=None, auth=None,
) )
self.acceptClient(server=self.insecure_server) self.acceptClient(server=self.insecure_server)
# Send STS policy to client # Send STS policy to client
m = self.getMessage() m = self.getMessage()
self.assertEqual(m.command, 'CAP', self.assertEqual(m.command, "CAP", "First message is not CAP LS.")
'First message is not CAP LS.') self.assertEqual(m.params[0], "LS", "First message is not CAP LS.")
self.assertEqual(m.params[0], 'LS', self.sendLine("CAP * LS :sts=port={}".format(self.server.getsockname()[1]))
'First message is not CAP LS.')
self.sendLine('CAP * LS :sts=port={}'.format(self.server.getsockname()[1]))
# The client will reconnect to the TLS port. Unfortunately, it does # The client will reconnect to the TLS port. Unfortunately, it does
# not trust its fingerprint. # not trust its fingerprint.

View File

@ -43,45 +43,61 @@ TEMPLATE_SSL_CONFIG = """
class CharybdisController(BaseServerController, DirectoryBasedController): class CharybdisController(BaseServerController, DirectoryBasedController):
software_name = 'Charybdis' software_name = "Charybdis"
supported_sasl_mechanisms = set() supported_sasl_mechanisms = set()
supported_capabilities = set() # Not exhaustive supported_capabilities = set() # Not exhaustive
def create_config(self): def create_config(self):
super().create_config() super().create_config()
with self.open_file('server.conf'): with self.open_file("server.conf"):
pass pass
def run(self, hostname, port, password=None, ssl=False, def run(
valid_metadata_keys=None, invalid_metadata_keys=None): self,
hostname,
port,
password=None,
ssl=False,
valid_metadata_keys=None,
invalid_metadata_keys=None,
):
if valid_metadata_keys or invalid_metadata_keys: if valid_metadata_keys or invalid_metadata_keys:
raise NotImplementedByController( raise NotImplementedByController(
'Defining valid and invalid METADATA keys.') "Defining valid and invalid METADATA keys."
)
assert self.proc is None assert self.proc is None
self.create_config() self.create_config()
self.port = port self.port = port
password_field = 'password = "{}";'.format(password) if password else '' password_field = 'password = "{}";'.format(password) if password else ""
if ssl: if ssl:
self.gen_ssl() self.gen_ssl()
ssl_config = TEMPLATE_SSL_CONFIG.format( ssl_config = TEMPLATE_SSL_CONFIG.format(
key_path=self.key_path, key_path=self.key_path,
pem_path=self.pem_path, pem_path=self.pem_path,
dh_path=self.dh_path, dh_path=self.dh_path,
)
else:
ssl_config = ''
with self.open_file('server.conf') as fd:
fd.write(TEMPLATE_CONFIG.format(
hostname=hostname,
port=port,
password_field=password_field,
ssl_config=ssl_config,
))
self.proc = subprocess.Popen(['charybdis', '-foreground',
'-configfile', os.path.join(self.directory, 'server.conf'),
'-pidfile', os.path.join(self.directory, 'server.pid'),
],
stderr=subprocess.DEVNULL
) )
else:
ssl_config = ""
with self.open_file("server.conf") as fd:
fd.write(
TEMPLATE_CONFIG.format(
hostname=hostname,
port=port,
password_field=password_field,
ssl_config=ssl_config,
)
)
self.proc = subprocess.Popen(
[
"charybdis",
"-foreground",
"-configfile",
os.path.join(self.directory, "server.conf"),
"-pidfile",
os.path.join(self.directory, "server.pid"),
],
stderr=subprocess.DEVNULL,
)
def get_irctest_controller_class(): def get_irctest_controller_class():

View File

@ -2,9 +2,10 @@ import subprocess
from irctest.basecontrollers import BaseClientController, NotImplementedByController from irctest.basecontrollers import BaseClientController, NotImplementedByController
class GircController(BaseClientController): class GircController(BaseClientController):
software_name = 'gIRC' software_name = "gIRC"
supported_sasl_mechanisms = ['PLAIN'] supported_sasl_mechanisms = ["PLAIN"]
supported_capabilities = set() # Not exhaustive supported_capabilities = set() # Not exhaustive
def __init__(self): def __init__(self):
@ -30,16 +31,17 @@ class GircController(BaseClientController):
def run(self, hostname, port, auth, tls_config): def run(self, hostname, port, auth, tls_config):
if tls_config: if tls_config:
print(tls_config) print(tls_config)
raise NotImplementedByController('TLS options') raise NotImplementedByController("TLS options")
args = ['--host', hostname, '--port', str(port), '--quiet'] args = ["--host", hostname, "--port", str(port), "--quiet"]
if auth and auth.username and auth.password: if auth and auth.username and auth.password:
args += ['--sasl-name', auth.username] args += ["--sasl-name", auth.username]
args += ['--sasl-pass', auth.password] args += ["--sasl-pass", auth.password]
args += ['--sasl-fail-is-ok'] args += ["--sasl-fail-is-ok"]
# Runs a client with the config given as arguments # Runs a client with the config given as arguments
self.proc = subprocess.Popen(['girc_test', 'connect'] + args) self.proc = subprocess.Popen(["girc_test", "connect"] + args)
def get_irctest_controller_class(): def get_irctest_controller_class():
return GircController return GircController

View File

@ -41,47 +41,62 @@ TEMPLATE_SSL_CONFIG = """
class HybridController(BaseServerController, DirectoryBasedController): class HybridController(BaseServerController, DirectoryBasedController):
software_name = 'Hybrid' software_name = "Hybrid"
supported_sasl_mechanisms = set() supported_sasl_mechanisms = set()
supported_capabilities = set() # Not exhaustive supported_capabilities = set() # Not exhaustive
def create_config(self): def create_config(self):
super().create_config() super().create_config()
with self.open_file('server.conf'): with self.open_file("server.conf"):
pass pass
def run(self, hostname, port, password=None, ssl=False, def run(
valid_metadata_keys=None, invalid_metadata_keys=None): self,
hostname,
port,
password=None,
ssl=False,
valid_metadata_keys=None,
invalid_metadata_keys=None,
):
if valid_metadata_keys or invalid_metadata_keys: if valid_metadata_keys or invalid_metadata_keys:
raise NotImplementedByController( raise NotImplementedByController(
'Defining valid and invalid METADATA keys.') "Defining valid and invalid METADATA keys."
)
assert self.proc is None assert self.proc is None
self.create_config() self.create_config()
self.port = port self.port = port
password_field = 'password = "{}";'.format(password) if password else '' password_field = 'password = "{}";'.format(password) if password else ""
if ssl: if ssl:
self.gen_ssl() self.gen_ssl()
ssl_config = TEMPLATE_SSL_CONFIG.format( ssl_config = TEMPLATE_SSL_CONFIG.format(
key_path=self.key_path, key_path=self.key_path,
pem_path=self.pem_path, pem_path=self.pem_path,
dh_path=self.dh_path, dh_path=self.dh_path,
) )
else: else:
ssl_config = '' ssl_config = ""
with self.open_file('server.conf') as fd: with self.open_file("server.conf") as fd:
fd.write(TEMPLATE_CONFIG.format( fd.write(
hostname=hostname, TEMPLATE_CONFIG.format(
port=port, hostname=hostname,
password_field=password_field, port=port,
ssl_config=ssl_config, password_field=password_field,
)) ssl_config=ssl_config,
self.proc = subprocess.Popen(['ircd', '-foreground', )
'-configfile', os.path.join(self.directory, 'server.conf'), )
'-pidfile', os.path.join(self.directory, 'server.pid'), self.proc = subprocess.Popen(
[
"ircd",
"-foreground",
"-configfile",
os.path.join(self.directory, "server.conf"),
"-pidfile",
os.path.join(self.directory, "server.pid"),
], ],
stdout=subprocess.DEVNULL, stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL stderr=subprocess.DEVNULL,
) )
def get_irctest_controller_class(): def get_irctest_controller_class():

View File

@ -27,47 +27,63 @@ TEMPLATE_SSL_CONFIG = """
<openssl certfile="{pem_path}" keyfile="{key_path}" dhfile="{dh_path}" hash="sha1"> <openssl certfile="{pem_path}" keyfile="{key_path}" dhfile="{dh_path}" hash="sha1">
""" """
class InspircdController(BaseServerController, DirectoryBasedController): class InspircdController(BaseServerController, DirectoryBasedController):
software_name = 'InspIRCd' software_name = "InspIRCd"
supported_sasl_mechanisms = set() supported_sasl_mechanisms = set()
supported_capabilities = set() # Not exhaustive supported_capabilities = set() # Not exhaustive
def create_config(self): def create_config(self):
super().create_config() super().create_config()
with self.open_file('server.conf'): with self.open_file("server.conf"):
pass pass
def run(self, hostname, port, password=None, ssl=False, def run(
restricted_metadata_keys=None, self,
valid_metadata_keys=None, invalid_metadata_keys=None): hostname,
port,
password=None,
ssl=False,
restricted_metadata_keys=None,
valid_metadata_keys=None,
invalid_metadata_keys=None,
):
if valid_metadata_keys or invalid_metadata_keys: if valid_metadata_keys or invalid_metadata_keys:
raise NotImplementedByController( raise NotImplementedByController(
'Defining valid and invalid METADATA keys.') "Defining valid and invalid METADATA keys."
)
assert self.proc is None assert self.proc is None
self.port = port self.port = port
self.create_config() self.create_config()
password_field = 'password="{}"'.format(password) if password else '' password_field = 'password="{}"'.format(password) if password else ""
if ssl: if ssl:
self.gen_ssl() self.gen_ssl()
ssl_config = TEMPLATE_SSL_CONFIG.format( ssl_config = TEMPLATE_SSL_CONFIG.format(
key_path=self.key_path, key_path=self.key_path,
pem_path=self.pem_path, pem_path=self.pem_path,
dh_path=self.dh_path, dh_path=self.dh_path,
)
else:
ssl_config = ''
with self.open_file('server.conf') as fd:
fd.write(TEMPLATE_CONFIG.format(
hostname=hostname,
port=port,
password_field=password_field,
ssl_config=ssl_config
))
self.proc = subprocess.Popen(['inspircd', '--nofork', '--config',
os.path.join(self.directory, 'server.conf')],
stdout=subprocess.DEVNULL
) )
else:
ssl_config = ""
with self.open_file("server.conf") as fd:
fd.write(
TEMPLATE_CONFIG.format(
hostname=hostname,
port=port,
password_field=password_field,
ssl_config=ssl_config,
)
)
self.proc = subprocess.Popen(
[
"inspircd",
"--nofork",
"--config",
os.path.join(self.directory, "server.conf"),
],
stdout=subprocess.DEVNULL,
)
def get_irctest_controller_class(): def get_irctest_controller_class():
return InspircdController return InspircdController

View File

@ -26,19 +26,23 @@ supybot.networks.testnet.sasl.ecdsa_key: {directory}/ecdsa_key.pem
supybot.networks.testnet.sasl.mechanisms: {mechanisms} supybot.networks.testnet.sasl.mechanisms: {mechanisms}
""" """
class LimnoriaController(BaseClientController, DirectoryBasedController): class LimnoriaController(BaseClientController, DirectoryBasedController):
software_name = 'Limnoria' software_name = "Limnoria"
supported_sasl_mechanisms = { supported_sasl_mechanisms = {
'PLAIN', 'ECDSA-NIST256P-CHALLENGE', 'SCRAM-SHA-256', 'EXTERNAL', "PLAIN",
} "ECDSA-NIST256P-CHALLENGE",
supported_capabilities = set(['sts']) # Not exhaustive "SCRAM-SHA-256",
"EXTERNAL",
}
supported_capabilities = set(["sts"]) # Not exhaustive
def create_config(self): def create_config(self):
create_config = super().create_config() create_config = super().create_config()
if create_config: if create_config:
with self.open_file('bot.conf'): with self.open_file("bot.conf"):
pass pass
with self.open_file('conf/users.conf'): with self.open_file("conf/users.conf"):
pass pass
def run(self, hostname, port, auth, tls_config=None): def run(self, hostname, port, auth, tls_config=None):
@ -48,27 +52,34 @@ class LimnoriaController(BaseClientController, DirectoryBasedController):
assert self.proc is None assert self.proc is None
self.create_config() self.create_config()
if auth: if auth:
mechanisms = ' '.join(map(authentication.Mechanisms.as_string, mechanisms = " ".join(
auth.mechanisms)) map(authentication.Mechanisms.as_string, auth.mechanisms)
)
if auth.ecdsa_key: if auth.ecdsa_key:
with self.open_file('ecdsa_key.pem') as fd: with self.open_file("ecdsa_key.pem") as fd:
fd.write(auth.ecdsa_key) fd.write(auth.ecdsa_key)
else: else:
mechanisms = '' mechanisms = ""
with self.open_file('bot.conf') as fd: with self.open_file("bot.conf") as fd:
fd.write(TEMPLATE_CONFIG.format( fd.write(
directory=self.directory, TEMPLATE_CONFIG.format(
loglevel='CRITICAL', directory=self.directory,
hostname=hostname, loglevel="CRITICAL",
port=port, hostname=hostname,
username=auth.username if auth else '', port=port,
password=auth.password if auth else '', username=auth.username if auth else "",
mechanisms=mechanisms.lower(), password=auth.password if auth else "",
enable_tls=tls_config.enable if tls_config else 'False', mechanisms=mechanisms.lower(),
trusted_fingerprints=' '.join(tls_config.trusted_fingerprints) if tls_config else '', enable_tls=tls_config.enable if tls_config else "False",
)) trusted_fingerprints=" ".join(tls_config.trusted_fingerprints)
self.proc = subprocess.Popen(['supybot', if tls_config
os.path.join(self.directory, 'bot.conf')]) else "",
)
)
self.proc = subprocess.Popen(
["supybot", os.path.join(self.directory, "bot.conf")]
)
def get_irctest_controller_class(): def get_irctest_controller_class():
return LimnoriaController return LimnoriaController

View File

@ -58,66 +58,84 @@ server:
recvq_len: 20 recvq_len: 20
""" """
def make_list(l): def make_list(l):
return '\n'.join(map(' - {}'.format, l)) return "\n".join(map(" - {}".format, l))
class MammonController(BaseServerController, DirectoryBasedController): class MammonController(BaseServerController, DirectoryBasedController):
software_name = 'Mammon' software_name = "Mammon"
supported_sasl_mechanisms = { supported_sasl_mechanisms = {
'PLAIN', 'ECDSA-NIST256P-CHALLENGE', "PLAIN",
} "ECDSA-NIST256P-CHALLENGE",
}
supported_capabilities = set() # Not exhaustive supported_capabilities = set() # Not exhaustive
def create_config(self): def create_config(self):
super().create_config() super().create_config()
with self.open_file('server.conf'): with self.open_file("server.conf"):
pass pass
def kill_proc(self): def kill_proc(self):
# Mammon does not seem to handle SIGTERM very well # Mammon does not seem to handle SIGTERM very well
self.proc.kill() self.proc.kill()
def run(self, hostname, port, password=None, ssl=False, def run(
restricted_metadata_keys=(), self,
valid_metadata_keys=(), invalid_metadata_keys=()): hostname,
port,
password=None,
ssl=False,
restricted_metadata_keys=(),
valid_metadata_keys=(),
invalid_metadata_keys=(),
):
if password is not None: if password is not None:
raise NotImplementedByController('PASS command') raise NotImplementedByController("PASS command")
if ssl: if ssl:
raise NotImplementedByController('SSL') raise NotImplementedByController("SSL")
assert self.proc is None assert self.proc is None
self.port = port self.port = port
self.create_config() self.create_config()
with self.open_file('server.yml') as fd: with self.open_file("server.yml") as fd:
fd.write(TEMPLATE_CONFIG.format( fd.write(
directory=self.directory, TEMPLATE_CONFIG.format(
hostname=hostname, directory=self.directory,
port=port, hostname=hostname,
authorized_keys=make_list(valid_metadata_keys), port=port,
restricted_keys=make_list(restricted_metadata_keys), authorized_keys=make_list(valid_metadata_keys),
)) restricted_keys=make_list(restricted_metadata_keys),
#with self.open_file('server.yml', 'r') as fd: )
)
# with self.open_file('server.yml', 'r') as fd:
# print(fd.read()) # print(fd.read())
self.proc = subprocess.Popen(['mammond', '--nofork', #'--debug', self.proc = subprocess.Popen(
'--config', os.path.join(self.directory, 'server.yml')]) [
"mammond",
"--nofork", #'--debug',
"--config",
os.path.join(self.directory, "server.yml"),
]
)
def registerUser(self, case, username, password=None): def registerUser(self, case, username, password=None):
# XXX: Move this somewhere else when # XXX: Move this somewhere else when
# https://github.com/ircv3/ircv3-specifications/pull/152 becomes # https://github.com/ircv3/ircv3-specifications/pull/152 becomes
# part of the specification # part of the specification
client = case.addClient(show_io=False) client = case.addClient(show_io=False)
case.sendLine(client, 'CAP LS 302') case.sendLine(client, "CAP LS 302")
case.sendLine(client, 'NICK registration_user') case.sendLine(client, "NICK registration_user")
case.sendLine(client, 'USER r e g :user') case.sendLine(client, "USER r e g :user")
case.sendLine(client, 'CAP END') case.sendLine(client, "CAP END")
while case.getRegistrationMessage(client).command != '001': while case.getRegistrationMessage(client).command != "001":
pass pass
list(case.getMessages(client)) list(case.getMessages(client))
case.sendLine(client, 'REG CREATE {} passphrase {}'.format( case.sendLine(client, "REG CREATE {} passphrase {}".format(username, password))
username, password))
msg = case.getMessage(client) msg = case.getMessage(client)
assert msg.command == '920', msg assert msg.command == "920", msg
list(case.getMessages(client)) list(case.getMessages(client))
case.removeClient(client) case.removeClient(client)
def get_irctest_controller_class(): def get_irctest_controller_class():
return MammonController return MammonController

View File

@ -6,13 +6,12 @@ import subprocess
from irctest.basecontrollers import NotImplementedByController from irctest.basecontrollers import NotImplementedByController
from irctest.basecontrollers import BaseServerController, DirectoryBasedController from irctest.basecontrollers import BaseServerController, DirectoryBasedController
OPER_PWD = 'frenchfries' OPER_PWD = "frenchfries"
BASE_CONFIG = { BASE_CONFIG = {
"network": { "network": {
"name": "OragonoTest", "name": "OragonoTest",
}, },
"server": { "server": {
"name": "oragono.test", "name": "oragono.test",
"listeners": {}, "listeners": {},
@ -35,140 +34,153 @@ BASE_CONFIG = {
"ban-message": "Try again later", "ban-message": "Try again later",
"exempted": ["localhost"], "exempted": ["localhost"],
}, },
'enforce-utf8': True, "enforce-utf8": True,
'relaymsg': { "relaymsg": {
'enabled': True, "enabled": True,
'separators': '/', "separators": "/",
'available-to-chanops': True, "available-to-chanops": True,
}, },
}, },
"accounts": {
'accounts': { "authentication-enabled": True,
'authentication-enabled': True, "multiclient": {
'multiclient': { "allowed-by-default": True,
'allowed-by-default': True, "enabled": True,
'enabled': True, "always-on": "disabled",
'always-on': 'disabled',
}, },
'registration': { "registration": {
'bcrypt-cost': 4, "bcrypt-cost": 4,
'enabled': True, "enabled": True,
'enabled-callbacks': ['none'], "enabled-callbacks": ["none"],
'verify-timeout': '120h', "verify-timeout": "120h",
}, },
'nick-reservation': { "nick-reservation": {
'enabled': True, "enabled": True,
'additional-nick-limit': 2, "additional-nick-limit": 2,
'method': 'strict', "method": "strict",
}, },
}, },
"channels": {
"channels": { "registration": {
"registration": {"enabled": True,}, "enabled": True,
}, },
},
"datastore": { "datastore": {
"path": None, "path": None,
}, },
"limits": {
'limits': { "awaylen": 200,
'awaylen': 200, "chan-list-modes": 60,
'chan-list-modes': 60, "channellen": 64,
'channellen': 64, "kicklen": 390,
'kicklen': 390, "linelen": {
'linelen': {'rest': 2048,}, "rest": 2048,
'monitor-entries': 100, },
'nicklen': 32, "monitor-entries": 100,
'topiclen': 390, "nicklen": 32,
'whowas-entries': 100, "topiclen": 390,
'multiline': {'max-bytes': 4096, 'max-lines': 32,}, "whowas-entries": 100,
}, "multiline": {
"max-bytes": 4096,
"history": { "max-lines": 32,
"enabled": True, },
"channel-length": 128, },
"client-length": 128, "history": {
"chathistory-maxmessages": 100, "enabled": True,
"tagmsg-storage": { "channel-length": 128,
"default": False, "client-length": 128,
"whitelist": ["+draft/persist", "+persist"], "chathistory-maxmessages": 100,
}, "tagmsg-storage": {
}, "default": False,
"whitelist": ["+draft/persist", "+persist"],
'oper-classes': { },
'server-admin': { },
'title': 'Server Admin', "oper-classes": {
'capabilities': [ "server-admin": {
"oper:local_kill", "title": "Server Admin",
"oper:local_ban", "capabilities": [
"oper:local_unban", "oper:local_kill",
"nofakelag", "oper:local_ban",
"oper:remote_kill", "oper:local_unban",
"oper:remote_ban", "nofakelag",
"oper:remote_unban", "oper:remote_kill",
"oper:rehash", "oper:remote_ban",
"oper:die", "oper:remote_unban",
"accreg", "oper:rehash",
"sajoin", "oper:die",
"samode", "accreg",
"vhosts", "sajoin",
"chanreg", "samode",
"relaymsg", "vhosts",
"chanreg",
"relaymsg",
], ],
}, },
}, },
"opers": {
'opers': { "root": {
'root': { "class": "server-admin",
'class': 'server-admin', "whois-line": "is a server admin",
'whois-line': 'is a server admin',
# OPER_PWD # OPER_PWD
'password': '$2a$04$3GzUZB5JapaAbwn7sogpOu9NSiLOgnozVllm2e96LiNPrm61ZsZSq', "password": "$2a$04$3GzUZB5JapaAbwn7sogpOu9NSiLOgnozVllm2e96LiNPrm61ZsZSq",
}, },
}, },
} }
LOGGING_CONFIG = { LOGGING_CONFIG = {
"logging": [ "logging": [
{ {
"method": "stderr", "method": "stderr",
"level": "debug", "level": "debug",
"type": "*", "type": "*",
}, },
] ]
} }
def hash_password(password): def hash_password(password):
if isinstance(password, str): if isinstance(password, str):
password = password.encode('utf-8') password = password.encode("utf-8")
# simulate entry of password and confirmation: # simulate entry of password and confirmation:
input_ = password + b'\n' + password + b'\n' input_ = password + b"\n" + password + b"\n"
p = subprocess.Popen(['oragono', 'genpasswd'], stdin=subprocess.PIPE, stdout=subprocess.PIPE) p = subprocess.Popen(
["oragono", "genpasswd"], stdin=subprocess.PIPE, stdout=subprocess.PIPE
)
out, _ = p.communicate(input_) out, _ = p.communicate(input_)
return out.decode('utf-8') return out.decode("utf-8")
class OragonoController(BaseServerController, DirectoryBasedController): class OragonoController(BaseServerController, DirectoryBasedController):
software_name = 'Oragono' software_name = "Oragono"
supported_sasl_mechanisms = { supported_sasl_mechanisms = {
'PLAIN', "PLAIN",
} }
_port_wait_interval = .01 _port_wait_interval = 0.01
supported_capabilities = set() # Not exhaustive supported_capabilities = set() # Not exhaustive
def create_config(self): def create_config(self):
super().create_config() super().create_config()
with self.open_file('ircd.yaml'): with self.open_file("ircd.yaml"):
pass pass
def kill_proc(self): def kill_proc(self):
self.proc.kill() self.proc.kill()
def run(self, hostname, port, password=None, ssl=False, def run(
restricted_metadata_keys=None, self,
valid_metadata_keys=None, invalid_metadata_keys=None, config=None): hostname,
port,
password=None,
ssl=False,
restricted_metadata_keys=None,
valid_metadata_keys=None,
invalid_metadata_keys=None,
config=None,
):
if valid_metadata_keys or invalid_metadata_keys: if valid_metadata_keys or invalid_metadata_keys:
raise NotImplementedByController( raise NotImplementedByController(
'Defining valid and invalid METADATA keys.') "Defining valid and invalid METADATA keys."
)
self.create_config() self.create_config()
if config is None: if config is None:
@ -180,59 +192,60 @@ class OragonoController(BaseServerController, DirectoryBasedController):
config = self.addMysqlToConfig(config) config = self.addMysqlToConfig(config)
if enable_roleplay: if enable_roleplay:
config['roleplay'] = { config["roleplay"] = {
'enabled': True, "enabled": True,
} }
if 'oragono_config' in self.test_config: if "oragono_config" in self.test_config:
self.test_config['oragono_config'](config) self.test_config["oragono_config"](config)
self.port = port self.port = port
bind_address = "127.0.0.1:%s" % (port,) bind_address = "127.0.0.1:%s" % (port,)
listener_conf = None # plaintext listener_conf = None # plaintext
if ssl: if ssl:
self.key_path = os.path.join(self.directory, 'ssl.key') self.key_path = os.path.join(self.directory, "ssl.key")
self.pem_path = os.path.join(self.directory, 'ssl.pem') self.pem_path = os.path.join(self.directory, "ssl.pem")
listener_conf = {"tls": {"cert": self.pem_path, "key": self.key_path},} listener_conf = {
config['server']['listeners'][bind_address] = listener_conf "tls": {"cert": self.pem_path, "key": self.key_path},
}
config["server"]["listeners"][bind_address] = listener_conf
config['datastore']['path'] = os.path.join(self.directory, 'ircd.db') config["datastore"]["path"] = os.path.join(self.directory, "ircd.db")
if password is not None: if password is not None:
config['server']['password'] = hash_password(password) config["server"]["password"] = hash_password(password)
assert self.proc is None assert self.proc is None
self._config_path = os.path.join(self.directory, 'server.yml') self._config_path = os.path.join(self.directory, "server.yml")
self._config = config self._config = config
self._write_config() self._write_config()
subprocess.call(['oragono', 'initdb', subprocess.call(["oragono", "initdb", "--conf", self._config_path, "--quiet"])
'--conf', self._config_path, '--quiet']) subprocess.call(["oragono", "mkcerts", "--conf", self._config_path, "--quiet"])
subprocess.call(['oragono', 'mkcerts', self.proc = subprocess.Popen(
'--conf', self._config_path, '--quiet']) ["oragono", "run", "--conf", self._config_path, "--quiet"]
self.proc = subprocess.Popen(['oragono', 'run', )
'--conf', self._config_path, '--quiet'])
def registerUser(self, case, username, password=None): def registerUser(self, case, username, password=None):
# XXX: Move this somewhere else when # XXX: Move this somewhere else when
# https://github.com/ircv3/ircv3-specifications/pull/152 becomes # https://github.com/ircv3/ircv3-specifications/pull/152 becomes
# part of the specification # part of the specification
client = case.addClient(show_io=False) client = case.addClient(show_io=False)
case.sendLine(client, 'CAP LS 302') case.sendLine(client, "CAP LS 302")
case.sendLine(client, 'NICK ' + username) case.sendLine(client, "NICK " + username)
case.sendLine(client, 'USER r e g :user') case.sendLine(client, "USER r e g :user")
case.sendLine(client, 'CAP END') case.sendLine(client, "CAP END")
while case.getRegistrationMessage(client).command != '001': while case.getRegistrationMessage(client).command != "001":
pass pass
case.getMessages(client) case.getMessages(client)
case.sendLine(client, 'NS REGISTER ' + password) case.sendLine(client, "NS REGISTER " + password)
msg = case.getMessage(client) msg = case.getMessage(client)
assert msg.params == [username, 'Account created'] assert msg.params == [username, "Account created"]
case.sendLine(client, 'QUIT') case.sendLine(client, "QUIT")
case.assertDisconnected(client) case.assertDisconnected(client)
def _write_config(self): def _write_config(self):
with open(self._config_path, 'w') as fd: with open(self._config_path, "w") as fd:
json.dump(self._config, fd) json.dump(self._config, fd)
def baseConfig(self): def baseConfig(self):
@ -248,25 +261,25 @@ class OragonoController(BaseServerController, DirectoryBasedController):
return config return config
def addMysqlToConfig(self, config=None): def addMysqlToConfig(self, config=None):
mysql_password = os.getenv('MYSQL_PASSWORD') mysql_password = os.getenv("MYSQL_PASSWORD")
if not mysql_password: if not mysql_password:
return config return config
if config is None: if config is None:
config = self.baseConfig() config = self.baseConfig()
config['datastore']['mysql'] = { config["datastore"]["mysql"] = {
"enabled": True, "enabled": True,
"host": "localhost", "host": "localhost",
"user": "oragono", "user": "oragono",
"password": mysql_password, "password": mysql_password,
"history-database": "oragono_history", "history-database": "oragono_history",
"timeout": "3s", "timeout": "3s",
} }
config['accounts']['multiclient'] = { config["accounts"]["multiclient"] = {
'enabled': True, "enabled": True,
'allowed-by-default': True, "allowed-by-default": True,
'always-on': 'disabled', "always-on": "disabled",
} }
config['history']['persistent'] = { config["history"]["persistent"] = {
"enabled": True, "enabled": True,
"unregistered-channels": True, "unregistered-channels": True,
"registered-channels": "opt-out", "registered-channels": "opt-out",
@ -277,12 +290,12 @@ class OragonoController(BaseServerController, DirectoryBasedController):
def rehash(self, case, config): def rehash(self, case, config):
self._config = config self._config = config
self._write_config() self._write_config()
client = 'operator_for_rehash' client = "operator_for_rehash"
case.connectClient(nick=client, name=client) case.connectClient(nick=client, name=client)
case.sendLine(client, 'OPER root %s' % (OPER_PWD,)) case.sendLine(client, "OPER root %s" % (OPER_PWD,))
case.sendLine(client, 'REHASH') case.sendLine(client, "REHASH")
case.getMessages(client) case.getMessages(client)
case.sendLine(client, 'QUIT') case.sendLine(client, "QUIT")
case.assertDisconnected(client) case.assertDisconnected(client)
def enable_debug_logging(self, case): def enable_debug_logging(self, case):
@ -290,5 +303,6 @@ class OragonoController(BaseServerController, DirectoryBasedController):
config.update(LOGGING_CONFIG) config.update(LOGGING_CONFIG)
self.rehash(case, config) self.rehash(case, config)
def get_irctest_controller_class(): def get_irctest_controller_class():
return OragonoController return OragonoController

View File

@ -19,30 +19,30 @@ auth_password = {password}
{auth_method} {auth_method}
""" """
class SopelController(BaseClientController): class SopelController(BaseClientController):
software_name = 'Sopel' software_name = "Sopel"
supported_sasl_mechanisms = { supported_sasl_mechanisms = {
'PLAIN', "PLAIN",
} }
supported_capabilities = set() # Not exhaustive supported_capabilities = set() # Not exhaustive
def __init__(self, test_config): def __init__(self, test_config):
super().__init__(test_config) super().__init__(test_config)
self.filename = next(tempfile._get_candidate_names()) + '.cfg' self.filename = next(tempfile._get_candidate_names()) + ".cfg"
self.proc = None self.proc = None
def kill(self): def kill(self):
if self.proc: if self.proc:
self.proc.kill() self.proc.kill()
if self.filename: if self.filename:
try: try:
os.unlink(os.path.join(os.path.expanduser('~/.sopel/'), os.unlink(os.path.join(os.path.expanduser("~/.sopel/"), self.filename))
self.filename)) except OSError: #  File does not exist
except OSError: # File does not exist
pass pass
def open_file(self, filename, mode='a'): def open_file(self, filename, mode="a"):
return open(os.path.join(os.path.expanduser('~/.sopel/'), filename), return open(os.path.join(os.path.expanduser("~/.sopel/"), filename), mode)
mode)
def create_config(self): def create_config(self):
with self.open_file(self.filename) as fd: with self.open_file(self.filename) as fd:
@ -51,20 +51,21 @@ class SopelController(BaseClientController):
def run(self, hostname, port, auth, tls_config): def run(self, hostname, port, auth, tls_config):
# Runs a client with the config given as arguments # Runs a client with the config given as arguments
if tls_config is not None: if tls_config is not None:
raise NotImplementedByController( raise NotImplementedByController("TLS configuration")
'TLS configuration')
assert self.proc is None assert self.proc is None
self.create_config() self.create_config()
with self.open_file(self.filename) as fd: with self.open_file(self.filename) as fd:
fd.write(TEMPLATE_CONFIG.format( fd.write(
hostname=hostname, TEMPLATE_CONFIG.format(
port=port, hostname=hostname,
username=auth.username if auth else '', port=port,
password=auth.password if auth else '', username=auth.username if auth else "",
auth_method='auth_method = sasl' if auth else '', password=auth.password if auth else "",
)) auth_method="auth_method = sasl" if auth else "",
self.proc = subprocess.Popen(['sopel', '--quiet', '-c', self.filename]) )
)
self.proc = subprocess.Popen(["sopel", "--quiet", "-c", self.filename])
def get_irctest_controller_class(): def get_irctest_controller_class():
return SopelController return SopelController

View File

@ -1,6 +1,6 @@
class NoMessageException(AssertionError): class NoMessageException(AssertionError):
pass pass
class ConnectionClosed(Exception): class ConnectionClosed(Exception):
pass pass

View File

@ -2,6 +2,7 @@
Handles ambiguities of RFCs. Handles ambiguities of RFCs.
""" """
def normalize_namreply_params(params): def normalize_namreply_params(params):
# So… RFC 2812 says: # So… RFC 2812 says:
# "( "=" / "*" / "@" ) <channel> # "( "=" / "*" / "@" ) <channel>
@ -12,7 +13,7 @@ def normalize_namreply_params(params):
# So let's normalize this to “with space”, and strip spaces at the # So let's normalize this to “with space”, and strip spaces at the
# end of the nick list. # end of the nick list.
if len(params) == 3: if len(params) == 3:
assert params[1][0] in '=*@', params assert params[1][0] in "=*@", params
params.insert(1), params[1][0] params.insert(1), params[1][0]
params[2] = params[2][1:] params[2] = params[2][1:]
params[3] = params[3].rstrip() params[3] = params[3].rstrip()

View File

@ -1,8 +1,8 @@
def cap_list_to_dict(l): def cap_list_to_dict(l):
d = {} d = {}
for cap in l: for cap in l:
if '=' in cap: if "=" in cap:
(key, value) = cap.split('=', 1) (key, value) = cap.split("=", 1)
else: else:
key = cap key = cap
value = None value = None

View File

@ -3,24 +3,35 @@ import re
import secrets import secrets
from collections import namedtuple from collections import namedtuple
HistoryMessage = namedtuple('HistoryMessage', ['time', 'msgid', 'target', 'text']) HistoryMessage = namedtuple("HistoryMessage", ["time", "msgid", "target", "text"])
def to_history_message(msg): def to_history_message(msg):
return HistoryMessage(time=msg.tags.get('time'), msgid=msg.tags.get('msgid'), target=msg.params[0], text=msg.params[1]) return HistoryMessage(
time=msg.tags.get("time"),
msgid=msg.tags.get("msgid"),
target=msg.params[0],
text=msg.params[1],
)
# thanks jess! # thanks jess!
IRCV3_FORMAT_STRFTIME = "%Y-%m-%dT%H:%M:%S.%f%z" IRCV3_FORMAT_STRFTIME = "%Y-%m-%dT%H:%M:%S.%f%z"
def ircv3_timestamp_to_unixtime(timestamp): def ircv3_timestamp_to_unixtime(timestamp):
return datetime.datetime.strptime(timestamp, IRCV3_FORMAT_STRFTIME).timestamp() return datetime.datetime.strptime(timestamp, IRCV3_FORMAT_STRFTIME).timestamp()
def random_name(base): def random_name(base):
return base + '-' + secrets.token_hex(8) return base + "-" + secrets.token_hex(8)
""" """
Stolen from supybot: Stolen from supybot:
""" """
class MultipleReplacer: class MultipleReplacer:
"""Return a callable that replaces all dict keys by the associated """Return a callable that replaces all dict keys by the associated
value. More efficient than multiple .replace().""" value. More efficient than multiple .replace()."""
@ -30,24 +41,26 @@ class MultipleReplacer:
# it to a class in Python 3. # it to a class in Python 3.
def __init__(self, dict_): def __init__(self, dict_):
self._dict = dict_ self._dict = dict_
dict_ = dict([(re.escape(key), val) for key,val in dict_.items()]) dict_ = dict([(re.escape(key), val) for key, val in dict_.items()])
self._matcher = re.compile('|'.join(dict_.keys())) self._matcher = re.compile("|".join(dict_.keys()))
def __call__(self, s): def __call__(self, s):
return self._matcher.sub(lambda m: self._dict[m.group(0)], s) return self._matcher.sub(lambda m: self._dict[m.group(0)], s)
def normalizeWhitespace(s, removeNewline=True): def normalizeWhitespace(s, removeNewline=True):
r"""Normalizes the whitespace in a string; \s+ becomes one space.""" r"""Normalizes the whitespace in a string; \s+ becomes one space."""
if not s: if not s:
return str(s) # not the same reference return str(s) # not the same reference
starts_with_space = (s[0] in ' \n\t\r') starts_with_space = s[0] in " \n\t\r"
ends_with_space = (s[-1] in ' \n\t\r') ends_with_space = s[-1] in " \n\t\r"
if removeNewline: if removeNewline:
newline_re = re.compile('[\r\n]+') newline_re = re.compile("[\r\n]+")
s = ' '.join(filter(bool, newline_re.split(s))) s = " ".join(filter(bool, newline_re.split(s)))
s = ' '.join(filter(bool, s.split('\t'))) s = " ".join(filter(bool, s.split("\t")))
s = ' '.join(filter(bool, s.split(' '))) s = " ".join(filter(bool, s.split(" ")))
if starts_with_space: if starts_with_space:
s = ' ' + s s = " " + s
if ends_with_space: if ends_with_space:
s += ' ' s += " "
return s return s

View File

@ -5,58 +5,58 @@ from .junkdrawer import MultipleReplacer
# http://ircv3.net/specs/core/message-tags-3.2.html#escaping-values # http://ircv3.net/specs/core/message-tags-3.2.html#escaping-values
TAG_ESCAPE = [ TAG_ESCAPE = [
('\\', '\\\\'), # \ -> \\ ("\\", "\\\\"), # \ -> \\
(' ', r'\s'), (" ", r"\s"),
(';', r'\:'), (";", r"\:"),
('\r', r'\r'), ("\r", r"\r"),
('\n', r'\n'), ("\n", r"\n"),
] ]
unescape_tag_value = MultipleReplacer( unescape_tag_value = MultipleReplacer(dict(map(lambda x: (x[1], x[0]), TAG_ESCAPE)))
dict(map(lambda x:(x[1],x[0]), TAG_ESCAPE)))
# TODO: validate host # TODO: validate host
tag_key_validator = re.compile(r'\+?(\S+/)?[a-zA-Z0-9-]+') tag_key_validator = re.compile(r"\+?(\S+/)?[a-zA-Z0-9-]+")
def parse_tags(s): def parse_tags(s):
tags = {} tags = {}
for tag in s.split(';'): for tag in s.split(";"):
if '=' not in tag: if "=" not in tag:
tags[tag] = None tags[tag] = None
else: else:
(key, value) = tag.split('=', 1) (key, value) = tag.split("=", 1)
assert tag_key_validator.match(key), \ assert tag_key_validator.match(key), "Invalid tag key: {}".format(key)
'Invalid tag key: {}'.format(key)
tags[key] = unescape_tag_value(value) tags[key] = unescape_tag_value(value)
return tags return tags
Message = collections.namedtuple('Message',
'tags prefix command params') Message = collections.namedtuple("Message", "tags prefix command params")
def parse_message(s): def parse_message(s):
"""Parse a message according to """Parse a message according to
http://tools.ietf.org/html/rfc1459#section-2.3.1 http://tools.ietf.org/html/rfc1459#section-2.3.1
and and
http://ircv3.net/specs/core/message-tags-3.2.html""" http://ircv3.net/specs/core/message-tags-3.2.html"""
s = s.rstrip('\r\n') s = s.rstrip("\r\n")
if s.startswith('@'): if s.startswith("@"):
(tags, s) = s.split(' ', 1) (tags, s) = s.split(" ", 1)
tags = parse_tags(tags[1:]) tags = parse_tags(tags[1:])
else: else:
tags = {} tags = {}
if ' :' in s: if " :" in s:
(other_tokens, trailing_param) = s.split(' :', 1) (other_tokens, trailing_param) = s.split(" :", 1)
tokens = list(filter(bool, other_tokens.split(' '))) + [trailing_param] tokens = list(filter(bool, other_tokens.split(" "))) + [trailing_param]
else: else:
tokens = list(filter(bool, s.split(' '))) tokens = list(filter(bool, s.split(" ")))
if tokens[0].startswith(':'): if tokens[0].startswith(":"):
prefix = tokens.pop(0)[1:] prefix = tokens.pop(0)[1:]
else: else:
prefix = None prefix = None
command = tokens.pop(0) command = tokens.pop(0)
params = tokens params = tokens
return Message( return Message(
tags=tags, tags=tags,
prefix=prefix, prefix=prefix,
command=command, command=command,
params=params, params=params,
) )

View File

@ -1,6 +1,15 @@
import base64 import base64
def sasl_plain_blob(username, passphrase): def sasl_plain_blob(username, passphrase):
blob = base64.b64encode(b'\x00'.join((username.encode('utf-8'), username.encode('utf-8'), passphrase.encode('utf-8')))) blob = base64.b64encode(
blobstr = blob.decode('ascii') b"\x00".join(
return f'AUTHENTICATE {blobstr}' (
username.encode("utf-8"),
username.encode("utf-8"),
passphrase.encode("utf-8"),
)
)
)
blobstr = blob.decode("ascii")
return f"AUTHENTICATE {blobstr}"

View File

@ -9,191 +9,191 @@
# They're intended to represent a relatively-standard cross-section of the IRC # They're intended to represent a relatively-standard cross-section of the IRC
# server ecosystem out there. Custom numerics will be marked as such. # server ecosystem out there. Custom numerics will be marked as such.
RPL_WELCOME = "001" RPL_WELCOME = "001"
RPL_YOURHOST = "002" RPL_YOURHOST = "002"
RPL_CREATED = "003" RPL_CREATED = "003"
RPL_MYINFO = "004" RPL_MYINFO = "004"
RPL_ISUPPORT = "005" RPL_ISUPPORT = "005"
RPL_SNOMASKIS = "008" RPL_SNOMASKIS = "008"
RPL_BOUNCE = "010" RPL_BOUNCE = "010"
RPL_TRACELINK = "200" RPL_TRACELINK = "200"
RPL_TRACECONNECTING = "201" RPL_TRACECONNECTING = "201"
RPL_TRACEHANDSHAKE = "202" RPL_TRACEHANDSHAKE = "202"
RPL_TRACEUNKNOWN = "203" RPL_TRACEUNKNOWN = "203"
RPL_TRACEOPERATOR = "204" RPL_TRACEOPERATOR = "204"
RPL_TRACEUSER = "205" RPL_TRACEUSER = "205"
RPL_TRACESERVER = "206" RPL_TRACESERVER = "206"
RPL_TRACESERVICE = "207" RPL_TRACESERVICE = "207"
RPL_TRACENEWTYPE = "208" RPL_TRACENEWTYPE = "208"
RPL_TRACECLASS = "209" RPL_TRACECLASS = "209"
RPL_TRACERECONNECT = "210" RPL_TRACERECONNECT = "210"
RPL_STATSLINKINFO = "211" RPL_STATSLINKINFO = "211"
RPL_STATSCOMMANDS = "212" RPL_STATSCOMMANDS = "212"
RPL_ENDOFSTATS = "219" RPL_ENDOFSTATS = "219"
RPL_UMODEIS = "221" RPL_UMODEIS = "221"
RPL_SERVLIST = "234" RPL_SERVLIST = "234"
RPL_SERVLISTEND = "235" RPL_SERVLISTEND = "235"
RPL_STATSUPTIME = "242" RPL_STATSUPTIME = "242"
RPL_STATSOLINE = "243" RPL_STATSOLINE = "243"
RPL_LUSERCLIENT = "251" RPL_LUSERCLIENT = "251"
RPL_LUSEROP = "252" RPL_LUSEROP = "252"
RPL_LUSERUNKNOWN = "253" RPL_LUSERUNKNOWN = "253"
RPL_LUSERCHANNELS = "254" RPL_LUSERCHANNELS = "254"
RPL_LUSERME = "255" RPL_LUSERME = "255"
RPL_ADMINME = "256" RPL_ADMINME = "256"
RPL_ADMINLOC1 = "257" RPL_ADMINLOC1 = "257"
RPL_ADMINLOC2 = "258" RPL_ADMINLOC2 = "258"
RPL_ADMINEMAIL = "259" RPL_ADMINEMAIL = "259"
RPL_TRACELOG = "261" RPL_TRACELOG = "261"
RPL_TRACEEND = "262" RPL_TRACEEND = "262"
RPL_TRYAGAIN = "263" RPL_TRYAGAIN = "263"
RPL_LOCALUSERS = "265" RPL_LOCALUSERS = "265"
RPL_GLOBALUSERS = "266" RPL_GLOBALUSERS = "266"
RPL_WHOISCERTFP = "276" RPL_WHOISCERTFP = "276"
RPL_AWAY = "301" RPL_AWAY = "301"
RPL_USERHOST = "302" RPL_USERHOST = "302"
RPL_ISON = "303" RPL_ISON = "303"
RPL_UNAWAY = "305" RPL_UNAWAY = "305"
RPL_NOWAWAY = "306" RPL_NOWAWAY = "306"
RPL_WHOISUSER = "311" RPL_WHOISUSER = "311"
RPL_WHOISSERVER = "312" RPL_WHOISSERVER = "312"
RPL_WHOISOPERATOR = "313" RPL_WHOISOPERATOR = "313"
RPL_WHOWASUSER = "314" RPL_WHOWASUSER = "314"
RPL_ENDOFWHO = "315" RPL_ENDOFWHO = "315"
RPL_WHOISIDLE = "317" RPL_WHOISIDLE = "317"
RPL_ENDOFWHOIS = "318" RPL_ENDOFWHOIS = "318"
RPL_WHOISCHANNELS = "319" RPL_WHOISCHANNELS = "319"
RPL_LIST = "322" RPL_LIST = "322"
RPL_LISTEND = "323" RPL_LISTEND = "323"
RPL_CHANNELMODEIS = "324" RPL_CHANNELMODEIS = "324"
RPL_UNIQOPIS = "325" RPL_UNIQOPIS = "325"
RPL_CHANNELCREATED = "329" RPL_CHANNELCREATED = "329"
RPL_WHOISACCOUNT = "330" RPL_WHOISACCOUNT = "330"
RPL_NOTOPIC = "331" RPL_NOTOPIC = "331"
RPL_TOPIC = "332" RPL_TOPIC = "332"
RPL_TOPICTIME = "333" RPL_TOPICTIME = "333"
RPL_WHOISBOT = "335" RPL_WHOISBOT = "335"
RPL_WHOISACTUALLY = "338" RPL_WHOISACTUALLY = "338"
RPL_INVITING = "341" RPL_INVITING = "341"
RPL_SUMMONING = "342" RPL_SUMMONING = "342"
RPL_INVITELIST = "346" RPL_INVITELIST = "346"
RPL_ENDOFINVITELIST = "347" RPL_ENDOFINVITELIST = "347"
RPL_EXCEPTLIST = "348" RPL_EXCEPTLIST = "348"
RPL_ENDOFEXCEPTLIST = "349" RPL_ENDOFEXCEPTLIST = "349"
RPL_VERSION = "351" RPL_VERSION = "351"
RPL_WHOREPLY = "352" RPL_WHOREPLY = "352"
RPL_NAMREPLY = "353" RPL_NAMREPLY = "353"
RPL_LINKS = "364" RPL_LINKS = "364"
RPL_ENDOFLINKS = "365" RPL_ENDOFLINKS = "365"
RPL_ENDOFNAMES = "366" RPL_ENDOFNAMES = "366"
RPL_BANLIST = "367" RPL_BANLIST = "367"
RPL_ENDOFBANLIST = "368" RPL_ENDOFBANLIST = "368"
RPL_ENDOFWHOWAS = "369" RPL_ENDOFWHOWAS = "369"
RPL_INFO = "371" RPL_INFO = "371"
RPL_MOTD = "372" RPL_MOTD = "372"
RPL_ENDOFINFO = "374" RPL_ENDOFINFO = "374"
RPL_MOTDSTART = "375" RPL_MOTDSTART = "375"
RPL_ENDOFMOTD = "376" RPL_ENDOFMOTD = "376"
RPL_YOUREOPER = "381" RPL_YOUREOPER = "381"
RPL_REHASHING = "382" RPL_REHASHING = "382"
RPL_YOURESERVICE = "383" RPL_YOURESERVICE = "383"
RPL_TIME = "391" RPL_TIME = "391"
RPL_USERSSTART = "392" RPL_USERSSTART = "392"
RPL_USERS = "393" RPL_USERS = "393"
RPL_ENDOFUSERS = "394" RPL_ENDOFUSERS = "394"
RPL_NOUSERS = "395" RPL_NOUSERS = "395"
ERR_UNKNOWNERROR = "400" ERR_UNKNOWNERROR = "400"
ERR_NOSUCHNICK = "401" ERR_NOSUCHNICK = "401"
ERR_NOSUCHSERVER = "402" ERR_NOSUCHSERVER = "402"
ERR_NOSUCHCHANNEL = "403" ERR_NOSUCHCHANNEL = "403"
ERR_CANNOTSENDTOCHAN = "404" ERR_CANNOTSENDTOCHAN = "404"
ERR_TOOMANYCHANNELS = "405" ERR_TOOMANYCHANNELS = "405"
ERR_WASNOSUCHNICK = "406" ERR_WASNOSUCHNICK = "406"
ERR_TOOMANYTARGETS = "407" ERR_TOOMANYTARGETS = "407"
ERR_NOSUCHSERVICE = "408" ERR_NOSUCHSERVICE = "408"
ERR_NOORIGIN = "409" ERR_NOORIGIN = "409"
ERR_INVALIDCAPCMD = "410" ERR_INVALIDCAPCMD = "410"
ERR_NORECIPIENT = "411" ERR_NORECIPIENT = "411"
ERR_NOTEXTTOSEND = "412" ERR_NOTEXTTOSEND = "412"
ERR_NOTOPLEVEL = "413" ERR_NOTOPLEVEL = "413"
ERR_WILDTOPLEVEL = "414" ERR_WILDTOPLEVEL = "414"
ERR_BADMASK = "415" ERR_BADMASK = "415"
ERR_INPUTTOOLONG = "417" ERR_INPUTTOOLONG = "417"
ERR_UNKNOWNCOMMAND = "421" ERR_UNKNOWNCOMMAND = "421"
ERR_NOMOTD = "422" ERR_NOMOTD = "422"
ERR_NOADMININFO = "423" ERR_NOADMININFO = "423"
ERR_FILEERROR = "424" ERR_FILEERROR = "424"
ERR_NONICKNAMEGIVEN = "431" ERR_NONICKNAMEGIVEN = "431"
ERR_ERRONEUSNICKNAME = "432" ERR_ERRONEUSNICKNAME = "432"
ERR_NICKNAMEINUSE = "433" ERR_NICKNAMEINUSE = "433"
ERR_NICKCOLLISION = "436" ERR_NICKCOLLISION = "436"
ERR_UNAVAILRESOURCE = "437" ERR_UNAVAILRESOURCE = "437"
ERR_REG_UNAVAILABLE = "440" ERR_REG_UNAVAILABLE = "440"
ERR_USERNOTINCHANNEL = "441" ERR_USERNOTINCHANNEL = "441"
ERR_NOTONCHANNEL = "442" ERR_NOTONCHANNEL = "442"
ERR_USERONCHANNEL = "443" ERR_USERONCHANNEL = "443"
ERR_NOLOGIN = "444" ERR_NOLOGIN = "444"
ERR_SUMMONDISABLED = "445" ERR_SUMMONDISABLED = "445"
ERR_USERSDISABLED = "446" ERR_USERSDISABLED = "446"
ERR_NOTREGISTERED = "451" ERR_NOTREGISTERED = "451"
ERR_NEEDMOREPARAMS = "461" ERR_NEEDMOREPARAMS = "461"
ERR_ALREADYREGISTRED = "462" ERR_ALREADYREGISTRED = "462"
ERR_NOPERMFORHOST = "463" ERR_NOPERMFORHOST = "463"
ERR_PASSWDMISMATCH = "464" ERR_PASSWDMISMATCH = "464"
ERR_YOUREBANNEDCREEP = "465" ERR_YOUREBANNEDCREEP = "465"
ERR_YOUWILLBEBANNED = "466" ERR_YOUWILLBEBANNED = "466"
ERR_KEYSET = "467" ERR_KEYSET = "467"
ERR_INVALIDUSERNAME = "468" ERR_INVALIDUSERNAME = "468"
ERR_LINKCHANNEL = "470" ERR_LINKCHANNEL = "470"
ERR_CHANNELISFULL = "471" ERR_CHANNELISFULL = "471"
ERR_UNKNOWNMODE = "472" ERR_UNKNOWNMODE = "472"
ERR_INVITEONLYCHAN = "473" ERR_INVITEONLYCHAN = "473"
ERR_BANNEDFROMCHAN = "474" ERR_BANNEDFROMCHAN = "474"
ERR_BADCHANNELKEY = "475" ERR_BADCHANNELKEY = "475"
ERR_BADCHANMASK = "476" ERR_BADCHANMASK = "476"
ERR_NOCHANMODES = "477" ERR_NOCHANMODES = "477"
ERR_NEEDREGGEDNICK = "477" ERR_NEEDREGGEDNICK = "477"
ERR_BANLISTFULL = "478" ERR_BANLISTFULL = "478"
ERR_NOPRIVILEGES = "481" ERR_NOPRIVILEGES = "481"
ERR_CHANOPRIVSNEEDED = "482" ERR_CHANOPRIVSNEEDED = "482"
ERR_CANTKILLSERVER = "483" ERR_CANTKILLSERVER = "483"
ERR_RESTRICTED = "484" ERR_RESTRICTED = "484"
ERR_UNIQOPPRIVSNEEDED = "485" ERR_UNIQOPPRIVSNEEDED = "485"
ERR_NOOPERHOST = "491" ERR_NOOPERHOST = "491"
ERR_UMODEUNKNOWNFLAG = "501" ERR_UMODEUNKNOWNFLAG = "501"
ERR_USERSDONTMATCH = "502" ERR_USERSDONTMATCH = "502"
ERR_HELPNOTFOUND = "524" ERR_HELPNOTFOUND = "524"
ERR_CANNOTSENDRP = "573" ERR_CANNOTSENDRP = "573"
RPL_WHOISSECURE = "671" RPL_WHOISSECURE = "671"
RPL_YOURLANGUAGESARE = "687" RPL_YOURLANGUAGESARE = "687"
RPL_WHOISLANGUAGE = "690" RPL_WHOISLANGUAGE = "690"
ERR_INVALIDMODEPARAM = "696" ERR_INVALIDMODEPARAM = "696"
RPL_HELPSTART = "704" RPL_HELPSTART = "704"
RPL_HELPTXT = "705" RPL_HELPTXT = "705"
RPL_ENDOFHELP = "706" RPL_ENDOFHELP = "706"
ERR_NOPRIVS = "723" ERR_NOPRIVS = "723"
RPL_MONONLINE = "730" RPL_MONONLINE = "730"
RPL_MONOFFLINE = "731" RPL_MONOFFLINE = "731"
RPL_MONLIST = "732" RPL_MONLIST = "732"
RPL_ENDOFMONLIST = "733" RPL_ENDOFMONLIST = "733"
ERR_MONLISTFULL = "734" ERR_MONLISTFULL = "734"
RPL_LOGGEDIN = "900" RPL_LOGGEDIN = "900"
RPL_LOGGEDOUT = "901" RPL_LOGGEDOUT = "901"
ERR_NICKLOCKED = "902" ERR_NICKLOCKED = "902"
RPL_SASLSUCCESS = "903" RPL_SASLSUCCESS = "903"
ERR_SASLFAIL = "904" ERR_SASLFAIL = "904"
ERR_SASLTOOLONG = "905" ERR_SASLTOOLONG = "905"
ERR_SASLABORTED = "906" ERR_SASLABORTED = "906"
ERR_SASLALREADY = "907" ERR_SASLALREADY = "907"
RPL_SASLMECHS = "908" RPL_SASLMECHS = "908"
RPL_REGISTRATION_SUCCESS = "920" RPL_REGISTRATION_SUCCESS = "920"
ERR_ACCOUNT_ALREADY_EXISTS = "921" ERR_ACCOUNT_ALREADY_EXISTS = "921"
ERR_REG_UNSPECIFIED_ERROR = "922" ERR_REG_UNSPECIFIED_ERROR = "922"
RPL_VERIFYSUCCESS = "923" RPL_VERIFYSUCCESS = "923"
ERR_ACCOUNT_ALREADY_VERIFIED = "924" ERR_ACCOUNT_ALREADY_VERIFIED = "924"
ERR_ACCOUNT_INVALID_VERIFY_CODE = "925" ERR_ACCOUNT_INVALID_VERIFY_CODE = "925"
RPL_REG_VERIFICATION_REQUIRED = "927" RPL_REG_VERIFICATION_REQUIRED = "927"
ERR_REG_INVALID_CRED_TYPE = "928" ERR_REG_INVALID_CRED_TYPE = "928"
ERR_REG_INVALID_CALLBACK = "929" ERR_REG_INVALID_CALLBACK = "929"
ERR_TOOMANYLANGUAGES = "981" ERR_TOOMANYLANGUAGES = "981"
ERR_NOLANGUAGE = "982" ERR_NOLANGUAGE = "982"

View File

@ -2,47 +2,59 @@ import unittest
import operator import operator
import collections import collections
class NotImplementedByController(unittest.SkipTest, NotImplementedError): class NotImplementedByController(unittest.SkipTest, NotImplementedError):
def __str__(self): def __str__(self):
return 'Not implemented by controller: {}'.format(self.args[0]) return "Not implemented by controller: {}".format(self.args[0])
class ImplementationChoice(unittest.SkipTest): class ImplementationChoice(unittest.SkipTest):
def __str__(self): def __str__(self):
return 'Choice in the implementation makes it impossible to ' \ return (
'perform a test: {}'.format(self.args[0]) "Choice in the implementation makes it impossible to "
"perform a test: {}".format(self.args[0])
)
class OptionalExtensionNotSupported(unittest.SkipTest): class OptionalExtensionNotSupported(unittest.SkipTest):
def __str__(self): def __str__(self):
return 'Unsupported extension: {}'.format(self.args[0]) return "Unsupported extension: {}".format(self.args[0])
class OptionalSaslMechanismNotSupported(unittest.SkipTest): class OptionalSaslMechanismNotSupported(unittest.SkipTest):
def __str__(self): def __str__(self):
return 'Unsupported SASL mechanism: {}'.format(self.args[0]) return "Unsupported SASL mechanism: {}".format(self.args[0])
class CapabilityNotSupported(unittest.SkipTest): class CapabilityNotSupported(unittest.SkipTest):
def __str__(self): def __str__(self):
return 'Unsupported capability: {}'.format(self.args[0]) return "Unsupported capability: {}".format(self.args[0])
class NotRequiredBySpecifications(unittest.SkipTest): class NotRequiredBySpecifications(unittest.SkipTest):
def __str__(self): def __str__(self):
return 'Tests not required by the set of tested specification(s).' return "Tests not required by the set of tested specification(s)."
class SkipStrictTest(unittest.SkipTest): class SkipStrictTest(unittest.SkipTest):
def __str__(self): def __str__(self):
return 'Tests not required because strict tests are disabled.' return "Tests not required because strict tests are disabled."
class TextTestResult(unittest.TextTestResult): class TextTestResult(unittest.TextTestResult):
def getDescription(self, test): def getDescription(self, test):
if hasattr(test, 'description'): if hasattr(test, "description"):
doc_first_lines = test.description() doc_first_lines = test.description()
else: else:
doc_first_lines = test.shortDescription() doc_first_lines = test.shortDescription()
return '\n'.join((str(test), doc_first_lines or '')) return "\n".join((str(test), doc_first_lines or ""))
class TextTestRunner(unittest.TextTestRunner): class TextTestRunner(unittest.TextTestRunner):
"""Small wrapper around unittest.TextTestRunner that reports the """Small wrapper around unittest.TextTestRunner that reports the
number of tests that were skipped because the software does not support number of tests that were skipped because the software does not support
an optional feature.""" an optional feature."""
resultclass = TextTestResult resultclass = TextTestResult
def run(self, test): def run(self, test):
@ -50,11 +62,13 @@ class TextTestRunner(unittest.TextTestRunner):
assert self.resultclass is TextTestResult assert self.resultclass is TextTestResult
if result.skipped: if result.skipped:
print() print()
print('Some tests were skipped because the following optional ' print(
'specifications/mechanisms are not supported:') "Some tests were skipped because the following optional "
"specifications/mechanisms are not supported:"
)
msg_to_count = collections.defaultdict(lambda: 0) msg_to_count = collections.defaultdict(lambda: 0)
for (test, msg) in result.skipped: for (test, msg) in result.skipped:
msg_to_count[msg] += 1 msg_to_count[msg] += 1
for (msg, count) in sorted(msg_to_count.items()): for (msg, count) in sorted(msg_to_count.items()):
print('\t{} ({} test(s))'.format(msg, count)) print("\t{} ({} test(s))".format(msg, count))
return result return result

View File

@ -4,44 +4,62 @@
from irctest import cases from irctest import cases
class AccountTagTestCase(cases.BaseServerTestCase, cases.OptionalityHelper): class AccountTagTestCase(cases.BaseServerTestCase, cases.OptionalityHelper):
def connectRegisteredClient(self, nick): def connectRegisteredClient(self, nick):
self.addClient() self.addClient()
self.sendLine(2, 'CAP LS 302') self.sendLine(2, "CAP LS 302")
capabilities = self.getCapLs(2) capabilities = self.getCapLs(2)
assert 'sasl' in capabilities assert "sasl" in capabilities
self.sendLine(2, 'AUTHENTICATE PLAIN') self.sendLine(2, "AUTHENTICATE PLAIN")
m = self.getMessage(2, filter_pred=lambda m:m.command != 'NOTICE') m = self.getMessage(2, filter_pred=lambda m: m.command != "NOTICE")
self.assertMessageEqual(m, command='AUTHENTICATE', params=['+'], self.assertMessageEqual(
fail_msg='Sent “AUTHENTICATE PLAIN”, server should have ' m,
'replied with “AUTHENTICATE +”, but instead sent: {msg}') command="AUTHENTICATE",
self.sendLine(2, 'AUTHENTICATE amlsbGVzAGppbGxlcwBzZXNhbWU=') params=["+"],
m = self.getMessage(2, filter_pred=lambda m:m.command != 'NOTICE') fail_msg="Sent “AUTHENTICATE PLAIN”, server should have "
self.assertMessageEqual(m, command='900', "replied with “AUTHENTICATE +”, but instead sent: {msg}",
fail_msg='Did not send 900 after correct SASL authentication.') )
self.sendLine(2, 'USER f * * :Realname') self.sendLine(2, "AUTHENTICATE amlsbGVzAGppbGxlcwBzZXNhbWU=")
self.sendLine(2, 'NICK {}'.format(nick)) m = self.getMessage(2, filter_pred=lambda m: m.command != "NOTICE")
self.sendLine(2, 'CAP END') self.assertMessageEqual(
m,
command="900",
fail_msg="Did not send 900 after correct SASL authentication.",
)
self.sendLine(2, "USER f * * :Realname")
self.sendLine(2, "NICK {}".format(nick))
self.sendLine(2, "CAP END")
self.skipToWelcome(2) self.skipToWelcome(2)
@cases.SpecificationSelector.requiredBySpecification('IRCv3.2') @cases.SpecificationSelector.requiredBySpecification("IRCv3.2")
@cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN') @cases.OptionalityHelper.skipUnlessHasMechanism("PLAIN")
def testPrivmsg(self): def testPrivmsg(self):
self.connectClient('foo', capabilities=['account-tag'], self.connectClient("foo", capabilities=["account-tag"], skip_if_cap_nak=True)
skip_if_cap_nak=True)
self.getMessages(1) self.getMessages(1)
self.controller.registerUser(self, 'jilles', 'sesame') self.controller.registerUser(self, "jilles", "sesame")
self.connectRegisteredClient('bar') self.connectRegisteredClient("bar")
self.sendLine(2, 'PRIVMSG foo :hi') self.sendLine(2, "PRIVMSG foo :hi")
self.getMessages(2) self.getMessages(2)
m = self.getMessage(1) m = self.getMessage(1)
self.assertMessageEqual(m, command='PRIVMSG', # RPL_MONONLINE self.assertMessageEqual(
fail_msg='Sent non-730 (RPL_MONONLINE) message after ' m,
'“bar” sent a PRIVMSG: {msg}') command="PRIVMSG", # RPL_MONONLINE
self.assertIn('account', m.tags, m, fail_msg="Sent non-730 (RPL_MONONLINE) message after "
fail_msg='PRIVMSG by logged in nick ' "“bar” sent a PRIVMSG: {msg}",
'does not contain an account tag: {msg}') )
self.assertEqual(m.tags['account'], 'jilles', m, self.assertIn(
fail_msg='PRIVMSG by logged in nick ' "account",
'does not contain the correct account tag (should be ' m.tags,
'“jilles”): {msg}') m,
fail_msg="PRIVMSG by logged in nick "
"does not contain an account tag: {msg}",
)
self.assertEqual(
m.tags["account"],
"jilles",
m,
fail_msg="PRIVMSG by logged in nick "
"does not contain the correct account tag (should be "
"“jilles”): {msg}",
)

View File

@ -4,49 +4,55 @@
from irctest import cases from irctest import cases
class AwayNotifyTestCase(cases.BaseServerTestCase, cases.OptionalityHelper):
@cases.SpecificationSelector.requiredBySpecification('IRCv3.1') class AwayNotifyTestCase(cases.BaseServerTestCase, cases.OptionalityHelper):
@cases.SpecificationSelector.requiredBySpecification("IRCv3.1")
def testAwayNotify(self): def testAwayNotify(self):
"""Basic away-notify test.""" """Basic away-notify test."""
self.connectClient('foo', capabilities=['away-notify'], skip_if_cap_nak=True) self.connectClient("foo", capabilities=["away-notify"], skip_if_cap_nak=True)
self.getMessages(1) self.getMessages(1)
self.joinChannel(1, '#chan') self.joinChannel(1, "#chan")
self.connectClient('bar') self.connectClient("bar")
self.getMessages(2) self.getMessages(2)
self.joinChannel(2, '#chan') self.joinChannel(2, "#chan")
self.getMessages(2) self.getMessages(2)
self.getMessages(1) self.getMessages(1)
self.sendLine(2, "AWAY :i'm going away") self.sendLine(2, "AWAY :i'm going away")
self.getMessages(2) self.getMessages(2)
messages = [msg for msg in self.getMessages(1) if msg.command == 'AWAY'] messages = [msg for msg in self.getMessages(1) if msg.command == "AWAY"]
self.assertEqual(len(messages), 1) self.assertEqual(len(messages), 1)
awayNotify = messages[0] awayNotify = messages[0]
self.assertTrue(awayNotify.prefix.startswith('bar!'), 'Unexpected away-notify source: %s' % (awayNotify.prefix,)) self.assertTrue(
awayNotify.prefix.startswith("bar!"),
"Unexpected away-notify source: %s" % (awayNotify.prefix,),
)
self.assertEqual(awayNotify.params, ["i'm going away"]) self.assertEqual(awayNotify.params, ["i'm going away"])
@cases.SpecificationSelector.requiredBySpecification('IRCv3.2') @cases.SpecificationSelector.requiredBySpecification("IRCv3.2")
def testAwayNotifyOnJoin(self): def testAwayNotifyOnJoin(self):
"""The away-notify specification states: """The away-notify specification states:
"Clients will be sent an AWAY message [...] when a user joins and has an away message set." "Clients will be sent an AWAY message [...] when a user joins and has an away message set."
""" """
self.connectClient('foo', capabilities=['away-notify'], skip_if_cap_nak=True) self.connectClient("foo", capabilities=["away-notify"], skip_if_cap_nak=True)
self.getMessages(1) self.getMessages(1)
self.joinChannel(1, '#chan') self.joinChannel(1, "#chan")
self.connectClient('bar') self.connectClient("bar")
self.getMessages(2) self.getMessages(2)
self.sendLine(2, "AWAY :i'm already away") self.sendLine(2, "AWAY :i'm already away")
self.getMessages(2) self.getMessages(2)
self.joinChannel(2, '#chan') self.joinChannel(2, "#chan")
self.getMessages(2) self.getMessages(2)
messages = [msg for msg in self.getMessages(1) if msg.command == 'AWAY'] messages = [msg for msg in self.getMessages(1) if msg.command == "AWAY"]
self.assertEqual(len(messages), 1) self.assertEqual(len(messages), 1)
awayNotify = messages[0] awayNotify = messages[0]
self.assertTrue(awayNotify.prefix.startswith('bar!'), 'Unexpected away-notify source: %s' % (awayNotify.prefix,)) self.assertTrue(
awayNotify.prefix.startswith("bar!"),
"Unexpected away-notify source: %s" % (awayNotify.prefix,),
)
self.assertEqual(awayNotify.params, ["i'm already away"]) self.assertEqual(awayNotify.params, ["i'm already away"])

View File

@ -4,139 +4,155 @@ from irctest.irc_utils.sasl import sasl_plain_blob
from irctest.numerics import RPL_WELCOME from irctest.numerics import RPL_WELCOME
from irctest.numerics import ERR_NICKNAMEINUSE from irctest.numerics import ERR_NICKNAMEINUSE
class Bouncer(cases.BaseServerTestCase):
@cases.SpecificationSelector.requiredBySpecification('Oragono') class Bouncer(cases.BaseServerTestCase):
@cases.SpecificationSelector.requiredBySpecification("Oragono")
def testBouncer(self): def testBouncer(self):
"""Test basic bouncer functionality.""" """Test basic bouncer functionality."""
self.controller.registerUser(self, 'observer', 'observerpassword') self.controller.registerUser(self, "observer", "observerpassword")
self.controller.registerUser(self, 'testuser', 'mypassword') self.controller.registerUser(self, "testuser", "mypassword")
self.connectClient('observer', password='observerpassword') self.connectClient("observer", password="observerpassword")
self.joinChannel(1, '#chan') self.joinChannel(1, "#chan")
self.sendLine(1, 'CAP REQ :message-tags server-time') self.sendLine(1, "CAP REQ :message-tags server-time")
self.getMessages(1) self.getMessages(1)
self.addClient() self.addClient()
self.sendLine(2, 'CAP LS 302') self.sendLine(2, "CAP LS 302")
self.sendLine(2, 'AUTHENTICATE PLAIN') self.sendLine(2, "AUTHENTICATE PLAIN")
self.sendLine(2, sasl_plain_blob('testuser', 'mypassword')) self.sendLine(2, sasl_plain_blob("testuser", "mypassword"))
self.sendLine(2, 'NICK testnick') self.sendLine(2, "NICK testnick")
self.sendLine(2, 'USER a 0 * a') self.sendLine(2, "USER a 0 * a")
self.sendLine(2, 'CAP REQ :server-time message-tags') self.sendLine(2, "CAP REQ :server-time message-tags")
self.sendLine(2, 'CAP END') self.sendLine(2, "CAP END")
messages = self.getMessages(2) messages = self.getMessages(2)
welcomes = [message for message in messages if message.command == RPL_WELCOME] welcomes = [message for message in messages if message.command == RPL_WELCOME]
self.assertEqual(len(welcomes), 1) self.assertEqual(len(welcomes), 1)
# should see a regburst for testnick # should see a regburst for testnick
self.assertEqual(welcomes[0].params[0], 'testnick') self.assertEqual(welcomes[0].params[0], "testnick")
self.joinChannel(2, '#chan') self.joinChannel(2, "#chan")
self.addClient() self.addClient()
self.sendLine(3, 'CAP LS 302') self.sendLine(3, "CAP LS 302")
self.sendLine(3, 'AUTHENTICATE PLAIN') self.sendLine(3, "AUTHENTICATE PLAIN")
self.sendLine(3, sasl_plain_blob('testuser', 'mypassword')) self.sendLine(3, sasl_plain_blob("testuser", "mypassword"))
self.sendLine(3, 'NICK testnick') self.sendLine(3, "NICK testnick")
self.sendLine(3, 'USER a 0 * a') self.sendLine(3, "USER a 0 * a")
self.sendLine(3, 'CAP REQ :server-time message-tags account-tag') self.sendLine(3, "CAP REQ :server-time message-tags account-tag")
self.sendLine(3, 'CAP END') self.sendLine(3, "CAP END")
messages = self.getMessages(3) messages = self.getMessages(3)
welcomes = [message for message in messages if message.command == RPL_WELCOME] welcomes = [message for message in messages if message.command == RPL_WELCOME]
self.assertEqual(len(welcomes), 1) self.assertEqual(len(welcomes), 1)
# should see the *same* regburst for testnick # should see the *same* regburst for testnick
self.assertEqual(welcomes[0].params[0], 'testnick') self.assertEqual(welcomes[0].params[0], "testnick")
joins = [message for message in messages if message.command == 'JOIN'] joins = [message for message in messages if message.command == "JOIN"]
# we should be automatically joined to #chan # we should be automatically joined to #chan
self.assertEqual(joins[0].params[0], '#chan') self.assertEqual(joins[0].params[0], "#chan")
# disable multiclient in nickserv # disable multiclient in nickserv
self.sendLine(3, 'NS SET MULTICLIENT OFF') self.sendLine(3, "NS SET MULTICLIENT OFF")
self.getMessages(3) self.getMessages(3)
self.addClient() self.addClient()
self.sendLine(4, 'CAP LS 302') self.sendLine(4, "CAP LS 302")
self.sendLine(4, 'AUTHENTICATE PLAIN') self.sendLine(4, "AUTHENTICATE PLAIN")
self.sendLine(4, sasl_plain_blob('testuser', 'mypassword')) self.sendLine(4, sasl_plain_blob("testuser", "mypassword"))
self.sendLine(4, 'NICK testnick') self.sendLine(4, "NICK testnick")
self.sendLine(4, 'USER a 0 * a') self.sendLine(4, "USER a 0 * a")
self.sendLine(4, 'CAP REQ :server-time message-tags') self.sendLine(4, "CAP REQ :server-time message-tags")
self.sendLine(4, 'CAP END') self.sendLine(4, "CAP END")
# with multiclient disabled, we should not be able to attach to the nick # with multiclient disabled, we should not be able to attach to the nick
messages = self.getMessages(4) messages = self.getMessages(4)
welcomes = [message for message in messages if message.command == RPL_WELCOME] welcomes = [message for message in messages if message.command == RPL_WELCOME]
self.assertEqual(len(welcomes), 0) self.assertEqual(len(welcomes), 0)
errors = [message for message in messages if message.command == ERR_NICKNAMEINUSE] errors = [
message for message in messages if message.command == ERR_NICKNAMEINUSE
]
self.assertEqual(len(errors), 1) self.assertEqual(len(errors), 1)
self.sendLine(3, 'NS SET MULTICLIENT ON') self.sendLine(3, "NS SET MULTICLIENT ON")
self.getMessages(3) self.getMessages(3)
self.addClient() self.addClient()
self.sendLine(5, 'CAP LS 302') self.sendLine(5, "CAP LS 302")
self.sendLine(5, 'AUTHENTICATE PLAIN') self.sendLine(5, "AUTHENTICATE PLAIN")
self.sendLine(5, sasl_plain_blob('testuser', 'mypassword')) self.sendLine(5, sasl_plain_blob("testuser", "mypassword"))
self.sendLine(5, 'NICK testnick') self.sendLine(5, "NICK testnick")
self.sendLine(5, 'USER a 0 * a') self.sendLine(5, "USER a 0 * a")
self.sendLine(5, 'CAP REQ server-time') self.sendLine(5, "CAP REQ server-time")
self.sendLine(5, 'CAP END') self.sendLine(5, "CAP END")
messages = self.getMessages(5) messages = self.getMessages(5)
welcomes = [message for message in messages if message.command == RPL_WELCOME] welcomes = [message for message in messages if message.command == RPL_WELCOME]
self.assertEqual(len(welcomes), 1) self.assertEqual(len(welcomes), 1)
self.sendLine(1, '@+clientOnlyTag=Value PRIVMSG #chan :hey') self.sendLine(1, "@+clientOnlyTag=Value PRIVMSG #chan :hey")
self.getMessages(1) self.getMessages(1)
messagesfortwo = [msg for msg in self.getMessages(2) if msg.command == 'PRIVMSG'] messagesfortwo = [
messagesforthree = [msg for msg in self.getMessages(3) if msg.command == 'PRIVMSG'] msg for msg in self.getMessages(2) if msg.command == "PRIVMSG"
]
messagesforthree = [
msg for msg in self.getMessages(3) if msg.command == "PRIVMSG"
]
self.assertEqual(len(messagesfortwo), 1) self.assertEqual(len(messagesfortwo), 1)
self.assertEqual(len(messagesforthree), 1) self.assertEqual(len(messagesforthree), 1)
messagefortwo = messagesfortwo[0] messagefortwo = messagesfortwo[0]
messageforthree = messagesforthree[0] messageforthree = messagesforthree[0]
messageforfive = self.getMessage(5) messageforfive = self.getMessage(5)
self.assertEqual(messagefortwo.params, ['#chan', 'hey']) self.assertEqual(messagefortwo.params, ["#chan", "hey"])
self.assertEqual(messageforthree.params, ['#chan', 'hey']) self.assertEqual(messageforthree.params, ["#chan", "hey"])
self.assertEqual(messageforfive.params, ['#chan', 'hey']) self.assertEqual(messageforfive.params, ["#chan", "hey"])
self.assertIn('time', messagefortwo.tags) self.assertIn("time", messagefortwo.tags)
self.assertIn('time', messageforthree.tags) self.assertIn("time", messageforthree.tags)
self.assertIn('time', messageforfive.tags) self.assertIn("time", messageforfive.tags)
# 3 has account-tag # 3 has account-tag
self.assertIn('account', messageforthree.tags) self.assertIn("account", messageforthree.tags)
# should get same msgid # should get same msgid
self.assertEqual(messagefortwo.tags['msgid'], messageforthree.tags['msgid']) self.assertEqual(messagefortwo.tags["msgid"], messageforthree.tags["msgid"])
# 5 only has server-time, shouldn't get account or msgid tags # 5 only has server-time, shouldn't get account or msgid tags
self.assertNotIn('account', messageforfive.tags) self.assertNotIn("account", messageforfive.tags)
self.assertNotIn('msgid', messageforfive.tags) self.assertNotIn("msgid", messageforfive.tags)
# test that copies of sent messages go out to other sessions # test that copies of sent messages go out to other sessions
self.sendLine(2, 'PRIVMSG observer :this is a direct message') self.sendLine(2, "PRIVMSG observer :this is a direct message")
self.getMessages(2) self.getMessages(2)
messageForRecipient = [msg for msg in self.getMessages(1) if msg.command == 'PRIVMSG'][0] messageForRecipient = [
copyForOtherSession = [msg for msg in self.getMessages(3) if msg.command == 'PRIVMSG'][0] msg for msg in self.getMessages(1) if msg.command == "PRIVMSG"
][0]
copyForOtherSession = [
msg for msg in self.getMessages(3) if msg.command == "PRIVMSG"
][0]
self.assertEqual(messageForRecipient.params, copyForOtherSession.params) self.assertEqual(messageForRecipient.params, copyForOtherSession.params)
self.assertEqual(messageForRecipient.tags['msgid'], copyForOtherSession.tags['msgid']) self.assertEqual(
messageForRecipient.tags["msgid"], copyForOtherSession.tags["msgid"]
)
self.sendLine(2, 'QUIT :two out') self.sendLine(2, "QUIT :two out")
quitLines = [msg for msg in self.getMessages(2) if msg.command == 'QUIT'] quitLines = [msg for msg in self.getMessages(2) if msg.command == "QUIT"]
self.assertEqual(len(quitLines), 1) self.assertEqual(len(quitLines), 1)
self.assertIn('two out', quitLines[0].params[0]) self.assertIn("two out", quitLines[0].params[0])
# neither the observer nor the other attached session should see a quit here # neither the observer nor the other attached session should see a quit here
quitLines = [msg for msg in self.getMessages(1) if msg.command == 'QUIT'] quitLines = [msg for msg in self.getMessages(1) if msg.command == "QUIT"]
self.assertEqual(quitLines, []) self.assertEqual(quitLines, [])
quitLines = [msg for msg in self.getMessages(3) if msg.command == 'QUIT'] quitLines = [msg for msg in self.getMessages(3) if msg.command == "QUIT"]
self.assertEqual(quitLines, []) self.assertEqual(quitLines, [])
# session 3 should be untouched at this point # session 3 should be untouched at this point
self.sendLine(1, '@+clientOnlyTag=Value PRIVMSG #chan :hey again') self.sendLine(1, "@+clientOnlyTag=Value PRIVMSG #chan :hey again")
self.getMessages(1) self.getMessages(1)
messagesforthree = [msg for msg in self.getMessages(3) if msg.command == 'PRIVMSG'] messagesforthree = [
msg for msg in self.getMessages(3) if msg.command == "PRIVMSG"
]
self.assertEqual(len(messagesforthree), 1) self.assertEqual(len(messagesforthree), 1)
self.assertMessageEqual(messagesforthree[0], command='PRIVMSG', params=['#chan', 'hey again']) self.assertMessageEqual(
messagesforthree[0], command="PRIVMSG", params=["#chan", "hey again"]
)
self.sendLine(5, 'QUIT :five out') self.sendLine(5, "QUIT :five out")
self.getMessages(5) self.getMessages(5)
self.sendLine(3, 'QUIT :three out') self.sendLine(3, "QUIT :three out")
quitLines = [msg for msg in self.getMessages(3) if msg.command == 'QUIT'] quitLines = [msg for msg in self.getMessages(3) if msg.command == "QUIT"]
self.assertEqual(len(quitLines), 1) self.assertEqual(len(quitLines), 1)
self.assertIn('three out', quitLines[0].params[0]) self.assertIn("three out", quitLines[0].params[0])
# observer should see *this* quit # observer should see *this* quit
quitLines = [msg for msg in self.getMessages(1) if msg.command == 'QUIT'] quitLines = [msg for msg in self.getMessages(1) if msg.command == "QUIT"]
self.assertEqual(len(quitLines), 1) self.assertEqual(len(quitLines), 1)
self.assertIn('three out', quitLines[0].params[0]) self.assertIn("three out", quitLines[0].params[0])

View File

@ -1,7 +1,8 @@
from irctest import cases from irctest import cases
class CapTestCase(cases.BaseServerTestCase): class CapTestCase(cases.BaseServerTestCase):
@cases.SpecificationSelector.requiredBySpecification('IRCv3.1') @cases.SpecificationSelector.requiredBySpecification("IRCv3.1")
def testNoReq(self): def testNoReq(self):
"""Test the server handles gracefully clients which do not send """Test the server handles gracefully clients which do not send
REQs. REQs.
@ -11,38 +12,44 @@ class CapTestCase(cases.BaseServerTestCase):
-- <http://ircv3.net/specs/core/capability-negotiation-3.1.html#the-cap-end-subcommand> -- <http://ircv3.net/specs/core/capability-negotiation-3.1.html#the-cap-end-subcommand>
""" """
self.addClient(1) self.addClient(1)
self.sendLine(1, 'CAP LS 302') self.sendLine(1, "CAP LS 302")
self.getCapLs(1) self.getCapLs(1)
self.sendLine(1, 'USER foo foo foo :foo') self.sendLine(1, "USER foo foo foo :foo")
self.sendLine(1, 'NICK foo') self.sendLine(1, "NICK foo")
self.sendLine(1, 'CAP END') self.sendLine(1, "CAP END")
m = self.getRegistrationMessage(1) m = self.getRegistrationMessage(1)
self.assertMessageEqual(m, command='001', self.assertMessageEqual(
fail_msg='Expected 001 after sending CAP END, got {msg}.') m, command="001", fail_msg="Expected 001 after sending CAP END, got {msg}."
)
@cases.SpecificationSelector.requiredBySpecification('IRCv3.1') @cases.SpecificationSelector.requiredBySpecification("IRCv3.1")
def testReqUnavailable(self): def testReqUnavailable(self):
"""Test the server handles gracefully clients which request """Test the server handles gracefully clients which request
capabilities that are not available. capabilities that are not available.
<http://ircv3.net/specs/core/capability-negotiation-3.1.html> <http://ircv3.net/specs/core/capability-negotiation-3.1.html>
""" """
self.addClient(1) self.addClient(1)
self.sendLine(1, 'CAP LS 302') self.sendLine(1, "CAP LS 302")
self.getCapLs(1) self.getCapLs(1)
self.sendLine(1, 'USER foo foo foo :foo') self.sendLine(1, "USER foo foo foo :foo")
self.sendLine(1, 'NICK foo') self.sendLine(1, "NICK foo")
self.sendLine(1, 'CAP REQ :foo') self.sendLine(1, "CAP REQ :foo")
m = self.getRegistrationMessage(1) m = self.getRegistrationMessage(1)
self.assertMessageEqual(m, command='CAP', self.assertMessageEqual(
subcommand='NAK', subparams=['foo'], m,
fail_msg='Expected CAP NAK after requesting non-existing ' command="CAP",
'capability, got {msg}.') subcommand="NAK",
self.sendLine(1, 'CAP END') subparams=["foo"],
fail_msg="Expected CAP NAK after requesting non-existing "
"capability, got {msg}.",
)
self.sendLine(1, "CAP END")
m = self.getRegistrationMessage(1) m = self.getRegistrationMessage(1)
self.assertMessageEqual(m, command='001', self.assertMessageEqual(
fail_msg='Expected 001 after sending CAP END, got {msg}.') m, command="001", fail_msg="Expected 001 after sending CAP END, got {msg}."
)
@cases.SpecificationSelector.requiredBySpecification('IRCv3.1') @cases.SpecificationSelector.requiredBySpecification("IRCv3.1")
def testNakExactString(self): def testNakExactString(self):
"""“The argument of the NAK subcommand MUST consist of at least the """“The argument of the NAK subcommand MUST consist of at least the
first 100 characters of the capability list in the REQ subcommand which first 100 characters of the capability list in the REQ subcommand which
@ -50,78 +57,100 @@ class CapTestCase(cases.BaseServerTestCase):
-- <http://ircv3.net/specs/core/capability-negotiation-3.1.html#the-cap-nak-subcommand> -- <http://ircv3.net/specs/core/capability-negotiation-3.1.html#the-cap-nak-subcommand>
""" """
self.addClient(1) self.addClient(1)
self.sendLine(1, 'CAP LS 302') self.sendLine(1, "CAP LS 302")
self.getCapLs(1) self.getCapLs(1)
# Five should be enough to check there is no reordering, even # Five should be enough to check there is no reordering, even
# alphabetical # alphabetical
self.sendLine(1, 'CAP REQ :foo qux bar baz qux quux') self.sendLine(1, "CAP REQ :foo qux bar baz qux quux")
m = self.getRegistrationMessage(1) m = self.getRegistrationMessage(1)
self.assertMessageEqual(m, command='CAP', self.assertMessageEqual(
subcommand='NAK', subparams=['foo qux bar baz qux quux'], m,
fail_msg='Expected “CAP NAK :foo qux bar baz qux quux” after ' command="CAP",
'sending “CAP REQ :foo qux bar baz qux quux”, but got {msg}.') subcommand="NAK",
subparams=["foo qux bar baz qux quux"],
fail_msg="Expected “CAP NAK :foo qux bar baz qux quux” after "
"sending “CAP REQ :foo qux bar baz qux quux”, but got {msg}.",
)
@cases.SpecificationSelector.requiredBySpecification('IRCv3.1') @cases.SpecificationSelector.requiredBySpecification("IRCv3.1")
def testNakWhole(self): def testNakWhole(self):
"""“The capability identifier set must be accepted as a whole, or """“The capability identifier set must be accepted as a whole, or
rejected entirely. rejected entirely.
-- <http://ircv3.net/specs/core/capability-negotiation-3.1.html#the-cap-req-subcommand> -- <http://ircv3.net/specs/core/capability-negotiation-3.1.html#the-cap-req-subcommand>
""" """
self.addClient(1) self.addClient(1)
self.sendLine(1, 'CAP LS 302') self.sendLine(1, "CAP LS 302")
self.assertIn('multi-prefix', self.getCapLs(1)) self.assertIn("multi-prefix", self.getCapLs(1))
self.sendLine(1, 'CAP REQ :foo multi-prefix bar') self.sendLine(1, "CAP REQ :foo multi-prefix bar")
m = self.getRegistrationMessage(1) m = self.getRegistrationMessage(1)
self.assertMessageEqual(m, command='CAP', self.assertMessageEqual(
subcommand='NAK', subparams=['foo multi-prefix bar'], m,
fail_msg='Expected “CAP NAK :foo multi-prefix bar” after ' command="CAP",
'sending “CAP REQ :foo multi-prefix bar”, but got {msg}.') subcommand="NAK",
self.sendLine(1, 'CAP REQ :multi-prefix bar') subparams=["foo multi-prefix bar"],
fail_msg="Expected “CAP NAK :foo multi-prefix bar” after "
"sending “CAP REQ :foo multi-prefix bar”, but got {msg}.",
)
self.sendLine(1, "CAP REQ :multi-prefix bar")
m = self.getRegistrationMessage(1) m = self.getRegistrationMessage(1)
self.assertMessageEqual(m, command='CAP', self.assertMessageEqual(
subcommand='NAK', subparams=['multi-prefix bar'], m,
fail_msg='Expected “CAP NAK :multi-prefix bar” after ' command="CAP",
'sending “CAP REQ :multi-prefix bar”, but got {msg}.') subcommand="NAK",
self.sendLine(1, 'CAP REQ :foo multi-prefix') subparams=["multi-prefix bar"],
fail_msg="Expected “CAP NAK :multi-prefix bar” after "
"sending “CAP REQ :multi-prefix bar”, but got {msg}.",
)
self.sendLine(1, "CAP REQ :foo multi-prefix")
m = self.getRegistrationMessage(1) m = self.getRegistrationMessage(1)
self.assertMessageEqual(m, command='CAP', self.assertMessageEqual(
subcommand='NAK', subparams=['foo multi-prefix'], m,
fail_msg='Expected “CAP NAK :foo multi-prefix” after ' command="CAP",
'sending “CAP REQ :foo multi-prefix”, but got {msg}.') subcommand="NAK",
subparams=["foo multi-prefix"],
fail_msg="Expected “CAP NAK :foo multi-prefix” after "
"sending “CAP REQ :foo multi-prefix”, but got {msg}.",
)
# TODO: make sure multi-prefix is not enabled at this point # TODO: make sure multi-prefix is not enabled at this point
self.sendLine(1, 'CAP REQ :multi-prefix') self.sendLine(1, "CAP REQ :multi-prefix")
m = self.getRegistrationMessage(1) m = self.getRegistrationMessage(1)
self.assertMessageEqual(m, command='CAP', self.assertMessageEqual(
subcommand='ACK', subparams=['multi-prefix'], m,
fail_msg='Expected “CAP ACK :multi-prefix” after ' command="CAP",
'sending “CAP REQ :multi-prefix”, but got {msg}.') subcommand="ACK",
subparams=["multi-prefix"],
fail_msg="Expected “CAP ACK :multi-prefix” after "
"sending “CAP REQ :multi-prefix”, but got {msg}.",
)
@cases.SpecificationSelector.requiredBySpecification('Oragono') @cases.SpecificationSelector.requiredBySpecification("Oragono")
def testCapRemovalByClient(self): def testCapRemovalByClient(self):
"""Test CAP LIST and removal of caps via CAP REQ :-tagname.""" """Test CAP LIST and removal of caps via CAP REQ :-tagname."""
self.addClient(1) self.addClient(1)
self.sendLine(1, 'CAP LS 302') self.sendLine(1, "CAP LS 302")
self.assertIn('multi-prefix', self.getCapLs(1)) self.assertIn("multi-prefix", self.getCapLs(1))
self.sendLine(1, 'CAP REQ :echo-message server-time') self.sendLine(1, "CAP REQ :echo-message server-time")
self.sendLine(1, 'nick bar') self.sendLine(1, "nick bar")
self.sendLine(1, 'user user 0 * realname') self.sendLine(1, "user user 0 * realname")
self.sendLine(1, 'CAP END') self.sendLine(1, "CAP END")
self.skipToWelcome(1) self.skipToWelcome(1)
self.getMessages(1) self.getMessages(1)
self.sendLine(1, 'CAP LIST') self.sendLine(1, "CAP LIST")
messages = self.getMessages(1) messages = self.getMessages(1)
cap_list = [m for m in messages if m.command == 'CAP'][0] cap_list = [m for m in messages if m.command == "CAP"][0]
self.assertEqual(set(cap_list.params[2].split()), {'echo-message', 'server-time'}) self.assertEqual(
self.assertIn('time', cap_list.tags) set(cap_list.params[2].split()), {"echo-message", "server-time"}
)
self.assertIn("time", cap_list.tags)
# remove the server-time cap # remove the server-time cap
self.sendLine(1, 'CAP REQ :-server-time') self.sendLine(1, "CAP REQ :-server-time")
self.getMessages(1) self.getMessages(1)
# server-time should be disabled # server-time should be disabled
self.sendLine(1, 'CAP LIST') self.sendLine(1, "CAP LIST")
messages = self.getMessages(1) messages = self.getMessages(1)
cap_list = [m for m in messages if m.command == 'CAP'][0] cap_list = [m for m in messages if m.command == "CAP"][0]
self.assertEqual(set(cap_list.params[2].split()), {'echo-message'}) self.assertEqual(set(cap_list.params[2].split()), {"echo-message"})
self.assertNotIn('time', cap_list.tags) self.assertNotIn("time", cap_list.tags)

View File

@ -1,44 +1,52 @@
from irctest import cases from irctest import cases
from irctest.numerics import ERR_CHANOPRIVSNEEDED, ERR_INVALIDMODEPARAM, ERR_LINKCHANNEL from irctest.numerics import ERR_CHANOPRIVSNEEDED, ERR_INVALIDMODEPARAM, ERR_LINKCHANNEL
MODERN_CAPS = ['server-time', 'message-tags', 'batch', 'labeled-response', 'echo-message', 'account-tag'] MODERN_CAPS = [
"server-time",
"message-tags",
"batch",
"labeled-response",
"echo-message",
"account-tag",
]
class ChannelForwarding(cases.BaseServerTestCase): class ChannelForwarding(cases.BaseServerTestCase):
"""Test the +f channel forwarding mode.""" """Test the +f channel forwarding mode."""
@cases.SpecificationSelector.requiredBySpecification('Oragono') @cases.SpecificationSelector.requiredBySpecification("Oragono")
def testChannelForwarding(self): def testChannelForwarding(self):
self.connectClient('bar', name='bar', capabilities=MODERN_CAPS) self.connectClient("bar", name="bar", capabilities=MODERN_CAPS)
self.connectClient('baz', name='baz', capabilities=MODERN_CAPS) self.connectClient("baz", name="baz", capabilities=MODERN_CAPS)
self.joinChannel('bar', '#bar') self.joinChannel("bar", "#bar")
self.joinChannel('bar', '#bar_two') self.joinChannel("bar", "#bar_two")
self.joinChannel('baz', '#baz') self.joinChannel("baz", "#baz")
self.sendLine('bar', 'MODE #bar +f #nonexistent') self.sendLine("bar", "MODE #bar +f #nonexistent")
msg = self.getMessage('bar') msg = self.getMessage("bar")
self.assertMessageEqual(msg, command=ERR_INVALIDMODEPARAM) self.assertMessageEqual(msg, command=ERR_INVALIDMODEPARAM)
# need chanops in the target channel as well # need chanops in the target channel as well
self.sendLine('bar', 'MODE #bar +f #baz') self.sendLine("bar", "MODE #bar +f #baz")
responses = set(msg.command for msg in self.getMessages('bar')) responses = set(msg.command for msg in self.getMessages("bar"))
self.assertIn(ERR_CHANOPRIVSNEEDED, responses) self.assertIn(ERR_CHANOPRIVSNEEDED, responses)
self.sendLine('bar', 'MODE #bar +f #bar_two') self.sendLine("bar", "MODE #bar +f #bar_two")
msg = self.getMessage('bar') msg = self.getMessage("bar")
self.assertMessageEqual(msg, command='MODE', params=['#bar', '+f', '#bar_two']) self.assertMessageEqual(msg, command="MODE", params=["#bar", "+f", "#bar_two"])
# can still join the channel fine # can still join the channel fine
self.joinChannel('baz', '#bar') self.joinChannel("baz", "#bar")
self.sendLine('baz', 'PART #bar') self.sendLine("baz", "PART #bar")
self.getMessages('baz') self.getMessages("baz")
# now make it invite-only, which should cause forwarding # now make it invite-only, which should cause forwarding
self.sendLine('bar', 'MODE #bar +i') self.sendLine("bar", "MODE #bar +i")
self.getMessages('bar') self.getMessages("bar")
self.sendLine('baz', 'JOIN #bar') self.sendLine("baz", "JOIN #bar")
msgs = self.getMessages('baz') msgs = self.getMessages("baz")
forward = [msg for msg in msgs if msg.command == ERR_LINKCHANNEL] forward = [msg for msg in msgs if msg.command == ERR_LINKCHANNEL]
self.assertEqual(forward[0].params[:3], ['baz', '#bar', '#bar_two']) self.assertEqual(forward[0].params[:3], ["baz", "#bar", "#bar_two"])
join = [msg for msg in msgs if msg.command == 'JOIN'] join = [msg for msg in msgs if msg.command == "JOIN"]
self.assertMessageEqual(join[0], params=['#bar_two']) self.assertMessageEqual(join[0], params=["#bar_two"])

File diff suppressed because it is too large Load Diff

View File

@ -1,27 +1,59 @@
from irctest import cases from irctest import cases
from irctest.numerics import ERR_CHANOPRIVSNEEDED from irctest.numerics import ERR_CHANOPRIVSNEEDED
MODERN_CAPS = ['server-time', 'message-tags', 'batch', 'labeled-response', 'echo-message', 'account-tag'] MODERN_CAPS = [
RENAME_CAP = 'draft/channel-rename' "server-time",
"message-tags",
"batch",
"labeled-response",
"echo-message",
"account-tag",
]
RENAME_CAP = "draft/channel-rename"
class ChannelRename(cases.BaseServerTestCase): class ChannelRename(cases.BaseServerTestCase):
"""Basic tests for channel-rename.""" """Basic tests for channel-rename."""
@cases.SpecificationSelector.requiredBySpecification('Oragono') @cases.SpecificationSelector.requiredBySpecification("Oragono")
def testChannelRename(self): def testChannelRename(self):
self.connectClient('bar', name='bar', capabilities=MODERN_CAPS+[RENAME_CAP]) self.connectClient("bar", name="bar", capabilities=MODERN_CAPS + [RENAME_CAP])
self.connectClient('baz', name='baz', capabilities=MODERN_CAPS) self.connectClient("baz", name="baz", capabilities=MODERN_CAPS)
self.joinChannel('bar', '#bar') self.joinChannel("bar", "#bar")
self.joinChannel('baz', '#bar') self.joinChannel("baz", "#bar")
self.getMessages('bar') self.getMessages("bar")
self.getMessages('baz') self.getMessages("baz")
self.sendLine('bar', 'RENAME #bar #qux :no reason') self.sendLine("bar", "RENAME #bar #qux :no reason")
self.assertMessageEqual(self.getMessage('bar'), command='RENAME', params=['#bar', '#qux', 'no reason']) self.assertMessageEqual(
legacy_responses = self.getMessages('baz') self.getMessage("bar"),
self.assertEqual(1, len([msg for msg in legacy_responses if msg.command == 'PART' and msg.params[0] == '#bar'])) command="RENAME",
self.assertEqual(1, len([msg for msg in legacy_responses if msg.command == 'JOIN' and msg.params == ['#qux']])) params=["#bar", "#qux", "no reason"],
)
legacy_responses = self.getMessages("baz")
self.assertEqual(
1,
len(
[
msg
for msg in legacy_responses
if msg.command == "PART" and msg.params[0] == "#bar"
]
),
)
self.assertEqual(
1,
len(
[
msg
for msg in legacy_responses
if msg.command == "JOIN" and msg.params == ["#qux"]
]
),
)
self.joinChannel('baz', '#bar') self.joinChannel("baz", "#bar")
self.sendLine('baz', 'MODE #bar +k beer') self.sendLine("baz", "MODE #bar +k beer")
self.assertNotIn(ERR_CHANOPRIVSNEEDED, [msg.command for msg in self.getMessages('baz')]) self.assertNotIn(
ERR_CHANOPRIVSNEEDED, [msg.command for msg in self.getMessages("baz")]
)

View File

@ -4,12 +4,13 @@ import time
from irctest import cases from irctest import cases
from irctest.irc_utils.junkdrawer import to_history_message, random_name from irctest.irc_utils.junkdrawer import to_history_message, random_name
CHATHISTORY_CAP = 'draft/chathistory' CHATHISTORY_CAP = "draft/chathistory"
EVENT_PLAYBACK_CAP = 'draft/event-playback' EVENT_PLAYBACK_CAP = "draft/event-playback"
MYSQL_PASSWORD = "" MYSQL_PASSWORD = ""
def validate_chathistory_batch(msgs): def validate_chathistory_batch(msgs):
batch_tag = None batch_tag = None
closed_batch_tag = None closed_batch_tag = None
@ -17,91 +18,120 @@ def validate_chathistory_batch(msgs):
for msg in msgs: for msg in msgs:
if msg.command == "BATCH": if msg.command == "BATCH":
batch_param = msg.params[0] batch_param = msg.params[0]
if batch_tag is None and batch_param[0] == '+': if batch_tag is None and batch_param[0] == "+":
batch_tag = batch_param[1:] batch_tag = batch_param[1:]
elif batch_param[0] == '-': elif batch_param[0] == "-":
closed_batch_tag = batch_param[1:] closed_batch_tag = batch_param[1:]
elif msg.command == "PRIVMSG" and batch_tag is not None and msg.tags.get("batch") == batch_tag: elif (
msg.command == "PRIVMSG"
and batch_tag is not None
and msg.tags.get("batch") == batch_tag
):
result.append(to_history_message(msg)) result.append(to_history_message(msg))
assert batch_tag == closed_batch_tag assert batch_tag == closed_batch_tag
return result return result
class ChathistoryTestCase(cases.BaseServerTestCase): class ChathistoryTestCase(cases.BaseServerTestCase):
@staticmethod @staticmethod
def config(): def config():
return { return {
"chathistory": True, "chathistory": True,
} }
@cases.SpecificationSelector.requiredBySpecification('Oragono') @cases.SpecificationSelector.requiredBySpecification("Oragono")
def testInvalidTargets(self): def testInvalidTargets(self):
bar, pw = random_name('bar'), random_name('pw') bar, pw = random_name("bar"), random_name("pw")
self.controller.registerUser(self, bar, pw) self.controller.registerUser(self, bar, pw)
self.connectClient(bar, name=bar, capabilities=['batch', 'labeled-response', 'message-tags', 'server-time', CHATHISTORY_CAP, EVENT_PLAYBACK_CAP], password=pw) self.connectClient(
bar,
name=bar,
capabilities=[
"batch",
"labeled-response",
"message-tags",
"server-time",
CHATHISTORY_CAP,
EVENT_PLAYBACK_CAP,
],
password=pw,
)
self.getMessages(bar) self.getMessages(bar)
qux = random_name('qux') qux = random_name("qux")
real_chname = random_name('#real_channel') real_chname = random_name("#real_channel")
self.connectClient(qux, name=qux) self.connectClient(qux, name=qux)
self.joinChannel(qux, real_chname) self.joinChannel(qux, real_chname)
self.getMessages(qux) self.getMessages(qux)
# test a nonexistent channel # test a nonexistent channel
self.sendLine(bar, 'CHATHISTORY LATEST #nonexistent_channel * 10') self.sendLine(bar, "CHATHISTORY LATEST #nonexistent_channel * 10")
msgs = self.getMessages(bar) msgs = self.getMessages(bar)
self.assertEqual(msgs[0].command, 'FAIL') self.assertEqual(msgs[0].command, "FAIL")
self.assertEqual(msgs[0].params[:2], ['CHATHISTORY', 'INVALID_TARGET']) self.assertEqual(msgs[0].params[:2], ["CHATHISTORY", "INVALID_TARGET"])
# as should a real channel to which one is not joined: # as should a real channel to which one is not joined:
self.sendLine(bar, 'CHATHISTORY LATEST %s * 10' % (real_chname,)) self.sendLine(bar, "CHATHISTORY LATEST %s * 10" % (real_chname,))
msgs = self.getMessages(bar) msgs = self.getMessages(bar)
self.assertEqual(msgs[0].command, 'FAIL') self.assertEqual(msgs[0].command, "FAIL")
self.assertEqual(msgs[0].params[:2], ['CHATHISTORY', 'INVALID_TARGET']) self.assertEqual(msgs[0].params[:2], ["CHATHISTORY", "INVALID_TARGET"])
@cases.SpecificationSelector.requiredBySpecification('Oragono') @cases.SpecificationSelector.requiredBySpecification("Oragono")
def testMessagesToSelf(self): def testMessagesToSelf(self):
bar, pw = random_name('bar'), random_name('pw') bar, pw = random_name("bar"), random_name("pw")
self.controller.registerUser(self, bar, pw) self.controller.registerUser(self, bar, pw)
self.connectClient(bar, name=bar, capabilities=['batch', 'labeled-response', 'message-tags', 'server-time'], password=pw) self.connectClient(
bar,
name=bar,
capabilities=["batch", "labeled-response", "message-tags", "server-time"],
password=pw,
)
self.getMessages(bar) self.getMessages(bar)
messages = [] messages = []
self.sendLine(bar, 'PRIVMSG %s :this is a privmsg sent to myself' % (bar,)) self.sendLine(bar, "PRIVMSG %s :this is a privmsg sent to myself" % (bar,))
replies = [msg for msg in self.getMessages(bar) if msg.command == 'PRIVMSG'] replies = [msg for msg in self.getMessages(bar) if msg.command == "PRIVMSG"]
self.assertEqual(len(replies), 1) self.assertEqual(len(replies), 1)
msg = replies[0] msg = replies[0]
self.assertEqual(msg.params, [bar, 'this is a privmsg sent to myself']) self.assertEqual(msg.params, [bar, "this is a privmsg sent to myself"])
messages.append(to_history_message(msg)) messages.append(to_history_message(msg))
self.sendLine(bar, 'CAP REQ echo-message') self.sendLine(bar, "CAP REQ echo-message")
self.getMessages(bar) self.getMessages(bar)
self.sendLine(bar, 'PRIVMSG %s :this is a second privmsg sent to myself' % (bar,)) self.sendLine(
replies = [msg for msg in self.getMessages(bar) if msg.command == 'PRIVMSG'] bar, "PRIVMSG %s :this is a second privmsg sent to myself" % (bar,)
)
replies = [msg for msg in self.getMessages(bar) if msg.command == "PRIVMSG"]
# two messages, the echo and the delivery # two messages, the echo and the delivery
self.assertEqual(len(replies), 2) self.assertEqual(len(replies), 2)
self.assertEqual(replies[0].params, [bar, 'this is a second privmsg sent to myself']) self.assertEqual(
replies[0].params, [bar, "this is a second privmsg sent to myself"]
)
messages.append(to_history_message(replies[0])) messages.append(to_history_message(replies[0]))
# messages should be otherwise identical # messages should be otherwise identical
self.assertEqual(to_history_message(replies[0]), to_history_message(replies[1])) self.assertEqual(to_history_message(replies[0]), to_history_message(replies[1]))
self.sendLine(bar, '@label=xyz PRIVMSG %s :this is a third privmsg sent to myself' % (bar,)) self.sendLine(
replies = [msg for msg in self.getMessages(bar) if msg.command == 'PRIVMSG'] bar,
"@label=xyz PRIVMSG %s :this is a third privmsg sent to myself" % (bar,),
)
replies = [msg for msg in self.getMessages(bar) if msg.command == "PRIVMSG"]
self.assertEqual(len(replies), 2) self.assertEqual(len(replies), 2)
# exactly one of the replies MUST be labeled # exactly one of the replies MUST be labeled
echo = [msg for msg in replies if msg.tags.get('label') == 'xyz'][0] echo = [msg for msg in replies if msg.tags.get("label") == "xyz"][0]
delivery = [msg for msg in replies if msg.tags.get('label') is None][0] delivery = [msg for msg in replies if msg.tags.get("label") is None][0]
self.assertEqual(echo.params, [bar, 'this is a third privmsg sent to myself']) self.assertEqual(echo.params, [bar, "this is a third privmsg sent to myself"])
messages.append(to_history_message(echo)) messages.append(to_history_message(echo))
self.assertEqual(to_history_message(echo), to_history_message(delivery)) self.assertEqual(to_history_message(echo), to_history_message(delivery))
# should receive exactly 3 messages in the correct order, no duplicates # should receive exactly 3 messages in the correct order, no duplicates
self.sendLine(bar, 'CHATHISTORY LATEST * * 10') self.sendLine(bar, "CHATHISTORY LATEST * * 10")
replies = [msg for msg in self.getMessages(bar) if msg.command == 'PRIVMSG'] replies = [msg for msg in self.getMessages(bar) if msg.command == "PRIVMSG"]
self.assertEqual([to_history_message(msg) for msg in replies], messages) self.assertEqual([to_history_message(msg) for msg in replies], messages)
self.sendLine(bar, 'CHATHISTORY LATEST %s * 10' % (bar,)) self.sendLine(bar, "CHATHISTORY LATEST %s * 10" % (bar,))
replies = [msg for msg in self.getMessages(bar) if msg.command == 'PRIVMSG'] replies = [msg for msg in self.getMessages(bar) if msg.command == "PRIVMSG"]
self.assertEqual([to_history_message(msg) for msg in replies], messages) self.assertEqual([to_history_message(msg) for msg in replies], messages)
def validate_echo_messages(self, num_messages, echo_messages): def validate_echo_messages(self, num_messages, echo_messages):
@ -111,31 +141,66 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
self.assertEqual(len(set(msg.msgid for msg in echo_messages)), num_messages) self.assertEqual(len(set(msg.msgid for msg in echo_messages)), num_messages)
self.assertEqual(len(set(msg.time for msg in echo_messages)), num_messages) self.assertEqual(len(set(msg.time for msg in echo_messages)), num_messages)
@cases.SpecificationSelector.requiredBySpecification('Oragono') @cases.SpecificationSelector.requiredBySpecification("Oragono")
def testChathistory(self): def testChathistory(self):
self.connectClient('bar', capabilities=['message-tags', 'server-time', 'echo-message', 'batch', 'labeled-response', CHATHISTORY_CAP, EVENT_PLAYBACK_CAP]) self.connectClient(
chname = '#' + secrets.token_hex(12) "bar",
capabilities=[
"message-tags",
"server-time",
"echo-message",
"batch",
"labeled-response",
CHATHISTORY_CAP,
EVENT_PLAYBACK_CAP,
],
)
chname = "#" + secrets.token_hex(12)
self.joinChannel(1, chname) self.joinChannel(1, chname)
self.getMessages(1) self.getMessages(1)
NUM_MESSAGES = 10 NUM_MESSAGES = 10
echo_messages = [] echo_messages = []
for i in range(NUM_MESSAGES): for i in range(NUM_MESSAGES):
self.sendLine(1, 'PRIVMSG %s :this is message %d' % (chname, i)) self.sendLine(1, "PRIVMSG %s :this is message %d" % (chname, i))
echo_messages.extend(to_history_message(msg) for msg in self.getMessages(1)) echo_messages.extend(to_history_message(msg) for msg in self.getMessages(1))
time.sleep(0.002) time.sleep(0.002)
self.validate_echo_messages(NUM_MESSAGES, echo_messages) self.validate_echo_messages(NUM_MESSAGES, echo_messages)
self.validate_chathistory(echo_messages, 1, chname) self.validate_chathistory(echo_messages, 1, chname)
@cases.SpecificationSelector.requiredBySpecification('Oragono') @cases.SpecificationSelector.requiredBySpecification("Oragono")
def testChathistoryDMs(self): def testChathistoryDMs(self):
c1 = secrets.token_hex(12) c1 = secrets.token_hex(12)
c2 = secrets.token_hex(12) c2 = secrets.token_hex(12)
self.controller.registerUser(self, c1, 'sesame1') self.controller.registerUser(self, c1, "sesame1")
self.controller.registerUser(self, c2, 'sesame2') self.controller.registerUser(self, c2, "sesame2")
self.connectClient(c1, capabilities=['message-tags', 'server-time', 'echo-message', 'batch', 'labeled-response', CHATHISTORY_CAP, EVENT_PLAYBACK_CAP], password='sesame1') self.connectClient(
self.connectClient(c2, capabilities=['message-tags', 'server-time', 'echo-message', 'batch', 'labeled-response', CHATHISTORY_CAP, EVENT_PLAYBACK_CAP], password='sesame2') c1,
capabilities=[
"message-tags",
"server-time",
"echo-message",
"batch",
"labeled-response",
CHATHISTORY_CAP,
EVENT_PLAYBACK_CAP,
],
password="sesame1",
)
self.connectClient(
c2,
capabilities=[
"message-tags",
"server-time",
"echo-message",
"batch",
"labeled-response",
CHATHISTORY_CAP,
EVENT_PLAYBACK_CAP,
],
password="sesame2",
)
self.getMessages(1) self.getMessages(1)
self.getMessages(2) self.getMessages(2)
@ -148,29 +213,60 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
else: else:
target = c1 target = c1
self.getMessages(user) self.getMessages(user)
self.sendLine(user, 'PRIVMSG %s :this is message %d' % (target, i)) self.sendLine(user, "PRIVMSG %s :this is message %d" % (target, i))
echo_messages.extend(to_history_message(msg) for msg in self.getMessages(user)) echo_messages.extend(
to_history_message(msg) for msg in self.getMessages(user)
)
time.sleep(0.002) time.sleep(0.002)
self.validate_echo_messages(NUM_MESSAGES, echo_messages) self.validate_echo_messages(NUM_MESSAGES, echo_messages)
self.validate_chathistory(echo_messages, 1, c2) self.validate_chathistory(echo_messages, 1, c2)
self.validate_chathistory(echo_messages, 1, '*') self.validate_chathistory(echo_messages, 1, "*")
self.validate_chathistory(echo_messages, 2, c1) self.validate_chathistory(echo_messages, 2, c1)
self.validate_chathistory(echo_messages, 2, '*') self.validate_chathistory(echo_messages, 2, "*")
c3 = secrets.token_hex(12) c3 = secrets.token_hex(12)
self.connectClient(c3, capabilities=['message-tags', 'server-time', 'echo-message', 'batch', 'labeled-response', CHATHISTORY_CAP, EVENT_PLAYBACK_CAP]) self.connectClient(
self.sendLine(1, 'PRIVMSG %s :this is a message in a separate conversation' % (c3,)) c3,
capabilities=[
"message-tags",
"server-time",
"echo-message",
"batch",
"labeled-response",
CHATHISTORY_CAP,
EVENT_PLAYBACK_CAP,
],
)
self.sendLine(
1, "PRIVMSG %s :this is a message in a separate conversation" % (c3,)
)
self.getMessages(1) self.getMessages(1)
self.sendLine(3, 'PRIVMSG %s :i agree that this is a separate conversation' % (c1,)) self.sendLine(
3, "PRIVMSG %s :i agree that this is a separate conversation" % (c1,)
)
# 3 received the first message as a delivery and the second as an echo # 3 received the first message as a delivery and the second as an echo
new_convo = [to_history_message(msg) for msg in self.getMessages(3) if msg.command == 'PRIVMSG'] new_convo = [
self.assertEqual([msg.text for msg in new_convo], ['this is a message in a separate conversation', 'i agree that this is a separate conversation']) to_history_message(msg)
for msg in self.getMessages(3)
if msg.command == "PRIVMSG"
]
self.assertEqual(
[msg.text for msg in new_convo],
[
"this is a message in a separate conversation",
"i agree that this is a separate conversation",
],
)
# messages should be stored and retrievable by c1, even though c3 is not registered # messages should be stored and retrievable by c1, even though c3 is not registered
self.getMessages(1) self.getMessages(1)
self.sendLine(1, 'CHATHISTORY LATEST %s * 10' % (c3,)) self.sendLine(1, "CHATHISTORY LATEST %s * 10" % (c3,))
results = [to_history_message(msg) for msg in self.getMessages(1) if msg.command == 'PRIVMSG'] results = [
to_history_message(msg)
for msg in self.getMessages(1)
if msg.command == "PRIVMSG"
]
self.assertEqual(results, new_convo) self.assertEqual(results, new_convo)
# additional messages with c3 should not show up in the c1-c2 history: # additional messages with c3 should not show up in the c1-c2 history:
@ -179,14 +275,31 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
self.validate_chathistory(echo_messages, 2, c1.upper()) self.validate_chathistory(echo_messages, 2, c1.upper())
# regression test for #833 # regression test for #833
self.sendLine(3, 'QUIT') self.sendLine(3, "QUIT")
self.assertDisconnected(3) self.assertDisconnected(3)
# register c3 as an account, then attempt to retrieve the conversation history with c1 # register c3 as an account, then attempt to retrieve the conversation history with c1
self.controller.registerUser(self, c3, 'sesame3') self.controller.registerUser(self, c3, "sesame3")
self.connectClient(c3, name=c3, capabilities=['message-tags', 'server-time', 'echo-message', 'batch', 'labeled-response', CHATHISTORY_CAP, EVENT_PLAYBACK_CAP], password='sesame3') self.connectClient(
c3,
name=c3,
capabilities=[
"message-tags",
"server-time",
"echo-message",
"batch",
"labeled-response",
CHATHISTORY_CAP,
EVENT_PLAYBACK_CAP,
],
password="sesame3",
)
self.getMessages(c3) self.getMessages(c3)
self.sendLine(c3, 'CHATHISTORY LATEST %s * 10' % (c1,)) self.sendLine(c3, "CHATHISTORY LATEST %s * 10" % (c1,))
results = [to_history_message(msg) for msg in self.getMessages(c3) if msg.command == 'PRIVMSG'] results = [
to_history_message(msg)
for msg in self.getMessages(c3)
if msg.command == "PRIVMSG"
]
# should get nothing # should get nothing
self.assertEqual(results, []) self.assertEqual(results, [])
@ -205,105 +318,213 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
result = validate_chathistory_batch(self.getMessages(user)) result = validate_chathistory_batch(self.getMessages(user))
self.assertEqual(echo_messages[-1:], result) self.assertEqual(echo_messages[-1:], result)
self.sendLine(user, "CHATHISTORY LATEST %s msgid=%s %d" % (chname, echo_messages[4].msgid, INCLUSIVE_LIMIT)) self.sendLine(
user,
"CHATHISTORY LATEST %s msgid=%s %d"
% (chname, echo_messages[4].msgid, INCLUSIVE_LIMIT),
)
result = validate_chathistory_batch(self.getMessages(user)) result = validate_chathistory_batch(self.getMessages(user))
self.assertEqual(echo_messages[5:], result) self.assertEqual(echo_messages[5:], result)
self.sendLine(user, "CHATHISTORY LATEST %s timestamp=%s %d" % (chname, echo_messages[4].time, INCLUSIVE_LIMIT)) self.sendLine(
user,
"CHATHISTORY LATEST %s timestamp=%s %d"
% (chname, echo_messages[4].time, INCLUSIVE_LIMIT),
)
result = validate_chathistory_batch(self.getMessages(user)) result = validate_chathistory_batch(self.getMessages(user))
self.assertEqual(echo_messages[5:], result) self.assertEqual(echo_messages[5:], result)
self.sendLine(user, "CHATHISTORY BEFORE %s msgid=%s %d" % (chname, echo_messages[6].msgid, INCLUSIVE_LIMIT)) self.sendLine(
user,
"CHATHISTORY BEFORE %s msgid=%s %d"
% (chname, echo_messages[6].msgid, INCLUSIVE_LIMIT),
)
result = validate_chathistory_batch(self.getMessages(user)) result = validate_chathistory_batch(self.getMessages(user))
self.assertEqual(echo_messages[:6], result) self.assertEqual(echo_messages[:6], result)
self.sendLine(user, "CHATHISTORY BEFORE %s timestamp=%s %d" % (chname, echo_messages[6].time, INCLUSIVE_LIMIT)) self.sendLine(
user,
"CHATHISTORY BEFORE %s timestamp=%s %d"
% (chname, echo_messages[6].time, INCLUSIVE_LIMIT),
)
result = validate_chathistory_batch(self.getMessages(user)) result = validate_chathistory_batch(self.getMessages(user))
self.assertEqual(echo_messages[:6], result) self.assertEqual(echo_messages[:6], result)
self.sendLine(user, "CHATHISTORY BEFORE %s timestamp=%s %d" % (chname, echo_messages[6].time, 2)) self.sendLine(
user,
"CHATHISTORY BEFORE %s timestamp=%s %d"
% (chname, echo_messages[6].time, 2),
)
result = validate_chathistory_batch(self.getMessages(user)) result = validate_chathistory_batch(self.getMessages(user))
self.assertEqual(echo_messages[4:6], result) self.assertEqual(echo_messages[4:6], result)
self.sendLine(user, "CHATHISTORY AFTER %s msgid=%s %d" % (chname, echo_messages[3].msgid, INCLUSIVE_LIMIT)) self.sendLine(
user,
"CHATHISTORY AFTER %s msgid=%s %d"
% (chname, echo_messages[3].msgid, INCLUSIVE_LIMIT),
)
result = validate_chathistory_batch(self.getMessages(user)) result = validate_chathistory_batch(self.getMessages(user))
self.assertEqual(echo_messages[4:], result) self.assertEqual(echo_messages[4:], result)
self.sendLine(user, "CHATHISTORY AFTER %s timestamp=%s %d" % (chname, echo_messages[3].time, INCLUSIVE_LIMIT)) self.sendLine(
user,
"CHATHISTORY AFTER %s timestamp=%s %d"
% (chname, echo_messages[3].time, INCLUSIVE_LIMIT),
)
result = validate_chathistory_batch(self.getMessages(user)) result = validate_chathistory_batch(self.getMessages(user))
self.assertEqual(echo_messages[4:], result) self.assertEqual(echo_messages[4:], result)
self.sendLine(user, "CHATHISTORY AFTER %s timestamp=%s %d" % (chname, echo_messages[3].time, 3)) self.sendLine(
user,
"CHATHISTORY AFTER %s timestamp=%s %d" % (chname, echo_messages[3].time, 3),
)
result = validate_chathistory_batch(self.getMessages(user)) result = validate_chathistory_batch(self.getMessages(user))
self.assertEqual(echo_messages[4:7], result) self.assertEqual(echo_messages[4:7], result)
# BETWEEN forwards and backwards # BETWEEN forwards and backwards
self.sendLine(user, "CHATHISTORY BETWEEN %s msgid=%s msgid=%s %d" % (chname, echo_messages[0].msgid, echo_messages[-1].msgid, INCLUSIVE_LIMIT)) self.sendLine(
user,
"CHATHISTORY BETWEEN %s msgid=%s msgid=%s %d"
% (
chname,
echo_messages[0].msgid,
echo_messages[-1].msgid,
INCLUSIVE_LIMIT,
),
)
result = validate_chathistory_batch(self.getMessages(user)) result = validate_chathistory_batch(self.getMessages(user))
self.assertEqual(echo_messages[1:-1], result) self.assertEqual(echo_messages[1:-1], result)
self.sendLine(user, "CHATHISTORY BETWEEN %s msgid=%s msgid=%s %d" % (chname, echo_messages[-1].msgid, echo_messages[0].msgid, INCLUSIVE_LIMIT)) self.sendLine(
user,
"CHATHISTORY BETWEEN %s msgid=%s msgid=%s %d"
% (
chname,
echo_messages[-1].msgid,
echo_messages[0].msgid,
INCLUSIVE_LIMIT,
),
)
result = validate_chathistory_batch(self.getMessages(user)) result = validate_chathistory_batch(self.getMessages(user))
self.assertEqual(echo_messages[1:-1], result) self.assertEqual(echo_messages[1:-1], result)
# BETWEEN forwards and backwards with a limit, should get different results this time # BETWEEN forwards and backwards with a limit, should get different results this time
self.sendLine(user, "CHATHISTORY BETWEEN %s msgid=%s msgid=%s %d" % (chname, echo_messages[0].msgid, echo_messages[-1].msgid, 3)) self.sendLine(
user,
"CHATHISTORY BETWEEN %s msgid=%s msgid=%s %d"
% (chname, echo_messages[0].msgid, echo_messages[-1].msgid, 3),
)
result = validate_chathistory_batch(self.getMessages(user)) result = validate_chathistory_batch(self.getMessages(user))
self.assertEqual(echo_messages[1:4], result) self.assertEqual(echo_messages[1:4], result)
self.sendLine(user, "CHATHISTORY BETWEEN %s msgid=%s msgid=%s %d" % (chname, echo_messages[-1].msgid, echo_messages[0].msgid, 3)) self.sendLine(
user,
"CHATHISTORY BETWEEN %s msgid=%s msgid=%s %d"
% (chname, echo_messages[-1].msgid, echo_messages[0].msgid, 3),
)
result = validate_chathistory_batch(self.getMessages(user)) result = validate_chathistory_batch(self.getMessages(user))
self.assertEqual(echo_messages[-4:-1], result) self.assertEqual(echo_messages[-4:-1], result)
# same stuff again but with timestamps # same stuff again but with timestamps
self.sendLine(user, "CHATHISTORY BETWEEN %s timestamp=%s timestamp=%s %d" % (chname, echo_messages[0].time, echo_messages[-1].time, INCLUSIVE_LIMIT)) self.sendLine(
user,
"CHATHISTORY BETWEEN %s timestamp=%s timestamp=%s %d"
% (chname, echo_messages[0].time, echo_messages[-1].time, INCLUSIVE_LIMIT),
)
result = validate_chathistory_batch(self.getMessages(user)) result = validate_chathistory_batch(self.getMessages(user))
self.assertEqual(echo_messages[1:-1], result) self.assertEqual(echo_messages[1:-1], result)
self.sendLine(user, "CHATHISTORY BETWEEN %s timestamp=%s timestamp=%s %d" % (chname, echo_messages[-1].time, echo_messages[0].time, INCLUSIVE_LIMIT)) self.sendLine(
user,
"CHATHISTORY BETWEEN %s timestamp=%s timestamp=%s %d"
% (chname, echo_messages[-1].time, echo_messages[0].time, INCLUSIVE_LIMIT),
)
result = validate_chathistory_batch(self.getMessages(user)) result = validate_chathistory_batch(self.getMessages(user))
self.assertEqual(echo_messages[1:-1], result) self.assertEqual(echo_messages[1:-1], result)
self.sendLine(user, "CHATHISTORY BETWEEN %s timestamp=%s timestamp=%s %d" % (chname, echo_messages[0].time, echo_messages[-1].time, 3)) self.sendLine(
user,
"CHATHISTORY BETWEEN %s timestamp=%s timestamp=%s %d"
% (chname, echo_messages[0].time, echo_messages[-1].time, 3),
)
result = validate_chathistory_batch(self.getMessages(user)) result = validate_chathistory_batch(self.getMessages(user))
self.assertEqual(echo_messages[1:4], result) self.assertEqual(echo_messages[1:4], result)
self.sendLine(user, "CHATHISTORY BETWEEN %s timestamp=%s timestamp=%s %d" % (chname, echo_messages[-1].time, echo_messages[0].time, 3)) self.sendLine(
user,
"CHATHISTORY BETWEEN %s timestamp=%s timestamp=%s %d"
% (chname, echo_messages[-1].time, echo_messages[0].time, 3),
)
result = validate_chathistory_batch(self.getMessages(user)) result = validate_chathistory_batch(self.getMessages(user))
self.assertEqual(echo_messages[-4:-1], result) self.assertEqual(echo_messages[-4:-1], result)
# AROUND # AROUND
self.sendLine(user, "CHATHISTORY AROUND %s msgid=%s %d" % (chname, echo_messages[7].msgid, 1)) self.sendLine(
user,
"CHATHISTORY AROUND %s msgid=%s %d" % (chname, echo_messages[7].msgid, 1),
)
result = validate_chathistory_batch(self.getMessages(user)) result = validate_chathistory_batch(self.getMessages(user))
self.assertEqual([echo_messages[7]], result) self.assertEqual([echo_messages[7]], result)
self.sendLine(user, "CHATHISTORY AROUND %s msgid=%s %d" % (chname, echo_messages[7].msgid, 3)) self.sendLine(
user,
"CHATHISTORY AROUND %s msgid=%s %d" % (chname, echo_messages[7].msgid, 3),
)
result = validate_chathistory_batch(self.getMessages(user)) result = validate_chathistory_batch(self.getMessages(user))
self.assertEqual(echo_messages[6:9], result) self.assertEqual(echo_messages[6:9], result)
self.sendLine(user, "CHATHISTORY AROUND %s timestamp=%s %d" % (chname, echo_messages[7].time, 3)) self.sendLine(
user,
"CHATHISTORY AROUND %s timestamp=%s %d"
% (chname, echo_messages[7].time, 3),
)
result = validate_chathistory_batch(self.getMessages(user)) result = validate_chathistory_batch(self.getMessages(user))
self.assertIn(echo_messages[7], result) self.assertIn(echo_messages[7], result)
@cases.SpecificationSelector.requiredBySpecification('Oragono') @cases.SpecificationSelector.requiredBySpecification("Oragono")
def testChathistoryTagmsg(self): def testChathistoryTagmsg(self):
c1 = secrets.token_hex(12) c1 = secrets.token_hex(12)
c2 = secrets.token_hex(12) c2 = secrets.token_hex(12)
chname = '#' + secrets.token_hex(12) chname = "#" + secrets.token_hex(12)
self.controller.registerUser(self, c1, 'sesame1') self.controller.registerUser(self, c1, "sesame1")
self.controller.registerUser(self, c2, 'sesame2') self.controller.registerUser(self, c2, "sesame2")
self.connectClient(c1, capabilities=['message-tags', 'server-time', 'echo-message', 'batch', 'labeled-response', CHATHISTORY_CAP, EVENT_PLAYBACK_CAP], password='sesame1') self.connectClient(
self.connectClient(c2, capabilities=['message-tags', 'server-time', 'echo-message', 'batch', 'labeled-response', CHATHISTORY_CAP,], password='sesame2') c1,
capabilities=[
"message-tags",
"server-time",
"echo-message",
"batch",
"labeled-response",
CHATHISTORY_CAP,
EVENT_PLAYBACK_CAP,
],
password="sesame1",
)
self.connectClient(
c2,
capabilities=[
"message-tags",
"server-time",
"echo-message",
"batch",
"labeled-response",
CHATHISTORY_CAP,
],
password="sesame2",
)
self.joinChannel(1, chname) self.joinChannel(1, chname)
self.joinChannel(2, chname) self.joinChannel(2, chname)
self.getMessages(1) self.getMessages(1)
self.getMessages(2) self.getMessages(2)
self.sendLine(1, '@+client-only-tag-test=success;+draft/persist TAGMSG %s' % (chname,)) self.sendLine(
1, "@+client-only-tag-test=success;+draft/persist TAGMSG %s" % (chname,)
)
echo = self.getMessages(1)[0] echo = self.getMessages(1)[0]
msgid = echo.tags['msgid'] msgid = echo.tags["msgid"]
def validate_tagmsg(msg, target, msgid): def validate_tagmsg(msg, target, msgid):
self.assertEqual(msg.command, 'TAGMSG') self.assertEqual(msg.command, "TAGMSG")
self.assertEqual(msg.tags['+client-only-tag-test'], 'success') self.assertEqual(msg.tags["+client-only-tag-test"], "success")
self.assertEqual(msg.tags['msgid'], msgid) self.assertEqual(msg.tags["msgid"], msgid)
self.assertEqual(msg.params, [target]) self.assertEqual(msg.params, [target])
validate_tagmsg(echo, chname, msgid) validate_tagmsg(echo, chname, msgid)
@ -312,69 +533,104 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
self.assertEqual(len(relay), 1) self.assertEqual(len(relay), 1)
validate_tagmsg(relay[0], chname, msgid) validate_tagmsg(relay[0], chname, msgid)
self.sendLine(1, 'CHATHISTORY LATEST %s * 10' % (chname,)) self.sendLine(1, "CHATHISTORY LATEST %s * 10" % (chname,))
history_tagmsgs = [msg for msg in self.getMessages(1) if msg.command == 'TAGMSG'] history_tagmsgs = [
msg for msg in self.getMessages(1) if msg.command == "TAGMSG"
]
self.assertEqual(len(history_tagmsgs), 1) self.assertEqual(len(history_tagmsgs), 1)
validate_tagmsg(history_tagmsgs[0], chname, msgid) validate_tagmsg(history_tagmsgs[0], chname, msgid)
# c2 doesn't have event-playback and MUST NOT receive replayed tagmsg # c2 doesn't have event-playback and MUST NOT receive replayed tagmsg
self.sendLine(2, 'CHATHISTORY LATEST %s * 10' % (chname,)) self.sendLine(2, "CHATHISTORY LATEST %s * 10" % (chname,))
history_tagmsgs = [msg for msg in self.getMessages(2) if msg.command == 'TAGMSG'] history_tagmsgs = [
msg for msg in self.getMessages(2) if msg.command == "TAGMSG"
]
self.assertEqual(len(history_tagmsgs), 0) self.assertEqual(len(history_tagmsgs), 0)
# now try a DM # now try a DM
self.sendLine(1, '@+client-only-tag-test=success;+draft/persist TAGMSG %s' % (c2,)) self.sendLine(
1, "@+client-only-tag-test=success;+draft/persist TAGMSG %s" % (c2,)
)
echo = self.getMessages(1)[0] echo = self.getMessages(1)[0]
msgid = echo.tags['msgid'] msgid = echo.tags["msgid"]
validate_tagmsg(echo, c2, msgid) validate_tagmsg(echo, c2, msgid)
relay = self.getMessages(2) relay = self.getMessages(2)
self.assertEqual(len(relay), 1) self.assertEqual(len(relay), 1)
validate_tagmsg(relay[0], c2, msgid) validate_tagmsg(relay[0], c2, msgid)
self.sendLine(1, 'CHATHISTORY LATEST %s * 10' % (c2,)) self.sendLine(1, "CHATHISTORY LATEST %s * 10" % (c2,))
history_tagmsgs = [msg for msg in self.getMessages(1) if msg.command == 'TAGMSG'] history_tagmsgs = [
msg for msg in self.getMessages(1) if msg.command == "TAGMSG"
]
self.assertEqual(len(history_tagmsgs), 1) self.assertEqual(len(history_tagmsgs), 1)
validate_tagmsg(history_tagmsgs[0], c2, msgid) validate_tagmsg(history_tagmsgs[0], c2, msgid)
# c2 doesn't have event-playback and MUST NOT receive replayed tagmsg # c2 doesn't have event-playback and MUST NOT receive replayed tagmsg
self.sendLine(2, 'CHATHISTORY LATEST %s * 10' % (c1,)) self.sendLine(2, "CHATHISTORY LATEST %s * 10" % (c1,))
history_tagmsgs = [msg for msg in self.getMessages(2) if msg.command == 'TAGMSG'] history_tagmsgs = [
msg for msg in self.getMessages(2) if msg.command == "TAGMSG"
]
self.assertEqual(len(history_tagmsgs), 0) self.assertEqual(len(history_tagmsgs), 0)
@cases.SpecificationSelector.requiredBySpecification("Oragono")
@cases.SpecificationSelector.requiredBySpecification('Oragono')
def testChathistoryDMClientOnlyTags(self): def testChathistoryDMClientOnlyTags(self):
# regression test for Oragono #1411 # regression test for Oragono #1411
c1 = secrets.token_hex(12) c1 = secrets.token_hex(12)
c2 = secrets.token_hex(12) c2 = secrets.token_hex(12)
self.controller.registerUser(self, c1, 'sesame1') self.controller.registerUser(self, c1, "sesame1")
self.controller.registerUser(self, c2, 'sesame2') self.controller.registerUser(self, c2, "sesame2")
self.connectClient(c1, capabilities=['message-tags', 'server-time', 'echo-message', 'batch', 'labeled-response', CHATHISTORY_CAP, EVENT_PLAYBACK_CAP], password='sesame1') self.connectClient(
self.connectClient(c2, capabilities=['message-tags', 'server-time', 'echo-message', 'batch', 'labeled-response', CHATHISTORY_CAP,], password='sesame2') c1,
capabilities=[
"message-tags",
"server-time",
"echo-message",
"batch",
"labeled-response",
CHATHISTORY_CAP,
EVENT_PLAYBACK_CAP,
],
password="sesame1",
)
self.connectClient(
c2,
capabilities=[
"message-tags",
"server-time",
"echo-message",
"batch",
"labeled-response",
CHATHISTORY_CAP,
],
password="sesame2",
)
self.getMessages(1) self.getMessages(1)
self.getMessages(2) self.getMessages(2)
echo_msgid = None echo_msgid = None
def validate_msg(msg):
self.assertEqual(msg.command, 'PRIVMSG')
self.assertEqual(msg.tags['+client-only-tag-test'], 'success')
self.assertEqual(msg.tags['msgid'], echo_msgid)
self.assertEqual(msg.params, [c2, 'hi'])
self.sendLine(1, '@+client-only-tag-test=success;+draft/persist PRIVMSG %s hi' % (c2,)) def validate_msg(msg):
self.assertEqual(msg.command, "PRIVMSG")
self.assertEqual(msg.tags["+client-only-tag-test"], "success")
self.assertEqual(msg.tags["msgid"], echo_msgid)
self.assertEqual(msg.params, [c2, "hi"])
self.sendLine(
1, "@+client-only-tag-test=success;+draft/persist PRIVMSG %s hi" % (c2,)
)
echo = self.getMessage(1) echo = self.getMessage(1)
echo_msgid = echo.tags['msgid'] echo_msgid = echo.tags["msgid"]
validate_msg(echo) validate_msg(echo)
relay = self.getMessage(2) relay = self.getMessage(2)
validate_msg(relay) validate_msg(relay)
self.sendLine(1, 'CHATHISTORY LATEST * * 10') self.sendLine(1, "CHATHISTORY LATEST * * 10")
hist = [msg for msg in self.getMessages(1) if msg.command == 'PRIVMSG'] hist = [msg for msg in self.getMessages(1) if msg.command == "PRIVMSG"]
self.assertEqual(len(hist), 1) self.assertEqual(len(hist), 1)
validate_msg(hist[0]) validate_msg(hist[0])
self.sendLine(2, 'CHATHISTORY LATEST * * 10') self.sendLine(2, "CHATHISTORY LATEST * * 10")
hist = [msg for msg in self.getMessages(2) if msg.command == 'PRIVMSG'] hist = [msg for msg in self.getMessages(2) if msg.command == "PRIVMSG"]
self.assertEqual(len(hist), 1) self.assertEqual(len(hist), 1)
validate_msg(hist[0]) validate_msg(hist[0])

View File

@ -1,31 +1,32 @@
from irctest import cases from irctest import cases
from irctest.numerics import RPL_WELCOME, ERR_NICKNAMEINUSE from irctest.numerics import RPL_WELCOME, ERR_NICKNAMEINUSE
class ConfusablesTestCase(cases.BaseServerTestCase): class ConfusablesTestCase(cases.BaseServerTestCase):
@staticmethod @staticmethod
def config(): def config():
return { return {
"oragono_config": lambda config: config['accounts'].update( "oragono_config": lambda config: config["accounts"].update(
{'nick-reservation': {'enabled': True, 'method': 'strict'}} {"nick-reservation": {"enabled": True, "method": "strict"}}
) )
} }
@cases.SpecificationSelector.requiredBySpecification('Oragono') @cases.SpecificationSelector.requiredBySpecification("Oragono")
def testConfusableNicks(self): def testConfusableNicks(self):
self.controller.registerUser(self, 'evan', 'sesame') self.controller.registerUser(self, "evan", "sesame")
self.addClient(1) self.addClient(1)
# U+0435 in place of e: # U+0435 in place of e:
self.sendLine(1, 'NICK еvan') self.sendLine(1, "NICK еvan")
self.sendLine(1, 'USER a 0 * a') self.sendLine(1, "USER a 0 * a")
messages = self.getMessages(1) messages = self.getMessages(1)
commands = set(msg.command for msg in messages) commands = set(msg.command for msg in messages)
self.assertNotIn(RPL_WELCOME, commands) self.assertNotIn(RPL_WELCOME, commands)
self.assertIn(ERR_NICKNAMEINUSE, commands) self.assertIn(ERR_NICKNAMEINUSE, commands)
self.connectClient('evan', name='evan', password='sesame') self.connectClient("evan", name="evan", password="sesame")
# should be able to switch to the confusable nick # should be able to switch to the confusable nick
self.sendLine('evan', 'NICK еvan') self.sendLine("evan", "NICK еvan")
messages = self.getMessages('evan') messages = self.getMessages("evan")
commands = set(msg.command for msg in messages) commands = set(msg.command for msg in messages)
self.assertIn('NICK', commands) self.assertIn("NICK", commands)

View File

@ -6,39 +6,48 @@ Tests section 4.1 of RFC 1459.
from irctest import cases from irctest import cases
from irctest.client_mock import ConnectionClosed from irctest.client_mock import ConnectionClosed
class PasswordedConnectionRegistrationTestCase(cases.BaseServerTestCase): class PasswordedConnectionRegistrationTestCase(cases.BaseServerTestCase):
password = 'testpassword' password = "testpassword"
@cases.SpecificationSelector.requiredBySpecification('RFC1459', 'RFC2812')
@cases.SpecificationSelector.requiredBySpecification("RFC1459", "RFC2812")
def testPassBeforeNickuser(self): def testPassBeforeNickuser(self):
self.addClient() self.addClient()
self.sendLine(1, 'PASS {}'.format(self.password)) self.sendLine(1, "PASS {}".format(self.password))
self.sendLine(1, 'NICK foo') self.sendLine(1, "NICK foo")
self.sendLine(1, 'USER username * * :Realname') self.sendLine(1, "USER username * * :Realname")
m = self.getRegistrationMessage(1) m = self.getRegistrationMessage(1)
self.assertMessageEqual(m, command='001', self.assertMessageEqual(
fail_msg='Did not get 001 after correct PASS+NICK+USER: {msg}') m,
command="001",
fail_msg="Did not get 001 after correct PASS+NICK+USER: {msg}",
)
@cases.SpecificationSelector.requiredBySpecification('RFC1459', 'RFC2812') @cases.SpecificationSelector.requiredBySpecification("RFC1459", "RFC2812")
def testNoPassword(self): def testNoPassword(self):
self.addClient() self.addClient()
self.sendLine(1, 'NICK foo') self.sendLine(1, "NICK foo")
self.sendLine(1, 'USER username * * :Realname') self.sendLine(1, "USER username * * :Realname")
m = self.getRegistrationMessage(1) m = self.getRegistrationMessage(1)
self.assertNotEqual(m.command, '001', self.assertNotEqual(
msg='Got 001 after NICK+USER but missing PASS') m.command, "001", msg="Got 001 after NICK+USER but missing PASS"
)
@cases.SpecificationSelector.requiredBySpecification('RFC1459', 'RFC2812') @cases.SpecificationSelector.requiredBySpecification("RFC1459", "RFC2812")
def testWrongPassword(self): def testWrongPassword(self):
self.addClient() self.addClient()
self.sendLine(1, 'PASS {}'.format(self.password + "garbage")) self.sendLine(1, "PASS {}".format(self.password + "garbage"))
self.sendLine(1, 'NICK foo') self.sendLine(1, "NICK foo")
self.sendLine(1, 'USER username * * :Realname') self.sendLine(1, "USER username * * :Realname")
m = self.getRegistrationMessage(1) m = self.getRegistrationMessage(1)
self.assertNotEqual(m.command, '001', self.assertNotEqual(
msg='Got 001 after NICK+USER but incorrect PASS') m.command, "001", msg="Got 001 after NICK+USER but incorrect PASS"
)
@cases.SpecificationSelector.requiredBySpecification('RFC1459', 'RFC2812', strict=True) @cases.SpecificationSelector.requiredBySpecification(
"RFC1459", "RFC2812", strict=True
)
def testPassAfterNickuser(self): def testPassAfterNickuser(self):
"""“The password can and must be set before any attempt to register """“The password can and must be set before any attempt to register
the connection is made. the connection is made.
@ -51,72 +60,77 @@ class PasswordedConnectionRegistrationTestCase(cases.BaseServerTestCase):
-- <https://tools.ietf.org/html/rfc2812#section-3.1.1> -- <https://tools.ietf.org/html/rfc2812#section-3.1.1>
""" """
self.addClient() self.addClient()
self.sendLine(1, 'NICK foo') self.sendLine(1, "NICK foo")
self.sendLine(1, 'USER username * * :Realname') self.sendLine(1, "USER username * * :Realname")
self.sendLine(1, 'PASS {}'.format(self.password)) self.sendLine(1, "PASS {}".format(self.password))
m = self.getRegistrationMessage(1) m = self.getRegistrationMessage(1)
self.assertNotEqual(m.command, '001', self.assertNotEqual(m.command, "001", "Got 001 after PASS sent after NICK+USER")
'Got 001 after PASS sent after NICK+USER')
class ConnectionRegistrationTestCase(cases.BaseServerTestCase): class ConnectionRegistrationTestCase(cases.BaseServerTestCase):
@cases.SpecificationSelector.requiredBySpecification('RFC1459') @cases.SpecificationSelector.requiredBySpecification("RFC1459")
def testQuitDisconnects(self): def testQuitDisconnects(self):
"""“The server must close the connection to a client which sends a """“The server must close the connection to a client which sends a
QUIT message. QUIT message.
-- <https://tools.ietf.org/html/rfc1459#section-4.1.3> -- <https://tools.ietf.org/html/rfc1459#section-4.1.3>
""" """
self.connectClient('foo') self.connectClient("foo")
self.getMessages(1) self.getMessages(1)
self.sendLine(1, 'QUIT') self.sendLine(1, "QUIT")
with self.assertRaises(ConnectionClosed): with self.assertRaises(ConnectionClosed):
self.getMessages(1) # Fetch remaining messages self.getMessages(1) # Fetch remaining messages
self.getMessages(1) self.getMessages(1)
@cases.SpecificationSelector.requiredBySpecification('RFC2812') @cases.SpecificationSelector.requiredBySpecification("RFC2812")
def testQuitErrors(self): def testQuitErrors(self):
"""“A client session is terminated with a quit message. The server """“A client session is terminated with a quit message. The server
acknowledges this by sending an ERROR message to the client. acknowledges this by sending an ERROR message to the client.
-- <https://tools.ietf.org/html/rfc2812#section-3.1.7> -- <https://tools.ietf.org/html/rfc2812#section-3.1.7>
""" """
self.connectClient('foo') self.connectClient("foo")
self.getMessages(1) self.getMessages(1)
self.sendLine(1, 'QUIT') self.sendLine(1, "QUIT")
try: try:
commands = {m.command for m in self.getMessages(1)} commands = {m.command for m in self.getMessages(1)}
except ConnectionClosed: except ConnectionClosed:
assert False, 'Connection closed without ERROR.' assert False, "Connection closed without ERROR."
self.assertIn('ERROR', commands, self.assertIn(
fail_msg='Did not receive ERROR as a reply to QUIT.') "ERROR", commands, fail_msg="Did not receive ERROR as a reply to QUIT."
)
def testNickCollision(self): def testNickCollision(self):
"""A user connects and requests the same nickname as an already """A user connects and requests the same nickname as an already
registered user. registered user.
""" """
self.connectClient('foo') self.connectClient("foo")
self.addClient() self.addClient()
self.sendLine(2, 'NICK foo') self.sendLine(2, "NICK foo")
self.sendLine(2, 'USER username * * :Realname') self.sendLine(2, "USER username * * :Realname")
m = self.getRegistrationMessage(2) m = self.getRegistrationMessage(2)
self.assertNotEqual(m.command, '001', self.assertNotEqual(
'Received 001 after registering with the nick of a ' m.command,
'registered user.') "001",
"Received 001 after registering with the nick of a " "registered user.",
)
def testEarlyNickCollision(self): def testEarlyNickCollision(self):
"""Two users register simultaneously with the same nick.""" """Two users register simultaneously with the same nick."""
self.addClient() self.addClient()
self.addClient() self.addClient()
self.sendLine(1, 'NICK foo') self.sendLine(1, "NICK foo")
self.sendLine(2, 'NICK foo') self.sendLine(2, "NICK foo")
self.sendLine(1, 'USER username * * :Realname') self.sendLine(1, "USER username * * :Realname")
self.sendLine(2, 'USER username * * :Realname') self.sendLine(2, "USER username * * :Realname")
m1 = self.getRegistrationMessage(1) m1 = self.getRegistrationMessage(1)
m2 = self.getRegistrationMessage(2) m2 = self.getRegistrationMessage(2)
self.assertNotEqual((m1.command, m2.command), ('001', '001'), self.assertNotEqual(
'Two concurrently registering requesting the same nickname ' (m1.command, m2.command),
'both got 001.') ("001", "001"),
"Two concurrently registering requesting the same nickname "
"both got 001.",
)
@cases.SpecificationSelector.requiredBySpecification('IRCv3.1', 'IRCv3.2') @cases.SpecificationSelector.requiredBySpecification("IRCv3.1", "IRCv3.2")
def testIrc301CapLs(self): def testIrc301CapLs(self):
"""IRCv3.1: “The LS subcommand is used to list the capabilities """IRCv3.1: “The LS subcommand is used to list the capabilities
supported by the server. The client should send an LS subcommand with supported by the server. The client should send an LS subcommand with
@ -128,24 +142,34 @@ class ConnectionRegistrationTestCase(cases.BaseServerTestCase):
-- <http://ircv3.net/specs/core/capability-negotiation-3.2.html#version-in-cap-ls> -- <http://ircv3.net/specs/core/capability-negotiation-3.2.html#version-in-cap-ls>
""" """
self.addClient() self.addClient()
self.sendLine(1, 'CAP LS') self.sendLine(1, "CAP LS")
m = self.getRegistrationMessage(1) m = self.getRegistrationMessage(1)
self.assertNotEqual(m.params[2], '*', m, self.assertNotEqual(
fail_msg='Server replied with multi-line CAP LS to a ' m.params[2],
'“CAP LS” (ie. IRCv3.1) request: {msg}') "*",
self.assertFalse(any('=' in cap for cap in m.params[2].split()), m,
'Server replied with a name-value capability in ' fail_msg="Server replied with multi-line CAP LS to a "
'CAP LS reply as a response to “CAP LS” (ie. IRCv3.1) ' "“CAP LS” (ie. IRCv3.1) request: {msg}",
'request: {}'.format(m)) )
self.assertFalse(
any("=" in cap for cap in m.params[2].split()),
"Server replied with a name-value capability in "
"CAP LS reply as a response to “CAP LS” (ie. IRCv3.1) "
"request: {}".format(m),
)
@cases.SpecificationSelector.requiredBySpecification('IRCv3.1') @cases.SpecificationSelector.requiredBySpecification("IRCv3.1")
def testEmptyCapList(self): def testEmptyCapList(self):
"""“If no capabilities are active, an empty parameter must be sent.” """“If no capabilities are active, an empty parameter must be sent.”
-- <http://ircv3.net/specs/core/capability-negotiation-3.1.html#the-cap-list-subcommand> -- <http://ircv3.net/specs/core/capability-negotiation-3.1.html#the-cap-list-subcommand>
""" """
self.addClient() self.addClient()
self.sendLine(1, 'CAP LIST') self.sendLine(1, "CAP LIST")
m = self.getRegistrationMessage(1) m = self.getRegistrationMessage(1)
self.assertMessageEqual(m, command='CAP', params=['*', 'LIST', ''], self.assertMessageEqual(
fail_msg='Sending “CAP LIST” as first message got a reply ' m,
'that is not “CAP * LIST :”: {msg}') command="CAP",
params=["*", "LIST", ""],
fail_msg="Sending “CAP LIST” as first message got a reply "
"that is not “CAP * LIST :”: {msg}",
)

View File

@ -6,65 +6,93 @@ from irctest import cases
from irctest.basecontrollers import NotImplementedByController from irctest.basecontrollers import NotImplementedByController
from irctest.irc_utils.junkdrawer import random_name from irctest.irc_utils.junkdrawer import random_name
class DMEchoMessageTestCase(cases.BaseServerTestCase):
@cases.SpecificationSelector.requiredBySpecification('Oragono') class DMEchoMessageTestCase(cases.BaseServerTestCase):
@cases.SpecificationSelector.requiredBySpecification("Oragono")
def testDirectMessageEcho(self): def testDirectMessageEcho(self):
bar = random_name('bar') bar = random_name("bar")
self.connectClient(bar, name=bar, capabilities=['batch', 'labeled-response', 'echo-message', 'message-tags', 'server-time']) self.connectClient(
bar,
name=bar,
capabilities=[
"batch",
"labeled-response",
"echo-message",
"message-tags",
"server-time",
],
)
self.getMessages(bar) self.getMessages(bar)
qux = random_name('qux') qux = random_name("qux")
self.connectClient(qux, name=qux, capabilities=['batch', 'labeled-response', 'echo-message', 'message-tags', 'server-time']) self.connectClient(
qux,
name=qux,
capabilities=[
"batch",
"labeled-response",
"echo-message",
"message-tags",
"server-time",
],
)
self.getMessages(qux) self.getMessages(qux)
self.sendLine(bar, '@label=xyz;+example-client-tag=example-value PRIVMSG %s :hi there' % (qux,)) self.sendLine(
bar,
"@label=xyz;+example-client-tag=example-value PRIVMSG %s :hi there"
% (qux,),
)
echo = self.getMessages(bar)[0] echo = self.getMessages(bar)[0]
delivery = self.getMessages(qux)[0] delivery = self.getMessages(qux)[0]
self.assertEqual(delivery.params, [qux, 'hi there']) self.assertEqual(delivery.params, [qux, "hi there"])
self.assertEqual(delivery.params, echo.params) self.assertEqual(delivery.params, echo.params)
self.assertEqual(delivery.tags['msgid'], echo.tags['msgid']) self.assertEqual(delivery.tags["msgid"], echo.tags["msgid"])
self.assertEqual(echo.tags['label'], 'xyz') self.assertEqual(echo.tags["label"], "xyz")
self.assertEqual(delivery.tags['+example-client-tag'], 'example-value') self.assertEqual(delivery.tags["+example-client-tag"], "example-value")
self.assertEqual(delivery.tags['+example-client-tag'], echo.tags['+example-client-tag']) self.assertEqual(
delivery.tags["+example-client-tag"], echo.tags["+example-client-tag"]
)
class EchoMessageTestCase(cases.BaseServerTestCase): class EchoMessageTestCase(cases.BaseServerTestCase):
def _testEchoMessage(command, solo, server_time): def _testEchoMessage(command, solo, server_time):
@cases.SpecificationSelector.requiredBySpecification('IRCv3.2') @cases.SpecificationSelector.requiredBySpecification("IRCv3.2")
def f(self): def f(self):
"""<http://ircv3.net/specs/extensions/echo-message-3.2.html> """<http://ircv3.net/specs/extensions/echo-message-3.2.html>"""
"""
self.addClient() self.addClient()
self.sendLine(1, 'CAP LS 302') self.sendLine(1, "CAP LS 302")
capabilities = self.getCapLs(1) capabilities = self.getCapLs(1)
if 'echo-message' not in capabilities: if "echo-message" not in capabilities:
raise NotImplementedByController('echo-message') raise NotImplementedByController("echo-message")
if server_time and 'server-time' not in capabilities: if server_time and "server-time" not in capabilities:
raise NotImplementedByController('server-time') raise NotImplementedByController("server-time")
# TODO: check also without this # TODO: check also without this
self.sendLine(1, 'CAP REQ :echo-message{}'.format( self.sendLine(
' server-time' if server_time else '')) 1,
"CAP REQ :echo-message{}".format(" server-time" if server_time else ""),
)
self.getRegistrationMessage(1) self.getRegistrationMessage(1)
# TODO: Remove this one the trailing space issue is fixed in Charybdis # TODO: Remove this one the trailing space issue is fixed in Charybdis
# and Mammon: # and Mammon:
#self.assertMessageEqual(m, command='CAP', # self.assertMessageEqual(m, command='CAP',
# params=['*', 'ACK', 'echo-message'] + # params=['*', 'ACK', 'echo-message'] +
# (['server-time'] if server_time else []), # (['server-time'] if server_time else []),
# fail_msg='Did not ACK advertised capabilities: {msg}') # fail_msg='Did not ACK advertised capabilities: {msg}')
self.sendLine(1, 'USER f * * :foo') self.sendLine(1, "USER f * * :foo")
self.sendLine(1, 'NICK baz') self.sendLine(1, "NICK baz")
self.sendLine(1, 'CAP END') self.sendLine(1, "CAP END")
self.skipToWelcome(1) self.skipToWelcome(1)
self.getMessages(1) self.getMessages(1)
self.sendLine(1, 'JOIN #chan') self.sendLine(1, "JOIN #chan")
if not solo: if not solo:
capabilities = ['server-time'] if server_time else None capabilities = ["server-time"] if server_time else None
self.connectClient('qux', capabilities=capabilities) self.connectClient("qux", capabilities=capabilities)
self.sendLine(2, 'JOIN #chan') self.sendLine(2, "JOIN #chan")
# Synchronize and clean # Synchronize and clean
self.getMessages(1) self.getMessages(1)
@ -72,30 +100,50 @@ class EchoMessageTestCase(cases.BaseServerTestCase):
self.getMessages(2) self.getMessages(2)
self.getMessages(1) self.getMessages(1)
self.sendLine(1, '{} #chan :hello everyone'.format(command)) self.sendLine(1, "{} #chan :hello everyone".format(command))
m1 = self.getMessage(1) m1 = self.getMessage(1)
self.assertMessageEqual(m1, command=command, self.assertMessageEqual(
params=['#chan', 'hello everyone'], m1,
fail_msg='Did not echo “{} #chan :hello everyone”: {msg}', command=command,
extra_format=(command,)) params=["#chan", "hello everyone"],
fail_msg="Did not echo “{} #chan :hello everyone”: {msg}",
extra_format=(command,),
)
if not solo: if not solo:
m2 = self.getMessage(2) m2 = self.getMessage(2)
self.assertMessageEqual(m2, command=command, self.assertMessageEqual(
params=['#chan', 'hello everyone'], m2,
fail_msg='Did not propagate “{} #chan :hello everyone”: ' command=command,
'after echoing it to the author: {msg}', params=["#chan", "hello everyone"],
extra_format=(command,)) fail_msg="Did not propagate “{} #chan :hello everyone”: "
self.assertEqual(m1.params, m2.params, "after echoing it to the author: {msg}",
fail_msg='Parameters of forwarded and echoed ' extra_format=(command,),
'messages differ: {} {}', )
extra_format=(m1, m2)) self.assertEqual(
m1.params,
m2.params,
fail_msg="Parameters of forwarded and echoed "
"messages differ: {} {}",
extra_format=(m1, m2),
)
if server_time: if server_time:
self.assertIn('time', m1.tags, fail_msg='Echoed message is missing server time: {}', extra_format=(m1,)) self.assertIn(
self.assertIn('time', m2.tags, fail_msg='Forwarded message is missing server time: {}', extra_format=(m2,)) "time",
m1.tags,
fail_msg="Echoed message is missing server time: {}",
extra_format=(m1,),
)
self.assertIn(
"time",
m2.tags,
fail_msg="Forwarded message is missing server time: {}",
extra_format=(m2,),
)
return f return f
testEchoMessagePrivmsgNoServerTime = _testEchoMessage('PRIVMSG', False, False) testEchoMessagePrivmsgNoServerTime = _testEchoMessage("PRIVMSG", False, False)
testEchoMessagePrivmsgSolo = _testEchoMessage('PRIVMSG', True, True) testEchoMessagePrivmsgSolo = _testEchoMessage("PRIVMSG", True, True)
testEchoMessagePrivmsg = _testEchoMessage('PRIVMSG', False, True) testEchoMessagePrivmsg = _testEchoMessage("PRIVMSG", False, True)
testEchoMessageNotice = _testEchoMessage('NOTICE', False, True) testEchoMessageNotice = _testEchoMessage("NOTICE", False, True)

View File

@ -4,52 +4,64 @@
from irctest import cases from irctest import cases
class MetadataTestCase(cases.BaseServerTestCase, cases.OptionalityHelper): class MetadataTestCase(cases.BaseServerTestCase, cases.OptionalityHelper):
def connectRegisteredClient(self, nick): def connectRegisteredClient(self, nick):
self.addClient() self.addClient()
self.sendLine(2, 'CAP LS 302') self.sendLine(2, "CAP LS 302")
capabilities = self.getCapLs(2) capabilities = self.getCapLs(2)
assert 'sasl' in capabilities assert "sasl" in capabilities
self.sendLine(2, 'AUTHENTICATE PLAIN') self.sendLine(2, "AUTHENTICATE PLAIN")
m = self.getMessage(2, filter_pred=lambda m:m.command != 'NOTICE') m = self.getMessage(2, filter_pred=lambda m: m.command != "NOTICE")
self.assertMessageEqual(m, command='AUTHENTICATE', params=['+'], self.assertMessageEqual(
fail_msg='Sent “AUTHENTICATE PLAIN”, server should have ' m,
'replied with “AUTHENTICATE +”, but instead sent: {msg}') command="AUTHENTICATE",
self.sendLine(2, 'AUTHENTICATE amlsbGVzAGppbGxlcwBzZXNhbWU=') params=["+"],
m = self.getMessage(2, filter_pred=lambda m:m.command != 'NOTICE') fail_msg="Sent “AUTHENTICATE PLAIN”, server should have "
self.assertMessageEqual(m, command='900', "replied with “AUTHENTICATE +”, but instead sent: {msg}",
fail_msg='Did not send 900 after correct SASL authentication.') )
self.sendLine(2, 'USER f * * :Realname') self.sendLine(2, "AUTHENTICATE amlsbGVzAGppbGxlcwBzZXNhbWU=")
self.sendLine(2, 'NICK {}'.format(nick)) m = self.getMessage(2, filter_pred=lambda m: m.command != "NOTICE")
self.sendLine(2, 'CAP END') self.assertMessageEqual(
m,
command="900",
fail_msg="Did not send 900 after correct SASL authentication.",
)
self.sendLine(2, "USER f * * :Realname")
self.sendLine(2, "NICK {}".format(nick))
self.sendLine(2, "CAP END")
self.skipToWelcome(2) self.skipToWelcome(2)
@cases.SpecificationSelector.requiredBySpecification('IRCv3.1') @cases.SpecificationSelector.requiredBySpecification("IRCv3.1")
def testNotLoggedIn(self): def testNotLoggedIn(self):
self.connectClient('foo', capabilities=['extended-join'], self.connectClient("foo", capabilities=["extended-join"], skip_if_cap_nak=True)
skip_if_cap_nak=True) self.joinChannel(1, "#chan")
self.joinChannel(1, '#chan') self.connectClient("bar")
self.connectClient('bar') self.joinChannel(2, "#chan")
self.joinChannel(2, '#chan')
m = self.getMessage(1) m = self.getMessage(1)
self.assertMessageEqual(m, command='JOIN', self.assertMessageEqual(
params=['#chan', '*', 'Realname'], m,
fail_msg='Expected “JOIN #chan * :Realname” after ' command="JOIN",
'unregistered user joined, got: {msg}') params=["#chan", "*", "Realname"],
fail_msg="Expected “JOIN #chan * :Realname” after "
"unregistered user joined, got: {msg}",
)
@cases.SpecificationSelector.requiredBySpecification('IRCv3.1') @cases.SpecificationSelector.requiredBySpecification("IRCv3.1")
@cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN') @cases.OptionalityHelper.skipUnlessHasMechanism("PLAIN")
def testLoggedIn(self): def testLoggedIn(self):
self.connectClient('foo', capabilities=['extended-join'], self.connectClient("foo", capabilities=["extended-join"], skip_if_cap_nak=True)
skip_if_cap_nak=True) self.joinChannel(1, "#chan")
self.joinChannel(1, '#chan')
self.controller.registerUser(self, 'jilles', 'sesame') self.controller.registerUser(self, "jilles", "sesame")
self.connectRegisteredClient('bar') self.connectRegisteredClient("bar")
self.joinChannel(2, '#chan') self.joinChannel(2, "#chan")
m = self.getMessage(1) m = self.getMessage(1)
self.assertMessageEqual(m, command='JOIN', self.assertMessageEqual(
params=['#chan', 'jilles', 'Realname'], m,
fail_msg='Expected “JOIN #chan * :Realname” after ' command="JOIN",
'nick “bar” logged in as “jilles” joined, got: {msg}') params=["#chan", "jilles", "Realname"],
fail_msg="Expected “JOIN #chan * :Realname” after "
"nick “bar” logged in as “jilles” joined, got: {msg}",
)

View File

@ -6,240 +6,570 @@ import re
from irctest import cases from irctest import cases
class LabeledResponsesTestCase(cases.BaseServerTestCase, cases.OptionalityHelper): class LabeledResponsesTestCase(cases.BaseServerTestCase, cases.OptionalityHelper):
@cases.SpecificationSelector.requiredBySpecification('IRCv3.2') @cases.SpecificationSelector.requiredBySpecification("IRCv3.2")
def testLabeledPrivmsgResponsesToMultipleClients(self): def testLabeledPrivmsgResponsesToMultipleClients(self):
self.connectClient('foo', capabilities=['batch', 'echo-message', 'labeled-response'], skip_if_cap_nak=True) self.connectClient(
"foo",
capabilities=["batch", "echo-message", "labeled-response"],
skip_if_cap_nak=True,
)
self.getMessages(1) self.getMessages(1)
self.connectClient('bar', capabilities=['batch', 'echo-message', 'labeled-response'], skip_if_cap_nak=True) self.connectClient(
"bar",
capabilities=["batch", "echo-message", "labeled-response"],
skip_if_cap_nak=True,
)
self.getMessages(2) self.getMessages(2)
self.connectClient('carl', capabilities=['batch', 'echo-message', 'labeled-response'], skip_if_cap_nak=True) self.connectClient(
"carl",
capabilities=["batch", "echo-message", "labeled-response"],
skip_if_cap_nak=True,
)
self.getMessages(3) self.getMessages(3)
self.connectClient('alice', capabilities=['batch', 'echo-message', 'labeled-response'], skip_if_cap_nak=True) self.connectClient(
"alice",
capabilities=["batch", "echo-message", "labeled-response"],
skip_if_cap_nak=True,
)
self.getMessages(4) self.getMessages(4)
self.sendLine(1, '@label=12345 PRIVMSG bar,carl,alice :hi') self.sendLine(1, "@label=12345 PRIVMSG bar,carl,alice :hi")
m = self.getMessage(1) m = self.getMessage(1)
m2 = self.getMessage(2) m2 = self.getMessage(2)
m3 = self.getMessage(3) m3 = self.getMessage(3)
m4 = self.getMessage(4) m4 = self.getMessage(4)
# ensure the label isn't sent to recipients # ensure the label isn't sent to recipients
self.assertMessageEqual(m2, command='PRIVMSG', fail_msg='No PRIVMSG received by target 1 after sending one out') self.assertMessageEqual(
self.assertNotIn('label', m2.tags, m2, fail_msg="When sending a PRIVMSG with a label, the target users shouldn't receive the label (only the sending user should): {msg}") m2,
self.assertMessageEqual(m3, command='PRIVMSG', fail_msg='No PRIVMSG received by target 1 after sending one out') command="PRIVMSG",
self.assertNotIn('label', m3.tags, m3, fail_msg="When sending a PRIVMSG with a label, the target users shouldn't receive the label (only the sending user should): {msg}") fail_msg="No PRIVMSG received by target 1 after sending one out",
self.assertMessageEqual(m4, command='PRIVMSG', fail_msg='No PRIVMSG received by target 1 after sending one out') )
self.assertNotIn('label', m4.tags, m4, fail_msg="When sending a PRIVMSG with a label, the target users shouldn't receive the label (only the sending user should): {msg}") self.assertNotIn(
"label",
m2.tags,
m2,
fail_msg="When sending a PRIVMSG with a label, the target users shouldn't receive the label (only the sending user should): {msg}",
)
self.assertMessageEqual(
m3,
command="PRIVMSG",
fail_msg="No PRIVMSG received by target 1 after sending one out",
)
self.assertNotIn(
"label",
m3.tags,
m3,
fail_msg="When sending a PRIVMSG with a label, the target users shouldn't receive the label (only the sending user should): {msg}",
)
self.assertMessageEqual(
m4,
command="PRIVMSG",
fail_msg="No PRIVMSG received by target 1 after sending one out",
)
self.assertNotIn(
"label",
m4.tags,
m4,
fail_msg="When sending a PRIVMSG with a label, the target users shouldn't receive the label (only the sending user should): {msg}",
)
self.assertMessageEqual(m, command='BATCH', fail_msg='No BATCH echo received after sending one out') self.assertMessageEqual(
m, command="BATCH", fail_msg="No BATCH echo received after sending one out"
)
@cases.SpecificationSelector.requiredBySpecification('IRCv3.2') @cases.SpecificationSelector.requiredBySpecification("IRCv3.2")
def testLabeledPrivmsgResponsesToClient(self): def testLabeledPrivmsgResponsesToClient(self):
self.connectClient('foo', capabilities=['batch', 'echo-message', 'labeled-response'], skip_if_cap_nak=True) self.connectClient(
"foo",
capabilities=["batch", "echo-message", "labeled-response"],
skip_if_cap_nak=True,
)
self.getMessages(1) self.getMessages(1)
self.connectClient('bar', capabilities=['batch', 'echo-message', 'labeled-response'], skip_if_cap_nak=True) self.connectClient(
"bar",
capabilities=["batch", "echo-message", "labeled-response"],
skip_if_cap_nak=True,
)
self.getMessages(2) self.getMessages(2)
self.sendLine(1, '@label=12345 PRIVMSG bar :hi') self.sendLine(1, "@label=12345 PRIVMSG bar :hi")
m = self.getMessage(1) m = self.getMessage(1)
m2 = self.getMessage(2) m2 = self.getMessage(2)
# ensure the label isn't sent to recipient # ensure the label isn't sent to recipient
self.assertMessageEqual(m2, command='PRIVMSG', fail_msg='No PRIVMSG received by the target after sending one out') self.assertMessageEqual(
self.assertNotIn('label', m2.tags, m2, fail_msg="When sending a PRIVMSG with a label, the target user shouldn't receive the label (only the sending user should): {msg}") m2,
command="PRIVMSG",
fail_msg="No PRIVMSG received by the target after sending one out",
)
self.assertNotIn(
"label",
m2.tags,
m2,
fail_msg="When sending a PRIVMSG with a label, the target user shouldn't receive the label (only the sending user should): {msg}",
)
self.assertMessageEqual(m, command='PRIVMSG', fail_msg='No PRIVMSG echo received after sending one out') self.assertMessageEqual(
self.assertIn('label', m.tags, m, fail_msg="When sending a PRIVMSG with a label, the echo'd message didn't contain the label at all: {msg}") m,
self.assertEqual(m.tags['label'], '12345', m, fail_msg="Echo'd PRIVMSG to a client did not contain the same label we sent it with(should be '12345'): {msg}") command="PRIVMSG",
fail_msg="No PRIVMSG echo received after sending one out",
)
self.assertIn(
"label",
m.tags,
m,
fail_msg="When sending a PRIVMSG with a label, the echo'd message didn't contain the label at all: {msg}",
)
self.assertEqual(
m.tags["label"],
"12345",
m,
fail_msg="Echo'd PRIVMSG to a client did not contain the same label we sent it with(should be '12345'): {msg}",
)
@cases.SpecificationSelector.requiredBySpecification('IRCv3.2') @cases.SpecificationSelector.requiredBySpecification("IRCv3.2")
def testLabeledPrivmsgResponsesToChannel(self): def testLabeledPrivmsgResponsesToChannel(self):
self.connectClient('foo', capabilities=['batch', 'echo-message', 'labeled-response'], skip_if_cap_nak=True) self.connectClient(
"foo",
capabilities=["batch", "echo-message", "labeled-response"],
skip_if_cap_nak=True,
)
self.getMessages(1) self.getMessages(1)
self.connectClient('bar', capabilities=['batch', 'echo-message', 'labeled-response'], skip_if_cap_nak=True) self.connectClient(
"bar",
capabilities=["batch", "echo-message", "labeled-response"],
skip_if_cap_nak=True,
)
self.getMessages(2) self.getMessages(2)
# join channels # join channels
self.sendLine(1, 'JOIN #test') self.sendLine(1, "JOIN #test")
self.getMessages(1) self.getMessages(1)
self.sendLine(2, 'JOIN #test') self.sendLine(2, "JOIN #test")
self.getMessages(2) self.getMessages(2)
self.getMessages(1) self.getMessages(1)
self.sendLine(1, '@label=12345;+draft/reply=123;+draft/react=l😃l PRIVMSG #test :hi') self.sendLine(
1, "@label=12345;+draft/reply=123;+draft/react=l😃l PRIVMSG #test :hi"
)
ms = self.getMessage(1) ms = self.getMessage(1)
mt = self.getMessage(2) mt = self.getMessage(2)
# ensure the label isn't sent to recipient # ensure the label isn't sent to recipient
self.assertMessageEqual(mt, command='PRIVMSG', fail_msg='No PRIVMSG received by the target after sending one out') self.assertMessageEqual(
self.assertNotIn('label', mt.tags, mt, fail_msg="When sending a PRIVMSG with a label, the target user shouldn't receive the label (only the sending user should): {msg}") mt,
command="PRIVMSG",
fail_msg="No PRIVMSG received by the target after sending one out",
)
self.assertNotIn(
"label",
mt.tags,
mt,
fail_msg="When sending a PRIVMSG with a label, the target user shouldn't receive the label (only the sending user should): {msg}",
)
# ensure sender correctly receives msg # ensure sender correctly receives msg
self.assertMessageEqual(ms, command='PRIVMSG', fail_msg="Got a message back that wasn't a PRIVMSG") self.assertMessageEqual(
self.assertIn('label', ms.tags, ms, fail_msg="When sending a PRIVMSG with a label, the source user should receive the label but didn't: {msg}") ms, command="PRIVMSG", fail_msg="Got a message back that wasn't a PRIVMSG"
self.assertEqual(ms.tags['label'], '12345', ms, fail_msg="Echo'd label doesn't match the label we sent (should be '12345'): {msg}") )
self.assertIn(
"label",
ms.tags,
ms,
fail_msg="When sending a PRIVMSG with a label, the source user should receive the label but didn't: {msg}",
)
self.assertEqual(
ms.tags["label"],
"12345",
ms,
fail_msg="Echo'd label doesn't match the label we sent (should be '12345'): {msg}",
)
@cases.SpecificationSelector.requiredBySpecification('IRCv3.2') @cases.SpecificationSelector.requiredBySpecification("IRCv3.2")
def testLabeledPrivmsgResponsesToSelf(self): def testLabeledPrivmsgResponsesToSelf(self):
self.connectClient('foo', capabilities=['batch', 'echo-message', 'labeled-response'], skip_if_cap_nak=True) self.connectClient(
"foo",
capabilities=["batch", "echo-message", "labeled-response"],
skip_if_cap_nak=True,
)
self.getMessages(1) self.getMessages(1)
self.sendLine(1, '@label=12345 PRIVMSG foo :hi') self.sendLine(1, "@label=12345 PRIVMSG foo :hi")
m1 = self.getMessage(1) m1 = self.getMessage(1)
m2 = self.getMessage(1) m2 = self.getMessage(1)
number_of_labels = 0 number_of_labels = 0
for m in [m1, m2]: for m in [m1, m2]:
self.assertMessageEqual(m, command='PRIVMSG', fail_msg="Got a message back that wasn't a PRIVMSG") self.assertMessageEqual(
if 'label' in m.tags: m,
command="PRIVMSG",
fail_msg="Got a message back that wasn't a PRIVMSG",
)
if "label" in m.tags:
number_of_labels += 1 number_of_labels += 1
self.assertEqual(m.tags['label'], '12345', m, fail_msg="Echo'd label doesn't match the label we sent (should be '12345'): {msg}") self.assertEqual(
m.tags["label"],
self.assertEqual(number_of_labels, 1, m1, fail_msg="When sending a PRIVMSG to self with echo-message, we only expect one message to contain the label. Instead, {} messages had the label".format(number_of_labels)) "12345",
m,
fail_msg="Echo'd label doesn't match the label we sent (should be '12345'): {msg}",
)
@cases.SpecificationSelector.requiredBySpecification('IRCv3.2') self.assertEqual(
number_of_labels,
1,
m1,
fail_msg="When sending a PRIVMSG to self with echo-message, we only expect one message to contain the label. Instead, {} messages had the label".format(
number_of_labels
),
)
@cases.SpecificationSelector.requiredBySpecification("IRCv3.2")
def testLabeledNoticeResponsesToClient(self): def testLabeledNoticeResponsesToClient(self):
self.connectClient('foo', capabilities=['batch', 'echo-message', 'labeled-response'], skip_if_cap_nak=True) self.connectClient(
"foo",
capabilities=["batch", "echo-message", "labeled-response"],
skip_if_cap_nak=True,
)
self.getMessages(1) self.getMessages(1)
self.connectClient('bar', capabilities=['batch', 'echo-message', 'labeled-response'], skip_if_cap_nak=True) self.connectClient(
"bar",
capabilities=["batch", "echo-message", "labeled-response"],
skip_if_cap_nak=True,
)
self.getMessages(2) self.getMessages(2)
self.sendLine(1, '@label=12345 NOTICE bar :hi') self.sendLine(1, "@label=12345 NOTICE bar :hi")
m = self.getMessage(1) m = self.getMessage(1)
m2 = self.getMessage(2) m2 = self.getMessage(2)
# ensure the label isn't sent to recipient # ensure the label isn't sent to recipient
self.assertMessageEqual(m2, command='NOTICE', fail_msg='No NOTICE received by the target after sending one out') self.assertMessageEqual(
self.assertNotIn('label', m2.tags, m2, fail_msg="When sending a NOTICE with a label, the target user shouldn't receive the label (only the sending user should): {msg}") m2,
command="NOTICE",
fail_msg="No NOTICE received by the target after sending one out",
)
self.assertNotIn(
"label",
m2.tags,
m2,
fail_msg="When sending a NOTICE with a label, the target user shouldn't receive the label (only the sending user should): {msg}",
)
self.assertMessageEqual(m, command='NOTICE', fail_msg='No NOTICE echo received after sending one out') self.assertMessageEqual(
self.assertIn('label', m.tags, m, fail_msg="When sending a NOTICE with a label, the echo'd message didn't contain the label at all: {msg}") m,
self.assertEqual(m.tags['label'], '12345', m, fail_msg="Echo'd NOTICE to a client did not contain the same label we sent it with(should be '12345'): {msg}") command="NOTICE",
fail_msg="No NOTICE echo received after sending one out",
)
self.assertIn(
"label",
m.tags,
m,
fail_msg="When sending a NOTICE with a label, the echo'd message didn't contain the label at all: {msg}",
)
self.assertEqual(
m.tags["label"],
"12345",
m,
fail_msg="Echo'd NOTICE to a client did not contain the same label we sent it with(should be '12345'): {msg}",
)
@cases.SpecificationSelector.requiredBySpecification('IRCv3.2') @cases.SpecificationSelector.requiredBySpecification("IRCv3.2")
def testLabeledNoticeResponsesToChannel(self): def testLabeledNoticeResponsesToChannel(self):
self.connectClient('foo', capabilities=['batch', 'echo-message', 'labeled-response'], skip_if_cap_nak=True) self.connectClient(
"foo",
capabilities=["batch", "echo-message", "labeled-response"],
skip_if_cap_nak=True,
)
self.getMessages(1) self.getMessages(1)
self.connectClient('bar', capabilities=['batch', 'echo-message', 'labeled-response'], skip_if_cap_nak=True) self.connectClient(
"bar",
capabilities=["batch", "echo-message", "labeled-response"],
skip_if_cap_nak=True,
)
self.getMessages(2) self.getMessages(2)
# join channels # join channels
self.sendLine(1, 'JOIN #test') self.sendLine(1, "JOIN #test")
self.getMessages(1) self.getMessages(1)
self.sendLine(2, 'JOIN #test') self.sendLine(2, "JOIN #test")
self.getMessages(2) self.getMessages(2)
self.getMessages(1) self.getMessages(1)
self.sendLine(1, '@label=12345;+draft/reply=123;+draft/react=l😃l NOTICE #test :hi') self.sendLine(
1, "@label=12345;+draft/reply=123;+draft/react=l😃l NOTICE #test :hi"
)
ms = self.getMessage(1) ms = self.getMessage(1)
mt = self.getMessage(2) mt = self.getMessage(2)
# ensure the label isn't sent to recipient # ensure the label isn't sent to recipient
self.assertMessageEqual(mt, command='NOTICE', fail_msg='No NOTICE received by the target after sending one out') self.assertMessageEqual(
self.assertNotIn('label', mt.tags, mt, fail_msg="When sending a NOTICE with a label, the target user shouldn't receive the label (only the sending user should): {msg}") mt,
command="NOTICE",
fail_msg="No NOTICE received by the target after sending one out",
)
self.assertNotIn(
"label",
mt.tags,
mt,
fail_msg="When sending a NOTICE with a label, the target user shouldn't receive the label (only the sending user should): {msg}",
)
# ensure sender correctly receives msg # ensure sender correctly receives msg
self.assertMessageEqual(ms, command='NOTICE', fail_msg="Got a message back that wasn't a NOTICE") self.assertMessageEqual(
self.assertIn('label', ms.tags, ms, fail_msg="When sending a NOTICE with a label, the source user should receive the label but didn't: {msg}") ms, command="NOTICE", fail_msg="Got a message back that wasn't a NOTICE"
self.assertEqual(ms.tags['label'], '12345', ms, fail_msg="Echo'd label doesn't match the label we sent (should be '12345'): {msg}") )
self.assertIn(
"label",
ms.tags,
ms,
fail_msg="When sending a NOTICE with a label, the source user should receive the label but didn't: {msg}",
)
self.assertEqual(
ms.tags["label"],
"12345",
ms,
fail_msg="Echo'd label doesn't match the label we sent (should be '12345'): {msg}",
)
@cases.SpecificationSelector.requiredBySpecification('IRCv3.2') @cases.SpecificationSelector.requiredBySpecification("IRCv3.2")
def testLabeledNoticeResponsesToSelf(self): def testLabeledNoticeResponsesToSelf(self):
self.connectClient('foo', capabilities=['batch', 'echo-message', 'labeled-response'], skip_if_cap_nak=True) self.connectClient(
"foo",
capabilities=["batch", "echo-message", "labeled-response"],
skip_if_cap_nak=True,
)
self.getMessages(1) self.getMessages(1)
self.sendLine(1, '@label=12345 NOTICE foo :hi') self.sendLine(1, "@label=12345 NOTICE foo :hi")
m1 = self.getMessage(1) m1 = self.getMessage(1)
m2 = self.getMessage(1) m2 = self.getMessage(1)
number_of_labels = 0 number_of_labels = 0
for m in [m1, m2]: for m in [m1, m2]:
self.assertMessageEqual(m, command='NOTICE', fail_msg="Got a message back that wasn't a NOTICE") self.assertMessageEqual(
if 'label' in m.tags: m, command="NOTICE", fail_msg="Got a message back that wasn't a NOTICE"
)
if "label" in m.tags:
number_of_labels += 1 number_of_labels += 1
self.assertEqual(m.tags['label'], '12345', m, fail_msg="Echo'd label doesn't match the label we sent (should be '12345'): {msg}") self.assertEqual(
m.tags["label"],
self.assertEqual(number_of_labels, 1, m1, fail_msg="When sending a NOTICE to self with echo-message, we only expect one message to contain the label. Instead, {} messages had the label".format(number_of_labels)) "12345",
m,
fail_msg="Echo'd label doesn't match the label we sent (should be '12345'): {msg}",
)
@cases.SpecificationSelector.requiredBySpecification('IRCv3.2') self.assertEqual(
number_of_labels,
1,
m1,
fail_msg="When sending a NOTICE to self with echo-message, we only expect one message to contain the label. Instead, {} messages had the label".format(
number_of_labels
),
)
@cases.SpecificationSelector.requiredBySpecification("IRCv3.2")
def testLabeledTagMsgResponsesToClient(self): def testLabeledTagMsgResponsesToClient(self):
self.connectClient('foo', capabilities=['batch', 'echo-message', 'labeled-response', 'message-tags'], skip_if_cap_nak=True) self.connectClient(
"foo",
capabilities=["batch", "echo-message", "labeled-response", "message-tags"],
skip_if_cap_nak=True,
)
self.getMessages(1) self.getMessages(1)
self.connectClient('bar', capabilities=['batch', 'echo-message', 'labeled-response', 'message-tags'], skip_if_cap_nak=True) self.connectClient(
"bar",
capabilities=["batch", "echo-message", "labeled-response", "message-tags"],
skip_if_cap_nak=True,
)
self.getMessages(2) self.getMessages(2)
self.sendLine(1, '@label=12345;+draft/reply=123;+draft/react=l😃l TAGMSG bar') self.sendLine(1, "@label=12345;+draft/reply=123;+draft/react=l😃l TAGMSG bar")
m = self.getMessage(1) m = self.getMessage(1)
m2 = self.getMessage(2) m2 = self.getMessage(2)
# ensure the label isn't sent to recipient # ensure the label isn't sent to recipient
self.assertMessageEqual(m2, command='TAGMSG', fail_msg='No TAGMSG received by the target after sending one out') self.assertMessageEqual(
self.assertNotIn('label', m2.tags, m2, fail_msg="When sending a TAGMSG with a label, the target user shouldn't receive the label (only the sending user should): {msg}") m2,
self.assertIn('+draft/reply', m2.tags, m2, fail_msg="Reply tag wasn't present on the target user's TAGMSG: {msg}") command="TAGMSG",
self.assertEqual(m2.tags['+draft/reply'], '123', m2, fail_msg="Reply tag wasn't the same on the target user's TAGMSG: {msg}") fail_msg="No TAGMSG received by the target after sending one out",
self.assertIn('+draft/react', m2.tags, m2, fail_msg="React tag wasn't present on the target user's TAGMSG: {msg}") )
self.assertEqual(m2.tags['+draft/react'], 'l😃l', m2, fail_msg="React tag wasn't the same on the target user's TAGMSG: {msg}") self.assertNotIn(
"label",
m2.tags,
m2,
fail_msg="When sending a TAGMSG with a label, the target user shouldn't receive the label (only the sending user should): {msg}",
)
self.assertIn(
"+draft/reply",
m2.tags,
m2,
fail_msg="Reply tag wasn't present on the target user's TAGMSG: {msg}",
)
self.assertEqual(
m2.tags["+draft/reply"],
"123",
m2,
fail_msg="Reply tag wasn't the same on the target user's TAGMSG: {msg}",
)
self.assertIn(
"+draft/react",
m2.tags,
m2,
fail_msg="React tag wasn't present on the target user's TAGMSG: {msg}",
)
self.assertEqual(
m2.tags["+draft/react"],
"l😃l",
m2,
fail_msg="React tag wasn't the same on the target user's TAGMSG: {msg}",
)
self.assertMessageEqual(m, command='TAGMSG', fail_msg='No TAGMSG echo received after sending one out') self.assertMessageEqual(
self.assertIn('label', m.tags, m, fail_msg="When sending a TAGMSG with a label, the echo'd message didn't contain the label at all: {msg}") m,
self.assertEqual(m.tags['label'], '12345', m, fail_msg="Echo'd TAGMSG to a client did not contain the same label we sent it with(should be '12345'): {msg}") command="TAGMSG",
self.assertIn('+draft/reply', m.tags, m, fail_msg="Reply tag wasn't present on the source user's TAGMSG: {msg}") fail_msg="No TAGMSG echo received after sending one out",
self.assertEqual(m2.tags['+draft/reply'], '123', m, fail_msg="Reply tag wasn't the same on the source user's TAGMSG: {msg}") )
self.assertIn('+draft/react', m.tags, m, fail_msg="React tag wasn't present on the source user's TAGMSG: {msg}") self.assertIn(
self.assertEqual(m2.tags['+draft/react'], 'l😃l', m, fail_msg="React tag wasn't the same on the source user's TAGMSG: {msg}") "label",
m.tags,
m,
fail_msg="When sending a TAGMSG with a label, the echo'd message didn't contain the label at all: {msg}",
)
self.assertEqual(
m.tags["label"],
"12345",
m,
fail_msg="Echo'd TAGMSG to a client did not contain the same label we sent it with(should be '12345'): {msg}",
)
self.assertIn(
"+draft/reply",
m.tags,
m,
fail_msg="Reply tag wasn't present on the source user's TAGMSG: {msg}",
)
self.assertEqual(
m2.tags["+draft/reply"],
"123",
m,
fail_msg="Reply tag wasn't the same on the source user's TAGMSG: {msg}",
)
self.assertIn(
"+draft/react",
m.tags,
m,
fail_msg="React tag wasn't present on the source user's TAGMSG: {msg}",
)
self.assertEqual(
m2.tags["+draft/react"],
"l😃l",
m,
fail_msg="React tag wasn't the same on the source user's TAGMSG: {msg}",
)
@cases.SpecificationSelector.requiredBySpecification('IRCv3.2') @cases.SpecificationSelector.requiredBySpecification("IRCv3.2")
def testLabeledTagMsgResponsesToChannel(self): def testLabeledTagMsgResponsesToChannel(self):
self.connectClient('foo', capabilities=['batch', 'echo-message', 'labeled-response', 'message-tags'], skip_if_cap_nak=True) self.connectClient(
"foo",
capabilities=["batch", "echo-message", "labeled-response", "message-tags"],
skip_if_cap_nak=True,
)
self.getMessages(1) self.getMessages(1)
self.connectClient('bar', capabilities=['batch', 'echo-message', 'labeled-response', 'message-tags'], skip_if_cap_nak=True) self.connectClient(
"bar",
capabilities=["batch", "echo-message", "labeled-response", "message-tags"],
skip_if_cap_nak=True,
)
self.getMessages(2) self.getMessages(2)
# join channels # join channels
self.sendLine(1, 'JOIN #test') self.sendLine(1, "JOIN #test")
self.getMessages(1) self.getMessages(1)
self.sendLine(2, 'JOIN #test') self.sendLine(2, "JOIN #test")
self.getMessages(2) self.getMessages(2)
self.getMessages(1) self.getMessages(1)
self.sendLine(1, '@label=12345;+draft/reply=123;+draft/react=l😃l TAGMSG #test') self.sendLine(1, "@label=12345;+draft/reply=123;+draft/react=l😃l TAGMSG #test")
ms = self.getMessage(1) ms = self.getMessage(1)
mt = self.getMessage(2) mt = self.getMessage(2)
# ensure the label isn't sent to recipient # ensure the label isn't sent to recipient
self.assertMessageEqual(mt, command='TAGMSG', fail_msg='No TAGMSG received by the target after sending one out') self.assertMessageEqual(
self.assertNotIn('label', mt.tags, mt, fail_msg="When sending a TAGMSG with a label, the target user shouldn't receive the label (only the sending user should): {msg}") mt,
command="TAGMSG",
fail_msg="No TAGMSG received by the target after sending one out",
)
self.assertNotIn(
"label",
mt.tags,
mt,
fail_msg="When sending a TAGMSG with a label, the target user shouldn't receive the label (only the sending user should): {msg}",
)
# ensure sender correctly receives msg # ensure sender correctly receives msg
self.assertMessageEqual(ms, command='TAGMSG', fail_msg="Got a message back that wasn't a TAGMSG") self.assertMessageEqual(
self.assertIn('label', ms.tags, ms, fail_msg="When sending a TAGMSG with a label, the source user should receive the label but didn't: {msg}") ms, command="TAGMSG", fail_msg="Got a message back that wasn't a TAGMSG"
self.assertEqual(ms.tags['label'], '12345', ms, fail_msg="Echo'd label doesn't match the label we sent (should be '12345'): {msg}") )
self.assertIn(
"label",
ms.tags,
ms,
fail_msg="When sending a TAGMSG with a label, the source user should receive the label but didn't: {msg}",
)
self.assertEqual(
ms.tags["label"],
"12345",
ms,
fail_msg="Echo'd label doesn't match the label we sent (should be '12345'): {msg}",
)
@cases.SpecificationSelector.requiredBySpecification('IRCv3.2') @cases.SpecificationSelector.requiredBySpecification("IRCv3.2")
def testLabeledTagMsgResponsesToSelf(self): def testLabeledTagMsgResponsesToSelf(self):
self.connectClient('foo', capabilities=['batch', 'echo-message', 'labeled-response', 'message-tags'], skip_if_cap_nak=True) self.connectClient(
"foo",
capabilities=["batch", "echo-message", "labeled-response", "message-tags"],
skip_if_cap_nak=True,
)
self.getMessages(1) self.getMessages(1)
self.sendLine(1, '@label=12345;+draft/reply=123;+draft/react=l😃l TAGMSG foo') self.sendLine(1, "@label=12345;+draft/reply=123;+draft/react=l😃l TAGMSG foo")
m1 = self.getMessage(1) m1 = self.getMessage(1)
m2 = self.getMessage(1) m2 = self.getMessage(1)
number_of_labels = 0 number_of_labels = 0
for m in [m1, m2]: for m in [m1, m2]:
self.assertMessageEqual(m, command='TAGMSG', fail_msg="Got a message back that wasn't a TAGMSG") self.assertMessageEqual(
if 'label' in m.tags: m, command="TAGMSG", fail_msg="Got a message back that wasn't a TAGMSG"
)
if "label" in m.tags:
number_of_labels += 1 number_of_labels += 1
self.assertEqual(m.tags['label'], '12345', m, fail_msg="Echo'd label doesn't match the label we sent (should be '12345'): {msg}") self.assertEqual(
m.tags["label"],
"12345",
m,
fail_msg="Echo'd label doesn't match the label we sent (should be '12345'): {msg}",
)
self.assertEqual(number_of_labels, 1, m1, fail_msg="When sending a TAGMSG to self with echo-message, we only expect one message to contain the label. Instead, {} messages had the label".format(number_of_labels)) self.assertEqual(
number_of_labels,
1,
m1,
fail_msg="When sending a TAGMSG to self with echo-message, we only expect one message to contain the label. Instead, {} messages had the label".format(
number_of_labels
),
)
@cases.SpecificationSelector.requiredBySpecification('IRCv3.2') @cases.SpecificationSelector.requiredBySpecification("IRCv3.2")
def testBatchedJoinMessages(self): def testBatchedJoinMessages(self):
self.connectClient('bar', capabilities=['batch', 'labeled-response', 'message-tags', 'server-time'], skip_if_cap_nak=True) self.connectClient(
"bar",
capabilities=["batch", "labeled-response", "message-tags", "server-time"],
skip_if_cap_nak=True,
)
self.getMessages(1) self.getMessages(1)
self.sendLine(1, '@label=12345 JOIN #xyz') self.sendLine(1, "@label=12345 JOIN #xyz")
m = self.getMessages(1) m = self.getMessages(1)
# we expect at least join and names lines, which must be batched # we expect at least join and names lines, which must be batched
@ -247,45 +577,57 @@ class LabeledResponsesTestCase(cases.BaseServerTestCase, cases.OptionalityHelper
# valid BATCH start line: # valid BATCH start line:
batch_start = m[0] batch_start = m[0]
self.assertMessageEqual(batch_start, command='BATCH') self.assertMessageEqual(batch_start, command="BATCH")
self.assertEqual(len(batch_start.params), 2) self.assertEqual(len(batch_start.params), 2)
self.assertTrue(batch_start.params[0].startswith('+'), 'batch start param must begin with +, got %s' % (batch_start.params[0],)) self.assertTrue(
batch_start.params[0].startswith("+"),
"batch start param must begin with +, got %s" % (batch_start.params[0],),
)
batch_id = batch_start.params[0][1:] batch_id = batch_start.params[0][1:]
# batch id MUST be alphanumerics and hyphens # batch id MUST be alphanumerics and hyphens
self.assertTrue(re.match(r'^[A-Za-z0-9\-]+$', batch_id) is not None, 'batch id must be alphanumerics and hyphens, got %r' % (batch_id,)) self.assertTrue(
self.assertEqual(batch_start.params[1], 'labeled-response') re.match(r"^[A-Za-z0-9\-]+$", batch_id) is not None,
self.assertEqual(batch_start.tags.get('label'), '12345') "batch id must be alphanumerics and hyphens, got %r" % (batch_id,),
)
self.assertEqual(batch_start.params[1], "labeled-response")
self.assertEqual(batch_start.tags.get("label"), "12345")
# valid BATCH end line # valid BATCH end line
batch_end = m[-1] batch_end = m[-1]
self.assertMessageEqual(batch_end, command='BATCH', params=['-' + batch_id]) self.assertMessageEqual(batch_end, command="BATCH", params=["-" + batch_id])
# messages must have the BATCH tag # messages must have the BATCH tag
for message in m[1:-1]: for message in m[1:-1]:
self.assertEqual(message.tags.get('batch'), batch_id) self.assertEqual(message.tags.get("batch"), batch_id)
@cases.SpecificationSelector.requiredBySpecification('Oragono') @cases.SpecificationSelector.requiredBySpecification("Oragono")
def testNoBatchForSingleMessage(self): def testNoBatchForSingleMessage(self):
self.connectClient('bar', capabilities=['batch', 'labeled-response', 'message-tags', 'server-time']) self.connectClient(
"bar",
capabilities=["batch", "labeled-response", "message-tags", "server-time"],
)
self.getMessages(1) self.getMessages(1)
self.sendLine(1, '@label=98765 PING adhoctestline') self.sendLine(1, "@label=98765 PING adhoctestline")
# no BATCH should be initiated for a one-line response, it should just be labeled # no BATCH should be initiated for a one-line response, it should just be labeled
ms = self.getMessages(1) ms = self.getMessages(1)
self.assertEqual(len(ms), 1) self.assertEqual(len(ms), 1)
m = ms[0] m = ms[0]
self.assertEqual(m.command, 'PONG') self.assertEqual(m.command, "PONG")
self.assertEqual(m.params[-1], 'adhoctestline') self.assertEqual(m.params[-1], "adhoctestline")
# check the label # check the label
self.assertEqual(m.tags.get('label'), '98765') self.assertEqual(m.tags.get("label"), "98765")
@cases.SpecificationSelector.requiredBySpecification('Oragono') @cases.SpecificationSelector.requiredBySpecification("Oragono")
def testEmptyBatchForNoResponse(self): def testEmptyBatchForNoResponse(self):
self.connectClient('bar', capabilities=['batch', 'labeled-response', 'message-tags', 'server-time']) self.connectClient(
"bar",
capabilities=["batch", "labeled-response", "message-tags", "server-time"],
)
self.getMessages(1) self.getMessages(1)
# PONG never receives a response # PONG never receives a response
self.sendLine(1, '@label=98765 PONG adhoctestline') self.sendLine(1, "@label=98765 PONG adhoctestline")
# labeled-response: "Servers MUST respond with a labeled # labeled-response: "Servers MUST respond with a labeled
# `ACK` message when a client sends a labeled command that normally # `ACK` message when a client sends a labeled command that normally
@ -294,5 +636,5 @@ class LabeledResponsesTestCase(cases.BaseServerTestCase, cases.OptionalityHelper
self.assertEqual(len(ms), 1) self.assertEqual(len(ms), 1)
ack = ms[0] ack = ms[0]
self.assertEqual(ack.command, 'ACK') self.assertEqual(ack.command, "ACK")
self.assertEqual(ack.tags.get('label'), '98765') self.assertEqual(ack.tags.get("label"), "98765")

View File

@ -6,142 +6,143 @@ from irctest import cases
from irctest.irc_utils.message_parser import parse_message from irctest.irc_utils.message_parser import parse_message
from irctest.numerics import ERR_INPUTTOOLONG from irctest.numerics import ERR_INPUTTOOLONG
class MessageTagsTestCase(cases.BaseServerTestCase, cases.OptionalityHelper):
@cases.SpecificationSelector.requiredBySpecification('message-tags') class MessageTagsTestCase(cases.BaseServerTestCase, cases.OptionalityHelper):
@cases.SpecificationSelector.requiredBySpecification("message-tags")
def testBasic(self): def testBasic(self):
def getAllMessages(): def getAllMessages():
for name in ['alice', 'bob', 'carol', 'dave']: for name in ["alice", "bob", "carol", "dave"]:
self.getMessages(name) self.getMessages(name)
def assertNoTags(line): def assertNoTags(line):
# tags start with '@', without tags we start with the prefix, # tags start with '@', without tags we start with the prefix,
# which begins with ':' # which begins with ':'
self.assertEqual(line[0], ':') self.assertEqual(line[0], ":")
msg = parse_message(line) msg = parse_message(line)
self.assertEqual(msg.tags, {}) self.assertEqual(msg.tags, {})
return msg return msg
self.connectClient( self.connectClient(
'alice', "alice", name="alice", capabilities=["message-tags"], skip_if_cap_nak=True
name='alice',
capabilities=['message-tags'],
skip_if_cap_nak=True
) )
self.joinChannel('alice', '#test') self.joinChannel("alice", "#test")
self.connectClient('bob', name='bob', capabilities=['message-tags', 'echo-message']) self.connectClient(
self.joinChannel('bob', '#test') "bob", name="bob", capabilities=["message-tags", "echo-message"]
self.connectClient('carol', name='carol') )
self.joinChannel('carol', '#test') self.joinChannel("bob", "#test")
self.connectClient('dave', name='dave', capabilities=['server-time']) self.connectClient("carol", name="carol")
self.joinChannel('dave', '#test') self.joinChannel("carol", "#test")
self.connectClient("dave", name="dave", capabilities=["server-time"])
self.joinChannel("dave", "#test")
getAllMessages() getAllMessages()
self.sendLine('alice', '@+baz=bat;fizz=buzz PRIVMSG #test hi') self.sendLine("alice", "@+baz=bat;fizz=buzz PRIVMSG #test hi")
self.getMessages('alice') self.getMessages("alice")
bob_msg = self.getMessage('bob') bob_msg = self.getMessage("bob")
carol_line = self.getMessage('carol', raw=True) carol_line = self.getMessage("carol", raw=True)
self.assertMessageEqual(bob_msg, command='PRIVMSG', params=['#test', 'hi']) self.assertMessageEqual(bob_msg, command="PRIVMSG", params=["#test", "hi"])
self.assertEqual(bob_msg.tags['+baz'], "bat") self.assertEqual(bob_msg.tags["+baz"], "bat")
self.assertIn('msgid', bob_msg.tags) self.assertIn("msgid", bob_msg.tags)
# should not relay a non-client-only tag # should not relay a non-client-only tag
self.assertNotIn('fizz', bob_msg.tags) self.assertNotIn("fizz", bob_msg.tags)
# carol MUST NOT receive tags # carol MUST NOT receive tags
carol_msg = assertNoTags(carol_line) carol_msg = assertNoTags(carol_line)
self.assertMessageEqual(carol_msg, command='PRIVMSG', params=['#test', 'hi']) self.assertMessageEqual(carol_msg, command="PRIVMSG", params=["#test", "hi"])
# dave SHOULD receive server-time tag # dave SHOULD receive server-time tag
dave_msg = self.getMessage('dave') dave_msg = self.getMessage("dave")
self.assertIn('time', dave_msg.tags) self.assertIn("time", dave_msg.tags)
# dave MUST NOT receive client-only tags # dave MUST NOT receive client-only tags
self.assertNotIn('+baz', dave_msg.tags) self.assertNotIn("+baz", dave_msg.tags)
getAllMessages() getAllMessages()
self.sendLine('bob', '@+bat=baz;+fizz=buzz PRIVMSG #test :hi yourself') self.sendLine("bob", "@+bat=baz;+fizz=buzz PRIVMSG #test :hi yourself")
bob_msg = self.getMessage('bob') # bob has echo-message bob_msg = self.getMessage("bob") # bob has echo-message
alice_msg = self.getMessage('alice') alice_msg = self.getMessage("alice")
carol_line = self.getMessage('carol', raw=True) carol_line = self.getMessage("carol", raw=True)
carol_msg = assertNoTags(carol_line) carol_msg = assertNoTags(carol_line)
for msg in [alice_msg, bob_msg, carol_msg]: for msg in [alice_msg, bob_msg, carol_msg]:
self.assertMessageEqual(msg, command='PRIVMSG', params=['#test', 'hi yourself']) self.assertMessageEqual(
msg, command="PRIVMSG", params=["#test", "hi yourself"]
)
for msg in [alice_msg, bob_msg]: for msg in [alice_msg, bob_msg]:
self.assertEqual(msg.tags['+bat'], 'baz') self.assertEqual(msg.tags["+bat"], "baz")
self.assertEqual(msg.tags['+fizz'], 'buzz') self.assertEqual(msg.tags["+fizz"], "buzz")
self.assertTrue(alice_msg.tags['msgid']) self.assertTrue(alice_msg.tags["msgid"])
self.assertEqual(alice_msg.tags['msgid'], bob_msg.tags['msgid']) self.assertEqual(alice_msg.tags["msgid"], bob_msg.tags["msgid"])
getAllMessages() getAllMessages()
# test TAGMSG and basic escaping # test TAGMSG and basic escaping
self.sendLine('bob', '@+buzz=fizz\:buzz;cat=dog;+steel=wootz TAGMSG #test') self.sendLine("bob", "@+buzz=fizz\:buzz;cat=dog;+steel=wootz TAGMSG #test")
bob_msg = self.getMessage('bob') # bob has echo-message bob_msg = self.getMessage("bob") # bob has echo-message
alice_msg = self.getMessage('alice') alice_msg = self.getMessage("alice")
# carol MUST NOT receive TAGMSG at all # carol MUST NOT receive TAGMSG at all
self.assertEqual(self.getMessages('carol'), []) self.assertEqual(self.getMessages("carol"), [])
# dave MUST NOT receive TAGMSG either, despite having server-time # dave MUST NOT receive TAGMSG either, despite having server-time
self.assertEqual(self.getMessages('dave'), []) self.assertEqual(self.getMessages("dave"), [])
for msg in [alice_msg, bob_msg]: for msg in [alice_msg, bob_msg]:
self.assertMessageEqual(alice_msg, command='TAGMSG', params=['#test']) self.assertMessageEqual(alice_msg, command="TAGMSG", params=["#test"])
self.assertEqual(msg.tags['+buzz'], 'fizz;buzz') self.assertEqual(msg.tags["+buzz"], "fizz;buzz")
self.assertEqual(msg.tags['+steel'], 'wootz') self.assertEqual(msg.tags["+steel"], "wootz")
self.assertNotIn('cat', msg.tags) self.assertNotIn("cat", msg.tags)
self.assertTrue(alice_msg.tags['msgid']) self.assertTrue(alice_msg.tags["msgid"])
self.assertEqual(alice_msg.tags['msgid'], bob_msg.tags['msgid']) self.assertEqual(alice_msg.tags["msgid"], bob_msg.tags["msgid"])
@cases.SpecificationSelector.requiredBySpecification('message-tags') @cases.SpecificationSelector.requiredBySpecification("message-tags")
def testLengthLimits(self): def testLengthLimits(self):
self.connectClient( self.connectClient(
'alice', "alice",
name='alice', name="alice",
capabilities=['message-tags', 'echo-message'], capabilities=["message-tags", "echo-message"],
skip_if_cap_nak=True skip_if_cap_nak=True,
) )
self.joinChannel('alice', '#test') self.joinChannel("alice", "#test")
self.connectClient('bob', name='bob', capabilities=['message-tags']) self.connectClient("bob", name="bob", capabilities=["message-tags"])
self.joinChannel('bob', '#test') self.joinChannel("bob", "#test")
self.getMessages('alice') self.getMessages("alice")
self.getMessages('bob') self.getMessages("bob")
# this is right at the limit of 4094 bytes of tag data, # this is right at the limit of 4094 bytes of tag data,
# 4096 bytes of tag section (including the starting '@' and the final ' ') # 4096 bytes of tag section (including the starting '@' and the final ' ')
max_tagmsg = '@foo=bar;+baz=%s TAGMSG #test' % ('a' * 4081,) max_tagmsg = "@foo=bar;+baz=%s TAGMSG #test" % ("a" * 4081,)
self.assertEqual(max_tagmsg.index('TAGMSG'), 4096) self.assertEqual(max_tagmsg.index("TAGMSG"), 4096)
self.sendLine('alice', max_tagmsg) self.sendLine("alice", max_tagmsg)
echo = self.getMessage('alice') echo = self.getMessage("alice")
relay = self.getMessage('bob') relay = self.getMessage("bob")
self.assertMessageEqual(echo, command='TAGMSG', params=['#test']) self.assertMessageEqual(echo, command="TAGMSG", params=["#test"])
self.assertMessageEqual(relay, command='TAGMSG', params=['#test']) self.assertMessageEqual(relay, command="TAGMSG", params=["#test"])
self.assertNotEqual(echo.tags['msgid'], '') self.assertNotEqual(echo.tags["msgid"], "")
self.assertEqual(echo.tags['msgid'], relay.tags['msgid']) self.assertEqual(echo.tags["msgid"], relay.tags["msgid"])
self.assertEqual(echo.tags['+baz'], 'a' * 4081) self.assertEqual(echo.tags["+baz"], "a" * 4081)
self.assertEqual(relay.tags['+baz'], echo.tags['+baz']) self.assertEqual(relay.tags["+baz"], echo.tags["+baz"])
excess_tagmsg = '@foo=bar;+baz=%s TAGMSG #test' % ('a' * 4082,) excess_tagmsg = "@foo=bar;+baz=%s TAGMSG #test" % ("a" * 4082,)
self.assertEqual(excess_tagmsg.index('TAGMSG'), 4097) self.assertEqual(excess_tagmsg.index("TAGMSG"), 4097)
self.sendLine('alice', excess_tagmsg) self.sendLine("alice", excess_tagmsg)
reply = self.getMessage('alice') reply = self.getMessage("alice")
self.assertEqual(reply.command, ERR_INPUTTOOLONG) self.assertEqual(reply.command, ERR_INPUTTOOLONG)
self.assertEqual(self.getMessages('bob'), []) self.assertEqual(self.getMessages("bob"), [])
max_privmsg = '@foo=bar;+baz=%s PRIVMSG #test %s' % ('a' * 4081, 'b' * 496) max_privmsg = "@foo=bar;+baz=%s PRIVMSG #test %s" % ("a" * 4081, "b" * 496)
# irctest adds the '\r\n' for us, this is right at the limit # irctest adds the '\r\n' for us, this is right at the limit
self.assertEqual(len(max_privmsg), 4096 + (512 - 2)) self.assertEqual(len(max_privmsg), 4096 + (512 - 2))
self.sendLine('alice', max_privmsg) self.sendLine("alice", max_privmsg)
echo = self.getMessage('alice') echo = self.getMessage("alice")
relay = self.getMessage('bob') relay = self.getMessage("bob")
self.assertNotEqual(echo.tags['msgid'], '') self.assertNotEqual(echo.tags["msgid"], "")
self.assertEqual(echo.tags['msgid'], relay.tags['msgid']) self.assertEqual(echo.tags["msgid"], relay.tags["msgid"])
self.assertEqual(echo.tags['+baz'], 'a' * 4081) self.assertEqual(echo.tags["+baz"], "a" * 4081)
self.assertEqual(relay.tags['+baz'], echo.tags['+baz']) self.assertEqual(relay.tags["+baz"], echo.tags["+baz"])
# message may have been truncated # message may have been truncated
self.assertIn('b' * 400, echo.params[1]) self.assertIn("b" * 400, echo.params[1])
self.assertEqual(echo.params[1].rstrip('b'), '') self.assertEqual(echo.params[1].rstrip("b"), "")
self.assertIn('b' * 400, relay.params[1]) self.assertIn("b" * 400, relay.params[1])
self.assertEqual(relay.params[1].rstrip('b'), '') self.assertEqual(relay.params[1].rstrip("b"), "")
excess_privmsg = '@foo=bar;+baz=%s PRIVMSG #test %s' % ('a' * 4082, 'b' * 495) excess_privmsg = "@foo=bar;+baz=%s PRIVMSG #test %s" % ("a" * 4082, "b" * 495)
# TAGMSG data is over the limit, but we're within the overall limit for a line # TAGMSG data is over the limit, but we're within the overall limit for a line
self.assertEqual(excess_privmsg.index('PRIVMSG'), 4097) self.assertEqual(excess_privmsg.index("PRIVMSG"), 4097)
self.assertEqual(len(excess_privmsg), 4096 + (512 - 2)) self.assertEqual(len(excess_privmsg), 4096 + (512 - 2))
self.sendLine('alice', excess_privmsg) self.sendLine("alice", excess_privmsg)
reply = self.getMessage('alice') reply = self.getMessage("alice")
self.assertEqual(reply.command, ERR_INPUTTOOLONG) self.assertEqual(reply.command, ERR_INPUTTOOLONG)
self.assertEqual(self.getMessages('bob'), []) self.assertEqual(self.getMessages("bob"), [])

View File

@ -6,54 +6,52 @@ Section 3.2 of RFC 2812
from irctest import cases from irctest import cases
from irctest.numerics import ERR_INPUTTOOLONG from irctest.numerics import ERR_INPUTTOOLONG
class PrivmsgTestCase(cases.BaseServerTestCase): class PrivmsgTestCase(cases.BaseServerTestCase):
@cases.SpecificationSelector.requiredBySpecification('RFC1459', 'RFC2812') @cases.SpecificationSelector.requiredBySpecification("RFC1459", "RFC2812")
def testPrivmsg(self): def testPrivmsg(self):
"""<https://tools.ietf.org/html/rfc2812#section-3.3.1>""" """<https://tools.ietf.org/html/rfc2812#section-3.3.1>"""
self.connectClient('foo') self.connectClient("foo")
self.sendLine(1, 'JOIN #chan') self.sendLine(1, "JOIN #chan")
self.connectClient('bar') self.connectClient("bar")
self.sendLine(2, 'JOIN #chan') self.sendLine(2, "JOIN #chan")
self.getMessages(2) # synchronize self.getMessages(2) # synchronize
self.sendLine(1, 'PRIVMSG #chan :hello there') self.sendLine(1, "PRIVMSG #chan :hello there")
self.getMessages(1) # synchronize self.getMessages(1) # synchronize
pms = [msg for msg in self.getMessages(2) if msg.command == 'PRIVMSG'] pms = [msg for msg in self.getMessages(2) if msg.command == "PRIVMSG"]
self.assertEqual(len(pms), 1) self.assertEqual(len(pms), 1)
self.assertMessageEqual( self.assertMessageEqual(
pms[0], pms[0], command="PRIVMSG", params=["#chan", "hello there"]
command='PRIVMSG',
params=['#chan', 'hello there']
) )
@cases.SpecificationSelector.requiredBySpecification('RFC1459', 'RFC2812') @cases.SpecificationSelector.requiredBySpecification("RFC1459", "RFC2812")
def testPrivmsgNonexistentChannel(self): def testPrivmsgNonexistentChannel(self):
"""<https://tools.ietf.org/html/rfc2812#section-3.3.1>""" """<https://tools.ietf.org/html/rfc2812#section-3.3.1>"""
self.connectClient('foo') self.connectClient("foo")
self.sendLine(1, 'PRIVMSG #nonexistent :hello there') self.sendLine(1, "PRIVMSG #nonexistent :hello there")
msg = self.getMessage(1) msg = self.getMessage(1)
# ERR_NOSUCHNICK, ERR_NOSUCHCHANNEL, or ERR_CANNOTSENDTOCHAN # ERR_NOSUCHNICK, ERR_NOSUCHCHANNEL, or ERR_CANNOTSENDTOCHAN
self.assertIn(msg.command, ('401', '403', '404')) self.assertIn(msg.command, ("401", "403", "404"))
class NoticeTestCase(cases.BaseServerTestCase): class NoticeTestCase(cases.BaseServerTestCase):
@cases.SpecificationSelector.requiredBySpecification('RFC1459', 'RFC2812') @cases.SpecificationSelector.requiredBySpecification("RFC1459", "RFC2812")
def testNotice(self): def testNotice(self):
"""<https://tools.ietf.org/html/rfc2812#section-3.3.2>""" """<https://tools.ietf.org/html/rfc2812#section-3.3.2>"""
self.connectClient('foo') self.connectClient("foo")
self.sendLine(1, 'JOIN #chan') self.sendLine(1, "JOIN #chan")
self.connectClient('bar') self.connectClient("bar")
self.sendLine(2, 'JOIN #chan') self.sendLine(2, "JOIN #chan")
self.getMessages(2) # synchronize self.getMessages(2) # synchronize
self.sendLine(1, 'NOTICE #chan :hello there') self.sendLine(1, "NOTICE #chan :hello there")
self.getMessages(1) # synchronize self.getMessages(1) # synchronize
notices = [msg for msg in self.getMessages(2) if msg.command == 'NOTICE'] notices = [msg for msg in self.getMessages(2) if msg.command == "NOTICE"]
self.assertEqual(len(notices), 1) self.assertEqual(len(notices), 1)
self.assertMessageEqual( self.assertMessageEqual(
notices[0], notices[0], command="NOTICE", params=["#chan", "hello there"]
command='NOTICE',
params=['#chan', 'hello there']
) )
@cases.SpecificationSelector.requiredBySpecification('RFC1459', 'RFC2812') @cases.SpecificationSelector.requiredBySpecification("RFC1459", "RFC2812")
def testNoticeNonexistentChannel(self): def testNoticeNonexistentChannel(self):
""" """
'automatic replies MUST NEVER be sent in response to a NOTICE message. 'automatic replies MUST NEVER be sent in response to a NOTICE message.
@ -61,17 +59,17 @@ class NoticeTestCase(cases.BaseServerTestCase):
back to the client on receipt of a notice.' back to the client on receipt of a notice.'
https://tools.ietf.org/html/rfc2812#section-3.3.2> https://tools.ietf.org/html/rfc2812#section-3.3.2>
""" """
self.connectClient('foo') self.connectClient("foo")
self.sendLine(1, 'NOTICE #nonexistent :hello there') self.sendLine(1, "NOTICE #nonexistent :hello there")
self.assertEqual(self.getMessages(1), []) self.assertEqual(self.getMessages(1), [])
class TagsTestCase(cases.BaseServerTestCase): class TagsTestCase(cases.BaseServerTestCase):
@cases.SpecificationSelector.requiredBySpecification('Oragono') @cases.SpecificationSelector.requiredBySpecification("Oragono")
def testLineTooLong(self): def testLineTooLong(self):
self.connectClient('bar') self.connectClient("bar")
self.joinChannel(1, '#xyz') self.joinChannel(1, "#xyz")
monsterMessage = '@+clientOnlyTagExample=' + 'a'*4096 + ' PRIVMSG #xyz hi!' monsterMessage = "@+clientOnlyTagExample=" + "a" * 4096 + " PRIVMSG #xyz hi!"
self.sendLine(1, monsterMessage) self.sendLine(1, monsterMessage)
replies = self.getMessages(1) replies = self.getMessages(1)
self.assertIn(ERR_INPUTTOOLONG, set(reply.command for reply in replies)) self.assertIn(ERR_INPUTTOOLONG, set(reply.command for reply in replies))

View File

@ -5,174 +5,244 @@ Tests METADATA features.
from irctest import cases from irctest import cases
class MetadataTestCase(cases.BaseServerTestCase): class MetadataTestCase(cases.BaseServerTestCase):
valid_metadata_keys = {'valid_key1', 'valid_key2'} valid_metadata_keys = {"valid_key1", "valid_key2"}
invalid_metadata_keys = {'invalid_key1', 'invalid_key2'} invalid_metadata_keys = {"invalid_key1", "invalid_key2"}
@cases.SpecificationSelector.requiredBySpecification('IRCv3.2-deprecated')
@cases.SpecificationSelector.requiredBySpecification("IRCv3.2-deprecated")
def testInIsupport(self): def testInIsupport(self):
"""“If METADATA is supported, it MUST be specified in RPL_ISUPPORT """“If METADATA is supported, it MUST be specified in RPL_ISUPPORT
using the METADATA key. using the METADATA key.
-- <http://ircv3.net/specs/core/metadata-3.2.html> -- <http://ircv3.net/specs/core/metadata-3.2.html>
""" """
self.addClient() self.addClient()
self.sendLine(1, 'CAP LS 302') self.sendLine(1, "CAP LS 302")
self.getCapLs(1) self.getCapLs(1)
self.sendLine(1, 'USER foo foo foo :foo') self.sendLine(1, "USER foo foo foo :foo")
self.sendLine(1, 'NICK foo') self.sendLine(1, "NICK foo")
self.sendLine(1, 'CAP END') self.sendLine(1, "CAP END")
self.skipToWelcome(1) self.skipToWelcome(1)
m = self.getMessage(1) m = self.getMessage(1)
while m.command != '005': # RPL_ISUPPORT while m.command != "005": # RPL_ISUPPORT
m = self.getMessage(1) m = self.getMessage(1)
self.assertIn('METADATA', {x.split('=')[0] for x in m.params[1:-1]}, self.assertIn(
fail_msg='{item} missing from RPL_ISUPPORT') "METADATA",
{x.split("=")[0] for x in m.params[1:-1]},
fail_msg="{item} missing from RPL_ISUPPORT",
)
self.getMessages(1) self.getMessages(1)
@cases.SpecificationSelector.requiredBySpecification('IRCv3.2-deprecated') @cases.SpecificationSelector.requiredBySpecification("IRCv3.2-deprecated")
def testGetOneUnsetValid(self): def testGetOneUnsetValid(self):
"""<http://ircv3.net/specs/core/metadata-3.2.html#metadata-get> """<http://ircv3.net/specs/core/metadata-3.2.html#metadata-get>"""
""" self.connectClient("foo")
self.connectClient('foo') self.sendLine(1, "METADATA * GET valid_key1")
self.sendLine(1, 'METADATA * GET valid_key1')
m = self.getMessage(1) m = self.getMessage(1)
self.assertMessageEqual(m, command='766', # ERR_NOMATCHINGKEY self.assertMessageEqual(
fail_msg='Did not reply with 766 (ERR_NOMATCHINGKEY) to a ' m,
'request to an unset valid METADATA key.') command="766", # ERR_NOMATCHINGKEY
fail_msg="Did not reply with 766 (ERR_NOMATCHINGKEY) to a "
"request to an unset valid METADATA key.",
)
@cases.SpecificationSelector.requiredBySpecification('IRCv3.2-deprecated') @cases.SpecificationSelector.requiredBySpecification("IRCv3.2-deprecated")
def testGetTwoUnsetValid(self): def testGetTwoUnsetValid(self):
"""“Multiple keys may be given. The response will be either RPL_KEYVALUE, """“Multiple keys may be given. The response will be either RPL_KEYVALUE,
ERR_KEYINVALID or ERR_NOMATCHINGKEY for every key in order. ERR_KEYINVALID or ERR_NOMATCHINGKEY for every key in order.
-- <http://ircv3.net/specs/core/metadata-3.2.html#metadata-get> -- <http://ircv3.net/specs/core/metadata-3.2.html#metadata-get>
""" """
self.connectClient('foo') self.connectClient("foo")
self.sendLine(1, 'METADATA * GET valid_key1 valid_key2') self.sendLine(1, "METADATA * GET valid_key1 valid_key2")
m = self.getMessage(1) m = self.getMessage(1)
self.assertMessageEqual(m, command='766', # ERR_NOMATCHINGKEY self.assertMessageEqual(
fail_msg='Did not reply with 766 (ERR_NOMATCHINGKEY) to a ' m,
'request to two unset valid METADATA key: {msg}') command="766", # ERR_NOMATCHINGKEY
self.assertEqual(m.params[1], 'valid_key1', m, fail_msg="Did not reply with 766 (ERR_NOMATCHINGKEY) to a "
fail_msg='Response to “METADATA * GET valid_key1 valid_key2” ' "request to two unset valid METADATA key: {msg}",
'did not respond to valid_key1 first: {msg}') )
self.assertEqual(
m.params[1],
"valid_key1",
m,
fail_msg="Response to “METADATA * GET valid_key1 valid_key2” "
"did not respond to valid_key1 first: {msg}",
)
m = self.getMessage(1) m = self.getMessage(1)
self.assertMessageEqual(m, command='766', # ERR_NOMATCHINGKEY self.assertMessageEqual(
fail_msg='Did not reply with two 766 (ERR_NOMATCHINGKEY) to a ' m,
'request to two unset valid METADATA key: {msg}') command="766", # ERR_NOMATCHINGKEY
self.assertEqual(m.params[1], 'valid_key2', m, fail_msg="Did not reply with two 766 (ERR_NOMATCHINGKEY) to a "
fail_msg='Response to “METADATA * GET valid_key1 valid_key2” ' "request to two unset valid METADATA key: {msg}",
'did not respond to valid_key2 as second response: {msg}') )
self.assertEqual(
m.params[1],
"valid_key2",
m,
fail_msg="Response to “METADATA * GET valid_key1 valid_key2” "
"did not respond to valid_key2 as second response: {msg}",
)
@cases.SpecificationSelector.requiredBySpecification('IRCv3.2-deprecated') @cases.SpecificationSelector.requiredBySpecification("IRCv3.2-deprecated")
def testListNoSet(self): def testListNoSet(self):
"""“This subcommand MUST list all currently-set metadata keys along """“This subcommand MUST list all currently-set metadata keys along
with their values. The response will be zero or more RPL_KEYVALUE with their values. The response will be zero or more RPL_KEYVALUE
events, following by RPL_METADATAEND event. events, following by RPL_METADATAEND event.
-- <http://ircv3.net/specs/core/metadata-3.2.html#metadata-list> -- <http://ircv3.net/specs/core/metadata-3.2.html#metadata-list>
""" """
self.connectClient('foo') self.connectClient("foo")
self.sendLine(1, 'METADATA * LIST') self.sendLine(1, "METADATA * LIST")
m = self.getMessage(1) m = self.getMessage(1)
self.assertMessageEqual(m, command='762', # RPL_METADATAEND self.assertMessageEqual(
fail_msg='Response to “METADATA * LIST” was not ' m,
'762 (RPL_METADATAEND) but: {msg}') command="762", # RPL_METADATAEND
fail_msg="Response to “METADATA * LIST” was not "
"762 (RPL_METADATAEND) but: {msg}",
)
@cases.SpecificationSelector.requiredBySpecification('IRCv3.2-deprecated') @cases.SpecificationSelector.requiredBySpecification("IRCv3.2-deprecated")
def testListInvalidTarget(self): def testListInvalidTarget(self):
"""“In case of invalid target RPL_METADATAEND MUST NOT be sent.” """“In case of invalid target RPL_METADATAEND MUST NOT be sent.”
-- <http://ircv3.net/specs/core/metadata-3.2.html#metadata-list> -- <http://ircv3.net/specs/core/metadata-3.2.html#metadata-list>
""" """
self.connectClient('foo') self.connectClient("foo")
self.sendLine(1, 'METADATA foobar LIST') self.sendLine(1, "METADATA foobar LIST")
m = self.getMessage(1) m = self.getMessage(1)
self.assertMessageEqual(m, command='765', # ERR_TARGETINVALID self.assertMessageEqual(
fail_msg='Response to “METADATA <invalid target> LIST” was ' m,
'not 765 (ERR_TARGETINVALID) but: {msg}') command="765", # ERR_TARGETINVALID
fail_msg="Response to “METADATA <invalid target> LIST” was "
"not 765 (ERR_TARGETINVALID) but: {msg}",
)
commands = {m.command for m in self.getMessages(1)} commands = {m.command for m in self.getMessages(1)}
self.assertNotIn('762', commands, self.assertNotIn(
fail_msg='Sent “METADATA <invalid target> LIST”, got 765 ' "762",
'(ERR_TARGETINVALID), and then 762 (RPL_METADATAEND)') commands,
fail_msg="Sent “METADATA <invalid target> LIST”, got 765 "
"(ERR_TARGETINVALID), and then 762 (RPL_METADATAEND)",
)
def assertSetValue(self, target, key, value, displayable_value=None): def assertSetValue(self, target, key, value, displayable_value=None):
if displayable_value is None: if displayable_value is None:
displayable_value = value displayable_value = value
self.sendLine(1, 'METADATA {} SET {} :{}'.format(target, key, value)) self.sendLine(1, "METADATA {} SET {} :{}".format(target, key, value))
m = self.getMessage(1) m = self.getMessage(1)
self.assertMessageEqual(m, command='761', # RPL_KEYVALUE self.assertMessageEqual(
fail_msg='Did not reply with 761 (RPL_KEYVALUE) to a valid ' m,
'“METADATA * SET {} :{}”: {msg}', command="761", # RPL_KEYVALUE
extra_format=(key, displayable_value,)) fail_msg="Did not reply with 761 (RPL_KEYVALUE) to a valid "
self.assertEqual(m.params[1], 'valid_key1', m, "“METADATA * SET {} :{}”: {msg}",
fail_msg='Second param of 761 after setting “{expects}” to ' extra_format=(
'{}” is not “{expects}”: {msg}.', key,
extra_format=(displayable_value,)) displayable_value,
self.assertEqual(m.params[3], value, m, ),
fail_msg='Fourth param of 761 after setting “{0}” to ' )
'{1}” is not “{1}”: {msg}.', self.assertEqual(
extra_format=(key, displayable_value)) m.params[1],
"valid_key1",
m,
fail_msg="Second param of 761 after setting “{expects}” to "
"{}” is not “{expects}”: {msg}.",
extra_format=(displayable_value,),
)
self.assertEqual(
m.params[3],
value,
m,
fail_msg="Fourth param of 761 after setting “{0}” to "
"{1}” is not “{1}”: {msg}.",
extra_format=(key, displayable_value),
)
m = self.getMessage(1) m = self.getMessage(1)
self.assertMessageEqual(m, command='762', # RPL_METADATAEND self.assertMessageEqual(
fail_msg='Did not send RPL_METADATAEND after setting ' m,
'a valid METADATA key.') command="762", # RPL_METADATAEND
fail_msg="Did not send RPL_METADATAEND after setting "
"a valid METADATA key.",
)
def assertGetValue(self, target, key, value, displayable_value=None): def assertGetValue(self, target, key, value, displayable_value=None):
self.sendLine(1, 'METADATA * GET {}'.format(key)) self.sendLine(1, "METADATA * GET {}".format(key))
m = self.getMessage(1) m = self.getMessage(1)
self.assertMessageEqual(m, command='761', # RPL_KEYVALUE self.assertMessageEqual(
fail_msg='Did not reply with 761 (RPL_KEYVALUE) to a valid ' m,
'“METADATA * GET” when the key is set is set: {msg}') command="761", # RPL_KEYVALUE
self.assertEqual(m.params[1], key, m, fail_msg="Did not reply with 761 (RPL_KEYVALUE) to a valid "
fail_msg='Second param of 761 after getting “{expects}' "“METADATA * GET” when the key is set is set: {msg}",
'(which is set) is not “{expects}”: {msg}.') )
self.assertEqual(m.params[3], value, m, self.assertEqual(
fail_msg='Fourth param of 761 after getting “{0}' m.params[1],
'(which is set to “{1}”) is not ”{1}”: {msg}.', key,
extra_format=(key, displayable_value)) m,
fail_msg="Second param of 761 after getting “{expects}"
"(which is set) is not “{expects}”: {msg}.",
)
self.assertEqual(
m.params[3],
value,
m,
fail_msg="Fourth param of 761 after getting “{0}"
"(which is set to “{1}”) is not ”{1}”: {msg}.",
extra_format=(key, displayable_value),
)
def assertSetGetValue(self, target, key, value, displayable_value=None): def assertSetGetValue(self, target, key, value, displayable_value=None):
self.assertSetValue(target, key, value, displayable_value) self.assertSetValue(target, key, value, displayable_value)
self.assertGetValue(target, key, value, displayable_value) self.assertGetValue(target, key, value, displayable_value)
@cases.SpecificationSelector.requiredBySpecification('IRCv3.2-deprecated') @cases.SpecificationSelector.requiredBySpecification("IRCv3.2-deprecated")
def testSetGetValid(self): def testSetGetValid(self):
"""<http://ircv3.net/specs/core/metadata-3.2.html> """<http://ircv3.net/specs/core/metadata-3.2.html>"""
""" self.connectClient("foo")
self.connectClient('foo') self.assertSetGetValue("*", "valid_key1", "myvalue")
self.assertSetGetValue('*', 'valid_key1', 'myvalue')
@cases.SpecificationSelector.requiredBySpecification('IRCv3.2-deprecated') @cases.SpecificationSelector.requiredBySpecification("IRCv3.2-deprecated")
def testSetGetZeroCharInValue(self): def testSetGetZeroCharInValue(self):
"""“Values are unrestricted, except that they MUST be UTF-8.” """“Values are unrestricted, except that they MUST be UTF-8.”
-- <http://ircv3.net/specs/core/metadata-3.2.html#metadata-restrictions> -- <http://ircv3.net/specs/core/metadata-3.2.html#metadata-restrictions>
""" """
self.connectClient('foo') self.connectClient("foo")
self.assertSetGetValue('*', 'valid_key1', 'zero->\0<-zero', self.assertSetGetValue("*", "valid_key1", "zero->\0<-zero", "zero->\\0<-zero")
'zero->\\0<-zero')
@cases.SpecificationSelector.requiredBySpecification('IRCv3.2-deprecated') @cases.SpecificationSelector.requiredBySpecification("IRCv3.2-deprecated")
def testSetGetHeartInValue(self): def testSetGetHeartInValue(self):
"""“Values are unrestricted, except that they MUST be UTF-8.” """“Values are unrestricted, except that they MUST be UTF-8.”
-- <http://ircv3.net/specs/core/metadata-3.2.html#metadata-restrictions> -- <http://ircv3.net/specs/core/metadata-3.2.html#metadata-restrictions>
""" """
heart = b'\xf0\x9f\x92\x9c'.decode() heart = b"\xf0\x9f\x92\x9c".decode()
self.connectClient('foo') self.connectClient("foo")
self.assertSetGetValue('*', 'valid_key1', '->{}<-'.format(heart), self.assertSetGetValue(
'zero->{}<-zero'.format(heart.encode())) "*",
"valid_key1",
"->{}<-".format(heart),
"zero->{}<-zero".format(heart.encode()),
)
@cases.SpecificationSelector.requiredBySpecification('IRCv3.2-deprecated') @cases.SpecificationSelector.requiredBySpecification("IRCv3.2-deprecated")
def testSetInvalidUtf8(self): def testSetInvalidUtf8(self):
"""“Values are unrestricted, except that they MUST be UTF-8.” """“Values are unrestricted, except that they MUST be UTF-8.”
-- <http://ircv3.net/specs/core/metadata-3.2.html#metadata-restrictions> -- <http://ircv3.net/specs/core/metadata-3.2.html#metadata-restrictions>
""" """
self.connectClient('foo') self.connectClient("foo")
# Sending directly because it is not valid UTF-8 so Python would # Sending directly because it is not valid UTF-8 so Python would
# not like it # not like it
self.clients[1].conn.sendall(b'METADATA * SET valid_key1 ' self.clients[1].conn.sendall(
b':invalid UTF-8 ->\xc3<-\r\n') b"METADATA * SET valid_key1 " b":invalid UTF-8 ->\xc3<-\r\n"
)
commands = {m.command for m in self.getMessages(1)} commands = {m.command for m in self.getMessages(1)}
self.assertNotIn('761', commands, # RPL_KEYVALUE self.assertNotIn(
fail_msg='Setting METADATA key to a value containing invalid ' "761",
'UTF-8 was answered with 761 (RPL_KEYVALUE)') commands, # RPL_KEYVALUE
self.clients[1].conn.sendall(b'METADATA * SET valid_key1 ' fail_msg="Setting METADATA key to a value containing invalid "
b':invalid UTF-8: \xc3\r\n') "UTF-8 was answered with 761 (RPL_KEYVALUE)",
)
self.clients[1].conn.sendall(
b"METADATA * SET valid_key1 " b":invalid UTF-8: \xc3\r\n"
)
commands = {m.command for m in self.getMessages(1)} commands = {m.command for m in self.getMessages(1)}
self.assertNotIn('761', commands, # RPL_KEYVALUE self.assertNotIn(
fail_msg='Setting METADATA key to a value containing invalid ' "761",
'UTF-8 was answered with 761 (RPL_KEYVALUE)') commands, # RPL_KEYVALUE
fail_msg="Setting METADATA key to a value containing invalid "
"UTF-8 was answered with 761 (RPL_KEYVALUE)",
)

View File

@ -5,106 +5,132 @@
from irctest import cases from irctest import cases
from irctest.client_mock import NoMessageException from irctest.client_mock import NoMessageException
from irctest.basecontrollers import NotImplementedByController from irctest.basecontrollers import NotImplementedByController
from irctest.numerics import RPL_MONLIST, RPL_ENDOFMONLIST, RPL_MONONLINE, RPL_MONOFFLINE from irctest.numerics import (
RPL_MONLIST,
RPL_ENDOFMONLIST,
RPL_MONONLINE,
RPL_MONOFFLINE,
)
class MonitorTestCase(cases.BaseServerTestCase): class MonitorTestCase(cases.BaseServerTestCase):
def check_server_support(self): def check_server_support(self):
if 'MONITOR' not in self.server_support: if "MONITOR" not in self.server_support:
raise NotImplementedByController('MONITOR') raise NotImplementedByController("MONITOR")
def assertMononline(self, client, nick, m=None): def assertMononline(self, client, nick, m=None):
if not m: if not m:
m = self.getMessage(client) m = self.getMessage(client)
self.assertMessageEqual(m, command='730', # RPL_MONONLINE self.assertMessageEqual(
fail_msg='Sent non-730 (RPL_MONONLINE) message after ' m,
'monitored nick “{}” connected: {msg}', command="730", # RPL_MONONLINE
extra_format=(nick,)) fail_msg="Sent non-730 (RPL_MONONLINE) message after "
self.assertEqual(len(m.params), 2, m, "monitored nick “{}” connected: {msg}",
fail_msg='Invalid number of params of RPL_MONONLINE: {msg}') extra_format=(nick,),
self.assertEqual(m.params[1].split('!')[0], 'bar', )
fail_msg='730 (RPL_MONONLINE) with bad target after “{}' self.assertEqual(
'connects: {msg}', len(m.params),
extra_format=(nick,)) 2,
m,
fail_msg="Invalid number of params of RPL_MONONLINE: {msg}",
)
self.assertEqual(
m.params[1].split("!")[0],
"bar",
fail_msg="730 (RPL_MONONLINE) with bad target after “{}"
"connects: {msg}",
extra_format=(nick,),
)
def assertMonoffline(self, client, nick, m=None): def assertMonoffline(self, client, nick, m=None):
if not m: if not m:
m = self.getMessage(client) m = self.getMessage(client)
self.assertMessageEqual(m, command='731', # RPL_MONOFFLINE self.assertMessageEqual(
fail_msg='Did not reply with 731 (RPL_MONOFFLINE) to ' m,
'“MONITOR + {}”, while “{}” is offline: {msg}', command="731", # RPL_MONOFFLINE
extra_format=(nick, nick)) fail_msg="Did not reply with 731 (RPL_MONOFFLINE) to "
self.assertEqual(len(m.params), 2, m, "“MONITOR + {}”, while “{}” is offline: {msg}",
fail_msg='Invalid number of params of RPL_MONOFFLINE: {msg}') extra_format=(nick, nick),
self.assertEqual(m.params[1].split('!')[0], 'bar', )
fail_msg='731 (RPL_MONOFFLINE) reply to “MONITOR + {}' self.assertEqual(
'with bad target: {msg}', len(m.params),
extra_format=(nick,)) 2,
m,
fail_msg="Invalid number of params of RPL_MONOFFLINE: {msg}",
)
self.assertEqual(
m.params[1].split("!")[0],
"bar",
fail_msg="731 (RPL_MONOFFLINE) reply to “MONITOR + {}"
"with bad target: {msg}",
extra_format=(nick,),
)
@cases.SpecificationSelector.requiredBySpecification('IRCv3.2') @cases.SpecificationSelector.requiredBySpecification("IRCv3.2")
def testMonitorOneDisconnected(self): def testMonitorOneDisconnected(self):
"""“If any of the targets being added are online, the server will """“If any of the targets being added are online, the server will
generate RPL_MONONLINE numerics listing those targets that are generate RPL_MONONLINE numerics listing those targets that are
online. online.
-- <http://ircv3.net/specs/core/monitor-3.2.html#monitor--targettarget2> -- <http://ircv3.net/specs/core/monitor-3.2.html#monitor--targettarget2>
""" """
self.connectClient('foo') self.connectClient("foo")
self.check_server_support() self.check_server_support()
self.sendLine(1, 'MONITOR + bar') self.sendLine(1, "MONITOR + bar")
self.assertMonoffline(1, 'bar') self.assertMonoffline(1, "bar")
self.connectClient('bar') self.connectClient("bar")
self.assertMononline(1, 'bar') self.assertMononline(1, "bar")
self.sendLine(2, 'QUIT :bye') self.sendLine(2, "QUIT :bye")
try: try:
self.getMessages(2) self.getMessages(2)
except ConnectionResetError: except ConnectionResetError:
pass pass
self.assertMonoffline(1, 'bar') self.assertMonoffline(1, "bar")
@cases.SpecificationSelector.requiredBySpecification('IRCv3.2') @cases.SpecificationSelector.requiredBySpecification("IRCv3.2")
def testMonitorOneConnection(self): def testMonitorOneConnection(self):
self.connectClient('foo') self.connectClient("foo")
self.check_server_support() self.check_server_support()
self.sendLine(1, 'MONITOR + bar') self.sendLine(1, "MONITOR + bar")
self.getMessages(1) self.getMessages(1)
self.connectClient('bar') self.connectClient("bar")
self.assertMononline(1, 'bar') self.assertMononline(1, "bar")
@cases.SpecificationSelector.requiredBySpecification('IRCv3.2') @cases.SpecificationSelector.requiredBySpecification("IRCv3.2")
def testMonitorOneConnected(self): def testMonitorOneConnected(self):
"""“If any of the targets being added are offline, the server will """“If any of the targets being added are offline, the server will
generate RPL_MONOFFLINE numerics listing those targets that are generate RPL_MONOFFLINE numerics listing those targets that are
online. online.
-- <http://ircv3.net/specs/core/monitor-3.2.html#monitor--targettarget2> -- <http://ircv3.net/specs/core/monitor-3.2.html#monitor--targettarget2>
""" """
self.connectClient('foo') self.connectClient("foo")
self.check_server_support() self.check_server_support()
self.connectClient('bar') self.connectClient("bar")
self.sendLine(1, 'MONITOR + bar') self.sendLine(1, "MONITOR + bar")
self.assertMononline(1, 'bar') self.assertMononline(1, "bar")
self.sendLine(2, 'QUIT :bye') self.sendLine(2, "QUIT :bye")
try: try:
self.getMessages(2) self.getMessages(2)
except ConnectionResetError: except ConnectionResetError:
pass pass
self.assertMonoffline(1, 'bar') self.assertMonoffline(1, "bar")
@cases.SpecificationSelector.requiredBySpecification('IRCv3.2') @cases.SpecificationSelector.requiredBySpecification("IRCv3.2")
def testMonitorOneConnectionWithQuit(self): def testMonitorOneConnectionWithQuit(self):
self.connectClient('foo') self.connectClient("foo")
self.check_server_support() self.check_server_support()
self.connectClient('bar') self.connectClient("bar")
self.sendLine(1, 'MONITOR + bar') self.sendLine(1, "MONITOR + bar")
self.assertMononline(1, 'bar') self.assertMononline(1, "bar")
self.sendLine(2, 'QUIT :bye') self.sendLine(2, "QUIT :bye")
try: try:
self.getMessages(2) self.getMessages(2)
except ConnectionResetError: except ConnectionResetError:
pass pass
self.assertMonoffline(1, 'bar') self.assertMonoffline(1, "bar")
self.connectClient('bar') self.connectClient("bar")
self.assertMononline(1, 'bar') self.assertMononline(1, "bar")
@cases.SpecificationSelector.requiredBySpecification('IRCv3.2') @cases.SpecificationSelector.requiredBySpecification("IRCv3.2")
def testMonitorConnectedAndDisconnected(self): def testMonitorConnectedAndDisconnected(self):
"""“If any of the targets being added are online, the server will """“If any of the targets being added are online, the server will
generate RPL_MONONLINE numerics listing those targets that are generate RPL_MONONLINE numerics listing those targets that are
@ -115,52 +141,76 @@ class MonitorTestCase(cases.BaseServerTestCase):
online. online.
-- <http://ircv3.net/specs/core/monitor-3.2.html#monitor--targettarget2> -- <http://ircv3.net/specs/core/monitor-3.2.html#monitor--targettarget2>
""" """
self.connectClient('foo') self.connectClient("foo")
self.check_server_support() self.check_server_support()
self.connectClient('bar') self.connectClient("bar")
self.sendLine(1, 'MONITOR + bar,baz') self.sendLine(1, "MONITOR + bar,baz")
m1 = self.getMessage(1) m1 = self.getMessage(1)
m2 = self.getMessage(1) m2 = self.getMessage(1)
commands = {m1.command, m2.command} commands = {m1.command, m2.command}
self.assertEqual(commands, {'730', '731'}, self.assertEqual(
fail_msg='Did not send one 730 (RPL_MONONLINE) and one ' commands,
'731 (RPL_MONOFFLINE) after “MONITOR + bar,baz” when “bar” ' {"730", "731"},
'is online and “baz” is offline. Sent this instead: {}', fail_msg="Did not send one 730 (RPL_MONONLINE) and one "
extra_format=((m1, m2))) "731 (RPL_MONOFFLINE) after “MONITOR + bar,baz” when “bar” "
if m1.command == '731': "is online and “baz” is offline. Sent this instead: {}",
extra_format=((m1, m2)),
)
if m1.command == "731":
(m1, m2) = (m2, m1) (m1, m2) = (m2, m1)
self.assertEqual(len(m1.params), 2, m1, self.assertEqual(
fail_msg='Invalid number of params of RPL_MONONLINE: {msg}') len(m1.params),
self.assertEqual(len(m2.params), 2, m2, 2,
fail_msg='Invalid number of params of RPL_MONONLINE: {msg}') m1,
self.assertEqual(m1.params[1].split('!')[0], 'bar', m1, fail_msg="Invalid number of params of RPL_MONONLINE: {msg}",
fail_msg='730 (RPL_MONONLINE) with bad target after ' )
'“MONITOR + bar,baz” and “bar” is connected: {msg}') self.assertEqual(
self.assertEqual(m2.params[1].split('!')[0], 'baz', m2, len(m2.params),
fail_msg='731 (RPL_MONOFFLINE) with bad target after ' 2,
'“MONITOR + bar,baz” and “baz” is disconnected: {msg}') m2,
fail_msg="Invalid number of params of RPL_MONONLINE: {msg}",
)
self.assertEqual(
m1.params[1].split("!")[0],
"bar",
m1,
fail_msg="730 (RPL_MONONLINE) with bad target after "
"“MONITOR + bar,baz” and “bar” is connected: {msg}",
)
self.assertEqual(
m2.params[1].split("!")[0],
"baz",
m2,
fail_msg="731 (RPL_MONOFFLINE) with bad target after "
"“MONITOR + bar,baz” and “baz” is disconnected: {msg}",
)
@cases.SpecificationSelector.requiredBySpecification('IRCv3.2') @cases.SpecificationSelector.requiredBySpecification("IRCv3.2")
def testUnmonitor(self): def testUnmonitor(self):
self.connectClient('foo') self.connectClient("foo")
self.check_server_support() self.check_server_support()
self.sendLine(1, 'MONITOR + bar') self.sendLine(1, "MONITOR + bar")
self.getMessages(1) self.getMessages(1)
self.connectClient('bar') self.connectClient("bar")
self.assertMononline(1, 'bar') self.assertMononline(1, "bar")
self.sendLine(1, 'MONITOR - bar') self.sendLine(1, "MONITOR - bar")
self.assertEqual(self.getMessages(1), [], self.assertEqual(
fail_msg='Got messages after “MONITOR - bar”: {got}') self.getMessages(1),
self.sendLine(2, 'QUIT :bye') [],
fail_msg="Got messages after “MONITOR - bar”: {got}",
)
self.sendLine(2, "QUIT :bye")
try: try:
self.getMessages(2) self.getMessages(2)
except ConnectionResetError: except ConnectionResetError:
pass pass
self.assertEqual(self.getMessages(1), [], self.assertEqual(
fail_msg='Got messages after disconnection of unmonitored ' self.getMessages(1),
'nick: {got}') [],
fail_msg="Got messages after disconnection of unmonitored " "nick: {got}",
)
@cases.SpecificationSelector.requiredBySpecification('IRCv3.2') @cases.SpecificationSelector.requiredBySpecification("IRCv3.2")
def testMonitorForbidsMasks(self): def testMonitorForbidsMasks(self):
"""“The MONITOR implementation also enhances user privacy by """“The MONITOR implementation also enhances user privacy by
disallowing subscription to hostmasks, allowing users to avoid disallowing subscription to hostmasks, allowing users to avoid
@ -171,27 +221,33 @@ class MonitorTestCase(cases.BaseServerTestCase):
by the IRC daemon. by the IRC daemon.
-- <http://ircv3.net/specs/core/monitor-3.2.html#monitor-command> -- <http://ircv3.net/specs/core/monitor-3.2.html#monitor-command>
""" """
self.connectClient('foo') self.connectClient("foo")
self.check_server_support() self.check_server_support()
self.sendLine(1, 'MONITOR + *!username@localhost') self.sendLine(1, "MONITOR + *!username@localhost")
self.sendLine(1, 'MONITOR + *!username@127.0.0.1') self.sendLine(1, "MONITOR + *!username@127.0.0.1")
try: try:
m = self.getMessage(1) m = self.getMessage(1)
self.assertNotEqual(m.command, '731', m, self.assertNotEqual(
fail_msg='Got 731 (RPL_MONOFFLINE) after adding a monitor ' m.command,
'on a mask: {msg}') "731",
m,
fail_msg="Got 731 (RPL_MONOFFLINE) after adding a monitor "
"on a mask: {msg}",
)
except NoMessageException: except NoMessageException:
pass pass
self.connectClient('bar') self.connectClient("bar")
try: try:
m = self.getMessage(1) m = self.getMessage(1)
except NoMessageException: except NoMessageException:
pass pass
else: else:
raise AssertionError('Got message after client whose MONITORing ' raise AssertionError(
'was requested via hostmask connected: {}'.format(m)) "Got message after client whose MONITORing "
"was requested via hostmask connected: {}".format(m)
)
@cases.SpecificationSelector.requiredBySpecification('IRCv3.2') @cases.SpecificationSelector.requiredBySpecification("IRCv3.2")
def testTwoMonitoringOneRemove(self): def testTwoMonitoringOneRemove(self):
"""Tests the following scenario: """Tests the following scenario:
* foo MONITORs qux * foo MONITORs qux
@ -199,30 +255,30 @@ class MonitorTestCase(cases.BaseServerTestCase):
* bar unMONITORs qux * bar unMONITORs qux
* qux connects. * qux connects.
""" """
self.connectClient('foo') self.connectClient("foo")
self.check_server_support() self.check_server_support()
self.connectClient('bar') self.connectClient("bar")
self.sendLine(1, 'MONITOR + qux') self.sendLine(1, "MONITOR + qux")
self.sendLine(2, 'MONITOR + qux') self.sendLine(2, "MONITOR + qux")
self.getMessages(1) self.getMessages(1)
self.getMessages(2) self.getMessages(2)
self.sendLine(2, 'MONITOR - qux') self.sendLine(2, "MONITOR - qux")
l = self.getMessages(2) l = self.getMessages(2)
self.assertEqual(l, [], self.assertEqual(
fail_msg='Got response to “MONITOR -”: {}', l, [], fail_msg="Got response to “MONITOR -”: {}", extra_format=(l,)
extra_format=(l,)) )
self.connectClient('qux') self.connectClient("qux")
self.getMessages(3) self.getMessages(3)
l = self.getMessages(1) l = self.getMessages(1)
self.assertNotEqual(l, [], self.assertNotEqual(
fail_msg='Received no message after MONITORed client ' l, [], fail_msg="Received no message after MONITORed client " "connects."
'connects.') )
l = self.getMessages(2) l = self.getMessages(2)
self.assertEqual(l, [], self.assertEqual(
fail_msg='Got response to unmonitored client: {}', l, [], fail_msg="Got response to unmonitored client: {}", extra_format=(l,)
extra_format=(l,)) )
@cases.SpecificationSelector.requiredBySpecification('IRCv3.2') @cases.SpecificationSelector.requiredBySpecification("IRCv3.2")
def testMonitorList(self): def testMonitorList(self):
def checkMonitorSubjects(messages, client_nick, expected_targets): def checkMonitorSubjects(messages, client_nick, expected_targets):
# collect all the RPL_MONLIST nicks into a set: # collect all the RPL_MONLIST nicks into a set:
@ -230,63 +286,81 @@ class MonitorTestCase(cases.BaseServerTestCase):
for message in messages: for message in messages:
if message.command == RPL_MONLIST: if message.command == RPL_MONLIST:
self.assertEqual(message.params[0], client_nick) self.assertEqual(message.params[0], client_nick)
result.update(message.params[1].split(',')) result.update(message.params[1].split(","))
# finally, RPL_ENDOFMONLIST should be sent # finally, RPL_ENDOFMONLIST should be sent
self.assertEqual(messages[-1].command, RPL_ENDOFMONLIST) self.assertEqual(messages[-1].command, RPL_ENDOFMONLIST)
self.assertEqual(messages[-1].params[0], client_nick) self.assertEqual(messages[-1].params[0], client_nick)
self.assertEqual(result, expected_targets) self.assertEqual(result, expected_targets)
self.connectClient('bar') self.connectClient("bar")
self.check_server_support() self.check_server_support()
self.sendLine(1, 'MONITOR L') self.sendLine(1, "MONITOR L")
checkMonitorSubjects(self.getMessages(1), 'bar', set()) checkMonitorSubjects(self.getMessages(1), "bar", set())
self.sendLine(1, 'MONITOR + qux') self.sendLine(1, "MONITOR + qux")
self.getMessages(1) self.getMessages(1)
self.sendLine(1, 'MONITOR L') self.sendLine(1, "MONITOR L")
checkMonitorSubjects(self.getMessages(1), 'bar', {'qux',}) checkMonitorSubjects(
self.getMessages(1),
"bar",
{
"qux",
},
)
self.sendLine(1, 'MONITOR + bazbat') self.sendLine(1, "MONITOR + bazbat")
self.getMessages(1) self.getMessages(1)
self.sendLine(1, 'MONITOR L') self.sendLine(1, "MONITOR L")
checkMonitorSubjects(self.getMessages(1), 'bar', {'qux', 'bazbat',}) checkMonitorSubjects(
self.getMessages(1),
"bar",
{
"qux",
"bazbat",
},
)
self.sendLine(1, 'MONITOR - qux') self.sendLine(1, "MONITOR - qux")
self.getMessages(1) self.getMessages(1)
self.sendLine(1, 'MONITOR L') self.sendLine(1, "MONITOR L")
checkMonitorSubjects(self.getMessages(1), 'bar', {'bazbat',}) checkMonitorSubjects(
self.getMessages(1),
"bar",
{
"bazbat",
},
)
@cases.SpecificationSelector.requiredBySpecification("IRCv3.2")
@cases.SpecificationSelector.requiredBySpecification('IRCv3.2')
def testNickChange(self): def testNickChange(self):
# see oragono issue #1076: nickname changes must trigger RPL_MONOFFLINE # see oragono issue #1076: nickname changes must trigger RPL_MONOFFLINE
self.connectClient('bar') self.connectClient("bar")
self.check_server_support() self.check_server_support()
self.sendLine(1, 'MONITOR + qux') self.sendLine(1, "MONITOR + qux")
self.getMessages(1) self.getMessages(1)
self.connectClient('baz') self.connectClient("baz")
self.getMessages(2) self.getMessages(2)
self.assertEqual(self.getMessages(1), []) self.assertEqual(self.getMessages(1), [])
self.sendLine(2, 'NICK qux') self.sendLine(2, "NICK qux")
self.getMessages(2) self.getMessages(2)
mononline = self.getMessages(1)[0] mononline = self.getMessages(1)[0]
self.assertEqual(mononline.command, RPL_MONONLINE) self.assertEqual(mononline.command, RPL_MONONLINE)
self.assertEqual(len(mononline.params), 2, mononline.params) self.assertEqual(len(mononline.params), 2, mononline.params)
self.assertIn(mononline.params[0], ('bar', '*')) self.assertIn(mononline.params[0], ("bar", "*"))
self.assertEqual(mononline.params[1].split('!')[0], 'qux') self.assertEqual(mononline.params[1].split("!")[0], "qux")
# no numerics for a case change # no numerics for a case change
self.sendLine(2, 'NICK QUX') self.sendLine(2, "NICK QUX")
self.getMessages(2) self.getMessages(2)
self.assertEqual(self.getMessages(1), []) self.assertEqual(self.getMessages(1), [])
self.sendLine(2, 'NICK bazbat') self.sendLine(2, "NICK bazbat")
self.getMessages(2) self.getMessages(2)
monoffline = self.getMessages(1)[0] monoffline = self.getMessages(1)[0]
# should get RPL_MONOFFLINE with the current unfolded nick # should get RPL_MONOFFLINE with the current unfolded nick
self.assertEqual(monoffline.command, RPL_MONOFFLINE) self.assertEqual(monoffline.command, RPL_MONOFFLINE)
self.assertEqual(len(monoffline.params), 2, monoffline.params) self.assertEqual(len(monoffline.params), 2, monoffline.params)
self.assertIn(monoffline.params[0], ('bar', '*')) self.assertIn(monoffline.params[0], ("bar", "*"))
self.assertEqual(monoffline.params[1].split('!')[0], 'QUX') self.assertEqual(monoffline.params[1].split("!")[0], "QUX")

View File

@ -5,8 +5,9 @@ Tests multi-prefix.
from irctest import cases from irctest import cases
class MultiPrefixTestCase(cases.BaseServerTestCase): class MultiPrefixTestCase(cases.BaseServerTestCase):
@cases.SpecificationSelector.requiredBySpecification('IRCv3.1') @cases.SpecificationSelector.requiredBySpecification("IRCv3.1")
def testMultiPrefix(self): def testMultiPrefix(self):
"""“When requested, the multi-prefix client capability will cause the """“When requested, the multi-prefix client capability will cause the
IRC server to send all possible prefixes which apply to a user in NAMES IRC server to send all possible prefixes which apply to a user in NAMES
@ -14,19 +15,35 @@ class MultiPrefixTestCase(cases.BaseServerTestCase):
These prefixes MUST be in order of rank, from highest to lowest. These prefixes MUST be in order of rank, from highest to lowest.
""" """
self.connectClient('foo', capabilities=['multi-prefix']) self.connectClient("foo", capabilities=["multi-prefix"])
self.joinChannel(1, '#chan') self.joinChannel(1, "#chan")
self.sendLine(1, 'MODE #chan +v foo') self.sendLine(1, "MODE #chan +v foo")
self.getMessages(1) self.getMessages(1)
#TODO(dan): Make sure +v is voice # TODO(dan): Make sure +v is voice
self.sendLine(1, 'NAMES #chan') self.sendLine(1, "NAMES #chan")
self.assertMessageEqual(self.getMessage(1), command='353', params=['foo', '=', '#chan', '@+foo'], fail_msg='Expected NAMES response (353) with @+foo, got: {msg}') self.assertMessageEqual(
self.getMessage(1),
command="353",
params=["foo", "=", "#chan", "@+foo"],
fail_msg="Expected NAMES response (353) with @+foo, got: {msg}",
)
self.getMessages(1) self.getMessages(1)
self.sendLine(1, 'WHO #chan') self.sendLine(1, "WHO #chan")
msg = self.getMessage(1) msg = self.getMessage(1)
self.assertEqual(msg.command, '352', msg, fail_msg='Expected WHO response (352), got: {msg}') self.assertEqual(
self.assertGreaterEqual(len(msg.params), 8, 'Expected WHO response (352) with 8 params, got: {msg}'.format(msg=msg)) msg.command, "352", msg, fail_msg="Expected WHO response (352), got: {msg}"
self.assertTrue('@+' in msg.params[6], 'Expected WHO response (352) with "@+" in param 7, got: {msg}'.format(msg=msg)) )
self.assertGreaterEqual(
len(msg.params),
8,
"Expected WHO response (352) with 8 params, got: {msg}".format(msg=msg),
)
self.assertTrue(
"@+" in msg.params[6],
'Expected WHO response (352) with "@+" in param 7, got: {msg}'.format(
msg=msg
),
)

View File

@ -4,118 +4,127 @@ draft/multiline
from irctest import cases from irctest import cases
CAP_NAME = 'draft/multiline' CAP_NAME = "draft/multiline"
BATCH_TYPE = 'draft/multiline' BATCH_TYPE = "draft/multiline"
CONCAT_TAG = 'draft/multiline-concat' CONCAT_TAG = "draft/multiline-concat"
base_caps = ["message-tags", "batch", "echo-message", "server-time", "labeled-response"]
base_caps = ['message-tags', 'batch', 'echo-message', 'server-time', 'labeled-response']
class MultilineTestCase(cases.BaseServerTestCase, cases.OptionalityHelper): class MultilineTestCase(cases.BaseServerTestCase, cases.OptionalityHelper):
@cases.SpecificationSelector.requiredBySpecification("multiline")
@cases.SpecificationSelector.requiredBySpecification('multiline')
def testBasic(self): def testBasic(self):
self.connectClient( self.connectClient(
'alice', capabilities=(base_caps + [CAP_NAME]), skip_if_cap_nak=True "alice", capabilities=(base_caps + [CAP_NAME]), skip_if_cap_nak=True
) )
self.joinChannel(1, '#test') self.joinChannel(1, "#test")
self.connectClient('bob', capabilities=(base_caps + [CAP_NAME])) self.connectClient("bob", capabilities=(base_caps + [CAP_NAME]))
self.joinChannel(2, '#test') self.joinChannel(2, "#test")
self.connectClient('charlie', capabilities=base_caps) self.connectClient("charlie", capabilities=base_caps)
self.joinChannel(3, '#test') self.joinChannel(3, "#test")
self.getMessages(1) self.getMessages(1)
self.getMessages(2) self.getMessages(2)
self.getMessages(3) self.getMessages(3)
self.sendLine(1, '@label=xyz BATCH +123 %s #test' % (BATCH_TYPE,)) self.sendLine(1, "@label=xyz BATCH +123 %s #test" % (BATCH_TYPE,))
self.sendLine(1, '@batch=123 PRIVMSG #test hello') self.sendLine(1, "@batch=123 PRIVMSG #test hello")
self.sendLine(1, '@batch=123 PRIVMSG #test :#how is ') self.sendLine(1, "@batch=123 PRIVMSG #test :#how is ")
self.sendLine(1, '@batch=123;%s PRIVMSG #test :everyone?' % (CONCAT_TAG,)) self.sendLine(1, "@batch=123;%s PRIVMSG #test :everyone?" % (CONCAT_TAG,))
self.sendLine(1, 'BATCH -123') self.sendLine(1, "BATCH -123")
echo = self.getMessages(1) echo = self.getMessages(1)
batchStart, batchEnd = echo[0], echo[-1] batchStart, batchEnd = echo[0], echo[-1]
self.assertEqual(batchStart.command, 'BATCH') self.assertEqual(batchStart.command, "BATCH")
self.assertEqual(batchStart.tags.get('label'), 'xyz') self.assertEqual(batchStart.tags.get("label"), "xyz")
self.assertEqual(len(batchStart.params), 3) self.assertEqual(len(batchStart.params), 3)
self.assertEqual(batchStart.params[1], CAP_NAME) self.assertEqual(batchStart.params[1], CAP_NAME)
self.assertEqual(batchStart.params[2], "#test") self.assertEqual(batchStart.params[2], "#test")
self.assertEqual(batchEnd.command, 'BATCH') self.assertEqual(batchEnd.command, "BATCH")
self.assertEqual(batchStart.params[0][1:], batchEnd.params[0][1:]) self.assertEqual(batchStart.params[0][1:], batchEnd.params[0][1:])
msgid = batchStart.tags.get('msgid') msgid = batchStart.tags.get("msgid")
time = batchStart.tags.get('time') time = batchStart.tags.get("time")
assert msgid assert msgid
assert time assert time
privmsgs = echo[1:-1] privmsgs = echo[1:-1]
for msg in privmsgs: for msg in privmsgs:
self.assertMessageEqual(msg, command='PRIVMSG') self.assertMessageEqual(msg, command="PRIVMSG")
self.assertNotIn('msgid', msg.tags) self.assertNotIn("msgid", msg.tags)
self.assertNotIn('time', msg.tags) self.assertNotIn("time", msg.tags)
self.assertIn(CONCAT_TAG, echo[3].tags) self.assertIn(CONCAT_TAG, echo[3].tags)
relay = self.getMessages(2) relay = self.getMessages(2)
batchStart, batchEnd = relay[0], relay[-1] batchStart, batchEnd = relay[0], relay[-1]
self.assertEqual(batchStart.command, 'BATCH') self.assertEqual(batchStart.command, "BATCH")
self.assertEqual(batchEnd.command, 'BATCH') self.assertEqual(batchEnd.command, "BATCH")
batchTag = batchStart.params[0][1:] batchTag = batchStart.params[0][1:]
self.assertEqual(batchStart.params[0], '+'+batchTag) self.assertEqual(batchStart.params[0], "+" + batchTag)
self.assertEqual(batchEnd.params[0], '-'+batchTag) self.assertEqual(batchEnd.params[0], "-" + batchTag)
self.assertEqual(batchStart.tags.get('msgid'), msgid) self.assertEqual(batchStart.tags.get("msgid"), msgid)
self.assertEqual(batchStart.tags.get('time'), time) self.assertEqual(batchStart.tags.get("time"), time)
privmsgs = relay[1:-1] privmsgs = relay[1:-1]
for msg in privmsgs: for msg in privmsgs:
self.assertMessageEqual(msg, command='PRIVMSG') self.assertMessageEqual(msg, command="PRIVMSG")
self.assertNotIn('msgid', msg.tags) self.assertNotIn("msgid", msg.tags)
self.assertNotIn('time', msg.tags) self.assertNotIn("time", msg.tags)
self.assertEqual(msg.tags.get('batch'), batchTag) self.assertEqual(msg.tags.get("batch"), batchTag)
self.assertIn(CONCAT_TAG, relay[3].tags) self.assertIn(CONCAT_TAG, relay[3].tags)
fallback_relay = self.getMessages(3) fallback_relay = self.getMessages(3)
relayed_fmsgids = [] relayed_fmsgids = []
for msg in fallback_relay: for msg in fallback_relay:
self.assertMessageEqual(msg, command='PRIVMSG') self.assertMessageEqual(msg, command="PRIVMSG")
relayed_fmsgids.append(msg.tags.get('msgid')) relayed_fmsgids.append(msg.tags.get("msgid"))
self.assertEqual(msg.tags.get('time'), time) self.assertEqual(msg.tags.get("time"), time)
self.assertNotIn(CONCAT_TAG, msg.tags) self.assertNotIn(CONCAT_TAG, msg.tags)
self.assertEqual(relayed_fmsgids, [msgid] + [None]*(len(fallback_relay)-1)) self.assertEqual(relayed_fmsgids, [msgid] + [None] * (len(fallback_relay) - 1))
@cases.SpecificationSelector.requiredBySpecification("multiline")
@cases.SpecificationSelector.requiredBySpecification('multiline')
def testBlankLines(self): def testBlankLines(self):
self.connectClient( self.connectClient(
'alice', capabilities=(base_caps + [CAP_NAME]), skip_if_cap_nak=True "alice", capabilities=(base_caps + [CAP_NAME]), skip_if_cap_nak=True
) )
self.joinChannel(1, '#test') self.joinChannel(1, "#test")
self.connectClient('bob', capabilities=(base_caps + [CAP_NAME])) self.connectClient("bob", capabilities=(base_caps + [CAP_NAME]))
self.joinChannel(2, '#test') self.joinChannel(2, "#test")
self.connectClient('charlie', capabilities=base_caps) self.connectClient("charlie", capabilities=base_caps)
self.joinChannel(3, '#test') self.joinChannel(3, "#test")
self.getMessages(1) self.getMessages(1)
self.getMessages(2) self.getMessages(2)
self.getMessages(3) self.getMessages(3)
self.sendLine(1, '@label=xyz;+client-only-tag BATCH +123 %s #test' % (BATCH_TYPE,)) self.sendLine(
self.sendLine(1, '@batch=123 PRIVMSG #test :') 1, "@label=xyz;+client-only-tag BATCH +123 %s #test" % (BATCH_TYPE,)
self.sendLine(1, '@batch=123 PRIVMSG #test :#how is ') )
self.sendLine(1, '@batch=123;%s PRIVMSG #test :everyone?' % (CONCAT_TAG,)) self.sendLine(1, "@batch=123 PRIVMSG #test :")
self.sendLine(1, 'BATCH -123') self.sendLine(1, "@batch=123 PRIVMSG #test :#how is ")
self.sendLine(1, "@batch=123;%s PRIVMSG #test :everyone?" % (CONCAT_TAG,))
self.sendLine(1, "BATCH -123")
self.getMessages(1) self.getMessages(1)
relay = self.getMessages(2) relay = self.getMessages(2)
batch_start = relay[0] batch_start = relay[0]
privmsgs = relay[1:-1] privmsgs = relay[1:-1]
self.assertEqual(len(privmsgs), 3) self.assertEqual(len(privmsgs), 3)
self.assertMessageEqual(privmsgs[0], command='PRIVMSG', params=['#test', '']) self.assertMessageEqual(privmsgs[0], command="PRIVMSG", params=["#test", ""])
self.assertMessageEqual(privmsgs[1], command='PRIVMSG', params=['#test', '#how is ']) self.assertMessageEqual(
self.assertMessageEqual(privmsgs[2], command='PRIVMSG', params=['#test', 'everyone?']) privmsgs[1], command="PRIVMSG", params=["#test", "#how is "]
self.assertIn('+client-only-tag', batch_start.tags) )
msgid = batch_start.tags['msgid'] self.assertMessageEqual(
privmsgs[2], command="PRIVMSG", params=["#test", "everyone?"]
)
self.assertIn("+client-only-tag", batch_start.tags)
msgid = batch_start.tags["msgid"]
fallback_relay = self.getMessages(3) fallback_relay = self.getMessages(3)
self.assertEqual(len(fallback_relay), 2) self.assertEqual(len(fallback_relay), 2)
self.assertMessageEqual(fallback_relay[0], command='PRIVMSG', params=['#test', '#how is ']) self.assertMessageEqual(
self.assertMessageEqual(fallback_relay[1], command='PRIVMSG', params=['#test', 'everyone?']) fallback_relay[0], command="PRIVMSG", params=["#test", "#how is "]
self.assertIn('+client-only-tag', fallback_relay[0].tags) )
self.assertIn('+client-only-tag', fallback_relay[1].tags) self.assertMessageEqual(
self.assertEqual(fallback_relay[0].tags['msgid'], msgid) fallback_relay[1], command="PRIVMSG", params=["#test", "everyone?"]
)
self.assertIn("+client-only-tag", fallback_relay[0].tags)
self.assertIn("+client-only-tag", fallback_relay[1].tags)
self.assertEqual(fallback_relay[0].tags["msgid"], msgid)

View File

@ -1,107 +1,114 @@
from irctest import cases from irctest import cases
REGISTER_CAP_NAME = 'draft/register' REGISTER_CAP_NAME = "draft/register"
class TestRegisterBeforeConnect(cases.BaseServerTestCase): class TestRegisterBeforeConnect(cases.BaseServerTestCase):
@staticmethod @staticmethod
def config(): def config():
return { return {
"oragono_config": lambda config: config['accounts']['registration'].update( "oragono_config": lambda config: config["accounts"]["registration"].update(
{'allow-before-connect': True} {"allow-before-connect": True}
) )
} }
@cases.SpecificationSelector.requiredBySpecification('Oragono') @cases.SpecificationSelector.requiredBySpecification("Oragono")
def testBeforeConnect(self): def testBeforeConnect(self):
self.addClient('bar') self.addClient("bar")
self.sendLine('bar', 'CAP LS 302') self.sendLine("bar", "CAP LS 302")
caps = self.getCapLs('bar') caps = self.getCapLs("bar")
self.assertIn(REGISTER_CAP_NAME, caps) self.assertIn(REGISTER_CAP_NAME, caps)
self.assertIn('before-connect', caps[REGISTER_CAP_NAME]) self.assertIn("before-connect", caps[REGISTER_CAP_NAME])
self.sendLine('bar', 'NICK bar') self.sendLine("bar", "NICK bar")
self.sendLine('bar', 'REGISTER * shivarampassphrase') self.sendLine("bar", "REGISTER * shivarampassphrase")
msgs = self.getMessages('bar') msgs = self.getMessages("bar")
register_response = [msg for msg in msgs if msg.command == 'REGISTER'][0] register_response = [msg for msg in msgs if msg.command == "REGISTER"][0]
self.assertEqual(register_response.params[0], 'SUCCESS') self.assertEqual(register_response.params[0], "SUCCESS")
class TestRegisterBeforeConnectDisallowed(cases.BaseServerTestCase): class TestRegisterBeforeConnectDisallowed(cases.BaseServerTestCase):
@staticmethod @staticmethod
def config(): def config():
return { return {
"oragono_config": lambda config: config['accounts']['registration'].update( "oragono_config": lambda config: config["accounts"]["registration"].update(
{'allow-before-connect': False} {"allow-before-connect": False}
) )
} }
@cases.SpecificationSelector.requiredBySpecification('Oragono') @cases.SpecificationSelector.requiredBySpecification("Oragono")
def testBeforeConnect(self): def testBeforeConnect(self):
self.addClient('bar') self.addClient("bar")
self.sendLine('bar', 'CAP LS 302') self.sendLine("bar", "CAP LS 302")
caps = self.getCapLs('bar') caps = self.getCapLs("bar")
self.assertIn(REGISTER_CAP_NAME, caps) self.assertIn(REGISTER_CAP_NAME, caps)
self.assertEqual(caps[REGISTER_CAP_NAME], None) self.assertEqual(caps[REGISTER_CAP_NAME], None)
self.sendLine('bar', 'NICK bar') self.sendLine("bar", "NICK bar")
self.sendLine('bar', 'REGISTER * shivarampassphrase') self.sendLine("bar", "REGISTER * shivarampassphrase")
msgs = self.getMessages('bar') msgs = self.getMessages("bar")
fail_response = [msg for msg in msgs if msg.command == 'FAIL'][0] fail_response = [msg for msg in msgs if msg.command == "FAIL"][0]
self.assertEqual(fail_response.params[:2], ['REGISTER', 'DISALLOWED']) self.assertEqual(fail_response.params[:2], ["REGISTER", "DISALLOWED"])
class TestRegisterEmailVerified(cases.BaseServerTestCase): class TestRegisterEmailVerified(cases.BaseServerTestCase):
@staticmethod @staticmethod
def config(): def config():
return { return {
"oragono_config": lambda config: config['accounts']['registration'].update( "oragono_config": lambda config: config["accounts"]["registration"].update(
{ {
'email-verification': { "email-verification": {
'enabled': True, "enabled": True,
'sender': 'test@example.com', "sender": "test@example.com",
'require-tls': True, "require-tls": True,
'helo-domain': 'example.com', "helo-domain": "example.com",
}, },
'allow-before-connect': True, "allow-before-connect": True,
} }
) )
} }
@cases.SpecificationSelector.requiredBySpecification('Oragono') @cases.SpecificationSelector.requiredBySpecification("Oragono")
def testBeforeConnect(self): def testBeforeConnect(self):
self.addClient('bar') self.addClient("bar")
self.sendLine('bar', 'CAP LS 302') self.sendLine("bar", "CAP LS 302")
caps = self.getCapLs('bar') caps = self.getCapLs("bar")
self.assertIn(REGISTER_CAP_NAME, caps) self.assertIn(REGISTER_CAP_NAME, caps)
self.assertEqual(set(caps[REGISTER_CAP_NAME].split(',')), {'before-connect', 'email-required'}) self.assertEqual(
self.sendLine('bar', 'NICK bar') set(caps[REGISTER_CAP_NAME].split(",")),
self.sendLine('bar', 'REGISTER * shivarampassphrase') {"before-connect", "email-required"},
msgs = self.getMessages('bar') )
fail_response = [msg for msg in msgs if msg.command == 'FAIL'][0] self.sendLine("bar", "NICK bar")
self.assertEqual(fail_response.params[:2], ['REGISTER', 'INVALID_EMAIL']) self.sendLine("bar", "REGISTER * shivarampassphrase")
msgs = self.getMessages("bar")
fail_response = [msg for msg in msgs if msg.command == "FAIL"][0]
self.assertEqual(fail_response.params[:2], ["REGISTER", "INVALID_EMAIL"])
@cases.SpecificationSelector.requiredBySpecification('Oragono') @cases.SpecificationSelector.requiredBySpecification("Oragono")
def testAfterConnect(self): def testAfterConnect(self):
self.connectClient('bar', name='bar') self.connectClient("bar", name="bar")
self.sendLine('bar', 'REGISTER * shivarampassphrase') self.sendLine("bar", "REGISTER * shivarampassphrase")
msgs = self.getMessages('bar') msgs = self.getMessages("bar")
fail_response = [msg for msg in msgs if msg.command == 'FAIL'][0] fail_response = [msg for msg in msgs if msg.command == "FAIL"][0]
self.assertEqual(fail_response.params[:2], ['REGISTER', 'INVALID_EMAIL']) self.assertEqual(fail_response.params[:2], ["REGISTER", "INVALID_EMAIL"])
class TestRegisterNoLandGrabs(cases.BaseServerTestCase): class TestRegisterNoLandGrabs(cases.BaseServerTestCase):
@staticmethod @staticmethod
def config(): def config():
return { return {
"oragono_config": lambda config: config['accounts']['registration'].update( "oragono_config": lambda config: config["accounts"]["registration"].update(
{'allow-before-connect': True} {"allow-before-connect": True}
) )
} }
@cases.SpecificationSelector.requiredBySpecification('Oragono') @cases.SpecificationSelector.requiredBySpecification("Oragono")
def testBeforeConnect(self): def testBeforeConnect(self):
# have an anonymous client take the 'root' username: # have an anonymous client take the 'root' username:
self.connectClient('root', name='root') self.connectClient("root", name="root")
# cannot register it out from under the anonymous nick holder: # cannot register it out from under the anonymous nick holder:
self.addClient('bar') self.addClient("bar")
self.sendLine('bar', 'NICK root') self.sendLine("bar", "NICK root")
self.sendLine('bar', 'REGISTER * shivarampassphrase') self.sendLine("bar", "REGISTER * shivarampassphrase")
msgs = self.getMessages('bar') msgs = self.getMessages("bar")
fail_response = [msg for msg in msgs if msg.command == 'FAIL'][0] fail_response = [msg for msg in msgs if msg.command == "FAIL"][0]
self.assertEqual(fail_response.params[:2], ['REGISTER', 'USERNAME_EXISTS']) self.assertEqual(fail_response.params[:2], ["REGISTER", "USERNAME_EXISTS"])

View File

@ -6,164 +6,172 @@ from irctest import cases
from irctest.numerics import ERR_ERRONEUSNICKNAME, ERR_NICKNAMEINUSE, RPL_WELCOME from irctest.numerics import ERR_ERRONEUSNICKNAME, ERR_NICKNAMEINUSE, RPL_WELCOME
class RegressionsTestCase(cases.BaseServerTestCase):
@cases.SpecificationSelector.requiredBySpecification('RFC1459') class RegressionsTestCase(cases.BaseServerTestCase):
@cases.SpecificationSelector.requiredBySpecification("RFC1459")
def testFailedNickChange(self): def testFailedNickChange(self):
# see oragono commit d0ded906d4ac8f # see oragono commit d0ded906d4ac8f
self.connectClient('alice') self.connectClient("alice")
self.connectClient('bob') self.connectClient("bob")
# bob tries to change to an in-use nickname; this MUST fail # bob tries to change to an in-use nickname; this MUST fail
self.sendLine(2, 'NICK alice') self.sendLine(2, "NICK alice")
ms = self.getMessages(2) ms = self.getMessages(2)
self.assertEqual(len(ms), 1) self.assertEqual(len(ms), 1)
self.assertMessageEqual(ms[0], command=ERR_NICKNAMEINUSE) self.assertMessageEqual(ms[0], command=ERR_NICKNAMEINUSE)
# bob MUST still own the bob nick, and be able to receive PRIVMSG as bob # bob MUST still own the bob nick, and be able to receive PRIVMSG as bob
self.sendLine(1, 'PRIVMSG bob hi') self.sendLine(1, "PRIVMSG bob hi")
ms = self.getMessages(1) ms = self.getMessages(1)
self.assertEqual(len(ms), 0) self.assertEqual(len(ms), 0)
ms = self.getMessages(2) ms = self.getMessages(2)
self.assertEqual(len(ms), 1) self.assertEqual(len(ms), 1)
self.assertMessageEqual(ms[0], command='PRIVMSG', params=['bob', 'hi']) self.assertMessageEqual(ms[0], command="PRIVMSG", params=["bob", "hi"])
@cases.SpecificationSelector.requiredBySpecification('RFC1459') @cases.SpecificationSelector.requiredBySpecification("RFC1459")
def testCaseChanges(self): def testCaseChanges(self):
self.connectClient('alice') self.connectClient("alice")
self.joinChannel(1, '#test') self.joinChannel(1, "#test")
self.connectClient('bob') self.connectClient("bob")
self.joinChannel(2, '#test') self.joinChannel(2, "#test")
self.getMessages(1) self.getMessages(1)
self.getMessages(2) self.getMessages(2)
# case change: both alice and bob should get a successful nick line # case change: both alice and bob should get a successful nick line
self.sendLine(1, 'NICK Alice') self.sendLine(1, "NICK Alice")
ms = self.getMessages(1) ms = self.getMessages(1)
self.assertEqual(len(ms), 1) self.assertEqual(len(ms), 1)
self.assertMessageEqual(ms[0], command='NICK', params=['Alice']) self.assertMessageEqual(ms[0], command="NICK", params=["Alice"])
ms = self.getMessages(2) ms = self.getMessages(2)
self.assertEqual(len(ms), 1) self.assertEqual(len(ms), 1)
self.assertMessageEqual(ms[0], command='NICK', params=['Alice']) self.assertMessageEqual(ms[0], command="NICK", params=["Alice"])
# no responses, either to the user or to friends, from a no-op nick change # no responses, either to the user or to friends, from a no-op nick change
self.sendLine(1, 'NICK Alice') self.sendLine(1, "NICK Alice")
ms = self.getMessages(1) ms = self.getMessages(1)
self.assertEqual(ms, []) self.assertEqual(ms, [])
ms = self.getMessages(2) ms = self.getMessages(2)
self.assertEqual(ms, []) self.assertEqual(ms, [])
@cases.SpecificationSelector.requiredBySpecification('IRCv3.2') @cases.SpecificationSelector.requiredBySpecification("IRCv3.2")
def testTagCap(self): def testTagCap(self):
# regression test for oragono #754 # regression test for oragono #754
self.connectClient( self.connectClient(
'alice', "alice",
capabilities=['message-tags', 'batch', 'echo-message', 'server-time'], capabilities=["message-tags", "batch", "echo-message", "server-time"],
skip_if_cap_nak=True skip_if_cap_nak=True,
) )
self.connectClient('bob') self.connectClient("bob")
self.getMessages(1) self.getMessages(1)
self.getMessages(2) self.getMessages(2)
self.sendLine(1, '@+draft/reply=ct95w3xemz8qj9du2h74wp8pee PRIVMSG bob :hey yourself') self.sendLine(
1, "@+draft/reply=ct95w3xemz8qj9du2h74wp8pee PRIVMSG bob :hey yourself"
)
ms = self.getMessages(1) ms = self.getMessages(1)
self.assertEqual(len(ms), 1) self.assertEqual(len(ms), 1)
self.assertMessageEqual(ms[0], command='PRIVMSG', params=['bob', 'hey yourself']) self.assertMessageEqual(
self.assertEqual(ms[0].tags.get('+draft/reply'), 'ct95w3xemz8qj9du2h74wp8pee') ms[0], command="PRIVMSG", params=["bob", "hey yourself"]
)
self.assertEqual(ms[0].tags.get("+draft/reply"), "ct95w3xemz8qj9du2h74wp8pee")
ms = self.getMessages(2) ms = self.getMessages(2)
self.assertEqual(len(ms), 1) self.assertEqual(len(ms), 1)
self.assertMessageEqual(ms[0], command='PRIVMSG', params=['bob', 'hey yourself']) self.assertMessageEqual(
ms[0], command="PRIVMSG", params=["bob", "hey yourself"]
)
self.assertEqual(ms[0].tags, {}) self.assertEqual(ms[0].tags, {})
self.sendLine(2, 'CAP REQ :message-tags server-time') self.sendLine(2, "CAP REQ :message-tags server-time")
self.getMessages(2) self.getMessages(2)
self.sendLine(1, '@+draft/reply=tbxqauh9nykrtpa3n6icd9whan PRIVMSG bob :hey again') self.sendLine(
1, "@+draft/reply=tbxqauh9nykrtpa3n6icd9whan PRIVMSG bob :hey again"
)
self.getMessages(1) self.getMessages(1)
ms = self.getMessages(2) ms = self.getMessages(2)
# now bob has the tags cap, so he should receive the tags # now bob has the tags cap, so he should receive the tags
self.assertEqual(len(ms), 1) self.assertEqual(len(ms), 1)
self.assertMessageEqual(ms[0], command='PRIVMSG', params=['bob', 'hey again']) self.assertMessageEqual(ms[0], command="PRIVMSG", params=["bob", "hey again"])
self.assertEqual(ms[0].tags.get('+draft/reply'), 'tbxqauh9nykrtpa3n6icd9whan') self.assertEqual(ms[0].tags.get("+draft/reply"), "tbxqauh9nykrtpa3n6icd9whan")
@cases.SpecificationSelector.requiredBySpecification('RFC1459') @cases.SpecificationSelector.requiredBySpecification("RFC1459")
def testStarNick(self): def testStarNick(self):
self.addClient(1) self.addClient(1)
self.sendLine(1, 'NICK *') self.sendLine(1, "NICK *")
self.sendLine(1, 'USER u s e r') self.sendLine(1, "USER u s e r")
replies = {'NOTICE'} replies = {"NOTICE"}
while replies == {'NOTICE'}: while replies == {"NOTICE"}:
replies = set(msg.command for msg in self.getMessages(1, synchronize=False)) replies = set(msg.command for msg in self.getMessages(1, synchronize=False))
self.assertIn(ERR_ERRONEUSNICKNAME, replies) self.assertIn(ERR_ERRONEUSNICKNAME, replies)
self.assertNotIn(RPL_WELCOME, replies) self.assertNotIn(RPL_WELCOME, replies)
self.sendLine(1, 'NICK valid') self.sendLine(1, "NICK valid")
replies = {'NOTICE'} replies = {"NOTICE"}
while replies <= {'NOTICE'}: while replies <= {"NOTICE"}:
replies = set(msg.command for msg in self.getMessages(1, synchronize=False)) replies = set(msg.command for msg in self.getMessages(1, synchronize=False))
self.assertNotIn(ERR_ERRONEUSNICKNAME, replies) self.assertNotIn(ERR_ERRONEUSNICKNAME, replies)
self.assertIn(RPL_WELCOME, replies) self.assertIn(RPL_WELCOME, replies)
@cases.SpecificationSelector.requiredBySpecification('RFC1459') @cases.SpecificationSelector.requiredBySpecification("RFC1459")
def testEmptyNick(self): def testEmptyNick(self):
self.addClient(1) self.addClient(1)
self.sendLine(1, 'NICK :') self.sendLine(1, "NICK :")
self.sendLine(1, 'USER u s e r') self.sendLine(1, "USER u s e r")
replies = {'NOTICE'} replies = {"NOTICE"}
while replies == {'NOTICE'}: while replies == {"NOTICE"}:
replies = set(msg.command for msg in self.getMessages(1, synchronize=False)) replies = set(msg.command for msg in self.getMessages(1, synchronize=False))
self.assertNotIn(RPL_WELCOME, replies) self.assertNotIn(RPL_WELCOME, replies)
@cases.SpecificationSelector.requiredBySpecification('RFC1459') @cases.SpecificationSelector.requiredBySpecification("RFC1459")
def testNickRelease(self): def testNickRelease(self):
# regression test for oragono #1252 # regression test for oragono #1252
self.connectClient('alice') self.connectClient("alice")
self.getMessages(1) self.getMessages(1)
self.sendLine(1, 'NICK malice') self.sendLine(1, "NICK malice")
nick_msgs = [msg for msg in self.getMessages(1) if msg.command == 'NICK'] nick_msgs = [msg for msg in self.getMessages(1) if msg.command == "NICK"]
self.assertEqual(len(nick_msgs), 1) self.assertEqual(len(nick_msgs), 1)
self.assertMessageEqual(nick_msgs[0], command='NICK', params=['malice']) self.assertMessageEqual(nick_msgs[0], command="NICK", params=["malice"])
self.addClient(2) self.addClient(2)
self.sendLine(2, 'NICK alice') self.sendLine(2, "NICK alice")
self.sendLine(2, 'USER u s e r') self.sendLine(2, "USER u s e r")
replies = set(msg.command for msg in self.getMessages(2)) replies = set(msg.command for msg in self.getMessages(2))
self.assertNotIn(ERR_NICKNAMEINUSE, replies) self.assertNotIn(ERR_NICKNAMEINUSE, replies)
self.assertIn(RPL_WELCOME, replies) self.assertIn(RPL_WELCOME, replies)
@cases.SpecificationSelector.requiredBySpecification('RFC1459') @cases.SpecificationSelector.requiredBySpecification("RFC1459")
def testNickReleaseQuit(self): def testNickReleaseQuit(self):
self.connectClient('alice') self.connectClient("alice")
self.getMessages(1) self.getMessages(1)
self.sendLine(1, 'QUIT') self.sendLine(1, "QUIT")
self.assertDisconnected(1) self.assertDisconnected(1)
self.addClient(2) self.addClient(2)
self.sendLine(2, 'NICK alice') self.sendLine(2, "NICK alice")
self.sendLine(2, 'USER u s e r') self.sendLine(2, "USER u s e r")
replies = set(msg.command for msg in self.getMessages(2)) replies = set(msg.command for msg in self.getMessages(2))
self.assertNotIn(ERR_NICKNAMEINUSE, replies) self.assertNotIn(ERR_NICKNAMEINUSE, replies)
self.assertIn(RPL_WELCOME, replies) self.assertIn(RPL_WELCOME, replies)
self.sendLine(2, 'QUIT') self.sendLine(2, "QUIT")
self.assertDisconnected(2) self.assertDisconnected(2)
self.addClient(3) self.addClient(3)
self.sendLine(3, 'NICK ALICE') self.sendLine(3, "NICK ALICE")
self.sendLine(3, 'USER u s e r') self.sendLine(3, "USER u s e r")
replies = set(msg.command for msg in self.getMessages(3)) replies = set(msg.command for msg in self.getMessages(3))
self.assertNotIn(ERR_NICKNAMEINUSE, replies) self.assertNotIn(ERR_NICKNAMEINUSE, replies)
self.assertIn(RPL_WELCOME, replies) self.assertIn(RPL_WELCOME, replies)
@cases.SpecificationSelector.requiredBySpecification('RFC1459') @cases.SpecificationSelector.requiredBySpecification("RFC1459")
def testNickReleaseUnregistered(self): def testNickReleaseUnregistered(self):
self.addClient(1) self.addClient(1)
self.sendLine(1, 'NICK alice') self.sendLine(1, "NICK alice")
self.sendLine(1, 'QUIT') self.sendLine(1, "QUIT")
self.assertDisconnected(1) self.assertDisconnected(1)
self.addClient(2) self.addClient(2)
self.sendLine(2, 'NICK alice') self.sendLine(2, "NICK alice")
self.sendLine(2, 'USER u s e r') self.sendLine(2, "USER u s e r")
replies = set(msg.command for msg in self.getMessages(2)) replies = set(msg.command for msg in self.getMessages(2))
self.assertNotIn(ERR_NICKNAMEINUSE, replies) self.assertNotIn(ERR_NICKNAMEINUSE, replies)
self.assertIn(RPL_WELCOME, replies) self.assertIn(RPL_WELCOME, replies)

View File

@ -3,8 +3,9 @@ from irctest.irc_utils.junkdrawer import random_name
from irctest.server_tests.test_chathistory import CHATHISTORY_CAP, EVENT_PLAYBACK_CAP from irctest.server_tests.test_chathistory import CHATHISTORY_CAP, EVENT_PLAYBACK_CAP
RELAYMSG_CAP = 'draft/relaymsg' RELAYMSG_CAP = "draft/relaymsg"
RELAYMSG_TAG_NAME = 'draft/relaymsg' RELAYMSG_TAG_NAME = "draft/relaymsg"
class RelaymsgTestCase(cases.BaseServerTestCase): class RelaymsgTestCase(cases.BaseServerTestCase):
@staticmethod @staticmethod
@ -13,60 +14,112 @@ class RelaymsgTestCase(cases.BaseServerTestCase):
"chathistory": True, "chathistory": True,
} }
@cases.SpecificationSelector.requiredBySpecification('Oragono') @cases.SpecificationSelector.requiredBySpecification("Oragono")
def testRelaymsg(self): def testRelaymsg(self):
self.connectClient('baz', name='baz', capabilities=['server-time', 'message-tags', 'batch', 'labeled-response', 'echo-message', CHATHISTORY_CAP, EVENT_PLAYBACK_CAP]) self.connectClient(
self.connectClient('qux', name='qux', capabilities=['server-time', 'message-tags', 'batch', 'labeled-response', 'echo-message', CHATHISTORY_CAP, EVENT_PLAYBACK_CAP]) "baz",
chname = random_name('#relaymsg') name="baz",
self.joinChannel('baz', chname) capabilities=[
self.joinChannel('qux', chname) "server-time",
self.getMessages('baz') "message-tags",
self.getMessages('qux') "batch",
"labeled-response",
"echo-message",
CHATHISTORY_CAP,
EVENT_PLAYBACK_CAP,
],
)
self.connectClient(
"qux",
name="qux",
capabilities=[
"server-time",
"message-tags",
"batch",
"labeled-response",
"echo-message",
CHATHISTORY_CAP,
EVENT_PLAYBACK_CAP,
],
)
chname = random_name("#relaymsg")
self.joinChannel("baz", chname)
self.joinChannel("qux", chname)
self.getMessages("baz")
self.getMessages("qux")
self.sendLine('baz', 'RELAYMSG %s invalid!nick/discord hi' % (chname,)) self.sendLine("baz", "RELAYMSG %s invalid!nick/discord hi" % (chname,))
response = self.getMessages('baz')[0] response = self.getMessages("baz")[0]
self.assertEqual(response.command, 'FAIL') self.assertEqual(response.command, "FAIL")
self.assertEqual(response.params[:2], ['RELAYMSG', 'INVALID_NICK']) self.assertEqual(response.params[:2], ["RELAYMSG", "INVALID_NICK"])
self.sendLine('baz', 'RELAYMSG %s regular_nick hi' % (chname,)) self.sendLine("baz", "RELAYMSG %s regular_nick hi" % (chname,))
response = self.getMessages('baz')[0] response = self.getMessages("baz")[0]
self.assertEqual(response.command, 'FAIL') self.assertEqual(response.command, "FAIL")
self.assertEqual(response.params[:2], ['RELAYMSG', 'INVALID_NICK']) self.assertEqual(response.params[:2], ["RELAYMSG", "INVALID_NICK"])
self.sendLine('baz', 'RELAYMSG %s smt/discord hi' % (chname,)) self.sendLine("baz", "RELAYMSG %s smt/discord hi" % (chname,))
response = self.getMessages('baz')[0] response = self.getMessages("baz")[0]
self.assertMessageEqual(response, nick='smt/discord', command='PRIVMSG', params=[chname, 'hi']) self.assertMessageEqual(
relayed_msg = self.getMessages('qux')[0] response, nick="smt/discord", command="PRIVMSG", params=[chname, "hi"]
self.assertMessageEqual(relayed_msg, nick='smt/discord', command='PRIVMSG', params=[chname, 'hi']) )
relayed_msg = self.getMessages("qux")[0]
self.assertMessageEqual(
relayed_msg, nick="smt/discord", command="PRIVMSG", params=[chname, "hi"]
)
# labeled-response # labeled-response
self.sendLine('baz', '@label=x RELAYMSG %s smt/discord :hi again' % (chname,)) self.sendLine("baz", "@label=x RELAYMSG %s smt/discord :hi again" % (chname,))
response = self.getMessages('baz')[0] response = self.getMessages("baz")[0]
self.assertMessageEqual(response, nick='smt/discord', command='PRIVMSG', params=[chname, 'hi again']) self.assertMessageEqual(
self.assertEqual(response.tags.get('label'), 'x') response, nick="smt/discord", command="PRIVMSG", params=[chname, "hi again"]
relayed_msg = self.getMessages('qux')[0] )
self.assertMessageEqual(relayed_msg, nick='smt/discord', command='PRIVMSG', params=[chname, 'hi again']) self.assertEqual(response.tags.get("label"), "x")
relayed_msg = self.getMessages("qux")[0]
self.assertMessageEqual(
relayed_msg,
nick="smt/discord",
command="PRIVMSG",
params=[chname, "hi again"],
)
self.sendLine('qux', 'RELAYMSG %s smt/discord :hi a third time' % (chname,)) self.sendLine("qux", "RELAYMSG %s smt/discord :hi a third time" % (chname,))
response = self.getMessages('qux')[0] response = self.getMessages("qux")[0]
self.assertEqual(response.command, 'FAIL') self.assertEqual(response.command, "FAIL")
self.assertEqual(response.params[:2], ['RELAYMSG', 'PRIVS_NEEDED']) self.assertEqual(response.params[:2], ["RELAYMSG", "PRIVS_NEEDED"])
# grant qux chanop, allowing relaymsg # grant qux chanop, allowing relaymsg
self.sendLine('baz', 'MODE %s +o qux' % (chname,)) self.sendLine("baz", "MODE %s +o qux" % (chname,))
self.getMessages('baz') self.getMessages("baz")
self.getMessages('qux') self.getMessages("qux")
# give baz the relaymsg cap # give baz the relaymsg cap
self.sendLine('baz', 'CAP REQ %s' % (RELAYMSG_CAP)) self.sendLine("baz", "CAP REQ %s" % (RELAYMSG_CAP))
self.assertMessageEqual(self.getMessages('baz')[0], command='CAP', params=['baz', 'ACK', RELAYMSG_CAP]) self.assertMessageEqual(
self.getMessages("baz")[0],
command="CAP",
params=["baz", "ACK", RELAYMSG_CAP],
)
self.sendLine('qux', 'RELAYMSG %s smt/discord :hi a third time' % (chname,)) self.sendLine("qux", "RELAYMSG %s smt/discord :hi a third time" % (chname,))
response = self.getMessages('qux')[0] response = self.getMessages("qux")[0]
self.assertMessageEqual(response, nick='smt/discord', command='PRIVMSG', params=[chname, 'hi a third time']) self.assertMessageEqual(
relayed_msg = self.getMessages('baz')[0] response,
self.assertMessageEqual(relayed_msg, nick='smt/discord', command='PRIVMSG', params=[chname, 'hi a third time']) nick="smt/discord",
self.assertEqual(relayed_msg.tags.get(RELAYMSG_TAG_NAME), 'qux') command="PRIVMSG",
params=[chname, "hi a third time"],
)
relayed_msg = self.getMessages("baz")[0]
self.assertMessageEqual(
relayed_msg,
nick="smt/discord",
command="PRIVMSG",
params=[chname, "hi a third time"],
)
self.assertEqual(relayed_msg.tags.get(RELAYMSG_TAG_NAME), "qux")
self.sendLine('baz', 'CHATHISTORY LATEST %s * 10' % (chname,)) self.sendLine("baz", "CHATHISTORY LATEST %s * 10" % (chname,))
messages = self.getMessages('baz') messages = self.getMessages("baz")
self.assertEqual([msg.params[-1] for msg in messages if msg.command == 'PRIVMSG'], ['hi', 'hi again', 'hi a third time']) self.assertEqual(
[msg.params[-1] for msg in messages if msg.command == "PRIVMSG"],
["hi", "hi again", "hi a third time"],
)

View File

@ -8,143 +8,209 @@ from irctest import cases
from irctest.numerics import RPL_AWAY from irctest.numerics import RPL_AWAY
ANCIENT_TIMESTAMP = '2006-01-02T15:04:05.999Z' ANCIENT_TIMESTAMP = "2006-01-02T15:04:05.999Z"
class ResumeTestCase(cases.BaseServerTestCase): class ResumeTestCase(cases.BaseServerTestCase):
@cases.SpecificationSelector.requiredBySpecification("Oragono")
@cases.SpecificationSelector.requiredBySpecification('Oragono')
def testNoResumeByDefault(self): def testNoResumeByDefault(self):
self.connectClient('bar', capabilities=['batch', 'echo-message', 'labeled-response']) self.connectClient(
"bar", capabilities=["batch", "echo-message", "labeled-response"]
)
ms = self.getMessages(1) ms = self.getMessages(1)
resume_messages = [m for m in ms if m.command == 'RESUME'] resume_messages = [m for m in ms if m.command == "RESUME"]
self.assertEqual(resume_messages, [], 'should not see RESUME messages unless explicitly negotiated') self.assertEqual(
resume_messages,
[],
"should not see RESUME messages unless explicitly negotiated",
)
@cases.SpecificationSelector.requiredBySpecification('Oragono') @cases.SpecificationSelector.requiredBySpecification("Oragono")
def testResume(self): def testResume(self):
chname = '#' + secrets.token_hex(12) chname = "#" + secrets.token_hex(12)
self.connectClient('bar', capabilities=['batch', 'labeled-response', 'server-time']) self.connectClient(
"bar", capabilities=["batch", "labeled-response", "server-time"]
)
ms = self.getMessages(1) ms = self.getMessages(1)
welcome = self.connectClient('baz', capabilities=['batch', 'labeled-response', 'server-time', 'draft/resume-0.5']) welcome = self.connectClient(
resume_messages = [m for m in welcome if m.command == 'RESUME'] "baz",
capabilities=[
"batch",
"labeled-response",
"server-time",
"draft/resume-0.5",
],
)
resume_messages = [m for m in welcome if m.command == "RESUME"]
self.assertEqual(len(resume_messages), 1) self.assertEqual(len(resume_messages), 1)
self.assertEqual(resume_messages[0].params[0], 'TOKEN') self.assertEqual(resume_messages[0].params[0], "TOKEN")
token = resume_messages[0].params[1] token = resume_messages[0].params[1]
self.joinChannel(1, chname) self.joinChannel(1, chname)
self.joinChannel(2, chname) self.joinChannel(2, chname)
self.sendLine(1, 'PRIVMSG %s :hello friends' % (chname,)) self.sendLine(1, "PRIVMSG %s :hello friends" % (chname,))
self.sendLine(1, 'PRIVMSG baz :hello friend singular') self.sendLine(1, "PRIVMSG baz :hello friend singular")
self.getMessages(1) self.getMessages(1)
# should receive these messages # should receive these messages
privmsgs = [m for m in self.getMessages(2) if m.command == 'PRIVMSG'] privmsgs = [m for m in self.getMessages(2) if m.command == "PRIVMSG"]
self.assertEqual(len(privmsgs), 2) self.assertEqual(len(privmsgs), 2)
privmsgs.sort(key=lambda m: m.params[0]) privmsgs.sort(key=lambda m: m.params[0])
self.assertMessageEqual(privmsgs[0], command='PRIVMSG', params=[chname, 'hello friends']) self.assertMessageEqual(
self.assertMessageEqual(privmsgs[1], command='PRIVMSG', params=['baz', 'hello friend singular']) privmsgs[0], command="PRIVMSG", params=[chname, "hello friends"]
channelMsgTime = privmsgs[0].tags.get('time') )
self.assertMessageEqual(
privmsgs[1], command="PRIVMSG", params=["baz", "hello friend singular"]
)
channelMsgTime = privmsgs[0].tags.get("time")
# tokens MUST be cryptographically secure; therefore, this token should be invalid # tokens MUST be cryptographically secure; therefore, this token should be invalid
# with probability at least 1 - 1/(2**128) # with probability at least 1 - 1/(2**128)
bad_token = 'a' * len(token) bad_token = "a" * len(token)
self.addClient() self.addClient()
self.sendLine(3, 'CAP LS') self.sendLine(3, "CAP LS")
self.sendLine(3, 'CAP REQ :batch labeled-response server-time draft/resume-0.5') self.sendLine(3, "CAP REQ :batch labeled-response server-time draft/resume-0.5")
self.sendLine(3, 'NICK tempnick') self.sendLine(3, "NICK tempnick")
self.sendLine(3, 'USER tempuser 0 * tempuser') self.sendLine(3, "USER tempuser 0 * tempuser")
self.sendLine(3, ' '.join(('RESUME', bad_token, ANCIENT_TIMESTAMP))) self.sendLine(3, " ".join(("RESUME", bad_token, ANCIENT_TIMESTAMP)))
# resume with a bad token MUST fail # resume with a bad token MUST fail
ms = self.getMessages(3) ms = self.getMessages(3)
resume_err_messages = [m for m in ms if m.command == 'FAIL' and m.params[:2] == ['RESUME', 'INVALID_TOKEN']] resume_err_messages = [
m
for m in ms
if m.command == "FAIL" and m.params[:2] == ["RESUME", "INVALID_TOKEN"]
]
self.assertEqual(len(resume_err_messages), 1) self.assertEqual(len(resume_err_messages), 1)
# however, registration should proceed with the alternative nick # however, registration should proceed with the alternative nick
self.sendLine(3, 'CAP END') self.sendLine(3, "CAP END")
welcome_msgs = [m for m in self.getMessages(3) if m.command == '001'] # RPL_WELCOME welcome_msgs = [
self.assertEqual(welcome_msgs[0].params[0], 'tempnick') m for m in self.getMessages(3) if m.command == "001"
] # RPL_WELCOME
self.assertEqual(welcome_msgs[0].params[0], "tempnick")
self.addClient() self.addClient()
self.sendLine(4, 'CAP LS') self.sendLine(4, "CAP LS")
self.sendLine(4, 'CAP REQ :batch labeled-response server-time draft/resume-0.5') self.sendLine(4, "CAP REQ :batch labeled-response server-time draft/resume-0.5")
self.sendLine(4, 'NICK tempnick_') self.sendLine(4, "NICK tempnick_")
self.sendLine(4, 'USER tempuser 0 * tempuser') self.sendLine(4, "USER tempuser 0 * tempuser")
# resume with a timestamp in the distant past # resume with a timestamp in the distant past
self.sendLine(4, ' '.join(('RESUME', token, ANCIENT_TIMESTAMP))) self.sendLine(4, " ".join(("RESUME", token, ANCIENT_TIMESTAMP)))
# successful resume does not require CAP END: # successful resume does not require CAP END:
# https://github.com/ircv3/ircv3-specifications/pull/306/files#r255318883 # https://github.com/ircv3/ircv3-specifications/pull/306/files#r255318883
ms = self.getMessages(4) ms = self.getMessages(4)
# now, do a valid resume with the correct token # now, do a valid resume with the correct token
resume_messages = [m for m in ms if m.command == 'RESUME'] resume_messages = [m for m in ms if m.command == "RESUME"]
self.assertEqual(len(resume_messages), 2) self.assertEqual(len(resume_messages), 2)
self.assertEqual(resume_messages[0].params[0], 'TOKEN') self.assertEqual(resume_messages[0].params[0], "TOKEN")
new_token = resume_messages[0].params[1] new_token = resume_messages[0].params[1]
self.assertNotEqual(token, new_token, 'should receive a new, strong resume token; instead got ' + new_token) self.assertNotEqual(
token,
new_token,
"should receive a new, strong resume token; instead got " + new_token,
)
# success message # success message
self.assertMessageEqual(resume_messages[1], command='RESUME', params=['SUCCESS', 'baz']) self.assertMessageEqual(
resume_messages[1], command="RESUME", params=["SUCCESS", "baz"]
)
# test replay of messages # test replay of messages
privmsgs = [m for m in ms if m.command == 'PRIVMSG' and m.prefix.startswith('bar')] privmsgs = [
m for m in ms if m.command == "PRIVMSG" and m.prefix.startswith("bar")
]
self.assertEqual(len(privmsgs), 2) self.assertEqual(len(privmsgs), 2)
privmsgs.sort(key=lambda m: m.params[0]) privmsgs.sort(key=lambda m: m.params[0])
self.assertMessageEqual(privmsgs[0], command='PRIVMSG', params=[chname, 'hello friends']) self.assertMessageEqual(
self.assertMessageEqual(privmsgs[1], command='PRIVMSG', params=['baz', 'hello friend singular']) privmsgs[0], command="PRIVMSG", params=[chname, "hello friends"]
)
self.assertMessageEqual(
privmsgs[1], command="PRIVMSG", params=["baz", "hello friend singular"]
)
# should replay with the original server-time # should replay with the original server-time
# TODO this probably isn't testing anything because the timestamp only has second resolution, # TODO this probably isn't testing anything because the timestamp only has second resolution,
# hence will typically match by accident # hence will typically match by accident
self.assertEqual(privmsgs[0].tags.get('time'), channelMsgTime) self.assertEqual(privmsgs[0].tags.get("time"), channelMsgTime)
# legacy client should receive a QUIT and a JOIN # legacy client should receive a QUIT and a JOIN
quit, join = [m for m in self.getMessages(1) if m.command in ('QUIT', 'JOIN')] quit, join = [m for m in self.getMessages(1) if m.command in ("QUIT", "JOIN")]
self.assertEqual(quit.command, 'QUIT') self.assertEqual(quit.command, "QUIT")
self.assertTrue(quit.prefix.startswith('baz')) self.assertTrue(quit.prefix.startswith("baz"))
self.assertMessageEqual(join, command='JOIN', params=[chname]) self.assertMessageEqual(join, command="JOIN", params=[chname])
self.assertTrue(join.prefix.startswith('baz')) self.assertTrue(join.prefix.startswith("baz"))
# original client should have been disconnected # original client should have been disconnected
self.assertDisconnected(2) self.assertDisconnected(2)
# new client should be receiving PRIVMSG sent to baz # new client should be receiving PRIVMSG sent to baz
self.sendLine(1, 'PRIVMSG baz :hello again') self.sendLine(1, "PRIVMSG baz :hello again")
self.getMessages(1) self.getMessages(1)
self.assertMessageEqual(self.getMessage(4), command='PRIVMSG', params=['baz', 'hello again']) self.assertMessageEqual(
self.getMessage(4), command="PRIVMSG", params=["baz", "hello again"]
)
# test chain-resuming (resuming the resumed connection, using the new token) # test chain-resuming (resuming the resumed connection, using the new token)
self.addClient() self.addClient()
self.sendLine(5, 'CAP LS') self.sendLine(5, "CAP LS")
self.sendLine(5, 'CAP REQ :batch labeled-response server-time draft/resume-0.5') self.sendLine(5, "CAP REQ :batch labeled-response server-time draft/resume-0.5")
self.sendLine(5, 'NICK tempnick_') self.sendLine(5, "NICK tempnick_")
self.sendLine(5, 'USER tempuser 0 * tempuser') self.sendLine(5, "USER tempuser 0 * tempuser")
self.sendLine(5, 'RESUME ' + new_token) self.sendLine(5, "RESUME " + new_token)
ms = self.getMessages(5) ms = self.getMessages(5)
resume_messages = [m for m in ms if m.command == 'RESUME'] resume_messages = [m for m in ms if m.command == "RESUME"]
self.assertEqual(len(resume_messages), 2) self.assertEqual(len(resume_messages), 2)
self.assertEqual(resume_messages[0].params[0], 'TOKEN') self.assertEqual(resume_messages[0].params[0], "TOKEN")
new_new_token = resume_messages[0].params[1] new_new_token = resume_messages[0].params[1]
self.assertNotEqual(token, new_new_token, 'should receive a new, strong resume token; instead got ' + new_new_token) self.assertNotEqual(
self.assertNotEqual(new_token, new_new_token, 'should receive a new, strong resume token; instead got ' + new_new_token) token,
new_new_token,
"should receive a new, strong resume token; instead got " + new_new_token,
)
self.assertNotEqual(
new_token,
new_new_token,
"should receive a new, strong resume token; instead got " + new_new_token,
)
# success message # success message
self.assertMessageEqual(resume_messages[1], command='RESUME', params=['SUCCESS', 'baz']) self.assertMessageEqual(
resume_messages[1], command="RESUME", params=["SUCCESS", "baz"]
)
@cases.SpecificationSelector.requiredBySpecification("Oragono")
@cases.SpecificationSelector.requiredBySpecification('Oragono')
def testBRB(self): def testBRB(self):
chname = '#' + secrets.token_hex(12) chname = "#" + secrets.token_hex(12)
self.connectClient('bar', capabilities=['batch', 'labeled-response', 'message-tags', 'server-time', 'draft/resume-0.5']) self.connectClient(
"bar",
capabilities=[
"batch",
"labeled-response",
"message-tags",
"server-time",
"draft/resume-0.5",
],
)
ms = self.getMessages(1) ms = self.getMessages(1)
self.joinChannel(1, chname) self.joinChannel(1, chname)
welcome = self.connectClient('baz', capabilities=['batch', 'labeled-response', 'server-time', 'draft/resume-0.5']) welcome = self.connectClient(
resume_messages = [m for m in welcome if m.command == 'RESUME'] "baz",
capabilities=[
"batch",
"labeled-response",
"server-time",
"draft/resume-0.5",
],
)
resume_messages = [m for m in welcome if m.command == "RESUME"]
self.assertEqual(len(resume_messages), 1) self.assertEqual(len(resume_messages), 1)
self.assertEqual(resume_messages[0].params[0], 'TOKEN') self.assertEqual(resume_messages[0].params[0], "TOKEN")
token = resume_messages[0].params[1] token = resume_messages[0].params[1]
self.joinChannel(2, chname) self.joinChannel(2, chname)
self.getMessages(1) self.getMessages(1)
self.sendLine(2, 'BRB :software upgrade') self.sendLine(2, "BRB :software upgrade")
# should receive, e.g., `BRB 210` (number of seconds) # should receive, e.g., `BRB 210` (number of seconds)
ms = [m for m in self.getMessages(2) if m.command == 'BRB'] ms = [m for m in self.getMessages(2) if m.command == "BRB"]
self.assertEqual(len(ms), 1) self.assertEqual(len(ms), 1)
self.assertGreater(int(ms[0].params[0]), 1) self.assertGreater(int(ms[0].params[0]), 1)
# BRB disconnects you # BRB disconnects you
@ -152,25 +218,33 @@ class ResumeTestCase(cases.BaseServerTestCase):
# without sending a QUIT line to friends # without sending a QUIT line to friends
self.assertEqual(self.getMessages(1), []) self.assertEqual(self.getMessages(1), [])
self.sendLine(1, 'PRIVMSG baz :hey there') self.sendLine(1, "PRIVMSG baz :hey there")
# BRB message should be sent as an away message # BRB message should be sent as an away message
self.assertMessageEqual(self.getMessage(1), command=RPL_AWAY, params=['bar', 'baz', 'software upgrade']) self.assertMessageEqual(
self.getMessage(1),
command=RPL_AWAY,
params=["bar", "baz", "software upgrade"],
)
self.addClient(3) self.addClient(3)
self.sendLine(3, 'CAP REQ :batch account-tag message-tags draft/resume-0.5') self.sendLine(3, "CAP REQ :batch account-tag message-tags draft/resume-0.5")
self.sendLine(3, ' '.join(('RESUME', token, ANCIENT_TIMESTAMP))) self.sendLine(3, " ".join(("RESUME", token, ANCIENT_TIMESTAMP)))
ms = self.getMessages(3) ms = self.getMessages(3)
resume_messages = [m for m in ms if m.command == 'RESUME'] resume_messages = [m for m in ms if m.command == "RESUME"]
self.assertEqual(len(resume_messages), 2) self.assertEqual(len(resume_messages), 2)
self.assertEqual(resume_messages[0].params[0], 'TOKEN') self.assertEqual(resume_messages[0].params[0], "TOKEN")
self.assertMessageEqual(resume_messages[1], command='RESUME', params=['SUCCESS', 'baz']) self.assertMessageEqual(
resume_messages[1], command="RESUME", params=["SUCCESS", "baz"]
)
privmsgs = [m for m in ms if m.command == 'PRIVMSG' and m.prefix.startswith('bar')] privmsgs = [
m for m in ms if m.command == "PRIVMSG" and m.prefix.startswith("bar")
]
self.assertEqual(len(privmsgs), 1) self.assertEqual(len(privmsgs), 1)
self.assertMessageEqual(privmsgs[0], params=['baz', 'hey there']) self.assertMessageEqual(privmsgs[0], params=["baz", "hey there"])
# friend with the resume cap should receive a RESUMED message # friend with the resume cap should receive a RESUMED message
resumed_messages = [m for m in self.getMessages(1) if m.command == 'RESUMED'] resumed_messages = [m for m in self.getMessages(1) if m.command == "RESUMED"]
self.assertEqual(len(resumed_messages), 1) self.assertEqual(len(resumed_messages), 1)
self.assertTrue(resumed_messages[0].prefix.startswith('baz')) self.assertTrue(resumed_messages[0].prefix.startswith("baz"))

View File

@ -2,6 +2,7 @@ from irctest import cases
from irctest.numerics import ERR_CANNOTSENDRP from irctest.numerics import ERR_CANNOTSENDRP
from irctest.irc_utils.junkdrawer import random_name from irctest.irc_utils.junkdrawer import random_name
class RoleplayTestCase(cases.BaseServerTestCase): class RoleplayTestCase(cases.BaseServerTestCase):
@staticmethod @staticmethod
def config(): def config():
@ -9,58 +10,70 @@ class RoleplayTestCase(cases.BaseServerTestCase):
"oragono_roleplay": True, "oragono_roleplay": True,
} }
@cases.SpecificationSelector.requiredBySpecification('Oragono') @cases.SpecificationSelector.requiredBySpecification("Oragono")
def testRoleplay(self): def testRoleplay(self):
bar = random_name('bar') bar = random_name("bar")
qux = random_name('qux') qux = random_name("qux")
chan = random_name('#chan') chan = random_name("#chan")
self.connectClient(bar, name=bar, capabilities=['batch', 'labeled-response', 'message-tags', 'server-time']) self.connectClient(
self.connectClient(qux, name=qux, capabilities=['batch', 'labeled-response', 'message-tags', 'server-time']) bar,
name=bar,
capabilities=["batch", "labeled-response", "message-tags", "server-time"],
)
self.connectClient(
qux,
name=qux,
capabilities=["batch", "labeled-response", "message-tags", "server-time"],
)
self.joinChannel(bar, chan) self.joinChannel(bar, chan)
self.joinChannel(qux, chan) self.joinChannel(qux, chan)
self.getMessages(bar) self.getMessages(bar)
# roleplay should be forbidden because we aren't +E yet # roleplay should be forbidden because we aren't +E yet
self.sendLine(bar, 'NPC %s bilbo too much bread' % (chan,)) self.sendLine(bar, "NPC %s bilbo too much bread" % (chan,))
reply = self.getMessages(bar)[0] reply = self.getMessages(bar)[0]
self.assertEqual(reply.command, ERR_CANNOTSENDRP) self.assertEqual(reply.command, ERR_CANNOTSENDRP)
self.sendLine(bar, 'MODE %s +E' % (chan,)) self.sendLine(bar, "MODE %s +E" % (chan,))
reply = self.getMessages(bar)[0] reply = self.getMessages(bar)[0]
self.assertEqual(reply.command, 'MODE') self.assertEqual(reply.command, "MODE")
self.assertMessageEqual(reply, command='MODE', params=[chan, '+E']) self.assertMessageEqual(reply, command="MODE", params=[chan, "+E"])
self.getMessages(qux) self.getMessages(qux)
self.sendLine(bar, 'NPC %s bilbo too much bread' % (chan,)) self.sendLine(bar, "NPC %s bilbo too much bread" % (chan,))
reply = self.getMessages(bar)[0] reply = self.getMessages(bar)[0]
self.assertEqual(reply.command, 'PRIVMSG') self.assertEqual(reply.command, "PRIVMSG")
self.assertEqual(reply.params[0], chan) self.assertEqual(reply.params[0], chan)
self.assertTrue(reply.prefix.startswith('*bilbo*!')) self.assertTrue(reply.prefix.startswith("*bilbo*!"))
self.assertIn('too much bread', reply.params[1]) self.assertIn("too much bread", reply.params[1])
reply = self.getMessages(qux)[0] reply = self.getMessages(qux)[0]
self.assertEqual(reply.command, 'PRIVMSG') self.assertEqual(reply.command, "PRIVMSG")
self.assertEqual(reply.params[0], chan) self.assertEqual(reply.params[0], chan)
self.assertTrue(reply.prefix.startswith('*bilbo*!')) self.assertTrue(reply.prefix.startswith("*bilbo*!"))
self.assertIn('too much bread', reply.params[1]) self.assertIn("too much bread", reply.params[1])
self.sendLine(bar, 'SCENE %s dark and stormy night' % (chan,)) self.sendLine(bar, "SCENE %s dark and stormy night" % (chan,))
reply = self.getMessages(bar)[0] reply = self.getMessages(bar)[0]
self.assertEqual(reply.command, 'PRIVMSG') self.assertEqual(reply.command, "PRIVMSG")
self.assertEqual(reply.params[0], chan) self.assertEqual(reply.params[0], chan)
self.assertTrue(reply.prefix.startswith('=Scene=!')) self.assertTrue(reply.prefix.startswith("=Scene=!"))
self.assertIn('dark and stormy night', reply.params[1]) self.assertIn("dark and stormy night", reply.params[1])
reply = self.getMessages(qux)[0] reply = self.getMessages(qux)[0]
self.assertEqual(reply.command, 'PRIVMSG') self.assertEqual(reply.command, "PRIVMSG")
self.assertEqual(reply.params[0], chan) self.assertEqual(reply.params[0], chan)
self.assertTrue(reply.prefix.startswith('=Scene=!')) self.assertTrue(reply.prefix.startswith("=Scene=!"))
self.assertIn('dark and stormy night', reply.params[1]) self.assertIn("dark and stormy night", reply.params[1])
# test history storage # test history storage
self.sendLine(qux, 'CHATHISTORY LATEST %s * 10' % (chan,)) self.sendLine(qux, "CHATHISTORY LATEST %s * 10" % (chan,))
reply = [msg for msg in self.getMessages(qux) if msg.command == 'PRIVMSG' and 'bilbo' in msg.prefix][0] reply = [
self.assertEqual(reply.command, 'PRIVMSG') msg
for msg in self.getMessages(qux)
if msg.command == "PRIVMSG" and "bilbo" in msg.prefix
][0]
self.assertEqual(reply.command, "PRIVMSG")
self.assertEqual(reply.params[0], chan) self.assertEqual(reply.params[0], chan)
self.assertTrue(reply.prefix.startswith('*bilbo*!')) self.assertTrue(reply.prefix.startswith("*bilbo*!"))
self.assertIn('too much bread', reply.params[1]) self.assertIn("too much bread", reply.params[1])

View File

@ -2,42 +2,60 @@ import base64
from irctest import cases from irctest import cases
class RegistrationTestCase(cases.BaseServerTestCase): class RegistrationTestCase(cases.BaseServerTestCase):
def testRegistration(self): def testRegistration(self):
self.controller.registerUser(self, 'testuser', 'mypassword') self.controller.registerUser(self, "testuser", "mypassword")
class SaslTestCase(cases.BaseServerTestCase, cases.OptionalityHelper): class SaslTestCase(cases.BaseServerTestCase, cases.OptionalityHelper):
@cases.SpecificationSelector.requiredBySpecification('IRCv3.1') @cases.SpecificationSelector.requiredBySpecification("IRCv3.1")
@cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN') @cases.OptionalityHelper.skipUnlessHasMechanism("PLAIN")
def testPlain(self): def testPlain(self):
"""PLAIN authentication with correct username/password.""" """PLAIN authentication with correct username/password."""
self.controller.registerUser(self, 'foo', 'sesame') self.controller.registerUser(self, "foo", "sesame")
self.controller.registerUser(self, 'jilles', 'sesame') self.controller.registerUser(self, "jilles", "sesame")
self.controller.registerUser(self, 'bar', 'sesame') self.controller.registerUser(self, "bar", "sesame")
self.addClient() self.addClient()
self.sendLine(1, 'CAP LS 302') self.sendLine(1, "CAP LS 302")
capabilities = self.getCapLs(1) capabilities = self.getCapLs(1)
self.assertIn('sasl', capabilities, self.assertIn(
fail_msg='Does not have SASL as the controller claims.') "sasl",
if capabilities['sasl'] is not None: capabilities,
self.assertIn('PLAIN', capabilities['sasl'], fail_msg="Does not have SASL as the controller claims.",
fail_msg='Does not have PLAIN mechanism as the controller ' )
'claims') if capabilities["sasl"] is not None:
self.sendLine(1, 'AUTHENTICATE PLAIN') self.assertIn(
m = self.getMessage(1, filter_pred=lambda m:m.command != 'NOTICE') "PLAIN",
self.assertMessageEqual(m, command='AUTHENTICATE', params=['+'], capabilities["sasl"],
fail_msg='Sent “AUTHENTICATE PLAIN”, server should have ' fail_msg="Does not have PLAIN mechanism as the controller " "claims",
'replied with “AUTHENTICATE +”, but instead sent: {msg}') )
self.sendLine(1, 'AUTHENTICATE amlsbGVzAGppbGxlcwBzZXNhbWU=') self.sendLine(1, "AUTHENTICATE PLAIN")
m = self.getMessage(1, filter_pred=lambda m:m.command != 'NOTICE') m = self.getMessage(1, filter_pred=lambda m: m.command != "NOTICE")
self.assertMessageEqual(m, command='900', self.assertMessageEqual(
fail_msg='Did not send 900 after correct SASL authentication.') m,
self.assertEqual(m.params[2], 'jilles', m, command="AUTHENTICATE",
fail_msg='900 should contain the account name as 3rd argument ' params=["+"],
'({expects}), not {got}: {msg}') fail_msg="Sent “AUTHENTICATE PLAIN”, server should have "
"replied with “AUTHENTICATE +”, but instead sent: {msg}",
)
self.sendLine(1, "AUTHENTICATE amlsbGVzAGppbGxlcwBzZXNhbWU=")
m = self.getMessage(1, filter_pred=lambda m: m.command != "NOTICE")
self.assertMessageEqual(
m,
command="900",
fail_msg="Did not send 900 after correct SASL authentication.",
)
self.assertEqual(
m.params[2],
"jilles",
m,
fail_msg="900 should contain the account name as 3rd argument "
"({expects}), not {got}: {msg}",
)
@cases.SpecificationSelector.requiredBySpecification('IRCv3.1') @cases.SpecificationSelector.requiredBySpecification("IRCv3.1")
@cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN') @cases.OptionalityHelper.skipUnlessHasMechanism("PLAIN")
def testPlainNoAuthzid(self): def testPlainNoAuthzid(self):
"""“message = [authzid] UTF8NUL authcid UTF8NUL passwd """“message = [authzid] UTF8NUL authcid UTF8NUL passwd
@ -60,73 +78,105 @@ class SaslTestCase(cases.BaseServerTestCase, cases.OptionalityHelper):
identity produces the same authorization identity. identity produces the same authorization identity.
-- <https://tools.ietf.org/html/rfc4616#section-2> -- <https://tools.ietf.org/html/rfc4616#section-2>
""" """
self.controller.registerUser(self, 'foo', 'sesame') self.controller.registerUser(self, "foo", "sesame")
self.controller.registerUser(self, 'jilles', 'sesame') self.controller.registerUser(self, "jilles", "sesame")
self.controller.registerUser(self, 'bar', 'sesame') self.controller.registerUser(self, "bar", "sesame")
self.addClient() self.addClient()
self.sendLine(1, 'CAP LS 302') self.sendLine(1, "CAP LS 302")
capabilities = self.getCapLs(1) capabilities = self.getCapLs(1)
self.assertIn('sasl', capabilities, self.assertIn(
fail_msg='Does not have SASL as the controller claims.') "sasl",
if capabilities['sasl'] is not None: capabilities,
self.assertIn('PLAIN', capabilities['sasl'], fail_msg="Does not have SASL as the controller claims.",
fail_msg='Does not have PLAIN mechanism as the controller ' )
'claims') if capabilities["sasl"] is not None:
self.sendLine(1, 'AUTHENTICATE PLAIN') self.assertIn(
m = self.getMessage(1, filter_pred=lambda m:m.command != 'NOTICE') "PLAIN",
self.assertMessageEqual(m, command='AUTHENTICATE', params=['+'], capabilities["sasl"],
fail_msg='Sent “AUTHENTICATE PLAIN”, server should have ' fail_msg="Does not have PLAIN mechanism as the controller " "claims",
'replied with “AUTHENTICATE +”, but instead sent: {msg}') )
self.sendLine(1, 'AUTHENTICATE AGppbGxlcwBzZXNhbWU=') self.sendLine(1, "AUTHENTICATE PLAIN")
m = self.getMessage(1, filter_pred=lambda m:m.command != 'NOTICE') m = self.getMessage(1, filter_pred=lambda m: m.command != "NOTICE")
self.assertMessageEqual(m, command='900', self.assertMessageEqual(
fail_msg='Did not send 900 after correct SASL authentication.') m,
self.assertEqual(m.params[2], 'jilles', m, command="AUTHENTICATE",
fail_msg='900 should contain the account name as 3rd argument ' params=["+"],
'({expects}), not {got}: {msg}') fail_msg="Sent “AUTHENTICATE PLAIN”, server should have "
"replied with “AUTHENTICATE +”, but instead sent: {msg}",
)
self.sendLine(1, "AUTHENTICATE AGppbGxlcwBzZXNhbWU=")
m = self.getMessage(1, filter_pred=lambda m: m.command != "NOTICE")
self.assertMessageEqual(
m,
command="900",
fail_msg="Did not send 900 after correct SASL authentication.",
)
self.assertEqual(
m.params[2],
"jilles",
m,
fail_msg="900 should contain the account name as 3rd argument "
"({expects}), not {got}: {msg}",
)
@cases.SpecificationSelector.requiredBySpecification('IRCv3.1') @cases.SpecificationSelector.requiredBySpecification("IRCv3.1")
def testMechanismNotAvailable(self): def testMechanismNotAvailable(self):
"""“If authentication fails, a 904 or 905 numeric will be sent” """“If authentication fails, a 904 or 905 numeric will be sent”
-- <http://ircv3.net/specs/extensions/sasl-3.1.html#the-authenticate-command> -- <http://ircv3.net/specs/extensions/sasl-3.1.html#the-authenticate-command>
""" """
self.controller.registerUser(self, 'jilles', 'sesame') self.controller.registerUser(self, "jilles", "sesame")
self.addClient() self.addClient()
self.sendLine(1, 'CAP LS 302') self.sendLine(1, "CAP LS 302")
capabilities = self.getCapLs(1) capabilities = self.getCapLs(1)
self.assertIn('sasl', capabilities, self.assertIn(
fail_msg='Does not have SASL as the controller claims.') "sasl",
self.sendLine(1, 'AUTHENTICATE FOO') capabilities,
fail_msg="Does not have SASL as the controller claims.",
)
self.sendLine(1, "AUTHENTICATE FOO")
m = self.getRegistrationMessage(1) m = self.getRegistrationMessage(1)
self.assertMessageEqual(m, command='904', self.assertMessageEqual(
fail_msg='Did not reply with 904 to “AUTHENTICATE FOO”: {msg}') m,
command="904",
fail_msg="Did not reply with 904 to “AUTHENTICATE FOO”: {msg}",
)
@cases.SpecificationSelector.requiredBySpecification('IRCv3.1') @cases.SpecificationSelector.requiredBySpecification("IRCv3.1")
@cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN') @cases.OptionalityHelper.skipUnlessHasMechanism("PLAIN")
def testPlainLarge(self): def testPlainLarge(self):
"""Test the client splits large AUTHENTICATE messages whose payload """Test the client splits large AUTHENTICATE messages whose payload
is not a multiple of 400. is not a multiple of 400.
<http://ircv3.net/specs/extensions/sasl-3.1.html#the-authenticate-command> <http://ircv3.net/specs/extensions/sasl-3.1.html#the-authenticate-command>
""" """
self.controller.registerUser(self, 'foo', 'bar'*100) self.controller.registerUser(self, "foo", "bar" * 100)
authstring = base64.b64encode(b'\x00'.join( authstring = base64.b64encode(
[b'foo', b'foo', b'bar'*100])).decode() b"\x00".join([b"foo", b"foo", b"bar" * 100])
).decode()
self.addClient() self.addClient()
self.sendLine(1, 'CAP LS 302') self.sendLine(1, "CAP LS 302")
capabilities = self.getCapLs(1) capabilities = self.getCapLs(1)
self.assertIn('sasl', capabilities, self.assertIn(
fail_msg='Does not have SASL as the controller claims.') "sasl",
if capabilities['sasl'] is not None: capabilities,
self.assertIn('PLAIN', capabilities['sasl'], fail_msg="Does not have SASL as the controller claims.",
fail_msg='Does not have PLAIN mechanism as the controller ' )
'claims') if capabilities["sasl"] is not None:
self.sendLine(1, 'AUTHENTICATE PLAIN') self.assertIn(
"PLAIN",
capabilities["sasl"],
fail_msg="Does not have PLAIN mechanism as the controller " "claims",
)
self.sendLine(1, "AUTHENTICATE PLAIN")
m = self.getRegistrationMessage(1) m = self.getRegistrationMessage(1)
self.assertMessageEqual(m, command='AUTHENTICATE', params=['+'], self.assertMessageEqual(
fail_msg='Sent “AUTHENTICATE PLAIN”, expected ' m,
'“AUTHENTICATE +” as a response, but got: {msg}') command="AUTHENTICATE",
self.sendLine(1, 'AUTHENTICATE {}'.format(authstring[0:400])) params=["+"],
self.sendLine(1, 'AUTHENTICATE {}'.format(authstring[400:])) fail_msg="Sent “AUTHENTICATE PLAIN”, expected "
"“AUTHENTICATE +” as a response, but got: {msg}",
)
self.sendLine(1, "AUTHENTICATE {}".format(authstring[0:400]))
self.sendLine(1, "AUTHENTICATE {}".format(authstring[400:]))
self.confirmSuccessfulAuth() self.confirmSuccessfulAuth()
@ -134,45 +184,61 @@ class SaslTestCase(cases.BaseServerTestCase, cases.OptionalityHelper):
# TODO: check username/etc in this as well, so we can apply it to other tests # TODO: check username/etc in this as well, so we can apply it to other tests
# TODO: may be in the other order # TODO: may be in the other order
m = self.getRegistrationMessage(1) m = self.getRegistrationMessage(1)
self.assertMessageEqual(m, command='900', self.assertMessageEqual(
fail_msg='Expected 900 (RPL_LOGGEDIN) after successful ' m,
'login, but got: {msg}') command="900",
fail_msg="Expected 900 (RPL_LOGGEDIN) after successful "
"login, but got: {msg}",
)
m = self.getRegistrationMessage(1) m = self.getRegistrationMessage(1)
self.assertMessageEqual(m, command='903', self.assertMessageEqual(
fail_msg='Expected 903 (RPL_SASLSUCCESS) after successful ' m,
'login, but got: {msg}') command="903",
fail_msg="Expected 903 (RPL_SASLSUCCESS) after successful "
"login, but got: {msg}",
)
# TODO: add a test for when the length of the authstring is greater than 800. # TODO: add a test for when the length of the authstring is greater than 800.
# I don't know how to do it, because it would make the registration # I don't know how to do it, because it would make the registration
# message's length too big for it to be valid. # message's length too big for it to be valid.
@cases.SpecificationSelector.requiredBySpecification('IRCv3.1') @cases.SpecificationSelector.requiredBySpecification("IRCv3.1")
@cases.OptionalityHelper.skipUnlessHasMechanism('PLAIN') @cases.OptionalityHelper.skipUnlessHasMechanism("PLAIN")
def testPlainLargeEquals400(self): def testPlainLargeEquals400(self):
"""Test the client splits large AUTHENTICATE messages whose payload """Test the client splits large AUTHENTICATE messages whose payload
is not a multiple of 400. is not a multiple of 400.
<http://ircv3.net/specs/extensions/sasl-3.1.html#the-authenticate-command> <http://ircv3.net/specs/extensions/sasl-3.1.html#the-authenticate-command>
""" """
self.controller.registerUser(self, 'foo', 'bar'*97) self.controller.registerUser(self, "foo", "bar" * 97)
authstring = base64.b64encode(b'\x00'.join( authstring = base64.b64encode(
[b'foo', b'foo', b'bar'*97])).decode() b"\x00".join([b"foo", b"foo", b"bar" * 97])
assert len(authstring) == 400, 'Bad test' ).decode()
assert len(authstring) == 400, "Bad test"
self.addClient() self.addClient()
self.sendLine(1, 'CAP LS 302') self.sendLine(1, "CAP LS 302")
capabilities = self.getCapLs(1) capabilities = self.getCapLs(1)
self.assertIn('sasl', capabilities, self.assertIn(
fail_msg='Does not have SASL as the controller claims.') "sasl",
if capabilities['sasl'] is not None: capabilities,
self.assertIn('PLAIN', capabilities['sasl'], fail_msg="Does not have SASL as the controller claims.",
fail_msg='Does not have PLAIN mechanism as the controller ' )
'claims') if capabilities["sasl"] is not None:
self.sendLine(1, 'AUTHENTICATE PLAIN') self.assertIn(
"PLAIN",
capabilities["sasl"],
fail_msg="Does not have PLAIN mechanism as the controller " "claims",
)
self.sendLine(1, "AUTHENTICATE PLAIN")
m = self.getRegistrationMessage(1) m = self.getRegistrationMessage(1)
self.assertMessageEqual(m, command='AUTHENTICATE', params=['+'], self.assertMessageEqual(
fail_msg='Sent “AUTHENTICATE PLAIN”, expected ' m,
'“AUTHENTICATE +” as a response, but got: {msg}') command="AUTHENTICATE",
self.sendLine(1, 'AUTHENTICATE {}'.format(authstring)) params=["+"],
self.sendLine(1, 'AUTHENTICATE +') fail_msg="Sent “AUTHENTICATE PLAIN”, expected "
"“AUTHENTICATE +” as a response, but got: {msg}",
)
self.sendLine(1, "AUTHENTICATE {}".format(authstring))
self.sendLine(1, "AUTHENTICATE +")
self.confirmSuccessfulAuth() self.confirmSuccessfulAuth()

View File

@ -1,41 +1,47 @@
from irctest import cases from irctest import cases
from irctest.numerics import RPL_NAMREPLY from irctest.numerics import RPL_NAMREPLY
class StatusmsgTestCase(cases.BaseServerTestCase):
@cases.SpecificationSelector.requiredBySpecification('Oragono') class StatusmsgTestCase(cases.BaseServerTestCase):
@cases.SpecificationSelector.requiredBySpecification("Oragono")
def testInIsupport(self): def testInIsupport(self):
"""Check that the expected STATUSMSG parameter appears in our isupport list.""" """Check that the expected STATUSMSG parameter appears in our isupport list."""
isupport = self.getISupport() isupport = self.getISupport()
self.assertEqual(isupport['STATUSMSG'], '~&@%+') self.assertEqual(isupport["STATUSMSG"], "~&@%+")
@cases.SpecificationSelector.requiredBySpecification('Oragono') @cases.SpecificationSelector.requiredBySpecification("Oragono")
def testStatusmsg(self): def testStatusmsg(self):
"""Test that STATUSMSG are sent to the intended recipients, with the intended prefixes.""" """Test that STATUSMSG are sent to the intended recipients, with the intended prefixes."""
self.connectClient('chanop') self.connectClient("chanop")
self.joinChannel(1, '#chan') self.joinChannel(1, "#chan")
self.getMessages(1) self.getMessages(1)
self.connectClient('joe') self.connectClient("joe")
self.joinChannel(2, '#chan') self.joinChannel(2, "#chan")
self.getMessages(2) self.getMessages(2)
self.connectClient('schmoe') self.connectClient("schmoe")
self.sendLine(3, 'join #chan') self.sendLine(3, "join #chan")
messages = self.getMessages(3) messages = self.getMessages(3)
names = set() names = set()
for message in messages: for message in messages:
if message.command == RPL_NAMREPLY: if message.command == RPL_NAMREPLY:
names.update(set(message.params[-1].split())) names.update(set(message.params[-1].split()))
# chanop should be opped # chanop should be opped
self.assertEqual(names, {'@chanop', 'joe', 'schmoe'}, f'unexpected names: {names}') self.assertEqual(
names, {"@chanop", "joe", "schmoe"}, f"unexpected names: {names}"
)
self.sendLine(3, 'privmsg @#chan :this message is for operators') self.sendLine(3, "privmsg @#chan :this message is for operators")
self.getMessages(3) self.getMessages(3)
# check the operator's messages # check the operator's messages
statusMsg = self.getMessage(1, filter_pred=lambda m:m.command == 'PRIVMSG') statusMsg = self.getMessage(1, filter_pred=lambda m: m.command == "PRIVMSG")
self.assertMessageEqual(statusMsg, params=['@#chan', 'this message is for operators']) self.assertMessageEqual(
statusMsg, params=["@#chan", "this message is for operators"]
)
# check the non-operator's messages # check the non-operator's messages
unprivilegedMessages = [msg for msg in self.getMessages(2) if msg.command == 'PRIVMSG'] unprivilegedMessages = [
msg for msg in self.getMessages(2) if msg.command == "PRIVMSG"
]
self.assertEqual(len(unprivilegedMessages), 0) self.assertEqual(len(unprivilegedMessages), 0)

View File

@ -4,126 +4,157 @@ User commands as specified in Section 3.6 of RFC 2812:
""" """
from irctest import cases from irctest import cases
from irctest.numerics import RPL_WHOISUSER, RPL_WHOISCHANNELS, RPL_AWAY, RPL_NOWAWAY, RPL_UNAWAY from irctest.numerics import (
RPL_WHOISUSER,
RPL_WHOISCHANNELS,
RPL_AWAY,
RPL_NOWAWAY,
RPL_UNAWAY,
)
class WhoisTestCase(cases.BaseServerTestCase): class WhoisTestCase(cases.BaseServerTestCase):
@cases.SpecificationSelector.requiredBySpecification("RFC2812")
@cases.SpecificationSelector.requiredBySpecification('RFC2812')
def testWhoisUser(self): def testWhoisUser(self):
"""Test basic WHOIS behavior""" """Test basic WHOIS behavior"""
nick = 'myCoolNickname' nick = "myCoolNickname"
username = 'myUsernam' # may be truncated if longer than this username = "myUsernam" # may be truncated if longer than this
realname = 'My Real Name' realname = "My Real Name"
self.addClient() self.addClient()
self.sendLine(1, f'NICK {nick}') self.sendLine(1, f"NICK {nick}")
self.sendLine(1, f'USER {username} 0 * :{realname}') self.sendLine(1, f"USER {username} 0 * :{realname}")
self.skipToWelcome(1) self.skipToWelcome(1)
self.connectClient('otherNickname') self.connectClient("otherNickname")
self.getMessages(2) self.getMessages(2)
self.sendLine(2, 'WHOIS mycoolnickname') self.sendLine(2, "WHOIS mycoolnickname")
messages = self.getMessages(2) messages = self.getMessages(2)
whois_user = messages[0] whois_user = messages[0]
self.assertEqual(whois_user.command, RPL_WHOISUSER) self.assertEqual(whois_user.command, RPL_WHOISUSER)
# "<client> <nick> <username> <host> * :<realname>" # "<client> <nick> <username> <host> * :<realname>"
self.assertEqual(whois_user.params[1], nick) self.assertEqual(whois_user.params[1], nick)
self.assertIn(whois_user.params[2], ('~' + username, username)) self.assertIn(whois_user.params[2], ("~" + username, username))
# dumb regression test for oragono/oragono#355: # dumb regression test for oragono/oragono#355:
self.assertNotIn(whois_user.params[3], [nick, username, '~' + username, realname]) self.assertNotIn(
whois_user.params[3], [nick, username, "~" + username, realname]
)
self.assertEqual(whois_user.params[5], realname) self.assertEqual(whois_user.params[5], realname)
class InvisibleTestCase(cases.BaseServerTestCase): class InvisibleTestCase(cases.BaseServerTestCase):
@cases.SpecificationSelector.requiredBySpecification("Oragono")
@cases.SpecificationSelector.requiredBySpecification('Oragono')
def testInvisibleWhois(self): def testInvisibleWhois(self):
"""Test interaction between MODE +i and RPL_WHOISCHANNELS.""" """Test interaction between MODE +i and RPL_WHOISCHANNELS."""
self.connectClient('userOne') self.connectClient("userOne")
self.joinChannel(1, '#xyz') self.joinChannel(1, "#xyz")
self.connectClient('userTwo') self.connectClient("userTwo")
self.getMessages(2) self.getMessages(2)
self.sendLine(2, 'WHOIS userOne') self.sendLine(2, "WHOIS userOne")
commands = {m.command for m in self.getMessages(2)} commands = {m.command for m in self.getMessages(2)}
self.assertIn(RPL_WHOISCHANNELS, commands, self.assertIn(
'RPL_WHOISCHANNELS should be sent for a non-invisible nick') RPL_WHOISCHANNELS,
commands,
"RPL_WHOISCHANNELS should be sent for a non-invisible nick",
)
self.getMessages(1) self.getMessages(1)
self.sendLine(1, 'MODE userOne +i') self.sendLine(1, "MODE userOne +i")
message = self.getMessage(1) message = self.getMessage(1)
self.assertEqual(message.command, 'MODE', self.assertEqual(
'Expected MODE reply, but received {}'.format(message.command)) message.command,
self.assertEqual(message.params, ['userOne', '+i'], "MODE",
'Expected user set +i, but received {}'.format(message.params)) "Expected MODE reply, but received {}".format(message.command),
)
self.assertEqual(
message.params,
["userOne", "+i"],
"Expected user set +i, but received {}".format(message.params),
)
self.getMessages(2) self.getMessages(2)
self.sendLine(2, 'WHOIS userOne') self.sendLine(2, "WHOIS userOne")
commands = {m.command for m in self.getMessages(2)} commands = {m.command for m in self.getMessages(2)}
self.assertNotIn(RPL_WHOISCHANNELS, commands, self.assertNotIn(
'RPL_WHOISCHANNELS should not be sent for an invisible nick' RPL_WHOISCHANNELS,
'unless the user is also a member of the channel') commands,
"RPL_WHOISCHANNELS should not be sent for an invisible nick"
"unless the user is also a member of the channel",
)
self.sendLine(2, 'JOIN #xyz') self.sendLine(2, "JOIN #xyz")
self.sendLine(2, 'WHOIS userOne') self.sendLine(2, "WHOIS userOne")
commands = {m.command for m in self.getMessages(2)} commands = {m.command for m in self.getMessages(2)}
self.assertIn(RPL_WHOISCHANNELS, commands, self.assertIn(
'RPL_WHOISCHANNELS should be sent for an invisible nick' RPL_WHOISCHANNELS,
'if the user is also a member of the channel') commands,
"RPL_WHOISCHANNELS should be sent for an invisible nick"
"if the user is also a member of the channel",
)
self.sendLine(2, 'PART #xyz') self.sendLine(2, "PART #xyz")
self.getMessages(2) self.getMessages(2)
self.getMessages(1) self.getMessages(1)
self.sendLine(1, 'MODE userOne -i') self.sendLine(1, "MODE userOne -i")
message = self.getMessage(1) message = self.getMessage(1)
self.assertEqual(message.command, 'MODE', self.assertEqual(
'Expected MODE reply, but received {}'.format(message.command)) message.command,
self.assertEqual(message.params, ['userOne', '-i'], "MODE",
'Expected user set -i, but received {}'.format(message.params)) "Expected MODE reply, but received {}".format(message.command),
)
self.assertEqual(
message.params,
["userOne", "-i"],
"Expected user set -i, but received {}".format(message.params),
)
self.sendLine(2, 'WHOIS userOne') self.sendLine(2, "WHOIS userOne")
commands = {m.command for m in self.getMessages(2)} commands = {m.command for m in self.getMessages(2)}
self.assertIn(RPL_WHOISCHANNELS, commands, self.assertIn(
'RPL_WHOISCHANNELS should be sent for a non-invisible nick') RPL_WHOISCHANNELS,
commands,
"RPL_WHOISCHANNELS should be sent for a non-invisible nick",
)
@cases.SpecificationSelector.requiredBySpecification('Oragono') @cases.SpecificationSelector.requiredBySpecification("Oragono")
def testWhoisAccount(self): def testWhoisAccount(self):
"""Test numeric 330, RPL_WHOISACCOUNT.""" """Test numeric 330, RPL_WHOISACCOUNT."""
self.controller.registerUser(self, 'shivaram', 'sesame') self.controller.registerUser(self, "shivaram", "sesame")
self.connectClient('netcat') self.connectClient("netcat")
self.sendLine(1, 'NS IDENTIFY shivaram sesame') self.sendLine(1, "NS IDENTIFY shivaram sesame")
self.getMessages(1) self.getMessages(1)
self.connectClient('curious') self.connectClient("curious")
self.sendLine(2, 'WHOIS netcat') self.sendLine(2, "WHOIS netcat")
messages = self.getMessages(2) messages = self.getMessages(2)
# 330 RPL_WHOISACCOUNT # 330 RPL_WHOISACCOUNT
whoisaccount = [message for message in messages if message.command == '330'] whoisaccount = [message for message in messages if message.command == "330"]
self.assertEqual(len(whoisaccount), 1) self.assertEqual(len(whoisaccount), 1)
params = whoisaccount[0].params params = whoisaccount[0].params
# <client> <nick> <authname> :<info> # <client> <nick> <authname> :<info>
self.assertEqual(len(params), 4) self.assertEqual(len(params), 4)
self.assertEqual(params[:3], ['curious', 'netcat', 'shivaram']) self.assertEqual(params[:3], ["curious", "netcat", "shivaram"])
self.sendLine(1, 'WHOIS curious') self.sendLine(1, "WHOIS curious")
messages = self.getMessages(2) messages = self.getMessages(2)
whoisaccount = [message for message in messages if message.command == '330'] whoisaccount = [message for message in messages if message.command == "330"]
self.assertEqual(len(whoisaccount), 0) self.assertEqual(len(whoisaccount), 0)
class AwayTestCase(cases.BaseServerTestCase):
@cases.SpecificationSelector.requiredBySpecification('RFC2812') class AwayTestCase(cases.BaseServerTestCase):
@cases.SpecificationSelector.requiredBySpecification("RFC2812")
def testAway(self): def testAway(self):
self.connectClient('bar') self.connectClient("bar")
self.sendLine(1, "AWAY :I'm not here right now") self.sendLine(1, "AWAY :I'm not here right now")
replies = self.getMessages(1) replies = self.getMessages(1)
self.assertIn(RPL_NOWAWAY, [msg.command for msg in replies]) self.assertIn(RPL_NOWAWAY, [msg.command for msg in replies])
self.connectClient('qux') self.connectClient("qux")
self.sendLine(2, "PRIVMSG bar :what's up") self.sendLine(2, "PRIVMSG bar :what's up")
replies = self.getMessages(2) replies = self.getMessages(2)
self.assertEqual(len(replies), 1) self.assertEqual(len(replies), 1)
self.assertEqual(replies[0].command, RPL_AWAY) self.assertEqual(replies[0].command, RPL_AWAY)
self.assertEqual(replies[0].params, ['qux', 'bar', "I'm not here right now"]) self.assertEqual(replies[0].params, ["qux", "bar", "I'm not here right now"])
self.sendLine(1, "AWAY") self.sendLine(1, "AWAY")
replies = self.getMessages(1) replies = self.getMessages(1)
@ -133,31 +164,36 @@ class AwayTestCase(cases.BaseServerTestCase):
replies = self.getMessages(2) replies = self.getMessages(2)
self.assertEqual(len(replies), 0) self.assertEqual(len(replies), 0)
class TestNoCTCPMode(cases.BaseServerTestCase):
@cases.SpecificationSelector.requiredBySpecification('Oragono') class TestNoCTCPMode(cases.BaseServerTestCase):
@cases.SpecificationSelector.requiredBySpecification("Oragono")
def testNoCTCPMode(self): def testNoCTCPMode(self):
self.connectClient('bar', 'bar') self.connectClient("bar", "bar")
self.connectClient('qux', 'qux') self.connectClient("qux", "qux")
# CTCP is not blocked by default: # CTCP is not blocked by default:
self.sendLine('qux', 'PRIVMSG bar :\x01VERSION\x01') self.sendLine("qux", "PRIVMSG bar :\x01VERSION\x01")
self.getMessages('qux') self.getMessages("qux")
relay = [msg for msg in self.getMessages('bar') if msg.command == 'PRIVMSG'][0] relay = [msg for msg in self.getMessages("bar") if msg.command == "PRIVMSG"][0]
self.assertEqual(relay.params[-1], '\x01VERSION\x01') self.assertEqual(relay.params[-1], "\x01VERSION\x01")
# set the no-CTCP user mode on bar: # set the no-CTCP user mode on bar:
self.sendLine('bar', 'MODE bar +T') self.sendLine("bar", "MODE bar +T")
replies = self.getMessages('bar') replies = self.getMessages("bar")
umode_line = [msg for msg in replies if msg.command == 'MODE'][0] umode_line = [msg for msg in replies if msg.command == "MODE"][0]
self.assertMessageEqual(umode_line, command='MODE', params=['bar', '+T']) self.assertMessageEqual(umode_line, command="MODE", params=["bar", "+T"])
# CTCP is now blocked: # CTCP is now blocked:
self.sendLine('qux', 'PRIVMSG bar :\x01VERSION\x01') self.sendLine("qux", "PRIVMSG bar :\x01VERSION\x01")
self.getMessages('qux') self.getMessages("qux")
self.assertEqual(self.getMessages('bar'), []) self.assertEqual(self.getMessages("bar"), [])
# normal PRIVMSG go through: # normal PRIVMSG go through:
self.sendLine('qux', 'PRIVMSG bar :please just tell me your client version') self.sendLine("qux", "PRIVMSG bar :please just tell me your client version")
self.getMessages('qux') self.getMessages("qux")
relay = self.getMessages('bar')[0] relay = self.getMessages("bar")[0]
self.assertMessageEqual(relay, command='PRIVMSG', nick='qux', params=['bar', 'please just tell me your client version']) self.assertMessageEqual(
relay,
command="PRIVMSG",
nick="qux",
params=["bar", "please just tell me your client version"],
)

View File

@ -1,23 +1,29 @@
from irctest import cases from irctest import cases
class Utf8TestCase(cases.BaseServerTestCase, cases.OptionalityHelper): class Utf8TestCase(cases.BaseServerTestCase, cases.OptionalityHelper):
@cases.SpecificationSelector.requiredBySpecification('Oragono') @cases.SpecificationSelector.requiredBySpecification("Oragono")
def testUtf8Validation(self): def testUtf8Validation(self):
self.connectClient('bar', capabilities=['batch', 'echo-message', 'labeled-response', 'message-tags']) self.connectClient(
self.joinChannel(1, '#qux') "bar",
self.sendLine(1, 'PRIVMSG #qux hi') capabilities=["batch", "echo-message", "labeled-response", "message-tags"],
)
self.joinChannel(1, "#qux")
self.sendLine(1, "PRIVMSG #qux hi")
ms = self.getMessages(1) ms = self.getMessages(1)
self.assertMessageEqual([m for m in ms if m.command == 'PRIVMSG'][0], params=['#qux', 'hi']) self.assertMessageEqual(
[m for m in ms if m.command == "PRIVMSG"][0], params=["#qux", "hi"]
)
self.sendLine(1, b'PRIVMSG #qux hi\xaa') self.sendLine(1, b"PRIVMSG #qux hi\xaa")
ms = self.getMessages(1) ms = self.getMessages(1)
self.assertEqual(len(ms), 1) self.assertEqual(len(ms), 1)
self.assertEqual(ms[0].command, 'FAIL') self.assertEqual(ms[0].command, "FAIL")
self.assertEqual(ms[0].params[:2], ['PRIVMSG', 'INVALID_UTF8']) self.assertEqual(ms[0].params[:2], ["PRIVMSG", "INVALID_UTF8"])
self.sendLine(1, b'@label=xyz PRIVMSG #qux hi\xaa') self.sendLine(1, b"@label=xyz PRIVMSG #qux hi\xaa")
ms = self.getMessages(1) ms = self.getMessages(1)
self.assertEqual(len(ms), 1) self.assertEqual(len(ms), 1)
self.assertEqual(ms[0].command, 'FAIL') self.assertEqual(ms[0].command, "FAIL")
self.assertEqual(ms[0].params[:2], ['PRIVMSG', 'INVALID_UTF8']) self.assertEqual(ms[0].params[:2], ["PRIVMSG", "INVALID_UTF8"])
self.assertEqual(ms[0].tags.get('label'), 'xyz') self.assertEqual(ms[0].tags.get("label"), "xyz")

View File

@ -10,7 +10,7 @@ def extract_playback_privmsgs(messages):
# convert the output of a playback command, drop the echo message # convert the output of a playback command, drop the echo message
result = [] result = []
for msg in messages: for msg in messages:
if msg.command == 'PRIVMSG' and msg.params[0].lower() != '*playback': if msg.command == "PRIVMSG" and msg.params[0].lower() != "*playback":
result.append(to_history_message(msg)) result.append(to_history_message(msg))
return result return result
@ -22,91 +22,197 @@ class ZncPlaybackTestCase(cases.BaseServerTestCase):
"chathistory": True, "chathistory": True,
} }
@cases.SpecificationSelector.requiredBySpecification('Oragono') @cases.SpecificationSelector.requiredBySpecification("Oragono")
def testZncPlayback(self): def testZncPlayback(self):
early_time = int(time.time() - 60) early_time = int(time.time() - 60)
chname = random_name('#znc_channel') chname = random_name("#znc_channel")
bar, pw = random_name('bar'), random_name('pass') bar, pw = random_name("bar"), random_name("pass")
self.controller.registerUser(self, bar, pw) self.controller.registerUser(self, bar, pw)
self.connectClient(bar, name=bar, capabilities=['batch', 'labeled-response', 'message-tags', 'server-time', 'echo-message'], password=pw) self.connectClient(
bar,
name=bar,
capabilities=[
"batch",
"labeled-response",
"message-tags",
"server-time",
"echo-message",
],
password=pw,
)
self.joinChannel(bar, chname) self.joinChannel(bar, chname)
qux = random_name('qux') qux = random_name("qux")
self.connectClient(qux, name=qux, capabilities=['batch', 'labeled-response', 'message-tags', 'server-time', 'echo-message']) self.connectClient(
qux,
name=qux,
capabilities=[
"batch",
"labeled-response",
"message-tags",
"server-time",
"echo-message",
],
)
self.joinChannel(qux, chname) self.joinChannel(qux, chname)
self.sendLine(qux, 'PRIVMSG %s :hi there' % (bar,)) self.sendLine(qux, "PRIVMSG %s :hi there" % (bar,))
dm = to_history_message([msg for msg in self.getMessages(qux) if msg.command == 'PRIVMSG'][0]) dm = to_history_message(
self.assertEqual(dm.text, 'hi there') [msg for msg in self.getMessages(qux) if msg.command == "PRIVMSG"][0]
)
self.assertEqual(dm.text, "hi there")
NUM_MESSAGES = 10 NUM_MESSAGES = 10
echo_messages = [] echo_messages = []
for i in range(NUM_MESSAGES): for i in range(NUM_MESSAGES):
self.sendLine(qux, 'PRIVMSG %s :this is message %d' % (chname, i)) self.sendLine(qux, "PRIVMSG %s :this is message %d" % (chname, i))
echo_messages.extend(to_history_message(msg) for msg in self.getMessages(qux) if msg.command == 'PRIVMSG') echo_messages.extend(
to_history_message(msg)
for msg in self.getMessages(qux)
if msg.command == "PRIVMSG"
)
time.sleep(0.003) time.sleep(0.003)
self.assertEqual(len(echo_messages), NUM_MESSAGES) self.assertEqual(len(echo_messages), NUM_MESSAGES)
self.getMessages(bar) self.getMessages(bar)
# reattach to 'bar' # reattach to 'bar'
self.connectClient(bar, name='viewer', capabilities=['batch', 'labeled-response', 'message-tags', 'server-time', 'echo-message'], password=pw) self.connectClient(
self.sendLine('viewer', 'PRIVMSG *playback :play * %d' % (early_time,)) bar,
messages = extract_playback_privmsgs(self.getMessages('viewer')) name="viewer",
capabilities=[
"batch",
"labeled-response",
"message-tags",
"server-time",
"echo-message",
],
password=pw,
)
self.sendLine("viewer", "PRIVMSG *playback :play * %d" % (early_time,))
messages = extract_playback_privmsgs(self.getMessages("viewer"))
self.assertEqual(set(messages), set([dm] + echo_messages)) self.assertEqual(set(messages), set([dm] + echo_messages))
self.sendLine('viewer', 'QUIT') self.sendLine("viewer", "QUIT")
self.assertDisconnected('viewer') self.assertDisconnected("viewer")
# reattach to 'bar', play back selectively # reattach to 'bar', play back selectively
self.connectClient(bar, name='viewer', capabilities=['batch', 'labeled-response', 'message-tags', 'server-time', 'echo-message'], password=pw) self.connectClient(
bar,
name="viewer",
capabilities=[
"batch",
"labeled-response",
"message-tags",
"server-time",
"echo-message",
],
password=pw,
)
mid_timestamp = ircv3_timestamp_to_unixtime(echo_messages[5].time) mid_timestamp = ircv3_timestamp_to_unixtime(echo_messages[5].time)
# exclude message 5 itself (oragono's CHATHISTORY implementation corrects for this, but znc.in/playback does not because whatever) # exclude message 5 itself (oragono's CHATHISTORY implementation corrects for this, but znc.in/playback does not because whatever)
mid_timestamp += .001 mid_timestamp += 0.001
self.sendLine('viewer', 'PRIVMSG *playback :play * %s' % (mid_timestamp,)) self.sendLine("viewer", "PRIVMSG *playback :play * %s" % (mid_timestamp,))
messages = extract_playback_privmsgs(self.getMessages('viewer')) messages = extract_playback_privmsgs(self.getMessages("viewer"))
self.assertEqual(messages, echo_messages[6:]) self.assertEqual(messages, echo_messages[6:])
self.sendLine('viewer', 'QUIT') self.sendLine("viewer", "QUIT")
self.assertDisconnected('viewer') self.assertDisconnected("viewer")
# reattach to 'bar', play back selectively (pass a parameter and 2 timestamps) # reattach to 'bar', play back selectively (pass a parameter and 2 timestamps)
self.connectClient(bar, name='viewer', capabilities=['batch', 'labeled-response', 'message-tags', 'server-time', 'echo-message'], password=pw) self.connectClient(
bar,
name="viewer",
capabilities=[
"batch",
"labeled-response",
"message-tags",
"server-time",
"echo-message",
],
password=pw,
)
start_timestamp = ircv3_timestamp_to_unixtime(echo_messages[2].time) start_timestamp = ircv3_timestamp_to_unixtime(echo_messages[2].time)
start_timestamp += .001 start_timestamp += 0.001
end_timestamp = ircv3_timestamp_to_unixtime(echo_messages[7].time) end_timestamp = ircv3_timestamp_to_unixtime(echo_messages[7].time)
self.sendLine('viewer', 'PRIVMSG *playback :play %s %s %s' % (chname, start_timestamp, end_timestamp,)) self.sendLine(
messages = extract_playback_privmsgs(self.getMessages('viewer')) "viewer",
"PRIVMSG *playback :play %s %s %s"
% (
chname,
start_timestamp,
end_timestamp,
),
)
messages = extract_playback_privmsgs(self.getMessages("viewer"))
self.assertEqual(messages, echo_messages[3:7]) self.assertEqual(messages, echo_messages[3:7])
# test nicknames as targets # test nicknames as targets
self.sendLine('viewer', 'PRIVMSG *playback :play %s %d' % (qux, early_time,)) self.sendLine(
messages = extract_playback_privmsgs(self.getMessages('viewer')) "viewer",
"PRIVMSG *playback :play %s %d"
% (
qux,
early_time,
),
)
messages = extract_playback_privmsgs(self.getMessages("viewer"))
self.assertEqual(messages, [dm]) self.assertEqual(messages, [dm])
self.sendLine('viewer', 'PRIVMSG *playback :play %s %d' % (qux.upper(), early_time,)) self.sendLine(
messages = extract_playback_privmsgs(self.getMessages('viewer')) "viewer",
"PRIVMSG *playback :play %s %d"
% (
qux.upper(),
early_time,
),
)
messages = extract_playback_privmsgs(self.getMessages("viewer"))
self.assertEqual(messages, [dm]) self.assertEqual(messages, [dm])
self.sendLine('viewer', 'QUIT') self.sendLine("viewer", "QUIT")
self.assertDisconnected('viewer') self.assertDisconnected("viewer")
# test 2-argument form # test 2-argument form
self.connectClient(bar, name='viewer', capabilities=['batch', 'labeled-response', 'message-tags', 'server-time', 'echo-message'], password=pw) self.connectClient(
self.sendLine('viewer', 'PRIVMSG *playback :play %s' % (chname,)) bar,
messages = extract_playback_privmsgs(self.getMessages('viewer')) name="viewer",
capabilities=[
"batch",
"labeled-response",
"message-tags",
"server-time",
"echo-message",
],
password=pw,
)
self.sendLine("viewer", "PRIVMSG *playback :play %s" % (chname,))
messages = extract_playback_privmsgs(self.getMessages("viewer"))
self.assertEqual(messages, echo_messages) self.assertEqual(messages, echo_messages)
self.sendLine('viewer', 'PRIVMSG *playback :play *self') self.sendLine("viewer", "PRIVMSG *playback :play *self")
messages = extract_playback_privmsgs(self.getMessages('viewer')) messages = extract_playback_privmsgs(self.getMessages("viewer"))
self.assertEqual(messages, [dm]) self.assertEqual(messages, [dm])
self.sendLine('viewer', 'PRIVMSG *playback :play *') self.sendLine("viewer", "PRIVMSG *playback :play *")
messages = extract_playback_privmsgs(self.getMessages('viewer')) messages = extract_playback_privmsgs(self.getMessages("viewer"))
self.assertEqual(set(messages), set([dm] + echo_messages)) self.assertEqual(set(messages), set([dm] + echo_messages))
self.sendLine('viewer', 'QUIT') self.sendLine("viewer", "QUIT")
self.assertDisconnected('viewer') self.assertDisconnected("viewer")
# test limiting behavior # test limiting behavior
config = self.controller.getConfig() config = self.controller.getConfig()
config['history']['znc-maxmessages'] = 5 config["history"]["znc-maxmessages"] = 5
self.controller.rehash(self, config) self.controller.rehash(self, config)
self.connectClient(bar, name='viewer', capabilities=['batch', 'labeled-response', 'message-tags', 'server-time', 'echo-message'], password=pw) self.connectClient(
self.sendLine('viewer', 'PRIVMSG *playback :play %s %d' % (chname, int(time.time() - 60))) bar,
messages = extract_playback_privmsgs(self.getMessages('viewer')) name="viewer",
capabilities=[
"batch",
"labeled-response",
"message-tags",
"server-time",
"echo-message",
],
password=pw,
)
self.sendLine(
"viewer", "PRIVMSG *playback :play %s %d" % (chname, int(time.time() - 60))
)
messages = extract_playback_privmsgs(self.getMessages("viewer"))
# should receive the latest 5 messages # should receive the latest 5 messages
self.assertEqual(messages, echo_messages[5:]) self.assertEqual(messages, echo_messages[5:])

View File

@ -1,16 +1,17 @@
import enum import enum
@enum.unique @enum.unique
class Specifications(enum.Enum): class Specifications(enum.Enum):
RFC1459 = 'RFC1459' RFC1459 = "RFC1459"
RFC2812 = 'RFC2812' RFC2812 = "RFC2812"
RFCDeprecated = 'RFC-deprecated' RFCDeprecated = "RFC-deprecated"
IRC301 = 'IRCv3.1' IRC301 = "IRCv3.1"
IRC302 = 'IRCv3.2' IRC302 = "IRCv3.2"
IRC302Deprecated = 'IRCv3.2-deprecated' IRC302Deprecated = "IRCv3.2-deprecated"
Oragono = 'Oragono' Oragono = "Oragono"
Multiline = 'multiline' Multiline = "multiline"
MessageTags = 'message-tags' MessageTags = "message-tags"
@classmethod @classmethod
def of_name(cls, name): def of_name(cls, name):

View File

@ -1,4 +1,3 @@
import collections import collections
TlsConfig = collections.namedtuple('TlsConfig', TlsConfig = collections.namedtuple("TlsConfig", "enable trusted_fingerprints")
'enable trusted_fingerprints')

2
pyproject.toml Normal file
View File

@ -0,0 +1,2 @@
[tool.black]
target-version = ['py37']

6
setup.cfg Normal file
View File

@ -0,0 +1,6 @@
[flake8]
# E203: whitespaces before ':' <https://github.com/psf/black/issues/315>
# E231: missing whitespace after ','
# W503: line break before binary operator <https://github.com/psf/black/issues/52>
ignore = E203,E231,W503
max-line-length = 88