Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion python/infinicore/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ def __init__(self, type=None, index=None):
def __getattr__(self, name):
# Lazily construct and cache an attribute.
# such as, self._underlying .
setattr(self, name, device._to_infinicore_device(self.type, self.index))
if name == "_underlying":
setattr(self, name, device._to_infinicore_device(self.type, self.index))
else:
raise KeyError(f"device does not support '{name}' attribute.")

return getattr(self, name)

def __repr__(self):
Expand Down
80 changes: 74 additions & 6 deletions python/infinicore/nn/modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,15 +481,12 @@ def _load_from_state_dict(
f"While copying the parameter named {key}, expected Tensor from checkpoint but received {type(input_param)}"
)

if (
(param.shape == input_param.shape)
and (param.dtype == input_param.dtype)
and (param.device == input_param.device)
if (param.shape == input_param.shape) and (
param.dtype == input_param.dtype
):
param.copy_(input_param)
else:
print(f"param '{name}' don't match input_param '{key}'")
setattr(self, name, input_param)
raise KeyError("param don't match input_param.")

elif strict:
missing_keys.append(key)
Expand Down Expand Up @@ -842,6 +839,77 @@ def named_children(
memo.add(module)
yield name, module

def get_submodule(self, target: str) -> "InfiniCoreModule":
"""Return the submodule given by ``target`` if it exists, otherwise throw an error.

Args:
target: The fully-qualified string name of the submodule to look for.
Returns:
infinicore.nn.Module: The submodule referenced by ``target``

Raises:
AttributeError: If at any point along the path resulting from
the target string the (sub)path resolves to a non-existent
attribute name or an object that is not an instance of ``nn.Module``.
"""
if target == "":
return self

atoms: list[str] = target.split(".")
mod: infinicore.nn.Module = self

for item in atoms:
if not hasattr(mod, item):
raise AttributeError(
mod._get_name() + " has no attribute `" + item + "`"
)
mod = getattr(mod, item)
if not isinstance(mod, infinicore.nn.Module):
raise AttributeError("`" + item + "` is not an nn.Module")

return mod

def get_parameter(self, target: str) -> "Parameter":
"""Return the parameter given by ``target`` if it exists, otherwise throw an error.

Args:
target: The fully-qualified string name of the Parameter to look for.

Returns:
infinicore.nn.Parameter: The Parameter referenced by ``target``

Raises:
AttributeError: If the target string references an invalid
path or resolves to something that is not an``nn.Parameter``
"""
module_path, _, param_name = target.rpartition(".")
mod: infinicore.nn.Module = self.get_submodule(module_path)

if not hasattr(mod, param_name):
raise AttributeError(
mod._get_name() + " has no attribute `" + param_name + "`"
)

param: Parameter = getattr(mod, param_name)

if not isinstance(param, Parameter):
raise AttributeError("`" + param_name + "` is not an nn.Parameter")

return param

def load_parameter(self, target: str, input_param: Tensor):
"""
load one parameter into Module.
Args:
target: The fully-qualified string name of the Parameter to look for.
input_param: The tensor obtained from the model.safetensors file
"""
param = self.get_parameter(target)
if (param.shape == input_param.shape) and (param.dtype == input_param.dtype):
param.copy_(input_param)
else:
raise KeyError("param don't match input_param.")

def eval(self: T) -> T:
r"""Sets the module in evaluation mode.

Expand Down
7 changes: 5 additions & 2 deletions python/infinicore/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def __getattr__(self, name):
getattr(self._underlying, name)
),
)
else:
raise KeyError(f"Tensor does not support '{name}' attribute.")

return getattr(self, name)

@property
Expand Down Expand Up @@ -116,8 +119,8 @@ def __matmul__(self, other):
def __mul__(self, other):
return infinicore.mul(self, other)

def narrow(self, dim, start, length):
return infinicore.narrow(self, dim, start, length)
def narrow(self, dim: int, start: int, length: int):
return Tensor(self._underlying.narrow(dim, start, length))


def empty(size, *, dtype=None, device=None, pin_memory=False):
Expand Down
Loading