diff --git a/commitizen/providers/cargo_provider.py b/commitizen/providers/cargo_provider.py index ca00f05e7..02ea41095 100644 --- a/commitizen/providers/cargo_provider.py +++ b/commitizen/providers/cargo_provider.py @@ -3,7 +3,7 @@ import fnmatch import glob from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, cast from tomlkit import TOMLDocument, dumps, parse from tomlkit.exceptions import NonExistentKey @@ -11,31 +11,33 @@ from commitizen.providers.base_provider import TomlProvider if TYPE_CHECKING: + from collections.abc import Iterable + from tomlkit.items import AoT -class CargoProvider(TomlProvider): - """ - Cargo version management +DictLike = dict[str, Any] - With support for `workspaces` - """ + +class CargoProvider(TomlProvider): + """Cargo version management for virtual workspace manifests + version.workspace=true members.""" filename = "Cargo.toml" lock_filename = "Cargo.lock" @property def lock_file(self) -> Path: - return Path() / self.lock_filename + return Path(self.lock_filename) def get(self, document: TOMLDocument) -> str: - out = _try_get_workspace(document)["package"]["version"] - if TYPE_CHECKING: - assert isinstance(out, str) - return out + t = _root_version_table(document) + v = t.get("version") + if not isinstance(v, str): + raise TypeError("expected root version to be a string") + return v def set(self, document: TOMLDocument, version: str) -> None: - _try_get_workspace(document)["package"]["version"] = version + _root_version_table(document)["version"] = version def set_version(self, version: str) -> None: super().set_version(version) @@ -50,56 +52,134 @@ def set_lock_version(self, version: str) -> None: if TYPE_CHECKING: assert isinstance(packages, AoT) - try: - cargo_package_name = cargo_toml_content["package"]["name"] # type: ignore[index] - if TYPE_CHECKING: - assert isinstance(cargo_package_name, str) - for i, package in enumerate(packages): - if package["name"] == cargo_package_name: - cargo_lock_content["package"][i]["version"] = version # type: ignore[index] - break - except NonExistentKey: - workspace = cargo_toml_content.get("workspace", {}) - if TYPE_CHECKING: - assert isinstance(workspace, dict) - workspace_members = workspace.get("members", []) - excluded_workspace_members = workspace.get("exclude", []) - members_inheriting: list[str] = [] - - for member in workspace_members: - for path in glob.glob(member, recursive=True): - if any( - fnmatch.fnmatch(path, pattern) - for pattern in excluded_workspace_members - ): - continue - - cargo_file = Path(path) / "Cargo.toml" - package_content = parse(cargo_file.read_text()).get("package", {}) - if TYPE_CHECKING: - assert isinstance(package_content, dict) - try: - version_workspace = package_content["version"]["workspace"] - if version_workspace is True: - package_name = package_content["name"] - if TYPE_CHECKING: - assert isinstance(package_name, str) - members_inheriting.append(package_name) - except NonExistentKey: - pass - - for i, package in enumerate(packages): - if package["name"] in members_inheriting: - cargo_lock_content["package"][i]["version"] = version # type: ignore[index] - + root_pkg = _table_get(cargo_toml_content, "package") + if root_pkg is not None: + name = root_pkg.get("name") + if isinstance(name, str): + _lock_set_versions(packages, {name}, version) + self.lock_file.write_text(dumps(cargo_lock_content)) + return + + ws = _table_get(cargo_toml_content, "workspace") or {} + member_globs = cast("list[str]", ws.get("members", []) or []) + exclude_globs = cast("list[str]", ws.get("exclude", []) or []) + inheriting = _workspace_inheriting_member_names(member_globs, exclude_globs) + _lock_set_versions(packages, inheriting, version) self.lock_file.write_text(dumps(cargo_lock_content)) -def _try_get_workspace(document: TOMLDocument) -> dict: +def _table_get(doc: TOMLDocument, key: str) -> DictLike | None: + """Get a TOML table by key as a dict-like object. + + Returns: + The value at `doc[key]` cast to a dict-like table (supports `.get`) if it + exists and is table/container-like; otherwise returns None. + + Rationale: + tomlkit returns loosely-typed Container/Table objects; using a small + helper keeps call sites readable and makes type-checkers happier. + """ try: - workspace = document["workspace"] - if TYPE_CHECKING: - assert isinstance(workspace, dict) - return workspace + value = doc[key] except NonExistentKey: - return document + return None + return cast("DictLike", value) if hasattr(value, "get") else None + + +def _root_version_table(doc: TOMLDocument) -> DictLike: + """Return the table that owns the "root" version field. + + This provider supports two layouts: + + 1) Workspace virtual manifests: + [workspace.package] + version = "x.y.z" + + 2) Regular crate(non-workspace root manifest): + [package] + version = "x.y.z" + + The selected table is where `get()` reads from and `set()` writes to. + """ + workspace_table = _table_get(doc, "workspace") + if workspace_table is not None: + workspace_package_table = workspace_table.get("package") + if hasattr(workspace_package_table, "get"): + return cast("DictLike", workspace_package_table) + + package_table = _table_get(doc, "package") + if package_table is None: + raise NonExistentKey("expected either [workspace.package] or [package]") + return package_table + + +def _is_workspace_inherited_version(v: Any) -> bool: + return hasattr(v, "get") and cast("DictLike", v).get("workspace") is True + + +def _iter_member_dirs( + member_globs: Iterable[str], exclude_globs: Iterable[str] +) -> Iterable[Path]: + """Yield workspace member directories matched by `member_globs`, excluding `exclude_globs`. + + Cargo workspaces define members/exclude as glob patterns (e.g. "crates/*"). + This helper expands those patterns and yields the corresponding directories + as `Path` objects, skipping any matches that satisfy an exclude glob. + + Kept as a helper to make call sites read as domain logic ("iterate member dirs") + rather than glob/filter plumbing. + """ + for member_glob in member_globs: + for match in glob.glob(member_glob, recursive=True): + if any(fnmatch.fnmatch(match, ex) for ex in exclude_globs): + continue + yield Path(match) + + +def _workspace_inheriting_member_names( + members: Iterable[str], excludes: Iterable[str] +) -> set[str]: + """Return workspace member crate names that inherit the workspace version. + + A member is considered "inheriting" when its Cargo.toml has: + [package] + version.workspace = true + + This scans `members` globs (respecting `excludes`) and returns the set of + `[package].name` values for matching crates. Missing/invalid Cargo.toml files + are ignored. + """ + inheriting_member_names: set[str] = set() + for d in _iter_member_dirs(members, excludes): + cargo_file = d / "Cargo.toml" + if not cargo_file.exists(): + continue + pkg = parse(cargo_file.read_text()).get("package") + if not hasattr(pkg, "get"): + continue + pkgd = cast("DictLike", pkg) + if _is_workspace_inherited_version(pkgd.get("version")): + name = pkgd.get("name") + if isinstance(name, str): + inheriting_member_names.add(name) + return inheriting_member_names + + +def _lock_set_versions(packages: Any, package_names: set[str], version: str) -> None: + """Update Cargo.lock package entries in-place. + + Args: + packages: `Cargo.lock` parsed TOML "package" array (AoT-like). Mutated in-place. + package_names: Set of package names whose `version` field should be updated. + version: New version string to write. + + Notes: + We use `enumerate` + index assignment because tomlkit AoT entries may be + Container-like and direct mutation patterns vary; indexed assignment is + reliable for updating the underlying document. + """ + if not package_names: + return + for i, pkg_entry in enumerate(packages): + if getattr(pkg_entry, "get", None) and pkg_entry.get("name") in package_names: + packages[i]["version"] = version diff --git a/tests/providers/test_cargo_provider.py b/tests/providers/test_cargo_provider.py index ea15fdbf3..ba30270f4 100644 --- a/tests/providers/test_cargo_provider.py +++ b/tests/providers/test_cargo_provider.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING import pytest +from tomlkit.exceptions import NonExistentKey from commitizen.providers import get_provider from commitizen.providers.cargo_provider import CargoProvider @@ -451,3 +452,145 @@ def test_cargo_provider_workspace_member_without_workspace_key( assert file.read_text() == dedent(expected_workspace_toml) # The lock file should remain unchanged since the member doesn't inherit workspace version assert lock_file.read_text() == dedent(expected_lock_content) + + +def test_cargo_provider_get_raises_when_version_is_not_string( + config: BaseConfig, + chdir: Path, +) -> None: + """CargoProvider.get should raise when root version is not a string.""" + (chdir / CargoProvider.filename).write_text( + dedent( + """\ + [package] + name = "whatever" + version = 1 + """ + ) + ) + config.settings["version_provider"] = "cargo" + + provider = get_provider(config) + assert isinstance(provider, CargoProvider) + + with pytest.raises(TypeError, match=r"expected root version to be a string"): + provider.get_version() + + +def test_cargo_provider_get_raises_when_no_package_tables( + config: BaseConfig, + chdir: Path, +) -> None: + """_root_version_table should raise when neither [workspace.package] nor [package] exists.""" + (chdir / CargoProvider.filename).write_text( + dedent( + """\ + [workspace] + members = [] + """ + ) + ) + config.settings["version_provider"] = "cargo" + + provider = get_provider(config) + assert isinstance(provider, CargoProvider) + + with pytest.raises( + NonExistentKey, match=r"expected either \[workspace\.package\] or \[package\]" + ): + provider.get_version() + + +def test_workspace_member_dir_without_cargo_toml_is_ignored( + config: BaseConfig, + chdir: Path, +) -> None: + """Cover: if not cargo_file.exists(): continue""" + workspace_toml = """\ + [workspace] + members = ["missing_manifest"] + + [workspace.package] + version = "0.1.0" + """ + lock_content = """\ + [[package]] + name = "missing_manifest" + version = "0.1.0" + source = "registry+https://github.com/rust-lang/crates.io-index" + checksum = "123abc" + """ + expected_workspace_toml = """\ + [workspace] + members = ["missing_manifest"] + + [workspace.package] + version = "42.1" + """ + + (chdir / CargoProvider.filename).write_text(dedent(workspace_toml)) + os.mkdir(chdir / "missing_manifest") # directory exists, but Cargo.toml does NOT + + (chdir / CargoProvider.lock_filename).write_text(dedent(lock_content)) + + config.settings["version_provider"] = "cargo" + provider = get_provider(config) + assert isinstance(provider, CargoProvider) + + provider.set_version("42.1") + assert (chdir / CargoProvider.filename).read_text() == dedent( + expected_workspace_toml + ) + # lock should remain unchanged since member cannot be inspected => not inheriting + assert (chdir / CargoProvider.lock_filename).read_text() == dedent(lock_content) + + +def test_workspace_member_with_non_table_package_is_ignored( + config: BaseConfig, + chdir: Path, +) -> None: + """Cover: if not hasattr(pkg, "get"): continue""" + workspace_toml = """\ + [workspace] + members = ["bad_package"] + + [workspace.package] + version = "0.1.0" + """ + member_toml = """\ + package = "oops" + """ + lock_content = """\ + [[package]] + name = "bad_package" + version = "0.1.0" + source = "registry+https://github.com/rust-lang/crates.io-index" + checksum = "123abc" + """ + expected_workspace_toml = """\ + [workspace] + members = ["bad_package"] + + [workspace.package] + version = "42.1" + """ + + (chdir / CargoProvider.filename).write_text(dedent(workspace_toml)) + + os.mkdir(chdir / "bad_package") + (chdir / "bad_package" / "Cargo.toml").write_text( + dedent(member_toml) + ) # package is str, not table + + (chdir / CargoProvider.lock_filename).write_text(dedent(lock_content)) + + config.settings["version_provider"] = "cargo" + provider = get_provider(config) + assert isinstance(provider, CargoProvider) + + provider.set_version("42.1") + assert (chdir / CargoProvider.filename).read_text() == dedent( + expected_workspace_toml + ) + # lock should remain unchanged since package is not a table => not inheriting + assert (chdir / CargoProvider.lock_filename).read_text() == dedent(lock_content)