Skip to content

Commit 5f5c5a9

Browse files
committed
add implementation, refactor select to if clauses
1 parent ebf92d7 commit 5f5c5a9

File tree

1 file changed

+84
-17
lines changed

1 file changed

+84
-17
lines changed

src/stdlib_intrinsics_matmul.fypp

+84-17
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ contains
1313
! Internal use only!
1414
pure function matmul_chain_order(p) result(s)
1515
integer, intent(in) :: p(:)
16-
integer :: s(1:size(p) - 2, 2: size(p) - 1), m(1: size(p) - 1, 1: size(p) - 1)
16+
integer :: s(1:size(p) - 2, 2:size(p) - 1), m(1:size(p) - 1, 1:size(p) - 1)
1717
integer :: n, l, i, j, k, q
1818
n = size(p) - 1
1919
m(:,:) = 0
@@ -40,35 +40,102 @@ contains
4040

4141
pure function matmul_chain_mult_${s}$_3 (m1, m2, m3, start, s) result(r)
4242
${t}$, intent(in) :: m1(:,:), m2(:,:), m3(:,:)
43-
integer, intent(in) :: start, s(:,:)
43+
integer, intent(in) :: start, s(:,2:)
4444
${t}$, allocatable :: r(:,:)
45+
integer :: tmp
46+
tmp = s(start, start + 2)
47+
48+
if (tmp == start) then
49+
r = matmul(m1, matmul(m2, m3))
50+
else if (tmp == start + 1) then
51+
r = matmul(matmul(m1, m2), m3)
52+
else
53+
error stop "stdlib_matmul: error: unexpected s(i,j)"
54+
end if
4555

46-
select case (s(start, start + 2))
47-
case (1)
48-
r = matmul(m1, matmul(m2, m3))
49-
case (2)
50-
r = matmul(matmul(m1, m2), m3)
51-
case default
52-
error stop "stdlib_matmul: unexpected error unexpected s(i,j)"
53-
end select
5456
end function matmul_chain_mult_${s}$_3
5557

5658
pure function matmul_chain_mult_${s}$_4 (m1, m2, m3, m4, start, s) result(r)
5759
${t}$, intent(in) :: m1(:,:), m2(:,:), m3(:,:), m4(:,:)
58-
integer, intent(in) :: start, s(:,:)
60+
integer, intent(in) :: start, s(:,2:)
61+
${t}$, allocatable :: r(:,:)
62+
integer :: tmp
63+
tmp = s(start, start + 3)
64+
65+
if (tmp == start) then
66+
r = matmul(m1, matmul_chain_mult_${s}$_3(m2, m3, m4, start + 1, s))
67+
else if (tmp == start + 1) then
68+
r = matmul(matmul(m1, m2), matmul(m3, m4))
69+
else if (tmp == start + 2) then
70+
r = matmul(matmul_chain_mult_${s}$_3(m1, m2, m3, start, s), m4)
71+
else
72+
error stop "stdlib_matmul: error: unexpected s(i,j)"
73+
end if
74+
75+
end function matmul_chain_mult_${s}$_4
76+
77+
pure module function stdlib_matmul_${s}$ (m1, m2, m3, m4, m5) result(r)
78+
${t}$, intent(in) :: m1(:,:), m2(:,:)
79+
${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:)
5980
${t}$, allocatable :: r(:,:)
81+
integer :: p(6), num_present
82+
integer, allocatable :: s(:,:)
6083

61-
select case (s(start, start + 3))
84+
p(1) = size(m1, 1)
85+
p(2) = size(m2, 1)
86+
p(3) = size(m2, 2)
87+
88+
num_present = 2
89+
if (present(m3)) then
90+
p(3) = size(m3, 1)
91+
p(4) = size(m3, 2)
92+
num_present = num_present + 1
93+
end if
94+
if (present(m4)) then
95+
p(4) = size(m4, 1)
96+
p(5) = size(m4, 2)
97+
num_present = num_present + 1
98+
end if
99+
if (present(m5)) then
100+
p(5) = size(m5, 1)
101+
p(6) = size(m5, 2)
102+
num_present = num_present + 1
103+
end if
104+
105+
if (num_present == 2) then
106+
r = matmul(m1, m2)
107+
return
108+
end if
109+
110+
! Now num_present >= 3
111+
allocate(s(1:num_present - 1, 2:num_present))
112+
113+
s = matmul_chain_order(p(1: num_present + 1))
114+
115+
if (num_present == 3) then
116+
r = matmul_chain_mult_${s}$_3(m1, m2, m3, 1, s)
117+
return
118+
else if (num_present == 4) then
119+
r = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s)
120+
return
121+
end if
122+
123+
! Now num_present is 5
124+
125+
select case (s(1, 5))
62126
case (1)
63-
r = matmul(m1, matmul_chain_mult_${s}$_3(m2, m3, m4, start + 1, s))
127+
r = matmul(m1, matmul_chain_mult_${s}$_4(m2, m3, m4, m5, 2, s))
64128
case (2)
65-
r = matmul(matmul(m1, m2), matmul(m3, m4))
129+
r = matmul(matmul(m1, m2), matmul_chain_mult_${s}$_3(m3, m4, m5, 3, s))
66130
case (3)
67-
r = matmul(matmul_chain_mult_${s}$_3(m1, m2, m3, start, s), m4)
131+
r = matmul(matmul_chain_mult_${s}$_3(m1, m2, m3, 1, s), matmul(m4, m5))
132+
case (4)
133+
r = matmul(matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s), m5)
68134
case default
69-
error stop "stdlib_matmul: unexpected error unexpected s(i,j)"
135+
error stop "stdlib_matmul: error: unexpected s(i,j)"
70136
end select
71-
end function matmul_chain_mult_${s}$_4
137+
138+
end function stdlib_matmul_${s}$
72139

73140
#:endfor
74141
end submodule stdlib_intrinsics_matmul

0 commit comments

Comments
 (0)