Skip to content

Commit 29a2fea

Browse files
ntessorerrjbca
andauthored
simple construction of quantities in config files (#316)
Co-authored-by: Richard R <[email protected]>
1 parent b54dc75 commit 29a2fea

File tree

3 files changed

+26
-0
lines changed

3 files changed

+26
-0
lines changed

skypy/pipeline/_config.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import builtins
22
from importlib import import_module
33
import yaml
4+
import re
5+
from astropy.units import Quantity
46

57
__all__ = [
68
'load_skypy_yaml',
@@ -52,10 +54,22 @@ def construct_function(self, name, node):
5254

5355
return (function,) if args == '' else (function, args)
5456

57+
def construct_quantity(self, node):
58+
value = self.construct_scalar(node)
59+
return Quantity(value)
60+
5561

5662
# constructor for generic functions
5763
SkyPyLoader.add_multi_constructor('!', SkyPyLoader.construct_function)
5864

65+
# constructor for quantities
66+
SkyPyLoader.add_constructor('!quantity', SkyPyLoader.construct_quantity)
67+
# Implicitly resolve quantities using the regex from astropy
68+
SkyPyLoader.add_implicit_resolver('!quantity', re.compile(r'''
69+
\s*[+-]?((\d+\.?\d*)|(\.\d+)|([nN][aA][nN])|
70+
([iI][nN][fF]([iI][nN][iI][tT][yY]){0,1}))([eE][+-]?\d+)?[.+-]? \w* \W+
71+
''', re.VERBOSE), list('-+0123456789.'))
72+
5973

6074
def load_skypy_yaml(filename):
6175
'''Read a SkyPy pipeline configuration from a YAML file.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
42_km: !quantity 42.0 km
2+
1_deg2: 1 deg2

skypy/pipeline/tests/test_config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from collections.abc import Callable
33
import pytest
44
from skypy.pipeline import load_skypy_yaml
5+
from astropy import units
56

67

78
def test_load_skypy_yaml():
@@ -31,3 +32,12 @@ def test_load_skypy_yaml():
3132
filename = get_pkg_data_filename('data/bad_module.yml')
3233
with pytest.raises(ImportError):
3334
load_skypy_yaml(filename)
35+
36+
37+
def test_yaml_quantities():
38+
# config with quantities
39+
filename = get_pkg_data_filename('data/quantities.yml')
40+
config = load_skypy_yaml(filename)
41+
42+
assert config['42_km'] == units.Quantity('42 km')
43+
assert config['1_deg2'] == units.Quantity('1 deg2')

0 commit comments

Comments
 (0)