@@ -23,18 +23,21 @@ struct SqliteLocalState : public LocalTableFunctionState {
2323 SQLiteStatement stmt;
2424 bool done = false ;
2525 vector<column_t > column_ids;
26+ // ! The amount of rows we scanned as part of this row group
27+ idx_t scan_count = 1 ;
2628
2729 ~SqliteLocalState () {
2830 }
2931};
3032
3133struct SqliteGlobalState : public GlobalTableFunctionState {
32- SqliteGlobalState (idx_t max_threads) : max_threads(max_threads) {
34+ explicit SqliteGlobalState (idx_t max_threads) : max_threads(max_threads) {
3335 }
3436
3537 mutex lock;
3638 idx_t position = 0 ;
3739 idx_t max_threads;
40+ idx_t rows_per_group = 0 ;
3841
3942 idx_t MaxThreads () const override {
4043 return max_threads;
@@ -72,9 +75,8 @@ static unique_ptr<FunctionData> SqliteBind(ClientContext &context, TableFunction
7275 throw std::runtime_error (" no columns for table " + result->table_name );
7376 }
7477
75- if (!db.GetMaxRowId (result->table_name , result->max_rowid )) {
76- result->max_rowid = idx_t (-1 );
77- result->rows_per_group = idx_t (-1 );
78+ if (!db.GetRowIdInfo (result->table_name , result->row_id_info )) {
79+ result->rows_per_group = optional_idx ();
7880 }
7981
8082 result->names = names;
@@ -106,7 +108,7 @@ static void SqliteInitInternal(ClientContext &context, const SqliteBindData &bin
106108
107109 auto sql =
108110 StringUtil::Format (" SELECT %s FROM \" %s\" " , col_names, SQLiteUtils::SanitizeIdentifier (bind_data.table_name ));
109- if (bind_data.rows_per_group != idx_t (- 1 )) {
111+ if (bind_data.rows_per_group . IsValid ( )) {
110112 // we are scanning a subset of the rows - generate a WHERE clause based on
111113 // the rowid
112114 auto where_clause = StringUtil::Format (" WHERE ROWID BETWEEN %d AND %d" , rowid_min, rowid_max);
@@ -121,7 +123,11 @@ static void SqliteInitInternal(ClientContext &context, const SqliteBindData &bin
121123static unique_ptr<NodeStatistics> SqliteCardinality (ClientContext &context, const FunctionData *bind_data_p) {
122124 D_ASSERT (bind_data_p);
123125 auto &bind_data = bind_data_p->Cast <SqliteBindData>();
124- return make_uniq<NodeStatistics>(bind_data.max_rowid );
126+ if (!bind_data.row_id_info .max_rowid .IsValid ()) {
127+ return nullptr ;
128+ }
129+ auto row_count = bind_data.row_id_info .max_rowid .GetIndex () - bind_data.row_id_info .min_rowid .GetIndex ();
130+ return make_uniq<NodeStatistics>(row_count);
125131}
126132
127133static idx_t SqliteMaxThreads (ClientContext &context, const FunctionData *bind_data_p) {
@@ -130,17 +136,41 @@ static idx_t SqliteMaxThreads(ClientContext &context, const FunctionData *bind_d
130136 if (bind_data.global_db ) {
131137 return 1 ;
132138 }
133- return bind_data.max_rowid / bind_data.rows_per_group ;
139+ if (!bind_data.row_id_info .max_rowid .IsValid ()) {
140+ return 1 ;
141+ }
142+ auto row_count = bind_data.row_id_info .max_rowid .GetIndex () - bind_data.row_id_info .min_rowid .GetIndex ();
143+ return row_count / bind_data.rows_per_group .GetIndex ();
134144}
135145
136146static bool SqliteParallelStateNext (ClientContext &context, const SqliteBindData &bind_data, SqliteLocalState &lstate,
137147 SqliteGlobalState &gstate) {
138148 lock_guard<mutex> parallel_lock (gstate.lock );
139- if (gstate.position < bind_data.max_rowid ) {
149+ if (!bind_data.rows_per_group .IsValid ()) {
150+ // not doing a parallel scan - scan everything at once
151+ if (gstate.position > 0 ) {
152+ // already scanned
153+ return false ;
154+ }
155+ SqliteInitInternal (context, bind_data, lstate, 0 , 0 );
156+ gstate.position = static_cast <idx_t >(-1 );
157+ lstate.scan_count = 0 ;
158+ return true ;
159+ }
160+ auto max_row_id = bind_data.row_id_info .max_rowid .GetIndex ();
161+ if (gstate.position < max_row_id) {
162+ if (lstate.scan_count == 0 && gstate.rows_per_group < max_row_id) {
163+ // we scanned no rows in our previous slice - double the rows per group
164+ gstate.rows_per_group *= 2 ;
165+ }
166+ if (gstate.rows_per_group == 0 ) {
167+ throw InternalException (" SqliteParallelStateNext - gstate.rows_per_group not set" );
168+ }
140169 auto start = gstate.position ;
141- auto end = start + bind_data .rows_per_group - 1 ;
170+ auto end = MinValue< idx_t >(max_row_id, start + gstate .rows_per_group - 1 ) ;
142171 SqliteInitInternal (context, bind_data, lstate, start, end);
143172 gstate.position = end + 1 ;
173+ lstate.scan_count = 0 ;
144174 return true ;
145175 }
146176 return false ;
@@ -161,8 +191,16 @@ SqliteInitLocalState(ExecutionContext &context, TableFunctionInitInput &input, G
161191
162192static unique_ptr<GlobalTableFunctionState> SqliteInitGlobalState (ClientContext &context,
163193 TableFunctionInitInput &input) {
194+ auto &bind_data = input.bind_data ->Cast <SqliteBindData>();
164195 auto result = make_uniq<SqliteGlobalState>(SqliteMaxThreads (context, input.bind_data .get ()));
165196 result->position = 0 ;
197+ if (bind_data.rows_per_group .IsValid ()) {
198+ auto min_row_id = bind_data.row_id_info .min_rowid .GetIndex ();
199+ if (min_row_id > 0 ) {
200+ result->position = min_row_id - 1 ;
201+ }
202+ result->rows_per_group = bind_data.rows_per_group .GetIndex ();
203+ }
166204 return std::move (result);
167205}
168206
@@ -191,6 +229,7 @@ static void SqliteScan(ClientContext &context, TableFunctionInput &data, DataChu
191229 output.SetCardinality (out_idx);
192230 break ;
193231 }
232+ state.scan_count ++;
194233 for (idx_t col_idx = 0 ; col_idx < output.ColumnCount (); col_idx++) {
195234 auto &out_vec = output.data [col_idx];
196235 auto sqlite_column_type = stmt.GetType (col_idx);
0 commit comments