irctest/irctest/client_mock.py

167 lines
6.2 KiB
Python

import socket
import ssl
import sys
import time
from typing import Any, Callable, List, Optional, Union
from .exceptions import ConnectionClosed, NoMessageException
from .irc_utils import message_parser
class ClientMock:
def __init__(self, name: Any, show_io: bool):
self.name = name
self.show_io = show_io
self.inbuffer: List[message_parser.Message] = []
self.ssl = False
def connect(self, hostname: str, port: int) -> None:
self.conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# probably useful for test_buffering, as it relies on chunking
# the packets to be useful
self.conn.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
self.conn.settimeout(1) # TODO: configurable
self.conn.connect((hostname, port))
if self.show_io:
print("{:.3f} {}: connects to server.".format(time.time(), self.name))
def disconnect(self) -> None:
if self.show_io:
print("{:.3f} {}: disconnects from server.".format(time.time(), self.name))
self.conn.close()
def starttls(self) -> None:
assert not self.ssl, "SSL already active."
self.conn = ssl.wrap_socket(self.conn)
self.ssl = True
def getMessages(
self, synchronize: bool = True, assert_get_one: bool = False, raw: bool = False
) -> List[message_parser.Message]:
"""actually returns List[str] in the rare case where raw=True."""
__tracebackhide__ = True # Hide from pytest tracebacks on test failure.
token: Optional[str]
if synchronize:
token = "synchronize{}".format(time.monotonic())
self.sendLine("PING {}".format(token))
else:
token = None
got_pong = False
data = b""
(self.inbuffer, messages) = ([], self.inbuffer)
conn = self.conn
try:
while not got_pong:
try:
new_data = conn.recv(4096)
except socket.timeout:
if not assert_get_one and not synchronize and data == b"":
# Received nothing
return []
if self.show_io:
print("{:.3f} {}: waiting…".format(time.time(), self.name))
continue
except ConnectionResetError:
raise ConnectionClosed()
else:
if not new_data:
# Connection closed
raise ConnectionClosed()
data += new_data
if not new_data.endswith(b"\r\n"):
continue
if not synchronize:
got_pong = True
for line in data.decode().split("\r\n"):
if line:
if self.show_io:
print(
"{time:.3f}{ssl} S -> {client}: {line}".format(
time=time.time(),
ssl=" (ssl)" if self.ssl else "",
client=self.name,
line=line,
)
)
message = message_parser.parse_message(line)
if message.command == "PONG" and token in message.params:
got_pong = True
elif (
synchronize
and message.command == "451"
and message.params[1] == "PING"
):
raise ValueError(
"Got '451 * PONG'. Did you forget synchronize=False?"
)
else:
if raw:
messages.append(line) # type: ignore
else:
messages.append(message)
data = b""
except ConnectionClosed:
if messages:
return messages
else:
raise
else:
return messages
def getMessage(
self,
filter_pred: Optional[Callable[[message_parser.Message], bool]] = None,
synchronize: bool = True,
raw: bool = False,
) -> message_parser.Message:
"""Returns str in the rare case where raw=True"""
__tracebackhide__ = True # Hide from pytest tracebacks on test failure.
while True:
if not self.inbuffer:
self.inbuffer = self.getMessages(
synchronize=synchronize, assert_get_one=True, raw=raw
)
if not self.inbuffer:
raise NoMessageException()
message = self.inbuffer.pop(0) # TODO: use dequeue
if not filter_pred or filter_pred(message):
return message
def sendLine(self, line: Union[str, bytes]) -> None:
if isinstance(line, str):
encoded_line = line.encode()
elif isinstance(line, bytes):
encoded_line = line
else:
raise ValueError(line)
if not encoded_line.endswith(b"\r\n"):
encoded_line += b"\r\n"
try:
ret = self.conn.sendall(encoded_line) # type: ignore
except BrokenPipeError:
raise ConnectionClosed()
if (
sys.version_info <= (3, 6) and self.ssl
): # https://bugs.python.org/issue25951
assert ret == len(encoded_line), (ret, repr(encoded_line))
else:
assert ret is None, ret
if self.show_io:
if isinstance(line, str):
escaped_line = line
escaped = ""
else:
escaped_line = repr(line)
escaped = " (escaped)"
print(
"{time:.3f}{escaped}{ssl} {client} -> S: {line}".format(
time=time.time(),
escaped=escaped,
ssl=" (ssl)" if self.ssl else "",
client=self.name,
line=escaped_line.strip("\r\n"),
)
)