Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
159f5a2
Various fixes, new tests and replacements of raises with logger warnings
Masara Aug 16, 2024
56c0a39
Fixing several bugs where return variables cant be parsed correctly
Masara Aug 17, 2024
f53d93b
Removed unused EnumType and BoundaryType; Added short Docstring descr…
Masara Aug 18, 2024
84c7237
Fixing a few bugs and replacing some raises with warnings
Masara Aug 18, 2024
091868b
added analysis of "not" statements as bool type
Masara Aug 21, 2024
32036c7
fixed a bug where, if stub files already existed, content that was al…
Masara Aug 21, 2024
16a589c
fixed a bug where some functions would not have "None" types even tho…
Masara Aug 23, 2024
702344f
Added TypeAliasType parsing for parameters
Masara Aug 25, 2024
b384f99
fixed a bug for type parsing where subclasses of aliases couldn't be …
Masara Aug 25, 2024
fe3ef13
fixed a bug where some class attributes defined in the __init__ funct…
Masara Aug 25, 2024
3921da7
fixed a bug where some class attributes would wrongly be labeled as T…
Masara Aug 25, 2024
fc46724
Added handling for results with operations (boolean results) and fixe…
Masara Aug 25, 2024
2fd3e5d
fixed a bug where type names would not be checked for naming conventi…
Masara Aug 26, 2024
54ae851
fixed the way reexported paths are searched and found; multiple text …
Masara Aug 27, 2024
881a1b6
reversed the change concerning None from commit #16a589c1
Masara Aug 27, 2024
ad2a11e
fixed a bug for parsing parameter types; replacing some logging.warni…
Masara Aug 27, 2024
9ba8221
Merge branch 'main' into various_fixes
Masara Oct 3, 2024
a6e640b
Big performance fix for the docstring parser by creating an indexer f…
Masara Oct 3, 2024
adc0a08
Trying to reduce the runtime for the "_add_to_imports" function in th…
Masara Oct 4, 2024
9c69b79
Merge remote-tracking branch 'origin/fix-docstring-runtime' into fix-…
Masara Oct 16, 2024
5a47d30
Runtime fix for the O(n^3) bug
Masara Oct 17, 2024
26577a1
Optimizing the _check_publicity_in_reexports function
Masara Oct 17, 2024
ee06dc2
Removed unused code
Masara Nov 10, 2024
7bf2ed1
Merge branch 'various_fixes' into fix-docstring-runtime
Masara Nov 10, 2024
9235d3d
Merge branch 'main' into various_fixes
Masara Nov 10, 2024
5655120
Merge branch 'various_fixes' into fix-docstring-runtime
Masara Nov 10, 2024
14771cb
Merge branch 'main' into fix-docstring-runtime
lars-reimann Mar 7, 2025
bcbc7d0
style: fix ruff errors
lars-reimann Mar 7, 2025
806b7e9
fix: mypy error
lars-reimann Mar 7, 2025
a7db44c
Fix for the import generation
Masara Apr 3, 2025
2a3cb09
Added test for bytes types conversion to Safe-DS stubs from docstrings
Masara Apr 3, 2025
3e5962f
Merge branch 'main' into fix-docstring-runtime
Masara Apr 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 5 additions & 13 deletions src/safeds_stubgen/api_analyzer/_ast_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,7 @@ def _create_inferred_results(
if type__ not in result_array[i]:
result_array[i].append(type__)

longest_inner_list = max(len(result_array[i]), longest_inner_list)
longest_inner_list = max(longest_inner_list, len(result_array[i]))
else:
result_array.append([type__])

Expand Down Expand Up @@ -1452,18 +1452,12 @@ def _check_publicity_in_reexports(self, name: str, qname: str, parent: Module |
continue

# If the whole module was reexported we have to check if the name or alias is intern
if module_is_reexported:
if module_is_reexported and not_internal and (isinstance(parent, Module) or parent.is_public):

# Check the wildcard imports of the source
for wildcard_import in reexport_source.wildcard_imports:
if (
(
(is_from_same_package and wildcard_import.module_name == module_name)
or (is_from_another_package and wildcard_import.module_name == module_qname)
)
and not_internal
and (isinstance(parent, Module) or parent.is_public)
):
if ((is_from_same_package and wildcard_import.module_name == module_name)
or (is_from_another_package and wildcard_import.module_name == module_qname)):
return True

# Check the qualified imports of the source
Expand All @@ -1474,11 +1468,9 @@ def _check_publicity_in_reexports(self, name: str, qname: str, parent: Module |
if (
qualified_import.qualified_name in {module_name, module_qname}
and (
(qualified_import.alias is None and not_internal)
qualified_import.alias is None
or (qualified_import.alias is not None and not is_internal(qualified_import.alias))
)
and not_internal
and (isinstance(parent, Module) or parent.is_public)
):
# If the module name or alias is not internal, check if the parent is public
return True
Expand Down
117 changes: 48 additions & 69 deletions src/safeds_stubgen/docstring_parsing/_docstring_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import logging
from typing import TYPE_CHECKING, Literal

from _griffe import models as griffe_models
from griffe import load
from griffe.dataclasses import Docstring
from griffe.docstrings.dataclasses import DocstringAttribute, DocstringParameter
from griffe.docstrings.utils import parse_annotation
from griffe.enumerations import DocstringSectionKind, Parser
Expand All @@ -25,34 +25,49 @@
if TYPE_CHECKING:
from pathlib import Path

from griffe.dataclasses import Object
from mypy import nodes


class DocstringParser(AbstractDocstringParser):
def __init__(self, parser: Parser, package_path: Path):
self.parser = parser

while True:
# If a package has no __init__.py file Griffe can't parse it, therefore we check the parent
try:
self.griffe_build = load(package_path, docstring_parser=parser)
griffe_build = load(package_path, docstring_parser=parser)
break
except KeyError:
package_path = package_path.parent

self.parser = parser
self.__cached_node: str | None = None
self.__cached_docstring: Docstring | None = None
self.griffe_index: dict[str, griffe_models.Object] = {}
self._recursive_griffe_indexer(griffe_build)

def _recursive_griffe_indexer(self, griffe_build: griffe_models.Object | griffe_models.Alias) -> None:
for member in griffe_build.all_members.values():
if isinstance(
member,
griffe_models.Class | griffe_models.Function | griffe_models.Attribute | griffe_models.Alias,
):
self.griffe_index[member.path] = member

if isinstance(member, griffe_models.Module | griffe_models.Class):
self._recursive_griffe_indexer(member)

def get_class_documentation(self, class_node: nodes.ClassDef) -> ClassDocstring:
griffe_node = self._get_griffe_node(class_node.fullname)
griffe_node = self.griffe_index.get(class_node.fullname, None)

if griffe_node is None: # pragma: no cover
raise TypeError(f"Expected a griffe node for {class_node.fullname}, got None.")
msg = (
f"Something went wrong while searching for the docstring for {class_node.fullname}. Please make sure"
" that all directories with python files have an __init__.py file."
)
logging.warning(msg)

description = ""
docstring = ""
examples = []
if griffe_node.docstring is not None:
if griffe_node is not None and griffe_node.docstring is not None:
docstring = griffe_node.docstring.value.strip("\n")

try:
Expand All @@ -76,7 +91,7 @@ def get_function_documentation(self, function_node: nodes.FuncDef) -> FunctionDo
docstring = ""
description = ""
examples = []
griffe_docstring = self.__get_cached_docstring(function_node.fullname)
griffe_docstring = self._get_griffe_docstring(function_node.fullname)
if griffe_docstring is not None:
docstring = griffe_docstring.value.strip("\n")

Expand Down Expand Up @@ -110,9 +125,9 @@ def get_parameter_documentation(
# For constructors (__init__ functions) the parameters are described on the class
if function_name == "__init__" and parent_class_qname:
parent_qname = parent_class_qname.replace("/", ".")
griffe_docstring = self.__get_cached_docstring(parent_qname)
griffe_docstring = self._get_griffe_docstring(parent_qname)
else:
griffe_docstring = self.__get_cached_docstring(function_qname)
griffe_docstring = self._get_griffe_docstring(function_qname)

# Find matching parameter docstrings
matching_parameters = []
Expand All @@ -123,7 +138,7 @@ def get_parameter_documentation(
# https://github.com/Safe-DS/Library-Analyzer/issues/10)
if self.parser == Parser.numpy and len(matching_parameters) == 0 and function_name == "__init__":
# Get constructor docstring & find matching parameter docstrings
constructor_docstring = self.__get_cached_docstring(function_qname)
constructor_docstring = self._get_griffe_docstring(function_qname)
if constructor_docstring is not None:
matching_parameters = self._get_matching_docstrings(constructor_docstring, parameter_name, "param")

Expand All @@ -136,7 +151,7 @@ def get_parameter_documentation(
raise TypeError(f"Expected parameter docstring, got {type(last_parameter)}.")

if griffe_docstring is None: # pragma: no cover
griffe_docstring = Docstring("")
griffe_docstring = griffe_models.Docstring("")

annotation = last_parameter.annotation
if annotation is None:
Expand All @@ -154,27 +169,23 @@ def get_parameter_documentation(
description=last_parameter.description.strip("\n") or "",
)

def get_attribute_documentation(
self,
parent_class_qname: str,
attribute_name: str,
) -> AttributeDocstring:
def get_attribute_documentation(self, parent_class_qname: str, attribute_name: str) -> AttributeDocstring:
parent_class_qname = parent_class_qname.replace("/", ".")

# Find matching attribute docstrings
parent_qname = parent_class_qname
griffe_docstring = self.__get_cached_docstring(parent_qname)
griffe_docstring = self._get_griffe_docstring(parent_qname)
if griffe_docstring is None:
matching_attributes = []
griffe_docstring = Docstring("")
griffe_docstring = griffe_models.Docstring("")
else:
matching_attributes = self._get_matching_docstrings(griffe_docstring, attribute_name, "attr")

# For Numpydoc, if the class has a constructor we have to check both the class and then the constructor
# (see issue https://github.com/Safe-DS/Library-Analyzer/issues/10)
if self.parser == Parser.numpy and len(matching_attributes) == 0:
constructor_qname = f"{parent_class_qname}.__init__"
constructor_docstring = self.__get_cached_docstring(constructor_qname)
constructor_docstring = self._get_griffe_docstring(constructor_qname)

# Find matching parameter docstrings
if constructor_docstring is not None:
Expand All @@ -198,7 +209,7 @@ def get_attribute_documentation(

def get_result_documentation(self, function_qname: str) -> list[ResultDocstring]:
# Find matching parameter docstrings
griffe_docstring = self.__get_cached_docstring(function_qname)
griffe_docstring = self._get_griffe_docstring(function_qname)

if griffe_docstring is None:
return []
Expand Down Expand Up @@ -251,7 +262,7 @@ def get_result_documentation(self, function_qname: str) -> list[ResultDocstring]

@staticmethod
def _get_matching_docstrings(
function_doc: Docstring,
function_doc: griffe_models.Docstring,
name: str,
type_: Literal["attr", "param"],
) -> list[DocstringAttribute | DocstringParameter]:
Expand All @@ -278,7 +289,7 @@ def _get_matching_docstrings(
def _griffe_annotation_to_api_type(
self,
annotation: Expr | str,
docstring: Docstring,
docstring: griffe_models.Docstring,
) -> sds_types.AbstractType | None:
if isinstance(annotation, ExprName | ExprAttribute):
if annotation.canonical_path == "typing.Any":
Expand All @@ -291,6 +302,8 @@ def _griffe_annotation_to_api_type(
return sds_types.NamedType(name="float", qname="builtins.float")
elif annotation.canonical_path == "str":
return sds_types.NamedType(name="str", qname="builtins.str")
elif annotation.canonical_path == "bytes":
return sds_types.NamedType(name="bytes", qname="builtins.bytes")
elif annotation.canonical_path == "list":
return sds_types.ListType(types=[])
elif annotation.canonical_path == "tuple":
Expand Down Expand Up @@ -403,49 +416,15 @@ def _remove_default_from_griffe_annotation(self, annotation: str) -> str:
return annotation.split(", default")[0]
return annotation

def _get_griffe_node(self, qname: str) -> Object | None:
node_qname_parts = qname.split(".")
griffe_node = self.griffe_build
for part in node_qname_parts:
if part in griffe_node.modules:
griffe_node = griffe_node.modules[part]
elif part in griffe_node.classes:
griffe_node = griffe_node.classes[part]
elif part in griffe_node.functions:
griffe_node = griffe_node.functions[part]
elif part in griffe_node.attributes:
griffe_node = griffe_node.attributes[part]
elif part == "__init__" and griffe_node.is_class:
return None
elif griffe_node.name == part:
continue
else: # pragma: no cover
msg = (
f"Something went wrong while searching for the docstring for {qname}. Please make sure"
" that all directories with python files have an __init__.py file.",
)
logging.warning(msg)

return griffe_node

def __get_cached_docstring(self, qname: str) -> Docstring | None:
"""
Return the Docstring for the given function node.
def _get_griffe_docstring(self, qname: str) -> griffe_models.Docstring | None:
griffe_node = self.griffe_index.get(qname, None)

It is only recomputed when the function node differs from the previous one that was passed to this function.
This avoids reparsing the docstring for the function itself and all of its parameters.
if griffe_node is not None:
return griffe_node.docstring

On Lars's system this caused a significant performance improvement: Previously, 8.382s were spent inside the
function get_parameter_documentation when parsing sklearn. Afterward, it was only 2.113s.
"""
if self.__cached_node != qname or qname.endswith("__init__"):
self.__cached_node = qname

griffe_node = self._get_griffe_node(qname)
if griffe_node is not None:
griffe_docstring = griffe_node.docstring
self.__cached_docstring = griffe_docstring
else:
self.__cached_docstring = None

return self.__cached_docstring
msg = (
f"Something went wrong while searching for the docstring for {qname}. Please make sure"
" that all directories with python files have an __init__.py file.",
)
logging.warning(msg)
return None
70 changes: 46 additions & 24 deletions src/safeds_stubgen/stubs_generator/_stub_string_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ def __init__(self, api: API, convert_identifiers: bool) -> None:
self.class_generics: list = []
self.module_imports: set[str] = set()
self.currently_creating_reexport_data: bool = False
self.import_index: dict[str, str] = {}

self.api = api
self.api: API = api
self.naming_convention = NamingConvention.SAFE_DS if convert_identifiers else NamingConvention.PYTHON
self.classes_outside_package: set[str] = set()
self.reexport_modules: dict[str, list[Class | Function]] = defaultdict(list)
Expand Down Expand Up @@ -1012,20 +1013,20 @@ def _has_node_shorter_reexport(self, node: Class | Function) -> bool:
return False

def _is_path_connected_to_class(self, path: str, class_path: str) -> bool:
if class_path.endswith(path):
if class_path.endswith(f"/{path}") or class_path == path:
return True

name = path.split("/")[-1]
class_name = class_path.split("/")[-1]
for reexport in self.api.reexport_map:
if reexport.endswith(name):
for module in self.api.reexport_map[reexport]:
if reexport.endswith(f"/{name}") or reexport == name:
for module in self.api.reexport_map[reexport]: # pragma: no cover
# Added "no cover" since I can't recreate this in the tests
if (
path.startswith(module.id)
and class_path.startswith(module.id)
and path.lstrip(module.id).lstrip("/") == name == class_name
): # pragma: no cover
):
return True

return False
Expand All @@ -1047,28 +1048,49 @@ def _add_to_imports(self, import_qname: str) -> None:

module_id = self._get_module_id(get_actual_id=True).replace("/", ".")
if module_id not in import_qname:
# We need the full path for an import from the same package, but we sometimes don't get enough information,
# therefore we have to search for the class and get its id
import_qname_path = import_qname.replace(".", "/")
in_package = False
qname = ""
for class_id in self.api.classes:
if self._is_path_connected_to_class(import_qname_path, class_id):
qname = class_id.replace("/", ".")

name = qname.split(".")[-1]
shortest_qname, _ = _get_shortest_public_reexport_and_alias(
reexport_map=self.api.reexport_map,
name=name,
qname=qname,
is_module=False,
)
module_id_parts = module_id.split(".")

# First we hope that we already found and indexed the type we are searching
if import_qname in self.import_index:
qname = self.import_index[import_qname]

# To save performance we next try to build the possible paths the type could originate from
if not qname:
for i in range(1, len(module_id_parts)):
test_id = ".".join(module_id_parts[:-i]) + "." + import_qname
if test_id.replace(".", "/") in self.api.classes:
qname = test_id
break

# If the tries above did not work we have to use this performance heavy way.
# We need the full path for an import from the same package, but we sometimes don't get enough
# information, therefore we have to search for the class and get its id
if not qname:
import_qname_path = import_qname.replace(".", "/")
import_path_name = import_qname_path.split("/")[-1]
for class_id in self.api.classes:
if (import_path_name == class_id.split("/")[-1] and
self._is_path_connected_to_class(import_qname_path, class_id)):
qname = class_id.replace("/", ".")
break

in_package = False
if qname:
self.import_index[import_qname] = qname

name = qname.split(".")[-1]
shortest_qname, _ = _get_shortest_public_reexport_and_alias(
reexport_map=self.api.reexport_map,
name=name,
qname=qname,
is_module=False,
)

if shortest_qname:
qname = f"{shortest_qname}.{name}"
if shortest_qname:
qname = f"{shortest_qname}.{name}"

in_package = True
break
in_package = True

qname = qname or import_qname

Expand Down
Loading