@@ -273,34 +273,18 @@ end
273
273
end
274
274
@generated function ldivu_solve_W_u! (
275
275
spa,
276
- spu ,
276
+ spl ,
277
277
n,
278
278
:: StaticInt{W} ,
279
279
:: StaticInt{U} ,
280
280
:: Val{UNIT}
281
281
) where {W,U,UNIT}
282
282
z = static (0 )
283
- # B_{n,m} = (A_{n,m} - \sum_{i=n+1}^N U_{n,i}B_{i,m})/U_{n,n}
283
+ # Actually a row-major rdivl
284
+ # B_{m,n} = (A_{m,n} - \sum_{i=n+1}^N B_{m,i}L_{i,n})/L_{n,n}
284
285
Aind = Unroll {1,1,W,2,W,zero(UInt),1} (Unroll {2,W,U,2,W,zero(UInt),1} ((z, z)))
285
286
q = quote
286
287
# $(Expr(:meta, :inline))
287
- # C = U \ A; U * C = A
288
- # A_{i,j} = U_{i,i}*C_{i,j} + \sum_{k=i+1}^{N}U_{i,k}C_{k,j}
289
- # C_{i,j} = U_{i,i} \ (A_{i,j} - \sum_{k=i+1}^{N}U_{i,k}C_{k,j})
290
- # The inputs here are transposed, as the library was formulated in terms of `rdiv!`,
291
- # so we have
292
- # C_{j,i} = (A_{j,i} - \sum_{k=i+1}^{N}C_{j,k}U_{k,i}) / L_{i,i}
293
- # This solves for the block: C_{j+[0,W],i+[0,W*U)}
294
- # This can be viewed as `U` blocks that are each `W`x`W`
295
- # E.g. U=3, rough alg:
296
- # r=[0,W); c=[0,WU)
297
- # X = A_{j+r,i+c} - \sum_{k=1}^{i-1}C_{j+r,k}*U_{k,i+c}
298
- # C_{j+r,i+r} = X[:, r] / U_{i+r,i+r}
299
- # C_{j+r,i+W+r} = (X[:, W+r] - C_{j+r,i+r}*U_{i+r,i+W+r}) / U_{i+W+r,i+W+r}
300
- # C_{j+r,i+2W+r} = (X[:, 2W+r] - C_{j+r,i+r}*U_{i+r,i+2W+r} - C_{j+r,i+W+r}*U_{i+W+r,i+2W+r}) / U_{i+2W+r,i+2W+r}
301
- #
302
- # outer unroll are `W` rows
303
- # Inner unroll are `W*U` columns (U simd vecs)
304
288
#
305
289
A11 = getfield (vload (spa, $ Aind), :data )
306
290
# The `W` rows
310
294
# Each iter:
311
295
# A_{j+[0,W), i+[0,W*U)} -= C_{j+[0,W),k}*U_{k,i+[0,W*U)}
312
296
for nk ∈ SafeCloseOpen (n) # nmuladd
313
- nkw = nk + $ W
314
- U_ki = vload (spu , $ (Unroll{2 ,W,U,2 ,W,zero (UInt),1 })((nkw, $ z)))
297
+ nkw = nk + $ (W * U)
298
+ U_ki = vload (spl , $ (Unroll{2 ,W,U,2 ,W,zero (UInt),1 })((nkw, $ z)))
315
299
Base. Cartesian. @nexprs $ W c ->
316
300
A11_c = vfnmadd_fast (U_ki, vload (spa, (static (c - 1 ), nkw)), A11_c)
317
301
end
326
310
Xu = Vector {Symbol} (undef, W)
327
311
Csym = Vector {Symbol} (undef, U)
328
312
for u = 1 : U
313
+ # X_u are future
329
314
X_u = Symbol (:X_ , u)
330
315
push! (
331
316
q. args,
@@ -341,18 +326,19 @@ end
341
326
)
342
327
)
343
328
)
329
+ # push!(q.args, :(println($X_u)))
344
330
for c = 1 : W
345
331
X_u_c = Xu[c] = Symbol (:X_ , u, :_ , c)
346
332
push! (q. args, Expr (:(= ), X_u_c, Expr (:call , getfield, X_u, c)))
347
333
end
348
334
# take A[(U-u+1)*W,u*W), [0,W)]
349
- for j = 1 : u- 1
350
- for k = 1 : W
351
- for c = 1 : W
352
- urow = ((W - k) + ((j - 1 ) * W))
335
+ for j = 1 : u- 1 # iter over all blocks ordered after
336
+ for k = 1 : W # reduction dimension, reverse order
337
+ for c = 1 : W # columns of C
338
+ urow = ((W - k) + ((U - j ) * W))
353
339
ucol = ((c - 1 ) + ((U - u) * W))
354
- push! (q. args, Expr (:call , println, " Row = $urow ; Col = $ucol " ))
355
- Uexpr = :(vload (spu , ($ urow, $ ucol)))
340
+ # push!(q.args, Expr(:call, println, "Row = $urow; Col = $ucol"))
341
+ Uexpr = :(vload (spl , ($ urow, $ ucol)))
356
342
X_u_c = Xu[c]
357
343
C_j_k = Symbol (:C_ , j, :_ , W + 1 - k)
358
344
Xucexpr = Expr (:call , vfnmadd_fast, C_j_k, Uexpr, X_u_c)
@@ -361,7 +347,7 @@ push!(q.args, Expr(:call, println, "Row = $urow; Col = $ucol"))
361
347
end
362
348
end
363
349
o = (U - u) * W
364
- sp = Expr (:call , gesp, :spu , (o, o))
350
+ sp = Expr (:call , gesp, :spl , (o, o))
365
351
Xut = Expr (:tuple )
366
352
for c = 1 : W
367
353
push! (Xut. args, Xu[c])
@@ -379,6 +365,7 @@ push!(q.args, Expr(:call, println, "Row = $urow; Col = $ucol"))
379
365
end
380
366
end
381
367
for u = 1 : U
368
+ # u = 1 is last, first processed (reverse order)
382
369
ui = Unroll {2,1,W,1,W,zero(UInt),1} ((z, (U - u) * W))
383
370
C_u = Csym[u]
384
371
push! (q. args, :(vstore! (spa, $ C_u, $ ui)))
@@ -387,7 +374,7 @@ push!(q.args, Expr(:call, println, "Row = $urow; Col = $ucol"))
387
374
end
388
375
@generated function ldivu_solve_W! (
389
376
spa,
390
- spu ,
377
+ spl ,
391
378
n,
392
379
:: StaticInt{W} ,
393
380
:: Val{UNIT} ,
425
412
# A_{j+[0,W), i+[0,W*U)} -= C_{j+[0,W),k}*U_{k,i+[0,W*U)}
426
413
for nk ∈ SafeCloseOpen (n) # nmuladd
427
414
nkw = nk + $ W
428
- U_ki = vload (spu , (nkw, $ (MM {W} (z))))
415
+ U_ki = vload (spl , (nkw, $ (MM {W} (z))))
429
416
Base. Cartesian. @nexprs $ R r ->
430
417
A11_r = vfnmadd_fast (U_ki, vload (spa, (static (r - 1 ), nkw)), A11_r)
431
418
end
444
431
# We then have column-major multiplies
445
432
# take A[(u-1)*W,u*W), [0,W)]
446
433
X = VectorizationBase. transpose_vecunroll (VecUnroll ($ t))
447
- C_u = solve_AL (X, spu , $ (Val (UNIT)))
434
+ C_u = solve_AL (X, spl , $ (Val (UNIT)))
448
435
end
449
436
push! (q. args, q2)
450
437
q3 = if R == Wpad
465
452
466
453
@generated function _ldivu_remainder! (
467
454
spa,
468
- spu ,
455
+ spl ,
469
456
N,
470
457
Nr,
471
458
:: StaticInt{W} ,
486
473
vlxj = :(xj = $ vlxj)
487
474
else
488
475
vlxj = quote
489
- xj = $ div ($ vlxj, vload (spu , (j, j)))
476
+ xj = $ div ($ vlxj, vload (spl , (j, j)))
490
477
vstore! (spa, xj, ($ z, j))
491
478
end
492
479
end
500
487
while i > 0
501
488
i -= 1
502
489
xi = vload (spa, ($ z, i))
503
- Uji = vload (spu , (j, i))
490
+ Uji = vload (spl , (j, i))
504
491
vstore! (spa, $ sub (xi, $ mul (xj, Uji)), ($ z, i))
505
492
end
506
493
j == 0 && break
@@ -514,19 +501,17 @@ end
514
501
mask = $ (getfield (_mask (WS, r), :u ) % UInt32)
515
502
n = N - Nr
516
503
if Nr > 0
517
- @show pointer (spa), pointer (spu), n, Nr
518
- let t = (gesp (spa, ($ z, n)), gesp (spu, (n, n))), ft = flatten_to_tup (t)
504
+ let t = (gesp (spa, ($ z, n)), gesp (spl, (n, n))), ft = flatten_to_tup (t)
519
505
BdivL_small_kern! (Nr, mask, $ WS, $ (Val (UNIT)), typeof (t), ft... )
520
506
end
521
507
end
522
508
# non-U, order first as matmul kern is smaller than optimal
523
509
while n != 0
524
510
k = N - n
525
511
n -= W
526
- @show pointer (spa), pointer (spu), k, n
527
512
ldivu_solve_W! (
528
513
gesp (spa, ($ z, n)),
529
- gesp (spu , (n, n)),
514
+ gesp (spl , (n, n)),
530
515
k,
531
516
$ WS,
532
517
Val (UNIT),
@@ -551,25 +536,25 @@ end
551
536
if W == 2
552
537
quote
553
538
$ (Expr (:meta , :inline ))
554
- spa, spu = reassemble_tup (Args, args)
555
- _ldivu_remainder! (spa, spu , N, Nrr, Nru, $ WS, $ (Val (UNIT)), $ (static (1 )))
539
+ spa, spl = reassemble_tup (Args, args)
540
+ _ldivu_remainder! (spa, spl , N, Nrr, Nru, $ WS, $ (Val (UNIT)), $ (static (1 )))
556
541
nothing
557
542
end
558
543
elseif W == 8
559
544
s8 = StaticInt (8 )
560
545
quote
561
546
# $(Expr(:meta, :inline))
562
- spa, spu = reassemble_tup (Args, args)
547
+ spa, spl = reassemble_tup (Args, args)
563
548
if m == M - 1
564
- _ldivu_remainder! (spa, spu , N, Nr, $ s8, $ (Val (UNIT)), $ (StaticInt (1 )))
549
+ _ldivu_remainder! (spa, spl , N, Nr, $ s8, $ (Val (UNIT)), $ (StaticInt (1 )))
565
550
else
566
551
if m == M - 2
567
- _ldivu_remainder! (spa, spu , N, Nr, $ s8, $ (Val (UNIT)), $ (StaticInt (2 )))
552
+ _ldivu_remainder! (spa, spl , N, Nr, $ s8, $ (Val (UNIT)), $ (StaticInt (2 )))
568
553
else
569
554
if m == M - 3
570
555
_ldivu_remainder! (
571
556
spa,
572
- spu ,
557
+ spl ,
573
558
N,
574
559
Nr,
575
560
$ s8,
580
565
if m == M - 4
581
566
_ldivu_remainder! (
582
567
spa,
583
- spu ,
568
+ spl ,
584
569
N,
585
570
Nr,
586
571
$ s8,
591
576
if m == M - 5
592
577
_ldivu_remainder! (
593
578
spa,
594
- spu ,
579
+ spl ,
595
580
N,
596
581
Nr,
597
582
$ s8,
602
587
if m == M - 6
603
588
_ldivu_remainder! (
604
589
spa,
605
- spu ,
590
+ spl ,
606
591
N,
607
592
Nr,
608
593
$ s8,
612
597
else
613
598
_ldivu_remainder! (
614
599
spa,
615
- spu ,
600
+ spl ,
616
601
N,
617
602
Nr,
618
603
$ s8,
630
615
else
631
616
quote
632
617
# $(Expr(:meta, :inline))
633
- spa, spu = reassemble_tup (Args, args)
618
+ spa, spl = reassemble_tup (Args, args)
634
619
Base. Cartesian. @nif $ (W - 1 ) w -> m == M - w w ->
635
- _ldivu_remainder! (spa, spu , N, Nr, $ WS, $ (Val (UNIT)), static (w))
620
+ _ldivu_remainder! (spa, spl , N, Nr, $ WS, $ (Val (UNIT)), static (w))
636
621
nothing
637
622
end
638
623
end
@@ -646,7 +631,7 @@ function _ldivu_L!(
646
631
args:: Vararg{Any,K}
647
632
) where {UNIT,Args,K}
648
633
# B_{n,m} = (A_{n,m} - \sum_{i=n+1}^N U_{n,i}B_{i,m})/U_{n,n}
649
- spa, spu = reassemble_tup (Args, args)
634
+ spa, spl = reassemble_tup (Args, args)
650
635
T = eltype (spa)
651
636
WS = pick_vector_width (T)
652
637
W = Int (WS)
@@ -664,8 +649,7 @@ function _ldivu_L!(
664
649
while m < M - WS + 1
665
650
n:: Int = nstart
666
651
if Nrr > 0
667
- let t = (gesp (spa, (z, n)), gesp (spu, (n, n))), ft = flatten_to_tup (t)
668
- @show 0 , n
652
+ let t = (gesp (spa, (z, n)), gesp (spl, (n, n))), ft = flatten_to_tup (t)
669
653
compute && BdivL_small_kern_u! (
670
654
Nrr,
671
655
StaticInt (1 ),
@@ -680,10 +664,9 @@ function _ldivu_L!(
680
664
for _ ∈ 1 : Ndr
681
665
k = N - n
682
666
n -= W
683
- @show 1 , n, k
684
667
compute && ldivu_solve_W! (
685
668
gesp (spa, (z, n)),
686
- gesp (spu , (n, n)),
669
+ gesp (spl , (n, n)),
687
670
k,
688
671
WS,
689
672
Val (UNIT),
@@ -694,10 +677,9 @@ function _ldivu_L!(
694
677
while n != 0
695
678
k = N - n
696
679
n -= Int (WU)
697
- @show 2 , n, k
698
680
compute && ldivu_solve_W_u! (
699
681
gesp (spa, (z, n)),
700
- gesp (spu , (n, n)),
682
+ gesp (spl , (n, n)),
701
683
k,
702
684
WS,
703
685
UF,
@@ -709,8 +691,7 @@ function _ldivu_L!(
709
691
end
710
692
# remainder on `m`
711
693
if m < M
712
- let tup = (spa, spu), ftup = flatten_to_tup (tup)
713
- @show m, Nrr, M
694
+ let tup = (spa, spl), ftup = flatten_to_tup (tup)
714
695
compute &&
715
696
ldivu_remainder! (M, N, m, Nrr, WS, Val (UNIT), typeof (tup), ftup... )
716
697
end
0 commit comments