diff --git a/irctest/cases.py b/irctest/cases.py index 18b4a38..7b37815 100644 --- a/irctest/cases.py +++ b/irctest/cases.py @@ -709,6 +709,12 @@ class BaseServerTestCase( self.server_support[param] = None welcome.append(m) + self.targmax: Dict[str, Optional[str]] = dict( + item.split(":", 1) # type: ignore + for item in (self.server_support.get("TARGMAX") or "").split(",") + if item + ) + return welcome def joinClient(self, client: TClientName, channel: str) -> None: diff --git a/irctest/numerics.py b/irctest/numerics.py index 8aafb08..8dc9736 100644 --- a/irctest/numerics.py +++ b/irctest/numerics.py @@ -142,6 +142,7 @@ ERR_USERONCHANNEL = "443" ERR_NOLOGIN = "444" ERR_SUMMONDISABLED = "445" ERR_USERSDISABLED = "446" +ERR_FORBIDDENCHANNEL = "448" ERR_NOTREGISTERED = "451" ERR_NEEDMOREPARAMS = "461" ERR_ALREADYREGISTRED = "462" diff --git a/irctest/server_tests/join.py b/irctest/server_tests/join.py index e3433e6..833f9c1 100644 --- a/irctest/server_tests/join.py +++ b/irctest/server_tests/join.py @@ -5,8 +5,26 @@ The JOIN command (`RFC 1459 `Modern `__) """ -from irctest import cases +from irctest import cases, runner from irctest.irc_utils import ambiguities +from irctest.numerics import ( + ERR_BADCHANMASK, + ERR_FORBIDDENCHANNEL, + ERR_NOSUCHCHANNEL, + RPL_ENDOFNAMES, + RPL_NAMREPLY, +) +from irctest.patma import ANYSTR, StrRe + +ERR_BADCHANNAME = "479" # Hybrid only, and conflicts with others + + +JOIN_ERROR_NUMERICS = { + ERR_BADCHANMASK, + ERR_NOSUCHCHANNEL, + ERR_FORBIDDENCHANNEL, + ERR_BADCHANNAME, +} class JoinTestCase(cases.BaseServerTestCase): @@ -26,13 +44,22 @@ class JoinTestCase(cases.BaseServerTestCase): self.connectClient("foo") self.sendLine(1, "JOIN #chan") received_commands = {m.command for m in self.getMessages(1)} - expected_commands = {"353", "366"} # RPL_NAMREPLY # RPL_ENDOFNAMES - self.assertTrue( - expected_commands.issubset(received_commands), + expected_commands = {RPL_NAMREPLY, RPL_ENDOFNAMES, "JOIN"} + acceptable_commands = expected_commands | {"MODE"} + self.assertLessEqual( # set inclusion + expected_commands, + received_commands, "Server sent {} commands, but at least {} were expected.".format( received_commands, expected_commands ), ) + self.assertLessEqual( # ditto + received_commands, + acceptable_commands, + "Server sent {} commands, but only {} were expected.".format( + received_commands, acceptable_commands + ), + ) @cases.mark_specifications("RFC2812") def testJoinNamreply(self): @@ -117,3 +144,95 @@ class JoinTestCase(cases.BaseServerTestCase): '"foo" with an optional "+" or "@" prefix, but got: ' "{msg}", ) + + def testJoinPartiallyInvalid(self): + """TODO: specify this in Modern""" + self.connectClient("foo") + if int(self.targmax.get("JOIN") or "4") < 2: + raise runner.OptionalExtensionNotSupported("multi-channel JOIN") + + self.sendLine(1, "JOIN #valid,inv@lid") + messages = self.getMessages(1) + received_commands = {m.command for m in messages} + expected_commands = {RPL_NAMREPLY, RPL_ENDOFNAMES, "JOIN"} + acceptable_commands = expected_commands | JOIN_ERROR_NUMERICS | {"MODE"} + self.assertLessEqual( + expected_commands, + received_commands, + "Server sent {} commands, but at least {} were expected.".format( + received_commands, expected_commands + ), + ) + self.assertLessEqual( + received_commands, + acceptable_commands, + "Server sent {} commands, but only {} were expected.".format( + received_commands, acceptable_commands + ), + ) + + nb_errors = 0 + for m in messages: + if m.command in JOIN_ERROR_NUMERICS: + nb_errors += 1 + self.assertMessageMatch(m, params=["foo", "inv@lid", ANYSTR]) + + self.assertEqual( + nb_errors, + 1, + fail_msg="Expected 1 error when joining channels '#valid' and 'inv@lid', " + "got {got}", + ) + + @cases.mark_capabilities("batch", "labeled-response") + def testJoinPartiallyInvalidLabeledResponse(self): + """TODO: specify this in Modern""" + self.connectClient( + "foo", capabilities=["batch", "labeled-response"], skip_if_cap_nak=True + ) + if int(self.targmax.get("JOIN") or "4") < 2: + raise runner.OptionalExtensionNotSupported("multi-channel JOIN") + + self.sendLine(1, "@label=label1 JOIN #valid,inv@lid") + messages = self.getMessages(1) + + first_msg = messages.pop(0) + last_msg = messages.pop(-1) + self.assertMessageMatch( + first_msg, command="BATCH", params=[StrRe(r"\+.*"), "labeled-response"] + ) + batch_id = first_msg.params[0][1:] + self.assertMessageMatch(last_msg, command="BATCH", params=["-" + batch_id]) + + received_commands = {m.command for m in messages} + expected_commands = {RPL_NAMREPLY, RPL_ENDOFNAMES, "JOIN"} + acceptable_commands = expected_commands | JOIN_ERROR_NUMERICS | {"MODE"} + self.assertLessEqual( + expected_commands, + received_commands, + "Server sent {} commands, but at least {} were expected.".format( + received_commands, expected_commands + ), + ) + self.assertLessEqual( + received_commands, + acceptable_commands, + "Server sent {} commands, but only {} were expected.".format( + received_commands, acceptable_commands + ), + ) + + nb_errors = 0 + for m in messages: + self.assertIn("batch", m.tags) + self.assertEqual(m.tags["batch"], batch_id) + if m.command in JOIN_ERROR_NUMERICS: + nb_errors += 1 + self.assertMessageMatch(m, params=["foo", "inv@lid", ANYSTR]) + + self.assertEqual( + nb_errors, + 1, + fail_msg="Expected 1 error when joining channels '#valid' and 'inv@lid', " + "got {got}", + ) diff --git a/irctest/server_tests/kick.py b/irctest/server_tests/kick.py index 06fda59..ac511de 100644 --- a/irctest/server_tests/kick.py +++ b/irctest/server_tests/kick.py @@ -230,12 +230,7 @@ class KickTestCase(cases.BaseServerTestCase): self.connectClient("qux") self.joinChannel(4, "#chan") - targmax = dict( - item.split(":", 1) - for item in self.server_support.get("TARGMAX", "").split(",") - if item - ) - if targmax.get("KICK", "1") == "1": + if self.targmax.get("KICK", "1") == "1": raise runner.OptionalExtensionNotSupported("Multi-target KICK") # TODO: check foo is an operator diff --git a/irctest/server_tests/names.py b/irctest/server_tests/names.py index 597bf8b..f45731a 100644 --- a/irctest/server_tests/names.py +++ b/irctest/server_tests/names.py @@ -62,12 +62,7 @@ class NamesTestCase(cases.BaseServerTestCase): def _testNamesMultipleChannels(self, symbol): self.connectClient("nick1") - targmax = dict( - item.split(":", 1) - for item in self.server_support.get("TARGMAX", "").split(",") - if item - ) - if targmax.get("NAMES", "1") == "1": + if self.targmax.get("NAMES", "1") == "1": raise runner.OptionalExtensionNotSupported("Multi-target NAMES") self.sendLine(1, "JOIN #chan1") diff --git a/irctest/server_tests/whowas.py b/irctest/server_tests/whowas.py index 79da7a4..76b250c 100644 --- a/irctest/server_tests/whowas.py +++ b/irctest/server_tests/whowas.py @@ -403,12 +403,7 @@ class WhowasTestCase(cases.BaseServerTestCase): self.connectClient("nick1") - targmax = dict( - item.split(":", 1) - for item in self.server_support.get("TARGMAX", "").split(",") - if item - ) - if targmax.get("WHOWAS", "1") == "1": + if self.targmax.get("WHOWAS", "1") == "1": raise runner.OptionalExtensionNotSupported("Multi-target WHOWAS") self.connectClient("nick2", ident="ident2")