Skip to content

Commit ee9a264

Browse files
Merge pull request #9 from pymc-labs/check-for-fileref
enforce the schema
2 parents 7b19c6b + 344766d commit ee9a264

File tree

5 files changed

+93
-34
lines changed

5 files changed

+93
-34
lines changed

app.py

+26
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import json
2+
from jsonschema import validate, ValidationError
23
from flask import Flask, request, jsonify
34
from celery import Celery, Task
45
from kombu import serialization
56

7+
68
import pandas as pd
79
import arviz as az
810
from pymc_marketing.mmm import (
@@ -102,6 +104,16 @@ def __call__(self, *args: object, **kwargs: object) -> object:
102104
# Ensure proper permissions (readable/writable by all users)
103105
os.chmod(DATA_DIR, 0o777)
104106

107+
# Extract the request schema from the OpenAPI spec
108+
def get_mmm_request_schema():
109+
try:
110+
with open('gpt-agent/api_spec.json', 'r') as f:
111+
api_spec = json.load(f)
112+
return api_spec['paths']['/run_mmm_async']['post']['requestBody']['content']['application/json']['schema']
113+
except Exception as e:
114+
logging.error("Failed to load API spec: %s", str(e))
115+
raise e
116+
105117

106118
@celery.task(bind=True)
107119
def run_mmm_task(self, data):
@@ -233,6 +245,20 @@ def run_mmm_async():
233245
data = request.get_json()
234246
logging.debug("run_mmm_async request data: %s", data)
235247

248+
try:
249+
schema = get_mmm_request_schema()
250+
validate(instance=data, schema=schema)
251+
except ValidationError as e:
252+
logging.error("Schema validation failed: %s", str(e))
253+
return jsonify({
254+
"error": "Invalid request format",
255+
"details": {
256+
"message": str(e),
257+
"path": " -> ".join(str(p) for p in e.path),
258+
"schema_path": " -> ".join(str(p) for p in e.schema_path)
259+
}
260+
}), 400
261+
236262
task = run_mmm_task.apply_async(args=[data])
237263
logging.info("Task submitted with ID: %s", task.id)
238264

environment.yml

+1
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ dependencies:
1414
- procps-ng
1515
- yq=2.12.0
1616
- dill=0.3.9
17+
- jsonschema=4.23.0

gpt-agent/api_spec.json

+34-4
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"info": {
44
"title": "PyMC-Marketing MMM API",
55
"description": "Asynchronous API for running Marketing Mix Modeling.",
6-
"version": "v0.0.3"
6+
"version": "v0.0.4"
77
},
88
"servers": [
99
{
@@ -22,17 +22,43 @@
2222
"application/json": {
2323
"schema": {
2424
"type": "object",
25+
"required": ["openaiFileIdRefs", "date_column", "channel_columns", "y_column"],
2526
"properties": {
2627
"openaiFileIdRefs": {
2728
"type": "array",
2829
"items": {
29-
"type": "string"
30+
"type": "object",
31+
"required": [
32+
"name",
33+
"id",
34+
"mime_type",
35+
"download_link"
36+
],
37+
"properties": {
38+
"name": {
39+
"type": "string",
40+
"description": "Name of the file"
41+
},
42+
"id": {
43+
"type": "string",
44+
"description": "OpenAI file ID"
45+
},
46+
"mime_type": {
47+
"type": "string",
48+
"description": "MIME type of the file"
49+
},
50+
"download_link": {
51+
"type": "string",
52+
"format": "uri",
53+
"description": "URL to download the file"
54+
}
55+
}
3056
},
31-
"description": "List of OpenAI file IDs to be used as references."
57+
"minItems": 1,
58+
"description": "List of OpenAI file references"
3259
},
3360
"date_column": {
3461
"type": "string",
35-
"default": "date",
3662
"description": "Name of the date column in data."
3763
},
3864
"channel_columns": {
@@ -42,6 +68,10 @@
4268
},
4369
"description": "List of channel column names."
4470
},
71+
"y_column": {
72+
"type": "string",
73+
"description": "Name of the y column in data."
74+
},
4575
"adstock_max_lag": {
4676
"type": "integer",
4777
"default": 8,

gpt-agent/gpt_prompt.md

+9-21
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ It leverages the `dev-nextgen-mmm.pymc-labs.com` API to run MMM models and retri
1111

1212
As BayesMMM, your main role is to:
1313

14-
1. Assist users in preparing and validating their data for MMM and ensure that is correctly formatted for the API operations.
14+
1. Assist users in validating their data for MMM and ensure that is correctly formatted for the API operations.
1515
2. Run the model asynchronously using `runMMMAsync`.
1616
3. Provide actionable insights and visualizations, such as saturation curves and relative channel contributions.
1717
4. Leverage the PyMC-Marketing codebase for analysis and visualization examples, replicating them to deliver meaningful insights.
@@ -20,29 +20,26 @@ Throughout your interactions provide concise responses using bullet points and f
2020

2121
## Running an MMM Analysis
2222

23-
### 1. Data Preparation
23+
### 1. Data Validation
2424

2525
Before starting, ensure the data includes:
2626

2727
- Date: Column with dates in `%Y-%m-%d` format.
2828
- Sales: Column with the target variable (renamed to `sales` if necessary).
2929
- Marketing Spend: Columns representing marketing channel spends (e.g., TV, online).
3030

31-
Handle missing values appropriately and convert the date column to the required format:
32-
33-
```python
34-
# Code example to convert date column to %Y-%m-%d format
35-
data['date_column_name'] = pd.to_datetime(data['date_column_name']).dt.strftime('%Y-%m-%d')
36-
```
37-
3831
**Very Important:**
39-
- Always confirm with the user that the data is correctly formatted before proceeding to initiate the model run.
32+
Validate the data, but do not attempt to fix it. Provide the user with code that they can run to fix the data. Instruct them to reupload the file to the GPT when the data is correctly formatted.
4033

4134
### 2. Initiating the Model Run
4235

4336
When asked to run the Bayesian MMM model you must use the `runMMMAsync` API operation with the correctly formatted data. **Do not import MMM libraries directly or attempt to run the model locally in your code interpreter**. The payload to the API should include the reference to the data file and the following parameters:
4437

45-
- **df**: The data as a CSV string.
38+
- **openaiFileIdRefs**: An array of objects with the following fields:
39+
- **name**: Name of the file.
40+
- **id**: OpenAI file ID.
41+
- **mime_type**: MIME type of the file.
42+
- **download_link**: URL to download the file.
4643
- **date_column**: Name of the date column.
4744
- **channel_columns**: List of channel spend columns.
4845
- **y_column**: Name of the y column.
@@ -96,15 +93,6 @@ The most important parameters are:
9693
* intercept: Intercept parameter
9794
* (optional) gamma_control: Control parameters that multiply the control variables
9895

99-
You can retrieve the return on ad spend from the `return_on_ad_spend` field in the payload returned by `getReturnOnAdSpend`. This is a JSON object with the following fields:
100-
101-
- `channel_columns`: List of channel columns.
102-
- `roas_mean`: Mean of the return on ad spend.
103-
- `roas_hdi_lower`: Lower bound of the 94% confidence interval of the return on ad spend.
104-
- `roas_hdi_upper`: Upper bound of the 94% confidence interval of the return on ad spend.
105-
106-
Plot the return on ad spend using the `roas_mean` and the `roas_hdi_lower` and `roas_hdi_upper` to plot the confidence interval.
107-
10896
### 6. Analysis Workflow
10997

11098
While waiting for results, you can suggest to the user to perform exploratory data analysis. Here some ideas:
@@ -120,6 +108,6 @@ After retrieving results here are some ideas:
120108

121109
- Spend with Saturation: Overlay total spend as a dashed line on the saturation plot.
122110

123-
** Important Reminder **
111+
** Very Important Reminders **
124112

125113
- Throughout your interactions provide **concise responses** using bullet points and formulas when appropriate.

test_mmm_async.py

+23-9
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,25 @@
1111

1212
API_KEY = os.environ.get('API_KEY', None)
1313

14-
def create_payload():
14+
def create_payload(include_file_refs=True):
15+
openaiFileIdRefs = []
16+
if include_file_refs:
17+
openaiFileIdRefs = [
18+
{
19+
"name": "mmm_example.csv",
20+
"id": "file-1234567890",
21+
"mime_type": "text/csv",
22+
"download_link": "https://raw.githubusercontent.com/pymc-labs/pymc-marketing/refs/heads/main/data/mmm_example.csv"
23+
}
24+
]
1525
payload = {
1626
"domain": "dev-nextgen-mmm.pymc-labs.com",
1727
"method": "post",
1828
"path": "/run_mmm_async",
1929
"operation": "runMMMAsync",
2030
"operation_hash": "0c869884cb92378e2dfe2ae377cac236cbc2b9d0",
2131
"is_consequential": True,
22-
"openaiFileIdRefs": [
23-
{
24-
"name": "mmm_example.csv",
25-
"id": "file-1234567890",
26-
"mime_type": "text/csv",
27-
"download_link": "https://raw.githubusercontent.com/pymc-labs/pymc-marketing/refs/heads/main/data/mmm_example.csv"
28-
}
29-
],
32+
"openaiFileIdRefs": openaiFileIdRefs,
3033
"date_column": "date_week",
3134
"channel_columns": [
3235
"x1",
@@ -38,6 +41,16 @@ def create_payload():
3841
}
3942
return payload
4043

44+
def test_missing_file_refs(base_url):
45+
payload = create_payload(include_file_refs=False)
46+
run_url = f"{base_url}/run_mmm_async"
47+
headers = {
48+
'Content-Type': 'application/json',
49+
'X-API-Key': API_KEY
50+
}
51+
response = requests.post(run_url, data=json.dumps(payload), headers=headers)
52+
assert response.status_code == 400
53+
assert response.json()["error"] == "Invalid request format"
4154

4255
def test_async_mmm_run(base_url):
4356
# Payload that includes data
@@ -104,4 +117,5 @@ def test_async_mmm_run(base_url):
104117
print("Invalid argument. Use 'local' or 'deployed-production' or 'deployed-development'.")
105118
sys.exit(1)
106119

120+
test_missing_file_refs(base_url)
107121
test_async_mmm_run(base_url)

0 commit comments

Comments
 (0)