diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 3e89078..3b720a2 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -13,10 +13,10 @@ jobs: - uses: actions/checkout@v2 - - name: Set up Python 3.7 + - name: Set up Python 3.11 uses: actions/setup-python@v2 with: - python-version: 3.7 + python-version: 3.11 - name: Cache dependencies uses: actions/cache@v2 diff --git a/.github/workflows/test-stable.yml b/.github/workflows/test-stable.yml index b8c5f1d..92461d3 100644 --- a/.github/workflows/test-stable.yml +++ b/.github/workflows/test-stable.yml @@ -1111,7 +1111,7 @@ jobs: uses: actions/checkout@v3 with: path: sable - ref: ff1179512a79eba57ca468a5f83af84ecce08a5b + ref: 1e05b0ebaa76cf7aa6ce2c34ba50744d6abbe9b0 repository: Libera-Chat/sable - name: Install rust toolchain uses: actions-rs/toolchain@v1 diff --git a/irctest/patma.py b/irctest/patma.py index 42fa4ae..c121414 100644 --- a/irctest/patma.py +++ b/irctest/patma.py @@ -1,6 +1,7 @@ """Pattern-matching utilities""" import dataclasses +import itertools import re from typing import Dict, List, Optional, Union @@ -27,6 +28,14 @@ class _AnyOptStr(Operator): return "ANYOPTSTR" +@dataclasses.dataclass(frozen=True) +class OptStrRe(Operator): + regexp: str + + def __repr__(self) -> str: + return f"OptStrRe(r'{self.regexp}')" + + @dataclasses.dataclass(frozen=True) class StrRe(Operator): regexp: str @@ -99,6 +108,11 @@ def match_string(got: Optional[str], expected: Union[str, Operator, None]) -> bo elif isinstance(expected, StrRe): if got is None or not re.match(expected.regexp, got): return False + elif isinstance(expected, OptStrRe): + if got is None: + return True + if not re.match(expected.regexp, got): + return False elif isinstance(expected, NotStrRe): if got is None or re.match(expected.regexp, got): return False @@ -128,11 +142,19 @@ def match_list( nb_remaining_items = len(got) - len(expected) expected += [remainder.item] * max(nb_remaining_items, remainder.min_length) - if len(got) != len(expected): + nb_optionals = 0 + for expected_value in expected: + if isinstance(expected_value, (_AnyOptStr, OptStrRe)): + nb_optionals += 1 + else: + if nb_optionals > 0: + raise NotImplementedError("Optional values in non-final position") + + if not (len(expected) - nb_optionals <= len(got) <= len(expected)): return False return all( match_string(got_value, expected_value) - for (got_value, expected_value) in zip(got, expected) + for (got_value, expected_value) in itertools.zip_longest(got, expected) ) diff --git a/irctest/self_tests/cases.py b/irctest/self_tests/cases.py index 6f8920b..428a73c 100644 --- a/irctest/self_tests/cases.py +++ b/irctest/self_tests/cases.py @@ -13,6 +13,7 @@ from irctest.patma import ( ANYSTR, ListRemainder, NotStrRe, + OptStrRe, RemainingKeys, StrRe, ) @@ -240,6 +241,28 @@ MESSAGE_SPECS: List[Tuple[Dict, List[str], List[str], List[str]]] = [ "expected tags to match {'tag1': 'bar', RemainingKeys(NotStrRe(r'tag2')): ANYOPTSTR}, got {'tag1': 'bar', 'tag2': 'baz'}", ] ), + ( + # the specification: + dict( + command="004", + params=["nick", "...", OptStrRe("[a-zA-Z]+")], + ), + # matches: + [ + "004 nick ... abc", + "004 nick ...", + ], + # and does not match: + [ + "004 nick ... 123", + "004 nick ... :", + ], + # and they each error with: + [ + "expected params to match ['nick', '...', OptStrRe(r'[a-zA-Z]+')], got ['nick', '...', '123']", + "expected params to match ['nick', '...', OptStrRe(r'[a-zA-Z]+')], got ['nick', '...', '']", + ] + ), ( # the specification: dict( diff --git a/irctest/server_tests/connection_registration.py b/irctest/server_tests/connection_registration.py index a3d5038..7cfdddd 100644 --- a/irctest/server_tests/connection_registration.py +++ b/irctest/server_tests/connection_registration.py @@ -10,7 +10,7 @@ import time from irctest import cases from irctest.client_mock import ConnectionClosed from irctest.numerics import ERR_NEEDMOREPARAMS, ERR_PASSWDMISMATCH -from irctest.patma import ANYSTR, StrRe +from irctest.patma import ANYLIST, ANYSTR, OptStrRe, StrRe class PasswordedConnectionRegistrationTestCase(cases.BaseServerTestCase): @@ -85,6 +85,92 @@ class PasswordedConnectionRegistrationTestCase(cases.BaseServerTestCase): class ConnectionRegistrationTestCase(cases.BaseServerTestCase): + def testConnectionRegistration(self): + self.addClient() + self.sendLine(1, "NICK foo") + self.sendLine(1, "USER foo * * :foo") + + for numeric in ("001", "002", "003"): + self.assertMessageMatch( + self.getRegistrationMessage(1), + command=numeric, + params=["foo", ANYSTR], + ) + + self.assertMessageMatch( + self.getRegistrationMessage(1), + command="004", # RPL_MYINFO + params=[ + "foo", + "My.Little.Server", + ANYSTR, # version + StrRe("[a-zA-Z]+"), # user modes + StrRe("[a-zA-Z]+"), # channel modes + OptStrRe("[a-zA-Z]+"), # channel modes with parameter + ], + ) + + # ISUPPORT + m = self.getRegistrationMessage(1) + while True: + self.assertMessageMatch( + m, + command="005", + params=["foo", *ANYLIST], + ) + m = self.getRegistrationMessage(1) + if m.command != "005": + break + + if m.command in ("042", "396"): # RPL_YOURID / RPL_VISIBLEHOST, non-standard + m = self.getRegistrationMessage(1) + + # LUSERS + while m.command in ("250", "251", "252", "253", "254", "255", "265", "266"): + m = self.getRegistrationMessage(1) + + if m.command == "375": # RPL_MOTDSTART + self.assertMessageMatch( + m, + command="375", + params=["foo", ANYSTR], + ) + while (m := self.getRegistrationMessage(1)).command == "372": + self.assertMessageMatch( + m, + command="372", # RPL_MOTD + params=["foo", ANYSTR], + ) + self.assertMessageMatch( + m, + command="376", # RPL_ENDOFMOTD + params=["foo", ANYSTR], + ) + else: + self.assertMessageMatch( + m, + command="422", # ERR_NOMOTD + params=["foo", ANYSTR], + ) + + # User mode + if m.command == "MODE": + self.assertMessageMatch( + m, + command="MODE", + params=["foo", ANYSTR, *ANYLIST], + ) + m = self.getRegistrationMessage(1) + elif m.command == "221": # RPL_UMODEIS + self.assertMessageMatch( + m, + command="221", + params=["foo", ANYSTR, *ANYLIST], + ) + m = self.getRegistrationMessage(1) + else: + print("Warning: missing MODE") + @cases.mark_specifications("RFC1459") def testQuitDisconnects(self): """“The server must close the connection to a client which sends a diff --git a/mypy.ini b/mypy.ini index 840ecdc..f9ef983 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,5 +1,5 @@ [mypy] -python_version = 3.7 +python_version = 3.8 warn_return_any = True warn_unused_configs = True diff --git a/pyproject.toml b/pyproject.toml index 583bb91..2ddd6ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [tool.black] -target-version = ['py37'] +target-version = ['py38'] exclude = 'irctest/scram/*' [tool.isort] diff --git a/workflows.yml b/workflows.yml index b739260..e49e99e 100644 --- a/workflows.yml +++ b/workflows.yml @@ -254,7 +254,7 @@ software: name: Sable repository: Libera-Chat/sable refs: - stable: ff1179512a79eba57ca468a5f83af84ecce08a5b + stable: 1e05b0ebaa76cf7aa6ce2c34ba50744d6abbe9b0 release: null devel: master devel_release: null