@@ -109,15 +109,37 @@ impl<T: RelNodeTyp> Memo<T> {
109
109
ExprId ( id)
110
110
}
111
111
112
- fn merge_group ( & mut self , group_a : ReducedGroupId , group_b : ReducedGroupId ) -> ReducedGroupId {
112
+ fn merge_group_inner (
113
+ & mut self ,
114
+ group_a : ReducedGroupId ,
115
+ group_b : ReducedGroupId ,
116
+ ) -> ReducedGroupId {
113
117
if group_a == group_b {
114
118
return group_a;
115
119
}
116
120
self . merged_groups
117
121
. insert ( group_a. as_group_id ( ) , group_b. as_group_id ( ) ) ;
122
+
123
+ // Copy all expressions from group a to group b
124
+ let group_a_exprs = self . get_all_exprs_in_group ( group_a. as_group_id ( ) ) ;
125
+ for expr_id in group_a_exprs {
126
+ let expr_node = self . expr_id_to_expr_node . get ( & expr_id) . unwrap ( ) ;
127
+ self . add_expr_to_group ( expr_id, group_b, expr_node. as_ref ( ) . clone ( ) ) ;
128
+ }
129
+
130
+ // Remove all expressions from group a (so we don't accidentally access it)
131
+ self . clear_exprs_in_group ( group_a) ;
132
+
118
133
group_b
119
134
}
120
135
136
+ pub fn merge_group ( & mut self , group_a : GroupId , group_b : GroupId ) -> GroupId {
137
+ let group_a_reduced = self . get_reduced_group_id ( group_a) ;
138
+ let group_b_reduced = self . get_reduced_group_id ( group_b) ;
139
+ self . merge_group_inner ( group_a_reduced, group_b_reduced)
140
+ . as_group_id ( )
141
+ }
142
+
121
143
fn get_group_id_of_expr_id ( & self , expr_id : ExprId ) -> GroupId {
122
144
self . expr_id_to_group_id [ & expr_id]
123
145
}
@@ -136,9 +158,11 @@ impl<T: RelNodeTyp> Memo<T> {
136
158
rel_node : RelNodeRef < T > ,
137
159
add_to_group_id : Option < GroupId > ,
138
160
) -> ( GroupId , ExprId ) {
139
- if rel_node. typ . extract_group ( ) . is_some ( ) {
140
- unreachable ! ( ) ;
141
- }
161
+ let node_current_group = rel_node. typ . extract_group ( ) ;
162
+ if let ( Some ( grp_a) , Some ( grp_b) ) = ( add_to_group_id, node_current_group) {
163
+ self . merge_group ( grp_a, grp_b) ;
164
+ } ;
165
+
142
166
let ( group_id, expr_id) = self . add_new_group_expr_inner (
143
167
rel_node,
144
168
add_to_group_id. map ( |x| self . get_reduced_group_id ( x) ) ,
@@ -198,6 +222,10 @@ impl<T: RelNodeTyp> Memo<T> {
198
222
props
199
223
}
200
224
225
+ fn clear_exprs_in_group ( & mut self , group_id : ReducedGroupId ) {
226
+ self . groups . remove ( & group_id) ;
227
+ }
228
+
201
229
fn add_expr_to_group (
202
230
& mut self ,
203
231
expr_id : ExprId ,
@@ -243,7 +271,7 @@ impl<T: RelNodeTyp> Memo<T> {
243
271
let group_id = self . get_group_id_of_expr_id ( expr_id) ;
244
272
let group_id = self . get_reduced_group_id ( group_id) ;
245
273
if let Some ( add_to_group_id) = add_to_group_id {
246
- self . merge_group ( add_to_group_id, group_id) ;
274
+ self . merge_group_inner ( add_to_group_id, group_id) ;
247
275
}
248
276
return ( group_id, expr_id) ;
249
277
}
0 commit comments