23
23
LOG = logging .getLogger (__name__ )
24
24
25
25
26
+ def _tidy (v ):
27
+ if isinstance (v , (list , tuple , set )):
28
+ return [_tidy (i ) for i in v ]
29
+ if isinstance (v , dict ):
30
+ return {k : _tidy (v ) for k , v in v .items ()}
31
+ if isinstance (v , str ) and v .startswith ("/" ):
32
+ return os .path .basename (v )
33
+ if isinstance (v , datetime .datetime ):
34
+ return v .isoformat ()
35
+ if isinstance (v , datetime .date ):
36
+ return v .isoformat ()
37
+ if isinstance (v , datetime .timedelta ):
38
+ return frequency_to_string (v )
39
+
40
+ if isinstance (v , Dataset ):
41
+ # That can happen in the `arguments`
42
+ # if a dataset is passed as an argument
43
+ return repr (v )
44
+
45
+ if isinstance (v , slice ):
46
+ return (v .start , v .stop , v .step )
47
+
48
+ return v
49
+
50
+
26
51
class Dataset :
27
52
arguments = {}
53
+ _name = None
28
54
29
55
def mutate (self ) -> "Dataset" :
30
56
"""Give an opportunity to a subclass to return a new Dataset
@@ -41,6 +67,21 @@ def _len(self):
41
67
return len (self )
42
68
43
69
def _subset (self , ** kwargs ):
70
+
71
+ if not kwargs :
72
+ return self .mutate ()
73
+
74
+ name = kwargs .pop ("name" , None )
75
+ result = self .__subset (** kwargs )
76
+ result ._name = name
77
+
78
+ return result
79
+
80
+ @property
81
+ def name (self ):
82
+ return self ._name
83
+
84
+ def __subset (self , ** kwargs ):
44
85
if not kwargs :
45
86
return self .mutate ()
46
87
@@ -254,41 +295,32 @@ def typed_variables(self):
254
295
255
296
return result
256
297
298
+ def _input_sources (self ):
299
+ sources = []
300
+ self .collect_input_sources (sources )
301
+ return sources
302
+
257
303
def metadata (self ):
258
304
import anemoi
259
305
260
- def tidy (v ):
261
- if isinstance (v , (list , tuple , set )):
262
- return [tidy (i ) for i in v ]
263
- if isinstance (v , dict ):
264
- return {k : tidy (v ) for k , v in v .items ()}
265
- if isinstance (v , str ) and v .startswith ("/" ):
266
- return os .path .basename (v )
267
- if isinstance (v , datetime .datetime ):
268
- return v .isoformat ()
269
- if isinstance (v , datetime .date ):
270
- return v .isoformat ()
271
- if isinstance (v , datetime .timedelta ):
272
- return frequency_to_string (v )
273
-
274
- if isinstance (v , Dataset ):
275
- # That can happen in the `arguments`
276
- # if a dataset is passed as an argument
277
- return repr (v )
278
-
279
- if isinstance (v , slice ):
280
- return (v .start , v .stop , v .step )
281
-
282
- return v
306
+ _ , source_to_arrays = self ._supporting_arrays_and_sources ()
307
+
308
+ sources = []
309
+ for i , source in enumerate (self ._input_sources ()):
310
+ source_metadata = source .dataset_metadata ().copy ()
311
+ source_metadata ["supporting_arrays" ] = source_to_arrays [id (source )]
312
+ sources .append (source_metadata )
283
313
284
314
md = dict (
285
315
version = anemoi .datasets .__version__ ,
286
316
arguments = self .arguments ,
287
317
** self .dataset_metadata (),
318
+ sources = sources ,
319
+ supporting_arrays = source_to_arrays [id (self )],
288
320
)
289
321
290
322
try :
291
- return json .loads (json .dumps (tidy (md )))
323
+ return json .loads (json .dumps (_tidy (md )))
292
324
except Exception :
293
325
LOG .exception ("Failed to serialize metadata" )
294
326
pprint .pprint (md )
@@ -313,8 +345,67 @@ def dataset_metadata(self):
313
345
dtype = str (self .dtype ),
314
346
start_date = self .start_date .astype (str ),
315
347
end_date = self .end_date .astype (str ),
348
+ name = self .name ,
316
349
)
317
350
351
+ def _supporting_arrays (self , * path ):
352
+
353
+ import numpy as np
354
+
355
+ def _path (path , name ):
356
+ return "/" .join (str (_ ) for _ in [* path , name ])
357
+
358
+ result = {
359
+ _path (path , "latitudes" ): self .latitudes ,
360
+ _path (path , "longitudes" ): self .longitudes ,
361
+ }
362
+ collected = []
363
+
364
+ self .collect_supporting_arrays (collected , * path )
365
+
366
+ for path , name , array in collected :
367
+ assert isinstance (path , tuple ) and isinstance (name , str )
368
+ assert isinstance (array , np .ndarray )
369
+
370
+ name = _path (path , name )
371
+
372
+ if name in result :
373
+ raise ValueError (f"Duplicate key { name } " )
374
+
375
+ result [name ] = array
376
+
377
+ return result
378
+
379
+ def supporting_arrays (self ):
380
+ """Arrays to be saved in the checkpoints"""
381
+ arrays , _ = self ._supporting_arrays_and_sources ()
382
+ return arrays
383
+
384
+ def _supporting_arrays_and_sources (self ):
385
+
386
+ source_to_arrays = {}
387
+
388
+ # Top levels arrays
389
+ result = self ._supporting_arrays ()
390
+ source_to_arrays [id (self )] = sorted (result .keys ())
391
+
392
+ # Arrays from the input sources
393
+ for i , source in enumerate (self ._input_sources ()):
394
+ name = source .name if source .name is not None else i
395
+ src_arrays = source ._supporting_arrays (name )
396
+ source_to_arrays [id (source )] = sorted (src_arrays .keys ())
397
+
398
+ for k in src_arrays :
399
+ assert k not in result
400
+
401
+ result .update (src_arrays )
402
+
403
+ return result , source_to_arrays
404
+
405
+ def collect_supporting_arrays (self , collected , * path ):
406
+ # Override this method to add more arrays
407
+ pass
408
+
318
409
def metadata_specific (self , ** kwargs ):
319
410
action = self .__class__ .__name__ .lower ()
320
411
# assert isinstance(self.frequency, datetime.timedelta), (self.frequency, self, action)
0 commit comments