Skip to content

add dict interface splitargdef and combineargdef #179

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 62 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
export @esc, isexpr, isline, iscall, rmlines, unblock, block, inexpr, namify, isdef,
longdef, shortdef, @expand, makeif, prettify, combinedef, splitdef, splitarg, combinearg
longdef, shortdef, @expand, makeif, prettify, combinedef, splitdef, splitarg, combinearg, splitargdef, combineargdef

"""
assoc!(d, k, v)
Expand Down Expand Up @@ -403,7 +403,8 @@ end
"""
combinearg(arg_name, arg_type, is_splat, default)

`combinearg` is the inverse of [`splitarg`](@ref).
`combinearg` is the inverse of [`splitarg`](@ref). `combinearg` may oneday
be deprecated in favor of `combineargdef`
"""
function combinearg(arg_name, arg_type, is_splat, default)
@assert arg_name !== nothing || arg_type !== nothing
Expand All @@ -430,6 +431,8 @@ julia> map(splitarg, (:(f(a=2, x::Int=nothing, y, args...))).args[2:end])
(:args, :Any, true, nothing)
```

`splitarg` may oneday be deprecated in favor of `splitargdef`

See also: [`combinearg`](@ref)
"""
function splitarg(arg_expr)
Expand All @@ -450,6 +453,63 @@ function splitarg(arg_expr)
return (arg_name, arg_type, is_splat, default)
end

"""
combineargdef(dict)

`combineargdef` is the inverse of [`splitargdef`](@ref).
"""
function combineargdef(dict)
@assert haskey(dict, :name) || haskey(dict, :type)
if haskey(dict, :name)
a = dict[:name]
a = haskey(dict, :type) ? :($a::$(dict[:type])) : a
else
a = :(::$(dict[:type]))
end
a = dict[:is_splat] ? Expr(:..., a) : a
return haskey(dict, :default) ? Expr(:kw, a, dict[:default]) : a
end

"""
splitargdef(arg)

Match function arguments (whther from a definition or a function call) such as
`x::Int=2` and return `Dict(:name=>..., :is_splat=>..., etc.)`. The definition can be rebuilt by
calling `MacroTools.combineargdef(dict)`.
For example:

```julia
julia> map(splitarg, (:(f(a=2, x::Int=nothing, y, args...))).args[2:end])
4-element Array{Tuple{Symbol,Symbol,Bool,Any},1}:
Dict(:name=>:a, :is_splat=false, :default=>2)
Dict(:name=>:x, :type=>:Int, :is_splat=>false)
Dict(:name=>:y, :is_splat=>false)
Dict(:name=>:args, :is_splat=>true)
```

See also: [`combineargdef`](@ref)
"""
function splitargdef(arg_expr)
dict=Dict()
if @capture(arg_expr, arg_expr2_ = default_)
dict[:default] = default
else
arg_expr2 = arg_expr
end
is_splat = @capture(arg_expr2, arg_expr3_...)
is_splat || (arg_expr3 = arg_expr2)
dict[:is_splat] = is_splat
if @capture(arg_expr3, ::T_)
dict[:type]=T
elseif @capture(arg_expr3, name_::T_)
dict[:name]=name
dict[:type]=T
else
dict[:name] = arg_expr3
end
return dict
end


function flatten1(ex)
isexpr(ex, :block) || return ex
Expand Down
63 changes: 63 additions & 0 deletions test/split.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ macro splitcombine(fundef) # should be a no-op
esc(MacroTools.combinedef(dict))
end

macro splitcombine2(fundef) # should be a no-op
dict = splitdef(fundef)
dict[:args] = map(arg->combineargdef(splitargdef(arg)), dict[:args])
dict[:kwargs] = map(arg->combineargdef(splitargdef(arg)), dict[:kwargs])
esc(MacroTools.combinedef(dict))
end

# Macros for testing that splitcombine doesn't break
# macrocalls in bodies
macro zeroarg()
Expand All @@ -29,14 +36,22 @@ let
@test map(splitarg, (:(f(a=2, x::Int=nothing, y::Any, args...))).args[2:end]) ==
[(:a, :Any, false, 2), (:x, :Int, false, :nothing),
(:y, :Any, false, nothing), (:args, :Any, true, nothing)]
@test map(splitargdef, (:(f(a=2, x::Int=nothing, y::Any, args...))).args[2:end]) ==
[Dict(:name=>:a, :is_splat=>false, :default=>2), Dict(:name=>:x, :type=>:Int, :is_splat=>false, :default=>:nothing),
Dict(:name=>:y, :type=>:Any, :is_splat=>false), Dict(:name=>:args, :is_splat=>true)]
@test splitarg(:(::Int)) == (nothing, :Int, false, nothing)
@test splitargdef(:(::Int)) == Dict(:type=>:Int, :is_splat=>false)
kwargs = splitdef(:(f(; a::Int = 1, b...) = 1))[:kwargs]
@test map(splitarg, kwargs) ==
[(:a, :Int, false, 1), (:b, :Any, true, nothing)]
@test map(splitargdef, kwargs) ==
[Dict(:name=>:a, :type=>:Int, :is_splat=>false, :default=>1), Dict(:name=>:b, :is_splat=>true)]
args = splitdef(:(f(a::Int = 1) = 1))[:args]
@test map(splitarg, args) == [(:a, :Int, false, 1)]
@test map(splitargdef, args) == [Dict(:name=>:a, :type=>:Int, :is_splat=>false, :default=>1)]
args = splitdef(:(f(a::Int ... = 1) = 1))[:args]
@test map(splitarg, args) == [(:a, :Int, true, 1)] # issue 165
@test map(splitargdef, args) == [Dict(:name=>:a, :type=>:Int, :is_splat=>true, :default=>1)] # issue 165

@splitcombine foo(x) = x+2
@test foo(10) == 12
Expand Down Expand Up @@ -85,6 +100,54 @@ let
end)(1, Number[2.0]) == (Int, Number)
end

let
@splitcombine2 foo(x) = x+2
@test foo(10) == 12
@splitcombine2 add(a, b=2; c=3, d=4)::Float64 = a+b+c+d
@test add(1; d=10) === 16.0
@splitcombine2 fparam(a::T) where {T} = T
@test fparam([]) == Vector{Any}
struct Orange end
@splitcombine2 (::Orange)(x) = x+2
@test Orange()(10) == 12
@splitcombine2 fwhere(a::T) where T = T
@test fwhere(10) == Int
@splitcombine2 manywhere(x::T, y::Vector{U}) where T <: U where U = (T, U)
@test manywhere(1, Number[2.0]) == (Int, Number)
@splitcombine2 fmacro0() = @zeroarg
@test fmacro0() == 1
@splitcombine2 fmacro1() = @onearg 1
@test fmacro1() == 2

@splitcombine2 bar(; a::Int = 1, b...) = 2
@test bar(a=3, x = 1, y = 2) == 2
@splitcombine2 qux(a::Int... = 0) = sum(a)
@test qux(1, 2, 3) == 6
@test qux() == 0

struct Foo{A, B}
a::A
b::B
end
# Parametric outer constructor
@splitcombine2 Foo{A}(a::A) where A = Foo{A, A}(a,a)
@test Foo{Int}(2) == Foo{Int, Int}(2, 2)

@test (@splitcombine2 x -> x + 2)(10) === 12
@test (@splitcombine2 (a, b=2; c=3, d=4) -> a+b+c+d)(1; d=10) === 16
@test (@splitcombine2 ((a, b)::Tuple{Int,Int} -> a + b))((1, 2)) == 3
@test (@splitcombine2 ((a::T) where {T}) -> T)([]) === Vector{Any}
@test (@splitcombine2 ((x::T, y::Vector{U}) where T <: U where U) -> (T, U))(1, Number[2.0]) ==
(Int, Number)
@test (@splitcombine2 () -> @zeroarg)() == 1
@test (@splitcombine2 () -> @onearg 1)() == 2
@test (@splitcombine2 function (x) x + 2 end)(10) === 12
@test (@splitcombine2 function (a::T) where {T} T end)([]) === Vector{Any}
@test (@splitcombine2 function (x::T, y::Vector{U}) where T <: U where U
(T, U)
end)(1, Number[2.0]) == (Int, Number)
end

@testset "combinestructdef, splitstructdef" begin
ex = :(struct S end)
@test ex |> splitstructdef |> combinestructdef |> Base.remove_linenums! ==
Expand Down