chathistory: Validate BATCH commands more strictly (#208)

This commit is contained in:
Val Lorentz 2023-06-03 19:32:05 +02:00 committed by GitHub
parent 5a5dbdb50d
commit e5f22e8080
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -10,7 +10,7 @@ import pytest
from irctest import cases, runner
from irctest.irc_utils.junkdrawer import random_name
from irctest.patma import ANYSTR
from irctest.patma import ANYSTR, StrRe
CHATHISTORY_CAP = "draft/chathistory"
EVENT_PLAYBACK_CAP = "draft/event-playback"
@ -21,28 +21,6 @@ SUBCOMMANDS = ["LATEST", "BEFORE", "AFTER", "BETWEEN", "AROUND"]
MYSQL_PASSWORD = ""
def validate_chathistory_batch(msgs):
batch_tag = None
closed_batch_tag = None
result = []
for msg in msgs:
if msg.command == "BATCH":
batch_param = msg.params[0]
if batch_tag is None and batch_param[0] == "+":
batch_tag = batch_param[1:]
elif batch_param[0] == "-":
closed_batch_tag = batch_param[1:]
elif (
msg.command == "PRIVMSG"
and batch_tag is not None
and msg.tags.get("batch") == batch_tag
):
if not msg.prefix.startswith("HistServ!"): # FIXME: ergo-specific
result.append(msg.to_history_message())
assert batch_tag == closed_batch_tag
return result
def skip_ngircd(f):
@functools.wraps(f)
def newf(self, *args, **kwargs):
@ -56,6 +34,26 @@ def skip_ngircd(f):
@cases.mark_specifications("IRCv3")
@cases.mark_services
class ChathistoryTestCase(cases.BaseServerTestCase):
def validate_chathistory_batch(self, msgs, target):
(start, *inner_msgs, end) = msgs
self.assertMessageMatch(
start, command="BATCH", params=[StrRe(r"\+.*"), "chathistory", target]
)
batch_tag = start.params[0][1:]
self.assertMessageMatch(end, command="BATCH", params=["-" + batch_tag])
result = []
for msg in inner_msgs:
if (
msg.command == "PRIVMSG"
and batch_tag is not None
and msg.tags.get("batch") == batch_tag
):
if not msg.prefix.startswith("HistServ!"): # FIXME: ergo-specific
result.append(msg.to_history_message())
return result
@staticmethod
def config() -> cases.TestCaseControllerConfig:
return cases.TestCaseControllerConfig(chathistory=True)
@ -308,6 +306,9 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
)
time.sleep(0.002)
self.getMessages(1)
self.getMessages(2)
self.validate_echo_messages(NUM_MESSAGES, echo_messages)
self.validate_chathistory(subcommand, echo_messages, 1, c2)
self.validate_chathistory(subcommand, echo_messages, 2, c1)
@ -401,15 +402,15 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
def _validate_chathistory_LATEST(self, echo_messages, user, chname):
INCLUSIVE_LIMIT = len(echo_messages) * 2
self.sendLine(user, "CHATHISTORY LATEST %s * %d" % (chname, INCLUSIVE_LIMIT))
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages, result)
self.sendLine(user, "CHATHISTORY LATEST %s * %d" % (chname, 5))
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[-5:], result)
self.sendLine(user, "CHATHISTORY LATEST %s * %d" % (chname, 1))
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[-1:], result)
self.sendLine(
@ -417,7 +418,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
"CHATHISTORY LATEST %s msgid=%s %d"
% (chname, echo_messages[4].msgid, INCLUSIVE_LIMIT),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[5:], result)
self.sendLine(
@ -425,7 +426,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
"CHATHISTORY LATEST %s timestamp=%s %d"
% (chname, echo_messages[4].time, INCLUSIVE_LIMIT),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[5:], result)
def _validate_chathistory_BEFORE(self, echo_messages, user, chname):
@ -435,7 +436,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
"CHATHISTORY BEFORE %s msgid=%s %d"
% (chname, echo_messages[6].msgid, INCLUSIVE_LIMIT),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[:6], result)
self.sendLine(
@ -443,7 +444,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
"CHATHISTORY BEFORE %s timestamp=%s %d"
% (chname, echo_messages[6].time, INCLUSIVE_LIMIT),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[:6], result)
self.sendLine(
@ -451,7 +452,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
"CHATHISTORY BEFORE %s timestamp=%s %d"
% (chname, echo_messages[6].time, 2),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[4:6], result)
def _validate_chathistory_AFTER(self, echo_messages, user, chname):
@ -461,7 +462,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
"CHATHISTORY AFTER %s msgid=%s %d"
% (chname, echo_messages[3].msgid, INCLUSIVE_LIMIT),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[4:], result)
self.sendLine(
@ -469,14 +470,14 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
"CHATHISTORY AFTER %s timestamp=%s %d"
% (chname, echo_messages[3].time, INCLUSIVE_LIMIT),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[4:], result)
self.sendLine(
user,
"CHATHISTORY AFTER %s timestamp=%s %d" % (chname, echo_messages[3].time, 3),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[4:7], result)
def _validate_chathistory_BETWEEN(self, echo_messages, user, chname):
@ -492,7 +493,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
INCLUSIVE_LIMIT,
),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[1:-1], result)
self.sendLine(
@ -505,7 +506,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
INCLUSIVE_LIMIT,
),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[1:-1], result)
# BETWEEN forwards and backwards with a limit, should get
@ -515,7 +516,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
"CHATHISTORY BETWEEN %s msgid=%s msgid=%s %d"
% (chname, echo_messages[0].msgid, echo_messages[-1].msgid, 3),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[1:4], result)
self.sendLine(
@ -523,7 +524,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
"CHATHISTORY BETWEEN %s msgid=%s msgid=%s %d"
% (chname, echo_messages[-1].msgid, echo_messages[0].msgid, 3),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[-4:-1], result)
# same stuff again but with timestamps
@ -532,28 +533,28 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
"CHATHISTORY BETWEEN %s timestamp=%s timestamp=%s %d"
% (chname, echo_messages[0].time, echo_messages[-1].time, INCLUSIVE_LIMIT),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[1:-1], result)
self.sendLine(
user,
"CHATHISTORY BETWEEN %s timestamp=%s timestamp=%s %d"
% (chname, echo_messages[-1].time, echo_messages[0].time, INCLUSIVE_LIMIT),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[1:-1], result)
self.sendLine(
user,
"CHATHISTORY BETWEEN %s timestamp=%s timestamp=%s %d"
% (chname, echo_messages[0].time, echo_messages[-1].time, 3),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[1:4], result)
self.sendLine(
user,
"CHATHISTORY BETWEEN %s timestamp=%s timestamp=%s %d"
% (chname, echo_messages[-1].time, echo_messages[0].time, 3),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[-4:-1], result)
def _validate_chathistory_AROUND(self, echo_messages, user, chname):
@ -561,14 +562,14 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
user,
"CHATHISTORY AROUND %s msgid=%s %d" % (chname, echo_messages[7].msgid, 1),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual([echo_messages[7]], result)
self.sendLine(
user,
"CHATHISTORY AROUND %s msgid=%s %d" % (chname, echo_messages[7].msgid, 3),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[6:9], result)
self.sendLine(
@ -576,7 +577,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
"CHATHISTORY AROUND %s timestamp=%s %d"
% (chname, echo_messages[7].time, 3),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertIn(echo_messages[7], result)
@pytest.mark.arbitrary_client_tags