From a8fa3bb37850e89500261024ff47da0c626ab75f Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Mon, 20 Jan 2025 10:44:44 +0800 Subject: [PATCH] Fix HF `transformers` Breaking Changes (#526) ## Summary 1. Add a general `version_dispatch` utility function that selects constructors and its args based on version comparisons. 2. Update `LlamaRotaryEmbedding`. Closes Issue https://github.com/linkedin/Liger-Kernel/issues/525 and PR https://github.com/linkedin/Liger-Kernel/pull/523 (Thanks to @hebiao064's nice works). ## Testing Done - Hardware Type: - [X] run `make test` to ensure correctness - [X] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Signed-off-by: Austin Liu --- benchmark/scripts/benchmark_rope.py | 18 +++++++++-- test/transformers/test_rope.py | 18 +++++++++-- test/utils.py | 47 +++++++++++++++++++++++++++++ 3 files changed, 79 insertions(+), 4 deletions(-) diff --git a/benchmark/scripts/benchmark_rope.py b/benchmark/scripts/benchmark_rope.py index f0c2a4f02..9f852db6a 100644 --- a/benchmark/scripts/benchmark_rope.py +++ b/benchmark/scripts/benchmark_rope.py @@ -1,6 +1,8 @@ import torch import triton +from test.utils import transformers_version_dispatch +from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding from transformers.models.llama.modeling_llama import apply_rotary_pos_emb from utils import QUANTILES @@ -30,7 +32,13 @@ def bench_speed_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x head_dim = hidden_size // num_q_heads - rotary_emb = LlamaRotaryEmbedding(head_dim, device=device) + rotary_emb = transformers_version_dispatch( + "4.48.0", + LlamaRotaryEmbedding, + LlamaRotaryEmbedding, + before_kwargs={"dim": head_dim, "device": device}, + after_kwargs={"config": LlamaConfig(num_kv_heads=num_kv_heads, head_dim=head_dim), "device": device}, + ) q = torch.randn( (1, seq_len, num_q_heads, head_dim), device=device, @@ -105,7 +113,13 @@ def bench_memory_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutpu seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x head_dim = hidden_size // num_q_heads - rotary_emb = LlamaRotaryEmbedding(head_dim, device=device) + rotary_emb = transformers_version_dispatch( + "4.48.0", + LlamaRotaryEmbedding, + LlamaRotaryEmbedding, + before_kwargs={"dim": head_dim, "device": device}, + after_kwargs={"config": LlamaConfig(num_kv_heads=num_kv_heads, head_dim=head_dim), "device": device}, + ) q = torch.randn( (1, seq_len, num_q_heads, head_dim), device=device, diff --git a/test/transformers/test_rope.py b/test/transformers/test_rope.py index 2670e8c81..752c68219 100644 --- a/test/transformers/test_rope.py +++ b/test/transformers/test_rope.py @@ -2,6 +2,8 @@ import torch from test.utils import supports_bfloat16 +from test.utils import transformers_version_dispatch +from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding from transformers.models.llama.modeling_llama import apply_rotary_pos_emb @@ -57,7 +59,13 @@ def test_correctness( atol, rtol, ): - rotary_emb = LlamaRotaryEmbedding(head_dim, device=device) + rotary_emb = transformers_version_dispatch( + "4.48.0", + LlamaRotaryEmbedding, + LlamaRotaryEmbedding, + before_kwargs={"dim": head_dim, "device": device}, + after_kwargs={"config": LlamaConfig(num_kv_heads=num_kv_heads, head_dim=head_dim), "device": device}, + ) _tensor_q = torch.randn((bsz, seq_len, num_q_heads, head_dim), device=device).transpose(1, 2).to(dtype) @@ -133,7 +141,13 @@ def test_functional_correctness( k1 = _k.clone().requires_grad_(True) k2 = _k.clone().requires_grad_(True) - rotary_emb = LlamaRotaryEmbedding(head_dim, device=device) + rotary_emb = transformers_version_dispatch( + "4.48.0", + LlamaRotaryEmbedding, + LlamaRotaryEmbedding, + before_kwargs={"dim": head_dim, "device": device}, + after_kwargs={"config": LlamaConfig(num_kv_heads=num_kv_heads, head_dim=head_dim), "device": device}, + ) pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) if expand_position_ids: diff --git a/test/utils.py b/test/utils.py index 283d04cb5..31294cc09 100644 --- a/test/utils.py +++ b/test/utils.py @@ -213,6 +213,53 @@ def supports_bfloat16(): return False +def transformers_version_dispatch( + required_version: str, + before_fn, + after_fn, + before_args: tuple = (), + after_args: tuple = (), + before_kwargs: dict = None, + after_kwargs: dict = None, +): + """ + Dispatches to different functions based on package version comparison. + + Args: + required_version: Version to compare against (e.g. "4.48.0") + before_fn: Function to call if package_version < required_version + after_fn: Function to call if package_version >= required_version + before_args: Positional arguments for before_fn + after_args: Positional arguments for after_fn + before_kwargs: Keyword arguments for before_fn + after_kwargs: Keyword arguments for after_fn + + Returns: + Result from either before_fn or after_fn + + Example: + >>> rotary_emb = transformers_version_dispatch( + ... "4.48.0", + ... LlamaRotaryEmbedding, + ... LlamaRotaryEmbedding, + ... before_args=(head_dim,), + ... after_args=(LlamaConfig(head_dim=head_dim),), + ... before_kwargs={'device': device}, + ... after_kwargs={'device': device} + ... ) + """ + from packaging import version + from transformers import __version__ as transformers_version + + before_kwargs = before_kwargs or {} + after_kwargs = after_kwargs or {} + + if version.parse(transformers_version) < version.parse(required_version): + return before_fn(*before_args, **before_kwargs) + else: + return after_fn(*after_args, **after_kwargs) + + def revert_liger_kernel_to_llama(model_config: MiniModelConfig): """ Revert all Liger kernel patches applied to Llama.