Skip to content

Commit 9ea1ca9

Browse files
author
Abhinab Bhattacharjee
committed
Added KL Div
1 parent 1dc2a83 commit 9ea1ca9

File tree

7 files changed

+270
-75
lines changed

7 files changed

+270
-75
lines changed

src/+datools/+examples/l63_RHF.m

+4-4
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,12 @@
4545

4646
ensembleGenerator = @(N) randn(natureODE.NumVars, N);
4747

48-
ensNs = 20:5:25;
49-
infs = 1.10:.01:1.10;
48+
ensNs = 5:5:50;
49+
infs = 1.01:.01:1.10;
5050
serveindicies = 1:1:natureODE.NumVars;
5151
rmses = inf*ones(numel(ensNs), numel(infs));
5252

53-
maxallowerr = 20;
53+
maxallowerr = 100;
5454

5555
mm = min(rmses(:));
5656

@@ -88,7 +88,7 @@
8888
'EnsembleGenerator', ensembleGenerator, ...
8989
'Inflation', inflation, ...
9090
'Parallel', false, ...
91-
'Tail', 'Gaussian');
91+
'Tail', 'Flat');
9292

9393
% rhf.setMean(natureODE.Y0);
9494
% rhf.scaleAnomalies(1/10);

src/+datools/+examples/l63_enkf.m

+121-41
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
clear; close all;
2-
figure;
3-
drawnow;
2+
% figure;
3+
% drawnow;
44

55
% time steps
66
Delta_t = 0.12;
77

8-
% Time Stepping Methods
9-
solvermodel = @(f, t, y) datools.utils.rk4(f, t, y, 50);
10-
solvernature = @(f, t, y) datools.utils.rk4(f, t, y, 50);
8+
% Time Stepping Methods (Use ode45 or write your own)
9+
% solvermodel = @(f, t, y) datools.utils.rk4(f, t, y, 50);
10+
% solvernature = @(f, t, y) datools.utils.rk4(f, t, y, 50);
11+
solvermodel = @(f, t, y) ode45(f, t, y);
12+
solvernature = @(f, t, y) ode45(f, t, y);
1113

12-
% ODE
14+
% Define ODE
1315
natureODE = otp.lorenz63.presets.Canonical;
1416
nature0 = randn(natureODE.NumVars, 1);
1517
natureODE.TimeSpan = [0, Delta_t];
@@ -21,19 +23,23 @@
2123
[tt, yy] = ode45(natureODE.Rhs.F, [0 10], nature0);
2224
natureODE.Y0 = yy(end, :).';
2325

26+
% initialize model
2427
model = datools.Model('Solver', solvermodel, 'ODEModel', modelODE);
2528
nature = datools.Model('Solver', solvernature, 'ODEModel', natureODE);
2629

30+
% Observation Model
2731
naturetomodel = datools.observation.Linear(numel(nature0), 'H',...
2832
speye(natureODE.NumVars));
2933

30-
%observeindicies = 1:natureODE.NumVars;
34+
35+
% observe these variables
3136
observeindicies = 1:1:natureODE.NumVars;
3237

3338
nobsvars = numel(observeindicies);
3439

3540
R = (1/1)*speye(nobsvars);
3641

42+
% Observaton model (Gaussian here)
3743
obserrormodel = datools.error.Gaussian('CovarianceSqrt', sqrtm(R));
3844
%obserrormodel = datools.error.Tent;
3945
observation = datools.observation.Indexed(model.NumVars, ...
@@ -45,93 +51,101 @@
4551

4652
ensembleGenerator = @(N) randn(natureODE.NumVars, N);
4753

48-
ensNs = 20:5:25;
49-
infs = 1.05:.01:1.05;
50-
histvar = 1:1:1;
54+
% number of ensembles and inflation
55+
% ensNs = 10:5:25;
56+
% infs = 1.01:0.01:1.04;
57+
58+
%ensNs = [5 15 25 50];
59+
ensNs = [50 100 150 200];
60+
infs = [1.01 1.02 1.05 1.10];
61+
rejs = [0.10 0.12 0.15 0.20];
62+
63+
% variables for which you need the rank histogram plot
64+
histvar = 1:1:3;
5165
serveindicies = 1:1:natureODE.NumVars;
5266
rmses = inf*ones(numel(ensNs), numel(infs));
67+
rhplotval = inf*ones(numel(ensNs), numel(infs));
5368

54-
maxallowerr = 2;
69+
maxallowerr = 10;
5570

5671
mm = min(rmses(:));
5772

5873
if mm >= maxallowerr
5974
mm = 0;
6075
end
6176

62-
imagesc(ensNs, infs, rmses.'); caxis([mm, 1]); colorbar; set(gca,'YDir','normal');
63-
axis square; title('EnKF_l63'); colormap('hot');
64-
xlabel('Ensemble Size'); ylabel('Inflation');
77+
% imagesc(ensNs, infs, rmses.'); caxis([mm, 1]); colorbar; set(gca,'YDir','normal');
78+
% axis square; title('EnKF_l63'); colormap('hot');
79+
% xlabel('Ensemble Size'); ylabel('Inflation');
6580

6681
runsleft = find(rmses == inf);
6782

83+
f1 = figure; f2 = figure; f3 = figure; f4 = figure;
84+
6885
for runn = runsleft.'
6986
[ensNi, infi] = ind2sub([numel(ensNs), numel(infs)], runn);
87+
[ensNi, reji] = ind2sub([numel(ensNs), numel(rejs)], runn);
7088

7189
fprintf('N: %d, inf: %.3f\n', ensNs(ensNi), infs(infi));
7290

7391
ns = 1;
7492
sE = zeros(ns, 1);
7593

7694
inflationAll = infs(infi);
95+
rejAll = rejs(reji);
7796
ensN = ensNs(ensNi);
7897

7998
for sample = 1:ns
8099
% Set rng for standard experiments
81100
rng(17 + sample - 1);
82101

83102
inflation = inflationAll;
103+
rejuvenation = rejAll;
84104

85-
% No localization
86-
% r = 5;
87-
% d = @(t, y, i, j) modelODE.DistanceFunction(t, y, i, j);
88-
%localization = [];
89-
90-
%localization = @(t, y, H) datools.tapering.gc(t, y, r, d, H);
91-
92-
% localization = @(t, y, Hi, k) datools.tapering.gcCTilde(t, y, Hi, r, d, k);
93-
%localization = @(t, y, Hi, k) datools.tapering.cutoffCTilde(t, y, r, d, Hi, k);
94-
95-
enkf = datools.statistical.ensemble.EnKF(model, ...
105+
% define the statistical/variational model here
106+
enkf = datools.statistical.ensemble.SIR(model, ...
96107
'Observation', observation, ...
97108
'NumEnsemble', ensN, ...
98109
'ModelError', modelerror, ...
99110
'EnsembleGenerator', ensembleGenerator, ...
100111
'Inflation', inflation, ...
101112
'Parallel', false, ...
102-
'RankHistogram', histvar);
113+
'RankHistogram', histvar, ...
114+
'Rejuvenation', rejuvenation);
103115

104116
enkf.setMean(natureODE.Y0);
105117
enkf.scaleAnomalies(1/10);
106118

119+
% define steps and spinups
107120
spinup = 100;
108121
times = 11*spinup;
109122

110123
mses = zeros(times - spinup, 1);
111124

112-
rmse = nan;
125+
% imagesc(ensNs, infs, rmses.'); caxis([mm, 1]); colorbar; set(gca,'YDir','normal');
126+
% axis square; title('EnKF'); colormap('pink');
127+
% xlabel('Ensemble Size'); ylabel('Inflation');
113128

129+
rmse = nan;
114130
ps = '';
115-
116131
do_enkf = true;
117132

133+
rmstempval = NaN * ones(1, times-spinup);
118134

119-
135+
% Assimilation
120136
for i = 1:times
121137
% forecast
122-
123138
nature.evolve();
124139

125140
if do_enkf
126141
enkf.forecast();
127142
end
128143

129-
130144
% observe
131145
xt = naturetomodel.observeWithoutError(nature.TimeSpan(1), nature.State);
132146
y = enkf.Observation.observeWithError(model.TimeSpan(1), xt);
133147

134-
% try RH
148+
% Rank histogram (if needed)
135149
datools.utils.stat.RH(enkf, xt);
136150

137151
% analysis
@@ -153,17 +167,19 @@
153167
mses(i - spinup) = mean((xa - xt).^2);
154168
rmse = sqrt(mean(mses(1:(i - spinup))));
155169

156-
if rmse > maxallowerr || isnan(rmse) || mses(i - spinup) > 2*maxallowerr
157-
do_enkf = false;
158-
end
170+
rmstempval(i - spinup) = rmse;
171+
172+
% if rmse > maxallowerr || isnan(rmse) || mses(i - spinup) > 2*maxallowerr
173+
% do_enkf = false;
174+
% end
159175
end
160176

161-
162177
if ~do_enkf
163178
break;
164179
end
165180

166181
end
182+
hold off;
167183

168184
if isnan(rmse)
169185
rmse = 1000;
@@ -176,7 +192,6 @@
176192
end
177193

178194
end
179-
rmse
180195
resE = mean(sE);
181196

182197
if isnan(resE)
@@ -185,19 +200,84 @@
185200

186201
rmses(ensNi, infi) = resE;
187202

203+
[xs, pval, rhplotval(ensNi, infi)] = datools.utils.stat.KLDiv(enkf.RankValue(1,1:end-1),...
204+
(1/ensN) * ones(1, ensN+1));
205+
188206
mm = min(rmses(:));
189207
mm = 0;
190208

191209
if mm >= maxallowerr
192210
mm = 0;
193211
end
194212

195-
imagesc(ensNs, infs, rmses.'); caxis([mm, 1]); colorbar; set(gca,'YDir','normal');
196-
axis square; title('EnKF'); colormap('pink');
213+
figure(f1);
214+
subplot(numel(infs), numel(ensNs), runn);
215+
hold all;
216+
z = enkf.RankValue(1,1:end-1);
217+
maxz = max(z);
218+
z = z/sum(z);
219+
NN = numel(z);
220+
z = NN*z;
221+
bar(xs, z);
222+
plot(xs, pval, '-*r');
223+
set(gca,'XTick',[xs(1) xs(end)]);
224+
set(gca,'XTickLabel',[1, ensN+1]);
225+
xlabel('bins');
226+
drawnow;
227+
228+
figure(f2);
229+
%imagesc(ensNs.', infs.', flipud(rmses.')); caxis([0, 1]); colorbar; set(gca,'YDir','normal');
230+
imagesc(ensNs.', rejs.', flipud(rmses.')); caxis([0, 1]); colorbar; set(gca,'YDir','normal');
231+
axis square; title('ETPF'); colormap('pink');
197232
xlabel('Ensemble Size'); ylabel('Inflation');
233+
%ytics = max(infs) - min(infs);
234+
ytics = max(rejs) - min(rejs);
235+
%ytics = min(infs):ytics/(length(infs) - 1):max(infs);
236+
ytics = min(rejs):ytics/(length(rejs) - 1):max(rejs);
237+
set(gca,'YTick', ytics);
238+
%set(gca,'YTickLabel', fliplr(infs));
239+
set(gca,'YTickLabel', fliplr(rejs));
198240
drawnow;
241+
242+
figure(f3);
243+
%imagesc(ensNs.', infs.', flipud(rhplotval.')); caxis([-0.09 0.09]); colorbar; set(gca, 'YDir', 'normal');
244+
imagesc(ensNs, rejs, flipud(rhplotval.')); caxis([-0.09 0.09]); colorbar; set(gca, 'YDir', 'normal');
245+
axis square; title('KLDiv'); colormap('summer');
246+
xlabel('Ensemble Size'); ylabel('Inflation');
247+
set(gca,'YTick', ytics);
248+
%set(gca,'YTickLabel', fliplr(infs));
249+
set(gca,'YTickLabel', fliplr(rejs));
250+
drawnow;
251+
252+
figure(f4);
253+
subplot( numel(infs), numel(ensNs), runn);
254+
plot(spinup+1:1:times, rmstempval);
255+
xlim([spinup+1 times]); ylim([0 1]);
256+
set(gca, 'XTick', [spinup+1 times])
257+
set(gca, 'XTickLabel', [spinup+1 times])
258+
han=axes(f4,'visible','off');
259+
han.Title.Visible='on';
260+
han.XLabel.Visible='on';
261+
han.YLabel.Visible='on';
262+
ylabel(han,'Value');
263+
xlabel(han,'Time Step');
264+
title(han,'RMSE');
265+
drawnow;
266+
199267
end
268+
% step = 0;
269+
% Rank histogram
200270
% figure;
201-
% bar(enkf.RankValue(1,1:end-1));
202-
return;
271+
% for i = 1: length(enkf.RankValue(:,1))
272+
% subplot(1,3,i);
273+
% bar(enkf.RankValue(i,1:end-1));
274+
% end
203275

276+
% Kldiv + poly approx (for the variable being observed)
277+
% figure;
278+
% for i = 1:length(histvar)
279+
%
280+
% end
281+
282+
283+
return;

0 commit comments

Comments
 (0)