irctest/irctest/basecontrollers.py

363 lines
12 KiB
Python

from __future__ import annotations
import dataclasses
import os
import shutil
import socket
import subprocess
import tempfile
import time
from typing import IO, Any, Callable, Dict, List, Optional, Set, Tuple, Type
import irctest
from . import authentication, tls
from .client_mock import ClientMock
from .irc_utils.junkdrawer import find_hostname_and_port
from .irc_utils.message_parser import Message
from .runner import NotImplementedByController
class ProcessStopped(Exception):
"""Raised when the controlled process stopped unexpectedly"""
pass
@dataclasses.dataclass
class TestCaseControllerConfig:
"""Test-case-specific configuration passed to the controller.
This is usually used to ask controllers to enable a feature;
but should not be an issue if controllers enable it all the time."""
chathistory: bool = False
"""Whether to enable chathistory features."""
ergo_roleplay: bool = False
"""Whether to enable the Ergo role-play commands."""
ergo_config: Optional[Callable[[Dict], Any]] = None
"""Oragono-specific configuration function that alters the dict in-place
This should be used as little as possible, using the other attributes instead;
as they are work with any controller."""
class _BaseController:
"""Base class for software controllers.
A software controller is an object that handles configuring and running
a process (eg. a server or a client), as well as sending it instructions
that are not part of the IRC specification."""
# set by conftest.py
openssl_bin: str
supports_sts: bool
supported_sasl_mechanisms: Set[str]
proc: Optional[subprocess.Popen]
def __init__(self, test_config: TestCaseControllerConfig):
self.test_config = test_config
self.proc = None
def check_is_alive(self) -> None:
assert self.proc
self.proc.poll()
if self.proc.returncode is not None:
raise ProcessStopped()
def kill_proc(self) -> None:
"""Terminates the controlled process, waits for it to exit, and
eventually kills it."""
assert self.proc
self.proc.terminate()
try:
self.proc.wait(5)
except subprocess.TimeoutExpired:
self.proc.kill()
self.proc = None
def kill(self) -> None:
"""Calls `kill_proc` and cleans the configuration."""
if self.proc:
self.kill_proc()
class DirectoryBasedController(_BaseController):
"""Helper for controllers whose software configuration is based on an
arbitrary directory."""
directory: Optional[str]
def __init__(self, test_config: TestCaseControllerConfig):
super().__init__(test_config)
self.directory = None
def kill(self) -> None:
"""Calls `kill_proc` and cleans the configuration."""
super().kill()
if self.directory:
shutil.rmtree(self.directory)
def terminate(self) -> None:
"""Stops the process gracefully, and does not clean its config."""
assert self.proc
self.proc.terminate()
self.proc.wait()
self.proc = None
def open_file(self, name: str, mode: str = "a") -> IO:
"""Open a file in the configuration directory."""
assert self.directory
if os.sep in name:
dir_ = os.path.join(self.directory, os.path.dirname(name))
if not os.path.isdir(dir_):
os.makedirs(dir_)
assert os.path.isdir(dir_)
return open(os.path.join(self.directory, name), mode)
def create_config(self) -> None:
if not self.directory:
self.directory = tempfile.mkdtemp()
def gen_ssl(self) -> None:
assert self.directory
self.csr_path = os.path.join(self.directory, "ssl.csr")
self.key_path = os.path.join(self.directory, "ssl.key")
self.pem_path = os.path.join(self.directory, "ssl.pem")
self.dh_path = os.path.join(self.directory, "dh.pem")
subprocess.check_output(
[
self.openssl_bin,
"req",
"-new",
"-newkey",
"rsa",
"-nodes",
"-out",
self.csr_path,
"-keyout",
self.key_path,
"-batch",
],
stderr=subprocess.DEVNULL,
)
subprocess.check_output(
[
self.openssl_bin,
"x509",
"-req",
"-in",
self.csr_path,
"-signkey",
self.key_path,
"-out",
self.pem_path,
],
stderr=subprocess.DEVNULL,
)
subprocess.check_output(
[self.openssl_bin, "dhparam", "-out", self.dh_path, "128"],
stderr=subprocess.DEVNULL,
)
class BaseClientController(_BaseController):
"""Base controller for IRC clients."""
def run(
self,
hostname: str,
port: int,
auth: Optional[authentication.Authentication],
tls_config: Optional[tls.TlsConfig] = None,
) -> None:
raise NotImplementedError()
class BaseServerController(_BaseController):
"""Base controller for IRC server."""
software_name: str # Class property
_port_wait_interval = 0.1
port_open = False
port: int
hostname: str
services_controller: Optional[BaseServicesController] = None
services_controller_class: Type[BaseServicesController]
extban_mute_char: Optional[str] = None
"""Character used for the 'mute' extban"""
nickserv = "NickServ"
def get_hostname_and_port(self) -> Tuple[str, int]:
return find_hostname_and_port()
def run(
self,
hostname: str,
port: int,
*,
password: Optional[str],
ssl: bool,
run_services: bool,
valid_metadata_keys: Optional[Set[str]],
invalid_metadata_keys: Optional[Set[str]],
) -> None:
raise NotImplementedError()
def registerUser(
self,
case: irctest.cases.BaseServerTestCase, # type: ignore
username: str,
password: Optional[str] = None,
**kwargs: Any,
) -> None:
if self.services_controller is not None:
self.services_controller.registerUser(case, username, password, **kwargs)
else:
raise NotImplementedByController("account registration")
def wait_for_port(self) -> None:
while not self.port_open:
self.check_is_alive()
time.sleep(self._port_wait_interval)
try:
c = socket.create_connection(("localhost", self.port), timeout=1.0)
c.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
# Make sure the server properly processes the disconnect.
# Otherwise, it may still count it in LUSER and fail tests in
# test_lusers.py (eg. this happens with Charybdis 3.5.0)
c.sendall(b"QUIT :chkport\r\n")
data = b""
try:
while b"chkport" not in data and b"ERROR" not in data:
data += c.recv(4096)
time.sleep(0.01)
c.send(b" ") # Triggers BrokenPipeError
except BrokenPipeError:
# ircu2 cuts the connection without a message if registration
# is not complete.
pass
c.close()
self.port_open = True
except Exception:
continue
def wait_for_services(self) -> None:
assert self.services_controller
self.services_controller.wait_for_services()
def terminate(self) -> None:
if self.services_controller is not None:
self.services_controller.terminate() # type: ignore
super().terminate() # type: ignore
def kill(self) -> None:
if self.services_controller is not None:
self.services_controller.kill() # type: ignore
super().kill()
class BaseServicesController(_BaseController):
def __init__(
self,
test_config: TestCaseControllerConfig,
server_controller: BaseServerController,
):
super().__init__(test_config)
self.test_config = test_config
self.server_controller = server_controller
self.services_up = False
def run(self, protocol: str, server_hostname: str, server_port: int) -> None:
raise NotImplementedError("BaseServerController.run()")
def wait_for_services(self) -> None:
if self.services_up:
# Don't check again if they are already available
return
self.server_controller.wait_for_port()
c = ClientMock(name="chkNS", show_io=True)
c.connect(self.server_controller.hostname, self.server_controller.port)
c.sendLine("NICK chkNS")
c.sendLine("USER chk chk chk chk")
for msg in c.getMessages(synchronize=False):
if msg.command == "PING":
# Hi Unreal
c.sendLine("PONG :" + msg.params[0])
c.getMessages()
timeout = time.time() + 5
while True:
c.sendLine(f"PRIVMSG {self.server_controller.nickserv} :HELP")
msgs = self.getServiceResponse(c)
for msg in msgs:
if msg.command == "401":
# NickServ not available yet
pass
elif msg.command == "NOTICE":
# NickServ is available
assert "nickserv" in (msg.prefix or "").lower(), msg
print("breaking")
break
else:
assert False, f"unexpected reply from NickServ: {msg}"
else:
if time.time() > timeout:
raise Exception("Timeout while waiting for NickServ")
continue
# If we're here, it means we broke from the for loop, so NickServ
# is available and we can break again
break
c.sendLine("QUIT")
c.getMessages()
c.disconnect()
self.services_up = True
def getServiceResponse(self, client: Any) -> List[Message]:
"""Wrapper aroung getMessages() that waits longer, because NickServ
is queried asynchronously."""
msgs: List[Message] = []
while not msgs:
time.sleep(0.05)
msgs = client.getMessages()
return msgs
def registerUser(
self,
case: irctest.cases.BaseServerTestCase, # type: ignore
username: str,
password: Optional[str] = None,
**kwargs: Any,
) -> None:
if not case.run_services:
raise ValueError(
"Attempted to register a nick, but `run_services` it not True."
)
if kwargs:
raise NotImplementedByController(", ".join(kwargs))
assert password
client = case.addClient(show_io=True)
case.sendLine(client, "NICK " + username)
case.sendLine(client, "USER r e g :user")
while case.getRegistrationMessage(client).command != "001":
pass
case.getMessages(client)
case.sendLine(
client,
f"PRIVMSG {self.server_controller.nickserv} "
f":REGISTER {password} foo@example.org",
)
msgs = self.getServiceResponse(case.clients[client])
if self.server_controller.software_name == "inspircd":
assert "900" in {msg.command for msg in msgs}, msgs
assert "NOTICE" in {msg.command for msg in msgs}, msgs
case.sendLine(client, "QUIT")
case.assertDisconnected(client)