1313
1414logger = logging .getLogger (__name__ )
1515
16+ # Maps the 'split_by' argument to the actual char used to split the Documents.
17+ # 'function' is not in the mapping cause it doesn't split on chars.
18+ _SPLIT_BY_MAPPING = {"page" : "\f " , "passage" : "\n \n " , "sentence" : "." , "word" : " " , "line" : "\n " }
19+
1620
1721@component
1822class DocumentSplitter :
@@ -73,7 +77,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
7377
7478 self .split_by = split_by
7579 if split_by not in ["function" , "page" , "passage" , "sentence" , "word" , "line" ]:
76- raise ValueError ("split_by must be one of 'word', 'sentence', 'page', 'passage' or 'line'." )
80+ raise ValueError ("split_by must be one of 'function', ' word', 'sentence', 'page', 'passage' or 'line'." )
7781 if split_by == "function" and splitting_function is None :
7882 raise ValueError ("When 'split_by' is set to 'function', a valid 'splitting_function' must be provided." )
7983 if split_length <= 0 :
@@ -108,7 +112,7 @@ def run(self, documents: List[Document]):
108112 if not isinstance (documents , list ) or (documents and not isinstance (documents [0 ], Document )):
109113 raise TypeError ("DocumentSplitter expects a List of Documents as input." )
110114
111- split_docs = []
115+ split_docs : List [ Document ] = []
112116 for doc in documents :
113117 if doc .content is None :
114118 raise ValueError (
@@ -117,42 +121,38 @@ def run(self, documents: List[Document]):
117121 if doc .content == "" :
118122 logger .warning ("Document ID {doc_id} has an empty content. Skipping this document." , doc_id = doc .id )
119123 continue
120- units = self ._split_into_units (doc .content , self .split_by )
121- text_splits , splits_pages , splits_start_idxs = self ._concatenate_units (
122- units , self .split_length , self .split_overlap , self .split_threshold
123- )
124- metadata = deepcopy (doc .meta )
125- metadata ["source_id" ] = doc .id
126- split_docs += self ._create_docs_from_splits (
127- text_splits = text_splits , splits_pages = splits_pages , splits_start_idxs = splits_start_idxs , meta = metadata
128- )
124+ split_docs += self ._split (doc )
129125 return {"documents" : split_docs }
130126
131- def _split_into_units (
132- self , text : str , split_by : Literal ["function" , "page" , "passage" , "sentence" , "word" , "line" ]
133- ) -> List [str ]:
134- if split_by == "page" :
135- self .split_at = "\f "
136- elif split_by == "passage" :
137- self .split_at = "\n \n "
138- elif split_by == "sentence" :
139- self .split_at = "."
140- elif split_by == "word" :
141- self .split_at = " "
142- elif split_by == "line" :
143- self .split_at = "\n "
144- elif split_by == "function" and self .splitting_function is not None :
145- return self .splitting_function (text )
146- else :
147- raise NotImplementedError (
148- """DocumentSplitter only supports 'function', 'line', 'page',
149- 'passage', 'sentence' or 'word' split_by options."""
150- )
151- units = text .split (self .split_at )
127+ def _split (self , to_split : Document ) -> List [Document ]:
128+ # We already check this before calling _split but
129+ # we need to make linters happy
130+ if to_split .content is None :
131+ return []
132+
133+ if self .split_by == "function" and self .splitting_function is not None :
134+ splits = self .splitting_function (to_split .content )
135+ docs : List [Document ] = []
136+ for s in splits :
137+ meta = deepcopy (to_split .meta )
138+ meta ["source_id" ] = to_split .id
139+ docs .append (Document (content = s , meta = meta ))
140+ return docs
141+
142+ split_at = _SPLIT_BY_MAPPING [self .split_by ]
143+ units = to_split .content .split (split_at )
152144 # Add the delimiter back to all units except the last one
153145 for i in range (len (units ) - 1 ):
154- units [i ] += self .split_at
155- return units
146+ units [i ] += split_at
147+
148+ text_splits , splits_pages , splits_start_idxs = self ._concatenate_units (
149+ units , self .split_length , self .split_overlap , self .split_threshold
150+ )
151+ metadata = deepcopy (to_split .meta )
152+ metadata ["source_id" ] = to_split .id
153+ return self ._create_docs_from_splits (
154+ text_splits = text_splits , splits_pages = splits_pages , splits_start_idxs = splits_start_idxs , meta = metadata
155+ )
156156
157157 def _concatenate_units (
158158 self , elements : List [str ], split_length : int , split_overlap : int , split_threshold : int
@@ -166,8 +166,8 @@ def _concatenate_units(
166166 """
167167
168168 text_splits : List [str ] = []
169- splits_pages = []
170- splits_start_idxs = []
169+ splits_pages : List [ int ] = []
170+ splits_start_idxs : List [ int ] = []
171171 cur_start_idx = 0
172172 cur_page = 1
173173 segments = windowed (elements , n = split_length , step = split_length - split_overlap )
@@ -200,7 +200,7 @@ def _concatenate_units(
200200 return text_splits , splits_pages , splits_start_idxs
201201
202202 def _create_docs_from_splits (
203- self , text_splits : List [str ], splits_pages : List [int ], splits_start_idxs : List [int ], meta : Dict
203+ self , text_splits : List [str ], splits_pages : List [int ], splits_start_idxs : List [int ], meta : Dict [ str , Any ]
204204 ) -> List [Document ]:
205205 """
206206 Creates Document objects from splits enriching them with page number and the metadata of the original document.
0 commit comments