2021-02-22 18:04:23 +00:00
|
|
|
import functools
|
2015-12-19 00:11:57 +00:00
|
|
|
import socket
|
2021-02-22 18:04:23 +00:00
|
|
|
import ssl
|
2016-07-20 09:41:35 +00:00
|
|
|
import tempfile
|
2021-02-22 18:04:23 +00:00
|
|
|
import time
|
2021-02-28 16:08:27 +00:00
|
|
|
from typing import (
|
|
|
|
Any,
|
|
|
|
Callable,
|
|
|
|
Container,
|
|
|
|
Dict,
|
|
|
|
Generic,
|
|
|
|
Hashable,
|
|
|
|
Iterable,
|
|
|
|
List,
|
|
|
|
Optional,
|
|
|
|
Set,
|
|
|
|
Tuple,
|
|
|
|
Type,
|
|
|
|
TypeVar,
|
|
|
|
Union,
|
|
|
|
cast,
|
|
|
|
)
|
2015-12-19 00:11:57 +00:00
|
|
|
import unittest
|
|
|
|
|
2021-02-16 23:51:47 +00:00
|
|
|
import pytest
|
|
|
|
|
2021-02-28 18:09:32 +00:00
|
|
|
from . import basecontrollers, client_mock, patma, runner, tls
|
2021-02-28 16:08:27 +00:00
|
|
|
from .authentication import Authentication
|
2021-02-28 12:40:08 +00:00
|
|
|
from .basecontrollers import TestCaseControllerConfig
|
2021-02-22 18:04:23 +00:00
|
|
|
from .exceptions import ConnectionClosed
|
|
|
|
from .irc_utils import capabilities, message_parser
|
2021-06-26 22:26:41 +00:00
|
|
|
from .irc_utils.junkdrawer import find_hostname_and_port, normalizeWhitespace
|
2021-02-28 16:08:27 +00:00
|
|
|
from .irc_utils.message_parser import Message
|
2020-02-17 09:05:21 +00:00
|
|
|
from .irc_utils.sasl import sasl_plain_blob
|
2021-02-22 18:02:13 +00:00
|
|
|
from .numerics import (
|
|
|
|
ERR_BADCHANNELKEY,
|
|
|
|
ERR_BANNEDFROMCHAN,
|
2021-02-22 18:04:23 +00:00
|
|
|
ERR_INVITEONLYCHAN,
|
2021-02-22 18:02:13 +00:00
|
|
|
ERR_NEEDREGGEDNICK,
|
2021-02-22 18:04:23 +00:00
|
|
|
ERR_NOSUCHCHANNEL,
|
|
|
|
ERR_TOOMANYCHANNELS,
|
2021-02-22 18:02:13 +00:00
|
|
|
)
|
2021-02-27 09:36:30 +00:00
|
|
|
from .specifications import Capabilities, IsupportTokens, Specifications
|
2021-02-22 18:02:13 +00:00
|
|
|
|
2021-06-26 16:37:44 +00:00
|
|
|
__tracebackhide__ = True # Hide from pytest tracebacks on test failure.
|
|
|
|
|
2021-02-22 18:02:13 +00:00
|
|
|
CHANNEL_JOIN_FAIL_NUMERICS = frozenset(
|
|
|
|
[
|
|
|
|
ERR_NOSUCHCHANNEL,
|
|
|
|
ERR_TOOMANYCHANNELS,
|
|
|
|
ERR_BADCHANNELKEY,
|
|
|
|
ERR_INVITEONLYCHAN,
|
|
|
|
ERR_BANNEDFROMCHAN,
|
|
|
|
ERR_NEEDREGGEDNICK,
|
|
|
|
]
|
|
|
|
)
|
2021-02-18 04:27:48 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
# typevar for decorators
|
|
|
|
TCallable = TypeVar("TCallable", bound=Callable)
|
|
|
|
|
|
|
|
# typevar for the client name used by tests (usually int or str)
|
|
|
|
TClientName = TypeVar("TClientName", bound=Union[Hashable, int])
|
|
|
|
|
|
|
|
TController = TypeVar("TController", bound=basecontrollers._BaseController)
|
|
|
|
|
|
|
|
# general-purpose typevar
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
2021-02-18 04:27:48 +00:00
|
|
|
|
|
|
|
class ChannelJoinException(Exception):
|
2021-02-28 16:08:27 +00:00
|
|
|
def __init__(self, code: str, params: List[str]):
|
2021-02-22 18:02:13 +00:00
|
|
|
super().__init__(f"Failed to join channel ({code}): {params}")
|
2021-02-18 04:27:48 +00:00
|
|
|
self.code = code
|
|
|
|
self.params = params
|
2015-12-19 07:43:45 +00:00
|
|
|
|
2021-02-22 18:02:13 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
class _IrcTestCase(unittest.TestCase, Generic[TController]):
|
2015-12-20 12:47:30 +00:00
|
|
|
"""Base class for test cases."""
|
2021-02-22 18:02:13 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
# Will be set by __main__.py
|
|
|
|
controllerClass: Type[TController]
|
|
|
|
show_io: bool
|
|
|
|
|
|
|
|
controller: TController
|
2015-12-19 00:11:57 +00:00
|
|
|
|
2021-02-15 22:03:11 +00:00
|
|
|
@staticmethod
|
2021-02-28 12:40:08 +00:00
|
|
|
def config() -> TestCaseControllerConfig:
|
2021-02-15 22:03:11 +00:00
|
|
|
"""Some configuration to pass to the controllers.
|
|
|
|
For example, Oragono only enables its MySQL support if
|
|
|
|
config()["chathistory"]=True.
|
|
|
|
"""
|
2021-02-28 12:40:08 +00:00
|
|
|
return TestCaseControllerConfig()
|
2021-02-15 22:03:11 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def description(self) -> str:
|
2015-12-21 12:09:30 +00:00
|
|
|
method_doc = self._testMethodDoc
|
|
|
|
if not method_doc:
|
2021-02-22 18:02:13 +00:00
|
|
|
return ""
|
2021-02-22 18:27:43 +00:00
|
|
|
return "\t" + normalizeWhitespace(
|
|
|
|
method_doc, removeNewline=False
|
|
|
|
).strip().replace("\n ", "\n\t")
|
2015-12-21 12:09:30 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def setUp(self) -> None:
|
2015-12-20 00:48:56 +00:00
|
|
|
super().setUp()
|
2021-04-17 21:10:10 +00:00
|
|
|
if self.controllerClass is not None:
|
|
|
|
self.controller = self.controllerClass(self.config())
|
2015-12-20 00:48:56 +00:00
|
|
|
if self.show_io:
|
2021-02-22 18:02:13 +00:00
|
|
|
print("---- new test ----")
|
|
|
|
|
2021-02-28 17:52:14 +00:00
|
|
|
def assertMessageMatch(self, msg: Message, **kwargs: Any) -> None:
|
2021-02-25 22:07:43 +00:00
|
|
|
"""Helper for partially comparing a message.
|
|
|
|
|
|
|
|
Takes the message as first arguments, and comparisons to be made
|
|
|
|
as keyword arguments.
|
|
|
|
|
2021-02-28 19:44:50 +00:00
|
|
|
Uses patma.match_list on the params argument.
|
2021-02-27 20:20:13 +00:00
|
|
|
"""
|
2021-02-25 22:07:43 +00:00
|
|
|
error = self.messageDiffers(msg, **kwargs)
|
|
|
|
if error:
|
2021-02-27 13:14:08 +00:00
|
|
|
raise self.failureException(error)
|
2021-02-25 22:07:43 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def messageEqual(self, msg: Message, **kwargs: Any) -> bool:
|
2021-02-25 22:07:43 +00:00
|
|
|
"""Boolean negation of `messageDiffers` (returns a boolean,
|
|
|
|
not an optional string)."""
|
|
|
|
return not self.messageDiffers(msg, **kwargs)
|
|
|
|
|
|
|
|
def messageDiffers(
|
2021-02-22 18:02:13 +00:00
|
|
|
self,
|
2021-02-28 16:08:27 +00:00
|
|
|
msg: Message,
|
2021-03-01 19:18:09 +00:00
|
|
|
params: Optional[List[Union[str, None, patma.Operator]]] = None,
|
2021-02-28 16:08:27 +00:00
|
|
|
target: Optional[str] = None,
|
2021-03-01 19:18:09 +00:00
|
|
|
tags: Optional[
|
|
|
|
Dict[Union[str, patma.Operator], Union[str, patma.Operator, None]]
|
|
|
|
] = None,
|
2021-02-28 16:08:27 +00:00
|
|
|
nick: Optional[str] = None,
|
|
|
|
fail_msg: Optional[str] = None,
|
|
|
|
extra_format: Tuple = (),
|
|
|
|
**kwargs: Any,
|
|
|
|
) -> Optional[str]:
|
2021-02-25 22:07:43 +00:00
|
|
|
"""Returns an error message if the message doesn't match the given arguments,
|
|
|
|
or None if it matches."""
|
2015-12-19 23:47:06 +00:00
|
|
|
for (key, value) in kwargs.items():
|
2021-02-25 22:07:43 +00:00
|
|
|
if getattr(msg, key) != value:
|
|
|
|
fail_msg = (
|
|
|
|
fail_msg or "expected {param} to be {expects}, got {got}: {msg}"
|
|
|
|
)
|
|
|
|
return fail_msg.format(
|
|
|
|
*extra_format,
|
|
|
|
got=getattr(msg, key),
|
|
|
|
expects=value,
|
|
|
|
param=key,
|
|
|
|
msg=msg,
|
|
|
|
)
|
2021-02-27 20:20:13 +00:00
|
|
|
|
2021-03-01 19:18:09 +00:00
|
|
|
if params and not patma.match_list(list(msg.params), params):
|
|
|
|
fail_msg = (
|
|
|
|
fail_msg or "expected params to match {expects}, got {got}: {msg}"
|
|
|
|
)
|
2021-02-27 20:20:13 +00:00
|
|
|
return fail_msg.format(
|
|
|
|
*extra_format, got=msg.params, expects=params, msg=msg
|
|
|
|
)
|
|
|
|
|
2021-03-01 19:18:09 +00:00
|
|
|
if tags and not patma.match_dict(msg.tags, tags):
|
|
|
|
fail_msg = fail_msg or "expected tags to match {expects}, got {got}: {msg}"
|
|
|
|
return fail_msg.format(*extra_format, got=msg.tags, expects=tags, msg=msg)
|
|
|
|
|
2015-12-22 11:14:55 +00:00
|
|
|
if nick:
|
2021-02-28 16:08:27 +00:00
|
|
|
got_nick = msg.prefix.split("!")[0] if msg.prefix else None
|
2021-04-17 21:10:10 +00:00
|
|
|
if nick != got_nick:
|
2021-02-25 22:07:43 +00:00
|
|
|
fail_msg = (
|
2021-04-17 21:10:10 +00:00
|
|
|
fail_msg
|
|
|
|
or "expected nick to be {expects}, got {got} instead: {msg}"
|
2021-02-25 22:07:43 +00:00
|
|
|
)
|
|
|
|
return fail_msg.format(
|
|
|
|
*extra_format, got=got_nick, expects=nick, param=key, msg=msg
|
|
|
|
)
|
|
|
|
|
|
|
|
return None
|
2015-12-21 20:48:59 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def assertIn(
|
|
|
|
self,
|
|
|
|
member: Any,
|
|
|
|
container: Union[Iterable[Any], Container[Any]],
|
|
|
|
msg: Optional[str] = None,
|
|
|
|
fail_msg: Optional[str] = None,
|
|
|
|
extra_format: Tuple = (),
|
|
|
|
) -> None:
|
2015-12-21 20:48:59 +00:00
|
|
|
if fail_msg:
|
2021-06-26 22:23:07 +00:00
|
|
|
msg = fail_msg.format(*extra_format, item=member, list=container, msg=msg)
|
|
|
|
super().assertIn(member, container, msg)
|
2021-02-22 18:02:13 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def assertNotIn(
|
|
|
|
self,
|
|
|
|
member: Any,
|
|
|
|
container: Union[Iterable[Any], Container[Any]],
|
|
|
|
msg: Optional[str] = None,
|
|
|
|
fail_msg: Optional[str] = None,
|
|
|
|
extra_format: Tuple = (),
|
|
|
|
) -> None:
|
2015-12-22 21:33:23 +00:00
|
|
|
if fail_msg:
|
2021-06-26 22:23:07 +00:00
|
|
|
msg = fail_msg.format(*extra_format, item=member, list=container, msg=msg)
|
|
|
|
super().assertNotIn(member, container, msg)
|
2021-02-22 18:02:13 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def assertEqual(
|
|
|
|
self,
|
|
|
|
got: T,
|
|
|
|
expects: T,
|
|
|
|
msg: Any = None,
|
|
|
|
fail_msg: Optional[str] = None,
|
|
|
|
extra_format: Tuple = (),
|
|
|
|
) -> None:
|
2015-12-21 20:48:59 +00:00
|
|
|
if fail_msg:
|
2021-06-26 22:23:07 +00:00
|
|
|
msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
|
|
|
|
super().assertEqual(got, expects, msg)
|
2021-02-22 18:02:13 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def assertNotEqual(
|
|
|
|
self,
|
|
|
|
got: T,
|
|
|
|
expects: T,
|
|
|
|
msg: Any = None,
|
|
|
|
fail_msg: Optional[str] = None,
|
|
|
|
extra_format: Tuple = (),
|
|
|
|
) -> None:
|
2015-12-22 11:14:55 +00:00
|
|
|
if fail_msg:
|
2021-06-26 22:23:07 +00:00
|
|
|
msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
|
|
|
|
super().assertNotEqual(got, expects, msg)
|
2015-12-19 08:30:50 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def assertGreater(
|
|
|
|
self,
|
|
|
|
got: T,
|
|
|
|
expects: T,
|
|
|
|
msg: Any = None,
|
|
|
|
fail_msg: Optional[str] = None,
|
|
|
|
extra_format: Tuple = (),
|
|
|
|
) -> None:
|
2021-02-26 20:06:17 +00:00
|
|
|
if fail_msg:
|
2021-06-26 22:23:07 +00:00
|
|
|
msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
|
|
|
|
super().assertGreater(got, expects, msg)
|
2021-02-26 20:06:17 +00:00
|
|
|
|
|
|
|
def assertGreaterEqual(
|
2021-02-28 16:08:27 +00:00
|
|
|
self,
|
|
|
|
got: T,
|
|
|
|
expects: T,
|
|
|
|
msg: Any = None,
|
|
|
|
fail_msg: Optional[str] = None,
|
|
|
|
extra_format: Tuple = (),
|
|
|
|
) -> None:
|
2021-02-26 20:06:17 +00:00
|
|
|
if fail_msg:
|
2021-06-26 22:23:07 +00:00
|
|
|
msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
|
|
|
|
super().assertGreaterEqual(got, expects, msg)
|
2021-02-26 20:06:17 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def assertLess(
|
|
|
|
self,
|
|
|
|
got: T,
|
|
|
|
expects: T,
|
|
|
|
msg: Any = None,
|
|
|
|
fail_msg: Optional[str] = None,
|
|
|
|
extra_format: Tuple = (),
|
|
|
|
) -> None:
|
2021-02-26 20:06:17 +00:00
|
|
|
if fail_msg:
|
2021-06-26 22:23:07 +00:00
|
|
|
msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
|
|
|
|
super().assertLess(got, expects, msg)
|
2021-02-26 20:06:17 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def assertLessEqual(
|
|
|
|
self,
|
|
|
|
got: T,
|
|
|
|
expects: T,
|
|
|
|
msg: Any = None,
|
|
|
|
fail_msg: Optional[str] = None,
|
|
|
|
extra_format: Tuple = (),
|
|
|
|
) -> None:
|
2021-02-26 20:06:17 +00:00
|
|
|
if fail_msg:
|
2021-06-26 22:23:07 +00:00
|
|
|
msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
|
|
|
|
super().assertLessEqual(got, expects, msg)
|
2021-02-26 20:06:17 +00:00
|
|
|
|
2021-02-22 18:02:13 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
class BaseClientTestCase(_IrcTestCase[basecontrollers.BaseClientController]):
|
2015-12-19 22:09:06 +00:00
|
|
|
"""Basic class for client tests. Handles spawning a client and exchanging
|
|
|
|
messages with it."""
|
2021-02-22 18:02:13 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
conn: Optional[socket.socket]
|
|
|
|
nick: Optional[str] = None
|
|
|
|
user: Optional[List[str]] = None
|
|
|
|
server: socket.socket
|
|
|
|
protocol_version = Optional[str]
|
|
|
|
acked_capabilities = Optional[Set[str]]
|
2021-02-22 18:02:13 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def setUp(self) -> None:
|
2015-12-20 00:48:56 +00:00
|
|
|
super().setUp()
|
2015-12-20 14:11:56 +00:00
|
|
|
self.conn = None
|
2015-12-19 00:11:57 +00:00
|
|
|
self._setUpServer()
|
2021-02-22 18:02:13 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def tearDown(self) -> None:
|
2015-12-20 14:11:56 +00:00
|
|
|
if self.conn:
|
2021-02-20 09:43:00 +00:00
|
|
|
try:
|
2021-02-22 18:02:13 +00:00
|
|
|
self.conn.sendall(b"QUIT :end of test.")
|
2021-02-20 09:43:00 +00:00
|
|
|
except BrokenPipeError:
|
2021-02-22 18:02:13 +00:00
|
|
|
pass # client already disconnected
|
2019-12-08 20:26:21 +00:00
|
|
|
except OSError:
|
2021-02-22 18:02:13 +00:00
|
|
|
pass # the conn was already closed by the test, or something
|
2015-12-20 00:17:52 +00:00
|
|
|
self.controller.kill()
|
2015-12-20 14:11:56 +00:00
|
|
|
if self.conn:
|
|
|
|
self.conn_file.close()
|
|
|
|
self.conn.close()
|
2015-12-19 00:11:57 +00:00
|
|
|
self.server.close()
|
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def _setUpServer(self) -> None:
|
2015-12-19 00:11:57 +00:00
|
|
|
"""Creates the server and make it listen."""
|
|
|
|
self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
2021-02-22 18:02:13 +00:00
|
|
|
self.server.bind(("", 0)) # Bind any free port
|
2015-12-19 00:11:57 +00:00
|
|
|
self.server.listen(1)
|
2021-02-22 18:02:13 +00:00
|
|
|
|
2021-02-24 11:58:39 +00:00
|
|
|
# Used to check if the client is alive from time to time
|
|
|
|
self.server.settimeout(1)
|
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def acceptClient(
|
|
|
|
self,
|
|
|
|
tls_cert: Optional[str] = None,
|
|
|
|
tls_key: Optional[str] = None,
|
|
|
|
server: Optional[socket.socket] = None,
|
|
|
|
) -> None:
|
2015-12-19 00:11:57 +00:00
|
|
|
"""Make the server accept a client connection. Blocking."""
|
2019-12-08 20:26:21 +00:00
|
|
|
server = server or self.server
|
2021-02-28 16:08:27 +00:00
|
|
|
assert server
|
2021-02-24 11:58:39 +00:00
|
|
|
# Wait for the client to connect
|
|
|
|
while True:
|
|
|
|
try:
|
|
|
|
(self.conn, addr) = server.accept()
|
|
|
|
except socket.timeout:
|
|
|
|
self.controller.check_is_alive()
|
|
|
|
else:
|
|
|
|
break
|
2016-07-20 09:41:35 +00:00
|
|
|
if tls_cert is None and tls_key is None:
|
|
|
|
pass
|
|
|
|
else:
|
2021-02-22 18:02:13 +00:00
|
|
|
assert (
|
|
|
|
tls_cert and tls_key
|
|
|
|
), "tls_cert must be provided if and only if tls_key is."
|
|
|
|
with tempfile.NamedTemporaryFile(
|
|
|
|
"at"
|
|
|
|
) as certfile, tempfile.NamedTemporaryFile("at") as keyfile:
|
2016-07-20 09:41:35 +00:00
|
|
|
certfile.write(tls_cert)
|
|
|
|
certfile.seek(0)
|
|
|
|
keyfile.write(tls_key)
|
|
|
|
keyfile.seek(0)
|
|
|
|
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
|
|
|
context.load_cert_chain(certfile=certfile.name, keyfile=keyfile.name)
|
|
|
|
self.conn = context.wrap_socket(self.conn, server_side=True)
|
2021-02-22 18:02:13 +00:00
|
|
|
self.conn_file = self.conn.makefile(newline="\r\n", encoding="utf8")
|
2015-12-19 00:11:57 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def getLine(self) -> str:
|
2015-12-19 08:30:50 +00:00
|
|
|
line = self.conn_file.readline()
|
|
|
|
if self.show_io:
|
2021-02-22 18:02:13 +00:00
|
|
|
print("{:.3f} C: {}".format(time.time(), line.strip()))
|
2015-12-19 08:30:50 +00:00
|
|
|
return line
|
2021-02-22 18:02:13 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def getMessage(
|
2021-04-17 21:10:10 +00:00
|
|
|
self, *args: Any, filter_pred: Optional[Callable[[Message], bool]] = None
|
2021-02-28 16:08:27 +00:00
|
|
|
) -> Message:
|
2015-12-21 11:24:40 +00:00
|
|
|
"""Gets a message and returns it. If a filter predicate is given,
|
|
|
|
fetches messages until the predicate returns a False on a message,
|
|
|
|
and returns this message."""
|
|
|
|
while True:
|
|
|
|
line = self.getLine(*args)
|
2016-07-20 09:41:35 +00:00
|
|
|
if not line:
|
|
|
|
raise ConnectionClosed()
|
2015-12-21 11:24:40 +00:00
|
|
|
msg = message_parser.parse_message(line)
|
|
|
|
if not filter_pred or filter_pred(msg):
|
|
|
|
return msg
|
2021-02-22 18:02:13 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def sendLine(self, line: str) -> None:
|
|
|
|
assert self.conn
|
2019-12-08 21:36:56 +00:00
|
|
|
self.conn.sendall(line.encode())
|
2021-02-22 18:02:13 +00:00
|
|
|
if not line.endswith("\r\n"):
|
|
|
|
self.conn.sendall(b"\r\n")
|
2015-12-19 16:52:38 +00:00
|
|
|
if self.show_io:
|
2021-02-22 18:02:13 +00:00
|
|
|
print("{:.3f} S: {}".format(time.time(), line.strip()))
|
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def readCapLs(
|
|
|
|
self, auth: Optional[Authentication] = None, tls_config: tls.TlsConfig = None
|
|
|
|
) -> None:
|
2015-12-19 08:03:08 +00:00
|
|
|
(hostname, port) = self.server.getsockname()
|
|
|
|
self.controller.run(
|
2021-02-22 18:27:43 +00:00
|
|
|
hostname=hostname, port=port, auth=auth, tls_config=tls_config
|
2021-02-22 18:02:13 +00:00
|
|
|
)
|
2015-12-19 08:03:08 +00:00
|
|
|
self.acceptClient()
|
|
|
|
m = self.getMessage()
|
2021-02-22 18:02:13 +00:00
|
|
|
self.assertEqual(m.command, "CAP", "First message is not CAP LS.")
|
|
|
|
if m.params == ["LS"]:
|
2015-12-19 08:03:08 +00:00
|
|
|
self.protocol_version = 301
|
2021-02-22 18:02:13 +00:00
|
|
|
elif m.params == ["LS", "302"]:
|
2015-12-19 08:03:08 +00:00
|
|
|
self.protocol_version = 302
|
2021-02-22 18:02:13 +00:00
|
|
|
elif m.params == ["END"]:
|
2015-12-19 10:32:19 +00:00
|
|
|
self.protocol_version = None
|
2015-12-19 08:03:08 +00:00
|
|
|
else:
|
2021-02-22 18:02:13 +00:00
|
|
|
raise AssertionError("Unknown CAP params: {}".format(m.params))
|
2015-12-19 08:30:50 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def userNickPredicate(self, msg: Message) -> bool:
|
2015-12-19 08:30:50 +00:00
|
|
|
"""Predicate to be used with getMessage to handle NICK/USER
|
|
|
|
transparently."""
|
2021-02-22 18:02:13 +00:00
|
|
|
if msg.command == "NICK":
|
2021-02-28 16:08:27 +00:00
|
|
|
self.assertEqual(len(msg.params), 1, msg=msg)
|
2015-12-19 08:30:50 +00:00
|
|
|
self.nick = msg.params[0]
|
|
|
|
return False
|
2021-02-22 18:02:13 +00:00
|
|
|
elif msg.command == "USER":
|
2021-02-28 16:08:27 +00:00
|
|
|
self.assertEqual(len(msg.params), 4, msg=msg)
|
2015-12-19 16:52:38 +00:00
|
|
|
self.user = msg.params
|
2015-12-19 08:30:50 +00:00
|
|
|
return False
|
|
|
|
else:
|
|
|
|
return True
|
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def negotiateCapabilities(
|
|
|
|
self,
|
|
|
|
caps: List[str],
|
|
|
|
cap_ls: bool = True,
|
|
|
|
auth: Optional[Authentication] = None,
|
|
|
|
) -> Optional[Message]:
|
2015-12-20 12:47:30 +00:00
|
|
|
"""Performes a complete capability negociation process, without
|
|
|
|
ending it, so the caller can continue the negociation."""
|
2015-12-19 16:52:38 +00:00
|
|
|
if cap_ls:
|
|
|
|
self.readCapLs(auth)
|
|
|
|
if not self.protocol_version:
|
|
|
|
# No negotiation.
|
2021-02-28 16:08:27 +00:00
|
|
|
return None
|
2021-02-22 18:02:13 +00:00
|
|
|
self.sendLine("CAP * LS :{}".format(" ".join(caps)))
|
2015-12-20 14:11:56 +00:00
|
|
|
capability_names = frozenset(capabilities.cap_list_to_dict(caps))
|
2015-12-19 20:17:06 +00:00
|
|
|
self.acked_capabilities = set()
|
2015-12-19 09:05:37 +00:00
|
|
|
while True:
|
|
|
|
m = self.getMessage(filter_pred=self.userNickPredicate)
|
2021-02-22 18:02:13 +00:00
|
|
|
if m.command != "CAP":
|
2015-12-19 16:52:38 +00:00
|
|
|
return m
|
2015-12-19 09:05:37 +00:00
|
|
|
self.assertGreater(len(m.params), 0, m)
|
2021-02-22 18:02:13 +00:00
|
|
|
if m.params[0] == "REQ":
|
2015-12-19 09:05:37 +00:00
|
|
|
self.assertEqual(len(m.params), 2, m)
|
|
|
|
requested = frozenset(m.params[1].split())
|
2015-12-19 20:17:06 +00:00
|
|
|
if not requested.issubset(capability_names):
|
2021-02-22 18:02:13 +00:00
|
|
|
self.sendLine(
|
|
|
|
"CAP {} NAK :{}".format(self.nick or "*", m.params[1][0:100])
|
|
|
|
)
|
2015-12-19 16:52:38 +00:00
|
|
|
else:
|
2021-02-22 18:02:13 +00:00
|
|
|
self.sendLine(
|
|
|
|
"CAP {} ACK :{}".format(self.nick or "*", m.params[1])
|
|
|
|
)
|
2021-02-28 16:08:27 +00:00
|
|
|
self.acked_capabilities.update(requested) # type: ignore
|
2015-12-19 09:05:37 +00:00
|
|
|
else:
|
|
|
|
return m
|
|
|
|
|
2015-12-19 22:09:06 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
class BaseServerTestCase(
|
|
|
|
_IrcTestCase[basecontrollers.BaseServerController], Generic[TClientName]
|
|
|
|
):
|
2015-12-19 22:09:06 +00:00
|
|
|
"""Basic class for server tests. Handles spawning a server and exchanging
|
|
|
|
messages with it."""
|
2021-02-22 18:02:13 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
show_io: bool # set by conftest.py
|
|
|
|
|
2021-02-28 11:23:06 +00:00
|
|
|
password: Optional[str] = None
|
2015-12-25 14:45:06 +00:00
|
|
|
ssl = False
|
2021-02-28 11:23:06 +00:00
|
|
|
valid_metadata_keys: Set[str] = set()
|
|
|
|
invalid_metadata_keys: Set[str] = set()
|
2021-06-26 20:11:47 +00:00
|
|
|
server_support: Optional[Dict[str, Optional[str]]]
|
2021-06-27 12:42:13 +00:00
|
|
|
run_services = False
|
2021-02-22 18:02:13 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def setUp(self) -> None:
|
2015-12-20 00:48:56 +00:00
|
|
|
super().setUp()
|
2021-02-27 14:57:02 +00:00
|
|
|
self.server_support = None
|
2021-06-26 22:26:41 +00:00
|
|
|
(self.hostname, self.port) = find_hostname_and_port()
|
2021-02-22 18:02:13 +00:00
|
|
|
self.controller.run(
|
|
|
|
self.hostname,
|
|
|
|
self.port,
|
|
|
|
password=self.password,
|
|
|
|
valid_metadata_keys=self.valid_metadata_keys,
|
|
|
|
invalid_metadata_keys=self.invalid_metadata_keys,
|
|
|
|
ssl=self.ssl,
|
2021-06-27 12:42:13 +00:00
|
|
|
run_services=self.run_services,
|
2021-02-22 18:02:13 +00:00
|
|
|
)
|
2021-02-28 16:08:27 +00:00
|
|
|
self.clients: Dict[TClientName, client_mock.ClientMock] = {}
|
2021-02-22 18:02:13 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def tearDown(self) -> None:
|
2015-12-19 22:09:06 +00:00
|
|
|
self.controller.kill()
|
|
|
|
for client in list(self.clients):
|
|
|
|
self.removeClient(client)
|
2021-02-22 18:02:13 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def addClient(
|
|
|
|
self, name: Optional[TClientName] = None, show_io: Optional[bool] = None
|
|
|
|
) -> TClientName:
|
2015-12-19 22:09:06 +00:00
|
|
|
"""Connects a client to the server and adds it to the dict.
|
|
|
|
If 'name' is not given, uses the lowest unused non-negative integer."""
|
2015-12-23 00:54:10 +00:00
|
|
|
self.controller.wait_for_port()
|
2021-06-27 12:42:13 +00:00
|
|
|
if self.run_services:
|
|
|
|
self.controller.wait_for_services()
|
2015-12-19 22:09:06 +00:00
|
|
|
if not name:
|
2021-02-28 16:08:27 +00:00
|
|
|
new_name: int = (
|
|
|
|
max(
|
|
|
|
[int(name) for name in self.clients if isinstance(name, (int, str))]
|
|
|
|
+ [0]
|
|
|
|
)
|
|
|
|
+ 1
|
|
|
|
)
|
|
|
|
name = cast(TClientName, new_name)
|
2015-12-21 11:24:40 +00:00
|
|
|
show_io = show_io if show_io is not None else self.show_io
|
2021-02-22 18:02:13 +00:00
|
|
|
self.clients[name] = client_mock.ClientMock(name=name, show_io=show_io)
|
2015-12-21 11:24:40 +00:00
|
|
|
self.clients[name].connect(self.hostname, self.port)
|
2015-12-20 12:12:54 +00:00
|
|
|
return name
|
2015-12-19 22:09:06 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def removeClient(self, name: TClientName) -> None:
|
2015-12-20 12:47:30 +00:00
|
|
|
"""Disconnects the client, without QUIT."""
|
2015-12-19 22:09:06 +00:00
|
|
|
assert name in self.clients
|
2015-12-21 11:24:40 +00:00
|
|
|
self.clients[name].disconnect()
|
2015-12-19 22:09:06 +00:00
|
|
|
del self.clients[name]
|
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def getMessages(self, client: TClientName, **kwargs: Any) -> List[Message]:
|
2015-12-21 11:24:40 +00:00
|
|
|
return self.clients[client].getMessages(**kwargs)
|
2021-02-22 18:02:13 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def getMessage(self, client: TClientName, **kwargs: Any) -> Message:
|
2015-12-21 11:24:40 +00:00
|
|
|
return self.clients[client].getMessage(**kwargs)
|
2021-02-22 18:02:13 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def getRegistrationMessage(self, client: TClientName) -> Message:
|
2015-12-21 11:24:40 +00:00
|
|
|
"""Filter notices, do not send pings."""
|
2021-02-22 18:02:13 +00:00
|
|
|
return self.getMessage(
|
|
|
|
client, synchronize=False, filter_pred=lambda m: m.command != "NOTICE"
|
|
|
|
)
|
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def sendLine(self, client: TClientName, line: Union[str, bytes]) -> None:
|
2015-12-21 11:24:40 +00:00
|
|
|
return self.clients[client].sendLine(line)
|
2015-12-20 00:48:56 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def getCapLs(
|
|
|
|
self, client: TClientName, as_list: bool = False
|
|
|
|
) -> Union[List[str], Dict[str, Optional[str]]]:
|
2015-12-20 12:47:30 +00:00
|
|
|
"""Waits for a CAP LS block, parses all CAP LS messages, and return
|
2015-12-20 14:11:56 +00:00
|
|
|
the dict capabilities, with their values.
|
|
|
|
|
|
|
|
If as_list is given, returns the raw list (ie. key/value not split)
|
|
|
|
in case the order matters (but it shouldn't)."""
|
|
|
|
caps = []
|
2015-12-20 00:48:56 +00:00
|
|
|
while True:
|
2015-12-21 11:24:40 +00:00
|
|
|
m = self.getRegistrationMessage(client)
|
2021-02-28 17:52:14 +00:00
|
|
|
self.assertMessageMatch(m, command="CAP")
|
2021-02-27 20:20:13 +00:00
|
|
|
self.assertEqual(m.params[1], "LS", fail_msg="Expected CAP * LS, got {got}")
|
2021-02-22 18:02:13 +00:00
|
|
|
if m.params[2] == "*":
|
2015-12-20 14:11:56 +00:00
|
|
|
caps.extend(m.params[3].split())
|
2015-12-20 00:48:56 +00:00
|
|
|
else:
|
2015-12-20 14:11:56 +00:00
|
|
|
caps.extend(m.params[2].split())
|
|
|
|
if not as_list:
|
2021-02-28 16:08:27 +00:00
|
|
|
return capabilities.cap_list_to_dict(caps)
|
2015-12-20 14:11:56 +00:00
|
|
|
return caps
|
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def assertDisconnected(self, client: TClientName) -> None:
|
2015-12-20 22:59:35 +00:00
|
|
|
try:
|
2018-12-31 00:05:13 +00:00
|
|
|
self.getMessages(client)
|
|
|
|
self.getMessages(client)
|
|
|
|
except (socket.error, ConnectionClosed):
|
2015-12-20 22:59:35 +00:00
|
|
|
del self.clients[client]
|
|
|
|
return
|
|
|
|
else:
|
2021-02-22 18:02:13 +00:00
|
|
|
raise AssertionError("Client not disconnected.")
|
2015-12-22 19:49:31 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def skipToWelcome(self, client: TClientName) -> List[Message]:
|
2015-12-22 19:49:31 +00:00
|
|
|
"""Skip to the point where we are registered
|
|
|
|
<https://tools.ietf.org/html/rfc2812#section-3.1>
|
|
|
|
"""
|
2018-12-28 18:42:47 +00:00
|
|
|
result = []
|
2015-12-21 11:24:40 +00:00
|
|
|
while True:
|
2015-12-22 19:49:31 +00:00
|
|
|
m = self.getMessage(client, synchronize=False)
|
2018-12-28 18:42:47 +00:00
|
|
|
result.append(m)
|
2021-02-22 18:02:13 +00:00
|
|
|
if m.command == "001":
|
2018-12-28 18:42:47 +00:00
|
|
|
return result
|
|
|
|
|
2021-06-27 11:01:03 +00:00
|
|
|
def requestCapabilities(
|
|
|
|
self,
|
|
|
|
client: TClientName,
|
|
|
|
capabilities: List[str],
|
|
|
|
skip_if_cap_nak: bool = False,
|
|
|
|
) -> None:
|
|
|
|
self.sendLine(client, "CAP REQ :{}".format(" ".join(capabilities)))
|
|
|
|
m = self.getRegistrationMessage(client)
|
|
|
|
try:
|
|
|
|
self.assertMessageMatch(
|
|
|
|
m, command="CAP", fail_msg="Expected CAP ACK, got: {msg}"
|
|
|
|
)
|
|
|
|
self.assertEqual(
|
|
|
|
m.params[1], "ACK", m, fail_msg="Expected CAP ACK, got: {msg}"
|
|
|
|
)
|
|
|
|
except AssertionError:
|
|
|
|
if skip_if_cap_nak:
|
|
|
|
raise runner.CapabilityNotSupported(" or ".join(capabilities))
|
|
|
|
else:
|
|
|
|
raise
|
|
|
|
|
2021-02-22 18:02:13 +00:00
|
|
|
def connectClient(
|
|
|
|
self,
|
2021-02-28 16:08:27 +00:00
|
|
|
nick: str,
|
|
|
|
name: TClientName = None,
|
|
|
|
capabilities: Optional[List[str]] = None,
|
|
|
|
skip_if_cap_nak: bool = False,
|
|
|
|
show_io: Optional[bool] = None,
|
|
|
|
account: Optional[str] = None,
|
|
|
|
password: Optional[str] = None,
|
|
|
|
ident: str = "username",
|
|
|
|
) -> List[Message]:
|
|
|
|
"""Connections a new client, does the cap negotiation
|
|
|
|
and connection registration, and skips to the end of the MOTD.
|
|
|
|
Returns the list of all messages received after registration,
|
|
|
|
just like `skipToWelcome`."""
|
2018-12-28 18:42:47 +00:00
|
|
|
client = self.addClient(name, show_io=show_io)
|
2021-06-26 22:37:33 +00:00
|
|
|
if capabilities:
|
|
|
|
self.sendLine(client, "CAP LS 302")
|
|
|
|
m = self.getRegistrationMessage(client)
|
2021-06-27 11:01:03 +00:00
|
|
|
self.requestCapabilities(client, capabilities, skip_if_cap_nak)
|
2020-03-11 10:51:23 +00:00
|
|
|
if password is not None:
|
2021-06-26 22:46:00 +00:00
|
|
|
if "sasl" not in (capabilities or ()):
|
|
|
|
raise ValueError("Used 'password' option without sasl capbilitiy")
|
2021-02-22 18:02:13 +00:00
|
|
|
self.sendLine(client, "AUTHENTICATE PLAIN")
|
2021-06-26 22:37:33 +00:00
|
|
|
m = self.getRegistrationMessage(client)
|
|
|
|
self.assertMessageMatch(m, command="AUTHENTICATE", params=["+"])
|
2021-02-27 16:34:37 +00:00
|
|
|
self.sendLine(client, sasl_plain_blob(account or nick, password))
|
2021-06-26 22:37:33 +00:00
|
|
|
m = self.getRegistrationMessage(client)
|
|
|
|
self.assertIn(m.command, ["900", "903"], str(m))
|
|
|
|
|
2021-02-22 18:02:13 +00:00
|
|
|
self.sendLine(client, "NICK {}".format(nick))
|
|
|
|
self.sendLine(client, "USER %s * * :Realname" % (ident,))
|
2021-06-26 22:37:33 +00:00
|
|
|
if capabilities:
|
|
|
|
self.sendLine(client, "CAP END")
|
2015-12-22 19:49:31 +00:00
|
|
|
|
2018-12-28 18:42:47 +00:00
|
|
|
welcome = self.skipToWelcome(client)
|
2021-02-22 18:02:13 +00:00
|
|
|
self.sendLine(client, "PING foo")
|
2015-12-21 11:24:40 +00:00
|
|
|
|
|
|
|
# Skip all that happy welcoming stuff
|
2021-02-27 14:57:02 +00:00
|
|
|
self.server_support = {}
|
2015-12-21 11:24:40 +00:00
|
|
|
while True:
|
2015-12-22 19:49:31 +00:00
|
|
|
m = self.getMessage(client)
|
2021-02-22 18:02:13 +00:00
|
|
|
if m.command == "PONG":
|
2015-12-21 11:24:40 +00:00
|
|
|
break
|
2021-02-22 18:02:13 +00:00
|
|
|
elif m.command == "005":
|
2015-12-25 21:47:11 +00:00
|
|
|
for param in m.params[1:-1]:
|
2021-02-22 18:02:13 +00:00
|
|
|
if "=" in param:
|
|
|
|
(key, value) = param.split("=")
|
2021-02-28 16:08:27 +00:00
|
|
|
self.server_support[key] = value
|
2015-12-25 21:47:11 +00:00
|
|
|
else:
|
2021-02-28 16:08:27 +00:00
|
|
|
self.server_support[param] = None
|
2018-12-28 18:42:47 +00:00
|
|
|
welcome.append(m)
|
|
|
|
|
|
|
|
return welcome
|
2015-12-20 22:59:35 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def joinClient(self, client: TClientName, channel: str) -> None:
|
2021-02-22 18:02:13 +00:00
|
|
|
self.sendLine(client, "JOIN {}".format(channel))
|
2015-12-22 11:14:55 +00:00
|
|
|
received = {m.command for m in self.getMessages(client)}
|
2021-02-22 18:02:13 +00:00
|
|
|
self.assertIn(
|
|
|
|
"366",
|
|
|
|
received,
|
|
|
|
fail_msg="Join to {} failed, {item} is not in the set of "
|
|
|
|
"received responses: {list}",
|
|
|
|
extra_format=(channel,),
|
|
|
|
)
|
2015-12-22 11:14:55 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def joinChannel(self, client: TClientName, channel: str) -> None:
|
2021-02-22 18:02:13 +00:00
|
|
|
self.sendLine(client, "JOIN {}".format(channel))
|
2017-11-01 17:29:45 +00:00
|
|
|
# wait until we see them join the channel
|
|
|
|
joined = False
|
|
|
|
while not joined:
|
2017-11-01 17:42:44 +00:00
|
|
|
for msg in self.getMessages(client):
|
2021-02-22 18:02:13 +00:00
|
|
|
if (
|
|
|
|
msg.command == "JOIN"
|
|
|
|
and 0 < len(msg.params)
|
|
|
|
and msg.params[0].lower() == channel.lower()
|
|
|
|
):
|
2017-11-01 17:29:45 +00:00
|
|
|
joined = True
|
|
|
|
break
|
2021-02-18 04:27:48 +00:00
|
|
|
elif msg.command in CHANNEL_JOIN_FAIL_NUMERICS:
|
|
|
|
raise ChannelJoinException(msg.command, msg.params)
|
2017-11-01 17:29:45 +00:00
|
|
|
|
2021-02-22 18:02:13 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
_TSelf = TypeVar("_TSelf", bound="OptionalityHelper")
|
|
|
|
_TReturn = TypeVar("_TReturn")
|
2021-02-28 11:23:06 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
|
|
|
|
class OptionalityHelper(Generic[TController]):
|
|
|
|
controller: TController
|
|
|
|
|
|
|
|
def checkSaslSupport(self) -> None:
|
2015-12-21 23:47:02 +00:00
|
|
|
if self.controller.supported_sasl_mechanisms:
|
|
|
|
return
|
2021-02-22 18:02:13 +00:00
|
|
|
raise runner.NotImplementedByController("SASL")
|
2015-12-21 23:47:02 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def checkMechanismSupport(self, mechanism: str) -> None:
|
2015-12-20 14:11:56 +00:00
|
|
|
if mechanism in self.controller.supported_sasl_mechanisms:
|
|
|
|
return
|
2015-12-22 17:54:06 +00:00
|
|
|
raise runner.OptionalSaslMechanismNotSupported(mechanism)
|
2015-12-20 14:11:56 +00:00
|
|
|
|
2021-02-28 11:23:06 +00:00
|
|
|
@staticmethod
|
2021-02-28 16:08:27 +00:00
|
|
|
def skipUnlessHasMechanism(
|
|
|
|
mech: str,
|
|
|
|
) -> Callable[[Callable[[_TSelf], _TReturn]], Callable[[_TSelf], _TReturn]]:
|
|
|
|
# Just a function returning a function that takes functions and
|
|
|
|
# returns functions, nothing to see here.
|
|
|
|
# If Python didn't have such an awful syntax for callables, it would be:
|
|
|
|
# str -> ((TSelf -> TReturn) -> (TSelf -> TReturn))
|
|
|
|
def decorator(f: Callable[[_TSelf], _TReturn]) -> Callable[[_TSelf], _TReturn]:
|
2015-12-20 14:11:56 +00:00
|
|
|
@functools.wraps(f)
|
2021-02-28 16:08:27 +00:00
|
|
|
def newf(self: _TSelf) -> _TReturn:
|
2015-12-20 14:11:56 +00:00
|
|
|
self.checkMechanismSupport(mech)
|
|
|
|
return f(self)
|
2021-02-22 18:02:13 +00:00
|
|
|
|
2015-12-20 14:11:56 +00:00
|
|
|
return newf
|
2021-02-22 18:02:13 +00:00
|
|
|
|
2015-12-20 14:11:56 +00:00
|
|
|
return decorator
|
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
@staticmethod
|
|
|
|
def skipUnlessHasSasl(
|
|
|
|
f: Callable[[_TSelf], _TReturn]
|
|
|
|
) -> Callable[[_TSelf], _TReturn]:
|
2015-12-21 23:47:02 +00:00
|
|
|
@functools.wraps(f)
|
2021-02-28 16:08:27 +00:00
|
|
|
def newf(self: _TSelf) -> _TReturn:
|
2015-12-21 23:47:02 +00:00
|
|
|
self.checkSaslSupport()
|
|
|
|
return f(self)
|
2021-02-22 18:02:13 +00:00
|
|
|
|
2015-12-21 23:47:02 +00:00
|
|
|
return newf
|
|
|
|
|
2015-12-22 18:55:48 +00:00
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def mark_specifications(
|
|
|
|
*specifications_str: str, deprecated: bool = False, strict: bool = False
|
|
|
|
) -> Callable[[TCallable], TCallable]:
|
2021-02-24 14:51:52 +00:00
|
|
|
specifications = frozenset(
|
2021-02-28 16:08:27 +00:00
|
|
|
Specifications.from_name(s) if isinstance(s, str) else s
|
|
|
|
for s in specifications_str
|
2021-02-24 14:51:52 +00:00
|
|
|
)
|
|
|
|
if None in specifications:
|
|
|
|
raise ValueError("Invalid set of specifications: {}".format(specifications))
|
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def decorator(f: TCallable) -> TCallable:
|
2021-02-24 14:51:52 +00:00
|
|
|
for specification in specifications:
|
|
|
|
f = getattr(pytest.mark, specification.value)(f)
|
|
|
|
if strict:
|
|
|
|
f = pytest.mark.strict(f)
|
|
|
|
if deprecated:
|
|
|
|
f = pytest.mark.deprecated(f)
|
|
|
|
return f
|
|
|
|
|
|
|
|
return decorator
|
2021-02-27 09:36:30 +00:00
|
|
|
|
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def mark_capabilities(
|
|
|
|
*capabilities_str: str, deprecated: bool = False, strict: bool = False
|
|
|
|
) -> Callable[[TCallable], TCallable]:
|
2021-02-27 09:36:30 +00:00
|
|
|
capabilities = frozenset(
|
2021-02-28 16:08:27 +00:00
|
|
|
Capabilities.from_name(c) if isinstance(c, str) else c for c in capabilities_str
|
2021-02-27 09:36:30 +00:00
|
|
|
)
|
|
|
|
if None in capabilities:
|
|
|
|
raise ValueError("Invalid set of capabilities: {}".format(capabilities))
|
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def decorator(f: TCallable) -> TCallable:
|
2021-02-27 09:36:30 +00:00
|
|
|
for capability in capabilities:
|
|
|
|
f = getattr(pytest.mark, capability.value)(f)
|
|
|
|
# Support for any capability implies IRCv3
|
|
|
|
f = pytest.mark.IRCv3(f)
|
|
|
|
return f
|
|
|
|
|
|
|
|
return decorator
|
|
|
|
|
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def mark_isupport(
|
|
|
|
*tokens_str: str, deprecated: bool = False, strict: bool = False
|
|
|
|
) -> Callable[[TCallable], TCallable]:
|
2021-02-27 09:36:30 +00:00
|
|
|
tokens = frozenset(
|
2021-02-28 16:08:27 +00:00
|
|
|
IsupportTokens.from_name(c) if isinstance(c, str) else c for c in tokens_str
|
2021-02-27 09:36:30 +00:00
|
|
|
)
|
|
|
|
if None in tokens:
|
|
|
|
raise ValueError("Invalid set of isupport tokens: {}".format(tokens))
|
|
|
|
|
2021-02-28 16:08:27 +00:00
|
|
|
def decorator(f: TCallable) -> TCallable:
|
2021-02-27 09:36:30 +00:00
|
|
|
for token in tokens:
|
|
|
|
f = getattr(pytest.mark, token.value)(f)
|
|
|
|
return f
|
|
|
|
|
|
|
|
return decorator
|