Skip to content

Commit ebf92d7

Browse files
committed
add helper functions
1 parent 35a5a28 commit ebf92d7

File tree

1 file changed

+18
-79
lines changed

1 file changed

+18
-79
lines changed

src/stdlib_intrinsics_matmul.fypp

+18-79
Original file line numberDiff line numberDiff line change
@@ -37,99 +37,38 @@ contains
3737
end function matmul_chain_order
3838

3939
#:for k, t, s in I_KINDS_TYPES + R_KINDS_TYPES + C_KINDS_TYPES
40+
41+
pure function matmul_chain_mult_${s}$_3 (m1, m2, m3, start, s) result(r)
42+
${t}$, intent(in) :: m1(:,:), m2(:,:), m3(:,:)
43+
integer, intent(in) :: start, s(:,:)
44+
${t}$, allocatable :: r(:,:)
4045

41-
pure module function stdlib_matmul_${s}$_3 (a, b, c) result(d)
42-
${t}$, intent(in) :: a(:,:), b(:,:), c(:,:)
43-
${t}$, allocatable :: d(:,:)
44-
integer :: sa(2), sb(2), sc(2), cost1, cost2
45-
sa = shape(a)
46-
sb = shape(b)
47-
sc = shape(c)
48-
49-
if ((sa(2) /= sb(1)) .or. (sb(2) /= sc(1))) then
50-
error stop "stdlib_matmul: Incompatible array shapes"
51-
end if
52-
53-
! computes the cost (number of scalar multiplications required)
54-
! cost(A, B) = shape(A)(1) * shape(A)(2) * shape(B)(2)
55-
cost1 = sa(1) * sa(2) * sb(2) + sa(1) * sb(2) * sc(2) ! ((AB)C)
56-
cost2 = sb(1) * sb(2) * sc(2) + sa(1) * sa(2) * sc(2) ! (A(BC))
57-
58-
if (cost1 < cost2) then
59-
d = matmul(matmul(a, b), c)
60-
else
61-
d = matmul(a, matmul(b, c))
62-
end if
63-
end function stdlib_matmul_${s}$_3
64-
65-
pure module function stdlib_matmul_${s}$_4 (a, b, c, d) result(e)
66-
${t}$, intent(in) :: a(:,:), b(:,:), c(:,:), d(:,:)
67-
${t}$, allocatable :: e(:,:)
68-
integer :: p(5), i
69-
integer :: s(3,2:4)
70-
71-
p(1) = size(a, 1)
72-
p(2) = size(b, 1)
73-
p(3) = size(c, 1)
74-
p(4) = size(d, 1)
75-
p(5) = size(d, 2)
76-
77-
s = matmul_chain_order(p)
78-
79-
select case (s(1,4))
46+
select case (s(start, start + 2))
8047
case (1)
81-
select case (s(2, 4))
82-
case (2)
83-
e = matmul(a, matmul(b, matmul(c, d)))
84-
case (3)
85-
e = matmul(a, matmul(matmul(b, c), d))
86-
case default
87-
error stop "stdlib_matmul: unexpected error unexpected s(i,j)"
88-
end select
48+
r = matmul(m1, matmul(m2, m3))
8949
case (2)
90-
e = matmul(matmul(a, b), matmul(c, d))
91-
case (3)
92-
select case (s(1, 3))
93-
case (1)
94-
e = matmul(matmul(a, matmul(b, c)), d)
95-
case (2)
96-
e = matmul(matmul(matmul(a, b), c), d)
97-
case default
98-
error stop "stdlib_matmul: unexpected error unexpected s(i,j)"
99-
end select
50+
r = matmul(matmul(m1, m2), m3)
10051
case default
10152
error stop "stdlib_matmul: unexpected error unexpected s(i,j)"
10253
end select
103-
end function stdlib_matmul_${s}$_4
104-
105-
pure module function stdlib_matmul_${s}$_5 (a, b, c, d, e) result(f)
106-
${t}$, intent(in) :: a(:,:), b(:,:), c(:,:), d(:,:), e(:,:)
107-
${t}$, allocatable :: f(:,:)
108-
integer :: p(6), i
109-
integer :: s(4,2:5)
110-
111-
p(1) = size(a, 1)
112-
p(2) = size(b, 1)
113-
p(3) = size(c, 1)
114-
p(4) = size(d, 1)
115-
p(5) = size(e, 1)
116-
p(6) = size(e, 2)
54+
end function matmul_chain_mult_${s}$_3
11755

118-
s = matmul_chain_order(p)
56+
pure function matmul_chain_mult_${s}$_4 (m1, m2, m3, m4, start, s) result(r)
57+
${t}$, intent(in) :: m1(:,:), m2(:,:), m3(:,:), m4(:,:)
58+
integer, intent(in) :: start, s(:,:)
59+
${t}$, allocatable :: r(:,:)
11960

120-
select case (s(1,5))
61+
select case (s(start, start + 3))
12162
case (1)
122-
f = matmul(a, stdlib_matmul(b, c, d, e))
63+
r = matmul(m1, matmul_chain_mult_${s}$_3(m2, m3, m4, start + 1, s))
12364
case (2)
124-
f = matmul(matmul(a, b), stdlib_matmul(c, d, e))
65+
r = matmul(matmul(m1, m2), matmul(m3, m4))
12566
case (3)
126-
f = matmul(stdlib_matmul(a, b ,c), matmul(d, e))
127-
case (4)
128-
f = matmul(stdlib_matmul(a, b, c, d), e)
67+
r = matmul(matmul_chain_mult_${s}$_3(m1, m2, m3, start, s), m4)
12968
case default
13069
error stop "stdlib_matmul: unexpected error unexpected s(i,j)"
13170
end select
132-
end function stdlib_matmul_${s}$_5
71+
end function matmul_chain_mult_${s}$_4
13372

13473
#:endfor
13574
end submodule stdlib_intrinsics_matmul

0 commit comments

Comments
 (0)