|
7 | 7 | import pandas as pd
|
8 | 8 |
|
9 | 9 | import xarray as xr
|
| 10 | +from xarray.backends.api import open_datatree |
| 11 | +from xarray.core.datatree import DataTree |
10 | 12 |
|
11 | 13 | from . import _skip_slow, parameterized, randint, randn, requires_dask
|
12 | 14 |
|
|
16 | 18 | except ImportError:
|
17 | 19 | pass
|
18 | 20 |
|
19 |
| - |
20 | 21 | os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
|
21 | 22 |
|
22 | 23 | _ENGINES = tuple(xr.backends.list_engines().keys() - {"store"})
|
@@ -469,6 +470,116 @@ def create_delayed_write():
|
469 | 470 | return ds.to_netcdf("file.nc", engine="netcdf4", compute=False)
|
470 | 471 |
|
471 | 472 |
|
| 473 | +class IONestedDataTree: |
| 474 | + """ |
| 475 | + A few examples that benchmark reading/writing a heavily nested netCDF datatree with |
| 476 | + xarray |
| 477 | + """ |
| 478 | + |
| 479 | + timeout = 300.0 |
| 480 | + repeat = 1 |
| 481 | + number = 5 |
| 482 | + |
| 483 | + def make_datatree(self, nchildren=10): |
| 484 | + # multiple Dataset |
| 485 | + self.ds = xr.Dataset() |
| 486 | + self.nt = 1000 |
| 487 | + self.nx = 90 |
| 488 | + self.ny = 45 |
| 489 | + self.nchildren = nchildren |
| 490 | + |
| 491 | + self.block_chunks = { |
| 492 | + "time": self.nt / 4, |
| 493 | + "lon": self.nx / 3, |
| 494 | + "lat": self.ny / 3, |
| 495 | + } |
| 496 | + |
| 497 | + self.time_chunks = {"time": int(self.nt / 36)} |
| 498 | + |
| 499 | + times = pd.date_range("1970-01-01", periods=self.nt, freq="D") |
| 500 | + lons = xr.DataArray( |
| 501 | + np.linspace(0, 360, self.nx), |
| 502 | + dims=("lon",), |
| 503 | + attrs={"units": "degrees east", "long_name": "longitude"}, |
| 504 | + ) |
| 505 | + lats = xr.DataArray( |
| 506 | + np.linspace(-90, 90, self.ny), |
| 507 | + dims=("lat",), |
| 508 | + attrs={"units": "degrees north", "long_name": "latitude"}, |
| 509 | + ) |
| 510 | + self.ds["foo"] = xr.DataArray( |
| 511 | + randn((self.nt, self.nx, self.ny), frac_nan=0.2), |
| 512 | + coords={"lon": lons, "lat": lats, "time": times}, |
| 513 | + dims=("time", "lon", "lat"), |
| 514 | + name="foo", |
| 515 | + attrs={"units": "foo units", "description": "a description"}, |
| 516 | + ) |
| 517 | + self.ds["bar"] = xr.DataArray( |
| 518 | + randn((self.nt, self.nx, self.ny), frac_nan=0.2), |
| 519 | + coords={"lon": lons, "lat": lats, "time": times}, |
| 520 | + dims=("time", "lon", "lat"), |
| 521 | + name="bar", |
| 522 | + attrs={"units": "bar units", "description": "a description"}, |
| 523 | + ) |
| 524 | + self.ds["baz"] = xr.DataArray( |
| 525 | + randn((self.nx, self.ny), frac_nan=0.2).astype(np.float32), |
| 526 | + coords={"lon": lons, "lat": lats}, |
| 527 | + dims=("lon", "lat"), |
| 528 | + name="baz", |
| 529 | + attrs={"units": "baz units", "description": "a description"}, |
| 530 | + ) |
| 531 | + |
| 532 | + self.ds.attrs = {"history": "created for xarray benchmarking"} |
| 533 | + |
| 534 | + self.oinds = { |
| 535 | + "time": randint(0, self.nt, 120), |
| 536 | + "lon": randint(0, self.nx, 20), |
| 537 | + "lat": randint(0, self.ny, 10), |
| 538 | + } |
| 539 | + self.vinds = { |
| 540 | + "time": xr.DataArray(randint(0, self.nt, 120), dims="x"), |
| 541 | + "lon": xr.DataArray(randint(0, self.nx, 120), dims="x"), |
| 542 | + "lat": slice(3, 20), |
| 543 | + } |
| 544 | + root = {f"group_{group}": self.ds for group in range(self.nchildren)} |
| 545 | + nested_tree1 = { |
| 546 | + f"group_{group}/subgroup_1": xr.Dataset() for group in range(self.nchildren) |
| 547 | + } |
| 548 | + nested_tree2 = { |
| 549 | + f"group_{group}/subgroup_2": xr.DataArray(np.arange(1, 10)).to_dataset( |
| 550 | + name="a" |
| 551 | + ) |
| 552 | + for group in range(self.nchildren) |
| 553 | + } |
| 554 | + nested_tree3 = { |
| 555 | + f"group_{group}/subgroup_2/sub-subgroup_1": self.ds |
| 556 | + for group in range(self.nchildren) |
| 557 | + } |
| 558 | + dtree = root | nested_tree1 | nested_tree2 | nested_tree3 |
| 559 | + self.dtree = DataTree.from_dict(dtree) |
| 560 | + |
| 561 | + |
| 562 | +class IOReadDataTreeNetCDF4(IONestedDataTree): |
| 563 | + def setup(self): |
| 564 | + # TODO: Lazily skipped in CI as it is very demanding and slow. |
| 565 | + # Improve times and remove errors. |
| 566 | + _skip_slow() |
| 567 | + |
| 568 | + requires_dask() |
| 569 | + |
| 570 | + self.make_datatree() |
| 571 | + self.format = "NETCDF4" |
| 572 | + self.filepath = "datatree.nc4.nc" |
| 573 | + dtree = self.dtree |
| 574 | + dtree.to_netcdf(filepath=self.filepath) |
| 575 | + |
| 576 | + def time_load_datatree_netcdf4(self): |
| 577 | + open_datatree(self.filepath, engine="netcdf4").load() |
| 578 | + |
| 579 | + def time_open_datatree_netcdf4(self): |
| 580 | + open_datatree(self.filepath, engine="netcdf4") |
| 581 | + |
| 582 | + |
472 | 583 | class IOWriteNetCDFDask:
|
473 | 584 | timeout = 60
|
474 | 585 | repeat = 1
|
|
0 commit comments