-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathsim_utils.py
More file actions
481 lines (372 loc) · 16 KB
/
sim_utils.py
File metadata and controls
481 lines (372 loc) · 16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
"""Utility functions for simulation setup, runs and handling data.
"""
import numpy as np
from itertools import zip_longest
import h5py
import json
import os
os.environ["NEURON_MODULE_OPTIONS"] = "-nogui" #Stops no gui warnings
import argparse
import importlib
import logging
def network_intialize(params):
"""Initialize network and setup instrumentation
Args:
params (dict or Param): Parameter dictionary
Returns:
network: Network object that includes cells and instrumentations
"""
from network import Network
#if asked to build matrix, load from cache. Otherwise load a previously saved matrix.
if params['build_conn_matrix']:
file = h5py.File(
f"cache/matrix_{params['conn_id']}_{params['sim_id']}_{params['sim_num']}.hdf5", "r"
)
adj_matrix = file["matrix"]
else:
try:
file = h5py.File(
f"network_configs/connections/saved_matrices/matrix_"
f"{params['conn_id']}_{params['matrix_id']}.hdf5", "r"
)
except FileNotFoundError as err:
err.add_note(f"Cannot find saved matrix network_configs/connections/"
f"saved_matrices/matrix_{params['conn_id']}_{params['matrix_id']}.hdf5")
raise err
adj_matrix = file["matrix"]
network = Network(0, adj_matrix, params) # initialize grid cells
file.close()
#Add instrumentation
setup_instrumentation = importlib.import_module(f"network_configs.instrument"
f"ations.{params['instr_id']}_instr").setup_instrumentation
setup_instrumentation(network)
return network
def load_sim_params(sim_id: str, file_path: str = None) -> dict:
"""Load simulation parameters.
Args:
sim_id (str): Simulation ID to load the parameters for.
file_path (str, optional: Direct path to the JSON file containing
the simulation parameters. If not provided, the parameters are
located for the given `sim_id`.
Returns:
dict: The simulation parameters loaded from the JSON file.
"""
if file_path == None:
data_dir = locate_data_dir(sim_id)
data_loc = f'{data_dir}{sim_id}/'
file_path = data_loc+f'{sim_id}.json'
return json_read(file_path)
def locate_data_dir(sim_id: str) -> str:
"""Find the location of the data directory for a given simulation ID.
The function checks for the existence of the data directory in several
predefined locations.
1. "data/" (local)
2. "/data/{user}/data/" (global)
The {user} placeholder is replaced with the current user's username obtained
from the environment variables "USERNAME" or "USER".
Args:
sim_id (str): The sim ID to locate the data directory for.
Returns:
str: The location of the data directory.
"""
user= os.getenv("USERNAME") or os.getenv("USER")
#predefined data locations
data_locations ={"local":"data/","global":f"/data/{user}/data/","ada":f"/data/{user}/"}
for key,val in data_locations.items():
if os.path.exists(f"{val}{sim_id}"):
return val
raise FileNotFoundError(f"Cannot locate data for sim ID: {sim_id}")
def json_save(obj: dict, fname: str):
"""Wrapper function to save a dictionary object to a JSON file.
Args:
obj (dict): The dictionary object to save.
fname (str): The file name to save the dictionary to.
"""
with open(fname, "w") as file:
json.dump(obj, file, indent=0)
def json_read(fname: str) -> dict:
"""Wrapper function to read a JSON file and return the dictionary object.
Args:
fname (str): The file name to read the dictionary from.
Returns:
dict: The dictionary object read from the JSON file.
"""
with open(fname, "r") as file:
obj = json.load(file)
return obj
def json_modify(obj: dict, fname: str):
"""Modify an entry in JSON file using the given dictionary.
If the file exists, it updates the file with the key-value pairs from the object.
If the file does not exist, it creates a new file with the given object.
Args:
obj (dict): The dictionary object containing key-value pairs to be
added or updated in the JSON file.
fname (str): The file name (including path) of the JSON file to be modified.
"""
if os.path.isfile(fname):
file_dict = json_read(fname)
for key, val in obj.items():
file_dict[key] = val
json_save(file_dict, fname)
else:
json_save(obj, fname)
def list_to_numpy(LoL: list, fill:float = np.nan) -> np.ndarray:
"""Converts a list of lists (LoL) to a NumPy array, filling missing values with a specified fill value.
Used for to convert list spike times with non-homogeneous lengths to a NumPy array to save in hdf5 format.
Args:
LoL (list of lists): The input list of lists to be converted to a NumPy array.
fill (float, optional): The value to use for filling missing values. Defaults to np.nan.
Returns
numpy.ndarray: A NumPy array with the contents of the input list of lists, with missing values filled.
"""
return np.array(list(zip_longest(*LoL, fillvalue=fill))).T
def check_sim_dup(sim_id: str, sim_num: int)->bool:
"""Checks if a simulation number exists in the HDF5 file for a given simulation ID.
Args:
sim_id (int): The ID of the simulation.
sim_num (int): The number of the simulation to check for duplication.
Returns:
bool: True if the simulation number exists in the file, False otherwise.
"""
fname = "data/sim_spikes_data_m_{}.hdf5".format(sim_id)
if os.path.exists(fname):
with h5py.File("data/sim_spikes_data_m_{}.hdf5".format(sim_id), "a") as file:
if str(sim_num) in file.keys():
return True
return False
def find_sim_num(params: dict, param_check: dict) -> dict:
"""Find simulation numbers that was run with a given set of paramters.
Useful for analysis. For e.g Finding the simulation number that was run with
``si_peak=1.0``
Args:
params (dict or Param): Dictionary of all simulations generated after a
run of multi simulation.
param_check (dict): A dictionary of parameters to check for. E.g: {si_peak: 1}
Returns:
matches (dict): A subset of global dictoinary containing only the simulations
that match the given parameters in param_check
"""
matches = {}
for sim_num, sim_param in params.items():
correct_sim = True
for key, value in sim_param.items():
for check_cond_key, check_cond_val in param_check.items():
if key == check_cond_key:
if str(value).startswith(str(check_cond_val)):
correct_sim = correct_sim and True
else:
correct_sim = correct_sim and False
if correct_sim:
matches[sim_num] = params[sim_num]
return matches
def get_sim_num(iters: tuple, n_iters: tuple) -> int:
"""Calculate the simulation number for the running iterartor indices.
Useful for analysis.
Args:
iters (tuple): A tuple of current iteration indices.
n_iters (tuple): A tuple of the total size for each iterator.
Returns:
int: simulation number matching the given iterator indices.
"""
assert len(iters)==len(n_iters) #The inputs should have equal dimensions
sim_num=0
n_iters = np.array(n_iters)
for i,iter in enumerate(iters):
sim_num+=iter*np.prod(n_iters[i+1:])
return sim_num
def remove_nodes_from_params(params):
"""Remove "-node*" suffix from sim_id in multi params dict.
Args:
params (dict): Multi params dict.
Returns:
dict: modified params dict without "-node*" suffix from sim_id.
"""
import re
for outer in params.values():
outer["sim_id"] = re.sub(r"-node\d+$", "", outer["sim_id"])
return params
def sim_setup_arg_parser()->argparse.ArgumentParser:
"""Set up the argument parser for the simulation.
Argument parser is initialzed and processed here to avoid cluttering
the main setup file.
Returns:
argparse.ArgumentParser: The argument parser with the arguments added.
"""
parser = argparse.ArgumentParser(description="Run a single simulation")
parser.add_argument("specs_file",
help="specificatons file",
type=str)
parser.add_argument("-v","--verbose",
help="show verbose output",
action="store_const",const=logging.DEBUG,default=logging.INFO)
parser.add_argument("-o","--overwrite_data",
help="overwrite data",
action="store_true")
return parser
def sim_run_arg_parser()->argparse.ArgumentParser:
"""Set up the argument parser for the simulation run.
Argument parser is initialzed and processed here to avoid cluttering
the main run file.
Returns:
argparse.ArgumentParser: The argument parser with the arguments added.
"""
parser = argparse.ArgumentParser(description="Run a single simulation")
parser.add_argument("-i","--sim_id",
help="simulation ID",
type=str,
required=True)
parser.add_argument("-v","--verbose",
help="show verbose output",
action="store_const",const=logging.DEBUG,default=logging.INFO)
return parser
def log_from_rank_0(logger:logging.Logger,rank:int,msg:str,level:int=logging.INFO):
"""Logs a message if the rank is 0.
Args:
logger (logging.Logger): The logger object to use for logging.
rank (int): The rank of the process.
msg (str): The message to be logged.
"""
if rank==0:
logger.log(level,msg)
def process_data_root(data_root:str)->str:
"""Processes the given data root path to ensure it ends with a slash.
Args:
data_root (str): The root path to the data directory.
Returns:
str: The data root path ending with a slash.
"""
return data_root+"/" if data_root[-1]!="/" else data_root
def load_spikes(sim_id:str,sim_num:int=0)->tuple:
"""Load spike data for a given simulation ID.
This function looks for data in the following directories adn returns both
stellate and interneuron spikes.
#. "data/" (local)
#. "/data/{user}/data/" (global)
Args:
sim_id (str): Simulation ID to load spikes from.
sim_num (int, optional): Simulation number to load spikes from,
by default 0 (single sim).
Returns:
tuple
A tuple containing two lists of lists:
- stell_spikes_l: List of spike times for stellate cells.
- intrnrn_spikes_l: List of spike times for interneurons.
"""
data_dir = locate_data_dir(sim_id)
sim_num = str(sim_num) #sim_num are stored as string in .hdf5
data_loc = f"{data_dir}{sim_id}/"
file_path_stell = data_loc + f"stell_spks_{sim_id}.hdf5"
file_path_intrnrn = data_loc + f"intrnrn_spks_{sim_id}.hdf5"
with h5py.File(file_path_stell, "r") as file:
stellate_spks_arr = np.array(file[f"{sim_num}/stell_spks"][:])
stell_spikes_l = [list(cell[~np.isnan(cell)]) for cell in stellate_spks_arr]
with h5py.File(file_path_intrnrn, "r") as file:
intrnrn_spks_arr = np.array(file[f"{sim_num}/intrnrn_spks"][:])
intrnrn_spikes_l = [list(cell[~np.isnan(cell)]) for cell in intrnrn_spks_arr]
return stell_spikes_l, intrnrn_spikes_l
def get_git_commit_hash():
"""Retrieves the current Git commit hash.
Returns:
str: The Git commit hash as a string if the command is successful.
None: If there is an error executing the command.
"""
import subprocess
try:
out=subprocess.run(["git", "rev-parse", "--short", "HEAD"],capture_output=True,encoding="utf-8")
out.check_returncode()
return out.stdout.strip()
except subprocess.CalledProcessError:
return None
class ProgressBar:
"""Progress bar for simulations.
Not tested for multiple simulations.
:meta private:
"""
def __init__(self,total,pc=None):
if os.name == 'nt':
self.marker="#"
else:
self.marker='\x1b[31m█\x1b[39m'
rank0=self._check_rank(pc)
if rank0:
self.total= total
self.length=50
self.curr_progress=0
def finish(self,pc=None):
rank0=self._check_rank(pc)
if rank0:
self.iteration=self.total
percent = 100 * (self.iteration / float(self.total))
filled_length = int(self.length * self.iteration // self.total)
bar = self.marker * filled_length + '-' * (self.length - filled_length)
print(f"Progress ({self.iteration} of {self.total} ms): |{bar}| {percent:.2f}%",end="\n",flush=True)
def increment(self,iteration,pc=None,flush=False):
rank0=self._check_rank(pc)
if rank0:
self.iteration=iteration
percent = 100 * (self.iteration / float(self.total))
filled_length = int(self.length * self.iteration // self.total)
bar = self.marker * filled_length + '-' * (self.length - filled_length)
print(f"Progress ({self.iteration} of {self.total} ms): |{bar}| {percent:.2f}%",end="\r",flush=flush)
def _check_rank(self,pc):
if pc is not None:
if pc.id()==0:
return True
return False
else:
return True
def get_module_from_path(file_path: str) -> str:
"""Convert a file path to a module name.
Used to import specs file that is passed as as argument to
the simulation setup scripts.
Args:
file_path (str): The relative file path (e.g., "specs/s_template.py")
Returns:
str: The corresponding module name (e.g., "specs.s_template")
"""
# Remove the file extension
module_name, _ = os.path.splitext(file_path)
# Replace path separators with dots
module_name = module_name.replace(os.path.sep, ".")
return module_name
def load_data(sim_id:str,data_id:str,cell_n:int=0,sim_num:str=0)->np.ndarray:
"""Load Non-spiking data for a given simulation ID.
Args:
sim_id (str): Simulation ID to load data from.
data_id (str): Data ID. e.g. 'stell_v'
cell_n (int, optional): Cell number to load data for, by default 0.
sim_num (str, optional): Simulation number to load data from, by default 0.
Returns:
np.ndarray: The data array loaded from the HDF5 file
"""
sim_num = str(sim_num)
data_dir = locate_data_dir(sim_id)
with h5py.File(f'{data_dir}{sim_id}/{data_id}_{sim_id}.hdf5', 'r') as f:
data= np.array(f[str(sim_num)]['{}'.format(data_id)][cell_n])
return data
def load_full_data(sim_id,data_id,sim_num=0):
"""Load Non-spiking data of all cells for a given simulation ID and number.
Args:
sim_id (str): Simulation ID to load data from.
data_id (str): Data ID. e.g. 'stell_v'
cell_n (int, optional): Cell number to load data for, by default 0.
sim_num (str, optional): Simulation number to load data from, by default 0.
Returns:
np.ndarray: The data array loaded from the HDF5 file.
"""
data_dir = locate_data_dir(sim_id)
with h5py.File(f'{data_dir}{sim_id}/{data_id}_{sim_id}.hdf5', 'r') as f:
data= np.array(f[str(sim_num)]['{}'.format(data_id)])
return data
def get_multiples_with_remainder(N, k):
"""Returns a list of multiples of `k` upto N.
Args:
N (int): The upper limit (exclusive) for the multiples.
k (int): The step size.
"""
N = int(N)
k = int(k)
multiples = list(range(k, N+1, k))
remainder = N - multiples[-1] if multiples else N
multiples.append(N) if remainder else None
return np.array(multiples,dtype="int")