Skip to content

Commit b4e81f7

Browse files
committed
Fix dict extend validation
1 parent 2afab37 commit b4e81f7

2 files changed

Lines changed: 55 additions & 3 deletions

File tree

src/blosc2/ctable.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5699,6 +5699,8 @@ def extend(self, data: list | CTable | Any, *, validate: bool | None = None) ->
56995699
if self.base is not None:
57005700
raise TypeError("Cannot extend view.")
57015701
if len(data) <= 0:
5702+
if isinstance(data, dict):
5703+
raise ValueError("No columns provided for extend().")
57025704
return
57035705

57045706
# Resolve effective validate flag: per-call override takes precedence
@@ -5720,9 +5722,25 @@ def extend(self, data: list | CTable | Any, *, validate: bool | None = None) ->
57205722
provided_names.add(name)
57215723
else:
57225724
if isinstance(data, dict):
5723-
provided_names = set(data) & set(current_col_names)
5724-
new_nrows = len(next(iter(data.values())))
5725-
raw_columns = {name: data[name] for name in provided_names}
5725+
known_names = [name for name in current_col_names if name in data]
5726+
if not known_names:
5727+
raise ValueError("No known stored columns provided for extend().")
5728+
column_lengths = {}
5729+
for name in known_names:
5730+
try:
5731+
column_lengths[name] = len(data[name])
5732+
except TypeError as exc:
5733+
raise TypeError(f"Column {name!r} does not have a length.") from exc
5734+
new_nrows = column_lengths[known_names[0]]
5735+
mismatched = {name: n for name, n in column_lengths.items() if n != new_nrows}
5736+
if mismatched:
5737+
details = ", ".join(f"{name}={n}" for name, n in mismatched.items())
5738+
raise ValueError(
5739+
f"All provided columns must have the same length; "
5740+
f"expected {new_nrows}, got {details}."
5741+
)
5742+
provided_names = set(known_names)
5743+
raw_columns = {name: data[name] for name in known_names}
57265744
elif isinstance(data, np.ndarray) and data.dtype.names is not None:
57275745
new_nrows = len(data)
57285746
raw_columns = {name: data[name] for name in data.dtype.names if name in current_col_names}

tests/ctable/test_extend_delete.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,40 @@ def assert_data_at_positions(table: CTable, positions: list, expected_ids: list)
5050
# -------------------------------------------------------------------
5151

5252

53+
def test_extend_dict_rejects_mismatched_column_lengths():
54+
t = CTable(Row, expected_size=8)
55+
with pytest.raises(ValueError, match="same length"):
56+
t.extend(
57+
{
58+
"id": [1, 2],
59+
"c_val": [0j],
60+
"score": [10.0, 20.0],
61+
"active": [True, False],
62+
}
63+
)
64+
65+
66+
def test_extend_dict_rejects_no_known_columns():
67+
t = CTable(Row, expected_size=8)
68+
with pytest.raises(ValueError, match="No known stored columns"):
69+
t.extend({"unknown": [1, 2]})
70+
71+
72+
def test_extend_dict_uses_known_column_lengths():
73+
t = CTable(Row, expected_size=8)
74+
t.extend(
75+
{
76+
"unknown": [0],
77+
"id": [1, 2],
78+
"c_val": [0j, 1j],
79+
"score": [10.0, 20.0],
80+
"active": [True, False],
81+
}
82+
)
83+
assert len(t) == 2
84+
assert list(t["id"][:]) == [1, 2]
85+
86+
5387
def test_gap_fill_mask_and_positions():
5488
"""extend and append fill from last valid position; mask is updated correctly."""
5589
# extend after deletions: mask and physical positions

0 commit comments

Comments
 (0)