Skip to content

Commit 2180cfb

Browse files
authored
Modify API for advanced settings (LightGBM) (#2261)
* lightgbm tests work fine * Options renaming * review comments * update tests to exercise the catalog entries
1 parent 620ca89 commit 2180cfb

File tree

11 files changed

+335
-158
lines changed

11 files changed

+335
-158
lines changed

src/Microsoft.ML.LightGBM.StaticPipe/LightGbmStaticExtensions.cs

+185-25
Large diffs are not rendered by default.

src/Microsoft.ML.LightGBM/LightGbmArguments.cs

+10-10
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,16 @@
1111
using Microsoft.ML.Internal.Internallearn;
1212
using Microsoft.ML.LightGBM;
1313

14-
[assembly: LoadableClass(typeof(LightGbmArguments.TreeBooster), typeof(LightGbmArguments.TreeBooster.Arguments),
15-
typeof(SignatureLightGBMBooster), LightGbmArguments.TreeBooster.FriendlyName, LightGbmArguments.TreeBooster.Name)]
16-
[assembly: LoadableClass(typeof(LightGbmArguments.DartBooster), typeof(LightGbmArguments.DartBooster.Arguments),
17-
typeof(SignatureLightGBMBooster), LightGbmArguments.DartBooster.FriendlyName, LightGbmArguments.DartBooster.Name)]
18-
[assembly: LoadableClass(typeof(LightGbmArguments.GossBooster), typeof(LightGbmArguments.GossBooster.Arguments),
19-
typeof(SignatureLightGBMBooster), LightGbmArguments.GossBooster.FriendlyName, LightGbmArguments.GossBooster.Name)]
14+
[assembly: LoadableClass(typeof(Options.TreeBooster), typeof(Options.TreeBooster.Arguments),
15+
typeof(SignatureLightGBMBooster), Options.TreeBooster.FriendlyName, Options.TreeBooster.Name)]
16+
[assembly: LoadableClass(typeof(Options.DartBooster), typeof(Options.DartBooster.Arguments),
17+
typeof(SignatureLightGBMBooster), Options.DartBooster.FriendlyName, Options.DartBooster.Name)]
18+
[assembly: LoadableClass(typeof(Options.GossBooster), typeof(Options.GossBooster.Arguments),
19+
typeof(SignatureLightGBMBooster), Options.GossBooster.FriendlyName, Options.GossBooster.Name)]
2020

21-
[assembly: EntryPointModule(typeof(LightGbmArguments.TreeBooster.Arguments))]
22-
[assembly: EntryPointModule(typeof(LightGbmArguments.DartBooster.Arguments))]
23-
[assembly: EntryPointModule(typeof(LightGbmArguments.GossBooster.Arguments))]
21+
[assembly: EntryPointModule(typeof(Options.TreeBooster.Arguments))]
22+
[assembly: EntryPointModule(typeof(Options.DartBooster.Arguments))]
23+
[assembly: EntryPointModule(typeof(Options.GossBooster.Arguments))]
2424

2525
namespace Microsoft.ML.LightGBM
2626
{
@@ -39,7 +39,7 @@ public interface IBoosterParameter
3939
/// Parameters names comes from LightGBM library.
4040
/// See https://github.com/Microsoft/LightGBM/blob/master/docs/Parameters.rst.
4141
/// </summary>
42-
public sealed class LightGbmArguments : LearnerInputBaseWithGroupId
42+
public sealed class Options : LearnerInputBaseWithGroupId
4343
{
4444
public abstract class BoosterParameter<TArgs> : IBoosterParameter
4545
where TArgs : class, new()

src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs

+8-13
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
using Microsoft.ML.Trainers.FastTree.Internal;
1717
using Microsoft.ML.Training;
1818

19-
[assembly: LoadableClass(LightGbmBinaryTrainer.Summary, typeof(LightGbmBinaryTrainer), typeof(LightGbmArguments),
19+
[assembly: LoadableClass(LightGbmBinaryTrainer.Summary, typeof(LightGbmBinaryTrainer), typeof(Options),
2020
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer) },
2121
LightGbmBinaryTrainer.UserName, LightGbmBinaryTrainer.LoadNameValue, LightGbmBinaryTrainer.ShortName, DocName = "trainer/LightGBM.md")]
2222

@@ -95,8 +95,8 @@ public sealed class LightGbmBinaryTrainer : LightGbmTrainerBase<float, BinaryPre
9595

9696
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
9797

98-
internal LightGbmBinaryTrainer(IHostEnvironment env, LightGbmArguments args)
99-
: base(env, LoadNameValue, args, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn))
98+
internal LightGbmBinaryTrainer(IHostEnvironment env, Options options)
99+
: base(env, LoadNameValue, options, TrainerUtils.MakeBoolScalarLabel(options.LabelColumn))
100100
{
101101
}
102102

@@ -111,20 +111,15 @@ internal LightGbmBinaryTrainer(IHostEnvironment env, LightGbmArguments args)
111111
/// <param name="numBoostRound">Number of iterations.</param>
112112
/// <param name="minDataPerLeaf">The minimal number of documents allowed in a leaf of the tree, out of the subsampled data.</param>
113113
/// <param name="learningRate">The learning rate.</param>
114-
/// <param name="advancedSettings">A delegate to set more settings.
115-
/// The settings here will override the ones provided in the direct signature,
116-
/// if both are present and have different values.
117-
/// The columns names, however need to be provided directly, not through the <paramref name="advancedSettings"/>.</param>
118-
public LightGbmBinaryTrainer(IHostEnvironment env,
114+
internal LightGbmBinaryTrainer(IHostEnvironment env,
119115
string labelColumn = DefaultColumnNames.Label,
120116
string featureColumn = DefaultColumnNames.Features,
121117
string weights = null,
122118
int? numLeaves = null,
123119
int? minDataPerLeaf = null,
124120
double? learningRate = null,
125-
int numBoostRound = LightGbmArguments.Defaults.NumBoostRound,
126-
Action<LightGbmArguments> advancedSettings = null)
127-
: base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weights, null, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings)
121+
int numBoostRound = LightGBM.Options.Defaults.NumBoostRound)
122+
: base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weights, null, numLeaves, minDataPerLeaf, learningRate, numBoostRound)
128123
{
129124
}
130125

@@ -186,14 +181,14 @@ public static partial class LightGbm
186181
ShortName = LightGbmBinaryTrainer.ShortName,
187182
XmlInclude = new[] { @"<include file='../Microsoft.ML.LightGBM/doc.xml' path='doc/members/member[@name=""LightGBM""]/*' />",
188183
@"<include file='../Microsoft.ML.LightGBM/doc.xml' path='doc/members/example[@name=""LightGbmBinaryClassifier""]/*' />"})]
189-
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, LightGbmArguments input)
184+
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, Options input)
190185
{
191186
Contracts.CheckValue(env, nameof(env));
192187
var host = env.Register("TrainLightGBM");
193188
host.CheckValue(input, nameof(input));
194189
EntryPointUtils.CheckInputArgs(host, input);
195190

196-
return LearnerEntryPointsUtils.Train<LightGbmArguments, CommonOutputs.BinaryClassificationOutput>(host, input,
191+
return LearnerEntryPointsUtils.Train<Options, CommonOutputs.BinaryClassificationOutput>(host, input,
197192
() => new LightGbmBinaryTrainer(host, input),
198193
getLabel: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
199194
getWeight: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn));

src/Microsoft.ML.LightGBM/LightGbmCatalog.cs

+60-31
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,31 @@ public static class LightGbmExtensions
2424
/// <param name="numBoostRound">Number of iterations.</param>
2525
/// <param name="minDataPerLeaf">The minimal number of documents allowed in a leaf of the tree, out of the subsampled data.</param>
2626
/// <param name="learningRate">The learning rate.</param>
27-
/// <param name="advancedSettings">A delegate to set more settings.
28-
/// The settings here will override the ones provided in the direct signature,
29-
/// if both are present and have different values.
30-
/// The columns names, however need to be provided directly, not through the <paramref name="advancedSettings"/>.</param>
3127
public static LightGbmRegressorTrainer LightGbm(this RegressionCatalog.RegressionTrainers catalog,
3228
string labelColumn = DefaultColumnNames.Label,
3329
string featureColumn = DefaultColumnNames.Features,
3430
string weights = null,
3531
int? numLeaves = null,
3632
int? minDataPerLeaf = null,
3733
double? learningRate = null,
38-
int numBoostRound = LightGbmArguments.Defaults.NumBoostRound,
39-
Action<LightGbmArguments> advancedSettings = null)
34+
int numBoostRound = Options.Defaults.NumBoostRound)
4035
{
4136
Contracts.CheckValue(catalog, nameof(catalog));
4237
var env = CatalogUtils.GetEnvironment(catalog);
43-
return new LightGbmRegressorTrainer(env, labelColumn, featureColumn, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings);
38+
return new LightGbmRegressorTrainer(env, labelColumn, featureColumn, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound);
39+
}
40+
41+
/// <summary>
42+
/// Predict a target using a decision tree regression model trained with the <see cref="LightGbmRegressorTrainer"/>.
43+
/// </summary>
44+
/// <param name="catalog">The <see cref="RegressionCatalog"/>.</param>
45+
/// <param name="options">Advanced options to the algorithm.</param>
46+
public static LightGbmRegressorTrainer LightGbm(this RegressionCatalog.RegressionTrainers catalog,
47+
Options options)
48+
{
49+
Contracts.CheckValue(catalog, nameof(catalog));
50+
var env = CatalogUtils.GetEnvironment(catalog);
51+
return new LightGbmRegressorTrainer(env, options);
4452
}
4553

4654
/// <summary>
@@ -54,28 +62,35 @@ public static LightGbmRegressorTrainer LightGbm(this RegressionCatalog.Regressio
5462
/// <param name="numBoostRound">Number of iterations.</param>
5563
/// <param name="minDataPerLeaf">The minimal number of documents allowed in a leaf of the tree, out of the subsampled data.</param>
5664
/// <param name="learningRate">The learning rate.</param>
57-
/// <param name="advancedSettings">A delegate to set more settings.
58-
/// The settings here will override the ones provided in the direct signature,
59-
/// if both are present and have different values.
60-
/// The columns names, however need to be provided directly, not through the <paramref name="advancedSettings"/>.</param>
6165
public static LightGbmBinaryTrainer LightGbm(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
6266
string labelColumn = DefaultColumnNames.Label,
6367
string featureColumn = DefaultColumnNames.Features,
6468
string weights = null,
6569
int? numLeaves = null,
6670
int? minDataPerLeaf = null,
6771
double? learningRate = null,
68-
int numBoostRound = LightGbmArguments.Defaults.NumBoostRound,
69-
Action<LightGbmArguments> advancedSettings = null)
72+
int numBoostRound = Options.Defaults.NumBoostRound)
7073
{
7174
Contracts.CheckValue(catalog, nameof(catalog));
7275
var env = CatalogUtils.GetEnvironment(catalog);
73-
return new LightGbmBinaryTrainer(env, labelColumn, featureColumn, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings);
76+
return new LightGbmBinaryTrainer(env, labelColumn, featureColumn, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound);
77+
}
7478

79+
/// <summary>
80+
/// Predict a target using a decision tree binary classification model trained with the <see cref="LightGbmBinaryTrainer"/>.
81+
/// </summary>
82+
/// <param name="catalog">The <see cref="BinaryClassificationCatalog"/>.</param>
83+
/// <param name="options">Advanced options to the algorithm.</param>
84+
public static LightGbmBinaryTrainer LightGbm(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
85+
Options options)
86+
{
87+
Contracts.CheckValue(catalog, nameof(catalog));
88+
var env = CatalogUtils.GetEnvironment(catalog);
89+
return new LightGbmBinaryTrainer(env, options);
7590
}
7691

7792
/// <summary>
78-
/// Predict a target using a decision tree binary classification model trained with the <see cref="LightGbmRankingTrainer"/>.
93+
/// Predict a target using a decision tree ranking model trained with the <see cref="LightGbmRankingTrainer"/>.
7994
/// </summary>
8095
/// <param name="catalog">The <see cref="RankingCatalog"/>.</param>
8196
/// <param name="labelColumn">The labelColumn column.</param>
@@ -86,10 +101,6 @@ public static LightGbmBinaryTrainer LightGbm(this BinaryClassificationCatalog.Bi
86101
/// <param name="numBoostRound">Number of iterations.</param>
87102
/// <param name="minDataPerLeaf">The minimal number of documents allowed in a leaf of the tree, out of the subsampled data.</param>
88103
/// <param name="learningRate">The learning rate.</param>
89-
/// <param name="advancedSettings">A delegate to set more settings.
90-
/// The settings here will override the ones provided in the direct signature,
91-
/// if both are present and have different values.
92-
/// The columns names, however need to be provided directly, not through the <paramref name="advancedSettings"/>.</param>
93104
public static LightGbmRankingTrainer LightGbm(this RankingCatalog.RankingTrainers catalog,
94105
string labelColumn = DefaultColumnNames.Label,
95106
string featureColumn = DefaultColumnNames.Features,
@@ -98,44 +109,62 @@ public static LightGbmRankingTrainer LightGbm(this RankingCatalog.RankingTrainer
98109
int? numLeaves = null,
99110
int? minDataPerLeaf = null,
100111
double? learningRate = null,
101-
int numBoostRound = LightGbmArguments.Defaults.NumBoostRound,
102-
Action<LightGbmArguments> advancedSettings = null)
112+
int numBoostRound = Options.Defaults.NumBoostRound)
103113
{
104114
Contracts.CheckValue(catalog, nameof(catalog));
105115
var env = CatalogUtils.GetEnvironment(catalog);
106-
return new LightGbmRankingTrainer(env, labelColumn, featureColumn, groupIdColumn, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings);
107-
116+
return new LightGbmRankingTrainer(env, labelColumn, featureColumn, groupIdColumn, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound);
108117
}
109118

110119
/// <summary>
111-
/// Predict a target using a decision tree binary classification model trained with the <see cref="LightGbmRankingTrainer"/>.
120+
/// Predict a target using a decision tree ranking model trained with the <see cref="LightGbmRankingTrainer"/>.
112121
/// </summary>
113122
/// <param name="catalog">The <see cref="RankingCatalog"/>.</param>
123+
/// <param name="options">Advanced options to the algorithm.</param>
124+
public static LightGbmRankingTrainer LightGbm(this RankingCatalog.RankingTrainers catalog,
125+
Options options)
126+
{
127+
Contracts.CheckValue(catalog, nameof(catalog));
128+
var env = CatalogUtils.GetEnvironment(catalog);
129+
return new LightGbmRankingTrainer(env, options);
130+
}
131+
132+
/// <summary>
133+
/// Predict a target using a decision tree multiclass classification model trained with the <see cref="LightGbmMulticlassTrainer"/>.
134+
/// </summary>
135+
/// <param name="catalog">The <see cref="MulticlassClassificationCatalog"/>.</param>
114136
/// <param name="labelColumn">The labelColumn column.</param>
115137
/// <param name="featureColumn">The features column.</param>
116138
/// <param name="weights">The weights column.</param>
117139
/// <param name="numLeaves">The number of leaves to use.</param>
118140
/// <param name="numBoostRound">Number of iterations.</param>
119141
/// <param name="minDataPerLeaf">The minimal number of documents allowed in a leaf of the tree, out of the subsampled data.</param>
120142
/// <param name="learningRate">The learning rate.</param>
121-
/// <param name="advancedSettings">A delegate to set more settings.
122-
/// The settings here will override the ones provided in the direct signature,
123-
/// if both are present and have different values.
124-
/// The columns names, however need to be provided directly, not through the <paramref name="advancedSettings"/>.</param>
125143
public static LightGbmMulticlassTrainer LightGbm(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
126144
string labelColumn = DefaultColumnNames.Label,
127145
string featureColumn = DefaultColumnNames.Features,
128146
string weights = null,
129147
int? numLeaves = null,
130148
int? minDataPerLeaf = null,
131149
double? learningRate = null,
132-
int numBoostRound = LightGbmArguments.Defaults.NumBoostRound,
133-
Action<LightGbmArguments> advancedSettings = null)
150+
int numBoostRound = Options.Defaults.NumBoostRound)
134151
{
135152
Contracts.CheckValue(catalog, nameof(catalog));
136153
var env = CatalogUtils.GetEnvironment(catalog);
137-
return new LightGbmMulticlassTrainer(env, labelColumn, featureColumn, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings);
154+
return new LightGbmMulticlassTrainer(env, labelColumn, featureColumn, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound);
155+
}
138156

157+
/// <summary>
158+
/// Predict a target using a decision tree multiclass classification model trained with the <see cref="LightGbmMulticlassTrainer"/>.
159+
/// </summary>
160+
/// <param name="catalog">The <see cref="MulticlassClassificationCatalog"/>.</param>
161+
/// <param name="options">Advanced options to the algorithm.</param>
162+
public static LightGbmMulticlassTrainer LightGbm(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
163+
Options options)
164+
{
165+
Contracts.CheckValue(catalog, nameof(catalog));
166+
var env = CatalogUtils.GetEnvironment(catalog);
167+
return new LightGbmMulticlassTrainer(env, options);
139168
}
140169
}
141170
}

0 commit comments

Comments
 (0)