|
3 | 3 | import resource |
4 | 4 | import sys |
5 | 5 | from abc import abstractmethod |
6 | | -from collections.abc import Sequence |
| 6 | +from collections.abc import Generator, Sequence |
7 | 7 | from functools import cached_property |
8 | 8 | from os import PathLike |
9 | 9 | from pathlib import Path |
@@ -744,7 +744,7 @@ def get_symbol(self, name: str) -> Symbol | None: |
744 | 744 | Returns: |
745 | 745 | Symbol | None: The found symbol, or None if not found. |
746 | 746 | """ |
747 | | - if symbol := self.resolve_name(name, self.end_byte): |
| 747 | + if symbol := next(self.resolve_name(name, self.end_byte), None): |
748 | 748 | if isinstance(symbol, Symbol): |
749 | 749 | return symbol |
750 | 750 | return next((x for x in self.symbols if x.name == name), None) |
@@ -819,7 +819,7 @@ def get_class(self, name: str) -> TClass | None: |
819 | 819 | Returns: |
820 | 820 | TClass | None: The matching Class object if found, None otherwise. |
821 | 821 | """ |
822 | | - if symbol := self.resolve_name(name, self.end_byte): |
| 822 | + if symbol := next(self.resolve_name(name, self.end_byte), None): |
823 | 823 | if isinstance(symbol, Class): |
824 | 824 | return symbol |
825 | 825 |
|
@@ -880,13 +880,41 @@ def valid_symbol_names(self) -> dict[str, Symbol | TImport | WildcardImport[TImp |
880 | 880 |
|
881 | 881 | @noapidoc |
882 | 882 | @reader |
883 | | - def resolve_name(self, name: str, start_byte: int | None = None) -> Symbol | Import | WildcardImport | None: |
| 883 | + def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]: |
| 884 | + """Resolves a name to a symbol, import, or wildcard import within the file's scope. |
| 885 | +
|
| 886 | + Performs name resolution by first checking the file's valid symbols and imports. When a start_byte |
| 887 | + is provided, ensures proper scope handling by only resolving to symbols that are defined before |
| 888 | + that position in the file. |
| 889 | +
|
| 890 | + Args: |
| 891 | + name (str): The name to resolve. |
| 892 | + start_byte (int | None): If provided, only resolves to symbols defined before this byte position |
| 893 | + in the file. Used for proper scope handling. Defaults to None. |
| 894 | + strict (bool): When True and using start_byte, only yields symbols if found in the correct scope. |
| 895 | + When False, allows falling back to global scope. Defaults to True. |
| 896 | +
|
| 897 | + Yields: |
| 898 | + Symbol | Import | WildcardImport: The resolved symbol, import, or wildcard import that matches |
| 899 | + the name and scope requirements. Yields at most one result. |
| 900 | + """ |
884 | 901 | if resolved := self.valid_symbol_names.get(name): |
| 902 | + # If we have a start_byte and the resolved symbol is after it, |
| 903 | + # we need to look for earlier definitions of the symbol |
885 | 904 | if start_byte is not None and resolved.end_byte > start_byte: |
886 | | - for symbol in self.symbols: |
| 905 | + # Search backwards through symbols to find the most recent definition |
| 906 | + # that comes before our start_byte position |
| 907 | + for symbol in reversed(self.symbols): |
887 | 908 | if symbol.start_byte <= start_byte and symbol.name == name: |
888 | | - return symbol |
889 | | - return resolved |
| 909 | + yield symbol |
| 910 | + return |
| 911 | + # If strict mode and no valid symbol found, return nothing |
| 912 | + if not strict: |
| 913 | + return |
| 914 | + # Either no start_byte constraint or symbol is before start_byte |
| 915 | + yield resolved |
| 916 | + return |
| 917 | + return |
890 | 918 |
|
891 | 919 | @property |
892 | 920 | @reader |
|
0 commit comments