Skip to content

Commit

Permalink
Add new stage to remove columns in a more convenient way
Browse files Browse the repository at this point in the history
  • Loading branch information
benkrikler committed May 7, 2020
1 parent 9e279bc commit 3816f28
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 0 deletions.
24 changes: 24 additions & 0 deletions fast_plotter/postproc/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,30 @@ def keep_specific_bins(df, axis, keep, expansions={}):
return out_df


def filter_cols(df, items=None, like=None, regex=None, drop_not_keep=False):
"""Filter out columns you want to keep.
Parameters:
items (list-like): A list of column names to filter with
like (str, list[string]): A string or list of strings which will filter
columns where they are found in the column name
regex (str): A regular expression to match column names to
drop_not_keep (bool): Inverts the selection if true so that matched columns are dropped
"""
if not like or not isinstance(like, (tuple, list)):
df_filtered = df.filter(items=items, like=like, regex=regex)
elif like:
if items and like:
raise RuntimeError("Can only use one of 'items', 'like', or 'regex'")
filtered = [set(col for col in df.columns if i in col) for i in like]
filtered = set.union(*filtered)
df_filtered = df.filter(items=filtered, regex=regex)

if drop_not_keep:
return df.drop(df_filtered.columns)
return df_filtered


def combine_cols(df, format_strings, as_index=[]):
"""Combine columns together using format strings"""
logger.info("Combining columns based on: %s", str(format_strings))
Expand Down
5 changes: 5 additions & 0 deletions fast_plotter/postproc/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ class AssignCol(BaseManipulator):
func = "assign_col"


class FilterCols(BaseManipulator):
cardinality = "one-to-one"
func = "filter_cols"


class AssignDim(BaseManipulator):
cardinality = "one-to-one"
func = "assign_dim"
Expand Down
24 changes: 24 additions & 0 deletions tests/postproc/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,30 @@ def test_split(binned_df):
assert all([r[0].index.nlevels == 3 for r in results])


def test_filter_cols(binned_df):
df = binned_df.index.to_frame()

result = funcs.filter_cols(df, items=["int"])
assert len(result.columns) == 1
assert result.columns[0] == "int"

result = funcs.filter_cols(df, items=["int", "cat"])
assert len(result.columns) == 2
assert set(result.columns) == set(("int", "cat"))

result = funcs.filter_cols(df, like="int")
assert len(result.columns) == 2
assert set(result.columns) == set(("int", "interval"))

result = funcs.filter_cols(df, like=["int", "cat"])
assert len(result.columns) == 3
assert set(result.columns) == set(("int", "cat", "interval"))

result = funcs.filter_cols(df, regex="^int.*")
assert len(result.columns) == 2
assert set(result.columns) == set(("int", "interval"))


# def test_reorder_dimensions():
# #def reorder_dimensions(df, order):
# pass
Expand Down

0 comments on commit 3816f28

Please sign in to comment.