From 7e3291b250cac0823286d1c1f2071aad2c2805ec Mon Sep 17 00:00:00 2001 From: rodrigo Date: Fri, 12 Sep 2025 20:13:11 +0100 Subject: [PATCH 1/2] resolve_getitem_source implementation --- hls4ml/converters/pytorch_to_hls.py | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/hls4ml/converters/pytorch_to_hls.py b/hls4ml/converters/pytorch_to_hls.py index 4bc3fbe854..ffc0af7666 100644 --- a/hls4ml/converters/pytorch_to_hls.py +++ b/hls4ml/converters/pytorch_to_hls.py @@ -238,13 +238,27 @@ def parse_pytorch_model(config, verbose=True): # if a 'getitem' is the input to a node, step back in the graph to find the real source of the input elif "getitem" in node.args[0].name: - for tmp_node in traced_model.nodes: - if tmp_node.name == node.args[0].name: - if "getitem" in tmp_node.args[0].name: - raise Exception('Nested getitem calles not resolved at the moment.') - input_names = [inputs_map.get(str(tmp_node.args[0]), str(tmp_node.args[0]))] - input_shapes = [output_shapes[str(tmp_node.args[0])]] - node.args = [tmp_node.args[0]] + def resolve_getitem_source(node_name, visited=None): + """Recursively resolve nested getitem calls to find the actual source node.""" + if visited is None: + visited = set() + + if node_name in visited: + raise Exception(f'Circular reference detected in getitem chain: {node_name}') + visited.add(node_name) + + for tmp_node in traced_model.nodes: + if tmp_node.name == node_name: + if "getitem" in tmp_node.args[0].name: + return resolve_getitem_source(tmp_node.args[0].name, visited) + else: + return tmp_node.args[0] + raise Exception(f'Could not find source node for getitem: {node_name}') + + source_node = resolve_getitem_source(node.args[0].name) + input_names = [inputs_map.get(str(source_node), str(source_node))] + input_shapes = [output_shapes[str(source_node)]] + node.args = [source_node] else: input_shapes = [output_shapes[str(i)] for i in node.args] # for Conv layers @@ -426,3 +440,4 @@ def parse_pytorch_model(config, verbose=True): def pytorch_to_hls(config): layer_list, input_layers, output_layers = parse_pytorch_model(config) return ModelGraph.from_layer_list(config, layer_list, inputs=input_layers, outputs=output_layers) + From eee2ea8992bb3412f5bbbab3ddcf443a1c43529a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 16 Sep 2025 14:00:11 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit hooks --- hls4ml/converters/pytorch_to_hls.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/hls4ml/converters/pytorch_to_hls.py b/hls4ml/converters/pytorch_to_hls.py index ffc0af7666..4414816fd3 100644 --- a/hls4ml/converters/pytorch_to_hls.py +++ b/hls4ml/converters/pytorch_to_hls.py @@ -238,15 +238,16 @@ def parse_pytorch_model(config, verbose=True): # if a 'getitem' is the input to a node, step back in the graph to find the real source of the input elif "getitem" in node.args[0].name: + def resolve_getitem_source(node_name, visited=None): """Recursively resolve nested getitem calls to find the actual source node.""" if visited is None: visited = set() - + if node_name in visited: raise Exception(f'Circular reference detected in getitem chain: {node_name}') visited.add(node_name) - + for tmp_node in traced_model.nodes: if tmp_node.name == node_name: if "getitem" in tmp_node.args[0].name: @@ -254,7 +255,7 @@ def resolve_getitem_source(node_name, visited=None): else: return tmp_node.args[0] raise Exception(f'Could not find source node for getitem: {node_name}') - + source_node = resolve_getitem_source(node.args[0].name) input_names = [inputs_map.get(str(source_node), str(source_node))] input_shapes = [output_shapes[str(source_node)]] @@ -440,4 +441,3 @@ def resolve_getitem_source(node_name, visited=None): def pytorch_to_hls(config): layer_list, input_layers, output_layers = parse_pytorch_model(config) return ModelGraph.from_layer_list(config, layer_list, inputs=input_layers, outputs=output_layers) -