Skip to content

Commit 6e1427d

Browse files
committed
enforcement of the schema and clear error if schema is violated
1 parent ae9c9a0 commit 6e1427d

File tree

4 files changed

+32
-10
lines changed

4 files changed

+32
-10
lines changed

app.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import json
2+
from jsonschema import validate, ValidationError
23
from flask import Flask, request, jsonify
34
from celery import Celery, Task
45
from kombu import serialization
56

7+
68
import pandas as pd
79
import arviz as az
810
from pymc_marketing.mmm import (
@@ -102,6 +104,16 @@ def __call__(self, *args: object, **kwargs: object) -> object:
102104
# Ensure proper permissions (readable/writable by all users)
103105
os.chmod(DATA_DIR, 0o777)
104106

107+
# Extract the request schema from the OpenAPI spec
108+
def get_mmm_request_schema():
109+
try:
110+
with open('gpt-agent/api_spec.json', 'r') as f:
111+
api_spec = json.load(f)
112+
return api_spec['paths']['/run_mmm_async']['post']['requestBody']['content']['application/json']['schema']
113+
except Exception as e:
114+
logging.error("Failed to load API spec: %s", str(e))
115+
raise e
116+
105117

106118
@celery.task(bind=True)
107119
def run_mmm_task(self, data):
@@ -233,12 +245,19 @@ def run_mmm_async():
233245
data = request.get_json()
234246
logging.debug("run_mmm_async request data: %s", data)
235247

236-
logging.info("checking that the data has file_refs: %s", data)
237-
if ("openaiFileIdRefs" not in data) or (len(data["openaiFileIdRefs"]) == 0): # TODO: do a more thorough schema check here
238-
logging.error("Data does not have openaiFileIdRefs")
239-
return jsonify({"error": "Request must include openaiFileIdRefs"}), 400
240-
else:
241-
logging.info("Data has openaiFileIdRefs")
248+
try:
249+
schema = get_mmm_request_schema()
250+
validate(instance=data, schema=schema)
251+
except ValidationError as e:
252+
logging.error("Schema validation failed: %s", str(e))
253+
return jsonify({
254+
"error": "Invalid request format",
255+
"details": {
256+
"message": str(e),
257+
"path": " -> ".join(str(p) for p in e.path),
258+
"schema_path": " -> ".join(str(p) for p in e.schema_path)
259+
}
260+
}), 400
242261

243262
task = run_mmm_task.apply_async(args=[data])
244263
logging.info("Task submitted with ID: %s", task.id)

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ dependencies:
1414
- procps-ng
1515
- yq=2.12.0
1616
- dill=0.3.9
17+
- jsonschema=4.23.0

gpt-agent/api_spec.json

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
"application/json": {
2323
"schema": {
2424
"type": "object",
25-
"required": ["openaiFileIdRefs"],
25+
"required": ["openaiFileIdRefs", "date_column", "channel_columns", "y_column"],
2626
"properties": {
2727
"openaiFileIdRefs": {
2828
"type": "array",
@@ -59,7 +59,6 @@
5959
},
6060
"date_column": {
6161
"type": "string",
62-
"default": "date",
6362
"description": "Name of the date column in data."
6463
},
6564
"channel_columns": {
@@ -69,6 +68,10 @@
6968
},
7069
"description": "List of channel column names."
7170
},
71+
"y_column": {
72+
"type": "string",
73+
"description": "Name of the y column in data."
74+
},
7275
"adstock_max_lag": {
7376
"type": "integer",
7477
"default": 8,

test_mmm_async.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,7 @@ def test_missing_file_refs(base_url):
5050
}
5151
response = requests.post(run_url, data=json.dumps(payload), headers=headers)
5252
assert response.status_code == 400
53-
assert response.json()["error"] == "Request must include openaiFileIdRefs"
54-
53+
assert response.json()["error"] == "Invalid request format"
5554

5655
def test_async_mmm_run(base_url):
5756
# Payload that includes data

0 commit comments

Comments
 (0)