Skip to content

Commit 4420cc7

Browse files
authored
Changed Ranker to Ranking in evaluation related files. (#2675)
1 parent f6d55f3 commit 4420cc7

File tree

14 files changed

+72
-72
lines changed

14 files changed

+72
-72
lines changed

src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -850,8 +850,8 @@ public static class PipelineSweeperSupportedMetrics
850850
public const string RSquared = RegressionLossEvaluatorBase<MultiOutputRegressionEvaluator.Aggregator>.RSquared;
851851
public const string LogLoss = BinaryClassifierEvaluator.LogLoss;
852852
public const string LogLossReduction = BinaryClassifierEvaluator.LogLossReduction;
853-
public const string Ndcg = RankerEvaluator.Ndcg;
854-
public const string Dcg = RankerEvaluator.Dcg;
853+
public const string Ndcg = RankingEvaluator.Ndcg;
854+
public const string Dcg = RankingEvaluator.Dcg;
855855
public const string PositivePrecision = BinaryClassifierEvaluator.PosPrecName;
856856
public const string PositiveRecall = BinaryClassifierEvaluator.PosRecallName;
857857
public const string NegativePrecision = BinaryClassifierEvaluator.NegPrecName;

src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ public static Dictionary<string, Func<IHostEnvironment, IMamlEvaluator>> Instanc
4444
{ MetadataUtils.Const.ScoreColumnKind.Regression, env => new RegressionMamlEvaluator(env, new RegressionMamlEvaluator.Arguments()) },
4545
{ MetadataUtils.Const.ScoreColumnKind.MultiOutputRegression, env => new MultiOutputRegressionMamlEvaluator(env, new MultiOutputRegressionMamlEvaluator.Arguments()) },
4646
{ MetadataUtils.Const.ScoreColumnKind.QuantileRegression, env => new QuantileRegressionMamlEvaluator(env, new QuantileRegressionMamlEvaluator.Arguments()) },
47-
{ MetadataUtils.Const.ScoreColumnKind.Ranking, env => new RankerMamlEvaluator(env, new RankerMamlEvaluator.Arguments()) },
47+
{ MetadataUtils.Const.ScoreColumnKind.Ranking, env => new RankingMamlEvaluator(env, new RankingMamlEvaluator.Arguments()) },
4848
{ MetadataUtils.Const.ScoreColumnKind.Clustering, env => new ClusteringMamlEvaluator(env, new ClusteringMamlEvaluator.Arguments()) },
4949
{ MetadataUtils.Const.ScoreColumnKind.AnomalyDetection, env => new AnomalyDetectionMamlEvaluator(env, new AnomalyDetectionMamlEvaluator.Arguments()) }
5050
};

src/Microsoft.ML.Data/Evaluators/Metrics/RankerMetrics.cs renamed to src/Microsoft.ML.Data/Evaluators/Metrics/RankingMetrics.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
namespace Microsoft.ML.Data
88
{
9-
public sealed class RankerMetrics
9+
public sealed class RankingMetrics
1010
{
1111
/// <summary>
1212
/// Array of normalized discounted cumulative gains where i-th element represent NDCG@i.
@@ -32,15 +32,15 @@ private static T Fetch<T>(IExceptionContext ectx, DataViewRow row, string name)
3232
return val;
3333
}
3434

35-
internal RankerMetrics(IExceptionContext ectx, DataViewRow overallResult)
35+
internal RankingMetrics(IExceptionContext ectx, DataViewRow overallResult)
3636
{
3737
VBuffer<double> Fetch(string name) => Fetch<VBuffer<double>>(ectx, overallResult, name);
3838

39-
Dcg = Fetch(RankerEvaluator.Dcg).GetValues().ToArray();
40-
Ndcg = Fetch(RankerEvaluator.Ndcg).GetValues().ToArray();
39+
Dcg = Fetch(RankingEvaluator.Dcg).GetValues().ToArray();
40+
Ndcg = Fetch(RankingEvaluator.Ndcg).GetValues().ToArray();
4141
}
4242

43-
internal RankerMetrics(double[] dcg, double[] ndcg)
43+
internal RankingMetrics(double[] dcg, double[] ndcg)
4444
{
4545
Dcg = new double[dcg.Length];
4646
dcg.CopyTo(Dcg, 0);

src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs renamed to src/Microsoft.ML.Data/Evaluators/RankingEvaluator.cs

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,19 @@
1616
using Microsoft.ML.Internal.Utilities;
1717
using Microsoft.ML.Model;
1818

19-
[assembly: LoadableClass(typeof(RankerEvaluator), typeof(RankerEvaluator), typeof(RankerEvaluator.Arguments), typeof(SignatureEvaluator),
20-
"Ranking Evaluator", RankerEvaluator.LoadName, "Ranking", "rank")]
19+
[assembly: LoadableClass(typeof(RankingEvaluator), typeof(RankingEvaluator), typeof(RankingEvaluator.Arguments), typeof(SignatureEvaluator),
20+
"Ranking Evaluator", RankingEvaluator.LoadName, "Ranking", "rank")]
2121

22-
[assembly: LoadableClass(typeof(RankerMamlEvaluator), typeof(RankerMamlEvaluator), typeof(RankerMamlEvaluator.Arguments), typeof(SignatureMamlEvaluator),
23-
"Ranking Evaluator", RankerEvaluator.LoadName, "Ranking", "rank")]
22+
[assembly: LoadableClass(typeof(RankingMamlEvaluator), typeof(RankingMamlEvaluator), typeof(RankingMamlEvaluator.Arguments), typeof(SignatureMamlEvaluator),
23+
"Ranking Evaluator", RankingEvaluator.LoadName, "Ranking", "rank")]
2424

25-
[assembly: LoadableClass(typeof(RankerPerInstanceTransform), null, typeof(SignatureLoadDataTransform),
26-
"", RankerPerInstanceTransform.LoaderSignature)]
25+
[assembly: LoadableClass(typeof(RankingPerInstanceTransform), null, typeof(SignatureLoadDataTransform),
26+
"", RankingPerInstanceTransform.LoaderSignature)]
2727

2828
namespace Microsoft.ML.Data
2929
{
3030
[BestFriend]
31-
internal sealed class RankerEvaluator : EvaluatorBase<RankerEvaluator.Aggregator>
31+
internal sealed class RankingEvaluator : EvaluatorBase<RankingEvaluator.Aggregator>
3232
{
3333
public sealed class Arguments
3434
{
@@ -61,7 +61,7 @@ public sealed class Arguments
6161
private readonly bool _groupSummary;
6262
private readonly Double[] _labelGains;
6363

64-
public RankerEvaluator(IHostEnvironment env, Arguments args)
64+
public RankingEvaluator(IHostEnvironment env, Arguments args)
6565
: base(env, LoadName)
6666
{
6767
// REVIEW: What kind of checking should be applied to labelGains?
@@ -89,13 +89,13 @@ private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
8989
var t = schema.Label.Value.Type;
9090
if (t != NumberDataViewType.Single && !(t is KeyType))
9191
{
92-
throw Host.ExceptSchemaMismatch(nameof(RankerMamlEvaluator.Arguments.LabelColumn),
92+
throw Host.ExceptSchemaMismatch(nameof(RankingMamlEvaluator.Arguments.LabelColumn),
9393
"label", schema.Label.Value.Name, "R4 or a key", t.ToString());
9494
}
9595
var scoreCol = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score);
9696
if (scoreCol.Type != NumberDataViewType.Single)
9797
{
98-
throw Host.ExceptSchemaMismatch(nameof(RankerMamlEvaluator.Arguments.ScoreColumn),
98+
throw Host.ExceptSchemaMismatch(nameof(RankingMamlEvaluator.Arguments.ScoreColumn),
9999
"score", scoreCol.Name, "R4", t.ToString());
100100
}
101101
}
@@ -105,7 +105,7 @@ private protected override void CheckCustomColumnTypesCore(RoleMappedSchema sche
105105
var t = schema.Group.Value.Type;
106106
if (!(t is KeyType))
107107
{
108-
throw Host.ExceptSchemaMismatch(nameof(RankerMamlEvaluator.Arguments.GroupIdColumn),
108+
throw Host.ExceptSchemaMismatch(nameof(RankingMamlEvaluator.Arguments.GroupIdColumn),
109109
"group", schema.Group.Value.Name, "key", t.ToString());
110110
}
111111
}
@@ -129,7 +129,7 @@ internal override IDataTransform GetPerInstanceMetricsCore(RoleMappedData data)
129129
var scoreInfo = data.Schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score);
130130
Host.CheckParam(data.Schema.Group.HasValue, nameof(data), "Schema must contain a group column");
131131

132-
return new RankerPerInstanceTransform(Host, data.Data,
132+
return new RankingPerInstanceTransform(Host, data.Data,
133133
data.Schema.Label.Value.Name, scoreInfo.Name, data.Schema.Group.Value.Name, _truncationLevel, _labelGains);
134134
}
135135

@@ -242,7 +242,7 @@ private protected override void GetAggregatorConsolidationFuncs(Aggregator aggre
242242
/// <param name="groupId">The name of the groupId column.</param>
243243
/// <param name="score">The name of the predicted score column.</param>
244244
/// <returns>The evaluation metrics for these outputs.</returns>
245-
public RankerMetrics Evaluate(IDataView data, string label, string groupId, string score)
245+
public RankingMetrics Evaluate(IDataView data, string label, string groupId, string score)
246246
{
247247
Host.CheckValue(data, nameof(data));
248248
Host.CheckNonEmpty(label, nameof(label));
@@ -256,12 +256,12 @@ public RankerMetrics Evaluate(IDataView data, string label, string groupId, stri
256256
Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics));
257257
var overall = resultDict[MetricKinds.OverallMetrics];
258258

259-
RankerMetrics result;
259+
RankingMetrics result;
260260
using (var cursor = overall.GetRowCursorForAllColumns())
261261
{
262262
var moved = cursor.MoveNext();
263263
Host.Assert(moved);
264-
result = new RankerMetrics(Host, cursor);
264+
result = new RankingMetrics(Host, cursor);
265265
moved = cursor.MoveNext();
266266
Host.Assert(!moved);
267267
}
@@ -374,15 +374,15 @@ public void Update(short label, Single output)
374374

375375
public void UpdateGroup(Single weight)
376376
{
377-
RankerUtils.QueryMaxDcg(_labelGains, TruncationLevel, _queryLabels, _queryOutputs, _groupMaxDcgCur);
377+
RankingUtils.QueryMaxDcg(_labelGains, TruncationLevel, _queryLabels, _queryOutputs, _groupMaxDcgCur);
378378
if (_groupMaxDcg != null)
379379
{
380380
var maxDcg = new Double[TruncationLevel];
381381
Array.Copy(_groupMaxDcgCur, maxDcg, TruncationLevel);
382382
_groupMaxDcg.Add(maxDcg);
383383
}
384384

385-
RankerUtils.QueryDcg(_labelGains, TruncationLevel, _queryLabels, _queryOutputs, _groupDcgCur);
385+
RankingUtils.QueryDcg(_labelGains, TruncationLevel, _queryLabels, _queryOutputs, _groupDcgCur);
386386
if (_groupDcg != null)
387387
{
388388
var groupDcg = new Double[TruncationLevel];
@@ -539,7 +539,7 @@ public void GetSlotNames(ref VBuffer<ReadOnlyMemory<char>> slotNames)
539539
}
540540
}
541541

542-
internal sealed class RankerPerInstanceTransform : IDataTransform
542+
internal sealed class RankingPerInstanceTransform : IDataTransform
543543
{
544544
public const string LoaderSignature = "RankerPerInstTransform";
545545
private const string RegistrationName = LoaderSignature;
@@ -552,7 +552,7 @@ private static VersionInfo GetVersionInfo()
552552
verReadableCur: 0x00010001,
553553
verWeCanReadBack: 0x00010001,
554554
loaderSignature: LoaderSignature,
555-
loaderAssemblyName: typeof(RankerPerInstanceTransform).Assembly.FullName);
555+
loaderAssemblyName: typeof(RankingPerInstanceTransform).Assembly.FullName);
556556
}
557557

558558
public const string Ndcg = "NDCG";
@@ -576,25 +576,25 @@ private static VersionInfo GetVersionInfo()
576576
/// </summary>
577577
public DataViewSchema OutputSchema => _transform.OutputSchema;
578578

579-
public RankerPerInstanceTransform(IHostEnvironment env, IDataView input, string labelCol, string scoreCol, string groupCol,
579+
public RankingPerInstanceTransform(IHostEnvironment env, IDataView input, string labelCol, string scoreCol, string groupCol,
580580
int truncationLevel, Double[] labelGains)
581581
{
582582
_transform = new Transform(env, input, labelCol, scoreCol, groupCol, truncationLevel, labelGains);
583583
}
584584

585-
private RankerPerInstanceTransform(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
585+
private RankingPerInstanceTransform(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
586586
{
587587
_transform = new Transform(env, ctx, input);
588588
}
589589

590-
public static RankerPerInstanceTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
590+
public static RankingPerInstanceTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
591591
{
592592
Contracts.CheckValue(env, nameof(env));
593593
var h = env.Register(RegistrationName);
594594
h.CheckValue(ctx, nameof(ctx));
595595
ctx.CheckAtModel(GetVersionInfo());
596596
h.CheckValue(input, nameof(input));
597-
return h.Apply("Loading Model", ch => new RankerPerInstanceTransform(h, ctx, input));
597+
return h.Apply("Loading Model", ch => new RankingPerInstanceTransform(h, ctx, input));
598598
}
599599

600600
void ICanSaveModel.Save(ModelSaveContext ctx)
@@ -801,9 +801,9 @@ protected override void ProcessExample(RowCursorState state, short label, Single
801801
protected override void UpdateState(RowCursorState state)
802802
{
803803
// Calculate the current group DCG, NDCG and MaxDcg.
804-
RankerUtils.QueryMaxDcg(_labelGains, _truncationLevel, state.QueryLabels, state.QueryOutputs,
804+
RankingUtils.QueryMaxDcg(_labelGains, _truncationLevel, state.QueryLabels, state.QueryOutputs,
805805
state.MaxDcgCur);
806-
RankerUtils.QueryDcg(_labelGains, _truncationLevel, state.QueryLabels, state.QueryOutputs, state.DcgCur);
806+
RankingUtils.QueryDcg(_labelGains, _truncationLevel, state.QueryLabels, state.QueryOutputs, state.DcgCur);
807807
for (int t = 0; t < _truncationLevel; t++)
808808
{
809809
Double ndcg = state.MaxDcgCur[t] > 0 ? state.DcgCur[t] / state.MaxDcgCur[t] * 100 : 0;
@@ -838,7 +838,7 @@ public RowCursorState(int truncationLevel)
838838
}
839839

840840
[BestFriend]
841-
internal sealed class RankerMamlEvaluator : MamlEvaluatorBase
841+
internal sealed class RankingMamlEvaluator : MamlEvaluatorBase
842842
{
843843
public sealed class Arguments : ArgumentsBase
844844
{
@@ -855,25 +855,25 @@ public sealed class Arguments : ArgumentsBase
855855
public string GroupSummaryFilename;
856856
}
857857

858-
private readonly RankerEvaluator _evaluator;
858+
private readonly RankingEvaluator _evaluator;
859859
private readonly string _groupIdCol;
860860

861861
private readonly string _groupSummaryFilename;
862862

863863
private protected override IEvaluator Evaluator => _evaluator;
864864

865-
public RankerMamlEvaluator(IHostEnvironment env, Arguments args)
865+
public RankingMamlEvaluator(IHostEnvironment env, Arguments args)
866866
: base(args, env, MetadataUtils.Const.ScoreColumnKind.Ranking, "RankerMamlEvaluator")
867867
{
868868
Host.CheckValue(args, nameof(args));
869869
Utils.CheckOptionalUserDirectory(args.GroupSummaryFilename, nameof(args.GroupSummaryFilename));
870870

871-
var evalArgs = new RankerEvaluator.Arguments();
871+
var evalArgs = new RankingEvaluator.Arguments();
872872
evalArgs.DcgTruncationLevel = args.DcgTruncationLevel;
873873
evalArgs.LabelGains = args.LabelGains;
874874
evalArgs.OutputGroupSummary = !string.IsNullOrEmpty(args.GroupSummaryFilename);
875875

876-
_evaluator = new RankerEvaluator(Host, evalArgs);
876+
_evaluator = new RankingEvaluator(Host, evalArgs);
877877
_groupSummaryFilename = args.GroupSummaryFilename;
878878
_groupIdCol = args.GroupIdColumn;
879879
}
@@ -908,14 +908,14 @@ private bool TryGetGroupSummaryMetrics(Dictionary<string, IDataView>[] metrics,
908908
Host.AssertNonEmpty(metrics);
909909

910910
if (metrics.Length == 1)
911-
return metrics[0].TryGetValue(RankerEvaluator.GroupSummary, out gs);
911+
return metrics[0].TryGetValue(RankingEvaluator.GroupSummary, out gs);
912912

913913
gs = null;
914914
var gsList = new List<IDataView>();
915915
for (int i = 0; i < metrics.Length; i++)
916916
{
917917
IDataView idv;
918-
if (!metrics[i].TryGetValue(RankerEvaluator.GroupSummary, out idv))
918+
if (!metrics[i].TryGetValue(RankingEvaluator.GroupSummary, out idv))
919919
return false;
920920

921921
idv = EvaluateUtils.AddFoldIndex(Host, idv, i, metrics.Length);
@@ -939,13 +939,13 @@ private protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleM
939939
yield return scoreCol.Name;
940940

941941
// Return the output columns.
942-
yield return RankerPerInstanceTransform.Ndcg;
943-
yield return RankerPerInstanceTransform.Dcg;
944-
yield return RankerPerInstanceTransform.MaxDcg;
942+
yield return RankingPerInstanceTransform.Ndcg;
943+
yield return RankingPerInstanceTransform.Dcg;
944+
yield return RankingPerInstanceTransform.MaxDcg;
945945
}
946946
}
947947

948-
internal static class RankerUtils
948+
internal static class RankingUtils
949949
{
950950
private static volatile Double[] _discountMap;
951951
public static Double[] DiscountMap
@@ -1054,8 +1054,8 @@ private static Comparison<int> GetCompareItems(List<short> queryLabels, List<Sin
10541054

10551055
internal static partial class Evaluate
10561056
{
1057-
[TlcModule.EntryPoint(Name = "Models.RankerEvaluator", Desc = "Evaluates a ranking scored dataset.")]
1058-
public static CommonOutputs.CommonEvaluateOutput Ranking(IHostEnvironment env, RankerMamlEvaluator.Arguments input)
1057+
[TlcModule.EntryPoint(Name = "Models.RankingEvaluator", Desc = "Evaluates a ranking scored dataset.")]
1058+
public static CommonOutputs.CommonEvaluateOutput Ranking(IHostEnvironment env, RankingMamlEvaluator.Arguments input)
10591059
{
10601060
Contracts.CheckValue(env, nameof(env));
10611061
var host = env.Register("EvaluateRanker");
@@ -1068,9 +1068,9 @@ public static CommonOutputs.CommonEvaluateOutput Ranking(IHostEnvironment env, R
10681068
MatchColumns(host, input, out label, out weight, out name);
10691069
var schema = input.Data.Schema;
10701070
string groupId = TrainUtils.MatchNameOrDefaultOrNull(host, schema,
1071-
nameof(RankerMamlEvaluator.Arguments.GroupIdColumn),
1071+
nameof(RankingMamlEvaluator.Arguments.GroupIdColumn),
10721072
input.GroupIdColumn, DefaultColumnNames.GroupId);
1073-
IMamlEvaluator evaluator = new RankerMamlEvaluator(host, input);
1073+
IMamlEvaluator evaluator = new RankingMamlEvaluator(host, input);
10741074
var data = new RoleMappedData(input.Data, label, null, groupId, weight, name);
10751075
var metrics = evaluator.Evaluate(data);
10761076

src/Microsoft.ML.Data/TrainCatalog.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,7 @@ internal RankingTrainers(RankingCatalog catalog)
623623
/// <param name="groupId">The name of the groupId column in <paramref name="data"/>.</param>
624624
/// <param name="score">The name of the score column in <paramref name="data"/>.</param>
625625
/// <returns>The evaluation results for these calibrated outputs.</returns>
626-
public RankerMetrics Evaluate(IDataView data,
626+
public RankingMetrics Evaluate(IDataView data,
627627
string label = DefaultColumnNames.Label,
628628
string groupId = DefaultColumnNames.GroupId,
629629
string score = DefaultColumnNames.Score)
@@ -633,7 +633,7 @@ public RankerMetrics Evaluate(IDataView data,
633633
Environment.CheckNonEmpty(score, nameof(score));
634634
Environment.CheckNonEmpty(groupId, nameof(groupId));
635635

636-
var eval = new RankerEvaluator(Environment, new RankerEvaluator.Arguments() { });
636+
var eval = new RankingEvaluator(Environment, new RankingEvaluator.Arguments() { });
637637
return eval.Evaluate(data, label, groupId, score);
638638
}
639639
}

src/Microsoft.ML.EntryPoints/CrossValidationMacro.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ private static IMamlEvaluator GetEvaluator(IHostEnvironment env, MacroUtils.Trai
430430
case MacroUtils.TrainerKinds.SignatureRegressorTrainer:
431431
return new RegressionMamlEvaluator(env, new RegressionMamlEvaluator.Arguments());
432432
case MacroUtils.TrainerKinds.SignatureRankerTrainer:
433-
return new RankerMamlEvaluator(env, new RankerMamlEvaluator.Arguments());
433+
return new RankingMamlEvaluator(env, new RankingMamlEvaluator.Arguments());
434434
case MacroUtils.TrainerKinds.SignatureAnomalyDetectorTrainer:
435435
return new AnomalyDetectionMamlEvaluator(env, new AnomalyDetectionMamlEvaluator.Arguments());
436436
case MacroUtils.TrainerKinds.SignatureClusteringTrainer:

0 commit comments

Comments
 (0)