Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ Issues = "https://github.com/ukaea/process/issues"
[project.optional-dependencies]
test = [
"pytest>=5.4.1",
"requests>=2.30",
"testbook>=0.4",
"pytest-cov>=3.0.0",
"pytest-xdist>=2.5.0",
"platformdirs~=4.5.0",
"filelock",
"process[examples]"
]
examples = ["jupyter==1.0.0", "jupytext"]
Expand Down
74 changes: 33 additions & 41 deletions tests/regression/regression_test_assets.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""Provides the classes to find, download, and access tracked MFiles on
a remote data repository
a remote data repository.
"""

import dataclasses
import logging
import re
import subprocess
from pathlib import Path

import requests
from platformdirs import user_cache_path

logger = logging.getLogger(__name__)
Expand All @@ -19,17 +19,36 @@
class TrackedMFile:
hash: str
scenario_name: str
download_link: str
location: Path


class RegressionTestAssetCollector:
remote_repository_owner = "timothy-nunn"
remote_repository_repo = "process-tracking-data"

def __init__(self):
def __init__(self, cache_location: Path | None = None):
self._cache_location = cache_location or TEST_ASSET_CACHE_DIR
self._hashes = self._git_commit_hashes()
self._repo_dir = self._get_regression_assets()
self._tracked_mfiles = self._get_tracked_mfiles()

def _get_regression_assets(self):
"""Ensures the user has an up-to-date local copy of the regression references by cloning/pulling
the remote repository to a local cache.
"""
repo_dir = TEST_ASSET_CACHE_DIR / "process-tracking-data"
if not repo_dir.exists():
repo_dir.mkdir(parents=True)

subprocess.run(
f"git clone https://github.com/timothy-nunn/process-tracking-data.git '{repo_dir.as_posix()}'",
shell=True,
check=True,
)
else:
subprocess.run(
"git pull", shell=True, check=True, cwd=repo_dir, capture_output=True
)

return repo_dir

def get_reference_mfile(self, scenario_name: str, target_hash: str | None = None):
"""Finds the most recent reference MFile for `<scenario_name>.IN.DAT`
and downloads it to the `directory` with the name `ref.<scenario_name>.MFILE.DAT`.
Expand All @@ -55,24 +74,7 @@ def get_reference_mfile(self, scenario_name: str, target_hash: str | None = None
if (mf.scenario_name == scenario_name and target_hash is None) or (
mf.scenario_name == scenario_name and target_hash == mf.hash
):
cache_directory = TEST_ASSET_CACHE_DIR / mf.hash
cached_location = cache_directory / f"ref.{scenario_name}.MFILE.DAT"

if cached_location.exists():
logger.info(
f"Using cached reference MFile ({cached_location}) found for commit {mf.hash}."
)
return cached_location

cache_directory.mkdir(parents=True, exist_ok=True)
cached_location.write_text(
requests.get(mf.download_link).content.decode()
)

logger.info(
f"Reference MFile found for commit {mf.hash}. Writing to {cached_location}"
)
return cached_location
return mf.location

return None

Expand All @@ -94,27 +96,20 @@ def _git_commit_hashes(self):
)

def _get_tracked_mfiles(self):
"""Gets a list of tracked MFiles from the remote repository.
"""Gets a list of tracked MFiles.

:returns: a list of tracked MFiles sorted to match the order of
hashes returned from `_git_commit_hashes`.
:rtype: list[TrackedMFile]
"""
repository_files_request = requests.get(
f"https://api.github.com/repos/"
f"{self.remote_repository_owner}/{self.remote_repository_repo}/git/trees/master"
)
repository_files_request.raise_for_status()
repository_files = repository_files_request.json()["tree"]

# create a list of tracked MFiles from the list of all files
# in the remote repository.
# in the repository.
# Only keep TrackedMFiles that are tracked for a commit on the
# current branch. This stops issues arising from main being
# ahead of the feature branch and having newer tracks.
tracked_mfiles = [
mfile
for f in repository_files
for f in self._repo_dir.glob("*.DAT")
if (mfile := self._get_tracked_mfile(f)) is not None
and mfile.hash in self._hashes
]
Expand All @@ -125,7 +120,7 @@ def _get_tracked_mfiles(self):
key=lambda m: self._hashes.index(m.hash),
)

def _get_tracked_mfile(self, json_data):
def _get_tracked_mfile(self, file: Path):
"""Converts JSON data of a file tracked on GitHub into a
`TrackedMFile`, if appropriate

Expand All @@ -137,13 +132,10 @@ def _get_tracked_mfile(self, json_data):
tracked mfile.
:rtype: TrackedMFile | None
"""
rematch = re.match(r"([a-zA-Z0-9_.]+)_MFILE_([a-z0-9]+).DAT", json_data["path"])
rematch = re.match(r"([a-zA-Z0-9_.]+)_MFILE_([a-z0-9]+).DAT", file.name)

if rematch is None:
return None
return TrackedMFile(
hash=rematch.group(2),
scenario_name=rematch.group(1),
download_link=f"https://raw.githubusercontent.com/"
f"{self.remote_repository_owner}/{self.remote_repository_repo}/master/{json_data['path']}",
hash=rematch.group(2), scenario_name=rematch.group(1), location=file
)
21 changes: 14 additions & 7 deletions tests/regression/test_process_input_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pathlib import Path

import pytest
from filelock import FileLock
from regression_test_assets import RegressionTestAssetCollector

from process.core.io.mfile import MFile
Expand Down Expand Up @@ -222,15 +223,21 @@ def mfile_value_changes(
return diffs


@pytest.fixture(scope="session")
def tracked_regression_test_assets():
@pytest.fixture(scope="module")
def tracked_regression_test_assets(tmp_path_factory, worker_id):
"""Session fixture providing a RegressionTestAssetCollector
for finding remote tracked MFiles.
for finding tracked MFiles.

This fixture creates one asset collector that is shared
between all regression tests and reduces the number of
API calls made to the remote repository."""
return RegressionTestAssetCollector()
When running using pytest-xdist this fixture stops multiple workers operating on
the asset directory at once using a file lock.
"""
if worker_id == "master":
return RegressionTestAssetCollector()

tmpdir = tmp_path_factory.getbasetemp().parent

with FileLock(tmpdir / "regression_tests.lock"):
return RegressionTestAssetCollector()


@pytest.mark.parametrize(
Expand Down