Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,59 @@
use crate::util::macros::parse_macro_input2;
use proc_macro2::TokenStream as MacroStream;

// TODO: move
pub fn inject_send_sync_static(generics: &mut syn::Generics) {
use syn::{Lifetime, Path, TraitBound, TraitBoundModifier, TypeParamBound, parse_quote};
fn path_is_ident(path: &Path, name: &str) -> bool {
path.segments.len() == 1 && path.segments[0].ident == name
}
for tp in generics.type_params_mut() {
// Scan existing bounds so we don't duplicate them.
let mut has_send = false;
let mut has_sync = false;
let mut has_static = false;

for b in &tp.bounds {
match b {
TypeParamBound::Trait(TraitBound {
modifier: TraitBoundModifier::None,
path,
..
}) => {
if path_is_ident(path, "Send") {
has_send = true;
}
if path_is_ident(path, "Sync") {
has_sync = true;
}
}
TypeParamBound::Trait(TraitBound {
modifier: TraitBoundModifier::Maybe(_),
..
}) => {
// e.g. ?Sized — ignore
}
TypeParamBound::Lifetime(lt) => {
if lt == &Lifetime::new("'static", lt.apostrophe) {
has_static = true;
}
}
_ => {}
}
}

if !has_send {
tp.bounds.push(parse_quote!(::core::marker::Send));
}
if !has_sync {
tp.bounds.push(parse_quote!(::core::marker::Sync));
}
if !has_static {
tp.bounds.push(parse_quote!('static));
}
}
}

pub fn expand_derive_auto_plugin(input: MacroStream) -> MacroStream {
use crate::macro_api::derives::auto_plugin::AutoPluginDeriveArgs;
use darling::FromDeriveInput;
Expand All @@ -9,10 +62,15 @@ pub fn expand_derive_auto_plugin(input: MacroStream) -> MacroStream {
use syn::DeriveInput;

let derive_input = parse_macro_input2!(input as DeriveInput);
let params = match AutoPluginDeriveArgs::from_derive_input(&derive_input) {
Ok(params) => params,
Err(err) => return err.write_errors(),
let params = {
let mut params = match AutoPluginDeriveArgs::from_derive_input(&derive_input) {
Ok(params) => params,
Err(err) => return err.write_errors(),
};
inject_send_sync_static(&mut params.generics);
params
};

let ident = &params.ident; // `Test`
let generics = &params.generics; // `<T1, T2>`
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
Expand Down
11 changes: 6 additions & 5 deletions tests/e2e/auto_plugin_with_generics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ use std::hash::Hash;
use std::ops::{AddAssign, Deref};

#[derive(AutoPlugin, Default)]
#[auto_plugin(impl_generic_plugin_trait, impl_generic_auto_plugin_trait)]
struct Test<T1, T2>(T1, T2)
where
T1: Default + Send + Sync + 'static,
T2: Default + Send + Sync + 'static;
#[auto_plugin(
impl_generic_plugin_trait,
impl_generic_auto_plugin_trait,
generics(usize, bool)
)]
struct Test<T1, T2>(T1, T2);

#[derive(Component, Default, Reflect)]
#[reflect(Component)]
Expand Down