|
10 | 10 | using Microsoft.ML.EntryPoints; |
11 | 11 | using Microsoft.ML.Internal.Utilities; |
12 | 12 | using Microsoft.ML.Model; |
| 13 | +using Microsoft.ML.Model.OnnxConverter; |
13 | 14 | using Microsoft.ML.Runtime; |
14 | 15 | using Microsoft.ML.Trainers; |
15 | 16 |
|
@@ -222,7 +223,8 @@ internal static CommonOutputs.MulticlassClassificationOutput TrainMulticlassNaiv |
222 | 223 | /// </summary> |
223 | 224 | public sealed class NaiveBayesMulticlassModelParameters : |
224 | 225 | ModelParametersBase<VBuffer<float>>, |
225 | | - IValueMapper |
| 226 | + IValueMapper, |
| 227 | + ISingleCanSaveOnnx |
226 | 228 | { |
227 | 229 | internal const string LoaderSignature = "MultiClassNaiveBayesPred"; |
228 | 230 | private static VersionInfo GetVersionInfo() |
@@ -252,6 +254,8 @@ private static VersionInfo GetVersionInfo() |
252 | 254 |
|
253 | 255 | DataViewType IValueMapper.OutputType => _outputType; |
254 | 256 |
|
| 257 | + bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true; |
| 258 | + |
255 | 259 | /// <summary> |
256 | 260 | /// Get the label histogram. |
257 | 261 | /// </summary> |
@@ -383,6 +387,155 @@ ValueMapper<TIn, TOut> IValueMapper.GetMapper<TIn, TOut>() |
383 | 387 | return (ValueMapper<TIn, TOut>)(Delegate)del; |
384 | 388 | } |
385 | 389 |
|
| 390 | + /// <summary> |
| 391 | + /// Creates an Onnx inferencing model by vectorizing and following the logic found in <see cref="Map"/> |
| 392 | + /// </summary> |
| 393 | + bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn) |
| 394 | + { |
| 395 | + float[] featureHistogram = new float[_featureHistogram[0].Length * _labelHistogram.Length]; |
| 396 | + float[] labelHistogramExpanded = new float[_featureHistogram[0].Length * _labelHistogram.Length]; |
| 397 | + |
| 398 | + for (int i = 0; i < _featureHistogram.Length; i++) |
| 399 | + { |
| 400 | + Array.Copy(_featureHistogram[i], 0, featureHistogram, i * _featureHistogram[i].Length, _featureHistogram[i].Length); |
| 401 | + } |
| 402 | + for (int i = 0; i < _featureHistogram[0].Length; i++) |
| 403 | + { |
| 404 | + Array.Copy(_labelHistogram, 0, labelHistogramExpanded, i * _featureHistogram.Length, _featureHistogram.Length); |
| 405 | + } |
| 406 | + |
| 407 | + var one = ctx.AddInitializer(1.0f, "one"); |
| 408 | + var zero = ctx.AddInitializer(0.0f, "zero"); |
| 409 | + var labelCount = ctx.AddInitializer((float)_labelCount, "labelCount"); |
| 410 | + var trainingCount = ctx.AddInitializer((float)_totalTrainingCount, "totalTrainingCount"); |
| 411 | + var labelHistogram = ctx.AddInitializer(labelHistogramExpanded.Take(_labelHistogram.Length), new long[] { _labelHistogram.Length, 1 }, "labelHistogram"); |
| 412 | + |
| 413 | + var featureHistogramName = ctx.AddInitializer(featureHistogram, new long[] { _featureHistogram.Length, _featureHistogram[0].Length }, "featureHistogram"); |
| 414 | + var labelHistogramName = ctx.AddInitializer(labelHistogramExpanded, new long[] { _featureHistogram[0].Length, _labelHistogram.Length }, "labelHistogramExpanded"); |
| 415 | + var learnedAbsentFeatureLogProb = ctx.AddInitializer(_absentFeaturesLogProb, new long[] { _absentFeaturesLogProb.Length, 1 }, "absentFeaturesLogProb"); |
| 416 | + |
| 417 | + var greaterOutput = ctx.AddIntermediateVariable(null, "greaterOutput", true); |
| 418 | + var opType = "Greater"; |
| 419 | + ctx.CreateNode(opType, new[] { featureColumn, zero }, new[] { greaterOutput }, ctx.GetNodeName(opType), ""); |
| 420 | + |
| 421 | + opType = "Cast"; |
| 422 | + var isFeaturePresent = ctx.AddIntermediateVariable(null, "isFeaturePresent", true); |
| 423 | + var node = ctx.CreateNode(opType, greaterOutput, isFeaturePresent, ctx.GetNodeName(opType), ""); |
| 424 | + var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType(); |
| 425 | + node.AddAttribute("to", t); |
| 426 | + |
| 427 | + //initialize logProb |
| 428 | + opType = "Div"; |
| 429 | + var divOutput = ctx.AddIntermediateVariable(null, "DivOutput", true); |
| 430 | + ctx.CreateNode(opType, new[] { labelHistogram, trainingCount }, new[] { divOutput }, ctx.GetNodeName(opType), ""); |
| 431 | + |
| 432 | + opType = "Log"; |
| 433 | + var logOutput = ctx.AddIntermediateVariable(null, "LogOutput", true); |
| 434 | + ctx.CreateNode(opType, divOutput, logOutput, ctx.GetNodeName(opType), ""); |
| 435 | + |
| 436 | + //log1 |
| 437 | + opType = "Sum"; |
| 438 | + var sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true); |
| 439 | + ctx.CreateNode(opType, new[] { featureHistogramName, one }, new[] { sumOutput }, ctx.GetNodeName(opType), ""); |
| 440 | + |
| 441 | + var logOutput1 = ctx.AddIntermediateVariable(null, "LogOutput", true); |
| 442 | + LogMul(ctx, sumOutput, isFeaturePresent, logOutput1); |
| 443 | + |
| 444 | + //log2 |
| 445 | + opType = "Transpose"; |
| 446 | + var labelHistogramTrans = ctx.AddIntermediateVariable(null, "transpose", true); |
| 447 | + ctx.CreateNode(opType, labelHistogramName, labelHistogramTrans, ctx.GetNodeName(opType), ""); |
| 448 | + |
| 449 | + opType = "Sub"; |
| 450 | + var absentFeatureCount = ctx.AddIntermediateVariable(null, "AbsentFeatureCounts", true); |
| 451 | + ctx.CreateNode(opType, new[] { labelHistogramTrans, featureHistogramName }, new[] { absentFeatureCount }, ctx.GetNodeName(opType), ""); |
| 452 | + |
| 453 | + opType = "Sum"; |
| 454 | + sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true); |
| 455 | + ctx.CreateNode(opType, new[] { labelHistogramTrans, labelCount }, new[] { sumOutput }, ctx.GetNodeName(opType), ""); |
| 456 | + |
| 457 | + var logOutput2 = ctx.AddIntermediateVariable(null, "LogOutput", true); |
| 458 | + LogMul(ctx, sumOutput, isFeaturePresent, logOutput2); |
| 459 | + |
| 460 | + //log3 |
| 461 | + opType = "Sum"; |
| 462 | + sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true); |
| 463 | + ctx.CreateNode(opType, new[] { absentFeatureCount, one }, new[] { sumOutput }, ctx.GetNodeName(opType), ""); |
| 464 | + |
| 465 | + var logOutput3 = ctx.AddIntermediateVariable(null, "LogOutput", true); |
| 466 | + LogMul(ctx, sumOutput, isFeaturePresent, logOutput3); |
| 467 | + |
| 468 | + //result |
| 469 | + opType = "Sub"; |
| 470 | + var logProb = ctx.AddIntermediateVariable(null, "LogProb", true); |
| 471 | + ctx.CreateNode(opType, new[] { logOutput1, logOutput2 }, new[] { logProb }, ctx.GetNodeName(opType), ""); |
| 472 | + |
| 473 | + opType = "Sub"; |
| 474 | + var absentFeatureLogProb = ctx.AddIntermediateVariable(null, "AbsentFeatureLogProb", true); |
| 475 | + ctx.CreateNode(opType, new[] { logOutput3, logOutput2 }, new[] { absentFeatureLogProb }, ctx.GetNodeName(opType), ""); |
| 476 | + |
| 477 | + opType = "ReduceSum"; |
| 478 | + var logProbReduceSum = ctx.AddIntermediateVariable(null, "ReduceSum", true); |
| 479 | + node = ctx.CreateNode(opType, new[] { logProb }, new[] { logProbReduceSum }, ctx.GetNodeName(opType), ""); |
| 480 | + long[] list = { 1 }; |
| 481 | + node.AddAttribute("axes", list); |
| 482 | + |
| 483 | + opType = "ReduceSum"; |
| 484 | + var absentFeatureLogProbReduceSum = ctx.AddIntermediateVariable(null, "ReduceSum", true); |
| 485 | + node = ctx.CreateNode(opType, new[] { absentFeatureLogProb }, new[] { absentFeatureLogProbReduceSum }, ctx.GetNodeName(opType), ""); |
| 486 | + node.AddAttribute("axes", list); |
| 487 | + |
| 488 | + opType = "Cast"; |
| 489 | + var castOutput = ctx.AddIntermediateVariable(null, "CastOutput2", true); |
| 490 | + node = ctx.CreateNode(opType, learnedAbsentFeatureLogProb, castOutput, ctx.GetNodeName(opType), ""); |
| 491 | + t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType(); |
| 492 | + node.AddAttribute("to", t); |
| 493 | + |
| 494 | + opType = "Sub"; |
| 495 | + var subOutput = ctx.AddIntermediateVariable(null, "SubOutput", true); |
| 496 | + ctx.CreateNode(opType, new[] { castOutput, absentFeatureLogProbReduceSum }, new[] { subOutput }, ctx.GetNodeName(opType), ""); |
| 497 | + |
| 498 | + opType = "Sum"; |
| 499 | + sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true); |
| 500 | + ctx.CreateNode(opType, new[] { subOutput, logProbReduceSum, logOutput }, new[] { sumOutput }, ctx.GetNodeName(opType), ""); |
| 501 | + |
| 502 | + opType = "Transpose"; |
| 503 | + var transposeOutput = ctx.AddIntermediateVariable(null, "TransposeOutput", true); |
| 504 | + ctx.CreateNode(opType, new[] { sumOutput }, new[] { outputNames[1] }, ctx.GetNodeName(opType), ""); |
| 505 | + |
| 506 | + opType = "ArgMax"; |
| 507 | + var scoreIndex = ctx.AddIntermediateVariable(null, "ScoreIndex", true); |
| 508 | + ctx.CreateNode(opType, new[] { sumOutput }, new[] { scoreIndex }, ctx.GetNodeName(opType), ""); |
| 509 | + |
| 510 | + opType = "Cast"; |
| 511 | + castOutput = ctx.AddIntermediateVariable(null, "CastOutput3", true); |
| 512 | + node = ctx.CreateNode(opType, scoreIndex, castOutput, ctx.GetNodeName(opType), ""); |
| 513 | + t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType(); |
| 514 | + node.AddAttribute("to", t); |
| 515 | + |
| 516 | + //log3 |
| 517 | + opType = "Sum"; |
| 518 | + sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true); |
| 519 | + ctx.CreateNode(opType, new[] { castOutput, one }, new[] { sumOutput }, ctx.GetNodeName(opType), ""); |
| 520 | + |
| 521 | + opType = "Cast"; |
| 522 | + node = ctx.CreateNode(opType, sumOutput, outputNames[0], ctx.GetNodeName(opType), ""); |
| 523 | + t = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt32).ToType(); |
| 524 | + node.AddAttribute("to", t); |
| 525 | + |
| 526 | + return true; |
| 527 | + } |
| 528 | + |
| 529 | + private void LogMul(OnnxContext ctx, string input, string isFeaturePresent, string output) |
| 530 | + { |
| 531 | + var opType = "Log"; |
| 532 | + var logOutput = ctx.AddIntermediateVariable(null, "LogOutput", true); |
| 533 | + ctx.CreateNode(opType, input, logOutput, ctx.GetNodeName(opType), ""); |
| 534 | + |
| 535 | + opType = "Mul"; |
| 536 | + ctx.CreateNode(opType, new[] { logOutput, isFeaturePresent }, new[] { output }, ctx.GetNodeName(opType), ""); |
| 537 | + } |
| 538 | + |
386 | 539 | private void ComputeLabelProbabilityFromFeature(double labelOccurrenceCount, int labelIndex, int featureIndex, |
387 | 540 | float featureValue, ref double logProb, ref double absentFeatureLogProb) |
388 | 541 | { |
|
0 commit comments