Skip to content

Commit

Permalink
impr: define - cache compiled statements
Browse files Browse the repository at this point in the history
  • Loading branch information
nalgeon committed Sep 19, 2022
1 parent 200a232 commit 17eac88
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 11 deletions.
25 changes: 17 additions & 8 deletions docs/define.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,21 @@ select name, body from sqlean_define where type = 'scalar';
└───────┴───────────────────┘
```

Scalar functions are compiled into prepared statements. SQLite requires these statements to be freed before the connection is closed. Unfortunately, there is no way to free them automatically. Therefore, always execute `define_free()` before disconnecting:

```
sqlite> .load dist/define
sqlite> select define('subxy', '? - ?');
...
sqlite> select define_free();
sqlite> .exit
```

To delete a scalar function, execute `undefine()`, then reconnect to the database:

```
sqlite> select undefine('sumn');
sqlite> select define_free();
... reconnect
sqlite> select sumn(5);
Parse error: no such function: sumn
Expand Down Expand Up @@ -146,7 +157,7 @@ Parse error: no such table: strcut

## Performance

Due to their dynamic nature, user-defined functions run rather slowly on large datasets.
User-defined functions are compiled into prepared statements, so they are pretty fast even on large datasets.

Given 1M rows table with random data:

Expand All @@ -160,27 +171,25 @@ Regular SQL query:

```sql
select max(x+1) from data;
Run Time: real 0.114 user 0.096000 sys 0.020000
Run Time: real 0.130 user 0.123171 sys 0.006865
```

Scalar function is 30x slower:
Scalar function is 2x slower:

```sql
select define('plus', ':x + 1');
select max(plus(x)) from data;
Run Time: real 3.747 user 3.720000 sys 0.024000
Run Time: real 0.249 user 0.243840 sys 0.005304
```

Table-valued function is 5x slower:
Table-valued function is 2.5x slower:

```sql
create virtual table plus using define((select :x + 1 as value));
select max(value) from data, plus(data.x);
Run Time: real 0.512 user 0.508000 sys 0.004000
Run Time: real 0.336 user 0.330145 sys 0.005352
```

The scalar function is so slow because it has to prepare an SQL statement for each value from scratch. There may be some way around this, but I haven't figured it out yet.

## Usage

```
Expand Down
148 changes: 145 additions & 3 deletions src/sqlite3-define.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,66 @@
#include "sqlite3ext.h"
SQLITE_EXTENSION_INIT1

#pragma region statement cache

typedef struct cache_node {
sqlite3_stmt* stmt;
struct cache_node* next;
} cache_node;

cache_node* cache_head = NULL;
cache_node* cache_tail = NULL;

static int cache_add(sqlite3_stmt* stmt) {
if (cache_head == NULL) {
cache_head = (cache_node*)malloc(sizeof(cache_node));
if (cache_head == NULL) {
return SQLITE_ERROR;
}
cache_head->stmt = stmt;
cache_head->next = NULL;
cache_tail = cache_head;
return SQLITE_OK;
}
cache_tail->next = (cache_node*)malloc(sizeof(cache_node));
if (cache_tail->next == NULL) {
return SQLITE_ERROR;
}
cache_tail = cache_tail->next;
cache_tail->stmt = stmt;
cache_tail->next = NULL;
return SQLITE_OK;
}

static void cache_print() {
if (cache_head == NULL) {
printf("cache is empty");
return;
}
cache_node* curr = cache_head;
while (curr != NULL) {
printf("%s\n", sqlite3_sql(curr->stmt));
curr = curr->next;
}
}

static void cache_free() {
if (cache_head == NULL) {
return;
}
cache_node* prev;
cache_node* curr = cache_head;
while (curr != NULL) {
sqlite3_finalize(curr->stmt);
prev = curr;
curr = curr->next;
free(prev);
}
cache_head = cache_tail = NULL;
}

#pragma endregion

#pragma region define scalar function

/*
Expand Down Expand Up @@ -44,6 +104,31 @@ static void exec_function(sqlite3_context* ctx, int argc, sqlite3_value** argv)
sqlite3_result_error_code(ctx, ret);
}

/*
* Executes compiled prepared statement from the context.
*/
static void exec_compiled(sqlite3_context* ctx, int argc, sqlite3_value** argv) {
int ret = SQLITE_OK;
sqlite3_stmt* stmt = sqlite3_user_data(ctx);
for (int i = 0; i < argc; i++) {
if ((ret = sqlite3_bind_value(stmt, i + 1, argv[i])) != SQLITE_OK) {
sqlite3_reset(stmt);
sqlite3_result_error_code(ctx, ret);
return;
}
}
if ((ret = sqlite3_step(stmt)) != SQLITE_ROW) {
if (ret == SQLITE_DONE) {
ret = SQLITE_MISUSE;
}
sqlite3_reset(stmt);
sqlite3_result_error_code(ctx, ret);
return;
}
sqlite3_result_value(ctx, sqlite3_column_value(stmt, 0));
sqlite3_reset(stmt);
}

/*
* Saves user-defined function into the database.
*/
Expand All @@ -68,7 +153,7 @@ static int save_function(sqlite3* db, const char* name, const char* type, const
}

/*
* Creates user-defined function.
* Creates user-defined function without caching the prepared statement.
*/
static int create_function(sqlite3* db, const char* name, const char* body) {
char* sql = sqlite3_mprintf("select %s", body);
Expand All @@ -89,6 +174,29 @@ static int create_function(sqlite3* db, const char* name, const char* body) {
NULL, sqlite3_free);
}

/*
* Creates user-defined function and caches the prepared statement.
*/
static int create_compiled(sqlite3* db, const char* name, const char* body) {
char* sql = sqlite3_mprintf("select %s", body);
if (!sql) {
return SQLITE_NOMEM;
}

sqlite3_stmt* stmt;
int ret = sqlite3_prepare_v3(db, sql, -1, SQLITE_PREPARE_PERSISTENT, &stmt, NULL);
sqlite3_free(sql);
if (ret != SQLITE_OK) {
return ret;
}
int nparams = sqlite3_bind_parameter_count(stmt);
if ((ret = cache_add(stmt)) != SQLITE_OK) {
return ret;
}

return sqlite3_create_function(db, name, nparams, SQLITE_UTF8, stmt, exec_compiled, NULL, NULL);
}

/*
* Loads user-defined functions from the database.
*/
Expand All @@ -112,7 +220,7 @@ static int load_functions(sqlite3* db) {
while (sqlite3_step(stmt) != SQLITE_DONE) {
name = (const char*)sqlite3_column_text(stmt, 0);
body = (const char*)sqlite3_column_text(stmt, 1);
ret = create_function(db, name, body);
ret = create_compiled(db, name, body);
if (ret != SQLITE_OK) {
break;
}
Expand All @@ -138,6 +246,38 @@ static void define_function(sqlite3_context* ctx, int argc, sqlite3_value** argv
}
}

/*
* Creates compiled user-defined function and saves it to the database.
*/
static void define_compiled(sqlite3_context* ctx, int argc, sqlite3_value** argv) {
sqlite3* db = sqlite3_context_db_handle(ctx);
const char* name = (const char*)sqlite3_value_text(argv[0]);
const char* body = (const char*)sqlite3_value_text(argv[1]);
int ret;
if ((ret = create_compiled(db, name, body)) != SQLITE_OK) {
sqlite3_result_error_code(ctx, ret);
return;
}
if ((ret = save_function(db, name, "scalar", body)) != SQLITE_OK) {
sqlite3_result_error_code(ctx, ret);
return;
}
}

/*
* Frees prepared statements compiled by user-defined functions.
*/
static void free_compiled(sqlite3_context* ctx, int argc, sqlite3_value** argv) {
cache_free();
}

/*
* Prints prepared statements cache contents.
*/
static void print_cache(sqlite3_context* ctx, int argc, sqlite3_value** argv) {
cache_print();
}

#pragma endregion

#pragma region define table valued function
Expand Down Expand Up @@ -461,7 +601,9 @@ __declspec(dllexport)
int sqlite3_define_init(sqlite3* db, char** pzErrMsg, const sqlite3_api_routines* pApi) {
SQLITE_EXTENSION_INIT2(pApi);
const int flags = SQLITE_UTF8 | SQLITE_DETERMINISTIC;
sqlite3_create_function(db, "define", 2, flags, NULL, define_function, NULL, NULL);
sqlite3_create_function(db, "define", 2, flags, NULL, define_compiled, NULL, NULL);
sqlite3_create_function(db, "define_free", 0, flags, NULL, free_compiled, NULL, NULL);
sqlite3_create_function(db, "define_cache", 0, flags, NULL, print_cache, NULL, NULL);
sqlite3_create_function(db, "undefine", 1, flags, NULL, undefine_function, NULL, NULL);
sqlite3_create_module(db, "define", &define_module, NULL);
return load_functions(db);
Expand Down
2 changes: 2 additions & 0 deletions test/define.sql
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,5 @@ select undefine('strcut');
select '53', count(*) = 0 from sqlean_define where name = 'strcut';
select '54', count(*) = 0 from sqlite_master where type = 'table' and name = 'strcut';
select '55', count(*) = 5 from sqlean_define;

select define_free();

0 comments on commit 17eac88

Please sign in to comment.