From 14ead5e0a4ca423a60c09beb852972eef2311d88 Mon Sep 17 00:00:00 2001 From: benlipkin Date: Tue, 22 Oct 2024 20:36:48 -0400 Subject: [PATCH] fix: continue to compare passing samples to live beam in tree search --- TUTORIAL.md | 2 +- decoding/generators.py | 76 ++++++++++++++++-------------- examples/thm_proving_treesearch.py | 2 +- tests/test_generators.py | 45 +++++++++++------- 4 files changed, 70 insertions(+), 55 deletions(-) diff --git a/TUTORIAL.md b/TUTORIAL.md index 882ba88..4f506ba 100644 --- a/TUTORIAL.md +++ b/TUTORIAL.md @@ -213,7 +213,7 @@ def run(prompt: str) -> list[Sample[str]]: stop_cond_pass=stop_pass, sync_str=sync_str, n=10, # number of complete samples to search for - beam_width=20, # size of active beam to maintain + beam_width=25, # size of beam to maintain beam_factor=5, # number of ways to split each particle at each step seed=42, ) diff --git a/decoding/generators.py b/decoding/generators.py index e58ac71..34c8bd6 100644 --- a/decoding/generators.py +++ b/decoding/generators.py @@ -155,26 +155,26 @@ def TreeSearch( # noqa: PLR0913 step_scorer: The scorer to rank the samples at each sync step. final_scorer: The scorer to rank the final beam. stop_cond_pass: A function that returns `True` if the sample should pass. - This adds a sample to the final beam. + This stops the sample from being extended. stop_cond_fail: A function that returns `True` if the sample should fail. - This filters the sample from the active beam. + This filters the sample from the live beam. n: The number of passing samples to generate before returning. - beam_width: The width of the active beam. This is the number of samples to + beam_width: The width of the beam. This is the number of samples to keep at each step. - beam_factor: The branching factor of the active beam. This is the number of - new samples to generate from each active sample at each sync step. + beam_factor: The branching factor of the beam. This is the number of + new samples to generate from each live sample at each sync step. max_steps: The maximum number of sync steps to take. min_tokens_per_step: The minimum number of tokens in each step's extension. max_tokens_per_step: The maximum number of tokens in each step's extension. - sync_str: A string or list of strings that, if generated, will stop each sample - in the active beam and await the sync step scoring, ranking, and filtering. - sync_token_ids: A list of token IDs that, if generated, will stop each sample - in the active beam and await the sync step scoring, ranking, and filtering. + sync_str: A string or list of strings that, if generated, will stop extending + each sample in the live beam and await scoring, ranking, and filtering. + sync_token_ids: A list of token IDs that, if generated, will stop extending + each sample in the live beam and await scoring, ranking, and filtering. A string can also be passed, which will specify all token IDs that contain that substring. include_sync_str_in_output: Whether to include the stop string in the output. track_logprobs: Whether to track log probabilities. This comes at a performance - cost, so it is off by default. In most cases, as you are alrady sampling + cost, so it is off by default. In most cases, as you are already sampling from the model, you do not want to double count the probabilities in the scorer anyways. temperature: The temperature for sampling. @@ -187,7 +187,7 @@ def TreeSearch( # noqa: PLR0913 Raises: ValueError: If any of the argument configurations are invalid - RuntimeError: if all active samples in the beam fail, + RuntimeError: if all live samples in the beam fail, or if max steps is reached before any samples pass. Examples: @@ -268,26 +268,28 @@ def _TreeSearch( sampling_params: SamplingParams, ) -> list[Sample[str]]: beam = [Sample(item=p, utility=-float("inf")) for p in prompts] - finished = set() + passing = [] for _ in range(search_params.max_steps): - prompts = [] stop_pass = [search_params.stop_pass(s.item) for s in beam] stop_fail = [search_params.stop_fail(s.item) for s in beam] - if all(stop_fail): - return _handle_failed_beam(finished) + passing = [] + prompts = [] for sample, passed, failed in zip(beam, stop_pass, stop_fail, strict=True): - if passed: - finished.add(sample) - continue - if failed: - continue - prompts.append(sample.item) - if len(finished) >= search_params.n: - return sort_samples(finished)[: search_params.n] - beam = _BestOfN(prompts, llm, scorer, sampling_params) + if passed and not failed: + passing.append(sample) + elif not failed: + prompts.append(sample.item) + else: # failed + pass + if len(passing) >= search_params.n: + return sort_samples(passing)[: search_params.n] + if len(prompts) == 0: + return _handle_failed_beam(passing) + live = _BestOfN(prompts, llm, scorer, sampling_params) + beam = passing + live if len(beam) > search_params.width: - beam = beam[: search_params.width] - return _handle_maxsteps(finished) + beam = sort_samples(beam)[: search_params.width] + return _handle_maxsteps(passing) def _prepare_token_ids( @@ -350,27 +352,31 @@ def _guard_positive_int(n: int) -> int: return n -def _handle_failed_beam(finished: set[Sample[str]]) -> list[Sample[str]]: - if len(finished) == 0: - msg = "All live samples failed." +def _handle_failed_beam(passing: list[Sample[str]]) -> list[Sample[str]]: + if len(passing) == 0: + msg = "All live samples failed before any passed stop conditions." msg += " Check compatibility of stop conditions or expand search." raise RuntimeError(msg) import warnings - msg = "All live samples failed. Returning available finished samples." + msg = "All live samples failed before completing search," + msg += " but some completed samples have already passed stopping conditions." + msg += " Returning available passing samples." warnings.warn(msg, stacklevel=2) - return sort_samples(finished) + return sort_samples(passing) -def _handle_maxsteps(finished: set[Sample[str]]) -> list[Sample[str]]: - if len(finished) == 0: +def _handle_maxsteps(passing: list[Sample[str]]) -> list[Sample[str]]: + if len(passing) == 0: msg = "Max steps reached, and no samples passed stop conditions." raise RuntimeError(msg) import warnings - msg = "Max steps reached. Returning available finished samples." + msg = "Max steps reached before completing search," + msg += "but some samples have already passed stopping conditions." + msg += " Returning available passing samples." warnings.warn(msg, stacklevel=2) - return sort_samples(finished) + return sort_samples(passing) _default_sampling_kwargs = { diff --git a/examples/thm_proving_treesearch.py b/examples/thm_proving_treesearch.py index 1db8274..af74dd9 100644 --- a/examples/thm_proving_treesearch.py +++ b/examples/thm_proving_treesearch.py @@ -79,7 +79,7 @@ def run(prompt: str) -> str: final_scorer=final_scorer, stop_cond_pass=stop_pass, n=10, - beam_width=20, + beam_width=25, beam_factor=5, sync_str="\n", seed=42, diff --git a/tests/test_generators.py b/tests/test_generators.py index c62969f..b329d8d 100644 --- a/tests/test_generators.py +++ b/tests/test_generators.py @@ -38,12 +38,14 @@ def test_treesearch_basic() -> None: delim = " " end = "." - def utility(s: str) -> int: - return -len(s) - def stop(s: str) -> bool: return end in s + def utility(s: str) -> int: + if stop(s): + return 1 + return -len(s) + scorer = Scorer.from_f_str_to_num(utility) sentence = TreeSearch( llm=llm, @@ -88,7 +90,12 @@ def test_treesearch_step() -> None: delim = " " end = "." + def stop(s: str) -> bool: + return end in s + def utility_step(s: str) -> int: + if stop(s): + return 1 ws = s.split(delim) check_words = 2 if len(ws) < check_words: @@ -98,9 +105,6 @@ def utility_step(s: str) -> int: def utility_final(s: str) -> int: return -len(s) - def stop(s: str) -> bool: - return end in s - step_scorer = Scorer.from_f_str_to_num(utility_step) final_scorer = Scorer.from_f_str_to_num(utility_final) sentence = TreeSearch( @@ -125,15 +129,17 @@ def test_treesearch_fail() -> None: delim = " " end = "." + def stop(s: str) -> bool: + return end in s + def utility(s: str) -> int: + if stop(s): + return 1 return -len(s) def fail(s: str) -> bool: return len(s) > max_len_constraint - def stop(s: str) -> bool: - return end in s - scorer = Scorer.from_f_str_to_num(utility) def beam_search(n: int, beam_width: int, beam_factor: int) -> list[Sample[str]]: @@ -157,7 +163,7 @@ def beam_search(n: int, beam_width: int, beam_factor: int) -> list[Sample[str]]: assert end in sentence n_requested = 5 - msg = "All live samples failed. Returning available finished samples." + msg = "All live samples failed before completing search" with pytest.warns(UserWarning, match=msg): out = beam_search(n_requested, 30, 6) assert 0 < len(out) < n_requested @@ -165,8 +171,7 @@ def beam_search(n: int, beam_width: int, beam_factor: int) -> list[Sample[str]]: assert all(s.item.startswith(start) for s in out) assert all(end in s.item for s in out) - msg = "All live samples failed." - msg += " Check compatibility of stop conditions or expand search." + msg = "All live samples failed before any passed stop conditions" with pytest.raises(RuntimeError, match=msg): beam_search(1, 30, 2) @@ -176,13 +181,16 @@ def test_treesearch_maxsteps() -> None: delim = " " end = "." - def utility(s: str) -> int: - return -len(s) - def stop(s: str) -> bool: return end in s + def utility(s: str) -> int: + if stop(s): + return 1 + return -len(s) + scorer = Scorer.from_f_str_to_num(utility) + n_requested = 3 def beam_search(max_steps: int) -> list[Sample[str]]: return TreeSearch( @@ -192,7 +200,7 @@ def beam_search(max_steps: int) -> list[Sample[str]]: sync_token_ids=delim, stop_cond_pass=stop, max_steps=max_steps, - n=3, + n=n_requested, beam_width=30, beam_factor=6, seed=0, @@ -200,16 +208,17 @@ def beam_search(max_steps: int) -> list[Sample[str]]: out = beam_search(5) sentence = out[0].item + assert len(out) == n_requested assert sentence.startswith(start) assert end in sentence - msg = "Max steps reached. Returning available finished samples." + msg = "Max steps reached before completing search" with pytest.warns(UserWarning, match=msg): out = beam_search(3) assert all(s.item.startswith(start) for s in out) assert all(end in s.item for s in out) - msg = "Max steps reached, and no samples passed stop conditions." + msg = "Max steps reached, and no samples passed stop conditions" with pytest.raises(RuntimeError, match=msg): beam_search(1)