@@ -27,25 +27,42 @@ function testcat(f, args::Tuple, type, kwargs=NamedTuple())
27
27
@test value (x) == f (args... ; kwargs... )
28
28
else
29
29
@assert length (args) == 2
30
- x = f (track (args[1 ]), args[2 ]; kwargs... )
31
- @test x isa type
32
- @test value (x) == f (args... ; kwargs... )
33
30
34
- x = f (args[1 ], track (args[2 ]); kwargs... )
35
- @test x isa type
36
- @test value (x) == f (args... ; kwargs... )
31
+ broken = f == hcat && (args[2 ] isa AbstractMatrix)
32
+ if broken && VERSION >= v " 1.4"
33
+ @test_broken f (track (args[1 ]), args[2 ]; kwargs... ) isa type
34
+ @test_broken value (f (track (args[1 ]), args[2 ]; kwargs... )) == f (args... ; kwargs... )
35
+ else
36
+ @test f (track (args[1 ]), args[2 ]; kwargs... ) isa type
37
+ @test value (f (track (args[1 ]), args[2 ]; kwargs... )) == f (args... ; kwargs... )
38
+ end
39
+
40
+ broken = f == hcat && (args[1 ] isa AbstractMatrix)
41
+ if broken && VERSION >= v " 1.4"
42
+ @test_broken f (args[1 ], track (args[2 ]); kwargs... ) isa type
43
+ @test_broken value (f (args[1 ], track (args[2 ]); kwargs... )) == f (args... ; kwargs... )
44
+ else
45
+ @test f (args[1 ], track (args[2 ]); kwargs... ) isa type
46
+ @test value (f (args[1 ], track (args[2 ]); kwargs... )) == f (args... ; kwargs... )
47
+ end
37
48
end
38
49
39
50
args = (args... , args... )
40
- x = f (track .(args)... ; kwargs... )
41
- @test x isa type
42
- @test value (x) == f (args... ; kwargs... )
43
-
44
51
sizes = size .(args)
52
+ broken = (f in (vcat, hcat) && (args[2 ] isa AbstractArray))
53
+ if broken && VERSION >= v " 1.4"
54
+ @test_broken f (track .(args)... ; kwargs... ) isa type
55
+ @test_broken value (f (track .(args)... ; kwargs... )) == f (args... ; kwargs... )
56
+ else
57
+ @test f (track .(args)... ; kwargs... ) isa type
58
+ @test value (f (track .(args)... ; kwargs... )) == f (args... ; kwargs... )
59
+ end
60
+
45
61
F = vecx -> sum (f (unpack (sizes, vecx)... ; kwargs... ))
46
62
X = pack (args)
47
63
@test ForwardDiff. gradient (F, X) == gradient (F, X)
48
64
end
65
+
49
66
function pack (xs)
50
67
return mapreduce (vcat, xs) do x
51
68
x isa Number ? x : vec (x)
0 commit comments