Skip to content

Commit 634eaa3

Browse files
committed
Replace reflection with Contracts
1 parent 0da937d commit 634eaa3

File tree

1 file changed

+5
-23
lines changed

1 file changed

+5
-23
lines changed

src/Microsoft.ML.Data/Utilities/ColumnCursor.cs

+5-23
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,11 @@ public static class ColumnCursorExtensions
2323
/// <param name="columnName">The name of the column to extract.</param>
2424
public static IEnumerable<T> GetColumn<T>(this IDataView data, string columnName)
2525
{
26-
var env = RetrieveHost(data);
27-
env.CheckValue(data, nameof(data));
28-
env.CheckNonEmpty(columnName, nameof(columnName));
26+
Contracts.CheckValue(data, nameof(data));
27+
Contracts.CheckNonEmpty(columnName, nameof(columnName));
2928

3029
if (!data.Schema.TryGetColumnIndex(columnName, out int col))
31-
throw env.ExceptSchemaMismatch(nameof(columnName), "input", columnName);
30+
throw Contracts.ExceptSchemaMismatch(nameof(columnName), "input", columnName);
3231

3332
// There are two decisions that we make here:
3433
// - Is the T an array type?
@@ -56,7 +55,7 @@ public static IEnumerable<T> GetColumn<T>(this IDataView data, string columnName
5655
{
5756
// Output is an array type.
5857
if (!(colType is VectorType colVectorType))
59-
throw env.ExceptSchemaMismatch(nameof(columnName), "input", columnName, "vector", "scalar");
58+
throw Contracts.ExceptSchemaMismatch(nameof(columnName), "input", columnName, "vector", "scalar");
6059
var elementType = typeof(T).GetElementType();
6160
if (elementType == colVectorType.ItemType.RawType)
6261
{
@@ -75,24 +74,7 @@ public static IEnumerable<T> GetColumn<T>(this IDataView data, string columnName
7574
}
7675
// Fall through to the failure.
7776
}
78-
throw env.Except($"Could not map a data view column '{columnName}' of type {colType} to {typeof(T)}.");
79-
}
80-
81-
/// <summary>
82-
/// Return a <see cref="IHost"/> assigned in a given <see cref="IDataView"/> implementation.
83-
/// </summary>
84-
/// <param name="data">an <see cref="IDataView"/> implementation.</param>
85-
private static IHost RetrieveHost(IDataView data)
86-
{
87-
// Search for the first (if there are multiples) field typed to IHost.
88-
var fields = data.GetType().GetFields(System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
89-
var fieldInfo = fields.Where((field, index) => field.FieldType == typeof(IHost)).FirstOrDefault();
90-
91-
// Check if a IHost really gets retrieved.
92-
string errorMessage = nameof(data) + " should contains a field of " + typeof(IHost);
93-
Contracts.CheckValue(fieldInfo, nameof(data), errorMessage);
94-
95-
return (IHost)fieldInfo.GetValue(data);
77+
throw Contracts.Except($"Could not map a data view column '{columnName}' of type {colType} to {typeof(T)}.");
9678
}
9779

9880
private static IEnumerable<T> GetColumnDirect<T>(IDataView data, int col)

0 commit comments

Comments
 (0)