Skip to content

ML Context to create them all #1252

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Oct 18, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 93 additions & 119 deletions docs/code/MlNetCookBook.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/samples/Microsoft.ML.Samples/Trainers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
// the alignment of the usings with the methods is intentional so they can display on the same level in the docs site.
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Learners;
using Microsoft.ML.Trainers;
using Microsoft.ML.StaticPipe;
using System;

// NOTE: WHEN ADDING TO THE FILE, ALWAYS APPEND TO THE END OF IT.
Expand Down
20 changes: 20 additions & 0 deletions src/Microsoft.ML.Data/DataLoadSave/DataLoadSaveCatalog.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

namespace Microsoft.ML.Runtime
{
/// <summary>
/// A catalog of operations to load and save data.
/// </summary>
public sealed class DataLoadSaveOperations
{
internal IHostEnvironment Environment { get; }

internal DataLoadSaveOperations(IHostEnvironment env)
{
Contracts.AssertValue(env);
Environment = env;
}
}
}
19 changes: 19 additions & 0 deletions src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ public Column() { }
public Column(string name, DataKind? type, int index)
: this(name, type, new[] { new Range(index) }) { }

public Column(string name, DataKind? type, int minIndex, int maxIndex)
: this(name, type, new[] { new Range(minIndex, maxIndex) })
{
}

public Column(string name, DataKind? type, Range[] source, KeyRange keyRange = null)
{
Contracts.CheckValue(name, nameof(name));
Expand Down Expand Up @@ -1003,6 +1008,18 @@ private bool HasHeader
private readonly IHost _host;
private const string RegistrationName = "TextLoader";

public TextLoader(IHostEnvironment env, Column[] columns, Action<Arguments> advancedSettings, IMultiStreamSource dataSample = null)
: this(env, MakeArgs(columns, advancedSettings), dataSample)
{
}

private static Arguments MakeArgs(Column[] columns, Action<Arguments> advancedSettings)
{
var result = new Arguments { Column = columns };
advancedSettings?.Invoke(result);
return result;
}

public TextLoader(IHostEnvironment env, Arguments args, IMultiStreamSource dataSample = null)
{
Contracts.CheckValue(env, nameof(env));
Expand Down Expand Up @@ -1320,6 +1337,8 @@ public void Save(ModelSaveContext ctx)

public IDataView Read(IMultiStreamSource source) => new BoundLoader(this, source);

public IDataView Read(string path) => Read(new MultiFileSource(path));
Copy link
Member

@eerhardt eerhardt Oct 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️ #Closed


private sealed class BoundLoader : IDataLoader
{
private readonly TextLoader _reader;
Expand Down
85 changes: 85 additions & 0 deletions src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderSaverCatalog.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Data.IO;
using Microsoft.ML.Runtime.Internal.Utilities;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;

namespace Microsoft.ML
{
public static class TextLoaderSaverCatalog
{
/// <summary>
/// Create a text reader.
/// </summary>
/// <param name="catalog">The catalog.</param>
/// <param name="args">The arguments to text reader, describing the data schema.</param>
/// <param name="dataSample">The optional location of a data sample.</param>
public static TextLoader TextReader(this DataLoadSaveOperations catalog,
TextLoader.Arguments args, IMultiStreamSource dataSample = null)
=> new TextLoader(CatalogUtils.GetEnvironment(catalog), args, dataSample);

/// <summary>
/// Create a text reader.
/// </summary>
/// <param name="catalog">The catalog.</param>
/// <param name="columns">The columns of the schema.</param>
/// <param name="advancedSettings">The delegate to set additional settings.</param>
/// <param name="dataSample">The optional location of a data sample.</param>
public static TextLoader TextReader(this DataLoadSaveOperations catalog,
TextLoader.Column[] columns, Action<TextLoader.Arguments> advancedSettings = null, IMultiStreamSource dataSample = null)
=> new TextLoader(CatalogUtils.GetEnvironment(catalog), columns, advancedSettings, dataSample);

/// <summary>
/// Read a data view from a text file using <see cref="TextLoader"/>.
/// </summary>
/// <param name="catalog">The catalog.</param>
/// <param name="columns">The columns of the schema.</param>
/// <param name="advancedSettings">The delegate to set additional settings</param>
/// <param name="path">The path to the file</param>
/// <returns>The data view.</returns>
public static IDataView ReadFromTextFile(this DataLoadSaveOperations catalog,
TextLoader.Column[] columns, string path, Action<TextLoader.Arguments> advancedSettings = null)
{
Contracts.CheckNonEmpty(path, nameof(path));

var env = catalog.GetEnvironment();

// REVIEW: it is almost always a mistake to have a 'trainable' text loader here.
// Therefore, we are going to disallow data sample.
var reader = new TextLoader(env, columns, advancedSettings, dataSample: null);
return reader.Read(new MultiFileSource(path));
}

/// <summary>
/// Save the data view as text.
/// </summary>
/// <param name="catalog">The catalog.</param>
/// <param name="data">The data view to save.</param>
/// <param name="stream">The stream to write to.</param>
/// <param name="separator">The column separator.</param>
/// <param name="headerRow">Whether to write the header row.</param>
/// <param name="schema">Whether to write the header comment with the schema.</param>
/// <param name="keepHidden">Whether to keep hidden columns in the dataset.</param>
public static void SaveAsText(this DataLoadSaveOperations catalog, IDataView data, Stream stream,
char separator = '\t', bool headerRow = true, bool schema = true, bool keepHidden = false)
{
Contracts.CheckValue(catalog, nameof(catalog));
Contracts.CheckValue(data, nameof(data));
Contracts.CheckValue(stream, nameof(stream));

var env = catalog.GetEnvironment();
var saver = new TextSaver(env, new TextSaver.Arguments { Separator = separator.ToString(), OutputHeader = headerRow, OutputSchema = schema });

using (var ch = env.Start("Saving data"))
DataSaverUtils.SaveDataView(ch, saver, data, stream, keepHidden);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ public static RegressionEvaluator.Result Evaluate<T>(
/// <param name="score">The index delegate for predicted score column.</param>
/// <returns>The evaluation metrics.</returns>
public static RankerEvaluator.Result Evaluate<T, TVal>(
this RankerContext ctx,
this RankingContext ctx,
DataView<T> data,
Func<T, Scalar<float>> label,
Func<T, Key<uint, TVal>> groupId,
Expand Down
108 changes: 108 additions & 0 deletions src/Microsoft.ML.Data/MLContext.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using System;

namespace Microsoft.ML
{
/// <summary>
/// The <see cref="MLContext"/> is a starting point for all ML.NET operations. It is instantiated by user,
/// provides mechanisms for logging and entry points for training, prediction, model operations etc.
/// </summary>
public sealed class MLContext : IHostEnvironment
{
// REVIEW: consider making LocalEnvironment and MLContext the same class instead of encapsulation.
Copy link
Member

@eerhardt eerhardt Oct 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't fully understand our plan here with LocalEnvironment and ConsoleEnvironment. How could I ever use a ConsoleEnvironment with this new MLContext class?

And if the answer is "you can't - you'd use ConsoleEnvironment directly", I think that is a pretty poor answer because then my whole API paradigm changes. I can no longer use BinaryClassificationContext, Transforms, etc. from the MLContext. #Closed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about the answer: "you can't use ConsoleEnvironment because it is internal and only used by MAML"?


In reply to: 225986842 [](ancestors = 225986842)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be fine with me. For now, I would think it's completely fine to have our custom console printing internal to the commandline tool


In reply to: 225989180 [](ancestors = 225989180,225986842)

Copy link
Member

@eerhardt eerhardt Oct 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#1284 #Closed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.


In reply to: 226052093 [](ancestors = 226052093)

private readonly LocalEnvironment _env;

/// <summary>
/// Trainers and tasks specific to binary classification problems.
/// </summary>
public BinaryClassificationContext BinaryClassification { get; }
/// <summary>
/// Trainers and tasks specific to multiclass classification problems.
/// </summary>
public MulticlassClassificationContext MulticlassClassification { get; }
/// <summary>
/// Trainers and tasks specific to regression problems.
/// </summary>
public RegressionContext Regression { get; }
/// <summary>
/// Trainers and tasks specific to clustering problems.
/// </summary>
public ClusteringContext Clustering { get; }
/// <summary>
/// Trainers and tasks specific to ranking problems.
/// </summary>
public RankingContext Ranking { get; }

/// <summary>
/// Data processing operations.
/// </summary>
public TransformsCatalog Transforms { get; }

/// <summary>
/// Operations with trained models.
/// </summary>
public ModelOperationsCatalog Model { get; }

/// <summary>
/// Data loading and saving.
/// </summary>
public DataLoadSaveOperations Data { get; }

// REVIEW: I think it's valuable to have the simplest possible interface for logging interception here,
// and expand if and when necessary. Exposing classes like ChannelMessage, MessageSensitivity and so on
// looks premature at this point.
/// <summary>
/// The handler for the log messages.
/// </summary>
public Action<string> Log { get; set; }

/// <summary>
/// Create the ML context.
/// </summary>
/// <param name="seed">Random seed. Set to <c>null</c> for a non-deterministic environment.</param>
/// <param name="conc">Concurrency level. Set to 1 to run single-threaded. Set to 0 to pick automatically.</param>
public MLContext(int? seed = null, int conc = 0)
{
_env = new LocalEnvironment(seed, conc);
_env.AddListener(ProcessMessage);

BinaryClassification = new BinaryClassificationContext(_env);
MulticlassClassification = new MulticlassClassificationContext(_env);
Regression = new RegressionContext(_env);
Clustering = new ClusteringContext(_env);
Ranking = new RankingContext(_env);
Transforms = new TransformsCatalog(_env);
Model = new ModelOperationsCatalog(_env);
Data = new DataLoadSaveOperations(_env);
}

private void ProcessMessage(IMessageSource source, ChannelMessage message)
{
if (Log == null)
return;

var msg = $"[Source={source.FullName}, Kind={message.Kind}] {message.Message}";
// Log may have been reset from another thread.
// We don't care which logger we send the message to, just making sure we don't crash.
Log?.Invoke(msg);
}

int IHostEnvironment.ConcurrencyFactor => _env.ConcurrencyFactor;
bool IHostEnvironment.IsCancelled => _env.IsCancelled;
ComponentCatalog IHostEnvironment.ComponentCatalog => _env.ComponentCatalog;
string IExceptionContext.ContextDescription => _env.ContextDescription;
IFileHandle IHostEnvironment.CreateOutputFile(string path) => _env.CreateOutputFile(path);
IFileHandle IHostEnvironment.CreateTempFile(string suffix, string prefix) => _env.CreateTempFile(suffix, prefix);
IFileHandle IHostEnvironment.OpenInputFile(string path) => _env.OpenInputFile(path);
TException IExceptionContext.Process<TException>(TException ex) => _env.Process(ex);
IHost IHostEnvironment.Register(string name, int? seed, bool? verbose, int? conc) => _env.Register(name, seed, verbose, conc);
IChannel IChannelProvider.Start(string name) => _env.Start(name);
IPipe<TMessage> IChannelProvider.StartPipe<TMessage>(string name) => _env.StartPipe<TMessage>(name);
IProgressChannel IProgressChannelProvider.StartProgressChannel(string name) => _env.StartProgressChannel(name);
}
}
38 changes: 38 additions & 0 deletions src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime.Data;
using System.IO;

namespace Microsoft.ML.Runtime
{
/// <summary>
/// An object serving as a 'catalog' of available model operations.
/// </summary>
public sealed class ModelOperationsCatalog
{
internal IHostEnvironment Environment { get; }

internal ModelOperationsCatalog(IHostEnvironment env)
{
Contracts.AssertValue(env);
Environment = env;
}

/// <summary>
/// Save the model to the stream.
/// </summary>
/// <param name="model">The trained model to be saved.</param>
/// <param name="stream">A writeable, seekable stream to save to.</param>
public void Save(ITransformer model, Stream stream) => model.SaveTo(Environment, stream);

/// <summary>
/// Load the model from the stream.
Copy link
Member

@sfilipi sfilipi Oct 13, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model [](start = 21, length = 5)

If we are saving and loading an ITransformer, should we call the class TransformerOperationsCatalog? #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the name should reflect the user-facing concept, rather than the implementation fact that 'trained model is a transformer'.


In reply to: 224950106 [](ancestors = 224950106)

/// </summary>
/// <param name="stream">A readable, seekable stream to load from.</param>
/// <returns>The loaded model.</returns>
public ITransformer Load(Stream stream) => TransformerChain.LoadFrom(Environment, stream);
}
}
Loading