Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 40e24e0

Browse files
authored
Merge pull request #533 from datafold/compare-bigquery-arrays-and-structs
Compare JSON, ARRAY, STRUCT types in BigQuery (simplistically)
2 parents d198647 + 7c8d058 commit 40e24e0

File tree

6 files changed

+119
-12
lines changed

6 files changed

+119
-12
lines changed

data_diff/sqeleton/abcs/database_types.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
DbTime = datetime
1414

1515

16+
@dataclass
1617
class ColType:
1718
supported = True
1819

@@ -140,6 +141,21 @@ class JSON(ColType):
140141
pass
141142

142143

144+
@dataclass
145+
class Array(ColType):
146+
item_type: ColType
147+
148+
149+
# Unlike JSON, structs are not free-form and have a very specific set of fields and their types.
150+
# We do not parse & use those fields now, but we can do this later.
151+
# For example, in BigQuery:
152+
# - https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#struct_type
153+
# - https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#struct_literals
154+
@dataclass
155+
class Struct(ColType):
156+
pass
157+
158+
143159
@dataclass
144160
class Integer(NumericType, IKey):
145161
precision: int = 0
@@ -227,6 +243,10 @@ def parse_type(
227243
) -> ColType:
228244
"Parse type info as returned by the database"
229245

246+
@abstractmethod
247+
def to_comparable(self, value: str, coltype: ColType) -> str:
248+
"""Ensure that the expression is comparable in ``IS DISTINCT FROM``."""
249+
230250

231251
from typing import TypeVar, Generic
232252

data_diff/sqeleton/abcs/mixins.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from .database_types import TemporalType, FractionalType, ColType_UUID, Boolean, ColType, String_UUID, JSON
2+
from .database_types import Array, TemporalType, FractionalType, ColType_UUID, Boolean, ColType, String_UUID, JSON, Struct
33
from .compiler import Compilable
44

55

@@ -8,6 +8,11 @@ class AbstractMixin(ABC):
88

99

1010
class AbstractMixin_NormalizeValue(AbstractMixin):
11+
12+
@abstractmethod
13+
def to_comparable(self, value: str, coltype: ColType) -> str:
14+
"""Ensure that the expression is comparable in ``IS DISTINCT FROM``."""
15+
1116
@abstractmethod
1217
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
1318
"""Creates an SQL expression, that converts 'value' to a normalized timestamp.
@@ -51,7 +56,15 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
5156

5257
def normalize_json(self, value: str, _coltype: JSON) -> str:
5358
"""Creates an SQL expression, that converts 'value' to its minified json string representation."""
54-
raise NotImplementedError()
59+
return self.to_string(value)
60+
61+
def normalize_array(self, value: str, _coltype: Array) -> str:
62+
"""Creates an SQL expression, that serialized an array into a JSON string."""
63+
return self.to_string(value)
64+
65+
def normalize_struct(self, value: str, _coltype: Struct) -> str:
66+
"""Creates an SQL expression, that serialized a typed struct into a JSON string."""
67+
return self.to_string(value)
5568

5669
def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
5770
"""Creates an SQL expression, that converts 'value' to a normalized representation.
@@ -79,6 +92,10 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
7992
return self.normalize_boolean(value, coltype)
8093
elif isinstance(coltype, JSON):
8194
return self.normalize_json(value, coltype)
95+
elif isinstance(coltype, Array):
96+
return self.normalize_array(value, coltype)
97+
elif isinstance(coltype, Struct):
98+
return self.normalize_struct(value, coltype)
8299
return self.to_string(value)
83100

84101

data_diff/sqeleton/databases/base.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from ..queries.ast_classes import Random
1818
from ..abcs.database_types import (
1919
AbstractDatabase,
20-
T_Dialect,
20+
Array,
21+
Struct,
2122
AbstractDialect,
2223
AbstractTable,
2324
ColType,
@@ -165,6 +166,10 @@ def concat(self, items: List[str]) -> str:
165166
joined_exprs = ", ".join(items)
166167
return f"concat({joined_exprs})"
167168

169+
def to_comparable(self, value: str, coltype: ColType) -> str:
170+
"""Ensure that the expression is comparable in ``IS DISTINCT FROM``."""
171+
return value
172+
168173
def is_distinct_from(self, a: str, b: str) -> str:
169174
return f"{a} is distinct from {b}"
170175

@@ -229,7 +234,7 @@ def parse_type(
229234
""" """
230235

231236
cls = self._parse_type_repr(type_repr)
232-
if not cls:
237+
if cls is None:
233238
return UnknownColType(type_repr)
234239

235240
if issubclass(cls, TemporalType):
@@ -257,10 +262,7 @@ def parse_type(
257262
)
258263
)
259264

260-
elif issubclass(cls, (Text, Native_UUID)):
261-
return cls()
262-
263-
elif issubclass(cls, JSON):
265+
elif issubclass(cls, (JSON, Array, Struct, Text, Native_UUID)):
264266
return cls()
265267

266268
raise TypeError(f"Parsing {type_repr} returned an unknown type '{cls}'.")

data_diff/sqeleton/databases/bigquery.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1-
from typing import List, Union
1+
import re
2+
from typing import Any, List, Union
23
from ..abcs.database_types import (
4+
ColType,
5+
Array,
6+
JSON,
7+
Struct,
38
Timestamp,
49
Datetime,
510
Integer,
@@ -10,6 +15,7 @@
1015
FractionalType,
1116
TemporalType,
1217
Boolean,
18+
UnknownColType,
1319
)
1420
from ..abcs.mixins import (
1521
AbstractMixin_MD5,
@@ -36,6 +42,7 @@ def md5_as_int(self, s: str) -> str:
3642

3743

3844
class Mixin_NormalizeValue(AbstractMixin_NormalizeValue):
45+
3946
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
4047
if coltype.rounds:
4148
timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))"
@@ -57,6 +64,27 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str:
5764
def normalize_boolean(self, value: str, _coltype: Boolean) -> str:
5865
return self.to_string(f"cast({value} as int)")
5966

67+
def normalize_json(self, value: str, _coltype: JSON) -> str:
68+
# BigQuery is unable to compare arrays & structs with ==/!=/distinct from, e.g.:
69+
# Got error: 400 Grouping is not defined for arguments of type ARRAY<INT64> at …
70+
# So we do the best effort and compare it as strings, hoping that the JSON forms
71+
# match on both sides: i.e. have properly ordered keys, same spacing, same quotes, etc.
72+
return f"to_json_string({value})"
73+
74+
def normalize_array(self, value: str, _coltype: Array) -> str:
75+
# BigQuery is unable to compare arrays & structs with ==/!=/distinct from, e.g.:
76+
# Got error: 400 Grouping is not defined for arguments of type ARRAY<INT64> at …
77+
# So we do the best effort and compare it as strings, hoping that the JSON forms
78+
# match on both sides: i.e. have properly ordered keys, same spacing, same quotes, etc.
79+
return f"to_json_string({value})"
80+
81+
def normalize_struct(self, value: str, _coltype: Struct) -> str:
82+
# BigQuery is unable to compare arrays & structs with ==/!=/distinct from, e.g.:
83+
# Got error: 400 Grouping is not defined for arguments of type ARRAY<INT64> at …
84+
# So we do the best effort and compare it as strings, hoping that the JSON forms
85+
# match on both sides: i.e. have properly ordered keys, same spacing, same quotes, etc.
86+
return f"to_json_string({value})"
87+
6088

6189
class Mixin_Schema(AbstractMixin_Schema):
6290
def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
@@ -112,11 +140,12 @@ class Dialect(BaseDialect, Mixin_Schema):
112140
"BIGNUMERIC": Decimal,
113141
"FLOAT64": Float,
114142
"FLOAT32": Float,
115-
# Text
116143
"STRING": Text,
117-
# Boolean
118144
"BOOL": Boolean,
145+
"JSON": JSON,
119146
}
147+
TYPE_ARRAY_RE = re.compile(r'ARRAY<(.+)>')
148+
TYPE_STRUCT_RE = re.compile(r'STRUCT<(.+)>')
120149
MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_TimeTravel, Mixin_RandomSample}
121150

122151
def random(self) -> str:
@@ -134,6 +163,40 @@ def type_repr(self, t) -> str:
134163
except KeyError:
135164
return super().type_repr(t)
136165

166+
def parse_type(
167+
self,
168+
table_path: DbPath,
169+
col_name: str,
170+
type_repr: str,
171+
*args: Any, # pass-through args
172+
**kwargs: Any, # pass-through args
173+
) -> ColType:
174+
col_type = super().parse_type(table_path, col_name, type_repr, *args, **kwargs)
175+
if isinstance(col_type, UnknownColType):
176+
177+
m = self.TYPE_ARRAY_RE.fullmatch(type_repr)
178+
if m:
179+
item_type = self.parse_type(table_path, col_name, m.group(1), *args, **kwargs)
180+
col_type = Array(item_type=item_type)
181+
182+
# We currently ignore structs' structure, but later can parse it too. Examples:
183+
# - STRUCT<INT64, STRING(10)> (unnamed)
184+
# - STRUCT<foo INT64, bar STRING(10)> (named)
185+
# - STRUCT<foo INT64, bar ARRAY<INT64>> (with complex fields)
186+
# - STRUCT<foo INT64, bar STRUCT<a INT64, b INT64>> (nested)
187+
m = self.TYPE_STRUCT_RE.fullmatch(type_repr)
188+
if m:
189+
col_type = Struct()
190+
191+
return col_type
192+
193+
def to_comparable(self, value: str, coltype: ColType) -> str:
194+
"""Ensure that the expression is comparable in ``IS DISTINCT FROM``."""
195+
if isinstance(coltype, (JSON, Array, Struct)):
196+
return self.normalize_value_by_type(value, coltype)
197+
else:
198+
return super().to_comparable(value, coltype)
199+
137200
def set_timezone_to_utc(self) -> str:
138201
raise NotImplementedError()
139202

data_diff/sqeleton/queries/ast_classes.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,9 @@ class IsDistinctFrom(ExprNode, LazyOps):
352352
type = bool
353353

354354
def compile(self, c: Compiler) -> str:
355-
return c.dialect.is_distinct_from(c.compile(self.a), c.compile(self.b))
355+
a = c.dialect.to_comparable(c.compile(self.a), self.a.type)
356+
b = c.dialect.to_comparable(c.compile(self.b), self.b.type)
357+
return c.dialect.is_distinct_from(a, b)
356358

357359

358360
@dataclass(eq=False, order=False)

tests/sqeleton/test_query.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ def concat(self, l: List[str]) -> str:
2626
s = ", ".join(l)
2727
return f"concat({s})"
2828

29+
def to_comparable(self, s: str) -> str:
30+
return s
31+
2932
def to_string(self, s: str) -> str:
3033
return f"cast({s} as varchar)"
3134

0 commit comments

Comments
 (0)