Skip to content

Commit 399196a

Browse files
authored
Merge pull request #478 from grlee77/pad
ENH: update the pad utility function and make it part of the public API
2 parents b904642 + bcb7949 commit 399196a

File tree

7 files changed

+211
-68
lines changed

7 files changed

+211
-68
lines changed

doc/source/conf.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
extensions = ['sphinx.ext.doctest', 'sphinx.ext.autodoc', 'sphinx.ext.todo',
3737
'sphinx.ext.extlinks', 'sphinx.ext.mathjax',
3838
'sphinx.ext.autosummary', 'numpydoc',
39+
'sphinx.ext.intersphinx',
3940
'matplotlib.sphinxext.plot_directive']
4041

4142
# Add any paths that contain templates here, relative to this directory.
@@ -224,3 +225,9 @@
224225
plot_formats = [('png', 96), 'pdf']
225226
plot_html_show_formats = False
226227
plot_html_show_source_link = False
228+
229+
# -- Options for intersphinx extension ---------------------------------------
230+
231+
# Intersphinx to get Numpy and other targets
232+
intersphinx_mapping = {
233+
'numpy': ('https://docs.scipy.org/doc/numpy/', None)}

doc/source/pyplots/plot_boundary_modes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
66
In practice, which signal extension mode is beneficial will depend on the
77
signal characteristics. For this particular signal, some modes such as
8-
"periodic", "antisymmetric" and "zeros" result in large discontinuities that
8+
"periodic", "antisymmetric" and "zero" result in large discontinuities that
99
would lead to large amplitude boundary coefficients in the detail coefficients
1010
of a discrete wavelet transform.
1111
"""
@@ -28,5 +28,5 @@
2828
boundary_mode_subplot(x, 'periodization', axes[5], symw=False)
2929
boundary_mode_subplot(x, 'smooth', axes[6], symw=False)
3030
boundary_mode_subplot(x, 'constant', axes[7], symw=False)
31-
boundary_mode_subplot(x, 'zeros', axes[8], symw=False)
31+
boundary_mode_subplot(x, 'zero', axes[8], symw=False)
3232
plt.show()

doc/source/ref/signal-extension-modes.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,3 +136,15 @@ periodization per N/A
136136
antisymmetric asym, asymh N/A
137137
antireflect asymw reflect, reflect_type='odd'
138138
================== ============= ===========================
139+
140+
Padding using PyWavelets Signal Extension Modes - ``pad``
141+
---------------------------------------------------------
142+
143+
.. autofunction:: pad
144+
145+
Pywavelets provides a function, :func:`pad`, that operate like
146+
:func:`numpy.pad`, but supporting the PyWavelets signal extension modes
147+
discussed above. For efficiency, the DWT routines in PyWavelets do not
148+
expclitly create padded signals using this function. It can be used to manually
149+
prepad signals to reduce boundary effects in functions such as :func:`cwt` and
150+
:func:`swt` that do not currently support all of these signal extension modes.

pywt/_doc_utils.py

Lines changed: 4 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
import numpy as np
55
from matplotlib import pyplot as plt
66

7+
from ._dwt import pad
8+
79
__all__ = ['wavedec_keys', 'wavedec2_keys', 'draw_2d_wp_basis',
8-
'draw_2d_fswavedecn_basis', 'pad', 'boundary_mode_subplot']
10+
'draw_2d_fswavedecn_basis', 'boundary_mode_subplot']
911

1012

1113
def wavedec_keys(level):
@@ -149,63 +151,6 @@ def draw_2d_fswavedecn_basis(shape, levels, fmt='k', plot_kwargs={}, ax=None,
149151
return fig, ax
150152

151153

152-
def pad(x, pad_widths, mode):
153-
"""Extend a 1D signal using a given boundary mode.
154-
155-
Like numpy.pad but supports all PyWavelets boundary modes.
156-
"""
157-
if np.isscalar(pad_widths):
158-
pad_widths = (pad_widths, pad_widths)
159-
160-
if x.ndim > 1:
161-
raise ValueError("This padding function is only for 1D signals.")
162-
163-
if mode in ['symmetric', 'reflect']:
164-
xp = np.pad(x, pad_widths, mode=mode)
165-
elif mode in ['periodic', 'periodization']:
166-
if mode == 'periodization' and x.size % 2 == 1:
167-
raise ValueError("periodization expects an even length signal.")
168-
xp = np.pad(x, pad_widths, mode='wrap')
169-
elif mode == 'zeros':
170-
xp = np.pad(x, pad_widths, mode='constant', constant_values=0)
171-
elif mode == 'constant':
172-
xp = np.pad(x, pad_widths, mode='edge')
173-
elif mode == 'smooth':
174-
xp = np.pad(x, pad_widths, mode='linear_ramp',
175-
end_values=(x[0] + pad_widths[0] * (x[0] - x[1]),
176-
x[-1] + pad_widths[1] * (x[-1] - x[-2])))
177-
elif mode == 'antisymmetric':
178-
# implement by flipping portions symmetric padding
179-
npad_l, npad_r = pad_widths
180-
xp = np.pad(x, pad_widths, mode='symmetric')
181-
r_edge = npad_l + x.size - 1
182-
l_edge = npad_l
183-
# width of each reflected segment
184-
seg_width = x.size
185-
# flip reflected segments on the right of the original signal
186-
n = 1
187-
while r_edge <= xp.size:
188-
segment_slice = slice(r_edge + 1,
189-
min(r_edge + 1 + seg_width, xp.size))
190-
if n % 2:
191-
xp[segment_slice] *= -1
192-
r_edge += seg_width
193-
n += 1
194-
195-
# flip reflected segments on the left of the original signal
196-
n = 1
197-
while l_edge >= 0:
198-
segment_slice = slice(max(0, l_edge - seg_width), l_edge)
199-
if n % 2:
200-
xp[segment_slice] *= -1
201-
l_edge -= seg_width
202-
n += 1
203-
elif mode == 'antireflect':
204-
npad_l, npad_r = pad_widths
205-
xp = np.pad(x, pad_widths, mode='reflect', reflect_type='odd')
206-
return xp
207-
208-
209154
def boundary_mode_subplot(x, mode, ax, symw=True):
210155
"""Plot an illustration of the boundary mode in a subplot axis."""
211156

@@ -236,7 +181,7 @@ def boundary_mode_subplot(x, mode, ax, symw=True):
236181
left -= 0.5
237182
step = len(x)
238183
rng = range(-2, 4)
239-
if mode in ['smooth', 'constant', 'zeros']:
184+
if mode in ['smooth', 'constant', 'zero']:
240185
rng = range(0, 2)
241186
for rep in rng:
242187
ax.plot((left + rep * step) * o2, [xp.min() - .5, xp.max() + .5], 'k-')

pywt/_dwt.py

Lines changed: 117 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
__all__ = ["dwt", "idwt", "downcoef", "upcoef", "dwt_max_level",
15-
"dwt_coeff_len"]
15+
"dwt_coeff_len", "pad"]
1616

1717

1818
def dwt_max_level(data_len, filter_len):
@@ -135,7 +135,6 @@ def dwt(data, wavelet, mode='symmetric', axis=-1):
135135
Axis over which to compute the DWT. If not given, the
136136
last axis is used.
137137
138-
139138
Returns
140139
-------
141140
(cA, cD) : tuple
@@ -211,7 +210,6 @@ def idwt(cA, cD, wavelet, mode='symmetric', axis=-1):
211210
Axis over which to compute the inverse DWT. If not given, the
212211
last axis is used.
213212
214-
215213
Returns
216214
-------
217215
rec: array_like
@@ -401,3 +399,119 @@ def upcoef(part, coeffs, wavelet, level=1, take=0):
401399
if part not in 'ad':
402400
raise ValueError("Argument 1 must be 'a' or 'd', not '%s'." % part)
403401
return np.asarray(_upcoef(part == 'a', coeffs, wavelet, level, take))
402+
403+
404+
def pad(x, pad_widths, mode):
405+
"""Extend a 1D signal using a given boundary mode.
406+
407+
This function operates like :func:`numpy.pad` but supports all signal
408+
extension modes that can be used by PyWavelets discrete wavelet transforms.
409+
410+
Parameters
411+
----------
412+
x : ndarray
413+
The array to pad
414+
pad_widths : {sequence, array_like, int}
415+
Number of values padded to the edges of each axis.
416+
``((before_1, after_1), … (before_N, after_N))`` unique pad widths for
417+
each axis. ``((before, after),)`` yields same before and after pad for
418+
each axis. ``(pad,)`` or int is a shortcut for
419+
``before = after = pad width`` for all axes.
420+
mode : str, optional
421+
Signal extension mode, see :ref:`Modes <ref-modes>`.
422+
423+
Returns
424+
-------
425+
pad : ndarray
426+
Padded array of rank equal to array with shape increased according to
427+
``pad_widths``.
428+
429+
Notes
430+
-----
431+
The performance of padding in dimensions > 1 may be substantially slower
432+
for modes ``'smooth'`` and ``'antisymmetric'`` as these modes are not
433+
supported efficiently by the underlying :func:`numpy.pad` function.
434+
435+
Note that the behavior of the ``'constant'`` mode here follows the
436+
PyWavelets convention which is different from NumPy (it is equivalent to
437+
``mode='edge'`` in :func:`numpy.pad`).
438+
"""
439+
x = np.asanyarray(x)
440+
441+
# process pad_widths exactly as in numpy.pad
442+
pad_widths = np.array(pad_widths)
443+
pad_widths = np.round(pad_widths).astype(np.intp, copy=False)
444+
if pad_widths.min() < 0:
445+
raise ValueError("pad_widths must be > 0")
446+
pad_widths = np.broadcast_to(pad_widths, (x.ndim, 2)).tolist()
447+
448+
if mode in ['symmetric', 'reflect']:
449+
xp = np.pad(x, pad_widths, mode=mode)
450+
elif mode in ['periodic', 'periodization']:
451+
if mode == 'periodization':
452+
# Promote odd-sized dimensions to even length by duplicating the
453+
# last value.
454+
edge_pad_widths = [(0, x.shape[ax] % 2)
455+
for ax in range(x.ndim)]
456+
x = np.pad(x, edge_pad_widths, mode='edge')
457+
xp = np.pad(x, pad_widths, mode='wrap')
458+
elif mode == 'zero':
459+
xp = np.pad(x, pad_widths, mode='constant', constant_values=0)
460+
elif mode == 'constant':
461+
xp = np.pad(x, pad_widths, mode='edge')
462+
elif mode == 'smooth':
463+
def pad_smooth(vector, pad_width, iaxis, kwargs):
464+
# smooth extension to left
465+
left = vector[pad_width[0]]
466+
slope_left = (left - vector[pad_width[0] + 1])
467+
vector[:pad_width[0]] = \
468+
left + np.arange(pad_width[0], 0, -1) * slope_left
469+
470+
# smooth extension to right
471+
right = vector[-pad_width[1] - 1]
472+
slope_right = (right - vector[-pad_width[1] - 2])
473+
vector[-pad_width[1]:] = \
474+
right + np.arange(1, pad_width[1] + 1) * slope_right
475+
return vector
476+
xp = np.pad(x, pad_widths, pad_smooth)
477+
elif mode == 'antisymmetric':
478+
def pad_antisymmetric(vector, pad_width, iaxis, kwargs):
479+
# smooth extension to left
480+
# implement by flipping portions symmetric padding
481+
npad_l, npad_r = pad_width
482+
vsize_nonpad = vector.size - npad_l - npad_r
483+
# Note: must modify vector in-place
484+
vector[:] = np.pad(vector[pad_width[0]:-pad_width[-1]],
485+
pad_width, mode='symmetric')
486+
vp = vector
487+
r_edge = npad_l + vsize_nonpad - 1
488+
l_edge = npad_l
489+
# width of each reflected segment
490+
seg_width = vsize_nonpad
491+
# flip reflected segments on the right of the original signal
492+
n = 1
493+
while r_edge <= vp.size:
494+
segment_slice = slice(r_edge + 1,
495+
min(r_edge + 1 + seg_width, vp.size))
496+
if n % 2:
497+
vp[segment_slice] *= -1
498+
r_edge += seg_width
499+
n += 1
500+
501+
# flip reflected segments on the left of the original signal
502+
n = 1
503+
while l_edge >= 0:
504+
segment_slice = slice(max(0, l_edge - seg_width), l_edge)
505+
if n % 2:
506+
vp[segment_slice] *= -1
507+
l_edge -= seg_width
508+
n += 1
509+
return vector
510+
xp = np.pad(x, pad_widths, pad_antisymmetric)
511+
elif mode == 'antireflect':
512+
xp = np.pad(x, pad_widths, mode='reflect', reflect_type='odd')
513+
else:
514+
raise ValueError(
515+
("unsupported mode: {}. The supported modes are {}").format(
516+
mode, Modes.modes))
517+
return xp

pywt/_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# <https://github.com/PyWavelets/pywt>
33
# See COPYING for license details.
44
import inspect
5+
import numpy as np
56
import sys
67
from collections.abc import Iterable
78

@@ -17,7 +18,7 @@
1718

1819

1920
def _as_wavelet(wavelet):
20-
"""Convert wavelet name to a Wavelet object"""
21+
"""Convert wavelet name to a Wavelet object."""
2122
if not isinstance(wavelet, (ContinuousWavelet, Wavelet)):
2223
wavelet = DiscreteContinuousWavelet(wavelet)
2324
if isinstance(wavelet, ContinuousWavelet):

pywt/tests/test_dwt_idwt.py

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from __future__ import division, print_function, absolute_import
33

44
import numpy as np
5-
from numpy.testing import assert_allclose, assert_, assert_raises
6-
5+
from numpy.testing import (assert_allclose, assert_, assert_raises,
6+
assert_array_equal)
77
import pywt
88

99
# Check that float32, float64, complex64, complex128 are preserved.
@@ -228,8 +228,72 @@ def test_error_on_continuous_wavelet():
228228
def test_dwt_zero_size_axes():
229229
# raise on empty input array
230230
assert_raises(ValueError, pywt.dwt, [], 'db2')
231-
231+
232232
# >1D case uses a different code path so check there as well
233233
x = np.ones((1, 4))[0:0, :] # 2D with a size zero axis
234234
assert_raises(ValueError, pywt.dwt, x, 'db2', axis=0)
235235

236+
237+
def test_pad_1d():
238+
x = [1, 2, 3]
239+
assert_array_equal(pywt.pad(x, (4, 6), 'periodization'),
240+
[1, 2, 3, 3, 1, 2, 3, 3, 1, 2, 3, 3, 1, 2])
241+
assert_array_equal(pywt.pad(x, (4, 6), 'periodic'),
242+
[3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3])
243+
assert_array_equal(pywt.pad(x, (4, 6), 'constant'),
244+
[1, 1, 1, 1, 1, 2, 3, 3, 3, 3, 3, 3, 3])
245+
assert_array_equal(pywt.pad(x, (4, 6), 'zero'),
246+
[0, 0, 0, 0, 1, 2, 3, 0, 0, 0, 0, 0, 0])
247+
assert_array_equal(pywt.pad(x, (4, 6), 'smooth'),
248+
[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
249+
assert_array_equal(pywt.pad(x, (4, 6), 'symmetric'),
250+
[3, 3, 2, 1, 1, 2, 3, 3, 2, 1, 1, 2, 3])
251+
assert_array_equal(pywt.pad(x, (4, 6), 'antisymmetric'),
252+
[3, -3, -2, -1, 1, 2, 3, -3, -2, -1, 1, 2, 3])
253+
assert_array_equal(pywt.pad(x, (4, 6), 'reflect'),
254+
[1, 2, 3, 2, 1, 2, 3, 2, 1, 2, 3, 2, 1])
255+
assert_array_equal(pywt.pad(x, (4, 6), 'antireflect'),
256+
[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
257+
258+
# equivalence of various pad_width formats
259+
assert_array_equal(pywt.pad(x, 4, 'periodic'),
260+
pywt.pad(x, (4, 4), 'periodic'))
261+
262+
assert_array_equal(pywt.pad(x, (4, ), 'periodic'),
263+
pywt.pad(x, (4, 4), 'periodic'))
264+
265+
assert_array_equal(pywt.pad(x, [(4, 4)], 'periodic'),
266+
pywt.pad(x, (4, 4), 'periodic'))
267+
268+
269+
def test_pad_errors():
270+
# negative pad width
271+
x = [1, 2, 3]
272+
assert_raises(ValueError, pywt.pad, x, -2, 'periodic')
273+
274+
# wrong length pad width
275+
assert_raises(ValueError, pywt.pad, x, (1, 1, 1), 'periodic')
276+
277+
# invalid mode name
278+
assert_raises(ValueError, pywt.pad, x, 2, 'bad_mode')
279+
280+
281+
def test_pad_nd():
282+
for ndim in [2, 3]:
283+
x = np.arange(4**ndim).reshape((4, ) * ndim)
284+
if ndim == 2:
285+
pad_widths = [(2, 1), (2, 3)]
286+
else:
287+
pad_widths = [(2, 1), ] * ndim
288+
for mode in pywt.Modes.modes:
289+
xp = pywt.pad(x, pad_widths, mode)
290+
291+
# expected result is the same as applying along axes separably
292+
xp_expected = x.copy()
293+
for ax in range(ndim):
294+
xp_expected = np.apply_along_axis(pywt.pad,
295+
ax,
296+
xp_expected,
297+
pad_widths=[pad_widths[ax]],
298+
mode=mode)
299+
assert_array_equal(xp, xp_expected)

0 commit comments

Comments
 (0)