@@ -37,99 +37,38 @@ contains
37
37
end function matmul_chain_order
38
38
39
39
#:for k, t, s in I_KINDS_TYPES + R_KINDS_TYPES + C_KINDS_TYPES
40
+
41
+ pure function matmul_chain_mult_${s}$_3 (m1, m2, m3, start, s) result(r)
42
+ ${t}$, intent(in) :: m1(:,:), m2(:,:), m3(:,:)
43
+ integer, intent(in) :: start, s(:,:)
44
+ ${t}$, allocatable :: r(:,:)
40
45
41
- pure module function stdlib_matmul_${s}$_3 (a, b, c) result(d)
42
- ${t}$, intent(in) :: a(:,:), b(:,:), c(:,:)
43
- ${t}$, allocatable :: d(:,:)
44
- integer :: sa(2), sb(2), sc(2), cost1, cost2
45
- sa = shape(a)
46
- sb = shape(b)
47
- sc = shape(c)
48
-
49
- if ((sa(2) /= sb(1)) .or. (sb(2) /= sc(1))) then
50
- error stop "stdlib_matmul: Incompatible array shapes"
51
- end if
52
-
53
- ! computes the cost (number of scalar multiplications required)
54
- ! cost(A, B) = shape(A)(1) * shape(A)(2) * shape(B)(2)
55
- cost1 = sa(1) * sa(2) * sb(2) + sa(1) * sb(2) * sc(2) ! ((AB)C)
56
- cost2 = sb(1) * sb(2) * sc(2) + sa(1) * sa(2) * sc(2) ! (A(BC))
57
-
58
- if (cost1 < cost2) then
59
- d = matmul(matmul(a, b), c)
60
- else
61
- d = matmul(a, matmul(b, c))
62
- end if
63
- end function stdlib_matmul_${s}$_3
64
-
65
- pure module function stdlib_matmul_${s}$_4 (a, b, c, d) result(e)
66
- ${t}$, intent(in) :: a(:,:), b(:,:), c(:,:), d(:,:)
67
- ${t}$, allocatable :: e(:,:)
68
- integer :: p(5), i
69
- integer :: s(3,2:4)
70
-
71
- p(1) = size(a, 1)
72
- p(2) = size(b, 1)
73
- p(3) = size(c, 1)
74
- p(4) = size(d, 1)
75
- p(5) = size(d, 2)
76
-
77
- s = matmul_chain_order(p)
78
-
79
- select case (s(1,4))
46
+ select case (s(start, start + 2))
80
47
case (1)
81
- select case (s(2, 4))
82
- case (2)
83
- e = matmul(a, matmul(b, matmul(c, d)))
84
- case (3)
85
- e = matmul(a, matmul(matmul(b, c), d))
86
- case default
87
- error stop "stdlib_matmul: unexpected error unexpected s(i,j)"
88
- end select
48
+ r = matmul(m1, matmul(m2, m3))
89
49
case (2)
90
- e = matmul(matmul(a, b), matmul(c, d))
91
- case (3)
92
- select case (s(1, 3))
93
- case (1)
94
- e = matmul(matmul(a, matmul(b, c)), d)
95
- case (2)
96
- e = matmul(matmul(matmul(a, b), c), d)
97
- case default
98
- error stop "stdlib_matmul: unexpected error unexpected s(i,j)"
99
- end select
50
+ r = matmul(matmul(m1, m2), m3)
100
51
case default
101
52
error stop "stdlib_matmul: unexpected error unexpected s(i,j)"
102
53
end select
103
- end function stdlib_matmul_${s}$_4
104
-
105
- pure module function stdlib_matmul_${s}$_5 (a, b, c, d, e) result(f)
106
- ${t}$, intent(in) :: a(:,:), b(:,:), c(:,:), d(:,:), e(:,:)
107
- ${t}$, allocatable :: f(:,:)
108
- integer :: p(6), i
109
- integer :: s(4,2:5)
110
-
111
- p(1) = size(a, 1)
112
- p(2) = size(b, 1)
113
- p(3) = size(c, 1)
114
- p(4) = size(d, 1)
115
- p(5) = size(e, 1)
116
- p(6) = size(e, 2)
54
+ end function matmul_chain_mult_${s}$_3
117
55
118
- s = matmul_chain_order(p)
56
+ pure function matmul_chain_mult_${s}$_4 (m1, m2, m3, m4, start, s) result(r)
57
+ ${t}$, intent(in) :: m1(:,:), m2(:,:), m3(:,:), m4(:,:)
58
+ integer, intent(in) :: start, s(:,:)
59
+ ${t}$, allocatable :: r(:,:)
119
60
120
- select case (s(1,5 ))
61
+ select case (s(start, start + 3 ))
121
62
case (1)
122
- f = matmul(a, stdlib_matmul(b, c, d, e ))
63
+ r = matmul(m1, matmul_chain_mult_${s}$_3(m2, m3, m4, start + 1, s ))
123
64
case (2)
124
- f = matmul(matmul(a, b ), stdlib_matmul(c, d, e ))
65
+ r = matmul(matmul(m1, m2 ), matmul(m3, m4 ))
125
66
case (3)
126
- f = matmul(stdlib_matmul(a, b ,c), matmul(d, e))
127
- case (4)
128
- f = matmul(stdlib_matmul(a, b, c, d), e)
67
+ r = matmul(matmul_chain_mult_${s}$_3(m1, m2, m3, start, s), m4)
129
68
case default
130
69
error stop "stdlib_matmul: unexpected error unexpected s(i,j)"
131
70
end select
132
- end function stdlib_matmul_ ${s}$_5
71
+ end function matmul_chain_mult_ ${s}$_4
133
72
134
73
#:endfor
135
74
end submodule stdlib_intrinsics_matmul
0 commit comments