From 6243908ecc2d11864b183101429f36121f4fe378 Mon Sep 17 00:00:00 2001 From: Valentin Lorentz Date: Tue, 21 Mar 2023 19:58:39 +0100 Subject: [PATCH] Add tests for SASL-IR --- irctest/server_tests/sasl.py | 146 ++++++++++++++++++++--------------- 1 file changed, 85 insertions(+), 61 deletions(-) diff --git a/irctest/server_tests/sasl.py b/irctest/server_tests/sasl.py index 600f959..1e460e9 100644 --- a/irctest/server_tests/sasl.py +++ b/irctest/server_tests/sasl.py @@ -1,4 +1,5 @@ import base64 +from typing import List from irctest import cases, runner, scram from irctest.numerics import ERR_SASLFAIL @@ -11,8 +12,34 @@ class RegistrationTestCase(cases.BaseServerTestCase): self.controller.registerUser(self, "testuser", "mypassword") -@cases.mark_services -class SaslTestCase(cases.BaseServerTestCase): +class _BaseSasl(cases.BaseServerTestCase): + sasl_ir: bool + capabilities: List[str] + + def _doInitialExchange(self, client, mechanism: str, chunk: str): + """Does the initial C->S, S->C, C->S exchange. + + With ``sasl_ir=False``, this is done with the usual three messages exchange + (``AUTHENTICATE ``, ``AUTHENTICATE +``, ``AUTHENTICATE ``) + with ``sasl_ir=True``, this is done in a single C->S message + (``AUTHENTICATE ``) + + See the [sasl-ir spec](https://github.com/ircv3/ircv3-specifications/pull/520) + """ + if self.sasl_ir: + self.sendLine(client, f"AUTHENTICATE {mechanism} {chunk}") + else: + self.sendLine(client, f"AUTHENTICATE {mechanism}") + m = self.getRegistrationMessage(1) + self.assertMessageMatch( + m, + command="AUTHENTICATE", + params=["+"], + fail_msg=f"Sent “AUTHENTICATE {mechanism}”, server should have " + f"replied with “AUTHENTICATE +”, but instead sent: {{msg}}", + ) + self.sendLine(client, f"AUTHENTICATE {chunk}") + @cases.mark_specifications("IRCv3") @cases.skipUnlessHasMechanism("PLAIN") def testPlain(self): @@ -34,17 +61,8 @@ class SaslTestCase(cases.BaseServerTestCase): capabilities["sasl"], fail_msg="Does not have PLAIN mechanism as the controller " "claims", ) - self.requestCapabilities(1, ["sasl"], skip_if_cap_nak=False) - self.sendLine(1, "AUTHENTICATE PLAIN") - m = self.getRegistrationMessage(1) - self.assertMessageMatch( - m, - command="AUTHENTICATE", - params=["+"], - fail_msg="Sent “AUTHENTICATE PLAIN”, server should have " - "replied with “AUTHENTICATE +”, but instead sent: {msg}", - ) - self.sendLine(1, "AUTHENTICATE amlsbGVzAGppbGxlcwBzZXNhbWU=") + self.requestCapabilities(1, self.capabilities, skip_if_cap_nak=False) + self._doInitialExchange(1, "PLAIN", "amlsbGVzAGppbGxlcwBzZXNhbWU=") m = self.getRegistrationMessage(1) self.assertMessageMatch( m, @@ -62,17 +80,8 @@ class SaslTestCase(cases.BaseServerTestCase): ).decode() self.controller.registerUser(self, "foo", password) self.addClient() - self.requestCapabilities(1, ["sasl"], skip_if_cap_nak=False) - self.sendLine(1, "AUTHENTICATE PLAIN") - m = self.getRegistrationMessage(1) - self.assertMessageMatch( - m, - command="AUTHENTICATE", - params=["+"], - fail_msg="Sent “AUTHENTICATE PLAIN”, server should have " - "replied with “AUTHENTICATE +”, but instead sent: {msg}", - ) - self.sendLine(1, "AUTHENTICATE " + authstring) + self.requestCapabilities(1, self.capabilities, skip_if_cap_nak=False) + self._doInitialExchange(1, "PLAIN", authstring) m = self.getRegistrationMessage(1) self.assertMessageMatch( m, @@ -122,17 +131,8 @@ class SaslTestCase(cases.BaseServerTestCase): capabilities["sasl"], fail_msg="Does not have PLAIN mechanism as the controller " "claims", ) - self.requestCapabilities(1, ["sasl"], skip_if_cap_nak=False) - self.sendLine(1, "AUTHENTICATE PLAIN") - m = self.getRegistrationMessage(1) - self.assertMessageMatch( - m, - command="AUTHENTICATE", - params=["+"], - fail_msg="Sent “AUTHENTICATE PLAIN”, server should have " - "replied with “AUTHENTICATE +”, but instead sent: {msg}", - ) - self.sendLine(1, "AUTHENTICATE AGppbGxlcwBzZXNhbWU=") + self.requestCapabilities(1, self.capabilities, skip_if_cap_nak=False) + self._doInitialExchange(1, "PLAIN", "AGppbGxlcwBzZXNhbWU=") m = self.getRegistrationMessage(1) self.assertMessageMatch( m, @@ -158,8 +158,11 @@ class SaslTestCase(cases.BaseServerTestCase): capabilities, fail_msg="Does not have SASL as the controller claims.", ) - self.requestCapabilities(1, ["sasl"], skip_if_cap_nak=False) - self.sendLine(1, "AUTHENTICATE FOO") + self.requestCapabilities(1, self.capabilities, skip_if_cap_nak=False) + if self.sasl_ir: + self.sendLine(1, "AUTHENTICATE FOO AGppbGxlcwBzZXNhbWU=") + else: + self.sendLine(1, "AUTHENTICATE FOO") m = self.getRegistrationMessage(1) while m.command == "908": # RPL_SASLMECHS m = self.getRegistrationMessage(1) @@ -209,17 +212,8 @@ class SaslTestCase(cases.BaseServerTestCase): capabilities["sasl"], fail_msg="Does not have PLAIN mechanism as the controller " "claims", ) - self.requestCapabilities(1, ["sasl"], skip_if_cap_nak=False) - self.sendLine(1, "AUTHENTICATE PLAIN") - m = self.getRegistrationMessage(1) - self.assertMessageMatch( - m, - command="AUTHENTICATE", - params=["+"], - fail_msg="Sent “AUTHENTICATE PLAIN”, expected " - "“AUTHENTICATE +” as a response, but got: {msg}", - ) - self.sendLine(1, "AUTHENTICATE {}".format(authstring[0:400])) + self.requestCapabilities(1, self.capabilities, skip_if_cap_nak=False) + self._doInitialExchange(1, "PLAIN", authstring[0:400]) self.sendLine(1, "AUTHENTICATE {}".format(authstring[400:])) self.confirmSuccessfulAuth() @@ -279,17 +273,8 @@ class SaslTestCase(cases.BaseServerTestCase): capabilities["sasl"], fail_msg="Does not have PLAIN mechanism as the controller " "claims", ) - self.requestCapabilities(1, ["sasl"], skip_if_cap_nak=False) - self.sendLine(1, "AUTHENTICATE PLAIN") - m = self.getRegistrationMessage(1) - self.assertMessageMatch( - m, - command="AUTHENTICATE", - params=["+"], - fail_msg="Sent “AUTHENTICATE PLAIN”, expected " - "“AUTHENTICATE +” as a response, but got: {msg}", - ) - self.sendLine(1, "AUTHENTICATE {}".format(authstring)) + self.requestCapabilities(1, self.capabilities, skip_if_cap_nak=False) + self._doInitialExchange(1, "PLAIN", authstring) self.sendLine(1, "AUTHENTICATE +") self.confirmSuccessfulAuth() @@ -298,6 +283,12 @@ class SaslTestCase(cases.BaseServerTestCase): # I don't know how to do it, because it would make the registration # message's length too big for it to be valid. + +@cases.mark_services +class SaslTestCase(_BaseSasl): + sasl_ir = False + capabilities = ["sasl"] + @cases.mark_specifications("IRCv3") @cases.skipUnlessHasMechanism("SCRAM-SHA-256") def testScramSha256Success(self): @@ -318,7 +309,7 @@ class SaslTestCase(cases.BaseServerTestCase): fail_msg="Does not have SCRAM-SHA-256 mechanism as the " "controller claims", ) - self.requestCapabilities(1, ["sasl"], skip_if_cap_nak=False) + self.requestCapabilities(1, self.capabilities, skip_if_cap_nak=False) self.sendLine(1, "AUTHENTICATE SCRAM-SHA-256") m = self.getRegistrationMessage(1) @@ -374,7 +365,7 @@ class SaslTestCase(cases.BaseServerTestCase): fail_msg="Does not have SCRAM-SHA-256 mechanism as the " "controller claims", ) - self.requestCapabilities(1, ["sasl"], skip_if_cap_nak=False) + self.requestCapabilities(1, self.capabilities, skip_if_cap_nak=False) self.sendLine(1, "AUTHENTICATE SCRAM-SHA-256") m = self.getRegistrationMessage(1) @@ -404,3 +395,36 @@ class SaslTestCase(cases.BaseServerTestCase): ) m = self.getRegistrationMessage(1) self.assertMessageMatch(m, command=ERR_SASLFAIL) + + +@cases.mark_services +class SaslIrTestCase(_BaseSasl): + """Tests SASL with clients requesting the + [sasl-ir](https://github.com/ircv3/ircv3-specifications/pull/520) cap and using it. + """ + + sasl_ir = True + capabilities = ["sasl", "draft/sasl-ir"] + + def setUp(self): + super().setUp() + self.connectClient( + "capgetter", capabilities=["draft/sasl-ir"], skip_if_cap_nak=True + ) + + +@cases.mark_services +class ImplicitSaslIrTestCase(_BaseSasl): + """Tests SASL with clients using the + [sasl-ir](https://github.com/ircv3/ircv3-specifications/pull/520) CAP without + requesting it. + """ + + sasl_ir = True + capabilities = ["sasl"] + + def setUp(self): + super().setUp() + self.connectClient( + "capgetter", capabilities=["draft/sasl-ir"], skip_if_cap_nak=True + )