Add a bunch of GetSet tests

This commit is contained in:
2023-04-15 19:02:04 +02:00
parent f18162c02d
commit 34f078e04a

View File

@ -11,13 +11,18 @@ import pytest
from irctest import cases, runner from irctest import cases, runner
from irctest.patma import ANYDICT, ANYSTR, StrRe from irctest.patma import ANYDICT, ANYSTR, StrRe
CLIENT_NICKS = {
1: "foo",
2: "bar",
}
class MetadataTestCase(cases.BaseServerTestCase): class MetadataTestCase(cases.BaseServerTestCase):
valid_metadata_keys = {"valid_key1", "valid_key2"} valid_metadata_keys = {"valid_key1", "valid_key2"}
invalid_metadata_keys = {"invalid_key1", "invalid_key2"} invalid_metadata_keys = {"invalid_key1", "invalid_key2"}
def getBatchMessages(self, client): def getBatchMessages(self, client):
messages = self.getMessages(1) messages = self.getMessages(client)
first_msg = messages.pop(0) first_msg = messages.pop(0)
last_msg = messages.pop(-1) last_msg = messages.pop(-1)
@ -123,47 +128,104 @@ class MetadataTestCase(cases.BaseServerTestCase):
"and then 762 (RPL_METADATAEND)", "and then 762 (RPL_METADATAEND)",
) )
def assertSetValue(self, target, key, value): def assertSetValue(self, client, target, key, value):
self.sendLine(1, "METADATA {} SET {} :{}".format(target, key, value)) self.sendLine(client, "METADATA {} SET {} :{}".format(target, key, value))
if target == "*": if target == "*":
target = StrRe(r"(\*|foo)") target = StrRe(r"(\*|" + CLIENT_NICKS[client] + ")")
self.assertMessageMatch( self.assertMessageMatch(
self.getMessage(1), self.getMessage(client),
command="761", # RPL_KEYVALUE command="761", # RPL_KEYVALUE
params=["foo", target, key, ANYSTR, value], params=[CLIENT_NICKS[client], target, key, ANYSTR, value],
) )
def assertGetValue(self, target, key, value): def assertGetValue(self, client, target, key, value):
self.sendLine(1, "METADATA {} GET {}".format(target, key)) self.sendLine(client, "METADATA {} GET {}".format(target, key))
if target == "*": if target == "*":
target = StrRe(r"(\*|foo)") target = StrRe(r"(\*|" + CLIENT_NICKS[client] + ")")
(batch_id, messages) = self.getBatchMessages(1) (batch_id, messages) = self.getBatchMessages(client)
self.assertEqual(len(messages), 1, fail_msg="Expected one RPL_KEYVALUE") self.assertEqual(len(messages), 1, fail_msg="Expected one RPL_KEYVALUE")
self.assertMessageMatch( self.assertMessageMatch(
messages[0], messages[0],
command="761", # RPL_KEYVALUE command="761", # RPL_KEYVALUE
params=["foo", target, key, ANYSTR, value], params=[CLIENT_NICKS[client], target, key, ANYSTR, value],
) )
def assertSetGetValue(self, target, key, value): def assertSetGetValue(self, client, target, key, value):
self.assertSetValue(target, key, value) self.assertSetValue(client, target, key, value)
self.assertGetValue(target, key, value) self.assertGetValue(client, target, key, value)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"set_target,get_target", itertools.product(["*", "foo"], ["*", "foo"]) "set_target,get_target", itertools.product(["*", "foo"], ["*", "foo"])
) )
@cases.mark_specifications("IRCv3") @cases.mark_specifications("IRCv3")
def testSetGetValid(self, set_target, get_target): def testSetGet(self, set_target, get_target):
"""<http://ircv3.net/specs/core/metadata-3.2.html>""" """<http://ircv3.net/specs/core/metadata-3.2.html>"""
self.connectClient( self.connectClient(
"foo", capabilities=["draft/metadata-2", "batch"], skip_if_cap_nak=True "foo", capabilities=["draft/metadata-2", "batch"], skip_if_cap_nak=True
) )
self.assertSetValue(set_target, "valid_key1", "myvalue") self.assertSetValue(1, set_target, "valid_key1", "myvalue")
self.assertGetValue(get_target, "valid_key1", "myvalue") self.assertGetValue(1, get_target, "valid_key1", "myvalue")
@cases.mark_specifications("IRCv3")
def testSetGetAgain(self):
"""<http://ircv3.net/specs/core/metadata-3.2.html>"""
self.connectClient(
"foo", capabilities=["draft/metadata-2", "batch"], skip_if_cap_nak=True
)
self.assertSetGetValue(1, "*", "valid_key1", "myvalue1")
self.assertSetGetValue(1, "*", "valid_key1", "myvalue2")
@cases.mark_specifications("IRCv3")
def testSetGetChannel(self):
self.connectClient(
"foo", capabilities=["draft/metadata-2", "batch"], skip_if_cap_nak=True
)
self.connectClient(
"bar", capabilities=["draft/metadata-2", "batch"], skip_if_cap_nak=True
)
self.sendLine(1, "JOIN #chan")
self.sendLine(2, "JOIN #chan")
self.getMessages(1)
self.getMessages(2)
self.getMessages(1)
self.assertSetGetValue(1, "#chan", "valid_key1", "myvalue1")
self.assertEqual(
self.getMessages(2),
[],
fail_msg="Unexpected messages after other user used METADATA SET: {got}",
)
self.assertGetValue(2, "#chan", "valid_key1", "myvalue1")
@cases.mark_specifications("IRCv3")
def testSetGetOtherUser(self):
self.connectClient(
"foo", capabilities=["draft/metadata-2", "batch"], skip_if_cap_nak=True
)
self.connectClient(
"bar", capabilities=["draft/metadata-2", "batch"], skip_if_cap_nak=True
)
# As of 2023-04-15, the Unreal module requires users to share a channel for
# metadata to be visible to each other
self.sendLine(1, "JOIN #chan")
self.sendLine(2, "JOIN #chan")
self.getMessages(1)
self.getMessages(2)
self.getMessages(1)
self.assertSetValue(1, "*", "valid_key1", "myvalue1")
self.assertEqual(
self.getMessages(2),
[],
fail_msg="Unexpected messages after other user used METADATA SET: {got}",
)
self.assertGetValue(2, "foo", "valid_key1", "myvalue1")
@cases.mark_specifications("IRCv3") @cases.mark_specifications("IRCv3")
def testSetGetValidBeforeConnect(self): def testSetGetValidBeforeConnect(self):
@ -179,14 +241,14 @@ class MetadataTestCase(cases.BaseServerTestCase):
self.requestCapabilities(1, ["draft/metadata-2", "batch"], skip_if_cap_nak=True) self.requestCapabilities(1, ["draft/metadata-2", "batch"], skip_if_cap_nak=True)
self.assertSetValue("*", "valid_key1", "myvalue") self.assertSetValue(1, "*", "valid_key1", "myvalue")
self.sendLine(1, "NICK foo") self.sendLine(1, "NICK foo")
self.sendLine(1, "USER foo 0 * :foo") self.sendLine(1, "USER foo 0 * :foo")
self.sendLine(1, "CAP END") self.sendLine(1, "CAP END")
self.skipToWelcome(1) self.skipToWelcome(1)
self.assertGetValue("*", "valid_key1", "myvalue") self.assertGetValue(1, "*", "valid_key1", "myvalue")
@cases.mark_specifications("IRCv3") @cases.mark_specifications("IRCv3")
def testSetGetHeartInValue(self): def testSetGetHeartInValue(self):
@ -198,6 +260,7 @@ class MetadataTestCase(cases.BaseServerTestCase):
"foo", capabilities=["draft/metadata-2", "batch"], skip_if_cap_nak=True "foo", capabilities=["draft/metadata-2", "batch"], skip_if_cap_nak=True
) )
self.assertSetGetValue( self.assertSetGetValue(
1,
"*", "*",
"valid_key1", "valid_key1",
"->{}<-".format(heart), "->{}<-".format(heart),