Files
2020-04-08 10:39:44 -07:00

61 lines
1.5 KiB
Python

import collections
import socket
class MockSocket(object):
def __init__(self, recv_bufs, connect_failure=None):
self.recv_bufs = collections.deque(recv_bufs)
self.send_bufs = []
self.closed = False
self.timeouts = []
self.connect_failure = connect_failure
self.connections = []
self.socket_options = []
def sendall(self, value):
self.send_bufs.append(value)
def close(self):
self.closed = True
def recv(self, size):
value = self.recv_bufs.popleft()
if isinstance(value, Exception):
raise value
return value
def settimeout(self, timeout):
self.timeouts.append(timeout)
def connect(self, server):
if isinstance(self.connect_failure, Exception):
raise self.connect_failure
self.connections.append(server)
def setsockopt(self, level, option, value):
self.socket_options.append((level, option, value))
class MockSocketModule(object):
def __init__(self, connect_failure=None):
self.connect_failure = connect_failure
self.sockets = []
def socket(self, family, type): # noqa: A002
socket = MockSocket([], connect_failure=self.connect_failure)
self.sockets.append(socket)
return socket
def __getattr__(self, name):
return getattr(socket, name)
# Compatibility to get a string back from a request
def _str(s):
if type(s) is str:
return s
elif type(s) is bytes:
return s.decode()
else:
return str(s)