Skip to content

Commit cf128c2

Browse files
committed
Add GDALDataset::[Get|Compute]InterBandCovarianceMatrix() and corresponding C and Python API
This generates and reads STATISTICS_COVARIANCES band metadata items as done by ArcGIS: https://pro.arcgis.com/en/pro-app/3.4/tool-reference/spatial-analyst/how-band-collection-statistics-works.htm which is consistant with the results of numpy.cov()
1 parent dff7013 commit cf128c2

File tree

13 files changed

+1521
-4
lines changed

13 files changed

+1521
-4
lines changed

autotest/gcore/gdal_stats.py

Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import shutil
3232
import struct
3333
import sys
34+
from itertools import chain
3435

3536
import gdaltest
3637
import pytest
@@ -1230,3 +1231,271 @@ def test_stats_minmax_all_invalid_mask(datatype):
12301231
ds.CreateMaskBand(gdal.GMF_PER_DATASET)
12311232
with pytest.raises(Exception, match="Failed to compute min/max"):
12321233
ds.GetRasterBand(1).ComputeRasterMinMax()
1234+
1235+
1236+
###############################################################################
1237+
1238+
1239+
def test_stats_GetInterBandCovarianceMatrix(tmp_vsimem):
1240+
1241+
gdal.FileFromMemBuffer(
1242+
tmp_vsimem / "test.tif", open("data/rgbsmall.tif", "rb").read()
1243+
)
1244+
1245+
with gdal.Open(tmp_vsimem / "test.tif") as ds:
1246+
assert ds.GetInterBandCovarianceMatrix() is None
1247+
assert gdal.VSIStatL(tmp_vsimem / "test.tif.aux.xml") is None
1248+
1249+
expected_cov_matrix = [
1250+
[2241.7045363745387, 2898.8196128051163, 1009.979953581434],
1251+
[2898.8196128051163, 3900.269159023618, 1248.65396718687],
1252+
[1009.979953581434, 1248.65396718687, 602.4703641456648],
1253+
]
1254+
1255+
with gdal.Open(tmp_vsimem / "test.tif") as ds:
1256+
assert list(
1257+
chain.from_iterable(ds.GetInterBandCovarianceMatrix(force=True))
1258+
) == pytest.approx(list(chain.from_iterable(expected_cov_matrix)))
1259+
assert gdal.VSIStatL(tmp_vsimem / "test.tif.aux.xml") is not None
1260+
1261+
with gdal.Open(tmp_vsimem / "test.tif") as ds:
1262+
assert list(
1263+
chain.from_iterable(ds.GetInterBandCovarianceMatrix())
1264+
) == pytest.approx(list(chain.from_iterable(expected_cov_matrix)))
1265+
1266+
gdal.Unlink(tmp_vsimem / "test.tif.aux.xml")
1267+
1268+
tab_pct = [0]
1269+
1270+
def my_progress(pct, msg, user_data):
1271+
assert pct >= tab_pct[0]
1272+
tab_pct[0] = pct
1273+
return True
1274+
1275+
with gdal.Open(tmp_vsimem / "test.tif") as ds:
1276+
assert list(
1277+
chain.from_iterable(
1278+
ds.GetInterBandCovarianceMatrix(
1279+
force=True, write_into_metadata=False, callback=my_progress
1280+
)
1281+
)
1282+
) == pytest.approx(list(chain.from_iterable(expected_cov_matrix)))
1283+
1284+
assert tab_pct[0] == 1.0
1285+
1286+
with gdal.Open(tmp_vsimem / "test.tif") as ds:
1287+
assert ds.GetInterBandCovarianceMatrix() is None
1288+
1289+
1290+
###############################################################################
1291+
1292+
1293+
def test_stats_GetInterBandCovarianceMatrix_edge_cases():
1294+
1295+
ds = gdal.GetDriverByName("MEM").Create("", 0, 0, 0)
1296+
with pytest.raises(Exception, match="Zero band dataset"):
1297+
ds.GetInterBandCovarianceMatrix(force=True)
1298+
1299+
1300+
###############################################################################
1301+
1302+
1303+
def test_stats_ComputeInterBandCovarianceMatrix(tmp_vsimem):
1304+
1305+
gdal.FileFromMemBuffer(
1306+
tmp_vsimem / "test.tif", open("data/rgbsmall.tif", "rb").read()
1307+
)
1308+
1309+
expected_cov_matrix = [
1310+
[2241.7045363745387, 2898.8196128051163, 1009.979953581434],
1311+
[2898.8196128051163, 3900.269159023618, 1248.65396718687],
1312+
[1009.979953581434, 1248.65396718687, 602.4703641456648],
1313+
]
1314+
1315+
with gdal.Open(tmp_vsimem / "test.tif") as ds:
1316+
assert list(
1317+
chain.from_iterable(ds.ComputeInterBandCovarianceMatrix())
1318+
) == pytest.approx(list(chain.from_iterable(expected_cov_matrix)))
1319+
1320+
try:
1321+
import numpy
1322+
1323+
has_numpy = True
1324+
except ImportError:
1325+
has_numpy = False
1326+
if has_numpy:
1327+
numpy_cov = numpy.cov(
1328+
[
1329+
ds.GetRasterBand(n + 1).ReadAsArray().ravel()
1330+
for n in range(ds.RasterCount)
1331+
]
1332+
)
1333+
assert list(chain.from_iterable(numpy_cov)) == pytest.approx(
1334+
list(chain.from_iterable(expected_cov_matrix))
1335+
)
1336+
1337+
assert gdal.VSIStatL(tmp_vsimem / "test.tif.aux.xml") is not None
1338+
1339+
with gdal.Open(tmp_vsimem / "test.tif") as ds:
1340+
assert list(
1341+
chain.from_iterable(ds.GetInterBandCovarianceMatrix())
1342+
) == pytest.approx(list(chain.from_iterable(expected_cov_matrix)))
1343+
1344+
tab_pct = [0]
1345+
1346+
def my_progress(pct, msg, user_data):
1347+
assert pct >= tab_pct[0]
1348+
tab_pct[0] = pct
1349+
return True
1350+
1351+
gdal.Unlink(tmp_vsimem / "test.tif.aux.xml")
1352+
1353+
with gdal.Open(tmp_vsimem / "test.tif") as ds:
1354+
assert list(
1355+
chain.from_iterable(
1356+
ds.ComputeInterBandCovarianceMatrix(
1357+
write_into_metadata=False, callback=my_progress
1358+
)
1359+
)
1360+
) == pytest.approx(list(chain.from_iterable(expected_cov_matrix)))
1361+
1362+
assert tab_pct[0] == 1.0
1363+
1364+
with gdal.Open(tmp_vsimem / "test.tif") as ds:
1365+
assert ds.GetInterBandCovarianceMatrix() is None
1366+
1367+
1368+
###############################################################################
1369+
1370+
1371+
def test_stats_ComputeInterBandCovarianceMatrix_overviews(tmp_vsimem):
1372+
1373+
ds = gdal.Translate(tmp_vsimem / "test.tif", "data/rgbsmall.tif", width=1024)
1374+
ds.BuildOverviews("NEAR", [2])
1375+
ds.Close()
1376+
1377+
expected_cov_matrix = [
1378+
[2241.7474621754936, 2898.723016592399, 1010.1568325910154],
1379+
[2898.723016592399, 3900.259907723227, 1248.630954306772],
1380+
[1010.1568325910154, 1248.630954306772, 601.9344108086958],
1381+
]
1382+
1383+
with gdal.Open(tmp_vsimem / "test.tif") as ds:
1384+
cov_matrix = ds.ComputeInterBandCovarianceMatrix(approx_ok=True)
1385+
assert list(chain.from_iterable(cov_matrix)) == pytest.approx(
1386+
list(chain.from_iterable(expected_cov_matrix))
1387+
)
1388+
1389+
1390+
###############################################################################
1391+
1392+
1393+
def test_stats_ComputeInterBandCovarianceMatrix_edge_cases():
1394+
1395+
ds = gdal.GetDriverByName("MEM").Create("", 0, 0, 0)
1396+
with pytest.raises(Exception, match="Zero band dataset"):
1397+
ds.ComputeInterBandCovarianceMatrix()
1398+
1399+
1400+
###############################################################################
1401+
1402+
1403+
def test_stats_ComputeInterBandCovarianceMatrix_nodata():
1404+
1405+
ds = gdal.GetDriverByName("MEM").Create("", 4, 1, 2)
1406+
ds.GetRasterBand(1).WriteRaster(0, 0, 4, 1, b"\x01\x02\03\xFF")
1407+
ds.GetRasterBand(1).SetNoDataValue(255)
1408+
ds.GetRasterBand(2).WriteRaster(0, 0, 4, 1, b"\x02\x01\xFE\03")
1409+
ds.GetRasterBand(2).SetNoDataValue(254)
1410+
1411+
expected_cov_matrix = [[1, 0], [0, 1]]
1412+
1413+
cov_matrix = ds.ComputeInterBandCovarianceMatrix()
1414+
assert list(chain.from_iterable(cov_matrix)) == pytest.approx(
1415+
list(chain.from_iterable(expected_cov_matrix))
1416+
)
1417+
1418+
1419+
###############################################################################
1420+
1421+
1422+
def test_stats_ComputeInterBandCovarianceMatrix_nan_value():
1423+
1424+
ds = gdal.GetDriverByName("MEM").Create("", 4, 1, 2, gdal.GDT_Float32)
1425+
ds.GetRasterBand(1).WriteRaster(0, 0, 4, 1, struct.pack("f" * 4, 1, 2, 3, math.nan))
1426+
ds.GetRasterBand(2).WriteRaster(0, 0, 4, 1, struct.pack("f" * 4, 2, 1, math.nan, 3))
1427+
1428+
expected_cov_matrix = [[1, 0], [0, 1]]
1429+
1430+
cov_matrix = ds.ComputeInterBandCovarianceMatrix()
1431+
assert list(chain.from_iterable(cov_matrix)) == pytest.approx(
1432+
list(chain.from_iterable(expected_cov_matrix))
1433+
)
1434+
1435+
1436+
###############################################################################
1437+
1438+
1439+
def test_stats_ComputeInterBandCovarianceMatrix_nan_result():
1440+
1441+
ds = gdal.GetDriverByName("MEM").Create("", 2, 1, 2, gdal.GDT_Float32)
1442+
ds.GetRasterBand(1).WriteRaster(0, 0, 2, 1, struct.pack("f" * 2, 1, math.nan))
1443+
ds.GetRasterBand(2).WriteRaster(0, 0, 2, 1, struct.pack("f" * 2, math.nan, 1))
1444+
1445+
cov_matrix = ds.ComputeInterBandCovarianceMatrix()
1446+
assert math.isnan(cov_matrix[0][0])
1447+
assert math.isnan(cov_matrix[0][1])
1448+
assert math.isnan(cov_matrix[1][0])
1449+
assert math.isnan(cov_matrix[1][1])
1450+
1451+
1452+
###############################################################################
1453+
1454+
1455+
def test_stats_ComputeInterBandCovarianceMatrix_failed_to_compute_stats():
1456+
1457+
ds = gdal.GetDriverByName("MEM").Create("", 2, 1, 1, gdal.GDT_Float32)
1458+
ds.GetRasterBand(1).WriteRaster(
1459+
0, 0, 2, 1, struct.pack("f" * 2, math.nan, math.nan)
1460+
)
1461+
1462+
with pytest.raises(Exception, match="Failed to compute statistics"):
1463+
ds.ComputeInterBandCovarianceMatrix()
1464+
1465+
1466+
###############################################################################
1467+
1468+
1469+
def test_stats_ComputeInterBandCovarianceMatrix_mask_band():
1470+
1471+
ds = gdal.GetDriverByName("MEM").Create("", 4, 1, 2)
1472+
ds.GetRasterBand(1).WriteRaster(0, 0, 4, 1, b"\x01\x02\x03\xFF")
1473+
ds.GetRasterBand(1).CreateMaskBand(0)
1474+
ds.GetRasterBand(1).GetMaskBand().WriteRaster(0, 0, 4, 1, b"\xFF\xFF\xFF\x00")
1475+
ds.GetRasterBand(2).WriteRaster(0, 0, 4, 1, b"\x02\x01\xFE\x03")
1476+
ds.GetRasterBand(2).CreateMaskBand(0)
1477+
ds.GetRasterBand(2).GetMaskBand().WriteRaster(0, 0, 4, 1, b"\xFF\xFF\x00\xFF")
1478+
1479+
expected_cov_matrix = [[1, 0], [0, 1]]
1480+
1481+
cov_matrix = ds.ComputeInterBandCovarianceMatrix()
1482+
assert list(chain.from_iterable(cov_matrix)) == pytest.approx(
1483+
list(chain.from_iterable(expected_cov_matrix))
1484+
)
1485+
1486+
1487+
###############################################################################
1488+
1489+
1490+
@pytest.mark.slow()
1491+
def test_stats_ComputeInterBandCovarianceMatrix_huge_mem_alloc():
1492+
1493+
ds = gdal.GetDriverByName("MEM").Create("", 1, 1, 50000)
1494+
1495+
def my_progress(pct, msg, user_data):
1496+
return pct < 1e-2
1497+
1498+
with pytest.raises(Exception):
1499+
ds.ComputeInterBandCovarianceMatrix(
1500+
write_into_metadata=False, callback=my_progress
1501+
)

doc/source/api/gdaldataset_cpp.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,6 @@ GDALRelationship class
5252
.. spelling:word-list::
5353
GetMetadataItem
5454
CREATIONOPTIONLIST
55+
nDeltaDegreeOfFreedom
56+
padfCovMatrix
57+

doc/source/api/python/raster_api.rst

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ Band Algebra
7272
.. autofunction:: osgeo.gdal.abs
7373

7474
.. autofunction:: osgeo.gdal.log
75-
75+
7676
.. autofunction:: osgeo.gdal.log10
7777

7878
.. autofunction:: osgeo.gdal.logical_and
@@ -119,3 +119,9 @@ Other
119119
:members:
120120
:undoc-members:
121121
:exclude-members: thisown
122+
123+
124+
.. below is an allow-list for spelling checker.
125+
126+
.. spelling:word-list::
127+
RasterCount

doc/source/api/raster_c_api.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,6 @@ gdal.h: Raster C API
2222
pDst
2323
SetMetadataItem
2424
nDatasetTypeFlag
25+
nDeltaDegreeOfFreedom
26+
padfCovMatrix
27+

gcore/gdal.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1359,6 +1359,16 @@ CPLErr CPL_DLL GDALDatasetGeolocationToPixelLine(
13591359
GDALDatasetH, double dfGeolocX, double dfGeolocY, OGRSpatialReferenceH hSRS,
13601360
double *pdfPixel, double *pdfLine, CSLConstList papszTransformerOptions);
13611361

1362+
CPLErr CPL_DLL GDALDatasetGetInterBandCovarianceMatrix(
1363+
GDALDatasetH hDS, double *padfCovMatrix, size_t nSize, bool bApproxOK,
1364+
bool bForce, bool bWriteIntoMetadata, int nDeltaDegreeOfFreedom,
1365+
GDALProgressFunc pfnProgress, void *pProgressData);
1366+
1367+
CPLErr CPL_DLL GDALDatasetComputeInterBandCovarianceMatrix(
1368+
GDALDatasetH hDS, double *padfCovMatrix, size_t nSize, bool bApproxOK,
1369+
bool bWriteIntoMetadata, int nDeltaDegreeOfFreedom,
1370+
GDALProgressFunc pfnProgress, void *pProgressData);
1371+
13621372
int CPL_DLL CPL_STDCALL GDALGetGCPCount(GDALDatasetH);
13631373
const char CPL_DLL *CPL_STDCALL GDALGetGCPProjection(GDALDatasetH);
13641374
OGRSpatialReferenceH CPL_DLL GDALGetGCPSpatialRef(GDALDatasetH);

gcore/gdal_dataset.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,29 @@ class CPL_DLL GDALDataset : public GDALMajorObject
511511
GDALProgressFunc pfnProgress,
512512
void *pProgressData, CSLConstList papszOptions);
513513

514+
CPLErr GetInterBandCovarianceMatrix(double *padfCovMatrix, size_t nSize,
515+
bool bApproxOK = false,
516+
bool bForce = false,
517+
bool bWriteIntoMetadata = true,
518+
int nDeltaDegreeOfFreedom = 1,
519+
GDALProgressFunc pfnProgress = nullptr,
520+
void *pProgressData = nullptr);
521+
522+
std::vector<double> GetInterBandCovarianceMatrix(
523+
bool bApproxOK = false, bool bForce = false,
524+
bool bWriteIntoMetadata = true, int nDeltaDegreeOfFreedom = 1,
525+
GDALProgressFunc pfnProgress = nullptr, void *pProgressData = nullptr);
526+
527+
CPLErr ComputeInterBandCovarianceMatrix(
528+
double *padfCovMatrix, size_t nSize, bool bApproxOK = false,
529+
bool bWriteIntoMetadata = true, int nDeltaDegreeOfFreedom = 1,
530+
GDALProgressFunc pfnProgress = nullptr, void *pProgressData = nullptr);
531+
532+
std::vector<double> ComputeInterBandCovarianceMatrix(
533+
bool bApproxOK = false, bool bWriteIntoMetadata = true,
534+
int nDeltaDegreeOfFreedom = 1, GDALProgressFunc pfnProgress = nullptr,
535+
void *pProgressData = nullptr);
536+
514537
#ifndef DOXYGEN_XML
515538
void ReportError(CPLErr eErrClass, CPLErrorNum err_no, const char *fmt,
516539
...) const CPL_PRINT_FUNC_FORMAT(4, 5);

gcore/gdal_rasterband.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,14 @@ class CPL_DLL GDALRasterBand : public GDALMajorObject
502502
double *pdfMax, double *pdfMean,
503503
double *pdfStdDev, GDALProgressFunc,
504504
void *pProgressData);
505+
506+
double ComputeInterBandCovariance(
507+
GDALRasterBand *poOtherBand, bool bApproxOK = false,
508+
int nDeltaDegreeOfFreedom = 1, const double *pdfThisBandMean = nullptr,
509+
const double *pdfOtherBandMean = nullptr,
510+
std::vector<double> *padfTempVector = nullptr,
511+
GDALProgressFunc = nullptr, void *pProgressData = nullptr);
512+
505513
virtual CPLErr SetStatistics(double dfMin, double dfMax, double dfMean,
506514
double dfStdDev);
507515
virtual CPLErr ComputeRasterMinMax(int bApproxOK, double *adfMinMax);

0 commit comments

Comments
 (0)