From 78579367ad944ff90fd60259ec3d735e54f4d270 Mon Sep 17 00:00:00 2001 From: Chen-Yu Yang Date: Fri, 6 Mar 2026 21:06:29 -0500 Subject: [PATCH] don't use intermediate dict in onnx parse also don't parse fields that are never used --- tinygrad/nn/onnx.py | 134 ++++++++++++++++++++------------------------ 1 file changed, 60 insertions(+), 74 deletions(-) diff --git a/tinygrad/nn/onnx.py b/tinygrad/nn/onnx.py index 90d98d7e851e8..e9174b4f99e37 100644 --- a/tinygrad/nn/onnx.py +++ b/tinygrad/nn/onnx.py @@ -153,21 +153,19 @@ def _decode_end_pos(self) -> int: def _parse_ModelProto(self) -> dict: """Entry point for parsing the ONNX model.""" - obj: dict[str, Any] = {"opset_import": []} + graph: dict|None = None + opset_imports: list[OpSetId] = [] for fid, wire_type in self._parse_message(self.reader.len): match fid: - case 4: obj["domain"] = self.reader.read_string() - case 5: obj["model_version"] = self.reader.read_int64() - case 7: obj["graph"] = self._parse_GraphProto() - case 8: obj["opset_import"].append(self._parse_OperatorSetIdProto()) + case 7: graph = self._parse_GraphProto() + case 8: opset_imports.append(self._parse_OperatorSetIdProto()) case _: self.reader.skip_field(wire_type) - + assert graph is not None # update opset version - opset_imports = {Domain.from_onnx(x.get('domain')):x.get('version', 1) for x in obj["opset_import"]} - for n in obj["graph"]["node"]: - n_ = n["parsed_node"] - n["parsed_node"] = OnnxNode(n_.op, OpSetId(n_.opset_id.domain, opset_imports.get(n_.opset_id.domain, 1)), n_.inputs, n_.outputs, n_.opts) - return obj + versions = {opset.domain: opset.version for opset in opset_imports} + graph["node"] = [OnnxNode(n.op, OpSetId(n.opset_id.domain, versions.get(n.opset_id.domain, 1)), n.inputs, n.outputs, n.opts) + for n in graph["node"]] + return graph def _parse_GraphProto(self) -> dict: obj: dict[str, Any] = {"node": [], "initializer": [], "input": [], "output": []} @@ -181,26 +179,23 @@ def _parse_GraphProto(self) -> dict: case _: self.reader.skip_field(wire_type) return obj - def _parse_NodeProto(self) -> dict: - obj: dict[str, Any] = {"input": [], "output": [], "attribute": [], "domain": None} + def _parse_NodeProto(self) -> OnnxNode: + inputs: list[str] = [] + outputs: list[str] = [] + attributes: list[tuple[str, Any]] = [] + domain: str|None = None + op_type = "" for fid, wire_type in self._parse_message(self._decode_end_pos()): match fid: - case 1: obj["input"].append(self.reader.read_string()) - case 2: obj["output"].append(self.reader.read_string()) - case 3: obj["name"] = self.reader.read_string() - case 4: obj["op_type"] = self.reader.read_string() - case 5: obj["attribute"].append(self._parse_AttributeProto()) - case 6: obj["doc_string"] = self.reader.read_string() - case 7: obj["domain"] = self.reader.read_string() + case 1: inputs.append(self.reader.read_string()) + case 2: outputs.append(self.reader.read_string()) + case 4: op_type = self.reader.read_string() + case 5: attributes.append(self._parse_AttributeProto()) + case 7: domain = self.reader.read_string() case _: self.reader.skip_field(wire_type) + return OnnxNode(op_type, OpSetId(Domain.from_onnx(domain), 1), tuple(inputs), tuple(outputs), dict(attributes)) - # parse node - attributes = {attr_dict["name"]: attr_dict[AttributeType(attr_dict["type"]).to_field_name()] for attr_dict in obj["attribute"]} - opset_id = OpSetId(Domain.from_onnx(obj.get('domain')), 1) # default version, to be updated later in _parse_ModelProto - obj["parsed_node"] = OnnxNode(obj["op_type"], opset_id, tuple(obj["input"]), tuple(obj["output"]), attributes) - return obj - - def _parse_TensorProto(self) -> dict: + def _parse_TensorProto(self) -> tuple[str, Tensor]: obj: dict[str, Any] = {"dims": []} for fid, wire_type in self._parse_message(self._decode_end_pos()): match fid: @@ -220,18 +215,16 @@ def _parse_TensorProto(self) -> dict: # load external data if self.load_external_data and obj.get("data_location", 0) == 1: if "external_data" not in obj: raise ValueError("no external_data") - location, length, offset = None, None, 0 - for kv in obj["external_data"]: - if kv["key"] == "location": location = kv["value"] - elif kv["key"] == "offset": offset = int(kv["value"]) - elif kv["key"] == "length": length = int(kv["value"]) - if location is None: raise ValueError("no location in external_data") + ext = dict(obj["external_data"]) + if "location" not in ext: raise ValueError("no location in external_data") + offset = int(ext.get("offset", "0")) + length = int(ext["length"]) if "length" in ext else None if self.file_path is None: if isinstance(self.tensor.device, str) and self.tensor.device.startswith("DISK:"): self.file_path = pathlib.Path(self.tensor.device[5:]) else: raise ValueError("onnx external_data needs the origin file path, try passing onnx file path to onnx_load") - ext_path = self.file_path.parent.joinpath(location) + ext_path = self.file_path.parent.joinpath(ext["location"]) if not ext_path.exists(): raise FileNotFoundError(f"external location not exists: {ext_path}") ext_tensor = Tensor(ext_path) @@ -241,23 +234,20 @@ def _parse_TensorProto(self) -> dict: # parse tensor to_dtype = dtype_fallback(true_dtype := OnnxDataType(obj['data_type']).to_dtype(), "buffer parse") shape = tuple(obj['dims']) - present_fields = [field for field in ['float_data', 'int32_data', 'int64_data', 'double_data', 'uint64_data', 'raw_data'] if field in obj] - assert len(present_fields) == 1, f"only 1 data field is allowed from {obj=}" - data = obj[present_fields[0]] - if not isinstance(data, Tensor): - obj["parsed_tensor"] = Tensor(data, dtype=to_dtype).reshape(shape) - return obj - assert isinstance(data, Tensor) and data.dtype == dtypes.uint8, data + data_fields = [f for f in ('float_data','int32_data','int64_data','double_data','uint64_data','raw_data') if f in obj] + data = obj[get_single_element(data_fields)] + name = obj.get("name", "") + if not isinstance(data, Tensor): return name, Tensor(data, dtype=to_dtype).reshape(shape) + assert data.dtype == dtypes.uint8, data data = data.bitcast(true_dtype).reshape(shape) data = data.to(Device.DEFAULT) if true_dtype is to_dtype else data.to("cpu").cast(to_dtype).to(Device.DEFAULT) # const folding if shape == (): if data.dtype == dtypes.float16 and sys.version_info < (3, 12): data = data.cast(dtypes.float32) data = Tensor(data.item(), dtype=to_dtype).reshape(shape) - obj["parsed_tensor"] = data - return obj + return name, data - def _parse_AttributeProto(self) -> dict: + def _parse_AttributeProto(self) -> tuple[str, Any]: obj: dict[str, Any] = {"floats": [], "ints": [], "strings": []} for fid, wire_type in self._parse_message(self._decode_end_pos()): match fid: @@ -265,7 +255,7 @@ def _parse_AttributeProto(self) -> dict: case 2: obj["f"] = self.reader.read_float() case 3: obj["i"] = self.reader.read_int64() case 4: obj["s"] = self.reader.read_bytes().data().tobytes().decode("utf8") - case 5: obj["t"] = self._parse_TensorProto()['parsed_tensor'] + case 5: obj["t"] = self._parse_TensorProto()[1] case 6: obj["g"] = OnnxRunner._from_subgraph(self._parse_GraphProto()) case 7: obj["floats"].append(self.reader.read_float()) case 8: obj["ints"].append(self.reader.read_int64()) @@ -273,26 +263,22 @@ def _parse_AttributeProto(self) -> dict: case 20: obj["type"] = self.reader.read_int64() case _: self.reader.skip_field(wire_type) obj["floats"], obj["ints"], obj["strings"] = tuple(obj["floats"]), tuple(obj["ints"]), tuple(obj["strings"]) - return obj + return obj["name"], obj[AttributeType(obj["type"]).to_field_name()] - def _parse_ValueInfoProto(self) -> dict: - obj: dict[str, Any] = {} + def _parse_ValueInfoProto(self) -> tuple[str, OnnxValue|None]: + name, type_obj = "", None for fid, wire_type in self._parse_message(self._decode_end_pos()): match fid: - case 1: obj["name"] = self.reader.read_string() - case 2: obj["type"] = self._parse_TypeProto() + case 1: name = self.reader.read_string() + case 2: type_obj = self._parse_TypeProto() case _: self.reader.skip_field(wire_type) - - # parse type - if "type" not in obj: return {**obj, "parsed_type": None} - type_obj = obj["type"] + if type_obj is None: return name, None if is_optional := "optional_type" in type_obj: type_obj = type_obj["optional_type"]["elem_type"] if is_sequence := "sequence_type" in type_obj: type_obj = type_obj["sequence_type"]["elem_type"] assert "tensor_type" in type_obj, type_obj shape_dims = type_obj['tensor_type'].get('shape', {}).get('dim', []) - obj['parsed_type'] = OnnxValue(tuple(d.get('dim_param') or d.get('dim_value') for d in shape_dims), - OnnxDataType(type_obj['tensor_type']['elem_type']).to_dtype(), is_optional, is_sequence) - return obj + return name, OnnxValue(tuple(d.get('dim_param') or d.get('dim_value') for d in shape_dims), + OnnxDataType(type_obj['tensor_type']['elem_type']).to_dtype(), is_optional, is_sequence) def _parse_TypeProto(self) -> dict: obj: dict[str, Any] = {} @@ -338,23 +324,24 @@ def _parse_TensorShapeProtoDimension(self) -> dict: case _: self.reader.skip_field(wire_type) return obj - def _parse_StringStringEntryProto(self) -> dict: - obj: dict[str, Any] = {} + def _parse_StringStringEntryProto(self) -> tuple[str, str]: + key, value = "", "" for fid, wire_type in self._parse_message(self._decode_end_pos()): match fid: - case 1: obj["key"] = self.reader.read_string() - case 2: obj["value"] = self.reader.read_string() + case 1: key = self.reader.read_string() + case 2: value = self.reader.read_string() case _: self.reader.skip_field(wire_type) - return obj + return key, value - def _parse_OperatorSetIdProto(self) -> dict: - obj: dict[str, Any] = {} + def _parse_OperatorSetIdProto(self) -> OpSetId: + domain: str|None = None + version = 1 for fid, wire_type in self._parse_message(self._decode_end_pos()): match fid: - case 1: obj["domain"] = self.reader.read_string() - case 2: obj["version"] = self.reader.read_int64() + case 1: domain = self.reader.read_string() + case 2: version = self.reader.read_int64() case _: self.reader.skip_field(wire_type) - return obj + return OpSetId(Domain.from_onnx(domain), version) # ***** python const ***** required_input_python_consts: dict[str, tuple[int, ...]] = { @@ -380,16 +367,15 @@ class OnnxRunner: model_path: The ONNX model, provided as a file path (a string or Path object) or a Tensor. """ def __init__(self, model_path: Tensor | str | pathlib.Path): - model = OnnxPBParser(model_path, load_external_data=True).parse() - self._init_from_graph(model["graph"]) + self._init_from_graph(OnnxPBParser(model_path, load_external_data=True).parse()) def _init_from_graph(self, graph: dict, is_subgraph: bool = False): - self.is_training = any(n['parsed_node'].opset_id.domain in {Domain.AI_ONNX_TRAINING, Domain.AI_ONNX_PREVIEW_TRAINING} for n in graph["node"]) + self.is_training = any(n.opset_id.domain in {Domain.AI_ONNX_TRAINING, Domain.AI_ONNX_PREVIEW_TRAINING} for n in graph["node"]) self.graph_name = graph["name"] if is_subgraph else "" - self.graph_values = {"": None, **{i["name"]: i["parsed_tensor"] for i in graph["initializer"]}} - self.graph_inputs = {i["name"]: i["parsed_type"] for i in graph["input"] if i["name"] not in self.graph_values} - self.graph_outputs = tuple(o["name"] for o in graph["output"]) - self.graph_nodes = tuple(n["parsed_node"] for n in graph["node"]) + self.graph_values: dict[str, Any] = {"": None, **dict(graph["initializer"])} + self.graph_inputs = {name: typ for name, typ in graph["input"] if name not in self.graph_values} + self.graph_outputs = tuple(name for name, _ in graph["output"]) + self.graph_nodes = tuple(graph["node"]) # track names from initializers and Constant nodes for fast path optimizations self.const_names: set[str] = set(self.graph_values.keys()) | {o for n in self.graph_nodes if n.op == "Constant" for o in n.outputs}