diff --git a/src/functions/aggregate/llm_first_or_last/implementation.cpp b/src/functions/aggregate/llm_first_or_last/implementation.cpp index db81038a..8cf8906e 100644 --- a/src/functions/aggregate/llm_first_or_last/implementation.cpp +++ b/src/functions/aggregate/llm_first_or_last/implementation.cpp @@ -5,11 +5,13 @@ namespace flockmtl { int LlmFirstOrLast::GetFirstOrLastTupleId(const nlohmann::json& tuples) { nlohmann::json data; const auto prompt = PromptManager::Render(user_query, tuples, function_type, model.GetModelDetails().tuple_format); - auto response = model.CallComplete(prompt, true, OutputType::INTEGER); + model.AddCompletionRequest(prompt, true, OutputType::INTEGER); + auto response = model.CollectCompletions()[0]; return response["items"][0]; } -nlohmann::json LlmFirstOrLast::Evaluate(nlohmann::json& tuples) { +template<> +nlohmann::json LlmFirstOrLast::Evaluate(nlohmann::json& tuples) { auto batch_tuples = nlohmann::json::array(); int start_index = 0; auto batch_size = std::min(model.GetModelDetails().batch_size, static_cast(tuples.size())); @@ -46,9 +48,43 @@ nlohmann::json LlmFirstOrLast::Evaluate(nlohmann::json& tuples) { return batch_tuples[0]; } +template<> +nlohmann::json LlmFirstOrLast::Evaluate(nlohmann::json& tuples) { + auto batch_size = std::min(model.GetModelDetails().batch_size, static_cast(tuples.size())); + if (batch_size <= 0) { + throw std::runtime_error("Batch size must be greater than zero"); + } + + std::vector current_tuples = tuples; + + do { + auto start_index = 0; + const auto n = static_cast(current_tuples.size()); + while (start_index < n) { + auto this_batch_size = std::min(batch_size, n - start_index); + nlohmann::json batch = nlohmann::json::array(); + for (int i = 0; i < this_batch_size; ++i) { + batch.push_back(current_tuples[start_index + i]); + } + const auto prompt = PromptManager::Render(user_query, batch, function_type, model.GetModelDetails().tuple_format); + model.AddCompletionRequest(prompt, true, OutputType::INTEGER); + start_index += this_batch_size; + } + std::vector new_tuples; + auto responses = model.CollectCompletions(); + for (size_t i = 0; i < responses.size(); ++i) { + int result_idx = responses[i]["items"][0]; + new_tuples.push_back(current_tuples[result_idx]); + } + current_tuples = std::move(new_tuples); + } while (current_tuples.size() > 1); + current_tuples[0].erase("flockmtl_tuple_id"); + return current_tuples[0]; +} + void LlmFirstOrLast::FinalizeResults(duckdb::Vector& states, duckdb::AggregateInputData& aggr_input_data, duckdb::Vector& result, idx_t count, idx_t offset, - AggregateFunctionType function_type) { + AggregateFunctionType function_type, ExecutionMode mode) { const auto states_vector = reinterpret_cast(duckdb::FlatVector::GetData(states)); auto function_instance = AggregateFunctionBase::GetInstance(); function_instance->function_type = function_type; @@ -64,7 +100,17 @@ void LlmFirstOrLast::FinalizeResults(duckdb::Vector& states, duckdb::AggregateIn tuple_with_id["flockmtl_tuple_id"] = j; tuples_with_ids.push_back(tuple_with_id); } - auto response = function_instance->Evaluate(tuples_with_ids); + nlohmann::json response; + switch (mode) { + case ExecutionMode::ASYNC: + response = function_instance->Evaluate(tuples_with_ids); + break; + case ExecutionMode::SYNC: + response = function_instance->Evaluate(tuples_with_ids); + break; + default: + break; + } result.SetValue(idx, response.dump()); } else { result.SetValue(idx, "{}");// Empty JSON object for null/empty states diff --git a/src/functions/aggregate/llm_first_or_last/instantiations.cpp b/src/functions/aggregate/llm_first_or_last/instantiations.cpp index e85c541a..f2308eb7 100644 --- a/src/functions/aggregate/llm_first_or_last/instantiations.cpp +++ b/src/functions/aggregate/llm_first_or_last/instantiations.cpp @@ -10,9 +10,13 @@ template void AggregateFunctionBase::SimpleUpdate(duckdb::Vector duckdb::data_ptr_t, idx_t); template void AggregateFunctionBase::Combine(duckdb::Vector&, duckdb::Vector&, duckdb::AggregateInputData&, idx_t); -template void LlmFirstOrLast::Finalize(duckdb::Vector&, duckdb::AggregateInputData&, - duckdb::Vector&, idx_t, idx_t); -template void LlmFirstOrLast::Finalize(duckdb::Vector&, duckdb::AggregateInputData&, - duckdb::Vector&, idx_t, idx_t); +template void LlmFirstOrLast::Finalize(duckdb::Vector&, duckdb::AggregateInputData&, + duckdb::Vector&, idx_t, idx_t); +template void LlmFirstOrLast::Finalize(duckdb::Vector&, duckdb::AggregateInputData&, + duckdb::Vector&, idx_t, idx_t); +template void LlmFirstOrLast::Finalize(duckdb::Vector&, duckdb::AggregateInputData&, + duckdb::Vector&, idx_t, idx_t); +template void LlmFirstOrLast::Finalize(duckdb::Vector&, duckdb::AggregateInputData&, + duckdb::Vector&, idx_t, idx_t); -} // namespace flockmtl +}// namespace flockmtl diff --git a/src/functions/aggregate/llm_first_or_last/registry.cpp b/src/functions/aggregate/llm_first_or_last/registry.cpp index c464b25f..0b77441b 100644 --- a/src/functions/aggregate/llm_first_or_last/registry.cpp +++ b/src/functions/aggregate/llm_first_or_last/registry.cpp @@ -4,25 +4,33 @@ namespace flockmtl { void AggregateRegistry::RegisterLlmFirst(duckdb::DatabaseInstance& db) { - auto string_concat = duckdb::AggregateFunction( - "llm_first", {duckdb::LogicalType::ANY, duckdb::LogicalType::ANY, duckdb::LogicalType::ANY}, - duckdb::LogicalType::VARCHAR, duckdb::AggregateFunction::StateSize, - LlmFirstOrLast::Initialize, LlmFirstOrLast::Operation, LlmFirstOrLast::Combine, - LlmFirstOrLast::Finalize, LlmFirstOrLast::SimpleUpdate, - nullptr, LlmFirstOrLast::Destroy); - - duckdb::ExtensionUtil::RegisterFunction(db, string_concat); + duckdb::ExtensionUtil::RegisterFunction(db, duckdb::AggregateFunction( + "llm_first", {duckdb::LogicalType::ANY, duckdb::LogicalType::ANY, duckdb::LogicalType::ANY}, + duckdb::LogicalType::VARCHAR, duckdb::AggregateFunction::StateSize, + LlmFirstOrLast::Initialize, LlmFirstOrLast::Operation, LlmFirstOrLast::Combine, + LlmFirstOrLast::Finalize, LlmFirstOrLast::SimpleUpdate, + nullptr, LlmFirstOrLast::Destroy)); + duckdb::ExtensionUtil::RegisterFunction(db, duckdb::AggregateFunction( + "llm_first_s", {duckdb::LogicalType::ANY, duckdb::LogicalType::ANY, duckdb::LogicalType::ANY}, + duckdb::LogicalType::VARCHAR, duckdb::AggregateFunction::StateSize, + LlmFirstOrLast::Initialize, LlmFirstOrLast::Operation, LlmFirstOrLast::Combine, + LlmFirstOrLast::Finalize, LlmFirstOrLast::SimpleUpdate, + nullptr, LlmFirstOrLast::Destroy)); } void AggregateRegistry::RegisterLlmLast(duckdb::DatabaseInstance& db) { - auto string_concat = duckdb::AggregateFunction( - "llm_last", {duckdb::LogicalType::ANY, duckdb::LogicalType::ANY, duckdb::LogicalType::ANY}, - duckdb::LogicalType::VARCHAR, duckdb::AggregateFunction::StateSize, - LlmFirstOrLast::Initialize, LlmFirstOrLast::Operation, LlmFirstOrLast::Combine, - LlmFirstOrLast::Finalize, LlmFirstOrLast::SimpleUpdate, - nullptr, LlmFirstOrLast::Destroy); - - duckdb::ExtensionUtil::RegisterFunction(db, string_concat); + duckdb::ExtensionUtil::RegisterFunction(db, duckdb::AggregateFunction( + "llm_last", {duckdb::LogicalType::ANY, duckdb::LogicalType::ANY, duckdb::LogicalType::ANY}, + duckdb::LogicalType::VARCHAR, duckdb::AggregateFunction::StateSize, + LlmFirstOrLast::Initialize, LlmFirstOrLast::Operation, LlmFirstOrLast::Combine, + LlmFirstOrLast::Finalize, LlmFirstOrLast::SimpleUpdate, + nullptr, LlmFirstOrLast::Destroy)); + duckdb::ExtensionUtil::RegisterFunction(db, duckdb::AggregateFunction( + "llm_last_s", {duckdb::LogicalType::ANY, duckdb::LogicalType::ANY, duckdb::LogicalType::ANY}, + duckdb::LogicalType::VARCHAR, duckdb::AggregateFunction::StateSize, + LlmFirstOrLast::Initialize, LlmFirstOrLast::Operation, LlmFirstOrLast::Combine, + LlmFirstOrLast::Finalize, LlmFirstOrLast::SimpleUpdate, + nullptr, LlmFirstOrLast::Destroy)); } }// namespace flockmtl \ No newline at end of file diff --git a/src/functions/aggregate/llm_reduce/implementation.cpp b/src/functions/aggregate/llm_reduce/implementation.cpp index 3c7ced63..e1431126 100644 --- a/src/functions/aggregate/llm_reduce/implementation.cpp +++ b/src/functions/aggregate/llm_reduce/implementation.cpp @@ -6,12 +6,14 @@ nlohmann::json LlmReduce::ReduceBatch(const nlohmann::json& tuples, const Aggreg nlohmann::json data; const auto prompt = PromptManager::Render(user_query, tuples, function_type, model.GetModelDetails().tuple_format); OutputType output_type = OutputType::STRING; - auto response = model.CallComplete(prompt, true, output_type); + model.AddCompletionRequest(prompt, true, output_type); + auto response = model.CollectCompletions()[0]; return response["items"][0]; }; -nlohmann::json LlmReduce::ReduceLoop(const std::vector& tuples, - const AggregateFunctionType& function_type) { +template<> +nlohmann::json LlmReduce::ReduceLoop(const std::vector& tuples, + const AggregateFunctionType& function_type) { auto batch_tuples = nlohmann::json::array(); int start_index = 0; auto batch_size = std::min(model.GetModelDetails().batch_size, static_cast(tuples.size())); @@ -46,9 +48,49 @@ nlohmann::json LlmReduce::ReduceLoop(const std::vector& tuples, return batch_tuples[0]; } +template<> +nlohmann::json LlmReduce::ReduceLoop(const std::vector& tuples, + const AggregateFunctionType& function_type) { + + auto batch_size = std::min(model.GetModelDetails().batch_size, static_cast(tuples.size())); + if (batch_size <= 0) { + throw std::runtime_error("Batch size must be greater than zero"); + } + + std::vector current_tuples = tuples; + + do { + auto start_index = 0; + const auto n = static_cast(current_tuples.size()); + + // Prepare all batches and add all completion requests + while (start_index < n) { + auto this_batch_size = std::min(batch_size, n - start_index); + nlohmann::json batch = nlohmann::json::array(); + for (int i = 0; i < this_batch_size; ++i) { + batch.push_back(current_tuples[start_index + i]); + } + const auto prompt = PromptManager::Render(user_query, batch, function_type, model.GetModelDetails().tuple_format); + OutputType output_type = OutputType::STRING; + model.AddCompletionRequest(prompt, true, output_type); + start_index += this_batch_size; + } + + // Collect all completions at once + std::vector new_tuples; + auto responses = model.CollectCompletions(); + for (size_t i = 0; i < responses.size(); ++i) { + new_tuples.push_back(responses[i]["items"][0]); + } + current_tuples = std::move(new_tuples); + } while (current_tuples.size() > 1); + + return current_tuples[0]; +} + void LlmReduce::FinalizeResults(duckdb::Vector& states, duckdb::AggregateInputData& aggr_input_data, duckdb::Vector& result, idx_t count, idx_t offset, - const AggregateFunctionType function_type) { + const AggregateFunctionType function_type, ExecutionMode mode) { const auto states_vector = reinterpret_cast(duckdb::FlatVector::GetData(states)); auto function_instance = AggregateFunctionBase::GetInstance(); @@ -57,7 +99,17 @@ void LlmReduce::FinalizeResults(duckdb::Vector& states, duckdb::AggregateInputDa auto* state = states_vector[idx]; if (state && state->value) { - auto response = function_instance->ReduceLoop(*state->value, function_type); + nlohmann::json response; + switch (mode) { + case ExecutionMode::SYNC: + response = function_instance->ReduceLoop(*state->value, function_type); + break; + case ExecutionMode::ASYNC: + response = function_instance->ReduceLoop(*state->value, function_type); + break; + default: + break; + } if (response.is_string()) { result.SetValue(idx, response.get()); } else { diff --git a/src/functions/aggregate/llm_reduce/instantiations.cpp b/src/functions/aggregate/llm_reduce/instantiations.cpp index c8b8e30f..ac57526d 100644 --- a/src/functions/aggregate/llm_reduce/instantiations.cpp +++ b/src/functions/aggregate/llm_reduce/instantiations.cpp @@ -10,7 +10,9 @@ template void AggregateFunctionBase::SimpleUpdate(duckdb::Vector[], d duckdb::data_ptr_t, idx_t); template void AggregateFunctionBase::Combine(duckdb::Vector&, duckdb::Vector&, duckdb::AggregateInputData&, idx_t); -template void LlmReduce::Finalize(duckdb::Vector&, duckdb::AggregateInputData&, - duckdb::Vector&, idx_t, idx_t); +template void LlmReduce::Finalize(duckdb::Vector&, duckdb::AggregateInputData&, + duckdb::Vector&, idx_t, idx_t); +template void LlmReduce::Finalize(duckdb::Vector&, duckdb::AggregateInputData&, + duckdb::Vector&, idx_t, idx_t); }// namespace flockmtl diff --git a/src/functions/aggregate/llm_reduce/registry.cpp b/src/functions/aggregate/llm_reduce/registry.cpp index d93d0420..2df2f9f9 100644 --- a/src/functions/aggregate/llm_reduce/registry.cpp +++ b/src/functions/aggregate/llm_reduce/registry.cpp @@ -4,14 +4,18 @@ namespace flockmtl { void AggregateRegistry::RegisterLlmReduce(duckdb::DatabaseInstance& db) { - auto string_concat = duckdb::AggregateFunction( - "llm_reduce", {duckdb::LogicalType::ANY, duckdb::LogicalType::ANY, duckdb::LogicalType::ANY}, - duckdb::LogicalType::JSON(), duckdb::AggregateFunction::StateSize, - LlmReduce::Initialize, LlmReduce::Operation, LlmReduce::Combine, - LlmReduce::Finalize, LlmReduce::SimpleUpdate, - nullptr, LlmReduce::Destroy); - - duckdb::ExtensionUtil::RegisterFunction(db, string_concat); + duckdb::ExtensionUtil::RegisterFunction(db, duckdb::AggregateFunction( + "llm_reduce", {duckdb::LogicalType::ANY, duckdb::LogicalType::ANY, duckdb::LogicalType::ANY}, + duckdb::LogicalType::JSON(), duckdb::AggregateFunction::StateSize, + LlmReduce::Initialize, LlmReduce::Operation, LlmReduce::Combine, + LlmReduce::Finalize, LlmReduce::SimpleUpdate, + nullptr, LlmReduce::Destroy)); + duckdb::ExtensionUtil::RegisterFunction(db, duckdb::AggregateFunction( + "llm_reduce_s", {duckdb::LogicalType::ANY, duckdb::LogicalType::ANY, duckdb::LogicalType::ANY}, + duckdb::LogicalType::JSON(), duckdb::AggregateFunction::StateSize, + LlmReduce::Initialize, LlmReduce::Operation, LlmReduce::Combine, + LlmReduce::Finalize, LlmReduce::SimpleUpdate, + nullptr, LlmReduce::Destroy)); } }// namespace flockmtl \ No newline at end of file diff --git a/src/functions/aggregate/llm_rerank/implementation.cpp b/src/functions/aggregate/llm_rerank/implementation.cpp index e193f06d..1b7708a5 100644 --- a/src/functions/aggregate/llm_rerank/implementation.cpp +++ b/src/functions/aggregate/llm_rerank/implementation.cpp @@ -6,8 +6,9 @@ std::vector LlmRerank::RerankBatch(const nlohmann::json& tuples) { nlohmann::json data; auto prompt = PromptManager::Render(user_query, tuples, AggregateFunctionType::RERANK, model.GetModelDetails().tuple_format); - auto response = model.CallComplete(prompt, true, OutputType::INTEGER); - return response["items"]; + model.AddCompletionRequest(prompt, true, OutputType::INTEGER); + auto responses = model.CollectCompletions(); + return responses[0]["items"]; }; nlohmann::json LlmRerank::SlidingWindow(nlohmann::json& tuples) { diff --git a/src/functions/aggregate/llm_rerank/registry.cpp b/src/functions/aggregate/llm_rerank/registry.cpp index 859a9b03..007c4822 100644 --- a/src/functions/aggregate/llm_rerank/registry.cpp +++ b/src/functions/aggregate/llm_rerank/registry.cpp @@ -4,13 +4,11 @@ namespace flockmtl { void AggregateRegistry::RegisterLlmRerank(duckdb::DatabaseInstance& db) { - auto string_concat = duckdb::AggregateFunction( - "llm_rerank", {duckdb::LogicalType::ANY, duckdb::LogicalType::ANY, duckdb::LogicalType::ANY}, - duckdb::LogicalType::VARCHAR, duckdb::AggregateFunction::StateSize, - LlmRerank::Initialize, LlmRerank::Operation, LlmRerank::Combine, LlmRerank::Finalize, LlmRerank::SimpleUpdate, - nullptr, LlmRerank::Destroy); - - duckdb::ExtensionUtil::RegisterFunction(db, string_concat); + duckdb::ExtensionUtil::RegisterFunction(db, duckdb::AggregateFunction( + "llm_rerank", {duckdb::LogicalType::ANY, duckdb::LogicalType::ANY, duckdb::LogicalType::ANY}, + duckdb::LogicalType::VARCHAR, duckdb::AggregateFunction::StateSize, + LlmRerank::Initialize, LlmRerank::Operation, LlmRerank::Combine, LlmRerank::Finalize, LlmRerank::SimpleUpdate, + nullptr, LlmRerank::Destroy)); } }// namespace flockmtl diff --git a/src/functions/scalar/llm_complete/implementation.cpp b/src/functions/scalar/llm_complete/implementation.cpp index c05bd899..e06a3cdb 100644 --- a/src/functions/scalar/llm_complete/implementation.cpp +++ b/src/functions/scalar/llm_complete/implementation.cpp @@ -21,7 +21,7 @@ void LlmComplete::ValidateArguments(duckdb::DataChunk& args) { } } -std::vector LlmComplete::Operation(duckdb::DataChunk& args) { +std::vector LlmComplete::Operation(duckdb::DataChunk& args, ExecutionMode mode) { // LlmComplete::ValidateArguments(args); auto model_details_json = CastVectorOfStructsToJson(args.data[0], 1)[0]; @@ -32,7 +32,8 @@ std::vector LlmComplete::Operation(duckdb::DataChunk& args) { std::vector results; if (args.ColumnCount() == 2) { auto template_str = prompt_details.prompt; - auto response = model.CallComplete(template_str, false); + model.AddCompletionRequest(template_str, false); + auto response = model.CollectCompletions()[0]; if (response.is_string()) { results.push_back(response.get()); } else { @@ -40,9 +41,17 @@ std::vector LlmComplete::Operation(duckdb::DataChunk& args) { } } else { auto tuples = CastVectorOfStructsToJson(args.data[2], args.size()); - - auto responses = BatchAndComplete(tuples, prompt_details.prompt, ScalarFunctionType::COMPLETE, model); - + nlohmann::json responses; + switch (mode) { + case ExecutionMode::SYNC: + responses = BatchAndComplete(tuples, prompt_details.prompt, ScalarFunctionType::COMPLETE, model); + break; + case ExecutionMode::ASYNC: + responses = BatchAndComplete(tuples, prompt_details.prompt, ScalarFunctionType::COMPLETE, model); + break; + default: + break; + } results.reserve(responses.size()); for (const auto& response: responses) { if (response.is_string()) { @@ -55,9 +64,10 @@ std::vector LlmComplete::Operation(duckdb::DataChunk& args) { return results; } +template void LlmComplete::Execute(duckdb::DataChunk& args, duckdb::ExpressionState& state, duckdb::Vector& result) { - if (const auto results = LlmComplete::Operation(args); static_cast(results.size()) == 1) { + if (const auto results = LlmComplete::Operation(args, MODE); static_cast(results.size()) == 1) { auto empty_vec = duckdb::Vector(std::string()); duckdb::UnaryExecutor::Execute( empty_vec, result, args.size(), @@ -70,4 +80,7 @@ void LlmComplete::Execute(duckdb::DataChunk& args, duckdb::ExpressionState& stat } } +template void LlmComplete::Execute(duckdb::DataChunk&, duckdb::ExpressionState&, duckdb::Vector&); +template void LlmComplete::Execute(duckdb::DataChunk&, duckdb::ExpressionState&, duckdb::Vector&); + }// namespace flockmtl diff --git a/src/functions/scalar/llm_complete/registry.cpp b/src/functions/scalar/llm_complete/registry.cpp index e90799ab..7d8d9c3e 100644 --- a/src/functions/scalar/llm_complete/registry.cpp +++ b/src/functions/scalar/llm_complete/registry.cpp @@ -5,7 +5,10 @@ namespace flockmtl { void ScalarRegistry::RegisterLlmComplete(duckdb::DatabaseInstance& db) { duckdb::ExtensionUtil::RegisterFunction(db, duckdb::ScalarFunction("llm_complete", {}, duckdb::LogicalType::JSON(), - LlmComplete::Execute, nullptr, nullptr, nullptr, + LlmComplete::Execute, nullptr, nullptr, nullptr, + nullptr, duckdb::LogicalType::ANY)); + duckdb::ExtensionUtil::RegisterFunction(db, duckdb::ScalarFunction("llm_complete_s", {}, duckdb::LogicalType::JSON(), + LlmComplete::Execute, nullptr, nullptr, nullptr, nullptr, duckdb::LogicalType::ANY)); } diff --git a/src/functions/scalar/llm_embedding/implementation.cpp b/src/functions/scalar/llm_embedding/implementation.cpp index 121723aa..b48bd582 100644 --- a/src/functions/scalar/llm_embedding/implementation.cpp +++ b/src/functions/scalar/llm_embedding/implementation.cpp @@ -14,8 +14,8 @@ void LlmEmbedding::ValidateArguments(duckdb::DataChunk& args) { } } -std::vector> LlmEmbedding::Operation(duckdb::DataChunk& args) { - // LlmEmbedding::ValidateArguments(args); +std::vector> LlmEmbedding::Operation(duckdb::DataChunk& args, ExecutionMode mode) { + LlmEmbedding::ValidateArguments(args); auto inputs = CastVectorOfStructsToJson(args.data[1], args.size()); auto model_details_json = CastVectorOfStructsToJson(args.data[0], 1)[0]; @@ -36,19 +36,23 @@ std::vector> LlmEmbedding::Operation(duckdb::DataC batch_size = static_cast(prepared_inputs.size()); } - std::vector> results; for (size_t i = 0; i < prepared_inputs.size(); i += batch_size) { std::vector batch_inputs; for (size_t j = i; j < i + batch_size && j < prepared_inputs.size(); j++) { batch_inputs.push_back(prepared_inputs[j]); } - auto embeddings = model.CallEmbedding(batch_inputs); - for (size_t index = 0; index < embeddings.size(); index++) { - duckdb::vector embedding; - for (auto& value: embeddings[index]) { - embedding.push_back(duckdb::Value(static_cast(value))); + model.AddEmbeddingRequest(batch_inputs); + } + + std::vector> results; + auto all_embeddings = model.CollectEmbeddings(); + for (size_t index = 0; index < all_embeddings.size(); index++) { + for (auto& embedding: all_embeddings[index]) { + duckdb::vector formatted_embedding; + for (auto& value: embedding) { + formatted_embedding.push_back(duckdb::Value(static_cast(value))); } - results.push_back(embedding); + results.push_back(formatted_embedding); } } return results; diff --git a/src/functions/scalar/llm_filter/implementation.cpp b/src/functions/scalar/llm_filter/implementation.cpp index 267c1331..e29fc8d1 100644 --- a/src/functions/scalar/llm_filter/implementation.cpp +++ b/src/functions/scalar/llm_filter/implementation.cpp @@ -19,7 +19,7 @@ void LlmFilter::ValidateArguments(duckdb::DataChunk& args) { } } -std::vector LlmFilter::Operation(duckdb::DataChunk& args) { +std::vector LlmFilter::Operation(duckdb::DataChunk& args, ExecutionMode mode) { // LlmFilter::ValidateArguments(args); auto model_details_json = CastVectorOfStructsToJson(args.data[0], 1)[0]; @@ -29,7 +29,17 @@ std::vector LlmFilter::Operation(duckdb::DataChunk& args) { auto tuples = CastVectorOfStructsToJson(args.data[2], args.size()); - auto responses = BatchAndComplete(tuples, prompt_details.prompt, ScalarFunctionType::FILTER, model); + nlohmann::json responses; + switch (mode) { + case ExecutionMode::SYNC: + responses = BatchAndComplete(tuples, prompt_details.prompt, ScalarFunctionType::FILTER, model); + break; + case ExecutionMode::ASYNC: + responses = BatchAndComplete(tuples, prompt_details.prompt, ScalarFunctionType::FILTER, model); + break; + default: + break; + } std::vector results; results.reserve(responses.size()); @@ -44,8 +54,9 @@ std::vector LlmFilter::Operation(duckdb::DataChunk& args) { return results; } +template void LlmFilter::Execute(duckdb::DataChunk& args, duckdb::ExpressionState& state, duckdb::Vector& result) { - const auto results = LlmFilter::Operation(args); + const auto results = LlmFilter::Operation(args, MODE); auto index = 0; for (const auto& res: results) { @@ -53,4 +64,7 @@ void LlmFilter::Execute(duckdb::DataChunk& args, duckdb::ExpressionState& state, } } +template void LlmFilter::Execute(duckdb::DataChunk& args, duckdb::ExpressionState& state, duckdb::Vector& result); +template void LlmFilter::Execute(duckdb::DataChunk& args, duckdb::ExpressionState& state, duckdb::Vector& result); + }// namespace flockmtl diff --git a/src/functions/scalar/llm_filter/registry.cpp b/src/functions/scalar/llm_filter/registry.cpp index f80a1625..664211b6 100644 --- a/src/functions/scalar/llm_filter/registry.cpp +++ b/src/functions/scalar/llm_filter/registry.cpp @@ -1,13 +1,17 @@ -#include "flockmtl/functions/scalar/llm_filter.hpp" #include "flockmtl/registry/registry.hpp" +#include "flockmtl/functions/scalar/llm_filter.hpp" namespace flockmtl { void ScalarRegistry::RegisterLlmFilter(duckdb::DatabaseInstance& db) { duckdb::ExtensionUtil::RegisterFunction( - db, duckdb::ScalarFunction("llm_filter", - {duckdb::LogicalType::ANY, duckdb::LogicalType::ANY, duckdb::LogicalType::ANY}, - duckdb::LogicalType::VARCHAR, LlmFilter::Execute)); + db, duckdb::ScalarFunction("llm_filter", + {duckdb::LogicalType::ANY, duckdb::LogicalType::ANY, duckdb::LogicalType::ANY}, + duckdb::LogicalType::VARCHAR, LlmFilter::Execute)); + duckdb::ExtensionUtil::RegisterFunction( + db, duckdb::ScalarFunction("llm_filter_s", + {duckdb::LogicalType::ANY, duckdb::LogicalType::ANY, duckdb::LogicalType::ANY}, + duckdb::LogicalType::VARCHAR, LlmFilter::Execute)); } -} // namespace flockmtl +}// namespace flockmtl diff --git a/src/functions/scalar/scalar.cpp b/src/functions/scalar/scalar.cpp index 85b9ea39..abe8eb08 100644 --- a/src/functions/scalar/scalar.cpp +++ b/src/functions/scalar/scalar.cpp @@ -10,13 +10,15 @@ nlohmann::json ScalarFunctionBase::Complete(const nlohmann::json& tuples, const if (function_type == ScalarFunctionType::FILTER) { output_type = OutputType::BOOL; } - auto response = model.CallComplete(prompt, true, output_type); - return response["items"]; + model.AddCompletionRequest(prompt, true, output_type); + auto response = model.CollectCompletions(); + return response[0]["items"]; }; -nlohmann::json ScalarFunctionBase::BatchAndComplete(const std::vector& tuples, - const std::string& user_prompt, - const ScalarFunctionType function_type, Model& model) { +template<> +nlohmann::json ScalarFunctionBase::BatchAndComplete(const std::vector& tuples, + const std::string& user_prompt, + const ScalarFunctionType function_type, Model& model) { const auto llm_template = PromptManager::GetTemplate(function_type); const auto model_details = model.GetModelDetails(); @@ -65,4 +67,55 @@ nlohmann::json ScalarFunctionBase::BatchAndComplete(const std::vector +nlohmann::json ScalarFunctionBase::BatchAndComplete(const std::vector& tuples, + const std::string& user_prompt, + const ScalarFunctionType function_type, Model& model) { + const auto llm_template = PromptManager::GetTemplate(function_type); + const auto model_details = model.GetModelDetails(); + auto batch_size = std::min(model.GetModelDetails().batch_size, static_cast(tuples.size())); + if (batch_size <= 0) { + throw std::runtime_error("Batch size must be greater than zero"); + } + + std::vector> all_batches; + int start_index = 0; + // Create all batches first + while (start_index < static_cast(tuples.size())) { + std::vector batch; + for (auto i = 0; i < batch_size && start_index + i < static_cast(tuples.size()); i++) { + batch.push_back(tuples[start_index + i]); + } + all_batches.push_back(batch); + start_index += batch_size; + } + + // For each batch, call Complete (which should AddRequest internally) + for (const auto& batch: all_batches) { + AddCompletionRequest(batch, user_prompt, function_type, model); + } + + // After all requests are queued, send them all and collect results + auto responses = model.CollectCompletions(); + auto results = nlohmann::json::array(); + for (auto& response: responses) { + auto items = response["items"]; + for (const auto& item: items) { + results.push_back(item); + } + } + return results; +} + }// namespace flockmtl diff --git a/src/include/flockmtl/functions/aggregate/llm_first_or_last.hpp b/src/include/flockmtl/functions/aggregate/llm_first_or_last.hpp index 8f0606b6..930ffae2 100644 --- a/src/include/flockmtl/functions/aggregate/llm_first_or_last.hpp +++ b/src/include/flockmtl/functions/aggregate/llm_first_or_last.hpp @@ -12,6 +12,7 @@ class LlmFirstOrLast : public AggregateFunctionBase { explicit LlmFirstOrLast() = default; int GetFirstOrLastTupleId(const nlohmann::json& tuples); + template nlohmann::json Evaluate(nlohmann::json& tuples); public: @@ -33,13 +34,13 @@ class LlmFirstOrLast : public AggregateFunctionBase { static void Destroy(duckdb::Vector& states, duckdb::AggregateInputData& aggr_input_data, idx_t count) { AggregateFunctionBase::Destroy(states, aggr_input_data, count); } - template + template static void Finalize(duckdb::Vector& states, duckdb::AggregateInputData& aggr_input_data, duckdb::Vector& result, idx_t count, idx_t offset) { - FinalizeResults(states, aggr_input_data, result, count, offset, function_type); + FinalizeResults(states, aggr_input_data, result, count, offset, function_type, mode); } static void FinalizeResults(duckdb::Vector& states, duckdb::AggregateInputData& aggr_input_data, - duckdb::Vector& result, idx_t count, idx_t offset, AggregateFunctionType function_type); + duckdb::Vector& result, idx_t count, idx_t offset, AggregateFunctionType function_type, ExecutionMode mode); }; }// namespace flockmtl diff --git a/src/include/flockmtl/functions/aggregate/llm_reduce.hpp b/src/include/flockmtl/functions/aggregate/llm_reduce.hpp index 30ac9e3e..755fc7cc 100644 --- a/src/include/flockmtl/functions/aggregate/llm_reduce.hpp +++ b/src/include/flockmtl/functions/aggregate/llm_reduce.hpp @@ -9,6 +9,7 @@ class LlmReduce : public AggregateFunctionBase { explicit LlmReduce() = default; nlohmann::json ReduceBatch(const nlohmann::json& tuples, const AggregateFunctionType& function_type); + template nlohmann::json ReduceLoop(const std::vector& tuples, const AggregateFunctionType& function_type); public: @@ -32,11 +33,11 @@ class LlmReduce : public AggregateFunctionBase { } static void FinalizeResults(duckdb::Vector& states, duckdb::AggregateInputData& aggr_input_data, duckdb::Vector& result, idx_t count, idx_t offset, - const AggregateFunctionType function_type); - template + const AggregateFunctionType function_type, ExecutionMode mode); + template static void Finalize(duckdb::Vector& states, duckdb::AggregateInputData& aggr_input_data, duckdb::Vector& result, idx_t count, idx_t offset) { - FinalizeResults(states, aggr_input_data, result, count, offset, function_type); + FinalizeResults(states, aggr_input_data, result, count, offset, function_type, mode); }; }; diff --git a/src/include/flockmtl/functions/scalar/llm_complete.hpp b/src/include/flockmtl/functions/scalar/llm_complete.hpp index 21b65b04..3023e510 100644 --- a/src/include/flockmtl/functions/scalar/llm_complete.hpp +++ b/src/include/flockmtl/functions/scalar/llm_complete.hpp @@ -7,8 +7,9 @@ namespace flockmtl { class LlmComplete : public ScalarFunctionBase { public: static void ValidateArguments(duckdb::DataChunk& args); - static std::vector Operation(duckdb::DataChunk& args); + static std::vector Operation(duckdb::DataChunk& args, ExecutionMode mode); + template static void Execute(duckdb::DataChunk& args, duckdb::ExpressionState& state, duckdb::Vector& result); }; -} // namespace flockmtl +}// namespace flockmtl diff --git a/src/include/flockmtl/functions/scalar/llm_embedding.hpp b/src/include/flockmtl/functions/scalar/llm_embedding.hpp index 24afdb70..1ce0adce 100644 --- a/src/include/flockmtl/functions/scalar/llm_embedding.hpp +++ b/src/include/flockmtl/functions/scalar/llm_embedding.hpp @@ -7,8 +7,8 @@ namespace flockmtl { class LlmEmbedding : public ScalarFunctionBase { public: static void ValidateArguments(duckdb::DataChunk& args); - static std::vector> Operation(duckdb::DataChunk& args); + static std::vector> Operation(duckdb::DataChunk& args, ExecutionMode mode = ExecutionMode::ASYNC); static void Execute(duckdb::DataChunk& args, duckdb::ExpressionState& state, duckdb::Vector& result); }; -} // namespace flockmtl +}// namespace flockmtl diff --git a/src/include/flockmtl/functions/scalar/llm_filter.hpp b/src/include/flockmtl/functions/scalar/llm_filter.hpp index d528639f..beb2bfd8 100644 --- a/src/include/flockmtl/functions/scalar/llm_filter.hpp +++ b/src/include/flockmtl/functions/scalar/llm_filter.hpp @@ -7,8 +7,9 @@ namespace flockmtl { class LlmFilter : public ScalarFunctionBase { public: static void ValidateArguments(duckdb::DataChunk& args); - static std::vector Operation(duckdb::DataChunk& args); + static std::vector Operation(duckdb::DataChunk& args, ExecutionMode mode); + template static void Execute(duckdb::DataChunk& args, duckdb::ExpressionState& state, duckdb::Vector& result); }; -} // namespace flockmtl +}// namespace flockmtl diff --git a/src/include/flockmtl/functions/scalar/scalar.hpp b/src/include/flockmtl/functions/scalar/scalar.hpp index 9edd5402..5af86f10 100644 --- a/src/include/flockmtl/functions/scalar/scalar.hpp +++ b/src/include/flockmtl/functions/scalar/scalar.hpp @@ -17,9 +17,12 @@ class ScalarFunctionBase { static void ValidateArguments(duckdb::DataChunk& args); static std::vector Operation(duckdb::DataChunk& args); static void Execute(duckdb::DataChunk& args, duckdb::ExpressionState& state, duckdb::Vector& result); - static nlohmann::json Complete(const nlohmann::json& tuples, const std::string& user_prompt, ScalarFunctionType function_type, Model& model); + + static void AddCompletionRequest(const nlohmann::json& tuples, const std::string& user_prompt, + ScalarFunctionType function_type, Model& model); + template static nlohmann::json BatchAndComplete(const std::vector& tuples, const std::string& user_prompt_name, ScalarFunctionType function_type, Model& model); diff --git a/src/include/flockmtl/model_manager/model.hpp b/src/include/flockmtl/model_manager/model.hpp index 0ce6c6c3..b4d38c01 100644 --- a/src/include/flockmtl/model_manager/model.hpp +++ b/src/include/flockmtl/model_manager/model.hpp @@ -17,13 +17,20 @@ namespace flockmtl { +enum class ExecutionMode { + SYNC, + ASYNC +}; + class Model { public: explicit Model(const nlohmann::json& model_json); explicit Model() = default; - nlohmann::json CallComplete(const std::string& prompt, bool json_response = true, - OutputType output_type = OutputType::STRING); - nlohmann::json CallEmbedding(const std::vector& inputs); + void AddCompletionRequest(const std::string& prompt, bool json_response = true, + OutputType output_type = OutputType::STRING); + void AddEmbeddingRequest(const std::vector& inputs); + std::vector CollectCompletions(const std::string& contentType = "application/json"); + std::vector CollectEmbeddings(const std::string& contentType = "application/json"); ModelDetails GetModelDetails(); static void SetMockProvider(const std::shared_ptr& mock_provider) { @@ -33,9 +40,10 @@ class Model { mock_provider_ = nullptr; } -private: std::shared_ptr provider_; + +private: ModelDetails model_details_; inline static std::shared_ptr mock_provider_ = nullptr; void ConstructProvider(); diff --git a/src/include/flockmtl/model_manager/providers/adapters/azure.hpp b/src/include/flockmtl/model_manager/providers/adapters/azure.hpp index c751dbbc..c0377d8d 100644 --- a/src/include/flockmtl/model_manager/providers/adapters/azure.hpp +++ b/src/include/flockmtl/model_manager/providers/adapters/azure.hpp @@ -7,10 +7,14 @@ namespace flockmtl { class AzureProvider : public IProvider { public: - AzureProvider(const ModelDetails& model_details) : IProvider(model_details) {} + AzureProvider(const ModelDetails& model_details) : IProvider(model_details) { + model_handler_ = + std::make_unique(model_details_.secret["api_key"], model_details_.secret["resource_name"], + model_details_.model, model_details_.secret["api_version"], true); + } - nlohmann::json CallComplete(const std::string& prompt, bool json_response, OutputType output_type) override; - nlohmann::json CallEmbedding(const std::vector& inputs) override; + void AddCompletionRequest(const std::string& prompt, bool json_response, OutputType output_type) override; + void AddEmbeddingRequest(const std::vector& inputs) override; }; }// namespace flockmtl diff --git a/src/include/flockmtl/model_manager/providers/adapters/ollama.hpp b/src/include/flockmtl/model_manager/providers/adapters/ollama.hpp index 079229bb..5cce8cff 100644 --- a/src/include/flockmtl/model_manager/providers/adapters/ollama.hpp +++ b/src/include/flockmtl/model_manager/providers/adapters/ollama.hpp @@ -7,10 +7,12 @@ namespace flockmtl { class OllamaProvider : public IProvider { public: - OllamaProvider(const ModelDetails& model_details) : IProvider(model_details) {} + OllamaProvider(const ModelDetails& model_details) : IProvider(model_details) { + model_handler_ = std::make_unique(model_details_.secret["api_url"], true); + } - nlohmann::json CallComplete(const std::string& prompt, bool json_response, OutputType output_type) override; - nlohmann::json CallEmbedding(const std::vector& inputs) override; + void AddCompletionRequest(const std::string& prompt, bool json_response, OutputType output_type) override; + void AddEmbeddingRequest(const std::vector& inputs) override; }; }// namespace flockmtl diff --git a/src/include/flockmtl/model_manager/providers/adapters/openai.hpp b/src/include/flockmtl/model_manager/providers/adapters/openai.hpp index fbbb363e..0ebfeee7 100644 --- a/src/include/flockmtl/model_manager/providers/adapters/openai.hpp +++ b/src/include/flockmtl/model_manager/providers/adapters/openai.hpp @@ -7,10 +7,17 @@ namespace flockmtl { class OpenAIProvider : public IProvider { public: - OpenAIProvider(const ModelDetails& model_details) : IProvider(model_details) {} + OpenAIProvider(const ModelDetails& model_details) : IProvider(model_details) { + auto base_url = std::string(""); + if (const auto it = model_details_.secret.find("base_url"); it != model_details_.secret.end()) { + base_url = it->second; + } + model_handler_ = std::make_unique( + model_details_.secret["api_key"], base_url, true); + } - nlohmann::json CallComplete(const std::string& prompt, bool json_response, OutputType output_type) override; - nlohmann::json CallEmbedding(const std::vector& inputs) override; + void AddCompletionRequest(const std::string& prompt, bool json_response, OutputType output_type) override; + void AddEmbeddingRequest(const std::vector& inputs) override; }; }// namespace flockmtl diff --git a/src/include/flockmtl/model_manager/providers/handlers/azure.hpp b/src/include/flockmtl/model_manager/providers/handlers/azure.hpp index 1add800d..b24358ac 100644 --- a/src/include/flockmtl/model_manager/providers/handlers/azure.hpp +++ b/src/include/flockmtl/model_manager/providers/handlers/azure.hpp @@ -1,20 +1,14 @@ -#pragma once - -#include "session.hpp" - -#include -#include -#include -#include +#include "flockmtl/model_manager/providers/handlers/base_handler.hpp" namespace flockmtl { -class AzureModelManager { +class AzureModelManager : public BaseModelProviderHandler { public: AzureModelManager(std::string token, std::string resource_name, std::string deployment_model_name, std::string api_version, bool throw_exception) - : _token(token), _resource_name(resource_name), _deployment_model_name(deployment_model_name), - _api_version(api_version), _session("Azure", throw_exception), _throw_exception(throw_exception) { + : BaseModelProviderHandler(throw_exception), + _token(token), _resource_name(resource_name), _deployment_model_name(deployment_model_name), + _api_version(api_version), _session("Azure", throw_exception) { _session.setToken(token, ""); } @@ -23,128 +17,60 @@ class AzureModelManager { AzureModelManager(AzureModelManager&&) = delete; AzureModelManager& operator=(AzureModelManager&&) = delete; - nlohmann::json CallComplete(const nlohmann::json& json, const std::string& contentType = "application/json") { - std::string url = "https://" + _resource_name + ".openai.azure.com/openai/deployments/" + - _deployment_model_name + "/chat/completions?api-version=" + _api_version; - _session.setUrl(url); - return execute_post(json.dump(), contentType); +protected: + void checkProviderSpecificResponse(const nlohmann::json& response, bool is_completion) override { + if (is_completion) { + if (response.contains("choices") && response["choices"].is_array() && !response["choices"].empty()) { + const auto& choice = response["choices"][0]; + if (choice.contains("finish_reason") && !choice["finish_reason"].is_null()) { + std::string finish_reason = choice["finish_reason"].get(); + if (finish_reason != "stop" && finish_reason != "length") { + throw std::runtime_error("Azure API did not finish successfully. finish_reason: " + finish_reason); + } + } + } + } else { + // Embedding-specific checks (if any) can be added here + if (response.contains("data") && response["data"].is_array() && response["data"].empty()) { + throw std::runtime_error("Azure API returned empty embedding data."); + } + } } - - nlohmann::json CallEmbedding(const nlohmann::json& json, const std::string& contentType = "application/json") { - std::string url = "https://" + _resource_name + ".openai.azure.com/openai/deployments/" + - _deployment_model_name + "/embeddings?api-version=" + _api_version; + std::string getCompletionUrl() const override { + return "https://" + _resource_name + ".openai.azure.com/openai/deployments/" + + _deployment_model_name + "/chat/completions?api-version=" + _api_version; + } + std::string getEmbedUrl() const override { + return "https://" + _resource_name + ".openai.azure.com/openai/deployments/" + + _deployment_model_name + "/embeddings?api-version=" + _api_version; + } + void prepareSessionForRequest(const std::string& url) override { _session.setUrl(url); - return execute_post(json.dump(), contentType); } - - // I am adding it here since I want to keep provider specific calls - // inside same file - static const char* get_azure_api_key() { - static int check_done = -1; - static const char* api_key = nullptr; - - if (check_done == -1) { - api_key = std::getenv("AZURE_API_KEY"); - check_done = 1; - } - - if (!api_key) { - throw std::runtime_error("AZURE_API_KEY environment variable is not set."); + void setParameters(const std::string& data, const std::string& contentType = "") override { + if (contentType != "multipart/form-data") { + _session.setBody(data); } - - return api_key; } - - static const char* get_azure_resource_name() { - static int check_done = -1; - static const char* rname = nullptr; - - if (check_done == -1) { - rname = std::getenv("AZURE_RESOURCE_NAME"); - check_done = 1; - } - - if (!rname) { - throw std::runtime_error("AZURE_RESOURCE_NAME environment variable is not set."); - } - - return rname; + auto postRequest(const std::string& contentType) -> decltype(((Session*) nullptr)->postPrepare(contentType)) override { + return _session.postPrepare(contentType); } - static const char* get_azure_api_version() { - static int check_done = -1; - static const char* api_version = nullptr; - - if (check_done == -1) { - api_version = std::getenv("AZURE_API_VERSION"); - check_done = 1; - } - - if (!api_version) { - throw std::runtime_error("AZURE_VERSION environment variable is not set."); + nlohmann::json ExtractCompletionOutput(const nlohmann::json& response) const override { + if (response.contains("choices") && response["choices"].is_array() && !response["choices"].empty()) { + const auto& choice = response["choices"][0]; + if (choice.contains("message") && choice["message"].contains("content")) { + return choice["message"]["content"].get(); + } } - - return api_version; + return {}; } -private: std::string _token; std::string _resource_name; std::string _deployment_model_name; std::string _api_version; Session _session; - bool _throw_exception; - - nlohmann::json execute_post(const std::string& data, const std::string& contentType) { - setParameters(data, contentType); - auto response = _session.postPrepare(contentType); - if (response.is_error) { - std::cout << ">> response error :\n" << response.text << "\n"; - trigger_error(response.error_message); - } - - nlohmann::json json {}; - if (isJson(response.text)) { - json = nlohmann::json::parse(response.text); - checkResponse(json); - } else { - trigger_error("Response is not a valid JSON"); - } - - return json; - } - - void trigger_error(const std::string& msg) { - if (_throw_exception) { - throw std::runtime_error("[Azure] error. Reason: " + msg); - } else { - std::cerr << "[Azure] error. Reason: " << msg << '\n'; - } - } - - void checkResponse(const nlohmann::json& json) { - if (json.contains("error")) { - auto reason = json["error"].dump(); - trigger_error(reason); - std::cerr << ">> response error :\n" << json.dump(2) << "\n"; - } - } - - bool isJson(const std::string& data) { - bool rc = true; - try { - auto json = nlohmann::json::parse(data); // throws if no json - } catch (std::exception&) { - rc = false; - } - return (rc); - } - - void setParameters(const std::string& data, const std::string& contentType = "") { - if (contentType != "multipart/form-data") { - _session.setBody(data); - } - } }; -} // namespace flockmtl \ No newline at end of file +}// namespace flockmtl diff --git a/src/include/flockmtl/model_manager/providers/handlers/base_handler.hpp b/src/include/flockmtl/model_manager/providers/handlers/base_handler.hpp new file mode 100644 index 00000000..c384171b --- /dev/null +++ b/src/include/flockmtl/model_manager/providers/handlers/base_handler.hpp @@ -0,0 +1,145 @@ +#pragma once + +#include "flockmtl/model_manager/providers/handlers/handler.hpp" +#include "session.hpp" +#include +#include +#include +#include +#include + +namespace flockmtl { + +class BaseModelProviderHandler : public IModelProviderHandler { +public: + explicit BaseModelProviderHandler(bool throw_exception) + : _throw_exception(throw_exception) {} + virtual ~BaseModelProviderHandler() = default; + + // AddRequest: just add the json to the batch (type is ignored, kept for interface compatibility) + void AddRequest(const nlohmann::json& json, RequestType type = RequestType::Completion) { + _request_batch.push_back(json); + } + + // CollectCompletions: process all as completions, then clear + std::vector CollectCompletions(const std::string& contentType = "application/json") { + std::vector completions; + if (!_request_batch.empty()) completions = ExecuteBatch(_request_batch, true, contentType, true); + _request_batch.clear(); + return completions; + } + + // CollectEmbeddings: process all as embeddings, then clear + std::vector CollectEmbeddings(const std::string& contentType = "application/json") { + std::vector embeddings; + if (!_request_batch.empty()) embeddings = ExecuteBatch(_request_batch, true, contentType, false); + _request_batch.clear(); + return embeddings; + } + + // Unified batch implementation with customizable headers + std::vector ExecuteBatch(const std::vector& jsons, bool async = true, const std::string& contentType = "application/json", bool is_completion = true) { + struct CurlRequestData { + std::string response; + CURL* easy = nullptr; + std::string payload; + }; + std::vector requests(jsons.size()); + CURLM* multi_handle = curl_multi_init(); + struct curl_slist* headers = nullptr; + headers = curl_slist_append(headers, "Content-Type: application/json"); + for (const auto& h: getExtraHeaders()) { + headers = curl_slist_append(headers, h.c_str()); + } + std::string url = is_completion ? getCompletionUrl() : getEmbedUrl(); + for (size_t i = 0; i < jsons.size(); ++i) { + requests[i].payload = jsons[i].dump(); + requests[i].easy = curl_easy_init(); + curl_easy_setopt(requests[i].easy, CURLOPT_URL, url.c_str()); + curl_easy_setopt(requests[i].easy, CURLOPT_HTTPHEADER, headers); + curl_easy_setopt(requests[i].easy, CURLOPT_WRITEFUNCTION, +[](char* ptr, size_t size, size_t nmemb, void* userdata) -> size_t { + std::string* resp = static_cast(userdata); + resp->append(ptr, size * nmemb); + return size * nmemb; }); + curl_easy_setopt(requests[i].easy, CURLOPT_WRITEDATA, &requests[i].response); + curl_easy_setopt(requests[i].easy, CURLOPT_POST, 1L); + curl_easy_setopt(requests[i].easy, CURLOPT_POSTFIELDS, requests[i].payload.c_str()); + curl_multi_add_handle(multi_handle, requests[i].easy); + } + int still_running = 0; + curl_multi_perform(multi_handle, &still_running); + while (still_running) { + int numfds; + curl_multi_wait(multi_handle, NULL, 0, 1000, &numfds); + curl_multi_perform(multi_handle, &still_running); + } + std::vector results(jsons.size()); + for (size_t i = 0; i < requests.size(); ++i) { + curl_easy_getinfo(requests[i].easy, CURLINFO_RESPONSE_CODE, NULL); + if (!requests[i].response.empty() && isJson(requests[i].response)) { + try { + nlohmann::json parsed = nlohmann::json::parse(requests[i].response); + checkResponse(parsed, is_completion); + if (is_completion) { + results[i] = ExtractCompletionOutput(parsed); + } else { + results[i] = ExtractEmbeddingVector(parsed); + } + } catch (const std::exception& e) { + trigger_error(std::string("JSON parse error: ") + e.what()); + } + } else { + trigger_error("Empty or invalid response in batch"); + } + curl_multi_remove_handle(multi_handle, requests[i].easy); + curl_easy_cleanup(requests[i].easy); + } + curl_slist_free_all(headers); + curl_multi_cleanup(multi_handle); + return results; + } + + virtual void setParameters(const std::string& data, const std::string& contentType = "") = 0; + virtual auto postRequest(const std::string& contentType) -> decltype(((Session*) nullptr)->postPrepare(contentType)) = 0; + +protected: + bool _throw_exception; + std::vector _request_batch; + + virtual std::string getCompletionUrl() const = 0; + virtual std::string getEmbedUrl() const = 0; + virtual void prepareSessionForRequest(const std::string& url) = 0; + virtual std::vector getExtraHeaders() const { return {}; } + virtual void checkProviderSpecificResponse(const nlohmann::json&, bool is_completion) {} + virtual nlohmann::json ExtractCompletionOutput(const nlohmann::json&) const { return {}; } + virtual nlohmann::json ExtractEmbeddingVector(const nlohmann::json&) const { return {}; } + + void trigger_error(const std::string& msg) { + if (_throw_exception) { + throw std::runtime_error("[ModelProvider] error. Reason: " + msg); + } else { + std::cerr << "[ModelProvider] error. Reason: " << msg << '\n'; + } + } + + void checkResponse(const nlohmann::json& json, bool is_completion) { + if (json.contains("error")) { + auto reason = json["error"].dump(); + trigger_error(reason); + std::cerr << ">> response error :\n" + << json.dump(2) << "\n"; + } + checkProviderSpecificResponse(json, is_completion); + } + + bool isJson(const std::string& data) { + try { + nlohmann::json::parse(data); + } catch (...) { + return false; + } + return true; + } +}; + +}// namespace flockmtl diff --git a/src/include/flockmtl/model_manager/providers/handlers/handler.hpp b/src/include/flockmtl/model_manager/providers/handlers/handler.hpp new file mode 100644 index 00000000..c97b9b91 --- /dev/null +++ b/src/include/flockmtl/model_manager/providers/handlers/handler.hpp @@ -0,0 +1,22 @@ +#pragma once + +#include + +namespace flockmtl { + +class IModelProviderHandler { +public: + enum class RequestType { Completion, + Embedding }; + + virtual ~IModelProviderHandler() = default; + // AddRequest: type distinguishes between completion and embedding (default: Completion) + virtual void AddRequest(const nlohmann::json& json, RequestType type = RequestType::Completion) = 0; + + // CollectCompletions: process all as completions, then clear + virtual std::vector CollectCompletions(const std::string& contentType = "application/json") = 0; + // CollectEmbeddings: process all as embeddings, then clear + virtual std::vector CollectEmbeddings(const std::string& contentType = "application/json") = 0; +}; + +}// namespace flockmtl diff --git a/src/include/flockmtl/model_manager/providers/handlers/ollama.hpp b/src/include/flockmtl/model_manager/providers/handlers/ollama.hpp index 689d8c1e..1352f626 100644 --- a/src/include/flockmtl/model_manager/providers/handlers/ollama.hpp +++ b/src/include/flockmtl/model_manager/providers/handlers/ollama.hpp @@ -1,126 +1,69 @@ -#ifndef _FLOCK_MTL_MODEL_MANAGER_OLLAMA_H -#define _FLOCK_MTL_MODEL_MANAGER_OLLAMA_H +#pragma once +#include "flockmtl/model_manager/providers/handlers/base_handler.hpp" #include "session.hpp" - +#include +#include #include #include #include #include +#include namespace flockmtl { -class OllamaModelManager { +class OllamaModelManager : public BaseModelProviderHandler { public: OllamaModelManager(const std::string& url, const bool throw_exception) - : _session("Ollama", throw_exception), _throw_exception(throw_exception), _url(url) {} + : BaseModelProviderHandler(throw_exception), _session("Ollama", throw_exception), _url(url) {} + OllamaModelManager(const OllamaModelManager&) = delete; OllamaModelManager& operator=(const OllamaModelManager&) = delete; OllamaModelManager(OllamaModelManager&&) = delete; OllamaModelManager& operator=(OllamaModelManager&&) = delete; - std::string GetChatUrl() const { return _url + "/api/generate"; } - - std::string GetEmbedUrl() const { return _url + "/api/embeddings"; } - - std::string GetAvailableOllamaModelsUrl() { - static int check_done = -1; - static const char* url = nullptr; - - if (check_done == -1) { - url = std::getenv("OLLAMA_AVAILABLE_MODELS_URL"); - check_done = 1; - } - - if (!url) { - throw std::runtime_error("OLLAMA_AVAILABLE_MODELS_URL environment variable is not set."); +protected: + std::string getCompletionUrl() const override { return _url + "/api/generate"; } + std::string getEmbedUrl() const override { return _url + "/api/embeddings"; } + void prepareSessionForRequest(const std::string& url) override { _session.setUrl(url); } + void setParameters(const std::string& data, const std::string& contentType = "") override { + if (contentType != "multipart/form-data") { + _session.setBody(data); } - - return url; } - - nlohmann::json CallComplete(const nlohmann::json& json, const std::string& contentType = "application/json") { - const std::string url = GetChatUrl(); - _session.setUrl(url); - return execute_post(json.dump(), contentType); + auto postRequest(const std::string& contentType) -> decltype(((Session*) nullptr)->postPrepareOllama(contentType)) override { + return _session.postPrepareOllama(contentType); } - - nlohmann::json CallEmbedding(const nlohmann::json& json, const std::string& contentType = "application/json") { - const std::string url = GetEmbedUrl(); - _session.setUrl(url); - return execute_post(json.dump(), contentType); - } - - bool validModel(const std::string& user_model_name) { - std::string url = GetAvailableOllamaModelsUrl(); - auto response = _session.validOllamaModelsJson(url); - auto json = nlohmann::json::parse(response.text); - bool res = false; - for (const auto& model : json["models"]) { - if (model.contains("name")) { - const auto& available_model = model["name"].get(); - res |= available_model.find(user_model_name) != std::string::npos; + void checkProviderSpecificResponse(const nlohmann::json& response, bool is_completion) override { + if (is_completion) { + if ((response.contains("done_reason") && response["done_reason"] != "stop") || + (response.contains("done") && !response["done"].is_null() && response["done"].get() != true)) { + throw std::runtime_error("The request was refused due to some internal error with Ollama API"); } - } - return res; - } - -private: - Session _session; - bool _throw_exception; - std::string _url; - - nlohmann::json execute_post(const std::string& data, const std::string& contentType) { - setParameters(data, contentType); - auto response = _session.postPrepareOllama(contentType); - if (response.is_error) { - trigger_error(response.error_message); - } - - nlohmann::json json {}; - if (isJson(response.text)) { - - json = nlohmann::json::parse(response.text); - checkResponse(json); } else { - trigger_error("Response is not a valid JSON"); - } - - return json; - } - - void trigger_error(const std::string& msg) { - if (_throw_exception) { - throw std::runtime_error(msg); - } else { - std::cerr << "[Ollama] error. Reason: " << msg << '\n'; + // Embedding-specific checks (if any) can be added here + if (response.contains("embedding") && (!response["embedding"].is_array() || response["embedding"].empty())) { + throw std::runtime_error("Ollama API returned empty or invalid embedding data."); + } } } - void checkResponse(const nlohmann::json& json) { - if (json.count("error")) { - auto reason = json["error"].dump(); - trigger_error(reason); - std::cerr << ">> response error :\n" << json.dump(2) << "\n"; + nlohmann::json ExtractCompletionOutput(const nlohmann::json& response) const override { + if (response.contains("response")) { + return nlohmann::json::parse(response["response"].get()); } + return {}; } - bool isJson(const std::string& data) { - bool rc = true; - try { - auto json = nlohmann::json::parse(data); // throws if no json - } catch (std::exception&) { - rc = false; + nlohmann::json ExtractEmbeddingVector(const nlohmann::json& response) const override { + if (response.contains("embedding") && response["embedding"].is_array()) { + return response["embedding"]; } - return (rc); + return {}; } - void setParameters(const std::string& data, const std::string& contentType = "") { - if (contentType != "multipart/form-data") { - _session.setBody(data); - } - } + Session _session; + std::string _url; }; -} // namespace flockmtl -#endif \ No newline at end of file +}// namespace flockmtl diff --git a/src/include/flockmtl/model_manager/providers/handlers/openai.hpp b/src/include/flockmtl/model_manager/providers/handlers/openai.hpp index 524721d3..ed101d72 100644 --- a/src/include/flockmtl/model_manager/providers/handlers/openai.hpp +++ b/src/include/flockmtl/model_manager/providers/handlers/openai.hpp @@ -1,871 +1,96 @@ -// The MIT License (MIT) -// -// Copyright (c) 2023 Olrea, Florian Dang -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -// SOFTWARE. - -#ifndef OPENAI_HPP_ -#define OPENAI_HPP_ - -#if OPENAI_VERBOSE_OUTPUT -#pragma message("OPENAI_VERBOSE_OUTPUT is ON") -#endif +#pragma once +#include "flockmtl/model_manager/providers/handlers/base_handler.hpp" +#include "session.hpp" #include #include -#include -#include +#include #include #include -#include "session.hpp" -#include // nlohmann/json - -namespace openai { - -namespace _detail { - -// Json alias -using Json = nlohmann::json; - -// forward declaration for category structures -class OpenAI; - -// https://platform.openai.com/docs/api-reference/models -// List and describe the various models available in the API. You can refer to -// the Models documentation to understand what models are available and the -// differences between them. -struct CategoryModel { - Json list(); - Json retrieve(const std::string &model); - - CategoryModel(OpenAI &openai) : openai_ {openai} {} - -private: - OpenAI &openai_; -}; - -// https://platform.openai.com/docs/api-reference/assistants -// Build assistants that can call models and use tools to perform tasks. -struct CategoryAssistants { - Json create(Json input); - Json retrieve(const std::string &assistants); - Json modify(const std::string &assistants, Json input); - Json del(const std::string &assistants); - Json list(); - Json createFile(const std::string &assistants, Json input); - Json retrieveFile(const std::string &assistants, const std::string &files); - Json delFile(const std::string &assistants, const std::string &files); - Json listFile(const std::string &assistants); - - CategoryAssistants(OpenAI &openai) : openai_ {openai} {} - -private: - OpenAI &openai_; -}; - -// https://platform.openai.com/docs/api-reference/threads -// Create threads that assistants can interact with. -struct CategoryThreads { - Json create(); - Json retrieve(const std::string &threads); - Json modify(const std::string &threads, Json input); - Json del(const std::string &threads); - Json list(); - - // https://platform.openai.com/docs/api-reference/messages - // Create messages within threads - Json createMessage(const std::string &threads, Json input); - Json retrieveMessage(const std::string &threads, const std::string &messages); - Json modifyMessage(const std::string &threads, const std::string &messages, Json input); - Json listMessage(const std::string &threads); - Json retrieveMessageFile(const std::string &threads, const std::string &messages, const std::string &files); - Json listMessageFile(const std::string &threads, const std::string &messages); - - // https://platform.openai.com/docs/api-reference/runs - // Represents an execution run on a thread. - Json createRun(const std::string &threads, Json input); - Json retrieveRun(const std::string &threads, const std::string &runs); - Json modifyRun(const std::string &threads, const std::string &runs, Json input); - Json listRun(const std::string &threads); - Json submitToolOutputsToRun(const std::string &threads, const std::string &runs, Json input); - Json cancelRun(const std::string &threads, const std::string &runs); - Json createThreadAndRun(Json input); - Json retrieveRunStep(const std::string &threads, const std::string &runs, const std::string &steps); - Json listRunStep(const std::string &threads, const std::string &runs); - - CategoryThreads(OpenAI &openai) : openai_ {openai} {} - -private: - OpenAI &openai_; -}; - -// https://platform.openai.com/docs/api-reference/completions -// Given a prompt, the model will return one or more predicted completions, and -// can also return the probabilities of alternative tokens at each position. -struct CategoryCompletion { - Json create(Json input); - - CategoryCompletion(OpenAI &openai) : openai_ {openai} {} - -private: - OpenAI &openai_; -}; - -// https://platform.openai.com/docs/api-reference/chat -// Given a prompt, the model will return one or more predicted chat completions. -struct CategoryChat { - Json create(Json input); - - CategoryChat(OpenAI &openai) : openai_ {openai} {} - -private: - OpenAI &openai_; -}; - -// https://platform.openai.com/docs/api-reference/audio -// Learn how to turn audio into text. -struct CategoryAudio { - Json transcribe(Json input); - Json translate(Json input); - - CategoryAudio(OpenAI &openai) : openai_ {openai} {} - -private: - OpenAI &openai_; -}; - -// https://platform.openai.com/docs/api-reference/edits -// Given a prompt and an instruction, the model will return an edited version of -// the prompt. -struct CategoryEdit { - Json create(Json input); - - CategoryEdit(OpenAI &openai) : openai_ {openai} {} - -private: - OpenAI &openai_; -}; - -// https://platform.openai.com/docs/api-reference/images -// Given a prompt and/or an input image, the model will generate a new image. -struct CategoryImage { - Json create(Json input); - Json edit(Json input); - Json variation(Json input); - - CategoryImage(OpenAI &openai) : openai_ {openai} {} - -private: - OpenAI &openai_; -}; - -// https://platform.openai.com/docs/api-reference/embeddings -// Get a vector representation of a given input that can be easily consumed by -// machine learning models and algorithms. -struct CategoryEmbedding { - Json create(Json input); - CategoryEmbedding(OpenAI &openai) : openai_ {openai} {} - -private: - OpenAI &openai_; -}; - -struct FileRequest { - std::string file; - std::string purpose; -}; - -// https://platform.openai.com/docs/api-reference/files -// Files are used to upload documents that can be used with features like -// Fine-tuning. -struct CategoryFile { - Json list(); - Json upload(Json input); - Json del(const std::string &file); // TODO - Json retrieve(const std::string &file_id); - Json content(const std::string &file_id); - - CategoryFile(OpenAI &openai) : openai_ {openai} {} - -private: - OpenAI &openai_; -}; - -// https://platform.openai.com/docs/api-reference/fine-tunes -// Manage fine-tuning jobs to tailor a model to your specific training data. -struct CategoryFineTune { - Json create(Json input); - Json list(); - Json retrieve(const std::string &fine_tune_id); - Json content(const std::string &fine_tune_id); - Json cancel(const std::string &fine_tune_id); - Json events(const std::string &fine_tune_id); - Json del(const std::string &model); - - CategoryFineTune(OpenAI &openai) : openai_ {openai} {} - -private: - OpenAI &openai_; -}; - -// https://platform.openai.com/docs/api-reference/moderations -// Given a input text, outputs if the model classifies it as violating OpenAI's -// content policy. -struct CategoryModeration { - Json create(Json input); - CategoryModeration(OpenAI &openai) : openai_ {openai} {} - -private: - OpenAI &openai_; -}; +namespace flockmtl { -// OpenAI -class OpenAI { +class OpenAIModelManager : public BaseModelProviderHandler { public: - OpenAI(const std::string &token = "", const std::string &organization = "", bool throw_exception = true, - const std::string &api_base_url = "", const std::string &beta = "") - : session_ {"OpenAI", throw_exception}, token_ {token}, organization_ {organization}, - throw_exception_ {throw_exception} { - if (token.empty()) { - if (const char *env_p = std::getenv("OPENAI_API_KEY")) { - token_ = std::string {env_p}; - } - } + OpenAIModelManager(std::string token, std::string api_base_url, bool throw_exception) + : BaseModelProviderHandler(throw_exception), _token(token), _session("OpenAI", throw_exception) { + _session.setToken(token, ""); if (api_base_url.empty()) { - if (const char *env_p = std::getenv("OPENAI_API_BASE")) { - base_url = std::string {env_p} + "/"; - } else { - base_url = "https://api.openai.com/v1/"; - } + _api_base_url = "https://api.openai.com/v1/"; } else { - base_url = api_base_url; + _api_base_url = api_base_url; } - session_.setUrl(base_url); - session_.setToken(token_, organization_); - session_.setBeta(beta); + _session.setUrl(_api_base_url); } - OpenAI(const OpenAI &) = delete; - OpenAI &operator=(const OpenAI &) = delete; - OpenAI(OpenAI &&) = delete; - OpenAI &operator=(OpenAI &&) = delete; - - void setToken(const std::string &token = "", const std::string &organization = "") { - session_.setToken(token, organization); - } - - static const char *get_openai_api_key() { - static int check_done = -1; - static const char *key = nullptr; - - if (check_done == -1) { - key = std::getenv("OPENAI_API_KEY"); - check_done = 1; - } - - if (key == nullptr) { - throw std::runtime_error("OPENAI_API_KEY environment variable is not set."); - } - - return key; - } - - void setProxy(const std::string &url) { session_.setProxyUrl(url); } - - void setBeta(const std::string &beta) { session_.setBeta(beta); } - - // void change_token(const std::string& token) { token_ = token; }; - void setThrowException(bool throw_exception) { throw_exception_ = throw_exception; } - - void setMultiformPart(const std::pair &filefield_and_filepath, - const std::map &fields) { - session_.setMultiformPart(filefield_and_filepath, fields); - } - - Json post(const std::string &suffix, const std::string &data, const std::string &contentType) { - setParameters(suffix, data, contentType); - auto response = session_.postPrepare(contentType); - if (response.is_error) { - trigger_error(response.error_message); - } - - Json json {}; - if (isJson(response.text)) { - - json = Json::parse(response.text); - checkResponse(json); - } else { -#if OPENAI_VERBOSE_OUTPUT - std::cerr << "Response is not a valid JSON"; - std::cout << "<< " << response.text << "\n"; -#endif - } - - return json; - } + OpenAIModelManager(const OpenAIModelManager&) = delete; + OpenAIModelManager& operator=(const OpenAIModelManager&) = delete; + OpenAIModelManager(OpenAIModelManager&&) = delete; + OpenAIModelManager& operator=(OpenAIModelManager&&) = delete; - Json get(const std::string &suffix, const std::string &data = "") { - setParameters(suffix, data); - auto response = session_.getPrepare(); - if (response.is_error) { - trigger_error(response.error_message); - } +protected: + std::string _token; + std::string _api_base_url; + Session _session; - Json json {}; - if (isJson(response.text)) { - json = Json::parse(response.text); - checkResponse(json); - } else { -#if OPENAI_VERBOSE_OUTPUT - std::cerr << "Response is not a valid JSON\n"; - std::cout << "<< " << response.text << "\n"; -#endif - json = Json {{"Result", response.text}}; - } - return json; + std::string getCompletionUrl() const override { + return _api_base_url + "chat/completions"; } - - Json post(const std::string &suffix, const Json &json, const std::string &contentType = "application/json") { - return post(suffix, json.dump(), contentType); + std::string getEmbedUrl() const override { + return _api_base_url + "embeddings"; } - - Json del(const std::string &suffix) { - setParameters(suffix, ""); - auto response = session_.deletePrepare(); - if (response.is_error) { - trigger_error(response.error_message); - } - - Json json {}; - if (isJson(response.text)) { - json = Json::parse(response.text); - checkResponse(json); - } else { -#if OPENAI_VERBOSE_OUTPUT - std::cerr << "Response is not a valid JSON\n"; - std::cout << "<< " << response.text << "\n"; -#endif - } - return json; + void prepareSessionForRequest(const std::string& url) override { + _session.setUrl(url); } - - std::string easyEscape(const std::string &text) { return session_.easyEscape(text); } - - void debug() const { std::cout << token_ << '\n'; } - - void setBaseUrl(const std::string &url) { base_url = url; } - - std::string getBaseUrl() const { return base_url; } - -private: - std::string base_url; - - void setParameters(const std::string &suffix, const std::string &data, const std::string &contentType = "") { - auto complete_url = base_url + suffix; - session_.setUrl(complete_url); - + void setParameters(const std::string& data, const std::string& contentType = "") override { if (contentType != "multipart/form-data") { - session_.setBody(data); + _session.setBody(data); } - -#if OPENAI_VERBOSE_OUTPUT - std::cout << "<< request: " << complete_url << " " << data << '\n'; -#endif } - - void checkResponse(const Json &json) { - if (json.count("error")) { - auto reason = json["error"].dump(); - trigger_error(reason); - -#if OPENAI_VERBOSE_OUTPUT - std::cerr << ">> response error :\n" << json.dump(2) << "\n"; -#endif - } + auto postRequest(const std::string& contentType) -> decltype(((Session*) nullptr)->postPrepare(contentType)) override { + return _session.postPrepare(contentType); } - - // as of now the only way - bool isJson(const std::string &data) { - bool rc = true; - try { - auto json = Json::parse(data); // throws if no json - } catch (std::exception &) { - rc = false; - } - return (rc); + std::vector getExtraHeaders() const override { + return {"Authorization: Bearer " + _token}; } - - void trigger_error(const std::string &msg) { - if (throw_exception_) { - throw std::runtime_error(msg); + void checkProviderSpecificResponse(const nlohmann::json& response, bool is_completion) override { + if (is_completion) { + if (response.contains("choices") && response["choices"].is_array() && !response["choices"].empty()) { + const auto& choice = response["choices"][0]; + if (choice.contains("finish_reason") && !choice["finish_reason"].is_null()) { + std::string finish_reason = choice["finish_reason"].get(); + if (finish_reason != "stop" && finish_reason != "length") { + throw std::runtime_error("OpenAI API did not finish successfully. finish_reason: " + finish_reason); + } + } + } } else { - std::cerr << "[OpenAI] error. Reason: " << msg << '\n'; + // Embedding-specific checks (if any) can be added here + if (response.contains("data") && response["data"].is_array() && response["data"].empty()) { + throw std::runtime_error("OpenAI API returned empty embedding data."); + } } } - -public: - CategoryModel model {*this}; - CategoryAssistants assistant {*this}; - CategoryThreads thread {*this}; - CategoryCompletion completion {*this}; - CategoryEdit edit {*this}; - CategoryImage image {*this}; - CategoryEmbedding embedding {*this}; - CategoryFile file {*this}; - CategoryFineTune fine_tune {*this}; - CategoryModeration moderation {*this}; - CategoryChat chat {*this}; - CategoryAudio audio {*this}; - // CategoryEngine engine{*this}; // Not handled since deprecated (use - // Model instead) - -private: - Session session_; - std::string token_; - std::string organization_; - bool throw_exception_; -}; - -inline std::string bool_to_string(const bool b) { - std::ostringstream ss; - ss << std::boolalpha << b; - return ss.str(); -} - -inline OpenAI &start(const std::string &token = "", const std::string &organization = "", bool throw_exception = true, - const std::string &api_base_url = "") { - static OpenAI instance {token, organization, throw_exception, api_base_url}; - return instance; -} - -inline OpenAI &instance() { return start(); } - -inline Json post(const std::string &suffix, const Json &json) { return instance().post(suffix, json); } - -inline Json get(const std::string &suffix /*, const Json& json*/) { return instance().get(suffix); } - -// Helper functions to get category structures instance() - -inline CategoryModel &model() { return instance().model; } - -inline CategoryAssistants &assistant() { return instance().assistant; } - -inline CategoryThreads &thread() { return instance().thread; } - -inline CategoryCompletion &completion() { return instance().completion; } - -inline CategoryChat &chat() { return instance().chat; } - -inline CategoryAudio &audio() { return instance().audio; } - -inline CategoryEdit &edit() { return instance().edit; } - -inline CategoryImage &image() { return instance().image; } - -inline CategoryEmbedding &embedding() { return instance().embedding; } - -inline CategoryFile &file() { return instance().file; } - -inline CategoryFineTune &fineTune() { return instance().fine_tune; } - -inline CategoryModeration &moderation() { return instance().moderation; } - -// Definitions of category methods - -// GET https://api.openai.com/v1/models -// Lists the currently available models, and provides basic information about -// each one such as the owner and availability. -inline Json CategoryModel::list() { return openai_.get("models"); } - -// GET https://api.openai.com/v1/models/{model} -// Retrieves a model instance, providing basic information about the model such -// as the owner and permissioning. -inline Json CategoryModel::retrieve(const std::string &model) { return openai_.get("models/" + model); } - -// POST https://api.openai.com/v1/assistants -// Create an assistant with a model and instructions. -inline Json CategoryAssistants::create(Json input) { return openai_.post("assistants", input); } - -// GET https://api.openai.com/v1/assistants/{assistant_id} -// Retrieves an assistant. -inline Json CategoryAssistants::retrieve(const std::string &assistants) { - return openai_.get("assistants/" + assistants); -} - -// POST https://api.openai.com/v1/assistants/{assistant_id} -// Modifies an assistant. -inline Json CategoryAssistants::modify(const std::string &assistants, Json input) { - return openai_.post("assistants/" + assistants, input); -} - -// DELETE https://api.openai.com/v1/assistants/{assistant_id} -// Delete an assistant. -inline Json CategoryAssistants::del(const std::string &assistants) { return openai_.del("assistants/" + assistants); } - -// GET https://api.openai.com/v1/assistants -// Returns a list of assistants. -inline Json CategoryAssistants::list() { return openai_.get("assistants"); } - -// POST https://api.openai.com/v1/assistants/{assistant_id}/files -// Create an assistant file by attaching a File to an assistant. -inline Json CategoryAssistants::createFile(const std::string &assistants, Json input) { - return openai_.post("assistants/" + assistants + "/files", input); -} - -// GET https://api.openai.com/v1/assistants/{assistant_id}/files/{file_id} -// Retrieves an AssistantFile. -inline Json CategoryAssistants::retrieveFile(const std::string &assistants, const std::string &files) { - return openai_.get("assistants/" + assistants + "/files/" + files); -} - -// DELETE https://api.openai.com/v1/assistants/{assistant_id}/files/{file_id} -// Delete an assistant file. -inline Json CategoryAssistants::delFile(const std::string &assistants, const std::string &files) { - return openai_.del("assistants/" + assistants + "/files/" + files); -} - -// GET https://api.openai.com/v1/assistants/{assistant_id}/files -// Returns a list of assistant files. -inline Json CategoryAssistants::listFile(const std::string &assistants) { - return openai_.get("assistants/" + assistants + "/files"); -} - -// POST https://api.openai.com/v1/threads -// Create a thread. -inline Json CategoryThreads::create() { - Json input; - return openai_.post("threads", input); -} - -// GET https://api.openai.com/v1/threads/{thread_id} -// Retrieves a thread. -inline Json CategoryThreads::retrieve(const std::string &threads) { return openai_.get("threads/" + threads); } - -// POST https://api.openai.com/v1/threads/{thread_id} -// Modifies a thread. -inline Json CategoryThreads::modify(const std::string &threads, Json input) { - return openai_.post("threads/" + threads, input); -} - -// DELETE https://api.openai.com/v1/threads/{thread_id} -// Delete a thread. -inline Json CategoryThreads::del(const std::string &threads) { return openai_.del("threads/" + threads); } - -// POST https://api.openai.com/v1/threads/{thread_id}/messages -// Create a message. -inline Json CategoryThreads::createMessage(const std::string &threads, Json input) { - return openai_.post("threads/" + threads + "/messages", input); -} - -// GET https://api.openai.com/v1/threads/{thread_id}/messages/{message_id} -// Retrieve a message. -inline Json CategoryThreads::retrieveMessage(const std::string &threads, const std::string &messages) { - return openai_.get("threads/" + threads + "/messages/" + messages); -} - -// POST https://api.openai.com/v1/threads/{thread_id}/messages/{message_id} -// Modifies a message. -inline Json CategoryThreads::modifyMessage(const std::string &threads, const std::string &messages, Json input) { - return openai_.post("threads/" + threads + "/messages/" + messages, input); -} - -// GET https://api.openai.com/v1/threads/{thread_id}/messages -// Returns a list of messages for a given thread. -inline Json CategoryThreads::listMessage(const std::string &threads) { - return openai_.get("threads/" + threads + "/messages"); -} - -// GET -// https://api.openai.com/v1/threads/{thread_id}/messages/{message_id}/files/{file_id} -// Retrieves a message file. -inline Json CategoryThreads::retrieveMessageFile(const std::string &threads, const std::string &messages, - const std::string &files) { - return openai_.get("threads/" + threads + "/messages/" + messages + "/files/" + files); -} - -// GET https://api.openai.com/v1/threads/{thread_id}/messages/{message_id}/files -// Returns a list of message files. -inline Json CategoryThreads::listMessageFile(const std::string &threads, const std::string &messages) { - return openai_.get("threads/" + threads + "/messages/" + messages + "/files"); -} - -// POST https://api.openai.com/v1/threads/{thread_id}/runs -// Create a run. -inline Json CategoryThreads::createRun(const std::string &threads, Json input) { - return openai_.post("threads/" + threads + "/runs", input); -} - -// GET https://api.openai.com/v1/threads/{thread_id}/runs/{run_id} -// Retrieves a run. -inline Json CategoryThreads::retrieveRun(const std::string &threads, const std::string &runs) { - return openai_.get("threads/" + threads + "/runs/" + runs); -} - -// POST https://api.openai.com/v1/threads/{thread_id}/runs/{run_id} -// Modifies a run. -inline Json CategoryThreads::modifyRun(const std::string &threads, const std::string &runs, Json input) { - return openai_.post("threads/" + threads + "/runs/" + runs, input); -} - -// GET https://api.openai.com/v1/threads/{thread_id}/runs -// Returns a list of runs belonging to a thread. -inline Json CategoryThreads::listRun(const std::string &threads) { return openai_.get("threads/" + threads + "/runs"); } - -// POST -// https://api.openai.com/v1/threads/{thread_id}/runs/{run_id}/submit_tool_outputs -// When a run has the status: "requires_action" and required_action.type is -// submit_tool_outputs, this endpoint can be used to submit the outputs from the -// tool calls once they're all completed. All outputs must be submitted in a -// single request. -inline Json CategoryThreads::submitToolOutputsToRun(const std::string &threads, const std::string &runs, Json input) { - return openai_.post("threads/" + threads + "/runs/" + runs + "submit_tool_outputs", input); -} - -// POST https://api.openai.com/v1/threads/{thread_id}/runs/{run_id}/cancel -// Cancels a run that is in_progress. -inline Json CategoryThreads::cancelRun(const std::string &threads, const std::string &runs) { - Json input; - return openai_.post("threads/" + threads + "/runs/" + runs + "/cancel", input); -} - -// POST https://api.openai.com/v1/threads/runs -// Create a thread and run it in one request. -inline Json CategoryThreads::createThreadAndRun(Json input) { return openai_.post("threads/runs", input); } - -// GET -// https://api.openai.com/v1/threads/{thread_id}/runs/{run_id}/steps/{step_id} -// Retrieves a run step. -inline Json CategoryThreads::retrieveRunStep(const std::string &threads, const std::string &runs, - const std::string &steps) { - return openai_.get("threads/" + threads + "/runs/" + runs + "/steps/" + steps); -} - -// GET https://api.openai.com/v1/threads/{thread_id}/runs/{run_id}/steps -// Returns a list of run steps belonging to a run. -inline Json CategoryThreads::listRunStep(const std::string &threads, const std::string &runs) { - return openai_.get("threads/" + threads + "/runs/" + runs + "/steps"); -} - -// POST https://api.openai.com/v1/completions -// Creates a completion for the provided prompt and parameters -inline Json CategoryCompletion::create(Json input) { return openai_.post("completions", input); } - -// POST https://api.openai.com/v1/chat/completions -// Creates a chat completion for the provided prompt and parameters -inline Json CategoryChat::create(Json input) { return openai_.post("chat/completions", input); } - -// POST https://api.openai.com/v1/audio/transcriptions -// Transcribes audio into the input language. -inline Json CategoryAudio::transcribe(Json input) { - auto lambda = [input]() -> std::map { - std::map temp; - temp.insert({"model", input["model"].get()}); - if (input.contains("language")) { - temp.insert({"language", input["language"].get()}); - } - if (input.contains("prompt")) { - temp.insert({"prompt", input["prompt"].get()}); - } - if (input.contains("response_format")) { - temp.insert({"response_format", input["response_format"].get()}); - } - if (input.contains("temperature")) { - temp.insert({"temperature", std::to_string(input["temperature"].get())}); - } - return temp; - }; - openai_.setMultiformPart({"file", input["file"].get()}, lambda()); - - return openai_.post("audio/transcriptions", std::string {""}, "multipart/form-data"); -} - -// POST https://api.openai.com/v1/audio/translations -// Translates audio into into English.. -inline Json CategoryAudio::translate(Json input) { - auto lambda = [input]() -> std::map { - std::map temp; - temp.insert({"model", input["model"].get()}); - if (input.contains("language")) { - temp.insert({"language", input["language"].get()}); - } - if (input.contains("prompt")) { - temp.insert({"prompt", input["prompt"].get()}); - } - if (input.contains("response_format")) { - temp.insert({"response_format", input["response_format"].get()}); - } - if (input.contains("temperature")) { - temp.insert({"temperature", std::to_string(input["temperature"].get())}); + nlohmann::json ExtractCompletionOutput(const nlohmann::json& response) const override { + if (response.contains("choices") && response["choices"].is_array() && !response["choices"].empty()) { + const auto& choice = response["choices"][0]; + if (choice.contains("message") && choice["message"].contains("content")) { + return nlohmann::json::parse(choice["message"]["content"].get()); + } } - return temp; - }; - openai_.setMultiformPart({"file", input["file"].get()}, lambda()); - - return openai_.post("audio/translations", std::string {""}, "multipart/form-data"); -} - -// POST https://api.openai.com/v1/translations -// Creates a new edit for the provided input, instruction, and parameters -inline Json CategoryEdit::create(Json input) { return openai_.post("edits", input); } - -// POST https://api.openai.com/v1/images/generations -// Given a prompt and/or an input image, the model will generate a new image. -inline Json CategoryImage::create(Json input) { return openai_.post("images/generations", input); } - -// POST https://api.openai.com/v1/images/edits -// Creates an edited or extended image given an original image and a prompt. -inline Json CategoryImage::edit(Json input) { - std::string prompt = input["prompt"].get(); // required - // Default values - std::string mask = ""; - int n = 1; - std::string size = "1024x1024"; - std::string response_format = "url"; - std::string user = ""; - - if (input.contains("mask")) { - mask = input["mask"].get(); - } - if (input.contains("n")) { - n = input["n"].get(); - } - if (input.contains("size")) { - size = input["size"].get(); - } - if (input.contains("response_format")) { - response_format = input["response_format"].get(); - } - if (input.contains("user")) { - user = input["user"].get(); + return {}; } - openai_.setMultiformPart({"image", input["image"].get()}, - std::map {{"prompt", prompt}, - {"mask", mask}, - {"n", std::to_string(n)}, - {"size", size}, - {"response_format", response_format}, - {"user", user}}); - return openai_.post("images/edits", std::string {""}, "multipart/form-data"); -} - -// POST https://api.openai.com/v1/images/variations -// Creates a variation of a given image. -inline Json CategoryImage::variation(Json input) { - // Default values - int n = 1; - std::string size = "1024x1024"; - std::string response_format = "url"; - std::string user = ""; - - if (input.contains("n")) { - n = input["n"].get(); - } - if (input.contains("size")) { - size = input["size"].get(); - } - if (input.contains("response_format")) { - response_format = input["response_format"].get(); - } - if (input.contains("user")) { - user = input["user"].get(); + nlohmann::json ExtractEmbeddingVector(const nlohmann::json& response) const override { + auto results = nlohmann::json::array(); + if (response.contains("data") && response["data"].is_array() && !response["data"].empty()) { + const auto& embeddings = response["data"]; + for (const auto& embedding: embeddings) { + results.push_back(embedding["embedding"]); + } + return results; + } } - openai_.setMultiformPart( - {"image", input["image"].get()}, - std::map { - {"n", std::to_string(n)}, {"size", size}, {"response_format", response_format}, {"user", user}}); - - return openai_.post("images/variations", std::string {""}, "multipart/form-data"); -} - -inline Json CategoryEmbedding::create(Json input) { return openai_.post("embeddings", input); } - -inline Json CategoryFile::list() { return openai_.get("files"); } - -inline Json CategoryFile::upload(Json input) { - openai_.setMultiformPart({"file", input["file"].get()}, - std::map {{"purpose", input["purpose"].get()}}); - - return openai_.post("files", std::string {""}, "multipart/form-data"); -} - -inline Json CategoryFile::del(const std::string &file_id) { return openai_.del("files/" + file_id); } - -inline Json CategoryFile::retrieve(const std::string &file_id) { return openai_.get("files/" + file_id); } - -inline Json CategoryFile::content(const std::string &file_id) { return openai_.get("files/" + file_id + "/content"); } - -inline Json CategoryFineTune::create(Json input) { return openai_.post("fine-tunes", input); } - -inline Json CategoryFineTune::list() { return openai_.get("fine-tunes"); } - -inline Json CategoryFineTune::retrieve(const std::string &fine_tune_id) { - return openai_.get("fine-tunes/" + fine_tune_id); -} - -inline Json CategoryFineTune::content(const std::string &fine_tune_id) { - return openai_.get("fine-tunes/" + fine_tune_id + "/content"); -} - -inline Json CategoryFineTune::cancel(const std::string &fine_tune_id) { - return openai_.post("fine-tunes/" + fine_tune_id + "/cancel", Json {}); -} - -inline Json CategoryFineTune::events(const std::string &fine_tune_id) { - return openai_.get("fine-tunes/" + fine_tune_id + "/events"); -} - -inline Json CategoryFineTune::del(const std::string &model) { return openai_.del("models/" + model); } - -inline Json CategoryModeration::create(Json input) { return openai_.post("moderations", input); } - -} // namespace _detail - -// Public interface - -using _detail::OpenAI; - -// instance -using _detail::instance; -using _detail::start; - -// Generic methods -using _detail::get; -using _detail::post; - -// Helper categories access -using _detail::assistant; -using _detail::audio; -using _detail::chat; -using _detail::completion; -using _detail::edit; -using _detail::embedding; -using _detail::file; -using _detail::fineTune; -using _detail::image; -using _detail::model; -using _detail::moderation; -using _detail::thread; - -using _detail::Json; - -} // namespace openai +}; -#endif // OPENAI_HPP_ +}// namespace flockmtl diff --git a/src/include/flockmtl/model_manager/providers/provider.hpp b/src/include/flockmtl/model_manager/providers/provider.hpp index 5591648a..de5ccf90 100644 --- a/src/include/flockmtl/model_manager/providers/provider.hpp +++ b/src/include/flockmtl/model_manager/providers/provider.hpp @@ -3,6 +3,7 @@ #include "fmt/format.h" #include +#include "flockmtl/model_manager/providers/handlers/handler.hpp" #include "flockmtl/model_manager/repository.hpp" namespace flockmtl { @@ -17,12 +18,20 @@ enum class OutputType { class IProvider { public: ModelDetails model_details_; + std::unique_ptr model_handler_; explicit IProvider(const ModelDetails& model_details) : model_details_(model_details) {}; virtual ~IProvider() = default; - virtual nlohmann::json CallComplete(const std::string& prompt, bool json_response, OutputType output_type) = 0; - virtual nlohmann::json CallEmbedding(const std::vector& inputs) = 0; + virtual void AddCompletionRequest(const std::string& prompt, bool json_response, OutputType output_type) = 0; + virtual void AddEmbeddingRequest(const std::vector& inputs) = 0; + + virtual std::vector CollectCompletions(const std::string& contentType = "application/json") { + return model_handler_->CollectCompletions(contentType); + } + virtual std::vector CollectEmbeddings(const std::string& contentType = "application/json") { + return model_handler_->CollectEmbeddings(contentType); + } static std::string GetOutputTypeString(const OutputType output_type) { switch (output_type) { diff --git a/src/model_manager/model.cpp b/src/model_manager/model.cpp index 8b69857d..bca2589a 100644 --- a/src/model_manager/model.cpp +++ b/src/model_manager/model.cpp @@ -91,10 +91,20 @@ void Model::ConstructProvider() { ModelDetails Model::GetModelDetails() { return model_details_; } -nlohmann::json Model::CallComplete(const std::string& prompt, const bool json_response, const OutputType output_type) { - return provider_->CallComplete(prompt, json_response, output_type); +void Model::AddCompletionRequest(const std::string& prompt, bool json_response, OutputType output_type) { + provider_->AddCompletionRequest(prompt, json_response, output_type); } -nlohmann::json Model::CallEmbedding(const std::vector& inputs) { return provider_->CallEmbedding(inputs); } +void Model::AddEmbeddingRequest(const std::vector& inputs) { + provider_->AddEmbeddingRequest(inputs); +} + +std::vector Model::CollectCompletions(const std::string& contentType) { + return provider_->CollectCompletions(contentType); +} + +std::vector Model::CollectEmbeddings(const std::string& contentType) { + return provider_->CollectEmbeddings(contentType); +} }// namespace flockmtl diff --git a/src/model_manager/providers/adapters/azure.cpp b/src/model_manager/providers/adapters/azure.cpp index 94d6e91f..f38e485c 100644 --- a/src/model_manager/providers/adapters/azure.cpp +++ b/src/model_manager/providers/adapters/azure.cpp @@ -2,11 +2,7 @@ namespace flockmtl { -nlohmann::json AzureProvider::CallComplete(const std::string& prompt, const bool json_response, OutputType output_type) { - auto azure_model_manager_uptr = - std::make_unique(model_details_.secret["api_key"], model_details_.secret["resource_name"], - model_details_.model, model_details_.secret["api_version"], true); - +void AzureProvider::AddCompletionRequest(const std::string& prompt, const bool json_response, OutputType output_type) { // Create a JSON request payload with the provided parameters nlohmann::json request_payload = {{"messages", {{{"role", "user"}, {"content", prompt}}}}}; @@ -35,65 +31,18 @@ nlohmann::json AzureProvider::CallComplete(const std::string& prompt, const bool } } - // Make a request to the Azure API - auto completion = azure_model_manager_uptr->CallComplete(request_payload); - - // Check if the conversation was too long for the context window - if (completion["choices"][0]["finish_reason"] == "length") { - // Handle the error when the context window is too long - throw ExceededMaxOutputTokensError(); - } - - // Check if the safety system refused the request - if (completion["choices"][0]["message"]["refusal"] != nullptr) { - // Handle refusal error - throw std::runtime_error( - duckdb_fmt::format("The request was refused due to Azure's safety system.{{\"refusal\": \"{}\"}}", - completion["choices"][0]["message"]["refusal"].get())); - } - - // Check if the model's output included restricted content - if (completion["choices"][0]["finish_reason"] == "content_filter") { - // Handle content filtering - throw std::runtime_error("The content filter was triggered, resulting in incomplete JSON."); - } - - std::string content_str = completion["choices"][0]["message"]["content"]; - - if (json_response) { - return nlohmann::json::parse(content_str); - } - - return content_str; + model_handler_->AddRequest(request_payload, IModelProviderHandler::RequestType::Completion); } -nlohmann::json AzureProvider::CallEmbedding(const std::vector& inputs) { - auto azure_model_manager_uptr = - std::make_unique(model_details_.secret["api_key"], model_details_.secret["resource_name"], - model_details_.model, model_details_.secret["api_version"], true); - - // Create a JSON request payload with the provided parameters - nlohmann::json request_payload = { - {"model", model_details_.model}, - {"input", inputs}, - }; +void AzureProvider::AddEmbeddingRequest(const std::vector& inputs) { + for (const auto& input: inputs) { + nlohmann::json request_payload = { + {"model", model_details_.model}, + {"prompt", input}, + }; - // Make a request to the Azure API - auto completion = azure_model_manager_uptr->CallEmbedding(request_payload); - - // Check if the conversation was too long for the context window - if (completion["choices"][0]["finish_reason"] == "length") { - // Handle the error when the context window is too long - throw ExceededMaxOutputTokensError(); - // Add error handling code here - } - - auto embeddings = nlohmann::json::array(); - for (auto& item: completion["data"]) { - embeddings.push_back(item["embedding"]); + model_handler_->AddRequest(request_payload, IModelProviderHandler::RequestType::Embedding); } - - return embeddings; } }// namespace flockmtl \ No newline at end of file diff --git a/src/model_manager/providers/adapters/ollama.cpp b/src/model_manager/providers/adapters/ollama.cpp index 0a999bf5..feb8424c 100644 --- a/src/model_manager/providers/adapters/ollama.cpp +++ b/src/model_manager/providers/adapters/ollama.cpp @@ -2,12 +2,10 @@ namespace flockmtl { -nlohmann::json OllamaProvider::CallComplete(const std::string& prompt, const bool json_response, OutputType output_type) { - auto ollama_model_manager_uptr = std::make_unique(model_details_.secret["api_url"], true); - - // Create a JSON request payload with the provided parameters +void OllamaProvider::AddCompletionRequest(const std::string& prompt, const bool json_response, OutputType output_type) { nlohmann::json request_payload = {{"model", model_details_.model}, - {"prompt", prompt}}; + {"prompt", prompt}, + {"stream", false}}; if (!model_details_.model_parameters.empty()) { request_payload.update(model_details_.model_parameters); @@ -16,56 +14,30 @@ nlohmann::json OllamaProvider::CallComplete(const std::string& prompt, const boo if (json_response) { if (model_details_.model_parameters.contains("format")) { auto schema = model_details_.model_parameters["format"]; - request_payload["format"] = {{"type", "object"}, {"properties", {"items", {{"type", "array"}, {"items", schema}}}}}; + request_payload["format"] = { + {"type", "object"}, + {"properties", {{"items", {{"type", "array"}, {"items", schema}}}}}, + {"required", {"items"}}}; } else { - request_payload["format"] = {{"type", "object"}, {"properties", {"items", {{"type", "array"}, {"items", {{"type", GetOutputTypeString(output_type)}}}}}}}; + request_payload["format"] = { + {"type", "object"}, + {"properties", {{"items", {{"type", "array"}, {"items", {{"type", GetOutputTypeString(output_type)}}}}}}}, + {"required", {"items"}}}; } } - nlohmann::json completion; - try { - completion = ollama_model_manager_uptr->CallComplete(request_payload); - } catch (const std::exception& e) { - throw std::runtime_error(duckdb_fmt::format("Error in making request to Ollama API: {}", e.what())); - } - - // Check if the call was not succesfull - if ((completion.contains("done_reason") && completion["done_reason"] != "stop") || - (completion.contains("done") && !completion["done"].is_null() && completion["done"].get() != true)) { - // Handle refusal error - throw std::runtime_error("The request was refused due to some internal error with Ollama API"); - } - - std::string content_str = completion["response"]; - - if (json_response) { - return nlohmann::json::parse(content_str); - } - - return content_str; + model_handler_->AddRequest(request_payload); } -nlohmann::json OllamaProvider::CallEmbedding(const std::vector& inputs) { - auto ollama_model_manager_uptr = std::make_unique(model_details_.secret["api_url"], true); - - auto embeddings = nlohmann::json::array(); +void OllamaProvider::AddEmbeddingRequest(const std::vector& inputs) { for (const auto& input: inputs) { - // Create a JSON request payload with the provided parameters nlohmann::json request_payload = { {"model", model_details_.model}, - {"prompt", input}, + {"input", input}, }; - nlohmann::json completion; - try { - completion = ollama_model_manager_uptr->CallEmbedding(request_payload); - } catch (const std::exception& e) { - throw std::runtime_error(duckdb_fmt::format("Error in making request to Ollama API: {}", e.what())); - } - - embeddings.push_back(completion["embedding"]); + model_handler_->AddRequest(request_payload, IModelProviderHandler::RequestType::Embedding); } - return embeddings; } }// namespace flockmtl \ No newline at end of file diff --git a/src/model_manager/providers/adapters/openai.cpp b/src/model_manager/providers/adapters/openai.cpp index 3092da04..60b48dbd 100644 --- a/src/model_manager/providers/adapters/openai.cpp +++ b/src/model_manager/providers/adapters/openai.cpp @@ -2,14 +2,7 @@ namespace flockmtl { -nlohmann::json OpenAIProvider::CallComplete(const std::string& prompt, bool json_response, OutputType output_type) { - auto base_url = std::string(""); - if (const auto it = model_details_.secret.find("base_url"); it != model_details_.secret.end()) { - base_url = it->second; - } - auto openai = openai::OpenAI(model_details_.secret["api_key"], "", true, base_url); - - // Create a JSON request payload with the provided parameters +void OpenAIProvider::AddCompletionRequest(const std::string& prompt, bool json_response, OutputType output_type) { nlohmann::json request_payload = {{"model", model_details_.model}, {"messages", {{{"role", "user"}, {"content", prompt}}}}}; @@ -38,70 +31,16 @@ nlohmann::json OpenAIProvider::CallComplete(const std::string& prompt, bool json } } - // Make a request to the OpenAI API - nlohmann::json completion; - try { - completion = openai.chat.create(request_payload); - } catch (const std::exception& e) { - throw std::runtime_error("Error in making request to OpenAI API: " + std::string(e.what())); - } - // Check if the conversation was too long for the context window - if (completion["choices"][0]["finish_reason"] == "length") { - // Handle the error when the context window is too long - throw ExceededMaxOutputTokensError(); - } - - // Check if the OpenAI safety system refused the request - if (completion["choices"][0]["message"]["refusal"] != nullptr) { - // Handle refusal error - throw std::runtime_error( - duckdb_fmt::format("The request was refused due to OpenAI's safety system.{{\"refusal\": \"{}\"}}", - completion["choices"][0]["message"]["refusal"].get())); - } - - // Check if the model's output included restricted content - if (completion["choices"][0]["finish_reason"] == "content_filter") { - // Handle content filtering - throw std::runtime_error("The content filter was triggered, resulting in incomplete JSON."); - } - - std::string content_str = completion["choices"][0]["message"]["content"]; - - if (json_response) { - return nlohmann::json::parse(content_str); - } - - return content_str; + model_handler_->AddRequest(request_payload); } -nlohmann::json OpenAIProvider::CallEmbedding(const std::vector& inputs) { - auto base_url = std::string(""); - if (const auto it = model_details_.secret.find("base_url"); it != model_details_.secret.end()) { - base_url = it->second; - } - auto openai = openai::OpenAI(model_details_.secret["api_key"], "", true, base_url); - - // Create a JSON request payload with the provided parameters +void OpenAIProvider::AddEmbeddingRequest(const std::vector& inputs) { nlohmann::json request_payload = { {"model", model_details_.model}, {"input", inputs}, }; - // Make a request to the OpenAI API - auto completion = openai.embedding.create(request_payload); - - // Check if the conversation was too long for the context window - if (completion["choices"][0]["finish_reason"] == "length") { - // Handle the error when the context window is too long - throw ExceededMaxOutputTokensError(); - } - - auto embeddings = nlohmann::json::array(); - for (auto& item: completion["data"]) { - embeddings.push_back(item["embedding"]); - } - - return embeddings; + model_handler_->AddRequest(request_payload, IModelProviderHandler::RequestType::Embedding); } }// namespace flockmtl diff --git a/src/prompt_manager/prompt_manager.cpp b/src/prompt_manager/prompt_manager.cpp index 4f22bd5f..592042d5 100644 --- a/src/prompt_manager/prompt_manager.cpp +++ b/src/prompt_manager/prompt_manager.cpp @@ -53,13 +53,13 @@ std::string PromptManager::ConstructInputTuplesHeader(const nlohmann::json& tupl std::string PromptManager::ConstructInputTuplesHeaderXML(const nlohmann::json& tuples) { if (tuples.empty()) { - return "Empty\n"; + return "
\n"; } - auto header = std::string(""); + auto header = std::string("
"); for (const auto& key: tuples[0].items()) { header += "" + key.key() + ""; } - header += "\n"; + header += "
\n"; return header; } diff --git a/test/functions/aggregate/llm_aggregate_function_test_base.hpp b/test/functions/aggregate/llm_aggregate_function_test_base.hpp index c2904b02..3459753d 100644 --- a/test/functions/aggregate/llm_aggregate_function_test_base.hpp +++ b/test/functions/aggregate/llm_aggregate_function_test_base.hpp @@ -22,7 +22,7 @@ class LLMAggregateTestBase : public ::testing::Test { static constexpr const char* DEFAULT_MODEL = "gpt-4o"; static constexpr const char* TEST_PROMPT = "Summarize the following data"; - std::shared_ptr mock_provider; + std::shared_ptr mock_provider; void SetUp() override { auto con = Config::GetConnection(); @@ -30,7 +30,7 @@ class LLMAggregateTestBase : public ::testing::Test { " TYPE OPENAI," " API_KEY 'your-api-key');"); - mock_provider = std::make_shared(); + mock_provider = std::make_shared(ModelDetails{}); Model::SetMockProvider(mock_provider); } diff --git a/test/functions/aggregate/llm_first.cpp b/test/functions/aggregate/llm_first.cpp index ea212a56..2823aa87 100644 --- a/test/functions/aggregate/llm_first.cpp +++ b/test/functions/aggregate/llm_first.cpp @@ -41,8 +41,10 @@ class LLMFirstTest : public LLMAggregateTestBase { // Test llm_first with SQL queries without GROUP BY TEST_F(LLMFirstTest, LLMFirstWithoutGroupBy) { - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) - .WillOnce(::testing::Return(GetExpectedJsonResponse())); + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) + .WillOnce(::testing::Return(std::vector{GetExpectedJsonResponse()})); auto con = Config::GetConnection(); @@ -59,9 +61,11 @@ TEST_F(LLMFirstTest, LLMFirstWithoutGroupBy) { // Test llm_first with SQL queries with GROUP BY TEST_F(LLMFirstTest, LLMFirstWithGroupBy) { - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(3); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) .Times(3) - .WillRepeatedly(::testing::Return(GetExpectedJsonResponse())); + .WillRepeatedly(::testing::Return(std::vector{GetExpectedJsonResponse()})); auto con = Config::GetConnection(); @@ -95,9 +99,11 @@ TEST_F(LLMFirstTest, Operation_InvalidArguments_ThrowsException) { TEST_F(LLMFirstTest, Operation_MultipleInputs_ProcessesCorrectly) { const nlohmann::json expected_response = GetExpectedJsonResponse(); - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(3); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) .Times(3) - .WillRepeatedly(::testing::Return(expected_response)); + .WillRepeatedly(::testing::Return(std::vector{expected_response})); auto con = Config::GetConnection(); @@ -122,8 +128,11 @@ TEST_F(LLMFirstTest, Operation_LargeInputSet_ProcessesCorrectly) { constexpr size_t input_count = 100; const nlohmann::json expected_response = PrepareExpectedResponseForLargeInput(input_count); - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) - .WillRepeatedly(::testing::Return(expected_response)); + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(100); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) + .Times(100) + .WillRepeatedly(::testing::Return(std::vector{expected_response})); auto con = Config::GetConnection(); diff --git a/test/functions/aggregate/llm_last.cpp b/test/functions/aggregate/llm_last.cpp index 58060905..4fdbb655 100644 --- a/test/functions/aggregate/llm_last.cpp +++ b/test/functions/aggregate/llm_last.cpp @@ -41,8 +41,10 @@ class LLMLastTest : public LLMAggregateTestBase { // Test llm_last with SQL queries without GROUP BY TEST_F(LLMLastTest, LLMLastWithoutGroupBy) { - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) - .WillOnce(::testing::Return(GetExpectedJsonResponse())); + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) + .WillOnce(::testing::Return(std::vector{GetExpectedJsonResponse()})); auto con = Config::GetConnection(); @@ -59,9 +61,11 @@ TEST_F(LLMLastTest, LLMLastWithoutGroupBy) { // Test llm_last with SQL queries with GROUP BY TEST_F(LLMLastTest, LLMLastWithGroupBy) { - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(3); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) .Times(3) - .WillRepeatedly(::testing::Return(GetExpectedJsonResponse())); + .WillRepeatedly(::testing::Return(std::vector{GetExpectedJsonResponse()})); auto con = Config::GetConnection(); @@ -95,9 +99,11 @@ TEST_F(LLMLastTest, Operation_InvalidArguments_ThrowsException) { TEST_F(LLMLastTest, Operation_MultipleInputs_ProcessesCorrectly) { const nlohmann::json expected_response = GetExpectedJsonResponse(); - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(3); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) .Times(3) - .WillRepeatedly(::testing::Return(expected_response)); + .WillRepeatedly(::testing::Return(std::vector{expected_response})); auto con = Config::GetConnection(); @@ -122,9 +128,11 @@ TEST_F(LLMLastTest, Operation_LargeInputSet_ProcessesCorrectly) { constexpr size_t input_count = 100; const nlohmann::json expected_response = GetExpectedJsonResponse(); - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(100); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) .Times(100) - .WillRepeatedly(::testing::Return(expected_response)); + .WillRepeatedly(::testing::Return(std::vector{expected_response})); auto con = Config::GetConnection(); diff --git a/test/functions/aggregate/llm_reduce.cpp b/test/functions/aggregate/llm_reduce.cpp index 594077f5..4d4abe0c 100644 --- a/test/functions/aggregate/llm_reduce.cpp +++ b/test/functions/aggregate/llm_reduce.cpp @@ -41,8 +41,10 @@ class LLMReduceTest : public LLMAggregateTestBase { // Test llm_reduce with SQL queries without GROUP BY TEST_F(LLMReduceTest, LLMReduceWithoutGroupBy) { - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) - .WillOnce(::testing::Return(GetExpectedJsonResponse())); + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) + .WillOnce(::testing::Return(std::vector{GetExpectedJsonResponse()})); auto con = Config::GetConnection(); @@ -59,9 +61,11 @@ TEST_F(LLMReduceTest, LLMReduceWithoutGroupBy) { // Test llm_reduce with SQL queries with GROUP BY TEST_F(LLMReduceTest, LLMReduceWithGroupBy) { - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(3); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) .Times(3) - .WillRepeatedly(::testing::Return(GetExpectedJsonResponse())); + .WillRepeatedly(::testing::Return(std::vector{GetExpectedJsonResponse()})); auto con = Config::GetConnection(); @@ -92,9 +96,11 @@ TEST_F(LLMReduceTest, Operation_InvalidArguments_ThrowsException) { TEST_F(LLMReduceTest, Operation_MultipleInputs_ProcessesCorrectly) { const nlohmann::json expected_response = GetExpectedJsonResponse(); - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(3); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) .Times(3) - .WillRepeatedly(::testing::Return(expected_response)); + .WillRepeatedly(::testing::Return(std::vector{expected_response})); auto con = Config::GetConnection(); @@ -116,8 +122,11 @@ TEST_F(LLMReduceTest, Operation_LargeInputSet_ProcessesCorrectly) { constexpr size_t input_count = 100; const nlohmann::json expected_response = PrepareExpectedResponseForLargeInput(input_count); - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) - .WillRepeatedly(::testing::Return(expected_response)); + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(100); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) + .Times(100) + .WillRepeatedly(::testing::Return(std::vector{expected_response})); auto con = Config::GetConnection(); diff --git a/test/functions/aggregate/llm_reduce_json.cpp b/test/functions/aggregate/llm_reduce_json.cpp index 62a053ed..76aaf261 100644 --- a/test/functions/aggregate/llm_reduce_json.cpp +++ b/test/functions/aggregate/llm_reduce_json.cpp @@ -42,8 +42,10 @@ class LLMReduceJsonTest : public LLMAggregateTestBase { // Test llm_reduce_json with SQL queries without GROUP BY TEST_F(LLMReduceJsonTest, LLMReduceJsonWithoutGroupBy) { - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) - .WillOnce(::testing::Return(GetExpectedJsonResponse())); + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) + .WillOnce(::testing::Return(std::vector{GetExpectedJsonResponse()})); auto con = Config::GetConnection(); @@ -61,9 +63,11 @@ TEST_F(LLMReduceJsonTest, LLMReduceJsonWithoutGroupBy) { // Test llm_reduce_json with SQL queries with GROUP BY TEST_F(LLMReduceJsonTest, LLMReduceJsonWithGroupBy) { - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(3); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) .Times(3) - .WillRepeatedly(::testing::Return(GetExpectedJsonResponse())); + .WillRepeatedly(::testing::Return(std::vector{GetExpectedJsonResponse()})); auto con = Config::GetConnection(); @@ -95,9 +99,11 @@ TEST_F(LLMReduceJsonTest, Operation_InvalidArguments_ThrowsException) { TEST_F(LLMReduceJsonTest, Operation_MultipleInputs_ProcessesCorrectly) { const nlohmann::json expected_response = GetExpectedJsonResponse(); - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(3); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) .Times(3) - .WillRepeatedly(::testing::Return(expected_response)); + .WillRepeatedly(::testing::Return(std::vector{expected_response})); auto con = Config::GetConnection(); @@ -120,8 +126,11 @@ TEST_F(LLMReduceJsonTest, Operation_LargeInputSet_ProcessesCorrectly) { constexpr size_t input_count = 100; const nlohmann::json expected_response = GetExpectedJsonResponse(); - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) - .WillRepeatedly(::testing::Return(expected_response)); + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(100); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) + .Times(100) + .WillRepeatedly(::testing::Return(std::vector{expected_response})); auto con = Config::GetConnection(); @@ -143,8 +152,10 @@ TEST_F(LLMReduceJsonTest, Operation_LargeInputSet_ProcessesCorrectly) { TEST_F(LLMReduceJsonTest, Operation_ValidJsonOutput_ParsesCorrectly) { const nlohmann::json expected_response = GetExpectedJsonResponse(); - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) - .WillOnce(::testing::Return(expected_response)); + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) + .WillOnce(::testing::Return(std::vector{expected_response})); auto con = Config::GetConnection(); @@ -181,8 +192,10 @@ TEST_F(LLMReduceJsonTest, Operation_ComplexJsonStructure_HandlesCorrectly) { complex_response["items"].push_back(item); - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) - .WillOnce(::testing::Return(complex_response)); + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) + .WillOnce(::testing::Return(std::vector{complex_response})); auto con = Config::GetConnection(); diff --git a/test/functions/aggregate/llm_rerank.cpp b/test/functions/aggregate/llm_rerank.cpp index cd2c1284..dad39dda 100644 --- a/test/functions/aggregate/llm_rerank.cpp +++ b/test/functions/aggregate/llm_rerank.cpp @@ -47,8 +47,10 @@ class LLMRerankTest : public LLMAggregateTestBase { // Test llm_rerank with SQL queries without GROUP BY TEST_F(LLMRerankTest, LLMRerankWithoutGroupBy) { - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) - .WillOnce(::testing::Return(GetExpectedJsonResponse())); + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) + .WillOnce(::testing::Return(std::vector{GetExpectedJsonResponse()})); auto con = Config::GetConnection(); @@ -69,9 +71,11 @@ TEST_F(LLMRerankTest, LLMRerankWithoutGroupBy) { // Test llm_rerank with SQL queries with GROUP BY TEST_F(LLMRerankTest, LLMRerankWithGroupBy) { - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(3); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) .Times(3) - .WillRepeatedly(::testing::Return(nlohmann::json::parse(LLM_RESPONSE_WITH_GROUP_BY))); + .WillRepeatedly(::testing::Return(std::vector{nlohmann::json::parse(LLM_RESPONSE_WITH_GROUP_BY)})); auto con = Config::GetConnection(); @@ -105,9 +109,11 @@ TEST_F(LLMRerankTest, Operation_InvalidArguments_ThrowsException) { TEST_F(LLMRerankTest, Operation_MultipleInputs_ProcessesCorrectly) { const nlohmann::json expected_response = nlohmann::json::parse(LLM_RESPONSE_WITH_GROUP_BY); - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(3); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) .Times(3) - .WillRepeatedly(::testing::Return(expected_response)); + .WillRepeatedly(::testing::Return(std::vector{expected_response})); auto con = Config::GetConnection(); @@ -132,8 +138,11 @@ TEST_F(LLMRerankTest, Operation_LargeInputSet_ProcessesCorrectly) { constexpr size_t input_count = 100; const nlohmann::json expected_response = PrepareExpectedResponseForLargeInput(input_count); - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) - .WillRepeatedly(::testing::Return(expected_response)); + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(100); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) + .Times(100) + .WillRepeatedly(::testing::Return(std::vector{expected_response})); auto con = Config::GetConnection(); diff --git a/test/functions/mock_provider.hpp b/test/functions/mock_provider.hpp index f63e0cd7..e151328e 100644 --- a/test/functions/mock_provider.hpp +++ b/test/functions/mock_provider.hpp @@ -1,19 +1,17 @@ #pragma once - -#include "flockmtl/model_manager/providers/adapters/openai.hpp" -#include "nlohmann/json.hpp" +#include "flockmtl/model_manager/providers/provider.hpp" #include namespace flockmtl { -// Mock class for OpenAI API to avoid real API calls during tests -class MockOpenAIProvider : public OpenAIProvider { +class MockProvider : public IProvider { public: - explicit MockOpenAIProvider() : OpenAIProvider(ModelDetails()) {} + explicit MockProvider(const ModelDetails& model_details) : IProvider(model_details) {} - // Override the API call methods for testing - MOCK_METHOD(nlohmann::json, CallComplete, (const std::string& prompt, bool json_response, OutputType output_type), (override)); - MOCK_METHOD(nlohmann::json, CallEmbedding, (const std::vector& inputs), (override)); + MOCK_METHOD(void, AddCompletionRequest, (const std::string& prompt, bool json_response, OutputType output_type), (override)); + MOCK_METHOD(void, AddEmbeddingRequest, (const std::vector& inputs), (override)); + MOCK_METHOD(std::vector, CollectCompletions, (const std::string& contentType), (override)); + MOCK_METHOD(std::vector, CollectEmbeddings, (const std::string& contentType), (override)); }; }// namespace flockmtl diff --git a/test/functions/scalar/llm_complete.cpp b/test/functions/scalar/llm_complete.cpp index acb69010..638e4de6 100644 --- a/test/functions/scalar/llm_complete.cpp +++ b/test/functions/scalar/llm_complete.cpp @@ -41,8 +41,10 @@ class LLMCompleteTest : public LLMFunctionTestBase { // Test llm_complete with SQL queries TEST_F(LLMCompleteTest, LLMCompleteWithoutInputColumns) { - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) - .WillOnce(::testing::Return(GetExpectedResponse())); + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) + .WillOnce(::testing::Return(std::vector{GetExpectedResponse()})); auto con = Config::GetConnection(); const auto results = con.Query("SELECT " + GetFunctionName() + "({'model_name': 'gpt-4o'},{'prompt': 'Explain the purpose of FlockMTL.'}) AS flockmtl_purpose;"); @@ -52,8 +54,10 @@ TEST_F(LLMCompleteTest, LLMCompleteWithoutInputColumns) { TEST_F(LLMCompleteTest, LLMCompleteWithInputColumns) { const nlohmann::json expected_response = {{"items", {"The capital of Canada is Ottawa."}}}; - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) - .WillOnce(::testing::Return(expected_response)); + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) + .WillOnce(::testing::Return(std::vector{expected_response})); auto con = Config::GetConnection(); const auto results = con.Query("SELECT " + GetFunctionName() + "({'model_name': 'gpt-4o'},{'prompt': 'What is the capital of France?'}, {'input': 'France'}) AS flockmtl_capital;"); @@ -66,8 +70,10 @@ TEST_F(LLMCompleteTest, ValidateArguments) { } TEST_F(LLMCompleteTest, Operation_TwoArguments_SimplePrompt) { - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) - .WillOnce(::testing::Return(GetExpectedResponse())); + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) + .WillOnce(::testing::Return(std::vector{GetExpectedResponse()})); duckdb::DataChunk chunk; CreateBasicChunk(chunk); @@ -75,7 +81,7 @@ TEST_F(LLMCompleteTest, Operation_TwoArguments_SimplePrompt) { SetStructStringData(chunk.data[0], {{{"model_name", DEFAULT_MODEL}}}); SetStructStringData(chunk.data[1], {{{"prompt", TEST_PROMPT}}}); - auto results = LlmComplete::Operation(chunk); + auto results = LlmComplete::Operation(chunk, ExecutionMode::SYNC); EXPECT_EQ(results.size(), 1); EXPECT_EQ(results[0], GetExpectedResponse()); @@ -84,8 +90,10 @@ TEST_F(LLMCompleteTest, Operation_TwoArguments_SimplePrompt) { TEST_F(LLMCompleteTest, Operation_ThreeArguments_BatchProcessing) { const std::vector responses = {"response 1", "response 2"}; const nlohmann::json expected_response = PrepareExpectedResponseForBatch(responses); - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) - .WillOnce(::testing::Return(expected_response)); + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) + .WillOnce(::testing::Return(std::vector{expected_response})); duckdb::DataChunk chunk; auto model_struct = CreateModelStruct(); @@ -101,7 +109,7 @@ TEST_F(LLMCompleteTest, Operation_ThreeArguments_BatchProcessing) { SetStructStringData(chunk.data[2], {{{"variable1", "Hello"}, {"variable2", "World"}}, {{"variable1", "Good"}, {"variable2", "Morning"}}}); - auto results = LlmComplete::Operation(chunk); + auto results = LlmComplete::Operation(chunk, ExecutionMode::ASYNC); EXPECT_EQ(results.size(), 2); EXPECT_EQ(results, expected_response["items"]); @@ -120,8 +128,10 @@ TEST_F(LLMCompleteTest, Operation_LargeInputSet_ProcessesCorrectly) { const nlohmann::json expected_response = PrepareExpectedResponseForLargeInput(input_count); - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) - .WillOnce(::testing::Return(expected_response)); + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) + .WillOnce(::testing::Return(std::vector{expected_response})); duckdb::DataChunk chunk; auto model_struct = CreateModelStruct(); @@ -144,7 +154,7 @@ TEST_F(LLMCompleteTest, Operation_LargeInputSet_ProcessesCorrectly) { SetStructStringData(chunk.data[2], large_input); - const auto results = LlmComplete::Operation(chunk); + const auto results = LlmComplete::Operation(chunk, ExecutionMode::ASYNC); EXPECT_EQ(results.size(), input_count); EXPECT_EQ(results, expected_response["items"]); diff --git a/test/functions/scalar/llm_complete_json.cpp b/test/functions/scalar/llm_complete_json.cpp index 3a521f4c..9e90a52b 100644 --- a/test/functions/scalar/llm_complete_json.cpp +++ b/test/functions/scalar/llm_complete_json.cpp @@ -45,8 +45,10 @@ class LLMCompleteJsonTest : public LLMFunctionTestBase { TEST_F(LLMCompleteJsonTest, LLMCompleteJsonWithoutInputColumns) { nlohmann::json expected_response = GetExpectedJsonResponse(); - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) - .WillOnce(::testing::Return(expected_response)); + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) + .WillOnce(::testing::Return(std::vector{expected_response})); auto con = Config::GetConnection(); const auto results = con.Query("SELECT " + GetFunctionName() + "({'model_name': 'gpt-4o'},{'prompt': 'Explain the purpose of FlockMTL.'}) AS flockmtl_purpose;"); @@ -56,8 +58,10 @@ TEST_F(LLMCompleteJsonTest, LLMCompleteJsonWithoutInputColumns) { TEST_F(LLMCompleteJsonTest, LLMCompleteJsonWithInputColumns) { const nlohmann::json expected_response = {{"items", {nlohmann::json::parse(R"({"capital": "Ottawa", "country": "Canada"})")}}}; - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) - .WillOnce(::testing::Return(expected_response)); + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) + .WillOnce(::testing::Return(std::vector{expected_response})); auto con = Config::GetConnection(); const auto results = con.Query("SELECT " + GetFunctionName() + "({'model_name': 'gpt-4o'},{'prompt': 'What is the capital of France?'}, {'input': 'France'}) AS flockmtl_capital;"); @@ -72,8 +76,10 @@ TEST_F(LLMCompleteJsonTest, ValidateArguments) { TEST_F(LLMCompleteJsonTest, Operation_TwoArguments_SimplePrompt) { nlohmann::json expected_response = GetExpectedJsonResponse(); - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) - .WillOnce(::testing::Return(expected_response)); + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) + .WillOnce(::testing::Return(std::vector{expected_response})); duckdb::DataChunk chunk; CreateBasicChunk(chunk); @@ -81,7 +87,7 @@ TEST_F(LLMCompleteJsonTest, Operation_TwoArguments_SimplePrompt) { SetStructStringData(chunk.data[0], {{{"model_name", DEFAULT_MODEL}}}); SetStructStringData(chunk.data[1], {{{"prompt", TEST_PROMPT}}}); - auto results = LlmComplete::Operation(chunk); + auto results = LlmComplete::Operation(chunk, ExecutionMode::ASYNC); EXPECT_EQ(results.size(), 1); EXPECT_EQ(results[0], expected_response.dump()); @@ -90,8 +96,10 @@ TEST_F(LLMCompleteJsonTest, Operation_TwoArguments_SimplePrompt) { TEST_F(LLMCompleteJsonTest, Operation_ThreeArguments_BatchProcessing) { const std::vector responses = {"response 1", "response 2"}; const nlohmann::json expected_response = PrepareExpectedResponseForBatch(responses); - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) - .WillOnce(::testing::Return(expected_response)); + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) + .WillOnce(::testing::Return(std::vector{expected_response})); duckdb::DataChunk chunk; auto model_struct = CreateModelStruct(); @@ -107,7 +115,7 @@ TEST_F(LLMCompleteJsonTest, Operation_ThreeArguments_BatchProcessing) { SetStructStringData(chunk.data[2], {{{"variable1", "Hello"}, {"variable2", "World"}}, {{"variable1", "Good"}, {"variable2", "Morning"}}}); - auto results = LlmComplete::Operation(chunk); + auto results = LlmComplete::Operation(chunk, ExecutionMode::ASYNC); EXPECT_EQ(results.size(), 2); std::vector expected_strings; @@ -130,8 +138,10 @@ TEST_F(LLMCompleteJsonTest, Operation_LargeInputSet_ProcessesCorrectly) { const nlohmann::json expected_response = PrepareExpectedResponseForLargeInput(input_count); - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) - .WillOnce(::testing::Return(expected_response)); + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) + .WillOnce(::testing::Return(std::vector{expected_response})); duckdb::DataChunk chunk; auto model_struct = CreateModelStruct(); @@ -154,7 +164,7 @@ TEST_F(LLMCompleteJsonTest, Operation_LargeInputSet_ProcessesCorrectly) { SetStructStringData(chunk.data[2], large_input); - const auto results = LlmComplete::Operation(chunk); + const auto results = LlmComplete::Operation(chunk, ExecutionMode::ASYNC); EXPECT_EQ(results.size(), input_count); std::vector expected_strings; diff --git a/test/functions/scalar/llm_embedding.cpp b/test/functions/scalar/llm_embedding.cpp index 3f4471bc..ba7fd311 100644 --- a/test/functions/scalar/llm_embedding.cpp +++ b/test/functions/scalar/llm_embedding.cpp @@ -99,8 +99,10 @@ const std::vector> LLMEmbeddingTest::EXPECTED_EMBEDDINGS = { // Test llm_embedding with SQL queries TEST_F(LLMEmbeddingTest, LLMEmbeddingWithTextInput) { const nlohmann::json expected_response = GetExpectedJsonResponse(); - EXPECT_CALL(*mock_provider, CallEmbedding(::testing::_)) - .WillOnce(::testing::Return(expected_response)); + EXPECT_CALL(*mock_provider, AddEmbeddingRequest(::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectEmbeddings(::testing::_)) + .WillOnce(::testing::Return(std::vector{expected_response})); auto con = Config::GetConnection(); const auto results = con.Query("SELECT " + GetFunctionName() + "({'model_name': 'text-embedding-3-small'}, {'text': 'This is a test document'}) AS embedding;"); @@ -114,8 +116,10 @@ TEST_F(LLMEmbeddingTest, LLMEmbeddingWithTextInput) { TEST_F(LLMEmbeddingTest, LLMEmbeddingWithMultipleTextFields) { const nlohmann::json expected_response = GetExpectedJsonResponse(); - EXPECT_CALL(*mock_provider, CallEmbedding(::testing::_)) - .WillOnce(::testing::Return(expected_response)); + EXPECT_CALL(*mock_provider, AddEmbeddingRequest(::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectEmbeddings(::testing::_)) + .WillOnce(::testing::Return(std::vector{expected_response})); auto con = Config::GetConnection(); const auto results = con.Query("SELECT " + GetFunctionName() + "({'model_name': 'text-embedding-3-small'}, {'title': 'Document Title', 'content': 'Document content here'}) AS embedding;"); @@ -164,8 +168,10 @@ TEST_F(LLMEmbeddingTest, ValidateArguments) { TEST_F(LLMEmbeddingTest, Operation_TwoArguments_RequiredStructure) { const nlohmann::json expected_response = GetExpectedJsonResponse(); - EXPECT_CALL(*mock_provider, CallEmbedding(::testing::_)) - .WillOnce(::testing::Return(expected_response)); + EXPECT_CALL(*mock_provider, AddEmbeddingRequest(::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectEmbeddings(::testing::_)) + .WillOnce(::testing::Return(std::vector{expected_response})); duckdb::DataChunk chunk; CreateEmbeddingChunk(chunk); @@ -187,8 +193,10 @@ TEST_F(LLMEmbeddingTest, Operation_TwoArguments_RequiredStructure) { TEST_F(LLMEmbeddingTest, Operation_BatchProcessing) { const std::vector> batch_embeddings = {{0.1, 0.2, 0.3, 0.4, 0.5}, {0.2, 0.3, 0.4, 0.5, 0.6}}; const nlohmann::json expected_response = PrepareExpectedResponseForBatch(batch_embeddings); - EXPECT_CALL(*mock_provider, CallEmbedding(::testing::_)) - .WillOnce(::testing::Return(expected_response)); + EXPECT_CALL(*mock_provider, AddEmbeddingRequest(::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectEmbeddings(::testing::_)) + .WillOnce(::testing::Return(std::vector{expected_response})); duckdb::DataChunk chunk; CreateEmbeddingChunk(chunk, 2); @@ -217,8 +225,10 @@ TEST_F(LLMEmbeddingTest, Operation_BatchProcessing) { TEST_F(LLMEmbeddingTest, Operation_MultipleInputFields) { const nlohmann::json expected_response = GetExpectedJsonResponse(); - EXPECT_CALL(*mock_provider, CallEmbedding(::testing::_)) - .WillOnce(::testing::Return(expected_response)); + EXPECT_CALL(*mock_provider, AddEmbeddingRequest(::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectEmbeddings(::testing::_)) + .WillOnce(::testing::Return(std::vector{expected_response})); duckdb::DataChunk chunk; auto model_struct = CreateModelStruct(); @@ -269,8 +279,10 @@ TEST_F(LLMEmbeddingTest, Operation_LargeInputSet_ProcessesCorrectly) { const nlohmann::json expected_response = PrepareExpectedResponseForLargeInput(input_count); - EXPECT_CALL(*mock_provider, CallEmbedding(::testing::_)) - .WillOnce(::testing::Return(expected_response)); + EXPECT_CALL(*mock_provider, AddEmbeddingRequest(::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectEmbeddings(::testing::_)) + .WillOnce(::testing::Return(std::vector{expected_response})); duckdb::DataChunk chunk; CreateEmbeddingChunk(chunk, input_count); @@ -305,8 +317,10 @@ TEST_F(LLMEmbeddingTest, Operation_LargeInputSet_ProcessesCorrectly) { TEST_F(LLMEmbeddingTest, Operation_ConcatenatedFields_ProcessesCorrectly) { const nlohmann::json expected_response = GetExpectedJsonResponse(); - EXPECT_CALL(*mock_provider, CallEmbedding(::testing::_)) - .WillOnce(::testing::Return(expected_response)); + EXPECT_CALL(*mock_provider, AddEmbeddingRequest(::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectEmbeddings(::testing::_)) + .WillOnce(::testing::Return(std::vector{expected_response})); duckdb::DataChunk chunk; auto model_struct = CreateModelStruct(); diff --git a/test/functions/scalar/llm_filter.cpp b/test/functions/scalar/llm_filter.cpp index a0120bae..6bfd2a72 100644 --- a/test/functions/scalar/llm_filter.cpp +++ b/test/functions/scalar/llm_filter.cpp @@ -60,8 +60,10 @@ class LLMFilterTest : public LLMFunctionTestBase { // Test llm_filter with SQL queries TEST_F(LLMFilterTest, LLMFilterWithInputColumns) { const nlohmann::json expected_response = {{"items", {true}}}; - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) - .WillRepeatedly(::testing::Return(expected_response)); + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(::testing::AtLeast(1)); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) + .WillRepeatedly(::testing::Return(std::vector{expected_response})); auto con = Config::GetConnection(); const auto results = con.Query("SELECT " + GetFunctionName() + "({'model_name': 'gpt-4o'},{'prompt': 'Is positive?'}, {'input': i}) AS filter FROM range(1) AS tbl(i);"); @@ -71,8 +73,10 @@ TEST_F(LLMFilterTest, LLMFilterWithInputColumns) { TEST_F(LLMFilterTest, LLMFilterWithMultipleInputs) { const nlohmann::json expected_response = {{"items", {false}}}; - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) - .WillOnce(::testing::Return(expected_response)); + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) + .WillOnce(::testing::Return(std::vector{expected_response})); auto con = Config::GetConnection(); const auto results = con.Query("SELECT " + GetFunctionName() + "({'model_name': 'gpt-4o'},{'prompt': 'Is positive?'}, {'input1': i, 'input2': i}) AS filter FROM range(1) AS tbl(i);"); @@ -86,8 +90,10 @@ TEST_F(LLMFilterTest, ValidateArguments) { TEST_F(LLMFilterTest, Operation_ThreeArguments_RequiredStructure) { const nlohmann::json expected_response = {{"items", {true}}}; - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) - .WillOnce(::testing::Return(expected_response)); + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) + .WillOnce(::testing::Return(std::vector{expected_response})); duckdb::DataChunk chunk; auto model_struct = CreateModelStruct(); @@ -101,7 +107,7 @@ TEST_F(LLMFilterTest, Operation_ThreeArguments_RequiredStructure) { SetStructStringData(chunk.data[1], {{{"prompt", "Does this text express positive sentiment?"}}}); SetStructStringData(chunk.data[2], {{{"sentiment_text", "I love this product!"}}}); - auto results = LlmFilter::Operation(chunk); + auto results = LlmFilter::Operation(chunk, ExecutionMode::ASYNC); EXPECT_EQ(results.size(), 1); EXPECT_EQ(results[0], FormatExpectedResult(expected_response)); @@ -110,8 +116,10 @@ TEST_F(LLMFilterTest, Operation_ThreeArguments_RequiredStructure) { TEST_F(LLMFilterTest, Operation_BatchProcessing) { const std::vector filter_responses = {true, false}; const nlohmann::json expected_response = PrepareExpectedResponseForBatch(filter_responses); - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) - .WillOnce(::testing::Return(expected_response)); + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) + .WillOnce(::testing::Return(std::vector{expected_response})); duckdb::DataChunk chunk; auto model_struct = CreateModelStruct(); @@ -127,7 +135,7 @@ TEST_F(LLMFilterTest, Operation_BatchProcessing) { SetStructStringData(chunk.data[2], {{{"review_text", "Great product"}, {"rating", "5"}}, {{"review_text", "Terrible quality"}, {"rating", "1"}}}); - auto results = LlmFilter::Operation(chunk); + auto results = LlmFilter::Operation(chunk, ExecutionMode::ASYNC); EXPECT_EQ(results.size(), 2); std::vector expected_results; @@ -142,8 +150,10 @@ TEST_F(LLMFilterTest, Operation_BatchProcessing_StringVector) { // Test with vector of strings (for compatibility with base class interface) const std::vector filter_responses = {"true", "false"}; const nlohmann::json expected_response = PrepareExpectedResponseForBatch(filter_responses); - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) - .WillOnce(::testing::Return(expected_response)); + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) + .WillOnce(::testing::Return(std::vector{expected_response})); duckdb::DataChunk chunk; auto model_struct = CreateModelStruct(); @@ -159,7 +169,7 @@ TEST_F(LLMFilterTest, Operation_BatchProcessing_StringVector) { SetStructStringData(chunk.data[2], {{{"text_content", "Great offer!"}}, {{"text_content", "Click here now!"}}}); - auto results = LlmFilter::Operation(chunk); + auto results = LlmFilter::Operation(chunk, ExecutionMode::ASYNC); EXPECT_EQ(results.size(), 2); std::vector expected_results; @@ -183,8 +193,10 @@ TEST_F(LLMFilterTest, Operation_LargeInputSet_ProcessesCorrectly) { const nlohmann::json expected_response = PrepareExpectedResponseForLargeInput(input_count); - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) - .WillOnce(::testing::Return(expected_response)); + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) + .WillOnce(::testing::Return(std::vector{expected_response})); duckdb::DataChunk chunk; auto model_struct = CreateModelStruct(); @@ -207,7 +219,7 @@ TEST_F(LLMFilterTest, Operation_LargeInputSet_ProcessesCorrectly) { SetStructStringData(chunk.data[2], large_input); - const auto results = LlmFilter::Operation(chunk); + const auto results = LlmFilter::Operation(chunk, ExecutionMode::ASYNC); EXPECT_EQ(results.size(), input_count); std::vector expected_strings; @@ -225,31 +237,7 @@ TEST_F(LLMFilterTest, Operation_TwoArguments_ThrowsException) { SetStructStringData(chunk.data[0], {{{"model_name", DEFAULT_MODEL}}}); SetStructStringData(chunk.data[1], {{{"prompt", TEST_PROMPT}}}); - EXPECT_THROW(LlmFilter::Operation(chunk), std::runtime_error); -} - -TEST_F(LLMFilterTest, Operation_NullResponse_HandlesAsTrue) { - // Test that null responses from the model are handled as "True" - const nlohmann::json expected_response = nlohmann::json(nullptr); - EXPECT_CALL(*mock_provider, CallComplete(::testing::_, ::testing::_, ::testing::_)) - .WillOnce(::testing::Return(expected_response)); - - duckdb::DataChunk chunk; - auto model_struct = CreateModelStruct(); - auto prompt_struct = CreatePromptStruct(); - auto input_struct = CreateInputStruct({"text"}); - - chunk.Initialize(duckdb::Allocator::DefaultAllocator(), {model_struct, prompt_struct, input_struct}); - chunk.SetCardinality(1); - - SetStructStringData(chunk.data[0], {{{"model_name", DEFAULT_MODEL}}}); - SetStructStringData(chunk.data[1], {{{"prompt", "Filter this content"}}}); - SetStructStringData(chunk.data[2], {{{"text", "Some content"}}}); - - auto results = LlmFilter::Operation(chunk); - - EXPECT_EQ(results.size(), 1); - EXPECT_EQ(results[0], "true");// Null responses should default to "true" + EXPECT_THROW(LlmFilter::Operation(chunk, ExecutionMode::ASYNC);, std::runtime_error); } }// namespace flockmtl diff --git a/test/functions/scalar/llm_function_test_base.hpp b/test/functions/scalar/llm_function_test_base.hpp index 873a0581..34851c88 100644 --- a/test/functions/scalar/llm_function_test_base.hpp +++ b/test/functions/scalar/llm_function_test_base.hpp @@ -3,7 +3,7 @@ #include "../mock_provider.hpp" #include "flockmtl/core/config.hpp" #include "flockmtl/model_manager/model.hpp" -#include "flockmtl/model_manager/providers/adapters/openai.hpp" +#include "flockmtl/model_manager/providers/provider.hpp" #include "nlohmann/json.hpp" #include #include @@ -20,7 +20,7 @@ class LLMFunctionTestBase : public ::testing::Test { static constexpr const char* DEFAULT_MODEL = "gpt-4o"; static constexpr const char* TEST_PROMPT = "Explain the purpose of FlockMTL."; - std::shared_ptr mock_provider; + std::shared_ptr mock_provider; void SetUp() override; void TearDown() override; diff --git a/test/functions/scalar/llm_function_test_base_instantiations.cpp b/test/functions/scalar/llm_function_test_base_instantiations.cpp index 68c9543f..4c8c5d93 100644 --- a/test/functions/scalar/llm_function_test_base_instantiations.cpp +++ b/test/functions/scalar/llm_function_test_base_instantiations.cpp @@ -13,7 +13,7 @@ void LLMFunctionTestBase::SetUp() { " TYPE OPENAI," " API_KEY 'your-api-key');"); - mock_provider = std::make_shared(); + mock_provider = std::make_shared(ModelDetails{}); Model::SetMockProvider(mock_provider); } @@ -147,7 +147,7 @@ void LLMFunctionTestBase::TestOperationInvalidArguments() { chunk.Initialize(duckdb::Allocator::DefaultAllocator(), {duckdb::LogicalType(duckdb::LogicalTypeId::VARCHAR)}); chunk.SetCardinality(1); - EXPECT_THROW(FunctionClass::Operation(chunk), std::runtime_error); + EXPECT_THROW(FunctionClass::Operation(chunk, ExecutionMode::ASYNC), std::runtime_error); } template @@ -167,7 +167,7 @@ void LLMFunctionTestBase::TestOperationEmptyPrompt() { SetStructStringData(chunk.data[1], {{{"prompt", ""}}}); SetStructStringData(chunk.data[2], {{{"test", "value"}}}); - EXPECT_THROW(FunctionClass::Operation(chunk), std::runtime_error); + EXPECT_THROW(FunctionClass::Operation(chunk, ExecutionMode::ASYNC);, std::runtime_error); return; } catch (const std::runtime_error&) { // Function might not accept 3 arguments, try with 2 @@ -179,7 +179,7 @@ void LLMFunctionTestBase::TestOperationEmptyPrompt() { SetStructStringData(chunk.data[0], {{{"model_name", DEFAULT_MODEL}}}); SetStructStringData(chunk.data[1], {{{"prompt", ""}}}); - EXPECT_THROW(FunctionClass::Operation(chunk), std::runtime_error); + EXPECT_THROW(FunctionClass::Operation(chunk, ExecutionMode::ASYNC);, std::runtime_error); } // Explicit instantiations for all used function classes diff --git a/test/model_manager/model_providers_test.cpp b/test/model_manager/model_providers_test.cpp index b094d9cb..8bb24a8e 100644 --- a/test/model_manager/model_providers_test.cpp +++ b/test/model_manager/model_providers_test.cpp @@ -1,3 +1,4 @@ +#include "../functions/mock_provider.hpp" #include "flockmtl/model_manager/providers/adapters/azure.hpp" #include "flockmtl/model_manager/providers/adapters/ollama.hpp" #include "flockmtl/model_manager/providers/adapters/openai.hpp" @@ -9,36 +10,6 @@ namespace flockmtl { using json = nlohmann::json; -// Mock class for OpenAI API to avoid real API calls during tests -class MockOpenAIProvider : public OpenAIProvider { -public: - explicit MockOpenAIProvider(const ModelDetails& model_details) : OpenAIProvider(model_details) {} - - // Override the API call methods for testing - MOCK_METHOD(nlohmann::json, CallComplete, (const std::string& prompt, bool json_response, OutputType output_type), (override)); - MOCK_METHOD(nlohmann::json, CallEmbedding, (const std::vector& inputs), (override)); -}; - -// Mock class for Azure API to avoid real API calls during tests -class MockAzureProvider : public AzureProvider { -public: - explicit MockAzureProvider(const ModelDetails& model_details) : AzureProvider(model_details) {} - - // Override the API call methods for testing - MOCK_METHOD(nlohmann::json, CallComplete, (const std::string& prompt, bool json_response, OutputType output_type), (override)); - MOCK_METHOD(nlohmann::json, CallEmbedding, (const std::vector& inputs), (override)); -}; - -// Mock class for Ollama API to avoid real API calls during tests -class MockOllamaProvider : public OllamaProvider { -public: - explicit MockOllamaProvider(const ModelDetails& model_details) : OllamaProvider(model_details) {} - - // Override the API call methods for testing - MOCK_METHOD(nlohmann::json, CallComplete, (const std::string& prompt, bool json_response, OutputType output_type), (override)); - MOCK_METHOD(nlohmann::json, CallEmbedding, (const std::vector& inputs), (override)); -}; - // Test OpenAI provider behavior TEST(ModelProvidersTest, OpenAIProviderTest) { ModelDetails model_details; @@ -49,28 +20,36 @@ TEST(ModelProvidersTest, OpenAIProviderTest) { model_details.secret = {{"api_key", "test_api_key"}}; // Create a mock provider - MockOpenAIProvider mock_provider(model_details); + MockProvider mock_provider(model_details); - // Set up mock behavior for CallComplete + // Set up mock behavior for AddCompletionRequest and CollectCompletions const std::string test_prompt = "Test prompt for completion"; const json expected_complete_response = {{"response", "This is a test response"}}; - EXPECT_CALL(mock_provider, CallComplete(test_prompt, true, OutputType::STRING)) - .WillOnce(::testing::Return(expected_complete_response)); + EXPECT_CALL(mock_provider, AddCompletionRequest(test_prompt, true, OutputType::STRING)) + .Times(1); + EXPECT_CALL(mock_provider, CollectCompletions("application/json")) + .WillOnce(::testing::Return(std::vector{expected_complete_response})); - // Set up mock behavior for CallEmbedding + // Set up mock behavior for AddEmbeddingRequest and CollectEmbeddings const std::vector test_inputs = {"Test input for embedding"}; const json expected_embedding_response = json::array({{0.1, 0.2, 0.3, 0.4, 0.5}}); - EXPECT_CALL(mock_provider, CallEmbedding(test_inputs)) - .WillOnce(::testing::Return(expected_embedding_response)); + EXPECT_CALL(mock_provider, AddEmbeddingRequest(test_inputs)) + .Times(1); + EXPECT_CALL(mock_provider, CollectEmbeddings("application/json")) + .WillOnce(::testing::Return(std::vector{expected_embedding_response})); // Test the mocked methods - auto complete_result = mock_provider.CallComplete(test_prompt, true, OutputType::STRING); - EXPECT_EQ(complete_result, expected_complete_response); - - auto embedding_result = mock_provider.CallEmbedding(test_inputs); - EXPECT_EQ(embedding_result, expected_embedding_response); + mock_provider.AddCompletionRequest(test_prompt, true, OutputType::STRING); + auto complete_results = mock_provider.CollectCompletions("application/json"); + ASSERT_EQ(complete_results.size(), 1); + EXPECT_EQ(complete_results[0], expected_complete_response); + + mock_provider.AddEmbeddingRequest(test_inputs); + auto embedding_results = mock_provider.CollectEmbeddings("application/json"); + ASSERT_EQ(embedding_results.size(), 1); + EXPECT_EQ(embedding_results[0], expected_embedding_response); } // Test Azure provider behavior @@ -86,28 +65,36 @@ TEST(ModelProvidersTest, AzureProviderTest) { {"api_version", "2023-05-15"}}; // Create a mock provider - MockAzureProvider mock_provider(model_details); + MockProvider mock_provider(model_details); - // Set up mock behavior for CallComplete + // Set up mock behavior for AddCompletionRequest and CollectCompletions const std::string test_prompt = "Test prompt for completion"; const json expected_complete_response = {{"response", "This is a test response from Azure"}}; - EXPECT_CALL(mock_provider, CallComplete(test_prompt, true, OutputType::STRING)) - .WillOnce(::testing::Return(expected_complete_response)); + EXPECT_CALL(mock_provider, AddCompletionRequest(test_prompt, true, OutputType::STRING)) + .Times(1); + EXPECT_CALL(mock_provider, CollectCompletions("application/json")) + .WillOnce(::testing::Return(std::vector{expected_complete_response})); - // Set up mock behavior for CallEmbedding + // Set up mock behavior for AddEmbeddingRequest and CollectEmbeddings const std::vector test_inputs = {"Test input for embedding"}; const json expected_embedding_response = json::array({{0.5, 0.4, 0.3, 0.2, 0.1}}); - EXPECT_CALL(mock_provider, CallEmbedding(test_inputs)) - .WillOnce(::testing::Return(expected_embedding_response)); + EXPECT_CALL(mock_provider, AddEmbeddingRequest(test_inputs)) + .Times(1); + EXPECT_CALL(mock_provider, CollectEmbeddings("application/json")) + .WillOnce(::testing::Return(std::vector{expected_embedding_response})); // Test the mocked methods - auto complete_result = mock_provider.CallComplete(test_prompt, true, OutputType::STRING); - EXPECT_EQ(complete_result, expected_complete_response); - - auto embedding_result = mock_provider.CallEmbedding(test_inputs); - EXPECT_EQ(embedding_result, expected_embedding_response); + mock_provider.AddCompletionRequest(test_prompt, true, OutputType::STRING); + auto complete_results = mock_provider.CollectCompletions("application/json"); + ASSERT_EQ(complete_results.size(), 1); + EXPECT_EQ(complete_results[0], expected_complete_response); + + mock_provider.AddEmbeddingRequest(test_inputs); + auto embedding_results = mock_provider.CollectEmbeddings("application/json"); + ASSERT_EQ(embedding_results.size(), 1); + EXPECT_EQ(embedding_results[0], expected_embedding_response); } // Test Ollama provider behavior @@ -120,28 +107,36 @@ TEST(ModelProvidersTest, OllamaProviderTest) { model_details.secret = {{"api_url", "http://localhost:11434"}}; // Create a mock provider - MockOllamaProvider mock_provider(model_details); + MockProvider mock_provider(model_details); - // Set up mock behavior for CallComplete + // Set up mock behavior for AddCompletionRequest and CollectCompletions const std::string test_prompt = "Test prompt for Ollama completion"; const json expected_complete_response = {{"response", "This is a test response from Ollama"}}; - EXPECT_CALL(mock_provider, CallComplete(test_prompt, true, OutputType::STRING)) - .WillOnce(::testing::Return(expected_complete_response)); + EXPECT_CALL(mock_provider, AddCompletionRequest(test_prompt, true, OutputType::STRING)) + .Times(1); + EXPECT_CALL(mock_provider, CollectCompletions("application/json")) + .WillOnce(::testing::Return(std::vector{expected_complete_response})); - // Set up mock behavior for CallEmbedding + // Set up mock behavior for AddEmbeddingRequest and CollectEmbeddings const std::vector test_inputs = {"Test input for Ollama embedding"}; const json expected_embedding_response = json::array({{0.7, 0.6, 0.5, 0.4, 0.3}}); - EXPECT_CALL(mock_provider, CallEmbedding(test_inputs)) - .WillOnce(::testing::Return(expected_embedding_response)); + EXPECT_CALL(mock_provider, AddEmbeddingRequest(test_inputs)) + .Times(1); + EXPECT_CALL(mock_provider, CollectEmbeddings("application/json")) + .WillOnce(::testing::Return(std::vector{expected_embedding_response})); // Test the mocked methods - auto complete_result = mock_provider.CallComplete(test_prompt, true, OutputType::STRING); - EXPECT_EQ(complete_result, expected_complete_response); - - auto embedding_result = mock_provider.CallEmbedding(test_inputs); - EXPECT_EQ(embedding_result, expected_embedding_response); + mock_provider.AddCompletionRequest(test_prompt, true, OutputType::STRING); + auto complete_results = mock_provider.CollectCompletions("application/json"); + ASSERT_EQ(complete_results.size(), 1); + EXPECT_EQ(complete_results[0], expected_complete_response); + + mock_provider.AddEmbeddingRequest(test_inputs); + auto embedding_results = mock_provider.CollectEmbeddings("application/json"); + ASSERT_EQ(embedding_results.size(), 1); + EXPECT_EQ(embedding_results[0], expected_embedding_response); } }// namespace flockmtl \ No newline at end of file diff --git a/test/prompt_manager/prompt_manager_test.cpp b/test/prompt_manager/prompt_manager_test.cpp index 5f732aba..fc094c1e 100644 --- a/test/prompt_manager/prompt_manager_test.cpp +++ b/test/prompt_manager/prompt_manager_test.cpp @@ -60,13 +60,12 @@ TEST(PromptManager, ReplaceSectionString) { EXPECT_EQ(result, "Replace and but not [that]."); } -// Test cases for PromptManager::ConstructInputTuplesHeader TEST(PromptManager, ConstructInputTuplesHeader) { json tuple = {{{"col1", "val1"}, {"col2", 123}}}; // XML auto xml_header = PromptManager::ConstructInputTuplesHeader(tuple, "xml"); - EXPECT_EQ(xml_header, "col1col2\n"); + EXPECT_EQ(xml_header, "
col1col2
\n"); // Markdown auto md_header = PromptManager::ConstructInputTuplesHeader(tuple, "markdown"); @@ -114,7 +113,7 @@ TEST(PromptManager, ConstructInputTuples) { // XML auto xml_expected = std::string("- The Number of Tuples to Generate Responses for: 2\n\n"); - xml_expected += "colAcolB\n"; + xml_expected += "
colAcolB
\n"; xml_expected += "\"row1A\"1\n"; xml_expected += "\"row2A\"2\n"; EXPECT_EQ(PromptManager::ConstructInputTuples(tuples, "xml"), xml_expected); @@ -143,7 +142,7 @@ TEST(PromptManager, ConstructInputTuplesEmpty) { // Empty tuples - XML auto xml_expected = std::string("- The Number of Tuples to Generate Responses for: 0\n\n"); - xml_expected += "Empty\n"; + xml_expected += "
\n"; EXPECT_EQ(PromptManager::ConstructInputTuples(empty_tuples, "xml"), xml_expected); // Empty tuples - Markdown