Skip to content

Commit d165e22

Browse files
authored
Add support for using a Llama.cpp binary and model from TurnkeyML (#234)
1 parent b65d5d7 commit d165e22

File tree

4 files changed

+254
-2
lines changed

4 files changed

+254
-2
lines changed

src/turnkeyml/llm/cli.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
AdaptHuggingface,
1313
)
1414

15+
from turnkeyml.llm.tools.llamacpp import LoadLlamaCpp
16+
1517
import turnkeyml.llm.cache as cache
1618
from turnkeyml.llm.tools.mmlu import AccuracyMMLU
1719
from turnkeyml.llm.tools.perplexity import AccuracyPerplexity
@@ -23,6 +25,7 @@ def main():
2325
# List the available tools
2426
tools = [
2527
HuggingfaceLoad,
28+
LoadLlamaCpp,
2629
AccuracyMMLU,
2730
AccuracyPerplexity,
2831
LLMPrompt,

src/turnkeyml/llm/docs/llamacpp.md

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# LLAMA.CPP
2+
3+
Run transformer models using a Llama.cpp binary and checkpoint. This model can then be used with chatting or benchmarks such as MMLU.
4+
5+
## Prerequisites
6+
7+
This flow has been verified with a generic Llama.cpp model.
8+
9+
These instructions are only for linux or Windows with wsl. It may be necessary to be running WSL in an Administrator command prompt.
10+
11+
These instructions also assume that TurnkeyML's llm extensions have been installed (for example with "pip install -e .[llm]")
12+
13+
14+
### Set up Environment (Assumes TurnkeyML is already installed)
15+
16+
Build or obtain the Llama.cpp model and desired checkpoint.
17+
For example (see the [llama.cpp](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md
18+
) source for more details):
19+
1. cd ~
20+
1. git clone https://github.com/ggerganov/llama.cpp
21+
1. cd llama.cpp
22+
1. make
23+
1. cd models
24+
1. wget https://huggingface.co/TheBloke/Dolphin-Llama2-7B-GGUF/resolve/main/dolphin-llama2-7b.Q5_K_M.gguf
25+
26+
27+
## Usage
28+
29+
The Llama.cpp tool currently supports the following parameters
30+
31+
| Parameter | Definition | Default |
32+
| --------- | ---------------------------------------------------- | ------- |
33+
| executable | Path to the Llama.cpp-generated application binary | None |
34+
| model-binary | Model checkpoint (do not use if --input is passed to lemonade) | None |
35+
| threads | Number of threads to use for computation | 1 |
36+
| context-size | Maximum context length | 512 |
37+
| temp | Temperature to use for inference (leave out to use the application default) | None |
38+
39+
### Example (assuming Llama.cpp built and a checkpoint loaded as above)
40+
41+
```bash
42+
lemonade --input ~/llama.cpp/models/dolphin-llama2-7b.Q5_K_M.gguf load-llama-cpp --executable ~/llama.cpp/llama-cli accuracy-mmlu --ntrain 5
43+
```
44+
45+
On windows, the llama.cpp binary might be in a different location (such as llama.cpp\build\bin\Release\), in which case the command mgiht be something like:
46+
```bash
47+
lemonade --input ~\llama.cpp\models\dolphin-llama2-7b.Q5_K_M.gguf load-llama-cpp --executable ~\llama.cpp\build\bin\Release\llama-cli accuracy-mmlu --ntrain 5
48+
```

src/turnkeyml/llm/tools/llamacpp.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
import argparse
2+
import os
3+
import subprocess
4+
from typing import Optional
5+
6+
from turnkeyml.state import State
7+
from turnkeyml.tools import FirstTool
8+
9+
import turnkeyml.common.build as build
10+
from .adapter import PassthroughTokenizer, ModelAdapter
11+
12+
13+
def llamacpp_dir(state: State):
14+
return os.path.join(build.output_dir(state.cache_dir, state.build_name), "llamacpp")
15+
16+
class LlamaCppAdapter(ModelAdapter):
17+
unique_name = "llama-cpp-adapter"
18+
19+
def __init__(self, executable, model, tool_dir, context_size, threads, temp):
20+
super().__init__()
21+
22+
self.executable = executable
23+
self.model = model
24+
self.tool_dir = tool_dir
25+
self.context_size = context_size
26+
self.threads = threads
27+
self.temp = temp
28+
29+
def generate(self, input_ids: str, max_new_tokens: Optional[int] = None):
30+
"""
31+
Pass a text prompt into the llamacpp inference CLI.
32+
33+
The input_ids arg here should receive the original text that
34+
would normally be encoded by a tokenizer.
35+
"""
36+
37+
cmd = [
38+
self.executable,
39+
"-e",
40+
]
41+
42+
optional_params = {
43+
"ctx-size": self.context_size,
44+
"n-predict": max_new_tokens,
45+
"threads": self.threads,
46+
"model": self.model,
47+
"prompt": input_ids,
48+
"temp": self.temp
49+
}
50+
51+
for flag, value in optional_params.items():
52+
if value is not None:
53+
cmd.append(f"--{flag} {value}")
54+
55+
cmd = [str(m) for m in cmd]
56+
57+
process = subprocess.Popen(
58+
cmd,
59+
stdout=subprocess.PIPE,
60+
stderr=subprocess.PIPE,
61+
universal_newlines=True,
62+
)
63+
64+
raw_output, raw_err= process.communicate()
65+
66+
if process.returncode != 0:
67+
raise subprocess.CalledProcessError(
68+
process.returncode, process.args, raw_output, raw_err)
69+
70+
prompt_found = False
71+
output_text = ""
72+
prompt_first_line = input_ids.split("\n")[0]
73+
for line in raw_output.splitlines():
74+
if prompt_first_line in line:
75+
prompt_found = True
76+
if prompt_found:
77+
line = line.replace("</s> [end of text]", "")
78+
output_text = output_text + line
79+
80+
if not prompt_found:
81+
raise Exception("Prompt not found in result, this is a bug in lemonade.")
82+
83+
return [output_text]
84+
85+
class LoadLlamaCpp(FirstTool):
86+
unique_name = "load-llama-cpp"
87+
88+
def __init__(self):
89+
super().__init__(monitor_message="Running llama.cpp model")
90+
91+
@staticmethod
92+
def parser(add_help: bool = True) -> argparse.ArgumentParser:
93+
parser = __class__.helpful_parser(
94+
short_description="Wrap Llamacpp models with an API",
95+
add_help=add_help,
96+
)
97+
98+
parser.add_argument(
99+
"--executable",
100+
required=True,
101+
type=str,
102+
help="Executable name",
103+
)
104+
105+
default_threads = 1
106+
parser.add_argument(
107+
"--threads",
108+
required=False,
109+
type=int,
110+
default=default_threads,
111+
help=f"Number of threads to use for generation (default: {default_threads})",
112+
)
113+
114+
context_size = 512
115+
parser.add_argument(
116+
"--context-size",
117+
required=False,
118+
type=int,
119+
default=context_size,
120+
help=f"Context size of the prompt (default: {context_size})",
121+
)
122+
123+
parser.add_argument(
124+
"--model-binary",
125+
required=False,
126+
help="Path to a .gguf model to use with benchmarking.",
127+
)
128+
129+
parser.add_argument(
130+
"--temp",
131+
type=float,
132+
required=False,
133+
help="Temperature",
134+
)
135+
136+
return parser
137+
138+
def run(
139+
self,
140+
state: State,
141+
input: str = None,
142+
context_size: int = None,
143+
threads: int = None,
144+
executable: str = None,
145+
model_binary: str = None,
146+
temp: float = None,
147+
) -> State:
148+
"""
149+
Create a tokenizer instance and model instance in `state` that support:
150+
151+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
152+
response = model.generate(input_ids, max_new_tokens=1)
153+
response_text = tokenizer.decode(response[0], skip_special_tokens=True).strip()
154+
"""
155+
156+
if executable is None:
157+
raise Exception(f"{self.__class__.unique_name} requires an executable")
158+
159+
if (input is not None and input != ""):
160+
model_binary = input
161+
162+
# Save execution parameters
163+
state.save_stat("context_size", context_size)
164+
state.save_stat("threads", threads)
165+
166+
if model_binary is None:
167+
raise Exception(
168+
f"{self.__class__.unique_name} requires the preceding tool to pass a "
169+
"Llamacpp model, "
170+
"or for the user to supply a model with `--model-binary`"
171+
)
172+
173+
state.model = LlamaCppAdapter(
174+
executable = executable,
175+
model=model_binary,
176+
tool_dir=llamacpp_dir(state),
177+
context_size=context_size,
178+
threads=threads,
179+
temp=temp,
180+
)
181+
state.tokenizer = PassthroughTokenizer()
182+
183+
return state

src/turnkeyml/llm/tools/mmlu.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import tarfile
44
from pathlib import Path
55
from typing import List, Optional
6+
import subprocess
67
import tqdm
78
import numpy as np
89
import pandas as pd
@@ -42,6 +43,12 @@ def parser(add_help: bool = True) -> argparse.ArgumentParser:
4243
default=5,
4344
help="Number of training examples to use. Default set to 5 for `5 Shot`",
4445
)
46+
parser.add_argument(
47+
"--max-evals",
48+
type=int,
49+
default=None,
50+
help="Maximum evaluations to run per test",
51+
)
4552
parser.add_argument(
4653
"--data-dir",
4754
type=str,
@@ -62,6 +69,7 @@ def run(
6269
self,
6370
state: State,
6471
ntrain: int = 5,
72+
max_evals: int = None,
6573
data_dir: Optional[str] = None,
6674
tests: List[str] = None,
6775
) -> State:
@@ -106,7 +114,7 @@ def run(
106114
)
107115

108116
detailed_results, acc = _eval_model(
109-
ntrain, subject, model, tokenizer, dev_df, test_df
117+
ntrain, max_evals, subject, model, tokenizer, dev_df, test_df
110118
)
111119
subject_results_df = pd.DataFrame(detailed_results)
112120
subject_csv_path = os.path.join(
@@ -118,11 +126,15 @@ def run(
118126
correct_answers_count = sum(
119127
result["Correct"] for result in detailed_results
120128
)
129+
121130
summary_data.append(
122131
{
123132
"Subject": subject,
124133
"Accuracy": acc,
125134
"Total Questions": len(test_df),
135+
"Evaluated Questions": (max_evals
136+
if max_evals is not None and max_evals < len(test_df)
137+
else len(test_df)),
126138
"Correct Answers": correct_answers_count,
127139
}
128140
)
@@ -197,6 +209,10 @@ def _generate_response(tokenizer, model, input_ids):
197209
try:
198210
response = model.generate(input_ids, max_new_tokens=1)
199211
return tokenizer.decode(response[0], skip_special_tokens=True).strip()
212+
except subprocess.CalledProcessError as e:
213+
printing.log_warning(
214+
f"Subprocess failed with command: {e} and error message: {e.stderr}"
215+
)
200216
except Exception as e: # pylint: disable=broad-except
201217
printing.log_warning(f"Error during model generation: {e}")
202218
return "" # Return an empty string on failure
@@ -238,7 +254,7 @@ def download_and_extract_dataset(data_cache_dir: str, dataset_url: str):
238254
return os.path.join(data_cache_dir, "data")
239255

240256

241-
def _eval_model(ntrain, subject, model, tokenizer, dev_df, test_df):
257+
def _eval_model(ntrain, max_evals, subject, model, tokenizer, dev_df, test_df):
242258
"""Evaluates the model on the test data for a given subject."""
243259
detailed_results = []
244260

@@ -265,6 +281,8 @@ def _eval_model(ntrain, subject, model, tokenizer, dev_df, test_df):
265281
"Correct": pred_label == label,
266282
}
267283
)
284+
if (max_evals is not None and i >= max_evals -1):
285+
break
268286

269287
acc = np.mean([res["Correct"] for res in detailed_results])
270288
return detailed_results, acc

0 commit comments

Comments
 (0)