-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathhelpers.py
223 lines (163 loc) · 6.34 KB
/
helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
# stdlib
import base64
import json
import urllib.parse
from typing import List
# third party
import pyarrow as pa
import streamlit as st
# first party
from chart import create_chart
from client import ConnAttr
from schema import Query
def keys_exist_in_dict(keys_list, dct):
return all(key in dct for key in keys_list)
def get_shared_elements(all_elements: List[List]):
if len(all_elements) == 0:
return []
try:
unique = set(all_elements[0]).intersection(*all_elements[1:])
except IndexError:
unique = set(all_elements[0])
return list(unique)
def to_arrow_table(byte_string: str, to_pandas: bool = True) -> pa.Table:
with pa.ipc.open_stream(base64.b64decode(byte_string)) as reader:
arrow_table = pa.Table.from_batches(reader, reader.schema)
if to_pandas:
return arrow_table.to_pandas()
return arrow_table
def create_graphql_code(query: Query) -> str:
return f"""
import requests
url = '{st.session_state.conn.host}/api/graphql'
query = \'\'\'{query.gql}\'\'\'
payload = {{'query': query, 'variables': {query.variables}}}
response = requests.post(url, json=payload, headers={{'Authorization': 'Bearer ***'}})
"""
def create_python_sdk_code(query: Query) -> str:
arguments = query.sdk
arguments_str = ",\n".join(
[f" {k}={v}" for k, v in arguments.items() if v]
)
return f"""
from dbtsl import SemanticLayerClient
client = SemanticLayerClient(
environment_id={st.session_state.conn.params['environmentid']},
auth_token="<your-semantic-layer-api-token>",
host="{st.session_state.conn.host.replace('https://', '')}",
)
def main():
with client.session():
table = client.query(\n{arguments_str}\n )
print(table)
main()
"""
def convert_df(df, to="to_csv", index=False):
return getattr(df, to)(index=index).encode("utf8")
def create_explorer_link(query):
if "account_id" in st.session_state:
_, col2 = st.columns([4, 1.5])
url = url_for_explorer(query.metric_names)
col2.page_link(url, label="View from dbt Explorer", icon="🕵️")
def create_tabs(state: st.session_state, suffix: str) -> None:
keys = ["query", "df", "compiled_sql"]
keys_with_suffix = [f"{key}_{suffix}" for key in keys]
if all(key in state for key in keys_with_suffix):
sql = getattr(state, f"compiled_sql_{suffix}")
df = getattr(state, f"df_{suffix}")
query = getattr(state, f"query_{suffix}")
tab1, tab2, tab3 = st.tabs(["Chart", "Data", "SQL"])
with tab1:
create_chart(df, query, suffix)
with tab2:
st.dataframe(df, use_container_width=True)
with tab3:
st.code(sql, language="sql")
create_explorer_link(query)
def encode_dictionary(d):
# Convert the dictionary to a JSON string
json_string = json.dumps(d)
# Convert the JSON string to bytes
json_bytes = json_string.encode("utf-8")
# Encode the bytes using Base64
base64_bytes = base64.b64encode(json_bytes)
# Convert the Base64 bytes back to string for easy storage/transmission
base64_string = base64_bytes.decode("utf-8")
return base64_string
def decode_string(s: str):
# Convert the Base64 string back to bytes
try:
base64_bytes = s.encode("utf-8")
except AttributeError:
return None
# Decode the Base64 bytes to get back the original bytes
json_bytes = base64.b64decode(base64_bytes)
# Convert bytes back to JSON string
json_string = json_bytes.decode("utf-8")
# Parse the JSON string back to a dictionary
d = json.loads(json_string)
return d
def set_context_query_param(params: List[str]):
d = {k: st.session_state[k] for k in params if k in st.session_state}
encoded = encode_dictionary(d)
st.query_params = {"context": encoded}
def retrieve_context_query_param():
context = st.query_params.get("context", None)
return decode_string(context)
def get_access_url(conn: ConnAttr = None):
if conn is None:
conn = st.session_state.conn
if conn.host.endswith(".semantic-layer"):
host = conn.host.replace(".semantic-layer", "")
elif conn.host.startswith("https://semantic-layer."):
host = conn.host.replace("https://semantic-layer.", "https://")
else:
host = conn.host.replace(".semantic-layer.", ".")
return host
def url_for_disco(conn: ConnAttr = None):
access_url = get_access_url(conn)
netloc = urllib.parse.urlparse(access_url).netloc
netloc_split = netloc.split(".")
if "us1" in netloc_split or "us2" in netloc_split:
account_prefix = netloc_split[0]
the_rest = ".".join(netloc_split[1:])
host = f"https://{account_prefix}.metadata.{the_rest}"
else:
host = f"https://metadata.{netloc}"
return host
def url_for_explorer(metrics: List[str], *, conn: ConnAttr = None):
if "account_id" not in st.session_state or "project_id" not in st.session_state:
return
metric_string = " ".join(f"+metric:{metric}" for metric in metrics)
encoded_param = urllib.parse.quote_plus(metric_string)
account_id = st.session_state.account_id
project_id = st.session_state.project_id
host = get_access_url(conn)
full_url = f"{host}/explore/{account_id}/projects/{project_id}/environments/production/lineage/?select={encoded_param}"
return full_url
def construct_cli_command(query: Query):
metrics = ",".join(query.metric_names)
group_by = ",".join(query.dimension_names)
command = f"dbt sl query --metrics {metrics}"
if group_by:
command += f" --group-by {group_by}"
if query.where:
where_str = " AND ".join([w.sql for w in query.where])
command += f' --where "{where_str}"'
if query.limit:
command += f" --limit {query.limit}"
if query.orderBy:
order_by_inputs = []
for order_by_input in query.orderBy:
if order_by_input.metric:
col = order_by_input.metric.name
else:
col = order_by_input.groupBy.name
if order_by_input.groupBy.grain:
col += f"__{order_by_input.groupBy.grain}"
if order_by_input.descending:
col = f"-{col}"
order_by_inputs.append(col)
order_by = ",".join(order_by_inputs)
command += f" --order-by {order_by}"
return command