Skip to content

Commit 5f7098e

Browse files
authored
Add field counters to allow you to check if a field was referenced (#37)
1 parent 7a0c414 commit 5f7098e

File tree

7 files changed

+223
-24
lines changed

7 files changed

+223
-24
lines changed

.github/workflows/run_tests.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ jobs:
77
runs-on: ubuntu-latest
88
strategy:
99
matrix:
10-
ruby: ['2.7', '3.2']
10+
ruby: ['3.2', '3.3']
1111
appraisal: ["rails-6", "rails-7"]
1212
steps:
1313
- uses: actions/checkout@v1
@@ -24,4 +24,4 @@ jobs:
2424
gem install bundler
2525
bundle install --jobs 4 --retry 3
2626
bundle exec appraisal install
27-
bundle exec appraisal ${{ matrix.appraisal }} rake
27+
bundle exec appraisal ${{ matrix.appraisal }} rake

lib/ar_query_matchers.rb

Lines changed: 131 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
require 'ar_query_matchers/queries/create_counter'
44
require 'ar_query_matchers/queries/load_counter'
55
require 'ar_query_matchers/queries/update_counter'
6+
require 'ar_query_matchers/queries/field_counter'
67
require 'bigdecimal'
78

89
module ArQueryMatchers
@@ -225,6 +226,86 @@ def failure_text
225226
end
226227
end
227228

229+
class FieldModels
230+
# The following will succeed:
231+
#
232+
# expect {
233+
# WcRiskClass.last.update_attributes(id: 9999)
234+
# WcRiskClass.last.update_attributes(id: 1234)
235+
# }.to query_by_field(
236+
# 'id' => [9999, 1234],
237+
# )
238+
#
239+
RSpec::Matchers.define(:query_by_field) do |expected = {}|
240+
include MatcherConfiguration
241+
include MatcherErrors
242+
243+
# Convert the map of expected values to a Hash of all arrays
244+
expected = expected.transform_values { |v| v.is_a(Array) ? v : [v] }
245+
246+
match do |block|
247+
@query_stats = Queries::FieldCounter.instrument(&block)
248+
expected == @query_stats.query_values
249+
end
250+
251+
def failure_text
252+
expectation_failed_message('query_by', show_values: true)
253+
end
254+
end
255+
256+
# The following will succeed:
257+
#
258+
# expect {
259+
# WcRiskClass.last.update_attributes(id: 9999)
260+
# WcRiskClass.last.update_attributes(id: 1234)
261+
# }.to query_by_field_at_least(
262+
# 'id' => 9999,
263+
# )
264+
#
265+
RSpec::Matchers.define(:query_by_field_at_least) do |expected = {}|
266+
include MatcherConfiguration
267+
include MatcherErrors
268+
269+
# Convert the map of expected values to a Hash of all arrays
270+
expected = expected.transform_values { |v| v.is_a?(Array) ? v : [v] }
271+
272+
match do |block|
273+
@query_stats = Queries::FieldCounter.instrument(&block)
274+
expected == @query_stats.query_values.select { |k, _| expected.keys.include?(k) }
275+
end
276+
277+
def failure_text
278+
expectation_failed_message('query_by', show_values: true, subset: true)
279+
end
280+
end
281+
282+
# The following will succeed:
283+
#
284+
# expect {
285+
# WcRiskClass.last.update_attributes(id: 9999)
286+
# WcRiskClass.last.update_attributes(id: 1234)
287+
# }.to query_by_field_at_least_ignore_notfound(
288+
# 'id' => 6666,
289+
# )
290+
#
291+
RSpec::Matchers.define(:query_by_field_at_least_ignore_notfound) do |expected = {}|
292+
include MatcherConfiguration
293+
include MatcherErrors
294+
295+
# Convert the map of expected values to a Hash of all arrays
296+
expected = expected.transform_values { |v| v.is_a?(Array) ? v : [v] }
297+
298+
match do |block|
299+
@query_stats = Queries::FieldCounter.instrument(&block)
300+
expected.select { |k, _| @query_stats.query_values.keys.include?(k) } == @query_stats.query_values.select { |k, _| expected.keys.include?(k) }
301+
end
302+
303+
def failure_text
304+
expectation_failed_message('query_by', show_values: true, subset: true, ignore_missing: true)
305+
end
306+
end
307+
end
308+
228309
# Shared methods that are included in the matchers.
229310
# They configure it and ensure we get consistent and human readable error messages
230311
module MatcherConfiguration
@@ -246,22 +327,37 @@ def supports_block_expectations?
246327
end
247328

248329
module MatcherErrors
249-
# Show the difference between expected and actual values with one value
250-
# per line. This is done by hand because as of this writing the author
251-
# doesn't understand how RSpec does its nice hash diff printing.
252-
def difference(keys)
253-
max_key_length = keys.reduce(0) { |max, key| [max, key.size].max }
330+
def create_display_string(max_key_length, key, left, right, show_values)
331+
diff_array = right - left
332+
diff_array = left - right if diff_array.empty?
333+
"#{key.rjust(max_key_length, ' ')} – expected: #{left}, got: #{right} #{"(difference: #{diff_array})" if show_values}"
334+
end
254335

336+
def loop_through_keys(keys, transformed_expected, show_values)
337+
max_key_length = keys.reduce(0) { |max, key| [max, key.size].max }
255338
keys.map do |key|
256-
left = expected.fetch(key, 0)
257-
right = @query_stats.queries.fetch(key, {}).fetch(:count, 0)
339+
left = transformed_expected.fetch(key, show_values ? [] : 0)
340+
left = [left] unless left.is_a?(Array) || show_values
258341

259-
diff = "#{'+' if right > left}#{right - left}"
342+
right = @query_stats.queries.fetch(key, {})
343+
right = show_values ? right.fetch(:values, []) : right.fetch(:count, 0)
260344

261-
"#{key.rjust(max_key_length, ' ')} – expected: #{left}, got: #{right} (#{diff})"
345+
create_display_string(max_key_length, key, left, right, show_values)
262346
end.compact
263347
end
264348

349+
# Show the difference between expected and actual values with one value
350+
# per line. This is done by hand because as of this writing the author
351+
# doesn't understand how RSpec does its nice hash diff printing.
352+
def difference(keys, show_values: false)
353+
transformed_expected = expected
354+
if show_values
355+
transformed_expected = expected.transform_values { |v| v.is_a?(Array) ? v : [v] }
356+
end
357+
358+
loop_through_keys keys, transformed_expected, show_values
359+
end
360+
265361
def source_lines(keys)
266362
line_frequency = @query_stats.query_lines_by_frequency
267363
keys_with_source_lines = keys.select { |key| line_frequency[key].present? }
@@ -281,10 +377,33 @@ def no_queries_fail_message(crud_operation)
281377
"Expected ActiveRecord to not #{crud_operation} any records, got #{@query_stats.query_counts}\n\nWhere unexpected queries came from:\n\n#{source_lines(@query_stats.query_counts.keys).join("\n")}"
282378
end
283379

284-
def expectation_failed_message(crud_operation)
380+
def reject_record(subset, current_expected, key, ignore_missing)
381+
if subset && !current_expected[key].nil?
382+
if ignore_missing
383+
@query_stats.queries[key][:values].empty? || (current_expected[key] - @query_stats.queries[key][:values]).empty?
384+
else
385+
(current_expected[key] - @query_stats.queries[key][:values]).empty?
386+
end
387+
else
388+
ignore_missing || @query_stats.queries[key][:values] == current_expected[key]
389+
end
390+
end
391+
392+
def filter_model_names(subset, show_values, ignore_missing)
285393
all_model_names = expected.keys + @query_stats.queries.keys
286-
model_names_with_wrong_count = all_model_names.reject { |key| expected[key] == @query_stats.queries[key][:count] }.uniq
287-
"Expected ActiveRecord to #{crud_operation} #{expected}, got #{@query_stats.query_counts}\nExpectations that differed:\n#{difference(model_names_with_wrong_count).join("\n")}\n\nWhere unexpected queries came from:\n\n#{source_lines(model_names_with_wrong_count).join("\n")}"
394+
if show_values
395+
transformed_expected = expected.transform_values { |v| v.is_a?(Array) ? v : [v] }
396+
all_model_names.reject { |key| reject_record(subset, transformed_expected, key, ignore_missing) }.uniq
397+
else
398+
all_model_names.reject { |key| expected[key] == @query_stats.queries[key][:count] }.uniq
399+
end
400+
end
401+
402+
def expectation_failed_message(crud_operation, show_values: false, subset: false, ignore_missing: false)
403+
model_names_with_wrong_count = filter_model_names(subset, show_values, ignore_missing)
404+
message = "Expected ActiveRecord to #{crud_operation} #{expected}, got #{show_values ? @query_stats.query_values : @query_stats.query_counts}\n"
405+
message += "Expectations that differed:\n#{difference(model_names_with_wrong_count, show_values: show_values).join("\n")}" if show_values
406+
message + "\n\nWhere unexpected queries came from:\n\n#{source_lines(model_names_with_wrong_count).join("\n")}"
288407
end
289408
end
290409
end
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# frozen_string_literal: true
2+
3+
require_relative './query_counter'
4+
require_relative './field_name'
5+
require_relative './query_filter'
6+
7+
module ArQueryMatchers
8+
module Queries
9+
# A specialized QueryCounter for "any action that involves field IDs".
10+
# For more information, see the QueryCounter class.
11+
class FieldCounter
12+
def self.instrument(&block)
13+
QueryCounter.new(FieldCounterFilter.new).instrument(&block)
14+
end
15+
16+
# Filters queries for counting purposes
17+
class FieldCounterFilter < Queries::QueryFilter
18+
# We need to look for a few things:
19+
# Anything with ` {field} = {value}` (this could be a select, update, delete)
20+
MODEL_FIELDS_PATTERN = /\.`(?<field_name>\w+)` = (?<field_value>[\w"`]+)/
21+
22+
# Anything with ` {field} IN ({value})` (this could be a select, update, delete)
23+
MODEL_FIELDS_IN_PATTERN = /\.`(?<field_name>\w+)` IN \((?<field_value>[\w"`]+)\)/
24+
25+
# Anything with `, field,` in an INSERT (we need to check the values)
26+
MODEL_INSERT_PATTERN = /INSERT INTO (?<table_name>[^`"]+) ... VALUES .../
27+
28+
def cleanup(value)
29+
cleaned_value = value.gsub '`', ''
30+
31+
# If this is an integer, we'll cast it automatically
32+
cleaned_value = value.to_i if cleaned_value == value
33+
34+
cleaned_value
35+
end
36+
37+
def filter_map(_name, sql)
38+
# We need to look for a few things:
39+
# - Anything with ` {field} = ` (this could be a select, update, delete)
40+
# - Anything with `, field,` in an INSERT (we need to check the values)
41+
select_field_query = sql.match(MODEL_FIELDS_PATTERN)
42+
# debugger if sql.match(/INSERT/)
43+
# TODO: MODEL_FIELDS_IN_PATTERN and MODEL_INSERT_PATTERN need to be handled
44+
45+
FieldName.new(select_field_query[:field_name], cleanup(select_field_query[:field_value])) if select_field_query
46+
end
47+
end
48+
end
49+
end
50+
end
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# frozen_string_literal: true
2+
3+
module ArQueryMatchers
4+
module Queries
5+
# An instance of this class is one of the values that could be returned from the QueryFilter#filter_map.
6+
# it accepts a name of an ActiveRecord model, for example: 'Company'.
7+
class FieldName
8+
attr_reader(:model_name, :model_value)
9+
10+
def initialize(model_name, model_value)
11+
@model_name = model_name
12+
@model_value = model_value
13+
end
14+
end
15+
end
16+
end

lib/ar_query_matchers/queries/load_counter.rb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def self.instrument(&block)
1717
class LoadQueryFilter < Queries::QueryFilter
1818
# Matches named SQL operations like the following:
1919
# 'User Load'
20-
MODEL_LOAD_PATTERN = /\A(?<model_name>[\w:]+) (Load|Exists)\Z/
20+
MODEL_LOAD_PATTERN = /\A(?<field_name>[\w:]+)/
2121

2222
# Matches unnamed SQL operations like the following:
2323
# "SELECT COUNT(*) FROM `users` ..."
@@ -27,7 +27,7 @@ def filter_map(name, sql)
2727
# First check for a `SELECT * FROM` query that ActiveRecord has
2828
# helpfully named for us in the payload
2929
match = name.match(MODEL_LOAD_PATTERN)
30-
return ModelName.new(match[:model_name]) if match
30+
return ModelName.new(match[:model_name]) if match&.names&.include? :model_name
3131

3232
# Fall back to pattern-matching on the table name in a COUNT and looking
3333
# up the table name from ActiveRecord's loaded descendants.

lib/ar_query_matchers/queries/query_counter.rb

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,14 @@ def query_counts
4444
Hash[*queries.reduce({}) { |acc, (model_name, data)| acc.update model_name => data[:count] }.sort_by(&:first).flatten]
4545
end
4646

47+
def query_values
48+
result = {}
49+
queries.each do |model_name, data|
50+
result[model_name] = data[:values]
51+
end
52+
result
53+
end
54+
4755
# @return [Hash] of line in the source code to its frequency
4856
def query_lines_by_frequency
4957
queries.reduce({}) do |lines, (model_name, data)|
@@ -62,7 +70,7 @@ def initialize(query_filter)
6270
# @param [block] block to instrument
6371
# @return [QueryStats] stats about all the SQL queries executed during the block
6472
def instrument(&block)
65-
queries = Hash.new { |h, k| h[k] = { count: 0, lines: [], time: BigDecimal(0) } }
73+
queries = Hash.new { |h, k| h[k] = { count: 0, lines: [], values: [], time: BigDecimal(0) } }
6674
ActiveSupport::Notifications.subscribed(to_proc(queries), 'sql.active_record', &block)
6775
QueryStats.new(queries)
6876
end
@@ -74,21 +82,26 @@ def instrument(&block)
7482
MARGINALIA_SQL_COMMENT_PATTERN = %r{/*line:(?<line>.*)'*/}
7583
private_constant :MARGINALIA_SQL_COMMENT_PATTERN
7684

85+
def add_to_query(queries, payload, model_obj, finish, start)
86+
model_name = model_obj&.model_name
87+
comment = payload[:sql].match(MARGINALIA_SQL_COMMENT_PATTERN)
88+
queries[model_name][:lines] << comment[:line] if comment
89+
queries[model_name][:count] += 1
90+
queries[model_name][:values].append(model_obj&.model_value) if model_obj.respond_to?(:model_value) && !queries[model_name][:values].include?(model_obj&.model_value)
91+
queries[model_name][:time] += (finish - start).round(6) # Round to microseconds
92+
end
93+
7794
def to_proc(queries)
7895
lambda do |_name, start, finish, _message_id, payload|
7996
return if payload[:cached]
8097

8198
# Given a `sql.active_record` event, figure out which model is being
8299
# accessed. Some of the simpler queries have a :name key that makes this
83100
# really easy. Others require parsing the SQL by hand.
84-
model_name = @query_filter.filter_map(payload[:name] || '', payload[:sql] || '')&.model_name
101+
model_obj = @query_filter.filter_map(payload[:name] || '', payload[:sql] || '')
102+
model_name = model_obj&.model_name
85103

86-
if model_name
87-
comment = payload[:sql].match(MARGINALIA_SQL_COMMENT_PATTERN)
88-
queries[model_name][:lines] << comment[:line] if comment
89-
queries[model_name][:count] += 1
90-
queries[model_name][:time] += (finish - start).round(6) # Round to microseconds
91-
end
104+
add_to_query(queries, payload, model_obj, finish, start) if model_name
92105
end
93106
end
94107
end

spec/ar_query_matchers/mock_data_model.rb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# @mission Infrastructure
44
# @team DEx
55

6+
require 'logger'
67
require 'active_record'
78

89
RSpec.shared_context('mock_data_model') do

0 commit comments

Comments
 (0)