Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finetune fix #7629

Merged
merged 8 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions docs/docs/tutorials/classification_finetuning/index.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,48 @@
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<details>\n",
"<summary>Recommended: Set up MLflow Tracing to understand what's happening under the hood.</summary>\n",
"\n",
"### MLflow DSPy Integration\n",
"\n",
"<a href=\"https://mlflow.org/\">MLflow</a> is an LLMOps tool that natively integrates with DSPy and offer explainability and experiment tracking. In this tutorial, you can use MLflow to visualize prompts and optimization progress as traces to understand the DSPy's behavior better. You can set up MLflow easily by following the four steps below.\n",
"\n",
"![MLflow Trace](./mlflow-tracing-classification.png)\n",
"\n",
"1. Install MLflow\n",
"\n",
"```bash\n",
"%pip install mlflow>=2.20\n",
"```\n",
"\n",
"2. Start MLflow UI in a separate terminal\n",
"```bash\n",
"mlflow ui --port 5000\n",
"```\n",
"\n",
"3. Connect the notebook to MLflow\n",
"```python\n",
"import mlflow\n",
"\n",
"mlflow.set_tracking_uri(\"http://localhost:5000\")\n",
"mlflow.set_experiment(\"DSPy\")\n",
"```\n",
"\n",
"4. Enabling tracing.\n",
"```python\n",
"mlflow.dspy.autolog()\n",
"```\n",
"\n",
"\n",
"To learn more about the integration, visit [MLflow DSPy Documentation](https://mlflow.org/docs/latest/llms/dspy/index.html) as well.\n",
"</details>"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -472,6 +514,54 @@
"evaluate(classify_ft)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<details>\n",
"<summary>Tracking Evaluation Results in MLflow Experiment</summary>\n",
"\n",
"<br/>\n",
"\n",
"To track and visualize the evaluation results over time, you can record the results in MLflow Experiment.\n",
"\n",
"\n",
"```python\n",
"import mlflow\n",
"\n",
"with mlflow.start_run(run_name=\"classifier_evaluation\"):\n",
" evaluate_correctness = dspy.Evaluate(\n",
" devset=devset,\n",
" metric=extraction_correctness_metric,\n",
" num_threads=16,\n",
" display_progress=True,\n",
" # To record the outputs and detailed scores to MLflow\n",
" return_all_scores=True,\n",
" return_outputs=True,\n",
" )\n",
"\n",
" # Evaluate the program as usual\n",
" aggregated_score, outputs, all_scores = evaluate_correctness(people_extractor)\n",
"\n",
" # Log the aggregated score\n",
" mlflow.log_metric(\"exact_match\", aggregated_score)\n",
" # Log the detailed evaluation results as a table\n",
" mlflow.log_table(\n",
" {\n",
" \"Text\": [example.text for example in devset],\n",
" \"Expected\": [example.example_label for example in devset],\n",
" \"Predicted\": outputs,\n",
" \"Exact match\": all_scores,\n",
" },\n",
" artifact_file=\"eval_results.json\",\n",
" )\n",
"```\n",
"\n",
"To learn more about the integration, visit [MLflow DSPy Documentation](https://mlflow.org/docs/latest/llms/dspy/index.html) as well.\n",
"\n",
"</details>"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -944,6 +1034,42 @@
"classify_ft(text=\"why hasnt my card come in yet?\")\n",
"dspy.inspect_history()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<details>\n",
"<summary>Saving fine-tuned programs in MLflow Experiment</summary>\n",
"\n",
"<br/>\n",
"\n",
"To deploy the fine-tuned program in production or share it with your team, you can save it in MLflow Experiment. Compared to simply saving it to a local file, MLflow offers the following benefits:\n",
"\n",
"1. **Dependency Management**: MLflow automatically save the frozen environment metadata along with the program to ensure reproducibility.\n",
"2. **Experiment Tracking**: With MLflow, you can track the program's performance and cost along with the program itself.\n",
"3. **Collaboration**: You can share the program and results with your team members by sharing the MLflow experiment.\n",
"\n",
"To save the program in MLflow, run the following code:\n",
"\n",
"```python\n",
"import mlflow\n",
"\n",
"# Start an MLflow Run and save the program\n",
"with mlflow.start_run(run_name=\"optimized_classifier\"):\n",
" model_info = mlflow.dspy.log_model(\n",
" classify_ft,\n",
" artifact_path=\"model\", # Any name to save the program in MLflow\n",
" )\n",
"\n",
"# Load the program back from MLflow\n",
"loaded = mlflow.dspy.load_model(model_info.model_uri)\n",
"```\n",
"\n",
"To learn more about the integration, visit [MLflow DSPy Documentation](https://mlflow.org/docs/latest/llms/dspy/index.html) as well.\n",
"\n",
"</details>"
]
}
],
"metadata": {
Expand Down
24 changes: 12 additions & 12 deletions dspy/clients/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import ujson

from dspy.clients.provider import Provider, TrainingJob
from dspy.clients.utils_finetune import DataFormat, get_finetune_directory
from dspy.clients.utils_finetune import TrainDataFormat, get_finetune_directory

if TYPE_CHECKING:
from databricks.sdk import WorkspaceClient
Expand Down Expand Up @@ -50,7 +50,7 @@ def is_provider_model(model: str) -> bool:
@staticmethod
def deploy_finetuned_model(
model: str,
data_format: Optional[DataFormat] = None,
data_format: Optional[TrainDataFormat] = None,
databricks_host: Optional[str] = None,
databricks_token: Optional[str] = None,
deploy_timeout: int = 900,
Expand Down Expand Up @@ -148,11 +148,11 @@ def deploy_finetuned_model(
num_retries = deploy_timeout // 60
for _ in range(num_retries):
try:
if data_format == DataFormat.chat:
if data_format == TrainDataFormat.CHAT:
client.chat.completions.create(
messages=[{"role": "user", "content": "hi"}], model=model_name, max_tokens=1
)
elif data_format == DataFormat.completion:
elif data_format == TrainDataFormat.COMPLETION:
client.completions.create(prompt="hi", model=model_name, max_tokens=1)
logger.info(f"Databricks model serving endpoint {model_name} is ready!")
return
Expand All @@ -169,17 +169,17 @@ def finetune(
job: TrainingJobDatabricks,
model: str,
train_data: List[Dict[str, Any]],
train_data_format: Optional[Union[TrainDataFormat, str]] = "chat",
train_kwargs: Optional[Dict[str, Any]] = None,
data_format: Optional[Union[DataFormat, str]] = "chat",
) -> str:
if isinstance(data_format, str):
if data_format == "chat":
data_format = DataFormat.chat
data_format = TrainDataFormat.CHAT
elif data_format == "completion":
data_format = DataFormat.completion
data_format = TrainDataFormat.COMPLETION
else:
raise ValueError(
f"String `data_format` must be one of 'chat' or 'completion', but received: {data_format}."
f"String `train_data_format` must be one of 'chat' or 'completion', but received: {data_format}."
)

if "train_data_path" not in train_kwargs:
Expand Down Expand Up @@ -243,7 +243,7 @@ def finetune(
return f"databricks/{job.endpoint_name}"

@staticmethod
def upload_data(train_data: List[Dict[str, Any]], databricks_unity_catalog_path: str, data_format: DataFormat):
def upload_data(train_data: List[Dict[str, Any]], databricks_unity_catalog_path: str, data_format: TrainDataFormat):
logger.info("Uploading finetuning data to Databricks Unity Catalog...")
file_path = _save_data_to_local_file(train_data, data_format)

Expand Down Expand Up @@ -303,7 +303,7 @@ def _create_directory_in_databricks_unity_catalog(w: "WorkspaceClient", databric
logger.info(f"Successfully created directory {databricks_unity_catalog_path} in Databricks Unity Catalog!")


def _save_data_to_local_file(train_data: List[Dict[str, Any]], data_format: DataFormat):
def _save_data_to_local_file(train_data: List[Dict[str, Any]], data_format: TrainDataFormat):
import uuid

file_name = f"finetuning_{uuid.uuid4()}.jsonl"
Expand All @@ -313,9 +313,9 @@ def _save_data_to_local_file(train_data: List[Dict[str, Any]], data_format: Data
file_path = os.path.abspath(file_path)
with open(file_path, "w") as f:
for item in train_data:
if data_format == DataFormat.chat:
if data_format == TrainDataFormat.CHAT:
_validate_chat_data(item)
elif data_format == DataFormat.completion:
elif data_format == TrainDataFormat.COMPLETION:
_validate_completion_data(item)

f.write(ujson.dumps(item) + "\n")
Expand Down
27 changes: 9 additions & 18 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from dspy.adapters.base import Adapter
from dspy.clients.openai import OpenAIProvider
from dspy.clients.provider import Provider, TrainingJob
from dspy.clients.utils_finetune import DataFormat, infer_data_format, validate_data_format
from dspy.clients.utils_finetune import TrainDataFormat
from dspy.utils.callback import BaseCallback, with_callbacks

from .base_lm import BaseLM
Expand All @@ -45,6 +45,7 @@ def __init__(
provider=None,
finetuning_model: Optional[str] = None,
launch_kwargs: Optional[dict[str, Any]] = None,
train_kwargs: Optional[dict[str, Any]] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -79,7 +80,8 @@ def __init__(
self.callbacks = callbacks or []
self.num_retries = num_retries
self.finetuning_model = finetuning_model
self.launch_kwargs = launch_kwargs
self.launch_kwargs = launch_kwargs or {}
self.train_kwargs = train_kwargs or {}

# TODO(bug): Arbitrary model strings could include the substring "o1-".
# We should find a more robust way to check for the "o1-" family models.
Expand Down Expand Up @@ -148,18 +150,16 @@ def __call__(self, prompt=None, messages=None, **kwargs):
return outputs

def launch(self, launch_kwargs: Optional[Dict[str, Any]] = None):
launch_kwargs = launch_kwargs or self.launch_kwargs
self.provider.launch(self, launch_kwargs)

def kill(self, launch_kwargs: Optional[Dict[str, Any]] = None):
launch_kwargs = launch_kwargs or self.launch_kwargs
self.provider.kill(self, launch_kwargs)

def finetune(
self,
train_data: List[Dict[str, Any]],
train_data_format: Optional[TrainDataFormat],
train_kwargs: Optional[Dict[str, Any]] = None,
data_format: Optional[DataFormat] = None,
) -> TrainingJob:
from dspy import settings as settings

Expand All @@ -170,27 +170,18 @@ def finetune(
err = f"Provider {self.provider} does not support fine-tuning."
assert self.provider.finetunable, err

# Perform data validation before starting the thread to fail early
train_kwargs = train_kwargs or {}
if not data_format:
adapter = self.infer_adapter()
data_format = infer_data_format(adapter)
validate_data_format(data=train_data, data_format=data_format)

# TODO(PR): We can quickly add caching, but doing so requires
# adding functions that just call other functions as we had in the last
# iteration, unless people have other ideas.
def thread_function_wrapper():
return self._run_finetune_job(job)

thread = threading.Thread(target=thread_function_wrapper)
model_to_finetune = self.finetuning_model or self.model
train_kwargs = train_kwargs or self.train_kwargs
model_to_finetune = self.finetuning_model or self.model
job = self.provider.TrainingJob(
thread=thread,
model=model_to_finetune,
train_data=train_data,
train_data_format=train_data_format,
train_kwargs=train_kwargs,
data_format=data_format,
)
thread.start()

Expand All @@ -204,8 +195,8 @@ def _run_finetune_job(self, job: TrainingJob):
job=job,
model=job.model,
train_data=job.train_data,
train_data_format=job.train_data_format,
train_kwargs=job.train_kwargs,
data_format=job.data_format,
)
lm = self.copy(model=model)
job.set_result(lm)
Expand Down
Loading
Loading