-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_gemini_model.py
38 lines (27 loc) · 988 Bytes
/
test_gemini_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from dotenv import load_dotenv
from dotenv import find_dotenv
load_dotenv(find_dotenv())
from google.cloud import aiplatform
import os
import vertexai
from vertexai.generative_models import GenerativeModel
# from vertexai.preview.tuning import sft
from vertexai.generative_models import GenerationConfig
def list_all_models():
for model in aiplatform.Model.list():
print(model.gca_resource.deployed_models[0].endpoint)
def chat(model, content):
# sft_tuning_job = sft.SupervisedTuningJob(job_name)
# sft_tuning_job.tuned_model_endpoint_name
tuned_model = GenerativeModel(
model.gca_resource.deployed_models[0].endpoint)
generation_config = GenerationConfig(
temperature=0,
)
result = tuned_model.generate_content(
content,
generation_config=generation_config)
print(result.candidates)
if __name__ == '__main__':
model = aiplatform.Model(os.environ['GEMINI_MODEL'])
chat(model, 'whats your name?')