Enable mypy, and do the minimal changes to make it pass

This commit is contained in:
Valentin Lorentz 2021-02-28 12:23:06 +01:00 committed by Valentin Lorentz
parent 1c1b8214a0
commit 12da7e1e3b
18 changed files with 195 additions and 179 deletions

View File

@ -14,3 +14,8 @@ repos:
rev: 3.8.3
hooks:
- id: flake8
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.812
hooks:
- id: mypy

View File

@ -1,5 +1,6 @@
import collections
import dataclasses
import enum
from typing import Optional, Tuple
@enum.unique
@ -19,7 +20,9 @@ class Mechanisms(enum.Enum):
scram_sha_256 = 3
Authentication = collections.namedtuple(
"Authentication", "mechanisms username password ecdsa_key"
)
Authentication.__new__.__defaults__ = ([Mechanisms.plain], None, None, None)
@dataclasses.dataclass
class Authentication:
mechanisms: Tuple[Mechanisms] = (Mechanisms.plain,)
username: Optional[str] = None
password: Optional[str] = None
ecdsa_key: Optional[str] = None

View File

@ -4,6 +4,7 @@ import socket
import subprocess
import tempfile
import time
from typing import Set
from .runner import NotImplementedByController
@ -135,6 +136,9 @@ class BaseServerController(_BaseController):
_port_wait_interval = 0.1
port_open = False
supports_sts: bool
supported_sasl_mechanisms: Set[str]
def run(self, hostname, port, password, valid_metadata_keys, invalid_metadata_keys):
raise NotImplementedError()

View File

@ -3,11 +3,12 @@ import socket
import ssl
import tempfile
import time
from typing import Optional, Set
import unittest
import pytest
from . import client_mock, runner
from . import basecontrollers, client_mock, runner
from .exceptions import ConnectionClosed
from .irc_utils import capabilities, message_parser
from .irc_utils.junkdrawer import normalizeWhitespace
@ -350,10 +351,10 @@ class BaseServerTestCase(_IrcTestCase):
"""Basic class for server tests. Handles spawning a server and exchanging
messages with it."""
password = None
password: Optional[str] = None
ssl = False
valid_metadata_keys = frozenset()
invalid_metadata_keys = frozenset()
valid_metadata_keys: Set[str] = set()
invalid_metadata_keys: Set[str] = set()
def setUp(self):
super().setUp()
@ -536,6 +537,8 @@ class BaseServerTestCase(_IrcTestCase):
class OptionalityHelper:
controller: basecontrollers.BaseServerController
def checkSaslSupport(self):
if self.controller.supported_sasl_mechanisms:
return
@ -546,6 +549,7 @@ class OptionalityHelper:
return
raise runner.OptionalSaslMechanismNotSupported(mechanism)
@staticmethod
def skipUnlessHasMechanism(mech):
def decorator(f):
@functools.wraps(f)
@ -565,22 +569,6 @@ class OptionalityHelper:
return newf
def checkCapabilitySupport(self, cap):
if cap in self.controller.supported_capabilities:
return
raise runner.CapabilityNotSupported(cap)
def skipUnlessSupportsCapability(cap):
def decorator(f):
@functools.wraps(f)
def newf(self):
self.checkCapabilitySupport(cap)
return f(self)
return newf
return decorator
def mark_specifications(*specifications, deprecated=False, strict=False):
specifications = frozenset(

View File

@ -1,7 +1,7 @@
import socket
import ssl
from irctest import cases, tls
from irctest import cases, runner, tls
from irctest.exceptions import ConnectionClosed
BAD_CERT = """
@ -146,8 +146,10 @@ class StsTestCase(cases.BaseClientTestCase, cases.OptionalityHelper):
self.insecure_server.close()
super().tearDown()
@cases.OptionalityHelper.skipUnlessSupportsCapability("sts")
@cases.mark_capabilities("sts")
def testSts(self):
if not self.controller.supports_sts:
raise runner.CapabilityNotSupported("sts")
tls_config = tls.TlsConfig(
enable=False, trusted_fingerprints=[GOOD_FINGERPRINT]
)
@ -191,8 +193,11 @@ class StsTestCase(cases.BaseClientTestCase, cases.OptionalityHelper):
# server
self.acceptClient()
@cases.OptionalityHelper.skipUnlessSupportsCapability("sts")
@cases.mark_capabilities("sts")
def testStsInvalidCertificate(self):
if not self.controller.supports_sts:
raise runner.CapabilityNotSupported("sts")
# Connect client to insecure server
(hostname, port) = self.insecure_server.getsockname()
self.controller.run(hostname=hostname, port=port, auth=None)

View File

@ -1,5 +1,6 @@
import os
import subprocess
from typing import Set
from irctest.basecontrollers import (
BaseServerController,
@ -43,8 +44,8 @@ TEMPLATE_SSL_CONFIG = """
class CharybdisController(BaseServerController, DirectoryBasedController):
software_name = "Charybdis"
binary_name = "charybdis"
supported_sasl_mechanisms = set()
supported_capabilities = set() # Not exhaustive
supported_sasl_mechanisms: Set[str] = set()
supports_sts = False
def create_config(self):
super().create_config()

View File

@ -6,7 +6,6 @@ from irctest.basecontrollers import BaseClientController, NotImplementedByContro
class GircController(BaseClientController):
software_name = "gIRC"
supported_sasl_mechanisms = ["PLAIN"]
supported_capabilities = set() # Not exhaustive
def __init__(self):
super().__init__()

View File

@ -1,5 +1,6 @@
import os
import subprocess
from typing import Set
from irctest.basecontrollers import (
BaseServerController,
@ -40,8 +41,8 @@ TEMPLATE_SSL_CONFIG = """
class HybridController(BaseServerController, DirectoryBasedController):
software_name = "Hybrid"
supported_sasl_mechanisms = set()
supported_capabilities = set() # Not exhaustive
supports_sts = False
supported_sasl_mechanisms: Set[str] = set()
def create_config(self):
super().create_config()

View File

@ -1,5 +1,6 @@
import os
import subprocess
from typing import Set
from irctest.basecontrollers import (
BaseServerController,
@ -38,8 +39,8 @@ TEMPLATE_SSL_CONFIG = """
class InspircdController(BaseServerController, DirectoryBasedController):
software_name = "InspIRCd"
supported_sasl_mechanisms = set()
supported_capabilities = set() # Not exhaustive
supported_sasl_mechanisms: Set[str] = set()
supports_str = False
def create_config(self):
super().create_config()

View File

@ -33,7 +33,7 @@ class LimnoriaController(BaseClientController, DirectoryBasedController):
"SCRAM-SHA-256",
"EXTERNAL",
}
supported_capabilities = set(["sts"]) # Not exhaustive
supports_sts = True
def create_config(self):
create_config = super().create_config()

View File

@ -68,7 +68,6 @@ def make_list(list_):
class MammonController(BaseServerController, DirectoryBasedController):
software_name = "Mammon"
supported_sasl_mechanisms = {"PLAIN", "ECDSA-NIST256P-CHALLENGE"}
supported_capabilities = set() # Not exhaustive
def create_config(self):
super().create_config()

View File

@ -130,9 +130,9 @@ def hash_password(password):
class OragonoController(BaseServerController, DirectoryBasedController):
software_name = "Oragono"
supported_sasl_mechanisms = {"PLAIN"}
_port_wait_interval = 0.01
supported_capabilities = set() # Not exhaustive
supported_sasl_mechanisms = {"PLAIN"}
supports_sts = True
def create_config(self):
super().create_config()

View File

@ -22,7 +22,7 @@ auth_password = {password}
class SopelController(BaseClientController):
software_name = "Sopel"
supported_sasl_mechanisms = {"PLAIN"}
supported_capabilities = set() # Not exhaustive
supports_sts = False
def __init__(self, test_config):
super().__init__(test_config)

View File

@ -709,8 +709,9 @@ class JoinTestCase(cases.BaseServerTestCase):
)
class testChannelCaseSensitivity(cases.BaseServerTestCase):
def _testChannelsEquivalent(casemapping, name1, name2):
def _testChannelsEquivalent(casemapping, name1, name2):
"""Generates test functions"""
@cases.mark_specifications("RFC1459", "RFC2812", strict=True)
def f(self):
self.connectClient("foo")
@ -732,7 +733,10 @@ class testChannelCaseSensitivity(cases.BaseServerTestCase):
f.__name__ = "testEquivalence__{}__{}".format(name1, name2)
return f
def _testChannelsNotEquivalent(casemapping, name1, name2):
def _testChannelsNotEquivalent(casemapping, name1, name2):
"""Generates test functions"""
@cases.mark_specifications("RFC1459", "RFC2812", strict=True)
def f(self):
self.connectClient("foo")
@ -758,6 +762,8 @@ class testChannelCaseSensitivity(cases.BaseServerTestCase):
f.__name__ = "testEquivalence__{}__{}".format(name1, name2)
return f
class testChannelCaseSensitivity(cases.BaseServerTestCase):
testAsciiSimpleEquivalent = _testChannelsEquivalent("ascii", "#Foo", "#foo")
testAsciiSimpleNotEquivalent = _testChannelsNotEquivalent("ascii", "#Foo", "#fooa")

View File

@ -7,51 +7,9 @@ from irctest.basecontrollers import NotImplementedByController
from irctest.irc_utils.junkdrawer import random_name
class EchoMessageTestCase(cases.BaseServerTestCase):
@cases.mark_capabilities("labeled-response", "echo-message", "message-tags")
def testDirectMessageEcho(self):
bar = random_name("bar")
self.connectClient(
bar,
name=bar,
capabilities=["labeled-response", "echo-message", "message-tags"],
skip_if_cap_nak=True,
)
self.getMessages(bar)
def _testEchoMessage(command, solo, server_time):
"""Generates test functions"""
qux = random_name("qux")
self.connectClient(
qux,
name=qux,
capabilities=["labeled-response", "echo-message", "message-tags"],
)
self.getMessages(qux)
self.sendLine(
bar,
"@label=xyz;+example-client-tag=example-value PRIVMSG %s :hi there"
% (qux,),
)
echo = self.getMessages(bar)[0]
delivery = self.getMessages(qux)[0]
self.assertEqual(delivery.params, [qux, "hi there"])
self.assertEqual(delivery.params, echo.params)
# Either both messages have a msgid, or neither does
self.assertEqual(delivery.tags.get("msgid"), echo.tags.get("msgid"))
self.assertEqual(
echo.tags.get("label"),
"xyz",
fail_msg="expected message label 'xyz', but got {got!r}",
)
self.assertEqual(delivery.tags["+example-client-tag"], "example-value")
self.assertEqual(
delivery.tags["+example-client-tag"], echo.tags["+example-client-tag"]
)
def _testEchoMessage(command, solo, server_time):
@cases.mark_capabilities("echo-message")
def f(self):
"""<http://ircv3.net/specs/extensions/echo-message-3.2.html>"""
@ -117,8 +75,7 @@ class EchoMessageTestCase(cases.BaseServerTestCase):
self.assertEqual(
m1.params,
m2.params,
fail_msg="Parameters of forwarded and echoed "
"messages differ: {} {}",
fail_msg="Parameters of forwarded and echoed " "messages differ: {} {}",
extra_format=(m1, m2),
)
if server_time:
@ -137,6 +94,51 @@ class EchoMessageTestCase(cases.BaseServerTestCase):
return f
class EchoMessageTestCase(cases.BaseServerTestCase):
@cases.mark_capabilities("labeled-response", "echo-message", "message-tags")
def testDirectMessageEcho(self):
bar = random_name("bar")
self.connectClient(
bar,
name=bar,
capabilities=["labeled-response", "echo-message", "message-tags"],
skip_if_cap_nak=True,
)
self.getMessages(bar)
qux = random_name("qux")
self.connectClient(
qux,
name=qux,
capabilities=["labeled-response", "echo-message", "message-tags"],
)
self.getMessages(qux)
self.sendLine(
bar,
"@label=xyz;+example-client-tag=example-value PRIVMSG %s :hi there"
% (qux,),
)
echo = self.getMessages(bar)[0]
delivery = self.getMessages(qux)[0]
self.assertEqual(delivery.params, [qux, "hi there"])
self.assertEqual(delivery.params, echo.params)
# Either both messages have a msgid, or neither does
self.assertEqual(delivery.tags.get("msgid"), echo.tags.get("msgid"))
self.assertEqual(
echo.tags.get("label"),
"xyz",
fail_msg="expected message label 'xyz', but got {got!r}",
)
self.assertEqual(delivery.tags["+example-client-tag"], "example-value")
self.assertEqual(
delivery.tags["+example-client-tag"], echo.tags["+example-client-tag"]
)
testEchoMessagePrivmsgNoServerTime = _testEchoMessage("PRIVMSG", False, False)
testEchoMessagePrivmsgSolo = _testEchoMessage("PRIVMSG", True, True)
testEchoMessagePrivmsg = _testEchoMessage("PRIVMSG", False, True)

View File

@ -23,12 +23,12 @@ LUSERME_REGEX = re.compile(r"^.*( [-0-9]* ).*( [-0-9]* ).*$")
@dataclass
class LusersResult:
GlobalVisible: int = None
GlobalInvisible: int = None
Servers: int = None
Opers: int = None
GlobalVisible: Optional[int] = None
GlobalInvisible: Optional[int] = None
Servers: Optional[int] = None
Opers: Optional[int] = None
Unregistered: Optional[int] = None
Channels: int = None
Channels: Optional[int] = None
LocalTotal: Optional[int] = None
LocalMax: Optional[int] = None
GlobalTotal: Optional[int] = None

View File

@ -34,6 +34,7 @@ class Capabilities(enum.Enum):
MULTILINE = "draft/multiline"
MULTI_PREFIX = "multi-prefix"
SERVER_TIME = "server-time"
STS = "sts"
@classmethod
def from_name(cls, name):

View File

@ -22,6 +22,7 @@ markers =
draft/multiline
multi-prefix
server-time
sts
# isupport tokens
MONITOR