Skip to content

Commit ea1e7ed

Browse files
committed
Improvements to WarpedGP class
1 parent 57c82be commit ea1e7ed

File tree

2 files changed

+75
-37
lines changed

2 files changed

+75
-37
lines changed

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/models/gp/AbstractGPRegressionModel.scala

+15
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import io.github.mandar2812.dynaml.algebra.PartitionedMatrixSolvers._
2828
import io.github.mandar2812.dynaml.kernels.{DiracKernel, LocalScalarKernel, SVMKernel}
2929
import io.github.mandar2812.dynaml.models.{ContinuousProcess, SecondOrderProcess}
3030
import io.github.mandar2812.dynaml.optimization.GloballyOptWithGrad
31+
import io.github.mandar2812.dynaml.pipes.DataPipe
3132
import io.github.mandar2812.dynaml.probability.MultGaussianPRV
3233
import org.apache.log4j.Logger
3334

@@ -391,4 +392,18 @@ object AbstractGPRegressionModel {
391392
else if(order > 0 && ex == 0) new GPNarModel(order, cov, noise, data).asInstanceOf[M]
392393
else new GPNarXModel(order, ex, cov, noise, data).asInstanceOf[M]
393394
}
395+
396+
def apply[T, I](
397+
cov: LocalScalarKernel[I],
398+
noise: LocalScalarKernel[I])(
399+
trainingdata: T, num: Int)(
400+
implicit transform: DataPipe[T, Seq[(I, Double)]], ct: ClassTag[I]) =
401+
new AbstractGPRegressionModel[T, I](cov, noise, trainingdata, num) {
402+
/**
403+
* Convert from the underlying data structure to
404+
* Seq[(I, Y)] where I is the index set of the GP
405+
* and Y is the value/label type.
406+
**/
407+
override def dataAsSeq(data: T) = transform(data)
408+
}
394409
}

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/models/gp/WarpedGP.scala

+60-37
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import io.github.mandar2812.dynaml.algebra.{PartitionedMatrix, PartitionedVector
55
import io.github.mandar2812.dynaml.analysis.{DifferentiableMap, PartitionedVectorField, PushforwardMap}
66
import io.github.mandar2812.dynaml.models.{ContinuousProcess, SecondOrderProcess}
77
import io.github.mandar2812.dynaml.optimization.GloballyOptWithGrad
8-
import io.github.mandar2812.dynaml.pipes.DataPipe
8+
import io.github.mandar2812.dynaml.pipes.{DataPipe, Encoder}
99
import io.github.mandar2812.dynaml.probability.{E, MeasurableDistrRV}
1010
import io.github.mandar2812.dynaml.utils
1111

@@ -15,8 +15,10 @@ import scala.reflect.ClassTag
1515
* Created by mandar on 02/01/2017.
1616
*/
1717
abstract class WarpedGP[T, I](p: AbstractGPRegressionModel[T, I])(
18-
wFuncT: PushforwardMap[Double, Double, Double])(
19-
implicit ev: ClassTag[I], pf: PartitionedVectorField)
18+
warpingFunc: PushforwardMap[Double, Double, Double])(
19+
implicit ev: ClassTag[I],
20+
pf: PartitionedVectorField,
21+
transform: Encoder[T, Seq[(I, Double)]])
2022
extends ContinuousProcess[
2123
T, I, Double,
2224
MeasurableDistrRV[PartitionedVector, PartitionedVector, PartitionedMatrix]]
@@ -25,6 +27,41 @@ abstract class WarpedGP[T, I](p: AbstractGPRegressionModel[T, I])(
2527
MeasurableDistrRV[PartitionedVector, PartitionedVector, PartitionedMatrix]]
2628
with GloballyOptWithGrad {
2729

30+
/**
31+
* The training data
32+
**/
33+
override protected val g: T = p.data
34+
35+
private val dataProcessPipe = transform >
36+
DataPipe((s: Seq[(I, Double)]) => s.map(pattern => (pattern._1, warpingFunc.i(pattern._2)))) >
37+
transform.i
38+
39+
val underlyingProcess =
40+
AbstractGPRegressionModel[T, I](
41+
p.covariance, p.noiseModel)(
42+
dataProcessPipe(p.data), p.npoints)(transform, ev)
43+
44+
45+
/**
46+
* Mean Function: Takes a member of the index set (input)
47+
* and returns the corresponding mean of the distribution
48+
* corresponding to input.
49+
**/
50+
override val mean = p.mean
51+
/**
52+
* Underlying covariance function of the
53+
* Gaussian Processes.
54+
**/
55+
override val covariance = p.covariance
56+
/**
57+
* Stores the names of the hyper-parameters
58+
**/
59+
override protected var hyper_parameters: List[String] = underlyingProcess._hyper_parameters
60+
/**
61+
* A Map which stores the current state of
62+
* the system.
63+
**/
64+
override protected var current_state: Map[String, Double] = underlyingProcess._current_state
2865

2966
//Define the default determinant implementation
3067
implicit val detImpl = DataPipe(
@@ -33,11 +70,11 @@ abstract class WarpedGP[T, I](p: AbstractGPRegressionModel[T, I])(
3370
//Define the push forward map for the multivariate case
3471
val wFuncPredDistr: PushforwardMap[PartitionedVector, PartitionedVector, PartitionedMatrix] =
3572
PushforwardMap(
36-
DataPipe((v: PartitionedVector) => v.map(c => (c._1, c._2.map(wFuncT.run)))),
73+
DataPipe((v: PartitionedVector) => v.map(c => (c._1, c._2.map(warpingFunc.run)))),
3774
DifferentiableMap(
38-
(v: PartitionedVector) => v.map(c => (c._1, c._2.map(wFuncT.i.run))),
75+
(v: PartitionedVector) => v.map(c => (c._1, c._2.map(warpingFunc.i.run))),
3976
(v: PartitionedVector) => new PartitionedMatrix(
40-
v._data.map(l => ((l._1, l._1), diag(l._2.map(wFuncT.i.J)))) ++
77+
v._data.map(l => ((l._1, l._1), diag(l._2.map(warpingFunc.i.J)))) ++
4178
utils.combine(Seq((0 until v.rowBlocks.toInt).toList, (0 until v.rowBlocks.toInt).toList))
4279
.map(c =>
4380
(c.head.toLong, c.last.toLong))
@@ -52,28 +89,10 @@ abstract class WarpedGP[T, I](p: AbstractGPRegressionModel[T, I])(
5289
* 2) Y- : The lower error bar estimate (mean - sigma*stdDeviation)
5390
* 3) Y+ : The upper error bar. (mean + sigma*stdDeviation)
5491
**/
55-
override def predictionWithErrorBars[U <: Seq[I]](testData: U, sigma: Int) = ???
56-
57-
/**
58-
* Mean Function: Takes a member of the index set (input)
59-
* and returns the corresponding mean of the distribution
60-
* corresponding to input.
61-
**/
62-
override val mean = p.mean
63-
/**
64-
* Underlying covariance function of the
65-
* Gaussian Processes.
66-
**/
67-
override val covariance = p.covariance
68-
/**
69-
* Stores the names of the hyper-parameters
70-
**/
71-
override protected var hyper_parameters: List[String] = p._hyper_parameters
72-
/**
73-
* A Map which stores the current state of
74-
* the system.
75-
**/
76-
override protected var current_state: Map[String, Double] = p._current_state
92+
override def predictionWithErrorBars[U <: Seq[I]](testData: U, sigma: Int) =
93+
underlyingProcess
94+
.predictionWithErrorBars(testData, sigma)
95+
.map(d => (d._1, warpingFunc(d._2), warpingFunc(d._3), warpingFunc(d._4)))
7796

7897
/**
7998
* Calculates the energy of the configuration,
@@ -86,33 +105,37 @@ abstract class WarpedGP[T, I](p: AbstractGPRegressionModel[T, I])(
86105
* @param options Optional parameters about configuration
87106
* @return Configuration Energy E(h)
88107
**/
89-
override def energy(h: Map[String, Double], options: Map[String, String]) = p.energy(h, options)
108+
override def energy(h: Map[String, Double], options: Map[String, String]) = {
109+
val trainingLabels = PartitionedVector(
110+
dataAsSeq(g).toStream.map(_._2),
111+
underlyingProcess.npoints.toLong, underlyingProcess._blockSize
112+
)
113+
114+
detImpl(wFuncPredDistr.i.J(trainingLabels))*underlyingProcess.energy(h, options)
115+
}
116+
90117

91118
/** Calculates posterior predictive distribution for
92119
* a particular set of test data points.
93120
*
94121
* @param test A Sequence or Sequence like data structure
95122
* storing the values of the input patters.
96123
**/
97-
override def predictiveDistribution[U <: Seq[I]](test: U) = wFuncPredDistr -> p.predictiveDistribution(test)
124+
override def predictiveDistribution[U <: Seq[I]](test: U) =
125+
wFuncPredDistr -> underlyingProcess.predictiveDistribution(test)
98126

99127
/**
100128
* Convert from the underlying data structure to
101129
* Seq[(I, Y)] where I is the index set of the GP
102130
* and Y is the value/label type.
103131
**/
104-
override def dataAsSeq(data: T) = p.dataAsSeq(data)
105-
106-
/**
107-
* The training data
108-
**/
109-
override protected val g: T = p.data
132+
override def dataAsSeq(data: T) = transform(data)
110133

111134
/**
112135
* Predict the value of the
113136
* target variable given a
114137
* point.
115138
*
116139
**/
117-
override def predict(point: I) = wFuncT(p.predictionWithErrorBars(Seq(point), 1).head._2)
140+
override def predict(point: I) = warpingFunc(underlyingProcess.predictionWithErrorBars(Seq(point), 1).head._2)
118141
}

0 commit comments

Comments
 (0)