Skip to content

Commit dc6a06e

Browse files
committed
Refactor serialize_data with better param names and logic
1 parent 5f67f6c commit dc6a06e

File tree

2 files changed

+61
-57
lines changed

2 files changed

+61
-57
lines changed

src/diffpy/utils/parsers/serialization.py

+49-44
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
##############################################################################
1515

1616
import json
17-
import pathlib
17+
from pathlib import Path
1818
import warnings
1919

2020
import numpy
@@ -26,10 +26,10 @@
2626

2727

2828
def serialize_data(
29-
filename,
29+
filepath,
3030
hdata: dict,
3131
data_table,
32-
dt_colnames=None,
32+
dt_col_names=None,
3333
show_path=True,
3434
serial_file=None,
3535
):
@@ -40,13 +40,13 @@ def serialize_data(
4040
4141
Parameters
4242
----------
43-
filename
44-
Name of the file whose data is being serialized.
43+
filepath
44+
The file path whose data is being serialized.
4545
hdata: dict
46-
File metadata (generally related to data table).
46+
The file metadata (generally related to data table).
4747
data_table: list or ndarray
48-
Data table.
49-
dt_colnames: list
48+
The data table.
49+
dt_col_names: list
5050
Names of each column in data_table. Every name in data_table_cols will be put into the Dictionary
5151
as a key with a value of that column in data_table (stored as a List). Put None for columns
5252
without names. If dt_cols has less non-None entries than columns in data_table,
@@ -64,64 +64,69 @@ def serialize_data(
6464
Returns the dictionary loaded from/into the updated database file.
6565
"""
6666

67-
# compile data_table and hddata together
67+
# Combine data_table and hdata together in data
6868
data = {}
6969

70-
# handle getting name of file for variety of filename types
71-
abs_path = pathlib.Path(filename).resolve()
72-
# add path to start of data if requested
70+
# Handle getting name of file for variety of filename types
71+
abs_path = Path(filepath).resolve()
72+
73+
# Add path to start of data if show_path is True
7374
if show_path and "path" not in hdata.keys():
7475
data.update({"path": abs_path.as_posix()})
75-
# title the entry with name of file (taken from end of path)
76-
title = abs_path.name
7776

78-
# first add data in hddata dict
77+
# Add hdata to data
7978
data.update(hdata)
8079

81-
# second add named columns in dt_cols
82-
# performed second to prioritize overwriting hdata entries with data_table column entries
83-
named_columns = 0 # initial value
84-
max_columns = 1 # higher than named_columns to trigger 'data table' entry
85-
if dt_colnames is not None:
86-
num_columns = [len(row) for row in data_table]
87-
max_columns = max(num_columns)
88-
num_col_names = len(dt_colnames)
89-
if max_columns < num_col_names: # assume numpy.loadtxt gives non-irregular array
90-
raise ImproperSizeError("More entries in dt_colnames than columns in data_table.")
91-
named_columns = 0
92-
for idx in range(num_col_names):
93-
colname = dt_colnames[idx]
94-
if colname is not None:
95-
if colname in hdata.keys():
80+
# Prioritize overwriting hdata entries with data_table column entries
81+
col_counter = 0
82+
83+
# Get a list of column counts in each entry in data table
84+
dt_col_counts = [len(row) for row in data_table]
85+
dt_max_col_count = max(dt_col_counts)
86+
87+
if dt_col_names is not None:
88+
dt_col_names_count = len(dt_col_names)
89+
if dt_max_col_count < dt_col_names_count: # assume numpy.loadtxt gives non-irregular array
90+
raise ImproperSizeError("More entries in dt_col_names_count than columns in data_table.")
91+
92+
for idx in range(dt_col_names_count):
93+
col_name = dt_col_names[idx]
94+
95+
if col_name is not None:
96+
97+
# Check if column name already exists in hdata
98+
if col_name in hdata.keys():
9699
warnings.warn(
97-
f"Entry '{colname}' in hdata has been overwritten by a data_table entry.",
100+
f"Entry '{col_name}' in hdata has been overwritten by a data_table entry.",
98101
RuntimeWarning,
99102
)
100-
data.update({colname: list(data_table[:, idx])})
101-
named_columns += 1
102103

103-
# finally add data_table as an entry named 'data table' if not all columns were parsed
104-
if named_columns < max_columns:
104+
# Add row data per column to data
105+
data.update({col_name: list(data_table[:, idx])})
106+
col_counter += 1
107+
108+
# Add data_table as an entry named 'data table' if not all columns were parsed
109+
if col_counter < dt_max_col_count:
105110
if "data table" in data.keys():
106111
warnings.warn(
107112
"Entry 'data table' in hdata has been overwritten by data_table.",
108113
RuntimeWarning,
109114
)
110115
data.update({"data table": data_table})
111116

112-
# parse name using pathlib and generate dictionary entry
113-
entry = {title: data}
117+
# Parse name using pathlib and generate dictionary entry
118+
data_key = abs_path.name
119+
entry = {data_key: data}
114120

115-
# no save
116121
if serial_file is None:
117122
return entry
118123

119124
# saving/updating file
120125
# check if supported type
121-
sf = pathlib.Path(serial_file)
122-
sf_name = sf.name
123-
extension = sf.suffix
124-
if extension not in supported_formats:
126+
sf_path = Path(serial_file)
127+
sf_name = sf_path.name
128+
sf_ext = sf_path.suffix
129+
if sf_ext not in supported_formats:
125130
raise UnsupportedTypeError(sf_name, supported_formats)
126131

127132
# new file or update
@@ -133,7 +138,7 @@ def serialize_data(
133138
pass
134139

135140
# json
136-
if extension == ".json":
141+
if sf_ext == ".json":
137142
# cannot serialize numpy arrays
138143
class NumpyEncoder(json.JSONEncoder):
139144
def default(self, data_obj):
@@ -177,7 +182,7 @@ def deserialize_data(filename, filetype=None):
177182
"""
178183

179184
# check if supported type
180-
f = pathlib.Path(filename)
185+
f = Path(filename)
181186
f_name = f.name
182187

183188
if filetype is None:

tests/test_serialization.py

+12-13
Original file line numberDiff line numberDiff line change
@@ -17,30 +17,29 @@ def test_load_multiple(tmp_path, datafile):
1717
tlm_list = sorted(dbload_dir.glob("*.gr"))
1818

1919
generated_data = None
20+
2021
for headerfile in tlm_list:
2122
# gather data using loadData
2223
hdata = loadData(headerfile, headers=True)
2324
data_table = loadData(headerfile)
2425

2526
# check path extraction
26-
generated_data = serialize_data(headerfile, hdata, data_table, dt_colnames=["r", "gr"], show_path=True)
27+
generated_data = serialize_data(headerfile, hdata, data_table, dt_col_names=["r", "gr"], show_path=True)
2728
assert headerfile == Path(generated_data[headerfile.name].pop("path"))
2829

2930
# rerun without path information and save to file
3031
generated_data = serialize_data(
3132
headerfile,
3233
hdata,
3334
data_table,
34-
dt_colnames=["r", "gr"],
35+
dt_col_names=["r", "gr"],
3536
show_path=False,
3637
serial_file=generatedjson,
3738
)
3839

3940
# compare to target
40-
target_data = deserialize_data(targetjson)
41+
target_data = deserialize_data(targetjson, filetype=".json")
4142
assert target_data == generated_data
42-
# ensure file saved properly
43-
assert target_data == deserialize_data(generatedjson, filetype=".json")
4443

4544

4645
def test_exceptions(datafile):
@@ -60,24 +59,24 @@ def test_exceptions(datafile):
6059

6160
# various dt_colnames inputs
6261
with pytest.raises(ImproperSizeError):
63-
serialize_data(loadfile, hdata, data_table, dt_colnames=["one", "two", "three is too many"])
62+
serialize_data(loadfile, hdata, data_table, dt_col_names=["one", "two", "three is too many"])
6463
# check proper output
65-
normal = serialize_data(loadfile, hdata, data_table, dt_colnames=["r", "gr"])
64+
normal = serialize_data(loadfile, hdata, data_table, dt_col_names=["r", "gr"])
6665
data_name = list(normal.keys())[0]
6766
r_list = normal[data_name]["r"]
6867
gr_list = normal[data_name]["gr"]
6968
# three equivalent ways to denote no column names
7069
missing_parameter = serialize_data(loadfile, hdata, data_table, show_path=False)
71-
empty_parameter = serialize_data(loadfile, hdata, data_table, show_path=False, dt_colnames=[])
72-
none_entry_parameter = serialize_data(loadfile, hdata, data_table, show_path=False, dt_colnames=[None, None])
70+
empty_parameter = serialize_data(loadfile, hdata, data_table, show_path=False, dt_col_names=[])
71+
none_entry_parameter = serialize_data(loadfile, hdata, data_table, show_path=False, dt_col_names=[None, None])
7372
# check equivalence
7473
assert missing_parameter == empty_parameter
7574
assert missing_parameter == none_entry_parameter
7675
assert numpy.allclose(missing_parameter[data_name]["data table"], data_table)
7776
# extract a single column
78-
r_extract = serialize_data(loadfile, hdata, data_table, show_path=False, dt_colnames=["r"])
79-
gr_extract = serialize_data(loadfile, hdata, data_table, show_path=False, dt_colnames=[None, "gr"])
80-
incorrect_r_extract = serialize_data(loadfile, hdata, data_table, show_path=False, dt_colnames=[None, "r"])
77+
r_extract = serialize_data(loadfile, hdata, data_table, show_path=False, dt_col_names=["r"])
78+
gr_extract = serialize_data(loadfile, hdata, data_table, show_path=False, dt_col_names=[None, "gr"])
79+
incorrect_r_extract = serialize_data(loadfile, hdata, data_table, show_path=False, dt_col_names=[None, "r"])
8180
# check proper columns extracted
8281
assert numpy.allclose(gr_extract[data_name]["gr"], incorrect_r_extract[data_name]["r"])
8382
assert "r" not in gr_extract[data_name]
@@ -101,7 +100,7 @@ def test_exceptions(datafile):
101100
hdata,
102101
data_table,
103102
show_path=False,
104-
dt_colnames=["c1", "c2", "c3"],
103+
dt_col_names=["c1", "c2", "c3"],
105104
)
106105
assert len(record) == 4
107106
for msg in record:

0 commit comments

Comments
 (0)