diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md new file mode 100644 index 00000000..81bdfee9 --- /dev/null +++ b/.github/copilot-instructions.md @@ -0,0 +1,127 @@ +# Copilot Instructions for Flock + +## Overview + +**Flock** is a C++ DuckDB extension that integrates LLMs (Large Language Models) and RAG (Retrieval-Augmented Generation) into DuckDB through declarative SQL queries. It supports OpenAI, Azure, and Ollama providers and enables semantic functions such as `llm_complete`, `llm_filter`, `llm_embedding`, and hybrid search directly from SQL. + +- **Language**: C++17 +- **Build system**: CMake (3.5+) with DuckDB's extension CI tools (`extension-ci-tools/`) +- **Dependency manager**: vcpkg (managed via `vcpkg.json`) +- **Key dependencies**: `nlohmann-json`, `curl`, `gtest` (see `vcpkg.json`) +- **DuckDB version targeted**: v1.4.4 (see `MainDistributionPipeline.yml`) + +## Repository Layout + +``` +. +├── CMakeLists.txt # Top-level CMake (builds static + loadable extension, unit tests) +├── extension_config.cmake # DuckDB extension load config (references this repo) +├── vcpkg.json # vcpkg dependency manifest +├── Makefile # Convenience targets (lf_setup, lf_setup_dev) +├── scripts/ +│ ├── build_and_run.sh # Interactive guided build + run script +│ ├── build_project.sh # Non-interactive build script +│ └── setup_vcpkg.sh # vcpkg bootstrap +├── src/ +│ ├── flock_extension.cpp # Extension entry point (LoadInternal, FlockExtension::Load) +│ ├── include/flock/ # Public headers +│ ├── core/ # Config, common utilities +│ ├── functions/ # Scalar and aggregate SQL functions +│ ├── model_manager/ # Provider integrations (OpenAI, Azure, Ollama) +│ ├── prompt_manager/ # Prompt management +│ ├── secret_manager/ # API key/secret handling +│ ├── registry/ # Model/prompt registries +│ ├── metrics/ # Metrics collection +│ └── custom_parser/ # Custom SQL parser extension +├── test/ +│ ├── unit/ # C++ unit tests (GTest), built via CMake +│ └── integration/ # Python integration tests (pytest, uv) +├── duckdb/ # DuckDB source submodule +├── extension-ci-tools/ # DuckDB extension CI/build helpers submodule +├── .clang-format # clang-format config (LLVM style, IndentWidth=4, ColumnLimit=0) +├── .cmake-format # cmake-format config +└── .pre-commit-config.yaml # Pre-commit hooks: clang-format v18.1.8, cmake-format v0.6.13 +``` + +## Building + +Always ensure submodules are initialised before building: + +```bash +git submodule update --init --recursive +``` + +### Setup vcpkg (first time or after clean) + +```bash +bash scripts/setup_vcpkg.sh +export VCPKG_TOOLCHAIN_PATH="$(pwd)/vcpkg/scripts/buildsystems/vcpkg.cmake" +``` + +### Release build + +```bash +mkdir -p build/release +cmake -G Ninja -DCMAKE_BUILD_TYPE=Release \ + -DEXTENSION_STATIC_BUILD=1 \ + -DVCPKG_BUILD=1 \ + -DCMAKE_TOOLCHAIN_FILE="$VCPKG_TOOLCHAIN_PATH" \ + -DVCPKG_MANIFEST_DIR="$(pwd)" \ + -DDUCKDB_EXTENSION_CONFIGS="$(pwd)/extension_config.cmake" \ + -S duckdb -B build/release +cmake --build build/release --config Release +``` + +Use `-G "Unix Makefiles"` if Ninja is not available. The DuckDB binary will be at `build/release/duckdb`. + +### Debug build + +Replace `Release` with `Debug` and `build/release` with `build/debug` in the commands above. Debug builds enable AddressSanitizer automatically. + +## Running Unit Tests + +Unit tests use GTest and are built as part of the CMake build. After building: + +```bash +cd build/release # or build/debug +ctest --output-on-failure +``` + +Or run the test binary directly: `./flock_tests` + +## Running Integration Tests + +Integration tests use Python/pytest and are in `test/integration/`. They require a running DuckDB binary with the Flock extension loaded and provider credentials set in a `.env` file (see `test/integration/.env-example`). + +```bash +cd test/integration +uv sync # install Python deps (requires uv) +uv run pytest +``` + +## Code Style & Linting + +- **C++**: `clang-format` v18.1.8 (config in `.clang-format`, LLVM-based, indent=4, no column limit) +- **CMake**: `cmake-format` v0.6.13 (config in `.cmake-format`) +- Run pre-commit on staged files: `pre-commit run` or `pre-commit run --all-files` +- Install dev tools: `make lf_setup_dev` + +Always run `clang-format` on modified C++ files before committing. The CI pipeline enforces both `format` and `tidy` checks (`code-quality-check` job in `MainDistributionPipeline.yml`). + +## CI Pipeline + +Defined in `.github/workflows/MainDistributionPipeline.yml`: + +- **duckdb-stable-build**: Builds extension binaries for all platforms using DuckDB v1.4.4 CI tools. +- **code-quality-check**: Runs `clang-format` and `clang-tidy` checks. + +Triggered on push to `main`/`dev` when `src/`, `test/`, `CMakeLists.txt`, or workflow files change, and on `workflow_dispatch`. + +## Key Notes + +- The extension entry point is `src/flock_extension.cpp` → `FlockExtension::Load` → `LoadInternal`. +- All SQL functions are registered via `flock::Config::Configure(loader)` in `src/core/config/`. +- New scalar functions go in `src/functions/scalar/`; aggregate functions in `src/functions/aggregate/`. +- Provider implementations live in `src/model_manager/providers/`. +- Public headers for the extension are under `src/include/flock/`. +- The `duckdb/` and `extension-ci-tools/` directories are git submodules — do not modify them. diff --git a/.github/workflows/MainDistributionPipeline.yml b/.github/workflows/MainDistributionPipeline.yml index 5a6cbe1c..9f5749d6 100644 --- a/.github/workflows/MainDistributionPipeline.yml +++ b/.github/workflows/MainDistributionPipeline.yml @@ -29,7 +29,6 @@ jobs: duckdb_version: v1.4.4 ci_tools_version: v1.4.4 extension_name: flock - exclude_archs: 'wasm_mvp;wasm_threads;wasm_eh' code-quality-check: name: Code Quality Check diff --git a/CMakeLists.txt b/CMakeLists.txt index 9bfa1810..383772e3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,7 +16,9 @@ include_directories(src/include) add_subdirectory(src) # Find dependencies -find_package(CURL REQUIRED) +if(NOT EMSCRIPTEN) + find_package(CURL REQUIRED) +endif() find_package(nlohmann_json CONFIG REQUIRED) # Build the DuckDB static and loadable extensions @@ -43,12 +45,18 @@ if(CMAKE_BUILD_TYPE STREQUAL "Debug") endif() # Link libraries for the static extension -target_link_libraries(${EXTENSION_NAME} CURL::libcurl - nlohmann_json::nlohmann_json) +if(NOT EMSCRIPTEN) + target_link_libraries(${EXTENSION_NAME} CURL::libcurl) +endif() +target_link_libraries(${EXTENSION_NAME} nlohmann_json::nlohmann_json) # Link libraries for the loadable extension -target_link_libraries(${LOADABLE_EXTENSION_NAME} CURL::libcurl - nlohmann_json::nlohmann_json) +if(NOT EMSCRIPTEN) + target_link_libraries(${LOADABLE_EXTENSION_NAME} CURL::libcurl) +endif() +target_link_libraries(${LOADABLE_EXTENSION_NAME} nlohmann_json::nlohmann_json) + +# WASM builds use EM_JS with synchronous XMLHttpRequest for HTTP # Install the extension install( @@ -63,6 +71,8 @@ if(CMAKE_BUILD_TYPE STREQUAL "Coverage") add_link_options(-fprofile-instr-generate -fcoverage-mapping) endif() -# Add the test directory -enable_testing() -add_subdirectory(test/unit) +if(NOT EMSCRIPTEN) + # Add the test directory if not on WASM + enable_testing() + add_subdirectory(test/unit) +endif() diff --git a/_codeql_detected_source_root b/_codeql_detected_source_root new file mode 120000 index 00000000..945c9b46 --- /dev/null +++ b/_codeql_detected_source_root @@ -0,0 +1 @@ +. \ No newline at end of file diff --git a/src/core/config/config.cpp b/src/core/config/config.cpp index e36eac90..0fe50d2a 100644 --- a/src/core/config/config.cpp +++ b/src/core/config/config.cpp @@ -1,7 +1,8 @@ #include "flock/core/config.hpp" -#include "filesystem.hpp" +#include "duckdb/common/file_system.hpp" #include "flock/secret_manager/secret_manager.hpp" #include +#include #include namespace flock { @@ -11,15 +12,15 @@ duckdb::DatabaseInstance* Config::db; std::string Config::get_schema_name() { return "flock_config"; } std::filesystem::path Config::get_global_storage_path() { -#ifdef _WIN32 - const char* homeDir = getenv("USERPROFILE"); +#ifdef __EMSCRIPTEN__ + return std::filesystem::path("opfs://flock_data/flock.db"); #else - const char* homeDir = getenv("HOME"); -#endif - if (homeDir == nullptr) { + const auto& home = duckdb::FileSystem::GetHomeDirectory(nullptr); + if (home.empty()) { throw std::runtime_error("Could not find home directory"); } - return std::filesystem::path(homeDir) / ".duckdb" / "flock_storage" / "flock.db"; + return std::filesystem::path(home) / ".duckdb" / "flock_storage" / "flock.db"; +#endif } duckdb::Connection Config::GetConnection(duckdb::DatabaseInstance* db) { @@ -30,42 +31,58 @@ duckdb::Connection Config::GetConnection(duckdb::DatabaseInstance* db) { return con; } -duckdb::Connection Config::GetGlobalConnection() { - const duckdb::DuckDB db(Config::get_global_storage_path().string()); - duckdb::Connection con(*db.instance); - return con; -} -void Config::SetupGlobalStorageLocation() { - const auto flock_global_path = get_global_storage_path(); - const auto flockDir = flock_global_path.parent_path(); - if (!std::filesystem::exists(flockDir)) { - try { - std::filesystem::create_directories(flockDir); - } catch (const std::filesystem::filesystem_error& e) { - std::cerr << "Error creating directories: " << e.what() << std::endl; +void Config::SetupGlobalStorageLocation(duckdb::DatabaseInstance* db_instance) { + if (!db_instance) { + return; + } +#ifdef __EMSCRIPTEN__ + // WASM: Client registers OPFS files before loading extension + return; +#endif + auto& fs = duckdb::FileSystem::GetFileSystem(*db_instance); + const std::string dir_path = get_global_storage_path().parent_path().string(); + try { + if (!dir_path.empty() && !fs.DirectoryExists(dir_path)) { + fs.CreateDirectory(dir_path); } + } catch (const std::exception& e) { + std::cerr << "Error creating directory " << dir_path << ": " << e.what() << std::endl; } } void Config::ConfigSchema(duckdb::Connection& con, std::string& schema_name) { - auto result = con.Query(duckdb_fmt::format(" SELECT * " - " FROM information_schema.schemata " - " WHERE schema_name = '{}'; ", - schema_name)); - if (result->RowCount() == 0) { - con.Query(duckdb_fmt::format("CREATE SCHEMA {};", schema_name)); - } + con.Query(duckdb_fmt::format("CREATE SCHEMA IF NOT EXISTS {};", schema_name)); } -void Config::ConfigureGlobal() { - auto con = Config::GetGlobalConnection(); +void Config::ConfigureGlobal(duckdb::DatabaseInstance* db_instance) { + if (!db_instance) { + return; + } + // Use the already-attached flock_storage database + auto con = Config::GetConnection(db_instance); + // Switch to flock_storage so ConfigureTables creates tables there. + // We switch back to memory afterward to avoid leaving the connection + // pointing at flock_storage, which would affect subsequent queries. + auto use_result = con.Query("USE flock_storage;"); + if (use_result->HasError()) { + std::cerr << "Failed to USE flock_storage: " << use_result->GetError() << std::endl; + return; + } ConfigureTables(con, ConfigType::GLOBAL); + con.Query("USE memory;"); } void Config::ConfigureLocal(duckdb::DatabaseInstance& db) { auto con = Config::GetConnection(&db); ConfigureTables(con, ConfigType::LOCAL); + + const std::string global_path = get_global_storage_path().string(); + auto result = con.Query( + duckdb_fmt::format("ATTACH DATABASE '{}' AS flock_storage;", global_path)); + if (result->HasError()) { + std::cerr << "Failed to attach flock_storage: " << result->GetError() << std::endl; + } } void Config::ConfigureTables(duckdb::Connection& con, const ConfigType type) { @@ -81,11 +98,23 @@ void Config::Configure(duckdb::ExtensionLoader& loader) { Registry::Register(loader); SecretManager::Register(loader); auto& db = loader.GetDatabaseInstance(); - if (const auto db_path = db.config.options.database_path; db_path != get_global_storage_path().string()) { - SetupGlobalStorageLocation(); - ConfigureGlobal(); + const auto db_path = db.config.options.database_path; + const std::string global_path = get_global_storage_path().string(); + + // If the main database is already at the global storage path, still attach for WASM :memory: case + if (db_path == global_path) { + auto con = GetConnection(&db); + ConfigureTables(con, ConfigType::LOCAL); + ConfigureTables(con, ConfigType::GLOBAL); +#ifdef __EMSCRIPTEN__ ConfigureLocal(db); +#endif + return; } + + SetupGlobalStorageLocation(&db); + ConfigureLocal(db); + ConfigureGlobal(&db); } void Config::AttachToGlobalStorage(duckdb::Connection& con, bool read_only) { @@ -116,11 +145,7 @@ bool Config::StorageAttachmentGuard::TryDetach() { } void Config::StorageAttachmentGuard::Wait(int milliseconds) { - auto start = std::chrono::steady_clock::now(); - auto duration = std::chrono::milliseconds(milliseconds); - while (std::chrono::steady_clock::now() - start < duration) { - // Busy-wait until the specified duration has elapsed - } + std::this_thread::sleep_for(std::chrono::milliseconds(milliseconds)); } Config::StorageAttachmentGuard::StorageAttachmentGuard(duckdb::Connection& con, bool read_only) @@ -130,7 +155,9 @@ Config::StorageAttachmentGuard::StorageAttachmentGuard(duckdb::Connection& con, attached = true; return; } - Wait(RETRY_DELAY_MS); + if (attempt < MAX_RETRIES - 1) { + Wait(RETRY_DELAY_MS); + } } Config::AttachToGlobalStorage(connection, read_only); attached = true; diff --git a/src/core/config/model.cpp b/src/core/config/model.cpp index 310b08f4..b85d558e 100644 --- a/src/core/config/model.cpp +++ b/src/core/config/model.cpp @@ -9,56 +9,38 @@ std::string Config::get_user_defined_models_table_name() { return "FLOCKMTL_MODE void Config::SetupDefaultModelsConfig(duckdb::Connection& con, std::string& schema_name) { const std::string table_name = Config::get_default_models_table_name(); - // Ensure schema exists - auto result = con.Query(duckdb_fmt::format(" SELECT table_name " - " FROM information_schema.tables " - " WHERE table_schema = '{}' " - " AND table_name = '{}'; ", - schema_name, table_name)); - if (result->RowCount() == 0) { - con.Query(duckdb_fmt::format(" INSTALL JSON; " - " LOAD JSON; " - " CREATE TABLE {}.{} ( " - " model_name VARCHAR NOT NULL PRIMARY KEY, " - " model VARCHAR NOT NULL, " - " provider_name VARCHAR NOT NULL, " - " model_args JSON DEFAULT '{{}}'" - " ); ", - schema_name, table_name)); - - con.Query(duckdb_fmt::format( - "INSERT INTO {}.{} (model_name, model, provider_name) " - "VALUES " - "('default', 'gpt-4o-mini', 'openai'), " - "('gpt-4o-mini', 'gpt-4o-mini', 'openai'), " - "('gpt-4o', 'gpt-4o', 'openai'), " - "('gpt-4o-transcribe', 'gpt-4o-transcribe', 'openai')," - "('gpt-4o-mini-transcribe', 'gpt-4o-mini-transcribe', 'openai')," - "('text-embedding-3-large', 'text-embedding-3-large', 'openai'), " - "('text-embedding-3-small', 'text-embedding-3-small', 'openai');", - schema_name, table_name)); - } + con.Query("INSTALL JSON; LOAD JSON;"); + con.Query(duckdb_fmt::format(" CREATE TABLE IF NOT EXISTS {}.{} ( " + " model_name VARCHAR NOT NULL PRIMARY KEY, " + " model VARCHAR NOT NULL, " + " provider_name VARCHAR NOT NULL, " + " model_args JSON DEFAULT '{{}}'" + " ); ", + schema_name, table_name)); + + con.Query(duckdb_fmt::format( + "INSERT OR IGNORE INTO {}.{} (model_name, model, provider_name) " + "VALUES " + "('default', 'gpt-4o-mini', 'openai'), " + "('gpt-4o-mini', 'gpt-4o-mini', 'openai'), " + "('gpt-4o', 'gpt-4o', 'openai'), " + "('gpt-4o-transcribe', 'gpt-4o-transcribe', 'openai')," + "('gpt-4o-mini-transcribe', 'gpt-4o-mini-transcribe', 'openai')," + "('text-embedding-3-large', 'text-embedding-3-large', 'openai'), " + "('text-embedding-3-small', 'text-embedding-3-small', 'openai');", + schema_name, table_name)); } void Config::SetupUserDefinedModelsConfig(duckdb::Connection& con, std::string& schema_name) { const std::string table_name = Config::get_user_defined_models_table_name(); - // Ensure schema exists - auto result = con.Query(duckdb_fmt::format(" SELECT table_name " - " FROM information_schema.tables " - " WHERE table_schema = '{}' " - " AND table_name = '{}'; ", - schema_name, table_name)); - if (result->RowCount() == 0) { - con.Query(duckdb_fmt::format(" INSTALL JSON; " - " LOAD JSON; " - " CREATE TABLE {}.{} ( " - " model_name VARCHAR NOT NULL PRIMARY KEY, " - " model VARCHAR NOT NULL, " - " provider_name VARCHAR NOT NULL, " - " model_args JSON NOT NULL" - " ); ", - schema_name, table_name)); - } + con.Query("INSTALL JSON; LOAD JSON;"); + con.Query(duckdb_fmt::format(" CREATE TABLE IF NOT EXISTS {}.{} ( " + " model_name VARCHAR NOT NULL PRIMARY KEY, " + " model VARCHAR NOT NULL, " + " provider_name VARCHAR NOT NULL, " + " model_args JSON NOT NULL" + " ); ", + schema_name, table_name)); } void Config::ConfigModelTable(duckdb::Connection& con, std::string& schema_name, const ConfigType type) { diff --git a/src/core/config/prompt.cpp b/src/core/config/prompt.cpp index a3efbd99..0d574f8f 100644 --- a/src/core/config/prompt.cpp +++ b/src/core/config/prompt.cpp @@ -8,25 +8,18 @@ std::string Config::get_prompts_table_name() { return "FLOCKMTL_PROMPT_INTERNAL_ void Config::ConfigPromptTable(duckdb::Connection& con, std::string& schema_name, const ConfigType type) { const std::string table_name = "FLOCKMTL_PROMPT_INTERNAL_TABLE"; - auto result = con.Query(duckdb_fmt::format(" SELECT table_name " - " FROM information_schema.tables " - " WHERE table_schema = '{}' " - " AND table_name = '{}'; ", - schema_name, table_name)); - if (result->RowCount() == 0) { - con.Query(duckdb_fmt::format(" CREATE TABLE {}.{} ( " - " prompt_name VARCHAR NOT NULL, " - " prompt VARCHAR NOT NULL, " - " updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, " - " version INT DEFAULT 1, " - " PRIMARY KEY (prompt_name, version) " - " ); ", + con.Query(duckdb_fmt::format(" CREATE TABLE IF NOT EXISTS {}.{} ( " + " prompt_name VARCHAR NOT NULL, " + " prompt VARCHAR NOT NULL, " + " updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, " + " version INT DEFAULT 1, " + " PRIMARY KEY (prompt_name, version) " + " ); ", + schema_name, table_name)); + if (type == ConfigType::GLOBAL) { + con.Query(duckdb_fmt::format(" INSERT OR IGNORE INTO {}.{} (prompt_name, prompt) " + " VALUES ('hello-world', 'Tell me hello world'); ", schema_name, table_name)); - if (type == ConfigType::GLOBAL) { - con.Query(duckdb_fmt::format(" INSERT INTO {}.{} (prompt_name, prompt) " - " VALUES ('hello-world', 'Tell me hello world'); ", - schema_name, table_name)); - } } } diff --git a/src/functions/aggregate/aggregate.cpp b/src/functions/aggregate/aggregate.cpp index d7389d25..5ed19f9c 100644 --- a/src/functions/aggregate/aggregate.cpp +++ b/src/functions/aggregate/aggregate.cpp @@ -130,8 +130,6 @@ AggregateFunctionBase::CastInputsToJson(duckdb::Vector inputs[], idx_t count) { if (prompt_context_json.contains("context_columns")) { context_columns = prompt_context_json["context_columns"]; prompt_context_json.erase("context_columns"); - } else { - throw std::runtime_error("Missing 'context_columns' in second argument. The prompt struct must include context_columns."); } return std::make_tuple(prompt_context_json, context_columns); diff --git a/src/include/flock/core/config.hpp b/src/include/flock/core/config.hpp index 88ce1581..be38b08e 100644 --- a/src/include/flock/core/config.hpp +++ b/src/include/flock/core/config.hpp @@ -1,8 +1,8 @@ #pragma once -#include "filesystem.hpp" #include "flock/core/common.hpp" #include "flock/registry/registry.hpp" +#include #include namespace flock { @@ -15,9 +15,9 @@ class Config { static duckdb::DatabaseInstance* db; static duckdb::DatabaseInstance* global_db; static duckdb::Connection GetConnection(duckdb::DatabaseInstance* db = nullptr); - static duckdb::Connection GetGlobalConnection(); + static void Configure(duckdb::ExtensionLoader& loader); - static void ConfigureGlobal(); + static void ConfigureGlobal(duckdb::DatabaseInstance* db_instance); static void ConfigureTables(duckdb::Connection& con, ConfigType type); static void ConfigureLocal(duckdb::DatabaseInstance& db); @@ -51,7 +51,7 @@ class Config { }; private: - static void SetupGlobalStorageLocation(); + static void SetupGlobalStorageLocation(duckdb::DatabaseInstance* db_instance); static void ConfigSchema(duckdb::Connection& con, std::string& schema_name); static void ConfigPromptTable(duckdb::Connection& con, std::string& schema_name, ConfigType type); static void ConfigModelTable(duckdb::Connection& con, std::string& schema_name, ConfigType type); diff --git a/src/include/flock/model_manager/providers/handlers/base_handler.hpp b/src/include/flock/model_manager/providers/handlers/base_handler.hpp index 3af445e1..b5a43e47 100644 --- a/src/include/flock/model_manager/providers/handlers/base_handler.hpp +++ b/src/include/flock/model_manager/providers/handlers/base_handler.hpp @@ -67,6 +67,36 @@ class BaseModelProviderHandler : public IModelProviderHandler { public: protected: std::vector ExecuteBatch(const std::vector& jsons, bool async = true, const std::string& contentType = "application/json", RequestType request_type = RequestType::Completion) { +#ifdef __EMSCRIPTEN__ + // WASM: Process requests sequentially using emscripten fetch + std::vector results(jsons.size()); + bool is_completion = (request_type == RequestType::Completion); + auto url = is_completion ? getCompletionUrl() : getEmbedUrl(); + + for (size_t i = 0; i < jsons.size(); ++i) { + prepareSessionForRequest(url); + setParameters(jsons[i].dump(), contentType); + auto response = postRequest(contentType); + + if (!response.is_error && !response.text.empty() && isJson(response.text)) { + try { + nlohmann::json parsed = nlohmann::json::parse(response.text); + checkResponse(parsed, request_type); + if (is_completion) { + results[i] = ExtractCompletionOutput(parsed); + } else { + results[i] = ExtractEmbeddingVector(parsed); + } + } catch (const std::exception& e) { + trigger_error(std::string("JSON parse error: ") + e.what()); + } + } else { + trigger_error("Empty or invalid response: " + response.error_message); + } + } + return results; +#else + // Native: Use curl multi-handle for parallel requests struct CurlRequestData { std::string response; CURL* easy = nullptr; @@ -188,9 +218,12 @@ class BaseModelProviderHandler : public IModelProviderHandler { std::remove(requests[i].temp_file_path.c_str()); } - curl_easy_getinfo(requests[i].easy, CURLINFO_RESPONSE_CODE, NULL); + long http_code = 0; + curl_easy_getinfo(requests[i].easy, CURLINFO_RESPONSE_CODE, &http_code); - if (isJson(requests[i].response)) { + if (requests[i].response.empty()) { + trigger_error("Empty response from provider (HTTP " + std::to_string(http_code) + ", URL: " + url + ")"); + } else if (isJson(requests[i].response)) { try { nlohmann::json parsed = nlohmann::json::parse(requests[i].response); checkResponse(parsed, request_type); @@ -206,13 +239,21 @@ class BaseModelProviderHandler : public IModelProviderHandler { try { results[i] = ExtractOutput(parsed, request_type); } catch (const std::exception& e) { - trigger_error(std::string("Output extraction error: ") + e.what()); + std::string msg = e.what(); + if (msg.rfind("[ModelProvider]", 0) == 0) { + throw; + } + trigger_error(std::string("Output extraction error: ") + msg); } } catch (const std::exception& e) { - trigger_error(std::string("Response processing error: ") + e.what()); + std::string msg = e.what(); + if (msg.rfind("[ModelProvider]", 0) == 0) { + throw; + } + trigger_error(std::string("Response processing error: ") + msg); } } else { - trigger_error("Invalid JSON response: " + requests[i].response); + trigger_error("Invalid JSON response (HTTP " + std::to_string(http_code) + ", URL: " + url + "): " + requests[i].response); } // Clean up mime form for transcriptions @@ -233,6 +274,7 @@ class BaseModelProviderHandler : public IModelProviderHandler { curl_multi_cleanup(multi_handle); return results; +#endif } virtual void setParameters(const std::string& data, const std::string& contentType = "") = 0; @@ -266,17 +308,39 @@ class BaseModelProviderHandler : public IModelProviderHandler { virtual std::pair ExtractTokenUsage(const nlohmann::json& response) const = 0; void trigger_error(const std::string& msg) { + const std::string prefix = "[ModelProvider] "; + std::string full_message; + if (msg.rfind(prefix, 0) == 0) { + full_message = msg; + } else { + full_message = prefix + msg; + } + if (_throw_exception) { - throw std::runtime_error("[ModelProvider] error. Reason: " + msg); + throw std::runtime_error(full_message); } else { - std::cerr << "[ModelProvider] error. Reason: " << msg << '\n'; + std::cerr << full_message << '\n'; } } void checkResponse(const nlohmann::json& json, RequestType request_type) { if (json.contains("error")) { - auto reason = json["error"].dump(); - trigger_error(reason); + const auto& err = json["error"]; + std::string reason; + + if (err.is_object()) { + if (err.contains("message") && err["message"].is_string()) { + reason = err["message"].get(); + } else { + reason = err.dump(); + } + } else if (err.is_string()) { + reason = err.get(); + } else { + reason = err.dump(); + } + + trigger_error("Provider error: " + reason); std::cerr << ">> response error :\n" << json.dump(2) << "\n"; } diff --git a/src/include/flock/model_manager/providers/handlers/openai.hpp b/src/include/flock/model_manager/providers/handlers/openai.hpp index 8162e341..d3dbf870 100644 --- a/src/include/flock/model_manager/providers/handlers/openai.hpp +++ b/src/include/flock/model_manager/providers/handlers/openai.hpp @@ -90,8 +90,8 @@ class OpenAIModelManager : public BaseModelProviderHandler { for (const auto& embedding: embeddings) { results.push_back(embedding["embedding"]); } - return results; } + return results; } std::pair ExtractTokenUsage(const nlohmann::json& response) const override { diff --git a/src/include/flock/model_manager/providers/handlers/session.hpp b/src/include/flock/model_manager/providers/handlers/session.hpp index 35158973..a8407229 100644 --- a/src/include/flock/model_manager/providers/handlers/session.hpp +++ b/src/include/flock/model_manager/providers/handlers/session.hpp @@ -1,12 +1,19 @@ #pragma once -#include +#include #include #include -#include #include #include +#ifdef __EMSCRIPTEN__ +#include "wasm_http.hpp" +#include +#else +#include +#include +#endif + struct Response { std::string text; bool is_error; @@ -16,123 +23,234 @@ struct Response { // Simple curl Session inspired by CPR class Session { public: - Session(const std::string& provider, bool throw_exception) - : provider_(provider), throw_exception_{throw_exception} { - initCurl(); - ignoreSSL(); - } - - Session(const std::string& provider, bool throw_exception, std::string proxy_url) - : provider_(provider), throw_exception_{throw_exception} { - initCurl(); - ignoreSSL(); - setProxyUrl(proxy_url); - } - - ~Session() { - curl_easy_cleanup(curl_); - curl_global_cleanup(); - if (mime_form_ != nullptr) { - curl_mime_free(mime_form_); - } - } - - void initCurl() { - curl_global_init(CURL_GLOBAL_ALL); - curl_ = curl_easy_init(); - if (curl_ == nullptr) { - throw std::runtime_error("curl cannot initialize");// here we throw it shouldn't happen - } - curl_easy_setopt(curl_, CURLOPT_NOSIGNAL, 1L); - } - - void ignoreSSL() { curl_easy_setopt(curl_, CURLOPT_SSL_VERIFYPEER, 0L); } - - void setUrl(const std::string& url) { url_ = url; } - - void setToken(const std::string& token, const std::string& organization) { - token_ = token; - organization_ = organization; - } - void setProxyUrl(const std::string& url) { - proxy_url_ = url; - curl_easy_setopt(curl_, CURLOPT_PROXY, proxy_url_.c_str()); - } - - void setBeta(const std::string& beta) { beta_ = beta; } - + // Constructor/Destructor + Session(const std::string& provider, bool throw_exception); + Session(const std::string& provider, bool throw_exception, std::string proxy_url); + ~Session(); + + // Common interface + void ignoreSSL(); + void setUrl(const std::string& url); + void setToken(const std::string& token, const std::string& organization); + void setProxyUrl(const std::string& url); + void setBeta(const std::string& beta); void setBody(const std::string& data); void setMultiformPart(const std::pair& filefield_and_filepath, const std::map& fields); Response getPrepare(); Response postPrepare(const std::string& contentType = ""); - Response postPrepareOllama(const std::string& contentType = ""); Response deletePrepare(); - Response makeRequest(const std::string& contentType = ""); - void set_auth_header(struct curl_slist** headers_ptr); - std::string easyEscape(const std::string& text); + Response postPrepareOllama(const std::string& contentType = ""); Response validOllamaModelsJson(const std::string& url); + std::string easyEscape(const std::string& text); private: - static size_t writeFunction(void* ptr, size_t size, size_t nmemb, std::string* data) { - data->append((char*) ptr, size * nmemb); - return size * nmemb; - } - -private: - CURL* curl_; - CURLcode res_; - curl_mime* mime_form_ = nullptr; std::string url_; - std::string proxy_url_; + std::string body_; std::string token_; std::string organization_; std::string beta_; std::string provider_; - bool throw_exception_; + +#ifdef __EMSCRIPTEN__ + Response makeWasmRequest(const char* method, const std::string& contentType); +#else + // Native-only members + CURL* curl_; + CURLcode res_; + curl_mime* mime_form_ = nullptr; + std::string proxy_url_; std::mutex mutex_request_; + + void initCurl(); + // Native-specific helpers + static size_t writeFunction(void* ptr, size_t size, size_t nmemb, std::string* data) { + data->append((char*) ptr, size * nmemb); + return size * nmemb; + } + Response makeRequest(const std::string& contentType = ""); + void set_auth_header(struct curl_slist** headers_ptr); +#endif }; -inline Response Session::validOllamaModelsJson(const std::string& url) { - std::lock_guard lock(mutex_request_); - struct curl_slist* headers = NULL; - curl_easy_setopt(curl_, CURLOPT_HTTPHEADER, headers); - curl_easy_setopt(curl_, CURLOPT_URL, url.c_str()); +inline void Session::setUrl(const std::string& url) { url_ = url; } +inline void Session::setBeta(const std::string& beta) { beta_ = beta; } +inline void Session::setToken(const std::string& token, const std::string& organization) { + token_ = token; + organization_ = organization; +} - std::string response_string; - std::string header_string; - curl_easy_setopt(curl_, CURLOPT_WRITEFUNCTION, writeFunction); - curl_easy_setopt(curl_, CURLOPT_WRITEDATA, &response_string); - curl_easy_setopt(curl_, CURLOPT_HEADERDATA, &header_string); +inline Session::Session(const std::string& provider, bool throw_exception) + : provider_(provider), throw_exception_(throw_exception) { +#ifndef __EMSCRIPTEN__ + initCurl(); + ignoreSSL(); +#endif +} - res_ = curl_easy_perform(curl_); - bool is_error = false; - std::string error_msg{}; - if (res_ != CURLE_OK) { - is_error = true; - error_msg = " curl_easy_perform() failed: " + std::string{curl_easy_strerror(res_)}; - if (throw_exception_) { - throw std::runtime_error(error_msg); +inline Session::Session(const std::string& provider, bool throw_exception, std::string proxy_url) + : provider_(provider), throw_exception_(throw_exception) { + // Proxy is not supported in WASM +#ifndef __EMSCRIPTEN__ + initCurl(); +#else + ignoreSSL(); + setProxyUrl(proxy_url); +#endif +} + +inline Session::~Session() { +#ifndef __EMSCRIPTEN__ + curl_easy_cleanup(curl_); + curl_global_cleanup(); + if (mime_form_ != nullptr) { + curl_mime_free(mime_form_); + } +#endif +} + + +inline std::string Session::easyEscape(const std::string& text) { +#ifndef __EMSCRIPTEN__ + + char* encoded_output = curl_easy_escape(curl_, text.c_str(), static_cast(text.length())); + std::string str = std::string{encoded_output}; + curl_free(encoded_output); + return str; +#else + std::string result; + for (char c: text) { + if (isalnum(c) || c == '-' || c == '_' || c == '.' || c == '~') { + result += c; } else { - std::cerr << error_msg << '\n'; + char buf[4]; + snprintf(buf, sizeof(buf), "%%%02X", (unsigned char) c); + result += buf; } } - return {response_string, is_error, error_msg}; + return result; +#endif +} + +#ifdef __EMSCRIPTEN__ +inline Response Session::makeWasmRequest(const char* method, const std::string& contentType) { + // Build headers as JSON object + std::string headers_json = "{"; + bool first = true; + + const auto addHeader = [&](const std::string& key, const std::string& value) { + if (!first) headers_json += ","; + headers_json += "\"" + key + "\":\"" + value + "\""; + first = false; + }; + + if (!contentType.empty()) { + addHeader("Content-Type", contentType); + } + + if (!token_.empty()) { + if (provider_ == "OpenAI") { + addHeader("Authorization", "Bearer " + token_); + } else if (provider_ == "Azure") { + addHeader("api-key", token_); + } else if (provider_ == "Anthropic") { + addHeader("x-api-key", token_); + addHeader("anthropic-version", "2023-06-01"); + } + } + + if (!organization_.empty()) { + addHeader(provider_ + "-Organization", organization_); + } + + if (!beta_.empty()) { + addHeader(provider_ + "-Beta", beta_); + } + + headers_json += "}"; + + // Make the request via JavaScript + char* result = wasm_http_request(method, url_.c_str(), body_.c_str(), headers_json.c_str()); + std::string result_str(result); + free(result); + + Response response; + try { + auto json_result = nlohmann::json::parse(result_str); + + int status = json_result.value("status", 0); + std::string response_text = json_result.value("response", ""); + bool has_error = json_result.contains("error"); + + if (has_error || status == 0) { + response.is_error = true; + response.error_message = provider_ + " HTTP request failed: " + result_str; + if (throw_exception_) { + throw std::runtime_error(response.error_message); + } + } else if (status >= 200 && status < 300) { + response.text = response_text; + response.is_error = false; + } else { + response.is_error = true; + response.error_message = provider_ + " HTTP " + std::to_string(status) + ": " + response_text; + if (throw_exception_) { + throw std::runtime_error(response.error_message); + } + } + } catch (const std::exception& e) { + response.is_error = true; + response.error_message = provider_ + " Error parsing response: " + e.what(); + if (throw_exception_) { + throw; + } + } + + return response; +} +#endif + + +#ifndef __EMSCRIPTEN__ +inline void Session::initCurl() { + curl_ = curl_easy_init(); + if (curl_ == nullptr) { + throw std::runtime_error("curl cannot initialize"); + } + curl_easy_setopt(curl_, CURLOPT_NOSIGNAL, 1); +} +#endif + +inline void Session::ignoreSSL() { +#ifndef __EMSCRIPTEN__ + curl_easy_setopt(curl_, CURLOPT_SSL_VERIFYPEER, 0L); +#endif +} + +inline void Session::setProxyUrl(const std::string& url) { +#ifndef __EMSCRIPTEN__ + proxy_url_ = url; + curl_easy_setopt(curl_, CURLOPT_PROXY, proxy_url_.c_str()); +#endif } inline void Session::setBody(const std::string& data) { +#ifndef __EMSCRIPTEN__ if (curl_) { curl_easy_setopt(curl_, CURLOPT_POSTFIELDSIZE, data.length()); curl_easy_setopt(curl_, CURLOPT_POSTFIELDS, data.data()); } +#else + body_ = data; +#endif } -inline void Session::setMultiformPart(const std::pair& fieldfield_and_filepath, +inline void Session::setMultiformPart(const std::pair& filefield_and_filepath, const std::map& fields) { - // https://curl.se/libcurl/c/curl_mime_init.html +#ifndef __EMSCRIPTEN__ if (curl_) { if (mime_form_ != nullptr) { curl_mime_free(mime_form_); @@ -143,8 +261,8 @@ inline void Session::setMultiformPart(const std::pair& mime_form_ = curl_mime_init(curl_); field = curl_mime_addpart(mime_form_); - curl_mime_name(field, fieldfield_and_filepath.first.c_str()); - curl_mime_filedata(field, fieldfield_and_filepath.second.c_str()); + curl_mime_name(field, filefield_and_filepath.first.c_str()); + curl_mime_filedata(field, filefield_and_filepath.second.c_str()); for (const auto& field_pair: fields) { field = curl_mime_addpart(mime_form_); @@ -154,20 +272,49 @@ inline void Session::setMultiformPart(const std::pair& curl_easy_setopt(curl_, CURLOPT_MIMEPOST, mime_form_); } +#else + throw std::runtime_error("Multipart form data not supported in WASM"); +#endif } inline Response Session::getPrepare() { +#ifndef __EMSCRIPTEN__ if (curl_) { curl_easy_setopt(curl_, CURLOPT_HTTPGET, 1L); curl_easy_setopt(curl_, CURLOPT_POST, 0L); curl_easy_setopt(curl_, CURLOPT_NOBODY, 0L); } return makeRequest(); +#else + + return makeWasmRequest("GET", ""); + +#endif +} + +inline Response Session::postPrepare(const std::string& contentType) { +#ifndef __EMSCRIPTEN__ + return makeRequest(contentType); +#else + return makeWasmRequest("POST", contentType.empty() ? "application/json" : contentType); +#endif } -inline Response Session::postPrepare(const std::string& contentType) { return makeRequest(contentType); } +inline Response Session::deletePrepare() { +#ifndef __EMSCRIPTEN__ + if (curl_) { + curl_easy_setopt(curl_, CURLOPT_HTTPGET, 0L); + curl_easy_setopt(curl_, CURLOPT_NOBODY, 0L); + curl_easy_setopt(curl_, CURLOPT_CUSTOMREQUEST, "DELETE"); + } + return makeRequest(); +#else + return makeWasmRequest("DELETE", ""); +#endif +} inline Response Session::postPrepareOllama(const std::string& contentType) { +#ifndef __EMSCRIPTEN__ std::lock_guard lock(mutex_request_); struct curl_slist* headers = NULL; @@ -194,28 +341,45 @@ inline Response Session::postPrepareOllama(const std::string& contentType) { } } return {response_string, is_error, error_msg}; +#else + return makeWasmRequest("POST", contentType.empty() ? "application/json" : contentType); +#endif } -inline Response Session::deletePrepare() { - if (curl_) { - curl_easy_setopt(curl_, CURLOPT_HTTPGET, 0L); - curl_easy_setopt(curl_, CURLOPT_NOBODY, 0L); - curl_easy_setopt(curl_, CURLOPT_CUSTOMREQUEST, "DELETE"); - } - return makeRequest(); -} +inline Response Session::validOllamaModelsJson(const std::string& url) { +#ifndef __EMSCRIPTEN__ + std::lock_guard lock(mutex_request_); -inline void Session::set_auth_header(struct curl_slist** headers_ptr) { - auto headers = *headers_ptr; - if (provider_ == "OpenAI") { - std::string auth_str = "Authorization: Bearer " + token_; - headers = curl_slist_append(headers, auth_str.c_str()); - } else if (provider_ == "Azure") { - std::string auth_str = "api-key: " + token_; - headers = curl_slist_append(headers, auth_str.c_str()); + struct curl_slist* headers = NULL; + curl_easy_setopt(curl_, CURLOPT_HTTPHEADER, headers); + curl_easy_setopt(curl_, CURLOPT_URL, url.c_str()); + + std::string response_string; + std::string header_string; + curl_easy_setopt(curl_, CURLOPT_WRITEFUNCTION, writeFunction); + curl_easy_setopt(curl_, CURLOPT_WRITEDATA, &response_string); + curl_easy_setopt(curl_, CURLOPT_HEADERDATA, &header_string); + + res_ = curl_easy_perform(curl_); + bool is_error = false; + std::string error_msg{}; + if (res_ != CURLE_OK) { + is_error = true; + error_msg = " curl_easy_perform() failed: " + std::string{curl_easy_strerror(res_)}; + if (throw_exception_) { + throw std::runtime_error(error_msg); + } else { + std::cerr << error_msg << '\n'; + } } + return {response_string, is_error, error_msg}; +#else + url_ = url; + return makeWasmRequest("GET", ""); +#endif } +#ifndef __EMSCRIPTEN__ inline Response Session::makeRequest(const std::string& contentType) { std::lock_guard lock(mutex_request_); @@ -232,6 +396,10 @@ inline Response Session::makeRequest(const std::string& contentType) { set_auth_header(&headers); + if (provider_ == "Anthropic") { + headers = curl_slist_append(headers, "anthropic-version: 2023-06-01"); + } + std::string organization_str = provider_ + "-Organization: "; if (!organization_.empty()) { organization_str += organization_; @@ -241,7 +409,7 @@ inline Response Session::makeRequest(const std::string& contentType) { std::string beta_str = provider_ + "-Beta: "; if (!beta_.empty()) { beta_str += beta_; - headers = curl_slist_append(headers, beta_.c_str()); + headers = curl_slist_append(headers, beta_str.c_str()); } curl_easy_setopt(curl_, CURLOPT_HTTPHEADER, headers); @@ -270,9 +438,16 @@ inline Response Session::makeRequest(const std::string& contentType) { return {response_string, is_error, error_msg}; } -inline std::string Session::easyEscape(const std::string& text) { - char* encoded_output = curl_easy_escape(curl_, text.c_str(), static_cast(text.length())); - const auto str = std::string{encoded_output}; - curl_free(encoded_output); - return str; +inline void Session::set_auth_header(struct curl_slist** headers_ptr) { + if (provider_ == "OpenAI") { + std::string auth_str = "Authorization: Bearer " + token_; + *headers_ptr = curl_slist_append(*headers_ptr, auth_str.c_str()); + } else if (provider_ == "Azure") { + std::string auth_str = "api-key: " + token_; + *headers_ptr = curl_slist_append(*headers_ptr, auth_str.c_str()); + } else if (provider_ == "Anthropic") { + std::string auth_str = "x-api-key: " + token_; + *headers_ptr = curl_slist_append(*headers_ptr, auth_str.c_str()); + } } +#endif diff --git a/src/include/flock/model_manager/providers/handlers/wasm_http.hpp b/src/include/flock/model_manager/providers/handlers/wasm_http.hpp new file mode 100644 index 00000000..5c268662 --- /dev/null +++ b/src/include/flock/model_manager/providers/handlers/wasm_http.hpp @@ -0,0 +1,7 @@ +#pragma once + +#ifdef __EMSCRIPTEN__ + +extern "C" char* wasm_http_request(const char* method, const char* url, const char* body, const char* headers_json); + +#endif // __EMSCRIPTEN__ diff --git a/src/model_manager/CMakeLists.txt b/src/model_manager/CMakeLists.txt index cfc0af5c..abf66785 100644 --- a/src/model_manager/CMakeLists.txt +++ b/src/model_manager/CMakeLists.txt @@ -6,5 +6,6 @@ set(EXTENSION_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/providers/adapters/azure.cpp ${CMAKE_CURRENT_SOURCE_DIR}/providers/adapters/openai.cpp ${CMAKE_CURRENT_SOURCE_DIR}/providers/adapters/ollama.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/providers/handlers/wasm_http.cpp ${EXTENSION_SOURCES} PARENT_SCOPE) diff --git a/src/model_manager/providers/handlers/wasm_http.cpp b/src/model_manager/providers/handlers/wasm_http.cpp new file mode 100644 index 00000000..1ed9f085 --- /dev/null +++ b/src/model_manager/providers/handlers/wasm_http.cpp @@ -0,0 +1,63 @@ +#ifdef __EMSCRIPTEN__ + +#include +#include +#include + +// JavaScript XMLHttpRequest wrapper - works synchronously in web workers +EM_JS(char*, wasm_http_request_impl, (const char* method, const char* url, const char* body, const char* headers_json), { + try { + var xhr = new XMLHttpRequest(); + xhr.open(UTF8ToString(method), UTF8ToString(url), false); // false = synchronous + + // Parse and set headers + var headersStr = UTF8ToString(headers_json); + if (headersStr && headersStr !== "{}") { + try { + var headers = JSON.parse(headersStr); + for (var key in headers) { + if (headers.hasOwnProperty(key)) { + xhr.setRequestHeader(key, headers[key]); + } + } + } catch (e) { + // Ignore header parsing errors + } + } + + var bodyStr = UTF8ToString(body); + if (bodyStr && bodyStr.length > 0) { + xhr.send(bodyStr); + } else { + xhr.send(); + } + + // Return JSON with status and response + var result = JSON.stringify({ + status: xhr.status, + response: xhr.responseText + }); + + var lengthBytes = lengthBytesUTF8(result) + 1; + var stringOnWasmHeap = _malloc(lengthBytes); + stringToUTF8(result, stringOnWasmHeap, lengthBytes); + return stringOnWasmHeap; + } catch (e) { + var errorResult = JSON.stringify({ + status: 0, + response: "", + error: e.toString() + }); + var lengthBytes = lengthBytesUTF8(errorResult) + 1; + var stringOnWasmHeap = _malloc(lengthBytes); + stringToUTF8(errorResult, stringOnWasmHeap, lengthBytes); + return stringOnWasmHeap; + } +}); + +// C++ wrapper function +extern "C" char* wasm_http_request(const char* method, const char* url, const char* body, const char* headers_json) { + return wasm_http_request_impl(method, url, body, headers_json); +} + +#endif // __EMSCRIPTEN__