6
6
# import os
7
7
# os.environ["TRANSFORMERS_CACHE"] = "/media/samuel/UDISK1/transformers_cache"
8
8
import os
9
- from dotenv import load_dotenv
10
9
import time
10
+
11
11
import torch
12
+ from dotenv import load_dotenv
12
13
from langchain .llms .base import LLM
13
- from transformers import pipeline
14
-
15
14
from llama_index import (
16
- SimpleDirectoryReader ,
17
15
GPTListIndex ,
18
- PromptHelper ,
19
16
LLMPredictor ,
20
- ServiceContext
17
+ PromptHelper ,
18
+ ServiceContext ,
19
+ SimpleDirectoryReader ,
21
20
)
21
+ from transformers import pipeline
22
22
23
23
# load_dotenv()
24
24
os .environ ["OPENAI_API_KEY" ] = "random"
25
25
26
+
26
27
def timeit ():
27
28
"""
28
29
a utility decoration to time running time
29
30
"""
31
+
30
32
def decorator (func ):
31
33
def wrapper (* args , ** kwargs ):
32
34
start = time .time ()
@@ -36,7 +38,9 @@ def wrapper(*args, **kwargs):
36
38
37
39
print (f"[{ (end - start ):.8f} seconds]: f({ args } ) -> { result } " )
38
40
return result
41
+
39
42
return wrapper
43
+
40
44
return decorator
41
45
42
46
@@ -45,24 +49,27 @@ def wrapper(*args, **kwargs):
45
49
max_input_size = 2048 ,
46
50
# number of output tokens
47
51
num_output = 256 ,
48
- # the maximum overlap between chunks.
49
- max_chunk_overlap = 20
52
+ # the maximum overlap between chunks.
53
+ max_chunk_overlap = 20 ,
50
54
)
51
55
56
+
52
57
class LocalOPT (LLM ):
53
58
# model_name = "facebook/opt-iml-max-30b" (this is a 60gb model)
54
- model_name = "facebook/opt-iml-1.3b" # ~2.63gb model
59
+ model_name = "facebook/opt-iml-1.3b" # ~2.63gb model
55
60
# https://huggingface.co/docs/transformers/main_classes/pipelines
56
- pipeline = pipeline ("text-generation" , model = model_name ,
57
- device = "cuda:0" ,
58
- model_kwargs = {"torch_dtype" : torch .bfloat16 }
59
- )
60
-
61
- def _call (self , prompt :str , stop = None ) -> str :
62
- response = self .pipeline (prompt , max_new_tokens = 256 )[0 ]["generated_text" ]
63
- # only return newly generated tokens
64
- return response [len (prompt ):]
65
-
61
+ pipeline = pipeline (
62
+ "text-generation" ,
63
+ model = model_name ,
64
+ device = "cuda:0" ,
65
+ model_kwargs = {"torch_dtype" : torch .bfloat16 },
66
+ )
67
+
68
+ def _call (self , prompt : str , stop = None ) -> str :
69
+ response = self .pipeline (prompt , max_new_tokens = 256 )[0 ]["generated_text" ]
70
+ # only return newly generated tokens
71
+ return response [len (prompt ) :]
72
+
66
73
@property
67
74
def _identifying_params (self ):
68
75
return {"name_of_model" : self .model_name }
@@ -71,6 +78,7 @@ def _identifying_params(self):
71
78
def _llm_type (self ):
72
79
return "custom"
73
80
81
+
74
82
@timeit ()
75
83
def create_index ():
76
84
print ("Creating index" )
@@ -79,30 +87,29 @@ def create_index():
79
87
# Service Context: a container for your llamaindex index and query
80
88
# https://gpt-index.readthedocs.io/en/latest/reference/service_context.html
81
89
service_context = ServiceContext .from_defaults (
82
- llm_predictor = llm ,
83
- prompt_helper = prompt_helper
90
+ llm_predictor = llm , prompt_helper = prompt_helper
84
91
)
85
- docs = SimpleDirectoryReader (' news' ).load_data ()
92
+ docs = SimpleDirectoryReader (" news" ).load_data ()
86
93
index = GPTListIndex .from_documents (docs , service_context = service_context )
87
94
print ("Done creating index" , index )
88
95
return index
89
96
97
+
90
98
@timeit ()
91
99
def execute_query ():
92
100
response = index .query (
93
101
"Who does Indonesia export its coal to in 2023?" ,
94
- # This will preemptively filter out nodes that do not contain required_keywords
102
+ # This will preemptively filter out nodes that do not contain required_keywords
95
103
# or contain exclude_keywords, reducing the search space and hence time/number of LLM calls/cost.
96
104
exclude_keywords = ["petroleum" ],
97
105
# required_keywords=["coal"],
98
- # exclude_keywords=["oil", "gas", "petroleum"]
99
-
106
+ # exclude_keywords=["oil", "gas", "petroleum"]
100
107
)
101
108
return response
102
109
103
110
104
111
if __name__ == "__main__" :
105
- """
112
+ """
106
113
Check if a local cache of the model exists,
107
114
if not, it will download the model from huggingface
108
115
"""
@@ -112,7 +119,13 @@ def execute_query():
112
119
index .save_to_disk ("7_custom_opt.json" )
113
120
else :
114
121
print ("Loading local cache of model" )
115
- index = GPTListIndex .load_from_disk ("7_custom_opt.json" )
122
+ llm = LLMPredictor (llm = LocalOPT ())
123
+ service_context = ServiceContext .from_defaults (
124
+ llm_predictor = llm , prompt_helper = prompt_helper
125
+ )
126
+ index = GPTListIndex .load_from_disk (
127
+ "7_custom_opt.json" , service_context = service_context
128
+ )
116
129
117
130
response = execute_query ()
118
131
print (response )
0 commit comments