chathistory: Validate BATCH commands more strictly (#208)

This commit is contained in:
Val Lorentz 2023-06-03 19:32:05 +02:00 committed by GitHub
parent 5a5dbdb50d
commit e5f22e8080
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -10,7 +10,7 @@ import pytest
from irctest import cases, runner from irctest import cases, runner
from irctest.irc_utils.junkdrawer import random_name from irctest.irc_utils.junkdrawer import random_name
from irctest.patma import ANYSTR from irctest.patma import ANYSTR, StrRe
CHATHISTORY_CAP = "draft/chathistory" CHATHISTORY_CAP = "draft/chathistory"
EVENT_PLAYBACK_CAP = "draft/event-playback" EVENT_PLAYBACK_CAP = "draft/event-playback"
@ -21,28 +21,6 @@ SUBCOMMANDS = ["LATEST", "BEFORE", "AFTER", "BETWEEN", "AROUND"]
MYSQL_PASSWORD = "" MYSQL_PASSWORD = ""
def validate_chathistory_batch(msgs):
batch_tag = None
closed_batch_tag = None
result = []
for msg in msgs:
if msg.command == "BATCH":
batch_param = msg.params[0]
if batch_tag is None and batch_param[0] == "+":
batch_tag = batch_param[1:]
elif batch_param[0] == "-":
closed_batch_tag = batch_param[1:]
elif (
msg.command == "PRIVMSG"
and batch_tag is not None
and msg.tags.get("batch") == batch_tag
):
if not msg.prefix.startswith("HistServ!"): # FIXME: ergo-specific
result.append(msg.to_history_message())
assert batch_tag == closed_batch_tag
return result
def skip_ngircd(f): def skip_ngircd(f):
@functools.wraps(f) @functools.wraps(f)
def newf(self, *args, **kwargs): def newf(self, *args, **kwargs):
@ -56,6 +34,26 @@ def skip_ngircd(f):
@cases.mark_specifications("IRCv3") @cases.mark_specifications("IRCv3")
@cases.mark_services @cases.mark_services
class ChathistoryTestCase(cases.BaseServerTestCase): class ChathistoryTestCase(cases.BaseServerTestCase):
def validate_chathistory_batch(self, msgs, target):
(start, *inner_msgs, end) = msgs
self.assertMessageMatch(
start, command="BATCH", params=[StrRe(r"\+.*"), "chathistory", target]
)
batch_tag = start.params[0][1:]
self.assertMessageMatch(end, command="BATCH", params=["-" + batch_tag])
result = []
for msg in inner_msgs:
if (
msg.command == "PRIVMSG"
and batch_tag is not None
and msg.tags.get("batch") == batch_tag
):
if not msg.prefix.startswith("HistServ!"): # FIXME: ergo-specific
result.append(msg.to_history_message())
return result
@staticmethod @staticmethod
def config() -> cases.TestCaseControllerConfig: def config() -> cases.TestCaseControllerConfig:
return cases.TestCaseControllerConfig(chathistory=True) return cases.TestCaseControllerConfig(chathistory=True)
@ -308,6 +306,9 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
) )
time.sleep(0.002) time.sleep(0.002)
self.getMessages(1)
self.getMessages(2)
self.validate_echo_messages(NUM_MESSAGES, echo_messages) self.validate_echo_messages(NUM_MESSAGES, echo_messages)
self.validate_chathistory(subcommand, echo_messages, 1, c2) self.validate_chathistory(subcommand, echo_messages, 1, c2)
self.validate_chathistory(subcommand, echo_messages, 2, c1) self.validate_chathistory(subcommand, echo_messages, 2, c1)
@ -401,15 +402,15 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
def _validate_chathistory_LATEST(self, echo_messages, user, chname): def _validate_chathistory_LATEST(self, echo_messages, user, chname):
INCLUSIVE_LIMIT = len(echo_messages) * 2 INCLUSIVE_LIMIT = len(echo_messages) * 2
self.sendLine(user, "CHATHISTORY LATEST %s * %d" % (chname, INCLUSIVE_LIMIT)) self.sendLine(user, "CHATHISTORY LATEST %s * %d" % (chname, INCLUSIVE_LIMIT))
result = validate_chathistory_batch(self.getMessages(user)) result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages, result) self.assertEqual(echo_messages, result)
self.sendLine(user, "CHATHISTORY LATEST %s * %d" % (chname, 5)) self.sendLine(user, "CHATHISTORY LATEST %s * %d" % (chname, 5))
result = validate_chathistory_batch(self.getMessages(user)) result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[-5:], result) self.assertEqual(echo_messages[-5:], result)
self.sendLine(user, "CHATHISTORY LATEST %s * %d" % (chname, 1)) self.sendLine(user, "CHATHISTORY LATEST %s * %d" % (chname, 1))
result = validate_chathistory_batch(self.getMessages(user)) result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[-1:], result) self.assertEqual(echo_messages[-1:], result)
self.sendLine( self.sendLine(
@ -417,7 +418,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
"CHATHISTORY LATEST %s msgid=%s %d" "CHATHISTORY LATEST %s msgid=%s %d"
% (chname, echo_messages[4].msgid, INCLUSIVE_LIMIT), % (chname, echo_messages[4].msgid, INCLUSIVE_LIMIT),
) )
result = validate_chathistory_batch(self.getMessages(user)) result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[5:], result) self.assertEqual(echo_messages[5:], result)
self.sendLine( self.sendLine(
@ -425,7 +426,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
"CHATHISTORY LATEST %s timestamp=%s %d" "CHATHISTORY LATEST %s timestamp=%s %d"
% (chname, echo_messages[4].time, INCLUSIVE_LIMIT), % (chname, echo_messages[4].time, INCLUSIVE_LIMIT),
) )
result = validate_chathistory_batch(self.getMessages(user)) result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[5:], result) self.assertEqual(echo_messages[5:], result)
def _validate_chathistory_BEFORE(self, echo_messages, user, chname): def _validate_chathistory_BEFORE(self, echo_messages, user, chname):
@ -435,7 +436,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
"CHATHISTORY BEFORE %s msgid=%s %d" "CHATHISTORY BEFORE %s msgid=%s %d"
% (chname, echo_messages[6].msgid, INCLUSIVE_LIMIT), % (chname, echo_messages[6].msgid, INCLUSIVE_LIMIT),
) )
result = validate_chathistory_batch(self.getMessages(user)) result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[:6], result) self.assertEqual(echo_messages[:6], result)
self.sendLine( self.sendLine(
@ -443,7 +444,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
"CHATHISTORY BEFORE %s timestamp=%s %d" "CHATHISTORY BEFORE %s timestamp=%s %d"
% (chname, echo_messages[6].time, INCLUSIVE_LIMIT), % (chname, echo_messages[6].time, INCLUSIVE_LIMIT),
) )
result = validate_chathistory_batch(self.getMessages(user)) result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[:6], result) self.assertEqual(echo_messages[:6], result)
self.sendLine( self.sendLine(
@ -451,7 +452,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
"CHATHISTORY BEFORE %s timestamp=%s %d" "CHATHISTORY BEFORE %s timestamp=%s %d"
% (chname, echo_messages[6].time, 2), % (chname, echo_messages[6].time, 2),
) )
result = validate_chathistory_batch(self.getMessages(user)) result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[4:6], result) self.assertEqual(echo_messages[4:6], result)
def _validate_chathistory_AFTER(self, echo_messages, user, chname): def _validate_chathistory_AFTER(self, echo_messages, user, chname):
@ -461,7 +462,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
"CHATHISTORY AFTER %s msgid=%s %d" "CHATHISTORY AFTER %s msgid=%s %d"
% (chname, echo_messages[3].msgid, INCLUSIVE_LIMIT), % (chname, echo_messages[3].msgid, INCLUSIVE_LIMIT),
) )
result = validate_chathistory_batch(self.getMessages(user)) result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[4:], result) self.assertEqual(echo_messages[4:], result)
self.sendLine( self.sendLine(
@ -469,14 +470,14 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
"CHATHISTORY AFTER %s timestamp=%s %d" "CHATHISTORY AFTER %s timestamp=%s %d"
% (chname, echo_messages[3].time, INCLUSIVE_LIMIT), % (chname, echo_messages[3].time, INCLUSIVE_LIMIT),
) )
result = validate_chathistory_batch(self.getMessages(user)) result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[4:], result) self.assertEqual(echo_messages[4:], result)
self.sendLine( self.sendLine(
user, user,
"CHATHISTORY AFTER %s timestamp=%s %d" % (chname, echo_messages[3].time, 3), "CHATHISTORY AFTER %s timestamp=%s %d" % (chname, echo_messages[3].time, 3),
) )
result = validate_chathistory_batch(self.getMessages(user)) result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[4:7], result) self.assertEqual(echo_messages[4:7], result)
def _validate_chathistory_BETWEEN(self, echo_messages, user, chname): def _validate_chathistory_BETWEEN(self, echo_messages, user, chname):
@ -492,7 +493,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
INCLUSIVE_LIMIT, INCLUSIVE_LIMIT,
), ),
) )
result = validate_chathistory_batch(self.getMessages(user)) result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[1:-1], result) self.assertEqual(echo_messages[1:-1], result)
self.sendLine( self.sendLine(
@ -505,7 +506,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
INCLUSIVE_LIMIT, INCLUSIVE_LIMIT,
), ),
) )
result = validate_chathistory_batch(self.getMessages(user)) result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[1:-1], result) self.assertEqual(echo_messages[1:-1], result)
# BETWEEN forwards and backwards with a limit, should get # BETWEEN forwards and backwards with a limit, should get
@ -515,7 +516,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
"CHATHISTORY BETWEEN %s msgid=%s msgid=%s %d" "CHATHISTORY BETWEEN %s msgid=%s msgid=%s %d"
% (chname, echo_messages[0].msgid, echo_messages[-1].msgid, 3), % (chname, echo_messages[0].msgid, echo_messages[-1].msgid, 3),
) )
result = validate_chathistory_batch(self.getMessages(user)) result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[1:4], result) self.assertEqual(echo_messages[1:4], result)
self.sendLine( self.sendLine(
@ -523,7 +524,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
"CHATHISTORY BETWEEN %s msgid=%s msgid=%s %d" "CHATHISTORY BETWEEN %s msgid=%s msgid=%s %d"
% (chname, echo_messages[-1].msgid, echo_messages[0].msgid, 3), % (chname, echo_messages[-1].msgid, echo_messages[0].msgid, 3),
) )
result = validate_chathistory_batch(self.getMessages(user)) result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[-4:-1], result) self.assertEqual(echo_messages[-4:-1], result)
# same stuff again but with timestamps # same stuff again but with timestamps
@ -532,28 +533,28 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
"CHATHISTORY BETWEEN %s timestamp=%s timestamp=%s %d" "CHATHISTORY BETWEEN %s timestamp=%s timestamp=%s %d"
% (chname, echo_messages[0].time, echo_messages[-1].time, INCLUSIVE_LIMIT), % (chname, echo_messages[0].time, echo_messages[-1].time, INCLUSIVE_LIMIT),
) )
result = validate_chathistory_batch(self.getMessages(user)) result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[1:-1], result) self.assertEqual(echo_messages[1:-1], result)
self.sendLine( self.sendLine(
user, user,
"CHATHISTORY BETWEEN %s timestamp=%s timestamp=%s %d" "CHATHISTORY BETWEEN %s timestamp=%s timestamp=%s %d"
% (chname, echo_messages[-1].time, echo_messages[0].time, INCLUSIVE_LIMIT), % (chname, echo_messages[-1].time, echo_messages[0].time, INCLUSIVE_LIMIT),
) )
result = validate_chathistory_batch(self.getMessages(user)) result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[1:-1], result) self.assertEqual(echo_messages[1:-1], result)
self.sendLine( self.sendLine(
user, user,
"CHATHISTORY BETWEEN %s timestamp=%s timestamp=%s %d" "CHATHISTORY BETWEEN %s timestamp=%s timestamp=%s %d"
% (chname, echo_messages[0].time, echo_messages[-1].time, 3), % (chname, echo_messages[0].time, echo_messages[-1].time, 3),
) )
result = validate_chathistory_batch(self.getMessages(user)) result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[1:4], result) self.assertEqual(echo_messages[1:4], result)
self.sendLine( self.sendLine(
user, user,
"CHATHISTORY BETWEEN %s timestamp=%s timestamp=%s %d" "CHATHISTORY BETWEEN %s timestamp=%s timestamp=%s %d"
% (chname, echo_messages[-1].time, echo_messages[0].time, 3), % (chname, echo_messages[-1].time, echo_messages[0].time, 3),
) )
result = validate_chathistory_batch(self.getMessages(user)) result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[-4:-1], result) self.assertEqual(echo_messages[-4:-1], result)
def _validate_chathistory_AROUND(self, echo_messages, user, chname): def _validate_chathistory_AROUND(self, echo_messages, user, chname):
@ -561,14 +562,14 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
user, user,
"CHATHISTORY AROUND %s msgid=%s %d" % (chname, echo_messages[7].msgid, 1), "CHATHISTORY AROUND %s msgid=%s %d" % (chname, echo_messages[7].msgid, 1),
) )
result = validate_chathistory_batch(self.getMessages(user)) result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual([echo_messages[7]], result) self.assertEqual([echo_messages[7]], result)
self.sendLine( self.sendLine(
user, user,
"CHATHISTORY AROUND %s msgid=%s %d" % (chname, echo_messages[7].msgid, 3), "CHATHISTORY AROUND %s msgid=%s %d" % (chname, echo_messages[7].msgid, 3),
) )
result = validate_chathistory_batch(self.getMessages(user)) result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[6:9], result) self.assertEqual(echo_messages[6:9], result)
self.sendLine( self.sendLine(
@ -576,7 +577,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase):
"CHATHISTORY AROUND %s timestamp=%s %d" "CHATHISTORY AROUND %s timestamp=%s %d"
% (chname, echo_messages[7].time, 3), % (chname, echo_messages[7].time, 3),
) )
result = validate_chathistory_batch(self.getMessages(user)) result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertIn(echo_messages[7], result) self.assertIn(echo_messages[7], result)
@pytest.mark.arbitrary_client_tags @pytest.mark.arbitrary_client_tags