16
16
using Microsoft . ML . Internal . Utilities ;
17
17
using Microsoft . ML . Model ;
18
18
19
- [ assembly: LoadableClass ( typeof ( RankerEvaluator ) , typeof ( RankerEvaluator ) , typeof ( RankerEvaluator . Arguments ) , typeof ( SignatureEvaluator ) ,
20
- "Ranking Evaluator" , RankerEvaluator . LoadName , "Ranking" , "rank" ) ]
19
+ [ assembly: LoadableClass ( typeof ( RankingEvaluator ) , typeof ( RankingEvaluator ) , typeof ( RankingEvaluator . Arguments ) , typeof ( SignatureEvaluator ) ,
20
+ "Ranking Evaluator" , RankingEvaluator . LoadName , "Ranking" , "rank" ) ]
21
21
22
- [ assembly: LoadableClass ( typeof ( RankerMamlEvaluator ) , typeof ( RankerMamlEvaluator ) , typeof ( RankerMamlEvaluator . Arguments ) , typeof ( SignatureMamlEvaluator ) ,
23
- "Ranking Evaluator" , RankerEvaluator . LoadName , "Ranking" , "rank" ) ]
22
+ [ assembly: LoadableClass ( typeof ( RankingMamlEvaluator ) , typeof ( RankingMamlEvaluator ) , typeof ( RankingMamlEvaluator . Arguments ) , typeof ( SignatureMamlEvaluator ) ,
23
+ "Ranking Evaluator" , RankingEvaluator . LoadName , "Ranking" , "rank" ) ]
24
24
25
- [ assembly: LoadableClass ( typeof ( RankerPerInstanceTransform ) , null , typeof ( SignatureLoadDataTransform ) ,
26
- "" , RankerPerInstanceTransform . LoaderSignature ) ]
25
+ [ assembly: LoadableClass ( typeof ( RankingPerInstanceTransform ) , null , typeof ( SignatureLoadDataTransform ) ,
26
+ "" , RankingPerInstanceTransform . LoaderSignature ) ]
27
27
28
28
namespace Microsoft . ML . Data
29
29
{
30
30
[ BestFriend ]
31
- internal sealed class RankerEvaluator : EvaluatorBase < RankerEvaluator . Aggregator >
31
+ internal sealed class RankingEvaluator : EvaluatorBase < RankingEvaluator . Aggregator >
32
32
{
33
33
public sealed class Arguments
34
34
{
@@ -61,7 +61,7 @@ public sealed class Arguments
61
61
private readonly bool _groupSummary ;
62
62
private readonly Double [ ] _labelGains ;
63
63
64
- public RankerEvaluator ( IHostEnvironment env , Arguments args )
64
+ public RankingEvaluator ( IHostEnvironment env , Arguments args )
65
65
: base ( env , LoadName )
66
66
{
67
67
// REVIEW: What kind of checking should be applied to labelGains?
@@ -89,13 +89,13 @@ private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
89
89
var t = schema . Label . Value . Type ;
90
90
if ( t != NumberDataViewType . Single && ! ( t is KeyType ) )
91
91
{
92
- throw Host . ExceptSchemaMismatch ( nameof ( RankerMamlEvaluator . Arguments . LabelColumn ) ,
92
+ throw Host . ExceptSchemaMismatch ( nameof ( RankingMamlEvaluator . Arguments . LabelColumn ) ,
93
93
"label" , schema . Label . Value . Name , "R4 or a key" , t . ToString ( ) ) ;
94
94
}
95
95
var scoreCol = schema . GetUniqueColumn ( MetadataUtils . Const . ScoreValueKind . Score ) ;
96
96
if ( scoreCol . Type != NumberDataViewType . Single )
97
97
{
98
- throw Host . ExceptSchemaMismatch ( nameof ( RankerMamlEvaluator . Arguments . ScoreColumn ) ,
98
+ throw Host . ExceptSchemaMismatch ( nameof ( RankingMamlEvaluator . Arguments . ScoreColumn ) ,
99
99
"score" , scoreCol . Name , "R4" , t . ToString ( ) ) ;
100
100
}
101
101
}
@@ -105,7 +105,7 @@ private protected override void CheckCustomColumnTypesCore(RoleMappedSchema sche
105
105
var t = schema . Group . Value . Type ;
106
106
if ( ! ( t is KeyType ) )
107
107
{
108
- throw Host . ExceptSchemaMismatch ( nameof ( RankerMamlEvaluator . Arguments . GroupIdColumn ) ,
108
+ throw Host . ExceptSchemaMismatch ( nameof ( RankingMamlEvaluator . Arguments . GroupIdColumn ) ,
109
109
"group" , schema . Group . Value . Name , "key" , t . ToString ( ) ) ;
110
110
}
111
111
}
@@ -129,7 +129,7 @@ internal override IDataTransform GetPerInstanceMetricsCore(RoleMappedData data)
129
129
var scoreInfo = data . Schema . GetUniqueColumn ( MetadataUtils . Const . ScoreValueKind . Score ) ;
130
130
Host . CheckParam ( data . Schema . Group . HasValue , nameof ( data ) , "Schema must contain a group column" ) ;
131
131
132
- return new RankerPerInstanceTransform ( Host , data . Data ,
132
+ return new RankingPerInstanceTransform ( Host , data . Data ,
133
133
data . Schema . Label . Value . Name , scoreInfo . Name , data . Schema . Group . Value . Name , _truncationLevel , _labelGains ) ;
134
134
}
135
135
@@ -242,7 +242,7 @@ private protected override void GetAggregatorConsolidationFuncs(Aggregator aggre
242
242
/// <param name="groupId">The name of the groupId column.</param>
243
243
/// <param name="score">The name of the predicted score column.</param>
244
244
/// <returns>The evaluation metrics for these outputs.</returns>
245
- public RankerMetrics Evaluate ( IDataView data , string label , string groupId , string score )
245
+ public RankingMetrics Evaluate ( IDataView data , string label , string groupId , string score )
246
246
{
247
247
Host . CheckValue ( data , nameof ( data ) ) ;
248
248
Host . CheckNonEmpty ( label , nameof ( label ) ) ;
@@ -256,12 +256,12 @@ public RankerMetrics Evaluate(IDataView data, string label, string groupId, stri
256
256
Host . Assert ( resultDict . ContainsKey ( MetricKinds . OverallMetrics ) ) ;
257
257
var overall = resultDict [ MetricKinds . OverallMetrics ] ;
258
258
259
- RankerMetrics result ;
259
+ RankingMetrics result ;
260
260
using ( var cursor = overall . GetRowCursorForAllColumns ( ) )
261
261
{
262
262
var moved = cursor . MoveNext ( ) ;
263
263
Host . Assert ( moved ) ;
264
- result = new RankerMetrics ( Host , cursor ) ;
264
+ result = new RankingMetrics ( Host , cursor ) ;
265
265
moved = cursor . MoveNext ( ) ;
266
266
Host . Assert ( ! moved ) ;
267
267
}
@@ -374,15 +374,15 @@ public void Update(short label, Single output)
374
374
375
375
public void UpdateGroup ( Single weight )
376
376
{
377
- RankerUtils . QueryMaxDcg ( _labelGains , TruncationLevel , _queryLabels , _queryOutputs , _groupMaxDcgCur ) ;
377
+ RankingUtils . QueryMaxDcg ( _labelGains , TruncationLevel , _queryLabels , _queryOutputs , _groupMaxDcgCur ) ;
378
378
if ( _groupMaxDcg != null )
379
379
{
380
380
var maxDcg = new Double [ TruncationLevel ] ;
381
381
Array . Copy ( _groupMaxDcgCur , maxDcg , TruncationLevel ) ;
382
382
_groupMaxDcg . Add ( maxDcg ) ;
383
383
}
384
384
385
- RankerUtils . QueryDcg ( _labelGains , TruncationLevel , _queryLabels , _queryOutputs , _groupDcgCur ) ;
385
+ RankingUtils . QueryDcg ( _labelGains , TruncationLevel , _queryLabels , _queryOutputs , _groupDcgCur ) ;
386
386
if ( _groupDcg != null )
387
387
{
388
388
var groupDcg = new Double [ TruncationLevel ] ;
@@ -539,7 +539,7 @@ public void GetSlotNames(ref VBuffer<ReadOnlyMemory<char>> slotNames)
539
539
}
540
540
}
541
541
542
- internal sealed class RankerPerInstanceTransform : IDataTransform
542
+ internal sealed class RankingPerInstanceTransform : IDataTransform
543
543
{
544
544
public const string LoaderSignature = "RankerPerInstTransform" ;
545
545
private const string RegistrationName = LoaderSignature ;
@@ -552,7 +552,7 @@ private static VersionInfo GetVersionInfo()
552
552
verReadableCur : 0x00010001 ,
553
553
verWeCanReadBack : 0x00010001 ,
554
554
loaderSignature : LoaderSignature ,
555
- loaderAssemblyName : typeof ( RankerPerInstanceTransform ) . Assembly . FullName ) ;
555
+ loaderAssemblyName : typeof ( RankingPerInstanceTransform ) . Assembly . FullName ) ;
556
556
}
557
557
558
558
public const string Ndcg = "NDCG" ;
@@ -576,25 +576,25 @@ private static VersionInfo GetVersionInfo()
576
576
/// </summary>
577
577
public DataViewSchema OutputSchema => _transform . OutputSchema ;
578
578
579
- public RankerPerInstanceTransform ( IHostEnvironment env , IDataView input , string labelCol , string scoreCol , string groupCol ,
579
+ public RankingPerInstanceTransform ( IHostEnvironment env , IDataView input , string labelCol , string scoreCol , string groupCol ,
580
580
int truncationLevel , Double [ ] labelGains )
581
581
{
582
582
_transform = new Transform ( env , input , labelCol , scoreCol , groupCol , truncationLevel , labelGains ) ;
583
583
}
584
584
585
- private RankerPerInstanceTransform ( IHostEnvironment env , ModelLoadContext ctx , IDataView input )
585
+ private RankingPerInstanceTransform ( IHostEnvironment env , ModelLoadContext ctx , IDataView input )
586
586
{
587
587
_transform = new Transform ( env , ctx , input ) ;
588
588
}
589
589
590
- public static RankerPerInstanceTransform Create ( IHostEnvironment env , ModelLoadContext ctx , IDataView input )
590
+ public static RankingPerInstanceTransform Create ( IHostEnvironment env , ModelLoadContext ctx , IDataView input )
591
591
{
592
592
Contracts . CheckValue ( env , nameof ( env ) ) ;
593
593
var h = env . Register ( RegistrationName ) ;
594
594
h . CheckValue ( ctx , nameof ( ctx ) ) ;
595
595
ctx . CheckAtModel ( GetVersionInfo ( ) ) ;
596
596
h . CheckValue ( input , nameof ( input ) ) ;
597
- return h . Apply ( "Loading Model" , ch => new RankerPerInstanceTransform ( h , ctx , input ) ) ;
597
+ return h . Apply ( "Loading Model" , ch => new RankingPerInstanceTransform ( h , ctx , input ) ) ;
598
598
}
599
599
600
600
void ICanSaveModel . Save ( ModelSaveContext ctx )
@@ -801,9 +801,9 @@ protected override void ProcessExample(RowCursorState state, short label, Single
801
801
protected override void UpdateState ( RowCursorState state )
802
802
{
803
803
// Calculate the current group DCG, NDCG and MaxDcg.
804
- RankerUtils . QueryMaxDcg ( _labelGains , _truncationLevel , state . QueryLabels , state . QueryOutputs ,
804
+ RankingUtils . QueryMaxDcg ( _labelGains , _truncationLevel , state . QueryLabels , state . QueryOutputs ,
805
805
state . MaxDcgCur ) ;
806
- RankerUtils . QueryDcg ( _labelGains , _truncationLevel , state . QueryLabels , state . QueryOutputs , state . DcgCur ) ;
806
+ RankingUtils . QueryDcg ( _labelGains , _truncationLevel , state . QueryLabels , state . QueryOutputs , state . DcgCur ) ;
807
807
for ( int t = 0 ; t < _truncationLevel ; t ++ )
808
808
{
809
809
Double ndcg = state . MaxDcgCur [ t ] > 0 ? state . DcgCur [ t ] / state . MaxDcgCur [ t ] * 100 : 0 ;
@@ -838,7 +838,7 @@ public RowCursorState(int truncationLevel)
838
838
}
839
839
840
840
[ BestFriend ]
841
- internal sealed class RankerMamlEvaluator : MamlEvaluatorBase
841
+ internal sealed class RankingMamlEvaluator : MamlEvaluatorBase
842
842
{
843
843
public sealed class Arguments : ArgumentsBase
844
844
{
@@ -855,25 +855,25 @@ public sealed class Arguments : ArgumentsBase
855
855
public string GroupSummaryFilename ;
856
856
}
857
857
858
- private readonly RankerEvaluator _evaluator ;
858
+ private readonly RankingEvaluator _evaluator ;
859
859
private readonly string _groupIdCol ;
860
860
861
861
private readonly string _groupSummaryFilename ;
862
862
863
863
private protected override IEvaluator Evaluator => _evaluator ;
864
864
865
- public RankerMamlEvaluator ( IHostEnvironment env , Arguments args )
865
+ public RankingMamlEvaluator ( IHostEnvironment env , Arguments args )
866
866
: base ( args , env , MetadataUtils . Const . ScoreColumnKind . Ranking , "RankerMamlEvaluator" )
867
867
{
868
868
Host . CheckValue ( args , nameof ( args ) ) ;
869
869
Utils . CheckOptionalUserDirectory ( args . GroupSummaryFilename , nameof ( args . GroupSummaryFilename ) ) ;
870
870
871
- var evalArgs = new RankerEvaluator . Arguments ( ) ;
871
+ var evalArgs = new RankingEvaluator . Arguments ( ) ;
872
872
evalArgs . DcgTruncationLevel = args . DcgTruncationLevel ;
873
873
evalArgs . LabelGains = args . LabelGains ;
874
874
evalArgs . OutputGroupSummary = ! string . IsNullOrEmpty ( args . GroupSummaryFilename ) ;
875
875
876
- _evaluator = new RankerEvaluator ( Host , evalArgs ) ;
876
+ _evaluator = new RankingEvaluator ( Host , evalArgs ) ;
877
877
_groupSummaryFilename = args . GroupSummaryFilename ;
878
878
_groupIdCol = args . GroupIdColumn ;
879
879
}
@@ -908,14 +908,14 @@ private bool TryGetGroupSummaryMetrics(Dictionary<string, IDataView>[] metrics,
908
908
Host . AssertNonEmpty ( metrics ) ;
909
909
910
910
if ( metrics . Length == 1 )
911
- return metrics [ 0 ] . TryGetValue ( RankerEvaluator . GroupSummary , out gs ) ;
911
+ return metrics [ 0 ] . TryGetValue ( RankingEvaluator . GroupSummary , out gs ) ;
912
912
913
913
gs = null ;
914
914
var gsList = new List < IDataView > ( ) ;
915
915
for ( int i = 0 ; i < metrics . Length ; i ++ )
916
916
{
917
917
IDataView idv ;
918
- if ( ! metrics [ i ] . TryGetValue ( RankerEvaluator . GroupSummary , out idv ) )
918
+ if ( ! metrics [ i ] . TryGetValue ( RankingEvaluator . GroupSummary , out idv ) )
919
919
return false ;
920
920
921
921
idv = EvaluateUtils . AddFoldIndex ( Host , idv , i , metrics . Length ) ;
@@ -939,13 +939,13 @@ private protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleM
939
939
yield return scoreCol . Name ;
940
940
941
941
// Return the output columns.
942
- yield return RankerPerInstanceTransform . Ndcg ;
943
- yield return RankerPerInstanceTransform . Dcg ;
944
- yield return RankerPerInstanceTransform . MaxDcg ;
942
+ yield return RankingPerInstanceTransform . Ndcg ;
943
+ yield return RankingPerInstanceTransform . Dcg ;
944
+ yield return RankingPerInstanceTransform . MaxDcg ;
945
945
}
946
946
}
947
947
948
- internal static class RankerUtils
948
+ internal static class RankingUtils
949
949
{
950
950
private static volatile Double [ ] _discountMap ;
951
951
public static Double [ ] DiscountMap
@@ -1054,8 +1054,8 @@ private static Comparison<int> GetCompareItems(List<short> queryLabels, List<Sin
1054
1054
1055
1055
internal static partial class Evaluate
1056
1056
{
1057
- [ TlcModule . EntryPoint ( Name = "Models.RankerEvaluator " , Desc = "Evaluates a ranking scored dataset." ) ]
1058
- public static CommonOutputs . CommonEvaluateOutput Ranking ( IHostEnvironment env , RankerMamlEvaluator . Arguments input )
1057
+ [ TlcModule . EntryPoint ( Name = "Models.RankingEvaluator " , Desc = "Evaluates a ranking scored dataset." ) ]
1058
+ public static CommonOutputs . CommonEvaluateOutput Ranking ( IHostEnvironment env , RankingMamlEvaluator . Arguments input )
1059
1059
{
1060
1060
Contracts . CheckValue ( env , nameof ( env ) ) ;
1061
1061
var host = env . Register ( "EvaluateRanker" ) ;
@@ -1068,9 +1068,9 @@ public static CommonOutputs.CommonEvaluateOutput Ranking(IHostEnvironment env, R
1068
1068
MatchColumns ( host , input , out label , out weight , out name ) ;
1069
1069
var schema = input . Data . Schema ;
1070
1070
string groupId = TrainUtils . MatchNameOrDefaultOrNull ( host , schema ,
1071
- nameof ( RankerMamlEvaluator . Arguments . GroupIdColumn ) ,
1071
+ nameof ( RankingMamlEvaluator . Arguments . GroupIdColumn ) ,
1072
1072
input . GroupIdColumn , DefaultColumnNames . GroupId ) ;
1073
- IMamlEvaluator evaluator = new RankerMamlEvaluator ( host , input ) ;
1073
+ IMamlEvaluator evaluator = new RankingMamlEvaluator ( host , input ) ;
1074
1074
var data = new RoleMappedData ( input . Data , label , null , groupId , weight , name ) ;
1075
1075
var metrics = evaluator . Evaluate ( data ) ;
1076
1076
0 commit comments