-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Internalization of TensorFlowUtils.cs and refactored TensorFlowCatalog. #2672
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
e5eef19
b170f52
ee9b7ae
6cc3f1c
1abb719
7b4c08c
fc188cd
963b9cd
a78ba89
ffd534e
d1c0dd8
e742885
7cd88ed
509bb12
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,6 +37,11 @@ public sealed class TransformsCatalog | |
/// </summary> | ||
public FeatureSelectionTransforms FeatureSelection { get; } | ||
|
||
/// <summary> | ||
/// List of operations for using TensorFlow model. | ||
/// </summary> | ||
public TensorFlowTransforms TensorFlow { get; } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Please do not do this. Otherwise we will have an empty property unless someone imports the nuget, which is confusing and undesirable. Please follow instead the pattern that we see in image processing. You'll note that we do not have an empty image processing nuget. Rather they are added to this catalog. Similar with ONNX scoring. You'll note that these both have extensions on This is defensible since we can take someone directly importing a nuget as a strong signal that they want to actually use those transforms. #Resolved |
||
|
||
internal TransformsCatalog(IHostEnvironment env) | ||
{ | ||
Contracts.AssertValue(env); | ||
|
@@ -47,6 +52,7 @@ internal TransformsCatalog(IHostEnvironment env) | |
Text = new TextTransforms(this); | ||
Projection = new ProjectionTransforms(this); | ||
FeatureSelection = new FeatureSelectionTransforms(this); | ||
TensorFlow = new TensorFlowTransforms(this); | ||
} | ||
|
||
public abstract class SubCatalogBase | ||
|
@@ -109,5 +115,15 @@ internal FeatureSelectionTransforms(TransformsCatalog owner) : base(owner) | |
{ | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// The catalog of TensorFlow operations. | ||
/// </summary> | ||
public sealed class TensorFlowTransforms : SubCatalogBase | ||
{ | ||
internal TensorFlowTransforms(TransformsCatalog owner) : base(owner) | ||
{ | ||
} | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,7 +20,7 @@ namespace Microsoft.ML.Transforms | |
/// </item> | ||
/// </list> | ||
/// </summary> | ||
public class TensorFlowModelInfo | ||
public sealed class TensorFlowModelInfo | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Would it be possible to rename this to |
||
{ | ||
internal TFSession Session { get; } | ||
public string ModelPath { get; } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think this can also be internal. #Resolved |
||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -2,9 +2,11 @@ | |||
// The .NET Foundation licenses this file to you under the MIT license. | ||||
// See the LICENSE file in the project root for more information. | ||||
|
||||
using System.Collections.Generic; | ||||
using Microsoft.Data.DataView; | ||||
using Microsoft.ML.Data; | ||||
using Microsoft.ML.Transforms; | ||||
using Microsoft.ML.Transforms.TensorFlow; | ||||
|
||||
namespace Microsoft.ML | ||||
{ | ||||
|
@@ -25,7 +27,7 @@ public static class TensorflowCatalog | |||
/// ]]> | ||||
/// </format> | ||||
/// </example> | ||||
public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog catalog, | ||||
public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog.TensorFlowTransforms catalog, | ||||
string modelLocation, | ||||
string outputColumnName, | ||||
string inputColumnName) | ||||
|
@@ -45,7 +47,7 @@ public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog ca | |||
/// ]]> | ||||
/// </format> | ||||
/// </example> | ||||
public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog catalog, | ||||
public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog.TensorFlowTransforms catalog, | ||||
string modelLocation, | ||||
string[] outputColumnNames, | ||||
string[] inputColumnNames) | ||||
|
@@ -58,7 +60,7 @@ public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog ca | |||
/// <param name="tensorFlowModel">The pre-loaded TensorFlow model.</param> | ||||
/// <param name="inputColumnName"> The name of the model input.</param> | ||||
/// <param name="outputColumnName">The name of the requested model output.</param> | ||||
public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog catalog, | ||||
public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog.TensorFlowTransforms catalog, | ||||
TensorFlowModelInfo tensorFlowModel, | ||||
string outputColumnName, | ||||
string inputColumnName) | ||||
|
@@ -78,7 +80,7 @@ public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog ca | |||
/// ]]> | ||||
/// </format> | ||||
/// </example> | ||||
public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog catalog, | ||||
public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog.TensorFlowTransforms catalog, | ||||
TensorFlowModelInfo tensorFlowModel, | ||||
string[] outputColumnNames, | ||||
string[] inputColumnNames) | ||||
|
@@ -90,7 +92,7 @@ public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog ca | |||
/// </summary> | ||||
/// <param name="catalog">The transform's catalog.</param> | ||||
/// <param name="options">The <see cref="TensorFlowEstimator.Options"/> specifying the inputs and the settings of the <see cref="TensorFlowEstimator"/>.</param> | ||||
public static TensorFlowEstimator TensorFlow(this TransformsCatalog catalog, | ||||
public static TensorFlowEstimator TensorFlow(this TransformsCatalog.TensorFlowTransforms catalog, | ||||
TensorFlowEstimator.Options options) | ||||
=> new TensorFlowEstimator(CatalogUtils.GetEnvironment(catalog), options); | ||||
|
||||
|
@@ -100,9 +102,42 @@ public static TensorFlowEstimator TensorFlow(this TransformsCatalog catalog, | |||
/// <param name="catalog">The transform's catalog.</param> | ||||
/// <param name="options">The <see cref="TensorFlowEstimator.Options"/> specifying the inputs and the settings of the <see cref="TensorFlowEstimator"/>.</param> | ||||
/// <param name="tensorFlowModel">The pre-loaded TensorFlow model.</param> | ||||
public static TensorFlowEstimator TensorFlow(this TransformsCatalog catalog, | ||||
public static TensorFlowEstimator TensorFlow(this TransformsCatalog.TensorFlowTransforms catalog, | ||||
TensorFlowEstimator.Options options, | ||||
TensorFlowModelInfo tensorFlowModel) | ||||
=> new TensorFlowEstimator(CatalogUtils.GetEnvironment(catalog), options, tensorFlowModel); | ||||
|
||||
/// <summary> | ||||
/// This method retrieves the information about the graph nodes of a TensorFlow model as an <see cref="DataViewSchema"/>. | ||||
/// For every node in the graph that has an output type that is compatible with the types supported by | ||||
/// <see cref="TensorFlowEstimator"/>, the output schema contains a column with the name of that node, and the | ||||
/// type of its output (including the item type and the shape, if it is known). Every column also contains metadata | ||||
/// of kind <see cref="TensorFlowUtils.TensorflowOperatorTypeKind"/>, indicating the operation type of the node, and if that node has inputs in the graph, | ||||
/// it contains metadata of kind <see cref="TensorFlowUtils.TensorflowUpstreamOperatorsKind"/>, indicating the names of the input nodes. | ||||
/// </summary> | ||||
/// <param name="catalog">The transform's catalog.</param> | ||||
/// <param name="modelLocation">Location of the TensorFlow model.</param> | ||||
public static DataViewSchema GetModelSchema(this TransformsCatalog.TensorFlowTransforms catalog, string modelLocation) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think |
||||
=> TensorFlowUtils.GetModelSchema(CatalogUtils.GetEnvironment(catalog), modelLocation); | ||||
|
||||
/// <summary> | ||||
/// This is a convenience method for iterating over the nodes of a TensorFlow model graph. It | ||||
/// iterates over the columns of the <see cref="DataViewSchema"/> returned by <see cref="GetModelSchema(TransformsCatalog.TensorFlowTransforms, string)"/>, | ||||
/// and for each one it returns a tuple containing the name, operation type, column type and an array of input node names. | ||||
/// This method is convenient for filtering nodes based on certain criteria, for example, by the operation type. | ||||
/// </summary> | ||||
/// <param name="catalog">The transform's catalog.</param> | ||||
/// <param name="modelLocation">Location of the TensorFlow model.</param> | ||||
public static IEnumerable<(string, string, DataViewType, string[])> GetModelNodes(this TransformsCatalog.TensorFlowTransforms catalog, string modelLocation) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This doesn't need to be a part of the public API. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is being use here. machinelearning/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/DnnAnalyzer.cs Line 20 in fb6ce54
If I don't expose it like this then I will have to make the internal one public. What do you suggest? In reply to: 258728605 [](ancestors = 258728605) |
||||
=> TensorFlowUtils.GetModelNodes(CatalogUtils.GetEnvironment(catalog), modelLocation); | ||||
|
||||
/// <summary> | ||||
/// Load TensorFlow model into memory. This is the convenience method that allows the model to be loaded once and subsequently use it for querying schema and creation of | ||||
/// <see cref="TensorFlowEstimator"/> using <see cref="TensorFlow(TransformsCatalog.TensorFlowTransforms, TensorFlowEstimator.Options, TensorFlowModelInfo)"/>. | ||||
/// </summary> | ||||
/// <param name="catalog">The transform's catalog.</param> | ||||
/// <param name="modelLocation">Location of the TensorFlow model.</param> | ||||
public static TensorFlowModelInfo LoadTensorFlowModel(this TransformsCatalog.TensorFlowTransforms catalog, string modelLocation) | ||||
=> TensorFlowUtils.LoadTensorFlowModel(CatalogUtils.GetEnvironment(catalog), modelLocation); | ||||
} | ||||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a sample that uses
modelInfo.GetInputSchema()
to find out what the name of the input node is?#Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see its being used at a couple of places in the tests e.g.
machinelearning/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
Line 894 in eb959c3
machinelearning/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
Line 849 in eb959c3