@@ -1648,57 +1648,62 @@ public void UseKeyDataViewTypeAsUInt32InOnnxInput()
16481648 Done ( ) ;
16491649 }
16501650
1651- [ Fact ]
1652- public void FeatureSelectionOnnxTest ( )
1651+ [ Theory ]
1652+ [ InlineData ( DataKind . String ) ]
1653+ [ InlineData ( DataKind . Single ) ]
1654+ [ InlineData ( DataKind . Double ) ]
1655+ public void FeatureSelectionOnnxTest ( DataKind dataKind )
16531656 {
16541657 var mlContext = new MLContext ( seed : 1 ) ;
16551658
16561659 string dataPath = GetDataPath ( "breast-cancer.txt" ) ;
16571660
1658- var dataView = ML . Data . LoadFromTextFile ( dataPath , new [ ] {
1659- new TextLoader . Column ( "ScalarFloat" , DataKind . Single , 6 ) ,
1660- new TextLoader . Column ( "VectorFloat" , DataKind . Single , 1 , 4 ) ,
1661- new TextLoader . Column ( "VectorDouble" , DataKind . Double , 4 , 8 ) ,
1661+ var dataView = mlContext . Data . LoadFromTextFile ( dataPath , new [ ] {
1662+ new TextLoader . Column ( "Scalar" , dataKind , 6 ) ,
1663+ new TextLoader . Column ( "Vector" , dataKind , 1 , 6 ) ,
16621664 new TextLoader . Column ( "Label" , DataKind . Boolean , 0 )
16631665 } ) ;
16641666
1665- var columns = new [ ] {
1666- new CountFeatureSelectingEstimator . ColumnOptions ( "FeatureSelectDouble" , "VectorDouble" , count : 1 ) ,
1667- new CountFeatureSelectingEstimator . ColumnOptions ( "ScalFeatureSelectMissing690" , "ScalarFloat" , count : 690 ) ,
1668- new CountFeatureSelectingEstimator . ColumnOptions ( "ScalFeatureSelectMissing100" , "ScalarFloat" , count : 100 ) ,
1669- new CountFeatureSelectingEstimator . ColumnOptions ( "VecFeatureSelectMissing690" , "VectorDouble" , count : 690 ) ,
1670- new CountFeatureSelectingEstimator . ColumnOptions ( "VecFeatureSelectMissing100" , "VectorDouble" , count : 100 )
1671- } ;
1672- var pipeline = ML . Transforms . FeatureSelection . SelectFeaturesBasedOnCount ( "FeatureSelect" , "VectorFloat" , count : 1 )
1673- . Append ( ML . Transforms . FeatureSelection . SelectFeaturesBasedOnCount ( columns ) )
1674- . Append ( ML . Transforms . FeatureSelection . SelectFeaturesBasedOnMutualInformation ( "FeatureSelectMIScalarFloat" , "ScalarFloat" ) )
1675- . Append ( ML . Transforms . FeatureSelection . SelectFeaturesBasedOnMutualInformation ( "FeatureSelectMIVectorFloat" , "VectorFloat" ) ) ;
1667+ IEstimator < ITransformer > [ ] pipelines =
1668+ {
1669+ // one or more features selected
1670+ mlContext . Transforms . FeatureSelection . SelectFeaturesBasedOnCount ( "VectorOutput" , "Vector" , count : 690 ) .
1671+ Append ( mlContext . Transforms . FeatureSelection . SelectFeaturesBasedOnCount ( "ScalarOutput" , "Scalar" , count : 100 ) ) ,
16761672
1677- var model = pipeline . Fit ( dataView ) ;
1678- var transformedData = model . Transform ( dataView ) ;
1679- var onnxModel = mlContext . Model . ConvertToOnnxProtobuf ( model , dataView ) ;
1673+ // no feature selected => column suppressed
1674+ mlContext . Transforms . FeatureSelection . SelectFeaturesBasedOnCount ( "VectorOutput" , "Vector" , count : 800 ) .
1675+ Append ( mlContext . Transforms . FeatureSelection . SelectFeaturesBasedOnCount ( "ScalarOutput" , "Scalar" , count : 800 ) ) ,
16801676
1681- var onnxFileName = "countfeatures.onnx" ;
1682- var onnxModelPath = GetOutputPath ( onnxFileName ) ;
1677+ mlContext . Transforms . FeatureSelection . SelectFeaturesBasedOnMutualInformation ( "VectorOutput" , "Vector" ) .
1678+ Append ( mlContext . Transforms . FeatureSelection . SelectFeaturesBasedOnMutualInformation ( "ScalarOutput" , "Scalar" ) )
1679+ } ;
1680+ for ( int i = 0 ; i < pipelines . Length ; i ++ )
1681+ {
1682+ //There's currently no support for suppressed string columns, since onnx string variable initiation is not supported
1683+ if ( dataKind == DataKind . String && i > 0 )
1684+ break ;
1685+ var model = pipelines [ i ] . Fit ( dataView ) ;
1686+ var transformedData = model . Transform ( dataView ) ;
1687+ var onnxModel = mlContext . Model . ConvertToOnnxProtobuf ( model , dataView ) ;
16831688
1684- SaveOnnxModel ( onnxModel , onnxModelPath , null ) ;
1689+ var onnxFileName = "countfeatures.onnx" ;
1690+ var onnxModelPath = GetOutputPath ( onnxFileName ) ;
16851691
1686- if ( IsOnnxRuntimeSupported ( ) )
1687- {
1688- // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
1689- var onnxEstimator = mlContext . Transforms . ApplyOnnxModel ( onnxModelPath ) ;
1690- var onnxTransformer = onnxEstimator . Fit ( dataView ) ;
1691- var onnxResult = onnxTransformer . Transform ( dataView ) ;
1692- CompareSelectedColumns < float > ( "FeatureSelectMIScalarFloat" , "FeatureSelectMIScalarFloat" , transformedData , onnxResult ) ;
1693- CompareSelectedColumns < float > ( "FeatureSelectMIVectorFloat" , "FeatureSelectMIVectorFloat" , transformedData , onnxResult ) ;
1694- CompareSelectedColumns < float > ( "ScalFeatureSelectMissing690" , "ScalFeatureSelectMissing690" , transformedData , onnxResult ) ;
1695- CompareSelectedColumns < double > ( "VecFeatureSelectMissing690" , "VecFeatureSelectMissing690" , transformedData , onnxResult ) ;
1692+ SaveOnnxModel ( onnxModel , onnxModelPath , null ) ;
1693+
1694+ if ( IsOnnxRuntimeSupported ( ) )
1695+ {
1696+ // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
1697+ var onnxEstimator = mlContext . Transforms . ApplyOnnxModel ( onnxModelPath ) ;
1698+ var onnxTransformer = onnxEstimator . Fit ( dataView ) ;
1699+ var onnxResult = onnxTransformer . Transform ( dataView ) ;
1700+ CompareResults ( "VectorOutput" , "VectorOutput" , transformedData , onnxResult ) ;
1701+ CompareResults ( "ScalarOutput" , "ScalarOutput" , transformedData , onnxResult ) ;
1702+ }
16961703 }
16971704 Done ( ) ;
16981705 }
16991706
1700-
1701-
17021707 [ Fact ]
17031708 public void SelectColumnsOnnxTest ( )
17041709 {
0 commit comments