Skip to content

Commit 8660ecc

Browse files
authored
CountFeatureSelectingEstimator no selection support (#5000)
* slotdroppingtransformer fix
1 parent 8fb8420 commit 8660ecc

File tree

3 files changed

+53
-51
lines changed

3 files changed

+53
-51
lines changed

src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,6 @@ private static readonly FuncInstanceMethodInfo1<Mapper, DataViewRow, int, Delega
460460
private readonly SlotsDroppingTransformer _parent;
461461
private readonly int[] _cols;
462462
private readonly DataViewType[] _srcTypes;
463-
private readonly DataViewType[] _rawTypes;
464463
private readonly DataViewType[] _dstTypes;
465464
private readonly SlotDropper[] _slotDropper;
466465
// Track if all the slots of the column are to be dropped.
@@ -473,7 +472,6 @@ public Mapper(SlotsDroppingTransformer parent, DataViewSchema inputSchema)
473472
_parent = parent;
474473
_cols = new int[_parent.ColumnPairs.Length];
475474
_srcTypes = new DataViewType[_parent.ColumnPairs.Length];
476-
_rawTypes = new DataViewType[_parent.ColumnPairs.Length];
477475
_dstTypes = new DataViewType[_parent.ColumnPairs.Length];
478476
_slotDropper = new SlotDropper[_parent.ColumnPairs.Length];
479477
_suppressed = new bool[_parent.ColumnPairs.Length];
@@ -486,8 +484,8 @@ public Mapper(SlotsDroppingTransformer parent, DataViewSchema inputSchema)
486484
_srcTypes[i] = inputSchema[_cols[i]].Type;
487485
VectorDataViewType srcVectorType = _srcTypes[i] as VectorDataViewType;
488486

489-
_rawTypes[i] = srcVectorType?.ItemType ?? _srcTypes[i];
490-
if (!IsValidColumnType(_rawTypes[i]))
487+
var rawType = srcVectorType?.ItemType ?? _srcTypes[i];
488+
if (!IsValidColumnType(rawType))
491489
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].inputColumnName);
492490

493491
int valueCount = srcVectorType?.Size ?? 1;
@@ -898,27 +896,26 @@ public void SaveAsOnnx(OnnxContext ctx)
898896
public bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
899897
{
900898
string opType;
901-
if (_srcTypes[iinfo] is VectorDataViewType)
899+
var slots = _slotDropper[iinfo].GetPreservedSlots();
900+
// vector column is not suppressed
901+
if (slots.Count() > 0)
902902
{
903903
opType = "GatherElements";
904-
IEnumerable<long> slots = _slotDropper[iinfo].GetPreservedSlots();
905904
var slotsVar = ctx.AddInitializer(slots, new long[] { 1, slots.Count() }, "PreservedSlots");
906905
var node = ctx.CreateNode(opType, new[] { srcVariableName, slotsVar }, new[] { dstVariableName }, ctx.GetNodeName(opType), "");
907906
node.AddAttribute("axis", 1);
908907
}
908+
// When the vector/scalar columnn is suppressed, we simply create an empty output vector
909909
else
910910
{
911911
string constVal;
912-
long[] dims = { 1, 1 };
913-
float[] floatVals = { 0.0f };
914-
long[] keyVals = { 0 };
915-
string[] stringVals = { "" };
916-
if (_rawTypes[iinfo] is TextDataViewType)
917-
constVal = ctx.AddInitializer(stringVals, dims);
918-
else if (_rawTypes[iinfo] is KeyDataViewType)
919-
constVal = ctx.AddInitializer(keyVals, dims);
912+
var type = _srcTypes[iinfo].GetItemType();
913+
if (type == TextDataViewType.Instance)
914+
constVal = ctx.AddInitializer(new string[] { "" }, new long[] { 1, 1 });
915+
else if (type == NumberDataViewType.Single)
916+
constVal = ctx.AddInitializer(new float[] { 0 }, new long[] { 1, 1 });
920917
else
921-
constVal = ctx.AddInitializer(floatVals, dims);
918+
constVal = ctx.AddInitializer(new double[] { 0 }, new long[] { 1, 1 });
922919

923920
opType = "Identity";
924921
ctx.CreateNode(opType, constVal, dstVariableName, ctx.GetNodeName(opType), "");

src/Microsoft.ML.Transforms/CountFeatureSelection.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ namespace Microsoft.ML.Transforms
2828
/// | | |
2929
/// | -- | -- |
3030
/// | Does this estimator need to look at the data to train its parameters? | Yes |
31-
/// | Input column data type | Vector or scalar of numeric, [text](xref:Microsoft.ML.Data.TextDataViewType) or [key](xref:Microsoft.ML.Data.KeyDataViewType) data types|
31+
/// | Input column data type | Vector or scalar of <xref:System.Single>, <xref:System.Double> or [text](xref:Microsoft.ML.Data.TextDataViewType) data types|
3232
/// | Output column data type | Same as the input column|
3333
/// | Exportable to ONNX | Yes |
3434
///

test/Microsoft.ML.Tests/OnnxConversionTest.cs

Lines changed: 40 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1648,57 +1648,62 @@ public void UseKeyDataViewTypeAsUInt32InOnnxInput()
16481648
Done();
16491649
}
16501650

1651-
[Fact]
1652-
public void FeatureSelectionOnnxTest()
1651+
[Theory]
1652+
[InlineData(DataKind.String)]
1653+
[InlineData(DataKind.Single)]
1654+
[InlineData(DataKind.Double)]
1655+
public void FeatureSelectionOnnxTest(DataKind dataKind)
16531656
{
16541657
var mlContext = new MLContext(seed: 1);
16551658

16561659
string dataPath = GetDataPath("breast-cancer.txt");
16571660

1658-
var dataView = ML.Data.LoadFromTextFile(dataPath, new[] {
1659-
new TextLoader.Column("ScalarFloat", DataKind.Single, 6),
1660-
new TextLoader.Column("VectorFloat", DataKind.Single, 1, 4),
1661-
new TextLoader.Column("VectorDouble", DataKind.Double, 4, 8),
1661+
var dataView = mlContext.Data.LoadFromTextFile(dataPath, new[] {
1662+
new TextLoader.Column("Scalar", dataKind, 6),
1663+
new TextLoader.Column("Vector", dataKind, 1, 6),
16621664
new TextLoader.Column("Label", DataKind.Boolean, 0)
16631665
});
16641666

1665-
var columns = new[] {
1666-
new CountFeatureSelectingEstimator.ColumnOptions("FeatureSelectDouble", "VectorDouble", count: 1),
1667-
new CountFeatureSelectingEstimator.ColumnOptions("ScalFeatureSelectMissing690", "ScalarFloat", count: 690),
1668-
new CountFeatureSelectingEstimator.ColumnOptions("ScalFeatureSelectMissing100", "ScalarFloat", count: 100),
1669-
new CountFeatureSelectingEstimator.ColumnOptions("VecFeatureSelectMissing690", "VectorDouble", count: 690),
1670-
new CountFeatureSelectingEstimator.ColumnOptions("VecFeatureSelectMissing100", "VectorDouble", count: 100)
1671-
};
1672-
var pipeline = ML.Transforms.FeatureSelection.SelectFeaturesBasedOnCount("FeatureSelect", "VectorFloat", count: 1)
1673-
.Append(ML.Transforms.FeatureSelection.SelectFeaturesBasedOnCount(columns))
1674-
.Append(ML.Transforms.FeatureSelection.SelectFeaturesBasedOnMutualInformation("FeatureSelectMIScalarFloat", "ScalarFloat"))
1675-
.Append(ML.Transforms.FeatureSelection.SelectFeaturesBasedOnMutualInformation("FeatureSelectMIVectorFloat", "VectorFloat"));
1667+
IEstimator<ITransformer>[] pipelines =
1668+
{
1669+
// one or more features selected
1670+
mlContext.Transforms.FeatureSelection.SelectFeaturesBasedOnCount("VectorOutput", "Vector", count: 690).
1671+
Append(mlContext.Transforms.FeatureSelection.SelectFeaturesBasedOnCount("ScalarOutput", "Scalar", count: 100)),
16761672

1677-
var model = pipeline.Fit(dataView);
1678-
var transformedData = model.Transform(dataView);
1679-
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView);
1673+
// no feature selected => column suppressed
1674+
mlContext.Transforms.FeatureSelection.SelectFeaturesBasedOnCount("VectorOutput", "Vector", count: 800).
1675+
Append(mlContext.Transforms.FeatureSelection.SelectFeaturesBasedOnCount("ScalarOutput", "Scalar", count: 800)),
16801676

1681-
var onnxFileName = "countfeatures.onnx";
1682-
var onnxModelPath = GetOutputPath(onnxFileName);
1677+
mlContext.Transforms.FeatureSelection.SelectFeaturesBasedOnMutualInformation("VectorOutput", "Vector").
1678+
Append(mlContext.Transforms.FeatureSelection.SelectFeaturesBasedOnMutualInformation("ScalarOutput", "Scalar"))
1679+
};
1680+
for (int i = 0; i < pipelines.Length; i++)
1681+
{
1682+
//There's currently no support for suppressed string columns, since onnx string variable initiation is not supported
1683+
if (dataKind == DataKind.String && i > 0)
1684+
break;
1685+
var model = pipelines[i].Fit(dataView);
1686+
var transformedData = model.Transform(dataView);
1687+
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView);
16831688

1684-
SaveOnnxModel(onnxModel, onnxModelPath, null);
1689+
var onnxFileName = "countfeatures.onnx";
1690+
var onnxModelPath = GetOutputPath(onnxFileName);
16851691

1686-
if (IsOnnxRuntimeSupported())
1687-
{
1688-
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
1689-
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(onnxModelPath);
1690-
var onnxTransformer = onnxEstimator.Fit(dataView);
1691-
var onnxResult = onnxTransformer.Transform(dataView);
1692-
CompareSelectedColumns<float>("FeatureSelectMIScalarFloat", "FeatureSelectMIScalarFloat", transformedData, onnxResult);
1693-
CompareSelectedColumns<float>("FeatureSelectMIVectorFloat", "FeatureSelectMIVectorFloat", transformedData, onnxResult);
1694-
CompareSelectedColumns<float>("ScalFeatureSelectMissing690", "ScalFeatureSelectMissing690", transformedData, onnxResult);
1695-
CompareSelectedColumns<double>("VecFeatureSelectMissing690", "VecFeatureSelectMissing690", transformedData, onnxResult);
1692+
SaveOnnxModel(onnxModel, onnxModelPath, null);
1693+
1694+
if (IsOnnxRuntimeSupported())
1695+
{
1696+
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
1697+
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(onnxModelPath);
1698+
var onnxTransformer = onnxEstimator.Fit(dataView);
1699+
var onnxResult = onnxTransformer.Transform(dataView);
1700+
CompareResults("VectorOutput", "VectorOutput", transformedData, onnxResult);
1701+
CompareResults("ScalarOutput", "ScalarOutput", transformedData, onnxResult);
1702+
}
16961703
}
16971704
Done();
16981705
}
16991706

1700-
1701-
17021707
[Fact]
17031708
public void SelectColumnsOnnxTest()
17041709
{

0 commit comments

Comments
 (0)