-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathspm_gmm.m
1257 lines (1180 loc) · 44 KB
/
spm_gmm.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
function varargout = spm_gmm(X, varargin)
%__________________________________________________________________________
%
% Fit a [Bayesian] Gaussian mixture model to observed [weighted] data.
%
% FORMAT [Z,MU,A,PI,...] = spm_gmm(X,...)
%
% MANDATORY
% ---------
% X - NxP matrix of observed values
%
% OPTIONAL
% --------
% K - Number of cluster [0=guess from options, 2 if cannot guess]
% W - [N]x1 Vector of weights associated with each observation [1]
%
% KEYWORD
% -------
% PropPrior - 1x[K] vector of Dirichlet priors [0=ML]
% or NxK matrix of fixed observation-wise proportions.
% GaussPrior - {MU(PxK),b(1x[K]),V([PxP]x[K]),n(1x[K])} Gauss-Wishart prior
% [{}=ML]
% Prune - Threshold on proportions to prune uninformative clusters
% [0=no pruning]
% Missing - Infer missing data [true]
% Start - Starting method: METHOD or {METHOD, PRECISION} with
% METHOD = ['kmeans'],'linspace','prior','sample','uniform',
% or provided: MU(PxK) or {MU(PxK),b}
% PRECISION = [kmeans2gmm or diag(a) with a = (range/(2K))^(-2)],
% or provided: A([PxP]x[K]) or {V([PxP]x[K]),n}
% KMeans - Cell of KMeans options [{}].
% IterMax - Maximum number of EM iterations [1000]
% Tolerance - Convergence criterion (~ lower bound gain) [1e-4]
% BinWidth - 1x[P] Bin width (histogram mode: add bits of variance) [0]
% InputDim - Input space dimension [0=try to guess]
% Verbose - Verbosity level: [0]= quiet
% 1 = write (lower bound)
% 2 = plot (lower bound)
% 3 = plot more (gmm fit)
%
% OUTPUT
% ------
% Z - NxK cluster reponsibility
% MU - PxK means [E[MU] if Bayesian]
% A - PxPxK precision matrices [E[A] if Bayesian]
% PI - 1xK proportions [E[PI] if Bayesian]
% b - 1xK posterior mean degrees of freedom [if Bayesian]
% V - PxPxK posterior scale matrix [if Bayesian]
% n - 1xK posterior precision degrees of freedom [if Bayesian]
% a - 1xK posterior Dirichlet [if Bayesian]
% X - NxP obs and inferred values [if infer]
%__________________________________________________________________________
%
% Use a learned mixture to segment an image.
%
% FORMAT [Z, X] = spm_gmm('apply',X,MU,A,PI,...) > Classical
% FORMAT [Z, X] = spm_gmm('apply',X,{MU,b},{V,n},a,...) > Bayesian
%
% KEYWORD
% -------
% Missing - Infer missing data: ['infer']/'remove'
% BinWidth - 1x[P] Bin width (histogram mode: add bits of variance) [0]
%__________________________________________________________________________
%
% help spm_gmm>Options
% help spm_gmm>TellMeMore
%__________________________________________________________________________
% Copyright (C) 2018 Wellcome Centre for Human Neuroimaging
% TODO
% - Recheck the different Start options
% Convention:
% N: Number of observations
% P: Dimension of observation space
% K: Number of clusters
% -------------------------------------------------------------------------
% Special case: Apply model
% > Here, we use a learned GMM to segment an image
if ischar(X) && strcmpi(X, 'apply')
[varargout{1:nargout}] = gmm_apply(varargin{:});
return
end
% -------------------------------------------------------------------------
% Parse inputs
p = inputParser;
p.FunctionName = 'spm_gmm';
p.addRequired('X', @isnumeric);
p.addOptional('K', 0, @(X) isscalar(X) && isnumeric(X));
p.addOptional('W', 1, @isnumeric);
p.addParameter('PropPrior', 0, @isnumeric);
p.addParameter('GaussPrior', {}, @iscell);
p.addParameter('Prune', 0, @(X) (isscalar(X) && isnumeric(X)) || islogical(X));
p.addParameter('Missing', true, @islogical);
p.addParameter('Start', 'kmeans', @(X) ischar(X) || isnumeric(X));
p.addParameter('KMeans', {}, @iscell);
p.addParameter('IterMax', 1000, @(X) isscalar(X) && isnumeric(X));
p.addParameter('Tolerance', 1e-4, @(X) isscalar(X) && isnumeric(X));
p.addParameter('BinWidth', 0, @isnumeric);
p.addParameter('InputDim', 0, @(X) isscalar(X) && isnumeric(X));
p.addParameter('Verbose', 0, @(X) (numel(X) <= 2) && (isnumeric(X) || islogical(X)));
p.parse(X, varargin{:});
W = p.Results.W;
K = p.Results.K;
P = p.Results.InputDim;
E = p.Results.BinWidth;
Start = p.Results.Start;
KMeans = p.Results.KMeans;
PropPrior = p.Results.PropPrior;
GaussPrior = p.Results.GaussPrior;
Prune = p.Results.Prune;
Missing = p.Results.Missing;
% -------------------------------------------------------------------------
% A bit of formatting
if ~iscell(Start)
Start = {Start};
end
if ~iscell(GaussPrior)
GaussPrior = {GaussPrior};
end
if islogical(Prune)
Prune = 1e-7 * Prune;
end
% -------------------------------------------------------------------------
% Guess dimension/clusters from provided initial values
[K,P,Start] = dimFromStart(K,P,Start);
[K,P,GaussPrior] = dimFromGaussPrior(K,P,GaussPrior);
[K,PropPrior] = dimFromPropPrior(K,PropPrior);
[P,dimX] = dimFromObservations(P, X);
% ---
% default value
if K == 0, K = 2; end
% -------------------------------------------------------------------------
% Proportion / Dirichlet
if size(PropPrior, 1) > 1
% Class probability map (fixed proportions)
PI = reshape(PropPrior, [], K);
logPI = log(max(PI,eps));
a0 = zeros(1,K, 'like', PI);
else
% Dirichlet prior
a0 = PropPrior(:)';
if numel(a0) < K
a0 = spm_padarray(a0, [0 K - numel(a0)], 'replicate', 'post');
end
PI = [];
logPI = [];
end
% -------------------------------------------------------------------------
% Reshape X (observations)
if dimX(1) == 1 && numel(dimX) == 2
% row-vector case
latX = dimX;
else
latX = dimX(1:end-1);
end
X = reshape(X, [], P);
N = size(X, 1); % Number of observations
N0 = N; % Original number of observations
% -------------------------------------------------------------------------
% Reshape W (weights)
W = W(:);
% -------------------------------------------------------------------------
% "Bin" variance
% > When we work with histograms, a bit of variance is lost due to the
% binning. Here, we assume some kind of uniform distribution inside the bin
% and consequently add the corresponding variance to the 2nd order moment.
if numel(E) < P
E = spm_padarray(E, [0 P - numel(E)], 'replicate', 'post');
end
E = (E.^2)/12;
% -------------------------------------------------------------------------
% Initialise Gauss-Wishart prior
[MU0,b0,V0,n0] = initialise_prior(GaussPrior, K, P);
pr = struct('MU', MU0, 'b', b0, 'V', V0, 'n', n0);
% -------------------------------------------------------------------------
% Initialise mixture
[~, MU, b, A, V, n, PI00,logPI00] = start(Start, X, W, K, a0, pr, KMeans);
if isempty(PI)
PI = PI00;
logPI = logPI00;
end
clear PI00 logPI00
% -------------------------------------------------------------------------
% Prepare missing data stuff (code image, mask, ...)
if ~any(any(isnan(X)))
Missing = false;
end
if Missing
% Deal with missing data
[X,code_image,msk_obs] = spm_gmm_lib('obs2cell', X);
if size(W,1) > 1
W = spm_gmm_lib('obs2cell', W, code_image, false);
end
if size(PI,1) > 1
PI = spm_gmm_lib('obs2cell', PI, code_image, false);
end
if size(E,1) > 1
E = spm_gmm_lib('obs2cell', E, code_image, true);
end
missmsk = [];
else
% Discard rows with missing values
missmsk = any(isnan(X),2);
if ~isempty(missmsk)
X = X(~missmsk,:);
if size(W,1) > 1
W = W(~missmsk);
end
if size(PI,1) > 1
PI0 = PI(missmsk,:);
PI = PI(~missmsk,:);
end
N = sum(~missmsk);
end
missmsk = find(missmsk); % saves a bit of memory
msk_obs = logical([]);
end
% -------------------------------------------------------------------------
% Default prior mean/precision if needed
if isempty(MU0) && sum(b0) > 0
MU0 = MU;
end
if isempty(V0) && sum(n0) > 0
if sum(n) > 0
V0 = V;
else
V0 = bsxfun(@rdivide, A, reshape(n0, [1 1 K]));
end
end
% -------------------------------------------------------------------------
% Initialise posterior
if sum(b) == 0
b = b0;
end
if sum(n) == 0
n = n0;
if sum(n) > 0, V = bsxfun(@rdivide,A,reshape(n,[1 1 K])); end
end
if sum(b) > 0, mean = {MU, b};
else, mean = MU; end
if sum(n) > 0, prec = {V, n};
else, prec = {A}; end
% -------------------------------------------------------------------------
% MAIN LOOP
[Z,cluster,prop] = spm_gmm_lib('loop', X, W, {mean,prec}, ...
{'LogProp', logPI, 'Prop', PI}, ...
'GaussPrior', {MU0, b0, V0, n0}, ...
'PropPrior', a0, ...
'Missing', msk_obs, ...
'IterMax', p.Results.IterMax, ...
'Tolerance', p.Results.Tolerance, ...
'SubIterMax', p.Results.IterMax, ...
'SubTolerance', p.Results.Tolerance, ...
'ObsUncertainty', E, ...
'Verbose', p.Results.Verbose);
MU = cluster.MU;
b = cluster.b;
A = cluster.A;
V = cluster.V;
n = cluster.n;
PI = prop.Prop;
a = prop.Dir;
% -------------------------------------------------------------------------
% -------------------------------------------------------------------------
% Cell 2 Matrix
if Missing
Z = spm_gmm_lib('cell2obs', Z, code_image, msk_obs);
if size(PI,1) > 1 && nargout >= 4
PI = spm_gmm_lib('cell2obs', PI, code_image, msk_obs);
end
if nargout >= 9
X = spm_gmm_lib('cell2obs', X, code_image, msk_obs);
end
end
% -------------------------------------------------------------------------
% Infer missing values
if nargout >= 9 && Missing
X = spm_gmm_lib('InferMissing', X, Z, {MU,A}, code_image);
end
% -------------------------------------------------------------------------
% Prune clusters
if Prune > 0
[K,PI,Z,MU,b,A,V,n] = prune(Prune, PI, Z, {MU,b}, {A,V,n});
end
% -------------------------------------------------------------------------
% Replace discarded missing values
if sum(missmsk) > 0
present = ones(N0, 1, 'logical');
present(missmsk) = false;
clear missing
Z = expand(Z, present, N0, K, 0);
if size(PI,1) > 1 && nargout >= 4
PI = expand(PI, present, N0, K, PI0);
end
if nargout >= 9
X = expand(X, present, N0, P, NaN);
end
end
% -------------------------------------------------------------------------
% Reshape everything (use input lattice)
Z = reshape(Z, [latX K]);
if size(PI,1) > 1 && nargout >= 4
PI = reshape(PI, [latX K]);
end
if nargout >= 9
X = reshape(X, [latX P]);
end
% -------------------------------------------------------------------------
% Push results in output object
if nargout >= 1, varargout{1} = Z;
if nargout >= 2, varargout{2} = MU;
if nargout >= 3, varargout{3} = A;
if nargout >= 4, varargout{4} = PI;
if nargout >= 5, varargout{5} = b;
if nargout >= 6, varargout{6} = V;
if nargout >= 7, varargout{7} = n;
if nargout >= 8, varargout{8} = a;
if nargout >= 9, varargout{9} = X;
end;end;end;end;end;end;end;end;end
% =========================================================================
function [Z,X] = gmm_apply(X, mean, prec, prop, varargin)
% FORMAT [Z,X] = gmm_apply(X, MU, A, PI, ...)
% FORMAT [Z,X] = gmm_apply(X, {MU,b}, {V,n}, a, ...)
% -------------------------------------------------------------------------
% Parse inputs
p = inputParser;
p.FunctionName = 'gmm_apply';
p.addRequired('X', @isnumeric);
p.addRequired('mean', @(X) isnumeric(X) || iscell(X));
p.addRequired('precision', @(X) isnumeric(X) || iscell(X));
p.addRequired('prop', @(X) isnumeric(X) || iscell(X));
p.addParameter('Missing', true, @islogical);
p.addParameter('BinWidth', 0, @isnumeric);
p.parse(X, mean, prec, prop, varargin{:});
E = p.Results.BinWidth;
Missing = p.Results.Missing;
MU = [];
b = [];
A = [];
V = [];
N = [];
PI = [];
a = [];
logPI = [];
% -------------------------------------------------------------------------
% Read arguments
if ~iscell(mean)
MU = mean;
elseif ~isempty(mean)
MU = mean{1};
if numel(mean) > 1
b = mean{2};
end
end
if ~iscell(prec)
A = prec;
elseif ~isempty(prec)
A = prec{1};
if numel(prec) > 1
n = prec{2};
if sum(b) > 0
V = A;
A = bsxfun(@times, V, reshape(n,1,1,[]));
prec = {V n};
end
end
end
if ~iscell(prop)
PI = prop;
elseif ~isempty(prop)
PI = prop{1};
end
% -------------------------------------------------------------------------
% Dimensions
P = size(MU,1);
K = size(MU,2);
% -------------------------------------------------------------------------
% Proportions/Dirichlet
if numel(PI) <= K
PI = PI(:)';
PI = spm_padarray(PI, [0 K - numel(PI)], 'replicate', 'post');
if abs(sum(PI)-1) > eps('single')
% Dirichlet prior
a = PI;
logPI = psi(a) - psi(sum(a));
else
logPI = log(max(PI,eps));
end
else
logPI = log(max(PI,eps));
end
clear PI
% -------------------------------------------------------------------------
% Reshape X (observations)
dimX = size(X);
if P == 1
latX = dimX;
if latX(2)==1
latX = latX(1);
end
else
latX = dimX(1:end-1);
end
X = reshape(X, [], P);
N = size(X, 1); % Number of observations
N0 = N; % Original number of observations
% -------------------------------------------------------------------------
% Prepare missing data stuff (code image, mask, ...)
if ~any(any(isnan(X)))
Missing = false;
end
if Missing
% Deal with missing data
[X,code_image,msk_obs] = spm_gmm_lib('obs2cell', X);
if size(E,1) > 1
E = spm_gmm_lib('obs2cell', E, code_image, true);
end
missmsk = [];
else
% Discard rows with missing values
missmsk = any(isnan(X),2);
if ~isempty(missmsk)
X = X(~missmsk,:);
N = sum(~missmsk);
end
missmsk = find(missmsk); % saves a bit of memory
msk_obs = [];
end
% -------------------------------------------------------------------------
% "Bin" variance
% > When we work with histograms, a bit of variance is lost due to the
% binning. Here, we assume some kind of uniform distribution inside the bin
% and consequently add the corresponding variance to the 2nd order moment.
if numel(E) < P
E = spm_padarray(E, [0 P - numel(E)], 'replicate', 'post');
end
E = (E.^2)/12;
% -------------------------------------------------------------------------
% Compute marginal log-likelihood
const = spm_gmm_lib('Normalisation', mean, prec, msk_obs);
logpX = spm_gmm_lib('Marginal', X, [{MU} prec], const, msk_obs, E);
% -------------------------------------------------------------------------
% Compute responsibilities
Z = spm_gmm_lib('Responsibility', logpX, logPI);
clear logpX logPI
% -------------------------------------------------------------------------
% Cell 2 Matrix
if Missing
Z = spm_gmm_lib('cell2obs', Z, code_image, msk_obs);
if nargout >= 2
X = spm_gmm_lib('cell2obs', X, code_image, msk_obs);
end
end
% -------------------------------------------------------------------------
% Infer missing values
if nargout >= 2 && Missing
X = spm_gmm_lib('InferMissing', X, Z, {MU,A}, code_image);
end
% -------------------------------------------------------------------------
% Replace discarded missing values
if sum(missmsk) > 0
present = ones(N0, 1, 'logical');
present(missmsk) = false;
clear missing
Z = expand(Z, present, N0, K, 0);
if nargout >= 2
X = expand(X, present, N0, P, NaN);
end
end
% -------------------------------------------------------------------------
% Reshape everything (use input lattice)
Z = reshape(Z, [latX K]);
if nargout >= 2
X = reshape(X, [latX P]);
end
% =========================================================================
function TellMeMore
% _________________________________________________________________________
%
% Gaussian mixture model (GMM)
% ----------------------------
% The Gaussian Mixture relies on a generative model of the data.
% Each observation is assumed to stem from one of K clusters, and each
% cluster possesses a Gaussian density.
%
% Classical GMM
% -------------
% With the most basic (and well known) model, we look for maximum
% likelihood values for the model parameters, which are the mean and
% precision matrix of each cluster: {Mu, A}_k = argmax p({x}_n | {Mu, A}_k)
% To compute this probabilitity, we need to integrate over all possible
% cluster responsibility (to which cluster belongs a given observation),
% which are unknown:
% p({x}_n | {Mu, A}_k) = int p({x}_n | {z}_n, {Mu, A}_k) p({z}_n) d{z}_n
% Since this integral is intractable, we use the Expectation-Maximisation
% algorithm: we alternate between computing the posterior distribution of
% responsibilities (given known means and precision matrices) and updating
% the mean and precisions (given the known posterior).
%
% Bayesian GMM
% ------------
% When we have some idea about how these mean and precision look like, we
% can take it into account in the form of Bayesian beliefs, i.e., prior
% probability distributions over the parameters we want to estimate. What
% we are now looking for are the posterior distributions (given some
% observed data) of these parameters:
% p({Mu, A}_k | {x}_n) = p({x}_n | {Mu, A}_k) p({Mu, A}_k) / p({x}_n)
% To make everything tractable, these prior beliefs are chosen to be
% conjugate priors. It is not very important to know what it means, except
% that it makes computing the posterior probabilities easier. In our case,
% we can use a Gaussian prior distribution for the means, a Wishart prior
% distribution for the precision matrices (we talk of Gauss-Prior
% distirbution when they are combined) and a Dirichlet distribution for
% clusters' proportion.
% Despite all that, we still cannot compute these posteriors. We make an
% additional assumption, which is that there is some sort of independence
% between the parameters to estimate; we say that the posterior can be
% factorised:
% q({Mu, A, Pi}_k, {z}_n) = q({Mu, A}_k) q({Pi}_k) q({z}_n)
% We can then use a technique called variational Bayesian inference to
% estimate, in turn, these distributions.
%
% Histogram GMM
% -------------
% To speed up computation, we sometimes prefer to use an histogram (that
% is, binned observations) rather than the full set of observations. A
% first way of doing this, is by assuming that each bin centre corresponds
% to several identical observations (the bin count). In other words, we now
% have weighted observations.
% However, using weighted observations makes estimating the precisions less
% robust, as we artificially reduce the variance by assigning different
% values to the same bin centre. To lower this effect, we can assume that
% each observation is actually a distribution, i.e., there is a bit of
% uncertainty about the "true" value. When computing the model, we want to
% integrate over all possible values.
% Here, we assume that these distributions are uniform in each bin. This
% makes the implementation easy and efficient: the expected value of each
% observation is the bin centre, and their variance is (w^2)/12, where w is
% the bin width. This consists of adding a bit of jitter when computing
% second order statistics to update the precision matrices.
%
% Missing values
% --------------
% In the classical GMM case, in the presence of partial observations (in a
% given voxel, some modalities might be missing), it is easy to compute the
% conditional likelihood of a given voxel, but computing ML mean and
% precision estimates is more tricky.
% With one Gaussian, an EM scheme can be designed to obtain ML parameters
% by alternating between inferring missing values and updating the
% parameters. This scheme can be extended to the mixture case, in which
% case the EM algorithm relies on a joint posterior distribution over class
% repsonsibilities and missing values.
% _________________________________________________________________________
% Copyright (C) 2018 Wellcome Centre for Human Neuroimaging
% =========================================================================
function Options
% _________________________________________________________________________
%
% FORMAT [Z,MU,A,PI,...] = spm_gmm(X,...)
%
% MANDATORY
% ---------
% X is a set of multidimensional observations. It usually takes the form of
% a matrix, in which case its first dimension (N) is the number of
% observations and its second dimension (P) is their dimension. However,
% it is possible to provide a multidimensional array, in which case the
% dimension P must be provided in some way. This might be through the
% size of user-provided starting estimates or priors, or directly using
% the 'InputDim' option. If the 'Missing' option is activated, X might
% contain missing values (NaN), which will be inferred by the model. If
% the 'Missing' option is not activated, all rows containing missing
% values will be excuded.
%
% OPTIONAL
% --------
% K is the number of clusters in the mixture. If not provided, we will
% first try to guess it from user-provided options (starting estimates or
% priors). If this is not possible, the default value is 2.
%
% W is a vector of weights associated with each observation. Usually,
% weights are used when the input dataset is an histogram. Suitable X and
% W arrays can be obtained with the spm_imbasics('hist') function. In
% this case, it is advised to also use the 'BinWidth' option, which
% prevents variances to collapse due to the pooling of values. If W is
% not provided, all observations have weight 1.
%
% KEYWORD
% -------
% PropPrior may take two forms (+ the default):
% 0) It can be empty (the default), in which case the algorithm
% searches for maximum-likelihood proportion.
% 1) It can be a vector of concentration parameters for a
% Dirichlet prior. Dirichlet distributions are conjugate priors
% for proportions, e.g. Categorical parameters. In this case,
% it must consist of K strictly positive values. Their
% normalised value is the prior expected proportion, while their
% sum corresponds to the precision of the distribution (the
% larger the precision, the stonger the prior).
% 2) It can be a matrix (or multidimensional array) of fixed
% observation-wise proportions. We often talk about
% non-stationary class proportions. In this case, it must
% contain N (the number of observations) times K (the number of
% clusters) elements. Elements must sum to one along the cluster
% dimension.
%
% GaussPrior {Mu(PxK),b(1x[K]),V([PxP]x[K]),n(1x[K])}
% The Gauss-Wishart distribution is a conjugate prior for the
% parameters (mean and precision matrix) of a multivariate
% Gaussian distribution. Its parameters are Mu, the expected
% mean, b, the prior degrees of freedom for the mean, V, the
% scale matrix and n, the prior degrees of freedom for the
% precision. The expected precision matrix is n*V. This option
% allows to defines these parameters for each cluster: it should
% be a cell of arrays with the dimensions stated above. Note
% that some parameters might be left empty, in which case they
% will be automatically determined. Typically, one could only
% provide the degrees of freedom (b and n), in which case the
% starting estimates will be used as prior expected means and
% precisions. Also, dimensions that are written in brackets are
% automatically expanded if needed.
% By default, this option is empty, and the algorithm
% searches for maximum-likelihood parameters.
%
% Prune contains a threshold for the final estimated proportions.
% Classes that are under this threshold will be considered as
% non-informative and pruned out. By default, the threshold is 0
% and no pruning is performed.
%
% Missing 'infer': (default) missing values are inferred by making use
% of some properties of the Gaussian distribution.
% However, the fit is slower due to this inferrence
% scheme.
% 'remove': all rows with missing values are exclude, and the
% fit is performed without inferrence. For
% computational efficiency, when no values are missing
% from the input set (i.e., there are no NaNs), this
% option is activated by default.
%
% Start Method used to select starting estimates:
% 'kmeans': A K-means clustering is used to obtain a first
% classification of the observations. Centroids are
% used to initialise the means, intra-cluster
% co-variance is used to initialise precision
% matrices (unless initial precision matrices are
% user-provided) and cluster size is used to
% initialise proportions. Options can be provided to
% the K-means algorithm using the 'KMeans' option.
% 'linspace': Centroids are chosen so that they are linearly
% spaced along the input range of values. Precision
% matrices are set as explained below.
% 'prior': The prior expected means and precision matrices
% are used as initial estimates.
% 'sample': Random samples are selected from the input dataset
% and used as initial means. Precision matrices are
% set as explained below.
% 'uniform': Random values are uniformly sampled from the input
% range of values and used as initial means.
% Precision matrices are set as explained below.
% MU(PxK): Initial means are user-provided.
% The method can be provided along with initial precision
% matrices, in a cell: {METHOD, A([PxP]x[K])}. By default,
% the initial precision matrix is common to all classes and
% chosen so that the input range is well covered:
% -> A = diag(a) with a(p) = (range(p)/(2K))^(-2)
%
% KMeans is a cell of options for the K-means algorithm.
%
% IterMax is the maximum number of EM iterations. Default is 1000.
%
% Tolerance is the convergence threshold below which the algorithm is
% stopped. The convergence criterion is a normalised lower bound
% gain (lcur-lprev)(lmax-lmin). The default tolerance is 1e-4.
%
% BinWidth is a vector of bin widths. It is useful when the input is an
% histogram (e.g. weighted observations) to regularise the
% variance estimation. If only one width is provided, it is
% automatically expanded to all dimensions.
%
% InputDim Number of dimensions in the input space. If not provided, we
% will try to infer it from the input arrays, e.g. starting
% estimates, last dimension of the input array, etc.
%
% Verbose 0: Quiet mode. Nothing is written or plotted.
% 1: Verbose mode. The lower bound is written after each
% iteration. This slows down significantly the algorithm.
% 2: Graphical mode. The evolution of the lower bound and its
% different components is plotted.
% 3: Graphical mode +. A representation of the fit (marginal
% distribution, joint 2D distributions, proportions) is
% plotted. This slows things down dramatically and should
% only be used for education or debugging purpose.
% _________________________________________________________________________
% Copyright (C) 2018 Wellcome Centre for Human Neuroimaging
% =========================================================================
function [K,P,Start] = dimFromStart(K, P, Start)
% FORMAT [K,P,Start] = dimFromStart(K, P, Start)
% K - Number of clusters (previous guess)
% K - Number of dimensions (previous guess)
% Start - Cell of "starting estimates" options
%
% Guess input dimension and number of clusters from starting estimates.
% ---
% Guess from starting mean
if numel(Start) > 1 && ~ischar(Start{2})
if size(Start{2}, 1) > 1
if P == 0
P = size(Start{2}, 1);
elseif P ~= size(Start{2}, 1)
warning(['Input space dimension does not agree with starting ' ...
'precision: %d vs. %d'], P, size(Start{2}, 1));
Start{2} = Start{2}(1:P,1:P,:);
end
end
if size(Start{2}, 3) > 1
if K == 0
K = size(Start{2}, 3);
elseif K ~= size(Start{2}, 3)
warning(['Number of clusters does not agree with starting ' ...
'precision: %d vs. %d'], K, size(Start{2}, 3));
Start{2} = Start{2}(:,:,1:K);
end
end
end
if ~ischar(Start{1})
if size(Start{1}, 1) > 0
if P == 0
P = size(Start{1}, 1);
elseif P ~= size(Start{1}, 1)
warning(['Input space dimension does not agree with starting ' ...
'mean: %d vs. %d'], P, size(Start{1}, 1));
Start{1} = Start{1}(1:P,:);
end
end
if size(Start{1}, 2) > 0
if K == 0
K = size(Start{1}, 2);
elseif K ~= size(Start{1}, 2)
warning(['Number of clusters does not agree with starting ' ...
'mean: %d vs. %d'], K, size(Start{1}, 2));
Start{1} = Start{1}(:,1:K);
end
end
end
% ---
% Guess from starting precision
if numel(Start) > 1 && ~ischar(Start{2})
if size(Start{2}, 1) > 1
if P == 0
P = size(Start{2}, 1);
elseif P ~= size(Start{2}, 1)
warning(['Input space dimension does not agree with starting ' ...
'precision: %d vs. %d'], P, size(Start{2}, 1));
Start{2} = Start{2}(1:P,1:P,:);
end
end
if size(Start{2}, 3) > 1
if K == 0
K = size(Start{2}, 3);
elseif K ~= size(Start{2}, 3)
warning(['Number of clusters does not agree with starting ' ...
'precision: %d vs. %d'], K, size(Start{2}, 3));
Start{2} = Start{2}(:,:,1:K);
end
end
end
% =========================================================================
function [K,P,GaussPrior] = dimFromGaussPrior(K, P, GaussPrior)
% FORMAT [K,P,GaussPrior] = dimFromGaussPrior(K, P, GaussPrior)
% K - Number of clusters (previous guess)
% K - Number of dimensions (previous guess)
% GaussPrior - Cell of "prior" options
%
% Guess input dimension and number of clusters from Gauss-Wishart prior
% ---
% Guess from prior mean
if numel(GaussPrior) > 1
if size(GaussPrior{1}, 1) > 0
if P == 0
P = size(GaussPrior{1}, 1);
elseif P ~= size(GaussPrior{1}, 1)
warning(['Input space dimension does not agree with prior ' ...
'mean: %d vs. %d'], P, size(GaussPrior{1}, 1));
GaussPrior{1} = GaussPrior{1}(1:P,:);
end
end
if size(GaussPrior{1}, 2) > 0
if K == 0
K = size(GaussPrior{1}, 2);
elseif K ~= size(GaussPrior{1}, 2)
warning(['Number of clusters does not agree with prior ' ...
'mean: %d vs. %d'], K, size(GaussPrior{1}, 2));
GaussPrior{1} = GaussPrior{1}(:,1:K);
end
end
end
% ---
% Guess from prior precision
if numel(GaussPrior) > 3
if size(GaussPrior{2}, 1) > 1
if P == 0
P = size(GaussPrior{2}, 1);
elseif P ~= size(GaussPrior{2}, 1)
warning(['Input space dimension does not agree with prior ' ...
'precision: %d vs. %d'], P, size(GaussPrior{2}, 1));
GaussPrior{2} = GaussPrior{2}(1:P,1:P,:);
end
end
if size(GaussPrior{2}, 3) > 1
if K == 0
K = size(GaussPrior{2}, 3);
elseif K ~= size(GaussPrior{2}, 3)
warning(['Number of clusters does not agree with prior ' ...
'precision: %d vs. %d'], K, size(GaussPrior{2}, 3));
GaussPrior{2} = GaussPrior{2}(:,:,1:K);
end
end
end
% =========================================================================
function [K,PropPrior] = dimFromPropPrior(K, PropPrior)
% Guess number of clusters from Dirichlet prior
% ---
% Guess from prior proportion
if size(PropPrior, 2) > 1
if K == 0
K = size(PropPrior, 2);
elseif K ~= size(PropPrior, 2)
warning(['Number of clusters does not agree with prior ' ...
'proportion: %d vs. %d'], K, size(PropPrior, 2));
PropPrior = PropPrior(:,1:K);
end
end
% =========================================================================
function [P,dimX] = dimFromObservations(P, X)
dimX = size(X);
if P == 0
if numel(dimX) == 2
if dimX(1) == 1
% row-vector case
P = dimX(1);
else
% matrix case
P = dimX(2);
end
else
% N-array case
dim = size(X);
P = dim(end);
end
end
% =========================================================================
function varargout = prune(threshold, PI, Z, mean, prec)
% FORMAT prune(threshold, PI, Z, {MU,b}, {A,V,n})
%
% Remove classes with proportion <= threshold
kept = sum(PI,1)/sum(PI(:)) >= threshold;
K = sum(kept);
Z = Z(:,kept);
PI = PI(:,kept);
if ~iscell(mean)
mean = {mean(:,kept)};
else
if numel(mean) >= 1
mean{1} = mean{1}(:,kept);
if numel(mean) >= 2 && sum(mean{2}) > 0
mean{2} = mean{2}(kept);
end
end
end
if ~iscell(prec)
prec = {prec(:,kept)};
else
if numel(prec) >= 1
if ~isempty(prec{1})
prec{1} = prec{1}(:,:,kept);
end
if numel(prec) >= 2
if ~isempty(prec{2})
prec{2} = prec{2}(:,:,kept);
end
if numel(prec) >= 3 && sum(prec{3}) > 0
prec{3} = prec{3}(kept);
end
end
end
end
varargout = [{K} {PI} {Z} mean prec];
% =========================================================================
function X = expand(X, msk, N, P, val)
X1 = X;
val = cast(val, 'like', X1);
switch val
case 0
X = zeros(N, P, 'like', X1);
case 1
X = ones(N, P, 'like', X1);
case Inf
X = Inf(N, P, 'like', X1);
otherwise
if numel(val) == 1
X = val * ones(N, P, 'like', X1);
elseif numel(val) == P
X = repmat(val, [N 1]);
end
end
X(msk,:) = X1; clear X1
% =========================================================================
function [Z,MU,b,A,V,n,PI,logPI] = start(method, X, W, K, a0, pr, kmeans)
% FORMAT C = start(method, X, W, K, a0, pr, kmeans)
%
% method - Method to use to select starting centroids
% 'kmeans', 'sample', 'uniform' or provided matrix
% X - Vector of NxP observed values
% W - Vector of Nx1 weights
% K - Number of clusters
% a0 - 1xK dirichlet priors (or empty)
% pr - Structure of Gauss-Wishart prior parameters
% kmeans - Options for spm_kmeans
%
% MU - PxK means
% A - PxPxK precision matrices
% PI - 1xK proportions
%
% Compute starting estimates
if ~iscell(method)
method = {method};
end
b = [];
n = [];
V = [];
switch method{1}
case 'kmeans'
% Use kmeans to produce a first clustering of the data
P = size(X,2);
N = size(X,1);
% Get labelling and centroids from K-means
[L,MU] = spm_kmeans(X, K, W, kmeans{:});