Skip to content

Commit a44386d

Browse files
authored
CountFeatureSelection transform doesn't work with text (dotnet#1365)
* Make CountFeatureSelection work with text data * Add baseline files
1 parent 76d1203 commit a44386d

File tree

4 files changed

+816
-5
lines changed

4 files changed

+816
-5
lines changed

src/Microsoft.ML.Transforms/CountFeatureSelection.cs

+4-5
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,6 @@ private static CountAggregator GetOneAggregator(IRow row, ColumnType colType, in
212212
}
213213

214214
private static CountAggregator GetOneAggregator<T>(IRow row, ColumnType colType, int colSrc)
215-
where T : IEquatable<T>
216215
{
217216
return new CountAggregator<T>(colType, row.GetGetter<T>(colSrc));
218217
}
@@ -225,7 +224,6 @@ private static CountAggregator GetVecAggregator(IRow row, ColumnType colType, in
225224
}
226225

227226
private static CountAggregator GetVecAggregator<T>(IRow row, ColumnType colType, int colSrc)
228-
where T : IEquatable<T>
229227
{
230228
return new CountAggregator<T>(colType, row.GetGetter<VBuffer<T>>(colSrc));
231229
}
@@ -237,7 +235,6 @@ private abstract class CountAggregator
237235
}
238236

239237
private sealed class CountAggregator<T> : CountAggregator, IColumnAggregator<VBuffer<T>>
240-
where T : IEquatable<T>
241238
{
242239
private readonly long[] _count;
243240
private readonly Action _fillBuffer;
@@ -258,7 +255,8 @@ public CountAggregator(ColumnType type, ValueGetter<T> getter)
258255
_buffer.Values[0] = t;
259256
};
260257
_isDefault = Conversions.Instance.GetIsDefaultPredicate<T>(type);
261-
_isMissing = Conversions.Instance.GetIsNAPredicate<T>(type);
258+
if (!Conversions.Instance.TryGetIsNAPredicate<T>(type, out _isMissing))
259+
_isMissing = (ref T value) => false;
262260
}
263261

264262
public CountAggregator(ColumnType type, ValueGetter<VBuffer<T>> getter)
@@ -268,7 +266,8 @@ public CountAggregator(ColumnType type, ValueGetter<VBuffer<T>> getter)
268266
_count = new long[size];
269267
_fillBuffer = () => getter(ref _buffer);
270268
_isDefault = Conversions.Instance.GetIsDefaultPredicate<T>(type.ItemType);
271-
_isMissing = Conversions.Instance.GetIsNAPredicate<T>(type.ItemType);
269+
if (!Conversions.Instance.TryGetIsNAPredicate<T>(type.ItemType, out _isMissing))
270+
_isMissing = (ref T value) => false;
272271
}
273272

274273
public override long[] Count

0 commit comments

Comments
 (0)