-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support uses of BACK that cause correlated references: fix remaining decorrelation edge cases #254
Changes from all commits
deeb914
5d6c513
170492e
8b8c098
cb83f1f
ef39e6e
3c79543
3db13d1
8dc17b3
044c217
0252cee
98c722b
aa86d31
b03a5ec
b569bfb
7607fea
215735b
0b8e1fb
a5190f8
a0d43f3
24c1e4a
8009e05
5d87b59
a3ee536
9945752
4c6aa79
e60dcc3
a23c395
2515683
d34c64c
cd59dc2
92778e3
6003ed0
fb86e3a
ca6b653
22dfb32
e550ae4
f7af1a9
3e0b62c
6a66c5f
19cfd47
acba2c6
3f0b703
fb25054
eb81179
0f5df29
9752b92
223c1b5
7b142eb
947d405
df7e1b0
3f167a4
09f2bfe
30536e8
261f869
d5b315e
3794707
c04f34d
39ce2ac
28065c0
d88192f
fed91e5
7d9492d
2ef58e3
b904f3a
7f69536
fb19842
d1ffdb5
91e2a74
8b9db1e
2c29a3d
e8b0cdb
a5d8c1f
836adca
9c72887
49f46d8
81d130e
60fea67
9980a2c
4c54d57
7a3fd30
b519cd0
4033936
c2bbb59
8fcc609
8a8e49f
4eeadd9
9427cce
397f40d
71ca0e2
719eec5
2289aaf
52c2af5
19b3628
600616e
b6c9b85
afd241b
ceff105
64311bc
26b21a2
8d6878f
06dd69f
6949cdd
7afe8ad
8184a7e
318eb1b
45fea5b
7bb353f
0cf4c11
3a1516e
746eed5
4818c59
17116d4
9ab9bed
a8f6535
cf8cc50
7556b5e
37ebdb4
8ff5157
9191ec4
9bc3564
67a342e
bb24b30
8f83903
b096443
e9f2606
9772f61
ec93596
0a00cea
d1e1520
fd37df1
a75bba5
8d75668
2cfe999
322c722
1e44262
4203c6c
5b6e9b3
b480c3b
c75b9fd
5a55ac5
ad61d27
07d771d
ba095ec
e82c9a1
639a56b
01e6d97
eab09e9
7ef0476
1246ba5
1bd9b07
c83b67e
3fbcc1f
fdbf909
ec77c0d
a5531e3
4becac1
c4af826
f44ba0e
11ac86d
39da637
8b74502
32074d9
2105bc8
d6ba5a1
50bad40
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,7 +34,7 @@ class Decorrelater: | |
|
||
def make_decorrelate_parent( | ||
self, hybrid: HybridTree, child_idx: int, required_steps: int | ||
) -> HybridTree: | ||
) -> tuple[HybridTree, int]: | ||
""" | ||
Creates a snapshot of the ancestry of the hybrid tree that contains | ||
a correlated child, without any of its children, its descendants, or | ||
|
@@ -50,21 +50,27 @@ def make_decorrelate_parent( | |
derivable. | ||
|
||
Returns: | ||
A snapshot of `hybrid` and its ancestry in the hybrid tree, without | ||
without any of its children or pipeline operators that occur during | ||
or after the derivation of the correlated child, or without any of | ||
its descendants. | ||
A tuple where the first entry is a snapshot of `hybrid` and its | ||
ancestry in the hybrid tree, without without any of its children or | ||
pipeline operators that occur during or after the derivation of the | ||
correlated child, or without any of its descendants. The second | ||
entry is the number of ancestor layers that should be skipped due | ||
to the PARTITION edge case. | ||
""" | ||
if isinstance(hybrid.pipeline[0], HybridPartition) and child_idx == 0: | ||
# Special case: if the correlated child is the data argument of a | ||
# partition operation, then the parent to snapshot is actually the | ||
# parent of the level containing the partition operation. In this | ||
# case, all of the parent's children & pipeline operators should be | ||
# included in the snapshot. | ||
assert hybrid.parent is not None | ||
return self.make_decorrelate_parent( | ||
if hybrid.parent is None: | ||
raise ValueError( | ||
"Malformed hybrid tree: partition data input to a partition node cannot contain a correlated reference to the partition node." | ||
) | ||
result = self.make_decorrelate_parent( | ||
hybrid.parent, len(hybrid.parent.children), len(hybrid.pipeline) | ||
) | ||
return result[0], result[1] + 1 | ||
# Temporarily detach the successor of the current level, then create a | ||
# deep copy of the current level (which will include its ancestors), | ||
# then reattach the successor back to the original. This ensures that | ||
|
@@ -78,7 +84,7 @@ def make_decorrelate_parent( | |
# that is has to. | ||
new_hybrid._children = new_hybrid._children[:child_idx] | ||
new_hybrid._pipeline = new_hybrid._pipeline[: required_steps + 1] | ||
return new_hybrid | ||
return new_hybrid, 0 | ||
|
||
def remove_correl_refs( | ||
self, expr: HybridExpr, parent: HybridTree, child_height: int | ||
|
@@ -221,6 +227,7 @@ def decorrelate_child( | |
new_parent: HybridTree, | ||
child: HybridConnection, | ||
is_aggregate: bool, | ||
skipped_levels: int, | ||
) -> None: | ||
""" | ||
Runs the logic to de-correlate a child of a hybrid tree that contains | ||
|
@@ -230,6 +237,17 @@ def decorrelate_child( | |
child can now replace correlated references with BACK references that | ||
point to terms in its newly expanded ancestry, and the original hybrid | ||
tree can now join onto this child using its uniqueness keys. | ||
|
||
Args: | ||
`old_parent`: The correlated ancestor hybrid tree that the correlated | ||
references should point to when they are targeted for removal. | ||
`new_parent`: The ancestor of `level` that removal should stop at. | ||
`child`: The child of the hybrid tree that contains the correlated | ||
nodes to be removed. | ||
`is_aggregate`: Whether the child is being aggregated with regards | ||
to its parent. | ||
`skipped_levels`: The number of ancestor layers that should be | ||
ignored when deriving backshifts of join/agg keys. | ||
""" | ||
# First, find the height of the child subtree & its top-most level. | ||
child_root: HybridTree = child.subtree | ||
|
@@ -245,28 +263,41 @@ def decorrelate_child( | |
new_join_keys: list[tuple[HybridExpr, HybridExpr]] = [] | ||
additional_levels: int = 0 | ||
current_level: HybridTree | None = old_parent | ||
new_agg_keys: list[HybridExpr] = [] | ||
while current_level is not None: | ||
for unique_key in current_level.pipeline[0].unique_exprs: | ||
skip_join: bool = ( | ||
isinstance(current_level.pipeline[0], HybridPartition) | ||
and child is current_level.children[0] | ||
) | ||
for unique_key in sorted(current_level.pipeline[0].unique_exprs, key=str): | ||
lhs_key: HybridExpr | None = unique_key.shift_back(additional_levels) | ||
rhs_key: HybridExpr | None = unique_key.shift_back( | ||
additional_levels + child_height | ||
additional_levels + child_height - skipped_levels | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This "skipped_levels" comes into play in TPCH query 22, since the PARTITION should not be considered in the height since it is a level that gets skipped by the parent snapshotting function. |
||
) | ||
assert lhs_key is not None and rhs_key is not None | ||
new_join_keys.append((lhs_key, rhs_key)) | ||
if not skip_join: | ||
new_join_keys.append((lhs_key, rhs_key)) | ||
new_agg_keys.append(rhs_key) | ||
current_level = current_level.parent | ||
additional_levels += 1 | ||
child.subtree.join_keys = new_join_keys | ||
# If aggregating, do the same with the aggregation keys. | ||
# If aggregating, update the aggregation keys accordingly. | ||
if is_aggregate: | ||
new_agg_keys: list[HybridExpr] = [] | ||
assert child.subtree.join_keys is not None | ||
for _, rhs_key in child.subtree.join_keys: | ||
new_agg_keys.append(rhs_key) | ||
child.subtree.agg_keys = new_agg_keys | ||
|
||
def decorrelate_hybrid_tree(self, hybrid: HybridTree) -> HybridTree: | ||
""" | ||
TODO | ||
The recursive procedure to remove unwanted correlated references from | ||
the entire hybrid tree, called from the bottom and working upwards | ||
to the top layer, and having each layer also de-correlate its children. | ||
|
||
Args: | ||
`hybrid`: The hybrid tree to remove correlated references from. | ||
|
||
Returns: | ||
The hybrid tree with all invalid correlated references removed as the | ||
tree structure is re-written to allow them to be replaced with BACK | ||
references. The transformation is also done in-place. | ||
""" | ||
# Recursively decorrelate the ancestors of the current level of the | ||
# hybrid tree. | ||
|
@@ -282,18 +313,22 @@ def decorrelate_hybrid_tree(self, hybrid: HybridTree) -> HybridTree: | |
for idx, child in enumerate(hybrid.children): | ||
if idx not in hybrid.correlated_children: | ||
continue | ||
new_parent: HybridTree = self.make_decorrelate_parent( | ||
hybrid, idx, hybrid.children[idx].required_steps | ||
) | ||
match child.connection_type: | ||
case ( | ||
ConnectionType.SINGULAR | ||
| ConnectionType.SINGULAR_ONLY_MATCH | ||
| ConnectionType.AGGREGATION | ||
| ConnectionType.AGGREGATION_ONLY_MATCH | ||
): | ||
new_parent, skipped_levels = self.make_decorrelate_parent( | ||
hybrid, idx, hybrid.children[idx].required_steps | ||
) | ||
self.decorrelate_child( | ||
hybrid, new_parent, child, child.connection_type.is_aggregation | ||
hybrid, | ||
new_parent, | ||
child, | ||
child.connection_type.is_aggregation, | ||
skipped_levels, | ||
) | ||
case ConnectionType.NDISTINCT | ConnectionType.NDISTINCT_ONLY_MATCH: | ||
raise NotImplementedError( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -494,17 +494,26 @@ def __init__( | |
for name, expr in predecessor.terms.items(): | ||
terms[name] = HybridRefExpr(name, expr.typ) | ||
renamings.update(predecessor.renamings) | ||
new_renamings: dict[str, str] = {} | ||
for name, expr in new_expressions.items(): | ||
if name in terms and terms[name] == expr: | ||
continue | ||
expr = expr.apply_renamings(predecessor.renamings) | ||
used_name: str = name | ||
idx: int = 0 | ||
while used_name in terms or used_name in renamings: | ||
while ( | ||
used_name in terms | ||
or used_name in renamings | ||
or used_name in new_renamings | ||
): | ||
used_name = f"{name}_{idx}" | ||
idx += 1 | ||
terms[used_name] = expr | ||
renamings[name] = used_name | ||
new_renamings[name] = used_name | ||
renamings.update(new_renamings) | ||
for old_name, new_name in new_renamings.items(): | ||
expr = new_expressions.pop(old_name) | ||
new_expressions[new_name] = expr | ||
Comment on lines
+504
to
+516
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change allows us to ensure that the terms in |
||
super().__init__(terms, renamings, orderings, predecessor.unique_exprs) | ||
self.calc = Calc | ||
self.new_expressions = new_expressions | ||
|
@@ -520,7 +529,10 @@ class HybridFilter(HybridOperation): | |
|
||
def __init__(self, predecessor: HybridOperation, condition: HybridExpr): | ||
super().__init__( | ||
predecessor.terms, {}, predecessor.orderings, predecessor.unique_exprs | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The use of |
||
predecessor.terms, | ||
predecessor.renamings, | ||
predecessor.orderings, | ||
predecessor.unique_exprs, | ||
) | ||
self.predecessor: HybridOperation = predecessor | ||
self.condition: HybridExpr = condition | ||
|
@@ -566,7 +578,10 @@ def __init__( | |
records_to_keep: int, | ||
): | ||
super().__init__( | ||
predecessor.terms, {}, predecessor.orderings, predecessor.unique_exprs | ||
predecessor.terms, | ||
predecessor.renamings, | ||
predecessor.orderings, | ||
predecessor.unique_exprs, | ||
) | ||
self.predecessor: HybridOperation = predecessor | ||
self.records_to_keep: int = records_to_keep | ||
|
@@ -908,13 +923,13 @@ def __repr__(self): | |
lines.append(" -> ".join(repr(operation) for operation in self.pipeline)) | ||
prefix = " " if self.successor is None else "↓" | ||
for idx, child in enumerate(self.children): | ||
lines.append(f"{prefix} child #{idx}:") | ||
lines.append(f"{prefix} child #{idx} ({child.connection_type.name}):") | ||
if child.subtree.agg_keys is not None: | ||
lines.append( | ||
f"{prefix} aggregate: {child.subtree.agg_keys} -> {child.aggs}:" | ||
) | ||
lines.append(f"{prefix} aggregate: {child.subtree.agg_keys}") | ||
if len(child.aggs): | ||
lines.append(f"{prefix} aggs: {child.aggs}:") | ||
if child.subtree.join_keys is not None: | ||
lines.append(f"{prefix} join: {child.subtree.join_keys}:") | ||
lines.append(f"{prefix} join: {child.subtree.join_keys}") | ||
for line in repr(child.subtree).splitlines(): | ||
lines.append(f"{prefix} {line}") | ||
return "\n".join(lines) | ||
|
@@ -1964,6 +1979,24 @@ def make_hybrid_tree( | |
rhs_expr.name, 0, rhs_expr.typ | ||
) | ||
join_key_exprs.append((lhs_expr, rhs_expr)) | ||
|
||
case PartitionBy(): | ||
partition = HybridPartition() | ||
successor_hybrid = HybridTree(partition) | ||
self.populate_children( | ||
successor_hybrid, node.child_access, child_ref_mapping | ||
) | ||
partition_child_idx = child_ref_mapping[0] | ||
for key_name in node.calc_terms: | ||
key = node.get_expr(key_name) | ||
expr = self.make_hybrid_expr( | ||
successor_hybrid, key, child_ref_mapping, False | ||
) | ||
partition.add_key(key_name, expr) | ||
key_exprs.append(HybridRefExpr(key_name, expr.typ)) | ||
successor_hybrid.children[ | ||
partition_child_idx | ||
].subtree.agg_keys = key_exprs | ||
case _: | ||
raise NotImplementedError( | ||
f"{node.__class__.__name__} (child is {node.child_access.__class__.__name__})" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -885,11 +885,26 @@ def rel_translation( | |
if isinstance(operation.collection, TableCollection): | ||
result = self.build_simple_table_scan(operation) | ||
if context is not None: | ||
# If the collection access is the child of something | ||
# else, join it onto that something else. Use the | ||
# uniqueness keys of the ancestor, which should also be | ||
# present in the collection (e.g. joining a partition | ||
# onto the original data using the partition keys). | ||
assert preceding_hybrid is not None | ||
join_keys: list[tuple[HybridExpr, HybridExpr]] = [] | ||
for unique_column in sorted( | ||
preceding_hybrid[0].pipeline[0].unique_exprs, key=str | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This case arises from the |
||
): | ||
if unique_column not in result.expressions: | ||
raise ValueError( | ||
f"Cannot connect parent context to child {operation.collection} because {unique_column} is not in the child's expressions." | ||
) | ||
join_keys.append((unique_column, unique_column)) | ||
result = self.join_outputs( | ||
context, | ||
result, | ||
JoinType.INNER, | ||
[], | ||
join_keys, | ||
None, | ||
) | ||
else: | ||
|
@@ -913,7 +928,8 @@ def rel_translation( | |
assert context is not None, "Malformed HybridTree pattern." | ||
result = self.translate_filter(operation, context) | ||
case HybridPartition(): | ||
assert context is not None, "Malformed HybridTree pattern." | ||
if context is None: | ||
context = TranslationOutput(EmptySingleton(), {}) | ||
result = self.translate_partition( | ||
operation, context, hybrid, pipeline_idx | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,8 @@ | |
|
||
from .abstract_node import RelationalNode | ||
from .aggregate import Aggregate | ||
from .join import Join | ||
from .empty_singleton import EmptySingleton | ||
from .join import Join, JoinType | ||
from .project import Project | ||
from .relational_expression_dispatcher import RelationalExpressionDispatcher | ||
from .relational_root import RelationalRoot | ||
|
@@ -141,7 +142,23 @@ def _prune_node_columns( | |
|
||
# Determine the new node. | ||
output = new_node.copy(inputs=new_inputs) | ||
return self._prune_identity_project(output), correl_refs | ||
output = self._prune_identity_project(output) | ||
# Special case: replace empty aggregation with VALUES () if possible. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These special cases can arise because the decorrelation can cause the LHS of the join to have 100% of its columns unused, which is problematic. |
||
if ( | ||
isinstance(output, Aggregate) | ||
and len(output.keys) == 0 | ||
and len(output.aggregations) == 0 | ||
): | ||
return EmptySingleton(), correl_refs | ||
# Special case: replace join where LHS is VALUES () with the RHS if | ||
# possible. | ||
if ( | ||
isinstance(output, Join) | ||
and isinstance(output.inputs[0], EmptySingleton) | ||
and output.join_types in ([JoinType.INNER], [JoinType.LEFT]) | ||
): | ||
return output.inputs[1], correl_refs | ||
return output, correl_refs | ||
|
||
def prune_unused_columns(self, root: RelationalRoot) -> RelationalRoot: | ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -431,6 +431,7 @@ def visit_filter(self, filter: Filter) -> None: | |
# TODO: (gh #151) Refactor a simpler way to check dependent expressions. | ||
if ( | ||
"group" in input_expr.args | ||
or "distinct" in input_expr.args | ||
or "where" in input_expr.args | ||
or "qualify" in input_expr.args | ||
or "order" in input_expr.args | ||
|
@@ -462,6 +463,7 @@ def visit_aggregate(self, aggregate: Aggregate) -> None: | |
query: Select | ||
if ( | ||
"group" in input_expr.args | ||
or "distinct" in input_expr.args | ||
or "qualify" in input_expr.args | ||
or "order" in input_expr.args | ||
or "limit" in input_expr.args | ||
|
@@ -472,7 +474,10 @@ def visit_aggregate(self, aggregate: Aggregate) -> None: | |
select_cols, input_expr, find_identifiers_in_list(select_cols) | ||
) | ||
if keys: | ||
query = query.group_by(*keys) | ||
if aggregations: | ||
query = query.group_by(*keys) | ||
else: | ||
query = query.distinct() | ||
Comment on lines
+477
to
+480
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is so if we have an aggregate on keys |
||
self._stack.append(query) | ||
|
||
def visit_limit(self, limit: Limit) -> None: | ||
|
@@ -511,7 +516,7 @@ def visit_limit(self, limit: Limit) -> None: | |
self._stack.append(query) | ||
|
||
def visit_empty_singleton(self, singleton: EmptySingleton) -> None: | ||
self._stack.append(Select().from_(values([()]))) | ||
self._stack.append(Select().select(SQLGlotStar()).from_(values([()]))) | ||
Comment on lines
-514
to
+519
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Without this change, you can end up with |
||
|
||
def visit_root(self, root: RelationalRoot) -> None: | ||
self.visit_inputs(root) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changing this to keep track of how many times we had to recursively step upward, because those levels should not be counted later.