Skip to content
Open
242 changes: 132 additions & 110 deletions siriuspy/siriuspy/clientarch/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,19 @@

import asyncio as _asyncio
import logging as _log
import ssl as _ssl
import urllib as _urllib
from datetime import timedelta as _timedelta
from threading import Thread as _Thread
from urllib.parse import quote as _quote

import numpy as _np
import urllib3 as _urllib3
from aiohttp import ClientSession as _ClientSession
from aiohttp import (
client_exceptions as _aio_exceptions,
ClientSession as _ClientSession,
TCPConnector as _TCPConn
)

try:
from lzstring import LZString as _LZString
except:
Expand All @@ -34,25 +38,57 @@ class ClientArchiver:
SERVER_URL = _envars.SRVURL_ARCHIVER
ENDPOINT = '/mgmt/bpl'

def __delete__(self):
"""Turn off thread when deleting."""
self.shutdown()

def __init__(self, server_url=None, timeout=None):
"""Initialize."""
timeout = timeout or ClientArchiver.DEFAULT_TIMEOUT
self.session = None
self._timeout = timeout
self._url = server_url or self.SERVER_URL
self._ret = None
self._request_url = None
# print('urllib3 InsecureRequestWarning disabled!')
self._thread = self._loop = None
self.connect()
_urllib3.disable_warnings(_urllib3.exceptions.InsecureRequestWarning)

def connect(self):
"""Starts bg. event loop in a separate thread.

Raises:
RuntimeError: when library is alread connected.
"""
if self._loop_alive():
return

self._loop = _asyncio.new_event_loop()
self._thread = _Thread(target=self._run_event_loop, daemon=True)
self._thread.start()

def shutdown(self, timeout=5):
"""Safely stops the bg. loop and waits for the thread to exit."""
if not self._loop_alive():
return

# 1. Cancel all pending tasks in the loop (to avoid ResourceWarnings)
self._loop.call_soon_threadsafe(self._cancel_all_tasks)

# 2. Schedule the loop to stop processing
self._loop.call_soon_threadsafe(self._loop.stop)

# 3. Wait for the thread to actually finish
self._thread.join(timeout=timeout)
if self._thread.is_alive():
print('Warning: Background thread did not stop in time.')

@property
def connected(self):
"""Connected."""
if not self._loop_alive():
return False
try:
status = _urllib.request.urlopen(
self._url, timeout=self._timeout, context=_ssl.SSLContext()
).status
return status == 200
return bool(self._make_request(self._url + '/mgmt'))
except _urllib.error.URLError:
return False

Expand Down Expand Up @@ -94,16 +130,11 @@ def last_requested_url(self):

def login(self, username, password):
"""Open login session."""
headers = {'User-Agent': 'Mozilla/5.0'}
headers = {'User-Agent': 'Mozilla/5.0', 'Host': 'cnpem.br'}
payload = {'username': username, 'password': password}
url = self._create_url(method='login')
ret = self._run_async_event_loop(
self._create_session,
url,
headers=headers,
payload=payload,
ssl=False,
)
coro = self._create_session(url, headers=headers, payload=payload)
ret = self._run_sync_coro(coro)
if ret is not None:
self.session, authenticated = ret
if authenticated:
Expand All @@ -119,7 +150,8 @@ def login(self, username, password):
def logout(self):
"""Close login session."""
if self.session:
resp = self._run_async_event_loop(self._close_session)
coro = self._close_session()
resp = self._run_sync_coro(coro)
self.session = None
return resp
return None
Expand All @@ -129,15 +161,15 @@ def getPVsInfo(self, pvnames):
if isinstance(pvnames, (list, tuple)):
pvnames = ','.join(pvnames)
url = self._create_url(method='getPVStatus', pv=pvnames)
resp = self._make_request(url, return_json=True)
resp = self._make_request(url)
return None if not resp else resp

def getAllPVs(self, pvnames):
"""Get All PVs."""
if isinstance(pvnames, (list, tuple)):
pvnames = ','.join(pvnames)
url = self._create_url(method='getAllPVs', pv=pvnames, limit='-1')
resp = self._make_request(url, return_json=True)
resp = self._make_request(url)
return None if not resp else resp

def deletePVs(self, pvnames):
Expand All @@ -153,7 +185,7 @@ def deletePVs(self, pvnames):
def getPausedPVsReport(self):
"""Get Paused PVs Report."""
url = self._create_url(method='getPausedPVsReport')
resp = self._make_request(url, return_json=True)
resp = self._make_request(url)
return None if not resp else resp

def getRecentlyModifiedPVs(self, limit=None, epoch_time=True):
Expand All @@ -167,7 +199,7 @@ def getRecentlyModifiedPVs(self, limit=None, epoch_time=True):
if limit is not None:
method += f'?limit={str(limit)}'
url = self._create_url(method=method)
resp = self._make_request(url, return_json=True)
resp = self._make_request(url)

# convert to epoch, if the case
if resp and epoch_time:
Expand Down Expand Up @@ -284,7 +316,7 @@ def getData(
end = len(all_urls)
pvn2idcs[pvname_orig[i]] = _np.arange(ini, end)

resps = self._make_request(all_urls, return_json=True)
resps = self._make_request(all_urls)
if not resps:
return None

Expand Down Expand Up @@ -331,7 +363,7 @@ def getPVDetails(self, pvname, get_request_url=False):
url = self._create_url(method='getPVDetails', pv=pvname)
if get_request_url:
return url
resp = self._make_request(url, return_json=True)
resp = self._make_request(url)
return None if not resp else resp

def switch_to_online_data(self):
Expand All @@ -352,7 +384,7 @@ def gen_archviewer_url_link(
time_ref=None,
pvoptnrpts=None,
pvcolors=None,
pvusediff=False
pvusediff=False,
):
"""Generate a Archiver Viewer URL for the given PVs.

Expand Down Expand Up @@ -396,7 +428,8 @@ def gen_archviewer_url_link(
# Thanks to Rafael Lyra for the basis of this implementation!
archiver_viewer_url = _envars.SRVURL_ARCHIVER_VIEWER + '/?pvConfig='
args = ClientArchiver._process_url_link_args(
pvnames, pvoptnrpts, pvcolors, pvusediff)
pvnames, pvoptnrpts, pvcolors, pvusediff
)
pvoptnrpts, pvcolors, pvusediff = args
pv_search = ''
for idx in range(len(pvnames)):
Expand Down Expand Up @@ -455,16 +488,36 @@ def _process_url_link_args(pvnames, pvoptnrpts, pvcolors, pvusediff):
pvusediff = [pvusediff] * len(pvnames)
return pvoptnrpts, pvcolors, pvusediff

def _make_request(self, url, need_login=False, return_json=False):
def _loop_alive(self):
"""Check if thread is alive and loop is running."""
return (
self._thread is not None
and self._thread.is_alive()
and self._loop.is_running()
)

def _cancel_all_tasks(self):
"""Helper to cancel tasks (must be called from the loop's thread)."""
if hasattr(_asyncio, 'all_tasks'):
all_tasks = _asyncio.all_tasks(loop=self._loop)
else: # python 3.6
all_tasks = _asyncio.Task.all_tasks(loop=self._loop)

for task in all_tasks:
task.cancel()

def _run_event_loop(self):
_asyncio.set_event_loop(self._loop)
try:
self._loop.run_forever()
finally:
self._loop.close()

def _make_request(self, url, need_login=False):
"""Make request."""
self._request_url = url
response = self._run_async_event_loop(
self._handle_request,
url,
return_json=return_json,
need_login=need_login,
)
return response
coro = self._handle_request(url, need_login=need_login)
return self._run_sync_coro(coro)

def _create_url(self, method, **kwargs):
"""Create URL."""
Expand All @@ -479,98 +532,67 @@ def _create_url(self, method, **kwargs):
url += '&'.join(['{}={}'.format(k, v) for k, v in kwargs.items()])
return url

# ---------- async methods ----------

def _run_async_event_loop(self, *args, **kwargs):
# NOTE: Run the asyncio commands in a separated Thread to isolate
# their EventLoop from the external environment (important for class
# to work within jupyter notebook environment).
_thread = _Thread(
target=self._thread_run_async_event_loop,
daemon=True,
args=args,
kwargs=kwargs,
)
_thread.start()
_thread.join()
return self._ret

def _thread_run_async_event_loop(self, func, *args, **kwargs):
"""Get event loop."""
close = False
try:
loop = _asyncio.get_event_loop()
except RuntimeError as error:
if 'no current event loop' in str(error):
loop = _asyncio.new_event_loop()
_asyncio.set_event_loop(loop)
close = True
else:
raise error
try:
self._ret = loop.run_until_complete(func(*args, **kwargs))
except _asyncio.TimeoutError:
raise _exceptions.TimeoutError
def _run_sync_coro(self, coro):
"""Run an async coroutine synchronously, compatible with Jupyter."""
if not self._thread.is_alive():
raise RuntimeError('Library is shut down')
future = _asyncio.run_coroutine_threadsafe(coro, self._loop)
return future.result(timeout=self._timeout)

if close:
loop.close()
# ---------- async methods ----------

async def _handle_request(self, url, return_json=False, need_login=False):
async def _handle_request(self, url, need_login=False):
"""Handle request."""
if self.session is not None:
response = await self._get_request_response(
url, self.session, return_json
)
response = await self._get_request_response(url, self.session)
elif need_login:
raise _exceptions.AuthenticationError('You need to login first.')
else:
async with _ClientSession() as sess:
response = await self._get_request_response(
url, sess, return_json
)
async with _ClientSession(connector=_TCPConn(ssl=False)) as sess:
response = await self._get_request_response(url, sess)
return response

async def _get_request_response(self, url, session, return_json):
async def _get_request_response(self, url, session):
"""Get request response."""
single = isinstance(url, str)
url = [url] if single else url
try:
if isinstance(url, list):
response = await _asyncio.gather(*[
session.get(u, ssl=False, timeout=self._timeout)
for u in url
])
if any([not r.ok for r in response]):
return None
if return_json:
jsons = list()
for res in response:
try:
data = await res.json()
jsons.append(data)
except ValueError:
_log.error(f'Error with URL {res.url}')
jsons.append(None)
response = jsons
else:
response = await session.get(
url, ssl=False, timeout=self._timeout
)
if not response.ok:
return None
if return_json:
try:
response = await response.json()
except ValueError:
_log.error(f'Error with URL {response.url}')
response = None
except _asyncio.TimeoutError as err_msg:
raise _exceptions.TimeoutError(err_msg)
response = await _asyncio.gather(*[
self._fetch_url(session, u) for u in url
])
except _asyncio.TimeoutError as err:
raise _exceptions.TimeoutError(
'Timeout reached. Try to increase `timeout`.'
) from err
except _aio_exceptions.ClientPayloadError as err:
raise _exceptions.PayloadError(
"Payload Error. Increasing `timeout` won't help. "
'Try:\n - decreasing `query_bin_interval`;'
'\n - or decrease the time interval for the aquisition;'
) from err

if single:
return response[0]
return response

async def _create_session(self, url, headers, payload, ssl):
async def _fetch_url(self, session, url):
async with session.get(url, timeout=self._timeout) as response:
if response.status != 200:
return None
try:
return await response.json()
except _aio_exceptions.ContentTypeError:
# for cases where response returns html (self.connected).
return await response.text()
except ValueError:
_log.error('Error with URL %s', response.url)
return None

async def _create_session(self, url, headers, payload):
"""Create session and handle login."""
session = _ClientSession()
session = _ClientSession(connector=_TCPConn(ssl=False))
async with session.post(
url, headers=headers, data=payload, ssl=ssl, timeout=self._timeout
url, headers=headers, data=payload, timeout=self._timeout
) as response:
content = await response.content.read()
authenticated = b'authenticated' in content
Expand Down
Loading