diff --git a/lldb/include/lldb/Target/Language.h b/lldb/include/lldb/Target/Language.h index ffa729753a5fa..694fc44312b4a 100644 --- a/lldb/include/lldb/Target/Language.h +++ b/lldb/include/lldb/Target/Language.h @@ -409,8 +409,7 @@ class Language : public PluginInterface { /// returns the name of the parent function where the input closure was /// defined. Returns an empty string if there is no such parent, or if the /// query does not make sense for this language. - virtual std::string - GetParentNameIfClosure(llvm::StringRef mangled_name) const { + virtual std::string GetParentNameIfClosure(Function &function) const { return ""; } diff --git a/lldb/source/Plugins/Language/Swift/SwiftLanguage.cpp b/lldb/source/Plugins/Language/Swift/SwiftLanguage.cpp index d4da1ac3b0ae4..7c118ecdb1620 100644 --- a/lldb/source/Plugins/Language/Swift/SwiftLanguage.cpp +++ b/lldb/source/Plugins/Language/Swift/SwiftLanguage.cpp @@ -2167,9 +2167,8 @@ SwiftLanguage::AreEqualForFrameComparison(const SymbolContext &sc1, llvm_unreachable("unhandled enumeration in AreEquivalentFunctions"); } -std::string -SwiftLanguage::GetParentNameIfClosure(llvm::StringRef mangled_name) const { - return SwiftLanguageRuntime::GetParentNameIfClosure(mangled_name); +std::string SwiftLanguage::GetParentNameIfClosure(Function &func) const { + return SwiftLanguageRuntime::GetParentNameIfClosure(func); } //------------------------------------------------------------------ diff --git a/lldb/source/Plugins/Language/Swift/SwiftLanguage.h b/lldb/source/Plugins/Language/Swift/SwiftLanguage.h index 9a78a3dd85fc0..248895f5b5e92 100644 --- a/lldb/source/Plugins/Language/Swift/SwiftLanguage.h +++ b/lldb/source/Plugins/Language/Swift/SwiftLanguage.h @@ -113,8 +113,7 @@ class SwiftLanguage : public Language { AreEqualForFrameComparison(const SymbolContext &sc1, const SymbolContext &sc2) const override; - std::string - GetParentNameIfClosure(llvm::StringRef mangled_name) const override; + std::string GetParentNameIfClosure(Function &func) const override; //------------------------------------------------------------------ // Static Functions //------------------------------------------------------------------ diff --git a/lldb/source/Plugins/LanguageRuntime/Swift/SwiftLanguageRuntime.h b/lldb/source/Plugins/LanguageRuntime/Swift/SwiftLanguageRuntime.h index 4dfcfad996094..f6edfeae40c71 100644 --- a/lldb/source/Plugins/LanguageRuntime/Swift/SwiftLanguageRuntime.h +++ b/lldb/source/Plugins/LanguageRuntime/Swift/SwiftLanguageRuntime.h @@ -205,7 +205,7 @@ class SwiftLanguageRuntime : public LanguageRuntime { const SymbolContext *sc = nullptr, const ExecutionContext *exe_ctx = nullptr); - static std::string GetParentNameIfClosure(llvm::StringRef mangled_name); + static std::string GetParentNameIfClosure(Function &func); /// Demangle a symbol to a swift::Demangle node tree. /// diff --git a/lldb/source/Plugins/LanguageRuntime/Swift/SwiftLanguageRuntimeNames.cpp b/lldb/source/Plugins/LanguageRuntime/Swift/SwiftLanguageRuntimeNames.cpp index 7e13523f2a802..1bfffe70c181b 100644 --- a/lldb/source/Plugins/LanguageRuntime/Swift/SwiftLanguageRuntimeNames.cpp +++ b/lldb/source/Plugins/LanguageRuntime/Swift/SwiftLanguageRuntimeNames.cpp @@ -1499,9 +1499,52 @@ SwiftLanguageRuntime::GetGenericSignature(llvm::StringRef function_name, return signature; } -std::string SwiftLanguageRuntime::GetParentNameIfClosure(StringRef name) { +/// Returns true if a function called `symbol_name` exists in `module`. +static bool SymbolExists(StringRef symbol_name, Module &module) { + SymbolContextList sc_list; + Module::LookupInfo lookup_info(ConstString(symbol_name), + lldb::FunctionNameType::eFunctionNameTypeFull, + lldb::eLanguageTypeSwift); + module.FindFunctions(lookup_info, CompilerDeclContext(), + ModuleFunctionSearchOptions(), sc_list); + return !sc_list.IsEmpty(); +} + +/// There are two possible variants of constructors: allocating +/// (Kind::Allocator) and initializing (Kind::Constructor). When mangling a +/// closure, the "context" is arbitrarily chosen to be the Kind::Constructor +/// variant. This function checks for the existence of the Kind::Constructor, +/// given by the symbol `ctor_kind_mangled_name` in the module. If it exists, +/// `ctor_kind_mangled_name` is returned. Otherwise, creates and returns the +/// symbol name for the `Kind::Allocating` variant. +static std::string +HandleCtorAndAllocatorVariants(StringRef ctor_kind_mangled_name, Module &module, + Node &top_level_node, Node &ctor_node) { + if (SymbolExists(ctor_kind_mangled_name, module)) + return ctor_kind_mangled_name.str(); + + using Kind = Node::Kind; + NodeFactory factory; + + // Create a kind::Allocator node with all the children of the + // kind::Constructor node. + Node *allocating_ctor = factory.createNode(Kind::Allocator); + if (!allocating_ctor) + return ""; + for (auto *child : ctor_node) + allocating_ctor->addChild(child, factory); + swift_demangle::ReplaceChildWith(top_level_node, ctor_node, *allocating_ctor); + + if (auto mangled = swift::Demangle::mangleNode(&top_level_node); + mangled.isSuccess()) + return std::move(mangled.result()); + return ""; +} + +std::string SwiftLanguageRuntime::GetParentNameIfClosure(Function &func) { using Kind = Node::Kind; swift::Demangle::Context ctx; + ConstString name = func.GetMangled().GetMangledName(); auto *node = SwiftLanguageRuntime::DemangleSymbolAsNode(name, ctx); if (!node || node->getKind() != Node::Kind::Global) return ""; @@ -1521,9 +1564,14 @@ std::string SwiftLanguageRuntime::GetParentNameIfClosure(StringRef name) { return ""; swift_demangle::ReplaceChildWith(*node, *closure_node, *parent_func_node); - if (ManglingErrorOr mangled = swift::Demangle::mangleNode(node); - mangled.isSuccess()) - return mangled.result(); - return ""; + ManglingErrorOr mangled = swift::Demangle::mangleNode(node); + if (!mangled.isSuccess()) + return ""; + + if (parent_func_node->getKind() == Kind::Constructor) + if (Module *module = func.CalculateSymbolContextModule().get()) + return HandleCtorAndAllocatorVariants(mangled.result(), *module, *node, + *parent_func_node); + return mangled.result(); } } // namespace lldb_private diff --git a/lldb/source/Target/Language.cpp b/lldb/source/Target/Language.cpp index 942b4140c4fc3..f61c372d5d3f0 100644 --- a/lldb/source/Target/Language.cpp +++ b/lldb/source/Target/Language.cpp @@ -601,8 +601,7 @@ GetParentFunctionsWhileClosure(const SymbolContext &sc, parents.push_back(root); for (int idx = 0; idx < upper_limit; idx++) { - ConstString mangled = root->GetMangled().GetMangledName(); - std::string parent = language.GetParentNameIfClosure(mangled); + std::string parent = language.GetParentNameIfClosure(*root); if (parent.empty()) break; diff --git a/lldb/test/API/lang/swift/closures_var_not_captured/TestSwiftClosureVarNotCaptured.py b/lldb/test/API/lang/swift/closures_var_not_captured/TestSwiftClosureVarNotCaptured.py index 38a78f084b04a..38b138c8f69c2 100644 --- a/lldb/test/API/lang/swift/closures_var_not_captured/TestSwiftClosureVarNotCaptured.py +++ b/lldb/test/API/lang/swift/closures_var_not_captured/TestSwiftClosureVarNotCaptured.py @@ -110,6 +110,35 @@ def test_async_closure(self): def test_ctor_class_closure(self): self.build() (target, process, thread) = self.get_to_bkpt("break_ctor_class") + check_not_captured_error( + self, thread.frames[0], "input", "MY_CLASS.init(input:)" + ) + check_not_captured_error( + self, thread.frames[0], "find_me", "MY_CLASS.init(input:)" + ) + check_no_enhanced_diagnostic(self, thread.frames[0], "dont_find_me") + + lldbutil.continue_to_source_breakpoint( + self, process, "break_static_member_class", lldb.SBFileSpec("main.swift") + ) + check_not_captured_error( + self, + thread.frames[0], + "input_static", + "static MY_CLASS.static_func(input_static:)", + ) + check_not_captured_error( + self, + thread.frames[0], + "find_me_static", + "static MY_CLASS.static_func(input_static:)", + ) + check_no_enhanced_diagnostic(self, thread.frames[0], "dont_find_me_static") + + @swiftTest + def test_ctor_struct_closure(self): + self.build() + (target, process, thread) = self.get_to_bkpt("break_ctor_struct") check_not_captured_error( self, thread.frames[0], "input", "MY_STRUCT.init(input:)" ) @@ -119,7 +148,7 @@ def test_ctor_class_closure(self): check_no_enhanced_diagnostic(self, thread.frames[0], "dont_find_me") lldbutil.continue_to_source_breakpoint( - self, process, "break_static_member", lldb.SBFileSpec("main.swift") + self, process, "break_static_member_struct", lldb.SBFileSpec("main.swift") ) check_not_captured_error( self, @@ -134,3 +163,32 @@ def test_ctor_class_closure(self): "static MY_STRUCT.static_func(input_static:)", ) check_no_enhanced_diagnostic(self, thread.frames[0], "dont_find_me_static") + + @swiftTest + def test_ctor_enum_closure(self): + self.build() + (target, process, thread) = self.get_to_bkpt("break_ctor_enum") + check_not_captured_error( + self, thread.frames[0], "input", "MY_ENUM.init(input:)" + ) + check_not_captured_error( + self, thread.frames[0], "find_me", "MY_ENUM.init(input:)" + ) + check_no_enhanced_diagnostic(self, thread.frames[0], "dont_find_me") + + lldbutil.continue_to_source_breakpoint( + self, process, "break_static_member_enum", lldb.SBFileSpec("main.swift") + ) + check_not_captured_error( + self, + thread.frames[0], + "input_static", + "static MY_ENUM.static_func(input_static:)", + ) + check_not_captured_error( + self, + thread.frames[0], + "find_me_static", + "static MY_ENUM.static_func(input_static:)", + ) + check_no_enhanced_diagnostic(self, thread.frames[0], "dont_find_me_static") diff --git a/lldb/test/API/lang/swift/closures_var_not_captured/main.swift b/lldb/test/API/lang/swift/closures_var_not_captured/main.swift index a3c69f63f200a..f04e36e9faff6 100644 --- a/lldb/test/API/lang/swift/closures_var_not_captured/main.swift +++ b/lldb/test/API/lang/swift/closures_var_not_captured/main.swift @@ -87,7 +87,7 @@ func func_3(arg: Int) async { print(dont_find_me) } -class MY_STRUCT { +class MY_CLASS { init(input: [Int]) { let find_me = "hello" let _ = input.map { @@ -99,7 +99,52 @@ class MY_STRUCT { static func static_func(input_static: [Int]) { let find_me_static = "hello" let _ = input_static.map { - return $0 // break_static_member + return $0 // break_static_member_class + } + let dont_find_me_static = "hello" + } +} + +struct MY_STRUCT { + init(input: [Int]) { + let find_me = "hello" + let _ = input.map { + print("break_ctor_struct") + return $0 + } + let dont_find_me = "hello" + } + + static func static_func(input_static: [Int]) { + let find_me_static = "hello" + let _ = input_static.map { + print("break_static_member_struct") + return $0 + } + let dont_find_me_static = "hello" + } +} + +enum MY_ENUM { + case case1(Double) + case case2(Double) + + init(input: [Int]) { + let find_me = "hello" + let _ = input.map { + print("break_ctor_enum") + return $0 + } + + let dont_find_me = "hello" + self = .case1(42.0) + } + + static func static_func(input_static: [Int]) { + let find_me_static = "hello" + let _ = input_static.map { + print("break_static_member_enum") + return $0 } let dont_find_me_static = "hello" } @@ -108,5 +153,9 @@ class MY_STRUCT { func_1(arg: 42) func_2(arg: 42) await func_3(arg: 42) +let _ = MY_CLASS(input: [1, 2]) +MY_CLASS.static_func(input_static: [42]) let _ = MY_STRUCT(input: [1, 2]) MY_STRUCT.static_func(input_static: [42]) +let _ = MY_ENUM(input: [1,2]) +MY_ENUM.static_func(input_static: [42])