diff --git a/rs_bindings_from_cc/generate_bindings/database/rs_snippet.rs b/rs_bindings_from_cc/generate_bindings/database/rs_snippet.rs index 9b5ecd39c..47658ebbb 100644 --- a/rs_bindings_from_cc/generate_bindings/database/rs_snippet.rs +++ b/rs_bindings_from_cc/generate_bindings/database/rs_snippet.rs @@ -298,6 +298,10 @@ pub enum UniformReprTemplateType { StdUniquePtr { element_type: RsTypeKind, }, + /// std::shared_ptr + StdSharedPtr { + element_type: RsTypeKind, + }, AbslSpan { is_const: bool, include_lifetime: bool, @@ -338,6 +342,18 @@ impl UniformReprTemplateType { Ok(arg_type_kind) }; match template_specialization_kind { + Some(TemplateSpecializationKind::StdSharedPtr { element_type }) => { + let element_type_kind = type_arg(element_type)?; + ensure!(element_type.is_const, "b/485328340: Crubit does not yet support std::shared_ptr, got: {}", element_type_kind.display(db)); + ensure!( + element_type_kind.is_complete(), + "Crubit does not support std::shared_ptr, got: `{}`", + element_type_kind.display(db) + ); + Ok(Some(Rc::new(UniformReprTemplateType::StdSharedPtr { + element_type: element_type_kind, + }))) + } Some(TemplateSpecializationKind::StdUniquePtr { element_type }) => { let element_type = type_arg(element_type)?; ensure!(element_type.is_complete(), "Rust std::unique_ptr cannot be used with incomplete types, and `{}` is incomplete", element_type.display(db)); @@ -401,6 +417,10 @@ impl UniformReprTemplateType { quote! { ::cc_std::std::unique_ptr::<#element_type_tokens> } } } + Self::StdSharedPtr { element_type } => { + let element_type_tokens = element_type.to_token_stream(db); + quote! { ::cc_std::std::shared_ptr_const::<#element_type_tokens> } + } Self::AbslSpan { is_const, include_lifetime, element_type } => { let element_type_tokens = element_type.to_token_stream(db); @@ -429,6 +449,7 @@ impl UniformReprTemplateType { match self { Self::StdVector { .. } => None, Self::StdUniquePtr { .. } => None, + Self::StdSharedPtr { .. } => None, Self::AbslSpan { include_lifetime: true, .. } => Some(Lifetime::elided()), Self::AbslSpan { include_lifetime: false, .. } => None, Self::StdStringView { lifetime, .. } => Some(lifetime.clone()), diff --git a/rs_bindings_from_cc/importers/cxx_record.cc b/rs_bindings_from_cc/importers/cxx_record.cc index 7b59e4646..0c92334c9 100644 --- a/rs_bindings_from_cc/importers/cxx_record.cc +++ b/rs_bindings_from_cc/importers/cxx_record.cc @@ -309,9 +309,7 @@ std::optional GetBridgeTypeAnnotation( for (const clang::TemplateArgument& template_arg : specialization_decl->getTemplateArgs().asArray()) { if (template_arg.getKind() == clang::TemplateArgument::ArgKind::Type) { - // TODO(b/454627672): is record_decl the right decl to check for - // assumed_lifetimes? - template_args.emplace_back( + template_args.push_back( ictx.ConvertQualType(template_arg.getAsType(), /*lifetimes=*/nullptr, /*nullable=*/true, ictx.AreAssumedLifetimesEnabledForTarget( @@ -601,6 +599,18 @@ absl::StatusOr GetTemplateSpecializationKind( ictx.ConvertQualType(t, /*lifetimes=*/nullptr, /*nullable=*/true, ictx.AreAssumedLifetimesEnabledForTarget( ictx.GetOwningTarget(specialization_decl)))); + } else if (templated_decl->getName() == "shared_ptr") { + if (specialization_decl->getTemplateArgs().size() != 1) { + return absl::InvalidArgumentError( + "std::shared_ptr must have exactly one template argument"); + } + clang::QualType t = specialization_decl->getTemplateArgs()[0].getAsType(); + return TemplateSpecialization::StdSharedPtr( + // TODO(b/454627672): is specialization_decl the right decl to check + // for assumed_lifetimes? + ictx.ConvertQualType(t, /*lifetimes=*/nullptr, /*nullable=*/true, + ictx.AreAssumedLifetimesEnabledForTarget( + ictx.GetOwningTarget(specialization_decl)))); } else if (templated_decl->getName() == "vector") { CRUBIT_ASSIGN_OR_RETURN( clang::QualType t, diff --git a/rs_bindings_from_cc/ir.cc b/rs_bindings_from_cc/ir.cc index 2a9e7c653..d84718b0e 100644 --- a/rs_bindings_from_cc/ir.cc +++ b/rs_bindings_from_cc/ir.cc @@ -619,6 +619,12 @@ llvm::json::Value TemplateSpecialization::ToJson() const { llvm::json::Object{ {"element_type", std_vector.element_type}}}}; }, + [&](const StdSharedPtr& std_shared_ptr) { + return llvm::json::Object{ + {"StdSharedPtr", + llvm::json::Object{ + {"element_type", std_shared_ptr.element_type}}}}; + }, [&](const StdUniquePtr& std_unique_ptr) { return llvm::json::Object{ {"StdUniquePtr", diff --git a/rs_bindings_from_cc/ir.h b/rs_bindings_from_cc/ir.h index a0ba4932a..36247ff23 100644 --- a/rs_bindings_from_cc/ir.h +++ b/rs_bindings_from_cc/ir.h @@ -677,6 +677,9 @@ struct TemplateSpecialization { struct StdVector { CcType element_type; }; + struct StdSharedPtr { + CcType element_type; + }; struct StdUniquePtr { CcType element_type; }; @@ -688,8 +691,9 @@ struct TemplateSpecialization { }; struct NonSpecial {}; - using Kind = std::variant; + using Kind = + std::variant; BazelLabel defining_target; Kind kind = NonSpecial{}; diff --git a/rs_bindings_from_cc/ir.rs b/rs_bindings_from_cc/ir.rs index 1a13cf4c4..66cf6f1cb 100644 --- a/rs_bindings_from_cc/ir.rs +++ b/rs_bindings_from_cc/ir.rs @@ -1122,13 +1122,24 @@ pub enum TemplateSpecializationKind { /// std::basic_string_view> StdWStringView, /// std::vector> - StdVector { element_type: CcType }, + StdVector { + element_type: CcType, + }, /// std::unique_ptr> - StdUniquePtr { element_type: CcType }, + StdSharedPtr { + element_type: CcType, + }, + StdUniquePtr { + element_type: CcType, + }, /// c9::Co - C9Co { element_type: CcType }, + C9Co { + element_type: CcType, + }, /// absl::Span - AbslSpan { element_type: CcType }, + AbslSpan { + element_type: CcType, + }, /// Some other template specialization. NonSpecial, } diff --git a/support/cc_std_impl/BUILD b/support/cc_std_impl/BUILD index ca9845835..d0fa17f81 100644 --- a/support/cc_std_impl/BUILD +++ b/support/cc_std_impl/BUILD @@ -54,6 +54,7 @@ additional_rust_srcs_for_crubit_bindings( additional_rust_srcs_for_crubit_bindings( name = "manually_bridged_types", srcs = [ + "shared_ptr.rs", "string.rs", "string_view.rs", "unique_ptr.rs", diff --git a/support/cc_std_impl/shared_ptr.rs b/support/cc_std_impl/shared_ptr.rs new file mode 100644 index 000000000..69c91c48c --- /dev/null +++ b/support/cc_std_impl/shared_ptr.rs @@ -0,0 +1,47 @@ +// Part of the Crubit project, under the Apache License v2.0 with LLVM +// Exceptions. See /LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +use crate::crubit_cc_std_internal::std_allocator; + +#[allow(non_snake_case)] +#[repr(C)] +pub struct shared_ptr_const { + ptr: *const T, + cntrl: *mut core::ffi::c_void, +} + +impl shared_ptr_const { + pub unsafe fn from_raw_parts(ptr: *const T, cntrl: *mut core::ffi::c_void) -> Self { + Self { ptr, cntrl } + } + + pub fn get(&self) -> *const T { + self.ptr + } + + pub fn as_ref(&self) -> Option<&T> { + unsafe { self.ptr.as_ref() } + } +} + +impl Clone for shared_ptr_const { + fn clone(&self) -> Self { + if !self.cntrl.is_null() { + unsafe { + std_allocator::shared_ptr_add_shared(self.cntrl); + } + } + Self { ptr: self.ptr, cntrl: self.cntrl } + } +} + +impl Drop for shared_ptr_const { + fn drop(&mut self) { + if !self.cntrl.is_null() { + unsafe { + std_allocator::shared_ptr_release_shared(self.cntrl); + } + } + } +} diff --git a/support/cc_std_impl/std_allocator.h b/support/cc_std_impl/std_allocator.h index d842089e4..37424c2be 100644 --- a/support/cc_std_impl/std_allocator.h +++ b/support/cc_std_impl/std_allocator.h @@ -5,6 +5,7 @@ #ifndef THIRD_PARTY_CRUBIT_SUPPORT_CC_STD_STD_ALLOCATOR_H_ #define THIRD_PARTY_CRUBIT_SUPPORT_CC_STD_STD_ALLOCATOR_H_ +#include <__memory/shared_count.h> #include #include @@ -40,6 +41,16 @@ inline void cpp_delete(void* ptr, size_t n, size_t align) { #endif } +inline void shared_ptr_add_shared(void* cntrl) { + auto* weak_count = static_cast(cntrl); + weak_count->__add_shared(); +} + +inline void shared_ptr_release_shared(void* cntrl) { + auto* weak_count = static_cast(cntrl); + weak_count->__release_shared(); +} + } // namespace crubit_cc_std_internal::std_allocator #endif // THIRD_PARTY_CRUBIT_SUPPORT_CC_STD_STD_ALLOCATOR_H_ diff --git a/support/cc_std_impl/test/shared_ptr/BUILD b/support/cc_std_impl/test/shared_ptr/BUILD new file mode 100644 index 000000000..8581b0151 --- /dev/null +++ b/support/cc_std_impl/test/shared_ptr/BUILD @@ -0,0 +1,28 @@ +# Part of the Crubit project, under the Apache License v2.0 with LLVM +# Exceptions. See /LICENSE for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//common:crubit_wrapper_macros_oss.bzl", "crubit_rust_test") +load("//rs_bindings_from_cc/test:test_bindings.bzl", "crubit_test_cc_library") + +package(default_applicable_licenses = ["//:license"]) + +crubit_test_cc_library( + name = "test_helpers", + testonly = 1, + hdrs = ["test_helpers.h"], + aspect_hints = ["//features:experimental"], + deps = ["//support:annotations"], +) + +crubit_rust_test( + name = "shared_ptr_test", + srcs = ["shared_ptr_test.rs"], + cc_deps = [ + ":test_helpers", + "//support/public:cc_std", + ], + deps = [ + "@crate_index//:googletest", + ], +) diff --git a/support/cc_std_impl/test/shared_ptr/shared_ptr_test.rs b/support/cc_std_impl/test/shared_ptr/shared_ptr_test.rs new file mode 100644 index 000000000..540480c8d --- /dev/null +++ b/support/cc_std_impl/test/shared_ptr/shared_ptr_test.rs @@ -0,0 +1,37 @@ +// Part of the Crubit project, under the Apache License v2.0 with LLVM +// Exceptions. See /LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +use cc_std::std::shared_ptr_const; +use googletest::prelude::*; + +#[gtest] +fn test_layout() { + // Testing that the layout matches C++ shared_ptr + assert_eq!( + core::mem::size_of::>(), + test_helpers::shared_ptr_test::get_shared_ptr_size() as usize + ); + assert_eq!( + core::mem::align_of::>(), + test_helpers::shared_ptr_test::get_shared_ptr_alignment() as usize + ); +} + +#[gtest] +fn test_polymorphic_destructor() { + // Initial destructor count should be 0. + assert_eq!(test_helpers::shared_ptr_test::get_derived_destructor_count(), 0); + { + let _shared_base = test_helpers::shared_ptr_test::create_virtual_base(); + assert_eq!(core::mem::size_of_val(&_shared_base), 16); + } + // After dropping shared_ptr which points to Derived, count should be 1. + assert_eq!(test_helpers::shared_ptr_test::get_derived_destructor_count(), 1); +} + +#[gtest] +fn test_deref() { + let shared = test_helpers::shared_ptr_test::create_shared_ptr(); + assert_eq!(*shared.as_ref().unwrap(), 1); +} diff --git a/support/cc_std_impl/test/shared_ptr/test_helpers.h b/support/cc_std_impl/test/shared_ptr/test_helpers.h new file mode 100644 index 000000000..e8487cdea --- /dev/null +++ b/support/cc_std_impl/test/shared_ptr/test_helpers.h @@ -0,0 +1,89 @@ +// Part of the Crubit project, under the Apache License v2.0 with LLVM +// Exceptions. See /LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +#ifndef THIRD_PARTY_CRUBIT_SUPPORT_CC_STD_IMPL_TEST_SHARED_PTR_TEST_HELPERS_H_ +#define THIRD_PARTY_CRUBIT_SUPPORT_CC_STD_IMPL_TEST_SHARED_PTR_TEST_HELPERS_H_ + +#include + +#include "support/annotations.h" + +namespace shared_ptr_test { + +struct TwoWords { + void* ptr1; + void* ptr2; +}; + +CRUBIT_MUST_BIND inline size_t get_shared_ptr_size() { + return sizeof(std::shared_ptr); +} + +CRUBIT_MUST_BIND inline size_t get_shared_ptr_alignment() { + return alignof(std::shared_ptr); +} + +CRUBIT_MUST_BIND inline std::shared_ptr create_shared_ptr() { + return std::make_shared(1); +} +CRUBIT_MUST_BIND inline void destroy_shared_ptr(std::shared_ptr) {} +CRUBIT_MUST_BIND inline std::shared_ptr create_shared_ptr_char() { + return std::make_shared('a'); +} +CRUBIT_MUST_BIND inline std::shared_ptr create_shared_ptr_short() { + return std::make_shared(static_cast(1)); +} +CRUBIT_MUST_BIND inline std::shared_ptr +create_shared_ptr_void_ptr() { + return std::shared_ptr(nullptr); +} +CRUBIT_MUST_BIND inline std::shared_ptr +create_shared_ptr_two_words() { + return std::make_shared(); +} + +// Since shared_ptr uses a control block for type erasure, we can get a pointer +// to the base class and the control block will correctly destroy the derived +// type. +struct Base { + virtual ~Base() = default; + static inline int derived_destructor_count = 0; + virtual bool is_derived() const { return false; } +}; +struct Derived : public Base { + ~Derived() override { derived_destructor_count++; } + bool is_derived() const override { return true; } +}; + +CRUBIT_MUST_BIND inline std::shared_ptr create_virtual_base() { + return std::make_shared(); +} + +CRUBIT_MUST_BIND inline int get_derived_destructor_count() { + return Base::derived_destructor_count; +} + +// Using a custom deleter via shared_ptr constructors. +struct CustomDelete { + static inline int custom_delete_count = 0; +}; +struct CustomDeleter { + void operator()(const CustomDelete* p) const { + CustomDelete::custom_delete_count++; + delete p; + } +}; + +CRUBIT_MUST_BIND inline std::shared_ptr +create_custom_delete() { + return std::shared_ptr(new CustomDelete(), + CustomDeleter()); +} + +CRUBIT_MUST_BIND inline int get_custom_delete_count() { + return CustomDelete::custom_delete_count; +} + +} // namespace shared_ptr_test + +#endif // THIRD_PARTY_CRUBIT_SUPPORT_CC_STD_IMPL_TEST_SHARED_PTR_TEST_HELPERS_H_