Skip to content

Commit 60a30a5

Browse files
committed
minor progress
1 parent 7645568 commit 60a30a5

File tree

3 files changed

+131
-30
lines changed

3 files changed

+131
-30
lines changed

src/rdivl.jl

Lines changed: 105 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,13 @@ end
268268
for _ 1:Nd
269269
k = N - n
270270
n -= W
271-
rdivl_solve_W!(gesp(spa, (z, n)), gesp(spl, (n, n)), k, mask, Val(UNIT))
271+
rdivl_solve_W!(
272+
gesp(spa, (z, n)),
273+
gesp(spl, (n, n)),
274+
k,
275+
Mask{W}(mask),
276+
Val(UNIT)
277+
)
272278
end
273279
spa = gesp(spa, (WS, StaticInt(0)))
274280
m = ubm
@@ -602,6 +608,7 @@ end
602608
end
603609
end
604610
end
611+
# B_{n,m} = (A_{n,m} - \sum_{i=n+1}^N U_{n,i}B_{i,m})/U_{n,n}
605612
function _ldivu_L!(
606613
M,
607614
N,
@@ -616,10 +623,35 @@ function _ldivu_L!(
616623
W = Int(WS)
617624
UF = unroll_factor(WS)
618625
WU = UF * WS
619-
Nr = VectorizationBase.vrem(N, WS)
626+
# for ldiv, we unroll over `n`
627+
Nd, Nr = VectorizationBase.vdivrem(N, WS)
628+
z = StaticInt(0)
620629
m = 0
621630
# m, no remainder
622631
while m < M - WS + 1
632+
n = Int(Nd * W)::Int
633+
if Nr > 0
634+
let t = (gesp(spa, (n, z)), gesp(spl, (n, n))), ft = flatten_to_tup(t)
635+
BdivL_small_kern_u!(Nr, StaticInt(1), Val(UNIT), WS, typeof(t), ft...)
636+
end
637+
end
638+
for _ 1:Nd
639+
k = N - n
640+
n -= W
641+
ldivu_solve_W_u!(
642+
gesp(spa, (n, z)),
643+
gesp(spl, (n, n)),
644+
k,
645+
WS,
646+
UF,
647+
Val(UNIT)
648+
)
649+
end
650+
while n < N - (WU - 1)
651+
ldivu_solve_W_u!(spa, spl, n, WS, UF, Val(UNIT))
652+
n += WU
653+
end
654+
623655
n = Nr # non factor of W remainder
624656
if n > 0
625657
let t = (spa, spl), ft = flatten_to_tup(t)
@@ -635,7 +667,7 @@ function _ldivu_L!(
635667
n += W
636668
end
637669
m += W
638-
spa = gesp(spa, (W, StaticInt(0)))
670+
spa = gesp(spa, (W, z))
639671
end
640672
# remainder on `m`
641673
if m < M
@@ -803,8 +835,7 @@ function ldiv!(
803835
end
804836
function ldiv!(
805837
U::UnitUpperTriangular{T},
806-
A::AbstractMatrix{T},
807-
::Val{false}
838+
A::AbstractMatrix{T}
808839
) where {T<:Union{Float32,Float64}}
809840
rdivl_dispatch!(transpose(A), transpose(parent(U)), Val(true))
810841
return A
@@ -817,3 +848,72 @@ function ldiv!(
817848
rdivl_dispatch!(transpose(copyto!(C, A)), transpose(parent(U)), Val(true))
818849
return C
819850
end
851+
852+
function rdiv!(
853+
A::AbstractMatrix{T},
854+
U::LowerTriangular{T},
855+
::Val
856+
) where {T<:Union{Float32,Float64}}
857+
rdivl_dispatch!(A, parent(U), Val(false))
858+
return A
859+
end
860+
function rdiv!(
861+
C::AbstractMatrix{T},
862+
A::AbstractMatrix{T},
863+
U::LowerTriangular{T},
864+
::Val
865+
) where {T<:Union{Float32,Float64}}
866+
rdivl_dispatch!(copyto!(C, A), parent(U), Val(false))
867+
return C
868+
end
869+
function rdiv!(
870+
A::AbstractMatrix{T},
871+
U::UnitLowerTriangular{T},
872+
::Val
873+
) where {T<:Union{Float32,Float64}}
874+
rdivl_dispatch!(A, parent(U), Val(true))
875+
return A
876+
end
877+
function rdiv!(
878+
C::AbstractMatrix{T},
879+
A::AbstractMatrix{T},
880+
U::UnitLowerTriangular{T},
881+
::Val
882+
) where {T<:Union{Float32,Float64}}
883+
rdivl_dispatch!(copyto!(C, A), parent(U), Val(true))
884+
return C
885+
end
886+
function ldiv!(
887+
U::UpperTriangular{T},
888+
A::AbstractMatrix{T},
889+
::Val
890+
) where {T<:Union{Float32,Float64}}
891+
rdivl_dispatch!(transpose(A), transpose(parent(U)), Val(false))
892+
return A
893+
end
894+
function ldiv!(
895+
C::AbstractMatrix{T},
896+
U::UpperTriangular{T},
897+
A::AbstractMatrix{T},
898+
::Val
899+
) where {T<:Union{Float32,Float64}}
900+
rdivl_dispatch!(transpose(copyto!(C, A)), transpose(parent(U)), Val(false))
901+
return C
902+
end
903+
function ldiv!(
904+
U::UnitUpperTriangular{T},
905+
A::AbstractMatrix{T},
906+
::Val
907+
) where {T<:Union{Float32,Float64}}
908+
rdivl_dispatch!(transpose(A), transpose(parent(U)), Val(true))
909+
return A
910+
end
911+
function ldiv!(
912+
C::AbstractMatrix{T},
913+
U::UnitUpperTriangular{T},
914+
A::AbstractMatrix{T},
915+
::Val
916+
) where {T<:Union{Float32,Float64}}
917+
rdivl_dispatch!(transpose(copyto!(C, A)), transpose(parent(U)), Val(true))
918+
return C
919+
end

src/rdivu.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ end
256256
# So, we'll use `U = 1`, and transpose blocks
257257
# We then have column-major multiplies
258258
Base.Cartesian.@nexprs $U u -> begin
259-
# take A[(u-1)*W,u*W), [0,W)]
259+
# take A[[(u-1)*W,u*W), [0,W)]
260260
X_u = getfield(
261261
VectorizationBase.transpose_vecunroll(
262262
VecUnroll(
@@ -460,15 +460,12 @@ end
460460
mask = nomaskiter ? maxmask : finalmask
461461
n = Nr
462462
if n > 0
463-
let t = (spa, spu),
464-
ft = flatten_to_tup(t),
465-
mask = getfield(mask, :u) % UInt32
466-
463+
let t = (spa, spu), ft = flatten_to_tup(t)
467464
BdivU_small_kern!(n, mask, WS, Val(UNIT), typeof(t), ft...)
468465
end
469466
end
470467
for _ 1:Nd
471-
rdivu_solve_W!(spa, spu, n, mask, Val(UNIT))
468+
rdivu_solve_W!(spa, spu, n, Mask{W}(mask), Val(UNIT))
472469
n += W
473470
end
474471
spa = gesp(spa, (WS, StaticInt(0)))

test/runtests.jl

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,33 +26,37 @@ function test_solve(::Type{T}) where {T}
2626
B .= rand.(T)
2727
@view(B[diagind(B)]) .+= one(T)
2828

29-
@test TriangularSolve.rdiv!(res, A, UpperTriangular(B)) *
30-
UpperTriangular(B) A
31-
@test TriangularSolve.rdiv!(res, A, UnitUpperTriangular(B)) *
32-
UnitUpperTriangular(B) A
33-
@test TriangularSolve.rdiv!(res, A, UpperTriangular(B), Val(false)) *
34-
UpperTriangular(B) A
35-
@test TriangularSolve.rdiv!(res, A, UnitUpperTriangular(B), Val(false)) *
36-
UnitUpperTriangular(B) A
29+
for C in (
30+
UpperTriangular(B),
31+
UnitUpperTriangular(B),
32+
LowerTriangular(B),
33+
UnitLowerTriangular(B)
34+
)
35+
@test TriangularSolve.rdiv!(res, A, C) * C A
36+
check_box_for_nans(RR, m, n)
37+
@test TriangularSolve.rdiv!(res, A, C, Val(false)) * C A
38+
check_box_for_nans(RR, m, n)
39+
end
3740

38-
check_box_for_nans(RR, m, n)
3941
res .= NaN
4042
A .= NaN
4143

4244
A = @view AA[17:16+n, 17:16+m]
4345
res = @view RR[17:16+n, 17:16+m]
4446
A .= rand.(T)
4547

46-
@test LowerTriangular(B) *
47-
TriangularSolve.ldiv!(res, LowerTriangular(B), A) A
48-
@test UnitLowerTriangular(B) *
49-
TriangularSolve.ldiv!(res, UnitLowerTriangular(B), A) A
50-
@test LowerTriangular(B) *
51-
TriangularSolve.ldiv!(res, LowerTriangular(B), A, Val(false)) A
52-
@test UnitLowerTriangular(B) *
53-
TriangularSolve.ldiv!(res, UnitLowerTriangular(B), A, Val(false))
54-
A
55-
check_box_for_nans(RR, n, m)
48+
for C in (
49+
UpperTriangular(B),
50+
UnitUpperTriangular(B),
51+
LowerTriangular(B),
52+
UnitLowerTriangular(B)
53+
)
54+
@test C * TriangularSolve.ldiv!(res, C, A) A
55+
check_box_for_nans(RR, n, m)
56+
@test C * TriangularSolve.ldiv!(res, C, A, Val(false)) A
57+
check_box_for_nans(RR, n, m)
58+
end
59+
5660
res .= NaN
5761
A .= NaN
5862
B .= NaN

0 commit comments

Comments
 (0)