diff --git a/src/diffpy/utils/parsers/serialization.py b/src/diffpy/utils/parsers/serialization.py index 46d4b8ff..93062f9b 100644 --- a/src/diffpy/utils/parsers/serialization.py +++ b/src/diffpy/utils/parsers/serialization.py @@ -14,8 +14,8 @@ ############################################################################## import json -import pathlib import warnings +from pathlib import Path import numpy @@ -26,10 +26,10 @@ def serialize_data( - filename, + filepath, hdata: dict, data_table, - dt_colnames=None, + dt_col_names=None, show_path=True, serial_file=None, ): @@ -40,13 +40,13 @@ def serialize_data( Parameters ---------- - filename - Name of the file whose data is being serialized. + filepath + The file path whose data is being serialized. hdata: dict - File metadata (generally related to data table). + The file metadata (generally related to data table). data_table: list or ndarray - Data table. - dt_colnames: list + The data table. + dt_col_names: list Names of each column in data_table. Every name in data_table_cols will be put into the Dictionary as a key with a value of that column in data_table (stored as a List). Put None for columns without names. If dt_cols has less non-None entries than columns in data_table, @@ -64,44 +64,49 @@ def serialize_data( Returns the dictionary loaded from/into the updated database file. """ - # compile data_table and hddata together + # Combine data_table and hdata together in data data = {} - # handle getting name of file for variety of filename types - abs_path = pathlib.Path(filename).resolve() - # add path to start of data if requested + # Handle getting name of file for variety of filename types + abs_path = Path(filepath).resolve() + + # Add path to start of data if show_path is True if show_path and "path" not in hdata.keys(): data.update({"path": abs_path.as_posix()}) - # title the entry with name of file (taken from end of path) - title = abs_path.name - # first add data in hddata dict + # Add hdata to data data.update(hdata) - # second add named columns in dt_cols - # performed second to prioritize overwriting hdata entries with data_table column entries - named_columns = 0 # initial value - max_columns = 1 # higher than named_columns to trigger 'data table' entry - if dt_colnames is not None: - num_columns = [len(row) for row in data_table] - max_columns = max(num_columns) - num_col_names = len(dt_colnames) - if max_columns < num_col_names: # assume numpy.loadtxt gives non-irregular array - raise ImproperSizeError("More entries in dt_colnames than columns in data_table.") - named_columns = 0 - for idx in range(num_col_names): - colname = dt_colnames[idx] - if colname is not None: - if colname in hdata.keys(): + # Prioritize overwriting hdata entries with data_table column entries + col_counter = 0 + + # Get a list of column counts in each entry in data table + dt_col_counts = [len(row) for row in data_table] + dt_max_col_count = max(dt_col_counts) + + if dt_col_names is not None: + dt_col_names_count = len(dt_col_names) + if dt_max_col_count < dt_col_names_count: # assume numpy.loadtxt gives non-irregular array + raise ImproperSizeError("More entries in dt_col_names_count than columns in data_table.") + + for idx in range(dt_col_names_count): + col_name = dt_col_names[idx] + + if col_name is not None: + + # Check if column name already exists in hdata + if col_name in hdata.keys(): warnings.warn( - f"Entry '{colname}' in hdata has been overwritten by a data_table entry.", + f"Entry '{col_name}' in hdata has been overwritten by a data_table entry.", RuntimeWarning, ) - data.update({colname: list(data_table[:, idx])}) - named_columns += 1 - # finally add data_table as an entry named 'data table' if not all columns were parsed - if named_columns < max_columns: + # Add row data per column to data + data.update({col_name: list(data_table[:, idx])}) + col_counter += 1 + + # Add data_table as an entry named 'data table' if not all columns were parsed + if col_counter < dt_max_col_count: if "data table" in data.keys(): warnings.warn( "Entry 'data table' in hdata has been overwritten by data_table.", @@ -109,19 +114,19 @@ def serialize_data( ) data.update({"data table": data_table}) - # parse name using pathlib and generate dictionary entry - entry = {title: data} + # Parse name using pathlib and generate dictionary entry + data_key = abs_path.name + entry = {data_key: data} - # no save if serial_file is None: return entry # saving/updating file # check if supported type - sf = pathlib.Path(serial_file) - sf_name = sf.name - extension = sf.suffix - if extension not in supported_formats: + sf_path = Path(serial_file) + sf_name = sf_path.name + sf_ext = sf_path.suffix + if sf_ext not in supported_formats: raise UnsupportedTypeError(sf_name, supported_formats) # new file or update @@ -133,7 +138,7 @@ def serialize_data( pass # json - if extension == ".json": + if sf_ext == ".json": # cannot serialize numpy arrays class NumpyEncoder(json.JSONEncoder): def default(self, data_obj): @@ -177,7 +182,7 @@ def deserialize_data(filename, filetype=None): """ # check if supported type - f = pathlib.Path(filename) + f = Path(filename) f_name = f.name if filetype is None: diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 0ba397ad..17e00843 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -17,13 +17,14 @@ def test_load_multiple(tmp_path, datafile): tlm_list = sorted(dbload_dir.glob("*.gr")) generated_data = None + for headerfile in tlm_list: # gather data using loadData hdata = loadData(headerfile, headers=True) data_table = loadData(headerfile) # check path extraction - generated_data = serialize_data(headerfile, hdata, data_table, dt_colnames=["r", "gr"], show_path=True) + generated_data = serialize_data(headerfile, hdata, data_table, dt_col_names=["r", "gr"], show_path=True) assert headerfile == Path(generated_data[headerfile.name].pop("path")) # rerun without path information and save to file @@ -31,16 +32,14 @@ def test_load_multiple(tmp_path, datafile): headerfile, hdata, data_table, - dt_colnames=["r", "gr"], + dt_col_names=["r", "gr"], show_path=False, serial_file=generatedjson, ) # compare to target - target_data = deserialize_data(targetjson) + target_data = deserialize_data(targetjson, filetype=".json") assert target_data == generated_data - # ensure file saved properly - assert target_data == deserialize_data(generatedjson, filetype=".json") def test_exceptions(datafile): @@ -60,24 +59,24 @@ def test_exceptions(datafile): # various dt_colnames inputs with pytest.raises(ImproperSizeError): - serialize_data(loadfile, hdata, data_table, dt_colnames=["one", "two", "three is too many"]) + serialize_data(loadfile, hdata, data_table, dt_col_names=["one", "two", "three is too many"]) # check proper output - normal = serialize_data(loadfile, hdata, data_table, dt_colnames=["r", "gr"]) + normal = serialize_data(loadfile, hdata, data_table, dt_col_names=["r", "gr"]) data_name = list(normal.keys())[0] r_list = normal[data_name]["r"] gr_list = normal[data_name]["gr"] # three equivalent ways to denote no column names missing_parameter = serialize_data(loadfile, hdata, data_table, show_path=False) - empty_parameter = serialize_data(loadfile, hdata, data_table, show_path=False, dt_colnames=[]) - none_entry_parameter = serialize_data(loadfile, hdata, data_table, show_path=False, dt_colnames=[None, None]) + empty_parameter = serialize_data(loadfile, hdata, data_table, show_path=False, dt_col_names=[]) + none_entry_parameter = serialize_data(loadfile, hdata, data_table, show_path=False, dt_col_names=[None, None]) # check equivalence assert missing_parameter == empty_parameter assert missing_parameter == none_entry_parameter assert numpy.allclose(missing_parameter[data_name]["data table"], data_table) # extract a single column - r_extract = serialize_data(loadfile, hdata, data_table, show_path=False, dt_colnames=["r"]) - gr_extract = serialize_data(loadfile, hdata, data_table, show_path=False, dt_colnames=[None, "gr"]) - incorrect_r_extract = serialize_data(loadfile, hdata, data_table, show_path=False, dt_colnames=[None, "r"]) + r_extract = serialize_data(loadfile, hdata, data_table, show_path=False, dt_col_names=["r"]) + gr_extract = serialize_data(loadfile, hdata, data_table, show_path=False, dt_col_names=[None, "gr"]) + incorrect_r_extract = serialize_data(loadfile, hdata, data_table, show_path=False, dt_col_names=[None, "r"]) # check proper columns extracted assert numpy.allclose(gr_extract[data_name]["gr"], incorrect_r_extract[data_name]["r"]) assert "r" not in gr_extract[data_name] @@ -101,7 +100,7 @@ def test_exceptions(datafile): hdata, data_table, show_path=False, - dt_colnames=["c1", "c2", "c3"], + dt_col_names=["c1", "c2", "c3"], ) assert len(record) == 4 for msg in record: