diff --git a/irctest/server_tests/sasl.py b/irctest/server_tests/sasl.py index e421671..e6bc998 100644 --- a/irctest/server_tests/sasl.py +++ b/irctest/server_tests/sasl.py @@ -1,4 +1,5 @@ import base64 +from typing import List from irctest import cases, runner, scram from irctest.numerics import ERR_SASLFAIL, RPL_LOGGEDIN, RPL_SASLMECHS @@ -11,8 +12,34 @@ class RegistrationTestCase(cases.BaseServerTestCase): self.controller.registerUser(self, "testuser", "mypassword") -@cases.mark_services -class SaslTestCase(cases.BaseServerTestCase): +class _BaseSasl(cases.BaseServerTestCase): + sasl_ir: bool + capabilities: List[str] + + def _doInitialExchange(self, client, mechanism: str, chunk: str): + """Does the initial C->S, S->C, C->S exchange. + + With ``sasl_ir=False``, this is done with the usual three messages exchange + (``AUTHENTICATE ``, ``AUTHENTICATE +``, ``AUTHENTICATE ``) + with ``sasl_ir=True``, this is done in a single C->S message + (``AUTHENTICATE ``) + + See the [sasl-ir spec](https://github.com/ircv3/ircv3-specifications/pull/520) + """ + if self.sasl_ir: + self.sendLine(client, f"AUTHENTICATE {mechanism} {chunk}") + else: + self.sendLine(client, f"AUTHENTICATE {mechanism}") + m = self.getRegistrationMessage(1) + self.assertMessageMatch( + m, + command="AUTHENTICATE", + params=["+"], + fail_msg=f"Sent “AUTHENTICATE {mechanism}”, server should have " + f"replied with “AUTHENTICATE +”, but instead sent: {{msg}}", + ) + self.sendLine(client, f"AUTHENTICATE {chunk}") + @cases.mark_specifications("IRCv3") @cases.skipUnlessHasMechanism("PLAIN") def testPlain(self): @@ -34,17 +61,8 @@ class SaslTestCase(cases.BaseServerTestCase): capabilities["sasl"], fail_msg="Does not have PLAIN mechanism as the controller " "claims", ) - self.requestCapabilities(1, ["sasl"], skip_if_cap_nak=False) - self.sendLine(1, "AUTHENTICATE PLAIN") - m = self.getRegistrationMessage(1) - self.assertMessageMatch( - m, - command="AUTHENTICATE", - params=["+"], - fail_msg="Sent “AUTHENTICATE PLAIN”, server should have " - "replied with “AUTHENTICATE +”, but instead sent: {msg}", - ) - self.sendLine(1, "AUTHENTICATE amlsbGVzAGppbGxlcwBzZXNhbWU=") + self.requestCapabilities(1, self.capabilities, skip_if_cap_nak=False) + self._doInitialExchange(1, "PLAIN", "amlsbGVzAGppbGxlcwBzZXNhbWU=") m = self.getRegistrationMessage(1) self.assertMessageMatch( m, @@ -88,17 +106,8 @@ class SaslTestCase(cases.BaseServerTestCase): ).decode() self.controller.registerUser(self, "foo", password) self.addClient() - self.requestCapabilities(1, ["sasl"], skip_if_cap_nak=False) - self.sendLine(1, "AUTHENTICATE PLAIN") - m = self.getRegistrationMessage(1) - self.assertMessageMatch( - m, - command="AUTHENTICATE", - params=["+"], - fail_msg="Sent “AUTHENTICATE PLAIN”, server should have " - "replied with “AUTHENTICATE +”, but instead sent: {msg}", - ) - self.sendLine(1, "AUTHENTICATE " + authstring) + self.requestCapabilities(1, self.capabilities, skip_if_cap_nak=False) + self._doInitialExchange(1, "PLAIN", authstring) m = self.getRegistrationMessage(1) self.assertMessageMatch( m, @@ -148,17 +157,8 @@ class SaslTestCase(cases.BaseServerTestCase): capabilities["sasl"], fail_msg="Does not have PLAIN mechanism as the controller " "claims", ) - self.requestCapabilities(1, ["sasl"], skip_if_cap_nak=False) - self.sendLine(1, "AUTHENTICATE PLAIN") - m = self.getRegistrationMessage(1) - self.assertMessageMatch( - m, - command="AUTHENTICATE", - params=["+"], - fail_msg="Sent “AUTHENTICATE PLAIN”, server should have " - "replied with “AUTHENTICATE +”, but instead sent: {msg}", - ) - self.sendLine(1, "AUTHENTICATE AGppbGxlcwBzZXNhbWU=") + self.requestCapabilities(1, self.capabilities, skip_if_cap_nak=False) + self._doInitialExchange(1, "PLAIN", "AGppbGxlcwBzZXNhbWU=") m = self.getRegistrationMessage(1) self.assertMessageMatch( m, @@ -184,8 +184,11 @@ class SaslTestCase(cases.BaseServerTestCase): capabilities, fail_msg="Does not have SASL as the controller claims.", ) - self.requestCapabilities(1, ["sasl"], skip_if_cap_nak=False) - self.sendLine(1, "AUTHENTICATE FOO") + self.requestCapabilities(1, self.capabilities, skip_if_cap_nak=False) + if self.sasl_ir: + self.sendLine(1, "AUTHENTICATE FOO AGppbGxlcwBzZXNhbWU=") + else: + self.sendLine(1, "AUTHENTICATE FOO") m = self.getRegistrationMessage(1) while m.command == RPL_SASLMECHS: m = self.getRegistrationMessage(1) @@ -235,17 +238,8 @@ class SaslTestCase(cases.BaseServerTestCase): capabilities["sasl"], fail_msg="Does not have PLAIN mechanism as the controller " "claims", ) - self.requestCapabilities(1, ["sasl"], skip_if_cap_nak=False) - self.sendLine(1, "AUTHENTICATE PLAIN") - m = self.getRegistrationMessage(1) - self.assertMessageMatch( - m, - command="AUTHENTICATE", - params=["+"], - fail_msg="Sent “AUTHENTICATE PLAIN”, expected " - "“AUTHENTICATE +” as a response, but got: {msg}", - ) - self.sendLine(1, "AUTHENTICATE {}".format(authstring[0:400])) + self.requestCapabilities(1, self.capabilities, skip_if_cap_nak=False) + self._doInitialExchange(1, "PLAIN", authstring[0:400]) self.sendLine(1, "AUTHENTICATE {}".format(authstring[400:])) self.confirmSuccessfulAuth() @@ -305,17 +299,8 @@ class SaslTestCase(cases.BaseServerTestCase): capabilities["sasl"], fail_msg="Does not have PLAIN mechanism as the controller " "claims", ) - self.requestCapabilities(1, ["sasl"], skip_if_cap_nak=False) - self.sendLine(1, "AUTHENTICATE PLAIN") - m = self.getRegistrationMessage(1) - self.assertMessageMatch( - m, - command="AUTHENTICATE", - params=["+"], - fail_msg="Sent “AUTHENTICATE PLAIN”, expected " - "“AUTHENTICATE +” as a response, but got: {msg}", - ) - self.sendLine(1, "AUTHENTICATE {}".format(authstring)) + self.requestCapabilities(1, self.capabilities, skip_if_cap_nak=False) + self._doInitialExchange(1, "PLAIN", authstring) self.sendLine(1, "AUTHENTICATE +") self.confirmSuccessfulAuth() @@ -324,6 +309,12 @@ class SaslTestCase(cases.BaseServerTestCase): # I don't know how to do it, because it would make the registration # message's length too big for it to be valid. + +@cases.mark_services +class SaslTestCase(_BaseSasl): + sasl_ir = False + capabilities = ["sasl"] + @cases.mark_specifications("IRCv3") @cases.skipUnlessHasMechanism("SCRAM-SHA-256") def testScramSha256Success(self): @@ -344,7 +335,7 @@ class SaslTestCase(cases.BaseServerTestCase): fail_msg="Does not have SCRAM-SHA-256 mechanism as the " "controller claims", ) - self.requestCapabilities(1, ["sasl"], skip_if_cap_nak=False) + self.requestCapabilities(1, self.capabilities, skip_if_cap_nak=False) self.sendLine(1, "AUTHENTICATE SCRAM-SHA-256") m = self.getRegistrationMessage(1) @@ -400,7 +391,7 @@ class SaslTestCase(cases.BaseServerTestCase): fail_msg="Does not have SCRAM-SHA-256 mechanism as the " "controller claims", ) - self.requestCapabilities(1, ["sasl"], skip_if_cap_nak=False) + self.requestCapabilities(1, self.capabilities, skip_if_cap_nak=False) self.sendLine(1, "AUTHENTICATE SCRAM-SHA-256") m = self.getRegistrationMessage(1) @@ -430,3 +421,36 @@ class SaslTestCase(cases.BaseServerTestCase): ) m = self.getRegistrationMessage(1) self.assertMessageMatch(m, command=ERR_SASLFAIL) + + +@cases.mark_services +class SaslIrTestCase(_BaseSasl): + """Tests SASL with clients requesting the + [sasl-ir](https://github.com/ircv3/ircv3-specifications/pull/520) cap and using it. + """ + + sasl_ir = True + capabilities = ["sasl", "draft/sasl-ir"] + + def setUp(self): + super().setUp() + self.connectClient( + "capgetter", capabilities=["draft/sasl-ir"], skip_if_cap_nak=True + ) + + +@cases.mark_services +class ImplicitSaslIrTestCase(_BaseSasl): + """Tests SASL with clients using the + [sasl-ir](https://github.com/ircv3/ircv3-specifications/pull/520) CAP without + requesting it. + """ + + sasl_ir = True + capabilities = ["sasl"] + + def setUp(self): + super().setUp() + self.connectClient( + "capgetter", capabilities=["draft/sasl-ir"], skip_if_cap_nak=True + )