mirror of
https://github.com/progval/irctest.git
synced 2025-04-05 14:59:49 +00:00
Enable mypy, and do the minimal changes to make it pass
This commit is contained in:
@ -14,3 +14,8 @@ repos:
|
||||
rev: 3.8.3
|
||||
hooks:
|
||||
- id: flake8
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v0.812
|
||||
hooks:
|
||||
- id: mypy
|
||||
|
@ -1,5 +1,6 @@
|
||||
import collections
|
||||
import dataclasses
|
||||
import enum
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
@enum.unique
|
||||
@ -19,7 +20,9 @@ class Mechanisms(enum.Enum):
|
||||
scram_sha_256 = 3
|
||||
|
||||
|
||||
Authentication = collections.namedtuple(
|
||||
"Authentication", "mechanisms username password ecdsa_key"
|
||||
)
|
||||
Authentication.__new__.__defaults__ = ([Mechanisms.plain], None, None, None)
|
||||
@dataclasses.dataclass
|
||||
class Authentication:
|
||||
mechanisms: Tuple[Mechanisms] = (Mechanisms.plain,)
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
ecdsa_key: Optional[str] = None
|
||||
|
@ -4,6 +4,7 @@ import socket
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Set
|
||||
|
||||
from .runner import NotImplementedByController
|
||||
|
||||
@ -135,6 +136,9 @@ class BaseServerController(_BaseController):
|
||||
_port_wait_interval = 0.1
|
||||
port_open = False
|
||||
|
||||
supports_sts: bool
|
||||
supported_sasl_mechanisms: Set[str]
|
||||
|
||||
def run(self, hostname, port, password, valid_metadata_keys, invalid_metadata_keys):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
@ -3,11 +3,12 @@ import socket
|
||||
import ssl
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Optional, Set
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from . import client_mock, runner
|
||||
from . import basecontrollers, client_mock, runner
|
||||
from .exceptions import ConnectionClosed
|
||||
from .irc_utils import capabilities, message_parser
|
||||
from .irc_utils.junkdrawer import normalizeWhitespace
|
||||
@ -350,10 +351,10 @@ class BaseServerTestCase(_IrcTestCase):
|
||||
"""Basic class for server tests. Handles spawning a server and exchanging
|
||||
messages with it."""
|
||||
|
||||
password = None
|
||||
password: Optional[str] = None
|
||||
ssl = False
|
||||
valid_metadata_keys = frozenset()
|
||||
invalid_metadata_keys = frozenset()
|
||||
valid_metadata_keys: Set[str] = set()
|
||||
invalid_metadata_keys: Set[str] = set()
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@ -536,6 +537,8 @@ class BaseServerTestCase(_IrcTestCase):
|
||||
|
||||
|
||||
class OptionalityHelper:
|
||||
controller: basecontrollers.BaseServerController
|
||||
|
||||
def checkSaslSupport(self):
|
||||
if self.controller.supported_sasl_mechanisms:
|
||||
return
|
||||
@ -546,6 +549,7 @@ class OptionalityHelper:
|
||||
return
|
||||
raise runner.OptionalSaslMechanismNotSupported(mechanism)
|
||||
|
||||
@staticmethod
|
||||
def skipUnlessHasMechanism(mech):
|
||||
def decorator(f):
|
||||
@functools.wraps(f)
|
||||
@ -565,22 +569,6 @@ class OptionalityHelper:
|
||||
|
||||
return newf
|
||||
|
||||
def checkCapabilitySupport(self, cap):
|
||||
if cap in self.controller.supported_capabilities:
|
||||
return
|
||||
raise runner.CapabilityNotSupported(cap)
|
||||
|
||||
def skipUnlessSupportsCapability(cap):
|
||||
def decorator(f):
|
||||
@functools.wraps(f)
|
||||
def newf(self):
|
||||
self.checkCapabilitySupport(cap)
|
||||
return f(self)
|
||||
|
||||
return newf
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def mark_specifications(*specifications, deprecated=False, strict=False):
|
||||
specifications = frozenset(
|
||||
|
@ -1,7 +1,7 @@
|
||||
import socket
|
||||
import ssl
|
||||
|
||||
from irctest import cases, tls
|
||||
from irctest import cases, runner, tls
|
||||
from irctest.exceptions import ConnectionClosed
|
||||
|
||||
BAD_CERT = """
|
||||
@ -146,8 +146,10 @@ class StsTestCase(cases.BaseClientTestCase, cases.OptionalityHelper):
|
||||
self.insecure_server.close()
|
||||
super().tearDown()
|
||||
|
||||
@cases.OptionalityHelper.skipUnlessSupportsCapability("sts")
|
||||
@cases.mark_capabilities("sts")
|
||||
def testSts(self):
|
||||
if not self.controller.supports_sts:
|
||||
raise runner.CapabilityNotSupported("sts")
|
||||
tls_config = tls.TlsConfig(
|
||||
enable=False, trusted_fingerprints=[GOOD_FINGERPRINT]
|
||||
)
|
||||
@ -191,8 +193,11 @@ class StsTestCase(cases.BaseClientTestCase, cases.OptionalityHelper):
|
||||
# server
|
||||
self.acceptClient()
|
||||
|
||||
@cases.OptionalityHelper.skipUnlessSupportsCapability("sts")
|
||||
@cases.mark_capabilities("sts")
|
||||
def testStsInvalidCertificate(self):
|
||||
if not self.controller.supports_sts:
|
||||
raise runner.CapabilityNotSupported("sts")
|
||||
|
||||
# Connect client to insecure server
|
||||
(hostname, port) = self.insecure_server.getsockname()
|
||||
self.controller.run(hostname=hostname, port=port, auth=None)
|
||||
|
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import subprocess
|
||||
from typing import Set
|
||||
|
||||
from irctest.basecontrollers import (
|
||||
BaseServerController,
|
||||
@ -43,8 +44,8 @@ TEMPLATE_SSL_CONFIG = """
|
||||
class CharybdisController(BaseServerController, DirectoryBasedController):
|
||||
software_name = "Charybdis"
|
||||
binary_name = "charybdis"
|
||||
supported_sasl_mechanisms = set()
|
||||
supported_capabilities = set() # Not exhaustive
|
||||
supported_sasl_mechanisms: Set[str] = set()
|
||||
supports_sts = False
|
||||
|
||||
def create_config(self):
|
||||
super().create_config()
|
||||
|
@ -6,7 +6,6 @@ from irctest.basecontrollers import BaseClientController, NotImplementedByContro
|
||||
class GircController(BaseClientController):
|
||||
software_name = "gIRC"
|
||||
supported_sasl_mechanisms = ["PLAIN"]
|
||||
supported_capabilities = set() # Not exhaustive
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import subprocess
|
||||
from typing import Set
|
||||
|
||||
from irctest.basecontrollers import (
|
||||
BaseServerController,
|
||||
@ -40,8 +41,8 @@ TEMPLATE_SSL_CONFIG = """
|
||||
|
||||
class HybridController(BaseServerController, DirectoryBasedController):
|
||||
software_name = "Hybrid"
|
||||
supported_sasl_mechanisms = set()
|
||||
supported_capabilities = set() # Not exhaustive
|
||||
supports_sts = False
|
||||
supported_sasl_mechanisms: Set[str] = set()
|
||||
|
||||
def create_config(self):
|
||||
super().create_config()
|
||||
|
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import subprocess
|
||||
from typing import Set
|
||||
|
||||
from irctest.basecontrollers import (
|
||||
BaseServerController,
|
||||
@ -38,8 +39,8 @@ TEMPLATE_SSL_CONFIG = """
|
||||
|
||||
class InspircdController(BaseServerController, DirectoryBasedController):
|
||||
software_name = "InspIRCd"
|
||||
supported_sasl_mechanisms = set()
|
||||
supported_capabilities = set() # Not exhaustive
|
||||
supported_sasl_mechanisms: Set[str] = set()
|
||||
supports_str = False
|
||||
|
||||
def create_config(self):
|
||||
super().create_config()
|
||||
|
@ -33,7 +33,7 @@ class LimnoriaController(BaseClientController, DirectoryBasedController):
|
||||
"SCRAM-SHA-256",
|
||||
"EXTERNAL",
|
||||
}
|
||||
supported_capabilities = set(["sts"]) # Not exhaustive
|
||||
supports_sts = True
|
||||
|
||||
def create_config(self):
|
||||
create_config = super().create_config()
|
||||
|
@ -68,7 +68,6 @@ def make_list(list_):
|
||||
class MammonController(BaseServerController, DirectoryBasedController):
|
||||
software_name = "Mammon"
|
||||
supported_sasl_mechanisms = {"PLAIN", "ECDSA-NIST256P-CHALLENGE"}
|
||||
supported_capabilities = set() # Not exhaustive
|
||||
|
||||
def create_config(self):
|
||||
super().create_config()
|
||||
|
@ -130,9 +130,9 @@ def hash_password(password):
|
||||
|
||||
class OragonoController(BaseServerController, DirectoryBasedController):
|
||||
software_name = "Oragono"
|
||||
supported_sasl_mechanisms = {"PLAIN"}
|
||||
_port_wait_interval = 0.01
|
||||
supported_capabilities = set() # Not exhaustive
|
||||
supported_sasl_mechanisms = {"PLAIN"}
|
||||
supports_sts = True
|
||||
|
||||
def create_config(self):
|
||||
super().create_config()
|
||||
|
@ -22,7 +22,7 @@ auth_password = {password}
|
||||
class SopelController(BaseClientController):
|
||||
software_name = "Sopel"
|
||||
supported_sasl_mechanisms = {"PLAIN"}
|
||||
supported_capabilities = set() # Not exhaustive
|
||||
supports_sts = False
|
||||
|
||||
def __init__(self, test_config):
|
||||
super().__init__(test_config)
|
||||
|
@ -709,8 +709,9 @@ class JoinTestCase(cases.BaseServerTestCase):
|
||||
)
|
||||
|
||||
|
||||
class testChannelCaseSensitivity(cases.BaseServerTestCase):
|
||||
def _testChannelsEquivalent(casemapping, name1, name2):
|
||||
def _testChannelsEquivalent(casemapping, name1, name2):
|
||||
"""Generates test functions"""
|
||||
|
||||
@cases.mark_specifications("RFC1459", "RFC2812", strict=True)
|
||||
def f(self):
|
||||
self.connectClient("foo")
|
||||
@ -732,7 +733,10 @@ class testChannelCaseSensitivity(cases.BaseServerTestCase):
|
||||
f.__name__ = "testEquivalence__{}__{}".format(name1, name2)
|
||||
return f
|
||||
|
||||
def _testChannelsNotEquivalent(casemapping, name1, name2):
|
||||
|
||||
def _testChannelsNotEquivalent(casemapping, name1, name2):
|
||||
"""Generates test functions"""
|
||||
|
||||
@cases.mark_specifications("RFC1459", "RFC2812", strict=True)
|
||||
def f(self):
|
||||
self.connectClient("foo")
|
||||
@ -758,6 +762,8 @@ class testChannelCaseSensitivity(cases.BaseServerTestCase):
|
||||
f.__name__ = "testEquivalence__{}__{}".format(name1, name2)
|
||||
return f
|
||||
|
||||
|
||||
class testChannelCaseSensitivity(cases.BaseServerTestCase):
|
||||
testAsciiSimpleEquivalent = _testChannelsEquivalent("ascii", "#Foo", "#foo")
|
||||
testAsciiSimpleNotEquivalent = _testChannelsNotEquivalent("ascii", "#Foo", "#fooa")
|
||||
|
||||
|
@ -7,51 +7,9 @@ from irctest.basecontrollers import NotImplementedByController
|
||||
from irctest.irc_utils.junkdrawer import random_name
|
||||
|
||||
|
||||
class EchoMessageTestCase(cases.BaseServerTestCase):
|
||||
@cases.mark_capabilities("labeled-response", "echo-message", "message-tags")
|
||||
def testDirectMessageEcho(self):
|
||||
bar = random_name("bar")
|
||||
self.connectClient(
|
||||
bar,
|
||||
name=bar,
|
||||
capabilities=["labeled-response", "echo-message", "message-tags"],
|
||||
skip_if_cap_nak=True,
|
||||
)
|
||||
self.getMessages(bar)
|
||||
def _testEchoMessage(command, solo, server_time):
|
||||
"""Generates test functions"""
|
||||
|
||||
qux = random_name("qux")
|
||||
self.connectClient(
|
||||
qux,
|
||||
name=qux,
|
||||
capabilities=["labeled-response", "echo-message", "message-tags"],
|
||||
)
|
||||
self.getMessages(qux)
|
||||
|
||||
self.sendLine(
|
||||
bar,
|
||||
"@label=xyz;+example-client-tag=example-value PRIVMSG %s :hi there"
|
||||
% (qux,),
|
||||
)
|
||||
echo = self.getMessages(bar)[0]
|
||||
delivery = self.getMessages(qux)[0]
|
||||
|
||||
self.assertEqual(delivery.params, [qux, "hi there"])
|
||||
self.assertEqual(delivery.params, echo.params)
|
||||
|
||||
# Either both messages have a msgid, or neither does
|
||||
self.assertEqual(delivery.tags.get("msgid"), echo.tags.get("msgid"))
|
||||
|
||||
self.assertEqual(
|
||||
echo.tags.get("label"),
|
||||
"xyz",
|
||||
fail_msg="expected message label 'xyz', but got {got!r}",
|
||||
)
|
||||
self.assertEqual(delivery.tags["+example-client-tag"], "example-value")
|
||||
self.assertEqual(
|
||||
delivery.tags["+example-client-tag"], echo.tags["+example-client-tag"]
|
||||
)
|
||||
|
||||
def _testEchoMessage(command, solo, server_time):
|
||||
@cases.mark_capabilities("echo-message")
|
||||
def f(self):
|
||||
"""<http://ircv3.net/specs/extensions/echo-message-3.2.html>"""
|
||||
@ -117,8 +75,7 @@ class EchoMessageTestCase(cases.BaseServerTestCase):
|
||||
self.assertEqual(
|
||||
m1.params,
|
||||
m2.params,
|
||||
fail_msg="Parameters of forwarded and echoed "
|
||||
"messages differ: {} {}",
|
||||
fail_msg="Parameters of forwarded and echoed " "messages differ: {} {}",
|
||||
extra_format=(m1, m2),
|
||||
)
|
||||
if server_time:
|
||||
@ -137,6 +94,51 @@ class EchoMessageTestCase(cases.BaseServerTestCase):
|
||||
|
||||
return f
|
||||
|
||||
|
||||
class EchoMessageTestCase(cases.BaseServerTestCase):
|
||||
@cases.mark_capabilities("labeled-response", "echo-message", "message-tags")
|
||||
def testDirectMessageEcho(self):
|
||||
bar = random_name("bar")
|
||||
self.connectClient(
|
||||
bar,
|
||||
name=bar,
|
||||
capabilities=["labeled-response", "echo-message", "message-tags"],
|
||||
skip_if_cap_nak=True,
|
||||
)
|
||||
self.getMessages(bar)
|
||||
|
||||
qux = random_name("qux")
|
||||
self.connectClient(
|
||||
qux,
|
||||
name=qux,
|
||||
capabilities=["labeled-response", "echo-message", "message-tags"],
|
||||
)
|
||||
self.getMessages(qux)
|
||||
|
||||
self.sendLine(
|
||||
bar,
|
||||
"@label=xyz;+example-client-tag=example-value PRIVMSG %s :hi there"
|
||||
% (qux,),
|
||||
)
|
||||
echo = self.getMessages(bar)[0]
|
||||
delivery = self.getMessages(qux)[0]
|
||||
|
||||
self.assertEqual(delivery.params, [qux, "hi there"])
|
||||
self.assertEqual(delivery.params, echo.params)
|
||||
|
||||
# Either both messages have a msgid, or neither does
|
||||
self.assertEqual(delivery.tags.get("msgid"), echo.tags.get("msgid"))
|
||||
|
||||
self.assertEqual(
|
||||
echo.tags.get("label"),
|
||||
"xyz",
|
||||
fail_msg="expected message label 'xyz', but got {got!r}",
|
||||
)
|
||||
self.assertEqual(delivery.tags["+example-client-tag"], "example-value")
|
||||
self.assertEqual(
|
||||
delivery.tags["+example-client-tag"], echo.tags["+example-client-tag"]
|
||||
)
|
||||
|
||||
testEchoMessagePrivmsgNoServerTime = _testEchoMessage("PRIVMSG", False, False)
|
||||
testEchoMessagePrivmsgSolo = _testEchoMessage("PRIVMSG", True, True)
|
||||
testEchoMessagePrivmsg = _testEchoMessage("PRIVMSG", False, True)
|
||||
|
@ -23,12 +23,12 @@ LUSERME_REGEX = re.compile(r"^.*( [-0-9]* ).*( [-0-9]* ).*$")
|
||||
|
||||
@dataclass
|
||||
class LusersResult:
|
||||
GlobalVisible: int = None
|
||||
GlobalInvisible: int = None
|
||||
Servers: int = None
|
||||
Opers: int = None
|
||||
GlobalVisible: Optional[int] = None
|
||||
GlobalInvisible: Optional[int] = None
|
||||
Servers: Optional[int] = None
|
||||
Opers: Optional[int] = None
|
||||
Unregistered: Optional[int] = None
|
||||
Channels: int = None
|
||||
Channels: Optional[int] = None
|
||||
LocalTotal: Optional[int] = None
|
||||
LocalMax: Optional[int] = None
|
||||
GlobalTotal: Optional[int] = None
|
||||
|
@ -34,6 +34,7 @@ class Capabilities(enum.Enum):
|
||||
MULTILINE = "draft/multiline"
|
||||
MULTI_PREFIX = "multi-prefix"
|
||||
SERVER_TIME = "server-time"
|
||||
STS = "sts"
|
||||
|
||||
@classmethod
|
||||
def from_name(cls, name):
|
||||
|
@ -22,6 +22,7 @@ markers =
|
||||
draft/multiline
|
||||
multi-prefix
|
||||
server-time
|
||||
sts
|
||||
|
||||
# isupport tokens
|
||||
MONITOR
|
||||
|
Reference in New Issue
Block a user