Skip to content

Commit dc691ac

Browse files
fixup
1 parent 2b7d7cf commit dc691ac

File tree

7 files changed

+84
-98
lines changed

7 files changed

+84
-98
lines changed

example/define_custom_local_operator.py

Lines changed: 24 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,32 @@
22
import numpy as np
33

44

5-
class CustomLocalOperator(Htool.VirtualLocalOperator):
5+
class CustomLocalOperator(Htool.LocalOperator):
66
def __init__(
77
self,
88
generator: Htool.VirtualGenerator,
9-
target_cluster: Htool.Cluster,
10-
global_source_cluster: Htool.Cluster,
11-
local_source_offset: int,
12-
local_source_size: int,
9+
target_local_renumbering: Htool.LocalRenumbering,
10+
source_local_renumbering: Htool.LocalRenumbering,
11+
target_use_permutation_to_mvprod: bool = False,
12+
source_use_permutation_to_mvprod: bool = False,
1313
) -> None:
1414
super().__init__(
15-
target_cluster,
16-
global_source_cluster,
17-
local_source_offset,
18-
local_source_size,
15+
target_local_renumbering,
16+
source_local_renumbering,
17+
target_use_permutation_to_mvprod,
18+
source_use_permutation_to_mvprod,
19+
)
20+
self.data = np.zeros(
21+
(target_local_renumbering.size, source_local_renumbering.size)
1922
)
20-
target_offset = target_cluster.get_offset()
21-
target_size = target_cluster.get_size()
22-
self.data = np.zeros((target_size, local_source_size))
2323
generator.build_submatrix(
24-
target_cluster.get_permutation()[
25-
target_offset : target_offset + target_size
24+
target_local_renumbering.permutation[
25+
target_local_renumbering.offset : target_local_renumbering.offset
26+
+ target_local_renumbering.size
2627
],
27-
global_source_cluster.get_permutation()[
28-
local_source_offset : local_source_offset + local_source_size
28+
source_local_renumbering.permutation[
29+
source_local_renumbering.offset : source_local_renumbering.offset
30+
+ source_local_renumbering.size
2931
],
3032
self.data,
3133
)
@@ -36,60 +38,20 @@ def add_vector_product(
3638
# Beware, inplace operation needed for output to keep the underlying data
3739
output *= beta
3840
if trans == "N":
39-
output += alpha * self.data.dot(
40-
input[
41-
self.local_source_offset : self.local_source_offset
42-
+ self.local_source_size
43-
]
44-
)
41+
output += alpha * self.data.dot(input)
4542
elif trans == "T":
46-
output += alpha * np.transpose(self.data).dot(
47-
input[
48-
self.local_source_offset : self.local_source_offset
49-
+ self.local_source_size
50-
]
51-
)
43+
output += alpha * np.transpose(self.data).dot(input)
5244
elif trans == "C":
53-
output += alpha * np.vdot(
54-
np.transpose(self.data),
55-
input[
56-
self.local_source_offset : self.local_source_offset
57-
+ self.local_source_size
58-
],
59-
)
45+
output += alpha * np.vdot(np.transpose(self.data), input)
6046

6147
def add_matrix_product_row_major(
6248
self, trans, alpha, input: np.array, beta, output: np.array
6349
) -> None:
6450
output *= beta
6551
if trans == "N":
66-
output += (
67-
alpha
68-
* self.data
69-
@ input[
70-
self.local_source_offset : self.local_source_offset
71-
+ self.local_source_size,
72-
:,
73-
]
74-
)
52+
output += alpha * self.data @ input
7553
elif trans == "T":
76-
output += (
77-
alpha
78-
* np.transpose(self.data)
79-
@ input[
80-
self.local_source_offset : self.local_source_offset
81-
+ self.local_source_size,
82-
:,
83-
]
84-
)
54+
output += alpha * np.transpose(self.data) @ input
8555
elif trans == "C":
86-
output += (
87-
alpha
88-
* np.matrix.H(self.data)
89-
@ input[
90-
self.local_source_offset : self.local_source_offset
91-
+ self.local_source_size,
92-
:,
93-
]
94-
)
56+
output += alpha * np.matrix.H(self.data) @ input
9557
output = np.asfortranarray(output)

example/use_local_hmatrix_compression.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,10 @@
7676
if local_source_cluster.get_offset() > 0:
7777
local_operator_1 = CustomLocalOperator(
7878
generator,
79-
local_target_cluster,
80-
source_cluster,
81-
0,
82-
local_source_cluster.get_offset(),
79+
Htool.LocalRenumbering(local_target_cluster),
80+
Htool.LocalRenumbering(
81+
0, local_source_cluster.get_offset(), source_cluster.get_permutation()
82+
),
8383
)
8484

8585
local_operator_2 = None
@@ -91,12 +91,14 @@
9191
):
9292
local_operator_2 = CustomLocalOperator(
9393
generator,
94-
local_target_cluster,
95-
source_cluster,
96-
local_source_cluster.get_size() + local_source_cluster.get_offset(),
97-
source_cluster.get_size()
98-
- local_source_cluster.get_size()
99-
- local_source_cluster.get_offset(),
94+
Htool.LocalRenumbering(local_target_cluster),
95+
Htool.LocalRenumbering(
96+
local_source_cluster.get_size() + local_source_cluster.get_offset(),
97+
source_cluster.get_size()
98+
- local_source_cluster.get_size()
99+
- local_source_cluster.get_offset(),
100+
source_cluster.get_permutation(),
101+
),
100102
)
101103

102104
if local_operator_1:

src/htool/local_operator/local_operator.hpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,25 @@
44
#include <htool/distributed_operator/implementations/local_operators/local_operator.hpp>
55
#include <pybind11/pybind11.h>
66

7-
template <typename CoefficientPrecision, typename CoordinatePrecision = CoefficientPrecision>
8-
class LocalOperatorPython : public htool::LocalOperator<CoefficientPrecision, CoordinatePrecision> {
7+
template <typename CoefficientPrecision>
8+
class LocalOperatorPython : public htool::LocalOperator<CoefficientPrecision> {
99
public:
10-
using htool::LocalOperator<CoefficientPrecision, CoordinatePrecision>::LocalOperator;
10+
using htool::LocalOperator<CoefficientPrecision>::LocalOperator;
1111

12-
LocalOperatorPython(const Cluster<CoordinatePrecision> &cluster_tree_target, const Cluster<CoordinatePrecision> &cluster_tree_source, bool target_use_permutation_to_mvprod = false, bool source_use_permutation_to_mvprod = false) : LocalOperator<CoefficientPrecision, CoordinatePrecision>(cluster_tree_target, cluster_tree_source, target_use_permutation_to_mvprod, source_use_permutation_to_mvprod) {}
12+
LocalOperatorPython(LocalRenumbering target_local_renumbering, LocalRenumbering source_local_renumbering, bool target_use_permutation_to_mvprod = false, bool source_use_permutation_to_mvprod = false) : LocalOperator<CoefficientPrecision>(target_local_renumbering, source_local_renumbering, target_use_permutation_to_mvprod, source_use_permutation_to_mvprod) {}
1313

1414
void local_add_vector_product(char trans, CoefficientPrecision alpha, const CoefficientPrecision *in, CoefficientPrecision beta, CoefficientPrecision *out) const override {
1515

16-
py::array_t<CoefficientPrecision> input(std::array<long int, 1>{this->m_source_cluster.get_size()}, in, py::capsule(in));
17-
py::array_t<CoefficientPrecision> output(std::array<long int, 1>{this->m_target_cluster.get_size()}, out, py::capsule(out));
16+
py::array_t<CoefficientPrecision> input(std::array<long int, 1>{this->m_local_source_renumbering.get_size()}, in, py::capsule(in));
17+
py::array_t<CoefficientPrecision> output(std::array<long int, 1>{this->m_local_target_renumbering.get_size()}, out, py::capsule(out));
1818

1919
add_vector_product(trans, alpha, input, beta, output);
2020
}
2121

2222
void local_add_matrix_product_row_major(char trans, CoefficientPrecision alpha, const CoefficientPrecision *in, CoefficientPrecision beta, CoefficientPrecision *out, int mu) const override {
2323

24-
py::array_t<CoefficientPrecision, py::array::c_style> input(std::array<long int, 2>{this->m_source_cluster.get_size(), mu}, in, py::capsule(in));
25-
py::array_t<CoefficientPrecision, py::array::c_style> output(std::array<long int, 2>{this->m_target_cluster.get_size(), mu}, out, py::capsule(out));
24+
py::array_t<CoefficientPrecision, py::array::c_style> input(std::array<long int, 2>{this->m_local_source_renumbering.get_size(), mu}, in, py::capsule(in));
25+
py::array_t<CoefficientPrecision, py::array::c_style> output(std::array<long int, 2>{this->m_local_target_renumbering.get_size(), mu}, out, py::capsule(out));
2626

2727
add_matrix_product_row_major(trans, alpha, input, beta, output);
2828
}
@@ -34,10 +34,10 @@ class LocalOperatorPython : public htool::LocalOperator<CoefficientPrecision, Co
3434
virtual void add_matrix_product_row_major(char trans, CoefficientPrecision alpha, const py::array_t<CoefficientPrecision, py::array::c_style> &in, CoefficientPrecision beta, py::array_t<CoefficientPrecision, py::array::c_style> &out) const = 0; // LCOV_EXCL_LINE
3535
};
3636

37-
template <typename CoefficientPrecision, typename CoordinatePrecision>
38-
class PyLocalOperator : public LocalOperatorPython<CoefficientPrecision, CoordinatePrecision> {
37+
template <typename CoefficientPrecision>
38+
class PyLocalOperator : public LocalOperatorPython<CoefficientPrecision> {
3939
public:
40-
using LocalOperatorPython<CoefficientPrecision, CoordinatePrecision>::LocalOperatorPython;
40+
using LocalOperatorPython<CoefficientPrecision>::LocalOperatorPython;
4141

4242
/* Trampoline (need one for each virtual function) */
4343
virtual void add_vector_product(char trans, CoefficientPrecision alpha, const py::array_t<CoefficientPrecision> &in, CoefficientPrecision beta, py::array_t<CoefficientPrecision> &out) const override {
@@ -66,17 +66,17 @@ class PyLocalOperator : public LocalOperatorPython<CoefficientPrecision, Coordin
6666
}
6767
};
6868

69-
template <typename CoefficientPrecision, typename CoordinatePrecision>
69+
template <typename CoefficientPrecision>
7070
void declare_local_operator(py::module &m, const std::string &class_name) {
7171
using VirtualClass = htool::VirtualLocalOperator<CoefficientPrecision>;
7272
py::class_<VirtualClass>(m, ("Virtual" + class_name).c_str());
7373

74-
using BaseClass = LocalOperator<CoefficientPrecision, CoordinatePrecision>;
74+
using BaseClass = LocalOperator<CoefficientPrecision>;
7575
py::class_<BaseClass, VirtualClass> py_base_class(m, ("Base" + class_name).c_str());
7676

77-
using Class = LocalOperatorPython<CoefficientPrecision, CoordinatePrecision>;
78-
py::class_<Class, PyLocalOperator<CoefficientPrecision, CoordinatePrecision>, BaseClass> py_class(m, class_name.c_str());
79-
py_class.def(py::init<const Cluster<CoordinatePrecision> &, const Cluster<CoordinatePrecision> &, bool, bool>());
77+
using Class = LocalOperatorPython<CoefficientPrecision>;
78+
py::class_<Class, PyLocalOperator<CoefficientPrecision>, BaseClass> py_class(m, class_name.c_str());
79+
py_class.def(py::init<LocalRenumbering, LocalRenumbering, bool, bool>());
8080
py_class.def("add_vector_product", &Class::add_vector_product, py::arg("trans"), py::arg("alpha"), py::arg("in").noconvert(true), py::arg("beta"), py::arg("out").noconvert(true));
8181
py_class.def("add_matrix_product_row_major", &Class::add_matrix_product_row_major);
8282
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#ifndef HTOOL_VIRTUAL_LOCAL_RENUMBERING_CPP
2+
#define HTOOL_VIRTUAL_LOCAL_RENUMBERING_CPP
3+
4+
#include <htool/distributed_operator/local_renumbering.hpp>
5+
#include <pybind11/pybind11.h>
6+
7+
template <typename CoordinatePrecision>
8+
void declare_local_renumbering(py::module &m, const std::string &className) {
9+
10+
using Class = LocalRenumbering;
11+
py::class_<Class> py_class(m, className.c_str());
12+
py_class.def(py::init([](int offset, int size, py::array_t<int> permutation) {
13+
return std::unique_ptr<Class>(new Class(offset, size, permutation.size(), permutation.data()));
14+
}));
15+
py_class.def(py::init<const Cluster<CoordinatePrecision> &>());
16+
py_class.def_property_readonly("offset", &Class::get_offset);
17+
py_class.def_property_readonly("size", &Class::get_size);
18+
py_class.def_property_readonly("permutation", [](const Class &self) { return py::array_t<int>(std::array<long int, 1>{self.get_global_size()}, self.get_permutation(), py::capsule(self.get_permutation())); });
19+
}
20+
21+
#endif

src/htool/main.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "hmatrix/lrmat.hpp"
1717

1818
#include "local_operator/local_operator.hpp"
19+
#include "local_operator/local_renumbering.hpp"
1920
#include "local_operator/virtual_local_operator.hpp"
2021

2122
#include "distributed_operator/distributed_operator.hpp"
@@ -58,7 +59,8 @@ PYBIND11_MODULE(Htool, m) {
5859
declare_custom_VirtualLowRankGenerator<double>(m, "VirtualLowRankGenerator");
5960
declare_custom_VirtualDenseBlocksGenerator<double>(m, "VirtualDenseBlocksGenerator");
6061

61-
declare_virtual_local_operator<double>(m, "VirtualLocalOperator", "ILocalOperator");
62+
declare_local_renumbering<double>(m, "LocalRenumbering");
63+
declare_local_operator<double>(m, "LocalOperator");
6264
// declare_local_operator<double, double>(m, "LocalOperator");
6365
declare_distributed_operator<double>(m, "DistributedOperator");
6466
declare_distributed_operator_utility<double, double>(m);

tests/conftest.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def cluster(geometry, symmetry):
151151
mpi4py.MPI.COMM_WORLD.size,
152152
partition=None,
153153
radii=None,
154-
weights=None
154+
weights=None,
155155
)
156156

157157
if target_partition is not None:
@@ -161,7 +161,7 @@ def cluster(geometry, symmetry):
161161
mpi4py.MPI.COMM_WORLD.size,
162162
partition=target_partition,
163163
radii=None,
164-
weights=None
164+
weights=None,
165165
)
166166
else:
167167
target_cluster: Htool.Cluster = cluster_builder.create_cluster_tree(
@@ -170,7 +170,7 @@ def cluster(geometry, symmetry):
170170
mpi4py.MPI.COMM_WORLD.size,
171171
partition=None,
172172
radii=None,
173-
weights=None
173+
weights=None,
174174
)
175175

176176
if symmetry == "S" or symmetry == "H":
@@ -205,11 +205,10 @@ def local_operator(request, generator, cluster, geometry):
205205
[target_cluster, source_cluster] = cluster
206206
return CustomLocalOperator(
207207
generator,
208-
target_cluster.get_cluster_on_partition(mpi4py.MPI.COMM_WORLD.rank),
209-
source_cluster,
210-
source_cluster.get_offset(),
211-
source_cluster.get_size(),
212-
208+
Htool.LocalRenumbering(
209+
target_cluster.get_cluster_on_partition(mpi4py.MPI.COMM_WORLD.rank)
210+
),
211+
Htool.LocalRenumbering(source_cluster),
213212
)
214213
else:
215214
return None

0 commit comments

Comments
 (0)