Skip to content

Estimators for Timeseries SSA / IID ChangepointDetection and SpikeDetection transforms #1254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 45 commits into from
Oct 31, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
f5e421e
Merge pull request #1 from dotnet/master
ganik Oct 10, 2018
facc77f
Made estimators for Timeseries SSA and IID ChangePointDetection, SSA …
ganik Oct 13, 2018
a586e83
Add more unit tests
ganik Oct 13, 2018
477f4eb
fix tests
ganik Oct 15, 2018
6b1b119
fix unit tests
ganik Oct 15, 2018
8647659
fix tests
ganik Oct 16, 2018
13f5f76
Merge pull request #2 from dotnet/master
ganik Oct 16, 2018
db25cb9
Merge branch 'master' into ganik/ssa
ganik Oct 16, 2018
e9c2d15
fix build
ganik Oct 16, 2018
da77379
fix comments
ganik Oct 17, 2018
ee97924
fix build
ganik Oct 17, 2018
a4853ec
fix unit tests
ganik Oct 17, 2018
663ca21
disabling SavePipeIidSpike and SavePipeIidChangePoint tests as they r…
ganik Oct 18, 2018
23aa3e3
fix typo
ganik Oct 18, 2018
8488192
fix comments
ganik Oct 18, 2018
1f61237
fix comments
ganik Oct 18, 2018
fad2656
fix comments
ganik Oct 18, 2018
ff90160
fix comments
ganik Oct 18, 2018
0c0b0ea
unit tests with invalid schema
ganik Oct 19, 2018
5ef10f3
remove unused type
ganik Oct 19, 2018
e67284c
fix comments
ganik Oct 20, 2018
99dd0e1
fix build
ganik Oct 21, 2018
4e67857
fix comments
ganik Oct 21, 2018
89a6d6a
fix build
ganik Oct 22, 2018
d75ec11
fix typo
ganik Oct 22, 2018
a563462
disable unit tests temporarily
ganik Oct 22, 2018
b9c9c49
fix comments
ganik Oct 28, 2018
3fa9f95
indent
ganik Oct 28, 2018
e140152
indent
ganik Oct 28, 2018
1aa7f55
Merge pull request #3 from dotnet/master
ganik Oct 28, 2018
1ccd234
Merge branch 'master' into ganik/ssa
ganik Oct 28, 2018
721e04f
fix tests
ganik Oct 29, 2018
6f92e4d
fix tests
ganik Oct 29, 2018
1b3c263
fix tests
ganik Oct 29, 2018
3f1f478
enable tests back
ganik Oct 29, 2018
28cc331
make Mapper private. Implement ITransformTemplate
ganik Oct 29, 2018
a6a03c7
Batch prediction test
ganik Oct 29, 2018
65108fd
fix comments
ganik Oct 29, 2018
48b0ddd
fix test
ganik Oct 30, 2018
78a0d72
fix comments
ganik Oct 30, 2018
58622f8
fix backward compat
ganik Oct 30, 2018
ba876d2
fix comments
ganik Oct 30, 2018
8320f24
fix comments
ganik Oct 30, 2018
546f2fa
Merge pull request #4 from dotnet/master
ganik Oct 31, 2018
c5c262e
Merge branch 'master' into ganik/ssa
ganik Oct 31, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,16 @@ private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "SSAMODLR",
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
//verWrittenCur: 0x00010001, // Initial
verWrittenCur: 0x00010002, // Added saving _state and _nextPrediction
Copy link
Contributor

@Zruty0 Zruty0 Oct 30, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

saving _state [](start = 52, length = 13)

Why do you need to do it now? Is this for the stateful row and spawning off new models? #Closed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the preexisting test TestEstimatorCore() will pass on new timeseries estimators. One of the test steps in TestEstimatorCore() is to save and re-load transformer from disk. Without this fix saving and reloading of timeseries SSA transformers changes its prediction due to lost _state and _nextPrediction


In reply to: 229477366 [](ancestors = 229477366)

verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(AdaptiveSingularSpectrumSequenceModeler).Assembly.FullName);
}

private const int VersionSavingStateAndPrediction = 0x00010002;

/// <summary>
/// The constructor for Adaptive SSA model.
/// </summary>
Expand Down Expand Up @@ -333,6 +336,7 @@ private AdaptiveSingularSpectrumSequenceModeler(AdaptiveSingularSpectrumSequence
_autoregressionNoiseVariance = model._autoregressionNoiseVariance;
_observationNoiseMean = model._observationNoiseMean;
_autoregressionNoiseMean = model._autoregressionNoiseMean;
_nextPrediction = model._nextPrediction;
_maxTrendRatio = model._maxTrendRatio;
_shouldStablize = model._shouldStablize;
_shouldMaintainInfo = model._shouldMaintainInfo;
Expand Down Expand Up @@ -368,11 +372,13 @@ public AdaptiveSingularSpectrumSequenceModeler(IHostEnvironment env, ModelLoadCo
// RankSelectionMethod: _rankSelectionMethod
// bool: isWeightSet
// float[]: _alpha
// float[]: _state
// bool: ShouldComputeForecastIntervals
// float: _observationNoiseVariance
// float: _autoregressionNoiseVariance
// float: _observationNoiseMean
// float: _autoregressionNoiseMean
// float: _nextPrediction
// int: _maxRank
// bool: _shouldStablize
// bool: _shouldMaintainInfo
Expand Down Expand Up @@ -404,6 +410,14 @@ public AdaptiveSingularSpectrumSequenceModeler(IHostEnvironment env, ModelLoadCo
_alpha = ctx.Reader.ReadFloatArray();
_host.CheckDecode(Utils.Size(_alpha) == _windowSize - 1);

if (ctx.Header.ModelVerReadable >= VersionSavingStateAndPrediction)
{
_state = ctx.Reader.ReadFloatArray();
_host.CheckDecode(Utils.Size(_state) == _windowSize - 1);
}
else
_state = new Single[_windowSize - 1];

ShouldComputeForecastIntervals = ctx.Reader.ReadBoolByte();

_observationNoiseVariance = ctx.Reader.ReadSingle();
Expand All @@ -414,6 +428,8 @@ public AdaptiveSingularSpectrumSequenceModeler(IHostEnvironment env, ModelLoadCo

_observationNoiseMean = ctx.Reader.ReadSingle();
_autoregressionNoiseMean = ctx.Reader.ReadSingle();
if (ctx.Header.ModelVerReadable >= VersionSavingStateAndPrediction)
_nextPrediction = ctx.Reader.ReadSingle();

_maxRank = ctx.Reader.ReadInt32();
_host.CheckDecode(1 <= _maxRank && _maxRank <= _windowSize - 1);
Expand Down Expand Up @@ -444,7 +460,7 @@ public AdaptiveSingularSpectrumSequenceModeler(IHostEnvironment env, ModelLoadCo
}

_buffer = new FixedSizeQueue<Single>(_seriesLength);
_state = new Single[_windowSize - 1];

_x = new CpuAlignedVector(_windowSize, SseUtils.CbAlign);
_xSmooth = new CpuAlignedVector(_windowSize, SseUtils.CbAlign);
}
Expand Down Expand Up @@ -475,11 +491,13 @@ public void Save(ModelSaveContext ctx)
// RankSelectionMethod: _rankSelectionMethod
// bool: _isWeightSet
// float[]: _alpha
// float[]: _state
// bool: ShouldComputeForecastIntervals
// float: _observationNoiseVariance
// float: _autoregressionNoiseVariance
// float: _observationNoiseMean
// float: _autoregressionNoiseMean
// float: _nextPrediction
// int: _maxRank
// bool: _shouldStablize
// bool: _shouldMaintainInfo
Expand All @@ -494,11 +512,13 @@ public void Save(ModelSaveContext ctx)
ctx.Writer.Write((byte)_rankSelectionMethod);
ctx.Writer.WriteBoolByte(_wTrans != null);
ctx.Writer.WriteFloatArray(_alpha);
ctx.Writer.WriteFloatArray(_state);
ctx.Writer.WriteBoolByte(ShouldComputeForecastIntervals);
ctx.Writer.Write(_observationNoiseVariance);
ctx.Writer.Write(_autoregressionNoiseVariance);
ctx.Writer.Write(_observationNoiseMean);
ctx.Writer.Write(_autoregressionNoiseMean);
ctx.Writer.Write(_nextPrediction);
ctx.Writer.Write(_maxRank);
ctx.Writer.WriteBoolByte(_shouldStablize);
ctx.Writer.WriteBoolByte(_shouldMaintainInfo);
Expand Down
22 changes: 18 additions & 4 deletions src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,32 @@ namespace Microsoft.ML.Runtime.TimeSeriesProcessing
/// </summary>
public abstract class IidAnomalyDetectionBase : SequentialAnomalyDetectionTransformBase<Single, IidAnomalyDetectionBase.State>
{
public IidAnomalyDetectionBase(ArgumentsBase args, string name, IHostEnvironment env, IDataView input)
: base(args, name, env, input)
public IidAnomalyDetectionBase(ArgumentsBase args, string name, IHostEnvironment env)
: base(args, name, env)
{
InitialWindowSize = 0;
}

public IidAnomalyDetectionBase(IHostEnvironment env, ModelLoadContext ctx, string name, IDataView input)
: base(env, ctx, name, input)
public IidAnomalyDetectionBase(IHostEnvironment env, ModelLoadContext ctx, string name)
: base(env, ctx, name)
{
Host.CheckDecode(InitialWindowSize == 0);
}

public override Schema GetOutputSchema(Schema inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));

if (!inputSchema.TryGetColumnIndex(InputColumnName, out var col))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", InputColumnName);

var colType = inputSchema.GetColumnType(col);
if (colType != NumberType.R4)
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", InputColumnName, NumberType.R4.ToString(), colType.ToString());

return Transform(new EmptyDataView(Host, inputSchema)).Schema;
}

public override void Save(ModelSaveContext ctx)
{
ctx.CheckAtModel();
Expand Down
119 changes: 108 additions & 11 deletions src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,35 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Model;
using Microsoft.ML.Runtime.TimeSeriesProcessing;
using static Microsoft.ML.Runtime.TimeSeriesProcessing.SequentialAnomalyDetectionTransformBase<System.Single, Microsoft.ML.Runtime.TimeSeriesProcessing.IidAnomalyDetectionBase.State>;

[assembly: LoadableClass(IidChangePointDetector.Summary, typeof(IidChangePointDetector), typeof(IidChangePointDetector.Arguments), typeof(SignatureDataTransform),
[assembly: LoadableClass(IidChangePointDetector.Summary, typeof(IDataTransform), typeof(IidChangePointDetector), typeof(IidChangePointDetector.Arguments), typeof(SignatureDataTransform),
IidChangePointDetector.UserName, IidChangePointDetector.LoaderSignature, IidChangePointDetector.ShortName)]
[assembly: LoadableClass(IidChangePointDetector.Summary, typeof(IidChangePointDetector), null, typeof(SignatureLoadDataTransform),

[assembly: LoadableClass(IidChangePointDetector.Summary, typeof(IDataTransform), typeof(IidChangePointDetector), null, typeof(SignatureLoadDataTransform),
IidChangePointDetector.UserName, IidChangePointDetector.LoaderSignature)]

[assembly: LoadableClass(IidChangePointDetector.Summary, typeof(IidChangePointDetector), null, typeof(SignatureLoadModel),
IidChangePointDetector.UserName, IidChangePointDetector.LoaderSignature)]

[assembly: LoadableClass(typeof(IRowMapper), typeof(IidChangePointDetector), null, typeof(SignatureLoadRowMapper),
IidChangePointDetector.UserName, IidChangePointDetector.LoaderSignature)]

namespace Microsoft.ML.Runtime.TimeSeriesProcessing
{
/// <summary>
/// This class implements the change point detector transform for an i.i.d. sequence based on adaptive kernel density estimation and martingales.
/// </summary>
public sealed class IidChangePointDetector : IidAnomalyDetectionBase, ITransformTemplate
public sealed class IidChangePointDetector : IidAnomalyDetectionBase
{
internal const string Summary = "This transform detects the change-points in an i.i.d. sequence using adaptive kernel density estimation and martingales.";
public const string LoaderSignature = "IidChangePointDetector";
Expand Down Expand Up @@ -89,8 +100,18 @@ private static VersionInfo GetVersionInfo()
loaderAssemblyName: typeof(IidChangePointDetector).Assembly.FullName);
}

public IidChangePointDetector(IHostEnvironment env, Arguments args, IDataView input)
: base(new BaseArguments(args), LoaderSignature, env, input)
// Factory method for SignatureDataTransform.
private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(args, nameof(args));
env.CheckValue(input, nameof(input));

return new IidChangePointDetector(env, args).MakeDataTransform(input);
}

internal IidChangePointDetector(IHostEnvironment env, Arguments args)
: base(new BaseArguments(args), LoaderSignature, env)
{
switch (Martingale)
{
Expand All @@ -109,8 +130,28 @@ public IidChangePointDetector(IHostEnvironment env, Arguments args, IDataView in
}
}

public IidChangePointDetector(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
: base(env, ctx, LoaderSignature, input)
// Factory method for SignatureLoadDataTransform.
private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
env.CheckValue(input, nameof(input));

return new IidChangePointDetector(env, ctx).MakeDataTransform(input);
}

// Factory method for SignatureLoadModel.
private static IidChangePointDetector Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());

return new IidChangePointDetector(env, ctx);
}

internal IidChangePointDetector(IHostEnvironment env, ModelLoadContext ctx)
: base(env, ctx, LoaderSignature)
{
// *** Binary format ***
// <base>
Expand All @@ -119,8 +160,8 @@ public IidChangePointDetector(IHostEnvironment env, ModelLoadContext ctx, IDataV
Host.CheckDecode(Side == AnomalySide.TwoSided);
}

private IidChangePointDetector(IHostEnvironment env, IidChangePointDetector transform, IDataView newSource)
: base(new BaseArguments(transform), LoaderSignature, env, newSource)
private IidChangePointDetector(IHostEnvironment env, IidChangePointDetector transform)
: base(new BaseArguments(transform), LoaderSignature, env)
{
}

Expand All @@ -139,9 +180,65 @@ public override void Save(ModelSaveContext ctx)
base.Save(ctx);
}

public IDataTransform ApplyToData(IHostEnvironment env, IDataView newSource)
// Factory method for SignatureLoadRowMapper.
private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema)
=> Create(env, ctx).MakeRowMapper(inputSchema);
}

/// <summary>
/// Estimator for <see cref="IidChangePointDetector"/>
/// </summary>
public sealed class IidChangePointEstimator : TrivialEstimator<IidChangePointDetector>
{
/// <summary>
/// Create a new instance of <see cref="IidChangePointEstimator"/>
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="inputColumn">Name of the input column.</param>
/// <param name="outputColumn">The name of the new column.</param>
/// <param name="confidence">The confidence for change point detection in the range [0, 100].</param>
/// <param name="changeHistoryLength">The change history length.</param>
/// <param name="martingale">The martingale used for scoring.</param>
/// <param name="eps">The epsilon parameter for the Power martingale.</param>
public IidChangePointEstimator(IHostEnvironment env, string inputColumn, string outputColumn, int confidence,
int changeHistoryLength, MartingaleType martingale = MartingaleType.Power, double eps = 0.1)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(IidChangePointEstimator)),
new IidChangePointDetector(env, new IidChangePointDetector.Arguments
{
Name = outputColumn,
Source = inputColumn,
Confidence = confidence,
ChangeHistoryLength = changeHistoryLength,
Martingale = martingale,
PowerMartingaleEpsilon = eps
}))
{
}

public IidChangePointEstimator(IHostEnvironment env, IidChangePointDetector.Arguments args)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(IidChangePointEstimator)),
new IidChangePointDetector(env, args))
{
return new IidChangePointDetector(env, this, newSource);
}

public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));

if (!inputSchema.TryFindColumn(Transformer.InputColumnName, out var col))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Transformer.InputColumnName);
if (col.ItemType != NumberType.R4)
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Transformer.InputColumnName, NumberType.R4.ToString(), col.GetTypeString());

var metadata = new List<SchemaShape.Column>() {
new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false)
};
var resultDic = inputSchema.Columns.ToDictionary(x => x.Name);

resultDic[Transformer.OutputColumnName] = new SchemaShape.Column(
Transformer.OutputColumnName, SchemaShape.Column.VectorKind.Vector, NumberType.R8, false, new SchemaShape(metadata));

return new SchemaShape(resultDic.Values);
}
}
}
Loading