diff --git a/src/c/CMakeLists.txt b/src/c/CMakeLists.txt index f1c1be9..232069b 100644 --- a/src/c/CMakeLists.txt +++ b/src/c/CMakeLists.txt @@ -1,7 +1,10 @@ cmake_minimum_required(VERSION 3.10) -project(pyraview C) +project(pyraview C CXX) -find_package(OpenMP) +set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +find_package(Threads REQUIRED) # Use absolute path for include to avoid relative path headaches include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../include) @@ -11,13 +14,8 @@ set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) -add_library(pyraview SHARED pyraview.c) - -if(OpenMP_C_FOUND) - target_link_libraries(pyraview PUBLIC OpenMP::OpenMP_C) -else() - message(WARNING "OpenMP not found. Compiling without parallel support.") -endif() +add_library(pyraview SHARED pyraview.cpp) +target_link_libraries(pyraview PRIVATE Threads::Threads) if(NOT WIN32) target_link_libraries(pyraview PRIVATE m) diff --git a/src/c/pyraview.c b/src/c/pyraview.cpp similarity index 61% rename from src/c/pyraview.c rename to src/c/pyraview.cpp index feafb1e..39faaa9 100644 --- a/src/c/pyraview.c +++ b/src/c/pyraview.cpp @@ -5,6 +5,12 @@ #include #include +#include +#include +#include +#include +#include + #ifdef _WIN32 #include #define pv_fseek _fseeki64 @@ -15,12 +21,6 @@ #define pv_ftell ftello #endif -#ifdef _OPENMP -#include -#else -#define omp_get_max_threads() 1 -#endif - #include // Utility: Write header @@ -63,13 +63,7 @@ static int pv_validate_or_create(FILE** f_out, const char* filename, int channel return 0; // Mismatch } // Verify startTime is valid (not necessarily matching, just valid double) - // But usually for appending, start time should be consistent or we accept the existing one. - // The prompt says "verify that the startTime in the existing file is valid". - // We'll check for NaN or Inf as a basic validity check. if (isnan(h.startTime) || isinf(h.startTime)) { - // If it's invalid, maybe fail? Or just proceed? - // Given "primary check remains channelCount and dataType", maybe just warn or ignore? - // The prompt implies a check. Let's return error if invalid. fclose(f); return -1; // Invalid start time in existing file } @@ -137,95 +131,131 @@ static int pv_internal_execute_##SUFFIX( \ int64_t input_stride = (layout == 0) ? C : 1; \ int64_t channel_step = (layout == 0) ? 1 : R; \ \ - /* Determine max threads */ \ - int max_threads = (nThreads > 0) ? nThreads : omp_get_max_threads(); \ + /* Determine effective threads */ \ + int effective_threads = (nThreads > 0) ? nThreads : std::thread::hardware_concurrency(); \ + if (effective_threads < 1) effective_threads = 1; \ \ - /* Parallel Loop */ \ - int64_t ch; \ - _Pragma("omp parallel for ordered num_threads(max_threads)") \ - for (ch = 0; ch < C; ch++) { \ - const T* ch_data = data + (ch * channel_step); \ - \ - /* Allocate buffers for this channel's output */ \ - /* Using malloc for buffers to avoid stack overflow */ \ - T* buffers[16]; \ - for(int i=0; i<16; i++) buffers[i] = NULL; \ - int64_t sizes[16]; \ - int64_t prev_len = R; \ - int alloc_failed = 0; \ - \ - for (int i = 0; i < nLevels; i++) { \ - int64_t out_len = prev_len / steps[i]; \ - /* We output Min/Max pairs, so 2 * out_len */ \ - sizes[i] = out_len; \ - if (out_len > 0) { \ - buffers[i] = (T*)malloc(out_len * 2 * sizeof(T)); \ - if (!buffers[i]) { alloc_failed = 1; break; } \ - } \ - prev_len = out_len; \ - } \ - \ - if (!alloc_failed) { \ - /* Compute L1 from Raw */ \ - if (sizes[0] > 0) { \ - int step = steps[0]; \ - T* out = buffers[0]; \ - int64_t count = sizes[0]; \ - for (int64_t i = 0; i < count; i++) { \ - T min_val = ch_data[i * step * input_stride]; \ - T max_val = min_val; \ - for (int j = 1; j < step; j++) { \ - T val = ch_data[(i * step + j) * input_stride]; \ - if (val < min_val) min_val = val; \ - if (val > max_val) max_val = val; \ + /* Synchronization primitives */ \ + std::atomic next_channel(0); \ + std::atomic next_write_ticket(0); \ + std::mutex write_mutex; \ + std::condition_variable write_cv; \ + std::atomic error_occurred(0); \ + \ + auto worker = [&]() { \ + while (true) { \ + /* Atomic fetch of work */ \ + int64_t ch = next_channel.fetch_add(1); \ + if (ch >= C) break; \ + \ + /* If global error, we skip work but must still process ticket */ \ + int skip_work = error_occurred.load(); \ + \ + const T* ch_data = data + (ch * channel_step); \ + T* buffers[16]; \ + for(int i=0; i<16; i++) buffers[i] = NULL; \ + int64_t sizes[16]; \ + int64_t prev_len = R; \ + int alloc_failed = 0; \ + \ + if (!skip_work) { \ + /* Buffer Allocation */ \ + for (int i = 0; i < nLevels; i++) { \ + int64_t out_len = prev_len / steps[i]; \ + sizes[i] = out_len; \ + if (out_len > 0) { \ + buffers[i] = (T*)malloc(out_len * 2 * sizeof(T)); \ + if (!buffers[i]) { alloc_failed = 1; break; } \ } \ - out[2*i] = min_val; \ - out[2*i+1] = max_val; \ + prev_len = out_len; \ } \ - } \ - \ - /* Compute L2..Ln from previous level */ \ - for (int lvl = 1; lvl < nLevels; lvl++) { \ - if (sizes[lvl] > 0) { \ - int step = steps[lvl]; \ - T* prev_buf = buffers[lvl-1]; \ - T* out = buffers[lvl]; \ - int64_t count = sizes[lvl]; \ - for (int64_t i = 0; i < count; i++) { \ - T min_val = prev_buf[i * step * 2]; \ - T max_val = prev_buf[i * step * 2 + 1]; \ - for (int j = 1; j < step; j++) { \ - T p_min = prev_buf[(i * step + j) * 2]; \ - T p_max = prev_buf[(i * step + j) * 2 + 1]; \ - if (p_min < min_val) min_val = p_min; \ - if (p_max > max_val) max_val = p_max; \ + \ + if (!alloc_failed) { \ + /* Compute L1 */ \ + if (sizes[0] > 0) { \ + int step = steps[0]; \ + T* out = buffers[0]; \ + int64_t count = sizes[0]; \ + for (int64_t i = 0; i < count; i++) { \ + T min_val = ch_data[i * step * input_stride]; \ + T max_val = min_val; \ + for (int j = 1; j < step; j++) { \ + T val = ch_data[(i * step + j) * input_stride]; \ + if (val < min_val) min_val = val; \ + if (val > max_val) max_val = val; \ + } \ + out[2*i] = min_val; \ + out[2*i+1] = max_val; \ + } \ + } \ + /* Compute L2..Ln */ \ + for (int lvl = 1; lvl < nLevels; lvl++) { \ + if (sizes[lvl] > 0) { \ + int step = steps[lvl]; \ + T* prev_buf = buffers[lvl-1]; \ + T* out = buffers[lvl]; \ + int64_t count = sizes[lvl]; \ + for (int64_t i = 0; i < count; i++) { \ + T min_val = prev_buf[i * step * 2]; \ + T max_val = prev_buf[i * step * 2 + 1]; \ + for (int j = 1; j < step; j++) { \ + T p_min = prev_buf[(i * step + j) * 2]; \ + T p_max = prev_buf[(i * step + j) * 2 + 1]; \ + if (p_min < min_val) min_val = p_min; \ + if (p_max > max_val) max_val = p_max; \ + } \ + out[2*i] = min_val; \ + out[2*i+1] = max_val; \ + } \ } \ - out[2*i] = min_val; \ - out[2*i+1] = max_val; \ } \ + } else { \ + /* Allocation failed */ \ + error_occurred.store(1); \ } \ } \ \ - /* Write to files sequentially */ \ - _Pragma("omp ordered") \ + /* Ordered Write Section */ \ { \ - for (int i = 0; i < nLevels; i++) { \ - if (sizes[i] > 0 && buffers[i]) { \ - fwrite(buffers[i], sizeof(T), sizes[i] * 2, files[i]); \ + std::unique_lock lock(write_mutex); \ + write_cv.wait(lock, [&]{ return next_write_ticket.load() == ch; }); \ + \ + if (!skip_work && !alloc_failed && !error_occurred.load()) { \ + for (int i = 0; i < nLevels; i++) { \ + if (sizes[i] > 0 && buffers[i]) { \ + if (fwrite(buffers[i], sizeof(T), sizes[i] * 2, files[i]) != (size_t)(sizes[i] * 2)) { \ + error_occurred.store(1); \ + } \ + } \ } \ } \ + \ + next_write_ticket.store(ch + 1); \ + write_cv.notify_all(); \ + } \ + \ + /* Cleanup buffers */ \ + for (int i = 0; i < nLevels; i++) { \ + if(buffers[i]) free(buffers[i]); \ } \ } \ - \ - /* Cleanup buffers */ \ - for (int i = 0; i < nLevels; i++) { \ - if(buffers[i]) free(buffers[i]); \ - } \ + }; \ + \ + /* Spawn Threads */ \ + std::vector threads; \ + for (int i = 0; i < effective_threads; ++i) { \ + threads.emplace_back(worker); \ + } \ + \ + /* Join Threads */ \ + for (auto& t : threads) { \ + t.join(); \ } \ \ /* Close files */ \ for (int i = 0; i < nLevels; i++) fclose(files[i]); \ - return ret; \ + \ + return error_occurred.load() ? -1 : 0; \ } // Instantiate workers @@ -240,7 +270,9 @@ DEFINE_WORKER(uint64_t, u64) DEFINE_WORKER(float, f32) DEFINE_WORKER(double, f64) -// Master Dispatcher +// Master Dispatcher (Extern C for ABI compatibility) +extern "C" { + int pyraview_process_chunk( const void* dataArray, int64_t numRows, @@ -302,3 +334,5 @@ int pyraview_get_header(const char* filename, PyraviewHeader* header) { if (memcmp(header->magic, "PYRA", 4) != 0) return -1; return 0; } + +} // extern "C" diff --git a/src/matlab/build_pyraview.m b/src/matlab/build_pyraview.m index 6a3ab15..20e5b8b 100644 --- a/src/matlab/build_pyraview.m +++ b/src/matlab/build_pyraview.m @@ -2,34 +2,24 @@ % Build script for Pyraview MEX % Paths relative to src/matlab/ -src_path = '../../src/c/pyraview.c'; +src_path = '../../src/c/pyraview.cpp'; include_path = '-I../../include'; % Source files inside +pyraview mex_src = '+pyraview/pyraview_mex.c'; header_src = '+pyraview/pyraview_get_header_mex.c'; -% OpenMP flags (adjust for OS/Compiler) -if ispc - % Windows MSVC usually supports /openmp - omp_flags = {'COMPFLAGS="$COMPFLAGS /openmp"'}; -elseif ismac - % MacOS (Clang) usually requires libomp installed and -Xpreprocessor flags. - % For simplicity in CI, we disable OpenMP on Mac. - fprintf('MacOS detected: Disabling OpenMP.\n'); - omp_flags = {}; -else - % Linux (GCC) - % Pass as separate arguments to avoid quoting issues - omp_flags = {'CFLAGS="$CFLAGS -fopenmp"', 'LDFLAGS="$LDFLAGS -fopenmp"'}; -end - % Output directory: +pyraview/ out_dir = '+pyraview'; fprintf('Building Pyraview MEX...\n'); try - mex('-v', '-outdir', out_dir, '-output', 'pyraview_mex', include_path, src_path, mex_src, omp_flags{:}); + % Build the main engine MEX + % Note: mex will compile .cpp file as C++. + % It will link with .c file (compiled as C). + % No OpenMP flags needed as we use C++11 std::thread + fprintf('Building pyraview_mex...\n'); + mex('-v', '-outdir', out_dir, '-output', 'pyraview_mex', include_path, src_path, mex_src); fprintf('Build pyraview_mex successful.\n'); fprintf('Building pyraview_get_header_mex...\n');