Skip to content

DeepAuto-AI/automl-agent

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

18 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

AutoML-Agent

This is the official implementation of AutoML-Agent: A Multi-Agent LLM Framework for Full-Pipeline AutoML (ICML 2025)

[Paper][Poster][Website]

Setup

Benchmark Datasets

Data Modality Downstream Task Dataset Name # Features # Train # Valid # Test # Classes Source License Evaluation Metric
Main Datasets
Image (Computer Vision) Image Classification Butterfly Image 224x224 4,549 1,299 651 75 Kaggle Dataset CC0 Accuracy
Shopee-IET Varying 640 160 80 4 Kaggle Competition Custom
Text (Natural Language Processing) Text Classification Ecommerce Text N/A 35,296 10,084 5,044 4 Kaggle Dataset CC BY 4.0 Accuracy
Textual Entailment N/A 3,925 982 4,908 3 Kaggle Dataset N/A
Tabular (Classic Machine Learning) Tabular Classification Banana Quality 7 5,600 1,600 800 2 Kaggle Dataset Apache 2.0 F1
Software Defects 21 73,268 18,318 91,587 2 Kaggle Competition N/A
Tabular Clustering Smoker Status 22 100,331 28,666 14,334 2 Kaggle Competition N/A RI
Higher Education Students Performance 31 101 29 15 8 Research Dataset (UCI ML) CC BY 4.0 RI
Tabular Regression Crab Age 8 53,316 13,329 66,646 N/A Kaggle Competition CC0 RMSLE
Crop Price 8 1,540 440 220 N/A Kaggle Dataset MIT RMSLE
Graph (Graph Learning) Node Classification Cora 1,433 2,708 2,708 2,708 7 Research Dataset (Planetoid) CC BY 4.0 Accuracy
Citeseer 3,703 3,327 3,327 3,327 6 Research Dataset (Planetoid) N/A
Time Series (Time Series Analysis) Time-Series Forecasting Weather 21 36,887 10,539 5,270 N/A Research Dataset (TSLib) CC BY 4.0 RMSLE
Electricity 321 18,412 5,260 2,632 N/A Research Dataset (TSLib) CC BY 4.0
Additional Datasets for SELA
Tabular (Classic Machine Learning) Binary Classification Smoker Status 22 85997 21500 143331 2 Kaggle Competition N/A F1
Click Prediction Small 11 19174 4794 7990 2 OpenML
Multi-Class Classification MFeat Factors 216 960 240 400 10 OpenML
Wine Quality White 11 2350 588 980 7 OpenML
Regression Colleges 44 3389 848 1413 N/A OpenML RMSE
House Prices 80 700 176 292 N/A Kaggle Competition

Usage

We recommend using conda environment.

conda create --name amla python=3.11
pip install -r requirements.txt

Run AutoML Development

  1. Run the instruction-tuned LoRA adapter (Download Link) for Prompt Agent via vLLM. vllm==0.4.1 is strictly required to get correct parsed results.
HF_TOKEN="Your HuggingFace Token" CUDA_VISIBLE_DEVICES="0,1,2,3" python -m vllm.entrypoints.openai.api_server --model mistralai/Mixtral-8x7B-Instruct-v0.1 --enable-lora --lora-modules prompt-llama=./adapter/adapter-mixtral/ --tensor-parallel-size 4
  1. Setup Prompt Agent and LLM backbone(s) in ./configs.py.
AVAILABLE_LLMs = {
    "prompt-llm": {
        "api_key": "empty",
        "model": "prompt-llama",
        "base_url": "http://localhost:8000/v1",
    },
    "gpt-4": {"api_key": "YOUR OPENAI KEY", "model": "gpt-4o"},
    "gpt-3.5": {"api_key": "YOUR OPENAI KEY", "model": "gpt-3.5-turbo"},
}
  1. Run chat with AutoML-Agent's Manager 🕴🏻!
from agent_manager import AgentManager

data_path = "agent_workspace/datasets/banana_quality.csv" # assuming the data is uploaded via web interface / API
user_prompt = "Build a model to classify banana quality as good or bad based on their numerical information about bananas of different quality (size, weight, sweetness, softness, harvest time, ripeness, and acidity). We have uploaded the entire dataset for you here in the banana_quality.csv file."
manager = AgentManager(llm='gpt-4', interactive=False, data_path=data_path)

manager.initiate_chat(user_prompt)

Running in a Jupyter notebook is recommended. The generated output .py file will be in the agent_workspace.

Citation

@inproceedings{AutoML_Agent,
  title={Auto{ML}-Agent: A Multi-Agent {LLM} Framework for Full-Pipeline Auto{ML}},
  author={Trirat, Patara and Jeong, Wonyong and Hwang, Sung Ju},
  booktitle={Forty-second International Conference on Machine Learning},
  year={2025},
  url={https://openreview.net/forum?id=p1UBWkOvZm}
}

License

This project is licensed under the CC BY-NC 4.0 license. Commercial use is prohibited.