Skip to content

Commit 4bf2ace

Browse files
authored
Added onnx export for NaiveBayesMulticlassTrainer (#4636)
* onnx export for naive bayes * adds double vector onnx support * resolving comments
1 parent 795559d commit 4bf2ace

File tree

5 files changed

+190
-1
lines changed

5 files changed

+190
-1
lines changed

src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,16 @@ public OnnxNode CreateNode(string opType, string input, string output, string na
170170
/// <returns>The initializer's ONNX name</returns>
171171
public abstract string AddInitializer(IEnumerable<long> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true);
172172

173+
/// <summary>
174+
/// Call this function can declare a global double tensor
175+
/// </summary>
176+
/// <param name="values">The doubles which are going to be added into the ONNX graph</param>
177+
/// <param name="dims">The shape that the doubles</param>
178+
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
179+
/// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
180+
/// <returns>The initializer's ONNX name</returns>
181+
public abstract string AddInitializer(IEnumerable<double> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true);
182+
173183
/// <summary>
174184
/// Call this function can declare a global string tensor
175185
/// </summary>

src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,17 @@ public override string AddInitializer(IEnumerable<long> values, IEnumerable<long
313313
return name;
314314
}
315315

316+
public override string AddInitializer(IEnumerable<double> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true)
317+
{
318+
_host.CheckValue(values, nameof(values));
319+
if (dims != null)
320+
_host.Check(dims.Aggregate((x, y) => x * y) == values.Count(), "Number of elements doesn't match tensor size");
321+
322+
name = AddVariable(name ?? "double", makeUniqueName);
323+
_initializers.Add(OnnxUtils.MakeDouble(name, values, dims));
324+
return name;
325+
}
326+
316327
public override string AddInitializer(IEnumerable<string> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true)
317328
{
318329
_host.CheckValue(values, nameof(values));

src/Microsoft.ML.OnnxConverter/OnnxUtils.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,20 @@ public static TensorProto MakeInt64s(string name, IEnumerable<long> values, IEnu
405405
return tensor;
406406
}
407407

408+
// Make double vector (i.e., 1-D tensor) with dims=null. Otherwise, dims is used as the shape of the produced tensor.
409+
public static TensorProto MakeDouble(string name, IEnumerable<double> values, IEnumerable<long> dims = null)
410+
{
411+
var tensor = new TensorProto();
412+
tensor.Name = name;
413+
tensor.DataType = (int)TensorProto.Types.DataType.Double;
414+
tensor.DoubleData.AddRange(values);
415+
if (dims != null)
416+
tensor.Dims.AddRange(dims);
417+
else
418+
tensor.Dims.Add(values.Count());
419+
return tensor;
420+
}
421+
408422
// Make float scalar in ONNX from native C# number
409423
public static TensorProto MakeFloat(string name, float value)
410424
{

src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/MulticlassNaiveBayesTrainer.cs

Lines changed: 154 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
using Microsoft.ML.EntryPoints;
1111
using Microsoft.ML.Internal.Utilities;
1212
using Microsoft.ML.Model;
13+
using Microsoft.ML.Model.OnnxConverter;
1314
using Microsoft.ML.Runtime;
1415
using Microsoft.ML.Trainers;
1516

@@ -222,7 +223,8 @@ internal static CommonOutputs.MulticlassClassificationOutput TrainMulticlassNaiv
222223
/// </summary>
223224
public sealed class NaiveBayesMulticlassModelParameters :
224225
ModelParametersBase<VBuffer<float>>,
225-
IValueMapper
226+
IValueMapper,
227+
ISingleCanSaveOnnx
226228
{
227229
internal const string LoaderSignature = "MultiClassNaiveBayesPred";
228230
private static VersionInfo GetVersionInfo()
@@ -252,6 +254,8 @@ private static VersionInfo GetVersionInfo()
252254

253255
DataViewType IValueMapper.OutputType => _outputType;
254256

257+
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true;
258+
255259
/// <summary>
256260
/// Get the label histogram.
257261
/// </summary>
@@ -383,6 +387,155 @@ ValueMapper<TIn, TOut> IValueMapper.GetMapper<TIn, TOut>()
383387
return (ValueMapper<TIn, TOut>)(Delegate)del;
384388
}
385389

390+
/// <summary>
391+
/// Creates an Onnx inferencing model by vectorizing and following the logic found in <see cref="Map"/>
392+
/// </summary>
393+
bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
394+
{
395+
float[] featureHistogram = new float[_featureHistogram[0].Length * _labelHistogram.Length];
396+
float[] labelHistogramExpanded = new float[_featureHistogram[0].Length * _labelHistogram.Length];
397+
398+
for (int i = 0; i < _featureHistogram.Length; i++)
399+
{
400+
Array.Copy(_featureHistogram[i], 0, featureHistogram, i * _featureHistogram[i].Length, _featureHistogram[i].Length);
401+
}
402+
for (int i = 0; i < _featureHistogram[0].Length; i++)
403+
{
404+
Array.Copy(_labelHistogram, 0, labelHistogramExpanded, i * _featureHistogram.Length, _featureHistogram.Length);
405+
}
406+
407+
var one = ctx.AddInitializer(1.0f, "one");
408+
var zero = ctx.AddInitializer(0.0f, "zero");
409+
var labelCount = ctx.AddInitializer((float)_labelCount, "labelCount");
410+
var trainingCount = ctx.AddInitializer((float)_totalTrainingCount, "totalTrainingCount");
411+
var labelHistogram = ctx.AddInitializer(labelHistogramExpanded.Take(_labelHistogram.Length), new long[] { _labelHistogram.Length, 1 }, "labelHistogram");
412+
413+
var featureHistogramName = ctx.AddInitializer(featureHistogram, new long[] { _featureHistogram.Length, _featureHistogram[0].Length }, "featureHistogram");
414+
var labelHistogramName = ctx.AddInitializer(labelHistogramExpanded, new long[] { _featureHistogram[0].Length, _labelHistogram.Length }, "labelHistogramExpanded");
415+
var learnedAbsentFeatureLogProb = ctx.AddInitializer(_absentFeaturesLogProb, new long[] { _absentFeaturesLogProb.Length, 1 }, "absentFeaturesLogProb");
416+
417+
var greaterOutput = ctx.AddIntermediateVariable(null, "greaterOutput", true);
418+
var opType = "Greater";
419+
ctx.CreateNode(opType, new[] { featureColumn, zero }, new[] { greaterOutput }, ctx.GetNodeName(opType), "");
420+
421+
opType = "Cast";
422+
var isFeaturePresent = ctx.AddIntermediateVariable(null, "isFeaturePresent", true);
423+
var node = ctx.CreateNode(opType, greaterOutput, isFeaturePresent, ctx.GetNodeName(opType), "");
424+
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
425+
node.AddAttribute("to", t);
426+
427+
//initialize logProb
428+
opType = "Div";
429+
var divOutput = ctx.AddIntermediateVariable(null, "DivOutput", true);
430+
ctx.CreateNode(opType, new[] { labelHistogram, trainingCount }, new[] { divOutput }, ctx.GetNodeName(opType), "");
431+
432+
opType = "Log";
433+
var logOutput = ctx.AddIntermediateVariable(null, "LogOutput", true);
434+
ctx.CreateNode(opType, divOutput, logOutput, ctx.GetNodeName(opType), "");
435+
436+
//log1
437+
opType = "Sum";
438+
var sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true);
439+
ctx.CreateNode(opType, new[] { featureHistogramName, one }, new[] { sumOutput }, ctx.GetNodeName(opType), "");
440+
441+
var logOutput1 = ctx.AddIntermediateVariable(null, "LogOutput", true);
442+
LogMul(ctx, sumOutput, isFeaturePresent, logOutput1);
443+
444+
//log2
445+
opType = "Transpose";
446+
var labelHistogramTrans = ctx.AddIntermediateVariable(null, "transpose", true);
447+
ctx.CreateNode(opType, labelHistogramName, labelHistogramTrans, ctx.GetNodeName(opType), "");
448+
449+
opType = "Sub";
450+
var absentFeatureCount = ctx.AddIntermediateVariable(null, "AbsentFeatureCounts", true);
451+
ctx.CreateNode(opType, new[] { labelHistogramTrans, featureHistogramName }, new[] { absentFeatureCount }, ctx.GetNodeName(opType), "");
452+
453+
opType = "Sum";
454+
sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true);
455+
ctx.CreateNode(opType, new[] { labelHistogramTrans, labelCount }, new[] { sumOutput }, ctx.GetNodeName(opType), "");
456+
457+
var logOutput2 = ctx.AddIntermediateVariable(null, "LogOutput", true);
458+
LogMul(ctx, sumOutput, isFeaturePresent, logOutput2);
459+
460+
//log3
461+
opType = "Sum";
462+
sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true);
463+
ctx.CreateNode(opType, new[] { absentFeatureCount, one }, new[] { sumOutput }, ctx.GetNodeName(opType), "");
464+
465+
var logOutput3 = ctx.AddIntermediateVariable(null, "LogOutput", true);
466+
LogMul(ctx, sumOutput, isFeaturePresent, logOutput3);
467+
468+
//result
469+
opType = "Sub";
470+
var logProb = ctx.AddIntermediateVariable(null, "LogProb", true);
471+
ctx.CreateNode(opType, new[] { logOutput1, logOutput2 }, new[] { logProb }, ctx.GetNodeName(opType), "");
472+
473+
opType = "Sub";
474+
var absentFeatureLogProb = ctx.AddIntermediateVariable(null, "AbsentFeatureLogProb", true);
475+
ctx.CreateNode(opType, new[] { logOutput3, logOutput2 }, new[] { absentFeatureLogProb }, ctx.GetNodeName(opType), "");
476+
477+
opType = "ReduceSum";
478+
var logProbReduceSum = ctx.AddIntermediateVariable(null, "ReduceSum", true);
479+
node = ctx.CreateNode(opType, new[] { logProb }, new[] { logProbReduceSum }, ctx.GetNodeName(opType), "");
480+
long[] list = { 1 };
481+
node.AddAttribute("axes", list);
482+
483+
opType = "ReduceSum";
484+
var absentFeatureLogProbReduceSum = ctx.AddIntermediateVariable(null, "ReduceSum", true);
485+
node = ctx.CreateNode(opType, new[] { absentFeatureLogProb }, new[] { absentFeatureLogProbReduceSum }, ctx.GetNodeName(opType), "");
486+
node.AddAttribute("axes", list);
487+
488+
opType = "Cast";
489+
var castOutput = ctx.AddIntermediateVariable(null, "CastOutput2", true);
490+
node = ctx.CreateNode(opType, learnedAbsentFeatureLogProb, castOutput, ctx.GetNodeName(opType), "");
491+
t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
492+
node.AddAttribute("to", t);
493+
494+
opType = "Sub";
495+
var subOutput = ctx.AddIntermediateVariable(null, "SubOutput", true);
496+
ctx.CreateNode(opType, new[] { castOutput, absentFeatureLogProbReduceSum }, new[] { subOutput }, ctx.GetNodeName(opType), "");
497+
498+
opType = "Sum";
499+
sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true);
500+
ctx.CreateNode(opType, new[] { subOutput, logProbReduceSum, logOutput }, new[] { sumOutput }, ctx.GetNodeName(opType), "");
501+
502+
opType = "Transpose";
503+
var transposeOutput = ctx.AddIntermediateVariable(null, "TransposeOutput", true);
504+
ctx.CreateNode(opType, new[] { sumOutput }, new[] { outputNames[1] }, ctx.GetNodeName(opType), "");
505+
506+
opType = "ArgMax";
507+
var scoreIndex = ctx.AddIntermediateVariable(null, "ScoreIndex", true);
508+
ctx.CreateNode(opType, new[] { sumOutput }, new[] { scoreIndex }, ctx.GetNodeName(opType), "");
509+
510+
opType = "Cast";
511+
castOutput = ctx.AddIntermediateVariable(null, "CastOutput3", true);
512+
node = ctx.CreateNode(opType, scoreIndex, castOutput, ctx.GetNodeName(opType), "");
513+
t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
514+
node.AddAttribute("to", t);
515+
516+
//log3
517+
opType = "Sum";
518+
sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true);
519+
ctx.CreateNode(opType, new[] { castOutput, one }, new[] { sumOutput }, ctx.GetNodeName(opType), "");
520+
521+
opType = "Cast";
522+
node = ctx.CreateNode(opType, sumOutput, outputNames[0], ctx.GetNodeName(opType), "");
523+
t = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt32).ToType();
524+
node.AddAttribute("to", t);
525+
526+
return true;
527+
}
528+
529+
private void LogMul(OnnxContext ctx, string input, string isFeaturePresent, string output)
530+
{
531+
var opType = "Log";
532+
var logOutput = ctx.AddIntermediateVariable(null, "LogOutput", true);
533+
ctx.CreateNode(opType, input, logOutput, ctx.GetNodeName(opType), "");
534+
535+
opType = "Mul";
536+
ctx.CreateNode(opType, new[] { logOutput, isFeaturePresent }, new[] { output }, ctx.GetNodeName(opType), "");
537+
}
538+
386539
private void ComputeLabelProbabilityFromFeature(double labelOccurrenceCount, int labelIndex, int featureIndex,
387540
float featureValue, ref double logProb, ref double absentFeatureLogProb)
388541
{

test/Microsoft.ML.Tests/OnnxConversionTest.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,6 +1177,7 @@ public void MulticlassTrainersOnnxConversionTest()
11771177
List<IEstimator<ITransformer>> estimators = new List<IEstimator<ITransformer>>()
11781178
{
11791179
mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy(),
1180+
mlContext.MulticlassClassification.Trainers.NaiveBayes(),
11801181
mlContext.MulticlassClassification.Trainers.OneVersusAll(
11811182
mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(), useProbabilities:false),
11821183
mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy(),

0 commit comments

Comments
 (0)