From e5f22e8080d215fa9b24b51752fbb8f5d51fbbc9 Mon Sep 17 00:00:00 2001 From: Val Lorentz Date: Sat, 3 Jun 2023 19:32:05 +0200 Subject: [PATCH] chathistory: Validate BATCH commands more strictly (#208) --- irctest/server_tests/chathistory.py | 91 +++++++++++++++-------------- 1 file changed, 46 insertions(+), 45 deletions(-) diff --git a/irctest/server_tests/chathistory.py b/irctest/server_tests/chathistory.py index 28a201a..873661e 100644 --- a/irctest/server_tests/chathistory.py +++ b/irctest/server_tests/chathistory.py @@ -10,7 +10,7 @@ import pytest from irctest import cases, runner from irctest.irc_utils.junkdrawer import random_name -from irctest.patma import ANYSTR +from irctest.patma import ANYSTR, StrRe CHATHISTORY_CAP = "draft/chathistory" EVENT_PLAYBACK_CAP = "draft/event-playback" @@ -21,28 +21,6 @@ SUBCOMMANDS = ["LATEST", "BEFORE", "AFTER", "BETWEEN", "AROUND"] MYSQL_PASSWORD = "" -def validate_chathistory_batch(msgs): - batch_tag = None - closed_batch_tag = None - result = [] - for msg in msgs: - if msg.command == "BATCH": - batch_param = msg.params[0] - if batch_tag is None and batch_param[0] == "+": - batch_tag = batch_param[1:] - elif batch_param[0] == "-": - closed_batch_tag = batch_param[1:] - elif ( - msg.command == "PRIVMSG" - and batch_tag is not None - and msg.tags.get("batch") == batch_tag - ): - if not msg.prefix.startswith("HistServ!"): # FIXME: ergo-specific - result.append(msg.to_history_message()) - assert batch_tag == closed_batch_tag - return result - - def skip_ngircd(f): @functools.wraps(f) def newf(self, *args, **kwargs): @@ -56,6 +34,26 @@ def skip_ngircd(f): @cases.mark_specifications("IRCv3") @cases.mark_services class ChathistoryTestCase(cases.BaseServerTestCase): + def validate_chathistory_batch(self, msgs, target): + (start, *inner_msgs, end) = msgs + + self.assertMessageMatch( + start, command="BATCH", params=[StrRe(r"\+.*"), "chathistory", target] + ) + batch_tag = start.params[0][1:] + self.assertMessageMatch(end, command="BATCH", params=["-" + batch_tag]) + + result = [] + for msg in inner_msgs: + if ( + msg.command == "PRIVMSG" + and batch_tag is not None + and msg.tags.get("batch") == batch_tag + ): + if not msg.prefix.startswith("HistServ!"): # FIXME: ergo-specific + result.append(msg.to_history_message()) + return result + @staticmethod def config() -> cases.TestCaseControllerConfig: return cases.TestCaseControllerConfig(chathistory=True) @@ -308,6 +306,9 @@ class ChathistoryTestCase(cases.BaseServerTestCase): ) time.sleep(0.002) + self.getMessages(1) + self.getMessages(2) + self.validate_echo_messages(NUM_MESSAGES, echo_messages) self.validate_chathistory(subcommand, echo_messages, 1, c2) self.validate_chathistory(subcommand, echo_messages, 2, c1) @@ -401,15 +402,15 @@ class ChathistoryTestCase(cases.BaseServerTestCase): def _validate_chathistory_LATEST(self, echo_messages, user, chname): INCLUSIVE_LIMIT = len(echo_messages) * 2 self.sendLine(user, "CHATHISTORY LATEST %s * %d" % (chname, INCLUSIVE_LIMIT)) - result = validate_chathistory_batch(self.getMessages(user)) + result = self.validate_chathistory_batch(self.getMessages(user), chname) self.assertEqual(echo_messages, result) self.sendLine(user, "CHATHISTORY LATEST %s * %d" % (chname, 5)) - result = validate_chathistory_batch(self.getMessages(user)) + result = self.validate_chathistory_batch(self.getMessages(user), chname) self.assertEqual(echo_messages[-5:], result) self.sendLine(user, "CHATHISTORY LATEST %s * %d" % (chname, 1)) - result = validate_chathistory_batch(self.getMessages(user)) + result = self.validate_chathistory_batch(self.getMessages(user), chname) self.assertEqual(echo_messages[-1:], result) self.sendLine( @@ -417,7 +418,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase): "CHATHISTORY LATEST %s msgid=%s %d" % (chname, echo_messages[4].msgid, INCLUSIVE_LIMIT), ) - result = validate_chathistory_batch(self.getMessages(user)) + result = self.validate_chathistory_batch(self.getMessages(user), chname) self.assertEqual(echo_messages[5:], result) self.sendLine( @@ -425,7 +426,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase): "CHATHISTORY LATEST %s timestamp=%s %d" % (chname, echo_messages[4].time, INCLUSIVE_LIMIT), ) - result = validate_chathistory_batch(self.getMessages(user)) + result = self.validate_chathistory_batch(self.getMessages(user), chname) self.assertEqual(echo_messages[5:], result) def _validate_chathistory_BEFORE(self, echo_messages, user, chname): @@ -435,7 +436,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase): "CHATHISTORY BEFORE %s msgid=%s %d" % (chname, echo_messages[6].msgid, INCLUSIVE_LIMIT), ) - result = validate_chathistory_batch(self.getMessages(user)) + result = self.validate_chathistory_batch(self.getMessages(user), chname) self.assertEqual(echo_messages[:6], result) self.sendLine( @@ -443,7 +444,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase): "CHATHISTORY BEFORE %s timestamp=%s %d" % (chname, echo_messages[6].time, INCLUSIVE_LIMIT), ) - result = validate_chathistory_batch(self.getMessages(user)) + result = self.validate_chathistory_batch(self.getMessages(user), chname) self.assertEqual(echo_messages[:6], result) self.sendLine( @@ -451,7 +452,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase): "CHATHISTORY BEFORE %s timestamp=%s %d" % (chname, echo_messages[6].time, 2), ) - result = validate_chathistory_batch(self.getMessages(user)) + result = self.validate_chathistory_batch(self.getMessages(user), chname) self.assertEqual(echo_messages[4:6], result) def _validate_chathistory_AFTER(self, echo_messages, user, chname): @@ -461,7 +462,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase): "CHATHISTORY AFTER %s msgid=%s %d" % (chname, echo_messages[3].msgid, INCLUSIVE_LIMIT), ) - result = validate_chathistory_batch(self.getMessages(user)) + result = self.validate_chathistory_batch(self.getMessages(user), chname) self.assertEqual(echo_messages[4:], result) self.sendLine( @@ -469,14 +470,14 @@ class ChathistoryTestCase(cases.BaseServerTestCase): "CHATHISTORY AFTER %s timestamp=%s %d" % (chname, echo_messages[3].time, INCLUSIVE_LIMIT), ) - result = validate_chathistory_batch(self.getMessages(user)) + result = self.validate_chathistory_batch(self.getMessages(user), chname) self.assertEqual(echo_messages[4:], result) self.sendLine( user, "CHATHISTORY AFTER %s timestamp=%s %d" % (chname, echo_messages[3].time, 3), ) - result = validate_chathistory_batch(self.getMessages(user)) + result = self.validate_chathistory_batch(self.getMessages(user), chname) self.assertEqual(echo_messages[4:7], result) def _validate_chathistory_BETWEEN(self, echo_messages, user, chname): @@ -492,7 +493,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase): INCLUSIVE_LIMIT, ), ) - result = validate_chathistory_batch(self.getMessages(user)) + result = self.validate_chathistory_batch(self.getMessages(user), chname) self.assertEqual(echo_messages[1:-1], result) self.sendLine( @@ -505,7 +506,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase): INCLUSIVE_LIMIT, ), ) - result = validate_chathistory_batch(self.getMessages(user)) + result = self.validate_chathistory_batch(self.getMessages(user), chname) self.assertEqual(echo_messages[1:-1], result) # BETWEEN forwards and backwards with a limit, should get @@ -515,7 +516,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase): "CHATHISTORY BETWEEN %s msgid=%s msgid=%s %d" % (chname, echo_messages[0].msgid, echo_messages[-1].msgid, 3), ) - result = validate_chathistory_batch(self.getMessages(user)) + result = self.validate_chathistory_batch(self.getMessages(user), chname) self.assertEqual(echo_messages[1:4], result) self.sendLine( @@ -523,7 +524,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase): "CHATHISTORY BETWEEN %s msgid=%s msgid=%s %d" % (chname, echo_messages[-1].msgid, echo_messages[0].msgid, 3), ) - result = validate_chathistory_batch(self.getMessages(user)) + result = self.validate_chathistory_batch(self.getMessages(user), chname) self.assertEqual(echo_messages[-4:-1], result) # same stuff again but with timestamps @@ -532,28 +533,28 @@ class ChathistoryTestCase(cases.BaseServerTestCase): "CHATHISTORY BETWEEN %s timestamp=%s timestamp=%s %d" % (chname, echo_messages[0].time, echo_messages[-1].time, INCLUSIVE_LIMIT), ) - result = validate_chathistory_batch(self.getMessages(user)) + result = self.validate_chathistory_batch(self.getMessages(user), chname) self.assertEqual(echo_messages[1:-1], result) self.sendLine( user, "CHATHISTORY BETWEEN %s timestamp=%s timestamp=%s %d" % (chname, echo_messages[-1].time, echo_messages[0].time, INCLUSIVE_LIMIT), ) - result = validate_chathistory_batch(self.getMessages(user)) + result = self.validate_chathistory_batch(self.getMessages(user), chname) self.assertEqual(echo_messages[1:-1], result) self.sendLine( user, "CHATHISTORY BETWEEN %s timestamp=%s timestamp=%s %d" % (chname, echo_messages[0].time, echo_messages[-1].time, 3), ) - result = validate_chathistory_batch(self.getMessages(user)) + result = self.validate_chathistory_batch(self.getMessages(user), chname) self.assertEqual(echo_messages[1:4], result) self.sendLine( user, "CHATHISTORY BETWEEN %s timestamp=%s timestamp=%s %d" % (chname, echo_messages[-1].time, echo_messages[0].time, 3), ) - result = validate_chathistory_batch(self.getMessages(user)) + result = self.validate_chathistory_batch(self.getMessages(user), chname) self.assertEqual(echo_messages[-4:-1], result) def _validate_chathistory_AROUND(self, echo_messages, user, chname): @@ -561,14 +562,14 @@ class ChathistoryTestCase(cases.BaseServerTestCase): user, "CHATHISTORY AROUND %s msgid=%s %d" % (chname, echo_messages[7].msgid, 1), ) - result = validate_chathistory_batch(self.getMessages(user)) + result = self.validate_chathistory_batch(self.getMessages(user), chname) self.assertEqual([echo_messages[7]], result) self.sendLine( user, "CHATHISTORY AROUND %s msgid=%s %d" % (chname, echo_messages[7].msgid, 3), ) - result = validate_chathistory_batch(self.getMessages(user)) + result = self.validate_chathistory_batch(self.getMessages(user), chname) self.assertEqual(echo_messages[6:9], result) self.sendLine( @@ -576,7 +577,7 @@ class ChathistoryTestCase(cases.BaseServerTestCase): "CHATHISTORY AROUND %s timestamp=%s %d" % (chname, echo_messages[7].time, 3), ) - result = validate_chathistory_batch(self.getMessages(user)) + result = self.validate_chathistory_batch(self.getMessages(user), chname) self.assertIn(echo_messages[7], result) @pytest.mark.arbitrary_client_tags