Skip to content

Commit

Permalink
Merge pull request #33264 from vespa-engine/toregge/adust-query-envir…
Browse files Browse the repository at this point in the history
…onment-api-to-return-field-length-info-instead-of-average-field-length

Adjust query environment API to return search::index::FieldLengthInfo
  • Loading branch information
toregge authored Feb 5, 2025
2 parents 4d3c2c0 + b69fdb3 commit 5b05733
Show file tree
Hide file tree
Showing 9 changed files with 29 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ QueryEnvironment::getAttributeContext() const
return _attrContext;
}

double
QueryEnvironment::get_average_field_length(const std::string &field_name) const
search::index::FieldLengthInfo
QueryEnvironment::get_field_length_info(const std::string &field_name) const
{
return _field_length_inspector.get_field_length_info(field_name).get_average_field_length();
return _field_length_inspector.get_field_length_info(field_name);
}

const search::fef::IIndexEnvironment &
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class QueryEnvironment : public search::fef::IQueryEnvironment
// inherited from search::fef::IQueryEnvironment
const search::attribute::IAttributeContext & getAttributeContext() const override;

double get_average_field_length(const std::string &field_name) const override;
search::index::FieldLengthInfo get_field_length_info(const std::string &field_name) const override;

// inherited from search::fef::IQueryEnvironment
const search::fef::IIndexEnvironment & getIndexEnvironment() const override;
Expand Down
5 changes: 3 additions & 2 deletions searchlib/src/tests/features/bm25/bm25_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,11 @@ TEST_F(Bm25BlueprintTest, blueprint_can_prepare_shared_state_with_average_field_
{
auto blueprint = expect_setup_succeed({"is"});
test::QueryEnvironment query_env;
query_env.get_avg_field_lengths()["is"] = 10;
query_env.get_field_length_info_map()["is"] =
search::index::FieldLengthInfo(10.0, 10.0, 1);
ObjectStore store;
blueprint->prepareSharedState(query_env, store);
EXPECT_DOUBLE_EQ(10, as_value<double>(*store.get("bm25.afl.is")));
EXPECT_DOUBLE_EQ(10.0, as_value<double>(*store.get("bm25.afl.is")));
}

TEST_F(Bm25BlueprintTest, dump_features_for_all_index_fields)
Expand Down
11 changes: 9 additions & 2 deletions searchlib/src/vespa/searchlib/features/bm25_feature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,14 +201,21 @@ make_avg_field_length_key(const std::string& base_name, const std::string& field
return base_name + ".afl." + field_name;
}

double
get_average_field_length(const search::fef::IQueryEnvironment& env, const std::string& field_name)
{
auto info = env.get_field_length_info(field_name);
return info.get_average_field_length();
}

}

void
Bm25Blueprint::prepareSharedState(const fef::IQueryEnvironment& env, fef::IObjectStore& store) const
{
std::string key = make_avg_field_length_key(getBaseName(), _field->name());
if (store.get(key) == nullptr) {
double avg_field_length = _avg_field_length.value_or(env.get_average_field_length(_field->name()));
double avg_field_length = _avg_field_length.value_or(get_average_field_length(env, _field->name()));
store.add(key, std::make_unique<AnyWrapper<double>>(avg_field_length));
}
}
Expand All @@ -219,7 +226,7 @@ Bm25Blueprint::createExecutor(const fef::IQueryEnvironment& env, vespalib::Stash
const auto* lookup_result = env.getObjectStore().get(make_avg_field_length_key(getBaseName(), _field->name()));
double avg_field_length = lookup_result != nullptr ?
as_value<double>(*lookup_result) :
_avg_field_length.value_or(env.get_average_field_length(_field->name()));
_avg_field_length.value_or(get_average_field_length(env, _field->name()));
return stash.create<Bm25Executor>(*_field, env, avg_field_length, _k1_param, _b_param);
}

Expand Down
3 changes: 2 additions & 1 deletion searchlib/src/vespa/searchlib/fef/iqueryenvironment.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "iindexenvironment.h"
#include "objectstore.h"
#include <vespa/searchcommon/attribute/iattributecontext.h>
#include <vespa/searchlib/index/field_length_info.h>

namespace search::common { struct GeoLocationSpec; }

Expand Down Expand Up @@ -80,7 +81,7 @@ class IQueryEnvironment
*
* @return average field length
**/
virtual double get_average_field_length(const std::string &field_name) const = 0;
virtual index::FieldLengthInfo get_field_length_info(const std::string &field_name) const = 0;

/**
* Returns a const view of the index environment.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class PhraseSplitterQueryEnv : public IQueryEnvironment
return _queryEnv.getAllLocations();
}
const attribute::IAttributeContext & getAttributeContext() const override { return _queryEnv.getAttributeContext(); }
double get_average_field_length(const std::string &field_name) const override { return _queryEnv.get_average_field_length(field_name); }
index::FieldLengthInfo get_field_length_info(const std::string &field_name) const override { return _queryEnv.get_field_length_info(field_name); }
const IIndexEnvironment & getIndexEnvironment() const override { return _queryEnv.getIndexEnvironment(); }

// Accessor methods used by PhraseSplitter
Expand Down
12 changes: 6 additions & 6 deletions searchlib/src/vespa/searchlib/fef/test/queryenvironment.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class QueryEnvironment : public IQueryEnvironment
Properties _properties;
std::vector<GeoLocationSpec> _locations;
search::attribute::IAttributeContext::UP _attrCtx;
std::unordered_map<std::string, double> _avg_field_lengths;
std::unordered_map<std::string, index::FieldLengthInfo> _field_length_info;

public:
/**
Expand All @@ -48,12 +48,12 @@ class QueryEnvironment : public IQueryEnvironment
return locations;
}
const search::attribute::IAttributeContext &getAttributeContext() const override { return *_attrCtx; }
double get_average_field_length(const std::string& field_name) const override {
auto itr = _avg_field_lengths.find(field_name);
if (itr != _avg_field_lengths.end()) {
index::FieldLengthInfo get_field_length_info(const std::string& field_name) const override {
auto itr = _field_length_info.find(field_name);
if (itr != _field_length_info.end()) {
return itr->second;
}
return 1.0;
return index::FieldLengthInfo(1.0, 1.0, 1);
}
const IIndexEnvironment &getIndexEnvironment() const override { assert(_indexEnv != NULL); return *_indexEnv; }

Expand Down Expand Up @@ -92,7 +92,7 @@ class QueryEnvironment : public IQueryEnvironment
/** Returns a reference to the location of this. */
void addLocation(const GeoLocationSpec &spec) { _locations.push_back(spec); }

std::unordered_map<std::string, double>& get_avg_field_lengths() { return _avg_field_lengths; }
std::unordered_map<std::string, index::FieldLengthInfo>& get_field_length_info_map() { return _field_length_info; }
};

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ QueryEnvironmentBuilder::add_node(const FieldInfo &info)
QueryEnvironmentBuilder&
QueryEnvironmentBuilder::set_avg_field_length(const std::string& field_name, double avg_field_length)
{
_queryEnv.get_avg_field_lengths()[field_name] = avg_field_length;
_queryEnv.get_field_length_info_map()[field_name] = index::FieldLengthInfo(avg_field_length, avg_field_length, 1);
return *this;
}

Expand Down
4 changes: 3 additions & 1 deletion streamingvisitors/src/vespa/searchvisitor/queryenvironment.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ class QueryEnvironment : public search::fef::IQueryEnvironment

GeoLocationSpecPtrs getAllLocations() const override;
const search::attribute::IAttributeContext & getAttributeContext() const override { return *_attrCtx; }
double get_average_field_length(const std::string &) const override { return 100.0; }
search::index::FieldLengthInfo get_field_length_info(const std::string &) const override {
return search::index::FieldLengthInfo(100.0, 100.0, 1);
}
const search::fef::IIndexEnvironment & getIndexEnvironment() const override { return _indexEnv; }
void addTerm(const search::fef::ITermData *term) { _queryTerms.push_back(term); }

Expand Down

0 comments on commit 5b05733

Please sign in to comment.