diff --git a/examples/onnx2xla.py b/examples/onnx2xla.py index 196ab1c03b53..d5105f714f31 100644 --- a/examples/onnx2xla.py +++ b/examples/onnx2xla.py @@ -94,8 +94,7 @@ def onnx_add(a, b, axis=None, broadcast=True): def interpret_onnx(graph, *args): - vals = dict({n.name: a for n, a in zip(graph.input, args)}, - **{n.name: _asarray(n) for n in graph.initializer}) + vals = {n.name: a for n, a in zip(graph.input, args)} for node in graph.node: args = (vals[name] for name in node.input) attrs = {a.name: attribute_handlers[a.type](a) for a in node.attribute}