From 626b6ba0f47db9fc5312beabe30379b74b743711 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Fri, 7 Feb 2025 14:29:11 +0100 Subject: [PATCH] Adding more relevant metadata to match whats provided in our other splitters --- .../preprocessors/csv_document_splitter.py | 19 ++++++++-- .../test_csv_document_splitter.py | 36 +++++++++++++++---- 2 files changed, 47 insertions(+), 8 deletions(-) diff --git a/haystack/components/preprocessors/csv_document_splitter.py b/haystack/components/preprocessors/csv_document_splitter.py index c62f1c2655..98719a3c58 100644 --- a/haystack/components/preprocessors/csv_document_splitter.py +++ b/haystack/components/preprocessors/csv_document_splitter.py @@ -63,6 +63,12 @@ def run(self, documents: List[Document]) -> Dict[str, List[Document]]: :return: A dictionary with a key `"documents"`, mapping to a list of new `Document` objects, each representing an extracted sub-table from the original CSV. + The metadata of each document includes: + - A field `source_id` to track the original document. + - A field `row_idx_start` to indicate the starting row index of the sub-table in the original table. + - A field `col_idx_start` to indicate the starting column index of the sub-table in the original table. + - A field `split_id` to indicate the order of the split in the original document. + - All other metadata copied from the original document. - If a document cannot be processed, it is returned unchanged. - The `meta` field from the original document is preserved in the split documents. @@ -93,11 +99,20 @@ def run(self, documents: List[Document]) -> Dict[str, List[Document]]: column_split_threshold=self.column_split_threshold, ) - for split_df in split_dfs: + # Sort split_dfs first by row index, then by column index + split_dfs.sort(key=lambda dataframe: (dataframe.index[0], dataframe.columns[0])) + + for split_id, split_df in enumerate(split_dfs): split_documents.append( Document( content=split_df.to_csv(index=False, header=False, lineterminator="\n"), - meta=document.meta.copy(), + meta={ + **document.meta.copy(), + "source_id": document.id, + "row_idx_start": int(split_df.index[0]), + "col_idx_start": int(split_df.columns[0]), + "split_id": split_id, + }, ) ) diff --git a/test/components/preprocessors/test_csv_document_splitter.py b/test/components/preprocessors/test_csv_document_splitter.py index bf44357e88..697178539a 100644 --- a/test/components/preprocessors/test_csv_document_splitter.py +++ b/test/components/preprocessors/test_csv_document_splitter.py @@ -107,26 +107,37 @@ def test_single_table_no_split(self, splitter: CSVDocumentSplitter) -> None: 1,2,3 4,5,6 """ - doc = Document(content=csv_content) + doc = Document(content=csv_content, id="test_id") result = splitter.run([doc])["documents"] assert len(result) == 1 assert result[0].content == csv_content + assert result[0].meta == {"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0} def test_row_split(self, splitter: CSVDocumentSplitter, two_tables_sep_by_two_empty_rows: str) -> None: - doc = Document(content=two_tables_sep_by_two_empty_rows) + doc = Document(content=two_tables_sep_by_two_empty_rows, id="test_id") result = splitter.run([doc])["documents"] assert len(result) == 2 expected_tables = ["A,B,C\n1,2,3\n", "X,Y,Z\n7,8,9\n"] + expected_meta = [ + {"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0}, + {"source_id": "test_id", "row_idx_start": 4, "col_idx_start": 0, "split_id": 1}, + ] for i, table in enumerate(result): assert table.content == expected_tables[i] + assert table.meta == expected_meta[i] def test_column_split(self, splitter: CSVDocumentSplitter, two_tables_sep_by_two_empty_columns: str) -> None: - doc = Document(content=two_tables_sep_by_two_empty_columns) + doc = Document(content=two_tables_sep_by_two_empty_columns, id="test_id") result = splitter.run([doc])["documents"] assert len(result) == 2 expected_tables = ["A,B\n1,2\n3,4\n", "X,Y\n7,8\n9,10\n"] + expected_meta = [ + {"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0}, + {"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 4, "split_id": 1}, + ] for i, table in enumerate(result): assert table.content == expected_tables[i] + assert table.meta == expected_meta[i] def test_recursive_split_one_level(self, splitter: CSVDocumentSplitter) -> None: csv_content = """A,B,,,X,Y @@ -136,12 +147,19 @@ def test_recursive_split_one_level(self, splitter: CSVDocumentSplitter) -> None: P,Q,,,M,N 3,4,,,9,10 """ - doc = Document(content=csv_content) + doc = Document(content=csv_content, id="test_id") result = splitter.run([doc])["documents"] assert len(result) == 4 expected_tables = ["A,B\n1,2\n", "X,Y\n7,8\n", "P,Q\n3,4\n", "M,N\n9,10\n"] + expected_meta = [ + {"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0}, + {"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 4, "split_id": 1}, + {"source_id": "test_id", "row_idx_start": 4, "col_idx_start": 0, "split_id": 2}, + {"source_id": "test_id", "row_idx_start": 4, "col_idx_start": 4, "split_id": 3}, + ] for i, table in enumerate(result): assert table.content == expected_tables[i] + assert table.meta == expected_meta[i] def test_recursive_split_two_levels(self, splitter: CSVDocumentSplitter) -> None: csv_content = """A,B,,,X,Y @@ -151,12 +169,18 @@ def test_recursive_split_two_levels(self, splitter: CSVDocumentSplitter) -> None P,Q,,,, 3,4,,,, """ - doc = Document(content=csv_content) + doc = Document(content=csv_content, id="test_id") result = splitter.run([doc])["documents"] assert len(result) == 3 - expected_tables = ["A,B\n1,2\n", "P,Q\n3,4\n", "X,Y\n7,8\nM,N\n9,10\n"] + expected_tables = ["A,B\n1,2\n", "X,Y\n7,8\nM,N\n9,10\n", "P,Q\n3,4\n"] + expected_meta = [ + {"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0}, + {"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 4, "split_id": 1}, + {"source_id": "test_id", "row_idx_start": 4, "col_idx_start": 0, "split_id": 2}, + ] for i, table in enumerate(result): assert table.content == expected_tables[i] + assert table.meta == expected_meta[i] def test_threshold_no_effect(self, two_tables_sep_by_two_empty_rows: str) -> None: splitter = CSVDocumentSplitter(row_split_threshold=3)