-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprompton_helpers.py
123 lines (90 loc) · 3.9 KB
/
prompton_helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from typing import List, Tuple
import streamlit as st
from datetime import datetime
from tzlocal import get_localzone
from prompton import client as prompton_client
from prompton import types as prompton_types
def get_prompton():
prompton_api = prompton_client.PromptonApi(
environment=st.session_state["prompton_env"],
token=st.session_state["auth_token"],
)
return prompton_api
def load_user(access_token: str):
st.session_state["auth_token"] = access_token
prompton = get_prompton()
st.session_state["current_user"] = prompton.users.get_current_user()
st.session_state["current_org"] = prompton.orgs.get_current_user_org()
@st.cache_data(ttl="1d")
def get_prompt_version_list(prompt_id: str):
prompton = get_prompton()
return prompton.prompt_versions.get_prompt_versions_list(prompt_id=prompt_id)
@st.cache_data(ttl="1d")
def get_prompt_version_by_id(prompt_version_id: str):
prompton = get_prompton()
return prompton.prompt_versions.get_prompt_version_by_id(id=prompt_version_id)
@st.cache_data(ttl="1d")
def get_prompt_by_id(prompt_id: str):
prompton = get_prompton()
return prompton.prompts.get_prompt_by_id(id=prompt_id)
@st.cache_data(ttl="1d")
def get_prompts():
prompton = get_prompton()
return prompton.prompts.get_prompts_list()
@st.cache_data(ttl="1hr")
def get_inferences(prompt_version_id: str):
prompton = get_prompton()
return prompton.inferences.get_inferences_list(prompt_version_id=prompt_version_id)
@st.cache_data(ttl="1d")
def get_my_feedbacks(prompt_version_id: str):
prompton = get_prompton()
my_feedbacks = prompton.feedbacks.get_feedbacks_list(
prompt_version_id=prompt_version_id,
prompton_user_id=st.session_state["current_user"].id,
)
return my_feedbacks
def get_inferences_to_evaluate(prompt_version_id: str):
"""Returns a list of inferences that the user has not yet evaluated the overall feedback
(ie. feedback.feedback_for_part = None))"""
_inferences_to_eval = get_inference_parts_to_evaluate(prompt_version_id, [None])
inferences_to_eval = [item[2] for item in _inferences_to_eval]
return inferences_to_eval
@st.cache_data(ttl="1hr")
def get_inference_parts_to_evaluate(
prompt_version_id: str, parts_to_evaluate: List[str | None]
):
"""Returns a list of tuples (inference response, parts, full inference) that the user has not yet evaluated.
If no parts_to_evaluate is provided then only overall feedbacks considered (i.e. feedback.feedback_for_part = None))
"""
inferences = get_inferences(prompt_version_id=prompt_version_id)
my_feedbacks = get_my_feedbacks(prompt_version_id=prompt_version_id)
inferences_to_eval: List[
Tuple[
prompton_types.InferenceResponseData,
List[str | None],
prompton_types.InferenceRead,
]
] = []
for _, inf in enumerate(inferences):
if (
inf.status == prompton_types.InferenceResponseStatus.PROCESSED
and isinstance(inf.response, prompton_types.InferenceResponseData)
):
parts_not_evaluated = parts_to_evaluate.copy()
for _, fb in enumerate(my_feedbacks):
if (
inf.id == fb.inference_id
and fb.feedback_for_part in parts_to_evaluate
):
try:
parts_not_evaluated.remove(fb.feedback_for_part)
except ValueError:
pass
if len(parts_not_evaluated) > 0:
inferences_to_eval.append((inf.response, parts_not_evaluated, inf))
return inferences_to_eval
def format_datetime(dt_str: str) -> str:
"""Formats an iso datetime string to a human readable format in clients local TimeZone."""
dt = datetime.fromisoformat(dt_str).astimezone(get_localzone())
_dt_str = dt.strftime("%d-%b-%y %H:%M:%S")
return _dt_str