mirror of
https://github.com/progval/irctest.git
synced 2025-04-05 23:09:48 +00:00
Add tests for JOIN with some invalid channels in the target param (#163)
This commit is contained in:
@ -709,6 +709,12 @@ class BaseServerTestCase(
|
|||||||
self.server_support[param] = None
|
self.server_support[param] = None
|
||||||
welcome.append(m)
|
welcome.append(m)
|
||||||
|
|
||||||
|
self.targmax: Dict[str, Optional[str]] = dict(
|
||||||
|
item.split(":", 1) # type: ignore
|
||||||
|
for item in (self.server_support.get("TARGMAX") or "").split(",")
|
||||||
|
if item
|
||||||
|
)
|
||||||
|
|
||||||
return welcome
|
return welcome
|
||||||
|
|
||||||
def joinClient(self, client: TClientName, channel: str) -> None:
|
def joinClient(self, client: TClientName, channel: str) -> None:
|
||||||
|
@ -142,6 +142,7 @@ ERR_USERONCHANNEL = "443"
|
|||||||
ERR_NOLOGIN = "444"
|
ERR_NOLOGIN = "444"
|
||||||
ERR_SUMMONDISABLED = "445"
|
ERR_SUMMONDISABLED = "445"
|
||||||
ERR_USERSDISABLED = "446"
|
ERR_USERSDISABLED = "446"
|
||||||
|
ERR_FORBIDDENCHANNEL = "448"
|
||||||
ERR_NOTREGISTERED = "451"
|
ERR_NOTREGISTERED = "451"
|
||||||
ERR_NEEDMOREPARAMS = "461"
|
ERR_NEEDMOREPARAMS = "461"
|
||||||
ERR_ALREADYREGISTRED = "462"
|
ERR_ALREADYREGISTRED = "462"
|
||||||
|
@ -5,8 +5,26 @@ The JOIN command (`RFC 1459
|
|||||||
`Modern <https://modern.ircdocs.horse/#join-message>`__)
|
`Modern <https://modern.ircdocs.horse/#join-message>`__)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from irctest import cases
|
from irctest import cases, runner
|
||||||
from irctest.irc_utils import ambiguities
|
from irctest.irc_utils import ambiguities
|
||||||
|
from irctest.numerics import (
|
||||||
|
ERR_BADCHANMASK,
|
||||||
|
ERR_FORBIDDENCHANNEL,
|
||||||
|
ERR_NOSUCHCHANNEL,
|
||||||
|
RPL_ENDOFNAMES,
|
||||||
|
RPL_NAMREPLY,
|
||||||
|
)
|
||||||
|
from irctest.patma import ANYSTR, StrRe
|
||||||
|
|
||||||
|
ERR_BADCHANNAME = "479" # Hybrid only, and conflicts with others
|
||||||
|
|
||||||
|
|
||||||
|
JOIN_ERROR_NUMERICS = {
|
||||||
|
ERR_BADCHANMASK,
|
||||||
|
ERR_NOSUCHCHANNEL,
|
||||||
|
ERR_FORBIDDENCHANNEL,
|
||||||
|
ERR_BADCHANNAME,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class JoinTestCase(cases.BaseServerTestCase):
|
class JoinTestCase(cases.BaseServerTestCase):
|
||||||
@ -26,13 +44,22 @@ class JoinTestCase(cases.BaseServerTestCase):
|
|||||||
self.connectClient("foo")
|
self.connectClient("foo")
|
||||||
self.sendLine(1, "JOIN #chan")
|
self.sendLine(1, "JOIN #chan")
|
||||||
received_commands = {m.command for m in self.getMessages(1)}
|
received_commands = {m.command for m in self.getMessages(1)}
|
||||||
expected_commands = {"353", "366"} # RPL_NAMREPLY # RPL_ENDOFNAMES
|
expected_commands = {RPL_NAMREPLY, RPL_ENDOFNAMES, "JOIN"}
|
||||||
self.assertTrue(
|
acceptable_commands = expected_commands | {"MODE"}
|
||||||
expected_commands.issubset(received_commands),
|
self.assertLessEqual( # set inclusion
|
||||||
|
expected_commands,
|
||||||
|
received_commands,
|
||||||
"Server sent {} commands, but at least {} were expected.".format(
|
"Server sent {} commands, but at least {} were expected.".format(
|
||||||
received_commands, expected_commands
|
received_commands, expected_commands
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
self.assertLessEqual( # ditto
|
||||||
|
received_commands,
|
||||||
|
acceptable_commands,
|
||||||
|
"Server sent {} commands, but only {} were expected.".format(
|
||||||
|
received_commands, acceptable_commands
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
@cases.mark_specifications("RFC2812")
|
@cases.mark_specifications("RFC2812")
|
||||||
def testJoinNamreply(self):
|
def testJoinNamreply(self):
|
||||||
@ -117,3 +144,95 @@ class JoinTestCase(cases.BaseServerTestCase):
|
|||||||
'"foo" with an optional "+" or "@" prefix, but got: '
|
'"foo" with an optional "+" or "@" prefix, but got: '
|
||||||
"{msg}",
|
"{msg}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def testJoinPartiallyInvalid(self):
|
||||||
|
"""TODO: specify this in Modern"""
|
||||||
|
self.connectClient("foo")
|
||||||
|
if int(self.targmax.get("JOIN") or "4") < 2:
|
||||||
|
raise runner.OptionalExtensionNotSupported("multi-channel JOIN")
|
||||||
|
|
||||||
|
self.sendLine(1, "JOIN #valid,inv@lid")
|
||||||
|
messages = self.getMessages(1)
|
||||||
|
received_commands = {m.command for m in messages}
|
||||||
|
expected_commands = {RPL_NAMREPLY, RPL_ENDOFNAMES, "JOIN"}
|
||||||
|
acceptable_commands = expected_commands | JOIN_ERROR_NUMERICS | {"MODE"}
|
||||||
|
self.assertLessEqual(
|
||||||
|
expected_commands,
|
||||||
|
received_commands,
|
||||||
|
"Server sent {} commands, but at least {} were expected.".format(
|
||||||
|
received_commands, expected_commands
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.assertLessEqual(
|
||||||
|
received_commands,
|
||||||
|
acceptable_commands,
|
||||||
|
"Server sent {} commands, but only {} were expected.".format(
|
||||||
|
received_commands, acceptable_commands
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
nb_errors = 0
|
||||||
|
for m in messages:
|
||||||
|
if m.command in JOIN_ERROR_NUMERICS:
|
||||||
|
nb_errors += 1
|
||||||
|
self.assertMessageMatch(m, params=["foo", "inv@lid", ANYSTR])
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
nb_errors,
|
||||||
|
1,
|
||||||
|
fail_msg="Expected 1 error when joining channels '#valid' and 'inv@lid', "
|
||||||
|
"got {got}",
|
||||||
|
)
|
||||||
|
|
||||||
|
@cases.mark_capabilities("batch", "labeled-response")
|
||||||
|
def testJoinPartiallyInvalidLabeledResponse(self):
|
||||||
|
"""TODO: specify this in Modern"""
|
||||||
|
self.connectClient(
|
||||||
|
"foo", capabilities=["batch", "labeled-response"], skip_if_cap_nak=True
|
||||||
|
)
|
||||||
|
if int(self.targmax.get("JOIN") or "4") < 2:
|
||||||
|
raise runner.OptionalExtensionNotSupported("multi-channel JOIN")
|
||||||
|
|
||||||
|
self.sendLine(1, "@label=label1 JOIN #valid,inv@lid")
|
||||||
|
messages = self.getMessages(1)
|
||||||
|
|
||||||
|
first_msg = messages.pop(0)
|
||||||
|
last_msg = messages.pop(-1)
|
||||||
|
self.assertMessageMatch(
|
||||||
|
first_msg, command="BATCH", params=[StrRe(r"\+.*"), "labeled-response"]
|
||||||
|
)
|
||||||
|
batch_id = first_msg.params[0][1:]
|
||||||
|
self.assertMessageMatch(last_msg, command="BATCH", params=["-" + batch_id])
|
||||||
|
|
||||||
|
received_commands = {m.command for m in messages}
|
||||||
|
expected_commands = {RPL_NAMREPLY, RPL_ENDOFNAMES, "JOIN"}
|
||||||
|
acceptable_commands = expected_commands | JOIN_ERROR_NUMERICS | {"MODE"}
|
||||||
|
self.assertLessEqual(
|
||||||
|
expected_commands,
|
||||||
|
received_commands,
|
||||||
|
"Server sent {} commands, but at least {} were expected.".format(
|
||||||
|
received_commands, expected_commands
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.assertLessEqual(
|
||||||
|
received_commands,
|
||||||
|
acceptable_commands,
|
||||||
|
"Server sent {} commands, but only {} were expected.".format(
|
||||||
|
received_commands, acceptable_commands
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
nb_errors = 0
|
||||||
|
for m in messages:
|
||||||
|
self.assertIn("batch", m.tags)
|
||||||
|
self.assertEqual(m.tags["batch"], batch_id)
|
||||||
|
if m.command in JOIN_ERROR_NUMERICS:
|
||||||
|
nb_errors += 1
|
||||||
|
self.assertMessageMatch(m, params=["foo", "inv@lid", ANYSTR])
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
nb_errors,
|
||||||
|
1,
|
||||||
|
fail_msg="Expected 1 error when joining channels '#valid' and 'inv@lid', "
|
||||||
|
"got {got}",
|
||||||
|
)
|
||||||
|
@ -230,12 +230,7 @@ class KickTestCase(cases.BaseServerTestCase):
|
|||||||
self.connectClient("qux")
|
self.connectClient("qux")
|
||||||
self.joinChannel(4, "#chan")
|
self.joinChannel(4, "#chan")
|
||||||
|
|
||||||
targmax = dict(
|
if self.targmax.get("KICK", "1") == "1":
|
||||||
item.split(":", 1)
|
|
||||||
for item in self.server_support.get("TARGMAX", "").split(",")
|
|
||||||
if item
|
|
||||||
)
|
|
||||||
if targmax.get("KICK", "1") == "1":
|
|
||||||
raise runner.OptionalExtensionNotSupported("Multi-target KICK")
|
raise runner.OptionalExtensionNotSupported("Multi-target KICK")
|
||||||
|
|
||||||
# TODO: check foo is an operator
|
# TODO: check foo is an operator
|
||||||
|
@ -62,12 +62,7 @@ class NamesTestCase(cases.BaseServerTestCase):
|
|||||||
def _testNamesMultipleChannels(self, symbol):
|
def _testNamesMultipleChannels(self, symbol):
|
||||||
self.connectClient("nick1")
|
self.connectClient("nick1")
|
||||||
|
|
||||||
targmax = dict(
|
if self.targmax.get("NAMES", "1") == "1":
|
||||||
item.split(":", 1)
|
|
||||||
for item in self.server_support.get("TARGMAX", "").split(",")
|
|
||||||
if item
|
|
||||||
)
|
|
||||||
if targmax.get("NAMES", "1") == "1":
|
|
||||||
raise runner.OptionalExtensionNotSupported("Multi-target NAMES")
|
raise runner.OptionalExtensionNotSupported("Multi-target NAMES")
|
||||||
|
|
||||||
self.sendLine(1, "JOIN #chan1")
|
self.sendLine(1, "JOIN #chan1")
|
||||||
|
@ -403,12 +403,7 @@ class WhowasTestCase(cases.BaseServerTestCase):
|
|||||||
|
|
||||||
self.connectClient("nick1")
|
self.connectClient("nick1")
|
||||||
|
|
||||||
targmax = dict(
|
if self.targmax.get("WHOWAS", "1") == "1":
|
||||||
item.split(":", 1)
|
|
||||||
for item in self.server_support.get("TARGMAX", "").split(",")
|
|
||||||
if item
|
|
||||||
)
|
|
||||||
if targmax.get("WHOWAS", "1") == "1":
|
|
||||||
raise runner.OptionalExtensionNotSupported("Multi-target WHOWAS")
|
raise runner.OptionalExtensionNotSupported("Multi-target WHOWAS")
|
||||||
|
|
||||||
self.connectClient("nick2", ident="ident2")
|
self.connectClient("nick2", ident="ident2")
|
||||||
|
Reference in New Issue
Block a user