@@ -13,7 +13,7 @@ contains
13
13
! Internal use only!
14
14
pure function matmul_chain_order(p) result(s)
15
15
integer, intent(in) :: p(:)
16
- integer :: s(1:size(p) - 2, 2: size(p) - 1), m(1: size(p) - 1, 1: size(p) - 1)
16
+ integer :: s(1:size(p) - 2, 2:size(p) - 1), m(1:size(p) - 1, 1:size(p) - 1)
17
17
integer :: n, l, i, j, k, q
18
18
n = size(p) - 1
19
19
m(:,:) = 0
@@ -40,35 +40,102 @@ contains
40
40
41
41
pure function matmul_chain_mult_${s}$_3 (m1, m2, m3, start, s) result(r)
42
42
${t}$, intent(in) :: m1(:,:), m2(:,:), m3(:,:)
43
- integer, intent(in) :: start, s(:,:)
43
+ integer, intent(in) :: start, s(:,2 :)
44
44
${t}$, allocatable :: r(:,:)
45
+ integer :: tmp
46
+ tmp = s(start, start + 2)
47
+
48
+ if (tmp == start) then
49
+ r = matmul(m1, matmul(m2, m3))
50
+ else if (tmp == start + 1) then
51
+ r = matmul(matmul(m1, m2), m3)
52
+ else
53
+ error stop "stdlib_matmul: error: unexpected s(i,j)"
54
+ end if
45
55
46
- select case (s(start, start + 2))
47
- case (1)
48
- r = matmul(m1, matmul(m2, m3))
49
- case (2)
50
- r = matmul(matmul(m1, m2), m3)
51
- case default
52
- error stop "stdlib_matmul: unexpected error unexpected s(i,j)"
53
- end select
54
56
end function matmul_chain_mult_${s}$_3
55
57
56
58
pure function matmul_chain_mult_${s}$_4 (m1, m2, m3, m4, start, s) result(r)
57
59
${t}$, intent(in) :: m1(:,:), m2(:,:), m3(:,:), m4(:,:)
58
- integer, intent(in) :: start, s(:,:)
60
+ integer, intent(in) :: start, s(:,2:)
61
+ ${t}$, allocatable :: r(:,:)
62
+ integer :: tmp
63
+ tmp = s(start, start + 3)
64
+
65
+ if (tmp == start) then
66
+ r = matmul(m1, matmul_chain_mult_${s}$_3(m2, m3, m4, start + 1, s))
67
+ else if (tmp == start + 1) then
68
+ r = matmul(matmul(m1, m2), matmul(m3, m4))
69
+ else if (tmp == start + 2) then
70
+ r = matmul(matmul_chain_mult_${s}$_3(m1, m2, m3, start, s), m4)
71
+ else
72
+ error stop "stdlib_matmul: error: unexpected s(i,j)"
73
+ end if
74
+
75
+ end function matmul_chain_mult_${s}$_4
76
+
77
+ pure module function stdlib_matmul_${s}$ (m1, m2, m3, m4, m5) result(r)
78
+ ${t}$, intent(in) :: m1(:,:), m2(:,:)
79
+ ${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:)
59
80
${t}$, allocatable :: r(:,:)
81
+ integer :: p(6), num_present
82
+ integer, allocatable :: s(:,:)
60
83
61
- select case (s(start, start + 3))
84
+ p(1) = size(m1, 1)
85
+ p(2) = size(m2, 1)
86
+ p(3) = size(m2, 2)
87
+
88
+ num_present = 2
89
+ if (present(m3)) then
90
+ p(3) = size(m3, 1)
91
+ p(4) = size(m3, 2)
92
+ num_present = num_present + 1
93
+ end if
94
+ if (present(m4)) then
95
+ p(4) = size(m4, 1)
96
+ p(5) = size(m4, 2)
97
+ num_present = num_present + 1
98
+ end if
99
+ if (present(m5)) then
100
+ p(5) = size(m5, 1)
101
+ p(6) = size(m5, 2)
102
+ num_present = num_present + 1
103
+ end if
104
+
105
+ if (num_present == 2) then
106
+ r = matmul(m1, m2)
107
+ return
108
+ end if
109
+
110
+ ! Now num_present >= 3
111
+ allocate(s(1:num_present - 1, 2:num_present))
112
+
113
+ s = matmul_chain_order(p(1: num_present + 1))
114
+
115
+ if (num_present == 3) then
116
+ r = matmul_chain_mult_${s}$_3(m1, m2, m3, 1, s)
117
+ return
118
+ else if (num_present == 4) then
119
+ r = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s)
120
+ return
121
+ end if
122
+
123
+ ! Now num_present is 5
124
+
125
+ select case (s(1, 5))
62
126
case (1)
63
- r = matmul(m1, matmul_chain_mult_${s}$_3 (m2, m3, m4, start + 1 , s))
127
+ r = matmul(m1, matmul_chain_mult_${s}$_4 (m2, m3, m4, m5, 2 , s))
64
128
case (2)
65
- r = matmul(matmul(m1, m2), matmul (m3, m4))
129
+ r = matmul(matmul(m1, m2), matmul_chain_mult_${s}$_3 (m3, m4, m5, 3, s ))
66
130
case (3)
67
- r = matmul(matmul_chain_mult_${s}$_3(m1, m2, m3, start, s), m4)
131
+ r = matmul(matmul_chain_mult_${s}$_3(m1, m2, m3, 1, s), matmul(m4, m5))
132
+ case (4)
133
+ r = matmul(matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s), m5)
68
134
case default
69
- error stop "stdlib_matmul: unexpected error unexpected s(i,j)"
135
+ error stop "stdlib_matmul: error: unexpected s(i,j)"
70
136
end select
71
- end function matmul_chain_mult_${s}$_4
137
+
138
+ end function stdlib_matmul_${s}$
72
139
73
140
#:endfor
74
141
end submodule stdlib_intrinsics_matmul
0 commit comments