From 22d915aecf871c55128efb22f2b789026813bd3c Mon Sep 17 00:00:00 2001 From: Devin Jeanpierre Date: Wed, 18 Mar 2026 10:22:52 -0700 Subject: [PATCH] Crubit support for `shared_ptr`. Unlike unique_ptr, I think we don't need a separate virtual_shared_ptr, because the destructor is type-erased. But the other consequence of this is that, dang, it's pretty hard to create one from pure Rust. For now, there's no constructors at all, just the ability to consume and pass around a shared_ptr. (The code for non-const T would be similar.) The CL was _nearly_ one-shot with Gemini, but it overdid it in a few places and I've pared it back to what I think the CL should be. In particular, the invasive use of internal details of libc++ is _entirely_ my idea. The fact that we're relying on shared_ptr having the layout we expect is already digging into the implementation details: we need careful deployment/etc. of this with collaboration of libc++ deployment owners for this to be kept working. In the google monorepo, we got it, but probably this will need to be turned off outside Google... PiperOrigin-RevId: 885665170 --- .../generate_bindings/database/rs_snippet.rs | 21 +++++ rs_bindings_from_cc/importers/cxx_record.cc | 16 +++- rs_bindings_from_cc/ir.cc | 6 ++ rs_bindings_from_cc/ir.h | 8 +- rs_bindings_from_cc/ir.rs | 19 +++- support/cc_std_impl/BUILD | 1 + support/cc_std_impl/shared_ptr.rs | 47 ++++++++++ support/cc_std_impl/std_allocator.h | 11 +++ support/cc_std_impl/test/shared_ptr/BUILD | 28 ++++++ .../test/shared_ptr/shared_ptr_test.rs | 37 ++++++++ .../test/shared_ptr/test_helpers.h | 89 +++++++++++++++++++ 11 files changed, 274 insertions(+), 9 deletions(-) create mode 100644 support/cc_std_impl/shared_ptr.rs create mode 100644 support/cc_std_impl/test/shared_ptr/BUILD create mode 100644 support/cc_std_impl/test/shared_ptr/shared_ptr_test.rs create mode 100644 support/cc_std_impl/test/shared_ptr/test_helpers.h diff --git a/rs_bindings_from_cc/generate_bindings/database/rs_snippet.rs b/rs_bindings_from_cc/generate_bindings/database/rs_snippet.rs index 9b5ecd39c..47658ebbb 100644 --- a/rs_bindings_from_cc/generate_bindings/database/rs_snippet.rs +++ b/rs_bindings_from_cc/generate_bindings/database/rs_snippet.rs @@ -298,6 +298,10 @@ pub enum UniformReprTemplateType { StdUniquePtr { element_type: RsTypeKind, }, + /// std::shared_ptr + StdSharedPtr { + element_type: RsTypeKind, + }, AbslSpan { is_const: bool, include_lifetime: bool, @@ -338,6 +342,18 @@ impl UniformReprTemplateType { Ok(arg_type_kind) }; match template_specialization_kind { + Some(TemplateSpecializationKind::StdSharedPtr { element_type }) => { + let element_type_kind = type_arg(element_type)?; + ensure!(element_type.is_const, "b/485328340: Crubit does not yet support std::shared_ptr, got: {}", element_type_kind.display(db)); + ensure!( + element_type_kind.is_complete(), + "Crubit does not support std::shared_ptr, got: `{}`", + element_type_kind.display(db) + ); + Ok(Some(Rc::new(UniformReprTemplateType::StdSharedPtr { + element_type: element_type_kind, + }))) + } Some(TemplateSpecializationKind::StdUniquePtr { element_type }) => { let element_type = type_arg(element_type)?; ensure!(element_type.is_complete(), "Rust std::unique_ptr cannot be used with incomplete types, and `{}` is incomplete", element_type.display(db)); @@ -401,6 +417,10 @@ impl UniformReprTemplateType { quote! { ::cc_std::std::unique_ptr::<#element_type_tokens> } } } + Self::StdSharedPtr { element_type } => { + let element_type_tokens = element_type.to_token_stream(db); + quote! { ::cc_std::std::shared_ptr_const::<#element_type_tokens> } + } Self::AbslSpan { is_const, include_lifetime, element_type } => { let element_type_tokens = element_type.to_token_stream(db); @@ -429,6 +449,7 @@ impl UniformReprTemplateType { match self { Self::StdVector { .. } => None, Self::StdUniquePtr { .. } => None, + Self::StdSharedPtr { .. } => None, Self::AbslSpan { include_lifetime: true, .. } => Some(Lifetime::elided()), Self::AbslSpan { include_lifetime: false, .. } => None, Self::StdStringView { lifetime, .. } => Some(lifetime.clone()), diff --git a/rs_bindings_from_cc/importers/cxx_record.cc b/rs_bindings_from_cc/importers/cxx_record.cc index 7b59e4646..0c92334c9 100644 --- a/rs_bindings_from_cc/importers/cxx_record.cc +++ b/rs_bindings_from_cc/importers/cxx_record.cc @@ -309,9 +309,7 @@ std::optional GetBridgeTypeAnnotation( for (const clang::TemplateArgument& template_arg : specialization_decl->getTemplateArgs().asArray()) { if (template_arg.getKind() == clang::TemplateArgument::ArgKind::Type) { - // TODO(b/454627672): is record_decl the right decl to check for - // assumed_lifetimes? - template_args.emplace_back( + template_args.push_back( ictx.ConvertQualType(template_arg.getAsType(), /*lifetimes=*/nullptr, /*nullable=*/true, ictx.AreAssumedLifetimesEnabledForTarget( @@ -601,6 +599,18 @@ absl::StatusOr GetTemplateSpecializationKind( ictx.ConvertQualType(t, /*lifetimes=*/nullptr, /*nullable=*/true, ictx.AreAssumedLifetimesEnabledForTarget( ictx.GetOwningTarget(specialization_decl)))); + } else if (templated_decl->getName() == "shared_ptr") { + if (specialization_decl->getTemplateArgs().size() != 1) { + return absl::InvalidArgumentError( + "std::shared_ptr must have exactly one template argument"); + } + clang::QualType t = specialization_decl->getTemplateArgs()[0].getAsType(); + return TemplateSpecialization::StdSharedPtr( + // TODO(b/454627672): is specialization_decl the right decl to check + // for assumed_lifetimes? + ictx.ConvertQualType(t, /*lifetimes=*/nullptr, /*nullable=*/true, + ictx.AreAssumedLifetimesEnabledForTarget( + ictx.GetOwningTarget(specialization_decl)))); } else if (templated_decl->getName() == "vector") { CRUBIT_ASSIGN_OR_RETURN( clang::QualType t, diff --git a/rs_bindings_from_cc/ir.cc b/rs_bindings_from_cc/ir.cc index 2a9e7c653..d84718b0e 100644 --- a/rs_bindings_from_cc/ir.cc +++ b/rs_bindings_from_cc/ir.cc @@ -619,6 +619,12 @@ llvm::json::Value TemplateSpecialization::ToJson() const { llvm::json::Object{ {"element_type", std_vector.element_type}}}}; }, + [&](const StdSharedPtr& std_shared_ptr) { + return llvm::json::Object{ + {"StdSharedPtr", + llvm::json::Object{ + {"element_type", std_shared_ptr.element_type}}}}; + }, [&](const StdUniquePtr& std_unique_ptr) { return llvm::json::Object{ {"StdUniquePtr", diff --git a/rs_bindings_from_cc/ir.h b/rs_bindings_from_cc/ir.h index a0ba4932a..36247ff23 100644 --- a/rs_bindings_from_cc/ir.h +++ b/rs_bindings_from_cc/ir.h @@ -677,6 +677,9 @@ struct TemplateSpecialization { struct StdVector { CcType element_type; }; + struct StdSharedPtr { + CcType element_type; + }; struct StdUniquePtr { CcType element_type; }; @@ -688,8 +691,9 @@ struct TemplateSpecialization { }; struct NonSpecial {}; - using Kind = std::variant; + using Kind = + std::variant; BazelLabel defining_target; Kind kind = NonSpecial{}; diff --git a/rs_bindings_from_cc/ir.rs b/rs_bindings_from_cc/ir.rs index 1a13cf4c4..66cf6f1cb 100644 --- a/rs_bindings_from_cc/ir.rs +++ b/rs_bindings_from_cc/ir.rs @@ -1122,13 +1122,24 @@ pub enum TemplateSpecializationKind { /// std::basic_string_view> StdWStringView, /// std::vector> - StdVector { element_type: CcType }, + StdVector { + element_type: CcType, + }, /// std::unique_ptr> - StdUniquePtr { element_type: CcType }, + StdSharedPtr { + element_type: CcType, + }, + StdUniquePtr { + element_type: CcType, + }, /// c9::Co - C9Co { element_type: CcType }, + C9Co { + element_type: CcType, + }, /// absl::Span - AbslSpan { element_type: CcType }, + AbslSpan { + element_type: CcType, + }, /// Some other template specialization. NonSpecial, } diff --git a/support/cc_std_impl/BUILD b/support/cc_std_impl/BUILD index ca9845835..d0fa17f81 100644 --- a/support/cc_std_impl/BUILD +++ b/support/cc_std_impl/BUILD @@ -54,6 +54,7 @@ additional_rust_srcs_for_crubit_bindings( additional_rust_srcs_for_crubit_bindings( name = "manually_bridged_types", srcs = [ + "shared_ptr.rs", "string.rs", "string_view.rs", "unique_ptr.rs", diff --git a/support/cc_std_impl/shared_ptr.rs b/support/cc_std_impl/shared_ptr.rs new file mode 100644 index 000000000..69c91c48c --- /dev/null +++ b/support/cc_std_impl/shared_ptr.rs @@ -0,0 +1,47 @@ +// Part of the Crubit project, under the Apache License v2.0 with LLVM +// Exceptions. See /LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +use crate::crubit_cc_std_internal::std_allocator; + +#[allow(non_snake_case)] +#[repr(C)] +pub struct shared_ptr_const { + ptr: *const T, + cntrl: *mut core::ffi::c_void, +} + +impl shared_ptr_const { + pub unsafe fn from_raw_parts(ptr: *const T, cntrl: *mut core::ffi::c_void) -> Self { + Self { ptr, cntrl } + } + + pub fn get(&self) -> *const T { + self.ptr + } + + pub fn as_ref(&self) -> Option<&T> { + unsafe { self.ptr.as_ref() } + } +} + +impl Clone for shared_ptr_const { + fn clone(&self) -> Self { + if !self.cntrl.is_null() { + unsafe { + std_allocator::shared_ptr_add_shared(self.cntrl); + } + } + Self { ptr: self.ptr, cntrl: self.cntrl } + } +} + +impl Drop for shared_ptr_const { + fn drop(&mut self) { + if !self.cntrl.is_null() { + unsafe { + std_allocator::shared_ptr_release_shared(self.cntrl); + } + } + } +} diff --git a/support/cc_std_impl/std_allocator.h b/support/cc_std_impl/std_allocator.h index d842089e4..37424c2be 100644 --- a/support/cc_std_impl/std_allocator.h +++ b/support/cc_std_impl/std_allocator.h @@ -5,6 +5,7 @@ #ifndef THIRD_PARTY_CRUBIT_SUPPORT_CC_STD_STD_ALLOCATOR_H_ #define THIRD_PARTY_CRUBIT_SUPPORT_CC_STD_STD_ALLOCATOR_H_ +#include <__memory/shared_count.h> #include #include @@ -40,6 +41,16 @@ inline void cpp_delete(void* ptr, size_t n, size_t align) { #endif } +inline void shared_ptr_add_shared(void* cntrl) { + auto* weak_count = static_cast(cntrl); + weak_count->__add_shared(); +} + +inline void shared_ptr_release_shared(void* cntrl) { + auto* weak_count = static_cast(cntrl); + weak_count->__release_shared(); +} + } // namespace crubit_cc_std_internal::std_allocator #endif // THIRD_PARTY_CRUBIT_SUPPORT_CC_STD_STD_ALLOCATOR_H_ diff --git a/support/cc_std_impl/test/shared_ptr/BUILD b/support/cc_std_impl/test/shared_ptr/BUILD new file mode 100644 index 000000000..8581b0151 --- /dev/null +++ b/support/cc_std_impl/test/shared_ptr/BUILD @@ -0,0 +1,28 @@ +# Part of the Crubit project, under the Apache License v2.0 with LLVM +# Exceptions. See /LICENSE for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//common:crubit_wrapper_macros_oss.bzl", "crubit_rust_test") +load("//rs_bindings_from_cc/test:test_bindings.bzl", "crubit_test_cc_library") + +package(default_applicable_licenses = ["//:license"]) + +crubit_test_cc_library( + name = "test_helpers", + testonly = 1, + hdrs = ["test_helpers.h"], + aspect_hints = ["//features:experimental"], + deps = ["//support:annotations"], +) + +crubit_rust_test( + name = "shared_ptr_test", + srcs = ["shared_ptr_test.rs"], + cc_deps = [ + ":test_helpers", + "//support/public:cc_std", + ], + deps = [ + "@crate_index//:googletest", + ], +) diff --git a/support/cc_std_impl/test/shared_ptr/shared_ptr_test.rs b/support/cc_std_impl/test/shared_ptr/shared_ptr_test.rs new file mode 100644 index 000000000..540480c8d --- /dev/null +++ b/support/cc_std_impl/test/shared_ptr/shared_ptr_test.rs @@ -0,0 +1,37 @@ +// Part of the Crubit project, under the Apache License v2.0 with LLVM +// Exceptions. See /LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +use cc_std::std::shared_ptr_const; +use googletest::prelude::*; + +#[gtest] +fn test_layout() { + // Testing that the layout matches C++ shared_ptr + assert_eq!( + core::mem::size_of::>(), + test_helpers::shared_ptr_test::get_shared_ptr_size() as usize + ); + assert_eq!( + core::mem::align_of::>(), + test_helpers::shared_ptr_test::get_shared_ptr_alignment() as usize + ); +} + +#[gtest] +fn test_polymorphic_destructor() { + // Initial destructor count should be 0. + assert_eq!(test_helpers::shared_ptr_test::get_derived_destructor_count(), 0); + { + let _shared_base = test_helpers::shared_ptr_test::create_virtual_base(); + assert_eq!(core::mem::size_of_val(&_shared_base), 16); + } + // After dropping shared_ptr which points to Derived, count should be 1. + assert_eq!(test_helpers::shared_ptr_test::get_derived_destructor_count(), 1); +} + +#[gtest] +fn test_deref() { + let shared = test_helpers::shared_ptr_test::create_shared_ptr(); + assert_eq!(*shared.as_ref().unwrap(), 1); +} diff --git a/support/cc_std_impl/test/shared_ptr/test_helpers.h b/support/cc_std_impl/test/shared_ptr/test_helpers.h new file mode 100644 index 000000000..e8487cdea --- /dev/null +++ b/support/cc_std_impl/test/shared_ptr/test_helpers.h @@ -0,0 +1,89 @@ +// Part of the Crubit project, under the Apache License v2.0 with LLVM +// Exceptions. See /LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +#ifndef THIRD_PARTY_CRUBIT_SUPPORT_CC_STD_IMPL_TEST_SHARED_PTR_TEST_HELPERS_H_ +#define THIRD_PARTY_CRUBIT_SUPPORT_CC_STD_IMPL_TEST_SHARED_PTR_TEST_HELPERS_H_ + +#include + +#include "support/annotations.h" + +namespace shared_ptr_test { + +struct TwoWords { + void* ptr1; + void* ptr2; +}; + +CRUBIT_MUST_BIND inline size_t get_shared_ptr_size() { + return sizeof(std::shared_ptr); +} + +CRUBIT_MUST_BIND inline size_t get_shared_ptr_alignment() { + return alignof(std::shared_ptr); +} + +CRUBIT_MUST_BIND inline std::shared_ptr create_shared_ptr() { + return std::make_shared(1); +} +CRUBIT_MUST_BIND inline void destroy_shared_ptr(std::shared_ptr) {} +CRUBIT_MUST_BIND inline std::shared_ptr create_shared_ptr_char() { + return std::make_shared('a'); +} +CRUBIT_MUST_BIND inline std::shared_ptr create_shared_ptr_short() { + return std::make_shared(static_cast(1)); +} +CRUBIT_MUST_BIND inline std::shared_ptr +create_shared_ptr_void_ptr() { + return std::shared_ptr(nullptr); +} +CRUBIT_MUST_BIND inline std::shared_ptr +create_shared_ptr_two_words() { + return std::make_shared(); +} + +// Since shared_ptr uses a control block for type erasure, we can get a pointer +// to the base class and the control block will correctly destroy the derived +// type. +struct Base { + virtual ~Base() = default; + static inline int derived_destructor_count = 0; + virtual bool is_derived() const { return false; } +}; +struct Derived : public Base { + ~Derived() override { derived_destructor_count++; } + bool is_derived() const override { return true; } +}; + +CRUBIT_MUST_BIND inline std::shared_ptr create_virtual_base() { + return std::make_shared(); +} + +CRUBIT_MUST_BIND inline int get_derived_destructor_count() { + return Base::derived_destructor_count; +} + +// Using a custom deleter via shared_ptr constructors. +struct CustomDelete { + static inline int custom_delete_count = 0; +}; +struct CustomDeleter { + void operator()(const CustomDelete* p) const { + CustomDelete::custom_delete_count++; + delete p; + } +}; + +CRUBIT_MUST_BIND inline std::shared_ptr +create_custom_delete() { + return std::shared_ptr(new CustomDelete(), + CustomDeleter()); +} + +CRUBIT_MUST_BIND inline int get_custom_delete_count() { + return CustomDelete::custom_delete_count; +} + +} // namespace shared_ptr_test + +#endif // THIRD_PARTY_CRUBIT_SUPPORT_CC_STD_IMPL_TEST_SHARED_PTR_TEST_HELPERS_H_