Skip to content

Commit 727a4aa

Browse files
Use index in set_input
1 parent f9100e3 commit 727a4aa

1 file changed

Lines changed: 58 additions & 18 deletions

File tree

crates/wasi-nn/src/backend/pytorch.rs

Lines changed: 58 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ impl BackendGraph for PytorchGraph {
7979
module: self.module.clone(),
8080
inputs: Vec::new(),
8181
output: TchTensor::new(),
82+
id_type: None,
8283
});
8384
Ok(box_.into())
8485
}
@@ -87,41 +88,80 @@ impl BackendGraph for PytorchGraph {
8788
unsafe impl Sync for PytorchExecutionContext {}
8889
struct PytorchExecutionContext {
8990
module: Arc<Mutex<tch::CModule>>,
90-
inputs: Vec<tch::Tensor>,
91+
inputs: Vec<Option<tch::Tensor>>,
9192
output: tch::Tensor,
93+
id_type: Option<Id>,
9294
}
9395

9496
impl BackendExecutionContext for PytorchExecutionContext {
95-
fn set_input(&mut self, _index: Id, input_tensor: &Tensor) -> Result<(), BackendError> {
96-
// Input index is not used in pytorch models. The forward method to a model passes the tensor/data to
97-
// the appropriate layer of the model.
97+
fn set_input(&mut self, id: Id, input_tensor: &Tensor) -> Result<(), BackendError> {
9898
let kind = input_tensor.ty.try_into()?;
9999
let dimensions = input_tensor
100100
.dimensions
101101
.iter()
102102
.map(|&dim| dim as i64)
103103
.collect::<Vec<_>>();
104-
self.inputs.push(TchTensor::from_data_size(
105-
&input_tensor.data,
106-
&dimensions,
107-
kind,
108-
));
109-
Ok(())
104+
let tensor = TchTensor::from_data_size(&input_tensor.data, &dimensions, kind);
105+
match id {
106+
Id::Index(i) => {
107+
// Check if id_type is already set and if it matches the current id type
108+
if let Some(Id::Name(_)) = self.id_type {
109+
return Err(BackendError::BackendAccess(anyhow::anyhow!(
110+
"Cannot mix u32 and str indexes"
111+
)));
112+
}
113+
// Set id_type if not already set
114+
if self.id_type.is_none() {
115+
self.id_type = Some(Id::Index(0)); // Provide a u32 value for Index
116+
}
117+
let i = i as usize;
118+
if i >= self.inputs.len() {
119+
self.inputs.resize_with(i + 1, || None);
120+
}
121+
self.inputs[i] = Some(tensor);
122+
Ok(())
123+
}
124+
Id::Name(_) => {
125+
// Check if id_type is already set and if it matches the current id type
126+
if let Some(Id::Index(_)) = self.id_type {
127+
return Err(BackendError::BackendAccess(anyhow::anyhow!(
128+
"Cannot mix u32 and str indexes"
129+
)));
130+
}
131+
// Set id_type if not already set
132+
if self.id_type.is_none() {
133+
self.id_type = Some(Id::Name(String::new())); // Provide a str value for Name
134+
}
135+
if self.inputs.get(0).is_some() {
136+
return Err(BackendError::BackendAccess(anyhow::anyhow!(
137+
"The pytorch backend does not support multiple named inputs"
138+
)));
139+
} else {
140+
self.inputs.push(Some(tensor));
141+
}
142+
Ok(())
143+
}
144+
}
110145
}
111146

112147
fn compute(&mut self) -> Result<(), BackendError> {
113-
// Use forward method on the compiled module/model after locking the mutex, and pass the input tensor to it
114-
self.output = self
115-
.module
116-
.lock()
117-
.unwrap()
118-
.forward_ts(&self.inputs)
119-
.unwrap();
148+
let inputs: Vec<tch::Tensor> = self
149+
.inputs
150+
.iter()
151+
.enumerate()
152+
.map(|(index, opt)| {
153+
opt.as_ref()
154+
.expect(&format!("Input tensor at index {} not set up", index))
155+
.shallow_clone()
156+
})
157+
.collect();
158+
// Use forward_ts method on the compiled module/model after locking the mutex, and pass the input tensor to it
159+
self.output = self.module.lock().unwrap().forward_ts(&inputs).unwrap();
120160
Ok(())
121161
}
122162

123163
fn get_output(&mut self, _index: Id) -> Result<Tensor, BackendError> {
124-
// Output index is not used in pytorch models. The forward method to a model returns the output tensor.
164+
// Output index is not used. The forward_ts method to a model returns a single output tensor.
125165
let numel = self.output.numel();
126166
let dimensions = self.output.size();
127167
let ty = self.output.kind().try_into()?;

0 commit comments

Comments
 (0)