Skip to content

Commit 10a9f9e

Browse files
jianjunzabrown
andauthored
WinML backend for wasi-nn (#7807)
* Add WinML backend for wasi-nn. * Log execution time. * WinML backend supports execution target selection. ExecutionTarget::Gpu is mapped to LearningModelDeviceKind::DirectX. * Limit WinML backend on Windows only. * Move wasi-nn WinML example to new test infra. * Scale tensor data in test app. App knows input and target data range, so it's better to let app to handle scaling. * Remove old example for wasi-nn WinML backend. * Update image2tensor link. * Format code. * Upgrade image2tensor to 0.3.1. * Upgrade windows to 0.52.0 * Use tensor data as input for wasi-nn WinML backend test. To avoid involving too many external dependencies, input image is converted to tensor data offline. * Restore trailing new line for Cargo.toml. * Remove unnecessary features for windows crate. * Check input tensor types. Only FP32 is supported right now. Reject other tensor types. * Rename default model name to model.onnx. It aligns with openvino backend. prtest:full * Run nn_image_classification_winml only when winml is enabled. * vet: add trusted `windows` crate to lockfile * Fix wasi-nn tests when both openvino and winml are enabled. * Add check for WinML availability. * vet: reapply vet lock --------- Co-authored-by: Andrew Brown <andrew.brown@intel.com>
1 parent 8eab5f8 commit 10a9f9e

9 files changed

Lines changed: 459 additions & 39 deletions

File tree

Cargo.lock

Lines changed: 45 additions & 25 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
use anyhow::Result;
2+
use std::fs;
3+
use std::time::Instant;
4+
use wasi_nn::*;
5+
6+
pub fn main() -> Result<()> {
7+
// Graph is supposed to be preloaded by `nn-graph` argument. The path ends with "mobilenet".
8+
let graph =
9+
wasi_nn::GraphBuilder::new(wasi_nn::GraphEncoding::Onnx, wasi_nn::ExecutionTarget::CPU)
10+
.build_from_cache("mobilenet")
11+
.unwrap();
12+
13+
let mut context = graph.init_execution_context().unwrap();
14+
println!("Created an execution context.");
15+
16+
// Convert image to tensor data.
17+
let tensor_data = fs::read("fixture/kitten.rgb")?;
18+
context
19+
.set_input(0, TensorType::F32, &[1, 3, 224, 224], &tensor_data)
20+
.unwrap();
21+
22+
// Execute the inference.
23+
let before_compute = Instant::now();
24+
context.compute().unwrap();
25+
println!(
26+
"Executed graph inference, took {} ms.",
27+
before_compute.elapsed().as_millis()
28+
);
29+
30+
// Retrieve the output.
31+
let mut output_buffer = vec![0f32; 1000];
32+
context.get_output(0, &mut output_buffer[..]).unwrap();
33+
34+
let result = sort_results(&output_buffer);
35+
println!("Found results, sorted top 5: {:?}", &result[..5]);
36+
assert_eq!(result[0].0, 284);
37+
Ok(())
38+
}
39+
40+
// Sort the buffer of probabilities. The graph places the match probability for each class at the
41+
// index for that class (e.g. the probability of class 42 is placed at buffer[42]). Here we convert
42+
// to a wrapping InferenceResult and sort the results. It is unclear why the MobileNet output
43+
// indices are "off by one" but the `.skip(1)` below seems necessary to get results that make sense
44+
// (e.g. 763 = "revolver" vs 762 = "restaurant")
45+
fn sort_results(buffer: &[f32]) -> Vec<InferenceResult> {
46+
let mut results: Vec<InferenceResult> = buffer
47+
.iter()
48+
.skip(1)
49+
.enumerate()
50+
.map(|(c, p)| InferenceResult(c, *p))
51+
.collect();
52+
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
53+
results
54+
}
55+
56+
// A wrapper for class ID and match probabilities.
57+
#[derive(Debug, PartialEq)]
58+
struct InferenceResult(usize, f32);

crates/wasi-nn/Cargo.toml

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,19 @@ wasmtime = { workspace = true, features = ["component-model", "runtime"] }
2424

2525
# These dependencies are necessary for the wasi-nn implementation:
2626
tracing = { workspace = true }
27-
openvino = { version = "0.6.0", features = ["runtime-linking"] }
2827
thiserror = { workspace = true }
28+
openvino = { version = "0.6.0", features = [
29+
"runtime-linking",
30+
], optional = true }
31+
32+
[target.'cfg(windows)'.dependencies.windows]
33+
version = "0.52"
34+
features = [
35+
"AI_MachineLearning",
36+
"Storage_Streams",
37+
"Foundation_Collections",
38+
]
39+
optional = true
2940

3041
[build-dependencies]
3142
walkdir = { workspace = true }
@@ -35,3 +46,10 @@ cap-std = { workspace = true }
3546
test-programs-artifacts = { workspace = true }
3647
wasi-common = { workspace = true, features = ["sync"] }
3748
wasmtime = { workspace = true, features = ["cranelift"] }
49+
50+
[features]
51+
default = ["openvino"]
52+
# openvino is available on all platforms, it requires openvino installed.
53+
openvino = ["dep:openvino"]
54+
# winml is only available on Windows 10 1809 and later.
55+
winml = ["dep:windows"]

crates/wasi-nn/src/backend/mod.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,15 @@
22
//! this crate. The `Box<dyn ...>` types returned by these interfaces allow
33
//! implementations to maintain backend-specific state between calls.
44
5+
#[cfg(feature = "openvino")]
56
pub mod openvino;
7+
#[cfg(feature = "winml")]
8+
pub mod winml;
69

10+
#[cfg(feature = "openvino")]
711
use self::openvino::OpenvinoBackend;
12+
#[cfg(feature = "winml")]
13+
use self::winml::WinMLBackend;
814
use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor};
915
use crate::{Backend, ExecutionContext, Graph};
1016
use std::path::Path;
@@ -13,7 +19,16 @@ use wiggle::GuestError;
1319

1420
/// Return a list of all available backend frameworks.
1521
pub fn list() -> Vec<crate::Backend> {
16-
vec![Backend::from(OpenvinoBackend::default())]
22+
let mut backends = vec![];
23+
#[cfg(feature = "openvino")]
24+
{
25+
backends.push(Backend::from(OpenvinoBackend::default()));
26+
}
27+
#[cfg(feature = "winml")]
28+
{
29+
backends.push(Backend::from(WinMLBackend::default()));
30+
}
31+
backends
1732
}
1833

1934
/// A [Backend] contains the necessary state to load [Graph]s.

0 commit comments

Comments
 (0)