Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions rs_bindings_from_cc/generate_bindings/database/rs_snippet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,10 @@ pub enum UniformReprTemplateType {
StdUniquePtr {
element_type: RsTypeKind,
},
/// std::shared_ptr<const T>
StdSharedPtr {
element_type: RsTypeKind,
},
AbslSpan {
is_const: bool,
include_lifetime: bool,
Expand Down Expand Up @@ -338,6 +342,18 @@ impl UniformReprTemplateType {
Ok(arg_type_kind)
};
match template_specialization_kind {
Some(TemplateSpecializationKind::StdSharedPtr { element_type }) => {
let element_type_kind = type_arg(element_type)?;
ensure!(element_type.is_const, "b/485328340: Crubit does not yet support std::shared_ptr<non-const T>, got: {}", element_type_kind.display(db));
ensure!(
element_type_kind.is_complete(),
"Crubit does not support std::shared_ptr<incomplete T>, got: `{}`",
element_type_kind.display(db)
);
Ok(Some(Rc::new(UniformReprTemplateType::StdSharedPtr {
element_type: element_type_kind,
})))
}
Some(TemplateSpecializationKind::StdUniquePtr { element_type }) => {
let element_type = type_arg(element_type)?;
ensure!(element_type.is_complete(), "Rust std::unique_ptr<T> cannot be used with incomplete types, and `{}` is incomplete", element_type.display(db));
Expand Down Expand Up @@ -401,6 +417,10 @@ impl UniformReprTemplateType {
quote! { ::cc_std::std::unique_ptr::<#element_type_tokens> }
}
}
Self::StdSharedPtr { element_type } => {
let element_type_tokens = element_type.to_token_stream(db);
quote! { ::cc_std::std::shared_ptr_const::<#element_type_tokens> }
}
Self::AbslSpan { is_const, include_lifetime, element_type } => {
let element_type_tokens = element_type.to_token_stream(db);

Expand Down Expand Up @@ -429,6 +449,7 @@ impl UniformReprTemplateType {
match self {
Self::StdVector { .. } => None,
Self::StdUniquePtr { .. } => None,
Self::StdSharedPtr { .. } => None,
Self::AbslSpan { include_lifetime: true, .. } => Some(Lifetime::elided()),
Self::AbslSpan { include_lifetime: false, .. } => None,
Self::StdStringView { lifetime, .. } => Some(lifetime.clone()),
Expand Down
16 changes: 13 additions & 3 deletions rs_bindings_from_cc/importers/cxx_record.cc
Original file line number Diff line number Diff line change
Expand Up @@ -309,9 +309,7 @@ std::optional<BridgeType> GetBridgeTypeAnnotation(
for (const clang::TemplateArgument& template_arg :
specialization_decl->getTemplateArgs().asArray()) {
if (template_arg.getKind() == clang::TemplateArgument::ArgKind::Type) {
// TODO(b/454627672): is record_decl the right decl to check for
// assumed_lifetimes?
template_args.emplace_back(
template_args.push_back(
ictx.ConvertQualType(template_arg.getAsType(),
/*lifetimes=*/nullptr, /*nullable=*/true,
ictx.AreAssumedLifetimesEnabledForTarget(
Expand Down Expand Up @@ -601,6 +599,18 @@ absl::StatusOr<TemplateSpecialization::Kind> GetTemplateSpecializationKind(
ictx.ConvertQualType(t, /*lifetimes=*/nullptr, /*nullable=*/true,
ictx.AreAssumedLifetimesEnabledForTarget(
ictx.GetOwningTarget(specialization_decl))));
} else if (templated_decl->getName() == "shared_ptr") {
if (specialization_decl->getTemplateArgs().size() != 1) {
return absl::InvalidArgumentError(
"std::shared_ptr must have exactly one template argument");
}
clang::QualType t = specialization_decl->getTemplateArgs()[0].getAsType();
return TemplateSpecialization::StdSharedPtr(
// TODO(b/454627672): is specialization_decl the right decl to check
// for assumed_lifetimes?
ictx.ConvertQualType(t, /*lifetimes=*/nullptr, /*nullable=*/true,
ictx.AreAssumedLifetimesEnabledForTarget(
ictx.GetOwningTarget(specialization_decl))));
} else if (templated_decl->getName() == "vector") {
CRUBIT_ASSIGN_OR_RETURN(
clang::QualType t,
Expand Down
6 changes: 6 additions & 0 deletions rs_bindings_from_cc/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,12 @@ llvm::json::Value TemplateSpecialization::ToJson() const {
llvm::json::Object{
{"element_type", std_vector.element_type}}}};
},
[&](const StdSharedPtr& std_shared_ptr) {
return llvm::json::Object{
{"StdSharedPtr",
llvm::json::Object{
{"element_type", std_shared_ptr.element_type}}}};
},
[&](const StdUniquePtr& std_unique_ptr) {
return llvm::json::Object{
{"StdUniquePtr",
Expand Down
8 changes: 6 additions & 2 deletions rs_bindings_from_cc/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,9 @@ struct TemplateSpecialization {
struct StdVector {
CcType element_type;
};
struct StdSharedPtr {
CcType element_type;
};
struct StdUniquePtr {
CcType element_type;
};
Expand All @@ -688,8 +691,9 @@ struct TemplateSpecialization {
};
struct NonSpecial {};

using Kind = std::variant<StdStringView, StdWStringView, StdVector,
StdUniquePtr, AbslSpan, C9Co, NonSpecial>;
using Kind =
std::variant<StdStringView, StdWStringView, StdVector, StdSharedPtr,
StdUniquePtr, AbslSpan, C9Co, NonSpecial>;

BazelLabel defining_target;
Kind kind = NonSpecial{};
Expand Down
19 changes: 15 additions & 4 deletions rs_bindings_from_cc/ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1122,13 +1122,24 @@ pub enum TemplateSpecializationKind {
/// std::basic_string_view<wchar_t, std::char_traits<wchar_t>>
StdWStringView,
/// std::vector<T, std::allocator<T>>
StdVector { element_type: CcType },
StdVector {
element_type: CcType,
},
/// std::unique_ptr<T, std::default_delete<T>>
StdUniquePtr { element_type: CcType },
StdSharedPtr {
element_type: CcType,
},
StdUniquePtr {
element_type: CcType,
},
/// c9::Co<T>
C9Co { element_type: CcType },
C9Co {
element_type: CcType,
},
/// absl::Span<T>
AbslSpan { element_type: CcType },
AbslSpan {
element_type: CcType,
},
/// Some other template specialization.
NonSpecial,
}
Expand Down
1 change: 1 addition & 0 deletions support/cc_std_impl/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ additional_rust_srcs_for_crubit_bindings(
additional_rust_srcs_for_crubit_bindings(
name = "manually_bridged_types",
srcs = [
"shared_ptr.rs",
"string.rs",
"string_view.rs",
"unique_ptr.rs",
Expand Down
47 changes: 47 additions & 0 deletions support/cc_std_impl/shared_ptr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Part of the Crubit project, under the Apache License v2.0 with LLVM
// Exceptions. See /LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

use crate::crubit_cc_std_internal::std_allocator;

#[allow(non_snake_case)]
#[repr(C)]
pub struct shared_ptr_const<T: Sized> {
ptr: *const T,
cntrl: *mut core::ffi::c_void,
}

impl<T: Sized> shared_ptr_const<T> {
pub unsafe fn from_raw_parts(ptr: *const T, cntrl: *mut core::ffi::c_void) -> Self {
Self { ptr, cntrl }
}

pub fn get(&self) -> *const T {
self.ptr
}

pub fn as_ref(&self) -> Option<&T> {
unsafe { self.ptr.as_ref() }
}
}

impl<T: Sized> Clone for shared_ptr_const<T> {
fn clone(&self) -> Self {
if !self.cntrl.is_null() {
unsafe {
std_allocator::shared_ptr_add_shared(self.cntrl);
}
}
Self { ptr: self.ptr, cntrl: self.cntrl }
}
}

impl<T: Sized> Drop for shared_ptr_const<T> {
fn drop(&mut self) {
if !self.cntrl.is_null() {
unsafe {
std_allocator::shared_ptr_release_shared(self.cntrl);
}
}
}
}
11 changes: 11 additions & 0 deletions support/cc_std_impl/std_allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#ifndef THIRD_PARTY_CRUBIT_SUPPORT_CC_STD_STD_ALLOCATOR_H_
#define THIRD_PARTY_CRUBIT_SUPPORT_CC_STD_STD_ALLOCATOR_H_

#include <__memory/shared_count.h>
#include <stdio.h>

#include <cstddef>
Expand Down Expand Up @@ -40,6 +41,16 @@ inline void cpp_delete(void* ptr, size_t n, size_t align) {
#endif
}

inline void shared_ptr_add_shared(void* cntrl) {
auto* weak_count = static_cast<std::__shared_weak_count*>(cntrl);
weak_count->__add_shared();
}

inline void shared_ptr_release_shared(void* cntrl) {
auto* weak_count = static_cast<std::__shared_weak_count*>(cntrl);
weak_count->__release_shared();
}

} // namespace crubit_cc_std_internal::std_allocator

#endif // THIRD_PARTY_CRUBIT_SUPPORT_CC_STD_STD_ALLOCATOR_H_
28 changes: 28 additions & 0 deletions support/cc_std_impl/test/shared_ptr/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Part of the Crubit project, under the Apache License v2.0 with LLVM
# Exceptions. See /LICENSE for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

load("//common:crubit_wrapper_macros_oss.bzl", "crubit_rust_test")
load("//rs_bindings_from_cc/test:test_bindings.bzl", "crubit_test_cc_library")

package(default_applicable_licenses = ["//:license"])

crubit_test_cc_library(
name = "test_helpers",
testonly = 1,
hdrs = ["test_helpers.h"],
aspect_hints = ["//features:experimental"],
deps = ["//support:annotations"],
)

crubit_rust_test(
name = "shared_ptr_test",
srcs = ["shared_ptr_test.rs"],
cc_deps = [
":test_helpers",
"//support/public:cc_std",
],
deps = [
"@crate_index//:googletest",
],
)
37 changes: 37 additions & 0 deletions support/cc_std_impl/test/shared_ptr/shared_ptr_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Part of the Crubit project, under the Apache License v2.0 with LLVM
// Exceptions. See /LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

use cc_std::std::shared_ptr_const;
use googletest::prelude::*;

#[gtest]
fn test_layout() {
// Testing that the layout matches C++ shared_ptr<const int32_t>
assert_eq!(
core::mem::size_of::<shared_ptr_const<i32>>(),
test_helpers::shared_ptr_test::get_shared_ptr_size() as usize
);
assert_eq!(
core::mem::align_of::<shared_ptr_const<i32>>(),
test_helpers::shared_ptr_test::get_shared_ptr_alignment() as usize
);
}

#[gtest]
fn test_polymorphic_destructor() {
// Initial destructor count should be 0.
assert_eq!(test_helpers::shared_ptr_test::get_derived_destructor_count(), 0);
{
let _shared_base = test_helpers::shared_ptr_test::create_virtual_base();
assert_eq!(core::mem::size_of_val(&_shared_base), 16);
}
// After dropping shared_ptr<const Base> which points to Derived, count should be 1.
assert_eq!(test_helpers::shared_ptr_test::get_derived_destructor_count(), 1);
}

#[gtest]
fn test_deref() {
let shared = test_helpers::shared_ptr_test::create_shared_ptr();
assert_eq!(*shared.as_ref().unwrap(), 1);
}
89 changes: 89 additions & 0 deletions support/cc_std_impl/test/shared_ptr/test_helpers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Part of the Crubit project, under the Apache License v2.0 with LLVM
// Exceptions. See /LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#ifndef THIRD_PARTY_CRUBIT_SUPPORT_CC_STD_IMPL_TEST_SHARED_PTR_TEST_HELPERS_H_
#define THIRD_PARTY_CRUBIT_SUPPORT_CC_STD_IMPL_TEST_SHARED_PTR_TEST_HELPERS_H_

#include <memory>

#include "support/annotations.h"

namespace shared_ptr_test {

struct TwoWords {
void* ptr1;
void* ptr2;
};

CRUBIT_MUST_BIND inline size_t get_shared_ptr_size() {
return sizeof(std::shared_ptr<const int>);
}

CRUBIT_MUST_BIND inline size_t get_shared_ptr_alignment() {
return alignof(std::shared_ptr<const int>);
}

CRUBIT_MUST_BIND inline std::shared_ptr<const int> create_shared_ptr() {
return std::make_shared<int>(1);
}
CRUBIT_MUST_BIND inline void destroy_shared_ptr(std::shared_ptr<const int>) {}
CRUBIT_MUST_BIND inline std::shared_ptr<const char> create_shared_ptr_char() {
return std::make_shared<char>('a');
}
CRUBIT_MUST_BIND inline std::shared_ptr<const short> create_shared_ptr_short() {
return std::make_shared<short>(static_cast<short>(1));
}
CRUBIT_MUST_BIND inline std::shared_ptr<const void>
create_shared_ptr_void_ptr() {
return std::shared_ptr<const void>(nullptr);
}
CRUBIT_MUST_BIND inline std::shared_ptr<const TwoWords>
create_shared_ptr_two_words() {
return std::make_shared<TwoWords>();
}

// Since shared_ptr uses a control block for type erasure, we can get a pointer
// to the base class and the control block will correctly destroy the derived
// type.
struct Base {
virtual ~Base() = default;
static inline int derived_destructor_count = 0;
virtual bool is_derived() const { return false; }
};
struct Derived : public Base {
~Derived() override { derived_destructor_count++; }
bool is_derived() const override { return true; }
};

CRUBIT_MUST_BIND inline std::shared_ptr<const Base> create_virtual_base() {
return std::make_shared<Derived>();
}

CRUBIT_MUST_BIND inline int get_derived_destructor_count() {
return Base::derived_destructor_count;
}

// Using a custom deleter via shared_ptr constructors.
struct CustomDelete {
static inline int custom_delete_count = 0;
};
struct CustomDeleter {
void operator()(const CustomDelete* p) const {
CustomDelete::custom_delete_count++;
delete p;
}
};

CRUBIT_MUST_BIND inline std::shared_ptr<const CustomDelete>
create_custom_delete() {
return std::shared_ptr<const CustomDelete>(new CustomDelete(),
CustomDeleter());
}

CRUBIT_MUST_BIND inline int get_custom_delete_count() {
return CustomDelete::custom_delete_count;
}

} // namespace shared_ptr_test

#endif // THIRD_PARTY_CRUBIT_SUPPORT_CC_STD_IMPL_TEST_SHARED_PTR_TEST_HELPERS_H_