diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 1c904f6adec3..d224eb96174a 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -755,7 +755,7 @@ def save_pretrained( output_config_file = os.path.join(save_directory, config_file_name) - self.to_json_file(output_config_file, use_diff=True) + self.to_json_file(output_config_file, use_diff=True, keys_to_pop=["compile_config"]) logger.info(f"Configuration saved in {output_config_file}") if push_to_hub: @@ -1022,8 +1022,6 @@ def to_dict(self) -> dict[str, Any]: del output["_commit_hash"] if "_original_object_hash" in output: del output["_original_object_hash"] - if "compile_config" in output: - del output["compile_config"] # Transformers version when serializing this file output["transformers_version"] = __version__ @@ -1031,7 +1029,9 @@ def to_dict(self) -> dict[str, Any]: self.dict_dtype_to_str(output) return output - def to_json_string(self, use_diff: bool = True, ignore_metadata: bool = False) -> str: + def to_json_string( + self, use_diff: bool = True, ignore_metadata: bool = False, keys_to_pop: list[str] | None = None + ) -> str: """ Serializes this instance to a JSON string. @@ -1041,6 +1041,8 @@ def to_json_string(self, use_diff: bool = True, ignore_metadata: bool = False) - is serialized to JSON string. ignore_metadata (`bool`, *optional*, defaults to `False`): Whether to ignore the metadata fields present in the instance + keys_to_pop (`list[str]`, *optional*): + Keys to pop from the config dictionary before serializing Returns: `str`: String containing all the attributes that make up this configuration instance in JSON format. @@ -1050,6 +1052,10 @@ def to_json_string(self, use_diff: bool = True, ignore_metadata: bool = False) - else: config_dict = self.to_dict() + if keys_to_pop is not None: + for key in keys_to_pop: + config_dict.pop(key, None) + if ignore_metadata: for metadata_field in METADATA_FIELDS: config_dict.pop(metadata_field, None) @@ -1075,7 +1081,9 @@ def convert_dataclass_to_dict(obj): return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" - def to_json_file(self, json_file_path: str | os.PathLike, use_diff: bool = True): + def to_json_file( + self, json_file_path: str | os.PathLike, use_diff: bool = True, keys_to_pop: list[str] | None = None + ) -> None: """ Save this instance to a JSON file. @@ -1085,9 +1093,11 @@ def to_json_file(self, json_file_path: str | os.PathLike, use_diff: bool = True) use_diff (`bool`, *optional*, defaults to `True`): If set to `True`, only the difference between the config instance and the default `GenerationConfig()` is serialized to JSON file. + keys_to_pop (`list[str]`, *optional*): + Keys to pop from the config dictionary before serializing """ with open(json_file_path, "w", encoding="utf-8") as writer: - writer.write(self.to_json_string(use_diff=use_diff)) + writer.write(self.to_json_string(use_diff=use_diff, keys_to_pop=keys_to_pop)) @classmethod def from_model_config(cls, model_config: PreTrainedConfig | dict) -> "GenerationConfig":