@@ -2,12 +2,71 @@ use crate::correlation_models::{CorrelationModel, SquaredExponentialCorr};
2
2
use crate :: errors:: { GpError , Result } ;
3
3
use crate :: mean_models:: { ConstantMean , RegressionModel } ;
4
4
use linfa:: { Float , ParamGuard } ;
5
+ use std:: convert:: TryFrom ;
6
+
7
+ #[ cfg( feature = "serializable" ) ]
8
+ use serde:: { Deserialize , Serialize } ;
9
+
10
+ /// A structure to represent a n-dim parameter estimation
11
+ #[ derive( Clone , Debug , PartialEq , Eq ) ]
12
+ #[ cfg_attr( feature = "serializable" , derive( Serialize , Deserialize ) ) ]
13
+ pub struct ParamTuning < F : Float > {
14
+ pub guess : Vec < F > ,
15
+ pub bounds : Vec < ( F , F ) > ,
16
+ }
17
+
18
+ impl < F : Float > TryFrom < ParamTuning < F > > for ThetaTuning < F > {
19
+ type Error = GpError ;
20
+ fn try_from ( pt : ParamTuning < F > ) -> Result < ThetaTuning < F > > {
21
+ if pt. guess . len ( ) != pt. bounds . len ( ) && ( pt. guess . len ( ) != 1 || pt. bounds . len ( ) != 1 ) {
22
+ return Err ( GpError :: InvalidValueError (
23
+ "Bad theta tuning specification" . to_string ( ) ,
24
+ ) ) ;
25
+ }
26
+ // TODO: check if guess in bounds
27
+ Ok ( ThetaTuning ( pt) )
28
+ }
29
+ }
30
+
31
+ /// As structure for theta hyperparameters guess
32
+ #[ derive( Clone , Debug , PartialEq , Eq ) ]
33
+ #[ cfg_attr( feature = "serializable" , derive( Serialize , Deserialize ) ) ]
34
+
35
+ pub struct ThetaTuning < F : Float > ( ParamTuning < F > ) ;
36
+ impl < F : Float > Default for ThetaTuning < F > {
37
+ fn default ( ) -> ThetaTuning < F > {
38
+ ThetaTuning ( ParamTuning {
39
+ guess : vec ! [ F :: cast( 0.01 ) ] ,
40
+ bounds : vec ! [ ( F :: cast( 1e-6 ) , F :: cast( 1e2 ) ) ] ,
41
+ } )
42
+ }
43
+ }
44
+
45
+ impl < F : Float > From < ThetaTuning < F > > for ParamTuning < F > {
46
+ fn from ( tt : ThetaTuning < F > ) -> ParamTuning < F > {
47
+ ParamTuning {
48
+ guess : tt. 0 . guess ,
49
+ bounds : tt. 0 . bounds ,
50
+ }
51
+ }
52
+ }
53
+
54
+ impl < F : Float > ThetaTuning < F > {
55
+ pub fn theta0 ( & self ) -> & [ F ] {
56
+ & self . 0 . guess
57
+ }
58
+ pub fn bounds ( & self ) -> & [ ( F , F ) ] {
59
+ & self . 0 . bounds
60
+ }
61
+ }
5
62
6
63
/// A set of validated GP parameters.
7
64
#[ derive( Clone , Debug , PartialEq , Eq ) ]
8
65
pub struct GpValidParams < F : Float , Mean : RegressionModel < F > , Corr : CorrelationModel < F > > {
9
66
/// Parameter of the autocorrelation model
10
67
pub ( crate ) theta : Option < Vec < F > > ,
68
+ /// Parameter guess of the autocorrelation model
69
+ pub ( crate ) theta_tuning : ThetaTuning < F > ,
11
70
/// Regression model representing the mean(x)
12
71
pub ( crate ) mean : Mean ,
13
72
/// Correlation model representing the spatial correlation between errors at e(x) and e(x')
@@ -24,6 +83,7 @@ impl<F: Float> Default for GpValidParams<F, ConstantMean, SquaredExponentialCorr
24
83
fn default ( ) -> GpValidParams < F , ConstantMean , SquaredExponentialCorr > {
25
84
GpValidParams {
26
85
theta : None ,
86
+ theta_tuning : ThetaTuning :: default ( ) ,
27
87
mean : ConstantMean ( ) ,
28
88
corr : SquaredExponentialCorr ( ) ,
29
89
kpls_dim : None ,
@@ -34,11 +94,6 @@ impl<F: Float> Default for GpValidParams<F, ConstantMean, SquaredExponentialCorr
34
94
}
35
95
36
96
impl < F : Float , Mean : RegressionModel < F > , Corr : CorrelationModel < F > > GpValidParams < F , Mean , Corr > {
37
- /// Get starting theta value for optimization
38
- pub fn initial_theta ( & self ) -> & Option < Vec < F > > {
39
- & self . theta
40
- }
41
-
42
97
/// Get mean model
43
98
pub fn mean ( & self ) -> & Mean {
44
99
& self . mean
@@ -49,6 +104,11 @@ impl<F: Float, Mean: RegressionModel<F>, Corr: CorrelationModel<F>> GpValidParam
49
104
& self . corr
50
105
}
51
106
107
+ /// Get starting theta value for optimization
108
+ pub fn theta_tuning ( & self ) -> & ThetaTuning < F > {
109
+ & self . theta_tuning
110
+ }
111
+
52
112
/// Get number of components used by PLS
53
113
pub fn kpls_dim ( & self ) -> & Option < usize > {
54
114
& self . kpls_dim
@@ -77,6 +137,7 @@ impl<F: Float, Mean: RegressionModel<F>, Corr: CorrelationModel<F>> GpParams<F,
77
137
pub fn new ( mean : Mean , corr : Corr ) -> GpParams < F , Mean , Corr > {
78
138
Self ( GpValidParams {
79
139
theta : None ,
140
+ theta_tuning : ThetaTuning :: default ( ) ,
80
141
mean,
81
142
corr,
82
143
kpls_dim : None ,
@@ -85,14 +146,6 @@ impl<F: Float, Mean: RegressionModel<F>, Corr: CorrelationModel<F>> GpParams<F,
85
146
} )
86
147
}
87
148
88
- /// Set initial value for theta hyper parameter.
89
- ///
90
- /// During training process, the internal optimization is started from `initial_theta`.
91
- pub fn initial_theta ( mut self , theta : Option < Vec < F > > ) -> Self {
92
- self . 0 . theta = theta;
93
- self
94
- }
95
-
96
149
/// Set mean model.
97
150
pub fn mean ( mut self , mean : Mean ) -> Self {
98
151
self . 0 . mean = mean;
@@ -112,6 +165,36 @@ impl<F: Float, Mean: RegressionModel<F>, Corr: CorrelationModel<F>> GpParams<F,
112
165
self
113
166
}
114
167
168
+ /// Set initial value for theta hyper parameter.
169
+ ///
170
+ /// During training process, the internal optimization is started from `theta_guess`.
171
+ pub fn theta_guess ( mut self , theta_guess : Vec < F > ) -> Self {
172
+ self . 0 . theta_tuning = ParamTuning {
173
+ guess : theta_guess,
174
+ ..ThetaTuning :: default ( ) . into ( )
175
+ }
176
+ . try_into ( )
177
+ . unwrap ( ) ;
178
+ self
179
+ }
180
+
181
+ /// Set theta hyper parameter search space.
182
+ pub fn theta_bounds ( mut self , theta_bounds : Vec < ( F , F ) > ) -> Self {
183
+ self . 0 . theta_tuning = ParamTuning {
184
+ bounds : theta_bounds,
185
+ ..ThetaTuning :: default ( ) . into ( )
186
+ }
187
+ . try_into ( )
188
+ . unwrap ( ) ;
189
+ self
190
+ }
191
+
192
+ /// Set theta hyper parameter tuning
193
+ pub fn theta_tuning ( mut self , theta_tuning : ThetaTuning < F > ) -> Self {
194
+ self . 0 . theta_tuning = theta_tuning;
195
+ self
196
+ }
197
+
115
198
/// Set the number of internal GP hyperparameter theta optimization restarts
116
199
pub fn n_start ( mut self , n_start : usize ) -> Self {
117
200
self . 0 . n_start = n_start;
@@ -136,18 +219,19 @@ impl<F: Float, Mean: RegressionModel<F>, Corr: CorrelationModel<F>> ParamGuard
136
219
fn check_ref ( & self ) -> Result < & Self :: Checked > {
137
220
if let Some ( d) = self . 0 . kpls_dim {
138
221
if d == 0 {
139
- return Err ( GpError :: InvalidValue ( "`kpls_dim` canot be 0!" . to_string ( ) ) ) ;
222
+ return Err ( GpError :: InvalidValueError (
223
+ "`kpls_dim` canot be 0!" . to_string ( ) ,
224
+ ) ) ;
140
225
}
141
- if let Some ( theta) = self . 0 . initial_theta ( ) {
142
- if theta. len ( ) > 1 && d > theta. len ( ) {
143
- return Err ( GpError :: InvalidValue ( format ! (
144
- "Dimension reduction ({}) should be smaller than expected
226
+ let theta = self . 0 . theta_tuning ( ) . theta0 ( ) ;
227
+ if theta. len ( ) > 1 && d > theta. len ( ) {
228
+ return Err ( GpError :: InvalidValueError ( format ! (
229
+ "Dimension reduction ({}) should be smaller than expected
145
230
training input size infered from given initial theta length ({})" ,
146
- d,
147
- theta. len( )
148
- ) ) ) ;
149
- } ;
150
- }
231
+ d,
232
+ theta. len( )
233
+ ) ) ) ;
234
+ } ;
151
235
}
152
236
Ok ( & self . 0 )
153
237
}
0 commit comments