Skip to content

Commit 1391107

Browse files
ganikshauheen
authored andcommitted
Estimators for Timeseries SSA / IID ChangepointDetection and SpikeDetection transforms (#1254)
1 parent a039462 commit 1391107

16 files changed

+1204
-180
lines changed

src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs

+23-3
Original file line numberDiff line numberDiff line change
@@ -221,13 +221,16 @@ private static VersionInfo GetVersionInfo()
221221
{
222222
return new VersionInfo(
223223
modelSignature: "SSAMODLR",
224-
verWrittenCur: 0x00010001, // Initial
225-
verReadableCur: 0x00010001,
224+
//verWrittenCur: 0x00010001, // Initial
225+
verWrittenCur: 0x00010002, // Added saving _state and _nextPrediction
226+
verReadableCur: 0x00010002,
226227
verWeCanReadBack: 0x00010001,
227228
loaderSignature: LoaderSignature,
228229
loaderAssemblyName: typeof(AdaptiveSingularSpectrumSequenceModeler).Assembly.FullName);
229230
}
230231

232+
private const int VersionSavingStateAndPrediction = 0x00010002;
233+
231234
/// <summary>
232235
/// The constructor for Adaptive SSA model.
233236
/// </summary>
@@ -333,6 +336,7 @@ private AdaptiveSingularSpectrumSequenceModeler(AdaptiveSingularSpectrumSequence
333336
_autoregressionNoiseVariance = model._autoregressionNoiseVariance;
334337
_observationNoiseMean = model._observationNoiseMean;
335338
_autoregressionNoiseMean = model._autoregressionNoiseMean;
339+
_nextPrediction = model._nextPrediction;
336340
_maxTrendRatio = model._maxTrendRatio;
337341
_shouldStablize = model._shouldStablize;
338342
_shouldMaintainInfo = model._shouldMaintainInfo;
@@ -368,11 +372,13 @@ public AdaptiveSingularSpectrumSequenceModeler(IHostEnvironment env, ModelLoadCo
368372
// RankSelectionMethod: _rankSelectionMethod
369373
// bool: isWeightSet
370374
// float[]: _alpha
375+
// float[]: _state
371376
// bool: ShouldComputeForecastIntervals
372377
// float: _observationNoiseVariance
373378
// float: _autoregressionNoiseVariance
374379
// float: _observationNoiseMean
375380
// float: _autoregressionNoiseMean
381+
// float: _nextPrediction
376382
// int: _maxRank
377383
// bool: _shouldStablize
378384
// bool: _shouldMaintainInfo
@@ -404,6 +410,14 @@ public AdaptiveSingularSpectrumSequenceModeler(IHostEnvironment env, ModelLoadCo
404410
_alpha = ctx.Reader.ReadFloatArray();
405411
_host.CheckDecode(Utils.Size(_alpha) == _windowSize - 1);
406412

413+
if (ctx.Header.ModelVerReadable >= VersionSavingStateAndPrediction)
414+
{
415+
_state = ctx.Reader.ReadFloatArray();
416+
_host.CheckDecode(Utils.Size(_state) == _windowSize - 1);
417+
}
418+
else
419+
_state = new Single[_windowSize - 1];
420+
407421
ShouldComputeForecastIntervals = ctx.Reader.ReadBoolByte();
408422

409423
_observationNoiseVariance = ctx.Reader.ReadSingle();
@@ -414,6 +428,8 @@ public AdaptiveSingularSpectrumSequenceModeler(IHostEnvironment env, ModelLoadCo
414428

415429
_observationNoiseMean = ctx.Reader.ReadSingle();
416430
_autoregressionNoiseMean = ctx.Reader.ReadSingle();
431+
if (ctx.Header.ModelVerReadable >= VersionSavingStateAndPrediction)
432+
_nextPrediction = ctx.Reader.ReadSingle();
417433

418434
_maxRank = ctx.Reader.ReadInt32();
419435
_host.CheckDecode(1 <= _maxRank && _maxRank <= _windowSize - 1);
@@ -444,7 +460,7 @@ public AdaptiveSingularSpectrumSequenceModeler(IHostEnvironment env, ModelLoadCo
444460
}
445461

446462
_buffer = new FixedSizeQueue<Single>(_seriesLength);
447-
_state = new Single[_windowSize - 1];
463+
448464
_x = new CpuAlignedVector(_windowSize, SseUtils.CbAlign);
449465
_xSmooth = new CpuAlignedVector(_windowSize, SseUtils.CbAlign);
450466
}
@@ -475,11 +491,13 @@ public void Save(ModelSaveContext ctx)
475491
// RankSelectionMethod: _rankSelectionMethod
476492
// bool: _isWeightSet
477493
// float[]: _alpha
494+
// float[]: _state
478495
// bool: ShouldComputeForecastIntervals
479496
// float: _observationNoiseVariance
480497
// float: _autoregressionNoiseVariance
481498
// float: _observationNoiseMean
482499
// float: _autoregressionNoiseMean
500+
// float: _nextPrediction
483501
// int: _maxRank
484502
// bool: _shouldStablize
485503
// bool: _shouldMaintainInfo
@@ -494,11 +512,13 @@ public void Save(ModelSaveContext ctx)
494512
ctx.Writer.Write((byte)_rankSelectionMethod);
495513
ctx.Writer.WriteBoolByte(_wTrans != null);
496514
ctx.Writer.WriteFloatArray(_alpha);
515+
ctx.Writer.WriteFloatArray(_state);
497516
ctx.Writer.WriteBoolByte(ShouldComputeForecastIntervals);
498517
ctx.Writer.Write(_observationNoiseVariance);
499518
ctx.Writer.Write(_autoregressionNoiseVariance);
500519
ctx.Writer.Write(_observationNoiseMean);
501520
ctx.Writer.Write(_autoregressionNoiseMean);
521+
ctx.Writer.Write(_nextPrediction);
502522
ctx.Writer.Write(_maxRank);
503523
ctx.Writer.WriteBoolByte(_shouldStablize);
504524
ctx.Writer.WriteBoolByte(_shouldMaintainInfo);

src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs

+18-4
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,32 @@ namespace Microsoft.ML.Runtime.TimeSeriesProcessing
1515
/// </summary>
1616
public abstract class IidAnomalyDetectionBase : SequentialAnomalyDetectionTransformBase<Single, IidAnomalyDetectionBase.State>
1717
{
18-
public IidAnomalyDetectionBase(ArgumentsBase args, string name, IHostEnvironment env, IDataView input)
19-
: base(args, name, env, input)
18+
public IidAnomalyDetectionBase(ArgumentsBase args, string name, IHostEnvironment env)
19+
: base(args, name, env)
2020
{
2121
InitialWindowSize = 0;
2222
}
2323

24-
public IidAnomalyDetectionBase(IHostEnvironment env, ModelLoadContext ctx, string name, IDataView input)
25-
: base(env, ctx, name, input)
24+
public IidAnomalyDetectionBase(IHostEnvironment env, ModelLoadContext ctx, string name)
25+
: base(env, ctx, name)
2626
{
2727
Host.CheckDecode(InitialWindowSize == 0);
2828
}
2929

30+
public override Schema GetOutputSchema(Schema inputSchema)
31+
{
32+
Host.CheckValue(inputSchema, nameof(inputSchema));
33+
34+
if (!inputSchema.TryGetColumnIndex(InputColumnName, out var col))
35+
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", InputColumnName);
36+
37+
var colType = inputSchema.GetColumnType(col);
38+
if (colType != NumberType.R4)
39+
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", InputColumnName, NumberType.R4.ToString(), colType.ToString());
40+
41+
return Transform(new EmptyDataView(Host, inputSchema)).Schema;
42+
}
43+
3044
public override void Save(ModelSaveContext ctx)
3145
{
3246
ctx.CheckAtModel();

src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs

+108-11
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,35 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6+
using System.Collections.Generic;
7+
using System.Linq;
8+
using Microsoft.ML.Core.Data;
69
using Microsoft.ML.Runtime;
710
using Microsoft.ML.Runtime.CommandLine;
811
using Microsoft.ML.Runtime.Data;
912
using Microsoft.ML.Runtime.EntryPoints;
1013
using Microsoft.ML.Runtime.Model;
1114
using Microsoft.ML.Runtime.TimeSeriesProcessing;
15+
using static Microsoft.ML.Runtime.TimeSeriesProcessing.SequentialAnomalyDetectionTransformBase<System.Single, Microsoft.ML.Runtime.TimeSeriesProcessing.IidAnomalyDetectionBase.State>;
1216

13-
[assembly: LoadableClass(IidChangePointDetector.Summary, typeof(IidChangePointDetector), typeof(IidChangePointDetector.Arguments), typeof(SignatureDataTransform),
17+
[assembly: LoadableClass(IidChangePointDetector.Summary, typeof(IDataTransform), typeof(IidChangePointDetector), typeof(IidChangePointDetector.Arguments), typeof(SignatureDataTransform),
1418
IidChangePointDetector.UserName, IidChangePointDetector.LoaderSignature, IidChangePointDetector.ShortName)]
15-
[assembly: LoadableClass(IidChangePointDetector.Summary, typeof(IidChangePointDetector), null, typeof(SignatureLoadDataTransform),
19+
20+
[assembly: LoadableClass(IidChangePointDetector.Summary, typeof(IDataTransform), typeof(IidChangePointDetector), null, typeof(SignatureLoadDataTransform),
21+
IidChangePointDetector.UserName, IidChangePointDetector.LoaderSignature)]
22+
23+
[assembly: LoadableClass(IidChangePointDetector.Summary, typeof(IidChangePointDetector), null, typeof(SignatureLoadModel),
1624
IidChangePointDetector.UserName, IidChangePointDetector.LoaderSignature)]
1725

26+
[assembly: LoadableClass(typeof(IRowMapper), typeof(IidChangePointDetector), null, typeof(SignatureLoadRowMapper),
27+
IidChangePointDetector.UserName, IidChangePointDetector.LoaderSignature)]
28+
1829
namespace Microsoft.ML.Runtime.TimeSeriesProcessing
1930
{
2031
/// <summary>
2132
/// This class implements the change point detector transform for an i.i.d. sequence based on adaptive kernel density estimation and martingales.
2233
/// </summary>
23-
public sealed class IidChangePointDetector : IidAnomalyDetectionBase, ITransformTemplate
34+
public sealed class IidChangePointDetector : IidAnomalyDetectionBase
2435
{
2536
internal const string Summary = "This transform detects the change-points in an i.i.d. sequence using adaptive kernel density estimation and martingales.";
2637
public const string LoaderSignature = "IidChangePointDetector";
@@ -89,8 +100,18 @@ private static VersionInfo GetVersionInfo()
89100
loaderAssemblyName: typeof(IidChangePointDetector).Assembly.FullName);
90101
}
91102

92-
public IidChangePointDetector(IHostEnvironment env, Arguments args, IDataView input)
93-
: base(new BaseArguments(args), LoaderSignature, env, input)
103+
// Factory method for SignatureDataTransform.
104+
private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
105+
{
106+
Contracts.CheckValue(env, nameof(env));
107+
env.CheckValue(args, nameof(args));
108+
env.CheckValue(input, nameof(input));
109+
110+
return new IidChangePointDetector(env, args).MakeDataTransform(input);
111+
}
112+
113+
internal IidChangePointDetector(IHostEnvironment env, Arguments args)
114+
: base(new BaseArguments(args), LoaderSignature, env)
94115
{
95116
switch (Martingale)
96117
{
@@ -109,8 +130,28 @@ public IidChangePointDetector(IHostEnvironment env, Arguments args, IDataView in
109130
}
110131
}
111132

112-
public IidChangePointDetector(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
113-
: base(env, ctx, LoaderSignature, input)
133+
// Factory method for SignatureLoadDataTransform.
134+
private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
135+
{
136+
Contracts.CheckValue(env, nameof(env));
137+
env.CheckValue(ctx, nameof(ctx));
138+
env.CheckValue(input, nameof(input));
139+
140+
return new IidChangePointDetector(env, ctx).MakeDataTransform(input);
141+
}
142+
143+
// Factory method for SignatureLoadModel.
144+
private static IidChangePointDetector Create(IHostEnvironment env, ModelLoadContext ctx)
145+
{
146+
Contracts.CheckValue(env, nameof(env));
147+
env.CheckValue(ctx, nameof(ctx));
148+
ctx.CheckAtModel(GetVersionInfo());
149+
150+
return new IidChangePointDetector(env, ctx);
151+
}
152+
153+
internal IidChangePointDetector(IHostEnvironment env, ModelLoadContext ctx)
154+
: base(env, ctx, LoaderSignature)
114155
{
115156
// *** Binary format ***
116157
// <base>
@@ -119,8 +160,8 @@ public IidChangePointDetector(IHostEnvironment env, ModelLoadContext ctx, IDataV
119160
Host.CheckDecode(Side == AnomalySide.TwoSided);
120161
}
121162

122-
private IidChangePointDetector(IHostEnvironment env, IidChangePointDetector transform, IDataView newSource)
123-
: base(new BaseArguments(transform), LoaderSignature, env, newSource)
163+
private IidChangePointDetector(IHostEnvironment env, IidChangePointDetector transform)
164+
: base(new BaseArguments(transform), LoaderSignature, env)
124165
{
125166
}
126167

@@ -139,9 +180,65 @@ public override void Save(ModelSaveContext ctx)
139180
base.Save(ctx);
140181
}
141182

142-
public IDataTransform ApplyToData(IHostEnvironment env, IDataView newSource)
183+
// Factory method for SignatureLoadRowMapper.
184+
private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema)
185+
=> Create(env, ctx).MakeRowMapper(inputSchema);
186+
}
187+
188+
/// <summary>
189+
/// Estimator for <see cref="IidChangePointDetector"/>
190+
/// </summary>
191+
public sealed class IidChangePointEstimator : TrivialEstimator<IidChangePointDetector>
192+
{
193+
/// <summary>
194+
/// Create a new instance of <see cref="IidChangePointEstimator"/>
195+
/// </summary>
196+
/// <param name="env">Host Environment.</param>
197+
/// <param name="inputColumn">Name of the input column.</param>
198+
/// <param name="outputColumn">The name of the new column.</param>
199+
/// <param name="confidence">The confidence for change point detection in the range [0, 100].</param>
200+
/// <param name="changeHistoryLength">The change history length.</param>
201+
/// <param name="martingale">The martingale used for scoring.</param>
202+
/// <param name="eps">The epsilon parameter for the Power martingale.</param>
203+
public IidChangePointEstimator(IHostEnvironment env, string inputColumn, string outputColumn, int confidence,
204+
int changeHistoryLength, MartingaleType martingale = MartingaleType.Power, double eps = 0.1)
205+
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(IidChangePointEstimator)),
206+
new IidChangePointDetector(env, new IidChangePointDetector.Arguments
207+
{
208+
Name = outputColumn,
209+
Source = inputColumn,
210+
Confidence = confidence,
211+
ChangeHistoryLength = changeHistoryLength,
212+
Martingale = martingale,
213+
PowerMartingaleEpsilon = eps
214+
}))
215+
{
216+
}
217+
218+
public IidChangePointEstimator(IHostEnvironment env, IidChangePointDetector.Arguments args)
219+
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(IidChangePointEstimator)),
220+
new IidChangePointDetector(env, args))
143221
{
144-
return new IidChangePointDetector(env, this, newSource);
222+
}
223+
224+
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
225+
{
226+
Host.CheckValue(inputSchema, nameof(inputSchema));
227+
228+
if (!inputSchema.TryFindColumn(Transformer.InputColumnName, out var col))
229+
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Transformer.InputColumnName);
230+
if (col.ItemType != NumberType.R4)
231+
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Transformer.InputColumnName, NumberType.R4.ToString(), col.GetTypeString());
232+
233+
var metadata = new List<SchemaShape.Column>() {
234+
new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false)
235+
};
236+
var resultDic = inputSchema.Columns.ToDictionary(x => x.Name);
237+
238+
resultDic[Transformer.OutputColumnName] = new SchemaShape.Column(
239+
Transformer.OutputColumnName, SchemaShape.Column.VectorKind.Vector, NumberType.R8, false, new SchemaShape(metadata));
240+
241+
return new SchemaShape(resultDic.Values);
145242
}
146243
}
147244
}

0 commit comments

Comments
 (0)