From 69c5dca4b9b2bcff9dc25dde488db4463e68f88f Mon Sep 17 00:00:00 2001
From: Val Lorentz <progval+git@progval.net>
Date: Sat, 19 Mar 2022 16:09:27 +0100
Subject: [PATCH] Add client tests for SASL with non-ASCII passwords (#137)

---
 irctest/client_tests/sasl.py    | 14 ++++++++------
 irctest/controllers/limnoria.py | 14 ++++++++++----
 2 files changed, 18 insertions(+), 10 deletions(-)

diff --git a/irctest/client_tests/sasl.py b/irctest/client_tests/sasl.py
index c157059..7cc3781 100644
--- a/irctest/client_tests/sasl.py
+++ b/irctest/client_tests/sasl.py
@@ -84,8 +84,9 @@ class SaslTestCase(cases.BaseClientTestCase, cases.OptionalityHelper):
         m = self.getMessage()
         self.assertMessageMatch(m, command="CAP")
 
+    @pytest.mark.parametrize("pattern", ["barbaz", "éèà"])
     @cases.OptionalityHelper.skipUnlessHasMechanism("PLAIN")
-    def testPlainLarge(self):
+    def testPlainLarge(self, pattern):
         """Test the client splits large AUTHENTICATE messages whose payload
         is not a multiple of 400.
         <http://ircv3.net/specs/extensions/sasl-3.1.html#the-authenticate-command>
@@ -94,10 +95,10 @@ class SaslTestCase(cases.BaseClientTestCase, cases.OptionalityHelper):
         auth = authentication.Authentication(
             mechanisms=[authentication.Mechanisms.plain],
             username="foo",
-            password="bar" * 200,
+            password=pattern * 100,
         )
         authstring = base64.b64encode(
-            b"\x00".join([b"foo", b"foo", b"bar" * 200])
+            b"\x00".join([b"foo", b"foo", pattern.encode() * 100])
         ).decode()
         m = self.negotiateCapabilities(["sasl"], auth=auth)
         self.assertEqual(m, Message({}, None, "AUTHENTICATE", ["PLAIN"]))
@@ -114,7 +115,8 @@ class SaslTestCase(cases.BaseClientTestCase, cases.OptionalityHelper):
         self.assertEqual(m, Message({}, None, "CAP", ["END"]))
 
     @cases.OptionalityHelper.skipUnlessHasMechanism("PLAIN")
-    def testPlainLargeMultiple(self):
+    @pytest.mark.parametrize("pattern", ["quux", "éè"])
+    def testPlainLargeMultiple(self, pattern):
         """Test the client splits large AUTHENTICATE messages whose payload
         is a multiple of 400.
         <http://ircv3.net/specs/extensions/sasl-3.1.html#the-authenticate-command>
@@ -123,10 +125,10 @@ class SaslTestCase(cases.BaseClientTestCase, cases.OptionalityHelper):
         auth = authentication.Authentication(
             mechanisms=[authentication.Mechanisms.plain],
             username="foo",
-            password="quux" * 148,
+            password=pattern * 148,
         )
         authstring = base64.b64encode(
-            b"\x00".join([b"foo", b"foo", b"quux" * 148])
+            b"\x00".join([b"foo", b"foo", pattern.encode() * 148])
         ).decode()
         m = self.negotiateCapabilities(["sasl"], auth=auth)
         self.assertEqual(m, Message({}, None, "AUTHENTICATE", ["PLAIN"]))
diff --git a/irctest/controllers/limnoria.py b/irctest/controllers/limnoria.py
index 78b005f..5b38e21 100644
--- a/irctest/controllers/limnoria.py
+++ b/irctest/controllers/limnoria.py
@@ -55,13 +55,19 @@ class LimnoriaController(BaseClientController, DirectoryBasedController):
         # Runs a client with the config given as arguments
         assert self.proc is None
         self.create_config()
+
+        username = password = ""
+        mechanisms = ""
         if auth:
             mechanisms = " ".join(mech.to_string() for mech in auth.mechanisms)
             if auth.ecdsa_key:
                 with self.open_file("ecdsa_key.pem") as fd:
                     fd.write(auth.ecdsa_key)
-        else:
-            mechanisms = ""
+
+            if auth.username:
+                username = auth.username.encode("unicode_escape").decode()
+            if auth.password:
+                password = auth.password.encode("unicode_escape").decode()
         with self.open_file("bot.conf") as fd:
             fd.write(
                 TEMPLATE_CONFIG.format(
@@ -69,8 +75,8 @@ class LimnoriaController(BaseClientController, DirectoryBasedController):
                     loglevel="CRITICAL",
                     hostname=hostname,
                     port=port,
-                    username=auth.username if auth else "",
-                    password=auth.password if auth else "",
+                    username=username,
+                    password=password,
                     mechanisms=mechanisms.lower(),
                     enable_tls=tls_config.enable if tls_config else "False",
                     trusted_fingerprints=" ".join(tls_config.trusted_fingerprints)