Skip to content

Commit

Permalink
Add support for using a Llama.cpp binary and model from TurnkeyML (#234)
Browse files Browse the repository at this point in the history
  • Loading branch information
gabeweisz authored Nov 19, 2024
1 parent b65d5d7 commit d165e22
Show file tree
Hide file tree
Showing 4 changed files with 254 additions and 2 deletions.
3 changes: 3 additions & 0 deletions src/turnkeyml/llm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
AdaptHuggingface,
)

from turnkeyml.llm.tools.llamacpp import LoadLlamaCpp

import turnkeyml.llm.cache as cache
from turnkeyml.llm.tools.mmlu import AccuracyMMLU
from turnkeyml.llm.tools.perplexity import AccuracyPerplexity
Expand All @@ -23,6 +25,7 @@ def main():
# List the available tools
tools = [
HuggingfaceLoad,
LoadLlamaCpp,
AccuracyMMLU,
AccuracyPerplexity,
LLMPrompt,
Expand Down
48 changes: 48 additions & 0 deletions src/turnkeyml/llm/docs/llamacpp.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# LLAMA.CPP

Run transformer models using a Llama.cpp binary and checkpoint. This model can then be used with chatting or benchmarks such as MMLU.

## Prerequisites

This flow has been verified with a generic Llama.cpp model.

These instructions are only for linux or Windows with wsl. It may be necessary to be running WSL in an Administrator command prompt.

These instructions also assume that TurnkeyML's llm extensions have been installed (for example with "pip install -e .[llm]")


### Set up Environment (Assumes TurnkeyML is already installed)

Build or obtain the Llama.cpp model and desired checkpoint.
For example (see the [llama.cpp](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md
) source for more details):
1. cd ~
1. git clone https://github.com/ggerganov/llama.cpp
1. cd llama.cpp
1. make
1. cd models
1. wget https://huggingface.co/TheBloke/Dolphin-Llama2-7B-GGUF/resolve/main/dolphin-llama2-7b.Q5_K_M.gguf


## Usage

The Llama.cpp tool currently supports the following parameters

| Parameter | Definition | Default |
| --------- | ---------------------------------------------------- | ------- |
| executable | Path to the Llama.cpp-generated application binary | None |
| model-binary | Model checkpoint (do not use if --input is passed to lemonade) | None |
| threads | Number of threads to use for computation | 1 |
| context-size | Maximum context length | 512 |
| temp | Temperature to use for inference (leave out to use the application default) | None |

### Example (assuming Llama.cpp built and a checkpoint loaded as above)

```bash
lemonade --input ~/llama.cpp/models/dolphin-llama2-7b.Q5_K_M.gguf load-llama-cpp --executable ~/llama.cpp/llama-cli accuracy-mmlu --ntrain 5
```

On windows, the llama.cpp binary might be in a different location (such as llama.cpp\build\bin\Release\), in which case the command mgiht be something like:
```bash
lemonade --input ~\llama.cpp\models\dolphin-llama2-7b.Q5_K_M.gguf load-llama-cpp --executable ~\llama.cpp\build\bin\Release\llama-cli accuracy-mmlu --ntrain 5
```
183 changes: 183 additions & 0 deletions src/turnkeyml/llm/tools/llamacpp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import argparse
import os
import subprocess
from typing import Optional

from turnkeyml.state import State
from turnkeyml.tools import FirstTool

import turnkeyml.common.build as build
from .adapter import PassthroughTokenizer, ModelAdapter


def llamacpp_dir(state: State):
return os.path.join(build.output_dir(state.cache_dir, state.build_name), "llamacpp")

class LlamaCppAdapter(ModelAdapter):
unique_name = "llama-cpp-adapter"

def __init__(self, executable, model, tool_dir, context_size, threads, temp):
super().__init__()

self.executable = executable
self.model = model
self.tool_dir = tool_dir
self.context_size = context_size
self.threads = threads
self.temp = temp

def generate(self, input_ids: str, max_new_tokens: Optional[int] = None):
"""
Pass a text prompt into the llamacpp inference CLI.
The input_ids arg here should receive the original text that
would normally be encoded by a tokenizer.
"""

cmd = [
self.executable,
"-e",
]

optional_params = {
"ctx-size": self.context_size,
"n-predict": max_new_tokens,
"threads": self.threads,
"model": self.model,
"prompt": input_ids,
"temp": self.temp
}

for flag, value in optional_params.items():
if value is not None:
cmd.append(f"--{flag} {value}")

cmd = [str(m) for m in cmd]

process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
)

raw_output, raw_err= process.communicate()

if process.returncode != 0:
raise subprocess.CalledProcessError(
process.returncode, process.args, raw_output, raw_err)

prompt_found = False
output_text = ""
prompt_first_line = input_ids.split("\n")[0]
for line in raw_output.splitlines():
if prompt_first_line in line:
prompt_found = True
if prompt_found:
line = line.replace("</s> [end of text]", "")
output_text = output_text + line

if not prompt_found:
raise Exception("Prompt not found in result, this is a bug in lemonade.")

return [output_text]

class LoadLlamaCpp(FirstTool):
unique_name = "load-llama-cpp"

def __init__(self):
super().__init__(monitor_message="Running llama.cpp model")

@staticmethod
def parser(add_help: bool = True) -> argparse.ArgumentParser:
parser = __class__.helpful_parser(
short_description="Wrap Llamacpp models with an API",
add_help=add_help,
)

parser.add_argument(
"--executable",
required=True,
type=str,
help="Executable name",
)

default_threads = 1
parser.add_argument(
"--threads",
required=False,
type=int,
default=default_threads,
help=f"Number of threads to use for generation (default: {default_threads})",
)

context_size = 512
parser.add_argument(
"--context-size",
required=False,
type=int,
default=context_size,
help=f"Context size of the prompt (default: {context_size})",
)

parser.add_argument(
"--model-binary",
required=False,
help="Path to a .gguf model to use with benchmarking.",
)

parser.add_argument(
"--temp",
type=float,
required=False,
help="Temperature",
)

return parser

def run(
self,
state: State,
input: str = None,
context_size: int = None,
threads: int = None,
executable: str = None,
model_binary: str = None,
temp: float = None,
) -> State:
"""
Create a tokenizer instance and model instance in `state` that support:
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
response = model.generate(input_ids, max_new_tokens=1)
response_text = tokenizer.decode(response[0], skip_special_tokens=True).strip()
"""

if executable is None:
raise Exception(f"{self.__class__.unique_name} requires an executable")

if (input is not None and input != ""):
model_binary = input

# Save execution parameters
state.save_stat("context_size", context_size)
state.save_stat("threads", threads)

if model_binary is None:
raise Exception(
f"{self.__class__.unique_name} requires the preceding tool to pass a "
"Llamacpp model, "
"or for the user to supply a model with `--model-binary`"
)

state.model = LlamaCppAdapter(
executable = executable,
model=model_binary,
tool_dir=llamacpp_dir(state),
context_size=context_size,
threads=threads,
temp=temp,
)
state.tokenizer = PassthroughTokenizer()

return state
22 changes: 20 additions & 2 deletions src/turnkeyml/llm/tools/mmlu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import tarfile
from pathlib import Path
from typing import List, Optional
import subprocess
import tqdm
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -42,6 +43,12 @@ def parser(add_help: bool = True) -> argparse.ArgumentParser:
default=5,
help="Number of training examples to use. Default set to 5 for `5 Shot`",
)
parser.add_argument(
"--max-evals",
type=int,
default=None,
help="Maximum evaluations to run per test",
)
parser.add_argument(
"--data-dir",
type=str,
Expand All @@ -62,6 +69,7 @@ def run(
self,
state: State,
ntrain: int = 5,
max_evals: int = None,
data_dir: Optional[str] = None,
tests: List[str] = None,
) -> State:
Expand Down Expand Up @@ -106,7 +114,7 @@ def run(
)

detailed_results, acc = _eval_model(
ntrain, subject, model, tokenizer, dev_df, test_df
ntrain, max_evals, subject, model, tokenizer, dev_df, test_df
)
subject_results_df = pd.DataFrame(detailed_results)
subject_csv_path = os.path.join(
Expand All @@ -118,11 +126,15 @@ def run(
correct_answers_count = sum(
result["Correct"] for result in detailed_results
)

summary_data.append(
{
"Subject": subject,
"Accuracy": acc,
"Total Questions": len(test_df),
"Evaluated Questions": (max_evals
if max_evals is not None and max_evals < len(test_df)
else len(test_df)),
"Correct Answers": correct_answers_count,
}
)
Expand Down Expand Up @@ -197,6 +209,10 @@ def _generate_response(tokenizer, model, input_ids):
try:
response = model.generate(input_ids, max_new_tokens=1)
return tokenizer.decode(response[0], skip_special_tokens=True).strip()
except subprocess.CalledProcessError as e:
printing.log_warning(
f"Subprocess failed with command: {e} and error message: {e.stderr}"
)
except Exception as e: # pylint: disable=broad-except
printing.log_warning(f"Error during model generation: {e}")
return "" # Return an empty string on failure
Expand Down Expand Up @@ -238,7 +254,7 @@ def download_and_extract_dataset(data_cache_dir: str, dataset_url: str):
return os.path.join(data_cache_dir, "data")


def _eval_model(ntrain, subject, model, tokenizer, dev_df, test_df):
def _eval_model(ntrain, max_evals, subject, model, tokenizer, dev_df, test_df):
"""Evaluates the model on the test data for a given subject."""
detailed_results = []

Expand All @@ -265,6 +281,8 @@ def _eval_model(ntrain, subject, model, tokenizer, dev_df, test_df):
"Correct": pred_label == label,
}
)
if (max_evals is not None and i >= max_evals -1):
break

acc = np.mean([res["Correct"] for res in detailed_results])
return detailed_results, acc

0 comments on commit d165e22

Please sign in to comment.