From 8e5f256cf38887fc73d4ae39769f13c42b2c2839 Mon Sep 17 00:00:00 2001 From: Electronic-Waste <2690692950@qq.com> Date: Wed, 19 Feb 2025 06:26:38 +0000 Subject: [PATCH] doc: update fine-tuning config & unify lora/qlora/dora. Signed-off-by: Electronic-Waste <2690692950@qq.com> --- docs/proposals/2401-llm-trainer-v2/README.md | 86 +++++++------------- 1 file changed, 28 insertions(+), 58 deletions(-) diff --git a/docs/proposals/2401-llm-trainer-v2/README.md b/docs/proposals/2401-llm-trainer-v2/README.md index 5da6f1757d..0f055cacd9 100644 --- a/docs/proposals/2401-llm-trainer-v2/README.md +++ b/docs/proposals/2401-llm-trainer-v2/README.md @@ -73,6 +73,24 @@ INFO:torchtune.utils.logging:Learning rate scheduler is initialized. By adopting `torchtune` as the low-level runtime for LLM fine-tuning, we can easily obtain the flexibility, efficiency and scalability brought by its unique "recipe-config" design, which will surely streamline and scale LLM fine-tuning on Kubernetes. +To hide users from complex Kubernetes configuations, we will provide a simple yet flexible Python SDK wrapping all specifications of models, datasets, training runtime and fine-tuning configs. Like this: + +```python +job_id = TrainingClient().train( + dataset_config=HuggingFaceDatasetConfig( + storage_uri="tatsu-lab/alpaca", + ), + trainer=Trainer( + fine_tuning_config=FineTuningConfig( + peft_config=LoraConfig(r=4), + dtype="bf16", + ), + num_nodes=5, + ), + runtime_ref=llm_runtime, +) +``` + ## Design Details ### `torchtune` Plugin @@ -91,83 +109,35 @@ We will add the fine-tuning configurations in the `fine_tuning_config` field in | Parameters | What is it? | | - | - | -| framework | Framework for fine-tuning. | -| dataset_class | Dataset class adopted to fine-tune the LLM. | +| dtype | The underlying data type used to represent the model and optimizer parameters. Currently, we support `bf16` and `fp32`. | | peft_config | Configuration for the PEFT(Parameter-Efficient Fine-Tuning), including Lora, AdapterPrompt, PrefixTuning, etc. | -| sharding_config | Configuration for sharding policy for distributed training, such as FSDP(Fully Shared Data Parallel) and ZeRO(Zero Redundancy Optimizer). | -| kwargs | Some other backend-specific and launch-CLI-specific parameters. | ```python # FineTuningConfig DataClass @dataclass class FineTuningConfig: - framework: str = "huggingface" - dataset_class: Union[str, Dataset] = "InstructionDataset" - peft_config: Optional[Union[LoraConfig, QLoraConfig, AdapterConfig, PrefixConfig]] = None - sharding_config: Optional[Union[FsdpConfig, ZeroConfig]] = None - kwargs: Optional[Dict[str, str]] = None + dtype: string = "bf16" + peft_config: Optional[Union[LoraConfig]] = None ``` The Python SDK will look like: -```python -job_id = TrainingClient().train( - dataset_config=HuggingFaceDatasetConfig( - storage_uri="tatsu-lab/alpaca", - ), - trainer=Trainer( - fine_tuning_config=FineTuningConfig( - framework="huggingface", - dataset_class="InstructionDataset", - peft_config=LoraConfig(r=4), - sharding_config=FsdpConfig(...), - kwargs={}, - ), - num_nodes=5, - ), - runtime_ref=llm_runtime, -) - -``` - -#### LoRA Config +**LoRA Config** The *LoraConfig* represents the config of LoRA we use to fine-tune the model. | Parameters | What is it? | | - | - | -| r | The rank of the low rank decomposition. | +| apply_lora_to_mlp | Whether to apply LoRA to the MLP in each transformer layer | +| apply_lora_to_output | Whether to apply LoRA to the model’s final output projection | +| lora_attn_modules | A list of strings specifying which layers of the model to apply LoRA: 1. `q_proj` applies LoRA to the query projection layer. 2. `k_proj` applies LoRA to the key projection layer. 3. `v_proj` applies LoRA to the value projection layer. 4. `output_proj` applies LoRA to the attention output projection layer. | +| lora_rank | The rank of the low rank decomposition. | | lora_alpha | The scaling factor that adjusts the magnitude of the low-rank matrices’ output | | lora_dropout | The probability of applying Dropout to the low rank updates | +| quantize_base | Whether to enable model quantization | +| use_dora | Whether to enable DoRA | -#### QLoRA Config - -The *QLoraConfig* represents the config of QLoRA we use to fine-tune the model. - -| Parameters | What is it? | -| - | - | -| r | The rank of the low rank decomposition. | -| lora_alpha | The scaling factor that adjusts the magnitude of the low-rank matrices’ output | -| lora_dropout | The probability of applying Dropout to the low rank updates | -| quant_type | The quantization type, supporting nf4 and fp4 | -| use_double_quant | Whether to enable double quantization | -| compute_dtype | Actual data type in the computing phase | -| quant_storage | Actual data type in the storage phase | - -```python -# QLoraConfig DataClass -@dataclass -class QLoraConfig: - r: Optional[int] = None - lora_alpha: Optional[int] = None - lora_dropout: Optional[float] = None - quant_type: str = "fp4" # fp4 or nf4 - use_double_quant: bool = False - compute_dtype: torch.dtype = torch.bfloat16 - quant_storage: torch.dtype = torch.bfloat16 - -``` ## Implementation History