diff --git a/.gitignore b/.gitignore index d230542..01066c7 100644 --- a/.gitignore +++ b/.gitignore @@ -23,4 +23,6 @@ eggs/ core* report* CMakeCache.txt +.nextflow +*.bdm.db diff --git a/Dockerfile.skyweavercpp b/Dockerfile.skyweavercpp index 90df0fc..2b3f093 100644 --- a/Dockerfile.skyweavercpp +++ b/Dockerfile.skyweavercpp @@ -60,7 +60,7 @@ WORKDIR /usr/src/skyweaver COPY . . RUN cmake -S . -B build/ -DARCH=native -DPSRDADA_INCLUDE_DIR=/usr/local/include/psrdada \ -DPSRDADACPP_INCLUDE_DIR=/usr/local/include/psrdada_cpp -DSKYWEAVER_NANTENNAS=64 \ - -DSKYWEAVER_NBEAMS=128 -DSKYWEAVER_NCHANS=64 -DSKYWEAVER_IB_SUBTRACTION=1 -DBUILD_SUBMODULES=1 \ - -DENABLE_TESTING=1 -DENABLE_BENCHMARK=1 &&\ + -DSKYWEAVER_NBEAMS=800 -DSKYWEAVER_NCHANS=64 -DSKYWEAVER_IB_SUBTRACTION=1 -DBUILD_SUBMODULES=1 \ + -DENABLE_TESTING=1 -DENABLE_BENCHMARK=1 -DSKYWEAVER_CB_TSCRUNCH=4 -DSKYWEAVER_IB_TSCRUNCH=4 -DSKYWEAVER_CB_FSCRUNCH=1 -DSKYWEAVER_IB_FSCRUNCH=1 &&\ make -C build/ -j 16 && make -C build/ install diff --git a/Dockerfile.skyweaverpy b/Dockerfile.skyweaverpy index 27077fd..ac7686a 100644 --- a/Dockerfile.skyweaverpy +++ b/Dockerfile.skyweaverpy @@ -20,7 +20,8 @@ RUN apt-get update && apt-get install -y \ # Install Python dependencies RUN pip install --upgrade pip && \ - pip install pytest + pip install pytest && \ + pip install pandas # Copy the rest of your application code into the container diff --git a/README.md b/README.md index ffc3e94..038f3bd 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,177 @@ # skyweaver Implementation of an offline FBFUSE beamformer for the MeerKAT telescope + +# Installation + +It is easiest to use the software inside a Docker container. Two dockerfiles is included part of this repository: To compile the c++ and python parts respectively. + +# Usage + +## Step 1: Get delays for the beamformer + +### Start with the skyweaver python CLI + +```bash +alias sw="python /path/to/python/skyweaver/cli.py" +sw --help +``` +This will print: +```console +usage: skyweaver [-h] {metadata,delays} ... + +positional arguments: + {metadata,delays} sub-command help + metadata Tools for observation metadata files + delays Tools for delay files + +optional arguments: + -h, --help show this help message and exit +``` + +### Get the metadata for the corresponding observation + +This is done outside this repository + +### Obtain metadata information + +```bash +sw metadata show +``` + +This will produce an output like the following: +```console +sw metadata show bvrmetadata_2024-02-16T10\:50\:46_72275.hdf5 +---------------Array configuration---------------- +Nantennas: 57 +Subarray: m000,m002,m003,m004,m005,m007,m008,m009,m010,m011,m012,m014,m015,m016, + m017,m018,m019,m020,m021,m022,m023,m024,m025,m026,m027,m029,m030,m031, + m032,m033,m034,m035,m036,m037,m038,m039,m040,m041,m042,m043,m044,m045, + m046,m048,m049,m050,m051,m053,m054,m056,m057,m058,m059,m060,m061,m062, + m063 +Centre frequency: 1284000000.0 Hz +Bandwidth: 856000000.0 Hz +Nchannels: 4096 +Sync epoch (UNIX): 1708039531.0 +Project ID: - +Schedule block ID: - +CBF version: cbf_dev +--------------------Pointings--------------------- +#0 J1644-4559 2024-02-16T11:16:08.957000000 until 2024-02-16T11:21:04.865000000 (UTC) + 1708082168.957 until 1708082464.865 (UNIX) + 72996182384029 until 73502776880016 (SAMPLE CLOCK) +#1 J1644-4559_Offset1 2024-02-16T11:21:22.092000000 until 2024-02-16T11:26:15.682000000 (UTC) + 1708082482.092 until 1708082775.682 (UNIX) + 73532269504013 until 74034895583866 (SAMPLE CLOCK) +#2 J1644-4559_Offset2 2024-02-16T11:26:34.536000000 until 2024-02-16T11:31:25.723000000 (UTC) + 1708082794.536 until 1708083085.723 (UNIX) + 74067173632022 until 74565685776084 (SAMPLE CLOCK) +#3 M28 2024-02-16T11:31:52.503000000 until 2024-02-16T12:01:51.651000000 (UTC) + 1708083112.503 until 1708084911.651 (UNIX) + 74611533136035 until 77691674512039 (SAMPLE CLOCK) +#4 J0437-4715 2024-02-16T12:03:06.117000000 until 2024-02-16T12:08:04.364000000 (UTC) + 1708084986.117 until 1708085284.364 (UNIX) + 77819160304176 until 78329759168140 (SAMPLE CLOCK) +``` + +### Create a config file in .yml format +Here is an example - there are comments to explain each parameter. + +```.yml +created_by: Vivek +beamformer_config: + # The total number of beams to be produced (must be a multiple of 32). This needs to be <= the number that SKYWEAVER is compiled for. + total_nbeams: 800 + # The number of time samples that will be accumulated after detection, inside the beamformer + tscrunch: 4 + # The number of frequency channels that will be accumulated after detection, inside the beamformer + # Will be coerced to 1 if coherent dedispersion is specified. + fscrunch: 1 + # The Stokes product to be calculated in the beamformer (I=0, Q=1, U=2, V=3) + stokes_mode: 0 + # Enable CB-IB subtraction in the beamformer + subtract_ib: True + + # Dispersion measure for coherent / incoherent dedispersion in pc cm^-3 + # A dispersion plan definition string " + # "(::::) or " + # "(:) " + # "or ()") +# Each DD plan is a "Stream" with zero indexed stream-ids + +ddplan: + - "478.6:478.6:478.6:1:1" #stream-id=0 + - "0.00:478.6:478.6:1:1" #stream-id=1 + +# every beamset can contain arbitrary set of antennas, corresponding targeted beams, and tiled beams +# total number of beams across all beamsets should be <= the number of beams that SKYWEAVER is compiled for. +beam_sets: + + - antenna_set: ['m000','m002','m003','m004','m005','m007','m008','m009','m010','m011', + 'm012','m014','m015','m016','m017','m018','m019','m020','m021','m022', + 'm023','m024','m025','m026','m027','m029','m030','m031','m032','m033', + 'm034','m035','m036','m037','m038','m039','m040','m041','m042','m043', + 'm044','m045','m046','m048','m049','m050','m051','m053','m054','m056', + 'm057','m058','m059','m060','m061','m062','m063'] + beams: [] + tilings: + - nbeams: 32 + reference_frequency: null + target: "J1644-4559,radec,16:44:49.273,-45:59:09.71" + overlap: 0.9 +``` + + + +### Create delay file for the corresponding pointing + +```bash +sw delays create --pointing-idx 0 --outfile J1644-4559_pointing_0.delays --step 4 bvrmetadata_2024-02-16T10\:50\:46_72275.hdf5 J1644-4559_boresight.yaml +``` + +This produces a `.delays` file used for beamforming, and a `.targets` file that contains beam metadata. There are also other files produced here for reproducibility and for visualisation. + +## Step 2: Initialise input and compile skyweaver + +### Create a list of dada files that correspond to the pointing + +```console +ls /b/u/vivek/00_DADA_FILES/J1644-4559/2024-02-16-11\:16\:08/L/48/*dada -1 > /bscratch/vivek/skyweaver_tests/J1644-4559_boresight_dadafiles.list +``` + +### Compile skyweaver + +This is done inside the dockerfile too. Either edit that to produce a docker image that has the software precompiled, or compile separately. + +```bash + cmake -S . -B $cmake_tmp_dir -DENABLE_TESTING=0 -DCMAKE_INSTALL_PREFIX=$install_dir -DARCH=native -DPSRDADA_INCLUDE_DIR=/usr/local/include/psrdada -DPSRDADACPP_INCLUDE_DIR=/usr/local/include/psrdada_cpp -DSKYWEAVER_NANTENNAS=64 -DSKYWEAVER_NBEAMS=${nbeams} -DSKYWEAVER_NCHANS=64 -DSKYWEAVER_IB_SUBTRACTION=1 -DCMAKE_BUILD_TYPE=RELEASE -DSKYWEAVER_CB_TSCRUNCH=${tscrunch} -DSKYWEAVER_IB_TSCRUNCH=${tscrunch}; + cd $cmake_tmp_dir + make -j 16 +``` +This compilation produces two binaries: `skyweavercpp` and `skycleaver` +## Step 3: Run the beamformer + +```bash + +/path/to/skyweavercpp --input-file J1644-4559_boresight_dadafiles.list --delay-file J1644-4559_pointing_0.delays --output-dir=/bscratch/vivek/skyweaver_out --gulp-size=32768 --log-level=warning --output-level=12 --stokes-mode I +``` + +Change the output level to 7 for bright pulsars like J1644-4559. + +This will produce `.tdb` files for the corresponding bridge. Run Step 3 for ALL 64 bridges with their corresponding dada file lists. These are DADA format files with the dimensions of TIME, INCOHERENT DM and BEAM as the order. For stokes I mode, The datatype is `int8_t`. For IQUV it is `char4`. + +## Steo 4: Cleave all bridges to form Filterbanks + +Here we cleave the 64 TDB[I/Q/U/V/IV/QU/IQUV] files to produce `NDM*NBEAMS*NSTOKES` number of T(F=64) files. + +to run this, do + +```bash + +/path/to/skycleaver -r /bscratch/vivek/skyweaver_out --output-dir /bscratch/vivek/skycleaver_out --nsamples-per-block 65536 --nthreads 32 --stream-id 0 --targets_file/bscratch/vivek/skyweaver_out/swdlays_J1644-4559.targets --out-stokes I --required_beams 0 + +``` + +This will produce a standard sigproc format `.fil` file that can be used for traditional processing. + + + diff --git a/cmake/compiler_settings.cmake b/cmake/compiler_settings.cmake index 879cf70..500e4f0 100644 --- a/cmake/compiler_settings.cmake +++ b/cmake/compiler_settings.cmake @@ -6,9 +6,9 @@ if (NOT CMAKE_BUILD_TYPE) endif () #set(CMAKE_VERBOSE_MAKEFILE 1) - +set (CMAKE_CXX_STANDARD 20) # Set compiler flags -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -fopenmp -std=gnu++20") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -fopenmp") set(ARCH "broadwell" CACHE STRING "target architecture (-march=native, x86-64), defautls to broadwell") diff --git a/cpp/skyweaver/CMakeLists.txt b/cpp/skyweaver/CMakeLists.txt index 0f49287..e2a605a 100644 --- a/cpp/skyweaver/CMakeLists.txt +++ b/cpp/skyweaver/CMakeLists.txt @@ -18,7 +18,6 @@ set(skyweaver_src src/Timer.cpp src/Transposer.cu src/WeightsManager.cu - src/SkyCleaver.cu src/SigprocHeader.cpp ) @@ -43,6 +42,8 @@ set(skyweaver_inc WeightsManager.cuh MultiFileWriter.cuh SigprocHeader.hpp + SkyCleaver.hpp + SkyCleaverConfig.hpp ) set(SKYWEAVER_LIBRARIES ${CMAKE_PROJECT_NAME} ${DEPENDENCY_LIBRARIES}) @@ -111,7 +112,7 @@ target_link_libraries(skyweavercpp OpenMP::OpenMP_CXX) install(TARGETS skyweavercpp DESTINATION bin) -cuda_add_executable(skycleaver src/skycleaver_cli.cu) +add_executable(skycleaver src/skycleaver_cli.cpp) target_link_libraries(skycleaver ${SKYWEAVER_LIBRARIES} ${DEPENDENCY_LIBRARIES} @@ -138,4 +139,4 @@ if(ENABLE_BENCHMARK) COMMAND beamformer_bench --benchmark_counters_tabular=true WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) add_subdirectory(benchmark) -endif() \ No newline at end of file +endif() diff --git a/cpp/skyweaver/DescribedVector.hpp b/cpp/skyweaver/DescribedVector.hpp index 6fcdd0a..9b11578 100644 --- a/cpp/skyweaver/DescribedVector.hpp +++ b/cpp/skyweaver/DescribedVector.hpp @@ -366,12 +366,13 @@ struct DescribedVector { * @param freqs */ void frequencies(FrequenciesType const& freqs) - { + { if(freqs.size() != get_dim_extent()) { throw std::runtime_error("Invalid number of frequecies passed."); } _frequencies_stale = false; _frequencies = freqs; + } /** @@ -386,6 +387,7 @@ struct DescribedVector { } _frequencies_stale = false; _frequencies.resize(1, freq); + _frequencies[0] = freq; } /** @@ -690,10 +692,16 @@ template using FPAStatsD = DescribedVector, FreqDim, PolnDim, AntennaDim>; -//skycleaver output vectors +//skycleaver vectors +template +using TDBPowersStdH = DescribedVector, + TimeDim, + DispersionDim, + BeamDim, + PolnDim>; template -using TFPowersH = DescribedVector>, +using TFPowersStdH = DescribedVector, TimeDim, FreqDim>; diff --git a/cpp/skyweaver/FileOutputStream.hpp b/cpp/skyweaver/FileOutputStream.hpp index 53dc988..968cc92 100644 --- a/cpp/skyweaver/FileOutputStream.hpp +++ b/cpp/skyweaver/FileOutputStream.hpp @@ -48,6 +48,8 @@ class FileOutputStream std::string _full_path; std::size_t _bytes_requested; std::size_t _bytes_written; + std::string _temporary_suffix; + std::string _temporary_path; std::ofstream _stream; }; diff --git a/cpp/skyweaver/MultiFileWriter.cuh b/cpp/skyweaver/MultiFileWriter.cuh index 15825c6..be49c88 100644 --- a/cpp/skyweaver/MultiFileWriter.cuh +++ b/cpp/skyweaver/MultiFileWriter.cuh @@ -13,43 +13,45 @@ namespace skyweaver { - - - -struct MultiFileWriterConfig{ - - std::size_t header_size; - std::size_t max_file_size; - std::string stokes_mode; - std::string output_dir; - std::string base_output_dir; - std::string prefix; - std::string extension; - std::string output_basename; - - - - MultiFileWriterConfig() : header_size(4096), max_file_size(2147483647), stokes_mode("I"), output_dir("default/"), prefix(""), extension(""){}; - MultiFileWriterConfig(std::size_t header_size, std::size_t max_file_size, std::string stokes_mode, std::string output_dir, std::string prefix, std::string extension) : header_size(header_size), max_file_size(max_file_size), stokes_mode(stokes_mode), output_dir(output_dir), prefix(prefix), extension(extension), output_basename(""){ }; - MultiFileWriterConfig(MultiFileWriterConfig const& other) : header_size(other.header_size), max_file_size(other.max_file_size), - stokes_mode(other.stokes_mode), output_dir(other.output_dir), base_output_dir(other.base_output_dir), prefix(other.prefix), extension(other.extension), output_basename(other.output_basename){}; - MultiFileWriterConfig& operator=(MultiFileWriterConfig const& other){ - header_size = other.header_size; - max_file_size = other.max_file_size; - stokes_mode = other.stokes_mode; - output_dir = other.output_dir; - prefix = other.prefix; - extension = other.extension; - output_basename = other.output_basename; - base_output_dir = other.base_output_dir; - return *this; - } - - std::string to_string(){ - return "header_size: " + std::to_string(header_size) + ", max_file_size: " + std::to_string(max_file_size) - + ", stokes_mode: " + stokes_mode + ", output_dir: " + output_dir + ", prefix: " + prefix + ", extension: " + extension - + ", output_basename: " + output_basename + ", base_output_dir: " + base_output_dir; - } +struct MultiFileWriterConfig { + std::size_t header_size; + std::size_t max_file_size; + std::string stokes_mode; + std::string output_dir; + std::string base_output_dir; + std::string inner_dir; + std::string prefix; + std::string extension; + std::string output_basename; + std::string suffix; + PreWriteConfig pre_write; + + MultiFileWriterConfig() + : header_size(4096), max_file_size(2147483647), stokes_mode("I"), + output_dir("default/"), base_output_dir("default_base/"), inner_dir(""), prefix(""), extension(""), output_basename(""), suffix("") {}; + + MultiFileWriterConfig(std::size_t header_size, + std::size_t max_file_size, + std::string stokes_mode, + std::string output_dir, + std::string prefix, + std::string extension, + std::string suffix) + : header_size(header_size), max_file_size(max_file_size), + stokes_mode(stokes_mode), output_dir(output_dir), prefix(prefix), + extension(extension), suffix(suffix), output_basename("") {}; + + + std::string to_string() + { + return "header_size: " + std::to_string(header_size) + + ", max_file_size: " + std::to_string(max_file_size) + + ", stokes_mode: " + stokes_mode + ", output_dir: " + output_dir + + ", prefix: " + prefix + ", extension: " + extension + + ", output_basename: " + output_basename + + ", base_output_dir: " + base_output_dir + + ", inner_dir: " + inner_dir + ", suffix: " + suffix; + } }; /** * @brief A class for handling writing of DescribedVectors @@ -58,12 +60,14 @@ struct MultiFileWriterConfig{ template class MultiFileWriter { -public: - - using CreateStreamCallBackType = std::function(MultiFileWriterConfig const&, - ObservationHeader const&, - VectorType const&, - std::size_t)>; + public: + using PreWriteCallback = std::function; + using CreateStreamCallBackType = + std::function( + MultiFileWriterConfig const&, + ObservationHeader const&, + VectorType const&, + std::size_t)>; public: /** @@ -74,8 +78,20 @@ public: * (used to avoid clashing file names). */ // MultiFileWriter(PipelineConfig const& config, std::string tag = ""); - MultiFileWriter(PipelineConfig const& config, std::string tag, CreateStreamCallBackType create_stream_callback); - MultiFileWriter(MultiFileWriterConfig config, std::string tag, CreateStreamCallBackType create_stream_callback); + MultiFileWriter(PipelineConfig const& config, + std::string tag, + CreateStreamCallBackType create_stream_callback); + MultiFileWriter(PipelineConfig const& config, + std::string tag, + CreateStreamCallBackType create_stream_callback, + PreWriteCallback pre_write_callback); + MultiFileWriter(MultiFileWriterConfig config, + std::string tag, + CreateStreamCallBackType create_stream_callback); + MultiFileWriter(MultiFileWriterConfig config, + std::string tag, + CreateStreamCallBackType create_stream_callback, + PreWriteCallback pre_write_callback); MultiFileWriter(MultiFileWriter const&) = delete; /** @@ -110,13 +126,12 @@ public: */ bool operator()(VectorType const& stream_data, std::size_t stream_idx = 0); - bool write(VectorType const& stream_data, - std::size_t stream_idx = 0); + bool write(VectorType const& stream_data, std::size_t stream_idx = 0); private: bool has_stream(std::size_t stream_idx); FileOutputStream& create_stream(VectorType const& stream_data, - std::size_t stream_idx); + std::size_t stream_idx); std::string get_output_dir(VectorType const& stream_data, std::size_t stream_idx); std::string get_basefilename(VectorType const& stream_data, @@ -124,6 +139,7 @@ public: std::string get_extension(VectorType const& stream_data); CreateStreamCallBackType _create_stream_callback; MultiFileWriterConfig _config; + PreWriteCallback _pre_write_callback; std::string _tag; ObservationHeader _header; std::map> _file_streams; @@ -136,4 +152,4 @@ public: #include "skyweaver/detail/MultiFileWriter.cu" #include "skyweaver/detail/file_writer_callbacks.cpp" -#endif // SKYWEAVER_MULTIFILEWRITER_CUH \ No newline at end of file +#endif // SKYWEAVER_MULTIFILEWRITER_CUH diff --git a/cpp/skyweaver/ObservationHeader.hpp b/cpp/skyweaver/ObservationHeader.hpp index 356d71e..6e2fbb9 100644 --- a/cpp/skyweaver/ObservationHeader.hpp +++ b/cpp/skyweaver/ObservationHeader.hpp @@ -4,12 +4,12 @@ #include "psrdada_cpp/raw_bytes.hpp" #include "skyweaver/Header.hpp" #include "skyweaver/PipelineConfig.hpp" + #include namespace skyweaver { - struct ObservationHeader { std::size_t nchans = 0; // Number of frequency channels in the subband std::size_t npol = 0; // Number of polarisations @@ -30,37 +30,38 @@ struct ObservationHeader { long double sync_time = 0.0; // The UNIX epoch of the sampler zero long double utc_start = 0.0; // The UTC start time of the data long double mjd_start = 0.0; // The MJD start time of the data - std::size_t obs_offset = 0; // The offset of the current file from UTC_START in bytesß - long double refdm = 0.0; // Reference DM - std::size_t ibeam = 0.0; // Beam number - std::size_t nbeams = 0; // Number of beams - std::string source_name; // Name of observation target - std::string ra; // Right ascension - std::string dec; // Declination - std::string telescope; // Telescope name - std::string instrument; // Name of the recording instrument - std::string order; // Order of the dimensions in the data - std::string ndms; // Number of DMs - std::vector dms; // DMs - - std::string to_string() const; // Convert the header to a string - - long double az; // Azimuth - long double za; // Zenith angle - std::size_t machineid = 0; // Machine ID - std::size_t nifs = 0; // Number of IFs - std::size_t telescopeid = 0; // Telescope ID - std::size_t datatype = 0; // Data type - std::size_t barycentric = 0; // Barycentric correction - std::string rawfile; // Raw file name - double fch1 = 0.0; // Centre frequency of the first channel - double foff = 0.0; // Frequency offset between channels - - bool sigproc_params = false; // Whether to include sigproc parameters + std::size_t obs_offset = + 0; // The offset of the current file from UTC_START in bytesß + long double refdm = 0.0; // Reference DM + std::size_t ibeam = 0.0; // Beam number + std::size_t nbeams = 0; // Number of beams + std::string source_name; // Name of observation target + std::string ra; // Right ascension + std::string dec; // Declination + std::string telescope; // Telescope name + std::string instrument; // Name of the recording instrument + std::string order; // Order of the dimensions in the data + std::size_t ndms; // Number of DMs + std::vector dms; // DMs + std::string stokes_mode; // Stokes mode + + std::string to_string() const; // Convert the header to a string + + long double az; // Azimuth + long double za; // Zenith angle + std::size_t machineid = 0; // Machine ID + std::size_t nifs = 0; // Number of IFs + std::size_t telescopeid = 0; // Telescope ID + std::size_t datatype = 0; // Data type + std::size_t barycentric = 0; // Barycentric correction + std::string rawfile; // Raw file name + double fch1 = 0.0; // Centre frequency of the first channel + double foff = 0.0; // Frequency offset between channels + + bool sigproc_params = false; // Whether to include sigproc parameters ObservationHeader() = default; - ObservationHeader(ObservationHeader const&) = default; + ObservationHeader(ObservationHeader const&) = default; ObservationHeader& operator=(ObservationHeader const&) = default; - }; // template for comparing two floating point objects @@ -72,7 +73,6 @@ is_close(T a, T b, T tolerance = 1e-12) return std::fabs(a - b) < tolerance; } - /** * @brief Parse header information for a DADA header block * @@ -90,15 +90,6 @@ void update_config(PipelineConfig& config, ObservationHeader const& header); bool are_headers_similar(ObservationHeader const& header1, ObservationHeader const& header2); - - - - - - } // namespace skyweaver - - - #endif // SKYWEAVER_OBSERVATIONHEADER_HPP \ No newline at end of file diff --git a/cpp/skyweaver/PipelineConfig.hpp b/cpp/skyweaver/PipelineConfig.hpp index 1656b0c..79becc2 100644 --- a/cpp/skyweaver/PipelineConfig.hpp +++ b/cpp/skyweaver/PipelineConfig.hpp @@ -9,6 +9,18 @@ namespace skyweaver { + struct WaitConfig + { + int iterations; + int sleep_time; + std::size_t min_free_space; + }; + + struct PreWriteConfig + { + bool is_enabled; + WaitConfig wait; + }; /** * @brief Class for wrapping the skyweaver pipeline configuration. @@ -153,6 +165,11 @@ class PipelineConfig */ DedispersionPlan const& ddplan() const; + /** + * @brief configures wait for filesystem space + */ + void configure_wait(std::string argument); + /** * @brief Enable/disable incoherent dedispersion based fscrunch after * beamforming @@ -165,6 +182,26 @@ class PipelineConfig */ bool enable_incoherent_dedispersion() const; + /** + * @brief Enable/disable calculation and writing of voltage statistics + */ + void output_statistics(bool enable); + + /** + * @brief Check if calculation of voltage statistics is enabled + */ + bool output_statistics() const; + + /** + * @brief Enable/disable writing out the incoherent beam + */ + void output_incoherent_beam(bool enable); + + /** + * @brief Check if outputting incoherent beam is enabled + */ + bool output_incoherent_beam() const; + /** * @brief Return the number of time samples to be integrated * in the coherent beamformer. @@ -209,6 +246,11 @@ class PipelineConfig return SKYWEAVER_CB_NSAMPLES_PER_BLOCK; } + PreWriteConfig pre_write_config() const + { + return _pre_write_config; + } + /** * @brief Return the total number of samples to read from file in each gulp. * @@ -323,6 +365,7 @@ class PipelineConfig } private: + std::size_t convertMemorySize(const std::string& str) const; void calculate_channel_frequencies() const; void update_power_offsets_and_scalings(); @@ -339,6 +382,8 @@ class PipelineConfig std::size_t _max_output_filesize; std::string _output_file_prefix; bool _enable_incoherent_dedispersion; + bool _output_statistics; + bool _output_incoherent_beam; double _cfreq; double _bw; mutable bool _channel_frequencies_stale; @@ -350,6 +395,7 @@ class PipelineConfig float _output_level; DedispersionPlan _ddplan; mutable std::vector _channel_frequencies; + PreWriteConfig _pre_write_config; }; } // namespace skyweaver diff --git a/cpp/skyweaver/SkyCleaver.cuh b/cpp/skyweaver/SkyCleaver.cuh deleted file mode 100644 index 8c91290..0000000 --- a/cpp/skyweaver/SkyCleaver.cuh +++ /dev/null @@ -1,78 +0,0 @@ -#ifndef SKYWEAVER_SKYCLEAVER_HPP -#define SKYWEAVER_SKYCLEAVER_HPP -#include "skyweaver/DescribedVector.hpp" -#include "boost/log/trivial.hpp" -#include "psrdada_cpp/psrdadaheader.hpp" -#include "psrdada_cpp/raw_bytes.hpp" -#include "skyweaver/ObservationHeader.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include -#include "skyweaver/MultiFileReader.cuh" -#include "skyweaver/SkyCleaverConfig.hpp" -#include "skyweaver/MultiFileWriter.cuh" -#include "skyweaver/ObservationHeader.hpp" -#include "skyweaver/Timer.hpp" - -namespace skyweaver -{ - -struct BridgeReader -{ - public: - std::vector _tdb_filenames; - std::unique_ptr _tdb_reader; - std::string freq; - -}; // BridgeReader -class SkyCleaver -{ - public: - using InputType = int8_t; - using OutputType = uint8_t; - using InputVectorType = TDBPowersH; - using OutputVectorType = TFPowersH; - using FreqType = std::size_t; // up to the nearest Hz - using BeamNumberType = std::size_t; - using DMNumberType = std::size_t; - - - private: - SkyCleaverConfig const& _config; - std::map> _bridge_readers; - std::map> _bridge_data; - - std::vector _expected_freqs; - std::vector _available_freqs; - std::size_t _nsamples_to_read; - ObservationHeader _header; - - std::vector _beam_filenames; - std::map>>> _beam_writers; - std::map>> _beam_data; - std::size_t _total_beam_writers; - - - void init_readers(); - void init_writers(); - - Timer _timer; - - - public: - SkyCleaver(SkyCleaverConfig const& config); - SkyCleaver(SkyCleaver const&) = delete; - void operator=(SkyCleaver const&) = delete; - - void cleave(); - -}; // class SkyCleaver -} // namespace skyweaver - -#endif // SKYWEAVER_SKYCLEAVER_HPP \ No newline at end of file diff --git a/cpp/skyweaver/SkyCleaver.hpp b/cpp/skyweaver/SkyCleaver.hpp index c22ae3d..fad34cc 100644 --- a/cpp/skyweaver/SkyCleaver.hpp +++ b/cpp/skyweaver/SkyCleaver.hpp @@ -1,8 +1,14 @@ -#include "SigprocFileWriter.hpp" +#ifndef SKYWEAVER_SKYCLEAVER_HPP +#define SKYWEAVER_SKYCLEAVER_HPP #include "boost/log/trivial.hpp" #include "psrdada_cpp/psrdadaheader.hpp" #include "psrdada_cpp/raw_bytes.hpp" +#include "skyweaver/DescribedVector.hpp" +#include "skyweaver/MultiFileReader.cuh" +#include "skyweaver/MultiFileWriter.cuh" #include "skyweaver/ObservationHeader.hpp" +#include "skyweaver/SkyCleaverConfig.hpp" +#include "skyweaver/Timer.hpp" #include #include @@ -12,32 +18,76 @@ #include #include #include + namespace skyweaver { -struct Bridge -{ - private: - using PowerType = std::vector; +struct BridgeReader { public: std::vector _tdb_filenames; + std::unique_ptr _tdb_reader; std::string freq; -}; // class Bridge +}; // BridgeReader + +struct BeamInfo { + std::string beam_name; + std::string beam_ra; + std::string beam_dec; +}; -class MultiBeamWriter +template +class SkyCleaver { + public: + using FreqType = std::size_t; // up to the nearest Hz + using BeamNumberType = std::size_t; + using DMNumberType = std::size_t; + using StokesNumberType = std::size_t; + private: - using FreqType = unsigned int; // up to the nearest Hz - std::map> _bridges; + SkyCleaverConfig& _config; + std::map> _bridge_readers; + std::map> _bridge_data; + + std::vector _expected_freqs; + std::vector _available_freqs; + std::size_t _nsamples_to_read; + ObservationHeader _header; + std::vector _beam_infos; + std::vector _beam_filenames; + std::map< + StokesNumberType, + std::map>>>> + _beam_writers; + + std::map >>> + _beam_data; + std::size_t _total_beam_writers; + + int _nthreads_read; + void init_readers(); + void init_writers(); + void read(std::size_t gulp_samples); + void write(); + + Timer _timer; + + public: + SkyCleaver(SkyCleaverConfig& config); + SkyCleaver(SkyCleaver const&) = delete; + void operator=(SkyCleaver const&) = delete; + void cleave(); +}; // class SkyCleaver - public - : add_bridge(FreqType freq, std::vector tdb_filenames); - remove_bridge(FreqType freq); - init(); +} // namespace skyweaver +#include "skyweaver/detail/SkyCleaver.cpp" -} // class MultiFileWriter -} // namespace skyweaver \ No newline at end of file +#endif // SKYWEAVER_SKYCLEAVER_HPP \ No newline at end of file diff --git a/cpp/skyweaver/SkyCleaverConfig.hpp b/cpp/skyweaver/SkyCleaverConfig.hpp index a1ae8c7..924d0d2 100644 --- a/cpp/skyweaver/SkyCleaverConfig.hpp +++ b/cpp/skyweaver/SkyCleaverConfig.hpp @@ -1,6 +1,7 @@ #ifndef SKYCLEAVERCONFIG_HPP #define SKYCLEAVERCONFIG_HPP -namespace skyweaver { +namespace skyweaver +{ class SkyCleaverConfig { @@ -19,51 +20,94 @@ class SkyCleaverConfig std::size_t _ndms; std::string _stokes_mode; std::size_t _dada_header_size; + std::size_t _start_sample; + std::size_t _nsamples_to_read; + std::vector _required_beams; + std::vector _required_dms; + std::string _targets; - - + std::string _out_stokes; + std::vector> _stokes_positions; - public: + public: + SkyCleaverConfig() + : _output_dir(""), _root_dir(""), _root_prefix(""), _out_prefix(""), + _nthreads(0), _nsamples_per_block(32768), _nchans(0), _nbeams(0), + _max_ram_gb(0), _max_output_filesize(2147483647), _stream_id(0), + _nbridges(64), _ndms(0), _stokes_mode("I"), _dada_header_size(4096), + _start_sample(0), _nsamples_to_read(0), _required_beams({}), _required_dms({}), + _targets(""), _out_stokes("I"), _stokes_positions({}) + { + } + SkyCleaverConfig(SkyCleaverConfig const&) = delete; - SkyCleaverConfig() : _output_dir(""), _root_dir(""), _root_prefix(""), _out_prefix(""), _nthreads(0), _nsamples_per_block(0), _nchans(0), _nbeams(0), _max_ram_gb(0), _max_output_filesize(2147483647), _stream_id(0), _nbridges(64), _ndms(0), _stokes_mode("I"), _dada_header_size(4096) {} - SkyCleaverConfig(SkyCleaverConfig const&) = delete; + void output_dir(std::string output_dir) { _output_dir = output_dir; } + void root_dir(std::string root_dir) { _root_dir = root_dir; } + void root_prefix(std::string root_prefix) { _root_prefix = root_prefix; } + void out_prefix(std::string out_prefix) { _out_prefix = out_prefix; } + void nthreads(std::size_t nthreads) { _nthreads = nthreads; } + void nsamples_per_block(std::size_t nsamples_per_block) + { + _nsamples_per_block = nsamples_per_block; + } + void nchans(std::size_t nchans) { _nchans = nchans; } + void nbeams(std::size_t nbeams) { _nbeams = nbeams; } + void max_ram_gb(std::size_t max_ram_gb) { _max_ram_gb = max_ram_gb; } + void max_output_filesize(std::size_t max_output_filesize) + { + _max_output_filesize = max_output_filesize; + } + void stream_id(std::size_t stream_id) { _stream_id = stream_id; } + void nbridges(std::size_t nbridges) { _nbridges = nbridges; } + void ndms(std::size_t ndms) { _ndms = ndms; } + void stokes_mode(std::string stokes_mode) { _stokes_mode = stokes_mode; } + void dada_header_size(std::size_t dada_header_size) + { + _dada_header_size = dada_header_size; + } + void start_sample(std::size_t start_sample) { _start_sample = start_sample; } + void nsamples_to_read(std::size_t nsamples_to_read) + { + _nsamples_to_read = nsamples_to_read; + } + void required_beams(std::vector required_beams) { _required_beams = required_beams; } + void required_dms(std::vector required_dms) { _required_dms = required_dms; } + void out_stokes(std::string out_stokes) { _out_stokes = out_stokes; } + void stokes_positions(std::vector> stokes_positions) + { + _stokes_positions = stokes_positions; + } + void targets_file(std::string targets) { _targets = targets; } - void output_dir(std::string output_dir) { _output_dir = output_dir; } - void root_dir(std::string root_dir) { _root_dir = root_dir; } - void root_prefix(std::string root_prefix) { _root_prefix = root_prefix; } - void out_prefix(std::string out_prefix) { _out_prefix = out_prefix; } - void nthreads(std::size_t nthreads) { _nthreads = nthreads; } - void nsamples_per_block(std::size_t nsamples_per_block) { _nsamples_per_block = nsamples_per_block; } - void nchans(std::size_t nchans) { _nchans = nchans; } - void nbeams(std::size_t nbeams) { _nbeams = nbeams; } - void max_ram_gb(std::size_t max_ram_gb) { _max_ram_gb = max_ram_gb; } - void max_output_filesize(std::size_t max_output_filesize) { _max_output_filesize = max_output_filesize; } - void stream_id(std::size_t stream_id) { _stream_id = stream_id; } - void nbridges(std::size_t nbridges) { _nbridges = nbridges; } - void ndms(std::size_t ndms) { _ndms = ndms; } - void stokes_mode(std::string stokes_mode) { _stokes_mode = stokes_mode; } - void dada_header_size(std::size_t dada_header_size) { _dada_header_size = dada_header_size; } - - - std::string output_dir() const { return _output_dir; } - std::string root_dir() const { return _root_dir; } - std::string root_prefix() const { return _root_prefix; } - std::string out_prefix() const { return _out_prefix; } - std::size_t nthreads() const { return _nthreads; } - std::size_t nsamples_per_block() const { return _nsamples_per_block; } - std::size_t nchans() const { return _nchans; } - std::size_t nbeams() const { return _nbeams; } - std::size_t max_ram_gb() const { return _max_ram_gb; } - std::size_t max_output_filesize() const { return _max_output_filesize; } - std::size_t stream_id() const { return _stream_id; } - std::size_t nbridges() const { return _nbridges; } - std::size_t ndms() const { return _ndms; } - std::string stokes_mode() const { return _stokes_mode; } - std::size_t dada_header_size() const { return _dada_header_size; } + std::string output_dir() const { return _output_dir; } + std::string root_dir() const { return _root_dir; } + std::string root_prefix() const { return _root_prefix; } + std::string out_prefix() const { return _out_prefix; } + std::size_t nthreads() const { return _nthreads; } + std::size_t nsamples_per_block() const { return _nsamples_per_block; } + std::size_t nchans() const { return _nchans; } + std::size_t nbeams() const { return _nbeams; } + std::size_t max_ram_gb() const { return _max_ram_gb; } + std::size_t max_output_filesize() const { return _max_output_filesize; } + std::size_t stream_id() const { return _stream_id; } + std::size_t nbridges() const { return _nbridges; } + std::size_t ndms() const { return _ndms; } + std::string stokes_mode() const { return _stokes_mode; } + std::size_t dada_header_size() const { return _dada_header_size; } + std::size_t start_sample() const { return _start_sample; } + std::size_t nsamples_to_read() const { return _nsamples_to_read; } + std::vector required_beams() const { return _required_beams; } + std::vector required_dms() const { return _required_dms; } + std::string out_stokes() const { return _out_stokes; } + std::vector> stokes_positions() const + { + return _stokes_positions; + } + std::string targets_file() const { return _targets; } }; diff --git a/cpp/skyweaver/detail/BeamformerPipeline.cu b/cpp/skyweaver/detail/BeamformerPipeline.cu index 7c44422..18617b3 100644 --- a/cpp/skyweaver/detail/BeamformerPipeline.cu +++ b/cpp/skyweaver/detail/BeamformerPipeline.cu @@ -114,8 +114,14 @@ void BeamformerPipeline:: _header = header; _utc_offset = utc_offset; _cb_handler.init(_header); - _ib_handler.init(_header); - _stats_handler.init(_header); + if (_config.output_incoherent_beam()) + { + _ib_handler.init(_header); + } + if (_config.output_statistics()) + { + _stats_handler.init(_header); + } NVTX_RANGE_POP(); } @@ -161,20 +167,23 @@ void BeamformerPipeline:: _timer.stop("transpose TAFTP to FTPA"); NVTX_RANGE_POP(); - NVTX_RANGE_PUSH("Calculate statistics"); - BOOST_LOG_TRIVIAL(debug) << "Checking if channel statistics update request"; - _timer.start("calculate statistics"); - _stats_manager->calculate_statistics(_ftpa_post_transpose); - _timer.stop("calculate statistics"); - NVTX_RANGE_POP(); - NVTX_RANGE_PUSH("Update scalings"); - if(_call_count == 0) { + if ((_call_count == 0) || _config.output_statistics()) + { + NVTX_RANGE_PUSH("Calculate statistics"); + BOOST_LOG_TRIVIAL(debug) << "Checking if channel statistics update request"; + _timer.start("calculate statistics"); + _stats_manager->calculate_statistics(_ftpa_post_transpose); + _timer.stop("calculate statistics"); + NVTX_RANGE_POP(); + if(_call_count == 0) { + NVTX_RANGE_PUSH("Update scalings"); _timer.start("update scalings"); _stats_manager->update_scalings(_delay_manager->beamset_weights(), _delay_manager->nbeamsets()); _timer.stop("update scalings"); + NVTX_RANGE_POP(); + } } - NVTX_RANGE_POP(); // BOOST_LOG_TRIVIAL(debug) << "Peeking the statistics"; // peek(_stats_manager->statistics(), 64); @@ -244,18 +253,24 @@ void BeamformerPipeline:: _timer.stop("coherent beam handler"); NVTX_RANGE_POP(); - NVTX_RANGE_PUSH("Incoherent beamformer handler"); - _timer.start("incoherent beam handler"); - _ib_handler(_tf_ib, dm_idx); - _timer.stop("incoherent beam handler"); - NVTX_RANGE_POP(); + if (_config.output_incoherent_beam()) + { + NVTX_RANGE_PUSH("Incoherent beamformer handler"); + _timer.start("incoherent beam handler"); + _ib_handler(_tf_ib, dm_idx); + _timer.stop("incoherent beam handler"); + NVTX_RANGE_POP(); + } } NVTX_RANGE_POP(); - NVTX_RANGE_PUSH("Stats handler"); - _timer.start("statistics handler"); - _stats_handler(_stats_manager->statistics()); - _timer.stop("statistics handler"); - NVTX_RANGE_POP(); + if (_config.output_statistics()) + { + NVTX_RANGE_PUSH("Stats handler"); + _timer.start("statistics handler"); + _stats_handler(_stats_manager->statistics()); + _timer.stop("statistics handler"); + NVTX_RANGE_POP(); + } NVTX_RANGE_POP(); } diff --git a/cpp/skyweaver/detail/IncoherentDedispersionPipeline.cu b/cpp/skyweaver/detail/IncoherentDedispersionPipeline.cu index 62f722b..f295f5b 100644 --- a/cpp/skyweaver/detail/IncoherentDedispersionPipeline.cu +++ b/cpp/skyweaver/detail/IncoherentDedispersionPipeline.cu @@ -73,10 +73,9 @@ void IncoherentDedispersionPipeline:: // Set the correct DMs on the block _output_buffers[ref_dm_idx].dms(plan[ref_dm_idx].incoherent_dms); _output_buffers[ref_dm_idx].reference_dm(plan[ref_dm_idx].coherent_dm); - _output_buffers[ref_dm_idx].frequencies({_config.centre_frequency() - _config.bandwidth() / 2.0}); - - BOOST_LOG_TRIVIAL(debug) << "setting centre frequency to " << _output_buffers[ref_dm_idx].frequencies()[0]; + _output_buffers[ref_dm_idx].frequencies(_config.channel_frequencies().front()); + BOOST_LOG_TRIVIAL(debug) << "setting centre frequency to " << std::setprecision(15) << _output_buffers[ref_dm_idx].frequencies()[0]; BOOST_LOG_TRIVIAL(debug) << "Passing output buffer to handler: " << _output_buffers[ref_dm_idx].describe(); _timer.start("file writing"); diff --git a/cpp/skyweaver/detail/MultiFileWriter.cu b/cpp/skyweaver/detail/MultiFileWriter.cu index 7529621..d5743f0 100644 --- a/cpp/skyweaver/detail/MultiFileWriter.cu +++ b/cpp/skyweaver/detail/MultiFileWriter.cu @@ -7,8 +7,6 @@ #include #include - - /** * Now write a DADA file per DM * with optional time splitting @@ -20,7 +18,6 @@ namespace skyweaver namespace { - std::string get_formatted_time(long double unix_timestamp) { char formatted_time[80]; @@ -40,29 +37,59 @@ std::string get_formatted_time(long double unix_timestamp) // } template -MultiFileWriter::MultiFileWriter(PipelineConfig const& config, - std::string tag, - CreateStreamCallBackType create_stream_callback) +MultiFileWriter::MultiFileWriter( + PipelineConfig const& config, + std::string tag, + CreateStreamCallBackType create_stream_callback) : _tag(tag), _create_stream_callback(create_stream_callback) { - MultiFileWriterConfig writer_config; - writer_config.header_size = config.dada_header_size(); - writer_config.max_file_size = config.max_output_filesize(); - writer_config.stokes_mode = config.stokes_mode(); - writer_config.base_output_dir = config.output_dir(); + _config.header_size = config.dada_header_size(); + _config.max_file_size = config.max_output_filesize(); + _config.stokes_mode = config.stokes_mode(); + _config.base_output_dir = config.output_dir(); + _config.inner_dir = ""; - _config = writer_config; } template -MultiFileWriter::MultiFileWriter(MultiFileWriterConfig config, - std::string tag, - CreateStreamCallBackType create_stream_callback) - : _config(config), _tag(tag), _create_stream_callback(create_stream_callback) +MultiFileWriter::MultiFileWriter( + PipelineConfig const& config, + std::string tag, + CreateStreamCallBackType create_stream_callback, + PreWriteCallback pre_write_callback) + : _tag(tag), _create_stream_callback(create_stream_callback), _pre_write_callback(pre_write_callback) { + _config.header_size = config.dada_header_size(); + _config.max_file_size = config.max_output_filesize(); + _config.stokes_mode = config.stokes_mode(); + _config.base_output_dir = config.output_dir(); + _config.pre_write = config.pre_write_config(); + _config.inner_dir = ""; + } +template +MultiFileWriter::MultiFileWriter( + MultiFileWriterConfig config, + std::string tag, + CreateStreamCallBackType create_stream_callback) + : _config(config), _tag(tag), + _create_stream_callback(create_stream_callback) +{ + _pre_write_callback = nullptr; +} +template +MultiFileWriter::MultiFileWriter( + MultiFileWriterConfig config, + std::string tag, + CreateStreamCallBackType create_stream_callback, + PreWriteCallback pre_write_callback) + : _config(config), _tag(tag), + _create_stream_callback(create_stream_callback), + _pre_write_callback(pre_write_callback) +{ +} template MultiFileWriter::~MultiFileWriter(){}; @@ -81,14 +108,10 @@ bool MultiFileWriter::has_stream(std::size_t stream_idx) } template -FileOutputStream& +FileOutputStream& MultiFileWriter::create_stream(VectorType const& stream_data, std::size_t stream_idx) { - - - - _config.output_dir = get_output_dir(stream_data, stream_idx); if(_config.extension.empty()) { @@ -97,10 +120,8 @@ MultiFileWriter::create_stream(VectorType const& stream_data, _config.output_basename = get_basefilename(stream_data, stream_idx); - - _file_streams[stream_idx] = _create_stream_callback(_config, _header, stream_data, stream_idx); - - + _file_streams[stream_idx] = + _create_stream_callback(_config, _header, stream_data, stream_idx); return *_file_streams[stream_idx]; } @@ -114,9 +135,12 @@ MultiFileWriter::get_output_dir(VectorType const& stream_data, // // std::stringstream output_dir; output_dir << _config.base_output_dir << "/" - << get_formatted_time(_header.utc_start) << "/" - << stream_idx; - + << get_formatted_time(_header.utc_start) << "/" << stream_idx; + + if(!_config.inner_dir.empty()) { + output_dir << "/" << _config.inner_dir; + } + return output_dir.str(); } @@ -132,9 +156,14 @@ MultiFileWriter::get_basefilename(VectorType const& stream_data, base_filename << _config.prefix << "_"; } base_filename << get_formatted_time(_header.utc_start) << "_" << stream_idx - << "_" << std::fixed << std::setprecision(3) + << "_cdm_" << std::fixed << std::setprecision(3) << std::setfill('0') << std::setw(9) - << stream_data.reference_dm(); + << stream_data.reference_dm(); + + if(!_config.suffix.empty()) { + base_filename << "_" << _config.suffix; + } + if(!_tag.empty()) { base_filename << "_" << _tag; } @@ -147,10 +176,9 @@ MultiFileWriter::get_extension(VectorType const& stream_data) { std::string dims = stream_data.dims_as_string(); for(auto& c: dims) { c = std::tolower(static_cast(c)); } - if(dims =="t") { + if(dims == "t") { return ".dat"; - } - else if(dims == "tf") { + } else if(dims == "tf") { return ".fil"; } return "." + dims; @@ -160,6 +188,10 @@ template bool MultiFileWriter::operator()(VectorType const& stream_data, std::size_t stream_idx) { + if (_pre_write_callback != nullptr && _config.pre_write.is_enabled) + { + _pre_write_callback(_config); + } if(!has_stream(stream_idx)) { create_stream(stream_data, stream_idx); } @@ -174,19 +206,16 @@ bool MultiFileWriter::operator()(VectorType const& stream_data, _file_streams.at(stream_idx) ->write(reinterpret_cast( thrust::raw_pointer_cast(stream_data.data())), - stream_data.size() * - sizeof(typename VectorType::value_type)); + stream_data.size() * sizeof(typename VectorType::value_type)); } return false; } template bool MultiFileWriter::write(VectorType const& stream_data, - std::size_t stream_idx) + std::size_t stream_idx) { - - return this->operator()(stream_data, stream_idx); - + return this->operator()(stream_data, stream_idx); } -} // namespace skyweaver \ No newline at end of file +} // namespace skyweaver diff --git a/cpp/skyweaver/detail/SkyCleaver.cpp b/cpp/skyweaver/detail/SkyCleaver.cpp new file mode 100644 index 0000000..9f86c58 --- /dev/null +++ b/cpp/skyweaver/detail/SkyCleaver.cpp @@ -0,0 +1,748 @@ + +#include "skyweaver/SkyCleaver.hpp" + +#include "skyweaver/beamformer_utils.cuh" +#include "skyweaver/types.cuh" +#include "skyweaver/skycleaver_utils.hpp" + + +#include +#include +#include +#include +#include +#include +#include + +namespace fs = std::filesystem; + +using BridgeReader = skyweaver::BridgeReader; +using MultiFileReader = skyweaver::MultiFileReader; +using BeamInfo = skyweaver::BeamInfo; + +namespace +{ + +std::string trim(const std::string& str) { + auto start = str.find_first_not_of(" \t\n\r"); + if (start == std::string::npos) { + return ""; // String is all whitespace + } + auto end = str.find_last_not_of(" \t\n\r"); + return str.substr(start, end - start + 1); +} + +template +std::string to_string_with_padding(T num, int width, int precision = -1) +{ + std::ostringstream oss; + oss << std::setw(width) << std::setfill('0'); + if(precision >= + 0) { // Check if precision is specified for floating-point numbers + oss << std::fixed << std::setprecision(precision); + } + oss << num; + return oss.str(); +} +std::vector +get_subdirs(std::string directory_path, + std::regex numeric_regex = std::regex("^[0-9]+$")) +{ + std::vector subdirs; + try { + if(fs::exists(directory_path) && fs::is_directory(directory_path)) { + for(const auto& entry: fs::directory_iterator(directory_path)) { + if(fs::is_directory(entry.status())) { + std::string folder_name = entry.path().filename().string(); + if(std::regex_match(folder_name, numeric_regex)) { + BOOST_LOG_TRIVIAL(debug) + << "Found subdirectory: " << folder_name; + subdirs.push_back(folder_name); + } + } + } + } else { + std::runtime_error( + "Root directory does not exist or is not a directory."); + } + } catch(const fs::filesystem_error& e) { + std::cerr << "Filesystem error: " << e.what() << std::endl; + std::runtime_error("Error reading subdirectories in root directory: " + + directory_path); + } + + return subdirs; +} + +std::vector get_files(std::string directory_path, + std::string extension) +{ + std::vector files; + try { + if(fs::exists(directory_path) && fs::is_directory(directory_path)) { + for(const auto& entry: fs::directory_iterator(directory_path)) { + if(fs::is_regular_file(entry.status())) { + std::string file_name = entry.path().string(); + if(file_name.find(extension) != std::string::npos) { + //check if .tmp not in filename + if(file_name.find(".tmp") != std::string::npos) { + continue; + } + files.push_back(file_name); + } + } + } + } else { + std::runtime_error("No files in bridge directory: " + + directory_path); + } + } catch(const fs::filesystem_error& e) { + std::cerr << "Filesystem error: " << e.what() << std::endl; + std::runtime_error("Error reading files in bridge directory: " + + directory_path); + } + + return files; +} + + +} // namespace + +void parse_target_file(std::string file_name, std::vector& beam_infos){ + std::ifstream targets_file(file_name); + if(!targets_file.is_open()) { + std::runtime_error("Error opening target file: " + file_name); + } + // the file is in csv format. First read the header to know the positions of name, ra and dec + std::string header; + do{ + std::getline(targets_file, header); + } while(header.empty() || header.find("#") != std::string::npos); + + std::vector header_tokens; + std::stringstream header_stream(header); + std::string token; + while(std::getline(header_stream, token, ',')) { + header_tokens.push_back(token); + } + std::size_t name_pos = std::distance(header_tokens.begin(), std::find(header_tokens.begin(), header_tokens.end(), "name")); + std::size_t ra_pos = std::distance(header_tokens.begin(), std::find(header_tokens.begin(), header_tokens.end(), "ra")); + std::size_t dec_pos = std::distance(header_tokens.begin(), std::find(header_tokens.begin(), header_tokens.end(), "dec")); + + if(name_pos == header_tokens.size() || ra_pos == header_tokens.size() || dec_pos == header_tokens.size()) { + std::runtime_error("Invalid header in target file: " + file_name); + } + + std::string line; + while(std::getline(targets_file, line)) { + + line = trim(line); + + //if empty line or # anywhere in line, continue + if(line.empty() || line.find("#") != std::string::npos) { + BOOST_LOG_TRIVIAL(debug) << "Skipping line: " << line; + continue; + } + + std::vector tokens; + std::stringstream line_stream(line); + std::string token; + while(std::getline(line_stream, token, ',')) { + tokens.push_back(token); + } + if(tokens.size() != header_tokens.size()) { + std::runtime_error("Invalid number of columns in target file: " + file_name); + } + BeamInfo beam_info; + beam_info.beam_name = tokens[name_pos]; + beam_info.beam_ra = tokens[ra_pos]; + beam_info.beam_dec = tokens[dec_pos]; + beam_infos.push_back(beam_info); + } +} + +void compare_bridge_headers(const skyweaver::ObservationHeader& first, + const skyweaver::ObservationHeader& second) +{ + if(first.nchans != second.nchans) { + throw std::runtime_error("Number of channels in bridge readers do not " + "match. Expected: " + + std::to_string(first.nchans) + + " Found: " + std::to_string(second.nchans)); + } + if(first.nbeams != second.nbeams) { + throw std::runtime_error( + "Number of beams in bridge readers do not match. " + "Expected: " + + std::to_string(first.nbeams) + + " Found: " + std::to_string(second.nbeams)); + } + if(first.nbits != second.nbits) { + throw std::runtime_error( + "Number of bits in bridge readers do not match. " + "Expected: " + + std::to_string(first.nbits) + + " Found: " + std::to_string(second.nbits)); + } + if(first.tsamp != second.tsamp) { + throw std::runtime_error( + "Sampling time in bridge readers do not match. " + "Expected: " + + std::to_string(first.tsamp) + + " Found: " + std::to_string(second.tsamp)); + } + if(first.stokes_mode != second.stokes_mode) { + throw std::runtime_error("Stokes mode in bridge readers do not match. " + "Expected: " + + first.stokes_mode + + " Found: " + second.stokes_mode); + } +} + +template +void skyweaver::SkyCleaver::init_readers() +{ + BOOST_LOG_NAMED_SCOPE("SkyCleaver::init_readers") + + std::string root_dir = _config.root_dir(); + std::size_t stream_id = _config.stream_id(); + + // get the list of directories in root/stream_id(for the nex) + std::vector freq_dirs = + get_subdirs(root_dir + "/" + std::to_string(stream_id)); + + BOOST_LOG_TRIVIAL(info) + << "Found " << freq_dirs.size() + << " frequency directories in root directory: " << root_dir; + + std::map::FreqType, + long double> + bridge_timestamps; + long double latest_timestamp = 0.0; + + for(const auto& freq_dir: freq_dirs) { + std::vector tdb_files = get_files( + root_dir + "/" + std::to_string(stream_id) + "/" + freq_dir, + ".tdb"); + BOOST_LOG_TRIVIAL(info) << "Found " << tdb_files.size() + << " TDB files for frequency: " << freq_dir; + if(tdb_files.empty()) { + BOOST_LOG_TRIVIAL(warning) + << "No TDB files found for frequency: " << freq_dir; + continue; + } + + std::size_t freq = static_cast(std::stoul(freq_dir)); + + _bridge_readers[freq] = + std::make_unique(tdb_files, + _config.dada_header_size(), + false); + long double timestamp = _bridge_readers[freq]->get_header().utc_start; + bridge_timestamps.insert({freq, timestamp}); + if(timestamp > latest_timestamp) { + latest_timestamp = timestamp; + } + _available_freqs.push_back(freq); + + BOOST_LOG_TRIVIAL(debug) + << "Added bridge reader for frequency: " << freq_dir; + } + + int nbridges = _config.nbridges(); + + _header = _bridge_readers[_available_freqs[0]]->get_header(); + for(const auto& [freq, reader]: _bridge_readers) { + compare_bridge_headers(_header, reader->get_header()); + } + BOOST_LOG_TRIVIAL(info) + << "Number of beams: " << _header.nbeams + << " Number of DMS: " << _header.ndms + << " Stokes mode: " << _header.stokes_mode + << " Number of channels: " << _header.nchans; + + + _config.nbeams(_header.nbeams); + _config.ndms(_header.ndms); + _config.stokes_mode(_header.stokes_mode); + _config.nchans(_header.nchans); + + BOOST_LOG_TRIVIAL(info) + << "Number of beams: " << _config.nbeams() + << " Number of DMS: " << _config.ndms() + << " Stokes mode: " << _config.stokes_mode() + << " Number of channels: " << _config.nchans(); + + std::vector> stokes_positions; + for(const auto stokes: _config.out_stokes()) { + std::size_t pos = _config.stokes_mode().find(stokes); + if(pos == std::string::npos) { + + if(stokes == 'L') { + std::size_t pos1 = _config.stokes_mode().find("Q"); + std::size_t pos2 = _config.stokes_mode().find("U"); + if(pos1 == std::string::npos || pos2 == std::string::npos) { + throw std::runtime_error("Asked for L, but beamformed data does not have Q and/or U"); + } + stokes_positions.push_back({pos1, pos2}); + continue; + } + else { + throw std::runtime_error("Requested stokes not found in beamformed data: " + stokes); + } + } + stokes_positions.push_back({pos}); + } + + _config.stokes_positions(stokes_positions); + + + long double obs_centre_freq = _header.obs_frequency; + long double obs_bandwidth = _header.obs_bandwidth; + + long double start_freq = obs_centre_freq - obs_bandwidth / 2; + + for(int i = 0; i < _config.nbridges(); i++) { + std::size_t ifreq = std::lround(std::floor( + start_freq + (i + 0.5) * obs_bandwidth / _config.nbridges())); + _expected_freqs.push_back(ifreq); + BOOST_LOG_TRIVIAL(info) + << "Expected frequency [" << i << "]: " << ifreq; + + if(_bridge_readers.find(ifreq) == _bridge_readers.end()) { + BOOST_LOG_TRIVIAL(warning) + << "Frequency " << ifreq + << " not found in bridge readers, will write zeros"; + } + _bridge_data[ifreq] = std::make_unique( + std::initializer_list{_config.nsamples_per_block(), + _config.ndms(), + _config.nbeams(), + _config.stokes_mode().size()}, + 0); + } + + std::size_t smallest_data_size = std::numeric_limits::max(); + std::size_t dbp_factor = + _config.ndms() * _config.nbeams() * _config.stokes_mode().size(); + + for(const auto& [freq, reader]: _bridge_readers) { + // at this point, all non-existed frequencies have been added with zero + // data now check if there are any unexpected frequencies in the bridge + // readers. + if(std::find(_expected_freqs.begin(), _expected_freqs.end(), freq) == + _expected_freqs.end()) { + throw std::runtime_error("Frequency " + std::to_string(freq) + + " not found in expected frequencies"); + } + + // now time align all the bridges to the latest timestamp + long double timestamp = bridge_timestamps[freq]; + long double time_diff = latest_timestamp - timestamp; + long double tsamp = + reader->get_header().tsamp * + 1e-6; // Header has it in microseconds, converting to seconds + std::size_t nsamples = std::floor(time_diff / tsamp); + + BOOST_LOG_TRIVIAL(debug) + << "Frequency: " << freq << " Timestamp: " << timestamp + << "tsamp: " << tsamp << " Latest timestamp: " << latest_timestamp + << " Time difference: " << time_diff + << " Number of samples to skip: " << nsamples; + + BOOST_LOG_TRIVIAL(debug) + << "Seeking " << nsamples * dbp_factor + << " bytes in bridge reader for frequency: " << freq; + + std::size_t bytes_seeking = + (nsamples * dbp_factor * + sizeof(typename InputVectorType::value_type)); + + _bridge_readers[freq]->seekg(bytes_seeking, std::ios_base::beg); + + std::size_t data_size = + _bridge_readers[freq]->get_total_size() - bytes_seeking; + BOOST_LOG_TRIVIAL(debug) + << "Data size for frequency: " << freq << " is " << data_size; + if(data_size < smallest_data_size) { + smallest_data_size = data_size; + } + } + + if(smallest_data_size % dbp_factor != 0) { + std::runtime_error( + "Data size is not a multiple of ndms * nbeams * nstokes"); + } + + std::size_t smallest_nsamples = std::floor(smallest_data_size / dbp_factor); + + if(smallest_nsamples < _config.start_sample()) { + std::runtime_error( + "start_sample is greater than the smallest_nsamples in the data."); + } + + smallest_nsamples = smallest_nsamples - _config.start_sample(); + + BOOST_LOG_TRIVIAL(info) + << "Smallest data size: " << smallest_data_size + << " Smallest number of samples: " << smallest_nsamples; + + if(smallest_nsamples < _config.nsamples_per_block()) { + std::runtime_error( + "Smallest data size is less than nsamples_per_block"); + } + + if(_config.nsamples_to_read() > 0) { + if(smallest_nsamples < _config.nsamples_to_read()) { + std::runtime_error( + "Smallest data size is less than nsamples_to_read"); + } + + _nsamples_to_read = _config.nsamples_to_read(); + } else { + _nsamples_to_read = smallest_nsamples; + } + + std::size_t bytes_seeking = (_config.start_sample() * dbp_factor * + sizeof(typename InputVectorType::value_type)); + + if(bytes_seeking > 0) { + BOOST_LOG_TRIVIAL(info) << "Seeking " << bytes_seeking + << " bytes in bridge readers to start sample: " + << _config.start_sample(); + for(const auto& [freq, reader]: _bridge_readers) { + _bridge_readers[freq]->seekg(bytes_seeking, std::ios_base::cur); + } + } + + BOOST_LOG_TRIVIAL(info) + << "Added " << _bridge_data.size() << " bridge readers to SkyCleaver"; + + _header = _bridge_readers[_available_freqs[0]]->get_header(); + BOOST_LOG_TRIVIAL(info) + << "Adding first header to SkyCleaver: " << _header.to_string(); + _header.nchans = _header.nchans * _config.nbridges(); + _header.nbeams = _config.nbeams(); + + _nthreads_read = _config.nthreads() > _config.nbridges() + ? _config.nbridges() + : _config.nthreads(); +} +template +void skyweaver::SkyCleaver::init_writers() +{ + BOOST_LOG_NAMED_SCOPE("SkyCleaver::init_writers") + + BOOST_LOG_TRIVIAL(debug) + << "_config.output_dir(); " << _config.output_dir(); + std::string output_dir = _config.output_dir(); + + if(!fs::exists(output_dir)) { + fs::create_directories(output_dir); + } + + for(std::size_t istokes = 0; istokes < _config.out_stokes().size(); + istokes++) { + for(int idm = 0; idm < _config.ndms(); idm++) { + if(!_config.required_dms().empty()) { + const auto& required_dms = _config.required_dms(); + if(std::ranges::find(required_dms, _header.dms[idm]) == + required_dms.end()) { + BOOST_LOG_TRIVIAL(info) + << "DM " << _header.dms[idm] + << " is not required, skipping from writing"; + continue; + } + } + for(int ibeam = 0; ibeam < _config.nbeams(); ibeam++) { + // skip if beam is not used + if(!_config.required_beams().empty()) { + const auto& required_beams = _config.required_beams(); + std::cerr << "required_beams: " << required_beams.size() + << "ibeam: " << ibeam << std::endl; + if(std::ranges::find(required_beams, ibeam) == + required_beams.end()) { + BOOST_LOG_TRIVIAL(info) + << "Beam " << ibeam + << " is not required, skipping from writing"; + continue; + } + } + BeamInfo beam_info = _beam_infos[ibeam]; + MultiFileWriterConfig writer_config; + writer_config.header_size = _config.dada_header_size(); + writer_config.max_file_size = _config.max_output_filesize(); + writer_config.stokes_mode = _config.out_stokes().at(istokes); + writer_config.base_output_dir = output_dir; + writer_config.inner_dir = beam_info.beam_name; + writer_config.prefix = _config.out_prefix(); + std::string suffix = "idm_" + + to_string_with_padding(_header.dms[idm], 9, 3); + + writer_config.suffix = suffix + "_" + + beam_info.beam_name + "_" + + _config.out_stokes().at(istokes); + writer_config.extension = ".fil"; + + BOOST_LOG_TRIVIAL(info) + << "Writer config: " << writer_config.to_string(); + + typename MultiFileWriter:: + CreateStreamCallBackType create_stream_callback_sigproc = + skyweaver::detail::create_sigproc_file_stream< + OutputVectorType>; + _beam_writers[istokes][idm][ibeam] = + std::make_unique>( + writer_config, + "", + create_stream_callback_sigproc); + _header.ibeam = ibeam; + + _header.ra = beam_info.beam_ra; + _header.dec = beam_info.beam_dec; + _beam_writers[istokes][idm][ibeam]->init(_header); + + _beam_data[istokes][idm][ibeam] = + std::make_shared( + std::initializer_list{_config.nsamples_per_block(), + _config.nbridges()}, + 0); + + _beam_data[istokes][idm][ibeam]->reference_dm(_header.refdm); + _total_beam_writers++; + } + } + } + + BOOST_LOG_TRIVIAL(info) + << "Added " << _total_beam_writers << " beam writers to SkyCleaver"; +} +template +skyweaver::SkyCleaver::SkyCleaver( + SkyCleaverConfig& config) + : _config(config) +{ + BOOST_LOG_TRIVIAL(info) << "Reading and initialising beam details from file: " + << _config.targets_file(); + + parse_target_file(_config.targets_file(), _beam_infos); + + BOOST_LOG_TRIVIAL(info) << "Number of beams in target file: " << _beam_infos.size(); + + _timer.start("skycleaver::init_readers"); + init_readers(); + _timer.stop("skycleaver::init_readers"); + + + if(_beam_infos.size() < _config.nbeams()){ // there are some null beams with zeros, do not create filterbanks for them. + std::string required_beams = "0:" + std::to_string(_beam_infos.size()-1); + std::vector required_beam_numbers = skyweaver::get_list_from_string(required_beams); + + if(_config.required_beams().empty()) { // if nothing given, set the valid beams to required beams + BOOST_LOG_TRIVIAL(warning) << "Number of beams in target file is less than the number of beams in the header. " + << "Setting required beams to: " << required_beams; + _config.required_beams(skyweaver::get_list_from_string(required_beams)); + } + else{ + for(auto beam_num: _config.required_beams()){ // if given, check if all requested beams are valid beams + if(std::find(required_beam_numbers.begin(), required_beam_numbers.end(), beam_num) == required_beam_numbers.end()){ + std::runtime_error("Beam number " + std::to_string(beam_num) + " not found in target file."); + } + } + } + } + + + _timer.start("skycleaver::init_writers"); + init_writers(); + _timer.stop("skycleaver::init_writers"); +} +template +void skyweaver::SkyCleaver::read( + std::size_t gulp_samples) +{ + std::size_t gulp_size = gulp_samples * _config.ndms() * _config.nbeams() * + _config.stokes_mode().size(); + BOOST_LOG_TRIVIAL(info) << "Reading gulp samples: " << gulp_samples + << " with size: " << gulp_size; + + omp_set_num_threads(_nthreads_read); + + std::vector read_failures( + _available_freqs.size(), + false); // since we cannot throw exceptions in parallel regions +#pragma omp parallel for + for(std::size_t i = 0; i < _available_freqs.size(); i++) { + skyweaver::SkyCleaver::FreqType + freq = _available_freqs[i]; + if(_bridge_readers.find(freq) == _bridge_readers.end()) { + read_failures[i] = true; + } + const auto& reader = _bridge_readers[freq]; + if(reader->eof()) { + BOOST_LOG_TRIVIAL(warning) + << "End of file reached for bridge " << freq; + read_failures[i] = true; + } + + std::streamsize read_size = + reader->read(reinterpret_cast(thrust::raw_pointer_cast( + _bridge_data[freq]->data())), + gulp_size); // read a big chunk of data + BOOST_LOG_TRIVIAL(debug) + << "Read " << read_size << " bytes from bridge" << freq; + if(read_size < + gulp_size * sizeof(typename InputVectorType::value_type)) { + BOOST_LOG_TRIVIAL(warning) + << "Read less data than expected from bridge " << freq; + read_failures[i] = true; + } + } + bool failed = false; + for(int i = 0; i < read_failures.size(); i++) { + if(read_failures[i]) { + BOOST_LOG_TRIVIAL(error) + << "Reading bridge [" << i << "]: failed " << std::endl; + failed = true; + } + } + if(failed) { + std::runtime_error("Failed to read data from bridge readers."); + } + + BOOST_LOG_TRIVIAL(info) << "Read data from bridge readers"; +} + +template +void skyweaver::SkyCleaver::cleave() +{ + BOOST_LOG_NAMED_SCOPE("SkyCleaver::cleave") + + for(std::size_t nsamples_read = 0; nsamples_read < _nsamples_to_read; + nsamples_read += _config.nsamples_per_block()) { + std::size_t gulp_samples = + _nsamples_to_read - nsamples_read < _config.nsamples_per_block() + ? _nsamples_to_read - nsamples_read + : _config.nsamples_per_block(); + + BOOST_LOG_TRIVIAL(info) << "Cleaving samples: " << nsamples_read + << " to " << nsamples_read + gulp_samples; + + _timer.start("skyweaver::read_data"); + read(gulp_samples); + _timer.stop("skyweaver::read_data"); + + _timer.start("skyweaver::process_data"); + omp_set_num_threads(_config.nthreads()); + + std::size_t nbridges = _config.nbridges(); + std::size_t ndms = _config.ndms(); + std::size_t nbeams = _config.nbeams(); + std::size_t nstokes_out = _config.out_stokes().size(); + std::size_t nstokes_in = _config.stokes_mode().size(); + +#pragma omp parallel for schedule(static) collapse(3) + for(std::size_t istokes = 0; istokes < nstokes_out; istokes++) { + for(std::size_t ibeam = 0; ibeam < nbeams; ibeam++) { + for(std::size_t idm = 0; idm < ndms; + idm++) { // cannot separate loops, so do checks later + + if(_beam_data.find(istokes) == _beam_data.end()) { + continue; + } + if(_beam_data[istokes].find(idm) == + _beam_data[istokes].end()) { + continue; + } + if(_beam_data[istokes][idm].find(ibeam) == + _beam_data[istokes][idm].end()) { + continue; + } + + + const std::vector stokes_positions = + _config.stokes_positions()[istokes]; + +#pragma omp simd + for(std::size_t isample = 0; isample < gulp_samples; + isample++) { + const std::size_t out_offset = isample * nbridges; + + // This is stupid but preferred over a more elegant solution that is not fast, this can be easily vectorised + if(stokes_positions.size() == 1) { + const std::size_t base_index = + isample * ndms * nbeams * nstokes_in + + idm * nbeams * nstokes_in + ibeam * nstokes_in + + stokes_positions[0]; + + std::size_t ifreq = 0; + for(const auto& [freq, ifreq_data]: + _bridge_data) { // for each frequency + _beam_data[istokes][idm][ibeam]->at( + out_offset + nbridges - 1 - ifreq) = + clamp(127 + + ifreq_data->at(base_index)); + ++ifreq; + } + } + else{ + std::size_t ifreq = 0; + for(const auto& [freq, ifreq_data]: _bridge_data) { + float value = 0; + for(int stokes_position=0; stokes_positionat(base_index) * ifreq_data->at(base_index)); + } + // for each frequency + _beam_data[istokes][idm][ibeam]->at( + out_offset + nbridges - 1 - ifreq) = + clamp(127 + + sqrt(value)); + ++ifreq; + } + } + + } + + } + } + } + BOOST_LOG_TRIVIAL(info) << "Processed data"; + _timer.stop("skyweaver::process_data"); + _timer.start("skyweaver::write_data"); + write(); + _timer.stop("skyweaver::write_data"); + } + _timer.show_all_timings(); + +} + +template +void skyweaver::SkyCleaver::write() +{ + omp_set_num_threads(_config.nthreads()); +#pragma omp parallel for schedule(static) collapse(3) + for(std::size_t istokes = 0; istokes < _config.out_stokes().size(); + istokes++) { + for(std::size_t idm = 0; idm < _config.ndms(); idm++) { + for(std::size_t ibeam = 0; ibeam < _config.nbeams(); ibeam++) { + if(_beam_data.find(istokes) == _beam_data.end()) { + continue; + } + if(_beam_data[istokes].find(idm) == _beam_data[istokes].end()) { + continue; + } + if(_beam_data[istokes][idm].find(ibeam) == + _beam_data[istokes][idm].end()) { + continue; + } + _beam_writers[istokes][idm][ibeam]->write( + *_beam_data[istokes][idm][ibeam], + _config.stream_id()); + } + } + } +} diff --git a/cpp/skyweaver/detail/file_writer_callbacks.cpp b/cpp/skyweaver/detail/file_writer_callbacks.cpp index e560aac..ffb2c42 100644 --- a/cpp/skyweaver/detail/file_writer_callbacks.cpp +++ b/cpp/skyweaver/detail/file_writer_callbacks.cpp @@ -4,22 +4,23 @@ #include "skyweaver/MultiFileWriter.cuh" #include "skyweaver/ObservationHeader.hpp" #include "skyweaver/SigprocHeader.hpp" + #include -#include -#include -#include -#include #include +#include +#include +#include +#include #include +#include #include -#include -#include #include -#include - +#include +#include -namespace { - /** +namespace +{ +/** * The expectation is that each file contains data in * TBTF order. Need to explicitly update: * INNER_T --> the number of timesamples per block that was processed @@ -37,13 +38,10 @@ HEADER DADA HDR_VERSION 1.0 HDR_SIZE 4096 DADA_VERSION 1.0 - FILE_SIZE 100000000000 FILE_NUMBER 0 - UTC_START 1708082229.000020336 MJD_START 60356.47024305579093 - SOURCE J1644-4559 RA 16:44:49.27 DEC -45:59:09.7 @@ -53,29 +51,27 @@ RECEIVER L-band FREQ 1284000000.000000 BW 856000000.000000 TSAMP 4.7850467290 -STOKES I - NBIT 8 NDIM 1 NPOL 1 NCHAN 64 NBEAM 800 ORDER TFB - CHAN0_IDX 2688 - )"; +)"; #define MJD_UNIX_EPOCH 40587.0 -} +} // namespace namespace skyweaver { -namespace detail +namespace detail { -template inline -std::unique_ptr create_dada_file_stream(MultiFileWriterConfig const& config, - ObservationHeader const& header, - VectorType const& stream_data, - std::size_t stream_idx) +template +inline std::unique_ptr +create_dada_file_stream(MultiFileWriterConfig const& config, + ObservationHeader const& header, + VectorType const& stream_data, + std::size_t stream_idx) { BOOST_LOG_TRIVIAL(debug) << "Creating stream based on stream prototype: " << stream_data.describe(); @@ -85,113 +81,120 @@ std::unique_ptr create_dada_file_stream(MultiFileWriterConfig config.max_file_size / stream_data.size() / sizeof(typename VectorType::value_type)) * stream_data.size(); - + BOOST_LOG_TRIVIAL(debug) << "Maximum allowed file size = " << filesize << " bytes (+header)"; std::stringstream output_dir; - output_dir << config.output_dir << "/" - << std::fixed << std::setfill('0') << std::setw(9) - << static_cast(header.frequency); - - std::stringstream output_basename; - output_basename << config.output_basename << "_" - << std::fixed << std::setfill('0') << std::setw(9) - << static_cast(header.frequency); - - std::unique_ptr file_stream = std::make_unique( - output_dir.str(), - output_basename.str(), - config.extension, - filesize, - [&, header, stream_data, stream_idx, filesize]( - std::size_t& header_size, - std::size_t bytes_written, - std::size_t file_idx) -> std::shared_ptr { - header_size = config.header_size; - char* temp_header = new char[header_size]; - std::fill(temp_header, temp_header + header_size, 0); - std::memcpy(temp_header, - default_dada_header.c_str(), - default_dada_header.size()); - psrdada_cpp::RawBytes bytes(temp_header, - header_size, - header_size, - false); - Header header_writer(bytes); - header_writer.set("SOURCE", header.source_name); - header_writer.set("RA", header.ra); - header_writer.set("DEC", header.dec); - header_writer.set("NBEAM", stream_data.nbeams()); - header_writer.set("NCHAN", stream_data.nchannels()); - header_writer.set("OBS_NCHAN", header.obs_nchans); - header_writer.set("OBS_FREQUENCY", - header.obs_frequency); - header_writer.set("OBS_BW", header.obs_bandwidth); - header_writer.set("NSAMP", stream_data.nsamples()); - if(stream_data.ndms()) { - header_writer.set("NDMS", stream_data.ndms()); - header_writer.set("DMS", stream_data.dms(), 7); - } - header_writer.set( - "COHERENT_DM", - static_cast(stream_data.reference_dm())); - try{ - header_writer.set("FREQ", std::accumulate( - stream_data.frequencies().begin(), - stream_data.frequencies().end(), - 0.0)/stream_data.frequencies().size()); - } catch(std::runtime_error& ){ - BOOST_LOG_TRIVIAL(warning) << "Warning: Frequencies array was stale, using the centre frequency from the header"; - header_writer.set("FREQ", header.frequency); - } - - header_writer.set("BW", header.bandwidth); - header_writer.set("TSAMP", stream_data.tsamp() * 1e6); - if(config.stokes_mode == "IQUV") { - header_writer.set("NPOL", 4); - } else { - header_writer.set("NPOL", 1); - } - header_writer.set("STOKES_MODE", - config.stokes_mode); - header_writer.set("ORDER", - stream_data.dims_as_string()); - header_writer.set("CHAN0_IDX", header.chan0_idx); - header_writer.set("FILE_SIZE", filesize); - header_writer.set("FILE_NUMBER", file_idx); - header_writer.set("OBS_OFFSET", bytes_written); - header_writer.set("OBS_OVERLAP", 0); - - double tstart = header.utc_start + stream_data.utc_offset(); - - header_writer.set("UTC_START", tstart); - header_writer.set("MJD_START", MJD_UNIX_EPOCH + tstart / 86400.0); - std::shared_ptr header_ptr( - temp_header, - std::default_delete()); - - return header_ptr; - }); + output_dir << config.output_dir << "/" << std::fixed << std::setfill('0') + << std::setw(9) << static_cast(header.frequency); + + std::stringstream output_basename; + output_basename << config.output_basename << "_" << std::fixed + << std::setfill('0') << std::setw(9) + << static_cast(header.frequency); + + std::unique_ptr file_stream = + std::make_unique( + output_dir.str(), + output_basename.str(), + config.extension, + filesize, + [&, header, stream_data, stream_idx, filesize]( + std::size_t& header_size, + std::size_t bytes_written, + std::size_t file_idx) -> std::shared_ptr { + header_size = config.header_size; + char* temp_header = new char[header_size]; + std::fill(temp_header, temp_header + header_size, 0); + std::memcpy(temp_header, + default_dada_header.c_str(), + default_dada_header.size()); + psrdada_cpp::RawBytes bytes(temp_header, + header_size, + header_size, + false); + Header header_writer(bytes); + header_writer.set("SOURCE", header.source_name); + header_writer.set("RA", header.ra); + header_writer.set("DEC", header.dec); + header_writer.set("NBEAM", stream_data.nbeams()); + header_writer.set("OBS_NCHAN", header.obs_nchans); + header_writer.set("NCHAN", + stream_data.nchannels()); + header_writer.set("OBS_FREQUENCY", + header.obs_frequency); + header_writer.set("OBS_BW", header.obs_bandwidth); + header_writer.set("NSAMP", stream_data.nsamples()); + if(stream_data.ndms()) { + header_writer.set("NDMS", stream_data.ndms()); + header_writer.set("DMS", stream_data.dms(), 7); + } + header_writer.set( + "COHERENT_DM", + static_cast(stream_data.reference_dm())); + try { + header_writer.set( + "FREQ", + std::accumulate(stream_data.frequencies().begin(), + stream_data.frequencies().end(), + 0.0) / + stream_data.frequencies().size()); + } catch(std::runtime_error&) { + BOOST_LOG_TRIVIAL(warning) + << "Warning: Frequencies array was stale, using the " + "centre frequency from the header"; + header_writer.set("FREQ", header.frequency); + } + + header_writer.set("BW", header.bandwidth); + header_writer.set("TSAMP", + stream_data.tsamp() * 1e6); + if(config.stokes_mode == "IQUV") { + header_writer.set("NPOL", 4); + } else { + header_writer.set("NPOL", 1); + } + header_writer.set("STOKES_MODE", + config.stokes_mode); + header_writer.set("ORDER", + stream_data.dims_as_string()); + header_writer.set("CHAN0_IDX", header.chan0_idx); + header_writer.set("FILE_SIZE", filesize); + header_writer.set("FILE_NUMBER", file_idx); + header_writer.set("OBS_OFFSET", bytes_written); + header_writer.set("OBS_OVERLAP", 0); + + long double tstart = + header.utc_start + stream_data.utc_offset(); + + header_writer.set("UTC_START", tstart); + header_writer.set("MJD_START", + MJD_UNIX_EPOCH + + tstart / 86400.0); + std::shared_ptr header_ptr( + temp_header, + std::default_delete()); + + return header_ptr; + }); return file_stream; } -template inline -std::unique_ptr create_sigproc_file_stream(MultiFileWriterConfig const& config, - ObservationHeader const& obs_header, - VectorType const& stream_data, - std::size_t stream_idx) +template +inline std::unique_ptr +create_sigproc_file_stream(MultiFileWriterConfig const& config, + ObservationHeader const& obs_header, + VectorType const& stream_data, + std::size_t stream_idx) { - - - BOOST_LOG_TRIVIAL(debug) << "Creating stream based on stream prototype: " << stream_data.describe(); ObservationHeader header = obs_header; BOOST_LOG_TRIVIAL(info) << "Header: " << header.to_string(); - + // Here we round the file size to a multiple of the stream prototype std::size_t filesize = std::max(1ul, @@ -199,36 +202,40 @@ std::unique_ptr create_sigproc_file_stream(MultiFileWriterCon sizeof(typename VectorType::value_type)) * stream_data.size(); BOOST_LOG_TRIVIAL(debug) - << "Maximum allowed file size = " << filesize << " bytes (+header)"; + << "Maximum allowed file size = " << filesize << " bytes (+header)"; - - double foff = -1* static_cast(header.obs_bandwidth / header.nchans)/1e6;// MHz - // 1* foff instead of 0.5* foff below because the dedispersion causes all the frequencies to change by half the bandwidth to refer to the bottom of the channel - double fch1 = static_cast(header.obs_frequency + header.obs_bandwidth / 2.0)/1e6 + foff; // MHz + double foff = -1 * + static_cast(header.obs_bandwidth / header.nchans) / + 1e6; // MHz + // 1* foff instead of 0.5* foff below because the dedispersion causes all + // the frequencies to change by half the bandwidth to refer to the bottom of + // the channel + double fch1 = + static_cast(header.obs_frequency + header.obs_bandwidth / 2.0) / + 1e6 + + foff; // MHz double utc_start = static_cast(header.utc_start); header.mjd_start = (utc_start / 86400.0) + MJD_UNIX_EPOCH; - uint32_t datatype = 0; + uint32_t datatype = 0; uint32_t barycentric = 0; // uint32_t ibeam = 0; - double az = 0.0; - double za = 0.0; - uint32_t nifs = header.npol; + double az = 0.0; + double za = 0.0; + uint32_t nifs = 1; header.sigproc_params = true; - header.rawfile = std::string("unset"); - header.fch1 = fch1; - header.foff = foff; - header.tsamp = header.tsamp/1e6; - header.az = az; - header.za = za; - header.datatype = datatype; - header.barycentric = barycentric; - header.nifs = nifs; - header.telescopeid = 64; - - + header.rawfile = std::string("unset"); + header.fch1 = fch1; + header.foff = foff; + header.tsamp = header.tsamp / 1e6; + header.az = az; + header.za = za; + header.datatype = datatype; + header.barycentric = barycentric; + header.nifs = nifs; + header.telescopeid = 64; BOOST_LOG_TRIVIAL(info) << "Creating Sigproc file stream"; @@ -246,42 +253,41 @@ std::unique_ptr create_sigproc_file_stream(MultiFileWriterCon // // causing compiler bugs prior to g++ 5.x // char formatted_time[80]; // strftime(formatted_time, 80, "%Y-%m-%d-%H:%M:%S", ptm); - // base_filename << formatted_time; - - //make config.output_dir if it does not exist - - - std::unique_ptr file_stream = std::make_unique( - config.output_dir, - config.output_basename, - config.extension, - filesize, - [header](std::size_t& header_size, - std::size_t bytes_written, - std::size_t file_idx) -> std::shared_ptr { - // We do not explicitly delete[] this array - // Cleanup is handled by the shared pointer - // created below - std::ostringstream header_stream; - // get ostream from temp_header - - - - SigprocHeader sigproc_header(header); - double mjd_offset = (((bytes_written / (header.nbits / 8.0)) / (header.nchans)) * - header.tsamp) / - (86400.0); - sigproc_header.add_time_offset(mjd_offset); - sigproc_header.write_header(header_stream); - std::string header_str = header_stream.str(); - header_size = header_str.size(); - char* header_cstr = new char[header_size]; - std::copy(header_str.begin(), header_str.end(), header_cstr); - std::shared_ptr header_ptr( - header_cstr, - std::default_delete()); - return header_ptr; - }); + // base_filename << formatted_time; + + // make config.output_dir if it does not exist + + std::unique_ptr file_stream = + std::make_unique( + config.output_dir, + config.output_basename, + config.extension, + filesize, + [header](std::size_t& header_size, + std::size_t bytes_written, + std::size_t file_idx) -> std::shared_ptr { + // We do not explicitly delete[] this array + // Cleanup is handled by the shared pointer + // created below + std::ostringstream header_stream; + // get ostream from temp_header + + SigprocHeader sigproc_header(header); + double mjd_offset = (((bytes_written / (header.nbits / 8.0)) / + (header.nchans)) * + header.tsamp) / + (86400.0); + sigproc_header.add_time_offset(mjd_offset); + sigproc_header.write_header(header_stream); + std::string header_str = header_stream.str(); + header_size = header_str.size(); + char* header_cstr = new char[header_size]; + std::copy(header_str.begin(), header_str.end(), header_cstr); + std::shared_ptr header_ptr( + header_cstr, + std::default_delete()); + return header_ptr; + }); return file_stream; } diff --git a/cpp/skyweaver/skycleaver_utils.hpp b/cpp/skyweaver/skycleaver_utils.hpp new file mode 100644 index 0000000..d0ad9e9 --- /dev/null +++ b/cpp/skyweaver/skycleaver_utils.hpp @@ -0,0 +1,53 @@ +#ifndef SKYWEAVER_SKYCLEAVER_UTILS_HPP +#define SKYWEAVER_SKYCLEAVER_UTILS_HPP +namespace skyweaver +{ + +template +std::vector +get_list_from_string(const std::string& value, + T epsilon = std::numeric_limits::epsilon()) +{ + std::vector output; + std::vector comma_chunks; + + // Split the input string by commas + std::stringstream ss(value); + std::string token; + while(std::getline(ss, token, ',')) { comma_chunks.push_back(token); } + + for(const auto& comma_chunk: comma_chunks) { + // Check if the chunk contains a colon (indicating a range) + if(comma_chunk.find(':') == std::string::npos) { + output.push_back(static_cast(std::atof(comma_chunk.c_str()))); + continue; + } + + // Split the range chunk by colons + std::stringstream ss_chunk(comma_chunk); + std::vector colon_chunks; + std::string colon_token; + while(std::getline(ss_chunk, colon_token, ':')) { + colon_chunks.push_back( + static_cast(std::atof(colon_token.c_str()))); + } + + // Determine the step size + T step = colon_chunks.size() == 3 ? colon_chunks[2] : static_cast(1); + T start = colon_chunks[0]; + T stop = colon_chunks[1]; + + // Loop and add values to the output vector + if constexpr(std::is_floating_point::value) { + for(T k = start; k <= stop + epsilon; k += step) { + output.push_back(k); + } + } else { + for(T k = start; k <= stop; k += step) { output.push_back(k); } + } + } + return output; +} + +} +#endif // SKYWEAVER_SKYCLEAVER_UTILS_HPP \ No newline at end of file diff --git a/cpp/skyweaver/src/CoherentDedisperser.cu b/cpp/skyweaver/src/CoherentDedisperser.cu index 52e4f41..94b0d31 100644 --- a/cpp/skyweaver/src/CoherentDedisperser.cu +++ b/cpp/skyweaver/src/CoherentDedisperser.cu @@ -13,22 +13,64 @@ #include #include #include + +namespace +{ + +#define NCHANS_PER_BLOCK 128 + +// Function to check if a number's prime factors are only 2, 3, 5, or +bool has_small_prime_factors(size_t n) +{ + if(n == 0) + return false; + while(n % 2 == 0) n /= 2; + while(n % 3 == 0) n /= 3; + while(n % 5 == 0) n /= 5; + while(n % 7 == 0) n /= 7; + return n == 1; +} + +// Function to find the next optimal FFT size +size_t next_optimal_fft_size(size_t N) +{ + size_t n = N; + while(!has_small_prime_factors(n)) { n++; } + return n; +} + +// Function to compute padding amount +size_t compute_padding(size_t N) +{ + size_t optimal_size = 0; + size_t n = N - 1; + + do { + n = n + 1; + optimal_size = next_optimal_fft_size(n); + BOOST_LOG_TRIVIAL(info) << "Trying optimal size: " << optimal_size; + } while(optimal_size % NCHANS_PER_BLOCK != 0); + return optimal_size - N; +} + +} // namespace + namespace skyweaver { void create_coherent_dedisperser_config(CoherentDedisperserConfig& config, PipelineConfig const& pipeline_config) { - // the centre frequency and bandwidth are for the bridge. This is taken from Observation Header (not from the user) + // the centre frequency and bandwidth are for the bridge. This is taken from + // Observation Header (not from the user) float f_low = pipeline_config.centre_frequency() - pipeline_config.bandwidth() / 2.0f; - float f_high = f_low + pipeline_config.bandwidth()/pipeline_config.nchans(); - - // pipeline_config.centre_frequency() + pipeline_config.bandwidth() / 2.0f; - float tsamp = pipeline_config.nchans() / pipeline_config.bandwidth(); + float f_high = + f_low + pipeline_config.bandwidth() / pipeline_config.nchans(); - + // pipeline_config.centre_frequency() + pipeline_config.bandwidth() / 2.0f; + float tsamp = pipeline_config.nchans() / pipeline_config.bandwidth(); if(pipeline_config.coherent_dms().empty()) { throw std::runtime_error("No coherent DMs specified"); @@ -38,14 +80,15 @@ void create_coherent_dedisperser_config(CoherentDedisperserConfig& config, pipeline_config.coherent_dms().end()); float max_dm = *it; BOOST_LOG_TRIVIAL(debug) << "Constructing coherent dedisperser plan"; - std::size_t max_dm_delay_samps = DMSampleDelay(max_dm, f_low, tsamp)(f_high); + std::size_t max_dm_delay_samps = + DMSampleDelay(max_dm, f_low, tsamp)(f_high); if(max_dm_delay_samps > 2 * pipeline_config.gulp_length_samps()) { throw std::runtime_error( "Gulp length must be at least 2 times the maximum DM delay"); } - if(max_dm_delay_samps %2 !=0) { + if((pipeline_config.gulp_length_samps() + max_dm_delay_samps) % 2 != 0) { max_dm_delay_samps++; } @@ -61,7 +104,7 @@ void create_coherent_dedisperser_config(CoherentDedisperserConfig& config, pipeline_config.coherent_dms()); } /* - * @brief Create a new CoherentDedis§rser object, mostly used only for + * @brief Create a new CoherentDedisperser object, mostly used only for * testing * * @param config The config reference @@ -78,20 +121,27 @@ void create_coherent_dedisperser_config(CoherentDedisperserConfig& config, std::vector dms) { config.gulp_samps = gulp_samps; - config.overlap_samps = overlap_samps; config.num_coarse_chans = num_coarse_chans; config.npols = npols; config.nantennas = nantennas; config.tsamp = tsamp; - config.low_freq = low_freq; - config.bw = bw; - config.high_freq = low_freq + bw; - config.coarse_chan_bw = bw / num_coarse_chans; - config.filter_delay = tsamp * overlap_samps / 2.0; - BOOST_LOG_TRIVIAL(warning) << "tsamp in create_coherent_dedisperser_config: " << config.tsamp; - BOOST_LOG_TRIVIAL(warning) << "overlap_samps in create_coherent_dedisperser_config: " << config.overlap_samps; - BOOST_LOG_TRIVIAL(warning) << "Filter delay: " << config.filter_delay; + config.low_freq = low_freq; + config.bw = bw; + config.high_freq = low_freq + bw; + config.coarse_chan_bw = bw / num_coarse_chans; + config.filter_delay = tsamp * overlap_samps / 2.0; + BOOST_LOG_TRIVIAL(debug) + << "tsamp in create_coherent_dedisperser_config: " << config.tsamp; + BOOST_LOG_TRIVIAL(info) + << "overlap_samps just due to dedispersion: " << overlap_samps; + BOOST_LOG_TRIVIAL(info) << "Filter delay: " << config.filter_delay; + + config.overlap_samps = + compute_padding(gulp_samps + overlap_samps) + overlap_samps; + BOOST_LOG_TRIVIAL(info) + << "overlap_samps after padding for optimal FFT size " + << config.overlap_samps; /* Precompute DM constants */ config._h_dms = dms; @@ -99,7 +149,8 @@ void create_coherent_dedisperser_config(CoherentDedisperserConfig& config, config._d_dm_prefactor.resize(dms.size()); config._d_ism_responses.resize(dms.size()); for(int i = 0; i < dms.size(); i++) { - config._d_ism_responses[i].resize(num_coarse_chans * gulp_samps); + config._d_ism_responses[i].resize( + num_coarse_chans * (config.gulp_samps + config.overlap_samps)); } thrust::transform(config._d_dms.begin(), @@ -118,16 +169,16 @@ void create_coherent_dedisperser_config(CoherentDedisperserConfig& config, // data is FTPA order, we will loop over F, so we are left with TPA order. // Let's fuse PA to X, so TX order. // We stride and batch over X and transform T - std::size_t X = config.npols * config.nantennas; - std::size_t fft_size = config.gulp_samps + config.overlap_samps; - int n[1] = {static_cast(fft_size)}; // FFT size - int inembed[1] = {static_cast(fft_size)}; - int onembed[1] = {static_cast(fft_size)}; - int istride = X; - int ostride = X; - int idist = 1; - int odist = 1; - int batch = X; + std::size_t X = config.npols * config.nantennas; + std::size_t fft_size = config.gulp_samps + config.overlap_samps; + int n[1] = {static_cast(fft_size)}; // FFT size + int inembed[1] = {static_cast(fft_size)}; + int onembed[1] = {static_cast(fft_size)}; + int istride = X; + int ostride = X; + int idist = 1; + int odist = 1; + int batch = X; if(cufftPlanMany(&config._fft_plan, 1, @@ -146,12 +197,6 @@ void create_coherent_dedisperser_config(CoherentDedisperserConfig& config, BOOST_LOG_TRIVIAL(debug) << "FFT plan created"; } - -namespace -{ -#define NCHANS_PER_BLOCK 128 -} // namespace - void CoherentDedisperser::dedisperse( TPAVoltagesD const& d_tpa_voltages_in, FTPAVoltagesD& d_ftpa_voltages_out, @@ -159,6 +204,7 @@ void CoherentDedisperser::dedisperse( unsigned int dm_idx) { BOOST_LOG_NAMED_SCOPE("CoherentDedisperser::dedisperse"); + // d_tpa_voltages_in.size() is with overlap _d_fpa_spectra.resize(d_tpa_voltages_in.size(), {0.0f, 0.0f}); _d_tpa_voltages_in_cufft.resize(d_tpa_voltages_in.size(), {0.0f, 0.0f}); _d_tpa_voltages_dedispersed.resize(d_tpa_voltages_in.size(), {0.0f, 0.0f}); @@ -206,8 +252,8 @@ void CoherentDedisperser::dedisperse( BOOST_LOG_TRIVIAL(debug) << "Executed inverse FFT"; - std::size_t out_offset = freq_idx * _config.nantennas * _config.npols * - (_config.gulp_samps); + std::size_t out_offset = + freq_idx * _config.nantennas * _config.npols * (_config.gulp_samps); std::size_t discard_size = _config.nantennas * _config.npols * _config.overlap_samps / 2; @@ -216,26 +262,30 @@ void CoherentDedisperser::dedisperse( BOOST_LOG_TRIVIAL(debug) << "copying from input from " << discard_size << " to " << _d_tpa_voltages_dedispersed.size() - discard_size; + BOOST_LOG_TRIVIAL(debug) + << "Total elements copied: " + << _d_tpa_voltages_dedispersed.size() - 2 * discard_size; + BOOST_LOG_TRIVIAL(debug) << "Remaining space in output: " + << d_ftpa_voltages_out.size() - out_offset; BOOST_LOG_TRIVIAL(debug) << "copying to output from " << out_offset << " to " << out_offset + _d_tpa_voltages_dedispersed.size() - 2 * discard_size; - - std::size_t fft_size = _config.gulp_samps + _config.overlap_samps; - - - // transform: divide by d_tpa_voltages_in.size() - thrust::transform(_d_tpa_voltages_dedispersed.begin() + discard_size, - _d_tpa_voltages_dedispersed.end() - discard_size, - d_ftpa_voltages_out.begin() + out_offset, - [=] __device__(cufftComplex const& val) { - char2 char2_val; - char2_val.x = static_cast( - __float2int_rn(val.x / fft_size)); // scale the data back - char2_val.y = - static_cast(__float2int_rn(val.y / fft_size)); - return char2_val; - }); + std::size_t fft_size = _config.gulp_samps + _config.overlap_samps; + + thrust::transform( + _d_tpa_voltages_dedispersed.begin() + discard_size, + _d_tpa_voltages_dedispersed.end() - discard_size, + d_ftpa_voltages_out.begin() + out_offset, + [=] __device__(cufftComplex const& val) { + char2 char2_val; + char2_val.x = static_cast( + __float2int_rn(val.x / fft_size)); // scale the data back + char2_val.y = static_cast(__float2int_rn(val.y / fft_size)); + return char2_val; + }); + + BOOST_LOG_TRIVIAL(debug) << "Transformed dedispersed voltages to char2"; d_ftpa_voltages_out.reference_dm(_config._h_dms[dm_idx]); } @@ -246,15 +296,25 @@ void CoherentDedisperser::multiply_by_chirp( unsigned int freq_idx, unsigned int dm_idx) { - std::size_t total_chans = _config._d_ism_responses[dm_idx].size(); - std::size_t response_offset = freq_idx * _config.gulp_samps; + std::size_t total_chans = + _config._d_ism_responses[dm_idx].size(); // all coarse + fine chans + std::size_t fft_size = + _config.gulp_samps + _config.overlap_samps; // ONLY FINE CHANS + std::size_t response_offset = freq_idx * fft_size; BOOST_LOG_TRIVIAL(debug) << "Freq idx: " << freq_idx; BOOST_LOG_TRIVIAL(debug) << "_config.gulp_samps: " << _config.gulp_samps; + BOOST_LOG_TRIVIAL(debug) + << "_config.overlap_samps: " << _config.overlap_samps; BOOST_LOG_TRIVIAL(debug) << "response_offset: " << response_offset; + BOOST_LOG_TRIVIAL(debug) << "total_chans: " << total_chans; + BOOST_LOG_TRIVIAL(debug) + << "chirp multiply input size: " << _d_fpa_spectra_in.size(); + BOOST_LOG_TRIVIAL(debug) + << "chirp multiply output size: " << _d_fpa_spectra_out.size(); dim3 blockSize(_config.nantennas * _config.npols); - dim3 gridSize(_config.gulp_samps / NCHANS_PER_BLOCK); + dim3 gridSize(fft_size / NCHANS_PER_BLOCK); kernels::dedisperse<<>>( thrust::raw_pointer_cast(_config._d_ism_responses[dm_idx].data() + response_offset), @@ -273,14 +333,16 @@ __global__ void dedisperse(cufftComplex const* __restrict__ _d_ism_response, cufftComplex* out, unsigned total_chans) { - const unsigned pa_size = blockDim.x; + const unsigned pa_size = blockDim.x; // NANT * NPOL volatile __shared__ cufftComplex response[NCHANS_PER_BLOCK]; - const unsigned block_start_chan_idx = blockIdx.x * NCHANS_PER_BLOCK; + const unsigned block_start_chan_idx = + blockIdx.x * NCHANS_PER_BLOCK; // coarse chan idx const unsigned remainder = - min(total_chans - block_start_chan_idx, NCHANS_PER_BLOCK); + min(total_chans - block_start_chan_idx, + NCHANS_PER_BLOCK); // how many channels to process for(int idx = threadIdx.x; idx < remainder; idx += pa_size) { cufftComplex const temp = _d_ism_response[block_start_chan_idx + idx]; @@ -330,8 +392,7 @@ struct DMResponse { int chan = tid / num_fine_chans; // Coarse channel int fine_chan = tid % num_fine_chans; // fine channel - double nu_0 = low_freq + chan * coarse_chan_bw - - 0.5f * coarse_chan_bw; // + fine_chan * fine_chan_bw; + double nu_0 = low_freq + chan * coarse_chan_bw; double nu = fine_chan * fine_chan_bw; // fine_chan_freq @@ -341,7 +402,7 @@ struct DMResponse { cufftDoubleComplex weight; sincos(phase, &weight.y, - &weight.x); // TO DO: test if it is not approximate + &weight.x); // TODO: test if it is not approximate cufftComplex float_weight; float_weight.x = static_cast(weight.x); float_weight.y = static_cast(weight.y); @@ -357,20 +418,39 @@ void get_dm_responses(CoherentDedisperserConfig& config, thrust::device_vector& response) { BOOST_LOG_TRIVIAL(debug) << "Generating DM responses"; - thrust::device_vector indices(config.num_coarse_chans * - config.gulp_samps); + std::size_t fft_size = config.gulp_samps + config.overlap_samps; + + thrust::device_vector indices(config.num_coarse_chans * fft_size); thrust::sequence(indices.begin(), indices.end()); + // store raw responses in a temporary variable + thrust::device_vector temp_response(response.size()); + BOOST_LOG_TRIVIAL(warning) << "DOING FFTSHIFT" << std::endl; // Apply the DMResponse functor using thrust's transform thrust::transform(indices.begin(), indices.end(), - response.begin(), + temp_response.begin(), kernels::DMResponse(config.num_coarse_chans, - config.gulp_samps, + fft_size, config.low_freq, config.coarse_chan_bw, config.fine_chan_bw, dm_prefactor)); + + // rotate the response to match cufft output + + if(fft_size %2 != 0) { + throw std::runtime_error("FFT size must be even."); + } + + std::size_t shift = fft_size / 2; + + for(auto it = temp_response.begin(), it_out = response.begin(); + it < temp_response.end(); + it += fft_size, it_out += fft_size) { + thrust::copy(it, it + shift, it_out + shift); + thrust::copy(it + shift, it + fft_size, it_out); + } } } // namespace skyweaver diff --git a/cpp/skyweaver/src/FileOutputStream.cpp b/cpp/skyweaver/src/FileOutputStream.cpp index f4d2ffa..40ca0ad 100644 --- a/cpp/skyweaver/src/FileOutputStream.cpp +++ b/cpp/skyweaver/src/FileOutputStream.cpp @@ -38,15 +38,16 @@ void create_directories(const fs::path& path) } FileOutputStream::File::File(std::string const& fname, std::size_t bytes) - : _full_path(fname), _bytes_requested(bytes), _bytes_written(0) + : _full_path(fname), _bytes_requested(bytes), _bytes_written(0), _temporary_suffix(".tmp") { + _temporary_path = _full_path + _temporary_suffix; _stream.exceptions(std::ofstream::failbit | std::ofstream::badbit); - _stream.open(_full_path, std::ofstream::out | std::ofstream::binary); + _stream.open(_temporary_path, std::ofstream::out | std::ofstream::binary); if(_stream.is_open()) { - BOOST_LOG_TRIVIAL(info) << "Opened output file " << _full_path; + BOOST_LOG_TRIVIAL(info) << "Opened output file " << _temporary_path; } else { std::stringstream error_message; - error_message << "Could not open file " << _full_path; + error_message << "Could not open file " << _temporary_path; BOOST_LOG_TRIVIAL(error) << error_message.str(); throw std::runtime_error(error_message.str()); } @@ -55,15 +56,29 @@ FileOutputStream::File::File(std::string const& fname, std::size_t bytes) FileOutputStream::File::~File() { if(_stream.is_open()) { - BOOST_LOG_TRIVIAL(info) << "Closing file " << _full_path; + BOOST_LOG_TRIVIAL(info) << "Closing file " << _temporary_path; _stream.close(); } + BOOST_LOG_TRIVIAL(info) << "Renaming file " << _temporary_path + << " to " << _full_path; + int res = std::rename(_temporary_path.c_str(), _full_path.c_str()); + + // Can't throw an exception from a destructor, + // but we'll at least put a warning in the log! + if (res) + { + std::stringstream error_message; + error_message << "Error renaming file " << _temporary_path + << " to " << _full_path + << " (" << res << ")."; + BOOST_LOG_TRIVIAL(error) << error_message.str(); + } } std::size_t FileOutputStream::File::write(char const* ptr, std::size_t bytes) { BOOST_LOG_TRIVIAL(debug) - << "Writing " << bytes << " bytes to " << _full_path; + << "Writing " << bytes << " bytes to " << _temporary_path; std::size_t bytes_remaining = _bytes_requested - _bytes_written; try { if(bytes > bytes_remaining) { @@ -89,7 +104,7 @@ std::size_t FileOutputStream::File::write(char const* ptr, std::size_t bytes) reason = "eofbit set."; } - BOOST_LOG_TRIVIAL(error) << "Error while writing to " << _full_path + BOOST_LOG_TRIVIAL(error) << "Error while writing to " << _temporary_path << " (" << e.what() << ") because of reason: " << reason; throw; diff --git a/cpp/skyweaver/src/ObservationHeader.cpp b/cpp/skyweaver/src/ObservationHeader.cpp index f554a59..e17c282 100644 --- a/cpp/skyweaver/src/ObservationHeader.cpp +++ b/cpp/skyweaver/src/ObservationHeader.cpp @@ -10,7 +10,7 @@ std::vector parse_float_list(std::string const& str) { std::vector values; std::size_t start = 0; - std::size_t end = 0; + std::size_t end = 0; while(end != std::string::npos) { end = str.find(',', start); values.push_back(std::stof(str.substr(start, end - start))); @@ -25,7 +25,7 @@ void read_dada_header(psrdada_cpp::RawBytes& raw_header, Header parser(raw_header); header.order = parser.get("ORDER"); - if(header.order.find("A") != std::string::npos) { + if(header.order.find("A") != std::string::npos) { header.nantennas = parser.get("NANT"); header.sample_clock = parser.get("SAMPLE_CLOCK"); @@ -35,16 +35,16 @@ void read_dada_header(psrdada_cpp::RawBytes& raw_header, header.sync_time = parser.get("SYNC_TIME"); } - - header.refdm = parser.get_or_default("COHERENT_DM", 0.0); + header.refdm = + parser.get_or_default("COHERENT_DM", 0.0); - header.npol = parser.get("NPOL"); - header.nbits = parser.get("NBIT"); - header.nchans = parser.get("NCHAN"); + header.npol = parser.get("NPOL"); + header.nbits = parser.get("NBIT"); + header.nchans = parser.get("NCHAN"); - header.bandwidth = parser.get("BW"); - header.frequency = parser.get("FREQ"); + header.bandwidth = parser.get("BW"); + header.frequency = parser.get("FREQ"); header.tsamp = parser.get("TSAMP"); @@ -58,18 +58,24 @@ void read_dada_header(psrdada_cpp::RawBytes& raw_header, header.chan0_idx = parser.get("CHAN0_IDX"); header.obs_offset = parser.get("OBS_OFFSET"); - header.obs_bandwidth = parser.get_or_default("OBS_BW", header.bandwidth); - header.obs_nchans = parser.get_or_default("OBS_NCHAN", header.nchans); - header.obs_frequency = parser.get_or_default("OBS_FREQUENCY", - parser.get_or_default("OBS_FREQ", - header.frequency)); - - header.ndms = parser.get_or_default("NDMS", "0"); - if(header.ndms != "0") { + header.obs_bandwidth = + parser.get_or_default("OBS_BW", 856e6); + header.obs_nchans = + parser.get_or_default("OBS_NCHAN", 4096); + header.obs_frequency = + parser.get_or_default( + "OBS_FREQUENCY", + parser.get_or_default( + "OBS_FREQ", + header.frequency)); + + header.ndms = parser.get_or_default("NDMS", 0); + header.nbeams = parser.get_or_default("NBEAM", 1); + header.stokes_mode = + parser.get_or_default("STOKES_MODE", "I"); + if(header.ndms != 0) { header.dms = parse_float_list(parser.get("DMS")); } - - } void validate_header(ObservationHeader const& header, PipelineConfig const& config) @@ -104,7 +110,6 @@ void update_config(PipelineConfig& config, ObservationHeader const& header) // TO DO: might need to add other variables in the future. } - bool are_headers_similar(ObservationHeader const& header1, ObservationHeader const& header2) { @@ -150,17 +155,13 @@ std::string ObservationHeader::to_string() const << " instrument: " << instrument << "\n" << " chan0_idx: " << chan0_idx << "\n" << " obs_offset: " << obs_offset << "\n"; - if(ndms != "0") { + if(ndms != 0) { oss << " ndms: " << ndms << "\n"; oss << " dms: "; - for(auto dm : dms) { - oss << dm << " "; - } + for(auto dm: dms) { oss << dm << " "; } oss << "\n"; } - if(sigproc_params) - { - + if(sigproc_params) { oss << " Sigproc parameters:\n" << " az: " << az << "\n" << " za: " << za << "\n" diff --git a/cpp/skyweaver/src/PipelineConfig.cpp b/cpp/skyweaver/src/PipelineConfig.cpp index c2cdf3b..79fafb1 100644 --- a/cpp/skyweaver/src/PipelineConfig.cpp +++ b/cpp/skyweaver/src/PipelineConfig.cpp @@ -15,7 +15,8 @@ PipelineConfig::PipelineConfig() _bw(13375000.0), _channel_frequencies_stale(true), _gulp_length_samps(4096), _start_time(0.0f), _duration(std::numeric_limits::infinity()), _total_nchans(4096), - _stokes_mode("I"), _output_level(24.0f) + _stokes_mode("I"), _output_level(24.0f), _output_statistics(true), _output_incoherent_beam(true), + _pre_write_config({0, {false, 0, 0}}) { } @@ -157,7 +158,64 @@ DedispersionPlan const& PipelineConfig::ddplan() const return _ddplan; } - +std::size_t PipelineConfig::convertMemorySize(const std::string& str) const { + std::size_t lastCharPos = str.find_last_not_of("0123456789"); + std::string numberPart = str.substr(0, lastCharPos); + std::string unitPart = str.substr(lastCharPos); + + std::size_t number = std::stoull(numberPart); + + if (unitPart.empty()) + return number; + else if (unitPart == "K" || unitPart == "k") + return number * 1024; + else if (unitPart == "M" || unitPart == "m") + return number * 1024 * 1024; + else if (unitPart == "G" || unitPart == "g") + return number * 1024 * 1024 * 1024; + else + throw std::runtime_error("Invalid memory unit!"); +} + +void PipelineConfig::configure_wait(std::string argument) +{ + std::vector tokens; + std::string token; + std::istringstream tokenStream(argument); + int indx = 0; + _pre_write_config.is_enabled = true; + while (std::getline(tokenStream, token, ':')) { + if(indx == 0) + { + errno = 0; + _pre_write_config.wait.iterations = std::stoi(token); + if (errno == ERANGE) { + throw std::runtime_error("Wait iteration number out of range!"); + } + if (_pre_write_config.wait.iterations < 0) _pre_write_config.wait.iterations = 0; + } else if(indx == 1) { + errno = 0; + _pre_write_config.wait.sleep_time = std::stoi(token); + if (errno == ERANGE) { + throw std::runtime_error("Sleep time out of range!"); + } + if (_pre_write_config.wait.sleep_time < 1) _pre_write_config.wait.sleep_time = 1; + } else if(indx == 2) { + if (!token.empty() && std::all_of(token.begin(), token.end(), ::isdigit)) + { + _pre_write_config.wait.min_free_space = std::stoull(token); + } else { + try { + _pre_write_config.wait.min_free_space = convertMemorySize(token); + } catch (std::runtime_error& e) { + std::cout << "Memory conversion error: " << e.what() << std::endl; + throw; + } + } + } + indx++; + } +} void PipelineConfig::enable_incoherent_dedispersion(bool enable) { @@ -169,6 +227,26 @@ bool PipelineConfig::enable_incoherent_dedispersion() const return _enable_incoherent_dedispersion; } +void PipelineConfig::output_statistics(bool enable) +{ + _output_statistics = enable; +} + +bool PipelineConfig::output_statistics() const +{ + return _output_statistics; +} + +void PipelineConfig::output_incoherent_beam(bool enable) +{ + _output_incoherent_beam = enable; +} + +bool PipelineConfig::output_incoherent_beam() const +{ + return _output_incoherent_beam; +} + std::vector const& PipelineConfig::channel_frequencies() const { if(_channel_frequencies_stale) { diff --git a/cpp/skyweaver/src/SkyCleaver.cu b/cpp/skyweaver/src/SkyCleaver.cu deleted file mode 100644 index 9df09e3..0000000 --- a/cpp/skyweaver/src/SkyCleaver.cu +++ /dev/null @@ -1,424 +0,0 @@ - -#include -#include -#include -#include -#include -#include -#include "skyweaver/types.cuh" -#include "skyweaver/SkyCleaver.cuh" - -namespace fs = std::filesystem; -using SkyCleaver = skyweaver::SkyCleaver; - - -using FreqType = skyweaver::SkyCleaver::FreqType; -using BridgeReader = skyweaver::BridgeReader; -using MultiFileReader = skyweaver::MultiFileReader; -using OutputVectorType = skyweaver::SkyCleaver::OutputVectorType; -using InputVectorType = skyweaver::SkyCleaver::InputVectorType; - -namespace -{ -template -std::string to_string_with_padding(T num, int width, int precision = -1) -{ - std::ostringstream oss; - oss << std::setw(width) << std::setfill('0'); - if(precision >= - 0) { // Check if precision is specified for floating-point numbers - oss << std::fixed << std::setprecision(precision); - } - oss << num; - return oss.str(); -} -std::vector -get_subdirs(std::string directory_path, - std::regex numeric_regex = std::regex("^[0-9]+$")) -{ - std::vector subdirs; - try { - if(fs::exists(directory_path) && fs::is_directory(directory_path)) { - for(const auto& entry: fs::directory_iterator(directory_path)) { - if(fs::is_directory(entry.status())) { - std::string folder_name = entry.path().filename().string(); - if(std::regex_match(folder_name, numeric_regex)) { - BOOST_LOG_TRIVIAL(debug) - << "Found subdirectory: " << folder_name; - subdirs.push_back(folder_name); - } - } - } - } else { - std::runtime_error( - "Root directory does not exist or is not a directory."); - } - } catch(const fs::filesystem_error& e) { - std::cerr << "Filesystem error: " << e.what() << std::endl; - std::runtime_error("Error reading subdirectories in root directory: " + - directory_path); - } - - return subdirs; -} - -std::vector get_files(std::string directory_path, - std::string extension) -{ - std::vector files; - try { - if(fs::exists(directory_path) && fs::is_directory(directory_path)) { - for(const auto& entry: fs::directory_iterator(directory_path)) { - if(fs::is_regular_file(entry.status())) { - std::string file_name = entry.path().string(); - if(file_name.find(extension) != std::string::npos) { - files.push_back(file_name); - } - } - } - } else { - std::runtime_error("No files in bridge directory: " + - directory_path); - } - } catch(const fs::filesystem_error& e) { - std::cerr << "Filesystem error: " << e.what() << std::endl; - std::runtime_error("Error reading files in bridge directory: " + - directory_path); - } - - return files; -} - -} // namespace - -void SkyCleaver::init_readers() -{ - BOOST_LOG_NAMED_SCOPE("SkyCleaver::init_readers") - - std::string root_dir = _config.root_dir(); - std::string root_prefix = _config.root_prefix(); - std::size_t stream_id = _config.stream_id(); - - // get the list of directories in root/stream_id(for the nex) - std::vector freq_dirs = - get_subdirs(root_dir + "/" + std::to_string(stream_id)); - - BOOST_LOG_TRIVIAL(info) - << "Found " << freq_dirs.size() - << " frequency directories in root directory: " << root_dir; - - std::map bridge_timestamps; - long double latest_timestamp = 0.0; - - for(const auto& freq_dir: freq_dirs) { - std::vector tdb_files = get_files( - root_dir + "/" + std::to_string(stream_id) + "/" + freq_dir, - ".tdb"); - BOOST_LOG_TRIVIAL(info) << "Found " << tdb_files.size() - << " TDB files for frequency: " << freq_dir; - if(tdb_files.empty()) { - BOOST_LOG_TRIVIAL(warning) - << "No TDB files found for frequency: " << freq_dir; - continue; - } - - std::size_t freq = static_cast(std::stoul(freq_dir)); - - _bridge_readers[freq] = - std::make_unique(tdb_files, - _config.dada_header_size(), - false); - long double timestamp = _bridge_readers[freq]->get_header().utc_start; - bridge_timestamps.insert({freq, timestamp}); - if(timestamp > latest_timestamp) { - latest_timestamp = timestamp; - } - _available_freqs.push_back(freq); - - BOOST_LOG_TRIVIAL(debug) - << "Added bridge reader for frequency: " << freq_dir; - } - - int nbridges = _config.nbridges(); - - std::size_t gulp_size = - _config.nsamples_per_block() * _config.ndms() * _config.nbeams(); - - ObservationHeader const& header = - (*_bridge_readers.begin()).second->get_header(); - BOOST_LOG_TRIVIAL(info) - << "Read header from first file: " << header.to_string(); - - float obs_centre_freq = header.obs_frequency; - float obs_bandwidth = header.obs_bandwidth; - for(int i = 0; i < nbridges; i++) { - int ifreq = std::lround(obs_centre_freq - obs_bandwidth / 2 + - (i + 0.5) * obs_bandwidth / nbridges); - _expected_freqs.push_back(ifreq); - BOOST_LOG_TRIVIAL(info) - << "Expected frequency [" << i << "]: " << ifreq; - - if(_bridge_readers.find(ifreq) == _bridge_readers.end()) { - BOOST_LOG_TRIVIAL(warning) - << "Frequency " << ifreq - << " not found in bridge readers, will write zeros"; - } - _bridge_data[ifreq] = std::make_unique( - std::initializer_list{_config.nsamples_per_block(), - _config.ndms(), - _config.nbeams()}, - 0); - } - - std::size_t smallest_data_size = std::numeric_limits::max(); - - for(const auto& [freq, reader]: _bridge_readers) { - // at this point, all non-existed frequencies have been added with zero - // data now check if there are any unexpected frequencies in the bridge - // readers. - if(std::find(_expected_freqs.begin(), _expected_freqs.end(), freq) == - _expected_freqs.end()) { - throw std::runtime_error("Frequency " + std::to_string(freq) + - " not found in expected frequencies"); - } - - // now time align all the bridges to the latest timestamp - long double timestamp = bridge_timestamps[freq]; - long double time_diff = latest_timestamp - timestamp; - long double tsamp = - reader->get_header().tsamp * - 1e-6; // Header has it in microseconds, converting to seconds - std::size_t nsamples = std::floor(time_diff / tsamp); - - BOOST_LOG_TRIVIAL(info) - << "Frequency: " << freq << " Timestamp: " << timestamp - << "tsamp: " << tsamp << " Latest timestamp: " << latest_timestamp - << " Time difference: " << time_diff - << " Number of samples to skip: " << nsamples; - - BOOST_LOG_TRIVIAL(info) - << "Seeking " << nsamples * _config.ndms() * _config.nbeams() - << " bytes in bridge reader for frequency: " << freq; - - std::size_t bytes_seeking = (nsamples * _config.ndms() * - _config.nbeams() * - sizeof(InputVectorType::value_type)); - - _bridge_readers[freq]->seekg(bytes_seeking, - std::ios_base::beg); - - std::size_t data_size = - _bridge_readers[freq]->get_total_size() - bytes_seeking; - BOOST_LOG_TRIVIAL(debug) - << "Data size for frequency: " << freq << " is " << data_size; - if(data_size < smallest_data_size) { - smallest_data_size = data_size; - } - } - - BOOST_LOG_TRIVIAL(debug) << "Smallest data size: " << smallest_data_size; - BOOST_LOG_TRIVIAL(debug) << "ndm: " << _config.ndms(); - BOOST_LOG_TRIVIAL(debug) << "nbeams: " << _config.nbeams(); - - if(smallest_data_size % (_config.ndms() * _config.nbeams()) != 0) { - std::runtime_error("Data size is not a multiple of ndms * nbeams"); - } - - std::size_t smallest_nsamples = - std::floor(smallest_data_size / _config.ndms() / _config.nbeams()); - - BOOST_LOG_TRIVIAL(info) - << "Smallest data size: " << smallest_data_size - << " Smallest number of samples: " << smallest_nsamples; - - if(smallest_nsamples < _config.nsamples_per_block()) { - std::runtime_error( - "Smallest data size is less than nsamples_per_block"); - } - - _nsamples_to_read = smallest_nsamples; - - BOOST_LOG_TRIVIAL(info) - << "Added " << _bridge_data.size() << " bridge readers to SkyCleaver"; - - _header = _bridge_readers[_available_freqs[0]]->get_header(); - BOOST_LOG_TRIVIAL(info) << "Adding first header to SkyCleaver"; - BOOST_LOG_TRIVIAL(info) << "Header: " << _header.to_string(); - _header.nchans = _header.nchans * _config.nbridges(); - _header.nbeams = _config.nbeams(); -} - -void SkyCleaver::init_writers() -{ - BOOST_LOG_NAMED_SCOPE("SkyCleaver::init_writers") - BOOST_LOG_TRIVIAL(debug) - << "_config.output_dir(); " << _config.output_dir(); - if(!fs::exists(_config.output_dir())) { - fs::create_directories(_config.output_dir()); - } - std::string out_prefix = _config.out_prefix().empty() - ? "" - : _config.out_prefix() + "_"; - std::string output_dir = _config.output_dir(); - - for(int idm = 0; idm < _config.ndms(); idm++) { - - std::string prefix = _config.ndms() > 1 ? out_prefix + "idm_" + - to_string_with_padding(idm, 9, 3) + "_": out_prefix; - - for(int ibeam = 0; ibeam < _config.nbeams(); ibeam++) { - - MultiFileWriterConfig writer_config; - writer_config.header_size = _config.dada_header_size(); - writer_config.max_file_size = _config.max_output_filesize(); - writer_config.stokes_mode = _config.stokes_mode(); - writer_config.base_output_dir = output_dir; - writer_config.prefix = prefix + "cb_" + to_string_with_padding(ibeam, 5);; - writer_config.extension = ".fil"; - - BOOST_LOG_TRIVIAL(info) - << "Writer config: " << writer_config.to_string(); - - typename MultiFileWriter::CreateStreamCallBackType - create_stream_callback_sigproc = - skyweaver::detail::create_sigproc_file_stream< - OutputVectorType>; - _beam_writers[idm][ibeam] = - std::make_unique>( - writer_config, - "", - create_stream_callback_sigproc); - _header.ibeam = ibeam; - _beam_writers[idm][ibeam]->init(_header); - - _beam_data[idm][ibeam] = std::make_shared( - std::initializer_list{_config.nsamples_per_block(), - _config.nbridges()}, - 0); - - _beam_data[idm][ibeam]->reference_dm(_header.refdm); - - _total_beam_writers++; - } - } - - BOOST_LOG_TRIVIAL(info) - << "Added " << _total_beam_writers << " beam writers to SkyCleaver"; -} - -SkyCleaver::SkyCleaver(SkyCleaverConfig const& config): _config(config) -{ - _timer.start("skycleaver::init_readers"); - init_readers(); - _timer.stop("skycleaver::init_readers"); - _timer.start("skycleaver::init_writers"); - init_writers(); - _timer.stop("skycleaver::init_writers"); -} - -void SkyCleaver::cleave() -{ - BOOST_LOG_NAMED_SCOPE("SkyCleaver::cleave") - - for(std::size_t nsamples_read = 0; nsamples_read < _nsamples_to_read; - nsamples_read += _config.nsamples_per_block()) { - std::size_t gulp_samples = - _nsamples_to_read - nsamples_read < _config.nsamples_per_block() - ? _nsamples_to_read - nsamples_read - : _config.nsamples_per_block(); - - BOOST_LOG_TRIVIAL(info) << "Cleaving samples: " << nsamples_read - << " to " << nsamples_read + gulp_samples; - - std::size_t gulp_size = - gulp_samples * _config.ndms() * _config.nbeams(); - - int nthreads_read = _config.nthreads() > _config.nbridges() - ? _config.nbridges() - : _config.nthreads(); - - omp_set_num_threads(nthreads_read); - - _timer.start("skyweaver::read_data"); - - std::vector read_status( - _available_freqs.size(), - false); // since we cannot throw exceptions in parallel regions -#pragma omp parallel for - for(std::size_t i = 0; i < _available_freqs.size(); i++) { - FreqType freq = _available_freqs[i]; - if(_bridge_readers.find(freq) == _bridge_readers.end()) { - read_status[i] = true; - } - const auto& reader = _bridge_readers[freq]; - if(reader->eof()) { - read_status[i] = true; - } - - std::streamsize read_size = - reader->read(reinterpret_cast(thrust::raw_pointer_cast( - _bridge_data[freq]->data())), - gulp_size); // read a big chunk of data - BOOST_LOG_TRIVIAL(info) - << "Read " << read_size << " bytes from bridge" << freq; - if(read_size < gulp_size * sizeof(InputVectorType::value_type)) { - BOOST_LOG_TRIVIAL(warning) - << "Read less data than expected from bridge " << freq; - read_status[i] = true; - } - } - - if(std::any_of(read_status.begin(), read_status.end(), [](bool status) { - return status; - })) { - std::runtime_error("Some bridges have had unexpected reads"); - } - - BOOST_LOG_TRIVIAL(info) << "Read data from bridge readers"; - - _timer.stop("skyweaver::read_data"); - _timer.start("skyweaver::process_data"); - - omp_set_num_threads(_config.nthreads()); - - std::size_t nbridges = _config.nbridges(); - std::size_t nsamples_per_block = _config.nsamples_per_block(); - std::size_t ndms = _config.ndms(); - std::size_t nbeams = _config.nbeams(); - -#pragma omp parallel for schedule(static) collapse(2) - for(std::size_t ibeam = 0; ibeam < nbeams; ibeam++) { - for(std::size_t idm = 0; idm < ndms; idm++) { -#pragma omp simd - for(std::size_t isample = 0; isample < nsamples_per_block; isample++) { - const std::size_t base_index = - isample * ndms * nbeams + idm * nbeams + ibeam; - - std::size_t ifreq = 0; - const std::size_t out_offset = isample * nbridges; - for(const auto& [freq, ifreq_data]: - _bridge_data) { // for each frequency - _beam_data[idm][ibeam]->at(out_offset + nbridges - 1 - - ifreq) = clamp(127 + ifreq_data->at(base_index)); - ++ifreq; - } - } - } - } - BOOST_LOG_TRIVIAL(info) << "Processed data"; - _timer.stop("skyweaver::process_data"); - - _timer.start("skyweaver::write_data"); - -#pragma omp parallel for schedule(static) collapse(2) - for(int idm = 0; idm < _config.ndms(); idm++) { - for(int ibeam = 0; ibeam < _config.nbeams(); ibeam++) { - _beam_writers[idm][ibeam]->write(*_beam_data[idm][ibeam], - _config.stream_id()); - } - } - - _timer.stop("skyweaver::write_data"); - } - _timer.show_all_timings(); -} diff --git a/cpp/skyweaver/src/skycleaver_cli.cpp b/cpp/skyweaver/src/skycleaver_cli.cpp new file mode 100644 index 0000000..88b4e55 --- /dev/null +++ b/cpp/skyweaver/src/skycleaver_cli.cpp @@ -0,0 +1,279 @@ +#include "skyweaver/DescribedVector.hpp" +#include "skyweaver/SkyCleaver.hpp" +#include "skyweaver/SkyCleaverConfig.hpp" +#include "skyweaver/logging.hpp" +#include "skyweaver/skycleaver_utils.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define BOOST_LOG_DYN_LINK 1 + +namespace +{ + +std::string skycleaver_splash = R"( + __ __ +.-----.| |--.--.--.----.| |.-----.---.-.--.--.-----.----. +|__ --|| <| | | __|| || -__| _ | | | -__| _| +|_____||__|__|___ |____||__||_____|___._|\___/|_____|__| + |_____| + +)"; +const size_t ERROR_IN_COMMAND_LINE = 1; +const size_t SUCCESS = 0; +const size_t ERROR_UNHANDLED_EXCEPTION = 2; + +const char* build_time = __DATE__ " " __TIME__; + + + +} // namespace + +namespace std +{ +std::ostream& operator<<(std::ostream& os, const std::vector& vec) +{ + for(auto item: vec) { os << item << " "; } + return os; +} +} // namespace std +template +void run_pipeline(skyweaver::SkyCleaverConfig& config) +{ + skyweaver::SkyCleaver skycleaver(config); + skycleaver.cleave(); +} + +int main(int argc, char** argv) +{ + std::cout << skycleaver_splash; + std::cout << "Build time: " << build_time << std::endl; + // skyweaver::init_logging("warning"); + + skyweaver::SkyCleaverConfig config; + + namespace po = boost::program_options; + + po::options_description generic("Generic options"); + generic.add_options()("cfg,c", + po::value()->default_value(""), + "Skycleaver configuration file"); + + po::options_description main_options("Main options"); + main_options.add_options()("help,h", "Produce help message")( + "root-dir,r", + po::value()->required()->notifier( + [&config](std::string key) { config.root_dir(key); }), + "The output directory for all results")( + "output-dir", + po::value() + ->default_value(config.output_dir()) + ->notifier([&config](std::string key) { config.output_dir(key); }), + "The output directory for all results")( + "out-prefix", + po::value() + ->default_value(config.out_prefix()) + ->notifier([&config](std::string key) { config.out_prefix(key); }), + "The prefix for all output files")( + "nthreads", + po::value() + ->default_value(config.nthreads()) + ->notifier([&config](unsigned int key) { config.nthreads(key); }), + "The number of threads to use for processing")( + "nsamples-per-block", + po::value() + ->default_value(config.nsamples_per_block()) + ->notifier( + [&config](std::size_t key) { config.nsamples_per_block(key); }), + "The number of samples per block")( + "nbridges", + po::value() + ->default_value(config.nbridges()) + ->notifier([&config](std::size_t key) { config.nbridges(key); }), + "The number of bridges")( + "stream-id", + po::value() + ->default_value(config.stream_id()) + ->notifier([&config](std::size_t key) { config.stream_id(key); }), + "The stream id")( + "max-ram-gb", + po::value() + ->default_value(config.max_ram_gb()) + ->notifier([&config](std::size_t key) { config.max_ram_gb(key); }), + "The maximum amount of RAM to use in GB")( + "max-output-filesize", + po::value() + ->default_value(config.max_output_filesize()) + ->notifier([&config](std::size_t key) { + config.max_output_filesize(key); + }), + "The maximum output file size in bytes")( + "dada-header-size", + po::value() + ->default_value(config.dada_header_size()) + ->notifier( + [&config](std::size_t key) { config.dada_header_size(key); }), + "The size of the DADA header")( + "log-level", + po::value()->default_value("info")->notifier( + [](std::string level) { skyweaver::init_logging(level); }), + "The logging level to use (debug, info, warning, error)")( + "start_sample", + po::value() + ->default_value(config.start_sample()) + ->notifier( + [&config](std::size_t key) { config.start_sample(key); }), + "Start from this sample")( + "nsamples_to_read", + po::value() + ->default_value(config.start_sample()) + ->notifier( + [&config](std::size_t key) { config.nsamples_to_read(key); }), + "total number of samples to read from start_sample")( + "targets_file", + po::value()->default_value("")->notifier( + [&config](std::string key) { + config.targets_file(key); + }), "update beam names and positions from this file" + )( + "required_beams", + po::value()->default_value("")->notifier( + [&config](std::string key) { + config.required_beams(skyweaver::get_list_from_string(key)); + }), + "Comma separated list of beams to process. Syntax - beam1, beam2, " + "beam3:beam4:step, beam5 etc..")( + "required_dms", + po::value()->default_value("")->notifier( + [&config](std::string key) { + config.required_dms(skyweaver::get_list_from_string(key)); + }), + "Comma separated list of DMs to process. Syntax - dm1, dm2, " + "dm1:dm2:step, etc..")( + "out-stokes", + po::value() + ->default_value(config.out_stokes()) + ->notifier([&config](std::string key) { + if(key.find_first_not_of("IQUVL") != std::string::npos) { + throw std::runtime_error("Invalid Stokes mode: " + key); + } + config.out_stokes(key); + }), + "The list of stokes needed - these will be separte output files"); + + po::options_description cmdline_options; + cmdline_options.add(generic).add(main_options); + + // set options allowed in config file + po::options_description config_file_options; + config_file_options.add(main_options); + po::variables_map variable_map; + try { + po::store( + po::command_line_parser(argc, argv).options(cmdline_options).run(), + variable_map); + if(variable_map.count("help")) { + std::cout + << "skycleaver -- A pipeline that cleaves input TDB files, " + "and cleaves them to form output Sigproc Filterbank files." + << std::endl + << cmdline_options << std::endl; + return SUCCESS; + } + } catch(po::error& e) { + std::cerr << "ERROR: " << e.what() << std::endl << std::endl; + return ERROR_IN_COMMAND_LINE; + } + + auto config_file = variable_map.at("cfg").as(); + + if(config_file != "") { + std::ifstream config_fs(config_file.c_str()); + if(!config_fs.is_open()) { + std::cerr << "Unable to open configuration file: " << config_file + << " (" << std::strerror(errno) << ")\n"; + return ERROR_UNHANDLED_EXCEPTION; + } else { + po::store(po::parse_config_file(config_fs, config_file_options), + variable_map); + } + } + po::notify(variable_map); + + BOOST_LOG_NAMED_SCOPE("skycleaver_cli"); + BOOST_LOG_TRIVIAL(info) << "Configuration: " << config_file; + BOOST_LOG_TRIVIAL(info) << "root_dir: " << config.root_dir(); + BOOST_LOG_TRIVIAL(info) << "output_dir: " << config.output_dir(); + BOOST_LOG_TRIVIAL(info) << "out_prefix: " << config.out_prefix(); + BOOST_LOG_TRIVIAL(info) << "nthreads: " << config.nthreads(); + BOOST_LOG_TRIVIAL(info) + << "nsamples_per_block: " << config.nsamples_per_block(); + BOOST_LOG_TRIVIAL(info) << "nchans: " << config.nchans(); + BOOST_LOG_TRIVIAL(info) << "nbeams: " << config.nbeams(); + BOOST_LOG_TRIVIAL(info) << "nbridges: " << config.nbridges(); + BOOST_LOG_TRIVIAL(info) << "ndms: " << config.ndms(); + BOOST_LOG_TRIVIAL(info) << "max_ram_gb: " << config.max_ram_gb(); + BOOST_LOG_TRIVIAL(info) + << "max_output_filesize: " << config.max_output_filesize(); + BOOST_LOG_TRIVIAL(info) + << "dada_header_size: " << config.dada_header_size(); + BOOST_LOG_TRIVIAL(info) << "start_sample: " << config.start_sample(); + BOOST_LOG_TRIVIAL(info) + << "nsamples_to_read: " << config.nsamples_to_read(); + BOOST_LOG_TRIVIAL(info) << "out_stokes: " << config.out_stokes(); + + if(config.required_beams().size() > 0) { + for(auto beam: config.required_beams()) { + BOOST_LOG_TRIVIAL(info) << "required_beam: " << beam; + } + } + if(config.required_dms().size() > 0) { + for(auto dm: config.required_dms()) { + BOOST_LOG_TRIVIAL(info) << "required_dm: " << dm; + } + } + + + run_pipeline, + skyweaver::TFPowersStdH>(config); + + + // skyweaver::SkyCleaver skycleaver(config); + // skycleaver.cleave(); + + // if(config.stokes_mode() == "I" || config.stokes_mode() == "Q" || + // config.stokes_mode() == "U" || config.stokes_mode() == "V") { + // run_pipeline>, + // skyweaver::TFPowersStdH>(config); + // } else if(config.stokes_mode() == "IV" || config.stokes_mode() == "QU") { + // run_pipeline, + // skyweaver::TFPowersStdH>(config); + + // } else if(config.stokes_mode() == "IQUV") { + // run_pipeline, + // skyweaver::TFPowersStdH>(config); + // } else { + // throw std::runtime_error("Invalid Stokes mode: " + + // config.stokes_mode()); + // } + +} diff --git a/cpp/skyweaver/src/skycleaver_cli.cu b/cpp/skyweaver/src/skycleaver_cli.cu deleted file mode 100644 index ff7ad8f..0000000 --- a/cpp/skyweaver/src/skycleaver_cli.cu +++ /dev/null @@ -1,228 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "skyweaver/logging.hpp" -#include "skyweaver/SkyCleaverConfig.hpp" -#include "skyweaver/SkyCleaver.cuh" -#define BOOST_LOG_DYN_LINK 1 - - -namespace -{ - - -std::string skycleaver_splash=R"( - __ __ -.-----.| |--.--.--.----.| |.-----.---.-.--.--.-----.----. -|__ --|| <| | | __|| || -__| _ | | | -__| _| -|_____||__|__|___ |____||__||_____|___._|\___/|_____|__| - |_____| - -)"; -const size_t ERROR_IN_COMMAND_LINE = 1; -const size_t SUCCESS = 0; -const size_t ERROR_UNHANDLED_EXCEPTION = 2; - -const char* build_time = __DATE__ " " __TIME__; - - -} - -namespace std -{ -std::ostream& operator<<(std::ostream& os, const std::vector& vec) -{ - for(auto item: vec) { os << item << " "; } - return os; -} -} // namespace std - - - - -int main(int argc, char** argv) -{ - std::cout << skycleaver_splash; - std::cout << "Build time: " << build_time << std::endl; - // skyweaver::init_logging("warning"); - - skyweaver::SkyCleaverConfig config; - - - namespace po = boost::program_options; - - po::options_description generic("Generic options"); - generic.add_options()("cfg,c", - po::value()->default_value(""), - "Skycleaver configuration file"); - - po::options_description main_options("Main options"); - main_options.add_options() - ("help,h", "Produce help message") - ("root-dir,r", - po::value() - -> required() - ->notifier( - [&config](std::string key) { config.root_dir(key); }), - "The output directory for all results") - ("output-dir", - po::value() - ->default_value(config.output_dir()) - ->notifier( - [&config](std::string key) { config.output_dir(key); }), - "The output directory for all results") - ("root-prefix", - po::value() - ->default_value(config.root_prefix()) - ->notifier( - [&config](std::string key) { config.root_prefix(key); }), - "The prefix for all output files") - ("out-prefix", - po::value() - ->default_value(config.out_prefix()) - ->notifier( - [&config](std::string key) { config.out_prefix(key); }), - "The prefix for all output files") - ("nthreads", - po::value() - ->default_value(config.nthreads()) - ->notifier( - [&config](unsigned int key) { config.nthreads(key); }), - "The number of threads to use for processing") - ("nsamples-per-block", - po::value() - ->default_value(config.nsamples_per_block()) - ->notifier( - [&config](std::size_t key) { config.nsamples_per_block(key); }), - "The number of samples per block") - ("nchans", - po::value() - ->default_value(config.nchans()) - ->notifier( - [&config](std::size_t key) { config.nchans(key); }), - "The number of channels") - ("nbridges", - po::value() - ->default_value(config.nbridges()) - ->notifier( - [&config](std::size_t key) { config.nbridges(key); }), - "The number of bridges") - ("nbeams", - po::value() - ->default_value(config.nbeams()) - ->notifier( - [&config](std::size_t key) { config.nbeams(key); }), - "The number of beams") - ("ndms", - po::value() - ->default_value(config.ndms()) - ->notifier( - [&config](std::size_t key) { config.ndms(key); }), - "The number of DMs") - ("stokes-mode", - po::value() - ->default_value(config.stokes_mode()) - ->notifier( - [&config](std::string key) { config.stokes_mode(key); }), - "The stokes mode") - ("stream-id", - po::value() - ->default_value(config.stream_id()) - ->notifier( - [&config](std::size_t key) { config.stream_id(key); }), - "The stream id") - ("max-ram-gb", - po::value() - ->default_value(config.max_ram_gb()) - ->notifier( - [&config](std::size_t key) { config.max_ram_gb(key); }), - "The maximum amount of RAM to use in GB") - ("max-output-filesize", - po::value() - ->default_value(config.max_output_filesize()) - ->notifier( - [&config](std::size_t key) { config.max_output_filesize(key); }), - "The maximum output file size in bytes") - ("dada-header-size", - po::value() - ->default_value(config.dada_header_size()) - ->notifier( - [&config](std::size_t key) { config.dada_header_size(key); }), - "The size of the DADA header") - ("log-level", - po::value()->default_value("info")->notifier( - [](std::string level) { skyweaver::init_logging(level); }), - "The logging level to use (debug, info, warning, error)"); - - po::options_description cmdline_options; - cmdline_options.add(generic).add(main_options); - - // set options allowed in config file - po::options_description config_file_options; - config_file_options.add(main_options); - po::variables_map variable_map; - try { - po::store(po::command_line_parser(argc, argv) - .options(cmdline_options) - .run(), - variable_map); - if(variable_map.count("help")) { - std::cout << "skycleaver -- A pipeline that cleaves input TDB files, " - "and cleaves them to form output Sigproc Filterbank files." - << std::endl - << cmdline_options << std::endl; - return SUCCESS; - } - } catch(po::error& e) { - std::cerr << "ERROR: " << e.what() << std::endl << std::endl; - return ERROR_IN_COMMAND_LINE; - } - - auto config_file = variable_map.at("cfg").as(); - - if(config_file != "") { - std::ifstream config_fs(config_file.c_str()); - if(!config_fs.is_open()) { - std::cerr << "Unable to open configuration file: " - << config_file << " (" << std::strerror(errno) - << ")\n"; - return ERROR_UNHANDLED_EXCEPTION; - } else { - po::store(po::parse_config_file(config_fs, config_file_options), - variable_map); - } - } - po::notify(variable_map); - BOOST_LOG_NAMED_SCOPE("skycleaver_cli"); - BOOST_LOG_TRIVIAL(info) << "Configuration: " << config_file; - BOOST_LOG_TRIVIAL(info) << "root_dir: " << config.root_dir(); - BOOST_LOG_TRIVIAL(info) << "output_dir: " << config.output_dir(); - BOOST_LOG_TRIVIAL(info) << "root_prefix: " << config.root_prefix(); - BOOST_LOG_TRIVIAL(info) << "out_prefix: " << config.out_prefix(); - BOOST_LOG_TRIVIAL(info) << "nthreads: " << config.nthreads(); - BOOST_LOG_TRIVIAL(info) << "nsamples_per_block: " << config.nsamples_per_block(); - BOOST_LOG_TRIVIAL(info) << "nchans: " << config.nchans(); - BOOST_LOG_TRIVIAL(info) << "nbeams: " << config.nbeams(); - BOOST_LOG_TRIVIAL(info) << "max_ram_gb: " << config.max_ram_gb(); - BOOST_LOG_TRIVIAL(info) << "max_output_filesize: " << config.max_output_filesize(); - - - skyweaver::SkyCleaver skycleaver(config); - skycleaver.cleave(); - - -} \ No newline at end of file diff --git a/cpp/skyweaver/src/skyweaver_cli.cu b/cpp/skyweaver/src/skyweaver_cli.cu index e168f63..39d44c5 100644 --- a/cpp/skyweaver/src/skyweaver_cli.cu +++ b/cpp/skyweaver/src/skyweaver_cli.cu @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -52,7 +53,7 @@ class NullHandler { public: template - void init(Args... args){}; + void init(Args... args) {}; template bool operator()(Args... args) @@ -239,6 +240,34 @@ void run_pipeline(Pipeline& pipeline, stopwatch.show_all_timings(); } +auto pre_write_callback = [] (skyweaver::MultiFileWriterConfig const& config) +{ + if (!config.pre_write.is_enabled) return; + std::filesystem::space_info space = std::filesystem::space(config.base_output_dir); + size_t limit = config.pre_write.wait.min_free_space; + if(space.available >= config.pre_write.wait.min_free_space) + return; + + BOOST_LOG_TRIVIAL(info) + << space.available + << " bytes available space is not enough. Need at least " + << limit + << " bytes in " << config.base_output_dir << "."; + BOOST_LOG_TRIVIAL(warning) << "Start pausing."; + int max_iterations = (config.pre_write.wait.iterations == 0) ? INT_MAX : config.pre_write.wait.iterations; + for (int i = 0; i < max_iterations; i++) + { + sleep(config.pre_write.wait.sleep_time); + space = std::filesystem::space(config.base_output_dir); + if (space.available >= limit) + { + BOOST_LOG_TRIVIAL(info) << "Space has been freed up. Will proceed."; + return; + } + } + throw std::runtime_error("Space for writing hasn't been freed up in time."); +}; + template void setup_pipeline(skyweaver::PipelineConfig& config) { @@ -260,48 +289,75 @@ void setup_pipeline(skyweaver::PipelineConfig& config) typename IBWriterType::CreateStreamCallBackType create_stream_callback_ib = skyweaver::detail::create_dada_file_stream< skyweaver::BTFPowersH>; - IBWriterType ib_handler(config, "ib", create_stream_callback_ib); + + std::unique_ptr ib_handler; using StatsWriterType = - skyweaver::MultiFileWriter>; + skyweaver::MultiFileWriter>; + + std::unique_ptr stats_handler; + typename StatsWriterType::CreateStreamCallBackType create_stream_callback_stats = skyweaver::detail::create_dada_file_stream< skyweaver::FPAStatsD>; - StatsWriterType stats_handler(config, "stats", create_stream_callback_stats); - if constexpr(enable_incoherent_dedispersion) { - using CBWriterType = skyweaver::MultiFileWriter>; + using CBWriterType = + skyweaver::MultiFileWriter>; typename CBWriterType::CreateStreamCallBackType - create_stream_callback_cb = - skyweaver::detail::create_dada_file_stream>; - skyweaver::MultiFileWriter> - cb_file_writer(config, "cb", create_stream_callback_cb); + create_stream_callback_cb = + skyweaver::detail::create_dada_file_stream< + skyweaver::TDBPowersH>; + + std::unique_ptr cb_file_writer; + if (config.pre_write_config().is_enabled) + { + ib_handler.reset(new IBWriterType(config, "ib", create_stream_callback_ib, pre_write_callback)); + stats_handler.reset(new StatsWriterType(config, "stats", create_stream_callback_stats, pre_write_callback)); + cb_file_writer.reset(new CBWriterType(config, "cb", create_stream_callback_cb, pre_write_callback)); + }else{ + ib_handler.reset(new IBWriterType(config, "ib", create_stream_callback_ib)); + stats_handler.reset(new StatsWriterType(config, "stats", create_stream_callback_stats)); + cb_file_writer.reset(new CBWriterType(config, "cb", create_stream_callback_cb)); + } + skyweaver::IncoherentDedispersionPipeline - incoherent_dispersion_pipeline(config, cb_file_writer); + decltype(* cb_file_writer.get())> + incoherent_dispersion_pipeline(config, * cb_file_writer.get()); skyweaver::BeamformerPipeline pipeline(config, incoherent_dispersion_pipeline, - ib_handler, - stats_handler); + * ib_handler.get(), + * stats_handler.get()); run_pipeline(pipeline, config, file_reader, header); } else { - using CBWriterType = skyweaver::MultiFileWriter>; + using CBWriterType = + skyweaver::MultiFileWriter>; + std::unique_ptr cb_file_writer; typename CBWriterType::CreateStreamCallBackType - create_stream_callback_cb = - skyweaver::detail::create_dada_file_stream>; - CBWriterType cb_file_writer(config, "cb", create_stream_callback_cb); - skyweaver::BeamformerPipeline>; + if (config.pre_write_config().is_enabled) + { + ib_handler.reset(new IBWriterType(config, "ib", create_stream_callback_ib, pre_write_callback)); + cb_file_writer.reset(new CBWriterType(config, "cb", create_stream_callback_cb, pre_write_callback)); + stats_handler.reset(new StatsWriterType(config, "stats", create_stream_callback_stats, pre_write_callback)); + }else{ + ib_handler.reset(new IBWriterType(config, "ib", create_stream_callback_ib)); + cb_file_writer.reset(new CBWriterType(config, "cb", create_stream_callback_cb)); + stats_handler.reset(new StatsWriterType(config, "stats", create_stream_callback_stats)); + } + skyweaver::BeamformerPipeline - pipeline(config, cb_file_writer, ib_handler, stats_handler); + pipeline(config, * cb_file_writer.get(), * ib_handler.get(), * stats_handler.get()); run_pipeline(pipeline, config, file_reader, header); } } @@ -467,11 +523,32 @@ int main(int argc, char** argv) [](std::size_t nthreads) { omp_set_num_threads(nthreads); }), "The number of threads to use for incoherent dedispersion") + // Waiting options + ("wait-for-space", + po::value() + ->notifier( + [&config](std::string key) { config.configure_wait(key); }), + "Wait for enough disk space for the output. " + "::") + // Logging options ("log-level", po::value()->default_value("info")->notifier( [](std::string level) { skyweaver::set_log_level(level); }), - "The logging level to use (debug, info, warning, error)"); + "The logging level to use (debug, info, warning, error)") + + ("statistics", + po::value()->default_value(true)->notifier( + [&config](bool const& enable) { + config.output_statistics(enable); }), + "Turn on/off calculation and output of voltage statistics") + + ("write-incoherent-beam", + po::value()->default_value(true)->notifier( + [&config](bool const& enable) { + config.output_incoherent_beam(enable); }), + "Turn on/off output of incoherent beam" + "Turning off does not disable incoherent beam subtraction"); // set options allowed on command line po::options_description cmdline_options; @@ -552,13 +629,15 @@ int main(int argc, char** argv) skyweaver::StokesParameter::V>, true>(config); } else if(config.stokes_mode() == "QU") { - setup_pipeline, - true>(config); + setup_pipeline< + skyweaver::StokesTraits, + true>(config); } else if(config.stokes_mode() == "IV") { - setup_pipeline, - true>(config); + setup_pipeline< + skyweaver::StokesTraits, + true>(config); } else if(config.stokes_mode() == "IQUV") { setup_pipeline( config); @@ -586,13 +665,15 @@ int main(int argc, char** argv) skyweaver::StokesParameter::V>, false>(config); } else if(config.stokes_mode() == "QU") { - setup_pipeline, - false>(config); + setup_pipeline< + skyweaver::StokesTraits, + false>(config); } else if(config.stokes_mode() == "IV") { - setup_pipeline, - false>(config); + setup_pipeline< + skyweaver::StokesTraits, + false>(config); } else if(config.stokes_mode() == "IQUV") { setup_pipeline( config); @@ -608,4 +689,4 @@ int main(int argc, char** argv) return ERROR_UNHANDLED_EXCEPTION; } return SUCCESS; -} \ No newline at end of file +} diff --git a/python/skyweaver/cli.py b/python/skyweaver/cli.py index c29fbfb..184b898 100644 --- a/python/skyweaver/cli.py +++ b/python/skyweaver/cli.py @@ -150,7 +150,8 @@ def delays_create( bfconfig: str, pointing_idx: int = None, step: float = 4.0, - outfile: str = None): + outfile: str = None, + hex: bool = True): """Create a delay file Args: @@ -160,6 +161,7 @@ def delays_create( step (float, optional): The time step between delay solutions. Defaults to 4.0 seconds. outfile (str, optional): The file to write delay models to. Defaults to a standard output filename. + hex (bool, optional): Use hex encoding for the output filename. Defaults to True. """ sm = skyweaver.SessionMetadata.from_file(metafile) bc = skyweaver.BeamformerConfig.from_file(bfconfig) @@ -173,23 +175,30 @@ def delays_create( raise ValueError("Pointing idx {} requested but only {} pointings in session") step = step * u.s pointing = pointings[pointing_idx] - delays, targets, _ = skyweaver.create_delays(sm, bc, pointing, step=step) + if outfile is None: - fname = "swdelays_{}_{}_to_{}_{}.bin".format( - pointing.phase_centre.name, - int(pointing.start_epoch.unix), - int(pointing.end_epoch.unix), - secrets.token_hex(3) - ) + if hex: + fname = "swdelays_{}_{}_to_{}_{}".format( + pointing.phase_centre.name, + int(pointing.start_epoch.unix), + int(pointing.end_epoch.unix), + secrets.token_hex(3) + ) + else: + fname = "swdelays_{}_{}_to_{}".format( + pointing.phase_centre.name, + int(pointing.start_epoch.unix), + int(pointing.end_epoch.unix) + ) else: fname = outfile - log.info("Writing delay model to file %s", fname) - with open(fname, "wb") as fo: + delays, targets, _ = skyweaver.create_delays(sm, bc, pointing, step=step, outfile=fname) + + log.info(f"Writing delay model to file {fname}.bin") + with open(fname + ".bin", "wb") as fo: for delay_model in delays: fo.write(delay_model.to_bytes()) - with open(fname + ".targets", "w") as fo: - for target in targets: - fo.write(target.format_katcp() +"\n") + def parse_default_args(args): @@ -255,13 +264,16 @@ def cli(): help="An HDF5 FBFUSE-BVR metadata file") delays_create_parser.add_argument("bfconfig", metavar="BFCONFIG", help="A YAML beamformer configuration file") + delays_create_parser.add_argument("--no-hex", action="store_false", dest="hex", + help="Do not use hex encoding for the output filename") delays_create_parser.set_defaults( func=lambda args: delays_create( args.metafile, args.bfconfig, args.pointing_idx, args.step, - args.outfile)) + args.outfile, + args.hex)) # parse and execute args = parser.parse_args() diff --git a/python/skyweaver/skyweaver.py b/python/skyweaver/skyweaver.py index 55ff79f..263999a 100644 --- a/python/skyweaver/skyweaver.py +++ b/python/skyweaver/skyweaver.py @@ -7,18 +7,26 @@ import logging import textwrap import ctypes -from typing import Any +from typing import Any, Tuple from dataclasses import dataclass from typing_extensions import Self - +import sys +from collections import defaultdict +from matplotlib.patches import Ellipse +import matplotlib.patches as mpatches # 3rd party imports import h5py import yaml import numpy as np +import pandas as pd +import random from rich.progress import track from astropy.time import Time, TimeDelta from astropy import units as u +from astropy.coordinates import SkyCoord from astropy.units import Quantity +from astropy import wcs +import matplotlib.pyplot as plt from katpoint import Target, Antenna from mosaic.beamforming import ( DelayPolynomial, @@ -228,6 +236,8 @@ def add_tiling( for ii, (ra, dec) in enumerate(coordinates): self._targets.append( (Target(f"{prefix}_{ii:04d},radec,{ra},{dec}"), sub_array_idx)) + + def add_beam(self, target: Target, subarray: Subarray = None) -> None: """Add a single beam to the engine @@ -245,6 +255,7 @@ def add_beam(self, target: Target, subarray: Subarray = None) -> None: sub_array_idx = len(self._subarray_sets) - 1 self._targets.append((target, sub_array_idx)) + def _extract_weights(self) -> np.ndarray: # Here we extract the weights for each beam/antenna # as an optimisation we cache the antenna mask per @@ -649,6 +660,7 @@ class BeamSet: A beam set is a collection of beams which are formed from a common subarray. """ + name: str anntenna_names: list[str] beams: list[Target] tilings: list[dict] @@ -689,8 +701,10 @@ def from_file(cls, config_file: str) -> Self: bfc = data["beamformer_config"] beam_sets = [] for bs in data["beam_sets"]: + print(bs) beam_sets.append( BeamSet( + bs["name"], bs["antenna_set"], bs["beams"], bs["tilings"], @@ -706,11 +720,10 @@ def from_file(cls, config_file: str) -> Self: beam_sets ) - def make_tiling( pointing: PointingMetadata, subarray: Subarray, - tiling_desc: dict) -> Tiling: + tiling_desc: dict) -> Tuple[Tiling, BeamShape]: """Make a tiling using the complete mosaic tiling options Args: @@ -779,15 +792,37 @@ def make_tiling( antenna_strings: list[str] = [ ant.format_katcp() for ant in subarray.antenna_positions ] + psfsim: PsfSim = PsfSim(antenna_strings, ref_freq) - beam_shape: BeamShape = psfsim.get_beam_shape(target, epoch.unix) + psf_beam_shape: BeamShape = psfsim.get_beam_shape(target, epoch.unix) + #Build Mosaic command here. Remove T + mosaic_epoch = epoch.iso.replace("T", " ") + mosaic_epoch = mosaic_epoch.replace("-", ".") + mosaic_antenna_string = ','.join([item.replace('m', '') for item in subarray.names]) + mosaic_command=f"python maketiling.py --freq {ref_freq} --source {target.body._ra} {target.body._dec} --datetime {mosaic_epoch} --subarray {mosaic_antenna_string} --verbose --tiling_method {method} --tiling_shape {shape} --ants antenna.csv --beamnum {nbeams} --overlap {overlap}" + + tiling: Tiling = generate_nbeams_tiling( - beam_shape, nbeams, overlap, + psf_beam_shape, nbeams, overlap, method, shape, parameter=shape_params, coordinate_type=coordinate_type) - return tiling - + return tiling, psf_beam_shape, mosaic_command + +def pad_ra_dec(ra, dec): + # Split RA and Dec into components + ra_parts = ra.split(':') + dec_parts = dec.split(':') + + # Pad RA hours and Dec degrees to two digits + ra_parts[0] = ra_parts[0].zfill(2) # Ensure two digits for RA hours + dec_parts[0] = dec_parts[0].zfill(2) # Ensure two digits for Dec degrees + + # Join the parts back together + padded_ra = ':'.join(ra_parts) + padded_dec = ':'.join(dec_parts) + + return padded_ra, padded_dec def create_delays( session_metadata: SessionMetadata, @@ -795,7 +830,8 @@ def create_delays( pointing: PointingMetadata, start_epoch: Time = None, end_epoch: Time = None, - step: TimeDelta = 4 * u.s) -> list[DelayModel]: + step: TimeDelta = 4 * u.s, + outfile: str = None) -> list[DelayModel]: """Create a set of delay models Args: @@ -808,6 +844,7 @@ def create_delays( Defaults to the end of the pointing. step (TimeDelta, optional): The step size between consequtive solutions. Defaults to 4*u.s. + outfile (str, optional): The path to write the delay models to. Defaults is None. Returns: list[DelayModel]: A list of delay models @@ -824,27 +861,381 @@ def create_delays( log.info("Step size: %s", step.to(u.s)) full_subarray = om.get_subarray() de = DelayEngine(full_subarray, pointing.phase_centre) - bs_tilings = [] + # Initialize beam_set_lookup dictionary to track unique beam sets + beam_set_lookup = {} + beam_set_id = 0 + plot_beams = [] + neighbouring_beams = [] + #Initialise the known beams. PSF size is calculated for 50% overlap + nbeams_requested = 1 + overlap = 0.5 + #Iterate beams first through all beam sets for bs in bc.beam_sets: - tilings = [] - subarray_subset = om.get_subarray(bs.anntenna_names) - # add beams first as these are likely more important than tilings + sorted_antennas = sorted(bs.anntenna_names) + subarray_subset = om.get_subarray(sorted_antennas) + antenna_string = ','.join(sorted_antennas) if bs.beams is not None: for target_desc in bs.beams: - de.add_beam(Target(target_desc), subarray_subset) + target = Target(target_desc) + #Add the beam to the delay engine + de.add_beam(target, subarray_subset) + beam_key = (antenna_string, overlap, nbeams_requested) + if beam_key not in beam_set_lookup: + beam_set_lookup[beam_key] = beam_set_id + beam_set_id += 1 + ra = str(target.body._ra) + dec = str(target.body._dec) + # Pad RA and Dec to two digits + ra, dec = pad_ra_dec(ra, dec) + name = target.name + current_beam_set_id = beam_set_lookup[beam_key] + #Get the PSF Beam shape at 50 % overlap for beams defined in yaml. + tiling_desc = { + "nbeams": nbeams_requested, + "overlap": overlap, + "target": target_desc, + } + _, psf_beam_shape,_ = make_tiling(pointing, subarray_subset, tiling_desc) + #Add the beam to the plot beams list + plot_beams.append((name, ra, dec, round(psf_beam_shape.axisH, 5), round(psf_beam_shape.axisV, 5), round(psf_beam_shape.angle, 5), current_beam_set_id, 0.5, len(sorted_antennas), 'known')) + neighbouring_beams.append((name, ra, dec, round(psf_beam_shape.axisH, 5), round(psf_beam_shape.axisV, 5), round(psf_beam_shape.angle, 5), current_beam_set_id, 0.5, len(sorted_antennas), 'known')) + + #Iterate through all tilings + bs_tilings = [] + for bs in bc.beam_sets: + tilings = [] + sorted_antennas = sorted(bs.anntenna_names) + subarray_subset = om.get_subarray(sorted_antennas) + antenna_string = ','.join(sorted_antennas) if bs.tilings is not None: + output_prefix = f"{outfile}_{bs.name}" if outfile is not None else None for tiling_desc in bs.tilings: - tiling = make_tiling(pointing, subarray_subset, tiling_desc) + tiling, psf_beamshape, mosaic_command = make_tiling(pointing, subarray_subset, tiling_desc) + #Add the tiling to the delay engine de.add_tiling(tiling, subarray_subset) tilings.append(tiling) + cb_beamshape = tiling.meta["axis"][:3] #axisH, axisV, angle + overlap = tiling.meta["axis"][-1] + nbeams_requested = tiling_desc['nbeams'] + beam_key = (antenna_string, overlap, nbeams_requested) + if beam_key not in beam_set_lookup: + beam_set_lookup[beam_key] = beam_set_id + beam_set_id += 1 + current_beam_set_id = beam_set_lookup[beam_key] + coords = tiling.get_equatorial_coordinates() + coords = SkyCoord(coords, unit=u.deg) + mosaic_command+=f" --tiling_plot {output_prefix}_bid_{beam_set_id}.png --tiling_coordinate {output_prefix}_bid_{beam_set_id}.csv" + log.info("Mosaic command written to %s", f"{outfile}.mosaic") + log.info(f"Writing PSF of BeamSet {bs.name} to {output_prefix}_bid_{beam_set_id}.fits") + psf_beamshape.psf.write_fits(f"{output_prefix}_bid_{beam_set_id}.fits") + log.info(f"Writing PSF Plot of BeamSet {bs.name} to {output_prefix}_bid_{beam_set_id}.png") + psf_beamshape.plot_psf(f"{output_prefix}_bid_{beam_set_id}.png") + with open(f"{outfile}.mosaic", "a") as f: + f.write(mosaic_command + "\n") + for index, coord in enumerate(coords): + ra_hms = coord.ra.to_string(unit=u.hour, sep=':', precision=2, pad=True) + dec_dms = coord.dec.to_string(unit=u.degree, sep=':', precision=1, alwayssign=True, pad=True) + plot_beams.append((f"{bs.name}_{index:03d}", ra_hms, dec_dms, round(cb_beamshape[0], 5), round(cb_beamshape[1], 5), round(cb_beamshape[2], 5), current_beam_set_id, overlap, len(sorted_antennas), 'tiling')) + neighbouring_beams.append((f"{bs.name}_{index:03d}", ra_hms, dec_dms, round(psf_beamshape.axisH, 5), round(psf_beamshape.axisV, 5), round(psf_beamshape.angle, 5), current_beam_set_id, 0.5, len(sorted_antennas), 'tiling')) bs_tilings.append(tilings) - log.info( - "Calculating solutions for %d antennas and %d beams", - full_subarray.nantennas, de.nbeams) + + target = tiling_desc.get("target", None) + if target is None: + print("No target specified") + target = pointing.phase_centre + else: + target = Target(target) + column_list = ['name', 'ra', 'dec', 'x', 'y', 'angle', 'beam_set_id', 'overlap', 'nantennas', 'type'] + plot_beams_df = pd.DataFrame(plot_beams, columns=column_list) + #Plot beams has beam shape for the overlap requested, wheras neighbouring beams has the PSF shape (50% overlap) + neighbouring_beams_df = pd.DataFrame(neighbouring_beams, columns=column_list) + log.info("Calculating solutions for %d antennas and %d beams", full_subarray.nantennas, de.nbeams) delays = de.calculate_delays(start_epoch, end_epoch, step) + log.info("Beams and tilings written to %s.targets", outfile) + plot_beams_df.to_csv(outfile + ".targets", index=False) + boresight_coords = psf_beamshape.bore_sight.equatorial + plot_multiple_tilings(pointing, neighbouring_beams_df, plot_beams_df, boresight_coords, outfile) return delays, de.targets, bs_tilings +def plot_multiple_tilings(pointing, neighbouring_beams_df, plot_beams_df, boresight_coords, outfile, HD=True, beam_size_scaling=1.0, annotate_beam_names=False): + # Initialize WCS projection + wcs_properties = wcs.WCS(naxis=2) + wcs_properties.wcs.crpix = [0, 0] + wcs_properties.wcs.ctype = ["RA---TAN", "DEC--TAN"] + wcs_properties.wcs.crval = boresight_coords + center = boresight_coords + + # Use neutral scaling for overlap detection + wcs_properties.wcs.cdelt = [1, 1] # Neutral pixel scale + + detector = BeamOverlapDetector(neighbouring_beams_df, wcs_properties) + overlaps_df = detector.find_overlapping_beams() + overlaps_df.to_csv(outfile + "_overlapping_beams.csv", index=False) + + # Update WCS projection for plotting + step = 1 / 10000000000.0 + wcs_properties.wcs.cdelt = [-step, step] + resolution = step + + thisDPI = 300 + if HD: + width = 3200. + extra_source_text_size = 8 + else: + width = 800. + extra_source_text_size = 3 + + fig = plt.figure(figsize=(width/thisDPI, width/thisDPI), dpi=thisDPI) + axis = fig.add_subplot(111, aspect='equal', projection=wcs_properties) + + # Define color palette + color_palette = ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3', '#ff7f00', '#ffff33', '#a65628', '#f781bf'] + + # Extract relevant data from DataFrame in one step for efficiency and clarity + beam_ra, beam_dec, beam_name, beam_set_id, beam_x, beam_y, beam_angle, beam_type, nantennas, overlap = ( + plot_beams_df['ra'].astype(str), + plot_beams_df['dec'].astype(str), + plot_beams_df['name'].astype(str), + plot_beams_df['beam_set_id'].astype(int), + plot_beams_df['x'].astype(float), + plot_beams_df['y'].astype(float), + plot_beams_df['angle'].astype(float), + plot_beams_df['type'].astype(str), + plot_beams_df['nantennas'].astype(int), + plot_beams_df['overlap'].astype(float) + ) + + # Get equatorial coordinates + equatorialCoordinates = SkyCoord(beam_ra, beam_dec, frame='fk5', unit=(u.hourangle, u.deg)) + beam_coordinate = np.array(wcs_properties.wcs_world2pix(np.array([equatorialCoordinates.ra.deg, equatorialCoordinates.dec.deg]).T, 0)) + + # Plot the boresight for reference + axis.plot(0, 0, marker='+', markersize=15, color='black') + + # Store labels for legend creation based on beam_set_id + labels = {} + + # Loop through all beams to create ellipses and prepare legend text + for idx in range(len(beam_coordinate)): + coord = beam_coordinate[idx] + + # Create the ellipse for each beam + ellipse = Ellipse( + xy=coord, + width=2.0 * beam_x.iloc[idx] * beam_size_scaling / resolution, + height=2.0 * beam_y.iloc[idx] * beam_size_scaling / resolution, + angle=beam_angle.iloc[idx] + ) + ellipse.fill = False + color = color_palette[beam_set_id.iloc[idx] % len(color_palette)] + ellipse.set_edgecolor(color) + axis.add_artist(ellipse) + + # Prepare legend text based on the beam type criteria + if beam_type.iloc[idx] == 'known': + legend_text = f"Beam set{beam_set_id.iloc[idx]}: {beam_type.iloc[idx]}, {nantennas.iloc[idx]} antennas" + axis.annotate( + beam_name.iloc[idx], + xy=coord, + xytext=(coord[0] + 5, coord[1] + 5), # Text position, slightly offset + fontsize=8 + ) + else: + legend_text = f"Beam set{beam_set_id.iloc[idx]}: {beam_name.iloc[idx].split('_')[0]}, {nantennas.iloc[idx]} antennas, {overlap.iloc[idx]} overlap" + # Add an annotation with the beam name. Be careful with the font size, as it can be too large for the plot + if annotate_beam_names: + axis.annotate(beam_name.iloc[idx], xy=coord, xytext=(coord[0], coord[1]), fontsize=extra_source_text_size) + # If the legend entry for this beam_set_id does not exist yet, create it + if beam_set_id.iloc[idx] not in labels: + labels[beam_set_id.iloc[idx]] = (ellipse, legend_text) + + # Create a legend for all beam categories (all types) + handles = [labels[k][0] for k in labels] + legend_texts = [labels[k][1] for k in labels] + axis.legend(handles, legend_texts, loc='upper right') + margin = 1.1 * max(np.sqrt(np.sum(np.square(beam_coordinate), axis=1))) + axis.set_xlim(center[0] - margin, center[0] + margin) + axis.set_ylim(center[1] - margin, center[1] + margin) + axis.set_xlabel('RA', fontsize=30) + axis.set_ylabel('Dec', fontsize=30) + + # Save the original, unzoomed plot + # output_filename_unzoomed = f"{outfile}_original.png" + # log.info(f"Saving original unzoomed plot to {output_filename_unzoomed}") + # plt.tight_layout() + # plt.savefig(output_filename_unzoomed) + + # Now zoom into the tiling beam sets and save separate plots + tiling_df = plot_beams_df[plot_beams_df['type'] == 'tiling'] + + for beam_set, group_df in tiling_df.groupby('beam_set_id'): + equatorialCoordinates_tiling = SkyCoord(group_df['ra'].astype(str), group_df['dec'].astype(str), frame='fk5', unit=(u.hourangle, u.deg)) + beam_coordinate_tiling = np.array(wcs_properties.wcs_world2pix(np.array([equatorialCoordinates_tiling.ra.deg, equatorialCoordinates_tiling.dec.deg]).T, 0)) + # Recalculate margin for this specific beam set + margin = 1.3 * max(np.sqrt(np.sum(np.square(beam_coordinate_tiling), axis=1))) + + # Set the new axis limits for zooming into this region + axis.set_xlim(center[0] - margin, center[0] + margin) + axis.set_ylim(center[1] - margin, center[1] + margin) + + # Save the zoomed plot for the current beam_set_id + output_filename_zoomed = f"{outfile}_tiling_beamset_{beam_set}.png" + log.info(f"Saving zoomed plot for beam_set_id {beam_set} to {output_filename_zoomed}") + plt.title(f"Boresight: {pointing.phase_centre.name}, UTC Start {pointing.start_epoch.isot}, Zoomed into Beam set {beam_set}") + plt.tight_layout() + plt.savefig(output_filename_zoomed) + + +class BeamOverlapDetector: + def __init__(self, neighbour_df, wcs_properties): + self.neighbour_df = neighbour_df + self.wcs_properties = wcs_properties + self._convert_coordinates() + self._convert_to_pixel_coordinates() + + def _convert_coordinates(self): + # Convert RA and Dec from strings to degrees + coords = SkyCoord(ra=self.neighbour_df['ra'].values, dec=self.neighbour_df['dec'].values, unit=(u.hourangle, u.deg)) + self.neighbour_df['ra_deg'] = coords.ra.deg + self.neighbour_df['dec_deg'] = coords.dec.deg + + def _convert_to_pixel_coordinates(self): + # Convert RA and Dec to pixel coordinates using WCS + sky_coords = SkyCoord(ra=self.neighbour_df['ra_deg'].values, dec=self.neighbour_df['dec_deg'].values, unit='deg') + pixel_coords = np.array(self.wcs_properties.wcs_world2pix(np.array([sky_coords.ra.deg, sky_coords.dec.deg]).T, 0)) + self.neighbour_df['x_pix'], self.neighbour_df['y_pix'] = pixel_coords[:, 0], pixel_coords[:, 1] + + + def ellipse_parametric(self, t, a, b, x0, y0, theta): + cos_t = np.cos(t) + sin_t = np.sin(t) + x = x0 + a * cos_t * np.cos(theta) - b * sin_t * np.sin(theta) + y = y0 + a * cos_t * np.sin(theta) + b * sin_t * np.cos(theta) + return x, y + + def point_in_ellipse(self, x, y, ellipse): + x0, y0, a, b, theta = ellipse["x0"], ellipse["y0"], ellipse["a"], ellipse["b"], ellipse["theta"] + cos_theta = np.cos(-theta) + sin_theta = np.sin(-theta) + xr = cos_theta * (x - x0) - sin_theta * (y - y0) + yr = sin_theta * (x - x0) + cos_theta * (y - y0) + return (xr**2 / a**2) + (yr**2 / b**2) <= 1 + + def check_containment(self, ellipse1, ellipse2): + # Check if the center of ellipse1 is inside ellipse2 + x0, y0 = ellipse1["x0"], ellipse1["y0"] + return self.point_in_ellipse(x0, y0, ellipse2) + + def discrete_overlap(self, ellipse1, ellipse2, num_points=100): + # First, check if one ellipse contains the center of the other + if self.check_containment(ellipse1, ellipse2) or self.check_containment(ellipse2, ellipse1): + return True + + # Check points on the perimeters of both ellipses + t_values = np.linspace(0, 2 * np.pi, num_points) + #psf_x, psf_y -> semi-major axis (a), semi-minor axis (b) of the ellipse + x1, y1 = self.ellipse_parametric(t_values, ellipse1["a"], ellipse1["b"], ellipse1["x0"], ellipse1["y0"], ellipse1["theta"]) + + # Check if any points on the perimeter of ellipse1 lie inside ellipse2 + for x, y in zip(x1, y1): + if self.point_in_ellipse(x, y, ellipse2): + return True + return False + + def find_nearest_neighbours(self, beam, beams, n=6): + # Create a SkyCoord object for the target beam + target_coord = SkyCoord(ra=beam['ra_deg'] * u.deg, dec=beam['dec_deg'] * u.deg) + + # Create SkyCoord objects for all other beams + all_coords = SkyCoord(ra=[b['ra_deg'] for b in beams] * u.deg, + dec=[b['dec_deg'] for b in beams] * u.deg) + + # Calculate separations + separations = target_coord.separation(all_coords) + + # Sort by separation and get the nearest neighbours + nearest_indices = np.argsort(separations)[1:n+1] # Skip the first one (itself) + + # Return the list of nearest beam names in descending order of separation + nearest_neighbours = [beams[i]['name'] for i in nearest_indices] + return nearest_neighbours + + def find_overlapping_beams(self): + beams = [] + for _, row in self.neighbour_df.iterrows(): + beams.append({ + "name": row['name'], + "x0": row['x_pix'], + "y0": row['y_pix'], + "a": row['x'], + "b": row['y'], + "theta": np.radians(row['angle']), + "beam_set_id": row['beam_set_id'], + "ra_deg": row['ra_deg'], + "dec_deg": row['dec_deg'] + }) + beam_overlap_dict = {} + for i, beam1 in enumerate(beams): + # Initialize overlap list for beam1 if not already in the dict + if beam1['name'] not in beam_overlap_dict: + beam_overlap_dict[beam1['name']] = { + "name": beam1['name'], + "x_pix": beam1['x0'], + "y_pix": beam1['y0'], + "a": beam1['a'], + "b": beam1['b'], + "angle": np.degrees(beam1['theta']), + "beam_set_id": beam1['beam_set_id'], + "overlapping_beams": [], # Empty list initially + "neighbouring_beams": [] # Empty list for nearest neighbours + } + + # Inner loop to check overlaps + for j in range(i + 1, len(beams)): # Only check pairs once + beam2 = beams[j] + + # Avoid checking beams with the same beam_set_id. They have same PSF shape + if beam1['beam_set_id'] == beam2['beam_set_id']: + continue + + # Initialize overlap list for beam2 if not already in the dict + if beam2['name'] not in beam_overlap_dict: + beam_overlap_dict[beam2['name']] = { + "name": beam2['name'], + "x_pix": beam2['x0'], + "y_pix": beam2['y0'], + "a": beam2['a'], + "b": beam2['b'], + "angle": np.degrees(beam2['theta']), + "beam_set_id": beam2['beam_set_id'], + "overlapping_beams": [], # Empty list initially + "neighbouring_beams": [] # Empty list for nearest neighbours + } + + # Check if beam1 and beam2 overlap + if self.discrete_overlap(beam1, beam2): + # Update overlap lists for both beams + beam_overlap_dict[beam1['name']]['overlapping_beams'].append(beam2['name']) + beam_overlap_dict[beam2['name']]['overlapping_beams'].append(beam1['name']) + + # Find the 6 nearest neighbours for beam1 only if beam set IDs match + matching_beams = [b for b in beams if b['beam_set_id'] == beam1['beam_set_id']] + nearest_neighbours = self.find_nearest_neighbours(beam1, matching_beams) + beam_overlap_dict[beam1['name']]['neighbouring_beams'] = nearest_neighbours + + # Convert the dictionary to a DataFrame + overlap_results = pd.DataFrame(beam_overlap_dict.values()) + overlap_results = overlap_results.sort_values(by=['beam_set_id', 'name']) + + return overlap_results + + + + + + + def main(): """ What does this thing actually do?