Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/changes/2921.api.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
As a consequence of fixing the bug #2921, `ctapipe.core.ExpressionEngine`
converts all input tables to `astropy.table.QTable` internally, which has a
small side effect on what is allowed in expressions: all columns with units are
now of type `astropy.units.Quantity`, instead of `astropy.table.Column`. Before,
an expression like ``"some_column.quantity.to(u.m)"`` would work if a ``Table``
was passed (but would fail for a ``QTable``). Now, that expression should be
``some_column.to(u.m)``
4 changes: 4 additions & 0 deletions docs/changes/2921.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Fixed bug where units were incorrect in the output table of an
`ctapipe.core.FeatureGenerator` if a table of class `astropy.table.Table` was
passed to the call method. This bug did not affect calls using an
`astropy.table.QTable`.
2 changes: 2 additions & 0 deletions src/ctapipe/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .component import Component, non_abstract_children
from .container import Container, DeprecatedField, Field, FieldValidationError, Map
from .expression_engine import ExpressionEngine
from .feature_generator import FeatureGenerator
from .provenance import Provenance, get_module_version
from .qualityquery import QualityCriteriaError, QualityQuery
Expand All @@ -28,4 +29,5 @@
"QualityQuery",
"QualityCriteriaError",
"FieldValidationError",
"ExpressionEngine",
]
56 changes: 44 additions & 12 deletions src/ctapipe/core/feature_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
"""

from collections import ChainMap
from copy import deepcopy

from astropy.table import QTable, Table

from .component import Component
from .expression_engine import ExpressionEngine
Expand All @@ -11,19 +14,31 @@
__all__ = [
"FeatureGenerator",
"FeatureGeneratorException",
"shallow_copy_table",
]


def _shallow_copy_table(table):
def shallow_copy_table(
table, output_cls: type[Table] | type[QTable] | None = None
) -> Table | QTable:
"""
Make a shallow copy of the table.

Data of the existing columns will be shared between shallow
copies, but adding / removing columns won't be seen in
the original table.
Data of the existing columns will be shared between shallow copies, but
adding / removing columns won't be seen in the original table. Metadata for
the new table will be a copy (not shallow) of the original metadata, so that
new metadata can be added without affecting the original table.

Parameters
----------
output_cls: type[Table] | type[QTable] | None
type of the output table. If None, use the input table type
"""
# automatically return Table or QTable depending on input
return table.__class__({col: table[col] for col in table.colnames}, copy=False)
output_cls = output_cls or table.__class__

new_table = output_cls({col: table[col] for col in table.colnames}, copy=False)
new_table.meta = deepcopy(table.meta)
return new_table


class FeatureGeneratorException(TypeError):
Expand Down Expand Up @@ -54,26 +69,43 @@ def __init__(self, config=None, parent=None, **kwargs):
self.engine = ExpressionEngine(expressions=self.features)
self._feature_names = [name for name, _ in self.features]

def __call__(self, table, **kwargs):
def __call__(self, table: Table | QTable, **kwargs) -> Table:
"""
Apply feature generation to the input table.

This method returns a shallow copy of the input table with the
new features added. Existing columns will share the underlying data,
however the new columns won't be visible in the input table.

Parameters
----------
table: QTable | Table
Input table. Internally a Table will be converted to a QTable so that
unit propagation works, so expressions should only rely on properties of QTables.
**kwargs:
Other objects that should be available in expressions. For example,
if a you pass ``subarray=subarray``, expressions can use that
object. This can also be special functions like ``f=my_function``,
which would allow an expression like ``"f(col1)"``.

Returns
-------
QTable|Table:
A new table with the same columns as the input, but with new columns
for each feature. The returned class depends on what was passed in.
"""
table = _shallow_copy_table(table)
lookup = ChainMap(table, kwargs)
table_copy = shallow_copy_table(table, output_cls=QTable)
lookup = ChainMap(table_copy, kwargs)

for result, name in zip(self.engine(lookup), self._feature_names):
if name in table.colnames:
if name in table_copy.colnames:
raise FeatureGeneratorException(f"{name} is already a column of table.")
try:
table[name] = result
table_copy[name] = result
except Exception as err:
raise err

return table
return table.__class__(table_copy) # ensure the return type is what is expected

def __len__(self):
return len(self.features)
49 changes: 46 additions & 3 deletions src/ctapipe/core/tests/test_feature_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np
import pytest
from astropy.table import Table
from astropy.table import QTable, Table

from ctapipe.core.expression_engine import ExpressionError
from ctapipe.core.feature_generator import FeatureGenerator, FeatureGeneratorException
Expand Down Expand Up @@ -60,14 +60,15 @@ def test_to_unit():

expressions = [
("length_meter", "length.to(u.m)"),
("log_length_meter", "log10(length.quantity.to_value(u.m))"),
("log_length_meter", "log10(length.to_value(u.m))"),
]
generator = FeatureGenerator(features=expressions)
table = Table({"length": [1 * u.km]})

table = generator(table)
assert table["length_meter"] == 1000
assert table["length_meter"] == 1000 * u.m
assert table["log_length_meter"] == 3
assert table["length_meter"].unit == u.m


def test_multiplicity(subarray_prod5_paranal):
Expand Down Expand Up @@ -102,3 +103,45 @@ def test_multiplicity(subarray_prod5_paranal):
np.testing.assert_equal(table["n_lsts"], [1, 2])
np.testing.assert_equal(table["n_msts"], [2, 1])
np.testing.assert_equal(table["n_ssts"], [0, 1])


@pytest.mark.parametrize("table_class", [QTable, Table])
def test_unit_propegation(table_class):
Comment thread
kosack marked this conversation as resolved.
Outdated
"""
Check that units propagate to features.

If a column in the input table has a unit, and a feature does math on that
unit, the feature should have the appropriate unit.
"""

import astropy.units as u

table = table_class(dict(x=np.arange(11) * u.cm, E=np.linspace(-2, 2, 11) * u.TeV))
features = [
("x2", "x**2"),
("E_per_area", "E/x**2"),
]

feature_gen = FeatureGenerator(features=features)
new_table = feature_gen(table)

assert new_table["x2"].unit.is_equivalent("cm2")
assert new_table["E_per_area"].unit.is_equivalent("TeV cm-2")


@pytest.mark.parametrize("table_class", [QTable, Table])
def test_input_output_class(table_class):
"""Ensure output table class is same as input."""

import astropy.units as u

table = table_class(dict(x=np.arange(11) * u.cm, E=np.linspace(-2, 2, 11) * u.TeV))
features = [
("x2", "x**2"),
("E_per_area", "E/x**2"),
]

feature_gen = FeatureGenerator(features=features)
new_table = feature_gen(table)

assert new_table.__class__ == table.__class__