-
Notifications
You must be signed in to change notification settings - Fork 1.9k
ML Context to create them all #1252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
26913ac
c6b5a67
e41fc21
e548c63
f7e99fd
65906d2
bbb7be4
8c485fc
3062cb8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
namespace Microsoft.ML.Runtime | ||
{ | ||
/// <summary> | ||
/// A catalog of operations to load and save data. | ||
/// </summary> | ||
public sealed class DataLoadSaveOperations | ||
{ | ||
internal IHostEnvironment Environment { get; } | ||
|
||
internal DataLoadSaveOperations(IHostEnvironment env) | ||
{ | ||
Contracts.AssertValue(env); | ||
Environment = env; | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.ML.Runtime; | ||
using Microsoft.ML.Runtime.Data; | ||
using Microsoft.ML.Runtime.Data.IO; | ||
using Microsoft.ML.Runtime.Internal.Utilities; | ||
using System; | ||
using System.Collections.Generic; | ||
using System.IO; | ||
using System.Linq; | ||
using System.Text; | ||
|
||
namespace Microsoft.ML | ||
{ | ||
public static class TextLoaderSaverCatalog | ||
{ | ||
/// <summary> | ||
/// Create a text reader. | ||
/// </summary> | ||
/// <param name="catalog">The catalog.</param> | ||
/// <param name="args">The arguments to text reader, describing the data schema.</param> | ||
/// <param name="dataSample">The optional location of a data sample.</param> | ||
public static TextLoader TextReader(this DataLoadSaveOperations catalog, | ||
TextLoader.Arguments args, IMultiStreamSource dataSample = null) | ||
=> new TextLoader(CatalogUtils.GetEnvironment(catalog), args, dataSample); | ||
|
||
/// <summary> | ||
/// Create a text reader. | ||
/// </summary> | ||
/// <param name="catalog">The catalog.</param> | ||
/// <param name="columns">The columns of the schema.</param> | ||
/// <param name="advancedSettings">The delegate to set additional settings.</param> | ||
/// <param name="dataSample">The optional location of a data sample.</param> | ||
public static TextLoader TextReader(this DataLoadSaveOperations catalog, | ||
TextLoader.Column[] columns, Action<TextLoader.Arguments> advancedSettings = null, IMultiStreamSource dataSample = null) | ||
=> new TextLoader(CatalogUtils.GetEnvironment(catalog), columns, advancedSettings, dataSample); | ||
|
||
/// <summary> | ||
/// Read a data view from a text file using <see cref="TextLoader"/>. | ||
/// </summary> | ||
/// <param name="catalog">The catalog.</param> | ||
/// <param name="columns">The columns of the schema.</param> | ||
/// <param name="advancedSettings">The delegate to set additional settings</param> | ||
/// <param name="path">The path to the file</param> | ||
/// <returns>The data view.</returns> | ||
public static IDataView ReadFromTextFile(this DataLoadSaveOperations catalog, | ||
TextLoader.Column[] columns, string path, Action<TextLoader.Arguments> advancedSettings = null) | ||
{ | ||
Contracts.CheckNonEmpty(path, nameof(path)); | ||
|
||
var env = catalog.GetEnvironment(); | ||
|
||
// REVIEW: it is almost always a mistake to have a 'trainable' text loader here. | ||
// Therefore, we are going to disallow data sample. | ||
var reader = new TextLoader(env, columns, advancedSettings, dataSample: null); | ||
return reader.Read(new MultiFileSource(path)); | ||
} | ||
|
||
/// <summary> | ||
/// Save the data view as text. | ||
/// </summary> | ||
/// <param name="catalog">The catalog.</param> | ||
/// <param name="data">The data view to save.</param> | ||
/// <param name="stream">The stream to write to.</param> | ||
/// <param name="separator">The column separator.</param> | ||
/// <param name="headerRow">Whether to write the header row.</param> | ||
/// <param name="schema">Whether to write the header comment with the schema.</param> | ||
/// <param name="keepHidden">Whether to keep hidden columns in the dataset.</param> | ||
public static void SaveAsText(this DataLoadSaveOperations catalog, IDataView data, Stream stream, | ||
char separator = '\t', bool headerRow = true, bool schema = true, bool keepHidden = false) | ||
{ | ||
Contracts.CheckValue(catalog, nameof(catalog)); | ||
Contracts.CheckValue(data, nameof(data)); | ||
Contracts.CheckValue(stream, nameof(stream)); | ||
|
||
var env = catalog.GetEnvironment(); | ||
var saver = new TextSaver(env, new TextSaver.Arguments { Separator = separator.ToString(), OutputHeader = headerRow, OutputSchema = schema }); | ||
|
||
using (var ch = env.Start("Saving data")) | ||
DataSaverUtils.SaveDataView(ch, saver, data, stream, keepHidden); | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.ML.Runtime; | ||
using Microsoft.ML.Runtime.Data; | ||
using System; | ||
|
||
namespace Microsoft.ML | ||
{ | ||
/// <summary> | ||
/// The <see cref="MLContext"/> is a starting point for all ML.NET operations. It is instantiated by user, | ||
/// provides mechanisms for logging and entry points for training, prediction, model operations etc. | ||
/// </summary> | ||
public sealed class MLContext : IHostEnvironment | ||
{ | ||
// REVIEW: consider making LocalEnvironment and MLContext the same class instead of encapsulation. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I still don't fully understand our plan here with And if the answer is "you can't - you'd use ConsoleEnvironment directly", I think that is a pretty poor answer because then my whole API paradigm changes. I can no longer use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do you think about the answer: "you can't use ConsoleEnvironment because it is internal and only used by MAML"? In reply to: 225986842 [](ancestors = 225986842) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That would be fine with me. For now, I would think it's completely fine to have our custom console printing internal to the commandline tool In reply to: 225989180 [](ancestors = 225989180,225986842) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. #1284 #Closed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
private readonly LocalEnvironment _env; | ||
|
||
/// <summary> | ||
/// Trainers and tasks specific to binary classification problems. | ||
/// </summary> | ||
public BinaryClassificationContext BinaryClassification { get; } | ||
/// <summary> | ||
/// Trainers and tasks specific to multiclass classification problems. | ||
/// </summary> | ||
public MulticlassClassificationContext MulticlassClassification { get; } | ||
/// <summary> | ||
/// Trainers and tasks specific to regression problems. | ||
/// </summary> | ||
public RegressionContext Regression { get; } | ||
/// <summary> | ||
/// Trainers and tasks specific to clustering problems. | ||
/// </summary> | ||
public ClusteringContext Clustering { get; } | ||
/// <summary> | ||
/// Trainers and tasks specific to ranking problems. | ||
/// </summary> | ||
public RankingContext Ranking { get; } | ||
|
||
/// <summary> | ||
/// Data processing operations. | ||
/// </summary> | ||
public TransformsCatalog Transforms { get; } | ||
|
||
/// <summary> | ||
/// Operations with trained models. | ||
/// </summary> | ||
public ModelOperationsCatalog Model { get; } | ||
|
||
/// <summary> | ||
/// Data loading and saving. | ||
/// </summary> | ||
public DataLoadSaveOperations Data { get; } | ||
|
||
// REVIEW: I think it's valuable to have the simplest possible interface for logging interception here, | ||
// and expand if and when necessary. Exposing classes like ChannelMessage, MessageSensitivity and so on | ||
// looks premature at this point. | ||
/// <summary> | ||
/// The handler for the log messages. | ||
/// </summary> | ||
public Action<string> Log { get; set; } | ||
|
||
/// <summary> | ||
/// Create the ML context. | ||
/// </summary> | ||
/// <param name="seed">Random seed. Set to <c>null</c> for a non-deterministic environment.</param> | ||
/// <param name="conc">Concurrency level. Set to 1 to run single-threaded. Set to 0 to pick automatically.</param> | ||
public MLContext(int? seed = null, int conc = 0) | ||
{ | ||
_env = new LocalEnvironment(seed, conc); | ||
_env.AddListener(ProcessMessage); | ||
|
||
BinaryClassification = new BinaryClassificationContext(_env); | ||
MulticlassClassification = new MulticlassClassificationContext(_env); | ||
Regression = new RegressionContext(_env); | ||
Clustering = new ClusteringContext(_env); | ||
Ranking = new RankingContext(_env); | ||
Transforms = new TransformsCatalog(_env); | ||
Model = new ModelOperationsCatalog(_env); | ||
Data = new DataLoadSaveOperations(_env); | ||
} | ||
|
||
private void ProcessMessage(IMessageSource source, ChannelMessage message) | ||
{ | ||
if (Log == null) | ||
return; | ||
|
||
var msg = $"[Source={source.FullName}, Kind={message.Kind}] {message.Message}"; | ||
// Log may have been reset from another thread. | ||
// We don't care which logger we send the message to, just making sure we don't crash. | ||
Log?.Invoke(msg); | ||
} | ||
|
||
int IHostEnvironment.ConcurrencyFactor => _env.ConcurrencyFactor; | ||
bool IHostEnvironment.IsCancelled => _env.IsCancelled; | ||
ComponentCatalog IHostEnvironment.ComponentCatalog => _env.ComponentCatalog; | ||
string IExceptionContext.ContextDescription => _env.ContextDescription; | ||
IFileHandle IHostEnvironment.CreateOutputFile(string path) => _env.CreateOutputFile(path); | ||
IFileHandle IHostEnvironment.CreateTempFile(string suffix, string prefix) => _env.CreateTempFile(suffix, prefix); | ||
IFileHandle IHostEnvironment.OpenInputFile(string path) => _env.OpenInputFile(path); | ||
TException IExceptionContext.Process<TException>(TException ex) => _env.Process(ex); | ||
IHost IHostEnvironment.Register(string name, int? seed, bool? verbose, int? conc) => _env.Register(name, seed, verbose, conc); | ||
IChannel IChannelProvider.Start(string name) => _env.Start(name); | ||
IPipe<TMessage> IChannelProvider.StartPipe<TMessage>(string name) => _env.StartPipe<TMessage>(name); | ||
IProgressChannel IProgressChannelProvider.StartProgressChannel(string name) => _env.StartProgressChannel(name); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.ML.Core.Data; | ||
using Microsoft.ML.Runtime.Data; | ||
using System.IO; | ||
|
||
namespace Microsoft.ML.Runtime | ||
{ | ||
/// <summary> | ||
/// An object serving as a 'catalog' of available model operations. | ||
/// </summary> | ||
public sealed class ModelOperationsCatalog | ||
{ | ||
internal IHostEnvironment Environment { get; } | ||
|
||
internal ModelOperationsCatalog(IHostEnvironment env) | ||
{ | ||
Contracts.AssertValue(env); | ||
Environment = env; | ||
} | ||
|
||
/// <summary> | ||
/// Save the model to the stream. | ||
/// </summary> | ||
/// <param name="model">The trained model to be saved.</param> | ||
/// <param name="stream">A writeable, seekable stream to save to.</param> | ||
public void Save(ITransformer model, Stream stream) => model.SaveTo(Environment, stream); | ||
|
||
/// <summary> | ||
/// Load the model from the stream. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
If we are saving and loading an ITransformer, should we call the class TransformerOperationsCatalog? #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the name should reflect the user-facing concept, rather than the implementation fact that 'trained model is a transformer'. In reply to: 224950106 [](ancestors = 224950106) |
||
/// </summary> | ||
/// <param name="stream">A readable, seekable stream to load from.</param> | ||
/// <returns>The loaded model.</returns> | ||
public ITransformer Load(Stream stream) => TransformerChain.LoadFrom(Environment, stream); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❤️ #Closed