3
3
// See the LICENSE file in the project root for more information.
4
4
5
5
using System ;
6
+ using System . Collections . Generic ;
7
+ using System . Linq ;
8
+ using Microsoft . ML . Core . Data ;
6
9
using Microsoft . ML . Runtime ;
7
10
using Microsoft . ML . Runtime . CommandLine ;
8
11
using Microsoft . ML . Runtime . Data ;
9
12
using Microsoft . ML . Runtime . EntryPoints ;
10
13
using Microsoft . ML . Runtime . Model ;
11
14
using Microsoft . ML . Runtime . TimeSeriesProcessing ;
15
+ using static Microsoft . ML . Runtime . TimeSeriesProcessing . SequentialAnomalyDetectionTransformBase < System . Single , Microsoft . ML . Runtime . TimeSeriesProcessing . IidAnomalyDetectionBase . State > ;
12
16
13
- [ assembly: LoadableClass ( IidChangePointDetector . Summary , typeof ( IidChangePointDetector ) , typeof ( IidChangePointDetector . Arguments ) , typeof ( SignatureDataTransform ) ,
17
+ [ assembly: LoadableClass ( IidChangePointDetector . Summary , typeof ( IDataTransform ) , typeof ( IidChangePointDetector ) , typeof ( IidChangePointDetector . Arguments ) , typeof ( SignatureDataTransform ) ,
14
18
IidChangePointDetector . UserName , IidChangePointDetector . LoaderSignature , IidChangePointDetector . ShortName ) ]
15
- [ assembly: LoadableClass ( IidChangePointDetector . Summary , typeof ( IidChangePointDetector ) , null , typeof ( SignatureLoadDataTransform ) ,
19
+
20
+ [ assembly: LoadableClass ( IidChangePointDetector . Summary , typeof ( IDataTransform ) , typeof ( IidChangePointDetector ) , null , typeof ( SignatureLoadDataTransform ) ,
21
+ IidChangePointDetector . UserName , IidChangePointDetector . LoaderSignature ) ]
22
+
23
+ [ assembly: LoadableClass ( IidChangePointDetector . Summary , typeof ( IidChangePointDetector ) , null , typeof ( SignatureLoadModel ) ,
16
24
IidChangePointDetector . UserName , IidChangePointDetector . LoaderSignature ) ]
17
25
26
+ [ assembly: LoadableClass ( typeof ( IRowMapper ) , typeof ( IidChangePointDetector ) , null , typeof ( SignatureLoadRowMapper ) ,
27
+ IidChangePointDetector . UserName , IidChangePointDetector . LoaderSignature ) ]
28
+
18
29
namespace Microsoft . ML . Runtime . TimeSeriesProcessing
19
30
{
20
31
/// <summary>
21
32
/// This class implements the change point detector transform for an i.i.d. sequence based on adaptive kernel density estimation and martingales.
22
33
/// </summary>
23
- public sealed class IidChangePointDetector : IidAnomalyDetectionBase , ITransformTemplate
34
+ public sealed class IidChangePointDetector : IidAnomalyDetectionBase
24
35
{
25
36
internal const string Summary = "This transform detects the change-points in an i.i.d. sequence using adaptive kernel density estimation and martingales." ;
26
37
public const string LoaderSignature = "IidChangePointDetector" ;
@@ -89,8 +100,18 @@ private static VersionInfo GetVersionInfo()
89
100
loaderAssemblyName : typeof ( IidChangePointDetector ) . Assembly . FullName ) ;
90
101
}
91
102
92
- public IidChangePointDetector ( IHostEnvironment env , Arguments args , IDataView input )
93
- : base ( new BaseArguments ( args ) , LoaderSignature , env , input )
103
+ // Factory method for SignatureDataTransform.
104
+ private static IDataTransform Create ( IHostEnvironment env , Arguments args , IDataView input )
105
+ {
106
+ Contracts . CheckValue ( env , nameof ( env ) ) ;
107
+ env . CheckValue ( args , nameof ( args ) ) ;
108
+ env . CheckValue ( input , nameof ( input ) ) ;
109
+
110
+ return new IidChangePointDetector ( env , args ) . MakeDataTransform ( input ) ;
111
+ }
112
+
113
+ internal IidChangePointDetector ( IHostEnvironment env , Arguments args )
114
+ : base ( new BaseArguments ( args ) , LoaderSignature , env )
94
115
{
95
116
switch ( Martingale )
96
117
{
@@ -109,8 +130,28 @@ public IidChangePointDetector(IHostEnvironment env, Arguments args, IDataView in
109
130
}
110
131
}
111
132
112
- public IidChangePointDetector ( IHostEnvironment env , ModelLoadContext ctx , IDataView input )
113
- : base ( env , ctx , LoaderSignature , input )
133
+ // Factory method for SignatureLoadDataTransform.
134
+ private static IDataTransform Create ( IHostEnvironment env , ModelLoadContext ctx , IDataView input )
135
+ {
136
+ Contracts . CheckValue ( env , nameof ( env ) ) ;
137
+ env . CheckValue ( ctx , nameof ( ctx ) ) ;
138
+ env . CheckValue ( input , nameof ( input ) ) ;
139
+
140
+ return new IidChangePointDetector ( env , ctx ) . MakeDataTransform ( input ) ;
141
+ }
142
+
143
+ // Factory method for SignatureLoadModel.
144
+ private static IidChangePointDetector Create ( IHostEnvironment env , ModelLoadContext ctx )
145
+ {
146
+ Contracts . CheckValue ( env , nameof ( env ) ) ;
147
+ env . CheckValue ( ctx , nameof ( ctx ) ) ;
148
+ ctx . CheckAtModel ( GetVersionInfo ( ) ) ;
149
+
150
+ return new IidChangePointDetector ( env , ctx ) ;
151
+ }
152
+
153
+ internal IidChangePointDetector ( IHostEnvironment env , ModelLoadContext ctx )
154
+ : base ( env , ctx , LoaderSignature )
114
155
{
115
156
// *** Binary format ***
116
157
// <base>
@@ -119,8 +160,8 @@ public IidChangePointDetector(IHostEnvironment env, ModelLoadContext ctx, IDataV
119
160
Host . CheckDecode ( Side == AnomalySide . TwoSided ) ;
120
161
}
121
162
122
- private IidChangePointDetector ( IHostEnvironment env , IidChangePointDetector transform , IDataView newSource )
123
- : base ( new BaseArguments ( transform ) , LoaderSignature , env , newSource )
163
+ private IidChangePointDetector ( IHostEnvironment env , IidChangePointDetector transform )
164
+ : base ( new BaseArguments ( transform ) , LoaderSignature , env )
124
165
{
125
166
}
126
167
@@ -139,9 +180,65 @@ public override void Save(ModelSaveContext ctx)
139
180
base . Save ( ctx ) ;
140
181
}
141
182
142
- public IDataTransform ApplyToData ( IHostEnvironment env , IDataView newSource )
183
+ // Factory method for SignatureLoadRowMapper.
184
+ private static IRowMapper Create ( IHostEnvironment env , ModelLoadContext ctx , ISchema inputSchema )
185
+ => Create ( env , ctx ) . MakeRowMapper ( inputSchema ) ;
186
+ }
187
+
188
+ /// <summary>
189
+ /// Estimator for <see cref="IidChangePointDetector"/>
190
+ /// </summary>
191
+ public sealed class IidChangePointEstimator : TrivialEstimator < IidChangePointDetector >
192
+ {
193
+ /// <summary>
194
+ /// Create a new instance of <see cref="IidChangePointEstimator"/>
195
+ /// </summary>
196
+ /// <param name="env">Host Environment.</param>
197
+ /// <param name="inputColumn">Name of the input column.</param>
198
+ /// <param name="outputColumn">The name of the new column.</param>
199
+ /// <param name="confidence">The confidence for change point detection in the range [0, 100].</param>
200
+ /// <param name="changeHistoryLength">The change history length.</param>
201
+ /// <param name="martingale">The martingale used for scoring.</param>
202
+ /// <param name="eps">The epsilon parameter for the Power martingale.</param>
203
+ public IidChangePointEstimator ( IHostEnvironment env , string inputColumn , string outputColumn , int confidence ,
204
+ int changeHistoryLength , MartingaleType martingale = MartingaleType . Power , double eps = 0.1 )
205
+ : base ( Contracts . CheckRef ( env , nameof ( env ) ) . Register ( nameof ( IidChangePointEstimator ) ) ,
206
+ new IidChangePointDetector ( env , new IidChangePointDetector . Arguments
207
+ {
208
+ Name = outputColumn ,
209
+ Source = inputColumn ,
210
+ Confidence = confidence ,
211
+ ChangeHistoryLength = changeHistoryLength ,
212
+ Martingale = martingale ,
213
+ PowerMartingaleEpsilon = eps
214
+ } ) )
215
+ {
216
+ }
217
+
218
+ public IidChangePointEstimator ( IHostEnvironment env , IidChangePointDetector . Arguments args )
219
+ : base ( Contracts . CheckRef ( env , nameof ( env ) ) . Register ( nameof ( IidChangePointEstimator ) ) ,
220
+ new IidChangePointDetector ( env , args ) )
143
221
{
144
- return new IidChangePointDetector ( env , this , newSource ) ;
222
+ }
223
+
224
+ public override SchemaShape GetOutputSchema ( SchemaShape inputSchema )
225
+ {
226
+ Host . CheckValue ( inputSchema , nameof ( inputSchema ) ) ;
227
+
228
+ if ( ! inputSchema . TryFindColumn ( Transformer . InputColumnName , out var col ) )
229
+ throw Host . ExceptSchemaMismatch ( nameof ( inputSchema ) , "input" , Transformer . InputColumnName ) ;
230
+ if ( col . ItemType != NumberType . R4 )
231
+ throw Host . ExceptSchemaMismatch ( nameof ( inputSchema ) , "input" , Transformer . InputColumnName , NumberType . R4 . ToString ( ) , col . GetTypeString ( ) ) ;
232
+
233
+ var metadata = new List < SchemaShape . Column > ( ) {
234
+ new SchemaShape . Column ( MetadataUtils . Kinds . SlotNames , SchemaShape . Column . VectorKind . Vector , TextType . Instance , false )
235
+ } ;
236
+ var resultDic = inputSchema . Columns . ToDictionary ( x => x . Name ) ;
237
+
238
+ resultDic [ Transformer . OutputColumnName ] = new SchemaShape . Column (
239
+ Transformer . OutputColumnName , SchemaShape . Column . VectorKind . Vector , NumberType . R8 , false , new SchemaShape ( metadata ) ) ;
240
+
241
+ return new SchemaShape ( resultDic . Values ) ;
145
242
}
146
243
}
147
244
}
0 commit comments