Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor serialize_data with better param names and logic #222

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 49 additions & 44 deletions src/diffpy/utils/parsers/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
##############################################################################

import json
import pathlib
from pathlib import Path
import warnings

import numpy
Expand All @@ -26,10 +26,10 @@


def serialize_data(
filename,
filepath,
hdata: dict,
data_table,
dt_colnames=None,
dt_col_names=None,
show_path=True,
serial_file=None,
):
Expand All @@ -40,13 +40,13 @@ def serialize_data(

Parameters
----------
filename
Name of the file whose data is being serialized.
filepath
The file path whose data is being serialized.
hdata: dict
File metadata (generally related to data table).
The file metadata (generally related to data table).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason we are not calling this variable metadata? or header_data. calling it hdata I am having to carry around too many pointers in my brain that I have to cognitively resolve as I read.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hdata came from loadData's naming of hdata for header data information. Good to change to metadata.

data_table: list or ndarray
Data table.
dt_colnames: list
The data table.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is a data table? A table with data on it? I think this is the something like the column data read from the file or something.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These serialization functions were coded in tandem with loadData, so the inputs of these match outputs of loadData. This assumes files structured like .chi/.gr/etc. files.

dt_col_names: list
Names of each column in data_table. Every name in data_table_cols will be put into the Dictionary
as a key with a value of that column in data_table (stored as a List). Put None for columns
without names. If dt_cols has less non-None entries than columns in data_table,
Expand All @@ -64,64 +64,69 @@ def serialize_data(
Returns the dictionary loaded from/into the updated database file.
"""

# compile data_table and hddata together
# Combine data_table and hdata together in data
data = {}

# handle getting name of file for variety of filename types
abs_path = pathlib.Path(filename).resolve()
# add path to start of data if requested
# Handle getting name of file for variety of filename types
abs_path = Path(filepath).resolve()

# Add path to start of data if show_path is True
if show_path and "path" not in hdata.keys():
data.update({"path": abs_path.as_posix()})
# title the entry with name of file (taken from end of path)
title = abs_path.name

# first add data in hddata dict
# Add hdata to data
data.update(hdata)

# second add named columns in dt_cols
# performed second to prioritize overwriting hdata entries with data_table column entries
named_columns = 0 # initial value
max_columns = 1 # higher than named_columns to trigger 'data table' entry
if dt_colnames is not None:
num_columns = [len(row) for row in data_table]
max_columns = max(num_columns)
num_col_names = len(dt_colnames)
if max_columns < num_col_names: # assume numpy.loadtxt gives non-irregular array
raise ImproperSizeError("More entries in dt_colnames than columns in data_table.")
named_columns = 0
for idx in range(num_col_names):
colname = dt_colnames[idx]
if colname is not None:
if colname in hdata.keys():
# Prioritize overwriting hdata entries with data_table column entries
col_counter = 0

# Get a list of column counts in each entry in data table
dt_col_counts = [len(row) for row in data_table]
dt_max_col_count = max(dt_col_counts)

if dt_col_names is not None:
dt_col_names_count = len(dt_col_names)
if dt_max_col_count < dt_col_names_count: # assume numpy.loadtxt gives non-irregular array
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compare the number of col_names values provided as an input for this function against the number of col determined from the data table.

raise ImproperSizeError("More entries in dt_col_names_count than columns in data_table.")

for idx in range(dt_col_names_count):
col_name = dt_col_names[idx]

if col_name is not None:

# Check if column name already exists in hdata
if col_name in hdata.keys():
warnings.warn(
f"Entry '{colname}' in hdata has been overwritten by a data_table entry.",
f"Entry '{col_name}' in hdata has been overwritten by a data_table entry.",
RuntimeWarning,
)
data.update({colname: list(data_table[:, idx])})
named_columns += 1

# finally add data_table as an entry named 'data table' if not all columns were parsed
if named_columns < max_columns:
# Add row data per column to data
data.update({col_name: list(data_table[:, idx])})
col_counter += 1

# Add data_table as an entry named 'data table' if not all columns were parsed
if col_counter < dt_max_col_count:
if "data table" in data.keys():
warnings.warn(
"Entry 'data table' in hdata has been overwritten by data_table.",
RuntimeWarning,
)
data.update({"data table": data_table})

# parse name using pathlib and generate dictionary entry
entry = {title: data}
# Parse name using pathlib and generate dictionary entry
data_key = abs_path.name
entry = {data_key: data}

# no save
if serial_file is None:
return entry

# saving/updating file
# check if supported type
sf = pathlib.Path(serial_file)
sf_name = sf.name
extension = sf.suffix
if extension not in supported_formats:
sf_path = Path(serial_file)
sf_name = sf_path.name
sf_ext = sf_path.suffix
if sf_ext not in supported_formats:
raise UnsupportedTypeError(sf_name, supported_formats)

# new file or update
Expand All @@ -133,7 +138,7 @@ def serialize_data(
pass

# json
if extension == ".json":
if sf_ext == ".json":
# cannot serialize numpy arrays
class NumpyEncoder(json.JSONEncoder):
def default(self, data_obj):
Expand Down Expand Up @@ -177,7 +182,7 @@ def deserialize_data(filename, filetype=None):
"""

# check if supported type
f = pathlib.Path(filename)
f = Path(filename)
f_name = f.name

if filetype is None:
Expand Down
25 changes: 12 additions & 13 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,29 @@ def test_load_multiple(tmp_path, datafile):
tlm_list = sorted(dbload_dir.glob("*.gr"))

generated_data = None

for headerfile in tlm_list:
# gather data using loadData
hdata = loadData(headerfile, headers=True)
data_table = loadData(headerfile)

# check path extraction
generated_data = serialize_data(headerfile, hdata, data_table, dt_colnames=["r", "gr"], show_path=True)
generated_data = serialize_data(headerfile, hdata, data_table, dt_col_names=["r", "gr"], show_path=True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

df_col_names and df_colnames are used across, so sticking with dt_col_names

assert headerfile == Path(generated_data[headerfile.name].pop("path"))

# rerun without path information and save to file
generated_data = serialize_data(
headerfile,
hdata,
data_table,
dt_colnames=["r", "gr"],
dt_col_names=["r", "gr"],
show_path=False,
serial_file=generatedjson,
)

# compare to target
target_data = deserialize_data(targetjson)
target_data = deserialize_data(targetjson, filetype=".json")
assert target_data == generated_data
# ensure file saved properly
assert target_data == deserialize_data(generatedjson, filetype=".json")


def test_exceptions(datafile):
Expand All @@ -60,24 +59,24 @@ def test_exceptions(datafile):

# various dt_colnames inputs
with pytest.raises(ImproperSizeError):
serialize_data(loadfile, hdata, data_table, dt_colnames=["one", "two", "three is too many"])
serialize_data(loadfile, hdata, data_table, dt_col_names=["one", "two", "three is too many"])
# check proper output
normal = serialize_data(loadfile, hdata, data_table, dt_colnames=["r", "gr"])
normal = serialize_data(loadfile, hdata, data_table, dt_col_names=["r", "gr"])
data_name = list(normal.keys())[0]
r_list = normal[data_name]["r"]
gr_list = normal[data_name]["gr"]
# three equivalent ways to denote no column names
missing_parameter = serialize_data(loadfile, hdata, data_table, show_path=False)
empty_parameter = serialize_data(loadfile, hdata, data_table, show_path=False, dt_colnames=[])
none_entry_parameter = serialize_data(loadfile, hdata, data_table, show_path=False, dt_colnames=[None, None])
empty_parameter = serialize_data(loadfile, hdata, data_table, show_path=False, dt_col_names=[])
none_entry_parameter = serialize_data(loadfile, hdata, data_table, show_path=False, dt_col_names=[None, None])
# check equivalence
assert missing_parameter == empty_parameter
assert missing_parameter == none_entry_parameter
assert numpy.allclose(missing_parameter[data_name]["data table"], data_table)
# extract a single column
r_extract = serialize_data(loadfile, hdata, data_table, show_path=False, dt_colnames=["r"])
gr_extract = serialize_data(loadfile, hdata, data_table, show_path=False, dt_colnames=[None, "gr"])
incorrect_r_extract = serialize_data(loadfile, hdata, data_table, show_path=False, dt_colnames=[None, "r"])
r_extract = serialize_data(loadfile, hdata, data_table, show_path=False, dt_col_names=["r"])
gr_extract = serialize_data(loadfile, hdata, data_table, show_path=False, dt_col_names=[None, "gr"])
incorrect_r_extract = serialize_data(loadfile, hdata, data_table, show_path=False, dt_col_names=[None, "r"])
# check proper columns extracted
assert numpy.allclose(gr_extract[data_name]["gr"], incorrect_r_extract[data_name]["r"])
assert "r" not in gr_extract[data_name]
Expand All @@ -101,7 +100,7 @@ def test_exceptions(datafile):
hdata,
data_table,
show_path=False,
dt_colnames=["c1", "c2", "c3"],
dt_col_names=["c1", "c2", "c3"],
)
assert len(record) == 4
for msg in record:
Expand Down
Loading