Skip to content

Commit fb6ce54

Browse files
authored
Fixing renmants of argument keyword in public API (#2636)
* 1st pass. builds work locally * fixes inside Ensemble, FastTree, StandardLearners, Sweeper * fixes to Data and Sweeper assemblies
1 parent 96ec842 commit fb6ce54

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+617
-617
lines changed

src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs

+33-33
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,10 @@ public sealed class Arguments : DataCommand.ArgumentsBase
107107
public CrossValidationCommand(IHostEnvironment env, Arguments args)
108108
: base(env, args, RegistrationName)
109109
{
110-
Host.CheckUserArg(Args.NumFolds >= 2, nameof(Args.NumFolds), "Number of folds must be greater than or equal to 2.");
110+
Host.CheckUserArg(ImplOptions.NumFolds >= 2, nameof(ImplOptions.NumFolds), "Number of folds must be greater than or equal to 2.");
111111
TrainUtils.CheckTrainer(Host, args.Trainer, args.DataFile);
112-
Utils.CheckOptionalUserDirectory(Args.SummaryFilename, nameof(Args.SummaryFilename));
113-
Utils.CheckOptionalUserDirectory(Args.OutputDataFile, nameof(Args.OutputDataFile));
112+
Utils.CheckOptionalUserDirectory(ImplOptions.SummaryFilename, nameof(ImplOptions.SummaryFilename));
113+
Utils.CheckOptionalUserDirectory(ImplOptions.OutputDataFile, nameof(ImplOptions.OutputDataFile));
114114
}
115115

116116
// This is for "forking" the host environment.
@@ -124,7 +124,7 @@ public override void Run()
124124
using (var ch = Host.Start(LoadName))
125125
using (var server = InitServer(ch))
126126
{
127-
var settings = CmdParser.GetSettings(Host, Args, new Arguments());
127+
var settings = CmdParser.GetSettings(Host, ImplOptions, new Arguments());
128128
string cmd = string.Format("maml.exe {0} {1}", LoadName, settings);
129129
ch.Info(cmd);
130130

@@ -139,7 +139,7 @@ public override void Run()
139139

140140
protected override void SendTelemetryCore(IPipe<TelemetryMessage> pipe)
141141
{
142-
SendTelemetryComponent(pipe, Args.Trainer);
142+
SendTelemetryComponent(pipe, ImplOptions.Trainer);
143143
base.SendTelemetryCore(pipe);
144144
}
145145

@@ -148,17 +148,17 @@ private void RunCore(IChannel ch, string cmd)
148148
Host.AssertValue(ch);
149149

150150
IPredictor inputPredictor = null;
151-
if (Args.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, Args.InputModelFile, out inputPredictor))
151+
if (ImplOptions.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, ImplOptions.InputModelFile, out inputPredictor))
152152
ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized.");
153153

154154
ch.Trace("Constructing data pipeline");
155155
IDataLoader loader = CreateRawLoader();
156156

157157
// If the per-instance results are requested and there is no name column, add a GenerateNumberTransform.
158-
var preXf = Args.PreTransforms;
159-
if (!string.IsNullOrEmpty(Args.OutputDataFile))
158+
var preXf = ImplOptions.PreTransforms;
159+
if (!string.IsNullOrEmpty(ImplOptions.OutputDataFile))
160160
{
161-
string name = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema, nameof(Args.NameColumn), Args.NameColumn, DefaultColumnNames.Name);
161+
string name = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema, nameof(ImplOptions.NameColumn), ImplOptions.NameColumn, DefaultColumnNames.Name);
162162
if (name == null)
163163
{
164164
preXf = preXf.Concat(
@@ -182,24 +182,24 @@ private void RunCore(IChannel ch, string cmd)
182182

183183
IDataView pipe = loader;
184184
var stratificationColumn = GetSplitColumn(ch, loader, ref pipe);
185-
var scorer = Args.Scorer;
186-
var evaluator = Args.Evaluator;
185+
var scorer = ImplOptions.Scorer;
186+
var evaluator = ImplOptions.Evaluator;
187187

188188
Func<IDataView> validDataCreator = null;
189-
if (Args.ValidationFile != null)
189+
if (ImplOptions.ValidationFile != null)
190190
{
191191
validDataCreator =
192192
() =>
193193
{
194194
// Fork the command.
195195
var impl = new CrossValidationCommand(this);
196-
return impl.CreateRawLoader(dataFile: Args.ValidationFile);
196+
return impl.CreateRawLoader(dataFile: ImplOptions.ValidationFile);
197197
};
198198
}
199199

200200
FoldHelper fold = new FoldHelper(Host, RegistrationName, pipe, stratificationColumn,
201-
Args, CreateRoleMappedData, ApplyAllTransformsToData, scorer, evaluator,
202-
validDataCreator, ApplyAllTransformsToData, inputPredictor, cmd, loader, !string.IsNullOrEmpty(Args.OutputDataFile));
201+
ImplOptions, CreateRoleMappedData, ApplyAllTransformsToData, scorer, evaluator,
202+
validDataCreator, ApplyAllTransformsToData, inputPredictor, cmd, loader, !string.IsNullOrEmpty(ImplOptions.OutputDataFile));
203203
var tasks = fold.GetCrossValidationTasks();
204204

205205
var eval = evaluator?.CreateComponent(Host) ??
@@ -218,32 +218,32 @@ private void RunCore(IChannel ch, string cmd)
218218
throw ch.Except("No overall metrics found");
219219

220220
var overall = eval.GetOverallResults(overallList.ToArray());
221-
MetricWriter.PrintOverallMetrics(Host, ch, Args.SummaryFilename, overall, Args.NumFolds);
221+
MetricWriter.PrintOverallMetrics(Host, ch, ImplOptions.SummaryFilename, overall, ImplOptions.NumFolds);
222222
eval.PrintAdditionalMetrics(ch, tasks.Select(t => t.Result.Metrics).ToArray());
223223
Dictionary<string, IDataView>[] metricValues = tasks.Select(t => t.Result.Metrics).ToArray();
224224
SendTelemetryMetric(metricValues);
225225

226226
// Save the per-instance results.
227-
if (!string.IsNullOrWhiteSpace(Args.OutputDataFile))
227+
if (!string.IsNullOrWhiteSpace(ImplOptions.OutputDataFile))
228228
{
229-
var perInstance = EvaluateUtils.ConcatenatePerInstanceDataViews(Host, eval, Args.CollateMetrics,
230-
Args.OutputExampleFoldIndex, tasks.Select(t => t.Result.PerInstanceResults).ToArray(), out var variableSizeVectorColumnNames);
229+
var perInstance = EvaluateUtils.ConcatenatePerInstanceDataViews(Host, eval, ImplOptions.CollateMetrics,
230+
ImplOptions.OutputExampleFoldIndex, tasks.Select(t => t.Result.PerInstanceResults).ToArray(), out var variableSizeVectorColumnNames);
231231
if (variableSizeVectorColumnNames.Length > 0)
232232
{
233233
ch.Warning("Detected columns of variable length: {0}. Consider setting collateMetrics- for meaningful per-Folds results.",
234234
string.Join(", ", variableSizeVectorColumnNames));
235235
}
236-
if (Args.CollateMetrics)
236+
if (ImplOptions.CollateMetrics)
237237
{
238238
ch.Assert(perInstance.Length == 1);
239-
MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, perInstance[0]);
239+
MetricWriter.SavePerInstance(Host, ch, ImplOptions.OutputDataFile, perInstance[0]);
240240
}
241241
else
242242
{
243243
int i = 0;
244244
foreach (var idv in perInstance)
245245
{
246-
MetricWriter.SavePerInstance(Host, ch, ConstructPerFoldName(Args.OutputDataFile, i), idv);
246+
MetricWriter.SavePerInstance(Host, ch, ConstructPerFoldName(ImplOptions.OutputDataFile, i), idv);
247247
i++;
248248
}
249249
}
@@ -265,20 +265,20 @@ private RoleMappedData ApplyAllTransformsToData(IHostEnvironment env, IChannel c
265265
/// </summary>
266266
private RoleMappedData CreateRoleMappedData(IHostEnvironment env, IChannel ch, IDataView data, ITrainer trainer)
267267
{
268-
foreach (var kvp in Args.Transforms)
268+
foreach (var kvp in ImplOptions.Transforms)
269269
data = kvp.Value.CreateComponent(env, data);
270270

271271
var schema = data.Schema;
272-
string label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.LabelColumn), Args.LabelColumn, DefaultColumnNames.Label);
273-
string features = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.FeatureColumn), Args.FeatureColumn, DefaultColumnNames.Features);
274-
string weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.WeightColumn), Args.WeightColumn, DefaultColumnNames.Weight);
275-
string name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.NameColumn), Args.NameColumn, DefaultColumnNames.Name);
276-
string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.GroupColumn), Args.GroupColumn, DefaultColumnNames.GroupId);
272+
string label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(ImplOptions.LabelColumn), ImplOptions.LabelColumn, DefaultColumnNames.Label);
273+
string features = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(ImplOptions.FeatureColumn), ImplOptions.FeatureColumn, DefaultColumnNames.Features);
274+
string weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(ImplOptions.WeightColumn), ImplOptions.WeightColumn, DefaultColumnNames.Weight);
275+
string name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(ImplOptions.NameColumn), ImplOptions.NameColumn, DefaultColumnNames.Name);
276+
string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(ImplOptions.GroupColumn), ImplOptions.GroupColumn, DefaultColumnNames.GroupId);
277277

278-
TrainUtils.AddNormalizerIfNeeded(env, ch, trainer, ref data, features, Args.NormalizeFeatures);
278+
TrainUtils.AddNormalizerIfNeeded(env, ch, trainer, ref data, features, ImplOptions.NormalizeFeatures);
279279

280280
// Training pipe and examples.
281-
var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumns);
281+
var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, ImplOptions.CustomColumns);
282282

283283
return new RoleMappedData(data, label, features, group, weight, name, customCols);
284284
}
@@ -291,11 +291,11 @@ private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output
291291
// If no stratification column was specified, but we have a group column of type Single, Double or
292292
// Key (contiguous) use it.
293293
string stratificationColumn = null;
294-
if (!string.IsNullOrWhiteSpace(Args.StratificationColumn))
295-
stratificationColumn = Args.StratificationColumn;
294+
if (!string.IsNullOrWhiteSpace(ImplOptions.StratificationColumn))
295+
stratificationColumn = ImplOptions.StratificationColumn;
296296
else
297297
{
298-
string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.GroupColumn), Args.GroupColumn, DefaultColumnNames.GroupId);
298+
string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(ImplOptions.GroupColumn), ImplOptions.GroupColumn, DefaultColumnNames.GroupId);
299299
int index;
300300
if (group != null && schema.TryGetColumnIndex(group, out index))
301301
{

0 commit comments

Comments
 (0)