66# import os
77# os.environ["TRANSFORMERS_CACHE"] = "/media/samuel/UDISK1/transformers_cache"
88import os
9- from dotenv import load_dotenv
109import time
10+
1111import torch
12+ from dotenv import load_dotenv
1213from langchain .llms .base import LLM
13- from transformers import pipeline
14-
1514from llama_index import (
16- SimpleDirectoryReader ,
1715 GPTListIndex ,
18- PromptHelper ,
1916 LLMPredictor ,
20- ServiceContext
17+ PromptHelper ,
18+ ServiceContext ,
19+ SimpleDirectoryReader ,
2120)
21+ from transformers import pipeline
2222
2323# load_dotenv()
2424os .environ ["OPENAI_API_KEY" ] = "random"
2525
26+
2627def timeit ():
2728 """
2829 a utility decoration to time running time
2930 """
31+
3032 def decorator (func ):
3133 def wrapper (* args , ** kwargs ):
3234 start = time .time ()
@@ -36,7 +38,9 @@ def wrapper(*args, **kwargs):
3638
3739 print (f"[{ (end - start ):.8f} seconds]: f({ args } ) -> { result } " )
3840 return result
41+
3942 return wrapper
43+
4044 return decorator
4145
4246
@@ -45,24 +49,27 @@ def wrapper(*args, **kwargs):
4549 max_input_size = 2048 ,
4650 # number of output tokens
4751 num_output = 256 ,
48- # the maximum overlap between chunks.
49- max_chunk_overlap = 20
52+ # the maximum overlap between chunks.
53+ max_chunk_overlap = 20 ,
5054)
5155
56+
5257class LocalOPT (LLM ):
5358 # model_name = "facebook/opt-iml-max-30b" (this is a 60gb model)
54- model_name = "facebook/opt-iml-1.3b" # ~2.63gb model
59+ model_name = "facebook/opt-iml-1.3b" # ~2.63gb model
5560 # https://huggingface.co/docs/transformers/main_classes/pipelines
56- pipeline = pipeline ("text-generation" , model = model_name ,
57- device = "cuda:0" ,
58- model_kwargs = {"torch_dtype" : torch .bfloat16 }
59- )
60-
61- def _call (self , prompt :str , stop = None ) -> str :
62- response = self .pipeline (prompt , max_new_tokens = 256 )[0 ]["generated_text" ]
63- # only return newly generated tokens
64- return response [len (prompt ):]
65-
61+ pipeline = pipeline (
62+ "text-generation" ,
63+ model = model_name ,
64+ device = "cuda:0" ,
65+ model_kwargs = {"torch_dtype" : torch .bfloat16 },
66+ )
67+
68+ def _call (self , prompt : str , stop = None ) -> str :
69+ response = self .pipeline (prompt , max_new_tokens = 256 )[0 ]["generated_text" ]
70+ # only return newly generated tokens
71+ return response [len (prompt ) :]
72+
6673 @property
6774 def _identifying_params (self ):
6875 return {"name_of_model" : self .model_name }
@@ -71,6 +78,7 @@ def _identifying_params(self):
7178 def _llm_type (self ):
7279 return "custom"
7380
81+
7482@timeit ()
7583def create_index ():
7684 print ("Creating index" )
@@ -79,30 +87,29 @@ def create_index():
7987 # Service Context: a container for your llamaindex index and query
8088 # https://gpt-index.readthedocs.io/en/latest/reference/service_context.html
8189 service_context = ServiceContext .from_defaults (
82- llm_predictor = llm ,
83- prompt_helper = prompt_helper
90+ llm_predictor = llm , prompt_helper = prompt_helper
8491 )
85- docs = SimpleDirectoryReader (' news' ).load_data ()
92+ docs = SimpleDirectoryReader (" news" ).load_data ()
8693 index = GPTListIndex .from_documents (docs , service_context = service_context )
8794 print ("Done creating index" , index )
8895 return index
8996
97+
9098@timeit ()
9199def execute_query ():
92100 response = index .query (
93101 "Who does Indonesia export its coal to in 2023?" ,
94- # This will preemptively filter out nodes that do not contain required_keywords
102+ # This will preemptively filter out nodes that do not contain required_keywords
95103 # or contain exclude_keywords, reducing the search space and hence time/number of LLM calls/cost.
96104 exclude_keywords = ["petroleum" ],
97105 # required_keywords=["coal"],
98- # exclude_keywords=["oil", "gas", "petroleum"]
99-
106+ # exclude_keywords=["oil", "gas", "petroleum"]
100107 )
101108 return response
102109
103110
104111if __name__ == "__main__" :
105- """
112+ """
106113 Check if a local cache of the model exists,
107114 if not, it will download the model from huggingface
108115 """
@@ -112,7 +119,13 @@ def execute_query():
112119 index .save_to_disk ("7_custom_opt.json" )
113120 else :
114121 print ("Loading local cache of model" )
115- index = GPTListIndex .load_from_disk ("7_custom_opt.json" )
122+ llm = LLMPredictor (llm = LocalOPT ())
123+ service_context = ServiceContext .from_defaults (
124+ llm_predictor = llm , prompt_helper = prompt_helper
125+ )
126+ index = GPTListIndex .load_from_disk (
127+ "7_custom_opt.json" , service_context = service_context
128+ )
116129
117130 response = execute_query ()
118131 print (response )
0 commit comments