4
4
from transformers import AutoModelForCausalLM , AutoTokenizer , pipeline
5
5
6
6
device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
7
- tokenizer = AutoTokenizer .from_pretrained ("gpt2" )
8
- model = AutoModelForCausalLM .from_pretrained ("trained/gpt2-728" )
7
+ tokenizer = AutoTokenizer .from_pretrained (
8
+ "trained/mini-copilot-tokenizer/tokenizer_10M" )
9
+ model = AutoModelForCausalLM .from_pretrained ("trained/mini-copilot/gpt2-large" )
9
10
model .to (device )
10
11
model .eval ()
11
12
13
+ pipe = pipeline ("text-generation" , model = model ,
14
+ tokenizer = tokenizer , device = device , max_new_tokens = 32 )
12
15
13
- # Function to predict the next token
14
- def predict_next_token (input_text , model = model , tokenizer = tokenizer , max_length = 50 ):
15
- pipe = pipeline ("text-generation" , model = model , tokenizer = tokenizer , device = device )
16
-
17
- predicted_text = pipe (input_text )[0 ]["generated_text" ]
18
16
19
- input_text_len = len ( input_text )
20
-
21
- return predicted_text [ input_text_len :]
17
+ # Function to predict the next token
18
+ def get_completion ( inp : str ) -> str :
19
+ return pipe ( inp )[ 0 ][ "generated_text" ][ len ( inp ) :]
22
20
23
21
22
+ # Lambda handler
24
23
def handler (event , context ):
25
24
try :
26
25
if event .get ('isBase64Encoded' , False ):
@@ -29,4 +28,4 @@ def handler(event, context):
29
28
body = event ['body' ]
30
29
except (KeyError , json .JSONDecodeError ) as e :
31
30
return {"statusCode" : 400 , "body" : f"Error processing request: { str (e )} " }
32
- return {"statusCode" : 200 , "body" : predict_next_token (body )}
31
+ return {"statusCode" : 200 , "body" : get_completion (body )}
0 commit comments