|
31 | 31 | import shutil |
32 | 32 | import struct |
33 | 33 | import sys |
| 34 | +from itertools import chain |
34 | 35 |
|
35 | 36 | import gdaltest |
36 | 37 | import pytest |
@@ -1230,3 +1231,271 @@ def test_stats_minmax_all_invalid_mask(datatype): |
1230 | 1231 | ds.CreateMaskBand(gdal.GMF_PER_DATASET) |
1231 | 1232 | with pytest.raises(Exception, match="Failed to compute min/max"): |
1232 | 1233 | ds.GetRasterBand(1).ComputeRasterMinMax() |
| 1234 | + |
| 1235 | + |
| 1236 | +############################################################################### |
| 1237 | + |
| 1238 | + |
| 1239 | +def test_stats_GetInterBandCovarianceMatrix(tmp_vsimem): |
| 1240 | + |
| 1241 | + gdal.FileFromMemBuffer( |
| 1242 | + tmp_vsimem / "test.tif", open("data/rgbsmall.tif", "rb").read() |
| 1243 | + ) |
| 1244 | + |
| 1245 | + with gdal.Open(tmp_vsimem / "test.tif") as ds: |
| 1246 | + assert ds.GetInterBandCovarianceMatrix() is None |
| 1247 | + assert gdal.VSIStatL(tmp_vsimem / "test.tif.aux.xml") is None |
| 1248 | + |
| 1249 | + expected_cov_matrix = [ |
| 1250 | + [2241.7045363745387, 2898.8196128051163, 1009.979953581434], |
| 1251 | + [2898.8196128051163, 3900.269159023618, 1248.65396718687], |
| 1252 | + [1009.979953581434, 1248.65396718687, 602.4703641456648], |
| 1253 | + ] |
| 1254 | + |
| 1255 | + with gdal.Open(tmp_vsimem / "test.tif") as ds: |
| 1256 | + assert list( |
| 1257 | + chain.from_iterable(ds.GetInterBandCovarianceMatrix(force=True)) |
| 1258 | + ) == pytest.approx(list(chain.from_iterable(expected_cov_matrix))) |
| 1259 | + assert gdal.VSIStatL(tmp_vsimem / "test.tif.aux.xml") is not None |
| 1260 | + |
| 1261 | + with gdal.Open(tmp_vsimem / "test.tif") as ds: |
| 1262 | + assert list( |
| 1263 | + chain.from_iterable(ds.GetInterBandCovarianceMatrix()) |
| 1264 | + ) == pytest.approx(list(chain.from_iterable(expected_cov_matrix))) |
| 1265 | + |
| 1266 | + gdal.Unlink(tmp_vsimem / "test.tif.aux.xml") |
| 1267 | + |
| 1268 | + tab_pct = [0] |
| 1269 | + |
| 1270 | + def my_progress(pct, msg, user_data): |
| 1271 | + assert pct >= tab_pct[0] |
| 1272 | + tab_pct[0] = pct |
| 1273 | + return True |
| 1274 | + |
| 1275 | + with gdal.Open(tmp_vsimem / "test.tif") as ds: |
| 1276 | + assert list( |
| 1277 | + chain.from_iterable( |
| 1278 | + ds.GetInterBandCovarianceMatrix( |
| 1279 | + force=True, write_into_metadata=False, callback=my_progress |
| 1280 | + ) |
| 1281 | + ) |
| 1282 | + ) == pytest.approx(list(chain.from_iterable(expected_cov_matrix))) |
| 1283 | + |
| 1284 | + assert tab_pct[0] == 1.0 |
| 1285 | + |
| 1286 | + with gdal.Open(tmp_vsimem / "test.tif") as ds: |
| 1287 | + assert ds.GetInterBandCovarianceMatrix() is None |
| 1288 | + |
| 1289 | + |
| 1290 | +############################################################################### |
| 1291 | + |
| 1292 | + |
| 1293 | +def test_stats_GetInterBandCovarianceMatrix_edge_cases(): |
| 1294 | + |
| 1295 | + ds = gdal.GetDriverByName("MEM").Create("", 0, 0, 0) |
| 1296 | + with pytest.raises(Exception, match="Zero band dataset"): |
| 1297 | + ds.GetInterBandCovarianceMatrix(force=True) |
| 1298 | + |
| 1299 | + |
| 1300 | +############################################################################### |
| 1301 | + |
| 1302 | + |
| 1303 | +def test_stats_ComputeInterBandCovarianceMatrix(tmp_vsimem): |
| 1304 | + |
| 1305 | + gdal.FileFromMemBuffer( |
| 1306 | + tmp_vsimem / "test.tif", open("data/rgbsmall.tif", "rb").read() |
| 1307 | + ) |
| 1308 | + |
| 1309 | + expected_cov_matrix = [ |
| 1310 | + [2241.7045363745387, 2898.8196128051163, 1009.979953581434], |
| 1311 | + [2898.8196128051163, 3900.269159023618, 1248.65396718687], |
| 1312 | + [1009.979953581434, 1248.65396718687, 602.4703641456648], |
| 1313 | + ] |
| 1314 | + |
| 1315 | + with gdal.Open(tmp_vsimem / "test.tif") as ds: |
| 1316 | + assert list( |
| 1317 | + chain.from_iterable(ds.ComputeInterBandCovarianceMatrix()) |
| 1318 | + ) == pytest.approx(list(chain.from_iterable(expected_cov_matrix))) |
| 1319 | + |
| 1320 | + try: |
| 1321 | + import numpy |
| 1322 | + |
| 1323 | + has_numpy = True |
| 1324 | + except ImportError: |
| 1325 | + has_numpy = False |
| 1326 | + if has_numpy: |
| 1327 | + numpy_cov = numpy.cov( |
| 1328 | + [ |
| 1329 | + ds.GetRasterBand(n + 1).ReadAsArray().ravel() |
| 1330 | + for n in range(ds.RasterCount) |
| 1331 | + ] |
| 1332 | + ) |
| 1333 | + assert list(chain.from_iterable(numpy_cov)) == pytest.approx( |
| 1334 | + list(chain.from_iterable(expected_cov_matrix)) |
| 1335 | + ) |
| 1336 | + |
| 1337 | + assert gdal.VSIStatL(tmp_vsimem / "test.tif.aux.xml") is not None |
| 1338 | + |
| 1339 | + with gdal.Open(tmp_vsimem / "test.tif") as ds: |
| 1340 | + assert list( |
| 1341 | + chain.from_iterable(ds.GetInterBandCovarianceMatrix()) |
| 1342 | + ) == pytest.approx(list(chain.from_iterable(expected_cov_matrix))) |
| 1343 | + |
| 1344 | + tab_pct = [0] |
| 1345 | + |
| 1346 | + def my_progress(pct, msg, user_data): |
| 1347 | + assert pct >= tab_pct[0] |
| 1348 | + tab_pct[0] = pct |
| 1349 | + return True |
| 1350 | + |
| 1351 | + gdal.Unlink(tmp_vsimem / "test.tif.aux.xml") |
| 1352 | + |
| 1353 | + with gdal.Open(tmp_vsimem / "test.tif") as ds: |
| 1354 | + assert list( |
| 1355 | + chain.from_iterable( |
| 1356 | + ds.ComputeInterBandCovarianceMatrix( |
| 1357 | + write_into_metadata=False, callback=my_progress |
| 1358 | + ) |
| 1359 | + ) |
| 1360 | + ) == pytest.approx(list(chain.from_iterable(expected_cov_matrix))) |
| 1361 | + |
| 1362 | + assert tab_pct[0] == 1.0 |
| 1363 | + |
| 1364 | + with gdal.Open(tmp_vsimem / "test.tif") as ds: |
| 1365 | + assert ds.GetInterBandCovarianceMatrix() is None |
| 1366 | + |
| 1367 | + |
| 1368 | +############################################################################### |
| 1369 | + |
| 1370 | + |
| 1371 | +def test_stats_ComputeInterBandCovarianceMatrix_overviews(tmp_vsimem): |
| 1372 | + |
| 1373 | + ds = gdal.Translate(tmp_vsimem / "test.tif", "data/rgbsmall.tif", width=1024) |
| 1374 | + ds.BuildOverviews("NEAR", [2]) |
| 1375 | + ds.Close() |
| 1376 | + |
| 1377 | + expected_cov_matrix = [ |
| 1378 | + [2241.7474621754936, 2898.723016592399, 1010.1568325910154], |
| 1379 | + [2898.723016592399, 3900.259907723227, 1248.630954306772], |
| 1380 | + [1010.1568325910154, 1248.630954306772, 601.9344108086958], |
| 1381 | + ] |
| 1382 | + |
| 1383 | + with gdal.Open(tmp_vsimem / "test.tif") as ds: |
| 1384 | + cov_matrix = ds.ComputeInterBandCovarianceMatrix(approx_ok=True) |
| 1385 | + assert list(chain.from_iterable(cov_matrix)) == pytest.approx( |
| 1386 | + list(chain.from_iterable(expected_cov_matrix)) |
| 1387 | + ) |
| 1388 | + |
| 1389 | + |
| 1390 | +############################################################################### |
| 1391 | + |
| 1392 | + |
| 1393 | +def test_stats_ComputeInterBandCovarianceMatrix_edge_cases(): |
| 1394 | + |
| 1395 | + ds = gdal.GetDriverByName("MEM").Create("", 0, 0, 0) |
| 1396 | + with pytest.raises(Exception, match="Zero band dataset"): |
| 1397 | + ds.ComputeInterBandCovarianceMatrix() |
| 1398 | + |
| 1399 | + |
| 1400 | +############################################################################### |
| 1401 | + |
| 1402 | + |
| 1403 | +def test_stats_ComputeInterBandCovarianceMatrix_nodata(): |
| 1404 | + |
| 1405 | + ds = gdal.GetDriverByName("MEM").Create("", 4, 1, 2) |
| 1406 | + ds.GetRasterBand(1).WriteRaster(0, 0, 4, 1, b"\x01\x02\03\xFF") |
| 1407 | + ds.GetRasterBand(1).SetNoDataValue(255) |
| 1408 | + ds.GetRasterBand(2).WriteRaster(0, 0, 4, 1, b"\x02\x01\xFE\03") |
| 1409 | + ds.GetRasterBand(2).SetNoDataValue(254) |
| 1410 | + |
| 1411 | + expected_cov_matrix = [[1, 0], [0, 1]] |
| 1412 | + |
| 1413 | + cov_matrix = ds.ComputeInterBandCovarianceMatrix() |
| 1414 | + assert list(chain.from_iterable(cov_matrix)) == pytest.approx( |
| 1415 | + list(chain.from_iterable(expected_cov_matrix)) |
| 1416 | + ) |
| 1417 | + |
| 1418 | + |
| 1419 | +############################################################################### |
| 1420 | + |
| 1421 | + |
| 1422 | +def test_stats_ComputeInterBandCovarianceMatrix_nan_value(): |
| 1423 | + |
| 1424 | + ds = gdal.GetDriverByName("MEM").Create("", 4, 1, 2, gdal.GDT_Float32) |
| 1425 | + ds.GetRasterBand(1).WriteRaster(0, 0, 4, 1, struct.pack("f" * 4, 1, 2, 3, math.nan)) |
| 1426 | + ds.GetRasterBand(2).WriteRaster(0, 0, 4, 1, struct.pack("f" * 4, 2, 1, math.nan, 3)) |
| 1427 | + |
| 1428 | + expected_cov_matrix = [[1, 0], [0, 1]] |
| 1429 | + |
| 1430 | + cov_matrix = ds.ComputeInterBandCovarianceMatrix() |
| 1431 | + assert list(chain.from_iterable(cov_matrix)) == pytest.approx( |
| 1432 | + list(chain.from_iterable(expected_cov_matrix)) |
| 1433 | + ) |
| 1434 | + |
| 1435 | + |
| 1436 | +############################################################################### |
| 1437 | + |
| 1438 | + |
| 1439 | +def test_stats_ComputeInterBandCovarianceMatrix_nan_result(): |
| 1440 | + |
| 1441 | + ds = gdal.GetDriverByName("MEM").Create("", 2, 1, 2, gdal.GDT_Float32) |
| 1442 | + ds.GetRasterBand(1).WriteRaster(0, 0, 2, 1, struct.pack("f" * 2, 1, math.nan)) |
| 1443 | + ds.GetRasterBand(2).WriteRaster(0, 0, 2, 1, struct.pack("f" * 2, math.nan, 1)) |
| 1444 | + |
| 1445 | + cov_matrix = ds.ComputeInterBandCovarianceMatrix() |
| 1446 | + assert math.isnan(cov_matrix[0][0]) |
| 1447 | + assert math.isnan(cov_matrix[0][1]) |
| 1448 | + assert math.isnan(cov_matrix[1][0]) |
| 1449 | + assert math.isnan(cov_matrix[1][1]) |
| 1450 | + |
| 1451 | + |
| 1452 | +############################################################################### |
| 1453 | + |
| 1454 | + |
| 1455 | +def test_stats_ComputeInterBandCovarianceMatrix_failed_to_compute_stats(): |
| 1456 | + |
| 1457 | + ds = gdal.GetDriverByName("MEM").Create("", 2, 1, 1, gdal.GDT_Float32) |
| 1458 | + ds.GetRasterBand(1).WriteRaster( |
| 1459 | + 0, 0, 2, 1, struct.pack("f" * 2, math.nan, math.nan) |
| 1460 | + ) |
| 1461 | + |
| 1462 | + with pytest.raises(Exception, match="Failed to compute statistics"): |
| 1463 | + ds.ComputeInterBandCovarianceMatrix() |
| 1464 | + |
| 1465 | + |
| 1466 | +############################################################################### |
| 1467 | + |
| 1468 | + |
| 1469 | +def test_stats_ComputeInterBandCovarianceMatrix_mask_band(): |
| 1470 | + |
| 1471 | + ds = gdal.GetDriverByName("MEM").Create("", 4, 1, 2) |
| 1472 | + ds.GetRasterBand(1).WriteRaster(0, 0, 4, 1, b"\x01\x02\x03\xFF") |
| 1473 | + ds.GetRasterBand(1).CreateMaskBand(0) |
| 1474 | + ds.GetRasterBand(1).GetMaskBand().WriteRaster(0, 0, 4, 1, b"\xFF\xFF\xFF\x00") |
| 1475 | + ds.GetRasterBand(2).WriteRaster(0, 0, 4, 1, b"\x02\x01\xFE\x03") |
| 1476 | + ds.GetRasterBand(2).CreateMaskBand(0) |
| 1477 | + ds.GetRasterBand(2).GetMaskBand().WriteRaster(0, 0, 4, 1, b"\xFF\xFF\x00\xFF") |
| 1478 | + |
| 1479 | + expected_cov_matrix = [[1, 0], [0, 1]] |
| 1480 | + |
| 1481 | + cov_matrix = ds.ComputeInterBandCovarianceMatrix() |
| 1482 | + assert list(chain.from_iterable(cov_matrix)) == pytest.approx( |
| 1483 | + list(chain.from_iterable(expected_cov_matrix)) |
| 1484 | + ) |
| 1485 | + |
| 1486 | + |
| 1487 | +############################################################################### |
| 1488 | + |
| 1489 | + |
| 1490 | +@pytest.mark.slow() |
| 1491 | +def test_stats_ComputeInterBandCovarianceMatrix_huge_mem_alloc(): |
| 1492 | + |
| 1493 | + ds = gdal.GetDriverByName("MEM").Create("", 1, 1, 50000) |
| 1494 | + |
| 1495 | + def my_progress(pct, msg, user_data): |
| 1496 | + return pct < 1e-2 |
| 1497 | + |
| 1498 | + with pytest.raises(Exception): |
| 1499 | + ds.ComputeInterBandCovarianceMatrix( |
| 1500 | + write_into_metadata=False, callback=my_progress |
| 1501 | + ) |
0 commit comments