Skip to content

Commit f27ee77

Browse files
gussmith23AD1024
authored andcommitted
[Rust] Update Rust bindings (apache#9808)
* Update Rust bindings * fmt Co-authored-by: AD1024 <dh63@cs.washington.edu>
1 parent e2c93b9 commit f27ee77

7 files changed

Lines changed: 175 additions & 5 deletions

File tree

include/tvm/relay/attrs/nn.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1232,7 +1232,7 @@ struct UpSampling3DAttrs : public tvm::AttrsNode<UpSampling3DAttrs> {
12321232
/*! \brief Attributes used for the padding operator */
12331233
struct PadAttrs : public tvm::AttrsNode<PadAttrs> {
12341234
Array<Array<Integer>> pad_width;
1235-
std::string pad_mode;
1235+
tvm::String pad_mode;
12361236

12371237
TVM_DECLARE_ATTRS(PadAttrs, "relay.attrs.PadAttrs") {
12381238
TVM_ATTR_FIELD(pad_width).describe(

include/tvm/relay/attrs/transform.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ struct GatherNDAttrs : public tvm::AttrsNode<GatherNDAttrs> {
173173
struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
174174
Integer batch_dims;
175175
Integer axis;
176-
std::string mode;
176+
tvm::String mode;
177177

178178
TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") {
179179
TVM_ATTR_FIELD(batch_dims)
@@ -329,7 +329,7 @@ struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> {
329329
Optional<Array<Integer>> begin;
330330
Optional<Array<Integer>> end;
331331
Optional<Array<Integer>> strides;
332-
std::string slice_mode;
332+
tvm::String slice_mode;
333333
Optional<Array<Integer>> axes;
334334

335335
TVM_DECLARE_ATTRS(StridedSliceAttrs, "relay.attrs.StridedSliceAttrs") {

rust/tvm/src/ir/relay/attrs/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@
1818
*/
1919

2020
pub mod nn;
21+
pub mod reduce;
2122
pub mod transform;

rust/tvm/src/ir/relay/attrs/nn.rs

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,35 @@ use tvm_macros::Object;
2626

2727
type IndexExpr = PrimExpr;
2828

29+
#[repr(C)]
30+
#[derive(Object, Debug)]
31+
#[ref_name = "PadAttrs"]
32+
#[type_key = "relay.attrs.PadAttrs"]
33+
pub struct PadAttrsNode {
34+
pub base: BaseAttrsNode,
35+
pub pad_width: Array<Array<IndexExpr>>,
36+
pub pad_mode: TString,
37+
}
38+
39+
#[repr(C)]
40+
#[derive(Object, Debug)]
41+
#[ref_name = "Conv1DAttrs"]
42+
#[type_key = "relay.attrs.Conv1DAttrs"]
43+
pub struct Conv1DAttrsNode {
44+
pub base: BaseAttrsNode,
45+
pub strides: Array<IndexExpr>,
46+
pub padding: Array<IndexExpr>,
47+
pub dilation: Array<IndexExpr>,
48+
// TODO(@gussmith23) groups is "int", what should it be here?
49+
pub groups: i32,
50+
pub channels: IndexExpr,
51+
pub kernel_size: Array<IndexExpr>,
52+
pub data_layout: TString,
53+
pub kernel_layout: TString,
54+
pub out_layout: TString,
55+
pub out_dtype: DataType,
56+
}
57+
2958
#[repr(C)]
3059
#[derive(Object, Debug)]
3160
#[ref_name = "Conv2DAttrs"]
@@ -42,6 +71,7 @@ pub struct Conv2DAttrsNode {
4271
pub data_layout: TString,
4372
pub kernel_layout: TString,
4473
pub out_layout: TString,
74+
pub auto_scheduler_rewritten_layout: TString,
4575
pub out_dtype: DataType,
4676
}
4777

@@ -138,6 +168,7 @@ pub struct AvgPool2DAttrsNode {
138168
pub pool_size: Array<IndexExpr>,
139169
pub strides: Array<IndexExpr>,
140170
pub padding: Array<IndexExpr>,
171+
pub dilation: Array<IndexExpr>,
141172
pub layout: TString,
142173
pub ceil_mode: bool,
143174
pub count_include_pad: bool,
@@ -155,3 +186,34 @@ pub struct UpSamplingAttrsNode {
155186
pub method: TString,
156187
pub align_corners: bool,
157188
}
189+
190+
#[repr(C)]
191+
#[derive(Object, Debug)]
192+
#[ref_name = "DropoutAttrs"]
193+
#[type_key = "relay.attrs.DropoutAttrs"]
194+
pub struct DropoutAttrsNode {
195+
pub base: BaseAttrsNode,
196+
pub rate: f64,
197+
}
198+
199+
#[repr(C)]
200+
#[derive(Object, Debug)]
201+
#[ref_name = "BatchMatmulAttrs"]
202+
#[type_key = "relay.attrs.BatchMatmulAttrs"]
203+
pub struct BatchMatmulAttrsNode {
204+
pub base: BaseAttrsNode,
205+
pub auto_scheduler_rewritten_layout: TString,
206+
pub out_dtype: DataType,
207+
}
208+
209+
#[repr(C)]
210+
#[derive(Object, Debug)]
211+
#[ref_name = "LayerNormAttrs"]
212+
#[type_key = "relay.attrs.LayerNormAttrs"]
213+
pub struct LayerNormAttrsNode {
214+
pub base: BaseAttrsNode,
215+
pub axis: i32,
216+
pub epsilon: f64,
217+
pub center: bool,
218+
pub scale: bool,
219+
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
use crate::ir::attrs::BaseAttrsNode;
21+
use crate::ir::PrimExpr;
22+
use crate::runtime::array::Array;
23+
use tvm_macros::Object;
24+
25+
type IndexExpr = PrimExpr;
26+
27+
#[repr(C)]
28+
#[derive(Object, Debug)]
29+
#[ref_name = "ReduceAttrs"]
30+
#[type_key = "relay.attrs.ReduceAttrs"]
31+
pub struct ReduceAttrsNode {
32+
pub base: BaseAttrsNode,
33+
pub axis: Array<IndexExpr>,
34+
pub keepdims: bool,
35+
pub exclude: bool,
36+
}
37+
38+
#[repr(C)]
39+
#[derive(Object, Debug)]
40+
#[ref_name = "VarianceAttrs"]
41+
#[type_key = "relay.attrs.ReduceAttrs"]
42+
pub struct VarianceAttrsNode {
43+
pub base: BaseAttrsNode,
44+
pub axis: Array<IndexExpr>,
45+
pub keepdims: bool,
46+
pub exclude: bool,
47+
pub unbiased: bool,
48+
}

rust/tvm/src/ir/relay/attrs/transform.rs

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,35 @@
1818
*/
1919

2020
use crate::ir::attrs::BaseAttrsNode;
21+
use crate::ir::relay::TString;
22+
use crate::ir::tir::IntImm;
2123
use crate::ir::PrimExpr;
2224
use crate::runtime::array::Array;
2325
use crate::runtime::ObjectRef;
2426
use tvm_macros::Object;
27+
use tvm_rt::DataType;
2528

2629
type IndexExpr = PrimExpr;
2730

31+
#[repr(C)]
32+
#[derive(Object, Debug)]
33+
#[ref_name = "ClipAttrs"]
34+
#[type_key = "relay.attrs.ClipAttrs"]
35+
pub struct ClipAttrsNode {
36+
pub base: BaseAttrsNode,
37+
pub a_min: f64,
38+
pub a_max: f64,
39+
}
40+
41+
#[repr(C)]
42+
#[derive(Object, Debug)]
43+
#[ref_name = "CastAttrs"]
44+
#[type_key = "relay.attrs.CastAttrs"]
45+
pub struct CastAttrsNode {
46+
pub base: BaseAttrsNode,
47+
pub dtype: DataType,
48+
}
49+
2850
#[repr(C)]
2951
#[derive(Object, Debug)]
3052
#[ref_name = "ExpandDimsAttrs"]
@@ -79,5 +101,40 @@ pub struct TransposeAttrsNode {
79101
#[type_key = "relay.attrs.SqueezeAttrs"]
80102
pub struct SqueezeAttrsNode {
81103
pub base: BaseAttrsNode,
82-
pub axis: Array<IndexExpr>,
104+
pub axis: Array<IntImm>,
105+
}
106+
107+
#[repr(C)]
108+
#[derive(Object, Debug)]
109+
#[ref_name = "TakeAttrs"]
110+
#[type_key = "relay.attrs.TakeAttrs"]
111+
pub struct TakeAttrsNode {
112+
pub base: BaseAttrsNode,
113+
pub batch_dims: IntImm,
114+
pub axis: IntImm,
115+
pub mode: TString,
116+
}
117+
118+
#[repr(C)]
119+
#[derive(Object, Debug)]
120+
#[ref_name = "StackAttrs"]
121+
#[type_key = "relay.attrs.StackAttrs"]
122+
pub struct StackAttrsNode {
123+
pub base: BaseAttrsNode,
124+
pub axis: IntImm,
125+
}
126+
127+
// TODO(@gussmith23) How to support Optional type? This "just works" when values
128+
// are provided for begin/end/strides, but I'm not sure what happens if None is
129+
// passed from the C++ side.
130+
#[repr(C)]
131+
#[derive(Object, Debug)]
132+
#[ref_name = "StridedSliceAttrs"]
133+
#[type_key = "relay.attrs.StridedSliceAttrs"]
134+
pub struct StridedSliceAttrsNode {
135+
pub base: BaseAttrsNode,
136+
pub begin: Array<IntImm>,
137+
pub end: Array<IntImm>,
138+
pub strides: Array<IntImm>,
139+
pub slice_mode: TString,
83140
}

rust/tvm/src/ir/relay/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
* under the License.
1818
*/
1919
use crate::runtime::array::Array;
20-
use crate::runtime::{object::*, IsObjectRef, String as TString};
20+
use crate::runtime::{self, object::*, IsObjectRef, String as TString};
2121

2222
use super::attrs::Attrs;
2323
use super::expr::BaseExprNode;
@@ -150,6 +150,7 @@ impl Var {
150150
#[type_key = "relay.Call"]
151151
pub struct CallNode {
152152
pub base: ExprNode,
153+
deleter: ObjectRef,
153154
pub op: Expr,
154155
pub args: Array<Expr>,
155156
pub attrs: Attrs,
@@ -166,6 +167,7 @@ impl Call {
166167
) -> Call {
167168
let node = CallNode {
168169
base: ExprNode::base::<CallNode>(span),
170+
deleter: todo!("Don't know how to construct this"),
169171
op: op,
170172
args: args,
171173
attrs: attrs,

0 commit comments

Comments
 (0)