From b62878de49f83ac29fb0b1a18026e20b0abcffd4 Mon Sep 17 00:00:00 2001 From: Chao Yin Date: Mon, 18 Mar 2024 15:51:58 -0500 Subject: [PATCH] Fix the marray bugs --- ltlt/Makefile | 12 ++++++-- ltlt/pivBlockedRight.cpp | 6 +--- ltlt/pivBlockedleft.cpp | 3 -- ltlt/pivUnblockedRight.cpp | 2 -- ltlt/pivUnblockedleft.cpp | 1 - ltlt/test.cpp | 5 ++-- marray/marray/blas.h | 5 ++++ marray/marray/flame.hpp | 57 +++++++++++++++++++++----------------- 8 files changed, 48 insertions(+), 43 deletions(-) diff --git a/ltlt/Makefile b/ltlt/Makefile index ddc722d..46b576f 100644 --- a/ltlt/Makefile +++ b/ltlt/Makefile @@ -15,12 +15,14 @@ endif LIBS=-lm -ldl exe=../bin/ltlt +testexe=../bin/test objs=$(patsubst %.cpp,%.o,$(wildcard *.cpp)) deps:=$(join $(addsuffix .deps/,$(dir $(objs))),$(notdir $(objs:.o=.d))) +objs:=$(filter-out ltlt.o test.o,$(objs)) .PHONY: all -all: $(exe) +all: $(exe) $(testexe) -include $(deps) @@ -29,9 +31,13 @@ all: $(exe) clean: rm -f $(objs) $(exe) -$(exe): $(objs) +$(exe): $(objs) ltlt.o @mkdir -p $(dir $(exe)) - $(CXX) $(LDFLAGS) -o $(exe) $(objs) $(LIBS) + $(CXX) $(LDFLAGS) -o $(exe) $(objs) ltlt.o $(LIBS) + +$(testexe): $(objs) test.o + @mkdir -p $(dir $(testexe)) + $(CXX) $(LDFLAGS) -o $(testexe) $(objs) test.o $(LIBS) %.o: %.cpp @mkdir -p $(dir $@).deps diff --git a/ltlt/pivBlockedRight.cpp b/ltlt/pivBlockedRight.cpp index 38f7d03..ad2a3b2 100644 --- a/ltlt/pivBlockedRight.cpp +++ b/ltlt/pivBlockedRight.cpp @@ -1,6 +1,4 @@ -#include "flame.hpp" -#include "fwd/marray_fwd.hpp" -#include "ltlt.hpp" +#include void ltlt_pivot_blockRL(const matrix_view& X, const row_view& pi, len_type block_size, const std::function&, const row_view&,len_type,bool)>& LTLT_UNB) { @@ -34,8 +32,6 @@ void ltlt_pivot_blockRL(const matrix_view& X, const row_view& pi, l LTLT_UNB(X[r1 | R2 | r3 | R4][r1 | R2 | r3 | R4], pi[R2 | r3], (r1 | R2 | r3 | R4).size() + 1, false); - pi[R2 | r3] = X[R2 | r3]; - pivot_rows(L[R2 | r3 | R4][R0 | r1], pi[R2 | r3]); diff --git a/ltlt/pivBlockedleft.cpp b/ltlt/pivBlockedleft.cpp index 382551a..aa97058 100644 --- a/ltlt/pivBlockedleft.cpp +++ b/ltlt/pivBlockedleft.cpp @@ -1,4 +1,3 @@ -#include "flame.hpp" #include "ltlt.hpp" void ltlt_pivot_blockLL(const matrix_view& X, const row_view& pi, len_type block_size, const std::function&,len_type,bool)>& LTLT_UNB) @@ -31,8 +30,6 @@ void ltlt_pivot_blockLL(const matrix_view& X, const row_view& pi, l LTLT_UNB(X[r1 | R2 | r3 | R4][r1 | R2 | r3 | R4], (r1 | R2).size(), true); - pi[R2 | r3] = X[R2 | r3]; - pivot_rows(L[R2 | r3 | R4][R0 | r1], pi[R2 | r3]); // ( R0 | r1 | R2 || r3 | R4 ) diff --git a/ltlt/pivUnblockedRight.cpp b/ltlt/pivUnblockedRight.cpp index e454a07..2562c8d 100644 --- a/ltlt/pivUnblockedRight.cpp +++ b/ltlt/pivUnblockedRight.cpp @@ -1,5 +1,3 @@ -#include "flame.hpp" -#include "fwd/marray_fwd.hpp" #include "ltlt.hpp" void ltlt_pivot_unblockRL(const matrix_view& X, const row_view& pi, len_type k, bool first_column, bool first_row) diff --git a/ltlt/pivUnblockedleft.cpp b/ltlt/pivUnblockedleft.cpp index 38fede7..0e5cd63 100644 --- a/ltlt/pivUnblockedleft.cpp +++ b/ltlt/pivUnblockedleft.cpp @@ -1,4 +1,3 @@ -#include "flame.hpp" #include "ltlt.hpp" void ltlt_pivot_unblockLL(const matrix_view& X, const row_view& pi, len_type k, bool first_column) diff --git a/ltlt/test.cpp b/ltlt/test.cpp index d1b0d2d..0a62b96 100644 --- a/ltlt/test.cpp +++ b/ltlt/test.cpp @@ -1,8 +1,7 @@ #include # include #include -#include "../marray/marray/flame.hpp" -#include "fwd/marray_fwd.hpp" +#include "test.hpp" #include using namespace MArray; @@ -63,4 +62,4 @@ int main() } } -} \ No newline at end of file +} diff --git a/marray/marray/blas.h b/marray/marray/blas.h index 1cc917e..d9c2261 100644 --- a/marray/marray/blas.h +++ b/marray/marray/blas.h @@ -732,6 +732,11 @@ using value_type = std::remove_cv_t::value_type>; namespace blas { +using std::conj; +inline float conj(float x) { return x; } +inline double conj(double x) { return x; } +inline long double conj(long double x) { return x; } + /****************************************************************************** * * Level 1 BLAS, C++ overloads diff --git a/marray/marray/flame.hpp b/marray/marray/flame.hpp index 9a729b9..c8efb24 100644 --- a/marray/marray/flame.hpp +++ b/marray/marray/flame.hpp @@ -7,7 +7,7 @@ #include #include -#include "bli_type_defs.h" +//#include "bli_type_defs.h" #include "fwd/marray_fwd.hpp" #include "marray_view.hpp" #include "expression.hpp" @@ -363,11 +363,11 @@ auto continue_with(const Args&... args) // Diagonal extraction template -auto diag(const MArray& A, len_type off=0) +auto diag(MArray&& A, len_type off=0) { MARRAY_ASSERT(A.dimension() == 2); - using T = typename MArray::value_type; + using T = typename std::decay_t::value_type; auto m = A.length(0); auto n = A.length(1); @@ -384,13 +384,13 @@ auto diag(const MArray& A, len_type off=0) } template -auto subdiag(const MArray& A) +auto subdiag(MArray&& A) { return diag(A, 1); } template -void pivot_rows(const MArray& A, len_type pi) +void pivot_rows(MArray&& A, len_type pi) { MARRAY_ASSERT(A.dimension() == 2); MARRAY_ASSERT(pi >= 0 && pi < A.length(0)); @@ -402,7 +402,7 @@ void pivot_rows(const MArray& A, len_type pi) } template -void pivot_columns(const MArray& A, len_type pi) +void pivot_columns(MArray&& A, len_type pi) { MARRAY_ASSERT(A.dimension() == 2); MARRAY_ASSERT(pi >= 0 && pi < A.length(1)); @@ -414,7 +414,7 @@ void pivot_columns(const MArray& A, len_type pi) } template -void pivot_both(const MArray& A, len_type pi, struc_t struc) +void pivot_both(MArray&& A, len_type pi, struc_t struc) { auto n = A.length(0); MARRAY_ASSERT(A.length(1) == n); @@ -429,11 +429,11 @@ void pivot_both(const MArray& A, len_type pi, struc_t struc) { case BLIS_GENERAL: pivot_rows(A, pi); - pivot_colums(A, pi); + pivot_columns(A, pi); break; case BLIS_SYMMETRIC: - blas::swap(A[tail][0], A[tail][pi]); + blas::swap(tail.size(), A[tail][0].data(),A.stride(0), A[tail][pi].data(), A.stride(0)); for (auto i : head) { @@ -450,24 +450,24 @@ void pivot_both(const MArray& A, len_type pi, struc_t struc) break; case BLIS_HERMITIAN: - blas::swap(A[tail][0], A[tail][pi]); + blas::swap(tail.size(), A[tail][0].data(),A.stride(0), A[tail][pi].data(), A.stride(0)); for (auto i : head) { auto Ai0 = A[i][0]; auto Apii = A[pi][i]; - A[i][0] = conj(Apii); - A[pi][i] = conj(Ai0); + A[i][0] = blas::conj(Apii); + A[pi][i] = blas::conj(Ai0); } std::swap(A[0][0], A[pi][pi]); - A[pi][0] = conj(A[pi][0]); + A[pi][0] = blas::conj(A[pi][0]); break; case BLIS_SKEW_SYMMETRIC: - blas::swap(A[tail][0], A[tail][pi]); + blas::swap(tail.size(), A[tail][0].data(),A.stride(0), A[tail][pi].data(), A.stride(0)); for (auto i : head) { @@ -484,19 +484,19 @@ void pivot_both(const MArray& A, len_type pi, struc_t struc) break; case BLIS_SKEW_HERMITIAN: - blas::swap(A[tail][0], A[tail][pi]); + blas::swap(tail.size(), A[tail][0].data(),A.stride(0), A[tail][pi].data(), A.stride(0)); for (auto i : head) { auto Ai0 = A[i][0]; auto Apii = A[pi][i]; - A[i][0] = -conj(Apii); - A[pi][i] = -conj(Ai0); + A[i][0] = -blas::conj(Apii); + A[pi][i] = -blas::conj(Ai0); } std::swap(A[0][0], A[pi][pi]); - A[pi][0] = -conj(A[pi][0]); + A[pi][0] = -blas::conj(A[pi][0]); break; } @@ -504,7 +504,7 @@ void pivot_both(const MArray& A, len_type pi, struc_t struc) template std::enable_if_t> -pivot_both(const MArray& A, const Pivot& p, struc_t struc) +pivot_both(MArray&& A, const Pivot& p, struc_t struc) { auto [T, B] = partition_rows(A); @@ -529,10 +529,12 @@ pivot_both(const MArray& A, const Pivot& p, struc_t struc) template std::enable_if_t> -pivot_rows(const MArray& A, const Pivot& p) +pivot_rows(MArray&& A_, const Pivot& p_) { - auto [T, B] = partition_rows(A); - + auto [T, B] = partition_rows(A_); + + auto A = A_.view(); + auto p = p_.view(); MARRAY_ASSERT(A.dimension() == 2); MARRAY_ASSERT(p.dimension() == 1); MARRAY_ASSERT(A.length(0) == p.length(0)); @@ -547,15 +549,18 @@ pivot_rows(const MArray& A, const Pivot& p) // ( R0 | r1 || R2 ) // ( T || B ) - tie(T, B) = continue_with(R0, r1, R2); + std::tie(T, B) = continue_with(R0, r1, R2); } } template std::enable_if_t> -pivot_columns(const MArray& A, const Pivot& p) +pivot_columns(MArray&& A_, const Pivot& p_) { - auto [T, B] = partition_columns(A); + auto [T, B] = partition_columns(A_); + + auto A = A_.view(); + auto p = p_.view(); MARRAY_ASSERT(A.dimension() == 2); MARRAY_ASSERT(p.dimension() == 1); @@ -571,7 +576,7 @@ pivot_columns(const MArray& A, const Pivot& p) // ( R0 | r1 || R2 ) // ( T || B ) - tie(T, B) = continue_with(R0, r1, R2); + std::tie(T, B) = continue_with(R0, r1, R2); } }