diff --git a/pfft/core.pyx b/pfft/core.pyx index 7bb6262..8da50ee 100644 --- a/pfft/core.pyx +++ b/pfft/core.pyx @@ -376,26 +376,31 @@ cdef class ProcMesh(object): else: raise ValueError("only comm=MPI.COMM_WORLD is supported, " + " update mpi4py to 2.0, with MPI._addressof") + else: + raise TypeError("comm is not a MPI.Comm") + self.comm = comm self.rank = comm.rank - cdef int [::1] np_ = numpy.array(np, 'int32') + cdef int [::1] np_ + np_ = numpy.array(np, 'int32') + self.np = numpy.array(np_) + rt = pfft_create_procmesh(np_.shape[0], ccomm, &np_[0], &self.ccart) if rt != 0: self.ccart = NULL raise RuntimeError("Failed to create proc mesh") - self.np = numpy.array(np_) self.ndim = len(self.np) # a buffer used for various purposes - cdef int[::1] junk = numpy.empty(self.ndim, 'int32') + cdef int[::1] junk = numpy.empty_like(np_) # now fill `this' - self.this = numpy.array(np, 'int32') - cMPI.MPI_Cart_get(self.ccart, 2, + self.this = numpy.array(np_, 'int32') + cMPI.MPI_Cart_get(self.ccart, len(self.this), &junk[0], &junk[0], self.this.data); @@ -482,13 +487,14 @@ cdef class Partition(object): local_ni, local_no, local_i_start, local_o_start = numpy.empty((4, n_.shape[0]), 'intp') - if len(n_) <= len(procmesh.np): + if (len(n_) > 1 and len(n_) <= len(procmesh.np)) \ + or (len(n_) == 1 and len(procmesh.np) > 1): raise ValueError("ProcMesh (%d) shall have less dimentions than Mesh (%d)" % (len(procmesh.np), len(n_))) self.type = Type(type) self.flags = Flags(flags) cdef pfft_local_size_func func = PFFT_LOCAL_SIZE_FUNC[self.type] - + print(numpy.array(n_), procmesh.np, procmesh.ndim) rt = func(n_.shape[0], &n_[0], diff --git a/pfft/tests/test_pfft.py b/pfft/tests/test_pfft.py index 75c7a5f..20fbd74 100644 --- a/pfft/tests/test_pfft.py +++ b/pfft/tests/test_pfft.py @@ -8,6 +8,7 @@ from mpi4py_test import MPIWorld from mpi4py import MPI + def test_world(): world = MPI.COMM_WORLD @@ -19,6 +20,15 @@ def test_world(): assert_array_equal(pfft.ProcMesh.split(2, None), pfft.ProcMesh.split(2, world)) assert_array_equal(pfft.ProcMesh.split(1, None), pfft.ProcMesh.split(1, world)) +@MPIWorld(NTask=1) +def test_1d(comm): + procmesh = pfft.ProcMesh(np=[comm.size,], comm=comm) + partition = pfft.Partition(pfft.Type.PFFT_C2C, + [4], procmesh, + pfft.Flags.PFFT_TRANSPOSED_OUT) + + assert_array_equal(partition.i_edges[0], [0]) + @MPIWorld(NTask=3, required=3, optional=True) def test_edges(comm): procmesh = pfft.ProcMesh(np=[comm.size,], comm=comm) diff --git a/pfft/version.py b/pfft/version.py index 4274e06..0c5c300 100644 --- a/pfft/version.py +++ b/pfft/version.py @@ -1 +1 @@ -__version__ = "0.1.11dev0" +__version__ = "0.1.11"