diff --git a/wmpl/Rebound/REBOUND.py b/wmpl/Rebound/REBOUND.py index 92ef3330..75620897 100644 --- a/wmpl/Rebound/REBOUND.py +++ b/wmpl/Rebound/REBOUND.py @@ -14,7 +14,7 @@ REBOUND_FOUND = True except ImportError: - print("REBOUND package not found. Install REBOUND and reboundx packages to use the REBOUND functions.") + # don't print a message here as its already printed whenever REBOUND_FOUND is False REBOUND_FOUND = False from wmpl.Utils.TrajConversions import ( diff --git a/wmpl/Trajectory/CorrelateDB.py b/wmpl/Trajectory/CorrelateDB.py new file mode 100644 index 00000000..fa0c2354 --- /dev/null +++ b/wmpl/Trajectory/CorrelateDB.py @@ -0,0 +1,840 @@ +# The MIT License + +# Copyright (c) 2024 Mark McIntyre + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +""" Python scripts to manage the WMPL SQLite databases +""" +import os +import sqlite3 +import logging +import logging.handlers +import argparse +import datetime +import json +import numpy as np + +from wmpl.Utils.TrajConversions import datetime2JD, jd2Date + + +log = logging.getLogger("traj_correlator") + +############################################################ +# classes to handle the Observation and Trajectory databases +############################################################ + + +class ObservationsDatabase(): + """ + A class to handle the sqlite observations database transparently. + """ + + def __init__(self, db_path, db_name='observations.db', purge_records=False, verbose=False): + """ + Create an observations database instance + + Parameters: + db_path : path to the location of the database + db_name : name to use, typically observations.db + purge_records : boolean, if true then delete any existing records + + """ + db_full_name = os.path.join(db_path, f'{db_name}') + if verbose: + log.info(f'opening database {db_full_name}') + con = sqlite3.connect(db_full_name) + con.execute('pragma journal_mode=wal') + if purge_records: + con.execute('drop table paired_obs') + res = con.execute("SELECT name FROM sqlite_master WHERE name='paired_obs'") + if res.fetchone() is None: + con.execute("CREATE TABLE paired_obs(obs_id VARCHAR(36) UNIQUE, obs_dt REAL, status INTEGER)") + con.commit() + self.dbhandle = con + + def _commitObsDatabase(self): + """ + Commit the obs db. This function exists so we can do lazy writes + """ + self.dbhandle.commit() + try: + self.dbhandle.execute('pragma wal_checkpoint(TRUNCATE)') + except Exception: + self.dbhandle.execute('pragma wal_checkpoint(PASSIVE)') + return + + def closeObsDatabase(self): + """ + Close the database, making sure we commit any pending updates + """ + + self._commitObsDatabase() + self.dbhandle.close() + self.dbhandle = None + return + + def checkObsPaired(self, obs_id, verbose=False): + """ + Check if an observation is already marked paired + return True if there is an observation with the correct obs id and with status = 1 + + Parameters: + obs_id : observation ID to check + + Returns: + True if paired, False otherwise + """ + + paired = True + cur = self.dbhandle.execute(f"SELECT obs_id FROM paired_obs WHERE obs_id='{obs_id}' and status=1") + if cur.fetchone() is None: + paired = False + if verbose: + log.info(f'{obs_id} is {"Paired" if paired else "Unpaired"}') + return paired + + def addPairedObservations(self, obs_ids, jdt_refs, verbose=False): + """ + Add or update a list of observations paired, setting status = 1 + + Parameters: + obs_ids : list of observation IDs + jdt_refs : list of julian reference dates of the observations + """ + + vals_str = ','.join(map(str,[(id, dt, 1) for id,dt in zip(obs_ids,jdt_refs)])) + + if verbose: + log.info(f'adding {obs_ids} to paired_obs table') + try: + self.dbhandle.execute(f"insert or replace into paired_obs values {vals_str}") + self.dbhandle.commit() + return True + except Exception: + log.warning(f'failed to add {obs_ids} to paired_obs table') + return False + + return + + def addPairedObs(self, obs_id, jdt_ref, verbose=False): + """ + Add or update a single entry in the database to mark an observation paired, setting status = 1 + + Parameters: + obs_id : observation ID + jdt_ref : julian reference date of the observation + """ + + if verbose: + log.info(f'adding {obs_id} to paired_obs table') + try: + self.dbhandle.execute(f"insert or replace into paired_obs values ('{obs_id}', {jdt_ref}, 1)") + self.dbhandle.commit() + return True + except Exception: + log.warning(f'failed to add {obs_id} to paired_obs table') + return False + + def unpairObs(self, obs_ids, verbose=False): + """ + Mark an observation unpaired. + If an entry exists in the database, update the status to 0. + ** Currently unused. ** + + Parameters: + met_obs_list : a list of observation IDs + """ + obs_ids_str = ','.join(obs_ids) + + if verbose: + log.info(f'unpairing {obs_ids_str}') + try: + log.info('update: write to obsdb') + self.dbhandle.execute(f"update paired_obs set status = 0 where obs_id in ({obs_ids_str})") + self.dbhandle.commit() + return True + except Exception: + log.warning(f'failed to unpair {obs_ids_str}') + return False + + def getLinkedObservations(self, jdt_ref): + """ + Return a list of observation IDs linked with a trajectory based on the jdt_ref of the traj + + Parameters + jdt_ref : the julian reference date of the trajectory + + """ + cur = self.dbhandle.execute(f"SELECT obs_id FROM paired_obs WHERE obs_dt={jdt_ref} and status=1") + return [x[0] for x in cur.fetchall()] + + def archiveObsDatabase(self, db_path, arch_prefix, archdate_jd): + """ + archive records older than archdate_jd to a database {arch_prefix}_observations.db + + Parameters: + db_path : path to the location of the archive database + arch_prefix : prefix to apply - typically of the form yyyymm + archdate_jd : julian date before which to archive data + """ + # create the database if it doesnt exist + archdb_name = f'{arch_prefix}_observations.db' + archdb = ObservationsDatabase(db_path, archdb_name) + archdb.closeObsDatabase() + + # attach the arch db, copy the records then delete them + archdb_fullname = os.path.join(db_path, f'{archdb_name}') + self.dbhandle.execute(f"attach database '{archdb_fullname}' as archdb") + try: + # bulk-copy if possible + self.dbhandle.execute(f'insert or replace into archdb.paired_obs select * from paired_obs where obs_date < {archdate_jd}') + except Exception: + # otherwise, one by one + cur = self.dbhandle.execute(f'select * from paired_obs where obs_date < {archdate_jd}') + for row in cur.fetchall(): + try: + self.dbhandle.execute(f"insert into archdb.paired_obs values('{row[0]}','{row[1]}',{row[2]})") + except Exception: + log.info(f'{row[1]} already exists in target') + + log.info('delete: write to obsdb') + self.dbhandle.execute(f'delete from paired_obs where obs_date < {archdate_jd}') + self.dbhandle.commit() + return + + def copyObsJsonRecords(self, paired_obs, dt_range): + """ + Copy recent data from the legacy Json database to the new database. + By design this only copies at most the last seven days, but a date-range can be + provided so that relevant data is copied. + + Parameters: + paired_obs : a json list of paired observations from the old database + dt_range : a date range to operate on - at most seven days duration + + """ + # only copy recent observations since + dt_end = dt_range[1] + dt_beg = max(dt_range[0], dt_end + datetime.timedelta(days=-7)) + + log.info('-----------------------------') + log.info('moving recent observations to sqlite - this may take some time....') + log.info(f'observation date range {dt_beg.isoformat()} to {dt_end.isoformat()}') + + i = 0 + keylist = paired_obs.keys() + for stat_id in keylist: + for obs_id in paired_obs[stat_id]: + try: + obs_date = datetime.datetime.strptime(obs_id.split('_')[1], '%Y%m%d-%H%M%S.%f') + except Exception: + obs_date = datetime.datetime(2000,1,1,0,0,0) + obs_date = obs_date.replace(tzinfo=datetime.timezone.utc) + + if obs_date >= dt_beg and obs_date < dt_end: + self.addPairedObs(obs_id) + i += 1 + if not i % 100000 and i != 0: + log.info(f'moved {i} observations') + self.dbhandle.commit() + log.info(f'done - moved {i} observations') + log.info('-----------------------------') + return + + def mergeObsDatabase(self, source_db_path): + """ + Merge in records from another observation database 'source_db_path', for example from a remote node + + Parameters: + source_db_path : full name and path to the source database to merge from + """ + + if not os.path.isfile(source_db_path): + log.warning(f'source database missing: {source_db_path}') + return + # attach the other db, copy the records then detach it + self.dbhandle.execute(f"attach database '{source_db_path}' as sourcedb") + res = self.dbhandle.execute("SELECT name FROM sourcedb.sqlite_master WHERE name='paired_obs'") + if res.fetchone() is None: + # table is missing so nothing to do + status = True + else: + try: + log.info('insert: write to obsdb') + self.dbhandle.execute('insert or replace into paired_obs select * from sourcedb.paired_obs') + status = True + except Exception as e: + log.info(f'unable to merge child observations from {source_db_path}') + log.info(e) + status = False + + self.dbhandle.commit() + self.dbhandle.execute("detach database 'sourcedb'") + return status + + +############################################################ + + +class TrajectoryDatabase(): + """ + A class to handle the sqlite trajectory database transparently. + """ + + def __init__(self, db_path, db_name='trajectories.db', purge_records=False, verbose=False): + """ + initialise the trajectory database + + Parameters: + db_path : path to the location to store the database + db_name : database name + purge_records : boolean, if true, delete any existing records + """ + + db_full_name = os.path.join(db_path, f'{db_name}') + log.info(f'opening database {db_full_name}') + con = sqlite3.connect(db_full_name) + if purge_records: + log.info('purge: write to trajdb') + con.execute('drop table if exists trajectories') + con.execute('drop table if exists failed_trajectories') + con.commit() + res = con.execute("SELECT name FROM sqlite_master WHERE name='trajectories'") + if res.fetchone() is None: + if verbose: + log.info('create table: write to trajdb') + con.execute("""CREATE TABLE trajectories( + jdt_ref REAL UNIQUE, + traj_id VARCHAR UNIQUE, + traj_file_path VARCHAR, + participating_stations VARCHAR, + ignored_stations VARCHAR, + radiant_eci_mini VARCHAR, + state_vect_mini VARCHAR, + phase_1_only INTEGER, + v_init REAL, + gravity_factor REAL, + v0z REAL, + v_avg REAL, + rbeg_jd REAL, + rend_jd REAL, + rbeg_lat REAL, + rbeg_lon REAL, + rbeg_ele REAL, + rend_lat REAL, + rend_lon REAL, + rend_ele REAL, + status INTEGER) """) + + res = con.execute("SELECT name FROM sqlite_master WHERE name='failed_trajectories'") + if res.fetchone() is None: + # note: traj_id not set as unique as some fails will have traj-id None + if verbose: + log.info('create table: write to trajdb') + con.execute("""CREATE TABLE failed_trajectories( + jdt_ref REAL UNIQUE, + traj_id VARCHAR, + traj_file_path VARCHAR, + participating_stations VARCHAR, + ignored_stations VARCHAR, + radiant_eci_mini VARCHAR, + state_vect_mini VARCHAR, + phase_1_only INTEGER, + v_init REAL, + gravity_factor REAL, + status INTEGER) """) + + con.commit() + self.dbhandle = con + return + + def _commitTrajDatabase(self, verbose=False): + """ + commit the traj db. + This function exists so we can do lazy writes in some cases + """ + + if verbose: + log.info('commit: write to trajdb') + self.dbhandle.commit() + return + + def closeTrajDatabase(self, verbose=False): + """ + close the database, making sure we commit any pending updates + """ + + if verbose: + log.info('commit: write to trajdb') + self._commitTrajDatabase() + self.dbhandle.close() + self.dbhandle = None + return + + + def checkCandIfProcessed(self, jdt_ref, station_list, verbose=False): + """ + check if a candidate was already processed into the database + This function is not currently used. + + Parameters: + jdt_ref : candidate's julian reference date + station_list : candidate's list of stations + + Returns: + True if there is a trajectory with the same jdt_ref and matching list of stations as the candidate + """ + + found = False + res = self.dbhandle.execute(f"SELECT traj_id,participating_stations, ignored_stations FROM failed_trajectories WHERE jdt_ref={jdt_ref} and status=1") + row = res.fetchone() + if row is None: + found = False + else: + traj_stations = list(set(json.loads(row[1]) + json.loads(row[2]))) + found = True if (traj_stations == station_list) else False + if found: + return found + + res = self.dbhandle.execute(f"SELECT traj_id,participating_stations, ignored_stations FROM trajectories WHERE jdt_ref={jdt_ref} and status=1") + row = res.fetchone() + if row is None: + found = False + else: + traj_stations = list(set(json.loads(row[1]) + json.loads(row[2]))) + found = True if (traj_stations == station_list) else False + return found + + def checkTrajIfFailed(self, traj_reduced, verbose=False): + """ + Check if a Trajectory was marked failed + + Parameters: + traj_reduced : a TrajReduced object + + Returns + True if there is a failed trajectory with the same jdt_ref and matching list of stations + """ + + if not hasattr(traj_reduced, 'jdt_ref') or not hasattr(traj_reduced, 'participating_stations') or not hasattr(traj_reduced, 'ignored_stations'): + return False + + found = False + station_list = list(set(traj_reduced.participating_stations + traj_reduced.ignored_stations)) + res = self.dbhandle.execute(f"SELECT traj_id,participating_stations, ignored_stations FROM failed_trajectories WHERE jdt_ref={traj_reduced.jdt_ref} and status=1") + row = res.fetchone() + if row is None: + found = False + else: + traj_stations = list(set(json.loads(row[1]) + json.loads(row[2]))) + found = True if (traj_stations == station_list) else False + return found + + def addTrajectory(self, traj_reduced, failed=False, force_add=True, verbose=False): + """ + add or update an entry in the database, setting status = 1 + + Parameters: + traj_reduced : a TrajReduced object + failed : boolean, if true, add the traj to the fails list + + """ + + tblname = 'failed_trajectories' if failed else 'trajectories' + + # if force_add is false, don't replace any existing entry + if not force_add: + res = self.dbhandle.execute(f'select traj_id from {tblname} where status =1') + row = res.fetchone() + if row is not None and row[0] !='None': + return True + + if verbose: + log.info(f'adding jdt {traj_reduced.jdt_ref} to {tblname}') + + # remove the output_dir part from the path so that the data are location-independent + traj_file_path = traj_reduced.traj_file_path[traj_reduced.traj_file_path.find('trajectories'):] + + # and remove windows-style path separators + traj_file_path = traj_file_path.replace('\\','/') + + if failed: + # fixup possible bad values + traj_id = 'None' if not hasattr(traj_reduced, 'traj_id') or traj_reduced.traj_id is None else traj_reduced.traj_id + v_init = 0 if traj_reduced.v_init is None else traj_reduced.v_init + radiant_eci_mini = [0,0,0] if traj_reduced.radiant_eci_mini is None else traj_reduced.radiant_eci_mini + state_vect_mini = [0,0,0] if traj_reduced.state_vect_mini is None else traj_reduced.state_vect_mini + + sql_str = (f'insert or replace into failed_trajectories values (' + f"{traj_reduced.jdt_ref}, '{traj_id}', '{traj_file_path}'," + f"'{json.dumps(traj_reduced.participating_stations)}'," + f"'{json.dumps(traj_reduced.ignored_stations)}'," + f"'{json.dumps(radiant_eci_mini)}'," + f"'{json.dumps(state_vect_mini)}'," + f"0,{v_init},{traj_reduced.gravity_factor},1)") + else: + sql_str = (f'insert or replace into trajectories values (' + f"{traj_reduced.jdt_ref}, '{traj_reduced.traj_id}', '{traj_file_path}'," + f"'{json.dumps(traj_reduced.participating_stations)}'," + f"'{json.dumps(traj_reduced.ignored_stations)}'," + f"'{json.dumps(traj_reduced.radiant_eci_mini)}'," + f"'{json.dumps(traj_reduced.state_vect_mini)}'," + f"{traj_reduced.phase_1_only},{traj_reduced.v_init},{traj_reduced.gravity_factor}," + f"{traj_reduced.v0z},{traj_reduced.v_avg}," + f"{traj_reduced.rbeg_jd},{traj_reduced.rend_jd}," + f"{traj_reduced.rbeg_lat},{traj_reduced.rbeg_lon},{traj_reduced.rbeg_ele}," + f"{traj_reduced.rend_lat},{traj_reduced.rend_lon},{traj_reduced.rend_ele},1)") + + sql_str = sql_str.replace('nan','"NaN"') + + self.dbhandle.execute(sql_str) + self.dbhandle.commit() + return True + + def removeTrajectory(self, traj_reduced, failed=False, verbose=False): + """ + Mark a trajectory unsolved + If an entry exists, update the status to 0. + + Parameters: + traj_reduced : a TrajReduced object + failed : boolean, if true then remove from the fails list + """ + if verbose: + log.info(f'removing {traj_reduced.traj_id}') + table_name = 'failed_trajectories' if failed else 'trajectories' + + self.dbhandle.execute(f"update {table_name} set status=0 where jdt_ref='{traj_reduced.jdt_ref}'") + self.dbhandle.commit() + + return True + + + def getTrajectories(self, output_dir, jdt_range, failed=False, verbose=False): + """ + Get a list of trajectories between two julian dates + + Parameters: + output_dir : output_dir specified when invoking CorrelateRMS - will be prepended to the trajectory path + jdt_range : tuple of julian dates to retrieve data between. if the 2nd date is None, retrieve all data to today + failed : boolean - if true, retrieve failed traj rather than successful ones + + Returns: + trajs: json list of traj_reduced objects + """ + + jdt_start, jdt_end = jdt_range + + table_name = 'failed_trajectories' if failed else 'trajectories' + if verbose: + log.info(f'getting trajectories between {jd2Date(jdt_start, dt_obj=True).strftime("%Y%m%d_%M%M%S.%f")} and {jd2Date(jdt_end, dt_obj=True).strftime("%Y%m%d_%M%M%S.%f")}') + + if not jdt_end: + self.dbhandle.execute(f"SELECT * FROM {table_name} WHERE jdt_ref={jdt_start}") + rows = cur.fetchall() + else: + rows = self.dbhandle.execute(f"SELECT * FROM {table_name} WHERE jdt_ref>={jdt_start} and jdt_ref<={jdt_end}") + trajs = [] + for rw in rows.fetchall(): + rw = [np.nan if x == 'NaN' else x for x in rw] + json_dict = {'jdt_ref':rw[0], 'traj_id':rw[1], 'traj_file_path':os.path.join(output_dir, rw[2]), + 'participating_stations': json.loads(rw[3]), + 'ignored_stations': json.loads(rw[4]), + 'radiant_eci_mini': json.loads(rw[5]), + 'state_vect_mini': json.loads(rw[6]), + 'phase_1_only': rw[7], 'v_init': rw[8],'gravity_factor': rw[9], + 'v0z': rw[10], 'v_avg': rw[11], + 'rbeg_jd': rw[12], 'rend_jd': rw[13], + 'rbeg_lat': rw[14], 'rbeg_lon': rw[15], 'rbeg_ele': rw[16], + 'rend_lat': rw[17], 'rend_lon': rw[18], 'rend_ele': rw[19] + } + + trajs.append(json_dict) + return trajs + + def getTrajBasics(self, output_dir, jdt_range, failed=False, verbose=False): + """ + Get a list of minimal trajectory details between two dates + + Parameters: + output_dir : output_dir specified when invoking CorrelateRMS - will be prepended to the trajectory path + jdt_range : tuple of julian dates to retrieve data betwee + failed : boolean, if true retrieve names of fails, otherwise retrieve successful + + Returns: + trajs: a json list of tuples of {jdt_ref, traj_id, traj_file_path} + + """ + + jdt_start, jdt_end = jdt_range + table_name = 'failed_trajectories' if failed else 'trajectories' + if not jdt_start: + cur = self.dbhandle.execute(f"SELECT jdt_ref, traj_id, traj_file_path FROM {table_name} where status=1") + rows = cur.fetchall() + elif not jdt_end: + cur = self.dbhandle.execute(f"SELECT jdt_ref, traj_id, traj_file_path FROM {table_name} WHERE jdt_ref={jdt_start} and status=1") + rows = cur.fetchall() + else: + cur = self.dbhandle.execute(f"SELECT jdt_ref, traj_id, traj_file_path FROM {table_name} WHERE jdt_ref>={jdt_start} and jdt_ref<={jdt_end} and status=1") + rows = cur.fetchall() + trajs = [] + for rw in rows: + trajs.append({'jdt_ref':rw[0], 'traj_id':rw[1], 'traj_file_path':os.path.join(output_dir, rw[2])}) + return trajs + + def archiveTrajDatabase(self, db_path, arch_prefix, archdate_jd): + """ + # archive records older than archdate_jd to a database {arch_prefix}_trajectories.db + + Parameters: + db_path : path to the location of the archive database + arch_prefix : prefix to apply - typically of the form yyyymm + archdate_jd : julian date before which to archive data + + """ + + # create the archive database if it doesnt exist + archdb_name = f'{arch_prefix}_trajectories.db' + archdb = TrajectoryDatabase(db_path, archdb_name) + archdb.closeTrajDatabase() + + # attach the arch db, copy the records then delete them + archdb_fullname = os.path.join(db_path, f'{archdb_name}') + cur = self.dbhandle.execute(f"attach database '{archdb_fullname}' as archdb") + log.info('delete: write to trajdb') + for table_name in ['trajectories', 'failed_trajectories']: + try: + # bulk-copy if possible + cur.execute(f'insert or replace into archdb.{table_name} select * from {table_name} where jdt_ref < {archdate_jd}') + cur.execute(f'delete from {table_name} where jdt_ref < {archdate_jd}') + except Exception: + log.warning(f'unable to archive {table_name}') + + self.dbhandle.commit() + return + + def copyTrajJsonRecords(self, trajectories, dt_range, failed=True): + """ + Copy trajectories from the old Json database + We only copy recent failed traj records since if we ever run for an historic date + its likely we will want to reanalyse all available data + + Parameters: + + trajectories : json list of trajetories extracted from the old Json DB + dt_range: : date range to use, at most seven days at a time + failed : boolean, default true to move failed traj + + """ + jd_end = datetime2JD(dt_range[1]) + jd_beg = max(datetime2JD(dt_range[0]), jd_end - 7) + + log.info('moving recent failed trajectories to sqlite - this may take some time....') + log.info(f'observation date range {jd2Date(jd_beg, dt_obj=True).isoformat()} to {dt_range[1].isoformat()}') + + keylist = [k for k in trajectories.keys() if float(k) >= jd_beg and float(k) <= jd_end] + i = 0 # just in case there aren't any trajectories to move + for i,jdt_ref in enumerate(keylist): + self.addTrajectory(trajectories[jdt_ref], failed=failed) + i += 1 + if not i % 10000: + self._commitTrajDatabase() + log.info(f'moved {i} failed_trajectories') + self._commitTrajDatabase() + log.info(f'done - moved {i} failed_trajectories') + + return + + def mergeTrajDatabase(self, source_db_path): + """ + merge in records from another observation database, for example from a remote node + + Parameters: + source_db_path : the full name of the source database from which to merge in records + + """ + + if not os.path.isfile(source_db_path): + log.warning(f'source database missing: {source_db_path}') + return + # attach the other db, copy the records then detach it + log.info('insert: write to trajdb') + cur = self.dbhandle.execute(f"attach database '{source_db_path}' as sourcedb") + + # TODO need to correct the traj_file_path to account for server locations + + status = True + for table_name in ['trajectories', 'failed_trajectories']: + try: + # bulk-copy if possible + cur.execute(f'insert or replace into {table_name} select * from sourcedb.{table_name}') + except Exception: + log.warning(f'unable to merge data from {source_db_path}') + status = False + self.dbhandle.commit() + cur.execute("detach database 'sourcedb'") + return status + +################################################################################## +# dummy classes for use in the above. +# We can't import from CorrelateRMS as this would create a circular reference + + +class DummyTrajReduced(): + """ + a dummy class for handling TrajReduced objects. + We can't import CorrelateRMS as that would create a circular dependency + """ + def __init__(self, jdt_ref=None, traj_id=None, traj_file_path=None, json_dict=None): + if json_dict is None: + self.jdt_ref = jdt_ref + self.traj_id = traj_id + self.traj_file_path = traj_file_path + else: + self.__dict__ = json_dict + + +class dummyDatabaseJSON(): + """ + Dummy class to handle the old Json data format + We can't import CorrelateRMS as that would create a circular dependency + """ + def __init__(self, db_dir, dt_range=None): + self.db_file_path = os.path.join(db_dir, 'processed_trajectories.json') + self.paired_obs = {} + self.failed_trajectories = {} + if os.path.exists(self.db_file_path): + self.__dict__ = json.load(open(self.db_file_path)) + if hasattr(self, 'trajectories'): + # Convert trajectories from JSON to TrajectoryReduced objects + traj_dict = getattr(self, "failed_trajectories") + trajectories_obj_dict = {} + for traj_json in traj_dict: + traj_reduced_tmp = DummyTrajReduced(json_dict=traj_dict[traj_json]) + trajectories_obj_dict[traj_reduced_tmp.jdt_ref] = traj_reduced_tmp + setattr(self, "failed_trajectories", trajectories_obj_dict) + + +################################################################################## + + +if __name__ == '__main__': + arg_parser = argparse.ArgumentParser(description="""Automatically compute trajectories from RMS data in the given directory.""", + formatter_class=argparse.RawTextHelpFormatter) + + arg_parser.add_argument('--dir_path', type=str, default=None, help='Path to the directory containing the databases.') + + arg_parser.add_argument('--database', type=str, default=None, help='Database to process, either observations or trajectories') + + arg_parser.add_argument('--action', type=str, default=None, help='Action to take on the database') + + arg_parser.add_argument('--stmt', type=str, default=None, help='statement to execute eg "select * from paired_obs"') + + arg_parser.add_argument("--logdir", type=str, default=None, + help="Path to the directory where the log files will be stored. If not given, a logs folder will be created in the database folder") + + arg_parser.add_argument('-r', '--timerange', metavar='TIME_RANGE', + help="""Apply action to this date range in the format: "(YYYYMMDD-HHMMSS,YYYYMMDD-HHMMSS)".""", type=str) + + cml_args = arg_parser.parse_args() + # Find the log directory + log_dir = cml_args.logdir + if log_dir is None: + log_dir = os.path.join(cml_args.dir_path, 'logs') + if not os.path.isdir(log_dir): + os.makedirs(log_dir) + log.setLevel(logging.DEBUG) + + # Init the log formatter + log_formatter = logging.Formatter( + fmt='%(asctime)s-%(levelname)-5s-%(module)-15s:%(lineno)-5d- %(message)s', + datefmt='%Y/%m/%d %H:%M:%S') + + # Init the file handler + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + log_file = os.path.join(log_dir, f"correlate_db_{timestamp}.log") + file_handler = logging.handlers.TimedRotatingFileHandler(log_file, when="midnight", backupCount=7) + file_handler.setFormatter(log_formatter) + log.addHandler(file_handler) + + # Init the console handler (i.e. print to console) + console_handler = logging.StreamHandler() + console_handler.setFormatter(log_formatter) + log.addHandler(console_handler) + + if cml_args.database: + dbname = cml_args.database.lower() + action = cml_args.action.lower() + + stmt = cml_args.stmt + + dt_range = None + if cml_args.timerange is not None: + time_beg, time_end = cml_args.timerange.strip("(").strip(")").split(",") + dt_beg = datetime.datetime.strptime(time_beg, "%Y%m%d-%H%M%S").replace(tzinfo=datetime.timezone.utc) + dt_end = datetime.datetime.strptime(time_end, "%Y%m%d-%H%M%S").replace(tzinfo=datetime.timezone.utc) + log.info("Custom time range:") + log.info(" BEG: {:s}".format(str(dt_beg))) + log.info(" END: {:s}".format(str(dt_end))) + dt_range = [dt_beg, dt_end] + + + if action == 'copy': + if dt_range is None: + log.info('Date range must be provided for copy operation') + else: + dt_range_jd = [datetime2JD(dt_range[0]),datetime2JD(dt_range[1])] + jsondb = dummyDatabaseJSON(db_dir=cml_args.dir_path) + obsdb = ObservationsDatabase(cml_args.dir_path) + obsdb.copyObsJsonRecords(jsondb.paired_obs, dt_range) + obsdb.closeObsDatabase() + trajdb = TrajectoryDatabase(cml_args.dir_path) + trajdb.copyTrajJsonRecords(jsondb.failed_trajectories, dt_range, failed=True) + trajdb.closeTrajDatabase() + else: + if dbname == 'observations': + obsdb = ObservationsDatabase(cml_args.dir_path) + if action == 'status': + cur = obsdb.dbhandle.execute('select * from paired_obs where status=1') + print(f'there are {len(cur.fetchall())} paired obs') + cur = obsdb.dbhandle.execute('select * from paired_obs where status=0') + print(f'and {len(cur.fetchall())} unpaired obs') + if action == 'execute': + print(stmt) + cur = obsdb.dbhandle.execute(stmt) + for rw in cur.fetchall(): + print(rw) + obsdb.closeObsDatabase() + + elif dbname == 'trajectories': + trajdb = TrajectoryDatabase(cml_args.dir_path) + if action == 'status': + cur = trajdb.dbhandle.execute('select * from trajectories where status=1') + print(f'there are {len(cur.fetchall())} successful trajectories') + cur = trajdb.dbhandle.execute('select * from failed_trajectories') + print(f'and {len(cur.fetchall())} failed trajectories') + if action == 'execute': + print(stmt) + cur = trajdb.dbhandle.execute(stmt) + for rw in cur.fetchall(): + print(rw) + trajdb.closeTrajDatabase() + else: + log.info('valid database not specified') diff --git a/wmpl/Trajectory/CorrelateEngine.py b/wmpl/Trajectory/CorrelateEngine.py index 52ff61f1..46dbacdf 100644 --- a/wmpl/Trajectory/CorrelateEngine.py +++ b/wmpl/Trajectory/CorrelateEngine.py @@ -8,7 +8,6 @@ import multiprocessing import logging import os - import numpy as np from wmpl.Trajectory.Trajectory import ObservedPoints, PlaneIntersection, Trajectory, moveStateVector @@ -18,12 +17,36 @@ from wmpl.Utils.ShowerAssociation import associateShowerTraj from wmpl.Utils.TrajConversions import J2000_JD, geo2Cartesian, cartesian2Geo, raDec2AltAz, altAz2RADec, \ raDec2ECI, datetime2JD, jd2Date, equatorialCoordPrecession_vect +from wmpl.Utils.Pickling import loadPickle + +MCMODE_NONE = 0 +MCMODE_PHASE1 = 1 +MCMODE_PHASE2 = 2 +MCMODE_CANDS = 4 +MCMODE_SIMPLE = MCMODE_CANDS + MCMODE_PHASE1 +MCMODE_BOTH = MCMODE_PHASE1 + MCMODE_PHASE2 +MCMODE_ALL = MCMODE_CANDS + MCMODE_PHASE1 + MCMODE_PHASE2 # Grab the logger from the main thread log = logging.getLogger("traj_correlator") +def getMcModeStr(mcmode, strtype=0): + modestrs = {4:'cands', 1:'simple', 2:'mcphase', 5:'candsimple', 3:'simplemc',7:'full',0:'full'} + fullmodestrs = {4:'CANDIDATE STAGE', 1:'SIMPLE STAGE', 2:'MONTE CARLO STAGE', 7:'FULL',0:'FULL'} + if strtype == 0: + if mcmode in fullmodestrs.keys(): + return fullmodestrs[mcmode] + else: + return 'MIXED' + else: + if mcmode in modestrs.keys(): + return modestrs[mcmode] + else: + return False + + def pickBestStations(obslist, max_stns): """ Find the stations with the best statistics @@ -239,6 +262,8 @@ def __init__(self, data_handle, traj_constraints, v_init_part, data_in_j2000=Tru # enable OS style ground maps if true self.enableOSM = enableOSM + self.candidatemode = None + def trajectoryRangeCheck(self, traj_reduced, platepar): """ Check that the trajectory is within the range limits. @@ -601,7 +626,7 @@ def initTrajectory(self, jdt_ref, mc_runs, verbose=False): return traj - def solveTrajectory(self, traj, mc_runs, mcmode=0, matched_obs=None, orig_traj=None): + def solveTrajectory(self, traj, mc_runs, mcmode=MCMODE_ALL, matched_obs=None, orig_traj=None, verbose=False): """ Given an initialized Trajectory object with observation, run the solver and automatically reject bad observations. @@ -630,9 +655,10 @@ def solveTrajectory(self, traj, mc_runs, mcmode=0, matched_obs=None, orig_traj=N # make a note of how many observations are already marked ignored. initial_ignore_count = len([obs for obs in traj.observations if obs.ignore_station]) log.info(f'initially ignoring {initial_ignore_count} stations...') + successful_traj_fit = False - # run the first phase of the solver if mcmode is 0 or 1 - if mcmode < 2: + # run the first phase of the solver if mcmode is MCMODE_PHASE1 + if mcmode & MCMODE_PHASE1: # Disable Monte Carlo runs until an initial stable set of observations is found traj.monte_carlo = False @@ -644,6 +670,7 @@ def solveTrajectory(self, traj, mc_runs, mcmode=0, matched_obs=None, orig_traj=N except ValueError as e: log.info("Error during trajectory estimation!") print(e) + # TODO do we need to add the trajectory to the failed traj database here? return False @@ -707,7 +734,8 @@ def solveTrajectory(self, traj, mc_runs, mcmode=0, matched_obs=None, orig_traj=N max_rejections_possible = int(np.ceil(0.5*len(traj_status.observations))) + initial_ignore_count log.info(f'max stations allowed to be rejected is {max_rejections_possible}') for i, obs in enumerate(traj_status.observations): - + if obs.ignore_station: + continue # Compute the median angular uncertainty of all other non-ignored stations ang_res_list = [obstmp.ang_res_std for j, obstmp in enumerate(traj_status.observations) if (i != j) and not obstmp.ignore_station] @@ -718,10 +746,6 @@ def solveTrajectory(self, traj, mc_runs, mcmode=0, matched_obs=None, orig_traj=N ang_res_median = np.median(ang_res_list) - # ### DEBUG PRINT - # print(obs.station_id, 'ang res:', np.degrees(obs.ang_res_std)*3600, \ - # np.degrees(ang_res_median)*3600) - # Check if the current observations is larger than the minimum limit, and # outside the median limit or larger than the maximum limit if (obs.ang_res_std > np.radians(self.traj_constraints.min_arcsec_err/3600)) \ @@ -795,19 +819,26 @@ def solveTrajectory(self, traj, mc_runs, mcmode=0, matched_obs=None, orig_traj=N # Init a new trajectory object (make sure to use the new reference Julian date) - traj = self.initTrajectory(traj_status.jdt_ref, mc_runs, verbose=False) + traj = self.initTrajectory(traj_status.jdt_ref, mc_runs, verbose=verbose) # Disable Monte Carlo runs until an initial stable set of observations is found traj.monte_carlo = False - # Reinitialize the observations, rejecting the ignored stations + # Reinitialize the observations. Note we *include* the ignored obs as they're internally marked ignored + # and so will be skipped, but to avoid confusion in the logs we only print the names of the non-ignored ones for obs in traj_status.observations: + traj.infillWithObs(obs) if not obs.ignore_station: log.info(f'Adding {obs.station_id}') - traj.infillWithObs(obs) log.info("") - log.info(f'Rerunning the trajectory solution with {len(traj.observations)} stations...') + active_stns = len([obs for obs in traj.observations if not obs.ignore_station]) + if active_stns < 2: + log.info(f"Only {active_stns} stations left - trajectory estimation failed!") + skip_trajectory = True + break + + log.info(f'Rerunning the trajectory solution with {active_stns} stations...') # Re-run the trajectory solution try: traj_status = traj.run() @@ -816,7 +847,8 @@ def solveTrajectory(self, traj, mc_runs, mcmode=0, matched_obs=None, orig_traj=N except ValueError as e: log.info("Error during trajectory estimation!") print(e) - return False + skip_trajectory = True + break # If the trajectory estimation failed, skip this trajectory @@ -835,33 +867,23 @@ def solveTrajectory(self, traj, mc_runs, mcmode=0, matched_obs=None, orig_traj=N # Skip the trajectory if no good solution was found if skip_trajectory: - # Add the trajectory to the list of failed trajectories - self.dh.addTrajectory(traj, failed_jdt_ref=jdt_ref) - log.info("Trajectory skipped and added to fails!") + ref_dt = jd2Date(min([met_obs.jdt_ref for met_obs in traj.observations]), dt_obj=True) + log.info(f"Trajectory at {ref_dt.isoformat()} skipped and added to fails!") + self.dh.addTrajectory(traj, failed_jdt_ref=jdt_ref, verbose=verbose) return False - # # If the trajectory solutions was not done at any point, skip the trajectory completely - # if traj_best is None: - # return False - - # # Otherwise, use the best trajectory solution until the solving failed - # else: - # log.info("Using previously estimated best trajectory...") - # traj_status = traj_best - - # If there are only two stations, make sure to reject solutions which have stations with # residuals higher than the maximum limit if len(traj_status.observations) == 2: if np.any([(obstmp.ang_res_std > np.radians(self.traj_constraints.max_arcsec_err/3600)) for obstmp in traj_status.observations]): + ref_dt = jd2Date(min([met_obs.jdt_ref for met_obs in traj.observations]), dt_obj=True) log.info("2 station only solution, one station has an error above the maximum limit, skipping!") - # Add the trajectory to the list of failed trajectories - self.dh.addTrajectory(traj_status, failed_jdt_ref=jdt_ref) - + log.info(f"Trajectory at {ref_dt.isoformat()} skipped and added to fails!") + self.dh.addTrajectory(traj_status, failed_jdt_ref=jdt_ref, verbose=verbose) return False @@ -869,7 +891,7 @@ def solveTrajectory(self, traj, mc_runs, mcmode=0, matched_obs=None, orig_traj=N traj = traj_status # if we're only doing the simple solution, then print the results - if mcmode == 1: + if mcmode == MCMODE_PHASE1: # Only proceed if the orbit could be computed if traj.orbit.ra_g is not None: # Update trajectory file name @@ -885,18 +907,16 @@ def solveTrajectory(self, traj, mc_runs, mcmode=0, matched_obs=None, orig_traj=N else: shower_code = shower_obj.IAU_code log.info("Shower: {:s}".format(shower_code)) + + if mcmode & MCMODE_PHASE1: successful_traj_fit = True log.info('finished initial solution') ##### end of simple soln phase ##### now run the Monte-carlo phase, if the mcmode is 0 (do both) or 2 (mc-only) - if mcmode == 0 or mcmode == 2: - if mcmode == 2: - traj_status = traj + if mcmode & MCMODE_PHASE2: + traj_status = traj - # save the traj in case we need to clean it up - save_traj = traj - # Only proceed if the orbit could be computed if traj.orbit.ra_g is not None: @@ -905,7 +925,7 @@ def solveTrajectory(self, traj, mc_runs, mcmode=0, matched_obs=None, orig_traj=N log.info("Stable set of observations found, computing uncertainties using Monte Carlo...") # Init a new trajectory object (make sure to use the new reference Julian date) - traj = self.initTrajectory(traj_status.jdt_ref, mc_runs, verbose=False) + traj = self.initTrajectory(traj_status.jdt_ref, mc_runs, verbose=verbose) # Enable Monte Carlo traj.monte_carlo = True @@ -918,7 +938,7 @@ def solveTrajectory(self, traj, mc_runs, mcmode=0, matched_obs=None, orig_traj=N # Don't do this in mc-only mode since phase1 has already selected the stations and we could # create duplicate orbits if we now exclude some stations from the solution # TODO should we do this here *at all* ? - if len(non_ignored_observations) > self.traj_constraints.max_stations and mcmode != 2: + if len(non_ignored_observations) > self.traj_constraints.max_stations and mcmode != MCMODE_PHASE2: # Sort the observations by residuals (smallest first) # TODO: implement better sorting algorithm @@ -951,7 +971,6 @@ def solveTrajectory(self, traj, mc_runs, mcmode=0, matched_obs=None, orig_traj=N except ValueError as e: log.info("Error during trajectory estimation!") print(e) - self.dh.cleanupPhase2TempPickle(save_traj) return False @@ -959,10 +978,10 @@ def solveTrajectory(self, traj, mc_runs, mcmode=0, matched_obs=None, orig_traj=N if traj_status is None: # Add the trajectory to the list of failed trajectories - if mcmode != 2: - self.dh.addTrajectory(traj, failed_jdt_ref=jdt_ref) - log.info('Trajectory failed to solve') - self.dh.cleanupPhase2TempPickle(save_traj) + if mcmode != MCMODE_PHASE2: + self.dh.addTrajectory(traj, failed_jdt_ref=jdt_ref, verbose=verbose) + ref_dt = jd2Date(min([met_obs.jdt_ref for met_obs in traj.observations]), dt_obj=True) + log.info(f"Trajectory at {ref_dt.isoformat()} skipped and added to fails!") return False @@ -975,7 +994,6 @@ def solveTrajectory(self, traj, mc_runs, mcmode=0, matched_obs=None, orig_traj=N log.info("Average velocity outside range: {:.1f} < {:.1f} < {:.1f} km/s, skipping...".format(self.traj_constraints.v_avg_min, traj.orbit.v_avg/1000, self.traj_constraints.v_avg_max)) - self.dh.cleanupPhase2TempPickle(save_traj) return False @@ -983,14 +1001,12 @@ def solveTrajectory(self, traj, mc_runs, mcmode=0, matched_obs=None, orig_traj=N for obs in traj.observations: if (obs.rbeg_ele is None) and (not obs.ignore_station): log.info("Heights from observations failed to be estimated!") - self.dh.cleanupPhase2TempPickle(save_traj) return False # Check that the orbit could be computed if traj.orbit.ra_g is None: log.info("The orbit could not be computed!") - self.dh.cleanupPhase2TempPickle(save_traj) return False # Set the trajectory fit as successful @@ -1015,7 +1031,6 @@ def solveTrajectory(self, traj, mc_runs, mcmode=0, matched_obs=None, orig_traj=N else: log.info("The orbit could not be computed!") - self.dh.cleanupPhase2TempPickle(save_traj) return False @@ -1023,77 +1038,193 @@ def solveTrajectory(self, traj, mc_runs, mcmode=0, matched_obs=None, orig_traj=N # Save the trajectory if successful. if successful_traj_fit: # restore the original traj_id so that the phase1 and phase 2 results use the same ID - if mcmode == 2: + if mcmode == MCMODE_PHASE2: traj.traj_id = saved_traj_id traj.phase_1_only = False - if mcmode == 1: + if mcmode == MCMODE_PHASE1: traj.phase_1_only = True if orig_traj: log.info(f"Removing the previous solution {os.path.dirname(orig_traj.traj_file_path)} ...") - self.dh.removeTrajectory(orig_traj) + remove_phase1 = True if abs(round((traj.jdt_ref-orig_traj.jdt_ref)*86400000,0)) > 0 else False + self.dh.removeTrajectory(orig_traj, remove_phase1=remove_phase1) traj.pre_mc_longname = os.path.split(self.dh.generateTrajOutputDirectoryPath(orig_traj, make_dirs=False))[-1] log.info('Saving trajectory....') self.dh.saveTrajectoryResults(traj, self.traj_constraints.save_plots) - if mcmode != 2: - # we do not need to update the database for phase2 - log.info('Updating database....') - self.dh.addTrajectory(traj) - # Mark observations as paired in a trajectory if fit successful - if mcmode != 2 and matched_obs is not None: - for _, met_obs_temp, _ in matched_obs: - self.dh.markObservationAsPaired(met_obs_temp) + # we do not need to update the database for phase2 + if mcmode != MCMODE_PHASE2: + log.info('Updating database....') + self.dh.addTrajectory(traj, verbose=verbose) + if matched_obs is not None: + self.dh.addPairedObs(matched_obs, traj.jdt_ref, verbose=verbose) else: log.info('unable to fit trajectory') return successful_traj_fit + def mergeBrokenCandidates(self, candidate_trajectories): + ### Merge all candidate trajectories which share the same observations ### + log.info("") + log.info("---------------------------") + log.info("3) MERGING BROKEN OBSERVATIONS") + log.info("---------------------------") + log.info(f"Initially {len(candidate_trajectories)} candidates") + merged_candidate_trajectories = [] + merged_indices = [] + total_obs_used = 0 + for i, traj_cand_ref in enumerate(candidate_trajectories): + + # Skip candidate trajectories that have already been merged + if i in merged_indices: + continue + + # Stop the search if the end has been reached + if (i + 1) == len(candidate_trajectories): + merged_candidate_trajectories.append(traj_cand_ref) + total_obs_used += len(traj_cand_ref) + break - def run(self, event_time_range=None, bin_time_range=None, mcmode=0): - """ Run meteor corellation using available data. - Keyword arguments: - event_time_range: [list] A list of two datetime objects. These are times between which - events should be used. None by default, which uses all available events. - mcmode: [int] flag to indicate whether or not to run monte-carlos - """ + # Get the mean time of the reference observation + ref_mean_dt = traj_cand_ref[0][1].mean_dt - # a bit of logging to let readers know what we're doing - if mcmode == 2: - mcmodestr = ' - MONTE CARLO STAGE' - elif mcmode == 1: - mcmodestr = ' - SIMPLE STAGE' - else: - mcmodestr = ' ' + obs_list_ref = [entry[1] for entry in traj_cand_ref] + merged_candidate = [] + + # Compute the mean radiant of the reference solution + plane_radiants_ref = [entry[2].radiant_eq for entry in traj_cand_ref] + ra_mean_ref = meanAngle([ra for ra, _ in plane_radiants_ref]) + dec_mean_ref = np.mean([dec for _, dec in plane_radiants_ref]) - if mcmode != 2: - # Get unpaired observations, filter out observations with too little points and sort them by time - unpaired_observations_all = self.dh.getUnpairedObservations() - unpaired_observations_all = [mettmp for mettmp in unpaired_observations_all - if len(mettmp.data) >= self.traj_constraints.min_meas_pts] - unpaired_observations_all = sorted(unpaired_observations_all, key=lambda x: x.reference_dt) - # Remove all observations done prior to 2000, to weed out those with bad time - unpaired_observations_all = [met_obs for met_obs in unpaired_observations_all - if met_obs.reference_dt > datetime.datetime(2000, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc)] + # Check for pairs + found_first_pair = False + for j, traj_cand_test in enumerate(candidate_trajectories[(i + 1):]): + # Skip same observations + if traj_cand_ref[0] == traj_cand_test[0]: + continue - # Normalize all reference times and time data so that the reference time is at t = 0 s - for met_obs in unpaired_observations_all: - # Correct the reference time - t_zero = met_obs.data[0].time_rel - met_obs.reference_dt = met_obs.reference_dt + datetime.timedelta(seconds=t_zero) + # Get the mean time of the test observation + test_mean_dt = traj_cand_test[0][1].mean_dt - # Normalize all observation times so that the first time is t = 0 s - for i in range(len(met_obs.data)): - met_obs.data[i].time_rel -= t_zero + # Make sure the observations that are being compared are within the time window + time_diff = (test_mean_dt - ref_mean_dt).total_seconds() + if abs(time_diff) > self.traj_constraints.max_toffset: + continue + + + # Break the search if the time went beyond the search. This can be done as observations + # are ordered in time + if time_diff > self.traj_constraints.max_toffset: + break + + + + # Create a list of observations + obs_list_test = [entry[1] for entry in traj_cand_test] + + # Check if there any any common observations between candidate trajectories and merge them + # if that is the case + found_match = False + test_ids = [x.id for x in obs_list_test] + for obs1 in obs_list_ref: + if obs1.id in test_ids: + found_match = True + break + # Compute the mean radiant of the reference solution + plane_radiants_test = [entry[2].radiant_eq for entry in traj_cand_test] + ra_mean_test = meanAngle([ra for ra, _ in plane_radiants_test]) + dec_mean_test = np.mean([dec for _, dec in plane_radiants_test]) + + # Skip the merging attempt if the estimated radiants are too far off + if np.degrees(angleBetweenSphericalCoords(dec_mean_ref, ra_mean_ref, dec_mean_test, ra_mean_test)) > self.traj_constraints.max_merge_radiant_angle: + continue + + + # Add the candidate trajectory to the common list if a match has been found + if found_match: + + ref_stations = [obs.station_code for obs in obs_list_ref] + + # Add observations that weren't present in the reference candidate + for entry in traj_cand_test: + + # Make sure the added observation is not already added + if entry[1] not in obs_list_ref: + + # Print the reference and the merged radiants + if not found_first_pair: + log.info("") + log.info("------") + log.info("Reference time: {:s}".format(str(ref_mean_dt))) + log.info("Reference stations: {:s}".format(", ".join(sorted(ref_stations)))) + log.info("Reference radiant: RA = {:.2f}, Dec = {:.2f}".format(np.degrees(ra_mean_ref), np.degrees(dec_mean_ref))) + log.info("") + found_first_pair = True + + log.info("Merging: {:s} {:s}".format(str(entry[1].mean_dt), str(entry[1].station_code))) + traj_cand_ref.append(entry) + + log.info("Merged radiant: RA = {:.2f}, Dec = {:.2f}".format(np.degrees(ra_mean_test), np.degrees(dec_mean_test))) + + # Mark that the current index has been processed + merged_indices.append(i + j + 1) + + # Add the reference candidate observations to the list + merged_candidate += traj_cand_ref + total_obs_used += len(traj_cand_ref) + + # Add the merged observation to the final list + merged_candidate_trajectories.append(merged_candidate) + + log.info(f"After merging, there are {len(merged_candidate_trajectories)} candidates") + return merged_candidate_trajectories, total_obs_used + + + def run(self, event_time_range=None, bin_time_range=None, mcmode=MCMODE_ALL, verbose=False): + """ Run meteor corellation using available data. + + Keyword arguments: + event_time_range: [list] A list of two datetime objects. These are times between which + events should be used. None by default, which uses all available events. + mcmode: [int] flag to indicate whether or not to run monte-carlos + """ + + # a bit of logging to let readers know what we're doing + mcmodestr = getMcModeStr(mcmode, strtype=1) + + if mcmode != MCMODE_PHASE2: + if mcmode & MCMODE_CANDS: + # Get unpaired observations, filter out observations with too few points and sort them by time + unpaired_observations_all = self.dh.getUnpairedObservations() + unpaired_observations_all = [mettmp for mettmp in unpaired_observations_all + if len(mettmp.data) >= self.traj_constraints.min_meas_pts] + unpaired_observations_all = sorted(unpaired_observations_all, key=lambda x: x.reference_dt) + + # Remove all observations done prior to 2000, to weed out those with bad time + unpaired_observations_all = [met_obs for met_obs in unpaired_observations_all + if met_obs.reference_dt > datetime.datetime(2000, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc)] + + # Normalize all reference times and time data so that the reference time is at t = 0 s + for met_obs in unpaired_observations_all: + + # Correct the reference time + t_zero = met_obs.data[0].time_rel + met_obs.reference_dt = met_obs.reference_dt + datetime.timedelta(seconds=t_zero) + + # Normalize all observation times so that the first time is t = 0 s + for i in range(len(met_obs.data)): + met_obs.data[i].time_rel -= t_zero + else: + event_time_range = self.dh.dt_range # If the time range was given, only use the events in that time range if event_time_range: @@ -1104,11 +1235,17 @@ def run(self, event_time_range=None, bin_time_range=None, mcmode=0): # Data will be divided into time bins, so the pairing function doesn't have to go pair many # observations at once and keep all pairs in memory else: - dt_beg = unpaired_observations_all[0].reference_dt - dt_end = unpaired_observations_all[-1].reference_dt + if mcmode & MCMODE_CANDS: + dt_beg = unpaired_observations_all[0].reference_dt + dt_end = unpaired_observations_all[-1].reference_dt + bin_days = 0.25 + else: + dt_beg, dt_end = self.dh.dt_range + bin_days = 1 + dt_bin_list = generateDatetimeBins( dt_beg, dt_end, - bin_days=1, utc_hour_break=12, tzinfo=datetime.timezone.utc, reverse=False + bin_days=bin_days, utc_hour_break=12, tzinfo=datetime.timezone.utc, reverse=False ) else: @@ -1126,6 +1263,7 @@ def run(self, event_time_range=None, bin_time_range=None, mcmode=0): log.info("---------------------------------") log.info("") + log.info(f'mcmode is {mcmodestr}') # Go though all time bins and split the list of observations for bin_beg, bin_end in dt_bin_list: @@ -1133,426 +1271,381 @@ def run(self, event_time_range=None, bin_time_range=None, mcmode=0): traj_solved_count = 0 # if we're in MC mode 0 or 1 we have to find the candidate trajectories - if mcmode < 2: - log.info("") - log.info("-----------------------------------") - log.info(" PAIRING TRAJECTORIES IN TIME BIN:") - log.info(" BIN BEG: {:s} UTC".format(str(bin_beg))) - log.info(" BIN END: {:s} UTC".format(str(bin_end))) - log.info("-----------------------------------") - log.info("") - - - # Select observations in the given time bin - unpaired_observations = [met_obs for met_obs in unpaired_observations_all - if (met_obs.reference_dt >= bin_beg) and (met_obs.reference_dt <= bin_end)] - - log.info(f'Analysing {len(unpaired_observations)} observations...') - - ### CHECK FOR PAIRING WITH PREVIOUSLY ESTIMATED TRAJECTORIES ### - - log.info("") - log.info("--------------------------------------------------------------------------") - log.info(" 1) CHECKING IF PREVIOUSLY ESTIMATED TRAJECTORIES HAVE NEW OBSERVATIONS") - log.info("--------------------------------------------------------------------------") - log.info("") - - # Get a list of all already computed trajectories within the given time bin - # Reducted trajectory objects are returned - - if bin_time_range: - # restrict checks to the bin range supplied to run() plus a day to allow for data upload times - log.info(f'Getting computed trajectories for bin {str(bin_time_range[0])} to {str(bin_time_range[1])}') - computed_traj_list = self.dh.getComputedTrajectories(datetime2JD(bin_time_range[0]), datetime2JD(bin_time_range[1])+1) - else: - # use the current bin. - log.info(f'Getting computed trajectories for {str(bin_beg)} to {str(bin_end)}') - computed_traj_list = self.dh.getComputedTrajectories(datetime2JD(bin_beg), datetime2JD(bin_end)) - - # Find all unpaired observations that match already existing trajectories - for traj_reduced in computed_traj_list: - - # If the trajectory already has more than the maximum number of stations, skip it - if len(traj_reduced.participating_stations) >= self.traj_constraints.max_stations: - - log.info( - "Trajectory {:s} has already reached the maximum number of stations, " - "skipping...".format( - str(jd2Date(traj_reduced.jdt_ref, dt_obj=True, tzinfo=datetime.timezone.utc)))) - - # TODO DECIDE WHETHER WE ACTUALLY WANT TO DO THIS - # the problem is that we could end up with unpaired observations that form a new trajectory instead of - # being added to an existing one - continue - - # Get all unprocessed observations which are close in time to the reference trajectory - traj_time_pairs = self.dh.getTrajTimePairs(traj_reduced, unpaired_observations, - self.traj_constraints.max_toffset) - - # Skip trajectory if there are no new obervations - if not traj_time_pairs: - continue - - + if mcmode != MCMODE_PHASE2: + ## we are in candidatemode mode 0 or 1 and want to find candidates + if mcmode & MCMODE_CANDS: + log.info("") + log.info("-----------------------------------") + log.info("0) PAIRING TRAJECTORIES IN TIME BIN:") + log.info(" BIN BEG: {:s} UTC".format(str(bin_beg))) + log.info(" BIN END: {:s} UTC".format(str(bin_end))) + log.info("-----------------------------------") log.info("") - log.info("Checking trajectory at {:s} in countries: {:s}".format( - str(jd2Date(traj_reduced.jdt_ref, dt_obj=True, tzinfo=datetime.timezone.utc)), - ", ".join(list(set([stat_id[:2] for stat_id in traj_reduced.participating_stations]))))) - log.info("--------") - - - # Filter out bad matches and only keep the good ones - candidate_observations = [] - traj_full = None - skip_traj_check = False - for met_obs in traj_time_pairs: - - log.info("Candidate observation: {:s}".format(met_obs.station_code)) - - platepar = self.dh.getPlatepar(met_obs) - - # Check that the trajectory beginning and end are within the distance limit - if not self.trajectoryRangeCheck(traj_reduced, platepar): - continue - - - # Check that the trajectory is within the field of view - if not self.trajectoryInFOV(traj_reduced, platepar): - continue - - - # Load the full trajectory object - if traj_full is None: - traj_full = self.dh.loadFullTraj(traj_reduced) - - # If the full trajectory couldn't be loaded, skip checking this trajectory - if traj_full is None: - - skip_traj_check = True - break - - - ### Do a rough trajectory solution and perform a quick quality control ### - - # Init observation object using the new meteor observation - obs_new = self.initObservationsObject(met_obs, platepar, - ref_dt=jd2Date(traj_reduced.jdt_ref, dt_obj=True, tzinfo=datetime.timezone.utc)) - - - # Get an observation from the trajectory object with the maximum convergence angle to - # the reference observations - obs_traj_best = None - qc_max = 0.0 - for obs_tmp in traj_full.observations: - - # Compute the plane intersection between the new and one of trajectory observations - pi = PlaneIntersection(obs_new, obs_tmp) - - # Take the observation with the maximum convergence angle - if (obs_traj_best is None) or (pi.conv_angle > qc_max): - qc_max = pi.conv_angle - obs_traj_best = obs_tmp - - - # Do a quick trajectory solution and perform sanity checks - plane_intersection = self.quickTrajectorySolution(obs_traj_best, obs_new) - if plane_intersection is None: - continue - - ### ### - - candidate_observations.append([obs_new, met_obs]) - - - # Skip the candidate trajectory if it couldn't be loaded from disk - if skip_traj_check: - continue - # If there are any good new observations, add them to the trajectory and re-run the solution - if candidate_observations: + # Select observations in the given time bin + unpaired_observations = [met_obs for met_obs in unpaired_observations_all + if (met_obs.reference_dt >= bin_beg) and (met_obs.reference_dt <= bin_end)] - log.info("Recomputing trajectory with new observations from stations:") + total_unpaired = len(unpaired_observations) + log.info(f'Analysing {total_unpaired} observations in this bucket...') + num_obs_paired = 0 - # Add new observations to the trajectory object - for obs_new, _ in candidate_observations: - log.info(obs_new.station_id) - traj_full.infillWithObs(obs_new) + # List of all candidate trajectories + candidate_trajectories = [] + ### CHECK FOR PAIRING WITH PREVIOUSLY ESTIMATED TRAJECTORIES ### + if total_unpaired > 0: + log.info("") + log.info("--------------------------------------------------------------------------") + log.info(" 1) CHECKING IF PREVIOUSLY ESTIMATED TRAJECTORIES HAVE NEW OBSERVATIONS") + log.info("--------------------------------------------------------------------------") + log.info("") - # Re-run the trajectory fit - # pass in orig_traj here so that it can be deleted from disk if the new solution succeeds - successful_traj_fit = self.solveTrajectory(traj_full, traj_full.mc_runs, mcmode=mcmode, orig_traj=traj_reduced) + # Get a list of all already computed trajectories within the given time bin + # Reducted trajectory objects are returned - # If the new trajectory solution succeeded, remove the now-paired observations - if successful_traj_fit: - - log.info("Remove paired observations from the processing list...") - for _, met_obs_temp in candidate_observations: - self.dh.markObservationAsPaired(met_obs_temp) - unpaired_observations.remove(met_obs_temp) - + if bin_time_range: + # restrict checks to the bin range supplied to run() plus a day to allow for data upload times + log.info(f'Getting computed trajectories for bin {str(bin_time_range[0])} to {str(bin_time_range[1])}') + computed_traj_list = self.dh.getComputedTrajectories(datetime2JD(bin_time_range[0]), datetime2JD(bin_time_range[1])+1) else: - log.info("New trajectory solution failed, keeping the old trajectory...") + # use the current bin. + log.info(f'Getting computed trajectories for {str(bin_beg)} to {str(bin_end)}') + computed_traj_list = self.dh.getComputedTrajectories(datetime2JD(bin_beg), datetime2JD(bin_end)) - ### ### + # Find all unpaired observations that match already existing trajectories + for traj_reduced in computed_traj_list: + # If the trajectory already has more than the maximum number of stations, skip it + if len(traj_reduced.participating_stations) >= self.traj_constraints.max_stations: - log.info("") - log.info("-------------------------------------------------") - log.info(" 2) PAIRING OBSERVATIONS INTO NEW TRAJECTORIES") - log.info("-------------------------------------------------") - log.info("") + log.info( + "Trajectory {:s} has already reached the maximum number of stations, " + "skipping...".format( + str(jd2Date(traj_reduced.jdt_ref, dt_obj=True, tzinfo=datetime.timezone.utc)))) - # List of all candidate trajectories - candidate_trajectories = [] + # TODO DECIDE WHETHER WE ACTUALLY WANT TO DO THIS + # the problem is that we could end up with unpaired observations that form a new trajectory instead of + # being added to an existing one + continue + + # Get all unprocessed observations which are close in time to the reference trajectory + traj_time_pairs = self.dh.getTrajTimePairs(traj_reduced, unpaired_observations, + self.traj_constraints.max_toffset) - # Go through all unpaired and unprocessed meteor observations - for met_obs in unpaired_observations: + # Skip trajectory if there are no new obervations + if not traj_time_pairs: + continue - # Skip observations that were processed in the meantime - if met_obs.processed: - continue - # Get station platepar - reference_platepar = self.dh.getPlatepar(met_obs) - obs1 = self.initObservationsObject(met_obs, reference_platepar) + log.info("") + log.info("Checking trajectory at {:s} in countries: {:s}".format( + str(jd2Date(traj_reduced.jdt_ref, dt_obj=True, tzinfo=datetime.timezone.utc)), + ", ".join(list(set([stat_id[:2] for stat_id in traj_reduced.participating_stations]))))) + log.info("--------") - # Keep a list of observations which matched the reference observation - matched_observations = [] + # Filter out bad matches and only keep the good ones + candidate_observations = [] + traj_full = None + skip_traj_check = False + for met_obs in traj_time_pairs: - # Find all meteors from other stations that are close in time to this meteor - plane_intersection_good = None - time_pairs = self.dh.findTimePairs(met_obs, unpaired_observations, - self.traj_constraints.max_toffset) - for met_pair_candidate in time_pairs: + log.info("Candidate observation: {:s}".format(met_obs.station_code)) - log.info("") - log.info("Processing pair:") - log.info("{:s} and {:s}".format(met_obs.station_code, met_pair_candidate.station_code)) - log.info("{:s} and {:s}".format(str(met_obs.reference_dt), str(met_pair_candidate.reference_dt))) - log.info("-----------------------") + platepar = self.dh.getPlatepar(met_obs) - ### Check if the stations are close enough and have roughly overlapping fields of view ### + # Check that the trajectory beginning and end are within the distance limit + if not self.trajectoryRangeCheck(traj_reduced, platepar): + continue - # Get candidate station platepar - candidate_platepar = self.dh.getPlatepar(met_pair_candidate) - # Check if the stations are within range - if not self.stationRangeCheck(reference_platepar, candidate_platepar): - continue + # Check that the trajectory is within the field of view + if not self.trajectoryInFOV(traj_reduced, platepar): + continue - # Check the FOV overlap - if not self.checkFOVOverlap(reference_platepar, candidate_platepar): - log.info("Station FOV does not overlap: {:s} and {:s}".format(met_obs.station_code, - met_pair_candidate.station_code)) - continue - ### ### + # Load the full trajectory object + if traj_full is None: + traj_full = self.dh.loadFullTraj(traj_reduced) + # If the full trajectory couldn't be loaded, skip checking this trajectory + if traj_full is None: + + skip_traj_check = True + break - ### Do a rough trajectory solution and perform a quick quality control ### + ### Do a rough trajectory solution and perform a quick quality control ### - # Init observations - obs2 = self.initObservationsObject(met_pair_candidate, candidate_platepar, - ref_dt=met_obs.reference_dt) + # Init observation object using the new meteor observation + obs_new = self.initObservationsObject(met_obs, platepar, + ref_dt=jd2Date(traj_reduced.jdt_ref, dt_obj=True, tzinfo=datetime.timezone.utc)) + obs_new.id = met_obs.id + obs_new.station_code = met_obs.station_code + obs_new.mean_dt = met_obs.mean_dt - # Do a quick trajectory solution and perform sanity checks - plane_intersection = self.quickTrajectorySolution(obs1, obs2) - if plane_intersection is None: - continue - - else: - plane_intersection_good = plane_intersection + # Get an observation from the trajectory object with the maximum convergence angle to + # the reference observations + obs_traj_best = None + qc_max = 0.0 + for obs_tmp in traj_full.observations: + + # Compute the plane intersection between the new and one of trajectory observations + pi = PlaneIntersection(obs_new, obs_tmp) - ### ### + # Take the observation with the maximum convergence angle + if (obs_traj_best is None) or (pi.conv_angle > qc_max): + qc_max = pi.conv_angle + obs_traj_best = obs_tmp - matched_observations.append([obs2, met_pair_candidate, plane_intersection]) + # Do a quick trajectory solution and perform sanity checks + plane_intersection = self.quickTrajectorySolution(obs_traj_best, obs_new) + if plane_intersection is None: + continue + ### ### - # If there are no matched observations, skip it - if len(matched_observations) == 0: + candidate_observations.append([obs_new, met_obs]) - if len(time_pairs) > 0: - log.info("") - log.info(" --- NO MATCH ---") - continue + # Skip the candidate trajectory if it couldn't be loaded from disk + if skip_traj_check: + continue - # Skip if there are not good plane intersections - if plane_intersection_good is None: - continue - # Add the first observation to matched observations - matched_observations.append([obs1, met_obs, plane_intersection_good]) + # If there are any good new observations, add them to the trajectory and re-run the solution + if candidate_observations: + log.info("Recomputing trajectory with new observations from stations:") - # Mark observations as processed - for _, met_obs_temp, _ in matched_observations: - met_obs_temp.processed = True - self.dh.markObservationAsProcessed(met_obs_temp) + # Add new observations to the trajectory object + for obs_new, _ in candidate_observations: + log.info(obs_new.station_id) + traj_full.infillWithObs(obs_new) - # Store candidate trajectories - log.info("") - log.info(" --- ADDING CANDIDATE ---") - candidate_trajectories.append(matched_observations) + # Re-run the trajectory fit + # pass in orig_traj here so that it can be deleted from disk if the new solution succeeds + # pass the new candidates in so that they can be marked paired if the new soln succeeds + # Note: mcmode must be phase1 here to force a recompute + successful_traj_fit = self.solveTrajectory(traj_full, traj_full.mc_runs, mcmode=MCMODE_PHASE1, + matched_obs=candidate_observations, orig_traj=traj_reduced, verbose=verbose) + + # If the new trajectory solution succeeded, remove the now-paired observations from the in memory list + if successful_traj_fit: + log.info("Remove paired observations from the processing list...") + for _, met_obs_temp in candidate_observations: + unpaired_observations.remove(met_obs_temp) + else: + log.info("New trajectory solution failed, keeping the old trajectory...") - ### Merge all candidate trajectories which share the same observations ### - log.info("") - log.info("---------------------------") - log.info("MERGING BROKEN OBSERVATIONS") - log.info("---------------------------") - merged_candidate_trajectories = [] - merged_indices = [] - for i, traj_cand_ref in enumerate(candidate_trajectories): - - # Skip candidate trajectories that have already been merged - if i in merged_indices: - continue + ### ### - - # Stop the search if the end has been reached - if (i + 1) == len(candidate_trajectories): - merged_candidate_trajectories.append(traj_cand_ref) - break + log.info("") + log.info("-------------------------------------------------") + log.info(" 2) PAIRING OBSERVATIONS INTO NEW TRAJECTORIES") + log.info("-------------------------------------------------") + log.info("") - # Get the mean time of the reference observation - ref_mean_dt = traj_cand_ref[0][1].mean_dt - obs_list_ref = [entry[1] for entry in traj_cand_ref] - merged_candidate = [] + # Go through all unpaired and unprocessed meteor observations + for met_obs in unpaired_observations: - # Compute the mean radiant of the reference solution - plane_radiants_ref = [entry[2].radiant_eq for entry in traj_cand_ref] - ra_mean_ref = meanAngle([ra for ra, _ in plane_radiants_ref]) - dec_mean_ref = np.mean([dec for _, dec in plane_radiants_ref]) + # Skip observations that were processed in the meantime + if met_obs.processed: + continue + if self.dh.checkIfObsPaired(met_obs.id, verbose=verbose): + continue - # Check for pairs - found_first_pair = False - for j, traj_cand_test in enumerate(candidate_trajectories[(i + 1):]): + # Get station platepar + reference_platepar = self.dh.getPlatepar(met_obs) + obs1 = self.initObservationsObject(met_obs, reference_platepar) - # Skip same observations - if traj_cand_ref[0] == traj_cand_test[0]: - continue + # Keep a list of observations which matched the reference observation + matched_observations = [] - # Get the mean time of the test observation - test_mean_dt = traj_cand_test[0][1].mean_dt + # Find all meteors from other stations that are close in time to this meteor + plane_intersection_good = None + time_pairs = self.dh.findTimePairs(met_obs, unpaired_observations, + self.traj_constraints.max_toffset) + for met_pair_candidate in time_pairs: - # Make sure the observations that are being compared are within the time window - time_diff = (test_mean_dt - ref_mean_dt).total_seconds() - if abs(time_diff) > self.traj_constraints.max_toffset: - continue + log.info("") + log.info("Processing pair:") + log.info("{:s} and {:s}".format(met_obs.station_code, met_pair_candidate.station_code)) + log.info("{:s} and {:s}".format(str(met_obs.reference_dt), str(met_pair_candidate.reference_dt))) + log.info("-----------------------") + ### Check if the stations are close enough and have roughly overlapping fields of view ### - # Break the search if the time went beyond the search. This can be done as observations - # are ordered in time - if time_diff > self.traj_constraints.max_toffset: - break + # Get candidate station platepar + candidate_platepar = self.dh.getPlatepar(met_pair_candidate) + # Check if the stations are within range + if not self.stationRangeCheck(reference_platepar, candidate_platepar): + continue + # Check the FOV overlap + if not self.checkFOVOverlap(reference_platepar, candidate_platepar): + log.info("Station FOV does not overlap: {:s} and {:s}".format(met_obs.station_code, + met_pair_candidate.station_code)) + continue - # Create a list of observations - obs_list_test = [entry[1] for entry in traj_cand_test] + ### ### - # Check if there any any common observations between candidate trajectories and merge them - # if that is the case - found_match = False - for obs1 in obs_list_ref: - if obs1 in obs_list_test: - found_match = True - break - # Compute the mean radiant of the reference solution - plane_radiants_test = [entry[2].radiant_eq for entry in traj_cand_test] - ra_mean_test = meanAngle([ra for ra, _ in plane_radiants_test]) - dec_mean_test = np.mean([dec for _, dec in plane_radiants_test]) + ### Do a rough trajectory solution and perform a quick quality control ### - # Skip the mergning attempt if the estimated radiants are too far off - if np.degrees(angleBetweenSphericalCoords(dec_mean_ref, ra_mean_ref, dec_mean_test, ra_mean_test)) > self.traj_constraints.max_merge_radiant_angle: + # Init observations + obs2 = self.initObservationsObject(met_pair_candidate, candidate_platepar, + ref_dt=met_obs.reference_dt) - continue + # Do a quick trajectory solution and perform sanity checks + plane_intersection = self.quickTrajectorySolution(obs1, obs2) + if plane_intersection is None: + continue + else: + plane_intersection_good = plane_intersection - # Add the candidate trajectory to the common list if a match has been found - if found_match: + ### ### - ref_stations = [obs.station_code for obs in obs_list_ref] + matched_observations.append([obs2, met_pair_candidate, plane_intersection]) - # Add observations that weren't present in the reference candidate - for entry in traj_cand_test: - # Make sure the added observation is not from a station that's already added - if entry[1].station_code in ref_stations: - continue - if entry[1] not in obs_list_ref: + # If there are no matched observations, skip it + if len(matched_observations) == 0: - # Print the reference and the merged radiants - if not found_first_pair: - log.info("") - log.info("------") - log.info("Reference time: {:s}".format(str(ref_mean_dt))) - log.info("Reference stations: {:s}".format(", ".join(sorted(ref_stations)))) - log.info("Reference radiant: RA = {:.2f}, Dec = {:.2f}".format(np.degrees(ra_mean_ref), np.degrees(dec_mean_ref))) - log.info("") - found_first_pair = True + if len(time_pairs) > 0: + log.info("") + log.info(" --- NO MATCH ---") - log.info("Merging: {:s} {:s}".format(str(entry[1].mean_dt), str(entry[1].station_code))) - traj_cand_ref.append(entry) + continue - log.info("Merged radiant: RA = {:.2f}, Dec = {:.2f}".format(np.degrees(ra_mean_test), np.degrees(dec_mean_test))) + # Skip if there are not good plane intersections + if plane_intersection_good is None: + continue - + # Add the first observation to matched observations + matched_observations.append([obs1, met_obs, plane_intersection_good]) - # Mark that the current index has been processed - merged_indices.append(i + j + 1) + # Mark observations as processed + for _, met_obs_temp, _ in matched_observations: + met_obs_temp.processed = True + # Store candidate trajectory group + # Note that this will include candidate groups that already failed on previous runs. + # We will exclude these later - we can't do it just yet as if new data has arrived, then + # in the next step, the group might be merged with another group creating a solvable set. + log.info("") + ref_dt = min([met_obs.reference_dt for _, met_obs, _ in matched_observations]) + log.info(f" --- ADDING CANDIDATE at {ref_dt.isoformat()} ---") + candidate_trajectories.append(matched_observations) - # Add the reference candidate observations to the list - merged_candidate += traj_cand_ref + # Check for mergeable candidate combinations + merged_candidate_trajectories, num_obs_paired = self.mergeBrokenCandidates(candidate_trajectories) + # Now check and exclude already-processed candidates + # We can't do this earlier as we need to check mergeability first + candidate_trajectories = self.dh.checkAlreadyProcessed(merged_candidate_trajectories, verbose=verbose) - # Add the merged observation to the final list - merged_candidate_trajectories.append(merged_candidate) + log.info("-----------------------") + log.info(f'There are {total_unpaired - num_obs_paired} remaining unpaired observations in this bucket.') + log.info("-----------------------") + # in candidate mode we want to save the candidates to disk + if mcmode == MCMODE_CANDS: + log.info("-----------------------") + if bin_time_range: + log.info(f'5) SAVING {len(candidate_trajectories)} CANDIDATES for {str(bin_time_range[0])} to {str(bin_time_range[1])}') + else: + log.info(f'5) SAVING {len(candidate_trajectories)} CANDIDATES for {str(bin_beg)} to {str(bin_end)}') + log.info("-----------------------") + # Save candidates. This will check and skip over already-processed + # combinations + self.dh.saveCandidates(candidate_trajectories, verbose=verbose) - candidate_trajectories = merged_candidate_trajectories + return len(candidate_trajectories) + + else: + log.info("-----------------------") + log.info('5) PROCESSING {} CANDIDATES'.format(len(candidate_trajectories))) + log.info("-----------------------") + # end of 'if mcmode & MCMODE_CANDS' ### ### - else: + else: + # candidatemode is LOAD so load any available candidates for processing + traj_solved_count = 0 + candidate_trajectories = [] + log.info("-----------------------") + log.info('LOADING CANDIDATES') + log.info("-----------------------") + + save_path = self.dh.candidate_dir + procpath = os.path.join(save_path, 'processed') + os.makedirs(procpath, exist_ok=True) + # TODO use glob.glob here + for fil in os.listdir(save_path): + if '.pickle' not in fil: + continue + try: + procfile = os.path.join(procpath, fil) + if os.path.isfile(procfile): + # Skip the trajectory if we already processed it. + # To force reprocessing, move the candidate from 'candidates/processed' to 'candidates' + log.info(f'Candidate {fil} already processed') + os.remove(os.path.join(save_path, fil)) + continue + loadedpickle = loadPickle(save_path, fil) + candidate_trajectories.append(loadedpickle) + # now move the loaded file so we don't try to reprocess it + os.rename(os.path.join(save_path, fil), procfile) + except Exception: + log.info(f'Candidate {fil} went away, probably picked up by another process') + log.info("-----------------------") + log.info('LOADED {} CANDIDATES'.format(len(candidate_trajectories))) + log.info("-----------------------") + # end of 'self.candidatemode == CANDMODE_LOAD' + # end of 'if mcmode != MCMODE_PHASE2' + else: + # mcmode == MCMODE_PHASE2 so we need to load the phase1 solutions log.info("-----------------------") log.info('LOADING PHASE1 SOLUTIONS') log.info("-----------------------") candidate_trajectories = self.dh.phase1Trajectories - # end of "if mcmode < 2" + # end of "if mcmode == MCMODE_PHASE2" + + # avoid reprocessing candidates that were already processed + num_traj = len(candidate_trajectories) log.info("") log.info("-----------------------") - log.info(f'SOLVING {len(candidate_trajectories)} TRAJECTORIES {mcmodestr}') + log.info(f'SOLVING {num_traj} TRAJECTORIES {mcmodestr}') log.info("-----------------------") log.info("") # Go through all candidate trajectories and compute the complete trajectory solution - for matched_observations in candidate_trajectories: + for i, matched_observations in enumerate(candidate_trajectories): log.info("") log.info("-----------------------") - + log.info(f'processing {"candidate" if mcmode==MCMODE_PHASE1 else "trajectory"} {i+1}/{num_traj}') # if mcmode is not 2, prepare to calculate the intersecting planes solutions - if mcmode != 2: + if mcmode != MCMODE_PHASE2: # Find unique station counts station_counts = np.unique([entry[1].station_code for entry in matched_observations], return_counts=True) @@ -1622,6 +1715,21 @@ def run(self, event_time_range=None, bin_time_range=None, mcmode=0): log.info("Max convergence angle too small: {:.1f} < {:.1f} deg".format(qc_max, self.traj_constraints.min_qc)) + # create a traj object to add to the failed database so we don't try to recompute this one again + ref_dt = min([met_obs.reference_dt for _, met_obs, _ in matched_observations]) + jdt_ref = datetime2JD(ref_dt) + + failed_traj = self.initTrajectory(jdt_ref, 0, verbose=verbose) + for obs_temp, met_obs, _ in matched_observations: + failed_traj.infillWithObs(obs_temp) + + t0 = min([obs.time_data[0] for obs in failed_traj.observations if (not obs.ignore_station) + or (not np.all(obs.ignore_list))]) + if t0 != 0.0: + failed_traj.jdt_ref = failed_traj.jdt_ref + t0/86400.0 + + log.info(f"Trajectory at {ref_dt.isoformat()} skipped and added to fails!") + self.dh.addTrajectory(failed_traj, failed_traj.jdt_ref, verbose=verbose) continue @@ -1649,17 +1757,20 @@ def run(self, event_time_range=None, bin_time_range=None, mcmode=0): # Init the solver (use the earliest date as the reference) - ref_dt = min([met_obs.reference_dt for _, met_obs, _ in matched_observations]) - jdt_ref = datetime2JD(ref_dt) - traj = self.initTrajectory(jdt_ref, mc_runs, verbose=False) + jdt_ref = min([obs_temp.jdt_ref for obs_temp, _, _ in matched_observations]) + + #log.info(f'ref_dt {jd2Date(jdt_ref, dt_obj=True)}') + traj = self.initTrajectory(jdt_ref, mc_runs, verbose=verbose) # Feed the observations into the trajectory solver for obs_temp, met_obs, _ in matched_observations: # Normalize the observations to the reference Julian date - jdt_ref_curr = datetime2JD(met_obs.reference_dt) + jdt_ref_curr = obs_temp.jdt_ref # datetime2JD(met_obs.reference_dt) obs_temp.time_data += (jdt_ref_curr - jdt_ref)*86400 + # we have normalised the time data to jdt_ref, now we need to reset jdt_ref for each obs too + obs_temp.jdt_ref = jdt_ref traj.infillWithObs(obs_temp) @@ -1671,29 +1782,30 @@ def run(self, event_time_range=None, bin_time_range=None, mcmode=0): # If the first time is not 0, normalize times so that the earliest time is 0 if t0 != 0.0: - + #log.info(f'adjusting by {t0}') # Offset all times by t0 for i in range(len(traj.observations)): traj.observations[i].time_data -= t0 - + # log.info(f'obs jdt_ref is {jd2Date(traj.observations[i].jdt_ref, dt_obj=True)}') # Recompute the reference JD to corresponds with t0 traj.jdt_ref = traj.jdt_ref + t0/86400.0 + #log.info(f'ref_dt {jd2Date(traj.jdt_ref, dt_obj=True)}') # If this trajectory already failed to be computed, don't try to recompute it again unless # new observations are added if self.dh.checkTrajIfFailed(traj): log.info("The same trajectory already failed to be computed in previous runs!") continue - # pass in matched_observations here so that solveTrajectory can mark them paired if they're used - result = self.solveTrajectory(traj, mc_runs, mcmode=mcmode, matched_obs=matched_observations) + # pass in matched_observations here so that we can mark them paired if they're used + result = self.solveTrajectory(traj, mc_runs, mcmode=mcmode, matched_obs=matched_observations, verbose=verbose) traj_solved_count += int(result) - # end of if mcmode != 2 + # end of if mcmode != MCMODE_PHASE2 else: - # mcmode is 2 and so we have a list of trajectories that were solved in phase 1 + # mcmode is MCMODE_PHASE2 and so we have a list of trajectories that were solved in phase 1 # to prepare for monte-carlo solutions traj = matched_observations @@ -1717,18 +1829,18 @@ def run(self, event_time_range=None, bin_time_range=None, mcmode=0): # This will increase the number of MC runs while keeping the processing time the same mc_runs = int(np.ceil(mc_runs/self.traj_constraints.mc_cores)*self.traj_constraints.mc_cores) - # pass in matched_observations here so that solveTrajectory can mark them paired if they're used - result = self.solveTrajectory(traj, mc_runs, mcmode=mcmode, matched_obs=matched_observations, orig_traj=traj) + # pass in matched_observations here so that we can mark them unpaired if the solver fails + result = self.solveTrajectory(traj, mc_runs, mcmode=mcmode, matched_obs=matched_observations, orig_traj=traj, verbose=verbose) traj_solved_count += int(result) # end of "for matched_observations in candidate_trajectories" outcomes = [traj_solved_count] - # Finish the correlation run (update the database with new values) - self.dh.saveDatabase() log.info(f'SOLVED {sum(outcomes)} TRAJECTORIES') log.info("") log.info("-----------------") log.info("SOLVING RUN DONE!") log.info("-----------------") + + return sum(outcomes) diff --git a/wmpl/Trajectory/CorrelateRMS.py b/wmpl/Trajectory/CorrelateRMS.py index 88c11292..2abea12c 100644 --- a/wmpl/Trajectory/CorrelateRMS.py +++ b/wmpl/Trajectory/CorrelateRMS.py @@ -12,22 +12,27 @@ import datetime import shutil import time -import signal import multiprocessing import logging import logging.handlers import glob -import pandas as pd from dateutil.relativedelta import relativedelta import numpy as np +import sys +import signal +import secrets from wmpl.Formats.CAMS import loadFTPDetectInfo -from wmpl.Trajectory.CorrelateEngine import TrajectoryCorrelator, TrajectoryConstraints +from wmpl.Trajectory.CorrelateEngine import TrajectoryCorrelator, TrajectoryConstraints, getMcModeStr from wmpl.Utils.Math import generateDatetimeBins from wmpl.Utils.OSTools import mkdirP from wmpl.Utils.Pickling import loadPickle, savePickle from wmpl.Utils.TrajConversions import datetime2JD, jd2Date -from wmpl.Utils.remoteDataHandling import collectRemoteTrajectories, moveRemoteTrajectories, uploadTrajToRemote +from wmpl.Utils.remoteDataHandling import RemoteDataHandler +from wmpl.Trajectory.CorrelateDB import ObservationsDatabase, TrajectoryDatabase +# from wmpl.Trajectory.Trajectory import Trajectory + +from wmpl.Trajectory.CorrelateEngine import MCMODE_CANDS, MCMODE_PHASE1, MCMODE_PHASE2, MCMODE_ALL, MCMODE_BOTH ### CONSTANTS ### @@ -77,6 +82,10 @@ def __init__(self, traj_file_path, json_dict=None, traj_obj=None): except FileNotFoundError: log.info("Pickle file not found: " + traj_file_path) return None + + except: + log.info("Pickle file could not be loaded: " + traj_file_path) + return None else: @@ -84,7 +93,6 @@ def __init__(self, traj_file_path, json_dict=None, traj_obj=None): traj = traj_obj self.traj_file_path = os.path.join(traj.output_dir, traj.file_name + "_trajectory.pickle") - # Reference Julian date (beginning of the meteor) self.jdt_ref = traj.jdt_ref @@ -149,10 +157,6 @@ def __init__(self, db_file_path, verbose=False): self.db_file_path = db_file_path - # List of processed directories (keys are station codes, values are relative paths to night - # directories) - self.processed_dirs = {} - # List of paired observations as a part of a trajectory (keys are station codes, values are unique # observation IDs) self.paired_obs = {} @@ -168,7 +172,6 @@ def __init__(self, db_file_path, verbose=False): # Load the database from a JSON file self.load(verbose=verbose) - def load(self, verbose=False): """ Load the database from a JSON file. """ @@ -202,7 +205,8 @@ def load(self, verbose=False): # Overwrite the database path with the saved one self.db_file_path = db_file_path_saved - if db_is_ok: + # if the trajectories attribute is not present, then the database has been converted to sqlite + if db_is_ok and hasattr(self, 'trajectories'): # Convert trajectories from JSON to TrajectoryReduced objects for traj_dict_str in ["trajectories", "failed_trajectories"]: traj_dict = getattr(self, traj_dict_str) @@ -219,159 +223,6 @@ def load(self, verbose=False): self.verbose = verbose - def save(self): - """ Save the database of processed meteors to disk. """ - - # Back up the existing data base - db_bak_file_path = self.db_file_path + ".bak" - if os.path.exists(self.db_file_path): - shutil.copy2(self.db_file_path, db_bak_file_path) - - # Save the data base - try: - with open(self.db_file_path, 'w') as f: - self2 = copy.deepcopy(self) - - # Convert reduced trajectory objects to JSON objects - self2.trajectories = {key: self.trajectories[key].__dict__ for key in self.trajectories} - self2.failed_trajectories = {key: self.failed_trajectories[key].__dict__ - for key in self.failed_trajectories} - if hasattr(self2, 'phase1Trajectories'): - delattr(self2, 'phase1Trajectories') - - f.write(json.dumps(self2, default=lambda o: o.__dict__, indent=4, sort_keys=True)) - - # Remove the backup file - if os.path.exists(db_bak_file_path): - os.remove(db_bak_file_path) - - except Exception as e: - log.warning('unable to save the database, likely corrupt data') - shutil.copy2(db_bak_file_path, self.db_file_path) - log.warning(e) - - def addProcessedDir(self, station_name, rel_proc_path): - """ Add the processed directory to the list. """ - - if station_name in self.processed_dirs: - if rel_proc_path not in self.processed_dirs[station_name]: - self.processed_dirs[station_name].append(rel_proc_path) - - - def addPairedObservation(self, met_obs): - """ Mark the given meteor observation as paired in a trajectory. """ - - if met_obs.station_code not in self.paired_obs: - self.paired_obs[met_obs.station_code] = [] - - if met_obs.id not in self.paired_obs[met_obs.station_code]: - self.paired_obs[met_obs.station_code].append(met_obs.id) - - - def checkObsIfPaired(self, met_obs): - """ Check if the given observation has been paired to a trajectory or not. """ - - if met_obs.station_code in self.paired_obs: - return (met_obs.id in self.paired_obs[met_obs.station_code]) - - else: - return False - - - def checkTrajIfFailed(self, traj): - """ Check if the given trajectory has been computed with the same observations and has failed to be - computed before. - - """ - - # Check if the reference time is in the list of failed trajectories - if traj.jdt_ref in self.failed_trajectories: - - # Get the failed trajectory object - failed_traj = self.failed_trajectories[traj.jdt_ref] - - # Check if the same observations participate in the failed trajectory as in the trajectory that - # is being tested - all_match = True - for obs in traj.observations: - - if not ((obs.station_id in failed_traj.participating_stations) or (obs.station_id in failed_traj.ignored_stations)): - - all_match = False - break - - # If the same stations were used, the trajectory estimation failed before - if all_match: - return True - - - return False - - - def addTrajectory(self, traj_file_path, traj_obj=None, failed=False): - """ Add a computed trajectory to the list. - - Arguments: - traj_file_path: [str] Full path the trajectory object. - - Keyword arguments: - traj_obj: [bool] Instead of loading a traj object from disk, use the given object. - failed: [bool] Add as a failed trajectory. False by default. - """ - - # Load the trajectory from disk - if traj_obj is None: - - # Init the reduced trajectory object - traj_reduced = TrajectoryReduced(traj_file_path) - if self.verbose: - log.info(f' loaded {traj_file_path}, traj_id {traj_reduced.traj_id}') - # Skip if failed - if traj_reduced is None: - return None - - if not hasattr(traj_reduced, "jdt_ref"): - return None - - else: - # Use the provided trajectory object - traj_reduced = traj_obj - if self.verbose: - log.info(f' loaded {traj_obj.traj_file_path}, traj_id {traj_reduced.traj_id}') - - - # Choose to which dictionary the trajectory will be added - if failed: - traj_dict = self.failed_trajectories - - else: - traj_dict = self.trajectories - - - # Add the trajectory to the list (key is the reference JD) - if traj_reduced.jdt_ref not in traj_dict: - traj_dict[traj_reduced.jdt_ref] = traj_reduced - else: - traj_dict[traj_reduced.jdt_ref].traj_id = traj_reduced.traj_id - - - - def removeTrajectory(self, traj_reduced, keepFolder=False): - """ Remove the trajectory from the data base and disk. """ - - # Remove the trajectory data base entry - if traj_reduced.jdt_ref in self.trajectories: - del self.trajectories[traj_reduced.jdt_ref] - - # Remove the trajectory folder on the disk - if not keepFolder and os.path.isfile(traj_reduced.traj_file_path): - traj_dir = os.path.dirname(traj_reduced.traj_file_path) - shutil.rmtree(traj_dir, ignore_errors=True) - if os.path.isfile(traj_reduced.traj_file_path): - log.info(f'unable to remove {traj_dir}') - - - class MeteorPointRMS(object): def __init__(self, frame, time_rel, x, y, ra, dec, azim, alt, mag): """ Container for individual meteor picks. """ @@ -399,7 +250,6 @@ def __init__(self, frame, time_rel, x, y, ra, dec, azim, alt, mag): self.mag = mag - class MeteorObsRMS(object): def __init__(self, station_code, reference_dt, platepar, data, rel_proc_path, ff_name=None): """ Container for meteor observations with the interface compatible with the trajectory correlator @@ -517,7 +367,7 @@ def __init__(self, **entries): class RMSDataHandle(object): - def __init__(self, dir_path, dt_range=None, db_dir=None, output_dir=None, mcmode=0, max_trajs=1000, remotehost=None, verbose=False): + def __init__(self, dir_path, dt_range=None, db_dir=None, output_dir=None, mcmode=MCMODE_ALL, max_trajs=1000, verbose=False, archivemonths=3): """ Handles data interfacing between the trajectory correlator and RMS data files on disk. Arguments: @@ -530,12 +380,16 @@ def __init__(self, dir_path, dt_range=None, db_dir=None, output_dir=None, mcmode database file will be loaded from the dir_path. output_dir: [str] Path to the directory where the output files will be saved. None by default, in which case the output files will be saved in the dir_path. + mcmode: [int] the operation mode, candidates, phase1 simple solns, mc phase or a combination max_trajs: [int] maximum number of phase1 trajectories to load at a time when adding uncertainties. Improves throughput. """ self.mc_mode = mcmode self.dir_path = dir_path + # create the data directory. Of course, if the folder doesnt exist there is nothing to process + # but by creating it we avoid an Exception later. And we can always copy data in. + mkdirP(dir_path) self.dt_range = dt_range @@ -559,15 +413,25 @@ def __init__(self, dir_path, dt_range=None, db_dir=None, output_dir=None, mcmode # Create the output directory if it doesn't exist mkdirP(self.output_dir) - # Phase 1 trajectory pickle directory needed to reload previous results. - self.phase1_dir = os.path.join(self.output_dir, 'phase1') + if dt_range is None or dt_range[0] == datetime.datetime(2000,1,1,0,0,0).replace(tzinfo=datetime.timezone.utc): + daysback = 14 + else: + daysback = (datetime.datetime.now().replace(tzinfo=datetime.timezone.utc) - dt_range[0]).days + 1 + + # Candidate directory, if running in create or load cands modes + self.candidate_dir = os.path.join(self.output_dir, 'candidates') + mkdirP(os.path.join(self.candidate_dir, 'processed')) + num_removed_cands = self.purgeProcessedData(os.path.join(self.candidate_dir, 'processed'), days_back=daysback, verbose=verbose) + log.info(f'removed {num_removed_cands} processed candidates') - # create the directory for phase1 simple trajectories, if needed - if self.mc_mode > 0: - mkdirP(os.path.join(self.phase1_dir, 'processed')) - self.purgePhase1ProcessedData(os.path.join(self.phase1_dir, 'processed')) + # Phase 1 trajectory pickle directory needed to reload previous results when running phase2. + self.phase1_dir = os.path.join(self.output_dir, 'phase1') + mkdirP(os.path.join(self.phase1_dir, 'processed')) + num_removed_ph1 = self.purgeProcessedData(os.path.join(self.phase1_dir, 'processed'), days_back=daysback, verbose=verbose) + log.info(f'removed {num_removed_ph1} processed phase1') - self.remotehost = remotehost + # In a previous incarnation, if the solver crashed it could leave some `.pickle_processing files`. + self.cleanupPartialProcessing() self.verbose = verbose @@ -575,40 +439,65 @@ def __init__(self, dir_path, dt_range=None, db_dir=None, output_dir=None, mcmode # Load database of processed folders database_path = os.path.join(self.db_dir, JSON_DB_NAME) + + # create an empty processing list + self.processing_list = [] + log.info("") - # move any remotely calculated pickles to their target locations - if os.path.isdir(os.path.join(self.output_dir, 'remoteuploads')): - moveRemoteTrajectories(self.output_dir) - - if mcmode != 2: - log.info("Loading database: {:s}".format(database_path)) - self.db = DatabaseJSON(database_path, verbose=self.verbose) - log.info('Archiving older entries....') - try: - self.archiveOldRecords(older_than=3) - except: - pass - log.info(" ... done!") - # Load the list of stations - station_list = self.loadStations() + if mcmode != MCMODE_PHASE2: - # Find unprocessed meteor files - log.info("") - log.info("Finding unprocessed data...") - self.processing_list = self.findUnprocessedFolders(station_list) - log.info(" ... done!") + # no need to load the legacy JSON file if we already have the sqlite databases + if not os.path.isfile(os.path.join(db_dir, 'observations.db')) and \ + not os.path.isfile(os.path.join(db_dir, 'trajectories.db')): + log.info("Loading old JSON database: {:s}".format(database_path)) + self.old_db = DatabaseJSON(database_path, verbose=self.verbose) + else: + self.old_db = None + + self.observations_db = ObservationsDatabase(db_dir) + if hasattr(self.old_db, 'paired_obs'): + # copy any legacy paired obs data into sqlite + self.observations_db.copyObsJsonRecords(self.old_db.paired_obs, dt_range) + + self.trajectory_db = TrajectoryDatabase(db_dir) + if hasattr(self.old_db, 'failed_trajectories'): + # copy any legacy failed traj data into sqlite, so we avoid recomputing them + self.trajectory_db.copyTrajJsonRecords(self.old_db.failed_trajectories, dt_range, failed=True) + + if self.old_db: + del self.old_db + + if archivemonths != 0: + log.info('Archiving older entries....') + try: + self.archiveOldRecords(older_than=archivemonths) + except: + pass + log.info(" ... done!") + + if mcmode & MCMODE_CANDS: + # Load the list of stations + station_list = self.loadStations() + + # Find unprocessed meteor files + log.info("") + log.info("Finding unprocessed data...") + self.processing_list = self.findUnprocessedFolders(station_list) + log.info(" ... done!") + + # in phase 1, initialise and collect data second as we load candidates dynamically + self.initialiseRemoteDataHandling() else: - # retrieve pickles from a remote host, if configured - if self.remotehost is not None: - collectRemoteTrajectories(remotehost, max_trajs, self.phase1_dir) + # in phase 2, initialise and collect data first as we need the phase1 traj on disk already + self.trajectory_db = None + self.observations_db = None + self.initialiseRemoteDataHandling() - # reload the phase1 trajectories dt_beg, dt_end = self.loadPhase1Trajectories(max_trajs=max_trajs) self.processing_list = None self.dt_range=[dt_beg, dt_end] - self.db = None ### Define country groups to speed up the proceessing ### @@ -632,41 +521,65 @@ def __init__(self, dir_path, dt_range=None, db_dir=None, output_dir=None, mcmode ### ### + def checkRemoteDataMode(self): + remote_cfg = os.path.join(self.db_dir, 'wmpl_remote.cfg') + if os.path.isfile(remote_cfg): + self.RemoteDatahandler = RemoteDataHandler(remote_cfg) + return self.RemoteDatahandler.mode + else: + return 'none' + + + def initialiseRemoteDataHandling(self): + # Initialise remote data handling, if the config file is present + remote_cfg = os.path.join(self.db_dir, 'wmpl_remote.cfg') + if os.path.isfile(remote_cfg): + log.info('remote data management requested, initialising') + self.RemoteDatahandler = RemoteDataHandler(remote_cfg) + if self.RemoteDatahandler.mode == 'child': + self.RemoteDatahandler.clearStopFlag() + status = self.getRemoteData(verbose=True) + else: + status = self.moveUploadedData(verbose=False) + if not status: + log.info('no remote data yet') + else: + self.RemoteDatahandler = None - def purgePhase1ProcessedData(self, dir_path): - """ Purge old phase1 processed data if it is older than 90 days. """ - - refdt = time.time() - 90*86400 - result = [] - for path, _, files in os.walk(dir_path): - - for file in files: - - file_path = os.path.join(path, file) - - # Check if the file is older than the reference date - try: - file_dt = os.stat(file_path).st_mtime - except FileNotFoundError: - log.warning(f"File not found: {file_path}") - continue - - if ( - os.path.exists(file_path) and (file_dt < refdt) and os.path.isfile(file_path) - ): - - try: - os.remove(file_path) - result.append(file_path) - - except FileNotFoundError: - log.warning(f"File not found: {file_path}") + def purgeProcessedData(self, dir_path, days_back=14, verbose=False): + """ Purge processed candidate or phase1 data if it is older than 30 days. """ - except Exception as e: - log.error(f"Error removing file {file_path}: {e}") - - return result + refdt = time.time() - days_back*86400 + num_removed = 0 + log.info(f'purging processed data from {dir_path} thats older than {days_back} days') + for file_name in glob.glob(os.path.join(dir_path,'*.pickle')): + try: + file_dt = os.stat(file_name).st_mtime + if file_dt < refdt: + if verbose: + log.info(f'removing {file_name}') + os.remove(file_name) + num_removed += 1 + except FileNotFoundError: + log.warning(f"File disappeared: {file_name}") + continue + except Exception as e: + log.error(f"Error removing file {file_name}: {e}") + + return num_removed + + def cleanupPartialProcessing(self): + log.info('checking for partially-processed phase1 files') + i=0 + for i, file_name in enumerate(glob.glob(os.path.join(self.phase1_dir, '*.pickle_processing'))): + new_name = file_name.replace('_processing','') + if os.path.isfile(new_name): + os.remove(file_name) + else: + os.rename(file_name, new_name) + log.info(f'updated {i} partially-processed files') + return def archiveOldRecords(self, older_than=3): """ @@ -682,43 +595,26 @@ def __init__(self, station, obs_id): archdate = datetime.datetime.now(datetime.timezone.utc) - relativedelta(months=older_than) archdate_jd = datetime2JD(archdate) + arch_prefix = archdate.strftime("%Y%m") + + # TODO check if this works + self.observations_db.archiveObsDatabase(self.db_dir, arch_prefix, archdate_jd) + self.trajectory_db.archiveTrajDatabase(self.db_dir, arch_prefix, archdate_jd) - arch_db_path = os.path.join(self.db_dir, f'{archdate.strftime("%Y%m")}_{JSON_DB_NAME}') - archdb = DatabaseJSON(arch_db_path, verbose=self.verbose) - log.info(f'Archiving db records to {arch_db_path}...') - - for traj in [t for t in self.db.trajectories if t < archdate_jd]: - if traj < archdate_jd: - archdb.addTrajectory(None, self.db.trajectories[traj], False) - self.db.removeTrajectory(self.db.trajectories[traj], keepFolder=True) - - for traj in [t for t in self.db.failed_trajectories if t < archdate_jd]: - if traj < archdate_jd: - archdb.addTrajectory(None, self.db.failed_trajectories[traj], True) - self.db.removeTrajectory(self.db.failed_trajectories[traj], keepFolder=True) - - for station in self.db.processed_dirs: - arch_processed = [dirname for dirname in self.db.processed_dirs[station] if - datetime.datetime.strptime(dirname[14:22], '%Y%m%d').replace(tzinfo=datetime.timezone.utc) < archdate] - for dirname in arch_processed: - archdb.addProcessedDir(station, dirname) - self.db.processed_dirs[station].remove(dirname) - - for station in self.db.paired_obs: - arch_processed = [obs_id for obs_id in self.db.paired_obs[station] if - datetime.datetime.strptime(obs_id[7:15], '%Y%m%d').replace(tzinfo=datetime.timezone.utc) < archdate] - for obs_id in arch_processed: - archdb.addPairedObservation(DummyMetObs(station, obs_id)) - self.db.paired_obs[station].remove(obs_id) - - archdb.save() - self.db.save() + return + + def closeObservationsDatabase(self): + self.observations_db.closeObsDatabase() + return + + def closeTrajectoryDatabase(self): + self.trajectory_db.closeTrajDatabase() return def loadStations(self): """ Load the station names in the processing folder. """ - station_list = [] + avail_station_list = [] for dir_name in sorted(os.listdir(self.dir_path)): @@ -726,14 +622,12 @@ def loadStations(self): if os.path.isdir(os.path.join(self.dir_path, dir_name)): if re.match("^[A-Z]{2}[A-Z0-9]{4}$", dir_name): log.info("Using station: " + dir_name) - station_list.append(dir_name) + avail_station_list.append(dir_name) else: log.info("Skipping directory: " + dir_name) - return station_list - - + return avail_station_list def findUnprocessedFolders(self, station_list): """ Go through directories and find folders with unprocessed data. """ @@ -747,10 +641,6 @@ def findUnprocessedFolders(self, station_list): station_path = os.path.join(self.dir_path, station_name) - # Add the station name to the database if it doesn't exist - if station_name not in self.db.processed_dirs: - self.db.processed_dirs[station_name] = [] - # Go through all directories in stations for night_name in os.listdir(station_path): @@ -770,10 +660,6 @@ def findUnprocessedFolders(self, station_list): night_path = os.path.join(station_path, night_name) night_path_rel = os.path.join(station_name, night_name) - # # If the night path is not in the processed list, add it to the processing list - # if night_path_rel not in self.db.processed_dirs[station_name]: - # processing_list.append([station_name, night_path_rel, night_path, night_dt]) - processing_list.append([station_name, night_path_rel, night_path, night_dt]) # else: @@ -785,8 +671,6 @@ def findUnprocessedFolders(self, station_list): return processing_list - - def initMeteorObs(self, station_code, ftpdetectinfo_path, platepars_recalibrated_dict): """ Init meteor observations from the FTPdetectinfo file and recalibrated platepars. """ @@ -806,8 +690,6 @@ def initMeteorObs(self, station_code, ftpdetectinfo_path, platepars_recalibrated return meteor_list - - def loadUnpairedObservations(self, processing_list, dt_range=None): """ Load unpaired meteor observations, i.e. observations that are not a part of any trajectory. """ @@ -862,21 +744,12 @@ def loadUnpairedObservations(self, processing_list, dt_range=None): # Skip these observations if no data files were found inside if (ftpdetectinfo_name is None) or (platepar_recalibrated_name is None): log.info(" Skipping {:s} due to missing data files...".format(rel_proc_path)) - - # Add the folder to the list of processed folders - self.db.addProcessedDir(station_code, rel_proc_path) - continue if station_code != prev_station: station_count += 1 prev_station = station_code - # Save database to mark those with missing data files (only every 250th station, to speed things up) - if (station_count % 250 == 0) and (station_code != prev_station): - self.saveDatabase() - - # Load platepars with open(os.path.join(proc_path, platepar_recalibrated_name)) as f: platepars_recalibrated_dict = json.load(f) @@ -934,11 +807,9 @@ def loadUnpairedObservations(self, processing_list, dt_range=None): continue # Add only unpaired observations - if not self.db.checkObsIfPaired(met_obs): - + if not self.checkIfObsPaired(met_obs.id, verbose=verbose): # print(" ", station_code, met_obs.reference_dt, rel_proc_path) added_count += 1 - unpaired_met_obs_list.append(met_obs) log.info(" Added {:d} observations!".format(added_count)) @@ -946,10 +817,8 @@ def loadUnpairedObservations(self, processing_list, dt_range=None): log.info("") log.info(" Finished loading unpaired observations!") - self.saveDatabase() return unpaired_met_obs_list - def yearMonthDayDirInDtRange(self, dir_name): """ Given a directory name which is either YYYY, YYYYMM or YYYYMMDD, check if it is in the given @@ -1039,8 +908,7 @@ def yearMonthDayDirInDtRange(self, dir_name): return True else: - return False - + return False def trajectoryFileInDtRange(self, file_name, dt_range=None): """ Check if the trajectory file is in the given datetime range. """ @@ -1069,15 +937,14 @@ def trajectoryFileInDtRange(self, file_name, dt_range=None): else: return False - - def removeDeletedTrajectories(self): + def removeDeletedTrajectories(self, verbose=True): """ Purge the database of any trajectories that no longer exist on disk. These can arise because the monte-carlo stage may update the data. """ if not os.path.isdir(self.output_dir): return - if self.db is None: + if self.trajectory_db is None: return log.info(" Removing deleted trajectories from: " + self.output_dir) @@ -1086,48 +953,32 @@ def removeDeletedTrajectories(self): self.dt_range[0].strftime("%Y-%m-%d %H:%M:%S"), self.dt_range[1].strftime("%Y-%m-%d %H:%M:%S"))) - jdt_start = datetime2JD(self.dt_range[0]) - jdt_end = datetime2JD(self.dt_range[1]) - - trajs_to_remove = [] - - keys = [k for k in self.db.trajectories.keys() if k >= jdt_start and k <= jdt_end] - for trajkey in keys: - traj_reduced = self.db.trajectories[trajkey] - # Update the trajectory path to make sure we're working with the correct filesystem - traj_path = self.generateTrajOutputDirectoryPath(traj_reduced) - traj_file_name = os.path.split(traj_reduced.traj_file_path)[1] - traj_path = os.path.join(traj_path, traj_file_name) - - if self.verbose: - log.info(f' testing {traj_path}') - - if not os.path.isfile(traj_path): - traj_reduced.traj_file_path = traj_path - trajs_to_remove.append(traj_reduced) - - for traj in trajs_to_remove: - log.info(f' removing deleted {traj.traj_file_path}') - - # remove from the database but not from the disk: they're already not on the disk and this avoids - # accidentally deleting a different traj with a timestamp which is within a millisecond - self.db.removeTrajectory(traj, keepFolder=True) + jdt_range = [datetime2JD(self.dt_range[0]), datetime2JD(self.dt_range[1])] + + traj_list = self.trajectory_db.getTrajBasics(self.output_dir, jdt_range) + i = 0 + for traj in traj_list: + if not os.path.isfile(os.path.join(self.output_dir, traj['traj_file_path'])): + if verbose: + log.info(f'removing traj {jd2Date(traj["jdt_ref"],dt_obj=True).strftime("%Y%m%d_%H%M%S.%f")} {traj["traj_file_path"]} from database') + self.removeTrajectory(TrajectoryReduced(None, json_dict=traj)) + i += 1 + log.info(f'removed {i} deleted trajectories') return - - def loadComputedTrajectories(self, traj_dir_path, dt_range=None): + def loadComputedTrajectories(self, dt_range=None): """ Load already estimated trajectories from disk within a date range. Arguments: - traj_dir_path: [str] Full path to a directory with trajectory pickles. + dt_range: [datetime, datetime] range of dates to load data for """ - + traj_dir_path = os.path.join(self.output_dir, OUTPUT_TRAJ_DIR) # defend against the case where there are no existing trajectories and traj_dir_path doesn't exist if not os.path.isdir(traj_dir_path): return - if self.db is None: + if self.trajectory_db is None: return if dt_range is None: @@ -1135,7 +986,7 @@ def loadComputedTrajectories(self, traj_dir_path, dt_range=None): else: dt_beg, dt_end = dt_range - log.info(" Loading trajectories from: " + traj_dir_path) + log.info(" Loading found trajectories from: " + traj_dir_path) if self.dt_range is not None: log.info(" Datetime range: {:s} - {:s}".format( dt_beg.strftime("%Y-%m-%d %H:%M:%S"), @@ -1160,17 +1011,17 @@ def loadComputedTrajectories(self, traj_dir_path, dt_range=None): curr_dt = jd2Date(jdt, dt_obj=True) if curr_dt.year != yyyy: yyyy = curr_dt.year - log.info("- year " + str(yyyy)) + #log.info("- year " + str(yyyy)) if curr_dt.month != mm: mm = curr_dt.month yyyymm = f'{yyyy}{mm:02d}' - log.info(" - month " + str(yyyymm)) + #log.info(" - month " + str(yyyymm)) if curr_dt.day != dd: dd = curr_dt.day yyyymmdd = f'{yyyy}{mm:02d}{dd:02d}' - log.info(" - day " + str(yyyymmdd)) + #log.info(" - day " + str(yyyymmdd)) yyyymmdd_dir_path = os.path.join(traj_dir_path, f'{yyyy}', f'{yyyymm}', f'{yyyymmdd}') @@ -1187,105 +1038,36 @@ def loadComputedTrajectories(self, traj_dir_path, dt_range=None): if self.trajectoryFileInDtRange(file_name, dt_range=dt_range): - self.db.addTrajectory(os.path.join(full_traj_dir, file_name)) + self.trajectory_db.addTrajectory(TrajectoryReduced(os.path.join(full_traj_dir, file_name)), force_add=False) # Print every 1000th trajectory if counter % 1000 == 0: - log.info(f" Loaded {counter:6d} trajectories, currently on {file_name}") + log.info(f" Loaded {counter:6d} trajectories") counter += 1 dir_paths.append(full_traj_dir) dur = (datetime.datetime.now() - start_time).total_seconds() log.info(f" Loaded {counter:6d} trajectories in {dur:.0f} seconds") - - def getComputedTrajectories(self, jd_beg, jd_end): """ Returns a list of computed trajectories between the Julian dates. """ - - return [self.db.trajectories[key] for key in self.db.trajectories - if (self.db.trajectories[key].jdt_ref >= jd_beg) - and (self.db.trajectories[key].jdt_ref <= jd_end)] - - - def removeDuplicateTrajectories(self, dt_range): - """ Remove trajectories with duplicate IDs - keeping the one with the most station observations - """ - - log.info('removing duplicate trajectories') - - tr_in_scope = self.getComputedTrajectories(datetime2JD(dt_range[0]), datetime2JD(dt_range[1])) - tr_to_check = [{'jdt_ref':traj.jdt_ref,'traj_id':traj.traj_id, 'traj': traj} for traj in tr_in_scope if hasattr(traj,'traj_id')] - - if len(tr_to_check) == 0: - log.info('no trajectories in range') - return - - tr_df = pd.DataFrame(tr_to_check) - tr_df['dupe']=tr_df.duplicated(subset=['traj_id']) - dupeids = tr_df[tr_df.dupe].sort_values(by=['traj_id']).traj_id - duperows = tr_df[tr_df.traj_id.isin(dupeids)] - - log.info(f'there are {len(duperows.traj_id.unique())} duplicate trajectories') - - # iterate over the duplicates, finding the best and removing the others - for traj_id in duperows.traj_id.unique(): - num_stats = 0 - best_traj_dt = None - best_traj_path = None - # find duplicate with largest number of observations - for testdt in duperows[duperows.traj_id==traj_id].jdt_ref.values: - - if len(dh.db.trajectories[testdt].participating_stations) > num_stats: - - best_traj_dt = testdt - num_stats = len(dh.db.trajectories[testdt].participating_stations) - # sometimes the database contains duplicates that differ by microseconds in jdt. These - # will have overwritten each other in the folder so make a note of the location. - best_traj_path = dh.db.trajectories[testdt].traj_file_path - - # now remove all except the best - for testdt in duperows[duperows.traj_id==traj_id].jdt_ref.values: - - traj = dh.db.trajectories[testdt] - if testdt != best_traj_dt: - - # get the current trajectory's location. If its the same as that of the best trajectory - # don't try to delete the solution from disk even if there's a small difference in jdt_ref - keepFolder = False - if traj.traj_file_path == best_traj_path: - keepFolder = True - # Update the trajectory path to make sure we're working with the correct filesystem - traj_path = self.generateTrajOutputDirectoryPath(traj) - traj_file_name = os.path.split(traj.traj_file_path)[1] - traj.traj_file_path = os.path.join(traj_path, traj_file_name) - log.info(f'removing duplicate {traj.traj_id} keep {traj_file_name} {keepFolder}') - - self.db.removeTrajectory(traj, keepFolder=keepFolder) - - else: - if self.verbose: - log.info(f'keeping {traj.traj_id} {traj.traj_file_path}') - - return - + jd_range = [jd_beg, jd_end] + json_dicts = self.trajectory_db.getTrajectories(self.output_dir, jd_range) + trajs = [TrajectoryReduced(None, json_dict=j) for j in json_dicts] + return trajs def getPlatepar(self, met_obs): """ Return the platepar of the meteor observation. """ return met_obs.platepar - - def getUnpairedObservations(self): """ Returns a list of unpaired meteor observations. """ return self.unpaired_observations - def countryFilter(self, station_code1, station_code2): """ Only pair observations if they are in proximity to a given country. """ @@ -1300,9 +1082,30 @@ def countryFilter(self, station_code1, station_code2): # If a given country is not in any of the groups, allow it to be paired return True + + def checkIfObsPaired(self, obs_id, verbose=False): + return self.observations_db.checkObsPaired(obs_id, verbose) + + def addPairedObs(self, matched_obs, jdt_ref, verbose=False): + """ + mark a list of observations as paired + + parameters: + matched_obs : a tuple containing the observations. + jdt_ref : the julian date of the Trajectory they are paired with. + """ + if len(matched_obs[0])==3: + obs_ids = [met_obs.id for _, met_obs, _ in matched_obs] + else: + obs_ids = [met_obs.id for _, met_obs in matched_obs] + jdt_refs = [jdt_ref] * len(obs_ids) - def findTimePairs(self, met_obs, unpaired_observations, max_toffset): + self.observations_db.addPairedObservations(obs_ids, jdt_refs, verbose=verbose) + + return + + def findTimePairs(self, met_obs, unpaired_observations, max_toffset, verbose=False): """ Finds pairs in time between the given meteor observations and all other observations from different stations. @@ -1322,6 +1125,9 @@ def findTimePairs(self, met_obs, unpaired_observations, max_toffset): # Go through all meteors from other stations for met_obs2 in unpaired_observations: + if self.checkIfObsPaired(met_obs2.id, verbose=verbose): + continue + # Take only observations from different stations if met_obs.station_code == met_obs2.station_code: continue @@ -1337,7 +1143,6 @@ def findTimePairs(self, met_obs, unpaired_observations, max_toffset): return found_pairs - def getTrajTimePairs(self, traj_reduced, unpaired_observations, max_toffset): """ Find unpaired observations which are close in time to the given trajectory. """ @@ -1366,7 +1171,6 @@ def getTrajTimePairs(self, traj_reduced, unpaired_observations, max_toffset): return found_traj_obs_pairs - def generateTrajOutputDirectoryPath(self, traj, make_dirs=False): """ Generate a path to the trajectory output directory. @@ -1377,11 +1181,11 @@ def generateTrajOutputDirectoryPath(self, traj, make_dirs=False): # Generate a list of station codes if isinstance(traj, TrajectoryReduced): # If the reducted trajectory object is given - station_list = traj.participating_stations + traj_station_list = traj.participating_stations else: # If the full trajectory object is given - station_list = [obs.station_id for obs in traj.observations if obs.ignore_station is False] + traj_station_list = [obs.station_id for obs in traj.observations if obs.ignore_station is False] # Datetime of the reference trajectory time @@ -1399,7 +1203,7 @@ def generateTrajOutputDirectoryPath(self, traj, make_dirs=False): # Name of the trajectory directory # sort the list of country codes otherwise we can end up with duplicate trajectories - ctry_list = list(set([stat_id[:2] for stat_id in station_list])) + ctry_list = list(set([stat_id[:2] for stat_id in traj_station_list])) ctry_list.sort() traj_dir = dt.strftime("%Y%m%d_%H%M%S.%f")[:-3] + "_" + "_".join(ctry_list) @@ -1411,8 +1215,7 @@ def generateTrajOutputDirectoryPath(self, traj, make_dirs=False): return out_path - - def saveTrajectoryResults(self, traj, save_plots): + def saveTrajectoryResults(self, traj, save_plots, verbose=False): """ Save trajectory results to the disk. """ @@ -1427,7 +1230,7 @@ def saveTrajectoryResults(self, traj, save_plots): # if additional observations are found then the refdt or country list may change quite a bit traj.longname = os.path.split(output_dir)[-1] - if self.mc_mode == 1: + if self.mc_mode & MCMODE_PHASE1: # The MC phase may change the refdt so save a copy of the the original name. traj.pre_mc_longname = traj.longname @@ -1438,18 +1241,14 @@ def saveTrajectoryResults(self, traj, save_plots): savePickle(traj, output_dir, traj.file_name + '_trajectory.pickle') log.info(f'saved {traj.traj_id} to {output_dir}') - if self.mc_mode == 1: - savePickle(traj, self.phase1_dir, traj.pre_mc_longname + '_trajectory.pickle') - elif self.mc_mode == 2: - # we save this in MC mode the MC phase may alter the trajectory details and if later on + if self.mc_mode & MCMODE_PHASE1 and not self.mc_mode & MCMODE_PHASE2: + self.saveCandOrTraj(traj, traj.pre_mc_longname + '_trajectory.pickle', verbose=verbose) + + elif self.mc_mode & MCMODE_PHASE2: + # the MC phase may alter the trajectory details and if later on # we're including additional observations we need to use the most recent version of the trajectory savePickle(traj, os.path.join(self.phase1_dir, 'processed'), traj.pre_mc_longname + '_trajectory.pickle') - if self.remotehost is not None: - log.info('saving to remote host') - uploadTrajToRemote(remotehost, traj.file_name + '_trajectory.pickle', output_dir) - log.info(' ...done') - # Save the plots if save_plots: traj.save_results = True @@ -1459,27 +1258,7 @@ def saveTrajectoryResults(self, traj, save_plots): pass traj.save_results = False - - - def markObservationAsProcessed(self, met_obs): - """ Mark the given meteor observation as processed. """ - - if self.db is None: - return - self.db.addProcessedDir(met_obs.station_code, met_obs.rel_proc_path) - - - - def markObservationAsPaired(self, met_obs): - """ Mark the given meteor observation as paired in a trajectory. """ - - if self.db is None: - return - self.db.addPairedObservation(met_obs) - - - - def addTrajectory(self, traj, failed_jdt_ref=None): + def addTrajectory(self, traj, failed_jdt_ref=None, verbose=False): """ Add the resulting trajectory to the database. Arguments: @@ -1487,7 +1266,7 @@ def addTrajectory(self, traj, failed_jdt_ref=None): failed_jdt_ref: [float] Reference Julian date of the failed trajectory. None by default. """ - if self.db is None: + if self.trajectory_db is None: return # Set the correct output path traj.output_dir = self.generateTrajOutputDirectoryPath(traj) @@ -1500,15 +1279,13 @@ def addTrajectory(self, traj, failed_jdt_ref=None): if failed_jdt_ref is not None: traj_reduced.jdt_ref = failed_jdt_ref - self.db.addTrajectory(None, traj_obj=traj_reduced, failed=(failed_jdt_ref is not None)) + self.trajectory_db.addTrajectory(traj_reduced, failed=(failed_jdt_ref is not None), verbose=verbose) - - - def removeTrajectory(self, traj_reduced): + def removeTrajectory(self, traj_reduced, remove_phase1=False): """ Remove the trajectory from the data base and disk. """ # in mcmode 2 the database isn't loaded but we still need to delete updated trajectories - if self.mc_mode == 2: + if self.mc_mode & MCMODE_PHASE2: if os.path.isfile(traj_reduced.traj_file_path): traj_dir = os.path.dirname(traj_reduced.traj_file_path) shutil.rmtree(traj_dir, ignore_errors=True) @@ -1518,50 +1295,74 @@ def removeTrajectory(self, traj_reduced): traj_dir = os.path.join(base_dir, traj_reduced.pre_mc_longname) if os.path.isdir(traj_dir): shutil.rmtree(traj_dir, ignore_errors=True) - else: - log.warning(f'unable to find {traj_dir}') - else: - log.warning(f'unable to find {traj_reduced.traj_file_path}') + return - # remove the processed pickle now we're done with it - self.cleanupPhase2TempPickle(traj_reduced, True) + if self.mc_mode & MCMODE_PHASE1 and remove_phase1: + # remove any solution from the phase1 folder + phase1_traj = os.path.join(self.phase1_dir, os.path.basename(traj_reduced.traj_file_path)) + if os.path.isfile(phase1_traj): + try: + os.remove(phase1_traj) + except Exception: + pass - return - self.db.removeTrajectory(traj_reduced) + # Remove the trajectory folder from the disk + if os.path.isfile(traj_reduced.traj_file_path): + traj_dir = os.path.dirname(traj_reduced.traj_file_path) + shutil.rmtree(traj_dir, ignore_errors=True) + if os.path.isfile(traj_reduced.traj_file_path): + log.warning(f'unable to remove {traj_dir}') + self.trajectory_db.removeTrajectory(traj_reduced) - def cleanupPhase2TempPickle(self, traj, success=False): + def checkAlreadyProcessed(self, matched_observations, verbose=False): """ - At the start of phase 2 monte-carlo sim calculation, the phase1 pickles are renamed to indicate they're being processed. - Once each one is processed (fail or succeed) we need to clean up the file. If the MC step failed, we still want to keep - the pickle, because we might later on get new data and it might become solvable. Otherwise, we can just delete the file - since the MC solver will have saved an updated one already. + Check if a list of candidates has already been processed, and return only the new ones """ - if self.mc_mode != 2: - return - fldr_name = os.path.split(self.generateTrajOutputDirectoryPath(traj, make_dirs=False))[-1] - pick = os.path.join(self.phase1_dir, fldr_name + '_trajectory.pickle_processing') - if os.path.isfile(pick): - os.remove(pick) - else: - log.warning(f'unable to find _processing file {pick}') - if not success: - # save the pickle in case we get new data later and can solve it - savePickle(traj, os.path.join(self.phase1_dir, 'processed'), fldr_name + '_trajectory.pickle') - return + # go through the candidates and check if they correspond to already-failed + candidate_trajectories=[] + for cand in matched_observations: + ref_dt = min([met_obs.reference_dt for _, met_obs, _ in cand]) + ctry_list = list(set([met_obs.station_code[:2] for _, met_obs, _ in cand])) + ctry_list.sort() + ctries = '_'.join(ctry_list) + file_name = f'{ref_dt.timestamp():.6f}_{ctries}.pickle' + save_dir = self.candidate_dir + if verbose: + log.info(f'Candidate {file_name} contains {len(cand)} observations') + + if os.path.isfile(os.path.join(save_dir, file_name)) or os.path.isfile(os.path.join(save_dir, 'processed', file_name)): + if verbose: + log.info(f'candidate {file_name} already processed') + continue + + else: + candidate_trajectories.append(cand) + + return candidate_trajectories + + def checkCandIfFailed(self, candidate): + """ Check if the given candidate has been processed with the same observations and has failed to be + computed before. + """ + jdt_ref = min([obs.jdt_ref for obs, _, _ in candidate]) + stations = [obs.station_id for obs, _, _ in candidate] + return self.trajectory_db.checkCandIfFailed(jdt_ref, stations) def checkTrajIfFailed(self, traj): """ Check if the given trajectory has been computed with the same observations and has failed to be computed before. - """ - - if self.db is None: - return - return self.db.checkTrajIfFailed(traj) + Parameters: + traj: full trajectory object + """ + if self.trajectory_db is None: + return + traj_reduced = TrajectoryReduced(None, traj_obj=traj) + return self.trajectory_db.checkTrajIfFailed(traj_reduced) def loadFullTraj(self, traj_reduced): """ Load the full trajectory object. @@ -1645,12 +1446,12 @@ def loadPhase1Trajectories(self, max_trajs=1000): if not hasattr(traj, 'pre_mc_longname'): traj.pre_mc_longname = os.path.split(traj_dir)[-1] - # Check if the traj object as fixed time offsets + # Check if the traj object has fixed time offsets if not hasattr(traj, 'fixed_time_offsets'): traj.fixed_time_offsets = {} - # now we've loaded the phase 1 solution, move it to prevent accidental reprocessing - procfile = os.path.join(self.phase1_dir, pick + '_processing') + # now we've loaded the phase 1 solution, move it to prevent reprocessing + procfile = os.path.join(self.phase1_dir, 'processed', pick) if os.path.isfile(procfile): os.remove(procfile) os.rename(os.path.join(self.phase1_dir, pick), procfile) @@ -1661,30 +1462,234 @@ def loadPhase1Trajectories(self, max_trajs=1000): # if the file couldn't be read, then skip it for now - we'll get it in the next pass log.info(f'File {pick} skipped for now') return dt_beg, dt_end + + def moveUploadedData(self, verbose=False): + """ + Used in 'master' mode: this moves uploaded data to the target locations on the server + and merges in the databases + """ + log.info('merging in any remotely processed data') + for node in self.RemoteDatahandler.nodes: + if node.nodename == 'localhost' or self.observations_db is None or self.trajectory_db is None: + continue + + # if the remote node upload path doesn't exist skip it + if not os.path.isdir(os.path.join(node.dirpath,'files')): + continue + # merge the databases + for obsdb_path in glob.glob(os.path.join(node.dirpath,'files','observations*.db')): + if self.observations_db.mergeObsDatabase(obsdb_path): + os.remove(obsdb_path) + try: + os.remove(f'{obsdb_path}-wal') + os.remove(f'{obsdb_path}-shm') + except Exception: + pass + + + for trajdb_path in glob.glob(os.path.join(node.dirpath,'files','trajectories*.db')): + if self.trajectory_db.mergeTrajDatabase(trajdb_path): + os.remove(trajdb_path) + + i = 0 + remote_trajdir = os.path.join(node.dirpath, 'files', 'trajectories') + if os.path.isdir(remote_trajdir): + for i,traj in enumerate(os.listdir(remote_trajdir)): + if os.path.isdir(os.path.join(remote_trajdir, traj)): + targ_path = os.path.join(self.output_dir, 'trajectories', traj[:4], traj[:6], traj[:8], traj) + src_path = os.path.join(node.dirpath,'files', 'trajectories', traj) + for src_name in os.listdir(src_path): + src_name = os.path.join(src_path, src_name) + if not os.path.isfile(src_name): + log.info(f'{src_name} missing') + else: + os.makedirs(targ_path, exist_ok=True) + shutil.copy(src_name, targ_path) + shutil.rmtree(src_path,ignore_errors=True) + if i > 0: + log.info(f'moved {i+1} trajectories') + + # if the node was in mode 1 then move any uploaded phase1 solutions + remote_ph1dir = os.path.join(node.dirpath, 'files', 'phase1') + if os.path.isdir(remote_ph1dir) and node.mode==1: + if not os.path.isdir(self.phase1_dir): + os.makedirs(self.phase1_dir, exist_ok=True) + i = 0 + for i, fil in enumerate([x for x in os.listdir(remote_ph1dir) if '.pickle' in x]): + full_name = os.path.join(remote_ph1dir, fil) + shutil.copy(full_name, self.phase1_dir) + os.remove(full_name) + + if i > 0: + log.info(f'moved {i+1} phase 1 solutions from {node.nodename}') + + # if the node was in mode 1 then move any uploaded processed candidates + remote_canddir = os.path.join(node.dirpath, 'files', 'candidates', 'processed') + if os.path.isdir(remote_canddir) and node.mode==1: + i = 0 + targ_dir = os.path.join(self.candidate_dir, 'processed') + for i, fil in enumerate([x for x in os.listdir(remote_canddir) if '.pickle' in x]): + full_name = os.path.join(remote_canddir, fil) + shutil.copy(full_name, targ_dir) + os.remove(full_name) + + if i > 0: + log.info(f'moved {i+1} processed candidates from {node.nodename}') + + return True - def saveDatabase(self): - """ Save the data base. """ + def checkAndRedistribCands(self, wait_time=6, verbose=False): + """ + Check child nodes and + 1) if the stop flag has appeared, move any pending data to prevent it getting stuck + 2) move data if it has been waiting more than wait_time hours + 3) if the node is idle, assign it extra data - def _breakHandler(signum, frame): - """ Do nothing if CTRL + C is pressed. """ - log.info("The data base is being saved, the program cannot be exited right now!") - pass + Parameters: + wait_time : time in hours to wait before data is considered stale - if self.db is None: - return - # Prevent quitting while a data base is being saved - original_signal = signal.getsignal(signal.SIGINT) - signal.signal(signal.SIGINT, _breakHandler) + """ + for node in self.RemoteDatahandler.nodes: + if node.nodename == 'localhost' or self.observations_db is None or self.trajectory_db is None: + continue + # if the remote node upload path doesn't exist skip it + if not os.path.isdir(os.path.join(node.dirpath,'files')): + continue - # Save the data base - log.info("Saving data base to disk...") - self.db.save() + # if the stop file has appeared, then move any pending candidates or phase1 files + if os.path.isfile(os.path.join(node.dirpath, 'files','stop')): + log.info(f'{node.nodename} stopfile has appeared, moving data') + for full_name in glob.glob(os.path.join(node.dirpath, 'files', 'candidates', '*.pickle')): + shutil.copy(full_name, self.candidate_dir) + os.remove(full_name) + for full_name in glob.glob(os.path.join(node.dirpath, 'files', 'phase1', '*.pickle')): + shutil.copy(full_name, self.phase1_dir) + os.remove(full_name) + else: + # if the stop file isnt present and the nodes are idle, give them something to do + targ_dir = os.path.join(node.dirpath, 'files', 'candidates') + if len(glob.glob(os.path.join(targ_dir, '*.pickle'))) == 0 and node.mode == MCMODE_PHASE1: + # the node is waiting for data + log.info(f'{node.nodename} idle, giving it extra candidates') + i = 0 + for i, full_name in enumerate(glob.glob(os.path.join(self.candidate_dir, '*.pickle'))): + shutil.copy(full_name, targ_dir) + os.remove(full_name) + i +=1 + if i == node.capacity: + break + pass + targ_dir = os.path.join(node.dirpath, 'files', 'phase1') + if len(glob.glob(os.path.join(targ_dir, '*.pickle'))) == 0 and node.mode == MCMODE_PHASE2: + # the node is waiting for data + log.info(f'{node.nodename} idle, giving it extra phase1 data') + i = 0 + for i, full_name in enumerate(glob.glob(os.path.join(self.phase1_dir, '*.pickle'))): + shutil.copy(full_name, targ_dir) + os.remove(full_name) + i +=1 + if i == node.capacity: + break + pass + + # if the files have been in the nodes folder for more than wait_time hours, move them + refdt = time.time() - wait_time*3600 + log.info(f'moving any stale data assigned to {node.nodename}') + for full_name in glob.glob(os.path.join(node.dirpath, 'files', 'candidates', '*.pickle')): + if os.stat(full_name).st_mtime < refdt: + shutil.copy(full_name, self.candidate_dir) + os.remove(full_name) + for full_name in glob.glob(os.path.join(node.dirpath, 'files', 'phase1', '*.pickle')): + if os.stat(full_name).st_mtime < refdt: + shutil.copy(full_name, self.phase1_dir) + os.remove(full_name) - # Restore the signal functionality - signal.signal(signal.SIGINT, original_signal) + return + def getRemoteData(self, verbose=False): + """ + Used in 'child' mode: this downloads data from the master for local processing. + """ + if not self.RemoteDatahandler: + log.info('remote data handler not initialised') + return False + # collect candidates or phase1 solutions from the master node + if self.mc_mode == MCMODE_PHASE1 or self.mc_mode == MCMODE_BOTH: + status = self.RemoteDatahandler.collectRemoteData('candidates', self.output_dir, verbose=verbose) + elif mcmode == MCMODE_PHASE2: + status = self.RemoteDatahandler.collectRemoteData('phase1', self.output_dir, verbose=verbose) + else: + status = False + return status + + def saveCandidates(self, candidate_trajectories, verbose=False): + num_saved = 0 + for matched_observations in candidate_trajectories: + ref_dt = min([met_obs.reference_dt for _, met_obs, _ in matched_observations]) + ctry_list = list(set([met_obs.station_code[:2] for _, met_obs, _ in matched_observations])) + ctry_list.sort() + ctries = '_'.join(ctry_list) + picklename = f'{ref_dt.timestamp():.6f}_{ctries}.pickle' + + if verbose: + log.info(f'Candidate {picklename} contains {len(matched_observations)} observations') + if self.saveCandOrTraj(matched_observations, picklename, 'candidates', verbose=verbose): + num_saved += 1 + + log.info("-----------------------") + log.info(f'Saved {num_saved} candidates') + log.info("-----------------------") + + def saveCandOrTraj(self, traj, file_name, savetype='phase1', verbose=False): + """ + in mcmode MCMODE_PHASE1 or MCMODE_SIMPLE , save the candidates or phase 1 trajectories + and distribute as appropriate + + """ + if savetype == 'phase1': + save_dir = self.phase1_dir + required_mode = MCMODE_PHASE2 + else: + save_dir = self.candidate_dir + required_mode = MCMODE_PHASE1 + + if self.RemoteDatahandler and self.RemoteDatahandler.mode == 'master': + + # Select a random bucket, check its not already full, and then save the pickle there. + # Make sure to break out once all buckets have been tested + # Fallback/default is to use the local dir. + tested_buckets = [] + bucket_num = -1 + bucket_list = self.RemoteDatahandler.nodes + bucket_list[-1].dirpath = save_dir + + while bucket_num not in tested_buckets: + bucket_num = secrets.randbelow(len(bucket_list)) + bucket = bucket_list[bucket_num] + # if the child isn't the right mode, skip it + if bucket.mode != required_mode and bucket.mode != -1: + tested_buckets.append(bucket_num) + continue + if bucket.nodename != 'localhost': + tmp_save_dir = os.path.join(bucket.dirpath, 'files', savetype) + else: + tmp_save_dir = save_dir + os.makedirs(tmp_save_dir, exist_ok=True) + if os.path.isfile(os.path.join(bucket.dirpath, 'files', 'stop')): + tested_buckets.append(bucket_num) + continue + if bucket.capacity < 0 or len(glob.glob(os.path.join(tmp_save_dir, '*.pickle'))) < bucket.capacity: + if bucket.nodename != 'localhost': + save_dir = tmp_save_dir + break + tested_buckets.append(bucket_num) + + if verbose: + log.info(f'saving {file_name} to {save_dir}') + savePickle(traj, save_dir, file_name) + return True @@ -1776,7 +1781,10 @@ def _breakHandler(signum, frame): help="Use best N stations in the solution (default is use 15 stations).") arg_parser.add_argument('--mcmode', '--mcmode', type=int, default=0, - help="Run just simple soln (1), just monte-carlos (2) or both (0, default).") + help="Operation mode - see readme. For standalone solving either don't set this or set it to 0") + + arg_parser.add_argument('--archivemonths', '--archivemonths', type=int, default=3, + help="Months back to archive old data. Default 3. Zero means don't archive (useful in testing).") arg_parser.add_argument('--maxtrajs', '--maxtrajs', type=int, default=None, help="Max number of trajectories to reload in each pass when doing the Monte-Carlo phase") @@ -1784,17 +1792,56 @@ def _breakHandler(signum, frame): arg_parser.add_argument('--autofreq', '--autofreq', type=int, default=360, help="Minutes to wait between runs in auto-mode") - arg_parser.add_argument('--remotehost', '--remotehost', type=str, default=None, - help="Remote host to collect and return MC phase solutions to. Supports internet-distributed processing.") - arg_parser.add_argument('--verbose', '--verbose', help='Verbose logging.', default=False, action="store_true") + arg_parser.add_argument('--addlogsuffix', '--addlogsuffix', help='add a suffix to the log to show what stage it is.', default=False, action="store_true") + # Parse the command line arguments cml_args = arg_parser.parse_args() ############################ - + db_dir = cml_args.dbdir + if db_dir is None: + db_dir = cml_args.dir_path + + # mcmode values + # mcmode = 1 -> load candidates and do simple solutions + # mcmode = 2 -> load simple solns and do MC solutions + # mcmode = 4 -> find candidates only + # mcmode = 7 -> do everything + # mcmode = 0 -> same as mode 7 + # bitwise combinations are permissioble so: + # 4+1 will find candidates and then run simple solutions to populate "phase1" + # 1+2 will load candidates from "candidates" and solve them completely + + mcmode = MCMODE_ALL if cml_args.mcmode == 0 else cml_args.mcmode + + + mcmodestr = getMcModeStr(mcmode, 1) + pid_file = None + if mcmodestr: + pid_file = os.path.join(db_dir, f'.{mcmodestr}.pid') + open(pid_file,'w').write(f'{os.getpid()}') + + # signal handler created inline here as it needs access to db_dir + def signal_handler(sig, frame): + signal.signal(sig, signal.SIG_IGN) # ignore additional signals + log.info('======================================') + log.info('CTRL-C pressed, exiting gracefully....') + log.info('======================================') + remote_cfg = os.path.join(db_dir, 'wmpl_remote.cfg') + if os.path.isfile(remote_cfg): + rdh = RemoteDataHandler(remote_cfg) + if rdh and rdh.mode == 'child': + rdh.setStopFlag() + if os.path.isfile(pid_file): + os.remove(pid_file) + log.info('DONE') + log.info('======================================') + sys.exit(0) + + signal.signal(signal.SIGINT, signal_handler) ### Init logging - roll over every day ### @@ -1821,6 +1868,11 @@ def _breakHandler(signum, frame): # Init the file handler timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") log_file = os.path.join(log_dir, f"correlate_rms_{timestamp}.log") + if cml_args.addlogsuffix: + modestr = getMcModeStr(cml_args.mcmode, 1) + if modestr: + log_file = os.path.join(log_dir, f"correlate_rms_{timestamp}_{modestr}.log") + file_handler = logging.handlers.TimedRotatingFileHandler(log_file, when="midnight", backupCount=7) file_handler.setFormatter(log_formatter) log.addHandler(file_handler) @@ -1869,21 +1921,14 @@ def _breakHandler(signum, frame): if cml_args.maxerr is not None: trajectory_constraints.max_arcsec_err = cml_args.maxerr - remotehost = cml_args.remotehost - if cml_args.mcmode !=2 and remotehost is not None: - log.info('remotehost only applicable in mcmode 2') - remotehost = None - + # set the maximum number of trajectories to reprocess when doing the MC uncertainties # set a default of 10 for remote processing and 1000 for local processing - if cml_args.remotehost is not None: - max_trajs = 10 - else: - max_trajs = 1000 + max_trajs = 1000 if cml_args.maxtrajs is not None: max_trajs = int(cml_args.maxtrajs) - if cml_args.mcmode == 2: + if mcmode == MCMODE_PHASE2: log.info(f'Reloading at most {max_trajs} phase1 trajectories.') # Set the number of CPU cores @@ -1893,8 +1938,22 @@ def _breakHandler(signum, frame): trajectory_constraints.mc_cores = cpu_cores log.info("Running using {:d} CPU cores.".format(cpu_cores)) + if mcmode == MCMODE_CANDS: + log.info('Saving Candidates only') + elif mcmode == MCMODE_PHASE1: + log.info('Loading Candidates if needed') + elif mcmode == MCMODE_ALL: + log.info('Full processing mode') + + if cml_args.verbose: + log.info('verbose flag set') + verbose = True + else: + verbose = False + # Run processing. If the auto run more is not on, the loop will break after one run previous_start_time = None + while True: # Clock for measuring script time @@ -1949,10 +2008,10 @@ def _breakHandler(signum, frame): dh = RMSDataHandle( cml_args.dir_path, dt_range=event_time_range, db_dir=cml_args.dbdir, output_dir=cml_args.outdir, - mcmode=cml_args.mcmode, max_trajs=max_trajs, remotehost=remotehost, verbose=cml_args.verbose) + mcmode=mcmode, max_trajs=max_trajs, verbose=verbose, archivemonths=cml_args.archivemonths) - # If there is nothing to process, stop, unless we're in mcmode 2 (processing_list is not used in this case) - if not dh.processing_list and cml_args.mcmode < 2: + # If there is nothing to process and we're in Candidate mode, stop + if not dh.processing_list and (mcmode & MCMODE_CANDS): log.info("") log.info("Nothing to process!") log.info("Probably everything is already processed.") @@ -1962,7 +2021,7 @@ def _breakHandler(signum, frame): ### GENERATE DAILY TIME BINS ### - if cml_args.mcmode != 2: + if mcmode != MCMODE_PHASE2: # Find the range of datetimes of all folders (take only those after the year 2000) proc_dir_dts = [entry[3] for entry in dh.processing_list if entry[3] is not None] proc_dir_dts = [dt for dt in proc_dir_dts if dt > datetime.datetime(2000, 1, 1, 0, 0, 0, @@ -1984,10 +2043,12 @@ def _breakHandler(signum, frame): proc_dir_dt_beg = min(proc_dir_dts) proc_dir_dt_end = max(proc_dir_dts) + bin_length = 0.25 if mcmode == MCMODE_CANDS else 1.0 + # Split the processing into daily chunks dt_bins = generateDatetimeBins( proc_dir_dt_beg, proc_dir_dt_end, - bin_days=1, tzinfo=datetime.timezone.utc, reverse=False) + bin_days=bin_length, tzinfo=datetime.timezone.utc, reverse=False) # check if we've created an extra bucket (might happen if requested timeperiod is less than 24h) if event_time_range is not None: @@ -1998,12 +2059,13 @@ def _breakHandler(signum, frame): dt_bins = [(dh.dt_range[0], dh.dt_range[1])] if dh.dt_range is not None: - # there's some data to process - log.info("") - log.info("ALL TIME BINS:") - log.info("----------") - for bin_beg, bin_end in dt_bins: - log.info("{:s}, {:s}".format(str(bin_beg), str(bin_end))) + # there's some data to process and we're in candidate mode + if mcmode & MCMODE_CANDS: + log.info("") + log.info("ALL TIME BINS:") + log.info("----------") + for bin_beg, bin_end in dt_bins: + log.info("{:s}, {:s}".format(str(bin_beg), str(bin_end))) ### ### @@ -2012,27 +2074,57 @@ def _breakHandler(signum, frame): # Go through all chunks in time for bin_beg, bin_end in dt_bins: - log.info("") - log.info("PROCESSING TIME BIN:") - log.info("{:s}, {:s}".format(str(bin_beg), str(bin_end))) - log.info("-----------------------------") - log.info("") + if mcmode & MCMODE_CANDS: + log.info("") + log.info("PROCESSING TIME BIN:") + log.info("{:s}, {:s}".format(str(bin_beg), str(bin_end))) + log.info("-----------------------------") + log.info("") - # Load data of unprocessed observations - if cml_args.mcmode != 2: dh.unpaired_observations = dh.loadUnpairedObservations(dh.processing_list, dt_range=(bin_beg, bin_end)) + log.info(f'loaded {len(dh.unpaired_observations)} observations') + + if mcmode != MCMODE_PHASE2: + # remove any trajectories that no longer exist on disk + dh.removeDeletedTrajectories() + # load computed trajectories from disk into sqlite + dh.loadComputedTrajectories(dt_range=(bin_beg, bin_end)) + # move any legacy failed traj into sqlite - # refresh list of calculated trajectories from disk - dh.removeDeletedTrajectories() - dh.loadComputedTrajectories(os.path.join(dh.output_dir, OUTPUT_TRAJ_DIR), dt_range=[bin_beg, bin_end]) - if cml_args.mcmode != 2: - dh.removeDuplicateTrajectories(dt_range=[bin_beg, bin_end]) # Run the trajectory correlator tc = TrajectoryCorrelator(dh, trajectory_constraints, cml_args.velpart, data_in_j2000=True, enableOSM=cml_args.enableOSM) bin_time_range = [bin_beg, bin_end] - tc.run(event_time_range=event_time_range, mcmode=cml_args.mcmode, bin_time_range=bin_time_range) + num_done = tc.run(event_time_range=event_time_range, mcmode=mcmode, bin_time_range=bin_time_range, verbose=verbose) + + if dh.RemoteDatahandler and dh.RemoteDatahandler.mode == 'child' and num_done > 0: + log.info('uploading to master node') + # close the databases and upload the data to the master node + if mcmode != MCMODE_PHASE2: + dh.closeTrajectoryDatabase() + dh.closeObservationsDatabase() + + dh.RemoteDatahandler.uploadToMaster(dh.output_dir, verbose=verbose) + + # truncate the tables here so they are clean for the next run + if mcmode != MCMODE_PHASE2: + dh.trajectory_db = TrajectoryDatabase(dh.db_dir, purge_records=True) + dh.observations_db = ObservationsDatabase(dh.db_dir, purge_records=True) + + if dh.RemoteDatahandler and dh.RemoteDatahandler.mode == 'master': + # move any uploaded data and then check and rebalance any pending cands or phase1s + dh.moveUploadedData(verbose=verbose) + dh.checkAndRedistribCands(wait_time=6, verbose=verbose) + + # If we're in either of these modes, the correlator will have scooped up available data + # from candidates or phase1 folders so no need to keep looping. + if mcmode == MCMODE_PHASE1 or mcmode == MCMODE_PHASE2: + break + + if mcmode & MCMODE_CANDS: + dh.closeObservationsDatabase() + else: # there were no datasets to process log.info('no data to process yet') @@ -2042,8 +2134,15 @@ def _breakHandler(signum, frame): # Store the previous start time previous_start_time = copy.deepcopy(t1) + + # Break after one loop if auto mode is not on if cml_args.auto is None: + # clear the remote data ready flag to indicate we're shutting down + if dh.RemoteDatahandler and dh.RemoteDatahandler.mode == 'child': + dh.RemoteDatahandler.setStopFlag() + if pid_file and os.path.isfile(pid_file): + os.remove(pid_file) break else: @@ -2052,6 +2151,10 @@ def _breakHandler(signum, frame): wait_time = (datetime.timedelta(hours=AUTO_RUN_FREQUENCY) - (datetime.datetime.now(datetime.timezone.utc) - t1)).total_seconds() + # remove the remote data stop flag to indicate we're open for business + if dh.RemoteDatahandler and dh.RemoteDatahandler.mode == 'child': + dh.RemoteDatahandler.clearStopFlag() + # Run immediately if the wait time has elapsed if wait_time < 0: continue @@ -2070,4 +2173,4 @@ def _breakHandler(signum, frame): while next_run_time > datetime.datetime.now(datetime.timezone.utc): print("Waiting {:s} to run the trajectory solver... ".format(str(next_run_time - datetime.datetime.now(datetime.timezone.utc)))) - time.sleep(2) + time.sleep(10) diff --git a/wmpl/Utils/Math.py b/wmpl/Utils/Math.py index bb6069b5..d916bc28 100644 --- a/wmpl/Utils/Math.py +++ b/wmpl/Utils/Math.py @@ -1113,11 +1113,13 @@ def generateDatetimeBins(dt_beg, dt_end, bin_days=7, utc_hour_break=12, tzinfo=N else: bin_beg = dt_beg + datetime.timedelta(days=i * bin_days) - bin_beg = bin_beg.replace(hour=int(utc_hour_break), minute=0, second=0, microsecond=0) + if bin_days > 0.999: + bin_beg = bin_beg.replace(hour=int(utc_hour_break), minute=0, second=0, microsecond=0) # Generate the bin ending edge bin_end = bin_beg + datetime.timedelta(days=bin_days) - bin_end = bin_end.replace(hour=int(utc_hour_break), minute=0, second=0, microsecond=0) + if bin_days > 0.999: + bin_end = bin_end.replace(hour=int(utc_hour_break), minute=0, second=0, microsecond=0) # Check that the ending bin is not beyond the end dt end_reached = False diff --git a/wmpl/Utils/remoteDataHandling.py b/wmpl/Utils/remoteDataHandling.py index 59f59a19..b4125642 100644 --- a/wmpl/Utils/remoteDataHandling.py +++ b/wmpl/Utils/remoteDataHandling.py @@ -23,176 +23,296 @@ import os import paramiko import logging -import glob import shutil +import uuid +import time -from wmpl.Utils.OSTools import mkdirP -from wmpl.Utils.Pickling import loadPickle +from configparser import ConfigParser log = logging.getLogger("traj_correlator") -def collectRemoteTrajectories(remotehost, max_trajs, output_dir): - """ - Collect trajectory pickles from a remote server for local phase2 (monte-carlo) processing - NB: do NOT use os.path.join here, as it will break on Windows - """ +class RemoteNode(): + def __init__(self, nodename, dirpath, capacity, mode, active=False): + self.nodename = nodename + self.dirpath = dirpath + self.capacity = int(capacity) + self.mode = int(mode) + self.active = active - ftpcli, remote_dir, sshcli = getSFTPConnection(remotehost) - if ftpcli is None: - return - - remote_phase1_dir = os.path.join(remote_dir, 'phase1').replace('\\','/') - - log.info(f'Looking in {remote_phase1_dir} on remote host for up to {max_trajs} trajectories') - try: - files = ftpcli.listdir(remote_phase1_dir) - files = [f for f in files if '.pickle' in f and 'processing' not in f] - files = files[:max_trajs] - - if len(files) == 0: - log.info('no data available at this time') - ftpcli.close() - sshcli.close() - return +class RemoteDataHandler(): + def __init__(self, cfg_file): + self.initialised = False + if not os.path.isfile(cfg_file): + log.warning(f'unable to find {cfg_file}, not enabling remote processing') + return - for trajfile in files: - fullname = os.path.join(remote_phase1_dir, trajfile).replace('\\','/') - localname = os.path.join(output_dir, trajfile) - ftpcli.get(fullname, localname) - ftpcli.rename(fullname, f'{fullname}_processing') - - log.info(f'Obtained {len(files)} trajectories') - - - except Exception as e: - log.warning('Problem with download') - log.info(e) - - ftpcli.close() - sshcli.close() - - return - + self.nodenames = None + self.nodes = None + self.capacity = None -def uploadTrajToRemote(remotehost, trajfile, output_dir): - """ - At the end of MC phase, upload the trajectory pickle and report to a remote host for integration - into the solved dataset - """ + self.host = None + self.user = None + self.key = None - ftpcli, remote_dir, sshcli = getSFTPConnection(remotehost) - if ftpcli is None: + self.ssh_client = None + self.sftp_client = None + + cfg = ConfigParser() + cfg.read(cfg_file) + self.mode = cfg['mode']['mode'].lower() + if self.mode not in ['master', 'child']: + log.warning('remote cfg: mode must be master or child, not enabling remote processing') + return + if self.mode == 'master': + if 'children' not in cfg.sections(): + log.warning('remote cfg: children section missing, not enabling remote processing') + return + + # create a list of available nodes, disabling any that are malformed in the config file + self.nodenames = [k for k in cfg['children'].keys()] + self.nodes = [k.split(',') for k in cfg['children'].values()] + self.nodes = [RemoteNode(nn,x[0],x[1],x[2]) for nn,x in zip(self.nodenames,self.nodes) if len(x)==3] + self.nodes.append(RemoteNode('localhost', None, -1, -1)) + activenodes = [n.nodename for n in self.nodes if n.capacity!=0] + log.info(f' using nodes {activenodes}') + else: + # 'child' mode + if 'sftp' not in cfg.sections() or 'key' not in cfg['sftp'] or 'host' not in cfg['sftp'] or 'user' not in cfg['sftp']: + log.warning('remote cfg: sftp user, key or host missing, not enabling remote processing') + return + + self.host = cfg['sftp']['host'] + self.user = cfg['sftp']['user'] + self.key = os.path.normpath(os.path.expanduser(cfg['sftp']['key'])) + if 'port' not in cfg['sftp']: + self.port = 22 + else: + self.port = int(cfg['sftp']['port']) + + self.initialised = True return - - remote_phase2_dir = os.path.join(remote_dir, 'remoteuploads').replace('\\','/') - try: - ftpcli.mkdir(remote_phase2_dir) - except Exception: - pass - - localname = os.path.join(output_dir, trajfile) - remotename = os.path.join(remote_phase2_dir, trajfile).replace('\\','/') - ftpcli.put(localname, remotename) - localname = localname.replace('_trajectory.pickle', '_report.txt') - remotename = remotename.replace('_trajectory.pickle', '_report.txt') - if os.path.isfile(localname): - ftpcli.put(localname, remotename) - - ftpcli.close() - sshcli.close() - return - - -def moveRemoteTrajectories(output_dir): - """ - Move remotely processed pickle files to their target location in the trajectories area, - making sure we clean up any previously-calculated trajectory and temporary files - """ - - phase2_dir = os.path.join(output_dir, 'remoteuploads') - - if os.path.isdir(phase2_dir): - log.info('Checking for remotely calculated trajectories...') - pickles = glob.glob1(phase2_dir, '*.pickle') - - for pick in pickles: - traj = loadPickle(phase2_dir, pick) - phase1_name = traj.pre_mc_longname - traj_dir = f'{output_dir}/trajectories/{phase1_name[:4]}/{phase1_name[:6]}/{phase1_name[:8]}/{phase1_name}' - if os.path.isdir(traj_dir): - shutil.rmtree(traj_dir) - processed_traj_file = os.path.join(output_dir, 'phase1', phase1_name + '_trajectory.pickle_processing') - - if os.path.isfile(processed_traj_file): - log.info(f' Moving {phase1_name} to processed folder...') - dst = os.path.join(output_dir, 'phase1', 'processed', phase1_name + '_trajectory.pickle') - shutil.copyfile(processed_traj_file, dst) - os.remove(processed_traj_file) - - phase2_name = traj.longname - traj_dir = f'{output_dir}/trajectories/{phase2_name[:4]}/{phase2_name[:6]}/{phase2_name[:8]}/{phase2_name}' - mkdirP(traj_dir) - log.info(f' Moving {phase2_name} to {traj_dir}...') - src = os.path.join(phase2_dir, pick) - dst = os.path.join(traj_dir, pick[:15]+'_trajectory.pickle') - - shutil.copyfile(src, dst) - os.remove(src) - - report_file = src.replace('_trajectory.pickle','_report.txt') - if os.path.isfile(report_file): - dst = dst.replace('_trajectory.pickle','_report.txt') - shutil.copyfile(report_file, dst) - os.remove(report_file) - - log.info(f'Moved {len(pickles)} trajectories.') - - return - - -def getSFTPConnection(remotehost): + def getSFTPConnection(self, verbose=False): + if not self.initialised: + return False + + if self.sftp_client: + return True + + log.info(f'Connecting to {self.host}:{self.port} as {self.user}....') - hostdets = remotehost.split(':') + if not os.path.isfile(os.path.expanduser(self.key)): + log.warning(f'ssh keyfile {self.key} missing') + return False + + self.ssh_client = paramiko.SSHClient() + if verbose: + log.info('created paramiko ssh client....') + self.ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + pkey = paramiko.RSAKey.from_private_key_file(self.key) + try: + if verbose: + log.info('connecting....') + self.ssh_client.connect(hostname=self.host, username=self.user, port=self.port, + pkey=pkey, look_for_keys=False, timeout=10) + if verbose: + log.info('connected....') + self.sftp_client = self.ssh_client.open_sftp() + if verbose: + log.info('created client') + return True + + except Exception as e: - if len(hostdets) < 2 or '@' not in hostdets[0]: - log.warning(f'{remotehost} malformed, should be user@host:port:/path/to/dataroot') - return None, None, None + log.warning('sftp connection to remote host failed') + log.warning(e) + self.closeSFTPConnection() + return False + + def closeSFTPConnection(self): + if self.sftp_client: + self.sftp_client.close() + self.sftp_client = None + if self.ssh_client: + self.ssh_client.close() + self.ssh_client = None + return - if len(hostdets) == 3: - port = int(hostdets[1]) - remote_data_dir = hostdets[2] - - else: - port = 22 - remote_data_dir = hostdets[1] - - user,host = hostdets[0].split('@') - log.info(f'Connecting to {host}....') + def putWithRetry(self, local_name, remname): + for i in range(10): + try: + self.sftp_client.put(local_name, remname) + break + except Exception: + time.sleep(1) + if i == 10: + log.warning(f'upload of {local_name} failed after 10 retries') + return + ######################################################## + # functions used by the client nodes - ssh_client = paramiko.SSHClient() - ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + def collectRemoteData(self, datatype, output_dir, verbose=False): + """ + Collect trajectory or candidate pickles from a remote server for local processing + parameters: + datatype = 'candidates' or 'phase1' + output_dir = folder to put the pickles into generally dh.output_dir + """ - if not os.path.isfile(os.path.expanduser('~/.ssh/trajsolver')): - log.warning('ssh keyfile ~/.ssh/trajsolver missing') - ssh_client.close() - return None, None, None - - pkey = paramiko.RSAKey.from_private_key_file(os.path.expanduser('~/.ssh/trajsolver')) - try: - ssh_client.connect(hostname=host, username=user, port=port, pkey=pkey, look_for_keys=False) - ftp_client = ssh_client.open_sftp() - return ftp_client, remote_data_dir, ssh_client - - except Exception as e: + if not self.initialised or not self.getSFTPConnection(verbose=verbose): + return False - log.warning('sftp connection to remote host failed') - log.warning(e) - ssh_client.close() + for pth in ['files', 'files/candidates', 'files/phase1', 'files/trajectories', + 'files/candidates/processed','files/phase1/processed']: + try: + self.sftp_client.mkdir(pth) + except Exception: + pass + self.sftp_client.chmod(pth, 0o777) - return None, None, None + try: + rem_dir = f'files/{datatype}' + files = self.sftp_client.listdir(rem_dir) + files = [f for f in files if '.pickle' in f and 'processing' not in f] + if len(files) == 0: + log.info('no data available at this time') + self.closeSFTPConnection() + return False + + local_dir = os.path.join(output_dir, datatype) + if not os.path.isdir(local_dir): + os.makedirs(local_dir, exist_ok=True) + for trajfile in files: + fullname = f'{rem_dir}/{trajfile}' + localname = os.path.join(local_dir, trajfile) + if verbose: + log.info(f'downloading {fullname} to {localname}') + for i in range(10): + try: + self.sftp_client.get(fullname, localname) + break + except Exception: + time.sleep(1) + try: + self.sftp_client.rename(fullname, f'{rem_dir}/processed/{trajfile}') + except: + try: + self.sftp_client.remove(fullname) + except: + log.info(f'unable to rename or remove {fullname}') + + log.info(f'Obtained {len(files)} {"trajectories" if datatype=="phase1" else "candidates"}') + + except Exception as e: + log.warning('Problem with download') + log.info(e) + + self.closeSFTPConnection() + return True + + def uploadToMaster(self, source_dir, verbose=False): + """ + upload the trajectory pickle and report to a remote host for integration + into the solved dataset + + parameters: + source_dir = root folder containing data, generally dh.output_dir + """ + + if not self.initialised or not self.getSFTPConnection(verbose=verbose): + return + + for pth in ['files', 'files/candidates', 'files/phase1', 'files/trajectories', + 'files/candidates/processed','files/phase1/processed']: + try: + self.sftp_client.mkdir(pth) + except Exception: + pass + self.sftp_client.chmod(pth, 0o777) + + phase1_dir = os.path.join(source_dir, 'phase1') + if os.path.isdir(phase1_dir): + # upload any phase1 trajectories + i=0 + proc_dir = os.path.join(phase1_dir, 'processed') + os.makedirs(proc_dir, exist_ok=True) + for fil in os.listdir(phase1_dir): + local_name = os.path.join(phase1_dir, fil) + if os.path.isdir(local_name): + continue + remname = f'files/phase1/{fil}' + if verbose: + log.info(f'uploading {local_name} to {remname}') + self.putWithRetry(local_name, remname) + if os.path.isfile(os.path.join(proc_dir, fil)): + os.remove(os.path.join(proc_dir, fil)) + shutil.move(local_name, proc_dir) + i += 1 + if i > 0: + log.info(f'uploaded {i} phase1 solutions') + # now upload any data in the 'trajectories' folder, flattening it to make it simpler + i=0 + if os.path.isdir(os.path.join(source_dir, 'trajectories')): + traj_dir = f'{source_dir}/trajectories' + for (dirpath, dirnames, filenames) in os.walk(traj_dir): + if len(filenames) > 0: + rem_path = f'files/trajectories/{os.path.basename(dirpath)}' + try: + self.sftp_client.mkdir(rem_path) + except Exception: + pass + self.sftp_client.chmod(rem_path, 0o777) + for fil in filenames: + local_name = os.path.join(dirpath, fil) + rem_file = f'{rem_path}/{fil}' + if verbose: + log.info(f'uploading {local_name} to {rem_file}') + self.putWithRetry(local_name, rem_file) + i += 1 + shutil.rmtree(traj_dir, ignore_errors=True) + if i > 0: + log.info(f'uploaded {int(i/2)} trajectories') + + # finally the databases + uuid_str = str(uuid.uuid4()) + for fname in ['observations', 'trajectories']: + local_name = os.path.join(source_dir, f'{fname}.db') + if os.path.isfile(local_name): + rem_file = f'files/{fname}-{uuid_str}.db' + if verbose: + log.info(f'uploading {local_name} to {rem_file}') + self.putWithRetry(local_name, rem_file) + + log.info('uploaded databases') + self.closeSFTPConnection() + return + + def setStopFlag(self, verbose=False): + if not self.initialised or not self.getSFTPConnection(): + return + try: + readyfile = os.path.join(os.getenv('TMP', default='/tmp'),'stop') + open(readyfile,'w').write('stop') + self.sftp_client.put(readyfile, 'files/stop') + except Exception: + log.warning('unable to set stop flag, master will not continue to assign data') + time.sleep(2) + self.closeSFTPConnection() + log.info('set stop flag') + return + + def clearStopFlag(self, verbose=False): + if not self.initialised or not self.getSFTPConnection(): + return + try: + self.sftp_client.remove('files/stop') + log.info('removed stop flag') + except: + pass + self.closeSFTPConnection() + return