irctest/irctest/cases.py

839 lines
28 KiB
Python
Raw Normal View History

import contextlib
2021-02-22 18:04:23 +00:00
import functools
2015-12-19 00:11:57 +00:00
import socket
2021-02-22 18:04:23 +00:00
import ssl
import tempfile
2021-02-22 18:04:23 +00:00
import time
from typing import (
Any,
Callable,
Container,
Dict,
Generic,
Hashable,
Iterable,
Iterator,
List,
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
cast,
)
2015-12-19 00:11:57 +00:00
import pytest
from . import basecontrollers, client_mock, patma, runner, tls
from .authentication import Authentication
from .basecontrollers import TestCaseControllerConfig
2021-02-22 18:04:23 +00:00
from .exceptions import ConnectionClosed
from .irc_utils import capabilities, message_parser
from .irc_utils.message_parser import Message
2020-02-17 09:05:21 +00:00
from .irc_utils.sasl import sasl_plain_blob
2021-02-22 18:02:13 +00:00
from .numerics import (
ERR_BADCHANNELKEY,
ERR_BANNEDFROMCHAN,
2021-02-22 18:04:23 +00:00
ERR_INVITEONLYCHAN,
2021-02-22 18:02:13 +00:00
ERR_NEEDREGGEDNICK,
2021-02-22 18:04:23 +00:00
ERR_NOSUCHCHANNEL,
ERR_TOOMANYCHANNELS,
RPL_HELLO,
2021-02-22 18:02:13 +00:00
)
from .specifications import Capabilities, IsupportTokens, Specifications
2021-02-22 18:02:13 +00:00
__tracebackhide__ = True # Hide from pytest tracebacks on test failure.
2021-02-22 18:02:13 +00:00
CHANNEL_JOIN_FAIL_NUMERICS = frozenset(
[
ERR_NOSUCHCHANNEL,
ERR_TOOMANYCHANNELS,
ERR_BADCHANNELKEY,
ERR_INVITEONLYCHAN,
ERR_BANNEDFROMCHAN,
ERR_NEEDREGGEDNICK,
]
)
2021-02-18 04:27:48 +00:00
# typevar for decorators
TCallable = TypeVar("TCallable", bound=Callable)
TClass = TypeVar("TClass", bound=Type)
# typevar for the client name used by tests (usually int or str)
TClientName = TypeVar("TClientName", bound=Union[Hashable, int])
TController = TypeVar("TController", bound=basecontrollers._BaseController)
# general-purpose typevar
T = TypeVar("T")
2021-02-18 04:27:48 +00:00
class ChannelJoinException(Exception):
def __init__(self, code: str, params: List[str]):
2021-02-22 18:02:13 +00:00
super().__init__(f"Failed to join channel ({code}): {params}")
2021-02-18 04:27:48 +00:00
self.code = code
self.params = params
2021-02-22 18:02:13 +00:00
class _IrcTestCase(Generic[TController]):
"""Base class for test cases.
It implements various `assert*` method that look like unittest's,
but is actually based on the `assert` statement so derived classes are
pytest-style rather than unittest-style.
It also calls setUp() and tearDown() like unittest would."""
2021-02-22 18:02:13 +00:00
# Will be set by __main__.py
controllerClass: Type[TController]
show_io: bool
controller: TController
2015-12-19 00:11:57 +00:00
__new__ = object.__new__ # pytest won't collect Generic subclasses otherwise
@staticmethod
def config() -> TestCaseControllerConfig:
"""Some configuration to pass to the controllers.
For example, Oragono only enables its MySQL support if
config()["chathistory"]=True.
"""
return TestCaseControllerConfig()
def setUp(self) -> None:
if self.controllerClass is not None:
self.controller = self.controllerClass(self.config())
2015-12-20 00:48:56 +00:00
if self.show_io:
2021-02-22 18:02:13 +00:00
print("---- new test ----")
def tearDown(self) -> None:
pass
def setup_method(self, method: Callable) -> None:
self.setUp()
def teardown_method(self, method: Callable) -> None:
self.tearDown()
def assertMessageMatch(self, msg: Message, **kwargs: Any) -> None:
"""Helper for partially comparing a message.
Takes the message as first arguments, and comparisons to be made
as keyword arguments.
Uses patma.match_list on the params argument.
"""
error = self.messageDiffers(msg, **kwargs)
if error:
raise AssertionError(error)
def messageEqual(self, msg: Message, **kwargs: Any) -> bool:
"""Boolean negation of `messageDiffers` (returns a boolean,
not an optional string)."""
return not self.messageDiffers(msg, **kwargs)
def messageDiffers(
2021-02-22 18:02:13 +00:00
self,
msg: Message,
params: Optional[List[Union[str, None, patma.Operator]]] = None,
target: Optional[str] = None,
tags: Optional[
Dict[Union[str, patma.Operator], Union[str, patma.Operator, None]]
] = None,
nick: Optional[str] = None,
fail_msg: Optional[str] = None,
extra_format: Tuple = (),
prefix: Union[None, str, patma.Operator] = None,
**kwargs: Any,
) -> Optional[str]:
"""Returns an error message if the message doesn't match the given arguments,
or None if it matches."""
2015-12-19 23:47:06 +00:00
for (key, value) in kwargs.items():
if getattr(msg, key) != value:
fail_msg = (
fail_msg or "expected {param} to be {expects}, got {got}: {msg}"
)
return fail_msg.format(
*extra_format,
got=getattr(msg, key),
expects=value,
param=key,
msg=msg,
)
if prefix and not patma.match_string(msg.prefix, prefix):
fail_msg = (
fail_msg or "expected prefix to match {expects}, got {got}: {msg}"
)
return fail_msg.format(
*extra_format, got=msg.prefix, expects=prefix, msg=msg
)
if params and not patma.match_list(list(msg.params), params):
fail_msg = (
fail_msg or "expected params to match {expects}, got {got}: {msg}"
)
return fail_msg.format(
*extra_format, got=msg.params, expects=params, msg=msg
)
if tags and not patma.match_dict(msg.tags, tags):
fail_msg = fail_msg or "expected tags to match {expects}, got {got}: {msg}"
return fail_msg.format(*extra_format, got=msg.tags, expects=tags, msg=msg)
if nick:
got_nick = msg.prefix.split("!")[0] if msg.prefix else None
if nick != got_nick:
fail_msg = (
fail_msg
or "expected nick to be {expects}, got {got} instead: {msg}"
)
return fail_msg.format(
*extra_format, got=got_nick, expects=nick, param=key, msg=msg
)
return None
def assertIn(
self,
member: Any,
container: Union[Iterable[Any], Container[Any]],
msg: Optional[str] = None,
fail_msg: Optional[str] = None,
extra_format: Tuple = (),
) -> None:
if fail_msg:
msg = fail_msg.format(*extra_format, item=member, list=container, msg=msg)
assert member in container, msg # type: ignore
2021-02-22 18:02:13 +00:00
def assertNotIn(
self,
member: Any,
container: Union[Iterable[Any], Container[Any]],
msg: Optional[str] = None,
fail_msg: Optional[str] = None,
extra_format: Tuple = (),
) -> None:
2015-12-22 21:33:23 +00:00
if fail_msg:
msg = fail_msg.format(*extra_format, item=member, list=container, msg=msg)
assert member not in container, msg # type: ignore
2021-02-22 18:02:13 +00:00
def assertEqual(
self,
got: T,
expects: T,
msg: Any = None,
fail_msg: Optional[str] = None,
extra_format: Tuple = (),
) -> None:
if fail_msg:
msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
assert got == expects, msg
2021-02-22 18:02:13 +00:00
def assertNotEqual(
self,
got: T,
expects: T,
msg: Any = None,
fail_msg: Optional[str] = None,
extra_format: Tuple = (),
) -> None:
if fail_msg:
msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
assert got != expects, msg
def assertGreater(
self,
got: T,
expects: T,
msg: Any = None,
fail_msg: Optional[str] = None,
extra_format: Tuple = (),
) -> None:
if fail_msg:
msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
assert got >= expects, msg # type: ignore
def assertGreaterEqual(
self,
got: T,
expects: T,
msg: Any = None,
fail_msg: Optional[str] = None,
extra_format: Tuple = (),
) -> None:
if fail_msg:
msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
assert got >= expects, msg # type: ignore
def assertLess(
self,
got: T,
expects: T,
msg: Any = None,
fail_msg: Optional[str] = None,
extra_format: Tuple = (),
) -> None:
if fail_msg:
msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
assert got < expects, msg # type: ignore
def assertLessEqual(
self,
got: T,
expects: T,
msg: Any = None,
fail_msg: Optional[str] = None,
extra_format: Tuple = (),
) -> None:
if fail_msg:
msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
assert got <= expects, msg # type: ignore
def assertTrue(
self,
got: T,
msg: Any = None,
fail_msg: Optional[str] = None,
extra_format: Tuple = (),
) -> None:
if fail_msg:
msg = fail_msg.format(*extra_format, got=got, msg=msg)
assert got, msg
def assertFalse(
self,
got: T,
msg: Any = None,
fail_msg: Optional[str] = None,
extra_format: Tuple = (),
) -> None:
if fail_msg:
msg = fail_msg.format(*extra_format, got=got, msg=msg)
assert not got, msg
@contextlib.contextmanager
def assertRaises(self, exception: Type[Exception]) -> Iterator[None]:
with pytest.raises(exception):
yield
2021-02-22 18:02:13 +00:00
class BaseClientTestCase(_IrcTestCase[basecontrollers.BaseClientController]):
2015-12-19 22:09:06 +00:00
"""Basic class for client tests. Handles spawning a client and exchanging
messages with it."""
2021-02-22 18:02:13 +00:00
conn: Optional[socket.socket]
nick: Optional[str] = None
user: Optional[List[str]] = None
server: socket.socket
protocol_version = Optional[str]
acked_capabilities = Optional[Set[str]]
2021-02-22 18:02:13 +00:00
__new__ = object.__new__ # pytest won't collect Generic[] subclasses otherwise
def setUp(self) -> None:
2015-12-20 00:48:56 +00:00
super().setUp()
2015-12-20 14:11:56 +00:00
self.conn = None
2015-12-19 00:11:57 +00:00
self._setUpServer()
2021-02-22 18:02:13 +00:00
def tearDown(self) -> None:
2015-12-20 14:11:56 +00:00
if self.conn:
try:
2021-02-22 18:02:13 +00:00
self.conn.sendall(b"QUIT :end of test.")
except BrokenPipeError:
2021-02-22 18:02:13 +00:00
pass # client already disconnected
2019-12-08 20:26:21 +00:00
except OSError:
2021-02-22 18:02:13 +00:00
pass # the conn was already closed by the test, or something
self.controller.kill()
2015-12-20 14:11:56 +00:00
if self.conn:
self.conn_file.close()
self.conn.close()
2015-12-19 00:11:57 +00:00
self.server.close()
def _setUpServer(self) -> None:
2015-12-19 00:11:57 +00:00
"""Creates the server and make it listen."""
self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
2021-02-22 18:02:13 +00:00
self.server.bind(("", 0)) # Bind any free port
2015-12-19 00:11:57 +00:00
self.server.listen(1)
2021-02-22 18:02:13 +00:00
# Used to check if the client is alive from time to time
self.server.settimeout(1)
def acceptClient(
self,
tls_cert: Optional[str] = None,
tls_key: Optional[str] = None,
server: Optional[socket.socket] = None,
) -> None:
2015-12-19 00:11:57 +00:00
"""Make the server accept a client connection. Blocking."""
2019-12-08 20:26:21 +00:00
server = server or self.server
assert server
# Wait for the client to connect
while True:
try:
(self.conn, addr) = server.accept()
except socket.timeout:
self.controller.check_is_alive()
else:
break
if tls_cert is None and tls_key is None:
pass
else:
2021-02-22 18:02:13 +00:00
assert (
tls_cert and tls_key
), "tls_cert must be provided if and only if tls_key is."
with tempfile.NamedTemporaryFile(
"at"
) as certfile, tempfile.NamedTemporaryFile("at") as keyfile:
certfile.write(tls_cert)
certfile.seek(0)
keyfile.write(tls_key)
keyfile.seek(0)
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
context.load_cert_chain(certfile=certfile.name, keyfile=keyfile.name)
self.conn = context.wrap_socket(self.conn, server_side=True)
2021-02-22 18:02:13 +00:00
self.conn_file = self.conn.makefile(newline="\r\n", encoding="utf8")
2015-12-19 00:11:57 +00:00
def getLine(self) -> str:
line = self.conn_file.readline()
if self.show_io:
2021-02-22 18:02:13 +00:00
print("{:.3f} C: {}".format(time.time(), line.strip()))
return line
2021-02-22 18:02:13 +00:00
def getMessage(
self, *args: Any, filter_pred: Optional[Callable[[Message], bool]] = None
) -> Message:
"""Gets a message and returns it. If a filter predicate is given,
fetches messages until the predicate returns a False on a message,
and returns this message."""
while True:
line = self.getLine(*args)
if not line:
raise ConnectionClosed()
msg = message_parser.parse_message(line)
if not filter_pred or filter_pred(msg):
return msg
2021-02-22 18:02:13 +00:00
def sendLine(self, line: str) -> None:
assert self.conn
self.conn.sendall(line.encode())
2021-02-22 18:02:13 +00:00
if not line.endswith("\r\n"):
self.conn.sendall(b"\r\n")
2015-12-19 16:52:38 +00:00
if self.show_io:
2021-02-22 18:02:13 +00:00
print("{:.3f} S: {}".format(time.time(), line.strip()))
def readCapLs(
self, auth: Optional[Authentication] = None, tls_config: tls.TlsConfig = None
) -> None:
(hostname, port) = self.server.getsockname()
self.controller.run(
hostname=hostname, port=port, auth=auth, tls_config=tls_config
2021-02-22 18:02:13 +00:00
)
self.acceptClient()
m = self.getMessage()
2021-02-22 18:02:13 +00:00
self.assertEqual(m.command, "CAP", "First message is not CAP LS.")
if m.params == ["LS"]:
self.protocol_version = 301
2021-02-22 18:02:13 +00:00
elif m.params == ["LS", "302"]:
self.protocol_version = 302
2021-02-22 18:02:13 +00:00
elif m.params == ["END"]:
self.protocol_version = None
else:
2021-02-22 18:02:13 +00:00
raise AssertionError("Unknown CAP params: {}".format(m.params))
def userNickPredicate(self, msg: Message) -> bool:
"""Predicate to be used with getMessage to handle NICK/USER
transparently."""
2021-02-22 18:02:13 +00:00
if msg.command == "NICK":
self.assertEqual(len(msg.params), 1, msg=msg)
self.nick = msg.params[0]
return False
2021-02-22 18:02:13 +00:00
elif msg.command == "USER":
self.assertEqual(len(msg.params), 4, msg=msg)
2015-12-19 16:52:38 +00:00
self.user = msg.params
return False
else:
return True
def negotiateCapabilities(
self,
caps: List[str],
cap_ls: bool = True,
auth: Optional[Authentication] = None,
) -> Optional[Message]:
2015-12-20 12:47:30 +00:00
"""Performes a complete capability negociation process, without
ending it, so the caller can continue the negociation."""
2015-12-19 16:52:38 +00:00
if cap_ls:
self.readCapLs(auth)
if not self.protocol_version:
# No negotiation.
return None
2021-02-22 18:02:13 +00:00
self.sendLine("CAP * LS :{}".format(" ".join(caps)))
2015-12-20 14:11:56 +00:00
capability_names = frozenset(capabilities.cap_list_to_dict(caps))
2015-12-19 20:17:06 +00:00
self.acked_capabilities = set()
while True:
m = self.getMessage(filter_pred=self.userNickPredicate)
2021-02-22 18:02:13 +00:00
if m.command != "CAP":
2015-12-19 16:52:38 +00:00
return m
self.assertGreater(len(m.params), 0, m)
2021-02-22 18:02:13 +00:00
if m.params[0] == "REQ":
self.assertEqual(len(m.params), 2, m)
requested = frozenset(m.params[1].split())
2015-12-19 20:17:06 +00:00
if not requested.issubset(capability_names):
2021-02-22 18:02:13 +00:00
self.sendLine(
"CAP {} NAK :{}".format(self.nick or "*", m.params[1][0:100])
)
2015-12-19 16:52:38 +00:00
else:
2021-02-22 18:02:13 +00:00
self.sendLine(
"CAP {} ACK :{}".format(self.nick or "*", m.params[1])
)
self.acked_capabilities.update(requested) # type: ignore
else:
return m
2015-12-19 22:09:06 +00:00
class BaseServerTestCase(
_IrcTestCase[basecontrollers.BaseServerController], Generic[TClientName]
):
2015-12-19 22:09:06 +00:00
"""Basic class for server tests. Handles spawning a server and exchanging
messages with it."""
2021-02-22 18:02:13 +00:00
show_io: bool # set by conftest.py
password: Optional[str] = None
2015-12-25 14:45:06 +00:00
ssl = False
valid_metadata_keys: Set[str] = set()
invalid_metadata_keys: Set[str] = set()
2021-06-26 20:11:47 +00:00
server_support: Optional[Dict[str, Optional[str]]]
run_services = False
2021-02-22 18:02:13 +00:00
__new__ = object.__new__ # pytest won't collect Generic[] subclasses otherwise
def setUp(self) -> None:
2015-12-20 00:48:56 +00:00
super().setUp()
self.server_support = None
2021-08-10 16:15:45 +00:00
(self.hostname, self.port) = self.controller.get_hostname_and_port()
2021-02-22 18:02:13 +00:00
self.controller.run(
self.hostname,
self.port,
password=self.password,
valid_metadata_keys=self.valid_metadata_keys,
invalid_metadata_keys=self.invalid_metadata_keys,
ssl=self.ssl,
run_services=self.run_services,
2021-02-22 18:02:13 +00:00
)
self.clients: Dict[TClientName, client_mock.ClientMock] = {}
2021-02-22 18:02:13 +00:00
def tearDown(self) -> None:
2015-12-19 22:09:06 +00:00
self.controller.kill()
for client in list(self.clients):
self.removeClient(client)
2021-02-22 18:02:13 +00:00
def addClient(
self, name: Optional[TClientName] = None, show_io: Optional[bool] = None
) -> TClientName:
2015-12-19 22:09:06 +00:00
"""Connects a client to the server and adds it to the dict.
If 'name' is not given, uses the lowest unused non-negative integer."""
self.controller.wait_for_port()
if self.run_services:
self.controller.wait_for_services()
2015-12-19 22:09:06 +00:00
if not name:
new_name: int = (
max(
[int(name) for name in self.clients if isinstance(name, (int, str))]
+ [0]
)
+ 1
)
name = cast(TClientName, new_name)
show_io = show_io if show_io is not None else self.show_io
2021-02-22 18:02:13 +00:00
self.clients[name] = client_mock.ClientMock(name=name, show_io=show_io)
self.clients[name].connect(self.hostname, self.port)
2015-12-20 12:12:54 +00:00
return name
2015-12-19 22:09:06 +00:00
def removeClient(self, name: TClientName) -> None:
2015-12-20 12:47:30 +00:00
"""Disconnects the client, without QUIT."""
2015-12-19 22:09:06 +00:00
assert name in self.clients
self.clients[name].disconnect()
2015-12-19 22:09:06 +00:00
del self.clients[name]
def getMessages(self, client: TClientName, **kwargs: Any) -> List[Message]:
return self.clients[client].getMessages(**kwargs)
2021-02-22 18:02:13 +00:00
def getMessage(self, client: TClientName, **kwargs: Any) -> Message:
return self.clients[client].getMessage(**kwargs)
2021-02-22 18:02:13 +00:00
def getRegistrationMessage(self, client: TClientName) -> Message:
"""Filter notices, do not send pings."""
while True:
msg = self.getMessage(
client,
synchronize=False,
filter_pred=lambda m: m.command not in ("NOTICE", RPL_HELLO),
)
if msg.command == "PING":
# Hi Unreal
self.sendLine(client, "PONG :" + msg.params[0])
else:
return msg
2021-02-22 18:02:13 +00:00
def sendLine(self, client: TClientName, line: Union[str, bytes]) -> None:
return self.clients[client].sendLine(line)
2015-12-20 00:48:56 +00:00
def getCapLs(
self, client: TClientName, as_list: bool = False
) -> Union[List[str], Dict[str, Optional[str]]]:
2015-12-20 12:47:30 +00:00
"""Waits for a CAP LS block, parses all CAP LS messages, and return
2015-12-20 14:11:56 +00:00
the dict capabilities, with their values.
If as_list is given, returns the raw list (ie. key/value not split)
in case the order matters (but it shouldn't)."""
caps = []
2015-12-20 00:48:56 +00:00
while True:
m = self.getRegistrationMessage(client)
self.assertMessageMatch(m, command="CAP")
self.assertEqual(m.params[1], "LS", fail_msg="Expected CAP * LS, got {got}")
2021-02-22 18:02:13 +00:00
if m.params[2] == "*":
2015-12-20 14:11:56 +00:00
caps.extend(m.params[3].split())
2015-12-20 00:48:56 +00:00
else:
2015-12-20 14:11:56 +00:00
caps.extend(m.params[2].split())
if not as_list:
return capabilities.cap_list_to_dict(caps)
2015-12-20 14:11:56 +00:00
return caps
def assertDisconnected(self, client: TClientName) -> None:
try:
2018-12-31 00:05:13 +00:00
self.getMessages(client)
self.getMessages(client)
except (socket.error, ConnectionClosed):
del self.clients[client]
return
else:
2021-02-22 18:02:13 +00:00
raise AssertionError("Client not disconnected.")
def skipToWelcome(self, client: TClientName) -> List[Message]:
"""Skip to the point where we are registered
<https://tools.ietf.org/html/rfc2812#section-3.1>
"""
2018-12-28 18:42:47 +00:00
result = []
while True:
m = self.getMessage(client, synchronize=False)
2018-12-28 18:42:47 +00:00
result.append(m)
2021-02-22 18:02:13 +00:00
if m.command == "001":
2018-12-28 18:42:47 +00:00
return result
elif m.command == "PING":
# Hi, Unreal
self.sendLine(client, "PONG :" + m.params[0])
2018-12-28 18:42:47 +00:00
def requestCapabilities(
self,
client: TClientName,
capabilities: List[str],
skip_if_cap_nak: bool = False,
) -> None:
self.sendLine(client, "CAP REQ :{}".format(" ".join(capabilities)))
m = self.getRegistrationMessage(client)
try:
self.assertMessageMatch(
m, command="CAP", fail_msg="Expected CAP ACK, got: {msg}"
)
self.assertEqual(
m.params[1], "ACK", m, fail_msg="Expected CAP ACK, got: {msg}"
)
except AssertionError:
if skip_if_cap_nak:
raise runner.CapabilityNotSupported(" or ".join(capabilities))
else:
raise
2021-02-22 18:02:13 +00:00
def connectClient(
self,
nick: str,
name: TClientName = None,
capabilities: Optional[List[str]] = None,
skip_if_cap_nak: bool = False,
show_io: Optional[bool] = None,
account: Optional[str] = None,
password: Optional[str] = None,
ident: str = "username",
) -> List[Message]:
"""Connections a new client, does the cap negotiation
and connection registration, and skips to the end of the MOTD.
Returns the list of all messages received after registration,
just like `skipToWelcome`."""
2018-12-28 18:42:47 +00:00
client = self.addClient(name, show_io=show_io)
if capabilities:
self.sendLine(client, "CAP LS 302")
m = self.getRegistrationMessage(client)
self.requestCapabilities(client, capabilities, skip_if_cap_nak)
2020-03-11 10:51:23 +00:00
if password is not None:
if "sasl" not in (capabilities or ()):
raise ValueError("Used 'password' option without sasl capbilitiy")
2021-02-22 18:02:13 +00:00
self.sendLine(client, "AUTHENTICATE PLAIN")
m = self.getRegistrationMessage(client)
self.assertMessageMatch(m, command="AUTHENTICATE", params=["+"])
self.sendLine(client, sasl_plain_blob(account or nick, password))
m = self.getRegistrationMessage(client)
self.assertIn(m.command, ["900", "903"], str(m))
2021-02-22 18:02:13 +00:00
self.sendLine(client, "NICK {}".format(nick))
self.sendLine(client, "USER %s * * :Realname" % (ident,))
if capabilities:
self.sendLine(client, "CAP END")
2018-12-28 18:42:47 +00:00
welcome = self.skipToWelcome(client)
2021-02-22 18:02:13 +00:00
self.sendLine(client, "PING foo")
# Skip all that happy welcoming stuff
self.server_support = {}
while True:
m = self.getMessage(client)
2021-02-22 18:02:13 +00:00
if m.command == "PONG":
break
2021-02-22 18:02:13 +00:00
elif m.command == "005":
2015-12-25 21:47:11 +00:00
for param in m.params[1:-1]:
2021-02-22 18:02:13 +00:00
if "=" in param:
(key, value) = param.split("=")
self.server_support[key] = value
2015-12-25 21:47:11 +00:00
else:
self.server_support[param] = None
2018-12-28 18:42:47 +00:00
welcome.append(m)
return welcome
def joinClient(self, client: TClientName, channel: str) -> None:
2021-02-22 18:02:13 +00:00
self.sendLine(client, "JOIN {}".format(channel))
received = {m.command for m in self.getMessages(client)}
2021-02-22 18:02:13 +00:00
self.assertIn(
"366",
received,
fail_msg="Join to {} failed, {item} is not in the set of "
"received responses: {list}",
extra_format=(channel,),
)
def joinChannel(self, client: TClientName, channel: str) -> None:
2021-02-22 18:02:13 +00:00
self.sendLine(client, "JOIN {}".format(channel))
# wait until we see them join the channel
joined = False
while not joined:
for msg in self.getMessages(client):
2021-02-22 18:02:13 +00:00
if (
msg.command == "JOIN"
and 0 < len(msg.params)
and msg.params[0].lower() == channel.lower()
):
joined = True
break
2021-02-18 04:27:48 +00:00
elif msg.command in CHANNEL_JOIN_FAIL_NUMERICS:
raise ChannelJoinException(msg.command, msg.params)
2021-02-22 18:02:13 +00:00
_TSelf = TypeVar("_TSelf", bound="OptionalityHelper")
_TReturn = TypeVar("_TReturn")
class OptionalityHelper(Generic[TController]):
controller: TController
def checkSaslSupport(self) -> None:
if self.controller.supported_sasl_mechanisms:
return
2021-02-22 18:02:13 +00:00
raise runner.NotImplementedByController("SASL")
def checkMechanismSupport(self, mechanism: str) -> None:
2015-12-20 14:11:56 +00:00
if mechanism in self.controller.supported_sasl_mechanisms:
return
raise runner.OptionalSaslMechanismNotSupported(mechanism)
2015-12-20 14:11:56 +00:00
@staticmethod
def skipUnlessHasMechanism(
mech: str,
) -> Callable[[Callable[..., _TReturn]], Callable[..., _TReturn]]:
# Just a function returning a function that takes functions and
# returns functions, nothing to see here.
# If Python didn't have such an awful syntax for callables, it would be:
# str -> ((TSelf -> TReturn) -> (TSelf -> TReturn))
def decorator(f: Callable[..., _TReturn]) -> Callable[..., _TReturn]:
2015-12-20 14:11:56 +00:00
@functools.wraps(f)
def newf(self: _TSelf, *args: Any, **kwargs: Any) -> _TReturn:
2015-12-20 14:11:56 +00:00
self.checkMechanismSupport(mech)
return f(self, *args, **kwargs)
2021-02-22 18:02:13 +00:00
2015-12-20 14:11:56 +00:00
return newf
2021-02-22 18:02:13 +00:00
2015-12-20 14:11:56 +00:00
return decorator
@staticmethod
def skipUnlessHasSasl(f: Callable[..., _TReturn]) -> Callable[..., _TReturn]:
@functools.wraps(f)
def newf(self: _TSelf, *args: Any, **kwargs: Any) -> _TReturn:
self.checkSaslSupport()
return f(self, *args, **kwargs)
2021-02-22 18:02:13 +00:00
return newf
def mark_services(cls: TClass) -> TClass:
cls.run_services = True
return pytest.mark.services(cls) # type: ignore
def mark_specifications(
*specifications_str: str, deprecated: bool = False, strict: bool = False
) -> Callable[[TCallable], TCallable]:
specifications = frozenset(
Specifications.from_name(s) if isinstance(s, str) else s
for s in specifications_str
)
if None in specifications:
raise ValueError("Invalid set of specifications: {}".format(specifications))
def decorator(f: TCallable) -> TCallable:
for specification in specifications:
f = getattr(pytest.mark, specification.value)(f)
if strict:
f = pytest.mark.strict(f)
if deprecated:
f = pytest.mark.deprecated(f)
return f
return decorator
def mark_capabilities(
*capabilities_str: str, deprecated: bool = False, strict: bool = False
) -> Callable[[TCallable], TCallable]:
capabilities = frozenset(
Capabilities.from_name(c) if isinstance(c, str) else c for c in capabilities_str
)
if None in capabilities:
raise ValueError("Invalid set of capabilities: {}".format(capabilities))
def decorator(f: TCallable) -> TCallable:
for capability in capabilities:
f = getattr(pytest.mark, capability.value)(f)
# Support for any capability implies IRCv3
f = pytest.mark.IRCv3(f)
return f
return decorator
def mark_isupport(
*tokens_str: str, deprecated: bool = False, strict: bool = False
) -> Callable[[TCallable], TCallable]:
tokens = frozenset(
IsupportTokens.from_name(c) if isinstance(c, str) else c for c in tokens_str
)
if None in tokens:
raise ValueError("Invalid set of isupport tokens: {}".format(tokens))
def decorator(f: TCallable) -> TCallable:
for token in tokens:
f = getattr(pytest.mark, token.value)(f)
return f
return decorator