|
| 1 | +function [P orderstruct] = mmtimes(varargin) |
| 2 | +% P = mmtimes(M1, M2, ... Mn) |
| 3 | +% return a chain matrix product P = M1*M2* ... *Mn |
| 4 | +% |
| 5 | +% {Mi} are matrices with compatible dimension: size(Mi,2) = size(Mi+1,1) |
| 6 | +% |
| 7 | +% Because the matrix multiplication is associative; the chain product can |
| 8 | +% be carried out with different order, leading to the same result (up to |
| 9 | +% round-off error). MMTIMES uses "optimal" order of binary product to |
| 10 | +% reduce the computational effort (probably the accuracy is also improved). |
| 11 | +% |
| 12 | +% The function assumes the cost of the product of (m x n) with (n x p) |
| 13 | +% matrices is (m*n*p). This assumption is typically true for full matrix. |
| 14 | +% |
| 15 | +% Notes: |
| 16 | +% Scalar matrix are groupped together, and the rest will be |
| 17 | +% multiplied with optimal order. |
| 18 | +% |
| 19 | +% To get the the structure that stores the best order, call with the |
| 20 | +% second outputs: |
| 21 | +% >> [P orderstruct] = mmtimes(M1, M2, ... Mn); |
| 22 | +% % This structure can be used later if the input matrices have the |
| 23 | +% % same sizes as those in the first call (but with different contents) |
| 24 | +% >> P = mmtimes(M1, M2, ... Mn, orderstruct); |
| 25 | +% |
| 26 | +% See also: mtimes |
| 27 | +% |
| 28 | +% Author: Bruno Luong <[email protected]> |
| 29 | +% Orginal: 19-Jun-2010 |
| 30 | +% 20-Jun-2010: quicker top-down algorithm |
| 31 | +% 23-Jun-2010: treat the case of scalars |
| 32 | +% 16-Aug-2010: passing optimal order as output/input argument |
| 33 | + |
| 34 | +Matrices = varargin; |
| 35 | + |
| 36 | +buildexpr = false; |
| 37 | +if ~isempty(Matrices) && isstruct(Matrices{end}) |
| 38 | + orderstruct = Matrices{end}; |
| 39 | + Matrices(end) = []; |
| 40 | +else |
| 41 | + % Detect scalars |
| 42 | + iscst = cellfun('length',Matrices) == 1; |
| 43 | + if any(iscst) |
| 44 | + % scalars are multiplied apart |
| 45 | + cst = prod([Matrices{iscst}]); |
| 46 | + Matrices = Matrices(~iscst); |
| 47 | + else |
| 48 | + cst = 1; |
| 49 | + end |
| 50 | + % Size of matrices |
| 51 | + szmats = [cellfun('size',Matrices,1) size(Matrices{end},2)]; |
| 52 | + s = MatrixChainOrder(szmats); |
| 53 | + |
| 54 | + orderstruct = struct('cst', cst, ... |
| 55 | + 's', s, ... |
| 56 | + 'szmats', szmats); |
| 57 | + |
| 58 | + if nargout>=2 |
| 59 | + % Prepare to build the string expression |
| 60 | + vnames = arrayfun(@inputname, 1:nargin, 'UniformOutput', false); |
| 61 | + % Default names, e.g., M1, M2, ..., for inputs that is not single variable |
| 62 | + noname = cellfun('isempty', vnames); |
| 63 | + vnames(noname) = arrayfun(@(i) sprintf('M%d', i), find(noname), 'UniformOutput', false); |
| 64 | + if any(iscst) |
| 65 | + % String '(M1*M2*...)' for constants |
| 66 | + cstexpr = strcat(vnames(iscst),'*'); |
| 67 | + cstexpr = strcat(cstexpr{:}); |
| 68 | + cstexpr = ['(' cstexpr(1:end-1) ')']; |
| 69 | + else |
| 70 | + cstexpr = ''; |
| 71 | + end |
| 72 | + vnames = vnames(~iscst); |
| 73 | + buildexpr = true; |
| 74 | + end |
| 75 | +end |
| 76 | + |
| 77 | +if ~isempty(Matrices) |
| 78 | + P = ProdEngine(1,length(Matrices),orderstruct.s,Matrices); |
| 79 | + if orderstruct.cst~=1 |
| 80 | + P = orderstruct.cst*P; |
| 81 | + end |
| 82 | + if buildexpr |
| 83 | + expr = Prodexpr(1,length(Matrices),orderstruct.s,vnames); |
| 84 | + if ~isempty(cstexpr) |
| 85 | + % Concatenate the constant expression in front |
| 86 | + expr = [cstexpr '*' expr]; |
| 87 | + end |
| 88 | + orderstruct.expr = expr; |
| 89 | + end |
| 90 | +else |
| 91 | + P = orderstruct.cst; |
| 92 | + if nargout>=2 |
| 93 | + orderstruct.expr = cstexpr; |
| 94 | + end |
| 95 | +end |
| 96 | + |
| 97 | +end % mmtimes |
| 98 | + |
| 99 | + |
| 100 | +%% |
| 101 | +function [s qmin] = MatrixChainOrder(szmats) |
| 102 | +% Find the best ordered chain-product, the best splitting index |
| 103 | +% of M(i)*...*M(j) is stored in s(j,i) of the array s (only the lower |
| 104 | +% part is filled) |
| 105 | +% Top-down dynamic programming, complexity O(n^3) |
| 106 | + |
| 107 | +n = length(szmats)-1; |
| 108 | +s = zeros(n); |
| 109 | + |
| 110 | +pk = szmats(2:n); |
| 111 | +ij = (0:n-1)*(n+1)+1; |
| 112 | +left = zeros(1,n-1); |
| 113 | +right = zeros(1,n-1); |
| 114 | +L = 1; |
| 115 | +while true % off-diagonal offset |
| 116 | + q = zeros(size(pk)); |
| 117 | + for j=1:n-L % this is faster and BSXFUN or product with DIAGONAL matrix |
| 118 | + q(:,j) = (szmats(j)*szmats(j+L+1))*pk(:,j); |
| 119 | + end |
| 120 | + q = q + left + right; |
| 121 | + [qmin loc] = min(q, [], 1); |
| 122 | + s(ij(1:end-L)+L) = (1:n-L)+loc; |
| 123 | + |
| 124 | + if L<n-1 |
| 125 | + pk = [pk(:,1:end-1); |
| 126 | + pk(end,2:end)]; |
| 127 | + left = [left(:,1:end-1); |
| 128 | + qmin(1:end-1)]; |
| 129 | + right = [qmin(2:end); |
| 130 | + right(:,2:end)]; |
| 131 | + L = L+1; |
| 132 | + else |
| 133 | + break |
| 134 | + end % if |
| 135 | +end % while-loop |
| 136 | + |
| 137 | +end % MatrixChainOrder |
| 138 | + |
| 139 | +%% |
| 140 | +function P = ProdEngine(i,j,s,Matrices) |
| 141 | +% Perform matrix product from the optimal order, recursive engine |
| 142 | +if i==j |
| 143 | + P = Matrices{i}; |
| 144 | +else |
| 145 | + k = s(j,i); |
| 146 | + P = ProdEngine(i,k-1,s,Matrices)*ProdEngine(k,j,s,Matrices); |
| 147 | +end |
| 148 | + |
| 149 | +end |
| 150 | + |
| 151 | +%% |
| 152 | +function expr = Prodexpr(i,j,s,vnames) |
| 153 | +% Return the string expression of the optimal order |
| 154 | +if i==j |
| 155 | + expr = vnames{i}; |
| 156 | +else |
| 157 | + k = s(j,i); |
| 158 | + expr = ['(' Prodexpr(i,k-1,s,vnames) '*' Prodexpr(k,j,s,vnames) ')']; |
| 159 | +end |
| 160 | + |
| 161 | +end % Prodexpr |
| 162 | + |
| 163 | + |
0 commit comments