Skip to content

Commit 626b6ba

Browse files
committed
Adding more relevant metadata to match whats provided in our other splitters
1 parent 1003904 commit 626b6ba

File tree

2 files changed

+47
-8
lines changed

2 files changed

+47
-8
lines changed

haystack/components/preprocessors/csv_document_splitter.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ def run(self, documents: List[Document]) -> Dict[str, List[Document]]:
6363
:return:
6464
A dictionary with a key `"documents"`, mapping to a list of new `Document` objects,
6565
each representing an extracted sub-table from the original CSV.
66+
The metadata of each document includes:
67+
- A field `source_id` to track the original document.
68+
- A field `row_idx_start` to indicate the starting row index of the sub-table in the original table.
69+
- A field `col_idx_start` to indicate the starting column index of the sub-table in the original table.
70+
- A field `split_id` to indicate the order of the split in the original document.
71+
- All other metadata copied from the original document.
6672
6773
- If a document cannot be processed, it is returned unchanged.
6874
- The `meta` field from the original document is preserved in the split documents.
@@ -93,11 +99,20 @@ def run(self, documents: List[Document]) -> Dict[str, List[Document]]:
9399
column_split_threshold=self.column_split_threshold,
94100
)
95101

96-
for split_df in split_dfs:
102+
# Sort split_dfs first by row index, then by column index
103+
split_dfs.sort(key=lambda dataframe: (dataframe.index[0], dataframe.columns[0]))
104+
105+
for split_id, split_df in enumerate(split_dfs):
97106
split_documents.append(
98107
Document(
99108
content=split_df.to_csv(index=False, header=False, lineterminator="\n"),
100-
meta=document.meta.copy(),
109+
meta={
110+
**document.meta.copy(),
111+
"source_id": document.id,
112+
"row_idx_start": int(split_df.index[0]),
113+
"col_idx_start": int(split_df.columns[0]),
114+
"split_id": split_id,
115+
},
101116
)
102117
)
103118

test/components/preprocessors/test_csv_document_splitter.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,26 +107,37 @@ def test_single_table_no_split(self, splitter: CSVDocumentSplitter) -> None:
107107
1,2,3
108108
4,5,6
109109
"""
110-
doc = Document(content=csv_content)
110+
doc = Document(content=csv_content, id="test_id")
111111
result = splitter.run([doc])["documents"]
112112
assert len(result) == 1
113113
assert result[0].content == csv_content
114+
assert result[0].meta == {"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0}
114115

115116
def test_row_split(self, splitter: CSVDocumentSplitter, two_tables_sep_by_two_empty_rows: str) -> None:
116-
doc = Document(content=two_tables_sep_by_two_empty_rows)
117+
doc = Document(content=two_tables_sep_by_two_empty_rows, id="test_id")
117118
result = splitter.run([doc])["documents"]
118119
assert len(result) == 2
119120
expected_tables = ["A,B,C\n1,2,3\n", "X,Y,Z\n7,8,9\n"]
121+
expected_meta = [
122+
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0},
123+
{"source_id": "test_id", "row_idx_start": 4, "col_idx_start": 0, "split_id": 1},
124+
]
120125
for i, table in enumerate(result):
121126
assert table.content == expected_tables[i]
127+
assert table.meta == expected_meta[i]
122128

123129
def test_column_split(self, splitter: CSVDocumentSplitter, two_tables_sep_by_two_empty_columns: str) -> None:
124-
doc = Document(content=two_tables_sep_by_two_empty_columns)
130+
doc = Document(content=two_tables_sep_by_two_empty_columns, id="test_id")
125131
result = splitter.run([doc])["documents"]
126132
assert len(result) == 2
127133
expected_tables = ["A,B\n1,2\n3,4\n", "X,Y\n7,8\n9,10\n"]
134+
expected_meta = [
135+
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0},
136+
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 4, "split_id": 1},
137+
]
128138
for i, table in enumerate(result):
129139
assert table.content == expected_tables[i]
140+
assert table.meta == expected_meta[i]
130141

131142
def test_recursive_split_one_level(self, splitter: CSVDocumentSplitter) -> None:
132143
csv_content = """A,B,,,X,Y
@@ -136,12 +147,19 @@ def test_recursive_split_one_level(self, splitter: CSVDocumentSplitter) -> None:
136147
P,Q,,,M,N
137148
3,4,,,9,10
138149
"""
139-
doc = Document(content=csv_content)
150+
doc = Document(content=csv_content, id="test_id")
140151
result = splitter.run([doc])["documents"]
141152
assert len(result) == 4
142153
expected_tables = ["A,B\n1,2\n", "X,Y\n7,8\n", "P,Q\n3,4\n", "M,N\n9,10\n"]
154+
expected_meta = [
155+
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0},
156+
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 4, "split_id": 1},
157+
{"source_id": "test_id", "row_idx_start": 4, "col_idx_start": 0, "split_id": 2},
158+
{"source_id": "test_id", "row_idx_start": 4, "col_idx_start": 4, "split_id": 3},
159+
]
143160
for i, table in enumerate(result):
144161
assert table.content == expected_tables[i]
162+
assert table.meta == expected_meta[i]
145163

146164
def test_recursive_split_two_levels(self, splitter: CSVDocumentSplitter) -> None:
147165
csv_content = """A,B,,,X,Y
@@ -151,12 +169,18 @@ def test_recursive_split_two_levels(self, splitter: CSVDocumentSplitter) -> None
151169
P,Q,,,,
152170
3,4,,,,
153171
"""
154-
doc = Document(content=csv_content)
172+
doc = Document(content=csv_content, id="test_id")
155173
result = splitter.run([doc])["documents"]
156174
assert len(result) == 3
157-
expected_tables = ["A,B\n1,2\n", "P,Q\n3,4\n", "X,Y\n7,8\nM,N\n9,10\n"]
175+
expected_tables = ["A,B\n1,2\n", "X,Y\n7,8\nM,N\n9,10\n", "P,Q\n3,4\n"]
176+
expected_meta = [
177+
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0},
178+
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 4, "split_id": 1},
179+
{"source_id": "test_id", "row_idx_start": 4, "col_idx_start": 0, "split_id": 2},
180+
]
158181
for i, table in enumerate(result):
159182
assert table.content == expected_tables[i]
183+
assert table.meta == expected_meta[i]
160184

161185
def test_threshold_no_effect(self, two_tables_sep_by_two_empty_rows: str) -> None:
162186
splitter = CSVDocumentSplitter(row_split_threshold=3)

0 commit comments

Comments
 (0)