Skip to content

Commit ab0385d

Browse files
authored
feat(picker): add HighestToScore Picker (#105)
1 parent db1509d commit ab0385d

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ pub mod prelude {
192192
pub use big_brain_derive::{ActionBuilder, ScorerBuilder};
193193
pub use evaluators::{Evaluator, LinearEvaluator, PowerEvaluator, SigmoidEvaluator};
194194
pub use measures::{ChebyshevDistance, Measure, WeightedProduct, WeightedSum};
195-
pub use pickers::{FirstToScore, Highest, Picker};
195+
pub use pickers::{FirstToScore, Highest, HighestToScore, Picker};
196196
pub use scorers::{
197197
AllOrNothing, EvaluatingScorer, FixedScore, MeasuredScorer, ProductOfScorers, Score,
198198
ScorerBuilder, SumOfScorers, WinningScorer,

src/pickers.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,45 @@ impl Picker for Highest {
8383
})
8484
}
8585
}
86+
87+
/// Picker that chooses the highest `Choice` with a [`Score`] higher than its
88+
/// configured `threshold`.
89+
///
90+
/// ### Example
91+
///
92+
/// ```
93+
/// # use big_brain::prelude::*;
94+
/// # fn main() {
95+
/// Thinker::build()
96+
/// .picker(HighestToScore::new(0.8))
97+
/// // .when(...)
98+
/// # ;
99+
/// # }
100+
/// ```
101+
#[derive(Debug, Clone, Default)]
102+
pub struct HighestToScore {
103+
pub threshold: f32,
104+
}
105+
106+
impl HighestToScore {
107+
pub fn new(threshold: f32) -> Self {
108+
Self { threshold }
109+
}
110+
}
111+
112+
impl Picker for HighestToScore {
113+
fn pick<'a>(&self, choices: &'a [Choice], scores: &Query<&Score>) -> Option<&'a Choice> {
114+
let mut highest_score = 0f32;
115+
116+
choices.iter().fold(None, |acc, choice| {
117+
let score = choice.calculate(scores);
118+
119+
if score <= self.threshold || score <= highest_score {
120+
return acc;
121+
}
122+
123+
highest_score = score;
124+
Some(choice)
125+
})
126+
}
127+
}

0 commit comments

Comments
 (0)