Skip to content

Commit 8c29b37

Browse files
committed
Tile-level partitioning in jr/ir loops (ex-trsm). (#695)
Details: - Reimplemented parallelization of the JR loop in gemmt (which is recycled for herk, her2k, syrk, and syr2k). Previously, the rectangular region of the current MC x NC panel of C would be parallelized separately from from the diagonal region of that same submatrix, with the rectangular portion being assigned to threads via slab or round-robin (rr) partitioning (as determined at configure- time) and the diagonal region being assigned via round-robin. This approach did not work well when extracting lots of parallelism from the JR loop and was often suboptimal even for smaller degrees of parallelism. This commit implements tile-level load balancing (tlb) in which the IR loop is effectively subjugated in service of more equitably dividing work in the JR loop. This approach is especially potent for certain situations where the diagonal region of the MC x NR panel of C are significant relative to the entire region. However, it also seems to benefit many problem sizes of other level-3 operations (excluding trsm, which has an inherent algorithmic dependency in the IR loop that prevents the application of tlb). For now, tlb is implemented as _var2b.c macrokernels for gemm (which forms the basis for gemm, hemm, and symm), gemmt (which forms the basis of herk, her2k, syrk, and syr2k), and trmm (which forms the basis of trmm and trmm3). Which function pointers (_var2() or _var2b()) are embedded in the control tree will depend on whether the BLIS_ENABLE_JRIR_TLB cpp macro is defined, which is controlled by the value passed to the existing --thread-part-jrir=METHOD (or -r METHOD) configure option. This script adds 'tlb' as a valid option alongside the previously supported values of 'slab' and 'rr'. ('slab' is still the default.) Thanks to Leick Robinson for abstractly inspiring this work, and to Minh Quan Ho for inquiring (in PR #562, and before that in Issue #437) about the possibility of improved load balance in macrokernel loops, and even prototyping what it might look like, long before I fully understood the problem. - In bli_thread_range_weighted_sub(), tweaked the the way we compute the area of the current MC x NC trapezoidal panel of C by better taking into account the microtile structure along the diagonal. Previously, it was an underestimate, as it assumed MR = NR = 1 (that is, it assumed that the microtile column of C that overlapped with microtiles exactly coincided with the diagonal). Now, we only assume MR = NR. This is still a slight underestimate when MR != NR, so the additional area is scaled by 1.5 in a hackish attempt to compensate for this, as well as other additional effects that are difficult to model (such as the increased cost of writing to temporary tiles before finally updating C). The net effect of this better estimation of the trapezoidal area should be (on average) slightly larger regions assigned to threads that have little or no overlap with the diagonal region (and correspondingly slightly smaller regions in the diagonal region), which we expect will lead to slightly better load balancing in most situations. - Spun off the contents of bli_thread.[ch] that relate to computing thread ranges into one of three source/header file pairs: - bli_thread_range.[ch], which define functions that are not specific to the jr/ir loops; - bli_thread_range_slab_rr.[ch], which define functions that implement slab or round-robin partitioning for the jr/ir loops; - bli_thread_range_tlb.[ch], which define functions that implement tlb for the jr/ir loops. - Fixed the computation of a_next in the last iteration of the IR loop in bli_gemmt_l_ker_var2(). Previously, it always "wrapped" back around to the first micropanel of the current MC x KC packed block of A. However, this is almost never actually the micropanel that is used next. A new macro, bli_gemmt_l_wrap_a_upanel(), computes a_next correctly, with a similarly named bli_gemmt_u_wrap_a_upanel() for use in the upper-stored case (which *does* actually always choose the first micropanel of A as its a_next at the end of the IR loop). - Removed adjustments for a_next/b_next (a2/b2) for the diagonal- intersecting case of gemmt_l_ker_var2() and the above-diagonal case of gemmt_u_ker_var2() since these cases will only coincide with the last iteration of the IR loop in very small problems. - Defined bli_is_last_iter_l() and bli_is_last_iter_u(), the latter of which explicitly considers whether the current microtile is the last tile that intersects the diagonal. (The former does the same, but the computation coincides with the original bli_is_last_iter().) These functions are now used in gemmt to test when a_next (or a2) should "wrap" (as discussed above). Also defined bli_is_last_iter_tlb_l() and bli_is_last_iter_tlb_u(), which are similar to the aforementioned functions but are used when employing tlb in gemmt. - Redefined macros in bli_packm_thrinfo.h, which test whether an iteration of work is assigned to a thread, as static inline functions in bli_param_macro_defs.h (and then deleted bli_packm_thrinfo.h). In the process of redefining these macros, I also renamed them from bli_packm_my_iter_rr/sl() to bli_is_my_iter_rr/sl(). - Renamed bli_thread_range_jrir_rr() -> bli_thread_range_rr() bli_thread_range_jrir_sl() -> bli_thread_range_sl() bli_thread_range_jrir() -> bli_thread_range_slrr() - Renamed bli_is_last_iter() -> bli_is_last_iter_slrr() - Defined bli_info_get_thread_jrir_tlb() and renamed: - bli_info_get_thread_part_jrir_slab() -> bli_info_get_thread_jrir_slab() - bli_info_get_thread_part_jrir_rr() -> bli_info_get_thread_jrir_rr() - Modified bli_rntm_set_ways_for_op() to redirect IR loop parallelism into the JR loop when tlb is enabled for non-trsm level-3 operations. - Added a sanity check to prevent bli_prune_unref_mparts() from being used on packed objects. This prohibition is necessary because the current implementation does not take into account the atomicity of packed micropanel widths relative to the diagonal of structured matrices. That is, the function prunes greedily without regard to whether doing so would prune off part of a micropanel *which has already been packed* and assigned to a thread for inclusion in the computation. - Further restricted early returns in bli_prune_unref_mparts() to situations where the primary matrix is not only of general structure but also dense (in terms of its uplo_t value). The addition of the matrix's dense-ness to the conditional is required because gemmt is somewhat unusual in that its C matrix has general structure but is marked as lower- or upper-stored via its uplo_t. By only checking for general structure, attempts to prune gemmt C matrices would incorrectly result in early returns, even though that operation effectively treats the matrix as symmetric (and stored in only one triangle). - Fixed a latent bug in bli_thread_range_rr() wherein incorrect ranges were computed when 1 < bf. Thankfully, this bug was not yet manifesting since all current invocations used bf == 1. - Fixed a latent bug in some unexercised code in bli_?gemmt_l_ker_var2() that would perform incorrect pruning of unreferenced regions above where the diagonal of a lower-stored matrix intersects the right edge. Thankfully, the bug was not harming anything since those unreferenced regions were being pruned prior to the macrokernel. - Rewrote slab/rr-based gemmt macrokernels so that they no longer carved C into rectangular and diagonal regions prior to parallelizing each separately. The new macrokernels use a unified loop structure where quadratic (slab) partitioning is used. - Updated all level-3 macrokernels to have a more uniform coding style, such as wrt combining variable declarations with initializations as well as the use of const. - Updated bls_l3_packm_var[123].c to use bli_thrinfo_n_way() and bli_thrinfo_work_id() instead of bli_thrinfo_num_threads() and bli_thrinfo_thread_id(), respectively. This change probably should have been included in aeb5f0c. - Removed old prototypes in bli_gemmt_var.h and bli_trmm_var.h that corresponded to functions that were removed in aeb5f0c. - Other very minor cleanups. - Comment updates. - (cherry picked from commit 2e1ba9d)
1 parent 6564639 commit 8c29b37

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+10608
-2022
lines changed

build/bli_config.h.in

+4
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@
8080
#define BLIS_ENABLE_JRIR_RR
8181
#endif
8282

83+
#if @enable_jrir_tlb@
84+
#define BLIS_ENABLE_JRIR_TLB
85+
#endif
86+
8387
#if @enable_pba_pools@
8488
#define BLIS_ENABLE_PBA_POOLS
8589
#else

configure

+43-18
Original file line numberDiff line numberDiff line change
@@ -340,16 +340,36 @@ print_usage()
340340
echo " "
341341
echo " -r METHOD, --thread-part-jrir=METHOD"
342342
echo " "
343-
echo " Request a method of assigning micropanels to threads in"
344-
echo " the JR and IR loops. Valid values for METHOD are 'slab'"
345-
echo " and 'rr'. Using 'slab' assigns (as much as possible)"
346-
echo " contiguous regions of micropanels to each thread while"
347-
echo " using 'rr' assigns micropanels to threads in a round-"
348-
echo " robin fashion. The chosen method also applies during"
349-
echo " the packing of A and B. The default method is 'slab'."
350-
echo " NOTE: Specifying this option constitutes a request,"
351-
echo " which may be ignored in select situations if the"
352-
echo " implementation has a good reason to do so."
343+
echo " Select a strategy for partitioning computation in JR and"
344+
echo " IR loops and assigning that computation to threads. Valid"
345+
echo " values for METHOD are 'rr', 'slab', and 'tlb':"
346+
echo " 'rr': Assign the computation associated with whole"
347+
echo " columns of microtiles to threads in a round-"
348+
echo " robin fashion. When selected, round-robin"
349+
echo " assignment is also employed during packing."
350+
echo " 'slab': Partition the computation into N contiguous"
351+
echo " regions, where each region contains a whole"
352+
echo " number of microtile columns, and assign one"
353+
echo " region to each thread. For some operations, the"
354+
echo " number of microtile columns contained within a"
355+
echo " given region may differ from that of other"
356+
echo " regions, depending on how much work is implied"
357+
echo " by each region. When selected, slab assignment"
358+
echo " is also employed during packing."
359+
echo " 'tlb': Tile-level load balancing is similar to slab,"
360+
echo " except that regions will be divided at a more"
361+
echo " granular level (individual microtiles instead"
362+
echo " of whole columns of microtiles) to ensure more"
363+
echo " equitable assignment of work to threads. When"
364+
echo " selected, tlb will only be employed for level-3"
365+
echo " operations except trsm; due to practical and"
366+
echo " algorithmic limitations, slab partitioning will"
367+
echo " be used instead during packing and for trsm."
368+
echo " The default strategy is 'slab'. NOTE: Specifying this"
369+
echo " option constitutes a request, which may be ignored in"
370+
echo " select situations if implementation has a good reason to"
371+
echo " do so. (See description of 'tlb' above for an example of"
372+
echo " this.)"
353373
echo " "
354374
echo " --disable-trsm-preinversion, --enable-trsm-preinversion"
355375
echo " "
@@ -3731,16 +3751,20 @@ main()
37313751

37323752
# Check the method of assigning micropanels to threads in the JR and IR
37333753
# loops.
3734-
enable_jrir_slab_01=0
37353754
enable_jrir_rr_01=0
3736-
if [ "x${thread_part_jrir}" = "xslab" ]; then
3737-
echo "${script_name}: requesting slab threading in jr and ir loops."
3738-
enable_jrir_slab_01=1
3739-
elif [ "x${thread_part_jrir}" = "xrr" ]; then
3740-
echo "${script_name}: requesting round-robin threading in jr and ir loops."
3755+
enable_jrir_slab_01=0
3756+
enable_jrir_tlb_01=0
3757+
if [ "x${thread_part_jrir}" = "xrr" ]; then
3758+
echo "${script_name}: requesting round-robin (rr) work partitioning in jr and/or ir loops."
37413759
enable_jrir_rr_01=1
3760+
elif [ "x${thread_part_jrir}" = "xslab" ]; then
3761+
echo "${script_name}: requesting slab work partitioning in jr and/or ir loops."
3762+
enable_jrir_slab_01=1
3763+
elif [ "x${thread_part_jrir}" = "xtlb" ]; then
3764+
echo "${script_name}: requesting tile-level load balancing (tlb) in unified jr+ir loop."
3765+
enable_jrir_tlb_01=1
37423766
else
3743-
echo "${script_name}: *** Unsupported method of thread partitioning in jr and ir loops: ${thread_part_jrir}."
3767+
echo "${script_name}: *** Unsupported method of work partitioning in jr/ir loops: ${thread_part_jrir}."
37443768
exit 1
37453769
fi
37463770

@@ -4177,8 +4201,9 @@ main()
41774201
| sed -e "s/@enable_pthreads_as_def@/${enable_pthreads_as_def_01}/g" \
41784202
| sed -e "s/@enable_hpx@/${enable_hpx_01}/g" \
41794203
| sed -e "s/@enable_hpx_as_def@/${enable_hpx_as_def_01}/g" \
4180-
| sed -e "s/@enable_jrir_slab@/${enable_jrir_slab_01}/g" \
41814204
| sed -e "s/@enable_jrir_rr@/${enable_jrir_rr_01}/g" \
4205+
| sed -e "s/@enable_jrir_slab@/${enable_jrir_slab_01}/g" \
4206+
| sed -e "s/@enable_jrir_tlb@/${enable_jrir_tlb_01}/g" \
41824207
| sed -e "s/@enable_pba_pools@/${enable_pba_pools_01}/g" \
41834208
| sed -e "s/@enable_sba_pools@/${enable_sba_pools_01}/g" \
41844209
| sed -e "s/@enable_mem_tracing@/${enable_mem_tracing_01}/g" \

frame/1m/packm/bli_packm.h

-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
#include "bli_packm_init.h"
4040
#include "bli_packm_int.h"
4141
#include "bli_packm_scalar.h"
42-
#include "bli_packm_thrinfo.h"
4342

4443
#include "bli_packm_part.h"
4544

frame/1m/packm/bli_packm_blk_var1.c

+8-8
Original file line numberDiff line numberDiff line change
@@ -170,11 +170,11 @@ void bli_packm_blk_var1
170170
const dim_t tid = bli_thrinfo_work_id( thread );
171171

172172
// Determine the thread range and increment using the current thread's
173-
// packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir()
173+
// packm thrinfo_t node. NOTE: The definition of bli_thread_range_slrr()
174174
// will depend on whether slab or round-robin partitioning was requested
175175
// at configure-time.
176176
dim_t it_start, it_end, it_inc;
177-
bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc );
177+
bli_thread_range_slrr( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc );
178178

179179
char* p_begin = p_cast;
180180

@@ -195,10 +195,10 @@ void bli_packm_blk_var1
195195

196196
char* c_begin = c_cast + (ic )*incc*dt_c_size;
197197

198-
// Hermitian/symmetric and general packing may use slab or
199-
// round-robin (bli_packm_my_iter()), depending on which was
200-
// selected at configure-time.
201-
if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) )
198+
// Hermitian/symmetric and general packing may use slab or round-
199+
// robin (bli_is_my_iter()), depending on which was selected at
200+
// configure-time.
201+
if ( bli_is_my_iter( it, it_start, it_end, tid, nt ) )
202202
{
203203
packm_ker_cast( bli_is_triangular( strucc ) ? BLIS_GENERAL : strucc,
204204
diagc,
@@ -286,9 +286,9 @@ void bli_packm_blk_var1
286286
// We nudge the imaginary stride up by one if it is odd.
287287
is_p_use += ( bli_is_odd( is_p_use ) ? 1 : 0 );
288288

289-
// NOTE: We MUST use round-robin work allocation (bli_packm_my_iter_rr())
289+
// NOTE: We MUST use round-robin work allocation (bli_is_my_iter_rr())
290290
// when packing micropanels of a triangular matrix.
291-
if ( bli_packm_my_iter_rr( it, it_start, it_end, tid, nt ) )
291+
if ( bli_is_my_iter_rr( it, tid, nt ) )
292292
{
293293
packm_ker_cast( strucc,
294294
diagc,

frame/3/bli_l3_sup_packm_var.c

+8-8
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,10 @@ void PASTEMAC(ch,varname) \
155155
dim_t it_start, it_end, it_inc; \
156156
\
157157
/* Determine the thread range and increment using the current thread's
158-
packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir()
158+
packm thrinfo_t node. NOTE: The definition of bli_thread_range_slrr()
159159
will depend on whether slab or round-robin partitioning was requested
160160
at configure-time. */ \
161-
bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); \
161+
bli_thread_range_slrr( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); \
162162
\
163163
/* Iterate over every logical micropanel in the source matrix. */ \
164164
for ( ic = ic0, it = 0; it < n_iter; \
@@ -175,9 +175,9 @@ void PASTEMAC(ch,varname) \
175175
panel_len_i = panel_len_full; \
176176
panel_len_max_i = panel_len_max; \
177177
\
178-
/* The definition of bli_packm_my_iter() will depend on whether slab
178+
/* The definition of bli_is_my_iter() will depend on whether slab
179179
or round-robin partitioning was requested at configure-time. */ \
180-
if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) ) \
180+
if ( bli_is_my_iter( it, it_start, it_end, tid, nt ) ) \
181181
{ \
182182
f \
183183
( \
@@ -398,10 +398,10 @@ void PASTEMAC(ch,varname) \
398398
dim_t it_start, it_end, it_inc; \
399399
\
400400
/* Determine the thread range and increment using the current thread's
401-
packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir()
401+
packm thrinfo_t node. NOTE: The definition of bli_thread_range_slrr()
402402
will depend on whether slab or round-robin partitioning was requested
403403
at configure-time. */ \
404-
bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); \
404+
bli_thread_range_slrr( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); \
405405
\
406406
/* Iterate over every logical micropanel in the source matrix. */ \
407407
for ( it = 0; it < n_iter; it += 1 ) \
@@ -412,9 +412,9 @@ void PASTEMAC(ch,varname) \
412412
ctype* p_use = p_begin; \
413413
\
414414
{ \
415-
/* The definition of bli_packm_my_iter() will depend on whether slab
415+
/* The definition of bli_is_my_iter() will depend on whether slab
416416
or round-robin partitioning was requested at configure-time. */ \
417-
if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) ) \
417+
if ( bli_is_my_iter( it, it_start, it_end, tid, nt ) ) \
418418
{ \
419419
PASTEMAC2(ch,scal2v,BLIS_TAPI_EX_SUF) \
420420
( \

frame/3/bli_l3_sup_var12.c

+2-2
Original file line numberDiff line numberDiff line change
@@ -357,11 +357,11 @@ void PASTEMAC(ch,varname) \
357357
object. */ \
358358
/*
359359
ctype* a2 = bli_gemm_get_next_a_upanel( a_ir, irstep_a, ir_inc ); \
360-
if ( bli_is_last_iter( i, ir_iter, 0, 1 ) ) \
360+
if ( bli_is_last_iter_slrr( i, ir_iter, 0, 1 ) ) \
361361
{ \
362362
a2 = a_00; \
363363
b2 = bli_gemm_get_next_b_upanel( b_jr, jrstep_b, jr_inc ); \
364-
if ( bli_is_last_iter( j, jr_iter, 0, 1 ) ) \
364+
if ( bli_is_last_iter_slrr( j, jr_iter, 0, 1 ) ) \
365365
b2 = b_00; \
366366
} \
367367
\

frame/3/bli_l3_thrinfo.h

+6-6
Original file line numberDiff line numberDiff line change
@@ -39,22 +39,22 @@
3939

4040
// gemm
4141

42-
// NOTE: The definition of bli_gemm_get_next_?_upanel() does not need to
43-
// change depending on BLIS_ENABLE_JRIR_SLAB / BLIS_ENABLE_JRIR_RR.
4442
#define bli_gemm_get_next_a_upanel( a1, step, inc ) ( a1 + step * inc )
4543
#define bli_gemm_get_next_b_upanel( b1, step, inc ) ( b1 + step * inc )
4644

4745
// gemmt
4846

49-
// NOTE: The definition of bli_gemmt_get_next_?_upanel() does not need to
50-
// change depending on BLIS_ENABLE_JRIR_SLAB / BLIS_ENABLE_JRIR_RR.
5147
#define bli_gemmt_get_next_a_upanel( a1, step, inc ) ( a1 + step * inc )
5248
#define bli_gemmt_get_next_b_upanel( b1, step, inc ) ( b1 + step * inc )
5349

50+
// NOTE: Here, we assume NO parallelism in the IR loop.
51+
#define bli_gemmt_l_wrap_a_upanel( a0, step, doff_j, mr, nr ) \
52+
( a0 + ( (-doff_j + 1*nr) / mr ) * step )
53+
#define bli_gemmt_u_wrap_a_upanel( a0, step, doff_j, mr, nr ) \
54+
( a0 )
55+
5456
// trmm
5557

56-
// NOTE: The definition of bli_trmm_get_next_?_upanel() does not need to
57-
// change depending on BLIS_ENABLE_JRIR_SLAB / BLIS_ENABLE_JRIR_RR.
5858
#define bli_trmm_get_next_a_upanel( a1, step, inc ) ( a1 + step * inc )
5959
#define bli_trmm_get_next_b_upanel( b1, step, inc ) ( b1 + step * inc )
6060

frame/3/gemm/bli_gemm_cntl.c

+19-4
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,25 @@ cntl_t* bli_gemmbp_cntl_create
6161
void_fp macro_kernel_fp;
6262

6363
// Choose the default macrokernel based on the operation family...
64-
if ( family == BLIS_GEMM ) macro_kernel_fp = bli_gemm_ker_var2;
65-
else if ( family == BLIS_GEMMT ) macro_kernel_fp = bli_gemmt_x_ker_var2;
66-
else if ( family == BLIS_TRMM ) macro_kernel_fp = bli_trmm_xx_ker_var2;
67-
else /* should never execute */ macro_kernel_fp = NULL;
64+
if ( family == BLIS_GEMM ) macro_kernel_fp =
65+
#ifdef BLIS_ENABLE_JRIR_TLB
66+
bli_gemm_ker_var2b;
67+
#else // ifdef ( _SLAB || _RR )
68+
bli_gemm_ker_var2;
69+
#endif
70+
else if ( family == BLIS_GEMMT ) macro_kernel_fp =
71+
#ifdef BLIS_ENABLE_JRIR_TLB
72+
bli_gemmt_x_ker_var2b;
73+
#else // ifdef ( _SLAB || _RR )
74+
bli_gemmt_x_ker_var2;
75+
#endif
76+
else if ( family == BLIS_TRMM ) macro_kernel_fp =
77+
#ifdef BLIS_ENABLE_JRIR_TLB
78+
bli_trmm_xx_ker_var2b;
79+
#else // ifdef ( _SLAB || _RR )
80+
bli_trmm_xx_ker_var2;
81+
#endif
82+
else /* should never execute */ macro_kernel_fp = NULL;
6883

6984
// ...unless a non-NULL kernel function pointer is passed in, in which
7085
// case we use that instead.

0 commit comments

Comments
 (0)