Skip to content

Commit

Permalink
doc: update fine-tuning config & unify lora/qlora/dora.
Browse files Browse the repository at this point in the history
Signed-off-by: Electronic-Waste <[email protected]>
  • Loading branch information
Electronic-Waste committed Feb 19, 2025
1 parent 49f5f02 commit 8e5f256
Showing 1 changed file with 28 additions and 58 deletions.
86 changes: 28 additions & 58 deletions docs/proposals/2401-llm-trainer-v2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,24 @@ INFO:torchtune.utils.logging:Learning rate scheduler is initialized.

By adopting `torchtune` as the low-level runtime for LLM fine-tuning, we can easily obtain the flexibility, efficiency and scalability brought by its unique "recipe-config" design, which will surely streamline and scale LLM fine-tuning on Kubernetes.

To hide users from complex Kubernetes configuations, we will provide a simple yet flexible Python SDK wrapping all specifications of models, datasets, training runtime and fine-tuning configs. Like this:

```python
job_id = TrainingClient().train(
dataset_config=HuggingFaceDatasetConfig(
storage_uri="tatsu-lab/alpaca",
),
trainer=Trainer(
fine_tuning_config=FineTuningConfig(
peft_config=LoraConfig(r=4),
dtype="bf16",
),
num_nodes=5,
),
runtime_ref=llm_runtime,
)
```

## Design Details

### `torchtune` Plugin
Expand All @@ -91,83 +109,35 @@ We will add the fine-tuning configurations in the `fine_tuning_config` field in

| Parameters | What is it? |
| - | - |
| framework | Framework for fine-tuning. |
| dataset_class | Dataset class adopted to fine-tune the LLM. |
| dtype | The underlying data type used to represent the model and optimizer parameters. Currently, we support `bf16` and `fp32`. |
| peft_config | Configuration for the PEFT(Parameter-Efficient Fine-Tuning), including Lora, AdapterPrompt, PrefixTuning, etc. |
| sharding_config | Configuration for sharding policy for distributed training, such as FSDP(Fully Shared Data Parallel) and ZeRO(Zero Redundancy Optimizer). |
| kwargs | Some other backend-specific and launch-CLI-specific parameters. |

```python
# FineTuningConfig DataClass
@dataclass
class FineTuningConfig:
framework: str = "huggingface"
dataset_class: Union[str, Dataset] = "InstructionDataset"
peft_config: Optional[Union[LoraConfig, QLoraConfig, AdapterConfig, PrefixConfig]] = None
sharding_config: Optional[Union[FsdpConfig, ZeroConfig]] = None
kwargs: Optional[Dict[str, str]] = None
dtype: string = "bf16"
peft_config: Optional[Union[LoraConfig]] = None

```

The Python SDK will look like:

```python
job_id = TrainingClient().train(
dataset_config=HuggingFaceDatasetConfig(
storage_uri="tatsu-lab/alpaca",
),
trainer=Trainer(
fine_tuning_config=FineTuningConfig(
framework="huggingface",
dataset_class="InstructionDataset",
peft_config=LoraConfig(r=4),
sharding_config=FsdpConfig(...),
kwargs={},
),
num_nodes=5,
),
runtime_ref=llm_runtime,
)

```

#### LoRA Config
**LoRA Config**

The *LoraConfig* represents the config of LoRA we use to fine-tune the model.

| Parameters | What is it? |
| - | - |
| r | The rank of the low rank decomposition. |
| apply_lora_to_mlp | Whether to apply LoRA to the MLP in each transformer layer |
| apply_lora_to_output | Whether to apply LoRA to the model’s final output projection |
| lora_attn_modules | A list of strings specifying which layers of the model to apply LoRA: 1. `q_proj` applies LoRA to the query projection layer. 2. `k_proj` applies LoRA to the key projection layer. 3. `v_proj` applies LoRA to the value projection layer. 4. `output_proj` applies LoRA to the attention output projection layer. |
| lora_rank | The rank of the low rank decomposition. |
| lora_alpha | The scaling factor that adjusts the magnitude of the low-rank matrices’ output |
| lora_dropout | The probability of applying Dropout to the low rank updates |
| quantize_base | Whether to enable model quantization |
| use_dora | Whether to enable DoRA |

#### QLoRA Config

The *QLoraConfig* represents the config of QLoRA we use to fine-tune the model.

| Parameters | What is it? |
| - | - |
| r | The rank of the low rank decomposition. |
| lora_alpha | The scaling factor that adjusts the magnitude of the low-rank matrices’ output |
| lora_dropout | The probability of applying Dropout to the low rank updates |
| quant_type | The quantization type, supporting nf4 and fp4 |
| use_double_quant | Whether to enable double quantization |
| compute_dtype | Actual data type in the computing phase |
| quant_storage | Actual data type in the storage phase |

```python
# QLoraConfig DataClass
@dataclass
class QLoraConfig:
r: Optional[int] = None
lora_alpha: Optional[int] = None
lora_dropout: Optional[float] = None
quant_type: str = "fp4" # fp4 or nf4
use_double_quant: bool = False
compute_dtype: torch.dtype = torch.bfloat16
quant_storage: torch.dtype = torch.bfloat16

```

## Implementation History

Expand Down

0 comments on commit 8e5f256

Please sign in to comment.