Enable mypy, and do the minimal changes to make it pass

This commit is contained in:
Valentin Lorentz 2021-02-28 12:23:06 +01:00 committed by Valentin Lorentz
parent 1c1b8214a0
commit 12da7e1e3b
18 changed files with 195 additions and 179 deletions

View File

@ -14,3 +14,8 @@ repos:
rev: 3.8.3
hooks:
- id: flake8
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.812
hooks:
- id: mypy

View File

@ -1,5 +1,6 @@
import collections
import dataclasses
import enum
from typing import Optional, Tuple
@enum.unique
@ -19,7 +20,9 @@ class Mechanisms(enum.Enum):
scram_sha_256 = 3
Authentication = collections.namedtuple(
"Authentication", "mechanisms username password ecdsa_key"
)
Authentication.__new__.__defaults__ = ([Mechanisms.plain], None, None, None)
@dataclasses.dataclass
class Authentication:
mechanisms: Tuple[Mechanisms] = (Mechanisms.plain,)
username: Optional[str] = None
password: Optional[str] = None
ecdsa_key: Optional[str] = None

View File

@ -4,6 +4,7 @@ import socket
import subprocess
import tempfile
import time
from typing import Set
from .runner import NotImplementedByController
@ -135,6 +136,9 @@ class BaseServerController(_BaseController):
_port_wait_interval = 0.1
port_open = False
supports_sts: bool
supported_sasl_mechanisms: Set[str]
def run(self, hostname, port, password, valid_metadata_keys, invalid_metadata_keys):
raise NotImplementedError()

View File

@ -3,11 +3,12 @@ import socket
import ssl
import tempfile
import time
from typing import Optional, Set
import unittest
import pytest
from . import client_mock, runner
from . import basecontrollers, client_mock, runner
from .exceptions import ConnectionClosed
from .irc_utils import capabilities, message_parser
from .irc_utils.junkdrawer import normalizeWhitespace
@ -350,10 +351,10 @@ class BaseServerTestCase(_IrcTestCase):
"""Basic class for server tests. Handles spawning a server and exchanging
messages with it."""
password = None
password: Optional[str] = None
ssl = False
valid_metadata_keys = frozenset()
invalid_metadata_keys = frozenset()
valid_metadata_keys: Set[str] = set()
invalid_metadata_keys: Set[str] = set()
def setUp(self):
super().setUp()
@ -536,6 +537,8 @@ class BaseServerTestCase(_IrcTestCase):
class OptionalityHelper:
controller: basecontrollers.BaseServerController
def checkSaslSupport(self):
if self.controller.supported_sasl_mechanisms:
return
@ -546,6 +549,7 @@ class OptionalityHelper:
return
raise runner.OptionalSaslMechanismNotSupported(mechanism)
@staticmethod
def skipUnlessHasMechanism(mech):
def decorator(f):
@functools.wraps(f)
@ -565,22 +569,6 @@ class OptionalityHelper:
return newf
def checkCapabilitySupport(self, cap):
if cap in self.controller.supported_capabilities:
return
raise runner.CapabilityNotSupported(cap)
def skipUnlessSupportsCapability(cap):
def decorator(f):
@functools.wraps(f)
def newf(self):
self.checkCapabilitySupport(cap)
return f(self)
return newf
return decorator
def mark_specifications(*specifications, deprecated=False, strict=False):
specifications = frozenset(

View File

@ -1,7 +1,7 @@
import socket
import ssl
from irctest import cases, tls
from irctest import cases, runner, tls
from irctest.exceptions import ConnectionClosed
BAD_CERT = """
@ -146,8 +146,10 @@ class StsTestCase(cases.BaseClientTestCase, cases.OptionalityHelper):
self.insecure_server.close()
super().tearDown()
@cases.OptionalityHelper.skipUnlessSupportsCapability("sts")
@cases.mark_capabilities("sts")
def testSts(self):
if not self.controller.supports_sts:
raise runner.CapabilityNotSupported("sts")
tls_config = tls.TlsConfig(
enable=False, trusted_fingerprints=[GOOD_FINGERPRINT]
)
@ -191,8 +193,11 @@ class StsTestCase(cases.BaseClientTestCase, cases.OptionalityHelper):
# server
self.acceptClient()
@cases.OptionalityHelper.skipUnlessSupportsCapability("sts")
@cases.mark_capabilities("sts")
def testStsInvalidCertificate(self):
if not self.controller.supports_sts:
raise runner.CapabilityNotSupported("sts")
# Connect client to insecure server
(hostname, port) = self.insecure_server.getsockname()
self.controller.run(hostname=hostname, port=port, auth=None)

View File

@ -1,5 +1,6 @@
import os
import subprocess
from typing import Set
from irctest.basecontrollers import (
BaseServerController,
@ -43,8 +44,8 @@ TEMPLATE_SSL_CONFIG = """
class CharybdisController(BaseServerController, DirectoryBasedController):
software_name = "Charybdis"
binary_name = "charybdis"
supported_sasl_mechanisms = set()
supported_capabilities = set() # Not exhaustive
supported_sasl_mechanisms: Set[str] = set()
supports_sts = False
def create_config(self):
super().create_config()

View File

@ -6,7 +6,6 @@ from irctest.basecontrollers import BaseClientController, NotImplementedByContro
class GircController(BaseClientController):
software_name = "gIRC"
supported_sasl_mechanisms = ["PLAIN"]
supported_capabilities = set() # Not exhaustive
def __init__(self):
super().__init__()

View File

@ -1,5 +1,6 @@
import os
import subprocess
from typing import Set
from irctest.basecontrollers import (
BaseServerController,
@ -40,8 +41,8 @@ TEMPLATE_SSL_CONFIG = """
class HybridController(BaseServerController, DirectoryBasedController):
software_name = "Hybrid"
supported_sasl_mechanisms = set()
supported_capabilities = set() # Not exhaustive
supports_sts = False
supported_sasl_mechanisms: Set[str] = set()
def create_config(self):
super().create_config()

View File

@ -1,5 +1,6 @@
import os
import subprocess
from typing import Set
from irctest.basecontrollers import (
BaseServerController,
@ -38,8 +39,8 @@ TEMPLATE_SSL_CONFIG = """
class InspircdController(BaseServerController, DirectoryBasedController):
software_name = "InspIRCd"
supported_sasl_mechanisms = set()
supported_capabilities = set() # Not exhaustive
supported_sasl_mechanisms: Set[str] = set()
supports_str = False
def create_config(self):
super().create_config()

View File

@ -33,7 +33,7 @@ class LimnoriaController(BaseClientController, DirectoryBasedController):
"SCRAM-SHA-256",
"EXTERNAL",
}
supported_capabilities = set(["sts"]) # Not exhaustive
supports_sts = True
def create_config(self):
create_config = super().create_config()

View File

@ -68,7 +68,6 @@ def make_list(list_):
class MammonController(BaseServerController, DirectoryBasedController):
software_name = "Mammon"
supported_sasl_mechanisms = {"PLAIN", "ECDSA-NIST256P-CHALLENGE"}
supported_capabilities = set() # Not exhaustive
def create_config(self):
super().create_config()

View File

@ -130,9 +130,9 @@ def hash_password(password):
class OragonoController(BaseServerController, DirectoryBasedController):
software_name = "Oragono"
supported_sasl_mechanisms = {"PLAIN"}
_port_wait_interval = 0.01
supported_capabilities = set() # Not exhaustive
supported_sasl_mechanisms = {"PLAIN"}
supports_sts = True
def create_config(self):
super().create_config()

View File

@ -22,7 +22,7 @@ auth_password = {password}
class SopelController(BaseClientController):
software_name = "Sopel"
supported_sasl_mechanisms = {"PLAIN"}
supported_capabilities = set() # Not exhaustive
supports_sts = False
def __init__(self, test_config):
super().__init__(test_config)

View File

@ -709,55 +709,61 @@ class JoinTestCase(cases.BaseServerTestCase):
)
def _testChannelsEquivalent(casemapping, name1, name2):
"""Generates test functions"""
@cases.mark_specifications("RFC1459", "RFC2812", strict=True)
def f(self):
self.connectClient("foo")
self.connectClient("bar")
if self.server_support["CASEMAPPING"] != casemapping:
raise runner.NotImplementedByController(
"Casemapping {} not implemented".format(casemapping)
)
self.joinClient(1, name1)
self.joinClient(2, name2)
try:
m = self.getMessage(1)
self.assertMessageEqual(m, command="JOIN", nick="bar")
except client_mock.NoMessageException:
raise AssertionError(
"Channel names {} and {} are not equivalent.".format(name1, name2)
)
f.__name__ = "testEquivalence__{}__{}".format(name1, name2)
return f
def _testChannelsNotEquivalent(casemapping, name1, name2):
"""Generates test functions"""
@cases.mark_specifications("RFC1459", "RFC2812", strict=True)
def f(self):
self.connectClient("foo")
self.connectClient("bar")
if self.server_support["CASEMAPPING"] != casemapping:
raise runner.NotImplementedByController(
"Casemapping {} not implemented".format(casemapping)
)
self.joinClient(1, name1)
self.joinClient(2, name2)
try:
m = self.getMessage(1)
except client_mock.NoMessageException:
pass
else:
self.assertMessageEqual(
m, command="JOIN", nick="bar"
) # This should always be true
raise AssertionError(
"Channel names {} and {} are equivalent.".format(name1, name2)
)
f.__name__ = "testEquivalence__{}__{}".format(name1, name2)
return f
class testChannelCaseSensitivity(cases.BaseServerTestCase):
def _testChannelsEquivalent(casemapping, name1, name2):
@cases.mark_specifications("RFC1459", "RFC2812", strict=True)
def f(self):
self.connectClient("foo")
self.connectClient("bar")
if self.server_support["CASEMAPPING"] != casemapping:
raise runner.NotImplementedByController(
"Casemapping {} not implemented".format(casemapping)
)
self.joinClient(1, name1)
self.joinClient(2, name2)
try:
m = self.getMessage(1)
self.assertMessageEqual(m, command="JOIN", nick="bar")
except client_mock.NoMessageException:
raise AssertionError(
"Channel names {} and {} are not equivalent.".format(name1, name2)
)
f.__name__ = "testEquivalence__{}__{}".format(name1, name2)
return f
def _testChannelsNotEquivalent(casemapping, name1, name2):
@cases.mark_specifications("RFC1459", "RFC2812", strict=True)
def f(self):
self.connectClient("foo")
self.connectClient("bar")
if self.server_support["CASEMAPPING"] != casemapping:
raise runner.NotImplementedByController(
"Casemapping {} not implemented".format(casemapping)
)
self.joinClient(1, name1)
self.joinClient(2, name2)
try:
m = self.getMessage(1)
except client_mock.NoMessageException:
pass
else:
self.assertMessageEqual(
m, command="JOIN", nick="bar"
) # This should always be true
raise AssertionError(
"Channel names {} and {} are equivalent.".format(name1, name2)
)
f.__name__ = "testEquivalence__{}__{}".format(name1, name2)
return f
testAsciiSimpleEquivalent = _testChannelsEquivalent("ascii", "#Foo", "#foo")
testAsciiSimpleNotEquivalent = _testChannelsNotEquivalent("ascii", "#Foo", "#fooa")

View File

@ -7,6 +7,94 @@ from irctest.basecontrollers import NotImplementedByController
from irctest.irc_utils.junkdrawer import random_name
def _testEchoMessage(command, solo, server_time):
"""Generates test functions"""
@cases.mark_capabilities("echo-message")
def f(self):
"""<http://ircv3.net/specs/extensions/echo-message-3.2.html>"""
self.addClient()
self.sendLine(1, "CAP LS 302")
capabilities = self.getCapLs(1)
if "echo-message" not in capabilities:
raise NotImplementedByController("echo-message")
if server_time and "server-time" not in capabilities:
raise NotImplementedByController("server-time")
# TODO: check also without this
self.sendLine(
1,
"CAP REQ :echo-message{}".format(" server-time" if server_time else ""),
)
self.getRegistrationMessage(1)
# TODO: Remove this one the trailing space issue is fixed in Charybdis
# and Mammon:
# self.assertMessageEqual(m, command='CAP',
# params=['*', 'ACK', 'echo-message'] +
# (['server-time'] if server_time else []),
# fail_msg='Did not ACK advertised capabilities: {msg}')
self.sendLine(1, "USER f * * :foo")
self.sendLine(1, "NICK baz")
self.sendLine(1, "CAP END")
self.skipToWelcome(1)
self.getMessages(1)
self.sendLine(1, "JOIN #chan")
if not solo:
capabilities = ["server-time"] if server_time else None
self.connectClient("qux", capabilities=capabilities)
self.sendLine(2, "JOIN #chan")
# Synchronize and clean
self.getMessages(1)
if not solo:
self.getMessages(2)
self.getMessages(1)
self.sendLine(1, "{} #chan :hello everyone".format(command))
m1 = self.getMessage(1)
self.assertMessageEqual(
m1,
command=command,
params=["#chan", "hello everyone"],
fail_msg="Did not echo “{} #chan :hello everyone”: {msg}",
extra_format=(command,),
)
if not solo:
m2 = self.getMessage(2)
self.assertMessageEqual(
m2,
command=command,
params=["#chan", "hello everyone"],
fail_msg="Did not propagate “{} #chan :hello everyone”: "
"after echoing it to the author: {msg}",
extra_format=(command,),
)
self.assertEqual(
m1.params,
m2.params,
fail_msg="Parameters of forwarded and echoed " "messages differ: {} {}",
extra_format=(m1, m2),
)
if server_time:
self.assertIn(
"time",
m1.tags,
fail_msg="Echoed message is missing server time: {}",
extra_format=(m1,),
)
self.assertIn(
"time",
m2.tags,
fail_msg="Forwarded message is missing server time: {}",
extra_format=(m2,),
)
return f
class EchoMessageTestCase(cases.BaseServerTestCase):
@cases.mark_capabilities("labeled-response", "echo-message", "message-tags")
def testDirectMessageEcho(self):
@ -51,92 +139,6 @@ class EchoMessageTestCase(cases.BaseServerTestCase):
delivery.tags["+example-client-tag"], echo.tags["+example-client-tag"]
)
def _testEchoMessage(command, solo, server_time):
@cases.mark_capabilities("echo-message")
def f(self):
"""<http://ircv3.net/specs/extensions/echo-message-3.2.html>"""
self.addClient()
self.sendLine(1, "CAP LS 302")
capabilities = self.getCapLs(1)
if "echo-message" not in capabilities:
raise NotImplementedByController("echo-message")
if server_time and "server-time" not in capabilities:
raise NotImplementedByController("server-time")
# TODO: check also without this
self.sendLine(
1,
"CAP REQ :echo-message{}".format(" server-time" if server_time else ""),
)
self.getRegistrationMessage(1)
# TODO: Remove this one the trailing space issue is fixed in Charybdis
# and Mammon:
# self.assertMessageEqual(m, command='CAP',
# params=['*', 'ACK', 'echo-message'] +
# (['server-time'] if server_time else []),
# fail_msg='Did not ACK advertised capabilities: {msg}')
self.sendLine(1, "USER f * * :foo")
self.sendLine(1, "NICK baz")
self.sendLine(1, "CAP END")
self.skipToWelcome(1)
self.getMessages(1)
self.sendLine(1, "JOIN #chan")
if not solo:
capabilities = ["server-time"] if server_time else None
self.connectClient("qux", capabilities=capabilities)
self.sendLine(2, "JOIN #chan")
# Synchronize and clean
self.getMessages(1)
if not solo:
self.getMessages(2)
self.getMessages(1)
self.sendLine(1, "{} #chan :hello everyone".format(command))
m1 = self.getMessage(1)
self.assertMessageEqual(
m1,
command=command,
params=["#chan", "hello everyone"],
fail_msg="Did not echo “{} #chan :hello everyone”: {msg}",
extra_format=(command,),
)
if not solo:
m2 = self.getMessage(2)
self.assertMessageEqual(
m2,
command=command,
params=["#chan", "hello everyone"],
fail_msg="Did not propagate “{} #chan :hello everyone”: "
"after echoing it to the author: {msg}",
extra_format=(command,),
)
self.assertEqual(
m1.params,
m2.params,
fail_msg="Parameters of forwarded and echoed "
"messages differ: {} {}",
extra_format=(m1, m2),
)
if server_time:
self.assertIn(
"time",
m1.tags,
fail_msg="Echoed message is missing server time: {}",
extra_format=(m1,),
)
self.assertIn(
"time",
m2.tags,
fail_msg="Forwarded message is missing server time: {}",
extra_format=(m2,),
)
return f
testEchoMessagePrivmsgNoServerTime = _testEchoMessage("PRIVMSG", False, False)
testEchoMessagePrivmsgSolo = _testEchoMessage("PRIVMSG", True, True)
testEchoMessagePrivmsg = _testEchoMessage("PRIVMSG", False, True)

View File

@ -23,12 +23,12 @@ LUSERME_REGEX = re.compile(r"^.*( [-0-9]* ).*( [-0-9]* ).*$")
@dataclass
class LusersResult:
GlobalVisible: int = None
GlobalInvisible: int = None
Servers: int = None
Opers: int = None
GlobalVisible: Optional[int] = None
GlobalInvisible: Optional[int] = None
Servers: Optional[int] = None
Opers: Optional[int] = None
Unregistered: Optional[int] = None
Channels: int = None
Channels: Optional[int] = None
LocalTotal: Optional[int] = None
LocalMax: Optional[int] = None
GlobalTotal: Optional[int] = None

View File

@ -34,6 +34,7 @@ class Capabilities(enum.Enum):
MULTILINE = "draft/multiline"
MULTI_PREFIX = "multi-prefix"
SERVER_TIME = "server-time"
STS = "sts"
@classmethod
def from_name(cls, name):

View File

@ -22,6 +22,7 @@ markers =
draft/multiline
multi-prefix
server-time
sts
# isupport tokens
MONITOR