From 5ab206b2207988363a5fc93b34df318fa285d1f6 Mon Sep 17 00:00:00 2001 From: pengcheng888 Date: Fri, 5 Dec 2025 16:52:30 +0800 Subject: [PATCH] =?UTF-8?q?issue/719=20-=20=E4=B8=BApython=E7=9A=84?= =?UTF-8?q?=E5=8D=95=E4=B8=AAtensor=E6=A8=A1=E5=9E=8B=E5=8A=A0=E8=BD=BD?= =?UTF-8?q?=E6=8F=90=E4=BE=9B=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/infinicore/device.py | 6 +- python/infinicore/nn/modules/module.py | 80 ++++++++++++++++++++++++-- python/infinicore/tensor.py | 7 ++- 3 files changed, 84 insertions(+), 9 deletions(-) diff --git a/python/infinicore/device.py b/python/infinicore/device.py index 220865601..010c92e57 100644 --- a/python/infinicore/device.py +++ b/python/infinicore/device.py @@ -34,7 +34,11 @@ def __init__(self, type=None, index=None): def __getattr__(self, name): # Lazily construct and cache an attribute. # such as, self._underlying . - setattr(self, name, device._to_infinicore_device(self.type, self.index)) + if name == "_underlying": + setattr(self, name, device._to_infinicore_device(self.type, self.index)) + else: + raise KeyError(f"device does not support '{name}' attribute.") + return getattr(self, name) def __repr__(self): diff --git a/python/infinicore/nn/modules/module.py b/python/infinicore/nn/modules/module.py index d21223903..ff9da5bd9 100644 --- a/python/infinicore/nn/modules/module.py +++ b/python/infinicore/nn/modules/module.py @@ -481,15 +481,12 @@ def _load_from_state_dict( f"While copying the parameter named {key}, expected Tensor from checkpoint but received {type(input_param)}" ) - if ( - (param.shape == input_param.shape) - and (param.dtype == input_param.dtype) - and (param.device == input_param.device) + if (param.shape == input_param.shape) and ( + param.dtype == input_param.dtype ): param.copy_(input_param) else: - print(f"param '{name}' don't match input_param '{key}'") - setattr(self, name, input_param) + raise KeyError("param don't match input_param.") elif strict: missing_keys.append(key) @@ -842,6 +839,77 @@ def named_children( memo.add(module) yield name, module + def get_submodule(self, target: str) -> "InfiniCoreModule": + """Return the submodule given by ``target`` if it exists, otherwise throw an error. + + Args: + target: The fully-qualified string name of the submodule to look for. + Returns: + infinicore.nn.Module: The submodule referenced by ``target`` + + Raises: + AttributeError: If at any point along the path resulting from + the target string the (sub)path resolves to a non-existent + attribute name or an object that is not an instance of ``nn.Module``. + """ + if target == "": + return self + + atoms: list[str] = target.split(".") + mod: infinicore.nn.Module = self + + for item in atoms: + if not hasattr(mod, item): + raise AttributeError( + mod._get_name() + " has no attribute `" + item + "`" + ) + mod = getattr(mod, item) + if not isinstance(mod, infinicore.nn.Module): + raise AttributeError("`" + item + "` is not an nn.Module") + + return mod + + def get_parameter(self, target: str) -> "Parameter": + """Return the parameter given by ``target`` if it exists, otherwise throw an error. + + Args: + target: The fully-qualified string name of the Parameter to look for. + + Returns: + infinicore.nn.Parameter: The Parameter referenced by ``target`` + + Raises: + AttributeError: If the target string references an invalid + path or resolves to something that is not an``nn.Parameter`` + """ + module_path, _, param_name = target.rpartition(".") + mod: infinicore.nn.Module = self.get_submodule(module_path) + + if not hasattr(mod, param_name): + raise AttributeError( + mod._get_name() + " has no attribute `" + param_name + "`" + ) + + param: Parameter = getattr(mod, param_name) + + if not isinstance(param, Parameter): + raise AttributeError("`" + param_name + "` is not an nn.Parameter") + + return param + + def load_parameter(self, target: str, input_param: Tensor): + """ + load one parameter into Module. + Args: + target: The fully-qualified string name of the Parameter to look for. + input_param: The tensor obtained from the model.safetensors file + """ + param = self.get_parameter(target) + if (param.shape == input_param.shape) and (param.dtype == input_param.dtype): + param.copy_(input_param) + else: + raise KeyError("param don't match input_param.") + def eval(self: T) -> T: r"""Sets the module in evaluation mode. diff --git a/python/infinicore/tensor.py b/python/infinicore/tensor.py index b72c72e77..b3120e4c0 100644 --- a/python/infinicore/tensor.py +++ b/python/infinicore/tensor.py @@ -42,6 +42,9 @@ def __getattr__(self, name): getattr(self._underlying, name) ), ) + else: + raise KeyError(f"Tensor does not support '{name}' attribute.") + return getattr(self, name) @property @@ -116,8 +119,8 @@ def __matmul__(self, other): def __mul__(self, other): return infinicore.mul(self, other) - def narrow(self, dim, start, length): - return infinicore.narrow(self, dim, start, length) + def narrow(self, dim: int, start: int, length: int): + return Tensor(self._underlying.narrow(dim, start, length)) def empty(size, *, dtype=None, device=None, pin_memory=False):