@@ -23,18 +23,21 @@ struct SqliteLocalState : public LocalTableFunctionState {
23
23
SQLiteStatement stmt;
24
24
bool done = false ;
25
25
vector<column_t > column_ids;
26
+ // ! The amount of rows we scanned as part of this row group
27
+ idx_t scan_count = 1 ;
26
28
27
29
~SqliteLocalState () {
28
30
}
29
31
};
30
32
31
33
struct SqliteGlobalState : public GlobalTableFunctionState {
32
- SqliteGlobalState (idx_t max_threads) : max_threads(max_threads) {
34
+ explicit SqliteGlobalState (idx_t max_threads) : max_threads(max_threads) {
33
35
}
34
36
35
37
mutex lock;
36
38
idx_t position = 0 ;
37
39
idx_t max_threads;
40
+ idx_t rows_per_group = 0 ;
38
41
39
42
idx_t MaxThreads () const override {
40
43
return max_threads;
@@ -72,9 +75,8 @@ static unique_ptr<FunctionData> SqliteBind(ClientContext &context, TableFunction
72
75
throw std::runtime_error (" no columns for table " + result->table_name );
73
76
}
74
77
75
- if (!db.GetMaxRowId (result->table_name , result->max_rowid )) {
76
- result->max_rowid = idx_t (-1 );
77
- result->rows_per_group = idx_t (-1 );
78
+ if (!db.GetRowIdInfo (result->table_name , result->row_id_info )) {
79
+ result->rows_per_group = optional_idx ();
78
80
}
79
81
80
82
result->names = names;
@@ -106,7 +108,7 @@ static void SqliteInitInternal(ClientContext &context, const SqliteBindData &bin
106
108
107
109
auto sql =
108
110
StringUtil::Format (" SELECT %s FROM \" %s\" " , col_names, SQLiteUtils::SanitizeIdentifier (bind_data.table_name ));
109
- if (bind_data.rows_per_group != idx_t (- 1 )) {
111
+ if (bind_data.rows_per_group . IsValid ( )) {
110
112
// we are scanning a subset of the rows - generate a WHERE clause based on
111
113
// the rowid
112
114
auto where_clause = StringUtil::Format (" WHERE ROWID BETWEEN %d AND %d" , rowid_min, rowid_max);
@@ -121,7 +123,11 @@ static void SqliteInitInternal(ClientContext &context, const SqliteBindData &bin
121
123
static unique_ptr<NodeStatistics> SqliteCardinality (ClientContext &context, const FunctionData *bind_data_p) {
122
124
D_ASSERT (bind_data_p);
123
125
auto &bind_data = bind_data_p->Cast <SqliteBindData>();
124
- return make_uniq<NodeStatistics>(bind_data.max_rowid );
126
+ if (!bind_data.row_id_info .max_rowid .IsValid ()) {
127
+ return nullptr ;
128
+ }
129
+ auto row_count = bind_data.row_id_info .max_rowid .GetIndex () - bind_data.row_id_info .min_rowid .GetIndex ();
130
+ return make_uniq<NodeStatistics>(row_count);
125
131
}
126
132
127
133
static idx_t SqliteMaxThreads (ClientContext &context, const FunctionData *bind_data_p) {
@@ -130,17 +136,41 @@ static idx_t SqliteMaxThreads(ClientContext &context, const FunctionData *bind_d
130
136
if (bind_data.global_db ) {
131
137
return 1 ;
132
138
}
133
- return bind_data.max_rowid / bind_data.rows_per_group ;
139
+ if (!bind_data.row_id_info .max_rowid .IsValid ()) {
140
+ return 1 ;
141
+ }
142
+ auto row_count = bind_data.row_id_info .max_rowid .GetIndex () - bind_data.row_id_info .min_rowid .GetIndex ();
143
+ return row_count / bind_data.rows_per_group .GetIndex ();
134
144
}
135
145
136
146
static bool SqliteParallelStateNext (ClientContext &context, const SqliteBindData &bind_data, SqliteLocalState &lstate,
137
147
SqliteGlobalState &gstate) {
138
148
lock_guard<mutex> parallel_lock (gstate.lock );
139
- if (gstate.position < bind_data.max_rowid ) {
149
+ if (!bind_data.rows_per_group .IsValid ()) {
150
+ // not doing a parallel scan - scan everything at once
151
+ if (gstate.position > 0 ) {
152
+ // already scanned
153
+ return false ;
154
+ }
155
+ SqliteInitInternal (context, bind_data, lstate, 0 , 0 );
156
+ gstate.position = static_cast <idx_t >(-1 );
157
+ lstate.scan_count = 0 ;
158
+ return true ;
159
+ }
160
+ auto max_row_id = bind_data.row_id_info .max_rowid .GetIndex ();
161
+ if (gstate.position < max_row_id) {
162
+ if (lstate.scan_count == 0 && gstate.rows_per_group < max_row_id) {
163
+ // we scanned no rows in our previous slice - double the rows per group
164
+ gstate.rows_per_group *= 2 ;
165
+ }
166
+ if (gstate.rows_per_group == 0 ) {
167
+ throw InternalException (" SqliteParallelStateNext - gstate.rows_per_group not set" );
168
+ }
140
169
auto start = gstate.position ;
141
- auto end = start + bind_data .rows_per_group - 1 ;
170
+ auto end = MinValue< idx_t >(max_row_id, start + gstate .rows_per_group - 1 ) ;
142
171
SqliteInitInternal (context, bind_data, lstate, start, end);
143
172
gstate.position = end + 1 ;
173
+ lstate.scan_count = 0 ;
144
174
return true ;
145
175
}
146
176
return false ;
@@ -161,8 +191,16 @@ SqliteInitLocalState(ExecutionContext &context, TableFunctionInitInput &input, G
161
191
162
192
static unique_ptr<GlobalTableFunctionState> SqliteInitGlobalState (ClientContext &context,
163
193
TableFunctionInitInput &input) {
194
+ auto &bind_data = input.bind_data ->Cast <SqliteBindData>();
164
195
auto result = make_uniq<SqliteGlobalState>(SqliteMaxThreads (context, input.bind_data .get ()));
165
196
result->position = 0 ;
197
+ if (bind_data.rows_per_group .IsValid ()) {
198
+ auto min_row_id = bind_data.row_id_info .min_rowid .GetIndex ();
199
+ if (min_row_id > 0 ) {
200
+ result->position = min_row_id - 1 ;
201
+ }
202
+ result->rows_per_group = bind_data.rows_per_group .GetIndex ();
203
+ }
166
204
return std::move (result);
167
205
}
168
206
@@ -191,6 +229,7 @@ static void SqliteScan(ClientContext &context, TableFunctionInput &data, DataChu
191
229
output.SetCardinality (out_idx);
192
230
break ;
193
231
}
232
+ state.scan_count ++;
194
233
for (idx_t col_idx = 0 ; col_idx < output.ColumnCount (); col_idx++) {
195
234
auto &out_vec = output.data [col_idx];
196
235
auto sqlite_column_type = stmt.GetType (col_idx);
0 commit comments