diff --git a/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs b/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs
index 5c01472ed9..6046fabc72 100644
--- a/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs
+++ b/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs
@@ -221,13 +221,16 @@ private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "SSAMODLR",
- verWrittenCur: 0x00010001, // Initial
- verReadableCur: 0x00010001,
+ //verWrittenCur: 0x00010001, // Initial
+ verWrittenCur: 0x00010002, // Added saving _state and _nextPrediction
+ verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(AdaptiveSingularSpectrumSequenceModeler).Assembly.FullName);
}
+ private const int VersionSavingStateAndPrediction = 0x00010002;
+
///
/// The constructor for Adaptive SSA model.
///
@@ -333,6 +336,7 @@ private AdaptiveSingularSpectrumSequenceModeler(AdaptiveSingularSpectrumSequence
_autoregressionNoiseVariance = model._autoregressionNoiseVariance;
_observationNoiseMean = model._observationNoiseMean;
_autoregressionNoiseMean = model._autoregressionNoiseMean;
+ _nextPrediction = model._nextPrediction;
_maxTrendRatio = model._maxTrendRatio;
_shouldStablize = model._shouldStablize;
_shouldMaintainInfo = model._shouldMaintainInfo;
@@ -368,11 +372,13 @@ public AdaptiveSingularSpectrumSequenceModeler(IHostEnvironment env, ModelLoadCo
// RankSelectionMethod: _rankSelectionMethod
// bool: isWeightSet
// float[]: _alpha
+ // float[]: _state
// bool: ShouldComputeForecastIntervals
// float: _observationNoiseVariance
// float: _autoregressionNoiseVariance
// float: _observationNoiseMean
// float: _autoregressionNoiseMean
+ // float: _nextPrediction
// int: _maxRank
// bool: _shouldStablize
// bool: _shouldMaintainInfo
@@ -404,6 +410,14 @@ public AdaptiveSingularSpectrumSequenceModeler(IHostEnvironment env, ModelLoadCo
_alpha = ctx.Reader.ReadFloatArray();
_host.CheckDecode(Utils.Size(_alpha) == _windowSize - 1);
+ if (ctx.Header.ModelVerReadable >= VersionSavingStateAndPrediction)
+ {
+ _state = ctx.Reader.ReadFloatArray();
+ _host.CheckDecode(Utils.Size(_state) == _windowSize - 1);
+ }
+ else
+ _state = new Single[_windowSize - 1];
+
ShouldComputeForecastIntervals = ctx.Reader.ReadBoolByte();
_observationNoiseVariance = ctx.Reader.ReadSingle();
@@ -414,6 +428,8 @@ public AdaptiveSingularSpectrumSequenceModeler(IHostEnvironment env, ModelLoadCo
_observationNoiseMean = ctx.Reader.ReadSingle();
_autoregressionNoiseMean = ctx.Reader.ReadSingle();
+ if (ctx.Header.ModelVerReadable >= VersionSavingStateAndPrediction)
+ _nextPrediction = ctx.Reader.ReadSingle();
_maxRank = ctx.Reader.ReadInt32();
_host.CheckDecode(1 <= _maxRank && _maxRank <= _windowSize - 1);
@@ -444,7 +460,7 @@ public AdaptiveSingularSpectrumSequenceModeler(IHostEnvironment env, ModelLoadCo
}
_buffer = new FixedSizeQueue(_seriesLength);
- _state = new Single[_windowSize - 1];
+
_x = new CpuAlignedVector(_windowSize, SseUtils.CbAlign);
_xSmooth = new CpuAlignedVector(_windowSize, SseUtils.CbAlign);
}
@@ -475,11 +491,13 @@ public void Save(ModelSaveContext ctx)
// RankSelectionMethod: _rankSelectionMethod
// bool: _isWeightSet
// float[]: _alpha
+ // float[]: _state
// bool: ShouldComputeForecastIntervals
// float: _observationNoiseVariance
// float: _autoregressionNoiseVariance
// float: _observationNoiseMean
// float: _autoregressionNoiseMean
+ // float: _nextPrediction
// int: _maxRank
// bool: _shouldStablize
// bool: _shouldMaintainInfo
@@ -494,11 +512,13 @@ public void Save(ModelSaveContext ctx)
ctx.Writer.Write((byte)_rankSelectionMethod);
ctx.Writer.WriteBoolByte(_wTrans != null);
ctx.Writer.WriteFloatArray(_alpha);
+ ctx.Writer.WriteFloatArray(_state);
ctx.Writer.WriteBoolByte(ShouldComputeForecastIntervals);
ctx.Writer.Write(_observationNoiseVariance);
ctx.Writer.Write(_autoregressionNoiseVariance);
ctx.Writer.Write(_observationNoiseMean);
ctx.Writer.Write(_autoregressionNoiseMean);
+ ctx.Writer.Write(_nextPrediction);
ctx.Writer.Write(_maxRank);
ctx.Writer.WriteBoolByte(_shouldStablize);
ctx.Writer.WriteBoolByte(_shouldMaintainInfo);
diff --git a/src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs b/src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs
index bb14178827..900407adec 100644
--- a/src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs
+++ b/src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs
@@ -15,18 +15,32 @@ namespace Microsoft.ML.Runtime.TimeSeriesProcessing
///
public abstract class IidAnomalyDetectionBase : SequentialAnomalyDetectionTransformBase
{
- public IidAnomalyDetectionBase(ArgumentsBase args, string name, IHostEnvironment env, IDataView input)
- : base(args, name, env, input)
+ public IidAnomalyDetectionBase(ArgumentsBase args, string name, IHostEnvironment env)
+ : base(args, name, env)
{
InitialWindowSize = 0;
}
- public IidAnomalyDetectionBase(IHostEnvironment env, ModelLoadContext ctx, string name, IDataView input)
- : base(env, ctx, name, input)
+ public IidAnomalyDetectionBase(IHostEnvironment env, ModelLoadContext ctx, string name)
+ : base(env, ctx, name)
{
Host.CheckDecode(InitialWindowSize == 0);
}
+ public override Schema GetOutputSchema(Schema inputSchema)
+ {
+ Host.CheckValue(inputSchema, nameof(inputSchema));
+
+ if (!inputSchema.TryGetColumnIndex(InputColumnName, out var col))
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", InputColumnName);
+
+ var colType = inputSchema.GetColumnType(col);
+ if (colType != NumberType.R4)
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", InputColumnName, NumberType.R4.ToString(), colType.ToString());
+
+ return Transform(new EmptyDataView(Host, inputSchema)).Schema;
+ }
+
public override void Save(ModelSaveContext ctx)
{
ctx.CheckAtModel();
diff --git a/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs b/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs
index 909bbc390c..9bb6dcab95 100644
--- a/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs
+++ b/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs
@@ -3,24 +3,35 @@
// See the LICENSE file in the project root for more information.
using System;
+using System.Collections.Generic;
+using System.Linq;
+using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Model;
using Microsoft.ML.Runtime.TimeSeriesProcessing;
+using static Microsoft.ML.Runtime.TimeSeriesProcessing.SequentialAnomalyDetectionTransformBase;
-[assembly: LoadableClass(IidChangePointDetector.Summary, typeof(IidChangePointDetector), typeof(IidChangePointDetector.Arguments), typeof(SignatureDataTransform),
+[assembly: LoadableClass(IidChangePointDetector.Summary, typeof(IDataTransform), typeof(IidChangePointDetector), typeof(IidChangePointDetector.Arguments), typeof(SignatureDataTransform),
IidChangePointDetector.UserName, IidChangePointDetector.LoaderSignature, IidChangePointDetector.ShortName)]
-[assembly: LoadableClass(IidChangePointDetector.Summary, typeof(IidChangePointDetector), null, typeof(SignatureLoadDataTransform),
+
+[assembly: LoadableClass(IidChangePointDetector.Summary, typeof(IDataTransform), typeof(IidChangePointDetector), null, typeof(SignatureLoadDataTransform),
+ IidChangePointDetector.UserName, IidChangePointDetector.LoaderSignature)]
+
+[assembly: LoadableClass(IidChangePointDetector.Summary, typeof(IidChangePointDetector), null, typeof(SignatureLoadModel),
IidChangePointDetector.UserName, IidChangePointDetector.LoaderSignature)]
+[assembly: LoadableClass(typeof(IRowMapper), typeof(IidChangePointDetector), null, typeof(SignatureLoadRowMapper),
+ IidChangePointDetector.UserName, IidChangePointDetector.LoaderSignature)]
+
namespace Microsoft.ML.Runtime.TimeSeriesProcessing
{
///
/// This class implements the change point detector transform for an i.i.d. sequence based on adaptive kernel density estimation and martingales.
///
- public sealed class IidChangePointDetector : IidAnomalyDetectionBase, ITransformTemplate
+ public sealed class IidChangePointDetector : IidAnomalyDetectionBase
{
internal const string Summary = "This transform detects the change-points in an i.i.d. sequence using adaptive kernel density estimation and martingales.";
public const string LoaderSignature = "IidChangePointDetector";
@@ -89,8 +100,18 @@ private static VersionInfo GetVersionInfo()
loaderAssemblyName: typeof(IidChangePointDetector).Assembly.FullName);
}
- public IidChangePointDetector(IHostEnvironment env, Arguments args, IDataView input)
- : base(new BaseArguments(args), LoaderSignature, env, input)
+ // Factory method for SignatureDataTransform.
+ private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(args, nameof(args));
+ env.CheckValue(input, nameof(input));
+
+ return new IidChangePointDetector(env, args).MakeDataTransform(input);
+ }
+
+ internal IidChangePointDetector(IHostEnvironment env, Arguments args)
+ : base(new BaseArguments(args), LoaderSignature, env)
{
switch (Martingale)
{
@@ -109,8 +130,28 @@ public IidChangePointDetector(IHostEnvironment env, Arguments args, IDataView in
}
}
- public IidChangePointDetector(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
- : base(env, ctx, LoaderSignature, input)
+ // Factory method for SignatureLoadDataTransform.
+ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(ctx, nameof(ctx));
+ env.CheckValue(input, nameof(input));
+
+ return new IidChangePointDetector(env, ctx).MakeDataTransform(input);
+ }
+
+ // Factory method for SignatureLoadModel.
+ private static IidChangePointDetector Create(IHostEnvironment env, ModelLoadContext ctx)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel(GetVersionInfo());
+
+ return new IidChangePointDetector(env, ctx);
+ }
+
+ internal IidChangePointDetector(IHostEnvironment env, ModelLoadContext ctx)
+ : base(env, ctx, LoaderSignature)
{
// *** Binary format ***
//
@@ -119,8 +160,8 @@ public IidChangePointDetector(IHostEnvironment env, ModelLoadContext ctx, IDataV
Host.CheckDecode(Side == AnomalySide.TwoSided);
}
- private IidChangePointDetector(IHostEnvironment env, IidChangePointDetector transform, IDataView newSource)
- : base(new BaseArguments(transform), LoaderSignature, env, newSource)
+ private IidChangePointDetector(IHostEnvironment env, IidChangePointDetector transform)
+ : base(new BaseArguments(transform), LoaderSignature, env)
{
}
@@ -139,9 +180,65 @@ public override void Save(ModelSaveContext ctx)
base.Save(ctx);
}
- public IDataTransform ApplyToData(IHostEnvironment env, IDataView newSource)
+ // Factory method for SignatureLoadRowMapper.
+ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema)
+ => Create(env, ctx).MakeRowMapper(inputSchema);
+ }
+
+ ///
+ /// Estimator for
+ ///
+ public sealed class IidChangePointEstimator : TrivialEstimator
+ {
+ ///
+ /// Create a new instance of
+ ///
+ /// Host Environment.
+ /// Name of the input column.
+ /// The name of the new column.
+ /// The confidence for change point detection in the range [0, 100].
+ /// The change history length.
+ /// The martingale used for scoring.
+ /// The epsilon parameter for the Power martingale.
+ public IidChangePointEstimator(IHostEnvironment env, string inputColumn, string outputColumn, int confidence,
+ int changeHistoryLength, MartingaleType martingale = MartingaleType.Power, double eps = 0.1)
+ : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(IidChangePointEstimator)),
+ new IidChangePointDetector(env, new IidChangePointDetector.Arguments
+ {
+ Name = outputColumn,
+ Source = inputColumn,
+ Confidence = confidence,
+ ChangeHistoryLength = changeHistoryLength,
+ Martingale = martingale,
+ PowerMartingaleEpsilon = eps
+ }))
+ {
+ }
+
+ public IidChangePointEstimator(IHostEnvironment env, IidChangePointDetector.Arguments args)
+ : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(IidChangePointEstimator)),
+ new IidChangePointDetector(env, args))
{
- return new IidChangePointDetector(env, this, newSource);
+ }
+
+ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
+ {
+ Host.CheckValue(inputSchema, nameof(inputSchema));
+
+ if (!inputSchema.TryFindColumn(Transformer.InputColumnName, out var col))
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Transformer.InputColumnName);
+ if (col.ItemType != NumberType.R4)
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Transformer.InputColumnName, NumberType.R4.ToString(), col.GetTypeString());
+
+ var metadata = new List() {
+ new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false)
+ };
+ var resultDic = inputSchema.Columns.ToDictionary(x => x.Name);
+
+ resultDic[Transformer.OutputColumnName] = new SchemaShape.Column(
+ Transformer.OutputColumnName, SchemaShape.Column.VectorKind.Vector, NumberType.R8, false, new SchemaShape(metadata));
+
+ return new SchemaShape(resultDic.Values);
}
}
}
diff --git a/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs b/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs
index f4d30c8696..cd7dafc83c 100644
--- a/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs
+++ b/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs
@@ -2,24 +2,35 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
+using System.Collections.Generic;
+using System.Linq;
+using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Model;
using Microsoft.ML.Runtime.TimeSeriesProcessing;
+using static Microsoft.ML.Runtime.TimeSeriesProcessing.SequentialAnomalyDetectionTransformBase;
-[assembly: LoadableClass(IidSpikeDetector.Summary, typeof(IidSpikeDetector), typeof(IidSpikeDetector.Arguments), typeof(SignatureDataTransform),
+[assembly: LoadableClass(IidSpikeDetector.Summary, typeof(IDataTransform), typeof(IidSpikeDetector), typeof(IidSpikeDetector.Arguments), typeof(SignatureDataTransform),
IidSpikeDetector.UserName, IidSpikeDetector.LoaderSignature, IidSpikeDetector.ShortName)]
-[assembly: LoadableClass(IidSpikeDetector.Summary, typeof(IidSpikeDetector), null, typeof(SignatureLoadDataTransform),
+
+[assembly: LoadableClass(IidSpikeDetector.Summary, typeof(IDataTransform), typeof(IidSpikeDetector), null, typeof(SignatureLoadDataTransform),
+ IidSpikeDetector.UserName, IidSpikeDetector.LoaderSignature)]
+
+[assembly: LoadableClass(IidSpikeDetector.Summary, typeof(IidSpikeDetector), null, typeof(SignatureLoadModel),
IidSpikeDetector.UserName, IidSpikeDetector.LoaderSignature)]
+[assembly: LoadableClass(typeof(IRowMapper), typeof(IidSpikeDetector), null, typeof(SignatureLoadRowMapper),
+ IidSpikeDetector.UserName, IidSpikeDetector.LoaderSignature)]
+
namespace Microsoft.ML.Runtime.TimeSeriesProcessing
{
///
/// This class implements the spike detector transform for an i.i.d. sequence based on adaptive kernel density estimation.
///
- public sealed class IidSpikeDetector : IidAnomalyDetectionBase, ITransformTemplate
+ public sealed class IidSpikeDetector : IidAnomalyDetectionBase
{
internal const string Summary = "This transform detects the spikes in a i.i.d. sequence using adaptive kernel density estimation.";
public const string LoaderSignature = "IidSpikeDetector";
@@ -85,22 +96,52 @@ private static VersionInfo GetVersionInfo()
loaderAssemblyName: typeof(IidSpikeDetector).Assembly.FullName);
}
- public IidSpikeDetector(IHostEnvironment env, Arguments args, IDataView input)
- : base(new BaseArguments(args), LoaderSignature, env, input)
+ // Factory method for SignatureDataTransform.
+ private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(args, nameof(args));
+ env.CheckValue(input, nameof(input));
+
+ return new IidSpikeDetector(env, args).MakeDataTransform(input);
+ }
+
+ internal IidSpikeDetector(IHostEnvironment env, Arguments args)
+ : base(new BaseArguments(args), LoaderSignature, env)
{
// This constructor is empty.
}
- public IidSpikeDetector(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
- : base(env, ctx, LoaderSignature, input)
+ // Factory method for SignatureLoadDataTransform.
+ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(ctx, nameof(ctx));
+ env.CheckValue(input, nameof(input));
+
+ return new IidSpikeDetector(env, ctx).MakeDataTransform(input);
+ }
+
+ // Factory method for SignatureLoadModel.
+ private static IidSpikeDetector Create(IHostEnvironment env, ModelLoadContext ctx)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel(GetVersionInfo());
+
+ return new IidSpikeDetector(env, ctx);
+ }
+
+ public IidSpikeDetector(IHostEnvironment env, ModelLoadContext ctx)
+ : base(env, ctx, LoaderSignature)
{
// *** Binary format ***
//
Host.CheckDecode(ThresholdScore == AlertingScore.PValueScore);
}
- private IidSpikeDetector(IHostEnvironment env, IidSpikeDetector transform, IDataView newSource)
- : base(new BaseArguments(transform), LoaderSignature, env, newSource)
+ private IidSpikeDetector(IHostEnvironment env, IidSpikeDetector transform)
+ : base(new BaseArguments(transform), LoaderSignature, env)
{
}
@@ -118,9 +159,60 @@ public override void Save(ModelSaveContext ctx)
base.Save(ctx);
}
- public IDataTransform ApplyToData(IHostEnvironment env, IDataView newSource)
+ // Factory method for SignatureLoadRowMapper.
+ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema)
+ => Create(env, ctx).MakeRowMapper(inputSchema);
+ }
+
+ ///
+ /// Estimator for
+ ///
+ public sealed class IidSpikeEstimator : TrivialEstimator
+ {
+ ///
+ /// Create a new instance of
+ ///
+ /// Host Environment.
+ /// Name of the input column.
+ /// The name of the new column.
+ /// The confidence for spike detection in the range [0, 100].
+ /// The size of the sliding window for computing the p-value.
+ /// The argument that determines whether to detect positive or negative anomalies, or both.
+ public IidSpikeEstimator(IHostEnvironment env, string inputColumn, string outputColumn, int confidence, int pvalueHistoryLength, AnomalySide side = AnomalySide.TwoSided)
+ : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(IidSpikeDetector)),
+ new IidSpikeDetector(env, new IidSpikeDetector.Arguments
+ {
+ Name = outputColumn,
+ Source = inputColumn,
+ Confidence = confidence,
+ PvalueHistoryLength = pvalueHistoryLength,
+ Side = side
+ }))
{
- return new IidSpikeDetector(env, this, newSource);
+ }
+
+ public IidSpikeEstimator(IHostEnvironment env, IidSpikeDetector.Arguments args)
+ : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(IidSpikeEstimator)), new IidSpikeDetector(env, args))
+ {
+ }
+
+ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
+ {
+ Host.CheckValue(inputSchema, nameof(inputSchema));
+
+ if (!inputSchema.TryFindColumn(Transformer.InputColumnName, out var col))
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Transformer.InputColumnName);
+ if (col.ItemType != NumberType.R4)
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Transformer.InputColumnName, NumberType.R4.ToString(), col.GetTypeString());
+
+ var metadata = new List() {
+ new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false)
+ };
+ var resultDic = inputSchema.Columns.ToDictionary(x => x.Name);
+ resultDic[Transformer.OutputColumnName] = new SchemaShape.Column(
+ Transformer.OutputColumnName, SchemaShape.Column.VectorKind.Vector, NumberType.R8, false, new SchemaShape(metadata));
+
+ return new SchemaShape(resultDic.Values);
}
}
}
diff --git a/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs b/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs
index 240d05e1c0..e80f41a985 100644
--- a/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs
+++ b/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs
@@ -24,7 +24,7 @@ namespace Microsoft.ML.Runtime.TimeSeriesProcessing
///
/// The type of the input sequence
/// The type of the state object for sequential anomaly detection. Must be a class inherited from AnomalyDetectionStateBase
- public abstract class SequentialAnomalyDetectionTransformBase : SequentialTransformBase, TState>
+ public abstract class SequentialAnomalyDetectionTransformBase : SequentialTransformerBase, TState>
where TState : SequentialAnomalyDetectionTransformBase.AnomalyDetectionStateBase, new()
{
///
@@ -144,10 +144,6 @@ public abstract class ArgumentsBase
// The size of the VBuffer in the dst column.
private int _outputLength;
- private readonly SchemaImpl _wrappedSchema;
-
- public override Schema Schema => _wrappedSchema.AsSchema;
-
private static int GetOutputLength(AlertingScore alertingScore, IHostEnvironment host)
{
switch (alertingScore)
@@ -163,23 +159,10 @@ private static int GetOutputLength(AlertingScore alertingScore, IHostEnvironment
}
}
- private static SchemaImpl CreateSchema(ISchema parentSchema, string colName, int length)
- {
- Contracts.AssertValue(parentSchema);
- Contracts.Assert(2 <= length && length <= 4);
-
- string[] names = { "Alert", "Raw Score", "P-Value Score", "Martingale Score" };
- int col;
- bool result = parentSchema.TryGetColumnIndex(colName, out col);
- Contracts.Assert(result);
-
- return new SchemaImpl(parentSchema, col, names, length);
- }
-
- protected SequentialAnomalyDetectionTransformBase(int windowSize, int initialWindowSize, string inputColumnName, string outputColumnName, string name, IHostEnvironment env, IDataView input,
+ protected SequentialAnomalyDetectionTransformBase(int windowSize, int initialWindowSize, string inputColumnName, string outputColumnName, string name, IHostEnvironment env,
AnomalySide anomalySide, MartingaleType martingale, AlertingScore alertingScore, Double powerMartingaleEpsilon,
Double alertThreshold)
- : base(windowSize, initialWindowSize, inputColumnName, outputColumnName, name, env, input, new VectorType(NumberType.R8, GetOutputLength(alertingScore, env)))
+ : base(Contracts.CheckRef(env, nameof(env)).Register(name), windowSize, initialWindowSize, inputColumnName, outputColumnName, new VectorType(NumberType.R8, GetOutputLength(alertingScore, env)))
{
Host.CheckUserArg(Enum.IsDefined(typeof(MartingaleType), martingale), nameof(ArgumentsBase.Martingale), "Value is undefined.");
Host.CheckUserArg(Enum.IsDefined(typeof(AnomalySide), anomalySide), nameof(ArgumentsBase.Side), "Value is undefined.");
@@ -198,17 +181,16 @@ protected SequentialAnomalyDetectionTransformBase(int windowSize, int initialWin
PowerMartingaleEpsilon = powerMartingaleEpsilon;
AlertThreshold = alertThreshold;
_outputLength = GetOutputLength(ThresholdScore, Host);
- _wrappedSchema = CreateSchema(base.Schema, outputColumnName, _outputLength);
}
- protected SequentialAnomalyDetectionTransformBase(ArgumentsBase args, string name, IHostEnvironment env, IDataView input)
- : this(args.WindowSize, args.InitialWindowSize, args.Source, args.Name, name, env, input, args.Side, args.Martingale,
+ protected SequentialAnomalyDetectionTransformBase(ArgumentsBase args, string name, IHostEnvironment env)
+ : this(args.WindowSize, args.InitialWindowSize, args.Source, args.Name, name, env, args.Side, args.Martingale,
args.AlertOn, args.PowerMartingaleEpsilon, args.AlertThreshold)
{
}
- protected SequentialAnomalyDetectionTransformBase(IHostEnvironment env, ModelLoadContext ctx, string name, IDataView input)
- : base(env, ctx, name, input)
+ protected SequentialAnomalyDetectionTransformBase(IHostEnvironment env, ModelLoadContext ctx, string name)
+ : base(Contracts.CheckRef(env, nameof(env)).Register(name), ctx)
{
// *** Binary format ***
//
@@ -242,7 +224,6 @@ protected SequentialAnomalyDetectionTransformBase(IHostEnvironment env, ModelLoa
Host.CheckDecode(ThresholdScore != AlertingScore.PValueScore || (0 <= AlertThreshold && AlertThreshold <= 1));
_outputLength = GetOutputLength(ThresholdScore, Host);
- _wrappedSchema = CreateSchema(base.Schema, OutputColumnName, _outputLength);
}
public override void Save(ModelSaveContext ctx)
@@ -557,102 +538,86 @@ protected override sealed void InitializeStateCore()
protected abstract Double ComputeRawAnomalyScore(ref TInput input, FixedSizeQueue windowedBuffer, long iteration);
}
- ///
- /// Schema implementation to add slot name metadata to the produced output column.
- ///
- private sealed class SchemaImpl : ISchema
- {
- private readonly ISchema _parent;
- private readonly int _col;
- private readonly ColumnType _type;
- private readonly string[] _names;
- private readonly int _namesLength;
- private readonly MetadataUtils.MetadataGetter>> _getter;
-
- public Schema AsSchema { get; }
+ protected override IRowMapper MakeRowMapper(ISchema schema) => new Mapper(Host, this, schema);
- public int ColumnCount { get { return _parent.ColumnCount; } }
+ private sealed class Mapper : IRowMapper
+ {
+ private readonly IHost _host;
+ private readonly SequentialAnomalyDetectionTransformBase _parent;
+ private readonly ISchema _parentSchema;
+ private readonly int _inputColumnIndex;
+ private readonly VBuffer> _slotNames;
- ///
- /// Constructs the schema.
- ///
- /// The schema we will wrap.
- /// Aside from presenting that additional piece of metadata, the constructed schema
- /// will appear identical to this input schema.
- /// The column in that has the metadata.
- ///
- ///
- public SchemaImpl(ISchema schema, int col, string[] names, int length)
+ public Mapper(IHostEnvironment env, SequentialAnomalyDetectionTransformBase parent, ISchema inputSchema)
{
- Contracts.Assert(length > 0);
- Contracts.Assert(Utils.Size(names) >= length);
- Contracts.AssertValue(schema);
- Contracts.Assert(0 <= col && col < schema.ColumnCount);
- _parent = schema;
- _col = col;
-
- _names = names;
- _namesLength = length;
-
- _type = new VectorType(TextType.Instance, _namesLength);
- Contracts.AssertValue(_type);
- _getter = GetSlotNames;
-
- AsSchema = Schema.Create(this);
+ Contracts.CheckValue(env, nameof(env));
+ _host = env.Register(nameof(Mapper));
+ _host.CheckValue(inputSchema, nameof(inputSchema));
+ _host.CheckValue(parent, nameof(parent));
+
+ if (!inputSchema.TryGetColumnIndex(parent.InputColumnName, out _inputColumnIndex))
+ throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", parent.InputColumnName);
+
+ var colType = inputSchema.GetColumnType(_inputColumnIndex);
+ if (colType != NumberType.R4)
+ throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", parent.InputColumnName, NumberType.R4.ToString(), colType.ToString());
+
+ _parent = parent;
+ _parentSchema = inputSchema;
+ _slotNames = new VBuffer>(4, new[] { "Alert".AsMemory(), "Raw Score".AsMemory(),
+ "P-Value Score".AsMemory(), "Martingale Score".AsMemory() });
}
- public bool TryGetColumnIndex(string name, out int col)
+ public Schema.Column[] GetOutputColumns()
{
- return _parent.TryGetColumnIndex(name, out col);
+ var meta = new Schema.Metadata.Builder();
+ meta.AddSlotNames(_parent._outputLength, GetSlotNames);
+ var info = new Schema.Column[1];
+ info[0] = new Schema.Column(_parent.OutputColumnName, new VectorType(NumberType.R8, _parent._outputLength), meta.GetMetadata());
+ return info;
}
- public string GetColumnName(int col)
- {
- return _parent.GetColumnName(col);
- }
+ public void GetSlotNames(ref VBuffer> dst) => _slotNames.CopyTo(ref dst, 0, _parent._outputLength);
- public ColumnType GetColumnType(int col)
+ public Func GetDependencies(Func activeOutput)
{
- return _parent.GetColumnType(col);
+ if (activeOutput(0))
+ return col => col == _inputColumnIndex;
+ else
+ return col => false;
}
- public IEnumerable> GetMetadataTypes(int col)
- {
- var result = _parent.GetMetadataTypes(col);
- if (col == _col)
- return result.Prepend(_type.GetPair(MetadataUtils.Kinds.SlotNames));
- return result;
- }
+ public void Save(ModelSaveContext ctx) => _parent.Save(ctx);
- public ColumnType GetMetadataTypeOrNull(string kind, int col)
+ public Delegate[] CreateGetters(IRow input, Func activeOutput, out Action disposer)
{
- if (col == _col && kind == MetadataUtils.Kinds.SlotNames)
- return _type;
- return _parent.GetMetadataTypeOrNull(kind, col);
+ disposer = null;
+ var getters = new Delegate[1];
+ if (activeOutput(0))
+ {
+ TState state = new TState();
+ state.InitState(_parent.WindowSize, _parent.InitialWindowSize, _parent, _host);
+ getters[0] = MakeGetter(input, state);
+ }
+ return getters;
}
- public void GetSlotNames(int col, ref VBuffer> slotNames)
- {
- Contracts.Assert(col == _col);
-
- var result = slotNames.Values;
- if (Utils.Size(result) < _namesLength)
- result = new ReadOnlyMemory[_namesLength];
+ private delegate void ProcessData(ref TInput src, ref VBuffer dst);
- for (int i = 0; i < _namesLength; ++i)
- result[i] = _names[i].AsMemory();
-
- slotNames = new VBuffer>(_namesLength, result, slotNames.Indices);
- }
-
- public void GetMetadata(string kind, int col, ref TValue value)
+ private Delegate MakeGetter(IRow input, TState state)
{
- if (col == _col && kind == MetadataUtils.Kinds.SlotNames)
+ _host.AssertValue(input);
+ var srcGetter = input.GetGetter(_inputColumnIndex);
+ ProcessData processData = _parent.WindowSize > 0 ?
+ (ProcessData) state.Process : state.ProcessWithoutBuffer;
+ ValueGetter > valueGetter = (ref VBuffer dst) =>
{
- _getter.Marshal(col, ref value);
- return;
- }
- _parent.GetMetadata(kind, col, ref value);
+ TInput src = default;
+ srcGetter(ref src);
+ processData(ref src, ref dst);
+ };
+
+ return valueGetter;
}
}
}
diff --git a/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs b/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs
new file mode 100644
index 0000000000..5e2b32a81b
--- /dev/null
+++ b/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs
@@ -0,0 +1,406 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Data.IO;
+using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Runtime.Model;
+using Microsoft.ML.Runtime.Api;
+using Microsoft.ML.Core.Data;
+
+namespace Microsoft.ML.Runtime.TimeSeriesProcessing
+{
+ ///
+ /// The base class for sequential processing transforms. This class implements the basic sliding window buffering. The derived classes need to specify the transform logic,
+ /// the initialization logic and the learning logic via implementing the abstract methods TransformCore(), InitializeStateCore() and LearnStateFromDataCore(), respectively
+ ///
+ /// The input type of the sequential processing.
+ /// The dst type of the sequential processing.
+ /// The state type of the sequential processing. Must be a class inherited from StateBase
+ public abstract class SequentialTransformerBase : ITransformer, ICanSaveModel
+ where TState : SequentialTransformerBase.StateBase, new()
+ {
+ ///
+ /// The base class for encapsulating the State object for sequential processing. This class implements a windowed buffer.
+ ///
+ public abstract class StateBase
+ {
+ // Ideally this class should be private. However, due to the current constraints with the LambdaTransform, we need to have
+ // access to the state class when inheriting from SequentialTransformerBase.
+ protected IHost Host;
+
+ ///
+ /// A reference to the parent transform that operates on the state object.
+ ///
+ protected SequentialTransformerBase ParentTransform;
+
+ ///
+ /// The internal windowed buffer for buffering the values in the input sequence.
+ ///
+ protected FixedSizeQueue WindowedBuffer;
+
+ ///
+ /// The buffer used to buffer the training data points.
+ ///
+ protected FixedSizeQueue InitialWindowedBuffer;
+
+ protected int WindowSize { get; private set; }
+
+ protected int InitialWindowSize { get; private set; }
+
+ ///
+ /// Counts the number of rows observed by the transform so far.
+ ///
+ protected long RowCounter { get; private set; }
+
+ protected long IncrementRowCounter()
+ {
+ RowCounter++;
+ return RowCounter;
+ }
+
+ private bool _isIniatilized;
+
+ ///
+ /// This method sets the window size and initializes the buffer only once.
+ /// Since the class needs to implement a default constructor, this methods provides a mechanism to initialize the window size and buffer.
+ ///
+ /// The size of the windowed buffer
+ /// The size of the windowed initial buffer used for training
+ /// The parent transform of this state object
+ /// The host
+ public void InitState(int windowSize, int initialWindowSize, SequentialTransformerBase parentTransform, IHost host)
+ {
+ Contracts.CheckValue(host, nameof(host), "The host cannot be null.");
+ host.Check(!_isIniatilized, "The window size can be set only once.");
+ host.CheckValue(parentTransform, nameof(parentTransform));
+ host.CheckParam(windowSize >= 0, nameof(windowSize), "Must be non-negative.");
+ host.CheckParam(initialWindowSize >= 0, nameof(initialWindowSize), "Must be non-negative.");
+
+ Host = host;
+ WindowSize = windowSize;
+ InitialWindowSize = initialWindowSize;
+ ParentTransform = parentTransform;
+ WindowedBuffer = (WindowSize > 0) ? new FixedSizeQueue(WindowSize) : new FixedSizeQueue(1);
+ InitialWindowedBuffer = (InitialWindowSize > 0) ? new FixedSizeQueue(InitialWindowSize) : new FixedSizeQueue(1);
+ RowCounter = 0;
+
+ InitializeStateCore();
+ _isIniatilized = true;
+ }
+
+ ///
+ /// This method implements the basic resetting mechanism for a state object and clears the buffer.
+ ///
+ public virtual void Reset()
+ {
+ Host.Assert(_isIniatilized);
+ Host.Assert(WindowedBuffer != null);
+ Host.Assert(InitialWindowedBuffer != null);
+
+ RowCounter = 0;
+ WindowedBuffer.Clear();
+ InitialWindowedBuffer.Clear();
+ }
+
+ public void Process(ref TInput input, ref TOutput output)
+ {
+ if (InitialWindowedBuffer.Count < InitialWindowSize)
+ {
+ InitialWindowedBuffer.AddLast(input);
+ SetNaOutput(ref output);
+
+ if (InitialWindowedBuffer.Count >= InitialWindowSize - WindowSize)
+ WindowedBuffer.AddLast(input);
+
+ if (InitialWindowedBuffer.Count == InitialWindowSize)
+ LearnStateFromDataCore(InitialWindowedBuffer);
+ }
+ else
+ {
+ TransformCore(ref input, WindowedBuffer, RowCounter - InitialWindowSize, ref output);
+ WindowedBuffer.AddLast(input);
+ IncrementRowCounter();
+ }
+ }
+
+ public void ProcessWithoutBuffer(ref TInput input, ref TOutput output)
+ {
+ if (InitialWindowedBuffer.Count < InitialWindowSize)
+ {
+ InitialWindowedBuffer.AddLast(input);
+ SetNaOutput(ref output);
+
+ if (InitialWindowedBuffer.Count == InitialWindowSize)
+ LearnStateFromDataCore(InitialWindowedBuffer);
+ }
+ else
+ {
+ TransformCore(ref input, WindowedBuffer, RowCounter - InitialWindowSize, ref output);
+ IncrementRowCounter();
+ }
+ }
+
+ ///
+ /// The abstract method that specifies the NA value for the dst type.
+ ///
+ ///
+ protected abstract void SetNaOutput(ref TOutput dst);
+
+ ///
+ /// The abstract method that realizes the main logic for the transform.
+ ///
+ /// A reference to the input object.
+ /// A reference to the dst object.
+ /// A reference to the windowed buffer.
+ /// A long number that indicates the number of times TransformCore has been called so far (starting value = 0).
+ protected abstract void TransformCore(ref TInput input, FixedSizeQueue windowedBuffer, long iteration, ref TOutput dst);
+
+ ///
+ /// The abstract method that realizes the logic for initializing the state object.
+ ///
+ protected abstract void InitializeStateCore();
+
+ ///
+ /// The abstract method that realizes the logic for learning the parameters and the initial state object from data.
+ ///
+ /// A queue of data points used for training
+ protected abstract void LearnStateFromDataCore(FixedSizeQueue data);
+ }
+
+ protected readonly IHost Host;
+
+ ///
+ /// The window size for buffering.
+ ///
+ protected readonly int WindowSize;
+
+ ///
+ /// The number of datapoints from the beginning of the sequence that are used for learning the initial state.
+ ///
+ protected int InitialWindowSize;
+
+ public string InputColumnName;
+ public string OutputColumnName;
+ protected ColumnType OutputColumnType;
+
+ public bool IsRowToRowMapper => false;
+
+ ///
+ /// The main constructor for the sequential transform
+ ///
+ /// The host.
+ /// The size of buffer used for windowed buffering.
+ /// The number of datapoints picked from the beginning of the series for training the transform parameters if needed.
+ /// The name of the input column.
+ /// The name of the dst column.
+ ///
+ protected SequentialTransformerBase(IHost host, int windowSize, int initialWindowSize, string inputColumnName, string outputColumnName, ColumnType outputColType)
+ {
+ Host = host;
+ Host.CheckParam(initialWindowSize >= 0, nameof(initialWindowSize), "Must be non-negative.");
+ Host.CheckParam(windowSize >= 0, nameof(windowSize), "Must be non-negative.");
+ // REVIEW: Very bad design. This base class is responsible for reporting errors on
+ // the arguments, but the arguments themselves are not derived form any base class.
+ Host.CheckNonEmpty(inputColumnName, nameof(PercentileThresholdTransform.Arguments.Source));
+ Host.CheckNonEmpty(outputColumnName, nameof(PercentileThresholdTransform.Arguments.Source));
+
+ InputColumnName = inputColumnName;
+ OutputColumnName = outputColumnName;
+ OutputColumnType = outputColType;
+ InitialWindowSize = initialWindowSize;
+ WindowSize = windowSize;
+ }
+
+ protected SequentialTransformerBase(IHost host, ModelLoadContext ctx)
+ {
+ Host = host;
+ Host.CheckValue(ctx, nameof(ctx));
+
+ // *** Binary format ***
+ // int: _windowSize
+ // int: _initialWindowSize
+ // int (string ID): _inputColumnName
+ // int (string ID): _outputColumnName
+ // ColumnType: _transform.Schema.GetColumnType(0)
+
+ var windowSize = ctx.Reader.ReadInt32();
+ Host.CheckDecode(windowSize >= 0);
+
+ var initialWindowSize = ctx.Reader.ReadInt32();
+ Host.CheckDecode(initialWindowSize >= 0);
+
+ var inputColumnName = ctx.LoadNonEmptyString();
+ var outputColumnName = ctx.LoadNonEmptyString();
+
+ InputColumnName = inputColumnName;
+ OutputColumnName = outputColumnName;
+ InitialWindowSize = initialWindowSize;
+ WindowSize = windowSize;
+
+ BinarySaver bs = new BinarySaver(Host, new BinarySaver.Arguments());
+ OutputColumnType = bs.LoadTypeDescriptionOrNull(ctx.Reader.BaseStream);
+ }
+
+ public virtual void Save(ModelSaveContext ctx)
+ {
+ Host.CheckValue(ctx, nameof(ctx));
+ Host.Assert(InitialWindowSize >= 0);
+ Host.Assert(WindowSize >= 0);
+
+ // *** Binary format ***
+ // int: _windowSize
+ // int: _initialWindowSize
+ // int (string ID): _inputColumnName
+ // int (string ID): _outputColumnName
+ // ColumnType: _transform.Schema.GetColumnType(0)
+
+ ctx.Writer.Write(WindowSize);
+ ctx.Writer.Write(InitialWindowSize);
+ ctx.SaveNonEmptyString(InputColumnName);
+ ctx.SaveNonEmptyString(OutputColumnName);
+
+ var bs = new BinarySaver(Host, new BinarySaver.Arguments());
+ bs.TryWriteTypeDescription(ctx.Writer.BaseStream, OutputColumnType, out int byteWritten);
+ }
+
+ public abstract Schema GetOutputSchema(Schema inputSchema);
+
+ protected abstract IRowMapper MakeRowMapper(ISchema schema);
+
+ protected SequentialDataTransform MakeDataTransform(IDataView input)
+ {
+ Host.CheckValue(input, nameof(input));
+ return new SequentialDataTransform(Host, this, input, MakeRowMapper(input.Schema));
+ }
+
+ public IDataView Transform(IDataView input) => MakeDataTransform(input);
+
+ public IRowToRowMapper GetRowToRowMapper(Schema inputSchema)
+ {
+ throw new InvalidOperationException("Not a RowToRowMapper.");
+ }
+
+ public sealed class SequentialDataTransform : TransformBase, ITransformTemplate
+ {
+ private readonly IRowMapper _mapper;
+ private readonly SequentialTransformerBase _parent;
+ private readonly IDataTransform _transform;
+ private readonly ColumnBindings _bindings;
+
+ public SequentialDataTransform(IHost host, SequentialTransformerBase parent, IDataView input, IRowMapper mapper)
+ :base(parent.Host, input)
+ {
+ _parent = parent;
+ _transform = CreateLambdaTransform(_parent.Host, input, _parent.InputColumnName,
+ _parent.OutputColumnName, InitFunction, _parent.WindowSize > 0, _parent.OutputColumnType);
+ _mapper = mapper;
+ _bindings = new ColumnBindings(Schema.Create(input.Schema), _mapper.GetOutputColumns());
+ }
+
+ private static IDataTransform CreateLambdaTransform(IHost host, IDataView input, string inputColumnName, string outputColumnName,
+ Action initFunction, bool hasBuffer, ColumnType outputColTypeOverride)
+ {
+ var inputSchema = SchemaDefinition.Create(typeof(DataBox));
+ inputSchema[0].ColumnName = inputColumnName;
+
+ var outputSchema = SchemaDefinition.Create(typeof(DataBox));
+ outputSchema[0].ColumnName = outputColumnName;
+
+ if (outputColTypeOverride != null)
+ outputSchema[0].ColumnType = outputColTypeOverride;
+
+ Action, DataBox, TState> lambda;
+ if (hasBuffer)
+ lambda = MapFunction;
+ else
+ lambda = MapFunctionWithoutBuffer;
+
+ return LambdaTransform.CreateMap(host, input, lambda, initFunction, inputSchema, outputSchema);
+ }
+
+ private static void MapFunction(DataBox input, DataBox output, TState state)
+ {
+ state.Process(ref input.Value, ref output.Value);
+ }
+
+ private static void MapFunctionWithoutBuffer(DataBox input, DataBox output, TState state)
+ {
+ state.ProcessWithoutBuffer(ref input.Value, ref output.Value);
+ }
+
+ private void InitFunction(TState state)
+ {
+ state.InitState(_parent.WindowSize, _parent.InitialWindowSize, _parent, _parent.Host);
+ }
+
+ public override bool CanShuffle { get { return false; } }
+
+ protected override IRowCursor GetRowCursorCore(Func predicate, IRandom rand = null)
+ {
+ var srcCursor = _transform.GetRowCursor(predicate, rand);
+ return new Cursor(Host, this, srcCursor);
+ }
+
+ protected override bool? ShouldUseParallelCursors(Func predicate)
+ {
+ Host.AssertValue(predicate);
+ return false;
+ }
+
+ public override Schema Schema => _bindings.Schema;
+
+ public override long? GetRowCount(bool lazy = true)
+ {
+ return _transform.GetRowCount(lazy);
+ }
+
+ public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, IRandom rand = null)
+ {
+ consolidator = null;
+ return new IRowCursor[] { GetRowCursorCore(predicate, rand) };
+ }
+
+ public override void Save(ModelSaveContext ctx)
+ {
+ _parent.Save(ctx);
+ }
+
+ public IDataTransform ApplyToData(IHostEnvironment env, IDataView newSource)
+ {
+ return new SequentialDataTransform(Contracts.CheckRef(env, nameof(env)).Register("SequentialDataTransform"), _parent, newSource, _mapper);
+ }
+ }
+
+ ///
+ /// A wrapper around the cursor which replaces the schema.
+ ///
+ private sealed class Cursor : SynchronizedCursorBase, IRowCursor
+ {
+ private readonly SequentialDataTransform _parent;
+
+ public Cursor(IHost host, SequentialDataTransform parent, IRowCursor input)
+ : base(host, input)
+ {
+ Ch.Assert(input.Schema.ColumnCount == parent.Schema.ColumnCount);
+ _parent = parent;
+ }
+
+ public Schema Schema { get { return _parent.Schema; } }
+
+ public bool IsColumnActive(int col)
+ {
+ Ch.Check(0 <= col && col < Schema.ColumnCount, "col");
+ return Input.IsColumnActive(col);
+ }
+
+ public ValueGetter GetGetter(int col)
+ {
+ Ch.Check(IsColumnActive(col), "col");
+ return Input.GetGetter(col);
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.ML.TimeSeries/SsaAnomalyDetectionBase.cs b/src/Microsoft.ML.TimeSeries/SsaAnomalyDetectionBase.cs
index f4f7023b81..86d422e484 100644
--- a/src/Microsoft.ML.TimeSeries/SsaAnomalyDetectionBase.cs
+++ b/src/Microsoft.ML.TimeSeries/SsaAnomalyDetectionBase.cs
@@ -106,8 +106,8 @@ public abstract class SsaArguments : ArgumentsBase
protected readonly Func ErrorFunc;
protected readonly ISequenceModeler Model;
- public SsaAnomalyDetectionBase(SsaArguments args, string name, IHostEnvironment env, IDataView input)
- : base(args.WindowSize, 0, args.Source, args.Name, name, env, input, args.Side, args.Martingale, args.AlertOn, args.PowerMartingaleEpsilon, args.AlertThreshold)
+ public SsaAnomalyDetectionBase(SsaArguments args, string name, IHostEnvironment env)
+ : base(args.WindowSize, 0, args.Source, args.Name, name, env, args.Side, args.Martingale, args.AlertOn, args.PowerMartingaleEpsilon, args.AlertThreshold)
{
Host.CheckUserArg(2 <= args.SeasonalWindowSize, nameof(args.SeasonalWindowSize), "Must be at least 2.");
Host.CheckUserArg(0 <= args.DiscountFactor && args.DiscountFactor <= 1, nameof(args.DiscountFactor), "Must be in the range [0, 1].");
@@ -118,18 +118,13 @@ public SsaAnomalyDetectionBase(SsaArguments args, string name, IHostEnvironment
ErrorFunction = args.ErrorFunction;
ErrorFunc = ErrorFunctionUtils.GetErrorFunction(ErrorFunction);
IsAdaptive = args.IsAdaptive;
-
// Creating the master SSA model
Model = new AdaptiveSingularSpectrumSequenceModeler(Host, args.InitialWindowSize, SeasonalWindowSize + 1, SeasonalWindowSize,
DiscountFactor, null, AdaptiveSingularSpectrumSequenceModeler.RankSelectionMethod.Exact, null, SeasonalWindowSize / 2, false, false);
-
- // Training the master SSA model
- var data = new RoleMappedData(input, null, InputColumnName);
- Model.Train(data);
}
- public SsaAnomalyDetectionBase(IHostEnvironment env, ModelLoadContext ctx, string name, IDataView input)
- : base(env, ctx, name, input)
+ public SsaAnomalyDetectionBase(IHostEnvironment env, ModelLoadContext ctx, string name)
+ : base(env, ctx, name)
{
// *** Binary format ***
//
@@ -159,6 +154,20 @@ public SsaAnomalyDetectionBase(IHostEnvironment env, ModelLoadContext ctx, strin
Host.CheckDecode(Model != null);
}
+ public override Schema GetOutputSchema(Schema inputSchema)
+ {
+ Host.CheckValue(inputSchema, nameof(inputSchema));
+
+ if (!inputSchema.TryGetColumnIndex(InputColumnName, out var col))
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", InputColumnName);
+
+ var colType = inputSchema.GetColumnType(col);
+ if (colType != NumberType.R4)
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", InputColumnName, NumberType.R4.ToString(), colType.ToString());
+
+ return Transform(new EmptyDataView(Host, inputSchema)).Schema;
+ }
+
public override void Save(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
@@ -200,7 +209,6 @@ protected override void InitializeAnomalyDetector()
{
_parentAnomalyDetector = (SsaAnomalyDetectionBase)Parent;
_model = _parentAnomalyDetector.Model.Clone();
- _model.InitState();
}
protected override double ComputeRawAnomalyScore(ref Single input, FixedSizeQueue windowedBuffer, long iteration)
diff --git a/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs b/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs
index 8eca18a034..0aa1386bcc 100644
--- a/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs
+++ b/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs
@@ -3,18 +3,29 @@
// See the LICENSE file in the project root for more information.
using System;
+using System.Collections.Generic;
+using System.Linq;
+using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Model;
using Microsoft.ML.Runtime.TimeSeriesProcessing;
+using static Microsoft.ML.Runtime.TimeSeriesProcessing.SequentialAnomalyDetectionTransformBase;
-[assembly: LoadableClass(SsaChangePointDetector.Summary, typeof(SsaChangePointDetector), typeof(SsaChangePointDetector.Arguments), typeof(SignatureDataTransform),
+[assembly: LoadableClass(SsaChangePointDetector.Summary, typeof(IDataTransform), typeof(SsaChangePointDetector), typeof(SsaChangePointDetector.Arguments), typeof(SignatureDataTransform),
SsaChangePointDetector.UserName, SsaChangePointDetector.LoaderSignature, SsaChangePointDetector.ShortName)]
-[assembly: LoadableClass(SsaChangePointDetector.Summary, typeof(SsaChangePointDetector), null, typeof(SignatureLoadDataTransform),
+
+[assembly: LoadableClass(SsaChangePointDetector.Summary, typeof(IDataTransform), typeof(SsaChangePointDetector), null, typeof(SignatureLoadDataTransform),
+ SsaChangePointDetector.UserName, SsaChangePointDetector.LoaderSignature)]
+
+[assembly: LoadableClass(SsaChangePointDetector.Summary, typeof(SsaChangePointDetector), null, typeof(SignatureLoadModel),
SsaChangePointDetector.UserName, SsaChangePointDetector.LoaderSignature)]
+[assembly: LoadableClass(typeof(IRowMapper), typeof(SsaChangePointDetector), null, typeof(SignatureLoadRowMapper),
+ SsaChangePointDetector.UserName, SsaChangePointDetector.LoaderSignature)]
+
namespace Microsoft.ML.Runtime.TimeSeriesProcessing
{
///
@@ -93,8 +104,24 @@ private static VersionInfo GetVersionInfo()
loaderAssemblyName: typeof(SsaChangePointDetector).Assembly.FullName);
}
- public SsaChangePointDetector(IHostEnvironment env, Arguments args, IDataView input)
- : base(new BaseArguments(args), LoaderSignature, env, input)
+ internal SsaChangePointDetector(IHostEnvironment env, Arguments args, IDataView input)
+ : this(env, args)
+ {
+ Model.Train(new RoleMappedData(input, null, InputColumnName));
+ }
+
+ // Factory method for SignatureDataTransform.
+ private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(args, nameof(args));
+ env.CheckValue(input, nameof(input));
+
+ return new SsaChangePointDetector(env, args, input).MakeDataTransform(input);
+ }
+
+ internal SsaChangePointDetector(IHostEnvironment env, Arguments args)
+ : base(new BaseArguments(args), LoaderSignature, env)
{
switch (Martingale)
{
@@ -113,8 +140,28 @@ public SsaChangePointDetector(IHostEnvironment env, Arguments args, IDataView in
}
}
- public SsaChangePointDetector(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
- : base(env, ctx, LoaderSignature, input)
+ // Factory method for SignatureLoadDataTransform.
+ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(ctx, nameof(ctx));
+ env.CheckValue(input, nameof(input));
+
+ return new SsaChangePointDetector(env, ctx).MakeDataTransform(input);
+ }
+
+ // Factory method for SignatureLoadModel.
+ private static SsaChangePointDetector Create(IHostEnvironment env, ModelLoadContext ctx)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel(GetVersionInfo());
+
+ return new SsaChangePointDetector(env, ctx);
+ }
+
+ internal SsaChangePointDetector(IHostEnvironment env, ModelLoadContext ctx)
+ : base(env, ctx, LoaderSignature)
{
// *** Binary format ***
//
@@ -141,5 +188,86 @@ public override void Save(ModelSaveContext ctx)
base.Save(ctx);
}
+
+ // Factory method for SignatureLoadRowMapper.
+ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema)
+ => Create(env, ctx).MakeRowMapper(inputSchema);
+ }
+
+ ///
+ /// Estimator for
+ ///
+ public sealed class SsaChangePointEstimator : IEstimator
+ {
+ private readonly IHost _host;
+ private readonly SsaChangePointDetector.Arguments _args;
+
+ ///
+ /// Create a new instance of
+ ///
+ /// Host Environment.
+ /// Name of the input column.
+ /// The name of the new column.
+ /// The confidence for change point detection in the range [0, 100].
+ /// The change history length.
+ /// The change history length.
+ /// The change history length.
+ /// The function used to compute the error between the expected and the observed value.
+ /// The martingale used for scoring.
+ /// The epsilon parameter for the Power martingale.
+ public SsaChangePointEstimator(IHostEnvironment env, string inputColumn, string outputColumn,
+ int confidence, int changeHistoryLength, int trainingWindowSize, int seasonalityWindowSize,
+ ErrorFunctionUtils.ErrorFunction errorFunction = ErrorFunctionUtils.ErrorFunction.SignedDifference,
+ MartingaleType martingale = MartingaleType.Power, double eps = 0.1)
+ : this(env, new SsaChangePointDetector.Arguments
+ {
+ Name = outputColumn,
+ Source = inputColumn,
+ Confidence = confidence,
+ ChangeHistoryLength = changeHistoryLength,
+ TrainingWindowSize = trainingWindowSize,
+ SeasonalWindowSize = seasonalityWindowSize,
+ Martingale = martingale,
+ PowerMartingaleEpsilon = eps,
+ ErrorFunction = errorFunction
+ })
+ {
+ }
+
+ public SsaChangePointEstimator(IHostEnvironment env, SsaChangePointDetector.Arguments args)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ _host = env.Register(nameof(SsaChangePointEstimator));
+
+ _host.CheckNonEmpty(args.Name, nameof(args.Name));
+ _host.CheckNonEmpty(args.Source, nameof(args.Source));
+
+ _args = args;
+ }
+
+ public SsaChangePointDetector Fit(IDataView input)
+ {
+ _host.CheckValue(input, nameof(input));
+ return new SsaChangePointDetector(_host, _args, input);
+ }
+
+ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
+ {
+ _host.CheckValue(inputSchema, nameof(inputSchema));
+
+ if (!inputSchema.TryFindColumn(_args.Source, out var col))
+ throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _args.Source);
+ if (col.ItemType != NumberType.R4)
+ throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _args.Source, NumberType.R4.ToString(), col.GetTypeString());
+
+ var metadata = new List() {
+ new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false)
+ };
+ var resultDic = inputSchema.Columns.ToDictionary(x => x.Name);
+ resultDic[_args.Name] = new SchemaShape.Column(
+ _args.Name, SchemaShape.Column.VectorKind.Vector, NumberType.R8, false, new SchemaShape(metadata));
+
+ return new SchemaShape(resultDic.Values);
+ }
}
}
diff --git a/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs b/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs
index 79d6622ff5..44238e5a71 100644
--- a/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs
+++ b/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs
@@ -2,18 +2,29 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
+using System.Collections.Generic;
+using System.Linq;
+using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Model;
using Microsoft.ML.Runtime.TimeSeriesProcessing;
+using static Microsoft.ML.Runtime.TimeSeriesProcessing.SequentialAnomalyDetectionTransformBase;
-[assembly: LoadableClass(SsaSpikeDetector.Summary, typeof(SsaSpikeDetector), typeof(SsaSpikeDetector.Arguments), typeof(SignatureDataTransform),
+[assembly: LoadableClass(SsaSpikeDetector.Summary, typeof(IDataTransform), typeof(SsaSpikeDetector), typeof(SsaSpikeDetector.Arguments), typeof(SignatureDataTransform),
SsaSpikeDetector.UserName, SsaSpikeDetector.LoaderSignature, SsaSpikeDetector.ShortName)]
-[assembly: LoadableClass(SsaSpikeDetector.Summary, typeof(SsaSpikeDetector), null, typeof(SignatureLoadDataTransform),
+
+[assembly: LoadableClass(SsaSpikeDetector.Summary, typeof(IDataTransform), typeof(SsaSpikeDetector), null, typeof(SignatureLoadDataTransform),
+ SsaSpikeDetector.UserName, SsaSpikeDetector.LoaderSignature)]
+
+[assembly: LoadableClass(SsaSpikeDetector.Summary, typeof(SsaSpikeDetector), null, typeof(SignatureLoadModel),
SsaSpikeDetector.UserName, SsaSpikeDetector.LoaderSignature)]
+[assembly: LoadableClass(typeof(IRowMapper), typeof(SsaSpikeDetector), null, typeof(SignatureLoadRowMapper),
+ SsaSpikeDetector.UserName, SsaSpikeDetector.LoaderSignature)]
+
namespace Microsoft.ML.Runtime.TimeSeriesProcessing
{
///
@@ -90,14 +101,50 @@ private static VersionInfo GetVersionInfo()
loaderAssemblyName: typeof(SsaSpikeDetector).Assembly.FullName);
}
- public SsaSpikeDetector(IHostEnvironment env, Arguments args, IDataView input)
- : base(new BaseArguments(args), LoaderSignature, env, input)
+ internal SsaSpikeDetector(IHostEnvironment env, Arguments args, IDataView input)
+ : base(new BaseArguments(args), LoaderSignature, env)
+ {
+ Model.Train(new RoleMappedData(input, null, InputColumnName));
+ }
+
+ // Factory method for SignatureDataTransform.
+ private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(args, nameof(args));
+ env.CheckValue(input, nameof(input));
+
+ return new SsaSpikeDetector(env, args, input).MakeDataTransform(input);
+ }
+
+ internal SsaSpikeDetector(IHostEnvironment env, Arguments args)
+ : base(new BaseArguments(args), LoaderSignature, env)
{
// This constructor is empty.
}
- public SsaSpikeDetector(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
- : base(env, ctx, LoaderSignature, input)
+ // Factory method for SignatureLoadDataTransform.
+ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(ctx, nameof(ctx));
+ env.CheckValue(input, nameof(input));
+
+ return new SsaSpikeDetector(env, ctx).MakeDataTransform(input);
+ }
+
+ // Factory method for SignatureLoadModel.
+ private static SsaSpikeDetector Create(IHostEnvironment env, ModelLoadContext ctx)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel(GetVersionInfo());
+
+ return new SsaSpikeDetector(env, ctx);
+ }
+
+ internal SsaSpikeDetector(IHostEnvironment env, ModelLoadContext ctx)
+ : base(env, ctx, LoaderSignature)
{
// *** Binary format ***
//
@@ -122,5 +169,83 @@ public override void Save(ModelSaveContext ctx)
base.Save(ctx);
}
+
+ // Factory method for SignatureLoadRowMapper.
+ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema)
+ => Create(env, ctx).MakeRowMapper(inputSchema);
+ }
+
+ ///
+ /// Estimator for
+ ///
+ public sealed class SsaSpikeEstimator : IEstimator
+ {
+ private readonly IHost _host;
+ private readonly SsaSpikeDetector.Arguments _args;
+
+ ///
+ /// Create a new instance of
+ ///
+ /// Host Environment.
+ /// Name of the input column.
+ /// The name of the new column.
+ /// The confidence for spike detection in the range [0, 100].
+ /// The size of the sliding window for computing the p-value.
+ /// The change history length.
+ /// The change history length.
+ /// The argument that determines whether to detect positive or negative anomalies, or both.
+ /// The function used to compute the error between the expected and the observed value.
+ public SsaSpikeEstimator(IHostEnvironment env, string inputColumn, string outputColumn, int confidence,
+ int pvalueHistoryLength, int trainingWindowSize, int seasonalityWindowSize, AnomalySide side = AnomalySide.TwoSided,
+ ErrorFunctionUtils.ErrorFunction errorFunction = ErrorFunctionUtils.ErrorFunction.SignedDifference)
+ : this(env, new SsaSpikeDetector.Arguments
+ {
+ Name = outputColumn,
+ Source = inputColumn,
+ Confidence = confidence,
+ PvalueHistoryLength = pvalueHistoryLength,
+ TrainingWindowSize = trainingWindowSize,
+ SeasonalWindowSize = seasonalityWindowSize,
+ Side = side,
+ ErrorFunction = errorFunction
+ })
+ {
+ }
+
+ public SsaSpikeEstimator(IHostEnvironment env, SsaSpikeDetector.Arguments args)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ _host = env.Register(nameof(SsaSpikeEstimator));
+
+ _host.CheckNonEmpty(args.Name, nameof(args.Name));
+ _host.CheckNonEmpty(args.Source, nameof(args.Source));
+
+ _args = args;
+ }
+
+ public SsaSpikeDetector Fit(IDataView input)
+ {
+ _host.CheckValue(input, nameof(input));
+ return new SsaSpikeDetector(_host, _args, input);
+ }
+
+ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
+ {
+ _host.CheckValue(inputSchema, nameof(inputSchema));
+
+ if (!inputSchema.TryFindColumn(_args.Source, out var col))
+ throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _args.Source);
+ if (col.ItemType != NumberType.R4)
+ throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _args.Source, NumberType.R4.ToString(), col.GetTypeString());
+
+ var metadata = new List() {
+ new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false)
+ };
+ var resultDic = inputSchema.Columns.ToDictionary(x => x.Name);
+ resultDic[_args.Name] = new SchemaShape.Column(
+ _args.Name, SchemaShape.Column.VectorKind.Vector, NumberType.R8, false, new SchemaShape(metadata));
+
+ return new SchemaShape(resultDic.Values);
+ }
}
}
diff --git a/src/Microsoft.ML.TimeSeries/TimeSeriesProcessing.cs b/src/Microsoft.ML.TimeSeries/TimeSeriesProcessing.cs
index 6b4275ffa3..65c7346d55 100644
--- a/src/Microsoft.ML.TimeSeries/TimeSeriesProcessing.cs
+++ b/src/Microsoft.ML.TimeSeries/TimeSeriesProcessing.cs
@@ -30,7 +30,7 @@ public static CommonOutputs.TransformOutput ExponentialAverage(IHostEnvironment
public static CommonOutputs.TransformOutput IidChangePointDetector(IHostEnvironment env, IidChangePointDetector.Arguments input)
{
var h = EntryPointUtils.CheckArgsAndCreateHost(env, "IidChangePointDetector", input);
- var view = new IidChangePointDetector(h, input, input.Data);
+ var view = new IidChangePointEstimator(h, input).Fit(input.Data).Transform(input.Data);
return new CommonOutputs.TransformOutput()
{
Model = new TransformModel(h, view, input.Data),
@@ -42,7 +42,7 @@ public static CommonOutputs.TransformOutput IidChangePointDetector(IHostEnvironm
public static CommonOutputs.TransformOutput IidSpikeDetector(IHostEnvironment env, IidSpikeDetector.Arguments input)
{
var h = EntryPointUtils.CheckArgsAndCreateHost(env, "IidSpikeDetector", input);
- var view = new IidSpikeDetector(h, input, input.Data);
+ var view = new IidSpikeEstimator(h, input).Fit(input.Data).Transform(input.Data);
return new CommonOutputs.TransformOutput()
{
Model = new TransformModel(h, view, input.Data),
@@ -90,7 +90,7 @@ public static CommonOutputs.TransformOutput SlidingWindowTransform(IHostEnvironm
public static CommonOutputs.TransformOutput SsaChangePointDetector(IHostEnvironment env, SsaChangePointDetector.Arguments input)
{
var h = EntryPointUtils.CheckArgsAndCreateHost(env, "SsaChangePointDetector", input);
- var view = new SsaChangePointDetector(h, input, input.Data);
+ var view = new SsaChangePointEstimator(h, input).Fit(input.Data).Transform(input.Data);
return new CommonOutputs.TransformOutput()
{
Model = new TransformModel(h, view, input.Data),
@@ -102,7 +102,7 @@ public static CommonOutputs.TransformOutput SsaChangePointDetector(IHostEnvironm
public static CommonOutputs.TransformOutput SsaSpikeDetector(IHostEnvironment env, SsaSpikeDetector.Arguments input)
{
var h = EntryPointUtils.CheckArgsAndCreateHost(env, "SsaSpikeDetector", input);
- var view = new SsaSpikeDetector(h, input, input.Data);
+ var view = new SsaSpikeEstimator(h, input).Fit(input.Data).Transform(input.Data);
return new CommonOutputs.TransformOutput()
{
Model = new TransformModel(h, view, input.Data),
diff --git a/test/BaselineOutput/Common/SavePipe/SavePipeIidChangePoint-Schema.txt b/test/BaselineOutput/Common/SavePipe/SavePipeIidChangePoint-Schema.txt
index ff1d284359..dd8d7959bf 100644
--- a/test/BaselineOutput/Common/SavePipe/SavePipeIidChangePoint-Schema.txt
+++ b/test/BaselineOutput/Common/SavePipe/SavePipeIidChangePoint-Schema.txt
@@ -1,7 +1,7 @@
---- BoundLoader ----
1 columns:
Features: R4
----- IidChangePointDetector ----
+---- SequentialDataTransform ----
2 columns:
Features: R4
Anomaly: Vec
@@ -16,7 +16,7 @@
fAnomaly: Vec
Metadata 'SlotNames': Vec: Length=4, Count=4
[0] 'Alert', [1] 'Raw Score', [2] 'P-Value Score', [3] 'Martingale Score'
----- IidChangePointDetector ----
+---- SequentialDataTransform ----
4 columns:
Features: R4
Anomaly: Vec
diff --git a/test/BaselineOutput/Common/SavePipe/SavePipeIidSpike-Schema.txt b/test/BaselineOutput/Common/SavePipe/SavePipeIidSpike-Schema.txt
index 25d32dd015..659c20133c 100644
--- a/test/BaselineOutput/Common/SavePipe/SavePipeIidSpike-Schema.txt
+++ b/test/BaselineOutput/Common/SavePipe/SavePipeIidSpike-Schema.txt
@@ -1,7 +1,7 @@
---- BoundLoader ----
1 columns:
Features: R4
----- IidSpikeDetector ----
+---- SequentialDataTransform ----
2 columns:
Features: R4
Anomaly: Vec
@@ -16,7 +16,7 @@
fAnomaly: Vec
Metadata 'SlotNames': Vec: Length=3, Count=3
[0] 'Alert', [1] 'Raw Score', [2] 'P-Value Score'
----- IidSpikeDetector ----
+---- SequentialDataTransform ----
4 columns:
Features: R4
Anomaly: Vec
diff --git a/test/BaselineOutput/SingleDebug/SavePipe/SavePipeSsaSpikeNoData-Schema.txt b/test/BaselineOutput/SingleDebug/SavePipe/SavePipeSsaSpikeNoData-Schema.txt
index 764b0f98c0..f0dcf9138b 100644
--- a/test/BaselineOutput/SingleDebug/SavePipe/SavePipeSsaSpikeNoData-Schema.txt
+++ b/test/BaselineOutput/SingleDebug/SavePipe/SavePipeSsaSpikeNoData-Schema.txt
@@ -1,7 +1,7 @@
---- BoundLoader ----
1 columns:
Features: R4
----- SsaSpikeDetector ----
+---- SequentialDataTransform ----
2 columns:
Features: R4
Anomaly: Vec
diff --git a/test/BaselineOutput/SingleRelease/SavePipe/SavePipeSsaSpikeNoData-Schema.txt b/test/BaselineOutput/SingleRelease/SavePipe/SavePipeSsaSpikeNoData-Schema.txt
index 764b0f98c0..f0dcf9138b 100644
--- a/test/BaselineOutput/SingleRelease/SavePipe/SavePipeSsaSpikeNoData-Schema.txt
+++ b/test/BaselineOutput/SingleRelease/SavePipe/SavePipeSsaSpikeNoData-Schema.txt
@@ -1,7 +1,7 @@
---- BoundLoader ----
1 columns:
Features: R4
----- SsaSpikeDetector ----
+---- SequentialDataTransform ----
2 columns:
Features: R4
Anomaly: Vec
diff --git a/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs
index b0abe45a8a..125105138c 100644
--- a/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs
+++ b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs
@@ -52,12 +52,13 @@ public void ChangeDetection()
Confidence = 80,
Source = "Value",
Name = "Change",
- ChangeHistoryLength = size,
- Data = dataView
+ ChangeHistoryLength = size
};
-
- var detector = TimeSeriesProcessing.IidChangePointDetector(env, args);
- var output = detector.Model.Apply(env, dataView);
+ // Train
+ var detector = new IidChangePointEstimator(env, args).Fit(dataView);
+ // Transform
+ var output = detector.Transform(dataView);
+ // Get predictions
var enumerator = output.AsEnumerable(env, true).GetEnumerator();
Prediction row = null;
List expectedValues = new List() { 0, 5, 0.5, 5.1200000000000114E-08, 0, 5, 0.4999999995, 5.1200000046080209E-08, 0, 5, 0.4999999995, 5.1200000092160303E-08,
@@ -80,8 +81,8 @@ public void ChangePointDetectionWithSeasonality()
{
using (var env = new ConsoleEnvironment(conc: 1))
{
- const int ChangeHistorySize = 2000;
- const int SeasonalitySize = 1000;
+ const int ChangeHistorySize = 10;
+ const int SeasonalitySize = 10;
const int NumberOfSeasonsInTraining = 5;
const int MaxTrainingSize = NumberOfSeasonsInTraining * SeasonalitySize;
@@ -94,7 +95,6 @@ public void ChangePointDetectionWithSeasonality()
Source = "Value",
Name = "Change",
ChangeHistoryLength = ChangeHistorySize,
- Data = dataView,
TrainingWindowSize = MaxTrainingSize,
SeasonalWindowSize = SeasonalitySize
};
@@ -106,21 +106,24 @@ public void ChangePointDetectionWithSeasonality()
for (int i = 0; i < ChangeHistorySize; i++)
data.Add(new Data(i * 100));
- var detector = TimeSeriesProcessing.SsaChangePointDetector(env, args);
- var output = detector.Model.Apply(env, dataView);
+ // Train
+ var detector = new SsaChangePointEstimator(env, args).Fit(dataView);
+ // Transform
+ var output = detector.Transform(dataView);
+ // Get predictions
var enumerator = output.AsEnumerable(env, true).GetEnumerator();
Prediction row = null;
- List expectedValues = new List() { 0, 0, 0.5, 0, 0, 1, 0.15865526383236372,
- 0, 0, 1.6069464981555939, 0.05652458872960725, 0, 0, 2.0183047652244568, 0.11021633531076747, 0};
+ List expectedValues = new List() { 0, -3.31410598754883, 0.5, 5.12000000000001E-08, 0, 1.5700820684432983, 5.2001145245395008E-07,
+ 0.012414560443710681, 0, 1.2854313254356384, 0.28810801662678009, 0.02038940454467935, 0, -1.0950627326965332, 0.36663890634019225, 0.026956459625565483};
int index = 0;
while (enumerator.MoveNext() && index < expectedValues.Count)
{
row = enumerator.Current;
- Assert.Equal(expectedValues[index++], row.Change[0], precision: 7);
- Assert.Equal(expectedValues[index++], row.Change[1], precision: 7);
- Assert.Equal(expectedValues[index++], row.Change[2], precision: 7);
- Assert.Equal(expectedValues[index++], row.Change[3], precision: 7);
+ Assert.Equal(expectedValues[index++], row.Change[0], precision: 7); // Alert
+ Assert.Equal(expectedValues[index++], row.Change[1], precision: 7); // Raw score
+ Assert.Equal(expectedValues[index++], row.Change[2], precision: 7); // P-Value score
+ Assert.Equal(expectedValues[index++], row.Change[3], precision: 7); // Martingale score
}
}
}
diff --git a/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesEstimatorTests.cs b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesEstimatorTests.cs
new file mode 100644
index 0000000000..3d7cfd5d32
--- /dev/null
+++ b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesEstimatorTests.cs
@@ -0,0 +1,166 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime.Api;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.RunTests;
+using Microsoft.ML.Runtime.TimeSeriesProcessing;
+using System.Collections.Generic;
+using Xunit;
+using Xunit.Abstractions;
+
+namespace Microsoft.ML.Tests
+{
+ public class TimeSeriesEstimatorTests : TestDataPipeBase
+ {
+ private const int inputSize = 150528;
+
+ private class Data
+ {
+ public float Value;
+
+ public Data(float value)
+ {
+ Value = value;
+ }
+ }
+
+ private class TestDataXY
+ {
+ [VectorType(inputSize)]
+ public float[] A;
+ }
+ private class TestDataDifferntType
+ {
+ [VectorType(inputSize)]
+ public string[] data_0;
+ }
+
+ public TimeSeriesEstimatorTests(ITestOutputHelper output) : base(output)
+ {
+ }
+
+ [Fact]
+ void TestSsaChangePointEstimator()
+ {
+ int Confidence = 95;
+ int ChangeHistorySize = 10;
+ int SeasonalitySize = 10;
+ int NumberOfSeasonsInTraining = 5;
+ int MaxTrainingSize = NumberOfSeasonsInTraining * SeasonalitySize;
+
+ List data = new List();
+ var dataView = Env.CreateStreamingDataView(data);
+
+ for (int j = 0; j < NumberOfSeasonsInTraining; j++)
+ for (int i = 0; i < SeasonalitySize; i++)
+ data.Add(new Data(i));
+
+ for (int i = 0; i < ChangeHistorySize; i++)
+ data.Add(new Data(i * 100));
+
+ var pipe = new SsaChangePointEstimator(Env, "Value", "Change",
+ Confidence, ChangeHistorySize, MaxTrainingSize, SeasonalitySize);
+
+ var xyData = new List { new TestDataXY() { A = new float[inputSize] } };
+ var stringData = new List { new TestDataDifferntType() { data_0 = new string[inputSize] } };
+
+ var invalidDataWrongNames = ComponentCreation.CreateDataView(Env, xyData);
+ var invalidDataWrongTypes = ComponentCreation.CreateDataView(Env, stringData);
+
+ TestEstimatorCore(pipe, dataView, invalidInput: invalidDataWrongTypes);
+ TestEstimatorCore(pipe, dataView, invalidInput: invalidDataWrongNames);
+
+ Done();
+ }
+
+ [Fact]
+ void TestSsaSpikeEstimator()
+ {
+ int Confidence = 95;
+ int PValueHistorySize = 10;
+ int SeasonalitySize = 10;
+ int NumberOfSeasonsInTraining = 5;
+ int MaxTrainingSize = NumberOfSeasonsInTraining * SeasonalitySize;
+
+ List data = new List();
+ var dataView = Env.CreateStreamingDataView(data);
+
+ for (int j = 0; j < NumberOfSeasonsInTraining; j++)
+ for (int i = 0; i < SeasonalitySize; i++)
+ data.Add(new Data(i));
+
+ for (int i = 0; i < PValueHistorySize; i++)
+ data.Add(new Data(i * 100));
+
+ var pipe = new SsaSpikeEstimator(Env, "Value", "Change",
+ Confidence, PValueHistorySize, MaxTrainingSize, SeasonalitySize);
+
+ var xyData = new List { new TestDataXY() { A = new float[inputSize] } };
+ var stringData = new List { new TestDataDifferntType() { data_0 = new string[inputSize] } };
+
+ var invalidDataWrongNames = ComponentCreation.CreateDataView(Env, xyData);
+ var invalidDataWrongTypes = ComponentCreation.CreateDataView(Env, stringData);
+
+ TestEstimatorCore(pipe, dataView, invalidInput: invalidDataWrongTypes);
+ TestEstimatorCore(pipe, dataView, invalidInput: invalidDataWrongNames);
+
+ Done();
+ }
+
+ [Fact]
+ void TestIidChangePointEstimator()
+ {
+ int Confidence = 95;
+ int ChangeHistorySize = 10;
+
+ List data = new List();
+ var dataView = Env.CreateStreamingDataView(data);
+
+ for (int i = 0; i < ChangeHistorySize; i++)
+ data.Add(new Data(i * 100));
+
+ var pipe = new IidChangePointEstimator(Env,
+ "Value", "Change", Confidence, ChangeHistorySize);
+
+ var xyData = new List { new TestDataXY() { A = new float[inputSize] } };
+ var stringData = new List { new TestDataDifferntType() { data_0 = new string[inputSize] } };
+
+ var invalidDataWrongNames = ComponentCreation.CreateDataView(Env, xyData);
+ var invalidDataWrongTypes = ComponentCreation.CreateDataView(Env, stringData);
+
+ TestEstimatorCore(pipe, dataView, invalidInput: invalidDataWrongTypes);
+ TestEstimatorCore(pipe, dataView, invalidInput: invalidDataWrongNames);
+
+ Done();
+ }
+
+ [Fact]
+ void TestIidSpikeEstimator()
+ {
+ int Confidence = 95;
+ int PValueHistorySize = 10;
+
+ List data = new List();
+ var dataView = Env.CreateStreamingDataView(data);
+
+ for (int i = 0; i < PValueHistorySize; i++)
+ data.Add(new Data(i * 100));
+
+ var pipe = new IidSpikeEstimator(Env,
+ "Value", "Change", Confidence, PValueHistorySize);
+
+ var xyData = new List { new TestDataXY() { A = new float[inputSize] } };
+ var stringData = new List { new TestDataDifferntType() { data_0 = new string[inputSize] } };
+
+ var invalidDataWrongNames = ComponentCreation.CreateDataView(Env, xyData);
+ var invalidDataWrongTypes = ComponentCreation.CreateDataView(Env, stringData);
+
+ TestEstimatorCore(pipe, dataView, invalidInput: invalidDataWrongTypes);
+ TestEstimatorCore(pipe, dataView, invalidInput: invalidDataWrongNames);
+
+ Done();
+ }
+ }
+}