Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 37 additions & 19 deletions rs_bindings_from_cc/generate_bindings/database/code_snippet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,8 @@ pub fn generated_items_to_tokens<'db>(
owned_type_name,
member_methods,
lifetime_params,
is_thread_safe,
size,
} = record_item.as_ref();

let type_param_tokens = if !lifetime_params.is_empty() {
Expand All @@ -615,15 +617,37 @@ pub fn generated_items_to_tokens<'db>(
quote! { align(#align) }
}));

let head_padding = head_padding.map(|n| {
let n = Literal::usize_unsuffixed(n);
// TODO(b/481405536): Do this unconditionally.
if *internally_mutable_unknown_fields {
quote! { __non_field_data: [::core::cell::Cell<::core::mem::MaybeUninit<u8>>; #n], }
} else {
quote! { __non_field_data: [::core::mem::MaybeUninit<u8>; #n], }
// For thread-safe types, generate an opaque UnsafeCell body instead of
// individual fields. This enables interior mutability, allowing non-const
// C++ methods to be called through shared references (&self).
let struct_body = if *is_thread_safe {
let size_literal = Literal::usize_unsuffixed(*size);
quote! {
__opaque: ::core::cell::UnsafeCell<[::core::mem::MaybeUninit<u8>; #size_literal]>,
}
});
} else {
let head_padding = head_padding.map(|n| {
let n = Literal::usize_unsuffixed(n);
// TODO(b/481405536): Do this unconditionally.
if *internally_mutable_unknown_fields {
quote! { __non_field_data: [::core::cell::Cell<::core::mem::MaybeUninit<u8>>; #n], }
} else {
quote! { __non_field_data: [::core::mem::MaybeUninit<u8>; #n], }
}
});
let lifetime_markers: Vec<TokenStream> = lifetime_params
.iter()
.map(|lt| {
let field_name = format_ident!("__marker_{}", lt.ident);
quote! { #field_name: ::core::marker::PhantomData<& #lt ()> }
})
.collect();
quote! {
#head_padding
#( #field_definitions )*
#( #lifetime_markers )*
}
};

let send_impl = match implements_send {
true => {
Expand Down Expand Up @@ -681,14 +705,6 @@ pub fn generated_items_to_tokens<'db>(
None
};

let lifetime_markers: Vec<TokenStream> = lifetime_params
.iter()
.map(|lt| {
let field_name = format_ident!("__marker_{}", lt.ident);
quote! { #field_name: ::core::marker::PhantomData<& #lt ()> }
})
.collect();

quote! {
#doc_comment_attr
#derive_attr
Expand All @@ -697,9 +713,7 @@ pub fn generated_items_to_tokens<'db>(
#[repr(#(#repr_attrs),*)]
#crubit_annotation
#visibility #struct_or_union #ident #type_param_tokens {
#head_padding
#( #field_definitions )*
#( #lifetime_markers )*
#struct_body
}

#send_impl
Expand Down Expand Up @@ -1015,6 +1029,10 @@ pub struct Record {
pub owned_type_name: Option<Ident>,
pub member_methods: Vec<TokenStream>,
pub lifetime_params: Vec<syn::Lifetime>,
/// Whether this type is annotated as thread-safe (CRUBIT_THREAD_SAFE).
pub is_thread_safe: bool,
/// The size of the type in bytes (needed for opaque UnsafeCell wrapping).
pub size: usize,
}

#[derive(Clone, Debug)]
Expand Down
5 changes: 5 additions & 0 deletions rs_bindings_from_cc/generate_bindings/database/rs_snippet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,11 @@ pub fn format_generic_params_replacing_by_self<'db, 'a>(
// Otherwise, these functions should be moved into a separate module.

pub fn should_derive_clone(record: &Record) -> bool {
// Thread-safe types wrap their fields in UnsafeCell<[MaybeUninit<u8>; N]>,
// which prevents them from deriving Clone.
if record.is_thread_safe {
return false;
}
match record.trait_derives.clone {
TraitImplPolarity::Positive => true,
TraitImplPolarity::Negative => false,
Expand Down
32 changes: 23 additions & 9 deletions rs_bindings_from_cc/generate_bindings/generate_struct_and_union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,9 @@ pub fn generate_record(db: &BindingsGenerator, record: Rc<Record>) -> Result<Api
derive_attr: generate_derives(&record),
recursively_pinned_attr,
must_use_attr: record.nodiscard.clone().map(MustUseAttr),
align: if override_alignment && record.size_align.alignment > 1 {
// Thread-safe types always need explicit alignment because the opaque
// UnsafeCell<[MaybeUninit<u8>; N]> body has alignment 1.
align: if (override_alignment || record.is_thread_safe) && record.size_align.alignment > 1 {
Some(record.size_align.alignment)
} else {
None
Expand Down Expand Up @@ -745,6 +747,8 @@ pub fn generate_record(db: &BindingsGenerator, record: Rc<Record>) -> Result<Api
member_methods,
delete: operator_delete_impl,
lifetime_params,
is_thread_safe: record.is_thread_safe,
size: record.size_align.size,
};

api_snippets.features |= Feature::negative_impls;
Expand Down Expand Up @@ -776,14 +780,18 @@ pub fn generate_record(db: &BindingsGenerator, record: Rc<Record>) -> Result<Api

api_snippets.assertions.push(rs_size_align_assertions(qualified_ident, &record.size_align));
api_snippets.assertions.push(record_trait_assertions);
api_snippets.assertions.push(field_offset_assertions);
api_snippets.assertions.extend(fields_that_must_be_copy.into_iter().map(
|formatted_field_type| Assertion::Impls {
type_name: formatted_field_type,
all_of: AssertableTrait::Copy.into(),
none_of: FlagSet::empty(),
},
));
// For thread-safe types, the struct body is opaque (UnsafeCell), so the original
// field names don't exist. Skip field offset assertions and field copy assertions.
if !record.is_thread_safe {
api_snippets.assertions.push(field_offset_assertions);
api_snippets.assertions.extend(fields_that_must_be_copy.into_iter().map(
|formatted_field_type| Assertion::Impls {
type_name: formatted_field_type,
all_of: AssertableTrait::Copy.into(),
none_of: FlagSet::empty(),
},
));
}

api_snippets.generated_items.insert(record.id, GeneratedItem::Record(Box::new(record_tokens)));
Ok(api_snippets)
Expand Down Expand Up @@ -816,6 +824,12 @@ pub fn rs_size_align_assertions(type_name: TokenStream, size_align: &ir::SizeAli
}

pub fn generate_derives(record: &Record) -> DeriveAttr {
// Thread-safe types wrap their fields in UnsafeCell<[MaybeUninit<u8>; N]>.
// This opaque byte array doesn't support useful standard derives, and Clone/Copy
// are explicitly prevented to support interior mutability anyway.
if record.is_thread_safe {
return DeriveAttr(vec![]);
}
let mut derives = vec![];
if should_derive_clone(record) {
derives.push(quote! { Clone });
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1943,3 +1943,77 @@ fn test_display() -> Result<()> {
);
Ok(())
}

#[gtest]
fn test_thread_safe_annotation_generates_send_sync() -> Result<()> {
let ir = ir_from_cc(
r#"
struct [[clang::annotate("crubit_thread_safe")]] ThreadSafeStruct final {
int field;
};
"#,
)?;

let rs_api = generate_bindings_tokens_for_test(ir)?.rs_api;

// Thread-safe types should get `unsafe impl Send` and `unsafe impl Sync`.
assert_rs_matches!(rs_api, quote! { unsafe impl Send for ThreadSafeStruct {} });
assert_rs_matches!(rs_api, quote! { unsafe impl Sync for ThreadSafeStruct {} });

Ok(())
}

#[gtest]
fn test_thread_safe_annotation_generates_unsafe_cell_body() -> Result<()> {
let ir = ir_from_cc(
r#"
struct [[clang::annotate("crubit_thread_safe")]] ThreadSafeStruct final {
int field;
};
"#,
)?;

let rs_api = generate_bindings_tokens_for_test(ir)?.rs_api;

// Thread-safe types should have an opaque UnsafeCell body instead of individual fields.
assert_rs_matches!(
rs_api,
quote! {
pub struct ThreadSafeStruct {
__opaque: ::core::cell::UnsafeCell<[::core::mem::MaybeUninit<u8>; 4]>,
}
}
);

// Individual fields should NOT appear.
assert_rs_not_matches!(rs_api, quote! { pub field });

Ok(())
}

#[gtest]
fn test_non_thread_safe_struct_has_negative_send_sync() -> Result<()> {
let ir = ir_from_cc(
r#"
struct RegularStruct final {
int field;
};
"#,
)?;

let rs_api = generate_bindings_tokens_for_test(ir)?.rs_api;

// Non-thread-safe types should NOT have `unsafe impl Send/Sync`.
assert_rs_not_matches!(rs_api, quote! { unsafe impl Send for RegularStruct {} });
assert_rs_not_matches!(rs_api, quote! { unsafe impl Sync for RegularStruct {} });

// They should have negative impls instead.
assert_rs_matches!(rs_api, quote! { impl !Send for RegularStruct {} });
assert_rs_matches!(rs_api, quote! { impl !Sync for RegularStruct {} });

// And should have normal fields, not UnsafeCell wrapping.
assert_rs_matches!(rs_api, quote! { pub field: ::ffi_11::c_int });
assert_rs_not_matches!(rs_api, quote! { __opaque });

Ok(())
}
13 changes: 13 additions & 0 deletions rs_bindings_from_cc/importers/cxx_record.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1121,12 +1121,24 @@ std::optional<IR::Item> CXXRecordDeclImporter::Import(
const clang::TypedefNameDecl* anon_typedef =
record_decl->getTypedefNameForAnonDecl();

absl::StatusOr<bool> is_thread_safe =
HasAnnotationWithoutArgs(*record_decl, "crubit_thread_safe");
if (!is_thread_safe.ok()) {
return unsupported(
FormattedError::FromStatus(std::move(is_thread_safe).status()));
}

absl::StatusOr<TraitDerives> trait_derives = GetTraitDerives(*record_decl);
if (!trait_derives.ok()) {
return unsupported(
FormattedError::FromStatus(std::move(trait_derives).status()));
}

if (*is_thread_safe) {
trait_derives->send = true;
trait_derives->sync = true;
}

absl::StatusOr<SafetyAnnotation> safety_annotation =
GetSafetyAnnotation(*record_decl);
if (!safety_annotation.ok()) {
Expand Down Expand Up @@ -1203,6 +1215,7 @@ std::optional<IR::Item> CXXRecordDeclImporter::Import(
.enclosing_item_id = std::move(enclosing_item_id),
.overloads_operator_delete = MayOverloadOperatorDelete(*record_decl),
.detected_formatter = *detected_formatter,
.is_thread_safe = *is_thread_safe,
.lifetime_inputs = std::move(lifetime_inputs),
};

Expand Down
1 change: 1 addition & 0 deletions rs_bindings_from_cc/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,7 @@ llvm::json::Value Record::ToJson() const {
{"must_bind", must_bind},
{"overloads_operator_delete", overloads_operator_delete},
{"detected_formatter", detected_formatter},
{"is_thread_safe", is_thread_safe},
};

if (!lifetime_inputs.empty()) {
Expand Down
5 changes: 5 additions & 0 deletions rs_bindings_from_cc/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,11 @@ struct Record {
bool overloads_operator_delete = false;
bool detected_formatter = false;

// Whether this type is annotated as thread-safe (CRUBIT_THREAD_SAFE).
// Thread-safe types implement Send+Sync and wrap their internals in
// UnsafeCell, allowing non-const C++ methods to be called via &self.
bool is_thread_safe = false;

// Lifetime variable names bound by this record.
std::vector<std::string> lifetime_inputs;

Expand Down
8 changes: 8 additions & 0 deletions rs_bindings_from_cc/ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1214,6 +1214,9 @@ pub struct Record {
/// string is used.
#[serde(default)]
pub deprecated: Option<Rc<str>>,
/// Whether this type is annotated as thread-safe (CRUBIT_THREAD_SAFE).
#[serde(default)]
pub is_thread_safe: bool,
}

impl GenericItem for Record {
Expand Down Expand Up @@ -1299,6 +1302,11 @@ impl Record {
}

pub fn should_derive_copy(&self) -> bool {
// Thread-safe types wrap their fields in UnsafeCell<[MaybeUninit<u8>; N]>,
// which prevents them from deriving Copy.
if self.is_thread_safe {
return false;
}
match self.trait_derives.copy {
TraitImplPolarity::Positive => true,
TraitImplPolarity::Negative => false,
Expand Down
43 changes: 43 additions & 0 deletions rs_bindings_from_cc/ir_from_cc_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,49 @@ fn test_conflicting_unsafe_annotation() {
);
}

#[gtest]
fn test_struct_with_thread_safe_annotation() {
let ir = ir_from_cc(
r#"
struct [[clang::annotate("crubit_thread_safe")]]
ThreadSafeType {
int foo;
};"#,
)
.unwrap();

assert_ir_matches!(
ir,
quote! {
Record {
rs_name: "ThreadSafeType", ...
is_thread_safe: true, ...
}
}
);
}

#[gtest]
fn test_struct_without_thread_safe_annotation() {
let ir = ir_from_cc(
r#"
struct NotThreadSafe {
int foo;
};"#,
)
.unwrap();

assert_ir_matches!(
ir,
quote! {
Record {
rs_name: "NotThreadSafe", ...
is_thread_safe: false, ...
}
}
);
}

#[gtest]
fn test_struct_with_unnamed_struct_and_union_members() {
// This test input causes `field_decl->getName()` to return an empty string.
Expand Down
27 changes: 27 additions & 0 deletions rs_bindings_from_cc/test/annotations/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,30 @@ crubit_rust_test(
"@crate_index//:googletest",
],
)

crubit_test_cc_library(
name = "thread_safe",
hdrs = ["thread_safe.h"],
deps = [
"//support:annotations",
],
)

crubit_rust_test(
name = "thread_safe_test",
srcs = ["thread_safe_test.rs"],
cc_deps = [
":thread_safe",
],
deps = [
"@crate_index//:googletest",
],
)

golden_test(
name = "thread_safe_golden_test",
basename = "thread_safe",
cc_library = "thread_safe",
golden_cc = "thread_safe_api_impl.cc",
golden_rs = "thread_safe_rs_api.rs",
)
Loading