@@ -608,12 +608,22 @@ for f in [
608608 end
609609end
610610
611- for f in (:eig, :eigh, :lq, :qr, :polar, :svd)
612- ff = Symbol(" default_" , f, " _algorithm" )
611+ for f in [
612+ :default_eig_algorithm,
613+ :default_eigh_algorithm,
614+ :default_lq_algorithm,
615+ :default_qr_algorithm,
616+ :default_polar_algorithm,
617+ :default_svd_algorithm,
618+ ]
613619 @eval begin
614- function MatrixAlgebraKit.$ ff(A:: Type{<:KroneckerMatrix} ; kwargs... )
620+ function MatrixAlgebraKit.$ f(
621+ A:: Type{<:KroneckerMatrix} ; kwargs1= (;), kwargs2= (;), kwargs...
622+ )
615623 A1, A2 = argument_types(A)
616- return KroneckerAlgorithm($ ff(A1; kwargs... ), $ ff(A2; kwargs... ))
624+ return KroneckerAlgorithm(
625+ $ f(A1; kwargs... , kwargs1... ), $ f(A2; kwargs... , kwargs2... )
626+ )
617627 end
618628 end
619629end
@@ -631,7 +641,7 @@ function MatrixAlgebraKit.default_algorithm(
631641 return default_qr_algorithm(A; kwargs... )
632642end
633643
634- for f in (
644+ for f in [
635645 :eig_full!,
636646 :eigh_full!,
637647 :qr_compact!,
@@ -642,22 +652,24 @@ for f in (
642652 :right_polar!,
643653 :svd_compact!,
644654 :svd_full!,
645- )
655+ ]
646656 @eval begin
647657 function MatrixAlgebraKit. initialize_output(
648658 :: typeof ($ f), a:: KroneckerMatrix , alg:: KroneckerAlgorithm
649659 )
650660 return initialize_output($ f, a. a, alg. a) .⊗ initialize_output($ f, a. b, alg. b)
651661 end
652- function MatrixAlgebraKit.$ f(a:: KroneckerMatrix , F, alg:: KroneckerAlgorithm ; kwargs... )
653- $ f(a. a, Base. Fix2(getfield, :a). (F), alg. a; kwargs... )
654- $ f(a. b, Base. Fix2(getfield, :b). (F), alg. b; kwargs... )
662+ function MatrixAlgebraKit.$ f(
663+ a:: KroneckerMatrix , F, alg:: KroneckerAlgorithm ; kwargs1= (;), kwargs2= (;), kwargs...
664+ )
665+ $ f(a. a, Base. Fix2(getfield, :a). (F), alg. a; kwargs... , kwargs1... )
666+ $ f(a. b, Base. Fix2(getfield, :b). (F), alg. b; kwargs... , kwargs2... )
655667 return F
656668 end
657669 end
658670end
659671
660- for f in ( :eig_vals!, :eigh_vals!, :svd_vals!)
672+ for f in [ :eig_vals!, :eigh_vals!, :svd_vals!]
661673 @eval begin
662674 function MatrixAlgebraKit. initialize_output(
663675 :: typeof ($ f), a:: KroneckerMatrix , alg:: KroneckerAlgorithm
@@ -672,7 +684,7 @@ for f in (:eig_vals!, :eigh_vals!, :svd_vals!)
672684 end
673685end
674686
675- for f in ( :eig_trunc!, :eigh_trunc!, :svd_trunc!)
687+ for f in [ :eig_trunc!, :eigh_trunc!, :svd_trunc!]
676688 @eval begin
677689 function MatrixAlgebraKit. truncate!(
678690 :: typeof ($ f),
@@ -684,25 +696,146 @@ for f in (:eig_trunc!, :eigh_trunc!, :svd_trunc!)
684696 end
685697end
686698
687- for f in ( :left_orth!, :right_orth!)
699+ for f in [ :left_orth!, :right_orth!]
688700 @eval begin
689701 function MatrixAlgebraKit. initialize_output(:: typeof ($ f), a:: KroneckerMatrix )
690702 return initialize_output($ f, a. a) .⊗ initialize_output($ f, a. b)
691703 end
692704 end
693705end
694706
695- for f in ( :left_null!, :right_null!)
707+ for f in [ :left_null!, :right_null!]
696708 @eval begin
697709 function MatrixAlgebraKit. initialize_output(:: typeof ($ f), a:: KroneckerMatrix )
698710 return initialize_output($ f, a. a) ⊗ initialize_output($ f, a. b)
699711 end
700- function MatrixAlgebraKit.$ f(a:: KroneckerMatrix , F; kwargs... )
701- $ f(a. a, F. a; kwargs... )
702- $ f(a. b, F. b; kwargs... )
712+ function MatrixAlgebraKit.$ f(a:: KroneckerMatrix , F; kwargs1 = (;), kwargs2 = (;), kwargs... )
713+ $ f(a. a, F. a; kwargs... , kwargs1 ... )
714+ $ f(a. b, F. b; kwargs... , kwargs2 ... )
703715 return F
704716 end
705717 end
706718end
707719
720+ # ###################################################################################
721+ # Special cases for MatrixAlgebraKit factorizations of `Eye(n) ⊗ A` and
722+ # `A ⊗ Eye(n)` where `A`.
723+ # TODO : Delete this once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/34
724+ # is merged.
725+
726+ using FillArrays: SquareEye
727+ const SquareEyeKronecker{T,A<: SquareEye{T} ,B<: AbstractMatrix{T} } = KroneckerMatrix{T,A,B}
728+ const KroneckerSquareEye{T,A<: AbstractMatrix{T} ,B<: SquareEye{T} } = KroneckerMatrix{T,A,B}
729+ const SquareEyeSquareEye{T,A<: SquareEye{T} ,B<: SquareEye{T} } = KroneckerMatrix{T,A,B}
730+
731+ struct SquareEyeAlgorithm <: AbstractAlgorithm end
732+
733+ # Defined to avoid type piracy.
734+ _copy_input_squareeye(f:: F , a) where {F} = copy_input(f, a)
735+ _copy_input_squareeye(f:: F , a:: SquareEye ) where {F} = a
736+
737+ for f in [
738+ :eig_full,
739+ :eigh_full,
740+ :qr_compact,
741+ :qr_full,
742+ :left_polar,
743+ :lq_compact,
744+ :lq_full,
745+ :right_polar,
746+ :svd_compact,
747+ :svd_full,
748+ ]
749+ for T in [:SquareEyeKronecker, :KroneckerSquareEye, :SquareEyeSquareEye]
750+ @eval begin
751+ function MatrixAlgebraKit. copy_input(:: typeof ($ f), a:: $T )
752+ return _copy_input_squareeye($ f, a. a) ⊗ _copy_input_squareeye($ f, a. b)
753+ end
754+ end
755+ end
756+ end
757+
758+ for f in [
759+ :default_eig_algorithm,
760+ :default_eigh_algorithm,
761+ :default_lq_algorithm,
762+ :default_qr_algorithm,
763+ :default_polar_algorithm,
764+ :default_svd_algorithm,
765+ ]
766+ f′ = Symbol(" _" , f, " _squareeye" )
767+ @eval begin
768+ $ f′(a) = $ f(a)
769+ $ f′(a:: Type{<:SquareEye} ) = SquareEyeAlgorithm()
770+ end
771+ for T in [:SquareEyeKronecker, :KroneckerSquareEye, :SquareEyeSquareEye]
772+ @eval begin
773+ function MatrixAlgebraKit.$ f(A:: Type{<:$T} ; kwargs1= (;), kwargs2= (;), kwargs... )
774+ A1, A2 = argument_types(A)
775+ return KroneckerAlgorithm(
776+ $ f′(A1; kwargs... , kwargs1... ), $ f′(A2; kwargs... , kwargs2... )
777+ )
778+ end
779+ end
780+ end
781+ end
782+
783+ # Defined to avoid type piracy.
784+ _initialize_output_squareeye(f:: F , a, alg) where {F} = initialize_output(f, a, alg)
785+ for f in [
786+ :eig_full!,
787+ :eigh_full!,
788+ :qr_compact!,
789+ :qr_full!,
790+ :left_polar!,
791+ :lq_compact!,
792+ :lq_full!,
793+ :right_polar!,
794+ ]
795+ @eval begin
796+ _initialize_output_squareeye(:: typeof ($ f), a:: SquareEye , alg) = (a, a)
797+ end
798+ end
799+ for f in [:svd_compact!, :svd_full!]
800+ @eval begin
801+ _initialize_output_squareeye(:: typeof ($ f), a:: SquareEye , alg) = (a, a, a)
802+ end
803+ end
804+
805+ for f in [
806+ :eig_full!,
807+ :eigh_full!,
808+ :qr_compact!,
809+ :qr_full!,
810+ :left_polar!,
811+ :lq_compact!,
812+ :lq_full!,
813+ :right_polar!,
814+ :svd_compact!,
815+ :svd_full!,
816+ ]
817+ f′ = Symbol(" _" , f, " _squareeye" )
818+ @eval begin
819+ $ f′(a, F, alg; kwargs... ) = $ f(a, F, alg; kwargs... )
820+ $ f′(a, F, alg:: SquareEyeAlgorithm ) = F
821+ end
822+ for T in [:SquareEyeKronecker, :KroneckerSquareEye, :SquareEyeSquareEye]
823+ @eval begin
824+ function MatrixAlgebraKit. initialize_output(
825+ :: typeof ($ f), a:: $T , alg:: KroneckerAlgorithm
826+ )
827+ return _initialize_output_squareeye($ f, a. a, alg. a) .⊗
828+ _initialize_output_squareeye($ f, a. b, alg. b)
829+ end
830+ function MatrixAlgebraKit.$ f(
831+ a:: $T , F, alg:: KroneckerAlgorithm ; kwargs1= (;), kwargs2= (;), kwargs...
832+ )
833+ $ f′(a. a, Base. Fix2(getfield, :a). (F), alg. a; kwargs... , kwargs1... )
834+ $ f′(a. b, Base. Fix2(getfield, :b). (F), alg. b; kwargs... , kwargs2... )
835+ return F
836+ end
837+ end
838+ end
839+ end
840+
708841end
0 commit comments