Add docstrings.

This commit is contained in:
Valentin Lorentz
2015-12-20 13:47:30 +01:00
parent 8d337bb7bd
commit 900f18492c
6 changed files with 44 additions and 0 deletions

View File

@ -3,6 +3,7 @@ import collections
@enum.unique
class Mechanisms(enum.Enum):
"""Enumeration for representing possible mechanisms."""
@classmethod
def as_string(cls, mech):
return {cls.plain: 'PLAIN',

View File

@ -7,15 +7,24 @@ import subprocess
from .optional_extensions import NotImplementedByController
class _BaseController:
"""Base class for software controllers.
A software controller is an object that handles configuring and running
a process (eg. a server or a client), as well as sending it instructions
that are not part of the IRC specification."""
pass
class DirectoryBasedController(_BaseController):
"""Helper for controllers whose software configuration is based on an
arbitrary directory."""
def __init__(self):
super().__init__()
self.directory = None
self.proc = None
def kill_proc(self):
"""Terminates the controlled process, waits for it to exit, and
eventually kills it."""
self.proc.terminate()
try:
self.proc.wait(5)
@ -23,11 +32,13 @@ class DirectoryBasedController(_BaseController):
self.proc.kill()
self.proc = None
def kill(self):
"""Calls `kill_proc` and cleans the configuration."""
if self.proc:
self.kill_proc()
if self.directory:
shutil.rmtree(self.directory)
def open_file(self, name, mode='a'):
"""Open a file in the configuration directory."""
assert self.directory
if os.sep in name:
dir_ = os.path.join(self.directory, os.path.dirname(name))
@ -39,10 +50,12 @@ class DirectoryBasedController(_BaseController):
self.directory = tempfile.mkdtemp()
class BaseClientController(_BaseController):
"""Base controller for IRC clients."""
def run(self, hostname, port, auth):
raise NotImplementedError()
class BaseServerController(_BaseController):
"""Base controller for IRC server."""
def run(self, hostname, port, start_wait):
raise NotImplementedError()
def registerUser(self, case, username):

View File

@ -7,6 +7,7 @@ from . import authentication
from .irc_utils import message_parser
class _IrcTestCase(unittest.TestCase):
"""Base class for test cases."""
controllerClass = None # Will be set by __main__.py
def setUp(self):
@ -29,6 +30,13 @@ class _IrcTestCase(unittest.TestCase):
return msg
def assertMessageEqual(self, msg, subcommand=None, subparams=None,
target=None, **kwargs):
"""Helper for partially comparing a message.
Takes the message as first arguments, and comparisons to be made
as keyword arguments.
Deals with subcommands (eg. `CAP`) if any of `subcommand`,
`subparams`, and `target` are given."""
for (key, value) in kwargs.items():
with self.subTest(key=key):
self.assertEqual(getattr(msg, key), value, msg)
@ -122,6 +130,8 @@ class ClientNegociationHelper:
return True
def negotiateCapabilities(self, capabilities, cap_ls=True, auth=None):
"""Performes a complete capability negociation process, without
ending it, so the caller can continue the negociation."""
if cap_ls:
self.readCapLs(auth)
if not self.protocol_version:
@ -187,6 +197,7 @@ class BaseServerTestCase(_IrcTestCase):
return name
def removeClient(self, name):
"""Disconnects the client, without QUIT."""
assert name in self.clients
self.clients[name].conn.close()
del self.clients[name]
@ -222,6 +233,8 @@ class BaseServerTestCase(_IrcTestCase):
print('{} -> S: {}'.format(client, line.strip()))
def getCapLs(self, client):
"""Waits for a CAP LS block, parses all CAP LS messages, and return
the list of capabilities."""
capabilities = []
while True:
m = self.getMessage(client,

View File

@ -25,6 +25,7 @@ class SaslMechanismCheck:
class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
SaslMechanismCheck):
def testPlain(self):
"""Test PLAIN authentication."""
auth = authentication.Authentication(
mechanisms=[authentication.Mechanisms.plain],
username='jilles',
@ -43,6 +44,8 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
self.assertEqual(m, Message([], None, 'CAP', ['END']))
def testPlainNotAvailable(self):
"""Test the client handles gracefully servers that don't provide a
mechanism it could use."""
auth = authentication.Authentication(
mechanisms=[authentication.Mechanisms.plain],
username='jilles',
@ -61,6 +64,8 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
def testPlainLarge(self):
"""Test the client splits large AUTHENTICATE messages whose payload
is not a multiple of 400."""
# TODO: authzid is optional
auth = authentication.Authentication(
mechanisms=[authentication.Mechanisms.plain],
@ -88,6 +93,8 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
self.assertEqual(m, Message([], None, 'CAP', ['END']))
def testPlainLargeMultiple(self):
"""Test the client splits large AUTHENTICATE messages whose payload
is a multiple of 400."""
# TODO: authzid is optional
auth = authentication.Authentication(
mechanisms=[authentication.Mechanisms.plain],
@ -115,6 +122,7 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
self.assertEqual(m, Message([], None, 'CAP', ['END']))
def testEcdsa(self):
"""Test ECDSA authentication."""
auth = authentication.Authentication(
mechanisms=[authentication.Mechanisms.ecdsa_nist256p_challenge],
username='jilles',
@ -145,6 +153,8 @@ class SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
class Irc302SaslTestCase(cases.BaseClientTestCase, cases.ClientNegociationHelper,
SaslMechanismCheck):
def testPlainNotAvailable(self):
"""Test the client does not try to authenticate using a mechanism the
server does not advertise."""
auth = authentication.Authentication(
mechanisms=[authentication.Mechanisms.plain],
username='jilles',

View File

@ -15,6 +15,9 @@ class OptionalSaslMechanismNotSupported(unittest.SkipTest):
return 'Unsupported SASL mechanism: {}'.format(self.args[0])
class OptionalityReportingTextTestRunner(unittest.TextTestRunner):
"""Small wrapper around unittest.TextTestRunner that reports the
number of tests that were skipped because the software does not support
an optional feature."""
def run(self, test):
result = super().run(test)
if result.skipped:

View File

@ -3,6 +3,8 @@ from irctest.irc_utils.message_parser import Message
class CapTestCase(cases.BaseServerTestCase):
def testNoReq(self):
"""Test the server handles gracefully clients which do not send
REQs."""
self.addClient(1)
self.sendLine(1, 'CAP LS 302')
self.getCapLs(1)
@ -13,6 +15,8 @@ class CapTestCase(cases.BaseServerTestCase):
self.assertMessageEqual(m, command='001')
def testReqUnavailable(self):
"""Test the server handles gracefully clients which request
capabilities that are not available"""
self.addClient(1)
self.sendLine(1, 'CAP LS 302')
self.getCapLs(1)