Skip to content

Commit 8cc9737

Browse files
committed
tests pass locally
1 parent af56460 commit 8cc9737

File tree

2 files changed

+47
-66
lines changed

2 files changed

+47
-66
lines changed

src/rdivl.jl

Lines changed: 40 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -273,34 +273,18 @@ end
273273
end
274274
@generated function ldivu_solve_W_u!(
275275
spa,
276-
spu,
276+
spl,
277277
n,
278278
::StaticInt{W},
279279
::StaticInt{U},
280280
::Val{UNIT}
281281
) where {W,U,UNIT}
282282
z = static(0)
283-
# B_{n,m} = (A_{n,m} - \sum_{i=n+1}^N U_{n,i}B_{i,m})/U_{n,n}
283+
# Actually a row-major rdivl
284+
# B_{m,n} = (A_{m,n} - \sum_{i=n+1}^N B_{m,i}L_{i,n})/L_{n,n}
284285
Aind = Unroll{1,1,W,2,W,zero(UInt),1}(Unroll{2,W,U,2,W,zero(UInt),1}((z, z)))
285286
q = quote
286287
# $(Expr(:meta, :inline))
287-
# C = U \ A; U * C = A
288-
# A_{i,j} = U_{i,i}*C_{i,j} + \sum_{k=i+1}^{N}U_{i,k}C_{k,j}
289-
# C_{i,j} = U_{i,i} \ (A_{i,j} - \sum_{k=i+1}^{N}U_{i,k}C_{k,j})
290-
# The inputs here are transposed, as the library was formulated in terms of `rdiv!`,
291-
# so we have
292-
# C_{j,i} = (A_{j,i} - \sum_{k=i+1}^{N}C_{j,k}U_{k,i}) / L_{i,i}
293-
# This solves for the block: C_{j+[0,W],i+[0,W*U)}
294-
# This can be viewed as `U` blocks that are each `W`x`W`
295-
# E.g. U=3, rough alg:
296-
# r=[0,W); c=[0,WU)
297-
# X = A_{j+r,i+c} - \sum_{k=1}^{i-1}C_{j+r,k}*U_{k,i+c}
298-
# C_{j+r,i+r} = X[:, r] / U_{i+r,i+r}
299-
# C_{j+r,i+W+r} = (X[:, W+r] - C_{j+r,i+r}*U_{i+r,i+W+r}) / U_{i+W+r,i+W+r}
300-
# C_{j+r,i+2W+r} = (X[:, 2W+r] - C_{j+r,i+r}*U_{i+r,i+2W+r} - C_{j+r,i+W+r}*U_{i+W+r,i+2W+r}) / U_{i+2W+r,i+2W+r}
301-
#
302-
# outer unroll are `W` rows
303-
# Inner unroll are `W*U` columns (U simd vecs)
304288
#
305289
A11 = getfield(vload(spa, $Aind), :data)
306290
# The `W` rows
@@ -310,8 +294,8 @@ end
310294
# Each iter:
311295
# A_{j+[0,W), i+[0,W*U)} -= C_{j+[0,W),k}*U_{k,i+[0,W*U)}
312296
for nk SafeCloseOpen(n) # nmuladd
313-
nkw = nk + $W
314-
U_ki = vload(spu, $(Unroll{2,W,U,2,W,zero(UInt),1})((nkw, $z)))
297+
nkw = nk + $(W * U)
298+
U_ki = vload(spl, $(Unroll{2,W,U,2,W,zero(UInt),1})((nkw, $z)))
315299
Base.Cartesian.@nexprs $W c ->
316300
A11_c = vfnmadd_fast(U_ki, vload(spa, (static(c - 1), nkw)), A11_c)
317301
end
@@ -326,6 +310,7 @@ end
326310
Xu = Vector{Symbol}(undef, W)
327311
Csym = Vector{Symbol}(undef, U)
328312
for u = 1:U
313+
# X_u are future
329314
X_u = Symbol(:X_, u)
330315
push!(
331316
q.args,
@@ -341,18 +326,19 @@ end
341326
)
342327
)
343328
)
329+
# push!(q.args, :(println($X_u)))
344330
for c = 1:W
345331
X_u_c = Xu[c] = Symbol(:X_, u, :_, c)
346332
push!(q.args, Expr(:(=), X_u_c, Expr(:call, getfield, X_u, c)))
347333
end
348334
# take A[(U-u+1)*W,u*W), [0,W)]
349-
for j = 1:u-1
350-
for k = 1:W
351-
for c = 1:W
352-
urow = ((W - k) + ((j - 1) * W))
335+
for j = 1:u-1 # iter over all blocks ordered after
336+
for k = 1:W # reduction dimension, reverse order
337+
for c = 1:W # columns of C
338+
urow = ((W - k) + ((U - j) * W))
353339
ucol = ((c - 1) + ((U - u) * W))
354-
push!(q.args, Expr(:call, println, "Row = $urow; Col = $ucol"))
355-
Uexpr = :(vload(spu, ($urow, $ucol)))
340+
# push!(q.args, Expr(:call, println, "Row = $urow; Col = $ucol"))
341+
Uexpr = :(vload(spl, ($urow, $ucol)))
356342
X_u_c = Xu[c]
357343
C_j_k = Symbol(:C_, j, :_, W + 1 - k)
358344
Xucexpr = Expr(:call, vfnmadd_fast, C_j_k, Uexpr, X_u_c)
@@ -361,7 +347,7 @@ push!(q.args, Expr(:call, println, "Row = $urow; Col = $ucol"))
361347
end
362348
end
363349
o = (U - u) * W
364-
sp = Expr(:call, gesp, :spu, (o, o))
350+
sp = Expr(:call, gesp, :spl, (o, o))
365351
Xut = Expr(:tuple)
366352
for c = 1:W
367353
push!(Xut.args, Xu[c])
@@ -379,6 +365,7 @@ push!(q.args, Expr(:call, println, "Row = $urow; Col = $ucol"))
379365
end
380366
end
381367
for u = 1:U
368+
# u = 1 is last, first processed (reverse order)
382369
ui = Unroll{2,1,W,1,W,zero(UInt),1}((z, (U - u) * W))
383370
C_u = Csym[u]
384371
push!(q.args, :(vstore!(spa, $C_u, $ui)))
@@ -387,7 +374,7 @@ push!(q.args, Expr(:call, println, "Row = $urow; Col = $ucol"))
387374
end
388375
@generated function ldivu_solve_W!(
389376
spa,
390-
spu,
377+
spl,
391378
n,
392379
::StaticInt{W},
393380
::Val{UNIT},
@@ -425,7 +412,7 @@ end
425412
# A_{j+[0,W), i+[0,W*U)} -= C_{j+[0,W),k}*U_{k,i+[0,W*U)}
426413
for nk SafeCloseOpen(n) # nmuladd
427414
nkw = nk + $W
428-
U_ki = vload(spu, (nkw, $(MM{W}(z))))
415+
U_ki = vload(spl, (nkw, $(MM{W}(z))))
429416
Base.Cartesian.@nexprs $R r ->
430417
A11_r = vfnmadd_fast(U_ki, vload(spa, (static(r - 1), nkw)), A11_r)
431418
end
@@ -444,7 +431,7 @@ end
444431
# We then have column-major multiplies
445432
# take A[(u-1)*W,u*W), [0,W)]
446433
X = VectorizationBase.transpose_vecunroll(VecUnroll($t))
447-
C_u = solve_AL(X, spu, $(Val(UNIT)))
434+
C_u = solve_AL(X, spl, $(Val(UNIT)))
448435
end
449436
push!(q.args, q2)
450437
q3 = if R == Wpad
@@ -465,7 +452,7 @@ end
465452

466453
@generated function _ldivu_remainder!(
467454
spa,
468-
spu,
455+
spl,
469456
N,
470457
Nr,
471458
::StaticInt{W},
@@ -486,7 +473,7 @@ end
486473
vlxj = :(xj = $vlxj)
487474
else
488475
vlxj = quote
489-
xj = $div($vlxj, vload(spu, (j, j)))
476+
xj = $div($vlxj, vload(spl, (j, j)))
490477
vstore!(spa, xj, ($z, j))
491478
end
492479
end
@@ -500,7 +487,7 @@ end
500487
while i > 0
501488
i -= 1
502489
xi = vload(spa, ($z, i))
503-
Uji = vload(spu, (j, i))
490+
Uji = vload(spl, (j, i))
504491
vstore!(spa, $sub(xi, $mul(xj, Uji)), ($z, i))
505492
end
506493
j == 0 && break
@@ -514,19 +501,17 @@ end
514501
mask = $(getfield(_mask(WS, r), :u) % UInt32)
515502
n = N - Nr
516503
if Nr > 0
517-
@show pointer(spa), pointer(spu), n, Nr
518-
let t = (gesp(spa, ($z, n)), gesp(spu, (n, n))), ft = flatten_to_tup(t)
504+
let t = (gesp(spa, ($z, n)), gesp(spl, (n, n))), ft = flatten_to_tup(t)
519505
BdivL_small_kern!(Nr, mask, $WS, $(Val(UNIT)), typeof(t), ft...)
520506
end
521507
end
522508
# non-U, order first as matmul kern is smaller than optimal
523509
while n != 0
524510
k = N - n
525511
n -= W
526-
@show pointer(spa), pointer(spu), k, n
527512
ldivu_solve_W!(
528513
gesp(spa, ($z, n)),
529-
gesp(spu, (n, n)),
514+
gesp(spl, (n, n)),
530515
k,
531516
$WS,
532517
Val(UNIT),
@@ -551,25 +536,25 @@ end
551536
if W == 2
552537
quote
553538
$(Expr(:meta, :inline))
554-
spa, spu = reassemble_tup(Args, args)
555-
_ldivu_remainder!(spa, spu, N, Nrr, Nru, $WS, $(Val(UNIT)), $(static(1)))
539+
spa, spl = reassemble_tup(Args, args)
540+
_ldivu_remainder!(spa, spl, N, Nrr, Nru, $WS, $(Val(UNIT)), $(static(1)))
556541
nothing
557542
end
558543
elseif W == 8
559544
s8 = StaticInt(8)
560545
quote
561546
# $(Expr(:meta, :inline))
562-
spa, spu = reassemble_tup(Args, args)
547+
spa, spl = reassemble_tup(Args, args)
563548
if m == M - 1
564-
_ldivu_remainder!(spa, spu, N, Nr, $s8, $(Val(UNIT)), $(StaticInt(1)))
549+
_ldivu_remainder!(spa, spl, N, Nr, $s8, $(Val(UNIT)), $(StaticInt(1)))
565550
else
566551
if m == M - 2
567-
_ldivu_remainder!(spa, spu, N, Nr, $s8, $(Val(UNIT)), $(StaticInt(2)))
552+
_ldivu_remainder!(spa, spl, N, Nr, $s8, $(Val(UNIT)), $(StaticInt(2)))
568553
else
569554
if m == M - 3
570555
_ldivu_remainder!(
571556
spa,
572-
spu,
557+
spl,
573558
N,
574559
Nr,
575560
$s8,
@@ -580,7 +565,7 @@ end
580565
if m == M - 4
581566
_ldivu_remainder!(
582567
spa,
583-
spu,
568+
spl,
584569
N,
585570
Nr,
586571
$s8,
@@ -591,7 +576,7 @@ end
591576
if m == M - 5
592577
_ldivu_remainder!(
593578
spa,
594-
spu,
579+
spl,
595580
N,
596581
Nr,
597582
$s8,
@@ -602,7 +587,7 @@ end
602587
if m == M - 6
603588
_ldivu_remainder!(
604589
spa,
605-
spu,
590+
spl,
606591
N,
607592
Nr,
608593
$s8,
@@ -612,7 +597,7 @@ end
612597
else
613598
_ldivu_remainder!(
614599
spa,
615-
spu,
600+
spl,
616601
N,
617602
Nr,
618603
$s8,
@@ -630,9 +615,9 @@ end
630615
else
631616
quote
632617
# $(Expr(:meta, :inline))
633-
spa, spu = reassemble_tup(Args, args)
618+
spa, spl = reassemble_tup(Args, args)
634619
Base.Cartesian.@nif $(W - 1) w -> m == M - w w ->
635-
_ldivu_remainder!(spa, spu, N, Nr, $WS, $(Val(UNIT)), static(w))
620+
_ldivu_remainder!(spa, spl, N, Nr, $WS, $(Val(UNIT)), static(w))
636621
nothing
637622
end
638623
end
@@ -646,7 +631,7 @@ function _ldivu_L!(
646631
args::Vararg{Any,K}
647632
) where {UNIT,Args,K}
648633
# B_{n,m} = (A_{n,m} - \sum_{i=n+1}^N U_{n,i}B_{i,m})/U_{n,n}
649-
spa, spu = reassemble_tup(Args, args)
634+
spa, spl = reassemble_tup(Args, args)
650635
T = eltype(spa)
651636
WS = pick_vector_width(T)
652637
W = Int(WS)
@@ -664,8 +649,7 @@ function _ldivu_L!(
664649
while m < M - WS + 1
665650
n::Int = nstart
666651
if Nrr > 0
667-
let t = (gesp(spa, (z, n)), gesp(spu, (n, n))), ft = flatten_to_tup(t)
668-
@show 0, n
652+
let t = (gesp(spa, (z, n)), gesp(spl, (n, n))), ft = flatten_to_tup(t)
669653
compute && BdivL_small_kern_u!(
670654
Nrr,
671655
StaticInt(1),
@@ -680,10 +664,9 @@ function _ldivu_L!(
680664
for _ 1:Ndr
681665
k = N - n
682666
n -= W
683-
@show 1, n, k
684667
compute && ldivu_solve_W!(
685668
gesp(spa, (z, n)),
686-
gesp(spu, (n, n)),
669+
gesp(spl, (n, n)),
687670
k,
688671
WS,
689672
Val(UNIT),
@@ -694,10 +677,9 @@ function _ldivu_L!(
694677
while n != 0
695678
k = N - n
696679
n -= Int(WU)
697-
@show 2, n, k
698680
compute && ldivu_solve_W_u!(
699681
gesp(spa, (z, n)),
700-
gesp(spu, (n, n)),
682+
gesp(spl, (n, n)),
701683
k,
702684
WS,
703685
UF,
@@ -709,8 +691,7 @@ function _ldivu_L!(
709691
end
710692
# remainder on `m`
711693
if m < M
712-
let tup = (spa, spu), ftup = flatten_to_tup(tup)
713-
@show m, Nrr, M
694+
let tup = (spa, spl), ftup = flatten_to_tup(tup)
714695
compute &&
715696
ldivu_remainder!(M, N, m, Nrr, WS, Val(UNIT), typeof(tup), ftup...)
716697
end

test/runtests.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ function test_solve(::Type{T}) where {T}
1818
for n 1:maxN
1919
@show n
2020
for m max(1, n - 10):n+10
21-
@show m
21+
# @show m
2222
A = @view AA[17:16+m, 17:16+n]
2323
res = @view RR[17:16+m, 17:16+n]
2424
B = @view BB[17:16+n, 17:16+n]
@@ -30,10 +30,10 @@ function test_solve(::Type{T}) where {T}
3030
for C in (
3131
UpperTriangular(B),
3232
UnitUpperTriangular(B),
33-
# LowerTriangular(B),
34-
# UnitLowerTriangular(B)
33+
LowerTriangular(B),
34+
UnitLowerTriangular(B)
3535
)
36-
@show typeof(C)
36+
# @show typeof(C)
3737
@test TriangularSolve.rdiv!(res, A, C) * C A
3838
check_box_for_nans(RR, m, n)
3939
@test TriangularSolve.rdiv!(res, A, C, Val(false)) * C A
@@ -48,12 +48,12 @@ function test_solve(::Type{T}) where {T}
4848
A .= rand.(T)
4949

5050
for C in (
51-
# UpperTriangular(B),
52-
# UnitUpperTriangular(B),
51+
UpperTriangular(B),
52+
UnitUpperTriangular(B),
5353
LowerTriangular(B),
5454
UnitLowerTriangular(B)
5555
)
56-
@show typeof(C)
56+
# @show typeof(C)
5757
@test C * TriangularSolve.ldiv!(res, C, A) A
5858
check_box_for_nans(RR, n, m)
5959
@test C * TriangularSolve.ldiv!(res, C, A, Val(false)) A

0 commit comments

Comments
 (0)