|
2 | 2 | from __future__ import division, print_function, absolute_import |
3 | 3 |
|
4 | 4 | import numpy as np |
5 | | -from numpy.testing import assert_allclose, assert_, assert_raises |
6 | | - |
| 5 | +from numpy.testing import (assert_allclose, assert_, assert_raises, |
| 6 | + assert_array_equal) |
7 | 7 | import pywt |
8 | 8 |
|
9 | 9 | # Check that float32, float64, complex64, complex128 are preserved. |
@@ -228,8 +228,72 @@ def test_error_on_continuous_wavelet(): |
228 | 228 | def test_dwt_zero_size_axes(): |
229 | 229 | # raise on empty input array |
230 | 230 | assert_raises(ValueError, pywt.dwt, [], 'db2') |
231 | | - |
| 231 | + |
232 | 232 | # >1D case uses a different code path so check there as well |
233 | 233 | x = np.ones((1, 4))[0:0, :] # 2D with a size zero axis |
234 | 234 | assert_raises(ValueError, pywt.dwt, x, 'db2', axis=0) |
235 | 235 |
|
| 236 | + |
| 237 | +def test_pad_1d(): |
| 238 | + x = [1, 2, 3] |
| 239 | + assert_array_equal(pywt.pad(x, (4, 6), 'periodization'), |
| 240 | + [1, 2, 3, 3, 1, 2, 3, 3, 1, 2, 3, 3, 1, 2]) |
| 241 | + assert_array_equal(pywt.pad(x, (4, 6), 'periodic'), |
| 242 | + [3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3]) |
| 243 | + assert_array_equal(pywt.pad(x, (4, 6), 'constant'), |
| 244 | + [1, 1, 1, 1, 1, 2, 3, 3, 3, 3, 3, 3, 3]) |
| 245 | + assert_array_equal(pywt.pad(x, (4, 6), 'zero'), |
| 246 | + [0, 0, 0, 0, 1, 2, 3, 0, 0, 0, 0, 0, 0]) |
| 247 | + assert_array_equal(pywt.pad(x, (4, 6), 'smooth'), |
| 248 | + [-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) |
| 249 | + assert_array_equal(pywt.pad(x, (4, 6), 'symmetric'), |
| 250 | + [3, 3, 2, 1, 1, 2, 3, 3, 2, 1, 1, 2, 3]) |
| 251 | + assert_array_equal(pywt.pad(x, (4, 6), 'antisymmetric'), |
| 252 | + [3, -3, -2, -1, 1, 2, 3, -3, -2, -1, 1, 2, 3]) |
| 253 | + assert_array_equal(pywt.pad(x, (4, 6), 'reflect'), |
| 254 | + [1, 2, 3, 2, 1, 2, 3, 2, 1, 2, 3, 2, 1]) |
| 255 | + assert_array_equal(pywt.pad(x, (4, 6), 'antireflect'), |
| 256 | + [-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) |
| 257 | + |
| 258 | + # equivalence of various pad_width formats |
| 259 | + assert_array_equal(pywt.pad(x, 4, 'periodic'), |
| 260 | + pywt.pad(x, (4, 4), 'periodic')) |
| 261 | + |
| 262 | + assert_array_equal(pywt.pad(x, (4, ), 'periodic'), |
| 263 | + pywt.pad(x, (4, 4), 'periodic')) |
| 264 | + |
| 265 | + assert_array_equal(pywt.pad(x, [(4, 4)], 'periodic'), |
| 266 | + pywt.pad(x, (4, 4), 'periodic')) |
| 267 | + |
| 268 | + |
| 269 | +def test_pad_errors(): |
| 270 | + # negative pad width |
| 271 | + x = [1, 2, 3] |
| 272 | + assert_raises(ValueError, pywt.pad, x, -2, 'periodic') |
| 273 | + |
| 274 | + # wrong length pad width |
| 275 | + assert_raises(ValueError, pywt.pad, x, (1, 1, 1), 'periodic') |
| 276 | + |
| 277 | + # invalid mode name |
| 278 | + assert_raises(ValueError, pywt.pad, x, 2, 'bad_mode') |
| 279 | + |
| 280 | + |
| 281 | +def test_pad_nd(): |
| 282 | + for ndim in [2, 3]: |
| 283 | + x = np.arange(4**ndim).reshape((4, ) * ndim) |
| 284 | + if ndim == 2: |
| 285 | + pad_widths = [(2, 1), (2, 3)] |
| 286 | + else: |
| 287 | + pad_widths = [(2, 1), ] * ndim |
| 288 | + for mode in pywt.Modes.modes: |
| 289 | + xp = pywt.pad(x, pad_widths, mode) |
| 290 | + |
| 291 | + # expected result is the same as applying along axes separably |
| 292 | + xp_expected = x.copy() |
| 293 | + for ax in range(ndim): |
| 294 | + xp_expected = np.apply_along_axis(pywt.pad, |
| 295 | + ax, |
| 296 | + xp_expected, |
| 297 | + pad_widths=[pad_widths[ax]], |
| 298 | + mode=mode) |
| 299 | + assert_array_equal(xp, xp_expected) |
0 commit comments