diff --git a/siriuspy/siriuspy/clientarch/client.py b/siriuspy/siriuspy/clientarch/client.py index dbd7387d5..902584a50 100644 --- a/siriuspy/siriuspy/clientarch/client.py +++ b/siriuspy/siriuspy/clientarch/client.py @@ -8,7 +8,6 @@ import asyncio as _asyncio import logging as _log -import ssl as _ssl import urllib as _urllib from datetime import timedelta as _timedelta from threading import Thread as _Thread @@ -16,7 +15,12 @@ import numpy as _np import urllib3 as _urllib3 -from aiohttp import ClientSession as _ClientSession +from aiohttp import ( + client_exceptions as _aio_exceptions, + ClientSession as _ClientSession, + TCPConnector as _TCPConn +) + try: from lzstring import LZString as _LZString except: @@ -34,25 +38,57 @@ class ClientArchiver: SERVER_URL = _envars.SRVURL_ARCHIVER ENDPOINT = '/mgmt/bpl' + def __delete__(self): + """Turn off thread when deleting.""" + self.shutdown() + def __init__(self, server_url=None, timeout=None): """Initialize.""" timeout = timeout or ClientArchiver.DEFAULT_TIMEOUT self.session = None self._timeout = timeout self._url = server_url or self.SERVER_URL - self._ret = None self._request_url = None - # print('urllib3 InsecureRequestWarning disabled!') + self._thread = self._loop = None + self.connect() _urllib3.disable_warnings(_urllib3.exceptions.InsecureRequestWarning) + def connect(self): + """Starts bg. event loop in a separate thread. + + Raises: + RuntimeError: when library is alread connected. + """ + if self._loop_alive(): + return + + self._loop = _asyncio.new_event_loop() + self._thread = _Thread(target=self._run_event_loop, daemon=True) + self._thread.start() + + def shutdown(self, timeout=5): + """Safely stops the bg. loop and waits for the thread to exit.""" + if not self._loop_alive(): + return + + # 1. Cancel all pending tasks in the loop (to avoid ResourceWarnings) + self._loop.call_soon_threadsafe(self._cancel_all_tasks) + + # 2. Schedule the loop to stop processing + self._loop.call_soon_threadsafe(self._loop.stop) + + # 3. Wait for the thread to actually finish + self._thread.join(timeout=timeout) + if self._thread.is_alive(): + print('Warning: Background thread did not stop in time.') + @property def connected(self): """Connected.""" + if not self._loop_alive(): + return False try: - status = _urllib.request.urlopen( - self._url, timeout=self._timeout, context=_ssl.SSLContext() - ).status - return status == 200 + return bool(self._make_request(self._url + '/mgmt')) except _urllib.error.URLError: return False @@ -94,16 +130,11 @@ def last_requested_url(self): def login(self, username, password): """Open login session.""" - headers = {'User-Agent': 'Mozilla/5.0'} + headers = {'User-Agent': 'Mozilla/5.0', 'Host': 'cnpem.br'} payload = {'username': username, 'password': password} url = self._create_url(method='login') - ret = self._run_async_event_loop( - self._create_session, - url, - headers=headers, - payload=payload, - ssl=False, - ) + coro = self._create_session(url, headers=headers, payload=payload) + ret = self._run_sync_coro(coro) if ret is not None: self.session, authenticated = ret if authenticated: @@ -119,7 +150,8 @@ def login(self, username, password): def logout(self): """Close login session.""" if self.session: - resp = self._run_async_event_loop(self._close_session) + coro = self._close_session() + resp = self._run_sync_coro(coro) self.session = None return resp return None @@ -129,7 +161,7 @@ def getPVsInfo(self, pvnames): if isinstance(pvnames, (list, tuple)): pvnames = ','.join(pvnames) url = self._create_url(method='getPVStatus', pv=pvnames) - resp = self._make_request(url, return_json=True) + resp = self._make_request(url) return None if not resp else resp def getAllPVs(self, pvnames): @@ -137,7 +169,7 @@ def getAllPVs(self, pvnames): if isinstance(pvnames, (list, tuple)): pvnames = ','.join(pvnames) url = self._create_url(method='getAllPVs', pv=pvnames, limit='-1') - resp = self._make_request(url, return_json=True) + resp = self._make_request(url) return None if not resp else resp def deletePVs(self, pvnames): @@ -153,7 +185,7 @@ def deletePVs(self, pvnames): def getPausedPVsReport(self): """Get Paused PVs Report.""" url = self._create_url(method='getPausedPVsReport') - resp = self._make_request(url, return_json=True) + resp = self._make_request(url) return None if not resp else resp def getRecentlyModifiedPVs(self, limit=None, epoch_time=True): @@ -167,7 +199,7 @@ def getRecentlyModifiedPVs(self, limit=None, epoch_time=True): if limit is not None: method += f'?limit={str(limit)}' url = self._create_url(method=method) - resp = self._make_request(url, return_json=True) + resp = self._make_request(url) # convert to epoch, if the case if resp and epoch_time: @@ -284,7 +316,7 @@ def getData( end = len(all_urls) pvn2idcs[pvname_orig[i]] = _np.arange(ini, end) - resps = self._make_request(all_urls, return_json=True) + resps = self._make_request(all_urls) if not resps: return None @@ -331,7 +363,7 @@ def getPVDetails(self, pvname, get_request_url=False): url = self._create_url(method='getPVDetails', pv=pvname) if get_request_url: return url - resp = self._make_request(url, return_json=True) + resp = self._make_request(url) return None if not resp else resp def switch_to_online_data(self): @@ -352,7 +384,7 @@ def gen_archviewer_url_link( time_ref=None, pvoptnrpts=None, pvcolors=None, - pvusediff=False + pvusediff=False, ): """Generate a Archiver Viewer URL for the given PVs. @@ -396,7 +428,8 @@ def gen_archviewer_url_link( # Thanks to Rafael Lyra for the basis of this implementation! archiver_viewer_url = _envars.SRVURL_ARCHIVER_VIEWER + '/?pvConfig=' args = ClientArchiver._process_url_link_args( - pvnames, pvoptnrpts, pvcolors, pvusediff) + pvnames, pvoptnrpts, pvcolors, pvusediff + ) pvoptnrpts, pvcolors, pvusediff = args pv_search = '' for idx in range(len(pvnames)): @@ -455,16 +488,36 @@ def _process_url_link_args(pvnames, pvoptnrpts, pvcolors, pvusediff): pvusediff = [pvusediff] * len(pvnames) return pvoptnrpts, pvcolors, pvusediff - def _make_request(self, url, need_login=False, return_json=False): + def _loop_alive(self): + """Check if thread is alive and loop is running.""" + return ( + self._thread is not None + and self._thread.is_alive() + and self._loop.is_running() + ) + + def _cancel_all_tasks(self): + """Helper to cancel tasks (must be called from the loop's thread).""" + if hasattr(_asyncio, 'all_tasks'): + all_tasks = _asyncio.all_tasks(loop=self._loop) + else: # python 3.6 + all_tasks = _asyncio.Task.all_tasks(loop=self._loop) + + for task in all_tasks: + task.cancel() + + def _run_event_loop(self): + _asyncio.set_event_loop(self._loop) + try: + self._loop.run_forever() + finally: + self._loop.close() + + def _make_request(self, url, need_login=False): """Make request.""" self._request_url = url - response = self._run_async_event_loop( - self._handle_request, - url, - return_json=return_json, - need_login=need_login, - ) - return response + coro = self._handle_request(url, need_login=need_login) + return self._run_sync_coro(coro) def _create_url(self, method, **kwargs): """Create URL.""" @@ -479,98 +532,67 @@ def _create_url(self, method, **kwargs): url += '&'.join(['{}={}'.format(k, v) for k, v in kwargs.items()]) return url - # ---------- async methods ---------- - - def _run_async_event_loop(self, *args, **kwargs): - # NOTE: Run the asyncio commands in a separated Thread to isolate - # their EventLoop from the external environment (important for class - # to work within jupyter notebook environment). - _thread = _Thread( - target=self._thread_run_async_event_loop, - daemon=True, - args=args, - kwargs=kwargs, - ) - _thread.start() - _thread.join() - return self._ret - - def _thread_run_async_event_loop(self, func, *args, **kwargs): - """Get event loop.""" - close = False - try: - loop = _asyncio.get_event_loop() - except RuntimeError as error: - if 'no current event loop' in str(error): - loop = _asyncio.new_event_loop() - _asyncio.set_event_loop(loop) - close = True - else: - raise error - try: - self._ret = loop.run_until_complete(func(*args, **kwargs)) - except _asyncio.TimeoutError: - raise _exceptions.TimeoutError + def _run_sync_coro(self, coro): + """Run an async coroutine synchronously, compatible with Jupyter.""" + if not self._thread.is_alive(): + raise RuntimeError('Library is shut down') + future = _asyncio.run_coroutine_threadsafe(coro, self._loop) + return future.result(timeout=self._timeout) - if close: - loop.close() + # ---------- async methods ---------- - async def _handle_request(self, url, return_json=False, need_login=False): + async def _handle_request(self, url, need_login=False): """Handle request.""" if self.session is not None: - response = await self._get_request_response( - url, self.session, return_json - ) + response = await self._get_request_response(url, self.session) elif need_login: raise _exceptions.AuthenticationError('You need to login first.') else: - async with _ClientSession() as sess: - response = await self._get_request_response( - url, sess, return_json - ) + async with _ClientSession(connector=_TCPConn(ssl=False)) as sess: + response = await self._get_request_response(url, sess) return response - async def _get_request_response(self, url, session, return_json): + async def _get_request_response(self, url, session): """Get request response.""" + single = isinstance(url, str) + url = [url] if single else url try: - if isinstance(url, list): - response = await _asyncio.gather(*[ - session.get(u, ssl=False, timeout=self._timeout) - for u in url - ]) - if any([not r.ok for r in response]): - return None - if return_json: - jsons = list() - for res in response: - try: - data = await res.json() - jsons.append(data) - except ValueError: - _log.error(f'Error with URL {res.url}') - jsons.append(None) - response = jsons - else: - response = await session.get( - url, ssl=False, timeout=self._timeout - ) - if not response.ok: - return None - if return_json: - try: - response = await response.json() - except ValueError: - _log.error(f'Error with URL {response.url}') - response = None - except _asyncio.TimeoutError as err_msg: - raise _exceptions.TimeoutError(err_msg) + response = await _asyncio.gather(*[ + self._fetch_url(session, u) for u in url + ]) + except _asyncio.TimeoutError as err: + raise _exceptions.TimeoutError( + 'Timeout reached. Try to increase `timeout`.' + ) from err + except _aio_exceptions.ClientPayloadError as err: + raise _exceptions.PayloadError( + "Payload Error. Increasing `timeout` won't help. " + 'Try:\n - decreasing `query_bin_interval`;' + '\n - or decrease the time interval for the aquisition;' + ) from err + + if single: + return response[0] return response - async def _create_session(self, url, headers, payload, ssl): + async def _fetch_url(self, session, url): + async with session.get(url, timeout=self._timeout) as response: + if response.status != 200: + return None + try: + return await response.json() + except _aio_exceptions.ContentTypeError: + # for cases where response returns html (self.connected). + return await response.text() + except ValueError: + _log.error('Error with URL %s', response.url) + return None + + async def _create_session(self, url, headers, payload): """Create session and handle login.""" - session = _ClientSession() + session = _ClientSession(connector=_TCPConn(ssl=False)) async with session.post( - url, headers=headers, data=payload, ssl=ssl, timeout=self._timeout + url, headers=headers, data=payload, timeout=self._timeout ) as response: content = await response.content.read() authenticated = b'authenticated' in content