Skip to content

Commit

Permalink
fix: continue to compare passing samples to live beam in tree search
Browse files Browse the repository at this point in the history
  • Loading branch information
benlipkin committed Oct 23, 2024
1 parent de2c330 commit 14ead5e
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 55 deletions.
2 changes: 1 addition & 1 deletion TUTORIAL.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def run(prompt: str) -> list[Sample[str]]:
stop_cond_pass=stop_pass,
sync_str=sync_str,
n=10, # number of complete samples to search for
beam_width=20, # size of active beam to maintain
beam_width=25, # size of beam to maintain
beam_factor=5, # number of ways to split each particle at each step
seed=42,
)
Expand Down
76 changes: 41 additions & 35 deletions decoding/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,26 +155,26 @@ def TreeSearch( # noqa: PLR0913
step_scorer: The scorer to rank the samples at each sync step.
final_scorer: The scorer to rank the final beam.
stop_cond_pass: A function that returns `True` if the sample should pass.
This adds a sample to the final beam.
This stops the sample from being extended.
stop_cond_fail: A function that returns `True` if the sample should fail.
This filters the sample from the active beam.
This filters the sample from the live beam.
n: The number of passing samples to generate before returning.
beam_width: The width of the active beam. This is the number of samples to
beam_width: The width of the beam. This is the number of samples to
keep at each step.
beam_factor: The branching factor of the active beam. This is the number of
new samples to generate from each active sample at each sync step.
beam_factor: The branching factor of the beam. This is the number of
new samples to generate from each live sample at each sync step.
max_steps: The maximum number of sync steps to take.
min_tokens_per_step: The minimum number of tokens in each step's extension.
max_tokens_per_step: The maximum number of tokens in each step's extension.
sync_str: A string or list of strings that, if generated, will stop each sample
in the active beam and await the sync step scoring, ranking, and filtering.
sync_token_ids: A list of token IDs that, if generated, will stop each sample
in the active beam and await the sync step scoring, ranking, and filtering.
sync_str: A string or list of strings that, if generated, will stop extending
each sample in the live beam and await scoring, ranking, and filtering.
sync_token_ids: A list of token IDs that, if generated, will stop extending
each sample in the live beam and await scoring, ranking, and filtering.
A string can also be passed, which will specify all token IDs that contain
that substring.
include_sync_str_in_output: Whether to include the stop string in the output.
track_logprobs: Whether to track log probabilities. This comes at a performance
cost, so it is off by default. In most cases, as you are alrady sampling
cost, so it is off by default. In most cases, as you are already sampling
from the model, you do not want to double count the probabilities in the
scorer anyways.
temperature: The temperature for sampling.
Expand All @@ -187,7 +187,7 @@ def TreeSearch( # noqa: PLR0913
Raises:
ValueError: If any of the argument configurations are invalid
RuntimeError: if all active samples in the beam fail,
RuntimeError: if all live samples in the beam fail,
or if max steps is reached before any samples pass.
Examples:
Expand Down Expand Up @@ -268,26 +268,28 @@ def _TreeSearch(
sampling_params: SamplingParams,
) -> list[Sample[str]]:
beam = [Sample(item=p, utility=-float("inf")) for p in prompts]
finished = set()
passing = []
for _ in range(search_params.max_steps):
prompts = []
stop_pass = [search_params.stop_pass(s.item) for s in beam]
stop_fail = [search_params.stop_fail(s.item) for s in beam]
if all(stop_fail):
return _handle_failed_beam(finished)
passing = []
prompts = []
for sample, passed, failed in zip(beam, stop_pass, stop_fail, strict=True):
if passed:
finished.add(sample)
continue
if failed:
continue
prompts.append(sample.item)
if len(finished) >= search_params.n:
return sort_samples(finished)[: search_params.n]
beam = _BestOfN(prompts, llm, scorer, sampling_params)
if passed and not failed:
passing.append(sample)
elif not failed:
prompts.append(sample.item)
else: # failed
pass
if len(passing) >= search_params.n:
return sort_samples(passing)[: search_params.n]
if len(prompts) == 0:
return _handle_failed_beam(passing)
live = _BestOfN(prompts, llm, scorer, sampling_params)
beam = passing + live
if len(beam) > search_params.width:
beam = beam[: search_params.width]
return _handle_maxsteps(finished)
beam = sort_samples(beam)[: search_params.width]
return _handle_maxsteps(passing)


def _prepare_token_ids(
Expand Down Expand Up @@ -350,27 +352,31 @@ def _guard_positive_int(n: int) -> int:
return n


def _handle_failed_beam(finished: set[Sample[str]]) -> list[Sample[str]]:
if len(finished) == 0:
msg = "All live samples failed."
def _handle_failed_beam(passing: list[Sample[str]]) -> list[Sample[str]]:
if len(passing) == 0:
msg = "All live samples failed before any passed stop conditions."
msg += " Check compatibility of stop conditions or expand search."
raise RuntimeError(msg)
import warnings

msg = "All live samples failed. Returning available finished samples."
msg = "All live samples failed before completing search,"
msg += " but some completed samples have already passed stopping conditions."
msg += " Returning available passing samples."
warnings.warn(msg, stacklevel=2)
return sort_samples(finished)
return sort_samples(passing)


def _handle_maxsteps(finished: set[Sample[str]]) -> list[Sample[str]]:
if len(finished) == 0:
def _handle_maxsteps(passing: list[Sample[str]]) -> list[Sample[str]]:
if len(passing) == 0:
msg = "Max steps reached, and no samples passed stop conditions."
raise RuntimeError(msg)
import warnings

msg = "Max steps reached. Returning available finished samples."
msg = "Max steps reached before completing search,"
msg += "but some samples have already passed stopping conditions."
msg += " Returning available passing samples."
warnings.warn(msg, stacklevel=2)
return sort_samples(finished)
return sort_samples(passing)


_default_sampling_kwargs = {
Expand Down
2 changes: 1 addition & 1 deletion examples/thm_proving_treesearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def run(prompt: str) -> str:
final_scorer=final_scorer,
stop_cond_pass=stop_pass,
n=10,
beam_width=20,
beam_width=25,
beam_factor=5,
sync_str="\n",
seed=42,
Expand Down
45 changes: 27 additions & 18 deletions tests/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,14 @@ def test_treesearch_basic() -> None:
delim = " "
end = "."

def utility(s: str) -> int:
return -len(s)

def stop(s: str) -> bool:
return end in s

def utility(s: str) -> int:
if stop(s):
return 1
return -len(s)

scorer = Scorer.from_f_str_to_num(utility)
sentence = TreeSearch(
llm=llm,
Expand Down Expand Up @@ -88,7 +90,12 @@ def test_treesearch_step() -> None:
delim = " "
end = "."

def stop(s: str) -> bool:
return end in s

def utility_step(s: str) -> int:
if stop(s):
return 1
ws = s.split(delim)
check_words = 2
if len(ws) < check_words:
Expand All @@ -98,9 +105,6 @@ def utility_step(s: str) -> int:
def utility_final(s: str) -> int:
return -len(s)

def stop(s: str) -> bool:
return end in s

step_scorer = Scorer.from_f_str_to_num(utility_step)
final_scorer = Scorer.from_f_str_to_num(utility_final)
sentence = TreeSearch(
Expand All @@ -125,15 +129,17 @@ def test_treesearch_fail() -> None:
delim = " "
end = "."

def stop(s: str) -> bool:
return end in s

def utility(s: str) -> int:
if stop(s):
return 1
return -len(s)

def fail(s: str) -> bool:
return len(s) > max_len_constraint

def stop(s: str) -> bool:
return end in s

scorer = Scorer.from_f_str_to_num(utility)

def beam_search(n: int, beam_width: int, beam_factor: int) -> list[Sample[str]]:
Expand All @@ -157,16 +163,15 @@ def beam_search(n: int, beam_width: int, beam_factor: int) -> list[Sample[str]]:
assert end in sentence

n_requested = 5
msg = "All live samples failed. Returning available finished samples."
msg = "All live samples failed before completing search"
with pytest.warns(UserWarning, match=msg):
out = beam_search(n_requested, 30, 6)
assert 0 < len(out) < n_requested
assert all(len(s.item) <= max_len_constraint for s in out)
assert all(s.item.startswith(start) for s in out)
assert all(end in s.item for s in out)

msg = "All live samples failed."
msg += " Check compatibility of stop conditions or expand search."
msg = "All live samples failed before any passed stop conditions"
with pytest.raises(RuntimeError, match=msg):
beam_search(1, 30, 2)

Expand All @@ -176,13 +181,16 @@ def test_treesearch_maxsteps() -> None:
delim = " "
end = "."

def utility(s: str) -> int:
return -len(s)

def stop(s: str) -> bool:
return end in s

def utility(s: str) -> int:
if stop(s):
return 1
return -len(s)

scorer = Scorer.from_f_str_to_num(utility)
n_requested = 3

def beam_search(max_steps: int) -> list[Sample[str]]:
return TreeSearch(
Expand All @@ -192,24 +200,25 @@ def beam_search(max_steps: int) -> list[Sample[str]]:
sync_token_ids=delim,
stop_cond_pass=stop,
max_steps=max_steps,
n=3,
n=n_requested,
beam_width=30,
beam_factor=6,
seed=0,
)

out = beam_search(5)
sentence = out[0].item
assert len(out) == n_requested
assert sentence.startswith(start)
assert end in sentence

msg = "Max steps reached. Returning available finished samples."
msg = "Max steps reached before completing search"
with pytest.warns(UserWarning, match=msg):
out = beam_search(3)
assert all(s.item.startswith(start) for s in out)
assert all(end in s.item for s in out)

msg = "Max steps reached, and no samples passed stop conditions."
msg = "Max steps reached, and no samples passed stop conditions"
with pytest.raises(RuntimeError, match=msg):
beam_search(1)

Expand Down

0 comments on commit 14ead5e

Please sign in to comment.