Skip to content

Commit 53f5680

Browse files
authored
Break infinite recursion in Struct.discoverMembers() (#1137)
1 parent c404452 commit 53f5680

File tree

1 file changed

+23
-10
lines changed

1 file changed

+23
-10
lines changed

cuda_bindings/setup.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -114,15 +114,28 @@ def __init__(self, name, members):
114114
self._member_names += [var_name]
115115
self._member_types += [var_type]
116116

117-
def discoverMembers(self, memberDict, prefix):
117+
def discoverMembers(self, memberDict, prefix, seen=None):
118+
# Prevent infinite recursion on self- or mutually-referential types
119+
if seen is None:
120+
seen = set()
121+
elif self._name in seen:
122+
return []
123+
118124
discovered = []
125+
next_seen = set(seen)
126+
next_seen.add(self._name)
127+
119128
for memberName, memberType in zip(self._member_names, self._member_types):
120129
if memberName:
121-
discovered += [".".join([prefix, memberName])]
122-
if memberType in memberDict:
123-
discovered += memberDict[memberType].discoverMembers(
124-
memberDict, discovered[-1] if memberName else prefix
130+
discovered.append(".".join([prefix, memberName]))
131+
132+
# Normalize to base type for lookup (strip qualifiers/pointers)
133+
t = memberType.replace("const ", "").replace("volatile ", "").strip().rstrip(" *")
134+
if t in memberDict and t != self._name:
135+
discovered += memberDict[t].discoverMembers(
136+
memberDict, discovered[-1] if memberName else prefix, next_seen
125137
)
138+
126139
return discovered
127140

128141
def __repr__(self):
@@ -153,16 +166,16 @@ def parse_headers(header_dict):
153166
r"char reserved\[52 - sizeof\(CUcheckpointGpuPair \*\)\];": rf"char reserved[{52 - 8}];",
154167
}
155168

156-
print(f'Parsing headers in "{include_path_list}" (Caching = {PARSER_CACHING})')
169+
print(f'Parsing headers in "{include_path_list}" (Caching = {PARSER_CACHING})', flush=True)
157170
for library, header_paths in header_dict.items():
158-
print(f"Parsing {library} headers")
171+
print(f"Parsing {library} headers", flush=True)
159172
parser = CParser(
160173
header_paths, cache="./cache_{}".format(library.split(".")[0]) if PARSER_CACHING else None, replace=replace
161174
)
162175

163176
if library == "driver":
164177
CUDA_VERSION = parser.defs["macros"].get("CUDA_VERSION", "Unknown")
165-
print(f"Found CUDA_VERSION: {CUDA_VERSION}")
178+
print(f"Found CUDA_VERSION: {CUDA_VERSION}", flush=True)
166179

167180
# Combine types with others since they sometimes get tangled
168181
found_types += {key for key in parser.defs["types"]}
@@ -211,10 +224,10 @@ def generate_output(infile, local):
211224
if os.path.exists(outfile):
212225
with open(outfile) as f:
213226
if f.read() == pxdcontent:
214-
print(f"Skipping {infile} (No change)")
227+
print(f"Skipping {infile} (No change)", flush=True)
215228
return
216229
with open(outfile, "w") as f:
217-
print(f"Generating {infile}")
230+
print(f"Generating {infile}", flush=True)
218231
f.write(pxdcontent)
219232

220233

0 commit comments

Comments
 (0)