assertMessageMatch: Add pattern-matching on tags, and start using it.

This commit is contained in:
Valentin Lorentz 2021-03-01 20:18:09 +01:00
parent 3c2db1531a
commit 1e0de7aefb
11 changed files with 251 additions and 394 deletions

View File

@ -123,8 +123,11 @@ class _IrcTestCase(unittest.TestCase, Generic[TController]):
def messageDiffers( def messageDiffers(
self, self,
msg: Message, msg: Message,
params: Optional[List[Union[str, patma.Operator]]] = None, params: Optional[List[Union[str, None, patma.Operator]]] = None,
target: Optional[str] = None, target: Optional[str] = None,
tags: Optional[
Dict[Union[str, patma.Operator], Union[str, patma.Operator, None]]
] = None,
nick: Optional[str] = None, nick: Optional[str] = None,
fail_msg: Optional[str] = None, fail_msg: Optional[str] = None,
extra_format: Tuple = (), extra_format: Tuple = (),
@ -145,12 +148,18 @@ class _IrcTestCase(unittest.TestCase, Generic[TController]):
msg=msg, msg=msg,
) )
if params and not patma.match_list(msg.params, params): if params and not patma.match_list(list(msg.params), params):
fail_msg = fail_msg or "params to be {expects}, got {got}: {msg}" fail_msg = (
fail_msg or "expected params to match {expects}, got {got}: {msg}"
)
return fail_msg.format( return fail_msg.format(
*extra_format, got=msg.params, expects=params, msg=msg *extra_format, got=msg.params, expects=params, msg=msg
) )
if tags and not patma.match_dict(msg.tags, tags):
fail_msg = fail_msg or "expected tags to match {expects}, got {got}: {msg}"
return fail_msg.format(*extra_format, got=msg.tags, expects=tags, msg=msg)
if nick: if nick:
got_nick = msg.prefix.split("!")[0] if msg.prefix else None got_nick = msg.prefix.split("!")[0] if msg.prefix else None
if msg.prefix is None: if msg.prefix is None:

View File

@ -2,7 +2,7 @@
import dataclasses import dataclasses
import re import re
from typing import List, Union from typing import Dict, List, Optional, Union
class Operator: class Operator:
@ -20,7 +20,14 @@ class AnyStr(Operator):
return "AnyStr" return "AnyStr"
@dataclasses.dataclass class AnyOptStr(Operator):
"""Wildcard matching any string as well as None"""
def __repr__(self) -> str:
return "AnyOptStr"
@dataclasses.dataclass(frozen=True)
class StrRe(Operator): class StrRe(Operator):
regexp: str regexp: str
@ -28,24 +35,94 @@ class StrRe(Operator):
return f"StrRe(r'{self.regexp}')" return f"StrRe(r'{self.regexp}')"
@dataclasses.dataclass(frozen=True)
class RemainingKeys(Operator):
"""Used in a dict pattern to match all remaining keys.
May only be present once."""
key: Operator
def __repr__(self) -> str:
return f"Keys({self.key!r})"
ANYSTR = AnyStr() ANYSTR = AnyStr()
"""Singleton, spares two characters""" """Singleton, spares two characters"""
ANYDICT = {RemainingKeys(ANYSTR): AnyOptStr()}
"""Matches any dictionary; useful to compare tags dict, eg.
`match_dict(got_tags, {"label": "foo", **ANYDICT})`"""
def match_list(got: List[str], expected: List[Union[str, Operator]]) -> bool:
def match_string(got: Optional[str], expected: Union[str, Operator, None]) -> bool:
if isinstance(expected, AnyOptStr):
return True
elif isinstance(expected, AnyStr) and got is not None:
return True
elif isinstance(expected, StrRe):
if got is None or not re.match(expected.regexp, got):
return False
elif isinstance(expected, Operator):
raise NotImplementedError(f"Unsupported operator: {expected}")
elif got != expected:
return False
return True
def match_list(
got: List[Optional[str]], expected: List[Union[str, None, Operator]]
) -> bool:
"""Returns True iff the list are equal. """Returns True iff the list are equal.
The ellipsis (aka. "..." aka triple dots) can be used on the 'expected'
side as a wildcard, matching any *single* value.""" The ANYSTR operator can be used on the 'expected' side as a wildcard,
matching any *single* value; and StrRe("<regexp>") can be used to match regular
expressions"""
if len(got) != len(expected): if len(got) != len(expected):
return False return False
for (got_value, expected_value) in zip(got, expected): return all(
if isinstance(expected_value, AnyStr): match_string(got_value, expected_value)
# wildcard for (got_value, expected_value) in zip(got, expected)
continue )
elif isinstance(expected_value, StrRe):
if not re.match(expected_value.regexp, got_value):
return False def match_dict(
got: Dict[str, Optional[str]],
expected: Dict[Union[str, Operator], Union[str, Operator, None]],
) -> bool:
"""Returns True iff the list are equal.
The ANYSTR operator can be used on the 'expected' side as a wildcard,
matching any *single* value; and StrRe("<regexp>") can be used to match regular
expressions
Additionally, the Keys() operator can be used to match remaining keys, and
ANYDICT to match any remaining dict"""
got = dict(got) # shallow copy, as we will remove keys
# Set to not-None if we find a Keys() operator in the dict keys
remaining_keys_wildcard = None
for (expected_key, expected_value) in expected.items():
if isinstance(expected_key, RemainingKeys):
remaining_keys_wildcard = (expected_key.key, expected_value)
elif isinstance(expected_key, Operator):
raise NotImplementedError(f"Unsupported operator: {expected_key}")
else: else:
if got_value != expected_value: if expected_key not in got:
return False return False
return True got_value = got.pop(expected_key)
if not match_string(got_value, expected_value):
return False
if remaining_keys_wildcard:
(expected_key, expected_value) = remaining_keys_wildcard
for (key, value) in got.items():
if not match_string(key, expected_key):
return False
if not match_string(value, expected_value):
return False
return True
else:
# There should be nothing left unmatched in the dict
return got == {}

View File

@ -43,23 +43,5 @@ class AccountTagTestCase(cases.BaseServerTestCase, cases.OptionalityHelper):
self.getMessages(2) self.getMessages(2)
m = self.getMessage(1) m = self.getMessage(1)
self.assertMessageMatch( self.assertMessageMatch(
m, m, command="PRIVMSG", params=["foo", "hi"], tags={"account": "jilles"}
command="PRIVMSG",
params=["foo", "hi"],
fail_msg="Expected a private PRIVMSG from 'bar', got: {msg}",
)
self.assertIn(
"account",
m.tags,
m,
fail_msg="PRIVMSG by logged-in nick "
"does not contain an account tag: {msg}",
)
self.assertEqual(
m.tags["account"],
"jilles",
m,
fail_msg="PRIVMSG by logged-in nick "
"does not contain the correct account tag (should be "
"“jilles”): {msg}",
) )

View File

@ -5,6 +5,7 @@
from irctest import cases from irctest import cases
from irctest.basecontrollers import NotImplementedByController from irctest.basecontrollers import NotImplementedByController
from irctest.irc_utils.junkdrawer import random_name from irctest.irc_utils.junkdrawer import random_name
from irctest.patma import ANYDICT
def _testEchoMessage(command, solo, server_time): def _testEchoMessage(command, solo, server_time):
@ -123,22 +124,22 @@ class EchoMessageTestCase(cases.BaseServerTestCase):
echo = self.getMessages(bar)[0] echo = self.getMessages(bar)[0]
delivery = self.getMessages(qux)[0] delivery = self.getMessages(qux)[0]
self.assertEqual(delivery.params, [qux, "hi there"]) self.assertMessageMatch(
self.assertEqual(delivery.params, echo.params) echo,
command="PRIVMSG",
params=[qux, "hi there"],
tags={"label": "xyz", "+example-client-tag": "example-value", **ANYDICT},
)
self.assertMessageMatch(
delivery,
command="PRIVMSG",
params=[qux, "hi there"],
tags={"+example-client-tag": "example-value", **ANYDICT},
)
# Either both messages have a msgid, or neither does # Either both messages have a msgid, or neither does
self.assertEqual(delivery.tags.get("msgid"), echo.tags.get("msgid")) self.assertEqual(delivery.tags.get("msgid"), echo.tags.get("msgid"))
self.assertEqual(
echo.tags.get("label"),
"xyz",
fail_msg="expected message label 'xyz', but got {got!r}",
)
self.assertEqual(delivery.tags["+example-client-tag"], "example-value")
self.assertEqual(
delivery.tags["+example-client-tag"], echo.tags["+example-client-tag"]
)
testEchoMessagePrivmsgNoServerTime = _testEchoMessage("PRIVMSG", False, False) testEchoMessagePrivmsgNoServerTime = _testEchoMessage("PRIVMSG", False, False)
testEchoMessagePrivmsgSolo = _testEchoMessage("PRIVMSG", True, True) testEchoMessagePrivmsgSolo = _testEchoMessage("PRIVMSG", True, True)
testEchoMessagePrivmsg = _testEchoMessage("PRIVMSG", False, True) testEchoMessagePrivmsg = _testEchoMessage("PRIVMSG", False, True)

View File

@ -8,7 +8,7 @@ so there may be many false positives.
import re import re
from irctest import cases from irctest import cases
from irctest.patma import StrRe from irctest.patma import ANYDICT, StrRe
class LabeledResponsesTestCase(cases.BaseServerTestCase, cases.OptionalityHelper): class LabeledResponsesTestCase(cases.BaseServerTestCase, cases.OptionalityHelper):
@ -53,51 +53,13 @@ class LabeledResponsesTestCase(cases.BaseServerTestCase, cases.OptionalityHelper
m4 = self.getMessage(4) m4 = self.getMessage(4)
# ensure the label isn't sent to recipients # ensure the label isn't sent to recipients
self.assertMessageMatch( self.assertMessageMatch(m2, command="PRIVMSG", tags={})
m2,
command="PRIVMSG",
fail_msg="No PRIVMSG received by target 1 after sending one out",
)
self.assertNotIn(
"label",
m2.tags,
m2,
fail_msg=(
"When sending a PRIVMSG with a label, "
"the target users shouldn't receive the label "
"(only the sending user should): {msg}"
),
)
self.assertMessageMatch( self.assertMessageMatch(
m3, m3,
command="PRIVMSG", command="PRIVMSG",
fail_msg="No PRIVMSG received by target 1 after sending one out", tags={},
)
self.assertNotIn(
"label",
m3.tags,
m3,
fail_msg=(
"When sending a PRIVMSG with a label, "
"the target users shouldn't receive the label "
"(only the sending user should): {msg}"
),
)
self.assertMessageMatch(
m4,
command="PRIVMSG",
fail_msg="No PRIVMSG received by target 1 after sending one out",
)
self.assertNotIn(
"label",
m4.tags,
m4,
fail_msg=(
"When sending a PRIVMSG with a label, "
"the target users shouldn't receive the label "
"(only the sending user should): {msg}"
),
) )
self.assertMessageMatch(m4, command="PRIVMSG", tags={})
self.assertMessageMatch( self.assertMessageMatch(
m, command="BATCH", fail_msg="No BATCH echo received after sending one out" m, command="BATCH", fail_msg="No BATCH echo received after sending one out"
@ -123,45 +85,9 @@ class LabeledResponsesTestCase(cases.BaseServerTestCase, cases.OptionalityHelper
m2 = self.getMessage(2) m2 = self.getMessage(2)
# ensure the label isn't sent to recipient # ensure the label isn't sent to recipient
self.assertMessageMatch( self.assertMessageMatch(m2, command="PRIVMSG", tags={})
m2,
command="PRIVMSG",
fail_msg="No PRIVMSG received by the target after sending one out",
)
self.assertNotIn(
"label",
m2.tags,
m2,
fail_msg=(
"When sending a PRIVMSG with a label, "
"the target user shouldn't receive the label "
"(only the sending user should): {msg}"
),
)
self.assertMessageMatch( self.assertMessageMatch(m, command="PRIVMSG", tags={"label": "12345"})
m,
command="PRIVMSG",
fail_msg="No PRIVMSG echo received after sending one out",
)
self.assertIn(
"label",
m.tags,
m,
fail_msg=(
"When sending a PRIVMSG with a label, "
"the echo'd message didn't contain the label at all: {msg}"
),
)
self.assertEqual(
m.tags["label"],
"12345",
m,
fail_msg=(
"Echo'd PRIVMSG to a client did not contain the same label "
"we sent it with(should be '12345'): {msg}"
),
)
@cases.mark_capabilities("echo-message", "labeled-response") @cases.mark_capabilities("echo-message", "labeled-response")
def testLabeledPrivmsgResponsesToChannel(self): def testLabeledPrivmsgResponsesToChannel(self):
@ -192,44 +118,10 @@ class LabeledResponsesTestCase(cases.BaseServerTestCase, cases.OptionalityHelper
mt = self.getMessage(2) mt = self.getMessage(2)
# ensure the label isn't sent to recipient # ensure the label isn't sent to recipient
self.assertMessageMatch( self.assertMessageMatch(mt, command="PRIVMSG", tags={})
mt,
command="PRIVMSG",
fail_msg="No PRIVMSG received by the target after sending one out",
)
self.assertNotIn(
"label",
mt.tags,
mt,
fail_msg=(
"When sending a PRIVMSG with a label, "
"the target user shouldn't receive the label "
"(only the sending user should): {msg}"
),
)
# ensure sender correctly receives msg # ensure sender correctly receives msg
self.assertMessageMatch( self.assertMessageMatch(ms, command="PRIVMSG", tags={"label": "12345"})
ms, command="PRIVMSG", fail_msg="Got a message back that wasn't a PRIVMSG"
)
self.assertIn(
"label",
ms.tags,
ms,
fail_msg=(
"When sending a PRIVMSG with a label, "
"the source user should receive the label but didn't: {msg}"
),
)
self.assertEqual(
ms.tags["label"],
"12345",
ms,
fail_msg=(
"Echo'd label doesn't match the label we sent "
"(should be '12345'): {msg}"
),
)
@cases.mark_capabilities("echo-message", "labeled-response") @cases.mark_capabilities("echo-message", "labeled-response")
def testLabeledPrivmsgResponsesToSelf(self): def testLabeledPrivmsgResponsesToSelf(self):
@ -294,45 +186,9 @@ class LabeledResponsesTestCase(cases.BaseServerTestCase, cases.OptionalityHelper
m2 = self.getMessage(2) m2 = self.getMessage(2)
# ensure the label isn't sent to recipient # ensure the label isn't sent to recipient
self.assertMessageMatch( self.assertMessageMatch(m2, command="NOTICE", tags={})
m2,
command="NOTICE",
fail_msg="No NOTICE received by the target after sending one out",
)
self.assertNotIn(
"label",
m2.tags,
m2,
fail_msg=(
"When sending a NOTICE with a label, "
"the target user shouldn't receive the label "
"(only the sending user should): {msg}"
),
)
self.assertMessageMatch( self.assertMessageMatch(m, command="NOTICE", tags={"label": "12345"})
m,
command="NOTICE",
fail_msg="No NOTICE echo received after sending one out",
)
self.assertIn(
"label",
m.tags,
m,
fail_msg=(
"When sending a NOTICE with a label, "
"the echo'd message didn't contain the label at all: {msg}"
),
)
self.assertEqual(
m.tags["label"],
"12345",
m,
fail_msg=(
"Echo'd NOTICE to a client did not contain the same label "
"we sent it with (should be '12345'): {msg}"
),
)
@cases.mark_capabilities("echo-message", "labeled-response") @cases.mark_capabilities("echo-message", "labeled-response")
def testLabeledNoticeResponsesToChannel(self): def testLabeledNoticeResponsesToChannel(self):
@ -363,44 +219,10 @@ class LabeledResponsesTestCase(cases.BaseServerTestCase, cases.OptionalityHelper
mt = self.getMessage(2) mt = self.getMessage(2)
# ensure the label isn't sent to recipient # ensure the label isn't sent to recipient
self.assertMessageMatch( self.assertMessageMatch(mt, command="NOTICE", tags={})
mt,
command="NOTICE",
fail_msg="No NOTICE received by the target after sending one out",
)
self.assertNotIn(
"label",
mt.tags,
mt,
fail_msg=(
"When sending a NOTICE with a label, "
"the target user shouldn't receive the label "
"(only the sending user should): {msg}"
),
)
# ensure sender correctly receives msg # ensure sender correctly receives msg
self.assertMessageMatch( self.assertMessageMatch(ms, command="NOTICE", tags={"label": "12345"})
ms, command="NOTICE", fail_msg="Got a message back that wasn't a NOTICE"
)
self.assertIn(
"label",
ms.tags,
ms,
fail_msg=(
"When sending a NOTICE with a label, "
"the source user should receive the label but didn't: {msg}"
),
)
self.assertEqual(
ms.tags["label"],
"12345",
ms,
fail_msg=(
"Echo'd label doesn't match the label we sent "
"(should be '12345'): {msg}"
),
)
@cases.mark_capabilities("echo-message", "labeled-response") @cases.mark_capabilities("echo-message", "labeled-response")
def testLabeledNoticeResponsesToSelf(self): def testLabeledNoticeResponsesToSelf(self):
@ -466,7 +288,7 @@ class LabeledResponsesTestCase(cases.BaseServerTestCase, cases.OptionalityHelper
self.assertMessageMatch( self.assertMessageMatch(
m2, m2,
command="TAGMSG", command="TAGMSG",
fail_msg="No TAGMSG received by the target after sending one out", tags={"+draft/reply": "123", "+draft/react": "l😃l", **ANYDICT},
) )
self.assertNotIn( self.assertNotIn(
"label", "label",
@ -478,77 +300,16 @@ class LabeledResponsesTestCase(cases.BaseServerTestCase, cases.OptionalityHelper
"(only the sending user should): {msg}" "(only the sending user should): {msg}"
), ),
) )
self.assertIn(
"+draft/reply",
m2.tags,
m2,
fail_msg="Reply tag wasn't present on the target user's TAGMSG: {msg}",
)
self.assertEqual(
m2.tags["+draft/reply"],
"123",
m2,
fail_msg="Reply tag wasn't the same on the target user's TAGMSG: {msg}",
)
self.assertIn(
"+draft/react",
m2.tags,
m2,
fail_msg="React tag wasn't present on the target user's TAGMSG: {msg}",
)
self.assertEqual(
m2.tags["+draft/react"],
"l😃l",
m2,
fail_msg="React tag wasn't the same on the target user's TAGMSG: {msg}",
)
self.assertMessageMatch( self.assertMessageMatch(
m, m,
command="TAGMSG", command="TAGMSG",
fail_msg="No TAGMSG echo received after sending one out", tags={
) "label": "12345",
self.assertIn( "+draft/reply": "123",
"label", "+draft/react": "l😃l",
m.tags, **ANYDICT,
m, },
fail_msg=(
"When sending a TAGMSG with a label, "
"the echo'd message didn't contain the label at all: {msg}"
),
)
self.assertEqual(
m.tags["label"],
"12345",
m,
fail_msg=(
"Echo'd TAGMSG to a client did not contain the same label "
"we sent it with (should be '12345'): {msg}"
),
)
self.assertIn(
"+draft/reply",
m.tags,
m,
fail_msg="Reply tag wasn't present on the source user's TAGMSG: {msg}",
)
self.assertEqual(
m2.tags["+draft/reply"],
"123",
m,
fail_msg="Reply tag wasn't the same on the source user's TAGMSG: {msg}",
)
self.assertIn(
"+draft/react",
m.tags,
m,
fail_msg="React tag wasn't present on the source user's TAGMSG: {msg}",
)
self.assertEqual(
m2.tags["+draft/react"],
"l😃l",
m,
fail_msg="React tag wasn't the same on the source user's TAGMSG: {msg}",
) )
@cases.mark_capabilities("echo-message", "labeled-response", "message-tags") @cases.mark_capabilities("echo-message", "labeled-response", "message-tags")
@ -596,25 +357,7 @@ class LabeledResponsesTestCase(cases.BaseServerTestCase, cases.OptionalityHelper
# ensure sender correctly receives msg # ensure sender correctly receives msg
self.assertMessageMatch( self.assertMessageMatch(
ms, command="TAGMSG", fail_msg="Got a message back that wasn't a TAGMSG" ms, command="TAGMSG", tags={"label": "12345", **ANYDICT}
)
self.assertIn(
"label",
ms.tags,
ms,
fail_msg=(
"When sending a TAGMSG with a label, "
"the source user should receive the label but didn't: {msg}"
),
)
self.assertEqual(
ms.tags["label"],
"12345",
ms,
fail_msg=(
"Echo'd label doesn't match the label we sent "
"(should be '12345'): {msg}"
),
) )
@cases.mark_capabilities("echo-message", "labeled-response", "message-tags") @cases.mark_capabilities("echo-message", "labeled-response", "message-tags")
@ -706,15 +449,10 @@ class LabeledResponsesTestCase(cases.BaseServerTestCase, cases.OptionalityHelper
self.sendLine(1, "@label=98765 PING adhoctestline") self.sendLine(1, "@label=98765 PING adhoctestline")
# no BATCH should be initiated for a one-line response, # no BATCH should be initiated for a one-line response,
# it should just be labeled # it should just be labeled
ms = self.getMessages(1) m = self.getMessage(1)
self.assertEqual(len(ms), 1) self.assertMessageMatch(m, command="PONG", tags={"label": "98765"})
m = ms[0]
self.assertEqual(m.command, "PONG")
self.assertEqual(m.params[-1], "adhoctestline") self.assertEqual(m.params[-1], "adhoctestline")
# check the label
self.assertEqual(m.tags.get("label"), "98765")
@cases.mark_capabilities("labeled-response") @cases.mark_capabilities("labeled-response")
def testEmptyBatchForNoResponse(self): def testEmptyBatchForNoResponse(self):
self.connectClient( self.connectClient(
@ -732,5 +470,4 @@ class LabeledResponsesTestCase(cases.BaseServerTestCase, cases.OptionalityHelper
self.assertEqual(len(ms), 1) self.assertEqual(len(ms), 1)
ack = ms[0] ack = ms[0]
self.assertEqual(ack.command, "ACK") self.assertMessageMatch(ack, command="ACK", tags={"label": "98765"})
self.assertEqual(ack.tags.get("label"), "98765")

View File

@ -5,6 +5,7 @@ https://ircv3.net/specs/extensions/message-tags.html
from irctest import cases from irctest import cases
from irctest.irc_utils.message_parser import parse_message from irctest.irc_utils.message_parser import parse_message
from irctest.numerics import ERR_INPUTTOOLONG from irctest.numerics import ERR_INPUTTOOLONG
from irctest.patma import ANYDICT, ANYSTR, StrRe
class MessageTagsTestCase(cases.BaseServerTestCase, cases.OptionalityHelper): class MessageTagsTestCase(cases.BaseServerTestCase, cases.OptionalityHelper):
@ -40,9 +41,12 @@ class MessageTagsTestCase(cases.BaseServerTestCase, cases.OptionalityHelper):
self.getMessages("alice") self.getMessages("alice")
bob_msg = self.getMessage("bob") bob_msg = self.getMessage("bob")
carol_line = self.getMessage("carol", raw=True) carol_line = self.getMessage("carol", raw=True)
self.assertMessageMatch(bob_msg, command="PRIVMSG", params=["#test", "hi"]) self.assertMessageMatch(
self.assertEqual(bob_msg.tags["+baz"], "bat") bob_msg,
self.assertIn("msgid", bob_msg.tags) command="PRIVMSG",
params=["#test", "hi"],
tags={"+baz": "bat", "msgid": ANYSTR, **ANYDICT},
)
# should not relay a non-client-only tag # should not relay a non-client-only tag
self.assertNotIn("fizz", bob_msg.tags) self.assertNotIn("fizz", bob_msg.tags)
# carol MUST NOT receive tags # carol MUST NOT receive tags
@ -50,7 +54,12 @@ class MessageTagsTestCase(cases.BaseServerTestCase, cases.OptionalityHelper):
self.assertMessageMatch(carol_msg, command="PRIVMSG", params=["#test", "hi"]) self.assertMessageMatch(carol_msg, command="PRIVMSG", params=["#test", "hi"])
# dave SHOULD receive server-time tag # dave SHOULD receive server-time tag
dave_msg = self.getMessage("dave") dave_msg = self.getMessage("dave")
self.assertIn("time", dave_msg.tags) self.assertMessageMatch(
dave_msg,
command="PRIVMSG",
params=["#test", "hi"],
tags={"time": ANYSTR, **ANYDICT},
)
# dave MUST NOT receive client-only tags # dave MUST NOT receive client-only tags
self.assertNotIn("+baz", dave_msg.tags) self.assertNotIn("+baz", dave_msg.tags)
getAllMessages() getAllMessages()
@ -60,14 +69,18 @@ class MessageTagsTestCase(cases.BaseServerTestCase, cases.OptionalityHelper):
alice_msg = self.getMessage("alice") alice_msg = self.getMessage("alice")
carol_line = self.getMessage("carol", raw=True) carol_line = self.getMessage("carol", raw=True)
carol_msg = assertNoTags(carol_line) carol_msg = assertNoTags(carol_line)
for msg in [alice_msg, bob_msg, carol_msg]:
self.assertMessageMatch(
msg, command="PRIVMSG", params=["#test", "hi yourself"]
)
for msg in [alice_msg, bob_msg]: for msg in [alice_msg, bob_msg]:
self.assertEqual(msg.tags["+bat"], "baz") self.assertMessageMatch(
self.assertEqual(msg.tags["+fizz"], "buzz") msg,
self.assertTrue(alice_msg.tags["msgid"]) command="PRIVMSG",
params=["#test", "hi yourself"],
tags={"+bat": "baz", "+fizz": "buzz", "msgid": ANYSTR, **ANYDICT},
)
self.assertMessageMatch(
carol_msg,
command="PRIVMSG",
params=["#test", "hi yourself"],
)
self.assertEqual(alice_msg.tags["msgid"], bob_msg.tags["msgid"]) self.assertEqual(alice_msg.tags["msgid"], bob_msg.tags["msgid"])
getAllMessages() getAllMessages()
@ -80,11 +93,18 @@ class MessageTagsTestCase(cases.BaseServerTestCase, cases.OptionalityHelper):
# dave MUST NOT receive TAGMSG either, despite having server-time # dave MUST NOT receive TAGMSG either, despite having server-time
self.assertEqual(self.getMessages("dave"), []) self.assertEqual(self.getMessages("dave"), [])
for msg in [alice_msg, bob_msg]: for msg in [alice_msg, bob_msg]:
self.assertMessageMatch(alice_msg, command="TAGMSG", params=["#test"]) self.assertMessageMatch(
self.assertEqual(msg.tags["+buzz"], "fizz;buzz") alice_msg,
self.assertEqual(msg.tags["+steel"], "wootz") command="TAGMSG",
params=["#test"],
tags={
"+buzz": "fizz;buzz",
"+steel": "wootz",
"msgid": ANYSTR,
**ANYDICT,
},
)
self.assertNotIn("cat", msg.tags) self.assertNotIn("cat", msg.tags)
self.assertTrue(alice_msg.tags["msgid"])
self.assertEqual(alice_msg.tags["msgid"], bob_msg.tags["msgid"]) self.assertEqual(alice_msg.tags["msgid"], bob_msg.tags["msgid"])
@cases.mark_capabilities("message-tags") @cases.mark_capabilities("message-tags")
@ -108,12 +128,19 @@ class MessageTagsTestCase(cases.BaseServerTestCase, cases.OptionalityHelper):
self.sendLine("alice", max_tagmsg) self.sendLine("alice", max_tagmsg)
echo = self.getMessage("alice") echo = self.getMessage("alice")
relay = self.getMessage("bob") relay = self.getMessage("bob")
self.assertMessageMatch(echo, command="TAGMSG", params=["#test"]) self.assertMessageMatch(
self.assertMessageMatch(relay, command="TAGMSG", params=["#test"]) echo,
self.assertNotEqual(echo.tags["msgid"], "") command="TAGMSG",
params=["#test"],
tags={"+baz": "a" * 4081, "msgid": StrRe(".+"), **ANYDICT},
)
self.assertMessageMatch(
relay,
command="TAGMSG",
params=["#test"],
tags={"+baz": "a" * 4081, "msgid": StrRe(".+"), **ANYDICT},
)
self.assertEqual(echo.tags["msgid"], relay.tags["msgid"]) self.assertEqual(echo.tags["msgid"], relay.tags["msgid"])
self.assertEqual(echo.tags["+baz"], "a" * 4081)
self.assertEqual(relay.tags["+baz"], echo.tags["+baz"])
excess_tagmsg = "@foo=bar;+baz=%s TAGMSG #test" % ("a" * 4082,) excess_tagmsg = "@foo=bar;+baz=%s TAGMSG #test" % ("a" * 4082,)
self.assertEqual(excess_tagmsg.index("TAGMSG"), 4097) self.assertEqual(excess_tagmsg.index("TAGMSG"), 4097)
@ -128,10 +155,19 @@ class MessageTagsTestCase(cases.BaseServerTestCase, cases.OptionalityHelper):
self.sendLine("alice", max_privmsg) self.sendLine("alice", max_privmsg)
echo = self.getMessage("alice") echo = self.getMessage("alice")
relay = self.getMessage("bob") relay = self.getMessage("bob")
self.assertNotEqual(echo.tags["msgid"], "") self.assertMessageMatch(
echo,
command="PRIVMSG",
params=["#test", StrRe("b{400,496}")],
tags={"+baz": "a" * 4081, "msgid": StrRe(".+"), **ANYDICT},
)
self.assertMessageMatch(
relay,
command="PRIVMSG",
params=["#test", StrRe("b{400,496}")],
tags={"+baz": "a" * 4081, "msgid": StrRe(".+"), **ANYDICT},
)
self.assertEqual(echo.tags["msgid"], relay.tags["msgid"]) self.assertEqual(echo.tags["msgid"], relay.tags["msgid"])
self.assertEqual(echo.tags["+baz"], "a" * 4081)
self.assertEqual(relay.tags["+baz"], echo.tags["+baz"])
# message may have been truncated # message may have been truncated
self.assertIn("b" * 400, echo.params[1]) self.assertIn("b" * 400, echo.params[1])
self.assertEqual(echo.params[1].rstrip("b"), "") self.assertEqual(echo.params[1].rstrip("b"), "")

View File

@ -3,7 +3,7 @@ draft/multiline
""" """
from irctest import cases from irctest import cases
from irctest.patma import StrRe from irctest.patma import ANYDICT, StrRe
CAP_NAME = "draft/multiline" CAP_NAME = "draft/multiline"
BATCH_TYPE = "draft/multiline" BATCH_TYPE = "draft/multiline"
@ -37,7 +37,10 @@ class MultilineTestCase(cases.BaseServerTestCase, cases.OptionalityHelper):
echo = self.getMessages(1) echo = self.getMessages(1)
batchStart, batchEnd = echo[0], echo[-1] batchStart, batchEnd = echo[0], echo[-1]
self.assertMessageMatch( self.assertMessageMatch(
batchStart, command="BATCH", params=[StrRe(r"\+.*"), BATCH_TYPE, "#test"] batchStart,
command="BATCH",
params=[StrRe(r"\+.*"), BATCH_TYPE, "#test"],
tags={"label": "xyz", **ANYDICT},
) )
self.assertEqual(batchStart.tags.get("label"), "xyz") self.assertEqual(batchStart.tags.get("label"), "xyz")
self.assertMessageMatch(batchEnd, command="BATCH", params=[StrRe("-.*")]) self.assertMessageMatch(batchEnd, command="BATCH", params=[StrRe("-.*")])

View File

@ -4,6 +4,7 @@ Regression tests for bugs in oragono.
from irctest import cases from irctest import cases
from irctest.numerics import ERR_ERRONEUSNICKNAME, ERR_NICKNAMEINUSE, RPL_WELCOME from irctest.numerics import ERR_ERRONEUSNICKNAME, ERR_NICKNAMEINUSE, RPL_WELCOME
from irctest.patma import ANYDICT
class RegressionsTestCase(cases.BaseServerTestCase): class RegressionsTestCase(cases.BaseServerTestCase):
@ -67,19 +68,19 @@ class RegressionsTestCase(cases.BaseServerTestCase):
self.sendLine( self.sendLine(
1, "@+draft/reply=ct95w3xemz8qj9du2h74wp8pee PRIVMSG bob :hey yourself" 1, "@+draft/reply=ct95w3xemz8qj9du2h74wp8pee PRIVMSG bob :hey yourself"
) )
ms = self.getMessages(1)
self.assertEqual(len(ms), 1)
self.assertMessageMatch( self.assertMessageMatch(
ms[0], command="PRIVMSG", params=["bob", "hey yourself"] self.getMessage(1),
command="PRIVMSG",
params=["bob", "hey yourself"],
tags={"+draft/reply": "ct95w3xemz8qj9du2h74wp8pee", **ANYDICT},
) )
self.assertEqual(ms[0].tags.get("+draft/reply"), "ct95w3xemz8qj9du2h74wp8pee")
ms = self.getMessages(2)
self.assertEqual(len(ms), 1)
self.assertMessageMatch( self.assertMessageMatch(
ms[0], command="PRIVMSG", params=["bob", "hey yourself"] self.getMessage(2),
command="PRIVMSG",
params=["bob", "hey yourself"],
tags={},
) )
self.assertEqual(ms[0].tags, {})
self.sendLine(2, "CAP REQ :message-tags server-time") self.sendLine(2, "CAP REQ :message-tags server-time")
self.getMessages(2) self.getMessages(2)
@ -87,11 +88,13 @@ class RegressionsTestCase(cases.BaseServerTestCase):
1, "@+draft/reply=tbxqauh9nykrtpa3n6icd9whan PRIVMSG bob :hey again" 1, "@+draft/reply=tbxqauh9nykrtpa3n6icd9whan PRIVMSG bob :hey again"
) )
self.getMessages(1) self.getMessages(1)
ms = self.getMessages(2)
# now bob has the tags cap, so he should receive the tags # now bob has the tags cap, so he should receive the tags
self.assertEqual(len(ms), 1) self.assertMessageMatch(
self.assertMessageMatch(ms[0], command="PRIVMSG", params=["bob", "hey again"]) self.getMessage(2),
self.assertEqual(ms[0].tags.get("+draft/reply"), "tbxqauh9nykrtpa3n6icd9whan") command="PRIVMSG",
params=["bob", "hey again"],
tags={"+draft/reply": "tbxqauh9nykrtpa3n6icd9whan", **ANYDICT},
)
@cases.mark_specifications("RFC1459") @cases.mark_specifications("RFC1459")
def testStarNick(self): def testStarNick(self):

View File

@ -18,8 +18,6 @@ class RelaymsgTestCase(cases.BaseServerTestCase):
"baz", "baz",
name="baz", name="baz",
capabilities=[ capabilities=[
"server-time",
"message-tags",
"batch", "batch",
"labeled-response", "labeled-response",
"echo-message", "echo-message",
@ -31,8 +29,6 @@ class RelaymsgTestCase(cases.BaseServerTestCase):
"qux", "qux",
name="qux", name="qux",
capabilities=[ capabilities=[
"server-time",
"message-tags",
"batch", "batch",
"labeled-response", "labeled-response",
"echo-message", "echo-message",
@ -74,9 +70,12 @@ class RelaymsgTestCase(cases.BaseServerTestCase):
self.sendLine("baz", "@label=x RELAYMSG %s smt/discord :hi again" % (chname,)) self.sendLine("baz", "@label=x RELAYMSG %s smt/discord :hi again" % (chname,))
response = self.getMessages("baz")[0] response = self.getMessages("baz")[0]
self.assertMessageMatch( self.assertMessageMatch(
response, nick="smt/discord", command="PRIVMSG", params=[chname, "hi again"] response,
nick="smt/discord",
command="PRIVMSG",
params=[chname, "hi again"],
tags={"label": "x"},
) )
self.assertEqual(response.tags.get("label"), "x")
relayed_msg = self.getMessages("qux")[0] relayed_msg = self.getMessages("qux")[0]
self.assertMessageMatch( self.assertMessageMatch(
relayed_msg, relayed_msg,
@ -118,8 +117,8 @@ class RelaymsgTestCase(cases.BaseServerTestCase):
nick="smt/discord", nick="smt/discord",
command="PRIVMSG", command="PRIVMSG",
params=[chname, "hi a third time"], params=[chname, "hi a third time"],
tags={RELAYMSG_TAG_NAME: "qux"},
) )
self.assertEqual(relayed_msg.tags.get(RELAYMSG_TAG_NAME), "qux")
self.sendLine("baz", "CHATHISTORY LATEST %s * 10" % (chname,)) self.sendLine("baz", "CHATHISTORY LATEST %s * 10" % (chname,))
messages = self.getMessages("baz") messages = self.getMessages("baz")

View File

@ -6,6 +6,7 @@ import secrets
from irctest import cases from irctest import cases
from irctest.numerics import RPL_AWAY from irctest.numerics import RPL_AWAY
from irctest.patma import ANYDICT, ANYSTR
ANCIENT_TIMESTAMP = "2006-01-02T15:04:05.999Z" ANCIENT_TIMESTAMP = "2006-01-02T15:04:05.999Z"
@ -59,7 +60,10 @@ class ResumeTestCase(cases.BaseServerTestCase):
privmsgs[0], command="PRIVMSG", params=[chname, "hello friends"] privmsgs[0], command="PRIVMSG", params=[chname, "hello friends"]
) )
self.assertMessageMatch( self.assertMessageMatch(
privmsgs[1], command="PRIVMSG", params=["mainnick", "hello friend singular"] privmsgs[1],
command="PRIVMSG",
params=["mainnick", "hello friend singular"],
tags={"time": ANYSTR, **ANYDICT},
) )
channelMsgTime = privmsgs[0].tags.get("time") channelMsgTime = privmsgs[0].tags.get("time")
@ -121,15 +125,17 @@ class ResumeTestCase(cases.BaseServerTestCase):
self.assertEqual(len(privmsgs), 2) self.assertEqual(len(privmsgs), 2)
privmsgs.sort(key=lambda m: m.params[0]) privmsgs.sort(key=lambda m: m.params[0])
self.assertMessageMatch( self.assertMessageMatch(
privmsgs[0], command="PRIVMSG", params=[chname, "hello friends"] privmsgs[0],
) command="PRIVMSG",
self.assertMessageMatch( params=[chname, "hello friends"],
privmsgs[1], command="PRIVMSG", params=["mainnick", "hello friend singular"] tags={"time": channelMsgTime, **ANYDICT},
) )
# should replay with the original server-time # should replay with the original server-time
# TODO this probably isn't testing anything because the timestamp only # TODO this probably isn't testing anything because the timestamp only
# has second resolution, hence will typically match by accident # has second resolution, hence will typically match by accident
self.assertEqual(privmsgs[0].tags.get("time"), channelMsgTime) self.assertMessageMatch(
privmsgs[1], command="PRIVMSG", params=["mainnick", "hello friend singular"]
)
# legacy client should receive a QUIT and a JOIN # legacy client should receive a QUIT and a JOIN
quit, join = [m for m in self.getMessages(1) if m.command in ("QUIT", "JOIN")] quit, join = [m for m in self.getMessages(1) if m.command in ("QUIT", "JOIN")]

View File

@ -1,4 +1,5 @@
from irctest import cases from irctest import cases
from irctest.patma import ANYSTR
class Utf8TestCase(cases.BaseServerTestCase, cases.OptionalityHelper): class Utf8TestCase(cases.BaseServerTestCase, cases.OptionalityHelper):
@ -6,7 +7,7 @@ class Utf8TestCase(cases.BaseServerTestCase, cases.OptionalityHelper):
def testUtf8Validation(self): def testUtf8Validation(self):
self.connectClient( self.connectClient(
"bar", "bar",
capabilities=["batch", "echo-message", "labeled-response", "message-tags"], capabilities=["batch", "echo-message", "labeled-response"],
) )
self.joinChannel(1, "#qux") self.joinChannel(1, "#qux")
self.sendLine(1, "PRIVMSG #qux hi") self.sendLine(1, "PRIVMSG #qux hi")
@ -16,14 +17,17 @@ class Utf8TestCase(cases.BaseServerTestCase, cases.OptionalityHelper):
) )
self.sendLine(1, b"PRIVMSG #qux hi\xaa") self.sendLine(1, b"PRIVMSG #qux hi\xaa")
ms = self.getMessages(1) self.assertMessageMatch(
self.assertEqual(len(ms), 1) self.getMessage(1),
self.assertEqual(ms[0].command, "FAIL") command="FAIL",
self.assertEqual(ms[0].params[:2], ["PRIVMSG", "INVALID_UTF8"]) params=["PRIVMSG", "INVALID_UTF8", ANYSTR],
tags={},
)
self.sendLine(1, b"@label=xyz PRIVMSG #qux hi\xaa") self.sendLine(1, b"@label=xyz PRIVMSG #qux hi\xaa")
ms = self.getMessages(1) self.assertMessageMatch(
self.assertEqual(len(ms), 1) self.getMessage(1),
self.assertEqual(ms[0].command, "FAIL") command="FAIL",
self.assertEqual(ms[0].params[:2], ["PRIVMSG", "INVALID_UTF8"]) params=["PRIVMSG", "INVALID_UTF8", ANYSTR],
self.assertEqual(ms[0].tags.get("label"), "xyz") tags={"label": "xyz"},
)