Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/StructFieldParamsTesting.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module StructFieldParamsTesting

export test_all_structs_have_fully_specified_fields
export test_all_fields_fully_specified, field_is_fully_specified
export test_all_structs_have_fully_specified_fields, print_all_structs_have_fully_specified_fields
export test_all_fields_fully_specified, field_is_fully_specified, print_all_fields_fully_specified

using MacroTools: @capture
using Markdown: Markdown
Expand Down
204 changes: 177 additions & 27 deletions src/check_struct_fields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,147 @@ function field_is_fully_specified(pkg::Module, struct_expr, field_name; location
report_error = false, location = location)
end

function print_all_fields_fully_specified(pkg::Module, struct_expr; location = nothing)
(struct_name, T, fields_dict) = _extract_struct_field_types(pkg::Module, struct_expr)

# Extract field order from original struct expression
field_order = _extract_field_order(struct_expr)

# Check each field and collect unspecified ones
unspecified_fields = Dict{Symbol, Any}()
fully_specified_fields = Dict{Symbol, Any}()
corrected_fields = Dict{Symbol, Any}()

for (field_name, field_type_expr) in fields_dict
is_specified = check_field_type_fully_specified(
pkg, struct_name, field_name, T, field_type_expr,
report_error = false, location = location)

if is_specified
fully_specified_fields[field_name] = field_type_expr
corrected_fields[field_name] = field_type_expr
else
unspecified_fields[field_name] = field_type_expr
# Generate corrected type
try
TypeObj = Base.eval(pkg, quote
$(field_type_expr) where {$(T...)}
end)
complete_type = single_unwrap_unionall(TypeObj)
corrected_fields[field_name] = complete_type
catch e
corrected_fields[field_name] = field_type_expr
end
end
end

# Print original struct definition
println("Original struct definition:")
if location !== nothing
printstyled("#= $location =#\n"; color=:light_black)
end
_print_struct_definition_colored(struct_name, T, fields_dict, field_order, struct_expr, unspecified_fields, :red)
println()

# Only show corrected version if there are issues
if !isempty(unspecified_fields)
println("Fully-specified version:")
_print_struct_definition_colored(struct_name, T, corrected_fields, field_order, struct_expr, unspecified_fields, :green)
end
end

function _extract_field_order(struct_expr)
@capture(
struct_expr,
struct name_{T__} <: S_ fields__ end | struct name_ <: S_ fields__ end |
struct name_{T__} fields__ end | struct name_ fields__ end |
mutable struct name_{T__} <: S_ fields__ end | mutable struct name_ <: S_ fields__ end |
mutable struct name_{T__} fields__ end | mutable struct name_ fields__ end
) || error("Invalid struct expression: $(struct_expr)")

fields_split = split_field.(fields)
filter!(x -> x !== nothing, fields_split)
return [f[1] for f in fields_split]
end

function _print_struct_definition_colored(struct_name, T, fields_dict, field_order, original_expr, unspecified_fields, color)
# Determine if it's mutable by looking at the original expression
is_mutable = false
if original_expr isa Expr && original_expr.head == :struct && length(original_expr.args) >= 1
is_mutable = original_expr.args[1] isa Bool ? original_expr.args[1] : false
end

# Build struct definition string
prefix = is_mutable ? "mutable struct " : "struct "

# Handle type parameters
if isempty(T)
struct_header = "$prefix$struct_name"
else
struct_header = "$prefix$struct_name{$(join(T, ", "))}"
end

println(struct_header)

# Print fields in original order with color for changed fields
for field_name in field_order
field_type_expr = fields_dict[field_name]
field_line = if field_type_expr == Any
" $field_name"
else
" $field_name::$field_type_expr"
end

# Color the line if this field was unspecified
if field_name in keys(unspecified_fields)
if color == :red
printstyled(field_line, "\n"; color=:red)
elseif color == :green
printstyled(field_line, "\n"; color=:green)
else
println(field_line)
end
else
println(field_line)
end
end

println("end")
end

function _print_struct_definition(struct_name, T, fields_dict, field_order, original_expr)
# Determine if it's mutable by looking at the original expression
is_mutable = false
if original_expr isa Expr && original_expr.head == :struct && length(original_expr.args) >= 1
is_mutable = original_expr.args[1] isa Bool ? original_expr.args[1] : false
end

# Build struct definition string
prefix = is_mutable ? "mutable struct " : "struct "

# Handle type parameters
if isempty(T)
struct_header = "$prefix$struct_name"
else
struct_header = "$prefix$struct_name{$(join(T, ", "))}"
end

println(struct_header)

# Print fields in original order
for field_name in field_order
field_type_expr = fields_dict[field_name]
if field_type_expr == Any
println(" $field_name")
else
println(" $field_name::$field_type_expr")
end
end

println("end")
end


function _extract_struct_field_types(pkg::Module, struct_expr)
@capture(
struct_expr,
Expand Down Expand Up @@ -66,16 +207,14 @@ function check_field_type_fully_specified(
@debug "Type is a DataType: $(TypeObj)"
return true
end
if typeof(TypeObj) == Union
# TODO: Handle every branch of the union.
# For now, just skip these fields.
return true
end
if TypeObj == Type
# TODO: FOR NOW, to avoid noisy result
return true
end
@assert typeof(TypeObj) === UnionAll "$(TypeObj) is not a UnionAll. Got $(typeof(TypeObj))."
@assert (
typeof(TypeObj) === UnionAll ||
typeof(TypeObj) === Union
) "$(TypeObj) is not a UnionAll. Got $(typeof(TypeObj))."

num_type_params = _count_unionall_free_parameters(TypeObj)
num_expr_args = _count_type_expr_params(mod, field_type_expr)
Expand All @@ -90,6 +229,14 @@ function check_field_type_fully_specified(
end
return success
end
function check_field_type_fully_specified(mod, TypeObj, field_type_expr)
num_type_params = _count_unionall_free_parameters(TypeObj)
num_expr_args = _count_type_expr_params(mod, field_type_expr)
# "Less than or equal to" in order to support fully constrained parameters in the expr.
# E.g.: `Vector{T} where T<:Int` has 0 free type params but 1 param in the expression.
success = num_type_params <= num_expr_args
return success
end

# TODO(type-alias): What do we actually want to do for alias types?
# E.g.
Expand All @@ -103,9 +250,14 @@ recursive_unwrap_unionall(@nospecialize(T)) = Base.unwrap_unionall(T)
recursive_unwrap_unionall(T::UnionAll) = recursive_unwrap_unionall(Base.unwrap_unionall(T))
recursive_unwrap_unionall(T::Union) = Union{recursive_unwrap_unionall(T.a), recursive_unwrap_unionall(T.b)}

single_unwrap_unionall(@nospecialize(T)) = Base.unwrap_unionall(T)
single_unwrap_unionall(T::Union) = Union{single_unwrap_unionall(T.a), single_unwrap_unionall(T.b)}

# Get free TypeVar names (without constraints):
# Foo{Int, X<:Integer, Y} where {X, Y} => [:X, :Y]
function type_param_names(TypeObj)
type_param_names(TypeObj) = Symbol[]
type_param_names(TypeObj::Union) = Symbol[type_param_names(TypeObj.a)..., type_param_names(TypeObj.b)...]
function type_param_names(TypeObj::UnionAll)
names = Symbol[]
while typeof(TypeObj) === UnionAll
push!(names, TypeObj.var.name)
Expand Down Expand Up @@ -162,13 +314,13 @@ function field_type_not_complete_message(
@assert num_type_params <= length(type_params)
type_params = type_params[1:num_type_params]
expr_args = num_expr_args == 0 ? Symbol[] : type_params[1:num_expr_args]
complete_type = Base.unwrap_unionall(TypeObj)
complete_type = single_unwrap_unionall(TypeObj)
# TODO(type-alias): see comment on recursive_unwrap_unionall
complete_type_recursive = recursive_unwrap_unionall(TypeObj)
s = num_type_params == 1 ? "" : "s"
print(io, """
In struct `$(mod).$(struct_name)`, the field `$(field_name)` does not have a fully \
specified type:\n
print(io, """\n
In struct `$(mod).$(struct_name)`,
the field `$(field_name)` does not have a fully specified type:\n
\t$(field_name)::$(field_type_expr)\n
"""
)
Expand All @@ -181,44 +333,42 @@ function field_type_not_complete_message(
print(io, """
The complete type is:\n
\t$(complete_type)\n
which expects $(num_type_params) type parameter$(s): `$(join(type_params, ", "))`.\n
which expects $(num_type_params) additional type parameter$(s): `$(join(type_params, ", "))`.\n
""")
if string(complete_type) != string(complete_type_recursive)
print(io, """
And which is an alias for:\n
\t$(complete_type_recursive)\n
""")
end

print(io, """
The current definition `$(field_type_expr)` specifies \
$(num_expr_args == 0 ? "no type parameters." : "only $(num_expr_args) type parameters: \
`$(join(expr_args, ", "))`.")
""")

print(io, """
This means the `$(field_name)` field currently has an abstract type,
and any access to it is type unstable will therefore cause a dynamic dispatch.
This means the `$(field_name)` field currently has an abstract type, and any access to it,
like `x.$(field_name)`, is type unstable and will therefore cause a dynamic dispatch.

If this was a mistake, possibly caused by a change to the `$(typename)` type that \
introduced new parameters to it, please make sure that your field `$(field_name)` is \
If this was a mistake, possibly caused by a change to the `$(typename)` type that
introduced new parameters to it, please make sure that your field `$(field_name)` is
fully concrete, with all parameters specified.

If, instead, this type instability is on purpose, please fully specify the omitted \
type parameters to silence this message. You can write that as `$(complete_type)`, or \
possibly in a shorter alias form which this message can't always detect. (E.g. you can \
If, instead, this type instability is on purpose, please fully specify the omitted
type parameters to silence this message. You can write that as `$(complete_type)`, or
possibly in a shorter alias form which this message can't always detect. (E.g. you can
write `Vector{T} where T` instead of `Array{T, 1} where T`.)
""")
return io
end

_count_unionall_free_parameters(@nospecialize(::Any)) = 0
function _count_unionall_free_parameters(TypeObj::Union)
aa = TypeObj.a
bb = TypeObj.b
return _count_unionall_free_parameters(aa) + _count_unionall_free_parameters(bb)
end
function _count_unionall_free_parameters(TypeObj::UnionAll)
return _count_unionall_free_parameters(Base.unwrap_unionall(TypeObj))
end
function _count_unionall_free_parameters(TypeObj::DataType)
count = 0
for param in @show TypeObj.parameters
for param in TypeObj.parameters
# only `TypeVars` can be free parameters, but in `T<:ConcreteType`
# don't consider `T` as a free parameter
if param isa TypeVar && !isconcretetype(param.ub)
Expand Down
Loading
Loading