diff --git a/HACKING b/HACKING index bba23cd..cda300b 100644 --- a/HACKING +++ b/HACKING @@ -79,7 +79,7 @@ addressing so you don't have to worry about row major order and column major order. The rwt_printf, mat_offset, offset_row, and offset_col macros will be very useful if you need to change any of the code that uses the mat() macro. -To understand the code for the transforms themselves, start with lib/src/dwt.c +To understand the code for the transforms themselves, start with lib/src/dwt.cc which is the best documented of the transforms. The rest of them are written and structured in a very similar fashion. @@ -87,7 +87,7 @@ The flow of the code is as follows. One of the transforms is called from MATLAB. This invokes one of the wrappers from the mex directory. The function here calls rwt_matlab_init in lib/src/init.c which calls other init functions. From here the mex wrapper calls the transform in lib/src. For example, the mdwt function -for the discrete wavlet transform calls dwt() in the lib/src/dwt.c file. This +for the discrete wavlet transform calls dwt() in the lib/src/dwt.cc file. This function has a few helpers in the same file. It allocates memory necessary for the transform in dwt_allocate(), calculates the high and low pass coefficients in dwt_coefficients(), performs the convolution in dwt_convolution, and frees diff --git a/bin/compile.m b/bin/compile.m index 800b861..b3f6e50 100644 --- a/bin/compile.m +++ b/bin/compile.m @@ -2,20 +2,20 @@ % if exist('OCTAVE_VERSION', 'builtin') - mkoctfile --mex -v -DOCTAVE_MEX_FILE ../mex/mdwt.c ../lib/src/dwt.c ../lib/src/init.c ../lib/src/platform.c -I../lib/inc -o omdwt.mex - mkoctfile --mex -v -DOCTAVE_MEX_FILE ../mex/midwt.c ../lib/src/idwt.c ../lib/src/init.c ../lib/src/platform.c -I../lib/inc -o omidwt.mex + mkoctfile --mex -v -DOCTAVE_MEX_FILE ../mex/mdwt.c ../lib/src/dwt.cc ../lib/src/init.c ../lib/src/platform.c -I../lib/inc -o omdwt.mex + mkoctfile --mex -v -DOCTAVE_MEX_FILE ../mex/midwt.c ../lib/src/idwt.cc ../lib/src/init.c ../lib/src/platform.c -I../lib/inc -o omidwt.mex mkoctfile --mex -v -DOCTAVE_MEX_FILE ../mex/mrdwt.c ../lib/src/rdwt.c ../lib/src/init.c ../lib/src/platform.c -I../lib/inc -o omrdwt.mex mkoctfile --mex -v -DOCTAVE_MEX_FILE ../mex/mirdwt.c ../lib/src/irdwt.c ../lib/src/init.c ../lib/src/platform.c -I../lib/inc -o omirdwt.mex else x = computer(); if (x(length(x)-1:length(x)) == '64') - mex -v -largeArrayDims ../mex/mdwt.c ../lib/src/dwt.c ../lib/src/init.c ../lib/src/platform.c -I../lib/inc -outdir ../bin - mex -v -largeArrayDims ../mex/midwt.c ../lib/src/idwt.c ../lib/src/init.c ../lib/src/platform.c -I../lib/inc -outdir ../bin + mex -v -largeArrayDims ../mex/mdwt.c ../lib/src/dwt.cc ../lib/src/init.c ../lib/src/platform.c -I../lib/inc -outdir ../bin + mex -v -largeArrayDims ../mex/midwt.c ../lib/src/idwt.cc ../lib/src/init.c ../lib/src/platform.c -I../lib/inc -outdir ../bin mex -v -largeArrayDims ../mex/mrdwt.c ../lib/src/rdwt.c ../lib/src/init.c ../lib/src/platform.c -I../lib/inc -outdir ../bin mex -v -largeArrayDims ../mex/mirdwt.c ../lib/src/irdwt.c ../lib/src/init.c ../lib/src/platform.c -I../lib/inc -outdir ../bin else - mex -v -compatibleArrayDims ../mex/mdwt.c ../lib/src/dwt.c ../lib/src/init.c ../lib/src/platform.c -I../lib/inc -outdir ../bin - mex -v -compatibleArrayDims ../mex/midwt.c ../lib/src/idwt.c ../lib/src/init.c ../lib/src/platform.c -I../lib/inc -outdir ../bin + mex -v -compatibleArrayDims ../mex/mdwt.c ../lib/src/dwt.cc ../lib/src/init.c ../lib/src/platform.c -I../lib/inc -outdir ../bin + mex -v -compatibleArrayDims ../mex/midwt.c ../lib/src/idwt.cc ../lib/src/init.c ../lib/src/platform.c -I../lib/inc -outdir ../bin mex -v -compatibleArrayDims ../mex/mrdwt.c ../lib/src/rdwt.c ../lib/src/init.c ../lib/src/platform.c -I../lib/inc -outdir ../bin mex -v -compatibleArrayDims ../mex/mirdwt.c ../lib/src/irdwt.c ../lib/src/init.c ../lib/src/platform.c -I../lib/inc -outdir ../bin end diff --git a/bin/mdwt.m b/bin/mdwt.m index 2f6967a..f7772fd 100644 --- a/bin/mdwt.m +++ b/bin/mdwt.m @@ -1,19 +1,19 @@ function [y,L] = mdwt(x,h,L) -% [y,L] = mdwt(x,h,L); +% [y,L] = mdwt(x,h[,L[,transdims]]); % % Function computes the discrete wavelet transform y for a 1D or 2D input % signal x using the scaling filter h. % % Input: -% x : finite length 1D or 2D signal (implicitly periodized) +% x : input matrix or nd array ( transform will be done on leading 1 or 2 dims) % h : scaling filter -% L : number of levels. In the case of a 1D signal, length(x) must be -% divisible by 2^L; in the case of a 2D signal, the row and the -% column dimension must be divisible by 2^L. If no argument is -% specified, a full DWT is returned for maximal possible L. +% L : number of levels. In the case of a 1D transform, size(x,1) must be +% divisible by 2^L; for a 2D transform, size(x,2) must also be +% divisible by 2^L. The default is the maximal possible L. +% transdims: 1 or 2 dimensional transform, default is 2 (if size allows ) % % Output: -% y : the wavelet transform of the signal +% y : the wavelet transform of the signal % (see example to understand the coefficients) % L : number of decomposition levels % @@ -36,7 +36,7 @@ % % 2D Example: % -% load test_image +% load test_image % h = daubcqf(4,'min'); % L = 1; % [y,L] = mdwt(test_image,h,L); @@ -58,16 +58,21 @@ % | L,H | H,H | % | | | % `------------------' -% -% where +% +% where % 1 : High pass vertically and high pass horizontally % 2 : Low pass vertically and high pass horizontally % 3 : High pass vertically and low pass horizontally -% 4 : Low pass vertically and Low pass horizontally +% 4 : Low pass vertically and Low pass horizontally % (scaling coefficients) % +% 4D Tensor with 2D transforms across leading Example: % -% +% x=randn(64,64,3,2); +% h=[1,1]*sqrt(.5); +% y=mdwt(x,h); +% y22 = mdwt(x(:,:,2,2),h); +% assert (norm(y(:,:,2,2) - y22 ) < 1e-9) % % See also: midwt, mrdwt, mirdwt % @@ -76,7 +81,7 @@ x = x * 1.0; if (exist('L')) [y,L] = omdwt(x,h,L); - else + else [y,L] = omdwt(x,h); end else diff --git a/bin/midwt.m b/bin/midwt.m index 1d1ecd3..0fcaaff 100644 --- a/bin/midwt.m +++ b/bin/midwt.m @@ -1,18 +1,17 @@ function [y,L] = midwt(x,h,L) -% [x,L] = midwt(y,h,L); -% -% Function computes the inverse discrete wavelet transform x for a 1D or -% 2D input signal y using the scaling filter h. +% [x,L] = midwt(y,h[,L[,transdims]]) +% +% Compute the inverse discrete wavelet transform x for a 1D or +% 2D input signal (or tensor of signals) y using the scaling filter h. % % Input: -% y : finite length 1D or 2D input signal (implicitly periodized) +% y : input wavelet domain coefficients (transform is done over leading dims) % (see function mdwt to find the structure of y) % h : scaling filter -% L : number of levels. In the case of a 1D signal, length(x) must be -% divisible by 2^L; in the case of a 2D signal, the row and the -% column dimension must be divisible by 2^L. If no argument is -% specified, a full inverse DWT is returned for maximal possible -% L. +% L : number of levels. In the case of a 1D transform, size(x,1) must be +% divisible by 2^L; for a 2D transform, size(x,2) must also be +% divisible by 2^L. The default is the maximal possible L. +% transdims: 1 or 2 dimensional transform, default is 2 (if size allows ) % % Output: % x : periodic reconstructed signal @@ -36,7 +35,7 @@ if exist('OCTAVE_VERSION', 'builtin') if (exist('L')) [y,L] = omidwt(x,h,L); - else + else [y,L] = omidwt(x,h); end else diff --git a/lib/inc/rwt_init.h b/lib/inc/rwt_init.h index 310fb18..abe14fd 100644 --- a/lib/inc/rwt_init.h +++ b/lib/inc/rwt_init.h @@ -6,19 +6,22 @@ #include "rwt_platform.h" +typedef struct { +size_t nrows; /*!< The number of rows in the input matrix. Output matrix will match. */ +size_t ncols; /*!< The number of columns in the input matrix. Output matrix will match. */ +size_t nmats; /*!< The number of matrices (actually just the product of any dimensions past 2) */ +int levels; /*!< L, the number of levels for the transform. */ +int ncoeff; /*!< Length of h / the number of scaling coefficients */ +double *scalings; /*!< Wavelet scaling coefficients */ +} rwt_init_params; + +typedef enum {NORMAL_DWT, REDUNDANT_DWT, INVERSE_DWT, INVERSE_REDUNDANT_DWT} transform_t; + #if defined(MATLAB_MEX_FILE) || defined(OCTAVE_MEX_FILE) #include "mex.h" #ifndef OCTAVE_MEX_FILE #include "matrix.h" #endif - typedef struct { - size_t nrows; /*!< The number of rows in the input matrix. Output matrix will match. */ - size_t ncols; /*!< The number of columns in the input matrix. Output matrix will match. */ - int levels; /*!< L, the number of levels for the transform. */ - int ncoeff; /*!< Length of h / the number of scaling coefficients */ - double *scalings; /*!< Wavelet scaling coefficients */ - } rwt_init_params; - typedef enum {NORMAL_DWT, REDUNDANT_DWT, INVERSE_DWT, INVERSE_REDUNDANT_DWT} transform_t; #endif #ifdef __cplusplus diff --git a/lib/inc/rwt_platform.h b/lib/inc/rwt_platform.h index 407eafa..93fc391 100644 --- a/lib/inc/rwt_platform.h +++ b/lib/inc/rwt_platform.h @@ -40,11 +40,11 @@ #define rwt_errormsg(msg) printf("\033[91m%s\033[0m\n", msg); #endif -#ifndef max - #define max(A,B) (A > B ? A : B) +#ifndef MAX +# define MAX(A,B) (A > B ? A : B) #endif -#ifndef min - #define min(A,B) (A < B ? A : B) +#ifndef MIN +# define MIN(A,B) (A < B ? A : B) #endif #define even(x) ((x & 1) ? 0 : 1) diff --git a/lib/inc/rwt_transforms.h b/lib/inc/rwt_transforms.h index 534c7d9..75dcb36 100644 --- a/lib/inc/rwt_transforms.h +++ b/lib/inc/rwt_transforms.h @@ -5,6 +5,7 @@ #define TRANSFORMS_H_ #include +#include "rwt_init.h" #ifdef __cplusplus extern "C" { @@ -15,8 +16,10 @@ extern "C" { * In all cases it is expected that the output array has already been * allocated prior to calling the transform function. */ -void dwt(double *x, size_t nrows, size_t ncols, double *h, int ncoeff, int levels, double *y); -void idwt(double *x, size_t nrows, size_t ncols, double *h, int ncoeff, int levels, double *y); +void dwt_double(const double *x, double *y,const rwt_init_params * parms); +void dwt_float(const float *x, float * y,const rwt_init_params * parms); +void idwt_double(double *x, const double *y,const rwt_init_params * parms); +void idwt_float(float *x, const float *y,const rwt_init_params * parms); void rdwt(double *x, size_t nrows, size_t ncols, double *h, int ncoeff, int levels, double *yl, double *yh); void irdwt(double *x, size_t nrows, size_t ncols, double *h, int ncoeff, int levels, double *yl, double *yh); diff --git a/lib/src/CMakeLists.txt b/lib/src/CMakeLists.txt index cf242af..096b44f 100644 --- a/lib/src/CMakeLists.txt +++ b/lib/src/CMakeLists.txt @@ -1,6 +1,6 @@ include_directories ("${PROJECT_SOURCE_DIR}/lib/inc") -add_library(dwt dwt.c) -add_library(idwt idwt.c) +add_library(dwt dwt.cc) +add_library(idwt idwt.cc) add_library(irdwt irdwt.c) add_library(rdwt rdwt.c) add_library(platform platform.c) diff --git a/lib/src/dwt.c b/lib/src/dwt.c deleted file mode 100644 index 50a970e..0000000 --- a/lib/src/dwt.c +++ /dev/null @@ -1,192 +0,0 @@ -/*! \file dwt.c - \brief Implementation of the discrete wavelet transform - -*/ - -#include "rwt_platform.h" - -/*! - * Perform convolution for dwt - * - * @param x_in input signal values - * @param lx the length of x_in - * @param coeff_low the low pass coefficients - * @param coeff_high the high pass coefficients - * @param ncoeff_minus_one one less than the number of scaling coefficients - * @param x_out_low low pass results - * @param x_out_high high pass results - * - * For the convolution we will calculate the output of the lowpass and highpass filters in parallel - * - * Normally we can describe the calculation of a convolution as - * \f$ (\textbf{w} * \textbf{z})_k = \frac{1}{N} \sum\limits_{l=0}^{2N-1} w_{k-l} \cdot z_{l} \f$ - * - * Our actual implementation resembles this - * - */ -void dwt_convolution(double *x_in, size_t lx, double *coeff_low, double *coeff_high, int ncoeff_minus_one, double *x_out_low, double *x_out_high) { - size_t i, j, ind; - double x0, x1; - for (i=lx; i1) { - for (idx_columns=0; idx_columns +#include +#include +using namespace std; + +namespace { +template +class DWT +{ + private: + rwt_init_params p; + std::vector x_dummy; + std::vector y_dummy_low; // low pass results of convolution + std::vector y_dummy_high;// high pass results of convolution + std::vector coeff_low;///< the low pass coefficients - reversed h + std::vector coeff_high;///< the high pass coefficients - forward h, alternate values are sign flipped + + public: + DWT( const rwt_init_params & params) + :p(params), + x_dummy(std::max(p.nrows,p.ncols)+p.ncoeff-1), + y_dummy_low(std::max(p.nrows,p.ncols)), + y_dummy_high(std::max(p.nrows,p.ncols)), + coeff_low(p.ncoeff), + coeff_high(p.ncoeff) + { + dwt_coefficients( (const data*)p.scalings); + } + + /*! + * Perform convolution for dwt + * + * @param x_in input signal values + * @param lx the length of x_in + * + * For the convolution we will calculate the output of the lowpass and highpass filters in parallel + * + * Normally we can describe the calculation of a convolution as + * \f$ (\textbf{w} * \textbf{z})_k = \frac{1}{N} \sum\limits_{l=0}^{2N-1} w_{k-l} \cdot z_{l} \f$ + */ + void dwt_convolution(data *x_in, size_t lx) + { + int ncoeff_minus_one = p.ncoeff-1; + size_t i, j, ind; + data x0, x1; + for (i=lx; i1) { + for (size_t idx_columns=0; idx_columns -1; k--) { - x_in_low[k] = x_in_low[lx+k]; - x_in_high[k] = x_in_high[lx+k]; - } - - ind = 0; - for (i=0; i<(lx); i++) { - x0 = 0; - x1 = 0; - tj = 0; - for (j=0; j<=ncoeff_halved_minus_one; j++) { - x0 = x0 + (x_in_low[i+j] * coeff_low[ncoeff_minus_one-1-tj]) + (x_in_high[i+j] * coeff_high[ncoeff_minus_one-1-tj]); - x1 = x1 + (x_in_low[i+j] * coeff_low[ncoeff_minus_one-tj]) + (x_in_high[i+j] * coeff_high[ncoeff_minus_one-tj]); - tj += 2; - } - x_out[ind++] = x0; - x_out[ind++] = x1; - } -} - - -/*! - * Allocate memory for idwt - * - * @param m the number of rows of the input matrix - * @param n the number of columns of the input matrix - * @param ncoeff the number of scaling coefficients - * @param x_dummy - * @param y_dummy_low - * @param y_dummy_high - * @param coeff_low - * @param coeff_high - * - */ -void idwt_allocate(size_t m, size_t n, int ncoeff, double **x_dummy, double **y_dummy_low, double **y_dummy_high, double **coeff_low, double **coeff_high) { - *x_dummy = (double *) rwt_calloc(max(m,n), sizeof(double)); - *y_dummy_low = (double *) rwt_calloc(max(m,n)+ncoeff/2-1, sizeof(double)); - *y_dummy_high = (double *) rwt_calloc(max(m,n)+ncoeff/2-1, sizeof(double)); - *coeff_low = (double *) rwt_calloc(ncoeff, sizeof(double)); - *coeff_high = (double *) rwt_calloc(ncoeff, sizeof(double)); -} - - -/*! - * Free memory we allocated for idwt - * - * @param x_dummy - * @param y_dummy_low - * @param y_dummy_high - * @param coeff_low - * @param coeff_high - * - */ -void idwt_free(double **x_dummy, double **y_dummy_low, double **y_dummy_high, double **coeff_low, double **coeff_high) { - rwt_free(*x_dummy); - rwt_free(*y_dummy_low); - rwt_free(*y_dummy_high); - rwt_free(*coeff_low); - rwt_free(*coeff_high); -} - - -/*! - * Put the scaling coeffients into a form ready for use in the convolution function - * - * @param ncoeff length of h / the number of scaling coefficients - * @param h the wavelet scaling coefficients - * @param coeff_low same as h - * @param coeff_high reversed h, even values are sign flipped - * - */ -void idwt_coefficients(int ncoeff, double *h, double **coeff_low, double **coeff_high) { - int i; - for (i=0; i1) - current_rows = nrows/sample_f; - else - current_rows = 1; - current_cols = ncols/sample_f; - - for (i=0; i<(nrows*ncols); i++) - x[i] = y[i]; - - /* main loop */ - for (current_level=levels; current_level >= 1; current_level--) { - row_cursor = current_rows/2; - column_cursor = current_cols/2; - - /* go by columns in case of a 2D signal*/ - if (nrows>1) { - for (idx_cols=0; idx_cols +#include + +namespace { +template +class IDWT +{ + private: + rwt_init_params p; + std::vector x_dummy; + std::vector y_dummy_low; // low pass results of convolution + std::vector y_dummy_high;// high pass results of convolution + std::vector coeff_low;///< the low pass coefficients - reversed h + std::vector coeff_high;///< the high pass coefficients - forward h, alternate values are sign flipped + + public: + IDWT( const rwt_init_params & params) + :p(params), + x_dummy(std::max(p.nrows,p.ncols)), + y_dummy_low(std::max(p.nrows,p.ncols)+p.ncoeff/2-1), + y_dummy_high(std::max(p.nrows,p.ncols)+p.ncoeff/2-1), + coeff_low(p.ncoeff), + coeff_high(p.ncoeff) + { + idwt_coefficients((const data*)p.scalings ); + } + + /*! + * Put the scaling coeffients into a form ready for use in the convolution function + * @param h the wavelet scaling coefficients + */ + void idwt_coefficients(const data *h) + { + for (int i=0; i -1; k--) { + y_dummy_low[k] = y_dummy_low[lx+k]; + y_dummy_high[k] = y_dummy_high[lx+k]; + } + + ind = 0; + for (i=0; i<(lx); i++) { + x0 = 0; + x1 = 0; + tj = 0; + for (j=0; j<=ncoeff_halved_minus_one; j++) { + x0 = x0 + (y_dummy_low[i+j] * coeff_low[ncoeff_minus_one-1-tj]) + (y_dummy_high[i+j] * coeff_high[ncoeff_minus_one-1-tj]); + x1 = x1 + (y_dummy_low[i+j] * coeff_low[ncoeff_minus_one-tj]) + (y_dummy_high[i+j] * coeff_high[ncoeff_minus_one-tj]); + tj += 2; + } + x_dummy[ind++] = x0; + x_dummy[ind++] = x1; + } + } + + /*! + * Perform the inverse discrete wavelet transform + * + * @param x the output signal with the inverse wavelet transform applied + * @param y the input signal + */ + void process(data *x, const data *y) + { + for (int m=0;m1) + current_rows = p.nrows/sample_f; + else + current_rows = 1; + current_cols = p.ncols/sample_f; + + for (i=0; i<(p.nrows*p.ncols); i++) + x[i] = y[i]; + + /* main loop */ + for (current_level=p.levels; current_level >= 1; current_level--) { + row_cursor = current_rows/2; + column_cursor = current_cols/2; + + /* go by columns in case of a 2D signal*/ + if (p.nrows>1) { + for (idx_cols=0; idx_cols(*parms).process(x,y); +} + +void idwt_float(float *x, const float *y,const rwt_init_params * parms) +{ + IDWT(*parms).process(x,y); +} diff --git a/lib/src/init.c b/lib/src/init.c index fa20199..42ac2d1 100644 --- a/lib/src/init.c +++ b/lib/src/init.c @@ -26,8 +26,8 @@ int rwt_check_parameter_count(int nrhs, transform_t transform_type) { } } else { - if (nrhs > 3) { - rwt_errormsg("There are at most 3 input parameters allowed!"); + if (nrhs > 4) { + rwt_errormsg("There are at most 4 input parameters allowed!"); return 1; } if (nrhs < 2) { @@ -39,6 +39,18 @@ int rwt_check_parameter_count(int nrhs, transform_t transform_type) { } +int rwt_numel( const mxArray * mtx) +{ + mwSize ndims = mxGetNumberOfDimensions(mtx); + if (ndims==0) + return 0; + int i,d=1; + const mwSize * dims = mxGetDimensions(mtx); + for (i=0;i 1) { + if (MIN(nrows, ncols) > 1) { if ((nrows != mh) | (3 * ncols * levels != nh)) { return 0; } @@ -85,10 +97,10 @@ int rwt_find_levels(size_t m, size_t n) { L = (L >> 1); i++; } - if (min(m, n) == 1) - L = max(i, j); + if (MIN(m, n) == 1) + L = MAX(i, j); else - L = min(i, j); + L = MIN(i, j); if (L == 0) { rwt_errormsg("Maximum number of levels is zero; no decomposition can be performed!"); return -1; @@ -150,18 +162,21 @@ int rwt_check_levels(int levels, size_t rows, size_t cols) { */ rwt_init_params rwt_matlab_init(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[], transform_t transform_type) { rwt_init_params params; - int argNumL; + int i; /*! Check for correct # of input parameters */ if (rwt_check_parameter_count(nrhs, transform_type) != 0) return params; - /*! Check that we don't have more than two dimensions in the input since that is currently unsupported. */ - if (mxGetNumberOfDimensions(prhs[0]) > 2) { - rwt_errormsg("Matrix must have fewer than 3 dimensions!"); - return params; - } + /*! Get the number of rows and columns in the input matrix. */ - params.nrows = mxGetM(prhs[0]); - params.ncols = mxGetN(prhs[0]); + const mwSize * dims = mxGetDimensions(prhs[0]); + mwSize ndims = mxGetNumberOfDimensions(prhs[0]); + params.nrows = dims[0]; + params.ncols = dims[1]; + + /*allow multiple matrices to be transformed at once*/ + params.nmats = 1; + for (i=2;i= (argnumTransDims + 1) && rwt_numel( prhs[argnumTransDims] )!=0 ) { + transDims = (int)*mxGetPr(prhs[argnumTransDims]); + }else{ + transDims = MIN(params.nrows,params.ncols) > 1 ? 2:1; + /* legacy defaults is 2d if there are at least 2 dimensions */ + } + + if ( transDims < 2) { + /* 1D transform */ + if (params.nrows ==1) { + /* OK -- nrows==1,ncols>1 */ + }else if ( params.ncols == 1) { + params.ncols = params.nrows; + params.nrows = 1; + }else{ + /* both leading dimensions >1, push (via view) the second into the 3rd */ + params.nmats *= params.ncols; + params.ncols = params.nrows; + params.nrows = 1; + } + }else { + /*2D across first two dimensions */ + } + + if ( nrhs >= (argNumL + 1) && rwt_numel( prhs[argNumL] )!=0 ) { params.levels = (int) *mxGetPr(prhs[argNumL]); - else + }else{ params.levels = rwt_find_levels(params.nrows, params.ncols); + } if (rwt_check_levels(params.levels, params.nrows, params.ncols)) { - return params; + return params; } /*! Read the scaling coefficients, h, from the input and find their length, ncoeff. @@ -185,18 +227,21 @@ rwt_init_params rwt_matlab_init(int nlhs, mxArray *plhs[], int nrhs, const mxArr */ if (transform_type == INVERSE_REDUNDANT_DWT) { params.scalings = mxGetPr(prhs[2]); - params.ncoeff = max(mxGetM(prhs[2]), mxGetN(prhs[2])); + params.ncoeff = MAX(mxGetM(prhs[2]), mxGetN(prhs[2])); if (!rwt_check_yl_matches_yh(prhs, params.nrows, params.ncols, params.levels)) { rwt_errormsg("Dimensions of first two input matrices not consistent!"); return params; } } else { + if ( mxGetClassID(prhs[0]) != mxGetClassID(prhs[1]) ) + rwt_errormsg("x and h must have same type"); params.scalings = mxGetPr(prhs[1]); - params.ncoeff = max(mxGetM(prhs[1]), mxGetN(prhs[1])); + params.ncoeff = MAX(mxGetM(prhs[1]), mxGetN(prhs[1])); } /*! Create the first item in the output array as a double matrix with the same dimensions as the input. */ - plhs[0] = mxCreateDoubleMatrix(params.nrows, params.ncols, mxREAL); + + plhs[0] = mxCreateNumericArray( ndims,dims, mxGetClassID(prhs[0]), mxIsComplex(prhs[0]) ? mxCOMPLEX : mxREAL); return params; } #endif diff --git a/lib/src/irdwt.c b/lib/src/irdwt.c index ba9efa6..e54e06e 100644 --- a/lib/src/irdwt.c +++ b/lib/src/irdwt.c @@ -27,12 +27,12 @@ void irdwt_convolution(double *x_out, size_t lx, double *coeff_low, double *coef void irdwt_allocate(size_t m, size_t n, int ncoeff, double **x_high, double **x_dummy_low, double **x_dummy_high, double **y_dummy_low_low, double **y_dummy_low_high, double **y_dummy_high_low, double **y_dummy_high_high, double **coeff_low, double **coeff_high) { *x_high = (double *) rwt_calloc(m*n, sizeof(double)); - *x_dummy_low = (double *) rwt_calloc(max(m,n), sizeof(double)); - *x_dummy_high = (double *) rwt_calloc(max(m,n), sizeof(double)); - *y_dummy_low_low = (double *) rwt_calloc(max(m,n)+ncoeff-1, sizeof(double)); - *y_dummy_low_high = (double *) rwt_calloc(max(m,n)+ncoeff-1, sizeof(double)); - *y_dummy_high_low = (double *) rwt_calloc(max(m,n)+ncoeff-1, sizeof(double)); - *y_dummy_high_high = (double *) rwt_calloc(max(m,n)+ncoeff-1, sizeof(double)); + *x_dummy_low = (double *) rwt_calloc(MAX(m,n), sizeof(double)); + *x_dummy_high = (double *) rwt_calloc(MAX(m,n), sizeof(double)); + *y_dummy_low_low = (double *) rwt_calloc(MAX(m,n)+ncoeff-1, sizeof(double)); + *y_dummy_low_high = (double *) rwt_calloc(MAX(m,n)+ncoeff-1, sizeof(double)); + *y_dummy_high_low = (double *) rwt_calloc(MAX(m,n)+ncoeff-1, sizeof(double)); + *y_dummy_high_high = (double *) rwt_calloc(MAX(m,n)+ncoeff-1, sizeof(double)); *coeff_low = (double *) rwt_calloc(ncoeff, sizeof(double)); *coeff_high = (double *) rwt_calloc(ncoeff, sizeof(double)); } diff --git a/lib/src/rdwt.c b/lib/src/rdwt.c index d4531c2..0af235f 100644 --- a/lib/src/rdwt.c +++ b/lib/src/rdwt.c @@ -54,12 +54,12 @@ void rdwt_convolution(double *x_in, size_t lx, double *coeff_low, double *coeff_ */ void rdwt_allocate(size_t m, size_t n, int ncoeff, double **x_dummy_low, double **x_dummy_high, double **y_dummy_low_low, double **y_dummy_low_high, double **y_dummy_high_low, double **y_dummy_high_high, double **coeff_low, double **coeff_high) { - *x_dummy_low = (double *) rwt_calloc(max(m,n)+ncoeff-1, sizeof(double)); - *x_dummy_high = (double *) rwt_calloc(max(m,n)+ncoeff-1, sizeof(double)); - *y_dummy_low_low = (double *) rwt_calloc(max(m,n), sizeof(double)); - *y_dummy_low_high = (double *) rwt_calloc(max(m,n), sizeof(double)); - *y_dummy_high_low = (double *) rwt_calloc(max(m,n), sizeof(double)); - *y_dummy_high_high = (double *) rwt_calloc(max(m,n), sizeof(double)); + *x_dummy_low = (double *) rwt_calloc(MAX(m,n)+ncoeff-1, sizeof(double)); + *x_dummy_high = (double *) rwt_calloc(MAX(m,n)+ncoeff-1, sizeof(double)); + *y_dummy_low_low = (double *) rwt_calloc(MAX(m,n), sizeof(double)); + *y_dummy_low_high = (double *) rwt_calloc(MAX(m,n), sizeof(double)); + *y_dummy_high_low = (double *) rwt_calloc(MAX(m,n), sizeof(double)); + *y_dummy_high_high = (double *) rwt_calloc(MAX(m,n), sizeof(double)); *coeff_low = (double *) rwt_calloc(ncoeff, sizeof(double)); *coeff_high = (double *) rwt_calloc(ncoeff, sizeof(double)); } diff --git a/mex/mdwt.c b/mex/mdwt.c index c3051cb..7cacc4c 100644 --- a/mex/mdwt.c +++ b/mex/mdwt.c @@ -4,7 +4,7 @@ This file is used to produce a MATLAB MEX binary for the discrete wavelet transform %y = mdwt(x,h,L); -% +% % function computes the discrete wavelet transform y for a 1D or 2D input % signal x. % @@ -35,6 +35,16 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) { rwt_init_params params = rwt_matlab_init(nlhs, plhs, nrhs, prhs, NORMAL_DWT); /*! Check input and determine the parameters for dwt() */ plhs[1] = mxCreateDoubleMatrix(1, 1, mxREAL); /*! Create the output matrix */ *mxGetPr(plhs[1]) = params.levels; /*! The second returned item is the number of levels */ - dwt(mxGetPr(prhs[0]), params.nrows, params.ncols, params.scalings, params.ncoeff, params.levels, mxGetPr(plhs[0])); /*! Perform the DWT */ + if ( mxIsDouble(prhs[0]) ) { + dwt_double(mxGetPr(prhs[0]), mxGetPr(plhs[0]),¶ms); + if ( mxIsComplex(prhs[0]) ) + dwt_double(mxGetPi(prhs[0]), mxGetPi(plhs[0]),¶ms); + }else if (mxIsSingle(prhs[0] ) ) { + dwt_float((float*)mxGetData(prhs[0]), (float*)mxGetData(plhs[0]),¶ms); /*! Perform the DWT */ + if ( mxIsComplex(prhs[0]) ) + dwt_float((float*)mxGetImagData(prhs[0]), (float*)mxGetImagData(plhs[0]),¶ms); + }else{ + rwt_errormsg("unsupported data type"); + } } diff --git a/mex/midwt.c b/mex/midwt.c index a87083c..d830f56 100644 --- a/mex/midwt.c +++ b/mex/midwt.c @@ -32,12 +32,20 @@ * */ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) { - double *x, *y; rwt_init_params params = rwt_matlab_init(nlhs, plhs, nrhs, prhs, INVERSE_DWT); - y = mxGetPr(prhs[0]); - x = mxGetPr(plhs[0]); plhs[1] = mxCreateDoubleMatrix(1, 1, mxREAL); *mxGetPr(plhs[1]) = params.levels; - idwt(x, params.nrows, params.ncols, params.scalings, params.ncoeff, params.levels, y); + + if ( mxIsDouble(prhs[0]) ) { + idwt_double((double*)mxGetData(plhs[0]), (double*)mxGetData(prhs[0]), ¶ms); + if ( mxIsComplex(prhs[0]) ) + idwt_double((double*)mxGetImagData(plhs[0]), (double*)mxGetImagData(prhs[0]),¶ms); + }else if (mxIsSingle(prhs[0])){ + idwt_float((float*)mxGetData(plhs[0]), (float*)mxGetData(prhs[0]),¶ms); + if ( mxIsComplex(prhs[0]) ) + idwt_float((float*)mxGetImagData(plhs[0]), (float*)mxGetImagData(prhs[0]),¶ms); + }else{ + rwt_errormsg("unsupported data type"); + } } diff --git a/mex/mrdwt.c b/mex/mrdwt.c index b28dfc2..9c56ead 100644 --- a/mex/mrdwt.c +++ b/mex/mrdwt.c @@ -41,7 +41,7 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) { double *x, *yl, *yh; rwt_init_params params = rwt_matlab_init(nlhs, plhs, nrhs, prhs, REDUNDANT_DWT); - if (min(params.nrows, params.ncols) == 1) + if (MIN(params.nrows, params.ncols) == 1) plhs[1] = mxCreateDoubleMatrix(params.nrows, params.levels*params.ncols, mxREAL); else plhs[1] = mxCreateDoubleMatrix(params.nrows, 3*params.levels*params.ncols, mxREAL); diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index b955f78..f41a5cf 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -22,7 +22,7 @@ SET(CMAKE_SWIG_FLAGS "") SET_SOURCE_FILES_PROPERTIES(rwt.i PROPERTIES CPLUSPLUS ON) #SET_SOURCE_FILES_PROPERTIES(rwt.i PROPERTIES SWIG_FLAGS "-includeall") -SWIG_ADD_MODULE(rwt python rwt.i ../lib/src/dwt.c ../lib/src/idwt.c ../lib/src/rdwt.c ../lib/src/irdwt.c ../lib/src/platform.c ../lib/src/init.c) +SWIG_ADD_MODULE(rwt python rwt.i ../lib/src/dwt.cc ../lib/src/idwt.cc ../lib/src/rdwt.c ../lib/src/irdwt.c ../lib/src/platform.c ../lib/src/init.c) SWIG_LINK_LIBRARIES(rwt ${PYTHON_LIBRARIES}) execute_process(COMMAND python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())" OUTPUT_VARIABLE PYTHON_SITE_PACKAGES OUTPUT_STRIP_TRAILING_WHITESPACE) diff --git a/tests/test_mdwt.m b/tests/test_mdwt.m index 77165e4..3e6b056 100644 --- a/tests/test_mdwt.m +++ b/tests/test_mdwt.m @@ -11,6 +11,27 @@ assertVectorsAlmostEqual(y, y_corr, 'relative', 0.001); assertEqual(L, L_corr); +function test_mdwt_1Dcpx + x = randn(8,2)*[1;1j]; + h = daubcqf(4, 'min'); + L = 2; % For 8 values in x we would normally be L=2 + [y, L] = mdwt(x, h, L); + yr = mdwt(real(x), h, L); + yi = mdwt(imag(x), h, L); + L_corr = 2; +assertVectorsAlmostEqual(y, yr+1j*yi, 'relative', 0.001); +assertEqual(L, L_corr); + +function test_mdwt_1Ds + x = single(makesig('LinChirp', 8)); + h = single(daubcqf(4, 'min')); + L = 2; % For 8 values in x we would normally be L=2 + [y, L] = mdwt(x, h, L); + y_corr = [1.1097 0.8767 0.8204 -0.5201 -0.0339 0.1001 0.2201 -0.1401]; + L_corr = 2; +assertVectorsAlmostEqual(y, y_corr, 'relative', 0.001); +assertEqual(L, L_corr); + function test_mdwt_2D x = [1 2 3 4; 5 6 7 8 ; 9 10 11 12; 13 14 15 16]; h = daubcqf(4); @@ -36,6 +57,31 @@ [y, L] = mdwt(x, h); assertEqual(L, 3); +function test_tensor_mdwt_1D + x = randn(8,3,2); + h = daubcqf(4); + y1 = mdwt(x, h,[],1); + y2 = nan(size(x)); + for i3=1:size(x,3) + for i4=1:size(x,4) + y2(:,:,i3,i4) = mdwt( squeeze(x(:,:,i3,i4)),h,[],1); + end + end +assertVectorsAlmostEqual(y1, y2, 'relative', 0.001); + +function test_tensor_mdwt_2D + x = randn(8,16,3,2); + h = daubcqf(4); + y1 = mdwt(x, h); + y2 = nan(size(x)); + for i3=1:size(x,3) + for i4=1:size(x,4) + y2(:,:,i3,i4) = mdwt( squeeze(x(:,:,i3,i4)),h); + end + end +assertVectorsAlmostEqual(y1, y2, 'relative', 0.001); + + function test_mdwt_compute_bad_L L = -1; x = [1 2 3 4 5 6 7 8 9]; diff --git a/tests/test_midwt.m b/tests/test_midwt.m index 745fef6..f5dd83c 100644 --- a/tests/test_midwt.m +++ b/tests/test_midwt.m @@ -2,7 +2,6 @@ initTestSuite; - function test_midwt_1D x = makesig('LinChirp',8); h = daubcqf(4,'min'); @@ -11,8 +10,25 @@ [x_new,L] = midwt(y,h,L); assertVectorsAlmostEqual(x, x_new,'relative',0.0001); +function test_midwt_1Dc + x = randn(8,2)*[1;1j]; + h = daubcqf(4,'min'); + L = 2; + [y,L] = midwt(x,h,L); + [yr,L] = midwt(real(x),h,L); + [yi,L] = midwt(imag(x),h,L); +assertVectorsAlmostEqual(y, yr+yi*1j,'relative',0.0001); + +function test_midwt_1Ds + x = single(makesig('LinChirp',8)); + h = single(daubcqf(4,'min')); + L = 2; + [y,L] = mdwt(x,h,L); + [x_new,L] = midwt(y,h,L); +assertVectorsAlmostEqual(x, x_new,'relative',0.0001); + function test_midwt_2D - load lena512; + load lena512; x = lena512; h = daubcqf(6); [y,L] = mdwt(x,h); @@ -20,4 +36,42 @@ assertEqual(L,9); assertVectorsAlmostEqual(x, x_new,'relative',0.0001); +function test_midwt_2Ds + load lena512; + x = lena512; + h = daubcqf(6); + [y,L] = mdwt(x,h); + [x_new,L] = midwt(single(y),single(h)); +assertEqual(L,9); +assertVectorsAlmostEqual(x, x_new,'relative',0.0001); + +function test_midwt_mat1D + x = randn(16,4); + h = daubcqf(4,'min'); + y1 = mdwt(x,h,[],1); + y2 = nan(size(y1)); + for i2=1:size(x,2) + y2(:,i2) = mdwt(squeeze(x(:,i2)),h,[],1); + end +assertVectorsAlmostEqual(y1, y2,'relative',0.0001); + +function test_midwt_tensor2D + x = randn(16,4,3,2); + h = daubcqf(4,'min'); + y1 = mdwt(x,h); % this should default to a 2d transform across the first 2 dimensions + y2 = nan(size(y1)); + for i3=1:size(x,3) + for i4=1:size(x,4) + y2(:,:,i3,i4) = mdwt(squeeze(x(:,:,i3,i4)),h,[],2); + end + end +assertVectorsAlmostEqual(y1, y2,'relative',0.0001); +function test_midwt_tensor1Dc + sz = [16,5,3,2]; + x = randn(sz) + 1j*randn(sz); + h = daubcqf(4,'min'); + L = 2; + y = mdwt(x,h,L,1); + [x2,L] = midwt(y,h,L,1); +assertVectorsAlmostEqual(x, x2,'relative',0.0001);