Skip to content

Commit 989f2b5

Browse files
committed
Add field counters to allow you to check if a field was referenced
1 parent 7a0c414 commit 989f2b5

File tree

5 files changed

+169
-11
lines changed

5 files changed

+169
-11
lines changed

lib/ar_query_matchers.rb

Lines changed: 87 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
require 'ar_query_matchers/queries/create_counter'
44
require 'ar_query_matchers/queries/load_counter'
55
require 'ar_query_matchers/queries/update_counter'
6+
require 'ar_query_matchers/queries/field_counter'
67
require 'bigdecimal'
78

89
module ArQueryMatchers
@@ -225,6 +226,60 @@ def failure_text
225226
end
226227
end
227228

229+
class FieldModels
230+
# The following will succeed:
231+
#
232+
# expect {
233+
# WcRiskClass.last.update_attributes(id: 9999)
234+
# WcRiskClass.last.update_attributes(id: 1234)
235+
# }.to query_by_field(
236+
# 'id' => [9999, 1234],
237+
# )
238+
#
239+
RSpec::Matchers.define(:query_by_field) do |expected = {}|
240+
include MatcherConfiguration
241+
include MatcherErrors
242+
243+
# Convert the map of expected values to a Hash of all arrays
244+
expected = expected.map { |k, v| [k, v.kind_of?(Array) ? v : [v]] }.to_h
245+
246+
match do |block|
247+
@query_stats = Queries::FieldCounter.instrument(&block)
248+
expected == @query_stats.query_values
249+
end
250+
251+
def failure_text
252+
expectation_failed_message('query_by', true)
253+
end
254+
end
255+
256+
# The following will succeed:
257+
#
258+
# expect {
259+
# WcRiskClass.last.update_attributes(id: 9999)
260+
# WcRiskClass.last.update_attributes(id: 1234)
261+
# }.to query_by_field_at_least(
262+
# 'id' => 9999,
263+
# )
264+
#
265+
RSpec::Matchers.define(:query_by_field_at_least) do |expected = {}|
266+
include MatcherConfiguration
267+
include MatcherErrors
268+
269+
# Convert the map of expected values to a Hash of all arrays
270+
_expected = expected.map { |k, v| [k, v.kind_of?(Array) ? v : [v]] }.to_h
271+
272+
match do |block|
273+
@query_stats = Queries::FieldCounter.instrument(&block)
274+
_expected == @query_stats.query_values.select{ |k,v| _expected.keys.include?(k) }
275+
end
276+
277+
def failure_text
278+
expectation_failed_message('query_by', true, true)
279+
end
280+
end
281+
end
282+
228283
# Shared methods that are included in the matchers.
229284
# They configure it and ensure we get consistent and human readable error messages
230285
module MatcherConfiguration
@@ -249,16 +304,26 @@ module MatcherErrors
249304
# Show the difference between expected and actual values with one value
250305
# per line. This is done by hand because as of this writing the author
251306
# doesn't understand how RSpec does its nice hash diff printing.
252-
def difference(keys)
307+
def difference(keys, values_only=false)
253308
max_key_length = keys.reduce(0) { |max, key| [max, key.size].max }
254309

255310
keys.map do |key|
256-
left = expected.fetch(key, 0)
257-
right = @query_stats.queries.fetch(key, {}).fetch(:count, 0)
311+
left = expected.fetch(key, values_only ? [] : 0)
312+
right = @query_stats.queries.fetch(key, {})
313+
if values_only
314+
left = [left] if not left.kind_of?(Array)
315+
right = right.fetch(:values, [])
316+
else
317+
right = right.fetch(:count, 0)
318+
end
258319

259-
diff = "#{'+' if right > left}#{right - left}"
320+
if values_only
321+
"#{key.rjust(max_key_length, ' ')} – expected: #{left}, got: #{right}"
322+
else
323+
diff = "#{'+' if right > left}#{right - left}"
324+
"#{key.rjust(max_key_length, ' ')} – expected: #{left}, got: #{right} (#{diff})"
325+
end
260326

261-
"#{key.rjust(max_key_length, ' ')} – expected: #{left}, got: #{right} (#{diff})"
262327
end.compact
263328
end
264329

@@ -281,10 +346,24 @@ def no_queries_fail_message(crud_operation)
281346
"Expected ActiveRecord to not #{crud_operation} any records, got #{@query_stats.query_counts}\n\nWhere unexpected queries came from:\n\n#{source_lines(@query_stats.query_counts.keys).join("\n")}"
282347
end
283348

284-
def expectation_failed_message(crud_operation)
349+
def expectation_failed_message(crud_operation, values_only=false, subset=false)
350+
if values_only
351+
_expected = expected.map { |k, v| [k, v.kind_of?(Array) ? v : [v]] }.to_h
352+
expected = _expected
353+
end
285354
all_model_names = expected.keys + @query_stats.queries.keys
286-
model_names_with_wrong_count = all_model_names.reject { |key| expected[key] == @query_stats.queries[key][:count] }.uniq
287-
"Expected ActiveRecord to #{crud_operation} #{expected}, got #{@query_stats.query_counts}\nExpectations that differed:\n#{difference(model_names_with_wrong_count).join("\n")}\n\nWhere unexpected queries came from:\n\n#{source_lines(model_names_with_wrong_count).join("\n")}"
355+
if values_only
356+
model_names_with_wrong_count = all_model_names.reject { |key|
357+
if subset && !expected[key].nil?
358+
(expected[key] - @query_stats.queries[key][:values]).empty?
359+
else
360+
@query_stats.queries[key][:values] == expected[key]
361+
end
362+
}.uniq
363+
else
364+
model_names_with_wrong_count = all_model_names.reject { |key| expected[key] == @query_stats.queries[key][:count] }.uniq
365+
end
366+
"Expected ActiveRecord to #{crud_operation} #{expected}, got #{values_only ? @query_stats.query_values : @query_stats.query_counts}\nExpectations that differed:\n#{difference(model_names_with_wrong_count, values_only).join("\n")}\n\nWhere unexpected queries came from:\n\n#{source_lines(model_names_with_wrong_count).join("\n")}"
288367
end
289368
end
290369
end
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# frozen_string_literal: true
2+
3+
require_relative './query_counter'
4+
require_relative './field_name'
5+
require_relative './query_filter'
6+
7+
module ArQueryMatchers
8+
module Queries
9+
# A specialized QueryCounter for "any action that involves field IDs".
10+
# For more information, see the QueryCounter class.
11+
class FieldCounter
12+
def self.instrument(&block)
13+
QueryCounter.new(FieldCounterFilter.new).instrument(&block)
14+
end
15+
16+
class FieldCounterFilter < Queries::QueryFilter
17+
# We need to look for a few things:
18+
# Anything with ` {field} = {value}` (this could be a select, update, delete)
19+
MODEL_FIELDS_PATTERN = /\.`(?<field_name>[\w]+)` = (?<field_value>[\w"`]+)/
20+
21+
# Anything with ` {field} IN ({value})` (this could be a select, update, delete)
22+
MODEL_FIELDS_IN_PATTERN = /\.`(?<field_name>[\w]+)` IN \((?<field_value>[\w"`]+)\)/
23+
24+
# Anything with `, field,` in an INSERT (we need to check the values)
25+
MODEL_INSERT_PATTERN = /INSERT INTO (?<table_name>[^`"]+) ... VALUES .../
26+
27+
def cleanup(value)
28+
cleaned_value = value.gsub '`', ''
29+
30+
# If this is an integer, we'll cast it automatically
31+
if cleaned_value == value
32+
cleaned_value = value.to_i
33+
end
34+
35+
cleaned_value
36+
end
37+
38+
def filter_map(name, sql)
39+
# We need to look for a few things:
40+
# - Anything with ` {field} = ` (this could be a select, update, delete)
41+
# - Anything with `, field,` in an INSERT (we need to check the values)
42+
select_field_query = sql.match(MODEL_FIELDS_PATTERN)
43+
if sql.match(/INSERT/)
44+
debugger
45+
end
46+
47+
FieldName.new(select_field_query[:field_name], cleanup(select_field_query[:field_value])) if select_field_query
48+
end
49+
end
50+
end
51+
end
52+
end
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# frozen_string_literal: true
2+
3+
module ArQueryMatchers
4+
module Queries
5+
# An instance of this class is one of the values that could be returned from the QueryFilter#filter_map.
6+
# its accepts a name of an ActiveRecord model, for example: 'Company'.
7+
class FieldName
8+
attr_reader(:model_name)
9+
attr_reader(:model_value)
10+
11+
def initialize(model_name, model_value)
12+
@model_name = model_name
13+
@model_value = model_value
14+
end
15+
end
16+
end
17+
end

lib/ar_query_matchers/queries/load_counter.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def self.instrument(&block)
1717
class LoadQueryFilter < Queries::QueryFilter
1818
# Matches named SQL operations like the following:
1919
# 'User Load'
20-
MODEL_LOAD_PATTERN = /\A(?<model_name>[\w:]+) (Load|Exists)\Z/
20+
MODEL_LOAD_PATTERN = /\A(?<field_name>[\w:]+)/
2121

2222
# Matches unnamed SQL operations like the following:
2323
# "SELECT COUNT(*) FROM `users` ..."

lib/ar_query_matchers/queries/query_counter.rb

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,14 @@ def query_counts
4444
Hash[*queries.reduce({}) { |acc, (model_name, data)| acc.update model_name => data[:count] }.sort_by(&:first).flatten]
4545
end
4646

47+
def query_values
48+
result = {}
49+
queries.each do |model_name, data|
50+
result[model_name] = data[:values]
51+
end
52+
result
53+
end
54+
4755
# @return [Hash] of line in the source code to its frequency
4856
def query_lines_by_frequency
4957
queries.reduce({}) do |lines, (model_name, data)|
@@ -62,7 +70,7 @@ def initialize(query_filter)
6270
# @param [block] block to instrument
6371
# @return [QueryStats] stats about all the SQL queries executed during the block
6472
def instrument(&block)
65-
queries = Hash.new { |h, k| h[k] = { count: 0, lines: [], time: BigDecimal(0) } }
73+
queries = Hash.new { |h, k| h[k] = { count: 0, lines: [], values: [], time: BigDecimal(0) } }
6674
ActiveSupport::Notifications.subscribed(to_proc(queries), 'sql.active_record', &block)
6775
QueryStats.new(queries)
6876
end
@@ -81,12 +89,14 @@ def to_proc(queries)
8189
# Given a `sql.active_record` event, figure out which model is being
8290
# accessed. Some of the simpler queries have a :name key that makes this
8391
# really easy. Others require parsing the SQL by hand.
84-
model_name = @query_filter.filter_map(payload[:name] || '', payload[:sql] || '')&.model_name
92+
model_obj = @query_filter.filter_map(payload[:name] || '', payload[:sql] || '')
93+
model_name = model_obj&.model_name
8594

8695
if model_name
8796
comment = payload[:sql].match(MARGINALIA_SQL_COMMENT_PATTERN)
8897
queries[model_name][:lines] << comment[:line] if comment
8998
queries[model_name][:count] += 1
99+
queries[model_name][:values].append(model_obj&.model_value) if model_obj.respond_to?(:model_value)
90100
queries[model_name][:time] += (finish - start).round(6) # Round to microseconds
91101
end
92102
end

0 commit comments

Comments
 (0)