diff --git a/irctest/basecontrollers.py b/irctest/basecontrollers.py index f4bd2a4..8ea9494 100644 --- a/irctest/basecontrollers.py +++ b/irctest/basecontrollers.py @@ -7,11 +7,13 @@ import socket import subprocess import tempfile import time -from typing import IO, Any, Callable, Dict, Optional, Set +from typing import IO, Any, Callable, Dict, List, Optional, Set import irctest from . import authentication, tls +from .client_mock import ClientMock +from .irc_utils.message_parser import Message from .runner import NotImplementedByController @@ -179,6 +181,7 @@ class BaseServerController(_BaseController): port_open = False port: int hostname: str + services_controller: BaseServicesController def run( self, @@ -199,7 +202,10 @@ class BaseServerController(_BaseController): username: str, password: Optional[str] = None, ) -> None: - raise NotImplementedByController("account registration") + if self.services_controller: + self.services_controller.registerUser(case, username, password) + else: + raise NotImplementedByController("account registration") def wait_for_port(self) -> None: while not self.port_open: @@ -223,4 +229,73 @@ class BaseServerController(_BaseController): continue def wait_for_services(self) -> None: - pass + self.services_controller.wait_for_services() + + +class BaseServicesController(_BaseController): + def __init__( + self, + test_config: TestCaseControllerConfig, + server_controller: BaseServerController, + ): + super().__init__(test_config) + self.test_config = test_config + self.server_controller = server_controller + + def wait_for_services(self) -> None: + self.server_controller.wait_for_port() + + c = ClientMock(name="chkNS", show_io=True) + c.connect(self.server_controller.hostname, self.server_controller.port) + c.sendLine("NICK chkNS") + c.sendLine("USER chk chk chk chk") + c.getMessages(synchronize=False) + + msgs: List[Message] = [] + while not msgs: + c.sendLine("PRIVMSG NickServ :HELP") + msgs = self.getNickServResponse(c) + if msgs[0].command == "401": + # NickServ not available yet + pass + elif msgs[0].command == "NOTICE": + # NickServ is available + assert "nickserv" in (msgs[0].prefix or "").lower(), msgs + else: + assert False, f"unexpected reply from NickServ: {msgs[0]}" + + c.sendLine("QUIT") + c.getMessages() + c.disconnect() + + def getNickServResponse(self, client: Any) -> List[Message]: + """Wrapper aroung getMessages() that waits longer, because NickServ + is queried asynchronously.""" + msgs: List[Message] = [] + while not msgs: + time.sleep(0.05) + msgs = client.getMessages() + return msgs + + def registerUser( + self, + case: irctest.cases.BaseServerTestCase, # type: ignore + username: str, + password: Optional[str] = None, + ) -> None: + if not case.run_services: + raise ValueError( + "Attempted to register a nick, but `run_services` it not True." + ) + assert password + client = case.addClient(show_io=True) + case.sendLine(client, "NICK " + username) + case.sendLine(client, "USER r e g :user") + while case.getRegistrationMessage(client).command != "001": + pass + case.getMessages(client) + case.sendLine(client, f"PRIVMSG NickServ :REGISTER {password} foo@example.org") + msgs = self.getNickServResponse(case.clients[client]) + assert "900" in {msg.command for msg in msgs}, msgs + case.sendLine(client, "QUIT") + case.assertDisconnected(client) diff --git a/irctest/controllers/atheme_services.py b/irctest/controllers/atheme_services.py index 9454092..bb96d63 100644 --- a/irctest/controllers/atheme_services.py +++ b/irctest/controllers/atheme_services.py @@ -1,19 +1,10 @@ import os import subprocess -import time -from typing import IO, Any, List, Optional - -try: - from typing import Protocol -except ImportError: - # Python < 3.8 - from typing_extensions import Protocol # type: ignore +from typing import Optional import irctest -from irctest.basecontrollers import DirectoryBasedController +from irctest.basecontrollers import BaseServicesController, DirectoryBasedController import irctest.cases -from irctest.client_mock import ClientMock -from irctest.irc_utils.message_parser import Message import irctest.runner TEMPLATE_CONFIG = """ @@ -61,32 +52,12 @@ saslserv {{ """ -class _Controller(Protocol): - # Magic class to make mypy accept AthemeServices as a mixin without actually - # inheriting. - directory: Optional[str] - hostname: str - port: int - services_proc: subprocess.Popen - - def wait_for_port(self) -> None: - ... - - def open_file(self, name: str, mode: str = "a") -> IO: - ... - - def getNickServResponse(self, client: Any) -> List[Message]: - ... - - -class AthemeServices(DirectoryBasedController): +class AthemeServices(BaseServicesController, DirectoryBasedController): """Mixin for server controllers that rely on Atheme""" - def __init__(self, *args, **kwargs): # type: ignore - super().__init__(*args, **kwargs) - self.services_proc = None + def run(self, server_hostname: str, server_port: int) -> None: + self.create_config() - def run_services(self: _Controller, server_hostname: str, server_port: int) -> None: with self.open_file("services.conf") as fd: fd.write( TEMPLATE_CONFIG.format( @@ -96,7 +67,7 @@ class AthemeServices(DirectoryBasedController): ) assert self.directory - self.services_proc = subprocess.Popen( + self.proc = subprocess.Popen( [ "atheme-services", "-n", # don't fork @@ -113,70 +84,16 @@ class AthemeServices(DirectoryBasedController): stderr=subprocess.DEVNULL, ) - def kill_proc(self) -> None: - super().kill_proc() - if self.services_proc is not None: - self.services_proc.kill() - self.services_proc = None - - def wait_for_services(self: _Controller) -> None: - self.wait_for_port() - - c = ClientMock(name="chkNS", show_io=True) - c.connect(self.hostname, self.port) - c.sendLine("NICK chkNS") - c.sendLine("USER chk chk chk chk") - c.getMessages(synchronize=False) - - msgs: List[Message] = [] - while not msgs: - c.sendLine("PRIVMSG NickServ :HELP") - msgs = self.getNickServResponse(c) - if msgs[0].command == "401": - # NickServ not available yet - pass - elif msgs[0].command == "NOTICE": - # NickServ is available - assert "nickserv" in (msgs[0].prefix or "").lower(), msgs - else: - assert False, f"unexpected reply from NickServ: {msgs[0]}" - - c.sendLine("QUIT") - c.getMessages() - c.disconnect() - - def getNickServResponse(self, client: Any) -> List[Message]: - """Wrapper aroung getMessages() that waits longer, because NickServ - is queried asynchronously.""" - msgs: List[Message] = [] - while not msgs: - time.sleep(0.05) - msgs = client.getMessages() - return msgs - def registerUser( self, case: irctest.cases.BaseServerTestCase, username: str, password: Optional[str] = None, ) -> None: - if not case.run_services: - raise ValueError( - "Attempted to register a nick, but `run_services` it not True." - ) assert password if len(password.encode()) > 288: # It's hardcoded at compile-time :( # https://github.com/atheme/atheme/blob/4fa0e03bd3ce2cb6041a339f308616580c5aac29/include/atheme/constants.h#L51 raise irctest.runner.NotImplementedByController("Passwords over 288 bytes") - client = case.addClient(show_io=True) - case.sendLine(client, "NICK " + username) - case.sendLine(client, "USER r e g :user") - while case.getRegistrationMessage(client).command != "001": - pass - case.getMessages(client) - case.sendLine(client, f"PRIVMSG NickServ :REGISTER {password} foo@example.org") - msgs = self.getNickServResponse(case.clients[client]) - assert "900" in {msg.command for msg in msgs}, msgs - case.sendLine(client, "QUIT") - case.assertDisconnected(client) + + super().registerUser(case, username, password) diff --git a/irctest/controllers/ergo.py b/irctest/controllers/ergo.py index 58dfa22..3bf77b7 100644 --- a/irctest/controllers/ergo.py +++ b/irctest/controllers/ergo.py @@ -206,6 +206,10 @@ class ErgoController(BaseServerController, DirectoryBasedController): ["ergo", "run", "--conf", self._config_path, "--quiet"] ) + def wait_for_services(self) -> None: + # Nothing to wait for, they start at the same time as Ergo. + pass + def registerUser( self, case: BaseServerTestCase, diff --git a/irctest/controllers/inspircd.py b/irctest/controllers/inspircd.py index fccdb54..809823b 100644 --- a/irctest/controllers/inspircd.py +++ b/irctest/controllers/inspircd.py @@ -63,9 +63,7 @@ TEMPLATE_SSL_CONFIG = """ """ -class InspircdController( - AthemeServices, BaseServerController, DirectoryBasedController -): +class InspircdController(BaseServerController, DirectoryBasedController): software_name = "InspIRCd" supported_sasl_mechanisms = {"PLAIN"} supports_sts = False @@ -130,7 +128,8 @@ class InspircdController( ) if run_services: - self.run_services( + self.services_controller = AthemeServices(self.test_config, self) + self.services_controller.run( server_hostname=services_hostname, server_port=services_port )