Add unittests and fix bugs

* Improved error messages
* Improved checking of user input

Signed-off-by: Jhon Honce <jhonce@redhat.com>

Closes: #978
Approved by: mheon
This commit is contained in:
Jhon Honce
2018-06-20 19:14:27 -07:00
committed by Atomic Bot
parent 3092d20847
commit 2f0f9944b6
6 changed files with 181 additions and 44 deletions

View File

@ -1,4 +1,5 @@
"""A client for communicating with a Podman varlink service.""" """A client for communicating with a Podman varlink service."""
import errno
import os import os
from urllib.parse import urlparse from urllib.parse import urlparse
@ -32,33 +33,47 @@ class BaseClient(object):
if interface is None: if interface is None:
raise ValueError('interface is required and cannot be None') raise ValueError('interface is required and cannot be None')
unsupported = set(kwargs.keys()).difference(
('uri', 'interface', 'remote_uri', 'identity_file'))
if unsupported:
raise ValueError('Unknown keyword arguments: {}'.format(
', '.join(unsupported)))
local_path = urlparse(uri).path local_path = urlparse(uri).path
if local_path == '': if local_path == '':
raise ValueError('path is required for uri, format' raise ValueError('path is required for uri,'
' "unix://path_to_socket"') ' expected format "unix://path_to_socket"')
if kwargs.get('remote_uri') or kwargs.get('identity_file'): if kwargs.get('remote_uri') or kwargs.get('identity_file'):
# Remote access requires the full tuple of information # Remote access requires the full tuple of information
if kwargs.get('remote_uri') is None: if kwargs.get('remote_uri') is None:
raise ValueError('remote is required, format' raise ValueError(
' "ssh://user@hostname/path_to_socket".') 'remote is required,'
' expected format "ssh://user@hostname/path_to_socket".')
remote = urlparse(kwargs['remote_uri']) remote = urlparse(kwargs['remote_uri'])
if remote.username is None: if remote.username is None:
raise ValueError('username is required for remote_uri, format' raise ValueError(
' "ssh://user@hostname/path_to_socket".') 'username is required for remote_uri,'
' expected format "ssh://user@hostname/path_to_socket".')
if remote.path == '': if remote.path == '':
raise ValueError('path is required for remote_uri, format' raise ValueError(
' "ssh://user@hostname/path_to_socket".') 'path is required for remote_uri,'
' expected format "ssh://user@hostname/path_to_socket".')
if remote.hostname is None: if remote.hostname is None:
raise ValueError('hostname is required for remote_uri, format' raise ValueError(
' "ssh://user@hostname/path_to_socket".') 'hostname is required for remote_uri,'
' expected format "ssh://user@hostname/path_to_socket".')
if kwargs.get('identity_file') is None: if kwargs.get('identity_file') is None:
raise ValueError('identity_file is required.') raise ValueError('identity_file is required.')
if not os.path.isfile(kwargs['identity_file']): if not os.path.isfile(kwargs['identity_file']):
raise ValueError('identity_file "{}" not found.'.format( raise FileNotFoundError(
kwargs['identity_file'])) errno.ENOENT,
os.strerror(errno.ENOENT),
kwargs['identity_file'],
)
return RemoteClient( return RemoteClient(
Context(uri, interface, local_path, remote.path, Context(uri, interface, local_path, remote.path,
remote.username, remote.hostname, remote.username, remote.hostname,
@ -111,7 +126,7 @@ class RemoteClient(BaseClient):
self._iface = self._client.open(self._context.interface) self._iface = self._client.open(self._context.interface)
return self._iface return self._iface
except Exception: except Exception:
self._close_tunnel(self._context.uri) tunnel.close(self._context.uri)
raise raise
def __exit__(self, e_type, e, e_traceback): def __exit__(self, e_type, e, e_traceback):
@ -154,15 +169,18 @@ class Client(object):
""" """
self._client = BaseClient.factory(uri, interface, **kwargs) self._client = BaseClient.factory(uri, interface, **kwargs)
address = "{}-{}".format(uri, interface)
# Quick validation of connection data provided # Quick validation of connection data provided
try: try:
if not System(self._client).ping(): if not System(self._client).ping():
raise ValueError('Failed varlink connection "{}/{}"'.format( raise ConnectionRefusedError(
uri, interface)) errno.ECONNREFUSED,
'Failed varlink connection "{}"'.format(address), address)
except FileNotFoundError: except FileNotFoundError:
raise ValueError('Failed varlink connection "{}/{}".' raise ConnectionError(
' Is podman service running?'.format( errno.ECONNREFUSED,
uri, interface)) ('Failed varlink connection "{}".'
' Is podman service running?').format(address), address)
def __enter__(self): def __enter__(self):
"""Return `self` upon entering the runtime context.""" """Return `self` upon entering the runtime context."""

View File

@ -21,6 +21,9 @@ class Mixin:
stderr is ignored. stderr is ignored.
""" """
if not self.containerrunning:
raise Exception('you can only attach to running containers')
if stdin is None: if stdin is None:
stdin = sys.stdin.fileno() stdin = sys.stdin.fileno()
@ -48,7 +51,7 @@ class Mixin:
packed = fcntl.ioctl(stdout, termios.TIOCGWINSZ, packed = fcntl.ioctl(stdout, termios.TIOCGWINSZ,
struct.pack('HHHH', 0, 0, 0, 0)) struct.pack('HHHH', 0, 0, 0, 0))
rows, cols, _, _ = struct.unpack('HHHH', packed) rows, cols, _, _ = struct.unpack('HHHH', packed)
# TODO: Need some kind of timeout in case pipe is blocked
with open(ctl_socket, 'w') as skt: with open(ctl_socket, 'w') as skt:
# send conmon window resize message # send conmon window resize message
skt.write('1 {} {}\n'.format(rows, cols)) skt.write('1 {} {}\n'.format(rows, cols))
@ -73,38 +76,37 @@ class Mixin:
# catch any resizing events and send the resize info # catch any resizing events and send the resize info
# to the control fifo "socket" # to the control fifo "socket"
signal.signal(signal.SIGWINCH, resize_handler) signal.signal(signal.SIGWINCH, resize_handler)
except termios.error: except termios.error:
original_attr = None original_attr = None
try: try:
# Prepare socket for communicating with conmon/container # TODO: socket.SOCK_SEQPACKET may not be supported in Windows
with socket.socket(socket.AF_UNIX, socket.SOCK_SEQPACKET) as skt: with socket.socket(socket.AF_UNIX, socket.SOCK_SEQPACKET) as skt:
# Prepare socket for communicating with conmon/container
skt.connect(io_socket) skt.connect(io_socket)
skt.sendall(b'\n') skt.sendall(b'\n')
sources = [skt, stdin] sources = [skt, stdin]
while sources: while sources:
readable, _, _ = select.select(sources, [], []) readable, _, _ = select.select(sources, [], [])
for r in readable: if skt in readable:
if r is skt: data = skt.recv(CONMON_BUFSZ)
data = r.recv(CONMON_BUFSZ) if not data:
if not data: sources.remove(skt)
sources.remove(skt)
# Remove source marker when writing # Remove source marker when writing
os.write(stdout, data[1:]) os.write(stdout, data[1:])
elif r is stdin:
data = os.read(stdin, CONMON_BUFSZ)
if not data:
sources.remove(stdin)
skt.sendall(data) if stdin in readable:
data = os.read(stdin, CONMON_BUFSZ)
if not data:
sources.remove(stdin)
if eot in data: skt.sendall(data)
sources.clear()
break if eot in data:
else: sources.clear()
raise ValueError('Unknown source in select')
finally: finally:
if original_attr: if original_attr:
termios.tcsetattr(stdout, termios.TCSADRAIN, original_attr) termios.tcsetattr(stdout, termios.TCSADRAIN, original_attr)

View File

@ -26,6 +26,7 @@ class Portal(collections.MutableMapping):
self.sweap = sweap self.sweap = sweap
self.ttl = sweap * 2 self.ttl = sweap * 2
self.lock = threading.RLock() self.lock = threading.RLock()
self._schedule_reaper()
def __getitem__(self, key): def __getitem__(self, key):
"""Given uri return tunnel and update TTL.""" """Given uri return tunnel and update TTL."""
@ -73,11 +74,12 @@ class Portal(collections.MutableMapping):
def reap(self): def reap(self):
"""Remove tunnels who's TTL has expired.""" """Remove tunnels who's TTL has expired."""
now = time.time()
with self.lock: with self.lock:
now = time.time() reaped_data = self.data.copy()
for entry, timeout in self.data: for entry in reaped_data.items():
if timeout < now: if entry[1][1] < now:
self.__delitem__(entry) del self.data[entry[0]]
else: else:
# StopIteration as soon as possible # StopIteration as soon as possible
break break
@ -121,7 +123,6 @@ class Tunnel(object):
def close(self, id): def close(self, id):
"""Close SSH tunnel.""" """Close SSH tunnel."""
print('Tunnel collapsed!')
if self._tunnel is None: if self._tunnel is None:
return return

View File

@ -0,0 +1,37 @@
from __future__ import absolute_import
import unittest
from unittest.mock import patch
import podman
from podman.client import BaseClient, Client, LocalClient, RemoteClient
class TestClient(unittest.TestCase):
def setUp(self):
pass
@patch('podman.libs.system.System.ping', return_value=True)
def test_local(self, mock_ping):
p = Client(
uri='unix:/run/podman',
interface='io.projectatomic.podman',
)
self.assertIsInstance(p._client, LocalClient)
self.assertIsInstance(p._client, BaseClient)
mock_ping.assert_called_once()
@patch('os.path.isfile', return_value=True)
@patch('podman.libs.system.System.ping', return_value=True)
def test_remote(self, mock_ping, mock_isfile):
p = Client(
uri='unix:/run/podman',
interface='io.projectatomic.podman',
remote_uri='ssh://user@hostname/run/podmain/podman',
identity_file='~/.ssh/id_rsa')
self.assertIsInstance(p._client, BaseClient)
mock_ping.assert_called_once()
mock_isfile.assert_called_once()

View File

@ -32,8 +32,8 @@ class TestSystem(unittest.TestCase):
uri=local_uri, uri=local_uri,
remote_uri=remote_uri, remote_uri=remote_uri,
identity_file=os.path.expanduser('~/.ssh/id_rsa'), identity_file=os.path.expanduser('~/.ssh/id_rsa'),
) as pclient: ) as remote_client:
pclient.system.ping() remote_client.system.ping()
def test_versions(self): def test_versions(self):
with podman.Client(self.host) as pclient: with podman.Client(self.host) as pclient:

View File

@ -0,0 +1,79 @@
from __future__ import absolute_import
import time
import unittest
from unittest.mock import MagicMock, patch
import podman
from podman.libs.tunnel import Context, Portal, Tunnel
class TestTunnel(unittest.TestCase):
def setUp(self):
self.tunnel_01 = MagicMock(spec=Tunnel)
self.tunnel_02 = MagicMock(spec=Tunnel)
def test_portal_ops(self):
portal = Portal(sweap=500)
portal['unix:/01'] = self.tunnel_01
portal['unix:/02'] = self.tunnel_02
self.assertEqual(portal.get('unix:/01'), self.tunnel_01)
self.assertEqual(portal.get('unix:/02'), self.tunnel_02)
del portal['unix:/02']
with self.assertRaises(KeyError):
portal['unix:/02']
self.assertEqual(len(portal), 1)
def test_portal_reaping(self):
portal = Portal(sweap=0.5)
portal['unix:/01'] = self.tunnel_01
portal['unix:/02'] = self.tunnel_02
self.assertEqual(len(portal), 2)
for entry in portal:
self.assertIn(entry, (self.tunnel_01, self.tunnel_02))
time.sleep(1)
portal.reap()
self.assertEqual(len(portal), 0)
def test_portal_no_reaping(self):
portal = Portal(sweap=500)
portal['unix:/01'] = self.tunnel_01
portal['unix:/02'] = self.tunnel_02
portal.reap()
self.assertEqual(len(portal), 2)
for entry in portal:
self.assertIn(entry, (self.tunnel_01, self.tunnel_02))
@patch('subprocess.Popen')
@patch('os.path.exists', return_value=True)
@patch('weakref.finalize')
def test_tunnel(self, mock_finalize, mock_exists, mock_Popen):
context = Context(
'unix:/01',
'io.projectatomic.podman',
'/tmp/user/socket',
'/run/podman/socket',
'user',
'hostname',
'~/.ssh/id_rsa',
)
tunnel = Tunnel(context).bore('unix:/01')
cmd = [
'ssh',
'-nNT',
'-L',
'{}:{}'.format(context.local_socket, context.remote_socket),
'-i',
context.identity_file,
'ssh://{}@{}'.format(context.username, context.hostname),
]
mock_finalize.assert_called_once_with(tunnel, tunnel.close, 'unix:/01')
mock_exists.assert_called_once_with(context.local_socket)
mock_Popen.assert_called_once_with(cmd, close_fds=True)