Use dataclasses instead of dicts/namedtuples

This commit is contained in:
Valentin Lorentz 2021-02-28 13:40:08 +01:00 committed by Valentin Lorentz
parent 12da7e1e3b
commit ac2a37362c
15 changed files with 124 additions and 97 deletions

View File

@ -7,13 +7,8 @@ from typing import Optional, Tuple
class Mechanisms(enum.Enum):
"""Enumeration for representing possible mechanisms."""
@classmethod
def as_string(cls, mech):
return {
cls.plain: "PLAIN",
cls.ecdsa_nist256p_challenge: "ECDSA-NIST256P-CHALLENGE",
cls.scram_sha_256: "SCRAM-SHA-256",
}[mech]
def to_string(self) -> str:
return self.name.upper().replace("_", "-")
plain = 1
ecdsa_nist256p_challenge = 2

View File

@ -1,10 +1,11 @@
import dataclasses
import os
import shutil
import socket
import subprocess
import tempfile
import time
from typing import Set
from typing import Any, Callable, Dict, Optional, Set
from .runner import NotImplementedByController
@ -15,6 +16,24 @@ class ProcessStopped(Exception):
pass
@dataclasses.dataclass
class TestCaseControllerConfig:
"""Test-case-specific configuration passed to the controller.
This is usually used to ask controllers to enable a feature;
but should not be an issue if controllers enable it all the time."""
chathistory: bool = False
"""Whether to enable chathistory features."""
oragono_roleplay: bool = False
"""Whether to enable the Oragono role-play commands."""
oragono_config: Optional[Callable[[Dict], Any]] = None
"""Oragono-specific configuration function that alters the dict in-place
This should be used as little as possible, using the other attributes instead;
as they are work with any controller."""
class _BaseController:
"""Base class for software controllers.
@ -22,7 +41,7 @@ class _BaseController:
a process (eg. a server or a client), as well as sending it instructions
that are not part of the IRC specification."""
def __init__(self, test_config):
def __init__(self, test_config: TestCaseControllerConfig):
self.test_config = test_config
self.proc = None
@ -36,7 +55,7 @@ class DirectoryBasedController(_BaseController):
"""Helper for controllers whose software configuration is based on an
arbitrary directory."""
def __init__(self, test_config):
def __init__(self, test_config: TestCaseControllerConfig):
super().__init__(test_config)
self.directory = None

View File

@ -9,6 +9,7 @@ import unittest
import pytest
from . import basecontrollers, client_mock, runner
from .basecontrollers import TestCaseControllerConfig
from .exceptions import ConnectionClosed
from .irc_utils import capabilities, message_parser
from .irc_utils.junkdrawer import normalizeWhitespace
@ -48,12 +49,12 @@ class _IrcTestCase(unittest.TestCase):
controllerClass = None # Will be set by __main__.py
@staticmethod
def config():
def config() -> TestCaseControllerConfig:
"""Some configuration to pass to the controllers.
For example, Oragono only enables its MySQL support if
config()["chathistory"]=True.
"""
return {}
return TestCaseControllerConfig()
def description(self):
method_doc = self._testMethodDoc

View File

@ -1,7 +1,7 @@
import os
import subprocess
from irctest import authentication, tls
from irctest import tls
from irctest.basecontrollers import BaseClientController, DirectoryBasedController
TEMPLATE_CONFIG = """
@ -50,9 +50,7 @@ class LimnoriaController(BaseClientController, DirectoryBasedController):
assert self.proc is None
self.create_config()
if auth:
mechanisms = " ".join(
map(authentication.Mechanisms.as_string, auth.mechanisms)
)
mechanisms = " ".join(mech.to_string() for mech in auth.mechanisms)
if auth.ecdsa_key:
with self.open_file("ecdsa_key.pem") as fd:
fd.write(auth.ecdsa_key)

View File

@ -162,16 +162,16 @@ class OragonoController(BaseServerController, DirectoryBasedController):
if config is None:
config = copy.deepcopy(BASE_CONFIG)
enable_chathistory = self.test_config.get("chathistory")
enable_roleplay = self.test_config.get("oragono_roleplay")
enable_chathistory = self.test_config.chathistory
enable_roleplay = self.test_config.oragono_roleplay
if enable_chathistory or enable_roleplay:
config = self.addMysqlToConfig(config)
if enable_roleplay:
config["roleplay"] = {"enabled": True}
if "oragono_config" in self.test_config:
self.test_config["oragono_config"](config)
if self.test_config.oragono_config:
self.test_config.oragono_config(config)
self.port = port
bind_address = "127.0.0.1:%s" % (port,)

View File

@ -1,20 +1,7 @@
from collections import namedtuple
import datetime
import re
import secrets
HistoryMessage = namedtuple("HistoryMessage", ["time", "msgid", "target", "text"])
def to_history_message(msg):
return HistoryMessage(
time=msg.tags.get("time"),
msgid=msg.tags.get("msgid"),
target=msg.params[0],
text=msg.params[1],
)
# thanks jess!
IRCV3_FORMAT_STRFTIME = "%Y-%m-%dT%H:%M:%S.%f%z"

View File

@ -1,5 +1,6 @@
import collections
import dataclasses
import re
from typing import Any, Dict, List, Optional
from .junkdrawer import MultipleReplacer
@ -29,7 +30,28 @@ def parse_tags(s):
return tags
Message = collections.namedtuple("Message", "tags prefix command params")
@dataclasses.dataclass(frozen=True)
class HistoryMessage:
time: Any
msgid: Optional[str]
target: str
text: str
@dataclasses.dataclass(frozen=True)
class Message:
tags: Dict[str, Optional[str]]
prefix: Optional[str]
command: str
params: List[str]
def to_history_message(self) -> HistoryMessage:
return HistoryMessage(
time=self.tags.get("time"),
msgid=self.tags.get("msgid"),
target=self.params[0],
text=self.params[1],
)
def parse_message(s):

View File

@ -2,7 +2,7 @@ import secrets
import time
from irctest import cases
from irctest.irc_utils.junkdrawer import random_name, to_history_message
from irctest.irc_utils.junkdrawer import random_name
CHATHISTORY_CAP = "draft/chathistory"
EVENT_PLAYBACK_CAP = "draft/event-playback"
@ -27,15 +27,15 @@ def validate_chathistory_batch(msgs):
and batch_tag is not None
and msg.tags.get("batch") == batch_tag
):
result.append(to_history_message(msg))
result.append(msg.to_history_message())
assert batch_tag == closed_batch_tag
return result
class ChathistoryTestCase(cases.BaseServerTestCase):
@staticmethod
def config():
return {"chathistory": True}
def config() -> cases.TestCaseControllerConfig:
return cases.TestCaseControllerConfig(chathistory=True)
@cases.mark_specifications("Oragono")
def testInvalidTargets(self):
@ -93,7 +93,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
self.assertEqual(len(replies), 1)
msg = replies[0]
self.assertEqual(msg.params, [bar, "this is a privmsg sent to myself"])
messages.append(to_history_message(msg))
messages.append(msg.to_history_message())
self.sendLine(bar, "CAP REQ echo-message")
self.getMessages(bar)
@ -106,9 +106,11 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
self.assertEqual(
replies[0].params, [bar, "this is a second privmsg sent to myself"]
)
messages.append(to_history_message(replies[0]))
messages.append(replies[0].to_history_message())
# messages should be otherwise identical
self.assertEqual(to_history_message(replies[0]), to_history_message(replies[1]))
self.assertEqual(
replies[0].to_history_message(), replies[1].to_history_message()
)
self.sendLine(
bar,
@ -120,17 +122,17 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
echo = [msg for msg in replies if msg.tags.get("label") == "xyz"][0]
delivery = [msg for msg in replies if msg.tags.get("label") is None][0]
self.assertEqual(echo.params, [bar, "this is a third privmsg sent to myself"])
messages.append(to_history_message(echo))
self.assertEqual(to_history_message(echo), to_history_message(delivery))
messages.append(echo.to_history_message())
self.assertEqual(echo.to_history_message(), delivery.to_history_message())
# should receive exactly 3 messages in the correct order, no duplicates
self.sendLine(bar, "CHATHISTORY LATEST * * 10")
replies = [msg for msg in self.getMessages(bar) if msg.command == "PRIVMSG"]
self.assertEqual([to_history_message(msg) for msg in replies], messages)
self.assertEqual([msg.to_history_message() for msg in replies], messages)
self.sendLine(bar, "CHATHISTORY LATEST %s * 10" % (bar,))
replies = [msg for msg in self.getMessages(bar) if msg.command == "PRIVMSG"]
self.assertEqual([to_history_message(msg) for msg in replies], messages)
self.assertEqual([msg.to_history_message() for msg in replies], messages)
def validate_echo_messages(self, num_messages, echo_messages):
# sanity checks: should have received the correct number of echo messages,
@ -161,7 +163,9 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
echo_messages = []
for i in range(NUM_MESSAGES):
self.sendLine(1, "PRIVMSG %s :this is message %d" % (chname, i))
echo_messages.extend(to_history_message(msg) for msg in self.getMessages(1))
echo_messages.extend(
msg.to_history_message() for msg in self.getMessages(1)
)
time.sleep(0.002)
self.validate_echo_messages(NUM_MESSAGES, echo_messages)
@ -213,7 +217,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
self.getMessages(user)
self.sendLine(user, "PRIVMSG %s :this is message %d" % (target, i))
echo_messages.extend(
to_history_message(msg) for msg in self.getMessages(user)
msg.to_history_message() for msg in self.getMessages(user)
)
time.sleep(0.002)
@ -245,7 +249,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
)
# 3 received the first message as a delivery and the second as an echo
new_convo = [
to_history_message(msg)
msg.to_history_message()
for msg in self.getMessages(3)
if msg.command == "PRIVMSG"
]
@ -262,7 +266,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
self.getMessages(1)
self.sendLine(1, "CHATHISTORY LATEST %s * 10" % (c3,))
results = [
to_history_message(msg)
msg.to_history_message()
for msg in self.getMessages(1)
if msg.command == "PRIVMSG"
]
@ -296,7 +300,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
self.getMessages(c3)
self.sendLine(c3, "CHATHISTORY LATEST %s * 10" % (c1,))
results = [
to_history_message(msg)
msg.to_history_message()
for msg in self.getMessages(c3)
if msg.command == "PRIVMSG"
]

View File

@ -4,12 +4,12 @@ from irctest.numerics import ERR_NICKNAMEINUSE, RPL_WELCOME
class ConfusablesTestCase(cases.BaseServerTestCase):
@staticmethod
def config():
return {
"oragono_config": lambda config: config["accounts"].update(
def config() -> cases.TestCaseControllerConfig:
return cases.TestCaseControllerConfig(
oragono_config=lambda config: config["accounts"].update(
{"nick-reservation": {"enabled": True, "method": "strict"}}
)
}
)
@cases.mark_specifications("Oragono")
def testConfusableNicks(self):

View File

@ -178,12 +178,12 @@ class LusersUnregisteredDefaultInvisibleTest(LusersUnregisteredTestCase):
"""Same as above but with +i as the default."""
@staticmethod
def config():
return {
"oragono_config": lambda config: config["accounts"].update(
def config() -> cases.TestCaseControllerConfig:
return cases.TestCaseControllerConfig(
oragono_config=lambda config: config["accounts"].update(
{"default-user-modes": "+i"}
)
}
)
@cases.mark_specifications("Oragono")
def testLusers(self):
@ -236,12 +236,12 @@ class LuserOpersTest(LusersTestCase):
class OragonoInvisibleDefaultTest(LusersTestCase):
@staticmethod
def config():
return {
"oragono_config": lambda config: config["accounts"].update(
def config() -> cases.TestCaseControllerConfig:
return cases.TestCaseControllerConfig(
oragono_config=lambda config: config["accounts"].update(
{"default-user-modes": "+i"}
)
}
)
@cases.mark_specifications("Oragono")
def testLusers(self):

View File

@ -5,12 +5,12 @@ REGISTER_CAP_NAME = "draft/register"
class TestRegisterBeforeConnect(cases.BaseServerTestCase):
@staticmethod
def config():
return {
"oragono_config": lambda config: config["accounts"]["registration"].update(
def config() -> cases.TestCaseControllerConfig:
return cases.TestCaseControllerConfig(
oragono_config=lambda config: config["accounts"]["registration"].update(
{"allow-before-connect": True}
)
}
)
@cases.mark_specifications("Oragono")
def testBeforeConnect(self):
@ -28,12 +28,12 @@ class TestRegisterBeforeConnect(cases.BaseServerTestCase):
class TestRegisterBeforeConnectDisallowed(cases.BaseServerTestCase):
@staticmethod
def config():
return {
"oragono_config": lambda config: config["accounts"]["registration"].update(
def config() -> cases.TestCaseControllerConfig:
return cases.TestCaseControllerConfig(
oragono_config=lambda config: config["accounts"]["registration"].update(
{"allow-before-connect": False}
)
}
)
@cases.mark_specifications("Oragono")
def testBeforeConnect(self):
@ -51,9 +51,9 @@ class TestRegisterBeforeConnectDisallowed(cases.BaseServerTestCase):
class TestRegisterEmailVerified(cases.BaseServerTestCase):
@staticmethod
def config():
return {
"oragono_config": lambda config: config["accounts"]["registration"].update(
def config() -> cases.TestCaseControllerConfig:
return cases.TestCaseControllerConfig(
oragono_config=lambda config: config["accounts"]["registration"].update(
{
"email-verification": {
"enabled": True,
@ -64,7 +64,7 @@ class TestRegisterEmailVerified(cases.BaseServerTestCase):
"allow-before-connect": True,
}
)
}
)
@cases.mark_specifications("Oragono")
def testBeforeConnect(self):
@ -93,12 +93,12 @@ class TestRegisterEmailVerified(cases.BaseServerTestCase):
class TestRegisterNoLandGrabs(cases.BaseServerTestCase):
@staticmethod
def config():
return {
"oragono_config": lambda config: config["accounts"]["registration"].update(
def config() -> cases.TestCaseControllerConfig:
return cases.TestCaseControllerConfig(
oragono_config=lambda config: config["accounts"]["registration"].update(
{"allow-before-connect": True}
)
}
)
@cases.mark_specifications("Oragono")
def testBeforeConnect(self):

View File

@ -8,8 +8,8 @@ RELAYMSG_TAG_NAME = "draft/relaymsg"
class RelaymsgTestCase(cases.BaseServerTestCase):
@staticmethod
def config():
return {"chathistory": True}
def config() -> cases.TestCaseControllerConfig:
return cases.TestCaseControllerConfig(chathistory=True)
@cases.mark_specifications("Oragono")
def testRelaymsg(self):

View File

@ -5,8 +5,8 @@ from irctest.numerics import ERR_CANNOTSENDRP
class RoleplayTestCase(cases.BaseServerTestCase):
@staticmethod
def config():
return {"oragono_roleplay": True}
def config() -> cases.TestCaseControllerConfig:
return cases.TestCaseControllerConfig(oragono_roleplay=True)
@cases.mark_specifications("Oragono")
def testRoleplay(self):

View File

@ -1,11 +1,7 @@
import time
from irctest import cases
from irctest.irc_utils.junkdrawer import (
ircv3_timestamp_to_unixtime,
random_name,
to_history_message,
)
from irctest.irc_utils.junkdrawer import ircv3_timestamp_to_unixtime, random_name
def extract_playback_privmsgs(messages):
@ -13,14 +9,14 @@ def extract_playback_privmsgs(messages):
result = []
for msg in messages:
if msg.command == "PRIVMSG" and msg.params[0].lower() != "*playback":
result.append(to_history_message(msg))
result.append(msg.to_history_message())
return result
class ZncPlaybackTestCase(cases.BaseServerTestCase):
@staticmethod
def config():
return {"chathistory": True}
def config() -> cases.TestCaseControllerConfig:
return cases.TestCaseControllerConfig(chathistory=True)
@cases.mark_specifications("Oragono")
def testZncPlayback(self):
@ -58,9 +54,9 @@ class ZncPlaybackTestCase(cases.BaseServerTestCase):
self.joinChannel(qux, chname)
self.sendLine(qux, "PRIVMSG %s :hi there" % (bar,))
dm = to_history_message(
[msg for msg in self.getMessages(qux) if msg.command == "PRIVMSG"][0]
)
dm = [msg for msg in self.getMessages(qux) if msg.command == "PRIVMSG"][
0
].to_history_message()
self.assertEqual(dm.text, "hi there")
NUM_MESSAGES = 10
@ -68,7 +64,7 @@ class ZncPlaybackTestCase(cases.BaseServerTestCase):
for i in range(NUM_MESSAGES):
self.sendLine(qux, "PRIVMSG %s :this is message %d" % (chname, i))
echo_messages.extend(
to_history_message(msg)
msg.to_history_message()
for msg in self.getMessages(qux)
if msg.command == "PRIVMSG"
)

View File

@ -1,3 +1,8 @@
import collections
import dataclasses
from typing import List
TlsConfig = collections.namedtuple("TlsConfig", "enable trusted_fingerprints")
@dataclasses.dataclass
class TlsConfig:
enable: bool
trusted_fingerprints: List[str]