@@ -79,6 +79,7 @@ impl BackendGraph for PytorchGraph {
7979 module : self . module . clone ( ) ,
8080 inputs : Vec :: new ( ) ,
8181 output : TchTensor :: new ( ) ,
82+ id_type : None ,
8283 } ) ;
8384 Ok ( box_. into ( ) )
8485 }
@@ -87,41 +88,80 @@ impl BackendGraph for PytorchGraph {
8788unsafe impl Sync for PytorchExecutionContext { }
8889struct PytorchExecutionContext {
8990 module : Arc < Mutex < tch:: CModule > > ,
90- inputs : Vec < tch:: Tensor > ,
91+ inputs : Vec < Option < tch:: Tensor > > ,
9192 output : tch:: Tensor ,
93+ id_type : Option < Id > ,
9294}
9395
9496impl BackendExecutionContext for PytorchExecutionContext {
95- fn set_input ( & mut self , _index : Id , input_tensor : & Tensor ) -> Result < ( ) , BackendError > {
96- // Input index is not used in pytorch models. The forward method to a model passes the tensor/data to
97- // the appropriate layer of the model.
97+ fn set_input ( & mut self , id : Id , input_tensor : & Tensor ) -> Result < ( ) , BackendError > {
9898 let kind = input_tensor. ty . try_into ( ) ?;
9999 let dimensions = input_tensor
100100 . dimensions
101101 . iter ( )
102102 . map ( |& dim| dim as i64 )
103103 . collect :: < Vec < _ > > ( ) ;
104- self . inputs . push ( TchTensor :: from_data_size (
105- & input_tensor. data ,
106- & dimensions,
107- kind,
108- ) ) ;
109- Ok ( ( ) )
104+ let tensor = TchTensor :: from_data_size ( & input_tensor. data , & dimensions, kind) ;
105+ match id {
106+ Id :: Index ( i) => {
107+ // Check if id_type is already set and if it matches the current id type
108+ if let Some ( Id :: Name ( _) ) = self . id_type {
109+ return Err ( BackendError :: BackendAccess ( anyhow:: anyhow!(
110+ "Cannot mix u32 and str indexes"
111+ ) ) ) ;
112+ }
113+ // Set id_type if not already set
114+ if self . id_type . is_none ( ) {
115+ self . id_type = Some ( Id :: Index ( 0 ) ) ; // Provide a u32 value for Index
116+ }
117+ let i = i as usize ;
118+ if i >= self . inputs . len ( ) {
119+ self . inputs . resize_with ( i + 1 , || None ) ;
120+ }
121+ self . inputs [ i] = Some ( tensor) ;
122+ Ok ( ( ) )
123+ }
124+ Id :: Name ( _) => {
125+ // Check if id_type is already set and if it matches the current id type
126+ if let Some ( Id :: Index ( _) ) = self . id_type {
127+ return Err ( BackendError :: BackendAccess ( anyhow:: anyhow!(
128+ "Cannot mix u32 and str indexes"
129+ ) ) ) ;
130+ }
131+ // Set id_type if not already set
132+ if self . id_type . is_none ( ) {
133+ self . id_type = Some ( Id :: Name ( String :: new ( ) ) ) ; // Provide a str value for Name
134+ }
135+ if self . inputs . get ( 0 ) . is_some ( ) {
136+ return Err ( BackendError :: BackendAccess ( anyhow:: anyhow!(
137+ "The pytorch backend does not support multiple named inputs"
138+ ) ) ) ;
139+ } else {
140+ self . inputs . push ( Some ( tensor) ) ;
141+ }
142+ Ok ( ( ) )
143+ }
144+ }
110145 }
111146
112147 fn compute ( & mut self ) -> Result < ( ) , BackendError > {
113- // Use forward method on the compiled module/model after locking the mutex, and pass the input tensor to it
114- self . output = self
115- . module
116- . lock ( )
117- . unwrap ( )
118- . forward_ts ( & self . inputs )
119- . unwrap ( ) ;
148+ let inputs: Vec < tch:: Tensor > = self
149+ . inputs
150+ . iter ( )
151+ . enumerate ( )
152+ . map ( |( index, opt) | {
153+ opt. as_ref ( )
154+ . expect ( & format ! ( "Input tensor at index {} not set up" , index) )
155+ . shallow_clone ( )
156+ } )
157+ . collect ( ) ;
158+ // Use forward_ts method on the compiled module/model after locking the mutex, and pass the input tensor to it
159+ self . output = self . module . lock ( ) . unwrap ( ) . forward_ts ( & inputs) . unwrap ( ) ;
120160 Ok ( ( ) )
121161 }
122162
123163 fn get_output ( & mut self , _index : Id ) -> Result < Tensor , BackendError > {
124- // Output index is not used in pytorch models . The forward method to a model returns the output tensor.
164+ // Output index is not used. The forward_ts method to a model returns a single output tensor.
125165 let numel = self . output . numel ( ) ;
126166 let dimensions = self . output . size ( ) ;
127167 let ty = self . output . kind ( ) . try_into ( ) ?;
0 commit comments