Skip to content

Commit 80feacc

Browse files
authored
Merge pull request #1 from nicholas-camarda/no-open-api-key
updated 7_custom.py
2 parents 4df93c2 + b37f156 commit 80feacc

File tree

1 file changed

+40
-27
lines changed

1 file changed

+40
-27
lines changed

7_custom.py

+40-27
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,29 @@
66
# import os
77
# os.environ["TRANSFORMERS_CACHE"] = "/media/samuel/UDISK1/transformers_cache"
88
import os
9-
from dotenv import load_dotenv
109
import time
10+
1111
import torch
12+
from dotenv import load_dotenv
1213
from langchain.llms.base import LLM
13-
from transformers import pipeline
14-
1514
from llama_index import (
16-
SimpleDirectoryReader,
1715
GPTListIndex,
18-
PromptHelper,
1916
LLMPredictor,
20-
ServiceContext
17+
PromptHelper,
18+
ServiceContext,
19+
SimpleDirectoryReader,
2120
)
21+
from transformers import pipeline
2222

2323
# load_dotenv()
2424
os.environ["OPENAI_API_KEY"] = "random"
2525

26+
2627
def timeit():
2728
"""
2829
a utility decoration to time running time
2930
"""
31+
3032
def decorator(func):
3133
def wrapper(*args, **kwargs):
3234
start = time.time()
@@ -36,7 +38,9 @@ def wrapper(*args, **kwargs):
3638

3739
print(f"[{(end - start):.8f} seconds]: f({args}) -> {result}")
3840
return result
41+
3942
return wrapper
43+
4044
return decorator
4145

4246

@@ -45,24 +49,27 @@ def wrapper(*args, **kwargs):
4549
max_input_size=2048,
4650
# number of output tokens
4751
num_output=256,
48-
# the maximum overlap between chunks.
49-
max_chunk_overlap=20
52+
# the maximum overlap between chunks.
53+
max_chunk_overlap=20,
5054
)
5155

56+
5257
class LocalOPT(LLM):
5358
# model_name = "facebook/opt-iml-max-30b" (this is a 60gb model)
54-
model_name = "facebook/opt-iml-1.3b" # ~2.63gb model
59+
model_name = "facebook/opt-iml-1.3b" # ~2.63gb model
5560
# https://huggingface.co/docs/transformers/main_classes/pipelines
56-
pipeline = pipeline("text-generation", model=model_name,
57-
device="cuda:0",
58-
model_kwargs={"torch_dtype": torch.bfloat16}
59-
)
60-
61-
def _call(self, prompt:str, stop=None) -> str:
62-
response = self.pipeline(prompt, max_new_tokens=256)[0]["generated_text"]
63-
# only return newly generated tokens
64-
return response[len(prompt):]
65-
61+
pipeline = pipeline(
62+
"text-generation",
63+
model=model_name,
64+
device="cuda:0",
65+
model_kwargs={"torch_dtype": torch.bfloat16},
66+
)
67+
68+
def _call(self, prompt: str, stop=None) -> str:
69+
response = self.pipeline(prompt, max_new_tokens=256)[0]["generated_text"]
70+
# only return newly generated tokens
71+
return response[len(prompt) :]
72+
6673
@property
6774
def _identifying_params(self):
6875
return {"name_of_model": self.model_name}
@@ -71,6 +78,7 @@ def _identifying_params(self):
7178
def _llm_type(self):
7279
return "custom"
7380

81+
7482
@timeit()
7583
def create_index():
7684
print("Creating index")
@@ -79,30 +87,29 @@ def create_index():
7987
# Service Context: a container for your llamaindex index and query
8088
# https://gpt-index.readthedocs.io/en/latest/reference/service_context.html
8189
service_context = ServiceContext.from_defaults(
82-
llm_predictor=llm,
83-
prompt_helper=prompt_helper
90+
llm_predictor=llm, prompt_helper=prompt_helper
8491
)
85-
docs = SimpleDirectoryReader('news').load_data()
92+
docs = SimpleDirectoryReader("news").load_data()
8693
index = GPTListIndex.from_documents(docs, service_context=service_context)
8794
print("Done creating index", index)
8895
return index
8996

97+
9098
@timeit()
9199
def execute_query():
92100
response = index.query(
93101
"Who does Indonesia export its coal to in 2023?",
94-
# This will preemptively filter out nodes that do not contain required_keywords
102+
# This will preemptively filter out nodes that do not contain required_keywords
95103
# or contain exclude_keywords, reducing the search space and hence time/number of LLM calls/cost.
96104
exclude_keywords=["petroleum"],
97105
# required_keywords=["coal"],
98-
# exclude_keywords=["oil", "gas", "petroleum"]
99-
106+
# exclude_keywords=["oil", "gas", "petroleum"]
100107
)
101108
return response
102109

103110

104111
if __name__ == "__main__":
105-
"""
112+
"""
106113
Check if a local cache of the model exists,
107114
if not, it will download the model from huggingface
108115
"""
@@ -112,7 +119,13 @@ def execute_query():
112119
index.save_to_disk("7_custom_opt.json")
113120
else:
114121
print("Loading local cache of model")
115-
index = GPTListIndex.load_from_disk("7_custom_opt.json")
122+
llm = LLMPredictor(llm=LocalOPT())
123+
service_context = ServiceContext.from_defaults(
124+
llm_predictor=llm, prompt_helper=prompt_helper
125+
)
126+
index = GPTListIndex.load_from_disk(
127+
"7_custom_opt.json", service_context=service_context
128+
)
116129

117130
response = execute_query()
118131
print(response)

0 commit comments

Comments
 (0)