Skip to content

Commit 051fd92

Browse files
committed
add jsonl log reloading for continued runs
1 parent 5d308d0 commit 051fd92

File tree

2 files changed

+19
-6
lines changed

2 files changed

+19
-6
lines changed

Diff for: evals/eval.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,14 @@
2727
_MAX_SAMPLES = None
2828

2929

30-
def _index_samples(samples: List[Any]) -> List[Tuple[Any, int]]:
30+
def _index_samples(samples: List[Any], completed_ids: List[int] = []) -> List[Tuple[Any, int]]:
3131
"""Shuffle `samples` and pair each sample with its index."""
3232
indices = list(range(len(samples)))
3333
random.Random(SHUFFLE_SEED).shuffle(indices)
3434
if _MAX_SAMPLES is not None:
3535
indices = indices[:_MAX_SAMPLES]
3636
logger.info(f"Evaluating {len(indices)} samples")
37-
work_items = [(samples[i], i) for i in indices]
37+
work_items = [(samples[i], i) for i in indices if not i in completed_ids]
3838
return work_items
3939

4040

@@ -120,7 +120,10 @@ def eval_all_samples(
120120
"""
121121
Evaluate all provided samples in parallel.
122122
"""
123-
work_items = _index_samples(samples)
123+
samples_completed_ids = [int(event.sample_id.split(".")[-1]) for event in recorder.get_events(type="sampling")]
124+
logger.info(f"Completed samples: {samples_completed_ids}")
125+
126+
work_items = _index_samples(samples, completed_ids=samples_completed_ids)
124127
threads = int(os.environ.get("EVALS_THREADS", "10"))
125128
show_progress = bool(os.environ.get("EVALS_SHOW_EVAL_PROGRESS", show_progress))
126129

@@ -207,7 +210,10 @@ def eval_all_samples(
207210
"""
208211
Evaluate all provided samples in parallel.
209212
"""
210-
work_items = _index_samples(samples)
213+
samples_completed_ids = [int(event.sample_id.split(".")[-1]) for event in recorder.get_events(type="sampling")]
214+
logger.info(f"Completed samples: {samples_completed_ids}")
215+
216+
work_items = _index_samples(samples, completed_ids=samples_completed_ids)
211217
threads = int(os.environ.get("EVALS_THREADS", "10"))
212218
show_progress = bool(os.environ.get("EVALS_SHOW_EVAL_PROGRESS", show_progress))
213219

Diff for: evals/record.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
import atexit
1010
import contextlib
1111
import dataclasses
12+
import json
1213
import logging
14+
import os.path
1315
import threading
1416
import time
1517
from contextvars import ContextVar
@@ -340,8 +342,13 @@ def __init__(
340342
self.event_file_path = log_path
341343
self.hidden_data_fields = hidden_data_fields
342344
if log_path is not None:
343-
with bf.BlobFile(log_path, "wb") as f:
344-
f.write((jsondumps({"spec": dataclasses.asdict(run_spec)}) + "\n").encode("utf-8"))
345+
if not os.path.exists(log_path):
346+
with bf.BlobFile(log_path, "wb") as f:
347+
f.write((jsondumps({"spec": dataclasses.asdict(run_spec)}) + "\n").encode("utf-8"))
348+
else:
349+
lines = bf.BlobFile(log_path, "rb").readlines()
350+
run_spec.run_id = json.loads(lines[0])["spec"]["run_id"]
351+
self._events = [Event(**json.loads(line)) for line in lines[1:]]
345352

346353
def _flush_events_internal(self, events_to_write: Sequence[Event]):
347354
start = time.time()

0 commit comments

Comments
 (0)