mirror of
https://github.com/progval/irctest.git
synced 2025-04-06 07:19:54 +00:00
Instead of hanging forever. Hopefully there shouldn't be any false positive.
167 lines
6.2 KiB
Python
167 lines
6.2 KiB
Python
import socket
|
|
import ssl
|
|
import sys
|
|
import time
|
|
from typing import Any, Callable, List, Optional, Union
|
|
|
|
from .exceptions import ConnectionClosed, NoMessageException
|
|
from .irc_utils import message_parser
|
|
|
|
|
|
class ClientMock:
|
|
def __init__(self, name: Any, show_io: bool):
|
|
self.name = name
|
|
self.show_io = show_io
|
|
self.inbuffer: List[message_parser.Message] = []
|
|
self.ssl = False
|
|
|
|
def connect(self, hostname: str, port: int) -> None:
|
|
self.conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
|
|
# probably useful for test_buffering, as it relies on chunking
|
|
# the packets to be useful
|
|
self.conn.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
|
|
|
|
self.conn.settimeout(1) # TODO: configurable
|
|
self.conn.connect((hostname, port))
|
|
if self.show_io:
|
|
print("{:.3f} {}: connects to server.".format(time.time(), self.name))
|
|
|
|
def disconnect(self) -> None:
|
|
if self.show_io:
|
|
print("{:.3f} {}: disconnects from server.".format(time.time(), self.name))
|
|
self.conn.close()
|
|
|
|
def starttls(self) -> None:
|
|
assert not self.ssl, "SSL already active."
|
|
self.conn = ssl.wrap_socket(self.conn)
|
|
self.ssl = True
|
|
|
|
def getMessages(
|
|
self, synchronize: bool = True, assert_get_one: bool = False, raw: bool = False
|
|
) -> List[message_parser.Message]:
|
|
"""actually returns List[str] in the rare case where raw=True."""
|
|
__tracebackhide__ = True # Hide from pytest tracebacks on test failure.
|
|
token: Optional[str]
|
|
if synchronize:
|
|
token = "synchronize{}".format(time.monotonic())
|
|
self.sendLine("PING {}".format(token))
|
|
else:
|
|
token = None
|
|
got_pong = False
|
|
data = b""
|
|
(self.inbuffer, messages) = ([], self.inbuffer)
|
|
conn = self.conn
|
|
try:
|
|
while not got_pong:
|
|
try:
|
|
new_data = conn.recv(4096)
|
|
except socket.timeout:
|
|
if not assert_get_one and not synchronize and data == b"":
|
|
# Received nothing
|
|
return []
|
|
if self.show_io:
|
|
print("{:.3f} {}: waiting…".format(time.time(), self.name))
|
|
continue
|
|
except ConnectionResetError:
|
|
raise ConnectionClosed()
|
|
else:
|
|
if not new_data:
|
|
# Connection closed
|
|
raise ConnectionClosed()
|
|
data += new_data
|
|
if not new_data.endswith(b"\r\n"):
|
|
continue
|
|
if not synchronize:
|
|
got_pong = True
|
|
for line in data.decode().split("\r\n"):
|
|
if line:
|
|
if self.show_io:
|
|
print(
|
|
"{time:.3f}{ssl} S -> {client}: {line}".format(
|
|
time=time.time(),
|
|
ssl=" (ssl)" if self.ssl else "",
|
|
client=self.name,
|
|
line=line,
|
|
)
|
|
)
|
|
message = message_parser.parse_message(line)
|
|
if message.command == "PONG" and token in message.params:
|
|
got_pong = True
|
|
elif (
|
|
synchronize
|
|
and message.command == "451"
|
|
and message.params[1] == "PING"
|
|
):
|
|
raise ValueError(
|
|
"Got '451 * PONG'. Did you forget synchronize=False?"
|
|
)
|
|
else:
|
|
if raw:
|
|
messages.append(line) # type: ignore
|
|
else:
|
|
messages.append(message)
|
|
data = b""
|
|
except ConnectionClosed:
|
|
if messages:
|
|
return messages
|
|
else:
|
|
raise
|
|
else:
|
|
return messages
|
|
|
|
def getMessage(
|
|
self,
|
|
filter_pred: Optional[Callable[[message_parser.Message], bool]] = None,
|
|
synchronize: bool = True,
|
|
raw: bool = False,
|
|
) -> message_parser.Message:
|
|
"""Returns str in the rare case where raw=True"""
|
|
__tracebackhide__ = True # Hide from pytest tracebacks on test failure.
|
|
while True:
|
|
if not self.inbuffer:
|
|
self.inbuffer = self.getMessages(
|
|
synchronize=synchronize, assert_get_one=True, raw=raw
|
|
)
|
|
if not self.inbuffer:
|
|
raise NoMessageException()
|
|
message = self.inbuffer.pop(0) # TODO: use dequeue
|
|
if not filter_pred or filter_pred(message):
|
|
return message
|
|
|
|
def sendLine(self, line: Union[str, bytes]) -> None:
|
|
if isinstance(line, str):
|
|
encoded_line = line.encode()
|
|
elif isinstance(line, bytes):
|
|
encoded_line = line
|
|
else:
|
|
raise ValueError(line)
|
|
if not encoded_line.endswith(b"\r\n"):
|
|
encoded_line += b"\r\n"
|
|
try:
|
|
ret = self.conn.sendall(encoded_line) # type: ignore
|
|
except BrokenPipeError:
|
|
raise ConnectionClosed()
|
|
if (
|
|
sys.version_info <= (3, 6) and self.ssl
|
|
): # https://bugs.python.org/issue25951
|
|
assert ret == len(encoded_line), (ret, repr(encoded_line))
|
|
else:
|
|
assert ret is None, ret
|
|
if self.show_io:
|
|
if isinstance(line, str):
|
|
escaped_line = line
|
|
escaped = ""
|
|
else:
|
|
escaped_line = repr(line)
|
|
escaped = " (escaped)"
|
|
print(
|
|
"{time:.3f}{escaped}{ssl} {client} -> S: {line}".format(
|
|
time=time.time(),
|
|
escaped=escaped,
|
|
ssl=" (ssl)" if self.ssl else "",
|
|
client=self.name,
|
|
line=escaped_line.strip("\r\n"),
|
|
)
|
|
)
|