Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 50 additions & 4 deletions src/functions/aggregate/llm_first_or_last/implementation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ namespace flockmtl {
int LlmFirstOrLast::GetFirstOrLastTupleId(const nlohmann::json& tuples) {
nlohmann::json data;
const auto prompt = PromptManager::Render(user_query, tuples, function_type, model.GetModelDetails().tuple_format);
auto response = model.CallComplete(prompt, true, OutputType::INTEGER);
model.AddCompletionRequest(prompt, true, OutputType::INTEGER);
auto response = model.CollectCompletions()[0];
return response["items"][0];
}

nlohmann::json LlmFirstOrLast::Evaluate(nlohmann::json& tuples) {
template<>
nlohmann::json LlmFirstOrLast::Evaluate<ExecutionMode::SYNC>(nlohmann::json& tuples) {
auto batch_tuples = nlohmann::json::array();
int start_index = 0;
auto batch_size = std::min<int>(model.GetModelDetails().batch_size, static_cast<int>(tuples.size()));
Expand Down Expand Up @@ -46,9 +48,43 @@ nlohmann::json LlmFirstOrLast::Evaluate(nlohmann::json& tuples) {
return batch_tuples[0];
}

template<>
nlohmann::json LlmFirstOrLast::Evaluate<ExecutionMode::ASYNC>(nlohmann::json& tuples) {
auto batch_size = std::min<int>(model.GetModelDetails().batch_size, static_cast<int>(tuples.size()));
if (batch_size <= 0) {
throw std::runtime_error("Batch size must be greater than zero");
}

std::vector<nlohmann::json> current_tuples = tuples;

do {
auto start_index = 0;
const auto n = static_cast<int>(current_tuples.size());
while (start_index < n) {
auto this_batch_size = std::min<int>(batch_size, n - start_index);
nlohmann::json batch = nlohmann::json::array();
for (int i = 0; i < this_batch_size; ++i) {
batch.push_back(current_tuples[start_index + i]);
}
const auto prompt = PromptManager::Render(user_query, batch, function_type, model.GetModelDetails().tuple_format);
model.AddCompletionRequest(prompt, true, OutputType::INTEGER);
start_index += this_batch_size;
}
std::vector<nlohmann::json> new_tuples;
auto responses = model.CollectCompletions();
for (size_t i = 0; i < responses.size(); ++i) {
int result_idx = responses[i]["items"][0];
new_tuples.push_back(current_tuples[result_idx]);
}
current_tuples = std::move(new_tuples);
} while (current_tuples.size() > 1);
current_tuples[0].erase("flockmtl_tuple_id");
return current_tuples[0];
}

void LlmFirstOrLast::FinalizeResults(duckdb::Vector& states, duckdb::AggregateInputData& aggr_input_data,
duckdb::Vector& result, idx_t count, idx_t offset,
AggregateFunctionType function_type) {
AggregateFunctionType function_type, ExecutionMode mode) {
const auto states_vector = reinterpret_cast<AggregateFunctionState**>(duckdb::FlatVector::GetData<duckdb::data_ptr_t>(states));
auto function_instance = AggregateFunctionBase::GetInstance<LlmFirstOrLast>();
function_instance->function_type = function_type;
Expand All @@ -64,7 +100,17 @@ void LlmFirstOrLast::FinalizeResults(duckdb::Vector& states, duckdb::AggregateIn
tuple_with_id["flockmtl_tuple_id"] = j;
tuples_with_ids.push_back(tuple_with_id);
}
auto response = function_instance->Evaluate(tuples_with_ids);
nlohmann::json response;
switch (mode) {
case ExecutionMode::ASYNC:
response = function_instance->Evaluate<ExecutionMode::ASYNC>(tuples_with_ids);
break;
case ExecutionMode::SYNC:
response = function_instance->Evaluate<ExecutionMode::SYNC>(tuples_with_ids);
break;
default:
break;
}
result.SetValue(idx, response.dump());
} else {
result.SetValue(idx, "{}");// Empty JSON object for null/empty states
Expand Down
14 changes: 9 additions & 5 deletions src/functions/aggregate/llm_first_or_last/instantiations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@ template void AggregateFunctionBase::SimpleUpdate<LlmFirstOrLast>(duckdb::Vector
duckdb::data_ptr_t, idx_t);
template void AggregateFunctionBase::Combine<LlmFirstOrLast>(duckdb::Vector&, duckdb::Vector&,
duckdb::AggregateInputData&, idx_t);
template void LlmFirstOrLast::Finalize<AggregateFunctionType::FIRST>(duckdb::Vector&, duckdb::AggregateInputData&,
duckdb::Vector&, idx_t, idx_t);
template void LlmFirstOrLast::Finalize<AggregateFunctionType::LAST>(duckdb::Vector&, duckdb::AggregateInputData&,
duckdb::Vector&, idx_t, idx_t);
template void LlmFirstOrLast::Finalize<AggregateFunctionType::FIRST, ExecutionMode::ASYNC>(duckdb::Vector&, duckdb::AggregateInputData&,
duckdb::Vector&, idx_t, idx_t);
template void LlmFirstOrLast::Finalize<AggregateFunctionType::FIRST, ExecutionMode::SYNC>(duckdb::Vector&, duckdb::AggregateInputData&,
duckdb::Vector&, idx_t, idx_t);
template void LlmFirstOrLast::Finalize<AggregateFunctionType::LAST, ExecutionMode::ASYNC>(duckdb::Vector&, duckdb::AggregateInputData&,
duckdb::Vector&, idx_t, idx_t);
template void LlmFirstOrLast::Finalize<AggregateFunctionType::LAST, ExecutionMode::SYNC>(duckdb::Vector&, duckdb::AggregateInputData&,
duckdb::Vector&, idx_t, idx_t);

} // namespace flockmtl
}// namespace flockmtl
40 changes: 24 additions & 16 deletions src/functions/aggregate/llm_first_or_last/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,33 @@
namespace flockmtl {

void AggregateRegistry::RegisterLlmFirst(duckdb::DatabaseInstance& db) {
auto string_concat = duckdb::AggregateFunction(
"llm_first", {duckdb::LogicalType::ANY, duckdb::LogicalType::ANY, duckdb::LogicalType::ANY},
duckdb::LogicalType::VARCHAR, duckdb::AggregateFunction::StateSize<AggregateFunctionState>,
LlmFirstOrLast::Initialize, LlmFirstOrLast::Operation, LlmFirstOrLast::Combine,
LlmFirstOrLast::Finalize<AggregateFunctionType::FIRST>, LlmFirstOrLast::SimpleUpdate,
nullptr, LlmFirstOrLast::Destroy);

duckdb::ExtensionUtil::RegisterFunction(db, string_concat);
duckdb::ExtensionUtil::RegisterFunction(db, duckdb::AggregateFunction(
"llm_first", {duckdb::LogicalType::ANY, duckdb::LogicalType::ANY, duckdb::LogicalType::ANY},
duckdb::LogicalType::VARCHAR, duckdb::AggregateFunction::StateSize<AggregateFunctionState>,
LlmFirstOrLast::Initialize, LlmFirstOrLast::Operation, LlmFirstOrLast::Combine,
LlmFirstOrLast::Finalize<AggregateFunctionType::FIRST, ExecutionMode::ASYNC>, LlmFirstOrLast::SimpleUpdate,
nullptr, LlmFirstOrLast::Destroy));
duckdb::ExtensionUtil::RegisterFunction(db, duckdb::AggregateFunction(
"llm_first_s", {duckdb::LogicalType::ANY, duckdb::LogicalType::ANY, duckdb::LogicalType::ANY},
duckdb::LogicalType::VARCHAR, duckdb::AggregateFunction::StateSize<AggregateFunctionState>,
LlmFirstOrLast::Initialize, LlmFirstOrLast::Operation, LlmFirstOrLast::Combine,
LlmFirstOrLast::Finalize<AggregateFunctionType::FIRST, ExecutionMode::SYNC>, LlmFirstOrLast::SimpleUpdate,
nullptr, LlmFirstOrLast::Destroy));
}

void AggregateRegistry::RegisterLlmLast(duckdb::DatabaseInstance& db) {
auto string_concat = duckdb::AggregateFunction(
"llm_last", {duckdb::LogicalType::ANY, duckdb::LogicalType::ANY, duckdb::LogicalType::ANY},
duckdb::LogicalType::VARCHAR, duckdb::AggregateFunction::StateSize<AggregateFunctionState>,
LlmFirstOrLast::Initialize, LlmFirstOrLast::Operation, LlmFirstOrLast::Combine,
LlmFirstOrLast::Finalize<AggregateFunctionType::LAST>, LlmFirstOrLast::SimpleUpdate,
nullptr, LlmFirstOrLast::Destroy);

duckdb::ExtensionUtil::RegisterFunction(db, string_concat);
duckdb::ExtensionUtil::RegisterFunction(db, duckdb::AggregateFunction(
"llm_last", {duckdb::LogicalType::ANY, duckdb::LogicalType::ANY, duckdb::LogicalType::ANY},
duckdb::LogicalType::VARCHAR, duckdb::AggregateFunction::StateSize<AggregateFunctionState>,
LlmFirstOrLast::Initialize, LlmFirstOrLast::Operation, LlmFirstOrLast::Combine,
LlmFirstOrLast::Finalize<AggregateFunctionType::LAST, ExecutionMode::ASYNC>, LlmFirstOrLast::SimpleUpdate,
nullptr, LlmFirstOrLast::Destroy));
duckdb::ExtensionUtil::RegisterFunction(db, duckdb::AggregateFunction(
"llm_last_s", {duckdb::LogicalType::ANY, duckdb::LogicalType::ANY, duckdb::LogicalType::ANY},
duckdb::LogicalType::VARCHAR, duckdb::AggregateFunction::StateSize<AggregateFunctionState>,
LlmFirstOrLast::Initialize, LlmFirstOrLast::Operation, LlmFirstOrLast::Combine,
LlmFirstOrLast::Finalize<AggregateFunctionType::LAST, ExecutionMode::SYNC>, LlmFirstOrLast::SimpleUpdate,
nullptr, LlmFirstOrLast::Destroy));
}

}// namespace flockmtl
62 changes: 57 additions & 5 deletions src/functions/aggregate/llm_reduce/implementation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ nlohmann::json LlmReduce::ReduceBatch(const nlohmann::json& tuples, const Aggreg
nlohmann::json data;
const auto prompt = PromptManager::Render(user_query, tuples, function_type, model.GetModelDetails().tuple_format);
OutputType output_type = OutputType::STRING;
auto response = model.CallComplete(prompt, true, output_type);
model.AddCompletionRequest(prompt, true, output_type);
auto response = model.CollectCompletions()[0];
return response["items"][0];
};

nlohmann::json LlmReduce::ReduceLoop(const std::vector<nlohmann::json>& tuples,
const AggregateFunctionType& function_type) {
template<>
nlohmann::json LlmReduce::ReduceLoop<ExecutionMode::SYNC>(const std::vector<nlohmann::json>& tuples,
const AggregateFunctionType& function_type) {
auto batch_tuples = nlohmann::json::array();
int start_index = 0;
auto batch_size = std::min<int>(model.GetModelDetails().batch_size, static_cast<int>(tuples.size()));
Expand Down Expand Up @@ -46,9 +48,49 @@ nlohmann::json LlmReduce::ReduceLoop(const std::vector<nlohmann::json>& tuples,
return batch_tuples[0];
}

template<>
nlohmann::json LlmReduce::ReduceLoop<ExecutionMode::ASYNC>(const std::vector<nlohmann::json>& tuples,
const AggregateFunctionType& function_type) {

auto batch_size = std::min<int>(model.GetModelDetails().batch_size, static_cast<int>(tuples.size()));
if (batch_size <= 0) {
throw std::runtime_error("Batch size must be greater than zero");
}

std::vector<nlohmann::json> current_tuples = tuples;

do {
auto start_index = 0;
const auto n = static_cast<int>(current_tuples.size());

// Prepare all batches and add all completion requests
while (start_index < n) {
auto this_batch_size = std::min<int>(batch_size, n - start_index);
nlohmann::json batch = nlohmann::json::array();
for (int i = 0; i < this_batch_size; ++i) {
batch.push_back(current_tuples[start_index + i]);
}
const auto prompt = PromptManager::Render(user_query, batch, function_type, model.GetModelDetails().tuple_format);
OutputType output_type = OutputType::STRING;
model.AddCompletionRequest(prompt, true, output_type);
start_index += this_batch_size;
}

// Collect all completions at once
std::vector<nlohmann::json> new_tuples;
auto responses = model.CollectCompletions();
for (size_t i = 0; i < responses.size(); ++i) {
new_tuples.push_back(responses[i]["items"][0]);
}
current_tuples = std::move(new_tuples);
} while (current_tuples.size() > 1);

return current_tuples[0];
}

void LlmReduce::FinalizeResults(duckdb::Vector& states, duckdb::AggregateInputData& aggr_input_data,
duckdb::Vector& result, idx_t count, idx_t offset,
const AggregateFunctionType function_type) {
const AggregateFunctionType function_type, ExecutionMode mode) {
const auto states_vector = reinterpret_cast<AggregateFunctionState**>(duckdb::FlatVector::GetData<duckdb::data_ptr_t>(states));

auto function_instance = AggregateFunctionBase::GetInstance<LlmReduce>();
Expand All @@ -57,7 +99,17 @@ void LlmReduce::FinalizeResults(duckdb::Vector& states, duckdb::AggregateInputDa
auto* state = states_vector[idx];

if (state && state->value) {
auto response = function_instance->ReduceLoop(*state->value, function_type);
nlohmann::json response;
switch (mode) {
case ExecutionMode::SYNC:
response = function_instance->ReduceLoop<ExecutionMode::ASYNC>(*state->value, function_type);
break;
case ExecutionMode::ASYNC:
response = function_instance->ReduceLoop<ExecutionMode::SYNC>(*state->value, function_type);
break;
default:
break;
}
if (response.is_string()) {
result.SetValue(idx, response.get<std::string>());
} else {
Expand Down
6 changes: 4 additions & 2 deletions src/functions/aggregate/llm_reduce/instantiations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ template void AggregateFunctionBase::SimpleUpdate<LlmReduce>(duckdb::Vector[], d
duckdb::data_ptr_t, idx_t);
template void AggregateFunctionBase::Combine<LlmReduce>(duckdb::Vector&, duckdb::Vector&, duckdb::AggregateInputData&,
idx_t);
template void LlmReduce::Finalize<AggregateFunctionType::REDUCE>(duckdb::Vector&, duckdb::AggregateInputData&,
duckdb::Vector&, idx_t, idx_t);
template void LlmReduce::Finalize<AggregateFunctionType::REDUCE, ExecutionMode::ASYNC>(duckdb::Vector&, duckdb::AggregateInputData&,
duckdb::Vector&, idx_t, idx_t);
template void LlmReduce::Finalize<AggregateFunctionType::REDUCE, ExecutionMode::SYNC>(duckdb::Vector&, duckdb::AggregateInputData&,
duckdb::Vector&, idx_t, idx_t);

}// namespace flockmtl
20 changes: 12 additions & 8 deletions src/functions/aggregate/llm_reduce/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@
namespace flockmtl {

void AggregateRegistry::RegisterLlmReduce(duckdb::DatabaseInstance& db) {
auto string_concat = duckdb::AggregateFunction(
"llm_reduce", {duckdb::LogicalType::ANY, duckdb::LogicalType::ANY, duckdb::LogicalType::ANY},
duckdb::LogicalType::JSON(), duckdb::AggregateFunction::StateSize<AggregateFunctionState>,
LlmReduce::Initialize, LlmReduce::Operation, LlmReduce::Combine,
LlmReduce::Finalize<AggregateFunctionType::REDUCE>, LlmReduce::SimpleUpdate,
nullptr, LlmReduce::Destroy);

duckdb::ExtensionUtil::RegisterFunction(db, string_concat);
duckdb::ExtensionUtil::RegisterFunction(db, duckdb::AggregateFunction(
"llm_reduce", {duckdb::LogicalType::ANY, duckdb::LogicalType::ANY, duckdb::LogicalType::ANY},
duckdb::LogicalType::JSON(), duckdb::AggregateFunction::StateSize<AggregateFunctionState>,
LlmReduce::Initialize, LlmReduce::Operation, LlmReduce::Combine,
LlmReduce::Finalize<AggregateFunctionType::REDUCE, ExecutionMode::ASYNC>, LlmReduce::SimpleUpdate,
nullptr, LlmReduce::Destroy));
duckdb::ExtensionUtil::RegisterFunction(db, duckdb::AggregateFunction(
"llm_reduce_s", {duckdb::LogicalType::ANY, duckdb::LogicalType::ANY, duckdb::LogicalType::ANY},
duckdb::LogicalType::JSON(), duckdb::AggregateFunction::StateSize<AggregateFunctionState>,
LlmReduce::Initialize, LlmReduce::Operation, LlmReduce::Combine,
LlmReduce::Finalize<AggregateFunctionType::REDUCE, ExecutionMode::SYNC>, LlmReduce::SimpleUpdate,
nullptr, LlmReduce::Destroy));
}

}// namespace flockmtl
5 changes: 3 additions & 2 deletions src/functions/aggregate/llm_rerank/implementation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ std::vector<int> LlmRerank::RerankBatch(const nlohmann::json& tuples) {
nlohmann::json data;
auto prompt =
PromptManager::Render(user_query, tuples, AggregateFunctionType::RERANK, model.GetModelDetails().tuple_format);
auto response = model.CallComplete(prompt, true, OutputType::INTEGER);
return response["items"];
model.AddCompletionRequest(prompt, true, OutputType::INTEGER);
auto responses = model.CollectCompletions();
return responses[0]["items"];
};

nlohmann::json LlmRerank::SlidingWindow(nlohmann::json& tuples) {
Expand Down
12 changes: 5 additions & 7 deletions src/functions/aggregate/llm_rerank/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
namespace flockmtl {

void AggregateRegistry::RegisterLlmRerank(duckdb::DatabaseInstance& db) {
auto string_concat = duckdb::AggregateFunction(
"llm_rerank", {duckdb::LogicalType::ANY, duckdb::LogicalType::ANY, duckdb::LogicalType::ANY},
duckdb::LogicalType::VARCHAR, duckdb::AggregateFunction::StateSize<AggregateFunctionState>,
LlmRerank::Initialize, LlmRerank::Operation, LlmRerank::Combine, LlmRerank::Finalize, LlmRerank::SimpleUpdate,
nullptr, LlmRerank::Destroy);

duckdb::ExtensionUtil::RegisterFunction(db, string_concat);
duckdb::ExtensionUtil::RegisterFunction(db, duckdb::AggregateFunction(
"llm_rerank", {duckdb::LogicalType::ANY, duckdb::LogicalType::ANY, duckdb::LogicalType::ANY},
duckdb::LogicalType::VARCHAR, duckdb::AggregateFunction::StateSize<AggregateFunctionState>,
LlmRerank::Initialize, LlmRerank::Operation, LlmRerank::Combine, LlmRerank::Finalize, LlmRerank::SimpleUpdate,
nullptr, LlmRerank::Destroy));
}

}// namespace flockmtl
Loading