From 1a4ed34c93956ad01221cfc264e23a54f941e789 Mon Sep 17 00:00:00 2001 From: Juncheng Fang Date: Tue, 23 Dec 2025 11:30:39 +0800 Subject: [PATCH] fix(cargo_provider): support workspace virtual manifests --- commitizen/providers/cargo_provider.py | 157 ++++++++++++++----------- 1 file changed, 87 insertions(+), 70 deletions(-) diff --git a/commitizen/providers/cargo_provider.py b/commitizen/providers/cargo_provider.py index ca00f05e7..a4e18e3e6 100644 --- a/commitizen/providers/cargo_provider.py +++ b/commitizen/providers/cargo_provider.py @@ -1,41 +1,39 @@ from __future__ import annotations -import fnmatch -import glob +import fnmatch, glob from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Iterable, Any, cast from tomlkit import TOMLDocument, dumps, parse from tomlkit.exceptions import NonExistentKey - from commitizen.providers.base_provider import TomlProvider if TYPE_CHECKING: from tomlkit.items import AoT -class CargoProvider(TomlProvider): - """ - Cargo version management +DictLike = dict[str, Any] - With support for `workspaces` - """ + +class CargoProvider(TomlProvider): + """Cargo version management for virtual workspace manifests + version.workspace=true members.""" filename = "Cargo.toml" lock_filename = "Cargo.lock" @property def lock_file(self) -> Path: - return Path() / self.lock_filename + return Path(self.lock_filename) def get(self, document: TOMLDocument) -> str: - out = _try_get_workspace(document)["package"]["version"] - if TYPE_CHECKING: - assert isinstance(out, str) - return out + t = _root_version_table(document) + v = t.get("version") + if not isinstance(v, str): + raise TypeError("expected root version to be a string") + return v def set(self, document: TOMLDocument, version: str) -> None: - _try_get_workspace(document)["package"]["version"] = version + _root_version_table(document)["version"] = version def set_version(self, version: str) -> None: super().set_version(version) @@ -43,63 +41,82 @@ def set_version(self, version: str) -> None: self.set_lock_version(version) def set_lock_version(self, version: str) -> None: - cargo_toml_content = parse(self.file.read_text()) - cargo_lock_content = parse(self.lock_file.read_text()) - packages = cargo_lock_content["package"] - + cargo_toml = parse(self.file.read_text()) + cargo_lock = parse(self.lock_file.read_text()) + packages = cargo_lock["package"] if TYPE_CHECKING: assert isinstance(packages, AoT) - try: - cargo_package_name = cargo_toml_content["package"]["name"] # type: ignore[index] - if TYPE_CHECKING: - assert isinstance(cargo_package_name, str) - for i, package in enumerate(packages): - if package["name"] == cargo_package_name: - cargo_lock_content["package"][i]["version"] = version # type: ignore[index] - break - except NonExistentKey: - workspace = cargo_toml_content.get("workspace", {}) - if TYPE_CHECKING: - assert isinstance(workspace, dict) - workspace_members = workspace.get("members", []) - excluded_workspace_members = workspace.get("exclude", []) - members_inheriting: list[str] = [] - - for member in workspace_members: - for path in glob.glob(member, recursive=True): - if any( - fnmatch.fnmatch(path, pattern) - for pattern in excluded_workspace_members - ): - continue - - cargo_file = Path(path) / "Cargo.toml" - package_content = parse(cargo_file.read_text()).get("package", {}) - if TYPE_CHECKING: - assert isinstance(package_content, dict) - try: - version_workspace = package_content["version"]["workspace"] - if version_workspace is True: - package_name = package_content["name"] - if TYPE_CHECKING: - assert isinstance(package_name, str) - members_inheriting.append(package_name) - except NonExistentKey: - pass - - for i, package in enumerate(packages): - if package["name"] in members_inheriting: - cargo_lock_content["package"][i]["version"] = version # type: ignore[index] - - self.lock_file.write_text(dumps(cargo_lock_content)) - - -def _try_get_workspace(document: TOMLDocument) -> dict: + root_pkg = _table_get(cargo_toml, "package") + if root_pkg is not None: + name = root_pkg.get("name") + if isinstance(name, str): + _lock_set_versions(packages, {name}, version) + self.lock_file.write_text(dumps(cargo_lock)) + return + + ws = _table_get(cargo_toml, "workspace") or {} + members = cast(list[str], ws.get("members", []) or []) + excludes = cast(list[str], ws.get("exclude", []) or []) + inheriting = _workspace_inheriting_member_names(members, excludes) + _lock_set_versions(packages, inheriting, version) + self.lock_file.write_text(dumps(cargo_lock)) + + +def _table_get(doc: TOMLDocument, key: str) -> DictLike | None: + """Return a dict-like table for `key` if present, else None (type-safe for Pylance).""" try: - workspace = document["workspace"] - if TYPE_CHECKING: - assert isinstance(workspace, dict) - return workspace + v = doc[key] # tomlkit returns Container/Table-like; typing is loose except NonExistentKey: - return document + return None + return cast(DictLike, v) if hasattr(v, "get") else None + + +def _root_version_table(doc: TOMLDocument) -> DictLike: + """Prefer [workspace.package]; fallback to [package].""" + ws = _table_get(doc, "workspace") + if ws is not None: + pkg = ws.get("package") + if hasattr(pkg, "get"): + return cast(DictLike, pkg) + pkg = _table_get(doc, "package") + if pkg is None: + raise NonExistentKey('expected either [workspace.package] or [package]') + return pkg + + +def _is_workspace_inherited_version(v: Any) -> bool: + return hasattr(v, "get") and cast(DictLike, v).get("workspace") is True + + +def _iter_member_dirs(members: Iterable[str], excludes: Iterable[str]) -> Iterable[Path]: + for pat in members: + for p in glob.glob(pat, recursive=True): + if any(fnmatch.fnmatch(p, ex) for ex in excludes): + continue + yield Path(p) + + +def _workspace_inheriting_member_names(members: Iterable[str], excludes: Iterable[str]) -> set[str]: + out: set[str] = set() + for d in _iter_member_dirs(members, excludes): + cargo_file = d / "Cargo.toml" + if not cargo_file.exists(): + continue + pkg = parse(cargo_file.read_text()).get("package") + if not hasattr(pkg, "get"): + continue + pkgd = cast(DictLike, pkg) + if _is_workspace_inherited_version(pkgd.get("version")): + name = pkgd.get("name") + if isinstance(name, str): + out.add(name) + return out + + +def _lock_set_versions(packages: Any, names: set[str], version: str) -> None: + if not names: + return + for i, p in enumerate(packages): + if getattr(p, "get", None) and p.get("name") in names: + packages[i]["version"] = version # type: ignore[index]