From e7f7bc1788ef2f6c562b08733e6345475b02c97a Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Wed, 10 Apr 2024 16:43:45 +0300 Subject: [PATCH] Extract mlapdv --- .../nate_optoBiasedChoiceWorld/task.py | 92 ++++++++++++------- .../task_parameters.yaml | 1 + projects/nate_optoBiasedChoiceWorld.py | 31 +++++++ pyproject.toml | 2 +- 4 files changed, 92 insertions(+), 34 deletions(-) diff --git a/iblrig_custom_tasks/nate_optoBiasedChoiceWorld/task.py b/iblrig_custom_tasks/nate_optoBiasedChoiceWorld/task.py index 4994554..a89f0af 100644 --- a/iblrig_custom_tasks/nate_optoBiasedChoiceWorld/task.py +++ b/iblrig_custom_tasks/nate_optoBiasedChoiceWorld/task.py @@ -7,9 +7,11 @@ Additionally the state machine is modified to add output TTLs for optogenetic stimulation """ import logging -import time +import sys +from argparse import ArgumentTypeError from pathlib import Path from typing import Literal +import warnings import numpy as np import yaml @@ -20,14 +22,18 @@ from importlib import reload import random -import sys -sys.path.append('C:\zapit-tcp-bridge\python') -import Python_TCP_Utils as ptu -from TCPclient import TCPclient +ZAPIT_PYTHON = r'C:\zapit-tcp-bridge\python' -num_cond = 52 #will need to change later - is there a function to automatically detect this?> +try: + assert Path(ZAPIT_PYTHON).exists() + sys.path.append(ZAPIT_PYTHON) + import Python_TCP_Utils as ptu + from TCPclient import TCPclient +except (AssertionError, ModuleNotFoundError): + warnings.warn( + 'Please clone https://github.com/Zapit-Optostim/zapit-tcp-bridge to ' + f'{Path(ZAPIT_PYTHON).parents[1]}', RuntimeWarning) -stim_location_history = [] log = logging.getLogger('iblrig.task') @@ -74,9 +80,6 @@ class Session(BiasedChoiceWorldSession): protocol_name = 'nate_optoBiasedChoiceWorld' extractor_tasks = ['TrialRegisterRaw', 'ChoiceWorldTrials', 'TrainingStatus'] - - - def __init__( self, *args, @@ -98,13 +101,23 @@ def __init__( p=[1 - probability_opto_stim, probability_opto_stim], size=NTRIALS_INIT, ).astype(bool) + self.trials_table['laser_location_idx'] = np.zeros(NTRIALS_INIT, dtype=int) + + def draw_next_trial_info(self, **kwargs): + """Draw next trial variables. + + This is called by the `next_trial` method before updating the Bpod state machine. This + subclass method generates the stimulation index which is sent to Zapit when arming the + laser on stimulation trials. + """ + if self.trials_table.at[self.trial_num, 'opto_stimulation']: + N = int(self.task_params.get('NUM_OPTO_COND', 52)) + self.trials_table.at[self.trial_num, 'laser_location_idx'] = random.randrange(1, N) def start_hardware(self): - - self.client = TCPclient(tcp_port=1488, tcp_ip='127.0.0.1') - self.client.close() # need to ensure is closed first; currently nowhere that this is defined at end of task! + self.client.close() # need to ensure is closed first; currently nowhere that this is defined at end of task! self.client.connect() super().start_hardware() # add the softcodes for the zapit opto stimulation @@ -112,44 +125,46 @@ def start_hardware(self): soft_code_dict.update({SOFTCODE_STOP_ZAPIT: self.zapit_stop_laser}) soft_code_dict.update({SOFTCODE_FIRE_ZAPIT: self.zapit_fire_laser}) self.bpod.register_softcodes(soft_code_dict) - def zapit_arm_laser(self): log.warning('Arming laser') - #this is where you define the laser stim (i.e., arm the laser) + # this is where you define the laser stim (i.e., arm the laser) - self.current_location_idx = random.randrange(1,int(num_cond)) + current_location_idx = self.trials_table.at[self.trial_num, 'laser_location_idx'] #hZP.send_samples( # conditionNum=current_location_idx, hardwareTriggered=True, logging=True #) - zapit_byte_tuple, zapit_int_tuple = ptu.gen_Zapit_byte_tuple(trial_state_command = 1, - arg_keys_dict = {'conditionNum_channel': True, 'laser_channel': True, - 'hardwareTriggered_channel': True, 'logging_channel': False, - 'verbose_channel': False}, - arg_values_dict = {'conditionNum': self.current_location_idx, 'laser_ON': True, - 'hardwareTriggered_ON': True, 'logging_ON': False, - 'verbose_ON': False}) + zapit_byte_tuple, zapit_int_tuple = ptu.gen_Zapit_byte_tuple( + trial_state_command=1, + arg_keys_dict={'conditionNum_channel': True, 'laser_channel': True, + 'hardwareTriggered_channel': True, 'logging_channel': False, + 'verbose_channel': False}, + arg_values_dict={'conditionNum': current_location_idx, 'laser_ON': True, + 'hardwareTriggered_ON': True, 'logging_ON': False, + 'verbose_ON': False} + ) response = self.client.send_receive(zapit_byte_tuple) log.warning(response) - stim_location_history.append(self.current_location_idx) def zapit_fire_laser(self): # just logging - actual firing will be triggered by the state machine via TTL - #this really only triggers a ttl and sends a log entry - no need to plug in code here + # this really only triggers a ttl and sends a log entry - no need to plug in code here log.warning('Firing laser') - def zapit_stop_laser(self): log.warning('Stopping laser') - zapit_byte_tuple, zapit_int_tuple = ptu.gen_Zapit_byte_tuple(trial_state_command = 0, - arg_keys_dict = {'conditionNum_channel': True, 'laser_channel': True, - 'hardwareTriggered_channel': True, 'logging_channel': False, - 'verbose_channel': False}, - arg_values_dict = {'conditionNum': self.current_location_idx, 'laser_ON': True, - 'hardwareTriggered_ON': False, 'logging_ON': False, - 'verbose_ON': False}) + current_location_idx = self.trials_table.at[self.trial_num, 'laser_location_idx'] + zapit_byte_tuple, zapit_int_tuple = ptu.gen_Zapit_byte_tuple( + trial_state_command=0, + arg_keys_dict={'conditionNum_channel': True, 'laser_channel': True, + 'hardwareTriggered_channel': True, 'logging_channel': False, + 'verbose_channel': False}, + arg_values_dict={'conditionNum': current_location_idx, 'laser_ON': True, + 'hardwareTriggered_ON': False, 'logging_ON': False, + 'verbose_ON': False} + ) response = self.client.send_receive(zapit_byte_tuple) def _instantiate_state_machine(self, trial_number=None): @@ -172,6 +187,11 @@ def _instantiate_state_machine(self, trial_number=None): @staticmethod def extra_parser(): """:return: argparse.parser()""" + def positive_int(value): + if (value := int(value)) <= 0: + raise ArgumentTypeError(f'"{value}" is an invalid positive int value') + return value + parser = super(Session, Session).extra_parser() parser.add_argument( '--probability_opto_stim', @@ -208,6 +228,12 @@ def extra_parser(): type=str, help='list of the state machine states where opto stim should be stopped', ) + parser.add_argument( + '--n_opto_cond', + default=DEFAULTS['NUM_OPTO_COND'], + type=positive_int, + help='the number (N) of preset conditions to draw from, where N > x > 0', + ) return parser diff --git a/iblrig_custom_tasks/nate_optoBiasedChoiceWorld/task_parameters.yaml b/iblrig_custom_tasks/nate_optoBiasedChoiceWorld/task_parameters.yaml index eaf8396..673f821 100644 --- a/iblrig_custom_tasks/nate_optoBiasedChoiceWorld/task_parameters.yaml +++ b/iblrig_custom_tasks/nate_optoBiasedChoiceWorld/task_parameters.yaml @@ -6,3 +6,4 @@ - error - reward 'PROBABILITY_OPTO_STIM': 0.2 # probability of optogenetic stimulation +'NUM_OPTO_COND': 52 # the number (N) of preset conditions to draw from, where N > x > 0 diff --git a/projects/nate_optoBiasedChoiceWorld.py b/projects/nate_optoBiasedChoiceWorld.py index 0d97e15..e4fd55f 100644 --- a/projects/nate_optoBiasedChoiceWorld.py +++ b/projects/nate_optoBiasedChoiceWorld.py @@ -6,7 +6,9 @@ The pipeline task subclasses, OptoTrialsBpod and OptoTrialsNidq, aren't strictly necessary. They simply assert that the laserStimulation datasets were indeed saved and registered by the Bpod extractor class. """ +import yaml import numpy as np +from packaging import version import ibllib.io.raw_data_loaders as raw from ibllib.io.extractors.base import BaseBpodTrialsExtractor, run_extractor_classes from ibllib.io.extractors.bpod_trials import BiasedTrials @@ -37,6 +39,13 @@ class TrialsOpto(BaseBpodTrialsExtractor): var_names = BiasedTrials.var_names + ('laser_intervals',) save_names = BiasedTrials.save_names + ('_ibl_laserStimulation.intervals.npy',) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.proj_version = version.parse(self.settings.get('PROJECT_EXTRACTION_VERSION', '0.0.0')) + if self.proj_version >= version.parse('0.3.0'): + self.var_names = BaseBpodTrialsExtractor.var_names + ('laser_mplapdv',) + self.save_names = BaseBpodTrialsExtractor.var_names + ('_ibl_laserStimulation.mlapdv.npy',) + def _extract(self, extractor_classes=None, **kwargs) -> dict: settings = self.settings.copy() if 'OPTO_STIM_STATES' in settings: @@ -53,7 +62,9 @@ def _extract(self, extractor_classes=None, **kwargs) -> dict: # Extract laser dataset laser_intervals = [] + location_index = [] for trial in filter(lambda t: t['opto_stimulation'], self.bpod_trials): + location_index.append(trial.get('laser_location_idx', 0)) states = trial['behavior_data']['States timestamps'] # Assumes one of these states per trial: takes the timestamp of the first matching state start = next((v[0][0] for k, v in states.items() if k in settings['OPTO_TTL_STATES']), np.nan) @@ -61,4 +72,24 @@ def _extract(self, extractor_classes=None, **kwargs) -> dict: laser_intervals.append((start, stop)) out['laser_intervals'] = np.array(laser_intervals, dtype=np.float64) + # Extract laser coordinates + if self.proj_version >= version.parse('0.3.0'): + location_index = np.fromiter(filter(None, location_index), dtype=int) + assert len(location_index) == out['laser_intervals'].shape[0] + out['laser_mplapdv'] = np.full((out['laser_intervals'].shape[0], 3), np.NaN) + # Load lookup table + try: + zapit_file = next(self.alf_path.glob('zapit_log_*.yml')) + except StopIteration: + raise FileNotFoundError('Failed to load zapit log file.') + + with open(zapit_file, 'r') as fp: + zapit = yaml.safe_load(fp) + if any(x['Type'] != 'unilateral_points' for x in (v for k, v in zapit.items() if k.startswith('stimLocations'))): + raise NotImplementedError # TODO verify and document + for i in np.unique(location_index): + location = zapit[f'stimLocations{i:02}'] + mlapdv = (location['ML'][0], location['AP'][0], 0.) # TODO ensure len == 3 + out['laser_mplapdv'][location_index == i, :] = mlapdv + return {k: out[k] for k in self.var_names} # Ensures all datasets present and ordered diff --git a/pyproject.toml b/pyproject.toml index e1e2320..797da0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "project_extraction" -version = "0.4.1" +version = "0.3.0" description = "Custom extractors for satellite tasks" dynamic = [ "readme" ] keywords = [ "IBL", "neuro-science" ]