-
Notifications
You must be signed in to change notification settings - Fork 2.7k
/
Copy pathapi.py
118 lines (96 loc) · 3.22 KB
/
api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
This file provides common interfaces and utilities used by eval creators to
sample from models and process the results.
"""
import logging
from abc import ABC, abstractmethod
from typing import Any, Callable, Optional, Protocol, Union, runtime_checkable
from evals.prompt.base import OpenAICreateChatPrompt, OpenAICreatePrompt, Prompt
from evals.record import record_match
logger = logging.getLogger(__name__)
class CompletionResult(ABC):
@abstractmethod
def get_completions(self) -> list[str]:
pass
@runtime_checkable
class CompletionFn(Protocol):
def __call__(
self,
prompt: Union[str, OpenAICreateChatPrompt],
**kwargs,
) -> CompletionResult:
"""
ARGS
====
`prompt`: Either a `Prompt` object or a raw prompt that will get wrapped in
the appropriate `Prompt` class.
`kwargs`: Other arguments passed to the API.
RETURNS
=======
The result of the API call.
The prompt that was fed into the API call as a str.
"""
class DummyCompletionResult(CompletionResult):
def get_completions(self) -> list[str]:
return ["This is a dummy response."]
class DummyCompletionFn(CompletionFn):
def __call__(
self, prompt: Union[OpenAICreatePrompt, OpenAICreateChatPrompt, Prompt], **kwargs
) -> CompletionResult:
return DummyCompletionResult()
# import weave
# @weave.op()
def record_and_check_match(
prompt: Any,
sampled: str,
expected: Union[str, list[str], tuple[str]],
separator: Callable[[str], bool] = None,
options: Optional[list[str]] = None,
):
"""
Records and checks if a sampled response from a CompletionFn matches the expected result.
Args:
prompt: The input prompt.
sampled: The sampled response from the model.
expected: The expected response or list of responses.
separator: Optional function to check if a character is a separator.
options: Optional list of options to match against the sampled response.
Returns:
The matched option or None if no match found.
"""
if isinstance(expected, tuple):
expected = list(expected)
elif not isinstance(expected, list):
expected = [expected]
if options is None:
options = expected
picked = None
for option in options:
if not sampled.startswith(option):
continue
if (
separator is not None
and len(sampled) > len(option)
and not separator(sampled[len(option)])
):
continue
picked = option
break
result = {
"prompt": prompt,
"sampled": sampled,
"options": options,
"picked": picked,
}
match = picked in expected
result["expected"] = expected
result["match"] = match
record_match(match, expected=expected, picked=picked, sampled=sampled, options=options)
prompt_0_content = prompt[0] if len(prompt) > 0 else dict()
prompt_0_content = prompt_0_content.get("content", "")
import weave
@weave.op()
def row(prompt_0_content, sampled, expected, picked, match):
return
row(prompt_0_content, sampled, expected, picked, match)
return picked