Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 60 additions & 74 deletions tinygrad/nn/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,21 +153,19 @@ def _decode_end_pos(self) -> int:

def _parse_ModelProto(self) -> dict:
"""Entry point for parsing the ONNX model."""
obj: dict[str, Any] = {"opset_import": []}
graph: dict|None = None
opset_imports: list[OpSetId] = []
for fid, wire_type in self._parse_message(self.reader.len):
match fid:
case 4: obj["domain"] = self.reader.read_string()
case 5: obj["model_version"] = self.reader.read_int64()
case 7: obj["graph"] = self._parse_GraphProto()
case 8: obj["opset_import"].append(self._parse_OperatorSetIdProto())
case 7: graph = self._parse_GraphProto()
case 8: opset_imports.append(self._parse_OperatorSetIdProto())
case _: self.reader.skip_field(wire_type)

assert graph is not None
# update opset version
opset_imports = {Domain.from_onnx(x.get('domain')):x.get('version', 1) for x in obj["opset_import"]}
for n in obj["graph"]["node"]:
n_ = n["parsed_node"]
n["parsed_node"] = OnnxNode(n_.op, OpSetId(n_.opset_id.domain, opset_imports.get(n_.opset_id.domain, 1)), n_.inputs, n_.outputs, n_.opts)
return obj
versions = {opset.domain: opset.version for opset in opset_imports}
graph["node"] = [OnnxNode(n.op, OpSetId(n.opset_id.domain, versions.get(n.opset_id.domain, 1)), n.inputs, n.outputs, n.opts)
for n in graph["node"]]
return graph

def _parse_GraphProto(self) -> dict:
obj: dict[str, Any] = {"node": [], "initializer": [], "input": [], "output": []}
Expand All @@ -181,26 +179,23 @@ def _parse_GraphProto(self) -> dict:
case _: self.reader.skip_field(wire_type)
return obj

def _parse_NodeProto(self) -> dict:
obj: dict[str, Any] = {"input": [], "output": [], "attribute": [], "domain": None}
def _parse_NodeProto(self) -> OnnxNode:
inputs: list[str] = []
outputs: list[str] = []
attributes: list[tuple[str, Any]] = []
domain: str|None = None
op_type = ""
for fid, wire_type in self._parse_message(self._decode_end_pos()):
match fid:
case 1: obj["input"].append(self.reader.read_string())
case 2: obj["output"].append(self.reader.read_string())
case 3: obj["name"] = self.reader.read_string()
case 4: obj["op_type"] = self.reader.read_string()
case 5: obj["attribute"].append(self._parse_AttributeProto())
case 6: obj["doc_string"] = self.reader.read_string()
case 7: obj["domain"] = self.reader.read_string()
case 1: inputs.append(self.reader.read_string())
case 2: outputs.append(self.reader.read_string())
case 4: op_type = self.reader.read_string()
case 5: attributes.append(self._parse_AttributeProto())
case 7: domain = self.reader.read_string()
case _: self.reader.skip_field(wire_type)
return OnnxNode(op_type, OpSetId(Domain.from_onnx(domain), 1), tuple(inputs), tuple(outputs), dict(attributes))

# parse node
attributes = {attr_dict["name"]: attr_dict[AttributeType(attr_dict["type"]).to_field_name()] for attr_dict in obj["attribute"]}
opset_id = OpSetId(Domain.from_onnx(obj.get('domain')), 1) # default version, to be updated later in _parse_ModelProto
obj["parsed_node"] = OnnxNode(obj["op_type"], opset_id, tuple(obj["input"]), tuple(obj["output"]), attributes)
return obj

def _parse_TensorProto(self) -> dict:
def _parse_TensorProto(self) -> tuple[str, Tensor]:
obj: dict[str, Any] = {"dims": []}
for fid, wire_type in self._parse_message(self._decode_end_pos()):
match fid:
Expand All @@ -220,18 +215,16 @@ def _parse_TensorProto(self) -> dict:
# load external data
if self.load_external_data and obj.get("data_location", 0) == 1:
if "external_data" not in obj: raise ValueError("no external_data")
location, length, offset = None, None, 0
for kv in obj["external_data"]:
if kv["key"] == "location": location = kv["value"]
elif kv["key"] == "offset": offset = int(kv["value"])
elif kv["key"] == "length": length = int(kv["value"])
if location is None: raise ValueError("no location in external_data")
ext = dict(obj["external_data"])
if "location" not in ext: raise ValueError("no location in external_data")
offset = int(ext.get("offset", "0"))
length = int(ext["length"]) if "length" in ext else None

if self.file_path is None:
if isinstance(self.tensor.device, str) and self.tensor.device.startswith("DISK:"):
self.file_path = pathlib.Path(self.tensor.device[5:])
else: raise ValueError("onnx external_data needs the origin file path, try passing onnx file path to onnx_load")
ext_path = self.file_path.parent.joinpath(location)
ext_path = self.file_path.parent.joinpath(ext["location"])
if not ext_path.exists(): raise FileNotFoundError(f"external location not exists: {ext_path}")

ext_tensor = Tensor(ext_path)
Expand All @@ -241,58 +234,51 @@ def _parse_TensorProto(self) -> dict:
# parse tensor
to_dtype = dtype_fallback(true_dtype := OnnxDataType(obj['data_type']).to_dtype(), "buffer parse")
shape = tuple(obj['dims'])
present_fields = [field for field in ['float_data', 'int32_data', 'int64_data', 'double_data', 'uint64_data', 'raw_data'] if field in obj]
assert len(present_fields) == 1, f"only 1 data field is allowed from {obj=}"
data = obj[present_fields[0]]
if not isinstance(data, Tensor):
obj["parsed_tensor"] = Tensor(data, dtype=to_dtype).reshape(shape)
return obj
assert isinstance(data, Tensor) and data.dtype == dtypes.uint8, data
data_fields = [f for f in ('float_data','int32_data','int64_data','double_data','uint64_data','raw_data') if f in obj]
data = obj[get_single_element(data_fields)]
name = obj.get("name", "")
if not isinstance(data, Tensor): return name, Tensor(data, dtype=to_dtype).reshape(shape)
assert data.dtype == dtypes.uint8, data
data = data.bitcast(true_dtype).reshape(shape)
data = data.to(Device.DEFAULT) if true_dtype is to_dtype else data.to("cpu").cast(to_dtype).to(Device.DEFAULT)
# const folding
if shape == ():
if data.dtype == dtypes.float16 and sys.version_info < (3, 12): data = data.cast(dtypes.float32)
data = Tensor(data.item(), dtype=to_dtype).reshape(shape)
obj["parsed_tensor"] = data
return obj
return name, data

def _parse_AttributeProto(self) -> dict:
def _parse_AttributeProto(self) -> tuple[str, Any]:
obj: dict[str, Any] = {"floats": [], "ints": [], "strings": []}
for fid, wire_type in self._parse_message(self._decode_end_pos()):
match fid:
case 1: obj["name"] = self.reader.read_string()
case 2: obj["f"] = self.reader.read_float()
case 3: obj["i"] = self.reader.read_int64()
case 4: obj["s"] = self.reader.read_bytes().data().tobytes().decode("utf8")
case 5: obj["t"] = self._parse_TensorProto()['parsed_tensor']
case 5: obj["t"] = self._parse_TensorProto()[1]
case 6: obj["g"] = OnnxRunner._from_subgraph(self._parse_GraphProto())
case 7: obj["floats"].append(self.reader.read_float())
case 8: obj["ints"].append(self.reader.read_int64())
case 9: obj["strings"].append(self.reader.read_bytes().data().tobytes().decode("utf8"))
case 20: obj["type"] = self.reader.read_int64()
case _: self.reader.skip_field(wire_type)
obj["floats"], obj["ints"], obj["strings"] = tuple(obj["floats"]), tuple(obj["ints"]), tuple(obj["strings"])
return obj
return obj["name"], obj[AttributeType(obj["type"]).to_field_name()]

def _parse_ValueInfoProto(self) -> dict:
obj: dict[str, Any] = {}
def _parse_ValueInfoProto(self) -> tuple[str, OnnxValue|None]:
name, type_obj = "", None
for fid, wire_type in self._parse_message(self._decode_end_pos()):
match fid:
case 1: obj["name"] = self.reader.read_string()
case 2: obj["type"] = self._parse_TypeProto()
case 1: name = self.reader.read_string()
case 2: type_obj = self._parse_TypeProto()
case _: self.reader.skip_field(wire_type)

# parse type
if "type" not in obj: return {**obj, "parsed_type": None}
type_obj = obj["type"]
if type_obj is None: return name, None
if is_optional := "optional_type" in type_obj: type_obj = type_obj["optional_type"]["elem_type"]
if is_sequence := "sequence_type" in type_obj: type_obj = type_obj["sequence_type"]["elem_type"]
assert "tensor_type" in type_obj, type_obj
shape_dims = type_obj['tensor_type'].get('shape', {}).get('dim', [])
obj['parsed_type'] = OnnxValue(tuple(d.get('dim_param') or d.get('dim_value') for d in shape_dims),
OnnxDataType(type_obj['tensor_type']['elem_type']).to_dtype(), is_optional, is_sequence)
return obj
return name, OnnxValue(tuple(d.get('dim_param') or d.get('dim_value') for d in shape_dims),
OnnxDataType(type_obj['tensor_type']['elem_type']).to_dtype(), is_optional, is_sequence)

def _parse_TypeProto(self) -> dict:
obj: dict[str, Any] = {}
Expand Down Expand Up @@ -338,23 +324,24 @@ def _parse_TensorShapeProtoDimension(self) -> dict:
case _: self.reader.skip_field(wire_type)
return obj

def _parse_StringStringEntryProto(self) -> dict:
obj: dict[str, Any] = {}
def _parse_StringStringEntryProto(self) -> tuple[str, str]:
key, value = "", ""
for fid, wire_type in self._parse_message(self._decode_end_pos()):
match fid:
case 1: obj["key"] = self.reader.read_string()
case 2: obj["value"] = self.reader.read_string()
case 1: key = self.reader.read_string()
case 2: value = self.reader.read_string()
case _: self.reader.skip_field(wire_type)
return obj
return key, value

def _parse_OperatorSetIdProto(self) -> dict:
obj: dict[str, Any] = {}
def _parse_OperatorSetIdProto(self) -> OpSetId:
domain: str|None = None
version = 1
for fid, wire_type in self._parse_message(self._decode_end_pos()):
match fid:
case 1: obj["domain"] = self.reader.read_string()
case 2: obj["version"] = self.reader.read_int64()
case 1: domain = self.reader.read_string()
case 2: version = self.reader.read_int64()
case _: self.reader.skip_field(wire_type)
return obj
return OpSetId(Domain.from_onnx(domain), version)

# ***** python const *****
required_input_python_consts: dict[str, tuple[int, ...]] = {
Expand All @@ -380,16 +367,15 @@ class OnnxRunner:
model_path: The ONNX model, provided as a file path (a string or Path object) or a Tensor.
"""
def __init__(self, model_path: Tensor | str | pathlib.Path):
model = OnnxPBParser(model_path, load_external_data=True).parse()
self._init_from_graph(model["graph"])
self._init_from_graph(OnnxPBParser(model_path, load_external_data=True).parse())

def _init_from_graph(self, graph: dict, is_subgraph: bool = False):
self.is_training = any(n['parsed_node'].opset_id.domain in {Domain.AI_ONNX_TRAINING, Domain.AI_ONNX_PREVIEW_TRAINING} for n in graph["node"])
self.is_training = any(n.opset_id.domain in {Domain.AI_ONNX_TRAINING, Domain.AI_ONNX_PREVIEW_TRAINING} for n in graph["node"])
self.graph_name = graph["name"] if is_subgraph else ""
self.graph_values = {"": None, **{i["name"]: i["parsed_tensor"] for i in graph["initializer"]}}
self.graph_inputs = {i["name"]: i["parsed_type"] for i in graph["input"] if i["name"] not in self.graph_values}
self.graph_outputs = tuple(o["name"] for o in graph["output"])
self.graph_nodes = tuple(n["parsed_node"] for n in graph["node"])
self.graph_values: dict[str, Any] = {"": None, **dict(graph["initializer"])}
self.graph_inputs = {name: typ for name, typ in graph["input"] if name not in self.graph_values}
self.graph_outputs = tuple(name for name, _ in graph["output"])
self.graph_nodes = tuple(graph["node"])
# track names from initializers and Constant nodes for fast path optimizations
self.const_names: set[str] = set(self.graph_values.keys()) | {o for n in self.graph_nodes if n.op == "Constant" for o in n.outputs}

Expand Down
Loading