diff --git a/irctest/basecontrollers.py b/irctest/basecontrollers.py index 73b4611..d18dab8 100644 --- a/irctest/basecontrollers.py +++ b/irctest/basecontrollers.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dataclasses import os import shutil @@ -5,8 +7,11 @@ import socket import subprocess import tempfile import time -from typing import Any, Callable, Dict, Optional, Set +from typing import IO, Any, Callable, Dict, Optional, Set +import irctest + +from . import authentication, tls from .runner import NotImplementedByController @@ -41,27 +46,27 @@ class _BaseController: a process (eg. a server or a client), as well as sending it instructions that are not part of the IRC specification.""" + # set by conftest.py + openssl_bin: str + + supports_sts: bool + supported_sasl_mechanisms: Set[str] + proc: Optional[subprocess.Popen] + def __init__(self, test_config: TestCaseControllerConfig): self.test_config = test_config self.proc = None - def check_is_alive(self): + def check_is_alive(self) -> None: + assert self.proc self.proc.poll() if self.proc.returncode is not None: raise ProcessStopped() - -class DirectoryBasedController(_BaseController): - """Helper for controllers whose software configuration is based on an - arbitrary directory.""" - - def __init__(self, test_config: TestCaseControllerConfig): - super().__init__(test_config) - self.directory = None - - def kill_proc(self): + def kill_proc(self) -> None: """Terminates the controlled process, waits for it to exit, and eventually kills it.""" + assert self.proc self.proc.terminate() try: self.proc.wait(5) @@ -69,20 +74,36 @@ class DirectoryBasedController(_BaseController): self.proc.kill() self.proc = None - def kill(self): + def kill(self) -> None: """Calls `kill_proc` and cleans the configuration.""" if self.proc: self.kill_proc() + + +class DirectoryBasedController(_BaseController): + """Helper for controllers whose software configuration is based on an + arbitrary directory.""" + + directory: Optional[str] + + def __init__(self, test_config: TestCaseControllerConfig): + super().__init__(test_config) + self.directory = None + + def kill(self) -> None: + """Calls `kill_proc` and cleans the configuration.""" + super().kill() if self.directory: shutil.rmtree(self.directory) - def terminate(self): + def terminate(self) -> None: """Stops the process gracefully, and does not clean its config.""" + assert self.proc self.proc.terminate() self.proc.wait() self.proc = None - def open_file(self, name, mode="a"): + def open_file(self, name: str, mode: str = "a") -> IO: """Open a file in the configuration directory.""" assert self.directory if os.sep in name: @@ -92,16 +113,12 @@ class DirectoryBasedController(_BaseController): assert os.path.isdir(dir_) return open(os.path.join(self.directory, name), mode) - def create_config(self): - """If there is no config dir, creates it and returns True. - Else returns False.""" - if self.directory: - return False - else: + def create_config(self) -> None: + if not self.directory: self.directory = tempfile.mkdtemp() - return True - def gen_ssl(self): + def gen_ssl(self) -> None: + assert self.directory self.csr_path = os.path.join(self.directory, "ssl.csr") self.key_path = os.path.join(self.directory, "ssl.key") self.pem_path = os.path.join(self.directory, "ssl.pem") @@ -145,7 +162,13 @@ class DirectoryBasedController(_BaseController): class BaseClientController(_BaseController): """Base controller for IRC clients.""" - def run(self, hostname, port, auth): + def run( + self, + hostname: str, + port: int, + auth: Optional[authentication.Authentication], + tls_config: Optional[tls.TlsConfig] = None, + ) -> None: raise NotImplementedError() @@ -154,17 +177,29 @@ class BaseServerController(_BaseController): _port_wait_interval = 0.1 port_open = False + port: int - supports_sts: bool - supported_sasl_mechanisms: Set[str] - - def run(self, hostname, port, password, valid_metadata_keys, invalid_metadata_keys): + def run( + self, + hostname: str, + port: int, + *, + password: Optional[str], + ssl: bool, + valid_metadata_keys: Optional[Set[str]], + invalid_metadata_keys: Optional[Set[str]], + ) -> None: raise NotImplementedError() - def registerUser(self, case, username, password=None): + def registerUser( + self, + case: irctest.cases.BaseServerTestCase, # type: ignore + username: str, + password: Optional[str] = None, + ) -> None: raise NotImplementedByController("account registration") - def wait_for_port(self): + def wait_for_port(self) -> None: while not self.port_open: self.check_is_alive() time.sleep(self._port_wait_interval) diff --git a/irctest/cases.py b/irctest/cases.py index aa62f01..75d6f52 100644 --- a/irctest/cases.py +++ b/irctest/cases.py @@ -3,16 +3,34 @@ import socket import ssl import tempfile import time -from typing import Optional, Set +from typing import ( + Any, + Callable, + Container, + Dict, + Generic, + Hashable, + Iterable, + List, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, + cast, +) import unittest import pytest -from . import basecontrollers, client_mock, runner +from . import basecontrollers, client_mock, runner, tls +from .authentication import Authentication from .basecontrollers import TestCaseControllerConfig from .exceptions import ConnectionClosed from .irc_utils import capabilities, message_parser from .irc_utils.junkdrawer import normalizeWhitespace +from .irc_utils.message_parser import Message from .irc_utils.sasl import sasl_plain_blob from .numerics import ( ERR_BADCHANNELKEY, @@ -35,18 +53,33 @@ CHANNEL_JOIN_FAIL_NUMERICS = frozenset( ] ) +# typevar for decorators +TCallable = TypeVar("TCallable", bound=Callable) + +# typevar for the client name used by tests (usually int or str) +TClientName = TypeVar("TClientName", bound=Union[Hashable, int]) + +TController = TypeVar("TController", bound=basecontrollers._BaseController) + +# general-purpose typevar +T = TypeVar("T") + class ChannelJoinException(Exception): - def __init__(self, code, params): + def __init__(self, code: str, params: List[str]): super().__init__(f"Failed to join channel ({code}): {params}") self.code = code self.params = params -class _IrcTestCase(unittest.TestCase): +class _IrcTestCase(unittest.TestCase, Generic[TController]): """Base class for test cases.""" - controllerClass = None # Will be set by __main__.py + # Will be set by __main__.py + controllerClass: Type[TController] + show_io: bool + + controller: TController @staticmethod def config() -> TestCaseControllerConfig: @@ -56,7 +89,7 @@ class _IrcTestCase(unittest.TestCase): """ return TestCaseControllerConfig() - def description(self): + def description(self) -> str: method_doc = self._testMethodDoc if not method_doc: return "" @@ -64,14 +97,13 @@ class _IrcTestCase(unittest.TestCase): method_doc, removeNewline=False ).strip().replace("\n ", "\n\t") - def setUp(self): + def setUp(self) -> None: super().setUp() self.controller = self.controllerClass(self.config()) - self.inbuffer = [] if self.show_io: print("---- new test ----") - def assertMessageEqual(self, msg, **kwargs): + def assertMessageEqual(self, msg: Message, **kwargs: Any) -> None: """Helper for partially comparing a message. Takes the message as first arguments, and comparisons to be made @@ -83,21 +115,21 @@ class _IrcTestCase(unittest.TestCase): if error: raise self.failureException(error) - def messageEqual(self, msg, **kwargs): + def messageEqual(self, msg: Message, **kwargs: Any) -> bool: """Boolean negation of `messageDiffers` (returns a boolean, not an optional string).""" return not self.messageDiffers(msg, **kwargs) def messageDiffers( self, - msg, - params=None, - target=None, - nick=None, - fail_msg=None, - extra_format=(), - **kwargs, - ): + msg: Message, + params: Optional[List[Any]] = None, + target: Optional[str] = None, + nick: Optional[str] = None, + fail_msg: Optional[str] = None, + extra_format: Tuple = (), + **kwargs: Any, + ) -> Optional[str]: """Returns an error message if the message doesn't match the given arguments, or None if it matches.""" for (key, value) in kwargs.items(): @@ -120,7 +152,7 @@ class _IrcTestCase(unittest.TestCase): ) if nick: - got_nick = msg.prefix.split("!")[0] + got_nick = msg.prefix.split("!")[0] if msg.prefix else None if msg.prefix is None: fail_msg = ( fail_msg or "expected nick to be {expects}, got {got} prefix: {msg}" @@ -131,7 +163,7 @@ class _IrcTestCase(unittest.TestCase): return None - def listMatch(self, got, expected): + def listMatch(self, got: List[str], expected: List[Any]) -> bool: """Returns True iff the list are equal. The ellipsis (aka. "..." aka triple dots) can be used on the 'expected' side as a wildcard, matching any *single* value.""" @@ -145,62 +177,124 @@ class _IrcTestCase(unittest.TestCase): return False return True - def assertIn(self, item, list_, msg=None, fail_msg=None, extra_format=()): + def assertIn( + self, + member: Any, + container: Union[Iterable[Any], Container[Any]], + msg: Optional[str] = None, + fail_msg: Optional[str] = None, + extra_format: Tuple = (), + ) -> None: if fail_msg: - fail_msg = fail_msg.format(*extra_format, item=item, list=list_, msg=msg) - super().assertIn(item, list_, fail_msg) + fail_msg = fail_msg.format( + *extra_format, item=member, list=container, msg=msg + ) + super().assertIn(member, container, fail_msg) - def assertNotIn(self, item, list_, msg=None, fail_msg=None, extra_format=()): + def assertNotIn( + self, + member: Any, + container: Union[Iterable[Any], Container[Any]], + msg: Optional[str] = None, + fail_msg: Optional[str] = None, + extra_format: Tuple = (), + ) -> None: if fail_msg: - fail_msg = fail_msg.format(*extra_format, item=item, list=list_, msg=msg) - super().assertNotIn(item, list_, fail_msg) + fail_msg = fail_msg.format( + *extra_format, item=member, list=container, msg=msg + ) + super().assertNotIn(member, container, fail_msg) - def assertEqual(self, got, expects, msg=None, fail_msg=None, extra_format=()): + def assertEqual( + self, + got: T, + expects: T, + msg: Any = None, + fail_msg: Optional[str] = None, + extra_format: Tuple = (), + ) -> None: if fail_msg: fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg) super().assertEqual(got, expects, fail_msg) - def assertNotEqual(self, got, expects, msg=None, fail_msg=None, extra_format=()): + def assertNotEqual( + self, + got: T, + expects: T, + msg: Any = None, + fail_msg: Optional[str] = None, + extra_format: Tuple = (), + ) -> None: if fail_msg: fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg) super().assertNotEqual(got, expects, fail_msg) - def assertGreater(self, got, expects, msg=None, fail_msg=None, extra_format=()): + def assertGreater( + self, + got: T, + expects: T, + msg: Any = None, + fail_msg: Optional[str] = None, + extra_format: Tuple = (), + ) -> None: if fail_msg: fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg) super().assertGreater(got, expects, fail_msg) def assertGreaterEqual( - self, got, expects, msg=None, fail_msg=None, extra_format=() - ): + self, + got: T, + expects: T, + msg: Any = None, + fail_msg: Optional[str] = None, + extra_format: Tuple = (), + ) -> None: if fail_msg: fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg) super().assertGreaterEqual(got, expects, fail_msg) - def assertLess(self, got, expects, msg=None, fail_msg=None, extra_format=()): + def assertLess( + self, + got: T, + expects: T, + msg: Any = None, + fail_msg: Optional[str] = None, + extra_format: Tuple = (), + ) -> None: if fail_msg: fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg) super().assertLess(got, expects, fail_msg) - def assertLessEqual(self, got, expects, msg=None, fail_msg=None, extra_format=()): + def assertLessEqual( + self, + got: T, + expects: T, + msg: Any = None, + fail_msg: Optional[str] = None, + extra_format: Tuple = (), + ) -> None: if fail_msg: fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg) super().assertLessEqual(got, expects, fail_msg) -class BaseClientTestCase(_IrcTestCase): +class BaseClientTestCase(_IrcTestCase[basecontrollers.BaseClientController]): """Basic class for client tests. Handles spawning a client and exchanging messages with it.""" - nick = None - user = None + conn: Optional[socket.socket] + nick: Optional[str] = None + user: Optional[List[str]] = None + server: socket.socket + protocol_version = Optional[str] + acked_capabilities = Optional[Set[str]] - def setUp(self): + def setUp(self) -> None: super().setUp() self.conn = None self._setUpServer() - def tearDown(self): + def tearDown(self) -> None: if self.conn: try: self.conn.sendall(b"QUIT :end of test.") @@ -214,7 +308,7 @@ class BaseClientTestCase(_IrcTestCase): self.conn.close() self.server.close() - def _setUpServer(self): + def _setUpServer(self) -> None: """Creates the server and make it listen.""" self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.server.bind(("", 0)) # Bind any free port @@ -223,9 +317,15 @@ class BaseClientTestCase(_IrcTestCase): # Used to check if the client is alive from time to time self.server.settimeout(1) - def acceptClient(self, tls_cert=None, tls_key=None, server=None): + def acceptClient( + self, + tls_cert: Optional[str] = None, + tls_key: Optional[str] = None, + server: Optional[socket.socket] = None, + ) -> None: """Make the server accept a client connection. Blocking.""" server = server or self.server + assert server # Wait for the client to connect while True: try: @@ -252,17 +352,17 @@ class BaseClientTestCase(_IrcTestCase): self.conn = context.wrap_socket(self.conn, server_side=True) self.conn_file = self.conn.makefile(newline="\r\n", encoding="utf8") - def getLine(self): + def getLine(self) -> str: line = self.conn_file.readline() if self.show_io: print("{:.3f} C: {}".format(time.time(), line.strip())) return line - def getMessages(self, *args): - lines = self.getLines(*args) - return map(message_parser.parse_message, lines) - - def getMessage(self, *args, filter_pred=None): + def getMessage( + self, + *args: Any, + filter_pred: Optional[Callable[[Message], bool]] = None, + ) -> Message: """Gets a message and returns it. If a filter predicate is given, fetches messages until the predicate returns a False on a message, and returns this message.""" @@ -274,18 +374,17 @@ class BaseClientTestCase(_IrcTestCase): if not filter_pred or filter_pred(msg): return msg - def sendLine(self, line): + def sendLine(self, line: str) -> None: + assert self.conn self.conn.sendall(line.encode()) if not line.endswith("\r\n"): self.conn.sendall(b"\r\n") if self.show_io: print("{:.3f} S: {}".format(time.time(), line.strip())) - -class ClientNegociationHelper: - """Helper class for tests handling capabilities negociation.""" - - def readCapLs(self, auth=None, tls_config=None): + def readCapLs( + self, auth: Optional[Authentication] = None, tls_config: tls.TlsConfig = None + ) -> None: (hostname, port) = self.server.getsockname() self.controller.run( hostname=hostname, port=port, auth=auth, tls_config=tls_config @@ -302,28 +401,33 @@ class ClientNegociationHelper: else: raise AssertionError("Unknown CAP params: {}".format(m.params)) - def userNickPredicate(self, msg): + def userNickPredicate(self, msg: Message) -> bool: """Predicate to be used with getMessage to handle NICK/USER transparently.""" if msg.command == "NICK": - self.assertEqual(len(msg.params), 1, msg) + self.assertEqual(len(msg.params), 1, msg=msg) self.nick = msg.params[0] return False elif msg.command == "USER": - self.assertEqual(len(msg.params), 4, msg) + self.assertEqual(len(msg.params), 4, msg=msg) self.user = msg.params return False else: return True - def negotiateCapabilities(self, caps, cap_ls=True, auth=None): + def negotiateCapabilities( + self, + caps: List[str], + cap_ls: bool = True, + auth: Optional[Authentication] = None, + ) -> Optional[Message]: """Performes a complete capability negociation process, without ending it, so the caller can continue the negociation.""" if cap_ls: self.readCapLs(auth) if not self.protocol_version: # No negotiation. - return + return None self.sendLine("CAP * LS :{}".format(" ".join(caps))) capability_names = frozenset(capabilities.cap_list_to_dict(caps)) self.acked_capabilities = set() @@ -343,21 +447,25 @@ class ClientNegociationHelper: self.sendLine( "CAP {} ACK :{}".format(self.nick or "*", m.params[1]) ) - self.acked_capabilities.update(requested) + self.acked_capabilities.update(requested) # type: ignore else: return m -class BaseServerTestCase(_IrcTestCase): +class BaseServerTestCase( + _IrcTestCase[basecontrollers.BaseServerController], Generic[TClientName] +): """Basic class for server tests. Handles spawning a server and exchanging messages with it.""" + show_io: bool # set by conftest.py + password: Optional[str] = None ssl = False valid_metadata_keys: Set[str] = set() invalid_metadata_keys: Set[str] = set() - def setUp(self): + def setUp(self) -> None: super().setUp() self.server_support = None self.find_hostname_and_port() @@ -369,53 +477,64 @@ class BaseServerTestCase(_IrcTestCase): invalid_metadata_keys=self.invalid_metadata_keys, ssl=self.ssl, ) - self.clients = {} + self.clients: Dict[TClientName, client_mock.ClientMock] = {} - def tearDown(self): + def tearDown(self) -> None: self.controller.kill() for client in list(self.clients): self.removeClient(client) - def find_hostname_and_port(self): + def find_hostname_and_port(self) -> None: """Find available hostname/port to listen on.""" s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.bind(("", 0)) (self.hostname, self.port) = s.getsockname() s.close() - def addClient(self, name=None, show_io=None): + def addClient( + self, name: Optional[TClientName] = None, show_io: Optional[bool] = None + ) -> TClientName: """Connects a client to the server and adds it to the dict. If 'name' is not given, uses the lowest unused non-negative integer.""" self.controller.wait_for_port() if not name: - name = max(map(int, list(self.clients) + [0])) + 1 + new_name: int = ( + max( + [int(name) for name in self.clients if isinstance(name, (int, str))] + + [0] + ) + + 1 + ) + name = cast(TClientName, new_name) show_io = show_io if show_io is not None else self.show_io self.clients[name] = client_mock.ClientMock(name=name, show_io=show_io) self.clients[name].connect(self.hostname, self.port) return name - def removeClient(self, name): + def removeClient(self, name: TClientName) -> None: """Disconnects the client, without QUIT.""" assert name in self.clients self.clients[name].disconnect() del self.clients[name] - def getMessages(self, client, **kwargs): + def getMessages(self, client: TClientName, **kwargs: Any) -> List[Message]: return self.clients[client].getMessages(**kwargs) - def getMessage(self, client, **kwargs): + def getMessage(self, client: TClientName, **kwargs: Any) -> Message: return self.clients[client].getMessage(**kwargs) - def getRegistrationMessage(self, client): + def getRegistrationMessage(self, client: TClientName) -> Message: """Filter notices, do not send pings.""" return self.getMessage( client, synchronize=False, filter_pred=lambda m: m.command != "NOTICE" ) - def sendLine(self, client, line): + def sendLine(self, client: TClientName, line: Union[str, bytes]) -> None: return self.clients[client].sendLine(line) - def getCapLs(self, client, as_list=False): + def getCapLs( + self, client: TClientName, as_list: bool = False + ) -> Union[List[str], Dict[str, Optional[str]]]: """Waits for a CAP LS block, parses all CAP LS messages, and return the dict capabilities, with their values. @@ -431,10 +550,10 @@ class BaseServerTestCase(_IrcTestCase): else: caps.extend(m.params[2].split()) if not as_list: - caps = capabilities.cap_list_to_dict(caps) + return capabilities.cap_list_to_dict(caps) return caps - def assertDisconnected(self, client): + def assertDisconnected(self, client: TClientName) -> None: try: self.getMessages(client) self.getMessages(client) @@ -444,7 +563,7 @@ class BaseServerTestCase(_IrcTestCase): else: raise AssertionError("Client not disconnected.") - def skipToWelcome(self, client): + def skipToWelcome(self, client: TClientName) -> List[Message]: """Skip to the point where we are registered """ @@ -457,15 +576,19 @@ class BaseServerTestCase(_IrcTestCase): def connectClient( self, - nick, - name=None, - capabilities=None, - skip_if_cap_nak=False, - show_io=None, - account=None, - password=None, - ident="username", - ): + nick: str, + name: TClientName = None, + capabilities: Optional[List[str]] = None, + skip_if_cap_nak: bool = False, + show_io: Optional[bool] = None, + account: Optional[str] = None, + password: Optional[str] = None, + ident: str = "username", + ) -> List[Message]: + """Connections a new client, does the cap negotiation + and connection registration, and skips to the end of the MOTD. + Returns the list of all messages received after registration, + just like `skipToWelcome`.""" client = self.addClient(name, show_io=show_io) if capabilities is not None and 0 < len(capabilities): self.sendLine(client, "CAP REQ :{}".format(" ".join(capabilities))) @@ -502,14 +625,14 @@ class BaseServerTestCase(_IrcTestCase): for param in m.params[1:-1]: if "=" in param: (key, value) = param.split("=") + self.server_support[key] = value else: - (key, value) = (param, None) - self.server_support[key] = value + self.server_support[param] = None welcome.append(m) return welcome - def joinClient(self, client, channel): + def joinClient(self, client: TClientName, channel: str) -> None: self.sendLine(client, "JOIN {}".format(channel)) received = {m.command for m in self.getMessages(client)} self.assertIn( @@ -520,7 +643,7 @@ class BaseServerTestCase(_IrcTestCase): extra_format=(channel,), ) - def joinChannel(self, client, channel): + def joinChannel(self, client: TClientName, channel: str) -> None: self.sendLine(client, "JOIN {}".format(channel)) # wait until we see them join the channel joined = False @@ -537,24 +660,34 @@ class BaseServerTestCase(_IrcTestCase): raise ChannelJoinException(msg.command, msg.params) -class OptionalityHelper: - controller: basecontrollers.BaseServerController +_TSelf = TypeVar("_TSelf", bound="OptionalityHelper") +_TReturn = TypeVar("_TReturn") - def checkSaslSupport(self): + +class OptionalityHelper(Generic[TController]): + controller: TController + + def checkSaslSupport(self) -> None: if self.controller.supported_sasl_mechanisms: return raise runner.NotImplementedByController("SASL") - def checkMechanismSupport(self, mechanism): + def checkMechanismSupport(self, mechanism: str) -> None: if mechanism in self.controller.supported_sasl_mechanisms: return raise runner.OptionalSaslMechanismNotSupported(mechanism) @staticmethod - def skipUnlessHasMechanism(mech): - def decorator(f): + def skipUnlessHasMechanism( + mech: str, + ) -> Callable[[Callable[[_TSelf], _TReturn]], Callable[[_TSelf], _TReturn]]: + # Just a function returning a function that takes functions and + # returns functions, nothing to see here. + # If Python didn't have such an awful syntax for callables, it would be: + # str -> ((TSelf -> TReturn) -> (TSelf -> TReturn)) + def decorator(f: Callable[[_TSelf], _TReturn]) -> Callable[[_TSelf], _TReturn]: @functools.wraps(f) - def newf(self): + def newf(self: _TSelf) -> _TReturn: self.checkMechanismSupport(mech) return f(self) @@ -562,23 +695,29 @@ class OptionalityHelper: return decorator - def skipUnlessHasSasl(f): + @staticmethod + def skipUnlessHasSasl( + f: Callable[[_TSelf], _TReturn] + ) -> Callable[[_TSelf], _TReturn]: @functools.wraps(f) - def newf(self): + def newf(self: _TSelf) -> _TReturn: self.checkSaslSupport() return f(self) return newf -def mark_specifications(*specifications, deprecated=False, strict=False): +def mark_specifications( + *specifications_str: str, deprecated: bool = False, strict: bool = False +) -> Callable[[TCallable], TCallable]: specifications = frozenset( - Specifications.from_name(s) if isinstance(s, str) else s for s in specifications + Specifications.from_name(s) if isinstance(s, str) else s + for s in specifications_str ) if None in specifications: raise ValueError("Invalid set of specifications: {}".format(specifications)) - def decorator(f): + def decorator(f: TCallable) -> TCallable: for specification in specifications: f = getattr(pytest.mark, specification.value)(f) if strict: @@ -590,14 +729,16 @@ def mark_specifications(*specifications, deprecated=False, strict=False): return decorator -def mark_capabilities(*capabilities, deprecated=False, strict=False): +def mark_capabilities( + *capabilities_str: str, deprecated: bool = False, strict: bool = False +) -> Callable[[TCallable], TCallable]: capabilities = frozenset( - Capabilities.from_name(c) if isinstance(c, str) else c for c in capabilities + Capabilities.from_name(c) if isinstance(c, str) else c for c in capabilities_str ) if None in capabilities: raise ValueError("Invalid set of capabilities: {}".format(capabilities)) - def decorator(f): + def decorator(f: TCallable) -> TCallable: for capability in capabilities: f = getattr(pytest.mark, capability.value)(f) # Support for any capability implies IRCv3 @@ -607,14 +748,16 @@ def mark_capabilities(*capabilities, deprecated=False, strict=False): return decorator -def mark_isupport(*tokens, deprecated=False, strict=False): +def mark_isupport( + *tokens_str: str, deprecated: bool = False, strict: bool = False +) -> Callable[[TCallable], TCallable]: tokens = frozenset( - IsupportTokens.from_name(c) if isinstance(c, str) else c for c in tokens + IsupportTokens.from_name(c) if isinstance(c, str) else c for c in tokens_str ) if None in tokens: raise ValueError("Invalid set of isupport tokens: {}".format(tokens)) - def decorator(f): + def decorator(f: TCallable) -> TCallable: for token in tokens: f = getattr(pytest.mark, token.value)(f) return f diff --git a/irctest/client_mock.py b/irctest/client_mock.py index 6eb22bf..8ce1583 100644 --- a/irctest/client_mock.py +++ b/irctest/client_mock.py @@ -2,36 +2,41 @@ import socket import ssl import sys import time +from typing import Any, Callable, List, Optional, Union from .exceptions import ConnectionClosed, NoMessageException from .irc_utils import message_parser class ClientMock: - def __init__(self, name, show_io): + def __init__(self, name: Any, show_io: bool): self.name = name self.show_io = show_io - self.inbuffer = [] + self.inbuffer: List[message_parser.Message] = [] self.ssl = False - def connect(self, hostname, port): + def connect(self, hostname: str, port: int) -> None: self.conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.conn.settimeout(1) # TODO: configurable self.conn.connect((hostname, port)) if self.show_io: print("{:.3f} {}: connects to server.".format(time.time(), self.name)) - def disconnect(self): + def disconnect(self) -> None: if self.show_io: print("{:.3f} {}: disconnects from server.".format(time.time(), self.name)) self.conn.close() - def starttls(self): + def starttls(self) -> None: assert not self.ssl, "SSL already active." self.conn = ssl.wrap_socket(self.conn) self.ssl = True - def getMessages(self, synchronize=True, assert_get_one=False, raw=False): + def getMessages( + self, synchronize: bool = True, assert_get_one: bool = False, raw: bool = False + ) -> List[message_parser.Message]: + """actually returns List[str] in the rare case where raw=True.""" + token: Optional[str] if synchronize: token = "synchronize{}".format(time.monotonic()) self.sendLine("PING {}".format(token)) @@ -79,7 +84,7 @@ class ClientMock: got_pong = True else: if raw: - messages.append(line) + messages.append(line) # type: ignore else: messages.append(message) data = b"" @@ -91,7 +96,13 @@ class ClientMock: else: return messages - def getMessage(self, filter_pred=None, synchronize=True, raw=False): + def getMessage( + self, + filter_pred: Optional[Callable[[message_parser.Message], bool]] = None, + synchronize: bool = True, + raw: bool = False, + ) -> message_parser.Message: + """Returns str in the rare case where raw=True""" while True: if not self.inbuffer: self.inbuffer = self.getMessages( @@ -103,7 +114,7 @@ class ClientMock: if not filter_pred or filter_pred(message): return message - def sendLine(self, line): + def sendLine(self, line: Union[str, bytes]) -> None: if isinstance(line, str): encoded_line = line.encode() elif isinstance(line, bytes): @@ -113,7 +124,7 @@ class ClientMock: if not encoded_line.endswith(b"\r\n"): encoded_line += b"\r\n" try: - ret = self.conn.sendall(encoded_line) + ret = self.conn.sendall(encoded_line) # type: ignore except BrokenPipeError: raise ConnectionClosed() if ( diff --git a/irctest/client_tests/test_cap.py b/irctest/client_tests/test_cap.py index 1914ff3..5361945 100644 --- a/irctest/client_tests/test_cap.py +++ b/irctest/client_tests/test_cap.py @@ -2,7 +2,7 @@ from irctest import cases from irctest.irc_utils.message_parser import Message -class CapTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper): +class CapTestCase(cases.BaseClientTestCase): @cases.mark_specifications("IRCv3") def testSendCap(self): """Send CAP LS 302 and read the result.""" diff --git a/irctest/client_tests/test_sasl.py b/irctest/client_tests/test_sasl.py index fdea482..dddeb35 100644 --- a/irctest/client_tests/test_sasl.py +++ b/irctest/client_tests/test_sasl.py @@ -39,9 +39,7 @@ class IdentityHash: return self._data -class SaslTestCase( - cases.BaseClientTestCase, cases.ClientNegociationHelper, cases.OptionalityHelper -): +class SaslTestCase(cases.BaseClientTestCase, cases.OptionalityHelper): @cases.OptionalityHelper.skipUnlessHasMechanism("PLAIN") def testPlain(self): """Test PLAIN authentication with correct username/password.""" @@ -263,9 +261,7 @@ class SaslTestCase( authenticator.response(msg) -class Irc302SaslTestCase( - cases.BaseClientTestCase, cases.ClientNegociationHelper, cases.OptionalityHelper -): +class Irc302SaslTestCase(cases.BaseClientTestCase, cases.OptionalityHelper): @cases.OptionalityHelper.skipUnlessHasMechanism("PLAIN") def testPlainNotAvailable(self): """Test the client does not try to authenticate using a mechanism the diff --git a/irctest/controllers/charybdis.py b/irctest/controllers/charybdis.py index 05a3d06..cb078b4 100644 --- a/irctest/controllers/charybdis.py +++ b/irctest/controllers/charybdis.py @@ -1,6 +1,6 @@ import os import subprocess -from typing import Set +from typing import Optional, Set, Type from irctest.basecontrollers import ( BaseServerController, @@ -47,20 +47,21 @@ class CharybdisController(BaseServerController, DirectoryBasedController): supported_sasl_mechanisms: Set[str] = set() supports_sts = False - def create_config(self): + def create_config(self) -> None: super().create_config() with self.open_file("server.conf"): pass def run( self, - hostname, - port, - password=None, - ssl=False, - valid_metadata_keys=None, - invalid_metadata_keys=None, - ): + hostname: str, + port: int, + *, + password: Optional[str], + ssl: bool, + valid_metadata_keys: Optional[Set[str]] = None, + invalid_metadata_keys: Optional[Set[str]] = None, + ) -> None: if valid_metadata_keys or invalid_metadata_keys: raise NotImplementedByController( "Defining valid and invalid METADATA keys." @@ -85,6 +86,7 @@ class CharybdisController(BaseServerController, DirectoryBasedController): ssl_config=ssl_config, ) ) + assert self.directory self.proc = subprocess.Popen( [ self.binary_name, @@ -98,5 +100,5 @@ class CharybdisController(BaseServerController, DirectoryBasedController): ) -def get_irctest_controller_class(): +def get_irctest_controller_class() -> Type[CharybdisController]: return CharybdisController diff --git a/irctest/controllers/girc.py b/irctest/controllers/girc.py index e4a64a7..21fd9c4 100644 --- a/irctest/controllers/girc.py +++ b/irctest/controllers/girc.py @@ -1,33 +1,25 @@ import subprocess +from typing import Optional, Type -from irctest.basecontrollers import BaseClientController, NotImplementedByController +from irctest import authentication, tls +from irctest.basecontrollers import ( + BaseClientController, + DirectoryBasedController, + NotImplementedByController, +) -class GircController(BaseClientController): +class GircController(BaseClientController, DirectoryBasedController): software_name = "gIRC" - supported_sasl_mechanisms = ["PLAIN"] + supported_sasl_mechanisms = {"PLAIN"} - def __init__(self): - super().__init__() - self.directory = None - self.proc = None - - def kill(self): - if self.proc: - self.proc.terminate() - try: - self.proc.wait(5) - except subprocess.TimeoutExpired: - self.proc.kill() - self.proc = None - - def __del__(self): - if self.proc: - self.proc.kill() - if self.directory: - self.directory.cleanup() - - def run(self, hostname, port, auth, tls_config): + def run( + self, + hostname: str, + port: int, + auth: Optional[authentication.Authentication], + tls_config: Optional[tls.TlsConfig] = None, + ) -> None: if tls_config: print(tls_config) raise NotImplementedByController("TLS options") @@ -42,5 +34,5 @@ class GircController(BaseClientController): self.proc = subprocess.Popen(["girc_test", "connect"] + args) -def get_irctest_controller_class(): +def get_irctest_controller_class() -> Type[GircController]: return GircController diff --git a/irctest/controllers/hybrid.py b/irctest/controllers/hybrid.py index eb0ed13..49f3375 100644 --- a/irctest/controllers/hybrid.py +++ b/irctest/controllers/hybrid.py @@ -1,6 +1,6 @@ import os import subprocess -from typing import Set +from typing import Optional, Set, Type from irctest.basecontrollers import ( BaseServerController, @@ -44,20 +44,21 @@ class HybridController(BaseServerController, DirectoryBasedController): supports_sts = False supported_sasl_mechanisms: Set[str] = set() - def create_config(self): + def create_config(self) -> None: super().create_config() with self.open_file("server.conf"): pass def run( self, - hostname, - port, - password=None, - ssl=False, - valid_metadata_keys=None, - invalid_metadata_keys=None, - ): + hostname: str, + port: int, + *, + password: Optional[str], + ssl: bool, + valid_metadata_keys: Optional[Set[str]] = None, + invalid_metadata_keys: Optional[Set[str]] = None, + ) -> None: if valid_metadata_keys or invalid_metadata_keys: raise NotImplementedByController( "Defining valid and invalid METADATA keys." @@ -82,6 +83,7 @@ class HybridController(BaseServerController, DirectoryBasedController): ssl_config=ssl_config, ) ) + assert self.directory self.proc = subprocess.Popen( [ "ircd", @@ -96,5 +98,5 @@ class HybridController(BaseServerController, DirectoryBasedController): ) -def get_irctest_controller_class(): +def get_irctest_controller_class() -> Type[HybridController]: return HybridController diff --git a/irctest/controllers/inspircd.py b/irctest/controllers/inspircd.py index 133b244..5b5d49f 100644 --- a/irctest/controllers/inspircd.py +++ b/irctest/controllers/inspircd.py @@ -1,6 +1,6 @@ import os import subprocess -from typing import Set +from typing import Optional, Set, Type from irctest.basecontrollers import ( BaseServerController, @@ -42,21 +42,22 @@ class InspircdController(BaseServerController, DirectoryBasedController): supported_sasl_mechanisms: Set[str] = set() supports_str = False - def create_config(self): + def create_config(self) -> None: super().create_config() with self.open_file("server.conf"): pass def run( self, - hostname, - port, - password=None, - ssl=False, - restricted_metadata_keys=None, - valid_metadata_keys=None, - invalid_metadata_keys=None, - ): + hostname: str, + port: int, + *, + password: Optional[str], + ssl: bool, + valid_metadata_keys: Optional[Set[str]] = None, + invalid_metadata_keys: Optional[Set[str]] = None, + restricted_metadata_keys: Optional[Set[str]] = None, + ) -> None: if valid_metadata_keys or invalid_metadata_keys: raise NotImplementedByController( "Defining valid and invalid METADATA keys." @@ -81,6 +82,7 @@ class InspircdController(BaseServerController, DirectoryBasedController): ssl_config=ssl_config, ) ) + assert self.directory self.proc = subprocess.Popen( [ "inspircd", @@ -92,5 +94,5 @@ class InspircdController(BaseServerController, DirectoryBasedController): ) -def get_irctest_controller_class(): +def get_irctest_controller_class() -> Type[InspircdController]: return InspircdController diff --git a/irctest/controllers/ircd_seven.py b/irctest/controllers/ircd_seven.py index d998221..c6bd3f1 100644 --- a/irctest/controllers/ircd_seven.py +++ b/irctest/controllers/ircd_seven.py @@ -1,3 +1,5 @@ +from typing import Type + from .charybdis import CharybdisController @@ -6,5 +8,5 @@ class IrcdSevenController(CharybdisController): binary_name = "ircd-seven" -def get_irctest_controller_class(): +def get_irctest_controller_class() -> Type[IrcdSevenController]: return IrcdSevenController diff --git a/irctest/controllers/limnoria.py b/irctest/controllers/limnoria.py index 4a38d5d..78b005f 100644 --- a/irctest/controllers/limnoria.py +++ b/irctest/controllers/limnoria.py @@ -1,7 +1,8 @@ import os import subprocess +from typing import Optional, Type -from irctest import tls +from irctest import authentication, tls from irctest.basecontrollers import BaseClientController, DirectoryBasedController TEMPLATE_CONFIG = """ @@ -35,15 +36,20 @@ class LimnoriaController(BaseClientController, DirectoryBasedController): } supports_sts = True - def create_config(self): - create_config = super().create_config() - if create_config: - with self.open_file("bot.conf"): - pass - with self.open_file("conf/users.conf"): - pass + def create_config(self) -> None: + super().create_config() + with self.open_file("bot.conf"): + pass + with self.open_file("conf/users.conf"): + pass - def run(self, hostname, port, auth, tls_config=None): + def run( + self, + hostname: str, + port: int, + auth: Optional[authentication.Authentication], + tls_config: Optional[tls.TlsConfig] = None, + ) -> None: if tls_config is None: tls_config = tls.TlsConfig(enable=False, trusted_fingerprints=[]) # Runs a client with the config given as arguments @@ -72,10 +78,11 @@ class LimnoriaController(BaseClientController, DirectoryBasedController): else "", ) ) + assert self.directory self.proc = subprocess.Popen( ["supybot", os.path.join(self.directory, "bot.conf")] ) -def get_irctest_controller_class(): +def get_irctest_controller_class() -> Type[LimnoriaController]: return LimnoriaController diff --git a/irctest/controllers/mammon.py b/irctest/controllers/mammon.py index bbeb4ef..57cdc35 100644 --- a/irctest/controllers/mammon.py +++ b/irctest/controllers/mammon.py @@ -1,11 +1,13 @@ import os import subprocess +from typing import Optional, Set, Type from irctest.basecontrollers import ( BaseServerController, DirectoryBasedController, NotImplementedByController, ) +from irctest.cases import BaseServerTestCase TEMPLATE_CONFIG = """ clients: @@ -61,7 +63,7 @@ server: """ -def make_list(list_): +def make_list(list_: Set[str]) -> str: return "\n".join(map(" - {}".format, list_)) @@ -69,25 +71,27 @@ class MammonController(BaseServerController, DirectoryBasedController): software_name = "Mammon" supported_sasl_mechanisms = {"PLAIN", "ECDSA-NIST256P-CHALLENGE"} - def create_config(self): + def create_config(self) -> None: super().create_config() with self.open_file("server.conf"): pass - def kill_proc(self): + def kill_proc(self) -> None: # Mammon does not seem to handle SIGTERM very well + assert self.proc self.proc.kill() def run( self, - hostname, - port, - password=None, - ssl=False, - restricted_metadata_keys=(), - valid_metadata_keys=(), - invalid_metadata_keys=(), - ): + hostname: str, + port: int, + *, + password: Optional[str], + ssl: bool, + valid_metadata_keys: Optional[Set[str]] = None, + invalid_metadata_keys: Optional[Set[str]] = None, + restricted_metadata_keys: Optional[Set[str]] = None, + ) -> None: if password is not None: raise NotImplementedByController("PASS command") if ssl: @@ -101,12 +105,13 @@ class MammonController(BaseServerController, DirectoryBasedController): directory=self.directory, hostname=hostname, port=port, - authorized_keys=make_list(valid_metadata_keys), - restricted_keys=make_list(restricted_metadata_keys), + authorized_keys=make_list(valid_metadata_keys or set()), + restricted_keys=make_list(restricted_metadata_keys or set()), ) ) # with self.open_file('server.yml', 'r') as fd: # print(fd.read()) + assert self.directory self.proc = subprocess.Popen( [ "mammond", @@ -116,7 +121,12 @@ class MammonController(BaseServerController, DirectoryBasedController): ] ) - def registerUser(self, case, username, password=None): + def registerUser( + self, + case: BaseServerTestCase, + username: str, + password: Optional[str] = None, + ) -> None: # XXX: Move this somewhere else when # https://github.com/ircv3/ircv3-specifications/pull/152 becomes # part of the specification @@ -135,5 +145,5 @@ class MammonController(BaseServerController, DirectoryBasedController): case.removeClient(client) -def get_irctest_controller_class(): +def get_irctest_controller_class() -> Type[MammonController]: return MammonController diff --git a/irctest/controllers/oragono.py b/irctest/controllers/oragono.py index b2082e2..55a2280 100644 --- a/irctest/controllers/oragono.py +++ b/irctest/controllers/oragono.py @@ -2,12 +2,14 @@ import copy import json import os import subprocess +from typing import Any, Dict, Optional, Set, Type, Union from irctest.basecontrollers import ( BaseServerController, DirectoryBasedController, NotImplementedByController, ) +from irctest.cases import BaseServerTestCase OPER_PWD = "frenchfries" @@ -116,7 +118,7 @@ BASE_CONFIG = { LOGGING_CONFIG = {"logging": [{"method": "stderr", "level": "debug", "type": "*"}]} -def hash_password(password): +def hash_password(password: Union[str, bytes]) -> str: if isinstance(password, str): password = password.encode("utf-8") # simulate entry of password and confirmation: @@ -134,25 +136,23 @@ class OragonoController(BaseServerController, DirectoryBasedController): supported_sasl_mechanisms = {"PLAIN"} supports_sts = True - def create_config(self): + def create_config(self) -> None: super().create_config() with self.open_file("ircd.yaml"): pass - def kill_proc(self): - self.proc.kill() - def run( self, - hostname, - port, - password=None, - ssl=False, - restricted_metadata_keys=None, - valid_metadata_keys=None, - invalid_metadata_keys=None, - config=None, - ): + hostname: str, + port: int, + *, + password: Optional[str], + ssl: bool, + valid_metadata_keys: Optional[Set[str]] = None, + invalid_metadata_keys: Optional[Set[str]] = None, + restricted_metadata_keys: Optional[Set[str]] = None, + config: Optional[Any] = None, + ) -> None: if valid_metadata_keys or invalid_metadata_keys: raise NotImplementedByController( "Defining valid and invalid METADATA keys." @@ -162,6 +162,8 @@ class OragonoController(BaseServerController, DirectoryBasedController): if config is None: config = copy.deepcopy(BASE_CONFIG) + assert self.directory + enable_chathistory = self.test_config.chathistory enable_roleplay = self.test_config.oragono_roleplay if enable_chathistory or enable_roleplay: @@ -180,12 +182,14 @@ class OragonoController(BaseServerController, DirectoryBasedController): self.key_path = os.path.join(self.directory, "ssl.key") self.pem_path = os.path.join(self.directory, "ssl.pem") listener_conf = {"tls": {"cert": self.pem_path, "key": self.key_path}} - config["server"]["listeners"][bind_address] = listener_conf + config["server"]["listeners"][bind_address] = listener_conf # type: ignore - config["datastore"]["path"] = os.path.join(self.directory, "ircd.db") + config["datastore"]["path"] = os.path.join( # type: ignore + self.directory, "ircd.db" + ) if password is not None: - config["server"]["password"] = hash_password(password) + config["server"]["password"] = hash_password(password) # type: ignore assert self.proc is None @@ -198,7 +202,12 @@ class OragonoController(BaseServerController, DirectoryBasedController): ["oragono", "run", "--conf", self._config_path, "--quiet"] ) - def registerUser(self, case, username, password=None): + def registerUser( + self, + case: BaseServerTestCase, + username: str, + password: Optional[str] = None, + ) -> None: # XXX: Move this somewhere else when # https://github.com/ircv3/ircv3-specifications/pull/152 becomes # part of the specification @@ -210,34 +219,35 @@ class OragonoController(BaseServerController, DirectoryBasedController): while case.getRegistrationMessage(client).command != "001": pass case.getMessages(client) + assert password case.sendLine(client, "NS REGISTER " + password) msg = case.getMessage(client) assert msg.params == [username, "Account created"] case.sendLine(client, "QUIT") case.assertDisconnected(client) - def _write_config(self): + def _write_config(self) -> None: with open(self._config_path, "w") as fd: json.dump(self._config, fd) - def baseConfig(self): + def baseConfig(self) -> Dict: return copy.deepcopy(BASE_CONFIG) - def getConfig(self): + def getConfig(self) -> Dict: return copy.deepcopy(self._config) - def addLoggingToConfig(self, config=None): + def addLoggingToConfig(self, config: Optional[Dict] = None) -> Dict: if config is None: config = self.baseConfig() config.update(LOGGING_CONFIG) return config - def addMysqlToConfig(self, config=None): + def addMysqlToConfig(self, config: Optional[Dict] = None) -> Dict: mysql_password = os.getenv("MYSQL_PASSWORD") - if not mysql_password: - return config if config is None: config = self.baseConfig() + if not mysql_password: + return config config["datastore"]["mysql"] = { "enabled": True, "host": "localhost", @@ -259,7 +269,7 @@ class OragonoController(BaseServerController, DirectoryBasedController): } return config - def rehash(self, case, config): + def rehash(self, case: BaseServerTestCase, config: Dict) -> None: self._config = config self._write_config() client = "operator_for_rehash" @@ -270,11 +280,11 @@ class OragonoController(BaseServerController, DirectoryBasedController): case.sendLine(client, "QUIT") case.assertDisconnected(client) - def enable_debug_logging(self, case): + def enable_debug_logging(self, case: BaseServerTestCase) -> None: config = self.getConfig() config.update(LOGGING_CONFIG) self.rehash(case, config) -def get_irctest_controller_class(): +def get_irctest_controller_class() -> Type[OragonoController]: return OragonoController diff --git a/irctest/controllers/solanum.py b/irctest/controllers/solanum.py index 063e03e..9314a54 100644 --- a/irctest/controllers/solanum.py +++ b/irctest/controllers/solanum.py @@ -1,3 +1,5 @@ +from typing import Type + from .charybdis import CharybdisController @@ -6,5 +8,5 @@ class SolanumController(CharybdisController): binary_name = "solanum" -def get_irctest_controller_class(): +def get_irctest_controller_class() -> Type[SolanumController]: return SolanumController diff --git a/irctest/controllers/sopel.py b/irctest/controllers/sopel.py index 37dce97..7304315 100644 --- a/irctest/controllers/sopel.py +++ b/irctest/controllers/sopel.py @@ -1,8 +1,14 @@ import os import subprocess import tempfile +from typing import Optional, TextIO, Type, cast -from irctest.basecontrollers import BaseClientController, NotImplementedByController +from irctest import authentication, tls +from irctest.basecontrollers import ( + BaseClientController, + NotImplementedByController, + TestCaseControllerConfig, +) TEMPLATE_CONFIG = """ [core] @@ -24,30 +30,34 @@ class SopelController(BaseClientController): supported_sasl_mechanisms = {"PLAIN"} supports_sts = False - def __init__(self, test_config): + def __init__(self, test_config: TestCaseControllerConfig): super().__init__(test_config) - self.filename = next(tempfile._get_candidate_names()) + ".cfg" - self.proc = None + self.filename = next(tempfile._get_candidate_names()) + ".cfg" # type: ignore - def kill(self): - if self.proc: - self.proc.kill() + def kill(self) -> None: + super().kill() if self.filename: try: os.unlink(os.path.join(os.path.expanduser("~/.sopel/"), self.filename)) except OSError: #  File does not exist pass - def open_file(self, filename, mode="a"): + def open_file(self, filename: str, mode: str = "a") -> TextIO: dir_path = os.path.expanduser("~/.sopel/") os.makedirs(dir_path, exist_ok=True) - return open(os.path.join(dir_path, filename), mode) + return cast(TextIO, open(os.path.join(dir_path, filename), mode)) - def create_config(self): + def create_config(self) -> None: with self.open_file(self.filename): pass - def run(self, hostname, port, auth, tls_config): + def run( + self, + hostname: str, + port: int, + auth: Optional[authentication.Authentication], + tls_config: Optional[tls.TlsConfig] = None, + ) -> None: # Runs a client with the config given as arguments if tls_config is not None: raise NotImplementedByController("TLS configuration") @@ -66,5 +76,5 @@ class SopelController(BaseClientController): self.proc = subprocess.Popen(["sopel", "--quiet", "-c", self.filename]) -def get_irctest_controller_class(): +def get_irctest_controller_class() -> Type[SopelController]: return SopelController diff --git a/irctest/irc_utils/ambiguities.py b/irctest/irc_utils/ambiguities.py index 1b7b943..4fac710 100644 --- a/irctest/irc_utils/ambiguities.py +++ b/irctest/irc_utils/ambiguities.py @@ -2,8 +2,10 @@ Handles ambiguities of RFCs. """ +from typing import List -def normalize_namreply_params(params): + +def normalize_namreply_params(params: List[str]) -> List[str]: # So… RFC 2812 says: # "( "=" / "*" / "@" ) # :[ "@" / "+" ] *( " " [ "@" / "+" ] ) @@ -12,6 +14,7 @@ def normalize_namreply_params(params): # prefix. # So let's normalize this to “with space”, and strip spaces at the # end of the nick list. + params = list(params) # copy the list if len(params) == 3: assert params[1][0] in "=*@", params params.insert(1, params[1][0]) diff --git a/irctest/irc_utils/capabilities.py b/irctest/irc_utils/capabilities.py index c38fb9f..e075bea 100644 --- a/irctest/irc_utils/capabilities.py +++ b/irctest/irc_utils/capabilities.py @@ -1,10 +1,12 @@ -def cap_list_to_dict(caps): - d = {} +from typing import Dict, List, Optional + + +def cap_list_to_dict(caps: List[str]) -> Dict[str, Optional[str]]: + d: Dict[str, Optional[str]] = {} for cap in caps: if "=" in cap: (key, value) = cap.split("=", 1) + d[key] = value else: - key = cap - value = None - d[key] = value + d[cap] = None return d diff --git a/irctest/irc_utils/junkdrawer.py b/irctest/irc_utils/junkdrawer.py index dc39738..249d023 100644 --- a/irctest/irc_utils/junkdrawer.py +++ b/irctest/irc_utils/junkdrawer.py @@ -1,16 +1,17 @@ import datetime import re import secrets +from typing import Dict # thanks jess! IRCV3_FORMAT_STRFTIME = "%Y-%m-%dT%H:%M:%S.%f%z" -def ircv3_timestamp_to_unixtime(timestamp): +def ircv3_timestamp_to_unixtime(timestamp: str) -> float: return datetime.datetime.strptime(timestamp, IRCV3_FORMAT_STRFTIME).timestamp() -def random_name(base): +def random_name(base: str) -> str: return base + "-" + secrets.token_hex(8) @@ -26,16 +27,16 @@ class MultipleReplacer: # We use an object instead of a lambda function because it avoids the # need for using the staticmethod() on the lambda function if assigning # it to a class in Python 3. - def __init__(self, dict_): + def __init__(self, dict_: Dict[str, str]): self._dict = dict_ dict_ = dict([(re.escape(key), val) for key, val in dict_.items()]) self._matcher = re.compile("|".join(dict_.keys())) - def __call__(self, s): + def __call__(self, s: str) -> str: return self._matcher.sub(lambda m: self._dict[m.group(0)], s) -def normalizeWhitespace(s, removeNewline=True): +def normalizeWhitespace(s: str, removeNewline: bool = True) -> str: r"""Normalizes the whitespace in a string; \s+ becomes one space.""" if not s: return str(s) # not the same reference diff --git a/irctest/irc_utils/message_parser.py b/irctest/irc_utils/message_parser.py index 22f2b99..f24a549 100644 --- a/irctest/irc_utils/message_parser.py +++ b/irctest/irc_utils/message_parser.py @@ -18,8 +18,8 @@ unescape_tag_value = MultipleReplacer(dict(map(lambda x: (x[1], x[0]), TAG_ESCAP tag_key_validator = re.compile(r"\+?(\S+/)?[a-zA-Z0-9-]+") -def parse_tags(s): - tags = {} +def parse_tags(s: str) -> Dict[str, Optional[str]]: + tags: Dict[str, Optional[str]] = {} for tag in s.split(";"): if "=" not in tag: tags[tag] = None @@ -54,15 +54,15 @@ class Message: ) -def parse_message(s): +def parse_message(s: str) -> Message: """Parse a message according to http://tools.ietf.org/html/rfc1459#section-2.3.1 and http://ircv3.net/specs/core/message-tags-3.2.html""" s = s.rstrip("\r\n") if s.startswith("@"): - (tags, s) = s.split(" ", 1) - tags = parse_tags(tags[1:]) + (tags_str, s) = s.split(" ", 1) + tags = parse_tags(tags_str[1:]) else: tags = {} if " :" in s: @@ -70,10 +70,7 @@ def parse_message(s): tokens = list(filter(bool, other_tokens.split(" "))) + [trailing_param] else: tokens = list(filter(bool, s.split(" "))) - if tokens[0].startswith(":"): - prefix = tokens.pop(0)[1:] - else: - prefix = None + prefix = prefix = tokens.pop(0)[1:] if tokens[0].startswith(":") else None command = tokens.pop(0) params = tokens return Message(tags=tags, prefix=prefix, command=command, params=params) diff --git a/irctest/irc_utils/sasl.py b/irctest/irc_utils/sasl.py index 4957af2..120ea7b 100644 --- a/irctest/irc_utils/sasl.py +++ b/irctest/irc_utils/sasl.py @@ -1,7 +1,7 @@ import base64 -def sasl_plain_blob(username, passphrase): +def sasl_plain_blob(username: str, passphrase: str) -> str: blob = base64.b64encode( b"\x00".join( ( diff --git a/irctest/runner.py b/irctest/runner.py index 64bd4c8..f493e1e 100644 --- a/irctest/runner.py +++ b/irctest/runner.py @@ -1,14 +1,15 @@ import collections +from typing import Dict, Union import unittest class NotImplementedByController(unittest.SkipTest, NotImplementedError): - def __str__(self): + def __str__(self) -> str: return "Not implemented by controller: {}".format(self.args[0]) class ImplementationChoice(unittest.SkipTest): - def __str__(self): + def __str__(self) -> str: return ( "Choice in the implementation makes it impossible to " "perform a test: {}".format(self.args[0]) @@ -16,49 +17,49 @@ class ImplementationChoice(unittest.SkipTest): class OptionalExtensionNotSupported(unittest.SkipTest): - def __str__(self): + def __str__(self) -> str: return "Unsupported extension: {}".format(self.args[0]) class OptionalSaslMechanismNotSupported(unittest.SkipTest): - def __str__(self): + def __str__(self) -> str: return "Unsupported SASL mechanism: {}".format(self.args[0]) class CapabilityNotSupported(unittest.SkipTest): - def __str__(self): + def __str__(self) -> str: return "Unsupported capability: {}".format(self.args[0]) class IsupportTokenNotSupported(unittest.SkipTest): - def __str__(self): + def __str__(self) -> str: return "Unsupported ISUPPORT token: {}".format(self.args[0]) class ChannelModeNotSupported(unittest.SkipTest): - def __str__(self): + def __str__(self) -> str: return "Unsupported channel mode: {} ({})".format(self.args[0], self.args[1]) class ExtbanNotSupported(unittest.SkipTest): - def __str__(self): + def __str__(self) -> str: return "Unsupported extban: {} ({})".format(self.args[0], self.args[1]) class NotRequiredBySpecifications(unittest.SkipTest): - def __str__(self): + def __str__(self) -> str: return "Tests not required by the set of tested specification(s)." class SkipStrictTest(unittest.SkipTest): - def __str__(self): + def __str__(self) -> str: return "Tests not required because strict tests are disabled." class TextTestResult(unittest.TextTestResult): - def getDescription(self, test): + def getDescription(self, test: unittest.TestCase) -> str: if hasattr(test, "description"): - doc_first_lines = test.description() + doc_first_lines = test.description() # type: ignore else: doc_first_lines = test.shortDescription() return "\n".join((str(test), doc_first_lines or "")) @@ -71,7 +72,9 @@ class TextTestRunner(unittest.TextTestRunner): resultclass = TextTestResult - def run(self, test): + def run( + self, test: Union[unittest.TestSuite, unittest.TestCase] + ) -> unittest.TestResult: result = super().run(test) assert self.resultclass is TextTestResult if result.skipped: @@ -80,7 +83,7 @@ class TextTestRunner(unittest.TextTestRunner): "Some tests were skipped because the following optional " "specifications/mechanisms are not supported:" ) - msg_to_count = collections.defaultdict(lambda: 0) + msg_to_count: Dict[str, int] = collections.defaultdict(lambda: 0) for (test, msg) in result.skipped: msg_to_count[msg] += 1 for (msg, count) in sorted(msg_to_count.items()): diff --git a/irctest/specifications.py b/irctest/specifications.py index 0ae9af0..928c150 100644 --- a/irctest/specifications.py +++ b/irctest/specifications.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import enum @@ -15,7 +17,7 @@ class Specifications(enum.Enum): Modern = "modern" @classmethod - def from_name(cls, name): + def from_name(cls, name: str) -> Specifications: name = name.upper() for spec in cls: if spec.value.upper() == name: @@ -37,7 +39,7 @@ class Capabilities(enum.Enum): STS = "sts" @classmethod - def from_name(cls, name): + def from_name(cls, name: str) -> Capabilities: try: return cls(name.lower()) except ValueError: @@ -50,7 +52,7 @@ class IsupportTokens(enum.Enum): STATUSMSG = "STATUSMSG" @classmethod - def from_name(cls, name): + def from_name(cls, name: str) -> IsupportTokens: try: return cls(name.upper()) except ValueError: