Skip to content

Commit 1f686c5

Browse files
authored
Liger Kernel integration (#1861)
* add initial plugin support w Liger kernel patches * integrate the input args classes * fix liger plugin and dynamic configuration class * drop untrainable samples and refactor config plugins integration * fix incorrect inputs and circular imports * fix bool comparison * fix for dropping untraibable tokens * fix licensing so liger integration is Apache 2.0 * add jamba support * pylint ignore
1 parent e8ff5d5 commit 1f686c5

File tree

12 files changed

+1010
-3
lines changed

12 files changed

+1010
-3
lines changed

.mypy.ini

+3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ ignore_errors = True
1111
[mypy-axolotl.models.mixtral.*]
1212
ignore_errors = True
1313

14+
[mypy-axolotl.integrations.liger.models.*]
15+
ignore_errors = True
16+
1417
[mypy-axolotl.models.phi.*]
1518
ignore_errors = True
1619

requirements.txt

+2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ gradio==3.50.2
3333
tensorboard
3434
python-dotenv==1.0.1
3535
autoawq>=0.2.5
36+
triton>=2.3.0
37+
liger-kernel
3638

3739
mamba-ssm==1.2.0.post1
3840

src/axolotl/cli/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from transformers.utils.import_utils import _is_package_available
2828

2929
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
30+
from axolotl.integrations.base import PluginManager
3031
from axolotl.logging_config import configure_logging
3132
from axolotl.train import TrainDatasetMeta
3233
from axolotl.utils.config import (
@@ -365,6 +366,11 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
365366

366367
cfg.axolotl_config_path = config
367368

369+
if cfg.get("plugins"):
370+
plugin_manager = PluginManager.get_instance()
371+
for plugin_name in cfg["plugins"]:
372+
plugin_manager.register(plugin_name)
373+
368374
try:
369375
device_props = torch.cuda.get_device_properties("cuda")
370376
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)

0 commit comments

Comments
 (0)