Skip to content

Commit 82cabf8

Browse files
Jenny WongJenny Wong
authored andcommitted
Modify set function in message to accept dictionary of key values and add check on values
1 parent 06adfa6 commit 82cabf8

File tree

2 files changed

+46
-4
lines changed

2 files changed

+46
-4
lines changed

eccodes/highlevel/message.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,16 +67,42 @@ def get(self, name, default=None, ktype=None):
6767
except KeyError:
6868
return default
6969

70-
def set(self, name, value):
71-
"""Set the value of the given key
70+
def set(self, *args):
71+
"""If two arguments are given, assumes this takes form of a single key
72+
value pair and sets the value of the given key. If a dictionary is passed
73+
then sets the values of all keys in the dictionary. Note, ordering
74+
if the keys is important. Finally, checks if value
75+
has been set correctly
7276
7377
Raises
7478
------
79+
TypeError
80+
If arguments do not take one of the two expected forms
7581
KeyError
7682
If the key does not exist
83+
ValueError
84+
If the set value of one of the keys is not the expected value
7785
"""
78-
with raise_keyerror(name):
79-
return eccodes.codes_set(self._handle, name, value)
86+
if isinstance(args[0], str) and len(args) == 2:
87+
key_values = {args[0]: args[1]}
88+
elif isinstance(args[0], dict):
89+
key_values = args[0]
90+
else:
91+
raise TypeError(f"Unsupported argument type. Expects two arguments consisting \
92+
of key and value pair, or a dictionary of key value pairs")
93+
94+
for name, value in key_values.items():
95+
with raise_keyerror(name):
96+
eccodes.codes_set(self._handle, name, value)
97+
98+
# Check values just set
99+
for name, value in key_values.items():
100+
saved_value = self.get(name)
101+
cast_value = value
102+
if not isinstance(value, type(saved_value)):
103+
cast_value = type(saved_value)(value)
104+
if saved_value != cast_value:
105+
raise ValueError(f"Unexpected retrieved value {saved_value}. Expected {cast_value}")
80106

81107
def get_array(self, name):
82108
"""Get the value of the given key as an array

tests/test_highlevel.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,22 @@ def test_message_set():
6464
assert np.all(message.get("values") == vals)
6565
assert message.is_missing(missing_key)
6666

67+
def test_message_set_multiple():
68+
with eccodes.FileReader(TEST_GRIB_DATA) as reader:
69+
message = next(reader)
70+
message.set({
71+
"centre": "ecmf",
72+
"numberOfValues": 10,
73+
"shortName": "z",
74+
})
75+
with pytest.raises(TypeError):
76+
message.set("centre", "ecmwf", 2)
77+
with pytest.raises(ValueError):
78+
message.set("stepRange", "0-12")
79+
message.set({
80+
"stepType": "max",
81+
"stepRange": "0-12"
82+
})
6783

6884
def test_message_iter():
6985
with eccodes.FileReader(TEST_GRIB_DATA2) as reader:

0 commit comments

Comments
 (0)