diff --git a/src/utils.jl b/src/utils.jl index a562c48..c34d4b2 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,5 +1,5 @@ export @esc, isexpr, isline, iscall, rmlines, unblock, block, inexpr, namify, isdef, - longdef, shortdef, @expand, makeif, prettify, combinedef, splitdef, splitarg, combinearg + longdef, shortdef, @expand, makeif, prettify, combinedef, splitdef, splitarg, combinearg, splitargdef, combineargdef """ assoc!(d, k, v) @@ -403,7 +403,8 @@ end """ combinearg(arg_name, arg_type, is_splat, default) -`combinearg` is the inverse of [`splitarg`](@ref). +`combinearg` is the inverse of [`splitarg`](@ref). `combinearg` may oneday +be deprecated in favor of `combineargdef` """ function combinearg(arg_name, arg_type, is_splat, default) @assert arg_name !== nothing || arg_type !== nothing @@ -430,6 +431,8 @@ julia> map(splitarg, (:(f(a=2, x::Int=nothing, y, args...))).args[2:end]) (:args, :Any, true, nothing) ``` +`splitarg` may oneday be deprecated in favor of `splitargdef` + See also: [`combinearg`](@ref) """ function splitarg(arg_expr) @@ -450,6 +453,63 @@ function splitarg(arg_expr) return (arg_name, arg_type, is_splat, default) end +""" + combineargdef(dict) + +`combineargdef` is the inverse of [`splitargdef`](@ref). +""" +function combineargdef(dict) + @assert haskey(dict, :name) || haskey(dict, :type) + if haskey(dict, :name) + a = dict[:name] + a = haskey(dict, :type) ? :($a::$(dict[:type])) : a + else + a = :(::$(dict[:type])) + end + a = dict[:is_splat] ? Expr(:..., a) : a + return haskey(dict, :default) ? Expr(:kw, a, dict[:default]) : a +end + +""" + splitargdef(arg) + +Match function arguments (whther from a definition or a function call) such as +`x::Int=2` and return `Dict(:name=>..., :is_splat=>..., etc.)`. The definition can be rebuilt by +calling `MacroTools.combineargdef(dict)`. +For example: + +```julia +julia> map(splitarg, (:(f(a=2, x::Int=nothing, y, args...))).args[2:end]) +4-element Array{Tuple{Symbol,Symbol,Bool,Any},1}: + Dict(:name=>:a, :is_splat=false, :default=>2) + Dict(:name=>:x, :type=>:Int, :is_splat=>false) + Dict(:name=>:y, :is_splat=>false) + Dict(:name=>:args, :is_splat=>true) +``` + +See also: [`combineargdef`](@ref) +""" +function splitargdef(arg_expr) + dict=Dict() + if @capture(arg_expr, arg_expr2_ = default_) + dict[:default] = default + else + arg_expr2 = arg_expr + end + is_splat = @capture(arg_expr2, arg_expr3_...) + is_splat || (arg_expr3 = arg_expr2) + dict[:is_splat] = is_splat + if @capture(arg_expr3, ::T_) + dict[:type]=T + elseif @capture(arg_expr3, name_::T_) + dict[:name]=name + dict[:type]=T + else + dict[:name] = arg_expr3 + end + return dict +end + function flatten1(ex) isexpr(ex, :block) || return ex diff --git a/test/split.jl b/test/split.jl index 9fc7801..ec827dc 100644 --- a/test/split.jl +++ b/test/split.jl @@ -11,6 +11,13 @@ macro splitcombine(fundef) # should be a no-op esc(MacroTools.combinedef(dict)) end +macro splitcombine2(fundef) # should be a no-op + dict = splitdef(fundef) + dict[:args] = map(arg->combineargdef(splitargdef(arg)), dict[:args]) + dict[:kwargs] = map(arg->combineargdef(splitargdef(arg)), dict[:kwargs]) + esc(MacroTools.combinedef(dict)) +end + # Macros for testing that splitcombine doesn't break # macrocalls in bodies macro zeroarg() @@ -29,14 +36,22 @@ let @test map(splitarg, (:(f(a=2, x::Int=nothing, y::Any, args...))).args[2:end]) == [(:a, :Any, false, 2), (:x, :Int, false, :nothing), (:y, :Any, false, nothing), (:args, :Any, true, nothing)] + @test map(splitargdef, (:(f(a=2, x::Int=nothing, y::Any, args...))).args[2:end]) == + [Dict(:name=>:a, :is_splat=>false, :default=>2), Dict(:name=>:x, :type=>:Int, :is_splat=>false, :default=>:nothing), + Dict(:name=>:y, :type=>:Any, :is_splat=>false), Dict(:name=>:args, :is_splat=>true)] @test splitarg(:(::Int)) == (nothing, :Int, false, nothing) + @test splitargdef(:(::Int)) == Dict(:type=>:Int, :is_splat=>false) kwargs = splitdef(:(f(; a::Int = 1, b...) = 1))[:kwargs] @test map(splitarg, kwargs) == [(:a, :Int, false, 1), (:b, :Any, true, nothing)] + @test map(splitargdef, kwargs) == + [Dict(:name=>:a, :type=>:Int, :is_splat=>false, :default=>1), Dict(:name=>:b, :is_splat=>true)] args = splitdef(:(f(a::Int = 1) = 1))[:args] @test map(splitarg, args) == [(:a, :Int, false, 1)] + @test map(splitargdef, args) == [Dict(:name=>:a, :type=>:Int, :is_splat=>false, :default=>1)] args = splitdef(:(f(a::Int ... = 1) = 1))[:args] @test map(splitarg, args) == [(:a, :Int, true, 1)] # issue 165 + @test map(splitargdef, args) == [Dict(:name=>:a, :type=>:Int, :is_splat=>true, :default=>1)] # issue 165 @splitcombine foo(x) = x+2 @test foo(10) == 12 @@ -85,6 +100,54 @@ let end)(1, Number[2.0]) == (Int, Number) end +let + @splitcombine2 foo(x) = x+2 + @test foo(10) == 12 + @splitcombine2 add(a, b=2; c=3, d=4)::Float64 = a+b+c+d + @test add(1; d=10) === 16.0 + @splitcombine2 fparam(a::T) where {T} = T + @test fparam([]) == Vector{Any} + struct Orange end + @splitcombine2 (::Orange)(x) = x+2 + @test Orange()(10) == 12 + @splitcombine2 fwhere(a::T) where T = T + @test fwhere(10) == Int + @splitcombine2 manywhere(x::T, y::Vector{U}) where T <: U where U = (T, U) + @test manywhere(1, Number[2.0]) == (Int, Number) + @splitcombine2 fmacro0() = @zeroarg + @test fmacro0() == 1 + @splitcombine2 fmacro1() = @onearg 1 + @test fmacro1() == 2 + + @splitcombine2 bar(; a::Int = 1, b...) = 2 + @test bar(a=3, x = 1, y = 2) == 2 + @splitcombine2 qux(a::Int... = 0) = sum(a) + @test qux(1, 2, 3) == 6 + @test qux() == 0 + + struct Foo{A, B} + a::A + b::B + end + # Parametric outer constructor + @splitcombine2 Foo{A}(a::A) where A = Foo{A, A}(a,a) + @test Foo{Int}(2) == Foo{Int, Int}(2, 2) + + @test (@splitcombine2 x -> x + 2)(10) === 12 + @test (@splitcombine2 (a, b=2; c=3, d=4) -> a+b+c+d)(1; d=10) === 16 + @test (@splitcombine2 ((a, b)::Tuple{Int,Int} -> a + b))((1, 2)) == 3 + @test (@splitcombine2 ((a::T) where {T}) -> T)([]) === Vector{Any} + @test (@splitcombine2 ((x::T, y::Vector{U}) where T <: U where U) -> (T, U))(1, Number[2.0]) == + (Int, Number) + @test (@splitcombine2 () -> @zeroarg)() == 1 + @test (@splitcombine2 () -> @onearg 1)() == 2 + @test (@splitcombine2 function (x) x + 2 end)(10) === 12 + @test (@splitcombine2 function (a::T) where {T} T end)([]) === Vector{Any} + @test (@splitcombine2 function (x::T, y::Vector{U}) where T <: U where U + (T, U) + end)(1, Number[2.0]) == (Int, Number) +end + @testset "combinestructdef, splitstructdef" begin ex = :(struct S end) @test ex |> splitstructdef |> combinestructdef |> Base.remove_linenums! ==