Skip to content

Commit 6df24d4

Browse files
committed
Make suggest work in folded space
1 parent 1a6e114 commit 6df24d4

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

ego/src/egor_service.rs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,10 @@ impl EgorServiceBuilder {
8787
Xoshiro256Plus::from_entropy()
8888
};
8989
EgorService {
90-
config: self.config.clone(),
90+
config: EgorConfig {
91+
xtypes: Some(continuous_xlimits_to_xtypes(xlimits)),
92+
..self.config.clone()
93+
},
9194
solver: EgorSolver::new(self.config, xlimits, rng),
9295
}
9396
}
@@ -102,7 +105,10 @@ impl EgorServiceBuilder {
102105
Xoshiro256Plus::from_entropy()
103106
};
104107
EgorService {
105-
config: self.config.clone(),
108+
config: EgorConfig {
109+
xtypes: Some(xtypes.to_vec()),
110+
..self.config.clone()
111+
},
106112
solver: EgorSolver::new_with_xtypes(self.config, xtypes, rng),
107113
}
108114
}
@@ -130,7 +136,11 @@ impl<SB: SurrogateBuilder> EgorService<SB> {
130136
x_data: &ArrayBase<impl Data<Elem = f64>, Ix2>,
131137
y_data: &ArrayBase<impl Data<Elem = f64>, Ix2>,
132138
) -> Array2<f64> {
133-
self.solver.suggest(x_data, y_data)
139+
let xtypes = self.config.xtypes.as_ref().unwrap();
140+
let x_data = unfold_with_enum_mask(xtypes, x_data);
141+
let x = self.solver.suggest(&x_data, y_data);
142+
let x = cast_to_discrete_values(xtypes, &x);
143+
fold_with_enum_index(xtypes, &x).to_owned()
134144
}
135145
}
136146

0 commit comments

Comments
 (0)