From 4fa1fea494436035f62e996710c6cc021c68cf6b Mon Sep 17 00:00:00 2001 From: Hannah Bast Date: Wed, 19 Feb 2025 13:39:16 +0100 Subject: [PATCH 1/7] Implement comparison of two geo points (#1801) This was not implemented so far and crashed the server when two geo points were compared. Also add `static_assert`s that will make such crashes less likely when datatypes are added in the future. Fixes #1791 --- src/global/ValueIdComparators.h | 22 ++++++++++++++-------- src/index/IndexImpl.cpp | 2 ++ 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/global/ValueIdComparators.h b/src/global/ValueIdComparators.h index 62338ef4ed..7f127b1887 100644 --- a/src/global/ValueIdComparators.h +++ b/src/global/ValueIdComparators.h @@ -505,17 +505,23 @@ ComparisonResult compareIdsImpl(ValueId a, ValueId b, auto comparator) { return fromBool(std::invoke(comparator, a, b)); } - auto visitor = [comparator]( + // If both are geo points, compare the raw IDs. + if (a.getDatatype() == Datatype::GeoPoint && + b.getDatatype() == Datatype::GeoPoint) { + return fromBool(std::invoke(comparator, a.getBits(), b.getBits())); + } + + auto visitor = [comparator, &a, &b]( const A& aValue, const B& bValue) -> ComparisonResult { - if constexpr (std::is_same_v && - std::is_same_v) { - // We have handled this case outside the visitor. - AD_FAIL(); - } else if constexpr (requires() { - std::invoke(comparator, aValue, bValue); - }) { + if constexpr (requires() { std::invoke(comparator, aValue, bValue); }) { return fromBool(std::invoke(comparator, aValue, bValue)); } else { + static_assert((!std::is_same_v) || + ad_utility::SameAsAny); + AD_LOG_ERROR << "Comparison not implemented for types " + << toString(a.getDatatype()) << " and " + << toString(b.getDatatype()) << std::endl; AD_FAIL(); } }; diff --git a/src/index/IndexImpl.cpp b/src/index/IndexImpl.cpp index 24a8f1c77e..3b2cc58d1a 100644 --- a/src/index/IndexImpl.cpp +++ b/src/index/IndexImpl.cpp @@ -551,6 +551,8 @@ IndexBuilderDataAsStxxlVector IndexImpl::passFileForVocabulary( AD_LOG_INFO << "Number of triples created (including QLever-internal ones): " << (*idTriples.wlock())->size() << " [may contain duplicates]" << std::endl; + AD_LOG_INFO << "Number of partial vocabularies created: " << numFiles + << std::endl; size_t sizeInternalVocabulary = 0; std::vector prefixes; From dfcd08799542900e7dc061acd8e9cd2653d6fde7 Mon Sep 17 00:00:00 2001 From: Johannes Kalmbach Date: Wed, 19 Feb 2025 15:23:32 +0100 Subject: [PATCH 2/7] Use mold for the coverage build. Signed-off-by: Johannes Kalmbach --- .github/workflows/code-coverage.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/code-coverage.yml b/.github/workflows/code-coverage.yml index fddf5f2174..c9f46032dd 100644 --- a/.github/workflows/code-coverage.yml +++ b/.github/workflows/code-coverage.yml @@ -51,6 +51,7 @@ jobs: - name: Install coverage tools run: | sudo apt install -y llvm-16 + sudo apt install mold - name: Show path run: | which llvm-profdata-16 @@ -60,7 +61,7 @@ jobs: - name: Configure CMake # Configure CMake in a 'build' subdirectory. `CMAKE_BUILD_TYPE` is only required if you are using a single-configuration generator such as make. # See https://cmake.org/cmake/help/latest/variable/CMAKE_BUILD_TYPE.html?highlight=cmake_build_type - run: cmake -B ${{github.workspace}}/build ${{env.cmake-flags}} -DCMAKE_BUILD_TYPE=${{env.build-type}} -DLOGLEVEL=TIMING -DADDITIONAL_COMPILER_FLAGS="${{env.warnings}} ${{env.asan-flags}} ${{env.ubsan-flags}} ${{env.coverage-flags}}" -DADDITIONAL_LINKER_FLAGS="${{env.coverage-flags}}" -DUSE_PARALLEL=false -DRUN_EXPENSIVE_TESTS=false -DSINGLE_TEST_BINARY=ON -DENABLE_EXPENSIVE_CHECKS=true + run: cmake -B ${{github.workspace}}/build ${{env.cmake-flags}} -DCMAKE_BUILD_TYPE=${{env.build-type}} -DLOGLEVEL=TIMING -DADDITIONAL_COMPILER_FLAGS="${{env.warnings}} ${{env.asan-flags}} ${{env.ubsan-flags}} ${{env.coverage-flags}}" -DADDITIONAL_LINKER_FLAGS="${{env.coverage-flags}}" -DUSE_PARALLEL=false -DRUN_EXPENSIVE_TESTS=false -DSINGLE_TEST_BINARY=ON -DENABLE_EXPENSIVE_CHECKS=true -DADDITIONAL_LINKER_FLAGS="-fuse-ld=mold" - name: Build # Build your program with the given configuration From b93b1010bfefa76769b4360eb0af606ae1acc0ca Mon Sep 17 00:00:00 2001 From: Johannes Kalmbach Date: Wed, 19 Feb 2025 15:35:41 +0100 Subject: [PATCH 3/7] Fix the Conan MacOS build (#1814) For some time now, the CI build on MacOS using the Conan package manager has been broken, because the building of the ICU dependency failed. This PR fixes this build again. The background is, that we have to link against a custom (brew installed) standard library, which makes the whole build process a lot more brittle as several flags have to be set manually. --- .github/workflows/macos.yml | 20 +++++++++++++++----- conanfile.txt | 2 +- conanprofiles/clang-16-macos | 3 +++ test/engine/ExistsJoinTest.cpp | 2 +- 4 files changed, 20 insertions(+), 7 deletions(-) diff --git a/.github/workflows/macos.yml b/.github/workflows/macos.yml index 3caa0d94f1..598c16591f 100644 --- a/.github/workflows/macos.yml +++ b/.github/workflows/macos.yml @@ -37,6 +37,9 @@ jobs: run: | pip3 install pyaml pyicu + - name: Install conan + run: | + brew install conan@2 - name: Install dependencies run: | brew install llvm@16 @@ -46,7 +49,13 @@ jobs: echo 'export LDFLAGS="-L/usr/local/opt/llvm@16/lib -L/usr/local/opt/llvm@16/lib/c++ -Wl,-rpath,/usr/local/opt/llvm@16/lib/c++"' >> ~/.bash_profile echo LDFLAGS="-L/usr/local/opt/llvm@16/lib -L/usr/local/opt/llvm@16/lib/c++ -Wl,-rpath,/usr/local/opt/llvm@16/lib/c++" >> $GITHUB_ENV echo 'export CPPFLAGS="-I/usr/local/opt/llvm@16/include"' >> ~/.bash_profile - echo CPPFLAGS="/usr/local/opt/llvm@16/include" >> $GITHUB_ENV + echo CPPFLAGS="-I/usr/local/opt/llvm@16/include" >> $GITHUB_ENV + echo 'export CFLAGS="-I/usr/local/opt/llvm@16/include"' >> ~/.bash_profile + echo CFLAGS="-I/usr/local/opt/llvm@16/include" >> $GITHUB_ENV + echo 'export CC="/usr/local/opt/llvm@16/bin/clang"' >> ~/.bash_profile + echo CC="/usr/local/opt/llvm@16/bin/clang" >> $GITHUB_ENV + echo 'export CXX="/usr/local/opt/llvm@16/bin/clang++"' >> ~/.bash_profile + echo CXX="/usr/local/opt/llvm@16/bin/clang++" >> $GITHUB_ENV source ~/.bash_profile - name: Print clang version run: clang++ --version @@ -54,7 +63,7 @@ jobs: - name: Cache for conan uses: actions/cache@v3 env: - cache-name: cache-conan-modules + cache-name: cache-conan-modules-macos-13 with: path: ~/.conan2 key: ${{ runner.os }}-build-${{ env.cache-name }}-${{ hashFiles('conanfile.txt') }} @@ -62,11 +71,12 @@ jobs: run: mkdir ${{github.workspace}}/build - name: Install and run conan working-directory: ${{github.workspace}}/build - run: > - conan install .. -pr:b=../conanprofiles/clang-16-macos -pr:h=../conanprofiles/clang-16-macos -of=. --build=missing; + run: | + conan install .. -pr:b=../conanprofiles/clang-16-macos -pr:h=../conanprofiles/clang-16-macos -of=. --build=missing - name: Configure CMake # For std::ranges::join_view we need the -fexperimental-library flag on libc++16, which on Mac requires to manually tinker with the linking flags. - run: cmake -B ${{github.workspace}}/build -DCMAKE_BUILD_TYPE=${{matrix.build-type}} -DCMAKE_TOOLCHAIN_FILE="$(pwd)/build/conan_toolchain.cmake" -DUSE_PARALLEL=true -DRUN_EXPENSIVE_TESTS=false -DENABLE_EXPENSIVE_CHECKS=true -DCMAKE_CXX_COMPILER=clang++ -DADDITIONAL_COMPILER_FLAGS="-fexperimental-library" -D_NO_TIMING_TESTS=ON -DADDITIONAL_LINKER_FLAGS="-L$(brew --prefix llvm)/lib/c++" + # We currently cannot use the parallel algorithms, as the parallel sort requires a GNU-extension, and we build with `libc++`. + run: cmake -B ${{github.workspace}}/build -DCMAKE_BUILD_TYPE=${{matrix.build-type}} -DCMAKE_TOOLCHAIN_FILE="$(pwd)/build/conan_toolchain.cmake" -DUSE_PARALLEL=false -DRUN_EXPENSIVE_TESTS=false -DENABLE_EXPENSIVE_CHECKS=true -DCMAKE_CXX_COMPILER=clang++ -DADDITIONAL_COMPILER_FLAGS="-fexperimental-library" -D_NO_TIMING_TESTS=ON -DADDITIONAL_LINKER_FLAGS="-L$(brew --prefix llvm)/lib/c++" - name: Build # Build your program with the given configuration diff --git a/conanfile.txt b/conanfile.txt index 7001b9f64b..def2099676 100644 --- a/conanfile.txt +++ b/conanfile.txt @@ -1,6 +1,6 @@ [requires] boost/1.81.0 -icu/73.1 +icu/76.1 openssl/3.1.1 zstd/1.5.5 # The jemalloc recipe for Conan2 is currently broken, uncomment this line as soon as this is fixed. diff --git a/conanprofiles/clang-16-macos b/conanprofiles/clang-16-macos index d1c5ccd365..c1273bacf1 100644 --- a/conanprofiles/clang-16-macos +++ b/conanprofiles/clang-16-macos @@ -6,3 +6,6 @@ compiler.cppstd=gnu17 compiler.libcxx=libc++ compiler.version=16 os=Macos + +[conf] +tools.build:compiler_executables={ "c": "clang", "cpp": "clang++"} diff --git a/test/engine/ExistsJoinTest.cpp b/test/engine/ExistsJoinTest.cpp index e16d9b3ba7..627d3283a4 100644 --- a/test/engine/ExistsJoinTest.cpp +++ b/test/engine/ExistsJoinTest.cpp @@ -32,7 +32,7 @@ void testExistsFromIdTable(IdTable left, IdTable right, // was applied auto permuteColumns = [](auto& table) { auto colsView = ad_utility::integerRange(table.numColumns()); - std::vector permutation; + std::vector permutation; ql::ranges::copy(colsView, std::back_inserter(permutation)); table.setColumnSubset(permutation); return permutation; From cc0e35b4c86a0c61b24bed8a0f99e5e48fdbef16 Mon Sep 17 00:00:00 2001 From: RobinTF <83676088+RobinTF@users.noreply.github.com> Date: Thu, 20 Feb 2025 09:40:09 +0100 Subject: [PATCH 4/7] Implement deep copies for the `Operation` subclasses. (#1815) The deep copies are implemented via an explicit `clone` methods. The reason behind this explicit method is two-fold: * Currently the children of an operation are stored as `shared_ptr`s to make the query planning cheaper, so we need explicit logic to recursively deep clone all the children. * We need an explicit `clone` method anyway, because `Operation` is a virtual base class, where we need an explicit `clone` method (the socalled `virtual copy constructor` idiom, to create a copy of a pointer to the virtual base class. --- src/engine/Bind.cpp | 5 ++ src/engine/Bind.h | 1 + src/engine/CartesianProductJoin.cpp | 11 +++ src/engine/CartesianProductJoin.h | 2 + src/engine/CountAvailablePredicates.cpp | 7 ++ src/engine/CountAvailablePredicates.h | 21 +++--- src/engine/Describe.cpp | 6 ++ src/engine/Describe.h | 1 + src/engine/Distinct.cpp | 6 ++ src/engine/Distinct.h | 1 + src/engine/ExistsJoin.cpp | 8 +++ src/engine/ExistsJoin.h | 2 + src/engine/ExportQueryExecutionTrees.cpp | 1 + src/engine/ExportQueryExecutionTrees.h | 1 - src/engine/Filter.cpp | 6 ++ src/engine/Filter.h | 3 +- src/engine/GroupBy.cpp | 6 ++ src/engine/GroupBy.h | 3 + src/engine/HasPredicateScan.cpp | 9 +++ src/engine/HasPredicateScan.h | 2 + src/engine/IndexScan.cpp | 14 +++- src/engine/IndexScan.h | 4 +- src/engine/Join.cpp | 8 +++ src/engine/Join.h | 2 + src/engine/Minus.cpp | 8 +++ src/engine/Minus.h | 8 ++- src/engine/MultiColumnJoin.cpp | 8 +++ src/engine/MultiColumnJoin.h | 2 + src/engine/NeutralElementOperation.h | 4 ++ src/engine/Operation.cpp | 28 ++++++++ src/engine/Operation.h | 8 +++ src/engine/OptionalJoin.cpp | 8 +++ src/engine/OptionalJoin.h | 2 + src/engine/OrderBy.cpp | 6 ++ src/engine/OrderBy.h | 2 + src/engine/PathSearch.cpp | 15 ++++ src/engine/PathSearch.h | 2 + src/engine/QueryExecutionTree.h | 6 ++ src/engine/Service.cpp | 6 ++ src/engine/Service.h | 2 + src/engine/Sort.cpp | 6 ++ src/engine/Sort.h | 2 + src/engine/SpatialJoin.cpp | 8 +++ src/engine/SpatialJoin.h | 2 + src/engine/TextIndexScanForEntity.cpp | 5 ++ src/engine/TextIndexScanForEntity.h | 9 ++- src/engine/TextIndexScanForWord.cpp | 5 ++ src/engine/TextIndexScanForWord.h | 9 ++- src/engine/TextLimit.cpp | 11 ++- src/engine/TextLimit.h | 2 + src/engine/TransitivePathBase.h | 8 +++ src/engine/TransitivePathBinSearch.cpp | 9 +++ src/engine/TransitivePathBinSearch.h | 2 + src/engine/TransitivePathHashMap.cpp | 9 +++ src/engine/TransitivePathHashMap.h | 4 +- src/engine/Union.cpp | 9 +++ src/engine/Union.h | 9 +-- src/engine/Values.cpp | 5 ++ src/engine/Values.h | 6 +- test/EngineTest.cpp | 15 ++++ test/FilterTest.cpp | 23 ++++++ test/GroupByTest.cpp | 21 ++++++ test/HasPredicateScanTest.cpp | 47 +++++++++++++ test/JoinTest.cpp | 19 +++++ test/MinusTest.cpp | 23 ++++++ test/MultiColumnJoinTest.cpp | 29 ++++++++ test/OrderByTest.cpp | 15 ++++ test/PathSearchTest.cpp | 35 ++++++++++ test/ServiceTest.cpp | 23 ++++++ test/SortTest.cpp | 11 +++ test/TextLimitOperationTest.cpp | 13 ++++ test/TransitivePathTest.cpp | 41 +++++++++++ test/UnionTest.cpp | 17 ++++- test/ValuesTest.cpp | 14 ++++ test/engine/BindTest.cpp | 20 ++++++ test/engine/CartesianProductJoinTest.cpp | 17 +++++ test/engine/DescribeTest.cpp | 16 +++++ test/engine/DistinctTest.cpp | 14 ++++ test/engine/ExistsJoinTest.cpp | 21 +++++- test/engine/IndexScanTest.cpp | 30 ++++++++ test/engine/SpatialJoinTest.cpp | 81 ++++++++++++++++++++++ test/engine/TextIndexScanForEntityTest.cpp | 14 ++++ test/engine/TextIndexScanForWordTest.cpp | 13 ++++ test/engine/ValuesForTesting.h | 26 +++++++ test/util/OperationTestHelpers.h | 34 +++++++++ 85 files changed, 958 insertions(+), 39 deletions(-) diff --git a/src/engine/Bind.cpp b/src/engine/Bind.cpp index bdf9132f6a..8e9ad5d77b 100644 --- a/src/engine/Bind.cpp +++ b/src/engine/Bind.cpp @@ -220,3 +220,8 @@ IdTable Bind::computeExpressionBind( return idTable; } + +// _____________________________________________________________________________ +std::unique_ptr Bind::cloneImpl() const { + return std::make_unique(_executionContext, _subtree->clone(), _bind); +} diff --git a/src/engine/Bind.h b/src/engine/Bind.h index 5613f8cd6f..86f271d3ec 100644 --- a/src/engine/Bind.h +++ b/src/engine/Bind.h @@ -33,6 +33,7 @@ class Bind : public Operation { bool supportsLimit() const override; private: + std::unique_ptr cloneImpl() const override; uint64_t getSizeEstimateBeforeLimit() override; public: diff --git a/src/engine/CartesianProductJoin.cpp b/src/engine/CartesianProductJoin.cpp index fdd22d6542..b5c3d69ab7 100644 --- a/src/engine/CartesianProductJoin.cpp +++ b/src/engine/CartesianProductJoin.cpp @@ -365,3 +365,14 @@ Result::Generator CartesianProductJoin::createLazyConsumer( idTables.pop_back(); } } + +// _____________________________________________________________________________ +std::unique_ptr CartesianProductJoin::cloneImpl() const { + Children copy; + copy.reserve(children_.size()); + for (const auto& operation : children_) { + copy.push_back(operation->clone()); + } + return std::make_unique(_executionContext, + std::move(copy), chunkSize_); +} diff --git a/src/engine/CartesianProductJoin.h b/src/engine/CartesianProductJoin.h index 9988cb72e6..817c620ddb 100644 --- a/src/engine/CartesianProductJoin.h +++ b/src/engine/CartesianProductJoin.h @@ -61,6 +61,8 @@ class CartesianProductJoin : public Operation { VariableToColumnMap computeVariableToColumnMap() const override; + std::unique_ptr cloneImpl() const override; + public: float getMultiplicity([[maybe_unused]] size_t col) override; diff --git a/src/engine/CountAvailablePredicates.cpp b/src/engine/CountAvailablePredicates.cpp index e78fcca694..1c9d0d7e0f 100644 --- a/src/engine/CountAvailablePredicates.cpp +++ b/src/engine/CountAvailablePredicates.cpp @@ -360,3 +360,10 @@ void CountAvailablePredicates::computePatternTrick( runtimeInfo.addDetail("costRatio", costRatio * 100); *dynResult = std::move(result).toDynamic(); } + +// _____________________________________________________________________________ +std::unique_ptr CountAvailablePredicates::cloneImpl() const { + return std::make_unique( + _executionContext, subtree_->clone(), subjectColumnIndex_, + predicateVariable_, countVariable_); +} diff --git a/src/engine/CountAvailablePredicates.h b/src/engine/CountAvailablePredicates.h index 1a804099f8..38d5a4a022 100644 --- a/src/engine/CountAvailablePredicates.h +++ b/src/engine/CountAvailablePredicates.h @@ -5,16 +5,11 @@ #include #include -#include #include -#include "../global/Pattern.h" -#include "../parser/ParsedQuery.h" -#include "./Operation.h" -#include "./QueryExecutionTree.h" - -using std::string; -using std::vector; +#include "engine/Operation.h" +#include "engine/QueryExecutionTree.h" +#include "global/Pattern.h" // This Operation takes a Result with at least one column containing ids, // and a column index referring to such a column. It then creates a Result @@ -43,16 +38,16 @@ class CountAvailablePredicates : public Operation { Variable predicateVariable, Variable countVariable); protected: - [[nodiscard]] string getCacheKeyImpl() const override; + [[nodiscard]] std::string getCacheKeyImpl() const override; public: - [[nodiscard]] string getDescriptor() const override; + [[nodiscard]] std::string getDescriptor() const override; [[nodiscard]] size_t getResultWidth() const override; - [[nodiscard]] vector resultSortedOn() const override; + [[nodiscard]] std::vector resultSortedOn() const override; - vector getChildren() override { + std::vector getChildren() override { using R = vector; return subtree_ != nullptr ? R{subtree_.get()} : R{}; } @@ -69,6 +64,8 @@ class CountAvailablePredicates : public Operation { private: uint64_t getSizeEstimateBeforeLimit() override; + std::unique_ptr cloneImpl() const override; + public: size_t getCostEstimate() override; diff --git a/src/engine/Describe.cpp b/src/engine/Describe.cpp index a0c43222d2..bb31d00bcb 100644 --- a/src/engine/Describe.cpp +++ b/src/engine/Describe.cpp @@ -244,3 +244,9 @@ ProtoResult Describe::computeResult([[maybe_unused]] bool requestLaziness) { return {std::move(resultTable), resultSortedOn(), std::move(localVocab)}; } + +// _____________________________________________________________________________ +std::unique_ptr Describe::cloneImpl() const { + return std::make_unique(_executionContext, subtree_->clone(), + describe_); +} diff --git a/src/engine/Describe.h b/src/engine/Describe.h index da0fa95bf6..ab566ed286 100644 --- a/src/engine/Describe.h +++ b/src/engine/Describe.h @@ -47,6 +47,7 @@ class Describe : public Operation { bool knownEmptyResult() override; private: + std::unique_ptr cloneImpl() const override; [[nodiscard]] vector resultSortedOn() const override; ProtoResult computeResult(bool requestLaziness) override; VariableToColumnMap computeVariableToColumnMap() const override; diff --git a/src/engine/Distinct.cpp b/src/engine/Distinct.cpp index 244e5d5c7c..a8faad3705 100644 --- a/src/engine/Distinct.cpp +++ b/src/engine/Distinct.cpp @@ -182,3 +182,9 @@ IdTable Distinct::outOfPlaceDistinct(const IdTable& dynInput) const { LOG(DEBUG) << "Distinct done.\n"; return std::move(output).toDynamic(); } + +// _____________________________________________________________________________ +std::unique_ptr Distinct::cloneImpl() const { + return std::make_unique(_executionContext, subtree_->clone(), + keepIndices_); +} diff --git a/src/engine/Distinct.h b/src/engine/Distinct.h index dba3e60b15..8257980b0b 100644 --- a/src/engine/Distinct.h +++ b/src/engine/Distinct.h @@ -52,6 +52,7 @@ class Distinct : public Operation { [[nodiscard]] string getCacheKeyImpl() const override; private: + std::unique_ptr cloneImpl() const override; ProtoResult computeResult(bool requestLaziness) override; VariableToColumnMap computeVariableToColumnMap() const override; diff --git a/src/engine/ExistsJoin.cpp b/src/engine/ExistsJoin.cpp index 902f551ddb..c751085916 100644 --- a/src/engine/ExistsJoin.cpp +++ b/src/engine/ExistsJoin.cpp @@ -205,3 +205,11 @@ std::shared_ptr ExistsJoin::addExistsJoinsToSubtree( } return subtree; } + +// _____________________________________________________________________________ +std::unique_ptr ExistsJoin::cloneImpl() const { + auto newJoin = std::make_unique(*this); + newJoin->left_ = left_->clone(); + newJoin->right_ = right_->clone(); + return newJoin; +} diff --git a/src/engine/ExistsJoin.h b/src/engine/ExistsJoin.h index 43dbbe074f..82c0e5a318 100644 --- a/src/engine/ExistsJoin.h +++ b/src/engine/ExistsJoin.h @@ -76,6 +76,8 @@ class ExistsJoin : public Operation { } private: + std::unique_ptr cloneImpl() const override; + ProtoResult computeResult([[maybe_unused]] bool requestLaziness) override; VariableToColumnMap computeVariableToColumnMap() const override; diff --git a/src/engine/ExportQueryExecutionTrees.cpp b/src/engine/ExportQueryExecutionTrees.cpp index f50bca9f61..40349c2d65 100644 --- a/src/engine/ExportQueryExecutionTrees.cpp +++ b/src/engine/ExportQueryExecutionTrees.cpp @@ -14,6 +14,7 @@ #include "parser/RdfEscaping.h" #include "util/ConstexprUtils.h" #include "util/http/MediaTypes.h" +#include "util/json.h" // Return true iff the `result` is nonempty. bool getResultForAsk(const std::shared_ptr& result) { diff --git a/src/engine/ExportQueryExecutionTrees.h b/src/engine/ExportQueryExecutionTrees.h index 93eb05a5b4..c50ef2792a 100644 --- a/src/engine/ExportQueryExecutionTrees.h +++ b/src/engine/ExportQueryExecutionTrees.h @@ -10,7 +10,6 @@ #include "parser/data/LimitOffsetClause.h" #include "util/CancellationHandle.h" #include "util/http/MediaTypes.h" -#include "util/json.h" // Class for computing the result of an already parsed and planned query and // exporting it in different formats (TSV, CSV, Turtle, JSON, Binary). diff --git a/src/engine/Filter.cpp b/src/engine/Filter.cpp index da868520b8..9513c475be 100644 --- a/src/engine/Filter.cpp +++ b/src/engine/Filter.cpp @@ -240,3 +240,9 @@ size_t Filter::getCostEstimate() { _subtree->getRootOperation()->getPrimarySortKeyVariable()) .costEstimate; } + +// _____________________________________________________________________________ +std::unique_ptr Filter::cloneImpl() const { + return std::make_unique(_executionContext, _subtree->clone(), + _expression); +} diff --git a/src/engine/Filter.h b/src/engine/Filter.h index 35700c4cb0..c8e55c195b 100644 --- a/src/engine/Filter.h +++ b/src/engine/Filter.h @@ -10,7 +10,6 @@ #include "engine/Operation.h" #include "engine/QueryExecutionTree.h" -#include "parser/ParsedQuery.h" class Filter : public Operation { using PrefilterVariablePair = sparqlExpression::PrefilterExprVariablePair; @@ -55,6 +54,8 @@ class Filter : public Operation { } private: + std::unique_ptr cloneImpl() const override; + VariableToColumnMap computeVariableToColumnMap() const override { return _subtree->getVariableColumns(); } diff --git a/src/engine/GroupBy.cpp b/src/engine/GroupBy.cpp index 28e756d502..1d865c2847 100644 --- a/src/engine/GroupBy.cpp +++ b/src/engine/GroupBy.cpp @@ -1624,3 +1624,9 @@ GroupBy::getVariableForCountOfSingleAlias() const { bool GroupBy::isVariableBoundInSubtree(const Variable& variable) const { return _subtree->getVariableColumnOrNullopt(variable).has_value(); } + +// _____________________________________________________________________________ +std::unique_ptr GroupBy::cloneImpl() const { + return std::make_unique(_executionContext, _groupByVariables, + _aliases, _subtree->clone()); +} diff --git a/src/engine/GroupBy.h b/src/engine/GroupBy.h index c5502d32b9..f1c5e44f74 100644 --- a/src/engine/GroupBy.h +++ b/src/engine/GroupBy.h @@ -576,6 +576,9 @@ class GroupBy : public Operation { // GROUP BY. This is used by some of the optimizations above. bool isVariableBoundInSubtree(const Variable& variable) const; + private: + std::unique_ptr cloneImpl() const override; + // TODO implement optimization when *additional* Variables are // grouped. diff --git a/src/engine/HasPredicateScan.cpp b/src/engine/HasPredicateScan.cpp index c27d879021..6b822eef0b 100644 --- a/src/engine/HasPredicateScan.cpp +++ b/src/engine/HasPredicateScan.cpp @@ -393,3 +393,12 @@ const TripleComponent& HasPredicateScan::getObject() const { return object_; } // ___________________________________________________________________________ HasPredicateScan::ScanType HasPredicateScan::getType() const { return type_; } + +// _____________________________________________________________________________ +std::unique_ptr HasPredicateScan::cloneImpl() const { + auto copy = std::make_unique(*this); + if (subtree_.has_value()) { + copy->subtree_.value().subtree_ = subtree().clone(); + } + return copy; +} diff --git a/src/engine/HasPredicateScan.h b/src/engine/HasPredicateScan.h index bfc1249858..6668c6fcdb 100644 --- a/src/engine/HasPredicateScan.h +++ b/src/engine/HasPredicateScan.h @@ -109,6 +109,8 @@ class HasPredicateScan : public Operation { const CompactVectorOfStrings& patterns); private: + std::unique_ptr cloneImpl() const override; + ProtoResult computeResult([[maybe_unused]] bool requestLaziness) override; [[nodiscard]] VariableToColumnMap computeVariableToColumnMap() const override; diff --git a/src/engine/IndexScan.cpp b/src/engine/IndexScan.cpp index faa9cda3c5..b07687ca1a 100644 --- a/src/engine/IndexScan.cpp +++ b/src/engine/IndexScan.cpp @@ -7,7 +7,6 @@ #include #include -#include #include #include @@ -660,3 +659,16 @@ std::pair IndexScan::prefilterTables( return {createPrefilteredJoinSide(state), createPrefilteredIndexScanSide(state)}; } + +// _____________________________________________________________________________ +std::unique_ptr IndexScan::cloneImpl() const { + auto prefilter = + prefilter_.has_value() + ? std::optional{std::pair{prefilter_.value().first->clone(), + prefilter_.value().second}} + : std::nullopt; + return std::make_unique(_executionContext, permutation_, subject_, + predicate_, object_, additionalColumns_, + additionalVariables_, graphsToFilter_, + std::move(prefilter)); +} diff --git a/src/engine/IndexScan.h b/src/engine/IndexScan.h index 72d377cfc3..13e1a9955e 100644 --- a/src/engine/IndexScan.h +++ b/src/engine/IndexScan.h @@ -5,7 +5,7 @@ #include -#include "./Operation.h" +#include "engine/Operation.h" #include "util/HashMap.h" class SparqlTriple; @@ -176,6 +176,8 @@ class IndexScan final : public Operation { const CompressedRelationReader::LazyScanMetadata& metadata); private: + std::unique_ptr cloneImpl() const override; + ProtoResult computeResult(bool requestLaziness) override; vector getChildren() override { return {}; } diff --git a/src/engine/Join.cpp b/src/engine/Join.cpp index 512a65beda..66f98f34f3 100644 --- a/src/engine/Join.cpp +++ b/src/engine/Join.cpp @@ -839,3 +839,11 @@ ad_utility::AddCombinedRowToIdTable Join::makeRowAdder( 1, IdTable{getResultWidth(), allocator()}, cancellationHandle_, CHUNK_SIZE, std::move(callback)}; } + +// _____________________________________________________________________________ +std::unique_ptr Join::cloneImpl() const { + auto copy = std::make_unique(*this); + copy->_left = _left->clone(); + copy->_right = _right->clone(); + return copy; +} diff --git a/src/engine/Join.h b/src/engine/Join.h index d0627bcfc2..6a90a1f438 100644 --- a/src/engine/Join.h +++ b/src/engine/Join.h @@ -146,6 +146,8 @@ class Join : public Operation { virtual string getCacheKeyImpl() const override; private: + std::unique_ptr cloneImpl() const override; + ProtoResult computeResult(bool requestLaziness) override; VariableToColumnMap computeVariableToColumnMap() const override; diff --git a/src/engine/Minus.cpp b/src/engine/Minus.cpp index 32aa9eca3a..eba7b700c1 100644 --- a/src/engine/Minus.cpp +++ b/src/engine/Minus.cpp @@ -220,3 +220,11 @@ Minus::RowComparison Minus::isRowEqSkipFirst( } return RowComparison::EQUAL; } + +// _____________________________________________________________________________ +std::unique_ptr Minus::cloneImpl() const { + auto copy = std::make_unique(*this); + copy->_left = _left->clone(); + copy->_right = _right->clone(); + return copy; +} diff --git a/src/engine/Minus.h b/src/engine/Minus.h index 92c4a49a2f..4774de1335 100644 --- a/src/engine/Minus.h +++ b/src/engine/Minus.h @@ -6,8 +6,8 @@ #include #include -#include "./Operation.h" -#include "./QueryExecutionTree.h" +#include "engine/Operation.h" +#include "engine/QueryExecutionTree.h" class Minus : public Operation { private: @@ -25,7 +25,7 @@ class Minus : public Operation { // Uninitialized Object for testing the computeMinus method struct OnlyForTestingTag {}; - explicit Minus(OnlyForTestingTag){}; + explicit Minus(OnlyForTestingTag) {} protected: string getCacheKeyImpl() const override; @@ -63,6 +63,8 @@ class Minus : public Operation { IdTable* result) const; private: + std::unique_ptr cloneImpl() const override; + /** * @brief Compares the two rows under the assumption that the first * entries of the rows are equal. diff --git a/src/engine/MultiColumnJoin.cpp b/src/engine/MultiColumnJoin.cpp index a831c4cd55..5cc63bf707 100644 --- a/src/engine/MultiColumnJoin.cpp +++ b/src/engine/MultiColumnJoin.cpp @@ -288,3 +288,11 @@ void MultiColumnJoin::computeMultiColumnJoin( result->setColumnSubset(joinColumnData.permutationResult()); checkCancellation(); } + +// _____________________________________________________________________________ +std::unique_ptr MultiColumnJoin::cloneImpl() const { + auto copy = std::make_unique(*this); + copy->_left = _left->clone(); + copy->_right = _right->clone(); + return copy; +} diff --git a/src/engine/MultiColumnJoin.h b/src/engine/MultiColumnJoin.h index ff5e784718..6ea855c721 100644 --- a/src/engine/MultiColumnJoin.h +++ b/src/engine/MultiColumnJoin.h @@ -63,6 +63,8 @@ class MultiColumnJoin : public Operation { IdTable* resultMightBeUnsorted); private: + std::unique_ptr cloneImpl() const override; + ProtoResult computeResult([[maybe_unused]] bool requestLaziness) override; VariableToColumnMap computeVariableToColumnMap() const override; diff --git a/src/engine/NeutralElementOperation.h b/src/engine/NeutralElementOperation.h index a6d88d9b9a..33b8ffa247 100644 --- a/src/engine/NeutralElementOperation.h +++ b/src/engine/NeutralElementOperation.h @@ -34,6 +34,10 @@ class NeutralElementOperation : public Operation { float getMultiplicity(size_t) override { return 0; }; bool knownEmptyResult() override { return false; }; + std::unique_ptr cloneImpl() const override { + return std::make_unique(_executionContext); + } + protected: [[nodiscard]] vector resultSortedOn() const override { return {}; diff --git a/src/engine/Operation.cpp b/src/engine/Operation.cpp index 0f94ff7886..771c3c4b51 100644 --- a/src/engine/Operation.cpp +++ b/src/engine/Operation.cpp @@ -625,3 +625,31 @@ uint64_t Operation::getSizeEstimate() { return getSizeEstimateBeforeLimit(); } } + +// _____________________________________________________________________________ +std::unique_ptr Operation::clone() const { + auto result = cloneImpl(); + auto compareTypes = [this, &result]() { + const auto& reference = *result; + return typeid(*this) == typeid(reference); + }; + AD_CORRECTNESS_CHECK(compareTypes()); + AD_CORRECTNESS_CHECK(result->_executionContext == _executionContext); + auto areChildrenDifferent = [this, &result]() { + auto ownChildren = getChildren(); + auto otherChildren = result->getChildren(); + if (ownChildren.size() != otherChildren.size()) { + return false; + } + for (size_t i = 0; i < ownChildren.size(); i++) { + if (ownChildren.at(i) == otherChildren.at(i)) { + return false; + } + } + return true; + }; + AD_CORRECTNESS_CHECK(areChildrenDifferent()); + AD_CORRECTNESS_CHECK(variableToColumnMap_ == result->variableToColumnMap_); + AD_EXPENSIVE_CHECK(getCacheKey() == result->getCacheKey()); + return result; +} diff --git a/src/engine/Operation.h b/src/engine/Operation.h index 1a3f68b83d..5ef65ae0f0 100644 --- a/src/engine/Operation.h +++ b/src/engine/Operation.h @@ -316,6 +316,14 @@ class Operation { const auto& getLimit() const { return _limit; } + private: + // Actual implementation of `clone()` without extra checks. + virtual std::unique_ptr cloneImpl() const = 0; + + public: + // Create a deep copy of this operation. + std::unique_ptr clone() const; + protected: // The QueryExecutionContext for this particular element. // No ownership. diff --git a/src/engine/OptionalJoin.cpp b/src/engine/OptionalJoin.cpp index 8f009d963e..8ef98e3254 100644 --- a/src/engine/OptionalJoin.cpp +++ b/src/engine/OptionalJoin.cpp @@ -379,3 +379,11 @@ void OptionalJoin::optionalJoin( result->setColumnSubset(joinColumnData.permutationResult()); checkCancellation(); } + +// _____________________________________________________________________________ +std::unique_ptr OptionalJoin::cloneImpl() const { + auto copy = std::make_unique(*this); + copy->_left = _left->clone(); + copy->_right = _right->clone(); + return copy; +} diff --git a/src/engine/OptionalJoin.h b/src/engine/OptionalJoin.h index 7685bd8708..1771db1f37 100644 --- a/src/engine/OptionalJoin.h +++ b/src/engine/OptionalJoin.h @@ -67,6 +67,8 @@ class OptionalJoin : public Operation { Implementation implementation = Implementation::GeneralCase); private: + std::unique_ptr cloneImpl() const override; + void computeSizeEstimateAndMultiplicities(); ProtoResult computeResult([[maybe_unused]] bool requestLaziness) override; diff --git a/src/engine/OrderBy.cpp b/src/engine/OrderBy.cpp index 5d999e62bc..0fa6656dca 100644 --- a/src/engine/OrderBy.cpp +++ b/src/engine/OrderBy.cpp @@ -139,3 +139,9 @@ OrderBy::SortedVariables OrderBy::getSortedVariables() const { } return result; } + +// _____________________________________________________________________________ +std::unique_ptr OrderBy::cloneImpl() const { + return std::make_unique(_executionContext, subtree_->clone(), + sortIndices_); +} diff --git a/src/engine/OrderBy.h b/src/engine/OrderBy.h index a04d187ce4..c7ed7c27bf 100644 --- a/src/engine/OrderBy.h +++ b/src/engine/OrderBy.h @@ -78,6 +78,8 @@ class OrderBy : public Operation { } private: + std::unique_ptr cloneImpl() const override; + ProtoResult computeResult([[maybe_unused]] bool requestLaziness) override; VariableToColumnMap computeVariableToColumnMap() const override { diff --git a/src/engine/PathSearch.cpp b/src/engine/PathSearch.cpp index 0a60197341..8acdd2c2fd 100644 --- a/src/engine/PathSearch.cpp +++ b/src/engine/PathSearch.cpp @@ -471,3 +471,18 @@ void PathSearch::pathsToResultTable(IdTable& tableDyn, PathsLimited& paths, tableDyn = std::move(table).toDynamic(); } + +// _____________________________________________________________________________ +std::unique_ptr PathSearch::cloneImpl() const { + auto copy = std::make_unique(*this); + copy->subtree_ = subtree_->clone(); + auto cloneIfNonEmpty = [](auto& tree) { + if (tree.has_value()) { + tree = tree.value()->clone(); + } + }; + cloneIfNonEmpty(copy->sourceTree_); + cloneIfNonEmpty(copy->targetTree_); + cloneIfNonEmpty(copy->sourceAndTargetTree_); + return copy; +} diff --git a/src/engine/PathSearch.h b/src/engine/PathSearch.h index b42f277eb7..e9dc8b537f 100644 --- a/src/engine/PathSearch.h +++ b/src/engine/PathSearch.h @@ -253,6 +253,8 @@ class PathSearch : public Operation { VariableToColumnMap computeVariableToColumnMap() const override; private: + std::unique_ptr cloneImpl() const override; + std::pair, std::span> handleSearchSides() const; /** diff --git a/src/engine/QueryExecutionTree.h b/src/engine/QueryExecutionTree.h index c8bd4f5fc9..398aa02f1d 100644 --- a/src/engine/QueryExecutionTree.h +++ b/src/engine/QueryExecutionTree.h @@ -243,6 +243,12 @@ class QueryExecutionTree { predicate_{std::move(predicate)}, object_{std::move(object)} {} }; + + std::shared_ptr clone() const { + return rootOperation_ ? std::make_shared( + qec_, rootOperation_->clone()) + : std::make_shared(qec_); + } }; namespace ad_utility { diff --git a/src/engine/Service.cpp b/src/engine/Service.cpp index a2791055c0..f2ef5a959f 100644 --- a/src/engine/Service.cpp +++ b/src/engine/Service.cpp @@ -621,3 +621,9 @@ void Service::precomputeSiblingResult(std::shared_ptr left, service->siblingInfo_->precomputedResult_; addRuntimeInfo(true); } + +// _____________________________________________________________________________ +std::unique_ptr Service::cloneImpl() const { + return std::make_unique(_executionContext, parsedServiceClause_, + getResultFunction_); +} diff --git a/src/engine/Service.h b/src/engine/Service.h index 8fef6f5d0e..6a914f66c6 100644 --- a/src/engine/Service.h +++ b/src/engine/Service.h @@ -103,6 +103,8 @@ class Service : public Operation { bool rightOnly, bool requestLaziness); private: + std::unique_ptr cloneImpl() const override; + // The string returned by this function is used as cache key. std::string getCacheKeyImpl() const override; diff --git a/src/engine/Sort.cpp b/src/engine/Sort.cpp index f66a8c8bc9..0b6bc67e15 100644 --- a/src/engine/Sort.cpp +++ b/src/engine/Sort.cpp @@ -74,3 +74,9 @@ ProtoResult Sort::computeResult([[maybe_unused]] bool requestLaziness) { LOG(DEBUG) << "Sort result computation done." << endl; return {std::move(idTable), resultSortedOn(), subRes->getSharedLocalVocab()}; } + +// _____________________________________________________________________________ +std::unique_ptr Sort::cloneImpl() const { + return std::make_unique(_executionContext, subtree_->clone(), + sortColumnIndices_); +} diff --git a/src/engine/Sort.h b/src/engine/Sort.h index d94a69c199..5075d79c46 100644 --- a/src/engine/Sort.h +++ b/src/engine/Sort.h @@ -67,6 +67,8 @@ class Sort : public Operation { } private: + std::unique_ptr cloneImpl() const override; + virtual ProtoResult computeResult( [[maybe_unused]] bool requestLaziness) override; diff --git a/src/engine/SpatialJoin.cpp b/src/engine/SpatialJoin.cpp index 767e5d4ac9..95bac69bc8 100644 --- a/src/engine/SpatialJoin.cpp +++ b/src/engine/SpatialJoin.cpp @@ -472,3 +472,11 @@ VariableToColumnMap SpatialJoin::computeVariableToColumnMap() const { return variableToColumnMap; } + +// _____________________________________________________________________________ +std::unique_ptr SpatialJoin::cloneImpl() const { + return std::make_unique( + _executionContext, config_, + childLeft_ ? std::optional{childLeft_->clone()} : std::nullopt, + childRight_ ? std::optional{childRight_->clone()} : std::nullopt); +} diff --git a/src/engine/SpatialJoin.h b/src/engine/SpatialJoin.h index af76c49a81..fc94a51866 100644 --- a/src/engine/SpatialJoin.h +++ b/src/engine/SpatialJoin.h @@ -175,6 +175,8 @@ class SpatialJoin : public Operation { } private: + std::unique_ptr cloneImpl() const override; + // helper function to generate a variable to column map from `childRight_` // that only contains the columns selected by `config_.payloadVariables_` // and (automatically added) the `config_.right_` variable. diff --git a/src/engine/TextIndexScanForEntity.cpp b/src/engine/TextIndexScanForEntity.cpp index 276bd4af75..49cf50e616 100644 --- a/src/engine/TextIndexScanForEntity.cpp +++ b/src/engine/TextIndexScanForEntity.cpp @@ -113,3 +113,8 @@ string TextIndexScanForEntity::getCacheKeyImpl() const { << (hasFixedEntity() ? fixedEntity() : "no fixed-entity") << " \""; return std::move(os).str(); } + +// _____________________________________________________________________________ +std::unique_ptr TextIndexScanForEntity::cloneImpl() const { + return std::make_unique(*this); +} diff --git a/src/engine/TextIndexScanForEntity.h b/src/engine/TextIndexScanForEntity.h index f1b11f9018..fd68377c33 100644 --- a/src/engine/TextIndexScanForEntity.h +++ b/src/engine/TextIndexScanForEntity.h @@ -6,7 +6,7 @@ #include -#include "./Operation.h" +#include "engine/Operation.h" // This operation retrieves all text records and their corresponding // entities from the fulltext index that contain a certain word or prefix. @@ -84,10 +84,7 @@ class TextIndexScanForEntity : public Operation { uint64_t getSizeEstimateBeforeLimit() override; - float getMultiplicity(size_t col) override { - (void)col; - return 1; - } + float getMultiplicity(size_t) override { return 1; } bool knownEmptyResult() override; @@ -96,6 +93,8 @@ class TextIndexScanForEntity : public Operation { VariableToColumnMap computeVariableToColumnMap() const override; private: + std::unique_ptr cloneImpl() const override; + const VocabIndex& getVocabIndexOfFixedEntity() const { AD_CONTRACT_CHECK(hasFixedEntity()); return std::get(varOrFixed_.entity_).second; diff --git a/src/engine/TextIndexScanForWord.cpp b/src/engine/TextIndexScanForWord.cpp index 7c3f931f8f..01f73aaf2a 100644 --- a/src/engine/TextIndexScanForWord.cpp +++ b/src/engine/TextIndexScanForWord.cpp @@ -81,3 +81,8 @@ string TextIndexScanForWord::getCacheKeyImpl() const { << " with word: \"" << word_ << "\""; return std::move(os).str(); } + +// _____________________________________________________________________________ +std::unique_ptr TextIndexScanForWord::cloneImpl() const { + return std::make_unique(*this); +} diff --git a/src/engine/TextIndexScanForWord.h b/src/engine/TextIndexScanForWord.h index 3628e27d9c..cfe5dff415 100644 --- a/src/engine/TextIndexScanForWord.h +++ b/src/engine/TextIndexScanForWord.h @@ -6,7 +6,7 @@ #include -#include "./Operation.h" +#include "engine/Operation.h" // This operation retrieves all text records from the fulltext index that // contain a certain word or prefix. @@ -36,10 +36,7 @@ class TextIndexScanForWord : public Operation { uint64_t getSizeEstimateBeforeLimit() override; - float getMultiplicity(size_t col) override { - (void)col; - return 1; - } + float getMultiplicity(size_t) override { return 1; } bool knownEmptyResult() override { return getSizeEstimateBeforeLimit() == 0; } @@ -48,6 +45,8 @@ class TextIndexScanForWord : public Operation { VariableToColumnMap computeVariableToColumnMap() const override; private: + std::unique_ptr cloneImpl() const override; + // Returns a Result containing an IdTable with the columns being // the text variable and the completed word (if it was prefixed) ProtoResult computeResult([[maybe_unused]] bool requestLaziness) override; diff --git a/src/engine/TextLimit.cpp b/src/engine/TextLimit.cpp index 6ec7b7868f..75bcd7724a 100644 --- a/src/engine/TextLimit.cpp +++ b/src/engine/TextLimit.cpp @@ -15,7 +15,9 @@ TextLimit::TextLimit(QueryExecutionContext* qec, const size_t limit, child_(std::move(child)), textRecordColumn_(textRecordColumn), entityColumns_(entityColumns), - scoreColumns_(scoreColumns) {} + scoreColumns_(scoreColumns) { + AD_CONTRACT_CHECK(child_); +} // _____________________________________________________________________________ ProtoResult TextLimit::computeResult([[maybe_unused]] bool requestLaziness) { @@ -188,3 +190,10 @@ string TextLimit::getCacheKeyImpl() const { os << "}"; return std::move(os).str(); } + +// _____________________________________________________________________________ +std::unique_ptr TextLimit::cloneImpl() const { + return std::make_unique(_executionContext, limit_, child_->clone(), + textRecordColumn_, entityColumns_, + scoreColumns_); +} diff --git a/src/engine/TextLimit.h b/src/engine/TextLimit.h index cbda207f5d..552e19923a 100644 --- a/src/engine/TextLimit.h +++ b/src/engine/TextLimit.h @@ -62,6 +62,8 @@ class TextLimit : public Operation { VariableToColumnMap computeVariableToColumnMap() const override; private: + std::unique_ptr cloneImpl() const override; + ProtoResult computeResult([[maybe_unused]] bool requestLaziness) override; vector getChildren() override { return {child_.get()}; } diff --git a/src/engine/TransitivePathBase.h b/src/engine/TransitivePathBase.h index 71607eed5e..633d465cf8 100644 --- a/src/engine/TransitivePathBase.h +++ b/src/engine/TransitivePathBase.h @@ -55,6 +55,14 @@ struct TransitivePathSide { // TODO use ql::ranges::starts_with return (!sortedOn.empty() && sortedOn[0] == col); } + + TransitivePathSide clone() const { + TransitivePathSide copy = *this; + if (copy.treeAndCol_.has_value()) { + copy.treeAndCol_.value().first = copy.treeAndCol_.value().first->clone(); + } + return copy; + } }; // We deliberately use the `std::` variants of a hash set and hash map because diff --git a/src/engine/TransitivePathBinSearch.cpp b/src/engine/TransitivePathBinSearch.cpp index d68b420407..970f5ee940 100644 --- a/src/engine/TransitivePathBinSearch.cpp +++ b/src/engine/TransitivePathBinSearch.cpp @@ -31,3 +31,12 @@ BinSearchMap TransitivePathBinSearch::setupEdgesMap( return BinSearchMap{dynSub.getColumn(startSide.subCol_), dynSub.getColumn(targetSide.subCol_)}; } + +// _____________________________________________________________________________ +std::unique_ptr TransitivePathBinSearch::cloneImpl() const { + auto copy = std::make_unique(*this); + copy->subtree_ = subtree_->clone(); + copy->lhs_ = lhs_.clone(); + copy->rhs_ = rhs_.clone(); + return copy; +} diff --git a/src/engine/TransitivePathBinSearch.h b/src/engine/TransitivePathBinSearch.h index 4973d9da5e..faeda642ce 100644 --- a/src/engine/TransitivePathBinSearch.h +++ b/src/engine/TransitivePathBinSearch.h @@ -70,6 +70,8 @@ class TransitivePathBinSearch : public TransitivePathImpl { size_t maxDist); private: + std::unique_ptr cloneImpl() const override; + // initialize the map from the subresult BinSearchMap setupEdgesMap( const IdTable& dynSub, const TransitivePathSide& startSide, diff --git a/src/engine/TransitivePathHashMap.cpp b/src/engine/TransitivePathHashMap.cpp index a48bcc39a8..7d03f751e8 100644 --- a/src/engine/TransitivePathHashMap.cpp +++ b/src/engine/TransitivePathHashMap.cpp @@ -44,3 +44,12 @@ HashMapWrapper TransitivePathHashMap::setupEdgesMap( } return HashMapWrapper{std::move(edges), allocator()}; } + +// _____________________________________________________________________________ +std::unique_ptr TransitivePathHashMap::cloneImpl() const { + auto copy = std::make_unique(*this); + copy->subtree_ = subtree_->clone(); + copy->lhs_ = lhs_.clone(); + copy->rhs_ = rhs_.clone(); + return copy; +} diff --git a/src/engine/TransitivePathHashMap.h b/src/engine/TransitivePathHashMap.h index 3bda2c2117..0cae09e5bd 100644 --- a/src/engine/TransitivePathHashMap.h +++ b/src/engine/TransitivePathHashMap.h @@ -23,7 +23,7 @@ struct HashMapWrapper { Set emptySet_; HashMapWrapper(Map map, ad_utility::AllocatorWithLimit allocator) - : map_(std::move(map)), emptySet_(allocator){}; + : map_(std::move(map)), emptySet_(allocator) {} /** * @brief Return the successors for the given Id. The successors are all ids, @@ -59,6 +59,8 @@ class TransitivePathHashMap : public TransitivePathImpl { size_t maxDist); private: + std::unique_ptr cloneImpl() const override; + /** * @brief Prepare a Map and a nodes vector for the transitive hull * computation. diff --git a/src/engine/Union.cpp b/src/engine/Union.cpp index ec1afcba05..fa5a07427b 100644 --- a/src/engine/Union.cpp +++ b/src/engine/Union.cpp @@ -283,3 +283,12 @@ Result::Generator Union::computeResultLazily( } } } + +// _____________________________________________________________________________ +std::unique_ptr Union::cloneImpl() const { + auto copy = std::make_unique(*this); + for (auto& subtree : copy->_subtrees) { + subtree = subtree->clone(); + } + return copy; +} diff --git a/src/engine/Union.h b/src/engine/Union.h index e71702315f..3d87664380 100644 --- a/src/engine/Union.h +++ b/src/engine/Union.h @@ -9,10 +9,9 @@ #include #include -#include "../parser/ParsedQuery.h" -#include "../util/HashMap.h" -#include "Operation.h" -#include "QueryExecutionTree.h" +#include "engine/Operation.h" +#include "engine/QueryExecutionTree.h" +#include "util/HashMap.h" class Union : public Operation { private: @@ -63,6 +62,8 @@ class Union : public Operation { } private: + std::unique_ptr cloneImpl() const override; + ProtoResult computeResult(bool requestLaziness) override; VariableToColumnMap computeVariableToColumnMap() const override; diff --git a/src/engine/Values.cpp b/src/engine/Values.cpp index 7a4535ceed..aae997b162 100644 --- a/src/engine/Values.cpp +++ b/src/engine/Values.cpp @@ -147,3 +147,8 @@ void Values::writeValues(IdTable* idTablePtr, LocalVocab* localVocab) { << absl::StrJoin(numLocalVocabPerColumn, ", ") << std::endl; *idTablePtr = std::move(idTable).toDynamic(); } + +// _____________________________________________________________________________ +std::unique_ptr Values::cloneImpl() const { + return std::make_unique(*this); +} diff --git a/src/engine/Values.h b/src/engine/Values.h index 71d25b7a56..f490fa4993 100644 --- a/src/engine/Values.h +++ b/src/engine/Values.h @@ -6,8 +6,8 @@ #pragma once -#include "../parser/ParsedQuery.h" -#include "Operation.h" +#include "engine/Operation.h" +#include "parser/ParsedQuery.h" class Values : public Operation { using SparqlValues = parsedQuery::SparqlValues; @@ -54,6 +54,8 @@ class Values : public Operation { VariableToColumnMap computeVariableToColumnMap() const override; private: + std::unique_ptr cloneImpl() const override; + // Compute the per-column multiplicity of the parsed values. void computeMultiplicities(); diff --git a/test/EngineTest.cpp b/test/EngineTest.cpp index 860fa5521e..65910f7a09 100644 --- a/test/EngineTest.cpp +++ b/test/EngineTest.cpp @@ -19,6 +19,7 @@ #include "util/IdTableHelpers.h" #include "util/IdTestHelpers.h" #include "util/IndexTestHelpers.h" +#include "util/OperationTestHelpers.h" using ad_utility::testing::makeAllocator; using namespace ad_utility::testing; @@ -368,3 +369,17 @@ TEST(Engine, countDistinct) { ::testing::HasSubstr("must be sorted")); } } + +// _____________________________________________________________________________ +TEST(OptionalJoin, clone) { + auto qec = ad_utility::testing::getQec(); + auto a = makeIdTableFromVector({{0}}); + auto left = idTableToExecutionTree(qec, a); + auto right = idTableToExecutionTree(qec, a); + OptionalJoin opt{qec, left, right}; + + auto clone = opt.clone(); + ASSERT_TRUE(clone); + EXPECT_THAT(opt, IsDeepCopy(*clone)); + EXPECT_EQ(clone->getDescriptor(), opt.getDescriptor()); +} diff --git a/test/FilterTest.cpp b/test/FilterTest.cpp index 027c5bbbbf..b7886275a1 100644 --- a/test/FilterTest.cpp +++ b/test/FilterTest.cpp @@ -13,6 +13,7 @@ #include "engine/sparqlExpressions/SparqlExpression.h" #include "util/IdTableHelpers.h" #include "util/IndexTestHelpers.h" +#include "util/OperationTestHelpers.h" using ::testing::ElementsAre; using ::testing::Eq; @@ -223,3 +224,25 @@ TEST(Filter, lazyChildMaterializedResultBinaryFilter) { EXPECT_EQ(result->idTable(), makeIdTableFromVector({{5}, {6}, {7}, {8}, {8}}, I)); } + +// _____________________________________________________________________________ +TEST(Filter, clone) { + using namespace makeSparqlExpression; + QueryExecutionContext* qec = ad_utility::testing::getQec(); + std::vector idTables; + auto I = ad_utility::testing::IntId; + idTables.push_back(makeIdTableFromVector({{1}}, I)); + + ValuesForTesting values{ + qec, std::move(idTables), {Variable{"?x"}}, false, {0}}; + QueryExecutionTree subTree{ + qec, std::make_shared(std::move(values))}; + Filter filter{qec, + std::make_shared(std::move(subTree)), + {ltSprql(Variable{"?x"}, I(5)), "!?x < 5"}}; + + auto clone = filter.clone(); + ASSERT_TRUE(clone); + EXPECT_THAT(filter, IsDeepCopy(*clone)); + EXPECT_EQ(clone->getDescriptor(), filter.getDescriptor()); +} diff --git a/test/GroupByTest.cpp b/test/GroupByTest.cpp index 8f89f9ff11..0f949c23ee 100644 --- a/test/GroupByTest.cpp +++ b/test/GroupByTest.cpp @@ -30,6 +30,7 @@ #include "index/ConstantsIndexBuilding.h" #include "parser/SparqlParser.h" #include "util/IndexTestHelpers.h" +#include "util/OperationTestHelpers.h" using namespace ad_utility::testing; using ::testing::Eq; @@ -116,6 +117,26 @@ TEST_F(GroupByTest, getDescriptor) { ASSERT_EQ(groupBy.getDescriptor(), "GroupBy on ?a"); } +// _____________________________________________________________________________ +TEST_F(GroupByTest, clone) { + auto expr = + std::make_unique(Variable{"?a"}); + auto alias = + Alias{sparqlExpression::SparqlExpressionPimpl{std::move(expr), "?a"}, + Variable{"?a"}}; + + parsedQuery::SparqlValues input; + input._variables = {Variable{"?a"}}; + auto values = ad_utility::makeExecutionTree(getQec(), input); + + GroupBy groupBy{getQec(), {Variable{"?a"}}, {alias}, values}; + + auto clone = groupBy.clone(); + ASSERT_TRUE(clone); + EXPECT_THAT(groupBy, IsDeepCopy(*clone)); + EXPECT_EQ(clone->getDescriptor(), groupBy.getDescriptor()); +} + TEST_F(GroupByTest, doGroupBy) { using std::string; using std::vector; diff --git a/test/HasPredicateScanTest.cpp b/test/HasPredicateScanTest.cpp index ca7f00fcbc..c0228bf265 100644 --- a/test/HasPredicateScanTest.cpp +++ b/test/HasPredicateScanTest.cpp @@ -16,6 +16,7 @@ #include "engine/IndexScan.h" #include "engine/ValuesForTesting.h" #include "util/IndexTestHelpers.h" +#include "util/OperationTestHelpers.h" namespace { using ad_utility::testing::makeAllocator; @@ -79,6 +80,37 @@ TEST_F(HasPredicateScanTest, freeO) { Variable{"?p"}}}; runTest(scan, {{p}, {p2}}); } +// _____________________________________________________________ +TEST_F(HasPredicateScanTest, clone) { + { + HasPredicateScan scan{ + qec, SparqlTriple{Variable{"?x"}, std::string{HAS_PREDICATE_PREDICATE}, + iri("

")}}; + + auto clone = scan.clone(); + ASSERT_TRUE(clone); + const auto& cloneReference = *clone; + EXPECT_EQ(typeid(scan), typeid(cloneReference)); + EXPECT_EQ(cloneReference.getDescriptor(), scan.getDescriptor()); + + EXPECT_EQ(scan.getChildren().empty(), cloneReference.getChildren().empty()); + } + { + HasPredicateScan scan{qec, + ad_utility::makeExecutionTree( + qec, makeIdTableFromVector({{0}}), + std::vector>{{V{"?p"}}}), + 0, V{"?x"}}; + + auto clone = scan.clone(); + ASSERT_TRUE(clone); + const auto& cloneReference = *clone; + EXPECT_EQ(typeid(scan), typeid(cloneReference)); + EXPECT_EQ(cloneReference.getDescriptor(), scan.getDescriptor()); + + EXPECT_NE(scan.getChildren().at(0), cloneReference.getChildren().at(0)); + } +} // _____________________________________________________________ TEST_F(HasPredicateScanTest, fullScan) { @@ -138,6 +170,21 @@ TEST_F(HasPredicateScanTest, patternTrickWithSubtree) { runTestUnordered(patternTrick, {{p3, Int(2)}, {p, Int(1)}}); } +// ____________________________________________________________ +TEST_F(HasPredicateScanTest, cloneCountAvailablePredicates) { + auto triple = SparqlTriple{V{"?x"}, "", V{"?y"}}; + triple.additionalScanColumns_.emplace_back( + ADDITIONAL_COLUMN_INDEX_SUBJECT_PATTERN, V{"?predicate"}); + auto indexScan = ad_utility::makeExecutionTree( + qec, Permutation::Enum::PSO, triple); + CountAvailablePredicates patternTrick{qec, indexScan, 1, V{"?predicate"}, + V{"?count"}}; + + auto clone = patternTrick.clone(); + ASSERT_TRUE(clone); + EXPECT_THAT(patternTrick, IsDeepCopy(*clone)); + EXPECT_EQ(clone->getDescriptor(), patternTrick.getDescriptor()); +} // ____________________________________________________________ TEST_F(HasPredicateScanTest, patternTrickWithSubtreeTwoFixedElements) { diff --git a/test/JoinTest.cpp b/test/JoinTest.cpp index 6ab090214f..e72029c814 100644 --- a/test/JoinTest.cpp +++ b/test/JoinTest.cpp @@ -764,3 +764,22 @@ TEST(JoinTest, verifyColumnPermutationsAreAppliedCorrectly) { testJoinOperation(join, expectedColumns, false); } } + +// _____________________________________________________________________________ +TEST(JoinTest, clone) { + auto qec = ad_utility::testing::getQec(); + auto leftTree = ad_utility::makeExecutionTree( + qec, makeIdTableFromVector({{I(1), I(1), I(1)}}), + Vars{Variable{"?t"}, Variable{"?s"}, Variable{"?u"}}, false, + std::vector{1}); + auto rightTree = ad_utility::makeExecutionTree( + qec, makeIdTableFromVector({{I(1), I(1), I(1)}}), + Vars{Variable{"?v"}, Variable{"?w"}, Variable{"?s"}}, false, + std::vector{2}); + Join join{qec, leftTree, rightTree, 1, 2}; + + auto clone = join.clone(); + ASSERT_TRUE(clone); + EXPECT_THAT(join, IsDeepCopy(*clone)); + EXPECT_EQ(clone->getDescriptor(), join.getDescriptor()); +} diff --git a/test/MinusTest.cpp b/test/MinusTest.cpp index a27a23c092..95cc44bef4 100644 --- a/test/MinusTest.cpp +++ b/test/MinusTest.cpp @@ -10,7 +10,11 @@ #include "./util/IdTestHelpers.h" #include "engine/CallFixedSize.h" #include "engine/Minus.h" +#include "engine/ValuesForTesting.h" #include "util/AllocatorTestHelpers.h" +#include "util/IdTableHelpers.h" +#include "util/IndexTestHelpers.h" +#include "util/OperationTestHelpers.h" namespace { auto table(size_t cols) { @@ -100,3 +104,22 @@ TEST(EngineTest, minusTest) { ASSERT_EQ(wantedRes[0], vres[0]); } + +// _____________________________________________________________________________ +TEST(Minus, clone) { + auto* qec = ad_utility::testing::getQec(); + Minus minus{ + qec, + ad_utility::makeExecutionTree( + qec, makeIdTableFromVector({{0, 1}}), + std::vector>{Variable{"?x"}, Variable{"?y"}}), + ad_utility::makeExecutionTree( + qec, makeIdTableFromVector({{0, 1}}), + std::vector>{Variable{"?x"}, + Variable{"?z"}})}; + + auto clone = minus.clone(); + ASSERT_TRUE(clone); + EXPECT_THAT(minus, IsDeepCopy(*clone)); + EXPECT_EQ(clone->getDescriptor(), minus.getDescriptor()); +} diff --git a/test/MultiColumnJoinTest.cpp b/test/MultiColumnJoinTest.cpp index f4b138f23c..6186b780d0 100644 --- a/test/MultiColumnJoinTest.cpp +++ b/test/MultiColumnJoinTest.cpp @@ -2,6 +2,7 @@ // Chair of Algorithms and Data Structures. // Author: Florian Kramer (florian.kramer@netpun.uni-freiburg.de) +#include #include #include @@ -12,6 +13,7 @@ #include "util/IdTableHelpers.h" #include "util/IdTestHelpers.h" #include "util/IndexTestHelpers.h" +#include "util/OperationTestHelpers.h" using ad_utility::testing::makeAllocator; namespace { @@ -76,3 +78,30 @@ TEST(EngineTest, multiColumnJoinTest) { ASSERT_EQ(wantedRes[2], vres[2]); ASSERT_EQ(wantedRes[3], vres[3]); } + +// _____________________________________________________________________________ +TEST(MultiColumnJoin, clone) { + auto* qec = ad_utility::testing::getQec(); + IdTable a = makeIdTableFromVector({{4, 1, 2}}); + MultiColumnJoin join{qec, idTableToExecutionTree(qec, a), + idTableToExecutionTree(qec, a)}; + + auto clone = join.clone(); + ASSERT_TRUE(clone); + EXPECT_THAT(join, IsDeepCopy(*clone)); + + std::string_view prefix = "MultiColumnJoin on "; + EXPECT_THAT(join.getDescriptor(), ::testing::StartsWith(prefix)); + EXPECT_THAT(clone->getDescriptor(), ::testing::StartsWith(prefix)); + // Order of join columns is not deterministic. + auto getVars = [prefix](std::string_view string) { + string.remove_prefix(prefix.length()); + std::vector vars; + for (const auto& split : absl::StrSplit(string, ' ', absl::SkipEmpty())) { + vars.emplace_back(split); + } + ql::ranges::sort(vars); + return vars; + }; + EXPECT_EQ(getVars(clone->getDescriptor()), getVars(join.getDescriptor())); +} diff --git a/test/OrderByTest.cpp b/test/OrderByTest.cpp index 046f55d1f6..28a28648c4 100644 --- a/test/OrderByTest.cpp +++ b/test/OrderByTest.cpp @@ -11,6 +11,7 @@ #include "engine/ValuesForTesting.h" #include "global/ValueIdComparators.h" #include "util/IndexTestHelpers.h" +#include "util/OperationTestHelpers.h" using namespace std::string_literals; using namespace std::chrono_literals; @@ -257,3 +258,17 @@ TEST(OrderBy, verifyOperationIsPreemptivelyAbortedWithNoRemainingTime) { orderBy.getResult(true), ::testing::HasSubstr("time estimate exceeded"), ad_utility::CancellationException); } + +// _____________________________________________________________________________ +TEST(OrderBy, clone) { + auto* qec = ad_utility::testing::getQec(); + IdTable permutedInput{2, qec->getAllocator()}; + + OrderBy orderBy = + makeOrderBy(permutedInput.clone(), OrderBy::SortIndices{{0, true}}); + + auto clone = orderBy.clone(); + ASSERT_TRUE(clone); + EXPECT_THAT(orderBy, IsDeepCopy(*clone)); + EXPECT_EQ(clone->getDescriptor(), orderBy.getDescriptor()); +} diff --git a/test/PathSearchTest.cpp b/test/PathSearchTest.cpp index 30ca2b42cf..03a168b798 100644 --- a/test/PathSearchTest.cpp +++ b/test/PathSearchTest.cpp @@ -12,6 +12,7 @@ #include "util/IdTableHelpers.h" #include "util/IdTestHelpers.h" #include "util/IndexTestHelpers.h" +#include "util/OperationTestHelpers.h" using ad_utility::testing::getQec; namespace { @@ -753,3 +754,37 @@ TEST(PathSearchTest, sourceAndTargetBound) { ASSERT_THAT(resultTable.idTable(), ::testing::UnorderedElementsAreArray(expected)); } + +// _____________________________________________________________________________ +TEST(PathSearchTest, clone) { + auto sub = makeIdTableFromVector({{0, 1}}); + + Vars vars = {Variable{"?start"}, Variable{"?end"}}; + PathSearchConfiguration config{PathSearchAlgorithm::ALL_PATHS, + Var{"?source"}, + Var{"?target"}, + Var{"?start"}, + Var{"?end"}, + Var{"?edgeIndex"}, + Var{"?pathIndex"}, + {}}; + + auto qec = getQec(); + auto subtree = ad_utility::makeExecutionTree( + qec, std::move(sub), vars); + PathSearch pathSearch{qec, subtree, std::move(config)}; + + auto clone = pathSearch.clone(); + ASSERT_TRUE(clone); + EXPECT_THAT(pathSearch, IsDeepCopy(*clone)); + EXPECT_EQ(clone->getDescriptor(), pathSearch.getDescriptor()); + + pathSearch.bindSourceSide(subtree, 0); + pathSearch.bindTargetSide(subtree, 0); + pathSearch.bindSourceAndTargetSide(subtree, 0, 0); + + clone = pathSearch.clone(); + ASSERT_TRUE(clone); + EXPECT_THAT(pathSearch, IsDeepCopy(*clone)); + EXPECT_EQ(clone->getDescriptor(), pathSearch.getDescriptor()); +} diff --git a/test/ServiceTest.cpp b/test/ServiceTest.cpp index 7f53202dbb..696dd0f811 100644 --- a/test/ServiceTest.cpp +++ b/test/ServiceTest.cpp @@ -22,6 +22,7 @@ #include "util/GTestHelpers.h" #include "util/IdTableHelpers.h" #include "util/IndexTestHelpers.h" +#include "util/OperationTestHelpers.h" #include "util/TripleComponentTestHelpers.h" #include "util/http/HttpUtils.h" @@ -795,3 +796,25 @@ TEST_F(ServiceTest, precomputeSiblingResult) { ->idTables()) { } } + +// ____________________________________________________________________________ +TEST_F(ServiceTest, clone) { + Service service{ + testQec, + parsedQuery::Service{ + {Variable{"?x"}, Variable{"?y"}}, + TripleComponent::Iri::fromIriref(""), + "PREFIX doof: ", + "{ }", + true}, + getResultFunctionFactory( + "http://localhorst:80/api", + "PREFIX doof: SELECT ?x ?y WHERE { }", + genJsonResult({"x", "y"}, {{"a", "b"}}), + boost::beast::http::status::ok, "application/sparql-results+json")}; + + auto clone = service.clone(); + ASSERT_TRUE(clone); + EXPECT_THAT(service, IsDeepCopy(*clone)); + EXPECT_EQ(clone->getDescriptor(), service.getDescriptor()); +} diff --git a/test/SortTest.cpp b/test/SortTest.cpp index ab51138815..b1d4ab3961 100644 --- a/test/SortTest.cpp +++ b/test/SortTest.cpp @@ -10,6 +10,7 @@ #include "engine/ValuesForTesting.h" #include "global/ValueIdComparators.h" #include "util/IndexTestHelpers.h" +#include "util/OperationTestHelpers.h" using namespace std::string_literals; using namespace std::chrono_literals; @@ -196,3 +197,13 @@ TEST(Sort, verifyOperationIsPreemptivelyAbortedWithNoRemainingTime) { sort.getResult(true), ::testing::HasSubstr("time estimate exceeded"), ad_utility::CancellationException); } + +// _____________________________________________________________________________ +TEST(Sort, clone) { + Sort sort = makeSort(makeIdTableFromVector({{0, 0}}), {0}); + + auto clone = sort.clone(); + ASSERT_TRUE(clone); + EXPECT_THAT(sort, IsDeepCopy(*clone)); + EXPECT_EQ(clone->getDescriptor(), sort.getDescriptor()); +} diff --git a/test/TextLimitOperationTest.cpp b/test/TextLimitOperationTest.cpp index ccf5d09088..ee52d25f44 100644 --- a/test/TextLimitOperationTest.cpp +++ b/test/TextLimitOperationTest.cpp @@ -11,6 +11,7 @@ #include "engine/TextLimit.h" #include "engine/ValuesForTesting.h" #include "util/IndexTestHelpers.h" +#include "util/OperationTestHelpers.h" namespace { TextLimit makeTextLimit(IdTable input, const size_t& n, @@ -502,3 +503,15 @@ TEST(TextLimit, CacheKey) { // The input is different. ASSERT_NE(textLimit1.getCacheKey(), textLimit7.getCacheKey()); } + +// _____________________________________________________________________________ +TEST(TextLimit, clone) { + VectorTable input{{1, 2, 3}, {1, 2, 3}, {1, 2, 3}}; + IdTable inputTable = makeIdTableFromVector(input, &Id::makeFromInt); + TextLimit textLimit = makeTextLimit(inputTable.clone(), 4, 0, {1}, {2}); + + auto clone = textLimit.clone(); + ASSERT_TRUE(clone); + EXPECT_THAT(static_cast(textLimit), IsDeepCopy(*clone)); + EXPECT_EQ(clone->getDescriptor(), textLimit.getDescriptor()); +} diff --git a/test/TransitivePathTest.cpp b/test/TransitivePathTest.cpp index 2e6da1855b..3c4879e3b8 100644 --- a/test/TransitivePathTest.cpp +++ b/test/TransitivePathTest.cpp @@ -16,6 +16,7 @@ #include "util/GTestHelpers.h" #include "util/IdTableHelpers.h" #include "util/IndexTestHelpers.h" +#include "util/OperationTestHelpers.h" using ad_utility::testing::getQec; namespace { @@ -651,6 +652,46 @@ TEST_P(TransitivePathTest, zeroLengthException) { "not supported")); } +// _____________________________________________________________________________ +TEST_P(TransitivePathTest, clone) { + auto sub = makeIdTableFromVector({{0, 2}}); + + TransitivePathSide left(std::nullopt, 0, Variable{"?start"}, 0); + TransitivePathSide right(std::nullopt, 1, Variable{"?target"}, 1); + { + auto transitivePath = + makePathUnbound(sub.clone(), {Variable{"?start"}, Variable{"?target"}}, + left, right, 0, std::numeric_limits::max()); + + auto clone = transitivePath->clone(); + ASSERT_TRUE(clone); + EXPECT_THAT(*transitivePath, IsDeepCopy(*clone)); + EXPECT_EQ(clone->getDescriptor(), transitivePath->getDescriptor()); + } + { + auto transitivePath = makePathBound( + false, std::move(sub), {Variable{"?start"}, Variable{"?target"}}, + sub.clone(), 0, {Variable{"?start"}, Variable{"?other"}}, left, right, + 0, std::numeric_limits::max()); + + auto clone = transitivePath->clone(); + ASSERT_TRUE(clone); + EXPECT_THAT(*transitivePath, IsDeepCopy(*clone)); + EXPECT_EQ(clone->getDescriptor(), transitivePath->getDescriptor()); + } + { + auto transitivePath = makePathBound( + true, std::move(sub), {Variable{"?start"}, Variable{"?target"}}, + sub.clone(), 0, {Variable{"?target"}, Variable{"?other"}}, left, right, + 0, std::numeric_limits::max()); + + auto clone = transitivePath->clone(); + ASSERT_TRUE(clone); + EXPECT_THAT(*transitivePath, IsDeepCopy(*clone)); + EXPECT_EQ(clone->getDescriptor(), transitivePath->getDescriptor()); + } +} + // _____________________________________________________________________________ INSTANTIATE_TEST_SUITE_P( TransitivePathTestSuite, TransitivePathTest, diff --git a/test/UnionTest.cpp b/test/UnionTest.cpp index 265aa92ec8..5a2036a9a1 100644 --- a/test/UnionTest.cpp +++ b/test/UnionTest.cpp @@ -4,15 +4,16 @@ #include -#include #include #include "./engine/ValuesForTesting.h" #include "./util/IdTableHelpers.h" #include "./util/IdTestHelpers.h" +#include "engine/NeutralElementOperation.h" #include "engine/Union.h" #include "global/Id.h" #include "util/IndexTestHelpers.h" +#include "util/OperationTestHelpers.h" namespace { auto V = ad_utility::testing::VocabId; @@ -181,3 +182,17 @@ TEST(Union, ensurePermutationIsAppliedCorrectly) { EXPECT_EQ(resultTable.idTable(), expected); } } + +// _____________________________________________________________________________ +TEST(Union, clone) { + auto* qec = ad_utility::testing::getQec(); + + Union unionOperation{ + qec, ad_utility::makeExecutionTree(qec), + ad_utility::makeExecutionTree(qec)}; + + auto clone = unionOperation.clone(); + ASSERT_TRUE(clone); + EXPECT_THAT(unionOperation, IsDeepCopy(*clone)); + EXPECT_EQ(clone->getDescriptor(), unionOperation.getDescriptor()); +} diff --git a/test/ValuesTest.cpp b/test/ValuesTest.cpp index 901f06d332..58c41f2e6a 100644 --- a/test/ValuesTest.cpp +++ b/test/ValuesTest.cpp @@ -13,6 +13,7 @@ #include "engine/Values.h" #include "engine/idTable/IdTable.h" #include "util/IndexTestHelpers.h" +#include "util/OperationTestHelpers.h" using TC = TripleComponent; using ValuesComponents = std::vector>; @@ -90,3 +91,16 @@ TEST(Values, illegalInput) { ValuesComponents values{{TC{12}, TC{""}}, {TC::UNDEF{}}}; ASSERT_ANY_THROW(Values(qec, {{Variable{"?x"}, Variable{"?y"}}, values})); } + +// _____________________________________________________________________________ +TEST(Values, clone) { + auto testQec = ad_utility::testing::getQec(" ."); + ValuesComponents values{{TC{12}, TC{iri("")}}, + {TC::UNDEF{}, TC{iri("")}}}; + Values valuesOperation(testQec, {{Variable{"?x"}, Variable{"?y"}}, values}); + + auto clone = valuesOperation.clone(); + ASSERT_TRUE(clone); + EXPECT_THAT(valuesOperation, IsDeepCopy(*clone)); + EXPECT_EQ(clone->getDescriptor(), valuesOperation.getDescriptor()); +} diff --git a/test/engine/BindTest.cpp b/test/engine/BindTest.cpp index 34ef0eb370..a51c57aa77 100644 --- a/test/engine/BindTest.cpp +++ b/test/engine/BindTest.cpp @@ -6,6 +6,7 @@ #include "../util/IdTableHelpers.h" #include "../util/IndexTestHelpers.h" +#include "../util/OperationTestHelpers.h" #include "./ValuesForTesting.h" #include "engine/Bind.h" #include "engine/sparqlExpressions/LiteralExpression.h" @@ -131,3 +132,22 @@ TEST( EXPECT_EQ(++iterator, idTables.end()); } } + +// _____________________________________________________________________________ +TEST(Bind, clone) { + auto* qec = ad_utility::testing::getQec(); + auto valuesTree = ad_utility::makeExecutionTree( + qec, IdTable{1, qec->getAllocator()}, Vars{Variable{"?a"}}, false, + std::vector{}, LocalVocab{}, std::nullopt, true); + Bind bind{ + qec, + std::move(valuesTree), + {SparqlExpressionPimpl{ + std::make_unique(Id::makeFromInt(42)), "42 as ?b"}, + Variable{"?b"}}}; + + auto clone = bind.clone(); + ASSERT_TRUE(clone); + EXPECT_THAT(bind, IsDeepCopy(*clone)); + EXPECT_EQ(clone->getDescriptor(), bind.getDescriptor()); +} diff --git a/test/engine/CartesianProductJoinTest.cpp b/test/engine/CartesianProductJoinTest.cpp index 8727aa223a..91f7809c09 100644 --- a/test/engine/CartesianProductJoinTest.cpp +++ b/test/engine/CartesianProductJoinTest.cpp @@ -8,6 +8,7 @@ #include "../util/GTestHelpers.h" #include "../util/IdTableHelpers.h" #include "../util/IndexTestHelpers.h" +#include "../util/OperationTestHelpers.h" #include "engine/CartesianProductJoin.h" #include "engine/QueryExecutionTree.h" @@ -629,3 +630,19 @@ INSTANTIATE_TEST_SUITE_P( } return std::move(stream).str(); }); + +// _____________________________________________________________________________ +TEST(CartesianProductJoin, clone) { + auto qec = getQec(); + std::vector> subtrees; + using Vars = std::vector>; + subtrees.push_back(ad_utility::makeExecutionTree( + qec, makeIdTableFromVector({{3, 4}}), + Vars{Variable{"?x"}, std::nullopt})); + CartesianProductJoin join{qec, std::move(subtrees)}; + + auto clone = join.clone(); + ASSERT_TRUE(clone); + EXPECT_THAT(join, IsDeepCopy(*clone)); + EXPECT_EQ(clone->getDescriptor(), join.getDescriptor()); +} diff --git a/test/engine/DescribeTest.cpp b/test/engine/DescribeTest.cpp index 76618a0f40..8bde2093e4 100644 --- a/test/engine/DescribeTest.cpp +++ b/test/engine/DescribeTest.cpp @@ -6,6 +6,7 @@ #include "../util/GTestHelpers.h" #include "../util/IndexTestHelpers.h" +#include "../util/OperationTestHelpers.h" #include "engine/Describe.h" #include "engine/IndexScan.h" #include "engine/NeutralElementOperation.h" @@ -178,3 +179,18 @@ TEST(Describe, simpleMembers) { EXPECT_THAT(children.at(0)->getRootOperation()->getDescriptor(), Eq("NeutralElement")); } + +// _____________________________________________________________________________ +TEST(Describe, clone) { + auto qec = getQec(); + parsedQuery::Describe parsedDescribe; + parsedDescribe.resources_.push_back(TripleComponent::Iri::fromIriref("")); + Describe describe{qec, + ad_utility::makeExecutionTree(qec), + parsedDescribe}; + + auto clone = describe.clone(); + ASSERT_TRUE(clone); + EXPECT_THAT(describe, IsDeepCopy(*clone)); + EXPECT_EQ(clone->getDescriptor(), describe.getDescriptor()); +} diff --git a/test/engine/DistinctTest.cpp b/test/engine/DistinctTest.cpp index 5ce219b60f..dbf623572c 100644 --- a/test/engine/DistinctTest.cpp +++ b/test/engine/DistinctTest.cpp @@ -6,6 +6,7 @@ #include "../util/IdTableHelpers.h" #include "../util/IndexTestHelpers.h" +#include "../util/OperationTestHelpers.h" #include "engine/Distinct.h" #include "engine/NeutralElementOperation.h" @@ -223,3 +224,16 @@ TEST(Distinct, lazyWithLazyInputs) { {{6, 7, 0, 6}, {2, 7, 1, 5}, {3, 7, 2, 4}, {1, 7, 3, 1}})), m(makeIdTableFromVector({{6, 7, 4, 6}})))); } + +// _____________________________________________________________________________ +TEST(Distinct, clone) { + auto qec = ad_utility::testing::getQec(); + Distinct distinct{ad_utility::testing::getQec(), + ad_utility::makeExecutionTree(qec), + std::vector{0, 1}}; + + auto clone = distinct.clone(); + ASSERT_TRUE(clone); + EXPECT_THAT(distinct, IsDeepCopy(*clone)); + EXPECT_EQ(clone->getDescriptor(), distinct.getDescriptor()); +} diff --git a/test/engine/ExistsJoinTest.cpp b/test/engine/ExistsJoinTest.cpp index 627d3283a4..93fc5fb2c8 100644 --- a/test/engine/ExistsJoinTest.cpp +++ b/test/engine/ExistsJoinTest.cpp @@ -7,9 +7,9 @@ #include "../util/GTestHelpers.h" #include "../util/IdTableHelpers.h" #include "../util/IndexTestHelpers.h" +#include "../util/OperationTestHelpers.h" #include "engine/ExistsJoin.h" #include "engine/IndexScan.h" -#include "engine/NeutralElementOperation.h" #include "engine/QueryExecutionTree.h" using namespace ad_utility::testing; @@ -133,3 +133,22 @@ TEST(Exists, computeResult) { testExistsFromIdTable(makeIdTableFromVector({{U, U}, {3, 7}}), IdTable(2, alloc), {false, false}, 2); } + +// _____________________________________________________________________________ +TEST(Exists, clone) { + auto* qec = getQec(); + ExistsJoin existsJoin{ + qec, + ad_utility::makeExecutionTree( + qec, makeIdTableFromVector({{0, 1}}), + std::vector>{Variable{"?x"}, Variable{"?y"}}), + ad_utility::makeExecutionTree( + qec, makeIdTableFromVector({{0, 1}}), + std::vector>{Variable{"?x"}, Variable{"?y"}}), + Variable{"?z"}}; + + auto clone = existsJoin.clone(); + ASSERT_TRUE(clone); + EXPECT_THAT(existsJoin, IsDeepCopy(*clone)); + EXPECT_EQ(clone->getDescriptor(), existsJoin.getDescriptor()); +} diff --git a/test/engine/IndexScanTest.cpp b/test/engine/IndexScanTest.cpp index 395a21261a..b99d32078e 100644 --- a/test/engine/IndexScanTest.cpp +++ b/test/engine/IndexScanTest.cpp @@ -1082,3 +1082,33 @@ TEST(IndexScan, prefilterTablesWithEmptyIndexScanReturnsEmptyGenerators) { EXPECT_EQ(leftGenerator.begin(), leftGenerator.end()); EXPECT_EQ(rightGenerator.begin(), rightGenerator.end()); } +// _____________________________________________________________________________ +TEST(IndexScan, clone) { + auto* qec = getQec(); + { + SparqlTriple xpy{Tc{Var{"?x"}}, "", Tc{Var{"?y"}}}; + IndexScan scan{qec, Permutation::PSO, xpy}; + + auto clone = scan.clone(); + ASSERT_TRUE(clone); + const auto& cloneReference = *clone; + EXPECT_EQ(typeid(scan), typeid(cloneReference)); + EXPECT_EQ(cloneReference.getDescriptor(), scan.getDescriptor()); + } + { + using namespace makeFilterExpression; + SparqlTriple xpy{Tc{Var{"?x"}}, "", Tc{Var{"?y"}}}; + IndexScan scan{ + qec, + Permutation::PSO, + xpy, + std::nullopt, + {{filterHelper::pr(ge(IntId(10)), Variable{"?price"}).first, 0}}}; + + auto clone = scan.clone(); + ASSERT_TRUE(clone); + const auto& cloneReference = *clone; + EXPECT_EQ(typeid(scan), typeid(cloneReference)); + EXPECT_EQ(cloneReference.getDescriptor(), scan.getDescriptor()); + } +} diff --git a/test/engine/SpatialJoinTest.cpp b/test/engine/SpatialJoinTest.cpp index 766801ee3b..3b6ad521d6 100644 --- a/test/engine/SpatialJoinTest.cpp +++ b/test/engine/SpatialJoinTest.cpp @@ -706,6 +706,87 @@ TEST(SpatialJoin, getCacheKeyImpl) { ASSERT_TRUE(cacheKeyString.find(rightCacheKeyString) != std::string::npos); } +// _____________________________________________________________________________ +TEST(SpatialJoin, clone) { + auto qec = buildTestQEC(); + auto numTriples = qec->getIndex().numTriples().normal; + ASSERT_EQ(numTriples, 15); + auto leftChild = + buildIndexScan(qec, {"?obj1", std::string{""}, "?point1"}); + auto rightChild = + buildIndexScan(qec, {"?obj2", std::string{""}, "?point2"}); + + { + SpatialJoin spatialJoin{ + qec, + SpatialJoinConfiguration{MaxDistanceConfig{1000}, Variable{"?point1"}, + Variable{"?point2"}}, + std::nullopt, std::nullopt}; + + auto clone = spatialJoin.clone(); + ASSERT_TRUE(clone); + const auto& cloneReference = *clone; + EXPECT_EQ(typeid(spatialJoin), typeid(cloneReference)); + EXPECT_EQ(cloneReference.getDescriptor(), spatialJoin.getDescriptor()); + + EXPECT_EQ(spatialJoin.getChildren().empty(), + cloneReference.getChildren().empty()); + } + + { + SpatialJoin spatialJoin{ + qec, + SpatialJoinConfiguration{MaxDistanceConfig{1000}, Variable{"?point1"}, + Variable{"?point2"}}, + leftChild, std::nullopt}; + + auto clone = spatialJoin.clone(); + ASSERT_TRUE(clone); + const auto& cloneReference = *clone; + EXPECT_EQ(typeid(spatialJoin), typeid(cloneReference)); + EXPECT_EQ(cloneReference.getDescriptor(), spatialJoin.getDescriptor()); + + EXPECT_NE(spatialJoin.getChildren().at(0), + cloneReference.getChildren().at(0)); + } + + { + SpatialJoin spatialJoin{ + qec, + SpatialJoinConfiguration{MaxDistanceConfig{1000}, Variable{"?point1"}, + Variable{"?point2"}}, + std::nullopt, rightChild}; + + auto clone = spatialJoin.clone(); + ASSERT_TRUE(clone); + const auto& cloneReference = *clone; + EXPECT_EQ(typeid(spatialJoin), typeid(cloneReference)); + EXPECT_EQ(cloneReference.getDescriptor(), spatialJoin.getDescriptor()); + + EXPECT_NE(spatialJoin.getChildren().at(0), + cloneReference.getChildren().at(0)); + } + + { + SpatialJoin spatialJoin{ + qec, + SpatialJoinConfiguration{MaxDistanceConfig{1000}, Variable{"?point1"}, + Variable{"?point2"}}, + leftChild, rightChild}; + + auto clone = spatialJoin.clone(); + ASSERT_TRUE(clone); + const auto& cloneReference = *clone; + EXPECT_EQ(typeid(spatialJoin), typeid(cloneReference)); + EXPECT_EQ(cloneReference.getDescriptor(), spatialJoin.getDescriptor()); + + EXPECT_NE(spatialJoin.getChildren().at(0), + cloneReference.getChildren().at(0)); + EXPECT_NE(spatialJoin.getChildren().at(1), + cloneReference.getChildren().at(1)); + } +} + } // namespace stringRepresentation namespace getMultiplicityAndSizeEstimate { diff --git a/test/engine/TextIndexScanForEntityTest.cpp b/test/engine/TextIndexScanForEntityTest.cpp index f2e8bcd611..9263ff8e39 100644 --- a/test/engine/TextIndexScanForEntityTest.cpp +++ b/test/engine/TextIndexScanForEntityTest.cpp @@ -7,6 +7,7 @@ #include "../util/GTestHelpers.h" #include "../util/IdTableHelpers.h" #include "../util/IndexTestHelpers.h" +#include "../util/OperationTestHelpers.h" #include "./TextIndexScanTestHelpers.h" #include "engine/IndexScan.h" #include "engine/TextIndexScanForEntity.h" @@ -152,4 +153,17 @@ TEST(TextIndexScanForEntity, KnownEmpty) { ASSERT_TRUE(!s3.knownEmptyResult()); } +// _____________________________________________________________________________ +TEST(TextIndexScanForEntity, clone) { + auto qec = getQec(); + + TextIndexScanForEntity scan{qec, Variable{"?text"}, Variable{"?entityVar"}, + "nonExistentWord*"}; + + auto clone = scan.clone(); + ASSERT_TRUE(clone); + EXPECT_THAT(scan, IsDeepCopy(*clone)); + EXPECT_EQ(clone->getDescriptor(), scan.getDescriptor()); +} + } // namespace diff --git a/test/engine/TextIndexScanForWordTest.cpp b/test/engine/TextIndexScanForWordTest.cpp index cc9b685ec8..c6e3fe6e87 100644 --- a/test/engine/TextIndexScanForWordTest.cpp +++ b/test/engine/TextIndexScanForWordTest.cpp @@ -10,6 +10,7 @@ #include "../util/GTestHelpers.h" #include "../util/IdTableHelpers.h" #include "../util/IndexTestHelpers.h" +#include "../util/OperationTestHelpers.h" #include "./TextIndexScanTestHelpers.h" #include "engine/IndexScan.h" #include "engine/TextIndexScanForWord.h" @@ -256,4 +257,16 @@ TEST(TextIndexScanForWord, KnownEmpty) { TextIndexScanForWord s5{qec, Variable{"?text1"}, "testing"}; ASSERT_TRUE(!s5.knownEmptyResult()); } + +// _____________________________________________________________________________ +TEST(TextIndexScanForWord, clone) { + auto qec = getQec(); + + TextIndexScanForWord scan{qec, Variable{"?text1"}, "nonExistentWord*"}; + + auto clone = scan.clone(); + ASSERT_TRUE(clone); + EXPECT_THAT(scan, IsDeepCopy(*clone)); + EXPECT_EQ(clone->getDescriptor(), scan.getDescriptor()); +} } // namespace diff --git a/test/engine/ValuesForTesting.h b/test/engine/ValuesForTesting.h index ec9b55cd30..7ffb22cebe 100644 --- a/test/engine/ValuesForTesting.h +++ b/test/engine/ValuesForTesting.h @@ -75,6 +75,9 @@ class ValuesForTesting : public Operation { costEstimate_ = totalRows; } + ValuesForTesting(ValuesForTesting&&) = default; + ValuesForTesting& operator=(ValuesForTesting&&) = default; + // Accessors for the estimates for manual testing. size_t& sizeEstimate() { return sizeEstimate_; } size_t& costEstimate() { return costEstimate_; } @@ -210,6 +213,29 @@ class ValuesForTesting : public Operation { return m; } + // _____________________________________________________________________________ + ValuesForTesting(const ValuesForTesting& other) + : Operation{other._executionContext}, + variables_{other.variables_}, + supportsLimit_{other.supportsLimit_}, + sizeEstimate_{other.sizeEstimate_}, + costEstimate_{other.costEstimate_}, + unlikelyToFitInCache_{other.unlikelyToFitInCache_}, + resultSortedColumns_{other.resultSortedColumns_}, + localVocab_{other.localVocab_.clone()}, + multiplicity_{other.multiplicity_}, + forceFullyMaterialized_{other.forceFullyMaterialized_} { + for (const auto& idTable : other.tables_) { + tables_.push_back(idTable.clone()); + } + } + + ValuesForTesting& operator=(const ValuesForTesting&) = delete; + + std::unique_ptr cloneImpl() const override { + return std::make_unique(ValuesForTesting{*this}); + } + std::vector resultSortedColumns_; LocalVocab localVocab_; std::optional multiplicity_; diff --git a/test/util/OperationTestHelpers.h b/test/util/OperationTestHelpers.h index 14c455960c..f0832ed6cc 100644 --- a/test/util/OperationTestHelpers.h +++ b/test/util/OperationTestHelpers.h @@ -43,6 +43,11 @@ class StallForeverOperation : public Operation { std::chrono::milliseconds publicRemainingTime() const { return remainingTime(); } + + // _____________________________________________________________________________ + std::unique_ptr cloneImpl() const override { + AD_THROW("Clone not implemented"); + } }; // _____________________________________________________________________________ @@ -83,6 +88,11 @@ class ShallowParentOperation : public Operation { std::chrono::milliseconds publicRemainingTime() const { return remainingTime(); } + + // _____________________________________________________________________________ + std::unique_ptr cloneImpl() const override { + AD_THROW("Clone not implemented"); + } }; // Operation that will throw on `computeResult` for testing. @@ -127,6 +137,11 @@ class AlwaysFailOperation : public Operation { }(), resultSortedOn()}; } + + // _____________________________________________________________________________ + std::unique_ptr cloneImpl() const override { + AD_THROW("Clone not implemented"); + } }; // Lazy operation that will yield a result with a custom generator you can @@ -154,6 +169,25 @@ class CustomGeneratorOperation : public Operation { AD_CONTRACT_CHECK(requestLaziness); return {std::move(generator_), resultSortedOn()}; } + + // _____________________________________________________________________________ + std::unique_ptr cloneImpl() const override { + AD_THROW("Clone not implemented"); + } }; +MATCHER_P(SameTypeId, ptr, "has the same type id") { + return typeid(*arg) == typeid(*ptr); +} + +inline auto IsDeepCopy(const Operation& other) { + using namespace ::testing; + return AllOf( + Address(SameTypeId(&other)), + AD_PROPERTY(Operation, getChildren, Pointwise(Ne(), other.getChildren())), + AD_PROPERTY(Operation, getCacheKey, Eq(other.getCacheKey())), + AD_PROPERTY(Operation, getExternallyVisibleVariableColumns, + Eq(other.getExternallyVisibleVariableColumns()))); +} + #endif // QLEVER_OPERATIONTESTHELPERS_H From 36cdbd4b07d781987af17822e623067d28d41500 Mon Sep 17 00:00:00 2001 From: Hannah Bast Date: Fri, 21 Feb 2025 03:18:10 +0100 Subject: [PATCH 5/7] Revert "Better error message on parallel turtle parsing ... (#1807)" (#1827) This reverts commit 8678731edf09615c066f682b53d9660784388529, which breaks the index build, see https://github.com/ad-freiburg/qlever-control/issues/139 --- src/engine/GraphStoreProtocol.cpp | 1 - src/index/IndexImpl.cpp | 2 - src/parser/RdfEscaping.h | 3 ++ src/parser/RdfParser.cpp | 62 +++++++++---------------------- src/parser/RdfParser.h | 57 ++++++++++++++-------------- src/parser/Tokenizer.h | 7 ++-- src/parser/TokenizerCtre.h | 2 - src/util/TaskQueue.h | 5 --- test/DeltaTriplesTest.cpp | 3 +- test/IndexTest.cpp | 2 +- test/RdfParserTest.cpp | 45 +--------------------- 11 files changed, 55 insertions(+), 134 deletions(-) diff --git a/src/engine/GraphStoreProtocol.cpp b/src/engine/GraphStoreProtocol.cpp index a708f5e948..fc46ec6fa0 100644 --- a/src/engine/GraphStoreProtocol.cpp +++ b/src/engine/GraphStoreProtocol.cpp @@ -4,7 +4,6 @@ #include "engine/GraphStoreProtocol.h" -#include "parser/Tokenizer.h" #include "util/http/beast.h" // ____________________________________________________________________________ diff --git a/src/index/IndexImpl.cpp b/src/index/IndexImpl.cpp index 3b2cc58d1a..1b0c762256 100644 --- a/src/index/IndexImpl.cpp +++ b/src/index/IndexImpl.cpp @@ -21,8 +21,6 @@ #include "index/IndexFormatVersion.h" #include "index/VocabularyMerger.h" #include "parser/ParallelParseBuffer.h" -#include "parser/Tokenizer.h" -#include "parser/TokenizerCtre.h" #include "util/BatchedPipeline.h" #include "util/CachingMemoryResource.h" #include "util/HashMap.h" diff --git a/src/parser/RdfEscaping.h b/src/parser/RdfEscaping.h index 29d69aa2b4..36bbcd74f7 100644 --- a/src/parser/RdfEscaping.h +++ b/src/parser/RdfEscaping.h @@ -5,12 +5,15 @@ #ifndef QLEVER_RDFESCAPING_H #define QLEVER_RDFESCAPING_H +#include + #include #include #include "global/TypedIndex.h" #include "parser/NormalizedString.h" #include "util/Exception.h" +#include "util/HashSet.h" #include "util/StringUtils.h" namespace RdfEscaping { diff --git a/src/parser/RdfParser.cpp b/src/parser/RdfParser.cpp index fdca43e3e1..c100a4b1dc 100644 --- a/src/parser/RdfParser.cpp +++ b/src/parser/RdfParser.cpp @@ -15,11 +15,10 @@ #include "global/Constants.h" #include "parser/GeoPoint.h" #include "parser/NormalizedString.h" -#include "parser/Tokenizer.h" -#include "parser/TokenizerCtre.h" +#include "parser/RdfEscaping.h" +#include "util/Conversions.h" #include "util/DateYearDuration.h" #include "util/OnDestructionDontThrowDuringStackUnwinding.h" -#include "util/TransparentFunctors.h" using namespace std::chrono_literals; // _______________________________________________________________ @@ -32,17 +31,7 @@ bool TurtleParser::statement() { // ______________________________________________________________ template bool TurtleParser::directive() { - bool successfulParse = prefixID() || base() || sparqlPrefix() || sparqlBase(); - if (successfulParse && prefixAndBaseDisabled_) { - raise( - "@prefix or @base directives need to be at the beginning of the file " - "when using the parallel parser. Use '--parse-parallel false' if you " - "can't guarantee this. If the reason for this error is that the input " - "is a concatenation of Turtle files, each of which has the prefixes at " - "the beginning, you should feed the files to QLever separately instead " - "of concatenated"); - } - return successfulParse; + return prefixID() || base() || sparqlPrefix() || sparqlBase(); } // ________________________________________________________________ @@ -641,7 +630,7 @@ bool TurtleParser::iri() { // _____________________________________________________________________ template bool TurtleParser::prefixedName() { - if constexpr (T::UseRelaxedParsing) { + if constexpr (UseRelaxedParsing) { if (!(pnameLnRelaxed() || pnameNS())) { return false; } @@ -756,7 +745,7 @@ bool TurtleParser::iriref() { // In relaxed mode, that is all we check. Otherwise, we check if the IRI is // standard-compliant. If not, we output a warning and try to parse it in a // more relaxed way. - if constexpr (T::UseRelaxedParsing) { + if constexpr (UseRelaxedParsing) { tok_.remove_prefix(endPos + 1); lastParseResult_ = TripleComponent::Iri::fromIrirefConsiderBase( view.substr(0, endPos + 1), baseForRelativeIri(), baseForAbsoluteIri()); @@ -959,20 +948,20 @@ bool RdfStreamParser::getLineImpl(TurtleTriple* triple) { // `parallelParser_` have been fully processed. After the last batch we will // push another call to this lambda to the `parallelParser_` which will then // finish the `tripleCollector_` as soon as all batches have been computed. -template -void RdfParallelParser::finishTripleCollectorIfLastBatch() { +template +void RdfParallelParser::finishTripleCollectorIfLastBatch() { if (batchIdx_.fetch_add(1) == numBatchesTotal_) { tripleCollector_.finish(); } } // __________________________________________________________________________________ -template -void RdfParallelParser::parseBatch(size_t parsePosition, auto batch) { +template +void RdfParallelParser::parseBatch(size_t parsePosition, + auto batch) { try { - RdfStringParser parser{defaultGraphIri_}; + RdfStringParser parser{defaultGraphIri_}; parser.prefixMap_ = this->prefixMap_; - parser.disablePrefixParsing(); parser.setPositionOffset(parsePosition); parser.setInputStream(std::move(batch)); // TODO: raise error message if a prefix parsing fails; @@ -983,15 +972,14 @@ void RdfParallelParser::parseBatch(size_t parsePosition, auto batch) { }); finishTripleCollectorIfLastBatch(); } catch (std::exception& e) { - errorMessages_.wlock()->emplace_back(parsePosition, e.what()); tripleCollector_.pushException(std::current_exception()); parallelParser_.finish(); } }; // _______________________________________________________________________ -template -void RdfParallelParser::feedBatchesToParser( +template +void RdfParallelParser::feedBatchesToParser( auto remainingBatchFromInitialization) { bool first = true; size_t parsePosition = 0; @@ -1031,15 +1019,14 @@ void RdfParallelParser::feedBatchesToParser( } } } catch (std::exception& e) { - errorMessages_.wlock()->emplace_back(parsePosition, e.what()); tripleCollector_.pushException(std::current_exception()); } }; // _______________________________________________________________________ -template -void RdfParallelParser::initialize(const string& filename, - ad_utility::MemorySize bufferSize) { +template +void RdfParallelParser::initialize( + const string& filename, ad_utility::MemorySize bufferSize) { fileBuffer_ = std::make_unique( bufferSize.getBytes(), "\\.[\\t ]*([\\r\\n]+)"); ParallelBuffer::BufferType remainingBatchFromInitialization; @@ -1048,7 +1035,7 @@ void RdfParallelParser::initialize(const string& filename, LOG(WARN) << "Empty input to the TURTLE parser, is this what you intended?" << std::endl; } else { - RdfStringParser declarationParser{}; + RdfStringParser declarationParser{}; declarationParser.setInputStream(std::move(batch.value())); while (declarationParser.parseDirectiveManually()) { } @@ -1075,20 +1062,7 @@ bool RdfParallelParser::getLineImpl(TurtleTriple* triple) { // contains no triples. (Theoretically this might happen, and it is safer this // way) while (triples_.empty()) { - auto optionalTripleTask = [&]() { - try { - return tripleCollector_.pop(); - } catch (const std::exception&) { - // In case of multiple errors in parallel batches, we always report the - // first error. - parallelParser_.waitUntilFinished(); - auto errors = std::move(*errorMessages_.wlock()); - const auto& firstError = - ql::ranges::min_element(errors, {}, ad_utility::first); - AD_CORRECTNESS_CHECK(firstError != errors.end()); - throw std::runtime_error{firstError->second}; - } - }(); + auto optionalTripleTask = tripleCollector_.pop(); if (!optionalTripleTask) { // Everything has been parsed return false; diff --git a/src/parser/RdfParser.h b/src/parser/RdfParser.h index d8aced692e..1fb36871c9 100644 --- a/src/parser/RdfParser.h +++ b/src/parser/RdfParser.h @@ -4,29 +4,35 @@ #pragma once -#include #include +#include +#include +#include #include #include -#include #include +#include "absl/strings/str_cat.h" #include "global/Constants.h" #include "global/SpecialIds.h" #include "index/ConstantsIndexBuilding.h" #include "index/InputFileSpecification.h" #include "parser/ParallelBuffer.h" +#include "parser/Tokenizer.h" +#include "parser/TokenizerCtre.h" #include "parser/TripleComponent.h" -#include "parser/TurtleTokenId.h" #include "parser/data/BlankNode.h" #include "util/Exception.h" +#include "util/File.h" #include "util/HashMap.h" #include "util/Log.h" #include "util/ParseException.h" #include "util/TaskQueue.h" #include "util/ThreadSafeQueue.h" +using std::string; + enum class TurtleParserIntegerOverflowBehavior { Error, OverflowingToDouble, @@ -120,6 +126,10 @@ class TurtleParser : public RdfParserBase { public: using ParseException = ::ParseException; + // The CTRE Tokenizer implies relaxed parsing. + static constexpr bool UseRelaxedParsing = + std::is_same_v; + // Get the result of the single rule that was parsed most recently. Used for // testing. const TripleComponent& getLastParseResult() const { return lastParseResult_; } @@ -194,10 +204,10 @@ class TurtleParser : public RdfParserBase { // Getters for the two base prefixes. Without BASE declaration, these will // both return the empty IRI. - const TripleComponent::Iri& baseForRelativeIri() const { + const TripleComponent::Iri& baseForRelativeIri() { return prefixMap_.at(baseForRelativeIriKey_); } - const TripleComponent::Iri& baseForAbsoluteIri() const { + const TripleComponent::Iri& baseForAbsoluteIri() { return prefixMap_.at(baseForAbsoluteIriKey_); } @@ -216,8 +226,6 @@ class TurtleParser : public RdfParserBase { static inline std::atomic numParsers_ = 0; size_t blankNodePrefix_ = numParsers_.fetch_add(1); - bool prefixAndBaseDisabled_ = false; - public: TurtleParser() = default; explicit TurtleParser(TripleComponent defaultGraphIri) @@ -392,7 +400,7 @@ class TurtleParser : public RdfParserBase { } // create a new, unused, unique blank node string - std::string createAnonNode() { + string createAnonNode() { return BlankNode{true, absl::StrCat(blankNodePrefix_, "_", numBlankNodes_++)} .toSparql(); @@ -471,7 +479,9 @@ CPP_template(typename Parser)( return positionOffset_ + tmpToParse_.size() - this->tok_.data().size(); } - void initialize(const std::string&, ad_utility::MemorySize) { + void initialize(const string& filename, ad_utility::MemorySize bufferSize) { + (void)filename; + (void)bufferSize; throw std::runtime_error( "RdfStringParser doesn't support calls to initialize. Only use " "parseUtf8String() for unit tests\n"); @@ -524,7 +534,7 @@ CPP_template(typename Parser)( // testing interface for reusing a parser // only specifies the tokenizers input stream. // Does not alter the tokenizers state - void setInputStream(const std::string& toParse) { + void setInputStream(const string& toParse) { tmpToParse_.clear(); tmpToParse_.reserve(toParse.size()); tmpToParse_.insert(tmpToParse_.end(), toParse.begin(), toParse.end()); @@ -545,9 +555,6 @@ CPP_template(typename Parser)( // as expected size_t getPosition() const { return this->tok_.begin() - tmpToParse_.data(); } - // Disable prefix parsing for turtle parsers during parallel parsing. - void disablePrefixParsing() { this->prefixAndBaseDisabled_ = true; } - FRIEND_TEST(RdfParserTest, prefixedName); FRIEND_TEST(RdfParserTest, prefixID); FRIEND_TEST(RdfParserTest, stringParse); @@ -583,7 +590,7 @@ class RdfStreamParser : public Parser { // Default construction needed for tests RdfStreamParser() = default; explicit RdfStreamParser( - const std::string& filename, + const string& filename, ad_utility::MemorySize bufferSize = DEFAULT_PARSER_BUFFER_SIZE, TripleComponent defaultGraphIri = qlever::specialIds().at(DEFAULT_GRAPH_IRI)) @@ -595,8 +602,7 @@ class RdfStreamParser : public Parser { bool getLineImpl(TurtleTriple* triple) override; - void initialize(const std::string& filename, - ad_utility::MemorySize bufferSize); + void initialize(const string& filename, ad_utility::MemorySize bufferSize); size_t getParsePosition() const override { return numBytesBeforeCurrentBatch_ + (tok_.data().data() - byteVec_.data()); @@ -638,7 +644,7 @@ class RdfStreamParser : public Parser { template class RdfParallelParser : public Parser { public: - using Triple = std::array; + using Triple = std::array; // Default construction needed for tests RdfParallelParser() = default; @@ -646,7 +652,7 @@ class RdfParallelParser : public Parser { // parser will sleep for the specified time before parsing each batch s.t. // certain corner cases can be tested. explicit RdfParallelParser( - const std::string& filename, + const string& filename, ad_utility::MemorySize bufferSize = DEFAULT_PARSER_BUFFER_SIZE, std::chrono::milliseconds sleepTimeForTesting = std::chrono::milliseconds{0}) @@ -659,8 +665,7 @@ class RdfParallelParser : public Parser { } // Construct a parser from a file and a given default graph iri. - RdfParallelParser(const std::string& filename, - ad_utility::MemorySize bufferSize, + RdfParallelParser(const string& filename, ad_utility::MemorySize bufferSize, const TripleComponent& defaultGraphIri) : Parser{defaultGraphIri}, defaultGraphIri_{defaultGraphIri} { initialize(filename, bufferSize); @@ -678,8 +683,7 @@ class RdfParallelParser : public Parser { parallelParser_.resetTimers(); } - void initialize(const std::string& filename, - ad_utility::MemorySize bufferSize); + void initialize(const string& filename, ad_utility::MemorySize bufferSize); size_t getParsePosition() const override { // TODO: can we really define this position here? @@ -716,12 +720,6 @@ class RdfParallelParser : public Parser { QUEUE_SIZE_BEFORE_PARALLEL_PARSING, NUM_PARALLEL_PARSER_THREADS, "parallel parser"}; std::future parseFuture_; - - // Collect error messages in case of multiple failures. The `size_t` is the - // start position of the corresponding batch, used to order the errors in case - // the batches are finished out of order. - ad_utility::Synchronized>> - errorMessages_; // The parallel parsers need to know when the last batch has been parsed, s.t. // the parser threads can be destroyed. The following two members are needed // for keeping track of this condition. @@ -781,8 +779,7 @@ class RdfMultifileParser : public RdfParserBase { // `parsingQueue_` is declared *after* the `finishedBatchQueue_`, s.t. when // destroying the parser, the threads from the `parsingQueue_` are all joined // before the `finishedBatchQueue_` (which they are using!) is destroyed. - ad_utility::TaskQueue parsingQueue_{QUEUE_SIZE_BEFORE_PARALLEL_PARSING, - NUM_PARALLEL_PARSER_THREADS}; + ad_utility::TaskQueue parsingQueue_{10, NUM_PARALLEL_PARSER_THREADS}; // The number of parsers that have started, but not yet finished. This is // needed to detect the complete parsing. diff --git a/src/parser/Tokenizer.h b/src/parser/Tokenizer.h index 6d971ccba8..a8dc50d0ac 100644 --- a/src/parser/Tokenizer.h +++ b/src/parser/Tokenizer.h @@ -7,7 +7,10 @@ #include #include +#include + #include "parser/TurtleTokenId.h" +#include "util/Exception.h" #include "util/Log.h" using re2::RE2; @@ -237,7 +240,7 @@ struct SkipWhitespaceAndCommentsMixin { auto v = self().view(); if (v.starts_with('#')) { auto pos = v.find('\n'); - if (pos == std::string::npos) { + if (pos == string::npos) { // TODO: This should rather yield an error. LOG(INFO) << "Warning, unfinished comment found while parsing" << std::endl; @@ -270,8 +273,6 @@ class Tokenizer : public SkipWhitespaceAndCommentsMixin { Tokenizer(std::string_view input) : _tokens(), _data(input.data(), input.size()) {} - static constexpr bool UseRelaxedParsing = false; - // if a prefix of the input stream matches the regex argument, // return true and that prefix and move the input stream forward // by the length of the match. If no match is found, diff --git a/src/parser/TokenizerCtre.h b/src/parser/TokenizerCtre.h index cd1f81cbe5..28c2e48731 100644 --- a/src/parser/TokenizerCtre.h +++ b/src/parser/TokenizerCtre.h @@ -154,8 +154,6 @@ class TokenizerCtre : public SkipWhitespaceAndCommentsMixin { */ explicit TokenizerCtre(std::string_view data) : _data(data) {} - static constexpr bool UseRelaxedParsing = true; - /// iterator to the next character that we have not yet consumed [[nodiscard]] auto begin() const { return _data.begin(); } diff --git a/src/util/TaskQueue.h b/src/util/TaskQueue.h index caf1147649..c18b0a4e8d 100644 --- a/src/util/TaskQueue.h +++ b/src/util/TaskQueue.h @@ -120,11 +120,6 @@ class TaskQueue { std::to_string(popTime_) + "ms (pop)"; } - // Block the current thread until `finish()` on the queue has been called and - // successfully completed. This function may NOT be called from inside a queue - // thread, otherwise there will be a deadlock. - void waitUntilFinished() const { finishedFinishing_.wait(false); } - ~TaskQueue() { if (startedFinishing_.test_and_set()) { // Someone has already called `finish`, we have to wait for the finishing diff --git a/test/DeltaTriplesTest.cpp b/test/DeltaTriplesTest.cpp index eb906be6d8..3b4564f483 100644 --- a/test/DeltaTriplesTest.cpp +++ b/test/DeltaTriplesTest.cpp @@ -4,18 +4,17 @@ // 2023 Hannah Bast // 2024 Julian Mundhahs -#include #include #include "./DeltaTriplesTestHelpers.h" #include "./util/GTestHelpers.h" #include "./util/IndexTestHelpers.h" +#include "absl/strings/str_split.h" #include "engine/ExportQueryExecutionTrees.h" #include "index/DeltaTriples.h" #include "index/IndexImpl.h" #include "index/Permutation.h" #include "parser/RdfParser.h" -#include "parser/Tokenizer.h" using namespace deltaTriplesTestHelpers; diff --git a/test/IndexTest.cpp b/test/IndexTest.cpp index 3c81f18cf3..68cee86b02 100644 --- a/test/IndexTest.cpp +++ b/test/IndexTest.cpp @@ -14,12 +14,12 @@ #include "./util/IdTableHelpers.h" #include "./util/IdTestHelpers.h" #include "./util/TripleComponentTestHelpers.h" +#include "global/Pattern.h" #include "index/Index.h" #include "index/IndexImpl.h" #include "util/IndexTestHelpers.h" using namespace ad_utility::testing; -using namespace std::string_literals; using ::testing::UnorderedElementsAre; diff --git a/test/RdfParserTest.cpp b/test/RdfParserTest.cpp index 7323deef9d..780952acfe 100644 --- a/test/RdfParserTest.cpp +++ b/test/RdfParserTest.cpp @@ -13,9 +13,8 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" #include "parser/RdfParser.h" -#include "parser/Tokenizer.h" -#include "parser/TokenizerCtre.h" #include "parser/TripleComponent.h" +#include "util/Conversions.h" #include "util/MemorySize/MemorySize.h" using std::string; @@ -1003,48 +1002,6 @@ TEST(RdfParserTest, exceptionPropagationFileBufferReading) { forAllParallelParsers(testWithParser, 40_B, inputWithLongTriple); } -// Test that in parallel parsing scattered prefixes or base declarations lead to -// an exception -TEST(RdfParserTest, exceptionOnScatteredPrefixOrBaseInParallelParser) { - std::string filename{"turtleParserExceptionPropagationFileBufferReading.dat"}; - auto testWithParser = [&](bool useBatchInterface, - ad_utility::MemorySize bufferSize, - std::string_view input) { - { - auto of = ad_utility::makeOfstream(filename); - of << input; - } - AD_EXPECT_THROW_WITH_MESSAGE( - (parseFromFile(filename, useBatchInterface, bufferSize)), - ::testing::HasSubstr("'--parse-parallel false'")); - ad_utility::deleteFile(filename); - }; - std::string inputWithScatteredPrefix = - "@prefix ex: . \n" - " . \n " - " . \n" - "@prefix ex: . \n"; - forAllParallelParsers(testWithParser, 40_B, inputWithScatteredPrefix); - std::string inputWithScatteredSparqlPrefix = - "PREFIX ex: . \n" - " . \n " - " . \n" - "PREFIX ex: . \n"; - forAllParallelParsers(testWithParser, 40_B, inputWithScatteredPrefix); - std::string inputWithScatteredBase = - "@base . \n" - " . \n " - " . \n" - "@base . \n"; - forAllParallelParsers(testWithParser, 40_B, inputWithScatteredPrefix); - std::string inputWithScatteredSparqlBase = - "BASE . \n" - " . \n " - " . \n" - "BASE . \n"; - forAllParallelParsers(testWithParser, 40_B, inputWithScatteredPrefix); -} - // Test that the parallel parser's destructor can be run quickly and without // blocking, even when there are still lots of blocks in the pipeline that are // currently being parsed. From dffbc2babf8c95df77be877b46f34e7832f795ae Mon Sep 17 00:00:00 2001 From: Julian <14220769+Qup42@users.noreply.github.com> Date: Fri, 21 Feb 2025 20:25:35 +0100 Subject: [PATCH 6/7] adapt AsioHelpers to Boost Asio 1.87.0 (#1831) Boost Asio 1.87.0 requires some minor changes in QLever. Closes #1822 --- src/util/AsioHelpers.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/util/AsioHelpers.h b/src/util/AsioHelpers.h index 31456c823e..0a40bb03ef 100644 --- a/src/util/AsioHelpers.h +++ b/src/util/AsioHelpers.h @@ -52,9 +52,9 @@ CPP_template(typename Executor, typename Function, typename Handler)( // exception_ptr and the return value as the second argument. if constexpr (isVoid) { function_(); - callHandler(nullptr); + callHandler(std::exception_ptr{}); } else { - callHandler(nullptr, function_()); + callHandler(std::exception_ptr{}, function_()); } } catch (...) { // If `function_()` throws, we propagate the exception to the From 8fe06428ee1dddbb3ebcb41a1d93525075571bc1 Mon Sep 17 00:00:00 2001 From: Julian <14220769+Qup42@users.noreply.github.com> Date: Fri, 21 Feb 2025 23:07:07 +0100 Subject: [PATCH 7/7] Implement Graph Store HTTP Protocol operations `GET` and `POST` (#1727) Integrate the Graph Store HTTP Protocol backend code from #1668. With this change, QLever now supports the graph management operations `GET` (get all triples from specified graph) and `POST` (insert triples from payload into specified graph). Here are four example calls (the first two do a `GET`, the last two do a `POST`, the `-s` stands for silent, the `?default` stands for the default graph): ``` curl -Gs "localhost:7001/?graph=http://example.org/42 -H "Accept: text/turtle" curl -Gs "localhost:7001/?default" -H "Accept: text/turtle" curl -s "localhost:7001/?graph=http://example.org/42&access-token=bla" -H "Content-type: text/turtle" -d " 42" curl -s "localhost:7001/?default&access-token=bla" -H "Content-type: text/turtle" -d " 42" ``` --- src/engine/CMakeLists.txt | 2 +- src/engine/GraphStoreProtocol.cpp | 40 +- src/engine/GraphStoreProtocol.h | 27 +- src/engine/ParsedRequestBuilder.cpp | 163 ++++++++ src/engine/ParsedRequestBuilder.h | 83 ++++ src/engine/SPARQLProtocol.cpp | 170 ++++++++ src/engine/SPARQLProtocol.h | 243 ++--------- src/engine/Server.cpp | 239 +++++------ src/engine/Server.h | 15 +- src/parser/SparqlParser.cpp | 13 + src/parser/SparqlParser.h | 3 + src/parser/data/GraphRef.h | 2 +- src/util/http/UrlParser.h | 14 +- test/CMakeLists.txt | 4 +- test/GraphStoreProtocolTest.cpp | 86 ++-- test/ParsedRequestBuilderTest.cpp | 296 ++++++++++++++ test/SPARQLProtocolTest.cpp | 606 +++++++++++++++++----------- test/SparqlParserTest.cpp | 23 ++ 18 files changed, 1359 insertions(+), 670 deletions(-) create mode 100644 src/engine/ParsedRequestBuilder.cpp create mode 100644 src/engine/ParsedRequestBuilder.h create mode 100644 src/engine/SPARQLProtocol.cpp create mode 100644 test/ParsedRequestBuilderTest.cpp diff --git a/src/engine/CMakeLists.txt b/src/engine/CMakeLists.txt index 7e8cbbc953..d1eba3a437 100644 --- a/src/engine/CMakeLists.txt +++ b/src/engine/CMakeLists.txt @@ -15,5 +15,5 @@ add_library(engine TextLimit.cpp LazyGroupBy.cpp GroupByHashMapOptimization.cpp SpatialJoin.cpp CountConnectedSubgraphs.cpp SpatialJoinAlgorithms.cpp PathSearch.cpp ExecuteUpdate.cpp Describe.cpp GraphStoreProtocol.cpp - QueryExecutionContext.cpp ExistsJoin.cpp) + QueryExecutionContext.cpp ExistsJoin.cpp SPARQLProtocol.cpp ParsedRequestBuilder.cpp) qlever_target_link_libraries(engine util index parser sparqlExpressions http SortPerformanceEstimator Boost::iostreams s2) diff --git a/src/engine/GraphStoreProtocol.cpp b/src/engine/GraphStoreProtocol.cpp index fc46ec6fa0..a22463b562 100644 --- a/src/engine/GraphStoreProtocol.cpp +++ b/src/engine/GraphStoreProtocol.cpp @@ -4,28 +4,9 @@ #include "engine/GraphStoreProtocol.h" +#include "parser/SparqlParser.h" #include "util/http/beast.h" -// ____________________________________________________________________________ -GraphOrDefault GraphStoreProtocol::extractTargetGraph( - const ad_utility::url_parser::ParamValueMap& params) { - const std::optional graphIri = - ad_utility::url_parser::checkParameter(params, "graph", std::nullopt); - const bool isDefault = - ad_utility::url_parser::checkParameter(params, "default", "").has_value(); - if (graphIri.has_value() == isDefault) { - throw std::runtime_error( - "Exactly one of the query parameters default or graph must be set to " - "identify the graph for the graph store protocol request."); - } - if (graphIri.has_value()) { - return GraphRef::fromIrirefWithoutBrackets(graphIri.value()); - } else { - AD_CORRECTNESS_CHECK(isDefault); - return DEFAULT{}; - } -} - // ____________________________________________________________________________ void GraphStoreProtocol::throwUnsupportedMediatype( const string_view& mediatype) { @@ -84,18 +65,15 @@ std::vector GraphStoreProtocol::convertTriples( // ____________________________________________________________________________ ParsedQuery GraphStoreProtocol::transformGet(const GraphOrDefault& graph) { - ParsedQuery res; - res._clause = parsedQuery::ConstructClause( - {{Variable("?s"), Variable("?p"), Variable("?o")}}); - res._rootGraphPattern = {}; - parsedQuery::GraphPattern selectSPO; - selectSPO._graphPatterns.emplace_back(parsedQuery::BasicGraphPattern{ - {SparqlTriple(Variable("?s"), "?p", Variable("?o"))}}); + // Construct the parsed query from its short equivalent SPARQL Update string. + // This is easier and also provides e.g. the `_originalString` field. + std::string query; if (const auto* iri = std::get_if(&graph)) { - res.datasetClauses_ = - parsedQuery::DatasetClauses::fromClauses({DatasetClause{*iri, false}}); + query = absl::StrCat("CONSTRUCT { ?s ?p ?o } WHERE { GRAPH ", + iri->toStringRepresentation(), " { ?s ?p ?o } }"); + } else { + query = "CONSTRUCT { ?s ?p ?o } WHERE { ?s ?p ?o }"; } - res._rootGraphPattern = std::move(selectSPO); - return res; + return SparqlParser::parseQuery(query); } diff --git a/src/engine/GraphStoreProtocol.h b/src/engine/GraphStoreProtocol.h index 4a8ee974cc..0d06187a8b 100644 --- a/src/engine/GraphStoreProtocol.h +++ b/src/engine/GraphStoreProtocol.h @@ -86,6 +86,10 @@ class GraphStoreProtocol { updateClause::GraphUpdate up{std::move(convertedTriples), {}}; ParsedQuery res; res._clause = parsedQuery::UpdateClause{std::move(up)}; + // Graph store protocol POST requests might have a very large body. Limit + // the length used for the string representation (arbitrarily) to 5000. + string_view body = string_view(rawRequest.body()).substr(0, 5000); + res._originalString = absl::StrCat("Graph Store POST Operation\n", body); return res; } FRIEND_TEST(GraphStoreProtocolTest, transformPost); @@ -101,25 +105,22 @@ class GraphStoreProtocol { // Update. CPP_template_2(typename RequestT)( requires ad_utility::httpUtils::HttpRequest) static ParsedQuery - transformGraphStoreProtocol(const RequestT& rawRequest) { + transformGraphStoreProtocol( + ad_utility::url_parser::sparqlOperation::GraphStoreOperation + operation, + const RequestT& rawRequest) { ad_utility::url_parser::ParsedUrl parsedUrl = ad_utility::url_parser::parseRequestTarget(rawRequest.target()); - // We only support passing the target graph as a query parameter (`Indirect - // Graph Identification`). `Direct Graph Identification` (the URL is the - // graph) is not supported. See also - // https://www.w3.org/TR/2013/REC-sparql11-http-rdf-update-20130321/#graph-identification. - GraphOrDefault graph = extractTargetGraph(parsedUrl.parameters_); - using enum boost::beast::http::verb; auto method = rawRequest.method(); if (method == get) { - return transformGet(graph); + return transformGet(operation.graph_); } else if (method == put) { throwUnsupportedHTTPMethod("PUT"); } else if (method == delete_) { throwUnsupportedHTTPMethod("DELETE"); } else if (method == post) { - return transformPost(rawRequest, graph); + return transformPost(rawRequest, operation.graph_); } else if (method == head) { throwUnsupportedHTTPMethod("HEAD"); } else if (method == patch) { @@ -131,12 +132,4 @@ class GraphStoreProtocol { "\" for the SPARQL Graph Store HTTP Protocol.")); } } - - private: - // Extract the graph to be acted upon using from the URL query parameters - // (`Indirect Graph Identification`). See - // https://www.w3.org/TR/2013/REC-sparql11-http-rdf-update-20130321/#indirect-graph-identification - static GraphOrDefault extractTargetGraph( - const ad_utility::url_parser::ParamValueMap& params); - FRIEND_TEST(GraphStoreProtocolTest, extractTargetGraph); }; diff --git a/src/engine/ParsedRequestBuilder.cpp b/src/engine/ParsedRequestBuilder.cpp new file mode 100644 index 0000000000..699cdbc3de --- /dev/null +++ b/src/engine/ParsedRequestBuilder.cpp @@ -0,0 +1,163 @@ +// Copyright 2025, University of Freiburg +// Chair of Algorithms and Data Structures +// Authors: Julian Mundhahs + +#include "ParsedRequestBuilder.h" + +using namespace ad_utility::url_parser::sparqlOperation; + +// ____________________________________________________________________________ +ParsedRequestBuilder::ParsedRequestBuilder(const RequestType& request) { + using namespace ad_utility::url_parser::sparqlOperation; + // For an HTTP request, `request.target()` yields the HTTP Request-URI. + // This is a concatenation of the URL path and the query strings. + auto parsedUrl = ad_utility::url_parser::parseRequestTarget(request.target()); + parsedRequest_ = {std::move(parsedUrl.path_), std::nullopt, + std::move(parsedUrl.parameters_), None{}}; +} + +// ____________________________________________________________________________ +void ParsedRequestBuilder::extractAccessToken(const RequestType& request) { + parsedRequest_.accessToken_ = + determineAccessToken(request, parsedRequest_.parameters_); +} + +// ____________________________________________________________________________ +void ParsedRequestBuilder::extractDatasetClauses() { + extractDatasetClauseIfOperationIs("default-graph-uri", false); + extractDatasetClauseIfOperationIs("named-graph-uri", true); + extractDatasetClauseIfOperationIs("using-graph-uri", false); + extractDatasetClauseIfOperationIs("using-named-graph-uri", true); +} + +// ____________________________________________________________________________ +bool ParsedRequestBuilder::parameterIsContainedExactlyOnce( + std::string_view key) const { + return ad_utility::url_parser::getParameterCheckAtMostOnce( + parsedRequest_.parameters_, key) + .has_value(); +} + +// ____________________________________________________________________________ +bool ParsedRequestBuilder::isGraphStoreOperation() const { + return parameterIsContainedExactlyOnce("graph") || + parameterIsContainedExactlyOnce("default"); +} + +// ____________________________________________________________________________ +void ParsedRequestBuilder::extractGraphStoreOperation() { + // SPARQL Graph Store HTTP Protocol with indirect graph identification + if (parameterIsContainedExactlyOnce("graph") && + parameterIsContainedExactlyOnce("default")) { + throw std::runtime_error( + R"(Parameters "graph" and "default" must not be set at the same time.)"); + } + AD_CORRECTNESS_CHECK(std::holds_alternative(parsedRequest_.operation_)); + // We only support passing the target graph as a query parameter + // (`Indirect Graph Identification`). `Direct Graph Identification` (the + // URL is the graph) is not supported. See also + // https://www.w3.org/TR/2013/REC-sparql11-http-rdf-update-20130321/#graph-identification. + parsedRequest_.operation_ = + GraphStoreOperation{extractTargetGraph(parsedRequest_.parameters_)}; +} + +// ____________________________________________________________________________ +bool ParsedRequestBuilder::parametersContain(std::string_view param) const { + return parsedRequest_.parameters_.contains(param); +} + +// ____________________________________________________________________________ +ad_utility::url_parser::ParsedRequest ParsedRequestBuilder::build() && { + return std::move(parsedRequest_); +} + +// ____________________________________________________________________________ +void ParsedRequestBuilder::reportUnsupportedContentTypeIfGraphStore( + std::string_view contentType) const { + if (isGraphStoreOperation()) { + throw std::runtime_error(absl::StrCat("Unsupported Content type \"", + contentType, + "\" for Graph Store protocol.")); + } +} + +// ____________________________________________________________________________ +template +void ParsedRequestBuilder::extractDatasetClauseIfOperationIs( + const std::string& key, bool isNamed) { + if (Operation* op = std::get_if(&parsedRequest_.operation_)) { + ad_utility::appendVector(op->datasetClauses_, + ad_utility::url_parser::parseDatasetClausesFrom( + parsedRequest_.parameters_, key, isNamed)); + } +} + +// ____________________________________________________________________________ +template +void ParsedRequestBuilder::extractOperationIfSpecified(string_view paramName) { + auto operation = ad_utility::url_parser::getParameterCheckAtMostOnce( + parsedRequest_.parameters_, paramName); + if (operation.has_value()) { + AD_CORRECTNESS_CHECK( + std::holds_alternative(parsedRequest_.operation_)); + parsedRequest_.operation_ = Operation{operation.value(), {}}; + parsedRequest_.parameters_.erase(paramName); + } +} + +template void ParsedRequestBuilder::extractOperationIfSpecified( + string_view paramName); +template void ParsedRequestBuilder::extractOperationIfSpecified( + string_view paramName); + +// ____________________________________________________________________________ +GraphOrDefault ParsedRequestBuilder::extractTargetGraph( + const ad_utility::url_parser::ParamValueMap& params) { + const std::optional graphIri = + ad_utility::url_parser::checkParameter(params, "graph", std::nullopt); + const bool isDefault = + ad_utility::url_parser::checkParameter(params, "default", "").has_value(); + if (graphIri.has_value() == isDefault) { + throw std::runtime_error( + R"(Exactly one of the query parameters "default" or "graph" must be set to identify the graph for the graph store protocol request.)"); + } + if (graphIri.has_value()) { + return GraphRef::fromIrirefWithoutBrackets(graphIri.value()); + } else { + AD_CORRECTNESS_CHECK(isDefault); + return DEFAULT{}; + } +} + +// ____________________________________________________________________________ +std::optional ParsedRequestBuilder::determineAccessToken( + const RequestType& request, + const ad_utility::url_parser::ParamValueMap& params) { + namespace http = boost::beast::http; + std::optional tokenFromAuthorizationHeader; + std::optional tokenFromParameter; + if (request.find(http::field::authorization) != request.end()) { + string_view authorization = request[http::field::authorization]; + const std::string prefix = "Bearer "; + if (!authorization.starts_with(prefix)) { + throw std::runtime_error(absl::StrCat( + "Authorization header doesn't start with \"", prefix, "\".")); + } + authorization.remove_prefix(prefix.length()); + tokenFromAuthorizationHeader = std::string(authorization); + } + if (params.contains("access-token")) { + tokenFromParameter = ad_utility::url_parser::getParameterCheckAtMostOnce( + params, "access-token"); + } + // If both are specified, they must be equal. This way there is no hidden + // precedence. + if (tokenFromAuthorizationHeader && tokenFromParameter && + tokenFromAuthorizationHeader != tokenFromParameter) { + throw std::runtime_error( + "Access token is specified both in the `Authorization` header and by " + "the `access-token` parameter, but they are not the same"); + } + return tokenFromAuthorizationHeader ? std::move(tokenFromAuthorizationHeader) + : std::move(tokenFromParameter); +} diff --git a/src/engine/ParsedRequestBuilder.h b/src/engine/ParsedRequestBuilder.h new file mode 100644 index 0000000000..66f8cfc4ad --- /dev/null +++ b/src/engine/ParsedRequestBuilder.h @@ -0,0 +1,83 @@ +// Copyright 2025, University of Freiburg +// Chair of Algorithms and Data Structures +// Authors: Julian Mundhahs + +#pragma once + +#include "util/http/UrlParser.h" +#include "util/http/beast.h" + +// Helper for parsing `HttpRequest` into `ParsedRequest`. The parsing has many +// common patterns but the details are slightly different. This struct +// stores the partially parsed `ParsedRequest` and methods for common +// operations used while parsing. +struct ParsedRequestBuilder { + FRIEND_TEST(ParsedRequestBuilderTest, extractTargetGraph); + FRIEND_TEST(ParsedRequestBuilderTest, determineAccessToken); + FRIEND_TEST(ParsedRequestBuilderTest, parameterIsContainedExactlyOnce); + + using RequestType = + boost::beast::http::request; + + ad_utility::url_parser::ParsedRequest parsedRequest_; + + // Initialize a `ParsedRequestBuilder`, parsing the request target into the + // `ParsedRequest`. + explicit ParsedRequestBuilder(const RequestType& request); + + // Extract the access token from the access-token parameter or the + // Authorization header and set it for `ParsedRequest`. If both are given, + // then they must be the same. + void extractAccessToken(const RequestType& request); + + // If applicable extract the dataset clauses from the parameters and set them + // on the Query or Update. + void extractDatasetClauses(); + + // If the parameter is set, set the operation with the parameter's value as + // operation string and empty dataset clauses. Setting an operation when one + // is already set is an error. Note: processed parameters are removed from the + // parameter map. + template + void extractOperationIfSpecified(string_view paramName); + + // Returns whether the request is a Graph Store operation. + bool isGraphStoreOperation() const; + + // Set the operation to the parsed Graph Store operation. + void extractGraphStoreOperation(); + + // Returns whether the parameters contain a parameter with the given key. + bool parametersContain(std::string_view param) const; + + // Check that requests don't both have these content types and are Graph + // Store operations. + void reportUnsupportedContentTypeIfGraphStore( + std::string_view contentType) const; + + // Move the `ParsedRequest` out when parsing is finished. + ad_utility::url_parser::ParsedRequest build() &&; + + private: + // Adds a dataset clause to the operation if it is of the given type. The + // dataset clause's IRI is the value of parameter `key`. The `isNamed_` of the + // dataset clause is as given. + template + void extractDatasetClauseIfOperationIs(const std::string& key, bool isNamed); + + // Check that a parameter is contained exactly once. An exception is thrown if + // a parameter is contained more than once. + bool parameterIsContainedExactlyOnce(std::string_view key) const; + + // Extract the graph to be acted upon using from the URL query parameters + // (`Indirect Graph Identification`). See + // https://www.w3.org/TR/2013/REC-sparql11-http-rdf-update-20130321/#indirect-graph-identification + static GraphOrDefault extractTargetGraph( + const ad_utility::url_parser::ParamValueMap& params); + + // Determine the access token from the parameters and the requests + // Authorization header. + static std::optional determineAccessToken( + const RequestType& request, + const ad_utility::url_parser::ParamValueMap& params); +}; diff --git a/src/engine/SPARQLProtocol.cpp b/src/engine/SPARQLProtocol.cpp new file mode 100644 index 0000000000..902092dbbf --- /dev/null +++ b/src/engine/SPARQLProtocol.cpp @@ -0,0 +1,170 @@ +// Copyright 2025, University of Freiburg +// Chair of Algorithms and Data Structures +// Authors: Julian Mundhahs + +#include "engine/SPARQLProtocol.h" + +using namespace ad_utility::url_parser::sparqlOperation; +namespace http = boost::beast::http; + +// ____________________________________________________________________________ +ad_utility::url_parser::ParsedRequest SPARQLProtocol::parseGET( + const RequestType& request) { + auto parsedRequestBuilder = ParsedRequestBuilder(request); + parsedRequestBuilder.extractAccessToken(request); + const bool isQuery = parsedRequestBuilder.parametersContain("query"); + if (parsedRequestBuilder.parametersContain("update")) { + throw std::runtime_error("SPARQL Update is not allowed as GET request."); + } + if (parsedRequestBuilder.isGraphStoreOperation()) { + if (isQuery) { + throw std::runtime_error( + R"(Request contains parameters for both a SPARQL Query ("query") and a Graph Store Protocol operation ("graph" or "default").)"); + } + // SPARQL Graph Store HTTP Protocol with indirect graph identification + parsedRequestBuilder.extractGraphStoreOperation(); + } else if (isQuery) { + // SPARQL Query + parsedRequestBuilder.extractOperationIfSpecified("query"); + parsedRequestBuilder.extractDatasetClauses(); + } + return std::move(parsedRequestBuilder).build(); +} + +// ____________________________________________________________________________ +ad_utility::url_parser::ParsedRequest SPARQLProtocol::parseUrlencodedPOST( + const RequestType& request) { + auto parsedRequestBuilder = ParsedRequestBuilder(request); + // All parameters must be included in the request body for URL-encoded + // POST. The HTTP query string parameters must be empty. See SPARQL 1.1 + // Protocol Sections 2.1.2 + if (!parsedRequestBuilder.parsedRequest_.parameters_.empty()) { + throw std::runtime_error( + "URL-encoded POST requests must not contain query parameters in " + "the URL."); + } + + // Set the url-encoded parameters from the request body. + // Note: previously we used `boost::urls::parse_query`, but that + // function doesn't unescape the `+` which encodes a space character. + // The following workaround of making the url-encoded parameters a + // complete relative url and parsing this URL seems to work. Note: We + // have to bind the result of `StrCat` to an explicit variable, as the + // `boost::urls` parsing routines only give back a view, which otherwise + // would be dangling. + auto bodyAsQuery = absl::StrCat("/?", request.body()); + auto query = boost::urls::parse_origin_form(bodyAsQuery); + if (!query) { + throw std::runtime_error("Invalid URL-encoded POST request, body was: " + + request.body()); + } + parsedRequestBuilder.parsedRequest_.parameters_ = + ad_utility::url_parser::paramsToMap(query->params()); + parsedRequestBuilder.reportUnsupportedContentTypeIfGraphStore( + contentTypeUrlEncoded); + if (parsedRequestBuilder.parametersContain("query") && + parsedRequestBuilder.parametersContain("update")) { + throw std::runtime_error( + R"(Request must only contain one of "query" and "update".)"); + } + parsedRequestBuilder.extractOperationIfSpecified("query"); + parsedRequestBuilder.extractOperationIfSpecified("update"); + parsedRequestBuilder.extractDatasetClauses(); + // We parse the access token from the url-encoded parameters in the + // body. The URL parameters must be empty for URL-encoded POST (see + // above). + parsedRequestBuilder.extractAccessToken(request); + + return std::move(parsedRequestBuilder).build(); +} + +// ____________________________________________________________________________ +template +ad_utility::url_parser::ParsedRequest SPARQLProtocol::parseSPARQLPOST( + const RequestType& request, std::string_view contentType) { + using namespace ad_utility::url_parser::sparqlOperation; + auto parsedRequestBuilder = ParsedRequestBuilder(request); + parsedRequestBuilder.reportUnsupportedContentTypeIfGraphStore(contentType); + parsedRequestBuilder.parsedRequest_.operation_ = + Operation{request.body(), {}}; + parsedRequestBuilder.extractDatasetClauses(); + parsedRequestBuilder.extractAccessToken(request); + return std::move(parsedRequestBuilder).build(); +} + +// ____________________________________________________________________________ +ad_utility::url_parser::ParsedRequest SPARQLProtocol::parsePOST( + const RequestType& request) { + // For a POST request, the content type must be either + // "application/x-www-form-urlencoded" (1), "application/sparql-query" + // (2) or "application/sparql-update" (3). If no content type applies, then + // the request must be a graph store request (4). + // + // (1) Section 2.1.2: The body of the POST request contains *all* + // parameters (including the query or update) in an encoded form (just + // like in the part of a GET request after the "?"). + // + // (2) Section 2.1.3: The body of the POST request contains *only* the + // unencoded SPARQL query. There may be additional HTTP query parameters. + // + // (3) Section 2.2.2: The body of the POST request contains *only* the + // unencoded SPARQL update. There may be additional HTTP query parameters. + // + // (4) Graph Store requests must contain the graph to be acted on as a query + // parameter (indirect graph identification). For POST requests the body + // contains an RDF payload that should be parsed according to the content type + // and inserted into the graph. + // + // Reference: https://www.w3.org/TR/2013/REC-sparql11-protocol-20130321 + std::string_view contentType = + request.base()[boost::beast::http::field::content_type]; + LOG(DEBUG) << "Content-type: \"" << contentType << "\"" << std::endl; + + // Note: For simplicity we only check via `starts_with`. This ignores + // additional parameters like `application/sparql-query;charset=utf8`. We + // currently always expect UTF-8. + // TODO Implement more complete parsing that allows the checking + // of these parameters. + if (contentType.starts_with(contentTypeUrlEncoded)) { + return parseUrlencodedPOST(request); + } + if (contentType.starts_with(contentTypeSparqlQuery)) { + return parseSPARQLPOST(request, contentTypeSparqlQuery); + } + if (contentType.starts_with(contentTypeSparqlUpdate)) { + return parseSPARQLPOST(request, contentTypeSparqlUpdate); + } + // No content type applies, we expect the request to be a graph store + // request. Checking if the content type is supported by the Graph Store HTTP + // Protocol implementation is done later. + auto parsedRequestBuilder = ParsedRequestBuilder(request); + if (parsedRequestBuilder.isGraphStoreOperation()) { + parsedRequestBuilder.extractGraphStoreOperation(); + parsedRequestBuilder.extractAccessToken(request); + return std::move(parsedRequestBuilder).build(); + } + + throw std::runtime_error(absl::StrCat( + "POST request with content type \"", contentType, + "\" not supported (must be Query/Update with content type \"", + contentTypeUrlEncoded, "\", \"", contentTypeSparqlQuery, "\" or \"", + contentTypeSparqlUpdate, + "\" or a valid graph store protocol POST request)")); +} + +// ____________________________________________________________________________ +ad_utility::url_parser::ParsedRequest SPARQLProtocol::parseHttpRequest( + const RequestType& request) { + if (request.method() == http::verb::get) { + return parseGET(request); + } + if (request.method() == http::verb::post) { + return parsePOST(request); + } + std::ostringstream requestMethodName; + requestMethodName << request.method(); + throw std::runtime_error(absl::StrCat( + "Request method \"", requestMethodName.str(), + "\" not supported (only GET and POST are supported; PUT, DELETE, HEAD " + "and PATCH for graph store protocol are not yet supported)")); +} diff --git a/src/engine/SPARQLProtocol.h b/src/engine/SPARQLProtocol.h index 956ab5bca1..b863b073de 100644 --- a/src/engine/SPARQLProtocol.h +++ b/src/engine/SPARQLProtocol.h @@ -4,210 +4,49 @@ #pragma once -#include "util/Algorithm.h" -#include "util/TypeIdentity.h" -#include "util/TypeTraits.h" -#include "util/http/HttpUtils.h" -#include "util/http/UrlParser.h" -#include "util/http/beast.h" +#include "engine/ParsedRequestBuilder.h" +// Parses HTTP requests to `ParsedRequests` (a representation of Query, Update, +// Graph Store and internal operations) according to the SPARQL specifications. class SPARQLProtocol { - FRIEND_TEST(SPARQLProtocolTest, extractAccessToken); + FRIEND_TEST(SPARQLProtocolTest, parseGET); + FRIEND_TEST(SPARQLProtocolTest, parseUrlencodedPOST); + FRIEND_TEST(SPARQLProtocolTest, parseQueryPOST); + FRIEND_TEST(SPARQLProtocolTest, parseUpdatePOST); + FRIEND_TEST(SPARQLProtocolTest, parsePOST); + + static constexpr std::string_view contentTypeUrlEncoded = + "application/x-www-form-urlencoded"; + static constexpr std::string_view contentTypeSparqlQuery = + "application/sparql-query"; + static constexpr std::string_view contentTypeSparqlUpdate = + "application/sparql-update"; + + using RequestType = ParsedRequestBuilder::RequestType; + + // Parse an HTTP GET request into a `ParsedRequest`. The + // `ParsedRequestBuilder` must have already extracted the access token. + static ad_utility::url_parser::ParsedRequest parseGET( + const RequestType& request); + + // Parse an HTTP POST request with content-type + // `application/x-www-form-urlencoded` into a `ParsedRequest`. + static ad_utility::url_parser::ParsedRequest parseUrlencodedPOST( + const RequestType& request); + + // Parse an HTTP POST request with a SPARQL operation in its body + // into a `ParsedRequest`. This is used for the content types + // `application/sparql-query` and `application/sparql-update`. + template + static ad_utility::url_parser::ParsedRequest parseSPARQLPOST( + const RequestType& request, std::string_view contentType); + + // Parse an HTTP POST request into a `ParsedRequest`. + static ad_utility::url_parser::ParsedRequest parsePOST( + const RequestType& request); public: - /// Parse the path and URL parameters from the given request. Supports both - /// GET and POST request according to the SPARQL 1.1 standard. - CPP_template_2(typename RequestT)( - requires ad_utility::httpUtils::HttpRequest< - RequestT>) static ad_utility::url_parser::ParsedRequest - parseHttpRequest(const RequestT& request) { - using namespace ad_utility::url_parser::sparqlOperation; - using namespace ad_utility::use_type_identity; - namespace http = boost::beast::http; - // For an HTTP request, `request.target()` yields the HTTP Request-URI. - // This is a concatenation of the URL path and the query strings. - auto parsedUrl = - ad_utility::url_parser::parseRequestTarget(request.target()); - ad_utility::url_parser::ParsedRequest parsedRequest{ - std::move(parsedUrl.path_), std::nullopt, - std::move(parsedUrl.parameters_), None{}}; - - // Some valid requests (e.g. QLever's custom commands like retrieving index - // statistics) don't have a query. So an empty operation is not necessarily - // an error. - auto setOperationIfSpecifiedInParams = [&parsedRequest]( - TI, - string_view paramName) { - auto operation = ad_utility::url_parser::getParameterCheckAtMostOnce( - parsedRequest.parameters_, paramName); - if (operation.has_value()) { - parsedRequest.operation_ = Operation{operation.value(), {}}; - parsedRequest.parameters_.erase(paramName); - } - }; - auto addToDatasetClausesIfOperationIs = - [&parsedRequest]( - TI, const std::string& key, bool isNamed) { - if (Operation* op = - std::get_if(&parsedRequest.operation_)) { - ad_utility::appendVector( - op->datasetClauses_, - ad_utility::url_parser::parseDatasetClausesFrom( - parsedRequest.parameters_, key, isNamed)); - } - }; - auto addDatasetClauses = [&addToDatasetClausesIfOperationIs] { - addToDatasetClausesIfOperationIs(ti, "default-graph-uri", false); - addToDatasetClausesIfOperationIs(ti, "named-graph-uri", true); - addToDatasetClausesIfOperationIs(ti, "using-graph-uri", false); - addToDatasetClausesIfOperationIs(ti, "using-named-graph-uri", - true); - }; - auto extractAccessTokenFromRequest = [&parsedRequest, &request]() { - parsedRequest.accessToken_ = - extractAccessToken(request, parsedRequest.parameters_); - }; - - if (request.method() == http::verb::get) { - setOperationIfSpecifiedInParams(ti, "query"); - addDatasetClauses(); - extractAccessTokenFromRequest(); - - if (parsedRequest.parameters_.contains("update")) { - throw std::runtime_error( - "SPARQL Update is not allowed as GET request."); - } - return parsedRequest; - } - if (request.method() == http::verb::post) { - // For a POST request, the content type *must* be either - // "application/x-www-form-urlencoded" (1), "application/sparql-query" - // (2) or "application/sparql-update" (3). - // - // (1) Section 2.1.2: The body of the POST request contains *all* - // parameters (including the query or update) in an encoded form (just - // like in the part of a GET request after the "?"). - // - // (2) Section 2.1.3: The body of the POST request contains *only* the - // unencoded SPARQL query. There may be additional HTTP query parameters. - // - // (3) Section 2.2.2: The body of the POST request contains *only* the - // unencoded SPARQL update. There may be additional HTTP query parameters. - // - // Reference: https://www.w3.org/TR/2013/REC-sparql11-protocol-20130321 - std::string_view contentType = request.base()[http::field::content_type]; - LOG(DEBUG) << "Content-type: \"" << contentType << "\"" << std::endl; - static constexpr std::string_view contentTypeUrlEncoded = - "application/x-www-form-urlencoded"; - static constexpr std::string_view contentTypeSparqlQuery = - "application/sparql-query"; - static constexpr std::string_view contentTypeSparqlUpdate = - "application/sparql-update"; - - // Note: For simplicity we only check via `starts_with`. This ignores - // additional parameters like `application/sparql-query;charset=utf8`. We - // currently always expect UTF-8. - // TODO Implement more complete parsing that allows the checking - // of these parameters. - if (contentType.starts_with(contentTypeUrlEncoded)) { - // All parameters must be included in the request body for URL-encoded - // POST. The HTTP query string parameters must be empty. See SPARQL 1.1 - // Protocol Sections 2.1.2 - if (!parsedRequest.parameters_.empty()) { - throw std::runtime_error( - "URL-encoded POST requests must not contain query parameters in " - "the URL."); - } - - // Set the url-encoded parameters from the request body. - // Note: previously we used `boost::urls::parse_query`, but that - // function doesn't unescape the `+` which encodes a space character. - // The following workaround of making the url-encoded parameters a - // complete relative url and parsing this URL seems to work. Note: We - // have to bind the result of `StrCat` to an explicit variable, as the - // `boost::urls` parsing routines only give back a view, which otherwise - // would be dangling. - auto bodyAsQuery = absl::StrCat("/?", request.body()); - auto query = boost::urls::parse_origin_form(bodyAsQuery); - if (!query) { - throw std::runtime_error( - "Invalid URL-encoded POST request, body was: " + request.body()); - } - parsedRequest.parameters_ = - ad_utility::url_parser::paramsToMap(query->params()); - - if (parsedRequest.parameters_.contains("query") && - parsedRequest.parameters_.contains("update")) { - throw std::runtime_error( - R"(Request must only contain one of "query" and "update".)"); - } - setOperationIfSpecifiedInParams(ti, "query"); - setOperationIfSpecifiedInParams(ti, "update"); - addDatasetClauses(); - // We parse the access token from the url-encoded parameters in the - // body. The URL parameters must be empty for URL-encoded POST (see - // above). - extractAccessTokenFromRequest(); - - return parsedRequest; - } - if (contentType.starts_with(contentTypeSparqlQuery)) { - parsedRequest.operation_ = Query{request.body(), {}}; - addDatasetClauses(); - extractAccessTokenFromRequest(); - return parsedRequest; - } - if (contentType.starts_with(contentTypeSparqlUpdate)) { - parsedRequest.operation_ = Update{request.body(), {}}; - addDatasetClauses(); - extractAccessTokenFromRequest(); - return parsedRequest; - } - throw std::runtime_error(absl::StrCat( - "POST request with content type \"", contentType, - "\" not supported (must be \"", contentTypeUrlEncoded, "\", \"", - contentTypeSparqlQuery, "\" or \"", contentTypeSparqlUpdate, "\")")); - } - std::ostringstream requestMethodName; - requestMethodName << request.method(); - throw std::runtime_error( - absl::StrCat("Request method \"", requestMethodName.str(), - "\" not supported (has to be GET or POST)")); - }; - - private: - CPP_template_2(typename RequestT)( - requires ad_utility::httpUtils::HttpRequest) static std:: - optional extractAccessToken( - const RequestT& request, - const ad_utility::url_parser::ParamValueMap& params) { - namespace http = boost::beast::http; - std::optional tokenFromAuthorizationHeader; - std::optional tokenFromParameter; - if (request.find(http::field::authorization) != request.end()) { - string_view authorization = request[http::field::authorization]; - const std::string prefix = "Bearer "; - if (!authorization.starts_with(prefix)) { - throw std::runtime_error(absl::StrCat( - "Authorization header doesn't start with \"", prefix, "\".")); - } - authorization.remove_prefix(prefix.length()); - tokenFromAuthorizationHeader = std::string(authorization); - } - if (params.contains("access-token")) { - tokenFromParameter = ad_utility::url_parser::getParameterCheckAtMostOnce( - params, "access-token"); - } - // If both are specified, they must be equal. This way there is no hidden - // precedence. - if (tokenFromAuthorizationHeader && tokenFromParameter && - tokenFromAuthorizationHeader != tokenFromParameter) { - throw std::runtime_error( - "Access token is specified both in the `Authorization` header and by " - "the `access-token` parameter, but they are not the same"); - } - return tokenFromAuthorizationHeader - ? std::move(tokenFromAuthorizationHeader) - : std::move(tokenFromParameter); - }; + // Parse a HTTP request. + static ad_utility::url_parser::ParsedRequest parseHttpRequest( + const RequestType& request); }; diff --git a/src/engine/Server.cpp b/src/engine/Server.cpp index f1c80ac15d..697192fa69 100644 --- a/src/engine/Server.cpp +++ b/src/engine/Server.cpp @@ -214,6 +214,60 @@ class QueryAlreadyInUseError : public std::runtime_error { "' is already in use!"} {} }; +// _____________________________________________________________________________ +auto Server::cancelAfterDeadline( + std::weak_ptr> cancellationHandle, + TimeLimit timeLimit) + -> QL_CONCEPT_OR_NOTHING( + ad_utility::InvocableWithExactReturnType) auto { + net::steady_timer timer{timerExecutor_, timeLimit}; + + timer.async_wait([cancellationHandle = std::move(cancellationHandle)]( + const boost::system::error_code&) { + if (auto pointer = cancellationHandle.lock()) { + pointer->cancel(ad_utility::CancellationState::TIMEOUT); + } + }); + return [timer = std::move(timer)]() mutable { timer.cancel(); }; +} + +// _____________________________________________________________________________ +auto Server::setupCancellationHandle( + const ad_utility::websocket::QueryId& queryId, TimeLimit timeLimit) + -> QL_CONCEPT_OR_NOTHING(ad_utility::isInstantiation< + CancellationHandleAndTimeoutTimerCancel>) auto { + auto cancellationHandle = queryRegistry_.getCancellationHandle(queryId); + AD_CORRECTNESS_CHECK(cancellationHandle); + cancellationHandle->startWatchDog(); + absl::Cleanup cancelCancellationHandle{ + cancelAfterDeadline(cancellationHandle, timeLimit)}; + return CancellationHandleAndTimeoutTimerCancel{ + std::move(cancellationHandle), std::move(cancelCancellationHandle)}; +} + +// ____________________________________________________________________________ +auto Server::prepareOperation( + std::string_view operationName, std::string_view operationSPARQL, + ad_utility::websocket::MessageSender& messageSender, + const ad_utility::url_parser::ParamValueMap& params, TimeLimit timeLimit) { + auto [cancellationHandle, cancelTimeoutOnDestruction] = + setupCancellationHandle(messageSender.getQueryId(), timeLimit); + + // Do the query planning. This creates a `QueryExecutionTree`, which will + // then be used to process the query. + auto [pinSubtrees, pinResult] = determineResultPinning(params); + LOG(INFO) << "Processing the following " << operationName << ":" + << (pinResult ? " [pin result]" : "") + << (pinSubtrees ? " [pin subresults]" : "") << "\n" + << operationSPARQL << std::endl; + QueryExecutionContext qec(index_, &cache_, allocator_, + sortPerformanceEstimator_, std::ref(messageSender), + pinSubtrees, pinResult); + + return std::tuple{std::move(qec), std::move(cancellationHandle), + std::move(cancelTimeoutOnDestruction)}; +} + // _____________________________________________________________________________ CPP_template_2(typename RequestT, typename ResponseT)( requires ad_utility::httpUtils::HttpRequest) @@ -368,57 +422,80 @@ CPP_template_2(typename RequestT, typename ResponseT)( } } - auto visitOperation = [&checkParameter, &accessTokenOk, &request, &send, - ¶meters, &requestTimer, - this]( - const Operation& op, auto opFieldString, - std::function pred, - std::string msg) -> Awaitable { - if (auto timeLimit = co_await verifyUserSubmittedQueryTimeout( - checkParameter("timeout", std::nullopt), accessTokenOk, request, - send)) { - ad_utility::websocket::MessageSender messageSender = createMessageSender( - queryHub_, request, std::invoke(opFieldString, op)); - auto [parsedOperation, qec, cancellationHandle, - cancelTimeoutOnDestruction] = - parseOperation(messageSender, parameters, op, timeLimit.value()); - if (pred(parsedOperation)) { - throw std::runtime_error( - absl::StrCat(msg, parsedOperation._originalString)); - } - if constexpr (std::is_same_v) { - co_return co_await processQuery(parameters, std::move(parsedOperation), - requestTimer, cancellationHandle, qec, - std::move(request), send, - timeLimit.value()); - } else { - static_assert(std::is_same_v); - co_return co_await processUpdate( - std::move(parsedOperation), requestTimer, cancellationHandle, qec, - std::move(request), send, timeLimit.value()); - } - } else { + auto visitOperation = + [&checkParameter, &accessTokenOk, &request, &send, ¶meters, + &requestTimer, + this](ParsedQuery parsedOperation, std::string operationName, + std::function expectedOperation, + const std::string msg) -> Awaitable { + auto timeLimit = co_await verifyUserSubmittedQueryTimeout( + checkParameter("timeout", std::nullopt), accessTokenOk, request, send); + if (!timeLimit.has_value()) { // If the optional is empty, this indicates an error response has been // sent to the client already. We can stop here. co_return; } + ad_utility::websocket::MessageSender messageSender = createMessageSender( + queryHub_, request, parsedOperation._originalString); + + auto [qec, cancellationHandle, cancelTimeoutOnDestruction] = + prepareOperation(operationName, parsedOperation._originalString, + messageSender, parameters, timeLimit.value()); + if (!expectedOperation(parsedOperation)) { + throw std::runtime_error( + absl::StrCat(msg, parsedOperation._originalString)); + } + if (parsedOperation.hasUpdateClause()) { + co_return co_await processUpdate( + std::move(parsedOperation), requestTimer, cancellationHandle, qec, + std::move(request), send, timeLimit.value()); + } else { + AD_CORRECTNESS_CHECK(parsedOperation.hasSelectClause() || + parsedOperation.hasAskClause() || + parsedOperation.hasConstructClause()); + co_return co_await processQuery( + parameters, std::move(parsedOperation), requestTimer, + cancellationHandle, qec, std::move(request), send, timeLimit.value()); + } }; - auto visitQuery = [&visitOperation](const Query& query) -> Awaitable { + auto visitQuery = [&visitOperation](Query query) -> Awaitable { + auto parsedQuery = SparqlParser::parseQuery(std::move(query.query_), + query.datasetClauses_); return visitOperation( - query, &Query::query_, &ParsedQuery::hasUpdateClause, + parsedQuery, "SPARQL Query", std::not_fn(&ParsedQuery::hasUpdateClause), "SPARQL QUERY was request via the HTTP request, but the " "following update was sent instead of an query: "); }; auto visitUpdate = [&visitOperation, &requireValidAccessToken]( - const Update& update) -> Awaitable { + Update update) -> Awaitable { requireValidAccessToken("SPARQL Update"); + auto parsedUpdate = SparqlParser::parseQuery(std::move(update.update_), + update.datasetClauses_); return visitOperation( - update, &Update::update_, std::not_fn(&ParsedQuery::hasUpdateClause), + parsedUpdate, "SPARQL Update", &ParsedQuery::hasUpdateClause, "SPARQL UPDATE was request via the HTTP request, but the " "following query was sent instead of an update: "); }; - auto visitNone = [&response, &send, - &request](const None&) -> Awaitable { + auto visitGraphStore = [&request, &visitOperation, &requireValidAccessToken]( + GraphStoreOperation operation) -> Awaitable { + ParsedQuery parsedOperation = + GraphStoreProtocol::transformGraphStoreProtocol(std::move(operation), + request); + + if (parsedOperation.hasUpdateClause()) { + requireValidAccessToken("Update from Graph Store Protocol"); + } + + // Don't check for the `ParsedQuery`s actual type (Query or Update) here + // because graph store operations can result in both. + auto trueFunc = [](const ParsedQuery&) { return true; }; + std::string_view queryType = + parsedOperation.hasUpdateClause() ? "Update" : "Query"; + return visitOperation(parsedOperation, + absl::StrCat("Graph Store (", queryType, ")"), + trueFunc, "Unused dummy message"); + }; + auto visitNone = [&response, &send, &request](None) -> Awaitable { // If there was no "query", but any of the URL parameters processed before // produced a `response`, send that now. Note that if multiple URL // parameters were processed, only the `response` from the last one is sent. @@ -440,7 +517,8 @@ CPP_template_2(typename RequestT, typename ResponseT)( co_return co_await processOperation( std::move(parsedHttpRequest.operation_), - ad_utility::OverloadCallOperator{visitQuery, visitUpdate, visitNone}, + ad_utility::OverloadCallOperator{visitQuery, visitUpdate, visitGraphStore, + visitNone}, requestTimer, request, send); } @@ -456,81 +534,6 @@ std::pair Server::determineResultPinning( return {pinSubtrees, pinResult}; } -// _____________________________________________________________________________ -auto Server::cancelAfterDeadline( - std::weak_ptr> cancellationHandle, - TimeLimit timeLimit) - -> QL_CONCEPT_OR_NOTHING( - ad_utility::InvocableWithExactReturnType) auto { - net::steady_timer timer{timerExecutor_, timeLimit}; - - timer.async_wait([cancellationHandle = std::move(cancellationHandle)]( - const boost::system::error_code&) { - if (auto pointer = cancellationHandle.lock()) { - pointer->cancel(ad_utility::CancellationState::TIMEOUT); - } - }); - return [timer = std::move(timer)]() mutable { timer.cancel(); }; -} - -// _____________________________________________________________________________ -auto Server::setupCancellationHandle( - const ad_utility::websocket::QueryId& queryId, TimeLimit timeLimit) - -> QL_CONCEPT_OR_NOTHING(ad_utility::isInstantiation< - CancellationHandleAndTimeoutTimerCancel>) auto { - auto cancellationHandle = queryRegistry_.getCancellationHandle(queryId); - AD_CORRECTNESS_CHECK(cancellationHandle); - cancellationHandle->startWatchDog(); - absl::Cleanup cancelCancellationHandle{ - cancelAfterDeadline(cancellationHandle, timeLimit)}; - return CancellationHandleAndTimeoutTimerCancel{ - std::move(cancellationHandle), std::move(cancelCancellationHandle)}; -} - -// ____________________________________________________________________________ -CPP_template_2(typename Operation)( - requires QueryOrUpdate) auto Server:: - parseOperation(ad_utility::websocket::MessageSender& messageSender, - const ad_utility::url_parser::ParamValueMap& params, - const Operation& operation, TimeLimit timeLimit) { - // The operation string was to be copied, do it here at the beginning. - const auto [operationName, operationSPARQL] = - [&operation]() -> std::pair { - if constexpr (std::is_same_v) { - return {"SPARQL Query", operation.query_}; - } else { - static_assert(std::is_same_v); - return {"SPARQL Update", operation.update_}; - } - }(); - - auto [cancellationHandle, cancelTimeoutOnDestruction] = - setupCancellationHandle(messageSender.getQueryId(), timeLimit); - - // Do the query planning. This creates a `QueryExecutionTree`, which will - // then be used to process the query. - auto [pinSubtrees, pinResult] = determineResultPinning(params); - LOG(INFO) << "Processing the following " << operationName << ":" - << (pinResult ? " [pin result]" : "") - << (pinSubtrees ? " [pin subresults]" : "") << "\n" - << operationSPARQL << std::endl; - QueryExecutionContext qec(index_, &cache_, allocator_, - sortPerformanceEstimator_, std::ref(messageSender), - pinSubtrees, pinResult); - ParsedQuery parsedQuery = - SparqlParser::parseQuery(std::move(operationSPARQL)); - // SPARQL Protocol 2.1.4 specifies that the dataset from the query - // parameters overrides the dataset from the query itself. - if (!operation.datasetClauses_.empty()) { - parsedQuery.datasetClauses_ = - parsedQuery::DatasetClauses::fromClauses(operation.datasetClauses_); - } - - return std::tuple{std::move(parsedQuery), std::move(qec), - std::move(cancellationHandle), - std::move(cancelTimeoutOnDestruction)}; -} - // ____________________________________________________________________________ Awaitable Server::planQuery( net::static_thread_pool& threadPool, ParsedQuery&& operation, @@ -941,14 +944,22 @@ CPP_template_2(typename VisitorT, typename RequestT, typename ResponseT)( ad_utility::url_parser::sparqlOperation::Operation operation, VisitorT visitor, const ad_utility::Timer& requestTimer, const RequestT& request, ResponseT& send) { - auto operationString = [&operation] { + // Copy the operation string for the error case before processing the + // operation, because processing moves it. + const std::string operationString = [&operation] { if (auto* q = std::get_if(&operation)) { return q->query_; } if (auto* u = std::get_if(&operation)) { return u->update_; } - return std::string("No operation string available."); + if (std::holds_alternative(operation)) { + return std::string( + "No operation string available for Graph Store Operation"); + } + AD_CORRECTNESS_CHECK(std::holds_alternative(operation)); + return std::string( + "No operation string available, because operation type is unknown."); }(); using namespace ad_utility::httpUtils; http::status responseStatus = http::status::ok; diff --git a/src/engine/Server.h b/src/engine/Server.h index 8c483c8fbe..9e8a11eb59 100644 --- a/src/engine/Server.h +++ b/src/engine/Server.h @@ -175,15 +175,12 @@ class Server { static std::pair determineResultPinning( const ad_utility::url_parser::ParamValueMap& params); FRIEND_TEST(ServerTest, determineResultPinning); - // Parse an operation - CPP_template_2(typename Operation)( - requires QueryOrUpdate< - Operation>) auto parseOperation(ad_utility::websocket::MessageSender& - messageSender, - const ad_utility::url_parser:: - ParamValueMap& params, - const Operation& operation, - TimeLimit timeLimit); + // Prepare the execution of an operation + auto prepareOperation(std::string_view operationName, + std::string_view operationSPARQL, + ad_utility::websocket::MessageSender& messageSender, + const ad_utility::url_parser::ParamValueMap& params, + TimeLimit timeLimit); // Plan a parsed query. Awaitable planQuery(net::static_thread_pool& thread_pool, diff --git a/src/parser/SparqlParser.cpp b/src/parser/SparqlParser.cpp index e75855fe91..c8be808b3a 100644 --- a/src/parser/SparqlParser.cpp +++ b/src/parser/SparqlParser.cpp @@ -26,3 +26,16 @@ ParsedQuery SparqlParser::parseQuery(std::string query) { AD_CONTRACT_CHECK(resultOfParseAndRemainingText.remainingText_.empty()); return std::move(resultOfParseAndRemainingText.resultOfParse_); } + +// _____________________________________________________________________________ +ParsedQuery SparqlParser::parseQuery( + std::string operation, const std::vector& datasets) { + auto parsedOperation = parseQuery(std::move(operation)); + // SPARQL Protocol 2.1.4 specifies that the dataset from the query + // parameters overrides the dataset from the query itself. + if (!datasets.empty()) { + parsedOperation.datasetClauses_ = + parsedQuery::DatasetClauses::fromClauses(datasets); + } + return parsedOperation; +} diff --git a/src/parser/SparqlParser.h b/src/parser/SparqlParser.h index 26c9fd0b57..b713a7a3bc 100644 --- a/src/parser/SparqlParser.h +++ b/src/parser/SparqlParser.h @@ -14,4 +14,7 @@ class SparqlParser { public: static ParsedQuery parseQuery(std::string query); + // A convenience function for parsing the query and setting the datasets. + static ParsedQuery parseQuery(std::string operation, + const std::vector& datasets); }; diff --git a/src/parser/data/GraphRef.h b/src/parser/data/GraphRef.h index 06b7e3ab59..94f4e229bb 100644 --- a/src/parser/data/GraphRef.h +++ b/src/parser/data/GraphRef.h @@ -8,7 +8,7 @@ #include "parser/Iri.h" -using GraphRef = TripleComponent::Iri; +using GraphRef = ad_utility::triple_component::Iri; // Denotes the target graph for an operation. Here the target is the default // graph. struct DEFAULT { diff --git a/src/util/http/UrlParser.h b/src/util/http/UrlParser.h index 0ca24127b7..dd441695b5 100644 --- a/src/util/http/UrlParser.h +++ b/src/util/http/UrlParser.h @@ -11,6 +11,7 @@ #include #include +#include "parser/data/GraphRef.h" #include "parser/sparqlParser/DatasetClause.h" #include "util/HashMap.h" @@ -44,7 +45,9 @@ struct ParsedUrl { ParamValueMap parameters_; }; -// The different SPARQL operations that a `ParsedRequest` can represent. +// The different SPARQL operations that a `ParsedRequest` can represent. The +// operations represent the detected operation type and can contain additional +// information that the operation needs. namespace sparqlOperation { // A SPARQL 1.1 Query struct Query { @@ -62,13 +65,20 @@ struct Update { bool operator==(const Update& rhs) const = default; }; +// A Graph Store HTTP Protocol operation. We only store the graph on which the +// operation acts. The actual operation is extracted later. +struct GraphStoreOperation { + GraphOrDefault graph_; + bool operator==(const GraphStoreOperation& rhs) const = default; +}; + // No operation. This can happen for QLever's custom operations (e.g. // `cache-stats`). These requests have no operation but are still valid. struct None { bool operator==(const None& rhs) const = default; }; -using Operation = std::variant; +using Operation = std::variant; } // namespace sparqlOperation // Representation of parsed HTTP request. diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index d588e74e55..87fddacffc 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -455,4 +455,6 @@ addLinkAndDiscoverTest(ExecuteUpdateTest engine) addLinkAndDiscoverTest(GraphStoreProtocolTest engine) -addLinkAndDiscoverTest(SPARQLProtocolTest) +addLinkAndDiscoverTest(SPARQLProtocolTest engine) + +addLinkAndDiscoverTest(ParsedRequestBuilderTest engine) diff --git a/test/GraphStoreProtocolTest.cpp b/test/GraphStoreProtocolTest.cpp index 24eece54c4..9c514d4415 100644 --- a/test/GraphStoreProtocolTest.cpp +++ b/test/GraphStoreProtocolTest.cpp @@ -14,58 +14,24 @@ namespace m = matchers; using namespace ad_utility::testing; +using namespace ad_utility::url_parser::sparqlOperation; using Var = Variable; using TC = TripleComponent; -// _____________________________________________________________________________________________ -TEST(GraphStoreProtocolTest, extractTargetGraph) { - // Equivalent to `/?default` - EXPECT_THAT(GraphStoreProtocol::extractTargetGraph({{"default", {""}}}), - DEFAULT{}); - // Equivalent to `/?graph=foo` - EXPECT_THAT(GraphStoreProtocol::extractTargetGraph({{"graph", {"foo"}}}), - iri("")); - // Equivalent to `/?graph=foo&graph=bar` - AD_EXPECT_THROW_WITH_MESSAGE( - GraphStoreProtocol::extractTargetGraph({{"graph", {"foo", "bar"}}}), - testing::HasSubstr( - "Parameter \"graph\" must be given exactly once. Is: 2")); - const std::string eitherDefaultOrGraphErrorMsg = - "Exactly one of the query parameters default or graph must be set to " - "identify the graph for the graph store protocol request."; - // Equivalent to `/` or `/?` - AD_EXPECT_THROW_WITH_MESSAGE( - GraphStoreProtocol::extractTargetGraph({}), - testing::HasSubstr(eitherDefaultOrGraphErrorMsg)); - // Equivalent to `/?unrelated=a&unrelated=b` - AD_EXPECT_THROW_WITH_MESSAGE( - GraphStoreProtocol::extractTargetGraph({{"unrelated", {"a", "b"}}}), - testing::HasSubstr(eitherDefaultOrGraphErrorMsg)); - // Equivalent to `/?default&graph=foo` - AD_EXPECT_THROW_WITH_MESSAGE( - GraphStoreProtocol::extractTargetGraph( - {{"default", {""}}, {"graph", {"foo"}}}), - testing::HasSubstr(eitherDefaultOrGraphErrorMsg)); -} - // _____________________________________________________________________________________________ TEST(GraphStoreProtocolTest, transformPost) { auto expectTransformPost = CPP_template_lambda()(typename RequestT)( - const RequestT& request, + const RequestT& request, const GraphOrDefault& graph, const testing::Matcher& matcher, ad_utility::source_location l = ad_utility::source_location::current())( requires ad_utility::httpUtils::HttpRequest) { auto trace = generateLocationTrace(l); - const ad_utility::url_parser::ParsedUrl parsedUrl = - ad_utility::url_parser::parseRequestTarget(request.target()); - const GraphOrDefault graph = - GraphStoreProtocol::extractTargetGraph(parsedUrl.parameters_); EXPECT_THAT(GraphStoreProtocol::transformPost(request, graph), matcher); }; expectTransformPost( - makePostRequest("/?default", "text/turtle", " ."), + makePostRequest("/?default", "text/turtle", " ."), DEFAULT{}, m::UpdateClause( m::GraphUpdate( {}, {{iri(""), iri(""), iri(""), std::monostate{}}}, @@ -73,6 +39,7 @@ TEST(GraphStoreProtocolTest, transformPost) { m::GraphPattern())); expectTransformPost( makePostRequest("/?default", "application/n-triples", " ."), + DEFAULT{}, m::UpdateClause( m::GraphUpdate( {}, {{iri(""), iri(""), iri(""), std::monostate{}}}, @@ -80,6 +47,7 @@ TEST(GraphStoreProtocolTest, transformPost) { m::GraphPattern())); expectTransformPost( makePostRequest("/?graph=bar", "application/n-triples", " ."), + iri(""), m::UpdateClause( m::GraphUpdate({}, {{iri(""), iri(""), iri(""), Iri("")}}, @@ -111,41 +79,39 @@ TEST(GraphStoreProtocolTest, transformPost) { // _____________________________________________________________________________________________ TEST(GraphStoreProtocolTest, transformGet) { - auto expectTransformGet = CPP_template_lambda()(typename RequestT)( - const RequestT& request, - const testing::Matcher& matcher, - ad_utility::source_location l = ad_utility::source_location::current())( - requires ad_utility::httpUtils::HttpRequest) { - auto trace = generateLocationTrace(l); - const ad_utility::url_parser::ParsedUrl parsedUrl = - ad_utility::url_parser::parseRequestTarget(request.target()); - const GraphOrDefault graph = - GraphStoreProtocol::extractTargetGraph(parsedUrl.parameters_); - EXPECT_THAT(GraphStoreProtocol::transformGet(graph), matcher); - }; + auto expectTransformGet = + [](const GraphOrDefault& graph, + const testing::Matcher& matcher, + ad_utility::source_location l = + ad_utility::source_location::current()) { + auto trace = generateLocationTrace(l); + EXPECT_THAT(GraphStoreProtocol::transformGet(graph), matcher); + }; expectTransformGet( - makeGetRequest("/?default"), - m::ConstructQuery({{Var{"?s"}, Var{"?p"}, Var{"?o"}}}, - m::GraphPattern(matchers::Triples({SparqlTriple( - TC(Var{"?s"}), "?p", TC(Var{"?o"}))})))); + DEFAULT{}, m::ConstructQuery( + {{Var{"?s"}, Var{"?p"}, Var{"?o"}}}, + m::GraphPattern(matchers::Triples( + {SparqlTriple(TC(Var{"?s"}), "?p", TC(Var{"?o"}))})))); expectTransformGet( - makeGetRequest("/?graph=foo"), - m::ConstructQuery({{Var{"?s"}, Var{"?p"}, Var{"?o"}}}, - m::GraphPattern(matchers::Triples({SparqlTriple( - TC(Var{"?s"}), "?p", TC(Var{"?o"}))})), - ScanSpecificationAsTripleComponent::Graphs{ - {TripleComponent(iri(""))}})); + iri(""), + m::ConstructQuery( + {{Var{"?s"}, Var{"?p"}, Var{"?o"}}}, + m::GraphPattern(m::GroupGraphPatternWithGraph( + iri(""), m::Triples({SparqlTriple(TC(Var{"?s"}), "?p", + TC(Var{"?o"}))}))))); } // _____________________________________________________________________________________________ TEST(GraphStoreProtocolTest, transformGraphStoreProtocol) { EXPECT_THAT(GraphStoreProtocol::transformGraphStoreProtocol( + GraphStoreOperation{DEFAULT{}}, ad_utility::testing::makeGetRequest("/?default")), m::ConstructQuery({{Var{"?s"}, Var{"?p"}, Var{"?o"}}}, m::GraphPattern(matchers::Triples({SparqlTriple( TC(Var{"?s"}), "?p", TC(Var{"?o"}))})))); EXPECT_THAT( GraphStoreProtocol::transformGraphStoreProtocol( + GraphStoreOperation{DEFAULT{}}, ad_utility::testing::makePostRequest( "/?default", "application/n-triples", " .")), m::UpdateClause(m::GraphUpdate({}, @@ -159,6 +125,7 @@ TEST(GraphStoreProtocolTest, transformGraphStoreProtocol) { auto trace = generateLocationTrace(l); AD_EXPECT_THROW_WITH_MESSAGE( GraphStoreProtocol::transformGraphStoreProtocol( + GraphStoreOperation{DEFAULT{}}, ad_utility::testing::makeRequest(method, "/?default")), testing::HasSubstr( absl::StrCat(std::string{boost::beast::http::to_string(method)}, @@ -170,6 +137,7 @@ TEST(GraphStoreProtocolTest, transformGraphStoreProtocol) { expectUnsupportedMethod(http::verb::patch); AD_EXPECT_THROW_WITH_MESSAGE( GraphStoreProtocol::transformGraphStoreProtocol( + GraphStoreOperation{DEFAULT{}}, ad_utility::testing::makeRequest(boost::beast::http::verb::connect, "/?default")), testing::HasSubstr("Unsupported HTTP method \"CONNECT\"")); diff --git a/test/ParsedRequestBuilderTest.cpp b/test/ParsedRequestBuilderTest.cpp new file mode 100644 index 0000000000..76e10069fa --- /dev/null +++ b/test/ParsedRequestBuilderTest.cpp @@ -0,0 +1,296 @@ +// Copyright 2024-2025, University of Freiburg, +// Chair of Algorithms and Data Structures. +// Author: Julian Mundhahs (mundhahj@tf.uni-freiburg.de) + +#include +#include +#include + +#include "util/GTestHelpers.h" +#include "util/HttpRequestHelpers.h" +#include "util/TypeIdentity.h" +#include "util/http/HttpUtils.h" +#include "util/http/UrlParser.h" + +using namespace ad_utility::use_type_identity; +using namespace ad_utility::url_parser; +using namespace ad_utility::url_parser::sparqlOperation; +using namespace ad_utility::testing; + +// _____________________________________________________________________________________________ +TEST(ParsedRequestBuilderTest, Constructor) { + auto expect = [](const auto& request, const std::string& path, + const ParamValueMap& params, + const ad_utility::source_location l = + ad_utility::source_location::current()) { + auto t = generateLocationTrace(l); + const auto builder = ParsedRequestBuilder(request); + EXPECT_THAT( + builder.parsedRequest_, + AllOf(AD_FIELD(ParsedRequest, path_, testing::Eq(path)), + AD_FIELD(ParsedRequest, accessToken_, testing::Eq(std::nullopt)), + AD_FIELD(ParsedRequest, parameters_, testing::Eq(params)), + AD_FIELD(ParsedRequest, operation_, + testing::VariantWith(None{})))); + }; + expect(makeGetRequest("/"), "/", {}); + expect(makeGetRequest("/default?graph=bar"), "/default", + {{"graph", {"bar"}}}); + expect(makeGetRequest("/api/foo?graph=bar&query=foo&graph=baz"), "/api/foo", + {{"graph", {"bar", "baz"}}, {"query", {"foo"}}}); +} + +// _____________________________________________________________________________________________ +TEST(ParsedRequestBuilderTest, extractAccessToken) { + auto expect = [](const auto& request, const std::optional& expected, + const ad_utility::source_location l = + ad_utility::source_location::current()) { + auto t = generateLocationTrace(l); + auto builder = ParsedRequestBuilder(request); + EXPECT_THAT(builder.parsedRequest_.accessToken_, testing::Eq(std::nullopt)); + builder.extractAccessToken(request); + EXPECT_THAT(builder.parsedRequest_.accessToken_, testing::Eq(expected)); + }; + expect(makeGetRequest("/"), std::nullopt); + expect(makeGetRequest("/?query=foo"), std::nullopt); + expect(makeGetRequest("/?query=foo&access-token=bar"), "bar"); + expect(makePostRequest("/?access-token=bar", + "application/x-www-form-urlencoded", "query=foo"), + "bar"); + expect( + makePostRequest("/?access-token=bar", "application/sparql-update", "foo"), + "bar"); + expect(makeRequest(http::verb::get, "/", + {{http::field::authorization, "Bearer bar"}}, ""), + "bar"); + expect(makeRequest(http::verb::post, "/", + {{http::field::authorization, "Bearer bar"}}, ""), + "bar"); +} + +// _____________________________________________________________________________________________ +TEST(ParsedRequestBuilderTest, extractDatasetClause) { + auto expect = [](const auto& request, TI, + const std::vector& expected, + const ad_utility::source_location l = + ad_utility::source_location::current()) { + auto t = generateLocationTrace(l); + auto builder = ParsedRequestBuilder(request); + // Initialize an empty operation with no dataset clauses set. + builder.parsedRequest_.operation_ = T{"", {}}; + builder.extractDatasetClauses(); + EXPECT_THAT(builder.parsedRequest_.operation_, + testing::VariantWith( + AD_FIELD(T, datasetClauses_, testing::Eq(expected)))); + }; + auto Iri = ad_utility::triple_component::Iri::fromIriref; + expect(makeGetRequest("/"), ti, {}); + expect(makeGetRequest("/?default-graph-uri=foo"), ti, + {{Iri(""), false}}); + expect(makeGetRequest("/?named-graph-uri=bar"), ti, + {{Iri(""), true}}); + expect(makeGetRequest("/?default-graph-uri=foo&named-graph-uri=bar&using-" + "graph-uri=baz&using-named-graph-uri=abc"), + ti, {{Iri(""), false}, {Iri(""), true}}); + expect(makePostRequest("/?default-graph-uri=foo&named-graph-uri=bar&using-" + "graph-uri=baz&using-named-graph-uri=abc", + "", ""), + ti, {{Iri(""), false}, {Iri(""), true}}); +} + +// _____________________________________________________________________________________________ +TEST(ParsedRequestBuilderTest, extractOperationIfSpecified) { + auto expect = [](const auto& request, TI, + std::string_view paramName, + const Operation& expected, + const ad_utility::source_location l = + ad_utility::source_location::current()) { + auto t = generateLocationTrace(l); + auto builder = ParsedRequestBuilder(request); + EXPECT_THAT(builder.parsedRequest_.operation_, + testing::VariantWith(None{})); + // Initialize an empty operation with no dataset clauses set. + builder.extractOperationIfSpecified(paramName); + EXPECT_THAT(builder.parsedRequest_.operation_, testing::Eq(expected)); + }; + expect(makeGetRequest("/"), ti, "query", None{}); + expect(makeGetRequest("/?query=foo"), ti, "update", None{}); + expect(makeGetRequest("/?query=foo"), ti, "query", Query{"foo", {}}); + expect(makePostRequest("/", "", ""), ti, "update", None{}); + expect(makePostRequest("/?update=bar", "", ""), ti, "update", + Update{"bar", {}}); +} + +// _____________________________________________________________________________________________ +TEST(ParsedRequestBuilderTest, isGraphStoreOperation) { + auto isGraphStoreOperation = + CPP_template_lambda()(typename RequestT)(const RequestT& request)( + requires ad_utility::httpUtils::HttpRequest) { + const auto builder = ParsedRequestBuilder(request); + return builder.isGraphStoreOperation(); + }; + EXPECT_THAT(isGraphStoreOperation(makeGetRequest("/")), testing::IsFalse()); + EXPECT_THAT( + isGraphStoreOperation(makeGetRequest("/?query=foo&access-token=bar")), + testing::IsFalse()); + EXPECT_THAT(isGraphStoreOperation(makeGetRequest("/?default")), + testing::IsTrue()); + EXPECT_THAT(isGraphStoreOperation(makeGetRequest("/?graph=foo")), + testing::IsTrue()); + EXPECT_THAT(isGraphStoreOperation( + makeGetRequest("/default?query=foo&access-token=bar")), + testing::IsFalse()); +} + +// _____________________________________________________________________________________________ +TEST(ParsedRequestBuilderTest, extractGraphStoreOperation) { + auto Iri = ad_utility::triple_component::Iri::fromIriref; + auto expect = [](const auto& request, const GraphOrDefault& graph, + const ad_utility::source_location l = + ad_utility::source_location::current()) { + auto t = generateLocationTrace(l); + auto builder = ParsedRequestBuilder(request); + EXPECT_THAT(builder.parsedRequest_.operation_, + testing::VariantWith(None{})); + builder.extractGraphStoreOperation(); + EXPECT_THAT(builder.parsedRequest_.operation_, + testing::VariantWith( + AD_FIELD(GraphStoreOperation, graph_, testing::Eq(graph)))); + }; + expect(makeGetRequest("/?default"), DEFAULT{}); + expect(makeGetRequest("/?graph=foo"), Iri("")); + expect(makePostRequest("/?default", "", ""), DEFAULT{}); + expect(makePostRequest("/?graph=bar", "", ""), Iri("")); + { + auto builder = ParsedRequestBuilder(makeGetRequest("/?default&graph=foo")); + AD_EXPECT_THROW_WITH_MESSAGE( + builder.extractGraphStoreOperation(), + testing::HasSubstr( + R"(Parameters "graph" and "default" must not be set at the same time.)")); + } + { + auto builder = ParsedRequestBuilder(makeGetRequest("/default")); + builder.parsedRequest_.operation_ = Query{"foo", {}}; + AD_EXPECT_THROW_WITH_MESSAGE(builder.extractGraphStoreOperation(), + testing::HasSubstr("")); + } +} + +// _____________________________________________________________________________________________ +TEST(ParsedRequestBuilderTest, parametersContain) { + auto builder = + ParsedRequestBuilder(makeGetRequest("/?query=foo&access-token=bar&baz")); + EXPECT_THAT(builder.parametersContain("query"), testing::IsTrue()); + EXPECT_THAT(builder.parametersContain("access-token"), testing::IsTrue()); + EXPECT_THAT(builder.parametersContain("baz"), testing::IsTrue()); + EXPECT_THAT(builder.parametersContain("default"), testing::IsFalse()); + EXPECT_THAT(builder.parametersContain("graph"), testing::IsFalse()); + builder.parsedRequest_.parameters_ = {{"graph", {"foo"}}}; + EXPECT_THAT(builder.parametersContain("query"), testing::IsFalse()); + EXPECT_THAT(builder.parametersContain("access-token"), testing::IsFalse()); + EXPECT_THAT(builder.parametersContain("baz"), testing::IsFalse()); + EXPECT_THAT(builder.parametersContain("default"), testing::IsFalse()); + EXPECT_THAT(builder.parametersContain("graph"), testing::IsTrue()); +} + +// _____________________________________________________________________________________________ +TEST(ParsedRequestBuilderTest, reportUnsupportedContentTypeIfGraphStore) { + auto builderGraphStore = ParsedRequestBuilder(makeGetRequest("/?default")); + AD_EXPECT_THROW_WITH_MESSAGE( + builderGraphStore.reportUnsupportedContentTypeIfGraphStore( + "application/x-www-form-urlencoded"), + testing::HasSubstr("")); + auto builderQuery = ParsedRequestBuilder(makeGetRequest("/?query=foo")); + EXPECT_NO_THROW(builderQuery.reportUnsupportedContentTypeIfGraphStore( + "application/sparql-query")); +} + +// _____________________________________________________________________________________________ +TEST(ParsedRequestBuilderTest, parameterIsContainedExactlyOnce) { + auto builder = ParsedRequestBuilder( + makeGetRequest("/?query=foo&access-token=bar&baz&query=baz")); + EXPECT_THAT(builder.parameterIsContainedExactlyOnce("does-not-exist"), + testing::IsFalse()); + EXPECT_THAT(builder.parameterIsContainedExactlyOnce("access-token"), + testing::IsTrue()); + AD_EXPECT_THROW_WITH_MESSAGE( + builder.parameterIsContainedExactlyOnce("query"), + testing::HasSubstr( + "Parameter \"query\" must be given exactly once. Is: 2")); +} + +// _____________________________________________________________________________________________ +TEST(ParsedRequestBuilderTest, extractTargetGraph) { + auto Iri = ad_utility::triple_component::Iri::fromIriref; + const auto extractTargetGraph = ParsedRequestBuilder::extractTargetGraph; + // Equivalent to `/?default` + EXPECT_THAT(extractTargetGraph({{"default", {""}}}), DEFAULT{}); + // Equivalent to `/?graph=foo` + EXPECT_THAT(extractTargetGraph({{"graph", {"foo"}}}), Iri("")); + // Equivalent to `/?graph=foo&graph=bar` + AD_EXPECT_THROW_WITH_MESSAGE( + extractTargetGraph({{"graph", {"foo", "bar"}}}), + testing::HasSubstr( + "Parameter \"graph\" must be given exactly once. Is: 2")); + const std::string eitherDefaultOrGraphErrorMsg = + R"(Exactly one of the query parameters "default" or "graph" must be set to identify the graph for the graph store protocol request.)"; + // Equivalent to `/` or `/?` + AD_EXPECT_THROW_WITH_MESSAGE( + extractTargetGraph({}), testing::HasSubstr(eitherDefaultOrGraphErrorMsg)); + // Equivalent to `/?unrelated=a&unrelated=b` + AD_EXPECT_THROW_WITH_MESSAGE( + extractTargetGraph({{"unrelated", {"a", "b"}}}), + testing::HasSubstr(eitherDefaultOrGraphErrorMsg)); + // Equivalent to `/?default&graph=foo` + AD_EXPECT_THROW_WITH_MESSAGE( + extractTargetGraph({{"default", {""}}, {"graph", {"foo"}}}), + testing::HasSubstr(eitherDefaultOrGraphErrorMsg)); +} + +// _____________________________________________________________________________________________ +TEST(ParsedRequestBuilderTest, determineAccessToken) { + auto extract = + CPP_template_lambda()(typename RequestT)(const RequestT& request)( + requires ad_utility::httpUtils::HttpRequest) { + auto parsedUrl = parseRequestTarget(request.target()); + return ParsedRequestBuilder::determineAccessToken(request, + parsedUrl.parameters_); + }; + EXPECT_THAT(extract(makeGetRequest("/")), testing::Eq(std::nullopt)); + EXPECT_THAT(extract(makeGetRequest("/?access-token=foo")), + testing::Optional(testing::Eq("foo"))); + EXPECT_THAT( + extract(makeRequest(http::verb::get, "/", + {{http::field::authorization, "Bearer foo"}})), + testing::Optional(testing::Eq("foo"))); + EXPECT_THAT( + extract(makeRequest(http::verb::get, "/?access-token=foo", + {{http::field::authorization, "Bearer foo"}})), + testing::Optional(testing::Eq("foo"))); + AD_EXPECT_THROW_WITH_MESSAGE( + extract(makeRequest(http::verb::get, "/?access-token=bar", + {{http::field::authorization, "Bearer foo"}})), + testing::HasSubstr( + "Access token is specified both in the `Authorization` header and " + "by the `access-token` parameter, but they are not the same")); + AD_EXPECT_THROW_WITH_MESSAGE( + extract(makeRequest(http::verb::get, "/", + {{http::field::authorization, "foo"}})), + testing::HasSubstr( + "Authorization header doesn't start with \"Bearer \".")); + EXPECT_THAT(extract(makePostRequest("/", "text/turtle", "")), + testing::Eq(std::nullopt)); + EXPECT_THAT(extract(makePostRequest("/?access-token=foo", "text/turtle", "")), + testing::Optional(testing::Eq("foo"))); + AD_EXPECT_THROW_WITH_MESSAGE( + extract(makeRequest(http::verb::post, "/?access-token=bar", + {{http::field::authorization, "Bearer foo"}})), + testing::HasSubstr( + "Access token is specified both in the `Authorization` header and " + "by the `access-token` parameter, but they are not the same")); + AD_EXPECT_THROW_WITH_MESSAGE( + extract(makeRequest(http::verb::post, "/?access-token=bar", + {{http::field::authorization, "foo"}})), + testing::HasSubstr( + "Authorization header doesn't start with \"Bearer \".")); +} diff --git a/test/SPARQLProtocolTest.cpp b/test/SPARQLProtocolTest.cpp index fff0af6c8e..6fd6a3ec73 100644 --- a/test/SPARQLProtocolTest.cpp +++ b/test/SPARQLProtocolTest.cpp @@ -18,11 +18,10 @@ using namespace ad_utility::url_parser::sparqlOperation; using namespace ad_utility::testing; namespace { -auto ParsedRequestIs = [](const std::string& path, - const std::optional& accessToken, - const ParamValueMap& parameters, - const std::variant& operation) - -> testing::Matcher { +auto ParsedRequestIs = + [](const std::string& path, const std::optional& accessToken, + const ParamValueMap& parameters, + const Operation& operation) -> testing::Matcher { return testing::AllOf( AD_FIELD(ad_utility::url_parser::ParsedRequest, path_, testing::Eq(path)), AD_FIELD(ad_utility::url_parser::ParsedRequest, accessToken_, @@ -32,56 +31,203 @@ auto ParsedRequestIs = [](const std::string& path, AD_FIELD(ad_utility::url_parser::ParsedRequest, operation_, testing::Eq(operation))); }; +auto Iri = ad_utility::triple_component::Iri::fromIriref; + +const std::string URLENCODED_PLAIN = "application/x-www-form-urlencoded"; +const std::string URLENCODED = URLENCODED_PLAIN + ";charset=UTF-8"; +const std::string QUERY = "application/sparql-query"; +const std::string UPDATE = "application/sparql-update"; +const std::string TURTLE = "text/turtle"; + +auto testAccessTokenCombinations = [](auto parse, const http::verb& method, + std::string_view pathBase, + const Operation& expectedOperation, + const ad_utility::HashMap& + headers = {}, + const std::optional& body = + std::nullopt, + ad_utility::source_location l = + ad_utility::source_location:: + current()) { + auto t = generateLocationTrace(l); + // Test the cases: + // 1. No access token + // 2. Access token in query + // 3. Access token in `Authorization` header + // 4. Different access tokens + // 5. Same access token + boost::urls::url pathWithAccessToken{pathBase}; + pathWithAccessToken.params().append({"access-token", "foo"}); + ad_utility::HashMap headersWithDifferentAccessToken{ + headers}; + headersWithDifferentAccessToken.insert( + {http::field::authorization, "Bearer bar"}); + ad_utility::HashMap headersWithSameAccessToken{ + headers}; + headersWithSameAccessToken.insert({http::field::authorization, "Bearer foo"}); + EXPECT_THAT(parse(makeRequest(method, pathBase, headers, body)), + ParsedRequestIs("/", std::nullopt, {}, expectedOperation)); + EXPECT_THAT( + parse(makeRequest(method, pathWithAccessToken.buffer(), headers, body)), + ParsedRequestIs("/", "foo", {{"access-token", {"foo"}}}, + expectedOperation)); + EXPECT_THAT(parse(makeRequest(method, pathBase, + headersWithDifferentAccessToken, body)), + ParsedRequestIs("/", "bar", {}, expectedOperation)); + EXPECT_THAT(parse(makeRequest(method, pathWithAccessToken.buffer(), + headersWithSameAccessToken, body)), + ParsedRequestIs("/", "foo", {{"access-token", {"foo"}}}, + expectedOperation)); + AD_EXPECT_THROW_WITH_MESSAGE( + parse(makeRequest(method, pathWithAccessToken.buffer(), + headersWithDifferentAccessToken, body)), + testing::HasSubstr("Access token is specified both in the " + "`Authorization` header and by the `access-token` " + "parameter, but they are not the same")); +}; +auto testAccessTokenCombinationsUrlEncoded = + [](auto parse, const std::string& bodyBase, + const Operation& expectedOperation, + ad_utility::source_location l = ad_utility::source_location::current()) { + auto t = generateLocationTrace(l); + // Test the cases: + // 1. No access token + // 2. Access token in query + // 3. Access token in `Authorization` header + // 4. Different access tokens + // 5. Same access token + boost::urls::url paramsWithAccessToken{absl::StrCat("/?", bodyBase)}; + paramsWithAccessToken.params().append({"access-token", "foo"}); + std::string bodyWithAccessToken{ + paramsWithAccessToken.encoded_params().buffer()}; + ad_utility::HashMap headers{ + {http::field::content_type, {URLENCODED}}}; + ad_utility::HashMap + headersWithDifferentAccessToken{ + {http::field::content_type, {URLENCODED}}, + {http::field::authorization, "Bearer bar"}}; + ad_utility::HashMap headersWithSameAccessToken{ + {http::field::content_type, {URLENCODED}}, + {http::field::authorization, "Bearer foo"}}; + EXPECT_THAT(parse(makeRequest(http::verb::post, "/", headers, bodyBase)), + ParsedRequestIs("/", std::nullopt, {}, expectedOperation)); + EXPECT_THAT(parse(makeRequest(http::verb::post, "/", headers, + bodyWithAccessToken)), + ParsedRequestIs("/", "foo", {{"access-token", {"foo"}}}, + expectedOperation)); + EXPECT_THAT(parse(makeRequest(http::verb::post, "/", + headersWithDifferentAccessToken, bodyBase)), + ParsedRequestIs("/", "bar", {}, expectedOperation)); + EXPECT_THAT(parse(makeRequest(http::verb::post, "/", + headersWithSameAccessToken, bodyBase)), + ParsedRequestIs("/", "foo", {}, expectedOperation)); + AD_EXPECT_THROW_WITH_MESSAGE( + parse(makeRequest(http::verb::post, "/", + headersWithDifferentAccessToken, + bodyWithAccessToken)), + testing::HasSubstr("Access token is specified both in the " + "`Authorization` header and by the `access-token` " + "parameter, but they are not the same")); + }; + } // namespace -TEST(SPARQLProtocolTest, parseHttpRequest) { +// _____________________________________________________________________________________________ +TEST(SPARQLProtocolTest, parseGET) { auto parse = CPP_template_lambda()(typename RequestT)(const RequestT& request)( requires ad_utility::httpUtils::HttpRequest) { - return SPARQLProtocol::parseHttpRequest(request); + return SPARQLProtocol::parseGET(request); }; - const std::string URLENCODED = - "application/x-www-form-urlencoded;charset=UTF-8"; - const std::string QUERY = "application/sparql-query"; - const std::string UPDATE = "application/sparql-update"; + // No SPARQL Operation EXPECT_THAT(parse(makeGetRequest("/")), ParsedRequestIs("/", std::nullopt, {}, None{})); EXPECT_THAT(parse(makeGetRequest("/ping")), ParsedRequestIs("/ping", std::nullopt, {}, None{})); EXPECT_THAT(parse(makeGetRequest("/?cmd=stats")), ParsedRequestIs("/", std::nullopt, {{"cmd", {"stats"}}}, None{})); + // Query EXPECT_THAT(parse(makeGetRequest( "/?query=SELECT+%2A%20WHERE%20%7B%7D&action=csv_export")), ParsedRequestIs("/", std::nullopt, {{"action", {"csv_export"}}}, Query{"SELECT * WHERE {}", {}})); + // Check that the correct datasets for the method (GET or POST) are added EXPECT_THAT( - parse(makePostRequest("/", URLENCODED, - "query=SELECT+%2A%20WHERE%20%7B%7D&send=100")), - ParsedRequestIs("/", std::nullopt, {{"send", {"100"}}}, - Query{"SELECT * WHERE {}", {}})); - AD_EXPECT_THROW_WITH_MESSAGE( - parse(makePostRequest("/", URLENCODED, - "ääär y=SELECT+%2A%20WHERE%20%7B%7D&send=100")), - ::testing::HasSubstr("Invalid URL-encoded POST request")); + parse(makeGetRequest("/?query=SELECT%20%2A%20WHERE%20%7B%7D&default-" + "graph-uri=foo&named-graph-uri=bar&using-graph-uri=" + "baz&using-named-graph-uri=cat")), + ParsedRequestIs("/", std::nullopt, + {{"default-graph-uri", {"foo"}}, + {"named-graph-uri", {"bar"}}, + {"using-graph-uri", {"baz"}}, + {"using-named-graph-uri", {"cat"}}}, + Query{"SELECT * WHERE {}", + {DatasetClause{Iri(""), false}, + DatasetClause{Iri(""), true}}})); + // Access token is read correctly + testAccessTokenCombinations(parse, http::verb::get, "/?query=a", + Query{"a", {}}); AD_EXPECT_THROW_WITH_MESSAGE( parse(makeGetRequest("/?query=SELECT%20%2A%20WHERE%20%7B%7D&query=SELECT%" "20%3Ffoo%20WHERE%20%7B%7D")), ::testing::StrEq( "Parameter \"query\" must be given exactly once. Is: 2")); + // Update (not allowed) AD_EXPECT_THROW_WITH_MESSAGE( - parse(makePostRequest("/", URLENCODED, - "query=SELECT%20%2A%20WHERE%20%7B%7D&update=DELETE%" - "20%7B%7D%20WHERE%20%7B%7D")), - ::testing::HasSubstr( - "Request must only contain one of \"query\" and \"update\".")); + parse(makeGetRequest("/?update=DELETE%20%2A%20WHERE%20%7B%7D")), + testing::StrEq("SPARQL Update is not allowed as GET request.")); + // Graph Store Operation + EXPECT_THAT(parse(makeGetRequest("/?graph=foo")), + ParsedRequestIs("/", std::nullopt, {{"graph", {"foo"}}}, + GraphStoreOperation{Iri("")})); + EXPECT_THAT(parse(makeGetRequest("/?default")), + ParsedRequestIs("/", std::nullopt, {{"default", {""}}}, + GraphStoreOperation{DEFAULT{}})); + EXPECT_THAT( + parse(makeGetRequest("/?default&access-token=foo&timeout=120s")), + ParsedRequestIs( + "/", "foo", + {{"access-token", {"foo"}}, {"default", {""}}, {"timeout", {"120s"}}}, + GraphStoreOperation{DEFAULT{}})); + AD_EXPECT_THROW_WITH_MESSAGE( + parse(makeGetRequest("/?default&default")), + testing::HasSubstr("Parameter \"default\" must be " + "given exactly once. Is: 2")); + AD_EXPECT_THROW_WITH_MESSAGE( + parse(makeGetRequest("/?graph=%3Cfoo%3E&graph=%3Cbar%3E")), + testing::HasSubstr("Parameter \"graph\" must be " + "given exactly once. Is: 2")); AD_EXPECT_THROW_WITH_MESSAGE( + parse(makeGetRequest("/?query=SELECT+%2A%20WHERE%20%7B%7D&graph=foo")), + testing::HasSubstr( + R"(Request contains parameters for both a SPARQL Query ("query") and a Graph Store Protocol operation ("graph" or "default").)")); + AD_EXPECT_THROW_WITH_MESSAGE( + parse(makeGetRequest("/?query=SELECT+%2A%20WHERE%20%7B%7D&default")), + testing::HasSubstr( + R"(Request contains parameters for both a SPARQL Query ("query") and a Graph Store Protocol operation ("graph" or "default").)")); +} + +// _____________________________________________________________________________________________ +TEST(SPARQLProtocolTest, parseUrlencodedPOST) { + auto parse = + CPP_template_lambda()(typename RequestT)(const RequestT& request)( + requires ad_utility::httpUtils::HttpRequest) { + return SPARQLProtocol::parseUrlencodedPOST(request); + }; + + // No SPARQL Operation + EXPECT_THAT( + parse(makePostRequest("/", URLENCODED, "cmd=clear-cache")), + ParsedRequestIs("/", std::nullopt, {{"cmd", {"clear-cache"}}}, None{})); + // Query + EXPECT_THAT( parse(makePostRequest("/", URLENCODED, - "update=DELETE%20%7B%7D%20WHERE%20%7B%7D&update=" - "DELETE%20%7B%7D%20WHERE%20%7B%7D")), - ::testing::StrEq( - "Parameter \"update\" must be given exactly once. Is: 2")); + "query=SELECT+%2A%20WHERE%20%7B%7D&send=100")), + ParsedRequestIs("/", std::nullopt, {{"send", {"100"}}}, + Query{"SELECT * WHERE {}", {}})); EXPECT_THAT( - parse(makePostRequest("/", "application/x-www-form-urlencoded", + parse(makePostRequest("/", URLENCODED, "query=SELECT%20%2A%20WHERE%20%7B%7D&send=100")), ParsedRequestIs("/", std::nullopt, {{"send", {"100"}}}, Query{"SELECT * WHERE {}", {}})); @@ -89,7 +235,6 @@ TEST(SPARQLProtocolTest, parseHttpRequest) { parse(makePostRequest("/", URLENCODED, "query=SELECT%20%2A%20WHERE%20%7B%7D")), ParsedRequestIs("/", std::nullopt, {}, Query{"SELECT * WHERE {}", {}})); - auto Iri = ad_utility::triple_component::Iri::fromIriref; EXPECT_THAT( parse(makePostRequest( "/", URLENCODED, @@ -104,71 +249,6 @@ TEST(SPARQLProtocolTest, parseHttpRequest) { {DatasetClause{Iri(""), false}, DatasetClause{Iri(""), true}, DatasetClause{Iri(""), true}}})); - ; - AD_EXPECT_THROW_WITH_MESSAGE( - parse(makePostRequest("/?send=100", URLENCODED, - "query=SELECT%20%2A%20WHERE%20%7B%7D")), - testing::StrEq("URL-encoded POST requests must not contain query " - "parameters in the URL.")); - EXPECT_THAT( - parse(makePostRequest("/", URLENCODED, "cmd=clear-cache")), - ParsedRequestIs("/", std::nullopt, {{"cmd", {"clear-cache"}}}, None{})); - EXPECT_THAT( - parse(makePostRequest("/", QUERY, "SELECT * WHERE {}")), - ParsedRequestIs("/", std::nullopt, {}, Query{"SELECT * WHERE {}", {}})); - EXPECT_THAT(parse(makePostRequest("/?send=100", QUERY, "SELECT * WHERE {}")), - ParsedRequestIs("/", std::nullopt, {{"send", {"100"}}}, - Query{"SELECT * WHERE {}", {}})); - AD_EXPECT_THROW_WITH_MESSAGE( - parse(makeRequest(http::verb::patch, "/")), - testing::StrEq( - "Request method \"PATCH\" not supported (has to be GET or POST)")); - AD_EXPECT_THROW_WITH_MESSAGE( - parse(makePostRequest("/", "invalid/content-type", "")), - testing::StrEq( - "POST request with content type \"invalid/content-type\" not " - "supported (must be \"application/x-www-form-urlencoded\", " - "\"application/sparql-query\" or \"application/sparql-update\")")); - AD_EXPECT_THROW_WITH_MESSAGE( - parse(makeGetRequest("/?update=DELETE%20%2A%20WHERE%20%7B%7D")), - testing::StrEq("SPARQL Update is not allowed as GET request.")); - EXPECT_THAT( - parse(makePostRequest("/", UPDATE, "DELETE * WHERE {}")), - ParsedRequestIs("/", std::nullopt, {}, Update{"DELETE * WHERE {}", {}})); - EXPECT_THAT( - parse(makePostRequest("/", URLENCODED, - "update=DELETE%20%2A%20WHERE%20%7B%7D")), - ParsedRequestIs("/", std::nullopt, {}, Update{"DELETE * WHERE {}", {}})); - EXPECT_THAT( - parse( - makePostRequest("/", URLENCODED, "update=DELETE+%2A+WHERE%20%7B%7D")), - ParsedRequestIs("/", std::nullopt, {}, Update{"DELETE * WHERE {}", {}})); - // Check that the correct datasets for the method (GET or POST) are added - EXPECT_THAT( - parse(makeGetRequest("/?query=SELECT%20%2A%20WHERE%20%7B%7D&default-" - "graph-uri=foo&named-graph-uri=bar&using-graph-uri=" - "baz&using-named-graph-uri=cat")), - ParsedRequestIs("/", std::nullopt, - {{"default-graph-uri", {"foo"}}, - {"named-graph-uri", {"bar"}}, - {"using-graph-uri", {"baz"}}, - {"using-named-graph-uri", {"cat"}}}, - Query{"SELECT * WHERE {}", - {DatasetClause{Iri(""), false}, - DatasetClause{Iri(""), true}}})); - EXPECT_THAT( - parse(makePostRequest("/?default-" - "graph-uri=foo&named-graph-uri=bar&using-graph-uri=" - "baz&using-named-graph-uri=cat", - QUERY, "SELECT * WHERE {}")), - ParsedRequestIs("/", std::nullopt, - {{"default-graph-uri", {"foo"}}, - {"named-graph-uri", {"bar"}}, - {"using-graph-uri", {"baz"}}, - {"using-named-graph-uri", {"cat"}}}, - Query{"SELECT * WHERE {}", - {DatasetClause{Iri(""), false}, - DatasetClause{Iri(""), true}}})); EXPECT_THAT( parse(makePostRequest("/", URLENCODED, "query=SELECT%20%2A%20WHERE%20%7B%7D&default-graph-" @@ -182,6 +262,18 @@ TEST(SPARQLProtocolTest, parseHttpRequest) { Query{"SELECT * WHERE {}", {DatasetClause{Iri(""), false}, DatasetClause{Iri(""), true}}})); + testAccessTokenCombinationsUrlEncoded(parse, + "query=SELECT%20%2A%20WHERE%20%7B%7D", + Query{"SELECT * WHERE {}", {}}); + // Update + EXPECT_THAT( + parse(makePostRequest("/", URLENCODED, + "update=DELETE%20%2A%20WHERE%20%7B%7D")), + ParsedRequestIs("/", std::nullopt, {}, Update{"DELETE * WHERE {}", {}})); + EXPECT_THAT( + parse( + makePostRequest("/", URLENCODED, "update=DELETE+%2A+WHERE%20%7B%7D")), + ParsedRequestIs("/", std::nullopt, {}, Update{"DELETE * WHERE {}", {}})); EXPECT_THAT( parse(makePostRequest("/", URLENCODED, "update=INSERT%20DATA%20%7B%7D&default-graph-uri=" @@ -197,6 +289,102 @@ TEST(SPARQLProtocolTest, parseHttpRequest) { Update{"INSERT DATA {}", {DatasetClause{Iri(""), false}, DatasetClause{Iri(""), true}}})); + testAccessTokenCombinationsUrlEncoded(parse, "update=DELETE%20WHERE%20%7B%7D", + Update{"DELETE WHERE {}", {}}); + // Error conditions + AD_EXPECT_THROW_WITH_MESSAGE( + parse(makePostRequest("/?send=100", URLENCODED, + "query=SELECT%20%2A%20WHERE%20%7B%7D")), + testing::StrEq("URL-encoded POST requests must not contain query " + "parameters in the URL.")); + AD_EXPECT_THROW_WITH_MESSAGE( + parse(makePostRequest("/", URLENCODED, + "ääär y=SELECT+%2A%20WHERE%20%7B%7D&send=100")), + ::testing::HasSubstr("Invalid URL-encoded POST request")); + AD_EXPECT_THROW_WITH_MESSAGE( + parse(makePostRequest("/", URLENCODED, + "query=SELECT%20%2A%20WHERE%20%7B%7D&update=DELETE%" + "20%7B%7D%20WHERE%20%7B%7D")), + ::testing::HasSubstr( + "Request must only contain one of \"query\" and \"update\".")); + AD_EXPECT_THROW_WITH_MESSAGE( + parse(makePostRequest("/", URLENCODED, + "update=DELETE%20%7B%7D%20WHERE%20%7B%7D&update=" + "DELETE%20%7B%7D%20WHERE%20%7B%7D")), + ::testing::StrEq( + "Parameter \"update\" must be given exactly once. Is: 2")); + // Graph Store Protocol (not allowed) + AD_EXPECT_THROW_WITH_MESSAGE( + parse(makePostRequest("/", URLENCODED, "graph=foo")), + testing::HasSubstr(absl::StrCat("Unsupported Content type \"", + URLENCODED_PLAIN, + "\" for Graph Store protocol."))); + AD_EXPECT_THROW_WITH_MESSAGE( + parse(makePostRequest("/", URLENCODED, "default")), + testing::HasSubstr(absl::StrCat("Unsupported Content type \"", + URLENCODED_PLAIN, + "\" for Graph Store protocol."))); +} + +// _____________________________________________________________________________________________ +TEST(SPARQLProtocolTest, parseQueryPOST) { + auto parse = + CPP_template_lambda()(typename RequestT)(const RequestT& request)( + requires ad_utility::httpUtils::HttpRequest) { + return SPARQLProtocol::parseSPARQLPOST( + request, SPARQLProtocol::contentTypeSparqlQuery); + }; + + // Query + EXPECT_THAT( + parse(makePostRequest("/", QUERY, "SELECT * WHERE {}")), + ParsedRequestIs("/", std::nullopt, {}, Query{"SELECT * WHERE {}", {}})); + EXPECT_THAT(parse(makePostRequest("/?send=100", QUERY, "SELECT * WHERE {}")), + ParsedRequestIs("/", std::nullopt, {{"send", {"100"}}}, + Query{"SELECT * WHERE {}", {}})); + // Check that the correct datasets for the method (GET or POST) are added + EXPECT_THAT( + parse(makePostRequest("/?default-" + "graph-uri=foo&named-graph-uri=bar&using-graph-uri=" + "baz&using-named-graph-uri=cat", + QUERY, "SELECT * WHERE {}")), + ParsedRequestIs("/", std::nullopt, + {{"default-graph-uri", {"foo"}}, + {"named-graph-uri", {"bar"}}, + {"using-graph-uri", {"baz"}}, + {"using-named-graph-uri", {"cat"}}}, + Query{"SELECT * WHERE {}", + {DatasetClause{Iri(""), false}, + DatasetClause{Iri(""), true}}})); + // Access token is read correctly + testAccessTokenCombinations(parse, http::verb::post, "/", Query{"a", {}}, + {{http::field::content_type, QUERY}}, "a"); + // Graph Store Operation (not allowed) + AD_EXPECT_THROW_WITH_MESSAGE( + parse(makePostRequest("/?graph=foo", QUERY, "")), + testing::HasSubstr( + "Unsupported Content type \"application/sparql-query\" for " + "Graph Store protocol.")); + AD_EXPECT_THROW_WITH_MESSAGE( + parse(makePostRequest("/?default", QUERY, "")), + testing::HasSubstr( + "Unsupported Content type \"application/sparql-query\" for " + "Graph Store protocol.")); +} + +// _____________________________________________________________________________________________ +TEST(SPARQLProtocolTest, parseUpdatePOST) { + auto parse = + CPP_template_lambda()(typename RequestT)(const RequestT& request)( + requires ad_utility::httpUtils::HttpRequest) { + return SPARQLProtocol::parseSPARQLPOST( + request, SPARQLProtocol::contentTypeSparqlUpdate); + }; + + // Update + EXPECT_THAT( + parse(makePostRequest("/", UPDATE, "DELETE * WHERE {}")), + ParsedRequestIs("/", std::nullopt, {}, Update{"DELETE * WHERE {}", {}})); EXPECT_THAT( parse(makePostRequest( "/?default-graph-uri=foo&named-graph-uri=bar&using-graph-uri=baz&" @@ -212,153 +400,105 @@ TEST(SPARQLProtocolTest, parseHttpRequest) { Update{"INSERT DATA {}", {DatasetClause{Iri(""), false}, DatasetClause{Iri(""), true}}})); - auto testAccessTokenCombinations = - [&](const http::verb& method, std::string_view pathBase, - const std::variant& expectedOperation, - const ad_utility::HashMap& headers = {}, - const std::optional& body = std::nullopt, - ad_utility::source_location l = - ad_utility::source_location::current()) { - auto t = generateLocationTrace(l); - // Test the cases: - // 1. No access token - // 2. Access token in query - // 3. Access token in `Authorization` header - // 4. Different access tokens - // 5. Same access token - boost::urls::url pathWithAccessToken{pathBase}; - pathWithAccessToken.params().append({"access-token", "foo"}); - ad_utility::HashMap - headersWithDifferentAccessToken{headers}; - headersWithDifferentAccessToken.insert( - {http::field::authorization, "Bearer bar"}); - ad_utility::HashMap - headersWithSameAccessToken{headers}; - headersWithSameAccessToken.insert( - {http::field::authorization, "Bearer foo"}); - EXPECT_THAT(parse(makeRequest(method, pathBase, headers, body)), - ParsedRequestIs("/", std::nullopt, {}, expectedOperation)); - EXPECT_THAT(parse(makeRequest(method, pathWithAccessToken.buffer(), - headers, body)), - ParsedRequestIs("/", "foo", {{"access-token", {"foo"}}}, - expectedOperation)); - EXPECT_THAT(parse(makeRequest(method, pathBase, - headersWithDifferentAccessToken, body)), - ParsedRequestIs("/", "bar", {}, expectedOperation)); - EXPECT_THAT(parse(makeRequest(method, pathWithAccessToken.buffer(), - headersWithSameAccessToken, body)), - ParsedRequestIs("/", "foo", {{"access-token", {"foo"}}}, - expectedOperation)); - AD_EXPECT_THROW_WITH_MESSAGE( - parse(makeRequest(method, pathWithAccessToken.buffer(), - headersWithDifferentAccessToken, body)), - testing::HasSubstr( - "Access token is specified both in the " - "`Authorization` header and by the `access-token` " - "parameter, but they are not the same")); - }; - testAccessTokenCombinations(http::verb::get, "/?query=a", Query{"a", {}}); - testAccessTokenCombinations(http::verb::post, "/", Query{"a", {}}, - {{http::field::content_type, QUERY}}, "a"); - testAccessTokenCombinations(http::verb::post, "/", Update{"a", {}}, + // Access token is read correctly + testAccessTokenCombinations(parse, http::verb::post, "/", Update{"a", {}}, {{http::field::content_type, UPDATE}}, "a"); - auto testAccessTokenCombinationsUrlEncoded = - [&](const std::string& bodyBase, - const std::variant& expectedOperation, - ad_utility::source_location l = - ad_utility::source_location::current()) { - auto t = generateLocationTrace(l); - // Test the cases: - // 1. No access token - // 2. Access token in query - // 3. Access token in `Authorization` header - // 4. Different access tokens - // 5. Same access token - boost::urls::url paramsWithAccessToken{absl::StrCat("/?", bodyBase)}; - paramsWithAccessToken.params().append({"access-token", "foo"}); - std::string bodyWithAccessToken{ - paramsWithAccessToken.encoded_params().buffer()}; - ad_utility::HashMap headers{ - {http::field::content_type, {URLENCODED}}}; - ad_utility::HashMap - headersWithDifferentAccessToken{ - {http::field::content_type, {URLENCODED}}, - {http::field::authorization, "Bearer bar"}}; - ad_utility::HashMap - headersWithSameAccessToken{ - {http::field::content_type, {URLENCODED}}, - {http::field::authorization, "Bearer foo"}}; - EXPECT_THAT( - parse(makeRequest(http::verb::post, "/", headers, bodyBase)), - ParsedRequestIs("/", std::nullopt, {}, expectedOperation)); - EXPECT_THAT(parse(makeRequest(http::verb::post, "/", headers, - bodyWithAccessToken)), - ParsedRequestIs("/", "foo", {{"access-token", {"foo"}}}, - expectedOperation)); - EXPECT_THAT( - parse(makeRequest(http::verb::post, "/", - headersWithDifferentAccessToken, bodyBase)), - ParsedRequestIs("/", "bar", {}, expectedOperation)); - EXPECT_THAT(parse(makeRequest(http::verb::post, "/", - headersWithSameAccessToken, bodyBase)), - ParsedRequestIs("/", "foo", {}, expectedOperation)); - AD_EXPECT_THROW_WITH_MESSAGE( - parse(makeRequest(http::verb::post, "/", - headersWithDifferentAccessToken, - bodyWithAccessToken)), - testing::HasSubstr( - "Access token is specified both in the " - "`Authorization` header and by the `access-token` " - "parameter, but they are not the same")); - }; - testAccessTokenCombinationsUrlEncoded("query=SELECT%20%2A%20WHERE%20%7B%7D", - Query{"SELECT * WHERE {}", {}}); - testAccessTokenCombinationsUrlEncoded("update=DELETE%20WHERE%20%7B%7D", - Update{"DELETE WHERE {}", {}}); + // Graph Store Protocol (not allowed) + AD_EXPECT_THROW_WITH_MESSAGE( + parse(makePostRequest("/?graph=foo", UPDATE, "")), + testing::HasSubstr( + "Unsupported Content type \"application/sparql-update\" for " + "Graph Store protocol.")); + AD_EXPECT_THROW_WITH_MESSAGE( + parse(makePostRequest("/?default", UPDATE, "")), + testing::HasSubstr( + "Unsupported Content type \"application/sparql-update\" for " + "Graph Store protocol.")); } -TEST(SPARQLProtocolTest, extractAccessToken) { - auto extract = +// _____________________________________________________________________________________________ +TEST(SPARQLProtocolTest, parsePOST) { + auto parse = CPP_template_lambda()(typename RequestT)(const RequestT& request)( requires ad_utility::httpUtils::HttpRequest) { - auto parsedUrl = parseRequestTarget(request.target()); - return SPARQLProtocol::extractAccessToken(request, parsedUrl.parameters_); + return SPARQLProtocol::parsePOST(request); }; - EXPECT_THAT(extract(makeGetRequest("/")), testing::Eq(std::nullopt)); - EXPECT_THAT(extract(makeGetRequest("/?access-token=foo")), - testing::Optional(testing::Eq("foo"))); + + // Query + EXPECT_THAT(parse(makePostRequest("/?access-token=foo", QUERY, + "SELECT * WHERE { ?s ?p ?o }")), + ParsedRequestIs("/", "foo", {{"access-token", {"foo"}}}, + Query{"SELECT * WHERE { ?s ?p ?o }", {}})); EXPECT_THAT( - extract(makeRequest(http::verb::get, "/", - {{http::field::authorization, "Bearer foo"}})), - testing::Optional(testing::Eq("foo"))); + parse(makePostRequest("/", URLENCODED, "access-token=foo&query=bar")), + ParsedRequestIs("/", "foo", {{"access-token", {"foo"}}}, + Query{"bar", {}})); + // Update + EXPECT_THAT(parse(makePostRequest("/?access-token=foo", UPDATE, + "INSERT DATA { }")), + ParsedRequestIs("/", "foo", {{"access-token", {"foo"}}}, + Update{"INSERT DATA { }", {}})); + // Update + // Graph Store Operation EXPECT_THAT( - extract(makeRequest(http::verb::get, "/?access-token=foo", - {{http::field::authorization, "Bearer foo"}})), - testing::Optional(testing::Eq("foo"))); - AD_EXPECT_THROW_WITH_MESSAGE( - extract(makeRequest(http::verb::get, "/?access-token=bar", - {{http::field::authorization, "Bearer foo"}})), - testing::HasSubstr( - "Access token is specified both in the `Authorization` header and by " - "the `access-token` parameter, but they are not the same")); - AD_EXPECT_THROW_WITH_MESSAGE( - extract(makeRequest(http::verb::get, "/", - {{http::field::authorization, "foo"}})), - testing::HasSubstr( - "Authorization header doesn't start with \"Bearer \".")); - EXPECT_THAT(extract(makePostRequest("/", "text/turtle", "")), - testing::Eq(std::nullopt)); - EXPECT_THAT(extract(makePostRequest("/?access-token=foo", "text/turtle", "")), - testing::Optional(testing::Eq("foo"))); + parse(makePostRequest("/?default", TURTLE, " .")), + ParsedRequestIs("/", std::nullopt, {{"default", {""}}}, + GraphStoreOperation{DEFAULT{}})); + EXPECT_THAT( + parse(makePostRequest("/?graph=foo", TURTLE, " .")), + ParsedRequestIs("/", std::nullopt, {{"graph", {"foo"}}}, + GraphStoreOperation{Iri("")})); + EXPECT_THAT( + parse(makePostRequest("/?graph=foo&access-token=secret", TURTLE, + " .")), + ParsedRequestIs("/", {"secret"}, + {{"graph", {"foo"}}, {"access-token", {"secret"}}}, + GraphStoreOperation{Iri("")})); + EXPECT_THAT(parse(makePostRequest("/?default&access-token=foo", TURTLE, + " ")), + ParsedRequestIs("/", "foo", + {{"access-token", {"foo"}}, {"default", {""}}}, + GraphStoreOperation{DEFAULT{}})); + EXPECT_THAT(parse(makeRequest(http::verb::post, "/?default", + {{http::field::authorization, {"Bearer foo"}}, + {http::field::content_type, {TURTLE}}}, + " ")), + ParsedRequestIs("/", "foo", {{"default", {""}}}, + GraphStoreOperation{DEFAULT{}})); + // Unsupported content type AD_EXPECT_THROW_WITH_MESSAGE( - extract(makeRequest(http::verb::post, "/?access-token=bar", - {{http::field::authorization, "Bearer foo"}})), + parse(makeRequest( + http::verb::post, "/", + {{http::field::content_type, {"unsupported/content-type"}}}, "")), testing::HasSubstr( - "Access token is specified both in the `Authorization` header and by " - "the `access-token` parameter, but they are not the same")); + R"(POST request with content type "unsupported/content-type" not supported (must be Query/Update with content type "application/x-www-form-urlencoded", "application/sparql-query" or "application/sparql-update" or a valid graph store protocol POST request)")); +} + +// _____________________________________________________________________________________________ +TEST(SPARQLProtocolTest, parseHttpRequest) { + auto parse = + CPP_template_lambda()(typename RequestT)(const RequestT& request)( + requires ad_utility::httpUtils::HttpRequest) { + return SPARQLProtocol::parseHttpRequest(request); + }; + + // Query + EXPECT_THAT(parse(makeGetRequest("/?query=foo")), + ParsedRequestIs("/", std::nullopt, {}, Query{"foo", {}})); + EXPECT_THAT( + parse(makePostRequest("/", URLENCODED, "access-token=foo&query=bar")), + ParsedRequestIs("/", "foo", {{"access-token", {"foo"}}}, + Query{"bar", {}})); + // Update + EXPECT_THAT(parse(makePostRequest("/?access-token=foo", UPDATE, + "INSERT DATA { }")), + ParsedRequestIs("/", "foo", {{"access-token", {"foo"}}}, + Update{"INSERT DATA { }", {}})); + + // Unsupported HTTP Method AD_EXPECT_THROW_WITH_MESSAGE( - extract(makeRequest(http::verb::post, "/?access-token=bar", - {{http::field::authorization, "foo"}})), - testing::HasSubstr( - "Authorization header doesn't start with \"Bearer \".")); + parse(makeRequest(http::verb::patch, "/")), + testing::StrEq("Request method \"PATCH\" not supported (only GET and " + "POST are supported; PUT, DELETE, HEAD and PATCH for " + "graph store protocol are not yet supported)")); } diff --git a/test/SparqlParserTest.cpp b/test/SparqlParserTest.cpp index ee65899b72..d19fc8d196 100644 --- a/test/SparqlParserTest.cpp +++ b/test/SparqlParserTest.cpp @@ -1384,3 +1384,26 @@ TEST(ParserTest, BaseDeclaration) { "SELECT * WHERE { ?s ?p ?o }"), ::testing::HasSubstr("absolute IRI"), InvalidSparqlQueryException); } + +TEST(ParserTest, parseWithDatasets) { + auto Iri = ad_utility::triple_component::Iri::fromIriref; + auto query = "SELECT * WHERE { ?s ?p ?o }"; + auto queryGraphPatternMatcher = + m::GraphPattern(m::Triples({{Var("?s"), "?p", Var("?o")}})); + EXPECT_THAT(SparqlParser::parseQuery(query, {}), + m::SelectQuery(m::AsteriskSelect(), queryGraphPatternMatcher)); + EXPECT_THAT( + SparqlParser::parseQuery(query, {DatasetClause{Iri(""), true}}), + m::SelectQuery(m::AsteriskSelect(), queryGraphPatternMatcher, + std::nullopt, {{Iri("")}})); + EXPECT_THAT( + SparqlParser::parseQuery(query, {DatasetClause{Iri(""), false}}), + m::SelectQuery(m::AsteriskSelect(), queryGraphPatternMatcher, + {{Iri("")}}, std::nullopt)); + EXPECT_THAT( + SparqlParser::parseQuery(query, {DatasetClause{Iri(""), false}, + DatasetClause{Iri(""), true}, + DatasetClause{Iri(""), false}}), + m::SelectQuery(m::AsteriskSelect(), queryGraphPatternMatcher, + {{Iri(""), Iri("")}}, {{Iri("")}})); +}