Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 65 additions & 7 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ sync_pyproject_version = "pyproject_version_sync.sync_pyproject_version:main"

[tool.poetry.dependencies]
python = "^3.10"
tomli = "~2.0.1"
tomlkit = "^0.14.0"


[tool.poetry.group.dev.dependencies]
Expand Down
117 changes: 88 additions & 29 deletions pyproject_version_sync/sync_pyproject_version.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
"""This pre-commit hook ensures that the version in pyproject.toml matches the latest git tag."""

import argparse
import re
import subprocess
import sys
from pathlib import Path
from typing import TypeVar

import tomlkit
from tomlkit.items import Table
from tomlkit.toml_document import TOMLDocument

import tomli
PATHS = ["tool.poetry.version", "project.version"]
T = TypeVar("T")


def _execute_in_shell(cmd: str) -> subprocess.CompletedProcess:
def _execute_in_shell(cmd: str) -> subprocess.CompletedProcess[bytes]:
return subprocess.run(cmd.split(), check=True, capture_output=True) # noqa: S603


Expand Down Expand Up @@ -44,45 +51,92 @@ def find_latest_tag() -> str:
return re.findall(r"^v?(\d+\.\d+\.\d+).*", latest_tag)[0]


def find_version_in_toml(toml_file: Path) -> str:
def extract_prefix_and_tail(path: str) -> tuple[list[str], str]:
"""
Find the project version in pyproject.toml.
Split a given dotted path into prefix and the last element.

Args:
toml_file: Path to pyproject.toml.
path: Dotted path, like "a.b.c".

Returns:
Project version.
Tuple of prefix and the last element.
"""
parts = path.split(".")
return parts[:-1], parts[-1]


def traverse(toml: TOMLDocument, path: str, cls: type[T]) -> T | None:
"""
with Path.open(toml_file, "rb") as f:
pyproject = tomli.load(f)
Traverse given toml by a given dotted path and verify found type.

return pyproject["tool"]["poetry"]["version"]
Args:
toml: Toml documet.
path: Dotted path, like "a.b.c".
cls: Expected class.

Returns:
Object at a given path or None if not found.
"""
prefix, tail = extract_prefix_and_tail(path)

root: Table | TOMLDocument = toml
for part in prefix:
next_root = root.get(part)
if not isinstance(next_root, Table):
return None
root = next_root

def write_new_version_to_toml(toml_file: Path, version_pyproject: str, version_git: str) -> None:
result = root.get(tail)
if not isinstance(result, cls):
return None
return result


def traverse_set(toml: TOMLDocument, path: str, value: object) -> bool:
"""
Write the new version to pyproject.toml.
Traverse given toml by a given dotted path and set the value, overwrite if it exists already.

Args:
toml_file: Path to pyproject.toml.
version_pyproject: Version in pyproject.toml.
version_git: Latest git tag.
toml: Toml documet.
path: Dotted path, like "a.b.c".
value: string or int.

Returns:
Success status.
"""
pyproject_raw = Path.open(toml_file).read()
prefix, tail = extract_prefix_and_tail(path)

root: Table | TOMLDocument = toml
for part in prefix:
next_root = root.setdefault(part, {})
if not isinstance(next_root, Table):
return False
root = next_root

root[tail] = value
return True


def find_version_in_toml(pyproject: TOMLDocument) -> tuple[str, str]:
"""
Find the project version in pyproject.toml.

# Ignore all the stuff after the block of interest
# Avoids edge case where we may overwrite the version of something else in the file by mistake
block = re.findall(
rf"\[tool\.poetry\][^\n]*.*\nversion\s?=\s?[\"|\']{re.escape(version_pyproject)}[\"|\']\n",
pyproject_raw,
flags=re.DOTALL,
)[0]
new_block = block.replace(version_pyproject, version_git)
Args:
pyproject: Parsed pyproject.toml.

pyproject_new = pyproject_raw.replace(block, new_block)
with Path.open(toml_file, "w") as f_out:
f_out.write(pyproject_new)
Returns:
Project version.
"""
# If user has invalid values it will error out
versions: list[tuple[str, str]] = []
for path in PATHS:
version = traverse(pyproject, path, str)
if version:
versions.append((path, version))

if len(versions) != 1:
sys.exit(f"Expected exactly one version in pyproject.toml, got: {versions}")
return versions[0]


def parse_args() -> argparse.Namespace:
Expand All @@ -109,9 +163,11 @@ def main() -> None:
"""Run the pre-commit hook."""
args = parse_args()
fix = args.fix
toml_file = args.toml_file

version_pyproject = find_version_in_toml(toml_file)
path = Path(args.toml_file)
pyproject = tomlkit.parse(path.read_bytes())

version_path, version_pyproject = find_version_in_toml(pyproject)

if version_pyproject != (version_git := find_latest_tag()):
if not fix:
Expand All @@ -120,7 +176,10 @@ def main() -> None:
f"Run with the `--fix` option to automatically sync.",
)

write_new_version_to_toml(toml_file, version_pyproject, version_git)
# This path should exist
assert traverse_set(pyproject, version_path, version_git) # noqa: S101
tomlkit.dump(pyproject, path.open("w"))

sys.exit("Syncing version in pyproject.toml to match latest git tag.")

sys.exit()
Expand Down