From ac2a37362cf724b6a0f3ac0cd5a297d614e2962c Mon Sep 17 00:00:00 2001 From: Valentin Lorentz Date: Sun, 28 Feb 2021 13:40:08 +0100 Subject: [PATCH] Use dataclasses instead of dicts/namedtuples --- irctest/authentication.py | 9 ++--- irctest/basecontrollers.py | 25 ++++++++++++-- irctest/cases.py | 5 +-- irctest/controllers/limnoria.py | 6 ++-- irctest/controllers/oragono.py | 8 ++--- irctest/irc_utils/junkdrawer.py | 13 ------- irctest/irc_utils/message_parser.py | 26 ++++++++++++-- irctest/server_tests/test_chathistory.py | 36 +++++++++++--------- irctest/server_tests/test_confusables.py | 8 ++--- irctest/server_tests/test_lusers.py | 16 ++++----- irctest/server_tests/test_register_verify.py | 32 ++++++++--------- irctest/server_tests/test_relaymsg.py | 4 +-- irctest/server_tests/test_roleplay.py | 4 +-- irctest/server_tests/test_znc_playback.py | 20 +++++------ irctest/tls.py | 9 +++-- 15 files changed, 124 insertions(+), 97 deletions(-) diff --git a/irctest/authentication.py b/irctest/authentication.py index 712fef6..12c9daf 100644 --- a/irctest/authentication.py +++ b/irctest/authentication.py @@ -7,13 +7,8 @@ from typing import Optional, Tuple class Mechanisms(enum.Enum): """Enumeration for representing possible mechanisms.""" - @classmethod - def as_string(cls, mech): - return { - cls.plain: "PLAIN", - cls.ecdsa_nist256p_challenge: "ECDSA-NIST256P-CHALLENGE", - cls.scram_sha_256: "SCRAM-SHA-256", - }[mech] + def to_string(self) -> str: + return self.name.upper().replace("_", "-") plain = 1 ecdsa_nist256p_challenge = 2 diff --git a/irctest/basecontrollers.py b/irctest/basecontrollers.py index 002f4e6..73b4611 100644 --- a/irctest/basecontrollers.py +++ b/irctest/basecontrollers.py @@ -1,10 +1,11 @@ +import dataclasses import os import shutil import socket import subprocess import tempfile import time -from typing import Set +from typing import Any, Callable, Dict, Optional, Set from .runner import NotImplementedByController @@ -15,6 +16,24 @@ class ProcessStopped(Exception): pass +@dataclasses.dataclass +class TestCaseControllerConfig: + """Test-case-specific configuration passed to the controller. + This is usually used to ask controllers to enable a feature; + but should not be an issue if controllers enable it all the time.""" + + chathistory: bool = False + """Whether to enable chathistory features.""" + + oragono_roleplay: bool = False + """Whether to enable the Oragono role-play commands.""" + + oragono_config: Optional[Callable[[Dict], Any]] = None + """Oragono-specific configuration function that alters the dict in-place + This should be used as little as possible, using the other attributes instead; + as they are work with any controller.""" + + class _BaseController: """Base class for software controllers. @@ -22,7 +41,7 @@ class _BaseController: a process (eg. a server or a client), as well as sending it instructions that are not part of the IRC specification.""" - def __init__(self, test_config): + def __init__(self, test_config: TestCaseControllerConfig): self.test_config = test_config self.proc = None @@ -36,7 +55,7 @@ class DirectoryBasedController(_BaseController): """Helper for controllers whose software configuration is based on an arbitrary directory.""" - def __init__(self, test_config): + def __init__(self, test_config: TestCaseControllerConfig): super().__init__(test_config) self.directory = None diff --git a/irctest/cases.py b/irctest/cases.py index 3d5b11b..aa62f01 100644 --- a/irctest/cases.py +++ b/irctest/cases.py @@ -9,6 +9,7 @@ import unittest import pytest from . import basecontrollers, client_mock, runner +from .basecontrollers import TestCaseControllerConfig from .exceptions import ConnectionClosed from .irc_utils import capabilities, message_parser from .irc_utils.junkdrawer import normalizeWhitespace @@ -48,12 +49,12 @@ class _IrcTestCase(unittest.TestCase): controllerClass = None # Will be set by __main__.py @staticmethod - def config(): + def config() -> TestCaseControllerConfig: """Some configuration to pass to the controllers. For example, Oragono only enables its MySQL support if config()["chathistory"]=True. """ - return {} + return TestCaseControllerConfig() def description(self): method_doc = self._testMethodDoc diff --git a/irctest/controllers/limnoria.py b/irctest/controllers/limnoria.py index aad91eb..4a38d5d 100644 --- a/irctest/controllers/limnoria.py +++ b/irctest/controllers/limnoria.py @@ -1,7 +1,7 @@ import os import subprocess -from irctest import authentication, tls +from irctest import tls from irctest.basecontrollers import BaseClientController, DirectoryBasedController TEMPLATE_CONFIG = """ @@ -50,9 +50,7 @@ class LimnoriaController(BaseClientController, DirectoryBasedController): assert self.proc is None self.create_config() if auth: - mechanisms = " ".join( - map(authentication.Mechanisms.as_string, auth.mechanisms) - ) + mechanisms = " ".join(mech.to_string() for mech in auth.mechanisms) if auth.ecdsa_key: with self.open_file("ecdsa_key.pem") as fd: fd.write(auth.ecdsa_key) diff --git a/irctest/controllers/oragono.py b/irctest/controllers/oragono.py index 1a22a89..b2082e2 100644 --- a/irctest/controllers/oragono.py +++ b/irctest/controllers/oragono.py @@ -162,16 +162,16 @@ class OragonoController(BaseServerController, DirectoryBasedController): if config is None: config = copy.deepcopy(BASE_CONFIG) - enable_chathistory = self.test_config.get("chathistory") - enable_roleplay = self.test_config.get("oragono_roleplay") + enable_chathistory = self.test_config.chathistory + enable_roleplay = self.test_config.oragono_roleplay if enable_chathistory or enable_roleplay: config = self.addMysqlToConfig(config) if enable_roleplay: config["roleplay"] = {"enabled": True} - if "oragono_config" in self.test_config: - self.test_config["oragono_config"](config) + if self.test_config.oragono_config: + self.test_config.oragono_config(config) self.port = port bind_address = "127.0.0.1:%s" % (port,) diff --git a/irctest/irc_utils/junkdrawer.py b/irctest/irc_utils/junkdrawer.py index 2c4bf22..dc39738 100644 --- a/irctest/irc_utils/junkdrawer.py +++ b/irctest/irc_utils/junkdrawer.py @@ -1,20 +1,7 @@ -from collections import namedtuple import datetime import re import secrets -HistoryMessage = namedtuple("HistoryMessage", ["time", "msgid", "target", "text"]) - - -def to_history_message(msg): - return HistoryMessage( - time=msg.tags.get("time"), - msgid=msg.tags.get("msgid"), - target=msg.params[0], - text=msg.params[1], - ) - - # thanks jess! IRCV3_FORMAT_STRFTIME = "%Y-%m-%dT%H:%M:%S.%f%z" diff --git a/irctest/irc_utils/message_parser.py b/irctest/irc_utils/message_parser.py index ea1f482..22f2b99 100644 --- a/irctest/irc_utils/message_parser.py +++ b/irctest/irc_utils/message_parser.py @@ -1,5 +1,6 @@ -import collections +import dataclasses import re +from typing import Any, Dict, List, Optional from .junkdrawer import MultipleReplacer @@ -29,7 +30,28 @@ def parse_tags(s): return tags -Message = collections.namedtuple("Message", "tags prefix command params") +@dataclasses.dataclass(frozen=True) +class HistoryMessage: + time: Any + msgid: Optional[str] + target: str + text: str + + +@dataclasses.dataclass(frozen=True) +class Message: + tags: Dict[str, Optional[str]] + prefix: Optional[str] + command: str + params: List[str] + + def to_history_message(self) -> HistoryMessage: + return HistoryMessage( + time=self.tags.get("time"), + msgid=self.tags.get("msgid"), + target=self.params[0], + text=self.params[1], + ) def parse_message(s): diff --git a/irctest/server_tests/test_chathistory.py b/irctest/server_tests/test_chathistory.py index 4895c87..6c69c0a 100644 --- a/irctest/server_tests/test_chathistory.py +++ b/irctest/server_tests/test_chathistory.py @@ -2,7 +2,7 @@ import secrets import time from irctest import cases -from irctest.irc_utils.junkdrawer import random_name, to_history_message +from irctest.irc_utils.junkdrawer import random_name CHATHISTORY_CAP = "draft/chathistory" EVENT_PLAYBACK_CAP = "draft/event-playback" @@ -27,15 +27,15 @@ def validate_chathistory_batch(msgs): and batch_tag is not None and msg.tags.get("batch") == batch_tag ): - result.append(to_history_message(msg)) + result.append(msg.to_history_message()) assert batch_tag == closed_batch_tag return result class ChathistoryTestCase(cases.BaseServerTestCase): @staticmethod - def config(): - return {"chathistory": True} + def config() -> cases.TestCaseControllerConfig: + return cases.TestCaseControllerConfig(chathistory=True) @cases.mark_specifications("Oragono") def testInvalidTargets(self): @@ -93,7 +93,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase): self.assertEqual(len(replies), 1) msg = replies[0] self.assertEqual(msg.params, [bar, "this is a privmsg sent to myself"]) - messages.append(to_history_message(msg)) + messages.append(msg.to_history_message()) self.sendLine(bar, "CAP REQ echo-message") self.getMessages(bar) @@ -106,9 +106,11 @@ class ChathistoryTestCase(cases.BaseServerTestCase): self.assertEqual( replies[0].params, [bar, "this is a second privmsg sent to myself"] ) - messages.append(to_history_message(replies[0])) + messages.append(replies[0].to_history_message()) # messages should be otherwise identical - self.assertEqual(to_history_message(replies[0]), to_history_message(replies[1])) + self.assertEqual( + replies[0].to_history_message(), replies[1].to_history_message() + ) self.sendLine( bar, @@ -120,17 +122,17 @@ class ChathistoryTestCase(cases.BaseServerTestCase): echo = [msg for msg in replies if msg.tags.get("label") == "xyz"][0] delivery = [msg for msg in replies if msg.tags.get("label") is None][0] self.assertEqual(echo.params, [bar, "this is a third privmsg sent to myself"]) - messages.append(to_history_message(echo)) - self.assertEqual(to_history_message(echo), to_history_message(delivery)) + messages.append(echo.to_history_message()) + self.assertEqual(echo.to_history_message(), delivery.to_history_message()) # should receive exactly 3 messages in the correct order, no duplicates self.sendLine(bar, "CHATHISTORY LATEST * * 10") replies = [msg for msg in self.getMessages(bar) if msg.command == "PRIVMSG"] - self.assertEqual([to_history_message(msg) for msg in replies], messages) + self.assertEqual([msg.to_history_message() for msg in replies], messages) self.sendLine(bar, "CHATHISTORY LATEST %s * 10" % (bar,)) replies = [msg for msg in self.getMessages(bar) if msg.command == "PRIVMSG"] - self.assertEqual([to_history_message(msg) for msg in replies], messages) + self.assertEqual([msg.to_history_message() for msg in replies], messages) def validate_echo_messages(self, num_messages, echo_messages): # sanity checks: should have received the correct number of echo messages, @@ -161,7 +163,9 @@ class ChathistoryTestCase(cases.BaseServerTestCase): echo_messages = [] for i in range(NUM_MESSAGES): self.sendLine(1, "PRIVMSG %s :this is message %d" % (chname, i)) - echo_messages.extend(to_history_message(msg) for msg in self.getMessages(1)) + echo_messages.extend( + msg.to_history_message() for msg in self.getMessages(1) + ) time.sleep(0.002) self.validate_echo_messages(NUM_MESSAGES, echo_messages) @@ -213,7 +217,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase): self.getMessages(user) self.sendLine(user, "PRIVMSG %s :this is message %d" % (target, i)) echo_messages.extend( - to_history_message(msg) for msg in self.getMessages(user) + msg.to_history_message() for msg in self.getMessages(user) ) time.sleep(0.002) @@ -245,7 +249,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase): ) # 3 received the first message as a delivery and the second as an echo new_convo = [ - to_history_message(msg) + msg.to_history_message() for msg in self.getMessages(3) if msg.command == "PRIVMSG" ] @@ -262,7 +266,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase): self.getMessages(1) self.sendLine(1, "CHATHISTORY LATEST %s * 10" % (c3,)) results = [ - to_history_message(msg) + msg.to_history_message() for msg in self.getMessages(1) if msg.command == "PRIVMSG" ] @@ -296,7 +300,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase): self.getMessages(c3) self.sendLine(c3, "CHATHISTORY LATEST %s * 10" % (c1,)) results = [ - to_history_message(msg) + msg.to_history_message() for msg in self.getMessages(c3) if msg.command == "PRIVMSG" ] diff --git a/irctest/server_tests/test_confusables.py b/irctest/server_tests/test_confusables.py index 2523f57..4818878 100644 --- a/irctest/server_tests/test_confusables.py +++ b/irctest/server_tests/test_confusables.py @@ -4,12 +4,12 @@ from irctest.numerics import ERR_NICKNAMEINUSE, RPL_WELCOME class ConfusablesTestCase(cases.BaseServerTestCase): @staticmethod - def config(): - return { - "oragono_config": lambda config: config["accounts"].update( + def config() -> cases.TestCaseControllerConfig: + return cases.TestCaseControllerConfig( + oragono_config=lambda config: config["accounts"].update( {"nick-reservation": {"enabled": True, "method": "strict"}} ) - } + ) @cases.mark_specifications("Oragono") def testConfusableNicks(self): diff --git a/irctest/server_tests/test_lusers.py b/irctest/server_tests/test_lusers.py index 2f233e6..25f35a9 100644 --- a/irctest/server_tests/test_lusers.py +++ b/irctest/server_tests/test_lusers.py @@ -178,12 +178,12 @@ class LusersUnregisteredDefaultInvisibleTest(LusersUnregisteredTestCase): """Same as above but with +i as the default.""" @staticmethod - def config(): - return { - "oragono_config": lambda config: config["accounts"].update( + def config() -> cases.TestCaseControllerConfig: + return cases.TestCaseControllerConfig( + oragono_config=lambda config: config["accounts"].update( {"default-user-modes": "+i"} ) - } + ) @cases.mark_specifications("Oragono") def testLusers(self): @@ -236,12 +236,12 @@ class LuserOpersTest(LusersTestCase): class OragonoInvisibleDefaultTest(LusersTestCase): @staticmethod - def config(): - return { - "oragono_config": lambda config: config["accounts"].update( + def config() -> cases.TestCaseControllerConfig: + return cases.TestCaseControllerConfig( + oragono_config=lambda config: config["accounts"].update( {"default-user-modes": "+i"} ) - } + ) @cases.mark_specifications("Oragono") def testLusers(self): diff --git a/irctest/server_tests/test_register_verify.py b/irctest/server_tests/test_register_verify.py index 149bba4..7d3d8b2 100644 --- a/irctest/server_tests/test_register_verify.py +++ b/irctest/server_tests/test_register_verify.py @@ -5,12 +5,12 @@ REGISTER_CAP_NAME = "draft/register" class TestRegisterBeforeConnect(cases.BaseServerTestCase): @staticmethod - def config(): - return { - "oragono_config": lambda config: config["accounts"]["registration"].update( + def config() -> cases.TestCaseControllerConfig: + return cases.TestCaseControllerConfig( + oragono_config=lambda config: config["accounts"]["registration"].update( {"allow-before-connect": True} ) - } + ) @cases.mark_specifications("Oragono") def testBeforeConnect(self): @@ -28,12 +28,12 @@ class TestRegisterBeforeConnect(cases.BaseServerTestCase): class TestRegisterBeforeConnectDisallowed(cases.BaseServerTestCase): @staticmethod - def config(): - return { - "oragono_config": lambda config: config["accounts"]["registration"].update( + def config() -> cases.TestCaseControllerConfig: + return cases.TestCaseControllerConfig( + oragono_config=lambda config: config["accounts"]["registration"].update( {"allow-before-connect": False} ) - } + ) @cases.mark_specifications("Oragono") def testBeforeConnect(self): @@ -51,9 +51,9 @@ class TestRegisterBeforeConnectDisallowed(cases.BaseServerTestCase): class TestRegisterEmailVerified(cases.BaseServerTestCase): @staticmethod - def config(): - return { - "oragono_config": lambda config: config["accounts"]["registration"].update( + def config() -> cases.TestCaseControllerConfig: + return cases.TestCaseControllerConfig( + oragono_config=lambda config: config["accounts"]["registration"].update( { "email-verification": { "enabled": True, @@ -64,7 +64,7 @@ class TestRegisterEmailVerified(cases.BaseServerTestCase): "allow-before-connect": True, } ) - } + ) @cases.mark_specifications("Oragono") def testBeforeConnect(self): @@ -93,12 +93,12 @@ class TestRegisterEmailVerified(cases.BaseServerTestCase): class TestRegisterNoLandGrabs(cases.BaseServerTestCase): @staticmethod - def config(): - return { - "oragono_config": lambda config: config["accounts"]["registration"].update( + def config() -> cases.TestCaseControllerConfig: + return cases.TestCaseControllerConfig( + oragono_config=lambda config: config["accounts"]["registration"].update( {"allow-before-connect": True} ) - } + ) @cases.mark_specifications("Oragono") def testBeforeConnect(self): diff --git a/irctest/server_tests/test_relaymsg.py b/irctest/server_tests/test_relaymsg.py index a84cdd8..7d7ace9 100644 --- a/irctest/server_tests/test_relaymsg.py +++ b/irctest/server_tests/test_relaymsg.py @@ -8,8 +8,8 @@ RELAYMSG_TAG_NAME = "draft/relaymsg" class RelaymsgTestCase(cases.BaseServerTestCase): @staticmethod - def config(): - return {"chathistory": True} + def config() -> cases.TestCaseControllerConfig: + return cases.TestCaseControllerConfig(chathistory=True) @cases.mark_specifications("Oragono") def testRelaymsg(self): diff --git a/irctest/server_tests/test_roleplay.py b/irctest/server_tests/test_roleplay.py index 70f0e7b..7837005 100644 --- a/irctest/server_tests/test_roleplay.py +++ b/irctest/server_tests/test_roleplay.py @@ -5,8 +5,8 @@ from irctest.numerics import ERR_CANNOTSENDRP class RoleplayTestCase(cases.BaseServerTestCase): @staticmethod - def config(): - return {"oragono_roleplay": True} + def config() -> cases.TestCaseControllerConfig: + return cases.TestCaseControllerConfig(oragono_roleplay=True) @cases.mark_specifications("Oragono") def testRoleplay(self): diff --git a/irctest/server_tests/test_znc_playback.py b/irctest/server_tests/test_znc_playback.py index a065e69..88a2100 100644 --- a/irctest/server_tests/test_znc_playback.py +++ b/irctest/server_tests/test_znc_playback.py @@ -1,11 +1,7 @@ import time from irctest import cases -from irctest.irc_utils.junkdrawer import ( - ircv3_timestamp_to_unixtime, - random_name, - to_history_message, -) +from irctest.irc_utils.junkdrawer import ircv3_timestamp_to_unixtime, random_name def extract_playback_privmsgs(messages): @@ -13,14 +9,14 @@ def extract_playback_privmsgs(messages): result = [] for msg in messages: if msg.command == "PRIVMSG" and msg.params[0].lower() != "*playback": - result.append(to_history_message(msg)) + result.append(msg.to_history_message()) return result class ZncPlaybackTestCase(cases.BaseServerTestCase): @staticmethod - def config(): - return {"chathistory": True} + def config() -> cases.TestCaseControllerConfig: + return cases.TestCaseControllerConfig(chathistory=True) @cases.mark_specifications("Oragono") def testZncPlayback(self): @@ -58,9 +54,9 @@ class ZncPlaybackTestCase(cases.BaseServerTestCase): self.joinChannel(qux, chname) self.sendLine(qux, "PRIVMSG %s :hi there" % (bar,)) - dm = to_history_message( - [msg for msg in self.getMessages(qux) if msg.command == "PRIVMSG"][0] - ) + dm = [msg for msg in self.getMessages(qux) if msg.command == "PRIVMSG"][ + 0 + ].to_history_message() self.assertEqual(dm.text, "hi there") NUM_MESSAGES = 10 @@ -68,7 +64,7 @@ class ZncPlaybackTestCase(cases.BaseServerTestCase): for i in range(NUM_MESSAGES): self.sendLine(qux, "PRIVMSG %s :this is message %d" % (chname, i)) echo_messages.extend( - to_history_message(msg) + msg.to_history_message() for msg in self.getMessages(qux) if msg.command == "PRIVMSG" ) diff --git a/irctest/tls.py b/irctest/tls.py index 89ff454..4ef18ef 100644 --- a/irctest/tls.py +++ b/irctest/tls.py @@ -1,3 +1,8 @@ -import collections +import dataclasses +from typing import List -TlsConfig = collections.namedtuple("TlsConfig", "enable trusted_fingerprints") + +@dataclasses.dataclass +class TlsConfig: + enable: bool + trusted_fingerprints: List[str]