type-annotate all functions outside the tests themselves.

This commit is contained in:
2021-02-28 17:08:27 +01:00
committed by Valentin Lorentz
parent ac2a37362c
commit 62a87b5957
22 changed files with 545 additions and 313 deletions

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import dataclasses import dataclasses
import os import os
import shutil import shutil
@ -5,8 +7,11 @@ import socket
import subprocess import subprocess
import tempfile import tempfile
import time import time
from typing import Any, Callable, Dict, Optional, Set from typing import IO, Any, Callable, Dict, Optional, Set
import irctest
from . import authentication, tls
from .runner import NotImplementedByController from .runner import NotImplementedByController
@ -41,27 +46,27 @@ class _BaseController:
a process (eg. a server or a client), as well as sending it instructions a process (eg. a server or a client), as well as sending it instructions
that are not part of the IRC specification.""" that are not part of the IRC specification."""
# set by conftest.py
openssl_bin: str
supports_sts: bool
supported_sasl_mechanisms: Set[str]
proc: Optional[subprocess.Popen]
def __init__(self, test_config: TestCaseControllerConfig): def __init__(self, test_config: TestCaseControllerConfig):
self.test_config = test_config self.test_config = test_config
self.proc = None self.proc = None
def check_is_alive(self): def check_is_alive(self) -> None:
assert self.proc
self.proc.poll() self.proc.poll()
if self.proc.returncode is not None: if self.proc.returncode is not None:
raise ProcessStopped() raise ProcessStopped()
def kill_proc(self) -> None:
class DirectoryBasedController(_BaseController):
"""Helper for controllers whose software configuration is based on an
arbitrary directory."""
def __init__(self, test_config: TestCaseControllerConfig):
super().__init__(test_config)
self.directory = None
def kill_proc(self):
"""Terminates the controlled process, waits for it to exit, and """Terminates the controlled process, waits for it to exit, and
eventually kills it.""" eventually kills it."""
assert self.proc
self.proc.terminate() self.proc.terminate()
try: try:
self.proc.wait(5) self.proc.wait(5)
@ -69,20 +74,36 @@ class DirectoryBasedController(_BaseController):
self.proc.kill() self.proc.kill()
self.proc = None self.proc = None
def kill(self): def kill(self) -> None:
"""Calls `kill_proc` and cleans the configuration.""" """Calls `kill_proc` and cleans the configuration."""
if self.proc: if self.proc:
self.kill_proc() self.kill_proc()
class DirectoryBasedController(_BaseController):
"""Helper for controllers whose software configuration is based on an
arbitrary directory."""
directory: Optional[str]
def __init__(self, test_config: TestCaseControllerConfig):
super().__init__(test_config)
self.directory = None
def kill(self) -> None:
"""Calls `kill_proc` and cleans the configuration."""
super().kill()
if self.directory: if self.directory:
shutil.rmtree(self.directory) shutil.rmtree(self.directory)
def terminate(self): def terminate(self) -> None:
"""Stops the process gracefully, and does not clean its config.""" """Stops the process gracefully, and does not clean its config."""
assert self.proc
self.proc.terminate() self.proc.terminate()
self.proc.wait() self.proc.wait()
self.proc = None self.proc = None
def open_file(self, name, mode="a"): def open_file(self, name: str, mode: str = "a") -> IO:
"""Open a file in the configuration directory.""" """Open a file in the configuration directory."""
assert self.directory assert self.directory
if os.sep in name: if os.sep in name:
@ -92,16 +113,12 @@ class DirectoryBasedController(_BaseController):
assert os.path.isdir(dir_) assert os.path.isdir(dir_)
return open(os.path.join(self.directory, name), mode) return open(os.path.join(self.directory, name), mode)
def create_config(self): def create_config(self) -> None:
"""If there is no config dir, creates it and returns True. if not self.directory:
Else returns False."""
if self.directory:
return False
else:
self.directory = tempfile.mkdtemp() self.directory = tempfile.mkdtemp()
return True
def gen_ssl(self): def gen_ssl(self) -> None:
assert self.directory
self.csr_path = os.path.join(self.directory, "ssl.csr") self.csr_path = os.path.join(self.directory, "ssl.csr")
self.key_path = os.path.join(self.directory, "ssl.key") self.key_path = os.path.join(self.directory, "ssl.key")
self.pem_path = os.path.join(self.directory, "ssl.pem") self.pem_path = os.path.join(self.directory, "ssl.pem")
@ -145,7 +162,13 @@ class DirectoryBasedController(_BaseController):
class BaseClientController(_BaseController): class BaseClientController(_BaseController):
"""Base controller for IRC clients.""" """Base controller for IRC clients."""
def run(self, hostname, port, auth): def run(
self,
hostname: str,
port: int,
auth: Optional[authentication.Authentication],
tls_config: Optional[tls.TlsConfig] = None,
) -> None:
raise NotImplementedError() raise NotImplementedError()
@ -154,17 +177,29 @@ class BaseServerController(_BaseController):
_port_wait_interval = 0.1 _port_wait_interval = 0.1
port_open = False port_open = False
port: int
supports_sts: bool def run(
supported_sasl_mechanisms: Set[str] self,
hostname: str,
def run(self, hostname, port, password, valid_metadata_keys, invalid_metadata_keys): port: int,
*,
password: Optional[str],
ssl: bool,
valid_metadata_keys: Optional[Set[str]],
invalid_metadata_keys: Optional[Set[str]],
) -> None:
raise NotImplementedError() raise NotImplementedError()
def registerUser(self, case, username, password=None): def registerUser(
self,
case: irctest.cases.BaseServerTestCase, # type: ignore
username: str,
password: Optional[str] = None,
) -> None:
raise NotImplementedByController("account registration") raise NotImplementedByController("account registration")
def wait_for_port(self): def wait_for_port(self) -> None:
while not self.port_open: while not self.port_open:
self.check_is_alive() self.check_is_alive()
time.sleep(self._port_wait_interval) time.sleep(self._port_wait_interval)

View File

@ -3,16 +3,34 @@ import socket
import ssl import ssl
import tempfile import tempfile
import time import time
from typing import Optional, Set from typing import (
Any,
Callable,
Container,
Dict,
Generic,
Hashable,
Iterable,
List,
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
cast,
)
import unittest import unittest
import pytest import pytest
from . import basecontrollers, client_mock, runner from . import basecontrollers, client_mock, runner, tls
from .authentication import Authentication
from .basecontrollers import TestCaseControllerConfig from .basecontrollers import TestCaseControllerConfig
from .exceptions import ConnectionClosed from .exceptions import ConnectionClosed
from .irc_utils import capabilities, message_parser from .irc_utils import capabilities, message_parser
from .irc_utils.junkdrawer import normalizeWhitespace from .irc_utils.junkdrawer import normalizeWhitespace
from .irc_utils.message_parser import Message
from .irc_utils.sasl import sasl_plain_blob from .irc_utils.sasl import sasl_plain_blob
from .numerics import ( from .numerics import (
ERR_BADCHANNELKEY, ERR_BADCHANNELKEY,
@ -35,18 +53,33 @@ CHANNEL_JOIN_FAIL_NUMERICS = frozenset(
] ]
) )
# typevar for decorators
TCallable = TypeVar("TCallable", bound=Callable)
# typevar for the client name used by tests (usually int or str)
TClientName = TypeVar("TClientName", bound=Union[Hashable, int])
TController = TypeVar("TController", bound=basecontrollers._BaseController)
# general-purpose typevar
T = TypeVar("T")
class ChannelJoinException(Exception): class ChannelJoinException(Exception):
def __init__(self, code, params): def __init__(self, code: str, params: List[str]):
super().__init__(f"Failed to join channel ({code}): {params}") super().__init__(f"Failed to join channel ({code}): {params}")
self.code = code self.code = code
self.params = params self.params = params
class _IrcTestCase(unittest.TestCase): class _IrcTestCase(unittest.TestCase, Generic[TController]):
"""Base class for test cases.""" """Base class for test cases."""
controllerClass = None # Will be set by __main__.py # Will be set by __main__.py
controllerClass: Type[TController]
show_io: bool
controller: TController
@staticmethod @staticmethod
def config() -> TestCaseControllerConfig: def config() -> TestCaseControllerConfig:
@ -56,7 +89,7 @@ class _IrcTestCase(unittest.TestCase):
""" """
return TestCaseControllerConfig() return TestCaseControllerConfig()
def description(self): def description(self) -> str:
method_doc = self._testMethodDoc method_doc = self._testMethodDoc
if not method_doc: if not method_doc:
return "" return ""
@ -64,14 +97,13 @@ class _IrcTestCase(unittest.TestCase):
method_doc, removeNewline=False method_doc, removeNewline=False
).strip().replace("\n ", "\n\t") ).strip().replace("\n ", "\n\t")
def setUp(self): def setUp(self) -> None:
super().setUp() super().setUp()
self.controller = self.controllerClass(self.config()) self.controller = self.controllerClass(self.config())
self.inbuffer = []
if self.show_io: if self.show_io:
print("---- new test ----") print("---- new test ----")
def assertMessageEqual(self, msg, **kwargs): def assertMessageEqual(self, msg: Message, **kwargs: Any) -> None:
"""Helper for partially comparing a message. """Helper for partially comparing a message.
Takes the message as first arguments, and comparisons to be made Takes the message as first arguments, and comparisons to be made
@ -83,21 +115,21 @@ class _IrcTestCase(unittest.TestCase):
if error: if error:
raise self.failureException(error) raise self.failureException(error)
def messageEqual(self, msg, **kwargs): def messageEqual(self, msg: Message, **kwargs: Any) -> bool:
"""Boolean negation of `messageDiffers` (returns a boolean, """Boolean negation of `messageDiffers` (returns a boolean,
not an optional string).""" not an optional string)."""
return not self.messageDiffers(msg, **kwargs) return not self.messageDiffers(msg, **kwargs)
def messageDiffers( def messageDiffers(
self, self,
msg, msg: Message,
params=None, params: Optional[List[Any]] = None,
target=None, target: Optional[str] = None,
nick=None, nick: Optional[str] = None,
fail_msg=None, fail_msg: Optional[str] = None,
extra_format=(), extra_format: Tuple = (),
**kwargs, **kwargs: Any,
): ) -> Optional[str]:
"""Returns an error message if the message doesn't match the given arguments, """Returns an error message if the message doesn't match the given arguments,
or None if it matches.""" or None if it matches."""
for (key, value) in kwargs.items(): for (key, value) in kwargs.items():
@ -120,7 +152,7 @@ class _IrcTestCase(unittest.TestCase):
) )
if nick: if nick:
got_nick = msg.prefix.split("!")[0] got_nick = msg.prefix.split("!")[0] if msg.prefix else None
if msg.prefix is None: if msg.prefix is None:
fail_msg = ( fail_msg = (
fail_msg or "expected nick to be {expects}, got {got} prefix: {msg}" fail_msg or "expected nick to be {expects}, got {got} prefix: {msg}"
@ -131,7 +163,7 @@ class _IrcTestCase(unittest.TestCase):
return None return None
def listMatch(self, got, expected): def listMatch(self, got: List[str], expected: List[Any]) -> bool:
"""Returns True iff the list are equal. """Returns True iff the list are equal.
The ellipsis (aka. "..." aka triple dots) can be used on the 'expected' The ellipsis (aka. "..." aka triple dots) can be used on the 'expected'
side as a wildcard, matching any *single* value.""" side as a wildcard, matching any *single* value."""
@ -145,62 +177,124 @@ class _IrcTestCase(unittest.TestCase):
return False return False
return True return True
def assertIn(self, item, list_, msg=None, fail_msg=None, extra_format=()): def assertIn(
self,
member: Any,
container: Union[Iterable[Any], Container[Any]],
msg: Optional[str] = None,
fail_msg: Optional[str] = None,
extra_format: Tuple = (),
) -> None:
if fail_msg: if fail_msg:
fail_msg = fail_msg.format(*extra_format, item=item, list=list_, msg=msg) fail_msg = fail_msg.format(
super().assertIn(item, list_, fail_msg) *extra_format, item=member, list=container, msg=msg
)
super().assertIn(member, container, fail_msg)
def assertNotIn(self, item, list_, msg=None, fail_msg=None, extra_format=()): def assertNotIn(
self,
member: Any,
container: Union[Iterable[Any], Container[Any]],
msg: Optional[str] = None,
fail_msg: Optional[str] = None,
extra_format: Tuple = (),
) -> None:
if fail_msg: if fail_msg:
fail_msg = fail_msg.format(*extra_format, item=item, list=list_, msg=msg) fail_msg = fail_msg.format(
super().assertNotIn(item, list_, fail_msg) *extra_format, item=member, list=container, msg=msg
)
super().assertNotIn(member, container, fail_msg)
def assertEqual(self, got, expects, msg=None, fail_msg=None, extra_format=()): def assertEqual(
self,
got: T,
expects: T,
msg: Any = None,
fail_msg: Optional[str] = None,
extra_format: Tuple = (),
) -> None:
if fail_msg: if fail_msg:
fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg) fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
super().assertEqual(got, expects, fail_msg) super().assertEqual(got, expects, fail_msg)
def assertNotEqual(self, got, expects, msg=None, fail_msg=None, extra_format=()): def assertNotEqual(
self,
got: T,
expects: T,
msg: Any = None,
fail_msg: Optional[str] = None,
extra_format: Tuple = (),
) -> None:
if fail_msg: if fail_msg:
fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg) fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
super().assertNotEqual(got, expects, fail_msg) super().assertNotEqual(got, expects, fail_msg)
def assertGreater(self, got, expects, msg=None, fail_msg=None, extra_format=()): def assertGreater(
self,
got: T,
expects: T,
msg: Any = None,
fail_msg: Optional[str] = None,
extra_format: Tuple = (),
) -> None:
if fail_msg: if fail_msg:
fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg) fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
super().assertGreater(got, expects, fail_msg) super().assertGreater(got, expects, fail_msg)
def assertGreaterEqual( def assertGreaterEqual(
self, got, expects, msg=None, fail_msg=None, extra_format=() self,
): got: T,
expects: T,
msg: Any = None,
fail_msg: Optional[str] = None,
extra_format: Tuple = (),
) -> None:
if fail_msg: if fail_msg:
fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg) fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
super().assertGreaterEqual(got, expects, fail_msg) super().assertGreaterEqual(got, expects, fail_msg)
def assertLess(self, got, expects, msg=None, fail_msg=None, extra_format=()): def assertLess(
self,
got: T,
expects: T,
msg: Any = None,
fail_msg: Optional[str] = None,
extra_format: Tuple = (),
) -> None:
if fail_msg: if fail_msg:
fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg) fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
super().assertLess(got, expects, fail_msg) super().assertLess(got, expects, fail_msg)
def assertLessEqual(self, got, expects, msg=None, fail_msg=None, extra_format=()): def assertLessEqual(
self,
got: T,
expects: T,
msg: Any = None,
fail_msg: Optional[str] = None,
extra_format: Tuple = (),
) -> None:
if fail_msg: if fail_msg:
fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg) fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
super().assertLessEqual(got, expects, fail_msg) super().assertLessEqual(got, expects, fail_msg)
class BaseClientTestCase(_IrcTestCase): class BaseClientTestCase(_IrcTestCase[basecontrollers.BaseClientController]):
"""Basic class for client tests. Handles spawning a client and exchanging """Basic class for client tests. Handles spawning a client and exchanging
messages with it.""" messages with it."""
nick = None conn: Optional[socket.socket]
user = None nick: Optional[str] = None
user: Optional[List[str]] = None
server: socket.socket
protocol_version = Optional[str]
acked_capabilities = Optional[Set[str]]
def setUp(self): def setUp(self) -> None:
super().setUp() super().setUp()
self.conn = None self.conn = None
self._setUpServer() self._setUpServer()
def tearDown(self): def tearDown(self) -> None:
if self.conn: if self.conn:
try: try:
self.conn.sendall(b"QUIT :end of test.") self.conn.sendall(b"QUIT :end of test.")
@ -214,7 +308,7 @@ class BaseClientTestCase(_IrcTestCase):
self.conn.close() self.conn.close()
self.server.close() self.server.close()
def _setUpServer(self): def _setUpServer(self) -> None:
"""Creates the server and make it listen.""" """Creates the server and make it listen."""
self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.server.bind(("", 0)) # Bind any free port self.server.bind(("", 0)) # Bind any free port
@ -223,9 +317,15 @@ class BaseClientTestCase(_IrcTestCase):
# Used to check if the client is alive from time to time # Used to check if the client is alive from time to time
self.server.settimeout(1) self.server.settimeout(1)
def acceptClient(self, tls_cert=None, tls_key=None, server=None): def acceptClient(
self,
tls_cert: Optional[str] = None,
tls_key: Optional[str] = None,
server: Optional[socket.socket] = None,
) -> None:
"""Make the server accept a client connection. Blocking.""" """Make the server accept a client connection. Blocking."""
server = server or self.server server = server or self.server
assert server
# Wait for the client to connect # Wait for the client to connect
while True: while True:
try: try:
@ -252,17 +352,17 @@ class BaseClientTestCase(_IrcTestCase):
self.conn = context.wrap_socket(self.conn, server_side=True) self.conn = context.wrap_socket(self.conn, server_side=True)
self.conn_file = self.conn.makefile(newline="\r\n", encoding="utf8") self.conn_file = self.conn.makefile(newline="\r\n", encoding="utf8")
def getLine(self): def getLine(self) -> str:
line = self.conn_file.readline() line = self.conn_file.readline()
if self.show_io: if self.show_io:
print("{:.3f} C: {}".format(time.time(), line.strip())) print("{:.3f} C: {}".format(time.time(), line.strip()))
return line return line
def getMessages(self, *args): def getMessage(
lines = self.getLines(*args) self,
return map(message_parser.parse_message, lines) *args: Any,
filter_pred: Optional[Callable[[Message], bool]] = None,
def getMessage(self, *args, filter_pred=None): ) -> Message:
"""Gets a message and returns it. If a filter predicate is given, """Gets a message and returns it. If a filter predicate is given,
fetches messages until the predicate returns a False on a message, fetches messages until the predicate returns a False on a message,
and returns this message.""" and returns this message."""
@ -274,18 +374,17 @@ class BaseClientTestCase(_IrcTestCase):
if not filter_pred or filter_pred(msg): if not filter_pred or filter_pred(msg):
return msg return msg
def sendLine(self, line): def sendLine(self, line: str) -> None:
assert self.conn
self.conn.sendall(line.encode()) self.conn.sendall(line.encode())
if not line.endswith("\r\n"): if not line.endswith("\r\n"):
self.conn.sendall(b"\r\n") self.conn.sendall(b"\r\n")
if self.show_io: if self.show_io:
print("{:.3f} S: {}".format(time.time(), line.strip())) print("{:.3f} S: {}".format(time.time(), line.strip()))
def readCapLs(
class ClientNegociationHelper: self, auth: Optional[Authentication] = None, tls_config: tls.TlsConfig = None
"""Helper class for tests handling capabilities negociation.""" ) -> None:
def readCapLs(self, auth=None, tls_config=None):
(hostname, port) = self.server.getsockname() (hostname, port) = self.server.getsockname()
self.controller.run( self.controller.run(
hostname=hostname, port=port, auth=auth, tls_config=tls_config hostname=hostname, port=port, auth=auth, tls_config=tls_config
@ -302,28 +401,33 @@ class ClientNegociationHelper:
else: else:
raise AssertionError("Unknown CAP params: {}".format(m.params)) raise AssertionError("Unknown CAP params: {}".format(m.params))
def userNickPredicate(self, msg): def userNickPredicate(self, msg: Message) -> bool:
"""Predicate to be used with getMessage to handle NICK/USER """Predicate to be used with getMessage to handle NICK/USER
transparently.""" transparently."""
if msg.command == "NICK": if msg.command == "NICK":
self.assertEqual(len(msg.params), 1, msg) self.assertEqual(len(msg.params), 1, msg=msg)
self.nick = msg.params[0] self.nick = msg.params[0]
return False return False
elif msg.command == "USER": elif msg.command == "USER":
self.assertEqual(len(msg.params), 4, msg) self.assertEqual(len(msg.params), 4, msg=msg)
self.user = msg.params self.user = msg.params
return False return False
else: else:
return True return True
def negotiateCapabilities(self, caps, cap_ls=True, auth=None): def negotiateCapabilities(
self,
caps: List[str],
cap_ls: bool = True,
auth: Optional[Authentication] = None,
) -> Optional[Message]:
"""Performes a complete capability negociation process, without """Performes a complete capability negociation process, without
ending it, so the caller can continue the negociation.""" ending it, so the caller can continue the negociation."""
if cap_ls: if cap_ls:
self.readCapLs(auth) self.readCapLs(auth)
if not self.protocol_version: if not self.protocol_version:
# No negotiation. # No negotiation.
return return None
self.sendLine("CAP * LS :{}".format(" ".join(caps))) self.sendLine("CAP * LS :{}".format(" ".join(caps)))
capability_names = frozenset(capabilities.cap_list_to_dict(caps)) capability_names = frozenset(capabilities.cap_list_to_dict(caps))
self.acked_capabilities = set() self.acked_capabilities = set()
@ -343,21 +447,25 @@ class ClientNegociationHelper:
self.sendLine( self.sendLine(
"CAP {} ACK :{}".format(self.nick or "*", m.params[1]) "CAP {} ACK :{}".format(self.nick or "*", m.params[1])
) )
self.acked_capabilities.update(requested) self.acked_capabilities.update(requested) # type: ignore
else: else:
return m return m
class BaseServerTestCase(_IrcTestCase): class BaseServerTestCase(
_IrcTestCase[basecontrollers.BaseServerController], Generic[TClientName]
):
"""Basic class for server tests. Handles spawning a server and exchanging """Basic class for server tests. Handles spawning a server and exchanging
messages with it.""" messages with it."""
show_io: bool # set by conftest.py
password: Optional[str] = None password: Optional[str] = None
ssl = False ssl = False
valid_metadata_keys: Set[str] = set() valid_metadata_keys: Set[str] = set()
invalid_metadata_keys: Set[str] = set() invalid_metadata_keys: Set[str] = set()
def setUp(self): def setUp(self) -> None:
super().setUp() super().setUp()
self.server_support = None self.server_support = None
self.find_hostname_and_port() self.find_hostname_and_port()
@ -369,53 +477,64 @@ class BaseServerTestCase(_IrcTestCase):
invalid_metadata_keys=self.invalid_metadata_keys, invalid_metadata_keys=self.invalid_metadata_keys,
ssl=self.ssl, ssl=self.ssl,
) )
self.clients = {} self.clients: Dict[TClientName, client_mock.ClientMock] = {}
def tearDown(self): def tearDown(self) -> None:
self.controller.kill() self.controller.kill()
for client in list(self.clients): for client in list(self.clients):
self.removeClient(client) self.removeClient(client)
def find_hostname_and_port(self): def find_hostname_and_port(self) -> None:
"""Find available hostname/port to listen on.""" """Find available hostname/port to listen on."""
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0)) s.bind(("", 0))
(self.hostname, self.port) = s.getsockname() (self.hostname, self.port) = s.getsockname()
s.close() s.close()
def addClient(self, name=None, show_io=None): def addClient(
self, name: Optional[TClientName] = None, show_io: Optional[bool] = None
) -> TClientName:
"""Connects a client to the server and adds it to the dict. """Connects a client to the server and adds it to the dict.
If 'name' is not given, uses the lowest unused non-negative integer.""" If 'name' is not given, uses the lowest unused non-negative integer."""
self.controller.wait_for_port() self.controller.wait_for_port()
if not name: if not name:
name = max(map(int, list(self.clients) + [0])) + 1 new_name: int = (
max(
[int(name) for name in self.clients if isinstance(name, (int, str))]
+ [0]
)
+ 1
)
name = cast(TClientName, new_name)
show_io = show_io if show_io is not None else self.show_io show_io = show_io if show_io is not None else self.show_io
self.clients[name] = client_mock.ClientMock(name=name, show_io=show_io) self.clients[name] = client_mock.ClientMock(name=name, show_io=show_io)
self.clients[name].connect(self.hostname, self.port) self.clients[name].connect(self.hostname, self.port)
return name return name
def removeClient(self, name): def removeClient(self, name: TClientName) -> None:
"""Disconnects the client, without QUIT.""" """Disconnects the client, without QUIT."""
assert name in self.clients assert name in self.clients
self.clients[name].disconnect() self.clients[name].disconnect()
del self.clients[name] del self.clients[name]
def getMessages(self, client, **kwargs): def getMessages(self, client: TClientName, **kwargs: Any) -> List[Message]:
return self.clients[client].getMessages(**kwargs) return self.clients[client].getMessages(**kwargs)
def getMessage(self, client, **kwargs): def getMessage(self, client: TClientName, **kwargs: Any) -> Message:
return self.clients[client].getMessage(**kwargs) return self.clients[client].getMessage(**kwargs)
def getRegistrationMessage(self, client): def getRegistrationMessage(self, client: TClientName) -> Message:
"""Filter notices, do not send pings.""" """Filter notices, do not send pings."""
return self.getMessage( return self.getMessage(
client, synchronize=False, filter_pred=lambda m: m.command != "NOTICE" client, synchronize=False, filter_pred=lambda m: m.command != "NOTICE"
) )
def sendLine(self, client, line): def sendLine(self, client: TClientName, line: Union[str, bytes]) -> None:
return self.clients[client].sendLine(line) return self.clients[client].sendLine(line)
def getCapLs(self, client, as_list=False): def getCapLs(
self, client: TClientName, as_list: bool = False
) -> Union[List[str], Dict[str, Optional[str]]]:
"""Waits for a CAP LS block, parses all CAP LS messages, and return """Waits for a CAP LS block, parses all CAP LS messages, and return
the dict capabilities, with their values. the dict capabilities, with their values.
@ -431,10 +550,10 @@ class BaseServerTestCase(_IrcTestCase):
else: else:
caps.extend(m.params[2].split()) caps.extend(m.params[2].split())
if not as_list: if not as_list:
caps = capabilities.cap_list_to_dict(caps) return capabilities.cap_list_to_dict(caps)
return caps return caps
def assertDisconnected(self, client): def assertDisconnected(self, client: TClientName) -> None:
try: try:
self.getMessages(client) self.getMessages(client)
self.getMessages(client) self.getMessages(client)
@ -444,7 +563,7 @@ class BaseServerTestCase(_IrcTestCase):
else: else:
raise AssertionError("Client not disconnected.") raise AssertionError("Client not disconnected.")
def skipToWelcome(self, client): def skipToWelcome(self, client: TClientName) -> List[Message]:
"""Skip to the point where we are registered """Skip to the point where we are registered
<https://tools.ietf.org/html/rfc2812#section-3.1> <https://tools.ietf.org/html/rfc2812#section-3.1>
""" """
@ -457,15 +576,19 @@ class BaseServerTestCase(_IrcTestCase):
def connectClient( def connectClient(
self, self,
nick, nick: str,
name=None, name: TClientName = None,
capabilities=None, capabilities: Optional[List[str]] = None,
skip_if_cap_nak=False, skip_if_cap_nak: bool = False,
show_io=None, show_io: Optional[bool] = None,
account=None, account: Optional[str] = None,
password=None, password: Optional[str] = None,
ident="username", ident: str = "username",
): ) -> List[Message]:
"""Connections a new client, does the cap negotiation
and connection registration, and skips to the end of the MOTD.
Returns the list of all messages received after registration,
just like `skipToWelcome`."""
client = self.addClient(name, show_io=show_io) client = self.addClient(name, show_io=show_io)
if capabilities is not None and 0 < len(capabilities): if capabilities is not None and 0 < len(capabilities):
self.sendLine(client, "CAP REQ :{}".format(" ".join(capabilities))) self.sendLine(client, "CAP REQ :{}".format(" ".join(capabilities)))
@ -502,14 +625,14 @@ class BaseServerTestCase(_IrcTestCase):
for param in m.params[1:-1]: for param in m.params[1:-1]:
if "=" in param: if "=" in param:
(key, value) = param.split("=") (key, value) = param.split("=")
self.server_support[key] = value
else: else:
(key, value) = (param, None) self.server_support[param] = None
self.server_support[key] = value
welcome.append(m) welcome.append(m)
return welcome return welcome
def joinClient(self, client, channel): def joinClient(self, client: TClientName, channel: str) -> None:
self.sendLine(client, "JOIN {}".format(channel)) self.sendLine(client, "JOIN {}".format(channel))
received = {m.command for m in self.getMessages(client)} received = {m.command for m in self.getMessages(client)}
self.assertIn( self.assertIn(
@ -520,7 +643,7 @@ class BaseServerTestCase(_IrcTestCase):
extra_format=(channel,), extra_format=(channel,),
) )
def joinChannel(self, client, channel): def joinChannel(self, client: TClientName, channel: str) -> None:
self.sendLine(client, "JOIN {}".format(channel)) self.sendLine(client, "JOIN {}".format(channel))
# wait until we see them join the channel # wait until we see them join the channel
joined = False joined = False
@ -537,24 +660,34 @@ class BaseServerTestCase(_IrcTestCase):
raise ChannelJoinException(msg.command, msg.params) raise ChannelJoinException(msg.command, msg.params)
class OptionalityHelper: _TSelf = TypeVar("_TSelf", bound="OptionalityHelper")
controller: basecontrollers.BaseServerController _TReturn = TypeVar("_TReturn")
def checkSaslSupport(self):
class OptionalityHelper(Generic[TController]):
controller: TController
def checkSaslSupport(self) -> None:
if self.controller.supported_sasl_mechanisms: if self.controller.supported_sasl_mechanisms:
return return
raise runner.NotImplementedByController("SASL") raise runner.NotImplementedByController("SASL")
def checkMechanismSupport(self, mechanism): def checkMechanismSupport(self, mechanism: str) -> None:
if mechanism in self.controller.supported_sasl_mechanisms: if mechanism in self.controller.supported_sasl_mechanisms:
return return
raise runner.OptionalSaslMechanismNotSupported(mechanism) raise runner.OptionalSaslMechanismNotSupported(mechanism)
@staticmethod @staticmethod
def skipUnlessHasMechanism(mech): def skipUnlessHasMechanism(
def decorator(f): mech: str,
) -> Callable[[Callable[[_TSelf], _TReturn]], Callable[[_TSelf], _TReturn]]:
# Just a function returning a function that takes functions and
# returns functions, nothing to see here.
# If Python didn't have such an awful syntax for callables, it would be:
# str -> ((TSelf -> TReturn) -> (TSelf -> TReturn))
def decorator(f: Callable[[_TSelf], _TReturn]) -> Callable[[_TSelf], _TReturn]:
@functools.wraps(f) @functools.wraps(f)
def newf(self): def newf(self: _TSelf) -> _TReturn:
self.checkMechanismSupport(mech) self.checkMechanismSupport(mech)
return f(self) return f(self)
@ -562,23 +695,29 @@ class OptionalityHelper:
return decorator return decorator
def skipUnlessHasSasl(f): @staticmethod
def skipUnlessHasSasl(
f: Callable[[_TSelf], _TReturn]
) -> Callable[[_TSelf], _TReturn]:
@functools.wraps(f) @functools.wraps(f)
def newf(self): def newf(self: _TSelf) -> _TReturn:
self.checkSaslSupport() self.checkSaslSupport()
return f(self) return f(self)
return newf return newf
def mark_specifications(*specifications, deprecated=False, strict=False): def mark_specifications(
*specifications_str: str, deprecated: bool = False, strict: bool = False
) -> Callable[[TCallable], TCallable]:
specifications = frozenset( specifications = frozenset(
Specifications.from_name(s) if isinstance(s, str) else s for s in specifications Specifications.from_name(s) if isinstance(s, str) else s
for s in specifications_str
) )
if None in specifications: if None in specifications:
raise ValueError("Invalid set of specifications: {}".format(specifications)) raise ValueError("Invalid set of specifications: {}".format(specifications))
def decorator(f): def decorator(f: TCallable) -> TCallable:
for specification in specifications: for specification in specifications:
f = getattr(pytest.mark, specification.value)(f) f = getattr(pytest.mark, specification.value)(f)
if strict: if strict:
@ -590,14 +729,16 @@ def mark_specifications(*specifications, deprecated=False, strict=False):
return decorator return decorator
def mark_capabilities(*capabilities, deprecated=False, strict=False): def mark_capabilities(
*capabilities_str: str, deprecated: bool = False, strict: bool = False
) -> Callable[[TCallable], TCallable]:
capabilities = frozenset( capabilities = frozenset(
Capabilities.from_name(c) if isinstance(c, str) else c for c in capabilities Capabilities.from_name(c) if isinstance(c, str) else c for c in capabilities_str
) )
if None in capabilities: if None in capabilities:
raise ValueError("Invalid set of capabilities: {}".format(capabilities)) raise ValueError("Invalid set of capabilities: {}".format(capabilities))
def decorator(f): def decorator(f: TCallable) -> TCallable:
for capability in capabilities: for capability in capabilities:
f = getattr(pytest.mark, capability.value)(f) f = getattr(pytest.mark, capability.value)(f)
# Support for any capability implies IRCv3 # Support for any capability implies IRCv3
@ -607,14 +748,16 @@ def mark_capabilities(*capabilities, deprecated=False, strict=False):
return decorator return decorator
def mark_isupport(*tokens, deprecated=False, strict=False): def mark_isupport(
*tokens_str: str, deprecated: bool = False, strict: bool = False
) -> Callable[[TCallable], TCallable]:
tokens = frozenset( tokens = frozenset(
IsupportTokens.from_name(c) if isinstance(c, str) else c for c in tokens IsupportTokens.from_name(c) if isinstance(c, str) else c for c in tokens_str
) )
if None in tokens: if None in tokens:
raise ValueError("Invalid set of isupport tokens: {}".format(tokens)) raise ValueError("Invalid set of isupport tokens: {}".format(tokens))
def decorator(f): def decorator(f: TCallable) -> TCallable:
for token in tokens: for token in tokens:
f = getattr(pytest.mark, token.value)(f) f = getattr(pytest.mark, token.value)(f)
return f return f

View File

@ -2,36 +2,41 @@ import socket
import ssl import ssl
import sys import sys
import time import time
from typing import Any, Callable, List, Optional, Union
from .exceptions import ConnectionClosed, NoMessageException from .exceptions import ConnectionClosed, NoMessageException
from .irc_utils import message_parser from .irc_utils import message_parser
class ClientMock: class ClientMock:
def __init__(self, name, show_io): def __init__(self, name: Any, show_io: bool):
self.name = name self.name = name
self.show_io = show_io self.show_io = show_io
self.inbuffer = [] self.inbuffer: List[message_parser.Message] = []
self.ssl = False self.ssl = False
def connect(self, hostname, port): def connect(self, hostname: str, port: int) -> None:
self.conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.conn.settimeout(1) # TODO: configurable self.conn.settimeout(1) # TODO: configurable
self.conn.connect((hostname, port)) self.conn.connect((hostname, port))
if self.show_io: if self.show_io:
print("{:.3f} {}: connects to server.".format(time.time(), self.name)) print("{:.3f} {}: connects to server.".format(time.time(), self.name))
def disconnect(self): def disconnect(self) -> None:
if self.show_io: if self.show_io:
print("{:.3f} {}: disconnects from server.".format(time.time(), self.name)) print("{:.3f} {}: disconnects from server.".format(time.time(), self.name))
self.conn.close() self.conn.close()
def starttls(self): def starttls(self) -> None:
assert not self.ssl, "SSL already active." assert not self.ssl, "SSL already active."
self.conn = ssl.wrap_socket(self.conn) self.conn = ssl.wrap_socket(self.conn)
self.ssl = True self.ssl = True
def getMessages(self, synchronize=True, assert_get_one=False, raw=False): def getMessages(
self, synchronize: bool = True, assert_get_one: bool = False, raw: bool = False
) -> List[message_parser.Message]:
"""actually returns List[str] in the rare case where raw=True."""
token: Optional[str]
if synchronize: if synchronize:
token = "synchronize{}".format(time.monotonic()) token = "synchronize{}".format(time.monotonic())
self.sendLine("PING {}".format(token)) self.sendLine("PING {}".format(token))
@ -79,7 +84,7 @@ class ClientMock:
got_pong = True got_pong = True
else: else:
if raw: if raw:
messages.append(line) messages.append(line) # type: ignore
else: else:
messages.append(message) messages.append(message)
data = b"" data = b""
@ -91,7 +96,13 @@ class ClientMock:
else: else:
return messages return messages
def getMessage(self, filter_pred=None, synchronize=True, raw=False): def getMessage(
self,
filter_pred: Optional[Callable[[message_parser.Message], bool]] = None,
synchronize: bool = True,
raw: bool = False,
) -> message_parser.Message:
"""Returns str in the rare case where raw=True"""
while True: while True:
if not self.inbuffer: if not self.inbuffer:
self.inbuffer = self.getMessages( self.inbuffer = self.getMessages(
@ -103,7 +114,7 @@ class ClientMock:
if not filter_pred or filter_pred(message): if not filter_pred or filter_pred(message):
return message return message
def sendLine(self, line): def sendLine(self, line: Union[str, bytes]) -> None:
if isinstance(line, str): if isinstance(line, str):
encoded_line = line.encode() encoded_line = line.encode()
elif isinstance(line, bytes): elif isinstance(line, bytes):
@ -113,7 +124,7 @@ class ClientMock:
if not encoded_line.endswith(b"\r\n"): if not encoded_line.endswith(b"\r\n"):
encoded_line += b"\r\n" encoded_line += b"\r\n"
try: try:
ret = self.conn.sendall(encoded_line) ret = self.conn.sendall(encoded_line) # type: ignore
except BrokenPipeError: except BrokenPipeError:
raise ConnectionClosed() raise ConnectionClosed()
if ( if (

View File

@ -2,7 +2,7 @@ from irctest import cases
from irctest.irc_utils.message_parser import Message from irctest.irc_utils.message_parser import Message
class CapTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper): class CapTestCase(cases.BaseClientTestCase):
@cases.mark_specifications("IRCv3") @cases.mark_specifications("IRCv3")
def testSendCap(self): def testSendCap(self):
"""Send CAP LS 302 and read the result.""" """Send CAP LS 302 and read the result."""

View File

@ -39,9 +39,7 @@ class IdentityHash:
return self._data return self._data
class SaslTestCase( class SaslTestCase(cases.BaseClientTestCase, cases.OptionalityHelper):
cases.BaseClientTestCase, cases.ClientNegociationHelper, cases.OptionalityHelper
):
@cases.OptionalityHelper.skipUnlessHasMechanism("PLAIN") @cases.OptionalityHelper.skipUnlessHasMechanism("PLAIN")
def testPlain(self): def testPlain(self):
"""Test PLAIN authentication with correct username/password.""" """Test PLAIN authentication with correct username/password."""
@ -263,9 +261,7 @@ class SaslTestCase(
authenticator.response(msg) authenticator.response(msg)
class Irc302SaslTestCase( class Irc302SaslTestCase(cases.BaseClientTestCase, cases.OptionalityHelper):
cases.BaseClientTestCase, cases.ClientNegociationHelper, cases.OptionalityHelper
):
@cases.OptionalityHelper.skipUnlessHasMechanism("PLAIN") @cases.OptionalityHelper.skipUnlessHasMechanism("PLAIN")
def testPlainNotAvailable(self): def testPlainNotAvailable(self):
"""Test the client does not try to authenticate using a mechanism the """Test the client does not try to authenticate using a mechanism the

View File

@ -1,6 +1,6 @@
import os import os
import subprocess import subprocess
from typing import Set from typing import Optional, Set, Type
from irctest.basecontrollers import ( from irctest.basecontrollers import (
BaseServerController, BaseServerController,
@ -47,20 +47,21 @@ class CharybdisController(BaseServerController, DirectoryBasedController):
supported_sasl_mechanisms: Set[str] = set() supported_sasl_mechanisms: Set[str] = set()
supports_sts = False supports_sts = False
def create_config(self): def create_config(self) -> None:
super().create_config() super().create_config()
with self.open_file("server.conf"): with self.open_file("server.conf"):
pass pass
def run( def run(
self, self,
hostname, hostname: str,
port, port: int,
password=None, *,
ssl=False, password: Optional[str],
valid_metadata_keys=None, ssl: bool,
invalid_metadata_keys=None, valid_metadata_keys: Optional[Set[str]] = None,
): invalid_metadata_keys: Optional[Set[str]] = None,
) -> None:
if valid_metadata_keys or invalid_metadata_keys: if valid_metadata_keys or invalid_metadata_keys:
raise NotImplementedByController( raise NotImplementedByController(
"Defining valid and invalid METADATA keys." "Defining valid and invalid METADATA keys."
@ -85,6 +86,7 @@ class CharybdisController(BaseServerController, DirectoryBasedController):
ssl_config=ssl_config, ssl_config=ssl_config,
) )
) )
assert self.directory
self.proc = subprocess.Popen( self.proc = subprocess.Popen(
[ [
self.binary_name, self.binary_name,
@ -98,5 +100,5 @@ class CharybdisController(BaseServerController, DirectoryBasedController):
) )
def get_irctest_controller_class(): def get_irctest_controller_class() -> Type[CharybdisController]:
return CharybdisController return CharybdisController

View File

@ -1,33 +1,25 @@
import subprocess import subprocess
from typing import Optional, Type
from irctest.basecontrollers import BaseClientController, NotImplementedByController from irctest import authentication, tls
from irctest.basecontrollers import (
BaseClientController,
DirectoryBasedController,
NotImplementedByController,
)
class GircController(BaseClientController): class GircController(BaseClientController, DirectoryBasedController):
software_name = "gIRC" software_name = "gIRC"
supported_sasl_mechanisms = ["PLAIN"] supported_sasl_mechanisms = {"PLAIN"}
def __init__(self): def run(
super().__init__() self,
self.directory = None hostname: str,
self.proc = None port: int,
auth: Optional[authentication.Authentication],
def kill(self): tls_config: Optional[tls.TlsConfig] = None,
if self.proc: ) -> None:
self.proc.terminate()
try:
self.proc.wait(5)
except subprocess.TimeoutExpired:
self.proc.kill()
self.proc = None
def __del__(self):
if self.proc:
self.proc.kill()
if self.directory:
self.directory.cleanup()
def run(self, hostname, port, auth, tls_config):
if tls_config: if tls_config:
print(tls_config) print(tls_config)
raise NotImplementedByController("TLS options") raise NotImplementedByController("TLS options")
@ -42,5 +34,5 @@ class GircController(BaseClientController):
self.proc = subprocess.Popen(["girc_test", "connect"] + args) self.proc = subprocess.Popen(["girc_test", "connect"] + args)
def get_irctest_controller_class(): def get_irctest_controller_class() -> Type[GircController]:
return GircController return GircController

View File

@ -1,6 +1,6 @@
import os import os
import subprocess import subprocess
from typing import Set from typing import Optional, Set, Type
from irctest.basecontrollers import ( from irctest.basecontrollers import (
BaseServerController, BaseServerController,
@ -44,20 +44,21 @@ class HybridController(BaseServerController, DirectoryBasedController):
supports_sts = False supports_sts = False
supported_sasl_mechanisms: Set[str] = set() supported_sasl_mechanisms: Set[str] = set()
def create_config(self): def create_config(self) -> None:
super().create_config() super().create_config()
with self.open_file("server.conf"): with self.open_file("server.conf"):
pass pass
def run( def run(
self, self,
hostname, hostname: str,
port, port: int,
password=None, *,
ssl=False, password: Optional[str],
valid_metadata_keys=None, ssl: bool,
invalid_metadata_keys=None, valid_metadata_keys: Optional[Set[str]] = None,
): invalid_metadata_keys: Optional[Set[str]] = None,
) -> None:
if valid_metadata_keys or invalid_metadata_keys: if valid_metadata_keys or invalid_metadata_keys:
raise NotImplementedByController( raise NotImplementedByController(
"Defining valid and invalid METADATA keys." "Defining valid and invalid METADATA keys."
@ -82,6 +83,7 @@ class HybridController(BaseServerController, DirectoryBasedController):
ssl_config=ssl_config, ssl_config=ssl_config,
) )
) )
assert self.directory
self.proc = subprocess.Popen( self.proc = subprocess.Popen(
[ [
"ircd", "ircd",
@ -96,5 +98,5 @@ class HybridController(BaseServerController, DirectoryBasedController):
) )
def get_irctest_controller_class(): def get_irctest_controller_class() -> Type[HybridController]:
return HybridController return HybridController

View File

@ -1,6 +1,6 @@
import os import os
import subprocess import subprocess
from typing import Set from typing import Optional, Set, Type
from irctest.basecontrollers import ( from irctest.basecontrollers import (
BaseServerController, BaseServerController,
@ -42,21 +42,22 @@ class InspircdController(BaseServerController, DirectoryBasedController):
supported_sasl_mechanisms: Set[str] = set() supported_sasl_mechanisms: Set[str] = set()
supports_str = False supports_str = False
def create_config(self): def create_config(self) -> None:
super().create_config() super().create_config()
with self.open_file("server.conf"): with self.open_file("server.conf"):
pass pass
def run( def run(
self, self,
hostname, hostname: str,
port, port: int,
password=None, *,
ssl=False, password: Optional[str],
restricted_metadata_keys=None, ssl: bool,
valid_metadata_keys=None, valid_metadata_keys: Optional[Set[str]] = None,
invalid_metadata_keys=None, invalid_metadata_keys: Optional[Set[str]] = None,
): restricted_metadata_keys: Optional[Set[str]] = None,
) -> None:
if valid_metadata_keys or invalid_metadata_keys: if valid_metadata_keys or invalid_metadata_keys:
raise NotImplementedByController( raise NotImplementedByController(
"Defining valid and invalid METADATA keys." "Defining valid and invalid METADATA keys."
@ -81,6 +82,7 @@ class InspircdController(BaseServerController, DirectoryBasedController):
ssl_config=ssl_config, ssl_config=ssl_config,
) )
) )
assert self.directory
self.proc = subprocess.Popen( self.proc = subprocess.Popen(
[ [
"inspircd", "inspircd",
@ -92,5 +94,5 @@ class InspircdController(BaseServerController, DirectoryBasedController):
) )
def get_irctest_controller_class(): def get_irctest_controller_class() -> Type[InspircdController]:
return InspircdController return InspircdController

View File

@ -1,3 +1,5 @@
from typing import Type
from .charybdis import CharybdisController from .charybdis import CharybdisController
@ -6,5 +8,5 @@ class IrcdSevenController(CharybdisController):
binary_name = "ircd-seven" binary_name = "ircd-seven"
def get_irctest_controller_class(): def get_irctest_controller_class() -> Type[IrcdSevenController]:
return IrcdSevenController return IrcdSevenController

View File

@ -1,7 +1,8 @@
import os import os
import subprocess import subprocess
from typing import Optional, Type
from irctest import tls from irctest import authentication, tls
from irctest.basecontrollers import BaseClientController, DirectoryBasedController from irctest.basecontrollers import BaseClientController, DirectoryBasedController
TEMPLATE_CONFIG = """ TEMPLATE_CONFIG = """
@ -35,15 +36,20 @@ class LimnoriaController(BaseClientController, DirectoryBasedController):
} }
supports_sts = True supports_sts = True
def create_config(self): def create_config(self) -> None:
create_config = super().create_config() super().create_config()
if create_config: with self.open_file("bot.conf"):
with self.open_file("bot.conf"): pass
pass with self.open_file("conf/users.conf"):
with self.open_file("conf/users.conf"): pass
pass
def run(self, hostname, port, auth, tls_config=None): def run(
self,
hostname: str,
port: int,
auth: Optional[authentication.Authentication],
tls_config: Optional[tls.TlsConfig] = None,
) -> None:
if tls_config is None: if tls_config is None:
tls_config = tls.TlsConfig(enable=False, trusted_fingerprints=[]) tls_config = tls.TlsConfig(enable=False, trusted_fingerprints=[])
# Runs a client with the config given as arguments # Runs a client with the config given as arguments
@ -72,10 +78,11 @@ class LimnoriaController(BaseClientController, DirectoryBasedController):
else "", else "",
) )
) )
assert self.directory
self.proc = subprocess.Popen( self.proc = subprocess.Popen(
["supybot", os.path.join(self.directory, "bot.conf")] ["supybot", os.path.join(self.directory, "bot.conf")]
) )
def get_irctest_controller_class(): def get_irctest_controller_class() -> Type[LimnoriaController]:
return LimnoriaController return LimnoriaController

View File

@ -1,11 +1,13 @@
import os import os
import subprocess import subprocess
from typing import Optional, Set, Type
from irctest.basecontrollers import ( from irctest.basecontrollers import (
BaseServerController, BaseServerController,
DirectoryBasedController, DirectoryBasedController,
NotImplementedByController, NotImplementedByController,
) )
from irctest.cases import BaseServerTestCase
TEMPLATE_CONFIG = """ TEMPLATE_CONFIG = """
clients: clients:
@ -61,7 +63,7 @@ server:
""" """
def make_list(list_): def make_list(list_: Set[str]) -> str:
return "\n".join(map(" - {}".format, list_)) return "\n".join(map(" - {}".format, list_))
@ -69,25 +71,27 @@ class MammonController(BaseServerController, DirectoryBasedController):
software_name = "Mammon" software_name = "Mammon"
supported_sasl_mechanisms = {"PLAIN", "ECDSA-NIST256P-CHALLENGE"} supported_sasl_mechanisms = {"PLAIN", "ECDSA-NIST256P-CHALLENGE"}
def create_config(self): def create_config(self) -> None:
super().create_config() super().create_config()
with self.open_file("server.conf"): with self.open_file("server.conf"):
pass pass
def kill_proc(self): def kill_proc(self) -> None:
# Mammon does not seem to handle SIGTERM very well # Mammon does not seem to handle SIGTERM very well
assert self.proc
self.proc.kill() self.proc.kill()
def run( def run(
self, self,
hostname, hostname: str,
port, port: int,
password=None, *,
ssl=False, password: Optional[str],
restricted_metadata_keys=(), ssl: bool,
valid_metadata_keys=(), valid_metadata_keys: Optional[Set[str]] = None,
invalid_metadata_keys=(), invalid_metadata_keys: Optional[Set[str]] = None,
): restricted_metadata_keys: Optional[Set[str]] = None,
) -> None:
if password is not None: if password is not None:
raise NotImplementedByController("PASS command") raise NotImplementedByController("PASS command")
if ssl: if ssl:
@ -101,12 +105,13 @@ class MammonController(BaseServerController, DirectoryBasedController):
directory=self.directory, directory=self.directory,
hostname=hostname, hostname=hostname,
port=port, port=port,
authorized_keys=make_list(valid_metadata_keys), authorized_keys=make_list(valid_metadata_keys or set()),
restricted_keys=make_list(restricted_metadata_keys), restricted_keys=make_list(restricted_metadata_keys or set()),
) )
) )
# with self.open_file('server.yml', 'r') as fd: # with self.open_file('server.yml', 'r') as fd:
# print(fd.read()) # print(fd.read())
assert self.directory
self.proc = subprocess.Popen( self.proc = subprocess.Popen(
[ [
"mammond", "mammond",
@ -116,7 +121,12 @@ class MammonController(BaseServerController, DirectoryBasedController):
] ]
) )
def registerUser(self, case, username, password=None): def registerUser(
self,
case: BaseServerTestCase,
username: str,
password: Optional[str] = None,
) -> None:
# XXX: Move this somewhere else when # XXX: Move this somewhere else when
# https://github.com/ircv3/ircv3-specifications/pull/152 becomes # https://github.com/ircv3/ircv3-specifications/pull/152 becomes
# part of the specification # part of the specification
@ -135,5 +145,5 @@ class MammonController(BaseServerController, DirectoryBasedController):
case.removeClient(client) case.removeClient(client)
def get_irctest_controller_class(): def get_irctest_controller_class() -> Type[MammonController]:
return MammonController return MammonController

View File

@ -2,12 +2,14 @@ import copy
import json import json
import os import os
import subprocess import subprocess
from typing import Any, Dict, Optional, Set, Type, Union
from irctest.basecontrollers import ( from irctest.basecontrollers import (
BaseServerController, BaseServerController,
DirectoryBasedController, DirectoryBasedController,
NotImplementedByController, NotImplementedByController,
) )
from irctest.cases import BaseServerTestCase
OPER_PWD = "frenchfries" OPER_PWD = "frenchfries"
@ -116,7 +118,7 @@ BASE_CONFIG = {
LOGGING_CONFIG = {"logging": [{"method": "stderr", "level": "debug", "type": "*"}]} LOGGING_CONFIG = {"logging": [{"method": "stderr", "level": "debug", "type": "*"}]}
def hash_password(password): def hash_password(password: Union[str, bytes]) -> str:
if isinstance(password, str): if isinstance(password, str):
password = password.encode("utf-8") password = password.encode("utf-8")
# simulate entry of password and confirmation: # simulate entry of password and confirmation:
@ -134,25 +136,23 @@ class OragonoController(BaseServerController, DirectoryBasedController):
supported_sasl_mechanisms = {"PLAIN"} supported_sasl_mechanisms = {"PLAIN"}
supports_sts = True supports_sts = True
def create_config(self): def create_config(self) -> None:
super().create_config() super().create_config()
with self.open_file("ircd.yaml"): with self.open_file("ircd.yaml"):
pass pass
def kill_proc(self):
self.proc.kill()
def run( def run(
self, self,
hostname, hostname: str,
port, port: int,
password=None, *,
ssl=False, password: Optional[str],
restricted_metadata_keys=None, ssl: bool,
valid_metadata_keys=None, valid_metadata_keys: Optional[Set[str]] = None,
invalid_metadata_keys=None, invalid_metadata_keys: Optional[Set[str]] = None,
config=None, restricted_metadata_keys: Optional[Set[str]] = None,
): config: Optional[Any] = None,
) -> None:
if valid_metadata_keys or invalid_metadata_keys: if valid_metadata_keys or invalid_metadata_keys:
raise NotImplementedByController( raise NotImplementedByController(
"Defining valid and invalid METADATA keys." "Defining valid and invalid METADATA keys."
@ -162,6 +162,8 @@ class OragonoController(BaseServerController, DirectoryBasedController):
if config is None: if config is None:
config = copy.deepcopy(BASE_CONFIG) config = copy.deepcopy(BASE_CONFIG)
assert self.directory
enable_chathistory = self.test_config.chathistory enable_chathistory = self.test_config.chathistory
enable_roleplay = self.test_config.oragono_roleplay enable_roleplay = self.test_config.oragono_roleplay
if enable_chathistory or enable_roleplay: if enable_chathistory or enable_roleplay:
@ -180,12 +182,14 @@ class OragonoController(BaseServerController, DirectoryBasedController):
self.key_path = os.path.join(self.directory, "ssl.key") self.key_path = os.path.join(self.directory, "ssl.key")
self.pem_path = os.path.join(self.directory, "ssl.pem") self.pem_path = os.path.join(self.directory, "ssl.pem")
listener_conf = {"tls": {"cert": self.pem_path, "key": self.key_path}} listener_conf = {"tls": {"cert": self.pem_path, "key": self.key_path}}
config["server"]["listeners"][bind_address] = listener_conf config["server"]["listeners"][bind_address] = listener_conf # type: ignore
config["datastore"]["path"] = os.path.join(self.directory, "ircd.db") config["datastore"]["path"] = os.path.join( # type: ignore
self.directory, "ircd.db"
)
if password is not None: if password is not None:
config["server"]["password"] = hash_password(password) config["server"]["password"] = hash_password(password) # type: ignore
assert self.proc is None assert self.proc is None
@ -198,7 +202,12 @@ class OragonoController(BaseServerController, DirectoryBasedController):
["oragono", "run", "--conf", self._config_path, "--quiet"] ["oragono", "run", "--conf", self._config_path, "--quiet"]
) )
def registerUser(self, case, username, password=None): def registerUser(
self,
case: BaseServerTestCase,
username: str,
password: Optional[str] = None,
) -> None:
# XXX: Move this somewhere else when # XXX: Move this somewhere else when
# https://github.com/ircv3/ircv3-specifications/pull/152 becomes # https://github.com/ircv3/ircv3-specifications/pull/152 becomes
# part of the specification # part of the specification
@ -210,34 +219,35 @@ class OragonoController(BaseServerController, DirectoryBasedController):
while case.getRegistrationMessage(client).command != "001": while case.getRegistrationMessage(client).command != "001":
pass pass
case.getMessages(client) case.getMessages(client)
assert password
case.sendLine(client, "NS REGISTER " + password) case.sendLine(client, "NS REGISTER " + password)
msg = case.getMessage(client) msg = case.getMessage(client)
assert msg.params == [username, "Account created"] assert msg.params == [username, "Account created"]
case.sendLine(client, "QUIT") case.sendLine(client, "QUIT")
case.assertDisconnected(client) case.assertDisconnected(client)
def _write_config(self): def _write_config(self) -> None:
with open(self._config_path, "w") as fd: with open(self._config_path, "w") as fd:
json.dump(self._config, fd) json.dump(self._config, fd)
def baseConfig(self): def baseConfig(self) -> Dict:
return copy.deepcopy(BASE_CONFIG) return copy.deepcopy(BASE_CONFIG)
def getConfig(self): def getConfig(self) -> Dict:
return copy.deepcopy(self._config) return copy.deepcopy(self._config)
def addLoggingToConfig(self, config=None): def addLoggingToConfig(self, config: Optional[Dict] = None) -> Dict:
if config is None: if config is None:
config = self.baseConfig() config = self.baseConfig()
config.update(LOGGING_CONFIG) config.update(LOGGING_CONFIG)
return config return config
def addMysqlToConfig(self, config=None): def addMysqlToConfig(self, config: Optional[Dict] = None) -> Dict:
mysql_password = os.getenv("MYSQL_PASSWORD") mysql_password = os.getenv("MYSQL_PASSWORD")
if not mysql_password:
return config
if config is None: if config is None:
config = self.baseConfig() config = self.baseConfig()
if not mysql_password:
return config
config["datastore"]["mysql"] = { config["datastore"]["mysql"] = {
"enabled": True, "enabled": True,
"host": "localhost", "host": "localhost",
@ -259,7 +269,7 @@ class OragonoController(BaseServerController, DirectoryBasedController):
} }
return config return config
def rehash(self, case, config): def rehash(self, case: BaseServerTestCase, config: Dict) -> None:
self._config = config self._config = config
self._write_config() self._write_config()
client = "operator_for_rehash" client = "operator_for_rehash"
@ -270,11 +280,11 @@ class OragonoController(BaseServerController, DirectoryBasedController):
case.sendLine(client, "QUIT") case.sendLine(client, "QUIT")
case.assertDisconnected(client) case.assertDisconnected(client)
def enable_debug_logging(self, case): def enable_debug_logging(self, case: BaseServerTestCase) -> None:
config = self.getConfig() config = self.getConfig()
config.update(LOGGING_CONFIG) config.update(LOGGING_CONFIG)
self.rehash(case, config) self.rehash(case, config)
def get_irctest_controller_class(): def get_irctest_controller_class() -> Type[OragonoController]:
return OragonoController return OragonoController

View File

@ -1,3 +1,5 @@
from typing import Type
from .charybdis import CharybdisController from .charybdis import CharybdisController
@ -6,5 +8,5 @@ class SolanumController(CharybdisController):
binary_name = "solanum" binary_name = "solanum"
def get_irctest_controller_class(): def get_irctest_controller_class() -> Type[SolanumController]:
return SolanumController return SolanumController

View File

@ -1,8 +1,14 @@
import os import os
import subprocess import subprocess
import tempfile import tempfile
from typing import Optional, TextIO, Type, cast
from irctest.basecontrollers import BaseClientController, NotImplementedByController from irctest import authentication, tls
from irctest.basecontrollers import (
BaseClientController,
NotImplementedByController,
TestCaseControllerConfig,
)
TEMPLATE_CONFIG = """ TEMPLATE_CONFIG = """
[core] [core]
@ -24,30 +30,34 @@ class SopelController(BaseClientController):
supported_sasl_mechanisms = {"PLAIN"} supported_sasl_mechanisms = {"PLAIN"}
supports_sts = False supports_sts = False
def __init__(self, test_config): def __init__(self, test_config: TestCaseControllerConfig):
super().__init__(test_config) super().__init__(test_config)
self.filename = next(tempfile._get_candidate_names()) + ".cfg" self.filename = next(tempfile._get_candidate_names()) + ".cfg" # type: ignore
self.proc = None
def kill(self): def kill(self) -> None:
if self.proc: super().kill()
self.proc.kill()
if self.filename: if self.filename:
try: try:
os.unlink(os.path.join(os.path.expanduser("~/.sopel/"), self.filename)) os.unlink(os.path.join(os.path.expanduser("~/.sopel/"), self.filename))
except OSError: #  File does not exist except OSError: #  File does not exist
pass pass
def open_file(self, filename, mode="a"): def open_file(self, filename: str, mode: str = "a") -> TextIO:
dir_path = os.path.expanduser("~/.sopel/") dir_path = os.path.expanduser("~/.sopel/")
os.makedirs(dir_path, exist_ok=True) os.makedirs(dir_path, exist_ok=True)
return open(os.path.join(dir_path, filename), mode) return cast(TextIO, open(os.path.join(dir_path, filename), mode))
def create_config(self): def create_config(self) -> None:
with self.open_file(self.filename): with self.open_file(self.filename):
pass pass
def run(self, hostname, port, auth, tls_config): def run(
self,
hostname: str,
port: int,
auth: Optional[authentication.Authentication],
tls_config: Optional[tls.TlsConfig] = None,
) -> None:
# Runs a client with the config given as arguments # Runs a client with the config given as arguments
if tls_config is not None: if tls_config is not None:
raise NotImplementedByController("TLS configuration") raise NotImplementedByController("TLS configuration")
@ -66,5 +76,5 @@ class SopelController(BaseClientController):
self.proc = subprocess.Popen(["sopel", "--quiet", "-c", self.filename]) self.proc = subprocess.Popen(["sopel", "--quiet", "-c", self.filename])
def get_irctest_controller_class(): def get_irctest_controller_class() -> Type[SopelController]:
return SopelController return SopelController

View File

@ -2,8 +2,10 @@
Handles ambiguities of RFCs. Handles ambiguities of RFCs.
""" """
from typing import List
def normalize_namreply_params(params):
def normalize_namreply_params(params: List[str]) -> List[str]:
# So… RFC 2812 says: # So… RFC 2812 says:
# "( "=" / "*" / "@" ) <channel> # "( "=" / "*" / "@" ) <channel>
# :[ "@" / "+" ] <nick> *( " " [ "@" / "+" ] <nick> ) # :[ "@" / "+" ] <nick> *( " " [ "@" / "+" ] <nick> )
@ -12,6 +14,7 @@ def normalize_namreply_params(params):
# prefix. # prefix.
# So let's normalize this to “with space”, and strip spaces at the # So let's normalize this to “with space”, and strip spaces at the
# end of the nick list. # end of the nick list.
params = list(params) # copy the list
if len(params) == 3: if len(params) == 3:
assert params[1][0] in "=*@", params assert params[1][0] in "=*@", params
params.insert(1, params[1][0]) params.insert(1, params[1][0])

View File

@ -1,10 +1,12 @@
def cap_list_to_dict(caps): from typing import Dict, List, Optional
d = {}
def cap_list_to_dict(caps: List[str]) -> Dict[str, Optional[str]]:
d: Dict[str, Optional[str]] = {}
for cap in caps: for cap in caps:
if "=" in cap: if "=" in cap:
(key, value) = cap.split("=", 1) (key, value) = cap.split("=", 1)
d[key] = value
else: else:
key = cap d[cap] = None
value = None
d[key] = value
return d return d

View File

@ -1,16 +1,17 @@
import datetime import datetime
import re import re
import secrets import secrets
from typing import Dict
# thanks jess! # thanks jess!
IRCV3_FORMAT_STRFTIME = "%Y-%m-%dT%H:%M:%S.%f%z" IRCV3_FORMAT_STRFTIME = "%Y-%m-%dT%H:%M:%S.%f%z"
def ircv3_timestamp_to_unixtime(timestamp): def ircv3_timestamp_to_unixtime(timestamp: str) -> float:
return datetime.datetime.strptime(timestamp, IRCV3_FORMAT_STRFTIME).timestamp() return datetime.datetime.strptime(timestamp, IRCV3_FORMAT_STRFTIME).timestamp()
def random_name(base): def random_name(base: str) -> str:
return base + "-" + secrets.token_hex(8) return base + "-" + secrets.token_hex(8)
@ -26,16 +27,16 @@ class MultipleReplacer:
# We use an object instead of a lambda function because it avoids the # We use an object instead of a lambda function because it avoids the
# need for using the staticmethod() on the lambda function if assigning # need for using the staticmethod() on the lambda function if assigning
# it to a class in Python 3. # it to a class in Python 3.
def __init__(self, dict_): def __init__(self, dict_: Dict[str, str]):
self._dict = dict_ self._dict = dict_
dict_ = dict([(re.escape(key), val) for key, val in dict_.items()]) dict_ = dict([(re.escape(key), val) for key, val in dict_.items()])
self._matcher = re.compile("|".join(dict_.keys())) self._matcher = re.compile("|".join(dict_.keys()))
def __call__(self, s): def __call__(self, s: str) -> str:
return self._matcher.sub(lambda m: self._dict[m.group(0)], s) return self._matcher.sub(lambda m: self._dict[m.group(0)], s)
def normalizeWhitespace(s, removeNewline=True): def normalizeWhitespace(s: str, removeNewline: bool = True) -> str:
r"""Normalizes the whitespace in a string; \s+ becomes one space.""" r"""Normalizes the whitespace in a string; \s+ becomes one space."""
if not s: if not s:
return str(s) # not the same reference return str(s) # not the same reference

View File

@ -18,8 +18,8 @@ unescape_tag_value = MultipleReplacer(dict(map(lambda x: (x[1], x[0]), TAG_ESCAP
tag_key_validator = re.compile(r"\+?(\S+/)?[a-zA-Z0-9-]+") tag_key_validator = re.compile(r"\+?(\S+/)?[a-zA-Z0-9-]+")
def parse_tags(s): def parse_tags(s: str) -> Dict[str, Optional[str]]:
tags = {} tags: Dict[str, Optional[str]] = {}
for tag in s.split(";"): for tag in s.split(";"):
if "=" not in tag: if "=" not in tag:
tags[tag] = None tags[tag] = None
@ -54,15 +54,15 @@ class Message:
) )
def parse_message(s): def parse_message(s: str) -> Message:
"""Parse a message according to """Parse a message according to
http://tools.ietf.org/html/rfc1459#section-2.3.1 http://tools.ietf.org/html/rfc1459#section-2.3.1
and and
http://ircv3.net/specs/core/message-tags-3.2.html""" http://ircv3.net/specs/core/message-tags-3.2.html"""
s = s.rstrip("\r\n") s = s.rstrip("\r\n")
if s.startswith("@"): if s.startswith("@"):
(tags, s) = s.split(" ", 1) (tags_str, s) = s.split(" ", 1)
tags = parse_tags(tags[1:]) tags = parse_tags(tags_str[1:])
else: else:
tags = {} tags = {}
if " :" in s: if " :" in s:
@ -70,10 +70,7 @@ def parse_message(s):
tokens = list(filter(bool, other_tokens.split(" "))) + [trailing_param] tokens = list(filter(bool, other_tokens.split(" "))) + [trailing_param]
else: else:
tokens = list(filter(bool, s.split(" "))) tokens = list(filter(bool, s.split(" ")))
if tokens[0].startswith(":"): prefix = prefix = tokens.pop(0)[1:] if tokens[0].startswith(":") else None
prefix = tokens.pop(0)[1:]
else:
prefix = None
command = tokens.pop(0) command = tokens.pop(0)
params = tokens params = tokens
return Message(tags=tags, prefix=prefix, command=command, params=params) return Message(tags=tags, prefix=prefix, command=command, params=params)

View File

@ -1,7 +1,7 @@
import base64 import base64
def sasl_plain_blob(username, passphrase): def sasl_plain_blob(username: str, passphrase: str) -> str:
blob = base64.b64encode( blob = base64.b64encode(
b"\x00".join( b"\x00".join(
( (

View File

@ -1,14 +1,15 @@
import collections import collections
from typing import Dict, Union
import unittest import unittest
class NotImplementedByController(unittest.SkipTest, NotImplementedError): class NotImplementedByController(unittest.SkipTest, NotImplementedError):
def __str__(self): def __str__(self) -> str:
return "Not implemented by controller: {}".format(self.args[0]) return "Not implemented by controller: {}".format(self.args[0])
class ImplementationChoice(unittest.SkipTest): class ImplementationChoice(unittest.SkipTest):
def __str__(self): def __str__(self) -> str:
return ( return (
"Choice in the implementation makes it impossible to " "Choice in the implementation makes it impossible to "
"perform a test: {}".format(self.args[0]) "perform a test: {}".format(self.args[0])
@ -16,49 +17,49 @@ class ImplementationChoice(unittest.SkipTest):
class OptionalExtensionNotSupported(unittest.SkipTest): class OptionalExtensionNotSupported(unittest.SkipTest):
def __str__(self): def __str__(self) -> str:
return "Unsupported extension: {}".format(self.args[0]) return "Unsupported extension: {}".format(self.args[0])
class OptionalSaslMechanismNotSupported(unittest.SkipTest): class OptionalSaslMechanismNotSupported(unittest.SkipTest):
def __str__(self): def __str__(self) -> str:
return "Unsupported SASL mechanism: {}".format(self.args[0]) return "Unsupported SASL mechanism: {}".format(self.args[0])
class CapabilityNotSupported(unittest.SkipTest): class CapabilityNotSupported(unittest.SkipTest):
def __str__(self): def __str__(self) -> str:
return "Unsupported capability: {}".format(self.args[0]) return "Unsupported capability: {}".format(self.args[0])
class IsupportTokenNotSupported(unittest.SkipTest): class IsupportTokenNotSupported(unittest.SkipTest):
def __str__(self): def __str__(self) -> str:
return "Unsupported ISUPPORT token: {}".format(self.args[0]) return "Unsupported ISUPPORT token: {}".format(self.args[0])
class ChannelModeNotSupported(unittest.SkipTest): class ChannelModeNotSupported(unittest.SkipTest):
def __str__(self): def __str__(self) -> str:
return "Unsupported channel mode: {} ({})".format(self.args[0], self.args[1]) return "Unsupported channel mode: {} ({})".format(self.args[0], self.args[1])
class ExtbanNotSupported(unittest.SkipTest): class ExtbanNotSupported(unittest.SkipTest):
def __str__(self): def __str__(self) -> str:
return "Unsupported extban: {} ({})".format(self.args[0], self.args[1]) return "Unsupported extban: {} ({})".format(self.args[0], self.args[1])
class NotRequiredBySpecifications(unittest.SkipTest): class NotRequiredBySpecifications(unittest.SkipTest):
def __str__(self): def __str__(self) -> str:
return "Tests not required by the set of tested specification(s)." return "Tests not required by the set of tested specification(s)."
class SkipStrictTest(unittest.SkipTest): class SkipStrictTest(unittest.SkipTest):
def __str__(self): def __str__(self) -> str:
return "Tests not required because strict tests are disabled." return "Tests not required because strict tests are disabled."
class TextTestResult(unittest.TextTestResult): class TextTestResult(unittest.TextTestResult):
def getDescription(self, test): def getDescription(self, test: unittest.TestCase) -> str:
if hasattr(test, "description"): if hasattr(test, "description"):
doc_first_lines = test.description() doc_first_lines = test.description() # type: ignore
else: else:
doc_first_lines = test.shortDescription() doc_first_lines = test.shortDescription()
return "\n".join((str(test), doc_first_lines or "")) return "\n".join((str(test), doc_first_lines or ""))
@ -71,7 +72,9 @@ class TextTestRunner(unittest.TextTestRunner):
resultclass = TextTestResult resultclass = TextTestResult
def run(self, test): def run(
self, test: Union[unittest.TestSuite, unittest.TestCase]
) -> unittest.TestResult:
result = super().run(test) result = super().run(test)
assert self.resultclass is TextTestResult assert self.resultclass is TextTestResult
if result.skipped: if result.skipped:
@ -80,7 +83,7 @@ class TextTestRunner(unittest.TextTestRunner):
"Some tests were skipped because the following optional " "Some tests were skipped because the following optional "
"specifications/mechanisms are not supported:" "specifications/mechanisms are not supported:"
) )
msg_to_count = collections.defaultdict(lambda: 0) msg_to_count: Dict[str, int] = collections.defaultdict(lambda: 0)
for (test, msg) in result.skipped: for (test, msg) in result.skipped:
msg_to_count[msg] += 1 msg_to_count[msg] += 1
for (msg, count) in sorted(msg_to_count.items()): for (msg, count) in sorted(msg_to_count.items()):

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import enum import enum
@ -15,7 +17,7 @@ class Specifications(enum.Enum):
Modern = "modern" Modern = "modern"
@classmethod @classmethod
def from_name(cls, name): def from_name(cls, name: str) -> Specifications:
name = name.upper() name = name.upper()
for spec in cls: for spec in cls:
if spec.value.upper() == name: if spec.value.upper() == name:
@ -37,7 +39,7 @@ class Capabilities(enum.Enum):
STS = "sts" STS = "sts"
@classmethod @classmethod
def from_name(cls, name): def from_name(cls, name: str) -> Capabilities:
try: try:
return cls(name.lower()) return cls(name.lower())
except ValueError: except ValueError:
@ -50,7 +52,7 @@ class IsupportTokens(enum.Enum):
STATUSMSG = "STATUSMSG" STATUSMSG = "STATUSMSG"
@classmethod @classmethod
def from_name(cls, name): def from_name(cls, name: str) -> IsupportTokens:
try: try:
return cls(name.upper()) return cls(name.upper())
except ValueError: except ValueError: