Skip to content

Commit

Permalink
Fix everest fixture for caching testcase results
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathan-eq committed Feb 3, 2025
1 parent 050077c commit 783d311
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
16 changes: 11 additions & 5 deletions tests/everest/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,17 @@ def create_evaluator_server_config(run_model):


@pytest.fixture
def cached_example(pytestconfig, evaluator_server_config_generator):
def cached_example(pytestconfig: pytest.Config, evaluator_server_config_generator):
cache = pytestconfig.cache

def run_config(test_data_case: str):
if cache.get(f"cached_example:{test_data_case}", None) is None:
if (
cache.get(
f"cached_example:{test_data_case}:{os.environ.get('PYTEST_XDIST_WORKER', '')}",
None,
)
is None
):
my_tmpdir = Path(tempfile.mkdtemp())
config_path = (
Path(__file__) / f"../../../test-data/everest/{test_data_case}"
Expand All @@ -187,14 +193,14 @@ def run_config(test_data_case: str):
}

cache.set(
f"cached_example:{test_data_case}",
f"cached_example:{test_data_case}:{os.environ.get('PYTEST_XDIST_WORKER', '')}",
(str(result_path), config_file, optimal_result_json),
)

result_path, config_file, optimal_result_json = cache.get(
f"cached_example:{test_data_case}", (None, None, None)
f"cached_example:{test_data_case}:{os.environ.get('PYTEST_XDIST_WORKER', '')}",
(None, None, None),
)

copied_tmpdir = tempfile.mkdtemp()
shutil.copytree(result_path, Path(copied_tmpdir) / "everest")
copied_path = str(Path(copied_tmpdir) / "everest")
Expand Down
6 changes: 6 additions & 0 deletions tests/everest/test_api_snapshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def test_api_snapshots(config_file, snapshot, cached_example):
.strip()
+ "\n"
)
snapshot.snapshot_dir = (
Path(str(snapshot.snapshot_dir).split(config_file)[0]) / config_file
)
snapshot.assert_match(snapshot_str, "snapshot.json")


Expand Down Expand Up @@ -119,6 +122,9 @@ def test_api_summary_snapshot(config_file, snapshot, cached_example):

api = EverestDataAPI(config)
dicts = api.summary_values().to_dicts()
snapshot.snapshot_dir = (
Path(str(snapshot.snapshot_dir).split(config_file)[0]) / config_file
)
snapshot.assert_match(
orjson.dumps(dicts, option=orjson.OPT_INDENT_2).decode("utf-8").strip() + "\n",
"snapshot.json",
Expand Down

0 comments on commit 783d311

Please sign in to comment.