Skip to content

Commit a3d51b9

Browse files
authored
Merge pull request #37 from vAporInside/master
Added Strassen’s Algorithm: CPP
2 parents e29ba6d + 733eb48 commit a3d51b9

File tree

1 file changed

+280
-0
lines changed
  • Divide and Conquer/Strassen’s Matrix Multiplication Algorithm/CPP

1 file changed

+280
-0
lines changed
Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
// CPP program to implement Strassen’s Matrix
2+
// Multiplication Algorithm
3+
#include <bits/stdc++.h>
4+
using namespace std;
5+
typedef long long lld;
6+
7+
/* Strassen's Algorithm for matrix multiplication
8+
Complexity: O(n^2.808) */
9+
10+
inline lld** MatrixMultiply(lld** a, lld** b, int n, int l, int m)
11+
{
12+
lld** c = new lld*[n];
13+
for (int i = 0; i < n; i++)
14+
c[i] = new lld[m];
15+
16+
for (int i = 0; i < n; i++) {
17+
for (int j = 0; j < m; j++) {
18+
c[i][j] = 0;
19+
for (int k = 0; k < l; k++) {
20+
c[i][j] += a[i][k] * b[k][j];
21+
}
22+
}
23+
}
24+
return c;
25+
}
26+
27+
inline lld** Strassen(lld** a, lld** b, int n, int l, int m)
28+
{
29+
if (n == 1 || l == 1 || m == 1)
30+
return MatrixMultiply(a, b, n, l, m);
31+
32+
lld** c = new lld*[n];
33+
for (int i = 0; i < n; i++)
34+
c[i] = new lld[m];
35+
36+
int adjN = (n >> 1) + (n & 1);
37+
int adjL = (l >> 1) + (l & 1);
38+
int adjM = (m >> 1) + (m & 1);
39+
40+
lld**** As = new lld***[2];
41+
for (int x = 0; x < 2; x++) {
42+
As[x] = new lld**[2];
43+
for (int y = 0; y < 2; y++) {
44+
As[x][y] = new lld*[adjN];
45+
for (int i = 0; i < adjN; i++) {
46+
As[x][y][i] = new lld[adjL];
47+
for (int j = 0; j < adjL; j++) {
48+
int I = i + (x & 1) * adjN;
49+
int J = j + (y & 1) * adjL;
50+
As[x][y][i][j] = (I < n && J < l) ? a[I][J] : 0;
51+
}
52+
}
53+
}
54+
}
55+
56+
lld**** Bs = new lld***[2];
57+
for (int x = 0; x < 2; x++) {
58+
Bs[x] = new lld**[2];
59+
for (int y = 0; y < 2; y++) {
60+
Bs[x][y] = new lld*[adjN];
61+
for (int i = 0; i < adjL; i++) {
62+
Bs[x][y][i] = new lld[adjM];
63+
for (int j = 0; j < adjM; j++) {
64+
int I = i + (x & 1) * adjL;
65+
int J = j + (y & 1) * adjM;
66+
Bs[x][y][i][j] = (I < l && J < m) ? b[I][J] : 0;
67+
}
68+
}
69+
}
70+
}
71+
72+
lld*** s = new lld**[10];
73+
for (int i = 0; i < 10; i++) {
74+
switch (i) {
75+
case 0:
76+
s[i] = new lld*[adjL];
77+
for (int j = 0; j < adjL; j++) {
78+
s[i][j] = new lld[adjM];
79+
for (int k = 0; k < adjM; k++) {
80+
s[i][j][k] = Bs[0][1][j][k] - Bs[1][1][j][k];
81+
}
82+
}
83+
break;
84+
case 1:
85+
s[i] = new lld*[adjN];
86+
for (int j = 0; j < adjN; j++) {
87+
s[i][j] = new lld[adjL];
88+
for (int k = 0; k < adjL; k++) {
89+
s[i][j][k] = As[0][0][j][k] + As[0][1][j][k];
90+
}
91+
}
92+
break;
93+
case 2:
94+
s[i] = new lld*[adjN];
95+
for (int j = 0; j < adjN; j++) {
96+
s[i][j] = new lld[adjL];
97+
for (int k = 0; k < adjL; k++) {
98+
s[i][j][k] = As[1][0][j][k] + As[1][1][j][k];
99+
}
100+
}
101+
break;
102+
case 3:
103+
s[i] = new lld*[adjL];
104+
for (int j = 0; j < adjL; j++) {
105+
s[i][j] = new lld[adjM];
106+
for (int k = 0; k < adjM; k++) {
107+
s[i][j][k] = Bs[1][0][j][k] - Bs[0][0][j][k];
108+
}
109+
}
110+
break;
111+
case 4:
112+
s[i] = new lld*[adjN];
113+
for (int j = 0; j < adjN; j++) {
114+
s[i][j] = new lld[adjL];
115+
for (int k = 0; k < adjL; k++) {
116+
s[i][j][k] = As[0][0][j][k] + As[1][1][j][k];
117+
}
118+
}
119+
break;
120+
case 5:
121+
s[i] = new lld*[adjL];
122+
for (int j = 0; j < adjL; j++) {
123+
s[i][j] = new lld[adjM];
124+
for (int k = 0; k < adjM; k++) {
125+
s[i][j][k] = Bs[0][0][j][k] + Bs[1][1][j][k];
126+
}
127+
}
128+
break;
129+
case 6:
130+
s[i] = new lld*[adjN];
131+
for (int j = 0; j < adjN; j++) {
132+
s[i][j] = new lld[adjL];
133+
for (int k = 0; k < adjL; k++) {
134+
s[i][j][k] = As[0][1][j][k] - As[1][1][j][k];
135+
}
136+
}
137+
break;
138+
case 7:
139+
s[i] = new lld*[adjL];
140+
for (int j = 0; j < adjL; j++) {
141+
s[i][j] = new lld[adjM];
142+
for (int k = 0; k < adjM; k++) {
143+
s[i][j][k] = Bs[1][0][j][k] + Bs[1][1][j][k];
144+
}
145+
}
146+
break;
147+
case 8:
148+
s[i] = new lld*[adjN];
149+
for (int j = 0; j < adjN; j++) {
150+
s[i][j] = new lld[adjL];
151+
for (int k = 0; k < adjL; k++) {
152+
s[i][j][k] = As[0][0][j][k] - As[1][0][j][k];
153+
}
154+
}
155+
break;
156+
case 9:
157+
s[i] = new lld*[adjL];
158+
for (int j = 0; j < adjL; j++) {
159+
s[i][j] = new lld[adjM];
160+
for (int k = 0; k < adjM; k++) {
161+
s[i][j][k] = Bs[0][0][j][k] + Bs[0][1][j][k];
162+
}
163+
}
164+
break;
165+
}
166+
}
167+
168+
lld*** p = new lld**[7];
169+
p[0] = Strassen(As[0][0], s[0], adjN, adjL, adjM);
170+
p[1] = Strassen(s[1], Bs[1][1], adjN, adjL, adjM);
171+
p[2] = Strassen(s[2], Bs[0][0], adjN, adjL, adjM);
172+
p[3] = Strassen(As[1][1], s[3], adjN, adjL, adjM);
173+
p[4] = Strassen(s[4], s[5], adjN, adjL, adjM);
174+
p[5] = Strassen(s[6], s[7], adjN, adjL, adjM);
175+
p[6] = Strassen(s[8], s[9], adjN, adjL, adjM);
176+
177+
for (int i = 0; i < adjN; i++) {
178+
for (int j = 0; j < adjM; j++) {
179+
c[i][j] = p[4][i][j] + p[3][i][j] - p[1][i][j] + p[5][i][j];
180+
if (j + adjM < m)
181+
c[i][j + adjM] = p[0][i][j] + p[1][i][j];
182+
if (i + adjN < n)
183+
c[i + adjN][j] = p[2][i][j] + p[3][i][j];
184+
if (i + adjN < n && j + adjM < m)
185+
c[i + adjN][j + adjM] = p[4][i][j] + p[0][i][j] - p[2][i][j] - p[6][i][j];
186+
}
187+
}
188+
189+
for (int x = 0; x < 2; x++) {
190+
for (int y = 0; y < 2; y++) {
191+
for (int i = 0; i < adjN; i++) {
192+
delete[] As[x][y][i];
193+
}
194+
delete[] As[x][y];
195+
}
196+
delete[] As[x];
197+
}
198+
delete[] As;
199+
200+
for (int x = 0; x < 2; x++) {
201+
for (int y = 0; y < 2; y++) {
202+
for (int i = 0; i < adjL; i++) {
203+
delete[] Bs[x][y][i];
204+
}
205+
delete[] Bs[x][y];
206+
}
207+
delete[] Bs[x];
208+
}
209+
delete[] Bs;
210+
211+
for (int i = 0; i < 10; i++) {
212+
switch (i) {
213+
case 0:
214+
case 3:
215+
case 5:
216+
case 7:
217+
case 9:
218+
for (int j = 0; j < adjL; j++) {
219+
delete[] s[i][j];
220+
}
221+
break;
222+
case 1:
223+
case 2:
224+
case 4:
225+
case 6:
226+
case 8:
227+
for (int j = 0; j < adjN; j++) {
228+
delete[] s[i][j];
229+
}
230+
break;
231+
}
232+
delete[] s[i];
233+
}
234+
delete[] s;
235+
236+
for (int i = 0; i < 7; i++) {
237+
for (int j = 0; j < (n >> 1); j++) {
238+
delete[] p[i][j];
239+
}
240+
delete[] p[i];
241+
}
242+
delete[] p;
243+
244+
return c;
245+
}
246+
247+
int main()
248+
{
249+
lld** matA;
250+
matA = new lld*[2];
251+
for (int i = 0; i < 2; i++)
252+
matA[i] = new lld[3];
253+
matA[0][0] = 1;
254+
matA[0][1] = 2;
255+
matA[0][2] = 3;
256+
matA[1][0] = 4;
257+
matA[1][1] = 5;
258+
matA[1][2] = 6;
259+
260+
lld** matB;
261+
matB = new lld*[3];
262+
for (int i = 0; i < 3; i++)
263+
matB[i] = new lld[2];
264+
matB[0][0] = 7;
265+
matB[0][1] = 8;
266+
matB[1][0] = 9;
267+
matB[1][1] = 10;
268+
matB[2][0] = 11;
269+
matB[2][1] = 12;
270+
271+
lld** matC = Strassen(matA, matB, 2, 3, 2);
272+
for (int i = 0; i < 2; i++) {
273+
for (int j = 0; j < 2; j++) {
274+
printf("%lld ", matC[i][j]);
275+
}
276+
printf("\n");
277+
}
278+
279+
return 0;
280+
}

0 commit comments

Comments
 (0)