forked from cancan101/airtable-db-api
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdialect.py
120 lines (94 loc) · 3.63 KB
/
dialect.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from __future__ import annotations
import urllib.parse
from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
List,
Optional,
Sequence,
Tuple,
Union,
)
from shillelagh.backends.apsw.dialects.base import APSWDialect
from .types import BaseMetadata
if TYPE_CHECKING:
from sqlalchemy.engine import Connection
from sqlalchemy.engine.url import URL
# -----------------------------------------------------------------------------
ADAPTER_NAME = "airtable"
# -----------------------------------------------------------------------------
def extract_query_host(
url: URL,
) -> Tuple[Dict[str, Union[str, Sequence[str]]], Optional[str]]:
"""
Extract the query from the SQLAlchemy URL.
"""
if url.query:
return dict(url.query), url.host
# there's a bug in how SQLAlchemy <1.4 handles URLs without trailing / in hosts,
# putting the query string in the host; handle that case here
if url.host and "?" in url.host:
real_host, query_str = url.host.split("?", 1)
return dict(urllib.parse.parse_qsl(query_str)), real_host
return {}, url.host
# -----------------------------------------------------------------------------
class APSWAirtableDialect(APSWDialect):
name = "airtable"
supports_statement_cache = True
def __init__(
self,
airtable_api_key: Optional[str] = None,
base_metadata: Optional[BaseMetadata] = None,
# Ick:
date_columns: Optional[Dict[str, Collection[str]]] = None,
**kwargs: Any,
):
# We tell Shillelagh that this dialect supports just one adapter
super().__init__(safe=True, adapters=[ADAPTER_NAME], **kwargs)
self.airtable_api_key = airtable_api_key
self.base_metadata = base_metadata
self.date_columns = date_columns
def get_table_names(
self, connection: Connection, schema: Optional[str] = None, **kwargs: Any
) -> List[str]:
url_query, _ = extract_query_host(connection.engine.url)
tables = url_query.get("tables")
if tables is not None:
if isinstance(tables, str):
tables = [tables]
return list(tables)
elif self.base_metadata is not None:
return [table["name"] for table in self.base_metadata.values()]
return []
def create_connect_args(
self,
url: URL,
) -> Tuple[Tuple[()], Dict[str, Any]]:
args, kwargs = super().create_connect_args(url)
if "adapter_kwargs" in kwargs and kwargs["adapter_kwargs"] != {}:
raise ValueError(
f"Unexpected adapter_kwargs found: {kwargs['adapter_kwargs']}"
)
if url.password and self.airtable_api_key:
raise ValueError("Both password and airtable_api_key were provided")
url_query, url_host = extract_query_host(url)
peek_rows = None
if "peek_rows" in url_query:
peek_rows_raw = url_query["peek_rows"]
if not isinstance(peek_rows_raw, str):
peek_rows_raw = peek_rows_raw[-1]
peek_rows = int(peek_rows_raw)
# At some point we might have args
adapter_kwargs = {
ADAPTER_NAME: {
"api_key": self.airtable_api_key or url.password,
"base_id": url_host,
"base_metadata": self.base_metadata,
"peek_rows": peek_rows,
"date_columns": self.date_columns,
}
}
# this seems gross, esp the path override. unclear why memory has to be set here
return args, {**kwargs, "path": ":memory:", "adapter_kwargs": adapter_kwargs}