Skip to content

Commit fef6e2b

Browse files
author
Abhinab Bhattacharjee
committed
Rank Histogram(Observation)
1 parent c5136e8 commit fef6e2b

File tree

6 files changed

+182
-96
lines changed

6 files changed

+182
-96
lines changed

src/+datools/+examples/+sandu/l63_experiments.m

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
% Use this to run Lorenz 63 experiments
55
modelname = 'Lorenz63';
66
% uncomment the filter you want to run
7-
%filtername = 'EnKF';
7+
filtername = 'EnKF';
88
%filtername = 'ETKF';
99
%filtername = 'ETPF';
1010
%filtername = 'SIR';
11-
filtername = 'RHF';
11+
%filtername = 'RHF';
1212

1313
% oservation variance
1414
variance = 1;
@@ -29,7 +29,7 @@
2929

3030
%% Remaining code%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
3131
% rank histogram for the 1st state only (for now)
32-
histvar = 1:1:1;
32+
histvar = 1:1:2;
3333
% decide the type of filter
3434
switch filtername
3535
case 'EnKF'
@@ -224,7 +224,7 @@
224224
y = filter.Observation.observeWithError(model.TimeSpan(1), xt);
225225

226226
% Rank histogram (if needed)
227-
datools.utils.stat.RH(filter, xt);
227+
datools.utils.stat.RH(filter, xt, y);
228228

229229
% analysis
230230
% try

src/+datools/+examples/+sandu/runexperiments2.m

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ function runexperiments2(user)
3333
%localization radius
3434
r = user.localizationradius;
3535

36+
% the states for which the histograms are saved
37+
histvar = user.histvar;
38+
RHmeasure = user.RHmeasure;
39+
3640
% plot indices
3741
rankhistogramplotindex = user.rankhistogramplotindex;
3842
rmseplotindex = user.rmseplotindex;
@@ -41,7 +45,7 @@ function runexperiments2(user)
4145

4246
%% Remaining code%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
4347
% rank histogram for the 1st state only (for now)
44-
histvar = 1:1:1;
48+
%histvar = 1:1:1;
4549

4650
% decide the type of filter
4751
switch filtername
@@ -59,6 +63,8 @@ function runexperiments2(user)
5963
filtertype = 'Particle';
6064
case 'SIR'
6165
filtertype = 'Particle';
66+
case 'SIS_EnKF'
67+
filtertype = 'Particle';
6268
case 'RHF'
6369
filtertype = 'Ensemble';
6470
case 'EnGMF'
@@ -161,7 +167,7 @@ function runexperiments2(user)
161167
model = datools.Model('Solver', solvermodel, 'ODEModel', modelODE);
162168
nature = datools.Model('Solver', solvernature, 'ODEModel', natureODE);
163169

164-
% Observation Model
170+
% Observable model
165171
naturetomodel = datools.observation.Linear(numel(nature0), 'H', ...
166172
speye(natureODE.NumVars));
167173

@@ -226,7 +232,7 @@ function runexperiments2(user)
226232

227233
ns = numsamples;
228234
sE = zeros(ns, 1);
229-
rankvaluesample = zeros(histvar, ensN+1, numsamples);
235+
rankvaluesample = zeros(numel(histvar), ensN+1, numsamples);
230236
rmstempvalsample = nan * ones(steps-spinup, numsamples);
231237

232238
parfor sample = 1:ns
@@ -260,6 +266,8 @@ function runexperiments2(user)
260266
localization = [];
261267
case 'SIR'
262268
localization = [];
269+
case 'SIS_EnKF'
270+
localization = [];
263271
case 'ETPF'
264272
localization = [];
265273
case 'ETPF2'
@@ -317,6 +325,16 @@ function runexperiments2(user)
317325
'Parallel', false, ...
318326
'RankHistogram', histvar, ...
319327
'Rejuvenation', rejuvenation);
328+
case 'SIS_EnKF'
329+
filter = datools.statistical.ensemble.(filtername)(model, ...
330+
'Observation', observation, ...
331+
'NumEnsemble', ensN, ...
332+
'ModelError', modelerror, ...
333+
'EnsembleGenerator', ensembleGenerator, ...
334+
'Parallel', false, ...
335+
'RankHistogram', histvar, ...
336+
'Rejuvenation', rejuvenation, ...
337+
'Inflation', 1.05);
320338
case 'RHF'
321339
filter = datools.statistical.ensemble.(filtername)(model, ...
322340
'Observation', observation, ...
@@ -380,17 +398,18 @@ function runexperiments2(user)
380398
filter.forecast();
381399
end
382400
catch e
383-
dofilter = false
401+
dofilter = false;
384402
end
385403

386404
% observe
387-
xt = naturetomodel.observeWithoutError(nature.TimeSpan(1), nature.State);
388-
y = filter.Observation.observeWithError(model.TimeSpan(1), xt);
405+
xt = naturetomodel.observeWithoutError(nature.TimeSpan(1), nature.State); % H(x_true)
406+
y = filter.Observation.observeWithError(model.TimeSpan(1), xt); % H(x_true) + noise
389407

390408
% Rank histogram (if needed)
391-
datools.utils.stat.RH(filter, xt);
392-
rankvaluesample(:, :, sample) = filter.RankValue(1, 1:end-1);
409+
% datools.utils.stat.RH(filter, xt, y, Hxa, RHmeasure);
410+
% rankvaluesample(:, :, sample) = filter.RankValue(:, 1:end-1);
393411

412+
394413
% analysis
395414
try
396415
if dofilter
@@ -401,6 +420,15 @@ function runexperiments2(user)
401420
end
402421

403422
xa = filter.BestEstimate;
423+
xaensembles = filter.Ensemble;
424+
425+
% observable
426+
hxa = naturetomodel.observeWithoutError(model.TimeSpan(1), xaensembles); % H(x_analysis)
427+
Hxa = filter.Observation.observeWithError(model.TimeSpan(1), hxa); % H(x_analysis) + noise
428+
429+
% Rank histogram (if needed)
430+
datools.utils.stat.RH(filter, xt, y, Hxa, RHmeasure);
431+
rankvaluesample(:, :, sample) = filter.RankValue(:, 1:end-1);
404432

405433
err = xt - xa;
406434

@@ -446,13 +474,22 @@ function runexperiments2(user)
446474
if isnan(resE)
447475
resE = 1000;
448476
end
477+
478+
% totalklval = 0;
479+
% for klcounter = 1:numel(histvar)
480+
% klvalue = datools.utils.stat.KLDivergence(rankvalue(klcounter,:), (1 / (ensN+1))*ones(1, ensN+1));
481+
% totalklval = totalklval + klvalue;
482+
% end
483+
484+
rankvalue = sum(rankvalue,1);
449485

450486
switch filtertype
451487
case 'Ensemble'
452488
rmses(ensNi, infi) = resE;
453489

454490
[xs, pval, rhplotval(ensNi, infi)] = datools.utils.stat.KLDiv(rankvalue, ...
455-
(1 / ensN)*ones(1, ensN+1));
491+
(1 / (ensN+1))*ones(1, ensN+1));
492+
%rhplotval(ensNi, infi) = totalklval;
456493
case 'Particle'
457494
rmses(ensNi, reji) = resE;
458495

src/+datools/+examples/+sandu/runnerscript.m

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,39 +24,46 @@
2424
% filtername = 'LETKF';
2525

2626
%% User inputs
27-
options.modelname = 'Lorenz63';
27+
options.modelname = 'QG';
2828

29-
options.filtername = 'EnKF';
29+
options.filtername = 'ETKF';
30+
31+
%options.ensNs = [8, 12, 16, 20, 24, 28, 32];%[25, 50, 75, 100];
32+
options.ensNs = [16, 32, 48, 64, 80, 96, 112];
33+
%options.ensNs = [32, 48, 64, 80, 96, 112, 128];
34+
%options.ensNs = [32, 64, 128, 256, 512, 1024, 2048];
3035

31-
options.ensNs = [4, 8, 12, 16, 20, 24, 28];%[25, 50, 75, 100];
3236

3337
options.infs = round(linspace(1, 1.10, 7), 2); % [1, 1.025, 1.05, 1.075, 1.10];
3438

3539
options.variance = 1; % observation variance
3640

37-
options.observeindicies = 1:1:3; % observation indices
41+
options.observeindicies = linspace(0,8001,150); % observation indices
3842

3943
options.rejs = round(logspace(-1.5, -0.25, 7), 2); % round(2*logspace(-1.5, -0.5, 6), 2)
4044

41-
options.spinups = 500;
45+
options.spinups = 50;
4246

4347
options.steps = 11 * options.spinups; % change as deemed fit
4448

45-
options.Dt = 0.12; % 0.12(L63), 0.05(L96), 0.0109(QG)
49+
options.Dt = 0.0109; % 0.12(L63), 0.05(L96), 0.0109(QG)
4650

4751
options.odesolver = 'ode45'; % ode45 , RK4
4852

4953
options.localize = false; % set to true if localization is needed
5054

51-
options.localizationradius = 4; % localization rdius
55+
options.localizationradius = 4; % localization radius
56+
57+
options.ns = 10;
5258

53-
options.ns = 20;
59+
options.histvar = 1:1:1;
60+
options.RHmeasure = 'Truth';
5461

5562
%plotting parameters
5663
options.rankhistogramplotindex = 1:2:numel(options.ensNs);
5764
options.rmseplotindex = 1:2:numel(options.ensNs);
58-
options.rmseheatmapplotindex = 1:1:numel(options.ensNs);
59-
options.kldivergenceplotindex = 1:1:numel(options.ensNs);
65+
options.rmseheatmapplotindex = 1:2:numel(options.ensNs);
66+
options.kldivergenceplotindex = 1:2:numel(options.ensNs);
6067

6168
%%
6269
% call the experiment file

src/+datools/+utils/+stat/KLDiv.m

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
function [xs, pval, KLVal] = KLDiv(pdf1, pdf2)
1111

1212
if numel(pdf1) ~= numel(pdf2)
13-
fprintf('Both vectors must be of same length');
14-
return;
13+
error('Both vectors must be of same length');
1514
end
1615

1716
% check if the distributions are normalized

src/+datools/+utils/+stat/RH.m

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
function RH(model, xt)
1+
function RH(filter, xt, y, Hxa, observationvariable)
22
%function to plot the rank histogram
33

4-
flag = isa(model, 'datools.statistical.ensemble.EnF');
4+
flag = isa(filter, 'datools.statistical.ensemble.EnF');
55
if ~flag
66
frpintf('Model should be ensemble type');
77
exit;
@@ -10,25 +10,46 @@ function RH(model, xt)
1010
fprintf('need to pass the truth at this time step');
1111
exit;
1212
end
13-
var = length(model.RankHistogram);
13+
var = length(filter.RankHistogram);
14+
varobservtion = filter.Observation.Indices;
15+
xa = filter.Ensemble;
16+
[~, ensN] = size(xa);
1417

15-
xf = model.Ensemble;
16-
[~, ensN] = size(xf);
17-
18-
for i = 1:var
19-
tempVar = model.RankHistogram(i);
20-
model.RankValue(i, ensN+2) = tempVar;
21-
ensFor = sort(xf(tempVar, :));
22-
23-
if xt(tempVar) < ensFor(1)
24-
model.RankValue(i, 1) = model.RankValue(i, 1) + 1;
25-
elseif xt(tempVar) > ensFor(length(ensFor))
26-
model.RankValue(i, ensN+1) = model.RankValue(i, ensN+1) + 1;
27-
else
28-
index = find(ensFor > xt(tempVar), 1);
29-
model.RankValue(i, index) = model.RankValue(i, index) + 1;
30-
end
3118

19+
switch observationvariable
20+
case 'Truth'
21+
for i = 1:var
22+
tempVar = filter.RankHistogram(i);
23+
filter.RankValue(i, ensN+2) = tempVar;
24+
ensFor = sort(xa(tempVar, :));
25+
26+
if xt(tempVar) < ensFor(1)
27+
filter.RankValue(i, 1) = filter.RankValue(i, 1) + 1;
28+
elseif xt(tempVar) > ensFor(length(ensFor))
29+
filter.RankValue(i, ensN+1) = filter.RankValue(i, ensN+1) + 1;
30+
else
31+
index = find(ensFor > xt(tempVar), 1);
32+
filter.RankValue(i, index) = filter.RankValue(i, index) + 1;
33+
end
34+
35+
end
36+
case 'Observation'
37+
for i = 1:numel(varobservtion)
38+
tempVar = filter.RankHistogram(varobservtion(i));
39+
filter.RankValue(i, ensN+2) = tempVar;
40+
ensFor = sort(Hxa(i, :));
41+
42+
if y(i) < ensFor(1)
43+
filter.RankValue(i, 1) = filter.RankValue(i, 1) + 1;
44+
elseif y(i) > ensFor(length(ensFor))
45+
filter.RankValue(i, ensN+1) = filter.RankValue(i, ensN+1) + 1;
46+
else
47+
index = find(ensFor > y(i), 1);
48+
filter.RankValue(i, index) = filter.RankValue(i, index) + 1;
49+
end
50+
51+
end
3252
end
3353

54+
3455
end

0 commit comments

Comments
 (0)