@@ -101,6 +101,33 @@ def merge_includes(hash1, hash2)
101
101
end
102
102
end
103
103
104
+ def assert_klass_has_instance_method ( klass , instance_method )
105
+ klass . instance_method ( instance_method )
106
+ rescue NameError => err
107
+ msg = "#{ klass } is missing the method our prepended code is expecting to patch. Was the undefined method removed or renamed upstream?\n See: #{ __FILE__ } .\n The NameError was: #{ err } . "
108
+ raise NameError , msg
109
+ end
110
+
111
+ if ActiveRecord . version >= Gem ::Version . new ( 7.0 ) # Rails 7.0 expected methods to patch
112
+ %w[
113
+ grouped_records
114
+ ] . each { |method | assert_klass_has_instance_method ( ActiveRecord ::Associations ::Preloader ::Branch , method ) }
115
+ elsif ActiveRecord . version >= Gem ::Version . new ( 6.1 ) # Rails 6.1 methods to patch
116
+ %w[
117
+ preloaders_for_reflection
118
+ preloaders_for_hash
119
+ preloaders_for_one
120
+ grouped_records
121
+ ] . each { |method | assert_klass_has_instance_method ( ActiveRecord ::Associations ::Preloader , method ) }
122
+ end
123
+
124
+ # Expected methods to patch on any version
125
+ %w[
126
+ build_select
127
+ arel_column
128
+ construct_join_dependency
129
+ ] . each { |method | assert_klass_has_instance_method ( ActiveRecord ::Relation , method ) }
130
+
104
131
module ActiveRecord
105
132
class Base
106
133
include ActiveRecord ::VirtualAttributes ::VirtualFields
@@ -178,34 +205,42 @@ def grouped_records(orig_association, records, polymorphic_parent)
178
205
end
179
206
# rubocop:enable Style/BlockDelimiters, Lint/AmbiguousBlockAssociation, Style/MethodCallWithArgsParentheses
180
207
} )
208
+ class Branch
209
+ prepend ( Module . new {
210
+ def grouped_records
211
+ h = { }
212
+ polymorphic_parent = !root? && parent . polymorphic?
213
+ source_records . each do |record |
214
+ # each class can resolve virtual_{attributes,includes} differently
215
+ @association = record . class . replace_virtual_fields ( association )
216
+
217
+ # 1 line optimization for single element array:
218
+ @association = association . first if association . kind_of? ( Array ) # && association.size == 1
219
+
220
+ case association
221
+ when Symbol , String
222
+ reflection = record . class . _reflect_on_association ( association )
223
+ next if polymorphic_parent && !reflection || !record . association ( association ) . klass
224
+ when nil
225
+ next
226
+ else # need parent (preloaders_for_{hash,one}) to handle this Array/Hash
227
+ reflection = association
228
+ end
229
+ ( h [ reflection ] ||= [ ] ) << record
230
+ end
231
+ h
232
+ end
233
+ } )
234
+ end if ActiveRecord . version >= Gem ::Version . new ( 7.0 )
181
235
end
182
236
end
183
237
184
238
class Relation
185
- def without_virtual_includes
186
- filtered_includes = includes_values && klass . replace_virtual_fields ( includes_values )
187
- if filtered_includes != includes_values
188
- spawn . tap { |other | other . includes_values = filtered_includes }
189
- else
190
- self
191
- end
192
- end
193
-
194
239
include ( Module . new {
195
- # From ActiveRecord::FinderMethods
196
- def apply_join_dependency ( *args , **kargs , &block )
197
- real = without_virtual_includes
198
- if real . equal? ( self )
199
- super
200
- else
201
- real . apply_join_dependency ( *args , **kargs , &block )
202
- end
203
- end
204
-
205
240
# From ActiveRecord::QueryMethods (rails 5.2 - 6.1)
206
241
def build_select ( arel )
207
242
if select_values . any?
208
- cols = arel_columns ( select_values . uniq ) . map do |col |
243
+ cols = arel_columns ( select_values ) . map do |col |
209
244
# if it is a virtual attribute, then add aliases to those columns
210
245
if col . kind_of? ( Arel ::Nodes ::Grouping ) && col . name
211
246
col . as ( connection . quote_column_name ( col . name ) )
@@ -233,16 +268,6 @@ def construct_join_dependency(associations, join_type) # :nodoc:
233
268
associations = klass . replace_virtual_fields ( associations )
234
269
super
235
270
end
236
-
237
- # From ActiveRecord::Calculations
238
- # introduces virtual includes support for calculate (we mostly use COUNT(*))
239
- def calculate ( operation , attribute_name )
240
- # allow calculate to work with includes and a virtual attribute
241
- real = without_virtual_includes
242
- return super if real . equal? ( self )
243
-
244
- real . calculate ( operation , attribute_name )
245
- end
246
271
} )
247
272
end
248
273
end
0 commit comments