Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions src/transformers/core_model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __init__(self, dim: int = 0, chunks: Optional[int] = None, sizes: Optional[S

def convert(
self,
value: dict[str, list[torch.Tensor]],
value: dict[str, dict[str, torch.Tensor]],
source_keys: list[str],
target_keys: list[str],
full_layer_name: str,
Expand All @@ -158,7 +158,7 @@ def __init__(self, dim: int = 0):
@torch.no_grad
def convert(
self,
value: dict[str, list[torch.Tensor]],
value: dict[str, torch.Tensor],
source_keys: list[str],
target_keys: list[str],
full_layer_name: str,
Expand Down Expand Up @@ -186,15 +186,15 @@ def __init__(self, dim: int = 0):
@torch.no_grad
def convert(
self,
value: dict[str, list[torch.Tensor]],
value: dict[str, dict[str, torch.Tensor]],
source_keys: list[str],
target_keys: list[str],
full_layer_name: str,
config,
) -> dict[str, torch.Tensor]:
merged: dict[str, torch.Tensor] = {}
for idx, key in enumerate(value.keys()):
tensors = value.get(key, [])
tensors = list(value.get(key, {}).values())
if len(source_keys) == 1:
key = full_layer_name
stacked = torch.stack(tensors, dim=self.dim)
Expand All @@ -215,7 +215,7 @@ def __init__(self, sizes: Sequence[Sequence[int]], dim: int = 0):
@torch.no_grad
def convert(
self,
value: dict[str, list[torch.Tensor]],
value: dict[str, dict[str, torch.Tensor]],
source_keys: list[str],
target_keys: list[str],
full_layer_name: str,
Expand Down Expand Up @@ -253,7 +253,7 @@ def _apply(self, tensor: torch.Tensor) -> torch.Tensor:
@torch.no_grad
def convert(
self,
value: dict[str, list[torch.Tensor]],
value: dict[str, dict[str, torch.Tensor]],
source_keys: list[str],
target_keys: list[str],
full_layer_name: str,
Expand All @@ -276,7 +276,7 @@ class WeightTransform:
distributed_operation: Optional[TensorParallelLayer] = None
quantization_operation: Optional[ConversionOps] = None

collected_tensors: dict[str, list[Future]] = field(default_factory=dict, init=False)
collected_tensors: dict[str, dict[str, Future]] = field(default_factory=dict, init=False)
layer_targets: dict[str, set[str]] = field(default_factory=dict, init=False)

def __post_init__(self):
Expand All @@ -286,8 +286,8 @@ def __post_init__(self):
self.target_keys = [self.target_keys]

def add_tensor(self, target_key: str, source_key: str, source_pattern: str, future: Future):
bucket = self.collected_tensors.setdefault(source_pattern, [])
bucket += [future]
bucket = self.collected_tensors.setdefault(source_pattern, {})
bucket[source_key] = future

bucket = self.layer_targets.setdefault(target_key, set())
bucket.add(source_key)
Expand All @@ -296,11 +296,17 @@ def add_tensor(self, target_key: str, source_key: str, source_pattern: str, futu
@dataclass(slots=True)
class WeightRenaming(WeightTransform):
# Special case of WeightTransform that only renames keys without any conversion.
collected_tensors: dict[str, Future] = field(default_factory=dict, init=False)
def add_tensor(self, target_key: str, source_key: str, source_pattern: str, future: Future):
self.collected_tensors[target_key] = future

bucket = self.layer_targets.setdefault(target_key, set())
bucket.add(source_key)

def convert(self, layer_name: str, config=None, quantizer=None, missing_keys: Optional[MutableSet[str]] = None):
misc = {}
for pattern, futures in self.collected_tensors.items():
self.collected_tensors[pattern] = [future.result() for future in futures]
for pattern, future in self.collected_tensors.items():
self.collected_tensors[pattern] = future.result()

collected_tensors = self.collected_tensors
if quantizer is not None and self.quantization_operation is not None:
Expand Down Expand Up @@ -334,7 +340,7 @@ def __post_init__(self):
def convert(self, layer_name: str, config=None, quantizer=None, missing_keys: Optional[MutableSet[str]] = None):
misc = {}
for pattern, futures in self.collected_tensors.items():
self.collected_tensors[pattern] = [future.result() for future in futures]
self.collected_tensors[pattern] = {k:future.result() for k, future in futures.items()}

collected_tensors = self.collected_tensors
for op in self.operations:
Expand Down Expand Up @@ -752,7 +758,6 @@ def convert_and_load_state_dict_in_model(
first_param_name, config=model.config, quantizer=hf_quantizer, missing_keys=missing_keys
)
for target_name, param in realized_value.items():
param = param[0] if isinstance(param, list) else param
device_match = device_map_regex.match(target_name)
param_device = device_map[device_match.group()] if device_match else device_map.get("", "cpu")
# Offloading support
Expand Down