Skip to content

Commit a1b66ac

Browse files
authored
Hide much of Microsoft.ML.Model namespace. (#2649)
* Internalize ModelLoadContext * Internalize VersionInfo * Internalize members of ModelSaveContext while keeping the class itself public * Internalize Repository and subclasses
1 parent 30df61a commit a1b66ac

File tree

17 files changed

+88
-49
lines changed

17 files changed

+88
-49
lines changed

src/Microsoft.ML.Core/Data/ModelHeader.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -621,7 +621,8 @@ public static string GetLoaderSigAlt(ref ModelHeader header)
621621
/// This is used to simplify version checking boiler-plate code. It is an optional
622622
/// utility type.
623623
/// </summary>
624-
public readonly struct VersionInfo
624+
[BestFriend]
625+
internal readonly struct VersionInfo
625626
{
626627
public readonly ulong ModelSignature;
627628
public readonly uint VerWrittenCur;

src/Microsoft.ML.Core/Data/ModelLoadContext.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ namespace Microsoft.ML.Model
1515
/// amount of boiler plate code. It can also be used when loading from a single stream,
1616
/// for implementors of ICanSaveInBinaryFormat.
1717
/// </summary>
18-
public sealed partial class ModelLoadContext : IDisposable
18+
[BestFriend]
19+
internal sealed partial class ModelLoadContext : IDisposable
1920
{
2021
/// <summary>
2122
/// When in repository mode, this is the repository we're reading from. It is null when

src/Microsoft.ML.Core/Data/ModelLoading.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ namespace Microsoft.ML.Model
1616
[BestFriend]
1717
internal delegate void SignatureLoadModel(ModelLoadContext ctx);
1818

19-
public sealed partial class ModelLoadContext : IDisposable
19+
internal sealed partial class ModelLoadContext : IDisposable
2020
{
2121
public const string ModelStreamName = "Model.key";
2222
internal const string NameBinary = "Model.bin";

src/Microsoft.ML.Core/Data/ModelSaveContext.cs

+30-15
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,21 @@ public sealed partial class ModelSaveContext : IDisposable
2121
/// When in repository mode, this is the repository we're writing to. It is null when
2222
/// in single-stream mode.
2323
/// </summary>
24-
public readonly RepositoryWriter Repository;
24+
[BestFriend]
25+
internal readonly RepositoryWriter Repository;
2526

2627
/// <summary>
2728
/// When in repository mode, this is the directory we're reading from. Null means the root
2829
/// of the repository. It is always null in single-stream mode.
2930
/// </summary>
30-
public readonly string Directory;
31+
[BestFriend]
32+
internal readonly string Directory;
3133

3234
/// <summary>
3335
/// The main stream writer.
3436
/// </summary>
35-
public readonly BinaryWriter Writer;
37+
[BestFriend]
38+
internal readonly BinaryWriter Writer;
3639

3740
/// <summary>
3841
/// The strings that will be saved in the main stream's string table.
@@ -49,7 +52,8 @@ public sealed partial class ModelSaveContext : IDisposable
4952
/// <summary>
5053
/// The min file position of the main stream.
5154
/// </summary>
52-
public readonly long FpMin;
55+
[BestFriend]
56+
internal readonly long FpMin;
5357

5458
/// <summary>
5559
/// The wrapped entry.
@@ -69,7 +73,8 @@ public sealed partial class ModelSaveContext : IDisposable
6973
/// <summary>
7074
/// Returns whether this context is in repository mode (true) or single-stream mode (false).
7175
/// </summary>
72-
public bool InRepository { get { return Repository != null; } }
76+
[BestFriend]
77+
internal bool InRepository => Repository != null;
7378

7479
/// <summary>
7580
/// Create a <see cref="ModelSaveContext"/> supporting saving to a repository, for implementors of <see cref="ICanSaveModel"/>.
@@ -125,7 +130,8 @@ internal ModelSaveContext(BinaryWriter writer, IExceptionContext ectx = null)
125130
ModelHeader.BeginWrite(Writer, out FpMin, out Header);
126131
}
127132

128-
public void CheckAtModel()
133+
[BestFriend]
134+
internal void CheckAtModel()
129135
{
130136
_ectx.Check(Writer.BaseStream.Position == FpMin + Header.FpModel);
131137
}
@@ -135,13 +141,15 @@ public void CheckAtModel()
135141
/// <see cref="Done"/> is called.
136142
/// </summary>
137143
/// <param name="ver"></param>
138-
public void SetVersionInfo(VersionInfo ver)
144+
[BestFriend]
145+
internal void SetVersionInfo(VersionInfo ver)
139146
{
140147
ModelHeader.SetVersionInfo(ref Header, ver);
141148
_loaderAssemblyName = ver.LoaderAssemblyName;
142149
}
143150

144-
public void SaveTextStream(string name, Action<TextWriter> action)
151+
[BestFriend]
152+
internal void SaveTextStream(string name, Action<TextWriter> action)
145153
{
146154
_ectx.Check(InRepository, "Can't save a text stream when writing to a single stream");
147155
_ectx.CheckNonEmpty(name, nameof(name));
@@ -156,7 +164,8 @@ public void SaveTextStream(string name, Action<TextWriter> action)
156164
}
157165
}
158166

159-
public void SaveBinaryStream(string name, Action<BinaryWriter> action)
167+
[BestFriend]
168+
internal void SaveBinaryStream(string name, Action<BinaryWriter> action)
160169
{
161170
_ectx.Check(InRepository, "Can't save a text stream when writing to a single stream");
162171
_ectx.CheckNonEmpty(name, nameof(name));
@@ -175,7 +184,8 @@ public void SaveBinaryStream(string name, Action<BinaryWriter> action)
175184
/// Puts a string into the context pool, and writes the integer code of the string ID
176185
/// to the write stream. If str is null, this writes -1 and doesn't add it to the pool.
177186
/// </summary>
178-
public void SaveStringOrNull(string str)
187+
[BestFriend]
188+
internal void SaveStringOrNull(string str)
179189
{
180190
if (str == null)
181191
Writer.Write(-1);
@@ -187,13 +197,15 @@ public void SaveStringOrNull(string str)
187197
/// Puts a string into the context pool, and writes the integer code of the string ID
188198
/// to the write stream. Checks that str is not null.
189199
/// </summary>
190-
public void SaveString(string str)
200+
[BestFriend]
201+
internal void SaveString(string str)
191202
{
192203
_ectx.CheckValue(str, nameof(str));
193204
Writer.Write(Strings.Add(str).Id);
194205
}
195206

196-
public void SaveString(ReadOnlyMemory<char> str)
207+
[BestFriend]
208+
internal void SaveString(ReadOnlyMemory<char> str)
197209
{
198210
Writer.Write(Strings.Add(str).Id);
199211
}
@@ -202,13 +214,15 @@ public void SaveString(ReadOnlyMemory<char> str)
202214
/// Puts a string into the context pool, and writes the integer code of the string ID
203215
/// to the write stream.
204216
/// </summary>
205-
public void SaveNonEmptyString(string str)
217+
[BestFriend]
218+
internal void SaveNonEmptyString(string str)
206219
{
207220
_ectx.CheckParam(!string.IsNullOrEmpty(str), nameof(str));
208221
Writer.Write(Strings.Add(str).Id);
209222
}
210223

211-
public void SaveNonEmptyString(ReadOnlyMemory<Char> str)
224+
[BestFriend]
225+
internal void SaveNonEmptyString(ReadOnlyMemory<char> str)
212226
{
213227
Writer.Write(Strings.Add(str).Id);
214228
}
@@ -217,7 +231,8 @@ public void SaveNonEmptyString(ReadOnlyMemory<Char> str)
217231
/// Commit the save operation. This completes writing of the main stream. When in repository
218232
/// mode, it disposes <see cref="Writer"/> (but not <see cref="Repository"/>).
219233
/// </summary>
220-
public void Done()
234+
[BestFriend]
235+
internal void Done()
221236
{
222237
_ectx.Check(Header.ModelSignature != 0, "ModelSignature not specified!");
223238
ModelHeader.EndWrite(Writer, FpMin, ref Header, Strings, _loaderAssemblyName);

src/Microsoft.ML.Core/Data/ModelSaving.cs

+11-6
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@ namespace Microsoft.ML.Model
1111
public sealed partial class ModelSaveContext : IDisposable
1212
{
1313
/// <summary>
14-
/// Save a sub model to the given sub directory. This requires InRepository to be true.
14+
/// Save a sub model to the given sub directory. This requires <see cref="InRepository"/> to be <see langword="true"/>.
1515
/// </summary>
16-
public void SaveModel<T>(T value, string name)
16+
[BestFriend]
17+
internal void SaveModel<T>(T value, string name)
1718
where T : class
1819
{
1920
_ectx.Check(InRepository, "Can't save a sub-model when writing to a single stream");
@@ -23,7 +24,8 @@ public void SaveModel<T>(T value, string name)
2324
/// <summary>
2425
/// Save the object by calling TrySaveModel then falling back to .net serialization.
2526
/// </summary>
26-
public static void SaveModel<T>(RepositoryWriter rep, T value, string path)
27+
[BestFriend]
28+
internal static void SaveModel<T>(RepositoryWriter rep, T value, string path)
2729
where T : class
2830
{
2931
if (value == null)
@@ -55,7 +57,8 @@ public static void SaveModel<T>(RepositoryWriter rep, T value, string path)
5557
/// <summary>
5658
/// Save to a single-stream by invoking the given action.
5759
/// </summary>
58-
public static void Save(BinaryWriter writer, Action<ModelSaveContext> fn)
60+
[BestFriend]
61+
internal static void Save(BinaryWriter writer, Action<ModelSaveContext> fn)
5962
{
6063
Contracts.CheckValue(writer, nameof(writer));
6164
Contracts.CheckValue(fn, nameof(fn));
@@ -68,9 +71,11 @@ public static void Save(BinaryWriter writer, Action<ModelSaveContext> fn)
6871
}
6972

7073
/// <summary>
71-
/// Save to the given sub directory by invoking the given action. This requires InRepository to be true.
74+
/// Save to the given sub directory by invoking the given action. This requires
75+
/// <see cref="InRepository"/> to be <see langword="true"/>.
7276
/// </summary>
73-
public void SaveSubModel(string dir, Action<ModelSaveContext> fn)
77+
[BestFriend]
78+
internal void SaveSubModel(string dir, Action<ModelSaveContext> fn)
7479
{
7580
_ectx.Check(InRepository, "Can't save a sub-model when writing to a single stream");
7681
_ectx.CheckNonEmpty(dir, nameof(dir));

src/Microsoft.ML.Core/Data/Repository.cs

+7-4
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,10 @@ internal interface ICanSaveInBinaryFormat
3232
}
3333

3434
/// <summary>
35-
/// Abstraction around a ZipArchive or other hierarchical storage.
35+
/// Abstraction around a <see cref="ZipArchive"/> or other hierarchical storage.
3636
/// </summary>
37-
public abstract class Repository : IDisposable
37+
[BestFriend]
38+
internal abstract class Repository : IDisposable
3839
{
3940
public sealed class Entry : IDisposable
4041
{
@@ -289,7 +290,8 @@ protected Entry AddEntry(string pathEnt, Stream stream)
289290
}
290291
}
291292

292-
public sealed class RepositoryWriter : Repository
293+
[BestFriend]
294+
internal sealed class RepositoryWriter : Repository
293295
{
294296
private const string DirTrainingInfo = "TrainingInfo";
295297

@@ -429,7 +431,8 @@ public void Commit()
429431
}
430432
}
431433

432-
public sealed class RepositoryReader : Repository
434+
[BestFriend]
435+
internal sealed class RepositoryReader : Repository
433436
{
434437
private ZipArchive _archive;
435438

src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ private static VersionInfo GetVersionInfo()
8787
}
8888

8989
// Factory for SignatureLoadModel.
90-
public TransformWrapper(IHostEnvironment env, ModelLoadContext ctx)
90+
private TransformWrapper(IHostEnvironment env, ModelLoadContext ctx)
9191
{
9292
Contracts.CheckValue(env, nameof(env));
9393
_host = env.Register(nameof(TransformWrapper));

src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ public static class TransformerChain
239239
{
240240
public const string LoaderSignature = "TransformerChain";
241241

242-
public static TransformerChain<ITransformer> Create(IHostEnvironment env, ModelLoadContext ctx)
242+
private static TransformerChain<ITransformer> Create(IHostEnvironment env, ModelLoadContext ctx)
243243
=> new TransformerChain<ITransformer>(env, ctx);
244244

245245
/// <summary>

src/Microsoft.ML.Data/Dirty/ModelParametersBase.cs

+7-4
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,23 @@ namespace Microsoft.ML.Internal.Internallearn
1414
/// </summary>
1515
public abstract class ModelParametersBase<TOutput> : ICanSaveModel, IPredictorProducing<TOutput>
1616
{
17-
public const string NormalizerWarningFormat =
17+
private const string NormalizerWarningFormat =
1818
"Ignoring integrated normalizer while loading a predictor of type {0}.{1}" +
1919
" Please refer to https://aka.ms/MLNetIssue for assistance with converting legacy models.";
2020

21-
protected readonly IHost Host;
21+
[BestFriend]
22+
private protected readonly IHost Host;
2223

23-
protected ModelParametersBase(IHostEnvironment env, string name)
24+
[BestFriend]
25+
private protected ModelParametersBase(IHostEnvironment env, string name)
2426
{
2527
Contracts.CheckValue(env, nameof(env));
2628
env.CheckNonWhiteSpace(name, nameof(name));
2729
Host = env.Register(name);
2830
}
2931

30-
protected ModelParametersBase(IHostEnvironment env, string name, ModelLoadContext ctx)
32+
[BestFriend]
33+
private protected ModelParametersBase(IHostEnvironment env, string name, ModelLoadContext ctx)
3134
{
3235
Contracts.CheckValue(env, nameof(env));
3336
env.CheckNonWhiteSpace(name, nameof(name));

src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -998,7 +998,7 @@ private protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleM
998998
// Multi-class evaluator adds four per-instance columns: "Assigned", "Top scores", "Top classes" and "Log-loss".
999999
private protected override IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMappedSchema schema)
10001000
{
1001-
// If the label column is a key without text key values, convert it to I8, just for saving the per-instance
1001+
// If the label column is a key without text key values, convert it to double, just for saving the per-instance
10021002
// text file, since if there are different key counts the columns cannot be appended.
10031003
string labelName = schema.Label.Value.Name;
10041004
if (!perInst.Schema.TryGetColumnIndex(labelName, out int labelColIndex))

src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -153,15 +153,15 @@ public abstract class CalibratorTransformer<TICalibrator> : RowToRowTransformerB
153153
private TICalibrator _calibrator;
154154
private readonly string _loaderSignature;
155155

156-
internal CalibratorTransformer(IHostEnvironment env, TICalibrator calibrator, string loaderSignature)
156+
private protected CalibratorTransformer(IHostEnvironment env, TICalibrator calibrator, string loaderSignature)
157157
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(CalibratorTransformer<TICalibrator>)))
158158
{
159159
_loaderSignature = loaderSignature;
160160
_calibrator = calibrator;
161161
}
162162

163163
// Factory method for SignatureLoadModel.
164-
internal CalibratorTransformer(IHostEnvironment env, ModelLoadContext ctx, string loaderSignature)
164+
private protected CalibratorTransformer(IHostEnvironment env, ModelLoadContext ctx, string loaderSignature)
165165
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(CalibratorTransformer<TICalibrator>)))
166166
{
167167
Contracts.AssertValue(ctx);
@@ -195,7 +195,7 @@ private protected override void SaveModel(ModelSaveContext ctx)
195195

196196
private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper<TICalibrator>(this, _calibrator, schema);
197197

198-
protected VersionInfo GetVersionInfo()
198+
private protected VersionInfo GetVersionInfo()
199199
{
200200
return new VersionInfo(
201201
modelSignature: "CALTRANS",

src/Microsoft.ML.Data/Transforms/InvertHashUtils.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,8 @@ public void Add(uint hash, T key)
321321
/// Simple utility class for saving a <see cref="VBuffer{T}"/> of ReadOnlyMemory
322322
/// as a model, both in a binary and more easily human readable form.
323323
/// </summary>
324-
public static class TextModelHelper
324+
[BestFriend]
325+
internal static class TextModelHelper
325326
{
326327
private const string LoaderSignature = "TextSpanBuffer";
327328

src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ protected OneToOneTransformerBase(IHost host, params (string outputColumnName, s
3232
ColumnPairs = columns;
3333
}
3434

35-
protected OneToOneTransformerBase(IHost host, ModelLoadContext ctx) : base(host)
35+
[BestFriend]
36+
private protected OneToOneTransformerBase(IHost host, ModelLoadContext ctx) : base(host)
3637
{
3738
// *** Binary format ***
3839
// int: number of added columns

src/Microsoft.ML.Data/Transforms/ValueMapping.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -666,7 +666,7 @@ private static bool CheckModelVersion(ModelLoadContext ctx, VersionInfo versionI
666666
}
667667
}
668668

669-
protected static ValueMappingTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
669+
private protected static ValueMappingTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
670670
{
671671
Contracts.CheckValue(env, nameof(env));
672672
env.CheckValue(ctx, nameof(ctx));

src/Microsoft.ML.StandardLearners/Standard/LinearModelParameters.cs

+5-3
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ public LinearModelParameters(IHostEnvironment env, string name, in VBuffer<float
131131
_weightsDenseLock = new object();
132132
}
133133

134-
protected LinearModelParameters(IHostEnvironment env, string name, ModelLoadContext ctx)
134+
private protected LinearModelParameters(IHostEnvironment env, string name, ModelLoadContext ctx)
135135
: base(env, name, ctx)
136136
{
137137
// *** Binary format ***
@@ -547,12 +547,14 @@ private protected override void SaveAsIni(TextWriter writer, RoleMappedSchema sc
547547

548548
public abstract class RegressionModelParameters : LinearModelParameters
549549
{
550-
public RegressionModelParameters(IHostEnvironment env, string name, in VBuffer<float> weights, float bias)
550+
[BestFriend]
551+
private protected RegressionModelParameters(IHostEnvironment env, string name, in VBuffer<float> weights, float bias)
551552
: base(env, name, in weights, bias)
552553
{
553554
}
554555

555-
protected RegressionModelParameters(IHostEnvironment env, string name, ModelLoadContext ctx)
556+
[BestFriend]
557+
private protected RegressionModelParameters(IHostEnvironment env, string name, ModelLoadContext ctx)
556558
: base(env, name, ctx)
557559
{
558560
}

test/Microsoft.ML.CodeAnalyzer.Tests/Code/ContractsCheckTest.cs

-3
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,6 @@ public async Task ContractsCheck()
3939
VerifyCS.Diagnostic(ContractsCheckAnalyzer.SimpleMessageDiagnostic.Rule).WithLocation(basis + 32, 35).WithArguments("Check", "\"Less fine: \" + env.GetType().Name"),
4040
VerifyCS.Diagnostic(ContractsCheckAnalyzer.NameofDiagnostic.Rule).WithLocation(basis + 34, 17).WithArguments("CheckUserArg", "name", "\"p\""),
4141
VerifyCS.Diagnostic(ContractsCheckAnalyzer.DecodeMessageWithLoadContextDiagnostic.Rule).WithLocation(basis + 39, 41).WithArguments("CheckDecode", "\"This message is suspicious\""),
42-
new DiagnosticResult("CS0117", DiagnosticSeverity.Error).WithLocation("Test1.cs", 220, 70).WithMessage("'MessageSensitivity' does not contain a definition for 'UserData'"),
43-
new DiagnosticResult("CS0117", DiagnosticSeverity.Error).WithLocation("Test1.cs", 231, 70).WithMessage("'MessageSensitivity' does not contain a definition for 'Schema'"),
44-
new DiagnosticResult("CS1061", DiagnosticSeverity.Error).WithLocation("Test1.cs", 747, 21).WithMessage("'IHostEnvironment' does not contain a definition for 'IsCancelled' and no accessible extension method 'IsCancelled' accepting a first argument of type 'IHostEnvironment' could be found (are you missing a using directive or an assembly reference?)"),
4542
};
4643

4744
var test = new VerifyCS.Test

0 commit comments

Comments
 (0)