Skip to content

Commit 6f021e8

Browse files
authored
Create MVP AI console (#934)
1 parent 4d9f803 commit 6f021e8

37 files changed

+2342
-11
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ __pycache__
66

77
*.log
88

9+
# This is a git submodule
10+
/ai/data/
11+
912
# Do not save generated files
1013
/ai/ft/outputs/
1114
/ai/outputs/

.gitmodules

+4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
11
[submodule "third_party/angular_components"]
22
path = third_party/angular_components
33
url = https://github.com/angular/components.git
4+
5+
[submodule "ai/data"]
6+
path = ai/data
7+
url = [email protected]:datasets/wwwillchen/mesop-data

ai/README.md

+18
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,24 @@ All the commands should be run from the `ai/` directory.
99
- All entry-points are in `src/*.py` - this includes the AI service and scripts.
1010
- `src/common` contains code that's shared between offline scripts and the online service.
1111

12+
## AI Console
13+
14+
**Setup**:
15+
16+
```sh
17+
git clone [email protected]:datasets/wwwillchen/mesop-data data
18+
```
19+
20+
**Running**:
21+
22+
Inside `ai/src/`, run the following command:
23+
24+
```sh
25+
mesop console.py --port=32124
26+
```
27+
28+
> Note: you can run this on a separate port to avoid conflicting with the main Mesop development app.
29+
1230
## Scripts
1331

1432
These are scripts used to generate and process data for offline evaluation.

ai/src/ai/common/diff.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import re
2+
from typing import NamedTuple
3+
4+
EDIT_HERE_MARKER = " # <--- EDIT HERE"
5+
6+
7+
class ApplyPatchResult(NamedTuple):
8+
has_error: bool
9+
result: str
10+
11+
12+
def apply_patch(original_code: str, patch: str) -> ApplyPatchResult:
13+
# Extract the diff content
14+
diff_pattern = r"<<<<<<< ORIGINAL(.*?)=======\n(.*?)>>>>>>> UPDATED"
15+
matches = re.findall(diff_pattern, patch, re.DOTALL)
16+
patched_code = original_code
17+
if len(matches) == 0:
18+
print("[WARN] No diff found:", patch)
19+
return ApplyPatchResult(
20+
True,
21+
"[AI-001] Sorry! AI output was mis-formatted. Please try again.",
22+
)
23+
for original, updated in matches:
24+
original = original.strip().replace(EDIT_HERE_MARKER, "")
25+
updated = updated.strip().replace(EDIT_HERE_MARKER, "")
26+
27+
# Replace the original part with the updated part
28+
new_patched_code = patched_code.replace(original, updated, 1)
29+
if new_patched_code == patched_code:
30+
return ApplyPatchResult(
31+
True,
32+
"[AI-002] Sorry! AI output could not be used. Please try again.",
33+
)
34+
patched_code = new_patched_code
35+
36+
return ApplyPatchResult(False, patched_code)

ai/src/ai/common/entity_store.py

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import os
2+
from typing import Generic, TypeVar
3+
4+
from pydantic import BaseModel
5+
6+
T = TypeVar("T", bound=BaseModel)
7+
8+
9+
def get_data_path(dirname: str) -> str:
10+
return os.path.join(
11+
os.path.dirname(__file__), "..", "..", "..", "data", dirname
12+
)
13+
14+
15+
class EntityStore(Generic[T]):
16+
def __init__(self, entity_type: type[T], *, dirname: str):
17+
self.entity_type = entity_type
18+
self.directory_path = get_data_path(dirname)
19+
20+
def get(self, id: str) -> T:
21+
file_path = os.path.join(self.directory_path, f"{id}.json")
22+
with open(file_path) as f:
23+
entity_json = f.read()
24+
entity = self.entity_type.model_validate_json(entity_json)
25+
return entity
26+
27+
def get_all(self) -> list[T]:
28+
entities: list[T] = []
29+
for filename in os.listdir(self.directory_path):
30+
if filename.endswith(".json"):
31+
file_path = os.path.join(self.directory_path, filename)
32+
with open(file_path) as f:
33+
entity_json = f.read()
34+
entities.append(self.entity_type.model_validate_json(entity_json))
35+
entities.sort(key=lambda x: x.id, reverse=True)
36+
return entities
37+
38+
def save(self, entity: T, overwrite: bool = False):
39+
id = entity.id # type: ignore
40+
entity_path = os.path.join(self.directory_path, f"{id}.json")
41+
if not overwrite and os.path.exists(entity_path):
42+
raise ValueError(
43+
f"{self.entity_type.__name__} with id {id} already exists"
44+
)
45+
with open(entity_path, "w") as f:
46+
f.write(entity.model_dump_json(indent=4))
47+
48+
def delete(self, entity_id: str):
49+
os.remove(os.path.join(self.directory_path, f"{entity_id}.json"))

ai/src/ai/common/example.py

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
"""
2+
An example is a single input/output pair.
3+
- Examples are used for fine-tuning a model (i.e. golden example) or running an eval (i.e. expected example).
4+
- There are two types of examples:
5+
- **Golden Example**: A golden example is an example that is used to create a golden dataset.
6+
- **Expected Example**: An expected example is an example that is used to evaluate a producer.
7+
Internally, once an expected example has been run through an eval, we create an **evaluated example**, but you don't need to create this manually in the UI.
8+
"""
9+
10+
import os
11+
import shutil
12+
from typing import Generic, Literal, TypeVar
13+
14+
from pydantic import BaseModel
15+
16+
17+
class ExampleInput(BaseModel):
18+
prompt: str
19+
input_code: str | None = None
20+
line_number_target: int | None = None
21+
22+
23+
class BaseExample(BaseModel):
24+
id: str
25+
input: ExampleInput
26+
27+
28+
class ExampleOutput(BaseModel):
29+
output_code: str | None = None
30+
raw_output: str | None = None
31+
output_type: Literal["full", "diff"] = "diff"
32+
33+
34+
class ExpectedExample(BaseExample):
35+
expect_executable: bool = True
36+
expect_type_checkable: bool = True
37+
38+
39+
class ExpectResult(BaseModel):
40+
name: Literal["executable", "type_checkable", "patchable"]
41+
score: int # 0 or 1
42+
message: str | None = None
43+
44+
45+
class EvaluatedExampleOutput(BaseModel):
46+
time_spent_secs: float
47+
tokens: int
48+
output: ExampleOutput
49+
expect_results: list[ExpectResult]
50+
51+
52+
class EvaluatedExample(BaseModel):
53+
expected: ExpectedExample
54+
outputs: list[EvaluatedExampleOutput]
55+
56+
57+
class GoldenExample(BaseExample):
58+
output: ExampleOutput
59+
60+
61+
T = TypeVar("T", bound=BaseExample)
62+
63+
64+
class ExampleStore(Generic[T]):
65+
def __init__(self, entity_type: type[T], *, dirname: str):
66+
self.entity_type = entity_type
67+
self.directory_path = os.path.join(
68+
os.path.dirname(__file__), "..", "..", "..", "data", dirname
69+
)
70+
71+
def get(self, id: str) -> T:
72+
dir_path = os.path.join(self.directory_path, id)
73+
json_path = os.path.join(dir_path, "example_input.json")
74+
with open(json_path) as f:
75+
entity_json = f.read()
76+
entity = self.entity_type.model_validate_json(entity_json)
77+
input = entity.input
78+
input_py_path = os.path.join(dir_path, "input.py")
79+
if os.path.exists(input_py_path):
80+
with open(input_py_path) as f:
81+
input.input_code = f.read()
82+
if isinstance(entity, GoldenExample):
83+
output_py_path = os.path.join(dir_path, "output.py")
84+
if os.path.exists(output_py_path):
85+
with open(output_py_path) as f:
86+
entity.output.output_code = f.read()
87+
raw_output_path = os.path.join(dir_path, "raw_output.txt")
88+
if os.path.exists(raw_output_path):
89+
with open(raw_output_path) as f:
90+
entity.output.raw_output = f.read()
91+
return entity
92+
93+
def get_all(self) -> list[T]:
94+
entities: list[T] = []
95+
for filename in os.listdir(self.directory_path):
96+
entities.append(self.get(filename))
97+
return entities
98+
99+
def save(self, entity: T, overwrite: bool = False):
100+
id = entity.id
101+
dir_path = os.path.join(self.directory_path, id)
102+
103+
if not overwrite:
104+
if os.path.exists(dir_path):
105+
raise ValueError(
106+
f"{self.entity_type.__name__} with id {id} already exists"
107+
)
108+
else:
109+
os.mkdir(dir_path)
110+
json_path = os.path.join(dir_path, "example_input.json")
111+
input_code = entity.input.input_code
112+
if input_code:
113+
input_py_path = os.path.join(dir_path, "input.py")
114+
with open(input_py_path, "w") as f:
115+
f.write(input_code)
116+
entity.input.input_code = None
117+
118+
if isinstance(entity, GoldenExample):
119+
output_py_path = os.path.join(dir_path, "output.py")
120+
with open(output_py_path, "w") as f:
121+
f.write(entity.output.output_code)
122+
raw_output_path = os.path.join(dir_path, "raw_output.txt")
123+
with open(raw_output_path, "w") as f:
124+
f.write(entity.output.raw_output)
125+
entity.output.output_code = None
126+
entity.output.raw_output = None
127+
with open(json_path, "w") as f:
128+
f.write(entity.model_dump_json(indent=4))
129+
130+
def delete(self, entity_id: str):
131+
shutil.rmtree(os.path.join(self.directory_path, entity_id))
132+
133+
134+
expected_example_store = ExampleStore(
135+
ExpectedExample, dirname="expected_examples"
136+
)
137+
golden_example_store = ExampleStore(GoldenExample, dirname="golden_examples")

ai/src/ai/common/executor.py

+128
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
from os import getenv
2+
from typing import Iterator
3+
4+
from openai import OpenAI
5+
from openai.types.chat import (
6+
ChatCompletionMessageParam,
7+
)
8+
9+
from ai.common.diff import EDIT_HERE_MARKER, ApplyPatchResult, apply_patch
10+
from ai.common.entity_store import get_data_path
11+
from ai.common.example import ExampleInput
12+
from ai.common.model import model_store
13+
from ai.common.producer import producer_store
14+
from ai.common.prompt_context import prompt_context_store
15+
from ai.common.prompt_fragment import PromptFragment, prompt_fragment_store
16+
17+
18+
class ProviderExecutor:
19+
def __init__(self, model_name: str, prompt_fragments: list[PromptFragment]):
20+
self.model_name = model_name
21+
22+
self.prompt_fragments = [
23+
PromptFragment(
24+
id=pf.id,
25+
role=pf.role,
26+
chain_of_thought=pf.chain_of_thought,
27+
content_value=get_content_value(pf),
28+
content_path=None,
29+
)
30+
for pf in prompt_fragments
31+
]
32+
33+
def format_messages(
34+
self, input: ExampleInput
35+
) -> list[ChatCompletionMessageParam]:
36+
code = input.input_code or ""
37+
# Add sentinel token based on line_number (1-indexed)
38+
if input.line_number_target is not None:
39+
code_lines = code.splitlines()
40+
if 1 <= input.line_number_target <= len(code_lines):
41+
code_lines[input.line_number_target - 1] += EDIT_HERE_MARKER
42+
code = "\n".join(code_lines)
43+
44+
return [
45+
{
46+
"role": pf.role,
47+
"content": pf.content_value.replace("<APP_CODE>", code).replace( # type: ignore
48+
"<APP_CHANGES>", input.prompt
49+
),
50+
}
51+
for pf in self.prompt_fragments
52+
]
53+
54+
def execute(self, input: ExampleInput) -> str: ...
55+
56+
def execute_stream(self, input: ExampleInput) -> Iterator[str]: ...
57+
58+
59+
class OpenaiExecutor(ProviderExecutor):
60+
def __init__(self, model_name: str, prompt_fragments: list[PromptFragment]):
61+
super().__init__(model_name, prompt_fragments)
62+
self.client = OpenAI(
63+
api_key=getenv("OPENAI_API_KEY"),
64+
)
65+
66+
def execute(self, input: ExampleInput) -> str:
67+
response = self.client.chat.completions.create(
68+
model=self.model_name,
69+
max_tokens=10_000,
70+
messages=self.format_messages(input),
71+
)
72+
return response.choices[0].message.content or ""
73+
74+
def execute_stream(self, input: ExampleInput) -> Iterator[str]:
75+
stream = self.client.chat.completions.create(
76+
model=self.model_name,
77+
max_tokens=10_000,
78+
messages=self.format_messages(input),
79+
stream=True,
80+
)
81+
for chunk in stream:
82+
content = chunk.choices[0].delta.content
83+
yield content or ""
84+
85+
86+
provider_executors: dict[str, type[ProviderExecutor]] = {
87+
"openai": OpenaiExecutor,
88+
}
89+
90+
91+
class ProducerExecutor:
92+
def __init__(self, producer_id: str):
93+
self.producer = producer_store.get(producer_id)
94+
95+
def get_provider_executor(self) -> ProviderExecutor:
96+
prompt_context = prompt_context_store.get(self.producer.prompt_context_id)
97+
prompt_fragments = [
98+
prompt_fragment_store.get(pfid) for pfid in prompt_context.fragment_ids
99+
]
100+
model = model_store.get(self.producer.mesop_model_id)
101+
provider_executor_type = provider_executors.get(model.provider)
102+
if provider_executor_type is None:
103+
raise ValueError(f"Provider {model.provider} not supported")
104+
provider_executor = provider_executor_type(model.name, prompt_fragments)
105+
return provider_executor
106+
107+
def execute(self, input: ExampleInput):
108+
return self.get_provider_executor().execute(input)
109+
110+
def execute_stream(self, input: ExampleInput):
111+
return self.get_provider_executor().execute_stream(input)
112+
113+
def transform_output(self, input_code: str, output: str):
114+
if self.producer.output_format == "diff":
115+
return apply_patch(input_code, output)
116+
elif self.producer.output_format == "full":
117+
return ApplyPatchResult(True, output)
118+
else:
119+
raise ValueError(f"Unknown output format: {self.producer.output_format}")
120+
121+
122+
def get_content_value(pf: PromptFragment) -> str | None:
123+
if pf.content_value is not None:
124+
return pf.content_value
125+
if pf.content_path is not None:
126+
with open(get_data_path(pf.content_path.replace("//", ""))) as f:
127+
return f.read()
128+
return None

0 commit comments

Comments
 (0)