Skip to content

Commit 5e58a79

Browse files
committed
Internalization of environment implementations
* Internalize HostEnvironmentBase/ConsoleEnvironment * Limit usage of ConsoleEnvironment
1 parent b82eba9 commit 5e58a79

File tree

69 files changed

+3838
-4099
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+3838
-4099
lines changed

src/Microsoft.ML.Api/TypedCursor.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -563,9 +563,9 @@ public static ICursorable<TRow> AsCursorable<TRow>(this IDataView data, bool ign
563563
SchemaDefinition schemaDefinition = null)
564564
where TRow : class, new()
565565
{
566-
// REVIEW: Take an env as a parameter.
567-
var env = new ConsoleEnvironment();
568-
return data.AsCursorable<TRow>(env, ignoreMissingColumns, schemaDefinition);
566+
// REVIEW: Take this as a parameter, or else make it a method on he context itself.
567+
var ml = new MLContext(42);
568+
return data.AsCursorable<TRow>(ml, ignoreMissingColumns, schemaDefinition);
569569
}
570570

571571
/// <summary>
@@ -604,9 +604,9 @@ public static IEnumerable<TRow> AsEnumerable<TRow>(this IDataView data, bool reu
604604
bool ignoreMissingColumns = false, SchemaDefinition schemaDefinition = null)
605605
where TRow : class, new()
606606
{
607-
// REVIEW: Take an env as a parameter.
608-
var env = new ConsoleEnvironment();
609-
return data.AsEnumerable<TRow>(env, reuseRowObject, ignoreMissingColumns, schemaDefinition);
607+
// REVIEW: Take this as a parameter, or else make it a method on the context itself.
608+
var ml = new MLContext();
609+
return data.AsEnumerable<TRow>(ml, reuseRowObject, ignoreMissingColumns, schemaDefinition);
610610
}
611611
}
612612
}

src/Microsoft.ML.Core/Data/IFileHandle.cs

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ public sealed class SimpleFileHandle : IFileHandle
6161
// handle has been disposed.
6262
private List<Stream> _streams;
6363

64-
private bool IsDisposed { get { return _streams == null; } }
64+
private bool IsDisposed => _streams == null;
6565

6666
public SimpleFileHandle(IExceptionContext ectx, string path, bool needsWrite, bool autoDelete)
6767
{
@@ -84,15 +84,9 @@ public SimpleFileHandle(IExceptionContext ectx, string path, bool needsWrite, bo
8484
_streams = new List<Stream>();
8585
}
8686

87-
public bool CanWrite
88-
{
89-
get { return !_wrote && !IsDisposed; }
90-
}
87+
public bool CanWrite => !_wrote && !IsDisposed;
9188

92-
public bool CanRead
93-
{
94-
get { return _wrote && !IsDisposed; }
95-
}
89+
public bool CanRead => _wrote && !IsDisposed;
9690

9791
public void Dispose()
9892
{

src/Microsoft.ML.Core/Data/IHostEnvironment.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ public interface IHostEnvironment : IChannelProvider, IProgressChannelProvider
7272
/// The suffix and prefix are optional. A common use for suffix is to specify an extension, eg, ".txt".
7373
/// The use of suffix and prefix, including whether they have any affect, is up to the host environment.
7474
/// </summary>
75+
[Obsolete("The host environment is not disposable, so it is inappropriate to use this method. " +
76+
"Handle your own temporary files. If you cannot, reconsider your life choices.")]
7577
IFileHandle CreateTempFile(string suffix = null, string prefix = null);
7678

7779
/// <summary>

src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
#pragma warning disable 420 // volatile with Interlocked.CompareExchange
66

77
using System;
8-
using System.Collections.Concurrent;
9-
using System.Collections.Generic;
108
using System.IO;
119
using System.Linq;
1210
using System.Threading;
@@ -15,7 +13,12 @@ namespace Microsoft.ML.Runtime.Data
1513
{
1614
using Stopwatch = System.Diagnostics.Stopwatch;
1715

18-
public sealed class ConsoleEnvironment : HostEnvironmentBase<ConsoleEnvironment>
16+
/// <summary>
17+
/// The console environment. As its name suggests, should be limited to those applications that deliberately want
18+
/// console functionality.
19+
/// </summary>
20+
[BestFriend]
21+
internal sealed class ConsoleEnvironment : HostEnvironmentBase<ConsoleEnvironment>
1922
{
2023
public const string ComponentHistoryKey = "ComponentHistory";
2124

src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ public interface IMessageDispatcher : IHostEnvironment
114114
/// AddListener/RemoveListener methods, and exposes the <see cref="ProgressReporting.ProgressTracker"/> to
115115
/// query progress.
116116
/// </summary>
117-
public abstract class HostEnvironmentBase<TEnv> : ChannelProviderBase, IHostEnvironment, IDisposable, IChannelProvider, IMessageDispatcher
117+
[BestFriend]
118+
internal abstract class HostEnvironmentBase<TEnv> : ChannelProviderBase, IHostEnvironment, IDisposable, IChannelProvider, IMessageDispatcher
118119
where TEnv : HostEnvironmentBase<TEnv>
119120
{
120121
/// <summary>

src/Microsoft.ML.Data/Commands/TrainTestCommand.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System.Collections.Generic;
6+
using System.IO;
67
using Microsoft.ML.Runtime;
78
using Microsoft.ML.Runtime.Command;
89
using Microsoft.ML.Runtime.CommandLine;
@@ -185,11 +186,12 @@ private void RunCore(IChannel ch, string cmd)
185186
Args.Calibrator, Args.MaxCalibrationExamples, Args.CacheData, inputPredictor, testDataUsedInTrainer);
186187

187188
IDataLoader testPipe;
188-
using (var file = !string.IsNullOrEmpty(Args.OutputModelFile) ?
189-
Host.CreateOutputFile(Args.OutputModelFile) : Host.CreateTempFile(".zip"))
189+
bool hasOutfile = !string.IsNullOrEmpty(Args.OutputModelFile);
190+
var tempFilePath = hasOutfile ? null : Path.GetTempFileName();
191+
192+
using (var file = new SimpleFileHandle(ch, hasOutfile ? Args.OutputModelFile : tempFilePath, true, !hasOutfile))
190193
{
191194
TrainUtils.SaveModel(Host, ch, file, predictor, data, cmd);
192-
193195
ch.Trace("Constructing the testing pipeline");
194196
using (var stream = file.OpenReadStream())
195197
using (var rep = RepositoryReader.Open(stream, ch))

src/Microsoft.ML.Data/EntryPoints/Cache.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,11 @@ public static CacheOutput CacheData(IHostEnvironment env, CacheInput input)
6868
cols.Add(i);
6969
}
7070

71+
#pragma warning disable CS0618 // This ought to be addressed. See #1287.
7172
// We are not disposing the fileHandle because we want it to stay around for the execution of the graph.
7273
// It will be disposed when the environment is disposed.
7374
var fileHandle = host.CreateTempFile();
75+
#pragma warning restore CS0618
7476

7577
using (var stream = fileHandle.CreateWriteStream())
7678
saver.SaveData(stream, input.Data, cols.ToArray());

src/Microsoft.ML.Legacy/LearningPipeline.cs

Lines changed: 59 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -161,86 +161,84 @@ public PredictionModel<TInput, TOutput> Train<TInput, TOutput>()
161161
where TInput : class
162162
where TOutput : class, new()
163163
{
164-
using (var environment = new ConsoleEnvironment(seed: _seed, conc: _conc))
164+
var environment = new MLContext(seed: _seed, conc: _conc);
165+
Experiment experiment = environment.CreateExperiment();
166+
ILearningPipelineStep step = null;
167+
List<ILearningPipelineLoader> loaders = new List<ILearningPipelineLoader>();
168+
List<Var<ITransformModel>> transformModels = new List<Var<ITransformModel>>();
169+
Var<ITransformModel> lastTransformModel = null;
170+
171+
foreach (ILearningPipelineItem currentItem in this)
165172
{
166-
Experiment experiment = environment.CreateExperiment();
167-
ILearningPipelineStep step = null;
168-
List<ILearningPipelineLoader> loaders = new List<ILearningPipelineLoader>();
169-
List<Var<ITransformModel>> transformModels = new List<Var<ITransformModel>>();
170-
Var<ITransformModel> lastTransformModel = null;
173+
if (currentItem is ILearningPipelineLoader loader)
174+
loaders.Add(loader);
171175

172-
foreach (ILearningPipelineItem currentItem in this)
176+
step = currentItem.ApplyStep(step, experiment);
177+
if (step is ILearningPipelineDataStep dataStep && dataStep.Model != null)
178+
transformModels.Add(dataStep.Model);
179+
else if (step is ILearningPipelinePredictorStep predictorDataStep)
173180
{
174-
if (currentItem is ILearningPipelineLoader loader)
175-
loaders.Add(loader);
181+
if (lastTransformModel != null)
182+
transformModels.Insert(0, lastTransformModel);
176183

177-
step = currentItem.ApplyStep(step, experiment);
178-
if (step is ILearningPipelineDataStep dataStep && dataStep.Model != null)
179-
transformModels.Add(dataStep.Model);
180-
else if (step is ILearningPipelinePredictorStep predictorDataStep)
184+
Var<IPredictorModel> predictorModel;
185+
if (transformModels.Count != 0)
181186
{
182-
if (lastTransformModel != null)
183-
transformModels.Insert(0, lastTransformModel);
184-
185-
Var<IPredictorModel> predictorModel;
186-
if (transformModels.Count != 0)
187+
var localModelInput = new Transforms.ManyHeterogeneousModelCombiner
187188
{
188-
var localModelInput = new Transforms.ManyHeterogeneousModelCombiner
189-
{
190-
PredictorModel = predictorDataStep.Model,
191-
TransformModels = new ArrayVar<ITransformModel>(transformModels.ToArray())
192-
};
193-
var localModelOutput = experiment.Add(localModelInput);
194-
predictorModel = localModelOutput.PredictorModel;
195-
}
196-
else
197-
predictorModel = predictorDataStep.Model;
198-
199-
var scorer = new Transforms.Scorer
200-
{
201-
PredictorModel = predictorModel
189+
PredictorModel = predictorDataStep.Model,
190+
TransformModels = new ArrayVar<ITransformModel>(transformModels.ToArray())
202191
};
203-
204-
var scorerOutput = experiment.Add(scorer);
205-
lastTransformModel = scorerOutput.ScoringTransform;
206-
step = new ScorerPipelineStep(scorerOutput.ScoredData, scorerOutput.ScoringTransform);
207-
transformModels.Clear();
192+
var localModelOutput = experiment.Add(localModelInput);
193+
predictorModel = localModelOutput.PredictorModel;
208194
}
209-
}
195+
else
196+
predictorModel = predictorDataStep.Model;
210197

211-
if (transformModels.Count > 0)
212-
{
213-
if (lastTransformModel != null)
214-
transformModels.Insert(0, lastTransformModel);
215-
216-
var modelInput = new Transforms.ModelCombiner
198+
var scorer = new Transforms.Scorer
217199
{
218-
Models = new ArrayVar<ITransformModel>(transformModels.ToArray())
200+
PredictorModel = predictorModel
219201
};
220202

221-
var modelOutput = experiment.Add(modelInput);
222-
lastTransformModel = modelOutput.OutputModel;
203+
var scorerOutput = experiment.Add(scorer);
204+
lastTransformModel = scorerOutput.ScoringTransform;
205+
step = new ScorerPipelineStep(scorerOutput.ScoredData, scorerOutput.ScoringTransform);
206+
transformModels.Clear();
223207
}
208+
}
224209

225-
experiment.Compile();
226-
foreach (ILearningPipelineLoader loader in loaders)
227-
{
228-
loader.SetInput(environment, experiment);
229-
}
230-
experiment.Run();
210+
if (transformModels.Count > 0)
211+
{
212+
if (lastTransformModel != null)
213+
transformModels.Insert(0, lastTransformModel);
231214

232-
ITransformModel model = experiment.GetOutput(lastTransformModel);
233-
BatchPredictionEngine<TInput, TOutput> predictor;
234-
using (var memoryStream = new MemoryStream())
215+
var modelInput = new Transforms.ModelCombiner
235216
{
236-
model.Save(environment, memoryStream);
217+
Models = new ArrayVar<ITransformModel>(transformModels.ToArray())
218+
};
237219

238-
memoryStream.Position = 0;
220+
var modelOutput = experiment.Add(modelInput);
221+
lastTransformModel = modelOutput.OutputModel;
222+
}
239223

240-
predictor = environment.CreateBatchPredictionEngine<TInput, TOutput>(memoryStream);
224+
experiment.Compile();
225+
foreach (ILearningPipelineLoader loader in loaders)
226+
{
227+
loader.SetInput(environment, experiment);
228+
}
229+
experiment.Run();
241230

242-
return new PredictionModel<TInput, TOutput>(predictor, memoryStream);
243-
}
231+
ITransformModel model = experiment.GetOutput(lastTransformModel);
232+
BatchPredictionEngine<TInput, TOutput> predictor;
233+
using (var memoryStream = new MemoryStream())
234+
{
235+
model.Save(environment, memoryStream);
236+
237+
memoryStream.Position = 0;
238+
239+
predictor = environment.CreateBatchPredictionEngine<TInput, TOutput>(memoryStream);
240+
241+
return new PredictionModel<TInput, TOutput>(predictor, memoryStream);
244242
}
245243
}
246244

src/Microsoft.ML.Legacy/LearningPipelineDebugProxy.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using Microsoft.ML.Runtime;
56
using Microsoft.ML.Runtime.Data;
67
using Microsoft.ML.Legacy.Transforms;
78
using System;
@@ -25,7 +26,7 @@ internal sealed class LearningPipelineDebugProxy
2526
private const int MaxSlotNamesToDisplay = 100;
2627

2728
private readonly LearningPipeline _pipeline;
28-
private readonly ConsoleEnvironment _environment;
29+
private readonly IHostEnvironment _environment;
2930
private IDataView _preview;
3031
private Exception _pipelineExecutionException;
3132
private PipelineItemDebugColumn[] _columns;
@@ -39,7 +40,7 @@ public LearningPipelineDebugProxy(LearningPipeline pipeline)
3940
_pipeline = new LearningPipeline();
4041

4142
// use a ConcurrencyFactor of 1 so other threads don't need to run in the debugger
42-
_environment = new ConsoleEnvironment(conc: 1);
43+
_environment = new MLContext(conc: 1);
4344

4445
foreach (ILearningPipelineItem item in pipeline)
4546
{

src/Microsoft.ML.Legacy/Models/BinaryClassificationEvaluator.cs

Lines changed: 34 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -24,54 +24,52 @@ public sealed partial class BinaryClassificationEvaluator
2424
/// </returns>
2525
public BinaryClassificationMetrics Evaluate(PredictionModel model, ILearningPipelineLoader testData)
2626
{
27-
using (var environment = new ConsoleEnvironment())
28-
{
29-
environment.CheckValue(model, nameof(model));
30-
environment.CheckValue(testData, nameof(testData));
27+
var environment = new MLContext();
28+
environment.CheckValue(model, nameof(model));
29+
environment.CheckValue(testData, nameof(testData));
3130

32-
Experiment experiment = environment.CreateExperiment();
31+
Experiment experiment = environment.CreateExperiment();
3332

34-
ILearningPipelineStep testDataStep = testData.ApplyStep(previousStep: null, experiment);
35-
if (!(testDataStep is ILearningPipelineDataStep testDataOutput))
36-
{
37-
throw environment.Except($"The {nameof(ILearningPipelineLoader)} did not return a {nameof(ILearningPipelineDataStep)} from ApplyStep.");
38-
}
33+
ILearningPipelineStep testDataStep = testData.ApplyStep(previousStep: null, experiment);
34+
if (!(testDataStep is ILearningPipelineDataStep testDataOutput))
35+
{
36+
throw environment.Except($"The {nameof(ILearningPipelineLoader)} did not return a {nameof(ILearningPipelineDataStep)} from ApplyStep.");
37+
}
3938

40-
var datasetScorer = new DatasetTransformScorer
41-
{
42-
Data = testDataOutput.Data
43-
};
44-
DatasetTransformScorer.Output scoreOutput = experiment.Add(datasetScorer);
39+
var datasetScorer = new DatasetTransformScorer
40+
{
41+
Data = testDataOutput.Data
42+
};
43+
DatasetTransformScorer.Output scoreOutput = experiment.Add(datasetScorer);
4544

46-
Data = scoreOutput.ScoredData;
47-
Output evaluteOutput = experiment.Add(this);
45+
Data = scoreOutput.ScoredData;
46+
Output evaluteOutput = experiment.Add(this);
4847

49-
experiment.Compile();
48+
experiment.Compile();
5049

51-
experiment.SetInput(datasetScorer.TransformModel, model.PredictorModel);
52-
testData.SetInput(environment, experiment);
50+
experiment.SetInput(datasetScorer.TransformModel, model.PredictorModel);
51+
testData.SetInput(environment, experiment);
5352

54-
experiment.Run();
53+
experiment.Run();
5554

56-
IDataView overallMetrics = experiment.GetOutput(evaluteOutput.OverallMetrics);
57-
if (overallMetrics == null)
58-
{
59-
throw environment.Except($"Could not find OverallMetrics in the results returned in {nameof(BinaryClassificationEvaluator)} Evaluate.");
60-
}
55+
IDataView overallMetrics = experiment.GetOutput(evaluteOutput.OverallMetrics);
56+
if (overallMetrics == null)
57+
{
58+
throw environment.Except($"Could not find OverallMetrics in the results returned in {nameof(BinaryClassificationEvaluator)} Evaluate.");
59+
}
6160

62-
IDataView confusionMatrix = experiment.GetOutput(evaluteOutput.ConfusionMatrix);
63-
if (confusionMatrix == null)
64-
{
65-
throw environment.Except($"Could not find ConfusionMatrix in the results returned in {nameof(BinaryClassificationEvaluator)} Evaluate.");
66-
}
61+
IDataView confusionMatrix = experiment.GetOutput(evaluteOutput.ConfusionMatrix);
62+
if (confusionMatrix == null)
63+
{
64+
throw environment.Except($"Could not find ConfusionMatrix in the results returned in {nameof(BinaryClassificationEvaluator)} Evaluate.");
65+
}
6766

68-
var metric = BinaryClassificationMetrics.FromMetrics(environment, overallMetrics, confusionMatrix);
67+
var metric = BinaryClassificationMetrics.FromMetrics(environment, overallMetrics, confusionMatrix);
6968

70-
if (metric.Count != 1)
71-
throw environment.Except($"Exactly one metric set was expected but found {metric.Count} metrics");
69+
if (metric.Count != 1)
70+
throw environment.Except($"Exactly one metric set was expected but found {metric.Count} metrics");
7271

73-
return metric[0];
74-
}
72+
return metric[0];
7573
}
7674
}
7775
}

0 commit comments

Comments
 (0)