diff --git a/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs b/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs index 5c01472ed9..6046fabc72 100644 --- a/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs +++ b/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs @@ -221,13 +221,16 @@ private static VersionInfo GetVersionInfo() { return new VersionInfo( modelSignature: "SSAMODLR", - verWrittenCur: 0x00010001, // Initial - verReadableCur: 0x00010001, + //verWrittenCur: 0x00010001, // Initial + verWrittenCur: 0x00010002, // Added saving _state and _nextPrediction + verReadableCur: 0x00010002, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, loaderAssemblyName: typeof(AdaptiveSingularSpectrumSequenceModeler).Assembly.FullName); } + private const int VersionSavingStateAndPrediction = 0x00010002; + /// /// The constructor for Adaptive SSA model. /// @@ -333,6 +336,7 @@ private AdaptiveSingularSpectrumSequenceModeler(AdaptiveSingularSpectrumSequence _autoregressionNoiseVariance = model._autoregressionNoiseVariance; _observationNoiseMean = model._observationNoiseMean; _autoregressionNoiseMean = model._autoregressionNoiseMean; + _nextPrediction = model._nextPrediction; _maxTrendRatio = model._maxTrendRatio; _shouldStablize = model._shouldStablize; _shouldMaintainInfo = model._shouldMaintainInfo; @@ -368,11 +372,13 @@ public AdaptiveSingularSpectrumSequenceModeler(IHostEnvironment env, ModelLoadCo // RankSelectionMethod: _rankSelectionMethod // bool: isWeightSet // float[]: _alpha + // float[]: _state // bool: ShouldComputeForecastIntervals // float: _observationNoiseVariance // float: _autoregressionNoiseVariance // float: _observationNoiseMean // float: _autoregressionNoiseMean + // float: _nextPrediction // int: _maxRank // bool: _shouldStablize // bool: _shouldMaintainInfo @@ -404,6 +410,14 @@ public AdaptiveSingularSpectrumSequenceModeler(IHostEnvironment env, ModelLoadCo _alpha = ctx.Reader.ReadFloatArray(); _host.CheckDecode(Utils.Size(_alpha) == _windowSize - 1); + if (ctx.Header.ModelVerReadable >= VersionSavingStateAndPrediction) + { + _state = ctx.Reader.ReadFloatArray(); + _host.CheckDecode(Utils.Size(_state) == _windowSize - 1); + } + else + _state = new Single[_windowSize - 1]; + ShouldComputeForecastIntervals = ctx.Reader.ReadBoolByte(); _observationNoiseVariance = ctx.Reader.ReadSingle(); @@ -414,6 +428,8 @@ public AdaptiveSingularSpectrumSequenceModeler(IHostEnvironment env, ModelLoadCo _observationNoiseMean = ctx.Reader.ReadSingle(); _autoregressionNoiseMean = ctx.Reader.ReadSingle(); + if (ctx.Header.ModelVerReadable >= VersionSavingStateAndPrediction) + _nextPrediction = ctx.Reader.ReadSingle(); _maxRank = ctx.Reader.ReadInt32(); _host.CheckDecode(1 <= _maxRank && _maxRank <= _windowSize - 1); @@ -444,7 +460,7 @@ public AdaptiveSingularSpectrumSequenceModeler(IHostEnvironment env, ModelLoadCo } _buffer = new FixedSizeQueue(_seriesLength); - _state = new Single[_windowSize - 1]; + _x = new CpuAlignedVector(_windowSize, SseUtils.CbAlign); _xSmooth = new CpuAlignedVector(_windowSize, SseUtils.CbAlign); } @@ -475,11 +491,13 @@ public void Save(ModelSaveContext ctx) // RankSelectionMethod: _rankSelectionMethod // bool: _isWeightSet // float[]: _alpha + // float[]: _state // bool: ShouldComputeForecastIntervals // float: _observationNoiseVariance // float: _autoregressionNoiseVariance // float: _observationNoiseMean // float: _autoregressionNoiseMean + // float: _nextPrediction // int: _maxRank // bool: _shouldStablize // bool: _shouldMaintainInfo @@ -494,11 +512,13 @@ public void Save(ModelSaveContext ctx) ctx.Writer.Write((byte)_rankSelectionMethod); ctx.Writer.WriteBoolByte(_wTrans != null); ctx.Writer.WriteFloatArray(_alpha); + ctx.Writer.WriteFloatArray(_state); ctx.Writer.WriteBoolByte(ShouldComputeForecastIntervals); ctx.Writer.Write(_observationNoiseVariance); ctx.Writer.Write(_autoregressionNoiseVariance); ctx.Writer.Write(_observationNoiseMean); ctx.Writer.Write(_autoregressionNoiseMean); + ctx.Writer.Write(_nextPrediction); ctx.Writer.Write(_maxRank); ctx.Writer.WriteBoolByte(_shouldStablize); ctx.Writer.WriteBoolByte(_shouldMaintainInfo); diff --git a/src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs b/src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs index bb14178827..900407adec 100644 --- a/src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs +++ b/src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs @@ -15,18 +15,32 @@ namespace Microsoft.ML.Runtime.TimeSeriesProcessing /// public abstract class IidAnomalyDetectionBase : SequentialAnomalyDetectionTransformBase { - public IidAnomalyDetectionBase(ArgumentsBase args, string name, IHostEnvironment env, IDataView input) - : base(args, name, env, input) + public IidAnomalyDetectionBase(ArgumentsBase args, string name, IHostEnvironment env) + : base(args, name, env) { InitialWindowSize = 0; } - public IidAnomalyDetectionBase(IHostEnvironment env, ModelLoadContext ctx, string name, IDataView input) - : base(env, ctx, name, input) + public IidAnomalyDetectionBase(IHostEnvironment env, ModelLoadContext ctx, string name) + : base(env, ctx, name) { Host.CheckDecode(InitialWindowSize == 0); } + public override Schema GetOutputSchema(Schema inputSchema) + { + Host.CheckValue(inputSchema, nameof(inputSchema)); + + if (!inputSchema.TryGetColumnIndex(InputColumnName, out var col)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", InputColumnName); + + var colType = inputSchema.GetColumnType(col); + if (colType != NumberType.R4) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", InputColumnName, NumberType.R4.ToString(), colType.ToString()); + + return Transform(new EmptyDataView(Host, inputSchema)).Schema; + } + public override void Save(ModelSaveContext ctx) { ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs b/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs index 909bbc390c..9bb6dcab95 100644 --- a/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs +++ b/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs @@ -3,24 +3,35 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.TimeSeriesProcessing; +using static Microsoft.ML.Runtime.TimeSeriesProcessing.SequentialAnomalyDetectionTransformBase; -[assembly: LoadableClass(IidChangePointDetector.Summary, typeof(IidChangePointDetector), typeof(IidChangePointDetector.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(IidChangePointDetector.Summary, typeof(IDataTransform), typeof(IidChangePointDetector), typeof(IidChangePointDetector.Arguments), typeof(SignatureDataTransform), IidChangePointDetector.UserName, IidChangePointDetector.LoaderSignature, IidChangePointDetector.ShortName)] -[assembly: LoadableClass(IidChangePointDetector.Summary, typeof(IidChangePointDetector), null, typeof(SignatureLoadDataTransform), + +[assembly: LoadableClass(IidChangePointDetector.Summary, typeof(IDataTransform), typeof(IidChangePointDetector), null, typeof(SignatureLoadDataTransform), + IidChangePointDetector.UserName, IidChangePointDetector.LoaderSignature)] + +[assembly: LoadableClass(IidChangePointDetector.Summary, typeof(IidChangePointDetector), null, typeof(SignatureLoadModel), IidChangePointDetector.UserName, IidChangePointDetector.LoaderSignature)] +[assembly: LoadableClass(typeof(IRowMapper), typeof(IidChangePointDetector), null, typeof(SignatureLoadRowMapper), + IidChangePointDetector.UserName, IidChangePointDetector.LoaderSignature)] + namespace Microsoft.ML.Runtime.TimeSeriesProcessing { /// /// This class implements the change point detector transform for an i.i.d. sequence based on adaptive kernel density estimation and martingales. /// - public sealed class IidChangePointDetector : IidAnomalyDetectionBase, ITransformTemplate + public sealed class IidChangePointDetector : IidAnomalyDetectionBase { internal const string Summary = "This transform detects the change-points in an i.i.d. sequence using adaptive kernel density estimation and martingales."; public const string LoaderSignature = "IidChangePointDetector"; @@ -89,8 +100,18 @@ private static VersionInfo GetVersionInfo() loaderAssemblyName: typeof(IidChangePointDetector).Assembly.FullName); } - public IidChangePointDetector(IHostEnvironment env, Arguments args, IDataView input) - : base(new BaseArguments(args), LoaderSignature, env, input) + // Factory method for SignatureDataTransform. + private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(args, nameof(args)); + env.CheckValue(input, nameof(input)); + + return new IidChangePointDetector(env, args).MakeDataTransform(input); + } + + internal IidChangePointDetector(IHostEnvironment env, Arguments args) + : base(new BaseArguments(args), LoaderSignature, env) { switch (Martingale) { @@ -109,8 +130,28 @@ public IidChangePointDetector(IHostEnvironment env, Arguments args, IDataView in } } - public IidChangePointDetector(IHostEnvironment env, ModelLoadContext ctx, IDataView input) - : base(env, ctx, LoaderSignature, input) + // Factory method for SignatureLoadDataTransform. + private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(ctx, nameof(ctx)); + env.CheckValue(input, nameof(input)); + + return new IidChangePointDetector(env, ctx).MakeDataTransform(input); + } + + // Factory method for SignatureLoadModel. + private static IidChangePointDetector Create(IHostEnvironment env, ModelLoadContext ctx) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); + + return new IidChangePointDetector(env, ctx); + } + + internal IidChangePointDetector(IHostEnvironment env, ModelLoadContext ctx) + : base(env, ctx, LoaderSignature) { // *** Binary format *** // @@ -119,8 +160,8 @@ public IidChangePointDetector(IHostEnvironment env, ModelLoadContext ctx, IDataV Host.CheckDecode(Side == AnomalySide.TwoSided); } - private IidChangePointDetector(IHostEnvironment env, IidChangePointDetector transform, IDataView newSource) - : base(new BaseArguments(transform), LoaderSignature, env, newSource) + private IidChangePointDetector(IHostEnvironment env, IidChangePointDetector transform) + : base(new BaseArguments(transform), LoaderSignature, env) { } @@ -139,9 +180,65 @@ public override void Save(ModelSaveContext ctx) base.Save(ctx); } - public IDataTransform ApplyToData(IHostEnvironment env, IDataView newSource) + // Factory method for SignatureLoadRowMapper. + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); + } + + /// + /// Estimator for + /// + public sealed class IidChangePointEstimator : TrivialEstimator + { + /// + /// Create a new instance of + /// + /// Host Environment. + /// Name of the input column. + /// The name of the new column. + /// The confidence for change point detection in the range [0, 100]. + /// The change history length. + /// The martingale used for scoring. + /// The epsilon parameter for the Power martingale. + public IidChangePointEstimator(IHostEnvironment env, string inputColumn, string outputColumn, int confidence, + int changeHistoryLength, MartingaleType martingale = MartingaleType.Power, double eps = 0.1) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(IidChangePointEstimator)), + new IidChangePointDetector(env, new IidChangePointDetector.Arguments + { + Name = outputColumn, + Source = inputColumn, + Confidence = confidence, + ChangeHistoryLength = changeHistoryLength, + Martingale = martingale, + PowerMartingaleEpsilon = eps + })) + { + } + + public IidChangePointEstimator(IHostEnvironment env, IidChangePointDetector.Arguments args) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(IidChangePointEstimator)), + new IidChangePointDetector(env, args)) { - return new IidChangePointDetector(env, this, newSource); + } + + public override SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + Host.CheckValue(inputSchema, nameof(inputSchema)); + + if (!inputSchema.TryFindColumn(Transformer.InputColumnName, out var col)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Transformer.InputColumnName); + if (col.ItemType != NumberType.R4) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Transformer.InputColumnName, NumberType.R4.ToString(), col.GetTypeString()); + + var metadata = new List() { + new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false) + }; + var resultDic = inputSchema.Columns.ToDictionary(x => x.Name); + + resultDic[Transformer.OutputColumnName] = new SchemaShape.Column( + Transformer.OutputColumnName, SchemaShape.Column.VectorKind.Vector, NumberType.R8, false, new SchemaShape(metadata)); + + return new SchemaShape(resultDic.Values); } } } diff --git a/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs b/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs index f4d30c8696..cd7dafc83c 100644 --- a/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs +++ b/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs @@ -2,24 +2,35 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.TimeSeriesProcessing; +using static Microsoft.ML.Runtime.TimeSeriesProcessing.SequentialAnomalyDetectionTransformBase; -[assembly: LoadableClass(IidSpikeDetector.Summary, typeof(IidSpikeDetector), typeof(IidSpikeDetector.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(IidSpikeDetector.Summary, typeof(IDataTransform), typeof(IidSpikeDetector), typeof(IidSpikeDetector.Arguments), typeof(SignatureDataTransform), IidSpikeDetector.UserName, IidSpikeDetector.LoaderSignature, IidSpikeDetector.ShortName)] -[assembly: LoadableClass(IidSpikeDetector.Summary, typeof(IidSpikeDetector), null, typeof(SignatureLoadDataTransform), + +[assembly: LoadableClass(IidSpikeDetector.Summary, typeof(IDataTransform), typeof(IidSpikeDetector), null, typeof(SignatureLoadDataTransform), + IidSpikeDetector.UserName, IidSpikeDetector.LoaderSignature)] + +[assembly: LoadableClass(IidSpikeDetector.Summary, typeof(IidSpikeDetector), null, typeof(SignatureLoadModel), IidSpikeDetector.UserName, IidSpikeDetector.LoaderSignature)] +[assembly: LoadableClass(typeof(IRowMapper), typeof(IidSpikeDetector), null, typeof(SignatureLoadRowMapper), + IidSpikeDetector.UserName, IidSpikeDetector.LoaderSignature)] + namespace Microsoft.ML.Runtime.TimeSeriesProcessing { /// /// This class implements the spike detector transform for an i.i.d. sequence based on adaptive kernel density estimation. /// - public sealed class IidSpikeDetector : IidAnomalyDetectionBase, ITransformTemplate + public sealed class IidSpikeDetector : IidAnomalyDetectionBase { internal const string Summary = "This transform detects the spikes in a i.i.d. sequence using adaptive kernel density estimation."; public const string LoaderSignature = "IidSpikeDetector"; @@ -85,22 +96,52 @@ private static VersionInfo GetVersionInfo() loaderAssemblyName: typeof(IidSpikeDetector).Assembly.FullName); } - public IidSpikeDetector(IHostEnvironment env, Arguments args, IDataView input) - : base(new BaseArguments(args), LoaderSignature, env, input) + // Factory method for SignatureDataTransform. + private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(args, nameof(args)); + env.CheckValue(input, nameof(input)); + + return new IidSpikeDetector(env, args).MakeDataTransform(input); + } + + internal IidSpikeDetector(IHostEnvironment env, Arguments args) + : base(new BaseArguments(args), LoaderSignature, env) { // This constructor is empty. } - public IidSpikeDetector(IHostEnvironment env, ModelLoadContext ctx, IDataView input) - : base(env, ctx, LoaderSignature, input) + // Factory method for SignatureLoadDataTransform. + private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(ctx, nameof(ctx)); + env.CheckValue(input, nameof(input)); + + return new IidSpikeDetector(env, ctx).MakeDataTransform(input); + } + + // Factory method for SignatureLoadModel. + private static IidSpikeDetector Create(IHostEnvironment env, ModelLoadContext ctx) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); + + return new IidSpikeDetector(env, ctx); + } + + public IidSpikeDetector(IHostEnvironment env, ModelLoadContext ctx) + : base(env, ctx, LoaderSignature) { // *** Binary format *** // Host.CheckDecode(ThresholdScore == AlertingScore.PValueScore); } - private IidSpikeDetector(IHostEnvironment env, IidSpikeDetector transform, IDataView newSource) - : base(new BaseArguments(transform), LoaderSignature, env, newSource) + private IidSpikeDetector(IHostEnvironment env, IidSpikeDetector transform) + : base(new BaseArguments(transform), LoaderSignature, env) { } @@ -118,9 +159,60 @@ public override void Save(ModelSaveContext ctx) base.Save(ctx); } - public IDataTransform ApplyToData(IHostEnvironment env, IDataView newSource) + // Factory method for SignatureLoadRowMapper. + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); + } + + /// + /// Estimator for + /// + public sealed class IidSpikeEstimator : TrivialEstimator + { + /// + /// Create a new instance of + /// + /// Host Environment. + /// Name of the input column. + /// The name of the new column. + /// The confidence for spike detection in the range [0, 100]. + /// The size of the sliding window for computing the p-value. + /// The argument that determines whether to detect positive or negative anomalies, or both. + public IidSpikeEstimator(IHostEnvironment env, string inputColumn, string outputColumn, int confidence, int pvalueHistoryLength, AnomalySide side = AnomalySide.TwoSided) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(IidSpikeDetector)), + new IidSpikeDetector(env, new IidSpikeDetector.Arguments + { + Name = outputColumn, + Source = inputColumn, + Confidence = confidence, + PvalueHistoryLength = pvalueHistoryLength, + Side = side + })) { - return new IidSpikeDetector(env, this, newSource); + } + + public IidSpikeEstimator(IHostEnvironment env, IidSpikeDetector.Arguments args) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(IidSpikeEstimator)), new IidSpikeDetector(env, args)) + { + } + + public override SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + Host.CheckValue(inputSchema, nameof(inputSchema)); + + if (!inputSchema.TryFindColumn(Transformer.InputColumnName, out var col)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Transformer.InputColumnName); + if (col.ItemType != NumberType.R4) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Transformer.InputColumnName, NumberType.R4.ToString(), col.GetTypeString()); + + var metadata = new List() { + new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false) + }; + var resultDic = inputSchema.Columns.ToDictionary(x => x.Name); + resultDic[Transformer.OutputColumnName] = new SchemaShape.Column( + Transformer.OutputColumnName, SchemaShape.Column.VectorKind.Vector, NumberType.R8, false, new SchemaShape(metadata)); + + return new SchemaShape(resultDic.Values); } } } diff --git a/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs b/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs index 240d05e1c0..e80f41a985 100644 --- a/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs +++ b/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs @@ -24,7 +24,7 @@ namespace Microsoft.ML.Runtime.TimeSeriesProcessing /// /// The type of the input sequence /// The type of the state object for sequential anomaly detection. Must be a class inherited from AnomalyDetectionStateBase - public abstract class SequentialAnomalyDetectionTransformBase : SequentialTransformBase, TState> + public abstract class SequentialAnomalyDetectionTransformBase : SequentialTransformerBase, TState> where TState : SequentialAnomalyDetectionTransformBase.AnomalyDetectionStateBase, new() { /// @@ -144,10 +144,6 @@ public abstract class ArgumentsBase // The size of the VBuffer in the dst column. private int _outputLength; - private readonly SchemaImpl _wrappedSchema; - - public override Schema Schema => _wrappedSchema.AsSchema; - private static int GetOutputLength(AlertingScore alertingScore, IHostEnvironment host) { switch (alertingScore) @@ -163,23 +159,10 @@ private static int GetOutputLength(AlertingScore alertingScore, IHostEnvironment } } - private static SchemaImpl CreateSchema(ISchema parentSchema, string colName, int length) - { - Contracts.AssertValue(parentSchema); - Contracts.Assert(2 <= length && length <= 4); - - string[] names = { "Alert", "Raw Score", "P-Value Score", "Martingale Score" }; - int col; - bool result = parentSchema.TryGetColumnIndex(colName, out col); - Contracts.Assert(result); - - return new SchemaImpl(parentSchema, col, names, length); - } - - protected SequentialAnomalyDetectionTransformBase(int windowSize, int initialWindowSize, string inputColumnName, string outputColumnName, string name, IHostEnvironment env, IDataView input, + protected SequentialAnomalyDetectionTransformBase(int windowSize, int initialWindowSize, string inputColumnName, string outputColumnName, string name, IHostEnvironment env, AnomalySide anomalySide, MartingaleType martingale, AlertingScore alertingScore, Double powerMartingaleEpsilon, Double alertThreshold) - : base(windowSize, initialWindowSize, inputColumnName, outputColumnName, name, env, input, new VectorType(NumberType.R8, GetOutputLength(alertingScore, env))) + : base(Contracts.CheckRef(env, nameof(env)).Register(name), windowSize, initialWindowSize, inputColumnName, outputColumnName, new VectorType(NumberType.R8, GetOutputLength(alertingScore, env))) { Host.CheckUserArg(Enum.IsDefined(typeof(MartingaleType), martingale), nameof(ArgumentsBase.Martingale), "Value is undefined."); Host.CheckUserArg(Enum.IsDefined(typeof(AnomalySide), anomalySide), nameof(ArgumentsBase.Side), "Value is undefined."); @@ -198,17 +181,16 @@ protected SequentialAnomalyDetectionTransformBase(int windowSize, int initialWin PowerMartingaleEpsilon = powerMartingaleEpsilon; AlertThreshold = alertThreshold; _outputLength = GetOutputLength(ThresholdScore, Host); - _wrappedSchema = CreateSchema(base.Schema, outputColumnName, _outputLength); } - protected SequentialAnomalyDetectionTransformBase(ArgumentsBase args, string name, IHostEnvironment env, IDataView input) - : this(args.WindowSize, args.InitialWindowSize, args.Source, args.Name, name, env, input, args.Side, args.Martingale, + protected SequentialAnomalyDetectionTransformBase(ArgumentsBase args, string name, IHostEnvironment env) + : this(args.WindowSize, args.InitialWindowSize, args.Source, args.Name, name, env, args.Side, args.Martingale, args.AlertOn, args.PowerMartingaleEpsilon, args.AlertThreshold) { } - protected SequentialAnomalyDetectionTransformBase(IHostEnvironment env, ModelLoadContext ctx, string name, IDataView input) - : base(env, ctx, name, input) + protected SequentialAnomalyDetectionTransformBase(IHostEnvironment env, ModelLoadContext ctx, string name) + : base(Contracts.CheckRef(env, nameof(env)).Register(name), ctx) { // *** Binary format *** // @@ -242,7 +224,6 @@ protected SequentialAnomalyDetectionTransformBase(IHostEnvironment env, ModelLoa Host.CheckDecode(ThresholdScore != AlertingScore.PValueScore || (0 <= AlertThreshold && AlertThreshold <= 1)); _outputLength = GetOutputLength(ThresholdScore, Host); - _wrappedSchema = CreateSchema(base.Schema, OutputColumnName, _outputLength); } public override void Save(ModelSaveContext ctx) @@ -557,102 +538,86 @@ protected override sealed void InitializeStateCore() protected abstract Double ComputeRawAnomalyScore(ref TInput input, FixedSizeQueue windowedBuffer, long iteration); } - /// - /// Schema implementation to add slot name metadata to the produced output column. - /// - private sealed class SchemaImpl : ISchema - { - private readonly ISchema _parent; - private readonly int _col; - private readonly ColumnType _type; - private readonly string[] _names; - private readonly int _namesLength; - private readonly MetadataUtils.MetadataGetter>> _getter; - - public Schema AsSchema { get; } + protected override IRowMapper MakeRowMapper(ISchema schema) => new Mapper(Host, this, schema); - public int ColumnCount { get { return _parent.ColumnCount; } } + private sealed class Mapper : IRowMapper + { + private readonly IHost _host; + private readonly SequentialAnomalyDetectionTransformBase _parent; + private readonly ISchema _parentSchema; + private readonly int _inputColumnIndex; + private readonly VBuffer> _slotNames; - /// - /// Constructs the schema. - /// - /// The schema we will wrap. - /// Aside from presenting that additional piece of metadata, the constructed schema - /// will appear identical to this input schema. - /// The column in that has the metadata. - /// - /// - public SchemaImpl(ISchema schema, int col, string[] names, int length) + public Mapper(IHostEnvironment env, SequentialAnomalyDetectionTransformBase parent, ISchema inputSchema) { - Contracts.Assert(length > 0); - Contracts.Assert(Utils.Size(names) >= length); - Contracts.AssertValue(schema); - Contracts.Assert(0 <= col && col < schema.ColumnCount); - _parent = schema; - _col = col; - - _names = names; - _namesLength = length; - - _type = new VectorType(TextType.Instance, _namesLength); - Contracts.AssertValue(_type); - _getter = GetSlotNames; - - AsSchema = Schema.Create(this); + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(nameof(Mapper)); + _host.CheckValue(inputSchema, nameof(inputSchema)); + _host.CheckValue(parent, nameof(parent)); + + if (!inputSchema.TryGetColumnIndex(parent.InputColumnName, out _inputColumnIndex)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", parent.InputColumnName); + + var colType = inputSchema.GetColumnType(_inputColumnIndex); + if (colType != NumberType.R4) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", parent.InputColumnName, NumberType.R4.ToString(), colType.ToString()); + + _parent = parent; + _parentSchema = inputSchema; + _slotNames = new VBuffer>(4, new[] { "Alert".AsMemory(), "Raw Score".AsMemory(), + "P-Value Score".AsMemory(), "Martingale Score".AsMemory() }); } - public bool TryGetColumnIndex(string name, out int col) + public Schema.Column[] GetOutputColumns() { - return _parent.TryGetColumnIndex(name, out col); + var meta = new Schema.Metadata.Builder(); + meta.AddSlotNames(_parent._outputLength, GetSlotNames); + var info = new Schema.Column[1]; + info[0] = new Schema.Column(_parent.OutputColumnName, new VectorType(NumberType.R8, _parent._outputLength), meta.GetMetadata()); + return info; } - public string GetColumnName(int col) - { - return _parent.GetColumnName(col); - } + public void GetSlotNames(ref VBuffer> dst) => _slotNames.CopyTo(ref dst, 0, _parent._outputLength); - public ColumnType GetColumnType(int col) + public Func GetDependencies(Func activeOutput) { - return _parent.GetColumnType(col); + if (activeOutput(0)) + return col => col == _inputColumnIndex; + else + return col => false; } - public IEnumerable> GetMetadataTypes(int col) - { - var result = _parent.GetMetadataTypes(col); - if (col == _col) - return result.Prepend(_type.GetPair(MetadataUtils.Kinds.SlotNames)); - return result; - } + public void Save(ModelSaveContext ctx) => _parent.Save(ctx); - public ColumnType GetMetadataTypeOrNull(string kind, int col) + public Delegate[] CreateGetters(IRow input, Func activeOutput, out Action disposer) { - if (col == _col && kind == MetadataUtils.Kinds.SlotNames) - return _type; - return _parent.GetMetadataTypeOrNull(kind, col); + disposer = null; + var getters = new Delegate[1]; + if (activeOutput(0)) + { + TState state = new TState(); + state.InitState(_parent.WindowSize, _parent.InitialWindowSize, _parent, _host); + getters[0] = MakeGetter(input, state); + } + return getters; } - public void GetSlotNames(int col, ref VBuffer> slotNames) - { - Contracts.Assert(col == _col); - - var result = slotNames.Values; - if (Utils.Size(result) < _namesLength) - result = new ReadOnlyMemory[_namesLength]; + private delegate void ProcessData(ref TInput src, ref VBuffer dst); - for (int i = 0; i < _namesLength; ++i) - result[i] = _names[i].AsMemory(); - - slotNames = new VBuffer>(_namesLength, result, slotNames.Indices); - } - - public void GetMetadata(string kind, int col, ref TValue value) + private Delegate MakeGetter(IRow input, TState state) { - if (col == _col && kind == MetadataUtils.Kinds.SlotNames) + _host.AssertValue(input); + var srcGetter = input.GetGetter(_inputColumnIndex); + ProcessData processData = _parent.WindowSize > 0 ? + (ProcessData) state.Process : state.ProcessWithoutBuffer; + ValueGetter > valueGetter = (ref VBuffer dst) => { - _getter.Marshal(col, ref value); - return; - } - _parent.GetMetadata(kind, col, ref value); + TInput src = default; + srcGetter(ref src); + processData(ref src, ref dst); + }; + + return valueGetter; } } } diff --git a/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs b/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs new file mode 100644 index 0000000000..5e2b32a81b --- /dev/null +++ b/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs @@ -0,0 +1,406 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Data.IO; +using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Core.Data; + +namespace Microsoft.ML.Runtime.TimeSeriesProcessing +{ + /// + /// The base class for sequential processing transforms. This class implements the basic sliding window buffering. The derived classes need to specify the transform logic, + /// the initialization logic and the learning logic via implementing the abstract methods TransformCore(), InitializeStateCore() and LearnStateFromDataCore(), respectively + /// + /// The input type of the sequential processing. + /// The dst type of the sequential processing. + /// The state type of the sequential processing. Must be a class inherited from StateBase + public abstract class SequentialTransformerBase : ITransformer, ICanSaveModel + where TState : SequentialTransformerBase.StateBase, new() + { + /// + /// The base class for encapsulating the State object for sequential processing. This class implements a windowed buffer. + /// + public abstract class StateBase + { + // Ideally this class should be private. However, due to the current constraints with the LambdaTransform, we need to have + // access to the state class when inheriting from SequentialTransformerBase. + protected IHost Host; + + /// + /// A reference to the parent transform that operates on the state object. + /// + protected SequentialTransformerBase ParentTransform; + + /// + /// The internal windowed buffer for buffering the values in the input sequence. + /// + protected FixedSizeQueue WindowedBuffer; + + /// + /// The buffer used to buffer the training data points. + /// + protected FixedSizeQueue InitialWindowedBuffer; + + protected int WindowSize { get; private set; } + + protected int InitialWindowSize { get; private set; } + + /// + /// Counts the number of rows observed by the transform so far. + /// + protected long RowCounter { get; private set; } + + protected long IncrementRowCounter() + { + RowCounter++; + return RowCounter; + } + + private bool _isIniatilized; + + /// + /// This method sets the window size and initializes the buffer only once. + /// Since the class needs to implement a default constructor, this methods provides a mechanism to initialize the window size and buffer. + /// + /// The size of the windowed buffer + /// The size of the windowed initial buffer used for training + /// The parent transform of this state object + /// The host + public void InitState(int windowSize, int initialWindowSize, SequentialTransformerBase parentTransform, IHost host) + { + Contracts.CheckValue(host, nameof(host), "The host cannot be null."); + host.Check(!_isIniatilized, "The window size can be set only once."); + host.CheckValue(parentTransform, nameof(parentTransform)); + host.CheckParam(windowSize >= 0, nameof(windowSize), "Must be non-negative."); + host.CheckParam(initialWindowSize >= 0, nameof(initialWindowSize), "Must be non-negative."); + + Host = host; + WindowSize = windowSize; + InitialWindowSize = initialWindowSize; + ParentTransform = parentTransform; + WindowedBuffer = (WindowSize > 0) ? new FixedSizeQueue(WindowSize) : new FixedSizeQueue(1); + InitialWindowedBuffer = (InitialWindowSize > 0) ? new FixedSizeQueue(InitialWindowSize) : new FixedSizeQueue(1); + RowCounter = 0; + + InitializeStateCore(); + _isIniatilized = true; + } + + /// + /// This method implements the basic resetting mechanism for a state object and clears the buffer. + /// + public virtual void Reset() + { + Host.Assert(_isIniatilized); + Host.Assert(WindowedBuffer != null); + Host.Assert(InitialWindowedBuffer != null); + + RowCounter = 0; + WindowedBuffer.Clear(); + InitialWindowedBuffer.Clear(); + } + + public void Process(ref TInput input, ref TOutput output) + { + if (InitialWindowedBuffer.Count < InitialWindowSize) + { + InitialWindowedBuffer.AddLast(input); + SetNaOutput(ref output); + + if (InitialWindowedBuffer.Count >= InitialWindowSize - WindowSize) + WindowedBuffer.AddLast(input); + + if (InitialWindowedBuffer.Count == InitialWindowSize) + LearnStateFromDataCore(InitialWindowedBuffer); + } + else + { + TransformCore(ref input, WindowedBuffer, RowCounter - InitialWindowSize, ref output); + WindowedBuffer.AddLast(input); + IncrementRowCounter(); + } + } + + public void ProcessWithoutBuffer(ref TInput input, ref TOutput output) + { + if (InitialWindowedBuffer.Count < InitialWindowSize) + { + InitialWindowedBuffer.AddLast(input); + SetNaOutput(ref output); + + if (InitialWindowedBuffer.Count == InitialWindowSize) + LearnStateFromDataCore(InitialWindowedBuffer); + } + else + { + TransformCore(ref input, WindowedBuffer, RowCounter - InitialWindowSize, ref output); + IncrementRowCounter(); + } + } + + /// + /// The abstract method that specifies the NA value for the dst type. + /// + /// + protected abstract void SetNaOutput(ref TOutput dst); + + /// + /// The abstract method that realizes the main logic for the transform. + /// + /// A reference to the input object. + /// A reference to the dst object. + /// A reference to the windowed buffer. + /// A long number that indicates the number of times TransformCore has been called so far (starting value = 0). + protected abstract void TransformCore(ref TInput input, FixedSizeQueue windowedBuffer, long iteration, ref TOutput dst); + + /// + /// The abstract method that realizes the logic for initializing the state object. + /// + protected abstract void InitializeStateCore(); + + /// + /// The abstract method that realizes the logic for learning the parameters and the initial state object from data. + /// + /// A queue of data points used for training + protected abstract void LearnStateFromDataCore(FixedSizeQueue data); + } + + protected readonly IHost Host; + + /// + /// The window size for buffering. + /// + protected readonly int WindowSize; + + /// + /// The number of datapoints from the beginning of the sequence that are used for learning the initial state. + /// + protected int InitialWindowSize; + + public string InputColumnName; + public string OutputColumnName; + protected ColumnType OutputColumnType; + + public bool IsRowToRowMapper => false; + + /// + /// The main constructor for the sequential transform + /// + /// The host. + /// The size of buffer used for windowed buffering. + /// The number of datapoints picked from the beginning of the series for training the transform parameters if needed. + /// The name of the input column. + /// The name of the dst column. + /// + protected SequentialTransformerBase(IHost host, int windowSize, int initialWindowSize, string inputColumnName, string outputColumnName, ColumnType outputColType) + { + Host = host; + Host.CheckParam(initialWindowSize >= 0, nameof(initialWindowSize), "Must be non-negative."); + Host.CheckParam(windowSize >= 0, nameof(windowSize), "Must be non-negative."); + // REVIEW: Very bad design. This base class is responsible for reporting errors on + // the arguments, but the arguments themselves are not derived form any base class. + Host.CheckNonEmpty(inputColumnName, nameof(PercentileThresholdTransform.Arguments.Source)); + Host.CheckNonEmpty(outputColumnName, nameof(PercentileThresholdTransform.Arguments.Source)); + + InputColumnName = inputColumnName; + OutputColumnName = outputColumnName; + OutputColumnType = outputColType; + InitialWindowSize = initialWindowSize; + WindowSize = windowSize; + } + + protected SequentialTransformerBase(IHost host, ModelLoadContext ctx) + { + Host = host; + Host.CheckValue(ctx, nameof(ctx)); + + // *** Binary format *** + // int: _windowSize + // int: _initialWindowSize + // int (string ID): _inputColumnName + // int (string ID): _outputColumnName + // ColumnType: _transform.Schema.GetColumnType(0) + + var windowSize = ctx.Reader.ReadInt32(); + Host.CheckDecode(windowSize >= 0); + + var initialWindowSize = ctx.Reader.ReadInt32(); + Host.CheckDecode(initialWindowSize >= 0); + + var inputColumnName = ctx.LoadNonEmptyString(); + var outputColumnName = ctx.LoadNonEmptyString(); + + InputColumnName = inputColumnName; + OutputColumnName = outputColumnName; + InitialWindowSize = initialWindowSize; + WindowSize = windowSize; + + BinarySaver bs = new BinarySaver(Host, new BinarySaver.Arguments()); + OutputColumnType = bs.LoadTypeDescriptionOrNull(ctx.Reader.BaseStream); + } + + public virtual void Save(ModelSaveContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); + Host.Assert(InitialWindowSize >= 0); + Host.Assert(WindowSize >= 0); + + // *** Binary format *** + // int: _windowSize + // int: _initialWindowSize + // int (string ID): _inputColumnName + // int (string ID): _outputColumnName + // ColumnType: _transform.Schema.GetColumnType(0) + + ctx.Writer.Write(WindowSize); + ctx.Writer.Write(InitialWindowSize); + ctx.SaveNonEmptyString(InputColumnName); + ctx.SaveNonEmptyString(OutputColumnName); + + var bs = new BinarySaver(Host, new BinarySaver.Arguments()); + bs.TryWriteTypeDescription(ctx.Writer.BaseStream, OutputColumnType, out int byteWritten); + } + + public abstract Schema GetOutputSchema(Schema inputSchema); + + protected abstract IRowMapper MakeRowMapper(ISchema schema); + + protected SequentialDataTransform MakeDataTransform(IDataView input) + { + Host.CheckValue(input, nameof(input)); + return new SequentialDataTransform(Host, this, input, MakeRowMapper(input.Schema)); + } + + public IDataView Transform(IDataView input) => MakeDataTransform(input); + + public IRowToRowMapper GetRowToRowMapper(Schema inputSchema) + { + throw new InvalidOperationException("Not a RowToRowMapper."); + } + + public sealed class SequentialDataTransform : TransformBase, ITransformTemplate + { + private readonly IRowMapper _mapper; + private readonly SequentialTransformerBase _parent; + private readonly IDataTransform _transform; + private readonly ColumnBindings _bindings; + + public SequentialDataTransform(IHost host, SequentialTransformerBase parent, IDataView input, IRowMapper mapper) + :base(parent.Host, input) + { + _parent = parent; + _transform = CreateLambdaTransform(_parent.Host, input, _parent.InputColumnName, + _parent.OutputColumnName, InitFunction, _parent.WindowSize > 0, _parent.OutputColumnType); + _mapper = mapper; + _bindings = new ColumnBindings(Schema.Create(input.Schema), _mapper.GetOutputColumns()); + } + + private static IDataTransform CreateLambdaTransform(IHost host, IDataView input, string inputColumnName, string outputColumnName, + Action initFunction, bool hasBuffer, ColumnType outputColTypeOverride) + { + var inputSchema = SchemaDefinition.Create(typeof(DataBox)); + inputSchema[0].ColumnName = inputColumnName; + + var outputSchema = SchemaDefinition.Create(typeof(DataBox)); + outputSchema[0].ColumnName = outputColumnName; + + if (outputColTypeOverride != null) + outputSchema[0].ColumnType = outputColTypeOverride; + + Action, DataBox, TState> lambda; + if (hasBuffer) + lambda = MapFunction; + else + lambda = MapFunctionWithoutBuffer; + + return LambdaTransform.CreateMap(host, input, lambda, initFunction, inputSchema, outputSchema); + } + + private static void MapFunction(DataBox input, DataBox output, TState state) + { + state.Process(ref input.Value, ref output.Value); + } + + private static void MapFunctionWithoutBuffer(DataBox input, DataBox output, TState state) + { + state.ProcessWithoutBuffer(ref input.Value, ref output.Value); + } + + private void InitFunction(TState state) + { + state.InitState(_parent.WindowSize, _parent.InitialWindowSize, _parent, _parent.Host); + } + + public override bool CanShuffle { get { return false; } } + + protected override IRowCursor GetRowCursorCore(Func predicate, IRandom rand = null) + { + var srcCursor = _transform.GetRowCursor(predicate, rand); + return new Cursor(Host, this, srcCursor); + } + + protected override bool? ShouldUseParallelCursors(Func predicate) + { + Host.AssertValue(predicate); + return false; + } + + public override Schema Schema => _bindings.Schema; + + public override long? GetRowCount(bool lazy = true) + { + return _transform.GetRowCount(lazy); + } + + public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, IRandom rand = null) + { + consolidator = null; + return new IRowCursor[] { GetRowCursorCore(predicate, rand) }; + } + + public override void Save(ModelSaveContext ctx) + { + _parent.Save(ctx); + } + + public IDataTransform ApplyToData(IHostEnvironment env, IDataView newSource) + { + return new SequentialDataTransform(Contracts.CheckRef(env, nameof(env)).Register("SequentialDataTransform"), _parent, newSource, _mapper); + } + } + + /// + /// A wrapper around the cursor which replaces the schema. + /// + private sealed class Cursor : SynchronizedCursorBase, IRowCursor + { + private readonly SequentialDataTransform _parent; + + public Cursor(IHost host, SequentialDataTransform parent, IRowCursor input) + : base(host, input) + { + Ch.Assert(input.Schema.ColumnCount == parent.Schema.ColumnCount); + _parent = parent; + } + + public Schema Schema { get { return _parent.Schema; } } + + public bool IsColumnActive(int col) + { + Ch.Check(0 <= col && col < Schema.ColumnCount, "col"); + return Input.IsColumnActive(col); + } + + public ValueGetter GetGetter(int col) + { + Ch.Check(IsColumnActive(col), "col"); + return Input.GetGetter(col); + } + } + } +} diff --git a/src/Microsoft.ML.TimeSeries/SsaAnomalyDetectionBase.cs b/src/Microsoft.ML.TimeSeries/SsaAnomalyDetectionBase.cs index f4f7023b81..86d422e484 100644 --- a/src/Microsoft.ML.TimeSeries/SsaAnomalyDetectionBase.cs +++ b/src/Microsoft.ML.TimeSeries/SsaAnomalyDetectionBase.cs @@ -106,8 +106,8 @@ public abstract class SsaArguments : ArgumentsBase protected readonly Func ErrorFunc; protected readonly ISequenceModeler Model; - public SsaAnomalyDetectionBase(SsaArguments args, string name, IHostEnvironment env, IDataView input) - : base(args.WindowSize, 0, args.Source, args.Name, name, env, input, args.Side, args.Martingale, args.AlertOn, args.PowerMartingaleEpsilon, args.AlertThreshold) + public SsaAnomalyDetectionBase(SsaArguments args, string name, IHostEnvironment env) + : base(args.WindowSize, 0, args.Source, args.Name, name, env, args.Side, args.Martingale, args.AlertOn, args.PowerMartingaleEpsilon, args.AlertThreshold) { Host.CheckUserArg(2 <= args.SeasonalWindowSize, nameof(args.SeasonalWindowSize), "Must be at least 2."); Host.CheckUserArg(0 <= args.DiscountFactor && args.DiscountFactor <= 1, nameof(args.DiscountFactor), "Must be in the range [0, 1]."); @@ -118,18 +118,13 @@ public SsaAnomalyDetectionBase(SsaArguments args, string name, IHostEnvironment ErrorFunction = args.ErrorFunction; ErrorFunc = ErrorFunctionUtils.GetErrorFunction(ErrorFunction); IsAdaptive = args.IsAdaptive; - // Creating the master SSA model Model = new AdaptiveSingularSpectrumSequenceModeler(Host, args.InitialWindowSize, SeasonalWindowSize + 1, SeasonalWindowSize, DiscountFactor, null, AdaptiveSingularSpectrumSequenceModeler.RankSelectionMethod.Exact, null, SeasonalWindowSize / 2, false, false); - - // Training the master SSA model - var data = new RoleMappedData(input, null, InputColumnName); - Model.Train(data); } - public SsaAnomalyDetectionBase(IHostEnvironment env, ModelLoadContext ctx, string name, IDataView input) - : base(env, ctx, name, input) + public SsaAnomalyDetectionBase(IHostEnvironment env, ModelLoadContext ctx, string name) + : base(env, ctx, name) { // *** Binary format *** // @@ -159,6 +154,20 @@ public SsaAnomalyDetectionBase(IHostEnvironment env, ModelLoadContext ctx, strin Host.CheckDecode(Model != null); } + public override Schema GetOutputSchema(Schema inputSchema) + { + Host.CheckValue(inputSchema, nameof(inputSchema)); + + if (!inputSchema.TryGetColumnIndex(InputColumnName, out var col)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", InputColumnName); + + var colType = inputSchema.GetColumnType(col); + if (colType != NumberType.R4) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", InputColumnName, NumberType.R4.ToString(), colType.ToString()); + + return Transform(new EmptyDataView(Host, inputSchema)).Schema; + } + public override void Save(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); @@ -200,7 +209,6 @@ protected override void InitializeAnomalyDetector() { _parentAnomalyDetector = (SsaAnomalyDetectionBase)Parent; _model = _parentAnomalyDetector.Model.Clone(); - _model.InitState(); } protected override double ComputeRawAnomalyScore(ref Single input, FixedSizeQueue windowedBuffer, long iteration) diff --git a/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs b/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs index 8eca18a034..0aa1386bcc 100644 --- a/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs +++ b/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs @@ -3,18 +3,29 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.TimeSeriesProcessing; +using static Microsoft.ML.Runtime.TimeSeriesProcessing.SequentialAnomalyDetectionTransformBase; -[assembly: LoadableClass(SsaChangePointDetector.Summary, typeof(SsaChangePointDetector), typeof(SsaChangePointDetector.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(SsaChangePointDetector.Summary, typeof(IDataTransform), typeof(SsaChangePointDetector), typeof(SsaChangePointDetector.Arguments), typeof(SignatureDataTransform), SsaChangePointDetector.UserName, SsaChangePointDetector.LoaderSignature, SsaChangePointDetector.ShortName)] -[assembly: LoadableClass(SsaChangePointDetector.Summary, typeof(SsaChangePointDetector), null, typeof(SignatureLoadDataTransform), + +[assembly: LoadableClass(SsaChangePointDetector.Summary, typeof(IDataTransform), typeof(SsaChangePointDetector), null, typeof(SignatureLoadDataTransform), + SsaChangePointDetector.UserName, SsaChangePointDetector.LoaderSignature)] + +[assembly: LoadableClass(SsaChangePointDetector.Summary, typeof(SsaChangePointDetector), null, typeof(SignatureLoadModel), SsaChangePointDetector.UserName, SsaChangePointDetector.LoaderSignature)] +[assembly: LoadableClass(typeof(IRowMapper), typeof(SsaChangePointDetector), null, typeof(SignatureLoadRowMapper), + SsaChangePointDetector.UserName, SsaChangePointDetector.LoaderSignature)] + namespace Microsoft.ML.Runtime.TimeSeriesProcessing { /// @@ -93,8 +104,24 @@ private static VersionInfo GetVersionInfo() loaderAssemblyName: typeof(SsaChangePointDetector).Assembly.FullName); } - public SsaChangePointDetector(IHostEnvironment env, Arguments args, IDataView input) - : base(new BaseArguments(args), LoaderSignature, env, input) + internal SsaChangePointDetector(IHostEnvironment env, Arguments args, IDataView input) + : this(env, args) + { + Model.Train(new RoleMappedData(input, null, InputColumnName)); + } + + // Factory method for SignatureDataTransform. + private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(args, nameof(args)); + env.CheckValue(input, nameof(input)); + + return new SsaChangePointDetector(env, args, input).MakeDataTransform(input); + } + + internal SsaChangePointDetector(IHostEnvironment env, Arguments args) + : base(new BaseArguments(args), LoaderSignature, env) { switch (Martingale) { @@ -113,8 +140,28 @@ public SsaChangePointDetector(IHostEnvironment env, Arguments args, IDataView in } } - public SsaChangePointDetector(IHostEnvironment env, ModelLoadContext ctx, IDataView input) - : base(env, ctx, LoaderSignature, input) + // Factory method for SignatureLoadDataTransform. + private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(ctx, nameof(ctx)); + env.CheckValue(input, nameof(input)); + + return new SsaChangePointDetector(env, ctx).MakeDataTransform(input); + } + + // Factory method for SignatureLoadModel. + private static SsaChangePointDetector Create(IHostEnvironment env, ModelLoadContext ctx) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); + + return new SsaChangePointDetector(env, ctx); + } + + internal SsaChangePointDetector(IHostEnvironment env, ModelLoadContext ctx) + : base(env, ctx, LoaderSignature) { // *** Binary format *** // @@ -141,5 +188,86 @@ public override void Save(ModelSaveContext ctx) base.Save(ctx); } + + // Factory method for SignatureLoadRowMapper. + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); + } + + /// + /// Estimator for + /// + public sealed class SsaChangePointEstimator : IEstimator + { + private readonly IHost _host; + private readonly SsaChangePointDetector.Arguments _args; + + /// + /// Create a new instance of + /// + /// Host Environment. + /// Name of the input column. + /// The name of the new column. + /// The confidence for change point detection in the range [0, 100]. + /// The change history length. + /// The change history length. + /// The change history length. + /// The function used to compute the error between the expected and the observed value. + /// The martingale used for scoring. + /// The epsilon parameter for the Power martingale. + public SsaChangePointEstimator(IHostEnvironment env, string inputColumn, string outputColumn, + int confidence, int changeHistoryLength, int trainingWindowSize, int seasonalityWindowSize, + ErrorFunctionUtils.ErrorFunction errorFunction = ErrorFunctionUtils.ErrorFunction.SignedDifference, + MartingaleType martingale = MartingaleType.Power, double eps = 0.1) + : this(env, new SsaChangePointDetector.Arguments + { + Name = outputColumn, + Source = inputColumn, + Confidence = confidence, + ChangeHistoryLength = changeHistoryLength, + TrainingWindowSize = trainingWindowSize, + SeasonalWindowSize = seasonalityWindowSize, + Martingale = martingale, + PowerMartingaleEpsilon = eps, + ErrorFunction = errorFunction + }) + { + } + + public SsaChangePointEstimator(IHostEnvironment env, SsaChangePointDetector.Arguments args) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(nameof(SsaChangePointEstimator)); + + _host.CheckNonEmpty(args.Name, nameof(args.Name)); + _host.CheckNonEmpty(args.Source, nameof(args.Source)); + + _args = args; + } + + public SsaChangePointDetector Fit(IDataView input) + { + _host.CheckValue(input, nameof(input)); + return new SsaChangePointDetector(_host, _args, input); + } + + public SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + _host.CheckValue(inputSchema, nameof(inputSchema)); + + if (!inputSchema.TryFindColumn(_args.Source, out var col)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _args.Source); + if (col.ItemType != NumberType.R4) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _args.Source, NumberType.R4.ToString(), col.GetTypeString()); + + var metadata = new List() { + new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false) + }; + var resultDic = inputSchema.Columns.ToDictionary(x => x.Name); + resultDic[_args.Name] = new SchemaShape.Column( + _args.Name, SchemaShape.Column.VectorKind.Vector, NumberType.R8, false, new SchemaShape(metadata)); + + return new SchemaShape(resultDic.Values); + } } } diff --git a/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs b/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs index 79d6622ff5..44238e5a71 100644 --- a/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs +++ b/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs @@ -2,18 +2,29 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.TimeSeriesProcessing; +using static Microsoft.ML.Runtime.TimeSeriesProcessing.SequentialAnomalyDetectionTransformBase; -[assembly: LoadableClass(SsaSpikeDetector.Summary, typeof(SsaSpikeDetector), typeof(SsaSpikeDetector.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(SsaSpikeDetector.Summary, typeof(IDataTransform), typeof(SsaSpikeDetector), typeof(SsaSpikeDetector.Arguments), typeof(SignatureDataTransform), SsaSpikeDetector.UserName, SsaSpikeDetector.LoaderSignature, SsaSpikeDetector.ShortName)] -[assembly: LoadableClass(SsaSpikeDetector.Summary, typeof(SsaSpikeDetector), null, typeof(SignatureLoadDataTransform), + +[assembly: LoadableClass(SsaSpikeDetector.Summary, typeof(IDataTransform), typeof(SsaSpikeDetector), null, typeof(SignatureLoadDataTransform), + SsaSpikeDetector.UserName, SsaSpikeDetector.LoaderSignature)] + +[assembly: LoadableClass(SsaSpikeDetector.Summary, typeof(SsaSpikeDetector), null, typeof(SignatureLoadModel), SsaSpikeDetector.UserName, SsaSpikeDetector.LoaderSignature)] +[assembly: LoadableClass(typeof(IRowMapper), typeof(SsaSpikeDetector), null, typeof(SignatureLoadRowMapper), + SsaSpikeDetector.UserName, SsaSpikeDetector.LoaderSignature)] + namespace Microsoft.ML.Runtime.TimeSeriesProcessing { /// @@ -90,14 +101,50 @@ private static VersionInfo GetVersionInfo() loaderAssemblyName: typeof(SsaSpikeDetector).Assembly.FullName); } - public SsaSpikeDetector(IHostEnvironment env, Arguments args, IDataView input) - : base(new BaseArguments(args), LoaderSignature, env, input) + internal SsaSpikeDetector(IHostEnvironment env, Arguments args, IDataView input) + : base(new BaseArguments(args), LoaderSignature, env) + { + Model.Train(new RoleMappedData(input, null, InputColumnName)); + } + + // Factory method for SignatureDataTransform. + private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(args, nameof(args)); + env.CheckValue(input, nameof(input)); + + return new SsaSpikeDetector(env, args, input).MakeDataTransform(input); + } + + internal SsaSpikeDetector(IHostEnvironment env, Arguments args) + : base(new BaseArguments(args), LoaderSignature, env) { // This constructor is empty. } - public SsaSpikeDetector(IHostEnvironment env, ModelLoadContext ctx, IDataView input) - : base(env, ctx, LoaderSignature, input) + // Factory method for SignatureLoadDataTransform. + private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(ctx, nameof(ctx)); + env.CheckValue(input, nameof(input)); + + return new SsaSpikeDetector(env, ctx).MakeDataTransform(input); + } + + // Factory method for SignatureLoadModel. + private static SsaSpikeDetector Create(IHostEnvironment env, ModelLoadContext ctx) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); + + return new SsaSpikeDetector(env, ctx); + } + + internal SsaSpikeDetector(IHostEnvironment env, ModelLoadContext ctx) + : base(env, ctx, LoaderSignature) { // *** Binary format *** // @@ -122,5 +169,83 @@ public override void Save(ModelSaveContext ctx) base.Save(ctx); } + + // Factory method for SignatureLoadRowMapper. + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); + } + + /// + /// Estimator for + /// + public sealed class SsaSpikeEstimator : IEstimator + { + private readonly IHost _host; + private readonly SsaSpikeDetector.Arguments _args; + + /// + /// Create a new instance of + /// + /// Host Environment. + /// Name of the input column. + /// The name of the new column. + /// The confidence for spike detection in the range [0, 100]. + /// The size of the sliding window for computing the p-value. + /// The change history length. + /// The change history length. + /// The argument that determines whether to detect positive or negative anomalies, or both. + /// The function used to compute the error between the expected and the observed value. + public SsaSpikeEstimator(IHostEnvironment env, string inputColumn, string outputColumn, int confidence, + int pvalueHistoryLength, int trainingWindowSize, int seasonalityWindowSize, AnomalySide side = AnomalySide.TwoSided, + ErrorFunctionUtils.ErrorFunction errorFunction = ErrorFunctionUtils.ErrorFunction.SignedDifference) + : this(env, new SsaSpikeDetector.Arguments + { + Name = outputColumn, + Source = inputColumn, + Confidence = confidence, + PvalueHistoryLength = pvalueHistoryLength, + TrainingWindowSize = trainingWindowSize, + SeasonalWindowSize = seasonalityWindowSize, + Side = side, + ErrorFunction = errorFunction + }) + { + } + + public SsaSpikeEstimator(IHostEnvironment env, SsaSpikeDetector.Arguments args) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(nameof(SsaSpikeEstimator)); + + _host.CheckNonEmpty(args.Name, nameof(args.Name)); + _host.CheckNonEmpty(args.Source, nameof(args.Source)); + + _args = args; + } + + public SsaSpikeDetector Fit(IDataView input) + { + _host.CheckValue(input, nameof(input)); + return new SsaSpikeDetector(_host, _args, input); + } + + public SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + _host.CheckValue(inputSchema, nameof(inputSchema)); + + if (!inputSchema.TryFindColumn(_args.Source, out var col)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _args.Source); + if (col.ItemType != NumberType.R4) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _args.Source, NumberType.R4.ToString(), col.GetTypeString()); + + var metadata = new List() { + new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false) + }; + var resultDic = inputSchema.Columns.ToDictionary(x => x.Name); + resultDic[_args.Name] = new SchemaShape.Column( + _args.Name, SchemaShape.Column.VectorKind.Vector, NumberType.R8, false, new SchemaShape(metadata)); + + return new SchemaShape(resultDic.Values); + } } } diff --git a/src/Microsoft.ML.TimeSeries/TimeSeriesProcessing.cs b/src/Microsoft.ML.TimeSeries/TimeSeriesProcessing.cs index 6b4275ffa3..65c7346d55 100644 --- a/src/Microsoft.ML.TimeSeries/TimeSeriesProcessing.cs +++ b/src/Microsoft.ML.TimeSeries/TimeSeriesProcessing.cs @@ -30,7 +30,7 @@ public static CommonOutputs.TransformOutput ExponentialAverage(IHostEnvironment public static CommonOutputs.TransformOutput IidChangePointDetector(IHostEnvironment env, IidChangePointDetector.Arguments input) { var h = EntryPointUtils.CheckArgsAndCreateHost(env, "IidChangePointDetector", input); - var view = new IidChangePointDetector(h, input, input.Data); + var view = new IidChangePointEstimator(h, input).Fit(input.Data).Transform(input.Data); return new CommonOutputs.TransformOutput() { Model = new TransformModel(h, view, input.Data), @@ -42,7 +42,7 @@ public static CommonOutputs.TransformOutput IidChangePointDetector(IHostEnvironm public static CommonOutputs.TransformOutput IidSpikeDetector(IHostEnvironment env, IidSpikeDetector.Arguments input) { var h = EntryPointUtils.CheckArgsAndCreateHost(env, "IidSpikeDetector", input); - var view = new IidSpikeDetector(h, input, input.Data); + var view = new IidSpikeEstimator(h, input).Fit(input.Data).Transform(input.Data); return new CommonOutputs.TransformOutput() { Model = new TransformModel(h, view, input.Data), @@ -90,7 +90,7 @@ public static CommonOutputs.TransformOutput SlidingWindowTransform(IHostEnvironm public static CommonOutputs.TransformOutput SsaChangePointDetector(IHostEnvironment env, SsaChangePointDetector.Arguments input) { var h = EntryPointUtils.CheckArgsAndCreateHost(env, "SsaChangePointDetector", input); - var view = new SsaChangePointDetector(h, input, input.Data); + var view = new SsaChangePointEstimator(h, input).Fit(input.Data).Transform(input.Data); return new CommonOutputs.TransformOutput() { Model = new TransformModel(h, view, input.Data), @@ -102,7 +102,7 @@ public static CommonOutputs.TransformOutput SsaChangePointDetector(IHostEnvironm public static CommonOutputs.TransformOutput SsaSpikeDetector(IHostEnvironment env, SsaSpikeDetector.Arguments input) { var h = EntryPointUtils.CheckArgsAndCreateHost(env, "SsaSpikeDetector", input); - var view = new SsaSpikeDetector(h, input, input.Data); + var view = new SsaSpikeEstimator(h, input).Fit(input.Data).Transform(input.Data); return new CommonOutputs.TransformOutput() { Model = new TransformModel(h, view, input.Data), diff --git a/test/BaselineOutput/Common/SavePipe/SavePipeIidChangePoint-Schema.txt b/test/BaselineOutput/Common/SavePipe/SavePipeIidChangePoint-Schema.txt index ff1d284359..dd8d7959bf 100644 --- a/test/BaselineOutput/Common/SavePipe/SavePipeIidChangePoint-Schema.txt +++ b/test/BaselineOutput/Common/SavePipe/SavePipeIidChangePoint-Schema.txt @@ -1,7 +1,7 @@ ---- BoundLoader ---- 1 columns: Features: R4 ----- IidChangePointDetector ---- +---- SequentialDataTransform ---- 2 columns: Features: R4 Anomaly: Vec @@ -16,7 +16,7 @@ fAnomaly: Vec Metadata 'SlotNames': Vec: Length=4, Count=4 [0] 'Alert', [1] 'Raw Score', [2] 'P-Value Score', [3] 'Martingale Score' ----- IidChangePointDetector ---- +---- SequentialDataTransform ---- 4 columns: Features: R4 Anomaly: Vec diff --git a/test/BaselineOutput/Common/SavePipe/SavePipeIidSpike-Schema.txt b/test/BaselineOutput/Common/SavePipe/SavePipeIidSpike-Schema.txt index 25d32dd015..659c20133c 100644 --- a/test/BaselineOutput/Common/SavePipe/SavePipeIidSpike-Schema.txt +++ b/test/BaselineOutput/Common/SavePipe/SavePipeIidSpike-Schema.txt @@ -1,7 +1,7 @@ ---- BoundLoader ---- 1 columns: Features: R4 ----- IidSpikeDetector ---- +---- SequentialDataTransform ---- 2 columns: Features: R4 Anomaly: Vec @@ -16,7 +16,7 @@ fAnomaly: Vec Metadata 'SlotNames': Vec: Length=3, Count=3 [0] 'Alert', [1] 'Raw Score', [2] 'P-Value Score' ----- IidSpikeDetector ---- +---- SequentialDataTransform ---- 4 columns: Features: R4 Anomaly: Vec diff --git a/test/BaselineOutput/SingleDebug/SavePipe/SavePipeSsaSpikeNoData-Schema.txt b/test/BaselineOutput/SingleDebug/SavePipe/SavePipeSsaSpikeNoData-Schema.txt index 764b0f98c0..f0dcf9138b 100644 --- a/test/BaselineOutput/SingleDebug/SavePipe/SavePipeSsaSpikeNoData-Schema.txt +++ b/test/BaselineOutput/SingleDebug/SavePipe/SavePipeSsaSpikeNoData-Schema.txt @@ -1,7 +1,7 @@ ---- BoundLoader ---- 1 columns: Features: R4 ----- SsaSpikeDetector ---- +---- SequentialDataTransform ---- 2 columns: Features: R4 Anomaly: Vec diff --git a/test/BaselineOutput/SingleRelease/SavePipe/SavePipeSsaSpikeNoData-Schema.txt b/test/BaselineOutput/SingleRelease/SavePipe/SavePipeSsaSpikeNoData-Schema.txt index 764b0f98c0..f0dcf9138b 100644 --- a/test/BaselineOutput/SingleRelease/SavePipe/SavePipeSsaSpikeNoData-Schema.txt +++ b/test/BaselineOutput/SingleRelease/SavePipe/SavePipeSsaSpikeNoData-Schema.txt @@ -1,7 +1,7 @@ ---- BoundLoader ---- 1 columns: Features: R4 ----- SsaSpikeDetector ---- +---- SequentialDataTransform ---- 2 columns: Features: R4 Anomaly: Vec diff --git a/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs index b0abe45a8a..125105138c 100644 --- a/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs +++ b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs @@ -52,12 +52,13 @@ public void ChangeDetection() Confidence = 80, Source = "Value", Name = "Change", - ChangeHistoryLength = size, - Data = dataView + ChangeHistoryLength = size }; - - var detector = TimeSeriesProcessing.IidChangePointDetector(env, args); - var output = detector.Model.Apply(env, dataView); + // Train + var detector = new IidChangePointEstimator(env, args).Fit(dataView); + // Transform + var output = detector.Transform(dataView); + // Get predictions var enumerator = output.AsEnumerable(env, true).GetEnumerator(); Prediction row = null; List expectedValues = new List() { 0, 5, 0.5, 5.1200000000000114E-08, 0, 5, 0.4999999995, 5.1200000046080209E-08, 0, 5, 0.4999999995, 5.1200000092160303E-08, @@ -80,8 +81,8 @@ public void ChangePointDetectionWithSeasonality() { using (var env = new ConsoleEnvironment(conc: 1)) { - const int ChangeHistorySize = 2000; - const int SeasonalitySize = 1000; + const int ChangeHistorySize = 10; + const int SeasonalitySize = 10; const int NumberOfSeasonsInTraining = 5; const int MaxTrainingSize = NumberOfSeasonsInTraining * SeasonalitySize; @@ -94,7 +95,6 @@ public void ChangePointDetectionWithSeasonality() Source = "Value", Name = "Change", ChangeHistoryLength = ChangeHistorySize, - Data = dataView, TrainingWindowSize = MaxTrainingSize, SeasonalWindowSize = SeasonalitySize }; @@ -106,21 +106,24 @@ public void ChangePointDetectionWithSeasonality() for (int i = 0; i < ChangeHistorySize; i++) data.Add(new Data(i * 100)); - var detector = TimeSeriesProcessing.SsaChangePointDetector(env, args); - var output = detector.Model.Apply(env, dataView); + // Train + var detector = new SsaChangePointEstimator(env, args).Fit(dataView); + // Transform + var output = detector.Transform(dataView); + // Get predictions var enumerator = output.AsEnumerable(env, true).GetEnumerator(); Prediction row = null; - List expectedValues = new List() { 0, 0, 0.5, 0, 0, 1, 0.15865526383236372, - 0, 0, 1.6069464981555939, 0.05652458872960725, 0, 0, 2.0183047652244568, 0.11021633531076747, 0}; + List expectedValues = new List() { 0, -3.31410598754883, 0.5, 5.12000000000001E-08, 0, 1.5700820684432983, 5.2001145245395008E-07, + 0.012414560443710681, 0, 1.2854313254356384, 0.28810801662678009, 0.02038940454467935, 0, -1.0950627326965332, 0.36663890634019225, 0.026956459625565483}; int index = 0; while (enumerator.MoveNext() && index < expectedValues.Count) { row = enumerator.Current; - Assert.Equal(expectedValues[index++], row.Change[0], precision: 7); - Assert.Equal(expectedValues[index++], row.Change[1], precision: 7); - Assert.Equal(expectedValues[index++], row.Change[2], precision: 7); - Assert.Equal(expectedValues[index++], row.Change[3], precision: 7); + Assert.Equal(expectedValues[index++], row.Change[0], precision: 7); // Alert + Assert.Equal(expectedValues[index++], row.Change[1], precision: 7); // Raw score + Assert.Equal(expectedValues[index++], row.Change[2], precision: 7); // P-Value score + Assert.Equal(expectedValues[index++], row.Change[3], precision: 7); // Martingale score } } } diff --git a/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesEstimatorTests.cs b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesEstimatorTests.cs new file mode 100644 index 0000000000..3d7cfd5d32 --- /dev/null +++ b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesEstimatorTests.cs @@ -0,0 +1,166 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.RunTests; +using Microsoft.ML.Runtime.TimeSeriesProcessing; +using System.Collections.Generic; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.Tests +{ + public class TimeSeriesEstimatorTests : TestDataPipeBase + { + private const int inputSize = 150528; + + private class Data + { + public float Value; + + public Data(float value) + { + Value = value; + } + } + + private class TestDataXY + { + [VectorType(inputSize)] + public float[] A; + } + private class TestDataDifferntType + { + [VectorType(inputSize)] + public string[] data_0; + } + + public TimeSeriesEstimatorTests(ITestOutputHelper output) : base(output) + { + } + + [Fact] + void TestSsaChangePointEstimator() + { + int Confidence = 95; + int ChangeHistorySize = 10; + int SeasonalitySize = 10; + int NumberOfSeasonsInTraining = 5; + int MaxTrainingSize = NumberOfSeasonsInTraining * SeasonalitySize; + + List data = new List(); + var dataView = Env.CreateStreamingDataView(data); + + for (int j = 0; j < NumberOfSeasonsInTraining; j++) + for (int i = 0; i < SeasonalitySize; i++) + data.Add(new Data(i)); + + for (int i = 0; i < ChangeHistorySize; i++) + data.Add(new Data(i * 100)); + + var pipe = new SsaChangePointEstimator(Env, "Value", "Change", + Confidence, ChangeHistorySize, MaxTrainingSize, SeasonalitySize); + + var xyData = new List { new TestDataXY() { A = new float[inputSize] } }; + var stringData = new List { new TestDataDifferntType() { data_0 = new string[inputSize] } }; + + var invalidDataWrongNames = ComponentCreation.CreateDataView(Env, xyData); + var invalidDataWrongTypes = ComponentCreation.CreateDataView(Env, stringData); + + TestEstimatorCore(pipe, dataView, invalidInput: invalidDataWrongTypes); + TestEstimatorCore(pipe, dataView, invalidInput: invalidDataWrongNames); + + Done(); + } + + [Fact] + void TestSsaSpikeEstimator() + { + int Confidence = 95; + int PValueHistorySize = 10; + int SeasonalitySize = 10; + int NumberOfSeasonsInTraining = 5; + int MaxTrainingSize = NumberOfSeasonsInTraining * SeasonalitySize; + + List data = new List(); + var dataView = Env.CreateStreamingDataView(data); + + for (int j = 0; j < NumberOfSeasonsInTraining; j++) + for (int i = 0; i < SeasonalitySize; i++) + data.Add(new Data(i)); + + for (int i = 0; i < PValueHistorySize; i++) + data.Add(new Data(i * 100)); + + var pipe = new SsaSpikeEstimator(Env, "Value", "Change", + Confidence, PValueHistorySize, MaxTrainingSize, SeasonalitySize); + + var xyData = new List { new TestDataXY() { A = new float[inputSize] } }; + var stringData = new List { new TestDataDifferntType() { data_0 = new string[inputSize] } }; + + var invalidDataWrongNames = ComponentCreation.CreateDataView(Env, xyData); + var invalidDataWrongTypes = ComponentCreation.CreateDataView(Env, stringData); + + TestEstimatorCore(pipe, dataView, invalidInput: invalidDataWrongTypes); + TestEstimatorCore(pipe, dataView, invalidInput: invalidDataWrongNames); + + Done(); + } + + [Fact] + void TestIidChangePointEstimator() + { + int Confidence = 95; + int ChangeHistorySize = 10; + + List data = new List(); + var dataView = Env.CreateStreamingDataView(data); + + for (int i = 0; i < ChangeHistorySize; i++) + data.Add(new Data(i * 100)); + + var pipe = new IidChangePointEstimator(Env, + "Value", "Change", Confidence, ChangeHistorySize); + + var xyData = new List { new TestDataXY() { A = new float[inputSize] } }; + var stringData = new List { new TestDataDifferntType() { data_0 = new string[inputSize] } }; + + var invalidDataWrongNames = ComponentCreation.CreateDataView(Env, xyData); + var invalidDataWrongTypes = ComponentCreation.CreateDataView(Env, stringData); + + TestEstimatorCore(pipe, dataView, invalidInput: invalidDataWrongTypes); + TestEstimatorCore(pipe, dataView, invalidInput: invalidDataWrongNames); + + Done(); + } + + [Fact] + void TestIidSpikeEstimator() + { + int Confidence = 95; + int PValueHistorySize = 10; + + List data = new List(); + var dataView = Env.CreateStreamingDataView(data); + + for (int i = 0; i < PValueHistorySize; i++) + data.Add(new Data(i * 100)); + + var pipe = new IidSpikeEstimator(Env, + "Value", "Change", Confidence, PValueHistorySize); + + var xyData = new List { new TestDataXY() { A = new float[inputSize] } }; + var stringData = new List { new TestDataDifferntType() { data_0 = new string[inputSize] } }; + + var invalidDataWrongNames = ComponentCreation.CreateDataView(Env, xyData); + var invalidDataWrongTypes = ComponentCreation.CreateDataView(Env, stringData); + + TestEstimatorCore(pipe, dataView, invalidInput: invalidDataWrongTypes); + TestEstimatorCore(pipe, dataView, invalidInput: invalidDataWrongNames); + + Done(); + } + } +}