diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3e2abca..ada8b4d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,3 +14,8 @@ repos: rev: 3.8.3 hooks: - id: flake8 + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.812 + hooks: + - id: mypy diff --git a/irctest/authentication.py b/irctest/authentication.py index eec2bec..712fef6 100644 --- a/irctest/authentication.py +++ b/irctest/authentication.py @@ -1,5 +1,6 @@ -import collections +import dataclasses import enum +from typing import Optional, Tuple @enum.unique @@ -19,7 +20,9 @@ class Mechanisms(enum.Enum): scram_sha_256 = 3 -Authentication = collections.namedtuple( - "Authentication", "mechanisms username password ecdsa_key" -) -Authentication.__new__.__defaults__ = ([Mechanisms.plain], None, None, None) +@dataclasses.dataclass +class Authentication: + mechanisms: Tuple[Mechanisms] = (Mechanisms.plain,) + username: Optional[str] = None + password: Optional[str] = None + ecdsa_key: Optional[str] = None diff --git a/irctest/basecontrollers.py b/irctest/basecontrollers.py index 5418733..002f4e6 100644 --- a/irctest/basecontrollers.py +++ b/irctest/basecontrollers.py @@ -4,6 +4,7 @@ import socket import subprocess import tempfile import time +from typing import Set from .runner import NotImplementedByController @@ -135,6 +136,9 @@ class BaseServerController(_BaseController): _port_wait_interval = 0.1 port_open = False + supports_sts: bool + supported_sasl_mechanisms: Set[str] + def run(self, hostname, port, password, valid_metadata_keys, invalid_metadata_keys): raise NotImplementedError() diff --git a/irctest/cases.py b/irctest/cases.py index aef4393..3d5b11b 100644 --- a/irctest/cases.py +++ b/irctest/cases.py @@ -3,11 +3,12 @@ import socket import ssl import tempfile import time +from typing import Optional, Set import unittest import pytest -from . import client_mock, runner +from . import basecontrollers, client_mock, runner from .exceptions import ConnectionClosed from .irc_utils import capabilities, message_parser from .irc_utils.junkdrawer import normalizeWhitespace @@ -350,10 +351,10 @@ class BaseServerTestCase(_IrcTestCase): """Basic class for server tests. Handles spawning a server and exchanging messages with it.""" - password = None + password: Optional[str] = None ssl = False - valid_metadata_keys = frozenset() - invalid_metadata_keys = frozenset() + valid_metadata_keys: Set[str] = set() + invalid_metadata_keys: Set[str] = set() def setUp(self): super().setUp() @@ -536,6 +537,8 @@ class BaseServerTestCase(_IrcTestCase): class OptionalityHelper: + controller: basecontrollers.BaseServerController + def checkSaslSupport(self): if self.controller.supported_sasl_mechanisms: return @@ -546,6 +549,7 @@ class OptionalityHelper: return raise runner.OptionalSaslMechanismNotSupported(mechanism) + @staticmethod def skipUnlessHasMechanism(mech): def decorator(f): @functools.wraps(f) @@ -565,22 +569,6 @@ class OptionalityHelper: return newf - def checkCapabilitySupport(self, cap): - if cap in self.controller.supported_capabilities: - return - raise runner.CapabilityNotSupported(cap) - - def skipUnlessSupportsCapability(cap): - def decorator(f): - @functools.wraps(f) - def newf(self): - self.checkCapabilitySupport(cap) - return f(self) - - return newf - - return decorator - def mark_specifications(*specifications, deprecated=False, strict=False): specifications = frozenset( diff --git a/irctest/client_tests/test_tls.py b/irctest/client_tests/test_tls.py index a317cd1..4f826cc 100644 --- a/irctest/client_tests/test_tls.py +++ b/irctest/client_tests/test_tls.py @@ -1,7 +1,7 @@ import socket import ssl -from irctest import cases, tls +from irctest import cases, runner, tls from irctest.exceptions import ConnectionClosed BAD_CERT = """ @@ -146,8 +146,10 @@ class StsTestCase(cases.BaseClientTestCase, cases.OptionalityHelper): self.insecure_server.close() super().tearDown() - @cases.OptionalityHelper.skipUnlessSupportsCapability("sts") + @cases.mark_capabilities("sts") def testSts(self): + if not self.controller.supports_sts: + raise runner.CapabilityNotSupported("sts") tls_config = tls.TlsConfig( enable=False, trusted_fingerprints=[GOOD_FINGERPRINT] ) @@ -191,8 +193,11 @@ class StsTestCase(cases.BaseClientTestCase, cases.OptionalityHelper): # server self.acceptClient() - @cases.OptionalityHelper.skipUnlessSupportsCapability("sts") + @cases.mark_capabilities("sts") def testStsInvalidCertificate(self): + if not self.controller.supports_sts: + raise runner.CapabilityNotSupported("sts") + # Connect client to insecure server (hostname, port) = self.insecure_server.getsockname() self.controller.run(hostname=hostname, port=port, auth=None) diff --git a/irctest/controllers/charybdis.py b/irctest/controllers/charybdis.py index 1dabb04..05a3d06 100644 --- a/irctest/controllers/charybdis.py +++ b/irctest/controllers/charybdis.py @@ -1,5 +1,6 @@ import os import subprocess +from typing import Set from irctest.basecontrollers import ( BaseServerController, @@ -43,8 +44,8 @@ TEMPLATE_SSL_CONFIG = """ class CharybdisController(BaseServerController, DirectoryBasedController): software_name = "Charybdis" binary_name = "charybdis" - supported_sasl_mechanisms = set() - supported_capabilities = set() # Not exhaustive + supported_sasl_mechanisms: Set[str] = set() + supports_sts = False def create_config(self): super().create_config() diff --git a/irctest/controllers/girc.py b/irctest/controllers/girc.py index 5c52aba..e4a64a7 100644 --- a/irctest/controllers/girc.py +++ b/irctest/controllers/girc.py @@ -6,7 +6,6 @@ from irctest.basecontrollers import BaseClientController, NotImplementedByContro class GircController(BaseClientController): software_name = "gIRC" supported_sasl_mechanisms = ["PLAIN"] - supported_capabilities = set() # Not exhaustive def __init__(self): super().__init__() diff --git a/irctest/controllers/hybrid.py b/irctest/controllers/hybrid.py index 7eed360..eb0ed13 100644 --- a/irctest/controllers/hybrid.py +++ b/irctest/controllers/hybrid.py @@ -1,5 +1,6 @@ import os import subprocess +from typing import Set from irctest.basecontrollers import ( BaseServerController, @@ -40,8 +41,8 @@ TEMPLATE_SSL_CONFIG = """ class HybridController(BaseServerController, DirectoryBasedController): software_name = "Hybrid" - supported_sasl_mechanisms = set() - supported_capabilities = set() # Not exhaustive + supports_sts = False + supported_sasl_mechanisms: Set[str] = set() def create_config(self): super().create_config() diff --git a/irctest/controllers/inspircd.py b/irctest/controllers/inspircd.py index 0432eb6..133b244 100644 --- a/irctest/controllers/inspircd.py +++ b/irctest/controllers/inspircd.py @@ -1,5 +1,6 @@ import os import subprocess +from typing import Set from irctest.basecontrollers import ( BaseServerController, @@ -38,8 +39,8 @@ TEMPLATE_SSL_CONFIG = """ class InspircdController(BaseServerController, DirectoryBasedController): software_name = "InspIRCd" - supported_sasl_mechanisms = set() - supported_capabilities = set() # Not exhaustive + supported_sasl_mechanisms: Set[str] = set() + supports_str = False def create_config(self): super().create_config() diff --git a/irctest/controllers/limnoria.py b/irctest/controllers/limnoria.py index 309226e..aad91eb 100644 --- a/irctest/controllers/limnoria.py +++ b/irctest/controllers/limnoria.py @@ -33,7 +33,7 @@ class LimnoriaController(BaseClientController, DirectoryBasedController): "SCRAM-SHA-256", "EXTERNAL", } - supported_capabilities = set(["sts"]) # Not exhaustive + supports_sts = True def create_config(self): create_config = super().create_config() diff --git a/irctest/controllers/mammon.py b/irctest/controllers/mammon.py index 4601664..bbeb4ef 100644 --- a/irctest/controllers/mammon.py +++ b/irctest/controllers/mammon.py @@ -68,7 +68,6 @@ def make_list(list_): class MammonController(BaseServerController, DirectoryBasedController): software_name = "Mammon" supported_sasl_mechanisms = {"PLAIN", "ECDSA-NIST256P-CHALLENGE"} - supported_capabilities = set() # Not exhaustive def create_config(self): super().create_config() diff --git a/irctest/controllers/oragono.py b/irctest/controllers/oragono.py index 9c40830..1a22a89 100644 --- a/irctest/controllers/oragono.py +++ b/irctest/controllers/oragono.py @@ -130,9 +130,9 @@ def hash_password(password): class OragonoController(BaseServerController, DirectoryBasedController): software_name = "Oragono" - supported_sasl_mechanisms = {"PLAIN"} _port_wait_interval = 0.01 - supported_capabilities = set() # Not exhaustive + supported_sasl_mechanisms = {"PLAIN"} + supports_sts = True def create_config(self): super().create_config() diff --git a/irctest/controllers/sopel.py b/irctest/controllers/sopel.py index 951e770..37dce97 100644 --- a/irctest/controllers/sopel.py +++ b/irctest/controllers/sopel.py @@ -22,7 +22,7 @@ auth_password = {password} class SopelController(BaseClientController): software_name = "Sopel" supported_sasl_mechanisms = {"PLAIN"} - supported_capabilities = set() # Not exhaustive + supports_sts = False def __init__(self, test_config): super().__init__(test_config) diff --git a/irctest/server_tests/test_channel_operations.py b/irctest/server_tests/test_channel_operations.py index 733c9e0..a16826c 100644 --- a/irctest/server_tests/test_channel_operations.py +++ b/irctest/server_tests/test_channel_operations.py @@ -709,55 +709,61 @@ class JoinTestCase(cases.BaseServerTestCase): ) +def _testChannelsEquivalent(casemapping, name1, name2): + """Generates test functions""" + + @cases.mark_specifications("RFC1459", "RFC2812", strict=True) + def f(self): + self.connectClient("foo") + self.connectClient("bar") + if self.server_support["CASEMAPPING"] != casemapping: + raise runner.NotImplementedByController( + "Casemapping {} not implemented".format(casemapping) + ) + self.joinClient(1, name1) + self.joinClient(2, name2) + try: + m = self.getMessage(1) + self.assertMessageEqual(m, command="JOIN", nick="bar") + except client_mock.NoMessageException: + raise AssertionError( + "Channel names {} and {} are not equivalent.".format(name1, name2) + ) + + f.__name__ = "testEquivalence__{}__{}".format(name1, name2) + return f + + +def _testChannelsNotEquivalent(casemapping, name1, name2): + """Generates test functions""" + + @cases.mark_specifications("RFC1459", "RFC2812", strict=True) + def f(self): + self.connectClient("foo") + self.connectClient("bar") + if self.server_support["CASEMAPPING"] != casemapping: + raise runner.NotImplementedByController( + "Casemapping {} not implemented".format(casemapping) + ) + self.joinClient(1, name1) + self.joinClient(2, name2) + try: + m = self.getMessage(1) + except client_mock.NoMessageException: + pass + else: + self.assertMessageEqual( + m, command="JOIN", nick="bar" + ) # This should always be true + raise AssertionError( + "Channel names {} and {} are equivalent.".format(name1, name2) + ) + + f.__name__ = "testEquivalence__{}__{}".format(name1, name2) + return f + + class testChannelCaseSensitivity(cases.BaseServerTestCase): - def _testChannelsEquivalent(casemapping, name1, name2): - @cases.mark_specifications("RFC1459", "RFC2812", strict=True) - def f(self): - self.connectClient("foo") - self.connectClient("bar") - if self.server_support["CASEMAPPING"] != casemapping: - raise runner.NotImplementedByController( - "Casemapping {} not implemented".format(casemapping) - ) - self.joinClient(1, name1) - self.joinClient(2, name2) - try: - m = self.getMessage(1) - self.assertMessageEqual(m, command="JOIN", nick="bar") - except client_mock.NoMessageException: - raise AssertionError( - "Channel names {} and {} are not equivalent.".format(name1, name2) - ) - - f.__name__ = "testEquivalence__{}__{}".format(name1, name2) - return f - - def _testChannelsNotEquivalent(casemapping, name1, name2): - @cases.mark_specifications("RFC1459", "RFC2812", strict=True) - def f(self): - self.connectClient("foo") - self.connectClient("bar") - if self.server_support["CASEMAPPING"] != casemapping: - raise runner.NotImplementedByController( - "Casemapping {} not implemented".format(casemapping) - ) - self.joinClient(1, name1) - self.joinClient(2, name2) - try: - m = self.getMessage(1) - except client_mock.NoMessageException: - pass - else: - self.assertMessageEqual( - m, command="JOIN", nick="bar" - ) # This should always be true - raise AssertionError( - "Channel names {} and {} are equivalent.".format(name1, name2) - ) - - f.__name__ = "testEquivalence__{}__{}".format(name1, name2) - return f - testAsciiSimpleEquivalent = _testChannelsEquivalent("ascii", "#Foo", "#foo") testAsciiSimpleNotEquivalent = _testChannelsNotEquivalent("ascii", "#Foo", "#fooa") diff --git a/irctest/server_tests/test_echo_message.py b/irctest/server_tests/test_echo_message.py index ad2057d..e94eb19 100644 --- a/irctest/server_tests/test_echo_message.py +++ b/irctest/server_tests/test_echo_message.py @@ -7,6 +7,94 @@ from irctest.basecontrollers import NotImplementedByController from irctest.irc_utils.junkdrawer import random_name +def _testEchoMessage(command, solo, server_time): + """Generates test functions""" + + @cases.mark_capabilities("echo-message") + def f(self): + """""" + self.addClient() + self.sendLine(1, "CAP LS 302") + capabilities = self.getCapLs(1) + if "echo-message" not in capabilities: + raise NotImplementedByController("echo-message") + if server_time and "server-time" not in capabilities: + raise NotImplementedByController("server-time") + + # TODO: check also without this + self.sendLine( + 1, + "CAP REQ :echo-message{}".format(" server-time" if server_time else ""), + ) + self.getRegistrationMessage(1) + # TODO: Remove this one the trailing space issue is fixed in Charybdis + # and Mammon: + # self.assertMessageEqual(m, command='CAP', + # params=['*', 'ACK', 'echo-message'] + + # (['server-time'] if server_time else []), + # fail_msg='Did not ACK advertised capabilities: {msg}') + self.sendLine(1, "USER f * * :foo") + self.sendLine(1, "NICK baz") + self.sendLine(1, "CAP END") + self.skipToWelcome(1) + self.getMessages(1) + + self.sendLine(1, "JOIN #chan") + + if not solo: + capabilities = ["server-time"] if server_time else None + self.connectClient("qux", capabilities=capabilities) + self.sendLine(2, "JOIN #chan") + + # Synchronize and clean + self.getMessages(1) + if not solo: + self.getMessages(2) + self.getMessages(1) + + self.sendLine(1, "{} #chan :hello everyone".format(command)) + m1 = self.getMessage(1) + self.assertMessageEqual( + m1, + command=command, + params=["#chan", "hello everyone"], + fail_msg="Did not echo “{} #chan :hello everyone”: {msg}", + extra_format=(command,), + ) + + if not solo: + m2 = self.getMessage(2) + self.assertMessageEqual( + m2, + command=command, + params=["#chan", "hello everyone"], + fail_msg="Did not propagate “{} #chan :hello everyone”: " + "after echoing it to the author: {msg}", + extra_format=(command,), + ) + self.assertEqual( + m1.params, + m2.params, + fail_msg="Parameters of forwarded and echoed " "messages differ: {} {}", + extra_format=(m1, m2), + ) + if server_time: + self.assertIn( + "time", + m1.tags, + fail_msg="Echoed message is missing server time: {}", + extra_format=(m1,), + ) + self.assertIn( + "time", + m2.tags, + fail_msg="Forwarded message is missing server time: {}", + extra_format=(m2,), + ) + + return f + + class EchoMessageTestCase(cases.BaseServerTestCase): @cases.mark_capabilities("labeled-response", "echo-message", "message-tags") def testDirectMessageEcho(self): @@ -51,92 +139,6 @@ class EchoMessageTestCase(cases.BaseServerTestCase): delivery.tags["+example-client-tag"], echo.tags["+example-client-tag"] ) - def _testEchoMessage(command, solo, server_time): - @cases.mark_capabilities("echo-message") - def f(self): - """""" - self.addClient() - self.sendLine(1, "CAP LS 302") - capabilities = self.getCapLs(1) - if "echo-message" not in capabilities: - raise NotImplementedByController("echo-message") - if server_time and "server-time" not in capabilities: - raise NotImplementedByController("server-time") - - # TODO: check also without this - self.sendLine( - 1, - "CAP REQ :echo-message{}".format(" server-time" if server_time else ""), - ) - self.getRegistrationMessage(1) - # TODO: Remove this one the trailing space issue is fixed in Charybdis - # and Mammon: - # self.assertMessageEqual(m, command='CAP', - # params=['*', 'ACK', 'echo-message'] + - # (['server-time'] if server_time else []), - # fail_msg='Did not ACK advertised capabilities: {msg}') - self.sendLine(1, "USER f * * :foo") - self.sendLine(1, "NICK baz") - self.sendLine(1, "CAP END") - self.skipToWelcome(1) - self.getMessages(1) - - self.sendLine(1, "JOIN #chan") - - if not solo: - capabilities = ["server-time"] if server_time else None - self.connectClient("qux", capabilities=capabilities) - self.sendLine(2, "JOIN #chan") - - # Synchronize and clean - self.getMessages(1) - if not solo: - self.getMessages(2) - self.getMessages(1) - - self.sendLine(1, "{} #chan :hello everyone".format(command)) - m1 = self.getMessage(1) - self.assertMessageEqual( - m1, - command=command, - params=["#chan", "hello everyone"], - fail_msg="Did not echo “{} #chan :hello everyone”: {msg}", - extra_format=(command,), - ) - - if not solo: - m2 = self.getMessage(2) - self.assertMessageEqual( - m2, - command=command, - params=["#chan", "hello everyone"], - fail_msg="Did not propagate “{} #chan :hello everyone”: " - "after echoing it to the author: {msg}", - extra_format=(command,), - ) - self.assertEqual( - m1.params, - m2.params, - fail_msg="Parameters of forwarded and echoed " - "messages differ: {} {}", - extra_format=(m1, m2), - ) - if server_time: - self.assertIn( - "time", - m1.tags, - fail_msg="Echoed message is missing server time: {}", - extra_format=(m1,), - ) - self.assertIn( - "time", - m2.tags, - fail_msg="Forwarded message is missing server time: {}", - extra_format=(m2,), - ) - - return f - testEchoMessagePrivmsgNoServerTime = _testEchoMessage("PRIVMSG", False, False) testEchoMessagePrivmsgSolo = _testEchoMessage("PRIVMSG", True, True) testEchoMessagePrivmsg = _testEchoMessage("PRIVMSG", False, True) diff --git a/irctest/server_tests/test_lusers.py b/irctest/server_tests/test_lusers.py index dfdccba..2f233e6 100644 --- a/irctest/server_tests/test_lusers.py +++ b/irctest/server_tests/test_lusers.py @@ -23,12 +23,12 @@ LUSERME_REGEX = re.compile(r"^.*( [-0-9]* ).*( [-0-9]* ).*$") @dataclass class LusersResult: - GlobalVisible: int = None - GlobalInvisible: int = None - Servers: int = None - Opers: int = None + GlobalVisible: Optional[int] = None + GlobalInvisible: Optional[int] = None + Servers: Optional[int] = None + Opers: Optional[int] = None Unregistered: Optional[int] = None - Channels: int = None + Channels: Optional[int] = None LocalTotal: Optional[int] = None LocalMax: Optional[int] = None GlobalTotal: Optional[int] = None diff --git a/irctest/specifications.py b/irctest/specifications.py index 1b5b8a7..0ae9af0 100644 --- a/irctest/specifications.py +++ b/irctest/specifications.py @@ -34,6 +34,7 @@ class Capabilities(enum.Enum): MULTILINE = "draft/multiline" MULTI_PREFIX = "multi-prefix" SERVER_TIME = "server-time" + STS = "sts" @classmethod def from_name(cls, name): diff --git a/pytest.ini b/pytest.ini index 96feae7..2e299c1 100644 --- a/pytest.ini +++ b/pytest.ini @@ -22,6 +22,7 @@ markers = draft/multiline multi-prefix server-time + sts # isupport tokens MONITOR