Use dataclasses instead of dicts/namedtuples

This commit is contained in:
Valentin Lorentz 2021-02-28 13:40:08 +01:00 committed by Valentin Lorentz
parent 12da7e1e3b
commit ac2a37362c
15 changed files with 124 additions and 97 deletions

View File

@ -7,13 +7,8 @@ from typing import Optional, Tuple
class Mechanisms(enum.Enum): class Mechanisms(enum.Enum):
"""Enumeration for representing possible mechanisms.""" """Enumeration for representing possible mechanisms."""
@classmethod def to_string(self) -> str:
def as_string(cls, mech): return self.name.upper().replace("_", "-")
return {
cls.plain: "PLAIN",
cls.ecdsa_nist256p_challenge: "ECDSA-NIST256P-CHALLENGE",
cls.scram_sha_256: "SCRAM-SHA-256",
}[mech]
plain = 1 plain = 1
ecdsa_nist256p_challenge = 2 ecdsa_nist256p_challenge = 2

View File

@ -1,10 +1,11 @@
import dataclasses
import os import os
import shutil import shutil
import socket import socket
import subprocess import subprocess
import tempfile import tempfile
import time import time
from typing import Set from typing import Any, Callable, Dict, Optional, Set
from .runner import NotImplementedByController from .runner import NotImplementedByController
@ -15,6 +16,24 @@ class ProcessStopped(Exception):
pass pass
@dataclasses.dataclass
class TestCaseControllerConfig:
"""Test-case-specific configuration passed to the controller.
This is usually used to ask controllers to enable a feature;
but should not be an issue if controllers enable it all the time."""
chathistory: bool = False
"""Whether to enable chathistory features."""
oragono_roleplay: bool = False
"""Whether to enable the Oragono role-play commands."""
oragono_config: Optional[Callable[[Dict], Any]] = None
"""Oragono-specific configuration function that alters the dict in-place
This should be used as little as possible, using the other attributes instead;
as they are work with any controller."""
class _BaseController: class _BaseController:
"""Base class for software controllers. """Base class for software controllers.
@ -22,7 +41,7 @@ class _BaseController:
a process (eg. a server or a client), as well as sending it instructions a process (eg. a server or a client), as well as sending it instructions
that are not part of the IRC specification.""" that are not part of the IRC specification."""
def __init__(self, test_config): def __init__(self, test_config: TestCaseControllerConfig):
self.test_config = test_config self.test_config = test_config
self.proc = None self.proc = None
@ -36,7 +55,7 @@ class DirectoryBasedController(_BaseController):
"""Helper for controllers whose software configuration is based on an """Helper for controllers whose software configuration is based on an
arbitrary directory.""" arbitrary directory."""
def __init__(self, test_config): def __init__(self, test_config: TestCaseControllerConfig):
super().__init__(test_config) super().__init__(test_config)
self.directory = None self.directory = None

View File

@ -9,6 +9,7 @@ import unittest
import pytest import pytest
from . import basecontrollers, client_mock, runner from . import basecontrollers, client_mock, runner
from .basecontrollers import TestCaseControllerConfig
from .exceptions import ConnectionClosed from .exceptions import ConnectionClosed
from .irc_utils import capabilities, message_parser from .irc_utils import capabilities, message_parser
from .irc_utils.junkdrawer import normalizeWhitespace from .irc_utils.junkdrawer import normalizeWhitespace
@ -48,12 +49,12 @@ class _IrcTestCase(unittest.TestCase):
controllerClass = None # Will be set by __main__.py controllerClass = None # Will be set by __main__.py
@staticmethod @staticmethod
def config(): def config() -> TestCaseControllerConfig:
"""Some configuration to pass to the controllers. """Some configuration to pass to the controllers.
For example, Oragono only enables its MySQL support if For example, Oragono only enables its MySQL support if
config()["chathistory"]=True. config()["chathistory"]=True.
""" """
return {} return TestCaseControllerConfig()
def description(self): def description(self):
method_doc = self._testMethodDoc method_doc = self._testMethodDoc

View File

@ -1,7 +1,7 @@
import os import os
import subprocess import subprocess
from irctest import authentication, tls from irctest import tls
from irctest.basecontrollers import BaseClientController, DirectoryBasedController from irctest.basecontrollers import BaseClientController, DirectoryBasedController
TEMPLATE_CONFIG = """ TEMPLATE_CONFIG = """
@ -50,9 +50,7 @@ class LimnoriaController(BaseClientController, DirectoryBasedController):
assert self.proc is None assert self.proc is None
self.create_config() self.create_config()
if auth: if auth:
mechanisms = " ".join( mechanisms = " ".join(mech.to_string() for mech in auth.mechanisms)
map(authentication.Mechanisms.as_string, auth.mechanisms)
)
if auth.ecdsa_key: if auth.ecdsa_key:
with self.open_file("ecdsa_key.pem") as fd: with self.open_file("ecdsa_key.pem") as fd:
fd.write(auth.ecdsa_key) fd.write(auth.ecdsa_key)

View File

@ -162,16 +162,16 @@ class OragonoController(BaseServerController, DirectoryBasedController):
if config is None: if config is None:
config = copy.deepcopy(BASE_CONFIG) config = copy.deepcopy(BASE_CONFIG)
enable_chathistory = self.test_config.get("chathistory") enable_chathistory = self.test_config.chathistory
enable_roleplay = self.test_config.get("oragono_roleplay") enable_roleplay = self.test_config.oragono_roleplay
if enable_chathistory or enable_roleplay: if enable_chathistory or enable_roleplay:
config = self.addMysqlToConfig(config) config = self.addMysqlToConfig(config)
if enable_roleplay: if enable_roleplay:
config["roleplay"] = {"enabled": True} config["roleplay"] = {"enabled": True}
if "oragono_config" in self.test_config: if self.test_config.oragono_config:
self.test_config["oragono_config"](config) self.test_config.oragono_config(config)
self.port = port self.port = port
bind_address = "127.0.0.1:%s" % (port,) bind_address = "127.0.0.1:%s" % (port,)

View File

@ -1,20 +1,7 @@
from collections import namedtuple
import datetime import datetime
import re import re
import secrets import secrets
HistoryMessage = namedtuple("HistoryMessage", ["time", "msgid", "target", "text"])
def to_history_message(msg):
return HistoryMessage(
time=msg.tags.get("time"),
msgid=msg.tags.get("msgid"),
target=msg.params[0],
text=msg.params[1],
)
# thanks jess! # thanks jess!
IRCV3_FORMAT_STRFTIME = "%Y-%m-%dT%H:%M:%S.%f%z" IRCV3_FORMAT_STRFTIME = "%Y-%m-%dT%H:%M:%S.%f%z"

View File

@ -1,5 +1,6 @@
import collections import dataclasses
import re import re
from typing import Any, Dict, List, Optional
from .junkdrawer import MultipleReplacer from .junkdrawer import MultipleReplacer
@ -29,7 +30,28 @@ def parse_tags(s):
return tags return tags
Message = collections.namedtuple("Message", "tags prefix command params") @dataclasses.dataclass(frozen=True)
class HistoryMessage:
time: Any
msgid: Optional[str]
target: str
text: str
@dataclasses.dataclass(frozen=True)
class Message:
tags: Dict[str, Optional[str]]
prefix: Optional[str]
command: str
params: List[str]
def to_history_message(self) -> HistoryMessage:
return HistoryMessage(
time=self.tags.get("time"),
msgid=self.tags.get("msgid"),
target=self.params[0],
text=self.params[1],
)
def parse_message(s): def parse_message(s):

View File

@ -2,7 +2,7 @@ import secrets
import time import time
from irctest import cases from irctest import cases
from irctest.irc_utils.junkdrawer import random_name, to_history_message from irctest.irc_utils.junkdrawer import random_name
CHATHISTORY_CAP = "draft/chathistory" CHATHISTORY_CAP = "draft/chathistory"
EVENT_PLAYBACK_CAP = "draft/event-playback" EVENT_PLAYBACK_CAP = "draft/event-playback"
@ -27,15 +27,15 @@ def validate_chathistory_batch(msgs):
and batch_tag is not None and batch_tag is not None
and msg.tags.get("batch") == batch_tag and msg.tags.get("batch") == batch_tag
): ):
result.append(to_history_message(msg)) result.append(msg.to_history_message())
assert batch_tag == closed_batch_tag assert batch_tag == closed_batch_tag
return result return result
class ChathistoryTestCase(cases.BaseServerTestCase): class ChathistoryTestCase(cases.BaseServerTestCase):
@staticmethod @staticmethod
def config(): def config() -> cases.TestCaseControllerConfig:
return {"chathistory": True} return cases.TestCaseControllerConfig(chathistory=True)
@cases.mark_specifications("Oragono") @cases.mark_specifications("Oragono")
def testInvalidTargets(self): def testInvalidTargets(self):
@ -93,7 +93,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
self.assertEqual(len(replies), 1) self.assertEqual(len(replies), 1)
msg = replies[0] msg = replies[0]
self.assertEqual(msg.params, [bar, "this is a privmsg sent to myself"]) self.assertEqual(msg.params, [bar, "this is a privmsg sent to myself"])
messages.append(to_history_message(msg)) messages.append(msg.to_history_message())
self.sendLine(bar, "CAP REQ echo-message") self.sendLine(bar, "CAP REQ echo-message")
self.getMessages(bar) self.getMessages(bar)
@ -106,9 +106,11 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
self.assertEqual( self.assertEqual(
replies[0].params, [bar, "this is a second privmsg sent to myself"] replies[0].params, [bar, "this is a second privmsg sent to myself"]
) )
messages.append(to_history_message(replies[0])) messages.append(replies[0].to_history_message())
# messages should be otherwise identical # messages should be otherwise identical
self.assertEqual(to_history_message(replies[0]), to_history_message(replies[1])) self.assertEqual(
replies[0].to_history_message(), replies[1].to_history_message()
)
self.sendLine( self.sendLine(
bar, bar,
@ -120,17 +122,17 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
echo = [msg for msg in replies if msg.tags.get("label") == "xyz"][0] echo = [msg for msg in replies if msg.tags.get("label") == "xyz"][0]
delivery = [msg for msg in replies if msg.tags.get("label") is None][0] delivery = [msg for msg in replies if msg.tags.get("label") is None][0]
self.assertEqual(echo.params, [bar, "this is a third privmsg sent to myself"]) self.assertEqual(echo.params, [bar, "this is a third privmsg sent to myself"])
messages.append(to_history_message(echo)) messages.append(echo.to_history_message())
self.assertEqual(to_history_message(echo), to_history_message(delivery)) self.assertEqual(echo.to_history_message(), delivery.to_history_message())
# should receive exactly 3 messages in the correct order, no duplicates # should receive exactly 3 messages in the correct order, no duplicates
self.sendLine(bar, "CHATHISTORY LATEST * * 10") self.sendLine(bar, "CHATHISTORY LATEST * * 10")
replies = [msg for msg in self.getMessages(bar) if msg.command == "PRIVMSG"] replies = [msg for msg in self.getMessages(bar) if msg.command == "PRIVMSG"]
self.assertEqual([to_history_message(msg) for msg in replies], messages) self.assertEqual([msg.to_history_message() for msg in replies], messages)
self.sendLine(bar, "CHATHISTORY LATEST %s * 10" % (bar,)) self.sendLine(bar, "CHATHISTORY LATEST %s * 10" % (bar,))
replies = [msg for msg in self.getMessages(bar) if msg.command == "PRIVMSG"] replies = [msg for msg in self.getMessages(bar) if msg.command == "PRIVMSG"]
self.assertEqual([to_history_message(msg) for msg in replies], messages) self.assertEqual([msg.to_history_message() for msg in replies], messages)
def validate_echo_messages(self, num_messages, echo_messages): def validate_echo_messages(self, num_messages, echo_messages):
# sanity checks: should have received the correct number of echo messages, # sanity checks: should have received the correct number of echo messages,
@ -161,7 +163,9 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
echo_messages = [] echo_messages = []
for i in range(NUM_MESSAGES): for i in range(NUM_MESSAGES):
self.sendLine(1, "PRIVMSG %s :this is message %d" % (chname, i)) self.sendLine(1, "PRIVMSG %s :this is message %d" % (chname, i))
echo_messages.extend(to_history_message(msg) for msg in self.getMessages(1)) echo_messages.extend(
msg.to_history_message() for msg in self.getMessages(1)
)
time.sleep(0.002) time.sleep(0.002)
self.validate_echo_messages(NUM_MESSAGES, echo_messages) self.validate_echo_messages(NUM_MESSAGES, echo_messages)
@ -213,7 +217,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
self.getMessages(user) self.getMessages(user)
self.sendLine(user, "PRIVMSG %s :this is message %d" % (target, i)) self.sendLine(user, "PRIVMSG %s :this is message %d" % (target, i))
echo_messages.extend( echo_messages.extend(
to_history_message(msg) for msg in self.getMessages(user) msg.to_history_message() for msg in self.getMessages(user)
) )
time.sleep(0.002) time.sleep(0.002)
@ -245,7 +249,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
) )
# 3 received the first message as a delivery and the second as an echo # 3 received the first message as a delivery and the second as an echo
new_convo = [ new_convo = [
to_history_message(msg) msg.to_history_message()
for msg in self.getMessages(3) for msg in self.getMessages(3)
if msg.command == "PRIVMSG" if msg.command == "PRIVMSG"
] ]
@ -262,7 +266,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
self.getMessages(1) self.getMessages(1)
self.sendLine(1, "CHATHISTORY LATEST %s * 10" % (c3,)) self.sendLine(1, "CHATHISTORY LATEST %s * 10" % (c3,))
results = [ results = [
to_history_message(msg) msg.to_history_message()
for msg in self.getMessages(1) for msg in self.getMessages(1)
if msg.command == "PRIVMSG" if msg.command == "PRIVMSG"
] ]
@ -296,7 +300,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
self.getMessages(c3) self.getMessages(c3)
self.sendLine(c3, "CHATHISTORY LATEST %s * 10" % (c1,)) self.sendLine(c3, "CHATHISTORY LATEST %s * 10" % (c1,))
results = [ results = [
to_history_message(msg) msg.to_history_message()
for msg in self.getMessages(c3) for msg in self.getMessages(c3)
if msg.command == "PRIVMSG" if msg.command == "PRIVMSG"
] ]

View File

@ -4,12 +4,12 @@ from irctest.numerics import ERR_NICKNAMEINUSE, RPL_WELCOME
class ConfusablesTestCase(cases.BaseServerTestCase): class ConfusablesTestCase(cases.BaseServerTestCase):
@staticmethod @staticmethod
def config(): def config() -> cases.TestCaseControllerConfig:
return { return cases.TestCaseControllerConfig(
"oragono_config": lambda config: config["accounts"].update( oragono_config=lambda config: config["accounts"].update(
{"nick-reservation": {"enabled": True, "method": "strict"}} {"nick-reservation": {"enabled": True, "method": "strict"}}
) )
} )
@cases.mark_specifications("Oragono") @cases.mark_specifications("Oragono")
def testConfusableNicks(self): def testConfusableNicks(self):

View File

@ -178,12 +178,12 @@ class LusersUnregisteredDefaultInvisibleTest(LusersUnregisteredTestCase):
"""Same as above but with +i as the default.""" """Same as above but with +i as the default."""
@staticmethod @staticmethod
def config(): def config() -> cases.TestCaseControllerConfig:
return { return cases.TestCaseControllerConfig(
"oragono_config": lambda config: config["accounts"].update( oragono_config=lambda config: config["accounts"].update(
{"default-user-modes": "+i"} {"default-user-modes": "+i"}
) )
} )
@cases.mark_specifications("Oragono") @cases.mark_specifications("Oragono")
def testLusers(self): def testLusers(self):
@ -236,12 +236,12 @@ class LuserOpersTest(LusersTestCase):
class OragonoInvisibleDefaultTest(LusersTestCase): class OragonoInvisibleDefaultTest(LusersTestCase):
@staticmethod @staticmethod
def config(): def config() -> cases.TestCaseControllerConfig:
return { return cases.TestCaseControllerConfig(
"oragono_config": lambda config: config["accounts"].update( oragono_config=lambda config: config["accounts"].update(
{"default-user-modes": "+i"} {"default-user-modes": "+i"}
) )
} )
@cases.mark_specifications("Oragono") @cases.mark_specifications("Oragono")
def testLusers(self): def testLusers(self):

View File

@ -5,12 +5,12 @@ REGISTER_CAP_NAME = "draft/register"
class TestRegisterBeforeConnect(cases.BaseServerTestCase): class TestRegisterBeforeConnect(cases.BaseServerTestCase):
@staticmethod @staticmethod
def config(): def config() -> cases.TestCaseControllerConfig:
return { return cases.TestCaseControllerConfig(
"oragono_config": lambda config: config["accounts"]["registration"].update( oragono_config=lambda config: config["accounts"]["registration"].update(
{"allow-before-connect": True} {"allow-before-connect": True}
) )
} )
@cases.mark_specifications("Oragono") @cases.mark_specifications("Oragono")
def testBeforeConnect(self): def testBeforeConnect(self):
@ -28,12 +28,12 @@ class TestRegisterBeforeConnect(cases.BaseServerTestCase):
class TestRegisterBeforeConnectDisallowed(cases.BaseServerTestCase): class TestRegisterBeforeConnectDisallowed(cases.BaseServerTestCase):
@staticmethod @staticmethod
def config(): def config() -> cases.TestCaseControllerConfig:
return { return cases.TestCaseControllerConfig(
"oragono_config": lambda config: config["accounts"]["registration"].update( oragono_config=lambda config: config["accounts"]["registration"].update(
{"allow-before-connect": False} {"allow-before-connect": False}
) )
} )
@cases.mark_specifications("Oragono") @cases.mark_specifications("Oragono")
def testBeforeConnect(self): def testBeforeConnect(self):
@ -51,9 +51,9 @@ class TestRegisterBeforeConnectDisallowed(cases.BaseServerTestCase):
class TestRegisterEmailVerified(cases.BaseServerTestCase): class TestRegisterEmailVerified(cases.BaseServerTestCase):
@staticmethod @staticmethod
def config(): def config() -> cases.TestCaseControllerConfig:
return { return cases.TestCaseControllerConfig(
"oragono_config": lambda config: config["accounts"]["registration"].update( oragono_config=lambda config: config["accounts"]["registration"].update(
{ {
"email-verification": { "email-verification": {
"enabled": True, "enabled": True,
@ -64,7 +64,7 @@ class TestRegisterEmailVerified(cases.BaseServerTestCase):
"allow-before-connect": True, "allow-before-connect": True,
} }
) )
} )
@cases.mark_specifications("Oragono") @cases.mark_specifications("Oragono")
def testBeforeConnect(self): def testBeforeConnect(self):
@ -93,12 +93,12 @@ class TestRegisterEmailVerified(cases.BaseServerTestCase):
class TestRegisterNoLandGrabs(cases.BaseServerTestCase): class TestRegisterNoLandGrabs(cases.BaseServerTestCase):
@staticmethod @staticmethod
def config(): def config() -> cases.TestCaseControllerConfig:
return { return cases.TestCaseControllerConfig(
"oragono_config": lambda config: config["accounts"]["registration"].update( oragono_config=lambda config: config["accounts"]["registration"].update(
{"allow-before-connect": True} {"allow-before-connect": True}
) )
} )
@cases.mark_specifications("Oragono") @cases.mark_specifications("Oragono")
def testBeforeConnect(self): def testBeforeConnect(self):

View File

@ -8,8 +8,8 @@ RELAYMSG_TAG_NAME = "draft/relaymsg"
class RelaymsgTestCase(cases.BaseServerTestCase): class RelaymsgTestCase(cases.BaseServerTestCase):
@staticmethod @staticmethod
def config(): def config() -> cases.TestCaseControllerConfig:
return {"chathistory": True} return cases.TestCaseControllerConfig(chathistory=True)
@cases.mark_specifications("Oragono") @cases.mark_specifications("Oragono")
def testRelaymsg(self): def testRelaymsg(self):

View File

@ -5,8 +5,8 @@ from irctest.numerics import ERR_CANNOTSENDRP
class RoleplayTestCase(cases.BaseServerTestCase): class RoleplayTestCase(cases.BaseServerTestCase):
@staticmethod @staticmethod
def config(): def config() -> cases.TestCaseControllerConfig:
return {"oragono_roleplay": True} return cases.TestCaseControllerConfig(oragono_roleplay=True)
@cases.mark_specifications("Oragono") @cases.mark_specifications("Oragono")
def testRoleplay(self): def testRoleplay(self):

View File

@ -1,11 +1,7 @@
import time import time
from irctest import cases from irctest import cases
from irctest.irc_utils.junkdrawer import ( from irctest.irc_utils.junkdrawer import ircv3_timestamp_to_unixtime, random_name
ircv3_timestamp_to_unixtime,
random_name,
to_history_message,
)
def extract_playback_privmsgs(messages): def extract_playback_privmsgs(messages):
@ -13,14 +9,14 @@ def extract_playback_privmsgs(messages):
result = [] result = []
for msg in messages: for msg in messages:
if msg.command == "PRIVMSG" and msg.params[0].lower() != "*playback": if msg.command == "PRIVMSG" and msg.params[0].lower() != "*playback":
result.append(to_history_message(msg)) result.append(msg.to_history_message())
return result return result
class ZncPlaybackTestCase(cases.BaseServerTestCase): class ZncPlaybackTestCase(cases.BaseServerTestCase):
@staticmethod @staticmethod
def config(): def config() -> cases.TestCaseControllerConfig:
return {"chathistory": True} return cases.TestCaseControllerConfig(chathistory=True)
@cases.mark_specifications("Oragono") @cases.mark_specifications("Oragono")
def testZncPlayback(self): def testZncPlayback(self):
@ -58,9 +54,9 @@ class ZncPlaybackTestCase(cases.BaseServerTestCase):
self.joinChannel(qux, chname) self.joinChannel(qux, chname)
self.sendLine(qux, "PRIVMSG %s :hi there" % (bar,)) self.sendLine(qux, "PRIVMSG %s :hi there" % (bar,))
dm = to_history_message( dm = [msg for msg in self.getMessages(qux) if msg.command == "PRIVMSG"][
[msg for msg in self.getMessages(qux) if msg.command == "PRIVMSG"][0] 0
) ].to_history_message()
self.assertEqual(dm.text, "hi there") self.assertEqual(dm.text, "hi there")
NUM_MESSAGES = 10 NUM_MESSAGES = 10
@ -68,7 +64,7 @@ class ZncPlaybackTestCase(cases.BaseServerTestCase):
for i in range(NUM_MESSAGES): for i in range(NUM_MESSAGES):
self.sendLine(qux, "PRIVMSG %s :this is message %d" % (chname, i)) self.sendLine(qux, "PRIVMSG %s :this is message %d" % (chname, i))
echo_messages.extend( echo_messages.extend(
to_history_message(msg) msg.to_history_message()
for msg in self.getMessages(qux) for msg in self.getMessages(qux)
if msg.command == "PRIVMSG" if msg.command == "PRIVMSG"
) )

View File

@ -1,3 +1,8 @@
import collections import dataclasses
from typing import List
TlsConfig = collections.namedtuple("TlsConfig", "enable trusted_fingerprints")
@dataclasses.dataclass
class TlsConfig:
enable: bool
trusted_fingerprints: List[str]