Skip to content

Commit 425ee30

Browse files
anforowiczcopybara-github
authored andcommitted
Substitute impl Into<T> => T in bindings of generic functions.
PiperOrigin-RevId: 884700041
1 parent 96c6797 commit 425ee30

12 files changed

Lines changed: 450 additions & 18 deletions

File tree

cc_bindings_from_rs/generate_bindings/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ rust_library(
2323
"generate_function_thunk.rs",
2424
"generate_struct_and_union.rs",
2525
"generate_template_specialization.rs",
26+
"get_generic_args.rs",
2627
"lib.rs",
2728
],
2829
edition = "2021",

cc_bindings_from_rs/generate_bindings/database/db.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use code_gen_utils::CcInclude;
1616
use dyn_format::Format;
1717
use error_report::{ErrorReporting, ReportFatalError};
1818
use proc_macro2::{Ident, TokenStream};
19-
use rustc_middle::ty::{Ty, TyCtxt};
19+
use rustc_middle::ty::{self, Ty, TyCtxt};
2020
use rustc_span::def_id::{CrateNum, DefId};
2121
use rustc_span::Symbol;
2222
use std::collections::HashMap;
@@ -307,5 +307,10 @@ memoized::query_group! {
307307
///
308308
/// Implementation: cc_bindings_from_rs/generate_bindings/generate_struct_and_union.rs?q=function:local_from_trait_impls_by_argument
309309
fn from_trait_impls_by_argument(&self, crate_num: CrateNum) -> Rc<HashMap<Ty<'tcx>, Vec<DefId>>>;
310+
311+
/// Given a function identified by `fn_def_id` (generic or non-generic) tries to return
312+
/// the generic arguments that should be used in the generated Crubit bindings.
313+
/// Fails if any of the generic parameters cannot be replaced with a concrete type.
314+
fn get_generic_args(&self, fn_def_id: DefId) -> Result<ty::GenericArgsRef<'tcx>>;
310315
}
311316
}

cc_bindings_from_rs/generate_bindings/generate_bindings_test.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1327,7 +1327,7 @@ fn test_format_item_static_method_with_generic_type_parameters() {
13271327
let main_api = &result.main_api;
13281328
let unsupported_msg = "Error generating bindings for `SomeStruct::generic_method` \
13291329
defined at <crubit_unittests.rs>;l=12: \
1330-
Generic functions are not supported yet (b/259749023)";
1330+
No valid non-generic replacement for generic type param `T`";
13311331
assert_cc_matches!(
13321332
main_api.tokens,
13331333
quote! {

cc_bindings_from_rs/generate_bindings/generate_function.rs

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use code_gen_utils::{
2121
use crubit_abi_type::{CrubitAbiTypeToCppExprTokens, CrubitAbiTypeToCppTokens};
2222
use database::code_snippet::{ApiSnippets, CcPrerequisites, CcSnippet};
2323
use database::{BindingsGenerator, TypeLocation};
24-
use error_report::{anyhow, bail, ensure};
24+
use error_report::{anyhow, bail};
2525
use itertools::Itertools;
2626
use proc_macro2::{Ident, Literal, TokenStream};
2727
use query_compiler::{is_copy, post_analysis_typing_env};
@@ -103,8 +103,12 @@ fn thunk_name(
103103
.unwrap_or(def_name)
104104
}
105105
} else {
106-
// Call to `mono` is ok - `generics_of` have been checked above.
107-
let instance = ty::Instance::mono(tcx, def_id);
106+
// `expect` and `expect_resolve` are used because `fn get_generic_args`
107+
// should be called earlier to reject cases with unsupported generics.
108+
let typing_env = ty::TypingEnv::non_body_analysis(tcx, def_id);
109+
let args = db.get_generic_args(def_id).expect("Generics should be checked earlier");
110+
let span = tcx.def_span(def_id);
111+
let instance = ty::Instance::expect_resolve(tcx, typing_env, def_id, args, span);
108112
tcx.symbol_name(instance).name.to_string()
109113
};
110114
let target_path_mangled_hash = if db.no_thunk_name_mangling() {
@@ -756,16 +760,11 @@ pub fn generate_function<'tcx>(
756760
) -> Result<ApiSnippets<'tcx>> {
757761
let tcx = db.tcx();
758762

759-
// TODO(b/281542952): Add support for `impl Into<T>` => `T` and similar substitutions.
760-
ensure!(
761-
!tcx.generics_of(def_id).requires_monomorphization(tcx),
762-
"Generic functions are not supported yet (b/259749023)"
763-
);
764-
let sig_mid = liberate_and_deanonymize_late_bound_regions(
765-
tcx,
766-
tcx.fn_sig(def_id).instantiate_identity(),
767-
def_id,
768-
);
763+
let sig_mid = {
764+
let generic_args = db.get_generic_args(def_id)?;
765+
let early_bound_fn_sig = tcx.fn_sig(def_id).instantiate(tcx, generic_args);
766+
liberate_and_deanonymize_late_bound_regions(tcx, early_bound_fn_sig, def_id)
767+
};
769768
check_fn_sig(&sig_mid)?;
770769

771770
let trait_ref = tcx

cc_bindings_from_rs/generate_bindings/generate_function_test.rs

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -920,16 +920,81 @@ fn test_format_item_lifetime_generic_fn_with_various_lifetimes() {
920920
}
921921

922922
#[test]
923-
fn test_format_item_unsupported_type_generic_fn() {
923+
fn test_format_item_generic_fn_into_trait_basic_replacement() {
924+
let test_src = r#"
925+
#![allow(unused)]
926+
pub fn generic_function(arg: impl Into<i32>) { todo!() }
927+
"#;
928+
test_format_item(test_src, "generic_function", |result| {
929+
let result = result.unwrap().unwrap();
930+
assert_cc_matches!(
931+
result.main_api.tokens,
932+
quote! {
933+
void generic_function(std::int32_t arg);
934+
}
935+
);
936+
assert_rs_matches!(
937+
result.rs_details.tokens,
938+
quote! {
939+
unsafe extern "C" fn __crubit_thunk_generic_ufunction(arg: i32) -> () {
940+
unsafe { ::rust_out::generic_function(arg) }
941+
}
942+
}
943+
);
944+
});
945+
}
946+
947+
/// This test was initally added to provide coverage/verification that
948+
/// _all_ generic parameters need to have valid replacements.
949+
#[test]
950+
fn test_format_item_generic_fn_into_trait_and_unsupported_trait() {
951+
let test_src = r#"
952+
pub trait MyTrait {}
953+
pub fn generic_function(_i: impl Into<i32>, _t: impl MyTrait) { todo!() }
954+
"#;
955+
test_format_item(test_src, "generic_function", |result| {
956+
let err = result.unwrap_err();
957+
assert_eq!(err, "No valid non-generic replacement for generic type param `impl MyTrait`");
958+
});
959+
}
960+
961+
/// This test was initally added to provide coverage/verification that
962+
/// _all_ clauses/constraints of `T` have to be considered in
963+
/// `is_valid_replacement_for_generic_type_param`.
964+
#[test]
965+
fn test_format_item_generic_fn_into_trait_when_failed_substitiution() {
966+
let test_src = r#"
967+
pub trait MyTrait {}
968+
pub fn generic_function<T>(_t: T) where T: Into<i32>, T: MyTrait { todo!() }
969+
"#;
970+
test_format_item(test_src, "generic_function", |result| {
971+
let err = result.unwrap_err();
972+
assert_eq!(err, "No valid non-generic replacement for generic type param `T`");
973+
});
974+
}
975+
976+
#[test]
977+
fn test_format_item_generic_fn_unsupported_const_param() {
978+
let test_src = r#"
979+
pub fn generic_function<const N: usize>() { todo!() }
980+
"#;
981+
test_format_item(test_src, "generic_function", |result| {
982+
let err = result.unwrap_err();
983+
assert_eq!(err, "`const`-generic functions are not supported (b/259749023)");
984+
});
985+
}
986+
987+
#[test]
988+
fn test_format_item_generic_fn_unsupported_type_param() {
924989
let test_src = r#"
925990
use std::fmt::Display;
926-
pub fn generic_function<T: Default + Display>() {
991+
pub fn generic_function<T: Default + Display>(_: T) {
927992
println!("{}", T::default());
928993
}
929994
"#;
930995
test_format_item(test_src, "generic_function", |result| {
931996
let err = result.unwrap_err();
932-
assert_eq!(err, "Generic functions are not supported yet (b/259749023)");
997+
assert_eq!(err, "No valid non-generic replacement for generic type param `T`");
933998
});
934999
}
9351000

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
// Part of the Crubit project, under the Apache License v2.0 with LLVM
2+
// Exceptions. See /LICENSE for license information.
3+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
use crate::liberate_and_deanonymize_late_bound_regions;
6+
use arc_anyhow::{anyhow, bail, ensure, Result};
7+
use database::BindingsGenerator;
8+
use rustc_infer::traits::{Obligation, ObligationCause};
9+
use rustc_middle::ty::{self, Ty, TyCtxt};
10+
use rustc_span::def_id::DefId;
11+
use rustc_span::symbol::sym;
12+
use rustc_trait_selection::infer::canonical::ir::TypingMode;
13+
use rustc_trait_selection::infer::TyCtxtInferExt;
14+
use rustc_trait_selection::traits::ObligationCtxt;
15+
use std::collections::{HashMap, HashSet};
16+
17+
/// Implementation of `BindingsGenerator::get_generic_args`.
18+
pub fn get_generic_args<'tcx>(
19+
db: &BindingsGenerator<'tcx>,
20+
fn_def_id: DefId,
21+
) -> Result<ty::GenericArgsRef<'tcx>> {
22+
let tcx = db.tcx();
23+
let generics = tcx.generics_of(fn_def_id);
24+
let predicates = tcx.predicates_of(fn_def_id);
25+
26+
// See the doc comment for `unused_generic_param` in
27+
// `test/functions/functions.rs` for an explanation why we currently don't
28+
// support unused generic params.
29+
let indices_of_actually_used_generic_params = {
30+
let mut finder = GenericParamsFinder::default();
31+
let fn_sig = tcx.fn_sig(fn_def_id).instantiate_identity();
32+
let fn_sig = liberate_and_deanonymize_late_bound_regions(tcx, fn_sig, fn_def_id);
33+
use rustc_type_ir::TypeVisitable;
34+
fn_sig.visit_with(&mut finder);
35+
finder.generic_param_indices
36+
};
37+
38+
let replacements: HashMap<usize, ty::GenericArg<'tcx>> = (0..generics.count())
39+
.map(|idx| {
40+
let param_def = generics.param_at(idx, tcx);
41+
let replacement = match param_def.kind {
42+
ty::GenericParamDefKind::Const { .. } => {
43+
bail!("`const`-generic functions are not supported (b/259749023)");
44+
}
45+
ty::GenericParamDefKind::Lifetime => tcx.mk_param_from_def(param_def),
46+
ty::GenericParamDefKind::Type { .. } => {
47+
ensure!(
48+
indices_of_actually_used_generic_params.contains(&param_def.index),
49+
"No support for replacing an _unused_ generic type param: `{}`",
50+
param_def.name,
51+
);
52+
get_replacement_for_generic_type_param(tcx, fn_def_id, predicates, param_def)
53+
.map(|ty| ty.into())
54+
.ok_or_else(|| {
55+
anyhow!(
56+
"No valid non-generic replacement for generic type param `{}`",
57+
param_def.name,
58+
)
59+
})?
60+
}
61+
};
62+
Ok((idx, replacement))
63+
})
64+
.collect::<Result<Vec<_>>>()?
65+
.into_iter()
66+
.collect();
67+
68+
Ok(ty::GenericArgs::for_item(tcx, fn_def_id, |param_def, _old_generic_args| {
69+
*replacements
70+
.get(&(param_def.index as usize))
71+
.expect("All errors should have been handled above")
72+
}))
73+
}
74+
75+
/// Given a generic constraint of the form `T: Trait`, returns the type that can potentially
76+
/// replace `T` in the generated bindings.
77+
///
78+
/// If the returned type needs to use a new anonymous lifetime, then it will be generated
79+
/// using the given `def_id` as its scope.
80+
fn get_replacement_for_trait_predicate<'tcx>(
81+
tcx: TyCtxt<'tcx>,
82+
trait_predicate: ty::TraitPredicate<'tcx>,
83+
) -> Option<Ty<'tcx>> {
84+
if trait_predicate.polarity != ty::PredicatePolarity::Positive {
85+
return None;
86+
}
87+
let trait_ref = trait_predicate.trait_ref;
88+
89+
// `args[0]` is `Self` / `T`. And when working with `Into<U>`, `AsRef<U>`, etc.
90+
// we typically want the first and only other generic argument - `U`.
91+
let ty1 = trait_ref.args.get(1).and_then(|generic_arg| generic_arg.as_type())?;
92+
93+
// `T: Into<U>` => `U`
94+
if tcx.is_diagnostic_item(sym::Into, trait_ref.def_id) {
95+
return Some(ty1);
96+
}
97+
98+
// TODO(b/281542952): Implement other replacements as needed.
99+
None
100+
}
101+
102+
/// Returns `true` if `new_ty` can be used as a replacement for `generic_param`
103+
/// in a generic item identified by `def_id` and constrained by the given `predicates`.
104+
fn is_valid_replacement_for_generic_type_param<'tcx>(
105+
tcx: TyCtxt<'tcx>,
106+
def_id: DefId,
107+
predicates: ty::GenericPredicates<'tcx>,
108+
generic_param: &ty::GenericParamDef,
109+
new_ty: Ty<'tcx>,
110+
) -> bool {
111+
let generic_args = ty::GenericArgs::for_item(tcx, def_id, |param_def, _old_generic_args| {
112+
if param_def.index == generic_param.index {
113+
new_ty.into()
114+
} else {
115+
tcx.mk_param_from_def(param_def)
116+
}
117+
});
118+
119+
let infcx = tcx.infer_ctxt().build(TypingMode::non_body_analysis());
120+
let ocx = ObligationCtxt::new(&infcx);
121+
let param_env = tcx.param_env(def_id);
122+
for (predicate, _span) in predicates.instantiate(tcx, generic_args) {
123+
let cause = ObligationCause::dummy();
124+
let predicate = ocx.normalize(&cause, param_env, predicate);
125+
ocx.register_obligation(Obligation::new(tcx, cause, param_env, predicate));
126+
}
127+
let errors = ocx.evaluate_obligations_error_on_ambiguity();
128+
errors.is_empty()
129+
}
130+
131+
/// Given a `generic_type_param` (e.g. `T` in `fn foo<T>(...)`) tries to find
132+
/// a non-generic type which can be used instead. For example, `T: Into<U>` may
133+
/// be potentially replaced with `U`, if `U` meets all the other `predicates`
134+
/// that may be constraining `T`. When multiple answers are possible, returns
135+
/// the first one.
136+
fn get_replacement_for_generic_type_param<'tcx>(
137+
tcx: TyCtxt<'tcx>,
138+
def_id: DefId,
139+
predicates: ty::GenericPredicates<'tcx>,
140+
generic_type_param: &ty::GenericParamDef,
141+
) -> Option<Ty<'tcx>> {
142+
// Look only at trait predicates involving this param (e.g. `T: SomeTrait`).
143+
let trait_predicates_for_this_generic_param = predicates
144+
.predicates
145+
.iter()
146+
.filter_map(|(clause, _)| match clause.kind().skip_binder() {
147+
ty::ClauseKind::Trait(trait_predicate) => Some(trait_predicate),
148+
_ => None,
149+
})
150+
.filter(|trait_predicate| match trait_predicate.trait_ref.self_ty().kind() {
151+
ty::Param(p) => p.index == generic_type_param.index,
152+
_ => false,
153+
});
154+
155+
// Find the first replacement that fits all the constraints.
156+
trait_predicates_for_this_generic_param
157+
.filter_map(|trait_predicate| get_replacement_for_trait_predicate(tcx, trait_predicate))
158+
.find(|new_ty| {
159+
is_valid_replacement_for_generic_type_param(
160+
tcx,
161+
def_id,
162+
predicates,
163+
generic_type_param,
164+
*new_ty,
165+
)
166+
})
167+
}
168+
169+
#[derive(Default)]
170+
struct GenericParamsFinder {
171+
generic_param_indices: HashSet<u32>,
172+
}
173+
174+
impl<'tcx> ty::TypeVisitor<TyCtxt<'tcx>> for GenericParamsFinder {
175+
fn visit_ty(&mut self, t: Ty<'tcx>) {
176+
if let ty::Param(p) = t.kind() {
177+
self.generic_param_indices.insert(p.index);
178+
}
179+
180+
// Visit nested types (e.g., `&T` or `&[T]`)
181+
use ty::TypeSuperVisitable;
182+
t.super_visit_with(self)
183+
}
184+
}

cc_bindings_from_rs/generate_bindings/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ pub mod generate_function;
2222
mod generate_function_thunk;
2323
mod generate_struct_and_union;
2424
mod generate_template_specialization;
25+
mod get_generic_args;
2526

2627
use crate::format_type::{
2728
crubit_abi_type_from_ty, ensure_ty_is_pointer_like, format_cc_ident, format_cc_ident_symbol,
@@ -230,6 +231,7 @@ pub fn new_database<'db>(
230231
generate_adt_core,
231232
crubit_abi_type_from_ty,
232233
from_trait_impls_by_argument,
234+
get_generic_args::get_generic_args,
233235
)
234236
}
235237

0 commit comments

Comments
 (0)