diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index a5edd6cd5..11f0c202b 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -394,3 +394,64 @@ function rrule( end return Ω, lyap_pullback end + +##### +##### `kron` +##### + +@static if VERSION ≥ v"1.9.0-DEV.1267" + function frule((_, Δx, Δy), ::typeof(kron), x::AbstractVecOrMat{<:Number}, y::AbstractVecOrMat{<:Number}) + return kron(x, y), kron(Δx, y) + kron(x, Δy) + end + + function rrule(::typeof(kron), x::AbstractVector{<:Number}, y::AbstractVector{<:Number}) + project_x = ProjectTo(x) + project_y = ProjectTo(y) + function kron_pullback(z̄) + dz = reshape(unthunk(z̄), length(y), length(x)) + x̄ = @thunk(project_x(conj.(dz' * y))) + ȳ = @thunk(project_y(dz * conj.(x))) + return NoTangent(), x̄, ȳ + end + return kron(x, y), kron_pullback + end + + function rrule(::typeof(kron), x::AbstractMatrix{<:Number}, y::AbstractVector{<:Number}) + project_x = ProjectTo(x) + project_y = ProjectTo(y) + function kron_pullback(z̄) + dz = reshape(unthunk(z̄), length(y), size(x)...) + x̄ = @thunk(project_x(dot.(Ref(y), eachslice(dz; dims = (2, 3))))) + ȳ = @thunk(project_y(dot.(Ref(x), eachslice(dz; dims = 1)))) + return NoTangent(), x̄, ȳ + end + return kron(x, y), kron_pullback + end + + function rrule(::typeof(kron), x::AbstractVector{<:Number}, y::AbstractMatrix{<:Number}) + project_x = ProjectTo(x) + project_y = ProjectTo(y) + function kron_pullback(z̄) + dz = reshape(unthunk(z̄), size(y, 1), length(x), size(y, 2)) + x̄ = @thunk(project_x(dot.(Ref(y), eachslice(dz; dims = 2)))) + ȳ = @thunk(project_y(dot.(Ref(x), eachslice(dz; dims = (1, 3))))) + return NoTangent(), x̄, ȳ + end + return kron(x, y), kron_pullback + end + + function rrule(::typeof(kron), x::AbstractMatrix{<:Number}, y::AbstractMatrix{<:Number}) + project_x = ProjectTo(x) + project_y = ProjectTo(y) + function kron_pullback(z̄) + dz = reshape(unthunk(z̄), size(y, 1), size(x, 1), size(y, 2), size(x, 2)) + x̄ = @thunk(project_x(_dot_collect.(Ref(y), eachslice(dz; dims = (2, 4))))) + ȳ = @thunk(project_y(_dot_collect.(Ref(x), eachslice(dz; dims = (1, 3))))) + return NoTangent(), x̄, ȳ + end + return kron(x, y), kron_pullback + end + + _dot_collect(A::AbstractMatrix, B::SubArray) = dot(A, B) + _dot_collect(A::Diagonal, B::SubArray) = dot(A, collect(B)) +end diff --git a/test/rulesets/LinearAlgebra/dense.jl b/test/rulesets/LinearAlgebra/dense.jl index 5f5efa8d2..c145dfdbf 100644 --- a/test/rulesets/LinearAlgebra/dense.jl +++ b/test/rulesets/LinearAlgebra/dense.jl @@ -159,4 +159,26 @@ test_rrule(lyap, A, C) end end + VERSION ≥ v"1.9.0" && @testset "kron" begin + @testset "AbstractVecOrMat{$T1}, AbstractVecOrMat{$T2}" for T1 in (Float64, ComplexF64), T2 in (Float64, ComplexF64) + @testset "frule" begin + test_frule(kron, randn(T1, 2), randn(T2, 3)) + test_frule(kron, randn(T1, 2, 3), randn(T2, 5)) + test_frule(kron, randn(T1, 2), randn(T2, 3, 5)) + test_frule(kron, randn(T1, 2, 3), randn(T2, 5, 7)) + end + @testset "rrule" begin + test_rrule(kron, randn(T1, 2), randn(T2, 3)) + + test_rrule(kron, Diagonal(randn(T1, 2)), randn(T2, 3)) + test_rrule(kron, randn(T1, 2, 3), randn(T2, 5)) + + test_rrule(kron, randn(T1, 2), randn(T2, 3, 5)) + test_rrule(kron, randn(T1, 2), Diagonal(randn(T2, 3))) + + test_rrule(kron, randn(T1, 2, 3), randn(T2, 5, 7)) + test_rrule(kron, Diagonal(randn(T1, 2)), Diagonal(randn(T2, 3))) + end + end + end end