diff --git a/pyproject.toml b/pyproject.toml index 8653807cc..700cd9e7e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,11 +52,11 @@ Issues = "https://github.com/ukaea/process/issues" [project.optional-dependencies] test = [ "pytest>=5.4.1", - "requests>=2.30", "testbook>=0.4", "pytest-cov>=3.0.0", "pytest-xdist>=2.5.0", "platformdirs~=4.5.0", + "filelock", "process[examples]" ] examples = ["jupyter==1.0.0", "jupytext"] diff --git a/tests/regression/regression_test_assets.py b/tests/regression/regression_test_assets.py index 90ea110c9..3920fe490 100644 --- a/tests/regression/regression_test_assets.py +++ b/tests/regression/regression_test_assets.py @@ -1,13 +1,13 @@ """Provides the classes to find, download, and access tracked MFiles on -a remote data repository +a remote data repository. """ import dataclasses import logging import re import subprocess +from pathlib import Path -import requests from platformdirs import user_cache_path logger = logging.getLogger(__name__) @@ -19,17 +19,36 @@ class TrackedMFile: hash: str scenario_name: str - download_link: str + location: Path class RegressionTestAssetCollector: - remote_repository_owner = "timothy-nunn" - remote_repository_repo = "process-tracking-data" - - def __init__(self): + def __init__(self, cache_location: Path | None = None): + self._cache_location = cache_location or TEST_ASSET_CACHE_DIR self._hashes = self._git_commit_hashes() + self._repo_dir = self._get_regression_assets() self._tracked_mfiles = self._get_tracked_mfiles() + def _get_regression_assets(self): + """Ensures the user has an up-to-date local copy of the regression references by cloning/pulling + the remote repository to a local cache. + """ + repo_dir = TEST_ASSET_CACHE_DIR / "process-tracking-data" + if not repo_dir.exists(): + repo_dir.mkdir(parents=True) + + subprocess.run( + f"git clone https://github.com/timothy-nunn/process-tracking-data.git '{repo_dir.as_posix()}'", + shell=True, + check=True, + ) + else: + subprocess.run( + "git pull", shell=True, check=True, cwd=repo_dir, capture_output=True + ) + + return repo_dir + def get_reference_mfile(self, scenario_name: str, target_hash: str | None = None): """Finds the most recent reference MFile for `.IN.DAT` and downloads it to the `directory` with the name `ref..MFILE.DAT`. @@ -55,24 +74,7 @@ def get_reference_mfile(self, scenario_name: str, target_hash: str | None = None if (mf.scenario_name == scenario_name and target_hash is None) or ( mf.scenario_name == scenario_name and target_hash == mf.hash ): - cache_directory = TEST_ASSET_CACHE_DIR / mf.hash - cached_location = cache_directory / f"ref.{scenario_name}.MFILE.DAT" - - if cached_location.exists(): - logger.info( - f"Using cached reference MFile ({cached_location}) found for commit {mf.hash}." - ) - return cached_location - - cache_directory.mkdir(parents=True, exist_ok=True) - cached_location.write_text( - requests.get(mf.download_link).content.decode() - ) - - logger.info( - f"Reference MFile found for commit {mf.hash}. Writing to {cached_location}" - ) - return cached_location + return mf.location return None @@ -94,27 +96,20 @@ def _git_commit_hashes(self): ) def _get_tracked_mfiles(self): - """Gets a list of tracked MFiles from the remote repository. + """Gets a list of tracked MFiles. :returns: a list of tracked MFiles sorted to match the order of hashes returned from `_git_commit_hashes`. :rtype: list[TrackedMFile] """ - repository_files_request = requests.get( - f"https://api.github.com/repos/" - f"{self.remote_repository_owner}/{self.remote_repository_repo}/git/trees/master" - ) - repository_files_request.raise_for_status() - repository_files = repository_files_request.json()["tree"] - # create a list of tracked MFiles from the list of all files - # in the remote repository. + # in the repository. # Only keep TrackedMFiles that are tracked for a commit on the # current branch. This stops issues arising from main being # ahead of the feature branch and having newer tracks. tracked_mfiles = [ mfile - for f in repository_files + for f in self._repo_dir.glob("*.DAT") if (mfile := self._get_tracked_mfile(f)) is not None and mfile.hash in self._hashes ] @@ -125,7 +120,7 @@ def _get_tracked_mfiles(self): key=lambda m: self._hashes.index(m.hash), ) - def _get_tracked_mfile(self, json_data): + def _get_tracked_mfile(self, file: Path): """Converts JSON data of a file tracked on GitHub into a `TrackedMFile`, if appropriate @@ -137,13 +132,10 @@ def _get_tracked_mfile(self, json_data): tracked mfile. :rtype: TrackedMFile | None """ - rematch = re.match(r"([a-zA-Z0-9_.]+)_MFILE_([a-z0-9]+).DAT", json_data["path"]) + rematch = re.match(r"([a-zA-Z0-9_.]+)_MFILE_([a-z0-9]+).DAT", file.name) if rematch is None: return None return TrackedMFile( - hash=rematch.group(2), - scenario_name=rematch.group(1), - download_link=f"https://raw.githubusercontent.com/" - f"{self.remote_repository_owner}/{self.remote_repository_repo}/master/{json_data['path']}", + hash=rematch.group(2), scenario_name=rematch.group(1), location=file ) diff --git a/tests/regression/test_process_input_files.py b/tests/regression/test_process_input_files.py index 416375984..80d9f2cd8 100644 --- a/tests/regression/test_process_input_files.py +++ b/tests/regression/test_process_input_files.py @@ -13,6 +13,7 @@ from pathlib import Path import pytest +from filelock import FileLock from regression_test_assets import RegressionTestAssetCollector from process.core.io.mfile import MFile @@ -222,15 +223,21 @@ def mfile_value_changes( return diffs -@pytest.fixture(scope="session") -def tracked_regression_test_assets(): +@pytest.fixture(scope="module") +def tracked_regression_test_assets(tmp_path_factory, worker_id): """Session fixture providing a RegressionTestAssetCollector - for finding remote tracked MFiles. + for finding tracked MFiles. - This fixture creates one asset collector that is shared - between all regression tests and reduces the number of - API calls made to the remote repository.""" - return RegressionTestAssetCollector() + When running using pytest-xdist this fixture stops multiple workers operating on + the asset directory at once using a file lock. + """ + if worker_id == "master": + return RegressionTestAssetCollector() + + tmpdir = tmp_path_factory.getbasetemp().parent + + with FileLock(tmpdir / "regression_tests.lock"): + return RegressionTestAssetCollector() @pytest.mark.parametrize(