Skip to content

Commit

Permalink
Adding more relevant metadata to match whats provided in our other sp…
Browse files Browse the repository at this point in the history
…litters
  • Loading branch information
sjrl committed Feb 7, 2025
1 parent 1003904 commit 626b6ba
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 8 deletions.
19 changes: 17 additions & 2 deletions haystack/components/preprocessors/csv_document_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ def run(self, documents: List[Document]) -> Dict[str, List[Document]]:
:return:
A dictionary with a key `"documents"`, mapping to a list of new `Document` objects,
each representing an extracted sub-table from the original CSV.
The metadata of each document includes:
- A field `source_id` to track the original document.
- A field `row_idx_start` to indicate the starting row index of the sub-table in the original table.
- A field `col_idx_start` to indicate the starting column index of the sub-table in the original table.
- A field `split_id` to indicate the order of the split in the original document.
- All other metadata copied from the original document.
- If a document cannot be processed, it is returned unchanged.
- The `meta` field from the original document is preserved in the split documents.
Expand Down Expand Up @@ -93,11 +99,20 @@ def run(self, documents: List[Document]) -> Dict[str, List[Document]]:
column_split_threshold=self.column_split_threshold,
)

for split_df in split_dfs:
# Sort split_dfs first by row index, then by column index
split_dfs.sort(key=lambda dataframe: (dataframe.index[0], dataframe.columns[0]))

for split_id, split_df in enumerate(split_dfs):
split_documents.append(
Document(
content=split_df.to_csv(index=False, header=False, lineterminator="\n"),
meta=document.meta.copy(),
meta={
**document.meta.copy(),
"source_id": document.id,
"row_idx_start": int(split_df.index[0]),
"col_idx_start": int(split_df.columns[0]),
"split_id": split_id,
},
)
)

Expand Down
36 changes: 30 additions & 6 deletions test/components/preprocessors/test_csv_document_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,26 +107,37 @@ def test_single_table_no_split(self, splitter: CSVDocumentSplitter) -> None:
1,2,3
4,5,6
"""
doc = Document(content=csv_content)
doc = Document(content=csv_content, id="test_id")
result = splitter.run([doc])["documents"]
assert len(result) == 1
assert result[0].content == csv_content
assert result[0].meta == {"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0}

def test_row_split(self, splitter: CSVDocumentSplitter, two_tables_sep_by_two_empty_rows: str) -> None:
doc = Document(content=two_tables_sep_by_two_empty_rows)
doc = Document(content=two_tables_sep_by_two_empty_rows, id="test_id")
result = splitter.run([doc])["documents"]
assert len(result) == 2
expected_tables = ["A,B,C\n1,2,3\n", "X,Y,Z\n7,8,9\n"]
expected_meta = [
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0},
{"source_id": "test_id", "row_idx_start": 4, "col_idx_start": 0, "split_id": 1},
]
for i, table in enumerate(result):
assert table.content == expected_tables[i]
assert table.meta == expected_meta[i]

def test_column_split(self, splitter: CSVDocumentSplitter, two_tables_sep_by_two_empty_columns: str) -> None:
doc = Document(content=two_tables_sep_by_two_empty_columns)
doc = Document(content=two_tables_sep_by_two_empty_columns, id="test_id")
result = splitter.run([doc])["documents"]
assert len(result) == 2
expected_tables = ["A,B\n1,2\n3,4\n", "X,Y\n7,8\n9,10\n"]
expected_meta = [
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0},
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 4, "split_id": 1},
]
for i, table in enumerate(result):
assert table.content == expected_tables[i]
assert table.meta == expected_meta[i]

def test_recursive_split_one_level(self, splitter: CSVDocumentSplitter) -> None:
csv_content = """A,B,,,X,Y
Expand All @@ -136,12 +147,19 @@ def test_recursive_split_one_level(self, splitter: CSVDocumentSplitter) -> None:
P,Q,,,M,N
3,4,,,9,10
"""
doc = Document(content=csv_content)
doc = Document(content=csv_content, id="test_id")
result = splitter.run([doc])["documents"]
assert len(result) == 4
expected_tables = ["A,B\n1,2\n", "X,Y\n7,8\n", "P,Q\n3,4\n", "M,N\n9,10\n"]
expected_meta = [
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0},
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 4, "split_id": 1},
{"source_id": "test_id", "row_idx_start": 4, "col_idx_start": 0, "split_id": 2},
{"source_id": "test_id", "row_idx_start": 4, "col_idx_start": 4, "split_id": 3},
]
for i, table in enumerate(result):
assert table.content == expected_tables[i]
assert table.meta == expected_meta[i]

def test_recursive_split_two_levels(self, splitter: CSVDocumentSplitter) -> None:
csv_content = """A,B,,,X,Y
Expand All @@ -151,12 +169,18 @@ def test_recursive_split_two_levels(self, splitter: CSVDocumentSplitter) -> None
P,Q,,,,
3,4,,,,
"""
doc = Document(content=csv_content)
doc = Document(content=csv_content, id="test_id")
result = splitter.run([doc])["documents"]
assert len(result) == 3
expected_tables = ["A,B\n1,2\n", "P,Q\n3,4\n", "X,Y\n7,8\nM,N\n9,10\n"]
expected_tables = ["A,B\n1,2\n", "X,Y\n7,8\nM,N\n9,10\n", "P,Q\n3,4\n"]
expected_meta = [
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0},
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 4, "split_id": 1},
{"source_id": "test_id", "row_idx_start": 4, "col_idx_start": 0, "split_id": 2},
]
for i, table in enumerate(result):
assert table.content == expected_tables[i]
assert table.meta == expected_meta[i]

def test_threshold_no_effect(self, two_tables_sep_by_two_empty_rows: str) -> None:
splitter = CSVDocumentSplitter(row_split_threshold=3)
Expand Down

0 comments on commit 626b6ba

Please sign in to comment.