diff --git a/dspy/teleprompt/simba.py b/dspy/teleprompt/simba.py index 17c1e29d08..d8509a522e 100644 --- a/dspy/teleprompt/simba.py +++ b/dspy/teleprompt/simba.py @@ -1,6 +1,7 @@ import dspy import random import logging +import sys import numpy as np from typing import Callable @@ -22,6 +23,7 @@ def __init__( max_demos=4, demo_input_field_maxlen=100_000, num_threads=None, + max_errors: int = 10, temperature_for_sampling=0.2, temperature_for_candidates=0.2, ): @@ -43,6 +45,7 @@ def __init__( self.max_demos = max_demos self.demo_input_field_maxlen = demo_input_field_maxlen self.num_threads = num_threads + self.max_errors = max_errors self.temperature_for_sampling = temperature_for_sampling self.temperature_for_candidates = temperature_for_candidates @@ -118,7 +121,7 @@ def register_new_program(prog: dspy.Module, score_list: list[float]): instance_idx = 0 # Parallel runner - run_parallel = dspy.Parallel(access_examples=False, num_threads=self.num_threads) + run_parallel = dspy.Parallel(access_examples=False, num_threads=self.num_threads, max_errors=self.max_errors) trial_logs = {} for batch_idx in range(self.max_steps):