1+ #include < iomanip>
2+ #include < iostream>
3+ #include < ctime>
4+ #include < chrono>
5+ #include < thread>
6+
7+ using namespace std ;
8+
9+ void PrintMatrix (double * matrix, int N)
10+ {
11+ for (int i = 0 ; i < N; i++)
12+ {
13+ for (int j = 0 ; j < N; j++)
14+ {
15+ std::cout.width (6 );
16+ // cout << matrix[i*N + j] << " ";
17+ std::cout << matrix[i * N + j] << " " << std::setw (6 );
18+ }
19+ std::cout << std::endl;
20+ }
21+ std::cout << std::endl;
22+ }
23+
24+ void GenerateRandomMatrix (double * matrix1, double * matrix2, int N)
25+ {
26+ for (int i = 0 ; i < N; i++)
27+ {
28+ for (int j = 0 ; j < N; j++)
29+ {
30+ matrix1[i * N + j] = ((rand () % 101 - 50 ) / 10.0 );
31+ matrix2[i * N + j] = ((rand () % 101 - 50 ) / 10.0 );
32+ /* matrix1[i*N + j] = (double)rand() / (double)RAND_MAX* 5;
33+ matrix2[i*N + j] = (double)rand() / (double)RAND_MAX* 5 ;*/
34+ }
35+ }
36+ }
37+
38+ bool matComparison (double * matrix1, double * matrix2, int N)
39+ {
40+ for (int i = 0 ; i < N; i++)
41+ {
42+ for (int j = 0 ; j < N; j++)
43+ {
44+ if (matrix1[i * N + j] != matrix2[i * N + j])
45+ {
46+ return false ;
47+ // std::cout << "Not Equal" << std::endl;
48+ }
49+ }
50+ }
51+ return true ;
52+ }
53+
54+ double * CreateMatrix (int N)
55+ {
56+ double * matrix = new double [N * N];
57+ return matrix;
58+ }
59+
60+ double * defaultMult (double * matrix1, double * matrix2, int N)
61+ {
62+ double * tmp = CreateMatrix (N);
63+ double sum;
64+ for (int i = 0 ; i < N; i++)
65+ {
66+ for (int j = 0 ; j < N; j++) {
67+ sum = 0 ;
68+ for (int k = 0 ; k < N; k++)
69+ {
70+ sum += matrix1[i * N + k] * matrix2[k * N + j];
71+ }
72+ tmp[i * N + j] = sum;
73+ }
74+ }
75+ return tmp;
76+ }
77+
78+ double * Add (double * matrix1, double * matrix2, int N)
79+ {
80+ double * tmp = CreateMatrix (N);
81+ for (int i = 0 ; i < N; i++)
82+ {
83+ for (int j = 0 ; j < N; j++)
84+ {
85+ tmp[i * N + j] = matrix1[i * N + j] + matrix2[i * N + j];
86+ }
87+ }
88+ return tmp;
89+ }
90+
91+ double * Add (double * matrix1, double * matrix2, double * matrix3, double * matrix4, int N)
92+ {
93+ double * tmp = CreateMatrix (N);
94+ for (int i = 0 ; i < N; i++)
95+ {
96+ for (int j = 0 ; j < N; j++)
97+ {
98+ tmp[i * N + j] = matrix1[i * N + j] + matrix2[i * N + j] + matrix3[i * N + j] + matrix4[i * N + j];
99+ }
100+ }
101+ return tmp;
102+ }
103+
104+ double * Sub (double * matrix1, double * matrix2, int N)
105+ {
106+ double * tmp = CreateMatrix (N);
107+ for (int i = 0 ; i < N; i++)
108+ {
109+ for (int j = 0 ; j < N; j++)
110+ {
111+ tmp[i * N + j] = matrix1[i * N + j] - matrix2[i * N + j];
112+ }
113+ }
114+ return tmp;
115+ }
116+
117+ double * Sub (double * matrix1, double * matrix2, double * matrix3, double * matrix4, int N)
118+ {
119+ double * tmp = CreateMatrix (N);
120+ for (int i = 0 ; i < N; i++)
121+ {
122+ for (int j = 0 ; j < N; j++)
123+ {
124+ tmp[i * N + j] = matrix1[i * N + j] + matrix2[i * N + j] + matrix3[i * N + j] - matrix4[i * N + j];
125+ }
126+ }
127+ return tmp;
128+ }
129+
130+ double * Str_alg (double * matrix1, double * matrix2, int N, int threshold);
131+
132+ double * Strassen_Threads (double * matrix1, double * matrix2, int N, int threshold)
133+ {
134+ double * Rez;
135+
136+ if (N <= threshold)
137+ Rez = defaultMult (matrix1, matrix2, N);
138+ else
139+ {
140+ Rez = CreateMatrix (N);
141+ N = N / 2 ;
142+
143+ double * A[4 ]; double * B[4 ]; double * C[4 ]; double * P[7 ];
144+
145+ double * TMP1; double * TMP2; double * TMP3; double * TMP4; double * TMP5;
146+ double * TMP6; double * TMP7; double * TMP8; double * TMP9; double * TMP10;
147+
148+ for (int i = 0 ; i < 4 ; i++)
149+ {
150+ A[i] = CreateMatrix (N);
151+ B[i] = CreateMatrix (N);
152+ }
153+
154+ for (int i = 0 ; i < N; i++)
155+ {
156+ for (int j = 0 ; j < N; j++)
157+ {
158+ A[0 ][i * N + j] = matrix1[2 * i * N + j];
159+ A[1 ][i * N + j] = matrix1[2 * i * N + j + N];
160+ A[2 ][i * N + j] = matrix1[2 * i * N + j + 2 * N * N];
161+ A[3 ][i * N + j] = matrix1[2 * i * N + j + 2 * N * N + N];
162+
163+ B[0 ][i * N + j] = matrix2[2 * i * N + j];
164+ B[1 ][i * N + j] = matrix2[2 * i * N + j + N];
165+ B[2 ][i * N + j] = matrix2[2 * i * N + j + 2 * N * N];
166+ B[3 ][i * N + j] = matrix2[2 * i * N + j + 2 * N * N + N];
167+ }
168+ }
169+
170+ thread* TMP_array = new thread[10 ];
171+ TMP_array[0 ] = thread ([&]() {TMP1 = Add (A[0 ], A[3 ], N); });
172+ TMP_array[1 ] = thread ([&]() {TMP2 = Add (B[0 ], B[3 ], N); });
173+ TMP_array[2 ] = thread ([&]() {TMP3 = Add (A[2 ], A[3 ], N); });
174+ TMP_array[3 ] = thread ([&]() {TMP4 = Sub (B[1 ], B[3 ], N); });
175+ TMP_array[4 ] = thread ([&]() {TMP5 = Sub (B[2 ], B[0 ], N); });
176+ TMP_array[5 ] = thread ([&]() {TMP6 = Add (A[0 ], A[1 ], N); });
177+ TMP_array[6 ] = thread ([&]() {TMP7 = Sub (A[2 ], A[0 ], N); });
178+ TMP_array[7 ] = thread ([&]() {TMP8 = Add (B[0 ], B[1 ], N); });
179+ TMP_array[8 ] = thread ([&]() {TMP9 = Sub (A[1 ], A[3 ], N); });
180+ TMP_array[9 ] = thread ([&]() {TMP10 = Add (B[2 ], B[3 ], N); });
181+
182+ for (size_t i = 0 ; i < 10 ; i++)
183+ TMP_array[i].join ();
184+
185+ thread* P_array = new thread[7 ];
186+ P_array[0 ] = thread ([&]() {P[0 ] = Str_alg (TMP1, TMP2, N, threshold); });
187+ P_array[1 ] = thread ([&]() {P[1 ] = Str_alg (TMP3, B[0 ], N, threshold); });
188+ P_array[2 ] = thread ([&]() {P[2 ] = Str_alg (A[0 ], TMP4, N, threshold); });
189+ P_array[3 ] = thread ([&]() {P[3 ] = Str_alg (A[3 ], TMP5, N, threshold); });
190+ P_array[4 ] = thread ([&]() {P[4 ] = Str_alg (TMP6, B[3 ], N, threshold); });
191+ P_array[5 ] = thread ([&]() {P[5 ] = Str_alg (TMP7, TMP8, N, threshold); });
192+ P_array[6 ] = thread ([&]() {P[6 ] = Str_alg (TMP9, TMP10, N, threshold); });
193+
194+ for (size_t i = 0 ; i < 7 ; i++)
195+ P_array[i].join ();
196+
197+ thread* C_array = new thread[4 ];
198+ C_array[0 ] = thread ([&]() {C[0 ] = Sub (P[0 ], P[3 ], P[6 ], P[4 ], N); });
199+ C_array[1 ] = thread ([&]() {C[1 ] = Add (P[2 ], P[4 ], N); });
200+ C_array[2 ] = thread ([&]() {C[2 ] = Add (P[1 ], P[3 ], N); });
201+ C_array[3 ] = thread ([&]() {C[3 ] = Sub (P[0 ], P[2 ], P[5 ], P[1 ], N); });
202+
203+ for (size_t i = 0 ; i < 4 ; i++)
204+ C_array[i].join ();
205+
206+ for (int i = 0 ; i < N; i++)
207+ {
208+ for (int j = 0 ; j < N; j++) {
209+ Rez[i * 2 * N + j] = C[0 ][i * N + j];
210+ Rez[i * 2 * N + j + N] = C[1 ][i * N + j];
211+ Rez[i * 2 * N + j + 2 * N * N] = C[2 ][i * N + j];
212+ Rez[i * 2 * N + j + 2 * N * N + N] = C[3 ][i * N + j];
213+ }
214+ }
215+
216+ for (int i = 0 ; i < 4 ; i++) {
217+ delete[] A[i];
218+ delete[] B[i];
219+ delete[] C[i];
220+ }
221+
222+ for (int i = 0 ; i < 7 ; i++) {
223+ delete[] P[i];
224+ }
225+
226+ delete[] TMP_array;
227+ delete[] P_array;
228+ delete[] C_array;
229+
230+ delete[] TMP1; delete[] TMP2; delete[] TMP3; delete[] TMP4; delete[] TMP5;
231+ delete[] TMP6; delete[] TMP7; delete[] TMP8; delete[] TMP9; delete[] TMP10;
232+ }
233+
234+ return Rez;
235+ }
236+
237+ int main ()
238+ {
239+ int degree;
240+ int Size;
241+
242+ std::cout << " Degree of matrix size: " ;
243+ std::cin >> degree;
244+ Size = (int )pow (2 , degree);
245+ std::cout << " Size of matrix: " << Size << std::endl<<std::endl;
246+
247+ double * matA = nullptr ;
248+ double * matB = nullptr ;
249+ double * matResStrassen = nullptr ;
250+ double * matResParallel = nullptr ;
251+
252+ matA = CreateMatrix (Size);
253+ matB = CreateMatrix (Size);
254+ matResStrassen = CreateMatrix (Size);
255+ matResParallel = CreateMatrix (Size);
256+
257+ GenerateRandomMatrix (matA, matB, Size);
258+
259+ auto start_Strassen = std::chrono::high_resolution_clock::now ();
260+ matResStrassen = Str_alg (matA, matB, Size, 64 );
261+ auto end_Strassen = std::chrono::high_resolution_clock::now ();
262+ std::chrono::duration<double > duration_Strassen = end_Strassen - start_Strassen;
263+
264+ auto start_Parallel = std::chrono::high_resolution_clock::now ();
265+ matResParallel = Strassen_Threads (matA, matB, Size, 64 );
266+ auto end_Parallel = std::chrono::high_resolution_clock::now ();
267+ std::chrono::duration<double > duration_par = end_Parallel - start_Parallel;
268+
269+ if (matComparison (matResStrassen, matResParallel, Size) != true ) { std::cout << " Mats are not equal" << std::endl << std::endl; }
270+ else { std::cout << " Mats are equal" << std::endl << std::endl; }
271+
272+ std::cout << " Strassen algorithm: " << duration_Strassen.count () << std::endl;
273+ std::cout << " Strassen parallel: " << duration_par.count () << std::endl;
274+
275+ delete[] matA;
276+ delete[] matB;
277+ delete[] matResStrassen;
278+ delete[] matResParallel;
279+ return 0 ;
280+ }
281+
282+ double * Str_alg (double * matrix1, double * matrix2, int N, int threshold)
283+ {
284+ double * Rez;
285+
286+ if (N <= threshold)
287+ Rez = defaultMult (matrix1, matrix2, N);
288+ else
289+ {
290+ Rez = CreateMatrix (N);
291+ N = N / 2 ;
292+
293+ double * A[4 ]; double * B[4 ]; double * C[4 ]; double * P[7 ];
294+
295+ double * TMP1; double * TMP2; double * TMP3; double * TMP4; double * TMP5;
296+ double * TMP6; double * TMP7; double * TMP8; double * TMP9; double * TMP10;
297+
298+ for (int i = 0 ; i < 4 ; i++)
299+ {
300+ A[i] = CreateMatrix (N);
301+ B[i] = CreateMatrix (N);
302+ }
303+
304+ for (int i = 0 ; i < N; i++)
305+ {
306+ for (int j = 0 ; j < N; j++)
307+ {
308+ A[0 ][i * N + j] = matrix1[2 * i * N + j];
309+ A[1 ][i * N + j] = matrix1[2 * i * N + j + N];
310+ A[2 ][i * N + j] = matrix1[2 * i * N + j + 2 * N * N];
311+ A[3 ][i * N + j] = matrix1[2 * i * N + j + 2 * N * N + N];
312+
313+ B[0 ][i * N + j] = matrix2[2 * i * N + j];
314+ B[1 ][i * N + j] = matrix2[2 * i * N + j + N];
315+ B[2 ][i * N + j] = matrix2[2 * i * N + j + 2 * N * N];
316+ B[3 ][i * N + j] = matrix2[2 * i * N + j + 2 * N * N + N];
317+ }
318+ }
319+
320+ TMP1 = Add (A[0 ], A[3 ], N);
321+ TMP2 = Add (B[0 ], B[3 ], N);
322+ TMP3 = Add (A[2 ], A[3 ], N);
323+ TMP4 = Sub (B[1 ], B[3 ], N);
324+ TMP5 = Sub (B[2 ], B[0 ], N);
325+ TMP6 = Add (A[0 ], A[1 ], N);
326+ TMP7 = Sub (A[2 ], A[0 ], N);
327+ TMP8 = Add (B[0 ], B[1 ], N);
328+ TMP9 = Sub (A[1 ], A[3 ], N);
329+ TMP10 = Add (B[2 ], B[3 ], N);
330+
331+ P[0 ] = Str_alg (TMP1, TMP2, N, threshold); // (A11 + A22)*(B11 + B22)
332+ P[1 ] = Str_alg (TMP3, B[0 ], N, threshold); // (A21 + A22)*B11
333+ P[2 ] = Str_alg (A[0 ], TMP4, N, threshold); // A11*(B12 - B22)
334+ P[3 ] = Str_alg (A[3 ], TMP5, N, threshold); // A22*(B21 - B11)
335+ P[4 ] = Str_alg (TMP6, B[3 ], N, threshold); // (A11 + A12)*B22
336+ P[5 ] = Str_alg (TMP7, TMP8, N, threshold); // (A21 - A11)*(B11 + B12)
337+ P[6 ] = Str_alg (TMP9, TMP10, N, threshold); // (A12 - A22)*(B21 + B22)
338+
339+ C[0 ] = Sub (P[0 ], P[3 ], P[6 ], P[4 ], N); // P1 + P4 - P5 + P7
340+ C[1 ] = Add (P[2 ], P[4 ], N); // P3 + P5
341+ C[2 ] = Add (P[1 ], P[3 ], N); // P2 + P4
342+ C[3 ] = Sub (P[0 ], P[2 ], P[5 ], P[1 ], N); // P1 - P2 + P3 + P6
343+
344+ for (int i = 0 ; i < N; i++)
345+ {
346+ for (int j = 0 ; j < N; j++) {
347+ Rez[i * 2 * N + j] = C[0 ][i * N + j];
348+ Rez[i * 2 * N + j + N] = C[1 ][i * N + j];
349+ Rez[i * 2 * N + j + 2 * N * N] = C[2 ][i * N + j];
350+ Rez[i * 2 * N + j + 2 * N * N + N] = C[3 ][i * N + j];
351+ }
352+ }
353+ for (int i = 0 ; i < 4 ; i++) {
354+ delete[] A[i];
355+ delete[] B[i];
356+ delete[] C[i];
357+ }
358+
359+ for (int i = 0 ; i < 7 ; i++) {
360+ delete[] P[i];
361+ }
362+
363+ delete[] TMP1; delete[] TMP2; delete[] TMP3; delete[] TMP4; delete[] TMP5;
364+ delete[] TMP6; delete[] TMP7; delete[] TMP8; delete[] TMP9; delete[] TMP10;
365+ }
366+
367+ return Rez;
368+ }
0 commit comments