13
13
14
14
logger = logging .getLogger (__name__ )
15
15
16
+ # Maps the 'split_by' argument to the actual char used to split the Documents.
17
+ # 'function' is not in the mapping cause it doesn't split on chars.
18
+ _SPLIT_BY_MAPPING = {"page" : "\f " , "passage" : "\n \n " , "sentence" : "." , "word" : " " , "line" : "\n " }
19
+
16
20
17
21
@component
18
22
class DocumentSplitter :
@@ -73,7 +77,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
73
77
74
78
self .split_by = split_by
75
79
if split_by not in ["function" , "page" , "passage" , "sentence" , "word" , "line" ]:
76
- raise ValueError ("split_by must be one of 'word', 'sentence', 'page', 'passage' or 'line'." )
80
+ raise ValueError ("split_by must be one of 'function', ' word', 'sentence', 'page', 'passage' or 'line'." )
77
81
if split_by == "function" and splitting_function is None :
78
82
raise ValueError ("When 'split_by' is set to 'function', a valid 'splitting_function' must be provided." )
79
83
if split_length <= 0 :
@@ -108,7 +112,7 @@ def run(self, documents: List[Document]):
108
112
if not isinstance (documents , list ) or (documents and not isinstance (documents [0 ], Document )):
109
113
raise TypeError ("DocumentSplitter expects a List of Documents as input." )
110
114
111
- split_docs = []
115
+ split_docs : List [ Document ] = []
112
116
for doc in documents :
113
117
if doc .content is None :
114
118
raise ValueError (
@@ -117,42 +121,38 @@ def run(self, documents: List[Document]):
117
121
if doc .content == "" :
118
122
logger .warning ("Document ID {doc_id} has an empty content. Skipping this document." , doc_id = doc .id )
119
123
continue
120
- units = self ._split_into_units (doc .content , self .split_by )
121
- text_splits , splits_pages , splits_start_idxs = self ._concatenate_units (
122
- units , self .split_length , self .split_overlap , self .split_threshold
123
- )
124
- metadata = deepcopy (doc .meta )
125
- metadata ["source_id" ] = doc .id
126
- split_docs += self ._create_docs_from_splits (
127
- text_splits = text_splits , splits_pages = splits_pages , splits_start_idxs = splits_start_idxs , meta = metadata
128
- )
124
+ split_docs += self ._split (doc )
129
125
return {"documents" : split_docs }
130
126
131
- def _split_into_units (
132
- self , text : str , split_by : Literal ["function" , "page" , "passage" , "sentence" , "word" , "line" ]
133
- ) -> List [str ]:
134
- if split_by == "page" :
135
- self .split_at = "\f "
136
- elif split_by == "passage" :
137
- self .split_at = "\n \n "
138
- elif split_by == "sentence" :
139
- self .split_at = "."
140
- elif split_by == "word" :
141
- self .split_at = " "
142
- elif split_by == "line" :
143
- self .split_at = "\n "
144
- elif split_by == "function" and self .splitting_function is not None :
145
- return self .splitting_function (text )
146
- else :
147
- raise NotImplementedError (
148
- """DocumentSplitter only supports 'function', 'line', 'page',
149
- 'passage', 'sentence' or 'word' split_by options."""
150
- )
151
- units = text .split (self .split_at )
127
+ def _split (self , to_split : Document ) -> List [Document ]:
128
+ # We already check this before calling _split but
129
+ # we need to make linters happy
130
+ if to_split .content is None :
131
+ return []
132
+
133
+ if self .split_by == "function" and self .splitting_function is not None :
134
+ splits = self .splitting_function (to_split .content )
135
+ docs : List [Document ] = []
136
+ for s in splits :
137
+ meta = deepcopy (to_split .meta )
138
+ meta ["source_id" ] = to_split .id
139
+ docs .append (Document (content = s , meta = meta ))
140
+ return docs
141
+
142
+ split_at = _SPLIT_BY_MAPPING [self .split_by ]
143
+ units = to_split .content .split (split_at )
152
144
# Add the delimiter back to all units except the last one
153
145
for i in range (len (units ) - 1 ):
154
- units [i ] += self .split_at
155
- return units
146
+ units [i ] += split_at
147
+
148
+ text_splits , splits_pages , splits_start_idxs = self ._concatenate_units (
149
+ units , self .split_length , self .split_overlap , self .split_threshold
150
+ )
151
+ metadata = deepcopy (to_split .meta )
152
+ metadata ["source_id" ] = to_split .id
153
+ return self ._create_docs_from_splits (
154
+ text_splits = text_splits , splits_pages = splits_pages , splits_start_idxs = splits_start_idxs , meta = metadata
155
+ )
156
156
157
157
def _concatenate_units (
158
158
self , elements : List [str ], split_length : int , split_overlap : int , split_threshold : int
@@ -166,8 +166,8 @@ def _concatenate_units(
166
166
"""
167
167
168
168
text_splits : List [str ] = []
169
- splits_pages = []
170
- splits_start_idxs = []
169
+ splits_pages : List [ int ] = []
170
+ splits_start_idxs : List [ int ] = []
171
171
cur_start_idx = 0
172
172
cur_page = 1
173
173
segments = windowed (elements , n = split_length , step = split_length - split_overlap )
@@ -200,7 +200,7 @@ def _concatenate_units(
200
200
return text_splits , splits_pages , splits_start_idxs
201
201
202
202
def _create_docs_from_splits (
203
- self , text_splits : List [str ], splits_pages : List [int ], splits_start_idxs : List [int ], meta : Dict
203
+ self , text_splits : List [str ], splits_pages : List [int ], splits_start_idxs : List [int ], meta : Dict [ str , Any ]
204
204
) -> List [Document ]:
205
205
"""
206
206
Creates Document objects from splits enriching them with page number and the metadata of the original document.
0 commit comments