mirror of
https://github.com/progval/irctest.git
synced 2025-04-05 06:49:47 +00:00
type-annotate all functions outside the tests themselves.
This commit is contained in:
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
@ -5,8 +7,11 @@ import socket
|
|||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
from typing import Any, Callable, Dict, Optional, Set
|
from typing import IO, Any, Callable, Dict, Optional, Set
|
||||||
|
|
||||||
|
import irctest
|
||||||
|
|
||||||
|
from . import authentication, tls
|
||||||
from .runner import NotImplementedByController
|
from .runner import NotImplementedByController
|
||||||
|
|
||||||
|
|
||||||
@ -41,27 +46,27 @@ class _BaseController:
|
|||||||
a process (eg. a server or a client), as well as sending it instructions
|
a process (eg. a server or a client), as well as sending it instructions
|
||||||
that are not part of the IRC specification."""
|
that are not part of the IRC specification."""
|
||||||
|
|
||||||
|
# set by conftest.py
|
||||||
|
openssl_bin: str
|
||||||
|
|
||||||
|
supports_sts: bool
|
||||||
|
supported_sasl_mechanisms: Set[str]
|
||||||
|
proc: Optional[subprocess.Popen]
|
||||||
|
|
||||||
def __init__(self, test_config: TestCaseControllerConfig):
|
def __init__(self, test_config: TestCaseControllerConfig):
|
||||||
self.test_config = test_config
|
self.test_config = test_config
|
||||||
self.proc = None
|
self.proc = None
|
||||||
|
|
||||||
def check_is_alive(self):
|
def check_is_alive(self) -> None:
|
||||||
|
assert self.proc
|
||||||
self.proc.poll()
|
self.proc.poll()
|
||||||
if self.proc.returncode is not None:
|
if self.proc.returncode is not None:
|
||||||
raise ProcessStopped()
|
raise ProcessStopped()
|
||||||
|
|
||||||
|
def kill_proc(self) -> None:
|
||||||
class DirectoryBasedController(_BaseController):
|
|
||||||
"""Helper for controllers whose software configuration is based on an
|
|
||||||
arbitrary directory."""
|
|
||||||
|
|
||||||
def __init__(self, test_config: TestCaseControllerConfig):
|
|
||||||
super().__init__(test_config)
|
|
||||||
self.directory = None
|
|
||||||
|
|
||||||
def kill_proc(self):
|
|
||||||
"""Terminates the controlled process, waits for it to exit, and
|
"""Terminates the controlled process, waits for it to exit, and
|
||||||
eventually kills it."""
|
eventually kills it."""
|
||||||
|
assert self.proc
|
||||||
self.proc.terminate()
|
self.proc.terminate()
|
||||||
try:
|
try:
|
||||||
self.proc.wait(5)
|
self.proc.wait(5)
|
||||||
@ -69,20 +74,36 @@ class DirectoryBasedController(_BaseController):
|
|||||||
self.proc.kill()
|
self.proc.kill()
|
||||||
self.proc = None
|
self.proc = None
|
||||||
|
|
||||||
def kill(self):
|
def kill(self) -> None:
|
||||||
"""Calls `kill_proc` and cleans the configuration."""
|
"""Calls `kill_proc` and cleans the configuration."""
|
||||||
if self.proc:
|
if self.proc:
|
||||||
self.kill_proc()
|
self.kill_proc()
|
||||||
|
|
||||||
|
|
||||||
|
class DirectoryBasedController(_BaseController):
|
||||||
|
"""Helper for controllers whose software configuration is based on an
|
||||||
|
arbitrary directory."""
|
||||||
|
|
||||||
|
directory: Optional[str]
|
||||||
|
|
||||||
|
def __init__(self, test_config: TestCaseControllerConfig):
|
||||||
|
super().__init__(test_config)
|
||||||
|
self.directory = None
|
||||||
|
|
||||||
|
def kill(self) -> None:
|
||||||
|
"""Calls `kill_proc` and cleans the configuration."""
|
||||||
|
super().kill()
|
||||||
if self.directory:
|
if self.directory:
|
||||||
shutil.rmtree(self.directory)
|
shutil.rmtree(self.directory)
|
||||||
|
|
||||||
def terminate(self):
|
def terminate(self) -> None:
|
||||||
"""Stops the process gracefully, and does not clean its config."""
|
"""Stops the process gracefully, and does not clean its config."""
|
||||||
|
assert self.proc
|
||||||
self.proc.terminate()
|
self.proc.terminate()
|
||||||
self.proc.wait()
|
self.proc.wait()
|
||||||
self.proc = None
|
self.proc = None
|
||||||
|
|
||||||
def open_file(self, name, mode="a"):
|
def open_file(self, name: str, mode: str = "a") -> IO:
|
||||||
"""Open a file in the configuration directory."""
|
"""Open a file in the configuration directory."""
|
||||||
assert self.directory
|
assert self.directory
|
||||||
if os.sep in name:
|
if os.sep in name:
|
||||||
@ -92,16 +113,12 @@ class DirectoryBasedController(_BaseController):
|
|||||||
assert os.path.isdir(dir_)
|
assert os.path.isdir(dir_)
|
||||||
return open(os.path.join(self.directory, name), mode)
|
return open(os.path.join(self.directory, name), mode)
|
||||||
|
|
||||||
def create_config(self):
|
def create_config(self) -> None:
|
||||||
"""If there is no config dir, creates it and returns True.
|
if not self.directory:
|
||||||
Else returns False."""
|
|
||||||
if self.directory:
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
self.directory = tempfile.mkdtemp()
|
self.directory = tempfile.mkdtemp()
|
||||||
return True
|
|
||||||
|
|
||||||
def gen_ssl(self):
|
def gen_ssl(self) -> None:
|
||||||
|
assert self.directory
|
||||||
self.csr_path = os.path.join(self.directory, "ssl.csr")
|
self.csr_path = os.path.join(self.directory, "ssl.csr")
|
||||||
self.key_path = os.path.join(self.directory, "ssl.key")
|
self.key_path = os.path.join(self.directory, "ssl.key")
|
||||||
self.pem_path = os.path.join(self.directory, "ssl.pem")
|
self.pem_path = os.path.join(self.directory, "ssl.pem")
|
||||||
@ -145,7 +162,13 @@ class DirectoryBasedController(_BaseController):
|
|||||||
class BaseClientController(_BaseController):
|
class BaseClientController(_BaseController):
|
||||||
"""Base controller for IRC clients."""
|
"""Base controller for IRC clients."""
|
||||||
|
|
||||||
def run(self, hostname, port, auth):
|
def run(
|
||||||
|
self,
|
||||||
|
hostname: str,
|
||||||
|
port: int,
|
||||||
|
auth: Optional[authentication.Authentication],
|
||||||
|
tls_config: Optional[tls.TlsConfig] = None,
|
||||||
|
) -> None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
@ -154,17 +177,29 @@ class BaseServerController(_BaseController):
|
|||||||
|
|
||||||
_port_wait_interval = 0.1
|
_port_wait_interval = 0.1
|
||||||
port_open = False
|
port_open = False
|
||||||
|
port: int
|
||||||
|
|
||||||
supports_sts: bool
|
def run(
|
||||||
supported_sasl_mechanisms: Set[str]
|
self,
|
||||||
|
hostname: str,
|
||||||
def run(self, hostname, port, password, valid_metadata_keys, invalid_metadata_keys):
|
port: int,
|
||||||
|
*,
|
||||||
|
password: Optional[str],
|
||||||
|
ssl: bool,
|
||||||
|
valid_metadata_keys: Optional[Set[str]],
|
||||||
|
invalid_metadata_keys: Optional[Set[str]],
|
||||||
|
) -> None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def registerUser(self, case, username, password=None):
|
def registerUser(
|
||||||
|
self,
|
||||||
|
case: irctest.cases.BaseServerTestCase, # type: ignore
|
||||||
|
username: str,
|
||||||
|
password: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
raise NotImplementedByController("account registration")
|
raise NotImplementedByController("account registration")
|
||||||
|
|
||||||
def wait_for_port(self):
|
def wait_for_port(self) -> None:
|
||||||
while not self.port_open:
|
while not self.port_open:
|
||||||
self.check_is_alive()
|
self.check_is_alive()
|
||||||
time.sleep(self._port_wait_interval)
|
time.sleep(self._port_wait_interval)
|
||||||
|
353
irctest/cases.py
353
irctest/cases.py
@ -3,16 +3,34 @@ import socket
|
|||||||
import ssl
|
import ssl
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
from typing import Optional, Set
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Container,
|
||||||
|
Dict,
|
||||||
|
Generic,
|
||||||
|
Hashable,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Set,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from . import basecontrollers, client_mock, runner
|
from . import basecontrollers, client_mock, runner, tls
|
||||||
|
from .authentication import Authentication
|
||||||
from .basecontrollers import TestCaseControllerConfig
|
from .basecontrollers import TestCaseControllerConfig
|
||||||
from .exceptions import ConnectionClosed
|
from .exceptions import ConnectionClosed
|
||||||
from .irc_utils import capabilities, message_parser
|
from .irc_utils import capabilities, message_parser
|
||||||
from .irc_utils.junkdrawer import normalizeWhitespace
|
from .irc_utils.junkdrawer import normalizeWhitespace
|
||||||
|
from .irc_utils.message_parser import Message
|
||||||
from .irc_utils.sasl import sasl_plain_blob
|
from .irc_utils.sasl import sasl_plain_blob
|
||||||
from .numerics import (
|
from .numerics import (
|
||||||
ERR_BADCHANNELKEY,
|
ERR_BADCHANNELKEY,
|
||||||
@ -35,18 +53,33 @@ CHANNEL_JOIN_FAIL_NUMERICS = frozenset(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# typevar for decorators
|
||||||
|
TCallable = TypeVar("TCallable", bound=Callable)
|
||||||
|
|
||||||
|
# typevar for the client name used by tests (usually int or str)
|
||||||
|
TClientName = TypeVar("TClientName", bound=Union[Hashable, int])
|
||||||
|
|
||||||
|
TController = TypeVar("TController", bound=basecontrollers._BaseController)
|
||||||
|
|
||||||
|
# general-purpose typevar
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class ChannelJoinException(Exception):
|
class ChannelJoinException(Exception):
|
||||||
def __init__(self, code, params):
|
def __init__(self, code: str, params: List[str]):
|
||||||
super().__init__(f"Failed to join channel ({code}): {params}")
|
super().__init__(f"Failed to join channel ({code}): {params}")
|
||||||
self.code = code
|
self.code = code
|
||||||
self.params = params
|
self.params = params
|
||||||
|
|
||||||
|
|
||||||
class _IrcTestCase(unittest.TestCase):
|
class _IrcTestCase(unittest.TestCase, Generic[TController]):
|
||||||
"""Base class for test cases."""
|
"""Base class for test cases."""
|
||||||
|
|
||||||
controllerClass = None # Will be set by __main__.py
|
# Will be set by __main__.py
|
||||||
|
controllerClass: Type[TController]
|
||||||
|
show_io: bool
|
||||||
|
|
||||||
|
controller: TController
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def config() -> TestCaseControllerConfig:
|
def config() -> TestCaseControllerConfig:
|
||||||
@ -56,7 +89,7 @@ class _IrcTestCase(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
return TestCaseControllerConfig()
|
return TestCaseControllerConfig()
|
||||||
|
|
||||||
def description(self):
|
def description(self) -> str:
|
||||||
method_doc = self._testMethodDoc
|
method_doc = self._testMethodDoc
|
||||||
if not method_doc:
|
if not method_doc:
|
||||||
return ""
|
return ""
|
||||||
@ -64,14 +97,13 @@ class _IrcTestCase(unittest.TestCase):
|
|||||||
method_doc, removeNewline=False
|
method_doc, removeNewline=False
|
||||||
).strip().replace("\n ", "\n\t")
|
).strip().replace("\n ", "\n\t")
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self) -> None:
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self.controller = self.controllerClass(self.config())
|
self.controller = self.controllerClass(self.config())
|
||||||
self.inbuffer = []
|
|
||||||
if self.show_io:
|
if self.show_io:
|
||||||
print("---- new test ----")
|
print("---- new test ----")
|
||||||
|
|
||||||
def assertMessageEqual(self, msg, **kwargs):
|
def assertMessageEqual(self, msg: Message, **kwargs: Any) -> None:
|
||||||
"""Helper for partially comparing a message.
|
"""Helper for partially comparing a message.
|
||||||
|
|
||||||
Takes the message as first arguments, and comparisons to be made
|
Takes the message as first arguments, and comparisons to be made
|
||||||
@ -83,21 +115,21 @@ class _IrcTestCase(unittest.TestCase):
|
|||||||
if error:
|
if error:
|
||||||
raise self.failureException(error)
|
raise self.failureException(error)
|
||||||
|
|
||||||
def messageEqual(self, msg, **kwargs):
|
def messageEqual(self, msg: Message, **kwargs: Any) -> bool:
|
||||||
"""Boolean negation of `messageDiffers` (returns a boolean,
|
"""Boolean negation of `messageDiffers` (returns a boolean,
|
||||||
not an optional string)."""
|
not an optional string)."""
|
||||||
return not self.messageDiffers(msg, **kwargs)
|
return not self.messageDiffers(msg, **kwargs)
|
||||||
|
|
||||||
def messageDiffers(
|
def messageDiffers(
|
||||||
self,
|
self,
|
||||||
msg,
|
msg: Message,
|
||||||
params=None,
|
params: Optional[List[Any]] = None,
|
||||||
target=None,
|
target: Optional[str] = None,
|
||||||
nick=None,
|
nick: Optional[str] = None,
|
||||||
fail_msg=None,
|
fail_msg: Optional[str] = None,
|
||||||
extra_format=(),
|
extra_format: Tuple = (),
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
):
|
) -> Optional[str]:
|
||||||
"""Returns an error message if the message doesn't match the given arguments,
|
"""Returns an error message if the message doesn't match the given arguments,
|
||||||
or None if it matches."""
|
or None if it matches."""
|
||||||
for (key, value) in kwargs.items():
|
for (key, value) in kwargs.items():
|
||||||
@ -120,7 +152,7 @@ class _IrcTestCase(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if nick:
|
if nick:
|
||||||
got_nick = msg.prefix.split("!")[0]
|
got_nick = msg.prefix.split("!")[0] if msg.prefix else None
|
||||||
if msg.prefix is None:
|
if msg.prefix is None:
|
||||||
fail_msg = (
|
fail_msg = (
|
||||||
fail_msg or "expected nick to be {expects}, got {got} prefix: {msg}"
|
fail_msg or "expected nick to be {expects}, got {got} prefix: {msg}"
|
||||||
@ -131,7 +163,7 @@ class _IrcTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def listMatch(self, got, expected):
|
def listMatch(self, got: List[str], expected: List[Any]) -> bool:
|
||||||
"""Returns True iff the list are equal.
|
"""Returns True iff the list are equal.
|
||||||
The ellipsis (aka. "..." aka triple dots) can be used on the 'expected'
|
The ellipsis (aka. "..." aka triple dots) can be used on the 'expected'
|
||||||
side as a wildcard, matching any *single* value."""
|
side as a wildcard, matching any *single* value."""
|
||||||
@ -145,62 +177,124 @@ class _IrcTestCase(unittest.TestCase):
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def assertIn(self, item, list_, msg=None, fail_msg=None, extra_format=()):
|
def assertIn(
|
||||||
|
self,
|
||||||
|
member: Any,
|
||||||
|
container: Union[Iterable[Any], Container[Any]],
|
||||||
|
msg: Optional[str] = None,
|
||||||
|
fail_msg: Optional[str] = None,
|
||||||
|
extra_format: Tuple = (),
|
||||||
|
) -> None:
|
||||||
if fail_msg:
|
if fail_msg:
|
||||||
fail_msg = fail_msg.format(*extra_format, item=item, list=list_, msg=msg)
|
fail_msg = fail_msg.format(
|
||||||
super().assertIn(item, list_, fail_msg)
|
*extra_format, item=member, list=container, msg=msg
|
||||||
|
)
|
||||||
|
super().assertIn(member, container, fail_msg)
|
||||||
|
|
||||||
def assertNotIn(self, item, list_, msg=None, fail_msg=None, extra_format=()):
|
def assertNotIn(
|
||||||
|
self,
|
||||||
|
member: Any,
|
||||||
|
container: Union[Iterable[Any], Container[Any]],
|
||||||
|
msg: Optional[str] = None,
|
||||||
|
fail_msg: Optional[str] = None,
|
||||||
|
extra_format: Tuple = (),
|
||||||
|
) -> None:
|
||||||
if fail_msg:
|
if fail_msg:
|
||||||
fail_msg = fail_msg.format(*extra_format, item=item, list=list_, msg=msg)
|
fail_msg = fail_msg.format(
|
||||||
super().assertNotIn(item, list_, fail_msg)
|
*extra_format, item=member, list=container, msg=msg
|
||||||
|
)
|
||||||
|
super().assertNotIn(member, container, fail_msg)
|
||||||
|
|
||||||
def assertEqual(self, got, expects, msg=None, fail_msg=None, extra_format=()):
|
def assertEqual(
|
||||||
|
self,
|
||||||
|
got: T,
|
||||||
|
expects: T,
|
||||||
|
msg: Any = None,
|
||||||
|
fail_msg: Optional[str] = None,
|
||||||
|
extra_format: Tuple = (),
|
||||||
|
) -> None:
|
||||||
if fail_msg:
|
if fail_msg:
|
||||||
fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
|
fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
|
||||||
super().assertEqual(got, expects, fail_msg)
|
super().assertEqual(got, expects, fail_msg)
|
||||||
|
|
||||||
def assertNotEqual(self, got, expects, msg=None, fail_msg=None, extra_format=()):
|
def assertNotEqual(
|
||||||
|
self,
|
||||||
|
got: T,
|
||||||
|
expects: T,
|
||||||
|
msg: Any = None,
|
||||||
|
fail_msg: Optional[str] = None,
|
||||||
|
extra_format: Tuple = (),
|
||||||
|
) -> None:
|
||||||
if fail_msg:
|
if fail_msg:
|
||||||
fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
|
fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
|
||||||
super().assertNotEqual(got, expects, fail_msg)
|
super().assertNotEqual(got, expects, fail_msg)
|
||||||
|
|
||||||
def assertGreater(self, got, expects, msg=None, fail_msg=None, extra_format=()):
|
def assertGreater(
|
||||||
|
self,
|
||||||
|
got: T,
|
||||||
|
expects: T,
|
||||||
|
msg: Any = None,
|
||||||
|
fail_msg: Optional[str] = None,
|
||||||
|
extra_format: Tuple = (),
|
||||||
|
) -> None:
|
||||||
if fail_msg:
|
if fail_msg:
|
||||||
fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
|
fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
|
||||||
super().assertGreater(got, expects, fail_msg)
|
super().assertGreater(got, expects, fail_msg)
|
||||||
|
|
||||||
def assertGreaterEqual(
|
def assertGreaterEqual(
|
||||||
self, got, expects, msg=None, fail_msg=None, extra_format=()
|
self,
|
||||||
):
|
got: T,
|
||||||
|
expects: T,
|
||||||
|
msg: Any = None,
|
||||||
|
fail_msg: Optional[str] = None,
|
||||||
|
extra_format: Tuple = (),
|
||||||
|
) -> None:
|
||||||
if fail_msg:
|
if fail_msg:
|
||||||
fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
|
fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
|
||||||
super().assertGreaterEqual(got, expects, fail_msg)
|
super().assertGreaterEqual(got, expects, fail_msg)
|
||||||
|
|
||||||
def assertLess(self, got, expects, msg=None, fail_msg=None, extra_format=()):
|
def assertLess(
|
||||||
|
self,
|
||||||
|
got: T,
|
||||||
|
expects: T,
|
||||||
|
msg: Any = None,
|
||||||
|
fail_msg: Optional[str] = None,
|
||||||
|
extra_format: Tuple = (),
|
||||||
|
) -> None:
|
||||||
if fail_msg:
|
if fail_msg:
|
||||||
fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
|
fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
|
||||||
super().assertLess(got, expects, fail_msg)
|
super().assertLess(got, expects, fail_msg)
|
||||||
|
|
||||||
def assertLessEqual(self, got, expects, msg=None, fail_msg=None, extra_format=()):
|
def assertLessEqual(
|
||||||
|
self,
|
||||||
|
got: T,
|
||||||
|
expects: T,
|
||||||
|
msg: Any = None,
|
||||||
|
fail_msg: Optional[str] = None,
|
||||||
|
extra_format: Tuple = (),
|
||||||
|
) -> None:
|
||||||
if fail_msg:
|
if fail_msg:
|
||||||
fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
|
fail_msg = fail_msg.format(*extra_format, got=got, expects=expects, msg=msg)
|
||||||
super().assertLessEqual(got, expects, fail_msg)
|
super().assertLessEqual(got, expects, fail_msg)
|
||||||
|
|
||||||
|
|
||||||
class BaseClientTestCase(_IrcTestCase):
|
class BaseClientTestCase(_IrcTestCase[basecontrollers.BaseClientController]):
|
||||||
"""Basic class for client tests. Handles spawning a client and exchanging
|
"""Basic class for client tests. Handles spawning a client and exchanging
|
||||||
messages with it."""
|
messages with it."""
|
||||||
|
|
||||||
nick = None
|
conn: Optional[socket.socket]
|
||||||
user = None
|
nick: Optional[str] = None
|
||||||
|
user: Optional[List[str]] = None
|
||||||
|
server: socket.socket
|
||||||
|
protocol_version = Optional[str]
|
||||||
|
acked_capabilities = Optional[Set[str]]
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self) -> None:
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self.conn = None
|
self.conn = None
|
||||||
self._setUpServer()
|
self._setUpServer()
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self) -> None:
|
||||||
if self.conn:
|
if self.conn:
|
||||||
try:
|
try:
|
||||||
self.conn.sendall(b"QUIT :end of test.")
|
self.conn.sendall(b"QUIT :end of test.")
|
||||||
@ -214,7 +308,7 @@ class BaseClientTestCase(_IrcTestCase):
|
|||||||
self.conn.close()
|
self.conn.close()
|
||||||
self.server.close()
|
self.server.close()
|
||||||
|
|
||||||
def _setUpServer(self):
|
def _setUpServer(self) -> None:
|
||||||
"""Creates the server and make it listen."""
|
"""Creates the server and make it listen."""
|
||||||
self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
self.server.bind(("", 0)) # Bind any free port
|
self.server.bind(("", 0)) # Bind any free port
|
||||||
@ -223,9 +317,15 @@ class BaseClientTestCase(_IrcTestCase):
|
|||||||
# Used to check if the client is alive from time to time
|
# Used to check if the client is alive from time to time
|
||||||
self.server.settimeout(1)
|
self.server.settimeout(1)
|
||||||
|
|
||||||
def acceptClient(self, tls_cert=None, tls_key=None, server=None):
|
def acceptClient(
|
||||||
|
self,
|
||||||
|
tls_cert: Optional[str] = None,
|
||||||
|
tls_key: Optional[str] = None,
|
||||||
|
server: Optional[socket.socket] = None,
|
||||||
|
) -> None:
|
||||||
"""Make the server accept a client connection. Blocking."""
|
"""Make the server accept a client connection. Blocking."""
|
||||||
server = server or self.server
|
server = server or self.server
|
||||||
|
assert server
|
||||||
# Wait for the client to connect
|
# Wait for the client to connect
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@ -252,17 +352,17 @@ class BaseClientTestCase(_IrcTestCase):
|
|||||||
self.conn = context.wrap_socket(self.conn, server_side=True)
|
self.conn = context.wrap_socket(self.conn, server_side=True)
|
||||||
self.conn_file = self.conn.makefile(newline="\r\n", encoding="utf8")
|
self.conn_file = self.conn.makefile(newline="\r\n", encoding="utf8")
|
||||||
|
|
||||||
def getLine(self):
|
def getLine(self) -> str:
|
||||||
line = self.conn_file.readline()
|
line = self.conn_file.readline()
|
||||||
if self.show_io:
|
if self.show_io:
|
||||||
print("{:.3f} C: {}".format(time.time(), line.strip()))
|
print("{:.3f} C: {}".format(time.time(), line.strip()))
|
||||||
return line
|
return line
|
||||||
|
|
||||||
def getMessages(self, *args):
|
def getMessage(
|
||||||
lines = self.getLines(*args)
|
self,
|
||||||
return map(message_parser.parse_message, lines)
|
*args: Any,
|
||||||
|
filter_pred: Optional[Callable[[Message], bool]] = None,
|
||||||
def getMessage(self, *args, filter_pred=None):
|
) -> Message:
|
||||||
"""Gets a message and returns it. If a filter predicate is given,
|
"""Gets a message and returns it. If a filter predicate is given,
|
||||||
fetches messages until the predicate returns a False on a message,
|
fetches messages until the predicate returns a False on a message,
|
||||||
and returns this message."""
|
and returns this message."""
|
||||||
@ -274,18 +374,17 @@ class BaseClientTestCase(_IrcTestCase):
|
|||||||
if not filter_pred or filter_pred(msg):
|
if not filter_pred or filter_pred(msg):
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
def sendLine(self, line):
|
def sendLine(self, line: str) -> None:
|
||||||
|
assert self.conn
|
||||||
self.conn.sendall(line.encode())
|
self.conn.sendall(line.encode())
|
||||||
if not line.endswith("\r\n"):
|
if not line.endswith("\r\n"):
|
||||||
self.conn.sendall(b"\r\n")
|
self.conn.sendall(b"\r\n")
|
||||||
if self.show_io:
|
if self.show_io:
|
||||||
print("{:.3f} S: {}".format(time.time(), line.strip()))
|
print("{:.3f} S: {}".format(time.time(), line.strip()))
|
||||||
|
|
||||||
|
def readCapLs(
|
||||||
class ClientNegociationHelper:
|
self, auth: Optional[Authentication] = None, tls_config: tls.TlsConfig = None
|
||||||
"""Helper class for tests handling capabilities negociation."""
|
) -> None:
|
||||||
|
|
||||||
def readCapLs(self, auth=None, tls_config=None):
|
|
||||||
(hostname, port) = self.server.getsockname()
|
(hostname, port) = self.server.getsockname()
|
||||||
self.controller.run(
|
self.controller.run(
|
||||||
hostname=hostname, port=port, auth=auth, tls_config=tls_config
|
hostname=hostname, port=port, auth=auth, tls_config=tls_config
|
||||||
@ -302,28 +401,33 @@ class ClientNegociationHelper:
|
|||||||
else:
|
else:
|
||||||
raise AssertionError("Unknown CAP params: {}".format(m.params))
|
raise AssertionError("Unknown CAP params: {}".format(m.params))
|
||||||
|
|
||||||
def userNickPredicate(self, msg):
|
def userNickPredicate(self, msg: Message) -> bool:
|
||||||
"""Predicate to be used with getMessage to handle NICK/USER
|
"""Predicate to be used with getMessage to handle NICK/USER
|
||||||
transparently."""
|
transparently."""
|
||||||
if msg.command == "NICK":
|
if msg.command == "NICK":
|
||||||
self.assertEqual(len(msg.params), 1, msg)
|
self.assertEqual(len(msg.params), 1, msg=msg)
|
||||||
self.nick = msg.params[0]
|
self.nick = msg.params[0]
|
||||||
return False
|
return False
|
||||||
elif msg.command == "USER":
|
elif msg.command == "USER":
|
||||||
self.assertEqual(len(msg.params), 4, msg)
|
self.assertEqual(len(msg.params), 4, msg=msg)
|
||||||
self.user = msg.params
|
self.user = msg.params
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def negotiateCapabilities(self, caps, cap_ls=True, auth=None):
|
def negotiateCapabilities(
|
||||||
|
self,
|
||||||
|
caps: List[str],
|
||||||
|
cap_ls: bool = True,
|
||||||
|
auth: Optional[Authentication] = None,
|
||||||
|
) -> Optional[Message]:
|
||||||
"""Performes a complete capability negociation process, without
|
"""Performes a complete capability negociation process, without
|
||||||
ending it, so the caller can continue the negociation."""
|
ending it, so the caller can continue the negociation."""
|
||||||
if cap_ls:
|
if cap_ls:
|
||||||
self.readCapLs(auth)
|
self.readCapLs(auth)
|
||||||
if not self.protocol_version:
|
if not self.protocol_version:
|
||||||
# No negotiation.
|
# No negotiation.
|
||||||
return
|
return None
|
||||||
self.sendLine("CAP * LS :{}".format(" ".join(caps)))
|
self.sendLine("CAP * LS :{}".format(" ".join(caps)))
|
||||||
capability_names = frozenset(capabilities.cap_list_to_dict(caps))
|
capability_names = frozenset(capabilities.cap_list_to_dict(caps))
|
||||||
self.acked_capabilities = set()
|
self.acked_capabilities = set()
|
||||||
@ -343,21 +447,25 @@ class ClientNegociationHelper:
|
|||||||
self.sendLine(
|
self.sendLine(
|
||||||
"CAP {} ACK :{}".format(self.nick or "*", m.params[1])
|
"CAP {} ACK :{}".format(self.nick or "*", m.params[1])
|
||||||
)
|
)
|
||||||
self.acked_capabilities.update(requested)
|
self.acked_capabilities.update(requested) # type: ignore
|
||||||
else:
|
else:
|
||||||
return m
|
return m
|
||||||
|
|
||||||
|
|
||||||
class BaseServerTestCase(_IrcTestCase):
|
class BaseServerTestCase(
|
||||||
|
_IrcTestCase[basecontrollers.BaseServerController], Generic[TClientName]
|
||||||
|
):
|
||||||
"""Basic class for server tests. Handles spawning a server and exchanging
|
"""Basic class for server tests. Handles spawning a server and exchanging
|
||||||
messages with it."""
|
messages with it."""
|
||||||
|
|
||||||
|
show_io: bool # set by conftest.py
|
||||||
|
|
||||||
password: Optional[str] = None
|
password: Optional[str] = None
|
||||||
ssl = False
|
ssl = False
|
||||||
valid_metadata_keys: Set[str] = set()
|
valid_metadata_keys: Set[str] = set()
|
||||||
invalid_metadata_keys: Set[str] = set()
|
invalid_metadata_keys: Set[str] = set()
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self) -> None:
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self.server_support = None
|
self.server_support = None
|
||||||
self.find_hostname_and_port()
|
self.find_hostname_and_port()
|
||||||
@ -369,53 +477,64 @@ class BaseServerTestCase(_IrcTestCase):
|
|||||||
invalid_metadata_keys=self.invalid_metadata_keys,
|
invalid_metadata_keys=self.invalid_metadata_keys,
|
||||||
ssl=self.ssl,
|
ssl=self.ssl,
|
||||||
)
|
)
|
||||||
self.clients = {}
|
self.clients: Dict[TClientName, client_mock.ClientMock] = {}
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self) -> None:
|
||||||
self.controller.kill()
|
self.controller.kill()
|
||||||
for client in list(self.clients):
|
for client in list(self.clients):
|
||||||
self.removeClient(client)
|
self.removeClient(client)
|
||||||
|
|
||||||
def find_hostname_and_port(self):
|
def find_hostname_and_port(self) -> None:
|
||||||
"""Find available hostname/port to listen on."""
|
"""Find available hostname/port to listen on."""
|
||||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
s.bind(("", 0))
|
s.bind(("", 0))
|
||||||
(self.hostname, self.port) = s.getsockname()
|
(self.hostname, self.port) = s.getsockname()
|
||||||
s.close()
|
s.close()
|
||||||
|
|
||||||
def addClient(self, name=None, show_io=None):
|
def addClient(
|
||||||
|
self, name: Optional[TClientName] = None, show_io: Optional[bool] = None
|
||||||
|
) -> TClientName:
|
||||||
"""Connects a client to the server and adds it to the dict.
|
"""Connects a client to the server and adds it to the dict.
|
||||||
If 'name' is not given, uses the lowest unused non-negative integer."""
|
If 'name' is not given, uses the lowest unused non-negative integer."""
|
||||||
self.controller.wait_for_port()
|
self.controller.wait_for_port()
|
||||||
if not name:
|
if not name:
|
||||||
name = max(map(int, list(self.clients) + [0])) + 1
|
new_name: int = (
|
||||||
|
max(
|
||||||
|
[int(name) for name in self.clients if isinstance(name, (int, str))]
|
||||||
|
+ [0]
|
||||||
|
)
|
||||||
|
+ 1
|
||||||
|
)
|
||||||
|
name = cast(TClientName, new_name)
|
||||||
show_io = show_io if show_io is not None else self.show_io
|
show_io = show_io if show_io is not None else self.show_io
|
||||||
self.clients[name] = client_mock.ClientMock(name=name, show_io=show_io)
|
self.clients[name] = client_mock.ClientMock(name=name, show_io=show_io)
|
||||||
self.clients[name].connect(self.hostname, self.port)
|
self.clients[name].connect(self.hostname, self.port)
|
||||||
return name
|
return name
|
||||||
|
|
||||||
def removeClient(self, name):
|
def removeClient(self, name: TClientName) -> None:
|
||||||
"""Disconnects the client, without QUIT."""
|
"""Disconnects the client, without QUIT."""
|
||||||
assert name in self.clients
|
assert name in self.clients
|
||||||
self.clients[name].disconnect()
|
self.clients[name].disconnect()
|
||||||
del self.clients[name]
|
del self.clients[name]
|
||||||
|
|
||||||
def getMessages(self, client, **kwargs):
|
def getMessages(self, client: TClientName, **kwargs: Any) -> List[Message]:
|
||||||
return self.clients[client].getMessages(**kwargs)
|
return self.clients[client].getMessages(**kwargs)
|
||||||
|
|
||||||
def getMessage(self, client, **kwargs):
|
def getMessage(self, client: TClientName, **kwargs: Any) -> Message:
|
||||||
return self.clients[client].getMessage(**kwargs)
|
return self.clients[client].getMessage(**kwargs)
|
||||||
|
|
||||||
def getRegistrationMessage(self, client):
|
def getRegistrationMessage(self, client: TClientName) -> Message:
|
||||||
"""Filter notices, do not send pings."""
|
"""Filter notices, do not send pings."""
|
||||||
return self.getMessage(
|
return self.getMessage(
|
||||||
client, synchronize=False, filter_pred=lambda m: m.command != "NOTICE"
|
client, synchronize=False, filter_pred=lambda m: m.command != "NOTICE"
|
||||||
)
|
)
|
||||||
|
|
||||||
def sendLine(self, client, line):
|
def sendLine(self, client: TClientName, line: Union[str, bytes]) -> None:
|
||||||
return self.clients[client].sendLine(line)
|
return self.clients[client].sendLine(line)
|
||||||
|
|
||||||
def getCapLs(self, client, as_list=False):
|
def getCapLs(
|
||||||
|
self, client: TClientName, as_list: bool = False
|
||||||
|
) -> Union[List[str], Dict[str, Optional[str]]]:
|
||||||
"""Waits for a CAP LS block, parses all CAP LS messages, and return
|
"""Waits for a CAP LS block, parses all CAP LS messages, and return
|
||||||
the dict capabilities, with their values.
|
the dict capabilities, with their values.
|
||||||
|
|
||||||
@ -431,10 +550,10 @@ class BaseServerTestCase(_IrcTestCase):
|
|||||||
else:
|
else:
|
||||||
caps.extend(m.params[2].split())
|
caps.extend(m.params[2].split())
|
||||||
if not as_list:
|
if not as_list:
|
||||||
caps = capabilities.cap_list_to_dict(caps)
|
return capabilities.cap_list_to_dict(caps)
|
||||||
return caps
|
return caps
|
||||||
|
|
||||||
def assertDisconnected(self, client):
|
def assertDisconnected(self, client: TClientName) -> None:
|
||||||
try:
|
try:
|
||||||
self.getMessages(client)
|
self.getMessages(client)
|
||||||
self.getMessages(client)
|
self.getMessages(client)
|
||||||
@ -444,7 +563,7 @@ class BaseServerTestCase(_IrcTestCase):
|
|||||||
else:
|
else:
|
||||||
raise AssertionError("Client not disconnected.")
|
raise AssertionError("Client not disconnected.")
|
||||||
|
|
||||||
def skipToWelcome(self, client):
|
def skipToWelcome(self, client: TClientName) -> List[Message]:
|
||||||
"""Skip to the point where we are registered
|
"""Skip to the point where we are registered
|
||||||
<https://tools.ietf.org/html/rfc2812#section-3.1>
|
<https://tools.ietf.org/html/rfc2812#section-3.1>
|
||||||
"""
|
"""
|
||||||
@ -457,15 +576,19 @@ class BaseServerTestCase(_IrcTestCase):
|
|||||||
|
|
||||||
def connectClient(
|
def connectClient(
|
||||||
self,
|
self,
|
||||||
nick,
|
nick: str,
|
||||||
name=None,
|
name: TClientName = None,
|
||||||
capabilities=None,
|
capabilities: Optional[List[str]] = None,
|
||||||
skip_if_cap_nak=False,
|
skip_if_cap_nak: bool = False,
|
||||||
show_io=None,
|
show_io: Optional[bool] = None,
|
||||||
account=None,
|
account: Optional[str] = None,
|
||||||
password=None,
|
password: Optional[str] = None,
|
||||||
ident="username",
|
ident: str = "username",
|
||||||
):
|
) -> List[Message]:
|
||||||
|
"""Connections a new client, does the cap negotiation
|
||||||
|
and connection registration, and skips to the end of the MOTD.
|
||||||
|
Returns the list of all messages received after registration,
|
||||||
|
just like `skipToWelcome`."""
|
||||||
client = self.addClient(name, show_io=show_io)
|
client = self.addClient(name, show_io=show_io)
|
||||||
if capabilities is not None and 0 < len(capabilities):
|
if capabilities is not None and 0 < len(capabilities):
|
||||||
self.sendLine(client, "CAP REQ :{}".format(" ".join(capabilities)))
|
self.sendLine(client, "CAP REQ :{}".format(" ".join(capabilities)))
|
||||||
@ -502,14 +625,14 @@ class BaseServerTestCase(_IrcTestCase):
|
|||||||
for param in m.params[1:-1]:
|
for param in m.params[1:-1]:
|
||||||
if "=" in param:
|
if "=" in param:
|
||||||
(key, value) = param.split("=")
|
(key, value) = param.split("=")
|
||||||
|
self.server_support[key] = value
|
||||||
else:
|
else:
|
||||||
(key, value) = (param, None)
|
self.server_support[param] = None
|
||||||
self.server_support[key] = value
|
|
||||||
welcome.append(m)
|
welcome.append(m)
|
||||||
|
|
||||||
return welcome
|
return welcome
|
||||||
|
|
||||||
def joinClient(self, client, channel):
|
def joinClient(self, client: TClientName, channel: str) -> None:
|
||||||
self.sendLine(client, "JOIN {}".format(channel))
|
self.sendLine(client, "JOIN {}".format(channel))
|
||||||
received = {m.command for m in self.getMessages(client)}
|
received = {m.command for m in self.getMessages(client)}
|
||||||
self.assertIn(
|
self.assertIn(
|
||||||
@ -520,7 +643,7 @@ class BaseServerTestCase(_IrcTestCase):
|
|||||||
extra_format=(channel,),
|
extra_format=(channel,),
|
||||||
)
|
)
|
||||||
|
|
||||||
def joinChannel(self, client, channel):
|
def joinChannel(self, client: TClientName, channel: str) -> None:
|
||||||
self.sendLine(client, "JOIN {}".format(channel))
|
self.sendLine(client, "JOIN {}".format(channel))
|
||||||
# wait until we see them join the channel
|
# wait until we see them join the channel
|
||||||
joined = False
|
joined = False
|
||||||
@ -537,24 +660,34 @@ class BaseServerTestCase(_IrcTestCase):
|
|||||||
raise ChannelJoinException(msg.command, msg.params)
|
raise ChannelJoinException(msg.command, msg.params)
|
||||||
|
|
||||||
|
|
||||||
class OptionalityHelper:
|
_TSelf = TypeVar("_TSelf", bound="OptionalityHelper")
|
||||||
controller: basecontrollers.BaseServerController
|
_TReturn = TypeVar("_TReturn")
|
||||||
|
|
||||||
def checkSaslSupport(self):
|
|
||||||
|
class OptionalityHelper(Generic[TController]):
|
||||||
|
controller: TController
|
||||||
|
|
||||||
|
def checkSaslSupport(self) -> None:
|
||||||
if self.controller.supported_sasl_mechanisms:
|
if self.controller.supported_sasl_mechanisms:
|
||||||
return
|
return
|
||||||
raise runner.NotImplementedByController("SASL")
|
raise runner.NotImplementedByController("SASL")
|
||||||
|
|
||||||
def checkMechanismSupport(self, mechanism):
|
def checkMechanismSupport(self, mechanism: str) -> None:
|
||||||
if mechanism in self.controller.supported_sasl_mechanisms:
|
if mechanism in self.controller.supported_sasl_mechanisms:
|
||||||
return
|
return
|
||||||
raise runner.OptionalSaslMechanismNotSupported(mechanism)
|
raise runner.OptionalSaslMechanismNotSupported(mechanism)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def skipUnlessHasMechanism(mech):
|
def skipUnlessHasMechanism(
|
||||||
def decorator(f):
|
mech: str,
|
||||||
|
) -> Callable[[Callable[[_TSelf], _TReturn]], Callable[[_TSelf], _TReturn]]:
|
||||||
|
# Just a function returning a function that takes functions and
|
||||||
|
# returns functions, nothing to see here.
|
||||||
|
# If Python didn't have such an awful syntax for callables, it would be:
|
||||||
|
# str -> ((TSelf -> TReturn) -> (TSelf -> TReturn))
|
||||||
|
def decorator(f: Callable[[_TSelf], _TReturn]) -> Callable[[_TSelf], _TReturn]:
|
||||||
@functools.wraps(f)
|
@functools.wraps(f)
|
||||||
def newf(self):
|
def newf(self: _TSelf) -> _TReturn:
|
||||||
self.checkMechanismSupport(mech)
|
self.checkMechanismSupport(mech)
|
||||||
return f(self)
|
return f(self)
|
||||||
|
|
||||||
@ -562,23 +695,29 @@ class OptionalityHelper:
|
|||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
def skipUnlessHasSasl(f):
|
@staticmethod
|
||||||
|
def skipUnlessHasSasl(
|
||||||
|
f: Callable[[_TSelf], _TReturn]
|
||||||
|
) -> Callable[[_TSelf], _TReturn]:
|
||||||
@functools.wraps(f)
|
@functools.wraps(f)
|
||||||
def newf(self):
|
def newf(self: _TSelf) -> _TReturn:
|
||||||
self.checkSaslSupport()
|
self.checkSaslSupport()
|
||||||
return f(self)
|
return f(self)
|
||||||
|
|
||||||
return newf
|
return newf
|
||||||
|
|
||||||
|
|
||||||
def mark_specifications(*specifications, deprecated=False, strict=False):
|
def mark_specifications(
|
||||||
|
*specifications_str: str, deprecated: bool = False, strict: bool = False
|
||||||
|
) -> Callable[[TCallable], TCallable]:
|
||||||
specifications = frozenset(
|
specifications = frozenset(
|
||||||
Specifications.from_name(s) if isinstance(s, str) else s for s in specifications
|
Specifications.from_name(s) if isinstance(s, str) else s
|
||||||
|
for s in specifications_str
|
||||||
)
|
)
|
||||||
if None in specifications:
|
if None in specifications:
|
||||||
raise ValueError("Invalid set of specifications: {}".format(specifications))
|
raise ValueError("Invalid set of specifications: {}".format(specifications))
|
||||||
|
|
||||||
def decorator(f):
|
def decorator(f: TCallable) -> TCallable:
|
||||||
for specification in specifications:
|
for specification in specifications:
|
||||||
f = getattr(pytest.mark, specification.value)(f)
|
f = getattr(pytest.mark, specification.value)(f)
|
||||||
if strict:
|
if strict:
|
||||||
@ -590,14 +729,16 @@ def mark_specifications(*specifications, deprecated=False, strict=False):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def mark_capabilities(*capabilities, deprecated=False, strict=False):
|
def mark_capabilities(
|
||||||
|
*capabilities_str: str, deprecated: bool = False, strict: bool = False
|
||||||
|
) -> Callable[[TCallable], TCallable]:
|
||||||
capabilities = frozenset(
|
capabilities = frozenset(
|
||||||
Capabilities.from_name(c) if isinstance(c, str) else c for c in capabilities
|
Capabilities.from_name(c) if isinstance(c, str) else c for c in capabilities_str
|
||||||
)
|
)
|
||||||
if None in capabilities:
|
if None in capabilities:
|
||||||
raise ValueError("Invalid set of capabilities: {}".format(capabilities))
|
raise ValueError("Invalid set of capabilities: {}".format(capabilities))
|
||||||
|
|
||||||
def decorator(f):
|
def decorator(f: TCallable) -> TCallable:
|
||||||
for capability in capabilities:
|
for capability in capabilities:
|
||||||
f = getattr(pytest.mark, capability.value)(f)
|
f = getattr(pytest.mark, capability.value)(f)
|
||||||
# Support for any capability implies IRCv3
|
# Support for any capability implies IRCv3
|
||||||
@ -607,14 +748,16 @@ def mark_capabilities(*capabilities, deprecated=False, strict=False):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def mark_isupport(*tokens, deprecated=False, strict=False):
|
def mark_isupport(
|
||||||
|
*tokens_str: str, deprecated: bool = False, strict: bool = False
|
||||||
|
) -> Callable[[TCallable], TCallable]:
|
||||||
tokens = frozenset(
|
tokens = frozenset(
|
||||||
IsupportTokens.from_name(c) if isinstance(c, str) else c for c in tokens
|
IsupportTokens.from_name(c) if isinstance(c, str) else c for c in tokens_str
|
||||||
)
|
)
|
||||||
if None in tokens:
|
if None in tokens:
|
||||||
raise ValueError("Invalid set of isupport tokens: {}".format(tokens))
|
raise ValueError("Invalid set of isupport tokens: {}".format(tokens))
|
||||||
|
|
||||||
def decorator(f):
|
def decorator(f: TCallable) -> TCallable:
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
f = getattr(pytest.mark, token.value)(f)
|
f = getattr(pytest.mark, token.value)(f)
|
||||||
return f
|
return f
|
||||||
|
@ -2,36 +2,41 @@ import socket
|
|||||||
import ssl
|
import ssl
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
from typing import Any, Callable, List, Optional, Union
|
||||||
|
|
||||||
from .exceptions import ConnectionClosed, NoMessageException
|
from .exceptions import ConnectionClosed, NoMessageException
|
||||||
from .irc_utils import message_parser
|
from .irc_utils import message_parser
|
||||||
|
|
||||||
|
|
||||||
class ClientMock:
|
class ClientMock:
|
||||||
def __init__(self, name, show_io):
|
def __init__(self, name: Any, show_io: bool):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.show_io = show_io
|
self.show_io = show_io
|
||||||
self.inbuffer = []
|
self.inbuffer: List[message_parser.Message] = []
|
||||||
self.ssl = False
|
self.ssl = False
|
||||||
|
|
||||||
def connect(self, hostname, port):
|
def connect(self, hostname: str, port: int) -> None:
|
||||||
self.conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
self.conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
self.conn.settimeout(1) # TODO: configurable
|
self.conn.settimeout(1) # TODO: configurable
|
||||||
self.conn.connect((hostname, port))
|
self.conn.connect((hostname, port))
|
||||||
if self.show_io:
|
if self.show_io:
|
||||||
print("{:.3f} {}: connects to server.".format(time.time(), self.name))
|
print("{:.3f} {}: connects to server.".format(time.time(), self.name))
|
||||||
|
|
||||||
def disconnect(self):
|
def disconnect(self) -> None:
|
||||||
if self.show_io:
|
if self.show_io:
|
||||||
print("{:.3f} {}: disconnects from server.".format(time.time(), self.name))
|
print("{:.3f} {}: disconnects from server.".format(time.time(), self.name))
|
||||||
self.conn.close()
|
self.conn.close()
|
||||||
|
|
||||||
def starttls(self):
|
def starttls(self) -> None:
|
||||||
assert not self.ssl, "SSL already active."
|
assert not self.ssl, "SSL already active."
|
||||||
self.conn = ssl.wrap_socket(self.conn)
|
self.conn = ssl.wrap_socket(self.conn)
|
||||||
self.ssl = True
|
self.ssl = True
|
||||||
|
|
||||||
def getMessages(self, synchronize=True, assert_get_one=False, raw=False):
|
def getMessages(
|
||||||
|
self, synchronize: bool = True, assert_get_one: bool = False, raw: bool = False
|
||||||
|
) -> List[message_parser.Message]:
|
||||||
|
"""actually returns List[str] in the rare case where raw=True."""
|
||||||
|
token: Optional[str]
|
||||||
if synchronize:
|
if synchronize:
|
||||||
token = "synchronize{}".format(time.monotonic())
|
token = "synchronize{}".format(time.monotonic())
|
||||||
self.sendLine("PING {}".format(token))
|
self.sendLine("PING {}".format(token))
|
||||||
@ -79,7 +84,7 @@ class ClientMock:
|
|||||||
got_pong = True
|
got_pong = True
|
||||||
else:
|
else:
|
||||||
if raw:
|
if raw:
|
||||||
messages.append(line)
|
messages.append(line) # type: ignore
|
||||||
else:
|
else:
|
||||||
messages.append(message)
|
messages.append(message)
|
||||||
data = b""
|
data = b""
|
||||||
@ -91,7 +96,13 @@ class ClientMock:
|
|||||||
else:
|
else:
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
def getMessage(self, filter_pred=None, synchronize=True, raw=False):
|
def getMessage(
|
||||||
|
self,
|
||||||
|
filter_pred: Optional[Callable[[message_parser.Message], bool]] = None,
|
||||||
|
synchronize: bool = True,
|
||||||
|
raw: bool = False,
|
||||||
|
) -> message_parser.Message:
|
||||||
|
"""Returns str in the rare case where raw=True"""
|
||||||
while True:
|
while True:
|
||||||
if not self.inbuffer:
|
if not self.inbuffer:
|
||||||
self.inbuffer = self.getMessages(
|
self.inbuffer = self.getMessages(
|
||||||
@ -103,7 +114,7 @@ class ClientMock:
|
|||||||
if not filter_pred or filter_pred(message):
|
if not filter_pred or filter_pred(message):
|
||||||
return message
|
return message
|
||||||
|
|
||||||
def sendLine(self, line):
|
def sendLine(self, line: Union[str, bytes]) -> None:
|
||||||
if isinstance(line, str):
|
if isinstance(line, str):
|
||||||
encoded_line = line.encode()
|
encoded_line = line.encode()
|
||||||
elif isinstance(line, bytes):
|
elif isinstance(line, bytes):
|
||||||
@ -113,7 +124,7 @@ class ClientMock:
|
|||||||
if not encoded_line.endswith(b"\r\n"):
|
if not encoded_line.endswith(b"\r\n"):
|
||||||
encoded_line += b"\r\n"
|
encoded_line += b"\r\n"
|
||||||
try:
|
try:
|
||||||
ret = self.conn.sendall(encoded_line)
|
ret = self.conn.sendall(encoded_line) # type: ignore
|
||||||
except BrokenPipeError:
|
except BrokenPipeError:
|
||||||
raise ConnectionClosed()
|
raise ConnectionClosed()
|
||||||
if (
|
if (
|
||||||
|
@ -2,7 +2,7 @@ from irctest import cases
|
|||||||
from irctest.irc_utils.message_parser import Message
|
from irctest.irc_utils.message_parser import Message
|
||||||
|
|
||||||
|
|
||||||
class CapTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper):
|
class CapTestCase(cases.BaseClientTestCase):
|
||||||
@cases.mark_specifications("IRCv3")
|
@cases.mark_specifications("IRCv3")
|
||||||
def testSendCap(self):
|
def testSendCap(self):
|
||||||
"""Send CAP LS 302 and read the result."""
|
"""Send CAP LS 302 and read the result."""
|
||||||
|
@ -39,9 +39,7 @@ class IdentityHash:
|
|||||||
return self._data
|
return self._data
|
||||||
|
|
||||||
|
|
||||||
class SaslTestCase(
|
class SaslTestCase(cases.BaseClientTestCase, cases.OptionalityHelper):
|
||||||
cases.BaseClientTestCase, cases.ClientNegociationHelper, cases.OptionalityHelper
|
|
||||||
):
|
|
||||||
@cases.OptionalityHelper.skipUnlessHasMechanism("PLAIN")
|
@cases.OptionalityHelper.skipUnlessHasMechanism("PLAIN")
|
||||||
def testPlain(self):
|
def testPlain(self):
|
||||||
"""Test PLAIN authentication with correct username/password."""
|
"""Test PLAIN authentication with correct username/password."""
|
||||||
@ -263,9 +261,7 @@ class SaslTestCase(
|
|||||||
authenticator.response(msg)
|
authenticator.response(msg)
|
||||||
|
|
||||||
|
|
||||||
class Irc302SaslTestCase(
|
class Irc302SaslTestCase(cases.BaseClientTestCase, cases.OptionalityHelper):
|
||||||
cases.BaseClientTestCase, cases.ClientNegociationHelper, cases.OptionalityHelper
|
|
||||||
):
|
|
||||||
@cases.OptionalityHelper.skipUnlessHasMechanism("PLAIN")
|
@cases.OptionalityHelper.skipUnlessHasMechanism("PLAIN")
|
||||||
def testPlainNotAvailable(self):
|
def testPlainNotAvailable(self):
|
||||||
"""Test the client does not try to authenticate using a mechanism the
|
"""Test the client does not try to authenticate using a mechanism the
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
from typing import Set
|
from typing import Optional, Set, Type
|
||||||
|
|
||||||
from irctest.basecontrollers import (
|
from irctest.basecontrollers import (
|
||||||
BaseServerController,
|
BaseServerController,
|
||||||
@ -47,20 +47,21 @@ class CharybdisController(BaseServerController, DirectoryBasedController):
|
|||||||
supported_sasl_mechanisms: Set[str] = set()
|
supported_sasl_mechanisms: Set[str] = set()
|
||||||
supports_sts = False
|
supports_sts = False
|
||||||
|
|
||||||
def create_config(self):
|
def create_config(self) -> None:
|
||||||
super().create_config()
|
super().create_config()
|
||||||
with self.open_file("server.conf"):
|
with self.open_file("server.conf"):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
hostname,
|
hostname: str,
|
||||||
port,
|
port: int,
|
||||||
password=None,
|
*,
|
||||||
ssl=False,
|
password: Optional[str],
|
||||||
valid_metadata_keys=None,
|
ssl: bool,
|
||||||
invalid_metadata_keys=None,
|
valid_metadata_keys: Optional[Set[str]] = None,
|
||||||
):
|
invalid_metadata_keys: Optional[Set[str]] = None,
|
||||||
|
) -> None:
|
||||||
if valid_metadata_keys or invalid_metadata_keys:
|
if valid_metadata_keys or invalid_metadata_keys:
|
||||||
raise NotImplementedByController(
|
raise NotImplementedByController(
|
||||||
"Defining valid and invalid METADATA keys."
|
"Defining valid and invalid METADATA keys."
|
||||||
@ -85,6 +86,7 @@ class CharybdisController(BaseServerController, DirectoryBasedController):
|
|||||||
ssl_config=ssl_config,
|
ssl_config=ssl_config,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
assert self.directory
|
||||||
self.proc = subprocess.Popen(
|
self.proc = subprocess.Popen(
|
||||||
[
|
[
|
||||||
self.binary_name,
|
self.binary_name,
|
||||||
@ -98,5 +100,5 @@ class CharybdisController(BaseServerController, DirectoryBasedController):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_irctest_controller_class():
|
def get_irctest_controller_class() -> Type[CharybdisController]:
|
||||||
return CharybdisController
|
return CharybdisController
|
||||||
|
@ -1,33 +1,25 @@
|
|||||||
import subprocess
|
import subprocess
|
||||||
|
from typing import Optional, Type
|
||||||
|
|
||||||
from irctest.basecontrollers import BaseClientController, NotImplementedByController
|
from irctest import authentication, tls
|
||||||
|
from irctest.basecontrollers import (
|
||||||
|
BaseClientController,
|
||||||
|
DirectoryBasedController,
|
||||||
|
NotImplementedByController,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class GircController(BaseClientController):
|
class GircController(BaseClientController, DirectoryBasedController):
|
||||||
software_name = "gIRC"
|
software_name = "gIRC"
|
||||||
supported_sasl_mechanisms = ["PLAIN"]
|
supported_sasl_mechanisms = {"PLAIN"}
|
||||||
|
|
||||||
def __init__(self):
|
def run(
|
||||||
super().__init__()
|
self,
|
||||||
self.directory = None
|
hostname: str,
|
||||||
self.proc = None
|
port: int,
|
||||||
|
auth: Optional[authentication.Authentication],
|
||||||
def kill(self):
|
tls_config: Optional[tls.TlsConfig] = None,
|
||||||
if self.proc:
|
) -> None:
|
||||||
self.proc.terminate()
|
|
||||||
try:
|
|
||||||
self.proc.wait(5)
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
self.proc.kill()
|
|
||||||
self.proc = None
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
if self.proc:
|
|
||||||
self.proc.kill()
|
|
||||||
if self.directory:
|
|
||||||
self.directory.cleanup()
|
|
||||||
|
|
||||||
def run(self, hostname, port, auth, tls_config):
|
|
||||||
if tls_config:
|
if tls_config:
|
||||||
print(tls_config)
|
print(tls_config)
|
||||||
raise NotImplementedByController("TLS options")
|
raise NotImplementedByController("TLS options")
|
||||||
@ -42,5 +34,5 @@ class GircController(BaseClientController):
|
|||||||
self.proc = subprocess.Popen(["girc_test", "connect"] + args)
|
self.proc = subprocess.Popen(["girc_test", "connect"] + args)
|
||||||
|
|
||||||
|
|
||||||
def get_irctest_controller_class():
|
def get_irctest_controller_class() -> Type[GircController]:
|
||||||
return GircController
|
return GircController
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
from typing import Set
|
from typing import Optional, Set, Type
|
||||||
|
|
||||||
from irctest.basecontrollers import (
|
from irctest.basecontrollers import (
|
||||||
BaseServerController,
|
BaseServerController,
|
||||||
@ -44,20 +44,21 @@ class HybridController(BaseServerController, DirectoryBasedController):
|
|||||||
supports_sts = False
|
supports_sts = False
|
||||||
supported_sasl_mechanisms: Set[str] = set()
|
supported_sasl_mechanisms: Set[str] = set()
|
||||||
|
|
||||||
def create_config(self):
|
def create_config(self) -> None:
|
||||||
super().create_config()
|
super().create_config()
|
||||||
with self.open_file("server.conf"):
|
with self.open_file("server.conf"):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
hostname,
|
hostname: str,
|
||||||
port,
|
port: int,
|
||||||
password=None,
|
*,
|
||||||
ssl=False,
|
password: Optional[str],
|
||||||
valid_metadata_keys=None,
|
ssl: bool,
|
||||||
invalid_metadata_keys=None,
|
valid_metadata_keys: Optional[Set[str]] = None,
|
||||||
):
|
invalid_metadata_keys: Optional[Set[str]] = None,
|
||||||
|
) -> None:
|
||||||
if valid_metadata_keys or invalid_metadata_keys:
|
if valid_metadata_keys or invalid_metadata_keys:
|
||||||
raise NotImplementedByController(
|
raise NotImplementedByController(
|
||||||
"Defining valid and invalid METADATA keys."
|
"Defining valid and invalid METADATA keys."
|
||||||
@ -82,6 +83,7 @@ class HybridController(BaseServerController, DirectoryBasedController):
|
|||||||
ssl_config=ssl_config,
|
ssl_config=ssl_config,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
assert self.directory
|
||||||
self.proc = subprocess.Popen(
|
self.proc = subprocess.Popen(
|
||||||
[
|
[
|
||||||
"ircd",
|
"ircd",
|
||||||
@ -96,5 +98,5 @@ class HybridController(BaseServerController, DirectoryBasedController):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_irctest_controller_class():
|
def get_irctest_controller_class() -> Type[HybridController]:
|
||||||
return HybridController
|
return HybridController
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
from typing import Set
|
from typing import Optional, Set, Type
|
||||||
|
|
||||||
from irctest.basecontrollers import (
|
from irctest.basecontrollers import (
|
||||||
BaseServerController,
|
BaseServerController,
|
||||||
@ -42,21 +42,22 @@ class InspircdController(BaseServerController, DirectoryBasedController):
|
|||||||
supported_sasl_mechanisms: Set[str] = set()
|
supported_sasl_mechanisms: Set[str] = set()
|
||||||
supports_str = False
|
supports_str = False
|
||||||
|
|
||||||
def create_config(self):
|
def create_config(self) -> None:
|
||||||
super().create_config()
|
super().create_config()
|
||||||
with self.open_file("server.conf"):
|
with self.open_file("server.conf"):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
hostname,
|
hostname: str,
|
||||||
port,
|
port: int,
|
||||||
password=None,
|
*,
|
||||||
ssl=False,
|
password: Optional[str],
|
||||||
restricted_metadata_keys=None,
|
ssl: bool,
|
||||||
valid_metadata_keys=None,
|
valid_metadata_keys: Optional[Set[str]] = None,
|
||||||
invalid_metadata_keys=None,
|
invalid_metadata_keys: Optional[Set[str]] = None,
|
||||||
):
|
restricted_metadata_keys: Optional[Set[str]] = None,
|
||||||
|
) -> None:
|
||||||
if valid_metadata_keys or invalid_metadata_keys:
|
if valid_metadata_keys or invalid_metadata_keys:
|
||||||
raise NotImplementedByController(
|
raise NotImplementedByController(
|
||||||
"Defining valid and invalid METADATA keys."
|
"Defining valid and invalid METADATA keys."
|
||||||
@ -81,6 +82,7 @@ class InspircdController(BaseServerController, DirectoryBasedController):
|
|||||||
ssl_config=ssl_config,
|
ssl_config=ssl_config,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
assert self.directory
|
||||||
self.proc = subprocess.Popen(
|
self.proc = subprocess.Popen(
|
||||||
[
|
[
|
||||||
"inspircd",
|
"inspircd",
|
||||||
@ -92,5 +94,5 @@ class InspircdController(BaseServerController, DirectoryBasedController):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_irctest_controller_class():
|
def get_irctest_controller_class() -> Type[InspircdController]:
|
||||||
return InspircdController
|
return InspircdController
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Type
|
||||||
|
|
||||||
from .charybdis import CharybdisController
|
from .charybdis import CharybdisController
|
||||||
|
|
||||||
|
|
||||||
@ -6,5 +8,5 @@ class IrcdSevenController(CharybdisController):
|
|||||||
binary_name = "ircd-seven"
|
binary_name = "ircd-seven"
|
||||||
|
|
||||||
|
|
||||||
def get_irctest_controller_class():
|
def get_irctest_controller_class() -> Type[IrcdSevenController]:
|
||||||
return IrcdSevenController
|
return IrcdSevenController
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
|
from typing import Optional, Type
|
||||||
|
|
||||||
from irctest import tls
|
from irctest import authentication, tls
|
||||||
from irctest.basecontrollers import BaseClientController, DirectoryBasedController
|
from irctest.basecontrollers import BaseClientController, DirectoryBasedController
|
||||||
|
|
||||||
TEMPLATE_CONFIG = """
|
TEMPLATE_CONFIG = """
|
||||||
@ -35,15 +36,20 @@ class LimnoriaController(BaseClientController, DirectoryBasedController):
|
|||||||
}
|
}
|
||||||
supports_sts = True
|
supports_sts = True
|
||||||
|
|
||||||
def create_config(self):
|
def create_config(self) -> None:
|
||||||
create_config = super().create_config()
|
super().create_config()
|
||||||
if create_config:
|
with self.open_file("bot.conf"):
|
||||||
with self.open_file("bot.conf"):
|
pass
|
||||||
pass
|
with self.open_file("conf/users.conf"):
|
||||||
with self.open_file("conf/users.conf"):
|
pass
|
||||||
pass
|
|
||||||
|
|
||||||
def run(self, hostname, port, auth, tls_config=None):
|
def run(
|
||||||
|
self,
|
||||||
|
hostname: str,
|
||||||
|
port: int,
|
||||||
|
auth: Optional[authentication.Authentication],
|
||||||
|
tls_config: Optional[tls.TlsConfig] = None,
|
||||||
|
) -> None:
|
||||||
if tls_config is None:
|
if tls_config is None:
|
||||||
tls_config = tls.TlsConfig(enable=False, trusted_fingerprints=[])
|
tls_config = tls.TlsConfig(enable=False, trusted_fingerprints=[])
|
||||||
# Runs a client with the config given as arguments
|
# Runs a client with the config given as arguments
|
||||||
@ -72,10 +78,11 @@ class LimnoriaController(BaseClientController, DirectoryBasedController):
|
|||||||
else "",
|
else "",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
assert self.directory
|
||||||
self.proc = subprocess.Popen(
|
self.proc = subprocess.Popen(
|
||||||
["supybot", os.path.join(self.directory, "bot.conf")]
|
["supybot", os.path.join(self.directory, "bot.conf")]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_irctest_controller_class():
|
def get_irctest_controller_class() -> Type[LimnoriaController]:
|
||||||
return LimnoriaController
|
return LimnoriaController
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
|
from typing import Optional, Set, Type
|
||||||
|
|
||||||
from irctest.basecontrollers import (
|
from irctest.basecontrollers import (
|
||||||
BaseServerController,
|
BaseServerController,
|
||||||
DirectoryBasedController,
|
DirectoryBasedController,
|
||||||
NotImplementedByController,
|
NotImplementedByController,
|
||||||
)
|
)
|
||||||
|
from irctest.cases import BaseServerTestCase
|
||||||
|
|
||||||
TEMPLATE_CONFIG = """
|
TEMPLATE_CONFIG = """
|
||||||
clients:
|
clients:
|
||||||
@ -61,7 +63,7 @@ server:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def make_list(list_):
|
def make_list(list_: Set[str]) -> str:
|
||||||
return "\n".join(map(" - {}".format, list_))
|
return "\n".join(map(" - {}".format, list_))
|
||||||
|
|
||||||
|
|
||||||
@ -69,25 +71,27 @@ class MammonController(BaseServerController, DirectoryBasedController):
|
|||||||
software_name = "Mammon"
|
software_name = "Mammon"
|
||||||
supported_sasl_mechanisms = {"PLAIN", "ECDSA-NIST256P-CHALLENGE"}
|
supported_sasl_mechanisms = {"PLAIN", "ECDSA-NIST256P-CHALLENGE"}
|
||||||
|
|
||||||
def create_config(self):
|
def create_config(self) -> None:
|
||||||
super().create_config()
|
super().create_config()
|
||||||
with self.open_file("server.conf"):
|
with self.open_file("server.conf"):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def kill_proc(self):
|
def kill_proc(self) -> None:
|
||||||
# Mammon does not seem to handle SIGTERM very well
|
# Mammon does not seem to handle SIGTERM very well
|
||||||
|
assert self.proc
|
||||||
self.proc.kill()
|
self.proc.kill()
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
hostname,
|
hostname: str,
|
||||||
port,
|
port: int,
|
||||||
password=None,
|
*,
|
||||||
ssl=False,
|
password: Optional[str],
|
||||||
restricted_metadata_keys=(),
|
ssl: bool,
|
||||||
valid_metadata_keys=(),
|
valid_metadata_keys: Optional[Set[str]] = None,
|
||||||
invalid_metadata_keys=(),
|
invalid_metadata_keys: Optional[Set[str]] = None,
|
||||||
):
|
restricted_metadata_keys: Optional[Set[str]] = None,
|
||||||
|
) -> None:
|
||||||
if password is not None:
|
if password is not None:
|
||||||
raise NotImplementedByController("PASS command")
|
raise NotImplementedByController("PASS command")
|
||||||
if ssl:
|
if ssl:
|
||||||
@ -101,12 +105,13 @@ class MammonController(BaseServerController, DirectoryBasedController):
|
|||||||
directory=self.directory,
|
directory=self.directory,
|
||||||
hostname=hostname,
|
hostname=hostname,
|
||||||
port=port,
|
port=port,
|
||||||
authorized_keys=make_list(valid_metadata_keys),
|
authorized_keys=make_list(valid_metadata_keys or set()),
|
||||||
restricted_keys=make_list(restricted_metadata_keys),
|
restricted_keys=make_list(restricted_metadata_keys or set()),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# with self.open_file('server.yml', 'r') as fd:
|
# with self.open_file('server.yml', 'r') as fd:
|
||||||
# print(fd.read())
|
# print(fd.read())
|
||||||
|
assert self.directory
|
||||||
self.proc = subprocess.Popen(
|
self.proc = subprocess.Popen(
|
||||||
[
|
[
|
||||||
"mammond",
|
"mammond",
|
||||||
@ -116,7 +121,12 @@ class MammonController(BaseServerController, DirectoryBasedController):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def registerUser(self, case, username, password=None):
|
def registerUser(
|
||||||
|
self,
|
||||||
|
case: BaseServerTestCase,
|
||||||
|
username: str,
|
||||||
|
password: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
# XXX: Move this somewhere else when
|
# XXX: Move this somewhere else when
|
||||||
# https://github.com/ircv3/ircv3-specifications/pull/152 becomes
|
# https://github.com/ircv3/ircv3-specifications/pull/152 becomes
|
||||||
# part of the specification
|
# part of the specification
|
||||||
@ -135,5 +145,5 @@ class MammonController(BaseServerController, DirectoryBasedController):
|
|||||||
case.removeClient(client)
|
case.removeClient(client)
|
||||||
|
|
||||||
|
|
||||||
def get_irctest_controller_class():
|
def get_irctest_controller_class() -> Type[MammonController]:
|
||||||
return MammonController
|
return MammonController
|
||||||
|
@ -2,12 +2,14 @@ import copy
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
|
from typing import Any, Dict, Optional, Set, Type, Union
|
||||||
|
|
||||||
from irctest.basecontrollers import (
|
from irctest.basecontrollers import (
|
||||||
BaseServerController,
|
BaseServerController,
|
||||||
DirectoryBasedController,
|
DirectoryBasedController,
|
||||||
NotImplementedByController,
|
NotImplementedByController,
|
||||||
)
|
)
|
||||||
|
from irctest.cases import BaseServerTestCase
|
||||||
|
|
||||||
OPER_PWD = "frenchfries"
|
OPER_PWD = "frenchfries"
|
||||||
|
|
||||||
@ -116,7 +118,7 @@ BASE_CONFIG = {
|
|||||||
LOGGING_CONFIG = {"logging": [{"method": "stderr", "level": "debug", "type": "*"}]}
|
LOGGING_CONFIG = {"logging": [{"method": "stderr", "level": "debug", "type": "*"}]}
|
||||||
|
|
||||||
|
|
||||||
def hash_password(password):
|
def hash_password(password: Union[str, bytes]) -> str:
|
||||||
if isinstance(password, str):
|
if isinstance(password, str):
|
||||||
password = password.encode("utf-8")
|
password = password.encode("utf-8")
|
||||||
# simulate entry of password and confirmation:
|
# simulate entry of password and confirmation:
|
||||||
@ -134,25 +136,23 @@ class OragonoController(BaseServerController, DirectoryBasedController):
|
|||||||
supported_sasl_mechanisms = {"PLAIN"}
|
supported_sasl_mechanisms = {"PLAIN"}
|
||||||
supports_sts = True
|
supports_sts = True
|
||||||
|
|
||||||
def create_config(self):
|
def create_config(self) -> None:
|
||||||
super().create_config()
|
super().create_config()
|
||||||
with self.open_file("ircd.yaml"):
|
with self.open_file("ircd.yaml"):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def kill_proc(self):
|
|
||||||
self.proc.kill()
|
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
hostname,
|
hostname: str,
|
||||||
port,
|
port: int,
|
||||||
password=None,
|
*,
|
||||||
ssl=False,
|
password: Optional[str],
|
||||||
restricted_metadata_keys=None,
|
ssl: bool,
|
||||||
valid_metadata_keys=None,
|
valid_metadata_keys: Optional[Set[str]] = None,
|
||||||
invalid_metadata_keys=None,
|
invalid_metadata_keys: Optional[Set[str]] = None,
|
||||||
config=None,
|
restricted_metadata_keys: Optional[Set[str]] = None,
|
||||||
):
|
config: Optional[Any] = None,
|
||||||
|
) -> None:
|
||||||
if valid_metadata_keys or invalid_metadata_keys:
|
if valid_metadata_keys or invalid_metadata_keys:
|
||||||
raise NotImplementedByController(
|
raise NotImplementedByController(
|
||||||
"Defining valid and invalid METADATA keys."
|
"Defining valid and invalid METADATA keys."
|
||||||
@ -162,6 +162,8 @@ class OragonoController(BaseServerController, DirectoryBasedController):
|
|||||||
if config is None:
|
if config is None:
|
||||||
config = copy.deepcopy(BASE_CONFIG)
|
config = copy.deepcopy(BASE_CONFIG)
|
||||||
|
|
||||||
|
assert self.directory
|
||||||
|
|
||||||
enable_chathistory = self.test_config.chathistory
|
enable_chathistory = self.test_config.chathistory
|
||||||
enable_roleplay = self.test_config.oragono_roleplay
|
enable_roleplay = self.test_config.oragono_roleplay
|
||||||
if enable_chathistory or enable_roleplay:
|
if enable_chathistory or enable_roleplay:
|
||||||
@ -180,12 +182,14 @@ class OragonoController(BaseServerController, DirectoryBasedController):
|
|||||||
self.key_path = os.path.join(self.directory, "ssl.key")
|
self.key_path = os.path.join(self.directory, "ssl.key")
|
||||||
self.pem_path = os.path.join(self.directory, "ssl.pem")
|
self.pem_path = os.path.join(self.directory, "ssl.pem")
|
||||||
listener_conf = {"tls": {"cert": self.pem_path, "key": self.key_path}}
|
listener_conf = {"tls": {"cert": self.pem_path, "key": self.key_path}}
|
||||||
config["server"]["listeners"][bind_address] = listener_conf
|
config["server"]["listeners"][bind_address] = listener_conf # type: ignore
|
||||||
|
|
||||||
config["datastore"]["path"] = os.path.join(self.directory, "ircd.db")
|
config["datastore"]["path"] = os.path.join( # type: ignore
|
||||||
|
self.directory, "ircd.db"
|
||||||
|
)
|
||||||
|
|
||||||
if password is not None:
|
if password is not None:
|
||||||
config["server"]["password"] = hash_password(password)
|
config["server"]["password"] = hash_password(password) # type: ignore
|
||||||
|
|
||||||
assert self.proc is None
|
assert self.proc is None
|
||||||
|
|
||||||
@ -198,7 +202,12 @@ class OragonoController(BaseServerController, DirectoryBasedController):
|
|||||||
["oragono", "run", "--conf", self._config_path, "--quiet"]
|
["oragono", "run", "--conf", self._config_path, "--quiet"]
|
||||||
)
|
)
|
||||||
|
|
||||||
def registerUser(self, case, username, password=None):
|
def registerUser(
|
||||||
|
self,
|
||||||
|
case: BaseServerTestCase,
|
||||||
|
username: str,
|
||||||
|
password: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
# XXX: Move this somewhere else when
|
# XXX: Move this somewhere else when
|
||||||
# https://github.com/ircv3/ircv3-specifications/pull/152 becomes
|
# https://github.com/ircv3/ircv3-specifications/pull/152 becomes
|
||||||
# part of the specification
|
# part of the specification
|
||||||
@ -210,34 +219,35 @@ class OragonoController(BaseServerController, DirectoryBasedController):
|
|||||||
while case.getRegistrationMessage(client).command != "001":
|
while case.getRegistrationMessage(client).command != "001":
|
||||||
pass
|
pass
|
||||||
case.getMessages(client)
|
case.getMessages(client)
|
||||||
|
assert password
|
||||||
case.sendLine(client, "NS REGISTER " + password)
|
case.sendLine(client, "NS REGISTER " + password)
|
||||||
msg = case.getMessage(client)
|
msg = case.getMessage(client)
|
||||||
assert msg.params == [username, "Account created"]
|
assert msg.params == [username, "Account created"]
|
||||||
case.sendLine(client, "QUIT")
|
case.sendLine(client, "QUIT")
|
||||||
case.assertDisconnected(client)
|
case.assertDisconnected(client)
|
||||||
|
|
||||||
def _write_config(self):
|
def _write_config(self) -> None:
|
||||||
with open(self._config_path, "w") as fd:
|
with open(self._config_path, "w") as fd:
|
||||||
json.dump(self._config, fd)
|
json.dump(self._config, fd)
|
||||||
|
|
||||||
def baseConfig(self):
|
def baseConfig(self) -> Dict:
|
||||||
return copy.deepcopy(BASE_CONFIG)
|
return copy.deepcopy(BASE_CONFIG)
|
||||||
|
|
||||||
def getConfig(self):
|
def getConfig(self) -> Dict:
|
||||||
return copy.deepcopy(self._config)
|
return copy.deepcopy(self._config)
|
||||||
|
|
||||||
def addLoggingToConfig(self, config=None):
|
def addLoggingToConfig(self, config: Optional[Dict] = None) -> Dict:
|
||||||
if config is None:
|
if config is None:
|
||||||
config = self.baseConfig()
|
config = self.baseConfig()
|
||||||
config.update(LOGGING_CONFIG)
|
config.update(LOGGING_CONFIG)
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def addMysqlToConfig(self, config=None):
|
def addMysqlToConfig(self, config: Optional[Dict] = None) -> Dict:
|
||||||
mysql_password = os.getenv("MYSQL_PASSWORD")
|
mysql_password = os.getenv("MYSQL_PASSWORD")
|
||||||
if not mysql_password:
|
|
||||||
return config
|
|
||||||
if config is None:
|
if config is None:
|
||||||
config = self.baseConfig()
|
config = self.baseConfig()
|
||||||
|
if not mysql_password:
|
||||||
|
return config
|
||||||
config["datastore"]["mysql"] = {
|
config["datastore"]["mysql"] = {
|
||||||
"enabled": True,
|
"enabled": True,
|
||||||
"host": "localhost",
|
"host": "localhost",
|
||||||
@ -259,7 +269,7 @@ class OragonoController(BaseServerController, DirectoryBasedController):
|
|||||||
}
|
}
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def rehash(self, case, config):
|
def rehash(self, case: BaseServerTestCase, config: Dict) -> None:
|
||||||
self._config = config
|
self._config = config
|
||||||
self._write_config()
|
self._write_config()
|
||||||
client = "operator_for_rehash"
|
client = "operator_for_rehash"
|
||||||
@ -270,11 +280,11 @@ class OragonoController(BaseServerController, DirectoryBasedController):
|
|||||||
case.sendLine(client, "QUIT")
|
case.sendLine(client, "QUIT")
|
||||||
case.assertDisconnected(client)
|
case.assertDisconnected(client)
|
||||||
|
|
||||||
def enable_debug_logging(self, case):
|
def enable_debug_logging(self, case: BaseServerTestCase) -> None:
|
||||||
config = self.getConfig()
|
config = self.getConfig()
|
||||||
config.update(LOGGING_CONFIG)
|
config.update(LOGGING_CONFIG)
|
||||||
self.rehash(case, config)
|
self.rehash(case, config)
|
||||||
|
|
||||||
|
|
||||||
def get_irctest_controller_class():
|
def get_irctest_controller_class() -> Type[OragonoController]:
|
||||||
return OragonoController
|
return OragonoController
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Type
|
||||||
|
|
||||||
from .charybdis import CharybdisController
|
from .charybdis import CharybdisController
|
||||||
|
|
||||||
|
|
||||||
@ -6,5 +8,5 @@ class SolanumController(CharybdisController):
|
|||||||
binary_name = "solanum"
|
binary_name = "solanum"
|
||||||
|
|
||||||
|
|
||||||
def get_irctest_controller_class():
|
def get_irctest_controller_class() -> Type[SolanumController]:
|
||||||
return SolanumController
|
return SolanumController
|
||||||
|
@ -1,8 +1,14 @@
|
|||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from typing import Optional, TextIO, Type, cast
|
||||||
|
|
||||||
from irctest.basecontrollers import BaseClientController, NotImplementedByController
|
from irctest import authentication, tls
|
||||||
|
from irctest.basecontrollers import (
|
||||||
|
BaseClientController,
|
||||||
|
NotImplementedByController,
|
||||||
|
TestCaseControllerConfig,
|
||||||
|
)
|
||||||
|
|
||||||
TEMPLATE_CONFIG = """
|
TEMPLATE_CONFIG = """
|
||||||
[core]
|
[core]
|
||||||
@ -24,30 +30,34 @@ class SopelController(BaseClientController):
|
|||||||
supported_sasl_mechanisms = {"PLAIN"}
|
supported_sasl_mechanisms = {"PLAIN"}
|
||||||
supports_sts = False
|
supports_sts = False
|
||||||
|
|
||||||
def __init__(self, test_config):
|
def __init__(self, test_config: TestCaseControllerConfig):
|
||||||
super().__init__(test_config)
|
super().__init__(test_config)
|
||||||
self.filename = next(tempfile._get_candidate_names()) + ".cfg"
|
self.filename = next(tempfile._get_candidate_names()) + ".cfg" # type: ignore
|
||||||
self.proc = None
|
|
||||||
|
|
||||||
def kill(self):
|
def kill(self) -> None:
|
||||||
if self.proc:
|
super().kill()
|
||||||
self.proc.kill()
|
|
||||||
if self.filename:
|
if self.filename:
|
||||||
try:
|
try:
|
||||||
os.unlink(os.path.join(os.path.expanduser("~/.sopel/"), self.filename))
|
os.unlink(os.path.join(os.path.expanduser("~/.sopel/"), self.filename))
|
||||||
except OSError: # File does not exist
|
except OSError: # File does not exist
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def open_file(self, filename, mode="a"):
|
def open_file(self, filename: str, mode: str = "a") -> TextIO:
|
||||||
dir_path = os.path.expanduser("~/.sopel/")
|
dir_path = os.path.expanduser("~/.sopel/")
|
||||||
os.makedirs(dir_path, exist_ok=True)
|
os.makedirs(dir_path, exist_ok=True)
|
||||||
return open(os.path.join(dir_path, filename), mode)
|
return cast(TextIO, open(os.path.join(dir_path, filename), mode))
|
||||||
|
|
||||||
def create_config(self):
|
def create_config(self) -> None:
|
||||||
with self.open_file(self.filename):
|
with self.open_file(self.filename):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def run(self, hostname, port, auth, tls_config):
|
def run(
|
||||||
|
self,
|
||||||
|
hostname: str,
|
||||||
|
port: int,
|
||||||
|
auth: Optional[authentication.Authentication],
|
||||||
|
tls_config: Optional[tls.TlsConfig] = None,
|
||||||
|
) -> None:
|
||||||
# Runs a client with the config given as arguments
|
# Runs a client with the config given as arguments
|
||||||
if tls_config is not None:
|
if tls_config is not None:
|
||||||
raise NotImplementedByController("TLS configuration")
|
raise NotImplementedByController("TLS configuration")
|
||||||
@ -66,5 +76,5 @@ class SopelController(BaseClientController):
|
|||||||
self.proc = subprocess.Popen(["sopel", "--quiet", "-c", self.filename])
|
self.proc = subprocess.Popen(["sopel", "--quiet", "-c", self.filename])
|
||||||
|
|
||||||
|
|
||||||
def get_irctest_controller_class():
|
def get_irctest_controller_class() -> Type[SopelController]:
|
||||||
return SopelController
|
return SopelController
|
||||||
|
@ -2,8 +2,10 @@
|
|||||||
Handles ambiguities of RFCs.
|
Handles ambiguities of RFCs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
def normalize_namreply_params(params):
|
|
||||||
|
def normalize_namreply_params(params: List[str]) -> List[str]:
|
||||||
# So… RFC 2812 says:
|
# So… RFC 2812 says:
|
||||||
# "( "=" / "*" / "@" ) <channel>
|
# "( "=" / "*" / "@" ) <channel>
|
||||||
# :[ "@" / "+" ] <nick> *( " " [ "@" / "+" ] <nick> )
|
# :[ "@" / "+" ] <nick> *( " " [ "@" / "+" ] <nick> )
|
||||||
@ -12,6 +14,7 @@ def normalize_namreply_params(params):
|
|||||||
# prefix.
|
# prefix.
|
||||||
# So let's normalize this to “with space”, and strip spaces at the
|
# So let's normalize this to “with space”, and strip spaces at the
|
||||||
# end of the nick list.
|
# end of the nick list.
|
||||||
|
params = list(params) # copy the list
|
||||||
if len(params) == 3:
|
if len(params) == 3:
|
||||||
assert params[1][0] in "=*@", params
|
assert params[1][0] in "=*@", params
|
||||||
params.insert(1, params[1][0])
|
params.insert(1, params[1][0])
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
def cap_list_to_dict(caps):
|
from typing import Dict, List, Optional
|
||||||
d = {}
|
|
||||||
|
|
||||||
|
def cap_list_to_dict(caps: List[str]) -> Dict[str, Optional[str]]:
|
||||||
|
d: Dict[str, Optional[str]] = {}
|
||||||
for cap in caps:
|
for cap in caps:
|
||||||
if "=" in cap:
|
if "=" in cap:
|
||||||
(key, value) = cap.split("=", 1)
|
(key, value) = cap.split("=", 1)
|
||||||
|
d[key] = value
|
||||||
else:
|
else:
|
||||||
key = cap
|
d[cap] = None
|
||||||
value = None
|
|
||||||
d[key] = value
|
|
||||||
return d
|
return d
|
||||||
|
@ -1,16 +1,17 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import re
|
import re
|
||||||
import secrets
|
import secrets
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
# thanks jess!
|
# thanks jess!
|
||||||
IRCV3_FORMAT_STRFTIME = "%Y-%m-%dT%H:%M:%S.%f%z"
|
IRCV3_FORMAT_STRFTIME = "%Y-%m-%dT%H:%M:%S.%f%z"
|
||||||
|
|
||||||
|
|
||||||
def ircv3_timestamp_to_unixtime(timestamp):
|
def ircv3_timestamp_to_unixtime(timestamp: str) -> float:
|
||||||
return datetime.datetime.strptime(timestamp, IRCV3_FORMAT_STRFTIME).timestamp()
|
return datetime.datetime.strptime(timestamp, IRCV3_FORMAT_STRFTIME).timestamp()
|
||||||
|
|
||||||
|
|
||||||
def random_name(base):
|
def random_name(base: str) -> str:
|
||||||
return base + "-" + secrets.token_hex(8)
|
return base + "-" + secrets.token_hex(8)
|
||||||
|
|
||||||
|
|
||||||
@ -26,16 +27,16 @@ class MultipleReplacer:
|
|||||||
# We use an object instead of a lambda function because it avoids the
|
# We use an object instead of a lambda function because it avoids the
|
||||||
# need for using the staticmethod() on the lambda function if assigning
|
# need for using the staticmethod() on the lambda function if assigning
|
||||||
# it to a class in Python 3.
|
# it to a class in Python 3.
|
||||||
def __init__(self, dict_):
|
def __init__(self, dict_: Dict[str, str]):
|
||||||
self._dict = dict_
|
self._dict = dict_
|
||||||
dict_ = dict([(re.escape(key), val) for key, val in dict_.items()])
|
dict_ = dict([(re.escape(key), val) for key, val in dict_.items()])
|
||||||
self._matcher = re.compile("|".join(dict_.keys()))
|
self._matcher = re.compile("|".join(dict_.keys()))
|
||||||
|
|
||||||
def __call__(self, s):
|
def __call__(self, s: str) -> str:
|
||||||
return self._matcher.sub(lambda m: self._dict[m.group(0)], s)
|
return self._matcher.sub(lambda m: self._dict[m.group(0)], s)
|
||||||
|
|
||||||
|
|
||||||
def normalizeWhitespace(s, removeNewline=True):
|
def normalizeWhitespace(s: str, removeNewline: bool = True) -> str:
|
||||||
r"""Normalizes the whitespace in a string; \s+ becomes one space."""
|
r"""Normalizes the whitespace in a string; \s+ becomes one space."""
|
||||||
if not s:
|
if not s:
|
||||||
return str(s) # not the same reference
|
return str(s) # not the same reference
|
||||||
|
@ -18,8 +18,8 @@ unescape_tag_value = MultipleReplacer(dict(map(lambda x: (x[1], x[0]), TAG_ESCAP
|
|||||||
tag_key_validator = re.compile(r"\+?(\S+/)?[a-zA-Z0-9-]+")
|
tag_key_validator = re.compile(r"\+?(\S+/)?[a-zA-Z0-9-]+")
|
||||||
|
|
||||||
|
|
||||||
def parse_tags(s):
|
def parse_tags(s: str) -> Dict[str, Optional[str]]:
|
||||||
tags = {}
|
tags: Dict[str, Optional[str]] = {}
|
||||||
for tag in s.split(";"):
|
for tag in s.split(";"):
|
||||||
if "=" not in tag:
|
if "=" not in tag:
|
||||||
tags[tag] = None
|
tags[tag] = None
|
||||||
@ -54,15 +54,15 @@ class Message:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_message(s):
|
def parse_message(s: str) -> Message:
|
||||||
"""Parse a message according to
|
"""Parse a message according to
|
||||||
http://tools.ietf.org/html/rfc1459#section-2.3.1
|
http://tools.ietf.org/html/rfc1459#section-2.3.1
|
||||||
and
|
and
|
||||||
http://ircv3.net/specs/core/message-tags-3.2.html"""
|
http://ircv3.net/specs/core/message-tags-3.2.html"""
|
||||||
s = s.rstrip("\r\n")
|
s = s.rstrip("\r\n")
|
||||||
if s.startswith("@"):
|
if s.startswith("@"):
|
||||||
(tags, s) = s.split(" ", 1)
|
(tags_str, s) = s.split(" ", 1)
|
||||||
tags = parse_tags(tags[1:])
|
tags = parse_tags(tags_str[1:])
|
||||||
else:
|
else:
|
||||||
tags = {}
|
tags = {}
|
||||||
if " :" in s:
|
if " :" in s:
|
||||||
@ -70,10 +70,7 @@ def parse_message(s):
|
|||||||
tokens = list(filter(bool, other_tokens.split(" "))) + [trailing_param]
|
tokens = list(filter(bool, other_tokens.split(" "))) + [trailing_param]
|
||||||
else:
|
else:
|
||||||
tokens = list(filter(bool, s.split(" ")))
|
tokens = list(filter(bool, s.split(" ")))
|
||||||
if tokens[0].startswith(":"):
|
prefix = prefix = tokens.pop(0)[1:] if tokens[0].startswith(":") else None
|
||||||
prefix = tokens.pop(0)[1:]
|
|
||||||
else:
|
|
||||||
prefix = None
|
|
||||||
command = tokens.pop(0)
|
command = tokens.pop(0)
|
||||||
params = tokens
|
params = tokens
|
||||||
return Message(tags=tags, prefix=prefix, command=command, params=params)
|
return Message(tags=tags, prefix=prefix, command=command, params=params)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import base64
|
import base64
|
||||||
|
|
||||||
|
|
||||||
def sasl_plain_blob(username, passphrase):
|
def sasl_plain_blob(username: str, passphrase: str) -> str:
|
||||||
blob = base64.b64encode(
|
blob = base64.b64encode(
|
||||||
b"\x00".join(
|
b"\x00".join(
|
||||||
(
|
(
|
||||||
|
@ -1,14 +1,15 @@
|
|||||||
import collections
|
import collections
|
||||||
|
from typing import Dict, Union
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
class NotImplementedByController(unittest.SkipTest, NotImplementedError):
|
class NotImplementedByController(unittest.SkipTest, NotImplementedError):
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
return "Not implemented by controller: {}".format(self.args[0])
|
return "Not implemented by controller: {}".format(self.args[0])
|
||||||
|
|
||||||
|
|
||||||
class ImplementationChoice(unittest.SkipTest):
|
class ImplementationChoice(unittest.SkipTest):
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
return (
|
return (
|
||||||
"Choice in the implementation makes it impossible to "
|
"Choice in the implementation makes it impossible to "
|
||||||
"perform a test: {}".format(self.args[0])
|
"perform a test: {}".format(self.args[0])
|
||||||
@ -16,49 +17,49 @@ class ImplementationChoice(unittest.SkipTest):
|
|||||||
|
|
||||||
|
|
||||||
class OptionalExtensionNotSupported(unittest.SkipTest):
|
class OptionalExtensionNotSupported(unittest.SkipTest):
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
return "Unsupported extension: {}".format(self.args[0])
|
return "Unsupported extension: {}".format(self.args[0])
|
||||||
|
|
||||||
|
|
||||||
class OptionalSaslMechanismNotSupported(unittest.SkipTest):
|
class OptionalSaslMechanismNotSupported(unittest.SkipTest):
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
return "Unsupported SASL mechanism: {}".format(self.args[0])
|
return "Unsupported SASL mechanism: {}".format(self.args[0])
|
||||||
|
|
||||||
|
|
||||||
class CapabilityNotSupported(unittest.SkipTest):
|
class CapabilityNotSupported(unittest.SkipTest):
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
return "Unsupported capability: {}".format(self.args[0])
|
return "Unsupported capability: {}".format(self.args[0])
|
||||||
|
|
||||||
|
|
||||||
class IsupportTokenNotSupported(unittest.SkipTest):
|
class IsupportTokenNotSupported(unittest.SkipTest):
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
return "Unsupported ISUPPORT token: {}".format(self.args[0])
|
return "Unsupported ISUPPORT token: {}".format(self.args[0])
|
||||||
|
|
||||||
|
|
||||||
class ChannelModeNotSupported(unittest.SkipTest):
|
class ChannelModeNotSupported(unittest.SkipTest):
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
return "Unsupported channel mode: {} ({})".format(self.args[0], self.args[1])
|
return "Unsupported channel mode: {} ({})".format(self.args[0], self.args[1])
|
||||||
|
|
||||||
|
|
||||||
class ExtbanNotSupported(unittest.SkipTest):
|
class ExtbanNotSupported(unittest.SkipTest):
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
return "Unsupported extban: {} ({})".format(self.args[0], self.args[1])
|
return "Unsupported extban: {} ({})".format(self.args[0], self.args[1])
|
||||||
|
|
||||||
|
|
||||||
class NotRequiredBySpecifications(unittest.SkipTest):
|
class NotRequiredBySpecifications(unittest.SkipTest):
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
return "Tests not required by the set of tested specification(s)."
|
return "Tests not required by the set of tested specification(s)."
|
||||||
|
|
||||||
|
|
||||||
class SkipStrictTest(unittest.SkipTest):
|
class SkipStrictTest(unittest.SkipTest):
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
return "Tests not required because strict tests are disabled."
|
return "Tests not required because strict tests are disabled."
|
||||||
|
|
||||||
|
|
||||||
class TextTestResult(unittest.TextTestResult):
|
class TextTestResult(unittest.TextTestResult):
|
||||||
def getDescription(self, test):
|
def getDescription(self, test: unittest.TestCase) -> str:
|
||||||
if hasattr(test, "description"):
|
if hasattr(test, "description"):
|
||||||
doc_first_lines = test.description()
|
doc_first_lines = test.description() # type: ignore
|
||||||
else:
|
else:
|
||||||
doc_first_lines = test.shortDescription()
|
doc_first_lines = test.shortDescription()
|
||||||
return "\n".join((str(test), doc_first_lines or ""))
|
return "\n".join((str(test), doc_first_lines or ""))
|
||||||
@ -71,7 +72,9 @@ class TextTestRunner(unittest.TextTestRunner):
|
|||||||
|
|
||||||
resultclass = TextTestResult
|
resultclass = TextTestResult
|
||||||
|
|
||||||
def run(self, test):
|
def run(
|
||||||
|
self, test: Union[unittest.TestSuite, unittest.TestCase]
|
||||||
|
) -> unittest.TestResult:
|
||||||
result = super().run(test)
|
result = super().run(test)
|
||||||
assert self.resultclass is TextTestResult
|
assert self.resultclass is TextTestResult
|
||||||
if result.skipped:
|
if result.skipped:
|
||||||
@ -80,7 +83,7 @@ class TextTestRunner(unittest.TextTestRunner):
|
|||||||
"Some tests were skipped because the following optional "
|
"Some tests were skipped because the following optional "
|
||||||
"specifications/mechanisms are not supported:"
|
"specifications/mechanisms are not supported:"
|
||||||
)
|
)
|
||||||
msg_to_count = collections.defaultdict(lambda: 0)
|
msg_to_count: Dict[str, int] = collections.defaultdict(lambda: 0)
|
||||||
for (test, msg) in result.skipped:
|
for (test, msg) in result.skipped:
|
||||||
msg_to_count[msg] += 1
|
msg_to_count[msg] += 1
|
||||||
for (msg, count) in sorted(msg_to_count.items()):
|
for (msg, count) in sorted(msg_to_count.items()):
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
|
|
||||||
|
|
||||||
@ -15,7 +17,7 @@ class Specifications(enum.Enum):
|
|||||||
Modern = "modern"
|
Modern = "modern"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_name(cls, name):
|
def from_name(cls, name: str) -> Specifications:
|
||||||
name = name.upper()
|
name = name.upper()
|
||||||
for spec in cls:
|
for spec in cls:
|
||||||
if spec.value.upper() == name:
|
if spec.value.upper() == name:
|
||||||
@ -37,7 +39,7 @@ class Capabilities(enum.Enum):
|
|||||||
STS = "sts"
|
STS = "sts"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_name(cls, name):
|
def from_name(cls, name: str) -> Capabilities:
|
||||||
try:
|
try:
|
||||||
return cls(name.lower())
|
return cls(name.lower())
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@ -50,7 +52,7 @@ class IsupportTokens(enum.Enum):
|
|||||||
STATUSMSG = "STATUSMSG"
|
STATUSMSG = "STATUSMSG"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_name(cls, name):
|
def from_name(cls, name: str) -> IsupportTokens:
|
||||||
try:
|
try:
|
||||||
return cls(name.upper())
|
return cls(name.upper())
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
Reference in New Issue
Block a user