23
23
from anemoi .datasets .data .stores import open_zarr
24
24
25
25
TEST_DATA_ROOT = "https://object-store.os-api.cci1.ecmwf.int/ml-tests/test-data/anemoi-datasets/create"
26
+ TEST_DATA_S3_ROOT = "s3://ml-tests/test-data/anemoi-datasets/create"
26
27
27
28
28
29
HERE = os .path .dirname (__file__ )
@@ -94,29 +95,44 @@ def __call__(self, name, *args, **kwargs):
94
95
_from_source = LoadSource ()
95
96
96
97
97
- def compare_dot_zattrs (a , b ):
98
+ def compare_dot_zattrs (a , b , path , errors ):
98
99
if isinstance (a , dict ):
99
100
a_keys = list (a .keys ())
100
101
b_keys = list (b .keys ())
101
102
for k in set (a_keys ) & set (b_keys ):
102
- if k in ["timestamp" , "uuid" , "latest_write_timestamp" , "yaml_config" ]:
103
- assert type (a [k ]) == type (b [k ]), ( # noqa: E721
104
- type (a [k ]),
105
- type (b [k ]),
106
- a [k ],
107
- b [k ],
108
- )
109
- assert k in a_keys , (k , a_keys )
110
- assert k in b_keys , (k , b_keys )
111
- return compare_dot_zattrs (a [k ], b [k ])
103
+ if k in [
104
+ "timestamp" ,
105
+ "uuid" ,
106
+ "latest_write_timestamp" ,
107
+ "yaml_config" ,
108
+ "history" ,
109
+ "provenance" ,
110
+ "provenance_load" ,
111
+ "description" ,
112
+ "config_path" ,
113
+ "dataset_status" ,
114
+ ]:
115
+ if type (a [k ]) != type (b [k ]): # noqa : E721
116
+ errors .append (f"❌ { path } .{ k } : type differs { type (a [k ])} != { type (b [k ])} " )
117
+ continue
118
+ compare_dot_zattrs (a [k ], b [k ], f"{ path } .{ k } " , errors )
119
+ return
112
120
113
121
if isinstance (a , list ):
114
- assert len (a ) == len (b ), (a , b )
115
- for v , w in zip (a , b ):
116
- return compare_dot_zattrs (v , w )
117
-
118
- assert type (a ) == type (b ), (type (a ), type (b ), a , b ) # noqa: E721
119
- return a == b , (a , b )
122
+ if len (a ) != len (b ):
123
+ errors .append (f"❌ { path } : lengths are different { len (a )} != { len (b )} " )
124
+ return
125
+ for i , (v , w ) in enumerate (zip (a , b )):
126
+ compare_dot_zattrs (v , w , f"{ path } .{ i } " , errors )
127
+ return
128
+
129
+ if type (a ) != type (b ): # noqa : E721
130
+ msg = f"❌ { path } actual != expected : { a } ({ type (a )} ) != { b } ({ type (b )} )"
131
+ errors .append (msg )
132
+ return
133
+ if a != b :
134
+ msg = f"❌ { path } actual != expected : { a } != { b } "
135
+ errors .append (msg )
120
136
121
137
122
138
def compare_datasets (a , b ):
@@ -169,19 +185,29 @@ def compare_statistics(ds1, ds2):
169
185
class Comparer :
170
186
def __init__ (self , name , output_path = None , reference_path = None ):
171
187
self .name = name
172
- self .reference = reference_path or os .path .join (TEST_DATA_ROOT , name + ".zarr" )
173
188
self .output = output_path or os .path .join (name + ".zarr" )
174
- print (f"Comparing { self .reference } and { self .output } " )
189
+ self .reference_path = reference_path
190
+ print (f"Comparing { self .output } and { self .reference_path } " )
175
191
176
- self .z_reference = open_zarr (self .reference )
177
192
self .z_output = open_zarr (self .output )
193
+ self .z_reference = open_zarr (self .reference_path )
178
194
179
- self .ds_reference = open_dataset ( self . reference )
195
+ self .z_reference [ "data" ]
180
196
self .ds_output = open_dataset (self .output )
197
+ self .ds_reference = open_dataset (self .reference_path )
181
198
182
199
def compare (self ):
183
- compare_dot_zattrs (self .z_output .attrs , self .z_reference .attrs )
200
+ errors = []
201
+ compare_dot_zattrs (dict (self .z_output .attrs ), dict (self .z_reference .attrs ), "metadata" , errors )
202
+ if errors :
203
+ print ("Comparison failed" )
204
+ print ("\n " .join (errors ))
205
+
206
+ if errors :
207
+ raise AssertionError ("Comparison failed" )
208
+
184
209
compare_datasets (self .ds_output , self .ds_reference )
210
+
185
211
compare_statistics (self .ds_output , self .ds_reference )
186
212
187
213
@@ -199,8 +225,14 @@ def test_run(name):
199
225
c .additions (delta = [1 , 3 , 6 , 12 ])
200
226
c .cleanup ()
201
227
202
- comparer = Comparer (name , output_path = output )
203
- comparer .compare ()
228
+ # reference_path = os.path.join(HERE, name + "-reference.zarr")
229
+ s3_uri = TEST_DATA_S3_ROOT + "/" + name + ".zarr"
230
+ # if not os.path.exists(reference_path):
231
+ # from anemoi.utils.s3 import download as s3_download
232
+ # s3_download(s3_uri + '/', reference_path, overwrite=True)
233
+
234
+ Comparer (name , output_path = output , reference_path = s3_uri ).compare ()
235
+ # Comparer(name, output_path=output, reference_path=reference_path).compare()
204
236
205
237
206
238
if __name__ == "__main__" :
0 commit comments