type-annotate all functions outside the tests themselves.

This commit is contained in:
Valentin Lorentz 2021-02-28 17:08:27 +01:00 committed by Valentin Lorentz
parent ac2a37362c
commit 62a87b5957
22 changed files with 545 additions and 313 deletions

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import dataclasses
import os
import shutil
@ -5,8 +7,11 @@ import socket
import subprocess
import tempfile
import time
from typing import Any, Callable, Dict, Optional, Set
from typing import IO, Any, Callable, Dict, Optional, Set
import irctest
from . import authentication, tls
from .runner import NotImplementedByController
@ -41,27 +46,27 @@ class _BaseController:
a process (eg. a server or a client), as well as sending it instructions
that are not part of the IRC specification."""
# set by conftest.py
openssl_bin: str
supports_sts: bool
supported_sasl_mechanisms: Set[str]
proc: Optional[subprocess.Popen]
def __init__(self, test_config: TestCaseControllerConfig):
self.test_config = test_config
self.proc = None
def check_is_alive(self):
def check_is_alive(self) -> None:
assert self.proc
self.proc.poll()
if self.proc.returncode is not None:
raise ProcessStopped()
class DirectoryBasedController(_BaseController):
"""Helper for controllers whose software configuration is based on an
arbitrary directory."""
def __init__(self, test_config: TestCaseControllerConfig):
super().__init__(test_config)
self.directory = None
def kill_proc(self):
def kill_proc(self) -> None:
"""Terminates the controlled process, waits for it to exit, and
eventually kills it."""
assert self.proc
self.proc.terminate()
try:
self.proc.wait(5)
@ -69,20 +74,36 @@ class DirectoryBasedController(_BaseController):
self.proc.kill()
self.proc = None
def kill(self):
def kill(self) -> None:
"""Calls `kill_proc` and cleans the configuration."""
if self.proc:
self.kill_proc()
class DirectoryBasedController(_BaseController):
"""Helper for controllers whose software configuration is based on an
arbitrary directory."""
directory: Optional[str]
def __init__(self, test_config: TestCaseControllerConfig):
super().__init__(test_config)
self.directory = None
def kill(self) -> None:
"""Calls `kill_proc` and cleans the configuration."""
super().kill()
if self.directory:
shutil.rmtree(self.directory)
def terminate(self):
def terminate(self) -> None:
"""Stops the process gracefully, and does not clean its config."""
assert self.proc
self.proc.terminate()
self.proc.wait()
self.proc = None
def open_file(self, name, mode="a"):
def open_file(self, name: str, mode: str = "a") -> IO:
"""Open a file in the configuration directory."""
assert self.directory
if os.sep in name:
@ -92,16 +113,12 @@ class DirectoryBasedController(_BaseController):
assert os.path.isdir(dir_)
return open(os.path.join(self.directory, name), mode)
def create_config(self):
"""If there is no config dir, creates it and returns True.
Else returns False."""
if self.directory:
return False
else:
def create_config(self) -> None:
if not self.directory:
self.directory = tempfile.mkdtemp()
return True
def gen_ssl(self):
def gen_ssl(self) -> None:
assert self.directory
self.csr_path = os.path.join(self.directory, "ssl.csr")
self.key_path = os.path.join(self.directory, "ssl.key")
self.pem_path = os.path.join(self.directory, "ssl.pem")
@ -145,7 +162,13 @@ class DirectoryBasedController(_BaseController):
class BaseClientController(_BaseController):
"""Base controller for IRC clients."""
def run(self, hostname, port, auth):
def run(
self,
hostname: str,
port: int,
auth: Optional[authentication.Authentication],
tls_config: Optional[tls.TlsConfig] = None,
) -> None:
raise NotImplementedError()
@ -154,17 +177,29 @@ class BaseServerController(_BaseController):
_port_wait_interval = 0.1
port_open = False
port: int
supports_sts: bool
supported_sasl_mechanisms: Set[str]
def run(self, hostname, port, password, valid_metadata_keys, invalid_metadata_keys):
def run(
self,
hostname: str,
port: int,
*,
password: Optional[str],
ssl: bool,
valid_metadata_keys: Optional[Set[str]],
invalid_metadata_keys: Optional[Set[str]],
) -> None:
raise NotImplementedError()
def registerUser(self, case, username, password=None):
def registerUser(
self,
case: irctest.cases.BaseServerTestCase, # type: ignore
username: str,
password: Optional[str] = None,
) -> None:
raise NotImplementedByController("account registration")
def wait_for_port(self):
def wait_for_port(self) -> None:
while not self.port_open:
self.check_is_alive()
time.sleep(self._port_wait_interval)

View File

@ -3,16 +3,34 @@ import socket
import ssl
import tempfile
import time
from typing import Optional, Set
from typing import (
Any,
Callable,
Container,
Dict,
Generic,
Hashable,
Iterable,
List,
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
cast,
)
import unittest
import pytest
from . import basecontrollers, client_mock, runner
from . import basecontrollers, client_mock, runner, tls
from .authentication import Authentication
from .basecontrollers import TestCaseControllerConfig
from .exceptions import ConnectionClosed
from .irc_utils import capabilities, message_parser
from .irc_utils.junkdrawer import normalizeWhitespace
from .irc_utils.message_parser import Message
from .irc_utils.sasl import sasl_plain_blob
from .numerics import (
ERR_BADCHANNELKEY,
@ -35,18 +53,33 @@ CHANNEL_JOIN_FAIL_NUMERICS = frozenset(
]
)
# typevar for decorators
TCallable = TypeVar("TCallable", bound=Callable)
# typevar for the client name used by tests (usually int or str)
TClientName = TypeVar("TClientName", bound=Union[Hashable, int])
TController = TypeVar("TController", bound=basecontrollers._BaseController)
# general-purpose typevar
T = TypeVar("T")
class ChannelJoinException(Exception):
def __init__(self, code, params):
def __init__(self, code: str, params: List[str]):
super().__init__(f"Failed to join channel ({code}): {params}")
self.code = code
self.params = params
class _IrcTestCase(unittest.TestCase):
class _IrcTestCase(unittest.TestCase, Generic[TController]):
"""Base class for test cases."""
controllerClass = None # Will be set by __main__.py
# Will be set by __main__.py
controllerClass: Type[TController]
show_io: bool
controller: TController
@staticmethod
def config() -> TestCaseControllerConfig:
@ -56,7 +89,7 @@ class _IrcTestCase(unittest.TestCase):
"""
return TestCaseControllerConfig()
def description(self):
def description(self) -> str:
method_doc = self._testMethodDoc
if not method_doc:
return ""
@ -64,14 +97,13 @@ class _IrcTestCase(unittest.TestCase):
method_doc, removeNewline=False
).strip().replace("\n ", "\n\t")
def setUp(self):
def setUp(self) -> None:
super().setUp()
self.controller = self.controllerClass(self.config())
self.inbuffer = []
if self.show_io:
print("---- new test ----")
def assertMessageEqual(self, msg, **kwargs):
def assertMessageEqual(self, msg: Message, **kwargs: Any) -> None:
"""Helper for partially comparing a message.
Takes the message as first arguments, and comparisons to be made
@ -83,21 +115,21 @@ class _IrcTestCase(unittest.TestCase):
if error:
raise self.failureException(error)
def messageEqual(self, msg, **kwargs):
def messageEqual(self, msg: Message, **kwargs: Any) -> bool:
"""Boolean negation of `messageDiffers` (returns a boolean,
not an optional string)."""
return not self.messageDiffers(msg, **kwargs)
def messageDiffers(
self,
msg,
params=None,
target=None,
nick=None,
fail_msg=None,
extra_format=(),
**kwargs,
):
msg: Message,
params: Optional[List[Any]] = None,
target: Optional[str] = None,
nick: Optional[str] = None,
fail_msg: Optional[str] = None,
extra_format: Tuple = (),
**kwargs: Any,
) -> Optional[str]:
"""Returns an error message if the message doesn't match the given arguments,
or None if it matches."""
for (key, value) in kwargs.items():
@ -120,7 +152,7 @@ class _IrcTestCase(unittest.TestCase):
)
if nick:
got_nick = msg.prefix.split("!")[0]
got_nick = msg.prefix.split("!")[0] if msg.prefix else None
if msg.prefix is None:
fail_msg = (
fail_msg or "expected nick to be {expects}, got {got} prefix: {msg}"
@ -131,7 +163,7 @@ class _IrcTestCase(unittest.TestCase):
return None
def listMatch(self, got, expected):
def listMatch(self, got: List[str], expected: List[Any]) -> bool:
"""Returns True iff the list are equal.
The ellipsis (aka. "..." aka triple dots) can be used on the 'expected'
side as a wildcard, matching any *single* value."""
@ -145,62 +177,124 @@ class _IrcTestCase(unittest.TestCase):
return False
return True
def assertIn(self, item, list_, msg=None, fail_msg=None, extra_format=()):
def assertIn(
self,
member: Any,
container: Union[Iterable[Any], Container[Any]],
msg: Optional[str] = None,
fail_msg: Optional[str] = None,
extra_format: Tuple = (),
) -> None:
if fail_msg:
fail_msg = fail_msg.format(*extra_format, item=item, list=list_, msg=msg)
super().assertIn(item, list_, fail_msg)
fail_msg = fail_msg.format(
*extra_format, item=member, list=container, msg=msg
)
super().assertIn(member, container, fail_msg)
def assertNotIn(self, item, list_, msg=None, fail_msg=None, extra_format=()):
def assertNotIn(
self,
member: Any,
container: Union[Iterable[Any], Container[Any]],
msg: Optional[str] = None,
fail_msg: Optional[str] = None,
extra_format: Tuple = (),
) -> None:
if fail_msg:
fail_msg = fail_msg.format(*extra_format, item=item, list=list_, msg=msg)
super().assertNotIn(item, list_, fail_msg)
fail_msg = fail_msg.format(
*extra_format, item=member, list=container, msg=msg
)
super().assertNotIn(member, container, fail_msg)
def assertEqual(self, got, expects, msg=None, fail_msg=None, extra_format=()):
def assertEqual(
self,
got: T,
expects: T,
msg: Any = None,
fail_msg: Optional[str] = None,
extra_format: Tuple = (),
) -> None:
if fail_msg:
fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
super().assertEqual(got, expects, fail_msg)
def assertNotEqual(self, got, expects, msg=None, fail_msg=None, extra_format=()):
def assertNotEqual(
self,
got: T,
expects: T,
msg: Any = None,
fail_msg: Optional[str] = None,
extra_format: Tuple = (),
) -> None:
if fail_msg:
fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
super().assertNotEqual(got, expects, fail_msg)
def assertGreater(self, got, expects, msg=None, fail_msg=None, extra_format=()):
def assertGreater(
self,
got: T,
expects: T,
msg: Any = None,
fail_msg: Optional[str] = None,
extra_format: Tuple = (),
) -> None:
if fail_msg:
fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
super().assertGreater(got, expects, fail_msg)
def assertGreaterEqual(
self, got, expects, msg=None, fail_msg=None, extra_format=()
):
self,
got: T,
expects: T,
msg: Any = None,
fail_msg: Optional[str] = None,
extra_format: Tuple = (),
) -> None:
if fail_msg:
fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
super().assertGreaterEqual(got, expects, fail_msg)
def assertLess(self, got, expects, msg=None, fail_msg=None, extra_format=()):
def assertLess(
self,
got: T,
expects: T,
msg: Any = None,
fail_msg: Optional[str] = None,
extra_format: Tuple = (),
) -> None:
if fail_msg:
fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
super().assertLess(got, expects, fail_msg)
def assertLessEqual(self, got, expects, msg=None, fail_msg=None, extra_format=()):
def assertLessEqual(
self,
got: T,
expects: T,
msg: Any = None,
fail_msg: Optional[str] = None,
extra_format: Tuple = (),
) -> None:
if fail_msg:
fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
super().assertLessEqual(got, expects, fail_msg)
class BaseClientTestCase(_IrcTestCase):
class BaseClientTestCase(_IrcTestCase[basecontrollers.BaseClientController]):
"""Basic class for client tests. Handles spawning a client and exchanging
messages with it."""
nick = None
user = None
conn: Optional[socket.socket]
nick: Optional[str] = None
user: Optional[List[str]] = None
server: socket.socket
protocol_version = Optional[str]
acked_capabilities = Optional[Set[str]]
def setUp(self):
def setUp(self) -> None:
super().setUp()
self.conn = None
self._setUpServer()
def tearDown(self):
def tearDown(self) -> None:
if self.conn:
try:
self.conn.sendall(b"QUIT :end of test.")
@ -214,7 +308,7 @@ class BaseClientTestCase(_IrcTestCase):
self.conn.close()
self.server.close()
def _setUpServer(self):
def _setUpServer(self) -> None:
"""Creates the server and make it listen."""
self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.server.bind(("", 0)) # Bind any free port
@ -223,9 +317,15 @@ class BaseClientTestCase(_IrcTestCase):
# Used to check if the client is alive from time to time
self.server.settimeout(1)
def acceptClient(self, tls_cert=None, tls_key=None, server=None):
def acceptClient(
self,
tls_cert: Optional[str] = None,
tls_key: Optional[str] = None,
server: Optional[socket.socket] = None,
) -> None:
"""Make the server accept a client connection. Blocking."""
server = server or self.server
assert server
# Wait for the client to connect
while True:
try:
@ -252,17 +352,17 @@ class BaseClientTestCase(_IrcTestCase):
self.conn = context.wrap_socket(self.conn, server_side=True)
self.conn_file = self.conn.makefile(newline="\r\n", encoding="utf8")
def getLine(self):
def getLine(self) -> str:
line = self.conn_file.readline()
if self.show_io:
print("{:.3f} C: {}".format(time.time(), line.strip()))
return line
def getMessages(self, *args):
lines = self.getLines(*args)
return map(message_parser.parse_message, lines)
def getMessage(self, *args, filter_pred=None):
def getMessage(
self,
*args: Any,
filter_pred: Optional[Callable[[Message], bool]] = None,
) -> Message:
"""Gets a message and returns it. If a filter predicate is given,
fetches messages until the predicate returns a False on a message,
and returns this message."""
@ -274,18 +374,17 @@ class BaseClientTestCase(_IrcTestCase):
if not filter_pred or filter_pred(msg):
return msg
def sendLine(self, line):
def sendLine(self, line: str) -> None:
assert self.conn
self.conn.sendall(line.encode())
if not line.endswith("\r\n"):
self.conn.sendall(b"\r\n")
if self.show_io:
print("{:.3f} S: {}".format(time.time(), line.strip()))
class ClientNegociationHelper:
"""Helper class for tests handling capabilities negociation."""
def readCapLs(self, auth=None, tls_config=None):
def readCapLs(
self, auth: Optional[Authentication] = None, tls_config: tls.TlsConfig = None
) -> None:
(hostname, port) = self.server.getsockname()
self.controller.run(
hostname=hostname, port=port, auth=auth, tls_config=tls_config
@ -302,28 +401,33 @@ class ClientNegociationHelper:
else:
raise AssertionError("Unknown CAP params: {}".format(m.params))
def userNickPredicate(self, msg):
def userNickPredicate(self, msg: Message) -> bool:
"""Predicate to be used with getMessage to handle NICK/USER
transparently."""
if msg.command == "NICK":
self.assertEqual(len(msg.params), 1, msg)
self.assertEqual(len(msg.params), 1, msg=msg)
self.nick = msg.params[0]
return False
elif msg.command == "USER":
self.assertEqual(len(msg.params), 4, msg)
self.assertEqual(len(msg.params), 4, msg=msg)
self.user = msg.params
return False
else:
return True
def negotiateCapabilities(self, caps, cap_ls=True, auth=None):
def negotiateCapabilities(
self,
caps: List[str],
cap_ls: bool = True,
auth: Optional[Authentication] = None,
) -> Optional[Message]:
"""Performes a complete capability negociation process, without
ending it, so the caller can continue the negociation."""
if cap_ls:
self.readCapLs(auth)
if not self.protocol_version:
# No negotiation.
return
return None
self.sendLine("CAP * LS :{}".format(" ".join(caps)))
capability_names = frozenset(capabilities.cap_list_to_dict(caps))
self.acked_capabilities = set()
@ -343,21 +447,25 @@ class ClientNegociationHelper:
self.sendLine(
"CAP {} ACK :{}".format(self.nick or "*", m.params[1])
)
self.acked_capabilities.update(requested)
self.acked_capabilities.update(requested) # type: ignore
else:
return m
class BaseServerTestCase(_IrcTestCase):
class BaseServerTestCase(
_IrcTestCase[basecontrollers.BaseServerController], Generic[TClientName]
):
"""Basic class for server tests. Handles spawning a server and exchanging
messages with it."""
show_io: bool # set by conftest.py
password: Optional[str] = None
ssl = False
valid_metadata_keys: Set[str] = set()
invalid_metadata_keys: Set[str] = set()
def setUp(self):
def setUp(self) -> None:
super().setUp()
self.server_support = None
self.find_hostname_and_port()
@ -369,53 +477,64 @@ class BaseServerTestCase(_IrcTestCase):
invalid_metadata_keys=self.invalid_metadata_keys,
ssl=self.ssl,
)
self.clients = {}
self.clients: Dict[TClientName, client_mock.ClientMock] = {}
def tearDown(self):
def tearDown(self) -> None:
self.controller.kill()
for client in list(self.clients):
self.removeClient(client)
def find_hostname_and_port(self):
def find_hostname_and_port(self) -> None:
"""Find available hostname/port to listen on."""
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
(self.hostname, self.port) = s.getsockname()
s.close()
def addClient(self, name=None, show_io=None):
def addClient(
self, name: Optional[TClientName] = None, show_io: Optional[bool] = None
) -> TClientName:
"""Connects a client to the server and adds it to the dict.
If 'name' is not given, uses the lowest unused non-negative integer."""
self.controller.wait_for_port()
if not name:
name = max(map(int, list(self.clients) + [0])) + 1
new_name: int = (
max(
[int(name) for name in self.clients if isinstance(name, (int, str))]
+ [0]
)
+ 1
)
name = cast(TClientName, new_name)
show_io = show_io if show_io is not None else self.show_io
self.clients[name] = client_mock.ClientMock(name=name, show_io=show_io)
self.clients[name].connect(self.hostname, self.port)
return name
def removeClient(self, name):
def removeClient(self, name: TClientName) -> None:
"""Disconnects the client, without QUIT."""
assert name in self.clients
self.clients[name].disconnect()
del self.clients[name]
def getMessages(self, client, **kwargs):
def getMessages(self, client: TClientName, **kwargs: Any) -> List[Message]:
return self.clients[client].getMessages(**kwargs)
def getMessage(self, client, **kwargs):
def getMessage(self, client: TClientName, **kwargs: Any) -> Message:
return self.clients[client].getMessage(**kwargs)
def getRegistrationMessage(self, client):
def getRegistrationMessage(self, client: TClientName) -> Message:
"""Filter notices, do not send pings."""
return self.getMessage(
client, synchronize=False, filter_pred=lambda m: m.command != "NOTICE"
)
def sendLine(self, client, line):
def sendLine(self, client: TClientName, line: Union[str, bytes]) -> None:
return self.clients[client].sendLine(line)
def getCapLs(self, client, as_list=False):
def getCapLs(
self, client: TClientName, as_list: bool = False
) -> Union[List[str], Dict[str, Optional[str]]]:
"""Waits for a CAP LS block, parses all CAP LS messages, and return
the dict capabilities, with their values.
@ -431,10 +550,10 @@ class BaseServerTestCase(_IrcTestCase):
else:
caps.extend(m.params[2].split())
if not as_list:
caps = capabilities.cap_list_to_dict(caps)
return capabilities.cap_list_to_dict(caps)
return caps
def assertDisconnected(self, client):
def assertDisconnected(self, client: TClientName) -> None:
try:
self.getMessages(client)
self.getMessages(client)
@ -444,7 +563,7 @@ class BaseServerTestCase(_IrcTestCase):
else:
raise AssertionError("Client not disconnected.")
def skipToWelcome(self, client):
def skipToWelcome(self, client: TClientName) -> List[Message]:
"""Skip to the point where we are registered
<https://tools.ietf.org/html/rfc2812#section-3.1>
"""
@ -457,15 +576,19 @@ class BaseServerTestCase(_IrcTestCase):
def connectClient(
self,
nick,
name=None,
capabilities=None,
skip_if_cap_nak=False,
show_io=None,
account=None,
password=None,
ident="username",
):
nick: str,
name: TClientName = None,
capabilities: Optional[List[str]] = None,
skip_if_cap_nak: bool = False,
show_io: Optional[bool] = None,
account: Optional[str] = None,
password: Optional[str] = None,
ident: str = "username",
) -> List[Message]:
"""Connections a new client, does the cap negotiation
and connection registration, and skips to the end of the MOTD.
Returns the list of all messages received after registration,
just like `skipToWelcome`."""
client = self.addClient(name, show_io=show_io)
if capabilities is not None and 0 < len(capabilities):
self.sendLine(client, "CAP REQ :{}".format(" ".join(capabilities)))
@ -502,14 +625,14 @@ class BaseServerTestCase(_IrcTestCase):
for param in m.params[1:-1]:
if "=" in param:
(key, value) = param.split("=")
self.server_support[key] = value
else:
(key, value) = (param, None)
self.server_support[key] = value
self.server_support[param] = None
welcome.append(m)
return welcome
def joinClient(self, client, channel):
def joinClient(self, client: TClientName, channel: str) -> None:
self.sendLine(client, "JOIN {}".format(channel))
received = {m.command for m in self.getMessages(client)}
self.assertIn(
@ -520,7 +643,7 @@ class BaseServerTestCase(_IrcTestCase):
extra_format=(channel,),
)
def joinChannel(self, client, channel):
def joinChannel(self, client: TClientName, channel: str) -> None:
self.sendLine(client, "JOIN {}".format(channel))
# wait until we see them join the channel
joined = False
@ -537,24 +660,34 @@ class BaseServerTestCase(_IrcTestCase):
raise ChannelJoinException(msg.command, msg.params)
class OptionalityHelper:
controller: basecontrollers.BaseServerController
_TSelf = TypeVar("_TSelf", bound="OptionalityHelper")
_TReturn = TypeVar("_TReturn")
def checkSaslSupport(self):
class OptionalityHelper(Generic[TController]):
controller: TController
def checkSaslSupport(self) -> None:
if self.controller.supported_sasl_mechanisms:
return
raise runner.NotImplementedByController("SASL")
def checkMechanismSupport(self, mechanism):
def checkMechanismSupport(self, mechanism: str) -> None:
if mechanism in self.controller.supported_sasl_mechanisms:
return
raise runner.OptionalSaslMechanismNotSupported(mechanism)
@staticmethod
def skipUnlessHasMechanism(mech):
def decorator(f):
def skipUnlessHasMechanism(
mech: str,
) -> Callable[[Callable[[_TSelf], _TReturn]], Callable[[_TSelf], _TReturn]]:
# Just a function returning a function that takes functions and
# returns functions, nothing to see here.
# If Python didn't have such an awful syntax for callables, it would be:
# str -> ((TSelf -> TReturn) -> (TSelf -> TReturn))
def decorator(f: Callable[[_TSelf], _TReturn]) -> Callable[[_TSelf], _TReturn]:
@functools.wraps(f)
def newf(self):
def newf(self: _TSelf) -> _TReturn:
self.checkMechanismSupport(mech)
return f(self)
@ -562,23 +695,29 @@ class OptionalityHelper:
return decorator
def skipUnlessHasSasl(f):
@staticmethod
def skipUnlessHasSasl(
f: Callable[[_TSelf], _TReturn]
) -> Callable[[_TSelf], _TReturn]:
@functools.wraps(f)
def newf(self):
def newf(self: _TSelf) -> _TReturn:
self.checkSaslSupport()
return f(self)
return newf
def mark_specifications(*specifications, deprecated=False, strict=False):
def mark_specifications(
*specifications_str: str, deprecated: bool = False, strict: bool = False
) -> Callable[[TCallable], TCallable]:
specifications = frozenset(
Specifications.from_name(s) if isinstance(s, str) else s for s in specifications
Specifications.from_name(s) if isinstance(s, str) else s
for s in specifications_str
)
if None in specifications:
raise ValueError("Invalid set of specifications: {}".format(specifications))
def decorator(f):
def decorator(f: TCallable) -> TCallable:
for specification in specifications:
f = getattr(pytest.mark, specification.value)(f)
if strict:
@ -590,14 +729,16 @@ def mark_specifications(*specifications, deprecated=False, strict=False):
return decorator
def mark_capabilities(*capabilities, deprecated=False, strict=False):
def mark_capabilities(
*capabilities_str: str, deprecated: bool = False, strict: bool = False
) -> Callable[[TCallable], TCallable]:
capabilities = frozenset(
Capabilities.from_name(c) if isinstance(c, str) else c for c in capabilities
Capabilities.from_name(c) if isinstance(c, str) else c for c in capabilities_str
)
if None in capabilities:
raise ValueError("Invalid set of capabilities: {}".format(capabilities))
def decorator(f):
def decorator(f: TCallable) -> TCallable:
for capability in capabilities:
f = getattr(pytest.mark, capability.value)(f)
# Support for any capability implies IRCv3
@ -607,14 +748,16 @@ def mark_capabilities(*capabilities, deprecated=False, strict=False):
return decorator
def mark_isupport(*tokens, deprecated=False, strict=False):
def mark_isupport(
*tokens_str: str, deprecated: bool = False, strict: bool = False
) -> Callable[[TCallable], TCallable]:
tokens = frozenset(
IsupportTokens.from_name(c) if isinstance(c, str) else c for c in tokens
IsupportTokens.from_name(c) if isinstance(c, str) else c for c in tokens_str
)
if None in tokens:
raise ValueError("Invalid set of isupport tokens: {}".format(tokens))
def decorator(f):
def decorator(f: TCallable) -> TCallable:
for token in tokens:
f = getattr(pytest.mark, token.value)(f)
return f

View File

@ -2,36 +2,41 @@ import socket
import ssl
import sys
import time
from typing import Any, Callable, List, Optional, Union
from .exceptions import ConnectionClosed, NoMessageException
from .irc_utils import message_parser
class ClientMock:
def __init__(self, name, show_io):
def __init__(self, name: Any, show_io: bool):
self.name = name
self.show_io = show_io
self.inbuffer = []
self.inbuffer: List[message_parser.Message] = []
self.ssl = False
def connect(self, hostname, port):
def connect(self, hostname: str, port: int) -> None:
self.conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.conn.settimeout(1) # TODO: configurable
self.conn.connect((hostname, port))
if self.show_io:
print("{:.3f} {}: connects to server.".format(time.time(), self.name))
def disconnect(self):
def disconnect(self) -> None:
if self.show_io:
print("{:.3f} {}: disconnects from server.".format(time.time(), self.name))
self.conn.close()
def starttls(self):
def starttls(self) -> None:
assert not self.ssl, "SSL already active."
self.conn = ssl.wrap_socket(self.conn)
self.ssl = True
def getMessages(self, synchronize=True, assert_get_one=False, raw=False):
def getMessages(
self, synchronize: bool = True, assert_get_one: bool = False, raw: bool = False
) -> List[message_parser.Message]:
"""actually returns List[str] in the rare case where raw=True."""
token: Optional[str]
if synchronize:
token = "synchronize{}".format(time.monotonic())
self.sendLine("PING {}".format(token))
@ -79,7 +84,7 @@ class ClientMock:
got_pong = True
else:
if raw:
messages.append(line)
messages.append(line) # type: ignore
else:
messages.append(message)
data = b""
@ -91,7 +96,13 @@ class ClientMock:
else:
return messages
def getMessage(self, filter_pred=None, synchronize=True, raw=False):
def getMessage(
self,
filter_pred: Optional[Callable[[message_parser.Message], bool]] = None,
synchronize: bool = True,
raw: bool = False,
) -> message_parser.Message:
"""Returns str in the rare case where raw=True"""
while True:
if not self.inbuffer:
self.inbuffer = self.getMessages(
@ -103,7 +114,7 @@ class ClientMock:
if not filter_pred or filter_pred(message):
return message
def sendLine(self, line):
def sendLine(self, line: Union[str, bytes]) -> None:
if isinstance(line, str):
encoded_line = line.encode()
elif isinstance(line, bytes):
@ -113,7 +124,7 @@ class ClientMock:
if not encoded_line.endswith(b"\r\n"):
encoded_line += b"\r\n"
try:
ret = self.conn.sendall(encoded_line)
ret = self.conn.sendall(encoded_line) # type: ignore
except BrokenPipeError:
raise ConnectionClosed()
if (

View File

@ -2,7 +2,7 @@ from irctest import cases
from irctest.irc_utils.message_parser import Message
class CapTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper):
class CapTestCase(cases.BaseClientTestCase):
@cases.mark_specifications("IRCv3")
def testSendCap(self):
"""Send CAP LS 302 and read the result."""

View File

@ -39,9 +39,7 @@ class IdentityHash:
return self._data
class SaslTestCase(
cases.BaseClientTestCase, cases.ClientNegociationHelper, cases.OptionalityHelper
):
class SaslTestCase(cases.BaseClientTestCase, cases.OptionalityHelper):
@cases.OptionalityHelper.skipUnlessHasMechanism("PLAIN")
def testPlain(self):
"""Test PLAIN authentication with correct username/password."""
@ -263,9 +261,7 @@ class SaslTestCase(
authenticator.response(msg)
class Irc302SaslTestCase(
cases.BaseClientTestCase, cases.ClientNegociationHelper, cases.OptionalityHelper
):
class Irc302SaslTestCase(cases.BaseClientTestCase, cases.OptionalityHelper):
@cases.OptionalityHelper.skipUnlessHasMechanism("PLAIN")
def testPlainNotAvailable(self):
"""Test the client does not try to authenticate using a mechanism the

View File

@ -1,6 +1,6 @@
import os
import subprocess
from typing import Set
from typing import Optional, Set, Type
from irctest.basecontrollers import (
BaseServerController,
@ -47,20 +47,21 @@ class CharybdisController(BaseServerController, DirectoryBasedController):
supported_sasl_mechanisms: Set[str] = set()
supports_sts = False
def create_config(self):
def create_config(self) -> None:
super().create_config()
with self.open_file("server.conf"):
pass
def run(
self,
hostname,
port,
password=None,
ssl=False,
valid_metadata_keys=None,
invalid_metadata_keys=None,
):
hostname: str,
port: int,
*,
password: Optional[str],
ssl: bool,
valid_metadata_keys: Optional[Set[str]] = None,
invalid_metadata_keys: Optional[Set[str]] = None,
) -> None:
if valid_metadata_keys or invalid_metadata_keys:
raise NotImplementedByController(
"Defining valid and invalid METADATA keys."
@ -85,6 +86,7 @@ class CharybdisController(BaseServerController, DirectoryBasedController):
ssl_config=ssl_config,
)
)
assert self.directory
self.proc = subprocess.Popen(
[
self.binary_name,
@ -98,5 +100,5 @@ class CharybdisController(BaseServerController, DirectoryBasedController):
)
def get_irctest_controller_class():
def get_irctest_controller_class() -> Type[CharybdisController]:
return CharybdisController

View File

@ -1,33 +1,25 @@
import subprocess
from typing import Optional, Type
from irctest.basecontrollers import BaseClientController, NotImplementedByController
from irctest import authentication, tls
from irctest.basecontrollers import (
BaseClientController,
DirectoryBasedController,
NotImplementedByController,
)
class GircController(BaseClientController):
class GircController(BaseClientController, DirectoryBasedController):
software_name = "gIRC"
supported_sasl_mechanisms = ["PLAIN"]
supported_sasl_mechanisms = {"PLAIN"}
def __init__(self):
super().__init__()
self.directory = None
self.proc = None
def kill(self):
if self.proc:
self.proc.terminate()
try:
self.proc.wait(5)
except subprocess.TimeoutExpired:
self.proc.kill()
self.proc = None
def __del__(self):
if self.proc:
self.proc.kill()
if self.directory:
self.directory.cleanup()
def run(self, hostname, port, auth, tls_config):
def run(
self,
hostname: str,
port: int,
auth: Optional[authentication.Authentication],
tls_config: Optional[tls.TlsConfig] = None,
) -> None:
if tls_config:
print(tls_config)
raise NotImplementedByController("TLS options")
@ -42,5 +34,5 @@ class GircController(BaseClientController):
self.proc = subprocess.Popen(["girc_test", "connect"] + args)
def get_irctest_controller_class():
def get_irctest_controller_class() -> Type[GircController]:
return GircController

View File

@ -1,6 +1,6 @@
import os
import subprocess
from typing import Set
from typing import Optional, Set, Type
from irctest.basecontrollers import (
BaseServerController,
@ -44,20 +44,21 @@ class HybridController(BaseServerController, DirectoryBasedController):
supports_sts = False
supported_sasl_mechanisms: Set[str] = set()
def create_config(self):
def create_config(self) -> None:
super().create_config()
with self.open_file("server.conf"):
pass
def run(
self,
hostname,
port,
password=None,
ssl=False,
valid_metadata_keys=None,
invalid_metadata_keys=None,
):
hostname: str,
port: int,
*,
password: Optional[str],
ssl: bool,
valid_metadata_keys: Optional[Set[str]] = None,
invalid_metadata_keys: Optional[Set[str]] = None,
) -> None:
if valid_metadata_keys or invalid_metadata_keys:
raise NotImplementedByController(
"Defining valid and invalid METADATA keys."
@ -82,6 +83,7 @@ class HybridController(BaseServerController, DirectoryBasedController):
ssl_config=ssl_config,
)
)
assert self.directory
self.proc = subprocess.Popen(
[
"ircd",
@ -96,5 +98,5 @@ class HybridController(BaseServerController, DirectoryBasedController):
)
def get_irctest_controller_class():
def get_irctest_controller_class() -> Type[HybridController]:
return HybridController

View File

@ -1,6 +1,6 @@
import os
import subprocess
from typing import Set
from typing import Optional, Set, Type
from irctest.basecontrollers import (
BaseServerController,
@ -42,21 +42,22 @@ class InspircdController(BaseServerController, DirectoryBasedController):
supported_sasl_mechanisms: Set[str] = set()
supports_str = False
def create_config(self):
def create_config(self) -> None:
super().create_config()
with self.open_file("server.conf"):
pass
def run(
self,
hostname,
port,
password=None,
ssl=False,
restricted_metadata_keys=None,
valid_metadata_keys=None,
invalid_metadata_keys=None,
):
hostname: str,
port: int,
*,
password: Optional[str],
ssl: bool,
valid_metadata_keys: Optional[Set[str]] = None,
invalid_metadata_keys: Optional[Set[str]] = None,
restricted_metadata_keys: Optional[Set[str]] = None,
) -> None:
if valid_metadata_keys or invalid_metadata_keys:
raise NotImplementedByController(
"Defining valid and invalid METADATA keys."
@ -81,6 +82,7 @@ class InspircdController(BaseServerController, DirectoryBasedController):
ssl_config=ssl_config,
)
)
assert self.directory
self.proc = subprocess.Popen(
[
"inspircd",
@ -92,5 +94,5 @@ class InspircdController(BaseServerController, DirectoryBasedController):
)
def get_irctest_controller_class():
def get_irctest_controller_class() -> Type[InspircdController]:
return InspircdController

View File

@ -1,3 +1,5 @@
from typing import Type
from .charybdis import CharybdisController
@ -6,5 +8,5 @@ class IrcdSevenController(CharybdisController):
binary_name = "ircd-seven"
def get_irctest_controller_class():
def get_irctest_controller_class() -> Type[IrcdSevenController]:
return IrcdSevenController

View File

@ -1,7 +1,8 @@
import os
import subprocess
from typing import Optional, Type
from irctest import tls
from irctest import authentication, tls
from irctest.basecontrollers import BaseClientController, DirectoryBasedController
TEMPLATE_CONFIG = """
@ -35,15 +36,20 @@ class LimnoriaController(BaseClientController, DirectoryBasedController):
}
supports_sts = True
def create_config(self):
create_config = super().create_config()
if create_config:
with self.open_file("bot.conf"):
pass
with self.open_file("conf/users.conf"):
pass
def create_config(self) -> None:
super().create_config()
with self.open_file("bot.conf"):
pass
with self.open_file("conf/users.conf"):
pass
def run(self, hostname, port, auth, tls_config=None):
def run(
self,
hostname: str,
port: int,
auth: Optional[authentication.Authentication],
tls_config: Optional[tls.TlsConfig] = None,
) -> None:
if tls_config is None:
tls_config = tls.TlsConfig(enable=False, trusted_fingerprints=[])
# Runs a client with the config given as arguments
@ -72,10 +78,11 @@ class LimnoriaController(BaseClientController, DirectoryBasedController):
else "",
)
)
assert self.directory
self.proc = subprocess.Popen(
["supybot", os.path.join(self.directory, "bot.conf")]
)
def get_irctest_controller_class():
def get_irctest_controller_class() -> Type[LimnoriaController]:
return LimnoriaController

View File

@ -1,11 +1,13 @@
import os
import subprocess
from typing import Optional, Set, Type
from irctest.basecontrollers import (
BaseServerController,
DirectoryBasedController,
NotImplementedByController,
)
from irctest.cases import BaseServerTestCase
TEMPLATE_CONFIG = """
clients:
@ -61,7 +63,7 @@ server:
"""
def make_list(list_):
def make_list(list_: Set[str]) -> str:
return "\n".join(map(" - {}".format, list_))
@ -69,25 +71,27 @@ class MammonController(BaseServerController, DirectoryBasedController):
software_name = "Mammon"
supported_sasl_mechanisms = {"PLAIN", "ECDSA-NIST256P-CHALLENGE"}
def create_config(self):
def create_config(self) -> None:
super().create_config()
with self.open_file("server.conf"):
pass
def kill_proc(self):
def kill_proc(self) -> None:
# Mammon does not seem to handle SIGTERM very well
assert self.proc
self.proc.kill()
def run(
self,
hostname,
port,
password=None,
ssl=False,
restricted_metadata_keys=(),
valid_metadata_keys=(),
invalid_metadata_keys=(),
):
hostname: str,
port: int,
*,
password: Optional[str],
ssl: bool,
valid_metadata_keys: Optional[Set[str]] = None,
invalid_metadata_keys: Optional[Set[str]] = None,
restricted_metadata_keys: Optional[Set[str]] = None,
) -> None:
if password is not None:
raise NotImplementedByController("PASS command")
if ssl:
@ -101,12 +105,13 @@ class MammonController(BaseServerController, DirectoryBasedController):
directory=self.directory,
hostname=hostname,
port=port,
authorized_keys=make_list(valid_metadata_keys),
restricted_keys=make_list(restricted_metadata_keys),
authorized_keys=make_list(valid_metadata_keys or set()),
restricted_keys=make_list(restricted_metadata_keys or set()),
)
)
# with self.open_file('server.yml', 'r') as fd:
# print(fd.read())
assert self.directory
self.proc = subprocess.Popen(
[
"mammond",
@ -116,7 +121,12 @@ class MammonController(BaseServerController, DirectoryBasedController):
]
)
def registerUser(self, case, username, password=None):
def registerUser(
self,
case: BaseServerTestCase,
username: str,
password: Optional[str] = None,
) -> None:
# XXX: Move this somewhere else when
# https://github.com/ircv3/ircv3-specifications/pull/152 becomes
# part of the specification
@ -135,5 +145,5 @@ class MammonController(BaseServerController, DirectoryBasedController):
case.removeClient(client)
def get_irctest_controller_class():
def get_irctest_controller_class() -> Type[MammonController]:
return MammonController

View File

@ -2,12 +2,14 @@ import copy
import json
import os
import subprocess
from typing import Any, Dict, Optional, Set, Type, Union
from irctest.basecontrollers import (
BaseServerController,
DirectoryBasedController,
NotImplementedByController,
)
from irctest.cases import BaseServerTestCase
OPER_PWD = "frenchfries"
@ -116,7 +118,7 @@ BASE_CONFIG = {
LOGGING_CONFIG = {"logging": [{"method": "stderr", "level": "debug", "type": "*"}]}
def hash_password(password):
def hash_password(password: Union[str, bytes]) -> str:
if isinstance(password, str):
password = password.encode("utf-8")
# simulate entry of password and confirmation:
@ -134,25 +136,23 @@ class OragonoController(BaseServerController, DirectoryBasedController):
supported_sasl_mechanisms = {"PLAIN"}
supports_sts = True
def create_config(self):
def create_config(self) -> None:
super().create_config()
with self.open_file("ircd.yaml"):
pass
def kill_proc(self):
self.proc.kill()
def run(
self,
hostname,
port,
password=None,
ssl=False,
restricted_metadata_keys=None,
valid_metadata_keys=None,
invalid_metadata_keys=None,
config=None,
):
hostname: str,
port: int,
*,
password: Optional[str],
ssl: bool,
valid_metadata_keys: Optional[Set[str]] = None,
invalid_metadata_keys: Optional[Set[str]] = None,
restricted_metadata_keys: Optional[Set[str]] = None,
config: Optional[Any] = None,
) -> None:
if valid_metadata_keys or invalid_metadata_keys:
raise NotImplementedByController(
"Defining valid and invalid METADATA keys."
@ -162,6 +162,8 @@ class OragonoController(BaseServerController, DirectoryBasedController):
if config is None:
config = copy.deepcopy(BASE_CONFIG)
assert self.directory
enable_chathistory = self.test_config.chathistory
enable_roleplay = self.test_config.oragono_roleplay
if enable_chathistory or enable_roleplay:
@ -180,12 +182,14 @@ class OragonoController(BaseServerController, DirectoryBasedController):
self.key_path = os.path.join(self.directory, "ssl.key")
self.pem_path = os.path.join(self.directory, "ssl.pem")
listener_conf = {"tls": {"cert": self.pem_path, "key": self.key_path}}
config["server"]["listeners"][bind_address] = listener_conf
config["server"]["listeners"][bind_address] = listener_conf # type: ignore
config["datastore"]["path"] = os.path.join(self.directory, "ircd.db")
config["datastore"]["path"] = os.path.join( # type: ignore
self.directory, "ircd.db"
)
if password is not None:
config["server"]["password"] = hash_password(password)
config["server"]["password"] = hash_password(password) # type: ignore
assert self.proc is None
@ -198,7 +202,12 @@ class OragonoController(BaseServerController, DirectoryBasedController):
["oragono", "run", "--conf", self._config_path, "--quiet"]
)
def registerUser(self, case, username, password=None):
def registerUser(
self,
case: BaseServerTestCase,
username: str,
password: Optional[str] = None,
) -> None:
# XXX: Move this somewhere else when
# https://github.com/ircv3/ircv3-specifications/pull/152 becomes
# part of the specification
@ -210,34 +219,35 @@ class OragonoController(BaseServerController, DirectoryBasedController):
while case.getRegistrationMessage(client).command != "001":
pass
case.getMessages(client)
assert password
case.sendLine(client, "NS REGISTER " + password)
msg = case.getMessage(client)
assert msg.params == [username, "Account created"]
case.sendLine(client, "QUIT")
case.assertDisconnected(client)
def _write_config(self):
def _write_config(self) -> None:
with open(self._config_path, "w") as fd:
json.dump(self._config, fd)
def baseConfig(self):
def baseConfig(self) -> Dict:
return copy.deepcopy(BASE_CONFIG)
def getConfig(self):
def getConfig(self) -> Dict:
return copy.deepcopy(self._config)
def addLoggingToConfig(self, config=None):
def addLoggingToConfig(self, config: Optional[Dict] = None) -> Dict:
if config is None:
config = self.baseConfig()
config.update(LOGGING_CONFIG)
return config
def addMysqlToConfig(self, config=None):
def addMysqlToConfig(self, config: Optional[Dict] = None) -> Dict:
mysql_password = os.getenv("MYSQL_PASSWORD")
if not mysql_password:
return config
if config is None:
config = self.baseConfig()
if not mysql_password:
return config
config["datastore"]["mysql"] = {
"enabled": True,
"host": "localhost",
@ -259,7 +269,7 @@ class OragonoController(BaseServerController, DirectoryBasedController):
}
return config
def rehash(self, case, config):
def rehash(self, case: BaseServerTestCase, config: Dict) -> None:
self._config = config
self._write_config()
client = "operator_for_rehash"
@ -270,11 +280,11 @@ class OragonoController(BaseServerController, DirectoryBasedController):
case.sendLine(client, "QUIT")
case.assertDisconnected(client)
def enable_debug_logging(self, case):
def enable_debug_logging(self, case: BaseServerTestCase) -> None:
config = self.getConfig()
config.update(LOGGING_CONFIG)
self.rehash(case, config)
def get_irctest_controller_class():
def get_irctest_controller_class() -> Type[OragonoController]:
return OragonoController

View File

@ -1,3 +1,5 @@
from typing import Type
from .charybdis import CharybdisController
@ -6,5 +8,5 @@ class SolanumController(CharybdisController):
binary_name = "solanum"
def get_irctest_controller_class():
def get_irctest_controller_class() -> Type[SolanumController]:
return SolanumController

View File

@ -1,8 +1,14 @@
import os
import subprocess
import tempfile
from typing import Optional, TextIO, Type, cast
from irctest.basecontrollers import BaseClientController, NotImplementedByController
from irctest import authentication, tls
from irctest.basecontrollers import (
BaseClientController,
NotImplementedByController,
TestCaseControllerConfig,
)
TEMPLATE_CONFIG = """
[core]
@ -24,30 +30,34 @@ class SopelController(BaseClientController):
supported_sasl_mechanisms = {"PLAIN"}
supports_sts = False
def __init__(self, test_config):
def __init__(self, test_config: TestCaseControllerConfig):
super().__init__(test_config)
self.filename = next(tempfile._get_candidate_names()) + ".cfg"
self.proc = None
self.filename = next(tempfile._get_candidate_names()) + ".cfg" # type: ignore
def kill(self):
if self.proc:
self.proc.kill()
def kill(self) -> None:
super().kill()
if self.filename:
try:
os.unlink(os.path.join(os.path.expanduser("~/.sopel/"), self.filename))
except OSError: #  File does not exist
pass
def open_file(self, filename, mode="a"):
def open_file(self, filename: str, mode: str = "a") -> TextIO:
dir_path = os.path.expanduser("~/.sopel/")
os.makedirs(dir_path, exist_ok=True)
return open(os.path.join(dir_path, filename), mode)
return cast(TextIO, open(os.path.join(dir_path, filename), mode))
def create_config(self):
def create_config(self) -> None:
with self.open_file(self.filename):
pass
def run(self, hostname, port, auth, tls_config):
def run(
self,
hostname: str,
port: int,
auth: Optional[authentication.Authentication],
tls_config: Optional[tls.TlsConfig] = None,
) -> None:
# Runs a client with the config given as arguments
if tls_config is not None:
raise NotImplementedByController("TLS configuration")
@ -66,5 +76,5 @@ class SopelController(BaseClientController):
self.proc = subprocess.Popen(["sopel", "--quiet", "-c", self.filename])
def get_irctest_controller_class():
def get_irctest_controller_class() -> Type[SopelController]:
return SopelController

View File

@ -2,8 +2,10 @@
Handles ambiguities of RFCs.
"""
from typing import List
def normalize_namreply_params(params):
def normalize_namreply_params(params: List[str]) -> List[str]:
# So… RFC 2812 says:
# "( "=" / "*" / "@" ) <channel>
# :[ "@" / "+" ] <nick> *( " " [ "@" / "+" ] <nick> )
@ -12,6 +14,7 @@ def normalize_namreply_params(params):
# prefix.
# So let's normalize this to “with space”, and strip spaces at the
# end of the nick list.
params = list(params) # copy the list
if len(params) == 3:
assert params[1][0] in "=*@", params
params.insert(1, params[1][0])

View File

@ -1,10 +1,12 @@
def cap_list_to_dict(caps):
d = {}
from typing import Dict, List, Optional
def cap_list_to_dict(caps: List[str]) -> Dict[str, Optional[str]]:
d: Dict[str, Optional[str]] = {}
for cap in caps:
if "=" in cap:
(key, value) = cap.split("=", 1)
d[key] = value
else:
key = cap
value = None
d[key] = value
d[cap] = None
return d

View File

@ -1,16 +1,17 @@
import datetime
import re
import secrets
from typing import Dict
# thanks jess!
IRCV3_FORMAT_STRFTIME = "%Y-%m-%dT%H:%M:%S.%f%z"
def ircv3_timestamp_to_unixtime(timestamp):
def ircv3_timestamp_to_unixtime(timestamp: str) -> float:
return datetime.datetime.strptime(timestamp, IRCV3_FORMAT_STRFTIME).timestamp()
def random_name(base):
def random_name(base: str) -> str:
return base + "-" + secrets.token_hex(8)
@ -26,16 +27,16 @@ class MultipleReplacer:
# We use an object instead of a lambda function because it avoids the
# need for using the staticmethod() on the lambda function if assigning
# it to a class in Python 3.
def __init__(self, dict_):
def __init__(self, dict_: Dict[str, str]):
self._dict = dict_
dict_ = dict([(re.escape(key), val) for key, val in dict_.items()])
self._matcher = re.compile("|".join(dict_.keys()))
def __call__(self, s):
def __call__(self, s: str) -> str:
return self._matcher.sub(lambda m: self._dict[m.group(0)], s)
def normalizeWhitespace(s, removeNewline=True):
def normalizeWhitespace(s: str, removeNewline: bool = True) -> str:
r"""Normalizes the whitespace in a string; \s+ becomes one space."""
if not s:
return str(s) # not the same reference

View File

@ -18,8 +18,8 @@ unescape_tag_value = MultipleReplacer(dict(map(lambda x: (x[1], x[0]), TAG_ESCAP
tag_key_validator = re.compile(r"\+?(\S+/)?[a-zA-Z0-9-]+")
def parse_tags(s):
tags = {}
def parse_tags(s: str) -> Dict[str, Optional[str]]:
tags: Dict[str, Optional[str]] = {}
for tag in s.split(";"):
if "=" not in tag:
tags[tag] = None
@ -54,15 +54,15 @@ class Message:
)
def parse_message(s):
def parse_message(s: str) -> Message:
"""Parse a message according to
http://tools.ietf.org/html/rfc1459#section-2.3.1
and
http://ircv3.net/specs/core/message-tags-3.2.html"""
s = s.rstrip("\r\n")
if s.startswith("@"):
(tags, s) = s.split(" ", 1)
tags = parse_tags(tags[1:])
(tags_str, s) = s.split(" ", 1)
tags = parse_tags(tags_str[1:])
else:
tags = {}
if " :" in s:
@ -70,10 +70,7 @@ def parse_message(s):
tokens = list(filter(bool, other_tokens.split(" "))) + [trailing_param]
else:
tokens = list(filter(bool, s.split(" ")))
if tokens[0].startswith(":"):
prefix = tokens.pop(0)[1:]
else:
prefix = None
prefix = prefix = tokens.pop(0)[1:] if tokens[0].startswith(":") else None
command = tokens.pop(0)
params = tokens
return Message(tags=tags, prefix=prefix, command=command, params=params)

View File

@ -1,7 +1,7 @@
import base64
def sasl_plain_blob(username, passphrase):
def sasl_plain_blob(username: str, passphrase: str) -> str:
blob = base64.b64encode(
b"\x00".join(
(

View File

@ -1,14 +1,15 @@
import collections
from typing import Dict, Union
import unittest
class NotImplementedByController(unittest.SkipTest, NotImplementedError):
def __str__(self):
def __str__(self) -> str:
return "Not implemented by controller: {}".format(self.args[0])
class ImplementationChoice(unittest.SkipTest):
def __str__(self):
def __str__(self) -> str:
return (
"Choice in the implementation makes it impossible to "
"perform a test: {}".format(self.args[0])
@ -16,49 +17,49 @@ class ImplementationChoice(unittest.SkipTest):
class OptionalExtensionNotSupported(unittest.SkipTest):
def __str__(self):
def __str__(self) -> str:
return "Unsupported extension: {}".format(self.args[0])
class OptionalSaslMechanismNotSupported(unittest.SkipTest):
def __str__(self):
def __str__(self) -> str:
return "Unsupported SASL mechanism: {}".format(self.args[0])
class CapabilityNotSupported(unittest.SkipTest):
def __str__(self):
def __str__(self) -> str:
return "Unsupported capability: {}".format(self.args[0])
class IsupportTokenNotSupported(unittest.SkipTest):
def __str__(self):
def __str__(self) -> str:
return "Unsupported ISUPPORT token: {}".format(self.args[0])
class ChannelModeNotSupported(unittest.SkipTest):
def __str__(self):
def __str__(self) -> str:
return "Unsupported channel mode: {} ({})".format(self.args[0], self.args[1])
class ExtbanNotSupported(unittest.SkipTest):
def __str__(self):
def __str__(self) -> str:
return "Unsupported extban: {} ({})".format(self.args[0], self.args[1])
class NotRequiredBySpecifications(unittest.SkipTest):
def __str__(self):
def __str__(self) -> str:
return "Tests not required by the set of tested specification(s)."
class SkipStrictTest(unittest.SkipTest):
def __str__(self):
def __str__(self) -> str:
return "Tests not required because strict tests are disabled."
class TextTestResult(unittest.TextTestResult):
def getDescription(self, test):
def getDescription(self, test: unittest.TestCase) -> str:
if hasattr(test, "description"):
doc_first_lines = test.description()
doc_first_lines = test.description() # type: ignore
else:
doc_first_lines = test.shortDescription()
return "\n".join((str(test), doc_first_lines or ""))
@ -71,7 +72,9 @@ class TextTestRunner(unittest.TextTestRunner):
resultclass = TextTestResult
def run(self, test):
def run(
self, test: Union[unittest.TestSuite, unittest.TestCase]
) -> unittest.TestResult:
result = super().run(test)
assert self.resultclass is TextTestResult
if result.skipped:
@ -80,7 +83,7 @@ class TextTestRunner(unittest.TextTestRunner):
"Some tests were skipped because the following optional "
"specifications/mechanisms are not supported:"
)
msg_to_count = collections.defaultdict(lambda: 0)
msg_to_count: Dict[str, int] = collections.defaultdict(lambda: 0)
for (test, msg) in result.skipped:
msg_to_count[msg] += 1
for (msg, count) in sorted(msg_to_count.items()):

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import enum
@ -15,7 +17,7 @@ class Specifications(enum.Enum):
Modern = "modern"
@classmethod
def from_name(cls, name):
def from_name(cls, name: str) -> Specifications:
name = name.upper()
for spec in cls:
if spec.value.upper() == name:
@ -37,7 +39,7 @@ class Capabilities(enum.Enum):
STS = "sts"
@classmethod
def from_name(cls, name):
def from_name(cls, name: str) -> Capabilities:
try:
return cls(name.lower())
except ValueError:
@ -50,7 +52,7 @@ class IsupportTokens(enum.Enum):
STATUSMSG = "STATUSMSG"
@classmethod
def from_name(cls, name):
def from_name(cls, name: str) -> IsupportTokens:
try:
return cls(name.upper())
except ValueError: