Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion cc_bindings_from_rs/generate_bindings/database/code_snippet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,24 @@ pub struct CcPrerequisites<'tcx> {
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
pub enum TemplateSpecialization<'tcx> {
RsStdEnum(RsStdEnumTemplateSpecialization<'tcx>),
TraitImpl(TraitImplTemplateSpecialization),
}

#[derive(Clone, Debug)]
pub struct TraitImplTemplateSpecialization {
pub self_ty_cc_name: TokenStream,
pub trait_impl: DefId,
}
impl PartialEq for TraitImplTemplateSpecialization {
fn eq(&self, other: &Self) -> bool {
self.trait_impl == other.trait_impl
}
}
impl Eq for TraitImplTemplateSpecialization {}
impl Hash for TraitImplTemplateSpecialization {
fn hash<H: Hasher>(&self, state: &mut H) {
self.trait_impl.hash(state);
}
}

#[derive(Clone, Debug)]
Expand All @@ -92,11 +110,19 @@ pub struct RsStdEnumTemplateSpecializationCore<'tcx> {
pub tag_type_cc: CcSnippet<'tcx>,
}

#[derive(Clone, Debug)]
#[derive(Clone)]
pub struct RsStdEnumTemplateSpecialization<'tcx> {
pub core: RsStdEnumTemplateSpecializationCore<'tcx>,
pub kind: TemplateSpecializationKind<'tcx>,
}
impl std::fmt::Debug for RsStdEnumTemplateSpecialization<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RsStdEnumTemplateSpecialization")
.field("core", &self.core.self_ty_rs)
.field("kind", &self.kind)
.finish()
}
}

impl PartialEq for RsStdEnumTemplateSpecialization<'_> {
fn eq(&self, other: &Self) -> bool {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

use crate::generate_function_thunk::replace_all_regions_with_static;
use crate::generate_struct_and_union::{generate_relocating_ctor, scalar_value_to_string};
use arc_anyhow::Result;
use crate::generate_struct_and_union::{
generate_associated_item, generate_relocating_ctor, has_type_or_const_vars,
scalar_value_to_string,
};
use crate::generate_unsupported_def;
use arc_anyhow::{Error, Result};
use code_gen_utils::{escape_non_identifier_chars, CcInclude};
use database::code_snippet::{
ApiSnippets, CcPrerequisites, CcSnippet, FormattedTy, RsStdEnumTemplateSpecialization,
RsStdEnumTemplateSpecializationCore, TemplateSpecialization,
RsStdEnumTemplateSpecializationCore, TemplateSpecialization, TraitImplTemplateSpecialization,
};
use database::{BindingsGenerator, TypeLocation};
use error_report::anyhow;
Expand All @@ -19,6 +23,8 @@ use quote::{format_ident, quote};
use rustc_abi::VariantIdx;
use rustc_middle::ty::layout::PrimitiveExt;
use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt};
use rustc_span::def_id::DefId;
use std::collections::HashSet;
use std::rc::Rc;

fn is_cpp_movable<'tcx>(db: &BindingsGenerator<'tcx>, ty: Ty<'tcx>) -> bool {
Expand Down Expand Up @@ -1142,13 +1148,167 @@ impl<'tcx> TemplateSpecializationExt<'tcx> for RsStdEnumTemplateSpecialization<'
}
}

/// Collect trait implementations and map them to `TemplateSpecialization::TraitImpl`.
pub(crate) fn collect_trait_impls<'a, 'tcx>(
db: &'a BindingsGenerator<'tcx>,
) -> impl Iterator<Item = TemplateSpecialization<'tcx>> + use<'a, 'tcx> {
let tcx = db.tcx();
let supported_traits: Vec<DefId> = db.supported_traits().iter().copied().collect();
// TyCtxt makes it easy to get all the implementations of a trait, but there isn't an easy way
// to say give me all the trait implementations for this type. This is by design. The compiler
// lazily determines conformance to traits as needed for types and never computes every trait
// for a type in a single data structure.
//
// We, however, want every implementation for a supported type, so we can emit bindings to them.
// We achieve this by walking every supported trait, walking every implementation of that trait,
// and paring down to the implementations that receive bindings.
//
// A serendipitous side effect of this approach is that our implementations are emitted as a
// single list containing just implementations. We want to emit all of our implementations in a
// separate portion of our header from the rest of our bindings. Our impls become template
// specializaitons, which are required to be in an enclosing namespace of the template they
// specialize. This prevents them from living in the same namespace as our other bindings, as
// they may implement a trait that is not enclosed by that namespace.
supported_traits.into_iter().flat_map(move |trait_def_id| {
use rustc_middle::ty::fast_reject::SimplifiedType;
tcx.trait_impls_of(trait_def_id)
.non_blanket_impls()
.into_iter()
.filter_map(move |(simple_ty, impl_def_ids)| match simple_ty {
SimplifiedType::Adt(did) => {
// Only bind implementations for supported ADTs.
let canonical_name = db.symbol_canonical_name(*did)?;
// We explicitly want to allow ADTs that specify cpp_type.
// These are typically C++ types that have generated Rust bindings.
if canonical_name.unqualified.cpp_type.is_none()
&& db.adt_needs_bindings(*did).is_err()
{
return None;
}
let crate_name = tcx.crate_name(did.krate);
// TODO: b/391443811 - Add support for implementations of stdlib types once
// we have headers that can be included for those types.
if ["std", "core", "alloc"].contains(&crate_name.as_str()) {
return None;
}
let adt_cc_name = canonical_name.format_for_cc(db).ok()?;
Some((adt_cc_name, impl_def_ids))
}
// TODO: b/457803426 - Support trait implementations for non-adt types.
_ => None,
})
.flat_map(move |(adt_cc_name, impl_def_ids)| {
impl_def_ids
.iter()
.copied()
// TODO: b/458768435 - This is technically suboptimal. We could instead only
// query for the impls from this crate, but the logic is complicated by
// supporting LOCAL_CRATE. Revisit once we've migrated to rmetas.
.filter(|impl_def_id| impl_def_id.krate == db.source_crate_num())
.map(move |impl_def_id| {
TemplateSpecialization::TraitImpl(TraitImplTemplateSpecialization {
self_ty_cc_name: adt_cc_name.clone(),
trait_impl: impl_def_id,
})
})
})
})
}

fn generate_trait_impl_specialization<'tcx>(
db: &BindingsGenerator<'tcx>,
trait_impl: &TraitImplTemplateSpecialization,
) -> std::result::Result<ApiSnippets<'tcx>, (DefId, Error)> {
let tcx = db.tcx();
let impl_def_id = trait_impl.trait_impl;
let trait_header = tcx.impl_trait_header(impl_def_id);
#[rustversion::before(2025-10-17)]
let trait_header = trait_header.expect("Trait impl should have a trait header");
let trait_ref = trait_header.trait_ref.instantiate_identity();
let trait_def_id = trait_ref.def_id;

let canonical_trait_name = db.symbol_canonical_name(trait_def_id).expect(
"symbol_canonical_name was unexpectedly called on a trait without a canonical name",
);
let trait_name = canonical_trait_name.format_for_cc(db).map_err(|err| (impl_def_id, err))?;

let mut prereqs = CcPrerequisites::default();
let trait_args: Vec<_> = trait_ref
.args
.iter()
// Skip self type.
.skip(1)
.filter_map(|arg| arg.as_type())
.map(|arg| {
if arg.flags().contains(has_type_or_const_vars()) {
return Err((impl_def_id, anyhow!("Implementation of traits must specify all types to receive bindings.")));
}
if arg.walk().any(|arg| arg.as_type().is_some_and(|ty| ty.is_ptr_sized_integral())) {
return Err((
impl_def_id,
anyhow!(
"b/491106325 - isize and usize types are not yet supported as trait type arguments."
),
));
}
db.format_ty_for_cc(arg, TypeLocation::Other)
.map(|snippet| snippet.into_tokens(&mut prereqs))
.map_err(|err| (impl_def_id, err))
})
.collect::<std::result::Result<Vec<_>, _>>()?;

let type_args = if trait_args.is_empty() {
quote! {}
} else {
quote! { <#(#trait_args),*> }
};

let trait_name_with_args = quote! { #trait_name #type_args };

prereqs.depend_on_def(db, trait_def_id).map_err(|err| (impl_def_id, err))?;

let mut member_function_names = HashSet::new();
let assoc_items: ApiSnippets = tcx
.associated_items(impl_def_id)
.in_definition_order()
.flat_map(|assoc_item| generate_associated_item(db, assoc_item, &mut member_function_names))
.collect();

let main_api = assoc_items.main_api.into_tokens(&mut prereqs);
prereqs.includes.insert(db.support_header("rs_std/traits.h"));

let self_ty_cc_name = &trait_impl.self_ty_cc_name;
Ok(ApiSnippets {
main_api: CcSnippet {
tokens: quote! {
__NEWLINE__
template<>
struct rs_std::impl<#self_ty_cc_name, #trait_name_with_args> {
static constexpr bool kIsImplemented = true;

#main_api
};
__NEWLINE__
},
prereqs,
},
cc_details: assoc_items.cc_details,
rs_details: assoc_items.rs_details,
})
}

/// Generate a template specialization.
pub fn generate_template_specialization<'tcx>(
db: &BindingsGenerator<'tcx>,
specialization: TemplateSpecialization<'tcx>,
) -> ApiSnippets<'tcx> {
let mut snippets = match &specialization {
TemplateSpecialization::RsStdEnum(rs_std_enum) => rs_std_enum.clone().api_snippets(db),
TemplateSpecialization::TraitImpl(trait_impl) => {
generate_trait_impl_specialization(db, trait_impl).unwrap_or_else(|(def_id, err)| {
generate_unsupported_def(db, def_id, err).into_main_api()
})
}
};
// Because we reuse logic from generate_struct_and_union here, we will add our `self_ty` as a template specialization of it's own specialization creating a depedency cycle.
// We break that loop manually here to avoid that.
Expand Down
Loading