-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_mmm_async.py
121 lines (104 loc) · 3.69 KB
/
test_mmm_async.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import requests
import json
import time
import pandas as pd
import os
import sys
import io
import dotenv
dotenv.load_dotenv()
API_KEY = os.environ.get('API_KEY', None)
def create_payload(include_file_refs=True):
openaiFileIdRefs = []
if include_file_refs:
openaiFileIdRefs = [
{
"name": "mmm_example.csv",
"id": "file-1234567890",
"mime_type": "text/csv",
"download_link": "https://raw.githubusercontent.com/pymc-labs/pymc-marketing/refs/heads/main/data/mmm_example.csv"
}
]
payload = {
"domain": "dev-nextgen-mmm.pymc-labs.com",
"method": "post",
"path": "/run_mmm_async",
"operation": "runMMMAsync",
"operation_hash": "0c869884cb92378e2dfe2ae377cac236cbc2b9d0",
"is_consequential": True,
"openaiFileIdRefs": openaiFileIdRefs,
"date_column": "date_week",
"channel_columns": [
"x1",
"x2"
],
"adstock_max_lag": 8,
"yearly_seasonality": 2,
"y_column": "y"
}
return payload
def test_missing_file_refs(base_url):
payload = create_payload(include_file_refs=False)
run_url = f"{base_url}/run_mmm_async"
headers = {
'Content-Type': 'application/json',
'X-API-Key': API_KEY
}
response = requests.post(run_url, data=json.dumps(payload), headers=headers)
assert response.status_code == 400
assert response.json()["error"] == "Invalid request format"
def test_async_mmm_run(base_url):
# Payload that includes data
payload = create_payload()
# Replace with your API endpoint for async run
run_url = f"{base_url}/run_mmm_async"
# Make a POST request to initiate the model run
headers = {
'Content-Type': 'application/json',
'X-API-Key': API_KEY
}
response = requests.post(run_url, data=json.dumps(payload), headers=headers)
# Assert the status code for initiation
assert response.status_code == 200
# Extract task_id
task_id = response.json()["task_id"]
print(f"Got task_id {task_id}")
# Polling URL
results_url = f"{base_url}/get_summary_statistics?task_id={task_id}"
# Poll for results
while True:
result_response = requests.get(results_url, headers=headers)
result_data = result_response.json()
if result_data["status"] == "completed":
# Handle completed task
# Perform additional assertions here
summary = pd.read_json(io.StringIO(result_data["summary"]),orient='split')
print('--------------------------------')
print(summary)
print('--------------------------------')
print("Task completed:!!!")
break
elif result_data["status"] == "failed":
# Handle failed task
print("Task failed:", result_data)
break
elif result_data["status"] == "pending":
# Wait before polling again
print("Pending...")
time.sleep(10)
if __name__ == "__main__":
if len(sys.argv) < 2:
print("Usage: python test_script.py [local|deployed]")
sys.exit(1)
environment = sys.argv[1]
if environment == "local":
base_url = "http://localhost:5001"
elif environment == "deployed-production":
base_url = "https://nextgen-mmm.pymc-labs.com"
elif environment == "deployed-development":
base_url = "https://dev-nextgen-mmm.pymc-labs.com"
else:
print("Invalid argument. Use 'local' or 'deployed-production' or 'deployed-development'.")
sys.exit(1)
test_missing_file_refs(base_url)
test_async_mmm_run(base_url)