Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Cleanups
  • Loading branch information
Panos committed Nov 30, 2020
commit a8b1ad39f66fb30e6a5b5f8cd7ef76f2c307f72f
90 changes: 20 additions & 70 deletions pssh/clients/native/tunnel_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from gevent import socket, spawn, joinall, get_hub, sleep, Timeout as GTimeout
from gevent.pool import Pool
from gevent.server import StreamServer
from gevent.select import poll, select, POLLIN, POLLOUT
from gevent.select import poll, POLLIN, POLLOUT
from ssh2.session import Session, LIBSSH2_SESSION_BLOCK_INBOUND, LIBSSH2_SESSION_BLOCK_OUTBOUND
from ssh2.error_codes import LIBSSH2_ERROR_EAGAIN

Expand All @@ -37,24 +37,20 @@ class ThreadedServer(Thread):

def __init__(self, client):
Thread.__init__(self)
self.servers = []
self.clients = []
self.server = None
self._hub = None
self.client = client
self.started = Event()

def run(self):
"""Thread run target. Starts tunnel client and waits for incoming
tunnel connection requests from ``Tunnel.in_q``."""
self._hub = get_hub()
assert self._hub.main_hub is False
self.server = TunnelServer(self.client)
self.started.set()
# try:
# self._init_tunnel_client()
# except Exception as ex:
# logger.error("Tunnel initilisation failed - %s", ex)
# self.exception = ex
# return
logger.debug("Hub in server runner is main hub: %s", self._hub.main_hub)
try:
self.server.serve_forever()
Expand Down Expand Up @@ -85,34 +81,6 @@ def read_rw(self, socket, address):
self.client.host, self.client.port, ex)
self.exception = ex
return
# events = POLLIN
# sockets = [socket, self.session.sock]
# while True:
# # import ipdb; ipdb.set_trace()
# read_s, write_s, x_sock = select([socket, self.session.sock], [], [], self.timeout)
# # self.poll()
# # self._poll_sockets(sockets, events)
# size, data = channel.read()
# if size > 0:
# socket.sendall(data)
# # self._poll_sockets(sockets, events)
# if socket in read_s:
# with GTimeout(seconds=1):
# try:
# data = socket.recv(1024)
# except GTimeout:
# continue
# data_len = len(data)
# total_written = 0
# if data_len > 0:
# while total_written < data_len:
# rc, written = channel.write(data[total_written:])
# total_written += written
# if rc == LIBSSH2_ERROR_EAGAIN:
# self.poll()
# while channel.close() == LIBSSH2_ERROR_EAGAIN:
# self.poll()
# socket.close()
source = spawn(self._read_forward_sock, socket, channel)
dest = spawn(self._read_channel, socket, channel)
logger.debug("Waiting for read/write greenlets")
Expand All @@ -127,101 +95,83 @@ def _wait_send_receive_lets(self, source, dest, channel, forward_sock):
logger.error(ex)
finally:
logger.debug("Closing channel and forward socket")
channel.close()
while channel.close() == LIBSSH2_ERROR_EAGAIN:
self.poll(timeout=.5)
forward_sock.close()

def _open_channel(self, fw_host, fw_port, local_port):
channel = self.session.direct_tcpip_ex(
fw_host, fw_port, '127.0.0.1',
local_port)
while channel == LIBSSH2_ERROR_EAGAIN:
select((self.client.sock,), (self.client.sock,), ())
self.poll()
channel = self.session.direct_tcpip_ex(
fw_host, fw_port, '127.0.0.1',
local_port)
return channel

def _read_forward_sock(self, forward_sock, channel):
while True:
if channel.eof():
logger.debug("Channel closed")
return
try:
logger.debug("Trying to read from socket")
with GTimeout(seconds=1):
data = forward_sock.recv(1024)
data = forward_sock.recv(1024)
except Exception as ex:
logger.error("Forward socket read error: %s", ex)
sleep(1)
continue
except GTimeout:
logger.debug("Timeout socket read")
self.poll()
continue
data_len = len(data)
logger.debug("Read %s data from forward socket" % (data_len,))
logger.debug("Read %s data from forward socket", data_len,)
if data_len == 0:
continue
data_written = 0
while data_written < data_len:
try:
with GTimeout(seconds=1):
rc, bytes_written = channel.write(data[data_written:])
rc, bytes_written = channel.write(data[data_written:])
except Exception as ex:
logger.error("Channel write error: %s", ex)
sleep(1)
continue
except GTimeout:
logger.debug("Timeout writing to channel")
self.poll()
continue
data_written += bytes_written
if rc == LIBSSH2_ERROR_EAGAIN:
sleep(.2)
logger.debug("Data remaining %s", (data_len - data_written,))
# select((), ((self.client.sock,)), (), timeout=0.001)
self.poll()
logger.debug("Wrote all data to channel")

def _read_channel(self, forward_sock, channel):
while True:
if channel.eof():
logger.debug("Channel closed")
return
try:
with GTimeout(seconds=1):
size, data = channel.read()
size, data = channel.read()
except Exception as ex:
logger.error("Error reading from channel - %s", ex)
sleep(1)
continue
except GTimeout:
logger.debug("Timeout channel read")
self.poll()
continue
logger.debug("Read %s data from channel" % (size,))
if size == LIBSSH2_ERROR_EAGAIN:
sleep(.2)
self.poll()
continue
try:
with GTimeout(seconds=1):
forward_sock.sendall(data)
forward_sock.sendall(data)
except Exception as ex:
logger.error(
"Error sending data to forward socket - %s", ex)
sleep(.5)
continue
except GTimeout:
logger.debug("Timeout writing to socket")
self.poll()
continue
logger.debug("Wrote %s data to forward socket", len(data))

def _open_channel(self, fw_host, fw_port, local_port):
channel = self.session.direct_tcpip_ex(
fw_host, fw_port, '127.0.0.1',
local_port)
while channel == LIBSSH2_ERROR_EAGAIN:
sleep(.1)
self.poll()
# self.session.set_blocking(1)
channel = self.session.direct_tcpip_ex(
fw_host, fw_port, '127.0.0.1',
local_port)
self.session.set_blocking(0)
return channel

def _open_channel_retries(self, fw_host, fw_port, local_port,
Expand Down